[Mlir-commits] [mlir] 23a6564 - [MLIR][Shape] Allow `shape.rank` to operate on extent tensors
Frederik Gossen
llvmlistbot at llvm.org
Fri Jul 24 03:44:02 PDT 2020
Author: Frederik Gossen
Date: 2020-07-24T10:43:39Z
New Revision: 23a65648c0cd412becbef9b914e2e33f6114ba87
URL: https://github.com/llvm/llvm-project/commit/23a65648c0cd412becbef9b914e2e33f6114ba87
DIFF: https://github.com/llvm/llvm-project/commit/23a65648c0cd412becbef9b914e2e33f6114ba87.diff
LOG: [MLIR][Shape] Allow `shape.rank` to operate on extent tensors
Differential Revision: https://reviews.llvm.org/D84429
Added:
Modified:
mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td
mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
mlir/lib/Dialect/Shape/IR/Shape.cpp
mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
mlir/test/Dialect/Shape/canonicalize.mlir
mlir/test/Dialect/Shape/invalid.mlir
mlir/test/Dialect/Shape/ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td
index 3cbebe723921..4c00c1f3c8f7 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td
@@ -122,4 +122,6 @@ def Shape_ShapeOrExtentTensorType : AnyTypeOf<[Shape_ShapeType,
Shape_ExtentTensorType],
"shape or extent tensor">;
+def Shape_SizeOrIndexType : AnyTypeOf<[Shape_SizeType, Index], "size or index">;
+
#endif // SHAPE_BASE_TD
diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
index 70f8d75748f4..014b72cd1339 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
@@ -201,12 +201,13 @@ def Shape_RankOp : Shape_Op<"rank", [NoSideEffect]> {
}];
let arguments = (ins Shape_ShapeOrExtentTensorType:$shape);
- let results = (outs Shape_SizeType:$rank);
+ let results = (outs Shape_SizeOrIndexType:$rank);
- let assemblyFormat = "$shape `:` type($shape) attr-dict";
+ let assemblyFormat = "$shape `:` type($shape) `->` type($rank) attr-dict";
let hasFolder = 1;
let hasCanonicalizer = 1;
+ let verifier = [{ return ::verify(*this); }];
}
def Shape_ToExtentTensorOp : Shape_Op<"to_extent_tensor", [NoSideEffect]> {
diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index c5f11a9a95d3..a7a9cb97e76b 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -8,6 +8,7 @@
#include "mlir/Dialect/Shape/IR/Shape.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Traits.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/DialectImplementation.h"
@@ -52,6 +53,8 @@ Operation *ShapeDialect::materializeConstant(OpBuilder &builder,
return builder.create<ConstSizeOp>(loc, type, value.cast<IntegerAttr>());
if (type.isa<WitnessType>())
return builder.create<ConstWitnessOp>(loc, type, value.cast<BoolAttr>());
+ if (type.isa<IndexType>())
+ return builder.create<ConstantOp>(loc, type, value);
return nullptr;
}
@@ -563,7 +566,17 @@ void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape,
// RankOp
//===----------------------------------------------------------------------===//
-OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) {
+static LogicalResult verify(shape::RankOp op) {
+ Type argTy = op.shape().getType();
+ Type resultTy = op.rank().getType();
+ if (argTy.isa<ShapeType>() && !resultTy.isa<SizeType>())
+ return op.emitOpError()
+ << "if operand is of type `shape` then the result must be of type "
+ "`size` to propagate potential errors";
+ return success();
+}
+
+OpFoldResult shape::RankOp::fold(ArrayRef<Attribute> operands) {
auto shape = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
if (!shape)
return {};
@@ -587,10 +600,11 @@ OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) {
/// %rank = shape.const_size 3
namespace {
-struct RankShapeOfCanonicalizationPattern : public OpRewritePattern<RankOp> {
- using OpRewritePattern<RankOp>::OpRewritePattern;
+struct RankShapeOfCanonicalizationPattern
+ : public OpRewritePattern<shape::RankOp> {
+ using OpRewritePattern<shape::RankOp>::OpRewritePattern;
- LogicalResult matchAndRewrite(RankOp op,
+ LogicalResult matchAndRewrite(shape::RankOp op,
PatternRewriter &rewriter) const override {
auto shapeOfOp = op.shape().getDefiningOp<ShapeOfOp>();
if (!shapeOfOp)
@@ -599,15 +613,18 @@ struct RankShapeOfCanonicalizationPattern : public OpRewritePattern<RankOp> {
shapeOfOp.arg().getType().dyn_cast<RankedTensorType>();
if (!rankedTensorType)
return failure();
+ assert(op.getType().isa<IndexType>() &&
+ "expected `rank(shape_of( ... )]` based on a shaped argument to "
+ "yield an index type");
int64_t rank = rankedTensorType.getRank();
- rewriter.replaceOpWithNewOp<ConstSizeOp>(op.getOperation(), rank);
+ rewriter.replaceOpWithNewOp<ConstantIndexOp>(op.getOperation(), rank);
return success();
}
};
} // namespace
-void RankOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns,
- MLIRContext *context) {
+void shape::RankOp::getCanonicalizationPatterns(
+ OwningRewritePatternList &patterns, MLIRContext *context) {
patterns.insert<RankShapeOfCanonicalizationPattern>(context);
}
diff --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
index 67fb7cdb7910..934a28a6ed44 100644
--- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
+++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
@@ -122,12 +122,12 @@ func @shape_of_dyn(%arg : tensor<1x5x?xf32>) {
// Convert `rank` to `dim` of the first dimension.
// CHECK-LABEL: @rank
// CHECK-SAME: (%[[SHAPE:.*]]: tensor<?xindex>) -> index
-func @rank(%shape : !shape.shape) -> !shape.size {
- // CHECK-DAG: %[[C0:.*]] = constant 0 : index
- // CHECK-DAG: %[[RESULT:.*]] = dim %[[SHAPE]], %[[C0]]
- // CHECK-DAG: return %[[RESULT]] : index
- %rank = shape.rank %shape : !shape.shape
- return %rank : !shape.size
+func @rank(%shape : tensor<?xindex>) -> index {
+ // CHECK: %[[C0:.*]] = constant 0 : index
+ // CHECK: %[[RESULT:.*]] = dim %[[SHAPE]], %[[C0]]
+ // CHECK: return %[[RESULT]] : index
+ %rank = shape.rank %shape : tensor<?xindex> -> index
+ return %rank : index
}
// -----
diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index e2874e09cc8a..e5b77a870a85 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -496,10 +496,10 @@ func @broadcastable_on_extent_tensors(%arg : tensor<?xindex>) {
// Fold `rank` based on constant shape.
// CHECK-LABEL: @fold_rank
func @fold_rank() -> !shape.size {
- // CHECK-DAG: %[[RESULT:.*]] = shape.const_size 5
- // CHECK-DAG: return %[[RESULT]] : !shape.size
+ // CHECK: %[[RESULT:.*]] = shape.const_size 5
+ // CHECK: return %[[RESULT]] : !shape.size
%shape = shape.const_shape [3, 4, 5, 6, 7] : !shape.shape
- %rank = shape.rank %shape : !shape.shape
+ %rank = shape.rank %shape : !shape.shape -> !shape.size
return %rank : !shape.size
}
@@ -509,38 +509,64 @@ func @fold_rank() -> !shape.size {
// CHECK-LABEL: @dont_fold_rank
// CHECK-SAME: (%[[SHAPE:.*]]: !shape.shape) -> !shape.size
func @dont_fold_rank(%shape : !shape.shape) -> !shape.size {
- // CHECK-DAG: %[[RESULT:.*]] = shape.rank %[[SHAPE]]
- // CHECK-DAG: return %[[RESULT]] : !shape.size
- %rank = shape.rank %shape : !shape.shape
+ // CHECK: %[[RESULT:.*]] = shape.rank %[[SHAPE]]
+ // CHECK: return %[[RESULT]] : !shape.size
+ %rank = shape.rank %shape : !shape.shape -> !shape.size
return %rank : !shape.size
}
// -----
+// Fold `rank` based on constant extent tensor.
+// CHECK-LABEL: @fold_rank
+func @fold_rank() -> index {
+ // CHECK: %[[RESULT:.*]] = constant 5 : index
+ // CHECK: return %[[RESULT]] : index
+ %shape = shape.const_shape [3, 4, 5, 6, 7] : tensor<?xindex>
+ %rank = shape.rank %shape : tensor<?xindex> -> index
+ return %rank : index
+}
+
+// -----
+
+// Do not fold `rank` for non-constant extent tensors.
+// CHECK-LABEL: @dont_fold_rank
+// CHECK-SAME: (%[[SHAPE:.*]]: tensor<?xindex>) -> index
+func @dont_fold_rank(%shape : tensor<?xindex>) -> index {
+ // CHECK: %[[RESULT:.*]] = shape.rank %[[SHAPE]] : tensor<?xindex> -> index
+ // CHECK: return %[[RESULT]] : index
+ %rank = shape.rank %shape : tensor<?xindex> -> index
+ return %rank : index
+}
+
+// -----
+
// Canonicalize `rank` when shape is derived from ranked tensor.
// CHECK-LABEL: @canonicalize_rank
-func @canonicalize_rank(%arg : tensor<1x2x?xf32>) -> !shape.size {
- // CHECK-DAG: %[[RESULT:.*]] = shape.const_size 3
- // CHECK-DAG: return %[[RESULT]] : !shape.size
+func @canonicalize_rank(%arg : tensor<1x2x?xf32>) -> index {
+ // CHECK: %[[RESULT:.*]] = constant 3 : index
+ // CHECK: return %[[RESULT]] : index
%shape = shape.shape_of %arg : tensor<1x2x?xf32> -> tensor<?xindex>
- %rank = shape.rank %shape : tensor<?xindex>
- return %rank : !shape.size
+ %rank = shape.rank %shape : tensor<?xindex> -> index
+ return %rank : index
}
// -----
// Do not canonicalize `rank` when shape is derived from unranked tensor.
// CHECK-LABEL: @dont_canonicalize_rank
-// CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>) -> !shape.size
-func @dont_canonicalize_rank(%arg : tensor<*xf32>) -> !shape.size {
- // CHECK-DAG: %[[SHAPE:.*]] = shape.shape_of %[[ARG]] : tensor<*xf32> -> tensor<?xindex>
- // CHECK-DAG: %[[SIZE:.*]] = shape.rank %[[SHAPE]]
- // CHECK-DAG: return %[[SIZE]] : !shape.size
+// CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>) -> index
+func @dont_canonicalize_rank(%arg : tensor<*xf32>) -> index {
+ // CHECK: %[[SHAPE:.*]] = shape.shape_of %[[ARG]] : tensor<*xf32> -> tensor<?xindex>
+ // CHECK: %[[SIZE:.*]] = shape.rank %[[SHAPE]]
+ // CHECK: return %[[SIZE]] : index
%shape = shape.shape_of %arg : tensor<*xf32> -> tensor<?xindex>
- %rank = shape.rank %shape : tensor<?xindex>
- return %rank : !shape.size
+ %rank = shape.rank %shape : tensor<?xindex> -> index
+ return %rank : index
}
+// -----
+
// Canonicalize redundant conversion from `index` to `size` and back.
// CHECK-LABEL: @index_to_size_to_index
// CHECK-SAME: (%[[IDX:.*]]: index) -> index
diff --git a/mlir/test/Dialect/Shape/invalid.mlir b/mlir/test/Dialect/Shape/invalid.mlir
index 23d0daf79378..ae25ba90c360 100644
--- a/mlir/test/Dialect/Shape/invalid.mlir
+++ b/mlir/test/Dialect/Shape/invalid.mlir
@@ -95,3 +95,10 @@ func @shape_of(%value_arg : !shape.value_shape,
%1 = shape.shape_of %shaped_arg : tensor<?x3x4xf32> -> !shape.shape
}
+// -----
+
+func @rank(%arg : !shape.shape) {
+ // expected-error at +1 {{if operand is of type `shape` then the result must be of type `size` to propagate potential errors}}
+ %0 = shape.rank %arg : !shape.shape -> index
+}
+
diff --git a/mlir/test/Dialect/Shape/ops.mlir b/mlir/test/Dialect/Shape/ops.mlir
index d0275aaf692e..f023c02c510c 100644
--- a/mlir/test/Dialect/Shape/ops.mlir
+++ b/mlir/test/Dialect/Shape/ops.mlir
@@ -137,13 +137,13 @@ func @test_from_extent_tensor(%arg: tensor<?xindex>) -> !shape.shape {
}
func @rank(%shape : !shape.shape) -> !shape.size {
- %rank = shape.rank %shape : !shape.shape
+ %rank = shape.rank %shape : !shape.shape -> !shape.size
return %rank : !shape.size
}
-func @rank_on_extent_tensor(%shape : tensor<?xindex>) -> !shape.size {
- %rank = shape.rank %shape : tensor<?xindex>
- return %rank : !shape.size
+func @rank_on_extent_tensor(%shape : tensor<?xindex>) -> index {
+ %rank = shape.rank %shape : tensor<?xindex> -> index
+ return %rank : index
}
More information about the Mlir-commits
mailing list