[Mlir-commits] [mlir] 9586121 - [mlir][ArmSVE] Add masked arithmetic operations

Javier Setoain llvmlistbot at llvm.org
Wed May 5 09:47:18 PDT 2021


Author: Javier Setoain
Date: 2021-05-05T17:41:58+01:00
New Revision: 95861216ac6558dc0dbcf638902feb9072c84661

URL: https://github.com/llvm/llvm-project/commit/95861216ac6558dc0dbcf638902feb9072c84661
DIFF: https://github.com/llvm/llvm-project/commit/95861216ac6558dc0dbcf638902feb9072c84661.diff

LOG: [mlir][ArmSVE] Add masked arithmetic operations

These instructions map to SVE-specific instrinsics that accept a
predicate operand to support control flow in vector code.

Differential Revision: https://reviews.llvm.org/D100982

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td
    mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp
    mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
    mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
    mlir/test/Dialect/ArmSVE/roundtrip.mlir
    mlir/test/Target/LLVMIR/arm-sve.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td b/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td
index 33c60ba7c8a5..e34177bb5094 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td
+++ b/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td
@@ -95,6 +95,13 @@ def ScalableVectorType : ArmSVE_Type<"ScalableVector"> {
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// Additional LLVM type constraints
+//===----------------------------------------------------------------------===//
+def LLVMScalableVectorType :
+  Type<CPred<"$_self.isa<::mlir::LLVM::LLVMScalableVectorType>()">,
+       "LLVM dialect scalable vector type">;
+
 //===----------------------------------------------------------------------===//
 // ArmSVE op definitions
 //===----------------------------------------------------------------------===//
@@ -158,6 +165,52 @@ class ScalableIOp<string mnemonic, string op_description,
     "$src1 `,` $src2 attr-dict `:` type($src1)";
 }
 
+class ScalableMaskedFOp<string mnemonic, string op_description,
+                        list<OpTrait> traits = []> :
+  ArmSVE_Op<mnemonic, !listconcat(traits,
+                       [AllTypesMatch<["src1", "src2", "res"]>,
+                        TypesMatchWith<
+                          "mask has i1 element type and same shape as operands",
+                          "src1", "mask", "getI1SameShape($_self)">])> {
+  let summary = "masked " # op_description # " for scalable vectors of floats";
+  let description = [{
+    The `arm_sve.}] # mnemonic # [{` operation takes one scalable vector mask
+    and two scalable vector operands, and perform floating point }] #
+    op_description # [{ on active lanes. Inactive lanes will keep the value of
+    the first operand.}];
+  let arguments = (ins
+          ScalableVectorOf<[I1]>:$mask,
+          ScalableVectorOf<[AnyFloat]>:$src1,
+          ScalableVectorOf<[AnyFloat]>:$src2
+  );
+  let results = (outs ScalableVectorOf<[AnyFloat]>:$res);
+  let assemblyFormat =
+    "$mask `,` $src1 `,` $src2 attr-dict `:` type($mask) `,` type($res)";
+}
+
+class ScalableMaskedIOp<string mnemonic, string op_description,
+                        list<OpTrait> traits = []> :
+  ArmSVE_Op<mnemonic, !listconcat(traits,
+                       [AllTypesMatch<["src1", "src2", "res"]>,
+                        TypesMatchWith<
+                          "mask has i1 element type and same shape as operands",
+                          "src1", "mask", "getI1SameShape($_self)">])> {
+  let summary = "masked " # op_description # " for scalable vectors of integers";
+  let description = [{
+    The `arm_sve.}] # mnemonic # [{` operation takes one scalable vector mask
+    and two scalable vector operands, and perform integer }] #
+    op_description # [{ on active lanes. Inactive lanes will keep the value of
+    the first operand.}];
+  let arguments = (ins
+          ScalableVectorOf<[I1]>:$mask,
+          ScalableVectorOf<[I8, I16, I32, I64]>:$src1,
+          ScalableVectorOf<[I8, I16, I32, I64]>:$src2
+  );
+  let results = (outs ScalableVectorOf<[I8, I16, I32, I64]>:$res);
+  let assemblyFormat =
+    "$mask `,` $src1 `,` $src2 attr-dict `:` type($mask) `,` type($res)";
+}
+
 def SdotOp : ArmSVE_Op<"sdot",
                [NoSideEffect,
                AllTypesMatch<["src1", "src2"]>,
@@ -321,21 +374,94 @@ def ScalableUDivIOp : ScalableIOp<"divi_unsigned", "unsigned division">;
 
 def ScalableDivFOp : ScalableFOp<"divf", "division">;
 
+def ScalableMaskedAddIOp : ScalableMaskedIOp<"masked.addi", "addition",
+                                             [Commutative]>;
+
+def ScalableMaskedAddFOp : ScalableMaskedFOp<"masked.addf", "addition",
+                            [Commutative]>;
+
+def ScalableMaskedSubIOp : ScalableMaskedIOp<"masked.subi", "subtraction">;
+
+def ScalableMaskedSubFOp : ScalableMaskedFOp<"masked.subf", "subtraction">;
+
+def ScalableMaskedMulIOp : ScalableMaskedIOp<"masked.muli", "multiplication",
+                            [Commutative]>;
+
+def ScalableMaskedMulFOp : ScalableMaskedFOp<"masked.mulf", "multiplication",
+                            [Commutative]>;
+
+def ScalableMaskedSDivIOp : ScalableMaskedIOp<"masked.divi_signed",
+                                              "signed division">;
+
+def ScalableMaskedUDivIOp : ScalableMaskedIOp<"masked.divi_unsigned",
+                                              "unsigned division">;
+
+def ScalableMaskedDivFOp : ScalableMaskedFOp<"masked.divf", "division">;
+
 def UmmlaIntrOp :
   ArmSVE_IntrBinaryOverloadedOp<"ummla">,
-  Arguments<(ins LLVM_AnyVector, LLVM_AnyVector, LLVM_AnyVector)>;
+  Arguments<(ins LLVMScalableVectorType, LLVMScalableVectorType,
+             LLVMScalableVectorType)>;
 
 def SmmlaIntrOp :
   ArmSVE_IntrBinaryOverloadedOp<"smmla">,
-  Arguments<(ins LLVM_AnyVector, LLVM_AnyVector, LLVM_AnyVector)>;
+  Arguments<(ins LLVMScalableVectorType, LLVMScalableVectorType,
+             LLVMScalableVectorType)>;
 
 def SdotIntrOp :
   ArmSVE_IntrBinaryOverloadedOp<"sdot">,
-  Arguments<(ins LLVM_AnyVector, LLVM_AnyVector, LLVM_AnyVector)>;
+  Arguments<(ins LLVMScalableVectorType, LLVMScalableVectorType,
+             LLVMScalableVectorType)>;
 
 def UdotIntrOp :
   ArmSVE_IntrBinaryOverloadedOp<"udot">,
-  Arguments<(ins LLVM_AnyVector, LLVM_AnyVector, LLVM_AnyVector)>;
+  Arguments<(ins LLVMScalableVectorType, LLVMScalableVectorType,
+             LLVMScalableVectorType)>;
+
+def ScalableMaskedAddIIntrOp :
+  ArmSVE_IntrBinaryOverloadedOp<"add">,
+  Arguments<(ins LLVMScalableVectorType, LLVMScalableVectorType,
+             LLVMScalableVectorType)>;
+
+def ScalableMaskedAddFIntrOp :
+  ArmSVE_IntrBinaryOverloadedOp<"fadd">,
+  Arguments<(ins LLVMScalableVectorType, LLVMScalableVectorType,
+             LLVMScalableVectorType)>;
+
+def ScalableMaskedMulIIntrOp :
+  ArmSVE_IntrBinaryOverloadedOp<"mul">,
+  Arguments<(ins LLVMScalableVectorType, LLVMScalableVectorType,
+             LLVMScalableVectorType)>;
+
+def ScalableMaskedMulFIntrOp :
+  ArmSVE_IntrBinaryOverloadedOp<"fmul">,
+  Arguments<(ins LLVMScalableVectorType, LLVMScalableVectorType,
+             LLVMScalableVectorType)>;
+
+def ScalableMaskedSubIIntrOp :
+  ArmSVE_IntrBinaryOverloadedOp<"sub">,
+  Arguments<(ins LLVMScalableVectorType, LLVMScalableVectorType,
+             LLVMScalableVectorType)>;
+
+def ScalableMaskedSubFIntrOp :
+  ArmSVE_IntrBinaryOverloadedOp<"fsub">,
+  Arguments<(ins LLVMScalableVectorType, LLVMScalableVectorType,
+             LLVMScalableVectorType)>;
+
+def ScalableMaskedSDivIIntrOp :
+  ArmSVE_IntrBinaryOverloadedOp<"sdiv">,
+  Arguments<(ins LLVMScalableVectorType, LLVMScalableVectorType,
+             LLVMScalableVectorType)>;
+
+def ScalableMaskedUDivIIntrOp :
+  ArmSVE_IntrBinaryOverloadedOp<"udiv">,
+  Arguments<(ins LLVMScalableVectorType, LLVMScalableVectorType,
+             LLVMScalableVectorType)>;
+
+def ScalableMaskedDivFIntrOp :
+  ArmSVE_IntrBinaryOverloadedOp<"fdiv">,
+  Arguments<(ins LLVMScalableVectorType, LLVMScalableVectorType,
+             LLVMScalableVectorType)>;
 
 def VectorScaleIntrOp:
   ArmSVE_NonSVEIntrUnaryOverloadedOp<"vscale">;

diff  --git a/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp b/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp
index 6091626c011c..b86ba14303f8 100644
--- a/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp
+++ b/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp
@@ -21,6 +21,8 @@
 
 using namespace mlir;
 
+static Type getI1SameShape(Type type);
+
 #define GET_OP_CLASSES
 #include "mlir/Dialect/ArmSVE/ArmSVE.cpp.inc"
 
@@ -59,3 +61,16 @@ void arm_sve::ArmSVEDialect::printType(Type type, DialectAsmPrinter &os) const {
   if (failed(generatedTypePrinter(type, os)))
     llvm_unreachable("unexpected 'arm_sve' type kind");
 }
+
+//===----------------------------------------------------------------------===//
+// ScalableVector versions of general helpers for comparison ops
+//===----------------------------------------------------------------------===//
+
+// Return the scalable vector of the same shape and containing i1.
+static Type getI1SameShape(Type type) {
+  auto i1Type = IntegerType::get(type.getContext(), 1);
+  if (auto sVectorType = type.dyn_cast<arm_sve::ScalableVectorType>())
+    return arm_sve::ScalableVectorType::get(type.getContext(),
+                                            sVectorType.getShape(), i1Type);
+  return nullptr;
+}

diff  --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
index b258f2ad9315..845f407fba3f 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
@@ -83,6 +83,33 @@ using UdotOpLowering = OneToOneConvertToLLVMPattern<UdotOp, UdotIntrOp>;
 using UmmlaOpLowering = OneToOneConvertToLLVMPattern<UmmlaOp, UmmlaIntrOp>;
 using VectorScaleOpLowering =
     OneToOneConvertToLLVMPattern<VectorScaleOp, VectorScaleIntrOp>;
+using ScalableMaskedAddIOpLowering =
+    OneToOneConvertToLLVMPattern<ScalableMaskedAddIOp,
+                                 ScalableMaskedAddIIntrOp>;
+using ScalableMaskedAddFOpLowering =
+    OneToOneConvertToLLVMPattern<ScalableMaskedAddFOp,
+                                 ScalableMaskedAddFIntrOp>;
+using ScalableMaskedSubIOpLowering =
+    OneToOneConvertToLLVMPattern<ScalableMaskedSubIOp,
+                                 ScalableMaskedSubIIntrOp>;
+using ScalableMaskedSubFOpLowering =
+    OneToOneConvertToLLVMPattern<ScalableMaskedSubFOp,
+                                 ScalableMaskedSubFIntrOp>;
+using ScalableMaskedMulIOpLowering =
+    OneToOneConvertToLLVMPattern<ScalableMaskedMulIOp,
+                                 ScalableMaskedMulIIntrOp>;
+using ScalableMaskedMulFOpLowering =
+    OneToOneConvertToLLVMPattern<ScalableMaskedMulFOp,
+                                 ScalableMaskedMulFIntrOp>;
+using ScalableMaskedSDivIOpLowering =
+    OneToOneConvertToLLVMPattern<ScalableMaskedSDivIOp,
+                                 ScalableMaskedSDivIIntrOp>;
+using ScalableMaskedUDivIOpLowering =
+    OneToOneConvertToLLVMPattern<ScalableMaskedUDivIOp,
+                                 ScalableMaskedUDivIIntrOp>;
+using ScalableMaskedDivFOpLowering =
+    OneToOneConvertToLLVMPattern<ScalableMaskedDivFOp,
+                                 ScalableMaskedDivFIntrOp>;
 
 static void
 populateBasicSVEArithmeticExportPatterns(LLVMTypeConverter &converter,
@@ -136,16 +163,52 @@ void mlir::populateArmSVELegalizeForLLVMExportPatterns(
                SmmlaOpLowering,
                UdotOpLowering,
                UmmlaOpLowering,
-               VectorScaleOpLowering>(converter);
+               VectorScaleOpLowering,
+               ScalableMaskedAddIOpLowering,
+               ScalableMaskedAddFOpLowering,
+               ScalableMaskedSubIOpLowering,
+               ScalableMaskedSubFOpLowering,
+               ScalableMaskedMulIOpLowering,
+               ScalableMaskedMulFOpLowering,
+               ScalableMaskedSDivIOpLowering,
+               ScalableMaskedUDivIOpLowering,
+               ScalableMaskedDivFOpLowering>(converter);
   // clang-format on
   populateBasicSVEArithmeticExportPatterns(converter, patterns);
 }
 
 void mlir::configureArmSVELegalizeForExportTarget(
     LLVMConversionTarget &target) {
-  target.addLegalOp<SdotIntrOp, SmmlaIntrOp, UdotIntrOp, UmmlaIntrOp,
-                    VectorScaleIntrOp>();
-  target.addIllegalOp<SdotOp, SmmlaOp, UdotOp, UmmlaOp, VectorScaleOp>();
+  // clang-format off
+  target.addLegalOp<SdotIntrOp,
+                    SmmlaIntrOp,
+                    UdotIntrOp,
+                    UmmlaIntrOp,
+                    VectorScaleIntrOp,
+                    ScalableMaskedAddIIntrOp,
+                    ScalableMaskedAddFIntrOp,
+                    ScalableMaskedSubIIntrOp,
+                    ScalableMaskedSubFIntrOp,
+                    ScalableMaskedMulIIntrOp,
+                    ScalableMaskedMulFIntrOp,
+                    ScalableMaskedSDivIIntrOp,
+                    ScalableMaskedUDivIIntrOp,
+                    ScalableMaskedDivFIntrOp>();
+  target.addIllegalOp<SdotOp,
+                      SmmlaOp,
+                      UdotOp,
+                      UmmlaOp,
+                      VectorScaleOp,
+                      ScalableMaskedAddIOp,
+                      ScalableMaskedAddFOp,
+                      ScalableMaskedSubIOp,
+                      ScalableMaskedSubFOp,
+                      ScalableMaskedMulIOp,
+                      ScalableMaskedMulFOp,
+                      ScalableMaskedSDivIOp,
+                      ScalableMaskedUDivIOp,
+                      ScalableMaskedDivFOp>();
+  // clang-format on
   auto hasScalableVectorType = [](TypeRange types) {
     for (Type type : types)
       if (type.isa<arm_sve::ScalableVectorType>())

diff  --git a/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
index f81196f4928f..2b2eda0bf32e 100644
--- a/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
+++ b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
@@ -55,7 +55,7 @@ func @arm_sve_arithi(%a: !arm_sve.vector<4xi32>,
   %3 = arm_sve.divi_signed %2, %e : !arm_sve.vector<4xi32>
   // CHECK: llvm.udiv {{.*}}: !llvm.vec<? x 4 x i32>
   %4 = arm_sve.divi_unsigned %2, %e : !arm_sve.vector<4xi32>
-  return %3 : !arm_sve.vector<4xi32>
+  return %4 : !arm_sve.vector<4xi32>
 }
 
 func @arm_sve_arithf(%a: !arm_sve.vector<4xf32>,
@@ -74,6 +74,53 @@ func @arm_sve_arithf(%a: !arm_sve.vector<4xf32>,
   return %3 : !arm_sve.vector<4xf32>
 }
 
+func @arm_sve_arithi_masked(%a: !arm_sve.vector<4xi32>,
+                            %b: !arm_sve.vector<4xi32>,
+                            %c: !arm_sve.vector<4xi32>,
+                            %d: !arm_sve.vector<4xi32>,
+                            %e: !arm_sve.vector<4xi32>,
+                            %mask: !arm_sve.vector<4xi1>
+                            ) -> !arm_sve.vector<4xi32> {
+  // CHECK: arm_sve.intr.add{{.*}}: (!llvm.vec<? x 4 x i1>, !llvm.vec<? x 4 x i32>, !llvm.vec<? x 4 x i32>) -> !llvm.vec<? x 4 x i32>
+  %0 = arm_sve.masked.addi %mask, %a, %b : !arm_sve.vector<4xi1>,
+                                           !arm_sve.vector<4xi32>
+  // CHECK: arm_sve.intr.sub{{.*}}: (!llvm.vec<? x 4 x i1>, !llvm.vec<? x 4 x i32>, !llvm.vec<? x 4 x i32>) -> !llvm.vec<? x 4 x i32>
+  %1 = arm_sve.masked.subi %mask, %0, %c : !arm_sve.vector<4xi1>,
+                                           !arm_sve.vector<4xi32>
+  // CHECK: arm_sve.intr.mul{{.*}}: (!llvm.vec<? x 4 x i1>, !llvm.vec<? x 4 x i32>, !llvm.vec<? x 4 x i32>) -> !llvm.vec<? x 4 x i32>
+  %2 = arm_sve.masked.muli %mask, %1, %d : !arm_sve.vector<4xi1>,
+                                           !arm_sve.vector<4xi32>
+  // CHECK: arm_sve.intr.sdiv{{.*}}: (!llvm.vec<? x 4 x i1>, !llvm.vec<? x 4 x i32>, !llvm.vec<? x 4 x i32>) -> !llvm.vec<? x 4 x i32>
+  %3 = arm_sve.masked.divi_signed %mask, %2, %e : !arm_sve.vector<4xi1>,
+                                                  !arm_sve.vector<4xi32>
+  // CHECK: arm_sve.intr.udiv{{.*}}: (!llvm.vec<? x 4 x i1>, !llvm.vec<? x 4 x i32>, !llvm.vec<? x 4 x i32>) -> !llvm.vec<? x 4 x i32>
+  %4 = arm_sve.masked.divi_unsigned %mask, %3, %e : !arm_sve.vector<4xi1>,
+                                                    !arm_sve.vector<4xi32>
+  return %4 : !arm_sve.vector<4xi32>
+}
+
+func @arm_sve_arithf_masked(%a: !arm_sve.vector<4xf32>,
+                            %b: !arm_sve.vector<4xf32>,
+                            %c: !arm_sve.vector<4xf32>,
+                            %d: !arm_sve.vector<4xf32>,
+                            %e: !arm_sve.vector<4xf32>,
+                            %mask: !arm_sve.vector<4xi1>
+                            ) -> !arm_sve.vector<4xf32> {
+  // CHECK: arm_sve.intr.fadd{{.*}}: (!llvm.vec<? x 4 x i1>, !llvm.vec<? x 4 x f32>, !llvm.vec<? x 4 x f32>) -> !llvm.vec<? x 4 x f32>
+  %0 = arm_sve.masked.addf %mask, %a, %b : !arm_sve.vector<4xi1>,
+                                           !arm_sve.vector<4xf32>
+  // CHECK: arm_sve.intr.fsub{{.*}}: (!llvm.vec<? x 4 x i1>, !llvm.vec<? x 4 x f32>, !llvm.vec<? x 4 x f32>) -> !llvm.vec<? x 4 x f32>
+  %1 = arm_sve.masked.subf %mask, %0, %c : !arm_sve.vector<4xi1>,
+                                           !arm_sve.vector<4xf32>
+  // CHECK: arm_sve.intr.fmul{{.*}}: (!llvm.vec<? x 4 x i1>, !llvm.vec<? x 4 x f32>, !llvm.vec<? x 4 x f32>) -> !llvm.vec<? x 4 x f32>
+  %2 = arm_sve.masked.mulf %mask, %1, %d : !arm_sve.vector<4xi1>,
+                                           !arm_sve.vector<4xf32>
+  // CHECK: arm_sve.intr.fdiv{{.*}}: (!llvm.vec<? x 4 x i1>, !llvm.vec<? x 4 x f32>, !llvm.vec<? x 4 x f32>) -> !llvm.vec<? x 4 x f32>
+  %3 = arm_sve.masked.divf %mask, %2, %e : !arm_sve.vector<4xi1>,
+                                           !arm_sve.vector<4xf32>
+  return %3 : !arm_sve.vector<4xf32>
+}
+
 func @get_vector_scale() -> index {
   // CHECK: arm_sve.vscale
   %0 = arm_sve.vector_scale : index

diff  --git a/mlir/test/Dialect/ArmSVE/roundtrip.mlir b/mlir/test/Dialect/ArmSVE/roundtrip.mlir
index 44cc2fa12217..4666d16f33f2 100644
--- a/mlir/test/Dialect/ArmSVE/roundtrip.mlir
+++ b/mlir/test/Dialect/ArmSVE/roundtrip.mlir
@@ -56,6 +56,53 @@ func @arm_sve_arithf(%a: !arm_sve.vector<4xf32>,
   return %1 : !arm_sve.vector<4xf32>
 }
 
+func @arm_sve_masked_arithi(%a: !arm_sve.vector<4xi32>,
+                            %b: !arm_sve.vector<4xi32>,
+                            %c: !arm_sve.vector<4xi32>,
+                            %d: !arm_sve.vector<4xi32>,
+                            %e: !arm_sve.vector<4xi32>,
+                            %mask: !arm_sve.vector<4xi1>)
+                            -> !arm_sve.vector<4xi32> {
+  // CHECK: arm_sve.masked.muli {{.*}}: !arm_sve.vector<4xi1>, !arm_sve.vector
+  %0 = arm_sve.masked.muli %mask, %a, %b : !arm_sve.vector<4xi1>,
+                                           !arm_sve.vector<4xi32>
+  // CHECK: arm_sve.masked.addi {{.*}}: !arm_sve.vector<4xi1>, !arm_sve.vector
+  %1 = arm_sve.masked.addi %mask, %0, %c : !arm_sve.vector<4xi1>,
+                                           !arm_sve.vector<4xi32>
+  // CHECK: arm_sve.masked.subi {{.*}}: !arm_sve.vector<4xi1>, !arm_sve.vector
+  %2 = arm_sve.masked.subi %mask, %1, %d : !arm_sve.vector<4xi1>,
+                                           !arm_sve.vector<4xi32>
+  // CHECK: arm_sve.masked.divi_signed
+  %3 = arm_sve.masked.divi_signed %mask, %2, %e : !arm_sve.vector<4xi1>,
+                                                  !arm_sve.vector<4xi32>
+  // CHECK: arm_sve.masked.divi_unsigned
+  %4 = arm_sve.masked.divi_unsigned %mask, %3, %e : !arm_sve.vector<4xi1>,
+                                                    !arm_sve.vector<4xi32>
+  return %2 : !arm_sve.vector<4xi32>
+}
+
+func @arm_sve_masked_arithf(%a: !arm_sve.vector<4xf32>,
+                            %b: !arm_sve.vector<4xf32>,
+                            %c: !arm_sve.vector<4xf32>,
+                            %d: !arm_sve.vector<4xf32>,
+                            %e: !arm_sve.vector<4xf32>,
+                            %mask: !arm_sve.vector<4xi1>)
+                            -> !arm_sve.vector<4xf32> {
+  // CHECK: arm_sve.masked.mulf {{.*}}: !arm_sve.vector<4xi1>, !arm_sve.vector
+  %0 = arm_sve.masked.mulf %mask, %a, %b : !arm_sve.vector<4xi1>,
+                                           !arm_sve.vector<4xf32>
+  // CHECK: arm_sve.masked.addf {{.*}}: !arm_sve.vector<4xi1>, !arm_sve.vector
+  %1 = arm_sve.masked.addf %mask, %0, %c : !arm_sve.vector<4xi1>,
+                                           !arm_sve.vector<4xf32>
+  // CHECK: arm_sve.masked.subf {{.*}}: !arm_sve.vector<4xi1>, !arm_sve.vector
+  %2 = arm_sve.masked.subf %mask, %1, %d : !arm_sve.vector<4xi1>,
+                                           !arm_sve.vector<4xf32>
+  // CHECK: arm_sve.masked.divf {{.*}}: !arm_sve.vector<4xi1>, !arm_sve.vector
+  %3 = arm_sve.masked.divf %mask, %2, %e : !arm_sve.vector<4xi1>,
+                                           !arm_sve.vector<4xf32>
+  return %3 : !arm_sve.vector<4xf32>
+}
+
 func @get_vector_scale() -> index {
   // CHECK: arm_sve.vector_scale : index
   %0 = arm_sve.vector_scale : index

diff  --git a/mlir/test/Target/LLVMIR/arm-sve.mlir b/mlir/test/Target/LLVMIR/arm-sve.mlir
index 71d4b0aee9b4..cf367904f899 100644
--- a/mlir/test/Target/LLVMIR/arm-sve.mlir
+++ b/mlir/test/Target/LLVMIR/arm-sve.mlir
@@ -72,6 +72,73 @@ llvm.func @arm_sve_arithf(%arg0: !llvm.vec<? x 4 x f32>,
   llvm.return %1 : !llvm.vec<? x 4 x f32>
 }
 
+// CHECK-LABEL: define <vscale x 4 x i32> @arm_sve_arithi_masked
+llvm.func @arm_sve_arithi_masked(%arg0: !llvm.vec<? x 4 x i32>,
+                                 %arg1: !llvm.vec<? x 4 x i32>,
+                                 %arg2: !llvm.vec<? x 4 x i32>,
+                                 %arg3: !llvm.vec<? x 4 x i32>,
+                                 %arg4: !llvm.vec<? x 4 x i32>,
+                                 %arg5: !llvm.vec<? x 4 x i1>)
+                                 -> !llvm.vec<? x 4 x i32> {
+  // CHECK: call <vscale x 4 x i32> @llvm.aarch64.sve.add.nxv4i32
+  %0 = "arm_sve.intr.add"(%arg5, %arg0, %arg1) : (!llvm.vec<? x 4 x i1>,
+                                                  !llvm.vec<? x 4 x i32>,
+                                                  !llvm.vec<? x 4 x i32>)
+                                                  -> !llvm.vec<? x 4 x i32>
+  // CHECK: call <vscale x 4 x i32> @llvm.aarch64.sve.sub.nxv4i32
+  %1 = "arm_sve.intr.sub"(%arg5, %0, %arg1) : (!llvm.vec<? x 4 x i1>,
+                                               !llvm.vec<? x 4 x i32>,
+                                               !llvm.vec<? x 4 x i32>)
+                                               -> !llvm.vec<? x 4 x i32>
+  // CHECK: call <vscale x 4 x i32> @llvm.aarch64.sve.mul.nxv4i32
+  %2 = "arm_sve.intr.mul"(%arg5, %1, %arg3) : (!llvm.vec<? x 4 x i1>,
+                                               !llvm.vec<? x 4 x i32>,
+                                               !llvm.vec<? x 4 x i32>)
+                                               -> !llvm.vec<? x 4 x i32>
+  // CHECK: call <vscale x 4 x i32> @llvm.aarch64.sve.sdiv.nxv4i32
+  %3 = "arm_sve.intr.sdiv"(%arg5, %2, %arg4) : (!llvm.vec<? x 4 x i1>,
+                                               !llvm.vec<? x 4 x i32>,
+                                               !llvm.vec<? x 4 x i32>)
+                                               -> !llvm.vec<? x 4 x i32>
+  // CHECK: call <vscale x 4 x i32> @llvm.aarch64.sve.udiv.nxv4i32
+  %4 = "arm_sve.intr.udiv"(%arg5, %3, %arg4) : (!llvm.vec<? x 4 x i1>,
+                                               !llvm.vec<? x 4 x i32>,
+                                               !llvm.vec<? x 4 x i32>)
+                                               -> !llvm.vec<? x 4 x i32>
+  llvm.return %4 : !llvm.vec<? x 4 x i32>
+}
+
+// CHECK-LABEL: define <vscale x 4 x float> @arm_sve_arithf_masked
+llvm.func @arm_sve_arithf_masked(%arg0: !llvm.vec<? x 4 x f32>,
+                                 %arg1: !llvm.vec<? x 4 x f32>,
+                                 %arg2: !llvm.vec<? x 4 x f32>,
+                                 %arg3: !llvm.vec<? x 4 x f32>,
+                                 %arg4: !llvm.vec<? x 4 x f32>,
+                                 %arg5: !llvm.vec<? x 4 x i1>)
+                                 -> !llvm.vec<? x 4 x f32> {
+  // CHECK: call <vscale x 4 x float> @llvm.aarch64.sve.fadd.nxv4f32
+  %0 = "arm_sve.intr.fadd"(%arg5, %arg0, %arg1) : (!llvm.vec<? x 4 x i1>,
+                                                   !llvm.vec<? x 4 x f32>,
+                                                   !llvm.vec<? x 4 x f32>)
+                                                   -> !llvm.vec<? x 4 x f32>
+  // CHECK: call <vscale x 4 x float> @llvm.aarch64.sve.fsub.nxv4f32
+  %1 = "arm_sve.intr.fsub"(%arg5, %0, %arg2) : (!llvm.vec<? x 4 x i1>,
+                                                !llvm.vec<? x 4 x f32>,
+                                                !llvm.vec<? x 4 x f32>)
+                                                -> !llvm.vec<? x 4 x f32>
+  // CHECK: call <vscale x 4 x float> @llvm.aarch64.sve.fmul.nxv4f32
+  %2 = "arm_sve.intr.fmul"(%arg5, %1, %arg3) : (!llvm.vec<? x 4 x i1>,
+                                                !llvm.vec<? x 4 x f32>,
+                                                !llvm.vec<? x 4 x f32>)
+                                                -> !llvm.vec<? x 4 x f32>
+  // CHECK: call <vscale x 4 x float> @llvm.aarch64.sve.fdiv.nxv4f32
+  %3 = "arm_sve.intr.fdiv"(%arg5, %2, %arg4) : (!llvm.vec<? x 4 x i1>,
+                                                !llvm.vec<? x 4 x f32>,
+                                                !llvm.vec<? x 4 x f32>)
+                                                -> !llvm.vec<? x 4 x f32>
+  llvm.return %3 : !llvm.vec<? x 4 x f32>
+}
+
 // CHECK-LABEL: define i64 @get_vector_scale()
 llvm.func @get_vector_scale() -> i64 {
   // CHECK: call i64 @llvm.vscale.i64()


        


More information about the Mlir-commits mailing list