MasterJH5574 commented on code in PR #17730:
URL: https://github.com/apache/tvm/pull/17730#discussion_r1987336679
##########
python/tvm/relax/backend/cuda/flashinfer.py:
##########
@@ -37,7 +39,43 @@ def _compile_flashinfer_kernels(
FLASHINFER_TVM_BINDING_DIR,
)
- # Todo(tvm-team): enable compilation cache
+ # ------------------------------------------------------------------------
+ # Caching Flow: create build_directory and compute cache hash.
+ # ------------------------------------------------------------------------
+ build_directory = FLASHINFER_JIT_DIR / name
+ build_directory.mkdir(parents=True, exist_ok=True)
+
+ # Compute latest modification time among all source files
+ latest_src_mtime = max(src.stat().st_mtime for src in source_paths)
+
+ # Get modification time for the current file (the one that contains this
function)
+ current_file_mtime = Path(__file__).stat().st_mtime
+
+ # Build the hash key from metadata
+ hash_key = {
+ "name": name,
+ "target": str(target),
+ "latest_src_mtime": latest_src_mtime,
+ "current_file_mtime": current_file_mtime,
+ }
+
+ system_lib_hash_value = hashlib.md5(
+ json.dumps(hash_key, sort_keys=True, indent=2).encode("utf-8")
+ ).hexdigest()
+
+ # Check if a valid hash exists in the build directory
+ hash_file = build_directory / "hash.txt"
Review Comment:
Let's use a more typical name `hash.md5` instead of txt.
##########
python/tvm/relax/backend/cuda/flashinfer.py:
##########
@@ -37,7 +39,43 @@ def _compile_flashinfer_kernels(
FLASHINFER_TVM_BINDING_DIR,
)
- # Todo(tvm-team): enable compilation cache
+ # ------------------------------------------------------------------------
+ # Caching Flow: create build_directory and compute cache hash.
+ # ------------------------------------------------------------------------
+ build_directory = FLASHINFER_JIT_DIR / name
+ build_directory.mkdir(parents=True, exist_ok=True)
+
+ # Compute latest modification time among all source files
+ latest_src_mtime = max(src.stat().st_mtime for src in source_paths)
+
+ # Get modification time for the current file (the one that contains this
function)
+ current_file_mtime = Path(__file__).stat().st_mtime
+
+ # Build the hash key from metadata
+ hash_key = {
+ "name": name,
+ "target": str(target),
+ "latest_src_mtime": latest_src_mtime,
+ "current_file_mtime": current_file_mtime,
+ }
+
+ system_lib_hash_value = hashlib.md5(
+ json.dumps(hash_key, sort_keys=True, indent=2).encode("utf-8")
+ ).hexdigest()
+
+ # Check if a valid hash exists in the build directory
+ hash_file = build_directory / "hash.txt"
+ if hash_file.exists():
+ with open(hash_file, "r") as f:
+ cached_hash = f.read().strip()
+ if cached_hash == system_lib_hash_value:
+ # Cache hit: return all object files in build_directory
+ return list(build_directory.glob("*.o"))
Review Comment:
For cache hit, we also need to make sure all .o files exist and are not
modified. If any of the .o files is missing or is modified since the
`latest_object_mtime`, we also need to recompile. So let's add this into the
hash key as well.
##########
python/tvm/relax/backend/cuda/flashinfer.py:
##########
@@ -202,7 +237,7 @@ def gen_flashinfer_prefill_module(
)
jit_args = {
"backend": backend,
- "uri": "batch_prefill_tvm",
+ "uri":
f"batch_prefill_tvm_dtype_q_{dtype_q}_dtype_kv_{dtype_kv}_dtype_o_{dtype_o}_qk_head_dim_{qk_head_dim}_v_head_dim_{v_head_dim}_enable_inline_rope_{enable_inline_rope}",
Review Comment:
Let's split the long string into multiple lines.
##########
python/tvm/relax/backend/cuda/flashinfer.py:
##########
@@ -37,7 +39,43 @@ def _compile_flashinfer_kernels(
FLASHINFER_TVM_BINDING_DIR,
)
- # Todo(tvm-team): enable compilation cache
+ # ------------------------------------------------------------------------
+ # Caching Flow: create build_directory and compute cache hash.
+ # ------------------------------------------------------------------------
+ build_directory = FLASHINFER_JIT_DIR / name
+ build_directory.mkdir(parents=True, exist_ok=True)
+
+ # Compute latest modification time among all source files
+ latest_src_mtime = max(src.stat().st_mtime for src in source_paths)
+
+ # Get modification time for the current file (the one that contains this
function)
+ current_file_mtime = Path(__file__).stat().st_mtime
+
+ # Build the hash key from metadata
+ hash_key = {
+ "name": name,
+ "target": str(target),
+ "latest_src_mtime": latest_src_mtime,
+ "current_file_mtime": current_file_mtime,
+ }
+
+ system_lib_hash_value = hashlib.md5(
Review Comment:
```suggestion
hash_value = hashlib.md5(
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]