## Copyright (c) 2017 Robert Bosch GmbH
## All rights reserved.
##
## This source code is licensed under the MIT license found in the
## LICENSE file in the root directory of this source tree.

from graph_transformer import sqdist
import torch
import torch.nn.functional as F
import numpy as np

def tta(x_atom, x_atom_pos, x_bond, x_bond_dist, x_triplet, x_triplet_angle, x_quad, x_quad_angle, args):
    elem_drop = args.elem_drop
    if elem_drop == 0.0:
        return x_atom, x_atom_pos, x_bond, x_bond_dist, x_triplet, x_triplet_angle, x_quad, x_quad_angle
    dev = x_atom.device
    bsz = x_atom.shape[0]
    N = x_atom.shape[1]
    M = x_bond.shape[1]
    P = x_triplet.shape[1]
    Q = x_quad.shape[1] if args.use_quad else 0
    atom_mask = (torch.zeros(bsz, N, 1).bernoulli_(1-elem_drop) / (1-elem_drop)).to(dev)
    bond_mask = (torch.zeros(bsz, M, 1).bernoulli_(1-elem_drop) / (1-elem_drop)).to(dev)
    trip_mask = (torch.zeros(bsz, P, 1).bernoulli_(1-elem_drop) / (1-elem_drop)).to(dev)
    x_atom = x_atom * atom_mask.long()
    x_atom_pos = x_atom_pos * atom_mask
    x_bond = x_bond * bond_mask.long()
    x_bond_dist = x_bond_dist * bond_mask[:,:,0]
    x_triplet = x_triplet * trip_mask.long()
    x_triplet_angle = x_triplet_angle * trip_mask[:,:,0]
    if args.use_quad:
        quad_mask = torch.zeros(bsz, Q, 1).bernoulli_(1-elem_drop) / (1-elem_drop)
        x_quad = x_quad * quad_mask
        x_quad_angle = x_quad_angle * quad_mask
    return x_atom, x_atom_pos, x_bond, x_bond_dist, x_triplet, x_triplet_angle, x_quad, x_quad_angle


def subgraph_filter(x_atom, x_atom_pos, x_bond, x_bond_dist, x_triplet, x_triplet_angle, args):
    D = sqdist(x_atom_pos[:,:,:3], x_atom_pos[:,:,:3])
    x_atom, x_atom_pos, x_bond, x_bond_dist, x_triplet, x_triplet_angle = \
        x_atom.clone().detach(), x_atom_pos.clone().detach(), x_bond.clone().detach(), x_bond_dist.clone().detach(), x_triplet.clone().detach(), x_triplet_angle.clone().detach()
    bsz = x_atom.shape[0]
    bonds_mask = torch.ones(bsz, x_bond.shape[1], 1).to(x_atom.device)
    for mol_id in range(bsz):
        if np.random.uniform(0,1) > args.cutout:
            continue
        assert not args.use_quad, "Quads are NOT cut out yet"
        atom_dists = D[mol_id]
        atoms = x_atom[mol_id, :, 0]
        n_valid_atoms = (atoms > 0).sum().item()
        if n_valid_atoms < 10:
            continue
        idx_to_drop = np.random.randint(n_valid_atoms-1)
        dist_row = atom_dists[idx_to_drop]
        neighbor_to_drop = torch.argmin((dist_row[dist_row>0])[:n_valid_atoms-1]).item()
        if neighbor_to_drop >= idx_to_drop: 
            neighbor_to_drop += 1
        x_atom[mol_id, idx_to_drop] = 0
        x_atom[mol_id, neighbor_to_drop] = 0
        x_atom_pos[mol_id, idx_to_drop] = 0
        x_atom_pos[mol_id, neighbor_to_drop] = 0
        bond_pos_to_drop = (x_bond[mol_id, :, 3] == idx_to_drop) | (x_bond[mol_id, :, 3] == neighbor_to_drop) \
                         | (x_bond[mol_id, :, 4] == idx_to_drop) | (x_bond[mol_id, :, 4] == neighbor_to_drop)
        trip_pos_to_drop = (x_triplet[mol_id, :, 2] == idx_to_drop) | (x_triplet[mol_id, :, 2] == neighbor_to_drop) \
                         | (x_triplet[mol_id, :, 3] == idx_to_drop) | (x_triplet[mol_id, :, 3] == neighbor_to_drop) \
                         | (x_triplet[mol_id, :, 4] == idx_to_drop) | (x_triplet[mol_id, :, 4] == neighbor_to_drop)
        x_bond[mol_id, bond_pos_to_drop] = 0
        x_bond_dist[mol_id, bond_pos_to_drop] = 0
        bonds_mask[mol_id, bond_pos_to_drop] = 0
        x_triplet[mol_id, trip_pos_to_drop] = 0
        x_triplet_angle[mol_id, trip_pos_to_drop] = 0
    return x_atom, x_atom_pos, x_bond, x_bond_dist, x_triplet, x_triplet_angle, bonds_mask