This is an automated email from the ASF dual-hosted git repository.

colinlee pushed a commit to branch colin_support_read_tree
in repository https://gitbox.apache.org/repos/asf/tsfile.git

commit 95830e53bc8f42d1990a289108a61c39db143592
Author: ColinLee <[email protected]>
AuthorDate: Sun Nov 30 11:09:29 2025 +0800

    fix some bugs.
---
 python/tests/test_write_and_read.py | 141 ++++++++++++++++++++++++++++++++++++
 python/tsfile/constants.py          |   4 +-
 python/tsfile/utils.py              |  29 +++++---
 3 files changed, 161 insertions(+), 13 deletions(-)

diff --git a/python/tests/test_write_and_read.py 
b/python/tests/test_write_and_read.py
index da6cc5c9..cf163c1f 100644
--- a/python/tests/test_write_and_read.py
+++ b/python/tests/test_write_and_read.py
@@ -19,6 +19,7 @@
 import os
 
 import numpy as np
+import pandas as pd
 import pytest
 from tsfile import ColumnSchema, TableSchema, TSEncoding
 from tsfile import Compressor
@@ -70,6 +71,146 @@ def test_row_record_write_and_read():
             os.remove("record_write_and_read.tsfile")
 
 
+def test_tree_query_to_dataframe_variants():
+    file_path = "tree_query_to_dataframe.tsfile"
+    device_ids = [
+        "root.db1.t1",
+        "root.db2.t1",
+        "root.db3.t2.t3",
+        "root.db3.t3",
+        "device",
+        "device.ln",
+        "device2.ln1.tmp",
+        "device3.ln2.tmp.v1.v2",
+        "device3.ln2.tmp.v1.v3",
+    ]
+    measurement_ids1 = ["temperature", "hudi", "level"]
+    measurement_ids2 = ["level", "vol"]
+    rows_per_device = 2
+    expected_values = {}
+    all_measurements = set()
+
+    def _is_null(value):
+        return value is None or pd.isna(value)
+
+    def _extract_device(row, path_columns):
+        parts = []
+        for col in path_columns:
+            value = row[col]
+            if not _is_null(value):
+                parts.append(str(value))
+        return ".".join(parts)
+
+    try:
+        writer = TsFileWriter(file_path)
+        for idx, device_id in enumerate(device_ids):
+            measurements = measurement_ids1 if idx % 2 == 0 else 
measurement_ids2
+            all_measurements.update(measurements)
+            for measurement in measurements:
+                writer.register_timeseries(
+                    device_id, TimeseriesSchema(measurement, TSDataType.INT32)
+                )
+            for ts in range(rows_per_device):
+                fields = []
+                measurement_snapshot = {}
+                for m_idx, measurement in enumerate(measurements):
+                    value = idx * 100 + ts * 10 + m_idx
+                    fields.append(Field(measurement, value, TSDataType.INT32))
+                    measurement_snapshot[measurement] = value
+                writer.write_row_record(RowRecord(device_id, ts, fields))
+                expected_values[(device_id, ts)] = measurement_snapshot
+        writer.close()
+
+        df_all = to_dataframe(file_path, start_time=0, 
end_time=rows_per_device)
+        print(df_all)
+        total_rows = len(device_ids) * rows_per_device
+        assert df_all.shape[0] == total_rows
+        for measurement in all_measurements:
+            assert measurement in df_all.columns
+        assert "time" in df_all.columns
+        path_columns = sorted(
+            [col for col in df_all.columns if col.startswith("col_")],
+            key=lambda name: int(name.split("_")[1]),
+        )
+        assert len(path_columns) > 0
+
+        for _, row in df_all.iterrows():
+            device = _extract_device(row, path_columns)
+            timestamp = int(row["time"])
+            assert (device, timestamp) in expected_values
+            expected_row = expected_values[(device, timestamp)]
+            for measurement in all_measurements:
+                value = row.get(measurement)
+                if measurement in expected_row:
+                    assert value == expected_row[measurement]
+                else:
+                    assert _is_null(value)
+
+        requested_columns = ["level", "temperature"]
+        df_subset = to_dataframe(
+            file_path, column_names=requested_columns, start_time=0, 
end_time=rows_per_device
+        )
+        for column in requested_columns:
+            assert column in df_subset.columns
+        for measurement in all_measurements:
+            if measurement not in requested_columns:
+                assert measurement not in df_subset.columns
+        for _, row in df_subset.iterrows():
+            device = _extract_device(row, path_columns)
+            timestamp = int(row["time"])
+            expected_row = expected_values[(device, timestamp)]
+            for measurement in requested_columns:
+                value = row.get(measurement)
+                if measurement in expected_row:
+                    assert value == expected_row[measurement]
+                else:
+                    assert _is_null(value)
+
+        df_limited = to_dataframe(
+            file_path, column_names=["level"], max_row_num=5, start_time=0, 
end_time=rows_per_device
+        )
+        assert df_limited.shape[0] == 5
+        assert "level" in df_limited.columns
+
+        iterator = to_dataframe(
+            file_path,
+            column_names=["level", "temperature"],
+            max_row_num=3,
+            start_time=0,
+            end_time=rows_per_device,
+            as_iterator=True,
+        )
+        iter_rows = 0
+        for batch in iterator:
+            assert isinstance(batch, pd.DataFrame)
+            assert set(batch.columns).issuperset({"time", "level"})
+            iter_rows += len(batch)
+            print(batch)
+        assert iter_rows == 18
+
+        iterator = to_dataframe(
+            file_path,
+            column_names=["level", "temperature"],
+            max_row_num=3,
+            start_time=0,
+            end_time=0,
+            as_iterator=True,
+        )
+        iter_rows = 0
+        for batch in iterator:
+            assert isinstance(batch, pd.DataFrame)
+            assert set(batch.columns).issuperset({"time", "level"})
+            iter_rows += len(batch)
+            print(batch)
+        assert iter_rows == 9
+
+        with pytest.raises(ColumnNotExistError):
+            to_dataframe(file_path, column_names=["level", "not_exists"])
+    finally:
+        if os.path.exists(file_path):
+            os.remove(file_path)
+
+
 @pytest.mark.skip(reason="API not match")
 def test_tablet_write_and_read():
     try:
diff --git a/python/tsfile/constants.py b/python/tsfile/constants.py
index 5eaa2470..72ac434b 100644
--- a/python/tsfile/constants.py
+++ b/python/tsfile/constants.py
@@ -53,9 +53,9 @@ class TSDataType(IntEnum):
         if self == TSDataType.BOOLEAN:
             return "bool"
         elif self == TSDataType.INT32:
-            return "int32"
+            return "Int32"
         elif self == TSDataType.INT64:
-            return "int64"
+            return "Int64"
         elif self == TSDataType.FLOAT:
             return "float32"
         elif self == TSDataType.DOUBLE:
diff --git a/python/tsfile/utils.py b/python/tsfile/utils.py
index e2baed5f..f7bce9bc 100644
--- a/python/tsfile/utils.py
+++ b/python/tsfile/utils.py
@@ -31,7 +31,7 @@ def to_dataframe(file_path: str,
                  max_row_num: int | None = None,
                  as_iterator: bool = False) -> Union[pd.DataFrame, 
Iterator[pd.DataFrame]]:
 
-    def _gen() -> Iterator[pd.DataFrame]:
+    def _gen(is_iterator: bool) -> Iterator[pd.DataFrame]:
         _table_name = table_name
         _column_names = column_names
         _start_time = start_time if start_time is not None else 
np.iinfo(np.int64).min
@@ -72,23 +72,30 @@ def to_dataframe(file_path: str,
 
             with query_result as result:
                 while result.next():
-                    if max_row_num is not None:
+                    if max_row_num is None:
+                        df = result.read_data_frame()
+                    elif is_iterator:
+                        df = result.read_data_frame(max_row_num)
+                    else:
                         remaining_rows = max_row_num - total_rows
                         if remaining_rows <= 0:
                             break
-                        else:
-                            batch_rows = min(remaining_rows, 1024)
-                        df = result.read_data_frame(batch_rows)
-                        total_rows += len(df)
-                    else:
-                        df = result.read_data_frame()
+                        df = result.read_data_frame(remaining_rows)
+                    if df is None or df.empty:
+                        continue
+                    total_rows += len(df)
                     yield df
+                    if (not is_iterator) and max_row_num is not None and 
total_rows >= max_row_num:
+                        break
 
     if as_iterator:
-        return _gen()
+        return _gen(True)
     else:
-        df_list = list(_gen())
+        df_list = list(_gen(False))
         if df_list:
-            return pd.concat(df_list, ignore_index=True)
+            df = pd.concat(df_list, ignore_index=True)
+            if max_row_num is not None and len(df) > max_row_num:
+                df = df.iloc[:max_row_num]
+            return df
         else:
             return pd.DataFrame()
\ No newline at end of file

Reply via email to