from typing import Tuple

import numpy as np
import torch
from torch import FloatTensor
from torch.utils.data.dataset import Dataset
from torchvision.datasets import MNIST
from torchvision.transforms import Compose, ToTensor, Normalize

from .settings import DATA_ROOT

MNIST_TRANSFORM = Compose([ToTensor(), Normalize((0.1307,), (0.3081,))])


class MNISTSummation(Dataset):
    def __init__(self, min_len: int, max_len: int, dataset_len: int, train: bool = True, transform: Compose = None):
        self.min_len = min_len
        self.max_len = max_len
        self.dataset_len = dataset_len
        self.train = train
        self.transform = transform

        self.mnist = MNIST(DATA_ROOT, train=self.train, transform=self.transform, download=True)
        mnist_len = self.mnist.__len__()
        mnist_items_range = np.arange(0, mnist_len)

        items_len_range = np.arange(self.min_len, self.max_len + 1)
        items_len = np.random.choice(items_len_range, size=self.dataset_len, replace=True)
        self.mnist_items = []
        for i in range(self.dataset_len):
            self.mnist_items.append(np.random.choice(mnist_items_range, size=items_len[i], replace=True))

    def __len__(self) -> int:
        return self.dataset_len

    def __getitem__(self, item: int) -> Tuple[FloatTensor, FloatTensor]:
        mnist_items = self.mnist_items[item]

        the_sum = 0
        images = []
        for mi in mnist_items:
            img, target = self.mnist.__getitem__(mi)
            the_sum += target
            images.append(img)

        return torch.stack(images, dim=0), torch.FloatTensor([the_sum])