import pickle

import numpy as np
import torch
from torch.utils.data import TensorDataset
from .folded_dataset import FoldedDataset

from options import opt


def get_msda_amazon_datasets(data_file, domain, kfold, feature_num):
    print(f'Loading mSDA Preprocessed Multi-Domain Amazon data for {domain} Domain')
    dataset = pickle.load(open(data_file, 'rb'))[domain]

    lx, ly = dataset['labeled']
    if feature_num > 0:
        lx = lx[:, : feature_num]
    lx = torch.from_numpy(lx.toarray()).float().to(opt.device)
    ly = torch.from_numpy(ly).long().to(opt.device)
    print(f'{domain} Domain has {len(ly)} labeled instances.')
    # if opt.use_cuda:
    #     lx, ly = lx.cuda(), ly.cuda()
    labeled_set = FoldedDataset(TensorDataset, kfold, lx, ly)

    ux, uy = dataset['unlabeled']
    if feature_num > 0:
        ux = ux[:, : feature_num]
    ux = torch.from_numpy(ux.toarray()).float().to(opt.device)
    uy = torch.from_numpy(uy).long().to(opt.device)
    print(f'{domain} Domain has {len(uy)} unlabeled instances.')
    # if opt.use_cuda:
    #     ux, uy = ux.cuda(), uy.cuda()
    unlabeled_set = TensorDataset(ux, uy)

    return labeled_set, unlabeled_set