package com.alibaba.alink.operator.batch.associationrule; import com.alibaba.alink.common.utils.DataSetConversionUtil; import com.alibaba.alink.common.utils.TableUtil; import com.alibaba.alink.operator.batch.BatchOperator; import com.alibaba.alink.operator.common.associationrule.ParallelPrefixSpan; import com.alibaba.alink.operator.common.associationrule.SequenceRule; import com.alibaba.alink.params.associationrule.PrefixSpanParams; import org.apache.flink.api.common.functions.*; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.common.typeinfo.Types; import org.apache.flink.api.java.DataSet; import org.apache.flink.api.java.aggregation.Aggregations; import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.api.java.tuple.Tuple3; import org.apache.flink.api.java.tuple.Tuple4; import org.apache.flink.configuration.Configuration; import org.apache.flink.ml.api.misc.param.Params; import org.apache.flink.table.api.Table; import org.apache.flink.types.Row; import org.apache.flink.util.Collector; import org.apache.flink.util.StringUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.util.*; /** * PrefixSpan algorithm is used to mine frequent sequential patterns. * The PrefixSpan algorithm is described in J. Pei, et al., * Mining Sequential Patterns by Pattern-Growth: The PrefixSpan Approach */ public final class PrefixSpanBatchOp extends BatchOperator<PrefixSpanBatchOp> implements PrefixSpanParams<PrefixSpanBatchOp> { private static final Logger LOG = LoggerFactory.getLogger(PrefixSpanBatchOp.class); /** * The separator between items in an element. */ public static final String ITEM_SEPARATOR = ","; /** * The separator between elements in a sequence. */ public static final String ELEMENT_SEPARATOR = ";"; /** * The separator between antecedent and consequent in the rules. */ public static final String RULE_SEPARATOR = "=>"; private static final String[] ITEMSETS_COL_NAMES = new String[]{"itemset", "supportcount", "itemcount"}; private static final String[] RULES_COL_NAMES = new String[]{"rule", "chain_length", "support", "confidence", "transaction_count"}; private static final TypeInformation[] ITEMSETS_COL_TYPES = new TypeInformation[]{ Types.STRING, Types.LONG, Types.LONG}; private static final TypeInformation[] RULES_COL_TYPES = new TypeInformation[]{ Types.STRING, Types.LONG, Types.DOUBLE, Types.DOUBLE, Types.LONG}; public PrefixSpanBatchOp() { this(new Params()); } public PrefixSpanBatchOp(Params params) { super(params); } @Override public PrefixSpanBatchOp linkFrom(BatchOperator<?>... inputs) { BatchOperator<?> in = checkAndGetFirst(inputs); final String itemsColName = getItemsCol(); final double minSupportPercent = getMinSupportPercent(); final int minSupportCount = getMinSupportCount(); final int maxPatternLength = getMaxPatternLength(); final double minConfidence = getMinConfidence(); final int itemsColIdx = TableUtil.findColIndexWithAssertAndHint(in.getSchema(), itemsColName); DataSet<Long> sequenceCount = count(in.getDataSet()); DataSet<Long> minSupportCnt = getMinSupportCnt(sequenceCount, minSupportCount, minSupportPercent); DataSet<List<List<String>>> inputSequences = ((DataSet<Row>) in.getDataSet()) .map(new MapFunction<Row, List<List<String>>>() { @Override public List<List<String>> map(Row row) throws Exception { String sequence = (String) row.getField(itemsColIdx); if (StringUtils.isNullOrWhitespaceOnly(sequence)) { return new ArrayList<>(); } String[] elements = sequence.split(ELEMENT_SEPARATOR); List<List<String>> ret = new ArrayList<>(elements.length); for (String element : elements) { String[] items = element.trim().split(ITEM_SEPARATOR); ret.add(Arrays.asList(items)); } return ret; } }) .name("split_sequences"); // Count the support of each items. DataSet<Tuple2<String, Integer>> itemCounts = inputSequences .flatMap(new FlatMapFunction<List<List<String>>, Tuple2<String, Integer>>() { @Override public void flatMap(List<List<String>> sequence, Collector<Tuple2<String, Integer>> out) throws Exception { sequence.forEach( s -> { s.forEach(t -> { out.collect(Tuple2.of(t, 1)); }); } ); } }) .groupBy(0) .aggregate(Aggregations.SUM, 1); // Drop items with support smaller than requirement. DataSet<Tuple2<String, Integer>> qualifiedItems = itemCounts .filter(new RichFilterFunction<Tuple2<String, Integer>>() { transient Long minSupportCount; @Override public void open(Configuration parameters) throws Exception { List<Long> bc = getRuntimeContext().getBroadcastVariable("minSupportCnt"); minSupportCount = bc.get(0); LOG.info("minSupportCnt {}", minSupportCount); } @Override public boolean filter(Tuple2<String, Integer> value) throws Exception { return value.f1 >= minSupportCount; } }) .withBroadcastSet(minSupportCnt, "minSupportCnt") .name("get_qualified_items"); // Assign items with indices, ordered by their support. DataSet<Tuple2<String, Integer>> itemIndex = in.getDataSet().getExecutionEnvironment().fromElements(0) .flatMap(new RichFlatMapFunction<Integer, Tuple2<String, Integer>>() { @Override public void flatMap(Integer value, Collector<Tuple2<String, Integer>> out) throws Exception { List<Tuple2<String, Integer>> bc = getRuntimeContext().getBroadcastVariable("qualifiedItems"); Integer[] order = new Integer[bc.size()]; for (int i = 0; i < order.length; i++) { order[i] = i; } Arrays.sort(order, (o1, o2) -> { Integer cnt1 = bc.get(o1).f1; Integer cnt2 = bc.get(o2).f1; if (cnt1.equals(cnt2)) { return bc.get(o1).f0.compareTo(bc.get(o2).f0); } return Integer.compare(cnt2, cnt1); }); for (int i = 0; i < order.length; i++) { out.collect(Tuple2.of(bc.get(order[i]).f0, i + 1)); // the index starts from 1 } } }) .withBroadcastSet(qualifiedItems, "qualifiedItems"); // Map each sequences to an int array. We use 0 to separate elements. DataSet<int[]> sequences = inputSequences .map(new RichMapFunction<List<List<String>>, int[]>() { transient Map<String, Integer> tokenToId; @Override public void open(Configuration parameters) throws Exception { tokenToId = new HashMap<>(); List<Tuple2<String, Integer>> bc = getRuntimeContext().getBroadcastVariable("itemIndex"); bc.forEach(t -> tokenToId.put(t.f0, t.f1)); } @Override public int[] map(List<List<String>> elements) throws Exception { List<Integer> seq = new ArrayList<>(); seq.add(0); for (List<String> element : elements) { int cnt = 0; for (String it : element) { Integer id = tokenToId.get(it); if (id != null) { cnt++; seq.add(id); } } if (cnt > 0) { seq.add(0); } } int[] sequence = new int[seq.size()]; for (int i = 0; i < sequence.length; i++) { sequence[i] = seq.get(i); } return sequence; } }) .withBroadcastSet(itemIndex, "itemIndex") .name("map_seq_to_int_array"); DataSet<Tuple2<Integer, Integer>> qualifiedItemCount = itemCounts.join(itemIndex) .where(0).equalTo(0).projectSecond(1).projectFirst(1); ParallelPrefixSpan ps = new ParallelPrefixSpan(sequences, minSupportCnt, qualifiedItemCount, maxPatternLength); DataSet<Tuple2<int[], Integer>> freqPatterns = ps.run(); DataSet<Tuple4<int[], int[], Integer, double[]>> rules = SequenceRule.extractSequenceRules(freqPatterns, sequenceCount, minConfidence); // Maps the indices in freq patterns and rules back to the original strings. DataSet<Row> patternsOutput = patternsIndexToString(freqPatterns, itemIndex); DataSet<Row> rulesOutput = rulesIndexToString(rules, itemIndex); Table table0 = DataSetConversionUtil.toTable(getMLEnvironmentId(), patternsOutput, ITEMSETS_COL_NAMES, ITEMSETS_COL_TYPES); Table table1 = DataSetConversionUtil.toTable(getMLEnvironmentId(), rulesOutput, RULES_COL_NAMES, RULES_COL_TYPES); this.setOutputTable(table0); this.setSideOutputTables(new Table[]{ table1 }); return this; } /** * Count number of records in the dataset. * * @return a dataset of one record, recording the number of records of "dataSet". */ private static <T> DataSet<Long> count(DataSet<T> dataSet) { return dataSet .mapPartition(new MapPartitionFunction<T, Long>() { @Override public void mapPartition(Iterable<T> values, Collector<Long> out) throws Exception { long cnt = 0L; for (T v : values) { cnt++; } out.collect(cnt); } }) .name("count_dataset") .returns(Types.LONG) .reduce(new ReduceFunction<Long>() { @Override public Long reduce(Long value1, Long value2) throws Exception { return value1 + value2; } }); } private static DataSet<Long> getMinSupportCnt(DataSet<Long> transactionsCnt, final int minSupportCount, final double minSupportPercent) { return transactionsCnt.map(new MapFunction<Long, Long>() { @Override public Long map(Long value) throws Exception { if (minSupportCount >= 0) { return (long) minSupportCount; } else { return (long) (Math.floor(value * minSupportPercent)); } } }); } /** * Encode the sequence patterns. */ private static Tuple3<String, Long, Long> encodeSequence(int[] sequence, String[] indexToString) { StringBuilder sbd = new StringBuilder(); int itemSetSize = 0; long chainLength = 1L; long itemCount = 0L; for (int i = 1; i < sequence.length - 1; i++) { if (sequence[i] == 0) { sbd.append(ELEMENT_SEPARATOR); chainLength++; itemSetSize = 0; } else { if (itemSetSize > 0) { sbd.append(ITEM_SEPARATOR); } sbd.append(indexToString[sequence[i]]); itemSetSize++; itemCount++; } } return Tuple3.of(sbd.toString(), itemCount, chainLength); } /** * Maps items' ids to strings in frequent patterns. * * @param patterns A dataset of: frequent patterns (represented as int array), support count. * @param itemIndex A dataset which is a mapping from items' names to indices. * @return A dataset of frequent patterns that is for output. */ private static DataSet<Row> patternsIndexToString(DataSet<Tuple2<int[], Integer>> patterns, DataSet<Tuple2<String, Integer>> itemIndex) { return patterns .map(new RichMapFunction<Tuple2<int[], Integer>, Row>() { transient String[] itemNames; @Override public void open(Configuration parameters) throws Exception { List<Tuple2<String, Integer>> bc = getRuntimeContext().getBroadcastVariable("itemIndex"); itemNames = new String[bc.size() + 1]; bc.forEach(t -> itemNames[t.f1] = t.f0); } @Override public Row map(Tuple2<int[], Integer> value) throws Exception { int[] sequence = value.f0; Tuple3<String, Long, Long> encoded = encodeSequence(sequence, itemNames); return Row.of(encoded.f0, value.f1.longValue(), encoded.f1); } }) .withBroadcastSet(itemIndex, "itemIndex") .name("patternsIndexToString"); } /** * Maps items' ids to strings in association rules. * * @param rules A dataset of: antecedent, consequent, support count, [lift, support, confidence]. * @param itemIndex A dataset which is a mapping from items' names to indices. * @return A dataset of association rules that is for output. */ private static DataSet<Row> rulesIndexToString(DataSet<Tuple4<int[], int[], Integer, double[]>> rules, DataSet<Tuple2<String, Integer>> itemIndex) { return rules .map(new RichMapFunction<Tuple4<int[], int[], Integer, double[]>, Row>() { transient String[] itemNames; @Override public void open(Configuration parameters) throws Exception { List<Tuple2<String, Integer>> bc = getRuntimeContext().getBroadcastVariable("itemIndex"); itemNames = new String[bc.size() + 1]; bc.forEach(t -> { itemNames[t.f1] = t.f0; }); } @Override public Row map(Tuple4<int[], int[], Integer, double[]> value) throws Exception { Tuple3<String, Long, Long> antecedent = encodeSequence(value.f0, this.itemNames); Tuple3<String, Long, Long> consequent = encodeSequence(value.f1, this.itemNames); return Row.of(antecedent.f0 + RULE_SEPARATOR + consequent.f0, antecedent.f2 + consequent.f2, value.f3[0], value.f3[1], value.f2.longValue()); } }) .withBroadcastSet(itemIndex, "itemIndex") .name("rulesIndexToString"); } }