# %% Change working directory from the workspace root to the ipynb file location. Turn this addition off with the DataScience.changeDirOnImportExport setting import dill from random import shuffle import random import os try: os.chdir(os.path.join(os.getcwd(), 'data')) print(os.getcwd()) except: pass # %% import pandas as pd from collections import defaultdict import numpy as np med_file = 'PRESCRIPTIONS.csv' diag_file = 'DIAGNOSES_ICD.csv' ndc2atc_file = 'ndc2atc_level4.csv' ddi_file = 'drug-DDI.csv' cid_atc = 'drug-atc.csv' patient_info_file = './gather_firstday.csv' def process_med(): print('process_med') med_pd = pd.read_csv(med_file, dtype={'NDC': 'category'}) # filter med_pd.drop(columns=['ROW_ID', 'DRUG_TYPE', 'DRUG_NAME_POE', 'DRUG_NAME_GENERIC', 'FORMULARY_DRUG_CD', 'GSN', 'PROD_STRENGTH', 'DOSE_VAL_RX', 'DOSE_UNIT_RX', 'FORM_VAL_DISP', 'FORM_UNIT_DISP', 'FORM_UNIT_DISP', 'ROUTE', 'ENDDATE', 'DRUG'], axis=1, inplace=True) med_pd.drop(index=med_pd[med_pd['NDC'] == '0'].index, axis=0, inplace=True) med_pd.fillna(method='pad', inplace=True) med_pd.dropna(inplace=True) med_pd.drop_duplicates(inplace=True) med_pd['ICUSTAY_ID'] = med_pd['ICUSTAY_ID'].astype('int64') med_pd['STARTDATE'] = pd.to_datetime( med_pd['STARTDATE'], format='%Y-%m-%d %H:%M:%S') med_pd.sort_values(by=['SUBJECT_ID', 'HADM_ID', 'ICUSTAY_ID', 'STARTDATE'], inplace=True) med_pd = med_pd.reset_index(drop=True) def filter_first24hour_med(med_pd): med_pd_new = med_pd.drop(columns=['NDC']) med_pd_new = med_pd_new.groupby( by=['SUBJECT_ID', 'HADM_ID', 'ICUSTAY_ID']).head([1]).reset_index(drop=True) med_pd_new = pd.merge(med_pd_new, med_pd, on=[ 'SUBJECT_ID', 'HADM_ID', 'ICUSTAY_ID', 'STARTDATE']) med_pd_new = med_pd_new.drop(columns=['STARTDATE']) return med_pd_new med_pd = filter_first24hour_med(med_pd) # or next line # med_pd = med_pd.drop(columns=['STARTDATE']) med_pd = med_pd.drop(columns=['ICUSTAY_ID']) med_pd = med_pd.drop_duplicates() return med_pd.reset_index(drop=True) def process_diag(): print('process_diag') diag_pd = pd.read_csv(diag_file) diag_pd.dropna(inplace=True) diag_pd.drop(columns=['SEQ_NUM', 'ROW_ID'], inplace=True) diag_pd.drop_duplicates(inplace=True) diag_pd.sort_values(by=['SUBJECT_ID', 'HADM_ID'], inplace=True) return diag_pd.reset_index(drop=True) def process_side(): print('process_side') side_pd = pd.read_csv(patient_info_file) # just use demographic information to avoid future information leak such as lab test and lab measurements side_pd = side_pd[['subject_id', 'hadm_id', 'icustay_id', 'gender_male', 'admission_type', 'first_icu_stay', 'admission_age', 'ethnicity', 'weight', 'height']] # process side_information side_pd = side_pd.dropna(thresh=4) side_pd.fillna(side_pd.mean(), inplace=True) side_pd = side_pd.groupby(by=['subject_id', 'hadm_id']).head( [1]).reset_index(drop=True) side_pd = pd.concat( [side_pd, pd.get_dummies(side_pd['ethnicity'])], axis=1) side_pd.drop(columns=['ethnicity', 'icustay_id'], inplace=True) side_pd.rename(columns={'subject_id': 'SUBJECT_ID', 'hadm_id': 'HADM_ID'}, inplace=True) return side_pd.reset_index(drop=True) def ndc2atc4(med_pd): with open('ndc2rxnorm_mapping.txt', 'r') as f: ndc2rxnorm = eval(f.read()) med_pd['RXCUI'] = med_pd['NDC'].map(ndc2rxnorm) med_pd.dropna(inplace=True) rxnorm2atc = pd.read_csv('ndc2atc_level4.csv') rxnorm2atc = rxnorm2atc.drop(columns=['YEAR', 'MONTH', 'NDC']) rxnorm2atc.drop_duplicates(subset=['RXCUI'], inplace=True) med_pd.drop(index=med_pd[med_pd['RXCUI'].isin( [''])].index, axis=0, inplace=True) med_pd['RXCUI'] = med_pd['RXCUI'].astype('int64') med_pd = med_pd.reset_index(drop=True) med_pd = med_pd.merge(rxnorm2atc, on=['RXCUI']) med_pd.drop(columns=['NDC', 'RXCUI'], inplace=True) # med_pd = med_pd.rename(columns={'ATC4':'NDC'}) med_pd['ATC4'] = med_pd['ATC4'].map(lambda x: x[:5]) med_pd = med_pd.drop_duplicates() med_pd = med_pd.reset_index(drop=True) return med_pd def filter_pro(pro_pd): pro_count = pro_pd.groupby(by=['ICD9_CODE']).size().reset_index().rename( columns={0: 'count'}).sort_values(by=['count'], ascending=False).reset_index(drop=True) pro_pd = pro_pd[pro_pd['ICD9_CODE'].isin( pro_count.loc[:1000, 'ICD9_CODE'])] return pro_pd.reset_index(drop=True) def filter_diag(diag_pd, num=128): print('filter diag') diag_count = diag_pd.groupby(by=['ICD9_CODE']).size().reset_index().rename( columns={0: 'count'}).sort_values(by=['count'], ascending=False).reset_index(drop=True) diag_pd = diag_pd[diag_pd['ICD9_CODE'].isin( diag_count.loc[:num, 'ICD9_CODE'])] return diag_pd.reset_index(drop=True) def filter_med(med_pd): med_count = med_pd.groupby(by=['ATC4']).size().reset_index().rename(columns={ 0: 'count'}).sort_values(by=['count'], ascending=False).reset_index(drop=True) med_pd = med_pd[med_pd['ATC4'].isin(med_count.loc[:299, 'ATC4'])] return med_pd.reset_index(drop=True) # visit filter def filter_by_visit_range(data_pd, v_range=(1, 2)): a = data_pd[['SUBJECT_ID', 'HADM_ID']].groupby( by='SUBJECT_ID')['HADM_ID'].unique().reset_index() a['HADM_ID_Len'] = a['HADM_ID'].map(lambda x: len(x)) a = a[(a['HADM_ID_Len'] >= v_range[0]) & (a['HADM_ID_Len'] < v_range[1])] data_pd_filter = a.reset_index(drop=True) data_pd = data_pd.merge( data_pd_filter[['SUBJECT_ID']], on='SUBJECT_ID', how='inner') return data_pd.reset_index(drop=True) def process_all(visit_range=(1, 2)): # get med and diag (visit>=2) med_pd = process_med() med_pd = ndc2atc4(med_pd) # med_pd = filter_300_most_med(med_pd) med_pd = filter_by_visit_range(med_pd, visit_range) diag_pd = process_diag() diag_pd = filter_diag(diag_pd, num=1999) # side_pd = process_side() # pro_pd = process_procedure() # pro_pd = filter_1000_most_pro(pro_pd) med_pd_key = med_pd[['SUBJECT_ID', 'HADM_ID']].drop_duplicates() diag_pd_key = diag_pd[['SUBJECT_ID', 'HADM_ID']].drop_duplicates() # pro_pd_key = pro_pd[['SUBJECT_ID', 'HADM_ID']].drop_duplicates() # side_pd_key = side_pd[['SUBJECT_ID', 'HADM_ID']].drop_duplicates() combined_key = med_pd_key.merge( diag_pd_key, on=['SUBJECT_ID', 'HADM_ID'], how='inner') # combined_key = combined_key.merge(pro_pd_key, on=['SUBJECT_ID', 'HADM_ID'], how='inner') # combined_key = combined_key.merge(side_pd_key, on=['SUBJECT_ID', 'HADM_ID'], how='inner') diag_pd = diag_pd.merge( combined_key, on=['SUBJECT_ID', 'HADM_ID'], how='inner') med_pd = med_pd.merge( combined_key, on=['SUBJECT_ID', 'HADM_ID'], how='inner') # side_pd = side_pd.merge(combined_key, on=['SUBJECT_ID', 'HADM_ID'], how='inner') # pro_pd = pro_pd.merge(combined_key, on=['SUBJECT_ID', 'HADM_ID'], how='inner') # flatten and merge diag_pd = diag_pd.groupby(by=['SUBJECT_ID', 'HADM_ID'])[ 'ICD9_CODE'].unique().reset_index() med_pd = med_pd.groupby(by=['SUBJECT_ID', 'HADM_ID'])[ 'ATC4'].unique().reset_index() # pro_pd = pro_pd.groupby(by=['SUBJECT_ID','HADM_ID'])['ICD9_CODE'].unique().reset_index().rename(columns={'ICD9_CODE':'PRO_CODE'}) diag_pd['ICD9_CODE'] = diag_pd['ICD9_CODE'].map(lambda x: list(x)) med_pd['ATC4'] = med_pd['ATC4'].map(lambda x: list(x)) # pro_pd['PRO_CODE'] = pro_pd['PRO_CODE'].map(lambda x: list(x)) data = diag_pd.merge(med_pd, on=['SUBJECT_ID', 'HADM_ID'], how='inner') # data = data.merge(side_pd, on=['SUBJECT_ID', 'HADM_ID'], how='inner') # data = data.merge(pro_pd, on=['SUBJECT_ID', 'HADM_ID'], how='inner') # data['ICD9_CODE_Len'] = data['ICD9_CODE'].map(lambda x: len(x)) # data['NDC_Len'] = data['NDC'].map(lambda x: len(x)) return data def filter_patient(data, dx_range=(2, np.inf), rx_range=(2, np.inf)): print('filter_patient') drop_subject_ls = [] for subject_id in data['SUBJECT_ID'].unique(): item_data = data[data['SUBJECT_ID'] == subject_id] for index, row in item_data.iterrows(): dx_len = len(list(row['ICD9_CODE'])) rx_len = len(list(row['ATC4'])) if dx_len < dx_range[0] or dx_len > dx_range[1]: drop_subject_ls.append(subject_id) break if rx_len < rx_range[0] or rx_len > rx_range[1]: drop_subject_ls.append(subject_id) break data.drop(index=data[data['SUBJECT_ID'].isin( drop_subject_ls)].index, axis=0, inplace=True) return data.reset_index(drop=True) def statistics(data): print('#patients ', data['SUBJECT_ID'].unique().shape) print('#clinical events ', len(data)) diag = data['ICD9_CODE'].values med = data['ATC4'].values unique_diag = set([j for i in diag for j in list(i)]) unique_med = set([j for i in med for j in list(i)]) print('#diagnosis ', len(unique_diag)) print('#med ', len(unique_med)) avg_diag = 0 avg_med = 0 max_diag = 0 max_med = 0 cnt = 0 max_visit = 0 avg_visit = 0 for subject_id in data['SUBJECT_ID'].unique(): item_data = data[data['SUBJECT_ID'] == subject_id] x = [] y = [] visit_cnt = 0 for index, row in item_data.iterrows(): visit_cnt += 1 cnt += 1 x.extend(list(row['ICD9_CODE'])) y.extend(list(row['ATC4'])) x = set(x) y = set(y) avg_diag += len(x) avg_med += len(y) avg_visit += visit_cnt if len(x) > max_diag: max_diag = len(x) if len(y) > max_med: max_med = len(y) if visit_cnt > max_visit: max_visit = visit_cnt print('#avg of diagnoses ', avg_diag / cnt) print('#avg of medicines ', avg_med / cnt) print('#avg of vists ', avg_visit / len(data['SUBJECT_ID'].unique())) print('#max of diagnoses ', max_diag) print('#max of medicines ', max_med) print('#max of visit ', max_visit) def run(visit_range=(1, 2)): data = process_all(visit_range) data = filter_patient(data) # unique code save diag = data['ICD9_CODE'].values med = data['ATC4'].values unique_diag = set([j for i in diag for j in list(i)]) unique_med = set([j for i in med for j in list(i)]) return data, unique_diag, unique_med def load_gamenet_multi_visit_data(file_name='data_gamenet.pkl'): data = pd.read_pickle(file_name) data.rename(columns={'NDC': 'ATC4'}, inplace=True) data.drop(columns=['PRO_CODE', 'NDC_Len'], axis=1, inplace=True) # unique code save diag = data['ICD9_CODE'].values med = data['ATC4'].values unique_diag = set([j for i in diag for j in list(i)]) unique_med = set([j for i in med for j in list(i)]) return data, unique_diag, unique_med def load_gamenet_multi_visit_data_with_pro(file_name='data_gamenet.pkl'): data = pd.read_pickle(file_name) data.rename(columns={'NDC': 'ATC4'}, inplace=True) data.drop(columns=['NDC_Len'], axis=1, inplace=True) # unique code save diag = data['ICD9_CODE'].values med = data['ATC4'].values pro = data['PRO_CODE'].values unique_diag = set([j for i in diag for j in list(i)]) unique_med = set([j for i in med for j in list(i)]) unique_pro = set([j for i in pro for j in list(i)]) return data, unique_pro, unique_diag, unique_med def main(): print('-'*20 + '\ndata-single processing') data_single_visit, diag1, med1 = run(visit_range=(1, 2)) print('-'*20 + '\ndata-multi processing ') data_multi_visit, pro, diag2, med2 = load_gamenet_multi_visit_data_with_pro() # med_diag_pair = gen_med_diag_pair(data) unique_diag = diag1 | diag2 unique_med = med1 | med2 with open('dx-vocab.txt', 'w') as fout: for code in unique_diag: fout.write(code + '\n') with open('rx-vocab.txt', 'w') as fout: for code in unique_med: fout.write(code + '\n') with open('rx-vocab-multi.txt', 'w') as fout: for code in med2: fout.write(code + '\n') with open('dx-vocab-multi.txt', 'w') as fout: for code in diag2: fout.write(code + '\n') with open('px-vocab-multi.txt', 'w') as fout: for code in pro: fout.write(code + '\n') # save data data_single_visit.to_pickle('data-single-visit.pkl') data_multi_visit.to_pickle('data-multi-visit.pkl') # med_diag_pair.to_pickle('med_diag.pkl') # print('med2diag len:', len(med_diag_pair)) print('-'*20 + '\ndata-single stat') statistics(data_single_visit) print('-'*20 + '\ndata_multi stat') statistics(data_multi_visit) return data_single_visit, data_multi_visit data_single_visit, data_multi_visit = main() data_multi_visit.head(10) # %% # split train, eval and test dataset random.seed(1203) def split_dataset(data_path='data-multi-visit.pkl'): data = pd.read_pickle(data_path) sample_id = data['SUBJECT_ID'].unique() random_number = [i for i in range(len(sample_id))] # shuffle(random_number) train_id = sample_id[random_number[:int(len(sample_id)*2/3)]] eval_id = sample_id[random_number[int( len(sample_id)*2/3): int(len(sample_id)*5/6)]] test_id = sample_id[random_number[int(len(sample_id)*5/6):]] def ls2file(list_data, file_name): with open(file_name, 'w') as fout: for item in list_data: fout.write(str(item) + '\n') ls2file(train_id, 'train-id.txt') ls2file(eval_id, 'eval-id.txt') ls2file(test_id, 'test-id.txt') print('train size: %d, eval size: %d, test size: %d' % (len(train_id), len(eval_id), len(test_id))) split_dataset() # %% # generate ehr graph for gamenet def generate_ehr_graph(): data_multi = pd.read_pickle('data-multi-visit.pkl') data_single = pd.read_pickle('data-single-visit.pkl') rx_voc_size = 0 rx_voc = {} with open('rx-vocab.txt', 'r') as fin: for line in fin: rx_voc[line.rstrip('\n')] = rx_voc_size rx_voc_size += 1 ehr_adj = np.zeros((rx_voc_size, rx_voc_size)) for idx, row in data_multi.iterrows(): med_set = list(map(lambda x: rx_voc[x], row['ATC4'])) for i, med_i in enumerate(med_set): for j, med_j in enumerate(med_set): if j <= i: continue ehr_adj[med_i, med_j] = 1 ehr_adj[med_j, med_i] = 1 for idx, row in data_single.iterrows(): med_set = list(map(lambda x: rx_voc[x], row['ATC4'])) for i, med_i in enumerate(med_set): for j, med_j in enumerate(med_set): if j <= i: continue ehr_adj[med_i, med_j] = 1 ehr_adj[med_j, med_i] = 1 print('avg med for one ', np.mean(np.sum(ehr_adj, axis=-1))) return ehr_adj ehr_adj = generate_ehr_graph() dill.dump(ehr_adj, open('ehr_adj.pkl', 'wb')) # %% # max len medical codes data = data_multi_visit max_len = 0 for subject_id in data['SUBJECT_ID'].unique(): item_df = data[data['SUBJECT_ID'] == subject_id] len_tmp = 0 for index, row in item_df.iterrows(): len_tmp += (len(row['ICD9_CODE']) + len(row['ATC4'])) if len_tmp > max_len: max_len = len_tmp print(max_len) # %% print(max_len) # %% pd.read_pickle(file_name) # %% data.rename(columns={'NDC': 'ATC4'}, inplace=True) data.drop(columns=['PRO_CODE', 'NDC_Len'], axis=1, inplace=True) # %% data.shape # %% data_dir = './data/' print('multi visit') multi_file = data_dir + 'data-multi-visit.pkl' multi_pkl = pd.read_pickle(multi_file) multi_pkl.iloc[0, 4:] # %% # stat rx_cnt_ls = [] dx_cnt_ls = [] visit_cnt_ls = [] for subject_id in multi_pkl['SUBJECT_ID'].unique(): visit_cnt = 0 for idx, visit in multi_pkl[multi_pkl['SUBJECT_ID'] == subject_id].iterrows(): rx_cnt_ls.append(len(visit['ATC4'])) dx_cnt_ls.append(len(visit['ICD9_CODE'])) visit_cnt += 1 visit_cnt_ls.append(visit_cnt) print('mean') print('dx', np.mean(dx_cnt_ls)) print('rx', np.mean(rx_cnt_ls)) print('visit', np.mean(visit_cnt_ls)) print('max') print('dx', np.max(dx_cnt_ls)) print('rx', np.max(rx_cnt_ls)) print('visit', np.max(visit_cnt_ls)) print('min') print('dx', np.min(dx_cnt_ls)) print('rx', np.min(rx_cnt_ls)) print('visit', np.min(visit_cnt_ls)) print('single visit') # %% single_file = data_dir + 'data-single-visit.pkl' single_pkl = pd.read_pickle(single_file) single_pkl.head() # %% rx_cnt_ls = [] dx_cnt_ls = [] visit_cnt_ls = [] for subject_id in single_pkl['SUBJECT_ID'].unique(): visit_cnt = 0 for idx, visit in single_pkl[single_pkl['SUBJECT_ID'] == subject_id].iterrows(): rx_cnt_ls.append(len(visit['ATC4'])) dx_cnt_ls.append(len(visit['ICD9_CODE'])) visit_cnt += 1 visit_cnt_ls.append(visit_cnt) print('mean') print('dx', np.mean(dx_cnt_ls)) print('rx', np.mean(rx_cnt_ls)) print('visit', np.mean(visit_cnt_ls)) print('max') print('dx', np.max(dx_cnt_ls)) print('rx', np.max(rx_cnt_ls)) print('visit', np.max(visit_cnt_ls)) print('min') print('dx', np.min(dx_cnt_ls)) print('rx', np.min(rx_cnt_ls)) print('visit', np.min(visit_cnt_ls)) # multi visit # mean # dx 13.640849760255728 # rx 13.930074587107086 # visit 2.3647244094488187 # max # dx 39 # rx 36 # visit 29 # min # dx 1 # rx 1 # visit 1 # single visit # mean # dx 10.820458611156285 # rx 13.759277931370955 # visit 1.0 # max # dx 39 # rx 52 # visit 1 # min # dx 2 # rx 2 # visit 1 # %% data_dir = './data/' single_pkl.head() # %%