smile.data.Attribute Java Examples

The following examples show how to use smile.data.Attribute. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example #1
Source File: SmileRandomForest.java    From kogito-runtimes with Apache License 2.0 6 votes vote down vote up
public SmileRandomForest(Map<String, AttributeType> inputFeatures,
                         String outputFeatureName,
                         AttributeType outputFeatureType,
                         double confidenceThreshold,
                         int numberTrees) {
    super(inputFeatures, outputFeatureName, outputFeatureType, confidenceThreshold);
    this.numberTrees = numberTrees;
    smileAttributes = new HashMap<>();
    for (Entry<String, AttributeType> inputFeature : inputFeatures.entrySet()) {
        final String name = inputFeature.getKey();
        final AttributeType type = inputFeature.getValue();
        smileAttributes.put(name, createAttribute(name, type));
        attributeNames.add(name);
    }
    numAttributes = smileAttributes.size();
    outcomeAttribute = createAttribute(outputFeatureName, outputFeatureType);
    outcomeAttributeType = outputFeatureType;
    
    dataset = new AttributeDataset("dataset", smileAttributes.values().toArray(new Attribute[numAttributes]), outcomeAttribute);
}
 
Example #2
Source File: SmileRandomForest.java    From kogito-runtimes with Apache License 2.0 5 votes vote down vote up
protected Attribute createAttribute(String name, AttributeType type) {
    if (type == AttributeType.NOMINAL || type == AttributeType.BOOLEAN) {
        return new NominalAttribute(name);
    } else if (type == AttributeType.NUMERIC) {
        return new NumericAttribute(name);
    } else {
        return new StringAttribute(name);
    }
}
 
Example #3
Source File: LinkClassifierTrainer.java    From ache with Apache License 2.0 5 votes vote down vote up
/**
 * Converts the input instances into an AttributeDataset object that can be used to train a
 * SMILE classifier.
 * 
 * @param attributes
 * @param instances
 * @param wrapper
 * @param dataset
 * @throws IOException
 */
private AttributeDataset createDataset(List<Sampler<LinkNeighborhood>> instances,
        String[] features, List<String> classValues, LinkNeighborhoodWrapper wrapper) {
    
    List<Attribute> attributes = new ArrayList<>();
    for(String featureName : features) {
        NumericAttribute attribute = new NumericAttribute(featureName);
        attributes.add(attribute);
    }

    Attribute[] attributesArray = (Attribute[]) attributes.toArray(new Attribute[attributes.size()]);
    String[] classValuesArray = (String[]) classValues.toArray(new String[classValues.size()]);
    String description = "If link leads to relevant page or not.";
    Attribute response = new NominalAttribute("y", description, classValuesArray);
    AttributeDataset dataset = new AttributeDataset("link_classifier", attributesArray, response);

    for (int level = 0; level < instances.size(); level++) {
        Sampler<LinkNeighborhood> levelSamples = instances.get(level);
        for (LinkNeighborhood ln : levelSamples.getSamples()) {
            Instance instance;
            try {
                instance = wrapper.extractToInstance(ln, features);
            } catch (MalformedURLException e) {
                logger.warn("Failed to process intance: "+ln.getLink().toString(), e);
                continue;
            }
            double[] values = instance.getValues(); // the instance's feature vector
            int y = level; // the class we're trying to predict
            dataset.add(values, y);
        }
    }
    return dataset;
}
 
Example #4
Source File: DecisionTreeTest.java    From incubator-hivemall with Apache License 2.0 4 votes vote down vote up
private static void runTracePredict(String datasetUrl, int responseIndex, int numLeafs)
        throws IOException, ParseException {
    URL url = new URL(datasetUrl);
    InputStream is = new BufferedInputStream(url.openStream());

    ArffParser arffParser = new ArffParser();
    arffParser.setResponseIndex(responseIndex);

    AttributeDataset ds = arffParser.parse(is);
    final Attribute[] attrs = ds.attributes();
    final Attribute targetAttr = ds.response();

    double[][] x = ds.toArray(new double[ds.size()][]);
    int[] y = ds.toArray(new int[ds.size()]);

    Random rnd = new Random(43L);
    int numTrain = (int) (x.length * 0.7);
    int[] index = ArrayUtils.shuffle(MathUtils.permutation(x.length), rnd);
    int[] cvTrain = Arrays.copyOf(index, numTrain);
    int[] cvTest = Arrays.copyOfRange(index, numTrain, index.length);

    double[][] trainx = Math.slice(x, cvTrain);
    int[] trainy = Math.slice(y, cvTrain);
    double[][] testx = Math.slice(x, cvTest);

    DecisionTree tree = new DecisionTree(SmileExtUtils.convertAttributeTypes(attrs),
        matrix(trainx, false), trainy, numLeafs, RandomNumberGeneratorFactory.createPRNG(43L));

    final LinkedHashMap<String, Double> map = new LinkedHashMap<>();
    final StringBuilder buf = new StringBuilder();
    for (int i = 0; i < testx.length; i++) {
        final DenseVector test = new DenseVector(testx[i]);
        tree.predict(test, new PredictionHandler() {

            @Override
            public void visitBranch(Operator op, int splitFeatureIndex, double splitFeature,
                    double splitValue) {
                buf.append(attrs[splitFeatureIndex].name);
                buf.append(" [" + splitFeature + "] ");
                buf.append(op);
                buf.append(' ');
                buf.append(splitValue);
                buf.append('\n');

                map.put(attrs[splitFeatureIndex].name + " [" + splitFeature + "] " + op,
                    splitValue);
            }

            @Override
            public void visitLeaf(int output, double[] posteriori) {
                buf.append(targetAttr.toString(output));
            }
        });

        Assert.assertTrue(buf.length() > 0);
        Assert.assertFalse(map.isEmpty());

        StringUtils.clear(buf);
        map.clear();
    }

}