/* * Copyright (c) [2016-2018] [University of Minnesota] * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in all * copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. */ package org.grouplens.samantha.server.expander; import com.fasterxml.jackson.databind.node.ObjectNode; import com.google.common.collect.Lists; import it.unimi.dsi.fastutil.ints.IntArrayList; import it.unimi.dsi.fastutil.ints.IntList; import it.unimi.dsi.fastutil.ints.IntOpenHashSet; import it.unimi.dsi.fastutil.ints.IntSet; import org.apache.commons.lang3.StringUtils; import org.grouplens.samantha.modeler.featurizer.FeatureExtractorUtilities; import org.grouplens.samantha.modeler.model.AbstractLearningModel; import org.grouplens.samantha.server.common.ModelService; import org.grouplens.samantha.server.config.SamanthaConfigService; import org.grouplens.samantha.server.io.RequestContext; import play.Configuration; import play.inject.Injector; import java.util.List; import java.util.Map; import java.util.Random; public class NegativeSamplingExpander implements EntityExpander { final private String itemAttr; final private String itemIndex; final private String keyPrefix; final private String labelAttr; final private String separator; final private String joiner; final private List<String> fillInAttrs; final private Integer maxIdx; final private Integer maxNumSample; final private AbstractLearningModel model; public NegativeSamplingExpander(String itemAttr, String itemIndex, String keyPrefix, String labelAttr, List<String> fillInAttrs, String separator, String joiner, Integer maxIdx, Integer maxNumSample, AbstractLearningModel model) { this.separator = separator; this.labelAttr = labelAttr; this.joiner = joiner; this.itemAttr = itemAttr; this.fillInAttrs = fillInAttrs; this.itemIndex = itemIndex; this.keyPrefix = keyPrefix; this.model = model; this.maxIdx = maxIdx; this.maxNumSample = maxNumSample; } public static EntityExpander getExpander(Configuration expanderConfig, Injector injector, RequestContext requestContext) { ModelService modelService = injector.instanceOf(ModelService.class); SamanthaConfigService configService = injector.instanceOf(SamanthaConfigService.class); configService.getPredictor(expanderConfig.getString("predictorName"), requestContext); AbstractLearningModel model = (AbstractLearningModel) modelService.getModel( requestContext.getEngineName(), expanderConfig.getString("modelName")); String keyPrefix = expanderConfig.getString("keyPrefix"); if (keyPrefix == null) { keyPrefix = expanderConfig.getString("itemAttr"); } return new NegativeSamplingExpander( expanderConfig.getString("itemAttr"), expanderConfig.getString("itemIndex"), keyPrefix, expanderConfig.getString("labelAttr"), expanderConfig.getStringList("fillInAttrs"), expanderConfig.getString("separator"), expanderConfig.getString("joiner"), expanderConfig.getInt("maxIdx"), expanderConfig.getInt("maxNumSample"), model); } private IntList getSampledIndices(IntSet trues, int maxVal) { IntList samples = new IntArrayList(); int num = trues.size(); if (maxNumSample != null) { num = maxNumSample; } for (int i=0; i<num; i++) { int dice = new Random().nextInt(maxVal); if (!trues.contains(dice)) { samples.add(dice); } } return samples; } public List<ObjectNode> expand(List<ObjectNode> initialResult, RequestContext requestContext) { int indexSize = model.getKeyMapSize(itemIndex); int maxVal = indexSize; if (maxIdx != null && maxIdx < indexSize) { maxVal = maxIdx; } for (ObjectNode entity : initialResult) { String itemStr = entity.get(itemAttr).asText(); if (!"".equals(itemStr)) { String[] items = itemStr.split(separator, -1); IntSet trues = new IntOpenHashSet(); for (String item : items) { String key = FeatureExtractorUtilities.composeKey(keyPrefix, item); if (model.containsKey(itemIndex, key)) { trues.add(model.getIndexForKey(itemIndex, key)); } else { trues.add(0); } } IntList samples = getSampledIndices(trues, maxVal); List<String> itemArr = Lists.newArrayList(itemStr); List<String> labelArr = Lists.newArrayList(entity.get(labelAttr).asText()); for (int sample : samples) { Map<String, String> key2val = FeatureExtractorUtilities.decomposeKey( (String)model.getKeyForIndex(itemIndex, sample)); itemArr.add(key2val.get(keyPrefix)); labelArr.add("0"); } if (fillInAttrs != null && fillInAttrs.size() > 0 && samples.size() > 0) { for (int i=0; i<fillInAttrs.size(); i++) { String fillStr = entity.get(fillInAttrs.get(i)).asText(); String[] fills = fillStr.split(separator, -1); List<String> fillEls = Lists.newArrayList(fillStr); for (int j=0; j<samples.size(); j++) { fillEls.add(fills[fills.length - 1]); } entity.put(fillInAttrs.get(i), StringUtils.join(fillEls, joiner)); } } entity.put(labelAttr, StringUtils.join(labelArr, joiner)); entity.put(itemAttr, StringUtils.join(itemArr, joiner)); } } return initialResult; } }