[Mlir-commits] [mlir] d4e4d5d - [MLIR][Shape] Allow for `shape_of` to return extent tensors

Frederik Gossen llvmlistbot at llvm.org
Fri Jul 24 01:41:08 PDT 2020


Author: Frederik Gossen
Date: 2020-07-24T08:40:40Z
New Revision: d4e4d5d78044a7e81df1343cf064dd8c9472b70c

URL: https://github.com/llvm/llvm-project/commit/d4e4d5d78044a7e81df1343cf064dd8c9472b70c
DIFF: https://github.com/llvm/llvm-project/commit/d4e4d5d78044a7e81df1343cf064dd8c9472b70c.diff

LOG: [MLIR][Shape] Allow for `shape_of` to return extent tensors

The operation `shape.shape_of` now returns an extent tensor `tensor<?xindex>` in
cases when no error are possible. All consuming operation will eventually accept
both, shapes and extent tensors.

Differential Revision: https://reviews.llvm.org/D84160

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td
    mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
    mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
    mlir/lib/Dialect/Shape/IR/Shape.cpp
    mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir
    mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
    mlir/test/Dialect/Shape/canonicalize.mlir
    mlir/test/Dialect/Shape/invalid.mlir
    mlir/test/Dialect/Shape/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td
index ef20b5a9813d..3cbebe723921 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td
@@ -104,11 +104,20 @@ def Shape_ValueShapeType : DialectType<ShapeDialect,
   }];
 }
 
+def Shape_ExtentTensorType :
+    1DTensorOf<[Index]>,
+    BuildableType<"::mlir::RankedTensorType::get({ShapedType::kDynamicSize}, "
+                  "$_builder.getType<::mlir::IndexType>())"> {
+  let typeDescription = [{
+    The extent tensor is a tensor of rank one with arbitrarily many index
+    elements. Like `!shape.shape`, it is used to represent shapes with the
+    
diff erence that it is guaranteed to be error-free.
+  }];
+}
+
 def Shape_ShapeOrSizeType : AnyTypeOf<[Shape_SizeType, Shape_ShapeType],
   "shape or size">;
 
-def Shape_ExtentTensorType : 1DTensorOf<[Index]>;
-
 def Shape_ShapeOrExtentTensorType : AnyTypeOf<[Shape_ShapeType,
                                                Shape_ExtentTensorType],
                                               "shape or extent tensor">;

diff  --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
index 2302c5110f65..70f8d75748f4 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
@@ -391,11 +391,17 @@ def Shape_ReduceOp : Shape_Op<"reduce",
 def Shape_ShapeOfOp : Shape_Op<"shape_of", [NoSideEffect]> {
   let summary = "Returns shape of a value or shaped type operand";
 
+  let description = [{
+    The operation takes a value or a shaped operand as an argument and it
+    returns a shape or extent tensor.
+  }];
+
   let arguments = (ins AnyTypeOf<[AnyShaped, Shape_ValueShapeType]>:$arg);
-  let results = (outs Shape_ShapeType:$result);
+  let results = (outs Shape_ShapeOrExtentTensorType:$result);
 
-  let assemblyFormat = "$arg `:` type($arg) attr-dict";
+  let assemblyFormat = "$arg `:` type($arg) `->` type($result) attr-dict";
 
+  let verifier = [{ return ::verify(*this); }];
   let hasFolder = 1;
 }
 

diff  --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
index f82019989e70..ae3874d0cb4d 100644
--- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
+++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
@@ -86,9 +86,11 @@ class ShapeOfOpConversion : public OpConversionPattern<ShapeOfOp> {
       }
     }
 
-    // Materialize shape as ranked tensor.
-    rewriter.replaceOpWithNewOp<TensorFromElementsOp>(op.getOperation(),
-                                                      dimValues);
+    // Materialize extent tensor.
+    Value staticExtentTensor =
+        rewriter.create<TensorFromElementsOp>(loc, dimValues);
+    rewriter.replaceOpWithNewOp<TensorCastOp>(op, staticExtentTensor,
+                                              op.getType());
     return success();
   }
 };

diff  --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 42b8b34c7e09..c5f11a9a95d3 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -23,9 +23,8 @@ namespace {
 #include "ShapeCanonicalization.inc"
 }
 
-static RankedTensorType getExtentTensorType(OpBuilder &builder) {
-  return RankedTensorType::get({ShapedType::kDynamicSize},
-                               builder.getIndexType());
+static RankedTensorType getExtentTensorType(MLIRContext *ctx) {
+  return RankedTensorType::get({ShapedType::kDynamicSize}, IndexType::get(ctx));
 }
 
 ShapeDialect::ShapeDialect(MLIRContext *context)
@@ -45,7 +44,8 @@ ShapeDialect::ShapeDialect(MLIRContext *context)
 Operation *ShapeDialect::materializeConstant(OpBuilder &builder,
                                              Attribute value, Type type,
                                              Location loc) {
-  if (type.isa<ShapeType>() || type == getExtentTensorType(builder))
+  if (type.isa<ShapeType>() ||
+      type == getExtentTensorType(builder.getContext()))
     return builder.create<ConstShapeOp>(loc, type,
                                         value.cast<DenseIntElementsAttr>());
   if (type.isa<SizeType>())
@@ -641,6 +641,23 @@ OpFoldResult ShapeOfOp::fold(ArrayRef<Attribute>) {
   return builder.getIndexTensorAttr(type.getShape());
 }
 
+static LogicalResult verify(ShapeOfOp op) {
+  Type argTy = op.arg().getType();
+  Type resultTy = op.result().getType();
+  if (argTy.isa<ValueShapeType>()) {
+    if (!resultTy.isa<ShapeType>())
+      return op.emitOpError()
+             << "if operand is of type `value_shape` then the result must be "
+                "of type `shape` to propagate potential error shapes";
+  } else {
+    assert(argTy.isa<ShapedType>());
+    if (resultTy != getExtentTensorType(op.getContext()))
+      return op.emitOpError() << "if operand is a shaped type then the result "
+                                 "must be an extent tensor";
+  }
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // SizeToIndexOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir b/mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir
index 2e5a45c4cc11..441b2e92cc3d 100644
--- a/mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir
+++ b/mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir
@@ -39,7 +39,7 @@ func @shape_of_unranked(%arg : tensor<*xf32>) {
   // CHECK:     }
   // CHECK-DAG: %[[SHAPE_INT:.*]] = tensor_load %[[SHAPE_MEM]] : memref<?xi64>
   // CHECK-DAG: %[[SHAPE:.*]] = index_cast %[[SHAPE_INT]] : tensor<?xi64> to tensor<?xindex>
-  %shape = shape.shape_of %arg : tensor<*xf32>
+  %shape = shape.shape_of %arg : tensor<*xf32> -> tensor<?xindex>
   return
 }
 

diff  --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
index f50b6530d9d7..67fb7cdb7910 100644
--- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
+++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
@@ -95,8 +95,9 @@ func @shape_of_stat(%arg : tensor<1x2x3xf32>) {
   // CHECK-DAG: %[[C1:.*]] = constant 1 : index
   // CHECK-DAG: %[[C2:.*]] = constant 2 : index
   // CHECK-DAG: %[[C3:.*]] = constant 3 : index
-  // CHECK-DAG: %[[SHAPE:.*]] = tensor_from_elements(%[[C1]], %[[C2]], %[[C3]]) : tensor<3xindex>
-  %shape = shape.shape_of %arg : tensor<1x2x3xf32>
+  // CHECK-DAG: %[[SHAPE_UNCASTED:.*]] = tensor_from_elements(%[[C1]], %[[C2]], %[[C3]]) : tensor<3xindex>
+  // CHECK-DAG: %[[SHAPE:.*]] = tensor_cast %[[SHAPE_UNCASTED]] : tensor<3xindex> to tensor<?xindex>
+  %shape = shape.shape_of %arg : tensor<1x2x3xf32> -> tensor<?xindex>
   return
 }
 
@@ -110,8 +111,9 @@ func @shape_of_dyn(%arg : tensor<1x5x?xf32>) {
   // CHECK-DAG: %[[C5:.*]] = constant 5 : index
   // CHECK-DAG: %[[C2:.*]] = constant 2 : index
   // CHECK-DAG: %[[DYN_DIM:.*]] = dim %[[ARG]], %[[C2]] : tensor<1x5x?xf32>
-  // CHECK-DAG: %[[SHAPE:.*]] = tensor_from_elements(%[[C1]], %[[C5]], %[[DYN_DIM]]) : tensor<3xindex>
-  %shape = shape.shape_of %arg : tensor<1x5x?xf32>
+  // CHECK-DAG: %[[SHAPE_UNCASTED:.*]] = tensor_from_elements(%[[C1]], %[[C5]], %[[DYN_DIM]]) : tensor<3xindex>
+  // CHECK-DAG: %[[SHAPE:.*]] = tensor_cast %[[SHAPE_UNCASTED]] : tensor<3xindex> to tensor<?xindex>
+  %shape = shape.shape_of %arg : tensor<1x5x?xf32> -> tensor<?xindex>
   return
 }
 
@@ -138,8 +140,8 @@ func @get_extent_shape_of(%arg : tensor<2x3xf32>, %idx : !shape.size)
     -> !shape.size {
   // CHECK: %[[RESULT:.*]] = dim %[[ARG]], %[[IDX]] : tensor<2x3xf32>
   // CHECK: return %[[RESULT]] : index
-  %shape = shape.shape_of %arg : tensor<2x3xf32>
-  %result = shape.get_extent %shape, %idx : !shape.shape
+  %shape = shape.shape_of %arg : tensor<2x3xf32> -> tensor<?xindex>
+  %result = shape.get_extent %shape, %idx : tensor<?xindex>
   return %result : !shape.size
 }
 

diff  --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index 9e691b88b016..e2874e09cc8a 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -1,10 +1,10 @@
-// RUN: mlir-opt -split-input-file -allow-unregistered-dialect -canonicalize <%s | FileCheck %s
+// RUN: mlir-opt -split-input-file -allow-unregistered-dialect -canonicalize %s | FileCheck %s
 
 // CHECK-LABEL: func @f
-func @f(%arg0: tensor<2x3x4xf32>) -> !shape.shape {
-  // CHECK: shape.const_shape [2, 3, 4] : !shape.shape
-  %0 = "shape.shape_of"(%arg0) : (tensor<2x3x4xf32>) -> !shape.shape
-  return %0 : !shape.shape
+func @f(%arg0: tensor<2x3x4xf32>) -> tensor<?xindex> {
+  // CHECK: shape.const_shape [2, 3, 4] : tensor<?xindex>
+  %0 = shape.shape_of %arg0 : tensor<2x3x4xf32> -> tensor<?xindex>
+  return %0 : tensor<?xindex>
 }
 
 // -----
@@ -522,8 +522,8 @@ func @dont_fold_rank(%shape : !shape.shape) -> !shape.size {
 func @canonicalize_rank(%arg : tensor<1x2x?xf32>) -> !shape.size {
   // CHECK-DAG: %[[RESULT:.*]] = shape.const_size 3
   // CHECK-DAG: return %[[RESULT]] : !shape.size
-  %shape = shape.shape_of %arg : tensor<1x2x?xf32>
-  %rank = shape.rank %shape : !shape.shape
+  %shape = shape.shape_of %arg : tensor<1x2x?xf32> -> tensor<?xindex>
+  %rank = shape.rank %shape : tensor<?xindex>
   return %rank : !shape.size
 }
 
@@ -533,11 +533,11 @@ func @canonicalize_rank(%arg : tensor<1x2x?xf32>) -> !shape.size {
 // CHECK-LABEL: @dont_canonicalize_rank
 // CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>) -> !shape.size
 func @dont_canonicalize_rank(%arg : tensor<*xf32>) -> !shape.size {
-  // CHECK-DAG: %[[SHAPE:.*]] = shape.shape_of %[[ARG]] : tensor<*xf32>
+  // CHECK-DAG: %[[SHAPE:.*]] = shape.shape_of %[[ARG]] : tensor<*xf32> -> tensor<?xindex>
   // CHECK-DAG: %[[SIZE:.*]] = shape.rank %[[SHAPE]]
   // CHECK-DAG: return %[[SIZE]] : !shape.size
-  %shape = shape.shape_of %arg : tensor<*xf32>
-  %rank = shape.rank %shape : !shape.shape
+  %shape = shape.shape_of %arg : tensor<*xf32> -> tensor<?xindex>
+  %rank = shape.rank %shape : tensor<?xindex>
   return %rank : !shape.size
 }
 
@@ -572,8 +572,8 @@ func @cstr_broadcastable_scalar(%arg0 : tensor<?xf32>) {
   // CHECK-NEXT: consume.witness
   // CHECK-NEXT: return
   %0 = shape.const_shape [] : !shape.shape
-  %1 = shape.shape_of %arg0 : tensor<?xf32>
-  %2 = shape.cstr_broadcastable %0, %1 : !shape.shape, !shape.shape
+  %1 = shape.shape_of %arg0 : tensor<?xf32> -> tensor<?xindex>
+  %2 = shape.cstr_broadcastable %0, %1 : !shape.shape, tensor<?xindex>
   "consume.witness"(%2) : (!shape.witness) -> ()
   return
 }
@@ -588,9 +588,9 @@ func @cstr_broadcastable_unknown(%arg0 : tensor<?xf32>, %arg1 : tensor<?xf32>) {
   // CHECK-NEXT: shape.cstr_broadcastable
   // CHECK-NEXT: consume.witness
   // CHECK-NEXT: return
-  %0 = shape.shape_of %arg0 : tensor<?xf32>
-  %1 = shape.shape_of %arg1 : tensor<?xf32>
-  %2 = shape.cstr_broadcastable %0, %1 : !shape.shape, !shape.shape
+  %0 = shape.shape_of %arg0 : tensor<?xf32> -> tensor<?xindex>
+  %1 = shape.shape_of %arg1 : tensor<?xf32> -> tensor<?xindex>
+  %2 = shape.cstr_broadcastable %0, %1 : tensor<?xindex>, tensor<?xindex>
   "consume.witness"(%2) : (!shape.witness) -> ()
   return
 }
@@ -603,9 +603,9 @@ func @cstr_broadcastable_scalar_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<i
   // CHECK-NEXT: shape.const_witness true
   // CHECK-NEXT: consume.witness
   // CHECK-NEXT: return
-  %0 = shape.shape_of %arg1 : tensor<index>
-  %1 = shape.shape_of %arg0 : tensor<*xf32>
-  %2 = shape.cstr_broadcastable %0, %1 : !shape.shape, !shape.shape
+  %0 = shape.shape_of %arg1 : tensor<index> -> tensor<?xindex>
+  %1 = shape.shape_of %arg0 : tensor<*xf32> -> tensor<?xindex>
+  %2 = shape.cstr_broadcastable %0, %1 : tensor<?xindex>, tensor<?xindex>
   "consume.witness"(%2) : (!shape.witness) -> ()
   return
 }

diff  --git a/mlir/test/Dialect/Shape/invalid.mlir b/mlir/test/Dialect/Shape/invalid.mlir
index 3aca3677c143..23d0daf79378 100644
--- a/mlir/test/Dialect/Shape/invalid.mlir
+++ b/mlir/test/Dialect/Shape/invalid.mlir
@@ -78,3 +78,20 @@ func @assuming_all_op_too_few_operands() {
   %w0 = shape.assuming_all
   return
 }
+
+// -----
+
+func @shape_of(%value_arg : !shape.value_shape,
+               %shaped_arg : tensor<?x3x4xf32>) {
+  // expected-error at +1 {{if operand is of type `value_shape` then the result must be of type `shape` to propagate potential error shapes}}
+  %0 = shape.shape_of %value_arg : !shape.value_shape -> tensor<?xindex>
+}
+
+// -----
+
+func @shape_of(%value_arg : !shape.value_shape,
+               %shaped_arg : tensor<?x3x4xf32>) {
+  // expected-error at +1 {{if operand is a shaped type then the result must be an extent tensor}}
+  %1 = shape.shape_of %shaped_arg : tensor<?x3x4xf32> -> !shape.shape
+}
+

diff  --git a/mlir/test/Dialect/Shape/ops.mlir b/mlir/test/Dialect/Shape/ops.mlir
index 66b5834ff653..d0275aaf692e 100644
--- a/mlir/test/Dialect/Shape/ops.mlir
+++ b/mlir/test/Dialect/Shape/ops.mlir
@@ -85,9 +85,9 @@ func @test_parse_const_shape() {
   return
 }
 
-func @test_shape_of(%arg0: tensor<?xf32>) -> !shape.shape {
-  %0 = shape.shape_of %arg0 : tensor<?xf32>
-  return %0 : !shape.shape
+func @test_shape_of(%arg0: tensor<?xf32>) -> tensor<?xindex> {
+  %0 = shape.shape_of %arg0 : tensor<?xf32> -> tensor<?xindex>
+  return %0 : tensor<?xindex>
 }
 
 func @test_constraints() {


        


More information about the Mlir-commits mailing list