[Mlir-commits] [mlir] [MLIR][AArch64] Refactor lowering of vector.contract to Neon I8MM (PR #149810)

Andrzej WarzyƄski llvmlistbot at llvm.org
Mon Jul 21 06:45:34 PDT 2025


================
@@ -85,202 +77,186 @@ std::optional<Value> getExtOperand(Value v) {
   return inOp;
 }
 
-// Designate the operation (resp. instruction) used to do sub-tile matrix
-// multiplications.
-enum class MMLA {
-  Signed,      // smmla
-  Unsigned,    // ummla
-  Mixed,       // usmmla
-  MixedSwapped // usmmla with LHS and RHS swapped
-};
+/// Helper function to extend a vector with elements iN, N < 8 to
+/// a vector of i8. Do sign extension if the parameter `signExt` is true,
+/// zero extension otherwise.
+Value extendSmallIntVector(Location loc, VectorType srcTy, Value val,
+                           bool signExt, PatternRewriter &rewriter) {
+  Type targetTy = srcTy.clone(rewriter.getI8Type());
+  return signExt ? rewriter.createOrFold<arith::ExtSIOp>(loc, targetTy, val)
+                 : rewriter.createOrFold<arith::ExtUIOp>(loc, targetTy, val);
+}
 
-// Create the matrix mulitply and accumulate operation according to `op`.
-Value createMMLA(PatternRewriter &rewriter, MMLA op, Location loc,
-                 mlir::Type accType, Value acc, Value lhs, Value rhs) {
-  switch (op) {
-  case MMLA::Signed:
-    return rewriter.createOrFold<arm_neon::SmmlaOp>(loc, accType, acc, lhs,
-                                                    rhs);
-  case MMLA::Unsigned:
-    return rewriter.createOrFold<arm_neon::UmmlaOp>(loc, accType, acc, lhs,
-                                                    rhs);
-  case MMLA::Mixed:
-    return rewriter.createOrFold<arm_neon::UsmmlaOp>(loc, accType, acc, lhs,
-                                                     rhs);
-  case MMLA::MixedSwapped:
-    // The accumulator comes transposed and the result will be transposed
-    // later, so all we have to do here is swap the operands.
-    return rewriter.createOrFold<arm_neon::UsmmlaOp>(loc, accType, acc, rhs,
-                                                     lhs);
+class VectorContractRewriter {
+protected:
+  // Designate the operation (resp. instruction) used to do sub-tile matrix
+  // multiplications.
+  enum class MMLA {
+    Nop,
+    Signed,      // smmla
+    Unsigned,    // ummla
+    Mixed,       // usmmla
+    MixedSwapped // usmmla with LHS and RHS swapped
+  };
+
+  // Lower-level operation to be emitted.
+  MMLA mmlaOp = MMLA::Nop;
+
+  // The operand tiles. These are not necessarily the operands of
+  // `vector.contract`, for example they could be operands to `arith.extsi`
+  // that is in turn fed into `vector.contract`.
+  Value lhs;
+  Value rhs;
+  Value acc;
+
+  // The dimensions logically corresponding to matrix multiplication of
+  // MxK * KxN -> MxN. The operands and the result do not necessarily have these
+  // shapes, for example RHS could be NxK with a transposing indexing map.
+  int64_t dimM = 0;
+  int64_t dimN = 0;
+  int64_t dimK = 0;
+
+  // Unroll iteration bounds. See documentaiton for `StaticTileOffsetRange`.
+  SmallVector<int64_t> iterationBounds;
+
+  // Sub-tile shape. The algorithm handles operand shapes, which are multiples
+  // of this shape.
+  SmallVector<int64_t> subTileShape;
+
+  // Create the matrix multiply and accumulate operation according to `mmlaOp`.
+  Value createMMLA(PatternRewriter &rewriter, Location loc, Value acc,
+                   Value lhs, Value rhs) {
+    switch (mmlaOp) {
+    case MMLA::Signed:
+      return rewriter.createOrFold<arm_neon::SmmlaOp>(loc, acc.getType(), acc,
+                                                      lhs, rhs);
+    case MMLA::Unsigned:
+      return rewriter.createOrFold<arm_neon::UmmlaOp>(loc, acc.getType(), acc,
+                                                      lhs, rhs);
+    case MMLA::Mixed:
+      return rewriter.createOrFold<arm_neon::UsmmlaOp>(loc, acc.getType(), acc,
+                                                       lhs, rhs);
+    case MMLA::MixedSwapped:
+      // The accumulator comes transposed and the result will be transposed
+      // later, so all we have to do here is swap the operands.
+      return rewriter.createOrFold<arm_neon::UsmmlaOp>(loc, acc.getType(), acc,
+                                                       rhs, lhs);
+    case MMLA::Nop:
+      llvm_unreachable("Uninitialized operation type");
+    }
   }
-}
 
-/// Lowering from a vector::contractOp arm neon smmla intrinsic. This will tile
-/// any vector.contract into multiple smmla instructions with unrolling so long
-/// as [2,2,8] is a divisor of its shape. It can also process vecmats with dimM
-/// = 1 (either explicitly or inferred if LHS has only dimK) If no unrolling is
-/// necessary, a single smmla instruction is emitted.
-class LowerContractionToNeonI8MMPattern
-    : public OpRewritePattern<vector::ContractionOp> {
-public:
-  using OpRewritePattern::OpRewritePattern;
-  LogicalResult matchAndRewrite(vector::ContractionOp op,
-                                PatternRewriter &rewriter) const override {
-    Location loc = op.getLoc();
-    // Infer tile sizes from operands. For vecmat, LHS may only have 1 dim.
-    // Note: RHS is not transposed.
-    mlir::VectorType lhsType = op.getLhsType();
-    mlir::VectorType rhsType = op.getRhsType();
+  // Check common preconditions for applying the patterns and initialize
+  // logical dimensions.
+  LogicalResult matchAndInit(vector::ContractionOp op,
+                             PatternRewriter &rewriter) {
+    // Check iterator types for matrix multiplication.
+    SmallVector<vector::IteratorType> itTypes = op.getIteratorTypesArray();
+    if (!((itTypes.size() == 3 &&
+           (itTypes[0] == vector::IteratorType::parallel &&
+            itTypes[1] == vector::IteratorType::parallel &&
+            itTypes[2] == vector::IteratorType::reduction)) ||
+          (itTypes.size() == 2 &&
+           (itTypes[0] == vector::IteratorType::parallel &&
+            itTypes[1] == vector::IteratorType::reduction))))
+      return rewriter.notifyMatchFailure(
+          op, "iterator types do not correspond to matrix multiplication");
+
     // Avoid 0-D vectors and 1-D rhs:
-    if (!lhsType.hasRank() || !rhsType.hasRank() || rhsType.getRank() < 2)
-      return failure();
+    VectorType lhsType = op.getLhsType();
+    VectorType rhsType = op.getRhsType();
+    if (!lhsType.hasRank() || !rhsType.hasRank() || lhsType.getRank() > 2 ||
+        rhsType.getRank() != 2)
+      return rewriter.notifyMatchFailure(op, "Invalid operand rank");
+
     // This codegen does not work for scalable vectors. Return failure so this
     // pattern is not accidentally chosen over patterns that lower to ArmSVE.
     if (lhsType.isScalable() || rhsType.isScalable())
-      return failure();
-    auto dimM = lhsType.getRank() == 1 ? 1 : lhsType.getDimSize(0);
-    auto dimN = rhsType.getDimSize(0);
-    auto dimK = rhsType.getDimSize(1);
-    bool isVecmat = dimM == 1 ? true : false;
-    if (lhsType.getDimSize(lhsType.getRank() - 1) !=
-        rhsType.getDimSize(rhsType.getRank() - 1)) {
-      return failure(); // dimK mismatch
-    }
-    // Unrolling patterns can handle any [2, 2, 8] shaped multiple of inputs for
-    // tiling.
-    if ((dimM % 2 != 0 && !isVecmat) || dimN % 2 != 0 || dimK % 8 != 0) {
-      return failure();
-    }
-
-    // Check iterator types for contract. All iterators except inner-most
-    // dimension must be parallel.
-    auto iteratorTypes = op.getIteratorTypesArray();
-    if (iteratorTypes.size() > 3 || iteratorTypes[iteratorTypes.size() - 1] !=
-                                        vector::IteratorType::reduction) {
-      return failure();
-    }
-    if (llvm::any_of(ArrayRef<vector::IteratorType>(iteratorTypes).drop_back(1),
-                     [](vector::IteratorType iteratorType) {
-                       return iteratorType != vector::IteratorType::parallel;
-                     })) {
-      return failure();
+      return rewriter.notifyMatchFailure(op,
+                                         "Not applicable to scalable vectors");
+
+    // Initialize dimensions and check for a matching K dimension.
+    dimM = lhsType.getDimSize(0);
+    dimN = rhsType.getDimSize(0);
+    dimK = rhsType.getDimSize(1);
+
+    int64_t lhsDimK;
+    if (lhsType.getRank() == 1) {
+      dimM = 1;
+      lhsDimK = lhsType.getDimSize(0);
+    } else {
+      lhsDimK = lhsType.getDimSize(1);
     }
 
-    // Check inputs are sign-/zero- extensions from iN (N <= 8) to i32. Get the
-    // values before the extension. All four signed/unsigned combinations for
-    // input operands are supported, but they are lowered to different
-    // operations. Determine which is the appropriate operation to lower to.
-    MMLA mmlaOp = MMLA::Signed;
-    auto maybeLhs = getExtOperand<arith::ExtSIOp>(op.getLhs());
-    if (!maybeLhs) {
-      mmlaOp = MMLA::Unsigned;
-      maybeLhs = getExtOperand<arith::ExtUIOp>(op.getLhs());
-    }
-    if (!maybeLhs)
-      return failure();
+    if (lhsDimK != dimK)
+      return rewriter.notifyMatchFailure(op, "Dimensions mismatch");
 
-    auto maybeRhs = getExtOperand<arith::ExtSIOp>(op.getRhs());
-    if (maybeRhs) {
-      if (mmlaOp == MMLA::Unsigned)
-        mmlaOp = MMLA::Mixed;
-    } else {
-      if (mmlaOp == MMLA::Signed)
-        mmlaOp = MMLA::MixedSwapped;
-      maybeRhs = getExtOperand<arith::ExtUIOp>(op.getRhs());
-    }
-    if (!maybeRhs)
-      return failure();
+    return success();
+  }
 
-    Value origLhs = *maybeLhs;
-    Value origRhs = *maybeRhs;
-
-    // Match any iX to i32 for X<8 then turn into an i8 output. Feed into
-    // following neon instruction. Check inputs for extsi are <=i8
-    Value extLhs;
-    Value extRhs;
-    if (auto lhsExtInType = dyn_cast<mlir::VectorType>(origLhs.getType())) {
-      if (lhsExtInType.getElementTypeBitWidth() <= 8) {
-        Type targetLhsExtTy =
-            matchContainerType(rewriter.getI8Type(), lhsExtInType);
-        if (mmlaOp == MMLA::Signed || mmlaOp == MMLA::Mixed)
-          extLhs = rewriter.createOrFold<arith::ExtSIOp>(loc, targetLhsExtTy,
-                                                         origLhs);
-        else
-          extLhs = rewriter.createOrFold<arith::ExtUIOp>(loc, targetLhsExtTy,
-                                                         origLhs);
-      }
-    }
-    if (auto rhsExtInType = dyn_cast<mlir::VectorType>(origRhs.getType())) {
-      if (rhsExtInType.getElementTypeBitWidth() <= 8) {
-        Type targetRhsExtTy =
-            matchContainerType(rewriter.getI8Type(), rhsExtInType);
-        if (mmlaOp == MMLA::Unsigned || mmlaOp == MMLA::Mixed)
-          extRhs = rewriter.createOrFold<arith::ExtUIOp>(loc, targetRhsExtTy,
-                                                         origRhs);
-        else
-          extRhs = rewriter.createOrFold<arith::ExtSIOp>(loc, targetRhsExtTy,
-                                                         origRhs);
-      }
-    }
+public:
+  void rewrite(vector::ContractionOp op, PatternRewriter &rewriter) {
----------------
banach-space wrote:

```suggestion
  void lower(vector::ContractionOp op, PatternRewriter &rewriter) {
```

https://github.com/llvm/llvm-project/pull/149810


More information about the Mlir-commits mailing list