"""
Divide the dataset into 2 parts only, i.e. train set and test set.
"""

from torch.utils.data import Dataset
from torchvision import transforms
import os
from PIL import Image
import numpy as np


class Driver(Dataset):
    def __init__(self, root, transform=None, target_transform=None, train=True, test=False):
        self.root = root
        self.transform = transform
        self.target_transform = target_transform
        self.train = train
        self.test = test

        if self.test:
            with open(os.path.join(self.root, 'test.csv'), 'r') as f:
                lines = f.readlines()[1:]
            dataset = []
            for line in lines:
                dataset.append(line.strip().split(','))
        else:
            with open(os.path.join(self.root, 'train.csv'), 'r') as f:
                lines = f.readlines()[1:]
            dataset = []
            for line in lines:
                dataset.append(line.strip().split(','))

        dataset = np.array(dataset)
        self.imgs = list(map(lambda x: os.path.join(self.root, x), dataset[:, 0]))
        self.target = list(map(int, dataset[:, 1]))

        if transform is None:
            normalize = transforms.Normalize(
                mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

            if self.test:
                self.transform = transforms.Compose([
                    transforms.Resize(256),
                    transforms.CenterCrop(224),
                    transforms.ToTensor(), normalize
                ])
            else:
                self.transform = transforms.Compose([
                    transforms.Resize(256),
                    transforms.RandomResizedCrop(224, scale=(0.25, 1)),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(), normalize
                ])

    def __getitem__(self, index):
        img_path = self.imgs[index]
        target = self.target[index]
        img = Image.open(img_path)

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

    def __len__(self):
        return len(self.imgs)


if __name__ == '__main__':
    driver = Driver('/home/husencd/Downloads/dataset/driver', train=True)
    print(driver.__getitem__(1))
    print(driver.__len__())  # 12977
    driver = Driver('/home/husencd/Downloads/dataset/driver', train=False, test=True)
    print(driver.__len__())  # 4331