[Mlir-commits] [mlir] eaf4913 - [MLIR][Shape] Realize `shape` to `std` lowering with declarative patterns
Frederik Gossen
llvmlistbot at llvm.org
Thu Jun 18 00:54:15 PDT 2020
Author: Frederik Gossen
Date: 2020-06-18T07:53:44Z
New Revision: eaf49130a9bae3f83df6244ae90319f455b1571b
URL: https://github.com/llvm/llvm-project/commit/eaf49130a9bae3f83df6244ae90319f455b1571b
DIFF: https://github.com/llvm/llvm-project/commit/eaf49130a9bae3f83df6244ae90319f455b1571b.diff
LOG: [MLIR][Shape] Realize `shape` to `std` lowering with declarative patterns
Setup declarative rewrite rules to lower the `shape` dialect to the `std`
dialect with two exemplary rules for `from/to_extent_tensor`.
Differential Revision: https://reviews.llvm.org/D82022
Added:
mlir/lib/Conversion/ShapeToStandard/ShapeToStandardPatterns.td
Modified:
mlir/lib/Conversion/ShapeToStandard/CMakeLists.txt
mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/ShapeToStandard/CMakeLists.txt b/mlir/lib/Conversion/ShapeToStandard/CMakeLists.txt
index 8750c331859e..4d97af2f235c 100644
--- a/mlir/lib/Conversion/ShapeToStandard/CMakeLists.txt
+++ b/mlir/lib/Conversion/ShapeToStandard/CMakeLists.txt
@@ -1,3 +1,7 @@
+set(LLVM_TARGET_DEFINITIONS ShapeToStandardPatterns.td)
+mlir_tablegen(ShapeToStandardPatterns.inc -gen-rewriters)
+add_public_tablegen_target(ShapeToStandardPatternsIncGen)
+
add_mlir_conversion_library(MLIRShapeToStandard
ShapeToStandard.cpp
@@ -6,6 +10,7 @@ add_mlir_conversion_library(MLIRShapeToStandard
DEPENDS
MLIRConversionPassIncGen
+ ShapeToStandardPatternsIncGen
LINK_COMPONENTS
Core
diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
index a6e5a3783c1c..6cce898112bf 100644
--- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
+++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
@@ -19,6 +19,9 @@ using namespace mlir::shape;
namespace {
+/// Generated conversion patterns.
+#include "ShapeToStandardPatterns.inc"
+
/// Conversion patterns.
template <typename SrcOpTy, typename DstOpTy>
class BinaryOpConversion : public OpConversionPattern<SrcOpTy> {
@@ -35,20 +38,6 @@ class BinaryOpConversion : public OpConversionPattern<SrcOpTy> {
}
};
-class FromExtentTensorOpConversion
- : public OpConversionPattern<FromExtentTensorOp> {
-public:
- using OpConversionPattern<FromExtentTensorOp>::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(FromExtentTensorOp op, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const override {
- FromExtentTensorOp::Adaptor transformed(operands);
- rewriter.replaceOp(op.getOperation(), transformed.input());
- return success();
- }
-};
-
class IndexToSizeOpConversion : public OpConversionPattern<IndexToSizeOp> {
public:
using OpConversionPattern<IndexToSizeOp>::OpConversionPattern;
@@ -75,20 +64,6 @@ class SizeToIndexOpConversion : public OpConversionPattern<SizeToIndexOp> {
}
};
-class ToExtentTensorOpConversion
- : public OpConversionPattern<ToExtentTensorOp> {
-public:
- using OpConversionPattern<ToExtentTensorOp>::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(ToExtentTensorOp op, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const override {
- ToExtentTensorOp::Adaptor transformed(operands);
- rewriter.replaceOp(op.getOperation(), transformed.input());
- return success();
- }
-};
-
class ConstSizeOpConverter : public OpConversionPattern<ConstSizeOp> {
public:
using OpConversionPattern<ConstSizeOp>::OpConversionPattern;
@@ -122,6 +97,7 @@ class ShapeTypeConverter : public TypeConverter {
/// Conversion pass.
class ConvertShapeToStandardPass
: public ConvertShapeToStandardBase<ConvertShapeToStandardPass> {
+
void runOnOperation() override {
// Setup type conversion.
MLIRContext &ctx = getContext();
@@ -151,15 +127,14 @@ class ConvertShapeToStandardPass
void mlir::populateShapeToStandardConversionPatterns(
OwningRewritePatternList &patterns, MLIRContext *ctx) {
+ populateWithGenerated(ctx, &patterns);
// clang-format off
patterns.insert<
BinaryOpConversion<AddOp, AddIOp>,
BinaryOpConversion<MulOp, MulIOp>,
ConstSizeOpConverter,
- FromExtentTensorOpConversion,
IndexToSizeOpConversion,
- SizeToIndexOpConversion,
- ToExtentTensorOpConversion>(ctx);
+ SizeToIndexOpConversion>(ctx);
// clang-format on
}
diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandardPatterns.td b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandardPatterns.td
new file mode 100644
index 000000000000..3ad54215d8ed
--- /dev/null
+++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandardPatterns.td
@@ -0,0 +1,12 @@
+include "mlir/Dialect/Shape/IR/ShapeOps.td"
+include "mlir/Dialect/StandardOps/IR/Ops.td"
+
+// Convert `from_extent_tensor` and `to_extent_tensor` to no-ops as shapes will
+// be represented as extent tensors.
+def FromExtentTensorOpConversion : Pat<
+ (Shape_FromExtentTensorOp $input),
+ (replaceWithValue $input)>;
+def ToExtentTensorOpConversion : Pat<
+ (Shape_ToExtentTensorOp $input),
+ (replaceWithValue $input)>;
+
More information about the Mlir-commits
mailing list