# # Deep Knowledge Tracing (DKT) Implementation # Mohammad M H Khajah <mohammad.khajah@colorado.edu> # Copyright (c) 2016 all rights reserved. # # How to use: # python dkt.py dataset.txt dataset_split.txt # # Script saves 3 files: # dataset.txt.model_weights trained model weights # dataset.txt.history training history (training LL, test AUC) # dataset.txt.preds predictions for test trials # import os import sys import numpy as np from keras.preprocessing import sequence from keras.utils import np_utils from keras.models import Sequential, Graph from keras.layers.core import TimeDistributedDense, Masking from keras.layers.recurrent import LSTM from keras import backend as K from sklearn.metrics import roc_auc_score import theano.tensor as Th import random import math import argparse def main(): parser = argparse.ArgumentParser(description='Process some integers.') parser.add_argument('--dataset', type=str, help='Dataset file', required=True) parser.add_argument('--splitfile', type=str, help='Split file', required=True) parser.add_argument('--hiddenunits', type=int, help='Number of LSTM hidden units.', default=200, required=False) parser.add_argument('--batchsize', type=int, help='Number of sequences to process in a batch.', default=5, required=False) parser.add_argument('--timewindow', type=int, help='Number of timesteps to process in a batch.', default=100, required=False) parser.add_argument('--epochs', type=int, help='Number of epochs.', default=50, required=False) args = parser.parse_args() dataset = args.dataset split_file = args.splitfile hidden_units = args.hiddenunits batch_size = args.batchsize time_window = args.timewindow epochs = args.epochs model_file = dataset + '.model_weights' history_file = dataset + '.history' preds_file = dataset + '.preds' overall_loss = [0.0] preds = [] history = [] # load dataset training_seqs, testing_seqs, num_skills = load_dataset(dataset, split_file) print "Training Sequences: %d" % len(training_seqs) print "Testing Sequences: %d" % len(testing_seqs) print "Number of skills: %d" % num_skills # Our loss function # The model gives predictions for all skills so we need to get the # prediction for the skill at time t. We do that by taking the column-wise # dot product between the predictions at each time slice and a # one-hot encoding of the skill at time t. # y_true: (nsamples x nsteps x nskills+1) # y_pred: (nsamples x nsteps x nskills) def loss_function(y_true, y_pred): skill = y_true[:,:,0:num_skills] obs = y_true[:,:,num_skills] rel_pred = Th.sum(y_pred * skill, axis=2) # keras implementation does a mean on the last dimension (axis=-1) which # it assumes is a singleton dimension. But in our context that would # be wrong. return K.binary_crossentropy(rel_pred, obs) # build model model = Sequential() # ignore padding model.add(Masking(-1.0, batch_input_shape=(batch_size, time_window, num_skills*2))) # lstm configured to keep states between batches model.add(LSTM(input_dim = num_skills*2, output_dim = hidden_units, return_sequences=True, batch_input_shape=(batch_size, time_window, num_skills*2), stateful = True )) # readout layer. TimeDistributedDense uses the same weights for all # time steps. model.add(TimeDistributedDense(input_dim = hidden_units, output_dim = num_skills, activation='sigmoid')) # optimize with rmsprop which dynamically adapts the learning # rate of each weight. model.compile(loss=loss_function, optimizer='rmsprop',class_mode="binary") # training function def trainer(X, Y): overall_loss[0] += model.train_on_batch(X, Y)[0] # prediction def predictor(X, Y): batch_activations = model.predict_on_batch(X) skill = Y[:,:,0:num_skills] obs = Y[:,:,num_skills] y_pred = np.squeeze(np.array(batch_activations)) rel_pred = np.sum(y_pred * skill, axis=2) for b in xrange(0, X.shape[0]): for t in xrange(0, X.shape[1]): if X[b, t, 0] == -1.0: continue preds.append((rel_pred[b][t], obs[b][t])) # call when prediction batch is finished # resets LSTM state because we are done with all sequences in the batch def finished_prediction_batch(percent_done): model.reset_states() # similiar to the above def finished_batch(percent_done): print "(%4.3f %%) %f" % (percent_done, overall_loss[0]) model.reset_states() # run the model for e in xrange(0, epochs): model.reset_states() # train run_func(training_seqs, num_skills, trainer, batch_size, time_window, finished_batch) model.reset_states() # test run_func(testing_seqs, num_skills, predictor, batch_size, time_window, finished_prediction_batch) # compute AUC auc = roc_auc_score([p[1] for p in preds], [p[0] for p in preds]) # log history.append((overall_loss[0], auc)) # save model model.save_weights(model_file, overwrite=True) print "==== Epoch: %d, Test AUC: %f" % (e, auc) # reset loss overall_loss[0] = 0.0 # save predictions with open(preds_file, 'w') as f: f.write('was_heldout\tprob_recall\tstudent_recalled\n') for pred in preds: f.write('1\t%f\t%d\n' % (pred[0], pred[1])) with open(history_file, 'w') as f: for h in history: f.write('\t'.join([str(he) for he in h])) f.write('\n') # clear preds preds = [] def run_func(seqs, num_skills, f, batch_size, time_window, batch_done = None): assert(min([len(s) for s in seqs]) > 0) # randomize samples seqs = seqs[:] random.shuffle(seqs) processed = 0 for start_from in xrange(0, len(seqs), batch_size): end_before = min(len(seqs), start_from + batch_size) x = [] y = [] for seq in seqs[start_from:end_before]: x_seq = [] y_seq = [] xt_zeros = [0 for i in xrange(0, num_skills*2)] ct_zeros = [0 for i in xrange(0, num_skills+1)] xt = xt_zeros[:] for skill, is_correct in seq: x_seq.append(xt) ct = ct_zeros[:] ct[skill] = 1 ct[num_skills] = is_correct y_seq.append(ct) # one hot encoding of (last_skill, is_correct) pos = skill * 2 + is_correct xt = xt_zeros[:] xt[pos] = 1 x.append(x_seq) y.append(y_seq) maxlen = max([len(s) for s in x]) maxlen = round_to_multiple(maxlen, time_window) # fill up the batch if necessary if len(x) < batch_size: for e in xrange(0, batch_size - len(x)): x_seq = [] y_seq = [] for t in xrange(0, time_window): x_seq.append([-1.0 for i in xrange(0, num_skills*2)]) y_seq.append([0.0 for i in xrange(0, num_skills+1)]) x.append(x_seq) y.append(y_seq) X = pad_sequences(x, padding='post', maxlen = maxlen, dim=num_skills*2, value=-1.0) Y = pad_sequences(y, padding='post', maxlen = maxlen, dim=num_skills+1, value=-1.0) for t in xrange(0, maxlen, time_window): f(X[:,t:(t+time_window),:], Y[:,t:(t+time_window),:]) processed += end_before - start_from # reset the states for the next batch of sequences if batch_done: batch_done((processed * 100.0) / len(seqs)) def round_to_multiple(x, base): return int(base * math.ceil(float(x)/base)) def load_dataset(dataset, split_file): seqs, num_skills = read_file(dataset) with open(split_file, 'r') as f: student_assignment = f.read().split(' ') training_seqs = [seqs[i] for i in xrange(0, len(seqs)) if student_assignment[i] == '1'] testing_seqs = [seqs[i] for i in xrange(0, len(seqs)) if student_assignment[i] == '0'] return training_seqs, testing_seqs, num_skills def read_file(dataset_path): seqs_by_student = {} problem_ids = {} next_problem_id = 0 with open(dataset_path, 'r') as f: for line in f: student, problem, is_correct = line.strip().split(' ') student = int(student) if student not in seqs_by_student: seqs_by_student[student] = [] if problem not in problem_ids: problem_ids[problem] = next_problem_id next_problem_id += 1 seqs_by_student[student].append((problem_ids[problem], int(is_correct == '1'))) sorted_keys = sorted(seqs_by_student.keys()) return [seqs_by_student[k] for k in sorted_keys], next_problem_id # https://groups.google.com/forum/#!msg/keras-users/7sw0kvhDqCw/QmDMX952tq8J def pad_sequences(sequences, maxlen=None, dim=1, dtype='int32', padding='pre', truncating='pre', value=0.): ''' Override keras method to allow multiple feature dimensions. @dim: input feature dimension (number of features per timestep) ''' lengths = [len(s) for s in sequences] nb_samples = len(sequences) if maxlen is None: maxlen = np.max(lengths) x = (np.ones((nb_samples, maxlen, dim)) * value).astype(dtype) for idx, s in enumerate(sequences): if truncating == 'pre': trunc = s[-maxlen:] elif truncating == 'post': trunc = s[:maxlen] else: raise ValueError("Truncating type '%s' not understood" % padding) if padding == 'post': x[idx, :len(trunc)] = trunc elif padding == 'pre': x[idx, -len(trunc):] = trunc else: raise ValueError("Padding type '%s' not understood" % padding) return x if __name__ == "__main__": main()