[Mlir-commits] [mlir] [mlir][ArmNeon] Update `LowerContractionToSMMLAPattern` to support proper unrolling for k dimension (PR #88591)
Kojo Acquah
llvmlistbot at llvm.org
Fri Apr 12 16:07:16 PDT 2024
https://github.com/KoolJBlack created https://github.com/llvm/llvm-project/pull/88591
None
>From b7e8882c6369cd814803e51fd94b037ca4b745d3 Mon Sep 17 00:00:00 2001
From: Kojo Acquah <kooljblack at google.com>
Date: Fri, 12 Apr 2024 23:02:41 +0000
Subject: [PATCH] implement proper unrolling for k dim
---
.../Transforms/LowerContractionToSMMLAPattern.cpp | 13 +++++++++++--
1 file changed, 11 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
index 3ae894692089b3..8284110f08abd9 100644
--- a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
+++ b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
@@ -133,8 +133,12 @@ class LowerContractionToSMMLAPattern
smmlaShape.insert(smmlaShape.begin(), isVecmat ? 1 : 2);
loopOrder.push_back(2);
}
+
+ // Keep track of the previous accumulator when tiling over K
+ Value kAcc;
for (SmallVector<int64_t> offsets :
StaticTileOffsetRange(unrolledSize, smmlaShape, loopOrder)) {
+ auto kTileIndex = offsets[offsets.size() - 1] / 8;
// Helper to compute the new shape of each operand and extract the slice.
auto extractOperand = [&](Value operand, AffineMap permutationMap,
ArrayRef<int64_t> operandOffsets) {
@@ -197,16 +201,21 @@ class LowerContractionToSMMLAPattern
auto collapsedRes = rewriter.createOrFold<vector::ShapeCastOp>(
tiledAcc.getLoc(), collapsedOutputType, tiledAcc);
+ if (kTileIndex != 0) {
+ collapsedRes = kAcc;
+ }
+
// Insert contract op
auto smmlaOp = rewriter.createOrFold<arm_neon::SmmlaOp>(
op.getLoc(), collapsedRes.getType(), collapsedRes, collapsedLhs,
collapsedRhs);
+ kAcc = smmlaOp;
// Reshape output back to 2D
Value tiledRes = rewriter.createOrFold<vector::ShapeCastOp>(
smmlaOp.getLoc(), tiledAcc.getType(), smmlaOp);
- // With vecmat, only one row of tiled ACC can be inserted inot file result
+ // With vecmat, only one row of tiled ACC can be inserted into file result
if (isVecmat) {
tiledRes = rewriter.createOrFold<vector::ExtractOp>(loc, tiledRes, 0);
}
@@ -214,7 +223,7 @@ class LowerContractionToSMMLAPattern
// Insert the tiled result back into the non tiled result of the
// contract op.
SmallVector<int64_t> strides(
- tiledRes.getType().cast<ShapedType>().getRank(), 1);
+ tiledAcc.getType().cast<ShapedType>().getRank(), 1);
result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
loc, tiledRes, result, accOffsets, strides);
}
More information about the Mlir-commits
mailing list