weibozhao commented on a change in pull request #24: URL: https://github.com/apache/flink-ml/pull/24#discussion_r755678350
########## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModel.java ########## @@ -0,0 +1,594 @@ +package org.apache.flink.ml.classification.knn; + +import org.apache.flink.api.common.eventtime.WatermarkStrategy; +import org.apache.flink.api.common.functions.RichMapFunction; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.connector.source.Source; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.connector.file.sink.FileSink; +import org.apache.flink.connector.file.src.FileSource; +import org.apache.flink.core.fs.Path; +import org.apache.flink.ml.api.core.Model; +import org.apache.flink.ml.common.broadcast.BroadcastUtils; +import org.apache.flink.ml.linalg.DenseMatrix; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.shaded.curator4.com.google.common.base.Preconditions; +import org.apache.flink.shaded.curator4.com.google.common.collect.ImmutableMap; +import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.core.type.TypeReference; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.streaming.api.functions.sink.filesystem.bucketassigners.BasePathBucketAssigner; +import org.apache.flink.streaming.api.functions.sink.filesystem.rollingpolicies.OnCheckpointRollingPolicy; +import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.api.Schema; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.table.catalog.ResolvedSchema; +import org.apache.flink.table.types.DataType; +import org.apache.flink.table.types.logical.utils.LogicalTypeParser; +import org.apache.flink.table.types.utils.LogicalTypeDataTypeConverter; +import org.apache.flink.types.Row; + +import org.apache.commons.lang3.ArrayUtils; +import sun.reflect.generics.reflectiveObjects.ParameterizedTypeImpl; + +import java.io.IOException; +import java.lang.reflect.Type; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.PriorityQueue; +import java.util.TreeMap; +import java.util.function.Function; + +/** Knn classification model fitted by estimator. */ +public class KnnModel implements Model<KnnModel>, KnnParams<KnnModel> { + + private static final long serialVersionUID = 1303892137143865652L; + + private static final String BROADCAST_STR = "broadcastModelKey"; + private static final int FASTDISTANCE_TYPE_INDEX = 0; + private static final int DATA_INDEX = 1; + + protected Map<Param<?>, Object> params = new HashMap<>(); + + private Table[] modelData; + + /** constructor. */ + public KnnModel() { + ParamUtils.initializeMapWithDefaultValues(params, this); + } + + /** + * constructor. + * + * @param params parameters for algorithm. + */ + public KnnModel(Map<Param<?>, Object> params) { + this.params = params; + } + + /** + * Set model data for knn prediction. + * + * @param modelData knn model. + * @return knn classification model. + */ + @Override + public KnnModel setModelData(Table... modelData) { + this.modelData = modelData; + return this; + } + + /** + * get model data. + * + * @return list of tables. + */ + @Override + public Table[] getModelData() { + return modelData; + } + + /** + * @param inputs a list of tables. + * @return result. + */ + @Override + public Table[] transform(Table... inputs) { + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + DataStream<Row> input = tEnv.toDataStream(inputs[0]); + DataStream<Row> model = tEnv.toDataStream(modelData[0]); + + Map<String, DataStream<?>> broadcastMap = new HashMap<>(1); + broadcastMap.put(BROADCAST_STR, model); + ResolvedSchema modelSchema = modelData[0].getResolvedSchema(); + DataType idType = + modelSchema.getColumnDataTypes().get(modelSchema.getColumnNames().size() - 1); + + ResolvedSchema outputSchema = + getOutputSchema(inputs[0].getResolvedSchema(), getParamMap(), idType); + + DataType[] dataTypes = outputSchema.getColumnDataTypes().toArray(new DataType[0]); + TypeInformation<?>[] typeInformations = new TypeInformation[dataTypes.length]; + + for (int i = 0; i < dataTypes.length; ++i) { + typeInformations[i] = TypeInformation.of(dataTypes[i].getLogicalType().getClass()); + } + + Function<List<DataStream<?>>, DataStream<Row>> function = + dataStreams -> { + DataStream stream = dataStreams.get(0); + return stream.transform( + "mapFunc", + new RowTypeInfo( + typeInformations, + outputSchema.getColumnNames().toArray(new String[0])), + new PredictOperator( + new KnnRichFunction( + getParamMap(), inputs[0].getResolvedSchema()))); + }; + + DataStream<Row> output = + BroadcastUtils.withBroadcastStream( + Collections.singletonList(input), broadcastMap, function); + return new Table[] {tEnv.fromDataStream(output, resolvedSchema2Schema(outputSchema))}; + } + + /** + * transform resolvedSchema to schema. + * + * @param resolvedSchema input resolvedSchema. + * @return output schema. + */ + private static Schema resolvedSchema2Schema(ResolvedSchema resolvedSchema) { + Schema.Builder builder = Schema.newBuilder(); + List<String> colNames = resolvedSchema.getColumnNames(); + List<DataType> colTypes = resolvedSchema.getColumnDataTypes(); + for (int i = 0; i < colNames.size(); ++i) { + builder.column(colNames.get(i), colTypes.get(i).getLogicalType().toString()); + } + return builder.build(); + } + + private static class KnnRichFunction extends RichMapFunction<Row, Row> { + private boolean firstEle = true; + private final String[] reservedCols; + private String[] selectedCols; + private String vectorCol; + private DataType idType; + private transient KnnModelData modelData; + private final Integer topN; + private Map<String, Object> meta; + + public KnnRichFunction(Map<Param<?>, Object> params, ResolvedSchema dataSchema) { + reservedCols = dataSchema.getColumnNames().toArray(new String[0]); + this.topN = (Integer) params.get(KnnParams.K); + } + + @Override + public Row map(Row row) throws Exception { + if (firstEle) { + loadModel(getRuntimeContext().getBroadcastVariable(BROADCAST_STR)); + firstEle = false; + } + DenseVector vector; + if (null != selectedCols) { + vector = new DenseVector(new double[selectedCols.length]); + for (int i = 0; i < selectedCols.length; i++) { + Preconditions.checkNotNull( + row.getField(selectedCols[i]), "There is NULL in featureCols!"); + vector.set(i, ((Number) row.getField(selectedCols[i])).doubleValue()); + } + } else { + vector = DenseVector.fromString(row.getField(vectorCol).toString()); + } + String s = findNeighbor(vector, topN, modelData).toLowerCase(); + + Row ret = new Row(reservedCols.length + 1); + for (int i = 0; i < reservedCols.length; ++i) { + ret.setField(i, row.getField(reservedCols[i])); + } + + Tuple2<Object, String> tuple2 = getResultFormat(extractObject(s, idType)); + ret.setField(reservedCols.length, tuple2.f0); + return ret; + } + + /** + * find the nearest topN neighbors from whole nodes. + * + * @param input input node. + * @param topN top N. + * @return neighbor. + */ + private String findNeighbor(Object input, Integer topN, KnnModelData modelData) { + PriorityQueue<Tuple2<Double, Object>> priorityQueue = + new PriorityQueue<>(modelData.getQueueComparator()); + search(input, topN, priorityQueue, modelData); + List<Object> items = new ArrayList<>(); + List<Double> metrics = new ArrayList<>(); + while (!priorityQueue.isEmpty()) { + Tuple2<Double, Object> result = priorityQueue.poll(); + items.add(castTo(result.f1, idType)); + metrics.add(result.f0); + } + Collections.reverse(items); + Collections.reverse(metrics); + priorityQueue.clear(); + return serializeResult(items, ImmutableMap.of("METRIC", metrics)); + } + + /** + * serialize result to json format. + * + * @param objectValue the nearest nodes found. + * @param others the metric of nodes. + * @return serialize result. + */ + private String serializeResult(List<Object> objectValue, Map<String, List<Double>> others) { + final String id = "ID"; + Map<String, String> result = + new TreeMap<>( + (o1, o2) -> { + if (id.equals(o1) && id.equals(o2)) { + return 0; + } else if (id.equals(o1)) { + return -1; + } else if (id.equals(o2)) { + return 1; + } + + return o1.compareTo(o2); + }); + + result.put(id, ReadWriteUtils.OBJECT_MAPPER.toJson(objectValue)); + + if (others != null) { + for (Map.Entry<String, List<Double>> other : others.entrySet()) { + result.put(other.getKey(), ReadWriteUtils.OBJECT_MAPPER.toJson(other.getValue())); + } + } + return ReadWriteUtils.OBJECT_MAPPER.toJson(result); + } + + /** + * @param input input node. + * @param topN top N. + * @param priorityQueue priority queue. + */ + private void search( + Object input, + Integer topN, + PriorityQueue<Tuple2<Double, Object>> priorityQueue, + KnnModelData modelData) { + Object sample = prepareSample(input, modelData); + Tuple2<Double, Object> head = null; + for (int i = 0; i < modelData.getLength(); i++) { + ArrayList<Tuple2<Double, Object>> values = computeDistance(sample, i); + if (null == values || values.size() == 0) { + continue; + } + for (Tuple2<Double, Object> currentValue : values) { + if (null == topN) { + priorityQueue.add(Tuple2.of(currentValue.f0, currentValue.f1)); + } else { + head = updateQueue(priorityQueue, topN, currentValue, head); + } + } + } + } + + /** + * update queue. + * + * @param map queue. + * @param topN top N. + * @param newValue new value. + * @param head head value. + * @param <T> id type. + * @return head value. + */ + private <T> Tuple2<Double, T> updateQueue( + PriorityQueue<Tuple2<Double, T>> map, + int topN, + Tuple2<Double, T> newValue, + Tuple2<Double, T> head) { + if (null == newValue) { + return head; + } + if (map.size() < topN) { + map.add(Tuple2.of(newValue.f0, newValue.f1)); + head = map.peek(); + } else { + if (map.comparator().compare(head, newValue) < 0) { + Tuple2<Double, T> peek = map.poll(); + peek.f0 = newValue.f0; + peek.f1 = newValue.f1; + map.add(peek); + head = map.peek(); + } + } + return head; + } + + /** + * prepare sample. + * + * @param input sample to parse. + * @return + */ + private Object prepareSample(Object input, KnnModelData modelData) { + return modelData + .getFastDistance() + .prepareVectorData(Tuple2.of(DenseVector.fromString(input.toString()), null)); + } + + private ArrayList<Tuple2<Double, Object>> computeDistance(Object input, Integer index) { + FastDistanceMatrixData data = modelData.getDictData().get(index); + DenseMatrix res = + modelData.getFastDistance().calc((FastDistanceVectorData) input, data); + ArrayList<Tuple2<Double, Object>> list = new ArrayList<>(0); + Row[] curRows = data.getRows(); + for (int i = 0; i < data.getRows().length; i++) { + Tuple2<Double, Object> tuple = Tuple2.of(res.getData()[i], curRows[i].getField(0)); + list.add(tuple); + } + return list; + } + + /** + * get output format of knn predict result. + * + * @param tuple initial result from knn predictor. + * @return output format result. + */ + private Tuple2<Object, String> getResultFormat(Tuple2<List<Object>, List<Object>> tuple) { + double percent = 1.0 / tuple.f0.size(); + Map<Object, Double> detail = new HashMap<>(0); + + for (Object obj : tuple.f0) { + detail.merge(obj, percent, Double::sum); + } + + double max = 0.0; + Object prediction = null; + + for (Map.Entry<Object, Double> entry : detail.entrySet()) { + if (entry.getValue() > max) { + max = entry.getValue(); + prediction = entry.getKey(); + } + } + + return Tuple2.of(prediction, ReadWriteUtils.OBJECT_MAPPER.toJson(detail)); + } + + /** + * @param json json format result of knn prediction. + * @param idType id type. + * @return List format result. + */ + private Tuple2<List<Object>, List<Object>> extractObject(String json, DataType idType) { + Map<String, String> deserializedJson; + try { + deserializedJson = + ReadWriteUtils.OBJECT_MAPPER.fromJson(json, new TypeReference<Map<String, String>>() {}.getType()); + } catch (Exception e) { + throw new IllegalStateException( + "Fail to deserialize json '" + json + "', please check the input!"); + } + + Map<String, String> lowerCaseDeserializedJson = new HashMap<>(0); + + for (Map.Entry<String, String> entry : deserializedJson.entrySet()) { + lowerCaseDeserializedJson.put( + entry.getKey().trim().toLowerCase(), entry.getValue()); + } + + Map<String, List<Object>> map = new HashMap<>(2); + + Type type = idType.getLogicalType().getDefaultConversion(); + String ids = lowerCaseDeserializedJson.get("id"); + String metric = lowerCaseDeserializedJson.get("metric"); + if (ids == null) { + map.put("id", null); + } else { + map.put( + "id", + ReadWriteUtils.OBJECT_MAPPER.fromJson( + ids, + ParameterizedTypeImpl.make(List.class, new Type[] {type}, null))); + } + + if (ids == null) { + map.put("metric", null); + } else { + map.put( + "metric", + ReadWriteUtils.OBJECT_MAPPER.fromJson( + metric, + ParameterizedTypeImpl.make( + List.class, new Type[] {Double.class}, null))); + } + return Tuple2.of(map.get("id"), map.get("metric")); + } + + private void loadModel(List<Object> broadcastVar) { + List<FastDistanceMatrixData> dictData = new ArrayList<>(); + for (Object obj : broadcastVar) { + Row row = (Row) obj; + if (row.getField(row.getArity() - 2) != null) { + meta = ReadWriteUtils.OBJECT_MAPPER.fromJson((String) row.getField(row.getArity() - 2), HashMap.class); + } + } + for (Object obj : broadcastVar) { + Row row = (Row) obj; + if (row.getField(FASTDISTANCE_TYPE_INDEX) != null) { + long type = (long) row.getField(FASTDISTANCE_TYPE_INDEX); + if (type == 1L) { + dictData.add( + FastDistanceMatrixData.fromString( + (String) row.getField(DATA_INDEX))); + } + } + } + if (meta.containsKey(KnnParams.FEATURE_COLS.name)) { + selectedCols = + ReadWriteUtils.OBJECT_MAPPER.fromJson( + (String) meta.get(KnnParams.FEATURE_COLS.name), String[].class); + } else { + vectorCol = + ReadWriteUtils.OBJECT_MAPPER.fromJson((String) meta.get(KnnParams.VECTOR_COL.name), String.class); + } + + modelData = new KnnModelData(dictData, new EuclideanDistance()); + idType = + LogicalTypeDataTypeConverter.toDataType( + LogicalTypeParser.parse((String) this.meta.get("idType"))); + } + } + + /** + * this operator use mapper to load the model data and do the prediction. if you want to write a + * prediction operator, you need implement a special mapper for this operator. + */ + private static class PredictOperator + extends AbstractUdfStreamOperator<Row, RichMapFunction<Row, Row>> + implements OneInputStreamOperator<Row, Row> { + + public PredictOperator(RichMapFunction<Row, Row> userFunction) { + super(userFunction); + } + + @Override + public void processElement(StreamRecord<Row> streamRecord) throws Exception { + Row value = streamRecord.getValue(); + output.collect(new StreamRecord<>(userFunction.map(value))); + } + } + + private ResolvedSchema getOutputSchema( + ResolvedSchema dataSchema, Map<Param<?>, Object> params, DataType idType) { + String[] reservedCols = dataSchema.getColumnNames().toArray(new String[0]); + DataType[] reservedTypes = dataSchema.getColumnDataTypes().toArray(new DataType[0]); + String[] resultCols = new String[] {(String) params.get(KnnParams.PREDICTION_COL)}; + DataType[] resultTypes = new DataType[] {idType}; + return ResolvedSchema.physical( + ArrayUtils.addAll(reservedCols, resultCols), + ArrayUtils.addAll(reservedTypes, resultTypes)); + } + + /** + * cast data x to t type. + * + * @param x data. + * @param t type. + * @return + */ + private static Object castTo(Object x, DataType t) { Review comment: OK, I will refine it later. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org