[Mlir-commits] [mlir] 2bb208d - [mlir] Don't allow dynamic extent tensor types for ConstShapeOp.
Adrian Kuegel
llvmlistbot at llvm.org
Thu Oct 7 01:56:32 PDT 2021
Author: Adrian Kuegel
Date: 2021-10-07T10:56:16+02:00
New Revision: 2bb208ddfd700f0fdd3028f83eecd280a8d6f3b5
URL: https://github.com/llvm/llvm-project/commit/2bb208ddfd700f0fdd3028f83eecd280a8d6f3b5
DIFF: https://github.com/llvm/llvm-project/commit/2bb208ddfd700f0fdd3028f83eecd280a8d6f3b5.diff
LOG: [mlir] Don't allow dynamic extent tensor types for ConstShapeOp.
ConstShapeOp has a constant shape, so its type can always be static.
We still allow it to have ShapeType though.
Differential Revision: https://reviews.llvm.org/D111139
Added:
Modified:
mlir/lib/Dialect/Shape/IR/Shape.cpp
mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td
mlir/test/Dialect/Shape/canonicalize.mlir
mlir/test/Dialect/Shape/ops.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 59a979c74a7d5..62b4e022408aa 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -827,14 +827,10 @@ bool mlir::shape::ConstShapeOp::isCompatibleReturnTypes(TypeRange l,
Type lhs = l.front();
Type rhs = r.front();
- if (lhs == rhs)
- return true;
-
if (lhs.isa<ShapeType>() || rhs.isa<ShapeType>())
// Shape type is compatible with all other valid return types.
return true;
-
- return succeeded(verifyCompatibleShapes(lhs, rhs));
+ return lhs == rhs;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td b/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td
index 8f847b1b28c56..7460dc5f3d33d 100644
--- a/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td
+++ b/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td
@@ -12,6 +12,10 @@ def HasSingleElement : Constraint<CPred< [{
$0.size() == 1
}]>>;
+def HasStaticShape : Constraint<CPred< [{
+ $0.getType().dyn_cast<ShapedType>().hasStaticShape()
+}]>>;
+
// Canonicalization patterns.
def AssumingAllOneOp : Pat<(Shape_AssumingAllOp $args),
@@ -37,4 +41,5 @@ def SizeToIndexToSizeCanonicalization : Pat<
// Fold tensor.cast(const_shape) to const_shape. This changes the type of
// const_shape to the destination type of the cast.
def TensorCastConstShape : Pat <
- (Tensor_CastOp (Shape_ConstShapeOp $arg)), (Shape_ConstShapeOp $arg)>;
+ (Tensor_CastOp:$res (Shape_ConstShapeOp $arg)), (Shape_ConstShapeOp $arg),
+ [(HasStaticShape $res)]>;
diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index b32b7ac9052cb..b0c2181b5b7ba 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -1,10 +1,10 @@
// RUN: mlir-opt -split-input-file -allow-unregistered-dialect -canonicalize %s | FileCheck %s
// CHECK-LABEL: func @f
-func @f(%arg0: tensor<2x3x4xf32>) -> tensor<?xindex> {
- // CHECK: shape.const_shape [2, 3, 4] : tensor<?xindex>
- %0 = shape.shape_of %arg0 : tensor<2x3x4xf32> -> tensor<?xindex>
- return %0 : tensor<?xindex>
+func @f(%arg0: tensor<2x3x4xf32>) -> tensor<3xindex> {
+ // CHECK: shape.const_shape [2, 3, 4] : tensor<3xindex>
+ %0 = shape.shape_of %arg0 : tensor<2x3x4xf32> -> tensor<3xindex>
+ return %0 : tensor<3xindex>
}
// -----
@@ -62,13 +62,13 @@ func @f() -> !shape.shape {
// Basic case including extent tensors.
// CHECK-LABEL: @broadcast
-func @broadcast() -> tensor<?xindex> {
- // CHECK: shape.const_shape [7, 2] : tensor<?xindex>
- %0 = shape.const_shape [1, 2] : tensor<?xindex>
- %1 = shape.const_shape [7, 1] : tensor<?xindex>
+func @broadcast() -> tensor<2xindex> {
+ // CHECK: shape.const_shape [7, 2] : tensor<2xindex>
+ %0 = shape.const_shape [1, 2] : tensor<2xindex>
+ %1 = shape.const_shape [7, 1] : tensor<2xindex>
%2 = shape.broadcast %0, %1
- : tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>
- return %2 : tensor<?xindex>
+ : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex>
+ return %2 : tensor<2xindex>
}
// -----
@@ -77,9 +77,9 @@ func @broadcast() -> tensor<?xindex> {
// CHECK-LABEL: @broadcast
func @broadcast() -> !shape.shape {
// CHECK: shape.const_shape [7, 2] : !shape.shape
- %0 = shape.const_shape [1, 2] : tensor<?xindex>
- %1 = shape.const_shape [7, 1] : tensor<?xindex>
- %2 = shape.broadcast %0, %1 : tensor<?xindex>, tensor<?xindex> -> !shape.shape
+ %0 = shape.const_shape [1, 2] : tensor<2xindex>
+ %1 = shape.const_shape [7, 1] : tensor<2xindex>
+ %2 = shape.broadcast %0, %1 : tensor<2xindex>, tensor<2xindex> -> !shape.shape
return %2 : !shape.shape
}
@@ -317,9 +317,9 @@ func @nonfoldable_num_elements(%shape : !shape.shape) -> !shape.size {
// CHECK-LABEL: func @basic
func @basic() -> index {
// CHECK: constant 2 : index
- %0 = shape.const_shape [0, 1, 2] : tensor<?xindex>
+ %0 = shape.const_shape [0, 1, 2] : tensor<3xindex>
%c2 = constant 2 : index
- %1 = shape.get_extent %0, %c2 : tensor<?xindex>, index -> index
+ %1 = shape.get_extent %0, %c2 : tensor<3xindex>, index -> index
return %1 : index
}
@@ -330,9 +330,9 @@ func @basic() -> index {
func @out_of_bounds() -> index {
// CHECK: shape.const_shape
// CHECK: shape.get_extent
- %0 = shape.const_shape [0, 1, 2] : tensor<?xindex>
+ %0 = shape.const_shape [0, 1, 2] : tensor<3xindex>
%c3 = constant 3 : index
- %1 = shape.get_extent %0, %c3 : tensor<?xindex>, index -> index
+ %1 = shape.get_extent %0, %c3 : tensor<3xindex>, index -> index
return %1 : index
}
@@ -559,12 +559,12 @@ func @f(%arg : !shape.shape) -> !shape.shape {
// any can be replaced with a constant input if it has one.
// CHECK-LABEL: func @f
-func @f(%arg : tensor<?xindex>) -> tensor<?xindex> {
- // CHECK-NEXT: %[[CS:.*]] = shape.const_shape [2, 3, 4] : tensor<?xindex>
- // CHECK-NEXT: return %[[CS]] : tensor<?xindex>
- %0 = shape.const_shape [2, 3, 4] : tensor<?xindex>
- %1 = shape.any %0, %arg : tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>
- return %1 : tensor<?xindex>
+func @f(%arg : tensor<?xindex>) -> tensor<3xindex> {
+ // CHECK-NEXT: %[[CS:.*]] = shape.const_shape [2, 3, 4] : tensor<3xindex>
+ // CHECK-NEXT: return %[[CS]] : tensor<3xindex>
+ %0 = shape.const_shape [2, 3, 4] : tensor<3xindex>
+ %1 = shape.any %0, %arg : tensor<3xindex>, tensor<?xindex> -> tensor<3xindex>
+ return %1 : tensor<3xindex>
}
// -----
@@ -837,8 +837,8 @@ func @dont_fold_rank(%shape : !shape.shape) -> !shape.size {
func @fold_rank() -> index {
// CHECK: %[[RESULT:.*]] = constant 5 : index
// CHECK: return %[[RESULT]] : index
- %shape = shape.const_shape [3, 4, 5, 6, 7] : tensor<?xindex>
- %rank = shape.rank %shape : tensor<?xindex> -> index
+ %shape = shape.const_shape [3, 4, 5, 6, 7] : tensor<5xindex>
+ %rank = shape.rank %shape : tensor<5xindex> -> index
return %rank : index
}
@@ -971,9 +971,9 @@ func @shape_eq_fold_1() -> i1 {
// CHECK: %[[RESULT:.*]] = constant true
// CHECK: return %[[RESULT]] : i1
%a = shape.const_shape [1, 2, 3] : !shape.shape
- %b = shape.const_shape [1, 2, 3] : tensor<?xindex>
- %c = shape.const_shape [1, 2, 3] : tensor<?xindex>
- %result = shape.shape_eq %a, %b, %c : !shape.shape, tensor<?xindex>, tensor<?xindex>
+ %b = shape.const_shape [1, 2, 3] : tensor<3xindex>
+ %c = shape.const_shape [1, 2, 3] : tensor<3xindex>
+ %result = shape.shape_eq %a, %b, %c : !shape.shape, tensor<3xindex>, tensor<3xindex>
return %result : i1
}
@@ -984,10 +984,10 @@ func @shape_eq_fold_1() -> i1 {
func @shape_eq_fold_0() -> i1 {
// CHECK: %[[RESULT:.*]] = constant false
// CHECK: return %[[RESULT]] : i1
- %a = shape.const_shape [1, 2, 3] : tensor<?xindex>
- %b = shape.const_shape [4, 5, 6] : tensor<?xindex>
- %c = shape.const_shape [4, 5, 6] : tensor<?xindex>
- %result = shape.shape_eq %a, %b, %c : tensor<?xindex>, tensor<?xindex>, tensor<?xindex>
+ %a = shape.const_shape [1, 2, 3] : tensor<3xindex>
+ %b = shape.const_shape [4, 5, 6] : tensor<3xindex>
+ %c = shape.const_shape [4, 5, 6] : tensor<3xindex>
+ %result = shape.shape_eq %a, %b, %c : tensor<3xindex>, tensor<3xindex>, tensor<3xindex>
return %result : i1
}
@@ -1161,18 +1161,17 @@ func @fold_assuming_all_single_element(%arg: tensor<?xindex>) {
func @fold_tensor.cast_of_const_shape_returned(%arg: i1) -> tensor<1xindex> {
// CHECK: shape.const_shape [2] : tensor<1xindex>
// CHECK-NOT: tensor.cast
- %0 = shape.const_shape [2] : tensor<?xindex>
- %1 = tensor.cast %0 : tensor<?xindex> to tensor<1xindex>
+ %0 = shape.const_shape [2] : tensor<1xindex>
+ %1 = tensor.cast %0 : tensor<1xindex> to tensor<1xindex>
return %1 : tensor<1xindex>
}
// -----
-// Verify that tensor.cast folding uses the correct type
-// CHECK-LABEL: @fold_tensor.cast_of_const_shape_returned_dynamic
-func @fold_tensor.cast_of_const_shape_returned_dynamic(%arg: i1) -> tensor<?xindex> {
- // CHECK: shape.const_shape [2] : tensor<?xindex>
- // CHECK-NOT: tensor.cast
+// CHECK-LABEL: @dont_fold_tensor.cast_of_const_shape_returned_dynamic
+func @dont_fold_tensor.cast_of_const_shape_returned_dynamic(%arg: i1) -> tensor<?xindex> {
+ // CHECK: %[[CONST_SHAPE:.*]] = shape.const_shape [2] : tensor<1xindex>
+ // CHECK: tensor.cast %[[CONST_SHAPE]] : tensor<1xindex> to tensor<?xindex>
%0 = shape.const_shape [2] : tensor<1xindex>
%1 = tensor.cast %0 : tensor<1xindex> to tensor<?xindex>
return %1 : tensor<?xindex>
diff --git a/mlir/test/Dialect/Shape/ops.mlir b/mlir/test/Dialect/Shape/ops.mlir
index e7b501e8e2352..a41e7b5936e13 100644
--- a/mlir/test/Dialect/Shape/ops.mlir
+++ b/mlir/test/Dialect/Shape/ops.mlir
@@ -35,7 +35,6 @@ func @test_shape_num_elements_unknown() {
func @const_shape() {
%0 = shape.const_shape [1, 2, 3] : !shape.shape
- %1 = shape.const_shape [4, 5, 6] : tensor<?xindex>
%2 = shape.const_shape [4, 5, 6] : tensor<3xindex>
return
}
@@ -55,11 +54,11 @@ func @test_broadcast_fixed() {
return
}
-func @test_broadcast_extents() -> tensor<?xindex> {
- %0 = shape.const_shape [10, 1, 57, 92] : tensor<?xindex>
- %1 = shape.const_shape [4, 57, 92] : tensor<?xindex>
- %2 = shape.broadcast %0, %1 : tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>
- return %2 : tensor<?xindex>
+func @test_broadcast_extents() -> tensor<4xindex> {
+ %0 = shape.const_shape [10, 1, 57, 92] : tensor<4xindex>
+ %1 = shape.const_shape [4, 57, 92] : tensor<3xindex>
+ %2 = shape.broadcast %0, %1 : tensor<4xindex>, tensor<3xindex> -> tensor<4xindex>
+ return %2 : tensor<4xindex>
}
func @test_shape_any_fixed() {
@@ -89,7 +88,7 @@ func @test_shape_any_fixed_mismatch() {
func @test_parse_const_shape() {
%0 = shape.const_shape [] : !shape.shape
%1 = shape.const_shape [1, 2, 3] : !shape.shape
- %2 = shape.const_shape [1, 2, 3] : tensor<?xindex>
+ %2 = shape.const_shape [1, 2, 3] : tensor<3xindex>
return
}
@@ -222,9 +221,9 @@ func @any() {
%0 = shape.const_shape [1, 2, 3] : !shape.shape
%1 = shape.const_shape [4, 5, 6] : !shape.shape
%2 = "shape.any"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
- %3 = shape.const_shape [1, 2, 3] : tensor<?xindex>
- %4 = shape.const_shape [4, 5, 6] : tensor<?xindex>
- %5 = "shape.any"(%3, %4) : (tensor<?xindex>, tensor<?xindex>) -> tensor<?xindex>
+ %3 = shape.const_shape [1, 2, 3] : tensor<3xindex>
+ %4 = shape.const_shape [4, 5, 6] : tensor<3xindex>
+ %5 = "shape.any"(%3, %4) : (tensor<3xindex>, tensor<3xindex>) -> tensor<3xindex>
return
}
More information about the Mlir-commits
mailing list