[llvm-branch-commits] [mlir] [mlir][linalg] Decompose winograd operators (PR #96183)
Hsiangkai Wang via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Thu Jun 27 01:39:46 PDT 2024
================
@@ -48,6 +287,261 @@ Value collapse2DData(RewriterBase &rewriter, Location loc, Value data) {
reassociation);
}
+// This function transforms the filter. The data layout of the filter is FHWC.
+// The transformation matrix is 2-dimension. We need to extract H x W from
+// FHWC first. We need to generate 2 levels of loops to iterate on F and C.
+// After the transformation, we get
+//
+// scf.for %f = lo_f to hi_f step 1
+// scf.for %c = lo_c to hi_c step 1
+// %extracted = extract filter<h x w> from filter<f x h x w x c>
+// %ret = linalg.matmul G, %extracted
+// %ret = linalg.matmul %ret, GT
+// %inserted = insert %ret into filter<tile_h x tile_w x h x w x c x f>
+//
+Value filterTransform(RewriterBase &rewriter, Location loc, Value filter,
+ Value retValue, int64_t m, int64_t r,
+ bool leftTransform = true, bool rightTransform = true) {
+ // Map from (m, r) to G transform matrix.
+ static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
+ GMatrices = {
+ {F_2_3, TransformMatrix(G_2x2_3x3, 4, 3)},
+ {F_4_3, TransformMatrix(G_4x4_3x3, 6, 3)},
+ {F_2_5, TransformMatrix(G_2x2_5x5, 6, 5)},
+ };
+
+ // Map from (m, r) to GT transform matrix.
+ static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
+ GTMatrices = {
+ {F_2_3, TransformMatrix(GT_2x2_3x3, 3, 4)},
+ {F_4_3, TransformMatrix(GT_4x4_3x3, 3, 6)},
+ {F_2_5, TransformMatrix(GT_2x2_5x5, 5, 6)},
+ };
+
+ auto filterType = cast<ShapedType>(filter.getType());
+ Type elementType = filterType.getElementType();
+ auto filterShape = filterType.getShape(); // F, H, W, C
+ int64_t filterF = filterShape[0];
+ int64_t filterH = filterShape[1];
+ int64_t filterW = filterShape[2];
+ int64_t filterC = filterShape[3];
+
+ if (filterH != r && filterH != 1)
+ return Value();
+ if (filterW != r && filterW != 1)
+ return Value();
+
+ // Return shape is <H x W x C x F>
+ auto zeroIdx = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ auto fUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, filterF);
+ auto cUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, filterC);
+ auto oneStep = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+ auto outerForOp =
+ rewriter.create<scf::ForOp>(loc, zeroIdx, fUpperBound, oneStep, retValue);
+ Block *outerForBody = outerForOp.getBody();
+ rewriter.setInsertionPointToStart(outerForBody);
+ Value FIter = outerForBody->getArgument(0);
----------------
Hsiangkai wrote:
I use buildLoopNest to create loops and use a callback to construct inner most loop body.
https://github.com/llvm/llvm-project/pull/96183
More information about the llvm-branch-commits
mailing list