mohamedawnallah commented on code in PR #35473:
URL: https://github.com/apache/beam/pull/35473#discussion_r2332204148


##########
sdks/python/apache_beam/examples/snippets/transforms/elementwise/enrichment_test.py:
##########
@@ -60,52 +82,227 @@ def validate_enrichment_with_vertex_ai_legacy():
   return expected
 
 
-def std_out_to_dict(stdout_lines, row_key):
-  output_dict = {}
-  for stdout_line in stdout_lines:
-    # parse the stdout in a dictionary format so that it can be
-    # evaluated/compared as one. This allows us to compare without
-    # considering the order of the stdout or the order that the fields of the
-    # row are arranged in.
-    fmtd = '{\"' + stdout_line[4:-1].replace('=', '\": ').replace(
-        ', ', ', \"').replace('\"\'', '\'') + "}"
-    stdout_dict = eval(fmtd)  # pylint: disable=eval-used
-    output_dict[stdout_dict[row_key]] = stdout_dict
-  return output_dict
+def validate_enrichment_with_google_cloudsql_pg():
+  expected = '''[START enrichment_with_google_cloudsql_pg]
+Row(product_id=1, name='A', quantity=2, region_id=3)
+Row(product_id=2, name='B', quantity=3, region_id=1)
+Row(product_id=3, name='C', quantity=10, region_id=4)
+  [END enrichment_with_google_cloudsql_pg]'''.splitlines()[1:-1]
+  return expected
+
+
+def validate_enrichment_with_external_pg():
+  expected = '''[START enrichment_with_external_pg]
+Row(product_id=1, name='A', quantity=2, region_id=3)
+Row(product_id=2, name='B', quantity=3, region_id=1)
+Row(product_id=3, name='C', quantity=10, region_id=4)
+  [END enrichment_with_external_pg]'''.splitlines()[1:-1]
+  return expected
+
+
+def validate_enrichment_with_external_mysql():
+  expected = '''[START enrichment_with_external_mysql]
+Row(product_id=1, name='A', quantity=2, region_id=3)
+Row(product_id=2, name='B', quantity=3, region_id=1)
+Row(product_id=3, name='C', quantity=10, region_id=4)
+  [END enrichment_with_external_mysql]'''.splitlines()[1:-1]
+  return expected
+
+
+def validate_enrichment_with_external_sqlserver():
+  expected = '''[START enrichment_with_external_sqlserver]
+Row(product_id=1, name='A', quantity=2, region_id=3)
+Row(product_id=2, name='B', quantity=3, region_id=1)
+Row(product_id=3, name='C', quantity=10, region_id=4)
+  [END enrichment_with_external_sqlserver]'''.splitlines()[1:-1]
+  return expected
 
 
 @mock.patch('sys.stdout', new_callable=StringIO)
[email protected]_testcontainer
 class EnrichmentTest(unittest.TestCase):
   def test_enrichment_with_bigtable(self, mock_stdout):
     enrichment_with_bigtable()
     output = mock_stdout.getvalue().splitlines()
     expected = validate_enrichment_with_bigtable()
-
-    self.assertEqual(len(output), len(expected))
-    self.assertEqual(
-        std_out_to_dict(output, 'sale_id'),
-        std_out_to_dict(expected, 'sale_id'))
+    self.assertEqual(output, expected)
 
   def test_enrichment_with_vertex_ai(self, mock_stdout):
     enrichment_with_vertex_ai()
     output = mock_stdout.getvalue().splitlines()
     expected = validate_enrichment_with_vertex_ai()
 
-    self.assertEqual(len(output), len(expected))
-    self.assertEqual(
-        std_out_to_dict(output, 'user_id'),
-        std_out_to_dict(expected, 'user_id'))
+    for i in range(len(expected)):
+      self.assertEqual(set(output[i].split(',')), set(expected[i].split(',')))
 
   def test_enrichment_with_vertex_ai_legacy(self, mock_stdout):
     enrichment_with_vertex_ai_legacy()
     output = mock_stdout.getvalue().splitlines()
     expected = validate_enrichment_with_vertex_ai_legacy()
     self.maxDiff = None
+    self.assertEqual(output, expected)
+
+  def test_enrichment_with_google_cloudsql_pg(self, mock_stdout):
+    db_adapter = DatabaseTypeAdapter.POSTGRESQL
+    with EnrichmentTestHelpers.sql_test_context(True, db_adapter):
+      try:
+        enrichment_with_google_cloudsql_pg()
+        output = mock_stdout.getvalue().splitlines()
+        expected = validate_enrichment_with_google_cloudsql_pg()
+        self.assertEqual(output, expected)
+      except Exception as e:
+        self.fail(f"Test failed with unexpected error: {e}")
+
+  def test_enrichment_with_external_pg(self, mock_stdout):
+    db_adapter = DatabaseTypeAdapter.POSTGRESQL
+    with EnrichmentTestHelpers.sql_test_context(False, db_adapter):
+      try:
+        enrichment_with_external_pg()
+        output = mock_stdout.getvalue().splitlines()
+        expected = validate_enrichment_with_external_pg()
+        self.assertEqual(output, expected)
+      except Exception as e:
+        self.fail(f"Test failed with unexpected error: {e}")
+
+  def test_enrichment_with_external_mysql(self, mock_stdout):
+    db_adapter = DatabaseTypeAdapter.MYSQL
+    with EnrichmentTestHelpers.sql_test_context(False, db_adapter):
+      try:
+        enrichment_with_external_mysql()
+        output = mock_stdout.getvalue().splitlines()
+        expected = validate_enrichment_with_external_mysql()
+        self.assertEqual(output, expected)
+      except Exception as e:
+        self.fail(f"Test failed with unexpected error: {e}")
+
+  def test_enrichment_with_external_sqlserver(self, mock_stdout):
+    db_adapter = DatabaseTypeAdapter.SQLSERVER
+    with EnrichmentTestHelpers.sql_test_context(False, db_adapter):
+      try:
+        enrichment_with_external_sqlserver()
+        output = mock_stdout.getvalue().splitlines()
+        expected = validate_enrichment_with_external_sqlserver()
+        self.assertEqual(output, expected)
+      except Exception as e:
+        self.fail(f"Test failed with unexpected error: {e}")
+
+
+@dataclass
+class CloudSQLEnrichmentTestDataConstruct:
+  client_handler: Callable[[], DBAPIConnection]
+  engine: Engine
+  metadata: MetaData
+  db: SQLDBContainerInfo = None
+
+
+class EnrichmentTestHelpers:
+  @contextmanager
+  def sql_test_context(is_cloudsql: bool, db_adapter: DatabaseTypeAdapter):
+    result: Optional[CloudSQLEnrichmentTestDataConstruct] = None
+    try:
+      result = EnrichmentTestHelpers.pre_sql_enrichment_test(
+          is_cloudsql, db_adapter)
+      yield
+    finally:
+      if result:
+        EnrichmentTestHelpers.post_sql_enrichment_test(result)
+
+  @staticmethod
+  def pre_sql_enrichment_test(
+      is_cloudsql: bool,
+      db_adapter: DatabaseTypeAdapter) -> CloudSQLEnrichmentTestDataConstruct:
+    table_id = "products"
+    columns = [
+        Column("product_id", Integer, primary_key=True),
+        Column("name", VARCHAR(255), nullable=False),
+        Column("quantity", Integer, nullable=False),
+        Column("region_id", Integer, nullable=False),
+    ]
+    table_data = [
+        {
+            "product_id": 1, "name": "A", 'quantity': 2, 'region_id': 3
+        },
+        {
+            "product_id": 2, "name": "B", 'quantity': 3, 'region_id': 1
+        },
+        {
+            "product_id": 3, "name": "C", 'quantity': 10, 'region_id': 4
+        },
+    ]
+    metadata = MetaData()
+
+    connection_config: ConnectionConfig
+    if is_cloudsql:
+      gcp_project_id = "apache-beam-testing"
+      region = "us-central1"
+      instance_name = "beam-integration-tests"
+      instance_connection_uri = f"{gcp_project_id}:{region}:{instance_name}"
+      db_id = "postgres"
+      user = "postgres"
+      password = os.getenv("ALLOYDB_PASSWORD")
+      os.environ['GOOGLE_CLOUD_SQL_DB_URI'] = instance_connection_uri
+      os.environ['GOOGLE_CLOUD_SQL_DB_ID'] = db_id
+      os.environ['GOOGLE_CLOUD_SQL_DB_USER'] = user
+      os.environ['GOOGLE_CLOUD_SQL_DB_PASSWORD'] = password
+      os.environ['GOOGLE_CLOUD_SQL_DB_TABLE_ID'] = table_id
+      connection_config = CloudSQLConnectionConfig(
+          db_adapter=db_adapter,
+          instance_connection_uri=instance_connection_uri,
+          user=user,
+          password=password,
+          db_id=db_id)
+    else:
+      db = SQLEnrichmentTestHelper.start_sql_db_container(db_adapter)

Review Comment:
   > I think we need something like this in both paths. The new error I'm 
seeing in CI is:
   > 
   > ```
   >       SQLEnrichmentTestHelper.create_table(
   >           table_id=table_id,
   >           engine=engine,
   >           columns=columns,
   >           table_data=table_data,
   >           metadata=metadata)
   >     
   >       result = CloudSQLEnrichmentTestDataConstruct(
   > >         db=db, client_handler=conenctor, engine=engine, 
metadata=metadata)
   > E     UnboundLocalError: local variable 'db' referenced before assignment
   > ```
   > 
   > Looks like the pass-through config did work though
   
   Fixed that unbounded local variable



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to