[llvm-branch-commits] [mlir] [MLIR][ArmSVE] Add an ArmSVE dialect operation which maps to svusmmla (PR #135634)
via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Mon Apr 14 08:45:24 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-sve
Author: Momchil Velikov (momchil-velikov)
<details>
<summary>Changes</summary>
Supersedes https://github.com/llvm/llvm-project/pull/135358
---
Full diff: https://github.com/llvm/llvm-project/pull/135634.diff
5 Files Affected:
- (modified) mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td (+32)
- (modified) mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp (+4)
- (modified) mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir (+12)
- (modified) mlir/test/Dialect/ArmSVE/roundtrip.mlir (+11)
- (modified) mlir/test/Target/LLVMIR/arm-sve.mlir (+12)
``````````diff
diff --git a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
index 1a59062ccc93d..da2a8f89b4cfd 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
+++ b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
@@ -273,6 +273,34 @@ def UmmlaOp : ArmSVE_Op<"ummla",
"$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)";
}
+def UsmmlaOp : ArmSVE_Op<"usmmla", [Pure,
+ AllTypesMatch<["src1", "src2"]>,
+ AllTypesMatch<["acc", "dst"]>]> {
+ let summary = "Matrix-matrix multiply and accumulate op";
+ let description = [{
+ USMMLA: Unsigned by signed integer matrix multiply-accumulate.
+
+ The unsigned by signed integer matrix multiply-accumulate operation
+ multiplies the 2×8 matrix of unsigned 8-bit integer values held
+ the first source vector by the 8×2 matrix of signed 8-bit integer
+ values in the second source vector. The resulting 2×2 widened 32-bit
+ integer matrix product is then added to the 32-bit integer matrix
+ accumulator.
+
+ Source:
+ https://developer.arm.com/documentation/100987/0000
+ }];
+ // Supports (vector<16xi8>, vector<16xi8>) -> (vector<4xi32>)
+ let arguments = (ins
+ ScalableVectorOfLengthAndType<[4], [I32]>:$acc,
+ ScalableVectorOfLengthAndType<[16], [I8]>:$src1,
+ ScalableVectorOfLengthAndType<[16], [I8]>:$src2
+ );
+ let results = (outs ScalableVectorOfLengthAndType<[4], [I32]>:$dst);
+ let assemblyFormat =
+ "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)";
+}
+
class SvboolTypeConstraint<string lhsArg, string rhsArg> : TypesMatchWith<
"expected corresponding svbool type widened to [16]xi1",
lhsArg, rhsArg,
@@ -568,6 +596,10 @@ def SmmlaIntrOp :
ArmSVE_IntrBinaryOverloadedOp<"smmla">,
Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
+def UsmmlaIntrOp :
+ ArmSVE_IntrBinaryOverloadedOp<"usmmla">,
+ Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
+
def SdotIntrOp :
ArmSVE_IntrBinaryOverloadedOp<"sdot">,
Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
index fe13ed03356b2..b1846e15196fc 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
@@ -24,6 +24,7 @@ using SdotOpLowering = OneToOneConvertToLLVMPattern<SdotOp, SdotIntrOp>;
using SmmlaOpLowering = OneToOneConvertToLLVMPattern<SmmlaOp, SmmlaIntrOp>;
using UdotOpLowering = OneToOneConvertToLLVMPattern<UdotOp, UdotIntrOp>;
using UmmlaOpLowering = OneToOneConvertToLLVMPattern<UmmlaOp, UmmlaIntrOp>;
+using UsmmlaOpLowering = OneToOneConvertToLLVMPattern<UsmmlaOp, UsmmlaIntrOp>;
using DupQLaneLowering =
OneToOneConvertToLLVMPattern<DupQLaneOp, DupQLaneIntrOp>;
using ScalableMaskedAddIOpLowering =
@@ -194,6 +195,7 @@ void mlir::populateArmSVELegalizeForLLVMExportPatterns(
SmmlaOpLowering,
UdotOpLowering,
UmmlaOpLowering,
+ UsmmlaOpLowering,
DupQLaneLowering,
ScalableMaskedAddIOpLowering,
ScalableMaskedAddFOpLowering,
@@ -222,6 +224,7 @@ void mlir::configureArmSVELegalizeForExportTarget(
SmmlaIntrOp,
UdotIntrOp,
UmmlaIntrOp,
+ UsmmlaIntrOp,
DupQLaneIntrOp,
ScalableMaskedAddIIntrOp,
ScalableMaskedAddFIntrOp,
@@ -242,6 +245,7 @@ void mlir::configureArmSVELegalizeForExportTarget(
SmmlaOp,
UdotOp,
UmmlaOp,
+ UsmmlaOp,
DupQLaneOp,
ScalableMaskedAddIOp,
ScalableMaskedAddFOp,
diff --git a/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
index 5d044517e0ea8..47587aa26506c 100644
--- a/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
+++ b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
@@ -48,6 +48,18 @@ func.func @arm_sve_ummla(%a: vector<[16]xi8>,
// -----
+func.func @arm_sve_usmmla(%a: vector<[16]xi8>,
+ %b: vector<[16]xi8>,
+ %c: vector<[4]xi32>)
+ -> vector<[4]xi32> {
+ // CHECK: arm_sve.intr.usmmla
+ %0 = arm_sve.usmmla %c, %a, %b :
+ vector<[16]xi8> to vector<[4]xi32>
+ return %0 : vector<[4]xi32>
+}
+
+// -----
+
func.func @arm_sve_arithi_masked(%a: vector<[4]xi32>,
%b: vector<[4]xi32>,
%c: vector<[4]xi32>,
diff --git a/mlir/test/Dialect/ArmSVE/roundtrip.mlir b/mlir/test/Dialect/ArmSVE/roundtrip.mlir
index 0f0c5a8575772..64e0cff39eb06 100644
--- a/mlir/test/Dialect/ArmSVE/roundtrip.mlir
+++ b/mlir/test/Dialect/ArmSVE/roundtrip.mlir
@@ -44,6 +44,17 @@ func.func @arm_sve_ummla(%a: vector<[16]xi8>,
// -----
+func.func @arm_sve_usmmla(%a: vector<[16]xi8>,
+ %b: vector<[16]xi8>,
+ %c: vector<[4]xi32>) -> vector<[4]xi32> {
+ // CHECK: arm_sve.usmmla {{.*}}: vector<[16]xi8> to vector<[4]xi3
+ %0 = arm_sve.usmmla %c, %a, %b :
+ vector<[16]xi8> to vector<[4]xi32>
+ return %0 : vector<[4]xi32>
+}
+
+// -----
+
func.func @arm_sve_masked_arithi(%a: vector<[4]xi32>,
%b: vector<[4]xi32>,
%c: vector<[4]xi32>,
diff --git a/mlir/test/Target/LLVMIR/arm-sve.mlir b/mlir/test/Target/LLVMIR/arm-sve.mlir
index ced59eb513b57..4d9b0da611cb0 100644
--- a/mlir/test/Target/LLVMIR/arm-sve.mlir
+++ b/mlir/test/Target/LLVMIR/arm-sve.mlir
@@ -48,6 +48,18 @@ llvm.func @arm_sve_ummla(%arg0: vector<[16]xi8>,
llvm.return %0 : vector<[4]xi32>
}
+// CHECK-LABEL: define <vscale x 4 x i32> @arm_sve_usmmla
+llvm.func @arm_sve_usmmla(%arg0: vector<[16]xi8>,
+ %arg1: vector<[16]xi8>,
+ %arg2: vector<[4]xi32>)
+ -> vector<[4]xi32> {
+ // CHECK: call <vscale x 4 x i32> @llvm.aarch64.sve.usmmla.nxv4i32(<vscale x 4
+ %0 = "arm_sve.intr.usmmla"(%arg2, %arg0, %arg1) :
+ (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>)
+ -> vector<[4]xi32>
+ llvm.return %0 : vector<[4]xi32>
+}
+
// CHECK-LABEL: define <vscale x 4 x i32> @arm_sve_arithi
llvm.func @arm_sve_arithi(%arg0: vector<[4]xi32>,
%arg1: vector<[4]xi32>,
``````````
</details>
https://github.com/llvm/llvm-project/pull/135634
More information about the llvm-branch-commits
mailing list