""" Copyright 2016 Yahoo Inc. Licensed under the terms of the 2 clause BSD license. Please see LICENSE file in the project root for terms. """ import sys import os if os.path.isfile('../caffe_path.txt'): fid = open('../caffe_path.txt', 'r') caffe_root = fid.readline().strip('\n') fid.close() else: caffe_root = '/home/luojh2/Software/caffe-master/python/' sys.path.insert(0, caffe_root) import caffe from caffe.proto import caffe_pb2 from google.protobuf import text_format import tempfile caffe.set_mode_cpu() def _create_file_from_netspec(netspec): f = tempfile.NamedTemporaryFile(mode='w+', delete=False) f.write(str(netspec.to_proto())) return f.name def get_complexity(netspec=None, prototxt_file=None, mode=None): # One of netspec, or prototxt_path params should not be None assert (netspec is not None) or (prototxt_file is not None) if netspec is not None: prototxt_file = _create_file_from_netspec(netspec) net = caffe.Net(prototxt_file, caffe.TEST) total_params = 0 total_flops = 0 net_params = caffe_pb2.NetParameter() text_format.Merge(open(prototxt_file).read(), net_params) print '\n ########### output ###########' for layer in net_params.layer: if layer.name in net.params: params = net.params[layer.name][0].data.size # If convolution layer, multiply flops with receptive field # i.e. #params * datawidth * dataheight if layer.type == 'Convolution': # 'conv' in layer: data_width = net.blobs[layer.name].data.shape[2] data_height = net.blobs[layer.name].data.shape[3] flops = net.params[layer.name][ 0].data.size * data_width * data_height # print >> sys.stderr, layer.name, params, flops else: flops = net.params[layer.name][0].data.size flops *= 2 print('%s: #params: %s, #FLOPs: %s') % ( layer.name, digit2string(params), digit2string(flops)) total_params += params total_flops += flops if netspec is not None: os.remove(prototxt_file) return total_params, total_flops def digit2string(x): x = float(x) if x < 10 ** 3: return "%.2f" % float(x) elif x < 10 ** 6: x = x / 10 ** 3 return "%.2f" % float(x) + 'K' elif x < 10 ** 9: x = x / 10 ** 6 return "%.2f" % float(x) + 'M' else: x = x / 10 ** 9 return "%.2f" % float(x) + 'B' if __name__ == '__main__': length = len(sys.argv) if length == 1: filepath = 'deploy.prototxt' else: filepath = sys.argv[1] params, flops = get_complexity(prototxt_file=filepath, mode='Test') print '\n ########### result ###########' print '#params=%s, #FLOPs=%s' % (digit2string(params), digit2string(flops))