[Mlir-commits] [mlir] [mlir][ArmSVE] Add `arm_sve.psel` operation (PR #95764)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Jun 17 03:37:46 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-sve

Author: Benjamin Maxwell (MacDue)

<details>
<summary>Changes</summary>

This adds a new operation for the SME/SVE2 psel instruction. This allows selecting a predicate based on a bit within another predicate, essentially allowing for 2-D predication. Informally the semantics are:

```mlir
%pd = arm_sve.psel %p1, %p2[%index] : vector<[4]xi1>, vector<[8]xi1>
```

=>

```
if p2[index % num_elements(p2)] == 1:
  pd = p1 : type(p1)
else:
  pd = all-false : type(p1)
```

---
Full diff: https://github.com/llvm/llvm-project/pull/95764.diff


6 Files Affected:

- (modified) mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td (+45) 
- (modified) mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp (+25-1) 
- (modified) mlir/test/Dialect/ArmSVE/invalid.mlir (+8) 
- (modified) mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir (+32) 
- (modified) mlir/test/Dialect/ArmSVE/roundtrip.mlir (+29) 
- (modified) mlir/test/Target/LLVMIR/arm-sve.mlir (+19) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
index aea55830c6607..5b98b21720ada 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
+++ b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
@@ -442,6 +442,43 @@ def ZipX4Op  : ArmSVE_Op<"zip.x4", [
   }];
 }
 
+def PselOp : ArmSVE_Op<"psel", [
+  Pure,
+  AllTypesMatch<["p1", "result"]>,
+]> {
+  let summary = "Predicate select";
+
+  let description = [{
+    This operation returns the input predicate `p1` or an all-false predicate
+    based on the bit at `p2[index]`. Informally the semantics are:
+    ```
+    if p2[index % num_elements(p2)] == 1:
+      return p1 : type(p1)
+    return all-false : type(p1)
+    ```
+
+    Example:
+    ```mlir
+    // Note: p1 and p2 can have different sizes.
+    %pd = arm_sve.psel %p1, %p2[%index] : vector<[4]xi1>, vector<[8]xi1>
+    ```
+
+    Note: This requires SME or SVE2 (`+sme` or `+sve2` in LLVM target features).
+  }];
+
+  let arguments = (ins SVEPredicate:$p1, SVEPredicate:$p2, Index:$index);
+  let results = (outs SVEPredicate:$result);
+
+  let builders = [
+    OpBuilder<(ins "Value":$p1, "Value":$p2, "Value":$index), [{
+      build($_builder, $_state, p1.getType(), p1, p2, index);
+  }]>];
+
+  let assemblyFormat = [{
+    $p1 `,` $p2 `[` $index `]` attr-dict `:` type($p1) `,` type($p2)
+  }];
+}
+
 def ScalableMaskedAddIOp : ScalableMaskedIOp<"masked.addi", "addition",
                                              [Commutative]>;
 
@@ -552,6 +589,14 @@ def ZipX4IntrOp : ArmSVE_IntrOp<"zip.x4",
                    Arg<AnyScalableVector, "v3">:$v3,
                    Arg<AnyScalableVector, "v3">:$v4)>;
 
+// Note: This intrinsic requires SME or SVE2.
+def PselIntrOp : ArmSVE_IntrOp<"psel",
+  /*traits=*/[Pure, TypeIs<"res", SVBool>],
+  /*overloadedOperands=*/[1]>,
+  Arguments<(ins Arg<SVBool, "p1">:$p1,
+                 Arg<SVEPredicate, "p2">:$p2,
+                 Arg<I32, "index">:$index)>;
+
 def WhileLTIntrOp :
   ArmSVE_IntrOp<"whilelt",
     [TypeIs<"res", SVEPredicate>, Pure],
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
index ed4f4cc7f0718..10f39a0855f5f 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
@@ -140,6 +140,28 @@ using ConvertFromSvboolOpLowering =
 using ZipX2OpLowering = OneToOneConvertToLLVMPattern<ZipX2Op, ZipX2IntrOp>;
 using ZipX4OpLowering = OneToOneConvertToLLVMPattern<ZipX4Op, ZipX4IntrOp>;
 
+/// Lower `arm_sve.psel` to LLVM intrinsics. This is almost a 1-to-1 conversion
+/// but first input (P1) and result predicates need conversion to/from svbool.
+struct PselOpLowering : public ConvertOpToLLVMPattern<PselOp> {
+  using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(PselOp pselOp, PselOp::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto svboolType = VectorType::get(16, rewriter.getI1Type(), true);
+    auto loc = pselOp.getLoc();
+    auto svboolP1 = rewriter.create<ConvertToSvboolIntrOp>(loc, svboolType,
+                                                           adaptor.getP1());
+    auto indexI32 = rewriter.create<arith::IndexCastOp>(
+        loc, rewriter.getI32Type(), pselOp.getIndex());
+    auto pselIntr = rewriter.create<PselIntrOp>(loc, svboolType, svboolP1,
+                                                pselOp.getP2(), indexI32);
+    rewriter.replaceOpWithNewOp<ConvertFromSvboolIntrOp>(
+        pselOp, adaptor.getP1().getType(), pselIntr);
+    return success();
+  }
+};
+
 /// Converts `vector.create_mask` ops that match the size of an SVE predicate
 /// to the `whilelt` intrinsic. This produces more canonical codegen than the
 /// generic LLVM lowering, see https://github.com/llvm/llvm-project/issues/81840
@@ -202,7 +224,8 @@ void mlir::populateArmSVELegalizeForLLVMExportPatterns(
                ConvertToSvboolOpLowering,
                ConvertFromSvboolOpLowering,
                ZipX2OpLowering,
-               ZipX4OpLowering>(converter);
+               ZipX4OpLowering,
+               PselOpLowering>(converter);
   // Add vector.create_mask conversion with a high benefit as it produces much
   // nicer code than the generic lowering.
   patterns.add<CreateMaskOpLowering>(converter, /*benefit=*/4096);
@@ -229,6 +252,7 @@ void mlir::configureArmSVELegalizeForExportTarget(
                     ConvertFromSvboolIntrOp,
                     ZipX2IntrOp,
                     ZipX4IntrOp,
+                    PselIntrOp,
                     WhileLTIntrOp>();
   target.addIllegalOp<SdotOp,
                       SmmlaOp,
diff --git a/mlir/test/Dialect/ArmSVE/invalid.mlir b/mlir/test/Dialect/ArmSVE/invalid.mlir
index 1258d3532c049..27b19f4321f8d 100644
--- a/mlir/test/Dialect/ArmSVE/invalid.mlir
+++ b/mlir/test/Dialect/ArmSVE/invalid.mlir
@@ -64,3 +64,11 @@ func.func @arm_sve_zip_x4_bad_vector_type(%a : vector<[5]xf64>) {
   arm_sve.zip.x4 %a, %a, %a, %a : vector<[5]xf64>
   return
 }
+
+// -----
+
+func.func @arm_sve_psel_bad_vector_type(%a : vector<[7]xi1>, %index: index) {
+  // expected-error at +1 {{op operand #0 must be  of ranks 1scalable vector of 1-bit signless integer values of length 16/8/4/2/1, but got 'vector<[7]xi1>'}}
+  arm_sve.psel %a, %a[%index] : vector<[7]xi1>, vector<[7]xi1>
+  return
+}
diff --git a/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
index 3fc5e6e9fcc96..ef792fcf988ce 100644
--- a/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
+++ b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
@@ -239,3 +239,35 @@ func.func @arm_sve_unsupported_create_masks(%index: index) -> (vector<[1]xi1>, v
   %2 = vector.create_mask %index : vector<[32]xi1>
   return %0, %1, %2 : vector<[1]xi1>, vector<[7]xi1>, vector<[32]xi1>
 }
+
+// -----
+
+// CHECK-LABEL: @arm_sve_psel_matching_predicate_types(
+// CHECK-SAME:                                         %[[P0:[a-z0-9]+]]: vector<[4]xi1>,
+// CHECK-SAME:                                         %[[P1:[a-z0-9]+]]: vector<[4]xi1>,
+// CHECK-SAME:                                         %[[INDEX:[a-z0-9]+]]: i64
+func.func @arm_sve_psel_matching_predicate_types(%a: vector<[4]xi1>, %b: vector<[4]xi1>, %index: index) -> vector<[4]xi1>
+{
+  //  CHECK-DAG: %[[INDEX_I32:.*]] = llvm.trunc %[[INDEX]] : i64 to i32
+  //  CHECK-DAG: %[[P0_IN:.*]] = "arm_sve.intr.convert.to.svbool"(%[[P0]]) : (vector<[4]xi1>) -> vector<[16]xi1>
+  // CHECK-NEXT: %[[PSEL:.*]] = "arm_sve.intr.psel"(%[[P0_IN]], %[[P1]], %[[INDEX_I32]]) : (vector<[16]xi1>, vector<[4]xi1>, i32) -> vector<[16]xi1>
+  // CHECK-NEXT: %[[RES:.*]] = "arm_sve.intr.convert.from.svbool"(%[[PSEL]]) : (vector<[16]xi1>) -> vector<[4]xi1>
+  %0 = arm_sve.psel %a, %b[%index] : vector<[4]xi1>, vector<[4]xi1>
+  return %0 : vector<[4]xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sve_psel_mixed_predicate_types(
+// CHECK-SAME:                                      %[[P0:[a-z0-9]+]]: vector<[8]xi1>,
+// CHECK-SAME:                                      %[[P1:[a-z0-9]+]]: vector<[16]xi1>,
+// CHECK-SAME:                                      %[[INDEX:[a-z0-9]+]]: i64
+func.func @arm_sve_psel_mixed_predicate_types(%a: vector<[8]xi1>, %b: vector<[16]xi1>, %index: index) -> vector<[8]xi1>
+{
+  //  CHECK-DAG: %[[INDEX_I32:.*]] = llvm.trunc %[[INDEX]] : i64 to i32
+  //  CHECK-DAG: %[[P0_IN:.*]] = "arm_sve.intr.convert.to.svbool"(%[[P0]]) : (vector<[8]xi1>) -> vector<[16]xi1>
+  // CHECK-NEXT: %[[PSEL:.*]] = "arm_sve.intr.psel"(%[[P0_IN]], %[[P1]], %[[INDEX_I32]]) : (vector<[16]xi1>, vector<[16]xi1>, i32) -> vector<[16]xi1>
+  // CHECK-NEXT: %[[RES:.*]] = "arm_sve.intr.convert.from.svbool"(%[[PSEL]]) : (vector<[16]xi1>) -> vector<[8]xi1>
+  %0 = arm_sve.psel %a, %b[%index] : vector<[8]xi1>, vector<[16]xi1>
+  return %0 : vector<[8]xi1>
+}
diff --git a/mlir/test/Dialect/ArmSVE/roundtrip.mlir b/mlir/test/Dialect/ArmSVE/roundtrip.mlir
index f7b79aa2f275c..0f0c5a8575772 100644
--- a/mlir/test/Dialect/ArmSVE/roundtrip.mlir
+++ b/mlir/test/Dialect/ArmSVE/roundtrip.mlir
@@ -225,3 +225,32 @@ func.func @arm_sve_zip_x4(
   %a8, %b8, %c8, %d8 = arm_sve.zip.x4 %v8, %v8, %v8, %v8 : vector<[16]xi8>
   return
 }
+
+// -----
+
+func.func @arm_sve_psel(
+  %p0: vector<[2]xi1>,
+  %p1: vector<[4]xi1>,
+  %p2: vector<[8]xi1>,
+  %p3: vector<[16]xi1>,
+  %index: index
+) {
+  // CHECK: arm_sve.psel %{{.*}}, %{{.*}}[%{{.*}}] : vector<[2]xi1>, vector<[2]xi1>
+  %0 = arm_sve.psel %p0, %p0[%index] : vector<[2]xi1>, vector<[2]xi1>
+  // CHECK: arm_sve.psel %{{.*}}, %{{.*}}[%{{.*}}] : vector<[4]xi1>, vector<[4]xi1>
+  %1 = arm_sve.psel %p1, %p1[%index] : vector<[4]xi1>, vector<[4]xi1>
+  // CHECK: arm_sve.psel %{{.*}}, %{{.*}}[%{{.*}}] : vector<[8]xi1>, vector<[8]xi1>
+  %2 = arm_sve.psel %p2, %p2[%index] : vector<[8]xi1>, vector<[8]xi1>
+  // CHECK: arm_sve.psel %{{.*}}, %{{.*}}[%{{.*}}] : vector<[16]xi1>, vector<[16]xi1>
+  %3 = arm_sve.psel %p3, %p3[%index] : vector<[16]xi1>, vector<[16]xi1>
+  /// Some mixed predicate type examples:
+  // CHECK: arm_sve.psel %{{.*}}, %{{.*}}[%{{.*}}] : vector<[2]xi1>, vector<[4]xi1>
+  %4 = arm_sve.psel %p0, %p1[%index] : vector<[2]xi1>, vector<[4]xi1>
+  // CHECK: arm_sve.psel %{{.*}}, %{{.*}}[%{{.*}}] : vector<[4]xi1>, vector<[8]xi1>
+  %5 = arm_sve.psel %p1, %p2[%index] : vector<[4]xi1>, vector<[8]xi1>
+  // CHECK: arm_sve.psel %{{.*}}, %{{.*}}[%{{.*}}] : vector<[8]xi1>, vector<[16]xi1>
+  %6 = arm_sve.psel %p2, %p3[%index] : vector<[8]xi1>, vector<[16]xi1>
+  // CHECK: arm_sve.psel %{{.*}}, %{{.*}}[%{{.*}}] : vector<[16]xi1>, vector<[2]xi1>
+  %7 = arm_sve.psel %p3, %p0[%index] : vector<[16]xi1>, vector<[2]xi1>
+  return
+}
diff --git a/mlir/test/Target/LLVMIR/arm-sve.mlir b/mlir/test/Target/LLVMIR/arm-sve.mlir
index 34413d46b440e..ed5a1fc7ba2e4 100644
--- a/mlir/test/Target/LLVMIR/arm-sve.mlir
+++ b/mlir/test/Target/LLVMIR/arm-sve.mlir
@@ -371,3 +371,22 @@ llvm.func @arm_sve_whilelt(%base: i64, %n: i64) {
   %4 = "arm_sve.intr.whilelt"(%base, %n) : (i64, i64) -> vector<[16]xi1>
   llvm.return
 }
+
+// CHECK-LABEL: arm_sve_psel(
+// CHECK-SAME:               <vscale x 16 x i1> %[[PN:[0-9]+]],
+// CHECK-SAME:               <vscale x 2 x i1> %[[P1:[0-9]+]],
+// CHECK-SAME:               <vscale x 4 x i1> %[[P2:[0-9]+]],
+// CHECK-SAME:               <vscale x 8 x i1> %[[P3:[0-9]+]],
+// CHECK-SAME:               <vscale x 16 x i1> %[[P4:[0-9]+]],
+// CHECK-SAME:               i32 %[[INDEX:[0-9]+]])
+llvm.func @arm_sve_psel(%pn: vector<[16]xi1>, %p1: vector<[2]xi1>, %p2: vector<[4]xi1>, %p3: vector<[8]xi1>, %p4: vector<[16]xi1>, %index: i32) {
+  // CHECK: call <vscale x 16 x i1> @llvm.aarch64.sve.psel.nxv2i1(<vscale x 16 x i1> %[[PN]], <vscale x 2 x i1> %[[P1]], i32 %[[INDEX]])
+  "arm_sve.intr.psel"(%pn, %p1, %index) : (vector<[16]xi1>, vector<[2]xi1>, i32) -> vector<[16]xi1>
+  // CHECK: call <vscale x 16 x i1> @llvm.aarch64.sve.psel.nxv4i1(<vscale x 16 x i1> %[[PN]], <vscale x 4 x i1> %[[P2]], i32 %[[INDEX]])
+  "arm_sve.intr.psel"(%pn, %p2, %index) : (vector<[16]xi1>, vector<[4]xi1>, i32) -> vector<[16]xi1>
+  // CHECK: call <vscale x 16 x i1> @llvm.aarch64.sve.psel.nxv8i1(<vscale x 16 x i1> %[[PN]], <vscale x 8 x i1> %[[P3]], i32 %[[INDEX]])
+  "arm_sve.intr.psel"(%pn, %p3, %index) : (vector<[16]xi1>, vector<[8]xi1>, i32) -> vector<[16]xi1>
+  // CHECK: call <vscale x 16 x i1> @llvm.aarch64.sve.psel.nxv16i1(<vscale x 16 x i1> %[[PN]], <vscale x 16 x i1> %[[P4]], i32 %[[INDEX]])
+  "arm_sve.intr.psel"(%pn, %p4, %index) : (vector<[16]xi1>, vector<[16]xi1>, i32) -> vector<[16]xi1>
+  llvm.return
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/95764


More information about the Mlir-commits mailing list