[Mlir-commits] [mlir] 3e037f8 - [MLIR][Shape] Derive more concrete type for `shape.shape_of`

Frederik Gossen llvmlistbot at llvm.org
Wed Apr 28 01:51:11 PDT 2021


Author: Frederik Gossen
Date: 2021-04-28T10:50:53+02:00
New Revision: 3e037f8f0e26acab8cc784ea4c7d05da79f7c22e

URL: https://github.com/llvm/llvm-project/commit/3e037f8f0e26acab8cc784ea4c7d05da79f7c22e
DIFF: https://github.com/llvm/llvm-project/commit/3e037f8f0e26acab8cc784ea4c7d05da79f7c22e.diff

LOG: [MLIR][Shape] Derive more concrete type for `shape.shape_of`

Also create all extent tensor constants with const_shape op.

Differential Revision: https://reviews.llvm.org/D99197

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Shape/IR/Shape.h
    mlir/lib/Dialect/Shape/IR/Shape.cpp
    mlir/test/Dialect/Shape/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Shape/IR/Shape.h b/mlir/include/mlir/Dialect/Shape/IR/Shape.h
index 47e61a8c47689..570719eff64d5 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/Shape.h
+++ b/mlir/include/mlir/Dialect/Shape/IR/Shape.h
@@ -31,6 +31,9 @@ namespace shape {
 /// Alias type for extent tensors.
 RankedTensorType getExtentTensorType(MLIRContext *ctx);
 
+// Check if a type is an extent tensor, e.g., tensor<?xindex>.
+bool isExtentTensorType(Type);
+
 // Given an input shape Value, try to obtain the shape's values.
 LogicalResult getShapeVec(Value input, SmallVectorImpl<int64_t> &shapeValues);
 

diff  --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 96618d24747fd..47fd322ba47ce 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -31,6 +31,11 @@ RankedTensorType shape::getExtentTensorType(MLIRContext *ctx) {
   return RankedTensorType::get({ShapedType::kDynamicSize}, IndexType::get(ctx));
 }
 
+bool shape::isExtentTensorType(Type type) {
+  auto ranked = type.dyn_cast<RankedTensorType>();
+  return ranked && ranked.getRank() == 1 && ranked.getElementType().isIndex();
+}
+
 LogicalResult shape::getShapeVec(Value input,
                                  SmallVectorImpl<int64_t> &shapeValues) {
   if (auto inputOp = input.getDefiningOp<ShapeOfOp>()) {
@@ -123,8 +128,7 @@ void ShapeDialect::initialize() {
 Operation *ShapeDialect::materializeConstant(OpBuilder &builder,
                                              Attribute value, Type type,
                                              Location loc) {
-  if (type.isa<ShapeType>() ||
-      type == getExtentTensorType(builder.getContext()))
+  if (type.isa<ShapeType>() || isExtentTensorType(type))
     return builder.create<ConstShapeOp>(loc, type,
                                         value.cast<DenseIntElementsAttr>());
   if (type.isa<SizeType>())
@@ -1148,10 +1152,15 @@ OpFoldResult ShapeOfOp::fold(ArrayRef<Attribute>) {
 }
 
 void ShapeOfOp::build(OpBuilder &builder, OperationState &result, Value arg) {
-  Type type = arg.getType().isa<ShapedType>()
-                  ? (Type)getExtentTensorType(builder.getContext())
-                  : (Type)builder.getType<ShapeType>();
-  return ShapeOfOp::build(builder, result, type, arg);
+  if (auto shapedTy = arg.getType().dyn_cast<ShapedType>()) {
+    int64_t rank =
+        shapedTy.hasRank() ? shapedTy.getRank() : ShapedType::kDynamicSize;
+    Type indexTy = builder.getIndexType();
+    Type extentTensorTy = RankedTensorType::get({rank}, indexTy);
+    return ShapeOfOp::build(builder, result, extentTensorTy, arg);
+  }
+  Type shapeTy = builder.getType<ShapeType>();
+  return ShapeOfOp::build(builder, result, shapeTy, arg);
 }
 
 namespace {

diff  --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index 6dc48b1e732aa..3876e9ca34fc4 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -138,7 +138,7 @@ func @all_but_one_empty(%arg0 : !shape.shape) -> !shape.shape {
 // CHECK-LABEL: @partial_folding
 // CHECK-SAME:  (%[[ARG:.*]]: !shape.shape)
 func @partial_folding(%arg0 : !shape.shape) -> !shape.shape {
-  // CHECK: %[[CST_SHAPE:.*]] = constant dense<[1, 2, 3]> : tensor<3xindex>
+  // CHECK: %[[CST_SHAPE:.*]] = shape.const_shape [1, 2, 3] : tensor<3xindex>
   // CHECK: %[[RESULT:.*]] = shape.broadcast %[[ARG]], %[[CST_SHAPE]] : !shape.shape, tensor<3xindex> -> !shape.shape
   // CHECK: return %[[RESULT]]
   %0 = shape.const_shape [2, 1] : !shape.shape
@@ -188,7 +188,7 @@ func @f() -> !shape.shape {
 // Basic case.
 // CHECK-LABEL: func @f
 func @f() -> tensor<2xindex> {
-  // CHECK: constant dense<[0, 1]> : tensor<2xindex>
+  // CHECK: shape.const_shape [0, 1] : tensor<2xindex>
   %cs = shape.const_shape [0, 1] : !shape.shape
   %0 = shape.to_extent_tensor %cs : !shape.shape -> tensor<2xindex>
   return %0 : tensor<2xindex>
@@ -1146,7 +1146,7 @@ func @fold_assuming_all_single_element(%arg: tensor<?xindex>) {
 // Verify that tensor.cast folding uses the correct type
 // CHECK-LABEL: @fold_tensor.cast_of_const_shape_returned
 func @fold_tensor.cast_of_const_shape_returned(%arg: i1) -> tensor<1xindex> {
-  // CHECK: constant dense<2> : tensor<1xindex>
+  // CHECK: shape.const_shape [2] : tensor<1xindex>
   // CHECK-NOT: tensor.cast
   %0 = shape.const_shape [2] : tensor<?xindex>
   %1 = tensor.cast %0 : tensor<?xindex> to tensor<1xindex>
@@ -1325,14 +1325,13 @@ func @min_same_arg(%a: !shape.shape) -> !shape.shape {
   // CHECK: return %[[SHAPE]]
   return %1 : !shape.shape
 }
-
 // ----
 
 // CHECK-LABEL: @cstr_broadcastable_folding
 func @cstr_broadcastable_folding(%arg : tensor<?x4xf32>) {
   // CHECK: const_witness true
   %0 = shape.shape_of %arg : tensor<?x4xf32> -> tensor<2xindex>
-  %1 = constant dense<[4]> : tensor<1xindex>
+  %1 = shape.const_shape [4] : tensor<1xindex>
   %2 = shape.cstr_broadcastable %0, %1: tensor<2xindex>, tensor<1xindex>
   "use"(%2) : (!shape.witness) -> ()
 }


        


More information about the Mlir-commits mailing list