import numpy as np import os import random from PIL import Image def load(path): return PIL2Chainer(Image.open(path)) def save(path, data, rescale=True): Chainer2PIL(data, rescale).save(path) def PIL2Chainer(img, scale=True): img = np.array(img) if len(img.shape) == 2: img = img.astype(np.float32).reshape((1, img.shape[0], img.shape[1])) else: buf = np.zeros((img.shape[2], img.shape[0], img.shape[1]), dtype=np.uint8) for i in range(3): buf[i,::] = img[:,:,i] img = buf.astype(np.float32) if scale: # img -= 128 img /= 256 return img def Chainer2PIL(data, rescale=True): data = np.array(data) if rescale: data *= 256 # data += 128 if data.dtype != np.uint8: data = np.clip(data, 0, 255) data = data.astype(np.uint8) if data.shape[0] == 1: buf = data.astype(np.uint8).reshape((data.shape[1], data.shape[2])) else: buf = np.zeros((data.shape[1], data.shape[2], data.shape[0]), dtype=np.uint8) for i in range(3): a = data[i,:,:] buf[:,:,i] = a img = Image.fromarray(buf) return img