================ @@ -48,6 +287,261 @@ Value collapse2DData(RewriterBase &rewriter, Location loc, Value data) { reassociation); } +// This function transforms the filter. The data layout of the filter is FHWC. +// The transformation matrix is 2-dimension. We need to extract H x W from +// FHWC first. We need to generate 2 levels of loops to iterate on F and C. +// After the transformation, we get +// +// scf.for %f = lo_f to hi_f step 1 +// scf.for %c = lo_c to hi_c step 1 +// %extracted = extract filter<h x w> from filter<f x h x w x c> +// %ret = linalg.matmul G, %extracted +// %ret = linalg.matmul %ret, GT +// %inserted = insert %ret into filter<tile_h x tile_w x h x w x c x f> +// +Value filterTransform(RewriterBase &rewriter, Location loc, Value filter, + Value retValue, int64_t m, int64_t r, + bool leftTransform = true, bool rightTransform = true) { + // Map from (m, r) to G transform matrix. + static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix> + GMatrices = { + {F_2_3, TransformMatrix(G_2x2_3x3, 4, 3)}, + {F_4_3, TransformMatrix(G_4x4_3x3, 6, 3)}, + {F_2_5, TransformMatrix(G_2x2_5x5, 6, 5)}, + }; + + // Map from (m, r) to GT transform matrix. + static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix> + GTMatrices = { + {F_2_3, TransformMatrix(GT_2x2_3x3, 3, 4)}, + {F_4_3, TransformMatrix(GT_4x4_3x3, 3, 6)}, + {F_2_5, TransformMatrix(GT_2x2_5x5, 5, 6)}, + }; + + auto filterType = cast<ShapedType>(filter.getType()); + Type elementType = filterType.getElementType(); + auto filterShape = filterType.getShape(); // F, H, W, C + int64_t filterF = filterShape[0]; + int64_t filterH = filterShape[1]; + int64_t filterW = filterShape[2]; + int64_t filterC = filterShape[3]; + + if (filterH != r && filterH != 1) + return Value(); + if (filterW != r && filterW != 1) + return Value(); + + // Return shape is <H x W x C x F> + auto zeroIdx = rewriter.create<arith::ConstantIndexOp>(loc, 0); + auto fUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, filterF); + auto cUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, filterC); + auto oneStep = rewriter.create<arith::ConstantIndexOp>(loc, 1); + auto outerForOp = + rewriter.create<scf::ForOp>(loc, zeroIdx, fUpperBound, oneStep, retValue); + Block *outerForBody = outerForOp.getBody(); + rewriter.setInsertionPointToStart(outerForBody); + Value FIter = outerForBody->getArgument(0); + + auto innerForOp = rewriter.create<scf::ForOp>( + loc, zeroIdx, cUpperBound, oneStep, outerForOp.getRegionIterArgs()[0]); ---------------- ftynse wrote:
Ditto. there must be a better-named function for this. https://github.com/llvm/llvm-project/pull/96183 _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits