jason810496 commented on code in PR #60543:
URL: https://github.com/apache/airflow/pull/60543#discussion_r2692815430


##########
providers/databricks/src/airflow/providers/databricks/operators/databricks_sql.py:
##########
@@ -127,41 +148,151 @@ def get_db_hook(self) -> DatabricksSqlHook:
     def _should_run_output_processing(self) -> bool:
         return self.do_xcom_push or bool(self._output_path)
 
+    def _is_gcs_path(self, path: str) -> bool:
+        """Check if the path is a GCS URI."""
+        return path.startswith("gs://")
+
+    def _parse_gcs_path(self, path: str) -> tuple[str, str]:
+        """Parse a GCS URI into bucket and object name."""
+        parsed = urlparse(path)
+        bucket = parsed.netloc
+        object_name = parsed.path.lstrip("/")
+        return bucket, object_name
+
+    def _upload_to_gcs(self, local_path: str, gcs_path: str) -> None:
+        """Upload a local file to GCS."""
+        try:
+            from airflow.providers.google.cloud.hooks.gcs import GCSHook
+        except ImportError:
+            raise AirflowException(

Review Comment:
   ```suggestion
               raise AirflowOptionalProviderFeatureException(
   ```



##########
providers/databricks/src/airflow/providers/databricks/operators/databricks_sql.py:
##########
@@ -127,41 +148,151 @@ def get_db_hook(self) -> DatabricksSqlHook:
     def _should_run_output_processing(self) -> bool:
         return self.do_xcom_push or bool(self._output_path)
 
+    def _is_gcs_path(self, path: str) -> bool:
+        """Check if the path is a GCS URI."""
+        return path.startswith("gs://")
+
+    def _parse_gcs_path(self, path: str) -> tuple[str, str]:
+        """Parse a GCS URI into bucket and object name."""
+        parsed = urlparse(path)
+        bucket = parsed.netloc
+        object_name = parsed.path.lstrip("/")
+        return bucket, object_name
+
+    def _upload_to_gcs(self, local_path: str, gcs_path: str) -> None:
+        """Upload a local file to GCS."""
+        try:
+            from airflow.providers.google.cloud.hooks.gcs import GCSHook
+        except ImportError:
+            raise AirflowException(
+                "The 'apache-airflow-providers-google' package is required for 
GCS output. "
+                "Install it with: pip install apache-airflow-providers-google"
+            )
+
+        bucket, object_name = self._parse_gcs_path(gcs_path)
+        hook = GCSHook(
+            gcp_conn_id=self._gcp_conn_id,
+            impersonation_chain=self._gcs_impersonation_chain,
+        )
+        hook.upload(
+            bucket_name=bucket,
+            object_name=object_name,
+            filename=local_path,
+        )
+        self.log.info("Uploaded output to %s", gcs_path)
+
+    def _write_parquet(self, file_path: str, field_names: list[str], rows: 
list[Any]) -> None:
+        """Write data to a Parquet file."""
+        import pyarrow as pa
+        import pyarrow.parquet as pq
+
+        data: dict[str, list] = {name: [] for name in field_names}
+        for row in rows:
+            row_dict = row._asdict()
+            for name in field_names:
+                data[name].append(row_dict[name])
+
+        table = pa.Table.from_pydict(data)
+        pq.write_table(table, file_path)
+
+    def _write_avro(self, file_path: str, field_names: list[str], rows: 
list[Any]) -> None:
+        """Write data to an Avro file using fastavro."""
+        try:
+            from fastavro import writer
+        except ImportError:
+            raise AirflowException(
+                "The 'fastavro' package is required for Avro output. Install 
it with: pip install fastavro"
+            )
+
+        data: dict[str, list] = {name: [] for name in field_names}
+        for row in rows:
+            row_dict = row._asdict()
+            for name in field_names:
+                data[name].append(row_dict[name])
+
+        schema_fields = []
+        for name in field_names:
+            sample_val = next(
+                (data[name][i] for i in range(len(data[name])) if 
data[name][i] is not None), None
+            )
+            if sample_val is None:
+                avro_type = ["null", "string"]
+            elif isinstance(sample_val, bool):
+                avro_type = ["null", "boolean"]
+            elif isinstance(sample_val, int):
+                avro_type = ["null", "long"]
+            elif isinstance(sample_val, float):
+                avro_type = ["null", "double"]
+            else:
+                avro_type = ["null", "string"]
+            schema_fields.append({"name": name, "type": avro_type})
+
+        avro_schema = {
+            "type": "record",
+            "name": "QueryResult",
+            "fields": schema_fields,
+        }
+
+        records = [row._asdict() for row in rows]
+        with open(file_path, "wb") as f:
+            writer(f, avro_schema, records)
+
     def _process_output(self, results: list[Any], descriptions: 
list[Sequence[Sequence] | None]) -> list[Any]:
         if not self._output_path:
             return list(zip(descriptions, results))
         if not self._output_format:
             raise AirflowException("Output format should be specified!")
-        # Output to a file only the result of last query
+
         last_description = descriptions[-1]
         last_results = results[-1]
         if last_description is None:
-            raise AirflowException("There is missing description present for 
the output file. .")
+            raise AirflowException("There is missing description present for 
the output file.")
         field_names = [field[0] for field in last_description]
-        if self._output_format.lower() == "csv":
-            with open(self._output_path, "w", newline="") as file:
-                if self._csv_params:
-                    csv_params = self._csv_params
-                else:
-                    csv_params = {}
-                write_header = csv_params.get("header", True)
-                if "header" in csv_params:
-                    del csv_params["header"]
-                writer = csv.DictWriter(file, fieldnames=field_names, 
**csv_params)
-                if write_header:
-                    writer.writeheader()
-                for row in last_results:
-                    writer.writerow(row._asdict())
-        elif self._output_format.lower() == "json":
-            with open(self._output_path, "w") as file:
-                file.write(json.dumps([row._asdict() for row in last_results]))
-        elif self._output_format.lower() == "jsonl":
-            with open(self._output_path, "w") as file:
-                for row in last_results:
-                    file.write(json.dumps(row._asdict()))
-                    file.write("\n")
+
+        is_gcs = self._is_gcs_path(self._output_path)
+        if is_gcs:

Review Comment:
   Would it be better to turn it to property?



##########
providers/databricks/src/airflow/providers/databricks/operators/databricks_sql.py:
##########
@@ -127,41 +148,151 @@ def get_db_hook(self) -> DatabricksSqlHook:
     def _should_run_output_processing(self) -> bool:
         return self.do_xcom_push or bool(self._output_path)
 
+    def _is_gcs_path(self, path: str) -> bool:
+        """Check if the path is a GCS URI."""
+        return path.startswith("gs://")
+
+    def _parse_gcs_path(self, path: str) -> tuple[str, str]:
+        """Parse a GCS URI into bucket and object name."""
+        parsed = urlparse(path)
+        bucket = parsed.netloc
+        object_name = parsed.path.lstrip("/")
+        return bucket, object_name
+
+    def _upload_to_gcs(self, local_path: str, gcs_path: str) -> None:
+        """Upload a local file to GCS."""
+        try:
+            from airflow.providers.google.cloud.hooks.gcs import GCSHook
+        except ImportError:
+            raise AirflowException(
+                "The 'apache-airflow-providers-google' package is required for 
GCS output. "
+                "Install it with: pip install apache-airflow-providers-google"
+            )
+
+        bucket, object_name = self._parse_gcs_path(gcs_path)
+        hook = GCSHook(
+            gcp_conn_id=self._gcp_conn_id,
+            impersonation_chain=self._gcs_impersonation_chain,
+        )
+        hook.upload(
+            bucket_name=bucket,
+            object_name=object_name,
+            filename=local_path,
+        )
+        self.log.info("Uploaded output to %s", gcs_path)
+
+    def _write_parquet(self, file_path: str, field_names: list[str], rows: 
list[Any]) -> None:
+        """Write data to a Parquet file."""
+        import pyarrow as pa
+        import pyarrow.parquet as pq
+
+        data: dict[str, list] = {name: [] for name in field_names}
+        for row in rows:
+            row_dict = row._asdict()
+            for name in field_names:
+                data[name].append(row_dict[name])
+
+        table = pa.Table.from_pydict(data)
+        pq.write_table(table, file_path)
+
+    def _write_avro(self, file_path: str, field_names: list[str], rows: 
list[Any]) -> None:
+        """Write data to an Avro file using fastavro."""
+        try:
+            from fastavro import writer
+        except ImportError:
+            raise AirflowException(
+                "The 'fastavro' package is required for Avro output. Install 
it with: pip install fastavro"
+            )
+
+        data: dict[str, list] = {name: [] for name in field_names}
+        for row in rows:
+            row_dict = row._asdict()
+            for name in field_names:
+                data[name].append(row_dict[name])
+
+        schema_fields = []
+        for name in field_names:
+            sample_val = next(
+                (data[name][i] for i in range(len(data[name])) if 
data[name][i] is not None), None
+            )
+            if sample_val is None:
+                avro_type = ["null", "string"]
+            elif isinstance(sample_val, bool):
+                avro_type = ["null", "boolean"]
+            elif isinstance(sample_val, int):
+                avro_type = ["null", "long"]
+            elif isinstance(sample_val, float):
+                avro_type = ["null", "double"]
+            else:
+                avro_type = ["null", "string"]
+            schema_fields.append({"name": name, "type": avro_type})
+
+        avro_schema = {
+            "type": "record",
+            "name": "QueryResult",
+            "fields": schema_fields,
+        }
+
+        records = [row._asdict() for row in rows]
+        with open(file_path, "wb") as f:
+            writer(f, avro_schema, records)
+
     def _process_output(self, results: list[Any], descriptions: 
list[Sequence[Sequence] | None]) -> list[Any]:
         if not self._output_path:
             return list(zip(descriptions, results))
         if not self._output_format:
             raise AirflowException("Output format should be specified!")
-        # Output to a file only the result of last query
+
         last_description = descriptions[-1]
         last_results = results[-1]
         if last_description is None:
-            raise AirflowException("There is missing description present for 
the output file. .")
+            raise AirflowException("There is missing description present for 
the output file.")
         field_names = [field[0] for field in last_description]
-        if self._output_format.lower() == "csv":
-            with open(self._output_path, "w", newline="") as file:
-                if self._csv_params:
-                    csv_params = self._csv_params
-                else:
-                    csv_params = {}
-                write_header = csv_params.get("header", True)
-                if "header" in csv_params:
-                    del csv_params["header"]
-                writer = csv.DictWriter(file, fieldnames=field_names, 
**csv_params)
-                if write_header:
-                    writer.writeheader()
-                for row in last_results:
-                    writer.writerow(row._asdict())
-        elif self._output_format.lower() == "json":
-            with open(self._output_path, "w") as file:
-                file.write(json.dumps([row._asdict() for row in last_results]))
-        elif self._output_format.lower() == "jsonl":
-            with open(self._output_path, "w") as file:
-                for row in last_results:
-                    file.write(json.dumps(row._asdict()))
-                    file.write("\n")
+
+        is_gcs = self._is_gcs_path(self._output_path)
+        if is_gcs:
+            suffix = f".{self._output_format.lower()}"
+            tmp_file = NamedTemporaryFile(mode="w", suffix=suffix, 
delete=False, newline="")
+            local_path = tmp_file.name
+            tmp_file.close()
         else:
-            raise AirflowException(f"Unsupported output format: 
'{self._output_format}'")
+            local_path = self._output_path
+
+        try:
+            output_format = self._output_format.lower()
+            if output_format == "csv":
+                with open(local_path, "w", newline="") as file:
+                    if self._csv_params:
+                        csv_params = self._csv_params.copy()
+                    else:
+                        csv_params = {}
+                    write_header = csv_params.pop("header", True)
+                    writer = csv.DictWriter(file, fieldnames=field_names, 
**csv_params)
+                    if write_header:
+                        writer.writeheader()
+                    for row in last_results:
+                        writer.writerow(row._asdict())
+            elif output_format == "json":
+                with open(local_path, "w") as file:
+                    file.write(json.dumps([row._asdict() for row in 
last_results]))
+            elif output_format == "jsonl":
+                with open(local_path, "w") as file:
+                    for row in last_results:
+                        file.write(json.dumps(row._asdict()))
+                        file.write("\n")
+            elif output_format == "parquet":
+                self._write_parquet(local_path, field_names, last_results)
+            elif output_format == "avro":
+                self._write_avro(local_path, field_names, last_results)
+            else:
+                raise AirflowException(f"Unsupported output format: 
'{self._output_format}'")

Review Comment:
   ```suggestion
                   raise ValueError(f"Unsupported output format: 
'{self._output_format}'")
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to