# -*- coding: utf-8 -*- """ File file_util.py @author:ZhengYuwei """ import os import logging import functools import tensorflow as tf from dataset.dataset_util import DatasetUtil class FileUtil(object): """ 从标签文件中,构造返回(image, label)的tf.data.Dataset数据集 标签文件内容如下: image_name label0,label1,label2,... """ @staticmethod def _parse_string_line(string_line, root_path): """ 解析文本中的一行字符串行,得到图片路径(拼接图片根目录)和标签 :param string_line: 文本中的一行字符串,image_name label0 label1 label2 label3 ... :param root_path: 图片根目录 :return: DatasetV1Adapter<(图片路径Tensor(shape=(), dtype=string),标签Tensor(shape=(?,), dtype=float32))> """ strings = tf.string_split([string_line], delimiter=' ').values image_path = tf.string_join([root_path, strings[0]], separator=os.sep) labels = tf.string_to_number(strings[1:]) return image_path, labels @staticmethod def _parse_image(image_path, _, image_size): """ 根据图片路径和标签,读取图片 :param image_path: 图片路径, Tensor(shape=(), dtype=string) :param _: 标签Tensor(shape(?,), dtype=float32)),本函数只产生图像dataset,故不需要 :param image_size: 图像需要resize到的大小 :return: 归一化的图片 Tensor(shape=(48, 144, ?), dtype=float32) """ # 图片 image = tf.read_file(image_path) image = tf.image.decode_jpeg(image) image = tf.image.resize_images(image, image_size, method=tf.image.ResizeMethod.NEAREST_NEIGHBOR) # 这里使用tf.float32会将照片归一化,也就是 *1/255 image = tf.image.convert_image_dtype(image, dtype=tf.float32) image = tf.reverse(image, axis=[2]) # 读取的是rgb,需要转为bgr return image @staticmethod def _parse_labels(_, labels, num_labels): """ 根据图片路径和标签,解析标签 :param _: 图片路径, Tensor(shape=(), dtype=string),本函数只产生标签dataset,故不需要 :param labels: 标签,Tensor(shape=(?,), dtype=float32) :param num_labels: 每个图像对于输出的标签数(多标签分类模型) :return: 标签 DatasetV1Adapter<(多个标签Tensor(shape=(), dtype=float32), ...)> """ label_list = list() for label_index in range(num_labels): label_list.append(labels[label_index]) return label_list @staticmethod def get_dataset(file_path, root_path, image_size, num_labels, batch_size, is_augment=True, is_test=False): """ 从标签文件读取数据,并解析为(image_path, labels)形式的列表 标签文件内容格式为: image_name label0,label1,label2,label3,... :param file_path: 标签文件路径 :param root_path: 图片路径的根目录,用于和标签文件中的image_name拼接 :param image_size: 图像需要resize到的尺寸 :param num_labels: 每个图像对于输出的标签数(多标签分类模型) :param batch_size: 批次大小 :param is_augment: 是否对图片进行数据增强 :param is_test: 是否为测试阶段,测试阶段的话,输出的dataset中多包含image_path :return: tf.data.Dataset对象 """ logging.info('利用标签文件、图片根目录生成tf.data数据集对象:') logging.info('1. 解析标签文件;') dataset = tf.data.TextLineDataset(file_path) dataset = DatasetUtil.shuffle_repeat(dataset, batch_size) dataset = dataset.map(functools.partial(FileUtil._parse_string_line, root_path=root_path), num_parallel_calls=tf.data.experimental.AUTOTUNE) logging.info('2. 读取图片数据,构造image set和label set;') image_set = dataset.map(functools.partial(FileUtil._parse_image, image_size=image_size), num_parallel_calls=tf.data.experimental.AUTOTUNE) labels_set = dataset.map(functools.partial(FileUtil._parse_labels, num_labels=num_labels), num_parallel_calls=tf.data.experimental.AUTOTUNE) if is_augment: logging.info('2.1 image set数据增强;') image_set = DatasetUtil.augment_image(image_set) logging.info('3. image set数据标准化;') image_set = image_set.map(lambda image: tf.image.per_image_standardization(image), num_parallel_calls=tf.data.experimental.AUTOTUNE) if is_test: logging.info('4. 完成tf.data (image, label, path) 测试数据集构造;') path_set = dataset.map(lambda image_path, label: image_path, num_parallel_calls=tf.data.experimental.AUTOTUNE) dataset = tf.data.Dataset.zip((image_set, labels_set, path_set)) else: logging.info('4. 完成tf.data (image, label) 训练数据集构造;') # 合并image、labels: # DatasetV1Adapter<shapes:((48,144,?), ((), ..., ())), types:(float32,(float32,...,flout32))> dataset = tf.data.Dataset.zip((image_set, labels_set)) logging.info('5. 构造tf.data多epoch训练模式;') dataset = DatasetUtil.batch_prefetch(dataset, batch_size) return dataset if __name__ == '__main__': import cv2 import numpy as np import time # 开启eager模式进行图片读取、增强和展示 tf.enable_eager_execution() train_file_path = './test_sample/label.txt' # 标签文件 image_root_path = './test_sample' # 图片根目录 train_batch = 100 train_set = FileUtil.get_dataset(train_file_path, image_root_path, image_size=(48, 144), num_labels=10, batch_size=train_batch, is_augment=True) start = time.time() for count, data in enumerate(train_set): for i in range(data[0].shape[0]): cv2.imshow('a', np.array(data[0][i])) cv2.waitKey(1) for count, data in enumerate(train_set): print('一批(%d)图像 shape:' % train_batch, data[0].shape) for i in range(data[0].shape[0]): cv2.imshow('a', np.array(data[0][i])) cv2.waitKey(1) print('一批(%d)标签 shape:' % train_batch, len(data[1])) for i in range(len(data[1])): print(data[1][i]) if count == 100: break print('耗时:', time.time() - start)