from __future__ import print_function import argparse import torch import torch.utils.data import torch.nn as nn import torch.optim as optim from torch.autograd import Variable from torchvision import datasets, transforms from sklearn.mixture import BayesianGaussianMixture import numpy as np parser = argparse.ArgumentParser(description='PyTorch MNIST Example') parser.add_argument('--batch-size', type=int, default=128, metavar='N', help='input batch size for training (default: 64)') parser.add_argument('--epochs', type=int, default=10, metavar='N', help='number of epochs to train (default: 2)') parser.add_argument('--no-cuda', action='store_true', default=False, help='enables CUDA training') parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)') parser.add_argument('--log-interval', type=int, default=10, metavar='N', help='how many batches to wait before logging training status') parser.add_argument('--hidden',type=int,default=10,metavar='N', help='number of dimension for z') parser.add_argument('--comp',type=int,default=100,metavar='N', help='maximum number of components in DP') args = parser.parse_args() args.cuda = not args.no_cuda and torch.cuda.is_available() torch.manual_seed(args.seed) if args.cuda: torch.cuda.manual_seed(args.seed) kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {} train_loader = torch.utils.data.DataLoader( datasets.MNIST('../data', train=True, download=True, transform=transforms.ToTensor()), batch_size=args.batch_size, shuffle=True, **kwargs) test_loader = torch.utils.data.DataLoader( datasets.MNIST('../data', train=False, transform=transforms.ToTensor()), batch_size=args.batch_size, shuffle=True, **kwargs) class VAE(nn.Module): def __init__(self): super(VAE, self).__init__() self.fc1 = nn.Linear(784, 400) self.fc21 = nn.Linear(400, args.hidden) self.fc22 = nn.Linear(400, args.hidden) self.fc3 = nn.Linear(args.hidden, 400) self.fc4 = nn.Linear(400, 784) self.relu = nn.ReLU() self.sigmoid = nn.Sigmoid() def encode(self, x): h1 = self.relu(self.fc1(x)) return self.fc21(h1), self.fc22(h1) def reparametrize(self, mu, logvar): std = logvar.mul(0.5).exp_() if args.cuda: eps = torch.cuda.FloatTensor(std.size()).normal_() else: eps = torch.FloatTensor(std.size()).normal_() eps = Variable(eps) return eps.mul(std).add_(mu) def decode(self, z): h3 = self.relu(self.fc3(z)) return self.sigmoid(self.fc4(h3)) def forward(self, x): mu, logvar = self.encode(x.view(-1, 784)) z = self.reparametrize(mu, logvar) return self.decode(z), mu, logvar, z def sample(self, model, n): z = Variable(torch.from_numpy(model.sample(n)[0].astype(np.float32))).cuda() return self.decode(z) model = VAE() if args.cuda: model.cuda() reconstruction_function = nn.BCELoss() reconstruction_function.size_average = False C = None def KL(model, mu, logvar, z): global C C = model.predict(z.cpu().data.numpy()) muc = Variable(torch.from_numpy(model.means_[C].astype(np.float32))) varc = Variable(torch.from_numpy(np.log(model.covariances_[C],dtype=np.float32))) if args.cuda: muc = muc.cuda() varc = varc.cuda() return torch.sum(muc.sub_(mu).pow(2).div(varc.exp()).add_(varc).sub_(logvar).add_(logvar.exp().div(varc.exp()))).mul_(0.5) def loss_function(recon_x, x, mu, logvar, model, z): BCE = reconstruction_function(recon_x, x) # see Appendix B from VAE paper: # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014 # https://arxiv.org/abs/1312.6114 # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) #KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar) #KLD = torch.sum(KLD_element).mul_(-0.5) return BCE + KL(model, mu, logvar, z) optimizer = optim.Adam(model.parameters(), lr=1e-3) def getz(): tmp = [] for (data,_) in train_loader: data = Variable(data) if args.cuda: data = data.cuda() recon_batch, mu, logvar, z = model(data) tmp.append(z.cpu().data.numpy()) return np.vstack(tmp) def train(epoch, prior): model.train() train_loss = 0 #prior = BayesianGaussianMixture(n_components=1, covariance_type='diag') tmp = [] for (data,_) in train_loader: data = Variable(data) if args.cuda: data = data.cuda() recon_batch, mu, logvar, z = model(data) tmp.append(z.cpu().data.numpy()) print('Update Prior') prior.fit(np.vstack(tmp)) print('prior: '+str(prior.weights_)) for batch_idx, (data, _) in enumerate(train_loader): data = Variable(data) if args.cuda: data = data.cuda() optimizer.zero_grad() recon_batch, mu, logvar, z = model(data) loss = loss_function(recon_batch, data, mu, logvar, prior, z) loss.backward() train_loss += loss.data[0] optimizer.step() #if batch_idx % args.log_interval == 0: # print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( # epoch, batch_idx * len(data), len(train_loader.dataset), # 100. * batch_idx / len(train_loader), # loss.data[0] / len(data))) print('====> Epoch: {} Average loss: {:.4f}'.format( epoch, train_loss / len(train_loader.dataset))) return prior def test(epoch, prior): model.eval() test_loss = 0 ans = np.zeros((args.comp, 10)) for data, lab in test_loader: if args.cuda: data = data.cuda() data = Variable(data, volatile=True) recon_batch, mu, logvar, z = model(data) test_loss += loss_function(recon_batch, data, mu, logvar, prior, z).data[0] for i in xrange(len(lab)): ans[C[i],lab[i]]+=1 print(ans) s = np.sum(ans) v = 0 for i in xrange(ans.shape[0]): for j in xrange(ans.shape[1]): if ans[i,j]>0: v += ans[i,j]/s*np.log(ans[i,j]/s/(np.sum(ans[i,:])/s)/(np.sum(ans[:,j])/s)) print("Mutual information: "+str(v)) test_loss /= len(test_loader.dataset) print('====> Test set loss: {:.4f}'.format(test_loss)) prior = BayesianGaussianMixture(n_components=args.comp, covariance_type='diag') for epoch in range(1, args.epochs + 1): prior=train(epoch, prior) test(epoch,prior) np.savetxt('z.txt',getz()) regen = model.sample(prior, 1024).cpu().data.numpy() import cPickle with open('img.pickle','wb') as f: cPickle.dump(regen, f) np.savetxt('model/fc1_W', model.fc1.weight.cpu().data.numpy()) np.savetxt('model/fc1_b', model.fc1.bias.cpu().data.numpy()) np.savetxt('model/fc21_W', model.fc21.weight.cpu().data.numpy()) np.savetxt('model/fc21_b', model.fc21.bias.cpu().data.numpy()) np.savetxt('model/fc22_W', model.fc22.weight.cpu().data.numpy()) np.savetxt('model/fc22_b', model.fc22.bias.cpu().data.numpy()) np.savetxt('model/fc3_W', model.fc3.weight.cpu().data.numpy()) np.savetxt('model/fc3_b', model.fc3.bias.cpu().data.numpy()) np.savetxt('model/fc4_W', model.fc4.weight.cpu().data.numpy()) np.savetxt('model/fc4_b', model.fc4.bias.cpu().data.numpy()) np.savetxt('model/weights', prior.weights_) np.savetxt('model/means', prior.means_) np.savetxt('model/covars', prior.covariances_)