[llvm-branch-commits] [mlir] [mlir][linalg] Decompose winograd operators (PR #96183)
Oleksandr Alex Zinenko via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Mon Jun 24 06:08:20 PDT 2024
================
@@ -100,6 +594,161 @@ Value matrixMultiply(RewriterBase &rewriter, Location loc,
return expandOutput;
}
+// This function transforms the output. The data layout of the output is HWNF.
+// The transformation matrix is 2-dimension. We need to extract H x W from
+// HWNF first. We need to generate 2 levels of loops to iterate on N and F.
+// After the transformation, we get
+//
+// scf.for %n = lo_n to hi_n step 1
+// scf.for %f = lo_f to hi_f step 1
+// %extracted = extract input<h x w> from result<h x w x n x f>
+// %ret = linalg.matmul AT, %extracted
+// %ret = linalg.matmul %ret, A
+// %inserted = insert %ret into ret<n x h x w x f>
+//
+Value outputTransform(RewriterBase &rewriter, Location loc, Value value,
+ Value output, int64_t m, int64_t r,
+ bool leftTransform = true, bool rightTransform = true) {
+ // Map from (m, r) to AT transform matrix.
+ static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
+ ATMatrices = {
+ {F_2_3, TransformMatrix(AT_2x2_3x3, 2, 4)},
+ {F_4_3, TransformMatrix(AT_4x4_3x3, 4, 6, 32)},
+ {F_2_5, TransformMatrix(AT_2x2_5x5, 2, 6, 16)},
+ };
+
+ // Map from (m, r) to A transform matrix.
+ static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
+ AMatrices = {
+ {F_2_3, TransformMatrix(A_2x2_3x3, 4, 2)},
+ {F_4_3, TransformMatrix(A_4x4_3x3, 6, 4, 32)},
+ {F_2_5, TransformMatrix(A_2x2_5x5, 6, 2, 16)},
+ };
+
+ auto valueType = cast<ShapedType>(value.getType());
+ Type elementType = valueType.getElementType();
+ auto valueShape = valueType.getShape(); // TileH, TileW, H, W, N, F
+ int64_t valueH = valueShape[2];
+ int64_t valueW = valueShape[3];
+ int64_t valueN = valueShape[4];
+ int64_t valueF = valueShape[5];
+ int64_t alphaH = leftTransform ? m + r - 1 : 1;
+ int64_t alphaW = rightTransform ? m + r - 1 : 1;
+
+ if (valueH != alphaH && valueH != 1)
+ return Value();
+ if (valueW != alphaW && valueW != 1)
+ return Value();
+
+ auto zeroIdx = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ auto nUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, valueN);
+ auto fUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, valueF);
+ auto oneStep = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+
+ auto outerForOp =
+ rewriter.create<scf::ForOp>(loc, zeroIdx, nUpperBound, oneStep, output);
+ Block *outerForBody = outerForOp.getBody();
+ rewriter.setInsertionPointToStart(outerForBody);
+ Value NIter = outerForBody->getArgument(0);
+
+ auto innerForOp = rewriter.create<scf::ForOp>(
+ loc, zeroIdx, fUpperBound, oneStep, outerForOp.getRegionIterArgs()[0]);
+ Block *innerForBody = innerForOp.getBody();
+ rewriter.setInsertionPointToStart(innerForBody);
+ Value FIter = innerForBody->getArgument(0);
+
+ // Extract (H, W) from (1, 1, H, W, N, F)
+ auto extractValue = extract2DData(
+ rewriter, loc, value, NIter, FIter, /*outLoopIdx=*/4,
+ /*inLoopIdx=*/5, /*heightIdx=*/2, /*widthIdx=*/3, /*srcSize=*/6);
+
+ TransformMapKeyTy key = {m, r};
+ int64_t retRows = 1;
+ int64_t retCols = 1;
+ int64_t leftScalarFactor = 1;
+ int64_t rightScalarFactor = 1;
+ Value matmulRetValue = extractValue;
+ if (leftTransform) {
+ // Get constant transform matrix AT
+ auto it = ATMatrices.find(key);
+ if (it == ATMatrices.end())
+ return Value();
+ const TransformMatrix &ATMatrix = it->second;
+
+ leftScalarFactor = ATMatrix.scalarFactor;
+ retRows = ATMatrix.rows;
+ auto matmulType = RankedTensorType::get({retRows, valueW}, elementType);
+ auto init = rewriter.create<tensor::EmptyOp>(loc, matmulType.getShape(),
+ elementType);
+
+ Value AT = create2DTransformMatrix(rewriter, loc, ATMatrix, elementType);
+ // Multiply AT x m
+ auto matmulOp = rewriter.create<linalg::MatmulOp>(
+ loc, matmulType, ValueRange{AT, matmulRetValue}, ValueRange{init});
+ matmulRetValue = matmulOp.getResult(0);
+ }
+
+ if (rightTransform) {
+ // Get constant transform matrix T
+ auto it = AMatrices.find(key);
+ if (it == AMatrices.end())
+ return Value();
+ const TransformMatrix &AMatrix = it->second;
+
+ rightScalarFactor = AMatrix.scalarFactor;
+ auto matmulType =
+ RankedTensorType::get({retRows, AMatrix.cols}, elementType);
+ retCols = AMatrix.cols;
+ auto init = rewriter.create<tensor::EmptyOp>(loc, matmulType.getShape(),
+ elementType);
+
+ Value A = create2DTransformMatrix(rewriter, loc, AMatrix, elementType);
+ // Multiply y = (AT x m) x A
+ auto matmulOp = rewriter.create<linalg::MatmulOp>(
+ loc, matmulType, ValueRange{matmulRetValue, A}, ValueRange{init});
+ matmulRetValue = matmulOp.getResult(0);
+ }
+
+ // Multiply scalar factor.
+ Value scalarFactor = rewriter.create<arith::ConstantOp>(
+ loc, FloatAttr::get(elementType, leftScalarFactor * rightScalarFactor));
+ auto matmulType = RankedTensorType::get({retRows, retCols}, elementType);
+ auto init =
+ rewriter.create<tensor::EmptyOp>(loc, matmulType.getShape(), elementType);
+
+ auto identityAffineMap = rewriter.getMultiDimIdentityMap(2);
+ SmallVector<AffineMap> affineMaps = {AffineMap::get(2, 0, init.getContext()),
+ identityAffineMap, identityAffineMap};
+ auto scalarMatrixOp = rewriter.create<linalg::GenericOp>(
+ loc, matmulType, ValueRange{scalarFactor, matmulRetValue},
+ ValueRange{init}, affineMaps, tosa::getNParallelLoopsAttrs(2),
+ [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
+ Value scalarVal = args[0];
+ Value matrixVal = args[1];
+ Value result = nestedBuilder.create<arith::MulFOp>(nestedLoc, scalarVal,
----------------
ftynse wrote:
Can we just use `linalg.mul`?
https://github.com/llvm/llvm-project/pull/96183
More information about the llvm-branch-commits
mailing list