import bisect import operator import numpy as np import torch from torch.utils import data from multilayer_perceptron import * from utils import * def preprocess_weights(weights): w_later = np.abs(weights[-1]) w_input = np.abs(weights[0]) for i in range(len(weights) - 2, 0, -1): w_later = np.matmul(w_later, np.abs(weights[i])) return w_input, w_later def make_one_indexed(interaction_ranking): return [(tuple(np.array(i) + 1), s) for i, s in interaction_ranking] def interpret_interactions(w_input, w_later, get_main_effects=False): interaction_strengths = {} for i in range(w_later.shape[1]): sorted_hweights = sorted( enumerate(w_input[i]), key=lambda x: x[1], reverse=True ) interaction_candidate = [] candidate_weights = [] for j in range(w_input.shape[1]): bisect.insort(interaction_candidate, sorted_hweights[j][0]) candidate_weights.append(sorted_hweights[j][1]) if not get_main_effects and len(interaction_candidate) == 1: continue interaction_tup = tuple(interaction_candidate) if interaction_tup not in interaction_strengths: interaction_strengths[interaction_tup] = 0 interaction_strength = (min(candidate_weights)) * (np.sum(w_later[:, i])) interaction_strengths[interaction_tup] += interaction_strength interaction_ranking = sorted( interaction_strengths.items(), key=operator.itemgetter(1), reverse=True ) return interaction_ranking def interpret_pairwise_interactions(w_input, w_later): p = w_input.shape[1] interaction_ranking = [] for i in range(p): for j in range(p): if i < j: strength = (np.minimum(w_input[:, i], w_input[:, j]) * w_later).sum() interaction_ranking.append(((i, j), strength)) interaction_ranking.sort(key=lambda x: x[1], reverse=True) return interaction_ranking def get_interactions(weights, pairwise=False, one_indexed=False): w_input, w_later = preprocess_weights(weights) if pairwise: interaction_ranking = interpret_pairwise_interactions(w_input, w_later) else: interaction_ranking = interpret_interactions(w_input, w_later) interaction_ranking = prune_redundant_interactions(interaction_ranking) if one_indexed: return make_one_indexed(interaction_ranking) else: return interaction_ranking def prune_redundant_interactions(interaction_ranking, max_interactions=100): interaction_ranking_pruned = [] current_superset_inters = [] for inter, strength in interaction_ranking: set_inter = set(inter) if len(interaction_ranking_pruned) >= max_interactions: break subset_inter_skip = False update_superset_inters = [] for superset_inter in current_superset_inters: if set_inter < superset_inter: subset_inter_skip = True break elif not (set_inter > superset_inter): update_superset_inters.append(superset_inter) if subset_inter_skip: continue current_superset_inters = update_superset_inters current_superset_inters.append(set_inter) interaction_ranking_pruned.append((inter, strength)) return interaction_ranking_pruned def detect_interactions( Xd, Yd, arch=[256, 128, 64], batch_size=100, device=torch.device("cpu"), seed=None, **kwargs ): if seed is not None: set_seed(seed) data_loaders = convert_to_torch_loaders(Xd, Yd, batch_size) model = create_mlp([feats.shape[1]] + arch + [1]).to(device) model, mlp_loss = train(model, data_loaders, device=device, **kwargs) inters = get_interactions(get_weights(model)) return inters, mlp_loss