AnandInguva commented on code in PR #24965:
URL: https://github.com/apache/beam/pull/24965#discussion_r1067299492


##########
sdks/python/apache_beam/ml/inference/xgboost_inference_it_test.py:
##########
@@ -0,0 +1,314 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import logging
+
+import pytest
+import unittest
+import uuid
+
+from apache_beam.examples.inference import xgboost_iris_classification
+from apache_beam.io.filesystems import FileSystems
+from apache_beam.testing.test_pipeline import TestPipeline
+
+
+def process_outputs(filepath):
+  with FileSystems().open(filepath) as f:
+    lines = f.readlines()
+  lines = [l.decode('utf-8').strip('\n') for l in lines]
+  return lines
+
+
[email protected]_xgboost
[email protected]_postcommit
+class XGBoostInference(unittest.TestCase):
+  def test_iris_classification_numpy_single_batch(self):
+    test_pipeline = TestPipeline(is_integration_test=True)
+    input_type = 'numpy'
+    output_file_dir = '/tmp'

Review Comment:
   Can we have output file directory to gcs since it gets cleaned up 
periodically? 
https://github.com/apache/beam/blob/c0e689331c2a6573ecf267b9bef133a85ea8a36c/sdks/python/apache_beam/ml/inference/pytorch_inference_it_test.py#L74



##########
sdks/python/apache_beam/ml/inference/xgboost_inference_test.py:
##########
@@ -0,0 +1,443 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+import shutil
+import sys
+import tempfile
+import unittest
+import zipfile
+from typing import Any
+from typing import Tuple
+
+import datatable
+import numpy
+import pandas
+import pytest
+import scipy
+
+import apache_beam as beam
+from apache_beam.ml.inference import RunInference
+from apache_beam.ml.inference.base import KeyedModelHandler
+from apache_beam.ml.inference.base import PredictionResult
+from apache_beam.ml.inference.xgboost_inference import 
XGBoostModelHandlerDatatable
+from apache_beam.ml.inference.xgboost_inference import XGBoostModelHandlerNumpy
+from apache_beam.ml.inference.xgboost_inference import 
XGBoostModelHandlerPandas
+from apache_beam.ml.inference.xgboost_inference import XGBoostModelHandlerSciPy
+from apache_beam.testing.test_pipeline import TestPipeline
+from apache_beam.testing.util import assert_that
+from apache_beam.testing.util import equal_to
+
+try:
+  import xgboost
+except ImportError:
+  raise unittest.SkipTest('XGBoost dependencies are not installed')
+
+
+def _compare_prediction_result(a: PredictionResult, b: PredictionResult):
+  if isinstance(a.example, scipy.sparse.csr_matrix) and isinstance(

Review Comment:
   For unittests to run, let us add these to the tox.ini file with the 
dependencies needed for xgboost tests similar to torch
   
   
https://github.com/apache/beam/blob/3a98ecbdbfad59ae8cb96d7eaef544d25d9c437f/sdks/python/tox.ini#L317



##########
sdks/python/apache_beam/examples/inference/xgboost_iris_classification.py:
##########
@@ -0,0 +1,169 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import argparse
+import logging
+from typing import Callable
+from typing import Iterable
+from typing import List
+from typing import Tuple
+from typing import Union
+
+import datatable
+import numpy
+import pandas
+import scipy
+import xgboost
+from sklearn.datasets import load_iris
+from sklearn.model_selection import train_test_split
+
+import apache_beam as beam
+from apache_beam.ml.inference.base import KeyedModelHandler
+from apache_beam.ml.inference.base import PredictionResult
+from apache_beam.ml.inference.base import RunInference
+from apache_beam.ml.inference.xgboost_inference import 
XGBoostModelHandlerDatatable
+from apache_beam.ml.inference.xgboost_inference import XGBoostModelHandlerNumpy
+from apache_beam.ml.inference.xgboost_inference import 
XGBoostModelHandlerPandas
+from apache_beam.ml.inference.xgboost_inference import XGBoostModelHandlerSciPy
+from apache_beam.options.pipeline_options import PipelineOptions
+from apache_beam.options.pipeline_options import SetupOptions
+from apache_beam.runners.runner import PipelineResult
+
+
+def _train_model(model_state_output_path: str = '/tmp/model.json', seed=999):
+  """Function to train an XGBoost Classifier using the sklearn Iris dataset"""
+  dataset = load_iris()
+  x_train, _, y_train, _ = train_test_split(
+      dataset['data'], dataset['target'], test_size=.2, random_state=seed)
+  booster = xgboost.XGBClassifier(
+      n_estimators=2, max_depth=2, learning_rate=1, 
objective='binary:logistic')
+  booster.fit(x_train, y_train)
+  booster.save_model(model_state_output_path)
+  return booster
+
+
+class PostProcessor(beam.DoFn):
+  """Process the PredictionResult to get the predicted label.
+  Returns a comma separated string with true label and predicted label.
+  """
+  def process(self, element: Tuple[int, PredictionResult]) -> Iterable[str]:
+    label, prediction_result = element
+    prediction = prediction_result.inference
+    yield '{},{}'.format(label, prediction)
+
+
+def parse_known_args(argv):
+  """Parses args for the workflow."""
+  parser = argparse.ArgumentParser()
+  parser.add_argument(
+      '--input-type',
+      dest='input_type',
+      required=True,
+      choices=['numpy', 'pandas', 'scipy', 'datatable'],
+      help=
+      'Datatype of the input data.'
+  )
+  parser.add_argument(
+      '--output',
+      dest='output',
+      required=True,
+      help='Path to save output predictions.')
+  parser.add_argument(
+      '--model-state',

Review Comment:
   ```suggestion
         '--model-state',
         '--model_state'
   ```
   
   I would replace `-` with `_`



##########
sdks/python/container/py310/base_image_requirements.txt:
##########
@@ -154,4 +155,5 @@ urllib3==1.26.13
 websocket-client==1.4.2
 Werkzeug==2.2.2
 wrapt==1.14.1
+xgboost==1.7.1

Review Comment:
   I would like to push back on adding these dependencies to container because 
they would increase the size of the container image since these are not 
dependencies of Apache Beam



##########
sdks/python/apache_beam/examples/inference/xgboost_iris_classification.py:
##########
@@ -0,0 +1,169 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import argparse
+import logging
+from typing import Callable
+from typing import Iterable
+from typing import List
+from typing import Tuple
+from typing import Union
+
+import datatable
+import numpy
+import pandas
+import scipy
+import xgboost
+from sklearn.datasets import load_iris
+from sklearn.model_selection import train_test_split
+
+import apache_beam as beam
+from apache_beam.ml.inference.base import KeyedModelHandler
+from apache_beam.ml.inference.base import PredictionResult
+from apache_beam.ml.inference.base import RunInference
+from apache_beam.ml.inference.xgboost_inference import 
XGBoostModelHandlerDatatable
+from apache_beam.ml.inference.xgboost_inference import XGBoostModelHandlerNumpy
+from apache_beam.ml.inference.xgboost_inference import 
XGBoostModelHandlerPandas
+from apache_beam.ml.inference.xgboost_inference import XGBoostModelHandlerSciPy
+from apache_beam.options.pipeline_options import PipelineOptions
+from apache_beam.options.pipeline_options import SetupOptions
+from apache_beam.runners.runner import PipelineResult
+
+
+def _train_model(model_state_output_path: str = '/tmp/model.json', seed=999):
+  """Function to train an XGBoost Classifier using the sklearn Iris dataset"""
+  dataset = load_iris()
+  x_train, _, y_train, _ = train_test_split(
+      dataset['data'], dataset['target'], test_size=.2, random_state=seed)
+  booster = xgboost.XGBClassifier(
+      n_estimators=2, max_depth=2, learning_rate=1, 
objective='binary:logistic')
+  booster.fit(x_train, y_train)
+  booster.save_model(model_state_output_path)
+  return booster
+
+
+class PostProcessor(beam.DoFn):
+  """Process the PredictionResult to get the predicted label.
+  Returns a comma separated string with true label and predicted label.
+  """
+  def process(self, element: Tuple[int, PredictionResult]) -> Iterable[str]:
+    label, prediction_result = element
+    prediction = prediction_result.inference
+    yield '{},{}'.format(label, prediction)
+
+
+def parse_known_args(argv):
+  """Parses args for the workflow."""
+  parser = argparse.ArgumentParser()
+  parser.add_argument(
+      '--input-type',
+      dest='input_type',
+      required=True,
+      choices=['numpy', 'pandas', 'scipy', 'datatable'],
+      help=
+      'Datatype of the input data.'
+  )
+  parser.add_argument(
+      '--output',
+      dest='output',
+      required=True,
+      help='Path to save output predictions.')
+  parser.add_argument(
+      '--model-state',

Review Comment:
   In the other arguments as well



##########
sdks/python/apache_beam/examples/inference/xgboost_iris_classification.py:
##########
@@ -0,0 +1,169 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import argparse
+import logging
+from typing import Callable
+from typing import Iterable
+from typing import List
+from typing import Tuple
+from typing import Union
+
+import datatable
+import numpy
+import pandas
+import scipy
+import xgboost
+from sklearn.datasets import load_iris
+from sklearn.model_selection import train_test_split
+
+import apache_beam as beam
+from apache_beam.ml.inference.base import KeyedModelHandler
+from apache_beam.ml.inference.base import PredictionResult
+from apache_beam.ml.inference.base import RunInference
+from apache_beam.ml.inference.xgboost_inference import 
XGBoostModelHandlerDatatable
+from apache_beam.ml.inference.xgboost_inference import XGBoostModelHandlerNumpy
+from apache_beam.ml.inference.xgboost_inference import 
XGBoostModelHandlerPandas
+from apache_beam.ml.inference.xgboost_inference import XGBoostModelHandlerSciPy
+from apache_beam.options.pipeline_options import PipelineOptions
+from apache_beam.options.pipeline_options import SetupOptions
+from apache_beam.runners.runner import PipelineResult
+
+
+def _train_model(model_state_output_path: str = '/tmp/model.json', seed=999):
+  """Function to train an XGBoost Classifier using the sklearn Iris dataset"""
+  dataset = load_iris()
+  x_train, _, y_train, _ = train_test_split(
+      dataset['data'], dataset['target'], test_size=.2, random_state=seed)
+  booster = xgboost.XGBClassifier(
+      n_estimators=2, max_depth=2, learning_rate=1, 
objective='binary:logistic')
+  booster.fit(x_train, y_train)
+  booster.save_model(model_state_output_path)
+  return booster
+
+
+class PostProcessor(beam.DoFn):
+  """Process the PredictionResult to get the predicted label.
+  Returns a comma separated string with true label and predicted label.
+  """
+  def process(self, element: Tuple[int, PredictionResult]) -> Iterable[str]:
+    label, prediction_result = element
+    prediction = prediction_result.inference
+    yield '{},{}'.format(label, prediction)
+
+
+def parse_known_args(argv):
+  """Parses args for the workflow."""
+  parser = argparse.ArgumentParser()
+  parser.add_argument(
+      '--input-type',
+      dest='input_type',
+      required=True,
+      choices=['numpy', 'pandas', 'scipy', 'datatable'],
+      help=
+      'Datatype of the input data.'
+  )
+  parser.add_argument(
+      '--output',
+      dest='output',
+      required=True,
+      help='Path to save output predictions.')
+  parser.add_argument(
+      '--model-state',
+      dest='model_state',
+      required=True,
+      help='Path to the state of the XGBoost model loaded for Inference.'
+  )
+  group = parser.add_mutually_exclusive_group(required=True)
+  group.add_argument('--split', action='store_true', dest='split')
+  group.add_argument('--no-split', action='store_false', dest='split')
+  return parser.parse_known_args(argv)
+
+
+def load_sklearn_iris_test_data(
+    data_type: Callable,
+    split: bool = True,
+    seed: int = 999) -> List[Union[numpy.array, pandas.DataFrame]]:
+  """
+    Loads test data from the sklearn Iris dataset in a given format,
+    either in a single or multiple batches.
+    Args:
+      data_type: Datatype of the iris test dataset.
+      split: Split the dataset in different batches or return single batch.
+      seed: Random state for splitting the train and test set.
+  """
+  dataset = load_iris()
+  _, x_test, _, _ = train_test_split(
+      dataset['data'], dataset['target'], test_size=.2, random_state=seed)
+
+  if split:
+    return [(index, data_type(sample.reshape(1, -1))) for index,
+            sample in enumerate(x_test)]
+  return [(0, data_type(x_test))]
+
+
+def run(
+    argv=None, save_main_session=True, test_pipeline=None) -> PipelineResult:
+  """
+    Args:
+      argv: Command line arguments defined for this example.
+      save_main_session: Used for internal testing.
+      test_pipeline: Used for internal testing.
+  """
+  known_args, pipeline_args = parse_known_args(argv)
+  pipeline_options = PipelineOptions(pipeline_args)
+  pipeline_options.view_as(SetupOptions).save_main_session = save_main_session
+
+  data_types = {
+      'numpy': (numpy.array, XGBoostModelHandlerNumpy),
+      'pandas': (pandas.DataFrame, XGBoostModelHandlerPandas),
+      'scipy': (scipy.sparse.csr_matrix, XGBoostModelHandlerSciPy),
+      'datatable': (datatable.Frame, XGBoostModelHandlerDatatable),
+  }
+
+  input_data_type, model_handler = data_types[known_args.input_type]
+
+  xgboost_model_handler = KeyedModelHandler(
+      model_handler(
+          model_class=xgboost.XGBClassifier,
+          model_state=known_args.model_state))
+
+  input_data = load_sklearn_iris_test_data(
+      data_type=input_data_type, split=known_args.split)
+
+  pipeline = test_pipeline
+  if not test_pipeline:
+    pipeline = beam.Pipeline(options=pipeline_options)
+
+  predictions = (
+      pipeline
+      | "ReadInputData" >> beam.Create(input_data)
+      | "RunInference" >> RunInference(xgboost_model_handler)
+      | "PostProcessOutputs" >> beam.ParDo(PostProcessor()))
+
+  _ = predictions | "WriteOutput" >> beam.io.WriteToText(
+      known_args.output, shard_name_template='', append_trailing_newlines=True)
+
+  result = pipeline.run()
+  result.wait_until_finish()
+  return result
+
+
+if __name__ == '__main__':
+  logging.getLogger().setLevel(logging.INFO)
+  _train_model()

Review Comment:
   How long does it take to train the model? I hope this would take less time 
since it will run on user's machine



##########
sdks/python/apache_beam/ml/inference/xgboost_inference.py:
##########
@@ -0,0 +1,212 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import sys
+from abc import ABC
+from typing import Any
+from typing import Callable
+from typing import Dict
+from typing import Iterable
+from typing import Optional
+from typing import Sequence
+from typing import Union
+
+import datatable
+import numpy
+import pandas
+import scipy
+import xgboost
+
+from apache_beam.ml.inference.base import ExampleT
+from apache_beam.ml.inference.base import ModelHandler
+from apache_beam.ml.inference.base import ModelT
+from apache_beam.ml.inference.base import PredictionResult
+from apache_beam.ml.inference.base import PredictionT
+
+
+class XGBoostModelHandler(ModelHandler[ExampleT, PredictionT, ModelT], ABC):

Review Comment:
   Users could(may be) use it if they want to use their own run_inference 
method. 
   
   We can also document in docstring in similar way 
   
https://github.com/apache/beam/blob/c0e689331c2a6573ecf267b9bef133a85ea8a36c/sdks/python/apache_beam/pvalue.py#L351



##########
sdks/python/container/py310/base_image_requirements.txt:
##########
@@ -154,4 +155,5 @@ urllib3==1.26.13
 websocket-client==1.4.2
 Werkzeug==2.2.2
 wrapt==1.14.1
+xgboost==1.7.1

Review Comment:
   I wouldn't put `xgboost` here since it is not a dependency of beam. 
   
   We can have a requirements file and gradle task. We can add this gradle task 
to the PostCommit suites similarly on torch tests run.
   
   For direct runner, 
https://github.com/apache/beam/blob/c0e689331c2a6573ecf267b9bef133a85ea8a36c/sdks/python/test-suites/direct/common.gradle#L284
   
   For Dataflow runner,  
https://github.com/apache/beam/blob/c0e689331c2a6573ecf267b9bef133a85ea8a36c/sdks/python/test-suites/dataflow/common.gradle#L403
   
   If IT tests are light weight and easy to run, we can add them to 
DirectRunner.
   
   



##########
sdks/python/apache_beam/ml/inference/xgboost_inference_it_test.py:
##########
@@ -0,0 +1,314 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import logging
+
+import pytest
+import unittest
+import uuid
+
+from apache_beam.examples.inference import xgboost_iris_classification
+from apache_beam.io.filesystems import FileSystems
+from apache_beam.testing.test_pipeline import TestPipeline
+
+
+def process_outputs(filepath):
+  with FileSystems().open(filepath) as f:
+    lines = f.readlines()
+  lines = [l.decode('utf-8').strip('\n') for l in lines]
+  return lines
+
+
[email protected]_xgboost
[email protected]_postcommit
+class XGBoostInference(unittest.TestCase):
+  def test_iris_classification_numpy_single_batch(self):
+    test_pipeline = TestPipeline(is_integration_test=True)
+    input_type = 'numpy'
+    output_file_dir = '/tmp'

Review Comment:
   Can we have this change for all the tests?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to