[Mlir-commits] [mlir] [MLIR][AArch64] Lower `vector.contract` with mixed signed/unsigned arguments to Neon FEAT_I8MM (PR #144698)

Momchil Velikov llvmlistbot at llvm.org
Wed Jun 25 02:31:14 PDT 2025


https://github.com/momchil-velikov updated https://github.com/llvm/llvm-project/pull/144698

>From 8f97d89e13760517b478560fc9715f69a5eafa8a Mon Sep 17 00:00:00 2001
From: Momchil Velikov <momchil.velikov at arm.com>
Date: Wed, 18 Jun 2025 11:38:24 +0000
Subject: [PATCH 1/3] [MLIR][AArch64] Lower `vector.contract` with mixed
 signend/unsigned arguments to Neon FEAT_I8MM

---
 .../LowerContractionToSMMLAPattern.cpp        | 159 +++++++++++++++---
 .../Dialect/ArmNeon/lower-to-arm-neon.mlir    |  80 +++++++--
 2 files changed, 203 insertions(+), 36 deletions(-)

diff --git a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
index 5ce3d2b28aeb3..967aff579227b 100644
--- a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
+++ b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
@@ -37,6 +37,81 @@ static Type matchContainerType(Type element, Type container) {
   return element;
 }
 
+// Get the operand of a `vector.contract`. This function is intended to abstract
+// away from the particular way a value is extended before feeding it into the
+// `vector.contract` - via zero-extend or an explicit or implicit sign-extend
+// (for implicit sign-extension see `vector.contract` documentation).
+//
+// The template parameter `Op` indicates the extension operation (explicit or
+// implicit) for which we are checking.
+//
+// Return success only for extensions from `iN` (N <= 8) to `i32`.
+template <typename Op>
+std::optional<Value> getExtOperand(Value v) {
+
+  static_assert(llvm::is_one_of<Op, arith::ExtSIOp, arith::ExtUIOp>::value,
+                "Must be instantiated with either sign- or zero- extension op");
+
+  // If the operand is not defined by an explicit extend operation of the
+  // accepted operation type allow for an implicit sign-extension.
+  auto extOp = dyn_cast_or_null<Op>(v.getDefiningOp());
+  if (!extOp) {
+    if constexpr (std::is_same<Op, arith::ExtSIOp>::value) {
+      auto eltTy = cast<VectorType>(v.getType()).getElementType();
+      if (!eltTy.isSignlessInteger() || eltTy.getIntOrFloatBitWidth() > 8)
+        return {};
+      return v;
+    }
+    return {};
+  }
+
+  // If the operand is defined by an explicit extend operation of the accepted
+  // operation type, check it's extended from `iN` (N <= 8) to `i32`.
+  auto inOp = extOp.getIn();
+  auto inTy = dyn_cast<VectorType>(inOp.getType());
+  if (!inTy)
+    return {};
+  auto inEltTy = inTy.getElementType();
+  if (!inEltTy.isSignlessInteger() || inEltTy.getIntOrFloatBitWidth() > 8)
+    return {};
+
+  auto outTy = dyn_cast<VectorType>(extOp.getType());
+  if (!(outTy && outTy.getElementType().isSignlessInteger(32)))
+    return {};
+
+  return inOp;
+}
+
+// Designate the operation (resp. instruction) used to do sub-tile matrix
+// multiplications.
+enum class MMLA {
+  Signed,      // smmla
+  Unsigned,    // ummla
+  Mixed,       // usmmla
+  MixedSwapped // usmmla with LHS and RHS swapped
+};
+
+// Create the matrix mulitply and accumulate operation according to `op`.
+Value createMMLA(PatternRewriter &rewriter, MMLA op, Location loc,
+                 mlir::Type accType, Value acc, Value lhs, Value rhs) {
+  switch (op) {
+  case MMLA::Signed:
+    return rewriter.createOrFold<arm_neon::SmmlaOp>(loc, accType, acc, lhs,
+                                                    rhs);
+  case MMLA::Unsigned:
+    return rewriter.createOrFold<arm_neon::UmmlaOp>(loc, accType, acc, lhs,
+                                                    rhs);
+  case MMLA::Mixed:
+    return rewriter.createOrFold<arm_neon::UsmmlaOp>(loc, accType, acc, lhs,
+                                                     rhs);
+  case MMLA::MixedSwapped:
+    // The accumulator comes transposed and the result will be transposed
+    // later, so all we have to do here is swap the operands.
+    return rewriter.createOrFold<arm_neon::UsmmlaOp>(loc, accType, acc, rhs,
+                                                     lhs);
+  }
+}
+
 /// Lowering from a vector::contractOp arm neon smmla intrinsic. This will tile
 /// any vector.contract into multiple smmla instructions with unrolling so long
 /// as [2,2,8] is a divisor of its shape. It can also process vecmats with dimM
@@ -88,39 +163,64 @@ class LowerContractionToSMMLAPattern
       return failure();
     }
 
-    // Check two extsi inputs Rhs Lhs for contract.
-    arith::ExtSIOp origLhsExtOp =
-        dyn_cast_or_null<arith::ExtSIOp>(op.getLhs().getDefiningOp());
-    arith::ExtSIOp origRhsExtOp =
-        dyn_cast_or_null<arith::ExtSIOp>(op.getRhs().getDefiningOp());
-    if (!origLhsExtOp || !origRhsExtOp) {
+    // Check inputs are sign-/zero- extensions from iN (N <= 8) to i32. Get the
+    // values before the extension. All four signed/unsigned combinations for
+    // input operands are supported, but they are lowered to different
+    // operations. Determine which is the appropriate operation to lower to.
+    MMLA mmlaOp = MMLA::Signed;
+    auto maybeLhs = getExtOperand<arith::ExtSIOp>(op.getLhs());
+    if (!maybeLhs) {
+      mmlaOp = MMLA::Unsigned;
+      maybeLhs = getExtOperand<arith::ExtUIOp>(op.getLhs());
+    }
+    if (!maybeLhs)
       return failure();
+
+    auto maybeRhs = getExtOperand<arith::ExtSIOp>(op.getRhs());
+    if (maybeRhs) {
+      if (mmlaOp == MMLA::Unsigned)
+        mmlaOp = MMLA::Mixed;
+    } else {
+      if (mmlaOp == MMLA::Signed)
+        mmlaOp = MMLA::MixedSwapped;
+      maybeRhs = getExtOperand<arith::ExtUIOp>(op.getRhs());
     }
+    if (!maybeRhs)
+      return failure();
+
+    Value origLhs = *maybeLhs;
+    Value origRhs = *maybeRhs;
 
     // Match any iX to i32 for X<8 then turn into an i8 output. Feed into
     // following neon instruction. Check inputs for extsi are <=i8
-    Value extsiLhs;
-    Value extsiRhs;
-    if (auto lhsExtInType =
-            dyn_cast<mlir::VectorType>(origLhsExtOp.getIn().getType())) {
+    Value extLhs;
+    Value extRhs;
+    if (auto lhsExtInType = dyn_cast<mlir::VectorType>(origLhs.getType())) {
       if (lhsExtInType.getElementTypeBitWidth() <= 8) {
         Type targetLhsExtTy =
             matchContainerType(rewriter.getI8Type(), lhsExtInType);
-        extsiLhs = rewriter.createOrFold<arith::ExtSIOp>(loc, targetLhsExtTy,
-                                                         origLhsExtOp.getIn());
+        if (mmlaOp == MMLA::Signed || mmlaOp == MMLA::Mixed)
+          extLhs = rewriter.createOrFold<arith::ExtSIOp>(loc, targetLhsExtTy,
+                                                         origLhs);
+        else
+          extLhs = rewriter.createOrFold<arith::ExtUIOp>(loc, targetLhsExtTy,
+                                                         origLhs);
       }
     }
-    if (auto rhsExtInType =
-            dyn_cast<mlir::VectorType>(origRhsExtOp.getIn().getType())) {
+    if (auto rhsExtInType = dyn_cast<mlir::VectorType>(origRhs.getType())) {
       if (rhsExtInType.getElementTypeBitWidth() <= 8) {
         Type targetRhsExtTy =
             matchContainerType(rewriter.getI8Type(), rhsExtInType);
-        extsiRhs = rewriter.createOrFold<arith::ExtSIOp>(loc, targetRhsExtTy,
-                                                         origRhsExtOp.getIn());
+        if (mmlaOp == MMLA::Unsigned || mmlaOp == MMLA::Mixed)
+          extRhs = rewriter.createOrFold<arith::ExtUIOp>(loc, targetRhsExtTy,
+                                                         origRhs);
+        else
+          extRhs = rewriter.createOrFold<arith::ExtSIOp>(loc, targetRhsExtTy,
+                                                         origRhs);
       }
     }
 
-    if (!extsiLhs || !extsiRhs) {
+    if (!extLhs || !extRhs) {
       return failure();
     }
 
@@ -155,11 +255,11 @@ class LowerContractionToSMMLAPattern
       AffineMap lhsPermutationMap = op.getIndexingMapsArray()[0];
       SmallVector<int64_t> lhsOffsets =
           applyPermutationMap(lhsPermutationMap, ArrayRef<int64_t>(offsets));
-      Value tiledLhs = extractOperand(extsiLhs, lhsPermutationMap, lhsOffsets);
+      Value tiledLhs = extractOperand(extLhs, lhsPermutationMap, lhsOffsets);
       AffineMap rhsPermutationMap = op.getIndexingMapsArray()[1];
       SmallVector<int64_t> rhsOffsets =
           applyPermutationMap(rhsPermutationMap, ArrayRef<int64_t>(offsets));
-      Value tiledRhs = extractOperand(extsiRhs, rhsPermutationMap, rhsOffsets);
+      Value tiledRhs = extractOperand(extRhs, rhsPermutationMap, rhsOffsets);
       AffineMap accPermutationMap = op.getIndexingMapsArray()[2];
       SmallVector<int64_t> accOffsets =
           applyPermutationMap(accPermutationMap, ArrayRef<int64_t>(offsets));
@@ -191,6 +291,13 @@ class LowerContractionToSMMLAPattern
         tiledAcc = expandForSMMLA(tiledAcc, outputExpandedType);
       }
 
+      // Transpose ACC if doing signed by unsigned multiplication, because we're
+      // using the instruction for unsigned by signed multiplication with
+      // reversed operands.
+      if (mmlaOp == MMLA::MixedSwapped)
+        tiledAcc = rewriter.create<vector::TransposeOp>(
+            loc, tiledAcc, ArrayRef<int64_t>({1, 0}));
+
       // Collapse tiled operands to 1D vectors required by smmla intrinsic
       auto collapsedInputType =
           VectorType::get(inputExpandedType.getNumElements(), inputElementType);
@@ -211,15 +318,21 @@ class LowerContractionToSMMLAPattern
       }
 
       // Insert contract op
-      kAcc = rewriter.createOrFold<arm_neon::SmmlaOp>(
-          op.getLoc(), collapsedRes.getType(), collapsedRes, collapsedLhs,
-          collapsedRhs);
+      kAcc = createMMLA(rewriter, mmlaOp, op.getLoc(), collapsedRes.getType(),
+                        collapsedRes, collapsedLhs, collapsedRhs);
 
       // Reshape output back to 2D
       Value tiledRes = rewriter.createOrFold<vector::ShapeCastOp>(
           kAcc.getLoc(), tiledAcc.getType(), kAcc);
 
-      // With vecmat, only one row of tiled ACC can be inserted into file result
+      // Because of the reversed operands the result is obtained transposed.
+      // Transpose it back,
+      if (mmlaOp == MMLA::MixedSwapped)
+        tiledRes = rewriter.create<vector::TransposeOp>(
+            loc, tiledRes, ArrayRef<int64_t>({1, 0}));
+
+      // With vecmat, only one row of tiled ACC can be inserted into the final
+      // result
       if (isVecmat) {
         tiledRes = rewriter.createOrFold<vector::ExtractOp>(loc, tiledRes, 0);
       }
diff --git a/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir b/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir
index e4f7ea150c850..5fc29c6442602 100644
--- a/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir
+++ b/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir
@@ -17,14 +17,28 @@ func.func @vector_arm_neon_mixed_types(%lhs: vector<2x8xi8>, %rhs: vector<2x8xi4
 
 // -----
 
-// CHECK-LABEL: vector_arm_neon_same_types
-// CHECK-SAME:    %[[A0:.*]]: vector<2x8xi8>, %[[A1:.*]]: vector<2x8xi8>, %[[A2:.*]]: vector<2x2xi32>
-// CHECK-DAG: %[[D0:.*]] = vector.shape_cast %[[A0]] : vector<2x8xi8> to vector<16xi8>
-// CHECK-DAG: %[[D1:.*]] = vector.shape_cast %[[A1]] : vector<2x8xi8> to vector<16xi8>
-// CHECK-DAG: %[[D2:.*]] = vector.shape_cast %[[A2]] : vector<2x2xi32> to vector<4xi32>
-// CHECK-DAG: %[[D3:.*]] = arm_neon.intr.smmla %[[D2]], %[[D0]], %[[D1]] : vector<16xi8> to vector<4xi32>
-// CHECK-DAG: %[[D4:.*]] = vector.shape_cast %[[D3]] : vector<4xi32> to vector<2x2xi32>
-func.func @vector_arm_neon_same_types(%lhs: vector<2x8xi8>, %rhs: vector<2x8xi8>, %acc : vector<2x2xi32>) -> vector<2x2xi32> {
+// CHECK-LABEL: vector_arm_neon_implicit_extsi
+// CHECK-SAME:    %[[LHS:.+]]: vector<2x8xi8>, %[[RHS:.+]]: vector<2x8xi8>, %[[ACC:.+]]: vector<2x2xi32>
+// CHECK:       %[[L:.+]] = vector.shape_cast %[[LHS]] : vector<2x8xi8> to vector<16xi8>
+// CHECK:       %[[R:.+]] = vector.shape_cast %[[RHS]] : vector<2x8xi8> to vector<16xi8>
+// CHECK:       %[[A:.+]] = vector.shape_cast %[[ACC]] : vector<2x2xi32> to vector<4xi32>
+// CHECK:       %[[M:.+]] = arm_neon.intr.smmla %[[A]], %[[L]], %[[R]] : vector<16xi8> to vector<4xi32>
+// CHECK:       %{{.+}} = vector.shape_cast %[[M]] : vector<4xi32> to vector<2x2xi32>
+func.func @vector_arm_neon_implicit_extsi(%lhs: vector<2x8xi8>, %rhs: vector<2x8xi8>, %acc : vector<2x2xi32>) -> vector<2x2xi32> {
+  %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs, %rhs, %acc : vector<2x8xi8>, vector<2x8xi8> into vector<2x2xi32>
+  return %res : vector<2x2xi32>
+}
+
+// -----
+
+// CHECK-LABEL: vector_arm_neon_signed_signed
+// CHECK-SAME:    %[[LHS:.+]]: vector<2x8xi8>, %[[RHS:.+]]: vector<2x8xi8>, %[[ACC:.+]]: vector<2x2xi32>
+// CHECK:       %[[L:.+]] = vector.shape_cast %[[LHS]] : vector<2x8xi8> to vector<16xi8>
+// CHECK:       %[[R:.+]] = vector.shape_cast %[[RHS]] : vector<2x8xi8> to vector<16xi8>
+// CHECK:       %[[A:.+]] = vector.shape_cast %[[ACC]] : vector<2x2xi32> to vector<4xi32>
+// CHECK:       %[[M:.+]] = arm_neon.intr.smmla %[[A]], %[[L]], %[[R]] : vector<16xi8> to vector<4xi32>
+// CHECK:       %{{.+}} = vector.shape_cast %[[M]] : vector<4xi32> to vector<2x2xi32>
+func.func @vector_arm_neon_signed_signed(%lhs: vector<2x8xi8>, %rhs: vector<2x8xi8>, %acc : vector<2x2xi32>) -> vector<2x2xi32> {
   %lhs_extsi = arith.extsi %lhs : vector<2x8xi8> to vector<2x8xi32>
   %rhs_extsi = arith.extsi %rhs : vector<2x8xi8> to vector<2x8xi32>
   %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs_extsi, %rhs_extsi, %acc : vector<2x8xi32>, vector<2x8xi32> into vector<2x2xi32>
@@ -33,11 +47,51 @@ func.func @vector_arm_neon_same_types(%lhs: vector<2x8xi8>, %rhs: vector<2x8xi8>
 
 // -----
 
-// CHECK-LABEL: vector_arm_neon_without_extsi
-// CHECK-SAME:    %[[A0:.*]]: vector<2x8xi32>, %[[A1:.*]]: vector<2x8xi32>, %[[A2:.*]]: vector<2x2xi32>
-// CHECK-DAG: %[[D0:.*]] = vector.contract
-func.func @vector_arm_neon_without_extsi(%lhs: vector<2x8xi32>, %rhs: vector<2x8xi32>, %acc : vector<2x2xi32>) -> vector<2x2xi32> {
-  %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs, %rhs, %acc : vector<2x8xi32>, vector<2x8xi32> into vector<2x2xi32>
+// CHECK-LABEL: vector_arm_neon_unsigned_signed
+// CHECK-SAME:    %[[LHS:.+]]: vector<2x8xi8>, %[[RHS:.+]]: vector<2x8xi8>, %[[ACC:.+]]: vector<2x2xi32>
+// CHECK:       %[[L:.+]] = vector.shape_cast %[[LHS]] : vector<2x8xi8> to vector<16xi8>
+// CHECK:       %[[R:.+]] = vector.shape_cast %[[RHS]] : vector<2x8xi8> to vector<16xi8>
+// CHECK:       %[[A:.+]] = vector.shape_cast %[[ACC]] : vector<2x2xi32> to vector<4xi32>
+// CHECK:       %[[M:.+]] = arm_neon.intr.usmmla %[[A]], %[[L]], %[[R]] : vector<16xi8> to vector<4xi32>
+// CHECK:       %{{.+}} = vector.shape_cast %[[M]] : vector<4xi32> to vector<2x2xi32>
+func.func @vector_arm_neon_unsigned_signed(%lhs: vector<2x8xi8>, %rhs: vector<2x8xi8>, %acc : vector<2x2xi32>) -> vector<2x2xi32> {
+  %lhs_extsi = arith.extui %lhs : vector<2x8xi8> to vector<2x8xi32>
+  %rhs_extsi = arith.extsi %rhs : vector<2x8xi8> to vector<2x8xi32>
+  %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs_extsi, %rhs_extsi, %acc : vector<2x8xi32>, vector<2x8xi32> into vector<2x2xi32>
+  return %res : vector<2x2xi32>
+}
+
+// -----
+
+// CHECK-LABEL: vector_arm_neon_unsigned_unsigned
+// CHECK-SAME:    %[[LHS:.+]]: vector<2x8xi8>, %[[RHS:.+]]: vector<2x8xi8>, %[[ACC:.+]]: vector<2x2xi32>
+// CHECK:       %[[L:.+]] = vector.shape_cast %[[LHS]] : vector<2x8xi8> to vector<16xi8>
+// CHECK:       %[[R:.+]] = vector.shape_cast %[[RHS]] : vector<2x8xi8> to vector<16xi8>
+// CHECK:       %[[A:.+]] = vector.shape_cast %[[ACC]] : vector<2x2xi32> to vector<4xi32>
+// CHECK:       %[[M:.+]] = arm_neon.intr.ummla %[[A]], %[[L]], %[[R]] : vector<16xi8> to vector<4xi32>
+// CHECK:       %{{.+}} = vector.shape_cast %[[M]] : vector<4xi32> to vector<2x2xi32>
+func.func @vector_arm_neon_unsigned_unsigned(%lhs: vector<2x8xi8>, %rhs: vector<2x8xi8>, %acc : vector<2x2xi32>) -> vector<2x2xi32> {
+  %lhs_extsi = arith.extui %lhs : vector<2x8xi8> to vector<2x8xi32>
+  %rhs_extsi = arith.extui %rhs : vector<2x8xi8> to vector<2x8xi32>
+  %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs_extsi, %rhs_extsi, %acc : vector<2x8xi32>, vector<2x8xi32> into vector<2x2xi32>
+  return %res : vector<2x2xi32>
+}
+
+// -----
+
+// CHECK-LABEL: vector_arm_neon_signed_unsigned
+// CHECK-SAME:    %[[LHS:.+]]: vector<2x8xi8>, %[[RHS:.+]]: vector<2x8xi8>, %[[ACC:.+]]: vector<2x2xi32>
+// CHECK:       %[[ACC_T:.+]] = vector.transpose %[[ACC]], [1, 0] : vector<2x2xi32> to vector<2x2xi32>
+// CHECK:       %[[L:.+]] = vector.shape_cast %[[LHS]] : vector<2x8xi8> to vector<16xi8>
+// CHECK:       %[[R:.+]] = vector.shape_cast %[[RHS]] : vector<2x8xi8> to vector<16xi8>
+// CHECK:       %[[A:.+]] = vector.shape_cast %[[ACC_T]] : vector<2x2xi32> to vector<4xi32>
+// CHECK:       %[[M:.+]] = arm_neon.intr.usmmla %[[A]], %[[R]], %[[L]] : vector<16xi8> to vector<4xi32>
+// CHECK:       %[[OUT_T:.+]] = vector.shape_cast %[[M]] : vector<4xi32> to vector<2x2xi32>
+// CHECK:       %{{.+}} = vector.transpose %[[OUT_T]], [1, 0] : vector<2x2xi32> to vector<2x2xi32>
+func.func @vector_arm_neon_signed_unsigned(%lhs: vector<2x8xi8>, %rhs: vector<2x8xi8>, %acc : vector<2x2xi32>) -> vector<2x2xi32> {
+  %lhs_extsi = arith.extsi %lhs : vector<2x8xi8> to vector<2x8xi32>
+  %rhs_extsi = arith.extui %rhs : vector<2x8xi8> to vector<2x8xi32>
+  %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs_extsi, %rhs_extsi, %acc : vector<2x8xi32>, vector<2x8xi32> into vector<2x2xi32>
   return %res : vector<2x2xi32>
 }
 

>From d52eb9bb2e6876c503e0f7a13dd7d3aa8621c76b Mon Sep 17 00:00:00 2001
From: Momchil Velikov <momchil.velikov at arm.com>
Date: Wed, 18 Jun 2025 13:42:28 +0000
Subject: [PATCH 2/3] [fixup] Remove "SMMLA" from some names

... since we generate now all of smmla/ummla/usmmla.
---
 mlir/include/mlir/Dialect/ArmNeon/Transforms.h       |  2 +-
 .../VectorToLLVM/ConvertVectorToLLVMPass.cpp         |  2 +-
 .../TransformOps/ArmNeonVectorTransformOps.cpp       |  2 +-
 mlir/lib/Dialect/ArmNeon/Transforms/CMakeLists.txt   |  2 +-
 ...ern.cpp => LowerContractionToNeonI8MMPattern.cpp} | 12 ++++++------
 .../Transforms/LowerContractionToSVEI8MMPattern.cpp  |  2 +-
 6 files changed, 11 insertions(+), 11 deletions(-)
 rename mlir/lib/Dialect/ArmNeon/Transforms/{LowerContractionToSMMLAPattern.cpp => LowerContractionToNeonI8MMPattern.cpp} (97%)

diff --git a/mlir/include/mlir/Dialect/ArmNeon/Transforms.h b/mlir/include/mlir/Dialect/ArmNeon/Transforms.h
index 52ebea2d0ffd9..2f0f634a96770 100644
--- a/mlir/include/mlir/Dialect/ArmNeon/Transforms.h
+++ b/mlir/include/mlir/Dialect/ArmNeon/Transforms.h
@@ -13,7 +13,7 @@ namespace mlir {
 class RewritePatternSet;
 
 namespace arm_neon {
-void populateLowerContractionToSMMLAPatternPatterns(
+void populateLowerContractionToNeonI8MMPatternPatterns(
     RewritePatternSet &patterns);
 } // namespace arm_neon
 
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index 293e01a5bf4d4..67c0eca15638a 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -85,7 +85,7 @@ void ConvertVectorToLLVMPass::runOnOperation() {
     populateVectorGatherLoweringPatterns(patterns);
     if (armI8MM) {
       if (armNeon)
-        arm_neon::populateLowerContractionToSMMLAPatternPatterns(patterns);
+        arm_neon::populateLowerContractionToNeonI8MMPatternPatterns(patterns);
       if (armSVE)
         populateLowerContractionToSVEI8MMPatternPatterns(patterns);
     }
diff --git a/mlir/lib/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.cpp b/mlir/lib/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.cpp
index e81fc6a8b5980..d07e6a52d8b5f 100644
--- a/mlir/lib/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.cpp
+++ b/mlir/lib/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.cpp
@@ -20,7 +20,7 @@ using namespace mlir;
 
 void transform::ApplyArmNeonContractionToI8MMPatternsOp::populatePatterns(
     RewritePatternSet &patterns) {
-  arm_neon::populateLowerContractionToSMMLAPatternPatterns(patterns);
+  arm_neon::populateLowerContractionToNeonI8MMPatternPatterns(patterns);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/ArmNeon/Transforms/CMakeLists.txt b/mlir/lib/Dialect/ArmNeon/Transforms/CMakeLists.txt
index 84fb1b0116d2a..06bafde451cbb 100644
--- a/mlir/lib/Dialect/ArmNeon/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/ArmNeon/Transforms/CMakeLists.txt
@@ -1,5 +1,5 @@
 add_mlir_dialect_library(MLIRArmNeonTransforms
-  LowerContractionToSMMLAPattern.cpp
+  LowerContractionToNeonI8MMPattern.cpp
 
   DEPENDS
   MLIRArmNeonIncGen
diff --git a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToNeonI8MMPattern.cpp
similarity index 97%
rename from mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
rename to mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToNeonI8MMPattern.cpp
index 967aff579227b..9c4f526b13c89 100644
--- a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
+++ b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToNeonI8MMPattern.cpp
@@ -1,4 +1,4 @@
-//===- LowerContractionToSMMLAPattern.cpp - Contract to SMMLA ---*- C++ -*-===//
+//===- LowerContractionToNeonI8MMPattern.cpp - Contract to I8MM -*- C++ -*-===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -6,8 +6,8 @@
 //
 //===----------------------------------------------------------------------===//
 //
-// This file implements lowering patterns from vector.contract to
-// arm_neon.intr.smmla
+// This file implements lowering patterns from vector.contract to operations
+// that map to instructions from the Neon FEAT_I8MM extension.
 //
 //===---
 
@@ -117,7 +117,7 @@ Value createMMLA(PatternRewriter &rewriter, MMLA op, Location loc,
 /// as [2,2,8] is a divisor of its shape. It can also process vecmats with dimM
 /// = 1 (either explicitly or inferred if LHS has only dimK) If no unrolling is
 /// necessary, a single smmla instruction is emitted.
-class LowerContractionToSMMLAPattern
+class LowerContractionToNeonI8MMPattern
     : public OpRewritePattern<vector::ContractionOp> {
 public:
   using OpRewritePattern::OpRewritePattern;
@@ -352,8 +352,8 @@ class LowerContractionToSMMLAPattern
 
 } // namespace
 
-void mlir::arm_neon::populateLowerContractionToSMMLAPatternPatterns(
+void mlir::arm_neon::populateLowerContractionToNeonI8MMPatternPatterns(
     RewritePatternSet &patterns) {
   MLIRContext *context = patterns.getContext();
-  patterns.add<LowerContractionToSMMLAPattern>(context, /*benefit=*/2);
+  patterns.add<LowerContractionToNeonI8MMPattern>(context, /*benefit=*/2);
 }
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp
index b1233c5c06eb4..9a75fee845655 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp
@@ -1,4 +1,4 @@
-//===- LowerContractionToSMMLAPattern.cpp - Contract to SMMLA ---*- C++ -*-===//
+//===- LowerContractionToSVEI8MMPattern.cpp - Contract to I8MM --*- C++ -*-===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.

>From 54f26cef643545d6bfd435c06be89de5307b6435 Mon Sep 17 00:00:00 2001
From: Momchil Velikov <momchil.velikov at arm.com>
Date: Wed, 25 Jun 2025 09:17:16 +0000
Subject: [PATCH 3/3] [fixup] Add links to GitHub issue #145559

---
 .../Transforms/LowerContractionToNeonI8MMPattern.cpp       | 7 ++++++-
 .../ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp | 5 +++++
 2 files changed, 11 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToNeonI8MMPattern.cpp b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToNeonI8MMPattern.cpp
index 9c4f526b13c89..7180884c77e98 100644
--- a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToNeonI8MMPattern.cpp
+++ b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToNeonI8MMPattern.cpp
@@ -9,7 +9,12 @@
 // This file implements lowering patterns from vector.contract to operations
 // that map to instructions from the Neon FEAT_I8MM extension.
 //
-//===---
+// TODO: There may be opportunities to unify this with a similar pattern
+// for SVE. See:
+//   https://github.com/llvm/llvm-project/issues/145559
+//   LowerContractionToSVEI8MMPattern.cpp
+//
+//===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp
index 9a75fee845655..f03534a81ed20 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp
@@ -9,6 +9,11 @@
 // This file implements lowering patterns from vector.contract to operations
 // that map to instructions from the SVE FEAT_I8MM extension.
 //
+// TODO: There may be opportunities to unify this with a similar pattern
+// for Neon. See:
+//   https://github.com/llvm/llvm-project/issues/145559
+//   LowerContractionToNeonI8MMPattern.cpp
+//
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Arith/IR/Arith.h"



More information about the Mlir-commits mailing list