# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. # pylint: skip-file import sys sys.path.insert(0, '../../python') import mxnet as mx import numpy as np import os, pickle, gzip import logging from mxnet.test_utils import get_cifar10 batch_size = 128 # small mlp network def get_net(): data = mx.symbol.Variable('data') float_data = mx.symbol.Cast(data=data, dtype="float32") fc1 = mx.symbol.FullyConnected(float_data, name='fc1', num_hidden=128) act1 = mx.symbol.Activation(fc1, name='relu1', act_type="relu") fc2 = mx.symbol.FullyConnected(act1, name = 'fc2', num_hidden = 64) act2 = mx.symbol.Activation(fc2, name='relu2', act_type="relu") fc3 = mx.symbol.FullyConnected(act2, name='fc3', num_hidden=10) softmax = mx.symbol.SoftmaxOutput(fc3, name="softmax") return softmax # check data get_cifar10() def get_iterator_uint8(kv): data_shape = (3, 28, 28) train = mx.io.ImageRecordUInt8Iter( path_imgrec = "data/cifar/train.rec", data_shape = data_shape, batch_size = batch_size, rand_crop = True, rand_mirror = True, num_parts = kv.num_workers, part_index = kv.rank) train = mx.io.PrefetchingIter(train) val = mx.io.ImageRecordUInt8Iter( path_imgrec = "data/cifar/test.rec", rand_crop = False, rand_mirror = False, data_shape = data_shape, batch_size = batch_size, num_parts = kv.num_workers, part_index = kv.rank) return (train, val) def get_iterator_float32(kv): data_shape = (3, 28, 28) train = mx.io.ImageRecordIter( path_imgrec = "data/cifar/train.rec", mean_img = "data/cifar/mean.bin", data_shape = data_shape, batch_size = batch_size, rand_crop = True, rand_mirror = True, num_parts = kv.num_workers, part_index = kv.rank) train = mx.io.PrefetchingIter(train) val = mx.io.ImageRecordIter( path_imgrec = "data/cifar/test.rec", mean_img = "data/cifar/mean.bin", rand_crop = False, rand_mirror = False, data_shape = data_shape, batch_size = batch_size, num_parts = kv.num_workers, part_index = kv.rank) return (train, val) num_epoch = 1 def run_cifar10(train, val, use_module): train.reset() val.reset() devs = [mx.cpu(0)] net = get_net() mod = mx.mod.Module(net, context=devs) optim_args = {'learning_rate': 0.001, 'wd': 0.00001, 'momentum': 0.9} eval_metrics = ['accuracy'] if use_module: executor = mx.mod.Module(net, context=devs) executor.fit( train, eval_data=val, optimizer_params=optim_args, eval_metric=eval_metrics, num_epoch=num_epoch, arg_params=None, aux_params=None, begin_epoch=0, batch_end_callback=mx.callback.Speedometer(batch_size, 50), epoch_end_callback=None) else: executor = mx.model.FeedForward.create( net, train, ctx=devs, eval_data=val, eval_metric=eval_metrics, num_epoch=num_epoch, arg_params=None, aux_params=None, begin_epoch=0, batch_end_callback=mx.callback.Speedometer(batch_size, 50), epoch_end_callback=None, **optim_args) ret = executor.score(val, eval_metrics) if use_module: ret = list(ret) logging.info('final accuracy = %f', ret[0][1]) assert (ret[0][1] > 0.08) else: logging.info('final accuracy = %f', ret[0]) assert (ret[0] > 0.08) class CustomDataIter(mx.io.DataIter): def __init__(self, data): super(CustomDataIter, self).__init__() self.data = data self.batch_size = data.provide_data[0][1][0] # use legacy tuple self.provide_data = [(n, s) for n, s in data.provide_data] self.provide_label = [(n, s) for n, s in data.provide_label] def reset(self): self.data.reset() def next(self): return self.data.next() def iter_next(self): return self.data.iter_next() def getdata(self): return self.data.getdata() def getlabel(self): return self.data.getlable() def getindex(self): return self.data.getindex() def getpad(self): return self.data.getpad() def test_cifar10(): # print logging by default logging.basicConfig(level=logging.DEBUG) console = logging.StreamHandler() console.setLevel(logging.DEBUG) logging.getLogger('').addHandler(console) kv = mx.kvstore.create("local") # test float32 input (train, val) = get_iterator_float32(kv) run_cifar10(train, val, use_module=False) run_cifar10(train, val, use_module=True) # test legecay tuple in provide_data and provide_label run_cifar10(CustomDataIter(train), CustomDataIter(val), use_module=False) run_cifar10(CustomDataIter(train), CustomDataIter(val), use_module=True) # test uint8 input (train, val) = get_iterator_uint8(kv) run_cifar10(train, val, use_module=False) run_cifar10(train, val, use_module=True) if __name__ == "__main__": test_cifar10()