[llvm-branch-commits] [mlir] 1d00508 - [mlir][Shape] Make sure tensor_cast(constant_shape) folding uses the correct type

Benjamin Kramer via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Thu Dec 10 01:54:15 PST 2020


Author: Benjamin Kramer
Date: 2020-12-10T10:49:25+01:00
New Revision: 1d00508c5bf0d43203e11765ce84cdd6cf257856

URL: https://github.com/llvm/llvm-project/commit/1d00508c5bf0d43203e11765ce84cdd6cf257856
DIFF: https://github.com/llvm/llvm-project/commit/1d00508c5bf0d43203e11765ce84cdd6cf257856.diff

LOG: [mlir][Shape] Make sure tensor_cast(constant_shape) folding uses the correct type

This is still subtle, but I think the test cases are sufficient to show
that it works.

Differential Revision: https://reviews.llvm.org/D92927

Added: 
    

Modified: 
    mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td
    mlir/test/Dialect/Shape/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td b/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td
index 43c670a8582e..4e6d062a232f 100644
--- a/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td
+++ b/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td
@@ -32,5 +32,7 @@ def SizeToIndexToSizeCanonicalization : Pat<
   (Shape_IndexToSizeOp (Shape_SizeToIndexOp $arg)),
   (replaceWithValue $arg)>;
 
+// Fold tensor_cast(const_shape) to const_shape. This changes the type of
+// const_shape to the destination type of the cast.
 def TensorCastConstShape : Pat <
-  (TensorCastOp (Shape_ConstShapeOp:$c $ty)), (replaceWithValue $c)>;
+  (TensorCastOp (Shape_ConstShapeOp $arg)), (Shape_ConstShapeOp $arg)>;

diff  --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index 9cb01da75901..aa43f515f753 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -872,13 +872,24 @@ func @fold_assuming_all_single_element(%arg: tensor<?xindex>) {
 
 // -----
 
-// Fold tensor_cast of a const_shape to const_shape
-// CHECK-LABEL: @fold_tensor_cast_of_const_shape
-func @fold_tensor_cast_of_const_shape(%arg: tensor<?xindex>) {
+// Verify that tensor_cast folding uses the correct type
+// CHECK-LABEL: @fold_tensor_cast_of_const_shape_returned
+func @fold_tensor_cast_of_const_shape_returned(%arg: i1) -> tensor<1xindex> {
+  // CHECK: constant dense<2> : tensor<1xindex>
   // CHECK-NOT: tensor_cast
   %0 = shape.const_shape [2] : tensor<?xindex>
   %1 = tensor_cast %0 : tensor<?xindex> to tensor<1xindex>
-  %2 = shape.cstr_broadcastable %1, %0 : tensor<1xindex>, tensor<?xindex>
-  "consume.witness"(%2) : (!shape.witness) -> ()
-  return
+  return %1 : tensor<1xindex>
+}
+
+// -----
+
+// Verify that tensor_cast folding uses the correct type
+// CHECK-LABEL: @fold_tensor_cast_of_const_shape_returned_dynamic
+func @fold_tensor_cast_of_const_shape_returned_dynamic(%arg: i1) -> tensor<?xindex> {
+  // CHECK: shape.const_shape [2] : tensor<?xindex>
+  // CHECK-NOT: tensor_cast
+  %0 = shape.const_shape [2] : tensor<1xindex>
+  %1 = tensor_cast %0 : tensor<1xindex> to tensor<?xindex>
+  return %1 : tensor<?xindex>
 }


        


More information about the llvm-branch-commits mailing list