================ @@ -2810,9 +2819,117 @@ LogicalResult WinogradInputTransformOp::verify() { if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) { return emitOpError("the output shape is not expected"); } + return success(); } +SmallVector<Range> +WinogradInputTransformOp::getIterationDomain(OpBuilder &builder) { + Location loc = getLoc(); + auto indexType = builder.getIndexType(); + auto zeroAttr = builder.getIntegerAttr(indexType, 0); + auto oneAttr = builder.getIntegerAttr(indexType, 1); + Value output = getOutput(); + SmallVector<Range> loopBounds(6); + for (unsigned dim = 0; dim < 6; ++dim) { + loopBounds[dim].offset = zeroAttr; + loopBounds[dim].size = getDimValue(builder, loc, output, dim); + loopBounds[dim].stride = oneAttr; + } + return loopBounds; +} + +SmallVector<utils::IteratorType> +WinogradInputTransformOp::getLoopIteratorTypes() { + SmallVector<utils::IteratorType> iteratorTypes(6, + utils::IteratorType::parallel); + return iteratorTypes; +} + +LogicalResult WinogradInputTransformOp::getResultTilePosition( + OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets, + ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets, + SmallVector<OpFoldResult> &resultSizes) { + auto zeroAttr = builder.getI64IntegerAttr(0); + auto oneAttr = builder.getI64IntegerAttr(1); + + resultOffsets.push_back(zeroAttr); + resultOffsets.push_back(zeroAttr); + resultOffsets.push_back(offsets[2]); + resultOffsets.push_back(offsets[3]); + resultOffsets.push_back(zeroAttr); + resultOffsets.push_back(zeroAttr); + resultSizes.push_back(sizes[0]); + resultSizes.push_back(sizes[1]); + resultSizes.push_back(oneAttr); + resultSizes.push_back(oneAttr); + resultSizes.push_back(sizes[4]); + resultSizes.push_back(sizes[5]); + + return success(); +} + +FailureOr<TilingResult> +WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder, + ArrayRef<OpFoldResult> offsets, + ArrayRef<OpFoldResult> sizes) { + auto oneAttr = builder.getI64IntegerAttr(1); + auto zeroAttr = builder.getI64IntegerAttr(0); + Value input = getInput(); + auto inputType = cast<ShapedType>(input.getType()); + auto inputShape = inputType.getShape(); + int64_t inputH = inputShape[1]; + int64_t inputW = inputShape[2]; + int64_t m = getM(); + int64_t r = getR(); + int64_t alpha = m + r - 1; + int64_t alphaH = inputH != 1 ? alpha : 1; + int64_t alphaW = inputW != 1 ? alpha : 1; + auto alphaHAttr = builder.getI64IntegerAttr(alphaH); + auto alphaWAttr = builder.getI64IntegerAttr(alphaW); + + Location loc = getLoc(); + SmallVector<Value> tiledOperands; + SmallVector<OpFoldResult> sliceOffsets, sliceSizes; + + auto context = builder.getContext(); + auto affineMap = + AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context); + Value mappedOffset1 = builder.create<affine::AffineApplyOp>( + loc, affineMap, getValueFromOpFoldResult(offsets[2], builder, loc)); + Value mappedOffset2 = builder.create<affine::AffineApplyOp>( + loc, affineMap, getValueFromOpFoldResult(offsets[3], builder, loc)); + + sliceOffsets.push_back(zeroAttr); + sliceOffsets.push_back(mappedOffset1); + sliceOffsets.push_back(mappedOffset2); + sliceOffsets.push_back(zeroAttr); + sliceSizes.push_back(sizes[4]); + sliceSizes.push_back(alphaHAttr); + sliceSizes.push_back(alphaWAttr); + sliceSizes.push_back(sizes[5]); + SmallVector<OpFoldResult> inputStrides(4, oneAttr); + tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>( + loc, getInput(), sliceOffsets, sliceSizes, inputStrides)); + + sliceOffsets.clear(); + sliceSizes.clear(); ---------------- ftynse wrote:
I'd rather declare new vectors for this. https://github.com/llvm/llvm-project/pull/96184 _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits