Java Code Examples for org.apache.solr.client.solrj.io.Tuple#getDouble()

The following examples show how to use org.apache.solr.client.solrj.io.Tuple#getDouble() . You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: StreamExpressionTest.java    From lucene-solr with Apache License 2.0 5 votes vote down vote up
public boolean assertDouble(Tuple tuple, String fieldName, double d) throws Exception {
  double dv = tuple.getDouble(fieldName);
  if(dv != d) {
    throw new Exception("Doubles not equal:"+d+" : "+dv);
  }

  return true;
}
 
Example 2
Source File: TestSQLHandler.java    From lucene-solr with Apache License 2.0 5 votes vote down vote up
public boolean assertDouble(Tuple tuple, String fieldName, double d) throws Exception {
  double dv = tuple.getDouble(fieldName);
  if(dv != d) {
    throw new Exception("Doubles not equal:"+d+" : "+dv);
  }

  return true;
}
 
Example 3
Source File: TupleStreamDataSetIterator.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
private static double getValue(Tuple tuple, String key, String idKey) {
    final Double value = tuple.getDouble(key);
    if (value == null) {
      // log potentially useful debugging info here ...
      if (idKey == null) {
        log.info("tuple[{}]={}", key, value);
      } else {
        log.info("tuple[{}]={} tuple[{}]={}", key, value, idKey, tuple.get(idKey));
      }
      // ... before proceeding to hit the NullPointerException below
    }
    return value.doubleValue();
}
 
Example 4
Source File: TextLogitStream.java    From lucene-solr with Apache License 2.0 4 votes vote down vote up
@SuppressWarnings({"unchecked"})
public Tuple read() throws IOException {
  try {

    if(++iteration > maxIterations) {
      return Tuple.EOF();
    } else {

      if (this.idfs == null) {
        loadTerms();

        if (weights != null && terms.size() + 1 != weights.size()) {
          throw new IOException(String.format(Locale.ROOT,"invalid expression %s - the number of weights must be %d, found %d", terms.size()+1, weights.size()));
        }
      }

      List<List<Double>> allWeights = new ArrayList<>();
      this.evaluation = new ClassificationEvaluation();

      this.error = 0;
      for (Future<Tuple> logitCall : callShards(getShardUrls())) {

        Tuple tuple = logitCall.get();
        List<Double> shardWeights = (List<Double>) tuple.get("weights");
        allWeights.add(shardWeights);
        this.error += tuple.getDouble("error");
        @SuppressWarnings({"rawtypes"})
        Map shardEvaluation = (Map) tuple.get("evaluation");
        this.evaluation.addEvaluation(shardEvaluation);
      }

      this.weights = averageWeights(allWeights);
      @SuppressWarnings({"rawtypes"})
      Map map = new HashMap();
      map.put(ID, name+"_"+iteration);
      map.put("name_s", name);
      map.put("field_s", field);
      map.put("terms_ss", terms);
      map.put("iteration_i", iteration);

      if(weights != null) {
        map.put("weights_ds", weights);
      }

      map.put("error_d", error);
      evaluation.putToMap(map);
      map.put("alpha_d", this.learningRate);
      map.put("idfs_ds", this.idfs);

      if (iteration != 1) {
        if (lastError <= error) {
          this.learningRate *= 0.5;
        } else {
          this.learningRate *= 1.05;
        }
      }

      lastError = error;

      return new Tuple(map);
    }

  } catch(Exception e) {
    throw new IOException(e);
  }
}
 
Example 5
Source File: StreamingTest.java    From lucene-solr with Apache License 2.0 4 votes vote down vote up
@Test
public void testStatsStream() throws Exception {

  new UpdateRequest()
      .add(id, "0", "a_s", "hello0", "a_i", "0", "a_f", "1")
      .add(id, "2", "a_s", "hello0", "a_i", "2", "a_f", "2")
      .add(id, "3", "a_s", "hello3", "a_i", "3", "a_f", "3")
      .add(id, "4", "a_s", "hello4", "a_i", "4", "a_f", "4")
      .add(id, "1", "a_s", "hello0", "a_i", "1", "a_f", "5")
      .add(id, "5", "a_s", "hello3", "a_i", "10", "a_f", "6")
      .add(id, "6", "a_s", "hello4", "a_i", "11", "a_f", "7")
      .add(id, "7", "a_s", "hello3", "a_i", "12", "a_f", "8")
      .add(id, "8", "a_s", "hello3", "a_i", "13", "a_f", "9")
      .add(id, "9", "a_s", "hello0", "a_i", "14", "a_f", "10")
      .commit(cluster.getSolrClient(), COLLECTIONORALIAS);

  StreamContext streamContext = new StreamContext();
  SolrClientCache solrClientCache = new SolrClientCache();
  streamContext.setSolrClientCache(solrClientCache);

  try {
    SolrParams sParamsA = mapParams("q", "*:*");

    Metric[] metrics = {new SumMetric("a_i"),
        new SumMetric("a_f"),
        new MinMetric("a_i"),
        new MinMetric("a_f"),
        new MaxMetric("a_i"),
        new MaxMetric("a_f"),
        new MeanMetric("a_i"),
        new MeanMetric("a_f"),
        new CountMetric()};

    StatsStream statsStream = new StatsStream(zkHost, COLLECTIONORALIAS, sParamsA, metrics);
    statsStream.setStreamContext(streamContext);
    List<Tuple> tuples = getTuples(statsStream);

    assertEquals(1, tuples.size());

    //Test Long and Double Sums

    Tuple tuple = tuples.get(0);

    Double sumi = tuple.getDouble("sum(a_i)");
    Double sumf = tuple.getDouble("sum(a_f)");
    Double mini = tuple.getDouble("min(a_i)");
    Double minf = tuple.getDouble("min(a_f)");
    Double maxi = tuple.getDouble("max(a_i)");
    Double maxf = tuple.getDouble("max(a_f)");
    Double avgi = tuple.getDouble("avg(a_i)");
    Double avgf = tuple.getDouble("avg(a_f)");
    Double count = tuple.getDouble("count(*)");

    assertEquals(70, sumi.longValue());
    assertEquals(55.0, sumf.doubleValue(), 0.01);
    assertEquals(0.0, mini.doubleValue(), 0.01);
    assertEquals(1.0, minf.doubleValue(), 0.01);
    assertEquals(14.0, maxi.doubleValue(), 0.01);
    assertEquals(10.0, maxf.doubleValue(), 0.01);
    assertEquals(7.0, avgi.doubleValue(), .01);
    assertEquals(5.5, avgf.doubleValue(), .001);
    assertEquals(10, count.doubleValue(), .01);
  } finally {
    solrClientCache.close();
  }
}
 
Example 6
Source File: StreamingTest.java    From lucene-solr with Apache License 2.0 4 votes vote down vote up
@Test
public void testTuple() throws Exception {

  new UpdateRequest()
      .add(id, "0", "a_s", "hello0", "a_i", "0", "a_f", "5.1", "s_multi", "a", "s_multi", "b", "i_multi",
          "1", "i_multi", "2", "f_multi", "1.2", "f_multi", "1.3")
      .commit(cluster.getSolrClient(), COLLECTIONORALIAS);

  StreamContext streamContext = new StreamContext();
  SolrClientCache solrClientCache = new SolrClientCache();
  streamContext.setSolrClientCache(solrClientCache);

  try {
    SolrParams sParams = mapParams("q", "*:*", "fl", "id,a_s,a_i,a_f,s_multi,i_multi,f_multi", "sort", "a_s asc");
    CloudSolrStream stream = new CloudSolrStream(zkHost, COLLECTIONORALIAS, sParams);
    stream.setStreamContext(streamContext);
    List<Tuple> tuples = getTuples(stream);
    Tuple tuple = tuples.get(0);

    String s = tuple.getString("a_s");
    assertEquals("hello0", s);


    long l = tuple.getLong("a_i");
    assertEquals(0, l);

    double d = tuple.getDouble("a_f");
    assertEquals(5.1, d, 0.001);


    List<String> stringList = tuple.getStrings("s_multi");
    assertEquals("a", stringList.get(0));
    assertEquals("b", stringList.get(1));

    List<Long> longList = tuple.getLongs("i_multi");
    assertEquals(1, longList.get(0).longValue());
    assertEquals(2, longList.get(1).longValue());

    List<Double> doubleList = tuple.getDoubles("f_multi");
    assertEquals(1.2, doubleList.get(0).doubleValue(), 0.001);
    assertEquals(1.3, doubleList.get(1).doubleValue(), 0.001);
  } finally {
    solrClientCache.close();
  }
}
 
Example 7
Source File: StreamExpressionTest.java    From lucene-solr with Apache License 2.0 4 votes vote down vote up
@Test
public void testDrillStream() throws Exception {

  new UpdateRequest()
      .add(id, "0", "a_s", "hello0", "a_i", "0", "a_f", "1")
      .add(id, "2", "a_s", "hello0", "a_i", "2", "a_f", "2")
      .add(id, "3", "a_s", "hello3", "a_i", "3", "a_f", "3")
      .add(id, "4", "a_s", "hello4", "a_i", "4", "a_f", "4")
      .add(id, "1", "a_s", "hello0", "a_i", "1", "a_f", "5")
      .add(id, "5", "a_s", "hello3", "a_i", "10", "a_f", "6")
      .add(id, "6", "a_s", "hello4", "a_i", "11", "a_f", "7")
      .add(id, "7", "a_s", "hello3", "a_i", "12", "a_f", "8")
      .add(id, "8", "a_s", "hello3", "a_i", "13", "a_f", "9")
      .add(id, "9", "a_s", "hello0", "a_i", "14", "a_f", "10")
      .commit(cluster.getSolrClient(), COLLECTIONORALIAS);

  List<Tuple> tuples;

  ModifiableSolrParams paramsLoc = new ModifiableSolrParams();
  String expr = "rollup(select(drill("
      + "                            collection1, "
      + "                            q=\"*:*\", "
      + "                            fl=\"a_s, a_f\", "
      + "                            sort=\"a_s desc\", "
      + "                            rollup(input(), over=\"a_s\", count(*), sum(a_f)))," +
      "                        a_s, count(*) as cnt, sum(a_f) as saf)," +
      "                  over=\"a_s\"," +
      "                  sum(cnt), sum(saf)"
      + ")";
  paramsLoc.set("expr", expr);
  paramsLoc.set("qt", "/stream");

  String url = cluster.getJettySolrRunners().get(0).getBaseUrl().toString()+"/"+COLLECTIONORALIAS;
  TupleStream solrStream = new SolrStream(url, paramsLoc);

  StreamContext context = new StreamContext();
  solrStream.setStreamContext(context);
  tuples = getTuples(solrStream);

  Tuple tuple = tuples.get(0);
  String bucket = tuple.getString("a_s");

  Double count = tuple.getDouble("sum(cnt)");
  Double saf = tuple.getDouble("sum(saf)");

  assertTrue(bucket.equals("hello4"));
  assertEquals(count.doubleValue(), 2, 0);
  assertEquals(saf.doubleValue(), 11, 0);

  tuple = tuples.get(1);
  bucket = tuple.getString("a_s");
  count = tuple.getDouble("sum(cnt)");
  saf = tuple.getDouble("sum(saf)");

  assertTrue(bucket.equals("hello3"));
  assertEquals(count.doubleValue(), 4, 0);
  assertEquals(saf.doubleValue(), 26, 0);

  tuple = tuples.get(2);
  bucket = tuple.getString("a_s");
  count = tuple.getDouble("sum(cnt)");
  saf = tuple.getDouble("sum(saf)");

  assertTrue(bucket.equals("hello0"));
  assertTrue(count.doubleValue() == 4);
  assertEquals(saf.doubleValue(), 18, 0);

}
 
Example 8
Source File: ModelTupleStreamTest.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
private void doTest(Model originalModel, int numInputs, int numOutputs) throws Exception {

    final Path tempDirPath = Files.createTempDirectory(null);
    final File tempDirFile = tempDirPath.toFile();
    tempDirFile.deleteOnExit();

    final SolrResourceLoader solrResourceLoader = new SolrResourceLoader(tempDirPath);

    final File tempFile = File.createTempFile("prefix", "suffix", tempDirFile);
    tempFile.deleteOnExit();

    final String serializedModelFileName = tempFile.getPath();

    ModelSerializer.writeModel(originalModel, serializedModelFileName, false);

    final Model restoredModel = ModelGuesser.loadModelGuess(serializedModelFileName);

    final StreamContext streamContext = new StreamContext();
    final SolrClientCache solrClientCache = new SolrClientCache();
    streamContext.setSolrClientCache(solrClientCache);

    final String[] inputKeys = new String[numInputs];
    final String inputKeysList = fillArray(inputKeys, "input", ",");

    final String[] outputKeys = new String[numOutputs];
    final String outputKeysList = fillArray(outputKeys, "output", ",");

    for (final float[] floats : floatsList(numInputs)) {

      final String inputValuesList;
      {
        final StringBuilder sb = new StringBuilder();
        for (int ii=0; ii<inputKeys.length; ++ii) {
          if (0 < ii) sb.append(',');
          sb.append(inputKeys[ii]).append('=').append(floats[ii]);
        }
        inputValuesList = sb.toString();
      }

      final StreamFactory streamFactory = new SolrDefaultStreamFactory()
          .withSolrResourceLoader(solrResourceLoader)
          .withFunctionName("model", ModelTupleStream.class);

      final StreamExpression streamExpression = StreamExpressionParser.parse("model("
        + "tuple(" + inputValuesList + ")"
        + ",serializedModelFileName=\"" + serializedModelFileName + "\""
        + ",inputKeys=\"" + inputKeysList + "\""
        + ",outputKeys=\"" + outputKeysList + "\""
        + ")");

      final TupleStream tupleStream = streamFactory.constructStream(streamExpression);
      tupleStream.setStreamContext(streamContext);

      assertTrue(tupleStream instanceof ModelTupleStream);
      final ModelTupleStream modelTupleStream = (ModelTupleStream)tupleStream;

      modelTupleStream.open();
      {
        final Tuple tuple1 = modelTupleStream.read();
        assertNotNull(tuple1);
        assertFalse(tuple1.EOF);

        for (int ii=0; ii<outputKeys.length; ++ii)
        {
          final INDArray inputs = Nd4j.create(new float[][] { floats });
          final double originalScore = NetworkUtils.output((Model)originalModel, inputs).getDouble(ii);
          final double restoredScore = NetworkUtils.output((Model)restoredModel, inputs).getDouble(ii);
          assertEquals(
            originalModel.getClass().getSimpleName()+" (originalScore-restoredScore)="+(originalScore-restoredScore),
            originalScore, restoredScore, 1e-5);

          final Double outputValue = tuple1.getDouble(outputKeys[ii]);
          assertNotNull(outputValue);
          final double tupleScore = outputValue.doubleValue();
          assertEquals(
            originalModel.getClass().getSimpleName()+" (originalScore-tupleScore["+ii+"])="+(originalScore-tupleScore),
            originalScore, tupleScore, 1e-5);
        }

        final Tuple tuple2 = modelTupleStream.read();
        assertNotNull(tuple2);
        assertTrue(tuple2.EOF);
      }
      modelTupleStream.close();

      doToExpressionTest(streamExpression,
        modelTupleStream.toExpression(streamFactory),
        inputKeys.length);

      doToExplanationTest(modelTupleStream.toExplanation(streamFactory));
    }

  }