"""Predict a title for a recipe.""" from os import path import random import json import pickle import h5py import numpy as np from utils import str_shape import keras.backend as K import argparse from config import path_models, path_data from constants import FN1, FN0, nb_unknown_words, eos from model import create_model from sample_gen import gensamples # set seeds in random libraries seed = 42 random.seed(seed) np.random.seed(seed) def load_weights(model, filepath): """Load all weights possible into model from filepath. This is a modified version of keras load_weights that loads as much as it can if there is a mismatch between file and model. It returns the weights of the first layer in which the mismatch has happened """ print('Loading', filepath, 'to', model.name) with h5py.File(filepath, mode='r') as f: # new file format layer_names = [n.decode('utf8') for n in f.attrs['layer_names']] # we batch weight value assignments in a single backend call # which provides a speedup in TensorFlow. weight_value_tuples = [] for name in layer_names: print(name) g = f[name] weight_names = [n.decode('utf8') for n in g.attrs['weight_names']] if len(weight_names): weight_values = [g[weight_name] for weight_name in weight_names] try: layer = model.get_layer(name=name) except: layer = None if not layer: print('failed to find layer', name, 'in model') print('weights', ' '.join(str_shape(w) for w in weight_values)) print('stopping to load all other layers') weight_values = [np.array(w) for w in weight_values] break symbolic_weights = layer.trainable_weights + layer.non_trainable_weights weight_value_tuples += zip(symbolic_weights, weight_values) weight_values = None K.batch_set_value(weight_value_tuples) return weight_values def main(sample_str=None): """Predict a title for a recipe.""" # load model parameters used for training with open(path.join(path_models, 'model_params.json'), 'r') as f: model_params = json.load(f) # create placeholder model model = create_model(**model_params) # load weights from training run load_weights(model, path.join(path_models, '{}.hdf5'.format(FN1))) # load recipe titles and descriptions with open(path.join(path_data, 'vocabulary-embedding.data.pkl'), 'rb') as fp: X_data, Y_data = pickle.load(fp) # load vocabulary with open(path.join(path_data, '{}.pkl'.format(FN0)), 'rb') as fp: embedding, idx2word, word2idx, glove_idx2idx = pickle.load(fp) vocab_size, embedding_size = embedding.shape oov0 = vocab_size - nb_unknown_words if sample_str is None: # load random recipe description if none provided i = np.random.randint(len(X_data)) sample_str = '' sample_title = '' for w in X_data[i]: sample_str += idx2word[w] + ' ' for w in Y_data[i]: sample_title += idx2word[w] + ' ' y = Y_data[i] print('Randomly sampled recipe:') print(sample_title) print(sample_str) else: sample_title = '' y = [eos] x = [word2idx[w.rstrip('^')] for w in sample_str.split()] samples = gensamples( skips=2, k=1, batch_size=2, short=False, temperature=1., use_unk=True, model=model, data=(x, y), idx2word=idx2word, oov0=oov0, glove_idx2idx=glove_idx2idx, vocab_size=vocab_size, nb_unknown_words=nb_unknown_words, ) headline = samples[0][0][len(samples[0][1]):] ' '.join(idx2word[w] for w in headline) if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--sample-str', type=str, default=None, help='Sample recipe description') args = parser.parse_args() main(sample_str=args.sample_str)