__author__ = 'Jiri Fajtl' __email__ = 'ok1zjf@gmail.com' __version__= '2.2' __status__ = "Research" __date__ = "28/1/2018" __license__= "MIT License" import os import numpy as np import glob import subprocess import platform import sys import pkg_resources import torch import PIL as Image try: import cv2 except: print("WARNING: Could not load OpenCV python package. Some functionality may not be available.") def list_files(path, extensions=[], sort=True, max_len=-1): if os.path.isdir(path): filenames = [os.path.join(path, fn) for fn in os.listdir(path) if any([fn.endswith(ext) for ext in extensions])] else: print("ERROR. ", path,' is not a directory!') return [] if sort: filenames.sort() if max_len>-1: filenames = filenames[:max_len] return filenames def get_video_list(video_path, max_len=-1): return list_files(video_path, extensions=['avi', 'flv', 'mpg', 'mp4'], sort=True, max_len=max_len) def get_image_list(video_path, max_len=-1): return list_files(video_path, extensions=['jpg', 'jpeg', 'png'], sort=True, max_len=max_len) def get_split_files(dataset_path, splits_path, split_name, absolute_path=False): path = os.path.join(dataset_path, splits_path, split_name) files = glob.glob(path) files.sort() if not absolute_path: files_out = [] for file in files: _,filename = os.path.split(file) files_out.append(filename) return files_out return files def get_max_rc_weights(experiment_path): log_filename = 'train_log_0.csv' try: f = open(os.path.join(experiment_path, log_filename), 'rt') max_rc = 0 max_epoch = -1 max_mse = -1 for line in f: toks = line.split(',') if toks[0] == 'val': epoch = toks[1] try: rc = float(toks[4]) if rc > max_rc: max_rc = rc max_epoch = int(epoch) max_mse = float(toks[6]) except: pass f.close() chkpt_file = experiment_path + '/' + 'weights_' + str(max_epoch) + '.pkl' if not os.path.isfile(chkpt_file): print("WARNING: File ",chkpt_file," does not exists!") return '', 0, 0, 0 return chkpt_file, max_rc, max_mse, max_epoch except: print('WARNING: Could not open ' + os.path.join(experiment_path, log_filename)) return '', 0, 0, 0 def get_split_index(split_filename): filename, _ = os.path.splitext(split_filename) id = int(filename.split('_')[-1]) return id def get_weight_files(split_files, experiment_name, max_rc_checkpoints=True): data_dir = 'data' weight_files = [] for split_filename in split_files: split_name,_ = os.path.splitext(split_filename) _, split_id = split_name.split('_') weight_files_all = os.path.join(data_dir, experiment_name+'_train_'+split_id+'/*.pkl') files = glob.glob(weight_files_all) if len(files) == 0: # No trained model weights for this split weight_files.append('') continue elif len(files) == 1: weight_files.append(files[0]) else: # Multiple weights if max_rc_checkpoints: weights_dir = os.path.join(data_dir, experiment_name + '_train_' + split_id) print("Selecting model weights with the highest RC on validation set in ",weights_dir) weight_file, max_rc, max_mse, max_epoch= get_max_rc_weights(weights_dir) if weight_file != '': print('Found: ',weight_file, ' RC=', max_rc, ' MSE=', max_rc, ' epoch=', max_epoch) weight_files.append(weight_file) continue # Get the weights from the last training epoch files.sort(key=lambda x: get_split_index(x), reverse=True) weight_file=files[0] weight_files.append(weight_file) return weight_files def run_command(command): p = subprocess.Popen(command.split(), stdout=subprocess.PIPE, stderr=subprocess.STDOUT) return '\n'.join([ '\t'+line.decode("utf-8").strip() for line in p.stdout.readlines()]) def ge_pkg_versions(): dep_versions = {} cmd = 'cat /proc/driver/nvidia/version' display_driver = run_command(cmd) dep_versions['display'] = display_driver dep_versions['cuda'] = 'NA' cuda_home = '/usr/local/cuda/' if 'CUDA_HOME' in os.environ: cuda_home = os.environ['CUDA_HOME'] cmd = cuda_home+'/version.txt' if os.path.isfile(cmd): cuda_version = run_command('cat '+cmd) dep_versions['cuda'] = cuda_version dep_versions['cudnn'] = torch.backends.cudnn.version() dep_versions['platform'] = platform.platform() dep_versions['python'] = sys.version_info[0] dep_versions['torch'] = torch.__version__ dep_versions['numpy'] = np.__version__ dep_versions['PIL'] = Image.VERSION dep_versions['OpenCV'] = 'NA' if 'cv2' in sys.modules: dep_versions['OpenCV'] = cv2.__version__ dep_versions['torchvision'] = pkg_resources.get_distribution("torchvision").version return dep_versions def print_pkg_versions(): print("Packages & system versions:") print("----------------------------------------------------------------------") versions = ge_pkg_versions() for key, val in versions.items(): print(key,": ",val) print("") return if __name__ == "__main__": print_pkg_versions() split_files = get_split_files('datasets/lamem', 'splits', 'test_*.txt') print(split_files) weight_files = get_weight_files(split_files, experiment_name='lamem_ResNet50FC_lstm3_last', max_rc_checkpoints=True) # weight_files = get_weight_files(split_files, experiment_name='lamem_ResNet50FC_lstm3') print(weight_files)