[Mlir-commits] [mlir] 69db592 - [mlir][arith] Disallow zero ranked tensors for select's condition

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed May 31 23:43:48 PDT 2023


Author: Manas
Date: 2023-06-01T12:12:46+05:30
New Revision: 69db592f762ade86508826a7b3c9d5434c4837e2

URL: https://github.com/llvm/llvm-project/commit/69db592f762ade86508826a7b3c9d5434c4837e2
DIFF: https://github.com/llvm/llvm-project/commit/69db592f762ade86508826a7b3c9d5434c4837e2.diff

LOG: [mlir][arith] Disallow zero ranked tensors for select's condition

Zero ranked tensor (say tensor<i1>) when used for arith.select's condition,
crashes optimizer during bufferization. This patch puts a constraint on
condition to be either scalar or of matching shape as to its result.

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D151270

Added: 
    

Modified: 
    mlir/docs/Bufferization.md
    mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
    mlir/include/mlir/IR/OpBase.td
    mlir/test/Dialect/Arith/invalid.mlir
    mlir/test/IR/invalid-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/docs/Bufferization.md b/mlir/docs/Bufferization.md
index ffa5f9e0efd35..f03d7bb877c9c 100644
--- a/mlir/docs/Bufferization.md
+++ b/mlir/docs/Bufferization.md
@@ -103,8 +103,8 @@ overwrite data that is still needed later in the program.
 
 To simplify this problem, One-Shot Bufferize was designed for ops that are in
 *destination-passing style*. For every tensor result, such ops have a tensor
-operand, who's buffer could be for storing the result of the op in the absence
-of other conflicts. We call such tensor operands the *destination*.
+operand, whose buffer could be utilized for storing the result of the op in the
+absence of other conflicts. We call such tensor operands the *destination*.
 
 As an example, consider the following op: `%0 = tensor.insert %cst into
 %t[%idx] : tensor<?xf32>`

diff  --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index 7b7b30e84ce2d..ee11510f89b79 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -1366,6 +1366,7 @@ def Arith_CmpFOp : Arith_CompareOp<"cmpf"> {
 
 def SelectOp : Arith_Op<"select", [Pure,
     AllTypesMatch<["true_value", "false_value", "result"]>,
+    ScalarConditionOrMatchingShape<["condition", "result"]>,
     DeclareOpInterfaceMethods<InferIntRangeInterface>,
   ] # ElementwiseMappable.traits> {
   let summary = "select operation";

diff  --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index a3e34f44e76fc..915cb8d588543 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -2548,6 +2548,12 @@ class ElementCount<string name> :
 
 class ElementType<string name> : StrFunc<"getElementTypeOrSelf($" # name # ")">;
 
+class AnyPred<list<string> values> :
+  CPred<!if(!lt(!size(values), 1),
+            "false",
+            !foldl("(" # !head(values) # ")", !tail(values), acc, v,
+                   acc # " || (" # v # ")"))>;
+
 class AllMatchPred<list<string> values> :
   CPred<!if(!lt(!size(values), 2),
             "true",
@@ -2570,6 +2576,17 @@ class AllMatchSameOperatorTrait<list<string> names, string operator,
   list<string> values = names;
 }
 
+class AnyMatchOperatorPred<list<string> names, string operator> :
+    AnyPred<!foreach(n, names, !subst("$_self", "$" # n, operator))>;
+
+class AnyMatchOperatorTrait<list<string> names, string operator,
+                            string summary> :
+    PredOpTrait<
+        "any of {" # !interleave(names, ", ") # "} has " # summary,
+        AnyMatchOperatorPred<names, operator>> {
+  list<string> values = names;
+}
+
 class AllElementCountsMatch<list<string> names> :
     AllMatchSameOperatorTrait<names, ElementCount<"_self">.result,
                               "element count">;
@@ -2695,4 +2712,16 @@ class TCopVTEtAreSameAt<list<int> indices> : CPred<
       "[this](unsigned i) { return getElementTypeOrSelf(this->getOperand(i)); "
       "}))">;
 
+class AnyScalarTypeMatch<list<string> names> :
+    AnyMatchOperatorTrait<names, "$_self.getType().isSignlessInteger(1)",
+                          "scalar type">;
+
+class ScalarConditionOrMatchingShape<list<string> names> :
+    PredOpTrait<
+        !head(names) # " is scalar or has matching shape",
+        Or<[AnyScalarTypeMatch<[!head(names)]>.predicate,
+            AllShapesMatch<names>.predicate]>> {
+  list<string> values = names;
+}
+
 #endif // OP_BASE

diff  --git a/mlir/test/Dialect/Arith/invalid.mlir b/mlir/test/Dialect/Arith/invalid.mlir
index 729c86514b03b..9f131e5afab05 100644
--- a/mlir/test/Dialect/Arith/invalid.mlir
+++ b/mlir/test/Dialect/Arith/invalid.mlir
@@ -753,3 +753,19 @@ func.func @func() {
 
   %x = arith.constant 1 : i32
 }
+
+// -----
+
+func.func @disallow_zero_rank_tensor_with_ranked_tensor(%arg0 : tensor<i1>, %arg1 : tensor<2xi64>, %arg2 : tensor<2xi64>) -> tensor<2xi64> {
+  // expected-error @+1 {{'arith.select' op failed to verify that condition is scalar or has matching shape}}
+  %0 = arith.select %arg0, %arg1, %arg2 : tensor<i1>, tensor<2xi64>
+  return %0 : tensor<2xi64>
+}
+
+// -----
+
+func.func @disallow_zero_rank_tensor_with_unranked_tensor(%arg0 : tensor<i1>, %arg1 : tensor<2x?xi64>, %arg2 : tensor<2x?xi64>) -> tensor<2x?xi64> {
+  // expected-error @+1 {{'arith.select' op failed to verify that condition is scalar or has matching shape}}
+  %0 = arith.select %arg0, %arg1, %arg2 : tensor<i1>, tensor<2x?xi64>
+  return %0 : tensor<2x?xi64>
+}

diff  --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir
index 85aae41daac7d..6ca7035022adb 100644
--- a/mlir/test/IR/invalid-ops.mlir
+++ b/mlir/test/IR/invalid-ops.mlir
@@ -82,7 +82,7 @@ func.func @func_with_ops(i1, i32, i64) {
 
 func.func @func_with_ops(vector<12xi1>, vector<42xi32>, vector<42xi32>) {
 ^bb0(%cond : vector<12xi1>, %t : vector<42xi32>, %f : vector<42xi32>):
-  // expected-error at +1 {{all non-scalar operands/results must have the same shape and base type}}
+  // expected-error at +1 {{'arith.select' op failed to verify that condition is scalar or has matching shape}}
   %r = "arith.select"(%cond, %t, %f) : (vector<12xi1>, vector<42xi32>, vector<42xi32>) -> vector<42xi32>
 }
 
@@ -90,7 +90,7 @@ func.func @func_with_ops(vector<12xi1>, vector<42xi32>, vector<42xi32>) {
 
 func.func @func_with_ops(tensor<12xi1>, tensor<42xi32>, tensor<42xi32>) {
 ^bb0(%cond : tensor<12xi1>, %t : tensor<42xi32>, %f : tensor<42xi32>):
-  // expected-error at +1 {{all non-scalar operands/results must have the same shape and base type}}
+  // expected-error at +1 {{'arith.select' op failed to verify that condition is scalar or has matching shape}}
   %r = "arith.select"(%cond, %t, %f) : (tensor<12xi1>, tensor<42xi32>, tensor<42xi32>) -> tensor<42xi32>
 }
 


        


More information about the Mlir-commits mailing list