[Mlir-commits] [mlir] d4e4d5d - [MLIR][Shape] Allow for `shape_of` to return extent tensors
Frederik Gossen
llvmlistbot at llvm.org
Fri Jul 24 01:41:08 PDT 2020
Author: Frederik Gossen
Date: 2020-07-24T08:40:40Z
New Revision: d4e4d5d78044a7e81df1343cf064dd8c9472b70c
URL: https://github.com/llvm/llvm-project/commit/d4e4d5d78044a7e81df1343cf064dd8c9472b70c
DIFF: https://github.com/llvm/llvm-project/commit/d4e4d5d78044a7e81df1343cf064dd8c9472b70c.diff
LOG: [MLIR][Shape] Allow for `shape_of` to return extent tensors
The operation `shape.shape_of` now returns an extent tensor `tensor<?xindex>` in
cases when no error are possible. All consuming operation will eventually accept
both, shapes and extent tensors.
Differential Revision: https://reviews.llvm.org/D84160
Added:
Modified:
mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td
mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
mlir/lib/Dialect/Shape/IR/Shape.cpp
mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir
mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
mlir/test/Dialect/Shape/canonicalize.mlir
mlir/test/Dialect/Shape/invalid.mlir
mlir/test/Dialect/Shape/ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td
index ef20b5a9813d..3cbebe723921 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td
@@ -104,11 +104,20 @@ def Shape_ValueShapeType : DialectType<ShapeDialect,
}];
}
+def Shape_ExtentTensorType :
+ 1DTensorOf<[Index]>,
+ BuildableType<"::mlir::RankedTensorType::get({ShapedType::kDynamicSize}, "
+ "$_builder.getType<::mlir::IndexType>())"> {
+ let typeDescription = [{
+ The extent tensor is a tensor of rank one with arbitrarily many index
+ elements. Like `!shape.shape`, it is used to represent shapes with the
+
diff erence that it is guaranteed to be error-free.
+ }];
+}
+
def Shape_ShapeOrSizeType : AnyTypeOf<[Shape_SizeType, Shape_ShapeType],
"shape or size">;
-def Shape_ExtentTensorType : 1DTensorOf<[Index]>;
-
def Shape_ShapeOrExtentTensorType : AnyTypeOf<[Shape_ShapeType,
Shape_ExtentTensorType],
"shape or extent tensor">;
diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
index 2302c5110f65..70f8d75748f4 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
@@ -391,11 +391,17 @@ def Shape_ReduceOp : Shape_Op<"reduce",
def Shape_ShapeOfOp : Shape_Op<"shape_of", [NoSideEffect]> {
let summary = "Returns shape of a value or shaped type operand";
+ let description = [{
+ The operation takes a value or a shaped operand as an argument and it
+ returns a shape or extent tensor.
+ }];
+
let arguments = (ins AnyTypeOf<[AnyShaped, Shape_ValueShapeType]>:$arg);
- let results = (outs Shape_ShapeType:$result);
+ let results = (outs Shape_ShapeOrExtentTensorType:$result);
- let assemblyFormat = "$arg `:` type($arg) attr-dict";
+ let assemblyFormat = "$arg `:` type($arg) `->` type($result) attr-dict";
+ let verifier = [{ return ::verify(*this); }];
let hasFolder = 1;
}
diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
index f82019989e70..ae3874d0cb4d 100644
--- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
+++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
@@ -86,9 +86,11 @@ class ShapeOfOpConversion : public OpConversionPattern<ShapeOfOp> {
}
}
- // Materialize shape as ranked tensor.
- rewriter.replaceOpWithNewOp<TensorFromElementsOp>(op.getOperation(),
- dimValues);
+ // Materialize extent tensor.
+ Value staticExtentTensor =
+ rewriter.create<TensorFromElementsOp>(loc, dimValues);
+ rewriter.replaceOpWithNewOp<TensorCastOp>(op, staticExtentTensor,
+ op.getType());
return success();
}
};
diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 42b8b34c7e09..c5f11a9a95d3 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -23,9 +23,8 @@ namespace {
#include "ShapeCanonicalization.inc"
}
-static RankedTensorType getExtentTensorType(OpBuilder &builder) {
- return RankedTensorType::get({ShapedType::kDynamicSize},
- builder.getIndexType());
+static RankedTensorType getExtentTensorType(MLIRContext *ctx) {
+ return RankedTensorType::get({ShapedType::kDynamicSize}, IndexType::get(ctx));
}
ShapeDialect::ShapeDialect(MLIRContext *context)
@@ -45,7 +44,8 @@ ShapeDialect::ShapeDialect(MLIRContext *context)
Operation *ShapeDialect::materializeConstant(OpBuilder &builder,
Attribute value, Type type,
Location loc) {
- if (type.isa<ShapeType>() || type == getExtentTensorType(builder))
+ if (type.isa<ShapeType>() ||
+ type == getExtentTensorType(builder.getContext()))
return builder.create<ConstShapeOp>(loc, type,
value.cast<DenseIntElementsAttr>());
if (type.isa<SizeType>())
@@ -641,6 +641,23 @@ OpFoldResult ShapeOfOp::fold(ArrayRef<Attribute>) {
return builder.getIndexTensorAttr(type.getShape());
}
+static LogicalResult verify(ShapeOfOp op) {
+ Type argTy = op.arg().getType();
+ Type resultTy = op.result().getType();
+ if (argTy.isa<ValueShapeType>()) {
+ if (!resultTy.isa<ShapeType>())
+ return op.emitOpError()
+ << "if operand is of type `value_shape` then the result must be "
+ "of type `shape` to propagate potential error shapes";
+ } else {
+ assert(argTy.isa<ShapedType>());
+ if (resultTy != getExtentTensorType(op.getContext()))
+ return op.emitOpError() << "if operand is a shaped type then the result "
+ "must be an extent tensor";
+ }
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// SizeToIndexOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir b/mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir
index 2e5a45c4cc11..441b2e92cc3d 100644
--- a/mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir
+++ b/mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir
@@ -39,7 +39,7 @@ func @shape_of_unranked(%arg : tensor<*xf32>) {
// CHECK: }
// CHECK-DAG: %[[SHAPE_INT:.*]] = tensor_load %[[SHAPE_MEM]] : memref<?xi64>
// CHECK-DAG: %[[SHAPE:.*]] = index_cast %[[SHAPE_INT]] : tensor<?xi64> to tensor<?xindex>
- %shape = shape.shape_of %arg : tensor<*xf32>
+ %shape = shape.shape_of %arg : tensor<*xf32> -> tensor<?xindex>
return
}
diff --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
index f50b6530d9d7..67fb7cdb7910 100644
--- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
+++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
@@ -95,8 +95,9 @@ func @shape_of_stat(%arg : tensor<1x2x3xf32>) {
// CHECK-DAG: %[[C1:.*]] = constant 1 : index
// CHECK-DAG: %[[C2:.*]] = constant 2 : index
// CHECK-DAG: %[[C3:.*]] = constant 3 : index
- // CHECK-DAG: %[[SHAPE:.*]] = tensor_from_elements(%[[C1]], %[[C2]], %[[C3]]) : tensor<3xindex>
- %shape = shape.shape_of %arg : tensor<1x2x3xf32>
+ // CHECK-DAG: %[[SHAPE_UNCASTED:.*]] = tensor_from_elements(%[[C1]], %[[C2]], %[[C3]]) : tensor<3xindex>
+ // CHECK-DAG: %[[SHAPE:.*]] = tensor_cast %[[SHAPE_UNCASTED]] : tensor<3xindex> to tensor<?xindex>
+ %shape = shape.shape_of %arg : tensor<1x2x3xf32> -> tensor<?xindex>
return
}
@@ -110,8 +111,9 @@ func @shape_of_dyn(%arg : tensor<1x5x?xf32>) {
// CHECK-DAG: %[[C5:.*]] = constant 5 : index
// CHECK-DAG: %[[C2:.*]] = constant 2 : index
// CHECK-DAG: %[[DYN_DIM:.*]] = dim %[[ARG]], %[[C2]] : tensor<1x5x?xf32>
- // CHECK-DAG: %[[SHAPE:.*]] = tensor_from_elements(%[[C1]], %[[C5]], %[[DYN_DIM]]) : tensor<3xindex>
- %shape = shape.shape_of %arg : tensor<1x5x?xf32>
+ // CHECK-DAG: %[[SHAPE_UNCASTED:.*]] = tensor_from_elements(%[[C1]], %[[C5]], %[[DYN_DIM]]) : tensor<3xindex>
+ // CHECK-DAG: %[[SHAPE:.*]] = tensor_cast %[[SHAPE_UNCASTED]] : tensor<3xindex> to tensor<?xindex>
+ %shape = shape.shape_of %arg : tensor<1x5x?xf32> -> tensor<?xindex>
return
}
@@ -138,8 +140,8 @@ func @get_extent_shape_of(%arg : tensor<2x3xf32>, %idx : !shape.size)
-> !shape.size {
// CHECK: %[[RESULT:.*]] = dim %[[ARG]], %[[IDX]] : tensor<2x3xf32>
// CHECK: return %[[RESULT]] : index
- %shape = shape.shape_of %arg : tensor<2x3xf32>
- %result = shape.get_extent %shape, %idx : !shape.shape
+ %shape = shape.shape_of %arg : tensor<2x3xf32> -> tensor<?xindex>
+ %result = shape.get_extent %shape, %idx : tensor<?xindex>
return %result : !shape.size
}
diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index 9e691b88b016..e2874e09cc8a 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -1,10 +1,10 @@
-// RUN: mlir-opt -split-input-file -allow-unregistered-dialect -canonicalize <%s | FileCheck %s
+// RUN: mlir-opt -split-input-file -allow-unregistered-dialect -canonicalize %s | FileCheck %s
// CHECK-LABEL: func @f
-func @f(%arg0: tensor<2x3x4xf32>) -> !shape.shape {
- // CHECK: shape.const_shape [2, 3, 4] : !shape.shape
- %0 = "shape.shape_of"(%arg0) : (tensor<2x3x4xf32>) -> !shape.shape
- return %0 : !shape.shape
+func @f(%arg0: tensor<2x3x4xf32>) -> tensor<?xindex> {
+ // CHECK: shape.const_shape [2, 3, 4] : tensor<?xindex>
+ %0 = shape.shape_of %arg0 : tensor<2x3x4xf32> -> tensor<?xindex>
+ return %0 : tensor<?xindex>
}
// -----
@@ -522,8 +522,8 @@ func @dont_fold_rank(%shape : !shape.shape) -> !shape.size {
func @canonicalize_rank(%arg : tensor<1x2x?xf32>) -> !shape.size {
// CHECK-DAG: %[[RESULT:.*]] = shape.const_size 3
// CHECK-DAG: return %[[RESULT]] : !shape.size
- %shape = shape.shape_of %arg : tensor<1x2x?xf32>
- %rank = shape.rank %shape : !shape.shape
+ %shape = shape.shape_of %arg : tensor<1x2x?xf32> -> tensor<?xindex>
+ %rank = shape.rank %shape : tensor<?xindex>
return %rank : !shape.size
}
@@ -533,11 +533,11 @@ func @canonicalize_rank(%arg : tensor<1x2x?xf32>) -> !shape.size {
// CHECK-LABEL: @dont_canonicalize_rank
// CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>) -> !shape.size
func @dont_canonicalize_rank(%arg : tensor<*xf32>) -> !shape.size {
- // CHECK-DAG: %[[SHAPE:.*]] = shape.shape_of %[[ARG]] : tensor<*xf32>
+ // CHECK-DAG: %[[SHAPE:.*]] = shape.shape_of %[[ARG]] : tensor<*xf32> -> tensor<?xindex>
// CHECK-DAG: %[[SIZE:.*]] = shape.rank %[[SHAPE]]
// CHECK-DAG: return %[[SIZE]] : !shape.size
- %shape = shape.shape_of %arg : tensor<*xf32>
- %rank = shape.rank %shape : !shape.shape
+ %shape = shape.shape_of %arg : tensor<*xf32> -> tensor<?xindex>
+ %rank = shape.rank %shape : tensor<?xindex>
return %rank : !shape.size
}
@@ -572,8 +572,8 @@ func @cstr_broadcastable_scalar(%arg0 : tensor<?xf32>) {
// CHECK-NEXT: consume.witness
// CHECK-NEXT: return
%0 = shape.const_shape [] : !shape.shape
- %1 = shape.shape_of %arg0 : tensor<?xf32>
- %2 = shape.cstr_broadcastable %0, %1 : !shape.shape, !shape.shape
+ %1 = shape.shape_of %arg0 : tensor<?xf32> -> tensor<?xindex>
+ %2 = shape.cstr_broadcastable %0, %1 : !shape.shape, tensor<?xindex>
"consume.witness"(%2) : (!shape.witness) -> ()
return
}
@@ -588,9 +588,9 @@ func @cstr_broadcastable_unknown(%arg0 : tensor<?xf32>, %arg1 : tensor<?xf32>) {
// CHECK-NEXT: shape.cstr_broadcastable
// CHECK-NEXT: consume.witness
// CHECK-NEXT: return
- %0 = shape.shape_of %arg0 : tensor<?xf32>
- %1 = shape.shape_of %arg1 : tensor<?xf32>
- %2 = shape.cstr_broadcastable %0, %1 : !shape.shape, !shape.shape
+ %0 = shape.shape_of %arg0 : tensor<?xf32> -> tensor<?xindex>
+ %1 = shape.shape_of %arg1 : tensor<?xf32> -> tensor<?xindex>
+ %2 = shape.cstr_broadcastable %0, %1 : tensor<?xindex>, tensor<?xindex>
"consume.witness"(%2) : (!shape.witness) -> ()
return
}
@@ -603,9 +603,9 @@ func @cstr_broadcastable_scalar_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<i
// CHECK-NEXT: shape.const_witness true
// CHECK-NEXT: consume.witness
// CHECK-NEXT: return
- %0 = shape.shape_of %arg1 : tensor<index>
- %1 = shape.shape_of %arg0 : tensor<*xf32>
- %2 = shape.cstr_broadcastable %0, %1 : !shape.shape, !shape.shape
+ %0 = shape.shape_of %arg1 : tensor<index> -> tensor<?xindex>
+ %1 = shape.shape_of %arg0 : tensor<*xf32> -> tensor<?xindex>
+ %2 = shape.cstr_broadcastable %0, %1 : tensor<?xindex>, tensor<?xindex>
"consume.witness"(%2) : (!shape.witness) -> ()
return
}
diff --git a/mlir/test/Dialect/Shape/invalid.mlir b/mlir/test/Dialect/Shape/invalid.mlir
index 3aca3677c143..23d0daf79378 100644
--- a/mlir/test/Dialect/Shape/invalid.mlir
+++ b/mlir/test/Dialect/Shape/invalid.mlir
@@ -78,3 +78,20 @@ func @assuming_all_op_too_few_operands() {
%w0 = shape.assuming_all
return
}
+
+// -----
+
+func @shape_of(%value_arg : !shape.value_shape,
+ %shaped_arg : tensor<?x3x4xf32>) {
+ // expected-error at +1 {{if operand is of type `value_shape` then the result must be of type `shape` to propagate potential error shapes}}
+ %0 = shape.shape_of %value_arg : !shape.value_shape -> tensor<?xindex>
+}
+
+// -----
+
+func @shape_of(%value_arg : !shape.value_shape,
+ %shaped_arg : tensor<?x3x4xf32>) {
+ // expected-error at +1 {{if operand is a shaped type then the result must be an extent tensor}}
+ %1 = shape.shape_of %shaped_arg : tensor<?x3x4xf32> -> !shape.shape
+}
+
diff --git a/mlir/test/Dialect/Shape/ops.mlir b/mlir/test/Dialect/Shape/ops.mlir
index 66b5834ff653..d0275aaf692e 100644
--- a/mlir/test/Dialect/Shape/ops.mlir
+++ b/mlir/test/Dialect/Shape/ops.mlir
@@ -85,9 +85,9 @@ func @test_parse_const_shape() {
return
}
-func @test_shape_of(%arg0: tensor<?xf32>) -> !shape.shape {
- %0 = shape.shape_of %arg0 : tensor<?xf32>
- return %0 : !shape.shape
+func @test_shape_of(%arg0: tensor<?xf32>) -> tensor<?xindex> {
+ %0 = shape.shape_of %arg0 : tensor<?xf32> -> tensor<?xindex>
+ return %0 : tensor<?xindex>
}
func @test_constraints() {
More information about the Mlir-commits
mailing list