# -*- coding: utf-8 -*- import argparse import cPickle as pickle import json import numpy as np import scipy.io import random import chainer from chainer import cuda, optimizers, serializers, functions as F from chainer.functions.evaluation import accuracy from net import ImageCaption import time parser = argparse.ArgumentParser(description='Train image caption model') parser.add_argument('--gpu', '-g', default=-1, type=int, help='GPU ID (negative value indicates CPU)') parser.add_argument('--sentence', '-s', required=True, type=str, help='input sentences dataset file path') parser.add_argument('--image', '-i', required=True, type=str, help='input images file path') parser.add_argument('--model', '-m', default=None, type=str, help='input model and state file path without extension') parser.add_argument('--output', '-o', required=True, type=str, help='output model and state file path without extension') parser.add_argument('--iter', default=100, type=int, help='output model and state file path without extension') args = parser.parse_args() gpu_device = None args = parser.parse_args() xp = np if args.gpu >= 0: cuda.check_cuda_available() gpu_device = args.gpu cuda.get_device(gpu_device).use() xp = cuda.cupy with open(args.sentence, 'rb') as f: sentence_dataset = pickle.load(f) image_dataset = scipy.io.loadmat(args.image) images = image_dataset['feats'].transpose((1, 0)) train_image_ids = sentence_dataset['images']['train'] train_sentences = sentence_dataset['sentences']['train'] test_image_ids = sentence_dataset['images']['test'] test_sentences = sentence_dataset['sentences']['test'] word_ids = sentence_dataset['word_ids'] feature_num = images.shape[1] hidden_num = 512 batch_size = 128 print 'word count: ', len(word_ids) caption_net = ImageCaption(len(word_ids), feature_num, hidden_num) if gpu_device is not None: caption_net.to_gpu(gpu_device) optimizer = optimizers.Adam() optimizer.setup(caption_net) if args.model is not None: serializers.load_hdf5(args.model + '.model', caption_net) serializers.load_hdf5(args.model + '.state', optimizer) bos = word_ids['<S>'] eos = word_ids['</S>'] unknown = word_ids['<UNK>'] def random_batches(image_groups, sentence_groups): batches = [] for image_ids, sentences in zip(image_groups, sentence_groups): length = len(sentences) index = np.arange(length, dtype=np.int32) np.random.shuffle(index) for n in range(0, length, batch_size): batch_index = index[n:n + batch_size] batches.append((image_ids[batch_index], sentences[batch_index])) random.shuffle(batches) return batches def make_groups(image_ids, sentences, train=True): if train: boundaries = [1, 6, 11, 16, 21, 31, 41, 51] else: boundaries = range(1, 41) sentence_groups = [] image_groups = [] for begin, end in zip(boundaries[:-1], boundaries[1:]): size = sum(map(lambda x: len(sentences[x]), range(begin, end))) sub_sentences = np.full((size, end + 1), eos, dtype=np.int32) sub_sentences[:, 0] = bos sub_image_ids = np.zeros((size,), dtype=np.int32) offset = 0 for n in range(begin, end): length = len(sentences[n]) if length > 0: sub_sentences[offset:offset + length, 1:n + 1] = sentences[n] sub_image_ids[offset:offset + length] = image_ids[n] offset += length sentence_groups.append(sub_sentences) image_groups.append(sub_image_ids) return image_groups, sentence_groups def forward(net, image_batch, sentence_batch, train=True): images = xp.asarray(image_batch) n, sentence_length = sentence_batch.shape net.initialize(images) loss = 0 acc = 0 size = 0 for i in range(sentence_length - 1): target = xp.where(xp.asarray(sentence_batch[:, i]) != eos, 1, 0).astype(np.float32) if (target == 0).all(): break with chainer.using_config('train', train): with chainer.using_config('enable_backprop', train): x = xp.asarray(sentence_batch[:, i]) t = xp.asarray(sentence_batch[:, i + 1]) y = net(x) y_max_index = xp.argmax(y.data, axis=1) mask = target.reshape((len(target), 1)).repeat(y.data.shape[1], axis=1) y = y * mask loss += F.softmax_cross_entropy(y, t) acc += xp.sum((y_max_index == t) * target) size += xp.sum(target) return loss / size, float(acc) / size, float(size) def train(epoch_num): image_groups, sentence_groups = make_groups(train_image_ids, train_sentences) test_image_groups, test_sentence_groups = make_groups(test_image_ids, test_sentences, train=False) for epoch in range(epoch_num): batches = random_batches(image_groups, sentence_groups) sum_loss = 0 sum_acc = 0 sum_size = 0 batch_num = len(batches) for i, (image_id_batch, sentence_batch) in enumerate(batches): loss, acc, size = forward(caption_net, images[image_id_batch], sentence_batch) caption_net.cleargrads() loss.backward() loss.unchain_backward() optimizer.update() sentence_length = sentence_batch.shape[1] sum_loss += float(loss.data) * size sum_acc += acc * size sum_size += size if (i + 1) % 500 == 0: print '{} / {} loss: {} accuracy: {}'.format(i + 1, batch_num, sum_loss / sum_size, sum_acc / sum_size) print 'epoch: {} done'.format(epoch + 1) print 'train loss: {} accuracy: {}'.format(sum_loss / sum_size, sum_acc / sum_size) sum_loss = 0 sum_acc = 0 sum_size = 0 for image_ids, sentences in zip(test_image_groups, test_sentence_groups): if len(sentences) == 0: continue size = len(sentences) for i in range(0, size, batch_size): image_id_batch = image_ids[i:i + batch_size] sentence_batch = sentences[i:i + batch_size] loss, acc, size = forward(caption_net, images[image_id_batch], sentence_batch, train=False) sentence_length = sentence_batch.shape[1] sum_loss += float(loss.data) * size sum_acc += acc * size sum_size += size print 'test loss: {} accuracy: {}'.format(sum_loss / sum_size, sum_acc / sum_size) serializers.save_hdf5(args.output + '_{0:04d}.model'.format(epoch), caption_net) serializers.save_hdf5(args.output + '_{0:04d}.state'.format(epoch), optimizer) train(args.iter)