import torch import torch.nn as nn import numpy as np from sklearn.metrics.cluster import normalized_mutual_info_score as nmi_score from sklearn.metrics import adjusted_rand_score as ari_score from sklearn.metrics import silhouette_score from sklearn.cluster import KMeans from utils.faster_mix_k_means_pytorch import K_Means from utils.util import cluster_acc, Identity, seed_torch from data.cifarloader import CIFAR100Loader from models.vgg import VGG from tqdm import tqdm from collections import Counter import random import math import warnings warnings.filterwarnings("ignore", category=UserWarning) warnings.filterwarnings("ignore", category=FutureWarning) def estimate_k(model, unlabeled_loader, labeled_loader, args): u_num = len(unlabeled_loader.dataset) u_targets = np.zeros(u_num) u_feats = np.zeros((u_num, 512)) print('extracting features for unlabeld data') for _, (x, label, idx) in enumerate(tqdm(unlabeled_loader)): x = x.to(device) _, feat = model(x) feat = feat.view(x.size(0), -1) idx = idx.data.cpu().numpy() u_feats[idx, :] = feat.data.cpu().numpy() u_targets[idx] = label.data.cpu().numpy() cand_k = np.arange(args.max_cand_k) #get acc for labeled data with short listed k l_num = len(labeled_loader.dataset) l_targets = np.zeros(l_num) l_feats = np.zeros((l_num, 512)) print('extracting features for labeld data') for _, (x, label, idx) in enumerate(tqdm(labeled_loader)): x = x.to(device) _, feat = model(x) feat = feat.view(x.size(0), -1) idx = idx.data.cpu().numpy() l_feats[idx, :] = feat.data.cpu().numpy() l_targets[idx] = label.data.cpu().numpy() l_classes = set(l_targets) num_lt_cls = int(round(len(l_classes)*args.split_ratio)) lt_classes = set(random.sample(l_classes, num_lt_cls)) #random sample 5 classes from all labeled classes lv_classes = l_classes - lt_classes lt_feats = np.empty((0, l_feats.shape[1])) lt_targets = np.empty(0) for c in lt_classes: lt_feats = np.vstack((lt_feats, l_feats[l_targets==c])) lt_targets = np.append(lt_targets, l_targets[l_targets==c]) lv_feats = np.empty((0, l_feats.shape[1])) lv_targets = np.empty(0) for c in lv_classes: lv_feats = np.vstack((lv_feats, l_feats[l_targets==c])) lv_targets = np.append(lv_targets, l_targets[l_targets==c]) cvi_list = np.zeros(len(cand_k)) acc_list = np.zeros(len(cand_k)) cat_pred_list = np.zeros([len(cand_k),u_num+l_num]) print('estimating K ...') for i in range(len(cand_k)): cvi_list[i], cat_pred_i = labeled_val_fun(np.concatenate((lv_feats, u_feats)), lt_feats, lt_targets, cand_k[i]+args.num_val_cls) cat_pred_list[i, :] = cat_pred_i acc_list[i] = cluster_acc(lv_targets, cat_pred_i[len(lt_targets): len(lt_targets)+len(lv_targets)]) best_k = get_best_k(cvi_list[:i+1], acc_list[:i+1], cat_pred_list[:i+1], l_num) print('current best K {}'.format(best_k)) kmeans = KMeans(n_clusters=best_k) u_pred = kmeans.fit_predict(u_feats).astype(np.int32) acc, nmi, ari = cluster_acc(u_targets, u_pred), nmi_score(u_targets, u_pred), ari_score(u_targets, u_pred) print('Final K {}, acc {:.4f}, nmi {:.4f}, ari {:.4f}'.format(best_k, acc, nmi, ari)) return best_k def labeled_val_fun(u_feats, l_feats, l_targets, k): if device=='cuda': torch.cuda.empty_cache() l_num=len(l_targets) kmeans = K_Means(k, pairwise_batch_size=256) kmeans.fit_mix(torch.from_numpy(u_feats).to(device), torch.from_numpy(l_feats).to(device), torch.from_numpy(l_targets).to(device)) cat_pred = kmeans.labels_.cpu().numpy() u_pred = cat_pred[l_num:] silh_score = silhouette_score(u_feats, u_pred) return silh_score, cat_pred def get_best_k(cvi_list, acc_list, cat_pred_list, l_num): idx_cvi = np.max(np.argwhere(cvi_list==np.max(cvi_list))) idx_acc = np.max(np.argwhere(acc_list==np.max(acc_list))) idx_best = int(math.ceil((idx_cvi+idx_acc)*1.0/2)) cat_pred = cat_pred_list[idx_best, :] cnt_cat = Counter(cat_pred.tolist()) cnt_l = Counter(cat_pred[:l_num].tolist()) cnt_ul = Counter(cat_pred[l_num:].tolist()) bin_cat = [x[1] for x in sorted(cnt_cat.items())] bin_l = [x[1] for x in sorted(cnt_l.items())] bin_ul = [x[1] for x in sorted(cnt_ul.items())] best_k = np.sum(np.array(bin_ul)/np.max(bin_ul).astype(float)>args.min_max_ratio) return best_k if __name__ == "__main__": import argparse parser = argparse.ArgumentParser( description='cluster', formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument('--batch_size', default=128, type=int) parser.add_argument('--n_clusters', default=10, type=int) parser.add_argument('--num_val_cls', default=10, type=int) parser.add_argument('--max_cand_k', default=100, type=int) parser.add_argument('--split_ratio', type=float, default=0.6) parser.add_argument('--min_max_ratio', type=float, default=0.01) parser.add_argument('--pretrain_dir', type=str, default='./data/experiments/pretrained/vgg6_cifar100_classif_80.pth') parser.add_argument('--dataset_root', type=str, default='./data/datasets/CIFAR/') parser.add_argument('--seed', default=1, type=int) args = parser.parse_args() args.cuda = torch.cuda.is_available() device = torch.device("cuda" if args.cuda else "cpu") seed_torch(args.seed) model = VGG(n_layer='5+1', out_dim=80).to(device) model.load_state_dict(torch.load(args.pretrain_dir), strict=False) model.last = Identity() val_loader = CIFAR100Loader(root=args.dataset_root, batch_size=args.batch_size, split='train',labeled = True, aug=None, shuffle=True, mode='probe') eval_loader = CIFAR100Loader(root=args.dataset_root, batch_size=args.batch_size, split='train',labeled = False, aug=None, shuffle=False) args.n_clusters = estimate_k(model, eval_loader, val_loader, args)