[Mlir-commits] [mlir] da291ba - [mlir] Add hoisting of transfer ops in affine loops
Javier Setoain
llvmlistbot at llvm.org
Wed Dec 7 12:10:21 PST 2022
Author: Javier Setoain
Date: 2022-12-07T20:08:07Z
New Revision: da291bab81200c93dffaa809a894168b7dedffd8
URL: https://github.com/llvm/llvm-project/commit/da291bab81200c93dffaa809a894168b7dedffd8
DIFF: https://github.com/llvm/llvm-project/commit/da291bab81200c93dffaa809a894168b7dedffd8.diff
LOG: [mlir] Add hoisting of transfer ops in affine loops
The only way to do this with the current hoisting strategy is by
lowering Affine to Scf first, but that prevents further passes on
Affine.
Differential Revision: https://reviews.llvm.org/D137600
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
mlir/test/Dialect/Linalg/hoisting.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
index 79d3e5524dc4..f51b4ffe9999 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
@@ -14,7 +14,9 @@
#include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Affine/Analysis/AffineStructures.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Affine/IR/AffineValueMap.h"
+#include "mlir/Dialect/Affine/Utils.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
@@ -29,6 +31,7 @@
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/LoopInvariantCodeMotionUtils.h"
#include "llvm/ADT/StringRef.h"
+#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
using llvm::dbgs;
@@ -425,10 +428,10 @@ void mlir::linalg::hoistRedundantVectorTransfers(func::FuncOp func) {
LLVM_DEBUG(DBGS() << "Candidate for hoisting: "
<< *transferRead.getOperation() << "\n");
- auto loop = dyn_cast<scf::ForOp>(transferRead->getParentOp());
+ auto loop = dyn_cast<LoopLikeOpInterface>(transferRead->getParentOp());
LLVM_DEBUG(DBGS() << "Parent op: " << *transferRead->getParentOp()
<< "\n");
- if (!loop)
+ if (!isa_and_nonnull<scf::ForOp, AffineForOp>(loop))
return WalkResult::advance();
LLVM_DEBUG(DBGS() << "Candidate read: " << *transferRead.getOperation()
@@ -513,18 +516,43 @@ void mlir::linalg::hoistRedundantVectorTransfers(func::FuncOp func) {
ArrayRef<BlockArgument> newBBArgs) {
return SmallVector<Value>{transferWrite.getVector()};
};
- auto newForOp =
- replaceLoopWithNewYields(b, loop, transferRead.getVector(), yieldFn);
// Transfer write has been hoisted, need to update the written vector by
// the value yielded by the newForOp.
- transferWrite.getVectorMutable().assign(newForOp.getResults().back());
-
- changed = true;
- loop.erase();
- // Need to interrupt and restart because erasing the loop messes up the
- // walk.
- return WalkResult::interrupt();
+ return TypeSwitch<Operation *, WalkResult>(loop)
+ .Case<scf::ForOp>([&](scf::ForOp scfForOp) {
+ auto newForOp = replaceLoopWithNewYields(
+ b, scfForOp, transferRead.getVector(), yieldFn);
+ transferWrite.getVectorMutable().assign(
+ newForOp.getResults().back());
+ changed = true;
+ loop.erase();
+ // Need to interrupt and restart because erasing the loop messes up
+ // the walk.
+ return WalkResult::interrupt();
+ })
+ .Case<AffineForOp>([&](AffineForOp affineForOp) {
+ auto newForOp = replaceForOpWithNewYields(
+ b, affineForOp, transferRead.getVector(),
+ SmallVector<Value>{transferWrite.getVector()},
+ transferWrite.getVector());
+ // Replace all uses of the `transferRead` with the corresponding
+ // basic block argument.
+ transferRead.getVector().replaceUsesWithIf(
+ newForOp.getLoopBody().getArguments().back(),
+ [&](OpOperand &use) {
+ Operation *user = use.getOwner();
+ return newForOp->isProperAncestor(user);
+ });
+ transferWrite.getVectorMutable().assign(
+ newForOp.getResults().back());
+ changed = true;
+ loop.erase();
+ // Need to interrupt and restart because erasing the loop messes up
+ // the walk.
+ return WalkResult::interrupt();
+ })
+ .Default([](Operation *) { return WalkResult::interrupt(); });
});
}
}
diff --git a/mlir/test/Dialect/Linalg/hoisting.mlir b/mlir/test/Dialect/Linalg/hoisting.mlir
index 2b783d144b7b..eac8ddcedf55 100644
--- a/mlir/test/Dialect/Linalg/hoisting.mlir
+++ b/mlir/test/Dialect/Linalg/hoisting.mlir
@@ -469,3 +469,39 @@ func.func @hoist_vector_transfer_write_pairs_disjoint_tensor(
return %1 : tensor<?x?xf32>
}
+// -----
+
+// CHECK-LABEL: func @hoist_vector_transfer_pairs_in_affine_loops(
+// CHECK-SAME: %[[MEMREF0:[a-zA-Z0-9]+]]: memref<64x64xi32>,
+// CHECK-SAME: %[[MEMREF1:[a-zA-Z0-9]+]]: memref<64x64xi32>,
+// CHECK-SAME: %[[MEMREF2:[a-zA-Z0-9]+]]: memref<64x64xi32>) {
+// CHECK: %[[C0:.*]] = arith.constant 0 : i32
+// CHECK: affine.for %[[I:.*]] = 0 to 64 {
+// CHECK: affine.for %[[J:.*]] = 0 to 64 step 16 {
+// CHECK: %[[R0:.*]] = vector.transfer_read %[[MEMREF2]][%[[I]], %[[J]]], %[[C0]] : memref<64x64xi32>, vector<16xi32>
+// CHECK: %[[R:.*]] = affine.for %[[K:.*]] = 0 to 64 iter_args(%[[ACC:.*]] = %[[R0]]) -> (vector<16xi32>) {
+// CHECK: %[[AV:.*]] = vector.transfer_read %[[MEMREF0]][%[[I]], %[[K]]], %[[C0]] {{.*}}: memref<64x64xi32>, vector<16xi32>
+// CHECK: %[[BV:.*]] = vector.transfer_read %[[MEMREF1]][%[[K]], %[[J]]], %[[C0]] {{.*}}: memref<64x64xi32>, vector<16xi32>
+// CHECK: %[[T0:.*]] = arith.muli %[[AV]], %[[BV]] : vector<16xi32>
+// CHECK: %[[T1:.*]] = arith.addi %[[ACC]], %[[T0]] : vector<16xi32>
+// CHECK: affine.yield %[[T1]] : vector<16xi32>
+// CHECK: }
+// CHECK: vector.transfer_write %[[R]], %[[MEMREF2]][%[[I]], %[[J]]] : vector<16xi32>, memref<64x64xi32>
+// CHECK: }
+// CHECK: }
+func.func @hoist_vector_transfer_pairs_in_affine_loops(%memref0: memref<64x64xi32>, %memref1: memref<64x64xi32>, %memref2: memref<64x64xi32>) {
+ %c0_i32 = arith.constant 0 : i32
+ affine.for %arg3 = 0 to 64 {
+ affine.for %arg4 = 0 to 64 step 16 {
+ affine.for %arg5 = 0 to 64 {
+ %0 = vector.transfer_read %memref0[%arg3, %arg5], %c0_i32 {permutation_map = affine_map<(d0, d1) -> (0)>} : memref<64x64xi32>, vector<16xi32>
+ %1 = vector.transfer_read %memref1[%arg5, %arg4], %c0_i32 : memref<64x64xi32>, vector<16xi32>
+ %2 = vector.transfer_read %memref2[%arg3, %arg4], %c0_i32 : memref<64x64xi32>, vector<16xi32>
+ %3 = arith.muli %0, %1 : vector<16xi32>
+ %4 = arith.addi %2, %3 : vector<16xi32>
+ vector.transfer_write %4, %memref2[%arg3, %arg4] : vector<16xi32>, memref<64x64xi32>
+ }
+ }
+ }
+ return
+}
More information about the Mlir-commits
mailing list