[Mlir-commits] [mlir] [mlir][ArmNeon] Update `LowerContractionToSMMLAPattern` to support proper unrolling for k dimension (PR #88591)

Kojo Acquah llvmlistbot at llvm.org
Fri Apr 12 16:07:16 PDT 2024


https://github.com/KoolJBlack created https://github.com/llvm/llvm-project/pull/88591

None

>From b7e8882c6369cd814803e51fd94b037ca4b745d3 Mon Sep 17 00:00:00 2001
From: Kojo Acquah <kooljblack at google.com>
Date: Fri, 12 Apr 2024 23:02:41 +0000
Subject: [PATCH] implement proper unrolling for k dim

---
 .../Transforms/LowerContractionToSMMLAPattern.cpp   | 13 +++++++++++--
 1 file changed, 11 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
index 3ae894692089b3..8284110f08abd9 100644
--- a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
+++ b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
@@ -133,8 +133,12 @@ class LowerContractionToSMMLAPattern
       smmlaShape.insert(smmlaShape.begin(), isVecmat ? 1 : 2);
       loopOrder.push_back(2);
     }
+
+    // Keep track of the previous accumulator when tiling over K
+    Value kAcc;
     for (SmallVector<int64_t> offsets :
          StaticTileOffsetRange(unrolledSize, smmlaShape, loopOrder)) {
+      auto kTileIndex = offsets[offsets.size() - 1] / 8;
       // Helper to compute the new shape of each operand and extract the slice.
       auto extractOperand = [&](Value operand, AffineMap permutationMap,
                                 ArrayRef<int64_t> operandOffsets) {
@@ -197,16 +201,21 @@ class LowerContractionToSMMLAPattern
       auto collapsedRes = rewriter.createOrFold<vector::ShapeCastOp>(
           tiledAcc.getLoc(), collapsedOutputType, tiledAcc);
 
+      if (kTileIndex != 0) {
+        collapsedRes = kAcc;
+      }
+
       // Insert contract op
       auto smmlaOp = rewriter.createOrFold<arm_neon::SmmlaOp>(
           op.getLoc(), collapsedRes.getType(), collapsedRes, collapsedLhs,
           collapsedRhs);
+      kAcc = smmlaOp;
 
       // Reshape output back to 2D
       Value tiledRes = rewriter.createOrFold<vector::ShapeCastOp>(
           smmlaOp.getLoc(), tiledAcc.getType(), smmlaOp);
 
-      // With vecmat, only one row of tiled ACC can be inserted inot file result
+      // With vecmat, only one row of tiled ACC can be inserted into file result
       if (isVecmat) {
         tiledRes = rewriter.createOrFold<vector::ExtractOp>(loc, tiledRes, 0);
       }
@@ -214,7 +223,7 @@ class LowerContractionToSMMLAPattern
       // Insert the tiled result back into the non tiled result of the
       // contract op.
       SmallVector<int64_t> strides(
-          tiledRes.getType().cast<ShapedType>().getRank(), 1);
+          tiledAcc.getType().cast<ShapedType>().getRank(), 1);
       result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
           loc, tiledRes, result, accOffsets, strides);
     }



More information about the Mlir-commits mailing list