[Mlir-commits] [mlir] 3e037f8 - [MLIR][Shape] Derive more concrete type for `shape.shape_of`
Frederik Gossen
llvmlistbot at llvm.org
Wed Apr 28 01:51:11 PDT 2021
Author: Frederik Gossen
Date: 2021-04-28T10:50:53+02:00
New Revision: 3e037f8f0e26acab8cc784ea4c7d05da79f7c22e
URL: https://github.com/llvm/llvm-project/commit/3e037f8f0e26acab8cc784ea4c7d05da79f7c22e
DIFF: https://github.com/llvm/llvm-project/commit/3e037f8f0e26acab8cc784ea4c7d05da79f7c22e.diff
LOG: [MLIR][Shape] Derive more concrete type for `shape.shape_of`
Also create all extent tensor constants with const_shape op.
Differential Revision: https://reviews.llvm.org/D99197
Added:
Modified:
mlir/include/mlir/Dialect/Shape/IR/Shape.h
mlir/lib/Dialect/Shape/IR/Shape.cpp
mlir/test/Dialect/Shape/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Shape/IR/Shape.h b/mlir/include/mlir/Dialect/Shape/IR/Shape.h
index 47e61a8c47689..570719eff64d5 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/Shape.h
+++ b/mlir/include/mlir/Dialect/Shape/IR/Shape.h
@@ -31,6 +31,9 @@ namespace shape {
/// Alias type for extent tensors.
RankedTensorType getExtentTensorType(MLIRContext *ctx);
+// Check if a type is an extent tensor, e.g., tensor<?xindex>.
+bool isExtentTensorType(Type);
+
// Given an input shape Value, try to obtain the shape's values.
LogicalResult getShapeVec(Value input, SmallVectorImpl<int64_t> &shapeValues);
diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 96618d24747fd..47fd322ba47ce 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -31,6 +31,11 @@ RankedTensorType shape::getExtentTensorType(MLIRContext *ctx) {
return RankedTensorType::get({ShapedType::kDynamicSize}, IndexType::get(ctx));
}
+bool shape::isExtentTensorType(Type type) {
+ auto ranked = type.dyn_cast<RankedTensorType>();
+ return ranked && ranked.getRank() == 1 && ranked.getElementType().isIndex();
+}
+
LogicalResult shape::getShapeVec(Value input,
SmallVectorImpl<int64_t> &shapeValues) {
if (auto inputOp = input.getDefiningOp<ShapeOfOp>()) {
@@ -123,8 +128,7 @@ void ShapeDialect::initialize() {
Operation *ShapeDialect::materializeConstant(OpBuilder &builder,
Attribute value, Type type,
Location loc) {
- if (type.isa<ShapeType>() ||
- type == getExtentTensorType(builder.getContext()))
+ if (type.isa<ShapeType>() || isExtentTensorType(type))
return builder.create<ConstShapeOp>(loc, type,
value.cast<DenseIntElementsAttr>());
if (type.isa<SizeType>())
@@ -1148,10 +1152,15 @@ OpFoldResult ShapeOfOp::fold(ArrayRef<Attribute>) {
}
void ShapeOfOp::build(OpBuilder &builder, OperationState &result, Value arg) {
- Type type = arg.getType().isa<ShapedType>()
- ? (Type)getExtentTensorType(builder.getContext())
- : (Type)builder.getType<ShapeType>();
- return ShapeOfOp::build(builder, result, type, arg);
+ if (auto shapedTy = arg.getType().dyn_cast<ShapedType>()) {
+ int64_t rank =
+ shapedTy.hasRank() ? shapedTy.getRank() : ShapedType::kDynamicSize;
+ Type indexTy = builder.getIndexType();
+ Type extentTensorTy = RankedTensorType::get({rank}, indexTy);
+ return ShapeOfOp::build(builder, result, extentTensorTy, arg);
+ }
+ Type shapeTy = builder.getType<ShapeType>();
+ return ShapeOfOp::build(builder, result, shapeTy, arg);
}
namespace {
diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index 6dc48b1e732aa..3876e9ca34fc4 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -138,7 +138,7 @@ func @all_but_one_empty(%arg0 : !shape.shape) -> !shape.shape {
// CHECK-LABEL: @partial_folding
// CHECK-SAME: (%[[ARG:.*]]: !shape.shape)
func @partial_folding(%arg0 : !shape.shape) -> !shape.shape {
- // CHECK: %[[CST_SHAPE:.*]] = constant dense<[1, 2, 3]> : tensor<3xindex>
+ // CHECK: %[[CST_SHAPE:.*]] = shape.const_shape [1, 2, 3] : tensor<3xindex>
// CHECK: %[[RESULT:.*]] = shape.broadcast %[[ARG]], %[[CST_SHAPE]] : !shape.shape, tensor<3xindex> -> !shape.shape
// CHECK: return %[[RESULT]]
%0 = shape.const_shape [2, 1] : !shape.shape
@@ -188,7 +188,7 @@ func @f() -> !shape.shape {
// Basic case.
// CHECK-LABEL: func @f
func @f() -> tensor<2xindex> {
- // CHECK: constant dense<[0, 1]> : tensor<2xindex>
+ // CHECK: shape.const_shape [0, 1] : tensor<2xindex>
%cs = shape.const_shape [0, 1] : !shape.shape
%0 = shape.to_extent_tensor %cs : !shape.shape -> tensor<2xindex>
return %0 : tensor<2xindex>
@@ -1146,7 +1146,7 @@ func @fold_assuming_all_single_element(%arg: tensor<?xindex>) {
// Verify that tensor.cast folding uses the correct type
// CHECK-LABEL: @fold_tensor.cast_of_const_shape_returned
func @fold_tensor.cast_of_const_shape_returned(%arg: i1) -> tensor<1xindex> {
- // CHECK: constant dense<2> : tensor<1xindex>
+ // CHECK: shape.const_shape [2] : tensor<1xindex>
// CHECK-NOT: tensor.cast
%0 = shape.const_shape [2] : tensor<?xindex>
%1 = tensor.cast %0 : tensor<?xindex> to tensor<1xindex>
@@ -1325,14 +1325,13 @@ func @min_same_arg(%a: !shape.shape) -> !shape.shape {
// CHECK: return %[[SHAPE]]
return %1 : !shape.shape
}
-
// ----
// CHECK-LABEL: @cstr_broadcastable_folding
func @cstr_broadcastable_folding(%arg : tensor<?x4xf32>) {
// CHECK: const_witness true
%0 = shape.shape_of %arg : tensor<?x4xf32> -> tensor<2xindex>
- %1 = constant dense<[4]> : tensor<1xindex>
+ %1 = shape.const_shape [4] : tensor<1xindex>
%2 = shape.cstr_broadcastable %0, %1: tensor<2xindex>, tensor<1xindex>
"use"(%2) : (!shape.witness) -> ()
}
More information about the Mlir-commits
mailing list