[Mlir-commits] [mlir] 5c9c4ad - Add the inline interface to the shape dialect
Feng Liu
llvmlistbot at llvm.org
Fri Aug 7 23:30:04 PDT 2020
Author: Feng Liu
Date: 2020-08-07T23:29:43-07:00
New Revision: 5c9c4ade9d1269e83fdf8e5d8f62b376a76da2b0
URL: https://github.com/llvm/llvm-project/commit/5c9c4ade9d1269e83fdf8e5d8f62b376a76da2b0
DIFF: https://github.com/llvm/llvm-project/commit/5c9c4ade9d1269e83fdf8e5d8f62b376a76da2b0.diff
LOG: Add the inline interface to the shape dialect
This patch also fixes a minor issue that shape.rank should allow
returning !shape.size. The dialect doc has such an example for
shape.rank.
Differential Revision: https://reviews.llvm.org/D85556
Added:
Modified:
mlir/lib/Dialect/Shape/IR/Shape.cpp
mlir/test/Dialect/Shape/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 47c592e51a40..7cf02149fa25 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -14,6 +14,7 @@
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/StandardTypes.h"
+#include "mlir/Transforms/InliningUtils.h"
#include "llvm/ADT/SmallString.h"
#include "llvm/Support/raw_ostream.h"
@@ -59,6 +60,32 @@ static LogicalResult verifyShapeOrExtentTensorOp(Operation *op) {
return success();
}
+//===----------------------------------------------------------------------===//
+// InlinerInterface
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// This class defines the interface for inlining shape dialect ops.
+struct ShapeInlinerInterface : public DialectInlinerInterface {
+ using DialectInlinerInterface::DialectInlinerInterface;
+
+ // Returns true if the given region 'src' can be inlined into the region
+ // 'dest' that is attached to an operation registered to the current dialect.
+ bool isLegalToInline(Region *dest, Region *src,
+ BlockAndValueMapping &) const final {
+ return true;
+ }
+
+ // Returns true if the given operation 'op', that is registered to this
+ // dialect, can be inlined into the region 'dest' that is attached to an
+ // operation registered to the current dialect.
+ bool isLegalToInline(Operation *op, Region *dest,
+ BlockAndValueMapping &) const final {
+ return true;
+ }
+};
+} // namespace
+
void ShapeDialect::initialize() {
addOperations<
#define GET_OP_LIST
@@ -66,6 +93,7 @@ void ShapeDialect::initialize() {
>();
addTypes<ComponentType, ElementType, ShapeType, SizeType, ValueShapeType,
WitnessType>();
+ addInterfaces<ShapeInlinerInterface>();
// Allow unknown operations during prototyping and testing. As the dialect is
// still evolving it makes it simple to start with an unregistered ops and
// try
diff erent variants before actually defining the op.
@@ -640,11 +668,14 @@ struct RankShapeOfCanonicalizationPattern
shapeOfOp.arg().getType().dyn_cast<RankedTensorType>();
if (!rankedTensorType)
return failure();
- assert(op.getType().isa<IndexType>() &&
- "expected `rank(shape_of( ... )]` based on a shaped argument to "
- "yield an index type");
int64_t rank = rankedTensorType.getRank();
- rewriter.replaceOpWithNewOp<ConstantIndexOp>(op.getOperation(), rank);
+ if (op.getType().isa<IndexType>()) {
+ rewriter.replaceOpWithNewOp<ConstantIndexOp>(op.getOperation(), rank);
+ } else if (op.getType().isa<shape::SizeType>()) {
+ rewriter.replaceOpWithNewOp<shape::ConstSizeOp>(op.getOperation(), rank);
+ } else {
+ return failure();
+ }
return success();
}
};
diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index 21c5a68c3adc..8a38b42c5a71 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -624,6 +624,18 @@ func @canonicalize_rank(%arg : tensor<1x2x?xf32>) -> index {
// -----
+// Canonicalize `rank` when shape is derived from ranked tensor.
+// CHECK-LABEL: @canonicalize_rank
+func @canonicalize_rank_size(%arg : tensor<1x2x?xf32>) -> !shape.size {
+ // CHECK: %[[RESULT:.*]] = shape.const_size 3
+ // CHECK: return %[[RESULT]] : !shape.size
+ %shape = shape.shape_of %arg : tensor<1x2x?xf32> -> !shape.shape
+ %rank = shape.rank %shape : !shape.shape -> !shape.size
+ return %rank : !shape.size
+}
+
+// -----
+
// Do not canonicalize `rank` when shape is derived from unranked tensor.
// CHECK-LABEL: @dont_canonicalize_rank
// CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>) -> index
More information about the Mlir-commits
mailing list