import os import random from typing import List import numpy as np import torch from nobos_commons.data_structures.dimension import ImageSize from nobos_torch_lib.configs.training_configs.training_config_base import TrainingConfigBase from nobos_torch_lib.datasets.action_recognition_datasets.ehpi_dataset import ScaleEhpi, TranslateEhpi, \ FlipEhpi, NormalizeEhpi, \ RemoveJointsOutsideImgEhpi from torch.utils.data import ConcatDataset, DataLoader from torchvision.transforms import transforms from ehpi_action_recognition.config import models_dir, ehpi_dataset_path from ehpi_action_recognition.paper_reproduction_code.datasets.ehpi_lstm_dataset import EhpiLSTMDataset from ehpi_action_recognition.paper_reproduction_code.models.ehpi_lstm import EhpiLSTM from ehpi_action_recognition.trainer_ehpi import TrainerEhpi def get_training_set_gt(dataset_path: str, image_size: ImageSize): num_joints = 15 left_indexes: List[int] = [3, 4, 5, 9, 10, 11] right_indexes: List[int] = [6, 7, 8, 12, 13, 14] datasets: List[EhpiLSTMDataset] = [ EhpiLSTMDataset(os.path.join(dataset_path, "JOURNAL_2019_03_GT_30fps"), transform=transforms.Compose([ RemoveJointsOutsideImgEhpi(image_size), ScaleEhpi(image_size), TranslateEhpi(image_size), FlipEhpi(left_indexes=left_indexes, right_indexes=right_indexes), NormalizeEhpi(image_size) ]), num_joints=num_joints), ] for dataset in datasets: dataset.print_label_statistics() return ConcatDataset(datasets) def get_training_posealgo(dataset_path: str, image_size: ImageSize): num_joints = 15 left_indexes: List[int] = [3, 4, 5, 9, 10, 11] right_indexes: List[int] = [6, 7, 8, 12, 13, 14] datasets: List[EhpiLSTMDataset] = [ EhpiLSTMDataset(os.path.join(dataset_path, "JOURNAL_2019_03_POSEALGO_30fps"), transform=transforms.Compose([ RemoveJointsOutsideImgEhpi(image_size), ScaleEhpi(image_size), TranslateEhpi(image_size), FlipEhpi(left_indexes=left_indexes, right_indexes=right_indexes), NormalizeEhpi(image_size) ]), num_joints=num_joints), ] for dataset in datasets: dataset.print_label_statistics() return ConcatDataset(datasets) def get_training_set_both(dataset_path: str, image_size: ImageSize): num_joints = 15 left_indexes: List[int] = [3, 4, 5, 9, 10, 11] right_indexes: List[int] = [6, 7, 8, 12, 13, 14] datasets: List[EhpiLSTMDataset] = [ EhpiLSTMDataset(os.path.join(dataset_path, "JOURNAL_2019_03_POSEALGO_30fps"), transform=transforms.Compose([ RemoveJointsOutsideImgEhpi(image_size), ScaleEhpi(image_size), TranslateEhpi(image_size), FlipEhpi(left_indexes=left_indexes, right_indexes=right_indexes), NormalizeEhpi(image_size) ]), num_joints=num_joints), EhpiLSTMDataset(os.path.join(dataset_path, "JOURNAL_2019_03_GT_30fps"), transform=transforms.Compose([ RemoveJointsOutsideImgEhpi(image_size), ScaleEhpi(image_size), TranslateEhpi(image_size), FlipEhpi(left_indexes=left_indexes, right_indexes=right_indexes), NormalizeEhpi(image_size) ]), num_joints=num_joints), ] for dataset in datasets: dataset.print_label_statistics() return ConcatDataset(datasets) def set_seed(seed): random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) np.random.seed(0) if __name__ == '__main__': batch_size = 256 seeds = [0, 104, 123, 142, 200] datasets = { "gt": get_training_set_gt, "pose": get_training_posealgo, "both": get_training_set_both } for seed in seeds: for dataset_name, get_dataset in datasets.items(): set_seed(seed) train_set = get_dataset(ehpi_dataset_path, ImageSize(1280, 720)) train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=1) # config train_config = TrainingConfigBase("ehpi_journal_2019_03_{}_seed_{}".format(dataset_name, seed), os.path.join(models_dir, "train_its_journal")) train_config.weight_decay = 0 train_config.num_epochs = 200 trainer = TrainerEhpi() trainer.train(train_loader, train_config, model=EhpiLSTM(15, 5))