from __future__ import absolute_import import torch from torch.utils import data from scipy.io import loadmat from enum import Enum class Tree(object): class NodeType(Enum): BOX = 0 # box node ADJ = 1 # adjacency (adjacent part assembly) node SYM = 2 # symmetry (symmetric part grouping) node class Node(object): def __init__(self, box=None, left=None, right=None, node_type=None, sym=None): self.box = box # box feature vector for a leaf node self.sym = sym # symmetry parameter vector for a symmetry node self.left = left # left child for ADJ or SYM (a symmeter generator) self.right = right # right child self.node_type = node_type self.label = torch.LongTensor([self.node_type.value]) def is_leaf(self): return self.node_type == Tree.NodeType.BOX and self.box is not None def is_adj(self): return self.node_type == Tree.NodeType.ADJ def is_sym(self): return self.node_type == Tree.NodeType.SYM def __init__(self, boxes, ops, syms): box_list = [b for b in torch.split(boxes, 1, 0)] sym_param = [s for s in torch.split(syms, 1, 0)] box_list.reverse() sym_param.reverse() queue = [] for id in xrange(ops.size()[1]): if ops[0, id] == Tree.NodeType.BOX.value: queue.append(Tree.Node(box=box_list.pop(), node_type=Tree.NodeType.BOX)) elif ops[0, id] == Tree.NodeType.ADJ.value: left_node = queue.pop() right_node = queue.pop() queue.append(Tree.Node(left=left_node, right=right_node, node_type=Tree.NodeType.ADJ)) elif ops[0, id] == Tree.NodeType.SYM.value: node = queue.pop() queue.append(Tree.Node(left=node, sym=sym_param.pop(), node_type=Tree.NodeType.SYM)) assert len(queue) == 1 self.root = queue[0] class GRASSDataset(data.Dataset): def __init__(self, dir, transform=None): self.dir = dir box_data = torch.from_numpy(loadmat(self.dir+u'/box_data.mat')[u'boxes']).float() op_data = torch.from_numpy(loadmat(self.dir+u'/op_data.mat')[u'ops']).int() sym_data = torch.from_numpy(loadmat(self.dir+u'/sym_data.mat')[u'syms']).float() #weight_list = torch.from_numpy(loadmat(self.dir+'/weights.mat')['weights']).float() num_examples = op_data.size()[1] box_data = torch.chunk(box_data, num_examples, 1) op_data = torch.chunk(op_data, num_examples, 1) sym_data = torch.chunk(sym_data, num_examples, 1) #weight_list = torch.chunk(weight_list, num_examples, 1) self.transform = transform self.trees = [] for i in xrange(len(op_data)) : boxes = torch.t(box_data[i]) ops = torch.t(op_data[i]) syms = torch.t(sym_data[i]) tree = Tree(boxes, ops, syms) self.trees.append(tree) def __getitem__(self, index): tree = self.trees[index] return tree def __len__(self): return len(self.trees)