/*
 *
 *  * ******************************************************************************
 *  *  * Copyright (c) 2015-2019 Skymind Inc.
 *  *  * Copyright (c) 2019 Konduit AI.
 *  *  *
 *  *  * This program and the accompanying materials are made available under the
 *  *  * terms of the Apache License, Version 2.0 which is available at
 *  *  * https://www.apache.org/licenses/LICENSE-2.0.
 *  *  *
 *  *  * Unless required by applicable law or agreed to in writing, software
 *  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 *  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 *  *  * License for the specific language governing permissions and limitations
 *  *  * under the License.
 *  *  *
 *  *  * SPDX-License-Identifier: Apache-2.0
 *  *  *****************************************************************************
 *
 *
 */

package ai.konduit.serving.output.adapter;

import ai.konduit.serving.output.types.RegressionOutput;
import io.vertx.ext.web.RoutingContext;
import lombok.AllArgsConstructor;
import lombok.NoArgsConstructor;
import org.datavec.api.transform.schema.Schema;
import org.dmg.pmml.FieldName;
import org.nd4j.linalg.api.ndarray.INDArray;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;

/**
 * Convert the input based on the input
 * {@link Schema} to {@link RegressionOutput}
 * representing real valued output.
 *
 * @author Adam Gibson
 */
@AllArgsConstructor
@NoArgsConstructor
public class RegressionOutputAdapter implements OutputAdapter<RegressionOutput> {

    private Schema schema;
    private List<FieldName> fieldNames;

    /**
     * Create the output adapter
     * with the output inputSchema
     *
     * @param schema the inputSchema of the output
     */
    public RegressionOutputAdapter(Schema schema) {
        this.schema = schema;
        fieldNames = new ArrayList<>(schema.numColumns());
        for (int i = 0; i < schema.numColumns(); i++) {
            fieldNames.add(FieldName.create(schema.getName(i)));
        }

    }

    @Override
    public RegressionOutput adapt(INDArray array, RoutingContext routingContext) {
        return RegressionOutput
                .builder()
                .values(array.toDoubleMatrix())
                .build();
    }

    @Override
    public RegressionOutput adapt(List<? extends Map<FieldName, ?>> pmmlExamples, RoutingContext routingContext) {
        if (schema == null) {
            throw new IllegalStateException("No inputSchema found. A inputSchema is required in order to create results.");
        }

        double[][] values = new double[pmmlExamples.size()][pmmlExamples.get(0).size()];
        for (int i = 0; i < pmmlExamples.size(); i++) {
            Map<FieldName, ?> example = pmmlExamples.get(i);
            for (int j = 0; j < schema.numColumns(); j++) {
                Double result = (Double) example.get(fieldNames.get(j));
                values[i][j] = result;
            }
        }

        return RegressionOutput.builder().values(values).build();
    }

    @Override
    public RegressionOutput adapt(Object input, RoutingContext routingContext) {
        if (input instanceof INDArray) {
            INDArray arr = (INDArray) input;
            return adapt(arr, routingContext);
        } else if (input instanceof List) {
            List<? extends Map<FieldName, ?>> pmmlExamples = (List<? extends Map<FieldName, ?>>) input;
            return adapt(pmmlExamples, routingContext);
        }

        throw new UnsupportedOperationException("Unable to convert input of type " + input);
    }

    @Override
    public Class<RegressionOutput> outputAdapterType() {
        return RegressionOutput.class;
    }
}