moomindani commented on code in PR #68519:
URL: https://github.com/apache/airflow/pull/68519#discussion_r3421242209
##########
providers/databricks/src/airflow/providers/databricks/operators/databricks.py:
##########
@@ -281,6 +284,50 @@ def _inject_airflow_params_into_task(task: dict, params:
dict) -> None:
task_def[field] = dict(params)
+def _coerce_json_to_dict(json: Any) -> dict[str, Any]:
+ if json is None:
+ return {}
+ if isinstance(json, Mapping):
+ return dict(json)
+ if isinstance(json, str):
+ return _parse_json_string_to_dict(json)
+ raise DatabricksOperatorPayloadError(
+ f"Databricks json payload must resolve to a mapping, not
{type(json).__name__}."
+ )
+
+
+def _parse_json_string_to_dict(json: str) -> dict[str, Any]:
+ if not json:
+ return {}
+ try:
+ parsed_json = json_utils.loads(json)
+ except json_utils.JSONDecodeError:
+ try:
+ parsed_json = ast.literal_eval(json)
Review Comment:
The `ast.literal_eval` fallback silently accepts Python dict literals (e.g.
`{'key': 'value'}` with single quotes) that are not valid JSON. This widens the
accepted input format beyond JSON without documentation.
If accepting Python dict literals is intentional, please document it in the
function docstring. Otherwise, consider re-raising the `JSONDecodeError` with a
clear message stating only JSON is supported.
##########
providers/databricks/src/airflow/providers/databricks/operators/databricks.py:
##########
@@ -430,14 +498,16 @@ def _hook(self):
)
def execute(self, context: Context) -> int:
- if "name" not in self.json:
+ json = cast("dict[str, Any]",
normalise_json_content(self._get_merged_json()))
+ if "name" not in json:
raise AirflowException("Missing required parameter: name")
- job_id = self._hook.find_job_id_by_name(self.json["name"])
- if not self.json.get("parameters") and self.params:
- self.json["parameters"] = [{"name": k, "default": v} for k, v in
dict(self.params).items()]
+ job_id = self._hook.find_job_id_by_name(json["name"])
+ if not json.get("parameters") and self.params:
+ json["parameters"] = [{"name": k, "default": v} for k, v in
dict(self.params).items()]
+ self.json = json
Review Comment:
This overwrites `self.json` with the merged, normalised dict. This breaks
template re-rendering on retry: if the original `self.json` was an XCom
template string (e.g. `"{{ ti.xcom_pull(task_ids='payload') }}"`), Airflow
calls `render_template_fields` again before re-running `execute()`. Because
`self.json` is now a plain dict (not the original template), the re-render is a
no-op and the retry silently reuses the stale value from the first run.
Consider keeping the original input in a non-template-field attribute (e.g.
`self._json_input`) and deriving the merged payload from it each time, rather
than storing the merged result back in `self.json`.
##########
providers/databricks/tests/unit/databricks/operators/test_databricks.py:
##########
@@ -867,7 +893,37 @@ def test_init_with_templating(self):
"run_name": TASK_ID,
}
)
- assert expected == utils.normalise_json_content(op.json)
+ assert expected == utils.normalise_json_content(op._get_merged_json())
+
+
@mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook")
+ def test_exec_with_xcom_arg_json_and_templated_named_parameters(self,
db_mock_class):
+ with DAG("test", schedule=None, start_date=datetime.now()):
Review Comment:
`datetime.now()` is prohibited in tests by the project's testing standards.
Use a fixed `datetime` constant (e.g. `DEFAULT_DATE` already defined in this
file) instead of `datetime.now()` for `start_date`.
##########
providers/databricks/src/airflow/providers/databricks/operators/databricks.py:
##########
@@ -992,26 +1121,32 @@ def _get_hook(self, caller: str) -> DatabricksHook:
)
def execute(self, context: Context):
+ json = self._get_merged_json()
+ self._validate_merged_json(json)
hook = self._hook
- if "job_name" in self.json:
- job_id = hook.find_job_id_by_name(self.json["job_name"])
+ if "job_name" in json:
+ job_id = hook.find_job_id_by_name(json["job_name"])
if job_id is None:
- raise AirflowException(f"Job ID for job name
{self.json['job_name']} can not be found")
- self.json["job_id"] = job_id
- del self.json["job_name"]
+ raise DatabricksOperatorPayloadError(
+ f"Job ID for job name {json['job_name']} can not be found"
+ )
+ json["job_id"] = job_id
+ del json["job_name"]
if self.cancel_previous_runs:
- if (job_id := self.json.get("job_id")) is None:
+ if (job_id := json.get("job_id")) is None:
raise ValueError(
"cancel_previous_runs=True requires either job_id or
job_name to be provided."
)
hook.cancel_all_runs(job_id)
- if not self.json.get("job_parameters") and self.params:
- self.json["job_parameters"] = dict(self.params)
+ json = cast("dict[str, Any]", normalise_json_content(json))
+ if not json.get("job_parameters") and self.params:
+ json["job_parameters"] = dict(self.params)
- self.run_id = hook.run_now(self.json)
+ self.json = json
Review Comment:
Same issue as in the other operators: storing the merged result in
`self.json` means a retry will not re-render the original template.
Note: the deferrable path depends on `self.json` being set here so that
`execute_complete` can read `job_parameters` from it. If you keep this
assignment for the deferrable path, make sure the original template is
preserved separately for the retry path.
##########
providers/databricks/src/airflow/providers/databricks/operators/databricks.py:
##########
@@ -674,28 +779,31 @@ def _get_hook(self, caller: str) -> DatabricksHook:
)
def execute(self, context: Context):
+ json = self._get_merged_json()
+ self._validate_merged_json(json)
if (
- "pipeline_task" in self.json
- and self.json["pipeline_task"].get("pipeline_id") is None
- and self.json["pipeline_task"].get("pipeline_name")
+ isinstance(json.get("pipeline_task"), Mapping)
+ and json["pipeline_task"].get("pipeline_id") is None
+ and json["pipeline_task"].get("pipeline_name")
):
# If pipeline_id is not provided, we need to fetch it from the
pipeline_name
- pipeline_name = self.json["pipeline_task"]["pipeline_name"]
- self.json["pipeline_task"]["pipeline_id"] =
self._hook.find_pipeline_id_by_name(pipeline_name)
- del self.json["pipeline_task"]["pipeline_name"]
+ pipeline_name = json["pipeline_task"]["pipeline_name"]
+ json["pipeline_task"] = dict(json["pipeline_task"])
+ json["pipeline_task"]["pipeline_id"] =
self._hook.find_pipeline_id_by_name(pipeline_name)
+ del json["pipeline_task"]["pipeline_name"]
if self.params:
params_dump = dict(self.params)
- tasks = self.json.get("tasks")
+ tasks = json.get("tasks")
if isinstance(tasks, list):
for task in tasks:
if isinstance(task, dict):
_inject_airflow_params_into_task(task, params_dump)
else:
- _inject_airflow_params_into_task(self.json, params_dump)
+ _inject_airflow_params_into_task(json, params_dump)
- json_normalised = normalise_json_content(self.json)
- self.run_id = self._hook.submit_run(json_normalised)
+ self.json = normalise_json_content(json)
Review Comment:
Same issue as in `DatabricksCreateJobsOperator.execute()`: overwriting
`self.json` with the merged normalised dict loses the original template string
and breaks re-rendering on retry.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]