[Mlir-commits] [mlir] 3324598 - [mlir] Add a pass to remove all shape.cstr_ and assuming_ ops.

Tres Popp llvmlistbot at llvm.org
Thu Jun 18 04:32:33 PDT 2020


Author: Tres Popp
Date: 2020-06-18T13:31:30+02:00
New Revision: 3324598844a28850527b8abc8b83579ad7ab94a2

URL: https://github.com/llvm/llvm-project/commit/3324598844a28850527b8abc8b83579ad7ab94a2
DIFF: https://github.com/llvm/llvm-project/commit/3324598844a28850527b8abc8b83579ad7ab94a2.diff

LOG: [mlir] Add a pass to remove all shape.cstr_ and assuming_ ops.

Summary:
This is to provide a utility to remove unsupported constraints or for
pipelines that happen to receive these but cannot lower them due to not
supporting assertions.

Differential Revision: https://reviews.llvm.org/D81560

Added: 
    mlir/lib/Dialect/Shape/Transforms/RemoveShapeConstraints.cpp
    mlir/test/Dialect/Shape/remove-shape-constraints.mlir

Modified: 
    mlir/include/mlir/Dialect/Shape/IR/Shape.h
    mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
    mlir/include/mlir/Dialect/Shape/Transforms/Passes.h
    mlir/include/mlir/Dialect/Shape/Transforms/Passes.td
    mlir/lib/Dialect/Shape/IR/Shape.cpp
    mlir/lib/Dialect/Shape/Transforms/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Shape/IR/Shape.h b/mlir/include/mlir/Dialect/Shape/IR/Shape.h
index 994975834d5d..b78b7304592d 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/Shape.h
+++ b/mlir/include/mlir/Dialect/Shape/IR/Shape.h
@@ -22,6 +22,8 @@
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 
 namespace mlir {
+class PatternRewriter;
+
 namespace shape {
 
 namespace ShapeTypes {

diff  --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
index a6f579cec505..49714c96b1cd 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
@@ -375,7 +375,7 @@ def Shape_SizeToIndexOp : Shape_Op<"size_to_index", [NoSideEffect]> {
   let hasFolder = 1;
 }
 
-def Shape_YieldOp : Shape_Op<"yield", 
+def Shape_YieldOp : Shape_Op<"yield",
     [HasParent<"ReduceOp">,
      NoSideEffect,
      ReturnLike,
@@ -533,6 +533,14 @@ def Shape_AssumingOp : Shape_Op<"assuming",
   let printer = [{ return ::print(p, *this); }];
   let parser = [{ return ::parse$cppClass(parser, result); }];
 
+  let extraClassDeclaration = [{
+    // Inline the region into the region containing the AssumingOp and delete
+    // the AssumingOp.
+    //
+    // This does no checks on the inputs to the AssumingOp.
+    static void inlineRegionIntoParent(AssumingOp &op, PatternRewriter &rewriter);
+  }];
+
   let hasCanonicalizer = 1;
 }
 

diff  --git a/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h b/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h
index 7e6065341608..e8d2167916d0 100644
--- a/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h
@@ -18,6 +18,7 @@
 
 namespace mlir {
 
+class FunctionPass;
 class MLIRContext;
 class OwningRewritePatternList;
 class Pass;
@@ -30,6 +31,17 @@ std::unique_ptr<Pass> createShapeToShapeLowering();
 /// Collects a set of patterns to rewrite ops within the Shape dialect.
 void populateShapeRewritePatterns(MLIRContext *context,
                                   OwningRewritePatternList &patterns);
+
+// Collects a set of patterns to replace all constraints with passing witnesses.
+// This is intended to then allow all ShapeConstraint related ops and data to
+// have no effects and allow them to be freely removed such as through
+// canonicalization and dead code elimination.
+//
+// After this pass, no cstr_ operations exist.
+void populateRemoveShapeConstraintsPatterns(OwningRewritePatternList &patterns,
+                                            MLIRContext *ctx);
+std::unique_ptr<FunctionPass> createRemoveShapeConstraintsPass();
+
 } // end namespace mlir
 
 #endif // MLIR_DIALECT_SHAPE_TRANSFORMS_PASSES_H_

diff  --git a/mlir/include/mlir/Dialect/Shape/Transforms/Passes.td b/mlir/include/mlir/Dialect/Shape/Transforms/Passes.td
index 46dc4dc37160..022bd3773ce2 100644
--- a/mlir/include/mlir/Dialect/Shape/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Shape/Transforms/Passes.td
@@ -11,6 +11,11 @@
 
 include "mlir/Pass/PassBase.td"
 
+def RemoveShapeConstraints : FunctionPass<"remove-shape-constraints"> {
+  let summary = "Replace all cstr_ ops with a true witness";
+  let constructor = "mlir::createRemoveShapeConstraintsPass()";
+}
+
 def ShapeToShapeLowering : FunctionPass<"shape-to-shape-lowering"> {
   let summary = "Legalize Shape dialect to be convertible to Standard";
   let constructor = "mlir::createShapeToShapeLowering()";

diff  --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 4a876e16bcec..664c0cb05b80 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -168,22 +168,7 @@ struct AssumingWithTrue : public OpRewritePattern<AssumingOp> {
     if (!witness || !witness.passingAttr())
       return failure();
 
-    auto *blockBeforeAssuming = rewriter.getInsertionBlock();
-    auto *assumingBlock = op.getBody();
-    auto initPosition = rewriter.getInsertionPoint();
-    auto *blockAfterAssuming =
-        rewriter.splitBlock(blockBeforeAssuming, initPosition);
-
-    // Remove the AssumingOp and AssumingYieldOp.
-    auto &yieldOp = assumingBlock->back();
-    rewriter.inlineRegionBefore(op.doRegion(), blockAfterAssuming);
-    rewriter.replaceOp(op, yieldOp.getOperands());
-    rewriter.eraseOp(&yieldOp);
-
-    // Merge blocks together as there was no branching behavior from the
-    // AssumingOp.
-    rewriter.mergeBlocks(assumingBlock, blockBeforeAssuming);
-    rewriter.mergeBlocks(blockAfterAssuming, blockBeforeAssuming);
+    AssumingOp::inlineRegionIntoParent(op, rewriter);
     return success();
   }
 };
@@ -191,10 +176,30 @@ struct AssumingWithTrue : public OpRewritePattern<AssumingOp> {
 
 void AssumingOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns,
                                              MLIRContext *context) {
-  // If taking a passing witness, inline region
+  // If taking a passing witness, inline region.
   patterns.insert<AssumingWithTrue>(context);
 }
 
+void AssumingOp::inlineRegionIntoParent(AssumingOp &op,
+                                        PatternRewriter &rewriter) {
+  auto *blockBeforeAssuming = rewriter.getInsertionBlock();
+  auto *assumingBlock = op.getBody();
+  auto initPosition = rewriter.getInsertionPoint();
+  auto *blockAfterAssuming =
+      rewriter.splitBlock(blockBeforeAssuming, initPosition);
+
+  // Remove the AssumingOp and AssumingYieldOp.
+  auto &yieldOp = assumingBlock->back();
+  rewriter.inlineRegionBefore(op.doRegion(), blockAfterAssuming);
+  rewriter.replaceOp(op, yieldOp.getOperands());
+  rewriter.eraseOp(&yieldOp);
+
+  // Merge blocks together as there was no branching behavior from the
+  // AssumingOp.
+  rewriter.mergeBlocks(assumingBlock, blockBeforeAssuming);
+  rewriter.mergeBlocks(blockAfterAssuming, blockBeforeAssuming);
+}
+
 //===----------------------------------------------------------------------===//
 // AssumingAllOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Shape/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Shape/Transforms/CMakeLists.txt
index 6f812b60658e..987f9c544b33 100644
--- a/mlir/lib/Dialect/Shape/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Shape/Transforms/CMakeLists.txt
@@ -1,4 +1,5 @@
 add_mlir_dialect_library(MLIRShapeOpsTransforms
+  RemoveShapeConstraints.cpp
   ShapeToShapeLowering.cpp
 
   ADDITIONAL_HEADER_DIRS

diff  --git a/mlir/lib/Dialect/Shape/Transforms/RemoveShapeConstraints.cpp b/mlir/lib/Dialect/Shape/Transforms/RemoveShapeConstraints.cpp
new file mode 100644
index 000000000000..641b4bc38e43
--- /dev/null
+++ b/mlir/lib/Dialect/Shape/Transforms/RemoveShapeConstraints.cpp
@@ -0,0 +1,64 @@
+//===-- RemoveShapeConstraints.cpp - Remove Shape Cstr and Assuming Ops ---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "PassDetail.h"
+#include "mlir/Dialect/Shape/IR/Shape.h"
+#include "mlir/Dialect/Shape/Transforms/Passes.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+using namespace mlir;
+
+namespace {
+/// Removal patterns.
+class RemoveCstrBroadcastableOp
+    : public OpRewritePattern<shape::CstrBroadcastableOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(shape::CstrBroadcastableOp op,
+                                PatternRewriter &rewriter) const override {
+    rewriter.replaceOpWithNewOp<shape::ConstWitnessOp>(op.getOperation(), true);
+    return success();
+  }
+};
+
+class RemoveCstrEqOp : public OpRewritePattern<shape::CstrEqOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(shape::CstrEqOp op,
+                                PatternRewriter &rewriter) const override {
+    rewriter.replaceOpWithNewOp<shape::ConstWitnessOp>(op.getOperation(), true);
+    return success();
+  }
+};
+
+/// Removal pass.
+class RemoveShapeConstraintsPass
+    : public RemoveShapeConstraintsBase<RemoveShapeConstraintsPass> {
+
+  void runOnFunction() override {
+    MLIRContext &ctx = getContext();
+
+    OwningRewritePatternList patterns;
+    populateRemoveShapeConstraintsPatterns(patterns, &ctx);
+
+    applyPatternsAndFoldGreedily(getFunction(), patterns);
+  }
+};
+
+} // namespace
+
+void mlir::populateRemoveShapeConstraintsPatterns(
+    OwningRewritePatternList &patterns, MLIRContext *ctx) {
+  patterns.insert<RemoveCstrBroadcastableOp, RemoveCstrEqOp>(ctx);
+}
+
+std::unique_ptr<FunctionPass> mlir::createRemoveShapeConstraintsPass() {
+  return std::make_unique<RemoveShapeConstraintsPass>();
+}

diff  --git a/mlir/test/Dialect/Shape/remove-shape-constraints.mlir b/mlir/test/Dialect/Shape/remove-shape-constraints.mlir
new file mode 100644
index 000000000000..69887c6994f4
--- /dev/null
+++ b/mlir/test/Dialect/Shape/remove-shape-constraints.mlir
@@ -0,0 +1,56 @@
+// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -remove-shape-constraints -canonicalize <%s | FileCheck %s --dump-input=fail --check-prefixes=CANON,CHECK-BOTH
+// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -remove-shape-constraints <%s | FileCheck %s --dump-input=fail --check-prefixes=REPLACE,CHECK-BOTH
+
+// -----
+// Check that cstr_broadcastable is removed.
+//
+// CHECK-BOTH: func @f
+func @f(%arg0 : !shape.shape, %arg1 : !shape.shape) -> index {
+  // REPLACE-NEXT: %[[WITNESS:.+]] = shape.const_witness true
+  // REPLACE-NOT: shape.cstr_eq
+  // REPLACE: shape.assuming %[[WITNESS]]
+  // CANON-NEXT: test.source
+  // CANON-NEXT: return
+  %0 = shape.cstr_broadcastable %arg0, %arg1
+  %1 = shape.assuming %0 -> index {
+    %2 = "test.source"() : () -> (index)
+    shape.assuming_yield %2 : index
+  }
+  return %1 : index
+}
+
+// -----
+// Check that cstr_eq is removed.
+//
+// CHECK-BOTH: func @f
+func @f(%arg0 : !shape.shape, %arg1 : !shape.shape) -> index {
+  // REPLACE-NEXT: %[[WITNESS:.+]] = shape.const_witness true
+  // REPLACE-NOT: shape.cstr_eq
+  // REPLACE: shape.assuming %[[WITNESS]]
+  // CANON-NEXT: test.source
+  // CANON-NEXT: return
+  %0 = shape.cstr_eq %arg0, %arg1
+  %1 = shape.assuming %0 -> index {
+    %2 = "test.source"() : () -> (index)
+    shape.assuming_yield %2 : index
+  }
+  return %1 : index
+}
+
+// -----
+// With a non-const value, we cannot fold away the code, but all constraints
+// should be removed still.
+//
+// CHECK-BOTH: func @f
+func @f(%arg0 : !shape.shape, %arg1 : !shape.shape) -> index {
+  // CANON-NEXT: test.source
+  // CANON-NEXT: return
+  %0 = shape.cstr_broadcastable %arg0, %arg1
+  %1 = shape.cstr_eq %arg0, %arg1
+  %2 = shape.assuming_all %0, %1
+  %3 = shape.assuming %0 -> index {
+    %4 = "test.source"() : () -> (index)
+    shape.assuming_yield %4 : index
+  }
+  return %3 : index
+}


        


More information about the Mlir-commits mailing list