[Mlir-commits] [mlir] 6953cf6 - [mlir][Linalg] Add a hoistRedundantVectorTransfers helper function
Nicolas Vasilache
llvmlistbot at llvm.org
Fri Jun 5 03:54:46 PDT 2020
Author: Nicolas Vasilache
Date: 2020-06-05T06:50:24-04:00
New Revision: 6953cf65024395508c464dc78c90b158b3241a3a
URL: https://github.com/llvm/llvm-project/commit/6953cf65024395508c464dc78c90b158b3241a3a
DIFF: https://github.com/llvm/llvm-project/commit/6953cf65024395508c464dc78c90b158b3241a3a.diff
LOG: [mlir][Linalg] Add a hoistRedundantVectorTransfers helper function
This revision adds a helper function to hoist vector.transfer_read /
vector.transfer_write pairs out of immediately enclosing scf::ForOp
iteratively, if the following conditions are true:
1. The 2 ops access the same memref with the same indices.
2. All operands are invariant under the enclosing scf::ForOp.
3. No uses of the memref either dominate the transfer_read or are
dominated by the transfer_write (i.e. no aliasing between the write and
the read across the loop)
To improve hoisting opportunities, call the `moveLoopInvariantCode` helper
function on the candidate loop above which to hoist. Hoisting the transfers
results in scf::ForOp yielding the value that originally transited through
memory.
This revision additionally exposes `moveLoopInvariantCode` as a helper in
LoopUtils.h and updates SliceAnalysis to support return scf::For values and
allow hoisting across multiple scf::ForOps.
Differential Revision: https://reviews.llvm.org/D81199
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/Transforms/Hoisting.h
mlir/include/mlir/Transforms/LoopUtils.h
mlir/lib/Analysis/SliceAnalysis.cpp
mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
mlir/lib/Transforms/LoopInvariantCodeMotion.cpp
mlir/test/Dialect/Linalg/hoisting.mlir
mlir/test/lib/Transforms/TestLinalgHoisting.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Hoisting.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Hoisting.h
index 283e68e8a76b..32693555ff40 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Hoisting.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Hoisting.h
@@ -16,11 +16,25 @@ namespace linalg {
/// Hoist alloc/dealloc pairs and alloca op out of immediately enclosing
/// scf::ForOp if both conditions are true:
-/// 1. all operands are defined outside the loop.
-/// 2. all uses are ViewLikeOp or DeallocOp.
+/// 1. All operands are defined outside the loop.
+/// 2. All uses are ViewLikeOp or DeallocOp.
// TODO: generalize on a per-need basis.
void hoistViewAllocOps(FuncOp func);
+/// Hoist vector.transfer_read/vector.transfer_write pairs out of immediately
+/// enclosing scf::ForOp iteratively, if the following conditions are true:
+/// 1. The two ops access the same memref with the same indices.
+/// 2. All operands are invariant under the enclosing scf::ForOp.
+/// 3. No uses of the memref either dominate the transfer_read or are
+/// dominated by the transfer_write (i.e. no aliasing between the write and
+/// the read across the loop)
+/// To improve hoisting opportunities, call the `moveLoopInvariantCode` helper
+/// function on the candidate loop above which to hoist. Hoisting the transfers
+/// results in scf::ForOp yielding the value that originally transited through
+/// memory.
+// TODO: generalize on a per-need basis.
+void hoistRedundantVectorTransfers(FuncOp func);
+
} // namespace linalg
} // namespace mlir
diff --git a/mlir/include/mlir/Transforms/LoopUtils.h b/mlir/include/mlir/Transforms/LoopUtils.h
index b3f23aaa4397..5a0d46f5ba57 100644
--- a/mlir/include/mlir/Transforms/LoopUtils.h
+++ b/mlir/include/mlir/Transforms/LoopUtils.h
@@ -22,10 +22,11 @@
namespace mlir {
class AffineForOp;
class FuncOp;
+class LoopLikeOpInterface;
+struct MemRefRegion;
class OpBuilder;
class Value;
class ValueRange;
-struct MemRefRegion;
namespace scf {
class ForOp;
@@ -294,6 +295,9 @@ LogicalResult
separateFullTiles(MutableArrayRef<AffineForOp> nest,
SmallVectorImpl<AffineForOp> *fullTileNest = nullptr);
+/// Move loop invariant code out of `looplike`.
+LogicalResult moveLoopInvariantCode(LoopLikeOpInterface looplike);
+
} // end namespace mlir
#endif // MLIR_TRANSFORMS_LOOP_UTILS_H
diff --git a/mlir/lib/Analysis/SliceAnalysis.cpp b/mlir/lib/Analysis/SliceAnalysis.cpp
index 5b630be5e9a5..e0c828fb55c1 100644
--- a/mlir/lib/Analysis/SliceAnalysis.cpp
+++ b/mlir/lib/Analysis/SliceAnalysis.cpp
@@ -41,20 +41,24 @@ static void getForwardSliceImpl(Operation *op,
}
if (auto forOp = dyn_cast<AffineForOp>(op)) {
- for (auto *ownerInst : forOp.getInductionVar().getUsers())
- if (forwardSlice->count(ownerInst) == 0)
- getForwardSliceImpl(ownerInst, forwardSlice, filter);
+ for (auto *ownerOp : forOp.getInductionVar().getUsers())
+ if (forwardSlice->count(ownerOp) == 0)
+ getForwardSliceImpl(ownerOp, forwardSlice, filter);
} else if (auto forOp = dyn_cast<scf::ForOp>(op)) {
- for (auto *ownerInst : forOp.getInductionVar().getUsers())
- if (forwardSlice->count(ownerInst) == 0)
- getForwardSliceImpl(ownerInst, forwardSlice, filter);
+ for (auto *ownerOp : forOp.getInductionVar().getUsers())
+ if (forwardSlice->count(ownerOp) == 0)
+ getForwardSliceImpl(ownerOp, forwardSlice, filter);
+ for (auto result : forOp.getResults())
+ for (auto *ownerOp : result.getUsers())
+ if (forwardSlice->count(ownerOp) == 0)
+ getForwardSliceImpl(ownerOp, forwardSlice, filter);
} else {
assert(op->getNumRegions() == 0 && "unexpected generic op with regions");
assert(op->getNumResults() <= 1 && "unexpected multiple results");
if (op->getNumResults() > 0) {
- for (auto *ownerInst : op->getResult(0).getUsers())
- if (forwardSlice->count(ownerInst) == 0)
- getForwardSliceImpl(ownerInst, forwardSlice, filter);
+ for (auto *ownerOp : op->getResult(0).getUsers())
+ if (forwardSlice->count(ownerOp) == 0)
+ getForwardSliceImpl(ownerOp, forwardSlice, filter);
}
}
@@ -139,15 +143,15 @@ SetVector<Operation *> mlir::getSlice(Operation *op,
SetVector<Operation *> backwardSlice;
SetVector<Operation *> forwardSlice;
while (currentIndex != slice.size()) {
- auto *currentInst = (slice)[currentIndex];
- // Compute and insert the backwardSlice starting from currentInst.
+ auto *currentOp = (slice)[currentIndex];
+ // Compute and insert the backwardSlice starting from currentOp.
backwardSlice.clear();
- getBackwardSlice(currentInst, &backwardSlice, backwardFilter);
+ getBackwardSlice(currentOp, &backwardSlice, backwardFilter);
slice.insert(backwardSlice.begin(), backwardSlice.end());
- // Compute and insert the forwardSlice starting from currentInst.
+ // Compute and insert the forwardSlice starting from currentOp.
forwardSlice.clear();
- getForwardSlice(currentInst, &forwardSlice, forwardFilter);
+ getForwardSlice(currentOp, &forwardSlice, forwardFilter);
slice.insert(forwardSlice.begin(), forwardSlice.end());
++currentIndex;
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
index 737e69021bc8..eeabdeb815dd 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
@@ -12,10 +12,15 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
+#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/Dialect/SCF/Utils.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Vector/VectorOps.h"
+#include "mlir/IR/Dominance.h"
#include "mlir/IR/Function.h"
+#include "mlir/Transforms/LoopUtils.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Debug.h"
@@ -75,3 +80,96 @@ void mlir::linalg::hoistViewAllocOps(FuncOp func) {
});
}
}
+
+void mlir::linalg::hoistRedundantVectorTransfers(FuncOp func) {
+ bool changed = true;
+ while (changed) {
+ changed = false;
+
+ func.walk([&](vector::TransferReadOp transferRead) {
+ LLVM_DEBUG(DBGS() << "Candidate for hoisting: "
+ << *transferRead.getOperation() << "\n");
+ auto loop = dyn_cast<scf::ForOp>(transferRead.getParentOp());
+ LLVM_DEBUG(DBGS() << "Parent op: " << *transferRead.getParentOp()
+ << "\n");
+ if (!loop)
+ return WalkResult::advance();
+
+ if (failed(moveLoopInvariantCode(
+ cast<LoopLikeOpInterface>(loop.getOperation()))))
+ llvm_unreachable(
+ "Unexpected failure to move invariant code out of loop");
+
+ LLVM_DEBUG(DBGS() << "Candidate read: " << *transferRead.getOperation()
+ << "\n");
+
+ llvm::SetVector<Operation *> forwardSlice;
+ getForwardSlice(transferRead, &forwardSlice);
+
+ // Look for the last TransferWriteOp in the forwardSlice of
+ // `transferRead` that operates on the same memref.
+ vector::TransferWriteOp transferWrite;
+ for (auto *sliceOp : llvm::reverse(forwardSlice)) {
+ auto candidateWrite = dyn_cast<vector::TransferWriteOp>(sliceOp);
+ if (!candidateWrite || candidateWrite.memref() != transferRead.memref())
+ continue;
+ transferWrite = candidateWrite;
+ }
+
+ // All operands of the TransferRead must be defined outside of the loop.
+ for (auto operand : transferRead.getOperands())
+ if (!loop.isDefinedOutsideOfLoop(operand))
+ return WalkResult::advance();
+
+ // Only hoist transfer_read / transfer_write pairs for now.
+ if (!transferWrite)
+ return WalkResult::advance();
+
+ LLVM_DEBUG(DBGS() << "Candidate: " << *transferWrite.getOperation()
+ << "\n");
+
+ // Approximate aliasing by checking that:
+ // 1. indices are the same,
+ // 2. no other use either dominates the transfer_read or is dominated
+ // by the transfer_write (i.e. aliasing between the write and the read
+ // across the loop).
+ if (transferRead.indices() != transferWrite.indices())
+ return WalkResult::advance();
+
+ // TODO: may want to memoize this information for performance but it
+ // likely gets invalidated often.
+ DominanceInfo dom(loop);
+ if (!dom.properlyDominates(transferRead.getOperation(), transferWrite))
+ return WalkResult::advance();
+ for (auto &use : transferRead.memref().getUses())
+ if (dom.properlyDominates(use.getOwner(),
+ transferRead.getOperation()) ||
+ dom.properlyDominates(transferWrite, use.getOwner()))
+ return WalkResult::advance();
+
+ // Hoist read before.
+ if (failed(loop.moveOutOfLoop({transferRead})))
+ llvm_unreachable(
+ "Unexpected failure to move transfer read out of loop");
+
+ // Hoist write after.
+ transferWrite.getOperation()->moveAfter(loop);
+
+ // Rewrite `loop` with new yields by cloning and erase the original loop.
+ OpBuilder b(transferRead);
+ auto newForOp = cloneWithNewYields(b, loop, transferRead.vector(),
+ transferWrite.vector());
+
+ // Transfer write has been hoisted, need to update the written value to
+ // the value yielded by the newForOp.
+ transferWrite.vector().replaceAllUsesWith(
+ newForOp.getResults().take_back()[0]);
+
+ changed = true;
+ loop.erase();
+ // Need to interrupt and restart because erasing the loop messes up the
+ // walk.
+ return WalkResult::interrupt();
+ });
+ }
+}
diff --git a/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp b/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp
index 5fee7d62ee10..0a78cb4526f6 100644
--- a/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp
+++ b/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp
@@ -17,6 +17,7 @@
#include "mlir/IR/Function.h"
#include "mlir/Interfaces/LoopLikeInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Transforms/LoopUtils.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
@@ -73,7 +74,7 @@ static bool canBeHoisted(Operation *op,
return true;
}
-static LogicalResult moveLoopInvariantCode(LoopLikeOpInterface looplike) {
+LogicalResult mlir::moveLoopInvariantCode(LoopLikeOpInterface looplike) {
auto &loopBody = looplike.getLoopBody();
// We use two collections here as we need to preserve the order for insertion
diff --git a/mlir/test/Dialect/Linalg/hoisting.mlir b/mlir/test/Dialect/Linalg/hoisting.mlir
index db686655d64d..54ab1dcf07c7 100644
--- a/mlir/test/Dialect/Linalg/hoisting.mlir
+++ b/mlir/test/Dialect/Linalg/hoisting.mlir
@@ -1,12 +1,13 @@
-// RUN: mlir-opt %s -test-linalg-hoisting=test-hoist-view-allocs | FileCheck %s
+// RUN: mlir-opt %s -test-linalg-hoisting=test-hoist-view-allocs -allow-unregistered-dialect | FileCheck %s
+// RUN: mlir-opt %s -test-linalg-hoisting=test-hoist-redundant-transfers -allow-unregistered-dialect | FileCheck %s --check-prefix=VECTOR_TRANSFERS
-// CHECK-LABEL: func @hoist(
+// CHECK-LABEL: func @hoist_allocs(
// CHECK-SAME: %[[VAL:[a-zA-Z0-9]*]]: index,
// CHECK-SAME: %[[LB:[a-zA-Z0-9]*]]: index,
// CHECK-SAME: %[[UB:[a-zA-Z0-9]*]]: index,
// CHECK-SAME: %[[STEP:[a-zA-Z0-9]*]]: index,
// CHECK-SAME: %[[CMP:[a-zA-Z0-9]*]]: i1
-func @hoist(%val: index, %lb : index, %ub : index, %step: index, %cmp: i1) {
+func @hoist_allocs(%val: index, %lb : index, %ub : index, %step: index, %cmp: i1) {
// CHECK-DAG: alloca(%[[VAL]]) : memref<?xi8>
// CHECK-DAG: %[[A0:.*]] = alloc(%[[VAL]]) : memref<?xi8>
// CHECK: scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
@@ -80,3 +81,69 @@ func @hoist(%val: index, %lb : index, %ub : index, %step: index, %cmp: i1) {
// CHECK: dealloc %[[A0]] : memref<?xi8>
return
}
+
+// VECTOR_TRANSFERS-LABEL: func @hoist_vector_transfer_pairs(
+// VECTOR_TRANSFERS-SAME: %[[MEMREF0:[a-zA-Z0-9]*]]: memref<?x?xf32>,
+// VECTOR_TRANSFERS-SAME: %[[MEMREF1:[a-zA-Z0-9]*]]: memref<?x?xf32>,
+// VECTOR_TRANSFERS-SAME: %[[MEMREF2:[a-zA-Z0-9]*]]: memref<?x?xf32>,
+// VECTOR_TRANSFERS-SAME: %[[MEMREF3:[a-zA-Z0-9]*]]: memref<?x?xf32>,
+// VECTOR_TRANSFERS-SAME: %[[MEMREF4:[a-zA-Z0-9]*]]: memref<?x?xf32>,
+// VECTOR_TRANSFERS-SAME: %[[MEMREF5:[a-zA-Z0-9]*]]: memref<?x?xf32>,
+// VECTOR_TRANSFERS-SAME: %[[VAL:[a-zA-Z0-9]*]]: index,
+// VECTOR_TRANSFERS-SAME: %[[LB:[a-zA-Z0-9]*]]: index,
+// VECTOR_TRANSFERS-SAME: %[[UB:[a-zA-Z0-9]*]]: index,
+// VECTOR_TRANSFERS-SAME: %[[STEP:[a-zA-Z0-9]*]]: index,
+// VECTOR_TRANSFERS-SAME: %[[CMP:[a-zA-Z0-9]*]]: i1
+func @hoist_vector_transfer_pairs(
+ %memref0: memref<?x?xf32>, %memref1: memref<?x?xf32>, %memref2: memref<?x?xf32>,
+ %memref3: memref<?x?xf32>, %memref4: memref<?x?xf32>, %memref5: memref<?x?xf32>,
+ %val: index, %lb : index, %ub : index, %step: index, %cmp: i1) {
+ %c0 = constant 0 : index
+ %cst = constant 0.0 : f32
+
+// VECTOR_TRANSFERS: vector.transfer_read %{{.*}} : memref<?x?xf32>, vector<1xf32>
+// VECTOR_TRANSFERS: scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args({{.*}}) -> (vector<1xf32>) {
+// VECTOR_TRANSFERS: vector.transfer_read %{{.*}} : memref<?x?xf32>, vector<2xf32>
+// VECTOR_TRANSFERS: scf.for %[[J:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args({{.*}}) -> (vector<1xf32>, vector<2xf32>) {
+// VECTOR_TRANSFERS: vector.transfer_read %{{.*}} : memref<?x?xf32>, vector<3xf32>
+// VECTOR_TRANSFERS: vector.transfer_read %{{.*}} : memref<?x?xf32>, vector<4xf32>
+// VECTOR_TRANSFERS: "some_crippling_use"(%[[MEMREF4]]) : (memref<?x?xf32>) -> ()
+// VECTOR_TRANSFERS: vector.transfer_read %{{.*}} : memref<?x?xf32>, vector<5xf32>
+// VECTOR_TRANSFERS: "some_use"(%{{.*}}) : (vector<1xf32>) -> vector<1xf32>
+// VECTOR_TRANSFERS: "some_use"(%{{.*}}) : (vector<2xf32>) -> vector<2xf32>
+// VECTOR_TRANSFERS: "some_use"(%[[MEMREF2]]) : (memref<?x?xf32>) -> vector<3xf32>
+// VECTOR_TRANSFERS: "some_use"(%{{.*}}) : (vector<4xf32>) -> vector<4xf32>
+// VECTOR_TRANSFERS: "some_use"(%{{.*}}) : (vector<5xf32>) -> vector<5xf32>
+// VECTOR_TRANSFERS: vector.transfer_write %{{.*}} : vector<3xf32>, memref<?x?xf32>
+// VECTOR_TRANSFERS: vector.transfer_write %{{.*}} : vector<4xf32>, memref<?x?xf32>
+// VECTOR_TRANSFERS: vector.transfer_write %{{.*}} : vector<5xf32>, memref<?x?xf32>
+// VECTOR_TRANSFERS: "some_crippling_use"(%[[MEMREF3]]) : (memref<?x?xf32>) -> ()
+// VECTOR_TRANSFERS: scf.yield {{.*}} : vector<1xf32>, vector<2xf32>
+// VECTOR_TRANSFERS: }
+// VECTOR_TRANSFERS: vector.transfer_write %{{.*}} : vector<2xf32>, memref<?x?xf32>
+// VECTOR_TRANSFERS: scf.yield {{.*}} : vector<1xf32>
+// VECTOR_TRANSFERS: }
+// VECTOR_TRANSFERS: vector.transfer_write %{{.*}} : vector<1xf32>, memref<?x?xf32>
+ scf.for %i = %lb to %ub step %step {
+ scf.for %j = %lb to %ub step %step {
+ %r0 = vector.transfer_read %memref1[%c0, %c0], %cst: memref<?x?xf32>, vector<1xf32>
+ %r1 = vector.transfer_read %memref0[%i, %i], %cst: memref<?x?xf32>, vector<2xf32>
+ %r2 = vector.transfer_read %memref2[%c0, %c0], %cst: memref<?x?xf32>, vector<3xf32>
+ %r3 = vector.transfer_read %memref3[%c0, %c0], %cst: memref<?x?xf32>, vector<4xf32>
+ "some_crippling_use"(%memref4) : (memref<?x?xf32>) -> ()
+ %r4 = vector.transfer_read %memref4[%c0, %c0], %cst: memref<?x?xf32>, vector<5xf32>
+ %u0 = "some_use"(%r0) : (vector<1xf32>) -> vector<1xf32>
+ %u1 = "some_use"(%r1) : (vector<2xf32>) -> vector<2xf32>
+ %u2 = "some_use"(%memref2) : (memref<?x?xf32>) -> vector<3xf32>
+ %u3 = "some_use"(%r3) : (vector<4xf32>) -> vector<4xf32>
+ %u4 = "some_use"(%r4) : (vector<5xf32>) -> vector<5xf32>
+ vector.transfer_write %u0, %memref1[%c0, %c0] : vector<1xf32>, memref<?x?xf32>
+ vector.transfer_write %u1, %memref0[%i, %i] : vector<2xf32>, memref<?x?xf32>
+ vector.transfer_write %u2, %memref2[%c0, %c0] : vector<3xf32>, memref<?x?xf32>
+ vector.transfer_write %u3, %memref3[%c0, %c0] : vector<4xf32>, memref<?x?xf32>
+ vector.transfer_write %u4, %memref4[%c0, %c0] : vector<5xf32>, memref<?x?xf32>
+ "some_crippling_use"(%memref3) : (memref<?x?xf32>) -> ()
+ }
+ }
+ return
+}
diff --git a/mlir/test/lib/Transforms/TestLinalgHoisting.cpp b/mlir/test/lib/Transforms/TestLinalgHoisting.cpp
index 2fb5c091aed4..d1e478fec3bc 100644
--- a/mlir/test/lib/Transforms/TestLinalgHoisting.cpp
+++ b/mlir/test/lib/Transforms/TestLinalgHoisting.cpp
@@ -29,6 +29,10 @@ struct TestLinalgHoisting
*this, "test-hoist-view-allocs",
llvm::cl::desc("Test hoisting alloc used by view"),
llvm::cl::init(false)};
+ Option<bool> testHoistRedundantTransfers{
+ *this, "test-hoist-redundant-transfers",
+ llvm::cl::desc("Test hoisting transfer_read/transfer_write pairs"),
+ llvm::cl::init(false)};
};
} // end anonymous namespace
@@ -37,6 +41,10 @@ void TestLinalgHoisting::runOnFunction() {
hoistViewAllocOps(getFunction());
return;
}
+ if (testHoistRedundantTransfers) {
+ hoistRedundantVectorTransfers(getFunction());
+ return;
+ }
}
namespace mlir {
More information about the Mlir-commits
mailing list