vinx13 commented on code in PR #16764:
URL: https://github.com/apache/tvm/pull/16764#discussion_r1539988421


##########
src/runtime/contrib/cublas/cublas_json_runtime.cc:
##########
@@ -129,14 +132,50 @@ class CublasJSONRuntime : public JSONRuntimeBase {
 
         auto [a_ptr, b_ptr, bias_ptr] = get_inputs(node, epilogue != 
CUBLASLT_EPILOGUE_DEFAULT);
 
+        const cublasLtMatmulAlgo_t* predef_algo_ptr = nullptr;
+        int64_t dyn_dim_val = 
dl_tensors[std::get<0>(dyn_dim_position)]->shape[std::get<1>(dyn_dim_position)];
+        auto algo_desc = algo_collection(dyn_dim_val);
+        if (algo_desc.defined())
+          predef_algo_ptr = &algo_desc->algo;
+
         tvm::contrib::CallCublasLt(entry_ptr->handle, stream, 
entry_ptr->matmul_pref_desc, a_ptr,
                                    b_ptr, bias_ptr, out_ptr, transa, transb,
-                                   entry_ptr->workspace_ptr, 
entry_ptr->workspace_size, epilogue);
+                                   entry_ptr->workspace_ptr, 
entry_ptr->workspace_size, epilogue,
+                                   predef_algo_ptr);
       }
     }
   }
 
   void Run() override { LOG(FATAL) << "Unreachable"; }
+
+ protected:
+  void LoadPredefAlgoCollection() {
+    for (const auto& node : nodes_) {
+      if (node.GetOpType() == "kernel" && node.HasAttr("predefined_algos")) {
+        // Load algo collection
+        auto predef_algos_str = 
node.GetAttr<std::vector<std::string>>("predefined_algos");
+        ICHECK_EQ(predef_algos_str.size(), 1);
+        algo_collection = 
tvm::contrib::AlgoCollection::FromJSON(predef_algos_str[0]);
+
+        // Define dynamic dimension position
+        for (const auto& ne : node.GetInputs()) {
+          auto shape = nodes_[ne.id_].GetOpShape()[ne.index_];
+          auto found = std::find(shape.begin(), shape.end(), -1);
+          if (found != shape.end()) {
+            uint32_t dyn_dim_idx = std::distance(shape.begin(), found);
+            uint32_t dyn_dim_eid = EntryID(ne);
+            dyn_dim_position = {dyn_dim_eid, dyn_dim_idx};

Review Comment:
   when there are multiple nodes with predefined algos, does overwrite the 
results of previous iterations?



##########
src/relax/backend/contrib/cublas/algo_db.h:
##########
@@ -0,0 +1,104 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \brief Codegen part of tuning capabilities for cublas matmul primitives. 
+ */
+
+#include <dmlc/json.h>
+
+#include "../../../../runtime/contrib/cublas/cublas_algo.h"
+
+namespace tvm {
+namespace relax {
+namespace contrib {
+
+using AlgoCollection = tvm::contrib::AlgoCollection;
+using AlgoDesc = tvm::contrib::AlgoDesc;
+
+/*! \brief Algo database with predefined Algo objects. */
+class AlgoDatabaseNode: public runtime::Object {
+  /*! \brief Mapping of compisite func struct hash to algo colelction. */
+  std::map<uint64_t, AlgoCollection> collections;
+
+public:
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    // v->Visit("collections", &collections);

Review Comment:
   remove this



##########
src/runtime/contrib/cublas/cublas_json_runtime.cc:
##########
@@ -129,14 +132,50 @@ class CublasJSONRuntime : public JSONRuntimeBase {
 
         auto [a_ptr, b_ptr, bias_ptr] = get_inputs(node, epilogue != 
CUBLASLT_EPILOGUE_DEFAULT);
 
+        const cublasLtMatmulAlgo_t* predef_algo_ptr = nullptr;
+        int64_t dyn_dim_val = 
dl_tensors[std::get<0>(dyn_dim_position)]->shape[std::get<1>(dyn_dim_position)];
+        auto algo_desc = algo_collection(dyn_dim_val);
+        if (algo_desc.defined())
+          predef_algo_ptr = &algo_desc->algo;

Review Comment:
   nit
   ```suggestion
           if (algo_desc.defined()) {
             predef_algo_ptr = &algo_desc->algo;
           }
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscr...@tvm.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to