[Mlir-commits] [mlir] [mlir][arith] Further clean up select op definition (PR #93358)
Jakub Kuderski
llvmlistbot at llvm.org
Mon Jun 3 11:26:28 PDT 2024
https://github.com/kuhar updated https://github.com/llvm/llvm-project/pull/93358
>From 778cd5465dce53b883c82958448781508114af51 Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Fri, 24 May 2024 20:19:25 -0400
Subject: [PATCH 1/2] [mlir][arith] Further clean up select op definition
* Improve the condition type requirement description ('scalar' ->
signess i1), to match what is actually verified.
* Use the `I1` type predicate instead of `AnyBooleanTypeMatch`.
Related discussion:
https://github.com/llvm/llvm-project/pull/93351#issuecomment-2130453233.
---
mlir/include/mlir/Dialect/Arith/IR/ArithOps.td | 15 ++++-----------
mlir/test/Dialect/Arith/invalid.mlir | 4 ++--
mlir/test/IR/invalid-ops.mlir | 4 ++--
3 files changed, 8 insertions(+), 15 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index ead52332e8eec..3981a46e1f072 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -1540,21 +1540,14 @@ def Arith_CmpFOp : Arith_CompareOp<"cmpf",
// SelectOp
//===----------------------------------------------------------------------===//
-class AnyBooleanTypeMatch<list<string> names> :
- AnyMatchOperatorTrait<names, "$_self.getType().isSignlessInteger(1)",
- "scalar type">;
-
-class ScalarConditionOrMatchingShape<list<string> names> :
+class BooleanConditionOrMatchingShape<list<string> names> :
PredOpTrait<
- !head(names) # " is scalar or has matching shape",
- Or<[AnyBooleanTypeMatch<[!head(names)]>.predicate,
- AllShapesMatch<names>.predicate]>> {
- list<string> values = names;
-}
+ !head(names) # " is signless i1 or has matching shape",
+ Or<[TypeIsPred<!head(names), I1>, AllShapesMatch<names>.predicate]>>;
def SelectOp : Arith_Op<"select", [Pure,
AllTypesMatch<["true_value", "false_value", "result"]>,
- ScalarConditionOrMatchingShape<["condition", "result"]>,
+ BooleanConditionOrMatchingShape<["condition", "result"]>,
DeclareOpInterfaceMethods<InferIntRangeInterface>,
] # ElementwiseMappable.traits> {
let summary = "select operation";
diff --git a/mlir/test/Dialect/Arith/invalid.mlir b/mlir/test/Dialect/Arith/invalid.mlir
index ada849220bb83..815c77a749382 100644
--- a/mlir/test/Dialect/Arith/invalid.mlir
+++ b/mlir/test/Dialect/Arith/invalid.mlir
@@ -792,7 +792,7 @@ func.func @func() {
// -----
func.func @disallow_zero_rank_tensor_with_ranked_tensor(%arg0 : tensor<i1>, %arg1 : tensor<2xi64>, %arg2 : tensor<2xi64>) -> tensor<2xi64> {
- // expected-error @+1 {{'arith.select' op failed to verify that condition is scalar or has matching shape}}
+ // expected-error @+1 {{'arith.select' op failed to verify that condition is signless i1 or has matching shape}}
%0 = arith.select %arg0, %arg1, %arg2 : tensor<i1>, tensor<2xi64>
return %0 : tensor<2xi64>
}
@@ -800,7 +800,7 @@ func.func @disallow_zero_rank_tensor_with_ranked_tensor(%arg0 : tensor<i1>, %arg
// -----
func.func @disallow_zero_rank_tensor_with_unranked_tensor(%arg0 : tensor<i1>, %arg1 : tensor<2x?xi64>, %arg2 : tensor<2x?xi64>) -> tensor<2x?xi64> {
- // expected-error @+1 {{'arith.select' op failed to verify that condition is scalar or has matching shape}}
+ // expected-error @+1 {{'arith.select' op failed to verify that condition is signless i1 or has matching shape}}
%0 = arith.select %arg0, %arg1, %arg2 : tensor<i1>, tensor<2x?xi64>
return %0 : tensor<2x?xi64>
}
diff --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir
index 6ca7035022adb..6672c8840ffde 100644
--- a/mlir/test/IR/invalid-ops.mlir
+++ b/mlir/test/IR/invalid-ops.mlir
@@ -82,7 +82,7 @@ func.func @func_with_ops(i1, i32, i64) {
func.func @func_with_ops(vector<12xi1>, vector<42xi32>, vector<42xi32>) {
^bb0(%cond : vector<12xi1>, %t : vector<42xi32>, %f : vector<42xi32>):
- // expected-error at +1 {{'arith.select' op failed to verify that condition is scalar or has matching shape}}
+ // expected-error at +1 {{'arith.select' op failed to verify that condition is signless i1 or has matching shape}}
%r = "arith.select"(%cond, %t, %f) : (vector<12xi1>, vector<42xi32>, vector<42xi32>) -> vector<42xi32>
}
@@ -90,7 +90,7 @@ func.func @func_with_ops(vector<12xi1>, vector<42xi32>, vector<42xi32>) {
func.func @func_with_ops(tensor<12xi1>, tensor<42xi32>, tensor<42xi32>) {
^bb0(%cond : tensor<12xi1>, %t : tensor<42xi32>, %f : tensor<42xi32>):
- // expected-error at +1 {{'arith.select' op failed to verify that condition is scalar or has matching shape}}
+ // expected-error at +1 {{'arith.select' op failed to verify that condition is signless i1 or has matching shape}}
%r = "arith.select"(%cond, %t, %f) : (tensor<12xi1>, tensor<42xi32>, tensor<42xi32>) -> tensor<42xi32>
}
>From 52c6b9370804f67df5e0869e4f3f6e03c39368d8 Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Fri, 24 May 2024 20:25:32 -0400
Subject: [PATCH 2/2] Simplify
---
mlir/include/mlir/Dialect/Arith/IR/ArithOps.td | 9 +++++----
1 file changed, 5 insertions(+), 4 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index 3981a46e1f072..989e17844990e 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -1540,14 +1540,15 @@ def Arith_CmpFOp : Arith_CompareOp<"cmpf",
// SelectOp
//===----------------------------------------------------------------------===//
-class BooleanConditionOrMatchingShape<list<string> names> :
+class BooleanConditionOrMatchingShape<string condition, string result> :
PredOpTrait<
- !head(names) # " is signless i1 or has matching shape",
- Or<[TypeIsPred<!head(names), I1>, AllShapesMatch<names>.predicate]>>;
+ condition # " is signless i1 or has matching shape",
+ Or<[TypeIsPred<condition, I1>,
+ AllShapesMatch<[condition, result]>.predicate]>>;
def SelectOp : Arith_Op<"select", [Pure,
AllTypesMatch<["true_value", "false_value", "result"]>,
- BooleanConditionOrMatchingShape<["condition", "result"]>,
+ BooleanConditionOrMatchingShape<"condition", "result">,
DeclareOpInterfaceMethods<InferIntRangeInterface>,
] # ElementwiseMappable.traits> {
let summary = "select operation";
More information about the Mlir-commits
mailing list