This is an automated email from the ASF dual-hosted git repository.

ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 3f27aa8db7 [WebGPU][CodeGen] Override PrintVecElemLoad and Store for 
WebGPU (#17917)
3f27aa8db7 is described below

commit 3f27aa8db7daf3b1e286c88b616274b6703319b1
Author: Charlie Ruan <[email protected]>
AuthorDate: Sat May 3 22:45:24 2025 -0400

    [WebGPU][CodeGen] Override PrintVecElemLoad and Store for WebGPU (#17917)
    
    This PR overrides `PrintVecElemLoad()` and `PrintVecElemStore()`
    for the WebGPU backend.
    
    Otherwise, we would generate things like `(QK_local[0i].s0)` for
    WebGPU, which is not a valid syntax in WGSL.
    Instead, we generate `(QK_local[0i][0])` after this PR. `QK_local` here
    is a `array<vec4<f32>, 1>`.
    
    This issue prevented WebLLM from generating the correct kernel
    after https://github.com/apache/tvm/pull/17748
    
    Co-authored-by: Ruihang Lai <[email protected]>
---
 src/target/source/codegen_webgpu.cc | 11 +++++++++++
 src/target/source/codegen_webgpu.h  |  4 ++++
 2 files changed, 15 insertions(+)

diff --git a/src/target/source/codegen_webgpu.cc 
b/src/target/source/codegen_webgpu.cc
index 1d1df91dc4..90be766638 100644
--- a/src/target/source/codegen_webgpu.cc
+++ b/src/target/source/codegen_webgpu.cc
@@ -348,6 +348,17 @@ void CodeGenWebGPU::PrintSSAAssign(const std::string& 
target, const std::string&
   stream << " = " << src << ";\n";
 }
 
+void CodeGenWebGPU::PrintVecElemLoad(const std::string& vec, DataType t, int i,
+                                     std::ostream& os) {  // NOLINT(*)
+  os << vec << "[" << i << "]";
+}
+
+void CodeGenWebGPU::PrintVecElemStore(const std::string& vec, DataType t, int 
i,
+                                      const std::string& value) {
+  this->PrintIndent();
+  stream << vec << "[" << i << "] = " << value << ";\n";
+}
+
 void CodeGenWebGPU::VisitExpr_(const BroadcastNode* op, std::ostream& os) {  
// NOLINT(*)
   std::string v = PrintExpr(op->value);
   int lanes = op->dtype.lanes();
diff --git a/src/target/source/codegen_webgpu.h 
b/src/target/source/codegen_webgpu.h
index 09f99fb886..b8f2f9a79d 100644
--- a/src/target/source/codegen_webgpu.h
+++ b/src/target/source/codegen_webgpu.h
@@ -58,6 +58,10 @@ class CodeGenWebGPU final : public CodeGenC {
   // assignment printing
   void PrintSSAAssign(const std::string& target, const std::string& src, 
DataType type) final;
 
+  // overload printing vector element load/store
+  void PrintVecElemLoad(const std::string& vec, DataType t, int i, 
std::ostream& os) final;
+  void PrintVecElemStore(const std::string& vec, DataType t, int i, const 
std::string& value) final;
+
   // overload visitor
   void VisitExpr_(const BroadcastNode* op, std::ostream& os) final;   // 
NOLINT(*)
   void VisitExpr_(const CallNode* op, std::ostream& os) final;        // 
NOLINT(*)

Reply via email to