This is an automated email from the ASF dual-hosted git repository. guanmingchiu pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/mahout.git
commit b8647b42dc6f5b35284f5060a7936d1e3bf72aa2 Author: rich7420 <[email protected]> AuthorDate: Fri Jan 2 08:19:44 2026 +0800 [QDP] Add TensorFlow tensor input support --- qdp/Cargo.lock | 368 +++++++++++++++++++---- qdp/Cargo.toml | 6 + qdp/qdp-core/Cargo.toml | 6 + qdp/qdp-core/{src/readers/mod.rs => build.rs} | 32 +- qdp/qdp-core/proto/tensor.proto | 32 ++ qdp/qdp-core/src/io.rs | 19 ++ qdp/qdp-core/src/lib.rs | 35 +++ qdp/qdp-core/src/readers/mod.rs | 3 + qdp/qdp-core/src/readers/tensorflow.rs | 268 +++++++++++++++++ qdp/qdp-core/src/{readers/mod.rs => tf_proto.rs} | 19 +- qdp/qdp-core/tests/tensorflow_io.rs | 354 ++++++++++++++++++++++ qdp/qdp-python/pyproject.toml | 3 +- qdp/qdp-python/src/lib.rs | 32 ++ 13 files changed, 1090 insertions(+), 87 deletions(-) diff --git a/qdp/Cargo.lock b/qdp/Cargo.lock index 9e902660e..d316707e3 100644 --- a/qdp/Cargo.lock +++ b/qdp/Cargo.lock @@ -55,6 +55,12 @@ dependencies = [ "libc", ] +[[package]] +name = "anyhow" +version = "1.0.100" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a23eb6b1614318a8071c9b2521f36b424b2c83db5eb3a0fead4a6c0809af6e61" + [[package]] name = "arbitrary" version = "1.4.2" @@ -299,6 +305,12 @@ version = "1.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" +[[package]] +name = "bitflags" +version = "2.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "812e12b5285cc515a9c72a5c1d3b6d46a19dac5acfef5265968c166106e31dd3" + [[package]] name = "block-buffer" version = "0.10.4" @@ -331,9 +343,9 @@ dependencies = [ [[package]] name = "bumpalo" -version = "3.19.0" +version = "3.19.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "46c5e41b57b8bba42a04676d81cb89e9ee8e859a1a66f80a5a72e1cb76b34d43" +checksum = "5dd9dc738b7a8311c7ade152424974d8115f2cdad61e8dab8dac9f2362298510" [[package]] name = "byteorder" @@ -349,9 +361,9 @@ checksum = "b35204fbdc0b3f4446b89fc1ac2cf84a8a68971995d0bf2e925ec7cd960f9cb3" [[package]] name = "cc" -version = "1.2.48" +version = "1.2.52" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c481bdbf0ed3b892f6f806287d72acd515b352a4ec27a208489b8c1bc839633a" +checksum = "cd4932aefd12402b36c60956a4fe0035421f544799057659ff86f923657aada3" dependencies = [ "find-msvc-tools", "jobserver", @@ -391,7 +403,7 @@ version = "0.1.16" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f9d839f2a20b0aee515dc581a6172f2321f96cab76c1a38a4c584a194955390e" dependencies = [ - "getrandom 0.2.16", + "getrandom 0.2.17", "once_cell", "tiny-keccak", ] @@ -535,11 +547,33 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" +[[package]] +name = "errno" +version = "0.3.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb" +dependencies = [ + "libc", + "windows-sys", +] + +[[package]] +name = "fastrand" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" + [[package]] name = "find-msvc-tools" -version = "0.1.5" +version = "0.1.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3a3076410a55c90011c298b04d0cfa770b00fa04e1e3c97d3f6c9de105a03844" +checksum = "f449e6c6c08c865631d4890cfacf252b3d396c9bcc83adb6623cdb02a8336c41" + +[[package]] +name = "fixedbitset" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ce7134b9999ecaf8bcd65542e436736ef32ddca1b3e06094cb6ec5755203b80" [[package]] name = "flatbuffers" @@ -547,7 +581,7 @@ version = "24.12.23" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4f1baf0dbf96932ec9a3038d57900329c015b0bfb7b63d904f3bc27e2b02a096" dependencies = [ - "bitflags", + "bitflags 1.3.2", "rustc_version", ] @@ -573,9 +607,9 @@ dependencies = [ [[package]] name = "getrandom" -version = "0.2.16" +version = "0.2.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "335ff9f135e4384c8150d6f27c6daed433577f86b4750418338c01a1a2528592" +checksum = "ff2abc00be7fca6ebc474524697ae276ad847ad0a6b3faa4bcb027e9a4614ad0" dependencies = [ "cfg-if", "libc", @@ -650,9 +684,9 @@ dependencies = [ [[package]] name = "indexmap" -version = "2.12.1" +version = "2.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0ad4bb2b565bca0645f4d68c5c9af97fba094e9791da685bf83cb5f3ce74acf2" +checksum = "7714e70437a7dc3ac8eb7e6f8df75fd8eb422675fc7678aff7364301092b1017" dependencies = [ "equivalent", "hashbrown 0.16.1", @@ -673,11 +707,20 @@ version = "3.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8bb03732005da905c88227371639bf1ad885cc712789c011c31c5fb3ab3ccf02" +[[package]] +name = "itertools" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba291022dbbd398a455acf126c1e341954079855bc60dfdda641363bd6922569" +dependencies = [ + "either", +] + [[package]] name = "itoa" -version = "1.0.15" +version = "1.0.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4a5f13b858c8d314ee3e8f639011f7ccefe71f97f96e50151fb991f267928e2c" +checksum = "92ecc6618181def0457392ccd0ee51198e065e016d1d527a7ac1b6dc7c1f09d2" [[package]] name = "jobserver" @@ -764,9 +807,9 @@ dependencies = [ [[package]] name = "libc" -version = "0.2.177" +version = "0.2.180" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2874a2af47a2325c2001a6e6fad9b16a53b802102b528163885171cf92b15976" +checksum = "bcc35a38544a891a5f7c865aca548a982ccb3b8650a5b06d0fd33a10283c56fc" [[package]] name = "libloading" @@ -784,6 +827,12 @@ version = "0.2.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f9fbbcab51052fe104eb5e5d351cf728d30a5be1fe14d9be8a3b097481fb97de" +[[package]] +name = "linux-raw-sys" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df1d3c3b53da64cf5760482273a98e575c651a67eec7f77df96b5b642de8f039" + [[package]] name = "log" version = "0.4.29" @@ -834,6 +883,12 @@ dependencies = [ "simd-adler32", ] +[[package]] +name = "multimap" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d87ecb2933e8aeadb3e3a02b828fed80a7528047e68b4f424523a0981a3a084" + [[package]] name = "ndarray" version = "0.16.1" @@ -849,6 +904,21 @@ dependencies = [ "rawpointer", ] +[[package]] +name = "ndarray" +version = "0.17.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "520080814a7a6b4a6e9070823bb24b4531daac8c4627e08ba5de8c5ef2f2752d" +dependencies = [ + "matrixmultiply", + "num-complex", + "num-integer", + "num-traits", + "portable-atomic", + "portable-atomic-util", + "rawpointer", +] + [[package]] name = "ndarray-npy" version = "0.9.1" @@ -856,7 +926,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7b313788c468c49141a9d9b6131fc15f403e6ef4e8446a0b2e18f664ddb278a9" dependencies = [ "byteorder", - "ndarray", + "ndarray 0.16.1", "num-complex", "num-traits", "py_literal", @@ -944,7 +1014,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7aac2e6a6e4468ffa092ad43c39b81c79196c2bb773b8db4085f695efe3bba17" dependencies = [ "libc", - "ndarray", + "ndarray 0.17.2", "num-complex", "num-integer", "num-traits", @@ -1018,9 +1088,9 @@ checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" [[package]] name = "pest" -version = "2.8.4" +version = "2.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cbcfd20a6d4eeba40179f05735784ad32bdaef05ce8e8af05f180d45bb3e7e22" +checksum = "2c9eb05c21a464ea704b53158d358a31e6425db2f63a1a7312268b05fe2b75f7" dependencies = [ "memchr", "ucd-trie", @@ -1028,9 +1098,9 @@ dependencies = [ [[package]] name = "pest_derive" -version = "2.8.4" +version = "2.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "51f72981ade67b1ca6adc26ec221be9f463f2b5839c7508998daa17c23d94d7f" +checksum = "68f9dbced329c441fa79d80472764b1a2c7e57123553b8519b36663a2fb234ed" dependencies = [ "pest", "pest_generator", @@ -1038,9 +1108,9 @@ dependencies = [ [[package]] name = "pest_generator" -version = "2.8.4" +version = "2.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dee9efd8cdb50d719a80088b76f81aec7c41ed6d522ee750178f83883d271625" +checksum = "3bb96d5051a78f44f43c8f712d8e810adb0ebf923fc9ed2655a7f66f63ba8ee5" dependencies = [ "pest", "pest_meta", @@ -1051,14 +1121,24 @@ dependencies = [ [[package]] name = "pest_meta" -version = "2.8.4" +version = "2.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bf1d70880e76bdc13ba52eafa6239ce793d85c8e43896507e43dd8984ff05b82" +checksum = "602113b5b5e8621770cfd490cfd90b9f84ab29bd2b0e49ad83eb6d186cef2365" dependencies = [ "pest", "sha2", ] +[[package]] +name = "petgraph" +version = "0.6.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b4c5cc86750666a3ed20bdaf5ca2a0344f9c67674cae0515bec2da16fbaa47db" +dependencies = [ + "fixedbitset", + "indexmap", +] + [[package]] name = "pkg-config" version = "0.3.32" @@ -1067,9 +1147,9 @@ checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c" [[package]] name = "portable-atomic" -version = "1.11.1" +version = "1.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f84267b20a16ea918e43c6a88433c2d54fa145c92a811b5b047ccbe153674483" +checksum = "f89776e4d69bb58bc6993e99ffa1d11f228b839984854c7daeb5d37f87cbe950" [[package]] name = "portable-atomic-util" @@ -1080,15 +1160,142 @@ dependencies = [ "portable-atomic", ] +[[package]] +name = "prettyplease" +version = "0.2.37" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "479ca8adacdd7ce8f1fb39ce9ecccbfe93a3f1344b3d0d97f20bc0196208f62b" +dependencies = [ + "proc-macro2", + "syn", +] + [[package]] name = "proc-macro2" -version = "1.0.103" +version = "1.0.105" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5ee95bc4ef87b8d5ba32e8b7714ccc834865276eab0aed5c9958d00ec45f49e8" +checksum = "535d180e0ecab6268a3e718bb9fd44db66bbbc256257165fc699dadf70d16fe7" dependencies = [ "unicode-ident", ] +[[package]] +name = "prost" +version = "0.12.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "deb1435c188b76130da55f17a466d252ff7b1418b2ad3e037d127b94e3411f29" +dependencies = [ + "bytes", + "prost-derive", +] + +[[package]] +name = "prost-build" +version = "0.12.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22505a5c94da8e3b7c2996394d1c933236c4d743e81a410bcca4e6989fc066a4" +dependencies = [ + "bytes", + "heck", + "itertools", + "log", + "multimap", + "once_cell", + "petgraph", + "prettyplease", + "prost", + "prost-types", + "regex", + "syn", + "tempfile", +] + +[[package]] +name = "prost-derive" +version = "0.12.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "81bddcdb20abf9501610992b6759a4c888aef7d1a7247ef75e2404275ac24af1" +dependencies = [ + "anyhow", + "itertools", + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "prost-types" +version = "0.12.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9091c90b0a32608e984ff2fa4091273cbdd755d54935c51d520887f4a1dbd5b0" +dependencies = [ + "prost", +] + +[[package]] +name = "protoc-bin-vendored" +version = "3.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d1c381df33c98266b5f08186583660090a4ffa0889e76c7e9a5e175f645a67fa" +dependencies = [ + "protoc-bin-vendored-linux-aarch_64", + "protoc-bin-vendored-linux-ppcle_64", + "protoc-bin-vendored-linux-s390_64", + "protoc-bin-vendored-linux-x86_32", + "protoc-bin-vendored-linux-x86_64", + "protoc-bin-vendored-macos-aarch_64", + "protoc-bin-vendored-macos-x86_64", + "protoc-bin-vendored-win32", +] + +[[package]] +name = "protoc-bin-vendored-linux-aarch_64" +version = "3.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c350df4d49b5b9e3ca79f7e646fde2377b199e13cfa87320308397e1f37e1a4c" + +[[package]] +name = "protoc-bin-vendored-linux-ppcle_64" +version = "3.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a55a63e6c7244f19b5c6393f025017eb5d793fd5467823a099740a7a4222440c" + +[[package]] +name = "protoc-bin-vendored-linux-s390_64" +version = "3.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1dba5565db4288e935d5330a07c264a4ee8e4a5b4a4e6f4e83fad824cc32f3b0" + +[[package]] +name = "protoc-bin-vendored-linux-x86_32" +version = "3.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8854774b24ee28b7868cd71dccaae8e02a2365e67a4a87a6cd11ee6cdbdf9cf5" + +[[package]] +name = "protoc-bin-vendored-linux-x86_64" +version = "3.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b38b07546580df720fa464ce124c4b03630a6fb83e05c336fea2a241df7e5d78" + +[[package]] +name = "protoc-bin-vendored-macos-aarch_64" +version = "3.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "89278a9926ce312e51f1d999fee8825d324d603213344a9a706daa009f1d8092" + +[[package]] +name = "protoc-bin-vendored-macos-x86_64" +version = "3.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "81745feda7ccfb9471d7a4de888f0652e806d5795b61480605d4943176299756" + +[[package]] +name = "protoc-bin-vendored-win32" +version = "3.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "95067976aca6421a523e491fce939a3e65249bac4b977adee0ee9771568e8aa3" + [[package]] name = "py_literal" version = "0.4.0" @@ -1104,9 +1311,9 @@ dependencies = [ [[package]] name = "pyo3" -version = "0.27.1" +version = "0.27.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "37a6df7eab65fc7bee654a421404947e10a0f7085b6951bf2ea395f4659fb0cf" +checksum = "ab53c047fcd1a1d2a8820fe84f05d6be69e9526be40cb03b73f86b6b03e6d87d" dependencies = [ "indoc", "libc", @@ -1121,18 +1328,18 @@ dependencies = [ [[package]] name = "pyo3-build-config" -version = "0.27.1" +version = "0.27.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f77d387774f6f6eec64a004eac0ed525aab7fa1966d94b42f743797b3e395afb" +checksum = "b455933107de8642b4487ed26d912c2d899dec6114884214a0b3bb3be9261ea6" dependencies = [ "target-lexicon", ] [[package]] name = "pyo3-ffi" -version = "0.27.1" +version = "0.27.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2dd13844a4242793e02df3e2ec093f540d948299a6a77ea9ce7afd8623f542be" +checksum = "1c85c9cbfaddf651b1221594209aed57e9e5cff63c4d11d1feead529b872a089" dependencies = [ "libc", "pyo3-build-config", @@ -1140,9 +1347,9 @@ dependencies = [ [[package]] name = "pyo3-macros" -version = "0.27.1" +version = "0.27.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eaf8f9f1108270b90d3676b8679586385430e5c0bb78bb5f043f95499c821a71" +checksum = "0a5b10c9bf9888125d917fb4d2ca2d25c8df94c7ab5a52e13313a07e050a3b02" dependencies = [ "proc-macro2", "pyo3-macros-backend", @@ -1152,9 +1359,9 @@ dependencies = [ [[package]] name = "pyo3-macros-backend" -version = "0.27.1" +version = "0.27.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "70a3b2274450ba5288bc9b8c1b69ff569d1d61189d4bff38f8d22e03d17f932b" +checksum = "03b51720d314836e53327f5871d4c0cfb4fb37cc2c4a11cc71907a86342c40f9" dependencies = [ "heck", "proc-macro2", @@ -1168,11 +1375,15 @@ name = "qdp-core" version = "0.1.0" dependencies = [ "arrow", + "bytes", "cudarc", - "ndarray", + "ndarray 0.16.1", "ndarray-npy", "nvtx", "parquet", + "prost", + "prost-build", + "protoc-bin-vendored", "qdp-kernels", "rayon", "thiserror", @@ -1197,9 +1408,9 @@ dependencies = [ [[package]] name = "quote" -version = "1.0.42" +version = "1.0.43" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a338cc41d27e6cc6dce6cefc13a0729dfbb81c262b1f519331575dd80ef3067f" +checksum = "dc74d9a594b72ae6656596548f56f667211f8a97b3d4c3d467150794690dc40a" dependencies = [ "proc-macro2", ] @@ -1280,6 +1491,19 @@ dependencies = [ "semver", ] +[[package]] +name = "rustix" +version = "1.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "146c9e247ccc180c1f61615433868c99f3de3ae256a30a43b49f67c2d9171f34" +dependencies = [ + "bitflags 2.10.0", + "errno", + "libc", + "linux-raw-sys", + "windows-sys", +] + [[package]] name = "rustversion" version = "1.0.22" @@ -1288,9 +1512,9 @@ checksum = "b39cdef0fa800fc44525c84ccb54a029961a8215f9619753635a9c0d2538d46d" [[package]] name = "ryu" -version = "1.0.20" +version = "1.0.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "28d3b2b1366ec20994f1fd18c3c594f05c5dd4bc44d8bb0c1c632c8d6829481f" +checksum = "a50f4cf475b65d88e057964e0e9bb1f0aa9bbb2036dc65c64596b42932536984" [[package]] name = "semver" @@ -1335,15 +1559,15 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.145" +version = "1.0.149" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "402a6f66d8c709116cf22f558eab210f5a50187f702eb4d7e5ef38d9a7f1c79c" +checksum = "83fc039473c5595ace860d8c4fafa220ff474b3fc6bfdb4293327f1a37e94d86" dependencies = [ "itoa", "memchr", - "ryu", "serde", "serde_core", + "zmij", ] [[package]] @@ -1365,9 +1589,9 @@ checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" [[package]] name = "simd-adler32" -version = "0.3.7" +version = "0.3.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d66dc143e6b11c1eddc06d5c423cfc97062865baf299914ab64caa38182078fe" +checksum = "e320a6c5ad31d271ad523dcf3ad13e2767ad8b1cb8f047f75a8aeaf8da139da2" [[package]] name = "simdutf8" @@ -1389,9 +1613,9 @@ checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" [[package]] name = "syn" -version = "2.0.111" +version = "2.0.114" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "390cc9a294ab71bdb1aa2e99d13be9c753cd2d7bd6560c77118597410c4d2e87" +checksum = "d4d107df263a3013ef9b1879b0df87d706ff80f65a86ea879bd9c31f9b307c2a" dependencies = [ "proc-macro2", "quote", @@ -1400,9 +1624,22 @@ dependencies = [ [[package]] name = "target-lexicon" -version = "0.13.3" +version = "0.13.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b1dd07eb858a2067e2f3c7155d54e929265c264e6f37efe3ee7a8d1b5a1dd0ba" + +[[package]] +name = "tempfile" +version = "3.24.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "df7f62577c25e07834649fc3b39fafdc597c0a3527dc1c60129201ccfcbaa50c" +checksum = "655da9c7eb6305c55742045d5a8d2037996d61d8de95806335c7c86ce0f82e9c" +dependencies = [ + "fastrand", + "getrandom 0.3.4", + "once_cell", + "rustix", + "windows-sys", +] [[package]] name = "thiserror" @@ -1609,6 +1846,15 @@ dependencies = [ "windows-link", ] +[[package]] +name = "windows-sys" +version = "0.61.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae137229bcbd6cdf0f7b80a31df61766145077ddf49416a728b02cb3921ff3fc" +dependencies = [ + "windows-link", +] + [[package]] name = "wit-bindgen" version = "0.46.0" @@ -1617,18 +1863,18 @@ checksum = "f17a85883d4e6d00e8a97c586de764dabcc06133f7f1d55dce5cdc070ad7fe59" [[package]] name = "zerocopy" -version = "0.8.31" +version = "0.8.33" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fd74ec98b9250adb3ca554bdde269adf631549f51d8a8f8f0a10b50f1cb298c3" +checksum = "668f5168d10b9ee831de31933dc111a459c97ec93225beb307aed970d1372dfd" dependencies = [ "zerocopy-derive", ] [[package]] name = "zerocopy-derive" -version = "0.8.31" +version = "0.8.33" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d8a8d209fdf45cf5138cbb5a506f6b52522a25afccc534d1475dad8e31105c6a" +checksum = "2c7962b26b0a8685668b671ee4b54d007a67d4eaf05fda79ac0ecf41e32270f1" dependencies = [ "proc-macro2", "quote", @@ -1652,6 +1898,12 @@ dependencies = [ "zopfli", ] +[[package]] +name = "zmij" +version = "1.0.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac93432f5b761b22864c774aac244fa5c0fd877678a4c37ebf6cf42208f9c9ec" + [[package]] name = "zopfli" version = "0.8.3" diff --git a/qdp/Cargo.toml b/qdp/Cargo.toml index 7f98ac5a4..7c9571833 100644 --- a/qdp/Cargo.toml +++ b/qdp/Cargo.toml @@ -31,6 +31,12 @@ parquet = "54" # NumPy file format support ndarray = "0.16" ndarray-npy = "0.9" +# Protocol Buffer support for TensorFlow TensorProto +prost = "0.12" +prost-build = "0.12" +bytes = "1.5" +# Optional: vendored protoc to avoid build failures when protoc is missing +protoc-bin-vendored = "3" # Release profile optimizations [profile.release] diff --git a/qdp/qdp-core/Cargo.toml b/qdp/qdp-core/Cargo.toml index fe0ae647c..c4c27533c 100644 --- a/qdp/qdp-core/Cargo.toml +++ b/qdp/qdp-core/Cargo.toml @@ -13,6 +13,12 @@ arrow = { workspace = true } parquet = { workspace = true } ndarray = { workspace = true } ndarray-npy = { workspace = true } +prost = { workspace = true } +bytes = { workspace = true } + +[build-dependencies] +prost-build = { workspace = true } +protoc-bin-vendored = { workspace = true } [lib] name = "qdp_core" diff --git a/qdp/qdp-core/src/readers/mod.rs b/qdp/qdp-core/build.rs similarity index 54% copy from qdp/qdp-core/src/readers/mod.rs copy to qdp/qdp-core/build.rs index 4ca199e37..311ea139d 100644 --- a/qdp/qdp-core/src/readers/mod.rs +++ b/qdp/qdp-core/build.rs @@ -14,19 +14,23 @@ // See the License for the specific language governing permissions and // limitations under the License. -//! Format-specific data reader implementations. -//! -//! This module contains concrete implementations of the [`DataReader`] and -//! [`StreamingDataReader`] traits for various file formats. -//! -//! # Fully Implemented Formats -//! - **Parquet**: [`ParquetReader`], [`ParquetStreamingReader`] -//! - **Arrow IPC**: [`ArrowIPCReader`] +fn main() { + // Use vendored protoc to avoid missing protoc in CI/dev environments + unsafe { + std::env::set_var("PROTOC", protoc_bin_vendored::protoc_bin_path().unwrap()); + } -pub mod arrow_ipc; -pub mod numpy; -pub mod parquet; + let mut config = prost_build::Config::new(); -pub use arrow_ipc::ArrowIPCReader; -pub use numpy::NumpyReader; -pub use parquet::{ParquetReader, ParquetStreamingReader}; + // Generate tensor_content as bytes::Bytes (avoids copy during protobuf decode) + config.bytes([".tensorflow.TensorProto.tensor_content"]); + + // Generate fixed filename include file to avoid guessing output filename/module path + config.include_file("tensorflow_proto_mod.rs"); + + config + .compile_protos(&["proto/tensor.proto"], &["proto"]) + .unwrap(); + + println!("cargo:rerun-if-changed=proto/tensor.proto"); +} diff --git a/qdp/qdp-core/proto/tensor.proto b/qdp/qdp-core/proto/tensor.proto new file mode 100644 index 000000000..c727e403b --- /dev/null +++ b/qdp/qdp-core/proto/tensor.proto @@ -0,0 +1,32 @@ +syntax = "proto3"; + +package tensorflow; + +// TensorProto - only define necessary fields, field numbers match TensorFlow official +message TensorProto { + // Field 1: dtype (enum DataType in TF, but varint in wire format) + // DT_DOUBLE = 2 (see tensorflow/core/framework/types.proto) + int32 dtype = 1; + + // Field 2: tensor_shape + TensorShapeProto tensor_shape = 2; + + // Field 4: tensor_content (preferred for efficient parsing) + bytes tensor_content = 4; + + // Field 6: double_val (fallback, only used when tensor_content is empty) + repeated double double_val = 6 [packed = true]; +} + +message TensorShapeProto { + // Field 2: dim (field number matches official) + repeated Dim dim = 2; + // Field 3: unknown_rank (optional; helps with better error messages) + bool unknown_rank = 3; +} + +message Dim { + // Field 1: size + int64 size = 1; + // Skip name field (field number 2) to reduce parsing overhead +} diff --git a/qdp/qdp-core/src/io.rs b/qdp/qdp-core/src/io.rs index f3715f04a..4e3cbdd07 100644 --- a/qdp/qdp-core/src/io.rs +++ b/qdp/qdp-core/src/io.rs @@ -267,3 +267,22 @@ pub fn read_numpy_batch<P: AsRef<Path>>(path: P) -> Result<(Vec<f64>, usize, usi /// /// This is a type alias for backward compatibility. Use [`crate::readers::ParquetStreamingReader`] directly. pub type ParquetBlockReader = crate::readers::ParquetStreamingReader; + +/// Reads batch data from a TensorFlow TensorProto file. +/// +/// Supports Float64 tensors with shape [batch_size, feature_size] or [n]. +/// Prefers tensor_content for efficient parsing, but still requires one copy to Vec<f64>. +/// +/// # Byte Order +/// Assumes little-endian byte order (standard on x86_64). +/// +/// # Returns +/// Tuple of `(flattened_data, num_samples, sample_size)` +/// +/// # TODO +/// Add OOM protection for very large files +pub fn read_tensorflow_batch<P: AsRef<Path>>(path: P) -> Result<(Vec<f64>, usize, usize)> { + use crate::reader::DataReader; + let mut reader = crate::readers::TensorFlowReader::new(path)?; + reader.read_batch() +} diff --git a/qdp/qdp-core/src/lib.rs b/qdp/qdp-core/src/lib.rs index 8d117ce1b..e748f41ea 100644 --- a/qdp/qdp-core/src/lib.rs +++ b/qdp/qdp-core/src/lib.rs @@ -21,6 +21,7 @@ pub mod io; pub mod preprocessing; pub mod reader; pub mod readers; +pub mod tf_proto; #[macro_use] mod profiling; @@ -489,6 +490,40 @@ impl QdpEngine { encoding_method, ) } + + /// Load data from TensorFlow TensorProto file and encode into quantum state + /// + /// Supports Float64 tensors with shape [batch_size, feature_size] or [n]. + /// Uses efficient parsing with tensor_content when available. + /// + /// # Arguments + /// * `path` - Path to TensorProto file (.pb) + /// * `num_qubits` - Number of qubits + /// * `encoding_method` - Strategy: "amplitude", "angle", or "basis" + /// + /// # Returns + /// Single DLPack pointer containing all encoded states (shape: [num_samples, 2^num_qubits]) + pub fn encode_from_tensorflow( + &self, + path: &str, + num_qubits: usize, + encoding_method: &str, + ) -> Result<*mut DLManagedTensor> { + crate::profile_scope!("Mahout::EncodeFromTensorFlow"); + + let (batch_data, num_samples, sample_size) = { + crate::profile_scope!("IO::ReadTensorFlowBatch"); + crate::io::read_tensorflow_batch(path)? + }; + + self.encode_batch( + &batch_data, + num_samples, + sample_size, + num_qubits, + encoding_method, + ) + } } // Re-export key types for convenience diff --git a/qdp/qdp-core/src/readers/mod.rs b/qdp/qdp-core/src/readers/mod.rs index 4ca199e37..c3ffc6efe 100644 --- a/qdp/qdp-core/src/readers/mod.rs +++ b/qdp/qdp-core/src/readers/mod.rs @@ -22,11 +22,14 @@ //! # Fully Implemented Formats //! - **Parquet**: [`ParquetReader`], [`ParquetStreamingReader`] //! - **Arrow IPC**: [`ArrowIPCReader`] +//! - **TensorFlow TensorProto**: [`TensorFlowReader`] pub mod arrow_ipc; pub mod numpy; pub mod parquet; +pub mod tensorflow; pub use arrow_ipc::ArrowIPCReader; pub use numpy::NumpyReader; pub use parquet::{ParquetReader, ParquetStreamingReader}; +pub use tensorflow::TensorFlowReader; diff --git a/qdp/qdp-core/src/readers/tensorflow.rs b/qdp/qdp-core/src/readers/tensorflow.rs new file mode 100644 index 000000000..d2ed9e103 --- /dev/null +++ b/qdp/qdp-core/src/readers/tensorflow.rs @@ -0,0 +1,268 @@ +// +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! TensorFlow TensorProto format reader implementation. + +use crate::error::{MahoutError, Result}; +use crate::reader::DataReader; +use bytes::Bytes; +use prost::Message; +use std::fs::File; +use std::io::Read; +use std::path::Path; + +/// Reader for TensorFlow TensorProto files. +/// +/// Supports Float64 tensors with shape [batch_size, feature_size] or [n]. +/// Prefers tensor_content for efficient parsing, but still requires one copy to Vec<f64>. +/// +/// # Byte Order +/// This implementation assumes little-endian byte order, which is the standard +/// on x86_64 platforms. TensorFlow typically uses host byte order. +pub struct TensorFlowReader { + // Store either raw bytes or f64 values to avoid unnecessary conversions + payload: TensorPayload, + num_samples: usize, + sample_size: usize, + read: bool, +} + +enum TensorPayload { + Bytes(Bytes), + F64(Vec<f64>), +} + +impl TensorFlowReader { + /// Create a new TensorFlow reader from a file path. + pub fn new<P: AsRef<Path>>(path: P) -> Result<Self> { + // Read entire file into memory (single read to avoid multiple I/O operations) + let mut file = File::open(path.as_ref()) + .map_err(|e| MahoutError::Io(format!("Failed to open TensorFlow file: {}", e)))?; + + let mut buffer = Vec::new(); + file.read_to_end(&mut buffer) + .map_err(|e| MahoutError::Io(format!("Failed to read TensorFlow file: {}", e)))?; + + // Use Bytes for decode input; with build.rs config.bytes(...) this avoids copying tensor_content during decode + let buffer = Bytes::from(buffer); + + // Parse TensorProto + let mut tensor_proto = crate::tf_proto::tensorflow::TensorProto::decode(buffer) + .map_err(|e| MahoutError::Io(format!("Failed to parse TensorProto: {}", e)))?; + + // Validate dtype == DT_DOUBLE (2) + // Official TensorFlow: DT_DOUBLE = 2 (not 9) + const DT_DOUBLE: i32 = 2; + if tensor_proto.dtype != DT_DOUBLE { + return Err(MahoutError::InvalidInput(format!( + "Expected DT_DOUBLE (2), got {}", + tensor_proto.dtype + ))); + } + + // Parse shape + let shape = tensor_proto.tensor_shape.as_ref().ok_or_else(|| { + MahoutError::InvalidInput("TensorProto.tensor_shape is missing".into()) + })?; + let (num_samples, sample_size) = Self::parse_shape(shape)?; + + // Extract data (prefer tensor_content, fallback to double_val) + // Check for integer overflow + let expected_elems = num_samples.checked_mul(sample_size).ok_or_else(|| { + MahoutError::InvalidInput(format!( + "Tensor shape too large: {} * {} would overflow", + num_samples, sample_size + )) + })?; + let expected_bytes = expected_elems + .checked_mul(std::mem::size_of::<f64>()) + .ok_or_else(|| { + MahoutError::InvalidInput(format!( + "Tensor size too large: {} elements * {} bytes would overflow", + expected_elems, + std::mem::size_of::<f64>() + )) + })?; + let payload = Self::extract_payload(&mut tensor_proto, expected_elems, expected_bytes)?; + + Ok(Self { + payload, + num_samples, + sample_size, + read: false, + }) + } + + /// Parse shape, supports 1D and 2D tensors + fn parse_shape( + shape: &crate::tf_proto::tensorflow::TensorShapeProto, + ) -> Result<(usize, usize)> { + if shape.unknown_rank { + return Err(MahoutError::InvalidInput( + "Unsupported tensor shape: unknown_rank=true".into(), + )); + } + + let dims = &shape.dim; + + match dims.len() { + 1 => { + // 1D: [n] -> single sample + let size = dims[0].size; + if size <= 0 { + return Err(MahoutError::InvalidInput(format!( + "Invalid dimension size: {}", + size + ))); + } + Ok((1, size as usize)) + } + 2 => { + // 2D: [batch_size, feature_size] + let batch_size = dims[0].size; + let feature_size = dims[1].size; + if batch_size <= 0 || feature_size <= 0 { + return Err(MahoutError::InvalidInput(format!( + "Invalid shape: [{}, {}]", + batch_size, feature_size + ))); + } + Ok((batch_size as usize, feature_size as usize)) + } + _ => Err(MahoutError::InvalidInput(format!( + "Unsupported tensor rank: {} (only 1D and 2D supported)", + dims.len() + ))), + } + } + + /// Safely extract tensor_content, handling alignment and byte order + /// + /// Prefers tensor_content (efficient parsing), falls back to double_val if unavailable. + fn extract_payload( + tensor_proto: &mut crate::tf_proto::tensorflow::TensorProto, + expected_elems: usize, + expected_bytes: usize, + ) -> Result<TensorPayload> { + if !tensor_proto.tensor_content.is_empty() { + let content = std::mem::take(&mut tensor_proto.tensor_content); + if content.len() != expected_bytes { + return Err(MahoutError::InvalidInput(format!( + "tensor_content size mismatch: expected {} bytes, got {}", + expected_bytes, + content.len() + ))); + } + // With build.rs config.bytes(...), this is Bytes (avoids copy during decode) + Ok(TensorPayload::Bytes(content)) + } else if !tensor_proto.double_val.is_empty() { + let values = std::mem::take(&mut tensor_proto.double_val); + if values.len() != expected_elems { + return Err(MahoutError::InvalidInput(format!( + "double_val length mismatch: expected {} values, got {}", + expected_elems, + values.len() + ))); + } + Ok(TensorPayload::F64(values)) + } else { + Err(MahoutError::InvalidInput( + "TensorProto has no data (both tensor_content and double_val are empty)" + .to_string(), + )) + } + } + + /// Convert `tensor_content` bytes to `Vec<f64>`. + /// + /// Note: Even though `tensor_content` can be zero-copy, `DataReader` requires `Vec<f64>`, + /// so one copy is still needed. Uses memcpy (instead of element-wise `from_le_bytes`) for best performance. + /// + /// # Safety + /// This function uses `unsafe` for memory copy, but performs the following safety checks: + /// 1. Byte order check (little-endian only) + /// 2. Length check (must be multiple of 8) + /// 3. Alignment check (f64 needs 8-byte alignment, Vec handles this automatically) + /// 4. Overflow check (ensures no overflow) + fn bytes_to_f64_vec(bytes: &Bytes) -> Result<Vec<f64>> { + if !cfg!(target_endian = "little") { + return Err(MahoutError::NotImplemented( + "Big-endian platforms are not supported for TensorFlow tensor_content".into(), + )); + } + if !bytes.len().is_multiple_of(8) { + return Err(MahoutError::InvalidInput(format!( + "tensor_content length {} is not a multiple of 8", + bytes.len() + ))); + } + + let n = bytes.len() / 8; + // Check overflow: ensure n doesn't exceed Vec's maximum capacity + if n > (usize::MAX / std::mem::size_of::<f64>()) { + return Err(MahoutError::InvalidInput( + "tensor_content too large: would exceed maximum vector size".into(), + )); + } + + let mut data = Vec::<f64>::with_capacity(n); + unsafe { + // Safety: We've checked: + // 1. bytes.len() % 8 == 0 (ensures divisible) + // 2. n <= usize::MAX / size_of::<f64>() (ensures no overflow) + // 3. Vec::with_capacity(n) ensures alignment (Rust Vec guarantees this) + // 4. copy_nonoverlapping is safe because source and destination don't overlap + // 5. Copy data first, then set length, ensuring memory is initialized + std::ptr::copy_nonoverlapping( + bytes.as_ptr(), + data.as_mut_ptr() as *mut u8, + bytes.len(), + ); + data.set_len(n); + } + Ok(data) + } +} + +impl DataReader for TensorFlowReader { + fn read_batch(&mut self) -> Result<(Vec<f64>, usize, usize)> { + if self.read { + return Err(MahoutError::InvalidInput( + "Reader already consumed".to_string(), + )); + } + self.read = true; + + match std::mem::replace(&mut self.payload, TensorPayload::F64(Vec::new())) { + TensorPayload::F64(data) => { + // Already Vec<f64>, return directly + Ok((data, self.num_samples, self.sample_size)) + } + TensorPayload::Bytes(bytes) => { + let data = Self::bytes_to_f64_vec(&bytes)?; + Ok((data, self.num_samples, self.sample_size)) + } + } + } + + fn get_sample_size(&self) -> Option<usize> { + Some(self.sample_size) + } + + fn get_num_samples(&self) -> Option<usize> { + Some(self.num_samples) + } +} diff --git a/qdp/qdp-core/src/readers/mod.rs b/qdp/qdp-core/src/tf_proto.rs similarity index 61% copy from qdp/qdp-core/src/readers/mod.rs copy to qdp/qdp-core/src/tf_proto.rs index 4ca199e37..e3c87b909 100644 --- a/qdp/qdp-core/src/readers/mod.rs +++ b/qdp/qdp-core/src/tf_proto.rs @@ -14,19 +14,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -//! Format-specific data reader implementations. +//! TensorFlow TensorProto protobuf definitions. //! -//! This module contains concrete implementations of the [`DataReader`] and -//! [`StreamingDataReader`] traits for various file formats. -//! -//! # Fully Implemented Formats -//! - **Parquet**: [`ParquetReader`], [`ParquetStreamingReader`] -//! - **Arrow IPC**: [`ArrowIPCReader`] - -pub mod arrow_ipc; -pub mod numpy; -pub mod parquet; +//! This module contains the generated protobuf code for TensorFlow TensorProto format. +//! The code is generated at build time by prost-build from `proto/tensor.proto`. -pub use arrow_ipc::ArrowIPCReader; -pub use numpy::NumpyReader; -pub use parquet::{ParquetReader, ParquetStreamingReader}; +// Generated by build.rs to OUT_DIR (see build.rs include_file) +include!(concat!(env!("OUT_DIR"), "/tensorflow_proto_mod.rs")); diff --git a/qdp/qdp-core/tests/tensorflow_io.rs b/qdp/qdp-core/tests/tensorflow_io.rs new file mode 100644 index 000000000..87b3edc3d --- /dev/null +++ b/qdp/qdp-core/tests/tensorflow_io.rs @@ -0,0 +1,354 @@ +// +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use bytes::Bytes; +use qdp_core::io::read_tensorflow_batch; +use qdp_core::reader::DataReader; +use qdp_core::readers::TensorFlowReader; +use std::fs; + +mod common; + +/// Helper function to create a TensorProto file using prost +/// This creates a minimal TensorProto with tensor_content (preferred path) +fn create_tensorflow_file_tensor_content( + path: &str, + data: &[f64], + shape: &[i64], +) -> Result<(), Box<dyn std::error::Error>> { + use prost::Message; + use qdp_core::tf_proto::tensorflow; + + // Convert f64 data to bytes (little-endian) + let mut tensor_content = Vec::with_capacity(data.len() * 8); + for &value in data { + tensor_content.extend_from_slice(&value.to_le_bytes()); + } + + let dims: Vec<tensorflow::Dim> = shape.iter().map(|&size| tensorflow::Dim { size }).collect(); + + let tensor_proto = tensorflow::TensorProto { + dtype: 2, // DT_DOUBLE = 2 + tensor_shape: Some(tensorflow::TensorShapeProto { + dim: dims, + unknown_rank: false, + }), + tensor_content: tensor_content.into(), + double_val: vec![], + }; + + let mut buf = Vec::new(); + tensor_proto.encode(&mut buf)?; + fs::write(path, buf)?; + Ok(()) +} + +/// Helper function to create a TensorProto file using double_val (fallback path) +fn create_tensorflow_file_double_val( + path: &str, + data: &[f64], + shape: &[i64], +) -> Result<(), Box<dyn std::error::Error>> { + use prost::Message; + use qdp_core::tf_proto::tensorflow; + + let dims: Vec<tensorflow::Dim> = shape.iter().map(|&size| tensorflow::Dim { size }).collect(); + + let tensor_proto = tensorflow::TensorProto { + dtype: 2, // DT_DOUBLE = 2 + tensor_shape: Some(tensorflow::TensorShapeProto { + dim: dims, + unknown_rank: false, + }), + tensor_content: Bytes::new(), + double_val: data.to_vec(), + }; + + let mut buf = Vec::new(); + tensor_proto.encode(&mut buf)?; + fs::write(path, buf)?; + Ok(()) +} + +#[test] +fn test_read_tensorflow_2d_tensor_content() { + let temp_path = "/tmp/test_tensorflow_2d_tc.pb"; + let num_samples = 10; + let sample_size = 16; + + // Create test data + let mut data = Vec::new(); + for i in 0..num_samples { + for j in 0..sample_size { + data.push((i * sample_size + j) as f64); + } + } + + // Create TensorProto file with tensor_content + create_tensorflow_file_tensor_content( + temp_path, + &data, + &[num_samples as i64, sample_size as i64], + ) + .unwrap(); + + // Read and verify + let (read_data, samples, size) = read_tensorflow_batch(temp_path).unwrap(); + + assert_eq!(samples, num_samples); + assert_eq!(size, sample_size); + assert_eq!(read_data.len(), num_samples * sample_size); + + for (i, &val) in read_data.iter().enumerate() { + assert_eq!(val, i as f64); + } + + // Cleanup + fs::remove_file(temp_path).unwrap(); +} + +#[test] +fn test_read_tensorflow_2d_double_val() { + let temp_path = "/tmp/test_tensorflow_2d_dv.pb"; + let num_samples = 5; + let sample_size = 8; + + // Create test data + let mut data = Vec::new(); + for i in 0..num_samples { + for j in 0..sample_size { + data.push((i * sample_size + j) as f64); + } + } + + // Create TensorProto file with double_val (fallback path) + create_tensorflow_file_double_val(temp_path, &data, &[num_samples as i64, sample_size as i64]) + .unwrap(); + + // Read and verify + let (read_data, samples, size) = read_tensorflow_batch(temp_path).unwrap(); + + assert_eq!(samples, num_samples); + assert_eq!(size, sample_size); + assert_eq!(read_data.len(), num_samples * sample_size); + + for (i, &val) in read_data.iter().enumerate() { + assert_eq!(val, i as f64); + } + + // Cleanup + fs::remove_file(temp_path).unwrap(); +} + +#[test] +fn test_read_tensorflow_1d_tensor() { + let temp_path = "/tmp/test_tensorflow_1d.pb"; + let sample_size = 16; + + // Create test data (1D tensor = single sample) + let data: Vec<f64> = (0..sample_size).map(|i| i as f64).collect(); + + // Create TensorProto file + create_tensorflow_file_tensor_content(temp_path, &data, &[sample_size as i64]).unwrap(); + + // Read and verify + let (read_data, samples, size) = read_tensorflow_batch(temp_path).unwrap(); + + assert_eq!(samples, 1); // 1D tensor is treated as single sample + assert_eq!(size, sample_size); + assert_eq!(read_data.len(), sample_size); + + for (i, &val) in read_data.iter().enumerate() { + assert_eq!(val, i as f64); + } + + // Cleanup + fs::remove_file(temp_path).unwrap(); +} + +#[test] +fn test_read_tensorflow_large_batch() { + let temp_path = "/tmp/test_tensorflow_large.pb"; + let num_samples = 100; + let sample_size = 64; + + // Create large dataset + let mut data = Vec::with_capacity(num_samples * sample_size); + for i in 0..num_samples { + for j in 0..sample_size { + data.push((i * sample_size + j) as f64 / (num_samples * sample_size) as f64); + } + } + + // Create TensorProto file + create_tensorflow_file_tensor_content( + temp_path, + &data, + &[num_samples as i64, sample_size as i64], + ) + .unwrap(); + + // Read and verify + let (read_data, samples, size) = read_tensorflow_batch(temp_path).unwrap(); + + assert_eq!(samples, num_samples); + assert_eq!(size, sample_size); + assert_eq!(read_data.len(), data.len()); + + for i in 0..data.len() { + assert!((data[i] - read_data[i]).abs() < 1e-10); + } + + // Cleanup + fs::remove_file(temp_path).unwrap(); +} + +#[test] +fn test_tensorflow_invalid_dtype() { + let temp_path = "/tmp/test_tensorflow_invalid_dtype.pb"; + use prost::Message; + use qdp_core::tf_proto::tensorflow; + + // Create TensorProto with wrong dtype (DT_FLOAT = 1 instead of DT_DOUBLE = 2) + let tensor_proto = tensorflow::TensorProto { + dtype: 1, // DT_FLOAT, not DT_DOUBLE + tensor_shape: Some(tensorflow::TensorShapeProto { + dim: vec![tensorflow::Dim { size: 4 }], + unknown_rank: false, + }), + tensor_content: Bytes::from(vec![0u8; 32]), // 4 * 8 bytes + double_val: vec![], + }; + + let mut buf = Vec::new(); + tensor_proto.encode(&mut buf).unwrap(); + fs::write(temp_path, buf).unwrap(); + + // Should fail with InvalidInput error + let result = read_tensorflow_batch(temp_path); + assert!(result.is_err()); + if let Err(e) = result { + assert!(e.to_string().contains("DT_DOUBLE")); + } + + // Cleanup + fs::remove_file(temp_path).unwrap(); +} + +#[test] +fn test_tensorflow_empty_file_fails() { + let result = read_tensorflow_batch("/tmp/nonexistent_tensorflow_file_12345.pb"); + assert!(result.is_err()); +} + +#[test] +fn test_tensorflow_3d_shape_fails() { + let temp_path = "/tmp/test_tensorflow_3d.pb"; + use prost::Message; + use qdp_core::tf_proto::tensorflow; + + // Create TensorProto with 3D shape (unsupported) + let tensor_proto = tensorflow::TensorProto { + dtype: 2, // DT_DOUBLE + tensor_shape: Some(tensorflow::TensorShapeProto { + dim: vec![ + tensorflow::Dim { size: 2 }, + tensorflow::Dim { size: 3 }, + tensorflow::Dim { size: 4 }, + ], + unknown_rank: false, + }), + tensor_content: Bytes::from(vec![0u8; 2 * 3 * 4 * 8]), + double_val: vec![], + }; + + let mut buf = Vec::new(); + tensor_proto.encode(&mut buf).unwrap(); + fs::write(temp_path, buf).unwrap(); + + // Should fail with InvalidInput error + let result = read_tensorflow_batch(temp_path); + assert!(result.is_err()); + if let Err(e) = result { + assert!(e.to_string().contains("Unsupported tensor rank")); + } + + // Cleanup + fs::remove_file(temp_path).unwrap(); +} + +#[test] +fn test_tensorflow_reader_direct() { + let temp_path = "/tmp/test_tensorflow_reader_direct.pb"; + let num_samples = 3; + let sample_size = 4; + + let data: Vec<f64> = (0..num_samples * sample_size).map(|i| i as f64).collect(); + create_tensorflow_file_tensor_content( + temp_path, + &data, + &[num_samples as i64, sample_size as i64], + ) + .unwrap(); + + // Test direct reader usage + let mut reader = TensorFlowReader::new(temp_path).unwrap(); + assert_eq!(reader.get_num_samples(), Some(num_samples)); + assert_eq!(reader.get_sample_size(), Some(sample_size)); + + let (read_data, samples, size) = reader.read_batch().unwrap(); + assert_eq!(samples, num_samples); + assert_eq!(size, sample_size); + assert_eq!(read_data, data); + + // Reader should be consumed + assert!(reader.read_batch().is_err()); + + // Cleanup + fs::remove_file(temp_path).unwrap(); +} + +#[test] +fn test_tensorflow_size_mismatch_fails() { + let temp_path = "/tmp/test_tensorflow_size_mismatch.pb"; + use prost::Message; + use qdp_core::tf_proto::tensorflow; + + // Create TensorProto with shape [2, 4] but wrong data size + let tensor_proto = tensorflow::TensorProto { + dtype: 2, // DT_DOUBLE + tensor_shape: Some(tensorflow::TensorShapeProto { + dim: vec![tensorflow::Dim { size: 2 }, tensorflow::Dim { size: 4 }], + unknown_rank: false, + }), + tensor_content: Bytes::from(vec![0u8; 16]), // Only 16 bytes, should be 2*4*8 = 64 bytes + double_val: vec![], + }; + + let mut buf = Vec::new(); + tensor_proto.encode(&mut buf).unwrap(); + fs::write(temp_path, buf).unwrap(); + + // Should fail with size mismatch error + let result = read_tensorflow_batch(temp_path); + assert!(result.is_err()); + if let Err(e) = result { + assert!(e.to_string().contains("size mismatch")); + } + + // Cleanup + fs::remove_file(temp_path).unwrap(); +} diff --git a/qdp/qdp-python/pyproject.toml b/qdp/qdp-python/pyproject.toml index 6ef0ab389..c52662564 100644 --- a/qdp/qdp-python/pyproject.toml +++ b/qdp/qdp-python/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "maturin" [project] name = "qdp-python" -requires-python = ">=3.11" +requires-python = ">=3.11,<3.13" classifiers = [ "Programming Language :: Rust", "Programming Language :: Python :: Implementation :: CPython", @@ -24,6 +24,7 @@ benchmark = [ "numpy>=1.24,<2.0", "pandas>=2.0", "pyarrow>=14.0", + "tensorflow>=2.20", "torch>=2.2", "qiskit>=1.0", "qiskit-aer>=0.17.2", diff --git a/qdp/qdp-python/src/lib.rs b/qdp/qdp-python/src/lib.rs index 1dc60da70..2f5d9af40 100644 --- a/qdp/qdp-python/src/lib.rs +++ b/qdp/qdp-python/src/lib.rs @@ -421,6 +421,38 @@ impl QdpEngine { consumed: false, }) } + + /// Encode from TensorFlow TensorProto file + /// + /// Args: + /// path: Path to TensorProto file (.pb) + /// num_qubits: Number of qubits for encoding + /// encoding_method: Encoding strategy (currently only "amplitude") + /// + /// Returns: + /// QuantumTensor: DLPack tensor containing all encoded states + /// + /// Example: + /// >>> engine = QdpEngine(device_id=0) + /// >>> batched = engine.encode_from_tensorflow("data.pb", 16, "amplitude") + /// >>> torch_tensor = torch.from_dlpack(batched) # Shape: [200, 65536] + fn encode_from_tensorflow( + &self, + path: &str, + num_qubits: usize, + encoding_method: &str, + ) -> PyResult<QuantumTensor> { + let ptr = self + .engine + .encode_from_tensorflow(path, num_qubits, encoding_method) + .map_err(|e| { + PyRuntimeError::new_err(format!("Encoding from TensorFlow failed: {}", e)) + })?; + Ok(QuantumTensor { + ptr, + consumed: false, + }) + } } /// Mahout QDP Python module
