[Mlir-commits] [mlir] [MLIR][ArmSVE] Add an ArmSVE dialect operation which maps to `svusmmla` (PR #135358)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Apr 11 05:31:10 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-sve

Author: Momchil Velikov (momchil-velikov)

<details>
<summary>Changes</summary>



---
Full diff: https://github.com/llvm/llvm-project/pull/135358.diff


5 Files Affected:

- (modified) mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td (+94-2) 
- (modified) mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp (+8) 
- (modified) mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir (+55) 
- (modified) mlir/test/Dialect/ArmSVE/roundtrip.mlir (+11) 
- (modified) mlir/test/Target/LLVMIR/arm-sve.mlir (+46) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
index cdcf4d8752e87..a678d9b09e66d 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
+++ b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
@@ -61,6 +61,13 @@ class Scalable1DVectorOfLength<int length, list<Type> elementTypes> : ShapedCont
   "a 1-D scalable vector with length " # length,
   "::mlir::VectorType">;
 
+def SVEVector : AnyTypeOf<[
+  Scalable1DVectorOfLength<2, [I64, F64]>,
+  Scalable1DVectorOfLength<4, [I32, F32]>,
+  Scalable1DVectorOfLength<8, [I16, F16, BF16]>,
+  Scalable1DVectorOfLength<16, [I8]>],
+  "an SVE vector with element size <= 64-bit">;
+
 //===----------------------------------------------------------------------===//
 // ArmSVE op definitions
 //===----------------------------------------------------------------------===//
@@ -72,14 +79,22 @@ class ArmSVE_IntrOp<string mnemonic,
                     list<Trait> traits = [],
                     list<int> overloadedOperands = [],
                     list<int> overloadedResults = [],
-                    int numResults = 1> :
+                    int numResults = 1,
+                    list<int> immArgPositions = [],
+                    list<string> immArgAttrNames = []> :
   LLVM_IntrOpBase</*Dialect dialect=*/ArmSVE_Dialect,
                   /*string opName=*/"intr." # mnemonic,
                   /*string enumName=*/"aarch64_sve_" # !subst(".", "_", mnemonic),
                   /*list<int> overloadedResults=*/overloadedResults,
                   /*list<int> overloadedOperands=*/overloadedOperands,
                   /*list<Trait> traits=*/traits,
-                  /*int numResults=*/numResults>;
+                  /*int numResults=*/numResults,
+                  /*bit requiresAccessGroup=*/0,
+                  /*bit requiresAliasAnalysis=*/0,
+                  /*bit requiresFastmath=*/0,
+                  /*bit requiresOpBundles=*/0,
+                  /*list<int> immArgPositions=*/immArgPositions,
+                  /*list<string> immArgAttrNames=*/immArgAttrNames>;
 
 class ArmSVE_IntrBinaryOverloadedOp<string mnemonic,
                                     list<Trait> traits = []>:
@@ -258,6 +273,34 @@ def UmmlaOp : ArmSVE_Op<"ummla",
     "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)";
 }
 
+def UsmmlaOp : ArmSVE_Op<"usmmla", [Pure,
+                                    AllTypesMatch<["src1", "src2"]>,
+                                    AllTypesMatch<["acc", "dst"]>]> {
+  let summary = "Matrix-matrix multiply and accumulate op";
+  let description = [{
+    USMMLA: Unsigned by signed integer matrix multiply-accumulate.
+
+    The unsigned by signed integer matrix multiply-accumulate operation
+    multiplies the 2×8 matrix of unsigned 8-bit integer values held
+    the first source vector by the 8×2 matrix of signed 8-bit integer
+    values in the second source vector. The resulting 2×2 widened 32-bit
+    integer matrix product is then added to the 32-bit integer matrix
+    accumulator.
+
+    Source:
+    https://developer.arm.com/documentation/100987/0000
+  }];
+  // Supports (vector<16xi8>, vector<16xi8>) -> (vector<4xi32>)
+  let arguments = (ins
+          ScalableVectorOfLengthAndType<[4], [I32]>:$acc,
+          ScalableVectorOfLengthAndType<[16], [I8]>:$src1,
+          ScalableVectorOfLengthAndType<[16], [I8]>:$src2
+  );
+  let results = (outs ScalableVectorOfLengthAndType<[4], [I32]>:$dst);
+  let assemblyFormat =
+    "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)";
+}
+
 class SvboolTypeConstraint<string lhsArg, string rhsArg> : TypesMatchWith<
       "expected corresponding svbool type widened to [16]xi1",
       lhsArg, rhsArg,
@@ -509,6 +552,41 @@ def ScalableMaskedUDivIOp : ScalableMaskedIOp<"masked.divi_unsigned",
 
 def ScalableMaskedDivFOp : ScalableMaskedFOp<"masked.divf", "division">;
 
+def DupQLaneOp : ArmSVE_Op<"dupq_lane", [Pure, AllTypesMatch<["src", "dst"]>]> {
+  let summary = "Broadcast indexed 128-bit segment to vector";
+
+  let description = [{
+    This operation fills each 128-bit segment of a vector with the elements
+    from the indexed 128-bit sgement of the source vector. If the VL is
+    128 bits the operation is a NOP.
+
+    Example:
+    ```mlir
+    // VL == 256
+    // %X = [A B C D x x x x]
+    %Y = arm_sve.dupq_lane %X[0] : vector<[4]xi32>
+    // Y = [A B C D A B C D]
+
+    // %U = [x x x x x x x x A B C D E F G H]
+    %V = arm_sve.dupq_lane %U[1] : vector<[8]xf16>
+    // %V = [A B C D E F H A B C D E F H]
+    ```
+  }];
+
+  let arguments = (ins SVEVector:$src,
+                       I64Attr:$lane);
+  let results = (outs SVEVector:$dst);
+
+  let builders = [
+    OpBuilder<(ins "Value":$src, "int64_t":$lane), [{
+      build($_builder, $_state, src.getType(), src, lane);
+    }]>];
+
+  let assemblyFormat = [{
+    $src `[` $lane `]` attr-dict `:` type($dst)
+  }];
+}
+
 def UmmlaIntrOp :
   ArmSVE_IntrBinaryOverloadedOp<"ummla">,
   Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
@@ -517,6 +595,10 @@ def SmmlaIntrOp :
   ArmSVE_IntrBinaryOverloadedOp<"smmla">,
   Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
 
+def UsmmlaIntrOp :
+  ArmSVE_IntrBinaryOverloadedOp<"usmmla">,
+  Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
+
 def SdotIntrOp :
   ArmSVE_IntrBinaryOverloadedOp<"sdot">,
   Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
@@ -610,4 +692,14 @@ def WhileLTIntrOp :
     /*overloadedResults=*/[0]>,
   Arguments<(ins I64:$base, I64:$n)>;
 
+def DupQLaneIntrOp : ArmSVE_IntrOp<"dupq_lane",
+    /*traits=*/[],
+    /*overloadedOperands=*/[0],
+    /*overloadedResults=*/[],
+    /*numResults=*/1,
+    /*immArgPositions*/[1],
+    /*immArgAttrNames*/["lane"]>,
+    Arguments<(ins Arg<ScalableVectorOfRank<[1]>, "v">:$v,
+                   Arg<I64Attr, "lane">:$lane)>;
+
 #endif // ARMSVE_OPS
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
index 2bdb640699d03..0bbe4717e0d17 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
@@ -24,6 +24,8 @@ using SdotOpLowering = OneToOneConvertToLLVMPattern<SdotOp, SdotIntrOp>;
 using SmmlaOpLowering = OneToOneConvertToLLVMPattern<SmmlaOp, SmmlaIntrOp>;
 using UdotOpLowering = OneToOneConvertToLLVMPattern<UdotOp, UdotIntrOp>;
 using UmmlaOpLowering = OneToOneConvertToLLVMPattern<UmmlaOp, UmmlaIntrOp>;
+using UsmmlaOpLowering = OneToOneConvertToLLVMPattern<UsmmlaOp, UsmmlaIntrOp>;
+using DupQLaneLowering = OneToOneConvertToLLVMPattern<DupQLaneOp, DupQLaneIntrOp>;
 using ScalableMaskedAddIOpLowering =
     OneToOneConvertToLLVMPattern<ScalableMaskedAddIOp,
                                  ScalableMaskedAddIIntrOp>;
@@ -192,6 +194,8 @@ void mlir::populateArmSVELegalizeForLLVMExportPatterns(
                SmmlaOpLowering,
                UdotOpLowering,
                UmmlaOpLowering,
+               UsmmlaOpLowering,
+               DupQLaneLowering,
                ScalableMaskedAddIOpLowering,
                ScalableMaskedAddFOpLowering,
                ScalableMaskedSubIOpLowering,
@@ -219,6 +223,8 @@ void mlir::configureArmSVELegalizeForExportTarget(
                     SmmlaIntrOp,
                     UdotIntrOp,
                     UmmlaIntrOp,
+                    UsmmlaIntrOp,
+                    DupQLaneIntrOp,
                     ScalableMaskedAddIIntrOp,
                     ScalableMaskedAddFIntrOp,
                     ScalableMaskedSubIIntrOp,
@@ -238,6 +244,8 @@ void mlir::configureArmSVELegalizeForExportTarget(
                       SmmlaOp,
                       UdotOp,
                       UmmlaOp,
+                      UsmmlaOp,
+                      DupQLaneOp,
                       ScalableMaskedAddIOp,
                       ScalableMaskedAddFOp,
                       ScalableMaskedSubIOp,
diff --git a/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
index bdb69a95a52de..47587aa26506c 100644
--- a/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
+++ b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
@@ -48,6 +48,18 @@ func.func @arm_sve_ummla(%a: vector<[16]xi8>,
 
 // -----
 
+func.func @arm_sve_usmmla(%a: vector<[16]xi8>,
+                    %b: vector<[16]xi8>,
+                    %c: vector<[4]xi32>)
+    -> vector<[4]xi32> {
+  // CHECK: arm_sve.intr.usmmla
+  %0 = arm_sve.usmmla %c, %a, %b :
+               vector<[16]xi8> to vector<[4]xi32>
+  return %0 : vector<[4]xi32>
+}
+
+// -----
+
 func.func @arm_sve_arithi_masked(%a: vector<[4]xi32>,
                             %b: vector<[4]xi32>,
                             %c: vector<[4]xi32>,
@@ -271,3 +283,46 @@ func.func @arm_sve_psel_mixed_predicate_types(%p0: vector<[8]xi1>, %p1: vector<[
   %0 = arm_sve.psel %p0, %p1[%index] : vector<[8]xi1>, vector<[16]xi1>
   return %0 : vector<[8]xi1>
 }
+
+// -----
+
+// CHECK-LABEL: @arm_sve_dupq_lane(
+// CHECK-SAME: %[[A0:[a-z0-9]+]]: vector<[16]xi8>
+// CHECK-SAME: %[[A1:[a-z0-9]+]]: vector<[8]xi16>
+// CHECK-SAME: %[[A2:[a-z0-9]+]]: vector<[8]xf16>
+// CHECK-SAME: %[[A3:[a-z0-9]+]]: vector<[8]xbf16>
+// CHECK-SAME: %[[A4:[a-z0-9]+]]: vector<[4]xi32>
+// CHECK-SAME: %[[A5:[a-z0-9]+]]: vector<[4]xf32>
+// CHECK-SAME: %[[A6:[a-z0-9]+]]: vector<[2]xi64>
+// CHECK-SAME: %[[A7:[a-z0-9]+]]: vector<[2]xf64>
+// CHECK-SAME: -> !llvm.struct<(vector<[16]xi8>, vector<[8]xi16>, vector<[8]xf16>, vector<[8]xbf16>, vector<[4]xi32>, vector<[4]xf32>, vector<[2]xi64>, vector<[2]xf64>)> {
+
+// CHECK: "arm_sve.intr.dupq_lane"(%[[A0]]) <{lane = 0 : i64}> : (vector<[16]xi8>) -> vector<[16]xi8>
+// CHECK: "arm_sve.intr.dupq_lane"(%[[A1]]) <{lane = 1 : i64}> : (vector<[8]xi16>) -> vector<[8]xi16>
+// CHECK: "arm_sve.intr.dupq_lane"(%[[A2]]) <{lane = 2 : i64}> : (vector<[8]xf16>) -> vector<[8]xf16>
+// CHECK: "arm_sve.intr.dupq_lane"(%[[A3]]) <{lane = 3 : i64}> : (vector<[8]xbf16>) -> vector<[8]xbf16>
+// CHECK: "arm_sve.intr.dupq_lane"(%[[A4]]) <{lane = 4 : i64}> : (vector<[4]xi32>) -> vector<[4]xi32>
+// CHECK: "arm_sve.intr.dupq_lane"(%[[A5]]) <{lane = 5 : i64}> : (vector<[4]xf32>) -> vector<[4]xf32>
+// CHECK: "arm_sve.intr.dupq_lane"(%[[A6]]) <{lane = 6 : i64}> : (vector<[2]xi64>) -> vector<[2]xi64>
+// CHECK: "arm_sve.intr.dupq_lane"(%[[A7]]) <{lane = 7 : i64}> : (vector<[2]xf64>) -> vector<[2]xf64>
+func.func @arm_sve_dupq_lane(
+  %v16i8: vector<[16]xi8>, %v8i16: vector<[8]xi16>,
+  %v8f16: vector<[8]xf16>, %v8bf16: vector<[8]xbf16>,
+  %v4i32: vector<[4]xi32>, %v4f32: vector<[4]xf32>,
+  %v2i64: vector<[2]xi64>, %v2f64: vector<[2]xf64>)
+  -> (vector<[16]xi8>, vector<[8]xi16>, vector<[8]xf16>, vector<[8]xbf16>,
+      vector<[4]xi32>, vector<[4]xf32>, vector<[2]xi64>, vector<[2]xf64>) {
+
+  %0 = arm_sve.dupq_lane %v16i8[0]  : vector<[16]xi8>
+  %1 = arm_sve.dupq_lane %v8i16[1]  : vector<[8]xi16>
+  %2 = arm_sve.dupq_lane %v8f16[2]  : vector<[8]xf16>
+  %3 = arm_sve.dupq_lane %v8bf16[3] : vector<[8]xbf16>
+  %4 = arm_sve.dupq_lane %v4i32[4]  : vector<[4]xi32>
+  %5 = arm_sve.dupq_lane %v4f32[5]  : vector<[4]xf32>
+  %6 = arm_sve.dupq_lane %v2i64[6]  : vector<[2]xi64>
+  %7 = arm_sve.dupq_lane %v2f64[7]  : vector<[2]xf64>
+
+  return %0, %1, %2, %3, %4, %5, %6, %7
+    : vector<[16]xi8>, vector<[8]xi16>, vector<[8]xf16>, vector<[8]xbf16>,
+      vector<[4]xi32>, vector<[4]xf32>, vector<[2]xi64>, vector<[2]xf64>
+}
diff --git a/mlir/test/Dialect/ArmSVE/roundtrip.mlir b/mlir/test/Dialect/ArmSVE/roundtrip.mlir
index 0f0c5a8575772..64e0cff39eb06 100644
--- a/mlir/test/Dialect/ArmSVE/roundtrip.mlir
+++ b/mlir/test/Dialect/ArmSVE/roundtrip.mlir
@@ -44,6 +44,17 @@ func.func @arm_sve_ummla(%a: vector<[16]xi8>,
 
 // -----
 
+func.func @arm_sve_usmmla(%a: vector<[16]xi8>,
+                    %b: vector<[16]xi8>,
+                    %c: vector<[4]xi32>) -> vector<[4]xi32> {
+  // CHECK: arm_sve.usmmla {{.*}}: vector<[16]xi8> to vector<[4]xi3
+  %0 = arm_sve.usmmla %c, %a, %b :
+             vector<[16]xi8> to vector<[4]xi32>
+  return %0 : vector<[4]xi32>
+}
+
+// -----
+
 func.func @arm_sve_masked_arithi(%a: vector<[4]xi32>,
                             %b: vector<[4]xi32>,
                             %c: vector<[4]xi32>,
diff --git a/mlir/test/Target/LLVMIR/arm-sve.mlir b/mlir/test/Target/LLVMIR/arm-sve.mlir
index ed5a1fc7ba2e4..2061f938b4c57 100644
--- a/mlir/test/Target/LLVMIR/arm-sve.mlir
+++ b/mlir/test/Target/LLVMIR/arm-sve.mlir
@@ -48,6 +48,18 @@ llvm.func @arm_sve_ummla(%arg0: vector<[16]xi8>,
   llvm.return %0 : vector<[4]xi32>
 }
 
+// CHECK-LABEL: define <vscale x 4 x i32> @arm_sve_usmmla
+llvm.func @arm_sve_usmmla(%arg0: vector<[16]xi8>,
+                         %arg1: vector<[16]xi8>,
+                         %arg2: vector<[4]xi32>)
+                         -> vector<[4]xi32> {
+  // CHECK: call <vscale x 4 x i32> @llvm.aarch64.sve.usmmla.nxv4i32(<vscale x 4
+  %0 = "arm_sve.intr.usmmla"(%arg2, %arg0, %arg1) :
+    (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>)
+        -> vector<[4]xi32>
+  llvm.return %0 : vector<[4]xi32>
+}
+
 // CHECK-LABEL: define <vscale x 4 x i32> @arm_sve_arithi
 llvm.func @arm_sve_arithi(%arg0: vector<[4]xi32>,
                           %arg1: vector<[4]xi32>,
@@ -390,3 +402,37 @@ llvm.func @arm_sve_psel(%pn: vector<[16]xi1>, %p1: vector<[2]xi1>, %p2: vector<[
   "arm_sve.intr.psel"(%pn, %p4, %index) : (vector<[16]xi1>, vector<[16]xi1>, i32) -> vector<[16]xi1>
   llvm.return
 }
+
+// CHECK-LABEL: @arm_sve_dupq_lane
+// CHECK-SAME: <vscale x 16 x i8> %0
+// CHECK-SAME: <vscale x 8 x i16> %1
+// CHECK-SAME: <vscale x 8 x half> %2
+// CHECK-SAME: <vscale x 8 x bfloat> %3
+// CHECK-SAME: <vscale x 4 x i32> %4
+// CHECK-SAME: <vscale x 4 x float> %5
+// CHECK-SAME: <vscale x 2 x i64> %6
+// CHECK-SAME: <vscale x 2 x double> %7
+
+// CHECK: call <vscale x 16 x i8> @llvm.aarch64.sve.dupq.lane.nxv16i8(<vscale x 16 x i8> %0, i64 0)
+// CHECK: call <vscale x 8 x i16> @llvm.aarch64.sve.dupq.lane.nxv8i16(<vscale x 8 x i16> %1, i64 1)
+// CHECK: call <vscale x 8 x half> @llvm.aarch64.sve.dupq.lane.nxv8f16(<vscale x 8 x half> %2, i64 2)
+// CHECK: call <vscale x 8 x bfloat> @llvm.aarch64.sve.dupq.lane.nxv8bf16(<vscale x 8 x bfloat> %3, i64 3)
+// CHECK: call <vscale x 4 x i32> @llvm.aarch64.sve.dupq.lane.nxv4i32(<vscale x 4 x i32> %4, i64 4)
+// CHECK: call <vscale x 4 x float> @llvm.aarch64.sve.dupq.lane.nxv4f32(<vscale x 4 x float> %5, i64 5)
+// CHECK: call <vscale x 2 x i64> @llvm.aarch64.sve.dupq.lane.nxv2i64(<vscale x 2 x i64> %6, i64 6)
+// CHECK: call <vscale x 2 x double> @llvm.aarch64.sve.dupq.lane.nxv2f64(<vscale x 2 x double> %7, i64 7)
+
+llvm.func @arm_sve_dupq_lane(%arg0: vector<[16]xi8>, %arg1: vector<[8]xi16>,
+                             %arg2: vector<[8]xf16>, %arg3: vector<[8]xbf16>,
+                             %arg4: vector<[4]xi32>,%arg5: vector<[4]xf32>,
+                             %arg6: vector<[2]xi64>, %arg7: vector<[2]xf64>) {
+  %0 = "arm_sve.intr.dupq_lane"(%arg0) <{lane = 0 : i64}> : (vector<[16]xi8>) -> vector<[16]xi8>
+  %1 = "arm_sve.intr.dupq_lane"(%arg1) <{lane = 1 : i64}> : (vector<[8]xi16>) -> vector<[8]xi16>
+  %2 = "arm_sve.intr.dupq_lane"(%arg2) <{lane = 2 : i64}> : (vector<[8]xf16>) -> vector<[8]xf16>
+  %3 = "arm_sve.intr.dupq_lane"(%arg3) <{lane = 3 : i64}> : (vector<[8]xbf16>) -> vector<[8]xbf16>
+  %4 = "arm_sve.intr.dupq_lane"(%arg4) <{lane = 4 : i64}> : (vector<[4]xi32>) -> vector<[4]xi32>
+  %5 = "arm_sve.intr.dupq_lane"(%arg5) <{lane = 5 : i64}> : (vector<[4]xf32>) -> vector<[4]xf32>
+  %6 = "arm_sve.intr.dupq_lane"(%arg6) <{lane = 6 : i64}> : (vector<[2]xi64>) -> vector<[2]xi64>
+  %7 = "arm_sve.intr.dupq_lane"(%arg7) <{lane = 7 : i64}> : (vector<[2]xf64>) -> vector<[2]xf64>
+  llvm.return
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/135358


More information about the Mlir-commits mailing list