package org.tensorframes.dsl import java.io.{BufferedReader, InputStreamReader, File} import java.nio.file.Files import java.nio.charset.StandardCharsets import org.tensorframes.Logging import org.scalatest.Matchers import scala.collection.JavaConverters._ object ExtractNodes extends Matchers with Logging { def executeCommand(py: String): Map[String, String] = { val content = s""" |from __future__ import print_function |import tensorflow as tf | |$py |g = tf.get_default_graph().as_graph_def() |for n in g.node: | print(">>>>>", str(n.name), "<<<<<<") | print(n) """.stripMargin val f = File.createTempFile("pythonTest", ".py") logTrace(s"Created temp file ${f.getAbsolutePath}") Files.write(f.toPath, content.getBytes(StandardCharsets.UTF_8)) // Using the standard python installation in the PATH. It needs to have TensorFlow installed. val p = new ProcessBuilder("python", f.getAbsolutePath).start() val s = p.getInputStream val isr = new InputStreamReader(s) val br = new BufferedReader(isr) var res: String = "" var str: String = "" while(str != null) { str = br.readLine() if (str != null) { res = res + "\n" + str } } p.waitFor() assert(p.exitValue() === 0, (p.exitValue(), { println(content) s"===========\n$content\n===========" })) res.split(">>>>>").map(_.trim).filterNot(_.isEmpty).map { b => val zs = b.split("\n") val node = zs.head.dropRight(7) val rest = zs.tail node -> rest.mkString("\n") } .toMap } def compareOutput(py: String, nodes: Operation*): Unit = { val g = TestUtilities.buildGraph(nodes.head, nodes.tail:_*) val m1 = g.getNodeList.asScala.map { n => n.getName -> n.toString.trim } .toMap val pym = executeCommand(py) logTrace(s"m1 = '$m1'") logTrace(s"pym = '$pym'") assert((m1.keySet -- pym.keySet).isEmpty, { val diff = (m1.keySet -- pym.keySet).toSeq.sorted s"Found extra nodes in scala: $diff" }) assert((pym.keySet -- m1.keySet).isEmpty, { val diff = (pym.keySet -- m1.keySet).toSeq.sorted s"Found extra nodes in python: $diff" }) for (k <- m1.keySet) { assert(m1(k) === pym(k), s"scala=${m1(k)}\npython=${pym(k)}") } } }