[Mlir-commits] [mlir] eaf4913 - [MLIR][Shape] Realize `shape` to `std` lowering with declarative patterns

Frederik Gossen llvmlistbot at llvm.org
Thu Jun 18 00:54:15 PDT 2020


Author: Frederik Gossen
Date: 2020-06-18T07:53:44Z
New Revision: eaf49130a9bae3f83df6244ae90319f455b1571b

URL: https://github.com/llvm/llvm-project/commit/eaf49130a9bae3f83df6244ae90319f455b1571b
DIFF: https://github.com/llvm/llvm-project/commit/eaf49130a9bae3f83df6244ae90319f455b1571b.diff

LOG: [MLIR][Shape] Realize `shape` to `std` lowering with declarative patterns

Setup declarative rewrite rules to lower the `shape` dialect to the `std`
dialect with two exemplary rules for `from/to_extent_tensor`.

Differential Revision: https://reviews.llvm.org/D82022

Added: 
    mlir/lib/Conversion/ShapeToStandard/ShapeToStandardPatterns.td

Modified: 
    mlir/lib/Conversion/ShapeToStandard/CMakeLists.txt
    mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/ShapeToStandard/CMakeLists.txt b/mlir/lib/Conversion/ShapeToStandard/CMakeLists.txt
index 8750c331859e..4d97af2f235c 100644
--- a/mlir/lib/Conversion/ShapeToStandard/CMakeLists.txt
+++ b/mlir/lib/Conversion/ShapeToStandard/CMakeLists.txt
@@ -1,3 +1,7 @@
+set(LLVM_TARGET_DEFINITIONS ShapeToStandardPatterns.td)
+mlir_tablegen(ShapeToStandardPatterns.inc -gen-rewriters)
+add_public_tablegen_target(ShapeToStandardPatternsIncGen)
+
 add_mlir_conversion_library(MLIRShapeToStandard
   ShapeToStandard.cpp
 
@@ -6,6 +10,7 @@ add_mlir_conversion_library(MLIRShapeToStandard
 
   DEPENDS
   MLIRConversionPassIncGen
+  ShapeToStandardPatternsIncGen
 
   LINK_COMPONENTS
   Core

diff  --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
index a6e5a3783c1c..6cce898112bf 100644
--- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
+++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
@@ -19,6 +19,9 @@ using namespace mlir::shape;
 
 namespace {
 
+/// Generated conversion patterns.
+#include "ShapeToStandardPatterns.inc"
+
 /// Conversion patterns.
 template <typename SrcOpTy, typename DstOpTy>
 class BinaryOpConversion : public OpConversionPattern<SrcOpTy> {
@@ -35,20 +38,6 @@ class BinaryOpConversion : public OpConversionPattern<SrcOpTy> {
   }
 };
 
-class FromExtentTensorOpConversion
-    : public OpConversionPattern<FromExtentTensorOp> {
-public:
-  using OpConversionPattern<FromExtentTensorOp>::OpConversionPattern;
-
-  LogicalResult
-  matchAndRewrite(FromExtentTensorOp op, ArrayRef<Value> operands,
-                  ConversionPatternRewriter &rewriter) const override {
-    FromExtentTensorOp::Adaptor transformed(operands);
-    rewriter.replaceOp(op.getOperation(), transformed.input());
-    return success();
-  }
-};
-
 class IndexToSizeOpConversion : public OpConversionPattern<IndexToSizeOp> {
 public:
   using OpConversionPattern<IndexToSizeOp>::OpConversionPattern;
@@ -75,20 +64,6 @@ class SizeToIndexOpConversion : public OpConversionPattern<SizeToIndexOp> {
   }
 };
 
-class ToExtentTensorOpConversion
-    : public OpConversionPattern<ToExtentTensorOp> {
-public:
-  using OpConversionPattern<ToExtentTensorOp>::OpConversionPattern;
-
-  LogicalResult
-  matchAndRewrite(ToExtentTensorOp op, ArrayRef<Value> operands,
-                  ConversionPatternRewriter &rewriter) const override {
-    ToExtentTensorOp::Adaptor transformed(operands);
-    rewriter.replaceOp(op.getOperation(), transformed.input());
-    return success();
-  }
-};
-
 class ConstSizeOpConverter : public OpConversionPattern<ConstSizeOp> {
 public:
   using OpConversionPattern<ConstSizeOp>::OpConversionPattern;
@@ -122,6 +97,7 @@ class ShapeTypeConverter : public TypeConverter {
 /// Conversion pass.
 class ConvertShapeToStandardPass
     : public ConvertShapeToStandardBase<ConvertShapeToStandardPass> {
+
   void runOnOperation() override {
     // Setup type conversion.
     MLIRContext &ctx = getContext();
@@ -151,15 +127,14 @@ class ConvertShapeToStandardPass
 
 void mlir::populateShapeToStandardConversionPatterns(
     OwningRewritePatternList &patterns, MLIRContext *ctx) {
+  populateWithGenerated(ctx, &patterns);
   // clang-format off
   patterns.insert<
       BinaryOpConversion<AddOp, AddIOp>,
       BinaryOpConversion<MulOp, MulIOp>,
       ConstSizeOpConverter,
-      FromExtentTensorOpConversion,
       IndexToSizeOpConversion,
-      SizeToIndexOpConversion,
-      ToExtentTensorOpConversion>(ctx);
+      SizeToIndexOpConversion>(ctx);
   // clang-format on
 }
 

diff  --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandardPatterns.td b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandardPatterns.td
new file mode 100644
index 000000000000..3ad54215d8ed
--- /dev/null
+++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandardPatterns.td
@@ -0,0 +1,12 @@
+include "mlir/Dialect/Shape/IR/ShapeOps.td"
+include "mlir/Dialect/StandardOps/IR/Ops.td"
+
+// Convert `from_extent_tensor` and `to_extent_tensor` to no-ops as shapes will
+// be represented as extent tensors.
+def FromExtentTensorOpConversion : Pat<
+    (Shape_FromExtentTensorOp $input),
+    (replaceWithValue $input)>;
+def ToExtentTensorOpConversion : Pat<
+    (Shape_ToExtentTensorOp $input),
+    (replaceWithValue $input)>;
+


        


More information about the Mlir-commits mailing list