import os import glob import scipy import numpy as np from gpnn.utils.logger import get_logger logger = get_logger() def read_idx_file(file_name): idx = [] with open(file_name) as f: for line in f: idx += [int(line)] return idx def get_filenames_from_dir(dirname, pattern): """ Return a list of files in directory dirname with some pattern (e.g., "Node_*.txt") """ return sorted(glob.glob(os.path.join(dirname, pattern))) def read_txt_file(file_name): """ Read the data inside a file to a dictionary """ data_dict = {} key_list = [] with open(file_name, "r") as file: for count, line in enumerate(file): # the first line should contain the key of each column if count == 0: key_list = line.split() num_key = len(key_list) for key in key_list: data_dict[key] = [] else: token_list = line.split() if len(token_list) < num_key: # sanity check logger.warning("There are missing tokens") # handle corner cases (missing token) in node file data_dict[key_list[0]].append(token_list[0]) data_dict[key_list[1]].append("missing_token") data_dict[key_list[2]].append(token_list[1]) # print("file_name = {}".format(file_name)) # raw_input("wait") elif len(token_list) > num_key: logger.warning("There is something wrong!") logger.warning("file name = {}".format(file_name)) raw_input("wait") else: for idx, token in enumerate(token_list): data_dict[key_list[idx]].append(token) return data_dict def gen_split_idx(num_all, num_train, num_val, num_test, seed=1234): """Generate train/val/test split indices Args: num_all: total number of data num_train: number of training data num_val: number of validation data num_test: number of testing data seed: seed of random number generator Returns: train_idx: index of training data val_idx: index of validation data test_idx index of testing data """ assert num_train + num_val + num_test <= num_all prng = np.random.RandomState(seed) perm_idx = prng.permutation(num_all) train_idx = perm_idx[:num_train] val_idx = perm_idx[num_train:num_train + num_val] test_idx = perm_idx[num_train + num_val:num_train + num_val + num_test] return train_idx, val_idx, test_idx def read_list_from_file(filename): data_list = [] with open(filename, "r") as ff: for line in ff: data_list += [line.rstrip()] return data_list def read_csv_file(file_name): with open(file_name, "r") as ff: count = 0 for line in ff: line_str = line.rstrip().split(",") if count == 0: num_col = len(line_str) results = [[] for _ in xrange(num_col)] for ii, xx in enumerate(line_str): results[ii] += [int(xx)] count += 1 return results