Github user markap14 commented on a diff in the pull request: https://github.com/apache/nifi/pull/2686#discussion_r187353262 --- Diff: nifi-nar-bundles/nifi-deeplearning4j-bundle/nifi-deeplearning4j-processors/src/main/java/org/apache/nifi/processors/deeplearning4j/DeepLearning4JPredictor.java --- @@ -0,0 +1,218 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.nifi.processors.deeplearning4j; +import org.apache.nifi.annotation.behavior.EventDriven; +import org.apache.nifi.annotation.behavior.InputRequirement; +import org.apache.nifi.annotation.behavior.InputRequirement.Requirement; +import org.apache.nifi.annotation.behavior.SupportsBatching; +import org.apache.nifi.annotation.behavior.WritesAttribute; +import org.apache.nifi.annotation.behavior.WritesAttributes; +import org.apache.nifi.annotation.documentation.CapabilityDescription; +import org.apache.nifi.annotation.documentation.Tags; +import org.apache.nifi.components.PropertyDescriptor; +import org.apache.nifi.flowfile.FlowFile; +import org.apache.nifi.processor.ProcessContext; +import org.apache.nifi.processor.ProcessSession; +import org.apache.nifi.processor.Relationship; +import org.apache.nifi.processor.exception.ProcessException; +import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; +import com.google.gson.Gson; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.nio.charset.Charset; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +@EventDriven +@SupportsBatching +@InputRequirement(Requirement.INPUT_REQUIRED) +@Tags({"deeplearning4j", "dl4j", "predict", "classification", "regression", "deep", "learning"}) +@CapabilityDescription("The DeepLearning4JPredictor predicts one or more value(s) based on provided deeplearning4j (https://github.com/deeplearning4j) model and the content of a FlowFile. " + + "The processor supports both classification and regression by extracting the record from the FlowFile body and applying the model. " + + "The processor supports batch by allowing multiple records to be passed in the FlowFile body with each record separated by the 'Record Separator' property. " + + "Each record can contain multiple fields with each field separated by the 'Field Separator' property." + ) +@WritesAttributes({ + @WritesAttribute(attribute = AbstractDeepLearning4JProcessor.DEEPLEARNING4J_ERROR_MESSAGE, description = "Deeplearning4J error message"), + @WritesAttribute(attribute = AbstractDeepLearning4JProcessor.DEEPLEARNING4J_OUTPUT_SHAPE, description = "Deeplearning4J output shape"), + }) +public class DeepLearning4JPredictor extends AbstractDeepLearning4JProcessor { + + static final Relationship REL_SUCCESS = new Relationship.Builder().name("success") + .description("Successful DeepLearning4j results are routed to this relationship").build(); + + static final Relationship REL_FAILURE = new Relationship.Builder().name("failure") + .description("Failed DeepLearning4j results are routed to this relationship").build(); + + protected Gson gson = new Gson(); + + private static final Set<Relationship> relationships; + private static final List<PropertyDescriptor> propertyDescriptors; + static { + final Set<Relationship> tempRelationships = new HashSet<>(); + tempRelationships.add(REL_SUCCESS); + tempRelationships.add(REL_FAILURE); + relationships = Collections.unmodifiableSet(tempRelationships); + final List<PropertyDescriptor> tempDescriptors = new ArrayList<>(); + tempDescriptors.add(MODEL_FILE); + tempDescriptors.add(RECORD_DIMENSIONS); + tempDescriptors.add(CHARSET); + tempDescriptors.add(FIELD_SEPARATOR); + tempDescriptors.add(RECORD_SEPARATOR); + propertyDescriptors = Collections.unmodifiableList(tempDescriptors); + } + + @Override + public Set<Relationship> getRelationships() { + return relationships; + } + + @Override + public final List<PropertyDescriptor> getSupportedPropertyDescriptors() { + return propertyDescriptors; + } + + @Override + public void onTrigger(final ProcessContext context, final ProcessSession session) throws ProcessException { + FlowFile flowFile = session.get(); + if ( flowFile == null ) { + return; + } + + Charset charset = Charset.forName(context.getProperty(CHARSET).evaluateAttributeExpressions(flowFile).getValue()); + if ( flowFile.getSize() == 0 ) { + String message = "FlowFile query is empty"; + getLogger().error(message); + flowFile = session.putAttribute(flowFile, DEEPLEARNING4J_ERROR_MESSAGE, message); + session.transfer(flowFile, REL_FAILURE); + return; + } + + String input = null; + try { + input = getFlowFileContents(session, charset, flowFile); + String fieldSeparator = context.getProperty(FIELD_SEPARATOR).evaluateAttributeExpressions(flowFile).getValue(); + String recordSeparator = context.getProperty(RECORD_SEPARATOR).evaluateAttributeExpressions(flowFile).getValue(); + + int [] dimensions = getInputDimensions(context, charset, flowFile, fieldSeparator); + + if ( getLogger().isDebugEnabled() ) { + getLogger().debug("Received input {} with dimensions {}", new Object[] { input, dimensions }); + } + + MultiLayerNetwork model = getModel(context); + + long startTimeMillis = System.currentTimeMillis(); + + String [] inputRecords = input.split(recordSeparator); + + List<INDArray> features = Arrays.stream(inputRecords).map( + record -> { + double [] parameters = Arrays.stream(record.split(fieldSeparator)).mapToDouble( + field -> Double.parseDouble(field)).toArray(); + + INDArray featureInput = Nd4j.create(parameters, dimensions); + + if ( getLogger().isDebugEnabled() ) { + getLogger().debug("Features for record {} parameters {} dims {} featureInput {} ", + new Object[] {record, parameters, dimensions, featureInput}); + } + + return featureInput; + + }).collect(Collectors.toList()); + + INDArray allFeatures = Nd4j.vstack(features); + + INDArray results = model.output(allFeatures); + + double [][] partitionedResults = new double[inputRecords.length][]; + for (int row = 0; row < inputRecords.length; row++) { + INDArray result = results.getRow(row); + partitionedResults[row] = Nd4j.toFlattened(result).toDoubleVector(); + } + + String jsonResult = gson.toJson(partitionedResults); + int [] shape = results.shape(); + String jsonShape = gson.toJson(Arrays.copyOfRange(shape, 1, shape.length)); + + if ( getLogger().isDebugEnabled() ) { + getLogger().debug("Prediction for inputRecords {}, dims {}, results {}, result.shape {}, partitionedResults {}, jsonResult {}, shape {}, jsonShape {}", + new Object[] {inputRecords, dimensions, results, Arrays.toString(results.shape()), partitionedResults, jsonResult, shape, jsonShape}); + } + + try (ByteArrayInputStream bais = new ByteArrayInputStream(jsonResult.getBytes(charset))) { --- End diff -- Is more efficient here (and probably more straight-forward?) to simply do: `flowFile = session.write(flowFile, out -> out.write(jsonResult.getBytes(charset)));`
---