#!/usr/bin/env python # -*- coding: utf-8 -*- # Copyright (c) 2017 Hiroaki Santo from __future__ import absolute_import from __future__ import division from __future__ import print_function import glob import os import cv2 import numpy as np import tqdm class DpsnDataset(object): def __init__(self, dataset_path="./dataset/", name="blobs_merl"): self.dataset_path = dataset_path self.data_list = glob.glob(os.path.join(self.dataset_path, "*/", "*/")) self.name = name tmp_list = glob.glob(os.path.join(self.data_list[0], "[0-9]*.png")) self.light_num = len(tmp_list) self.img_size = cv2.imread(tmp_list[0])[:, :, 0].shape print("light_num: {}".format(self.light_num)) print("image_size: {}".format(self.img_size)) print("data_num: {}".format(len(self))) np.random.seed(1) self.random_indices = np.random.permutation(len(self)) def data_path2name(self, path): dir_path, _ = os.path.split(path) dir_path, brdf_name = os.path.split(dir_path) _, obj_name = os.path.split(dir_path) return obj_name, brdf_name def __load_normal_png(self, n_path): n_img = cv2.imread(n_path)[:, :, ::-1] m, n, _ = n_img.shape N = n_img.reshape(-1, 3).T N = N.astype(np.float32) / 255. * 2. - 1. for i in range(m * n): norm = np.linalg.norm(N[:, i]) if norm != 0: N[:, i] /= norm mask = np.ones(shape=(m * n)) n_img = n_img.reshape(-1, 3).T for i in range(m * n): if np.linalg.norm(n_img[:, i]) == 0: mask[i] = 0 N[:, mask == 0] = 0 return N, m, n, mask def __len__(self): return len(self.data_list) def load_data(self, root_path): def_png_path = os.path.join(root_path, "{light_index}.png") m, n = self.img_size M = np.zeros(shape=(m * n, self.light_num, 3), dtype=np.float32) for l in range(self.light_num): m_img = cv2.imread(def_png_path.format(light_index=l), cv2.IMREAD_UNCHANGED)[:, :, ::-1] # m_img = cv2.imread(def_png_path.format(light_index=l))[:, :, ::-1] M[:, l, :] = m_img.reshape(-1, 3) obj_name, brdf_name = self.data_path2name(root_path + "/") N, m, n, mask = self.__load_normal_png(os.path.join(self.dataset_path, obj_name, "{}.png".format(obj_name))) return M, N, mask def get_batch(self, index, image_num): """ :param index: :param image_num: This does not mean the number of batch data. Number of pixels in image_num images becomes number of data. :return: """ indices = np.arange(index, index + image_num) indices %= len(self) indices = self.random_indices[indices] batch_normal = [] batch_mess = [] for i in indices: M, N, mask = self.load_data(self.data_list[i]) M = M.astype(float) / np.max(M) for p in range(N.shape[1]): if mask[p] == 0: continue if np.min(np.linalg.norm(M[p, :, :], axis=0)) == 0: continue for color in range(3): batch_normal.append(N[:, p]) batch_mess.append(M[p, :, color]) return np.array(batch_normal, dtype=np.float32), np.array(batch_mess, dtype=np.float32) def save_as_tfrecord(self): print("[*] save_as_tfrecord()") import tensorflow as tf tfwriter = tf.python_io.TFRecordWriter( os.path.join(self.dataset_path, "{}_{}.tfrecord".format(type(self).__name__, self.name))) try: for i in tqdm.tqdm(range(0, len(self), 30)): normal, mess = self.get_batch(index=i, image_num=30) print("serialize data: {}, {}".format(normal.shape, mess.shape)) for j in np.random.permutation(len(normal)): n_ = normal[j, :].astype(np.float32) m_ = mess[j, :].astype(np.float32) record = tf.train.Example(features=tf.train.Features(feature={ 'normal': tf.train.Feature(float_list=tf.train.FloatList(value=n_.reshape(-1).tolist())), 'mess': tf.train.Feature( float_list=tf.train.FloatList(value=m_.reshape(-1).tolist())), })) tfwriter.write(record.SerializeToString()) finally: tfwriter.close() def load_from_tfrecord(self): import tensorflow as tf data_path = os.path.join(self.dataset_path, "{}_{}.tfrecord".format(type(self).__name__, self.name)) assert os.path.exists(data_path), data_path filename_queue = tf.train.string_input_producer([data_path], num_epochs=None) reader = tf.TFRecordReader() _, serialized_example = reader.read(filename_queue) features = tf.parse_single_example( serialized_example, features={ 'normal': tf.FixedLenFeature([3], tf.float32), 'mess': tf.FixedLenFeature([self.light_num], tf.float32), }) return features["normal"], features["mess"] if __name__ == '__main__': import params dataset = DpsnDataset(dataset_path=os.path.join(params.DATASET_PATH, "train")) dataset.save_as_tfrecord()