[Mlir-commits] [mlir] [mlir][linalg][NFC] Remove linalg subset hoisting (PR #70636)
Matthias Springer
llvmlistbot at llvm.org
Sat Nov 4 19:47:38 PDT 2023
https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/70636
>From ba3c30c92fea0bea154ab164c9f06fa4c7c462c0 Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Sun, 5 Nov 2023 11:43:39 +0900
Subject: [PATCH] [mlir][linalg] Remove subset hoisting on tensors
---
.../mlir/Dialect/Linalg/Transforms/Hoisting.h | 103 ----
.../Dialect/Linalg/Transforms/CMakeLists.txt | 1 -
.../Linalg/Transforms/HoistPadding.cpp | 5 +-
.../Dialect/Linalg/Transforms/Hoisting.cpp | 9 -
.../Linalg/Transforms/SubsetHoisting.cpp | 553 ------------------
5 files changed, 3 insertions(+), 668 deletions(-)
delete mode 100644 mlir/lib/Dialect/Linalg/Transforms/SubsetHoisting.cpp
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Hoisting.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Hoisting.h
index d4444c3f869e5cc..921c3c3e8c7db69 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Hoisting.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Hoisting.h
@@ -45,109 +45,6 @@ namespace linalg {
/// when used on distributed loops with memref semantics!
void hoistRedundantVectorTransfers(func::FuncOp func);
-/// Greedily hoist redundant subset extract/insert operations on tensors outside
-/// of `forOp`. The logic follows:
-/// 1. Look for a write walking back from the `forOp` yield.
-/// 2. Check the uses of the matching block argument and look for a matching
-/// read (i.e. extract_slice of transfer_read) with matching indices.
-/// 3. In the case of a transfer_write, we can bypass other non-conflicting
-/// operations and find more hoisting opportunities.
-/// 4. Hoist the read/write pair and update the tensor SSA links.
-///
-/// Return the unmodified `forOp` if no hoisting occured.
-/// Return a new scf::ForOp if hoisting on tensors occured.
-///
-/// After this transformation the returned scf::ForOp may have unused arguments
-/// that can be removed by application of canonicalization patterns.
-///
-/// Example:
-/// ========
-/// IR Resembling:
-///
-/// ```
-/// %0 = scf.for %i = %l to %u step %s iter_args(%a0 = %t0)->(tensor<10xf32>) {
-/// %1 = scf.for %j = %l to %u step %s iter_args(%a6 = %a0)->(tensor<10xf32>) {
-/// %e = tensor.extract_slice %a6[%i][%sz][1]: tensor<10xf32> to tensor<?xf32>
-/// %r = vector.transfer_read %e[%c0], %cst: tensor<?xf32>, vector<4xf32>
-/// %u = "some_use"(%r) : (vector<4xf32>) -> vector<4xf32>
-/// %w = vector.transfer_write %u, %e[%c0] : vector<4xf32>, tensor<?xf32>
-/// %st = tensor.insert_slice %w into %a6[%i][%sz][1]
-/// : tensor<?xf32> into tensor<10xf32>
-/// scf.yield %st: tensor<10xf32>
-/// }
-/// scf.yield %1: tensor<10xf32>
-/// }
-/// ```
-///
-/// Progressively hoists to:
-///
-/// ```
-/// %0 = scf.for %i = %l to %u step %s iter_args(%a0 = %t0) -> (tensor<10xf32>){
-/// %e = tensor.extract_slice %a0[%i][%sz][1]: tensor<10xf32> to tensor<?xf32>
-/// %1:2 = scf.for %j = %l to %u step %s iter_args(%a6 = a0, %a7 = %e)
-/// -> (tensor<10xf32>, tensor<?xf32>) {
-/// %r = vector.transfer_read %a7[%c0], %cst: tensor<?xf32>, vector<4xf32>
-/// %u = "some_use"(%r) : (vector<4xf32>) -> vector<4xf32>
-/// %w = vector.transfer_write %u, %a7[%c0] : vector<4xf32>, tensor<?xf32>
-/// scf.yield %a6, %w: tensor<10xf32>, tensor<?xf32>
-/// }
-/// %st = tensor.insert_slice %1#1 into %1#0[%i][%sz][1]
-/// : tensor<?xf32> into tensor<10xf32>
-/// scf.yield %1: tensor<10xf32>
-/// }
-/// ```
-///
-/// and
-///
-/// ```
-/// %0 = scf.for %i = %l to %u step %s iter_args(%a0 = %t0) -> (tensor<10xf32>){
-/// %e = tensor.extract_slice %a0[%i][%sz][1]: tensor<10xf32> to tensor<?xf32>
-/// %r = vector.transfer_read %a7[%c0], %cst: tensor<?xf32>, vector<4xf32>
-/// %1:3 = scf.for %j = %l to %u step %s iter_args(%a6 = a0, %a7 = %e, %a7 = r)
-/// -> (tensor<10xf32>, tensor<?xf32>, vector<4xf32>) {
-/// %u = "some_use"(%r) : (vector<4xf32>) -> vector<4xf32>
-/// scf.yield %a6, %a7, %u: tensor<10xf32>, tensor<?xf32>, vector<4xf32>
-/// }
-/// %w = vector.transfer_write %1#2, %1#1[%c0] : vector<4xf32>, tensor<?xf32>
-/// %st = tensor.insert_slice %w into %1#0[%i][%sz][1]
-/// : tensor<?xf32> into tensor<10xf32>
-/// scf.yield %1: tensor<10xf32>
-/// }
-/// ```
-///
-/// It can then canonicalize to:
-///
-/// ```
-/// %0 = scf.for %i = %l to %u step %s iter_args(%a0 = %t0) -> (tensor<10xf32>){
-/// %e = tensor.extract_slice %a0[%i][%sz][1]: tensor<10xf32> to tensor<?xf32>
-/// %r = vector.transfer_read %a7[%c0], %cst: tensor<?xf32>, vector<4xf32>
-/// %1 = scf.for %j = %l to %u step %s iter_args(%a7 = r)
-/// -> (tensor<10xf32>, tensor<?xf32>, vector<4xf32>) {
-/// %u = "some_use"(%r) : (vector<4xf32>) -> vector<4xf32>
-/// scf.yield %u: vector<4xf32>
-/// }
-/// %w = vector.transfer_write %1, %e[%c0] : vector<4xf32>, tensor<?xf32>
-/// %st = tensor.insert_slice %w into %a0[%i][%sz][1]
-/// : tensor<?xf32> into tensor<10xf32>
-/// scf.yield %1: tensor<10xf32>
-/// }
-/// ```
-///
-// TODO: This should be further generalized along a few different axes:
-// - Other loops than scf.ForOp that operate on tensors (both sequential and
-// parallel loops).
-// - Other subset extract/insert pairs than tensor.extract/insert_slice and
-// vector.transfer_read/write.
-// - More general areSubsetDisjoint analysis/interface to work across all
-// subset op types and allow bypassing non-WAW-conflicting operations in
-// more cases.
-scf::ForOp hoistRedundantSubsetExtractInsert(RewriterBase &rewriter,
- scf::ForOp forOp);
-
-/// Call into `hoistRedundantSubsetInsertExtract` without a RewriterBase.
-// TODO: obsolete and should be retired
-void hoistRedundantVectorTransfersOnTensor(func::FuncOp func);
-
} // namespace linalg
} // namespace mlir
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 7af1148bb93d5a0..2f7b556bb24604e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -27,7 +27,6 @@ add_mlir_dialect_library(MLIRLinalgTransforms
Specialize.cpp
Split.cpp
SplitReduction.cpp
- SubsetHoisting.cpp
SubsetInsertionOpInterfaceImpl.cpp
SwapExtractSliceWithFillPatterns.cpp
Tiling.cpp
diff --git a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
index 866f51b0e92bbde..805c9d4ed3b79ff 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
@@ -25,6 +25,7 @@
#include "mlir/IR/Dominance.h"
#include "mlir/IR/Matchers.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
+#include "mlir/Transforms/LoopInvariantCodeMotionUtils.h"
#include "mlir/Transforms/RegionUtils.h"
#include "llvm/Support/Debug.h"
@@ -292,8 +293,8 @@ void HoistPaddingAnalysis::enableHoistPadding(RewriterBase &rewriter) {
// enclosing loop, try to apply hoisting on this outermost loop.
// TODO: we may want finer-grained hoisting of only that particular `sliceOp`.
if (!outermostEnclosingForOp.isDefinedOutsideOfLoop(sliceOp.getSource())) {
- outermostEnclosingForOp =
- hoistRedundantSubsetExtractInsert(rewriter, outermostEnclosingForOp);
+ outermostEnclosingForOp = cast<scf::ForOp>(
+ hoistLoopInvariantSubsets(rewriter, outermostEnclosingForOp));
}
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
index cbb2c507de69f9e..80ce97ee3437a5f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
@@ -43,15 +43,6 @@ using llvm::dbgs;
using namespace mlir;
using namespace mlir::linalg;
-void mlir::linalg::hoistRedundantVectorTransfersOnTensor(func::FuncOp func) {
- IRRewriter rewriter(func->getContext());
- // TODO: walking in some reverse / inside-out order would be more efficient
- // and would capture more cases.
- func.walk([&](scf::ForOp forOp) {
- hoistRedundantSubsetExtractInsert(rewriter, forOp);
- });
-}
-
static bool noAliasingUseInLoop(vector::TransferReadOp transferRead,
LoopLikeOpInterface loop) {
Value source = transferRead.getSource();
diff --git a/mlir/lib/Dialect/Linalg/Transforms/SubsetHoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/SubsetHoisting.cpp
deleted file mode 100644
index 91e0d139ec5c2f0..000000000000000
--- a/mlir/lib/Dialect/Linalg/Transforms/SubsetHoisting.cpp
+++ /dev/null
@@ -1,553 +0,0 @@
-//===- SubsetHoisting.cpp - Linalg hoisting transformations----------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file implements functions concerned with hoisting invariant subset
-// operations in the context of Linalg transformations.
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
-#include "mlir/Dialect/SCF/IR/SCF.h"
-#include "mlir/Dialect/SCF/Utils/Utils.h"
-#include "mlir/Dialect/Tensor/IR/Tensor.h"
-#include "mlir/Dialect/Utils/StaticValueUtils.h"
-#include "mlir/Dialect/Vector/IR/VectorOps.h"
-#include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#include "mlir/Transforms/LoopInvariantCodeMotionUtils.h"
-#include "llvm/Support/Debug.h"
-#include "llvm/Support/ErrorHandling.h"
-
-#define DEBUG_TYPE "subset-hoisting"
-
-#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
-
-using namespace mlir;
-using namespace mlir::linalg;
-
-/// Return true if the location of the subset defined by the op is invariant of
-/// the loop iteration.
-static bool
-isSubsetLocationLoopInvariant(scf::ForOp forOp,
- vector::TransferWriteOp transferWriteOp) {
- for (Value operand : transferWriteOp.getIndices())
- if (!forOp.isDefinedOutsideOfLoop(operand))
- return false;
- return true;
-}
-
-/// Return true if the location of the subset defined by the op is invariant of
-/// the loop iteration.
-static bool isSubsetLocationLoopInvariant(scf::ForOp forOp,
- tensor::InsertSliceOp insertSliceOp) {
- for (Value operand : insertSliceOp->getOperands().drop_front(
- tensor::InsertSliceOp::getOffsetSizeAndStrideStartOperandIndex()))
- if (!forOp.isDefinedOutsideOfLoop(operand))
- return false;
- return true;
-}
-
-/// Given an `srcTensor` that is a block argument belong to a loop.
-/// Greedily look for the first read that can be hoisted out of the loop (i.e.
-/// that satisfied the conditions):
-/// - The read is of type `tensor.extract_slice`.
-/// - The read is one of the uses of `srcTensor`.
-/// - The read is to the same subset that `tensor.insert_slice` writes.
-// TODO: Unify implementations once the "bypassing behavior" is the same.
-static FailureOr<tensor::ExtractSliceOp>
-findHoistableMatchingExtractSlice(RewriterBase &rewriter,
- tensor::InsertSliceOp insertSliceOp,
- BlockArgument srcTensor) {
- assert(isa<RankedTensorType>(srcTensor.getType()) && "not a ranked tensor");
-
- auto forOp = cast<scf::ForOp>(srcTensor.getOwner()->getParentOp());
-
- LLVM_DEBUG(DBGS() << "--find matching read for: " << insertSliceOp << "\n";
- DBGS() << "--amongst users of: " << srcTensor << "\n");
-
- SmallVector<Operation *> users(srcTensor.getUsers());
- if (forOp.isDefinedOutsideOfLoop(insertSliceOp.getDest()))
- llvm::append_range(users, insertSliceOp.getDest().getUsers());
-
- for (Operation *user : users) {
- LLVM_DEBUG(DBGS() << "----inspect user: " << *user << "\n");
- auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
- // Skip ops other than extract_slice with an exact matching of their tensor
- // subset.
- if (extractSliceOp) {
- auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
- if (extractSliceOp.getResultType() != insertSliceOp.getSourceType() ||
- !extractSliceOp.isSameAs(insertSliceOp, isSame)) {
- LLVM_DEBUG(DBGS() << "------not a matching extract_slice\n";
- DBGS() << *user << " vs " << *insertSliceOp << "\n");
- continue;
- }
-
- // Skip insert_slice whose vector is defined within the loop: we need to
- // hoist that definition first otherwise dominance violations trigger.
- if (!isa<BlockArgument>(extractSliceOp.getSource()) &&
- !forOp.isDefinedOutsideOfLoop(extractSliceOp.getSource())) {
- LLVM_DEBUG(DBGS() << "------transfer_read vector is loop-dependent\n");
- continue;
- }
- return extractSliceOp;
- }
-
- // TODO: Look through disjoint subsets, similar to vector.transfer_write
- // and unify implementations.
- }
-
- LLVM_DEBUG(DBGS() << "----no matching extract_slice");
- return failure();
-}
-
-/// Given an `srcTensor` that is a block argument belong to a loop.
-/// Greedily look for the first read that can be hoisted out of the loop (i.e.
-/// that satisfied the conditions):
-/// - The read is of type `tensor.transfer_read`.
-/// - The read is one of the uses of `srcTensor`.
-/// - The read is to the same subset that `tensor.transfer_write` writes.
-// TODO: Unify implementations once the "bypassing behavior" is the same.
-static FailureOr<vector::TransferReadOp>
-findHoistableMatchingTransferRead(RewriterBase &rewriter,
- vector::TransferWriteOp transferWriteOp,
- BlockArgument srcTensor) {
- if (!isa<RankedTensorType>(srcTensor.getType()))
- return failure();
-
- auto forOp = cast<scf::ForOp>(srcTensor.getOwner()->getParentOp());
-
- LLVM_DEBUG(DBGS() << "--find matching read for: " << transferWriteOp << "\n";
- DBGS() << "--amongst users of: " << srcTensor << "\n";);
-
- // vector.transfer_write is a bit peculiar: we look through dependencies
- // to disjoint tensor subsets. This requires a while loop.
- // TODO: Look through disjoint subsets for tensor.insert_slice and unify
- // implementations.
- SmallVector<Operation *> users(srcTensor.getUsers());
- // TODO: transferWriteOp.getSource is actually the destination tensor!!
- if (forOp.isDefinedOutsideOfLoop(transferWriteOp.getSource()))
- llvm::append_range(users, transferWriteOp.getSource().getUsers());
- while (!users.empty()) {
- Operation *user = users.pop_back_val();
- LLVM_DEBUG(DBGS() << "----inspect user: " << *user << "\n");
- auto read = dyn_cast<vector::TransferReadOp>(user);
- if (read) {
- // Skip ops other than transfer_read with an exact matching subset.
- if (read.getIndices() != transferWriteOp.getIndices() ||
- read.getVectorType() != transferWriteOp.getVectorType()) {
- LLVM_DEBUG(DBGS() << "------not a transfer_read that matches the "
- "transfer_write: "
- << *user << "\n\t(vs " << *transferWriteOp << ")\n");
- continue;
- }
-
- // transfer_read may be of a vector that is defined within the loop: we
- // traverse it by virtue of bypassing disjoint subset operations rooted at
- // a bbArg and yielding a matching yield.
- if (!isa<BlockArgument>(read.getSource()) &&
- !forOp.isDefinedOutsideOfLoop(read.getSource())) {
- LLVM_DEBUG(DBGS() << "------transfer_read vector appears loop "
- "dependent but will be tested for disjointness as "
- "part of the bypass analysis\n");
- }
- LLVM_DEBUG(DBGS() << "------found match\n");
- return read;
- }
-
- // As an optimization, we look further through dependencies to disjoint
- // tensor subsets. This creates more opportunities to find a matching read.
- if (isa<vector::TransferWriteOp>(user)) {
- // If we find a write with disjoint indices append all its uses.
- // TODO: Generalize areSubsetsDisjoint and allow other bypass than
- // just vector.transfer_write - vector.transfer_write.
- if (vector::isDisjointTransferIndices(
- cast<VectorTransferOpInterface>(user),
- cast<VectorTransferOpInterface>(
- transferWriteOp.getOperation()))) {
- LLVM_DEBUG(DBGS() << "----follow through disjoint write\n");
- users.append(user->getUsers().begin(), user->getUsers().end());
- } else {
- LLVM_DEBUG(DBGS() << "----skip non-disjoint write\n");
- }
- }
- }
-
- LLVM_DEBUG(DBGS() << "--no matching transfer_read\n");
- return rewriter.notifyMatchFailure(transferWriteOp,
- "no matching transfer_read");
-}
-
-/// Return the `vector.transfer_write` that produces `yieldOperand`, if:
-/// - The write operates on tensors.
-/// - All indices are defined outside of the loop.
-/// Return failure otherwise.
-///
-/// This is sufficient condition to hoist the `vector.transfer_write`; other
-/// operands can always be yielded by the loop where needed.
-// TODO: generalize beyond scf::ForOp.
-// TODO: Unify implementations once the "bypassing behavior" is the same.
-static FailureOr<vector::TransferWriteOp>
-getLoopInvariantTransferWriteDefining(RewriterBase &rewriter, scf::ForOp forOp,
- BlockArgument bbArg,
- OpOperand &yieldOperand) {
- assert(bbArg.getArgNumber() ==
- forOp.getNumInductionVars() + yieldOperand.getOperandNumber() &&
- "bbArg and yieldOperand must match");
- assert(isa<scf::YieldOp>(yieldOperand.getOwner()) && "must be an scf.yield");
-
- Value v = yieldOperand.get();
- auto transferWriteOp = v.getDefiningOp<vector::TransferWriteOp>();
- if (!transferWriteOp)
- return rewriter.notifyMatchFailure(v.getLoc(), "not a transfer_write");
-
- if (transferWriteOp->getNumResults() == 0) {
- return rewriter.notifyMatchFailure(v.getLoc(),
- "unsupported transfer_write on buffers");
- }
-
- // We do not explicitly check that the destination is a BBarg that matches the
- // yield operand as this would prevent us from bypassing other non-conflicting
- // writes.
-
- // Indexing must not depend on `forOp`.
- if (!isSubsetLocationLoopInvariant(forOp, transferWriteOp))
- return rewriter.notifyMatchFailure(
- v.getLoc(), "transfer_write indexing is loop-dependent");
-
- return transferWriteOp;
-}
-
-/// Return the `tensor.insert_slice` that produces `yieldOperand`, if:
-/// 1. Its destination tensor is a block argument of the `forOp`.
-/// 2. The unique use of its result is a yield with operand number matching
-/// the block argument.
-/// 3. All indices are defined outside of the loop.
-/// Return failure otherwise.
-///
-/// This is sufficient condition to hoist the `tensor.insert_slice`; other
-/// operands can always be yielded by the loop where needed.
-/// Note: 1. + 2. ensure that the yield / iter_args cycle results in proper
-/// semantics (i.e. no ping-ping between iter_args across iterations).
-// TODO: generalize beyond scf::ForOp.
-// TODO: Unify implementations once the "bypassing behavior" is the same.
-static FailureOr<tensor::InsertSliceOp>
-getLoopInvariantInsertSliceDefining(RewriterBase &rewriter, scf::ForOp forOp,
- BlockArgument bbArg,
- OpOperand &yieldOperand) {
- assert(bbArg.getArgNumber() ==
- forOp.getNumInductionVars() + yieldOperand.getOperandNumber() &&
- "bbArg and yieldOperand must match");
- assert(isa<scf::YieldOp>(yieldOperand.getOwner()) && "must be an scf.yield");
-
- Value v = yieldOperand.get();
- auto insertSliceOp = v.getDefiningOp<tensor::InsertSliceOp>();
- if (!insertSliceOp)
- return rewriter.notifyMatchFailure(v.getLoc(), "not an insert_slice");
-
- // Tensor inserted into must be a BBArg at position matching yield operand.
- // TODO: In the future we should not perform this check if we want to bypass
- // other non-conflicting writes.
- if (bbArg != insertSliceOp.getDest())
- return rewriter.notifyMatchFailure(v.getLoc(), "not a matching bbarg");
-
- // Indexing inserted into must not depend on `forOp`.
- if (!isSubsetLocationLoopInvariant(forOp, insertSliceOp))
- return rewriter.notifyMatchFailure(
- v.getLoc(), "insert_slice indexing is loop-dependent");
-
- return insertSliceOp;
-}
-
-/// Check if the chunk of data inserted by the `writeOp` is read by any other
-/// op than the candidateReadOp. This conflicting operation prevents hoisting,
-/// return it or nullptr if none is found.
-// TODO: Generalize subset disjunction analysis/interface.
-// TODO: Support more subset op types.
-static Operation *isTensorChunkAccessedByUnknownOp(Operation *writeOp,
- Operation *candidateReadOp,
- BlockArgument tensorArg) {
- // Make sure none of the other uses read the part of the tensor modified
- // by the transfer_write.
- llvm::SmallVector<Value::use_range, 1> uses;
- uses.push_back(tensorArg.getUses());
- while (!uses.empty()) {
- for (OpOperand &use : uses.pop_back_val()) {
- Operation *user = use.getOwner();
- // Skip the candidate use, only inspect the "other" uses.
- if (user == candidateReadOp || user == writeOp)
- continue;
-
- // TODO: Consider all transitive uses through
- // extract_slice/insert_slice. Atm we just bail because a stronger
- // analysis is needed for these cases.
- if (isa<tensor::ExtractSliceOp, tensor::InsertSliceOp>(user))
- return user;
-
- // Consider all transitive uses through a vector.transfer_write.
- if (isa<vector::TransferWriteOp>(writeOp)) {
- if (auto writeUser = dyn_cast<vector::TransferWriteOp>(user)) {
- uses.push_back(writeUser->getResult(0).getUses());
- continue;
- }
- }
-
- // Consider all nested uses through an scf::ForOp. We may have
- // pass-through tensor arguments left from previous level of
- // hoisting.
- if (auto forUser = dyn_cast<scf::ForOp>(user)) {
- Value arg = forUser.getBody()->getArgument(
- use.getOperandNumber() - forUser.getNumControlOperands() +
- /*iv value*/ 1);
- uses.push_back(arg.getUses());
- continue;
- }
-
- // Follow the use yield, only if it doesn't escape the original region.
- scf::YieldOp yieldUser = dyn_cast<scf::YieldOp>(user);
- if (yieldUser &&
- writeOp->getParentOp()->isAncestor(yieldUser->getParentOp())) {
- Value ret = yieldUser->getParentOp()->getResult(use.getOperandNumber());
- uses.push_back(ret.getUses());
- continue;
- }
-
- // If the write is a vector::TransferWriteOp, it may have been bypassed
- // and we need to check subset disjunction
- if (isa<vector::TransferWriteOp>(writeOp)) {
- auto read = dyn_cast<vector::TransferReadOp>(user);
- if (!read || !vector::isDisjointTransferIndices(
- cast<VectorTransferOpInterface>(read.getOperation()),
- cast<VectorTransferOpInterface>(writeOp))) {
- return user;
- }
- }
- }
- }
- return nullptr;
-}
-
-/// Mechanical hoisting of a matching read / write pair.
-/// Return the newly created scf::ForOp with an extra yields.
-// TODO: Unify implementations once the "bypassing behavior" is the same.
-static scf::ForOp hoistTransferReadWrite(
- RewriterBase &rewriter, vector::TransferReadOp transferReadOp,
- vector::TransferWriteOp transferWriteOp, BlockArgument tensorBBArg) {
- scf::ForOp forOp = cast<scf::ForOp>(tensorBBArg.getOwner()->getParentOp());
- LLVM_DEBUG(DBGS() << "--Start hoisting\n";
- DBGS() << "--Hoist read : " << transferReadOp << "\n";
- DBGS() << "--Hoist write: " << transferWriteOp << "\n";
- DBGS() << "--Involving : " << tensorBBArg << "\n");
-
- // TODO: don't hardcode /*numIvs=*/1.
- assert(tensorBBArg.getArgNumber() >= /*numIvs=*/1);
- int64_t initArgNumber = tensorBBArg.getArgNumber() - /*numIvs=*/1;
-
- // 1. Hoist the read op. Thanks to our previous checks we know this will not
- // trigger dominance violations once BBArgs are updated.
- // TODO: should the rewriter ever want to track this move ?
- transferReadOp->moveBefore(forOp);
- if (!forOp.isDefinedOutsideOfLoop(transferReadOp.getSource())) {
- rewriter.startRootUpdate(transferReadOp);
- transferReadOp.getSourceMutable().assign(
- forOp.getInitArgs()[initArgNumber]);
- rewriter.finalizeRootUpdate(transferReadOp);
- }
-
- // 2. Rewrite `loop` with an additional yield. This is the quantity that is
- // computed iteratively but whose storage has become loop-invariant.
- NewYieldValuesFn yieldFn = [&](OpBuilder &b, Location loc,
- ArrayRef<BlockArgument> newBBArgs) {
- return SmallVector<Value>{transferWriteOp.getVector()};
- };
- auto newForOp = cast<scf::ForOp>(*forOp.replaceWithAdditionalYields(
- rewriter, {transferReadOp.getVector()},
- /*replaceInitOperandUsesInLoop=*/true, yieldFn));
-
- // 3. Update the yield. Invariant: initArgNumber is the destination tensor.
- auto yieldOp =
- cast<scf::YieldOp>(newForOp.getRegion().front().getTerminator());
- // TODO: transferWriteOp.getSource is actually the destination tensor!!
- rewriter.startRootUpdate(yieldOp);
- yieldOp->setOperand(initArgNumber, transferWriteOp.getSource());
- rewriter.finalizeRootUpdate(yieldOp);
-
- // 4. Hoist write after and make uses of newForOp.getResult(initArgNumber)
- // flow through it.
- // TODO: should the rewriter ever want to track this move ?
- transferWriteOp->moveAfter(newForOp);
- rewriter.startRootUpdate(transferWriteOp);
- transferWriteOp.getVectorMutable().assign(newForOp.getResults().back());
- // TODO: transferWriteOp.getSource is actually the destination tensor!!
- transferWriteOp.getSourceMutable().assign(newForOp.getResult(initArgNumber));
- rewriter.finalizeRootUpdate(transferWriteOp);
- rewriter.replaceAllUsesExcept(newForOp.getResult(initArgNumber),
- transferWriteOp.getResult(), transferWriteOp);
- return newForOp;
-}
-
-/// Mechanical hoisting of a matching read / write pair.
-/// Return the newly created scf::ForOp with an extra yields.
-// TODO: Unify implementations once the "bypassing behavior" is the same.
-static scf::ForOp hoistExtractInsertSlice(RewriterBase &rewriter,
- tensor::ExtractSliceOp extractSliceOp,
- tensor::InsertSliceOp insertSliceOp,
- BlockArgument tensorBBArg) {
- scf::ForOp forOp = cast<scf::ForOp>(tensorBBArg.getOwner()->getParentOp());
- LLVM_DEBUG(DBGS() << "--Start hoisting\n";
- DBGS() << "--Hoist read : " << extractSliceOp << "\n";
- DBGS() << "--Hoist write: " << insertSliceOp << "\n";
- DBGS() << "--Involving : " << tensorBBArg << "\n");
-
- // TODO: don't hardcode /*numIvs=*/1.
- assert(tensorBBArg.getArgNumber() >= /*numIvs=*/1);
- int64_t initArgNumber = tensorBBArg.getArgNumber() - /*numIvs=*/1;
-
- // 1. Hoist the read op. Thanks to our previous checks we know this will not
- // trigger dominance violations once BBArgs are updated.
- // TODO: should the rewriter ever want to track this move ?
- extractSliceOp->moveBefore(forOp);
- if (!forOp.isDefinedOutsideOfLoop(extractSliceOp.getSource())) {
- assert(extractSliceOp.getSource() == tensorBBArg &&
- "extractSlice source not defined above must be the tracked bbArg");
- rewriter.startRootUpdate(extractSliceOp);
- extractSliceOp.getSourceMutable().assign(
- forOp.getInitArgs()[initArgNumber]);
- rewriter.finalizeRootUpdate(extractSliceOp);
- }
-
- // 2. Rewrite `loop` with an additional yield. This is the quantity that is
- // computed iteratively but whose storage has become loop-invariant.
- NewYieldValuesFn yieldFn = [&](OpBuilder &b, Location loc,
- ArrayRef<BlockArgument> newBBArgs) {
- return SmallVector<Value>{insertSliceOp.getSource()};
- };
- auto newForOp = cast<scf::ForOp>(*forOp.replaceWithAdditionalYields(
- rewriter, extractSliceOp.getResult(),
- /*replaceInitOperandUsesInLoop=*/true, yieldFn));
-
- // 3. Update the yield. Invariant: initArgNumber is the destination tensor.
- auto yieldOp =
- cast<scf::YieldOp>(newForOp.getRegion().front().getTerminator());
- // TODO: should the rewriter ever want to track this ?
- rewriter.startRootUpdate(yieldOp);
- yieldOp->setOperand(initArgNumber, insertSliceOp.getDest());
- rewriter.finalizeRootUpdate(yieldOp);
-
- // 4. Hoist write after and make uses of newForOp.getResult(initArgNumber)
- // flow through it.
- // TODO: should the rewriter ever want to track this move ?
- insertSliceOp->moveAfter(newForOp);
- rewriter.startRootUpdate(insertSliceOp);
- insertSliceOp.getSourceMutable().assign(newForOp.getResults().back());
- insertSliceOp.getDestMutable().assign(newForOp.getResult(initArgNumber));
- rewriter.finalizeRootUpdate(insertSliceOp);
- rewriter.replaceAllUsesExcept(newForOp.getResult(initArgNumber),
- insertSliceOp.getResult(), insertSliceOp);
- return newForOp;
-}
-
-/// Greedily hoist redundant subset extract/insert operations on tensors
-/// outside `forOp`.
-/// Return the unmodified `forOp` if no hoisting occurred.
-/// Return a new scf::ForOp if hoisting on tensors occurred.
-scf::ForOp
-mlir::linalg::hoistRedundantSubsetExtractInsert(RewriterBase &rewriter,
- scf::ForOp forOp) {
- LLVM_DEBUG(DBGS() << "Enter hoistRedundantSubsetExtractInsert scf.for\n");
- Operation *yield = forOp.getBody()->getTerminator();
-
- LLVM_DEBUG(DBGS() << "\n"; DBGS() << "Consider " << forOp << "\n");
-
- scf::ForOp newForOp = forOp;
- do {
- forOp = newForOp;
- for (const auto &it : llvm::enumerate(forOp.getRegionIterArgs())) {
- LLVM_DEBUG(DBGS() << "Consider " << it.value() << "\n");
-
- // 1. Find a loop invariant subset write yielding `ret` that we can
- // consider for hoisting.
- // TODO: TypeSwitch when we add more cases.
- OpOperand &ret = yield->getOpOperand(it.index());
- FailureOr<vector::TransferWriteOp> transferWriteOp =
- getLoopInvariantTransferWriteDefining(rewriter, forOp, it.value(),
- ret);
- FailureOr<tensor::InsertSliceOp> insertSliceOp =
- getLoopInvariantInsertSliceDefining(rewriter, forOp, it.value(), ret);
- if (failed(transferWriteOp) && failed(insertSliceOp)) {
- LLVM_DEBUG(DBGS() << "no loop invariant write defining iter_args "
- << it.value() << "\n");
- continue;
- }
-
- Operation *writeOp = succeeded(transferWriteOp)
- ? transferWriteOp->getOperation()
- : insertSliceOp->getOperation();
-
- // 2. Only accept writes with a single use (i.e. the yield).
- if (!writeOp->hasOneUse()) {
- LLVM_DEBUG(DBGS() << "write with more than 1 use " << *writeOp << "\n");
- continue;
- }
-
- LLVM_DEBUG(DBGS() << "Write to hoist: " << *writeOp << "\n");
-
- // 3. Find a matching read that can also be hoisted.
- Operation *matchingReadOp = nullptr;
- // TODO: TypeSwitch.
- if (succeeded(transferWriteOp)) {
- auto maybeTransferRead = findHoistableMatchingTransferRead(
- rewriter, *transferWriteOp, it.value());
- if (succeeded(maybeTransferRead))
- matchingReadOp = maybeTransferRead->getOperation();
- } else if (succeeded(insertSliceOp)) {
- auto maybeExtractSlice = findHoistableMatchingExtractSlice(
- rewriter, *insertSliceOp, it.value());
- if (succeeded(maybeExtractSlice))
- matchingReadOp = maybeExtractSlice->getOperation();
- } else {
- llvm_unreachable("unexpected case");
- }
- if (!matchingReadOp) {
- LLVM_DEBUG(DBGS() << "No matching read\n");
- continue;
- }
-
- // 4. Make sure no other use reads the part of the modified tensor.
- // This is necessary to guard against hazards when non-conflicting subset
- // ops are bypassed.
- Operation *maybeUnknownOp =
- isTensorChunkAccessedByUnknownOp(writeOp, matchingReadOp, it.value());
- if (maybeUnknownOp) {
- LLVM_DEBUG(DBGS() << "Tensor chunk accessed by unknown op, skip: "
- << *maybeUnknownOp << "\n");
- continue;
- }
-
- // 5. Perform the actual mechanical hoisting.
- // TODO: TypeSwitch.
- LLVM_DEBUG(DBGS() << "Read to hoist: " << *matchingReadOp << "\n");
- if (succeeded(transferWriteOp)) {
- newForOp = hoistTransferReadWrite(
- rewriter, cast<vector::TransferReadOp>(matchingReadOp),
- *transferWriteOp, it.value());
- } else if (succeeded(insertSliceOp)) {
- newForOp = hoistExtractInsertSlice(
- rewriter, cast<tensor::ExtractSliceOp>(matchingReadOp),
- *insertSliceOp, it.value());
- } else {
- llvm_unreachable("unexpected case");
- }
- break;
- }
- } while (forOp != newForOp);
-
- return newForOp;
-}
More information about the Mlir-commits
mailing list