[Mlir-commits] [mlir] 85e4e9d - [mlir][arith] Further clean up select op definition (#93358)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Jun 3 19:29:43 PDT 2024


Author: Jakub Kuderski
Date: 2024-06-03T22:29:39-04:00
New Revision: 85e4e9d2150d62be578065cc22a37c2c7613ce88

URL: https://github.com/llvm/llvm-project/commit/85e4e9d2150d62be578065cc22a37c2c7613ce88
DIFF: https://github.com/llvm/llvm-project/commit/85e4e9d2150d62be578065cc22a37c2c7613ce88.diff

LOG: [mlir][arith] Further clean up select op definition (#93358)

* Improve the condition type requirement description ('scalar' ->
signless i1), to match what is actually verified.
* Use the `I1` type predicate instead of `AnyBooleanTypeMatch`.

Related discussion:
https://github.com/llvm/llvm-project/pull/93351#issuecomment-2130453233.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
    mlir/test/Dialect/Arith/invalid.mlir
    mlir/test/IR/invalid-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index 81ed0f924a2e2..06fbdb7f2c4cb 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -1553,21 +1553,15 @@ def Arith_CmpFOp : Arith_CompareOp<"cmpf",
 // SelectOp
 //===----------------------------------------------------------------------===//
 
-class AnyBooleanTypeMatch<list<string> names> :
-    AnyMatchOperatorTrait<names, "$_self.getType().isSignlessInteger(1)",
-                          "scalar type">;
-
-class ScalarConditionOrMatchingShape<list<string> names> :
+class BooleanConditionOrMatchingShape<string condition, string result> :
     PredOpTrait<
-        !head(names) # " is scalar or has matching shape",
-        Or<[AnyBooleanTypeMatch<[!head(names)]>.predicate,
-            AllShapesMatch<names>.predicate]>> {
-  list<string> values = names;
-}
+      condition # " is signless i1 or has matching shape",
+      Or<[TypeIsPred<condition, I1>,
+          AllShapesMatch<[condition, result]>.predicate]>>;
 
 def SelectOp : Arith_Op<"select", [Pure,
     AllTypesMatch<["true_value", "false_value", "result"]>,
-    ScalarConditionOrMatchingShape<["condition", "result"]>,
+    BooleanConditionOrMatchingShape<"condition", "result">,
     DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRangesFromOptional"]>,
   ] # ElementwiseMappable.traits> {
   let summary = "select operation";

diff  --git a/mlir/test/Dialect/Arith/invalid.mlir b/mlir/test/Dialect/Arith/invalid.mlir
index 652aa738ad392..088da475e8eb4 100644
--- a/mlir/test/Dialect/Arith/invalid.mlir
+++ b/mlir/test/Dialect/Arith/invalid.mlir
@@ -832,7 +832,7 @@ func.func @func() {
 // -----
 
 func.func @disallow_zero_rank_tensor_with_ranked_tensor(%arg0 : tensor<i1>, %arg1 : tensor<2xi64>, %arg2 : tensor<2xi64>) -> tensor<2xi64> {
-  // expected-error @+1 {{'arith.select' op failed to verify that condition is scalar or has matching shape}}
+  // expected-error @+1 {{'arith.select' op failed to verify that condition is signless i1 or has matching shape}}
   %0 = arith.select %arg0, %arg1, %arg2 : tensor<i1>, tensor<2xi64>
   return %0 : tensor<2xi64>
 }
@@ -840,7 +840,7 @@ func.func @disallow_zero_rank_tensor_with_ranked_tensor(%arg0 : tensor<i1>, %arg
 // -----
 
 func.func @disallow_zero_rank_tensor_with_unranked_tensor(%arg0 : tensor<i1>, %arg1 : tensor<2x?xi64>, %arg2 : tensor<2x?xi64>) -> tensor<2x?xi64> {
-  // expected-error @+1 {{'arith.select' op failed to verify that condition is scalar or has matching shape}}
+  // expected-error @+1 {{'arith.select' op failed to verify that condition is signless i1 or has matching shape}}
   %0 = arith.select %arg0, %arg1, %arg2 : tensor<i1>, tensor<2x?xi64>
   return %0 : tensor<2x?xi64>
 }

diff  --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir
index 6ca7035022adb..6672c8840ffde 100644
--- a/mlir/test/IR/invalid-ops.mlir
+++ b/mlir/test/IR/invalid-ops.mlir
@@ -82,7 +82,7 @@ func.func @func_with_ops(i1, i32, i64) {
 
 func.func @func_with_ops(vector<12xi1>, vector<42xi32>, vector<42xi32>) {
 ^bb0(%cond : vector<12xi1>, %t : vector<42xi32>, %f : vector<42xi32>):
-  // expected-error at +1 {{'arith.select' op failed to verify that condition is scalar or has matching shape}}
+  // expected-error at +1 {{'arith.select' op failed to verify that condition is signless i1 or has matching shape}}
   %r = "arith.select"(%cond, %t, %f) : (vector<12xi1>, vector<42xi32>, vector<42xi32>) -> vector<42xi32>
 }
 
@@ -90,7 +90,7 @@ func.func @func_with_ops(vector<12xi1>, vector<42xi32>, vector<42xi32>) {
 
 func.func @func_with_ops(tensor<12xi1>, tensor<42xi32>, tensor<42xi32>) {
 ^bb0(%cond : tensor<12xi1>, %t : tensor<42xi32>, %f : tensor<42xi32>):
-  // expected-error at +1 {{'arith.select' op failed to verify that condition is scalar or has matching shape}}
+  // expected-error at +1 {{'arith.select' op failed to verify that condition is signless i1 or has matching shape}}
   %r = "arith.select"(%cond, %t, %f) : (tensor<12xi1>, tensor<42xi32>, tensor<42xi32>) -> tensor<42xi32>
 }
 


        


More information about the Mlir-commits mailing list