#!/usr/bin/env python # -*- coding: utf-8 -*- import time import math import numpy as np import torch import torch.nn as nn from torch.nn import init import torch.nn.parallel import torch.backends.cudnn as cudnn import torch.optim as optim import torch.utils.data import torchvision.datasets as dset import torchvision.transforms as transforms import torchvision.utils as vutils from torch.autograd import Variable import random from skimage.measure import compare_psnr import os from cnn_model import CGP2CNN from my_data_loader import get_train_valid_loader, get_test_loader def weights_init(m): classname = m.__class__.__name__ if classname.find('Conv') != -1: m.weight.data.normal_(0.0, 0.02) elif classname.find('BatchNorm') != -1: m.weight.data.normal_(1.0, 0.02) m.bias.data.fill_(0) def weights_init_normal(m): classname = m.__class__.__name__ if classname.find('Conv2d') != -1: m.apply(weights_init_normal_) elif classname.find('Linear') != -1: init.uniform(m.weight.data, 0.0, 0.02) elif classname.find('BatchNorm2d') != -1: init.uniform(m.weight.data, 1.0, 0.02) init.constant(m.bias.data, 0.0) def weights_init_normal_(m): classname = m.__class__.__name__ if classname.find('Conv') != -1: init.uniform(m.weight.data, 0.0, 0.02) elif classname.find('Linear') != -1: init.uniform(m.weight.data, 0.0, 0.02) elif classname.find('BatchNorm2d') != -1: init.uniform(m.weight.data, 1.0, 0.02) init.constant(m.bias.data, 0.0) def weights_init_xavier(m): classname = m.__class__.__name__ if classname.find('Conv') != -1: init.xavier_normal(m.weight.data, gain=1) elif classname.find('Linear') != -1: init.xavier_normal(m.weight.data, gain=1) elif classname.find('BatchNorm2d') != -1: init.uniform(m.weight.data, 1.0, 0.02) init.constant(m.bias.data, 0.0) def weights_init_kaiming(m): classname = m.__class__.__name__ if classname.find('Conv2d') != -1: init.kaiming_normal(m.weight.data, a=0, mode='fan_in') elif classname.find('Linear') != -1: init.kaiming_normal(m.weight.data, a=0, mode='fan_in') elif classname.find('BatchNorm2d') != -1: init.uniform(m.weight.data, 1.0, 0.02) init.constant(m.bias.data, 0.0) def weights_init_orthogonal(m): classname = m.__class__.__name__ print(classname) if classname.find('Conv') != -1: init.orthogonal(m.weight.data, gain=1) elif classname.find('Linear') != -1: init.orthogonal(m.weight.data, gain=1) elif classname.find('BatchNorm2d') != -1: init.uniform(m.weight.data, 1.0, 0.02) init.constant(m.bias.data, 0.0) def init_weights(net, init_type='normal'): print('initialization method [%s]' % init_type) if init_type == 'normal': net.apply(weights_init_normal) elif init_type == 'xavier': net.apply(weights_init_xavier) elif init_type == 'kaiming': net.apply(weights_init_kaiming) elif init_type == 'orthogonal': net.apply(weights_init_orthogonal) else: raise NotImplementedError('initialization method [%s] is not implemented' % init_type) # __init__: load dataset # __call__: training the CNN defined by CGP list class CNN_train(): def __init__(self, dataset_name, validation=True, verbose=True, imgSize=32, batchsize=128): # dataset_name: name of data set ('bsds'(color) or 'bsds_gray') # validation: [True] model train/validation mode # [False] model test mode for final evaluation of the evolved model # (raining data : all training data, test data : all test data) # verbose: flag of display self.verbose = verbose self.imgSize = imgSize self.validation = validation self.batchsize = batchsize self.dataset_name = dataset_name # load dataset if dataset_name == 'cifar10' or dataset_name == 'mnist': if dataset_name == 'cifar10': self.n_class = 10 self.channel = 3 if self.validation: self.dataloader, self.test_dataloader = get_train_valid_loader(data_dir='./', batch_size=self.batchsize, augment=True, random_seed=2018, num_workers=1, pin_memory=True) # self.dataloader, self.test_dataloader = loaders[0], loaders[1] else: train_dataset = dset.CIFAR10(root='./', train=True, download=True, transform=transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.Scale(self.imgSize), transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), ])) test_dataset = dset.CIFAR10(root='./', train=False, download=True, transform=transforms.Compose([ transforms.Scale(self.imgSize), transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), ])) self.dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=self.batchsize, shuffle=True, num_workers=int(2)) self.test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=self.batchsize, shuffle=True, num_workers=int(2)) print('train num ', len(self.dataloader.dataset)) # print('test num ', len(self.test_dataloader.dataset)) else: print('\tInvalid input dataset name at CNN_train()') exit(1) def __call__(self, cgp, gpuID, epoch_num=200, out_model='mymodel.model'): if self.verbose: print('GPUID :', gpuID) print('epoch_num :', epoch_num) print('batch_size:', self.batchsize) # model torch.backends.cudnn.benchmark = True model = CGP2CNN(cgp, self.channel, self.n_class, self.imgSize) init_weights(model, 'kaiming') model.cuda(gpuID) # Loss and Optimizer criterion = nn.CrossEntropyLoss() criterion.cuda(gpuID) optimizer = optim.Adam(model.parameters(), lr=0.01, betas=(0.5, 0.999)) # optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, dampening=0, weight_decay=0.0005) input = torch.FloatTensor(self.batchsize, self.channel, self.imgSize, self.imgSize) input = input.cuda(gpuID) label = torch.LongTensor(self.batchsize) label = label.cuda(gpuID) # Train loop for epoch in range(1, epoch_num+1): start_time = time.time() if self.verbose: print('epoch', epoch) train_loss = 0 total = 0 correct = 0 ite = 0 for module in model.children(): module.train(True) for _, (data, target) in enumerate(self.dataloader): if self.dataset_name == 'mnist': data = data[:,0:1,:,:] # for gray scale images data = data.cuda(gpuID) target = target.cuda(gpuID) input.resize_as_(data).copy_(data) input_ = Variable(input) label.resize_as_(target).copy_(target) label_ = Variable(label) optimizer.zero_grad() try: output = model(input_, None) except: import traceback traceback.print_exc() return 0. loss = criterion(output, label_) train_loss += loss.data[0] loss.backward() optimizer.step() _, predicted = torch.max(output.data, 1) total += label_.size(0) correct += predicted.eq(label_.data).cpu().sum() ite += 1 print('Train set : Average loss: {:.4f}'.format(train_loss)) print('Train set : Average Acc : {:.4f}'.format(correct/total)) print('time ', time.time()-start_time) if self.validation: if epoch == 30: for param_group in optimizer.param_groups: tmp = param_group['lr'] tmp *= 0.1 for param_group in optimizer.param_groups: param_group['lr'] = tmp if epoch == epoch_num: for module in model.children(): module.train(False) t_loss = self.__test_per_std(model, criterion, gpuID, input, label) else: if epoch == 5: for param_group in optimizer.param_groups: tmp = param_group['lr'] tmp *= 10 for param_group in optimizer.param_groups: param_group['lr'] = tmp if epoch % 10 == 0: for module in model.children(): module.train(False) t_loss = self.__test_per_std(model, criterion, gpuID, input, label) if epoch == 250: for param_group in optimizer.param_groups: tmp = param_group['lr'] tmp *= 0.1 for param_group in optimizer.param_groups: param_group['lr'] = tmp if epoch == 375: for param_group in optimizer.param_groups: tmp = param_group['lr'] tmp *= 0.1 for param_group in optimizer.param_groups: param_group['lr'] = tmp # save the model torch.save(model.state_dict(), './model_%d.pth' % int(gpuID)) return t_loss # For validation/test def __test_per_std(self, model, criterion, gpuID, input, label): test_loss = 0 total = 0 correct = 0 ite = 0 for _, (data, target) in enumerate(self.test_dataloader): if self.dataset_name == 'mnsit': data = data[:,0:1,:,:] data = data.cuda(gpuID) target = target.cuda(gpuID) input.resize_as_(data).copy_(data) input_ = Variable(input) label.resize_as_(target).copy_(target) label_ = Variable(label) try: output = model(input_, None) except: import traceback traceback.print_exc() return 0. loss = criterion(output, label_) test_loss += loss.data[0] _, predicted = torch.max(output.data, 1) total += label_.size(0) correct += predicted.eq(label_.data).cpu().sum() ite += 1 print('Test set : Average loss: {:.4f}'.format(test_loss)) print('Test set : (%d/%d)' % (correct, total)) print('Test set : Average Acc : {:.4f}'.format(correct/total)) return (correct/total)