[Mlir-commits] [mlir] [MLIR][AArch64] Refactor lowering of vector.contract to Neon I8MM (PR #149810)

Momchil Velikov llvmlistbot at llvm.org
Tue Jul 22 07:23:05 PDT 2025


https://github.com/momchil-velikov updated https://github.com/llvm/llvm-project/pull/149810

>From ced3ddf5ffb8ef8933238f266683b1124d637841 Mon Sep 17 00:00:00 2001
From: Momchil Velikov <momchil.velikov at arm.com>
Date: Thu, 10 Jul 2025 11:10:19 +0000
Subject: [PATCH 1/2] [MLIR][AArch64] Refactor lowering of vector.contract to
 Neon I8MM

This patch refactors the pattern in `Transforms/LowerContractionToNeonI8MMPattern.cpp`
using similar approach as in https://github.com/llvm/llvm-project/pull/147052
to prepare for adding BF16 support.
---
 .../LowerContractionToNeonI8MMPattern.cpp     | 431 ++++++++++--------
 1 file changed, 247 insertions(+), 184 deletions(-)

diff --git a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToNeonI8MMPattern.cpp b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToNeonI8MMPattern.cpp
index 15de736480c5e..7e6a2bab59a83 100644
--- a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToNeonI8MMPattern.cpp
+++ b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToNeonI8MMPattern.cpp
@@ -31,23 +31,15 @@ using namespace mlir;
 using namespace mlir::arm_neon;
 
 namespace {
-
-/// Return the shaped type with new element type.
-static Type matchContainerType(Type element, Type container) {
-  if (auto shapedTy = dyn_cast<ShapedType>(container)) {
-    return shapedTy.clone(element);
-  }
-  return element;
-}
-
-// Get the operand of a `vector.contract`. This function is intended to abstract
-// away from the particular way a value is extended before feeding it into the
-// `vector.contract` - via zero-extend or an explicit or implicit sign-extend
-// (for implicit sign-extension see `vector.contract` documentation).
-//
-// The template parameter `Op` indicates the extension operation (explicit or
-// implicit) for which we are checking.
-//
+/// Get the operand of a `vector.contract`. This function is intended to
+/// abstract away from the particular way a value is extended before feeding it
+/// into the `vector.contract` - via zero-extend or an explicit or implicit
+/// sign-extend (for implicit sign-extension see `vector.contract`
+/// documentation).
+///
+/// The template parameter `Op` indicates the extension operation (explicit or
+/// implicit) for which we are checking.
+///
 // Return success only for extensions from `iN` (N <= 8) to `i32`.
 template <typename Op>
 std::optional<Value> getExtOperand(Value v) {
@@ -85,202 +77,186 @@ std::optional<Value> getExtOperand(Value v) {
   return inOp;
 }
 
-// Designate the operation (resp. instruction) used to do sub-tile matrix
-// multiplications.
-enum class MMLA {
-  Signed,      // smmla
-  Unsigned,    // ummla
-  Mixed,       // usmmla
-  MixedSwapped // usmmla with LHS and RHS swapped
-};
+/// Helper function to extend a vector with elements iN, N < 8 to
+/// a vector of i8. Do sign extension if the parameter `signExt` is true,
+/// zero extension otherwise.
+Value extendSmallIntVector(Location loc, VectorType srcTy, Value val,
+                           bool signExt, PatternRewriter &rewriter) {
+  Type targetTy = srcTy.clone(rewriter.getI8Type());
+  return signExt ? rewriter.createOrFold<arith::ExtSIOp>(loc, targetTy, val)
+                 : rewriter.createOrFold<arith::ExtUIOp>(loc, targetTy, val);
+}
 
-// Create the matrix mulitply and accumulate operation according to `op`.
-Value createMMLA(PatternRewriter &rewriter, MMLA op, Location loc,
-                 mlir::Type accType, Value acc, Value lhs, Value rhs) {
-  switch (op) {
-  case MMLA::Signed:
-    return rewriter.createOrFold<arm_neon::SmmlaOp>(loc, accType, acc, lhs,
-                                                    rhs);
-  case MMLA::Unsigned:
-    return rewriter.createOrFold<arm_neon::UmmlaOp>(loc, accType, acc, lhs,
-                                                    rhs);
-  case MMLA::Mixed:
-    return rewriter.createOrFold<arm_neon::UsmmlaOp>(loc, accType, acc, lhs,
-                                                     rhs);
-  case MMLA::MixedSwapped:
-    // The accumulator comes transposed and the result will be transposed
-    // later, so all we have to do here is swap the operands.
-    return rewriter.createOrFold<arm_neon::UsmmlaOp>(loc, accType, acc, rhs,
-                                                     lhs);
+class VectorContractRewriter {
+protected:
+  // Designate the operation (resp. instruction) used to do sub-tile matrix
+  // multiplications.
+  enum class MMLA {
+    Nop,
+    Signed,      // smmla
+    Unsigned,    // ummla
+    Mixed,       // usmmla
+    MixedSwapped // usmmla with LHS and RHS swapped
+  };
+
+  // Lower-level operation to be emitted.
+  MMLA mmlaOp = MMLA::Nop;
+
+  // The operand tiles. These are not necessarily the operands of
+  // `vector.contract`, for example they could be operands to `arith.extsi`
+  // that is in turn fed into `vector.contract`.
+  Value lhs;
+  Value rhs;
+  Value acc;
+
+  // The dimensions logically corresponding to matrix multiplication of
+  // MxK * KxN -> MxN. The operands and the result do not necessarily have these
+  // shapes, for example RHS could be NxK with a transposing indexing map.
+  int64_t dimM = 0;
+  int64_t dimN = 0;
+  int64_t dimK = 0;
+
+  // Unroll iteration bounds. See documentaiton for `StaticTileOffsetRange`.
+  SmallVector<int64_t> iterationBounds;
+
+  // Sub-tile shape. The algorithm handles operand shapes, which are multiples
+  // of this shape.
+  SmallVector<int64_t> subTileShape;
+
+  // Create the matrix multiply and accumulate operation according to `mmlaOp`.
+  Value createMMLA(PatternRewriter &rewriter, Location loc, Value acc,
+                   Value lhs, Value rhs) {
+    switch (mmlaOp) {
+    case MMLA::Signed:
+      return rewriter.createOrFold<arm_neon::SmmlaOp>(loc, acc.getType(), acc,
+                                                      lhs, rhs);
+    case MMLA::Unsigned:
+      return rewriter.createOrFold<arm_neon::UmmlaOp>(loc, acc.getType(), acc,
+                                                      lhs, rhs);
+    case MMLA::Mixed:
+      return rewriter.createOrFold<arm_neon::UsmmlaOp>(loc, acc.getType(), acc,
+                                                       lhs, rhs);
+    case MMLA::MixedSwapped:
+      // The accumulator comes transposed and the result will be transposed
+      // later, so all we have to do here is swap the operands.
+      return rewriter.createOrFold<arm_neon::UsmmlaOp>(loc, acc.getType(), acc,
+                                                       rhs, lhs);
+    case MMLA::Nop:
+      llvm_unreachable("Uninitialized operation type");
+    }
   }
-}
 
-/// Lowering from a vector::contractOp arm neon smmla intrinsic. This will tile
-/// any vector.contract into multiple smmla instructions with unrolling so long
-/// as [2,2,8] is a divisor of its shape. It can also process vecmats with dimM
-/// = 1 (either explicitly or inferred if LHS has only dimK) If no unrolling is
-/// necessary, a single smmla instruction is emitted.
-class LowerContractionToNeonI8MMPattern
-    : public OpRewritePattern<vector::ContractionOp> {
-public:
-  using OpRewritePattern::OpRewritePattern;
-  LogicalResult matchAndRewrite(vector::ContractionOp op,
-                                PatternRewriter &rewriter) const override {
-    Location loc = op.getLoc();
-    // Infer tile sizes from operands. For vecmat, LHS may only have 1 dim.
-    // Note: RHS is not transposed.
-    mlir::VectorType lhsType = op.getLhsType();
-    mlir::VectorType rhsType = op.getRhsType();
+  // Check common preconditions for applying the patterns and initialize
+  // logical dimensions.
+  LogicalResult matchAndInit(vector::ContractionOp op,
+                             PatternRewriter &rewriter) {
+    // Check iterator types for matrix multiplication.
+    SmallVector<vector::IteratorType> itTypes = op.getIteratorTypesArray();
+    if (!((itTypes.size() == 3 &&
+           (itTypes[0] == vector::IteratorType::parallel &&
+            itTypes[1] == vector::IteratorType::parallel &&
+            itTypes[2] == vector::IteratorType::reduction)) ||
+          (itTypes.size() == 2 &&
+           (itTypes[0] == vector::IteratorType::parallel &&
+            itTypes[1] == vector::IteratorType::reduction))))
+      return rewriter.notifyMatchFailure(
+          op, "iterator types do not correspond to matrix multiplication");
+
     // Avoid 0-D vectors and 1-D rhs:
-    if (!lhsType.hasRank() || !rhsType.hasRank() || rhsType.getRank() < 2)
-      return failure();
+    VectorType lhsType = op.getLhsType();
+    VectorType rhsType = op.getRhsType();
+    if (!lhsType.hasRank() || !rhsType.hasRank() || lhsType.getRank() > 2 ||
+        rhsType.getRank() != 2)
+      return rewriter.notifyMatchFailure(op, "Invalid operand rank");
+
     // This codegen does not work for scalable vectors. Return failure so this
     // pattern is not accidentally chosen over patterns that lower to ArmSVE.
     if (lhsType.isScalable() || rhsType.isScalable())
-      return failure();
-    auto dimM = lhsType.getRank() == 1 ? 1 : lhsType.getDimSize(0);
-    auto dimN = rhsType.getDimSize(0);
-    auto dimK = rhsType.getDimSize(1);
-    bool isVecmat = dimM == 1 ? true : false;
-    if (lhsType.getDimSize(lhsType.getRank() - 1) !=
-        rhsType.getDimSize(rhsType.getRank() - 1)) {
-      return failure(); // dimK mismatch
-    }
-    // Unrolling patterns can handle any [2, 2, 8] shaped multiple of inputs for
-    // tiling.
-    if ((dimM % 2 != 0 && !isVecmat) || dimN % 2 != 0 || dimK % 8 != 0) {
-      return failure();
-    }
-
-    // Check iterator types for contract. All iterators except inner-most
-    // dimension must be parallel.
-    auto iteratorTypes = op.getIteratorTypesArray();
-    if (iteratorTypes.size() > 3 || iteratorTypes[iteratorTypes.size() - 1] !=
-                                        vector::IteratorType::reduction) {
-      return failure();
-    }
-    if (llvm::any_of(ArrayRef<vector::IteratorType>(iteratorTypes).drop_back(1),
-                     [](vector::IteratorType iteratorType) {
-                       return iteratorType != vector::IteratorType::parallel;
-                     })) {
-      return failure();
+      return rewriter.notifyMatchFailure(op,
+                                         "Not applicable to scalable vectors");
+
+    // Initialize dimensions and check for a matching K dimension.
+    dimM = lhsType.getDimSize(0);
+    dimN = rhsType.getDimSize(0);
+    dimK = rhsType.getDimSize(1);
+
+    int64_t lhsDimK;
+    if (lhsType.getRank() == 1) {
+      dimM = 1;
+      lhsDimK = lhsType.getDimSize(0);
+    } else {
+      lhsDimK = lhsType.getDimSize(1);
     }
 
-    // Check inputs are sign-/zero- extensions from iN (N <= 8) to i32. Get the
-    // values before the extension. All four signed/unsigned combinations for
-    // input operands are supported, but they are lowered to different
-    // operations. Determine which is the appropriate operation to lower to.
-    MMLA mmlaOp = MMLA::Signed;
-    auto maybeLhs = getExtOperand<arith::ExtSIOp>(op.getLhs());
-    if (!maybeLhs) {
-      mmlaOp = MMLA::Unsigned;
-      maybeLhs = getExtOperand<arith::ExtUIOp>(op.getLhs());
-    }
-    if (!maybeLhs)
-      return failure();
+    if (lhsDimK != dimK)
+      return rewriter.notifyMatchFailure(op, "Dimensions mismatch");
 
-    auto maybeRhs = getExtOperand<arith::ExtSIOp>(op.getRhs());
-    if (maybeRhs) {
-      if (mmlaOp == MMLA::Unsigned)
-        mmlaOp = MMLA::Mixed;
-    } else {
-      if (mmlaOp == MMLA::Signed)
-        mmlaOp = MMLA::MixedSwapped;
-      maybeRhs = getExtOperand<arith::ExtUIOp>(op.getRhs());
-    }
-    if (!maybeRhs)
-      return failure();
+    return success();
+  }
 
-    Value origLhs = *maybeLhs;
-    Value origRhs = *maybeRhs;
-
-    // Match any iX to i32 for X<8 then turn into an i8 output. Feed into
-    // following neon instruction. Check inputs for extsi are <=i8
-    Value extLhs;
-    Value extRhs;
-    if (auto lhsExtInType = dyn_cast<mlir::VectorType>(origLhs.getType())) {
-      if (lhsExtInType.getElementTypeBitWidth() <= 8) {
-        Type targetLhsExtTy =
-            matchContainerType(rewriter.getI8Type(), lhsExtInType);
-        if (mmlaOp == MMLA::Signed || mmlaOp == MMLA::Mixed)
-          extLhs = rewriter.createOrFold<arith::ExtSIOp>(loc, targetLhsExtTy,
-                                                         origLhs);
-        else
-          extLhs = rewriter.createOrFold<arith::ExtUIOp>(loc, targetLhsExtTy,
-                                                         origLhs);
-      }
-    }
-    if (auto rhsExtInType = dyn_cast<mlir::VectorType>(origRhs.getType())) {
-      if (rhsExtInType.getElementTypeBitWidth() <= 8) {
-        Type targetRhsExtTy =
-            matchContainerType(rewriter.getI8Type(), rhsExtInType);
-        if (mmlaOp == MMLA::Unsigned || mmlaOp == MMLA::Mixed)
-          extRhs = rewriter.createOrFold<arith::ExtUIOp>(loc, targetRhsExtTy,
-                                                         origRhs);
-        else
-          extRhs = rewriter.createOrFold<arith::ExtSIOp>(loc, targetRhsExtTy,
-                                                         origRhs);
-      }
-    }
+public:
+  void rewrite(vector::ContractionOp op, PatternRewriter &rewriter) {
+    // Create some convenience types.
+    auto inputElementType = cast<ShapedType>(lhs.getType()).getElementType();
+    auto accElementType = cast<ShapedType>(acc.getType()).getElementType();
+    auto inputExpandedType =
+        VectorType::get({2, subTileShape.back()}, inputElementType);
+    auto outputExpandedType = VectorType::get({2, 2}, accElementType);
+
+    // One-dimensional representation of logical sub-tiles as required by the
+    // ArmNeon ops.
+    auto collapsedInputType =
+        VectorType::get(inputExpandedType.getNumElements(), inputElementType);
+    auto collapsedOutputType =
+        VectorType::get(outputExpandedType.getNumElements(), accElementType);
+
+    // Get indexing maps for a more concise/convenient access.
+    auto indexingMaps = op.getIndexingMapsArray();
+    AffineMap &lhsPermutationMap = indexingMaps[0];
+    AffineMap &rhsPermutationMap = indexingMaps[1];
+    AffineMap &accPermutationMap = indexingMaps[2];
 
-    if (!extLhs || !extRhs) {
-      return failure();
-    }
+    Location loc = op.getLoc();
 
     // Initial accumulator for the final result. This is the un-tiled result if
     // tiling is done.
     Value result = rewriter.create<arith::ConstantOp>(
         loc, op.getResultType(), rewriter.getZeroAttr(op.getResultType()));
 
-    SmallVector<int64_t> unrolledSize = *op.getShapeForUnroll();
-    SmallVector<int64_t> smmlaShape = {2, 8};
-    SmallVector<int64_t> loopOrder = {0, 1};
-    if (unrolledSize.size() == 3) {
-      smmlaShape.insert(smmlaShape.begin(), isVecmat ? 1 : 2);
+    SmallVector<int64_t, 3> loopOrder = {0, 1};
+    if (iterationBounds.size() == 3)
       loopOrder.push_back(2);
-    }
 
     // Keep track of the previous accumulator when tiling over K.
     Value kAcc;
     for (SmallVector<int64_t> offsets :
-         StaticTileOffsetRange(unrolledSize, smmlaShape, loopOrder)) {
+         StaticTileOffsetRange(iterationBounds, subTileShape, loopOrder)) {
       // Helper to compute the new shape of each operand and extract the slice.
       auto extractOperand = [&](Value operand, AffineMap permutationMap,
                                 ArrayRef<int64_t> operandOffsets) {
-        SmallVector<int64_t> operandShape =
-            applyPermutationMap(permutationMap, ArrayRef<int64_t>(smmlaShape));
+        SmallVector<int64_t> operandShape = applyPermutationMap(
+            permutationMap, ArrayRef<int64_t>(subTileShape));
         SmallVector<int64_t> operandStrides(operandOffsets.size(), 1);
         return rewriter.createOrFold<vector::ExtractStridedSliceOp>(
             loc, operand, operandOffsets, operandShape, operandStrides);
       };
 
       // Extract tiled lhs, rhs, and acc
-      AffineMap lhsPermutationMap = op.getIndexingMapsArray()[0];
       SmallVector<int64_t> lhsOffsets =
           applyPermutationMap(lhsPermutationMap, ArrayRef<int64_t>(offsets));
-      Value tiledLhs = extractOperand(extLhs, lhsPermutationMap, lhsOffsets);
-      AffineMap rhsPermutationMap = op.getIndexingMapsArray()[1];
+      Value tiledLhs = extractOperand(lhs, lhsPermutationMap, lhsOffsets);
       SmallVector<int64_t> rhsOffsets =
           applyPermutationMap(rhsPermutationMap, ArrayRef<int64_t>(offsets));
-      Value tiledRhs = extractOperand(extRhs, rhsPermutationMap, rhsOffsets);
-      AffineMap accPermutationMap = op.getIndexingMapsArray()[2];
+      Value tiledRhs = extractOperand(rhs, rhsPermutationMap, rhsOffsets);
       SmallVector<int64_t> accOffsets =
           applyPermutationMap(accPermutationMap, ArrayRef<int64_t>(offsets));
-      Value tiledAcc =
-          extractOperand(op.getAcc(), accPermutationMap, accOffsets);
-
-      auto inputElementType =
-          cast<ShapedType>(tiledLhs.getType()).getElementType();
-      auto accElementType =
-          cast<ShapedType>(tiledAcc.getType()).getElementType();
-      auto inputExpandedType = VectorType::get({2, 8}, inputElementType);
-      auto outputExpandedType = VectorType::get({2, 2}, accElementType);
+      Value tiledAcc = extractOperand(acc, accPermutationMap, accOffsets);
 
       // With vecmat, tiled LHS and ACC will contain only one of 2 necessary
-      // rows along dimM. Expand their shapes to match the smmla op.
-      if (isVecmat) {
-        auto expandForSMMLA = [&](Value tiledOperand,
-                                  VectorType expandedTypeType) {
+      // rows along dimM. Expand their shapes to match the ArmNeon op.
+      if (dimM == 1) {
+        auto expandRowVector = [&](Value tiledOperand,
+                                   VectorType expandedTypeType) {
           auto emptyOperand = rewriter.create<arith::ConstantOp>(
               loc, expandedTypeType, rewriter.getZeroAttr(expandedTypeType));
           SmallVector<int64_t> offsets(
@@ -290,8 +266,8 @@ class LowerContractionToNeonI8MMPattern
           return rewriter.createOrFold<vector::InsertStridedSliceOp>(
               loc, tiledOperand, emptyOperand, offsets, strides);
         };
-        tiledLhs = expandForSMMLA(tiledLhs, inputExpandedType);
-        tiledAcc = expandForSMMLA(tiledAcc, outputExpandedType);
+        tiledLhs = expandRowVector(tiledLhs, inputExpandedType);
+        tiledAcc = expandRowVector(tiledAcc, outputExpandedType);
       }
 
       // Transpose ACC if doing signed by unsigned multiplication, because we're
@@ -301,15 +277,11 @@ class LowerContractionToNeonI8MMPattern
         tiledAcc = rewriter.create<vector::TransposeOp>(
             loc, tiledAcc, ArrayRef<int64_t>({1, 0}));
 
-      // Collapse tiled operands to 1D vectors required by smmla intrinsic
-      auto collapsedInputType =
-          VectorType::get(inputExpandedType.getNumElements(), inputElementType);
+      // Collapse tiled operands to 1D vectors required by the ArmNeon ops
       auto collapsedLhs = rewriter.createOrFold<vector::ShapeCastOp>(
           tiledLhs.getLoc(), collapsedInputType, tiledLhs);
       auto collapsedRhs = rewriter.createOrFold<vector::ShapeCastOp>(
           tiledRhs.getLoc(), collapsedInputType, tiledRhs);
-      auto collapsedOutputType =
-          VectorType::get(outputExpandedType.getNumElements(), accElementType);
 
       bool initialKAcc = offsets.back() == 0;
       Value collapsedRes;
@@ -321,8 +293,8 @@ class LowerContractionToNeonI8MMPattern
       }
 
       // Insert contract op
-      kAcc = createMMLA(rewriter, mmlaOp, op.getLoc(), collapsedRes.getType(),
-                        collapsedRes, collapsedLhs, collapsedRhs);
+      kAcc =
+          createMMLA(rewriter, loc, collapsedRes, collapsedLhs, collapsedRhs);
 
       // Reshape output back to 2D
       Value tiledRes = rewriter.createOrFold<vector::ShapeCastOp>(
@@ -336,9 +308,8 @@ class LowerContractionToNeonI8MMPattern
 
       // With vecmat, only one row of tiled ACC can be inserted into the final
       // result
-      if (isVecmat) {
+      if (dimM == 1)
         tiledRes = rewriter.createOrFold<vector::ExtractOp>(loc, tiledRes, 0);
-      }
 
       // Insert the tiled result back into the non tiled result of the
       // contract op.
@@ -349,6 +320,98 @@ class LowerContractionToNeonI8MMPattern
     }
 
     rewriter.replaceOp(op, result);
+  }
+};
+
+class VectorContractRewriterI8MM : public VectorContractRewriter {
+public:
+  LogicalResult matchAndInit(vector::ContractionOp op,
+                             PatternRewriter &rewriter) {
+    if (failed(VectorContractRewriter::matchAndInit(op, rewriter)))
+      return failure();
+
+    // Unrolling patterns can handle any [2, 2, 8] shaped multiple of inputs for
+    // tiling.
+    if ((dimM != 1 && dimM % 2 != 0) || dimN % 2 != 0 || dimK % 8 != 0)
+      return rewriter.notifyMatchFailure(op, "Unsupported operand shapes");
+
+    // Check inputs are sign-/zero- extensions from iN (N <= 8) to i32. Get the
+    // values before the extension. All four signed/unsigned combinations for
+    // input operands are supported, but they are lowered to different
+    // operations. Determine which is the appropriate operation to lower to.
+    mmlaOp = MMLA::Signed;
+    auto maybeLhs = getExtOperand<arith::ExtSIOp>(op.getLhs());
+    if (!maybeLhs) {
+      mmlaOp = MMLA::Unsigned;
+      maybeLhs = getExtOperand<arith::ExtUIOp>(op.getLhs());
+    }
+    if (!maybeLhs)
+      return rewriter.notifyMatchFailure(
+          op, "LHS is not a sign- or zero- extended iN, N <= 8");
+
+    auto maybeRhs = getExtOperand<arith::ExtSIOp>(op.getRhs());
+    if (maybeRhs) {
+      if (mmlaOp == MMLA::Unsigned)
+        mmlaOp = MMLA::Mixed;
+    } else {
+      if (mmlaOp == MMLA::Signed)
+        mmlaOp = MMLA::MixedSwapped;
+      maybeRhs = getExtOperand<arith::ExtUIOp>(op.getRhs());
+    }
+
+    if (!maybeRhs)
+      return rewriter.notifyMatchFailure(
+          op, "RHS is not a sign- or zero- extended iN, N <= 8");
+
+    lhs = *maybeLhs;
+    rhs = *maybeRhs;
+    acc = op.getAcc();
+
+    // Extend inputs from iN, N < 8 to i8.
+    Location loc = op.getLoc();
+    auto lhsExtInType = cast<VectorType>(lhs.getType());
+    if (lhsExtInType.getElementTypeBitWidth() < 8)
+      lhs = extendSmallIntVector(loc, lhsExtInType, lhs,
+                                 /* signExt */ mmlaOp == MMLA::Signed ||
+                                     mmlaOp == MMLA::Mixed,
+                                 rewriter);
+
+    auto rhsExtInType = cast<VectorType>(rhs.getType());
+    if (rhsExtInType.getElementTypeBitWidth() < 8)
+
+      rhs = extendSmallIntVector(loc, rhsExtInType, rhs,
+                                 /* signExt */ mmlaOp != MMLA::Unsigned &&
+                                     mmlaOp != MMLA::Mixed,
+                                 rewriter);
+
+    // Initialize parameters for unrolling.
+    iterationBounds = *op.getShapeForUnroll();
+    if (iterationBounds.size() == 3)
+      subTileShape = SmallVector<int64_t>({dimM == 1 ? 1 : 2, 2, 8});
+    else
+      subTileShape = SmallVector<int64_t>({2, 8});
+
+    return success();
+  }
+};
+
+/// Lowering from a vector::contractOp arm neon smmla intrinsic. This will tile
+/// any vector.contract into multiple smmla instructions with unrolling so long
+/// as [2,2,8] is a divisor of its shape. It can also process vecmats with dimM
+/// = 1 (either explicitly or inferred if LHS has only dimK) If no unrolling is
+/// necessary, a single smmla instruction is emitted.
+class LowerContractionToNeonI8MMPattern
+    : public OpRewritePattern<vector::ContractionOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(vector::ContractionOp op,
+                                PatternRewriter &rewriter) const override {
+
+    VectorContractRewriterI8MM vcr;
+    if (failed(vcr.matchAndInit(op, rewriter)))
+      return failure();
+    vcr.rewrite(op, rewriter);
+
     return success();
   }
 };

>From 0f64f8665174ac14738aae32e28d127266a7e6c7 Mon Sep 17 00:00:00 2001
From: Momchil Velikov <momchil.velikov at arm.com>
Date: Tue, 22 Jul 2025 14:06:43 +0000
Subject: [PATCH 2/2] [fixup] Rename a member function

---
 .../ArmNeon/Transforms/LowerContractionToNeonI8MMPattern.cpp  | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToNeonI8MMPattern.cpp b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToNeonI8MMPattern.cpp
index 7e6a2bab59a83..59acb362191a7 100644
--- a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToNeonI8MMPattern.cpp
+++ b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToNeonI8MMPattern.cpp
@@ -195,7 +195,7 @@ class VectorContractRewriter {
   }
 
 public:
-  void rewrite(vector::ContractionOp op, PatternRewriter &rewriter) {
+  void lower(vector::ContractionOp op, PatternRewriter &rewriter) {
     // Create some convenience types.
     auto inputElementType = cast<ShapedType>(lhs.getType()).getElementType();
     auto accElementType = cast<ShapedType>(acc.getType()).getElementType();
@@ -410,7 +410,7 @@ class LowerContractionToNeonI8MMPattern
     VectorContractRewriterI8MM vcr;
     if (failed(vcr.matchAndInit(op, rewriter)))
       return failure();
-    vcr.rewrite(op, rewriter);
+    vcr.lower(op, rewriter);
 
     return success();
   }



More information about the Mlir-commits mailing list