[Mlir-commits] [mlir] d009f6e - [mlir] Convert ConstShapeOp to a static tensor type.
Adrian Kuegel
llvmlistbot at llvm.org
Tue Oct 5 03:14:55 PDT 2021
Author: Adrian Kuegel
Date: 2021-10-05T12:14:43+02:00
New Revision: d009f6e51cae7e7a155d083c9170723554f2e776
URL: https://github.com/llvm/llvm-project/commit/d009f6e51cae7e7a155d083c9170723554f2e776
DIFF: https://github.com/llvm/llvm-project/commit/d009f6e51cae7e7a155d083c9170723554f2e776.diff
LOG: [mlir] Convert ConstShapeOp to a static tensor type.
ConstShapeOp knows its shape, so it should also have a static tensor type.
Differential Revision: https://reviews.llvm.org/D111127
Added:
Modified:
mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
index f622d5ed89e85..3aef4bbcf3503 100644
--- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
+++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
@@ -191,7 +191,7 @@ LogicalResult ConstShapeOpConverter::matchAndRewrite(
Type indexTy = rewriter.getIndexType();
Value tensor =
rewriter.create<tensor::FromElementsOp>(loc, indexTy, extentOperands);
- Type resultTy = RankedTensorType::get({ShapedType::kDynamicSize}, indexTy);
+ Type resultTy = RankedTensorType::get({op.shape().size()}, indexTy);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultTy, tensor);
return success();
}
diff --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
index bc551f85d0294..ccc8d56a90c98 100644
--- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
+++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
@@ -89,29 +89,29 @@ func @get_extent_from_extent_tensor(%extents : tensor<?xindex>, %idx : index)
// Lower `const_shape` to `tensor.from_elements`.
// CHECK-LABEL: @const_shape
-// CHECK-SAME: () -> tensor<?xindex>
-func @const_shape() -> tensor<?xindex> {
+// CHECK-SAME: () -> tensor<3xindex>
+func @const_shape() -> tensor<3xindex> {
// CHECK: %[[C1:.*]] = constant 1 : index
// CHECK: %[[C2:.*]] = constant 2 : index
// CHECK: %[[C3:.*]] = constant 3 : index
// CHECK: %[[TENSOR3:.*]] = tensor.from_elements %[[C1]], %[[C2]], %[[C3]]
- // CHECK: %[[RESULT:.*]] = tensor.cast %[[TENSOR3]] : tensor<3xindex> to tensor<?xindex>
- // CHECK: return %[[RESULT]] : tensor<?xindex>
- %shape = shape.const_shape [1, 2, 3] : tensor<?xindex>
- return %shape : tensor<?xindex>
+ // CHECK: %[[RESULT:.*]] = tensor.cast %[[TENSOR3]] : tensor<3xindex> to tensor<3xindex>
+ // CHECK: return %[[RESULT]] : tensor<3xindex>
+ %shape = shape.const_shape [1, 2, 3] : tensor<3xindex>
+ return %shape : tensor<3xindex>
}
// -----
// Lower `const_shape` in the case of rank 0.
// CHECK-LABEL: func @const_shape_zero_elements
-// CHECK-SAME: () -> tensor<?xindex>
-func @const_shape_zero_elements() -> tensor<?xindex> {
+// CHECK-SAME: () -> tensor<0xindex>
+func @const_shape_zero_elements() -> tensor<0xindex> {
// CHECK: %[[TENSOR:.*]] = tensor.from_elements : tensor<0xindex>
- // CHECK: %[[RESULT:.*]] = tensor.cast %[[TENSOR]] : tensor<0xindex> to tensor<?xindex>
- // CHECK: return %[[RESULT]] : tensor<?xindex>
- %shape = shape.const_shape [] : tensor<?xindex>
- return %shape : tensor<?xindex>
+ // CHECK: %[[RESULT:.*]] = tensor.cast %[[TENSOR]] : tensor<0xindex> to tensor<0xindex>
+ // CHECK: return %[[RESULT]] : tensor<0xindex>
+ %shape = shape.const_shape [] : tensor<0xindex>
+ return %shape : tensor<0xindex>
}
// -----
More information about the Mlir-commits
mailing list