[Mlir-commits] [mlir] 4521b11 - [mlir][Linalg] Reimplement hoisting on tensors as a subset-based transformation
Nicolas Vasilache
llvmlistbot at llvm.org
Mon Feb 27 08:26:07 PST 2023
Author: Nicolas Vasilache
Date: 2023-02-27T08:15:44-08:00
New Revision: 4521b113978d9ddaaae038e3cdd9d8902e2392f9
URL: https://github.com/llvm/llvm-project/commit/4521b113978d9ddaaae038e3cdd9d8902e2392f9
DIFF: https://github.com/llvm/llvm-project/commit/4521b113978d9ddaaae038e3cdd9d8902e2392f9.diff
LOG: [mlir][Linalg] Reimplement hoisting on tensors as a subset-based transformation
This revision significantly rewrites hoisting on tensors.
Previously, `vector.transfer_read/write` and `tensor.extract/insert_slice` would
be clumped together when looking for candidate pairs.
This would significantly increase the complexity of the logic and would not apply
independently to `tensor.extract/insert_slice`.
The new implementation decouples the cases and starts to cast the problem
as a generic matching subset extract/insert, which will be future proof when
other such operation pairs are introduced.
Lastly, the implementation makes the distinction clear between `vector.transfer_read/write` for
which we allow bypasses of the disjoint subsets from `tensor.extract/insert_slice` for which we
do not yet allow it.
This can be extended in the future and unified once we have subset disjunction implemented more generally.
The algorithm can be rewritten to be less of a fixed point with interspersed canonicalizations.
As a consequence, the test explicitly adds a canonicalization to clean up the IR and verify we end up in the same state.
That extra canonicalization exhibited that one of the uses in one of the tests was dead, so we fix the appropriate test.
Differential Revision: https://reviews.llvm.org/D144656
Added:
mlir/lib/Dialect/Linalg/Transforms/SubsetHoisting.cpp
Modified:
mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
mlir/include/mlir/Dialect/Linalg/Transforms/Hoisting.h
mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
mlir/lib/Dialect/Utils/StaticValueUtils.cpp
mlir/test/Dialect/Linalg/hoisting.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 41c5daf6744d0..4aacd68e3bc97 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -1725,6 +1725,10 @@ def HoistRedundantVectorTransfersOp :
dominated by the transfer_write (i.e. no aliasing between the write and
the read across the loop)
+ WARNING: This hoisting does not model parallelism and is generally incorrect
+ when used on distributed loops with memref semantics!
+ TODO: obsolete and should be retired.
+
#### Return modes:
The operation always succeeds and returns a handle to the transformed
@@ -1823,4 +1827,51 @@ def ConvertConv2DToImg2ColOp : Op<Transform_Dialect,
}];
}
+//===----------------------------------------------------------------------===//
+// HoistRedundantTensorSubsetsOp
+//===----------------------------------------------------------------------===//
+
+def HoistRedundantTensorSubsetsOp :
+ Op<Transform_Dialect, "structured.hoist_redundant_tensor_subsets",
+ [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+ TransformEachOpTrait, TransformOpInterface]> {
+ let description = [{
+ Hoists supported tensor subset extract/insert operation pairs out of
+ immediately enclosing loop iteratively, if the following conditions
+ are true:
+ 1. The 2 ops access the same tensor subset.
+ 2. All operands are invariant under the enclosing loop.
+
+ The supported subset extract/insert operation pairs currently comprise:
+ - tensor.extract_slice / tensor.insert_slice
+ - vector.transfer_read / vector.transfer_write on tensors
+
+ Only scf.for loops are currently supported.
+
+ When applied to:
+ 1. an scf.for loop, hoist out of this loop only.
+ 2. a non-loop op, apply hoisting to all the contained loop ops.
+
+ #### Return modes:
+
+ The operation always succeeds and returns a handle to the transformed
+ function op.
+ }];
+
+ let arguments = (ins TransformHandleTypeInterface:$target);
+ let results = (outs TransformHandleTypeInterface:$transformed);
+
+ let assemblyFormat = "$target attr-dict `:` functional-type(operands, results) ";
+
+ let builders = [
+ OpBuilder<(ins "Value":$target)>,
+ ];
+ let extraClassDeclaration = [{
+ ::mlir::DiagnosedSilenceableFailure applyToOne(
+ ::mlir::Operation *target,
+ ::mlir::transform::ApplyToEachResultList &results,
+ ::mlir::transform::TransformState &state);
+ }];
+}
+
#endif // LINALG_TRANSFORM_OPS
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Hoisting.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Hoisting.h
index 355106ddd9175..24cb754d65e16 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Hoisting.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Hoisting.h
@@ -10,9 +10,13 @@
#define MLIR_DIALECT_LINALG_TRANSFORMS_HOISTING_H_
namespace mlir {
+class RewriterBase;
namespace func {
class FuncOp;
} // namespace func
+namespace scf {
+class ForOp;
+} // namespace scf
namespace linalg {
@@ -28,11 +32,112 @@ namespace linalg {
/// function on the candidate loop above which to hoist. Hoisting the transfers
/// results in scf::ForOp yielding the value that originally transited through
/// memory.
-// TODO: generalize on a per-need basis.
+///
+/// WARNING: This hoisting does not model parallelism and is generally incorrect
+/// when used on distributed loops with memref semantics!
void hoistRedundantVectorTransfers(func::FuncOp func);
-/// Same behavior as `hoistRedundantVectorTransfers` but works on tensors
-/// instead of buffers.
+/// Greedily hoist redundant subset extract/insert operations on tensors outside
+/// of `forOp`. The logic follows:
+/// 1. Look for a write walking back from the `forOp` yield.
+/// 2. Check the uses of the matching block argument and look for a matching
+/// read (i.e. extract_slice of transfer_read) with matching indices.
+/// 3. In the case of a transfer_write, we can bypass other non-conflicting
+/// operations and find more hoisting opportunities.
+/// 4. Hoist the read/write pair and update the tensor SSA links.
+///
+/// Return the unmodified `forOp` if no hoisting occured.
+/// Return a new scf::ForOp if hoisting on tensors occured.
+///
+/// After this transformation the returned scf::ForOp may have unused arguments
+/// that can be removed by application of canonicalization patterns.
+///
+/// Example:
+/// ========
+/// IR Resembling:
+///
+/// ```
+/// %0 = scf.for %i = %l to %u step %s iter_args(%a0 = %t0)->(tensor<10xf32>) {
+/// %1 = scf.for %j = %l to %u step %s iter_args(%a6 = %a0)->(tensor<10xf32>) {
+/// %e = tensor.extract_slice %a6[%i][%sz][1]: tensor<10xf32> to tensor<?xf32>
+/// %r = vector.transfer_read %e[%c0], %cst: tensor<?xf32>, vector<4xf32>
+/// %u = "some_use"(%r) : (vector<4xf32>) -> vector<4xf32>
+/// %w = vector.transfer_write %u, %e[%c0] : vector<4xf32>, tensor<?xf32>
+/// %st = tensor.insert_slice %w into %a6[%i][%sz][1]
+/// : tensor<?xf32> into tensor<10xf32>
+/// scf.yield %st: tensor<10xf32>
+/// }
+/// scf.yield %1: tensor<10xf32>
+/// }
+/// ```
+///
+/// Progressively hoists to:
+///
+/// ```
+/// %0 = scf.for %i = %l to %u step %s iter_args(%a0 = %t0) -> (tensor<10xf32>){
+/// %e = tensor.extract_slice %a0[%i][%sz][1]: tensor<10xf32> to tensor<?xf32>
+/// %1:2 = scf.for %j = %l to %u step %s iter_args(%a6 = a0, %a7 = %e)
+/// -> (tensor<10xf32>, tensor<?xf32>) {
+/// %r = vector.transfer_read %a7[%c0], %cst: tensor<?xf32>, vector<4xf32>
+/// %u = "some_use"(%r) : (vector<4xf32>) -> vector<4xf32>
+/// %w = vector.transfer_write %u, %a7[%c0] : vector<4xf32>, tensor<?xf32>
+/// scf.yield %a6, %w: tensor<10xf32>, tensor<?xf32>
+/// }
+/// %st = tensor.insert_slice %1#1 into %1#0[%i][%sz][1]
+/// : tensor<?xf32> into tensor<10xf32>
+/// scf.yield %1: tensor<10xf32>
+/// }
+/// ```
+///
+/// and
+///
+/// ```
+/// %0 = scf.for %i = %l to %u step %s iter_args(%a0 = %t0) -> (tensor<10xf32>){
+/// %e = tensor.extract_slice %a0[%i][%sz][1]: tensor<10xf32> to tensor<?xf32>
+/// %r = vector.transfer_read %a7[%c0], %cst: tensor<?xf32>, vector<4xf32>
+/// %1:3 = scf.for %j = %l to %u step %s iter_args(%a6 = a0, %a7 = %e, %a7 = r)
+/// -> (tensor<10xf32>, tensor<?xf32>, vector<4xf32>) {
+/// %u = "some_use"(%r) : (vector<4xf32>) -> vector<4xf32>
+/// scf.yield %a6, %a7, %u: tensor<10xf32>, tensor<?xf32>, vector<4xf32>
+/// }
+/// %w = vector.transfer_write %1#2, %1#1[%c0] : vector<4xf32>, tensor<?xf32>
+/// %st = tensor.insert_slice %w into %1#0[%i][%sz][1]
+/// : tensor<?xf32> into tensor<10xf32>
+/// scf.yield %1: tensor<10xf32>
+/// }
+/// ```
+///
+/// It can then canonicalize to:
+///
+/// ```
+/// %0 = scf.for %i = %l to %u step %s iter_args(%a0 = %t0) -> (tensor<10xf32>){
+/// %e = tensor.extract_slice %a0[%i][%sz][1]: tensor<10xf32> to tensor<?xf32>
+/// %r = vector.transfer_read %a7[%c0], %cst: tensor<?xf32>, vector<4xf32>
+/// %1 = scf.for %j = %l to %u step %s iter_args(%a7 = r)
+/// -> (tensor<10xf32>, tensor<?xf32>, vector<4xf32>) {
+/// %u = "some_use"(%r) : (vector<4xf32>) -> vector<4xf32>
+/// scf.yield %u: vector<4xf32>
+/// }
+/// %w = vector.transfer_write %1, %e[%c0] : vector<4xf32>, tensor<?xf32>
+/// %st = tensor.insert_slice %w into %a0[%i][%sz][1]
+/// : tensor<?xf32> into tensor<10xf32>
+/// scf.yield %1: tensor<10xf32>
+/// }
+/// ```
+///
+// TODO: This should be further generalized along a few
diff erent axes:
+// - Other loops than scf.ForOp that operate on tensors (both sequential and
+// parallel loops).
+// - Other subset extract/insert pairs than tensor.extract/insert_slice and
+// vector.transfer_read/write.
+// - More general areSubsetDisjoint analysis/interface to work across all
+// subset op types and allow bypassing non-WAW-conflicting operations in
+// more cases.
+scf::ForOp hoistRedundantSubsetExtractInsert(RewriterBase &rewriter,
+ scf::ForOp forOp);
+
+/// Call into `hoistRedundantSubsetInsertExtract` without a RewriterBase.
+// TODO: obsolete and should be retired
void hoistRedundantVectorTransfersOnTensor(func::FuncOp func);
} // namespace linalg
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 9652d7de5f7cd..80c0ba5e754a9 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -30,7 +30,17 @@ class Tensor_OpWithOffsetSizesAndStrides<string mnemonic,
list<Trait> traits = []>
: Tensor_Op<mnemonic, traits> {
code extraBaseClassDeclaration = [{
- /// Returns the dynamic sizes for this subview operation if specified.
+ /// Return the type of the base tensor operand.
+ ::mlir::RankedTensorType getSourceType() {
+ return getSource().getType().cast<RankedTensorType>();
+ }
+
+ /// Return the type of the result tensor.
+ ::mlir::RankedTensorType getResultType() {
+ return getResult().getType().cast<RankedTensorType>();
+ }
+
+ /// Return the dynamic sizes for this subview operation if specified.
::mlir::Operation::operand_range getDynamicSizes() { return getSizes(); }
/// Return the list of Range (i.e. offset, size, stride). Each
@@ -105,7 +115,7 @@ def Tensor_DimOp : Tensor_Op<"dim", [
%c0 = arith.constant 0 : index
%x = tensor.dim %A, %c0 : tensor<4x?xf32>
- // Returns the dynamic dimension of %A.
+ // Return the dynamic dimension of %A.
%c1 = arith.constant 1 : index
%y = tensor.dim %A, %c1 : memref<4x?xf32>
@@ -361,14 +371,10 @@ def Tensor_ExtractSliceOp : Tensor_OpWithOffsetSizesAndStrides<"extract_slice",
];
let extraClassDeclaration = extraBaseClassDeclaration # [{
- /// Returns the type of the base tensor operand.
- RankedTensorType getSourceType() {
- return getSource().getType().cast<RankedTensorType>();
- }
-
/// The result of an extract_slice is always a tensor.
+ // TODO: deprecate
RankedTensorType getType() {
- return getResult().getType().cast<RankedTensorType>();
+ return getResultType();
}
/// Compute the rank-reduction mask that can be applied to map the source
@@ -834,25 +840,21 @@ def Tensor_InsertSliceOp : Tensor_OpWithOffsetSizesAndStrides<"insert_slice", [
];
let extraClassDeclaration = extraBaseClassDeclaration # [{
- /// Returns the type of the base tensor operand.
- RankedTensorType getSourceType() {
- return getSource().getType().cast<RankedTensorType>();
- }
-
/// The result of a insert_slice is always a tensor.
+ // TODO: Deprecate this method.
RankedTensorType getType() {
- return getResult().getType().cast<RankedTensorType>();
+ return getResultType();
}
/// The `dest` type is the same as the result type.
RankedTensorType getDestType() {
- return getType();
+ return getResultType();
}
/// Return the expected rank of each of the`static_offsets`, `static_sizes`
/// and `static_strides` attributes.
std::array<unsigned, 3> getArrayAttrMaxRanks() {
- unsigned rank = getType().getRank();
+ unsigned rank = getResultType().getRank();
return {rank, rank, rank};
}
diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
index c37d35134dce1..100699c7f7fd8 100644
--- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
@@ -82,6 +82,8 @@ bool isConstantIntValue(OpFoldResult ofr, int64_t value);
/// that come from the fact there is no IndexAttr and that IndexType have no
/// bitwidth.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2);
+bool isEqualConstantIntOrValueArray(ArrayRef<OpFoldResult> ofrs1,
+ ArrayRef<OpFoldResult> ofrs2);
/// Helper function to convert a vector of `OpFoldResult`s into a vector of
/// `Value`s. For each `OpFoldResult` in `valueOrAttrVec` return the fold
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 5a8f9816aefd1..6baf392f95082 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3097,8 +3097,10 @@ DiagnosedSilenceableFailure
transform::HoistRedundantVectorTransfersOp::applyToOne(
func::FuncOp target, transform::ApplyToEachResultList &results,
transform::TransformState &state) {
+ // WARNING: This hoisting does not model parallelism and is generally
+ // incorrect when used on distributed loops with memref semantics!
+ // TODO: obsolete and should be retired.
linalg::hoistRedundantVectorTransfers(target);
- linalg::hoistRedundantVectorTransfersOnTensor(target);
results.push_back(target);
return DiagnosedSilenceableFailure::success();
}
@@ -3136,6 +3138,32 @@ DiagnosedSilenceableFailure transform::ConvertConv2DToImg2ColOp::applyToOne(
return DiagnosedSilenceableFailure::success();
}
+//===----------------------------------------------------------------------===//
+// HoistRedundantTensorSubsetsOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::HoistRedundantTensorSubsetsOp::applyToOne(
+ Operation *target, transform::ApplyToEachResultList &results,
+ transform::TransformState &state) {
+ IRRewriter rewriter(target->getContext());
+ auto forOp = dyn_cast<scf::ForOp>(target);
+ if (forOp) {
+ scf::ForOp newForOp =
+ linalg::hoistRedundantSubsetExtractInsert(rewriter, forOp);
+ results.push_back(newForOp);
+ return DiagnosedSilenceableFailure::success();
+ }
+
+ // TODO: walking in some reverse / inside-out order would be more efficient
+ // and would capture more cases.
+ target->walk([&](scf::ForOp forOp) {
+ hoistRedundantSubsetExtractInsert(rewriter, forOp);
+ });
+ results.push_back(target);
+ return DiagnosedSilenceableFailure::success();
+}
+
//===----------------------------------------------------------------------===//
// Transform op registration
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index adcc87f42dab2..8ad28f9be2abc 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -25,6 +25,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
Promotion.cpp
Split.cpp
SplitReduction.cpp
+ SubsetHoisting.cpp
SwapExtractSliceWithFillPatterns.cpp
Tiling.cpp
TilingInterfaceImpl.cpp
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
index f51b4ffe99996..9bab4ffb4c99b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
@@ -43,374 +43,13 @@ using llvm::dbgs;
using namespace mlir;
using namespace mlir::linalg;
-namespace {
-/// Represents a unit of hoistable TransferWriteOp. This may comprise other
-/// instructions that need to be hoisted too.
-struct HoistableWrite {
- vector::TransferWriteOp transferWriteOp;
- tensor::InsertSliceOp insertSliceOp;
-};
-/// Represents a unit of hoistable TransferReadOp. This may comprise other
-/// instructions that need to be hoisted too.
-struct HoistableRead {
- vector::TransferReadOp transferReadOp;
- tensor::ExtractSliceOp extractSliceOp;
-};
-} // namespace
-
-/// Return true if op1 and op2 are the same constant or the same SSA value.
-static bool isEqualOffsetSizeOrStride(OpFoldResult op1, OpFoldResult op2) {
- auto getConstantIntValue = [](OpFoldResult ofr) -> std::optional<int64_t> {
- Attribute attr = ofr.dyn_cast<Attribute>();
- // Note: isa+cast-like pattern allows writing the condition below as 1 line.
- if (!attr && ofr.get<Value>().getDefiningOp<arith::ConstantOp>())
- attr = ofr.get<Value>().getDefiningOp<arith::ConstantOp>().getValue();
- if (auto intAttr = attr.dyn_cast_or_null<IntegerAttr>())
- return intAttr.getValue().getSExtValue();
- return std::nullopt;
- };
- auto cst1 = getConstantIntValue(op1), cst2 = getConstantIntValue(op2);
- if (cst1 && cst2 && *cst1 == *cst2)
- return true;
- auto v1 = op1.dyn_cast<Value>(), v2 = op2.dyn_cast<Value>();
- return v1 && v2 && v1 == v2;
-}
-
-/// Return true is all offsets, sizes and strides are equal.
-static bool sameOffsetsSizesAndStrides(tensor::ExtractSliceOp s,
- tensor::InsertSliceOp si) {
- if (s.getStaticOffsets().size() != si.getStaticOffsets().size())
- return false;
- if (s.getStaticSizes().size() != si.getStaticSizes().size())
- return false;
- if (s.getStaticStrides().size() != si.getStaticStrides().size())
- return false;
- for (auto it : llvm::zip(s.getMixedOffsets(), si.getMixedOffsets()))
- if (!isEqualOffsetSizeOrStride(std::get<0>(it), std::get<1>(it)))
- return false;
- for (auto it : llvm::zip(s.getMixedSizes(), si.getMixedSizes()))
- if (!isEqualOffsetSizeOrStride(std::get<0>(it), std::get<1>(it)))
- return false;
- for (auto it : llvm::zip(s.getMixedStrides(), si.getMixedStrides()))
- if (!isEqualOffsetSizeOrStride(std::get<0>(it), std::get<1>(it)))
- return false;
- return true;
-}
-
-/// Look for a HoistableRead, in the given tensor uses, accessing the same
-/// offset as the HoistableWrite.
-static HoistableRead findMatchingTransferRead(HoistableWrite write,
- Value srcTensor) {
- assert(write.transferWriteOp &&
- "expected hoistable write to have a .transfer_write");
-
- LLVM_DEBUG(DBGS() << "findMatchingTransferRead for: "
- << *write.transferWriteOp.getOperation() << "\n");
- if (write.insertSliceOp)
- LLVM_DEBUG(DBGS() << "findMatchingTransferRead inserSliceOp: "
- << *write.insertSliceOp.getOperation() << "\n");
- SmallVector<Operation *> users(srcTensor.getUsers().begin(),
- srcTensor.getUsers().end());
- while (!users.empty()) {
- Operation *user = users.pop_back_val();
- LLVM_DEBUG(DBGS() << "findMatchingTransferRead inspect user: " << *user
- << "\n");
-
- // If HoistableWrite involves a InsertSliceOp, we need to find a
- // matching ExtractSliceOp.
- tensor::ExtractSliceOp sliceOp;
- Operation *maybeTransferReadUser = user;
- if (write.insertSliceOp) {
- sliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
- if (!sliceOp || sliceOp.getResult().getType() !=
- write.insertSliceOp.getSource().getType())
- continue;
-
- LLVM_DEBUG(DBGS() << "check whether sameOffsetsSizesAndStrides: "
- << *sliceOp << " vs " << *write.insertSliceOp << "\n");
- if (!sameOffsetsSizesAndStrides(sliceOp, write.insertSliceOp))
- continue;
-
- LLVM_DEBUG(DBGS() << "sameOffsetsSizesAndStrides: SUCCESS\n");
- // If we got here, sliceOp is hoistable iff it has exactly 2 uses:
- // 1. the transfer_write we want to hoist.
- // 2. a matching transfer_read.
- // Anything else, we skip.
- bool skip = false;
- Operation *otherUser = nullptr;
- for (Operation *u : sliceOp->getUsers()) {
- if (u == write.transferWriteOp)
- continue;
- if (otherUser) {
- skip = true;
- break;
- }
- otherUser = u;
- }
- if (skip || !otherUser)
- continue;
- maybeTransferReadUser = otherUser;
- }
-
- LLVM_DEBUG(DBGS() << "maybeTransferReadUser: " << *maybeTransferReadUser
- << "\n");
- auto read = dyn_cast<vector::TransferReadOp>(maybeTransferReadUser);
- if (read && read.getIndices() == write.transferWriteOp.getIndices() &&
- read.getVectorType() == write.transferWriteOp.getVectorType())
- return HoistableRead{read, sliceOp};
-
- if (isa<vector::TransferWriteOp>(user)) {
- // If we find a write with disjoint indices recurse through its uses.
- if (vector::isDisjointTransferIndices(
- cast<VectorTransferOpInterface>(user),
- cast<VectorTransferOpInterface>(
- write.transferWriteOp.getOperation()))) {
- users.append(user->getUsers().begin(), user->getUsers().end());
- }
- }
- }
- return HoistableRead();
-}
-
-/// Check if the chunk of data inserted by the HoistableWrite are read by any
-/// other op than the HoistableRead candidate.
-static bool tensorChunkAccessedByUnknownOp(HoistableWrite write,
- HoistableRead candidateRead,
- BlockArgument tensorArg) {
- // Make sure none of the other uses read the part of the tensor modified
- // by the transfer_write.
- llvm::SmallVector<Value::use_range, 1> uses;
- uses.push_back(tensorArg.getUses());
- while (!uses.empty()) {
- for (OpOperand &use : uses.pop_back_val()) {
- Operation *user = use.getOwner();
- // Skip the candidate use, only inspect the "other" uses.
- if (user == candidateRead.transferReadOp ||
- user == candidateRead.extractSliceOp ||
- user == write.transferWriteOp || user == write.insertSliceOp)
- continue;
- // Consider all transitive uses through a extract_slice / insert_slice.
- // TODO: atm we just bail because a stronger analysis is needed for these
- // cases.
- if (isa<tensor::ExtractSliceOp, tensor::InsertSliceOp>(user))
- return true;
- // Consider all transitive uses through a vector.transfer_write.
- if (auto writeUser = dyn_cast<vector::TransferWriteOp>(user)) {
- uses.push_back(writeUser->getResult(0).getUses());
- continue;
- }
- // Consider all nested uses through an scf::ForOp. We may have
- // pass-through tensor arguments left from previous level of
- // hoisting.
- if (auto forUser = dyn_cast<scf::ForOp>(user)) {
- Value arg = forUser.getLoopBody().getArgument(
- use.getOperandNumber() - forUser.getNumControlOperands() +
- /*iv value*/ 1);
- uses.push_back(arg.getUses());
- continue;
- }
- // Follow the use yield as long as it doesn't escape the original
- // region.
- scf::YieldOp yieldUser = dyn_cast<scf::YieldOp>(user);
- if (yieldUser && write.transferWriteOp->getParentOp()->isAncestor(
- yieldUser->getParentOp())) {
- Value ret = yieldUser->getParentOp()->getResult(use.getOperandNumber());
- uses.push_back(ret.getUses());
- continue;
- }
- auto read = dyn_cast<vector::TransferReadOp>(user);
- if (!read || !vector::isDisjointTransferIndices(
- cast<VectorTransferOpInterface>(read.getOperation()),
- cast<VectorTransferOpInterface>(
- write.transferWriteOp.getOperation()))) {
- return true;
- }
- }
- }
- return false;
-}
-
-/// Return the `forOp`-invariant HoistableWrite that produces `yieldOperand`.
-/// Return the null HoistableWrite() if it is not comprised of a
-/// vector.transfer_write + optional insert_slice or if any of the indexings
-/// is `forOp`-dependent.
-static HoistableWrite
-getLoopInvariantTransferWriteOpDefining(scf::ForOp forOp,
- OpOperand &yieldOperand) {
- Value v = yieldOperand.get();
- if (auto write = v.getDefiningOp<vector::TransferWriteOp>()) {
- // Indexing must not depend on `forOp`.
- for (Value operand : write.getIndices())
- if (!forOp.isDefinedOutsideOfLoop(operand))
- return HoistableWrite();
-
- return HoistableWrite{write, nullptr};
- }
-
- if (auto insertSliceOp = v.getDefiningOp<tensor::InsertSliceOp>()) {
- // Inserted slice must come from vector.transfer_write.
- auto write =
- insertSliceOp.getSource().getDefiningOp<vector::TransferWriteOp>();
- if (!write)
- return HoistableWrite();
-
- // Tensor inserted into must be a BBArg at position matching yieldOperand's.
- auto bbArg = insertSliceOp.getDest().dyn_cast<BlockArgument>();
- if (!bbArg || bbArg.getOwner()->getParentOp() != forOp ||
- bbArg.getArgNumber() != /*num iv=*/1 + yieldOperand.getOperandNumber())
- return HoistableWrite();
-
- // Indexing inserted into must not depend on `forOp`.
- for (Value operand : insertSliceOp->getOperands().drop_front(
- tensor::InsertSliceOp::getOffsetSizeAndStrideStartOperandIndex()))
- if (!forOp.isDefinedOutsideOfLoop(operand))
- return HoistableWrite();
-
- return HoistableWrite{write, insertSliceOp};
- }
-
- return HoistableWrite();
-}
-
-/// Mechanical hoisting of a matching HoistableRead / HoistableWrite pair.
-static void hoistReadWrite(HoistableRead read, HoistableWrite write,
- BlockArgument tensorBBArg) {
- scf::ForOp forOp = cast<scf::ForOp>(tensorBBArg.getOwner()->getParentOp());
- assert(read.transferReadOp && write.transferWriteOp &&
- "expected transfer_read and transfer_write ops to be set");
- assert(((read.extractSliceOp && write.insertSliceOp) ||
- (!read.extractSliceOp && !write.insertSliceOp)) &&
- "expected matching extract_slice / insert_slice");
- LLVM_DEBUG(DBGS() << "In forOp:\n"
- << *forOp.getOperation()
- << "\nHoist: " << *read.transferReadOp.getOperation()
- << "\nHoist: " << *write.transferWriteOp.getOperation()
- << "\nInvolving: " << tensorBBArg << "\n");
-
- // If a read slice is present, hoist it.
- if (read.extractSliceOp)
- forOp.moveOutOfLoop(read.extractSliceOp);
-
- // Hoist the transfer_read op.
- forOp.moveOutOfLoop(read.transferReadOp);
-
- // TODO: don't hardcode /*numIvs=*/1.
- assert(tensorBBArg.getArgNumber() >= /*numIvs=*/1);
- unsigned initArgNumber = tensorBBArg.getArgNumber() - /*numIvs=*/1;
-
- // Update the source tensor.
- if (read.extractSliceOp)
- read.extractSliceOp.getSourceMutable().assign(
- forOp.getInitArgs()[initArgNumber]);
- else
- read.transferReadOp.getSourceMutable().assign(
- forOp.getInitArgs()[initArgNumber]);
-
- // Hoist write after.
- if (write.insertSliceOp)
- write.insertSliceOp->moveAfter(forOp);
- write.transferWriteOp->moveAfter(forOp);
-
- // Update the yield.
- auto yieldOp = cast<scf::YieldOp>(forOp.getRegion().front().getTerminator());
- if (write.insertSliceOp)
- yieldOp->setOperand(initArgNumber, write.insertSliceOp.getDest());
- else
- yieldOp->setOperand(initArgNumber, write.transferWriteOp.getSource());
-
- // Rewrite `loop` with additional new yields.
- OpBuilder b(read.transferReadOp);
- NewYieldValueFn yieldFn = [&](OpBuilder &b, Location loc,
- ArrayRef<BlockArgument> newBBArgs) {
- return SmallVector<Value>{write.transferWriteOp.getVector()};
- };
- auto newForOp = replaceLoopWithNewYields(
- b, forOp, read.transferReadOp.getVector(), yieldFn);
-
- // Transfer write has been hoisted, need to update the vector and tensor
- // source. Replace the result of the loop to use the new tensor created
- // outside the loop.
- // Depending on whether a insert_slice is present or not, it carries the
- // update on the tensor operands.
- if (write.insertSliceOp) {
- newForOp.getResult(initArgNumber)
- .replaceAllUsesWith(write.insertSliceOp.getResult());
- write.transferWriteOp.getSourceMutable().assign(
- read.extractSliceOp.getResult());
- write.insertSliceOp.getDestMutable().assign(
- read.extractSliceOp.getSource());
- } else {
- newForOp.getResult(initArgNumber)
- .replaceAllUsesWith(write.transferWriteOp.getResult());
- write.transferWriteOp.getSourceMutable().assign(
- newForOp.getResult(initArgNumber));
- }
-
- // Always update with the newly yield tensor and vector.
- write.transferWriteOp.getVectorMutable().assign(newForOp.getResults().back());
-}
-
-// To hoist transfer op on tensor the logic can be significantly simplified
-// compared to the case on buffer. The transformation follows this logic:
-// 1. Look for transfer_write with a single use from ForOp yield
-// 2. Check the uses of the matching block argument and look for a transfer_read
-// with the same indices.
-// 3. Check that all the other uses of the tensor argument are either disjoint
-// tensor_read or transfer_write. For transfer_write uses recurse to make sure
-// the new tensor has the same restrictions on its uses.
-// 4. Hoist the tensor_read/tensor_write and update the tensor SSA links.
-// After this transformation the scf.forOp may have unused arguments that can be
-// remove by the canonicalization pass.
void mlir::linalg::hoistRedundantVectorTransfersOnTensor(func::FuncOp func) {
- bool changed = true;
- while (changed) {
- changed = false;
- func.walk([&](scf::ForOp forOp) {
- Operation *yield = forOp.getBody()->getTerminator();
- for (const auto &it : llvm::enumerate(forOp.getRegionIterArgs())) {
- OpOperand &ret = yield->getOpOperand(it.index());
- HoistableWrite write =
- getLoopInvariantTransferWriteOpDefining(forOp, ret);
- if (!write.transferWriteOp || !write.transferWriteOp->hasOneUse())
- continue;
- LLVM_DEBUG(dbgs() << "\n";
- DBGS() << "Candidate write for hoisting: "
- << *write.transferWriteOp.getOperation() << "\n");
- if (write.insertSliceOp)
- LLVM_DEBUG(DBGS() << "Candidate insert_slice for hoisting: "
- << *write.insertSliceOp.getOperation() << "\n");
- if (llvm::any_of(write.transferWriteOp.getIndices(),
- [&forOp](Value index) {
- return !forOp.isDefinedOutsideOfLoop(index);
- }))
- continue;
- // Find a read with the same type and indices.
- HoistableRead matchingRead =
- findMatchingTransferRead(write, it.value());
- // Make sure none of the other uses read the part of the tensor modified
- // by the transfer_write.
- if (!matchingRead.transferReadOp ||
- tensorChunkAccessedByUnknownOp(write, matchingRead, it.value()))
- continue;
-
- LLVM_DEBUG(DBGS() << "Start hoisting\n");
- hoistReadWrite(matchingRead, write, it.value());
- changed = true;
- forOp.erase();
-
- // Need to interrupt and restart: erasing the loop messes up the walk.
- return WalkResult::interrupt();
- }
- return WalkResult::advance();
- });
- // Apply canonicalization so the newForOp + yield folds immediately, thus
- // cleaning up the IR and potentially enabling more hoisting.
- if (changed) {
- RewritePatternSet patterns(func->getContext());
- scf::ForOp::getCanonicalizationPatterns(patterns, func->getContext());
- (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
- }
- }
+ IRRewriter rewriter(func->getContext());
+ // TODO: walking in some reverse / inside-out order would be more efficient
+ // and would capture more cases.
+ func.walk([&](scf::ForOp forOp) {
+ hoistRedundantSubsetExtractInsert(rewriter, forOp);
+ });
}
void mlir::linalg::hoistRedundantVectorTransfers(func::FuncOp func) {
diff --git a/mlir/lib/Dialect/Linalg/Transforms/SubsetHoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/SubsetHoisting.cpp
new file mode 100644
index 0000000000000..c0355a14d366b
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/SubsetHoisting.cpp
@@ -0,0 +1,553 @@
+//===- SubsetHoisting.cpp - Linalg hoisting transformations----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements functions concerned with hoisting invariant subset
+// operations in the context of Linalg transformations.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SCF/Utils/Utils.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "mlir/Transforms/LoopInvariantCodeMotionUtils.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/ErrorHandling.h"
+
+#define DEBUG_TYPE "subset-hoisting"
+
+#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
+
+using namespace mlir;
+using namespace mlir::linalg;
+
+/// Return true if the location of the subset defined by the op is invariant of
+/// the loop iteration.
+static bool
+isSubsetLocationLoopInvariant(scf::ForOp forOp,
+ vector::TransferWriteOp transferWriteOp) {
+ for (Value operand : transferWriteOp.getIndices())
+ if (!forOp.isDefinedOutsideOfLoop(operand))
+ return false;
+ return true;
+}
+
+/// Return true if the location of the subset defined by the op is invariant of
+/// the loop iteration.
+static bool isSubsetLocationLoopInvariant(scf::ForOp forOp,
+ tensor::InsertSliceOp insertSliceOp) {
+ for (Value operand : insertSliceOp->getOperands().drop_front(
+ tensor::InsertSliceOp::getOffsetSizeAndStrideStartOperandIndex()))
+ if (!forOp.isDefinedOutsideOfLoop(operand))
+ return false;
+ return true;
+}
+
+/// Given an `srcTensor` that is a block argument belong to a loop.
+/// Greedily look for the first read that can be hoisted out of the loop (i.e.
+/// that satisfied the conditions):
+/// - The read is of type `tensor.extract_slice`.
+/// - The read is one of the uses of `srcTensor`.
+/// - The read is to the same subset that `tensor.insert_slice` writes.
+// TODO: Unify implementations once the "bypassing behavior" is the same.
+static FailureOr<tensor::ExtractSliceOp>
+findHoistableMatchingExtractSlice(RewriterBase &rewriter,
+ tensor::InsertSliceOp insertSliceOp,
+ BlockArgument srcTensor) {
+ assert(srcTensor.getType().isa<RankedTensorType>() && "not a ranked tensor");
+
+ auto forOp = cast<scf::ForOp>(srcTensor.getOwner()->getParentOp());
+
+ LLVM_DEBUG(DBGS() << "--find matching read for: " << insertSliceOp << "\n";
+ DBGS() << "--amongst users of: " << srcTensor << "\n");
+
+ SmallVector<Operation *> users(srcTensor.getUsers());
+ if (forOp.isDefinedOutsideOfLoop(insertSliceOp.getDest()))
+ llvm::append_range(users, insertSliceOp.getDest().getUsers());
+
+ for (Operation *user : users) {
+ LLVM_DEBUG(DBGS() << "----inspect user: " << *user << "\n");
+ auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
+ // Skip ops other than extract_slice with an exact matching of their tensor
+ // subset.
+ if (extractSliceOp) {
+ auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
+ if (extractSliceOp.getResultType() != insertSliceOp.getSourceType() ||
+ !extractSliceOp.isSameAs(insertSliceOp, isSame)) {
+ LLVM_DEBUG(DBGS() << "------not a matching extract_slice\n";
+ DBGS() << *user << " vs " << *insertSliceOp << "\n");
+ continue;
+ }
+
+ // Skip insert_slice whose vector is defined within the loop: we need to
+ // hoist that definition first otherwise dominance violations trigger.
+ if (!extractSliceOp.getSource().isa<BlockArgument>() &&
+ !forOp.isDefinedOutsideOfLoop(extractSliceOp.getSource())) {
+ LLVM_DEBUG(DBGS() << "------transfer_read vector is loop-dependent\n");
+ continue;
+ }
+ return extractSliceOp;
+ }
+
+ // TODO: Look through disjoint subsets, similar to vector.transfer_write
+ // and unify implementations.
+ }
+
+ LLVM_DEBUG(DBGS() << "----no matching extract_slice");
+ return failure();
+}
+
+/// Given an `srcTensor` that is a block argument belong to a loop.
+/// Greedily look for the first read that can be hoisted out of the loop (i.e.
+/// that satisfied the conditions):
+/// - The read is of type `tensor.transfer_read`.
+/// - The read is one of the uses of `srcTensor`.
+/// - The read is to the same subset that `tensor.transfer_write` writes.
+// TODO: Unify implementations once the "bypassing behavior" is the same.
+static FailureOr<vector::TransferReadOp>
+findHoistableMatchingTransferRead(RewriterBase &rewriter,
+ vector::TransferWriteOp transferWriteOp,
+ BlockArgument srcTensor) {
+ if (!srcTensor.getType().isa<RankedTensorType>())
+ return failure();
+
+ auto forOp = cast<scf::ForOp>(srcTensor.getOwner()->getParentOp());
+
+ LLVM_DEBUG(DBGS() << "--find matching read for: " << transferWriteOp << "\n";
+ DBGS() << "--amongst users of: " << srcTensor << "\n";);
+
+ // vector.transfer_write is a bit peculiar: we look through dependencies
+ // to disjoint tensor subsets. This requires a while loop.
+ // TODO: Look through disjoint subsets for tensor.insert_slice and unify
+ // implementations.
+ SmallVector<Operation *> users(srcTensor.getUsers());
+ // TODO: transferWriteOp.getSource is actually the destination tensor!!
+ if (forOp.isDefinedOutsideOfLoop(transferWriteOp.getSource()))
+ llvm::append_range(users, transferWriteOp.getSource().getUsers());
+ while (!users.empty()) {
+ Operation *user = users.pop_back_val();
+ LLVM_DEBUG(DBGS() << "----inspect user: " << *user << "\n");
+ auto read = dyn_cast<vector::TransferReadOp>(user);
+ if (read) {
+ // Skip ops other than transfer_read with an exact matching subset.
+ if (read.getIndices() != transferWriteOp.getIndices() ||
+ read.getVectorType() != transferWriteOp.getVectorType()) {
+ LLVM_DEBUG(DBGS() << "------not a transfer_read that matches the "
+ "transfer_write: "
+ << *user << "\n\t(vs " << *transferWriteOp << ")\n");
+ continue;
+ }
+
+ // transfer_read may be of a vector that is defined within the loop: we
+ // traverse it by virtue of bypassing disjoint subset operations rooted at
+ // a bbArg and yielding a matching yield.
+ if (!read.getSource().isa<BlockArgument>() &&
+ !forOp.isDefinedOutsideOfLoop(read.getSource())) {
+ LLVM_DEBUG(DBGS() << "------transfer_read vector appears loop "
+ "dependent but will be tested for disjointness as "
+ "part of the bypass analysis\n");
+ }
+ LLVM_DEBUG(DBGS() << "------found match\n");
+ return read;
+ }
+
+ // As an optimization, we look further through dependencies to disjoint
+ // tensor subsets. This creates more opportunities to find a matching read.
+ if (isa<vector::TransferWriteOp>(user)) {
+ // If we find a write with disjoint indices append all its uses.
+ // TODO: Generalize areSubsetsDisjoint and allow other bypass than
+ // just vector.transfer_write - vector.transfer_write.
+ if (vector::isDisjointTransferIndices(
+ cast<VectorTransferOpInterface>(user),
+ cast<VectorTransferOpInterface>(
+ transferWriteOp.getOperation()))) {
+ LLVM_DEBUG(DBGS() << "----follow through disjoint write\n");
+ users.append(user->getUsers().begin(), user->getUsers().end());
+ } else {
+ LLVM_DEBUG(DBGS() << "----skip non-disjoint write\n");
+ }
+ }
+ }
+
+ LLVM_DEBUG(DBGS() << "--no matching transfer_read\n");
+ return rewriter.notifyMatchFailure(transferWriteOp,
+ "no matching transfer_read");
+}
+
+/// Return the `vector.transfer_write` that produces `yieldOperand`, if:
+/// - The write operates on tensors.
+/// - All indices are defined outside of the loop.
+/// Return failure otherwise.
+///
+/// This is sufficient condition to hoist the `vector.transfer_write`; other
+/// operands can always be yielded by the loop where needed.
+// TODO: generalize beyond scf::ForOp.
+// TODO: Unify implementations once the "bypassing behavior" is the same.
+static FailureOr<vector::TransferWriteOp>
+getLoopInvariantTransferWriteDefining(RewriterBase &rewriter, scf::ForOp forOp,
+ BlockArgument bbArg,
+ OpOperand &yieldOperand) {
+ assert(bbArg.getArgNumber() ==
+ forOp.getNumInductionVars() + yieldOperand.getOperandNumber() &&
+ "bbArg and yieldOperand must match");
+ assert(isa<scf::YieldOp>(yieldOperand.getOwner()) && "must be an scf.yield");
+
+ Value v = yieldOperand.get();
+ auto transferWriteOp = v.getDefiningOp<vector::TransferWriteOp>();
+ if (!transferWriteOp)
+ return rewriter.notifyMatchFailure(v.getLoc(), "not a transfer_write");
+
+ if (transferWriteOp->getNumResults() == 0) {
+ return rewriter.notifyMatchFailure(v.getLoc(),
+ "unsupported transfer_write on buffers");
+ }
+
+ // We do not explicitly check that the destination is a BBarg that matches the
+ // yield operand as this would prevent us from bypassing other non-conflicting
+ // writes.
+
+ // Indexing must not depend on `forOp`.
+ if (!isSubsetLocationLoopInvariant(forOp, transferWriteOp))
+ return rewriter.notifyMatchFailure(
+ v.getLoc(), "transfer_write indexing is loop-dependent");
+
+ return transferWriteOp;
+}
+
+/// Return the `tensor.insert_slice` that produces `yieldOperand`, if:
+/// 1. Its destination tensor is a block argument of the `forOp`.
+/// 2. The unique use of its result is a yield with operand number matching
+/// the block argument.
+/// 3. All indices are defined outside of the loop.
+/// Return failure otherwise.
+///
+/// This is sufficient condition to hoist the `tensor.insert_slice`; other
+/// operands can always be yielded by the loop where needed.
+/// Note: 1. + 2. ensure that the yield / iter_args cycle results in proper
+/// semantics (i.e. no ping-ping between iter_args across iterations).
+// TODO: generalize beyond scf::ForOp.
+// TODO: Unify implementations once the "bypassing behavior" is the same.
+static FailureOr<tensor::InsertSliceOp>
+getLoopInvariantInsertSliceDefining(RewriterBase &rewriter, scf::ForOp forOp,
+ BlockArgument bbArg,
+ OpOperand &yieldOperand) {
+ assert(bbArg.getArgNumber() ==
+ forOp.getNumInductionVars() + yieldOperand.getOperandNumber() &&
+ "bbArg and yieldOperand must match");
+ assert(isa<scf::YieldOp>(yieldOperand.getOwner()) && "must be an scf.yield");
+
+ Value v = yieldOperand.get();
+ auto insertSliceOp = v.getDefiningOp<tensor::InsertSliceOp>();
+ if (!insertSliceOp)
+ return rewriter.notifyMatchFailure(v.getLoc(), "not an insert_slice");
+
+ // Tensor inserted into must be a BBArg at position matching yield operand.
+ // TODO: In the future we should not perform this check if we want to bypass
+ // other non-conflicting writes.
+ if (bbArg != insertSliceOp.getDest())
+ return rewriter.notifyMatchFailure(v.getLoc(), "not a matching bbarg");
+
+ // Indexing inserted into must not depend on `forOp`.
+ if (!isSubsetLocationLoopInvariant(forOp, insertSliceOp))
+ return rewriter.notifyMatchFailure(
+ v.getLoc(), "insert_slice indexing is loop-dependent");
+
+ return insertSliceOp;
+}
+
+/// Check if the chunk of data inserted by the `writeOp` is read by any other
+/// op than the candidateReadOp. This conflicting operation prevents hoisting,
+/// return it or nullptr if none is found.
+// TODO: Generalize subset disjunction analysis/interface.
+// TODO: Support more subset op types.
+static Operation *isTensorChunkAccessedByUnknownOp(Operation *writeOp,
+ Operation *candidateReadOp,
+ BlockArgument tensorArg) {
+ // Make sure none of the other uses read the part of the tensor modified
+ // by the transfer_write.
+ llvm::SmallVector<Value::use_range, 1> uses;
+ uses.push_back(tensorArg.getUses());
+ while (!uses.empty()) {
+ for (OpOperand &use : uses.pop_back_val()) {
+ Operation *user = use.getOwner();
+ // Skip the candidate use, only inspect the "other" uses.
+ if (user == candidateReadOp || user == writeOp)
+ continue;
+
+ // TODO: Consider all transitive uses through
+ // extract_slice/insert_slice. Atm we just bail because a stronger
+ // analysis is needed for these cases.
+ if (isa<tensor::ExtractSliceOp, tensor::InsertSliceOp>(user))
+ return user;
+
+ // Consider all transitive uses through a vector.transfer_write.
+ if (isa<vector::TransferWriteOp>(writeOp)) {
+ if (auto writeUser = dyn_cast<vector::TransferWriteOp>(user)) {
+ uses.push_back(writeUser->getResult(0).getUses());
+ continue;
+ }
+ }
+
+ // Consider all nested uses through an scf::ForOp. We may have
+ // pass-through tensor arguments left from previous level of
+ // hoisting.
+ if (auto forUser = dyn_cast<scf::ForOp>(user)) {
+ Value arg = forUser.getLoopBody().getArgument(
+ use.getOperandNumber() - forUser.getNumControlOperands() +
+ /*iv value*/ 1);
+ uses.push_back(arg.getUses());
+ continue;
+ }
+
+ // Follow the use yield, only if it doesn't escape the original region.
+ scf::YieldOp yieldUser = dyn_cast<scf::YieldOp>(user);
+ if (yieldUser &&
+ writeOp->getParentOp()->isAncestor(yieldUser->getParentOp())) {
+ Value ret = yieldUser->getParentOp()->getResult(use.getOperandNumber());
+ uses.push_back(ret.getUses());
+ continue;
+ }
+
+ // If the write is a vector::TransferWriteOp, it may have been bypassed
+ // and we need to check subset disjunction
+ if (isa<vector::TransferWriteOp>(writeOp)) {
+ auto read = dyn_cast<vector::TransferReadOp>(user);
+ if (!read || !vector::isDisjointTransferIndices(
+ cast<VectorTransferOpInterface>(read.getOperation()),
+ cast<VectorTransferOpInterface>(writeOp))) {
+ return user;
+ }
+ }
+ }
+ }
+ return nullptr;
+}
+
+/// Mechanical hoisting of a matching read / write pair.
+/// Return the newly created scf::ForOp with an extra yields.
+// TODO: Unify implementations once the "bypassing behavior" is the same.
+static scf::ForOp hoistTransferReadWrite(
+ RewriterBase &rewriter, vector::TransferReadOp transferReadOp,
+ vector::TransferWriteOp transferWriteOp, BlockArgument tensorBBArg) {
+ scf::ForOp forOp = cast<scf::ForOp>(tensorBBArg.getOwner()->getParentOp());
+ LLVM_DEBUG(DBGS() << "--Start hoisting\n";
+ DBGS() << "--Hoist read : " << transferReadOp << "\n";
+ DBGS() << "--Hoist write: " << transferWriteOp << "\n";
+ DBGS() << "--Involving : " << tensorBBArg << "\n");
+
+ // TODO: don't hardcode /*numIvs=*/1.
+ assert(tensorBBArg.getArgNumber() >= /*numIvs=*/1);
+ int64_t initArgNumber = tensorBBArg.getArgNumber() - /*numIvs=*/1;
+
+ // 1. Hoist the read op. Thanks to our previous checks we know this will not
+ // trigger dominance violations once BBArgs are updated.
+ // TODO: should the rewriter ever want to track this move ?
+ transferReadOp->moveBefore(forOp);
+ if (!forOp.isDefinedOutsideOfLoop(transferReadOp.getSource())) {
+ rewriter.startRootUpdate(transferReadOp);
+ transferReadOp.getSourceMutable().assign(
+ forOp.getInitArgs()[initArgNumber]);
+ rewriter.finalizeRootUpdate(transferReadOp);
+ }
+
+ // 2. Rewrite `loop` with an additional yield. This is the quantity that is
+ // computed iteratively but whose storage has become loop-invariant.
+ NewYieldValueFn yieldFn = [&](OpBuilder &b, Location loc,
+ ArrayRef<BlockArgument> newBBArgs) {
+ return SmallVector<Value>{transferWriteOp.getVector()};
+ };
+ auto newForOp = replaceLoopWithNewYields(
+ rewriter, forOp, {transferReadOp.getVector()}, yieldFn);
+ rewriter.eraseOp(forOp);
+
+ // 3. Update the yield. Invariant: initArgNumber is the destination tensor.
+ auto yieldOp =
+ cast<scf::YieldOp>(newForOp.getRegion().front().getTerminator());
+ // TODO: transferWriteOp.getSource is actually the destination tensor!!
+ rewriter.startRootUpdate(yieldOp);
+ yieldOp->setOperand(initArgNumber, transferWriteOp.getSource());
+ rewriter.finalizeRootUpdate(yieldOp);
+
+ // 4. Hoist write after and make uses of newForOp.getResult(initArgNumber)
+ // flow through it.
+ // TODO: should the rewriter ever want to track this move ?
+ transferWriteOp->moveAfter(newForOp);
+ rewriter.startRootUpdate(transferWriteOp);
+ transferWriteOp.getVectorMutable().assign(newForOp.getResults().back());
+ // TODO: transferWriteOp.getSource is actually the destination tensor!!
+ transferWriteOp.getSourceMutable().assign(newForOp.getResult(initArgNumber));
+ rewriter.finalizeRootUpdate(transferWriteOp);
+ rewriter.replaceAllUsesExcept(newForOp.getResult(initArgNumber),
+ transferWriteOp.getResult(), transferWriteOp);
+ return newForOp;
+}
+
+/// Mechanical hoisting of a matching read / write pair.
+/// Return the newly created scf::ForOp with an extra yields.
+// TODO: Unify implementations once the "bypassing behavior" is the same.
+static scf::ForOp hoistExtractInsertSlice(RewriterBase &rewriter,
+ tensor::ExtractSliceOp extractSliceOp,
+ tensor::InsertSliceOp insertSliceOp,
+ BlockArgument tensorBBArg) {
+ scf::ForOp forOp = cast<scf::ForOp>(tensorBBArg.getOwner()->getParentOp());
+ LLVM_DEBUG(DBGS() << "--Start hoisting\n";
+ DBGS() << "--Hoist read : " << extractSliceOp << "\n";
+ DBGS() << "--Hoist write: " << insertSliceOp << "\n";
+ DBGS() << "--Involving : " << tensorBBArg << "\n");
+
+ // TODO: don't hardcode /*numIvs=*/1.
+ assert(tensorBBArg.getArgNumber() >= /*numIvs=*/1);
+ int64_t initArgNumber = tensorBBArg.getArgNumber() - /*numIvs=*/1;
+
+ // 1. Hoist the read op. Thanks to our previous checks we know this will not
+ // trigger dominance violations once BBArgs are updated.
+ // TODO: should the rewriter ever want to track this move ?
+ extractSliceOp->moveBefore(forOp);
+ if (!forOp.isDefinedOutsideOfLoop(extractSliceOp.getSource())) {
+ assert(extractSliceOp.getSource() == tensorBBArg &&
+ "extractSlice source not defined above must be the tracked bbArg");
+ rewriter.startRootUpdate(extractSliceOp);
+ extractSliceOp.getSourceMutable().assign(
+ forOp.getInitArgs()[initArgNumber]);
+ rewriter.finalizeRootUpdate(extractSliceOp);
+ }
+
+ // 2. Rewrite `loop` with an additional yield. This is the quantity that is
+ // computed iteratively but whose storage has become loop-invariant.
+ NewYieldValueFn yieldFn = [&](OpBuilder &b, Location loc,
+ ArrayRef<BlockArgument> newBBArgs) {
+ return SmallVector<Value>{insertSliceOp.getSource()};
+ };
+ auto newForOp = replaceLoopWithNewYields(rewriter, forOp,
+ extractSliceOp.getResult(), yieldFn);
+ rewriter.eraseOp(forOp);
+
+ // 3. Update the yield. Invariant: initArgNumber is the destination tensor.
+ auto yieldOp =
+ cast<scf::YieldOp>(newForOp.getRegion().front().getTerminator());
+ // TODO: should the rewriter ever want to track this ?
+ rewriter.startRootUpdate(yieldOp);
+ yieldOp->setOperand(initArgNumber, insertSliceOp.getDest());
+ rewriter.finalizeRootUpdate(yieldOp);
+
+ // 4. Hoist write after and make uses of newForOp.getResult(initArgNumber)
+ // flow through it.
+ // TODO: should the rewriter ever want to track this move ?
+ insertSliceOp->moveAfter(newForOp);
+ rewriter.startRootUpdate(insertSliceOp);
+ insertSliceOp.getSourceMutable().assign(newForOp.getResults().back());
+ insertSliceOp.getDestMutable().assign(newForOp.getResult(initArgNumber));
+ rewriter.finalizeRootUpdate(insertSliceOp);
+ rewriter.replaceAllUsesExcept(newForOp.getResult(initArgNumber),
+ insertSliceOp.getResult(), insertSliceOp);
+ return newForOp;
+}
+
+/// Greedily hoist redundant subset extract/insert operations on tensors
+/// outside `forOp`.
+/// Return the unmodified `forOp` if no hoisting occurred.
+/// Return a new scf::ForOp if hoisting on tensors occurred.
+scf::ForOp
+mlir::linalg::hoistRedundantSubsetExtractInsert(RewriterBase &rewriter,
+ scf::ForOp forOp) {
+ LLVM_DEBUG(DBGS() << "Enter hoistRedundantSubsetExtractInsert scf.for\n");
+ Operation *yield = forOp.getBody()->getTerminator();
+
+ LLVM_DEBUG(DBGS() << "\n"; DBGS() << "Consider " << forOp << "\n");
+
+ scf::ForOp newForOp = forOp;
+ do {
+ forOp = newForOp;
+ for (const auto &it : llvm::enumerate(forOp.getRegionIterArgs())) {
+ LLVM_DEBUG(DBGS() << "Consider " << it.value() << "\n");
+
+ // 1. Find a loop invariant subset write yielding `ret` that we can
+ // consider for hoisting.
+ // TODO: TypeSwitch when we add more cases.
+ OpOperand &ret = yield->getOpOperand(it.index());
+ FailureOr<vector::TransferWriteOp> transferWriteOp =
+ getLoopInvariantTransferWriteDefining(rewriter, forOp, it.value(),
+ ret);
+ FailureOr<tensor::InsertSliceOp> insertSliceOp =
+ getLoopInvariantInsertSliceDefining(rewriter, forOp, it.value(), ret);
+ if (failed(transferWriteOp) && failed(insertSliceOp)) {
+ LLVM_DEBUG(DBGS() << "no loop invariant write defining iter_args "
+ << it.value() << "\n");
+ continue;
+ }
+
+ Operation *writeOp = succeeded(transferWriteOp)
+ ? transferWriteOp->getOperation()
+ : insertSliceOp->getOperation();
+
+ // 2. Only accept writes with a single use (i.e. the yield).
+ if (!writeOp->hasOneUse()) {
+ LLVM_DEBUG(DBGS() << "write with more than 1 use " << *writeOp << "\n");
+ continue;
+ }
+
+ LLVM_DEBUG(DBGS() << "Write to hoist: " << *writeOp << "\n");
+
+ // 3. Find a matching read that can also be hoisted.
+ Operation *matchingReadOp = nullptr;
+ // TODO: TypeSwitch.
+ if (succeeded(transferWriteOp)) {
+ auto maybeTransferRead = findHoistableMatchingTransferRead(
+ rewriter, *transferWriteOp, it.value());
+ if (succeeded(maybeTransferRead))
+ matchingReadOp = maybeTransferRead->getOperation();
+ } else if (succeeded(insertSliceOp)) {
+ auto maybeExtractSlice = findHoistableMatchingExtractSlice(
+ rewriter, *insertSliceOp, it.value());
+ if (succeeded(maybeExtractSlice))
+ matchingReadOp = maybeExtractSlice->getOperation();
+ } else {
+ llvm_unreachable("unexpected case");
+ }
+ if (!matchingReadOp) {
+ LLVM_DEBUG(DBGS() << "No matching read\n");
+ continue;
+ }
+
+ // 4. Make sure no other use reads the part of the modified tensor.
+ // This is necessary to guard against hazards when non-conflicting subset
+ // ops are bypassed.
+ Operation *maybeUnknownOp =
+ isTensorChunkAccessedByUnknownOp(writeOp, matchingReadOp, it.value());
+ if (maybeUnknownOp) {
+ LLVM_DEBUG(DBGS() << "Tensor chunk accessed by unknown op, skip: "
+ << *maybeUnknownOp << "\n");
+ continue;
+ }
+
+ // 5. Perform the actual mechanical hoisting.
+ // TODO: TypeSwitch.
+ LLVM_DEBUG(DBGS() << "Read to hoist: " << *matchingReadOp << "\n");
+ if (succeeded(transferWriteOp)) {
+ newForOp = hoistTransferReadWrite(
+ rewriter, cast<vector::TransferReadOp>(matchingReadOp),
+ *transferWriteOp, it.value());
+ } else if (succeeded(insertSliceOp)) {
+ newForOp = hoistExtractInsertSlice(
+ rewriter, cast<tensor::ExtractSliceOp>(matchingReadOp),
+ *insertSliceOp, it.value());
+ } else {
+ llvm_unreachable("unexpected case");
+ }
+ break;
+ }
+ } while (forOp != newForOp);
+
+ return newForOp;
+}
diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
index 294dc810507b4..45ea541660fbd 100644
--- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
+++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
@@ -136,6 +136,16 @@ bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2) {
return v1 && v1 == v2;
}
+bool isEqualConstantIntOrValueArray(ArrayRef<OpFoldResult> ofrs1,
+ ArrayRef<OpFoldResult> ofrs2) {
+ if (ofrs1.size() != ofrs2.size())
+ return false;
+ for (auto [ofr1, ofr2] : llvm::zip_equal(ofrs1, ofrs2))
+ if (!isEqualConstantIntOrValue(ofr1, ofr2))
+ return false;
+ return true;
+}
+
/// Helper function to convert a vector of `OpFoldResult`s into a vector of
/// `Value`s. For each `OpFoldResult` in `valueOrAttrVec` return the fold result
/// if it casts to a `Value` or create an index-type constant if it casts to
diff --git a/mlir/test/Dialect/Linalg/hoisting.mlir b/mlir/test/Dialect/Linalg/hoisting.mlir
index 8830a4f427212..aeecb8cf95f89 100644
--- a/mlir/test/Dialect/Linalg/hoisting.mlir
+++ b/mlir/test/Dialect/Linalg/hoisting.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -test-transform-dialect-interpreter --split-input-file --allow-unregistered-dialect %s | FileCheck %s
+// RUN: mlir-opt -test-transform-dialect-interpreter -canonicalize --split-input-file --allow-unregistered-dialect %s | FileCheck %s
// CHECK-LABEL: func @hoist_vector_transfer_pairs(
// CHECK-SAME: %[[MEMREF0:[a-zA-Z0-9]*]]: memref<?x?xf32>,
@@ -29,7 +29,7 @@ func.func @hoist_vector_transfer_pairs(
// CHECK: vector.transfer_read %{{.*}} : memref<?x?xf32>, vector<5xf32>
// CHECK: "some_use"(%{{.*}}) : (vector<1xf32>) -> vector<1xf32>
// CHECK: "some_use"(%{{.*}}) : (vector<2xf32>) -> vector<2xf32>
-// CHECK: "some_use"(%[[MEMREF2]]) : (memref<?x?xf32>) -> vector<3xf32>
+// CHECK: "some_use"(%[[MEMREF2]], %{{.*}}) : (memref<?x?xf32>, vector<3xf32>) -> vector<3xf32>
// CHECK: "some_use"(%{{.*}}) : (vector<4xf32>) -> vector<4xf32>
// CHECK: "some_use"(%{{.*}}) : (vector<5xf32>) -> vector<5xf32>
// CHECK: vector.transfer_write %{{.*}} : vector<3xf32>, memref<?x?xf32>
@@ -56,7 +56,7 @@ func.func @hoist_vector_transfer_pairs(
"some_crippling_use"(%memref5) : (memref<?x?xf32>) -> ()
%u0 = "some_use"(%r0) : (vector<1xf32>) -> vector<1xf32>
%u1 = "some_use"(%r1) : (vector<2xf32>) -> vector<2xf32>
- %u2 = "some_use"(%memref2) : (memref<?x?xf32>) -> vector<3xf32>
+ %u2 = "some_use"(%memref2, %r2) : (memref<?x?xf32>, vector<3xf32>) -> vector<3xf32>
%u3 = "some_use"(%r3) : (vector<4xf32>) -> vector<4xf32>
%u4 = "some_use"(%r4) : (vector<5xf32>) -> vector<5xf32>
%u5 = "some_use"(%r5) : (vector<6xf32>) -> vector<6xf32>
@@ -173,6 +173,51 @@ transform.sequence failures(propagate) {
// -----
+// CHECK-LABEL: func @hoist_vector_transfer_pairs_in_affine_loops(
+// CHECK-SAME: %[[MEMREF0:[a-zA-Z0-9]+]]: memref<64x64xi32>,
+// CHECK-SAME: %[[MEMREF1:[a-zA-Z0-9]+]]: memref<64x64xi32>,
+// CHECK-SAME: %[[MEMREF2:[a-zA-Z0-9]+]]: memref<64x64xi32>) {
+// CHECK: %[[C0:.*]] = arith.constant 0 : i32
+// CHECK: affine.for %[[I:.*]] = 0 to 64 {
+// CHECK: affine.for %[[J:.*]] = 0 to 64 step 16 {
+// CHECK: %[[R0:.*]] = vector.transfer_read %[[MEMREF2]][%[[I]], %[[J]]], %[[C0]] : memref<64x64xi32>, vector<16xi32>
+// CHECK: %[[R:.*]] = affine.for %[[K:.*]] = 0 to 64 iter_args(%[[ACC:.*]] = %[[R0]]) -> (vector<16xi32>) {
+// CHECK: %[[AV:.*]] = vector.transfer_read %[[MEMREF0]][%[[I]], %[[K]]], %[[C0]] {{.*}}: memref<64x64xi32>, vector<16xi32>
+// CHECK: %[[BV:.*]] = vector.transfer_read %[[MEMREF1]][%[[K]], %[[J]]], %[[C0]] {{.*}}: memref<64x64xi32>, vector<16xi32>
+// CHECK: %[[T0:.*]] = arith.muli %[[AV]], %[[BV]] : vector<16xi32>
+// CHECK: %[[T1:.*]] = arith.addi %[[ACC]], %[[T0]] : vector<16xi32>
+// CHECK: affine.yield %[[T1]] : vector<16xi32>
+// CHECK: }
+// CHECK: vector.transfer_write %[[R]], %[[MEMREF2]][%[[I]], %[[J]]] : vector<16xi32>, memref<64x64xi32>
+// CHECK: }
+// CHECK: }
+func.func @hoist_vector_transfer_pairs_in_affine_loops(%memref0: memref<64x64xi32>, %memref1: memref<64x64xi32>, %memref2: memref<64x64xi32>) {
+ %c0_i32 = arith.constant 0 : i32
+ affine.for %arg3 = 0 to 64 {
+ affine.for %arg4 = 0 to 64 step 16 {
+ affine.for %arg5 = 0 to 64 {
+ %0 = vector.transfer_read %memref0[%arg3, %arg5], %c0_i32 {permutation_map = affine_map<(d0, d1) -> (0)>} : memref<64x64xi32>, vector<16xi32>
+ %1 = vector.transfer_read %memref1[%arg5, %arg4], %c0_i32 : memref<64x64xi32>, vector<16xi32>
+ %2 = vector.transfer_read %memref2[%arg3, %arg4], %c0_i32 : memref<64x64xi32>, vector<16xi32>
+ %3 = arith.muli %0, %1 : vector<16xi32>
+ %4 = arith.addi %2, %3 : vector<16xi32>
+ vector.transfer_write %4, %memref2[%arg3, %arg4] : vector<16xi32>, memref<64x64xi32>
+ }
+ }
+ }
+ return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["func.func"]} in %arg1
+ : (!pdl.operation) -> !pdl.operation
+ transform.structured.hoist_redundant_vector_transfers %0
+ : (!pdl.operation) -> !pdl.operation
+}
+
+// -----
+
// CHECK-LABEL: func @hoist_vector_transfer_pairs_tensor
func.func @hoist_vector_transfer_pairs_tensor(
%tensor0: tensor<?x?xf32>, %tensor1: tensor<?x?xf32>, %tensor2: tensor<?x?xf32>,
@@ -256,7 +301,7 @@ transform.sequence failures(propagate) {
^bb1(%arg1: !pdl.operation):
%0 = transform.structured.match ops{["func.func"]} in %arg1
: (!pdl.operation) -> !pdl.operation
- transform.structured.hoist_redundant_vector_transfers %0
+ transform.structured.hoist_redundant_tensor_subsets %0
: (!pdl.operation) -> !pdl.operation
}
@@ -351,7 +396,7 @@ transform.sequence failures(propagate) {
^bb1(%arg1: !pdl.operation):
%0 = transform.structured.match ops{["func.func"]} in %arg1
: (!pdl.operation) -> !pdl.operation
- transform.structured.hoist_redundant_vector_transfers %0
+ transform.structured.hoist_redundant_tensor_subsets %0
: (!pdl.operation) -> !pdl.operation
}
@@ -468,7 +513,7 @@ transform.sequence failures(propagate) {
^bb1(%arg1: !pdl.operation):
%0 = transform.structured.match ops{["func.func"]} in %arg1
: (!pdl.operation) -> !pdl.operation
- transform.structured.hoist_redundant_vector_transfers %0
+ transform.structured.hoist_redundant_tensor_subsets %0
: (!pdl.operation) -> !pdl.operation
}
@@ -501,6 +546,8 @@ func.func @hoist_vector_transfer_write_pairs_disjoint_tensor(
%r00 = vector.transfer_read %arg5[%c0, %c0], %cst: tensor<?x?xf32>, vector<2xf32>
%u00 = "some_use"(%r00) : (vector<2xf32>) -> vector<2xf32>
%w10 = vector.transfer_write %u00, %arg5[%c0, %c0] : vector<2xf32>, tensor<?x?xf32>
+
+ // Hoist by properly bypassing the disjoint write %w10.
%r01 = vector.transfer_read %w10[%c0, %c3], %cst: tensor<?x?xf32>, vector<2xf32>
%u01 = "some_use"(%r01) : (vector<2xf32>) -> vector<2xf32>
%w11 = vector.transfer_write %u01, %w10[%c0, %c3] : vector<2xf32>, tensor<?x?xf32>
@@ -513,51 +560,119 @@ transform.sequence failures(propagate) {
^bb1(%arg1: !pdl.operation):
%0 = transform.structured.match ops{["func.func"]} in %arg1
: (!pdl.operation) -> !pdl.operation
- transform.structured.hoist_redundant_vector_transfers %0
+ transform.structured.hoist_redundant_tensor_subsets %0
: (!pdl.operation) -> !pdl.operation
}
// -----
-// CHECK-LABEL: func @hoist_vector_transfer_pairs_in_affine_loops(
-// CHECK-SAME: %[[MEMREF0:[a-zA-Z0-9]+]]: memref<64x64xi32>,
-// CHECK-SAME: %[[MEMREF1:[a-zA-Z0-9]+]]: memref<64x64xi32>,
-// CHECK-SAME: %[[MEMREF2:[a-zA-Z0-9]+]]: memref<64x64xi32>) {
-// CHECK: %[[C0:.*]] = arith.constant 0 : i32
-// CHECK: affine.for %[[I:.*]] = 0 to 64 {
-// CHECK: affine.for %[[J:.*]] = 0 to 64 step 16 {
-// CHECK: %[[R0:.*]] = vector.transfer_read %[[MEMREF2]][%[[I]], %[[J]]], %[[C0]] : memref<64x64xi32>, vector<16xi32>
-// CHECK: %[[R:.*]] = affine.for %[[K:.*]] = 0 to 64 iter_args(%[[ACC:.*]] = %[[R0]]) -> (vector<16xi32>) {
-// CHECK: %[[AV:.*]] = vector.transfer_read %[[MEMREF0]][%[[I]], %[[K]]], %[[C0]] {{.*}}: memref<64x64xi32>, vector<16xi32>
-// CHECK: %[[BV:.*]] = vector.transfer_read %[[MEMREF1]][%[[K]], %[[J]]], %[[C0]] {{.*}}: memref<64x64xi32>, vector<16xi32>
-// CHECK: %[[T0:.*]] = arith.muli %[[AV]], %[[BV]] : vector<16xi32>
-// CHECK: %[[T1:.*]] = arith.addi %[[ACC]], %[[T0]] : vector<16xi32>
-// CHECK: affine.yield %[[T1]] : vector<16xi32>
-// CHECK: }
-// CHECK: vector.transfer_write %[[R]], %[[MEMREF2]][%[[I]], %[[J]]] : vector<16xi32>, memref<64x64xi32>
-// CHECK: }
-// CHECK: }
-func.func @hoist_vector_transfer_pairs_in_affine_loops(%memref0: memref<64x64xi32>, %memref1: memref<64x64xi32>, %memref2: memref<64x64xi32>) {
- %c0_i32 = arith.constant 0 : i32
- affine.for %arg3 = 0 to 64 {
- affine.for %arg4 = 0 to 64 step 16 {
- affine.for %arg5 = 0 to 64 {
- %0 = vector.transfer_read %memref0[%arg3, %arg5], %c0_i32 {permutation_map = affine_map<(d0, d1) -> (0)>} : memref<64x64xi32>, vector<16xi32>
- %1 = vector.transfer_read %memref1[%arg5, %arg4], %c0_i32 : memref<64x64xi32>, vector<16xi32>
- %2 = vector.transfer_read %memref2[%arg3, %arg4], %c0_i32 : memref<64x64xi32>, vector<16xi32>
- %3 = arith.muli %0, %1 : vector<16xi32>
- %4 = arith.addi %2, %3 : vector<16xi32>
- vector.transfer_write %4, %memref2[%arg3, %arg4] : vector<16xi32>, memref<64x64xi32>
- }
+// CHECK-LABEL: func @hoist_vector_transfer_pairs_tensor_and_slices_static_large_tensor
+// CHECK-SAME: %[[TENSOR0:[a-zA-Z0-9]*]]: tensor<100x100xf32>,
+// CHECK-SAME: %[[TENSOR1:[a-zA-Z0-9]*]]: tensor<200x200xf32>,
+// CHECK-SAME: %[[TENSOR2:[a-zA-Z0-9]*]]: tensor<300x300xf32>
+func.func @hoist_vector_transfer_pairs_tensor_and_slices_static_large_tensor(
+ %tensor0: tensor<100x100xf32>, %tensor1: tensor<200x200xf32>, %tensor2: tensor<300x300xf32>,
+ %val: index, %lb : index, %ub : index, %step: index) ->
+ (
+ tensor<100x100xf32>, tensor<200x200xf32>, tensor<300x300xf32>
+ ) {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0.0 : f32
+
+ // CHECK: scf.for %[[I:.*]] = {{.*}} iter_args(
+ // CHECK-SAME: %[[TENSOR0_ARG:[0-9a-zA-Z]+]] = %[[TENSOR0]],
+ // CHECK-SAME: %[[TENSOR1_ARG:[0-9a-zA-Z]+]] = %[[TENSOR1]],
+ // CHECK-SAME: %[[TENSOR2_ARG:[0-9a-zA-Z]+]] = %[[TENSOR2]]
+ // CHECK-SAME: ) ->
+ // CHECK-SAME: (tensor<100x100xf32>, tensor<200x200xf32>, tensor<300x300xf32>
+ %0:3 = scf.for %i = %lb to %ub step %step
+ iter_args(%arg0 = %tensor0, %arg1 = %tensor1, %arg2 = %tensor2)
+ -> (tensor<100x100xf32>, tensor<200x200xf32>, tensor<300x300xf32>) {
+
+ // Hoisted
+ // CHECK: %[[ST0:.*]] = tensor.extract_slice %[[TENSOR0_ARG]][%[[I]], %[[I]]]{{.*}}: tensor<100x100xf32> to tensor<?x?xf32>
+ // CHECK: %[[V0:.*]] = vector.transfer_read %[[ST0]]{{.*}} : tensor<?x?xf32>, vector<1xf32>
+
+ // CHECK: %[[R:.*]]:3 = scf.for %[[J:.*]] = {{.*}} iter_args(
+ // CHECK-SAME: %[[TENSOR1_ARG_L2:[0-9a-zA-Z]+]] = %[[TENSOR1_ARG]]
+ // CHECK-SAME: %[[TENSOR2_ARG_L2:[0-9a-zA-Z]+]] = %[[TENSOR2_ARG]]
+ // CHECK-SAME: %[[V0_ARG_L2:[0-9a-zA-Z]+]] = %[[V0]]
+ // CHECK-SAME: ) ->
+ // CHECK-SAME: (tensor<200x200xf32>, tensor<300x300xf32>, vector<1xf32>
+ %1:3 = scf.for %j = %lb to %ub step %step
+ iter_args(%arg6 = %arg0, %arg7 = %arg1, %arg8 = %arg2)
+ -> (tensor<100x100xf32>, tensor<200x200xf32>, tensor<300x300xf32>) {
+ // Hoists.
+ %st0 = tensor.extract_slice %arg6[%i, %i][%step, %step][1, 1] : tensor<100x100xf32> to tensor<?x?xf32>
+ %r0 = vector.transfer_read %st0[%c0, %c0], %cst: tensor<?x?xf32>, vector<1xf32>
+
+ // CHECK: %[[ST1:.*]] = tensor.extract_slice %[[TENSOR1_ARG_L2]][%[[J]],{{.*}}: tensor<200x200xf32> to tensor<?x?xf32>
+ // CHECK: %[[V1:.*]] = vector.transfer_read %[[ST1]]{{.*}} : tensor<?x?xf32>, vector<2xf32>
+ // Does not hoist (slice depends on %j)
+ %st1 = tensor.extract_slice %arg7[%j, %c0][%step, %step][1, 1] : tensor<200x200xf32> to tensor<?x?xf32>
+ %r1 = vector.transfer_read %st1[%c0, %c0], %cst: tensor<?x?xf32>, vector<2xf32>
+
+ // CHECK: %[[ST2:.*]] = tensor.extract_slice %[[TENSOR2_ARG_L2]][%[[I]],{{.*}}: tensor<300x300xf32> to tensor<?x?xf32>
+ // CHECK: %[[V2:.*]] = vector.transfer_read %[[ST2]]{{.*}} : tensor<?x?xf32>, vector<3xf32>
+ // Does not hoist, 2 slice %arg8.
+ %st2 = tensor.extract_slice %arg8[%i, %c0][%step, %step][1, 1] : tensor<300x300xf32> to tensor<?x?xf32>
+ %r2 = vector.transfer_read %st2[%c0, %c0], %cst: tensor<?x?xf32>, vector<3xf32>
+
+ // CHECK: %[[U0:.*]] = "some_use"(%[[V0_ARG_L2]]) : (vector<1xf32>) -> vector<1xf32>
+ // CHECK: %[[U1:.*]] = "some_use"(%[[V1]]) : (vector<2xf32>) -> vector<2xf32>
+ // CHECK: %[[U2:.*]] = "some_use"(%[[V2]]) : (vector<3xf32>) -> vector<3xf32>
+ %u0 = "some_use"(%r0) : (vector<1xf32>) -> vector<1xf32>
+ %u1 = "some_use"(%r1) : (vector<2xf32>) -> vector<2xf32>
+ %u2 = "some_use"(%r2) : (vector<3xf32>) -> vector<3xf32>
+
+ // Hoists
+ %w0 = vector.transfer_write %u0, %st0[%c0, %c0] : vector<1xf32>, tensor<?x?xf32>
+
+ // CHECK-DAG: %[[STI1:.*]] = vector.transfer_write %[[U1]], %{{.*}} : vector<2xf32>, tensor<?x?xf32>
+ // Does not hoist (associated slice depends on %j).
+ %w1 = vector.transfer_write %u1, %st1[%i, %i] : vector<2xf32>, tensor<?x?xf32>
+
+ // CHECK-DAG: %[[STI2:.*]] = vector.transfer_write %[[U2]], %{{.*}} : vector<3xf32>, tensor<?x?xf32>
+ // Does not hoist, 2 slice / insert_slice for %arg8.
+ %w2 = vector.transfer_write %u2, %st2[%c0, %c0] : vector<3xf32>, tensor<?x?xf32>
+
+ // Hoists.
+ %sti0 = tensor.insert_slice %w0 into %arg6[%i, %i][%step, %step][1, 1] : tensor<?x?xf32> into tensor<100x100xf32>
+
+ // CHECK-DAG: tensor.insert_slice %[[STI1]] into %[[TENSOR1_ARG_L2]][%[[J]],{{.*}}: tensor<?x?xf32> into tensor<200x200xf32>
+ // Does not hoist (depends on %j).
+ %sti1 = tensor.insert_slice %w1 into %arg7[%j, %c0][%step, %step][1, 1] : tensor<?x?xf32> into tensor<200x200xf32>
+
+ // CHECK-DAG: tensor.insert_slice %[[STI2]] into %[[TENSOR2_ARG_L2]][%[[I]],{{.*}}: tensor<?x?xf32> into tensor<300x300xf32>
+ // Does not hoist, 2 slice / insert_slice for %arg8.
+ %sti2 = tensor.insert_slice %w2 into %arg8[%i, %c0][%step, %step][1, 1] : tensor<?x?xf32> into tensor<300x300xf32>
+ // Extract with a
diff erent stride to make sure we cannot fold this extract with the above insert.
+ %st22 = tensor.extract_slice %sti2[%i, %c0][%step, %step][2, 1] : tensor<300x300xf32> to tensor<?x?xf32>
+ %sti22 = tensor.insert_slice %st22 into %arg8[%i, %c0][%step, %step][1, 1] : tensor<?x?xf32> into tensor<300x300xf32>
+
+ // CHECK: scf.yield {{.*}} : tensor<200x200xf32>, tensor<300x300xf32>, vector<1xf32>
+ // CHECK: }
+ scf.yield %sti0, %sti1, %sti22:
+ tensor<100x100xf32>, tensor<200x200xf32>, tensor<300x300xf32>
}
+
+ // Hoisted
+ // CHECK: %[[STI0:.*]] = vector.transfer_write %[[R]]#2, %[[ST0]]{{.*}} : vector<1xf32>, tensor<?x?xf32>
+ // CHECK: tensor.insert_slice %[[STI0]] into %[[TENSOR0_ARG]][%[[I]], %[[I]]]{{.*}} : tensor<?x?xf32> into tensor<100x100xf32>
+
+ // CHECK: scf.yield {{.*}} : tensor<100x100xf32>, tensor<200x200xf32>, tensor<300x300xf32>
+ scf.yield %1#0, %1#1, %1#2 :
+ tensor<100x100xf32>, tensor<200x200xf32>, tensor<300x300xf32>
+
+ // CHECK: }
}
- return
+ return %0#0, %0#1, %0#2 : tensor<100x100xf32>, tensor<200x200xf32>, tensor<300x300xf32>
}
transform.sequence failures(propagate) {
^bb1(%arg1: !pdl.operation):
%0 = transform.structured.match ops{["func.func"]} in %arg1
: (!pdl.operation) -> !pdl.operation
- transform.structured.hoist_redundant_vector_transfers %0
+ transform.structured.hoist_redundant_tensor_subsets %0
: (!pdl.operation) -> !pdl.operation
}
More information about the Mlir-commits
mailing list