[Mlir-commits] [mlir] d009f6e - [mlir] Convert ConstShapeOp to a static tensor type.

Adrian Kuegel llvmlistbot at llvm.org
Tue Oct 5 03:14:55 PDT 2021


Author: Adrian Kuegel
Date: 2021-10-05T12:14:43+02:00
New Revision: d009f6e51cae7e7a155d083c9170723554f2e776

URL: https://github.com/llvm/llvm-project/commit/d009f6e51cae7e7a155d083c9170723554f2e776
DIFF: https://github.com/llvm/llvm-project/commit/d009f6e51cae7e7a155d083c9170723554f2e776.diff

LOG: [mlir] Convert ConstShapeOp to a static tensor type.

ConstShapeOp knows its shape, so it should also have a static tensor type.

Differential Revision: https://reviews.llvm.org/D111127

Added: 
    

Modified: 
    mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
    mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
index f622d5ed89e85..3aef4bbcf3503 100644
--- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
+++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
@@ -191,7 +191,7 @@ LogicalResult ConstShapeOpConverter::matchAndRewrite(
   Type indexTy = rewriter.getIndexType();
   Value tensor =
       rewriter.create<tensor::FromElementsOp>(loc, indexTy, extentOperands);
-  Type resultTy = RankedTensorType::get({ShapedType::kDynamicSize}, indexTy);
+  Type resultTy = RankedTensorType::get({op.shape().size()}, indexTy);
   rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultTy, tensor);
   return success();
 }

diff  --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
index bc551f85d0294..ccc8d56a90c98 100644
--- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
+++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
@@ -89,29 +89,29 @@ func @get_extent_from_extent_tensor(%extents : tensor<?xindex>, %idx : index)
 
 // Lower `const_shape` to `tensor.from_elements`.
 // CHECK-LABEL: @const_shape
-// CHECK-SAME: () -> tensor<?xindex>
-func @const_shape() -> tensor<?xindex> {
+// CHECK-SAME: () -> tensor<3xindex>
+func @const_shape() -> tensor<3xindex> {
   // CHECK: %[[C1:.*]] = constant 1 : index
   // CHECK: %[[C2:.*]] = constant 2 : index
   // CHECK: %[[C3:.*]] = constant 3 : index
   // CHECK: %[[TENSOR3:.*]] = tensor.from_elements %[[C1]], %[[C2]], %[[C3]]
-  // CHECK: %[[RESULT:.*]] = tensor.cast %[[TENSOR3]] : tensor<3xindex> to tensor<?xindex>
-  // CHECK: return %[[RESULT]] : tensor<?xindex>
-  %shape = shape.const_shape [1, 2, 3] : tensor<?xindex>
-  return %shape : tensor<?xindex>
+  // CHECK: %[[RESULT:.*]] = tensor.cast %[[TENSOR3]] : tensor<3xindex> to tensor<3xindex>
+  // CHECK: return %[[RESULT]] : tensor<3xindex>
+  %shape = shape.const_shape [1, 2, 3] : tensor<3xindex>
+  return %shape : tensor<3xindex>
 }
 
 // -----
 
 // Lower `const_shape` in the case of rank 0.
 // CHECK-LABEL: func @const_shape_zero_elements
-// CHECK-SAME: () -> tensor<?xindex>
-func @const_shape_zero_elements() -> tensor<?xindex> {
+// CHECK-SAME: () -> tensor<0xindex>
+func @const_shape_zero_elements() -> tensor<0xindex> {
   // CHECK: %[[TENSOR:.*]] = tensor.from_elements : tensor<0xindex>
-  // CHECK: %[[RESULT:.*]] = tensor.cast %[[TENSOR]] : tensor<0xindex> to tensor<?xindex>
-  // CHECK: return %[[RESULT]] : tensor<?xindex>
-  %shape = shape.const_shape [] : tensor<?xindex>
-  return %shape : tensor<?xindex>
+  // CHECK: %[[RESULT:.*]] = tensor.cast %[[TENSOR]] : tensor<0xindex> to tensor<0xindex>
+  // CHECK: return %[[RESULT]] : tensor<0xindex>
+  %shape = shape.const_shape [] : tensor<0xindex>
+  return %shape : tensor<0xindex>
 }
 
 // -----


        


More information about the Mlir-commits mailing list