package edu.sdsc.mmtf.spark.mappers; import java.util.ArrayList; import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; import java.util.List; import java.util.Map; import javax.vecmath.Point3d; import org.apache.spark.api.java.function.PairFlatMapFunction; import org.biojava.nbio.structure.symmetry.geometry.DistanceBox; import org.rcsb.mmtf.api.StructureDataInterface; import org.rcsb.mmtf.decoder.DecoderUtils; import org.rcsb.mmtf.encoder.AdapterToStructureData; import org.rcsb.mmtf.encoder.EncoderUtils; import edu.sdsc.mmtf.spark.filters.ContainsPolymerChainType; import scala.Tuple2; import javax.vecmath.*; /** * TODO * @author Yue Yu */ public class StructureToProteinDimers implements PairFlatMapFunction<Tuple2<String,StructureDataInterface>,String, StructureDataInterface> { private static final long serialVersionUID = 590629701792189982L; private double cutoffDistance = 8.0; private int contacts = 20; private boolean useAllAtoms = false; private boolean exclusive = false; /** * * */ public StructureToProteinDimers() {} /** * * */ public StructureToProteinDimers(double cutoffDistance, int contacts) { this.cutoffDistance = cutoffDistance; this.contacts = contacts; } /** * * */ public StructureToProteinDimers(double cutoffDistance, int contacts, boolean useAllAtoms, boolean exclusive) { this.cutoffDistance = cutoffDistance; this.contacts = contacts; this.useAllAtoms = useAllAtoms; this.exclusive = exclusive; } @Override public Iterator<Tuple2<String, StructureDataInterface>> call(Tuple2<String, StructureDataInterface> t) throws Exception { StructureDataInterface structure = t._2; List<Tuple2<String, StructureDataInterface>> resList = new ArrayList<>(); //split the structure into a list of structure of chains List<StructureDataInterface> chains = splitToChains(structure); List<Vector3d> chainVectors = getChainVectors(chains); //for each chain, create a distance box List<DistanceBox<Integer>> boxes; if(useAllAtoms == true) boxes = getAllAtomsDistanceBoxes(chains, cutoffDistance); else boxes = getCBetaAtomsDistanceBoxes(chains, cutoffDistance); List<Vector3d> exclusiveList = new ArrayList<Vector3d>(); //loop through chains for(int i = 0; i < chains.size(); i++) { for(int j = 0; j < i; j++) { //for each pair of chains, check if they are in contact or not if(checkPair(boxes.get(i), boxes.get(j), chains.get(i), chains.get(j), cutoffDistance, contacts)) { if(exclusive) { // String es1 = chains.get(i).getEntitySequence(getChainToEntityIndex(chains.get(i))[0]); // String es2 = chains.get(j).getEntitySequence(getChainToEntityIndex(chains.get(j))[0]); Vector3d newVec = calcDiff(chainVectors.get(i), chainVectors.get(j)); // System.out.println(newVec); if(!checkList(newVec, exclusiveList)) { resList.add(combineChains(chains.get(i), chains.get(j))); exclusiveList.add(newVec); } } else resList.add(combineChains(chains.get(i), chains.get(j))); } } } // System.out.println(exclusiveList); return resList.iterator(); } private static double distance(StructureDataInterface s1, StructureDataInterface s2, Integer index1, Integer index2 ) { double xCoord = s1.getxCoords()[index1]; double yCoord = s1.getyCoords()[index1]; double zCoord = s1.getzCoords()[index1]; Point3d newPoint1 = new Point3d(xCoord, yCoord, zCoord); xCoord = s2.getxCoords()[index2]; yCoord = s2.getyCoords()[index2]; zCoord = s2.getzCoords()[index2]; Point3d newPoint2 = new Point3d(xCoord, yCoord, zCoord); return newPoint1.distance(newPoint2); } private static boolean checkPair(DistanceBox<Integer> box1, DistanceBox<Integer> box2, StructureDataInterface s1, StructureDataInterface s2, double cutoffDistance, int contacts) { List<Integer> pointsInBox2= box1.getIntersection(box2); List<Integer> pointsInBox1= box2.getIntersection(box1); HashSet<Integer> hs1 = new HashSet<Integer>(); HashSet<Integer> hs2 = new HashSet<Integer>(); int num = 0; for(int i = 0; i < pointsInBox2.size(); i++) { for(int j = 0; j < pointsInBox1.size(); j++) { if(hs1.contains(i) || hs2.contains(j)) continue; if(distance(s1, s2, pointsInBox2.get(i), pointsInBox1.get(j)) < cutoffDistance) { num++; hs1.add(i); hs2.add(j); } if(num > contacts) return true; } } return false; } private static boolean checkList(Vector3d vec, List<Vector3d> exclusiveList) { for(int i = 0; i < exclusiveList.size(); i++) { if(calcDiff(vec, exclusiveList.get(i)).length() < 0.1 && vec.angle(exclusiveList.get(i)) < 0.1) return true; vec.negate(); if(calcDiff(vec, exclusiveList.get(i)).length() < 0.1 && vec.angle(exclusiveList.get(i)) < 0.1) return true; vec.negate(); } return false; } private static List<Vector3d> getChainVectors(List<StructureDataInterface> chains) { List<Vector3d> chainVectors = new ArrayList<Vector3d>(); for(int i = 0; i < chains.size(); i++) { chainVectors.add(calcAverageVec(chains.get(i))); } return chainVectors; } private static Vector3d calcAverageVec(StructureDataInterface s1) { double totX = 0; double totY = 0; double totZ = 0; for(int i = 0; i <s1.getNumAtoms(); i++) { totX += s1.getxCoords()[i]; totY += s1.getyCoords()[i]; totZ += s1.getzCoords()[i]; } return new Vector3d(totX/s1.getNumAtoms(), totY/s1.getNumAtoms(), totZ/s1.getNumAtoms()); } private static Vector3d calcDiff(Vector3d v1, Vector3d v2) { v1.sub(v2); return v1; } private static List<DistanceBox<Integer>> getAllAtomsDistanceBoxes(List<StructureDataInterface> chains, double cutoffDistance) { List<DistanceBox<Integer>> distanceBoxes = new ArrayList<DistanceBox<Integer>>(); for(int i = 0; i < chains.size(); i++) { StructureDataInterface tmp = chains.get(i); DistanceBox<Integer> newBox = new DistanceBox<Integer>(cutoffDistance); //System.out.println(tmp.getNumAtoms()); for(int j = 0; j <tmp.getNumAtoms(); j++) { double xCoord = tmp.getxCoords()[j]; double yCoord = tmp.getyCoords()[j]; double zCoord = tmp.getzCoords()[j]; Point3d newPoint = new Point3d(xCoord, yCoord, zCoord); //System.out.println(newPoint); newBox.addPoint(newPoint, j); } distanceBoxes.add(newBox); } return distanceBoxes; } private static List<DistanceBox<Integer>> getCBetaAtomsDistanceBoxes(List<StructureDataInterface> chains, double cutoffDistance) { List<DistanceBox<Integer>> distanceBoxes = new ArrayList<DistanceBox<Integer>>(); for(int i = 0; i < chains.size(); i++) { StructureDataInterface tmp = chains.get(i); DistanceBox<Integer> newBox = new DistanceBox<Integer>(cutoffDistance); int groupIndex = 0; int atomIndex = 0; for (int k = 0; k < tmp.getGroupsPerChain()[0]; k++) { int groupType = tmp.getGroupTypeIndices()[groupIndex]; for (int m = 0; m < tmp.getNumAtomsInGroup(groupType); m++) { String atomName = tmp.getGroupAtomNames(groupType)[m]; if(atomName.equals("CB")) { double xCoord = tmp.getxCoords()[atomIndex]; double yCoord = tmp.getyCoords()[atomIndex]; double zCoord =tmp.getzCoords()[atomIndex]; Point3d newPoint = new Point3d(xCoord, yCoord, zCoord); newBox.addPoint(newPoint, atomIndex); } atomIndex++; } groupIndex++; } distanceBoxes.add(newBox); } return distanceBoxes; } private static List<StructureDataInterface> splitToChains(StructureDataInterface s) { List<StructureDataInterface> chains = new ArrayList<StructureDataInterface>(); int numChains = s.getChainsPerModel()[0]; int[] chainToEntityIndex = getChainToEntityIndex(s); int[] atomsPerChain = new int[numChains]; int[] bondsPerChain = new int[numChains]; getNumAtomsAndBonds(s, atomsPerChain, bondsPerChain); for (int i = 0, atomCounter = 0, groupCounter = 0; i < numChains; i++){ AdapterToStructureData newChain = new AdapterToStructureData(); int entityToChainIndex = chainToEntityIndex[i]; Map<Integer, Integer> atomMap = new HashMap<>(); // to avoid of information loss, add chainName/IDs and entity id // this required by some queries String structureId = s.getStructureId() + "." + s.getChainNames()[i] + "." + s.getChainIds()[i] + "." + (entityToChainIndex+1); // set header newChain.initStructure(bondsPerChain[i], atomsPerChain[i], s.getGroupsPerChain()[i], 1, 1, structureId); DecoderUtils.addXtalographicInfo(s, newChain); DecoderUtils.addHeaderInfo(s, newChain); // set model info (only one model: 0) newChain.setModelInfo(0, 1); // set entity and chain info newChain.setEntityInfo(new int[]{0}, s.getEntitySequence(entityToChainIndex), s.getEntityDescription(entityToChainIndex), s.getEntityType(entityToChainIndex)); newChain.setChainInfo(s.getChainIds()[i], s.getChainNames()[i], s.getGroupsPerChain()[i]); for (int j = 0; j < s.getGroupsPerChain()[i]; j++, groupCounter++){ int groupIndex = s.getGroupTypeIndices()[groupCounter]; // set group info newChain.setGroupInfo(s.getGroupName(groupIndex), s.getGroupIds()[groupCounter], s.getInsCodes()[groupCounter], s.getGroupChemCompType(groupIndex), s.getNumAtomsInGroup(groupIndex), s.getGroupBondOrders(groupIndex).length, s.getGroupSingleLetterCode(groupIndex), s.getGroupSequenceIndices()[groupCounter], s.getSecStructList()[groupCounter]); for (int k = 0; k < s.getNumAtomsInGroup(groupIndex); k++, atomCounter++){ newChain.setAtomInfo(s.getGroupAtomNames(groupIndex)[k], s.getAtomIds()[atomCounter], s.getAltLocIds()[atomCounter], s.getxCoords()[atomCounter], s.getyCoords()[atomCounter], s.getzCoords()[atomCounter], s.getOccupancies()[atomCounter], s.getbFactors()[atomCounter], s.getGroupElementNames(groupIndex)[k], s.getGroupAtomCharges(groupIndex)[k]); } // add intra-group bond info for (int l = 0; l < s.getGroupBondOrders(groupIndex).length; l++) { int bondIndOne = s.getGroupBondIndices(groupIndex)[l*2]; int bondIndTwo = s.getGroupBondIndices(groupIndex)[l*2+1]; int bondOrder = s.getGroupBondOrders(groupIndex)[l]; newChain.setGroupBond(bondIndOne, bondIndTwo, bondOrder); } } // Add inter-group bond info for(int ii = 0; ii < s.getInterGroupBondOrders().length; ii++){ int bondIndOne = s.getInterGroupBondIndices()[ii*2]; int bondIndTwo = s.getInterGroupBondIndices()[ii*2+1]; int bondOrder = s.getInterGroupBondOrders()[ii]; Integer indexOne = atomMap.get(bondIndOne); if (indexOne != null) { Integer indexTwo = atomMap.get(bondIndTwo); if (indexTwo != null) { newChain.setInterGroupBond(indexOne, indexTwo, bondOrder); } } } newChain.finalizeStructure(); if(EncoderUtils.getTypeFromChainId(newChain, 0).equals("polymer")) { boolean match = true; for (int j = 0; j < newChain.getGroupsPerChain()[0]; j++) { if (match) { int groupIndex = newChain.getGroupTypeIndices()[j]; String type = newChain.getGroupChemCompType(groupIndex); //System.out.println(j + " " + type); match = type.equals(ContainsPolymerChainType.L_PEPTIDE_LINKING) || type.equals(ContainsPolymerChainType.PEPTIDE_LINKING); } } if(match) chains.add(newChain); } } return chains; } /** * A method that takes two structure of chains and return a single structur of two chains. */ private static Tuple2<String, StructureDataInterface> combineChains(StructureDataInterface s1, StructureDataInterface s2) { int groupCounter = 0; int atomCounter = 0; String structureId = s1.getStructureId() + "_append_" + s2.getStructureId(); AdapterToStructureData combinedStructure = new AdapterToStructureData(); combinedStructure.initStructure(s1.getNumBonds() + s2.getNumBonds(), s1.getNumAtoms() + s2.getNumAtoms(), s1.getNumGroups() + s2.getNumGroups(), 2, 1, structureId); DecoderUtils.addXtalographicInfo(s1, combinedStructure); DecoderUtils.addHeaderInfo(s1, combinedStructure); combinedStructure.setModelInfo(0, 2); // set entity and chain info combinedStructure.setEntityInfo(new int[]{0}, s1.getEntitySequence(getChainToEntityIndex(s1)[0]), s1.getEntityDescription(getChainToEntityIndex(s1)[0]), s1.getEntityType(getChainToEntityIndex(s1)[0])); combinedStructure.setChainInfo(s1.getChainIds()[0], s1.getChainNames()[0], s1.getGroupsPerChain()[0]); for (int i = 0; i < s1.getGroupsPerChain()[0]; i++, groupCounter++){ int groupIndex = s1.getGroupTypeIndices()[groupCounter]; // set group info combinedStructure.setGroupInfo(s1.getGroupName(groupIndex), s1.getGroupIds()[groupCounter], s1.getInsCodes()[groupCounter], s1.getGroupChemCompType(groupIndex), s1.getNumAtomsInGroup(groupIndex), s1.getGroupBondOrders(groupIndex).length, s1.getGroupSingleLetterCode(groupIndex), s1.getGroupSequenceIndices()[groupCounter], s1.getSecStructList()[groupCounter]); for (int j = 0; j < s1.getNumAtomsInGroup(groupIndex); j++, atomCounter++){ combinedStructure.setAtomInfo(s1.getGroupAtomNames(groupIndex)[j], s1.getAtomIds()[atomCounter], s1.getAltLocIds()[atomCounter], s1.getxCoords()[atomCounter], s1.getyCoords()[atomCounter], s1.getzCoords()[atomCounter], s1.getOccupancies()[atomCounter], s1.getbFactors()[atomCounter], s1.getGroupElementNames(groupIndex)[j], s1.getGroupAtomCharges(groupIndex)[j]); } //TODO : not sure if we should add bonds like this. for (int j = 0; j < s1.getGroupBondOrders(groupIndex).length; j++) { int bondIndOne = s1.getGroupBondIndices(groupIndex)[j*2]; int bondIndTwo = s1.getGroupBondIndices(groupIndex)[j*2+1]; int bondOrder = s1.getGroupBondOrders(groupIndex)[j]; combinedStructure.setGroupBond(bondIndOne, bondIndTwo, bondOrder); } } // set entity and chain info combinedStructure.setEntityInfo(new int[]{1}, s1.getEntitySequence(getChainToEntityIndex(s2)[0]), s2.getEntityDescription(getChainToEntityIndex(s2)[0]), s2.getEntityType(getChainToEntityIndex(s2)[0])); combinedStructure.setChainInfo(s2.getChainIds()[0], s2.getChainNames()[0], s2.getGroupsPerChain()[0]); groupCounter = 0; atomCounter = 0; for (int i = 0; i < s2.getGroupsPerChain()[0]; i++, groupCounter++){ int groupIndex = s2.getGroupTypeIndices()[groupCounter]; // set group info combinedStructure.setGroupInfo(s2.getGroupName(groupIndex), s2.getGroupIds()[groupCounter], s2.getInsCodes()[groupCounter], s2.getGroupChemCompType(groupIndex), s2.getNumAtomsInGroup(groupIndex), s2.getGroupBondOrders(groupIndex).length, s2.getGroupSingleLetterCode(groupIndex), s2.getGroupSequenceIndices()[groupCounter], s2.getSecStructList()[groupCounter]); for (int j = 0; j < s2.getNumAtomsInGroup(groupIndex); j++, atomCounter++){ combinedStructure.setAtomInfo(s2.getGroupAtomNames(groupIndex)[j], s2.getAtomIds()[atomCounter], s2.getAltLocIds()[atomCounter], s2.getxCoords()[atomCounter], s2.getyCoords()[atomCounter], s2.getzCoords()[atomCounter], s2.getOccupancies()[atomCounter], s2.getbFactors()[atomCounter], s2.getGroupElementNames(groupIndex)[j], s2.getGroupAtomCharges(groupIndex)[j]); } //TODO : not sure if we should add bonds like this. for (int j = 0; j < s2.getGroupBondOrders(groupIndex).length; j++) { int bondIndOne = s2.getGroupBondIndices(groupIndex)[j*2]; int bondIndTwo = s2.getGroupBondIndices(groupIndex)[j*2+1]; int bondOrder = s2.getGroupBondOrders(groupIndex)[j]; combinedStructure.setGroupBond(bondIndOne, bondIndTwo, bondOrder); } } combinedStructure.finalizeStructure(); return (new Tuple2<String, StructureDataInterface>(structureId, combinedStructure)); } /** * Gets the number of atoms and bonds per chain. */ private static void getNumAtomsAndBonds(StructureDataInterface structure, int[] atomsPerChain, int[] bondsPerChain) { int numChains = structure.getChainsPerModel()[0]; for (int i = 0, groupCounter = 0; i < numChains; i++){ for (int j = 0; j < structure.getGroupsPerChain()[i]; j++, groupCounter++){ int groupIndex = structure.getGroupTypeIndices()[groupCounter]; atomsPerChain[i] += structure.getNumAtomsInGroup(groupIndex); bondsPerChain[i] += structure.getGroupBondOrders(groupIndex).length; } } } /** * Returns an array that maps a chain index to an entity index. * @param structureDataInterface * @return */ private static int[] getChainToEntityIndex(StructureDataInterface structure) { int[] entityChainIndex = new int[structure.getNumChains()]; for (int i = 0; i < structure.getNumEntities(); i++) { for (int j: structure.getEntityChainIndexList(i)) { entityChainIndex[j] = i; } } return entityChainIndex; } }