[Mlir-commits] [mlir] [mlir][ArmNeon] Implements unrolling patterns for LowerContractionToSMMLAPattern (PR #84848)
Diego Caballero
llvmlistbot at llvm.org
Fri Mar 15 10:18:13 PDT 2024
================
@@ -113,26 +114,73 @@ class LowerContractionToSMMLAPattern
return failure();
}
- // Collapse to 1D vectors required by smmla intrinsic
- auto collapsedInputType = VectorType::get(
- {16}, extsiLhs.getType().cast<ShapedType>().getElementType());
- auto collapsedOutputType =
- VectorType::get({4}, res.getType().cast<ShapedType>().getElementType());
- auto collapsedLhs = rewriter.createOrFold<vector::ShapeCastOp>(
- extsiLhs.getLoc(), collapsedInputType, extsiLhs);
- auto collapsedRhs = rewriter.createOrFold<vector::ShapeCastOp>(
- extsiRhs.getLoc(), collapsedInputType, extsiRhs);
- auto collapsedRes = rewriter.createOrFold<vector::ShapeCastOp>(
- res.getLoc(), collapsedOutputType, res);
-
- // Replace the contract with a neon op
- auto smmlaOp = rewriter.createOrFold<arm_neon::SmmlaOp>(
- op.getLoc(), collapsedRes.getType(), collapsedRes, collapsedLhs,
- collapsedRhs);
-
- // Reshape output back to 2D
- rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, op.getResultType(),
- smmlaOp);
+ // Initial accumulator for the final result. This is the un-tiled result if
+ // tiling is done.
+ Value result = rewriter.create<arith::ConstantOp>(
+ loc, op.getResultType(), rewriter.getZeroAttr(op.getResultType()));
+
+ SmallVector<int64_t> unrolledSize = *op.getShapeForUnroll();
+ SmallVector<int64_t> smmlaShape{2, 2, 8};
+ SmallVector<int64_t> loopOrder{0, 1, 2};
+ for (SmallVector<int64_t> offsets :
+ StaticTileOffsetRange(unrolledSize, smmlaShape, loopOrder)) {
+
+ // Helper to compute the new shape of each operand and extract the slice.
+ auto extractOperand = [&](Value operand, AffineMap permutationMap,
+ ArrayRef<int64_t> operandOffsets) {
+ SmallVector<int64_t> operandShape =
+ applyPermutationMap(permutationMap, ArrayRef<int64_t>(smmlaShape));
+ SmallVector<int64_t> operandStrides(operandOffsets.size(), 1);
+ return rewriter.createOrFold<vector::ExtractStridedSliceOp>(
+ loc, operand, operandOffsets, operandShape, operandStrides);
+ };
+
+ // Extract tiled lhs, rhs, and acc
+ AffineMap lhsPermutationMap = op.getIndexingMapsArray()[0];
+ SmallVector<int64_t> lhsOffsets =
+ applyPermutationMap(lhsPermutationMap, ArrayRef<int64_t>(offsets));
+ auto tiledLhs = extractOperand(extsiLhs, lhsPermutationMap, lhsOffsets);
----------------
dcaballe wrote:
Spell out some of these autos where the type is not explicit anywhere?
https://github.com/llvm/llvm-project/pull/84848
More information about the Mlir-commits
mailing list