import torch

from slender.prune.vanilla import prune_vanilla_elementwise
from slender.quantize.linear import quantize_linear_fix_zeros
from slender.quantize.fixed_point import quantize_fixed_point
from slender.quantize.quantizer import Quantizer
from slender.coding.encode import EncodedParam
from slender.coding.codec import Codec


def test_encode_param():
    param = torch.rand(256, 128, 3, 3)
    prune_vanilla_elementwise(sparsity=0.7, param=param)
    quantize_linear_fix_zeros(param, k=16)
    huffman = EncodedParam(param=param, method='huffman',
                           encode_indices=True, bit_length_zero_run_length=4)
    stats = huffman.stats
    print(stats)
    assert torch.eq(param, huffman.data).all()
    state_dict = huffman.state_dict()
    huffman = EncodedParam()
    huffman.load_state_dict(state_dict)
    assert torch.eq(param, huffman.data).all()
    vanilla = EncodedParam(param=param, method='vanilla',
                           encode_indices=True, bit_length_zero_run_length=4)
    stats = vanilla.stats
    print(stats)
    assert torch.eq(param, vanilla.data).all()
    quantize_fixed_point(param=param, bit_length=4, bit_length_integer=0)
    fixed_point = EncodedParam(param=param, method='fixed_point',
                               bit_length=4, bit_length_integer=0,
                               encode_indices=True, bit_length_zero_run_length=4)
    stats = fixed_point.stats
    print(stats)
    assert torch.eq(param, fixed_point.data).all()


def test_codec():
    quantize_rule = [
        ('0.weight', 'k-means', 4, 'k-means++'),
        ('1.weight', 'fixed_point', 6, 1),
    ]
    model = torch.nn.Sequential(torch.nn.Conv2d(256, 128, 3, bias=True),
                                torch.nn.Conv2d(128, 512, 1, bias=False))
    mask_dict = {}
    for n, p in model.named_parameters():
        mask_dict[n] = prune_vanilla_elementwise(sparsity=0.6, param=p.data)
    quantizer = Quantizer(rule=quantize_rule, fix_zeros=True)
    quantizer.quantize(model, update_labels=False, verbose=True)
    rule = [
        ('0.weight', 'huffman', 0, 0, 4),
        ('1.weight', 'fixed_point', 6, 1, 4)
    ]
    codec = Codec(rule=rule)
    encoded_module = codec.encode(model)
    print(codec.stats)
    state_dict = encoded_module.state_dict()
    model_2 = torch.nn.Sequential(torch.nn.Conv2d(256, 128, 3, bias=True),
                                  torch.nn.Conv2d(128, 512, 1, bias=False))
    model_2 = Codec.decode(model_2, state_dict)
    for p1, p2 in zip(model.parameters(), model_2.parameters()):
        if p1.dim() > 1:
            assert torch.eq(p1, p2).all()