[Mlir-commits] [mlir] ed499dd - [MLIR] Fix operation clone
William S. Moses
llvmlistbot at llvm.org
Fri Apr 15 10:09:17 PDT 2022
Author: William S. Moses
Date: 2022-04-15T13:09:13-04:00
New Revision: ed499ddcdaa6a47d21d98c12f29a94386cd4643a
URL: https://github.com/llvm/llvm-project/commit/ed499ddcdaa6a47d21d98c12f29a94386cd4643a
DIFF: https://github.com/llvm/llvm-project/commit/ed499ddcdaa6a47d21d98c12f29a94386cd4643a.diff
LOG: [MLIR] Fix operation clone
Operation clone is currently faulty.
Suppose you have a block like as follows:
```
(%x0 : i32) {
%x1 = f(%x0)
return %x1
}
```
The test case we have is that we want to "unroll" this, in which we want to change this to compute `f(f(x0))` instead of just `f(x0)`. We do so by making a copy of the body at the end of the block and set the uses of the argument in the copy operations with the value returned from the original block.
This is implemented as follows:
1) map to the block arguments to the returned value (`map[x0] = x1`).
2) clone the body
Now for this small example, this works as intended and we get the following.
```
(%x0 : i32) {
%x1 = f(%x0)
%x2 = f(%x1)
return %x2
}
```
This is because the current logic to clone `x1 = f(x0)` first looks up the arguments in the map (which finds `x0` maps to `x1` from the initialization), and then sets the map of the result to the cloned result (`map[x1] = x2`).
However, this fails if `x0` is not an argument to the op, but instead used inside the region, like below.
```
(%x0 : i32) {
%x1 = f() {
yield %x0
}
return %x1
}
```
This is because cloning an op currently first looks up the args (none), sets the map of the result (`map[%x1] = %x2`), and then clones the regions. This results in the following, which is clearly illegal:
```
(%x0 : i32) {
%x1 = f() {
yield %x0
}
%x2 = f() {
yield %x2
}
return %x2
}
```
Diving deeper, this is partially due to the ordering (how this PR fixes it), as well as how region cloning works. Namely it will first clone with the mapping, and then it will remap all operands. Since the ordering above now has a map of `x0 -> x1` and `x1 -> x2`, we end up with the incorrect behavior here.
Reviewed By: ftynse
Differential Revision: https://reviews.llvm.org/D122531
Added:
mlir/test/IR/test-clone.mlir
mlir/test/lib/IR/TestClone.cpp
Modified:
mlir/include/mlir/IR/Operation.h
mlir/lib/IR/Operation.cpp
mlir/test/lib/IR/CMakeLists.txt
mlir/tools/mlir-opt/mlir-opt.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h
index cd33ceb43fd6a..316c52f086bbc 100644
--- a/mlir/include/mlir/IR/Operation.h
+++ b/mlir/include/mlir/IR/Operation.h
@@ -85,7 +85,10 @@ class alignas(8) Operation final
/// original one, but they will be left empty.
/// Operands are remapped using `mapper` (if present), and `mapper` is updated
/// to contain the results.
- Operation *cloneWithoutRegions(BlockAndValueMapping &mapper);
+ /// The `mapResults` argument specifies whether the results of the operation
+ /// should also be mapped.
+ Operation *cloneWithoutRegions(BlockAndValueMapping &mapper,
+ bool mapResults = true);
/// Create a partial copy of this operation without traversing into attached
/// regions. The new operation will have the same number of regions as the
diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp
index 19fb180676faa..d7b8125add6e3 100644
--- a/mlir/lib/IR/Operation.cpp
+++ b/mlir/lib/IR/Operation.cpp
@@ -526,7 +526,8 @@ InFlightDiagnostic Operation::emitOpError(const Twine &message) {
/// Create a deep copy of this operation but keep the operation regions empty.
/// Operands are remapped using `mapper` (if present), and `mapper` is updated
/// to contain the results.
-Operation *Operation::cloneWithoutRegions(BlockAndValueMapping &mapper) {
+Operation *Operation::cloneWithoutRegions(BlockAndValueMapping &mapper,
+ bool mapResults) {
SmallVector<Value, 8> operands;
SmallVector<Block *, 2> successors;
@@ -545,8 +546,10 @@ Operation *Operation::cloneWithoutRegions(BlockAndValueMapping &mapper) {
successors, getNumRegions());
// Remember the mapping of any results.
- for (unsigned i = 0, e = getNumResults(); i != e; ++i)
- mapper.map(getResult(i), newOp->getResult(i));
+ if (mapResults) {
+ for (unsigned i = 0, e = getNumResults(); i != e; ++i)
+ mapper.map(getResult(i), newOp->getResult(i));
+ }
return newOp;
}
@@ -562,12 +565,15 @@ Operation *Operation::cloneWithoutRegions() {
/// sub-operations to the corresponding operation that is copied, and adds
/// those mappings to the map.
Operation *Operation::clone(BlockAndValueMapping &mapper) {
- auto *newOp = cloneWithoutRegions(mapper);
+ auto *newOp = cloneWithoutRegions(mapper, /*mapResults=*/false);
// Clone the regions.
for (unsigned i = 0; i != numRegions; ++i)
getRegion(i).cloneInto(&newOp->getRegion(i), mapper);
+ for (unsigned i = 0, e = getNumResults(); i != e; ++i)
+ mapper.map(getResult(i), newOp->getResult(i));
+
return newOp;
}
diff --git a/mlir/test/IR/test-clone.mlir b/mlir/test/IR/test-clone.mlir
new file mode 100644
index 0000000000000..4028d7bf0c4b3
--- /dev/null
+++ b/mlir/test/IR/test-clone.mlir
@@ -0,0 +1,20 @@
+// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline="func.func(test-clone)" -split-input-file
+
+module {
+ func @fixpoint(%arg1 : i32) -> i32 {
+ %r = "test.use"(%arg1) ({
+ "test.yield"(%arg1) : (i32) -> ()
+ }) : (i32) -> i32
+ return %r : i32
+ }
+}
+
+// CHECK: func @fixpoint(%[[arg0:.+]]: i32) -> i32 {
+// CHECK-NEXT: %[[i0:.+]] = "test.use"(%[[arg0]]) ({
+// CHECK-NEXT: "test.yield"(%arg0) : (i32) -> ()
+// CHECK-NEXT: }) : (i32) -> i32
+// CHECK-NEXT: %[[i1:.+]] = "test.use"(%[[i0]]) ({
+// CHECK-NEXT: "test.yield"(%[[i0]]) : (i32) -> ()
+// CHECK-NEXT: }) : (i32) -> i32
+// CHECK-NEXT: return %[[i1]] : i32
+// CHECK-NEXT: }
diff --git a/mlir/test/lib/IR/CMakeLists.txt b/mlir/test/lib/IR/CMakeLists.txt
index a195817f4fe9f..a8b54d4d2d4f5 100644
--- a/mlir/test/lib/IR/CMakeLists.txt
+++ b/mlir/test/lib/IR/CMakeLists.txt
@@ -1,6 +1,7 @@
# Exclude tests from libMLIR.so
add_mlir_library(MLIRTestIR
TestBuiltinAttributeInterfaces.cpp
+ TestClone.cpp
TestDiagnostics.cpp
TestDominance.cpp
TestFunc.cpp
diff --git a/mlir/test/lib/IR/TestClone.cpp b/mlir/test/lib/IR/TestClone.cpp
new file mode 100644
index 0000000000000..76166552ffab2
--- /dev/null
+++ b/mlir/test/lib/IR/TestClone.cpp
@@ -0,0 +1,64 @@
+//===- TestSymbolUses.cpp - Pass to test symbol uselists ------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "TestDialect.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/Pass/Pass.h"
+
+using namespace mlir;
+
+namespace {
+
+/// This is a test pass which clones the body of a function. Specifically
+/// this pass replaces f(x) to instead return f(f(x)) in which the cloned body
+/// takes the result of the first operation return as an input.
+struct ClonePass
+ : public PassWrapper<ClonePass, InterfacePass<FunctionOpInterface>> {
+ StringRef getArgument() const final { return "test-clone"; }
+ StringRef getDescription() const final { return "Test clone of op"; }
+ void runOnOperation() override {
+ FunctionOpInterface op = getOperation();
+
+ // Limit testing to ops with only one region.
+ if (op->getNumRegions() != 1)
+ return;
+
+ Region ®ion = op->getRegion(0);
+ if (!region.hasOneBlock())
+ return;
+
+ Block ®ionEntry = region.front();
+ auto terminator = regionEntry.getTerminator();
+
+ // Only handle functions whose returns match the inputs.
+ if (terminator->getNumOperands() != regionEntry.getNumArguments())
+ return;
+
+ BlockAndValueMapping map;
+ for (auto tup :
+ llvm::zip(terminator->getOperands(), regionEntry.getArguments())) {
+ if (std::get<0>(tup).getType() != std::get<1>(tup).getType())
+ return;
+ map.map(std::get<1>(tup), std::get<0>(tup));
+ }
+
+ OpBuilder B(op->getContext());
+ B.setInsertionPointToEnd(®ionEntry);
+ SmallVector<Operation *> toClone;
+ for (Operation &inst : regionEntry)
+ toClone.push_back(&inst);
+ for (Operation *inst : toClone)
+ B.clone(*inst, map);
+ terminator->erase();
+ }
+};
+} // namespace
+
+namespace mlir {
+void registerCloneTestPasses() { PassRegistration<ClonePass>(); }
+} // namespace mlir
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 5e946de7a42e3..6a82811ee85ab 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -30,6 +30,7 @@ using namespace mlir;
// Defined in the test directory, no public header.
namespace mlir {
void registerConvertToTargetEnvPass();
+void registerCloneTestPasses();
void registerPassManagerTestPass();
void registerPrintSpirvAvailabilityPass();
void registerShapeFunctionTestPasses();
@@ -119,6 +120,7 @@ void registerTestTransformDialectExtension(DialectRegistry &);
#ifdef MLIR_INCLUDE_TESTS
void registerTestPasses() {
+ registerCloneTestPasses();
registerConvertToTargetEnvPass();
registerPassManagerTestPass();
registerPrintSpirvAvailabilityPass();
More information about the Mlir-commits
mailing list