package ai.libs.jaicore.ml.core.dataset.serialization; import java.io.BufferedReader; import java.io.BufferedWriter; import java.io.File; import java.io.FileWriter; import java.io.IOException; import java.nio.file.Files; import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.NoSuchElementException; import java.util.Objects; import java.util.stream.Collectors; import java.util.stream.IntStream; import org.api4.java.ai.ml.core.dataset.descriptor.IDatasetDescriptor; import org.api4.java.ai.ml.core.dataset.descriptor.IFileDatasetDescriptor; import org.api4.java.ai.ml.core.dataset.schema.ILabeledInstanceSchema; import org.api4.java.ai.ml.core.dataset.schema.attribute.IAttribute; import org.api4.java.ai.ml.core.dataset.schema.attribute.ICategoricalAttribute; import org.api4.java.ai.ml.core.dataset.schema.attribute.INumericAttribute; import org.api4.java.ai.ml.core.dataset.serialization.DatasetDeserializationFailedException; import org.api4.java.ai.ml.core.dataset.serialization.IDatasetDeserializer; import org.api4.java.ai.ml.core.dataset.serialization.UnsupportedAttributeTypeException; import org.api4.java.ai.ml.core.dataset.supervised.ILabeledDataset; import org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import ai.libs.jaicore.basic.OptionsParser; import ai.libs.jaicore.basic.kvstore.KVStore; import ai.libs.jaicore.ml.core.dataset.Dataset; import ai.libs.jaicore.ml.core.dataset.DenseInstance; import ai.libs.jaicore.ml.core.dataset.SparseInstance; import ai.libs.jaicore.ml.core.dataset.schema.LabeledInstanceSchema; import ai.libs.jaicore.ml.core.dataset.schema.attribute.IntBasedCategoricalAttribute; import ai.libs.jaicore.ml.core.dataset.schema.attribute.NumericAttribute; import ai.libs.jaicore.ml.core.dataset.schema.attribute.StringAttribute; import ai.libs.jaicore.ml.core.dataset.serialization.arff.EArffAttributeType; import ai.libs.jaicore.ml.core.dataset.serialization.arff.EArffItem; public class ArffDatasetAdapter implements IDatasetDeserializer<ILabeledDataset<ILabeledInstance>> { private static final Logger LOGGER = LoggerFactory.getLogger(ArffDatasetAdapter.class); public static final String K_RELATION_NAME = "relationName"; public static final String K_CLASS_INDEX = "classIndex"; private static final String F_CLASS_INDEX = "C"; private static final String SEPARATOR_RELATIONNAME = ":"; private static final String SEPARATOR_ATTRIBUTE_DESCRIPTION = " "; private static final String SEPARATOR_DENSE_INSTANCE_VALUES = ","; private final boolean sparseMode; private IDatasetDescriptor datasetDescriptor = null; public ArffDatasetAdapter(final boolean sparseMode, final IDatasetDescriptor datasetDescriptor) { this(sparseMode); this.datasetDescriptor = datasetDescriptor; } public ArffDatasetAdapter(final boolean sparseMode) { this.sparseMode = sparseMode; } public ArffDatasetAdapter() { this(false); } public IAttribute getAttributeWithName(final IFileDatasetDescriptor datasetFile, final String nameOfAttribute) throws DatasetDeserializationFailedException { try (BufferedReader br = Files.newBufferedReader(datasetFile.getDatasetDescription().toPath())) { String line; while ((line = br.readLine()) != null) { if (line.toLowerCase().startsWith(EArffItem.ATTRIBUTE.getValue().toLowerCase())) { IAttribute att = parseAttribute(line); if (att.getName().equals(nameOfAttribute)) { return att; } } } throw new NoSuchElementException("No attribute with name " + nameOfAttribute + " found."); } catch (Exception e) { throw new DatasetDeserializationFailedException(e); } } public ILabeledDataset<ILabeledInstance> deserializeDataset(final IFileDatasetDescriptor datasetFile, final String nameOfClassAttribute) throws DatasetDeserializationFailedException { Objects.requireNonNull(datasetFile, "No dataset has been configured."); /* read the file until the class parameter is found and count the params */ int numAttributes = 0; try (BufferedReader br = Files.newBufferedReader(datasetFile.getDatasetDescription().toPath())) { String line; while ((line = br.readLine()) != null) { if (line.toLowerCase().startsWith(EArffItem.ATTRIBUTE.getValue().toLowerCase())) { IAttribute att = parseAttribute(line); if (att.getName().equals(nameOfClassAttribute)) { break; } numAttributes++; } } } catch (Exception e) { throw new DatasetDeserializationFailedException(e); } LOGGER.info("Successfully identified class attribute index {} for attribute with name {}", numAttributes, nameOfClassAttribute); return this.deserializeDataset(datasetFile, numAttributes); } public ILabeledDataset<ILabeledInstance> deserializeDataset(final IFileDatasetDescriptor datasetDescriptor, final int columnWithClassIndex) throws DatasetDeserializationFailedException { Objects.requireNonNull(datasetDescriptor, "No dataset has been configured."); return readDataset(this.sparseMode, datasetDescriptor.getDatasetDescription(), columnWithClassIndex); } @Override public ILabeledDataset<ILabeledInstance> deserializeDataset(final IDatasetDescriptor datasetDescriptor) throws DatasetDeserializationFailedException, InterruptedException { if (!(datasetDescriptor instanceof IFileDatasetDescriptor)) { throw new DatasetDeserializationFailedException("Cannot handle dataset descriptor of type " + datasetDescriptor.getClass().getName()); } return this.deserializeDataset((IFileDatasetDescriptor) datasetDescriptor, -1); } public ILabeledDataset<ILabeledInstance> deserializeDataset() throws InterruptedException, DatasetDeserializationFailedException { return this.deserializeDataset(this.datasetDescriptor); } /** * Extracts meta data about a relation from a string. * * @param line The line which is to be parsed to extract the necessary information from the relation name. * @return A KVStore containing the parsed meta data. */ protected static KVStore parseRelation(final String line) { KVStore metaData = new KVStore(); // cut off relation tag String relationDescription = line.substring(EArffItem.RELATION.getValue().length()).trim(); if (relationDescription.startsWith("'") && relationDescription.endsWith("'")) { String[] relationNameAndOptions = line.substring(line.indexOf('\'') + 1, line.lastIndexOf('\'')).split(SEPARATOR_RELATIONNAME); metaData.put(K_RELATION_NAME, relationNameAndOptions[0].trim()); if (relationNameAndOptions.length > 1) { OptionsParser optParser = new OptionsParser(relationNameAndOptions[1]); metaData.put(K_CLASS_INDEX, optParser.get(F_CLASS_INDEX)); } } else { metaData.put(K_RELATION_NAME, relationDescription); } return metaData; } protected static IAttribute parseAttribute(final String line) throws UnsupportedAttributeTypeException { String attributeDefinitionSplit = line.replaceAll("\\t", " ").substring(EArffItem.ATTRIBUTE.getValue().length() + 1).trim(); String name = attributeDefinitionSplit.substring(0, attributeDefinitionSplit.indexOf(SEPARATOR_ATTRIBUTE_DESCRIPTION)); if (name.trim().startsWith("'") && !name.trim().endsWith("'")) { int cutIndex = attributeDefinitionSplit.indexOf('\'', name.length()); name += attributeDefinitionSplit.substring(name.length(), name.length() + cutIndex + 1); } String type = attributeDefinitionSplit.substring(name.length() + 1).trim(); name = name.trim(); if ((name.startsWith("'") && name.endsWith("'")) || (name.startsWith("\"") && name.endsWith("\""))) { name = name.substring(1, name.length() - 1); } EArffAttributeType attType; String[] values = null; if (type.startsWith("{") && type.endsWith("}")) { values = type.substring(1, type.length() - 1).split(SEPARATOR_DENSE_INSTANCE_VALUES); attType = EArffAttributeType.NOMINAL; } else { try { attType = EArffAttributeType.valueOf(type.toUpperCase()); } catch (IllegalArgumentException e) { throw new UnsupportedAttributeTypeException("The attribute type " + type.toUpperCase() + " is not supported in the EArffAttributeType ENUM. (line: " + line + ")"); } } switch (attType) { case NUMERIC: case REAL: case INTEGER: return new NumericAttribute(name); case STRING: return new StringAttribute(name); case NOMINAL: if (values != null) { return new IntBasedCategoricalAttribute(name, Arrays.stream(values).map(String::trim).map(x -> (((x.startsWith("'") && x.endsWith("'")) || x.startsWith("\"") && x.endsWith("\"")) ? x.substring(1, x.length() - 1) : x)).collect(Collectors.toList())); } else { throw new IllegalStateException("Identified a nominal attribute but it seems to have no values."); } default: throw new UnsupportedAttributeTypeException("Can not deal with attribute type " + type); } } protected static Object parseInstance(final boolean sparseData, final List<IAttribute> attributes, final int targetIndex, final String line) { if (line.trim().startsWith("%")) { throw new IllegalArgumentException("Cannot create object for commented line!"); } boolean sparseMode = sparseData; String curLine = line; if (curLine.trim().startsWith("{") && curLine.trim().endsWith("}")) { curLine = curLine.substring(1, curLine.length() - 1); sparseMode = true; if (curLine.trim().isEmpty()) { // the instance does not mention any explicit values => return an empty map. return new HashMap<>(); } } String[] lineSplit = curLine.split(","); if (lineSplit.length < attributes.size()) { sparseMode = true; } if (!sparseMode) { if (lineSplit.length != attributes.size()) { throw new IllegalArgumentException("Cannot parse instance as this is not a sparse instance but has less columns than there are attributes defined. Expected values: " + attributes.size() + ". Actual number of values: " + lineSplit.length + ". Values: " + Arrays.toString(lineSplit)); } Object[] parsedDenseInstance = new Object[lineSplit.length - 1]; Object target = null; int cI = 0; for (int i = 0; i < lineSplit.length; i++) { if (i == targetIndex) { target = attributes.get(i).deserializeAttributeValue(lineSplit[i]); } else { parsedDenseInstance[cI++] = attributes.get(i).deserializeAttributeValue(lineSplit[i]); } } return Arrays.asList(parsedDenseInstance, target); } else { Map<Integer, Object> parsedSparseInstance = new HashMap<>(); for (String sparseValue : lineSplit) { int indexOfFirstSpace = sparseValue.indexOf(' '); int indexOfAttribute = Integer.parseInt(sparseValue.substring(0, indexOfFirstSpace)); String attributeValue = sparseValue.substring(indexOfFirstSpace + 1); parsedSparseInstance.put(indexOfAttribute, attributes.get(indexOfAttribute).deserializeAttributeValue(attributeValue)); } return parsedSparseInstance; } } protected static ILabeledDataset<ILabeledInstance> createDataset(final KVStore relationMetaData, final List<IAttribute> attributes) { if (!relationMetaData.containsKey(K_CLASS_INDEX) || relationMetaData.getAsInt(K_CLASS_INDEX) < 0) { throw new IllegalArgumentException("No (valid) class index given!"); } List<IAttribute> attributeList = new ArrayList<>(attributes); IAttribute labelAttribute = attributeList.remove((int) relationMetaData.getAsInt(K_CLASS_INDEX)); ILabeledInstanceSchema schema = new LabeledInstanceSchema(relationMetaData.getAsString(K_RELATION_NAME), attributeList, labelAttribute); return new Dataset(schema); } public static ILabeledDataset<ILabeledInstance> readDataset(final File datasetFile) throws DatasetDeserializationFailedException { return readDataset(false, datasetFile); } public static ILabeledDataset<ILabeledInstance> readDataset(final boolean sparseMode, final File datasetFile) throws DatasetDeserializationFailedException { return readDataset(sparseMode, datasetFile, -1); } public static ILabeledDataset<ILabeledInstance> readDataset(final boolean sparseMode, final File datasetFile, final int columnWithClassIndex) throws DatasetDeserializationFailedException { try (BufferedReader br = Files.newBufferedReader(datasetFile.toPath())) { ILabeledDataset<ILabeledInstance> dataset = null; KVStore relationMetaData = new KVStore(); List<IAttribute> attributes = new ArrayList<>(); boolean instanceReadMode = false; String line; long lineCounter = 1; while ((line = br.readLine()) != null) { if (!instanceReadMode) { if (line.toLowerCase().startsWith(EArffItem.RELATION.getValue())) { // parse relation meta data relationMetaData = parseRelation(line); if (columnWithClassIndex >= 0) { relationMetaData.put(K_CLASS_INDEX, columnWithClassIndex); } } else if (line.toLowerCase().startsWith(EArffItem.ATTRIBUTE.getValue())) { // parse attribute meta data attributes.add(parseAttribute(line)); } else if (line.toLowerCase().startsWith(EArffItem.DATA.getValue())) { // switch to instance read mode if (!line.toLowerCase().trim().equals(EArffItem.DATA.getValue())) { throw new IllegalArgumentException( "Error while parsing arff-file on line " + lineCounter + ": There is more in this line than just the data declaration " + EArffItem.DATA.getValue() + ", which is not supported"); } instanceReadMode = true; if (relationMetaData.containsKey(K_CLASS_INDEX) && relationMetaData.getAsInt(K_CLASS_INDEX) >= 0) { dataset = createDataset(relationMetaData, attributes); } else { LOGGER.warn("Invalid class index in the dataset's meta data ({}): Assuming last column to be the target attribute!", relationMetaData.get(K_CLASS_INDEX)); relationMetaData.put(K_CLASS_INDEX, attributes.size() - 1); dataset = createDataset(relationMetaData, attributes); } } } else { line = line.trim(); if (!line.isEmpty() && !line.startsWith("%")) { // ignore empty and comment lines Object parsedInstance = parseInstance(sparseMode, attributes, relationMetaData.getAsInt(K_CLASS_INDEX), line); ILabeledInstance newI; if (parsedInstance instanceof List<?>) { newI = new DenseInstance((Object[]) ((List<?>) parsedInstance).get(0), ((List<?>) parsedInstance).get(1)); } else if (parsedInstance instanceof Map) { @SuppressWarnings("unchecked") Map<Integer, Object> parsedSparseInstance = (Map<Integer, Object>) parsedInstance; Object label = (parsedSparseInstance).containsKey(relationMetaData.getAsInt(K_CLASS_INDEX)) ? parsedSparseInstance.remove(relationMetaData.getAsInt(K_CLASS_INDEX)) : 0; // in sparse instance, the class attribute // may be missing; it is then assumed to // be 0 if (label == null) { throw new IllegalArgumentException("Cannot identify label for instance " + line); } newI = new SparseInstance(dataset.getNumAttributes(), parsedSparseInstance, label); } else { throw new IllegalStateException("Severe Error: The format of the parsed instance is not as expected."); } if (newI.getNumAttributes() != dataset.getNumAttributes()) { throw new IllegalStateException("Instance has " + newI.getNumAttributes() + " attributes, but the dataset defines " + dataset.getNumAttributes() + " attributes."); } dataset.add(newI); } } } lineCounter++; return dataset; } catch (Exception e) { throw new DatasetDeserializationFailedException("Could not deserialize dataset from ARFF file.", e); } } public static void serializeDataset(final File arffOutputFile, final ILabeledDataset<? extends ILabeledInstance> data) throws IOException { try (BufferedWriter bw = new BufferedWriter(new FileWriter(arffOutputFile))) { // write metadata serializeMetaData(bw, data); bw.write("\n\n"); // write actual data (payload) serializeData(bw, data); } } private static void serializeData(final BufferedWriter bw, final ILabeledDataset<? extends ILabeledInstance> data) throws IOException { bw.write(EArffItem.DATA.getValue() + "\n"); for (ILabeledInstance instance : data) { if (instance instanceof DenseInstance) { Object[] atts = instance.getAttributes(); bw.write(IntStream.range(0, atts.length).mapToObj(x -> serializeAttributeValue(data.getInstanceSchema().getAttribute(x), atts[x])).collect(Collectors.joining(","))); bw.write(","); bw.write(serializeAttributeValue(data.getInstanceSchema().getLabelAttribute(), instance.getLabel())); bw.write("\n"); } else { bw.write("{"); bw.write(((SparseInstance) instance).getAttributeMap().entrySet().stream().map(x -> x.getKey() + " " + serializeAttributeValue(data.getInstanceSchema().getAttribute(x.getKey()), x.getValue())) .collect(Collectors.joining(","))); if (instance.isLabelPresent()) { bw.write(","); } bw.write(data.getNumAttributes()); bw.write(" "); bw.write(serializeAttributeValue(data.getInstanceSchema().getLabelAttribute(), instance.getLabel())); bw.write("}\n"); } } } private static String serializeAttributeValue(final IAttribute att, final Object value) { String returnValue = att.serializeAttributeValue(value); if (att instanceof ICategoricalAttribute) { returnValue = "'" + returnValue + "'"; } return returnValue; } private static void serializeMetaData(final BufferedWriter bw, final ILabeledDataset<? extends ILabeledInstance> data) throws IOException { StringBuilder sb = new StringBuilder(); sb.append(EArffItem.RELATION.getValue() + " " + data.getRelationName()); sb.append("\n"); sb.append("\n"); for (IAttribute att : data.getInstanceSchema().getAttributeList()) { sb.append(serializeAttribute(att)); sb.append("\n"); } sb.append(serializeAttribute(data.getInstanceSchema().getLabelAttribute())); bw.write(sb.toString()); } private static String serializeAttribute(final IAttribute att) { StringBuilder sb = new StringBuilder(); sb.append(EArffItem.ATTRIBUTE.getValue() + " '" + att.getName() + "' "); if (att instanceof ICategoricalAttribute) { sb.append("{'" + ((ICategoricalAttribute) att).getLabels().stream().collect(Collectors.joining("','")) + "'}"); } else if (att instanceof INumericAttribute) { sb.append(EArffAttributeType.NUMERIC.getName()); } return sb.toString(); } }