import tensorflow as tf from lm_human_preferences.language import model, sample from lm_human_preferences.utils import core as utils from lm_human_preferences.utils.core import Schema class Policy: def __init__( self, trained_model, *, scope=None, use_resource=False, embed_queries=lambda queries: queries, temperature=1.0, is_root=True, build_respond=True, ): self.trained_model = trained_model self.model_hparams = trained_model.hparams() self.is_root = is_root self.use_resource = use_resource self.encoder = self.trained_model.encoding.get_encoder() with tf.variable_scope(scope, 'transformer_policy', use_resource=self.use_resource) as s: self.scope = s self.model = model.Model( hparams=self.model_hparams, scalar_heads=['value']) self.built = False self.embed_queries = embed_queries self.temperature = temperature self.padding_token = self.encoder.padding_token if build_respond: self.respond = utils.graph_function( queries=Schema(tf.int32, (None, None)), length=Schema(tf.int32, ()), )(self.respond_op) self.analyze_responses = utils.graph_function( queries=Schema(tf.int32, (None, None)), responses=Schema(tf.int32, (None, None)), )(self.analyze_responses_op) def get_encoder(self): return self.encoder def step_core(self, model_hparams, tokens, past=None, past_tokens=None, do_dropout=False, name=None): with tf.name_scope(name, 'step'): with tf.variable_scope( self.scope, reuse=self.built, auxiliary_name_scope=not self.built, use_resource=self.use_resource): lm_output = self.model(X=tokens, past=past, past_tokens=past_tokens, do_dropout=do_dropout, padding_token=self.padding_token) # need to slice logits since we don't want to generate special tokens logits = lm_output['lm_logits'][:,:,:self.model_hparams.n_vocab] presents = lm_output['present'] value = lm_output['value'] if not self.built: self._set_initializers() self.built = True return { 'logits': logits, 'values': value, 'presents': presents, } def ensure_built(self): if not self.built: with tf.name_scope('dummy'): self.step_core(self.model_hparams, tokens=tf.zeros([0,0], dtype=tf.int32)) def get_params(self): self.ensure_built() params = utils.find_trainable_variables(self.scope.name) assert len(params) > 0 return params def _set_initializers(self): """Change initializers to load a language model from a tensorflow checkpoint.""" # Skip if # 1. We're not rank 0. Values will be copied from there. # 2. We want random initialization. Normal initialization will do the work. if not self.is_root or self.trained_model.name == 'test': return with tf.init_scope(): scope = self.scope.name # Initialize! params = {v.op.name: v for v in utils.find_trainable_variables(scope)} self.trained_model.init_op(params, new_scope=scope) def respond_op(self, queries, length): contexts = self.embed_queries(queries) context_length = tf.shape(contexts)[1] result = sample.sample_sequence( step=self.step_core, context=contexts, length=length, model_hparams=self.model_hparams, temperature=self.temperature, extra_outputs={'values':tf.float32}, ) return dict( responses=result['tokens'][:, context_length:], logprobs=result['logprobs'], values=result['values'], ) def analyze_responses_op(self, queries, responses): contexts = self.embed_queries(queries) context_length = tf.shape(contexts)[1] tokens = tf.concat([contexts, responses], axis=1) result = self.step_core(self.model_hparams, tokens) logits = result['logits'][:, context_length-1:-1] logits /= self.temperature return dict( logprobs = utils.logprobs_from_logits(logits=logits, labels=responses), entropies = utils.entropy_from_logits(logits), values = result['values'][:, context_length-1:-1], )