[Mlir-commits] [mlir] [MLIR][AArch64] Check indexing maps before checking for dimensions compatibility (PR #145702)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Jun 25 06:48:04 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-sve

Author: Momchil Velikov (momchil-velikov)

<details>
<summary>Changes</summary>

In `LowerContractionToSVEI8MMPattern` check we have the expected indexing maps before deciding which operand dimension must match with which.

For example, with indexing map like:

    lhs: (m, n, k) -> (m, k)
    rhs: (m, n, k) -> (n, k)
    acc: (m, n, k) -> (m, n)

we would like the second `lhs` dimension (columns) to match with the second `rhs` (rows, transposed) whereas with indexing maps like

    lhs: (m, n, k) -> (m, k)
    rhs: (m, n, k) -> (k, n)
    acc: (m, n, k) -> (m, n)

we would like the second `lhs` dimension (columns) to match with the first `rhs` (rows, canonical matrix multiplication).

Since only the first kind of indexing maps is supported, the patch does not change anything of significance, just the notification message when the pattern would fail to apply anyway.

---
Full diff: https://github.com/llvm/llvm-project/pull/145702.diff


1 Files Affected:

- (modified) mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp (+19-19) 


``````````diff
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp
index a1209fe8230e2..70d2e06f48902 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp
@@ -136,11 +136,26 @@ class LowerContractionToSVEI8MMPattern
   LogicalResult matchAndRewrite(vector::ContractionOp op,
                                 PatternRewriter &rewriter) const override {
 
-    Location loc = op.getLoc();
+    // Check permutation maps. For now only accept
+    //   lhs: (d0, d1, d2) -> (d0, d2)
+    //   rhs: (d0, d1, d2) -> (d1, d2)
+    //   acc: (d0, d1, d2) -> (d0, d1)
+    // This corresponds to matrix multiplication with transposed RHS.
+    if (op.getIndexingMapsArray()[0] !=
+            AffineMap::getMultiDimMapWithTargets(3, ArrayRef{0u, 2u},
+                                                 op.getContext()) ||
+        op.getIndexingMapsArray()[1] !=
+            AffineMap::getMultiDimMapWithTargets(3, ArrayRef{1u, 2u},
+                                                 op.getContext()) ||
+        op.getIndexingMapsArray()[2] !=
+            AffineMap::getMultiDimMapWithTargets(3, ArrayRef{0u, 1u},
+                                                 op.getContext()))
+      return rewriter.notifyMatchFailure(op, "non-matching permutation maps");
+
     mlir::VectorType lhsType = op.getLhsType();
     mlir::VectorType rhsType = op.getRhsType();
 
-    // Check the rank the types so we can safely examine their dimensions.
+    // Check the rank of the types so we can safely examine their dimensions.
     if (lhsType.getRank() != 2 || rhsType.getRank() != 2)
       return rewriter.notifyMatchFailure(op, "non-matching operand shape");
 
@@ -159,22 +174,6 @@ class LowerContractionToSVEI8MMPattern
         !rhsType.getScalableDims()[0])
       return rewriter.notifyMatchFailure(op, "non-matching operand shape");
 
-    // Check permutation maps. For now only accept
-    //   lhs: (d0, d1, d2) -> (d0, d2)
-    //   rhs: (d0, d1, d2) -> (d1, d2)
-    //   acc: (d0, d1, d2) -> (d0, d1)
-    // This corresponds to matrix multiplication with transposed RHS.
-    if (op.getIndexingMapsArray()[0] !=
-            AffineMap::getMultiDimMapWithTargets(3, ArrayRef{0u, 2u},
-                                                 op.getContext()) ||
-        op.getIndexingMapsArray()[1] !=
-            AffineMap::getMultiDimMapWithTargets(3, ArrayRef{1u, 2u},
-                                                 op.getContext()) ||
-        op.getIndexingMapsArray()[2] !=
-            AffineMap::getMultiDimMapWithTargets(3, ArrayRef{0u, 1u},
-                                                 op.getContext()))
-      return rewriter.notifyMatchFailure(op, "non-matching permutation maps");
-
     // Check iterator types for matrix multiplication.
     auto itTypes = op.getIteratorTypesArray();
     if (itTypes.size() != 3 || itTypes[0] != vector::IteratorType::parallel ||
@@ -228,6 +227,7 @@ class LowerContractionToSVEI8MMPattern
                                    /*scalableDims=*/{true});
 
     // Extract LHS sub-tiles with logicall shape <2x8>.
+    Location loc = op.getLoc();
     SmallVector<Value> lhsTile;
     for (int64_t i = 0; i < M; i += 2) {
       // Extract two consecutive rows of the LHS tile.
@@ -283,7 +283,7 @@ class LowerContractionToSVEI8MMPattern
       if (mmlaOp == MMLA::MixedSwapped) {
         // We need to swap the positions of the LHS and RHS (since we don't have
         // a signed * unsigned operation), but then each individual 2x2 tile of
-        // the acumulator and (later) the result need to be transposed.
+        // the accumulator and (later) the result need to be transposed.
         accTileVec = rewriter.create<vector::InterleaveOp>(loc, r0, r1);
       } else {
         // Bitcast them to 64-bit elements, so subsequent

``````````

</details>


https://github.com/llvm/llvm-project/pull/145702


More information about the Mlir-commits mailing list