[Mlir-commits] [mlir] 94ef248 - Revert "[MLIR] Canonicalize `shape.assuming` op to yield only inner values"
Frederik Gossen
llvmlistbot at llvm.org
Tue Mar 23 08:06:47 PDT 2021
Author: Frederik Gossen
Date: 2021-03-23T16:05:55+01:00
New Revision: 94ef248d7b76939cc3caacb83b7e168b82a74764
URL: https://github.com/llvm/llvm-project/commit/94ef248d7b76939cc3caacb83b7e168b82a74764
DIFF: https://github.com/llvm/llvm-project/commit/94ef248d7b76939cc3caacb83b7e168b82a74764.diff
LOG: Revert "[MLIR] Canonicalize `shape.assuming` op to yield only inner values"
This reverts commit 5f8acd4fd233cdce5892958df56ed1f000f75f9e.
Added:
Modified:
mlir/lib/Dialect/Shape/IR/Shape.cpp
mlir/test/Dialect/Shape/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 0feac8793e111..d2a10a9f5dccb 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -11,7 +11,6 @@
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Traits.h"
-#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectImplementation.h"
@@ -269,72 +268,12 @@ struct AssumingWithTrue : public OpRewritePattern<AssumingOp> {
return success();
}
};
-
-// Results of an assuming op that are defined outside its body are available
-// indepentently of the assuming op. There is no need to yield such values. This
-// canonicalization replaces such results with their definition.
-struct AssumingBypassIndependentResult : public OpRewritePattern<AssumingOp> {
- using OpRewritePattern<AssumingOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(AssumingOp op,
- PatternRewriter &rewriter) const override {
- Block *body = op.getBody();
- auto yieldOp = llvm::dyn_cast<AssumingYieldOp>(body->getTerminator());
- if (!yieldOp)
- return failure();
-
- // See if there is at least one result that can bypass the assuming op.
- auto isDefinedInBody = [&](Value val) {
- Operation *def = val.getDefiningOp();
- return def && op->isAncestor(def);
- };
- if (llvm::all_of(yieldOp.operands(), isDefinedInBody))
- return failure();
-
- SmallVector<Value, 2> replacementValues;
- auto newAssumingOp = rewriter.create<shape::AssumingOp>(
- op.getLoc(), op.witness(), [&](OpBuilder &b, Location loc) {
- // Copy body.
- BlockAndValueMapping mapping;
- for (auto &nested : body->without_terminator())
- b.clone(nested, mapping);
-
- // Collect new yielded values.
- SmallVector<Value, 2> mappedResults;
- for (auto result : yieldOp.getOperands()) {
- if (isDefinedInBody(result)) {
- // This value is a result of the assuming op. We can obtain the
- // replacement value only after the new op is fully constructed.
- mappedResults.push_back(mapping.lookup(result));
- replacementValues.push_back(nullptr);
- } else {
- // When defined outside of the assuming block, we can use it
- // direclty. There is no need to yield the value from within the
- // block.
- replacementValues.push_back(result);
- }
- }
- return mappedResults;
- });
-
- // Use the assuming op's results for the missing replacement values, which
- // could not bypass the op.
- auto src = newAssumingOp.getResults().begin();
- for (auto &dst : replacementValues) {
- if (dst)
- continue;
- dst = *src++;
- }
-
- rewriter.replaceOp(op, replacementValues);
- return success();
- }
-};
} // namespace
void AssumingOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
- patterns.add<AssumingBypassIndependentResult, AssumingWithTrue>(context);
+ // If taking a passing witness, inline region.
+ patterns.add<AssumingWithTrue>(context);
}
// See RegionBranchOpInterface in Interfaces/ControlFlowInterfaces.td
diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index 3c4a2829a793e..39f17e9d253f6 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -1144,28 +1144,3 @@ func @broadcast_on_single_operand(%a : tensor<3xindex>) {
"use"(%0) : (tensor<?xindex>) -> ()
return
}
-
-// -----
-
-// CHECK-LABEL: @bypass_assmunig
-// CHECK-SAME: (%[[ARG:.*]]: tensor<2x3xf32>)
-func @bypass_assmunig(%arg : tensor<2x3xf32>)
- -> (tensor<2x3xf32>, tensor<2x3xf32>, tensor<2x3xf32>) {
- // CHECK: %[[OUTER:.*]] = "some.tensor"
- // CHECK: %[[WITNESS:.*]] = "some.witness"
- // CHECK: %[[YIELDED:.*]] = shape.assuming %[[WITNESS]] -> (tensor<2x3xf32>) {
- // CHECK: %[[INNER:.*]] = "some.tensor"
- // CHECK: shape.assuming_yield %[[INNER]] : tensor<2x3xf32>
- // CHECK: }
- // CHECK: return %[[YIELDED]], %[[OUTER]], %[[ARG]]
- %outer = "some.tensor"() : () -> tensor<2x3xf32>
- %witness = "some.witness"() : () -> !shape.witness
- %results:3 = shape.assuming %witness
- -> (tensor<2x3xf32>, tensor<2x3xf32>, tensor<2x3xf32>) {
- %inner = "some.tensor"() : () -> tensor<2x3xf32>
- shape.assuming_yield %inner, %outer, %arg
- : tensor<2x3xf32>, tensor<2x3xf32>, tensor<2x3xf32>
- }
- return %results#0, %results#1, %results#2
- : tensor<2x3xf32>, tensor<2x3xf32>, tensor<2x3xf32>
-}
More information about the Mlir-commits
mailing list