import os import sys sys.path.append(os.path.join(os.path.dirname(__file__), "../")) import argparse import numpy as np import chainer from chainer import cuda from chainer import optimizers from chainer import serializers import alex from mlimages.gather.imagenet import ImagenetAPI from mlimages.label import LabelingMachine from mlimages.training import TrainingData from mlimages.model import ImageProperty DATA_DIR = os.path.join(os.path.dirname(__file__), "./data/imagenet/") IMAGES_ROOT = os.path.join(DATA_DIR, "./images") LABEL_FILE = os.path.join(os.path.dirname(__file__), "./data/imagenet/label.txt") LABEL_DEF_FILE = os.path.join(os.path.dirname(__file__), "./data/imagenet/label_def.txt") MEAN_IMAGE_FILE = os.path.join(os.path.dirname(__file__), "./data/imagenet/mean_image.png") MODEL_FILE = os.path.join(os.path.dirname(__file__), "./data/imagenet/chainer_alex.model") IMAGE_PROP = ImageProperty(width=227, resize_by_downscale=True) def download_imagenet(wnid, limit=-1): api = ImagenetAPI(data_root=DATA_DIR, limit=limit, debug=True) api.logger.info("start to gather the ImageNet images.") folders = api.gather(wnid, include_subset=True) # rename images root folder images_root = os.path.join(DATA_DIR, folders[0]) os.rename(images_root, IMAGES_ROOT) print("Down load has done.") def make_label(): machine = LabelingMachine(data_root=IMAGES_ROOT) lf = machine.label_dir_auto(label_file=LABEL_FILE, label_def_file=LABEL_DEF_FILE) def show(limit, shuffle=True): td = TrainingData(LABEL_FILE, img_root=IMAGES_ROOT, mean_image_file=MEAN_IMAGE_FILE, image_property=IMAGE_PROP) _limit = limit if limit > 0 else 5 iterator = td.generate() if shuffle: import random shuffled = list(iterator) random.shuffle(shuffled) iterator = iter(shuffled) i = 0 for arr, im in iterator: restored = td.data_to_image(arr, im.label, raw=True) print(im.path) restored.image.show() i += 1 if i >= _limit: break def train(epoch=10, batch_size=32, gpu=False): if gpu: cuda.check_cuda_available() xp = cuda.cupy if gpu else np td = TrainingData(LABEL_FILE, img_root=IMAGES_ROOT, image_property=IMAGE_PROP) # make mean image if not os.path.isfile(MEAN_IMAGE_FILE): print("make mean image...") td.make_mean_image(MEAN_IMAGE_FILE) else: td.mean_image_file = MEAN_IMAGE_FILE # train model label_def = LabelingMachine.read_label_def(LABEL_DEF_FILE) model = alex.Alex(len(label_def)) optimizer = optimizers.MomentumSGD(lr=0.01, momentum=0.9) optimizer.setup(model) epoch = epoch batch_size = batch_size print("Now our model is {0} classification task.".format(len(label_def))) print("begin training the model. epoch:{0} batch size:{1}.".format(epoch, batch_size)) if gpu: model.to_gpu() for i in range(epoch): print("epoch {0}/{1}: (learning rate={2})".format(i + 1, epoch, optimizer.lr)) td.shuffle(overwrite=True) for x_batch, y_batch in td.generate_batches(batch_size): x = chainer.Variable(xp.asarray(x_batch)) t = chainer.Variable(xp.asarray(y_batch)) optimizer.update(model, x, t) print("loss: {0}, accuracy: {1}".format(float(model.loss.data), float(model.accuracy.data))) serializers.save_npz(MODEL_FILE, model) optimizer.lr *= 0.97 def predict(limit): _limit = limit if limit > 0 else 5 td = TrainingData(LABEL_FILE, img_root=IMAGES_ROOT, mean_image_file=MEAN_IMAGE_FILE, image_property=IMAGE_PROP) label_def = LabelingMachine.read_label_def(LABEL_DEF_FILE) model = alex.Alex(len(label_def)) serializers.load_npz(MODEL_FILE, model) i = 0 for arr, im in td.generate(): x = np.ndarray((1,) + arr.shape, arr.dtype) x[0] = arr x = chainer.Variable(np.asarray(x), volatile="on") y = model.predict(x) p = np.argmax(y.data) print("predict {0}, actual {1}".format(label_def[p], label_def[im.label])) im.image.show() i += 1 if i >= _limit: break if __name__ == "__main__": parser = argparse.ArgumentParser(description="Example of Imagenet x AlexNet") parser.add_argument("task", type=str, help="task of script. " + "".join([ "g: gather images", "l: make label file", "s: show training images (shuffle data when 'ss')", "t: train model", "p: predict" ])) parser.add_argument("-wnid", type=str, help="imagenet id (default is cats(n02121808))", default="n02121808") parser.add_argument("-limit", type=int, help="g: download image limit, s,p: show/predict image limit", default=-1) parser.add_argument("-epoch", type=int, help="when t: epoch count", default=10) parser.add_argument("-batchsize", type=int, help="when t: batch size", default=32) parser.add_argument("-gpu", action="store_true", help="when t: use gpu") args = parser.parse_args() if args.task == "g": download_imagenet(args.wnid, args.limit) elif args.task == "l": print("create label data automatically.") make_label() elif args.task == "s": show(args.limit, shuffle=False) elif args.task == "ss": show(args.limit, shuffle=True) elif args.task == "t": train(epoch=args.epoch, batch_size=args.batchsize, gpu=args.gpu) elif args.task == "p": predict(args.limit)