package ai.libs.jaicore.ml.classification.multilabel.dataset;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.LinkedList;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

import org.api4.java.ai.ml.core.dataset.schema.ILabeledInstanceSchema;
import org.api4.java.ai.ml.core.dataset.schema.attribute.IAttribute;
import org.api4.java.ai.ml.core.dataset.schema.attribute.ICategoricalAttribute;
import org.api4.java.ai.ml.core.dataset.schema.attribute.INumericAttribute;
import org.api4.java.ai.ml.core.dataset.serialization.UnsupportedAttributeTypeException;
import org.api4.java.ai.ml.core.dataset.supervised.ILabeledDataset;
import org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance;

import ai.libs.jaicore.ml.core.dataset.schema.LabeledInstanceSchema;
import ai.libs.jaicore.ml.core.dataset.schema.attribute.IntBasedCategoricalAttribute;
import ai.libs.jaicore.ml.core.dataset.schema.attribute.MultiLabelAttribute;
import ai.libs.jaicore.ml.core.dataset.schema.attribute.NumericAttribute;
import weka.core.Attribute;
import weka.core.DenseInstance;
import weka.core.Instance;
import weka.core.Instances;

public class MekaInstancesUtil {

	private MekaInstancesUtil() {
		/* Intentionally blank, hiding standard constructor for this util class. */
	}

	public static ILabeledInstanceSchema extractSchema(final Instances dataset) {
		int targetIndex = dataset.classIndex();
		if (targetIndex < 0) {
			throw new IllegalArgumentException("Class index of Instances object is not set!");
		}
		List<IAttribute> attributes = IntStream.range(dataset.classIndex(), dataset.numAttributes()).mapToObj(dataset::attribute).map(MekaInstancesUtil::transformWEKAAttributeToAttributeType).collect(Collectors.toList());

		List<String> values = IntStream.range(0, dataset.classIndex()).mapToObj(x -> dataset.attribute(x).name()).collect(Collectors.toList());
		IAttribute labelAttribute = new MultiLabelAttribute("labels", values);
		return new LabeledInstanceSchema(dataset.relationName(), attributes, labelAttribute);
	}

	public static Instances datasetToWekaInstances(final ILabeledDataset<? extends ILabeledInstance> dataset) throws UnsupportedAttributeTypeException {
		Instances wekaInstances = createDatasetFromSchema(dataset.getInstanceSchema());
		for (ILabeledInstance inst : dataset) {
			double[] point = inst.getPoint();
			double[] pointWithLabel = Arrays.copyOf(point, point.length + 1);
			DenseInstance iNew = new DenseInstance(1, pointWithLabel);
			iNew.setDataset(wekaInstances);
			if (dataset.getLabelAttribute() instanceof ICategoricalAttribute) {
				iNew.setClassValue(((ICategoricalAttribute) dataset.getLabelAttribute()).getLabelOfCategory((int) inst.getLabel()));
			} else {
				iNew.setClassValue((Double) inst.getLabel());
			}
			wekaInstances.add(iNew); // this MUST come here AFTER having set the class value; otherwise, the class is not registered correctly in the Instances object!!
		}
		return wekaInstances;
	}

	public static Instances createDatasetFromSchema(final ILabeledInstanceSchema schema) throws UnsupportedAttributeTypeException {
		Objects.requireNonNull(schema);
		List<Attribute> attributes = new LinkedList<>();

		for (int i = 0; i < schema.getNumAttributes(); i++) {
			IAttribute attType = schema.getAttributeList().get(i);
			if (attType instanceof NumericAttribute) {
				attributes.add(new Attribute("att" + i));
			} else if (attType instanceof IntBasedCategoricalAttribute) {
				attributes.add(new Attribute("att" + i, ((IntBasedCategoricalAttribute) attType).getLabels()));
			} else {
				throw new UnsupportedAttributeTypeException("The class attribute has an unsupported attribute type " + attType.getName() + ".");
			}
		}

		IAttribute classType = schema.getLabelAttribute();
		Attribute classAttribute;

		if (classType instanceof INumericAttribute) {
			classAttribute = new Attribute("class");
		} else if (classType instanceof ICategoricalAttribute) {
			classAttribute = new Attribute("class", ((IntBasedCategoricalAttribute) classType).getLabels());
		} else {
			throw new UnsupportedAttributeTypeException("The class attribute has an unsupported attribute type.");
		}

		ArrayList<Attribute> attributeList = new ArrayList<>(attributes);
		attributeList.add(classAttribute);

		Instances wekaInstances = new Instances("weka-instances", attributeList, 0);
		wekaInstances.setClassIndex(wekaInstances.numAttributes() - 1);
		return wekaInstances;

	}

	public static IAttribute transformWEKAAttributeToAttributeType(final Attribute att) {
		String attributeName = att.name();
		if (att.isNumeric()) {
			return new NumericAttribute(attributeName);
		} else if (att.isNominal()) {
			List<String> domain = new LinkedList<>();
			for (int i = 0; i < att.numValues(); i++) {
				domain.add(att.value(i));
			}
			return new IntBasedCategoricalAttribute(attributeName, domain);
		}
		throw new IllegalArgumentException("Can only transform numeric or categorical attributes");
	}

	public static Instance transformInstanceToWekaInstance(final ILabeledInstanceSchema schema, final ILabeledInstance instance) throws UnsupportedAttributeTypeException {
		if (instance.getNumAttributes() != schema.getNumAttributes()) {
			throw new IllegalArgumentException("Schema and instance do not coincide. The schema defines " + schema.getNumAttributes() + " attributes but the instance has " + instance.getNumAttributes() + " attributes.");
		}
		if (instance instanceof MekaInstance) {
			return ((MekaInstance) instance).getElement();
		}
		Objects.requireNonNull(schema);
		Instances dataset = createDatasetFromSchema(schema);
		Instance iNew = new DenseInstance(dataset.numAttributes());
		iNew.setDataset(dataset);
		for (int i = 0; i < instance.getNumAttributes(); i++) {
			if (schema.getAttribute(i) instanceof INumericAttribute) {
				iNew.setValue(i, ((INumericAttribute) schema.getAttribute(i)).getAsAttributeValue(instance.getAttributeValue(i)).getValue());
			} else if (schema.getAttribute(i) instanceof ICategoricalAttribute) {
				iNew.setValue(i, ((ICategoricalAttribute) schema.getAttribute(i)).getAsAttributeValue(instance.getAttributeValue(i)).getValue());
			} else {
				throw new UnsupportedAttributeTypeException("Only categorical and numeric attributes are supported!");
			}
		}

		if (schema.getLabelAttribute() instanceof INumericAttribute) {
			iNew.setValue(iNew.numAttributes() - 1, ((INumericAttribute) schema.getLabelAttribute()).getAsAttributeValue(instance.getLabel()).getValue());
		} else if (schema.getLabelAttribute() instanceof ICategoricalAttribute) {
			iNew.setValue(iNew.numAttributes() - 1, ((ICategoricalAttribute) schema.getLabelAttribute()).getAsAttributeValue(instance.getLabel()).getValue());
		} else {
			throw new UnsupportedAttributeTypeException("Only categorical and numeric attributes are supported!");
		}
		return iNew;
	}

}