package regression; import org.junit.AfterClass; import org.junit.Assert; import org.junit.BeforeClass; import org.junit.Test; import org.neo4j.graphdb.GraphDatabaseService; import org.neo4j.kernel.impl.proc.Procedures; import org.neo4j.kernel.internal.GraphDatabaseAPI; import org.neo4j.test.TestGraphDatabaseFactory; import java.io.BufferedReader; import java.io.FileReader; import java.io.IOException; import java.util.*; import static org.neo4j.helpers.collection.MapUtil.map; import org.apache.mahout.common.RandomUtils; public class LogisticTest { private static GraphDatabaseService db; //TODO: larger data set, correct random function @BeforeClass public static void setUp() throws Exception { db = new TestGraphDatabaseFactory().newImpermanentDatabase(); Procedures procedures = ((GraphDatabaseAPI) db).getDependencyResolver().resolveDependency(Procedures.class); procedures.registerProcedure(Logistic.class); procedures.registerFunction(Logistic.class); } @AfterClass public static void tearDown() throws Exception { db.shutdown(); } @Test public void makeModel() throws Exception { String csvFile = "/Users/laurenshin/documents/linreg-graph-analytics/src/test/resources/iris-full.csv"; String line = ""; String csvSplitBy = ","; List<Map<String,Double>> data = new ArrayList<>(); List<String> target = new ArrayList<>(); List<Integer> order = new ArrayList<>(); /*Map<String, Integer> stringToInt = new HashMap<>(); Map<Integer, String> intToString = new HashMap<>(); stringToInt.put("Iris-setosa", 0); stringToInt.put("Iris-versicolor", 1); stringToInt.put("Iris-virginica", 2); intToString.put(0, "Iris-setosa"); intToString.put(1, "Iris-versicolor"); intToString.put(2, "Iris-virginica");*/ try (BufferedReader br = new BufferedReader(new FileReader(csvFile))){ br.readLine(); //skip headers int i = 0; while ((line = br.readLine()) != null) { String[] flower = line.split(csvSplitBy); Map<String, Double> v = new HashMap<>(4); v.put("sepallength", Double.parseDouble(flower[1])); //sepal length v.put("sepalwidth", Double.parseDouble(flower[2])); //sepal width v.put("petallength", Double.parseDouble(flower[3])); //petal length v.put("petalwidth", Double.parseDouble(flower[4])); //petal width data.add(v); target.add(flower[5]); //class order.add(i++); } } catch (IOException e) { e.printStackTrace(); Assert.fail("unable to read csv file for test data"); } RandomUtils.useTestSeed(); Random random = RandomUtils.getRandom(); Collections.shuffle(order, random); List<Integer> train = order.subList(0, 100); List<Integer> test = order.subList(100, 150); db.execute("CALL regression.logistic.create('model', ['Iris-setosa', 'Iris-versicolor', 'Iris-virginica'], " + "{sepallength:'float', sepalwidth:'float', petallength:'float', petalwidth:'float'}, {prior:'L2'})").close(); for (int pass = 0; pass < 30; pass++) { Collections.shuffle(train, random); for (int j : train) { db.execute("CALL regression.logistic.add('model', {output}, {inputs})", map("inputs", data.get(j), "output", target.get(j))); } } int successes = 0; int failures = 0; for (int k : test) { String t; String guess = ((String) db.execute("RETURN regression.logistic.predict('model', {inputs}) as prediction", map("inputs", data.get(k))).next().get("prediction")); if (guess.equals(target.get(k))) { t = "SUCCESS!"; successes++; } else { t = "FAIL!"; failures++; } System.out.format("Expected: %s, Actual: %s %s%n", target.get(k), guess, t); } System.out.format("SUCCESSES: %d%n", successes); System.out.format("FAILURES: %d%n", failures); db.execute("CALL regression.logistic.delete('model')"); } }