import collections
import pprint
import bitstring # needs to be installed via pip


# read entire file as string into 'input'
with open("input.txt") as f:
    input = f.read()

input_len = len(input) # save for later
print(f"Input text: {input_len} bytes")

# TEXT_PROCESSING_AND_TOKENIZATION
# at this point the 'input' variable holds the input
# text (the Bee Movie script)

input = input.lower() # turn everything lowercase

# add spaces before and after punctuation and newlines
# to make the .split(" ") below split these from
# the words they "belong" to, making them individual tokens
# 
# before: "yellow, black." -> ["yellow,", "black."]
# after:  "yellow ,  black . " -> ["yellow", ",", "black", "."]
input = input.replace(",", " , ")
input = input.replace(".", " . ")
input = input.replace("!", " ! ")
input = input.replace("?", " ? ")
input = input.replace("\n", " \n ")

input_tokens = input.split(" ")

# remove empty tokens until there are none left
# and .remove() raises a ValueError, then continue
try:
    while True:
        input_tokens.remove("")
except ValueError:
    pass
# TEXT_PROCESSING_AND_TOKENIZATION_END

print(input_tokens[:100])

# TREE_BUILDING
# collections.Counter (among other things) acts as
# a dictionary of {value: count}, where value
# is each unique element of the list we feed in
# (in our case tokenss) and the count is the amount
# of times it occours in the list 
input_counted = collections.Counter(input_tokens)

# create a dictionary for each unique token that we have
# this acts as a leaf in our tree
tree = [{'token': token, 'count': count} for token, count in input_counted.items()]

# until only the stem is left, grab the two nodes/leafs
# with the fewest count and group them together in a node
while len(tree) > 1:
    tree.sort(key=lambda x: x['count'], reverse=True)
    zero = tree.pop()
    one = tree.pop()
    # the count of a node is the sum of its two children
    tree.append({'one': one, 'zero': zero,
                 'count': one['count'] + zero['count']})

# only the stem is left, set 'tree' to it so that we
# don't need to specify a redundant [0] every time
tree = tree[0]
# TREE_BUILDING_END

# DICTIONARY_BUILDING
# dictionary holding the translations from token to bits
translation = {}

def add_to_translation(bits, node):
    if 'token' in node:
        # we're at a leaf (end of tree)
        # store the word with the bits collected along the path
        translation[node['token']] = bits
    else:
        # we're at a node, run this function again on the
        # two childern (recurse). give it the BitArray that we
        # were given, plus the bit corrosponding to the
        # branch of the child node/leaf
        add_to_translation(bits + "0b1", node['one'])
        add_to_translation(bits + "0b0", node['zero'])

# start at the stem with an empty BitArray
add_to_translation(bitstring.BitArray(), tree)
# DICTIONARY_BUILDING_END

# COMPRESSING
# our final compressed data
compressed = bitstring.BitArray()

# iterate through the tokens, adding the corrosponding
# bits for each one to our compressed data
for token in input_tokens:
    compressed.append(translation[token])
# COMPRESSING_END

# just add all words and their BitArrays in the translation together
# as an approximation for the tree size
tree_len = len(''.join(translation.keys())) \
           + len(''.join([x.b for x in translation.values()]))/8
compressed_len = len(compressed)/8

print(f"Compressed text (without tree): {compressed_len}")
print(f"Tree estimate: {tree_len}")
compression_ratio = (compressed_len + tree_len) / input_len
print(f"Compression ratio (lower is better): {round(compression_ratio * 100, 2)}%")









# DECOMPRESSING
# turn our compressed data into a stream that we can
# read bits from one-by-one
compressed_stream = bitstring.ConstBitStream(compressed)

# our decompressed data
decompressed_tokens = []
# our current position in the tree
# we start on the stem
pos = tree

while compressed_stream.pos < compressed_stream.length:
    # if the next bit is a one, set our position
    # to the one-child of the pos, else set our
    # position to the zero-child
    if compressed_stream.read('bin:1') == "1":
        pos = pos['one']
    else:
        pos = pos['zero']

    if 'token' in pos:
        # we hit a leaf, add its corrospoonding
        # token to our decompressed data and reset
        # our position back to the stem 
        decompressed_tokens.append(pos['token'])
        pos = tree
# DECOMPRESSING_END

# join our tokens together to text
decompressed = " ".join(decompressed_tokens)

# undo the text processing we did at the start of the
# script
decompressed = decompressed.replace(" \n ", "\n")
decompressed = decompressed.replace(" ?", "?")
decompressed = decompressed.replace(" !", "!")
decompressed = decompressed.replace(" .", ".")
decompressed = decompressed.replace(" ,", ",")

# print the start of our decompressed data
print("\n" + decompressed[:500])
