================
@@ -729,3 +729,92 @@ Value sparse_tensor::createOrFoldSliceStrideOp(OpBuilder 
&builder, Location loc,
     return constantIndex(builder, loc, *stride);
   return builder.create<ToSliceStrideOp>(loc, tensor, APInt(64, dim));
 }
+
+void sparse_tensor::fillDimShape(OpBuilder &builder, Location loc,
+                                 SparseTensorType stt,
+                                 SmallVectorImpl<Value> &out) {
+  out.clear();
+  out.reserve(stt.getDimRank());
+  for (const DynSize sh : stt.getDimShape()) {
+    const auto s = ShapedType::isDynamic(sh) ? 0 : sh;
+    out.push_back(constantIndex(builder, loc, s));
+  }
+}
+
+Value sparse_tensor::genReader(OpBuilder &builder, Location loc,
+                               SparseTensorType stt, Value tensor,
+                               /*out*/ SmallVectorImpl<Value> &dimShapesValues,
+                               /*out*/ Value &dimSizesBuffer) {
+  // Construct the dimShapes buffer. The buffer contains the static size
+  // per dimension, or otherwise a zero for a dynamic size.
+  fillDimShape(builder, loc, stt, dimShapesValues);
+  Value dimShapesBuffer = allocaBuffer(builder, loc, dimShapesValues);
+  // Create the `CheckedSparseTensorReader`. This reader performs a
+  // consistency check on the static sizes, but accepts any size
+  // of each dimension with a dynamic size.
+  Type opaqueTp = getOpaquePointerType(builder);
+  Type eltTp = stt.getElementType();
+  Value valTp = constantPrimaryTypeEncoding(builder, loc, eltTp);
+  Value reader =
+      createFuncCall(builder, loc, "createCheckedSparseTensorReader", opaqueTp,
+                     {tensor, dimShapesBuffer, valTp}, EmitCInterface::On)
+          .getResult(0);
+  // For static shapes, the shape buffer can be used right away. For dynamic
+  // shapes, use the information from the reader to construct a buffer that
+  // supplies the actual size for each dynamic dimension.
+  dimSizesBuffer = dimShapesBuffer;
+  if (stt.hasDynamicDimShape()) {
+    Type indexTp = builder.getIndexType();
+    auto memTp = MemRefType::get({ShapedType::kDynamic}, indexTp);
+    dimSizesBuffer =
+        createFuncCall(builder, loc, "getSparseTensorReaderDimSizes", memTp,
+                       reader, EmitCInterface::On)
+            .getResult(0);
+  }
+  return reader;
+}
+
+Value sparse_tensor::genReaderBuffers(OpBuilder &builder, Location loc,
+                                      SparseTensorType stt,
+                                      ArrayRef<Value> dimShapesValues,
+                                      Value dimSizesBuffer,
+                                      /*out*/ Value &dim2lvlBuffer,
+                                      /*out*/ Value &lvl2dimBuffer) {
+  const Dimension dimRank = stt.getDimRank();
+  const Level lvlRank = stt.getLvlRank();
+  // For an identify mapping, the dim2lvl and lvl2dim mappings are
----------------
yinying-lisa-li wrote:

identity?

https://github.com/llvm/llvm-project/pull/68360
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to