/* * Copyright 2018 org.LTR4L * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.ltr4l.lucene.solr.client; import com.fasterxml.jackson.databind.ObjectMapper; import org.apache.http.HttpEntity; import org.apache.http.HttpResponse; import org.apache.http.client.HttpClient; import org.apache.http.client.methods.HttpGet; import org.apache.http.client.methods.HttpPost; import org.apache.http.entity.StringEntity; import org.apache.http.impl.client.HttpClients; import org.apache.http.util.EntityUtils; import org.ltr4l.click.CMQueryHandler; import org.ltr4l.click.LTRResponse; import org.ltr4l.click.LTRResponseHandler; import org.ltr4l.click.ClickRateClassifier; import java.io.InputStreamReader; import java.nio.charset.StandardCharsets; import java.util.*; public class FeatureExtractor { private final CMQueryHandler cmQueryHandler; private LTRResponseHandler ltrResponseHandler; private final String url; private final String confName; private final String idField; private final long extractionTimeout; private String procId; public FeatureExtractor(String url, String confName, CMQueryHandler cmQueryHandler) { this(url, confName, cmQueryHandler, "id", 10000L); } public FeatureExtractor(String url, String confName, CMQueryHandler cmQueryHandler, Long extractionTimeout) { this(url, confName, cmQueryHandler, "id", extractionTimeout); } public FeatureExtractor(String url, String confName, CMQueryHandler cmQueryHandler, String idField) { this(url, confName, cmQueryHandler, idField, 10000L); } public FeatureExtractor(String url, String confName, CMQueryHandler cmQueryHandler, String idField, Long extractionTimeout) { this.url = url; this.confName = confName; this.cmQueryHandler = cmQueryHandler; this.idField = idField; this.extractionTimeout = extractionTimeout; } public void execute() throws Exception { if( url == null || url.equals("")) return; postTrainingData(); if(isFinished()) { download(); } else { System.err.println("Could not finish feature extraction.\nPlease set longer extraction timeout period or confirm the url and conf name are valid."); } } private void postTrainingData() throws Exception { HttpClient httpClient = HttpClients.createDefault(); CMQueryHandler.CMQueries cmQueries = new CMQueryHandler.CMQueries(cmQueryHandler.getClickRates(), idField); String url = this.url + "?command=extract&conf=" + confName + "&wt=json"; StringEntity trainingJson = new StringEntity(cmQueries.toString(), "UTF-8"); HttpPost httpPost = new HttpPost(url); trainingJson.setContentType("application/json; charset=UTF-8"); httpPost.setEntity(trainingJson); httpPost.addHeader("Content-type", "application/json; charset=UTF-8"); httpPost.addHeader("Accept", "application/json"); HttpResponse response = httpClient.execute(httpPost); HttpEntity entity = response.getEntity(); if (entity != null) { ObjectMapper mapper = new ObjectMapper(); Map<String, Object> entityMap = mapper.readValue(EntityUtils.toString(entity), Map.class); Map<String, Object> results = (Map<String, Object>) entityMap.get("results"); procId = results == null ? null : String.valueOf(results.get("procId")); } } private boolean isFinished() throws Exception { String url = this.url + "?command=progress&procId=" + procId + "&wt=json"; HttpGet httpGet = new HttpGet(url); HttpClient httpClient = HttpClients.createDefault(); int progress = 0; long startTime = System.nanoTime(); long duration = 0; //TODO: smarter code while (progress < 100 && duration < extractionTimeout * 10000) { Thread.sleep(1000); HttpResponse response = httpClient.execute(httpGet); HttpEntity entity = response.getEntity(); if (entity != null) { ObjectMapper mapper = new ObjectMapper(); Map<String, Object> entityMap = mapper.readValue(EntityUtils.toString(entity), Map.class); Map<String, Object> results = (Map<String, Object>) entityMap.get("results"); progress = results == null ? 0 : (Integer)results.get("progress"); } else { return false; } duration = System.nanoTime() - startTime; } return progress == 100; } private void download() throws Exception { String url = this.url + "?command=download&procId=" + procId + "&wt=json"; HttpClient httpClient = HttpClients.createDefault(); HttpGet httpGet = new HttpGet(url); httpGet.addHeader("Content-type", "application/json; charset=UTF-8"); HttpResponse response = httpClient.execute(httpGet); InputStreamReader inputStreamReader = new InputStreamReader(response.getEntity().getContent(), StandardCharsets.UTF_8); ltrResponseHandler = new LTRResponseHandler(inputStreamReader); } public Map<String, LTRResponse.Doc[]> getTrainingData() { return ltrResponseHandler == null ? null : ltrResponseHandler.mergeClickRates(cmQueryHandler.getClickRates()); } //TODO: Large training data may cause OOM, we should parse & write Doc into output file one by one. public String getMSFormatTrainingData(String borderListStr) { if (ltrResponseHandler == null) return null; ClickRateClassifier crc = new ClickRateClassifier(borderListStr); Map<String, LTRResponse.Doc[]> trainingData = ltrResponseHandler.mergeClickRates(cmQueryHandler.getClickRates()); StringBuilder sb = new StringBuilder(); long qid = 0; for (Map.Entry<String, LTRResponse.Doc[]> entry : trainingData.entrySet()) { LTRResponse.Doc[] docs = entry.getValue(); for (LTRResponse.Doc doc : docs) { sb.append(String.valueOf(crc.classify(doc.getClickrate())) + " " + "qid:" + String.valueOf(qid)); double[] features = doc.features; int len = features.length; for (int i = 0; i < len; i++) { sb.append(" " + String.valueOf(i+1) + ":" + String.valueOf(features[i])); } sb.append(" #docid = " + doc.id + "\n"); } qid++; } return sb.toString(); } }