[Mlir-commits] [mlir] 69db592 - [mlir][arith] Disallow zero ranked tensors for select's condition
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed May 31 23:43:48 PDT 2023
Author: Manas
Date: 2023-06-01T12:12:46+05:30
New Revision: 69db592f762ade86508826a7b3c9d5434c4837e2
URL: https://github.com/llvm/llvm-project/commit/69db592f762ade86508826a7b3c9d5434c4837e2
DIFF: https://github.com/llvm/llvm-project/commit/69db592f762ade86508826a7b3c9d5434c4837e2.diff
LOG: [mlir][arith] Disallow zero ranked tensors for select's condition
Zero ranked tensor (say tensor<i1>) when used for arith.select's condition,
crashes optimizer during bufferization. This patch puts a constraint on
condition to be either scalar or of matching shape as to its result.
Reviewed By: mehdi_amini
Differential Revision: https://reviews.llvm.org/D151270
Added:
Modified:
mlir/docs/Bufferization.md
mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
mlir/include/mlir/IR/OpBase.td
mlir/test/Dialect/Arith/invalid.mlir
mlir/test/IR/invalid-ops.mlir
Removed:
################################################################################
diff --git a/mlir/docs/Bufferization.md b/mlir/docs/Bufferization.md
index ffa5f9e0efd35..f03d7bb877c9c 100644
--- a/mlir/docs/Bufferization.md
+++ b/mlir/docs/Bufferization.md
@@ -103,8 +103,8 @@ overwrite data that is still needed later in the program.
To simplify this problem, One-Shot Bufferize was designed for ops that are in
*destination-passing style*. For every tensor result, such ops have a tensor
-operand, who's buffer could be for storing the result of the op in the absence
-of other conflicts. We call such tensor operands the *destination*.
+operand, whose buffer could be utilized for storing the result of the op in the
+absence of other conflicts. We call such tensor operands the *destination*.
As an example, consider the following op: `%0 = tensor.insert %cst into
%t[%idx] : tensor<?xf32>`
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index 7b7b30e84ce2d..ee11510f89b79 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -1366,6 +1366,7 @@ def Arith_CmpFOp : Arith_CompareOp<"cmpf"> {
def SelectOp : Arith_Op<"select", [Pure,
AllTypesMatch<["true_value", "false_value", "result"]>,
+ ScalarConditionOrMatchingShape<["condition", "result"]>,
DeclareOpInterfaceMethods<InferIntRangeInterface>,
] # ElementwiseMappable.traits> {
let summary = "select operation";
diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index a3e34f44e76fc..915cb8d588543 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -2548,6 +2548,12 @@ class ElementCount<string name> :
class ElementType<string name> : StrFunc<"getElementTypeOrSelf($" # name # ")">;
+class AnyPred<list<string> values> :
+ CPred<!if(!lt(!size(values), 1),
+ "false",
+ !foldl("(" # !head(values) # ")", !tail(values), acc, v,
+ acc # " || (" # v # ")"))>;
+
class AllMatchPred<list<string> values> :
CPred<!if(!lt(!size(values), 2),
"true",
@@ -2570,6 +2576,17 @@ class AllMatchSameOperatorTrait<list<string> names, string operator,
list<string> values = names;
}
+class AnyMatchOperatorPred<list<string> names, string operator> :
+ AnyPred<!foreach(n, names, !subst("$_self", "$" # n, operator))>;
+
+class AnyMatchOperatorTrait<list<string> names, string operator,
+ string summary> :
+ PredOpTrait<
+ "any of {" # !interleave(names, ", ") # "} has " # summary,
+ AnyMatchOperatorPred<names, operator>> {
+ list<string> values = names;
+}
+
class AllElementCountsMatch<list<string> names> :
AllMatchSameOperatorTrait<names, ElementCount<"_self">.result,
"element count">;
@@ -2695,4 +2712,16 @@ class TCopVTEtAreSameAt<list<int> indices> : CPred<
"[this](unsigned i) { return getElementTypeOrSelf(this->getOperand(i)); "
"}))">;
+class AnyScalarTypeMatch<list<string> names> :
+ AnyMatchOperatorTrait<names, "$_self.getType().isSignlessInteger(1)",
+ "scalar type">;
+
+class ScalarConditionOrMatchingShape<list<string> names> :
+ PredOpTrait<
+ !head(names) # " is scalar or has matching shape",
+ Or<[AnyScalarTypeMatch<[!head(names)]>.predicate,
+ AllShapesMatch<names>.predicate]>> {
+ list<string> values = names;
+}
+
#endif // OP_BASE
diff --git a/mlir/test/Dialect/Arith/invalid.mlir b/mlir/test/Dialect/Arith/invalid.mlir
index 729c86514b03b..9f131e5afab05 100644
--- a/mlir/test/Dialect/Arith/invalid.mlir
+++ b/mlir/test/Dialect/Arith/invalid.mlir
@@ -753,3 +753,19 @@ func.func @func() {
%x = arith.constant 1 : i32
}
+
+// -----
+
+func.func @disallow_zero_rank_tensor_with_ranked_tensor(%arg0 : tensor<i1>, %arg1 : tensor<2xi64>, %arg2 : tensor<2xi64>) -> tensor<2xi64> {
+ // expected-error @+1 {{'arith.select' op failed to verify that condition is scalar or has matching shape}}
+ %0 = arith.select %arg0, %arg1, %arg2 : tensor<i1>, tensor<2xi64>
+ return %0 : tensor<2xi64>
+}
+
+// -----
+
+func.func @disallow_zero_rank_tensor_with_unranked_tensor(%arg0 : tensor<i1>, %arg1 : tensor<2x?xi64>, %arg2 : tensor<2x?xi64>) -> tensor<2x?xi64> {
+ // expected-error @+1 {{'arith.select' op failed to verify that condition is scalar or has matching shape}}
+ %0 = arith.select %arg0, %arg1, %arg2 : tensor<i1>, tensor<2x?xi64>
+ return %0 : tensor<2x?xi64>
+}
diff --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir
index 85aae41daac7d..6ca7035022adb 100644
--- a/mlir/test/IR/invalid-ops.mlir
+++ b/mlir/test/IR/invalid-ops.mlir
@@ -82,7 +82,7 @@ func.func @func_with_ops(i1, i32, i64) {
func.func @func_with_ops(vector<12xi1>, vector<42xi32>, vector<42xi32>) {
^bb0(%cond : vector<12xi1>, %t : vector<42xi32>, %f : vector<42xi32>):
- // expected-error at +1 {{all non-scalar operands/results must have the same shape and base type}}
+ // expected-error at +1 {{'arith.select' op failed to verify that condition is scalar or has matching shape}}
%r = "arith.select"(%cond, %t, %f) : (vector<12xi1>, vector<42xi32>, vector<42xi32>) -> vector<42xi32>
}
@@ -90,7 +90,7 @@ func.func @func_with_ops(vector<12xi1>, vector<42xi32>, vector<42xi32>) {
func.func @func_with_ops(tensor<12xi1>, tensor<42xi32>, tensor<42xi32>) {
^bb0(%cond : tensor<12xi1>, %t : tensor<42xi32>, %f : tensor<42xi32>):
- // expected-error at +1 {{all non-scalar operands/results must have the same shape and base type}}
+ // expected-error at +1 {{'arith.select' op failed to verify that condition is scalar or has matching shape}}
%r = "arith.select"(%cond, %t, %f) : (tensor<12xi1>, tensor<42xi32>, tensor<42xi32>) -> tensor<42xi32>
}
More information about the Mlir-commits
mailing list