This is an automated email from the ASF dual-hosted git repository.

riteshghorse pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new a13749fd529 [Python] Allow users to pass service name for profiler 
(#26220)
a13749fd529 is described below

commit a13749fd5296452a65969e0a277aab4554d11dc1
Author: Ritesh Ghorse <riteshgho...@gmail.com>
AuthorDate: Wed May 3 10:38:11 2023 -0400

    [Python] Allow users to pass service name for profiler (#26220)
    
    * handle service name for profiler
    
    * add debug messages
    
    * remove logs
    
    * move import statement and logs
    
    * rm unnecessary check
    
    * add helper to GCloudOption
    
    * rm lint
    
    * added new helper and unit test
    
    * separate functions
    
    * updated unit tests
    
    * refactored plus check for envvar
    
    * updated changes.md
---
 CHANGES.md                                         |  1 +
 .../python/apache_beam/options/pipeline_options.py | 16 ++++++
 .../apache_beam/runners/worker/data_sampler.py     |  1 -
 .../apache_beam/runners/worker/sdk_worker_main.py  | 58 +++++++++++++---------
 .../runners/worker/sdk_worker_main_test.py         | 29 +++++++++++
 5 files changed, 80 insertions(+), 25 deletions(-)

diff --git a/CHANGES.md b/CHANGES.md
index bcd74f7dd51..85b3d03c146 100644
--- a/CHANGES.md
+++ b/CHANGES.md
@@ -66,6 +66,7 @@
 
 ## New Features / Improvements
 
+* Allow passing service name for google-cloud-profiler (Python) 
([#26280](https://github.com/apache/beam/issues/26280)).
 * Dead letter queue support added to RunInference in Python 
([#24209](https://github.com/apache/beam/issues/24209)).
 * Support added for defining pre/postprocessing operations on the RunInference 
transform ([#26308](https://github.com/apache/beam/issues/26308))
 
diff --git a/sdks/python/apache_beam/options/pipeline_options.py 
b/sdks/python/apache_beam/options/pipeline_options.py
index a602912cd97..283c2caa49d 100644
--- a/sdks/python/apache_beam/options/pipeline_options.py
+++ b/sdks/python/apache_beam/options/pipeline_options.py
@@ -22,6 +22,7 @@
 import argparse
 import json
 import logging
+import os
 from typing import Any
 from typing import Callable
 from typing import Dict
@@ -876,6 +877,21 @@ class GoogleCloudOptions(PipelineOptions):
 
     return errors
 
+  def get_cloud_profiler_service_name(self):
+    _ENABLE_GOOGLE_CLOUD_PROFILER = 'enable_google_cloud_profiler'
+    if self.dataflow_service_options:
+      if _ENABLE_GOOGLE_CLOUD_PROFILER in self.dataflow_service_options:
+        return os.environ["JOB_NAME"]
+      for option_name in self.dataflow_service_options:
+        if option_name.startswith(_ENABLE_GOOGLE_CLOUD_PROFILER + '='):
+          return option_name.split('=', 1)[1]
+
+    experiments = self.view_as(DebugOptions).experiments or []
+    if _ENABLE_GOOGLE_CLOUD_PROFILER in experiments:
+      return os.environ["JOB_NAME"]
+
+    return None
+
 
 class AzureOptions(PipelineOptions):
   """Azure Blob Storage options."""
diff --git a/sdks/python/apache_beam/runners/worker/data_sampler.py 
b/sdks/python/apache_beam/runners/worker/data_sampler.py
index 2b37d008040..7cc8152693d 100644
--- a/sdks/python/apache_beam/runners/worker/data_sampler.py
+++ b/sdks/python/apache_beam/runners/worker/data_sampler.py
@@ -29,7 +29,6 @@ from typing import Dict
 from typing import Iterable
 from typing import List
 from typing import Optional
-from typing import Tuple
 from typing import Union
 
 from apache_beam.coders.coder_impl import CoderImpl
diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker_main.py 
b/sdks/python/apache_beam/runners/worker/sdk_worker_main.py
index b643034899d..e0545e1f007 100644
--- a/sdks/python/apache_beam/runners/worker/sdk_worker_main.py
+++ b/sdks/python/apache_beam/runners/worker/sdk_worker_main.py
@@ -173,32 +173,42 @@ def create_harness(environment, dry_run=False):
   return fn_log_handler, sdk_harness, sdk_pipeline_options
 
 
+def _start_profiler(gcp_profiler_service_name, gcp_profiler_service_version):
+  try:
+    import googlecloudprofiler
+    if gcp_profiler_service_name and gcp_profiler_service_version:
+      googlecloudprofiler.start(
+          service=gcp_profiler_service_name,
+          service_version=gcp_profiler_service_version,
+          verbose=1)
+      _LOGGER.info('Turning on Google Cloud Profiler.')
+    else:
+      raise RuntimeError('Unable to find the job id or job name from envvar.')
+  except Exception as e:  # pylint: disable=broad-except
+    _LOGGER.warning(
+        'Unable to start google cloud profiler due to error: %s. For how to '
+        'enable Cloud Profiler with Dataflow see '
+        'https://cloud.google.com/dataflow/docs/guides/profiling-a-pipeline.'
+        'For troubleshooting tips with Cloud Profiler see '
+        'https://cloud.google.com/profiler/docs/troubleshooting.' % e)
+
+
+def _get_gcp_profiler_name_if_enabled(sdk_pipeline_options):
+  gcp_profiler_service_name = sdk_pipeline_options.view_as(
+      GoogleCloudOptions).get_cloud_profiler_service_name()
+
+  return gcp_profiler_service_name
+
+
 def main(unused_argv):
   """Main entry point for SDK Fn Harness."""
-  fn_log_handler, sdk_harness, sdk_pipeline_options = 
create_harness(os.environ)
-  experiments = sdk_pipeline_options.view_as(DebugOptions).experiments or []
-  dataflow_service_options = (
-      sdk_pipeline_options.view_as(GoogleCloudOptions).dataflow_service_options
-      or [])
-  if (_ENABLE_GOOGLE_CLOUD_PROFILER in experiments) or (
-      _ENABLE_GOOGLE_CLOUD_PROFILER in dataflow_service_options):
-    try:
-      import googlecloudprofiler
-      job_id = os.environ["JOB_ID"]
-      job_name = os.environ["JOB_NAME"]
-      if job_id and job_name:
-        googlecloudprofiler.start(
-            service=job_name, service_version=job_id, verbose=1)
-        _LOGGER.info('Turning on Google Cloud Profiler.')
-      else:
-        raise RuntimeError('Unable to find the job id or job name from 
envvar.')
-    except Exception as e:  # pylint: disable=broad-except
-      _LOGGER.warning(
-          'Unable to start google cloud profiler due to error: %s. For how to '
-          'enable Cloud Profiler with Dataflow see '
-          'https://cloud.google.com/dataflow/docs/guides/profiling-a-pipeline.'
-          'For troubleshooting tips with Cloud Profiler see '
-          'https://cloud.google.com/profiler/docs/troubleshooting.' % e)
+  (fn_log_handler, sdk_harness,
+   sdk_pipeline_options) = create_harness(os.environ)
+
+  gcp_profiler_name = _get_gcp_profiler_name_if_enabled(sdk_pipeline_options)
+  if gcp_profiler_name:
+    _start_profiler(gcp_profiler_name, os.environ["JOB_ID"])
+
   try:
     _LOGGER.info('Python sdk harness starting.')
     sdk_harness.run()
diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker_main_test.py 
b/sdks/python/apache_beam/runners/worker/sdk_worker_main_test.py
index c7ec4220850..00e09840787 100644
--- a/sdks/python/apache_beam/runners/worker/sdk_worker_main_test.py
+++ b/sdks/python/apache_beam/runners/worker/sdk_worker_main_test.py
@@ -21,6 +21,7 @@
 
 import io
 import logging
+import os
 import unittest
 
 from hamcrest import all_of
@@ -205,6 +206,34 @@ class SdkWorkerMainTest(unittest.TestCase):
         sdk_worker_main._set_log_level_overrides(overrides)
         self.assertIn(expected, cm.output[0])
 
+  def test_gcp_profiler_uses_provided_service_name_when_specified(self):
+    options = PipelineOptions(
+        ['--dataflow_service_options=enable_google_cloud_profiler=sample'])
+    gcp_profiler_name = sdk_worker_main._get_gcp_profiler_name_if_enabled(
+        options)
+    sdk_worker_main._start_profiler = unittest.mock.MagicMock()
+    sdk_worker_main._start_profiler(gcp_profiler_name, "version")
+    sdk_worker_main._start_profiler.assert_called_with("sample", "version")
+
+  @unittest.mock.patch.dict(os.environ, {"JOB_NAME": "sample_job"}, clear=True)
+  def test_gcp_profiler_uses_job_name_when_service_name_not_specified(self):
+    options = PipelineOptions(
+        ['--dataflow_service_options=enable_google_cloud_profiler'])
+    gcp_profiler_name = sdk_worker_main._get_gcp_profiler_name_if_enabled(
+        options)
+    sdk_worker_main._start_profiler = unittest.mock.MagicMock()
+    sdk_worker_main._start_profiler(gcp_profiler_name, "version")
+    sdk_worker_main._start_profiler.assert_called_with("sample_job", "version")
+
+  @unittest.mock.patch.dict(os.environ, {"JOB_NAME": "sample_job"}, clear=True)
+  def test_gcp_profiler_uses_job_name_when_enabled_as_experiment(self):
+    options = PipelineOptions(['--experiment=enable_google_cloud_profiler'])
+    gcp_profiler_name = sdk_worker_main._get_gcp_profiler_name_if_enabled(
+        options)
+    sdk_worker_main._start_profiler = unittest.mock.MagicMock()
+    sdk_worker_main._start_profiler(gcp_profiler_name, "version")
+    sdk_worker_main._start_profiler.assert_called_with("sample_job", "version")
+
 
 if __name__ == '__main__':
   logging.getLogger().setLevel(logging.INFO)

Reply via email to