[Mlir-commits] [mlir] [MLIR][ArmSVE] Add lowering of `vector.contract` to SVE `*MMLA` instructions (PR #135359)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Apr 11 05:31:48 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-sve

@llvm/pr-subscribers-mlir-neon

Author: Momchil Velikov (momchil-velikov)

<details>
<summary>Changes</summary>



---

Patch is 93.07 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/135359.diff


21 Files Affected:

- (modified) mlir/include/mlir/Conversion/Passes.td (+4) 
- (modified) mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td (+94-2) 
- (modified) mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h (+3) 
- (modified) mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt (+1) 
- (modified) mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp (+7) 
- (modified) mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp (+4-1) 
- (modified) mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt (+1) 
- (modified) mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp (+8) 
- (added) mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp (+304) 
- (modified) mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir (+55) 
- (modified) mlir/test/Dialect/ArmSVE/roundtrip.mlir (+11) 
- (added) mlir/test/Dialect/Vector/CPU/ArmSVE/vector-smmla.mlir (+94) 
- (added) mlir/test/Dialect/Vector/CPU/ArmSVE/vector-summla.mlir (+85) 
- (added) mlir/test/Dialect/Vector/CPU/ArmSVE/vector-ummla.mlir (+94) 
- (added) mlir/test/Dialect/Vector/CPU/ArmSVE/vector-usmmla.mlir (+95) 
- (added) mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-4x8x4.mlir (+117) 
- (added) mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-8x8x8-vs2.mlir (+159) 
- (added) mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-summla-4x8x4.mlir (+118) 
- (added) mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-ummla-4x8x4.mlir (+119) 
- (added) mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-usmmla-4x8x4.mlir (+117) 
- (modified) mlir/test/Target/LLVMIR/arm-sve.mlir (+46) 


``````````diff
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index bbba495e613b2..930d8b44abca0 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1406,6 +1406,10 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> {
            "bool", /*default=*/"false",
            "Enables the use of ArmSVE dialect while lowering the vector "
        "dialect.">,
+    Option<"armI8MM", "enable-arm-i8mm",
+           "bool", /*default=*/"false",
+           "Enables the use of Arm FEAT_I8MM instructions while lowering "
+           "the vector dialect.">,
     Option<"x86Vector", "enable-x86vector",
            "bool", /*default=*/"false",
            "Enables the use of X86Vector dialect while lowering the vector "
diff --git a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
index cdcf4d8752e87..a678d9b09e66d 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
+++ b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
@@ -61,6 +61,13 @@ class Scalable1DVectorOfLength<int length, list<Type> elementTypes> : ShapedCont
   "a 1-D scalable vector with length " # length,
   "::mlir::VectorType">;
 
+def SVEVector : AnyTypeOf<[
+  Scalable1DVectorOfLength<2, [I64, F64]>,
+  Scalable1DVectorOfLength<4, [I32, F32]>,
+  Scalable1DVectorOfLength<8, [I16, F16, BF16]>,
+  Scalable1DVectorOfLength<16, [I8]>],
+  "an SVE vector with element size <= 64-bit">;
+
 //===----------------------------------------------------------------------===//
 // ArmSVE op definitions
 //===----------------------------------------------------------------------===//
@@ -72,14 +79,22 @@ class ArmSVE_IntrOp<string mnemonic,
                     list<Trait> traits = [],
                     list<int> overloadedOperands = [],
                     list<int> overloadedResults = [],
-                    int numResults = 1> :
+                    int numResults = 1,
+                    list<int> immArgPositions = [],
+                    list<string> immArgAttrNames = []> :
   LLVM_IntrOpBase</*Dialect dialect=*/ArmSVE_Dialect,
                   /*string opName=*/"intr." # mnemonic,
                   /*string enumName=*/"aarch64_sve_" # !subst(".", "_", mnemonic),
                   /*list<int> overloadedResults=*/overloadedResults,
                   /*list<int> overloadedOperands=*/overloadedOperands,
                   /*list<Trait> traits=*/traits,
-                  /*int numResults=*/numResults>;
+                  /*int numResults=*/numResults,
+                  /*bit requiresAccessGroup=*/0,
+                  /*bit requiresAliasAnalysis=*/0,
+                  /*bit requiresFastmath=*/0,
+                  /*bit requiresOpBundles=*/0,
+                  /*list<int> immArgPositions=*/immArgPositions,
+                  /*list<string> immArgAttrNames=*/immArgAttrNames>;
 
 class ArmSVE_IntrBinaryOverloadedOp<string mnemonic,
                                     list<Trait> traits = []>:
@@ -258,6 +273,34 @@ def UmmlaOp : ArmSVE_Op<"ummla",
     "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)";
 }
 
+def UsmmlaOp : ArmSVE_Op<"usmmla", [Pure,
+                                    AllTypesMatch<["src1", "src2"]>,
+                                    AllTypesMatch<["acc", "dst"]>]> {
+  let summary = "Matrix-matrix multiply and accumulate op";
+  let description = [{
+    USMMLA: Unsigned by signed integer matrix multiply-accumulate.
+
+    The unsigned by signed integer matrix multiply-accumulate operation
+    multiplies the 2×8 matrix of unsigned 8-bit integer values held
+    the first source vector by the 8×2 matrix of signed 8-bit integer
+    values in the second source vector. The resulting 2×2 widened 32-bit
+    integer matrix product is then added to the 32-bit integer matrix
+    accumulator.
+
+    Source:
+    https://developer.arm.com/documentation/100987/0000
+  }];
+  // Supports (vector<16xi8>, vector<16xi8>) -> (vector<4xi32>)
+  let arguments = (ins
+          ScalableVectorOfLengthAndType<[4], [I32]>:$acc,
+          ScalableVectorOfLengthAndType<[16], [I8]>:$src1,
+          ScalableVectorOfLengthAndType<[16], [I8]>:$src2
+  );
+  let results = (outs ScalableVectorOfLengthAndType<[4], [I32]>:$dst);
+  let assemblyFormat =
+    "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)";
+}
+
 class SvboolTypeConstraint<string lhsArg, string rhsArg> : TypesMatchWith<
       "expected corresponding svbool type widened to [16]xi1",
       lhsArg, rhsArg,
@@ -509,6 +552,41 @@ def ScalableMaskedUDivIOp : ScalableMaskedIOp<"masked.divi_unsigned",
 
 def ScalableMaskedDivFOp : ScalableMaskedFOp<"masked.divf", "division">;
 
+def DupQLaneOp : ArmSVE_Op<"dupq_lane", [Pure, AllTypesMatch<["src", "dst"]>]> {
+  let summary = "Broadcast indexed 128-bit segment to vector";
+
+  let description = [{
+    This operation fills each 128-bit segment of a vector with the elements
+    from the indexed 128-bit sgement of the source vector. If the VL is
+    128 bits the operation is a NOP.
+
+    Example:
+    ```mlir
+    // VL == 256
+    // %X = [A B C D x x x x]
+    %Y = arm_sve.dupq_lane %X[0] : vector<[4]xi32>
+    // Y = [A B C D A B C D]
+
+    // %U = [x x x x x x x x A B C D E F G H]
+    %V = arm_sve.dupq_lane %U[1] : vector<[8]xf16>
+    // %V = [A B C D E F H A B C D E F H]
+    ```
+  }];
+
+  let arguments = (ins SVEVector:$src,
+                       I64Attr:$lane);
+  let results = (outs SVEVector:$dst);
+
+  let builders = [
+    OpBuilder<(ins "Value":$src, "int64_t":$lane), [{
+      build($_builder, $_state, src.getType(), src, lane);
+    }]>];
+
+  let assemblyFormat = [{
+    $src `[` $lane `]` attr-dict `:` type($dst)
+  }];
+}
+
 def UmmlaIntrOp :
   ArmSVE_IntrBinaryOverloadedOp<"ummla">,
   Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
@@ -517,6 +595,10 @@ def SmmlaIntrOp :
   ArmSVE_IntrBinaryOverloadedOp<"smmla">,
   Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
 
+def UsmmlaIntrOp :
+  ArmSVE_IntrBinaryOverloadedOp<"usmmla">,
+  Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
+
 def SdotIntrOp :
   ArmSVE_IntrBinaryOverloadedOp<"sdot">,
   Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
@@ -610,4 +692,14 @@ def WhileLTIntrOp :
     /*overloadedResults=*/[0]>,
   Arguments<(ins I64:$base, I64:$n)>;
 
+def DupQLaneIntrOp : ArmSVE_IntrOp<"dupq_lane",
+    /*traits=*/[],
+    /*overloadedOperands=*/[0],
+    /*overloadedResults=*/[],
+    /*numResults=*/1,
+    /*immArgPositions*/[1],
+    /*immArgAttrNames*/["lane"]>,
+    Arguments<(ins Arg<ScalableVectorOfRank<[1]>, "v">:$v,
+                   Arg<I64Attr, "lane">:$lane)>;
+
 #endif // ARMSVE_OPS
diff --git a/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h b/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h
index 8665c8224cc45..232e2be29e574 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h
@@ -20,6 +20,9 @@ class RewritePatternSet;
 void populateArmSVELegalizeForLLVMExportPatterns(
     const LLVMTypeConverter &converter, RewritePatternSet &patterns);
 
+void populateLowerContractionToSVEI8MMPatternPatterns(
+    RewritePatternSet &patterns);
+
 /// Configure the target to support lowering ArmSVE ops to ops that map to LLVM
 /// intrinsics.
 void configureArmSVELegalizeForExportTarget(LLVMConversionTarget &target);
diff --git a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
index 330474a718e30..8e2620029c354 100644
--- a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
+++ b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
@@ -35,6 +35,7 @@ add_mlir_conversion_library(MLIRVectorToLLVMPass
   MLIRVectorToLLVM
 
   MLIRArmNeonDialect
+  MLIRArmNeonTransforms
   MLIRArmSVEDialect
   MLIRArmSVETransforms
   MLIRAMXDialect
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index 7082b92c95d1d..1e6c8122b1d0e 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -14,6 +14,7 @@
 #include "mlir/Dialect/AMX/Transforms.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
+#include "mlir/Dialect/ArmNeon/Transforms.h"
 #include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
 #include "mlir/Dialect/ArmSVE/Transforms/Transforms.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -82,6 +83,12 @@ void ConvertVectorToLLVMPass::runOnOperation() {
     populateVectorStepLoweringPatterns(patterns);
     populateVectorRankReducingFMAPattern(patterns);
     populateVectorGatherLoweringPatterns(patterns);
+    if (armI8MM) {
+      if (armNeon)
+        arm_neon::populateLowerContractionToSMMLAPatternPatterns(patterns);
+      if (armSVE)
+        populateLowerContractionToSVEI8MMPatternPatterns(patterns);
+    }
     (void)applyPatternsGreedily(getOperation(), std::move(patterns));
   }
 
diff --git a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
index 2a1271dfd6bdf..e807b233aa7aa 100644
--- a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
+++ b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
@@ -56,6 +56,9 @@ class LowerContractionToSMMLAPattern
     // Avoid 0-D vectors and 1-D rhs:
     if (!lhsType.hasRank() || !rhsType.hasRank() || rhsType.getRank() < 2)
       return failure();
+    // Avoid scalable vectors.
+    if (lhsType.isScalable() || rhsType.isScalable())
+      return failure();
     auto dimM = lhsType.getRank() == 1 ? 1 : lhsType.getDimSize(0);
     auto dimN = rhsType.getDimSize(0);
     auto dimK = rhsType.getDimSize(1);
@@ -238,5 +241,5 @@ class LowerContractionToSMMLAPattern
 void mlir::arm_neon::populateLowerContractionToSMMLAPatternPatterns(
     RewritePatternSet &patterns) {
   MLIRContext *context = patterns.getContext();
-  patterns.add<LowerContractionToSMMLAPattern>(context, /*benefit=*/1);
+  patterns.add<LowerContractionToSMMLAPattern>(context, /*benefit=*/2);
 }
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt b/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt
index a70c489a51fea..65f98b44b1b69 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt
@@ -1,6 +1,7 @@
 add_mlir_dialect_library(MLIRArmSVETransforms
   LegalizeForLLVMExport.cpp
   LegalizeVectorStorage.cpp
+  LowerContractionToSVEI8MMPattern.cpp
 
   DEPENDS
   MLIRArmSVEConversionsIncGen
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
index 2bdb640699d03..0bbe4717e0d17 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
@@ -24,6 +24,8 @@ using SdotOpLowering = OneToOneConvertToLLVMPattern<SdotOp, SdotIntrOp>;
 using SmmlaOpLowering = OneToOneConvertToLLVMPattern<SmmlaOp, SmmlaIntrOp>;
 using UdotOpLowering = OneToOneConvertToLLVMPattern<UdotOp, UdotIntrOp>;
 using UmmlaOpLowering = OneToOneConvertToLLVMPattern<UmmlaOp, UmmlaIntrOp>;
+using UsmmlaOpLowering = OneToOneConvertToLLVMPattern<UsmmlaOp, UsmmlaIntrOp>;
+using DupQLaneLowering = OneToOneConvertToLLVMPattern<DupQLaneOp, DupQLaneIntrOp>;
 using ScalableMaskedAddIOpLowering =
     OneToOneConvertToLLVMPattern<ScalableMaskedAddIOp,
                                  ScalableMaskedAddIIntrOp>;
@@ -192,6 +194,8 @@ void mlir::populateArmSVELegalizeForLLVMExportPatterns(
                SmmlaOpLowering,
                UdotOpLowering,
                UmmlaOpLowering,
+               UsmmlaOpLowering,
+               DupQLaneLowering,
                ScalableMaskedAddIOpLowering,
                ScalableMaskedAddFOpLowering,
                ScalableMaskedSubIOpLowering,
@@ -219,6 +223,8 @@ void mlir::configureArmSVELegalizeForExportTarget(
                     SmmlaIntrOp,
                     UdotIntrOp,
                     UmmlaIntrOp,
+                    UsmmlaIntrOp,
+                    DupQLaneIntrOp,
                     ScalableMaskedAddIIntrOp,
                     ScalableMaskedAddFIntrOp,
                     ScalableMaskedSubIIntrOp,
@@ -238,6 +244,8 @@ void mlir::configureArmSVELegalizeForExportTarget(
                       SmmlaOp,
                       UdotOp,
                       UmmlaOp,
+                      UsmmlaOp,
+                      DupQLaneOp,
                       ScalableMaskedAddIOp,
                       ScalableMaskedAddFOp,
                       ScalableMaskedSubIOp,
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp
new file mode 100644
index 0000000000000..c0620c71440bc
--- /dev/null
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp
@@ -0,0 +1,304 @@
+//===- LowerContractionToSMMLAPattern.cpp - Contract to SMMLA ---*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements lowering patterns from vector.contract to
+// SVE I8MM operations.
+//
+//===---
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
+#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#include "mlir/Dialect/UB/IR/UBOps.h"
+
+#define DEBUG_TYPE "lower-contract-to-arm-sve-i8mm"
+
+using namespace mlir;
+using namespace mlir::arm_sve;
+
+namespace {
+// Check if the given value is a result of the operation `T` (which must be
+// sign- or zero- extend) from i8 to i32. Return the value before the extension.
+template <typename T>
+inline std::enable_if_t<(std::is_base_of_v<arith::ExtSIOp, T> ||
+                         std::is_base_of_v<arith::ExtUIOp, T>),
+                        std::optional<Value>>
+extractExtOperand(Value v, Type i8Ty, Type i32Ty) {
+  auto extOp = dyn_cast_or_null<T>(v.getDefiningOp());
+  if (!extOp)
+    return {};
+
+  auto inOp = extOp.getIn();
+  auto inTy = dyn_cast<VectorType>(inOp.getType());
+  if (!inTy || inTy.getElementType() != i8Ty)
+    return {};
+
+  auto outTy = dyn_cast<VectorType>(extOp.getType());
+  if (!outTy || outTy.getElementType() != i32Ty)
+    return {};
+
+  return inOp;
+}
+
+// Designate the operation (resp. instruction) used to do sub-tile matrix
+// multiplications.
+enum class MMLA {
+  Signed,      // smmla
+  Unsigned,    // ummla
+  Mixed,       // usmmla
+  MixedSwapped // usmmla with LHS and RHS swapped
+};
+
+// Create the matrix multply and accumulate operation according to `op`.
+Value createMMLA(PatternRewriter &rewriter, MMLA op, Location loc,
+                 mlir::VectorType accType, Value acc, Value lhs, Value rhs) {
+  switch (op) {
+  case MMLA::Signed:
+    return rewriter.create<arm_sve::SmmlaOp>(loc, accType, acc, lhs, rhs);
+  case MMLA::Unsigned:
+    return rewriter.create<arm_sve::UmmlaOp>(loc, accType, acc, lhs, rhs);
+  case MMLA::Mixed:
+    return rewriter.create<arm_sve::UsmmlaOp>(loc, accType, acc, lhs, rhs);
+  case MMLA::MixedSwapped:
+    // The accumulator comes transposed and the result will be transposed
+    // later, so all we have to do here is swap the operands.
+    return rewriter.create<arm_sve::UsmmlaOp>(loc, accType, acc, rhs, lhs);
+  }
+}
+
+class LowerContractionToSVEI8MMPattern
+    : public OpRewritePattern<vector::ContractionOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(vector::ContractionOp op,
+                                PatternRewriter &rewriter) const override {
+
+    Location loc = op.getLoc();
+    mlir::VectorType lhsType = op.getLhsType();
+    mlir::VectorType rhsType = op.getRhsType();
+
+    // For now handle LHS<Mx8> and RHS<8x[N]> - these are the types we
+    // eventually expect from MMT4D. M and N dimensions must be even and at
+    // least 2.
+    if (!lhsType.hasRank() || lhsType.getRank() != 2 || !rhsType.hasRank() ||
+        rhsType.getRank() != 2)
+      return failure();
+
+    if (lhsType.isScalable() || !rhsType.isScalable())
+      return failure();
+
+    // M, N, and K are the conventional names for matrix dimensions in the
+    // context of matrix multiplication.
+    auto M = lhsType.getDimSize(0);
+    auto N = rhsType.getDimSize(0);
+    auto K = rhsType.getDimSize(1);
+
+    if (lhsType.getDimSize(1) != K || K != 8 || M < 2 || M % 2 != 0 || N < 2 ||
+        N % 2 != 0 || !rhsType.getScalableDims()[0])
+      return failure();
+
+    // Check permutation maps. For now only accept
+    //   lhs: (d0, d1, d2) -> (d0, d2)
+    //   rhs: (d0, d1, d2) -> (d1, d2)
+    //   acc: (d0, d1, d2) -> (d0, d1)
+    // Note: RHS is transposed.
+    if (op.getIndexingMapsArray()[0] !=
+            AffineMap::getMultiDimMapWithTargets(3, ArrayRef{0u, 2u},
+                                                 op.getContext()) ||
+        op.getIndexingMapsArray()[1] !=
+            AffineMap::getMultiDimMapWithTargets(3, ArrayRef{1u, 2u},
+                                                 op.getContext()) ||
+        op.getIndexingMapsArray()[2] !=
+            AffineMap::getMultiDimMapWithTargets(3, ArrayRef{0u, 1u},
+                                                 op.getContext()))
+      return failure();
+
+    // Check iterator types for matrix multiplication.
+    auto itTypes = op.getIteratorTypesArray();
+    if (itTypes.size() != 3 || itTypes[0] != vector::IteratorType::parallel ||
+        itTypes[1] != vector::IteratorType::parallel ||
+        itTypes[2] != vector::IteratorType::reduction)
+      return failure();
+
+    // Check the combining kind is addition.
+    if (op.getKind() != vector::CombiningKind::ADD)
+      return failure();
+
+    // Check the output is a vector of i32 elements.
+    auto outTy = dyn_cast<VectorType>(op.getType());
+    if (!outTy || outTy.getElementType() != rewriter.getI32Type())
+      return failure();
+
+    // Check inputs are sign-/zero- extensions from i8 to i32. Get the values
+    // before the extension. All four signed/unsigned combinations for input
+    // operands are supported, but they are lowered to different operations.
+    // Determina which is the appropriate operation to lower to.
+    MMLA mmlaOp = MMLA::Signed;
+    auto maybeLhs = extractExtOperand<arith::ExtSIOp>(
+        op.getLhs(), rewriter.getI8Type(), rewriter.getI32Type());
+    if (!maybeLhs) {
+      mmlaOp = MMLA::Unsigned;
+      maybeLhs = extractExtOperand<arith::ExtUIOp>(
+          op.getLhs(), rewriter.getI8Type(), rewriter.getI32Type());
+    }
+    if (!maybeLhs)
+      return failure();
+
+    auto maybeRhs = extractExtOperand<arith::ExtSIOp>(
+        op.getRhs(), rewriter.getI8Type(), rewriter.getI32Type());
+    if (maybeRhs) {
+      if (mmlaOp == MMLA::Unsigned)
+        mmlaOp = MMLA::Mixed;
+    } else {
+      if (mmlaOp == MMLA::Signed)
+        mmlaOp = MMLA::MixedSwapped;
+      maybeRhs = extractExtOperand<arith::ExtUIOp>(
+          op.getRhs(), rewriter.getI8Type(), rewriter.getI32Type());
+    }
+    if (!maybeRhs)
+      return failure();
+
+    // One-dimensional vector types for arm_sve.*mmla
+    auto nxv16i8 = VectorType::get(16, rewriter.getI8Type(), {true});
+    auto nxv4i32 = VectorType::get(4, rewriter.getI32Type(), {true});
+
+    // Extract LHS sub-tiles.
+    SmallVector<Value> lhsTile;
+    for (int64_t i = 0; i < M; i += 2) {
+      // Exract two consective rows of the LHS tile.
+      auto r0 = rew...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/135359


More information about the Mlir-commits mailing list