/* * Copyright (c) [2016-2018] [University of Minnesota] * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in all * copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. */ package org.grouplens.samantha.server.predictor; import org.grouplens.samantha.modeler.featurizer.FeatureExtractor; import org.grouplens.samantha.modeler.common.LearningData; import org.grouplens.samantha.modeler.model.SpaceMode; import org.grouplens.samantha.modeler.tree.RegressionTree; import org.grouplens.samantha.modeler.tree.RegressionTreeProducer; import org.grouplens.samantha.modeler.tree.TreeLearningMethod; import org.grouplens.samantha.server.common.AbstractModelManager; import org.grouplens.samantha.server.common.ModelManager; import org.grouplens.samantha.server.config.ConfigKey; import org.grouplens.samantha.server.expander.ExpanderUtilities; import org.grouplens.samantha.server.featurizer.FeatureExtractorConfig; import org.grouplens.samantha.server.featurizer.FeatureExtractorListConfigParser; import org.grouplens.samantha.server.featurizer.FeaturizerConfigParser; import org.grouplens.samantha.server.expander.EntityExpander; import org.grouplens.samantha.server.io.RequestContext; import play.Configuration; import play.inject.Injector; import java.util.ArrayList; import java.util.List; public class RegressionTreePredictorConfig implements PredictorConfig { private final String modelName; private final String modelFile; private final List<FeatureExtractorConfig> feaExtConfigs; private final List<String> features; private final String labelName; private final String weightName; private final Configuration daoConfigs; private final List<Configuration> expandersConfig; private final Injector injector; private final TreeLearningMethod method; private final String daoConfigKey; private final Configuration config; private RegressionTreePredictorConfig(String modelName, List<FeatureExtractorConfig> feaExtConfigs, List<String> features, String labelName, String weightName, Configuration daoConfigs, List<Configuration> expandersConfig, Injector injector, TreeLearningMethod method, String modelFile, String daoConfigKey, Configuration config) { this.modelName = modelName; this.feaExtConfigs = feaExtConfigs; this.features = features; this.labelName = labelName; this.weightName = weightName; this.daoConfigs = daoConfigs; this.expandersConfig = expandersConfig; this.injector = injector; this.method = method; this.modelFile = modelFile; this.daoConfigKey = daoConfigKey; this.config = config; } public static PredictorConfig getPredictorConfig(Configuration predictorConfig, Injector injector) { FeaturizerConfigParser parser = injector.instanceOf( FeatureExtractorListConfigParser.class); Configuration daoConfigs = predictorConfig.getConfig(ConfigKey.ENTITY_DAOS_CONFIG.get()); List<Configuration> expanders = ExpanderUtilities.getEntityExpandersConfig(predictorConfig); List<FeatureExtractorConfig> feaExtConfigs = parser.parse(predictorConfig .getConfig(ConfigKey.PREDICTOR_FEATURIZER_CONFIG.get())); return new RegressionTreePredictorConfig(predictorConfig.getString("modelName"), feaExtConfigs, predictorConfig.getStringList("features"), predictorConfig.getString("labelName"), predictorConfig.getString("weightName"), daoConfigs, expanders, injector, injector.instanceOf(TreeLearningMethod.class), predictorConfig.getString("modelFile"), predictorConfig.getString("daoConfigKey"), predictorConfig); } private class RegressionTreeModelManager extends AbstractModelManager { public RegressionTreeModelManager(String modelName, String modelFile, Injector injector) { super(injector, modelName, modelFile, new ArrayList<>()); } public Object createModel(RequestContext requestContext, SpaceMode spaceMode) { List<FeatureExtractor> featureExtractors = new ArrayList<>(); for (FeatureExtractorConfig feaExtConfig : feaExtConfigs) { featureExtractors.add(feaExtConfig.getFeatureExtractor(requestContext)); } RegressionTreeProducer producer = injector.instanceOf(RegressionTreeProducer.class); RegressionTree model = producer.createRegressionTree(modelName, spaceMode, features, featureExtractors, labelName, weightName); return model; } public Object buildModel(Object model, RequestContext requestContext) { RegressionTree regressionTree = (RegressionTree) model; LearningData data = PredictorUtilities.getLearningData(regressionTree, requestContext, requestContext.getRequestBody().get(daoConfigKey), daoConfigs, expandersConfig, injector, true, null, 128); method.learn(regressionTree, data, null); return model; } } public Predictor getPredictor(RequestContext requestContext) { List<EntityExpander> entityExpanders = ExpanderUtilities.getEntityExpanders(requestContext, expandersConfig, injector); ModelManager modelManager = new RegressionTreeModelManager(modelName, modelFile, injector); RegressionTree model = (RegressionTree) modelManager.manage(requestContext); return new PredictiveModelBasedPredictor(config, model, model, daoConfigs, injector, entityExpanders, daoConfigKey); } }