''' author: Johannes Wagner <wagner@hcm-lab.de> created: 2018/05/04 Copyright (C) University of Augsburg, Lab for Human Centered Multimedia ''' import sys, os, json, glob #os.environ['CUDA_VISIBLE_DEVICES'] = '-1' import tensorflow as tf import numpy as np def getOptions(opts,vars): opts['path'] = '' def getSampleDimensionOut(dim, opts, vars): vars['loaded'] = False vars['n_classes'] = 0 try: load_model(opts, vars) vars['loaded'] = True except Exception as ex: print(ex) return vars['n_classes'] def getSampleTypeOut(type, types, opts, vars): if type != types.FLOAT: print('types other than float are not supported') return types.UNDEF return type def load_model(opts, vars): print('load model ', opts['path']) if os.path.isdir(opts['path']): files = glob.glob(os.path.join(opts['path'], 'model.ckpt-*.meta')) if files: files.sort() checkpoint_path, _ = os.path.splitext(files[-1]) else: checkpoint_path = opts['path'] if not all([os.path.exists(checkpoint_path + x) for x in ['.data-00000-of-00001', '.index', '.meta']]): print('ERROR: could not load model') raise FileNotFoundError vocabulary_path = checkpoint_path + '.json' if not os.path.exists(vocabulary_path): vocabulary_path = os.path.join(os.path.dirname(checkpoint_path), 'vocab.json') if not os.path.exists(vocabulary_path): print('ERROR: could not load vocabulary') raise FileNotFoundError graph = tf.Graph() with graph.as_default(): print('loading model {}'.format(checkpoint_path)) saver = tf.train.import_meta_graph(checkpoint_path + '.meta') with open(vocabulary_path, 'r') as fp: vocab = json.load(fp) x = graph.get_tensor_by_name(vocab['x']) y = graph.get_tensor_by_name(vocab['y']) init = graph.get_operation_by_name(vocab['init']) logits = graph.get_tensor_by_name(vocab['logits']) ph_n_shuffle = graph.get_tensor_by_name(vocab['n_shuffle']) ph_n_repeat = graph.get_tensor_by_name(vocab['n_repeat']) ph_n_batch = graph.get_tensor_by_name(vocab['n_batch']) vars['n_classes'] = len(vocab['targets']) sess = tf.Session() saver.restore(sess, checkpoint_path) vars['sess'] = sess vars['x'] = x vars['y'] = y vars['ph_n_shuffle'] = ph_n_shuffle vars['ph_n_repeat'] = ph_n_repeat vars['ph_n_batch'] = ph_n_batch vars['init'] = init vars['logits'] = logits def transform_enter(sin, sout, sxtra, board, opts, vars): pass def transform(info, sin, sout, sxtra, board, opts, vars): if vars['loaded']: sess = vars['sess'] x = vars['x'] y = vars['y'] ph_n_shuffle = vars['ph_n_shuffle'] ph_n_repeat = vars['ph_n_repeat'] ph_n_batch = vars['ph_n_batch'] init = vars['init'] logits = vars['logits'] input = np.asmatrix(sin).reshape(-1, x.shape[1]) dummy = np.zeros((input.shape[0],), dtype=np.int32) sess.run(init, feed_dict = { x : input, y : dummy, ph_n_shuffle : 1, ph_n_repeat : 1, ph_n_batch : input.shape[0] }) output = sess.run(logits) output = np.mean(output, axis=0) for i in range(sout.dim): sout[i] = output[i] def transform_flush(sin, sout, sxtra, board, opts, vars): if vars['loaded']: sess = vars['sess'] sess.close()