================ @@ -729,3 +729,92 @@ Value sparse_tensor::createOrFoldSliceStrideOp(OpBuilder &builder, Location loc, return constantIndex(builder, loc, *stride); return builder.create<ToSliceStrideOp>(loc, tensor, APInt(64, dim)); } + +void sparse_tensor::fillDimShape(OpBuilder &builder, Location loc, + SparseTensorType stt, + SmallVectorImpl<Value> &out) { + out.clear(); + out.reserve(stt.getDimRank()); + for (const DynSize sh : stt.getDimShape()) { + const auto s = ShapedType::isDynamic(sh) ? 0 : sh; + out.push_back(constantIndex(builder, loc, s)); + } +} + +Value sparse_tensor::genReader(OpBuilder &builder, Location loc, + SparseTensorType stt, Value tensor, + /*out*/ SmallVectorImpl<Value> &dimShapesValues, + /*out*/ Value &dimSizesBuffer) { + // Construct the dimShapes buffer. The buffer contains the static size + // per dimension, or otherwise a zero for a dynamic size. + fillDimShape(builder, loc, stt, dimShapesValues); + Value dimShapesBuffer = allocaBuffer(builder, loc, dimShapesValues); + // Create the `CheckedSparseTensorReader`. This reader performs a + // consistency check on the static sizes, but accepts any size + // of each dimension with a dynamic size. + Type opaqueTp = getOpaquePointerType(builder); + Type eltTp = stt.getElementType(); + Value valTp = constantPrimaryTypeEncoding(builder, loc, eltTp); + Value reader = + createFuncCall(builder, loc, "createCheckedSparseTensorReader", opaqueTp, + {tensor, dimShapesBuffer, valTp}, EmitCInterface::On) + .getResult(0); + // For static shapes, the shape buffer can be used right away. For dynamic + // shapes, use the information from the reader to construct a buffer that + // supplies the actual size for each dynamic dimension. + dimSizesBuffer = dimShapesBuffer; + if (stt.hasDynamicDimShape()) { + Type indexTp = builder.getIndexType(); + auto memTp = MemRefType::get({ShapedType::kDynamic}, indexTp); + dimSizesBuffer = + createFuncCall(builder, loc, "getSparseTensorReaderDimSizes", memTp, + reader, EmitCInterface::On) + .getResult(0); + } + return reader; +} + +Value sparse_tensor::genReaderBuffers(OpBuilder &builder, Location loc, + SparseTensorType stt, + ArrayRef<Value> dimShapesValues, + Value dimSizesBuffer, + /*out*/ Value &dim2lvlBuffer, + /*out*/ Value &lvl2dimBuffer) { + const Dimension dimRank = stt.getDimRank(); + const Level lvlRank = stt.getLvlRank(); + // For an identify mapping, the dim2lvl and lvl2dim mappings are ---------------- yinying-lisa-li wrote:
identity? https://github.com/llvm/llvm-project/pull/68360 _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits