This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new c481950807 [Relax][Refactor] Phase out FewShotTuning (#18864)
c481950807 is described below
commit c481950807791f0d3c9e005381f76555b0ceb5aa
Author: Tianqi Chen <[email protected]>
AuthorDate: Mon Mar 2 12:49:42 2026 -0500
[Relax][Refactor] Phase out FewShotTuning (#18864)
## Summary
- Remove `FewShotTuning` pass from Relax transform (C++ implementation,
Python bindings, and test file)
- The pass is unused in the current codebase and can be safely removed
## Files Changed
- `include/tvm/relax/transform.h` — Remove declaration
- `python/tvm/relax/transform/__init__.py` — Remove from imports
- `python/tvm/relax/transform/transform.py` — Remove Python function
- `src/relax/transform/few_shot_tuning.cc` — Delete (C++ implementation)
- `tests/python/relax/test_transform_few_shot_tuning.py` — Delete (test
file)
---
include/tvm/relax/transform.h | 12 -
python/tvm/relax/transform/__init__.py | 1 -
python/tvm/relax/transform/transform.py | 24 --
src/relax/transform/few_shot_tuning.cc | 188 ----------
tests/lint/check_asf_header.py | 22 +-
.../python/relax/test_transform_few_shot_tuning.py | 392 ---------------------
6 files changed, 18 insertions(+), 621 deletions(-)
diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h
index 0e660292c4..9ffeb05f8f 100644
--- a/include/tvm/relax/transform.h
+++ b/include/tvm/relax/transform.h
@@ -673,18 +673,6 @@ ToMixedPrecision(const DataType& out_dtype,
*/
TVM_DLL Pass RewriteCUDAGraph();
-/*!
- * \brief The pass is designed for few shot tuning for static shape PrimFuncs.
It examines all the
- * blocks within the PrimFunc and conducts loop fusion, splitting, and other
transformations based
- * on MetaSchedule schedule rules but directly samples from the search space
instead of using the
- * tuning algorithm. User can specify the number of valid counts to try and
whether to use runner
- * for benchmarking.
- * \param valid_count The number of valid counts to try.
- * \param benchmark Whether to use runner for benchmarking.
- * \return The Pass.
- */
-TVM_DLL Pass FewShotTuning(int valid_count, bool benchmark);
-
/*!
* \brief This pass updates the var_buffer mapping of PrimFunctions from the
call_tir info.
* Primarily used to update the VDevice information if any changes occured
from the caller.
diff --git a/python/tvm/relax/transform/__init__.py
b/python/tvm/relax/transform/__init__.py
index 5bf79dc7c8..c3188adf50 100644
--- a/python/tvm/relax/transform/__init__.py
+++ b/python/tvm/relax/transform/__init__.py
@@ -41,7 +41,6 @@ from .transform import (
EliminateCommonSubexpr,
ExpandMatmulOfSum,
ExpandTupleArguments,
- FewShotTuning,
FoldConstant,
FunctionPass,
FuseOps,
diff --git a/python/tvm/relax/transform/transform.py
b/python/tvm/relax/transform/transform.py
index 8911423c65..e70392e88d 100644
--- a/python/tvm/relax/transform/transform.py
+++ b/python/tvm/relax/transform/transform.py
@@ -1264,30 +1264,6 @@ def MetaScheduleTuneIRMod(
) # type: ignore
-def FewShotTuning(
- valid_count: int = 1,
- benchmark: bool = False,
-) -> tvm.ir.transform.Pass:
- """The pass is designed for few shot tuning for static shape PrimFuncs. It
examines all the
- blocks within the PrimFunc and conducts loop fusion, splitting, and other
transformations based
- on MetaSchedule schedule rules but directly samples from the search space
instead of using the
- tuning algorithm. User can specify the number of valid counts to try and
whether to use runner
- for benchmarking.
-
- Parameters
- ----------
- valid_count: int
- The number of valid counts to try.
- benchmark: bool
- Whether to use runner for benchmarking.
-
- Returns
- -------
- ret: tvm.ir.transform.Pass
- """
- return _ffi_api.FewShotTuning(valid_count, benchmark) # type: ignore
-
-
def DecomposeOpsForInference(func_name: str | None = None) ->
tvm.ir.transform.Pass:
"""Decompose composite operators that are composed by other operators
during inference.
For example, the result of batch norm (a triple) will be simplified.
Attention, tensor_to_shape,
diff --git a/src/relax/transform/few_shot_tuning.cc
b/src/relax/transform/few_shot_tuning.cc
deleted file mode 100644
index a88b92e8e4..0000000000
--- a/src/relax/transform/few_shot_tuning.cc
+++ /dev/null
@@ -1,188 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-#include <tvm/ffi/reflection/registry.h>
-#include <tvm/relax/transform.h>
-
-#include "../../s_tir/meta_schedule/utils.h"
-
-namespace tvm {
-namespace relax {
-namespace transform {
-
-tir::PrimFunc FewShotTunePrimFunc(const tir::PrimFunc& prim_func, const
Target& target,
- int64_t valid_count, bool benchmark) {
- // fetch a local builder
- static const auto f_get_local_builder =
-
tvm::ffi::Function::GetGlobalRequired("s_tir.meta_schedule.builder.get_local_builder");
- s_tir::meta_schedule::Builder builder =
- f_get_local_builder().cast<s_tir::meta_schedule::Builder>();
- TVM_FFI_CHECK(builder.defined(), ValueError) << "The local builder is not
defined!";
- // fetch a local runner
- s_tir::meta_schedule::Runner runner{ffi::UnsafeInit()};
- if (benchmark) {
- static const auto f_get_local_runner =
-
tvm::ffi::Function::GetGlobalRequired("s_tir.meta_schedule.runner.get_local_runner");
- runner = f_get_local_runner().cast<s_tir::meta_schedule::Runner>();
- TVM_FFI_CHECK(runner.defined(), ValueError) << "The local runner is not
defined!";
- }
- // create an IRModule
- IRModule mod = IRModule(ffi::Map<GlobalVar, BaseFunc>(
- {{GlobalVar("main"), WithAttr(prim_func, tvm::attr::kGlobalSymbol,
ffi::String("main"))}}));
- // fetch the number of physical cores
- static const auto f_cpu_count =
- tvm::ffi::Function::GetGlobalRequired("s_tir.meta_schedule.cpu_count");
- int num_threads = f_cpu_count(false).cast<int>();
- // store the results
- ffi::Array<IRModule> results;
- std::vector<double> costs;
- // create a TuneContext
- s_tir::meta_schedule::TuneContext task = s_tir::meta_schedule::TuneContext(
- /*mod=*/mod,
- /*target=*/target,
- /*space_generator=*/
-
s_tir::meta_schedule::SpaceGenerator::PostOrderApply(/*f_block_filter=*/nullptr,
-
/*sch_rules=*/std::nullopt,
-
/*postprocs=*/std::nullopt,
-
/*mutator_probs=*/std::nullopt),
-
/*search_strategy=*/s_tir::meta_schedule::SearchStrategy::ReplayTrace(/*max_fail_count=*/100),
- /*task_name=*/std::nullopt,
- /*num_threads=*/num_threads, // use all available local threads
- /*rand_state=*/-1, // -1 means use random seed
- /*logger=*/nullptr);
- task->Initialize();
- task->search_strategy.value()->PreTuning(
- /*max_trials=*/valid_count, /*num_trials_per_iter=*/valid_count,
-
/*design_spaces=*/task->space_generator.value()->GenerateDesignSpace(mod),
- /*database=*/std::nullopt,
- /*cost_model=*/std::nullopt);
- int fail_count = 0, max_fail_count = 100;
- while (valid_count > 0 && fail_count < max_fail_count) {
- ffi::Optional<ffi::Array<s_tir::meta_schedule::MeasureCandidate>>
candidates =
- task->search_strategy.value()->GenerateMeasureCandidates();
- if (!candidates.defined()) break;
- ffi::Array<s_tir::meta_schedule::BuilderInput> builder_inputs;
- for (const s_tir::meta_schedule::MeasureCandidate& candidate :
candidates.value()) {
- builder_inputs.push_back(s_tir::meta_schedule::BuilderInput(
- /*mod=*/candidate->sch->mod(),
- /*target=*/target));
- }
- ffi::Array<s_tir::meta_schedule::BuilderResult> builder_results =
- builder->Build(builder_inputs);
- TVM_FFI_ICHECK_EQ(builder_results.size(), candidates.value().size());
- int idx = 0;
- bool no_valid = true; // whether there is no valid schedule in this
iteration
- for (const s_tir::meta_schedule::BuilderResult& builder_result :
builder_results) {
- if (!builder_result->error_msg.has_value()) {
- results.push_back(candidates.value()[idx]->sch->mod());
- valid_count--;
- no_valid = false;
- }
- idx++;
- }
- fail_count += no_valid; // increase fail_count if there is no valid
schedule
- if (benchmark) {
- ffi::Array<s_tir::meta_schedule::RunnerInput> runner_inputs;
- int idx = 0;
- for (const s_tir::meta_schedule::BuilderResult& builder_result :
builder_results) {
- if (!builder_result->error_msg.has_value()) {
- runner_inputs.push_back(s_tir::meta_schedule::RunnerInput(
- /*artifact_path=*/builder_result->artifact_path.value(),
- /*device_type=*/target->kind->name,
- /*args_info=*/candidates.value()[idx]->args_info));
- }
- idx++;
- }
- ffi::Array<s_tir::meta_schedule::RunnerFuture> runner_futures =
runner->Run(runner_inputs);
- for (const s_tir::meta_schedule::RunnerFuture& runner_future :
runner_futures) {
- s_tir::meta_schedule::RunnerResult runner_result =
runner_future->Result();
- if (runner_result->error_msg.has_value()) {
- costs.push_back(1e10);
- } else {
- double sum = 0;
- for (const FloatImm& cost : runner_result->run_secs.value()) {
- sum += cost->value;
- }
- costs.push_back(sum / runner_result->run_secs.value().size());
- }
- }
- TVM_FFI_ICHECK_EQ(costs.size(), results.size());
- }
- }
- if (results.size() == 0) {
- LOG(WARNING) << "No valid schedule found";
- return prim_func;
- }
- if (fail_count >= max_fail_count) {
- LOG(WARNING) << "Reached the maximum number of failed trials";
- }
- int best_idx = 0;
- if (benchmark) {
- for (size_t i = 1; i < costs.size(); ++i) {
- if (costs[i] < costs[best_idx]) {
- best_idx = i;
- }
- }
- } else {
- best_idx = results.size() - 1;
- }
- return WithAttr(Downcast<tir::PrimFunc>(results[best_idx]->Lookup("main")),
- tvm::tir::attr::kIsScheduled, Bool(true));
-}
-
-Pass FewShotTuning(int valid_count, bool benchmark) {
- auto pass_func = //
- [=](IRModule m, PassContext pc) {
- // input check
- TVM_FFI_ICHECK(valid_count > 0) << "Valid_count must be positive.";
- TVM_FFI_ICHECK(valid_count > 1 || !benchmark)
- << "Benchmarking requires at least two valid trials.";
- // get the target from context.
- tvm::Target target = tvm::Target::Current();
- TVM_FFI_ICHECK(target.defined()) << "Target is not set in current
context";
- // generate the few shot tuned prim funcs.
- ffi::Map<GlobalVar, BaseFunc> result;
- for (const auto& [gv, func] : m->functions) {
- if (func->IsInstance<tir::PrimFuncNode>() &&
- !func->HasNonzeroAttr(tir::attr::kIsScheduled)) {
- result.Set(gv,
-
FewShotTunePrimFunc(ffi::GetRef<tir::PrimFunc>(func.as<tir::PrimFuncNode>()),
- target, valid_count, benchmark));
- } else {
- result.Set(gv, func);
- }
- }
- return IRModule(result, // functions
- m->source_map, // map
- m->attrs); // attrs);
- };
- return CreateModulePass(/*pass_function=*/pass_func, //
- /*opt_level=*/0, //
- /*pass_name=*/"FewShotTuning", //
- /*required=*/{});
-}
-
-TVM_FFI_STATIC_INIT_BLOCK() {
- namespace refl = tvm::ffi::reflection;
- refl::GlobalDef().def("relax.transform.FewShotTuning", FewShotTuning);
-}
-
-} // namespace transform
-} // namespace relax
-} // namespace tvm
diff --git a/tests/lint/check_asf_header.py b/tests/lint/check_asf_header.py
index f5dcf22fcd..f0bfdc6a87 100644
--- a/tests/lint/check_asf_header.py
+++ b/tests/lint/check_asf_header.py
@@ -216,17 +216,31 @@ def should_skip_file(filepath: str) -> bool:
def get_git_files() -> list[str] | None:
- """Get list of files tracked by git."""
+ """Get list of files tracked by git (excluding files staged for
deletion)."""
try:
result = subprocess.run(
["git", "ls-files"], check=False, capture_output=True, text=True,
cwd=Path.cwd()
)
- if result.returncode == 0:
- return [line.strip() for line in result.stdout.split("\n") if
line.strip()]
- else:
+ if result.returncode != 0:
print("Error: Could not get git files. Make sure you're in a git
repository.")
print("Git command failed:", result.stderr.strip())
return None
+ all_files = {line.strip() for line in result.stdout.split("\n") if
line.strip()}
+ # Exclude files staged for deletion so the header check does not
+ # report errors for files that are intentionally being removed.
+ deleted_result = subprocess.run(
+ ["git", "ls-files", "--deleted"],
+ check=False,
+ capture_output=True,
+ text=True,
+ cwd=Path.cwd(),
+ )
+ if deleted_result.returncode == 0:
+ deleted = {line.strip() for line in
deleted_result.stdout.split("\n") if line.strip()}
+ all_files -= deleted
+ elif deleted_result.stderr:
+ print(f"Warning: 'git ls-files --deleted' failed:
{deleted_result.stderr.strip()}")
+ return sorted(all_files)
except FileNotFoundError:
print("Error: Git not found. This tool requires git to be installed.")
return None
diff --git a/tests/python/relax/test_transform_few_shot_tuning.py
b/tests/python/relax/test_transform_few_shot_tuning.py
deleted file mode 100644
index 6c8ee37290..0000000000
--- a/tests/python/relax/test_transform_few_shot_tuning.py
+++ /dev/null
@@ -1,392 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-# pylint: disable=invalid-name,,missing-function-docstring
-# ruff: noqa: E501, F403
-
-import numpy as np
-import pytest
-
-import tvm
-import tvm.testing
-from tvm.relax.transform import FewShotTuning
-from tvm.s_tir.meta_schedule.arg_info import ArgInfo
-from tvm.s_tir.meta_schedule.testing.tune_utils import generate_input_data
-from tvm.s_tir.tensor_intrin.cuda import * # pylint:
disable=wildcard-import,unused-wildcard-import
-from tvm.s_tir.tensor_intrin.x86 import * # pylint:
disable=wildcard-import,unused-wildcard-import
-from tvm.script import tir as T
-
-
-# pylint: disable=no-self-argument,missing-class-docstring,line-too-long
-# fmt: off
[email protected]_module
-class MatMul:
- @T.prim_func
- def matmul(
- A: T.Buffer((32, 32), "float16"),
- B: T.Buffer((32, 32), "float16"),
- C: T.Buffer((32, 32), "float16"),
- ):
- T.func_attr({"tir.noalias": True})
- # with T.sblock("root"):
- for i, j, k in T.grid(32, 32, 32):
- with T.sblock("C"):
- v_i, v_j, v_k = T.axis.remap("SSR", [i, j, k])
- T.reads(A[v_i, v_k], B[v_k, v_j])
- T.writes(C[v_i, v_j])
- with T.init():
- C[v_i, v_j] = T.float16(0)
- C[v_i, v_j] = C[v_i, v_j] + A[v_i, v_k] * B[v_k, v_j]
-
[email protected]_module
-class Softmax:
- @T.prim_func
- def softmax(rxplaceholder: T.Buffer((T.int64(8), T.int64(3456),
T.int64(3456)), "float32"), T_softmax_norm: T.Buffer((T.int64(8),
T.int64(3456), T.int64(3456)), "float32")):
- T.func_attr({"op_pattern": 4, "tir.noalias": True})
- # with T.sblock("root"):
- T_softmax_maxelem = T.alloc_buffer((T.int64(8), T.int64(3456)),
"float32")
- T_softmax_exp = T.alloc_buffer((T.int64(8), T.int64(3456),
T.int64(3456)), "float32")
- T_softmax_expsum = T.alloc_buffer((T.int64(8), T.int64(3456)),
"float32")
- for i0, i1, k in T.grid(T.int64(8), T.int64(3456), T.int64(3456)):
- with T.sblock("T_softmax_maxelem"):
- v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
- T.reads(rxplaceholder[v_i0, v_i1, v_k])
- T.writes(T_softmax_maxelem[v_i0, v_i1])
- with T.init():
- T_softmax_maxelem[v_i0, v_i1] = T.float16(-65504)
- T_softmax_maxelem[v_i0, v_i1] = T.max(T_softmax_maxelem[v_i0,
v_i1], rxplaceholder[v_i0, v_i1, v_k])
- for i0, i1, i2 in T.grid(T.int64(8), T.int64(3456), T.int64(3456)):
- with T.sblock("T_softmax_exp"):
- v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2])
- T.reads(rxplaceholder[v_i0, v_i1, v_i2],
T_softmax_maxelem[v_i0, v_i1])
- T.writes(T_softmax_exp[v_i0, v_i1, v_i2])
- T_softmax_exp[v_i0, v_i1, v_i2] = T.exp(rxplaceholder[v_i0,
v_i1, v_i2] - T_softmax_maxelem[v_i0, v_i1])
- for i0, i1, k in T.grid(T.int64(8), T.int64(3456), T.int64(3456)):
- with T.sblock("T_softmax_expsum"):
- v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
- T.reads(T_softmax_exp[v_i0, v_i1, v_k])
- T.writes(T_softmax_expsum[v_i0, v_i1])
- with T.init():
- T_softmax_expsum[v_i0, v_i1] = T.float16(0)
- T_softmax_expsum[v_i0, v_i1] = T_softmax_expsum[v_i0, v_i1] +
T_softmax_exp[v_i0, v_i1, v_k]
- for i0, i1, i2 in T.grid(T.int64(8), T.int64(3456), T.int64(3456)):
- with T.sblock("T_softmax_norm"):
- v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2])
- T.reads(T_softmax_exp[v_i0, v_i1, v_i2],
T_softmax_expsum[v_i0, v_i1])
- T.writes(T_softmax_norm[v_i0, v_i1, v_i2])
- T.sblock_attr({"axis": 2})
- T_softmax_norm[v_i0, v_i1, v_i2] = T_softmax_exp[v_i0, v_i1,
v_i2] / T_softmax_expsum[v_i0, v_i1]
-
[email protected]_module
-class Fused_Variance_Cast1:
- @T.prim_func
- def main(lv3: T.Buffer((T.int64(1), T.int64(32), T.int64(34560)),
"float32"), compute: T.Buffer((T.int64(1), T.int64(32), T.int64(1)),
"float16")):
- T.func_attr({"tir.noalias": True})
- # with T.sblock("root"):
- rxplaceholder_red = T.alloc_buffer((T.int64(1), T.int64(32),
T.int64(1)))
- T_divide = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1)))
- T_subtract = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(34560)))
- T_multiply = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(34560)))
- T_multiply_red = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1)))
- T_divide_1 = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1)))
- for ax0, ax1, ax2, k2 in T.grid(T.int64(1), T.int64(32), T.int64(1),
T.int64(34560)):
- with T.sblock("rxplaceholder_red"):
- v_ax0, v_ax1, v_ax2, v_k2 = T.axis.remap("SSSR", [ax0, ax1,
ax2, k2])
- T.reads(lv3[v_ax0, v_ax1, v_k2])
- T.writes(rxplaceholder_red[v_ax0, v_ax1, v_ax2])
- with T.init():
- rxplaceholder_red[v_ax0, v_ax1, v_ax2] = T.float32(0)
- rxplaceholder_red[v_ax0, v_ax1, v_ax2] =
rxplaceholder_red[v_ax0, v_ax1, v_ax2] + lv3[v_ax0, v_ax1, v_k2]
- for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(32), T.int64(1)):
- with T.sblock("T_divide"):
- v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
- T.reads(rxplaceholder_red[v_ax0, v_ax1, v_ax2])
- T.writes(T_divide[v_ax0, v_ax1, v_ax2])
- T_divide[v_ax0, v_ax1, v_ax2] = rxplaceholder_red[v_ax0,
v_ax1, v_ax2] * T.float32(2.8935185185185186e-05)
- for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(32), T.int64(34560)):
- with T.sblock("T_subtract"):
- v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
- T.reads(lv3[v_ax0, v_ax1, v_ax2], T_divide[v_ax0, v_ax1,
T.int64(0)])
- T.writes(T_subtract[v_ax0, v_ax1, v_ax2])
- T_subtract[v_ax0, v_ax1, v_ax2] = lv3[v_ax0, v_ax1, v_ax2] -
T_divide[v_ax0, v_ax1, T.int64(0)]
- for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(32), T.int64(34560)):
- with T.sblock("T_multiply"):
- v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
- T.reads(T_subtract[v_ax0, v_ax1, v_ax2])
- T.writes(T_multiply[v_ax0, v_ax1, v_ax2])
- T_multiply[v_ax0, v_ax1, v_ax2] = T_subtract[v_ax0, v_ax1,
v_ax2] * T_subtract[v_ax0, v_ax1, v_ax2]
- for ax0, ax1, ax2, k2 in T.grid(T.int64(1), T.int64(32), T.int64(1),
T.int64(34560)):
- with T.sblock("T_multiply_red"):
- v_ax0, v_ax1, v_ax2, v_k2 = T.axis.remap("SSSR", [ax0, ax1,
ax2, k2])
- T.reads(T_multiply[v_ax0, v_ax1, v_k2])
- T.writes(T_multiply_red[v_ax0, v_ax1, v_ax2])
- with T.init():
- T_multiply_red[v_ax0, v_ax1, v_ax2] = T.float32(0)
- T_multiply_red[v_ax0, v_ax1, v_ax2] = T_multiply_red[v_ax0,
v_ax1, v_ax2] + T_multiply[v_ax0, v_ax1, v_k2]
- for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(32), T.int64(1)):
- with T.sblock("T_divide_1"):
- v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
- T.reads(T_multiply_red[v_ax0, v_ax1, v_ax2])
- T.writes(T_divide_1[v_ax0, v_ax1, v_ax2])
- T_divide_1[v_ax0, v_ax1, v_ax2] = T_multiply_red[v_ax0, v_ax1,
v_ax2] * T.float32(2.8935185185185186e-05)
- for i0, i1, i2 in T.grid(T.int64(1), T.int64(32), T.int64(1)):
- with T.sblock("compute"):
- v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2])
- T.reads(T_divide_1[v_i0, v_i1, v_i2])
- T.writes(compute[v_i0, v_i1, v_i2])
- compute[v_i0, v_i1, v_i2] = T.Cast("float16", T_divide_1[v_i0,
v_i1, v_i2])
-
[email protected]_module
-class Fuse_Mean_Cast1:
- @T.prim_func
- def main(lv: T.Buffer((T.int64(1), T.int64(32), T.int64(34560)),
"float32"), compute: T.Buffer((T.int64(1), T.int64(32), T.int64(1)),
"float16")):
- T.func_attr({"tir.noalias": True})
- # with T.sblock("root"):
- rxplaceholder_red = T.alloc_buffer((T.int64(1), T.int64(32),
T.int64(1)))
- T_divide = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1)))
- for ax0, ax1, ax2, k2 in T.grid(T.int64(1), T.int64(32), T.int64(1),
T.int64(34560)):
- with T.sblock("rxplaceholder_red"):
- v_ax0, v_ax1, v_ax2, v_k2 = T.axis.remap("SSSR", [ax0, ax1,
ax2, k2])
- T.reads(lv[v_ax0, v_ax1, v_k2])
- T.writes(rxplaceholder_red[v_ax0, v_ax1, v_ax2])
- with T.init():
- rxplaceholder_red[v_ax0, v_ax1, v_ax2] = T.float32(0)
- rxplaceholder_red[v_ax0, v_ax1, v_ax2] =
rxplaceholder_red[v_ax0, v_ax1, v_ax2] + lv[v_ax0, v_ax1, v_k2]
- for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(32), T.int64(1)):
- with T.sblock("T_divide"):
- v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
- T.reads(rxplaceholder_red[v_ax0, v_ax1, v_ax2])
- T.writes(T_divide[v_ax0, v_ax1, v_ax2])
- T_divide[v_ax0, v_ax1, v_ax2] = rxplaceholder_red[v_ax0,
v_ax1, v_ax2] * T.float32(2.8935185185185186e-05)
- for i0, i1, i2 in T.grid(T.int64(1), T.int64(32), T.int64(1)):
- with T.sblock("compute"):
- v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2])
- T.reads(T_divide[v_i0, v_i1, v_i2])
- T.writes(compute[v_i0, v_i1, v_i2])
- compute[v_i0, v_i1, v_i2] = T.Cast("float16", T_divide[v_i0,
v_i1, v_i2])
-
[email protected]_module
-class Module:
- @T.prim_func
- def main(lv26: T.Buffer((T.int64(1), T.int64(3456), T.int64(2560)),
"float16"), T_multiply: T.Buffer((T.int64(1), T.int64(3456), T.int64(1280)),
"float16")):
- T.func_attr({"tir.noalias": True})
- # with T.sblock("root"):
- T_strided_slice_with_axes = T.alloc_buffer((T.int64(1), T.int64(3456),
T.int64(1280)), "float16")
- T_divide = T.alloc_buffer((T.int64(1), T.int64(3456), T.int64(1280)),
"float16")
- T_multiply_1 = T.alloc_buffer((T.int64(1), T.int64(3456),
T.int64(1280)), "float16")
- T_multiply_2 = T.alloc_buffer((T.int64(1), T.int64(3456),
T.int64(1280)), "float16")
- compute = T.alloc_buffer((T.int64(1), T.int64(3456), T.int64(1280)))
- compute_1 = T.alloc_buffer((T.int64(1), T.int64(3456), T.int64(1280)))
- compute_2 = T.alloc_buffer((T.int64(1), T.int64(3456), T.int64(1280)),
"float16")
- T_multiply_3 = T.alloc_buffer((T.int64(1), T.int64(3456),
T.int64(1280)), "float16")
- T_add = T.alloc_buffer((T.int64(1), T.int64(3456), T.int64(1280)),
"float16")
- T_multiply_4 = T.alloc_buffer((T.int64(1), T.int64(3456),
T.int64(1280)), "float16")
- T_multiply_5 = T.alloc_buffer((T.int64(1), T.int64(3456),
T.int64(1280)), "float16")
- T_divide_1 = T.alloc_buffer((T.int64(1), T.int64(3456),
T.int64(1280)), "float16")
- T_add_1 = T.alloc_buffer((T.int64(1), T.int64(3456), T.int64(1280)),
"float16")
- T_add_2 = T.alloc_buffer((T.int64(1), T.int64(3456), T.int64(1280)),
"float16")
- T_multiply_6 = T.alloc_buffer((T.int64(1), T.int64(3456),
T.int64(1280)), "float16")
- T_strided_slice_with_axes_1 = T.alloc_buffer((T.int64(1),
T.int64(3456), T.int64(1280)), "float16")
- T_multiply_7 = T.alloc_buffer((T.int64(1), T.int64(3456),
T.int64(1280)), "float16")
- for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(3456), T.int64(1280)):
- with T.sblock("T_strided_slice_with_axes"):
- v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
- T.reads(lv26[v_ax0, v_ax1, v_ax2 + T.int64(1280)])
- T.writes(T_strided_slice_with_axes[v_ax0, v_ax1, v_ax2])
- T_strided_slice_with_axes[v_ax0, v_ax1, v_ax2] = lv26[v_ax0,
v_ax1, v_ax2 + T.int64(1280)]
- for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(3456), T.int64(1280)):
- with T.sblock("T_divide"):
- v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
- T.reads(T_strided_slice_with_axes[v_ax0, v_ax1, v_ax2])
- T.writes(T_divide[v_ax0, v_ax1, v_ax2])
- T_divide[v_ax0, v_ax1, v_ax2] =
T_strided_slice_with_axes[v_ax0, v_ax1, v_ax2] * T.float16(0.70718232044198892)
- for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(3456), T.int64(1280)):
- with T.sblock("T_multiply"):
- v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
- T.reads(T_divide[v_ax0, v_ax1, v_ax2])
- T.writes(T_multiply_1[v_ax0, v_ax1, v_ax2])
- T_multiply_1[v_ax0, v_ax1, v_ax2] = T_divide[v_ax0, v_ax1,
v_ax2] * T.float16(1.4140625)
- for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(3456), T.int64(1280)):
- with T.sblock("T_multiply_1"):
- v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
- T.reads(T_multiply_1[v_ax0, v_ax1, v_ax2])
- T.writes(T_multiply_2[v_ax0, v_ax1, v_ax2])
- T_multiply_2[v_ax0, v_ax1, v_ax2] = T_multiply_1[v_ax0, v_ax1,
v_ax2] * T.float16(0.70710678118654757)
- for i0, i1, i2 in T.grid(T.int64(1), T.int64(3456), T.int64(1280)):
- with T.sblock("compute"):
- v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2])
- T.reads(T_multiply_2[v_i0, v_i1, v_i2])
- T.writes(compute[v_i0, v_i1, v_i2])
- compute[v_i0, v_i1, v_i2] = T.Cast("float32",
T_multiply_2[v_i0, v_i1, v_i2])
- for i0, i1, i2 in T.grid(T.int64(1), T.int64(3456), T.int64(1280)):
- with T.sblock("compute_1"):
- v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2])
- T.reads(compute[v_i0, v_i1, v_i2])
- T.writes(compute_1[v_i0, v_i1, v_i2])
- compute_1[v_i0, v_i1, v_i2] = T.erf(compute[v_i0, v_i1, v_i2])
- for i0, i1, i2 in T.grid(T.int64(1), T.int64(3456), T.int64(1280)):
- with T.sblock("compute_2"):
- v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2])
- T.reads(compute_1[v_i0, v_i1, v_i2])
- T.writes(compute_2[v_i0, v_i1, v_i2])
- compute_2[v_i0, v_i1, v_i2] = T.Cast("float16",
compute_1[v_i0, v_i1, v_i2])
- for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(3456), T.int64(1280)):
- with T.sblock("T_multiply_1_1"):
- v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
- T.reads(compute_2[v_ax0, v_ax1, v_ax2])
- T.writes(T_multiply_3[v_ax0, v_ax1, v_ax2])
- T_multiply_3[v_ax0, v_ax1, v_ax2] = compute_2[v_ax0, v_ax1,
v_ax2] * T.float16(0.5)
- for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(3456), T.int64(1280)):
- with T.sblock("T_add"):
- v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
- T.reads(T_multiply_3[v_ax0, v_ax1, v_ax2])
- T.writes(T_add[v_ax0, v_ax1, v_ax2])
- T_add[v_ax0, v_ax1, v_ax2] = T.float16(0.5) +
T_multiply_3[v_ax0, v_ax1, v_ax2]
- for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(3456), T.int64(1280)):
- with T.sblock("T_multiply_2"):
- v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
- T.reads(T_multiply_1[v_ax0, v_ax1, v_ax2], T_add[v_ax0, v_ax1,
v_ax2])
- T.writes(T_multiply_4[v_ax0, v_ax1, v_ax2])
- T_multiply_4[v_ax0, v_ax1, v_ax2] = T_multiply_1[v_ax0, v_ax1,
v_ax2] * T_add[v_ax0, v_ax1, v_ax2]
- for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(3456), T.int64(1280)):
- with T.sblock("T_multiply_3"):
- v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
- T.reads(T_multiply_4[v_ax0, v_ax1, v_ax2])
- T.writes(T_multiply_5[v_ax0, v_ax1, v_ax2])
- T_multiply_5[v_ax0, v_ax1, v_ax2] = T_multiply_4[v_ax0, v_ax1,
v_ax2] * T.float16(1.4140625)
- for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(3456), T.int64(1280)):
- with T.sblock("T_divide_1"):
- v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
- T.reads(T_multiply_5[v_ax0, v_ax1, v_ax2], T_divide[v_ax0,
v_ax1, v_ax2])
- T.writes(T_divide_1[v_ax0, v_ax1, v_ax2])
- T_divide_1[v_ax0, v_ax1, v_ax2] = T_multiply_5[v_ax0, v_ax1,
v_ax2] / T_divide[v_ax0, v_ax1, v_ax2]
- for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(3456), T.int64(1280)):
- with T.sblock("T_add_1"):
- v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
- T.reads(T_divide_1[v_ax0, v_ax1, v_ax2])
- T.writes(T_add_1[v_ax0, v_ax1, v_ax2])
- T_add_1[v_ax0, v_ax1, v_ax2] = T_divide_1[v_ax0, v_ax1, v_ax2]
+ T.float16(-1)
- for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(3456), T.int64(1280)):
- with T.sblock("T_add_2"):
- v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
- T.reads(T_add_1[v_ax0, v_ax1, v_ax2])
- T.writes(T_add_2[v_ax0, v_ax1, v_ax2])
- T_add_2[v_ax0, v_ax1, v_ax2] = T_add_1[v_ax0, v_ax1, v_ax2] +
T.float16(1)
- for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(3456), T.int64(1280)):
- with T.sblock("T_multiply_4"):
- v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
- T.reads(T_strided_slice_with_axes[v_ax0, v_ax1, v_ax2],
T_add_2[v_ax0, v_ax1, v_ax2])
- T.writes(T_multiply_6[v_ax0, v_ax1, v_ax2])
- T_multiply_6[v_ax0, v_ax1, v_ax2] =
T_strided_slice_with_axes[v_ax0, v_ax1, v_ax2] * T_add_2[v_ax0, v_ax1, v_ax2]
- for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(3456), T.int64(1280)):
- with T.sblock("T_strided_slice_with_axes_1"):
- v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
- T.reads(lv26[v_ax0, v_ax1, v_ax2])
- T.writes(T_strided_slice_with_axes_1[v_ax0, v_ax1, v_ax2])
- T_strided_slice_with_axes_1[v_ax0, v_ax1, v_ax2] = lv26[v_ax0,
v_ax1, v_ax2]
- for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(3456), T.int64(1280)):
- with T.sblock("T_multiply_5"):
- v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
- T.reads(T_multiply_6[v_ax0, v_ax1, v_ax2])
- T.writes(T_multiply_7[v_ax0, v_ax1, v_ax2])
- T_multiply_7[v_ax0, v_ax1, v_ax2] = T_multiply_6[v_ax0, v_ax1,
v_ax2] * T.float16(0.5)
- for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(3456), T.int64(1280)):
- with T.sblock("T_multiply_6"):
- v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
- T.reads(T_strided_slice_with_axes_1[v_ax0, v_ax1, v_ax2],
T_multiply_7[v_ax0, v_ax1, v_ax2])
- T.writes(T_multiply[v_ax0, v_ax1, v_ax2])
- T_multiply[v_ax0, v_ax1, v_ax2] =
T_strided_slice_with_axes_1[v_ax0, v_ax1, v_ax2] * T_multiply_7[v_ax0, v_ax1,
v_ax2]
-# fmt: on
-# pylint: enable=no-self-argument,missing-class-docstring,line-too-long
-
-
-def _target() -> tvm.target.Target:
- return tvm.target.Target({"kind": "llvm", "num-cores": 4})
- # for local testing only
- # return tvm.target.Target("nvidia/geforce-rtx-3070")
-
-
-def _acc() -> float:
- return 1e-2 if _target().kind.name == "cuda" else 1e-7
-
-
-def _get_single_prim_func(mod: tvm.ir.IRModule) -> tvm.tir.PrimFunc:
- funcs = [func for func in mod.functions.values()]
- assert len(funcs) == 1, "Only one function is supported."
- return funcs[0]
-
-
-def _get_input_output_info(func: tvm.tir.PrimFunc) -> tuple[list[np.ndarray],
tuple, str]:
- args = ArgInfo.from_prim_func(func)
- inputs = [generate_input_data(x.shape, x.dtype) for x in args[:-1]]
- output_shape = args[-1].shape
- output_dtype = args[-1].dtype
- return inputs, output_shape, output_dtype
-
-
-def _expected_results(
- mod: tvm.ir.IRModule, inputs: list[np.ndarray], output_shape: tuple,
output_dtype: str
-) -> np.ndarray:
- func = _get_single_prim_func(mod)
- func = func.with_attr("global_symbol", "main")
- rt_mod = tvm.compile(func, target="llvm")
- data = [
- tvm.runtime.tensor(x)
- for x in [
- *inputs,
- np.zeros(output_shape, dtype=output_dtype),
- ]
- ]
- rt_mod(*data)
- return data[-1].numpy()
-
-
-def _actual_results(
- actual: tvm.ir.IRModule, inputs: list[np.ndarray], output_shape: tuple,
output_dtype: str
-):
- target = _target()
- actual_rt_mod = tvm.compile(actual, target=target)
- actual_data = [
- tvm.runtime.tensor(x, device=tvm.cuda() if target.kind.name == "cuda"
else tvm.cpu())
- for x in [
- *inputs,
- np.zeros(output_shape, dtype=output_dtype),
- ]
- ]
- actual_rt_mod(*actual_data)
- return actual_data[-1].numpy()
-
-
-def _assert_allclose(mod: tvm.ir.IRModule, actual: tvm.ir.IRModule) -> None:
- inputs, output_shape, output_dtype =
_get_input_output_info(_get_single_prim_func(mod))
- expected_output = _expected_results(mod, inputs, output_shape,
output_dtype)
- actual_output = _actual_results(actual, inputs, output_shape, output_dtype)
- tvm.testing.assert_allclose(expected_output, actual_output, rtol=1e-3,
atol=1e-3)
-
-
-# Fused_Variance_Cast1 not added due to
https://github.com/apache/tvm/issues/14791
[email protected]("mod", [Softmax, MatMul, Fuse_Mean_Cast1, Module])
[email protected]("benchmark", [False, True])
-def test_funcs(mod: tvm.ir.IRModule, benchmark: bool) -> None:
- valid_count = 10 if benchmark else 1
- with _target(), tvm.transform.PassContext(opt_level=3):
- actual = FewShotTuning(valid_count=valid_count)(mod)
- assert _get_single_prim_func(actual).attrs["tir.is_scheduled"], "Schedule
is not applied."
- _assert_allclose(mod, actual)
-
-
-if __name__ == "__main__":
- tvm.testing.main()