[llvm-branch-commits] [mlir] 1d00508 - [mlir][Shape] Make sure tensor_cast(constant_shape) folding uses the correct type
Benjamin Kramer via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Thu Dec 10 01:54:15 PST 2020
Author: Benjamin Kramer
Date: 2020-12-10T10:49:25+01:00
New Revision: 1d00508c5bf0d43203e11765ce84cdd6cf257856
URL: https://github.com/llvm/llvm-project/commit/1d00508c5bf0d43203e11765ce84cdd6cf257856
DIFF: https://github.com/llvm/llvm-project/commit/1d00508c5bf0d43203e11765ce84cdd6cf257856.diff
LOG: [mlir][Shape] Make sure tensor_cast(constant_shape) folding uses the correct type
This is still subtle, but I think the test cases are sufficient to show
that it works.
Differential Revision: https://reviews.llvm.org/D92927
Added:
Modified:
mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td
mlir/test/Dialect/Shape/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td b/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td
index 43c670a8582e..4e6d062a232f 100644
--- a/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td
+++ b/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td
@@ -32,5 +32,7 @@ def SizeToIndexToSizeCanonicalization : Pat<
(Shape_IndexToSizeOp (Shape_SizeToIndexOp $arg)),
(replaceWithValue $arg)>;
+// Fold tensor_cast(const_shape) to const_shape. This changes the type of
+// const_shape to the destination type of the cast.
def TensorCastConstShape : Pat <
- (TensorCastOp (Shape_ConstShapeOp:$c $ty)), (replaceWithValue $c)>;
+ (TensorCastOp (Shape_ConstShapeOp $arg)), (Shape_ConstShapeOp $arg)>;
diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index 9cb01da75901..aa43f515f753 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -872,13 +872,24 @@ func @fold_assuming_all_single_element(%arg: tensor<?xindex>) {
// -----
-// Fold tensor_cast of a const_shape to const_shape
-// CHECK-LABEL: @fold_tensor_cast_of_const_shape
-func @fold_tensor_cast_of_const_shape(%arg: tensor<?xindex>) {
+// Verify that tensor_cast folding uses the correct type
+// CHECK-LABEL: @fold_tensor_cast_of_const_shape_returned
+func @fold_tensor_cast_of_const_shape_returned(%arg: i1) -> tensor<1xindex> {
+ // CHECK: constant dense<2> : tensor<1xindex>
// CHECK-NOT: tensor_cast
%0 = shape.const_shape [2] : tensor<?xindex>
%1 = tensor_cast %0 : tensor<?xindex> to tensor<1xindex>
- %2 = shape.cstr_broadcastable %1, %0 : tensor<1xindex>, tensor<?xindex>
- "consume.witness"(%2) : (!shape.witness) -> ()
- return
+ return %1 : tensor<1xindex>
+}
+
+// -----
+
+// Verify that tensor_cast folding uses the correct type
+// CHECK-LABEL: @fold_tensor_cast_of_const_shape_returned_dynamic
+func @fold_tensor_cast_of_const_shape_returned_dynamic(%arg: i1) -> tensor<?xindex> {
+ // CHECK: shape.const_shape [2] : tensor<?xindex>
+ // CHECK-NOT: tensor_cast
+ %0 = shape.const_shape [2] : tensor<1xindex>
+ %1 = tensor_cast %0 : tensor<1xindex> to tensor<?xindex>
+ return %1 : tensor<?xindex>
}
More information about the llvm-branch-commits
mailing list