This is an automated email from the ASF dual-hosted git repository.

guanmingchiu pushed a commit to branch dev-qdp
in repository https://gitbox.apache.org/repos/asf/mahout.git


The following commit(s) were added to refs/heads/dev-qdp by this push:
     new 97932d078 [QDP] Choose between different frameworks in benchmark (#741)
97932d078 is described below

commit 97932d0781b1db4abdd4c272427a9e118a9e9fc0
Author: Ping <[email protected]>
AuthorDate: Thu Dec 18 14:35:05 2025 +0800

    [QDP] Choose between different frameworks in benchmark (#741)
    
    Signed-off-by: 400Ping <[email protected]>
---
 docs/benchmarks/dataloader_throughput.md         |  4 ++
 qdp/benchmark/benchmark_dataloader_throughput.py | 71 +++++++++++++++++++-----
 2 files changed, 60 insertions(+), 15 deletions(-)

diff --git a/docs/benchmarks/dataloader_throughput.md 
b/docs/benchmarks/dataloader_throughput.md
index aa6a9901c..242340c26 100644
--- a/docs/benchmarks/dataloader_throughput.md
+++ b/docs/benchmarks/dataloader_throughput.md
@@ -17,6 +17,9 @@ cargo run -p qdp-core --example dataloader_throughput 
--release
 
 # Cross-framework comparison (requires deps in qdp/benchmark/requirements.txt)
 python qdp/benchmark/benchmark_dataloader_throughput.py --qubits 16 --batches 
200 --batch-size 64 --prefetch 16
+
+# Run only Mahout + PennyLane legs
+python qdp/benchmark/benchmark_dataloader_throughput.py --frameworks 
mahout,pennylane
 ```
 
 ## Example Output
@@ -67,3 +70,4 @@ Speedup vs Qiskit:          8.44x
   - `--batches`: number of host-side batches to stream.
   - `--batch-size`: vectors per batch; raises total samples (`batches * 
batch-size`).
   - `--prefetch`: CPU queue depth; higher values help hide slow CPU-side prep 
(e.g., Qiskit state prep) and keep GPU fed.
+  - `--frameworks`: comma-separated list of legs to execute 
(`pennylane,qiskit,mahout`) or `all`.
diff --git a/qdp/benchmark/benchmark_dataloader_throughput.py 
b/qdp/benchmark/benchmark_dataloader_throughput.py
index bc8a39291..9ce974084 100644
--- a/qdp/benchmark/benchmark_dataloader_throughput.py
+++ b/qdp/benchmark/benchmark_dataloader_throughput.py
@@ -39,6 +39,7 @@ from mahout_qdp import QdpEngine
 
 BAR = "=" * 70
 SEP = "-" * 70
+FRAMEWORK_CHOICES = ("pennylane", "qiskit", "mahout")
 
 try:
     import pennylane as qml
@@ -91,6 +92,26 @@ def normalize_batch(batch: np.ndarray) -> np.ndarray:
     return batch / norms
 
 
+def parse_frameworks(raw: str) -> list[str]:
+    if raw.lower() == "all":
+        return list(FRAMEWORK_CHOICES)
+
+    selected: list[str] = []
+    for part in raw.split(","):
+        name = part.strip().lower()
+        if not name:
+            continue
+        if name not in FRAMEWORK_CHOICES:
+            raise ValueError(
+                f"Unknown framework '{name}'. Choose from: "
+                f"{', '.join(FRAMEWORK_CHOICES)} or 'all'."
+            )
+        if name not in selected:
+            selected.append(name)
+
+    return selected if selected else list(FRAMEWORK_CHOICES)
+
+
 def run_mahout(num_qubits: int, total_batches: int, batch_size: int, prefetch: 
int):
     try:
         engine = QdpEngine(0)
@@ -211,8 +232,22 @@ def main():
     parser.add_argument(
         "--prefetch", type=int, default=16, help="CPU-side prefetch depth."
     )
+    parser.add_argument(
+        "--frameworks",
+        type=str,
+        default="all",
+        help=(
+            "Comma-separated list of frameworks to run "
+            "(pennylane,qiskit,mahout) or 'all'."
+        ),
+    )
     args = parser.parse_args()
 
+    try:
+        frameworks = parse_frameworks(args.frameworks)
+    except ValueError as exc:
+        parser.error(str(exc))
+
     total_vectors = args.batches * args.batch_size
     vector_len = 1 << args.qubits
 
@@ -221,6 +256,7 @@ def main():
     print(f"  Vector length: {vector_len}")
     print(f"  Batches      : {args.batches}")
     print(f"  Prefetch     : {args.prefetch}")
+    print(f"  Frameworks   : {', '.join(frameworks)}")
     bytes_per_vec = vector_len * 8
     print(f"  Generated {total_vectors} samples")
     print(
@@ -235,23 +271,28 @@ def main():
     )
     print(BAR)
 
-    print()
-    print("[PennyLane] Full Pipeline (DataLoader -> GPU)...")
-    t_pl, th_pl = run_pennylane(
-        args.qubits, args.batches, args.batch_size, args.prefetch
-    )
+    t_pl = th_pl = t_qiskit = th_qiskit = t_mahout = th_mahout = 0.0
 
-    print()
-    print("[Qiskit] Full Pipeline (DataLoader -> GPU)...")
-    t_qiskit, th_qiskit = run_qiskit(
-        args.qubits, args.batches, args.batch_size, args.prefetch
-    )
+    if "pennylane" in frameworks:
+        print()
+        print("[PennyLane] Full Pipeline (DataLoader -> GPU)...")
+        t_pl, th_pl = run_pennylane(
+            args.qubits, args.batches, args.batch_size, args.prefetch
+        )
 
-    print()
-    print("[Mahout] Full Pipeline (DataLoader -> GPU)...")
-    t_mahout, th_mahout = run_mahout(
-        args.qubits, args.batches, args.batch_size, args.prefetch
-    )
+    if "qiskit" in frameworks:
+        print()
+        print("[Qiskit] Full Pipeline (DataLoader -> GPU)...")
+        t_qiskit, th_qiskit = run_qiskit(
+            args.qubits, args.batches, args.batch_size, args.prefetch
+        )
+
+    if "mahout" in frameworks:
+        print()
+        print("[Mahout] Full Pipeline (DataLoader -> GPU)...")
+        t_mahout, th_mahout = run_mahout(
+            args.qubits, args.batches, args.batch_size, args.prefetch
+        )
 
     print()
     print(BAR)

Reply via email to