[Mlir-commits] [mlir] 2bb208d - [mlir] Don't allow dynamic extent tensor types for ConstShapeOp.

Adrian Kuegel llvmlistbot at llvm.org
Thu Oct 7 01:56:32 PDT 2021


Author: Adrian Kuegel
Date: 2021-10-07T10:56:16+02:00
New Revision: 2bb208ddfd700f0fdd3028f83eecd280a8d6f3b5

URL: https://github.com/llvm/llvm-project/commit/2bb208ddfd700f0fdd3028f83eecd280a8d6f3b5
DIFF: https://github.com/llvm/llvm-project/commit/2bb208ddfd700f0fdd3028f83eecd280a8d6f3b5.diff

LOG: [mlir] Don't allow dynamic extent tensor types for ConstShapeOp.

ConstShapeOp has a constant shape, so its type can always be static.
We still allow it to have ShapeType though.

Differential Revision: https://reviews.llvm.org/D111139

Added: 
    

Modified: 
    mlir/lib/Dialect/Shape/IR/Shape.cpp
    mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td
    mlir/test/Dialect/Shape/canonicalize.mlir
    mlir/test/Dialect/Shape/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 59a979c74a7d5..62b4e022408aa 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -827,14 +827,10 @@ bool mlir::shape::ConstShapeOp::isCompatibleReturnTypes(TypeRange l,
   Type lhs = l.front();
   Type rhs = r.front();
 
-  if (lhs == rhs)
-    return true;
-
   if (lhs.isa<ShapeType>() || rhs.isa<ShapeType>())
     // Shape type is compatible with all other valid return types.
     return true;
-
-  return succeeded(verifyCompatibleShapes(lhs, rhs));
+  return lhs == rhs;
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td b/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td
index 8f847b1b28c56..7460dc5f3d33d 100644
--- a/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td
+++ b/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td
@@ -12,6 +12,10 @@ def HasSingleElement : Constraint<CPred< [{
   $0.size() == 1
 }]>>;
 
+def HasStaticShape : Constraint<CPred< [{
+  $0.getType().dyn_cast<ShapedType>().hasStaticShape()
+}]>>;
+
 // Canonicalization patterns.
 
 def AssumingAllOneOp : Pat<(Shape_AssumingAllOp $args),
@@ -37,4 +41,5 @@ def SizeToIndexToSizeCanonicalization : Pat<
 // Fold tensor.cast(const_shape) to const_shape. This changes the type of
 // const_shape to the destination type of the cast.
 def TensorCastConstShape : Pat <
-  (Tensor_CastOp (Shape_ConstShapeOp $arg)), (Shape_ConstShapeOp $arg)>;
+  (Tensor_CastOp:$res (Shape_ConstShapeOp $arg)), (Shape_ConstShapeOp $arg),
+  [(HasStaticShape $res)]>;

diff  --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index b32b7ac9052cb..b0c2181b5b7ba 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -1,10 +1,10 @@
 // RUN: mlir-opt -split-input-file -allow-unregistered-dialect -canonicalize %s | FileCheck %s
 
 // CHECK-LABEL: func @f
-func @f(%arg0: tensor<2x3x4xf32>) -> tensor<?xindex> {
-  // CHECK: shape.const_shape [2, 3, 4] : tensor<?xindex>
-  %0 = shape.shape_of %arg0 : tensor<2x3x4xf32> -> tensor<?xindex>
-  return %0 : tensor<?xindex>
+func @f(%arg0: tensor<2x3x4xf32>) -> tensor<3xindex> {
+  // CHECK: shape.const_shape [2, 3, 4] : tensor<3xindex>
+  %0 = shape.shape_of %arg0 : tensor<2x3x4xf32> -> tensor<3xindex>
+  return %0 : tensor<3xindex>
 }
 
 // -----
@@ -62,13 +62,13 @@ func @f() -> !shape.shape {
 
 // Basic case including extent tensors.
 // CHECK-LABEL: @broadcast
-func @broadcast() -> tensor<?xindex> {
-  // CHECK: shape.const_shape [7, 2] : tensor<?xindex>
-  %0 = shape.const_shape [1, 2] : tensor<?xindex>
-  %1 = shape.const_shape [7, 1] : tensor<?xindex>
+func @broadcast() -> tensor<2xindex> {
+  // CHECK: shape.const_shape [7, 2] : tensor<2xindex>
+  %0 = shape.const_shape [1, 2] : tensor<2xindex>
+  %1 = shape.const_shape [7, 1] : tensor<2xindex>
   %2 = shape.broadcast %0, %1
-      : tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>
-  return %2 : tensor<?xindex>
+      : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex>
+  return %2 : tensor<2xindex>
 }
 
 // -----
@@ -77,9 +77,9 @@ func @broadcast() -> tensor<?xindex> {
 // CHECK-LABEL: @broadcast
 func @broadcast() -> !shape.shape {
   // CHECK: shape.const_shape [7, 2] : !shape.shape
-  %0 = shape.const_shape [1, 2] : tensor<?xindex>
-  %1 = shape.const_shape [7, 1] : tensor<?xindex>
-  %2 = shape.broadcast %0, %1 : tensor<?xindex>, tensor<?xindex> -> !shape.shape
+  %0 = shape.const_shape [1, 2] : tensor<2xindex>
+  %1 = shape.const_shape [7, 1] : tensor<2xindex>
+  %2 = shape.broadcast %0, %1 : tensor<2xindex>, tensor<2xindex> -> !shape.shape
   return %2 : !shape.shape
 }
 
@@ -317,9 +317,9 @@ func @nonfoldable_num_elements(%shape : !shape.shape) -> !shape.size {
 // CHECK-LABEL: func @basic
 func @basic() -> index {
   // CHECK: constant 2 : index
-  %0 = shape.const_shape [0, 1, 2] : tensor<?xindex>
+  %0 = shape.const_shape [0, 1, 2] : tensor<3xindex>
   %c2 = constant 2 : index
-  %1 = shape.get_extent %0, %c2 : tensor<?xindex>, index -> index
+  %1 = shape.get_extent %0, %c2 : tensor<3xindex>, index -> index
   return %1 : index
 }
 
@@ -330,9 +330,9 @@ func @basic() -> index {
 func @out_of_bounds() -> index {
   // CHECK: shape.const_shape
   // CHECK: shape.get_extent
-  %0 = shape.const_shape [0, 1, 2] : tensor<?xindex>
+  %0 = shape.const_shape [0, 1, 2] : tensor<3xindex>
   %c3 = constant 3 : index
-  %1 = shape.get_extent %0, %c3 : tensor<?xindex>, index -> index
+  %1 = shape.get_extent %0, %c3 : tensor<3xindex>, index -> index
   return %1 : index
 }
 
@@ -559,12 +559,12 @@ func @f(%arg : !shape.shape) -> !shape.shape {
 
 // any can be replaced with a constant input if it has one.
 // CHECK-LABEL: func @f
-func @f(%arg : tensor<?xindex>) -> tensor<?xindex> {
-  // CHECK-NEXT: %[[CS:.*]] = shape.const_shape [2, 3, 4] : tensor<?xindex>
-  // CHECK-NEXT: return %[[CS]] : tensor<?xindex>
-  %0 = shape.const_shape [2, 3, 4] : tensor<?xindex>
-  %1 = shape.any %0, %arg : tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>
-  return %1 : tensor<?xindex>
+func @f(%arg : tensor<?xindex>) -> tensor<3xindex> {
+  // CHECK-NEXT: %[[CS:.*]] = shape.const_shape [2, 3, 4] : tensor<3xindex>
+  // CHECK-NEXT: return %[[CS]] : tensor<3xindex>
+  %0 = shape.const_shape [2, 3, 4] : tensor<3xindex>
+  %1 = shape.any %0, %arg : tensor<3xindex>, tensor<?xindex> -> tensor<3xindex>
+  return %1 : tensor<3xindex>
 }
 
 // -----
@@ -837,8 +837,8 @@ func @dont_fold_rank(%shape : !shape.shape) -> !shape.size {
 func @fold_rank() -> index {
   // CHECK: %[[RESULT:.*]] = constant 5 : index
   // CHECK: return %[[RESULT]] : index
-  %shape = shape.const_shape [3, 4, 5, 6, 7] : tensor<?xindex>
-  %rank = shape.rank %shape : tensor<?xindex> -> index
+  %shape = shape.const_shape [3, 4, 5, 6, 7] : tensor<5xindex>
+  %rank = shape.rank %shape : tensor<5xindex> -> index
   return %rank : index
 }
 
@@ -971,9 +971,9 @@ func @shape_eq_fold_1() -> i1 {
   // CHECK: %[[RESULT:.*]] = constant true
   // CHECK: return %[[RESULT]] : i1
   %a = shape.const_shape [1, 2, 3] : !shape.shape
-  %b = shape.const_shape [1, 2, 3] : tensor<?xindex>
-  %c = shape.const_shape [1, 2, 3] : tensor<?xindex>
-  %result = shape.shape_eq %a, %b, %c : !shape.shape, tensor<?xindex>, tensor<?xindex>
+  %b = shape.const_shape [1, 2, 3] : tensor<3xindex>
+  %c = shape.const_shape [1, 2, 3] : tensor<3xindex>
+  %result = shape.shape_eq %a, %b, %c : !shape.shape, tensor<3xindex>, tensor<3xindex>
   return %result : i1
 }
 
@@ -984,10 +984,10 @@ func @shape_eq_fold_1() -> i1 {
 func @shape_eq_fold_0() -> i1 {
   // CHECK: %[[RESULT:.*]] = constant false
   // CHECK: return %[[RESULT]] : i1
-  %a = shape.const_shape [1, 2, 3] : tensor<?xindex>
-  %b = shape.const_shape [4, 5, 6] : tensor<?xindex>
-  %c = shape.const_shape [4, 5, 6] : tensor<?xindex>
-  %result = shape.shape_eq %a, %b, %c : tensor<?xindex>, tensor<?xindex>, tensor<?xindex>
+  %a = shape.const_shape [1, 2, 3] : tensor<3xindex>
+  %b = shape.const_shape [4, 5, 6] : tensor<3xindex>
+  %c = shape.const_shape [4, 5, 6] : tensor<3xindex>
+  %result = shape.shape_eq %a, %b, %c : tensor<3xindex>, tensor<3xindex>, tensor<3xindex>
   return %result : i1
 }
 
@@ -1161,18 +1161,17 @@ func @fold_assuming_all_single_element(%arg: tensor<?xindex>) {
 func @fold_tensor.cast_of_const_shape_returned(%arg: i1) -> tensor<1xindex> {
   // CHECK: shape.const_shape [2] : tensor<1xindex>
   // CHECK-NOT: tensor.cast
-  %0 = shape.const_shape [2] : tensor<?xindex>
-  %1 = tensor.cast %0 : tensor<?xindex> to tensor<1xindex>
+  %0 = shape.const_shape [2] : tensor<1xindex>
+  %1 = tensor.cast %0 : tensor<1xindex> to tensor<1xindex>
   return %1 : tensor<1xindex>
 }
 
 // -----
 
-// Verify that tensor.cast folding uses the correct type
-// CHECK-LABEL: @fold_tensor.cast_of_const_shape_returned_dynamic
-func @fold_tensor.cast_of_const_shape_returned_dynamic(%arg: i1) -> tensor<?xindex> {
-  // CHECK: shape.const_shape [2] : tensor<?xindex>
-  // CHECK-NOT: tensor.cast
+// CHECK-LABEL: @dont_fold_tensor.cast_of_const_shape_returned_dynamic
+func @dont_fold_tensor.cast_of_const_shape_returned_dynamic(%arg: i1) -> tensor<?xindex> {
+  // CHECK: %[[CONST_SHAPE:.*]] = shape.const_shape [2] : tensor<1xindex>
+  // CHECK: tensor.cast %[[CONST_SHAPE]] : tensor<1xindex> to tensor<?xindex>
   %0 = shape.const_shape [2] : tensor<1xindex>
   %1 = tensor.cast %0 : tensor<1xindex> to tensor<?xindex>
   return %1 : tensor<?xindex>

diff  --git a/mlir/test/Dialect/Shape/ops.mlir b/mlir/test/Dialect/Shape/ops.mlir
index e7b501e8e2352..a41e7b5936e13 100644
--- a/mlir/test/Dialect/Shape/ops.mlir
+++ b/mlir/test/Dialect/Shape/ops.mlir
@@ -35,7 +35,6 @@ func @test_shape_num_elements_unknown() {
 
 func @const_shape() {
   %0 = shape.const_shape [1, 2, 3] : !shape.shape
-  %1 = shape.const_shape [4, 5, 6] : tensor<?xindex>
   %2 = shape.const_shape [4, 5, 6] : tensor<3xindex>
   return
 }
@@ -55,11 +54,11 @@ func @test_broadcast_fixed() {
   return
 }
 
-func @test_broadcast_extents() -> tensor<?xindex> {
-  %0 = shape.const_shape [10, 1, 57, 92] : tensor<?xindex>
-  %1 = shape.const_shape [4, 57, 92] : tensor<?xindex>
-  %2 = shape.broadcast %0, %1 : tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>
-  return %2 : tensor<?xindex>
+func @test_broadcast_extents() -> tensor<4xindex> {
+  %0 = shape.const_shape [10, 1, 57, 92] : tensor<4xindex>
+  %1 = shape.const_shape [4, 57, 92] : tensor<3xindex>
+  %2 = shape.broadcast %0, %1 : tensor<4xindex>, tensor<3xindex> -> tensor<4xindex>
+  return %2 : tensor<4xindex>
 }
 
 func @test_shape_any_fixed() {
@@ -89,7 +88,7 @@ func @test_shape_any_fixed_mismatch() {
 func @test_parse_const_shape() {
   %0 = shape.const_shape [] : !shape.shape
   %1 = shape.const_shape [1, 2, 3] : !shape.shape
-  %2 = shape.const_shape [1, 2, 3] : tensor<?xindex>
+  %2 = shape.const_shape [1, 2, 3] : tensor<3xindex>
   return
 }
 
@@ -222,9 +221,9 @@ func @any() {
   %0 = shape.const_shape [1, 2, 3] : !shape.shape
   %1 = shape.const_shape [4, 5, 6] : !shape.shape
   %2 = "shape.any"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
-  %3 = shape.const_shape [1, 2, 3] : tensor<?xindex>
-  %4 = shape.const_shape [4, 5, 6] : tensor<?xindex>
-  %5 = "shape.any"(%3, %4) : (tensor<?xindex>, tensor<?xindex>) -> tensor<?xindex>
+  %3 = shape.const_shape [1, 2, 3] : tensor<3xindex>
+  %4 = shape.const_shape [4, 5, 6] : tensor<3xindex>
+  %5 = "shape.any"(%3, %4) : (tensor<3xindex>, tensor<3xindex>) -> tensor<3xindex>
   return
 }
 


        


More information about the Mlir-commits mailing list