[Mlir-commits] [mlir] [mlir][ArmSVE] Add `arm_sve.psel` operation (PR #95764)

Benjamin Maxwell llvmlistbot at llvm.org
Mon Jun 17 10:31:07 PDT 2024


https://github.com/MacDue updated https://github.com/llvm/llvm-project/pull/95764

>From 47771ebab1540af40f00cff93095f5e52ee22e93 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Mon, 17 Jun 2024 10:22:23 +0000
Subject: [PATCH 1/3] [mlir][ArmSVE] Add `arm_sve.psel` operation

This adds a new operation for the SME/SVE2 psel instruction. This allows
selecting a predicate based on a bit within another predicate,
essentially allowing for 2-D predication. Informally the semantics are:

```mlir
%pd = arm_sve.psel %p1, %p2[%index] : vector<[4]xi1>, vector<[8]xi1>
```

=>

```
if p2[index % num_elements(p2)] == 1:
  pd = p1 : type(p1)
else:
  pd = all-false : type(p1)
```
---
 mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td | 45 +++++++++++++++++++
 .../Transforms/LegalizeForLLVMExport.cpp      | 26 ++++++++++-
 mlir/test/Dialect/ArmSVE/invalid.mlir         |  8 ++++
 .../Dialect/ArmSVE/legalize-for-llvm.mlir     | 32 +++++++++++++
 mlir/test/Dialect/ArmSVE/roundtrip.mlir       | 29 ++++++++++++
 mlir/test/Target/LLVMIR/arm-sve.mlir          | 19 ++++++++
 6 files changed, 158 insertions(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
index aea55830c6607..5b98b21720ada 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
+++ b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
@@ -442,6 +442,43 @@ def ZipX4Op  : ArmSVE_Op<"zip.x4", [
   }];
 }
 
+def PselOp : ArmSVE_Op<"psel", [
+  Pure,
+  AllTypesMatch<["p1", "result"]>,
+]> {
+  let summary = "Predicate select";
+
+  let description = [{
+    This operation returns the input predicate `p1` or an all-false predicate
+    based on the bit at `p2[index]`. Informally the semantics are:
+    ```
+    if p2[index % num_elements(p2)] == 1:
+      return p1 : type(p1)
+    return all-false : type(p1)
+    ```
+
+    Example:
+    ```mlir
+    // Note: p1 and p2 can have different sizes.
+    %pd = arm_sve.psel %p1, %p2[%index] : vector<[4]xi1>, vector<[8]xi1>
+    ```
+
+    Note: This requires SME or SVE2 (`+sme` or `+sve2` in LLVM target features).
+  }];
+
+  let arguments = (ins SVEPredicate:$p1, SVEPredicate:$p2, Index:$index);
+  let results = (outs SVEPredicate:$result);
+
+  let builders = [
+    OpBuilder<(ins "Value":$p1, "Value":$p2, "Value":$index), [{
+      build($_builder, $_state, p1.getType(), p1, p2, index);
+  }]>];
+
+  let assemblyFormat = [{
+    $p1 `,` $p2 `[` $index `]` attr-dict `:` type($p1) `,` type($p2)
+  }];
+}
+
 def ScalableMaskedAddIOp : ScalableMaskedIOp<"masked.addi", "addition",
                                              [Commutative]>;
 
@@ -552,6 +589,14 @@ def ZipX4IntrOp : ArmSVE_IntrOp<"zip.x4",
                    Arg<AnyScalableVector, "v3">:$v3,
                    Arg<AnyScalableVector, "v3">:$v4)>;
 
+// Note: This intrinsic requires SME or SVE2.
+def PselIntrOp : ArmSVE_IntrOp<"psel",
+  /*traits=*/[Pure, TypeIs<"res", SVBool>],
+  /*overloadedOperands=*/[1]>,
+  Arguments<(ins Arg<SVBool, "p1">:$p1,
+                 Arg<SVEPredicate, "p2">:$p2,
+                 Arg<I32, "index">:$index)>;
+
 def WhileLTIntrOp :
   ArmSVE_IntrOp<"whilelt",
     [TypeIs<"res", SVEPredicate>, Pure],
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
index ed4f4cc7f0718..10f39a0855f5f 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
@@ -140,6 +140,28 @@ using ConvertFromSvboolOpLowering =
 using ZipX2OpLowering = OneToOneConvertToLLVMPattern<ZipX2Op, ZipX2IntrOp>;
 using ZipX4OpLowering = OneToOneConvertToLLVMPattern<ZipX4Op, ZipX4IntrOp>;
 
+/// Lower `arm_sve.psel` to LLVM intrinsics. This is almost a 1-to-1 conversion
+/// but first input (P1) and result predicates need conversion to/from svbool.
+struct PselOpLowering : public ConvertOpToLLVMPattern<PselOp> {
+  using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(PselOp pselOp, PselOp::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto svboolType = VectorType::get(16, rewriter.getI1Type(), true);
+    auto loc = pselOp.getLoc();
+    auto svboolP1 = rewriter.create<ConvertToSvboolIntrOp>(loc, svboolType,
+                                                           adaptor.getP1());
+    auto indexI32 = rewriter.create<arith::IndexCastOp>(
+        loc, rewriter.getI32Type(), pselOp.getIndex());
+    auto pselIntr = rewriter.create<PselIntrOp>(loc, svboolType, svboolP1,
+                                                pselOp.getP2(), indexI32);
+    rewriter.replaceOpWithNewOp<ConvertFromSvboolIntrOp>(
+        pselOp, adaptor.getP1().getType(), pselIntr);
+    return success();
+  }
+};
+
 /// Converts `vector.create_mask` ops that match the size of an SVE predicate
 /// to the `whilelt` intrinsic. This produces more canonical codegen than the
 /// generic LLVM lowering, see https://github.com/llvm/llvm-project/issues/81840
@@ -202,7 +224,8 @@ void mlir::populateArmSVELegalizeForLLVMExportPatterns(
                ConvertToSvboolOpLowering,
                ConvertFromSvboolOpLowering,
                ZipX2OpLowering,
-               ZipX4OpLowering>(converter);
+               ZipX4OpLowering,
+               PselOpLowering>(converter);
   // Add vector.create_mask conversion with a high benefit as it produces much
   // nicer code than the generic lowering.
   patterns.add<CreateMaskOpLowering>(converter, /*benefit=*/4096);
@@ -229,6 +252,7 @@ void mlir::configureArmSVELegalizeForExportTarget(
                     ConvertFromSvboolIntrOp,
                     ZipX2IntrOp,
                     ZipX4IntrOp,
+                    PselIntrOp,
                     WhileLTIntrOp>();
   target.addIllegalOp<SdotOp,
                       SmmlaOp,
diff --git a/mlir/test/Dialect/ArmSVE/invalid.mlir b/mlir/test/Dialect/ArmSVE/invalid.mlir
index 1258d3532c049..27b19f4321f8d 100644
--- a/mlir/test/Dialect/ArmSVE/invalid.mlir
+++ b/mlir/test/Dialect/ArmSVE/invalid.mlir
@@ -64,3 +64,11 @@ func.func @arm_sve_zip_x4_bad_vector_type(%a : vector<[5]xf64>) {
   arm_sve.zip.x4 %a, %a, %a, %a : vector<[5]xf64>
   return
 }
+
+// -----
+
+func.func @arm_sve_psel_bad_vector_type(%a : vector<[7]xi1>, %index: index) {
+  // expected-error at +1 {{op operand #0 must be  of ranks 1scalable vector of 1-bit signless integer values of length 16/8/4/2/1, but got 'vector<[7]xi1>'}}
+  arm_sve.psel %a, %a[%index] : vector<[7]xi1>, vector<[7]xi1>
+  return
+}
diff --git a/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
index 3fc5e6e9fcc96..ef792fcf988ce 100644
--- a/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
+++ b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
@@ -239,3 +239,35 @@ func.func @arm_sve_unsupported_create_masks(%index: index) -> (vector<[1]xi1>, v
   %2 = vector.create_mask %index : vector<[32]xi1>
   return %0, %1, %2 : vector<[1]xi1>, vector<[7]xi1>, vector<[32]xi1>
 }
+
+// -----
+
+// CHECK-LABEL: @arm_sve_psel_matching_predicate_types(
+// CHECK-SAME:                                         %[[P0:[a-z0-9]+]]: vector<[4]xi1>,
+// CHECK-SAME:                                         %[[P1:[a-z0-9]+]]: vector<[4]xi1>,
+// CHECK-SAME:                                         %[[INDEX:[a-z0-9]+]]: i64
+func.func @arm_sve_psel_matching_predicate_types(%a: vector<[4]xi1>, %b: vector<[4]xi1>, %index: index) -> vector<[4]xi1>
+{
+  //  CHECK-DAG: %[[INDEX_I32:.*]] = llvm.trunc %[[INDEX]] : i64 to i32
+  //  CHECK-DAG: %[[P0_IN:.*]] = "arm_sve.intr.convert.to.svbool"(%[[P0]]) : (vector<[4]xi1>) -> vector<[16]xi1>
+  // CHECK-NEXT: %[[PSEL:.*]] = "arm_sve.intr.psel"(%[[P0_IN]], %[[P1]], %[[INDEX_I32]]) : (vector<[16]xi1>, vector<[4]xi1>, i32) -> vector<[16]xi1>
+  // CHECK-NEXT: %[[RES:.*]] = "arm_sve.intr.convert.from.svbool"(%[[PSEL]]) : (vector<[16]xi1>) -> vector<[4]xi1>
+  %0 = arm_sve.psel %a, %b[%index] : vector<[4]xi1>, vector<[4]xi1>
+  return %0 : vector<[4]xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sve_psel_mixed_predicate_types(
+// CHECK-SAME:                                      %[[P0:[a-z0-9]+]]: vector<[8]xi1>,
+// CHECK-SAME:                                      %[[P1:[a-z0-9]+]]: vector<[16]xi1>,
+// CHECK-SAME:                                      %[[INDEX:[a-z0-9]+]]: i64
+func.func @arm_sve_psel_mixed_predicate_types(%a: vector<[8]xi1>, %b: vector<[16]xi1>, %index: index) -> vector<[8]xi1>
+{
+  //  CHECK-DAG: %[[INDEX_I32:.*]] = llvm.trunc %[[INDEX]] : i64 to i32
+  //  CHECK-DAG: %[[P0_IN:.*]] = "arm_sve.intr.convert.to.svbool"(%[[P0]]) : (vector<[8]xi1>) -> vector<[16]xi1>
+  // CHECK-NEXT: %[[PSEL:.*]] = "arm_sve.intr.psel"(%[[P0_IN]], %[[P1]], %[[INDEX_I32]]) : (vector<[16]xi1>, vector<[16]xi1>, i32) -> vector<[16]xi1>
+  // CHECK-NEXT: %[[RES:.*]] = "arm_sve.intr.convert.from.svbool"(%[[PSEL]]) : (vector<[16]xi1>) -> vector<[8]xi1>
+  %0 = arm_sve.psel %a, %b[%index] : vector<[8]xi1>, vector<[16]xi1>
+  return %0 : vector<[8]xi1>
+}
diff --git a/mlir/test/Dialect/ArmSVE/roundtrip.mlir b/mlir/test/Dialect/ArmSVE/roundtrip.mlir
index f7b79aa2f275c..0f0c5a8575772 100644
--- a/mlir/test/Dialect/ArmSVE/roundtrip.mlir
+++ b/mlir/test/Dialect/ArmSVE/roundtrip.mlir
@@ -225,3 +225,32 @@ func.func @arm_sve_zip_x4(
   %a8, %b8, %c8, %d8 = arm_sve.zip.x4 %v8, %v8, %v8, %v8 : vector<[16]xi8>
   return
 }
+
+// -----
+
+func.func @arm_sve_psel(
+  %p0: vector<[2]xi1>,
+  %p1: vector<[4]xi1>,
+  %p2: vector<[8]xi1>,
+  %p3: vector<[16]xi1>,
+  %index: index
+) {
+  // CHECK: arm_sve.psel %{{.*}}, %{{.*}}[%{{.*}}] : vector<[2]xi1>, vector<[2]xi1>
+  %0 = arm_sve.psel %p0, %p0[%index] : vector<[2]xi1>, vector<[2]xi1>
+  // CHECK: arm_sve.psel %{{.*}}, %{{.*}}[%{{.*}}] : vector<[4]xi1>, vector<[4]xi1>
+  %1 = arm_sve.psel %p1, %p1[%index] : vector<[4]xi1>, vector<[4]xi1>
+  // CHECK: arm_sve.psel %{{.*}}, %{{.*}}[%{{.*}}] : vector<[8]xi1>, vector<[8]xi1>
+  %2 = arm_sve.psel %p2, %p2[%index] : vector<[8]xi1>, vector<[8]xi1>
+  // CHECK: arm_sve.psel %{{.*}}, %{{.*}}[%{{.*}}] : vector<[16]xi1>, vector<[16]xi1>
+  %3 = arm_sve.psel %p3, %p3[%index] : vector<[16]xi1>, vector<[16]xi1>
+  /// Some mixed predicate type examples:
+  // CHECK: arm_sve.psel %{{.*}}, %{{.*}}[%{{.*}}] : vector<[2]xi1>, vector<[4]xi1>
+  %4 = arm_sve.psel %p0, %p1[%index] : vector<[2]xi1>, vector<[4]xi1>
+  // CHECK: arm_sve.psel %{{.*}}, %{{.*}}[%{{.*}}] : vector<[4]xi1>, vector<[8]xi1>
+  %5 = arm_sve.psel %p1, %p2[%index] : vector<[4]xi1>, vector<[8]xi1>
+  // CHECK: arm_sve.psel %{{.*}}, %{{.*}}[%{{.*}}] : vector<[8]xi1>, vector<[16]xi1>
+  %6 = arm_sve.psel %p2, %p3[%index] : vector<[8]xi1>, vector<[16]xi1>
+  // CHECK: arm_sve.psel %{{.*}}, %{{.*}}[%{{.*}}] : vector<[16]xi1>, vector<[2]xi1>
+  %7 = arm_sve.psel %p3, %p0[%index] : vector<[16]xi1>, vector<[2]xi1>
+  return
+}
diff --git a/mlir/test/Target/LLVMIR/arm-sve.mlir b/mlir/test/Target/LLVMIR/arm-sve.mlir
index 34413d46b440e..ed5a1fc7ba2e4 100644
--- a/mlir/test/Target/LLVMIR/arm-sve.mlir
+++ b/mlir/test/Target/LLVMIR/arm-sve.mlir
@@ -371,3 +371,22 @@ llvm.func @arm_sve_whilelt(%base: i64, %n: i64) {
   %4 = "arm_sve.intr.whilelt"(%base, %n) : (i64, i64) -> vector<[16]xi1>
   llvm.return
 }
+
+// CHECK-LABEL: arm_sve_psel(
+// CHECK-SAME:               <vscale x 16 x i1> %[[PN:[0-9]+]],
+// CHECK-SAME:               <vscale x 2 x i1> %[[P1:[0-9]+]],
+// CHECK-SAME:               <vscale x 4 x i1> %[[P2:[0-9]+]],
+// CHECK-SAME:               <vscale x 8 x i1> %[[P3:[0-9]+]],
+// CHECK-SAME:               <vscale x 16 x i1> %[[P4:[0-9]+]],
+// CHECK-SAME:               i32 %[[INDEX:[0-9]+]])
+llvm.func @arm_sve_psel(%pn: vector<[16]xi1>, %p1: vector<[2]xi1>, %p2: vector<[4]xi1>, %p3: vector<[8]xi1>, %p4: vector<[16]xi1>, %index: i32) {
+  // CHECK: call <vscale x 16 x i1> @llvm.aarch64.sve.psel.nxv2i1(<vscale x 16 x i1> %[[PN]], <vscale x 2 x i1> %[[P1]], i32 %[[INDEX]])
+  "arm_sve.intr.psel"(%pn, %p1, %index) : (vector<[16]xi1>, vector<[2]xi1>, i32) -> vector<[16]xi1>
+  // CHECK: call <vscale x 16 x i1> @llvm.aarch64.sve.psel.nxv4i1(<vscale x 16 x i1> %[[PN]], <vscale x 4 x i1> %[[P2]], i32 %[[INDEX]])
+  "arm_sve.intr.psel"(%pn, %p2, %index) : (vector<[16]xi1>, vector<[4]xi1>, i32) -> vector<[16]xi1>
+  // CHECK: call <vscale x 16 x i1> @llvm.aarch64.sve.psel.nxv8i1(<vscale x 16 x i1> %[[PN]], <vscale x 8 x i1> %[[P3]], i32 %[[INDEX]])
+  "arm_sve.intr.psel"(%pn, %p3, %index) : (vector<[16]xi1>, vector<[8]xi1>, i32) -> vector<[16]xi1>
+  // CHECK: call <vscale x 16 x i1> @llvm.aarch64.sve.psel.nxv16i1(<vscale x 16 x i1> %[[PN]], <vscale x 16 x i1> %[[P4]], i32 %[[INDEX]])
+  "arm_sve.intr.psel"(%pn, %p4, %index) : (vector<[16]xi1>, vector<[16]xi1>, i32) -> vector<[16]xi1>
+  llvm.return
+}

>From 8e5eb0746851a7c9c2e8f291707b6be777aff082 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Mon, 17 Jun 2024 13:04:54 +0100
Subject: [PATCH 2/3] Apply suggestions from code review

Co-authored-by: Cullen Rhodes <cullen.rhodes at arm.com>
---
 mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
index 5b98b21720ada..7ee2b85b039b3 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
+++ b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
@@ -463,7 +463,7 @@ def PselOp : ArmSVE_Op<"psel", [
     %pd = arm_sve.psel %p1, %p2[%index] : vector<[4]xi1>, vector<[8]xi1>
     ```
 
-    Note: This requires SME or SVE2 (`+sme` or `+sve2` in LLVM target features).
+    Note: This requires SME or SVE2.1 (`+sme` or `+sve2p1` in LLVM target features).
   }];
 
   let arguments = (ins SVEPredicate:$p1, SVEPredicate:$p2, Index:$index);
@@ -589,7 +589,7 @@ def ZipX4IntrOp : ArmSVE_IntrOp<"zip.x4",
                    Arg<AnyScalableVector, "v3">:$v3,
                    Arg<AnyScalableVector, "v3">:$v4)>;
 
-// Note: This intrinsic requires SME or SVE2.
+// Note: This intrinsic requires SME or SVE2.1.
 def PselIntrOp : ArmSVE_IntrOp<"psel",
   /*traits=*/[Pure, TypeIs<"res", SVBool>],
   /*overloadedOperands=*/[1]>,

>From adc481a7c7277f03816e50efac285b464b78c2f2 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Mon, 17 Jun 2024 17:24:02 +0000
Subject: [PATCH 3/3] Fixups

---
 mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td   | 12 +++++++++---
 mlir/test/Dialect/ArmSVE/invalid.mlir           |  2 +-
 mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir |  8 ++++----
 3 files changed, 14 insertions(+), 8 deletions(-)

diff --git a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
index 7ee2b85b039b3..d7e8b22fbd2d3 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
+++ b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
@@ -37,10 +37,16 @@ def ArmSVE_Dialect : Dialect {
 //===----------------------------------------------------------------------===//
 
 def SVBool : ScalableVectorOfRankAndLengthAndType<
-  [1], [16], [I1]>;
+  [1], [16], [I1]>
+{
+  let summary = "vector<[16]xi1>";
+}
 
 def SVEPredicate : ScalableVectorOfRankAndLengthAndType<
-  [1], [16, 8, 4, 2, 1], [I1]>;
+  [1], [16, 8, 4, 2, 1], [I1]>
+{
+  let summary = "vector<[1]xi1>, vector<[2]xi1>, vector<[4]xi1>, vector<[8]xi1>, or vector<[16]xi1>";
+}
 
 // Generalizations of SVBool and SVEPredicate to ranks >= 1.
 // These are masks with a single trailing scalable dimension.
@@ -450,7 +456,7 @@ def PselOp : ArmSVE_Op<"psel", [
 
   let description = [{
     This operation returns the input predicate `p1` or an all-false predicate
-    based on the bit at `p2[index]`. Informally the semantics are:
+    based on the bit at `p2[index]`. Informally, the semantics are:
     ```
     if p2[index % num_elements(p2)] == 1:
       return p1 : type(p1)
diff --git a/mlir/test/Dialect/ArmSVE/invalid.mlir b/mlir/test/Dialect/ArmSVE/invalid.mlir
index 27b19f4321f8d..a021d4393107c 100644
--- a/mlir/test/Dialect/ArmSVE/invalid.mlir
+++ b/mlir/test/Dialect/ArmSVE/invalid.mlir
@@ -68,7 +68,7 @@ func.func @arm_sve_zip_x4_bad_vector_type(%a : vector<[5]xf64>) {
 // -----
 
 func.func @arm_sve_psel_bad_vector_type(%a : vector<[7]xi1>, %index: index) {
-  // expected-error at +1 {{op operand #0 must be  of ranks 1scalable vector of 1-bit signless integer values of length 16/8/4/2/1, but got 'vector<[7]xi1>'}}
+  // expected-error at +1 {{op operand #0 must be vector<[1]xi1>, vector<[2]xi1>, vector<[4]xi1>, vector<[8]xi1>, or vector<[16]xi1>, but got 'vector<[7]xi1>'}}
   arm_sve.psel %a, %a[%index] : vector<[7]xi1>, vector<[7]xi1>
   return
 }
diff --git a/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
index ef792fcf988ce..31d5376c32d9c 100644
--- a/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
+++ b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
@@ -246,13 +246,13 @@ func.func @arm_sve_unsupported_create_masks(%index: index) -> (vector<[1]xi1>, v
 // CHECK-SAME:                                         %[[P0:[a-z0-9]+]]: vector<[4]xi1>,
 // CHECK-SAME:                                         %[[P1:[a-z0-9]+]]: vector<[4]xi1>,
 // CHECK-SAME:                                         %[[INDEX:[a-z0-9]+]]: i64
-func.func @arm_sve_psel_matching_predicate_types(%a: vector<[4]xi1>, %b: vector<[4]xi1>, %index: index) -> vector<[4]xi1>
+func.func @arm_sve_psel_matching_predicate_types(%p0: vector<[4]xi1>, %p1: vector<[4]xi1>, %index: index) -> vector<[4]xi1>
 {
   //  CHECK-DAG: %[[INDEX_I32:.*]] = llvm.trunc %[[INDEX]] : i64 to i32
   //  CHECK-DAG: %[[P0_IN:.*]] = "arm_sve.intr.convert.to.svbool"(%[[P0]]) : (vector<[4]xi1>) -> vector<[16]xi1>
   // CHECK-NEXT: %[[PSEL:.*]] = "arm_sve.intr.psel"(%[[P0_IN]], %[[P1]], %[[INDEX_I32]]) : (vector<[16]xi1>, vector<[4]xi1>, i32) -> vector<[16]xi1>
   // CHECK-NEXT: %[[RES:.*]] = "arm_sve.intr.convert.from.svbool"(%[[PSEL]]) : (vector<[16]xi1>) -> vector<[4]xi1>
-  %0 = arm_sve.psel %a, %b[%index] : vector<[4]xi1>, vector<[4]xi1>
+  %0 = arm_sve.psel %p0, %p1[%index] : vector<[4]xi1>, vector<[4]xi1>
   return %0 : vector<[4]xi1>
 }
 
@@ -262,12 +262,12 @@ func.func @arm_sve_psel_matching_predicate_types(%a: vector<[4]xi1>, %b: vector<
 // CHECK-SAME:                                      %[[P0:[a-z0-9]+]]: vector<[8]xi1>,
 // CHECK-SAME:                                      %[[P1:[a-z0-9]+]]: vector<[16]xi1>,
 // CHECK-SAME:                                      %[[INDEX:[a-z0-9]+]]: i64
-func.func @arm_sve_psel_mixed_predicate_types(%a: vector<[8]xi1>, %b: vector<[16]xi1>, %index: index) -> vector<[8]xi1>
+func.func @arm_sve_psel_mixed_predicate_types(%p0: vector<[8]xi1>, %p1: vector<[16]xi1>, %index: index) -> vector<[8]xi1>
 {
   //  CHECK-DAG: %[[INDEX_I32:.*]] = llvm.trunc %[[INDEX]] : i64 to i32
   //  CHECK-DAG: %[[P0_IN:.*]] = "arm_sve.intr.convert.to.svbool"(%[[P0]]) : (vector<[8]xi1>) -> vector<[16]xi1>
   // CHECK-NEXT: %[[PSEL:.*]] = "arm_sve.intr.psel"(%[[P0_IN]], %[[P1]], %[[INDEX_I32]]) : (vector<[16]xi1>, vector<[16]xi1>, i32) -> vector<[16]xi1>
   // CHECK-NEXT: %[[RES:.*]] = "arm_sve.intr.convert.from.svbool"(%[[PSEL]]) : (vector<[16]xi1>) -> vector<[8]xi1>
-  %0 = arm_sve.psel %a, %b[%index] : vector<[8]xi1>, vector<[16]xi1>
+  %0 = arm_sve.psel %p0, %p1[%index] : vector<[8]xi1>, vector<[16]xi1>
   return %0 : vector<[8]xi1>
 }



More information about the Mlir-commits mailing list