roastduck commented on a change in pull request #5483:
URL: https://github.com/apache/incubator-tvm/pull/5483#discussion_r426130269



##########
File path: src/printer/tir_text_printer.cc
##########
@@ -0,0 +1,597 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tir_text_printer.cc
+ * \brief Printer to print out the IR text format
+ *        that can be parsed by a parser.
+ */
+
+#include <tvm/ir/module.h>
+#include <tvm/ir/type.h>
+#include <tvm/ir/type_functor.h>
+#include <tvm/node/serialization.h>
+#include <tvm/tir/expr.h>
+#include <tvm/tir/function.h>
+#include <tvm/tir/op.h>
+#include <tvm/tir/stmt.h>
+
+#include <algorithm>
+#include <string>
+
+#include "doc.h"
+#include "meta_data.h"
+#include "text_printer.h"
+
+namespace tvm {
+namespace tir {
+
+Doc TIRTextPrinter::Print(const ObjectRef& node) {
+  if (!node.defined()) return Doc::Text("(nullptr)");
+  if (node->IsInstance<StmtNode>()) {
+    return VisitStmt(Downcast<Stmt>(node));
+  } else if (node->IsInstance<AnyNode>()) {
+    return Doc::Text("?");
+  } else if (node->IsInstance<PrimExprNode>()) {
+    return VisitExpr(Downcast<PrimExpr>(node));
+  } else if (node->IsInstance<TypeNode>()) {
+    return VisitType(Downcast<Type>(node));
+  } else if (node->IsInstance<PrimFuncNode>()) {
+    return PrintPrimFunc(Downcast<PrimFunc>(node));
+  } else if (node->IsInstance<IRModuleNode>()) {
+    return PrintIRModule(Downcast<IRModule>(node));
+  } else if (node->IsInstance<ArrayNode>()) {
+    return PrintArray(node.as<ArrayNode>());
+  } else if (node->IsInstance<IterVarNode>()) {
+    return PrintIterVar(node.as<IterVarNode>());
+  } else if (node->IsInstance<RangeNode>()) {
+    return PrintRange(node.as<RangeNode>());
+  } else if (node->IsInstance<BufferNode>()) {
+    return PrintBuffer(node.as<BufferNode>());
+  } else if (node->IsInstance<StringObj>()) {
+    return PrintString(node.as<StringObj>());
+  } else {
+    return this->meta_->GetMetaNode(node);
+  }
+}
+
+Doc TIRTextPrinter::PrintPrimFunc(const PrimFunc& primFunc) {
+  const auto* op = primFunc.operator->();
+  const auto& signature = op->func_type_annotation();
+  // collect Meta in DictAttr
+  for (const auto& it : primFunc->attrs->dict) {
+    meta_collector_.Collect(it.second);
+  }
+  // collect buffers in buffer_map
+  memo_var_.clear();
+  memo_buf_.clear();
+  for (const auto& it : op->buffer_map) {
+    memo_buf_[it.second] = AllocBuf(it.second);
+  }
+  // print PrimFunc
+  Doc doc;
+  doc << "primfn" << "(";
+  // print params and its type annotation
+  std::vector<Doc> params;
+  for (const auto& param : op->params) {
+    params.push_back(Print(param));
+  }
+  Doc sep;
+  doc << PrintSep(params, Doc::Indent(9, Doc::Text(", "))) << ")";
+  // print return type
+  doc << " -> " << Print(signature->ret_type);
+  // print attr
+  Doc attr_doc;
+  std::vector<Doc> attr_docs;
+  for (const auto& it : op->attrs->dict) {
+    attr_docs.push_back(Doc::StrLiteral(it.first) << ": " << Print(it.second));
+  }
+  attr_doc << Doc::NewLine() << "attr = {" << PrintSep(attr_docs, Doc::Text(", 
")) << "}";
+  doc << Doc::Indent(2, attr_doc);
+  // print all the buffers in the tree
+  Doc buffer_doc;
+  std::vector<Doc> buffer_docs;
+  for (const auto& it : memo_buf_) {
+    const auto& buf = it.first;
+    buffer_docs.push_back(Print(buf)
+                              << Doc::Text(": Buffer(") << Print(buf->data) << 
", "
+                              << PrintDType(buf->dtype) << ", " << 
Print(buf->shape) << ", "
+                              << Print(buf->strides));
+    if (!is_zero(buf->elem_offset)) {
+      buffer_docs.back() << ", elem_offset=" << Print(buf->elem_offset);
+    }
+    if (buf->scope != "global") {
+      buffer_docs.back() << ", scope=" << Doc::StrLiteral(buf->scope);
+    }
+    if (buf->data_alignment != 128) {
+      buffer_docs.back() << ", align=" << buf->data_alignment;
+    }
+    if (buf->offset_factor != 1) {
+      buffer_docs.back() << ", offset_factor=" << buf->offset_factor;
+    }
+    if (buf->buffer_type != 1) {
+      buffer_docs.back() << ", type=" << Doc::StrLiteral("auto");
+    }
+    buffer_docs.back() << ")";
+  }
+  buffer_doc << Doc::NewLine() << "buffers = {";
+  buffer_doc << PrintSep(buffer_docs, Doc::Indent(11, Doc::Text(",") << 
Doc::NewLine()));
+  doc << Doc::Indent(2, buffer_doc) << "}";
+  // print buffer_map
+  std::vector<Doc> buffer_map_doc;
+  for (const auto& it : op->buffer_map) {
+    buffer_map_doc.push_back(Print(it.first) << ": " << Print(it.second));
+  }
+  doc << Doc::Indent(2, Doc::NewLine()
+      << "buffer_map = {" << PrintSep(buffer_map_doc, Doc::Text(", ")) << "}");
+  doc << PrintBody(op->body);
+  return doc;
+}
+
+Doc TIRTextPrinter::PrintIRModule(const IRModule& module) {
+  const auto* op = module.operator->();
+  Doc doc;
+
+  Doc body;
+  body << Doc::NewLine();
+  std::vector<Doc> functions;
+  for (auto it = op->functions.begin(); it != op->functions.end(); ++it) {
+    if ((*it).second.as<PrimFuncNode>()) {
+      functions.push_back(Print((*it).second));
+    }
+  }
+  body << TIRTextPrinter::PrintSep(functions, Doc::NewLine() << 
Doc::NewLine());
+  doc << Doc::Indent(0, body);
+  return doc;
+}
+
+Doc TIRTextPrinter::PrintArray(const ArrayNode* op) {
+  Doc doc;
+  doc << '[';
+  for (size_t i = 0; i < op->data.size(); ++i) {
+    if (i != 0) {
+      doc << ", ";
+    }
+    doc << Print(op->data[i]);
+  }
+  doc << ']';
+  return doc;
+}
+
+Doc TIRTextPrinter::PrintIterVar(const IterVarNode* op) {
+  Doc doc;
+  doc << "IterVar(" << Print(op->var);
+  if (op->dom.defined()) {
+    doc << ", [" << Print(op->dom) << "], ";
+  } else {
+    doc << ", " << Print(op->dom) << ", ";
+  }
+  doc << Doc::StrLiteral(IterVarType2String(op->iter_type)) << ", ";
+  doc << Doc::StrLiteral(op->thread_tag) << ")";
+  return doc;
+}
+
+Doc TIRTextPrinter::PrintRange(const RangeNode* op) {
+  return Print(op->min) << ":" << Print(op->min + op->extent);
+}
+
+Doc TIRTextPrinter::PrintBuffer(const BufferNode* op) {
+  const Buffer& buffer = GetRef<Buffer>(op);
+  CHECK_GT(memo_buf_.count(buffer), 0);
+  return meta_->InMeta(buffer) ? meta_->GetMetaNode(buffer) : 
memo_buf_[buffer];
+}
+
+Doc TIRTextPrinter::VisitExprDefault_(const Object* op) {
+  return this->meta_->GetMetaNode(GetRef<ObjectRef>(op));
+}
+
+Doc TIRTextPrinter::VisitStmtDefault_(const Object* op) {
+  return this->meta_->GetMetaNode(GetRef<ObjectRef>(op));
+}
+
+Doc TIRTextPrinter::VisitExpr_(const IntImmNode* op) {
+  return PrintConstScalar<int64_t>(op->dtype, op->value);
+}
+
+Doc TIRTextPrinter::VisitExpr_(const FloatImmNode* op) {
+  return PrintConstScalar<double>(op->dtype, op->value);
+}
+
+Doc TIRTextPrinter::VisitExpr_(const StringImmNode* op) { return 
Doc::StrLiteral(op->value); }
+
+Doc TIRTextPrinter::VisitExpr_(const CastNode* op) {
+  Doc doc;
+  doc << "cast(" << PrintDType(op->dtype) << ", " << Print(op->value) << ")";
+  return doc;
+}
+
+Doc TIRTextPrinter::VisitExpr_(const VarNode* op) {
+  const Var& var = GetRef<Var>(op);
+  return meta_->InMeta(var) ? meta_->GetMetaNode(var) : 
AllocVar(GetRef<Var>(op));
+}
+
+#define TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(OpName, OpString)     \
+  Doc TIRTextPrinter::VisitExpr_(const OpName* op) {               \
+    Doc doc;                                                       \
+    doc << "(" << Print(op->a) << OpString;                        \
+    doc << Print(op->b) << ")";                                    \
+    return doc;                                                    \
+  }
+
+TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(AddNode, " + ")
+TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(SubNode, " - ")
+TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(MulNode, "*")
+TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(DivNode, " / ")
+TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(ModNode, " % ")
+TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(EQNode, " == ")
+TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(NENode, " != ")
+TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(LTNode, " < ")
+TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(LENode, " <= ")
+TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(GTNode, " > ")
+TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(GENode, " >= ")
+TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(AndNode, " && ")
+TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(OrNode, " || ")
+
+Doc TIRTextPrinter::VisitExpr_(const FloorDivNode* op) {
+  Doc doc;
+  doc << "floordiv(" << Print(op->a) << ", " << Print(op->b) << ")";
+  return doc;
+}
+
+Doc TIRTextPrinter::VisitExpr_(const FloorModNode* op) {
+  Doc doc;
+  doc << "floormod(" << Print(op->a) << ", " << Print(op->b) << ")";
+  return doc;
+}
+
+Doc TIRTextPrinter::VisitExpr_(const MinNode* op) {
+  Doc doc;
+  doc << "min(" << Print(op->a) << ", " << Print(op->b) << ")";
+  return doc;
+}
+
+Doc TIRTextPrinter::VisitExpr_(const MaxNode* op) {
+  Doc doc;
+  doc << "max(" << Print(op->a) << ", " << Print(op->b) << ")";
+  return doc;
+}
+
+Doc TIRTextPrinter::VisitExpr_(const NotNode* op) {
+  Doc doc;
+  doc << "!" << Print(op->a);
+  return doc;
+}
+
+Doc TIRTextPrinter::VisitExpr_(const SelectNode* op) {
+  Doc doc;
+  doc << "select(" << Print(op->condition) << ", " << Print(op->true_value) << 
", "
+      << Print(op->false_value);
+  return doc;
+}
+
+Doc TIRTextPrinter::VisitExpr_(const BufferLoadNode* op) {
+  Doc doc;
+  doc << Print(op->buffer) << Print(op->indices);
+  return doc;
+}
+
+Doc TIRTextPrinter::VisitExpr_(const LoadNode* op) {
+  Doc doc;
+  doc << "(" << PrintDType(op->dtype) << "*)"
+      << Print(op->buffer_var) << "[" << Print(op->index) << "])";

Review comment:
       Parentheses are not matched here. Please consider 
`((dtype*)buffer_var)[index]`.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to