[llvm-branch-commits] [mlir] [MLIR][ArmSVE] Add an ArmSVE dialect operation which maps to svusmmla (PR #135634)

Momchil Velikov via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Tue Apr 15 08:54:34 PDT 2025


https://github.com/momchil-velikov updated https://github.com/llvm/llvm-project/pull/135634

>From 5e91c2eb411cba43794fa7db918e88099885849e Mon Sep 17 00:00:00 2001
From: Momchil Velikov <momchil.velikov at arm.com>
Date: Thu, 10 Apr 2025 14:38:27 +0000
Subject: [PATCH] [MLIR][ArmSVE] Add an ArmSVE dialect operation which maps to
 `svusmmla`

---
 mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td | 95 +++++++++++--------
 .../Transforms/LegalizeForLLVMExport.cpp      |  4 +
 .../Dialect/ArmSVE/legalize-for-llvm.mlir     | 12 +++
 mlir/test/Dialect/ArmSVE/roundtrip.mlir       | 11 +++
 mlir/test/Target/LLVMIR/arm-sve.mlir          | 12 +++
 5 files changed, 96 insertions(+), 38 deletions(-)

diff --git a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
index 3a990f8464ef8..7385bb73b449a 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
+++ b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
@@ -147,11 +147,9 @@ class ScalableMaskedIOp<string mnemonic, string op_description,
     "$mask `,` $src1 `,` $src2 attr-dict `:` type($mask) `,` type($res)";
 }
 
-def SdotOp : ArmSVE_Op<"sdot",
-               [Pure,
-               AllTypesMatch<["src1", "src2"]>,
-               AllTypesMatch<["acc", "dst"]>,
-             ]> {
+def SdotOp : ArmSVE_Op<"sdot", [Pure,
+                                AllTypesMatch<["src1", "src2"]>,
+                                AllTypesMatch<["acc", "dst"]>]> {
   let summary = "Vector-vector dot product and accumulate op";
   let description = [{
     SDOT: Signed integer addition of dot product.
@@ -178,11 +176,9 @@ def SdotOp : ArmSVE_Op<"sdot",
     "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)";
 }
 
-def SmmlaOp : ArmSVE_Op<"smmla",
-                [Pure,
-                AllTypesMatch<["src1", "src2"]>,
-                AllTypesMatch<["acc", "dst"]>,
-              ]> {
+def SmmlaOp : ArmSVE_Op<"smmla", [Pure,
+                                  AllTypesMatch<["src1", "src2"]>,
+                                  AllTypesMatch<["acc", "dst"]>]> {
   let summary = "Matrix-matrix multiply and accumulate op";
   let description = [{
     SMMLA: Signed integer matrix multiply-accumulate.
@@ -210,11 +206,9 @@ def SmmlaOp : ArmSVE_Op<"smmla",
     "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)";
 }
 
-def UdotOp : ArmSVE_Op<"udot",
-               [Pure,
-               AllTypesMatch<["src1", "src2"]>,
-               AllTypesMatch<["acc", "dst"]>,
-             ]> {
+def UdotOp : ArmSVE_Op<"udot", [Pure,
+                                AllTypesMatch<["src1", "src2"]>,
+                                AllTypesMatch<["acc", "dst"]>]> {
   let summary = "Vector-vector dot product and accumulate op";
   let description = [{
     UDOT: Unsigned integer addition of dot product.
@@ -241,11 +235,9 @@ def UdotOp : ArmSVE_Op<"udot",
     "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)";
 }
 
-def UmmlaOp : ArmSVE_Op<"ummla",
-                [Pure,
-                AllTypesMatch<["src1", "src2"]>,
-                AllTypesMatch<["acc", "dst"]>,
-              ]> {
+def UmmlaOp : ArmSVE_Op<"ummla", [Pure,
+                                  AllTypesMatch<["src1", "src2"]>,
+                                  AllTypesMatch<["acc", "dst"]>]> {
   let summary = "Matrix-matrix multiply and accumulate op";
   let description = [{
     UMMLA: Unsigned integer matrix multiply-accumulate.
@@ -273,14 +265,42 @@ def UmmlaOp : ArmSVE_Op<"ummla",
     "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)";
 }
 
+def UsmmlaOp : ArmSVE_Op<"usmmla", [Pure,
+                                    AllTypesMatch<["src1", "src2"]>,
+                                    AllTypesMatch<["acc", "dst"]>]> {
+  let summary = "Matrix-matrix multiply and accumulate op";
+  let description = [{
+    USMMLA: Unsigned by signed integer matrix multiply-accumulate.
+
+    The unsigned by signed integer matrix multiply-accumulate operation
+    multiplies the 2×8 matrix of unsigned 8-bit integer values held
+    the first source vector by the 8×2 matrix of signed 8-bit integer
+    values in the second source vector. The resulting 2×2 widened 32-bit
+    integer matrix product is then added to the 32-bit integer matrix
+    accumulator.
+
+    Source:
+    https://developer.arm.com/documentation/100987/0000
+  }];
+  // Supports (vector<16xi8>, vector<16xi8>) -> (vector<4xi32>)
+  let arguments = (ins
+          ScalableVectorOfLengthAndType<[4], [I32]>:$acc,
+          ScalableVectorOfLengthAndType<[16], [I8]>:$src1,
+          ScalableVectorOfLengthAndType<[16], [I8]>:$src2
+  );
+  let results = (outs ScalableVectorOfLengthAndType<[4], [I32]>:$dst);
+  let assemblyFormat =
+    "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)";
+}
+
 class SvboolTypeConstraint<string lhsArg, string rhsArg> : TypesMatchWith<
       "expected corresponding svbool type widened to [16]xi1",
       lhsArg, rhsArg,
       "VectorType(VectorType::Builder(::llvm::cast<VectorType>($_self)).setDim(::llvm::cast<VectorType>($_self).getRank() - 1, 16))">;
 
 def ConvertFromSvboolOp : ArmSVE_Op<"convert_from_svbool",
-                            [Pure, SvboolTypeConstraint<"result", "source">]>
-{
+                                    [Pure,
+                                     SvboolTypeConstraint<"result", "source">]> {
   let summary = "Convert a svbool type to a SVE predicate type";
   let description = [{
     Converts svbool types (`vector<[16]xi1>` or vectors of that type, e.g.
@@ -313,8 +333,8 @@ def ConvertFromSvboolOp : ArmSVE_Op<"convert_from_svbool",
 }
 
 def ConvertToSvboolOp : ArmSVE_Op<"convert_to_svbool",
-                            [Pure, SvboolTypeConstraint<"source", "result">]>
-{
+                                  [Pure,
+                                   SvboolTypeConstraint<"source", "result">]> {
   let summary = "Convert a SVE predicate type to a svbool type";
   let description = [{
     Converts SVE predicate types (or vectors of predicate types, e.g.
@@ -356,10 +376,9 @@ def ZipInputVectorType : AnyTypeOf<[
   Scalable1DVectorOfLength<16, [I8]>],
   "an SVE vector with element size <= 64-bit">;
 
-def ZipX2Op  : ArmSVE_Op<"zip.x2", [
-  Pure,
-  AllTypesMatch<["sourceV1", "sourceV2", "resultV1", "resultV2"]>]
-> {
+def ZipX2Op : ArmSVE_Op<"zip.x2", [Pure,
+                                   AllTypesMatch<["sourceV1", "sourceV2",
+                                                  "resultV1", "resultV2"]>]> {
   let summary = "Multi-vector two-way zip op";
 
   let description = [{
@@ -400,12 +419,11 @@ def ZipX2Op  : ArmSVE_Op<"zip.x2", [
   }];
 }
 
-def ZipX4Op  : ArmSVE_Op<"zip.x4", [
-  Pure,
-  AllTypesMatch<[
-    "sourceV1", "sourceV2", "sourceV3", "sourceV4",
-    "resultV1", "resultV2", "resultV3", "resultV4"]>]
-> {
+def ZipX4Op
+  : ArmSVE_Op<"zip.x4",
+              [Pure,
+               AllTypesMatch<["sourceV1", "sourceV2", "sourceV3", "sourceV4",
+                              "resultV1", "resultV2", "resultV3", "resultV4"]>]> {
   let summary = "Multi-vector four-way zip op";
 
   let description = [{
@@ -463,10 +481,7 @@ def ZipX4Op  : ArmSVE_Op<"zip.x4", [
   }];
 }
 
-def PselOp : ArmSVE_Op<"psel", [
-  Pure,
-  AllTypesMatch<["p1", "result"]>,
-]> {
+def PselOp : ArmSVE_Op<"psel", [Pure, AllTypesMatch<["p1", "result"]>]> {
   let summary = "Predicate select";
 
   let description = [{
@@ -571,6 +586,10 @@ def SmmlaIntrOp :
   ArmSVE_IntrBinaryOverloadedOp<"smmla">,
   Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
 
+def UsmmlaIntrOp :
+  ArmSVE_IntrBinaryOverloadedOp<"usmmla">,
+  Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
+
 def SdotIntrOp :
   ArmSVE_IntrBinaryOverloadedOp<"sdot">,
   Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
index 536373b82c67f..35f2a02cc4ec6 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
@@ -24,6 +24,7 @@ using SdotOpLowering = OneToOneConvertToLLVMPattern<SdotOp, SdotIntrOp>;
 using SmmlaOpLowering = OneToOneConvertToLLVMPattern<SmmlaOp, SmmlaIntrOp>;
 using UdotOpLowering = OneToOneConvertToLLVMPattern<UdotOp, UdotIntrOp>;
 using UmmlaOpLowering = OneToOneConvertToLLVMPattern<UmmlaOp, UmmlaIntrOp>;
+using UsmmlaOpLowering = OneToOneConvertToLLVMPattern<UsmmlaOp, UsmmlaIntrOp>;
 using DupQLaneLowering =
     OneToOneConvertToLLVMPattern<DupQLaneOp, DupQLaneIntrOp>;
 using ScalableMaskedAddIOpLowering =
@@ -206,6 +207,7 @@ void mlir::populateArmSVELegalizeForLLVMExportPatterns(
                SmmlaOpLowering,
                UdotOpLowering,
                UmmlaOpLowering,
+               UsmmlaOpLowering,
                ZipX2OpLowering,
                ZipX4OpLowering,
                SdotOpLowering>(converter);
@@ -234,6 +236,7 @@ void mlir::configureArmSVELegalizeForExportTarget(
                     SmmlaIntrOp,
                     UdotIntrOp,
                     UmmlaIntrOp,
+                    UsmmlaIntrOp,
                     WhileLTIntrOp,
                     ZipX2IntrOp,
                     ZipX4IntrOp,
@@ -254,6 +257,7 @@ void mlir::configureArmSVELegalizeForExportTarget(
                       SmmlaOp,
                       UdotOp,
                       UmmlaOp,
+                      UsmmlaOp,
                       ZipX2Op,
                       ZipX4Op,
                       SdotOp>();
diff --git a/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
index 650b3e72d4ecd..8c658db009adf 100644
--- a/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
+++ b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
@@ -48,6 +48,18 @@ func.func @arm_sve_ummla(%a: vector<[16]xi8>,
 
 // -----
 
+func.func @arm_sve_usmmla(%a: vector<[16]xi8>,
+                    %b: vector<[16]xi8>,
+                    %c: vector<[4]xi32>)
+    -> vector<[4]xi32> {
+  // CHECK: arm_sve.intr.usmmla
+  %0 = arm_sve.usmmla %c, %a, %b :
+               vector<[16]xi8> to vector<[4]xi32>
+  return %0 : vector<[4]xi32>
+}
+
+// -----
+
 func.func @arm_sve_arithi_masked(%a: vector<[4]xi32>,
                             %b: vector<[4]xi32>,
                             %c: vector<[4]xi32>,
diff --git a/mlir/test/Dialect/ArmSVE/roundtrip.mlir b/mlir/test/Dialect/ArmSVE/roundtrip.mlir
index 0f0c5a8575772..64e0cff39eb06 100644
--- a/mlir/test/Dialect/ArmSVE/roundtrip.mlir
+++ b/mlir/test/Dialect/ArmSVE/roundtrip.mlir
@@ -44,6 +44,17 @@ func.func @arm_sve_ummla(%a: vector<[16]xi8>,
 
 // -----
 
+func.func @arm_sve_usmmla(%a: vector<[16]xi8>,
+                    %b: vector<[16]xi8>,
+                    %c: vector<[4]xi32>) -> vector<[4]xi32> {
+  // CHECK: arm_sve.usmmla {{.*}}: vector<[16]xi8> to vector<[4]xi3
+  %0 = arm_sve.usmmla %c, %a, %b :
+             vector<[16]xi8> to vector<[4]xi32>
+  return %0 : vector<[4]xi32>
+}
+
+// -----
+
 func.func @arm_sve_masked_arithi(%a: vector<[4]xi32>,
                             %b: vector<[4]xi32>,
                             %c: vector<[4]xi32>,
diff --git a/mlir/test/Target/LLVMIR/arm-sve.mlir b/mlir/test/Target/LLVMIR/arm-sve.mlir
index 14c68b21fd86c..da71cb5a63bd2 100644
--- a/mlir/test/Target/LLVMIR/arm-sve.mlir
+++ b/mlir/test/Target/LLVMIR/arm-sve.mlir
@@ -48,6 +48,18 @@ llvm.func @arm_sve_ummla(%arg0: vector<[16]xi8>,
   llvm.return %0 : vector<[4]xi32>
 }
 
+// CHECK-LABEL: define <vscale x 4 x i32> @arm_sve_usmmla
+llvm.func @arm_sve_usmmla(%arg0: vector<[16]xi8>,
+                         %arg1: vector<[16]xi8>,
+                         %arg2: vector<[4]xi32>)
+                         -> vector<[4]xi32> {
+  // CHECK: call <vscale x 4 x i32> @llvm.aarch64.sve.usmmla.nxv4i32(<vscale x 4
+  %0 = "arm_sve.intr.usmmla"(%arg2, %arg0, %arg1) :
+    (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>)
+        -> vector<[4]xi32>
+  llvm.return %0 : vector<[4]xi32>
+}
+
 // CHECK-LABEL: define <vscale x 4 x i32> @arm_sve_arithi
 llvm.func @arm_sve_arithi(%arg0: vector<[4]xi32>,
                           %arg1: vector<[4]xi32>,



More information about the llvm-branch-commits mailing list