import numpy as np import os import os.path as osp import json from utils.cluster import compute_kmedoids from .wider_face import WIDERFace from torch.utils import data def get_dataloader(datapath, args, num_templates=25, template_file="templates.json", img_transforms=None, train=True, split="train"): template_file = osp.join("datasets", template_file) if osp.exists(template_file): templates = json.load(open(template_file)) else: # Cluster the bounding boxes to get the templates dataset = WIDERFace(osp.expanduser(args.traindata), []) clustering = compute_kmedoids(dataset.get_all_bboxes(), 1, indices=num_templates, option='pyclustering', max_clusters=num_templates) print("Canonical bounding boxes computed") templates = clustering[num_templates]['medoids'].tolist() # record templates json.dump(templates, open(template_file, "w")) templates = np.round_(np.array(templates), decimals=8) data_loader = data.DataLoader(WIDERFace(osp.expanduser(datapath), templates, train=train, split=split, img_transforms=img_transforms, dataset_root=osp.expanduser(args.dataset_root), debug=args.debug), batch_size=args.batch_size, shuffle=train, num_workers=args.workers, pin_memory=True) return data_loader, templates