package uclmr.util import java.io.FileWriter import uclmr.hack.EntityHackNormalization import uclmr.{DefaultIx, TensorKB} /** * @author rockt */ object Predictor extends App { val pathToMatrix = args.lift(0).getOrElse("./data/out/bbc/serialized/") val outFile = args.lift(1).getOrElse("./data/out/bbc/predictions.txt") val relations = if (args.size > 2) args.tail else Array( "REL$/location/administrative_division/country", "REL$/base/biblioness/bibs_location/country", "REL$/location/location/contains", "REL$/people/person/nationality", "REL$/base/aareas/schema/administrative_area/administrative_parent", "REL$/location/country/first_level_divisions", "REL$/location/country/capital" ) println("Loading db...") val kb = new TensorKB(100) kb.deserialize(pathToMatrix) println(kb.toInfoString) println("Predicting facts...") val predictions = relations.map(rel => rel -> kb.keys2 .filterNot(t => kb.getFact(rel, t, DefaultIx).exists(_.train)) .map(t => { (kb.prob(rel, t), t) }).sortBy(-_._1) ).toMap println("Reporting predictions...") if (true || args.size > 1) { val writer = new FileWriter(outFile) EntityHackNormalization.init() predictions.foreach(t => t._2.take(100).foreach { case (score, es) => val Array(e1, e2) = es.toString.tail.init.split(",") val can1 = if (e1.startsWith("/m/")) EntityHackNormalization.getCanonical(e1) else e1 val can2 = if (e2.startsWith("/m/")) EntityHackNormalization.getCanonical(e2) else e2 writer.write(s"$score\t$e1\t$can1\t$e2\t$can2\t${ t._1 }\n") }) writer.close() } else { predictions.foreach(t => t._2.take(100).foreach { case (score, es) => val Array(e1, e2) = es.toString.tail.init.split(",") println(s"$score\t$e1\t$e2\t${ t._1 }") }) } }