Owen-CH-Leung commented on code in PR #53821:
URL: https://github.com/apache/airflow/pull/53821#discussion_r2262026037


##########
providers/elasticsearch/tests/unit/elasticsearch/log/test_es_task_handler.py:
##########
@@ -963,3 +942,147 @@ def test_self_not_valid_arg():
     Test if self is not a valid argument.
     """
     assert "self" not in VALID_ES_CONFIG_KEYS
+
+
[email protected]_test
+class TestElasticsearchRemoteLogIO:
+    DAG_ID = "dag_for_testing_es_log_handler"
+    TASK_ID = "task_for_testing_es_log_handler"
+    LOGICAL_DATE = datetime(2016, 1, 1)
+    FILENAME_TEMPLATE = "{try_number}.log"
+
+    @pytest.fixture(autouse=True)
+    def setup_tests(self, ti, es_8_container_url):
+        self.elasticsearch_8_url = es_8_container_url
+        self.elasticsearch_io = ElasticsearchRemoteLogIO(
+            write_to_es=True,
+            write_stdout=True,
+            delete_local_copy=True,
+            host=es_8_container_url,
+            base_log_folder=Path(""),
+        )
+
+    @pytest.fixture
+    def tmp_json_file(self):
+        with tempfile.TemporaryDirectory() as tmpdir:
+            os.makedirs(tmpdir, exist_ok=True)
+
+            file_path = os.path.join(tmpdir, "1.log")
+            self.tmp_file = file_path
+
+            sample_logs = [
+                {"message": "start"},
+                {"message": "processing"},
+                {"message": "end"},
+            ]
+            with open(file_path, "w") as f:
+                for log in sample_logs:
+                    f.write(json.dumps(log) + "\n")
+
+            yield file_path
+
+            del self.tmp_file
+
+    @pytest.fixture
+    def ti(self, create_task_instance, create_log_template):
+        create_log_template(
+            self.FILENAME_TEMPLATE,
+            (
+                "{dag_id}-{task_id}-{logical_date}-{try_number}"
+                if AIRFLOW_V_3_0_PLUS
+                else "{dag_id}-{task_id}-{execution_date}-{try_number}"
+            ),
+        )
+        yield get_ti(
+            dag_id=self.DAG_ID,
+            task_id=self.TASK_ID,
+            logical_date=self.LOGICAL_DATE,
+            create_task_instance=create_task_instance,
+        )
+        clear_db_runs()
+        clear_db_dags()
+
+    @pytest.fixture
+    def unique_index(self):
+        """Generate a unique index name for each test."""
+        return f"airflow-logs-{uuid.uuid4()}"
+
+    @pytest.fixture
+    def write_to_es(self, tmp_json_file, ti, unique_index):
+        self.elasticsearch_io.target_index = unique_index
+        self.elasticsearch_io.index_pattern = unique_index
+        self.elasticsearch_io.upload(tmp_json_file, ti)
+        self.elasticsearch_io.client.indices.refresh(index="_all")
+
+    def test_write_to_es(self, tmp_json_file, ti):
+        self.elasticsearch_io.write_stdout = False
+        self.elasticsearch_io.upload(tmp_json_file, ti)
+        self.elasticsearch_io.client.indices.refresh(index="_all")
+        res = self.elasticsearch_io.client.search(index="_all", 
query={"match_all": {}})
+
+        offset = 1
+        expected_msg = ["start", "processing", "end"]
+        expected_log_id = 
f"{ti.dag_id}-{ti.task_id}-{ti.run_id}-{ti.map_index}-{ti.try_number}"
+        assert res["hits"]["total"]["value"] == 3
+        for msg, hit in zip(expected_msg, res["hits"]["hits"]):
+            assert hit["_index"] == "airflow-logs"
+            assert hit["_source"]["message"] == msg
+            assert hit["_source"]["offset"] == offset
+            assert hit["_source"]["log_id"] == expected_log_id
+            offset += 1
+        self.elasticsearch_io.client.indices.delete(index="airflow-logs")
+
+    def test_write_to_stdout(self, tmp_json_file, ti, capsys):
+        self.elasticsearch_io.write_to_es = False
+        self.elasticsearch_io.upload(tmp_json_file, ti)
+
+        captured = capsys.readouterr()
+        stdout_lines = captured.out.strip().splitlines()
+
+        log_entries = [json.loads(line) for line in stdout_lines]
+        assert log_entries[0]["message"] == "start"
+        assert log_entries[1]["message"] == "processing"
+        assert log_entries[2]["message"] == "end"
+
+    def test_invalid_task_log_file_path(self, ti):
+        with (
+            patch.object(self.elasticsearch_io, "_parse_raw_log") as 
mock_parse,
+            patch.object(self.elasticsearch_io, "_write_to_es") as mock_write,
+        ):
+            self.elasticsearch_io.upload(Path("/invalid/path"), ti)
+
+            mock_parse.assert_not_called()
+            mock_write.assert_not_called()
+
+    def test_raw_log_should_contain_log_id_and_offset(self, tmp_json_file, ti):
+        with open(self.tmp_file) as f:
+            raw_log = f.read()
+        json_log_lines = self.elasticsearch_io._parse_raw_log(raw_log, ti)
+        assert len(json_log_lines) == 3
+        for json_log_line in json_log_lines:
+            assert "log_id" in json_log_line
+            assert "offset" in json_log_line
+
+    @patch(
+        "airflow.providers.elasticsearch.log.es_task_handler.TASK_LOG_FIELDS",
+        ["message"],
+    )
+    def test_read_es_log(self, write_to_es, ti):
+        log_source_info, log_messages = self.elasticsearch_io.read("", ti)
+        assert log_source_info[0] == self.elasticsearch_8_url
+        assert len(log_messages) == 3
+
+        expected_msg = ["start", "processing", "end"]
+        for msg, log_message in zip(expected_msg, log_messages):
+            json_log = json.loads(log_message)
+            assert "message" in json_log
+            assert json_log["message"] == msg
+
+    # @patch.object(ElasticsearchRemoteLogIO, "_get_index_patterns", 
return_value="invalid")

Review Comment:
   Revised . Thanks!



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to