================ @@ -100,6 +594,161 @@ Value matrixMultiply(RewriterBase &rewriter, Location loc, return expandOutput; } +// This function transforms the output. The data layout of the output is HWNF. +// The transformation matrix is 2-dimension. We need to extract H x W from +// HWNF first. We need to generate 2 levels of loops to iterate on N and F. +// After the transformation, we get +// +// scf.for %n = lo_n to hi_n step 1 +// scf.for %f = lo_f to hi_f step 1 +// %extracted = extract input<h x w> from result<h x w x n x f> +// %ret = linalg.matmul AT, %extracted +// %ret = linalg.matmul %ret, A +// %inserted = insert %ret into ret<n x h x w x f> +// +Value outputTransform(RewriterBase &rewriter, Location loc, Value value, + Value output, int64_t m, int64_t r, + bool leftTransform = true, bool rightTransform = true) { + // Map from (m, r) to AT transform matrix. + static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix> + ATMatrices = { + {F_2_3, TransformMatrix(AT_2x2_3x3, 2, 4)}, + {F_4_3, TransformMatrix(AT_4x4_3x3, 4, 6, 32)}, + {F_2_5, TransformMatrix(AT_2x2_5x5, 2, 6, 16)}, + }; + + // Map from (m, r) to A transform matrix. + static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix> + AMatrices = { + {F_2_3, TransformMatrix(A_2x2_3x3, 4, 2)}, + {F_4_3, TransformMatrix(A_4x4_3x3, 6, 4, 32)}, + {F_2_5, TransformMatrix(A_2x2_5x5, 6, 2, 16)}, + }; + + auto valueType = cast<ShapedType>(value.getType()); + Type elementType = valueType.getElementType(); + auto valueShape = valueType.getShape(); // TileH, TileW, H, W, N, F + int64_t valueH = valueShape[2]; + int64_t valueW = valueShape[3]; + int64_t valueN = valueShape[4]; + int64_t valueF = valueShape[5]; + int64_t alphaH = leftTransform ? m + r - 1 : 1; + int64_t alphaW = rightTransform ? m + r - 1 : 1; + + if (valueH != alphaH && valueH != 1) + return Value(); + if (valueW != alphaW && valueW != 1) + return Value(); + + auto zeroIdx = rewriter.create<arith::ConstantIndexOp>(loc, 0); + auto nUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, valueN); + auto fUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, valueF); + auto oneStep = rewriter.create<arith::ConstantIndexOp>(loc, 1); + + auto outerForOp = + rewriter.create<scf::ForOp>(loc, zeroIdx, nUpperBound, oneStep, output); + Block *outerForBody = outerForOp.getBody(); + rewriter.setInsertionPointToStart(outerForBody); + Value NIter = outerForBody->getArgument(0); + + auto innerForOp = rewriter.create<scf::ForOp>( + loc, zeroIdx, fUpperBound, oneStep, outerForOp.getRegionIterArgs()[0]); + Block *innerForBody = innerForOp.getBody(); + rewriter.setInsertionPointToStart(innerForBody); + Value FIter = innerForBody->getArgument(0); + + // Extract (H, W) from (1, 1, H, W, N, F) + auto extractValue = extract2DData( + rewriter, loc, value, NIter, FIter, /*outLoopIdx=*/4, + /*inLoopIdx=*/5, /*heightIdx=*/2, /*widthIdx=*/3, /*srcSize=*/6); + + TransformMapKeyTy key = {m, r}; + int64_t retRows = 1; + int64_t retCols = 1; + int64_t leftScalarFactor = 1; + int64_t rightScalarFactor = 1; + Value matmulRetValue = extractValue; + if (leftTransform) { + // Get constant transform matrix AT + auto it = ATMatrices.find(key); + if (it == ATMatrices.end()) + return Value(); + const TransformMatrix &ATMatrix = it->second; + + leftScalarFactor = ATMatrix.scalarFactor; + retRows = ATMatrix.rows; + auto matmulType = RankedTensorType::get({retRows, valueW}, elementType); + auto init = rewriter.create<tensor::EmptyOp>(loc, matmulType.getShape(), + elementType); + + Value AT = create2DTransformMatrix(rewriter, loc, ATMatrix, elementType); + // Multiply AT x m + auto matmulOp = rewriter.create<linalg::MatmulOp>( + loc, matmulType, ValueRange{AT, matmulRetValue}, ValueRange{init}); + matmulRetValue = matmulOp.getResult(0); + } + + if (rightTransform) { + // Get constant transform matrix T + auto it = AMatrices.find(key); + if (it == AMatrices.end()) + return Value(); + const TransformMatrix &AMatrix = it->second; + + rightScalarFactor = AMatrix.scalarFactor; + auto matmulType = + RankedTensorType::get({retRows, AMatrix.cols}, elementType); + retCols = AMatrix.cols; + auto init = rewriter.create<tensor::EmptyOp>(loc, matmulType.getShape(), + elementType); + + Value A = create2DTransformMatrix(rewriter, loc, AMatrix, elementType); + // Multiply y = (AT x m) x A + auto matmulOp = rewriter.create<linalg::MatmulOp>( + loc, matmulType, ValueRange{matmulRetValue, A}, ValueRange{init}); + matmulRetValue = matmulOp.getResult(0); + } + + // Multiply scalar factor. + Value scalarFactor = rewriter.create<arith::ConstantOp>( + loc, FloatAttr::get(elementType, leftScalarFactor * rightScalarFactor)); + auto matmulType = RankedTensorType::get({retRows, retCols}, elementType); + auto init = + rewriter.create<tensor::EmptyOp>(loc, matmulType.getShape(), elementType); + + auto identityAffineMap = rewriter.getMultiDimIdentityMap(2); + SmallVector<AffineMap> affineMaps = {AffineMap::get(2, 0, init.getContext()), + identityAffineMap, identityAffineMap}; + auto scalarMatrixOp = rewriter.create<linalg::GenericOp>( + loc, matmulType, ValueRange{scalarFactor, matmulRetValue}, + ValueRange{init}, affineMaps, tosa::getNParallelLoopsAttrs(2), ---------------- ftynse wrote:
Let's not use TOSA from here. If this helper is needed, let's move it elsewhere. https://github.com/llvm/llvm-project/pull/96183 _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits