import torch import torch.nn as nn import torch.nn.functional as F from torch.nn.parameter import Parameter from torch.optim import Adam from torch.autograd import Variable import numpy as np from sklearn.cluster import KMeans from sklearn.metrics.cluster import normalized_mutual_info_score as nmi_score from sklearn.metrics import adjusted_rand_score as ari_score from sklearn.decomposition import PCA from data.omniglotloader import omniglot_alphabet_func, omniglot_evaluation_alphabets_mapping from utils.util import cluster_acc, Identity, AverageMeter, seed_torch, str2bool from utils import ramps from models.vgg import VGG from modules.module import feat2prob, target_distribution from tqdm import tqdm import warnings import os warnings.filterwarnings("ignore", category=UserWarning) def init_prob_kmeans(model, eval_loader, args): torch.manual_seed(1) model = model.to(device) # cluster parameter initiate model.eval() targets = np.zeros(len(eval_loader.dataset)) feats = np.zeros((len(eval_loader.dataset), 1024)) for _, (x, _, label, idx) in enumerate(eval_loader): x = x.to(device) _, feat = model(x) feat = feat.view(x.size(0), -1) idx = idx.data.cpu().numpy() feats[idx, :] = feat.data.cpu().numpy() targets[idx] = label.data.cpu().numpy() # evaluate clustering performance pca = PCA(n_components=args.n_clusters) feats = pca.fit_transform(feats) kmeans = KMeans(n_clusters=args.n_clusters, n_init=20) y_pred = kmeans.fit_predict(feats) acc, nmi, ari = cluster_acc(targets, y_pred), nmi_score(targets, y_pred), ari_score(targets, y_pred) print('Init acc {:.4f}, nmi {:.4f}, ari {:.4f}'.format(acc, nmi, ari)) probs = feat2prob(torch.from_numpy(feats), torch.from_numpy(kmeans.cluster_centers_)) return kmeans.cluster_centers_, probs def warmup_train(model, alphabetStr, train_loader, eval_loader, args): optimizer = Adam(model.parameters(), lr=args.warmup_lr) for epoch in range(args.warmup_epochs): loss_record = AverageMeter() model.train() for batch_idx, (x, g_x, _, idx) in enumerate(train_loader): _, feat = model(x.to(device)) prob = feat2prob(feat, model.center) loss = F.kl_div(prob.log(), args.p_targets[idx].float().to(device)) optimizer.zero_grad() loss.backward() optimizer.step() loss_record.update(loss.item(), x.size(0)) print('Warmup Train Epoch: {} Avg Loss: {:.4f}'.format(epoch, loss_record.avg)) test(model, eval_loader, args) def Baseline_train(model, alphabetStr, train_loader, eval_loader, args): optimizer = Adam(model.parameters(), lr=args.lr) for epoch in range(args.epochs): loss_record = AverageMeter() model.train() for batch_idx, (x, g_x, _, idx) in enumerate(train_loader): _, feat = model(x.to(device)) prob = feat2prob(feat, model.center) loss = F.kl_div(prob.log(), args.p_targets[idx].float().to(device)) optimizer.zero_grad() loss.backward() optimizer.step() loss_record.update(loss.item(), x.size(0)) print('Train Epoch: {} Avg Loss: {:.4f}'.format(epoch, loss_record.avg)) _, _, _, probs = test(model, eval_loader, args) if epoch % args.update_interval==0: args.p_targets= target_distribution(probs) torch.save(model.state_dict(), args.model_dir) print("model saved to {}.".format(args.model_dir)) def PI_train(model, alphabetStr, train_loader, eval_loader, args): optimizer = Adam(model.parameters(), lr=args.lr) for epoch in range(args.epochs): loss_record = AverageMeter() model.train() w = args.rampup_coefficient * ramps.sigmoid_rampup(epoch, args.rampup_length) for batch_idx, (x, g_x, _, idx) in enumerate(train_loader): _, feat = model(x.to(device)) _, feat_g = model(g_x.to(device)) prob = feat2prob(feat, model.center) prob_g = feat2prob(feat_g, model.center) loss = F.kl_div(prob.log(), args.p_targets[idx].float().to(device)) mse_loss = F.mse_loss(prob, prob_g) loss=loss + w*mse_loss optimizer.zero_grad() loss.backward() optimizer.step() loss_record.update(loss.item(), x.size(0)) print('Train Epoch: {} Avg Loss: {:.4f}'.format(epoch, loss_record.avg)) _, _, _, probs = test(model, eval_loader, args) if epoch % args.update_interval==0: args.p_targets= target_distribution(probs) torch.save(model.state_dict(), args.model_dir) print("model saved to {}.".format(args.model_dir)) def TE_train(model, alphabetStr, train_loader, eval_loader, args): optimizer = Adam(model.parameters(), lr=args.lr) w = 0 alpha = 0.6 ntrain = len(train_loader.dataset) Z = torch.zeros(ntrain, args.n_clusters).float().to(device) # intermediate values z_ema = torch.zeros(ntrain, args.n_clusters).float().to(device) # temporal outputs z_epoch = torch.zeros(ntrain, args.n_clusters).float().to(device) # current outputs for epoch in range(args.epochs): loss_record = AverageMeter() model.train() w = args.rampup_coefficient * ramps.sigmoid_rampup(epoch, args.rampup_length) for batch_idx, (x, _, _, idx) in enumerate(train_loader): _, feat = model(x.to(device)) prob = feat2prob(feat, model.center) z_epoch[idx, :] = prob prob_bar = Variable(z_ema[idx, :], requires_grad=False) loss = F.kl_div(prob.log(), args.p_targets[idx].float().to(device)) mse_loss = F.mse_loss(prob, prob_bar) loss=loss+w*mse_loss optimizer.zero_grad() loss.backward() optimizer.step() loss_record.update(loss.item(), x.size(0)) Z = alpha * Z + (1. - alpha) * z_epoch z_ema = Z * (1. / (1. - alpha ** (epoch + 1))) print('Train Epoch: {} Avg Loss: {:.4f}'.format(epoch, loss_record.avg)) _, _, _, probs = test(model, eval_loader, args) if epoch % args.update_interval==0: args.p_targets = target_distribution(probs) torch.save(model.state_dict(), args.model_dir) print("model saved to {}.".format(args.model_dir)) def TEP_train(model, alphabetStr, train_loader, eval_loader, args): optimizer = Adam(model.parameters(), lr=args.lr) w = 0 alpha = 0.6 ntrain = len(train_loader.dataset) Z = torch.zeros(ntrain, args.n_clusters).float().to(device) # intermediate values z_ema = torch.zeros(ntrain, args.n_clusters).float().to(device) # temporal outputs z_epoch = torch.zeros(ntrain, args.n_clusters).float().to(device) # current outputs for epoch in range(args.epochs): loss_record = AverageMeter() model.train() for batch_idx, (x, g_x, _, idx) in enumerate(train_loader): _, feat = model(x.to(device)) prob = feat2prob(feat, model.center) loss = F.kl_div(prob.log(), args.p_targets[idx].float().to(device)) optimizer.zero_grad() loss.backward() optimizer.step() loss_record.update(loss.item(), x.size(0)) print('Train Epoch: {} Avg Loss: {:.4f}'.format(epoch, loss_record.avg)) _, _, _, probs = test(model, eval_loader, args) z_epoch = probs.float().to(device) Z = alpha * Z + (1. - alpha) * z_epoch z_bars = Z * (1. / (1. - alpha ** (epoch + 1))) if epoch % args.update_interval==0: args.p_targets = target_distribution(z_bars).float().to(device) torch.save(model.state_dict(), args.model_dir) print("model saved to {}.".format(args.model_dir)) def test(model, eval_loader, args): model.eval() targets = np.zeros(len(eval_loader.dataset)) y_pred = np.zeros(len(eval_loader.dataset)) probs= np.zeros((len(eval_loader.dataset), args.n_clusters)) for _, (x, _, label, idx) in enumerate(eval_loader): x = x.to(device) _, feat = model(x) prob = feat2prob(feat, model.center) # prob = F.softmax(logit, dim=1) idx = idx.data.cpu().numpy() y_pred[idx] = prob.data.cpu().detach().numpy().argmax(1) targets[idx] = label.data.cpu().numpy() probs[idx, :] = prob.cpu().detach().numpy() # evaluate clustering performance y_pred = y_pred.astype(np.int64) acc, nmi, ari = cluster_acc(targets, y_pred), nmi_score(targets, y_pred), ari_score(targets, y_pred) print('Test acc {:.4f}, nmi {:.4f}, ari {:.4f}'.format(acc, nmi, ari)) probs = torch.from_numpy(probs) return acc, nmi, ari, probs if __name__ == "__main__": import argparse parser = argparse.ArgumentParser( description='cluster', formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument('--lr', type=float, default=0.001) parser.add_argument('--warmup_lr', type=float, default=0.001) parser.add_argument('--warmup_epochs', default=10, type=int) parser.add_argument('--epochs', default=90, type=int) parser.add_argument('--batch_size', default=100, type=int) parser.add_argument('--update_interval', default=1, type=int) parser.add_argument('--seed', default=1, type=int) parser.add_argument('--save_txt', default=False, type=str2bool, help='save txt or not', metavar='BOOL') parser.add_argument('--rampup_length', default=5, type=int) parser.add_argument('--rampup_coefficient', default=100.0, type=float) parser.add_argument('--n_clusters', default=10, type=int) parser.add_argument('--num_workers', default=2, type=int) parser.add_argument('--pretrain_dir', type=str, default='./data/experiments/pretrained/vgg6_omniglot_proto.pth') parser.add_argument('--dataset_root', type=str, default='./data/datasets') parser.add_argument('--exp_root', type=str, default='./data/experiments/') parser.add_argument('--subfolder_name', type=str, default='run') parser.add_argument('--save_txt_name', type=str, default='result.txt') parser.add_argument('--DTC', type=str, default='PI') args = parser.parse_args() args.cuda = torch.cuda.is_available() print("use cuda: {}".format(args.cuda)) device = torch.device("cuda" if args.cuda else "cpu") seed_torch(args.seed) model = VGG(n_layer='4+2', in_channels=1).to(device) model.load_state_dict(torch.load(args.pretrain_dir), strict=False) model.last = Identity() init_feat_extractor = model acc = {} nmi = {} ari = {} for _, alphabetStr in omniglot_evaluation_alphabets_mapping.items(): runner_name = os.path.basename(__file__).split(".")[0] model_dir= args.exp_root + '{}/{}/{}'.format(runner_name, args.DTC, args.subfolder_name) if not os.path.exists(model_dir): os.makedirs(model_dir) args.model_dir = model_dir+'/'+'vgg6_{}.pth'.format(alphabetStr) args.save_txt_path= args.exp_root + '{}/{}/{}'.format(runner_name, args.DTC, args.save_txt_name) train_Dloader, eval_Dloader = omniglot_alphabet_func(alphabet=alphabetStr, background=False, root=args.dataset_root)(batch_size=args.batch_size, num_workers=args.num_workers) args.n_clusters = train_Dloader.num_classes model = VGG(n_layer='4+2', out_dim=args.n_clusters, in_channels=1).to(device) model.load_state_dict(torch.load(args.pretrain_dir), strict=False) model.center= Parameter(torch.Tensor(args.n_clusters, args.n_clusters)) init_centers, init_probs = init_prob_kmeans(init_feat_extractor, eval_Dloader, args) args.p_targets = target_distribution(init_probs) model.center.data = torch.tensor(init_centers).float().to(device) warmup_train(model, alphabetStr, train_Dloader, eval_Dloader, args) if args.DTC == 'Baseline': Baseline_train(model, alphabetStr, train_Dloader, eval_Dloader, args) elif args.DTC == 'PI': PI_train(model, alphabetStr, train_Dloader, eval_Dloader, args) elif args.DTC == 'TE': TE_train(model, alphabetStr, train_Dloader, eval_Dloader, args) elif args.DTC == 'TEP': TEP_train(model, alphabetStr, train_Dloader, eval_Dloader, args) acc[alphabetStr], nmi[alphabetStr], ari[alphabetStr], _ = test(model, eval_Dloader, args) print('ACC for all alphabets:',acc) print('NMI for all alphabets:',nmi) print('ARI for all alphabets:',ari) avg_acc, avg_nmi, avg_ari = sum(acc.values())/float(len(acc)), sum(nmi.values())/float(len(nmi)), sum(ari.values())/float(len(ari)) print('avg ACC {:.4f}, NMI {:.4f} ARI {:.4f}'.format(avg_acc, avg_nmi, avg_ari)) if args.save_txt: with open(args.save_txt_path, 'a') as f: f.write("{:.4f}, {:.4f}, {:.4f}\n".format(avg_acc, avg_nmi, avg_ari))