[Mlir-commits] [mlir] [mlir][transform] LISH: Add transform op (PR #70630)
Matthias Springer
llvmlistbot at llvm.org
Sun Oct 29 23:53:33 PDT 2023
https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/70630
Add a transform op for loop-invariant subset hoisting. Delete the old transform op from the Linalg dialect.
Depends on #70535, #70617, #70619, #70623, #70628, #70629. Only review the top commit.
>From d5a79135800b91462c8e4563af0e7295f80743f5 Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Mon, 30 Oct 2023 14:32:21 +0900
Subject: [PATCH 1/6] [mlir] `SubsetOpInterface` and
`SubsetExtractionOpInterface`
---
.../Dialect/Bufferization/IR/Bufferization.h | 2 +-
.../Bufferization/IR/BufferizationOps.td | 3 +-
.../SubsetInsertionOpInterfaceImpl.h | 3 +-
.../SubsetInsertionOpInterfaceImpl.h | 3 +-
mlir/include/mlir/InitAllDialects.h | 4 +-
mlir/include/mlir/Interfaces/CMakeLists.txt | 2 +-
.../Interfaces/SubsetInsertionOpInterface.h | 27 --
.../Interfaces/SubsetInsertionOpInterface.td | 155 ----------
.../mlir/Interfaces/SubsetOpInterface.h | 45 +++
.../mlir/Interfaces/SubsetOpInterface.td | 267 ++++++++++++++++++
.../mlir/Interfaces/ValueBoundsOpInterface.h | 36 +--
.../Bufferization/IR/BufferizationOps.cpp | 12 +
.../Dialect/Bufferization/IR/CMakeLists.txt | 2 +-
.../Bufferization/Transforms/CMakeLists.txt | 2 +-
.../Transforms/EmptyTensorElimination.cpp | 2 +-
.../Transforms/OneShotAnalysis.cpp | 2 +-
.../Dialect/Linalg/Transforms/CMakeLists.txt | 2 +-
.../SubsetInsertionOpInterfaceImpl.cpp | 31 +-
.../BufferizableOpInterfaceImpl.cpp | 2 +-
.../Dialect/Tensor/Transforms/CMakeLists.txt | 2 +-
.../SubsetInsertionOpInterfaceImpl.cpp | 137 +++++++--
mlir/lib/Interfaces/CMakeLists.txt | 9 +-
.../Interfaces/SubsetInsertionOpInterface.cpp | 23 --
mlir/lib/Interfaces/SubsetOpInterface.cpp | 58 ++++
.../lib/Interfaces/ValueBoundsOpInterface.cpp | 76 +++++
.../llvm-project-overlay/mlir/BUILD.bazel | 33 +--
.../mlir/python/BUILD.bazel | 2 +-
27 files changed, 654 insertions(+), 288 deletions(-)
delete mode 100644 mlir/include/mlir/Interfaces/SubsetInsertionOpInterface.h
delete mode 100644 mlir/include/mlir/Interfaces/SubsetInsertionOpInterface.td
create mode 100644 mlir/include/mlir/Interfaces/SubsetOpInterface.h
create mode 100644 mlir/include/mlir/Interfaces/SubsetOpInterface.td
delete mode 100644 mlir/lib/Interfaces/SubsetInsertionOpInterface.cpp
create mode 100644 mlir/lib/Interfaces/SubsetOpInterface.cpp
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h b/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h
index c035190f43e3950..e98b5728b38ef81 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h
@@ -15,7 +15,7 @@
#include "mlir/Interfaces/CopyOpInterface.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
-#include "mlir/Interfaces/SubsetInsertionOpInterface.h"
+#include "mlir/Interfaces/SubsetOpInterface.h"
//===----------------------------------------------------------------------===//
// Bufferization Dialect
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
index 72a4aa712f49c98..e6b6d052df96a8c 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
@@ -15,7 +15,7 @@ include "mlir/Dialect/Bufferization/IR/BufferizationBase.td"
include "mlir/Interfaces/DestinationStyleOpInterface.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
-include "mlir/Interfaces/SubsetInsertionOpInterface.td"
+include "mlir/Interfaces/SubsetOpInterface.td"
include "mlir/Interfaces/CopyOpInterface.td"
class Bufferization_Op<string mnemonic, list<Trait> traits = []>
@@ -220,6 +220,7 @@ def Bufferization_MaterializeInDestinationOp
AllElementTypesMatch<["source", "dest"]>,
BufferizableOpInterface, DestinationStyleOpInterface,
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
+ DeclareOpInterfaceMethods<SubsetOpInterface>,
DeclareOpInterfaceMethods<SubsetInsertionOpInterface,
["getSourceOperand", "getValuesNeededToBuildSubsetExtraction",
"buildSubsetExtraction", "isEquivalentSubset"]>,
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/SubsetInsertionOpInterfaceImpl.h b/mlir/include/mlir/Dialect/Linalg/Transforms/SubsetInsertionOpInterfaceImpl.h
index 023a46df2620109..94b0fb25b506650 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/SubsetInsertionOpInterfaceImpl.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/SubsetInsertionOpInterfaceImpl.h
@@ -13,8 +13,7 @@ namespace mlir {
class DialectRegistry;
namespace linalg {
-void registerSubsetInsertionOpInterfaceExternalModels(
- DialectRegistry ®istry);
+void registerSubsetOpInterfaceExternalModels(DialectRegistry ®istry);
} // namespace linalg
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.h b/mlir/include/mlir/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.h
index e21b07d8a2705a0..019da189a8c991b 100644
--- a/mlir/include/mlir/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.h
+++ b/mlir/include/mlir/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.h
@@ -13,8 +13,7 @@ namespace mlir {
class DialectRegistry;
namespace tensor {
-void registerSubsetInsertionOpInterfaceExternalModels(
- DialectRegistry ®istry);
+void registerSubsetOpInterfaceExternalModels(DialectRegistry ®istry);
} // namespace tensor
} // namespace mlir
diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index 00f400aab5d50a0..7c2ffb7408d9afd 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -151,7 +151,7 @@ inline void registerAllDialects(DialectRegistry ®istry) {
cf::registerBufferDeallocationOpInterfaceExternalModels(registry);
gpu::registerBufferDeallocationOpInterfaceExternalModels(registry);
linalg::registerBufferizableOpInterfaceExternalModels(registry);
- linalg::registerSubsetInsertionOpInterfaceExternalModels(registry);
+ linalg::registerSubsetOpInterfaceExternalModels(registry);
linalg::registerTilingInterfaceExternalModels(registry);
linalg::registerValueBoundsOpInterfaceExternalModels(registry);
memref::registerAllocationOpInterfaceExternalModels(registry);
@@ -167,7 +167,7 @@ inline void registerAllDialects(DialectRegistry ®istry) {
tensor::registerBufferizableOpInterfaceExternalModels(registry);
tensor::registerFindPayloadReplacementOpInterfaceExternalModels(registry);
tensor::registerInferTypeOpInterfaceExternalModels(registry);
- tensor::registerSubsetInsertionOpInterfaceExternalModels(registry);
+ tensor::registerSubsetOpInterfaceExternalModels(registry);
tensor::registerTilingInterfaceExternalModels(registry);
tensor::registerValueBoundsOpInterfaceExternalModels(registry);
vector::registerBufferizableOpInterfaceExternalModels(registry);
diff --git a/mlir/include/mlir/Interfaces/CMakeLists.txt b/mlir/include/mlir/Interfaces/CMakeLists.txt
index 36a04ff0eaeaf4b..d81298bb4daf014 100644
--- a/mlir/include/mlir/Interfaces/CMakeLists.txt
+++ b/mlir/include/mlir/Interfaces/CMakeLists.txt
@@ -12,7 +12,7 @@ add_mlir_interface(ParallelCombiningOpInterface)
add_mlir_interface(RuntimeVerifiableOpInterface)
add_mlir_interface(ShapedOpInterfaces)
add_mlir_interface(SideEffectInterfaces)
-add_mlir_interface(SubsetInsertionOpInterface)
+add_mlir_interface(SubsetOpInterface)
add_mlir_interface(TilingInterface)
add_mlir_interface(ValueBoundsOpInterface)
add_mlir_interface(VectorInterfaces)
diff --git a/mlir/include/mlir/Interfaces/SubsetInsertionOpInterface.h b/mlir/include/mlir/Interfaces/SubsetInsertionOpInterface.h
deleted file mode 100644
index 3a6dfceadcce7c0..000000000000000
--- a/mlir/include/mlir/Interfaces/SubsetInsertionOpInterface.h
+++ /dev/null
@@ -1,27 +0,0 @@
-//===- SubsetInsertionOpInterface.h - Tensor Subsets ------------*- C++ -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_INTERFACES_SUBSETINSERTIONOPINTERFACE_H_
-#define MLIR_INTERFACES_SUBSETINSERTIONOPINTERFACE_H_
-
-#include "mlir/IR/OpDefinition.h"
-
-namespace mlir {
-namespace detail {
-
-/// Return the destination/"init" operand of the op if it implements the
-/// `DestinationStyleOpInterface` and has exactly one "init" operand. Asserts
-/// otherwise.
-OpOperand &defaultGetDestinationOperand(Operation *op);
-
-} // namespace detail
-} // namespace mlir
-
-#include "mlir/Interfaces/SubsetInsertionOpInterface.h.inc"
-
-#endif // MLIR_INTERFACES_SUBSETINSERTIONOPINTERFACE_H_
diff --git a/mlir/include/mlir/Interfaces/SubsetInsertionOpInterface.td b/mlir/include/mlir/Interfaces/SubsetInsertionOpInterface.td
deleted file mode 100644
index ef94a8ae9a60efd..000000000000000
--- a/mlir/include/mlir/Interfaces/SubsetInsertionOpInterface.td
+++ /dev/null
@@ -1,155 +0,0 @@
-//===-- SubsetInsertionOpInterface.td - Tensor Subsets -----*- tablegen -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef SUBSET_INSERTION_OP_INTERFACE
-#define SUBSET_INSERTION_OP_INTERFACE
-
-include "mlir/IR/OpBase.td"
-
-def SubsetInsertionOpInterface : OpInterface<"SubsetInsertionOpInterface"> {
- let description = [{
- This interface can be implemented by ops that insert a source tensor into
- a destination tensor.
-
- The elements in the destination tensor that are overwritten by this
- insertion are called the "subset". How the subset is defined is up to the
- op. E.g., "tensor.insert_slice" defines the subset via a hyperrectangular
- slice. A scatter operation could define the subset via a list of indices.
-
- Ops that deal with tensor subsets come in two flavours:
- - Insertion flavor: Ops that insert a source tensor into a destination
- tensor at the specified subset. Such ops usually return a new destination
- tensor and implement the `DestinationStyleOpInterface`. Insertion ops can
- implement the `SubsetInsertionOpInterface`. Example: "tensor.insert_slice"
- - Extraction flavor: Ops that define a tensor subset. They extract a
- specified subset from a tensor. There is currently no op interface for
- such ops. Example: "tensor.extract_slice"
-
- This interface provides helper methods for efficient bufferization of
- subset-based tensor IR. Tensor subsets can bufferize to buffer "views"/
- "aliases" (in contrast to one or multiple less efficient buffer allocation).
-
- This interface is queried by One-Shot Bufferize to detect cases where a
- seeming read-after-write is not actually a conflict because the respective
- ops are operating on equivalent subsets. More details can be found in the
- documentation of One-Shot Analysis (see `areNonConflictingSubsets`).
-
- Note: This interface currently assumes that a subset op inserts a single
- tensor (source) into a destination tensor at a single subset.
- }];
- let cppNamespace = "::mlir";
- let methods = [
- InterfaceMethod<
- /*desc=*/[{
- Return the source tensor operand.
- }],
- /*retType=*/"::mlir::OpOperand &",
- /*methodName=*/"getSourceOperand",
- /*args=*/(ins)
- >,
- InterfaceMethod<
- /*desc=*/[{
- Return the destination tensor operand.
- }],
- /*retType=*/"::mlir::OpOperand &",
- /*methodName=*/"getDestinationOperand",
- /*args=*/(ins),
- /*methodBody=*/"",
- /*defaultImplementation=*/[{
- return ::mlir::detail::defaultGetDestinationOperand(
- $_op.getOperation());
- }]
- >,
- InterfaceMethod<
- /*desc=*/[{
- Return "true" if this operation inserts into a subset that is
- equivalent to the subset defined by `candidate`.
-
- Two subsets are "equivalent" and "same" if they can bufferize to the
- same buffer views/aliases. If they are "equivalent", the tensor IR
- may be expressed in terms of different SSA values (but they could
- bufferize to MemRef SSA values that can CSE without breaking
- correctness). `equivalenceFn` should return "true" if the two given
- values are equivalent.
-
- Example:
- ```
- // The subset of the SubsetInsertionOpInterface op %1 is equivalent to
- // the subset defined by %2 (but not "same"):
- %0 = arith.select %c, %t, %t : tensor<?xf32>
- %1 = tensor.insert_slice %x into %0[0][5][1]
- : tensor<5xf32> into tensor<?xf32>
- %2 = tensor.extract_slice %t[0][5][1] : tensor<?xf32> to tensor<5xf32>
-
- // The subset of the SubsetInsertionOpInterface op %1 is equivalent to
- // and "same" as the subset defined by %2.
- %1 = tensor.insert_slice %x into %t[0][5][1]
- : tensor<5xf32> into tensor<?xf32>
- %2 = tensor.extract_slice %t[0][5][1] : tensor<?xf32> to tensor<5xf32>
- ```
- }],
- /*retType=*/"bool",
- /*methodName=*/"isEquivalentSubset",
- /*args=*/(ins
- "::mlir::Value":$candidate,
- "::llvm::function_ref<bool(Value, Value)>":$equivalenceFn)
- >,
- InterfaceMethod<
- /*desc=*/[{
- Return the subset of the destination tensor that this operation
- inserts into.
-
- Example:
- ```
- // SubsetOpInterface op:
- %0 = tensor.insert_slice %t0 into %t1[%pos][5][1]
- : tensor<5xf32> into tensor<?xf32>
- // Subset (built by this function):
- %1 = tensor.extract_slice %t1[%pos][5][1]
- : tensor<?xf32> to tensor<5xf32>
- ```
-
- Note: Implementations do not necessarily have to build new IR. They
- may return existing SSA values.
- }],
- /*retType=*/"::mlir::Value",
- /*methodName=*/"buildSubsetExtraction",
- /*args=*/(ins "::mlir::OpBuilder &":$builder, "Location":$loc)
- >,
- InterfaceMethod<
- /*desc=*/[{
- Return all SSA values that are needed (i.e., must be in scope) at the
- insertion of the builder when calling `buildSubsetExtraction`. Users
- of `buildSubsetExtraction` can use this helper method to find a
- suitable insertion point.
-
- Example: The SSA values needed to build the subset in the example of
- `buildSubsetExtraction` are %t1 and %pos.
- }],
- /*retType=*/"::llvm::SmallVector<::mlir::Value>",
- /*methodName=*/"getValuesNeededToBuildSubsetExtraction",
- /*args=*/(ins)
- >,
- ];
-
- let extraClassDeclaration = [{
- /// Return "true" if this operation inserts into the same subset as defined
- /// by `candidate`.
- ///
- /// Note: This function is useful outside of bufferization, where no tensor
- /// equivalence information is available.
- bool isSameSubset(OpResult candidate) {
- auto subsetOp = cast<::mlir::SubsetInsertionOpInterface>(
- getOperation());
- return subsetOp.isEquivalentSubset(
- candidate, [](Value v1, Value v2) { return v1 == v2; });
- }
- }];
-}
-
-#endif // SUBSET_INSERTION_OP_INTERFACE
diff --git a/mlir/include/mlir/Interfaces/SubsetOpInterface.h b/mlir/include/mlir/Interfaces/SubsetOpInterface.h
new file mode 100644
index 000000000000000..049cf2456a9c842
--- /dev/null
+++ b/mlir/include/mlir/Interfaces/SubsetOpInterface.h
@@ -0,0 +1,45 @@
+//===- SubsetOpInterface.h - Tensor Subsets ---------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_INTERFACES_SUBSETOPINTERFACE_H_
+#define MLIR_INTERFACES_SUBSETOPINTERFACE_H_
+
+#include "mlir/IR/OpDefinition.h"
+
+namespace mlir {
+class SubsetOpInterface;
+class SubsetExtractionOpInterface;
+class SubsetInsertionOpInterface;
+
+namespace detail {
+
+/// Return the destination/"init" operand of the op if it implements the
+/// `DestinationStyleOpInterface` and has exactly one "init" operand. Asserts
+/// otherwise.
+OpOperand &defaultGetDestinationOperand(Operation *op);
+
+/// Return the updated destination result of the op if it implements the
+/// `DestinationStyleOpInterface`.
+OpResult defaultGetUpdatedDestination(Operation *op);
+
+/// Default implementation of `isEquivalentSubset`.
+bool defaultIsEquivalentSubset(Operation *op, Value candidate,
+ function_ref<bool(Value, Value)> equivalenceFn);
+
+/// Verify `SubsetOpInterface`.
+LogicalResult verifySubsetOpInterface(SubsetOpInterface op);
+
+/// Verify `SubsetExtractionOpInterface`.
+LogicalResult verifySubsetExtractionOpInterface(SubsetExtractionOpInterface op);
+
+} // namespace detail
+} // namespace mlir
+
+#include "mlir/Interfaces/SubsetOpInterface.h.inc"
+
+#endif // MLIR_INTERFACES_SUBSETOPINTERFACE_H_
diff --git a/mlir/include/mlir/Interfaces/SubsetOpInterface.td b/mlir/include/mlir/Interfaces/SubsetOpInterface.td
new file mode 100644
index 000000000000000..07d62b8319c2961
--- /dev/null
+++ b/mlir/include/mlir/Interfaces/SubsetOpInterface.td
@@ -0,0 +1,267 @@
+//===-- SubsetOpInterface.td - Tensor Subsets --------------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef SUBSET_OP_INTERFACE
+#define SUBSET_OP_INTERFACE
+
+include "mlir/IR/OpBase.td"
+
+def SubsetOpInterface : OpInterface<"SubsetOpInterface"> {
+ let description = [{
+ This interface can be implemented by ops that operate on tensor subsets. A
+ "subset" is a part of a tensor. This interface describes the subset that
+ an implementing op operates on. Only the specified subset may be accessed by
+ the op.
+
+ Subset ops come in two flavours and ops that implement the
+ `SubsetOpInterface` must also implement one of the respective interfaces.
+ - Insertion flavor: Ops that insert a source value into a destination
+ tensor at the specified subset. Such ops return an updated destination
+ tensor and usually implement the `DestinationStyleOpInterface`. Insertion
+ ops must implement the `SubsetInsertionOpInterface`.
+ - Extraction flavor: Ops that extract at a subset. Extraction ops must
+ implement the `SubsetExtractionOpInterface`.
+
+ How the subset is specified is up to the implementing op. E.g.:
+ - `tensor.extract_slice/insert_slice` describe the subset as a
+ hyperrectangular slice.
+ - `tensor.gather/scatter` describe the subset as list of indices. (Not
+ implemented yet.)
+
+ Note: This interface does not expose any interface methods to get a
+ description of the accessed subset. That is because there is currently no
+ efficient way to describe arbitrary subsets. This interface merely provides
+ interface methods to check if two subsets are equivalent or disjoint.
+ }];
+
+ let cppNamespace = "::mlir";
+ let methods = [
+ InterfaceMethod<
+ /*desc=*/[{
+ Return "true" if this op and the given candidate subset op operate on
+ an equivalent subset. Return "false" is the two subsets are disjoint
+ or cannot be proven to be equivalent.
+ }],
+ /*retType=*/"bool",
+ /*methodName=*/"operatesOnEquivalentSubset",
+ /*args=*/(ins
+ "::mlir::SubsetOpInterface":$candidate,
+ "::llvm::function_ref<bool(Value, Value)>":$equivalenceFn)
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Return "true" if this op and the given candidate subset op operate on
+ disjoint subsets. Return "false" is the two subsets are equivalent,
+ overlapping or cannot be proven to be disjoint.
+ }],
+ /*retType=*/"bool",
+ /*methodName=*/"operatesOnDisjointSubset",
+ /*args=*/(ins
+ "::mlir::SubsetOpInterface":$candidate,
+ "::llvm::function_ref<bool(Value, Value)>":$equivalenceFn)
+ >,
+ ];
+
+ let verify = [{
+ return ::mlir::detail::verifySubsetOpInterface(
+ ::mlir::cast<::mlir::SubsetOpInterface>($_op));
+ }];
+}
+
+def SubsetExtractionOpInterface
+ : OpInterface<"SubsetExtractionOpInterface", [SubsetOpInterface]> {
+ let description = [{
+ This interface can be implemented by ops that extract a value from
+ a source tensor at a specified subset. The elements in the source tensor
+ that are read by this extraction are called "subset".
+
+ Extraction ops must have a single result value.
+ }];
+
+ let cppNamespace = "::mlir";
+ let methods = [
+ InterfaceMethod<
+ /*desc=*/[{
+ Return the source tensor operand.
+ }],
+ /*retType=*/"::mlir::OpOperand &",
+ /*methodName=*/"getSourceOperand",
+ /*args=*/(ins)
+ >,
+ ];
+
+ let verify = [{
+ return ::mlir::detail::verifySubsetExtractionOpInterface(
+ ::mlir::cast<::mlir::SubsetExtractionOpInterface>($_op));
+ }];
+
+ let extraClassDeclaration = [{
+ /// Return the single result of this op.
+ ::mlir::Value getResult() {
+ return getOperation()->getResult(0);
+ }
+ }];
+}
+
+def SubsetInsertionOpInterface
+ : OpInterface<"SubsetInsertionOpInterface", [SubsetOpInterface]> {
+ let description = [{
+ This interface can be implemented by ops that insert a source value into
+ a destination tensor at a specified subset. The elements in the destination
+ tensor that are overwritten by this insertion are called "subset". The
+ updated destination tensor is returned.
+
+ This interface provides helper methods for efficient bufferization of
+ subset-based tensor IR. Tensor subsets can bufferize to buffer "views"/
+ "aliases" (in contrast to one or multiple less efficient buffer allocation).
+
+ This interface is queried by One-Shot Bufferize to detect cases where a
+ seeming read-after-write is not actually a conflict because the respective
+ ops are operating on equivalent subsets. More details can be found in the
+ documentation of One-Shot Analysis (see `areNonConflictingSubsets`).
+ }];
+
+ let cppNamespace = "::mlir";
+ let methods = [
+ InterfaceMethod<
+ /*desc=*/[{
+ Return the source operand. The source operand can have any type.
+ }],
+ /*retType=*/"::mlir::OpOperand &",
+ /*methodName=*/"getSourceOperand",
+ /*args=*/(ins)
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Return the destination operand. The destination operand must be a
+ tensor.
+
+ This function does not have to be implemented for destination style
+ ops that exactly one "init".
+ }],
+ /*retType=*/"::mlir::OpOperand &",
+ /*methodName=*/"getDestinationOperand",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return ::mlir::detail::defaultGetDestinationOperand(
+ $_op.getOperation());
+ }]
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Return the updated destination result.
+
+ This function does not have to be implemented for destination style
+ ops.
+ }],
+ /*retType=*/"::mlir::OpResult",
+ /*methodName=*/"getUpdatedDestination",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return ::mlir::detail::defaultGetUpdatedDestination(
+ $_op.getOperation());
+ }]
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Return "true" if this operation inserts into a subset that is
+ equivalent to the subset defined by `candidate`.
+
+ Two subsets are "equivalent" and "same" if they can bufferize to the
+ same buffer views/aliases. If they are "equivalent", the tensor IR
+ may be expressed in terms of different SSA values (but they could
+ bufferize to MemRef SSA values that can CSE without breaking
+ correctness). `equivalenceFn` should return "true" if the two given
+ values are equivalent.
+
+ Example:
+ ```
+ // The subset of the SubsetInsertionOpInterface op %1 is equivalent to
+ // the subset defined by %2 (but not "same"):
+ %0 = arith.select %c, %t, %t : tensor<?xf32>
+ %1 = tensor.insert_slice %x into %0[0][5][1]
+ : tensor<5xf32> into tensor<?xf32>
+ %2 = tensor.extract_slice %t[0][5][1] : tensor<?xf32> to tensor<5xf32>
+
+ // The subset of the SubsetInsertionOpInterface op %1 is equivalent to
+ // and "same" as the subset defined by %2.
+ %1 = tensor.insert_slice %x into %t[0][5][1]
+ : tensor<5xf32> into tensor<?xf32>
+ %2 = tensor.extract_slice %t[0][5][1] : tensor<?xf32> to tensor<5xf32>
+ ```
+
+ Note: By default, this function calls
+ `SubsetOpInterface::operatesOnEquivalentSubset`.
+ }],
+ /*retType=*/"bool",
+ /*methodName=*/"isEquivalentSubset",
+ /*args=*/(ins
+ "::mlir::Value":$candidate,
+ "::llvm::function_ref<bool(Value, Value)>":$equivalenceFn),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return ::mlir::detail::defaultIsEquivalentSubset(
+ $_op.getOperation(), candidate, equivalenceFn);
+ }]
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Return the subset of the destination tensor that this operation
+ inserts into.
+
+ Example:
+ ```
+ // SubsetOpInterface op:
+ %0 = tensor.insert_slice %t0 into %t1[%pos][5][1]
+ : tensor<5xf32> into tensor<?xf32>
+ // Subset (built by this function):
+ %1 = tensor.extract_slice %t1[%pos][5][1]
+ : tensor<?xf32> to tensor<5xf32>
+ ```
+
+ Note: Implementations do not necessarily have to build new IR. They
+ may return existing SSA values.
+ }],
+ /*retType=*/"::mlir::Value",
+ /*methodName=*/"buildSubsetExtraction",
+ /*args=*/(ins "::mlir::OpBuilder &":$builder, "Location":$loc)
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Return all SSA values that are needed (i.e., must be in scope) at the
+ insertion of the builder when calling `buildSubsetExtraction`. Users
+ of `buildSubsetExtraction` can use this helper method to find a
+ suitable insertion point.
+
+ Example: The SSA values needed to build the subset in the example of
+ `buildSubsetExtraction` are %t1 and %pos.
+ }],
+ /*retType=*/"::llvm::SmallVector<::mlir::Value>",
+ /*methodName=*/"getValuesNeededToBuildSubsetExtraction",
+ /*args=*/(ins)
+ >,
+ ];
+
+ let extraClassDeclaration = [{
+ /// Return "true" if this operation inserts into the same subset as defined
+ /// by `candidate`.
+ ///
+ /// Note: This function is useful outside of bufferization, where no tensor
+ /// equivalence information is available.
+ bool isSameSubset(OpResult candidate) {
+ auto subsetOp = cast<::mlir::SubsetInsertionOpInterface>(
+ getOperation());
+ return subsetOp.isEquivalentSubset(
+ candidate, [](Value v1, Value v2) { return v1 == v2; });
+ }
+ }];
+}
+
+#endif // SUBSET_OP_INTERFACE
diff --git a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
index 8f11c563e0cbd91..8e2986a2d1f05f6 100644
--- a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
+++ b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
@@ -19,6 +19,7 @@
#include <queue>
namespace mlir {
+class OffsetSizeAndStrideOpInterface;
using ValueDimList = SmallVector<std::pair<Value, std::optional<int64_t>>>;
@@ -134,11 +135,11 @@ class ValueBoundsConstraintSet {
std::optional<int64_t> dim, ValueRange independencies,
bool closedUB = false);
- /// Compute a constant bound for the given index-typed value or shape
- /// dimension size.
+ /// Compute a constant bound for the given affine map, where dims and symbols
+ /// are bound to the given operands. The affine map must have exactly one
+ /// result.
///
- /// `dim` must be `nullopt` if and only if `value` is index-typed. This
- /// function traverses the backward slice of the given value in a
+ /// This function traverses the backward slice of the given operands in a
/// worklist-driven manner until `stopCondition` evaluates to "true". The
/// constraint set is populated according to `ValueBoundsOpInterface` for each
/// visited value. (No constraints are added for values for which the stop
@@ -155,26 +156,12 @@ class ValueBoundsConstraintSet {
std::optional<int64_t> dim = std::nullopt,
StopConditionFn stopCondition = nullptr,
bool closedUB = false);
-
- /// Compute a constant bound for the given affine map, where dims and symbols
- /// are bound to the given operands. The affine map must have exactly one
- /// result.
- ///
- /// This function traverses the backward slice of the given operands in a
- /// worklist-driven manner until `stopCondition` evaluates to "true". The
- /// constraint set is populated according to `ValueBoundsOpInterface` for each
- /// visited value. (No constraints are added for values for which the stop
- /// condition evaluates to "true".)
- ///
- /// The stop condition is optional: If none is specified, the backward slice
- /// is traversed in a breadth-first manner until a constant bound could be
- /// computed.
- ///
- /// By default, lower/equal bounds are closed and upper bounds are open. If
- /// `closedUB` is set to "true", upper bounds are also closed.
static FailureOr<int64_t> computeConstantBound(
presburger::BoundType type, AffineMap map, ValueDimList mapOperands,
StopConditionFn stopCondition = nullptr, bool closedUB = false);
+ static FailureOr<int64_t> computeConstantBound(
+ presburger::BoundType type, AffineMap map, ArrayRef<Value> mapOperands,
+ StopConditionFn stopCondition = nullptr, bool closedUB = false);
/// Compute a constant delta between the given two values. Return "failure"
/// if a constant delta could not be determined.
@@ -195,6 +182,13 @@ class ValueBoundsConstraintSet {
std::optional<int64_t> dim1 = std::nullopt,
std::optional<int64_t> dim2 = std::nullopt);
+ /// Return "true" if the given slices are guaranteed to be overlapping.
+ /// Return "false" if the given slices are guaranteed to be non-overlapping.
+ /// Return "failure" if unknown.
+ static FailureOr<bool>
+ areOverlappingSlices(OffsetSizeAndStrideOpInterface slice1,
+ OffsetSizeAndStrideOpInterface slice2);
+
/// Add a bound for the given index-typed value or shaped value. This function
/// returns a builder that adds the bound.
BoundBuilder bound(Value value) { return BoundBuilder(*this, value); }
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index 52ff6ceeee85b03..8f19245efdba6c8 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -643,6 +643,18 @@ OpOperand &MaterializeInDestinationOp::getSourceOperand() {
return getOperation()->getOpOperand(0) /*source*/;
}
+bool MaterializeInDestinationOp::operatesOnEquivalentSubset(
+ SubsetOpInterface subsetOp,
+ function_ref<bool(Value, Value)> equivalenceFn) {
+ return false;
+}
+
+bool MaterializeInDestinationOp::operatesOnDisjointSubset(
+ SubsetOpInterface subsetOp,
+ function_ref<bool(Value, Value)> equivalenceFn) {
+ return false;
+}
+
LogicalResult MaterializeInDestinationOp::verify() {
if (!isa<TensorType, BaseMemRefType>(getDest().getType()))
return emitOpError("'dest' must be a tensor or a memref");
diff --git a/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt b/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt
index 385d8dc9364e379..9895db9d93ce0bb 100644
--- a/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt
@@ -23,7 +23,7 @@ add_mlir_dialect_library(MLIRBufferizationDialect
MLIRFunctionInterfaces
MLIRIR
MLIRSparseTensorDialect
- MLIRSubsetInsertionOpInterface
+ MLIRSubsetOpInterface
MLIRTensorDialect
MLIRMemRefDialect
)
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt
index a6876c7c824e0ce..8617c17e7a5e5e5 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt
@@ -36,7 +36,7 @@ add_mlir_dialect_library(MLIRBufferizationTransforms
MLIRTensorDialect
MLIRSCFDialect
MLIRSideEffectInterfaces
- MLIRSubsetInsertionOpInterface
+ MLIRSubsetOpInterface
MLIRTransforms
MLIRViewLikeInterface
MLIRSupport
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
index 6622cfefa76a26f..4a418a05e6ff565 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
@@ -15,7 +15,7 @@
#include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Dominance.h"
-#include "mlir/Interfaces/SubsetInsertionOpInterface.h"
+#include "mlir/Interfaces/SubsetOpInterface.h"
#include "mlir/Pass/Pass.h"
namespace mlir {
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
index 0bbfdba2b6e6ef9..c48402c3742a77b 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
@@ -54,7 +54,7 @@
#include "mlir/IR/Operation.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
-#include "mlir/Interfaces/SubsetInsertionOpInterface.h"
+#include "mlir/Interfaces/SubsetOpInterface.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/SetVector.h"
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index bad246c262979b7..bb90e5ee546d113 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -66,7 +66,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
MLIRSCFTransforms
MLIRSCFUtils
MLIRPass
- MLIRSubsetInsertionOpInterface
+ MLIRSubsetOpInterface
MLIRSparseTensorDialect
MLIRTensorDialect
MLIRTensorTilingInterfaceImpl
diff --git a/mlir/lib/Dialect/Linalg/Transforms/SubsetInsertionOpInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/SubsetInsertionOpInterfaceImpl.cpp
index e0819082102ef66..e9fe93efceec356 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/SubsetInsertionOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/SubsetInsertionOpInterfaceImpl.cpp
@@ -9,12 +9,38 @@
#include "mlir/Dialect/Linalg/Transforms/SubsetInsertionOpInterfaceImpl.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
-#include "mlir/Interfaces/SubsetInsertionOpInterface.h"
+#include "mlir/Interfaces/SubsetOpInterface.h"
using namespace mlir;
using namespace mlir::linalg;
namespace {
+struct LinalgCopyOpSubsetOpInterface
+ : public SubsetOpInterface::ExternalModel<LinalgCopyOpSubsetOpInterface,
+ linalg::CopyOp> {
+ bool operatesOnEquivalentSubset(
+ Operation *op, SubsetOpInterface candidate,
+ function_ref<bool(Value, Value)> equivalenceFn) const {
+ // linalg.copy operates on the entire source tensor. This interface is
+ // currently needed only for One-Shot Bufferize, which only uses
+ // `SubsetInsertionOpInterface`, so this interface method is currently not
+ // needed. In the absence of an analysis, "false" is a conservative way to
+ // implement this interface method.
+ return false;
+ }
+
+ bool operatesOnDisjointSubset(
+ Operation *op, SubsetOpInterface candidate,
+ function_ref<bool(Value, Value)> equivalenceFn) const {
+ // linalg.copy operates on the entire source tensor. This interface is
+ // currently needed only for One-Shot Bufferize, which only uses
+ // `SubsetInsertionOpInterface`, so this interface method is currently not
+ // needed. In the absence of an analysis, "false" is a conservative way to
+ // implement this interface method.
+ return false;
+ }
+};
+
struct LinalgCopyOpInterface
: public SubsetInsertionOpInterface::ExternalModel<LinalgCopyOpInterface,
linalg::CopyOp> {
@@ -48,9 +74,10 @@ struct LinalgCopyOpInterface
};
} // namespace
-void mlir::linalg::registerSubsetInsertionOpInterfaceExternalModels(
+void mlir::linalg::registerSubsetOpInterfaceExternalModels(
DialectRegistry ®istry) {
registry.addExtension(+[](MLIRContext *ctx, linalg::LinalgDialect *dialect) {
+ linalg::CopyOp::attachInterface<LinalgCopyOpSubsetOpInterface>(*ctx);
linalg::CopyOp::attachInterface<LinalgCopyOpInterface>(*ctx);
});
}
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index a95443db88b50b2..2cd57e7324b4dc5 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -1066,5 +1066,5 @@ void mlir::tensor::registerBufferizableOpInterfaceExternalModels(
// Bufferization requires SubsetInsertionOpInterface models. Make sure that
// they are registered.
- tensor::registerSubsetInsertionOpInterfaceExternalModels(registry);
+ tensor::registerSubsetOpInterfaceExternalModels(registry);
}
diff --git a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
index a0c172ac52e4be8..c5fd4e65bbf7028 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
@@ -30,7 +30,7 @@ add_mlir_dialect_library(MLIRTensorTransforms
MLIRMemRefDialect
MLIRPass
MLIRSCFDialect
- MLIRSubsetInsertionOpInterface
+ MLIRSubsetOpInterface
MLIRTensorDialect
MLIRTensorUtils
MLIRTilingInterface
diff --git a/mlir/lib/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.cpp
index dbda9953684f41d..7a1bafd409eea60 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.cpp
@@ -9,17 +9,115 @@
#include "mlir/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
-#include "mlir/Interfaces/SubsetInsertionOpInterface.h"
+#include "mlir/Interfaces/SubsetOpInterface.h"
+#include "mlir/Interfaces/ValueBoundsOpInterface.h"
using namespace mlir;
using namespace mlir::tensor;
namespace {
+/// Return the tensor that the given subset op operates on.
+Value getContainerOperand(SubsetOpInterface op) {
+ if (auto extractionOp =
+ dyn_cast<SubsetExtractionOpInterface>(op.getOperation()))
+ return extractionOp.getSourceOperand().get();
+ if (auto insertionOp =
+ dyn_cast<SubsetInsertionOpInterface>(op.getOperation()))
+ return insertionOp.getDestinationOperand().get();
+ llvm_unreachable("expected SubsetExtraction/InsertionOpInterface");
+}
+
+/// Return "true" if the two ops operate on an equivalent subset.
+/// `equivalenceFn` is used to determine equivalence of tensors. Return "false"
+/// if the two ops operate non-equivalent subsets, if equivalence cannot be
+/// determined or if `op1` is not a subset op.
+template <typename OpTy>
+bool operateOnEquivalentSubsets(
+ OpTy op1, SubsetOpInterface op2,
+ function_ref<bool(Value, Value)> equivalenceFn) {
+ auto offsetsSizesAndStrides2 =
+ dyn_cast<OffsetSizeAndStrideOpInterface>(op2.getOperation());
+ if (!offsetsSizesAndStrides2)
+ return false;
+ if (!sameOffsetsSizesAndStrides(op1, offsetsSizesAndStrides2,
+ isEqualConstantIntOrValue))
+ return false;
+ return equivalenceFn(
+ getContainerOperand(cast<SubsetOpInterface>(op1.getOperation())),
+ getContainerOperand(op2));
+}
+
+/// Return "true" if the two ops operate on a disjoint subsets.
+/// `equivalenceFn` is used to determine equivalence of tensors. Return "false"
+/// if the two ops operate non-disjoint subsets, if disjointness cannot be
+/// determined or if `op1` is not a subset op.
+template <typename OpTy>
+bool operateOnDisjointSubsets(OpTy op1, SubsetOpInterface op2,
+ function_ref<bool(Value, Value)> equivalenceFn) {
+ auto offsetsSizesAndStrides2 =
+ dyn_cast<OffsetSizeAndStrideOpInterface>(op2.getOperation());
+ if (!offsetsSizesAndStrides2)
+ return false;
+ FailureOr<bool> overlappingSlices =
+ ValueBoundsConstraintSet::areOverlappingSlices(op1,
+ offsetsSizesAndStrides2);
+ if (failed(overlappingSlices) || *overlappingSlices)
+ return false;
+ return equivalenceFn(
+ getContainerOperand(cast<SubsetOpInterface>(op1.getOperation())),
+ getContainerOperand(op2));
+}
+
+struct ExtractSliceOpSubsetOpInterface
+ : public SubsetOpInterface::ExternalModel<ExtractSliceOpSubsetOpInterface,
+ tensor::ExtractSliceOp> {
+ bool operatesOnEquivalentSubset(
+ Operation *op, SubsetOpInterface candidate,
+ function_ref<bool(Value, Value)> equivalenceFn) const {
+ auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
+ return operateOnEquivalentSubsets(extractSliceOp, candidate, equivalenceFn);
+ }
+
+ bool operatesOnDisjointSubset(
+ Operation *op, SubsetOpInterface candidate,
+ function_ref<bool(Value, Value)> equivalenceFn) const {
+ auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
+ return operateOnDisjointSubsets(extractSliceOp, candidate, equivalenceFn);
+ }
+};
+
+struct ExtractSliceOpSubsetExtractionOpInterface
+ : public SubsetExtractionOpInterface::ExternalModel<
+ ExtractSliceOpSubsetExtractionOpInterface, tensor::ExtractSliceOp> {
+ OpOperand &getSourceOperand(Operation *op) const {
+ return cast<tensor::ExtractSliceOp>(op).getSourceMutable();
+ }
+};
+
template <typename OpTy>
-struct InsertSliceLikeOpInterface
+struct InsertSliceLikeOpSubsetOpInterface
+ : public SubsetOpInterface::ExternalModel<
+ InsertSliceLikeOpSubsetOpInterface<OpTy>, OpTy> {
+ bool operatesOnEquivalentSubset(
+ Operation *op, SubsetOpInterface candidate,
+ function_ref<bool(Value, Value)> equivalenceFn) const {
+ auto insertSliceOp = cast<OpTy>(op);
+ return operateOnEquivalentSubsets(insertSliceOp, candidate, equivalenceFn);
+ }
+
+ bool operatesOnDisjointSubset(
+ Operation *op, SubsetOpInterface candidate,
+ function_ref<bool(Value, Value)> equivalenceFn) const {
+ auto insertSliceOp = cast<OpTy>(op);
+ return operateOnDisjointSubsets(insertSliceOp, candidate, equivalenceFn);
+ }
+};
+
+template <typename OpTy>
+struct InsertSliceLikeOpSubsetInsertionOpInterface
: public SubsetInsertionOpInterface::ExternalModel<
- InsertSliceLikeOpInterface<OpTy>, OpTy> {
+ InsertSliceLikeOpSubsetInsertionOpInterface<OpTy>, OpTy> {
OpOperand &getSourceOperand(Operation *op) const {
return cast<OpTy>(op).getSourceMutable();
}
@@ -28,23 +126,6 @@ struct InsertSliceLikeOpInterface
return cast<OpTy>(op).getDestMutable();
}
- /// Return "true" if `insertSliceOp` inserts into a subset that is equivalent
- /// to the subset defined by `candidate`. `equivalenceFn` is used to determine
- /// equivalence of tensors.
- bool
- isEquivalentSubset(Operation *op, Value candidate,
- function_ref<bool(Value, Value)> equivalenceFn) const {
- auto insertSliceOp = cast<OpTy>(op);
- // Look for a matching tensor.extract_slice op.
- auto extractSliceOp = candidate.getDefiningOp<tensor::ExtractSliceOp>();
- if (!extractSliceOp)
- return false;
- if (!equivalenceFn(extractSliceOp.getSource(), insertSliceOp.getDest()))
- return false;
- return sameOffsetsSizesAndStrides(extractSliceOp, insertSliceOp,
- isEqualConstantIntOrValue);
- }
-
Value buildSubsetExtraction(Operation *op, OpBuilder &builder,
Location loc) const {
auto insertSliceOp = cast<OpTy>(op);
@@ -73,12 +154,22 @@ struct InsertSliceLikeOpInterface
} // namespace
-void mlir::tensor::registerSubsetInsertionOpInterfaceExternalModels(
+void mlir::tensor::registerSubsetOpInterfaceExternalModels(
DialectRegistry ®istry) {
registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) {
- InsertSliceOp::attachInterface<InsertSliceLikeOpInterface<InsertSliceOp>>(
+ // Note: `SubsetExtractionOpInterface` and `SubsetInsertionOpInterface`
+ // require `SubsetOpInterface`.
+ ExtractSliceOp::attachInterface<ExtractSliceOpSubsetOpInterface>(*ctx);
+ ExtractSliceOp::attachInterface<ExtractSliceOpSubsetExtractionOpInterface>(
*ctx);
+ InsertSliceOp::attachInterface<
+ InsertSliceLikeOpSubsetOpInterface<InsertSliceOp>>(*ctx);
+ InsertSliceOp::attachInterface<
+ InsertSliceLikeOpSubsetInsertionOpInterface<InsertSliceOp>>(*ctx);
+ ParallelInsertSliceOp::attachInterface<
+ InsertSliceLikeOpSubsetOpInterface<ParallelInsertSliceOp>>(*ctx);
ParallelInsertSliceOp::attachInterface<
- InsertSliceLikeOpInterface<ParallelInsertSliceOp>>(*ctx);
+ InsertSliceLikeOpSubsetInsertionOpInterface<ParallelInsertSliceOp>>(
+ *ctx);
});
}
diff --git a/mlir/lib/Interfaces/CMakeLists.txt b/mlir/lib/Interfaces/CMakeLists.txt
index f74306206d63f14..2652d261f480ba4 100644
--- a/mlir/lib/Interfaces/CMakeLists.txt
+++ b/mlir/lib/Interfaces/CMakeLists.txt
@@ -16,7 +16,7 @@ set(LLVM_OPTIONAL_SOURCES
RuntimeVerifiableOpInterface.cpp
ShapedOpInterfaces.cpp
SideEffectInterfaces.cpp
- SubsetInsertionOpInterface.cpp
+ SubsetOpInterface.cpp
TilingInterface.cpp
ValueBoundsOpInterface.cpp
VectorInterfaces.cpp
@@ -84,15 +84,15 @@ add_mlir_interface_library(RuntimeVerifiableOpInterface)
add_mlir_interface_library(ShapedOpInterfaces)
add_mlir_interface_library(SideEffectInterfaces)
-add_mlir_library(MLIRSubsetInsertionOpInterface
- SubsetInsertionOpInterface.cpp
+add_mlir_library(MLIRSubsetOpInterface
+ SubsetOpInterface.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Interfaces
DEPENDS
MLIRDestinationStyleOpInterface
- MLIRSubsetInsertionOpInterfaceIncGen
+ MLIRSubsetOpInterfaceIncGen
LINK_LIBS PUBLIC
MLIRDestinationStyleOpInterface
@@ -112,6 +112,7 @@ add_mlir_library(MLIRValueBoundsOpInterface
DEPENDS
MLIRDestinationStyleOpInterface
MLIRValueBoundsOpInterfaceIncGen
+ MLIRViewLikeInterface
LINK_LIBS PUBLIC
MLIRAnalysis
diff --git a/mlir/lib/Interfaces/SubsetInsertionOpInterface.cpp b/mlir/lib/Interfaces/SubsetInsertionOpInterface.cpp
deleted file mode 100644
index b2b092287f96ba6..000000000000000
--- a/mlir/lib/Interfaces/SubsetInsertionOpInterface.cpp
+++ /dev/null
@@ -1,23 +0,0 @@
-//===- SubsetInsertionOpInterface.cpp - Tensor Subsets --------------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Interfaces/SubsetInsertionOpInterface.h"
-#include "mlir/Interfaces/DestinationStyleOpInterface.h"
-
-#include "mlir/Interfaces/SubsetInsertionOpInterface.cpp.inc"
-
-using namespace mlir;
-
-OpOperand &detail::defaultGetDestinationOperand(Operation *op) {
- auto dstOp = dyn_cast<DestinationStyleOpInterface>(op);
- assert(dstOp && "getDestination must be implemented for non-DPS ops");
- assert(
- dstOp.getNumDpsInits() == 1 &&
- "getDestination must be implemented for ops with 0 or more than 1 init");
- return *dstOp.getDpsInitOperand(0);
-}
diff --git a/mlir/lib/Interfaces/SubsetOpInterface.cpp b/mlir/lib/Interfaces/SubsetOpInterface.cpp
new file mode 100644
index 000000000000000..7245ab20c499e20
--- /dev/null
+++ b/mlir/lib/Interfaces/SubsetOpInterface.cpp
@@ -0,0 +1,58 @@
+//===- SubsetOpInterface.cpp - Tensor Subsets -----------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Interfaces/SubsetOpInterface.h"
+#include "mlir/Interfaces/DestinationStyleOpInterface.h"
+
+#include "mlir/Interfaces/SubsetOpInterface.cpp.inc"
+
+using namespace mlir;
+
+OpOperand &detail::defaultGetDestinationOperand(Operation *op) {
+ auto dstOp = dyn_cast<DestinationStyleOpInterface>(op);
+ assert(dstOp && "getDestination must be implemented for non-DPS ops");
+ assert(
+ dstOp.getNumDpsInits() == 1 &&
+ "getDestination must be implemented for ops with 0 or more than 1 init");
+ return *dstOp.getDpsInitOperand(0);
+}
+
+OpResult detail::defaultGetUpdatedDestination(Operation *op) {
+ auto dstOp = dyn_cast<DestinationStyleOpInterface>(op);
+ assert(dstOp && "getUpdatedDestination must be implemented for non-DPS ops");
+ auto insertionOp = cast<SubsetInsertionOpInterface>(op);
+ return dstOp.getTiedOpResult(&insertionOp.getDestinationOperand());
+}
+
+bool detail::defaultIsEquivalentSubset(
+ Operation *op, Value candidate,
+ function_ref<bool(Value, Value)> equivalenceFn) {
+ assert(isa<SubsetInsertionOpInterface>(op) &&
+ "expected SubsetInsertionOpInterface");
+ if (!candidate.getDefiningOp<SubsetExtractionOpInterface>())
+ return false;
+ return cast<SubsetOpInterface>(op).operatesOnEquivalentSubset(
+ candidate.getDefiningOp<SubsetOpInterface>(), equivalenceFn);
+}
+
+LogicalResult detail::verifySubsetOpInterface(SubsetOpInterface op) {
+ if (!(isa<SubsetExtractionOpInterface>(op.getOperation()) ^
+ isa<SubsetInsertionOpInterface>(op.getOperation())))
+ return op->emitOpError(
+ "SubsetOpInterface ops must implement either "
+ "SubsetExtractionOpInterface or SubsetInsertionOpInterface");
+ return success();
+}
+
+LogicalResult
+detail::verifySubsetExtractionOpInterface(SubsetExtractionOpInterface op) {
+ if (op->getNumResults() != 1)
+ return op->emitOpError(
+ "SubsetExtractionOpInterface ops must have one result");
+ return success();
+}
diff --git a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
index ff941115219f68b..200f7f3b25b731b 100644
--- a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
+++ b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
@@ -11,6 +11,7 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Matchers.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
+#include "mlir/Interfaces/ViewLikeInterface.h"
#include "llvm/ADT/APSInt.h"
#include "llvm/Support/Debug.h"
@@ -484,6 +485,17 @@ FailureOr<int64_t> ValueBoundsConstraintSet::computeConstantBound(
return failure();
}
+FailureOr<int64_t> ValueBoundsConstraintSet::computeConstantBound(
+ presburger::BoundType type, AffineMap map, ArrayRef<Value> operands,
+ StopConditionFn stopCondition, bool closedUB) {
+ ValueDimList valueDims;
+ for (Value v : operands) {
+ assert(v.getType().isIndex() && "expected index type");
+ valueDims.emplace_back(v, std::nullopt);
+ }
+ return computeConstantBound(type, map, valueDims, stopCondition, closedUB);
+}
+
FailureOr<int64_t>
ValueBoundsConstraintSet::computeConstantDelta(Value value1, Value value2,
std::optional<int64_t> dim1,
@@ -512,6 +524,70 @@ ValueBoundsConstraintSet::areEqual(Value value1, Value value2,
return *delta == 0;
}
+FailureOr<bool> ValueBoundsConstraintSet::areOverlappingSlices(
+ OffsetSizeAndStrideOpInterface slice1,
+ OffsetSizeAndStrideOpInterface slice2) {
+ assert(slice1.getStaticOffsets().size() == slice1.getStaticOffsets().size() &&
+ "expected slices of same rank");
+ assert(slice1.getStaticSizes().size() == slice1.getStaticSizes().size() &&
+ "expected slices of same rank");
+ assert(slice1.getStaticStrides().size() == slice1.getStaticStrides().size() &&
+ "expected slices of same rank");
+
+ Builder b(slice1.getContext());
+ bool foundUnknownBound = false;
+ for (int64_t i = 0, e = slice1.getStaticOffsets().size(); i < e; ++i) {
+ AffineMap map =
+ AffineMap::get(/*dimCount=*/0, /*symbolCount=*/4,
+ b.getAffineSymbolExpr(0) +
+ b.getAffineSymbolExpr(1) * b.getAffineSymbolExpr(2) -
+ b.getAffineSymbolExpr(3));
+ {
+ // Case 1: Slices are guaranteed to be non-overlapping if
+ // offset1 + size1 * stride1 <= offset2 (for at least one dimension).
+ SmallVector<OpFoldResult> ofrOperands;
+ ofrOperands.push_back(slice1.getMixedOffsets()[i]);
+ ofrOperands.push_back(slice1.getMixedSizes()[i]);
+ ofrOperands.push_back(slice1.getMixedStrides()[i]);
+ ofrOperands.push_back(slice2.getMixedOffsets()[i]);
+ SmallVector<Value> valueOperands;
+ AffineMap foldedMap =
+ foldAttributesIntoMap(b, map, ofrOperands, valueOperands);
+ FailureOr<int64_t> constBound = computeConstantBound(
+ presburger::BoundType::EQ, foldedMap, valueOperands);
+ foundUnknownBound |= failed(constBound);
+ if (succeeded(constBound) && *constBound <= 0)
+ return false;
+ }
+ {
+ // Case 2: Slices are guaranteed to be non-overlapping if
+ // offset2 + size2 * stride2 <= offset1 (for at least one dimension).
+ SmallVector<OpFoldResult> ofrOperands;
+ ofrOperands.push_back(slice2.getMixedOffsets()[i]);
+ ofrOperands.push_back(slice2.getMixedSizes()[i]);
+ ofrOperands.push_back(slice2.getMixedStrides()[i]);
+ ofrOperands.push_back(slice1.getMixedOffsets()[i]);
+ SmallVector<Value> valueOperands;
+ AffineMap foldedMap =
+ foldAttributesIntoMap(b, map, ofrOperands, valueOperands);
+ FailureOr<int64_t> constBound = computeConstantBound(
+ presburger::BoundType::EQ, foldedMap, valueOperands);
+ foundUnknownBound |= failed(constBound);
+ if (succeeded(constBound) && *constBound <= 0)
+ return false;
+ }
+ }
+
+ // If at least one bound could not be computed, we cannot be certain that the
+ // slices are really overlapping.
+ if (foundUnknownBound)
+ return failure();
+
+ // All bounds could be computed and none of the above cases applied.
+ // Therefore, the slices are guaranteed to overlap.
+ return true;
+}
+
ValueBoundsConstraintSet::BoundBuilder &
ValueBoundsConstraintSet::BoundBuilder::operator[](int64_t dim) {
assert(!this->dim.has_value() && "dim was already set");
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index b1bd543ae7fce63..add146bf084d321 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -6933,7 +6933,7 @@ cc_library(
":MemRefDialect",
":Pass",
":SCFDialect",
- ":SubsetInsertionOpInterface",
+ ":SubsetOpInterface",
":TensorDialect",
":TensorPassIncGen",
":TensorUtils",
@@ -10189,9 +10189,9 @@ gentbl_cc_library(
)
td_library(
- name = "SubsetInsertionOpInterfaceTdFiles",
+ name = "SubsetOpInterfaceTdFiles",
srcs = [
- "include/mlir/Interfaces/SubsetInsertionOpInterface.td",
+ "include/mlir/Interfaces/SubsetOpInterface.td",
],
includes = ["include"],
deps = [
@@ -10200,33 +10200,33 @@ td_library(
)
gentbl_cc_library(
- name = "SubsetInsertionOpInterfaceIncGen",
+ name = "SubsetOpInterfaceIncGen",
tbl_outs = [
(
["-gen-op-interface-decls"],
- "include/mlir/Interfaces/SubsetInsertionOpInterface.h.inc",
+ "include/mlir/Interfaces/SubsetOpInterface.h.inc",
),
(
["-gen-op-interface-defs"],
- "include/mlir/Interfaces/SubsetInsertionOpInterface.cpp.inc",
+ "include/mlir/Interfaces/SubsetOpInterface.cpp.inc",
),
],
tblgen = ":mlir-tblgen",
- td_file = "include/mlir/Interfaces/SubsetInsertionOpInterface.td",
+ td_file = "include/mlir/Interfaces/SubsetOpInterface.td",
deps = [
- ":SubsetInsertionOpInterfaceTdFiles",
+ ":SubsetOpInterfaceTdFiles",
],
)
cc_library(
- name = "SubsetInsertionOpInterface",
- srcs = ["lib/Interfaces/SubsetInsertionOpInterface.cpp"],
- hdrs = ["include/mlir/Interfaces/SubsetInsertionOpInterface.h"],
+ name = "SubsetOpInterface",
+ srcs = ["lib/Interfaces/SubsetOpInterface.cpp"],
+ hdrs = ["include/mlir/Interfaces/SubsetOpInterface.h"],
includes = ["include"],
deps = [
":DestinationStyleOpInterface",
":IR",
- ":SubsetInsertionOpInterfaceIncGen",
+ ":SubsetOpInterfaceIncGen",
":Support",
"//llvm:Support",
],
@@ -10470,7 +10470,7 @@ cc_library(
":SCFTransforms",
":SCFUtils",
":SparseTensorDialect",
- ":SubsetInsertionOpInterface",
+ ":SubsetOpInterface",
":Support",
":TensorDialect",
":TensorTilingInterfaceImpl",
@@ -10530,6 +10530,7 @@ cc_library(
":IR",
":Support",
":ValueBoundsOpInterfaceIncGen",
+ ":ViewLikeInterface",
"//llvm:Support",
],
)
@@ -12579,7 +12580,7 @@ gentbl_cc_library(
":BufferizableOpInterfaceTdFiles",
":BufferizationOpsTdFiles",
":DestinationStyleOpInterfaceTdFiles",
- ":SubsetInsertionOpInterfaceTdFiles",
+ ":SubsetOpInterfaceTdFiles",
],
)
@@ -12619,7 +12620,7 @@ cc_library(
":InferTypeOpInterface",
":MemRefDialect",
":SparseTensorDialect",
- ":SubsetInsertionOpInterface",
+ ":SubsetOpInterface",
":Support",
":TensorDialect",
"//llvm:Support",
@@ -12669,7 +12670,7 @@ cc_library(
":Pass",
":SCFDialect",
":SideEffectInterfaces",
- ":SubsetInsertionOpInterface",
+ ":SubsetOpInterface",
":Support",
":TensorDialect",
":Transforms",
diff --git a/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel
index c83e4cc7ada23e3..348ee2beabeb061 100644
--- a/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel
@@ -413,7 +413,7 @@ gentbl_filegroup(
deps = [
":BufferizationOpsPyTdFiles",
"//mlir:DestinationStyleOpInterfaceTdFiles",
- "//mlir:SubsetInsertionOpInterfaceTdFiles",
+ "//mlir:SubsetOpInterfaceTdFiles",
],
)
>From 23ce8911cf57df4ad6cdeb9f2cd89bc7b0c000d5 Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Mon, 30 Oct 2023 14:32:49 +0900
Subject: [PATCH 2/6] [mlir] Loop-invariant subset hoisting
---
.../Transforms/LoopInvariantCodeMotionUtils.h | 39 +++
mlir/include/mlir/Transforms/Passes.h | 2 +
mlir/include/mlir/Transforms/Passes.td | 5 +
.../lib/Interfaces/ValueBoundsOpInterface.cpp | 4 +-
.../Transforms/LoopInvariantCodeMotion.cpp | 20 ++
mlir/lib/Transforms/Utils/CMakeLists.txt | 1 +
.../Utils/LoopInvariantCodeMotionUtils.cpp | 253 +++++++++++++++++-
.../loop-invariant-subset-hoisting.mlir | 237 ++++++++++++++++
.../llvm-project-overlay/mlir/BUILD.bazel | 1 +
9 files changed, 556 insertions(+), 6 deletions(-)
create mode 100644 mlir/test/Transforms/loop-invariant-subset-hoisting.mlir
diff --git a/mlir/include/mlir/Transforms/LoopInvariantCodeMotionUtils.h b/mlir/include/mlir/Transforms/LoopInvariantCodeMotionUtils.h
index c7b816eb28faf5f..579054070f729b0 100644
--- a/mlir/include/mlir/Transforms/LoopInvariantCodeMotionUtils.h
+++ b/mlir/include/mlir/Transforms/LoopInvariantCodeMotionUtils.h
@@ -71,6 +71,45 @@ size_t moveLoopInvariantCode(
/// methods provided by the interface.
size_t moveLoopInvariantCode(LoopLikeOpInterface loopLike);
+/// Hoist loop-invariant tensor subsets (subset extraction and subset insertion
+/// ops) from loop-like ops. Extraction ops are moved before the loop. Insertion
+/// ops are moved after the loop. The loop body operates on newly added region
+/// iter_args (one per extraction-insertion pair).
+///
+/// A subset extraction op (`SubsetExtractionOpInterface`) extracts from a
+/// tensor value at a subset. The result of the op may have an arbitrary type,
+/// i.e., not necessarily a tensor type. Example: "tensor.extract_slice".
+///
+/// A subset insertion op (`SubsetInsertionOpInterface`) inserts into a tensor
+/// value ("destination") at a subset. Example: "tensor.insert_slice".
+///
+/// Matching extraction-insertion subset ops can be hoisted from a loop if there
+/// are no other ops within the loop that operate on the same or on an
+/// overlapping subset. In particular, non-subset ops can prevent hoisting
+/// because the analysis does not know what subset they operate on.
+///
+/// Example:
+/// ```
+/// %r = scf.for ... iter_args(%t = %a) -> (tensor<?xf32>) {
+/// %0 = tensor.extract_slice %t[0][5][1] : tensor<?xf32> to tensor<5xf32>
+/// %1 = "test.foo"(%0) : (tensor<5xf32>) -> (tensor<5xf32>)
+/// %2 = tensor.insert_slice %1 into %t[0][5][1]
+/// : tensor<5xf32> into tensor<?xf32>
+/// scf.yield %2 : tensor<?xf32>
+/// }
+/// ```
+/// Is rewritten to:
+/// ```
+/// %0 = tensor.extract_slice %a[0][5][1] : tensor<?xf32> to tensor<5xf32>
+/// %new_loop:2 = scf.for ... iter_args(%t = %a, %h = %0) -> (tensor<?xf32>) {
+/// %1 = "test.foo"(%h) : (tensor<5xf32>) -> (tensor<5xf32>)
+/// scf.yield %t, %2 : tensor<?xf32>, tensor<5xf32>
+/// }
+/// %r = tensor.insert_slice %new_loop#1 into %new_loop#0
+/// : tensor<5xf32> into tensor<?xf32>
+/// ```
+LoopLikeOpInterface hoistLoopInvariantSubsets(LoopLikeOpInterface loopLike);
+
} // end namespace mlir
#endif // MLIR_TRANSFORMS_LOOPINVARIANTCODEMOTIONUTILS_H
diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h
index 320932bb999561f..ab04478ae076694 100644
--- a/mlir/include/mlir/Transforms/Passes.h
+++ b/mlir/include/mlir/Transforms/Passes.h
@@ -78,6 +78,8 @@ std::unique_ptr<Pass> createGenerateRuntimeVerificationPass();
/// instructions out of the loop.
std::unique_ptr<Pass> createLoopInvariantCodeMotionPass();
+std::unique_ptr<Pass> createLoopInvariantSubsetHoistingPass();
+
/// Creates a pass to strip debug information from a function.
std::unique_ptr<Pass> createStripDebugInfoPass();
diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index 26d2ff3c30ded57..2d2d54fb8fb5eaa 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -329,6 +329,11 @@ def LoopInvariantCodeMotion : Pass<"loop-invariant-code-motion"> {
let constructor = "mlir::createLoopInvariantCodeMotionPass()";
}
+def LoopInvariantSubsetHoisting : Pass<"loop-invariant-subset-hoisting"> {
+ let summary = "Hoist loop invariant subset ops outside of the loop";
+ let constructor = "mlir::createLoopInvariantSubsetHoistingPass()";
+}
+
def Mem2Reg : Pass<"mem2reg"> {
let summary = "Promotes memory slots into values.";
let description = [{
diff --git a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
index 200f7f3b25b731b..f0c37c872e6d31d 100644
--- a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
+++ b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
@@ -86,7 +86,7 @@ AffineExpr ValueBoundsConstraintSet::getExpr(Value value,
return builder.getAffineConstantExpr(shapedType.getDimSize(*dim));
} else {
// Constant index value: return directly.
- if (auto constInt = getConstantIntValue(value))
+ if (auto constInt = ::getConstantIntValue(value))
return builder.getAffineConstantExpr(*constInt);
}
@@ -103,7 +103,7 @@ AffineExpr ValueBoundsConstraintSet::getExpr(Value value,
AffineExpr ValueBoundsConstraintSet::getExpr(OpFoldResult ofr) {
if (Value value = llvm::dyn_cast_if_present<Value>(ofr))
return getExpr(value, /*dim=*/std::nullopt);
- auto constInt = getConstantIntValue(ofr);
+ auto constInt = ::getConstantIntValue(ofr);
assert(constInt.has_value() && "expected Integer constant");
return builder.getAffineConstantExpr(*constInt);
}
diff --git a/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp b/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp
index 854fde09bac796e..e6d8af8f05832d3 100644
--- a/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp
+++ b/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp
@@ -18,6 +18,7 @@
namespace mlir {
#define GEN_PASS_DEF_LOOPINVARIANTCODEMOTION
+#define GEN_PASS_DEF_LOOPINVARIANTSUBSETHOISTING
#include "mlir/Transforms/Passes.h.inc"
} // namespace mlir
@@ -29,6 +30,12 @@ struct LoopInvariantCodeMotion
: public impl::LoopInvariantCodeMotionBase<LoopInvariantCodeMotion> {
void runOnOperation() override;
};
+
+struct LoopInvariantSubsetHoisting
+ : public impl::LoopInvariantSubsetHoistingBase<
+ LoopInvariantSubsetHoisting> {
+ void runOnOperation() override;
+};
} // namespace
void LoopInvariantCodeMotion::runOnOperation() {
@@ -39,6 +46,19 @@ void LoopInvariantCodeMotion::runOnOperation() {
[&](LoopLikeOpInterface loopLike) { moveLoopInvariantCode(loopLike); });
}
+void LoopInvariantSubsetHoisting::runOnOperation() {
+ // Walk through all loops in a function in innermost-loop-first order. This
+ // way, we first hoist from the inner loop, and place the ops in the outer
+ // loop, which in turn can be further hoisted from.
+ getOperation()->walk([&](LoopLikeOpInterface loopLike) {
+ (void)hoistLoopInvariantSubsets(loopLike);
+ });
+}
+
std::unique_ptr<Pass> mlir::createLoopInvariantCodeMotionPass() {
return std::make_unique<LoopInvariantCodeMotion>();
}
+
+std::unique_ptr<Pass> mlir::createLoopInvariantSubsetHoistingPass() {
+ return std::make_unique<LoopInvariantSubsetHoisting>();
+}
diff --git a/mlir/lib/Transforms/Utils/CMakeLists.txt b/mlir/lib/Transforms/Utils/CMakeLists.txt
index efc7a5160b2399e..1c608e0634a67e2 100644
--- a/mlir/lib/Transforms/Utils/CMakeLists.txt
+++ b/mlir/lib/Transforms/Utils/CMakeLists.txt
@@ -20,5 +20,6 @@ add_mlir_library(MLIRTransformUtils
MLIRFunctionInterfaces
MLIRLoopLikeInterface
MLIRSideEffectInterfaces
+ MLIRSubsetOpInterface
MLIRRewrite
)
diff --git a/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp b/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
index 080492da6ae4b97..f39f8ec3598f322 100644
--- a/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
@@ -11,9 +11,12 @@
//===----------------------------------------------------------------------===//
#include "mlir/Transforms/LoopInvariantCodeMotionUtils.h"
+
#include "mlir/IR/Operation.h"
+#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/LoopLikeInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Interfaces/SubsetOpInterface.h"
#include "llvm/Support/Debug.h"
#include <queue>
@@ -26,7 +29,7 @@ using namespace mlir;
/// loop (by means of calling definedOutside).
/// - the op has no side-effects.
static bool canBeHoisted(Operation *op,
- function_ref<bool(Value)> definedOutside) {
+ function_ref<bool(OpOperand &)> condition) {
// Do not move terminators.
if (op->hasTrait<OpTrait::IsTerminator>())
return false;
@@ -35,11 +38,11 @@ static bool canBeHoisted(Operation *op,
// defined outside of the loop or in a nested region, but not at the level of
// the loop body.
auto walkFn = [&](Operation *child) {
- for (Value operand : child->getOperands()) {
+ for (OpOperand &operand : child->getOpOperands()) {
// Ignore values defined in a nested region.
- if (op->isAncestor(operand.getParentRegion()->getParentOp()))
+ if (op->isAncestor(operand.get().getParentRegion()->getParentOp()))
continue;
- if (!definedOutside(operand))
+ if (!condition(operand))
return WalkResult::interrupt();
}
return WalkResult::advance();
@@ -47,6 +50,12 @@ static bool canBeHoisted(Operation *op,
return !op->walk(walkFn).wasInterrupted();
}
+static bool canBeHoisted(Operation *op,
+ function_ref<bool(Value)> definedOutside) {
+ return canBeHoisted(
+ op, [&](OpOperand &operand) { return definedOutside(operand.get()); });
+}
+
size_t mlir::moveLoopInvariantCode(
ArrayRef<Region *> regions,
function_ref<bool(Value, Region *)> isDefinedOutsideRegion,
@@ -105,3 +114,239 @@ size_t mlir::moveLoopInvariantCode(LoopLikeOpInterface loopLike) {
},
[&](Operation *op, Region *) { loopLike.moveOutOfLoop(op); });
}
+
+namespace {
+/// Helper data structure that keeps track of equivalent/disjoint subset ops.
+class MatchingSubsets {
+public:
+ /// Insert a subset op.
+ void insert(SubsetOpInterface op) {
+ allSubsetOps.push_back(op);
+ if (auto extractionOp =
+ dyn_cast<SubsetExtractionOpInterface>(op.getOperation()))
+ insertExtractionOp(extractionOp);
+ if (auto insertionOp =
+ dyn_cast<SubsetInsertionOpInterface>(op.getOperation()))
+ insertInsertionOp(insertionOp);
+ }
+
+ /// Return a range of matching extraction-insertion subset ops. If there is no
+ /// matching extraction/insertion op, the respective value is empty. Ops are
+ /// skipped if there are other subset ops that are not guaranteed to operate
+ /// on disjoint subsets.
+ auto getHoistableSubsetOps() {
+ return llvm::make_filter_range(
+ llvm::zip(extractions, insertions), [&](auto pair) {
+ auto [extractionOp, insertionOp] = pair;
+ // Hoist only if the extracted and inserted values have the same type.
+ if (extractionOp && insertionOp &&
+ extractionOp->getResult(0).getType() !=
+ insertionOp.getSourceOperand().get().getType())
+ return false;
+ // Hoist only if there are no conflicting subset ops.
+ return allDisjoint(extractionOp, insertionOp);
+ });
+ }
+
+private:
+ /// Helper function for equivalence of tensor values. Since only insertion
+ /// subset ops (that are also destination style ops) are followed when
+ /// traversing the SSA use-def chain, all tensor values are equivalent.
+ static bool isEquivalent(Value v1, Value v2) { return true; }
+
+ /// Return "true" if the subsets of the given extraction and insertion ops
+ /// are operating disjoint from the subsets that all other known subset ops
+ /// are operating on.
+ bool allDisjoint(SubsetExtractionOpInterface extractionOp,
+ SubsetInsertionOpInterface insertionOp) const {
+ for (SubsetOpInterface other : allSubsetOps) {
+ if (other == extractionOp || other == insertionOp)
+ continue;
+ if (extractionOp &&
+ !other.operatesOnDisjointSubset(extractionOp, isEquivalent))
+ return false;
+ if (insertionOp &&
+ !other.operatesOnDisjointSubset(insertionOp, isEquivalent))
+ return false;
+ }
+ return true;
+ }
+
+ /// Insert a subset extraction op. If the subset is equivalent to an existing
+ /// subset insertion op, pair them up. (If there is already a paired up subset
+ /// extraction op, overwrite the subset extraction op.)
+ void insertExtractionOp(SubsetExtractionOpInterface extractionOp) {
+ for (auto it : llvm::enumerate(insertions)) {
+ if (!it.value())
+ continue;
+ auto other = cast<SubsetOpInterface>(it.value().getOperation());
+ if (other.operatesOnEquivalentSubset(extractionOp, isEquivalent)) {
+ extractions[it.index()] = extractionOp;
+ return;
+ }
+ }
+ // There is no known equivalent insertion op. Create a new entry.
+ extractions.push_back(extractionOp);
+ insertions.push_back({});
+ }
+
+ /// Insert a subset insertion op. If the subset is equivalent to an existing
+ /// subset extraction op, pair them up. (If there is already a paired up
+ /// subset insertion op, overwrite the subset insertion op.)
+ void insertInsertionOp(SubsetInsertionOpInterface insertionOp) {
+ for (auto it : llvm::enumerate(extractions)) {
+ if (!it.value())
+ continue;
+ auto other = cast<SubsetOpInterface>(it.value().getOperation());
+ if (other.operatesOnEquivalentSubset(insertionOp, isEquivalent)) {
+ insertions[it.index()] = insertionOp;
+ return;
+ }
+ }
+ // There is no known equivalent extraction op. Create a new entry.
+ extractions.push_back({});
+ insertions.push_back(insertionOp);
+ }
+
+ SmallVector<SubsetExtractionOpInterface> extractions;
+ SmallVector<SubsetInsertionOpInterface> insertions;
+ SmallVector<SubsetOpInterface> allSubsetOps;
+};
+} // namespace
+
+/// If the given value has a single use by an op that is a terminator, return
+/// that use. Otherwise, return nullptr.
+static OpOperand *getSingleTerminatorUse(Value value) {
+ if (!value.hasOneUse())
+ return nullptr;
+ OpOperand &use = *value.getUses().begin();
+ if (use.getOwner()->hasTrait<OpTrait::IsTerminator>())
+ return &use;
+ return nullptr;
+}
+
+/// Hoist all subset ops that operate on the idx-th region iter_arg of the given
+/// loop-like op and index into loop-invariant subset locations. Return the
+/// newly created loop op (that has extra iter_args) or the original loop op if
+/// nothing was hoisted.
+static LoopLikeOpInterface hoistSubsetAtIterArg(LoopLikeOpInterface loopLike,
+ BlockArgument iterArg) {
+ IRRewriter rewriter(loopLike.getContext());
+ assert(iterArg.getOwner()->getParentOp() == loopLike && "invalid iter_arg");
+ auto it = llvm::find(loopLike.getRegionIterArgs(), iterArg);
+ int64_t iterArgIdx = std::distance(loopLike.getRegionIterArgs().begin(), it);
+ Value value = iterArg;
+ MatchingSubsets subsets;
+
+ // Traverse use-def chain. Subset ops can be hoisted only if all ops along the
+ // use-def chain starting from the region iter_arg are subset extraction or
+ // subset insertion ops. The chain must terminate at the corresponding yield
+ // operand (e.g., no swapping of iter_args).
+ OpOperand *yieldedOperand = nullptr;
+ // Iterate until the single use of the current SSA value is a terminator,
+ // which is expected to be the yielding operation of the loop.
+ while (!(yieldedOperand = getSingleTerminatorUse(value))) {
+ Value nextValue = {};
+
+ for (OpOperand &use : value.getUses()) {
+ auto subsetOp = dyn_cast<SubsetOpInterface>(use.getOwner());
+ if (!subsetOp)
+ return loopLike;
+ subsets.insert(subsetOp);
+
+ if (auto insertionOp =
+ dyn_cast<SubsetInsertionOpInterface>(use.getOwner())) {
+ // The value must be used as a destination. (In case of a source, the
+ // entire tensor would be read, which would prevent any hoisting.)
+ if (&use != &insertionOp.getDestinationOperand())
+ return loopLike;
+ // There must be a single use-def chain from the region iter_arg to the
+ // terminator. I.e., only one insertion op. Branches are not supported.
+ if (nextValue)
+ return loopLike;
+ nextValue = insertionOp.getUpdatedDestination();
+ }
+ }
+
+ // Nothing can be hoisted if the chain does not continue with loop yielding
+ // op or a subset insertion op.
+ if (!nextValue)
+ return loopLike;
+ value = nextValue;
+ }
+
+ // Hoist only if the SSA use-def chain ends in the yielding terminator of the
+ // loop and the yielded value is the `idx`-th operand. (I.e., there is no
+ // swapping yield.)
+ if (loopLike.getTiedLoopYieldedValue(iterArg) != yieldedOperand)
+ return loopLike;
+
+ // Hoist all matching extraction-insertion pairs one-by-one.
+ for (auto it : subsets.getHoistableSubsetOps()) {
+ auto extractionOp = std::get<0>(it);
+ auto insertionOp = std::get<1>(it);
+
+ // Ops cannot be hoisted if they depend on loop-variant values.
+ if (extractionOp) {
+ if (!canBeHoisted(extractionOp, [&](OpOperand &operand) {
+ return loopLike.isDefinedOutsideOfLoop(operand.get()) ||
+ &operand == &extractionOp.getSourceOperand();
+ }))
+ extractionOp = {};
+ }
+ if (insertionOp) {
+ if (!canBeHoisted(insertionOp, [&](OpOperand &operand) {
+ return loopLike.isDefinedOutsideOfLoop(operand.get()) ||
+ &operand == &insertionOp.getSourceOperand() ||
+ &operand == &insertionOp.getDestinationOperand();
+ }))
+ insertionOp = {};
+ }
+
+ // Only hoist extraction-insertion pairs for now. Standalone extractions/
+ // insertions that are loop-invariant could be hoisted, but there may be
+ // easier ways to canonicalize the IR.
+ if (extractionOp && insertionOp) {
+ // Create a new loop with an additional iter_arg.
+ NewYieldValuesFn newYieldValuesFn =
+ [&](OpBuilder &b, Location loc,
+ ArrayRef<BlockArgument> innerNewBBArgs) -> SmallVector<Value> {
+ return {insertionOp.getSourceOperand().get()};
+ };
+ FailureOr<LoopLikeOpInterface> newLoop =
+ loopLike.replaceWithAdditionalYields(
+ rewriter, extractionOp.getResult(),
+ /*replaceInitOperandUsesInLoop=*/true, newYieldValuesFn);
+ if (failed(newLoop))
+ return loopLike;
+ loopLike = *newLoop;
+
+ // Hoist the extraction/insertion ops.
+ iterArg = loopLike.getRegionIterArgs()[iterArgIdx];
+ OpResult loopResult = loopLike.getTiedLoopResult(iterArg);
+ OpResult newLoopResult = loopLike.getLoopResults()->back();
+ extractionOp->moveBefore(loopLike);
+ insertionOp->moveAfter(loopLike);
+ insertionOp.getUpdatedDestination().replaceAllUsesWith(
+ insertionOp.getDestinationOperand().get());
+ extractionOp.getSourceOperand().set(
+ loopLike.getTiedLoopInit(iterArg)->get());
+ loopResult.replaceAllUsesWith(insertionOp.getUpdatedDestination());
+ insertionOp.getSourceOperand().set(newLoopResult);
+ insertionOp.getDestinationOperand().set(loopResult);
+ }
+ }
+
+ return loopLike;
+}
+
+LoopLikeOpInterface
+mlir::hoistLoopInvariantSubsets(LoopLikeOpInterface loopLike) {
+ // Note: As subset ops are getting hoisted, the number of region iter_args
+ // increases. This can enable further hoisting opportunities on the new
+ // iter_args.
+ for (int64_t i = 0; i < loopLike.getRegionIterArgs().size(); ++i) {
+ loopLike = hoistSubsetAtIterArg(loopLike, loopLike.getRegionIterArgs()[i]);
+ }
+ return loopLike;
+}
diff --git a/mlir/test/Transforms/loop-invariant-subset-hoisting.mlir b/mlir/test/Transforms/loop-invariant-subset-hoisting.mlir
new file mode 100644
index 000000000000000..5cded4c99182c14
--- /dev/null
+++ b/mlir/test/Transforms/loop-invariant-subset-hoisting.mlir
@@ -0,0 +1,237 @@
+// RUN: mlir-opt %s -split-input-file -loop-invariant-subset-hoisting | FileCheck %s
+
+// CHECK-LABEL: func @hoist_matching_extract_insert(
+// CHECK-SAME: %[[arg:.*]]: tensor<?xf32>
+func.func @hoist_matching_extract_insert(%arg: tensor<?xf32>) -> tensor<?xf32> {
+ %lb = "test.foo"() : () -> (index)
+ %ub = "test.foo"() : () -> (index)
+ %step = "test.foo"() : () -> (index)
+
+ // CHECK: %[[extract:.*]] = tensor.extract_slice %[[arg]]
+ // CHECK: %[[for:.*]]:2 = scf.for {{.*}} iter_args(%[[t:.*]] = %[[arg]], %[[hoisted:.*]] = %[[extract]])
+ %0 = scf.for %iv = %lb to %ub step %step iter_args(%t = %arg) -> (tensor<?xf32>) {
+ // CHECK: tensor.extract_slice %[[t]][9] [5] [1]
+ %standalone = tensor.extract_slice %t[9][5][1] : tensor<?xf32> to tensor<5xf32>
+ "test.foo"(%standalone) : (tensor<5xf32>) -> ()
+
+ %1 = tensor.extract_slice %t[0][5][1] : tensor<?xf32> to tensor<5xf32>
+ // CHECK: %[[foo:.*]] = "test.foo"(%[[hoisted]])
+ %2 = "test.foo"(%1) : (tensor<5xf32>) -> (tensor<5xf32>)
+ %3 = tensor.insert_slice %2 into %t[0][5][1] : tensor<5xf32> into tensor<?xf32>
+ // CHECK: scf.yield %[[t]], %[[foo]]
+ scf.yield %3 : tensor<?xf32>
+ }
+ // CHECK: %[[insert:.*]] = tensor.insert_slice %[[for]]#1 into %[[for]]#0
+
+ // CHECK: return %[[insert]]
+ return %0 : tensor<?xf32>
+}
+
+// -----
+
+func.func @subset_of_subset(%arg: tensor<?xf32>) -> tensor<?xf32> {
+ %lb = "test.foo"() : () -> (index)
+ %ub = "test.foo"() : () -> (index)
+ %step = "test.foo"() : () -> (index)
+
+ // CHECK: %[[extract1:.*]] = tensor.extract_slice %[[arg]]
+ // CHECK: %[[extract2:.*]] = tensor.extract_slice %[[extract1]]
+ // CHECK: %[[for:.*]]:3 = scf.for {{.*}} iter_args(%[[t:.*]] = %[[arg]], %[[hoisted1:.*]] = %[[extract1]], %[[hoisted2:.*]] = %[[extract2]])
+ %0 = scf.for %iv = %lb to %ub step %step iter_args(%t = %arg) -> (tensor<?xf32>) {
+ %extract1 = tensor.extract_slice %t[0][5][1] : tensor<?xf32> to tensor<5xf32>
+ %extract2 = tensor.extract_slice %extract1[1][2][1] : tensor<5xf32> to tensor<2xf32>
+
+ // CHECK: %[[foo:.*]] = "test.foo"(%[[hoisted2]])
+ %2 = "test.foo"(%extract2) : (tensor<2xf32>) -> (tensor<2xf32>)
+
+ %insert1 = tensor.insert_slice %2 into %extract1[1][2][1] : tensor<2xf32> into tensor<5xf32>
+ %insert2 = tensor.insert_slice %insert1 into %t[0][5][1] : tensor<5xf32> into tensor<?xf32>
+
+ // CHECK: scf.yield %[[t]], %[[hoisted1]], %[[foo]]
+ scf.yield %insert2 : tensor<?xf32>
+ }
+ // CHECK: %[[insert2:.*]] = tensor.insert_slice %[[for]]#2 into %[[for]]#1[1] [2] [1]
+ // CHECK: %[[insert1:.*]] = tensor.insert_slice %[[insert2]] into %[[for]]#0[0] [5] [1]
+
+ // CHECK: return %[[insert1]]
+ return %0 : tensor<?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @hoist_matching_chain(
+// CHECK-SAME: %[[arg:.*]]: tensor<?xf32>
+func.func @hoist_matching_chain(%arg: tensor<?xf32>) -> tensor<?xf32> {
+ %lb = "test.foo"() : () -> (index)
+ %ub = "test.foo"() : () -> (index)
+ %step = "test.foo"() : () -> (index)
+ %sz = "test.foo"() : () -> (index)
+
+ // CHECK: %[[extract2:.*]] = tensor.extract_slice %[[arg]][%{{.*}}] [5] [1]
+ // CHECK: %[[extract1:.*]] = tensor.extract_slice %[[arg]][0] [%{{.*}}] [1]
+ // CHECK: %[[for:.*]]:3 = scf.for {{.*}} iter_args(%[[t:.*]] = %[[arg]], %[[hoisted2:.*]] = %[[extract2]], %[[hoisted1:.*]] = %[[extract1]])
+ %0 = scf.for %iv = %lb to %ub step %step iter_args(%t = %arg) -> (tensor<?xf32>) {
+ %1 = tensor.extract_slice %t[0][%sz][1] : tensor<?xf32> to tensor<?xf32>
+ %2 = tensor.extract_slice %t[%sz][5][1] : tensor<?xf32> to tensor<5xf32>
+ // CHECK-DAG: %[[foo1:.*]] = "test.foo"(%[[hoisted1]])
+ // CHECK-DAG: %[[foo2:.*]] = "test.foo"(%[[hoisted2]])
+ %foo1 = "test.foo"(%1) : (tensor<?xf32>) -> (tensor<?xf32>)
+ %foo2 = "test.foo"(%2) : (tensor<5xf32>) -> (tensor<5xf32>)
+ %5 = tensor.insert_slice %foo2 into %t[%sz][5][1] : tensor<5xf32> into tensor<?xf32>
+ %6 = tensor.insert_slice %foo1 into %5[0][%sz][1] : tensor<?xf32> into tensor<?xf32>
+ // CHECK: scf.yield %[[t]], %[[foo2]], %[[foo1]]
+ scf.yield %6 : tensor<?xf32>
+ }
+ // CHECK: %[[insert2:.*]] = tensor.insert_slice %[[for]]#2 into %[[for]]#0[0] [%{{.*}}] [1]
+ // CHECK: %[[insert1:.*]] = tensor.insert_slice %[[for]]#1 into %[[insert2]][%{{.*}}] [5] [1]
+
+ // CHECK: return %[[insert1]]
+ return %0 : tensor<?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @do_not_hoist_overlapping_subsets(
+func.func @do_not_hoist_overlapping_subsets(%arg: tensor<?xf32>) -> tensor<?xf32> {
+ %lb = "test.foo"() : () -> (index)
+ %ub = "test.foo"() : () -> (index)
+ %step = "test.foo"() : () -> (index)
+ %sz1 = "test.foo"() : () -> (index)
+ %sz2 = "test.foo"() : () -> (index)
+
+ // CHECK: scf.for
+ %0 = scf.for %iv = %lb to %ub step %step iter_args(%t = %arg) -> (tensor<?xf32>) {
+ // These two slices are potentially overlapping. Do not hoist.
+ // CHECK: tensor.extract_slice
+ // CHECK: tensor.extract_slice
+ %1 = tensor.extract_slice %t[0][%sz1][1] : tensor<?xf32> to tensor<?xf32>
+ %2 = tensor.extract_slice %t[10][%sz2][1] : tensor<?xf32> to tensor<?xf32>
+ // CHECK: "test.foo"
+ // CHECK: "test.foo"
+ %foo1 = "test.foo"(%1) : (tensor<?xf32>) -> (tensor<?xf32>)
+ %foo2 = "test.foo"(%2) : (tensor<?xf32>) -> (tensor<?xf32>)
+ // CHECK: tensor.insert_slice
+ // CHECK: tensor.insert_slice
+ %5 = tensor.insert_slice %foo2 into %t[0][%sz1][1] : tensor<?xf32> into tensor<?xf32>
+ %6 = tensor.insert_slice %foo1 into %5[10][%sz2][1] : tensor<?xf32> into tensor<?xf32>
+ // CHECK: scf.yield
+ scf.yield %6 : tensor<?xf32>
+ }
+
+ return %0 : tensor<?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @multiple_yields(
+// CHECK-SAME: %[[arg:.*]]: tensor<?xf32>
+func.func @multiple_yields(%arg: tensor<?xf32>) -> (tensor<?xf32>, tensor<?xf32>) {
+ %lb = "test.foo"() : () -> (index)
+ %ub = "test.foo"() : () -> (index)
+ %step = "test.foo"() : () -> (index)
+
+ // CHECK: %[[extract1:.*]] = tensor.extract_slice
+ // CHECK: %[[extract2:.*]] = tensor.extract_slice
+ // CHECK: scf.for {{.*}} iter_args(%{{.*}} = %[[arg]], %{{.*}} = %[[arg]], %{{.*}} = %[[extract1]], %{{.*}} = %[[extract2]])
+ %0:2 = scf.for %iv = %lb to %ub step %step iter_args(%t1 = %arg, %t2 = %arg)
+ -> (tensor<?xf32>, tensor<?xf32>) {
+ %1 = tensor.extract_slice %t1[0][5][1] : tensor<?xf32> to tensor<5xf32>
+ %2 = tensor.extract_slice %t2[5][5][1] : tensor<?xf32> to tensor<5xf32>
+ // CHECK: "test.foo"
+ // CHECK: "test.foo"
+ %foo1 = "test.foo"(%1) : (tensor<5xf32>) -> (tensor<5xf32>)
+ %foo2 = "test.foo"(%2) : (tensor<5xf32>) -> (tensor<5xf32>)
+ %5 = tensor.insert_slice %foo2 into %t1[0][5][1] : tensor<5xf32> into tensor<?xf32>
+ %6 = tensor.insert_slice %foo1 into %t2[5][5][1] : tensor<5xf32> into tensor<?xf32>
+ // CHECK: scf.yield
+ scf.yield %5, %6 : tensor<?xf32>, tensor<?xf32>
+ }
+ // CHECK: tensor.insert_slice
+ // CHECK: tensor.insert_slice
+
+ return %0#0, %0#1 : tensor<?xf32>, tensor<?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @do_not_hoist_swapping_yields(
+func.func @do_not_hoist_swapping_yields(%arg: tensor<?xf32>) -> (tensor<?xf32>, tensor<?xf32>) {
+ %lb = "test.foo"() : () -> (index)
+ %ub = "test.foo"() : () -> (index)
+ %step = "test.foo"() : () -> (index)
+
+ // CHECK: scf.for
+ %0:2 = scf.for %iv = %lb to %ub step %step iter_args(%t1 = %arg, %t2 = %arg)
+ -> (tensor<?xf32>, tensor<?xf32>) {
+ // CHECK: tensor.extract_slice
+ // CHECK: tensor.extract_slice
+ %1 = tensor.extract_slice %t1[0][5][1] : tensor<?xf32> to tensor<5xf32>
+ %2 = tensor.extract_slice %t2[5][5][1] : tensor<?xf32> to tensor<5xf32>
+ // CHECK: "test.foo"
+ // CHECK: "test.foo"
+ %foo1 = "test.foo"(%1) : (tensor<5xf32>) -> (tensor<5xf32>)
+ %foo2 = "test.foo"(%2) : (tensor<5xf32>) -> (tensor<5xf32>)
+ // CHECK: tensor.insert_slice
+ // CHECK: tensor.insert_slice
+ %5 = tensor.insert_slice %foo2 into %t1[0][5][1] : tensor<5xf32> into tensor<?xf32>
+ %6 = tensor.insert_slice %foo1 into %t2[5][5][1] : tensor<5xf32> into tensor<?xf32>
+ // Swapping yields: do not hoist.
+ // CHECK: scf.yield
+ scf.yield %6, %5 : tensor<?xf32>, tensor<?xf32>
+ }
+
+ return %0#0, %0#1 : tensor<?xf32>, tensor<?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @non_subset_op(
+func.func @non_subset_op(%arg: tensor<?xf32>) -> tensor<?xf32> {
+ %lb = "test.foo"() : () -> (index)
+ %ub = "test.foo"() : () -> (index)
+ %step = "test.foo"() : () -> (index)
+
+ // CHECK: scf.for
+ %0 = scf.for %iv = %lb to %ub step %step iter_args(%t = %arg) -> (tensor<?xf32>) {
+ // If any value along the use-def chain from the region iter_arg to the
+ // terminator is used by a non-subset op, no subset op along that chain can
+ // be hoisted. That is because it is unknown which parts of the value are
+ // accessed by the non-subset op.
+ // CHECK: "test.non_subset_op"
+ "test.non_subset_op"(%t) : (tensor<?xf32>) -> ()
+ // CHECK: tensor.extract_slice
+ %1 = tensor.extract_slice %t[0][5][1] : tensor<?xf32> to tensor<5xf32>
+ // CHECK: "test.foo"
+ %2 = "test.foo"(%1) : (tensor<5xf32>) -> (tensor<5xf32>)
+ // CHECK: tensor.insert_slice
+ %3 = tensor.insert_slice %2 into %t[0][5][1] : tensor<5xf32> into tensor<?xf32>
+ // CHECK: scf.yield
+ scf.yield %3 : tensor<?xf32>
+ }
+
+ return %0 : tensor<?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @non_loop_invariant_subset_op(
+func.func @non_loop_invariant_subset_op(%arg: tensor<?xf32>) -> tensor<?xf32> {
+ %lb = "test.foo"() : () -> (index)
+ %ub = "test.foo"() : () -> (index)
+ %step = "test.foo"() : () -> (index)
+
+ // CHECK: scf.for
+ %0 = scf.for %iv = %lb to %ub step %step iter_args(%t = %arg) -> (tensor<?xf32>) {
+ // Subset ops that are not loop-invariant cannot be hoisted.
+ // CHECK: tensor.extract_slice
+ %1 = tensor.extract_slice %t[%iv][5][1] : tensor<?xf32> to tensor<5xf32>
+ // CHECK: "test.foo"
+ %2 = "test.foo"(%1) : (tensor<5xf32>) -> (tensor<5xf32>)
+ // CHECK: tensor.insert_slice
+ %3 = tensor.insert_slice %2 into %t[%iv][5][1] : tensor<5xf32> into tensor<?xf32>
+ // CHECK: scf.yield
+ scf.yield %3 : tensor<?xf32>
+ }
+
+ return %0 : tensor<?xf32>
+}
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index add146bf084d321..a9aa1d848fce4f5 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -7035,6 +7035,7 @@ cc_library(
":MemorySlotInterfaces",
":Rewrite",
":SideEffectInterfaces",
+ ":SubsetOpInterface",
":Support",
":TransformsPassIncGen",
":config",
>From 959e9254c112c339993b96c4ff2152bbf1388257 Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Mon, 30 Oct 2023 15:06:41 +0900
Subject: [PATCH 3/6] [mlir][WIP] Bypass analysis for loops
---
.../Utils/LoopInvariantCodeMotionUtils.cpp | 72 ++++++++++++++-----
.../loop-invariant-subset-hoisting.mlir | 35 +++++++++
2 files changed, 91 insertions(+), 16 deletions(-)
diff --git a/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp b/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
index f39f8ec3598f322..bb4e6dc62b9c935 100644
--- a/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
@@ -120,8 +120,10 @@ namespace {
class MatchingSubsets {
public:
/// Insert a subset op.
- void insert(SubsetOpInterface op) {
+ void insert(SubsetOpInterface op, bool collectHoistableOps = true) {
allSubsetOps.push_back(op);
+ if (!collectHoistableOps)
+ return;
if (auto extractionOp =
dyn_cast<SubsetExtractionOpInterface>(op.getOperation()))
insertExtractionOp(extractionOp);
@@ -148,6 +150,15 @@ class MatchingSubsets {
});
}
+ /// Populate subset ops starting from the given region iter_arg. Return
+ /// "failure" if non-subset ops are found along the path to the loop yielding
+ /// op or if there is no single path to the tied yielded operand. If
+ /// `collectHoistableOps` is set to "false", subset ops are gathered
+ /// throughout the traversal, but not enumerated by `getHoistableSubsetOps`.
+ LogicalResult populateSubsetOpsAtIterArg(LoopLikeOpInterface loopLike,
+ BlockArgument iterArg,
+ bool collectHoistableOps = true);
+
private:
/// Helper function for equivalence of tensor values. Since only insertion
/// subset ops (that are also destination style ops) are followed when
@@ -225,18 +236,12 @@ static OpOperand *getSingleTerminatorUse(Value value) {
return nullptr;
}
-/// Hoist all subset ops that operate on the idx-th region iter_arg of the given
-/// loop-like op and index into loop-invariant subset locations. Return the
-/// newly created loop op (that has extra iter_args) or the original loop op if
-/// nothing was hoisted.
-static LoopLikeOpInterface hoistSubsetAtIterArg(LoopLikeOpInterface loopLike,
- BlockArgument iterArg) {
- IRRewriter rewriter(loopLike.getContext());
+LogicalResult
+MatchingSubsets::populateSubsetOpsAtIterArg(LoopLikeOpInterface loopLike,
+ BlockArgument iterArg,
+ bool collectHoistableOps) {
assert(iterArg.getOwner()->getParentOp() == loopLike && "invalid iter_arg");
- auto it = llvm::find(loopLike.getRegionIterArgs(), iterArg);
- int64_t iterArgIdx = std::distance(loopLike.getRegionIterArgs().begin(), it);
Value value = iterArg;
- MatchingSubsets subsets;
// Traverse use-def chain. Subset ops can be hoisted only if all ops along the
// use-def chain starting from the region iter_arg are subset extraction or
@@ -249,21 +254,39 @@ static LoopLikeOpInterface hoistSubsetAtIterArg(LoopLikeOpInterface loopLike,
Value nextValue = {};
for (OpOperand &use : value.getUses()) {
+ if (auto nestedLoop = dyn_cast<LoopLikeOpInterface>(use.getOwner())) {
+ // Subset ops in nested loops are collected to check if there are only
+ // disjoint subset ops, but such subset ops are not subject to hoisting.
+ // To hoist subset ops from nested loops, the hoisting transformation
+ // should be run on the nested loop.
+ auto nestedIterArg = nestedLoop.getTiedLoopRegionIterArg(&use);
+ if (!nestedIterArg)
+ return failure();
+ // Note: `populateSubsetOpsAtIterArg` fails if there is no single SSA
+ // use-def chain starting at `nestedIterArg` and terminating in the
+ // tied, yielding operand.
+ if (failed(populateSubsetOpsAtIterArg(nestedLoop, nestedIterArg,
+ /*collectHoistableOps=*/false)))
+ return failure();
+ nextValue = nestedLoop.getTiedLoopResult(&use);
+ continue;
+ }
+
auto subsetOp = dyn_cast<SubsetOpInterface>(use.getOwner());
if (!subsetOp)
- return loopLike;
- subsets.insert(subsetOp);
+ return failure();
+ insert(subsetOp);
if (auto insertionOp =
dyn_cast<SubsetInsertionOpInterface>(use.getOwner())) {
// The value must be used as a destination. (In case of a source, the
// entire tensor would be read, which would prevent any hoisting.)
if (&use != &insertionOp.getDestinationOperand())
- return loopLike;
+ return failure();
// There must be a single use-def chain from the region iter_arg to the
// terminator. I.e., only one insertion op. Branches are not supported.
if (nextValue)
- return loopLike;
+ return failure();
nextValue = insertionOp.getUpdatedDestination();
}
}
@@ -271,7 +294,7 @@ static LoopLikeOpInterface hoistSubsetAtIterArg(LoopLikeOpInterface loopLike,
// Nothing can be hoisted if the chain does not continue with loop yielding
// op or a subset insertion op.
if (!nextValue)
- return loopLike;
+ return failure();
value = nextValue;
}
@@ -279,6 +302,23 @@ static LoopLikeOpInterface hoistSubsetAtIterArg(LoopLikeOpInterface loopLike,
// loop and the yielded value is the `idx`-th operand. (I.e., there is no
// swapping yield.)
if (loopLike.getTiedLoopYieldedValue(iterArg) != yieldedOperand)
+ return failure();
+
+ return success();
+}
+
+/// Hoist all subset ops that operate on the idx-th region iter_arg of the given
+/// loop-like op and index into loop-invariant subset locations. Return the
+/// newly created loop op (that has extra iter_args) or the original loop op if
+/// nothing was hoisted.
+static LoopLikeOpInterface hoistSubsetAtIterArg(LoopLikeOpInterface loopLike,
+ BlockArgument iterArg) {
+ assert(iterArg.getOwner()->getParentOp() == loopLike && "invalid iter_arg");
+ auto it = llvm::find(loopLike.getRegionIterArgs(), iterArg);
+ int64_t iterArgIdx = std::distance(loopLike.getRegionIterArgs().begin(), it);
+ IRRewriter rewriter(loopLike.getContext());
+ MatchingSubsets subsets;
+ if (failed(subsets.populateSubsetOpsAtIterArg(loopLike, iterArg)))
return loopLike;
// Hoist all matching extraction-insertion pairs one-by-one.
diff --git a/mlir/test/Transforms/loop-invariant-subset-hoisting.mlir b/mlir/test/Transforms/loop-invariant-subset-hoisting.mlir
index 5cded4c99182c14..b9161f4e20d1927 100644
--- a/mlir/test/Transforms/loop-invariant-subset-hoisting.mlir
+++ b/mlir/test/Transforms/loop-invariant-subset-hoisting.mlir
@@ -235,3 +235,38 @@ func.func @non_loop_invariant_subset_op(%arg: tensor<?xf32>) -> tensor<?xf32> {
return %0 : tensor<?xf32>
}
+
+// -----
+
+// CHECK-LABEL: func @nested_hoisting(
+// CHECK-SAME: %[[arg:.*]]: tensor<?xf32>
+func.func @nested_hoisting(%arg: tensor<?xf32>) -> tensor<?xf32> {
+ %lb = "test.foo"() : () -> (index)
+ %ub = "test.foo"() : () -> (index)
+ %step = "test.foo"() : () -> (index)
+
+ // CHECK: %[[extract:.*]] = tensor.extract_slice %[[arg]][0] [5] [1]
+ // CHECK: %[[extract2:.*]] = tensor.extract_slice %[[arg]][5] [5] [1]
+ // CHECK: %[[for:.*]]:3 = scf.for {{.*}} iter_args(%[[t:.*]] = %[[arg]], %[[hoisted:.*]] = %[[extract]], %[[hoisted2:.*]] = %[[extract2]])
+ %0 = scf.for %iv = %lb to %ub step %step iter_args(%t = %arg) -> (tensor<?xf32>) {
+ %1 = tensor.extract_slice %t[0][5][1] : tensor<?xf32> to tensor<5xf32>
+ // CHECK: %[[foo:.*]] = "test.foo"(%[[hoisted]])
+ %2 = "test.foo"(%1) : (tensor<5xf32>) -> (tensor<5xf32>)
+ %3 = tensor.insert_slice %2 into %t[0][5][1] : tensor<5xf32> into tensor<?xf32>
+ // CHECK: %[[for2:.*]]:2 = {{.*}} iter_args(%[[t2:.*]] = %[[t]], %[[hoisted2_nested:.*]] = %[[hoisted2]])
+ %4 = scf.for %iv2 = %lb to %ub step %step iter_args(%t2 = %3) -> (tensor<?xf32>) {
+ %5 = tensor.extract_slice %t2[5][5][1] : tensor<?xf32> to tensor<5xf32>
+ // CHECK: %[[foo2:.*]] = "test.foo"(%[[hoisted2_nested]])
+ %6 = "test.foo"(%5) : (tensor<5xf32>) -> (tensor<5xf32>)
+ %7 = tensor.insert_slice %6 into %t2[5][5][1] : tensor<5xf32> into tensor<?xf32>
+ // CHECK: scf.yield %[[t2]], %[[foo2]]
+ scf.yield %7 : tensor<?xf32>
+ }
+ // CHECK: scf.yield %[[for2]]#0, %[[foo]], %[[for2]]#1
+ scf.yield %4 : tensor<?xf32>
+ }
+ // CHECK: %[[insert:.*]] = tensor.insert_slice %[[for]]#2 into %[[for]]#0[5] [5] [1]
+ // CHECK: %[[insert2:.*]] = tensor.insert_slice %[[for]]#1 into %[[insert]][0] [5] [1]
+ // CHECK: return %[[insert2]]
+ return %0 : tensor<?xf32>
+}
>From 387be39420a1830683d3b6291d95b0beeff5a5b4 Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Mon, 30 Oct 2023 15:17:02 +0900
Subject: [PATCH 4/6] [mlir][Interfaces] `SubsetOpInterface`: Add helpers for
hyperrectangular subsets
---
.../Bufferization/IR/BufferizationOps.td | 3 +-
mlir/include/mlir/IR/OpDefinition.h | 5 +
.../mlir/Interfaces/SubsetOpInterface.h | 16 ++-
.../mlir/Interfaces/SubsetOpInterface.td | 53 +++++++--
.../mlir/Interfaces/ValueBoundsOpInterface.h | 53 ++++++++-
.../SubsetInsertionOpInterfaceImpl.cpp | 82 +------------
mlir/lib/Interfaces/CMakeLists.txt | 2 +
mlir/lib/Interfaces/SubsetOpInterface.cpp | 49 ++++++++
.../lib/Interfaces/ValueBoundsOpInterface.cpp | 109 ++++++++++++++++--
.../loop-invariant-subset-hoisting.mlir | 9 +-
.../llvm-project-overlay/mlir/BUILD.bazel | 1 +
11 files changed, 285 insertions(+), 97 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
index e6b6d052df96a8c..9dc6afcaab31c86 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
@@ -220,7 +220,8 @@ def Bufferization_MaterializeInDestinationOp
AllElementTypesMatch<["source", "dest"]>,
BufferizableOpInterface, DestinationStyleOpInterface,
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
- DeclareOpInterfaceMethods<SubsetOpInterface>,
+ DeclareOpInterfaceMethods<SubsetOpInterface,
+ ["operatesOnEquivalentSubset", "operatesOnDisjointSubset"]>,
DeclareOpInterfaceMethods<SubsetInsertionOpInterface,
["getSourceOperand", "getValuesNeededToBuildSubsetExtraction",
"buildSubsetExtraction", "isEquivalentSubset"]>,
diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index 8ab37c1d51d6b6c..bd68c27445744e3 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -268,6 +268,11 @@ class OpFoldResult : public PointerUnion<Attribute, Value> {
public:
void dump() const { llvm::errs() << *this << "\n"; }
+
+ MLIRContext *getContext() const {
+ return is<Attribute>() ? get<Attribute>().getContext()
+ : get<Value>().getContext();
+ }
};
// Temporarily exit the MLIR namespace to add casting support as later code in
diff --git a/mlir/include/mlir/Interfaces/SubsetOpInterface.h b/mlir/include/mlir/Interfaces/SubsetOpInterface.h
index 049cf2456a9c842..98c33ec65012fca 100644
--- a/mlir/include/mlir/Interfaces/SubsetOpInterface.h
+++ b/mlir/include/mlir/Interfaces/SubsetOpInterface.h
@@ -10,6 +10,7 @@
#define MLIR_INTERFACES_SUBSETOPINTERFACE_H_
#include "mlir/IR/OpDefinition.h"
+#include "mlir/Interfaces/ValueBoundsOpInterface.h"
namespace mlir {
class SubsetOpInterface;
@@ -27,10 +28,23 @@ OpOperand &defaultGetDestinationOperand(Operation *op);
/// `DestinationStyleOpInterface`.
OpResult defaultGetUpdatedDestination(Operation *op);
-/// Default implementation of `isEquivalentSubset`.
+/// Default implementation of `SubsetInsertionOpInterface::isEquivalentSubset`.
bool defaultIsEquivalentSubset(Operation *op, Value candidate,
function_ref<bool(Value, Value)> equivalenceFn);
+/// Default implementation of `SubsetOpInterface::operatesOnEquivalentSubset`.
+bool defaultOperatesOnEquivalentSubset(
+ Operation *op, SubsetOpInterface candidate,
+ function_ref<bool(Value, Value)> equivalenceFn);
+
+/// Default implementation of `SubsetOpInterface::operatesOnDisjointSubset`.
+bool defaultOperatesOnDisjointSubset(
+ Operation *op, SubsetOpInterface candidate,
+ function_ref<bool(Value, Value)> equivalenceFn);
+
+/// Return the container that the given subset op is operating on.
+Value getTensorContainer(Operation *op);
+
/// Verify `SubsetOpInterface`.
LogicalResult verifySubsetOpInterface(SubsetOpInterface op);
diff --git a/mlir/include/mlir/Interfaces/SubsetOpInterface.td b/mlir/include/mlir/Interfaces/SubsetOpInterface.td
index 07d62b8319c2961..a00e398618a0118 100644
--- a/mlir/include/mlir/Interfaces/SubsetOpInterface.td
+++ b/mlir/include/mlir/Interfaces/SubsetOpInterface.td
@@ -32,11 +32,6 @@ def SubsetOpInterface : OpInterface<"SubsetOpInterface"> {
hyperrectangular slice.
- `tensor.gather/scatter` describe the subset as list of indices. (Not
implemented yet.)
-
- Note: This interface does not expose any interface methods to get a
- description of the accessed subset. That is because there is currently no
- efficient way to describe arbitrary subsets. This interface merely provides
- interface methods to check if two subsets are equivalent or disjoint.
}];
let cppNamespace = "::mlir";
@@ -46,24 +41,59 @@ def SubsetOpInterface : OpInterface<"SubsetOpInterface"> {
Return "true" if this op and the given candidate subset op operate on
an equivalent subset. Return "false" is the two subsets are disjoint
or cannot be proven to be equivalent.
+
+ This interface method does not have to be implemented if
+ `getAccessedHyperrectangularSlice` is implemented.
}],
/*retType=*/"bool",
/*methodName=*/"operatesOnEquivalentSubset",
/*args=*/(ins
"::mlir::SubsetOpInterface":$candidate,
- "::llvm::function_ref<bool(Value, Value)>":$equivalenceFn)
+ "::llvm::function_ref<bool(Value, Value)>":$equivalenceFn),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return ::mlir::detail::defaultOperatesOnEquivalentSubset(
+ $_op, candidate, equivalenceFn);
+ }]
>,
InterfaceMethod<
/*desc=*/[{
Return "true" if this op and the given candidate subset op operate on
disjoint subsets. Return "false" is the two subsets are equivalent,
overlapping or cannot be proven to be disjoint.
+
+ This interface method does not have to be implemented if
+ `getAccessedHyperrectangularSlice` is implemented.
}],
/*retType=*/"bool",
/*methodName=*/"operatesOnDisjointSubset",
/*args=*/(ins
"::mlir::SubsetOpInterface":$candidate,
- "::llvm::function_ref<bool(Value, Value)>":$equivalenceFn)
+ "::llvm::function_ref<bool(Value, Value)>":$equivalenceFn),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return ::mlir::detail::defaultOperatesOnDisjointSubset(
+ $_op, candidate, equivalenceFn);
+ }]
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ If this op operates on a hyperrectangular subset, return a
+ description of the subset in terms of offsets, sizes and strides.
+ Otherwise, return "failure".
+
+ This interface method is a convenience method for the most common case
+ of hyperrectangular subset ops. It is optional. If it is implemented,
+ `operatesOnEquivalentSubset` and `operatesOnDisjointSubset` do not
+ have to be implemented.
+ }],
+ /*retType=*/"::mlir::FailureOr<::mlir::HyperrectangularSlice>",
+ /*methodName=*/"getAccessedHyperrectangularSlice",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return ::mlir::failure();
+ }]
>,
];
@@ -71,6 +101,15 @@ def SubsetOpInterface : OpInterface<"SubsetOpInterface"> {
return ::mlir::detail::verifySubsetOpInterface(
::mlir::cast<::mlir::SubsetOpInterface>($_op));
}];
+
+ let extraClassDeclaration = [{
+ /// Return the container that this operation is operating on. In case of an
+ /// extraction op, the container is the source tensor. In case of an
+ /// insertion op, the container is the destination tensor.
+ Value getTensorContainer() {
+ return ::mlir::detail::getTensorContainer(getOperation());
+ }
+ }];
}
def SubsetExtractionOpInterface
diff --git a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
index 8e2986a2d1f05f6..28dadfb9ecf8688 100644
--- a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
+++ b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
@@ -21,6 +21,31 @@
namespace mlir {
class OffsetSizeAndStrideOpInterface;
+/// A hyperrectangular slice, represented as a list of offsets, sizes and
+/// strides.
+class HyperrectangularSlice {
+public:
+ HyperrectangularSlice(ArrayRef<OpFoldResult> offsets,
+ ArrayRef<OpFoldResult> sizes,
+ ArrayRef<OpFoldResult> strides);
+
+ /// Create a hyperrectangular slice with unit strides.
+ HyperrectangularSlice(ArrayRef<OpFoldResult> offsets,
+ ArrayRef<OpFoldResult> sizes);
+
+ /// Infer a hyperrectangular slice from `OffsetSizeAndStrideOpInterface`.
+ HyperrectangularSlice(OffsetSizeAndStrideOpInterface op);
+
+ ArrayRef<OpFoldResult> getMixedOffsets() const { return mixedOffsets; }
+ ArrayRef<OpFoldResult> getMixedSizes() const { return mixedSizes; }
+ ArrayRef<OpFoldResult> getMixedStrides() const { return mixedStrides; }
+
+private:
+ SmallVector<OpFoldResult> mixedOffsets;
+ SmallVector<OpFoldResult> mixedSizes;
+ SmallVector<OpFoldResult> mixedStrides;
+};
+
using ValueDimList = SmallVector<std::pair<Value, std::optional<int64_t>>>;
/// A helper class to be used with `ValueBoundsOpInterface`. This class stores a
@@ -182,12 +207,34 @@ class ValueBoundsConstraintSet {
std::optional<int64_t> dim1 = std::nullopt,
std::optional<int64_t> dim2 = std::nullopt);
+ /// Compute whether the given values/attributes are equal. Return "failure" if
+ /// equality could not be determined.
+ ///
+ /// `ofr1`/`ofr2` must be of index type.
+ static FailureOr<bool> areEqual(OpFoldResult ofr1, OpFoldResult ofr2);
+
/// Return "true" if the given slices are guaranteed to be overlapping.
/// Return "false" if the given slices are guaranteed to be non-overlapping.
/// Return "failure" if unknown.
- static FailureOr<bool>
- areOverlappingSlices(OffsetSizeAndStrideOpInterface slice1,
- OffsetSizeAndStrideOpInterface slice2);
+ ///
+ /// Slices are overlapping if for all dimensions:
+ /// * offset1 + size1 * stride1 <= offset2
+ /// * and offset2 + size2 * stride2 <= offset1
+ ///
+ /// Slice are non-overlapping if the above constraint is not satisfied for
+ /// at least one dimension.
+ static FailureOr<bool> areOverlappingSlices(MLIRContext *ctx,
+ HyperrectangularSlice slice1,
+ HyperrectangularSlice slice2);
+
+ /// Return "true" if the given slices are guaranteed to be equivalent.
+ /// Return "false" if the given slices are guaranteed to be non-equivalent.
+ /// Return "failure" if unknown.
+ ///
+ /// Slices are equivalent if their offsets, sizes and strices are equal.
+ static FailureOr<bool> areEquivalentSlices(MLIRContext *ctx,
+ HyperrectangularSlice slice1,
+ HyperrectangularSlice slice2);
/// Add a bound for the given index-typed value or shaped value. This function
/// returns a builder that adds the bound.
diff --git a/mlir/lib/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.cpp
index 7a1bafd409eea60..d50d7c62b789c8c 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.cpp
@@ -17,73 +17,12 @@ using namespace mlir::tensor;
namespace {
-/// Return the tensor that the given subset op operates on.
-Value getContainerOperand(SubsetOpInterface op) {
- if (auto extractionOp =
- dyn_cast<SubsetExtractionOpInterface>(op.getOperation()))
- return extractionOp.getSourceOperand().get();
- if (auto insertionOp =
- dyn_cast<SubsetInsertionOpInterface>(op.getOperation()))
- return insertionOp.getDestinationOperand().get();
- llvm_unreachable("expected SubsetExtraction/InsertionOpInterface");
-}
-
-/// Return "true" if the two ops operate on an equivalent subset.
-/// `equivalenceFn` is used to determine equivalence of tensors. Return "false"
-/// if the two ops operate non-equivalent subsets, if equivalence cannot be
-/// determined or if `op1` is not a subset op.
-template <typename OpTy>
-bool operateOnEquivalentSubsets(
- OpTy op1, SubsetOpInterface op2,
- function_ref<bool(Value, Value)> equivalenceFn) {
- auto offsetsSizesAndStrides2 =
- dyn_cast<OffsetSizeAndStrideOpInterface>(op2.getOperation());
- if (!offsetsSizesAndStrides2)
- return false;
- if (!sameOffsetsSizesAndStrides(op1, offsetsSizesAndStrides2,
- isEqualConstantIntOrValue))
- return false;
- return equivalenceFn(
- getContainerOperand(cast<SubsetOpInterface>(op1.getOperation())),
- getContainerOperand(op2));
-}
-
-/// Return "true" if the two ops operate on a disjoint subsets.
-/// `equivalenceFn` is used to determine equivalence of tensors. Return "false"
-/// if the two ops operate non-disjoint subsets, if disjointness cannot be
-/// determined or if `op1` is not a subset op.
-template <typename OpTy>
-bool operateOnDisjointSubsets(OpTy op1, SubsetOpInterface op2,
- function_ref<bool(Value, Value)> equivalenceFn) {
- auto offsetsSizesAndStrides2 =
- dyn_cast<OffsetSizeAndStrideOpInterface>(op2.getOperation());
- if (!offsetsSizesAndStrides2)
- return false;
- FailureOr<bool> overlappingSlices =
- ValueBoundsConstraintSet::areOverlappingSlices(op1,
- offsetsSizesAndStrides2);
- if (failed(overlappingSlices) || *overlappingSlices)
- return false;
- return equivalenceFn(
- getContainerOperand(cast<SubsetOpInterface>(op1.getOperation())),
- getContainerOperand(op2));
-}
-
struct ExtractSliceOpSubsetOpInterface
: public SubsetOpInterface::ExternalModel<ExtractSliceOpSubsetOpInterface,
tensor::ExtractSliceOp> {
- bool operatesOnEquivalentSubset(
- Operation *op, SubsetOpInterface candidate,
- function_ref<bool(Value, Value)> equivalenceFn) const {
- auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
- return operateOnEquivalentSubsets(extractSliceOp, candidate, equivalenceFn);
- }
-
- bool operatesOnDisjointSubset(
- Operation *op, SubsetOpInterface candidate,
- function_ref<bool(Value, Value)> equivalenceFn) const {
- auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
- return operateOnDisjointSubsets(extractSliceOp, candidate, equivalenceFn);
+ FailureOr<HyperrectangularSlice>
+ getAccessedHyperrectangularSlice(Operation *op) const {
+ return HyperrectangularSlice(cast<OffsetSizeAndStrideOpInterface>(op));
}
};
@@ -99,18 +38,9 @@ template <typename OpTy>
struct InsertSliceLikeOpSubsetOpInterface
: public SubsetOpInterface::ExternalModel<
InsertSliceLikeOpSubsetOpInterface<OpTy>, OpTy> {
- bool operatesOnEquivalentSubset(
- Operation *op, SubsetOpInterface candidate,
- function_ref<bool(Value, Value)> equivalenceFn) const {
- auto insertSliceOp = cast<OpTy>(op);
- return operateOnEquivalentSubsets(insertSliceOp, candidate, equivalenceFn);
- }
-
- bool operatesOnDisjointSubset(
- Operation *op, SubsetOpInterface candidate,
- function_ref<bool(Value, Value)> equivalenceFn) const {
- auto insertSliceOp = cast<OpTy>(op);
- return operateOnDisjointSubsets(insertSliceOp, candidate, equivalenceFn);
+ FailureOr<HyperrectangularSlice>
+ getAccessedHyperrectangularSlice(Operation *op) const {
+ return HyperrectangularSlice(cast<OffsetSizeAndStrideOpInterface>(op));
}
};
diff --git a/mlir/lib/Interfaces/CMakeLists.txt b/mlir/lib/Interfaces/CMakeLists.txt
index 2652d261f480ba4..e7c76e70ed6b5d7 100644
--- a/mlir/lib/Interfaces/CMakeLists.txt
+++ b/mlir/lib/Interfaces/CMakeLists.txt
@@ -93,10 +93,12 @@ add_mlir_library(MLIRSubsetOpInterface
DEPENDS
MLIRDestinationStyleOpInterface
MLIRSubsetOpInterfaceIncGen
+ MLIRValueBoundsOpInterface
LINK_LIBS PUBLIC
MLIRDestinationStyleOpInterface
MLIRIR
+ MLIRValueBoundsOpInterface
)
add_mlir_interface_library(TilingInterface)
diff --git a/mlir/lib/Interfaces/SubsetOpInterface.cpp b/mlir/lib/Interfaces/SubsetOpInterface.cpp
index 7245ab20c499e20..d0bdadf500f6f6c 100644
--- a/mlir/lib/Interfaces/SubsetOpInterface.cpp
+++ b/mlir/lib/Interfaces/SubsetOpInterface.cpp
@@ -8,6 +8,7 @@
#include "mlir/Interfaces/SubsetOpInterface.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
+#include "mlir/Interfaces/ValueBoundsOpInterface.h"
#include "mlir/Interfaces/SubsetOpInterface.cpp.inc"
@@ -40,6 +41,54 @@ bool detail::defaultIsEquivalentSubset(
candidate.getDefiningOp<SubsetOpInterface>(), equivalenceFn);
}
+bool detail::defaultOperatesOnEquivalentSubset(
+ Operation *op, SubsetOpInterface candidate,
+ function_ref<bool(Value, Value)> equivalenceFn) {
+ auto subsetOp = cast<SubsetOpInterface>(op);
+ FailureOr<HyperrectangularSlice> slice =
+ subsetOp.getAccessedHyperrectangularSlice();
+ assert(succeeded(slice) &&
+ "operatesOnEquivalentSubset must be implemented if "
+ "getAccessedHyperrectangularSlice is not implemented");
+ FailureOr<HyperrectangularSlice> otherSlice =
+ candidate.getAccessedHyperrectangularSlice();
+ if (failed(otherSlice))
+ return false;
+ if (!equivalenceFn(subsetOp.getTensorContainer(),
+ candidate.getTensorContainer()))
+ return false;
+ FailureOr<bool> equivalent = ValueBoundsConstraintSet::areEquivalentSlices(
+ op->getContext(), *slice, *otherSlice);
+ return succeeded(equivalent) && *equivalent;
+}
+
+bool detail::defaultOperatesOnDisjointSubset(
+ Operation *op, SubsetOpInterface candidate,
+ function_ref<bool(Value, Value)> equivalenceFn) {
+ auto subsetOp = cast<SubsetOpInterface>(op);
+ FailureOr<HyperrectangularSlice> slice =
+ subsetOp.getAccessedHyperrectangularSlice();
+ assert(succeeded(slice) &&
+ "defaultOperatesOnDisjointSubset must be implemented if "
+ "getAccessedHyperrectangularSlice is not implemented");
+ FailureOr<HyperrectangularSlice> otherSlice =
+ candidate.getAccessedHyperrectangularSlice();
+ if (failed(otherSlice))
+ return false;
+ if (!equivalenceFn(subsetOp.getTensorContainer(),
+ candidate.getTensorContainer()))
+ return false;
+ FailureOr<bool> overlapping = ValueBoundsConstraintSet::areOverlappingSlices(
+ op->getContext(), *slice, *otherSlice);
+ return succeeded(overlapping) && !*overlapping;
+}
+
+Value detail::getTensorContainer(Operation *op) {
+ if (auto insertionOp = dyn_cast<::mlir::SubsetInsertionOpInterface>(op))
+ return insertionOp.getDestinationOperand().get();
+ return cast<::mlir::SubsetExtractionOpInterface>(op).getSourceOperand().get();
+}
+
LogicalResult detail::verifySubsetOpInterface(SubsetOpInterface op) {
if (!(isa<SubsetExtractionOpInterface>(op.getOperation()) ^
isa<SubsetInsertionOpInterface>(op.getOperation())))
diff --git a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
index f0c37c872e6d31d..62ba63402925e01 100644
--- a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
+++ b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
@@ -25,6 +25,32 @@ namespace mlir {
#include "mlir/Interfaces/ValueBoundsOpInterface.cpp.inc"
} // namespace mlir
+HyperrectangularSlice::HyperrectangularSlice(ArrayRef<OpFoldResult> offsets,
+ ArrayRef<OpFoldResult> sizes,
+ ArrayRef<OpFoldResult> strides)
+ : mixedOffsets(offsets), mixedSizes(sizes), mixedStrides(strides) {
+ assert(offsets.size() == sizes.size() &&
+ "expected same number of offsets, sizes, strides");
+ assert(offsets.size() == strides.size() &&
+ "expected same number of offsets, sizes, strides");
+}
+
+HyperrectangularSlice::HyperrectangularSlice(ArrayRef<OpFoldResult> offsets,
+ ArrayRef<OpFoldResult> sizes)
+ : mixedOffsets(offsets), mixedSizes(sizes) {
+ assert(offsets.size() == sizes.size() &&
+ "expected same number of offsets and sizes");
+ // Assume that all strides are 1.
+ if (offsets.empty())
+ return;
+ MLIRContext *ctx = offsets.front().getContext();
+ mixedStrides.append(offsets.size(), Builder(ctx).getIndexAttr(1));
+}
+
+HyperrectangularSlice::HyperrectangularSlice(OffsetSizeAndStrideOpInterface op)
+ : HyperrectangularSlice(op.getMixedOffsets(), op.getMixedSizes(),
+ op.getMixedStrides()) {}
+
/// If ofr is a constant integer or an IntegerAttr, return the integer.
static std::optional<int64_t> getConstantIntValue(OpFoldResult ofr) {
// Case 1: Check for Constant integer.
@@ -524,19 +550,44 @@ ValueBoundsConstraintSet::areEqual(Value value1, Value value2,
return *delta == 0;
}
-FailureOr<bool> ValueBoundsConstraintSet::areOverlappingSlices(
- OffsetSizeAndStrideOpInterface slice1,
- OffsetSizeAndStrideOpInterface slice2) {
- assert(slice1.getStaticOffsets().size() == slice1.getStaticOffsets().size() &&
+FailureOr<bool> ValueBoundsConstraintSet::areEqual(OpFoldResult ofr1,
+ OpFoldResult ofr2) {
+ Builder b(ofr1.getContext());
+ AffineMap map =
+ AffineMap::get(/*dimCount=*/0, /*symbolCount=*/2,
+ b.getAffineSymbolExpr(0) - b.getAffineSymbolExpr(1));
+ SmallVector<OpFoldResult> ofrOperands;
+ ofrOperands.push_back(ofr1);
+ ofrOperands.push_back(ofr2);
+ SmallVector<Value> valueOperands;
+ AffineMap foldedMap =
+ foldAttributesIntoMap(b, map, ofrOperands, valueOperands);
+ ValueDimList valueDims;
+ for (Value v : valueOperands) {
+ assert(v.getType().isIndex() && "expected index type");
+ valueDims.emplace_back(v, std::nullopt);
+ }
+ FailureOr<int64_t> delta =
+ computeConstantBound(presburger::BoundType::EQ, foldedMap, valueDims);
+ if (failed(delta))
+ return failure();
+ return *delta == 0;
+}
+
+FailureOr<bool>
+ValueBoundsConstraintSet::areOverlappingSlices(MLIRContext *ctx,
+ HyperrectangularSlice slice1,
+ HyperrectangularSlice slice2) {
+ assert(slice1.getMixedOffsets().size() == slice1.getMixedOffsets().size() &&
"expected slices of same rank");
- assert(slice1.getStaticSizes().size() == slice1.getStaticSizes().size() &&
+ assert(slice1.getMixedSizes().size() == slice1.getMixedSizes().size() &&
"expected slices of same rank");
- assert(slice1.getStaticStrides().size() == slice1.getStaticStrides().size() &&
+ assert(slice1.getMixedStrides().size() == slice1.getMixedStrides().size() &&
"expected slices of same rank");
- Builder b(slice1.getContext());
+ Builder b(ctx);
bool foundUnknownBound = false;
- for (int64_t i = 0, e = slice1.getStaticOffsets().size(); i < e; ++i) {
+ for (int64_t i = 0, e = slice1.getMixedOffsets().size(); i < e; ++i) {
AffineMap map =
AffineMap::get(/*dimCount=*/0, /*symbolCount=*/4,
b.getAffineSymbolExpr(0) +
@@ -588,6 +639,48 @@ FailureOr<bool> ValueBoundsConstraintSet::areOverlappingSlices(
return true;
}
+FailureOr<bool>
+ValueBoundsConstraintSet::areEquivalentSlices(MLIRContext *ctx,
+ HyperrectangularSlice slice1,
+ HyperrectangularSlice slice2) {
+ assert(slice1.getMixedOffsets().size() == slice1.getMixedOffsets().size() &&
+ "expected slices of same rank");
+ assert(slice1.getMixedSizes().size() == slice1.getMixedSizes().size() &&
+ "expected slices of same rank");
+ assert(slice1.getMixedStrides().size() == slice1.getMixedStrides().size() &&
+ "expected slices of same rank");
+
+ // The two slices are equivalent if all of their offsets, sizes and strides
+ // are equal. If equality cannot be determined for at least one of those
+ // values, equivalence cannot be determined and this function returns
+ // "failure".
+ for (auto [offset1, offset2] :
+ llvm::zip_equal(slice1.getMixedOffsets(), slice2.getMixedOffsets())) {
+ FailureOr<bool> equal = areEqual(offset1, offset2);
+ if (failed(equal))
+ return failure();
+ if (!equal.value())
+ return false;
+ }
+ for (auto [size1, size2] :
+ llvm::zip_equal(slice1.getMixedSizes(), slice2.getMixedSizes())) {
+ FailureOr<bool> equal = areEqual(size1, size2);
+ if (failed(equal))
+ return failure();
+ if (!equal.value())
+ return false;
+ }
+ for (auto [stride1, stride2] :
+ llvm::zip_equal(slice1.getMixedStrides(), slice2.getMixedStrides())) {
+ FailureOr<bool> equal = areEqual(stride1, stride2);
+ if (failed(equal))
+ return failure();
+ if (!equal.value())
+ return false;
+ }
+ return true;
+}
+
ValueBoundsConstraintSet::BoundBuilder &
ValueBoundsConstraintSet::BoundBuilder::operator[](int64_t dim) {
assert(!this->dim.has_value() && "dim was already set");
diff --git a/mlir/test/Transforms/loop-invariant-subset-hoisting.mlir b/mlir/test/Transforms/loop-invariant-subset-hoisting.mlir
index b9161f4e20d1927..bb60eeaba52455c 100644
--- a/mlir/test/Transforms/loop-invariant-subset-hoisting.mlir
+++ b/mlir/test/Transforms/loop-invariant-subset-hoisting.mlir
@@ -7,6 +7,11 @@ func.func @hoist_matching_extract_insert(%arg: tensor<?xf32>) -> tensor<?xf32> {
%ub = "test.foo"() : () -> (index)
%step = "test.foo"() : () -> (index)
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %add = arith.addi %c0, %c1 : index
+ %sub = arith.subi %add, %c1 : index
+
// CHECK: %[[extract:.*]] = tensor.extract_slice %[[arg]]
// CHECK: %[[for:.*]]:2 = scf.for {{.*}} iter_args(%[[t:.*]] = %[[arg]], %[[hoisted:.*]] = %[[extract]])
%0 = scf.for %iv = %lb to %ub step %step iter_args(%t = %arg) -> (tensor<?xf32>) {
@@ -17,7 +22,9 @@ func.func @hoist_matching_extract_insert(%arg: tensor<?xf32>) -> tensor<?xf32> {
%1 = tensor.extract_slice %t[0][5][1] : tensor<?xf32> to tensor<5xf32>
// CHECK: %[[foo:.*]] = "test.foo"(%[[hoisted]])
%2 = "test.foo"(%1) : (tensor<5xf32>) -> (tensor<5xf32>)
- %3 = tensor.insert_slice %2 into %t[0][5][1] : tensor<5xf32> into tensor<?xf32>
+ // Obfuscate the IR by inserting at offset %sub instead of 0; both of them
+ // have the same value.
+ %3 = tensor.insert_slice %2 into %t[%sub][5][1] : tensor<5xf32> into tensor<?xf32>
// CHECK: scf.yield %[[t]], %[[foo]]
scf.yield %3 : tensor<?xf32>
}
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index a9aa1d848fce4f5..4e814428f0cc457 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -10229,6 +10229,7 @@ cc_library(
":IR",
":SubsetOpInterfaceIncGen",
":Support",
+ ":ValueBoundsOpInterface",
"//llvm:Support",
],
)
>From e903f76f1ec4612a1262bc57d1d2c8c4b84d1a9a Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Mon, 30 Oct 2023 15:31:47 +0900
Subject: [PATCH 5/6] [mlir][vector] Implement subset op interface for xfer ops
---
.../Vector/Transforms/SubsetOpInterfaceImpl.h | 20 +
mlir/include/mlir/InitAllDialects.h | 2 +
.../Dialect/Vector/Transforms/CMakeLists.txt | 2 +
.../Transforms/SubsetOpInterfaceImpl.cpp | 82 +++
mlir/test/Dialect/Linalg/hoisting.mlir | 471 ------------------
.../loop-invariant-subset-hoisting.mlir | 318 ++++++++++++
.../llvm-project-overlay/mlir/BUILD.bazel | 1 +
7 files changed, 425 insertions(+), 471 deletions(-)
create mode 100644 mlir/include/mlir/Dialect/Vector/Transforms/SubsetOpInterfaceImpl.h
create mode 100644 mlir/lib/Dialect/Vector/Transforms/SubsetOpInterfaceImpl.cpp
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/SubsetOpInterfaceImpl.h b/mlir/include/mlir/Dialect/Vector/Transforms/SubsetOpInterfaceImpl.h
new file mode 100644
index 000000000000000..74bde485fa17a99
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/SubsetOpInterfaceImpl.h
@@ -0,0 +1,20 @@
+//===- SubsetOpInterfaceImpl.h - Tensor subsets -----------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_VECTOR_SUBSETOPINTERFACEIMPL_H
+#define MLIR_DIALECT_VECTOR_SUBSETOPINTERFACEIMPL_H
+
+namespace mlir {
+class DialectRegistry;
+
+namespace vector {
+void registerSubsetOpInterfaceExternalModels(DialectRegistry ®istry);
+} // namespace vector
+} // namespace mlir
+
+#endif // MLIR_DIALECT_VECTOR_SUBSETOPINTERFACEIMPL_H
diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index 7c2ffb7408d9afd..621110d130818d3 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -85,6 +85,7 @@
#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.h"
+#include "mlir/Dialect/Vector/Transforms/SubsetOpInterfaceImpl.h"
#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
#include "mlir/IR/Dialect.h"
#include "mlir/Interfaces/CastInterfaces.h"
@@ -171,6 +172,7 @@ inline void registerAllDialects(DialectRegistry ®istry) {
tensor::registerTilingInterfaceExternalModels(registry);
tensor::registerValueBoundsOpInterfaceExternalModels(registry);
vector::registerBufferizableOpInterfaceExternalModels(registry);
+ vector::registerSubsetOpInterfaceExternalModels(registry);
NVVM::registerNVVMTargetInterfaceExternalModels(registry);
ROCDL::registerROCDLTargetInterfaceExternalModels(registry);
}
diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
index e6e7fb465aec3b9..513340096a5c1fc 100644
--- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
@@ -10,6 +10,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
LowerVectorShapeCast.cpp
LowerVectorTransfer.cpp
LowerVectorTranspose.cpp
+ SubsetOpInterfaceImpl.cpp
VectorDistribute.cpp
VectorDropLeadUnitDim.cpp
VectorEmulateNarrowType.cpp
@@ -40,6 +41,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
MLIRMemRefUtils
MLIRSCFDialect
MLIRSideEffectInterfaces
+ MLIRSubsetOpInterface
MLIRTensorDialect
MLIRTransforms
MLIRVectorDialect
diff --git a/mlir/lib/Dialect/Vector/Transforms/SubsetOpInterfaceImpl.cpp b/mlir/lib/Dialect/Vector/Transforms/SubsetOpInterfaceImpl.cpp
new file mode 100644
index 000000000000000..b450d5b78a46663
--- /dev/null
+++ b/mlir/lib/Dialect/Vector/Transforms/SubsetOpInterfaceImpl.cpp
@@ -0,0 +1,82 @@
+//===- SubsetOpInterfaceImpl.cpp - Tensor subsets -------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Vector/Transforms/SubsetOpInterfaceImpl.h"
+
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Interfaces/SubsetOpInterface.h"
+
+using namespace mlir;
+using namespace mlir::vector;
+
+namespace {
+
+template <typename OpTy>
+struct XferOpSubsetOpInterface
+ : public SubsetOpInterface::ExternalModel<XferOpSubsetOpInterface<OpTy>,
+ OpTy> {
+ FailureOr<HyperrectangularSlice>
+ getAccessedHyperrectangularSlice(Operation *op) const {
+ auto xferOp = cast<OpTy>(op);
+ Builder b(xferOp->getContext());
+ SmallVector<OpFoldResult> offsets = llvm::map_to_vector(
+ xferOp.getIndices(), [](Value v) -> OpFoldResult { return v; });
+ SmallVector<OpFoldResult> sizes = llvm::map_to_vector(
+ xferOp.getTransferChunkAccessed(),
+ [&](int64_t sz) -> OpFoldResult { return b.getIndexAttr(sz); });
+ return HyperrectangularSlice(offsets, sizes);
+ }
+};
+
+struct TransferReadOpSubsetExtractionOpInterface
+ : public SubsetExtractionOpInterface::ExternalModel<
+ TransferReadOpSubsetExtractionOpInterface, vector::TransferReadOp> {
+ OpOperand &getSourceOperand(Operation *op) const {
+ return cast<vector::TransferReadOp>(op).getSourceMutable();
+ }
+};
+
+struct TransferWriteOpSubsetInsertionOpInterface
+ : public SubsetInsertionOpInterface::ExternalModel<
+ TransferWriteOpSubsetInsertionOpInterface, vector::TransferWriteOp> {
+ OpOperand &getSourceOperand(Operation *op) const {
+ return cast<vector::TransferWriteOp>(op).getVectorMutable();
+ }
+
+ OpOperand &getDestinationOperand(Operation *op) const {
+ return cast<vector::TransferWriteOp>(op).getSourceMutable();
+ }
+
+ Value buildSubsetExtraction(Operation *op, OpBuilder &builder,
+ Location loc) const {
+ // TODO: Implement when needed.
+ return Value();
+ }
+
+ SmallVector<Value>
+ getValuesNeededToBuildSubsetExtraction(Operation *op) const {
+ // TODO: Implement when needed.
+ return {};
+ }
+};
+
+} // namespace
+
+void mlir::vector::registerSubsetOpInterfaceExternalModels(
+ DialectRegistry ®istry) {
+ registry.addExtension(+[](MLIRContext *ctx, vector::VectorDialect *dialect) {
+ TransferReadOp::attachInterface<XferOpSubsetOpInterface<TransferReadOp>>(
+ *ctx);
+ TransferReadOp::attachInterface<TransferReadOpSubsetExtractionOpInterface>(
+ *ctx);
+ TransferWriteOp::attachInterface<XferOpSubsetOpInterface<TransferWriteOp>>(
+ *ctx);
+ TransferWriteOp::attachInterface<TransferWriteOpSubsetInsertionOpInterface>(
+ *ctx);
+ });
+}
diff --git a/mlir/test/Dialect/Linalg/hoisting.mlir b/mlir/test/Dialect/Linalg/hoisting.mlir
index 3623952a08df024..550ffbc7bab678a 100644
--- a/mlir/test/Dialect/Linalg/hoisting.mlir
+++ b/mlir/test/Dialect/Linalg/hoisting.mlir
@@ -224,477 +224,6 @@ module attributes {transform.with_named_sequence} {
// -----
-// CHECK-LABEL: func @hoist_vector_transfer_pairs_tensor
-func.func @hoist_vector_transfer_pairs_tensor(
- %tensor0: tensor<?x?xf32>, %tensor1: tensor<?x?xf32>, %tensor2: tensor<?x?xf32>,
- %tensor3: tensor<?x?xf32>, %tensor4: tensor<?x?xf32>, %tensor5: tensor<?x?xf32>,
- %val: index, %lb : index, %ub : index, %step: index) ->
- (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>,
- tensor<?x?xf32>, tensor<?x?xf32>) {
- %c0 = arith.constant 0 : index
- %cst = arith.constant 0.0 : f32
-
-// CHECK: vector.transfer_read %{{.*}} : tensor<?x?xf32>, vector<1xf32>
-// CHECK: scf.for {{.*}} iter_args({{.*}}) ->
-// CHECK-SAME: (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, vector<1xf32>) {
-// CHECK: vector.transfer_read %{{.*}} : tensor<?x?xf32>, vector<2xf32>
-// CHECK: scf.for {{.*}} iter_args({{.*}}) ->
-// CHECK-SAME: (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, vector<2xf32>, vector<1xf32>) {
-// CHECK: vector.transfer_read %{{.*}} : tensor<?x?xf32>, vector<4xf32>
-// CHECK: "some_crippling_use"(%{{.*}}) : (tensor<?x?xf32>) -> ()
-// CHECK: vector.transfer_read %{{.*}} : tensor<?x?xf32>, vector<5xf32>
-// CHECK: "some_use"(%{{.*}}) : (vector<1xf32>) -> vector<1xf32>
-// CHECK: "some_use"(%{{.*}}) : (vector<2xf32>) -> vector<2xf32>
-// CHECK: "some_use"(%{{.*}}) : (tensor<?x?xf32>) -> vector<3xf32>
-// CHECK: "some_use"(%{{.*}}) : (vector<4xf32>) -> vector<4xf32>
-// CHECK: "some_use"(%{{.*}}) : (vector<5xf32>) -> vector<5xf32>
-// CHECK: vector.transfer_write %{{.*}} : vector<3xf32>, tensor<?x?xf32>
-// CHECK: vector.transfer_write %{{.*}} : vector<4xf32>, tensor<?x?xf32>
-// CHECK: vector.transfer_write %{{.*}} : vector<5xf32>, tensor<?x?xf32>
-// CHECK: "some_crippling_use"(%{{.*}}) : (tensor<?x?xf32>) -> ()
-// CHECK: scf.yield {{.*}} :
-// CHECK-SAME: tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, vector<2xf32>, vector<1xf32>
-// CHECK: }
-// CHECK: vector.transfer_write %{{.*}} : vector<2xf32>, tensor<?x?xf32>
-// CHECK: scf.yield {{.*}} :
-// CHECK-SAME: tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, vector<1xf32>
-// CHECK: }
-// CHECK: vector.transfer_write %{{.*}} : vector<1xf32>, tensor<?x?xf32>
- %0:6 = scf.for %i = %lb to %ub step %step
- iter_args(%arg0 = %tensor0, %arg1 = %tensor1, %arg2 = %tensor2,
- %arg3 = %tensor3, %arg4 = %tensor4, %arg5 = %tensor5)
- -> (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>,
- tensor<?x?xf32>, tensor<?x?xf32>) {
- %1:6 = scf.for %j = %lb to %ub step %step
- iter_args(%arg6 = %arg0, %arg7 = %arg1, %arg8 = %arg2,
- %arg9 = %arg3, %arg10 = %arg4, %arg11 = %arg5)
- -> (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>,
- tensor<?x?xf32>, tensor<?x?xf32>) {
- %r0 = vector.transfer_read %arg7[%c0, %c0], %cst: tensor<?x?xf32>, vector<1xf32>
- %r1 = vector.transfer_read %arg6[%i, %i], %cst: tensor<?x?xf32>, vector<2xf32>
- %r3 = vector.transfer_read %arg9[%c0, %c0], %cst: tensor<?x?xf32>, vector<4xf32>
- "some_crippling_use"(%arg10) : (tensor<?x?xf32>) -> ()
- %r4 = vector.transfer_read %arg10[%c0, %c0], %cst: tensor<?x?xf32>, vector<5xf32>
- %r5 = vector.transfer_read %arg11[%c0, %c0], %cst: tensor<?x?xf32>, vector<6xf32>
- "some_crippling_use"(%arg11) : (tensor<?x?xf32>) -> ()
- %u0 = "some_use"(%r0) : (vector<1xf32>) -> vector<1xf32>
- %u1 = "some_use"(%r1) : (vector<2xf32>) -> vector<2xf32>
- %u2 = "some_use"(%arg8) : (tensor<?x?xf32>) -> vector<3xf32>
- %u3 = "some_use"(%r3) : (vector<4xf32>) -> vector<4xf32>
- %u4 = "some_use"(%r4) : (vector<5xf32>) -> vector<5xf32>
- %u5 = "some_use"(%r5) : (vector<6xf32>) -> vector<6xf32>
- %w1 = vector.transfer_write %u0, %arg7[%c0, %c0] : vector<1xf32>, tensor<?x?xf32>
- %w0 = vector.transfer_write %u1, %arg6[%i, %i] : vector<2xf32>, tensor<?x?xf32>
- %w2 = vector.transfer_write %u2, %arg8[%c0, %c0] : vector<3xf32>, tensor<?x?xf32>
- %w3 = vector.transfer_write %u3, %arg9[%c0, %c0] : vector<4xf32>, tensor<?x?xf32>
- %w4 = vector.transfer_write %u4, %arg10[%c0, %c0] : vector<5xf32>, tensor<?x?xf32>
- %w5 = vector.transfer_write %u5, %arg11[%c0, %c0] : vector<6xf32>, tensor<?x?xf32>
- "some_crippling_use"(%w3) : (tensor<?x?xf32>) -> ()
- scf.yield %w0, %w1, %w2, %w3, %w4, %w5 :
- tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>,
- tensor<?x?xf32>, tensor<?x?xf32>
- }
- scf.yield %1#0, %1#1, %1#2, %1#3, %1#4, %1#5 :
- tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>,
- tensor<?x?xf32>, tensor<?x?xf32>
- }
- return %0#0, %0#1, %0#2, %0#3, %0#4, %0#5 :
- tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>,
- tensor<?x?xf32>, tensor<?x?xf32>
-}
-
-module attributes {transform.with_named_sequence} {
- transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["func.func"]} in %arg1
- : (!transform.any_op) -> !transform.any_op
- transform.structured.hoist_redundant_tensor_subsets %0
- : (!transform.any_op) -> ()
- transform.yield
- }
-}
-
-// -----
-
-// CHECK-LABEL: func @hoist_vector_transfer_pairs_disjoint_tensor(
-// CHECK-SAME: %[[TENSOR0:[a-zA-Z0-9]*]]: tensor<?x?xf32>,
-// CHECK-SAME: %[[TENSOR1:[a-zA-Z0-9]*]]: tensor<?x?xf32>,
-// CHECK-SAME: %[[TENSOR2:[a-zA-Z0-9]*]]: tensor<?x?xf32>,
-// CHECK-SAME: %[[TENSOR3:[a-zA-Z0-9]*]]: tensor<?x?xf32>,
-func.func @hoist_vector_transfer_pairs_disjoint_tensor(
- %tensor0: tensor<?x?xf32>, %tensor1: tensor<?x?xf32>,
- %tensor2: tensor<?x?xf32>, %tensor3: tensor<?x?xf32>,
- %val: index, %lb : index, %ub : index, %step: index,
- %random_index : index) ->
- (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>) {
- %c0 = arith.constant 0 : index
- %c1 = arith.constant 1 : index
- %c3 = arith.constant 3 : index
- %cst = arith.constant 0.0 : f32
-
-// CHECK: vector.transfer_read %[[TENSOR2]]{{.*}} : tensor<?x?xf32>, vector<3xf32>
-// CHECK: vector.transfer_read %[[TENSOR2]]{{.*}} : tensor<?x?xf32>, vector<3xf32>
-// CHECK: vector.transfer_read %[[TENSOR3]]{{.*}} : tensor<?x?xf32>, vector<4xf32>
-// CHECK: vector.transfer_read %[[TENSOR3]]{{.*}} : tensor<?x?xf32>, vector<4xf32>
-// CHECK: %[[R:.*]]:6 = scf.for {{.*}} iter_args({{.*}}) ->
-// CHECK-SAME: (tensor<?x?xf32>, tensor<?x?xf32>, vector<3xf32>, vector<3xf32>, vector<4xf32>, vector<4xf32>) {
-// CHECK: scf.for {{.*}} iter_args({{.*}}) ->
-// CHECK-SAME: (tensor<?x?xf32>, tensor<?x?xf32>, vector<3xf32>, vector<3xf32>, vector<4xf32>, vector<4xf32>) {
-// CHECK: vector.transfer_read %[[TENSOR1]]{{.*}} : tensor<?x?xf32>, vector<2xf32>
-// CHECK: vector.transfer_read %[[TENSOR1]]{{.*}} : tensor<?x?xf32>, vector<2xf32>
-// CHECK: "some_use"(%{{.*}}) : (vector<2xf32>) -> vector<2xf32>
-// CHECK: "some_use"(%{{.*}}) : (vector<2xf32>) -> vector<2xf32>
-// CHECK: "some_use"(%{{.*}}) : (vector<3xf32>) -> vector<3xf32>
-// CHECK: "some_use"(%{{.*}}) : (vector<3xf32>) -> vector<3xf32>
-// CHECK: "some_use"(%{{.*}}) : (vector<4xf32>) -> vector<4xf32>
-// CHECK: "some_use"(%{{.*}}) : (vector<4xf32>) -> vector<4xf32>
-// CHECK: "some_use"(%{{.*}}) : (vector<2xf32>) -> vector<2xf32>
-// CHECK: "some_use"(%{{.*}}) : (vector<2xf32>) -> vector<2xf32>
-// CHECK: vector.transfer_write %{{.*}}, %{{.*}}{{.*}} : vector<2xf32>, tensor<?x?xf32>
-// CHECK: vector.transfer_write %{{.*}}, %{{.*}}{{.*}} : vector<2xf32>, tensor<?x?xf32>
-// CHECK: scf.yield {{.*}} :
-// CHECK-SAME: tensor<?x?xf32>, tensor<?x?xf32>, vector<3xf32>, vector<3xf32>, vector<4xf32>, vector<4xf32>
-// CHECK: }
-// CHECK: scf.yield {{.*}} :
-// CHECK-SAME: tensor<?x?xf32>, tensor<?x?xf32>, vector<3xf32>, vector<3xf32>, vector<4xf32>, vector<4xf32>
-// CHECK: }
-// CHECK: %[[TENSOR4:.*]] = vector.transfer_write %[[R]]#5, %[[TENSOR3]]{{.*}} : vector<4xf32>, tensor<?x?xf32>
-// CHECK: vector.transfer_write %[[R]]#4, %[[TENSOR4]]{{.*}} : vector<4xf32>, tensor<?x?xf32>
-// CHECK: %[[TENSOR5:.*]] = vector.transfer_write %[[R]]#3, %[[TENSOR2]]{{.*}} : vector<3xf32>, tensor<?x?xf32>
-// CHECK: vector.transfer_write %[[R]]#2, %[[TENSOR5]]{{.*}} : vector<3xf32>, tensor<?x?xf32>
- %0:4 = scf.for %i = %lb to %ub step %step
- iter_args(%arg0 = %tensor0, %arg1 = %tensor1, %arg2 = %tensor2,
- %arg3 = %tensor3)
- -> (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>) {
- %1:4 = scf.for %j = %lb to %ub step %step
- iter_args(%arg4 = %arg0, %arg5 = %arg1, %arg6 = %arg2,
- %arg7 = %arg3)
- -> (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>) {
- %r00 = vector.transfer_read %arg5[%c0, %c0], %cst: tensor<?x?xf32>, vector<2xf32>
- %r01 = vector.transfer_read %arg5[%c0, %c1], %cst: tensor<?x?xf32>, vector<2xf32>
- %r20 = vector.transfer_read %arg6[%c0, %c0], %cst: tensor<?x?xf32>, vector<3xf32>
- %r21 = vector.transfer_read %arg6[%c0, %c3], %cst: tensor<?x?xf32>, vector<3xf32>
- %r30 = vector.transfer_read %arg7[%c0, %random_index], %cst: tensor<?x?xf32>, vector<4xf32>
- %r31 = vector.transfer_read %arg7[%c1, %random_index], %cst: tensor<?x?xf32>, vector<4xf32>
- %r10 = vector.transfer_read %arg4[%i, %i], %cst: tensor<?x?xf32>, vector<2xf32>
- %r11 = vector.transfer_read %arg4[%random_index, %random_index], %cst: tensor<?x?xf32>, vector<2xf32>
- %u00 = "some_use"(%r00) : (vector<2xf32>) -> vector<2xf32>
- %u01 = "some_use"(%r01) : (vector<2xf32>) -> vector<2xf32>
- %u20 = "some_use"(%r20) : (vector<3xf32>) -> vector<3xf32>
- %u21 = "some_use"(%r21) : (vector<3xf32>) -> vector<3xf32>
- %u30 = "some_use"(%r30) : (vector<4xf32>) -> vector<4xf32>
- %u31 = "some_use"(%r31) : (vector<4xf32>) -> vector<4xf32>
- %u10 = "some_use"(%r10) : (vector<2xf32>) -> vector<2xf32>
- %u11 = "some_use"(%r11) : (vector<2xf32>) -> vector<2xf32>
- %w10 = vector.transfer_write %u00, %arg5[%c0, %c0] : vector<2xf32>, tensor<?x?xf32>
- %w11 = vector.transfer_write %u01, %w10[%c0, %c1] : vector<2xf32>, tensor<?x?xf32>
- %w20 = vector.transfer_write %u20, %arg6[%c0, %c0] : vector<3xf32>, tensor<?x?xf32>
- %w21 = vector.transfer_write %u21, %w20[%c0, %c3] : vector<3xf32>, tensor<?x?xf32>
- %w30 = vector.transfer_write %u30, %arg7[%c0, %random_index] : vector<4xf32>, tensor<?x?xf32>
- %w31 = vector.transfer_write %u31, %w30[%c1, %random_index] : vector<4xf32>, tensor<?x?xf32>
- %w00 = vector.transfer_write %u10, %arg4[%i, %i] : vector<2xf32>, tensor<?x?xf32>
- %w01 = vector.transfer_write %u11, %w00[%random_index, %random_index] : vector<2xf32>, tensor<?x?xf32>
- scf.yield %w01, %w11, %w21, %w31 : tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>
- }
- scf.yield %1#0, %1#1, %1#2, %1#3 : tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>
- }
- return %0#0, %0#1, %0#2, %0#3 : tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>
-}
-
-module attributes {transform.with_named_sequence} {
- transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["func.func"]} in %arg1
- : (!transform.any_op) -> !transform.any_op
- transform.structured.hoist_redundant_tensor_subsets %0
- : (!transform.any_op) -> ()
- transform.yield
- }
-}
-
-// -----
-
-// CHECK-LABEL: func @hoist_vector_transfer_pairs_tensor_and_slices
-// CHECK-SAME: %[[TENSOR0:[a-zA-Z0-9]*]]: tensor<?x?xf32>,
-// CHECK-SAME: %[[TENSOR1:[a-zA-Z0-9]*]]: tensor<?x?xf32>,
-// CHECK-SAME: %[[TENSOR2:[a-zA-Z0-9]*]]: tensor<?x?xf32>,
-// CHECK-SAME: %[[TENSOR3:[a-zA-Z0-9]*]]: tensor<?x?xf32>,
-// CHECK-SAME: %[[TENSOR4:[a-zA-Z0-9]*]]: tensor<?x?xf32>,
-// CHECK-SAME: %[[TENSOR5:[a-zA-Z0-9]*]]: tensor<?x?xf32>
-func.func @hoist_vector_transfer_pairs_tensor_and_slices(
- %tensor0: tensor<?x?xf32>, %tensor1: tensor<?x?xf32>, %tensor2: tensor<?x?xf32>,
- %tensor3: tensor<?x?xf32>, %tensor4: tensor<?x?xf32>, %tensor5: tensor<?x?xf32>,
- %val: index, %lb : index, %ub : index, %step: index) ->
- (
- tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>//, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>
- ) {
- %c0 = arith.constant 0 : index
- %cst = arith.constant 0.0 : f32
-
- // CHECK: scf.for %[[I:.*]] = {{.*}} iter_args(
- // CHECK-SAME: %[[TENSOR0_ARG:[0-9a-zA-Z]+]] = %[[TENSOR0]],
- // CHECK-SAME: %[[TENSOR1_ARG:[0-9a-zA-Z]+]] = %[[TENSOR1]],
- // CHECK-SAME: %[[TENSOR2_ARG:[0-9a-zA-Z]+]] = %[[TENSOR2]]
- // CHECK-SAME: ) ->
- // CHECK-SAME: (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>
- %0:3 = scf.for %i = %lb to %ub step %step
- iter_args(%arg0 = %tensor0, %arg1 = %tensor1, %arg2 = %tensor2)
- -> (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>) {
-
- // Hoisted
- // CHECK: %[[ST0:.*]] = tensor.extract_slice %[[TENSOR0_ARG]][%[[I]], %[[I]]]{{.*}}: tensor<?x?xf32> to tensor<?x?xf32>
- // CHECK: %[[V0:.*]] = vector.transfer_read %[[ST0]]{{.*}} : tensor<?x?xf32>, vector<1xf32>
-
- // CHECK: %[[R:.*]]:3 = scf.for %[[J:.*]] = {{.*}} iter_args(
- // CHECK-SAME: %[[TENSOR1_ARG_L2:[0-9a-zA-Z]+]] = %[[TENSOR1_ARG]]
- // CHECK-SAME: %[[TENSOR2_ARG_L2:[0-9a-zA-Z]+]] = %[[TENSOR2_ARG]]
- // CHECK-SAME: %[[V0_ARG_L2:[0-9a-zA-Z]+]] = %[[V0]]
- // CHECK-SAME: ) ->
- // CHECK-SAME: (tensor<?x?xf32>, tensor<?x?xf32>, vector<1xf32>
- %1:3 = scf.for %j = %lb to %ub step %step
- iter_args(%arg6 = %arg0, %arg7 = %arg1, %arg8 = %arg2)
- -> (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>) {
- // Hoists.
- %st0 = tensor.extract_slice %arg6[%i, %i][%step, %step][1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
- %r0 = vector.transfer_read %st0[%c0, %c0], %cst: tensor<?x?xf32>, vector<1xf32>
-
- // CHECK: %[[ST1:.*]] = tensor.extract_slice %[[TENSOR1_ARG_L2]][%[[J]],{{.*}}: tensor<?x?xf32> to tensor<?x?xf32>
- // CHECK: %[[V1:.*]] = vector.transfer_read %[[ST1]]{{.*}} : tensor<?x?xf32>, vector<2xf32>
- // Does not hoist (slice depends on %j)
- %st1 = tensor.extract_slice %arg7[%j, %c0][%step, %step][1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
- %r1 = vector.transfer_read %st1[%c0, %c0], %cst: tensor<?x?xf32>, vector<2xf32>
-
- // CHECK: %[[ST2:.*]] = tensor.extract_slice %[[TENSOR2_ARG_L2]][%[[I]],{{.*}}: tensor<?x?xf32> to tensor<?x?xf32>
- // CHECK: %[[V2:.*]] = vector.transfer_read %[[ST2]]{{.*}} : tensor<?x?xf32>, vector<3xf32>
- // Does not hoist, 2 slice %arg8.
- %st2 = tensor.extract_slice %arg8[%i, %c0][%step, %step][1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
- %r2 = vector.transfer_read %st2[%c0, %c0], %cst: tensor<?x?xf32>, vector<3xf32>
-
- // CHECK: %[[U0:.*]] = "some_use"(%[[V0_ARG_L2]]) : (vector<1xf32>) -> vector<1xf32>
- // CHECK: %[[U1:.*]] = "some_use"(%[[V1]]) : (vector<2xf32>) -> vector<2xf32>
- // CHECK: %[[U2:.*]] = "some_use"(%[[V2]]) : (vector<3xf32>) -> vector<3xf32>
- %u0 = "some_use"(%r0) : (vector<1xf32>) -> vector<1xf32>
- %u1 = "some_use"(%r1) : (vector<2xf32>) -> vector<2xf32>
- %u2 = "some_use"(%r2) : (vector<3xf32>) -> vector<3xf32>
-
- // Hoists
- %w0 = vector.transfer_write %u0, %st0[%c0, %c0] : vector<1xf32>, tensor<?x?xf32>
-
- // CHECK-DAG: %[[STI1:.*]] = vector.transfer_write %[[U1]], %{{.*}} : vector<2xf32>, tensor<?x?xf32>
- // Does not hoist (associated slice depends on %j).
- %w1 = vector.transfer_write %u1, %st1[%i, %i] : vector<2xf32>, tensor<?x?xf32>
-
- // CHECK-DAG: %[[STI2:.*]] = vector.transfer_write %[[U2]], %{{.*}} : vector<3xf32>, tensor<?x?xf32>
- // Does not hoist, 2 slice / insert_slice for %arg8.
- %w2 = vector.transfer_write %u2, %st2[%c0, %c0] : vector<3xf32>, tensor<?x?xf32>
-
- // Hoists.
- %sti0 = tensor.insert_slice %w0 into %arg6[%i, %i][%step, %step][1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
-
- // CHECK-DAG: tensor.insert_slice %[[STI1]] into %[[TENSOR1_ARG_L2]][%[[J]],{{.*}}: tensor<?x?xf32> into tensor<?x?xf32>
- // Does not hoist (depends on %j).
- %sti1 = tensor.insert_slice %w1 into %arg7[%j, %c0][%step, %step][1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
-
- // CHECK-DAG: tensor.insert_slice %[[STI2]] into %[[TENSOR2_ARG_L2]][%[[I]],{{.*}}: tensor<?x?xf32> into tensor<?x?xf32>
- // Does not hoist, 2 slice / insert_slice for %arg8.
- %sti2 = tensor.insert_slice %w2 into %arg8[%i, %c0][%step, %step][1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
- // Extract with a different stride to make sure we cannot fold this extract with the above insert.
- %st22 = tensor.extract_slice %sti2[%i, %c0][%step, %step][2, 1] : tensor<?x?xf32> to tensor<?x?xf32>
- %sti22 = tensor.insert_slice %st22 into %arg8[%i, %c0][%step, %step][1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
-
- // CHECK: scf.yield {{.*}} : tensor<?x?xf32>, tensor<?x?xf32>, vector<1xf32>
- // CHECK: }
- scf.yield %sti0, %sti1, %sti22:
- tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>
- }
-
- // Hoisted
- // CHECK: %[[STI0:.*]] = vector.transfer_write %[[R]]#2, %[[ST0]]{{.*}} : vector<1xf32>, tensor<?x?xf32>
- // CHECK: tensor.insert_slice %[[STI0]] into %[[TENSOR0_ARG]][%[[I]], %[[I]]]{{.*}} : tensor<?x?xf32> into tensor<?x?xf32>
-
- // CHECK: scf.yield {{.*}} : tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>
- scf.yield %1#0, %1#1, %1#2 :
- tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>
-
- // CHECK: }
- }
- return %0#0, %0#1, %0#2 : tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>
-}
-
-module attributes {transform.with_named_sequence} {
- transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["func.func"]} in %arg1
- : (!transform.any_op) -> !transform.any_op
- transform.structured.hoist_redundant_tensor_subsets %0
- : (!transform.any_op) -> ()
- transform.yield
- }
-}
-
-// -----
-
-// CHECK-LABEL: func @hoist_vector_transfer_write_pairs_disjoint_tensor(
-// CHECK-SAME: %[[T:.*]]: tensor<?x?xf32>,
-// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
-// CHECK-DAG: %[[R0:.*]] = vector.transfer_read %[[T]][%[[C0]], %[[C0]]], %{{.*}} : tensor<?x?xf32>, vector<2xf32>
-// CHECK-DAG: %[[R1:.*]] = vector.transfer_read %[[T]][%[[C0]], %[[C3]]], %{{.*}} : tensor<?x?xf32>, vector<2xf32>
-// CHECK: %[[F:.*]]:2 = scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[R3:.*]] = %[[R1:.*]], %[[R2:.*]] = %[[R0]]) -> (vector<2xf32>, vector<2xf32>) {
-// CHECK: %[[R4:.*]] = "some_use"(%[[R2]]) : (vector<2xf32>) -> vector<2xf32>
-// CHECK: %[[R5:.*]] = "some_use"(%[[R3]]) : (vector<2xf32>) -> vector<2xf32>
-// CHECK: scf.yield %[[R5]], %[[R4]] : vector<2xf32>, vector<2xf32>
-// CHECK: }
-// CHECK: %[[W0:.*]] = vector.transfer_write %[[F]]#1, %[[T]][%[[C0]], %[[C0]]] : vector<2xf32>, tensor<?x?xf32>
-// CHECK: %[[W1:.*]] = vector.transfer_write %[[F]]#0, %[[W0]][%[[C0]], %[[C3]]] : vector<2xf32>, tensor<?x?xf32>
-// CHECK: return %[[W1]] : tensor<?x?xf32>
-func.func @hoist_vector_transfer_write_pairs_disjoint_tensor(
- %tensor: tensor<?x?xf32>,
- %val: index, %lb : index, %ub : index, %step: index) ->
- (tensor<?x?xf32>) {
- %c0 = arith.constant 0 : index
- %c1 = arith.constant 1 : index
- %c3 = arith.constant 3 : index
- %cst = arith.constant 0.0 : f32
- %1 = scf.for %j = %lb to %ub step %step iter_args(%arg5 = %tensor)
- -> (tensor<?x?xf32>) {
- %r00 = vector.transfer_read %arg5[%c0, %c0], %cst: tensor<?x?xf32>, vector<2xf32>
- %u00 = "some_use"(%r00) : (vector<2xf32>) -> vector<2xf32>
- %w10 = vector.transfer_write %u00, %arg5[%c0, %c0] : vector<2xf32>, tensor<?x?xf32>
-
- // Hoist by properly bypassing the disjoint write %w10.
- %r01 = vector.transfer_read %w10[%c0, %c3], %cst: tensor<?x?xf32>, vector<2xf32>
- %u01 = "some_use"(%r01) : (vector<2xf32>) -> vector<2xf32>
- %w11 = vector.transfer_write %u01, %w10[%c0, %c3] : vector<2xf32>, tensor<?x?xf32>
- scf.yield %w11 : tensor<?x?xf32>
- }
- return %1 : tensor<?x?xf32>
-}
-
-module attributes {transform.with_named_sequence} {
- transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["func.func"]} in %arg1
- : (!transform.any_op) -> !transform.any_op
- transform.structured.hoist_redundant_tensor_subsets %0
- : (!transform.any_op) -> ()
- transform.yield
- }
-}
-
-// -----
-
-// CHECK-LABEL: func @hoist_vector_transfer_pairs_tensor_and_slices_static_large_tensor
-// CHECK-SAME: %[[TENSOR0:[a-zA-Z0-9]*]]: tensor<100x100xf32>,
-// CHECK-SAME: %[[TENSOR1:[a-zA-Z0-9]*]]: tensor<200x200xf32>,
-// CHECK-SAME: %[[TENSOR2:[a-zA-Z0-9]*]]: tensor<300x300xf32>
-func.func @hoist_vector_transfer_pairs_tensor_and_slices_static_large_tensor(
- %tensor0: tensor<100x100xf32>, %tensor1: tensor<200x200xf32>, %tensor2: tensor<300x300xf32>,
- %val: index, %lb : index, %ub : index, %step: index) ->
- (
- tensor<100x100xf32>, tensor<200x200xf32>, tensor<300x300xf32>
- ) {
- %c0 = arith.constant 0 : index
- %cst = arith.constant 0.0 : f32
-
- // CHECK: scf.for %[[I:.*]] = {{.*}} iter_args(
- // CHECK-SAME: %[[TENSOR0_ARG:[0-9a-zA-Z]+]] = %[[TENSOR0]],
- // CHECK-SAME: %[[TENSOR1_ARG:[0-9a-zA-Z]+]] = %[[TENSOR1]],
- // CHECK-SAME: %[[TENSOR2_ARG:[0-9a-zA-Z]+]] = %[[TENSOR2]]
- // CHECK-SAME: ) ->
- // CHECK-SAME: (tensor<100x100xf32>, tensor<200x200xf32>, tensor<300x300xf32>
- %0:3 = scf.for %i = %lb to %ub step %step
- iter_args(%arg0 = %tensor0, %arg1 = %tensor1, %arg2 = %tensor2)
- -> (tensor<100x100xf32>, tensor<200x200xf32>, tensor<300x300xf32>) {
-
- // Hoisted
- // CHECK: %[[ST0:.*]] = tensor.extract_slice %[[TENSOR0_ARG]][%[[I]], %[[I]]]{{.*}}: tensor<100x100xf32> to tensor<?x?xf32>
- // CHECK: %[[V0:.*]] = vector.transfer_read %[[ST0]]{{.*}} : tensor<?x?xf32>, vector<1xf32>
-
- // CHECK: %[[R:.*]]:3 = scf.for %[[J:.*]] = {{.*}} iter_args(
- // CHECK-SAME: %[[TENSOR1_ARG_L2:[0-9a-zA-Z]+]] = %[[TENSOR1_ARG]]
- // CHECK-SAME: %[[TENSOR2_ARG_L2:[0-9a-zA-Z]+]] = %[[TENSOR2_ARG]]
- // CHECK-SAME: %[[V0_ARG_L2:[0-9a-zA-Z]+]] = %[[V0]]
- // CHECK-SAME: ) ->
- // CHECK-SAME: (tensor<200x200xf32>, tensor<300x300xf32>, vector<1xf32>
- %1:3 = scf.for %j = %lb to %ub step %step
- iter_args(%arg6 = %arg0, %arg7 = %arg1, %arg8 = %arg2)
- -> (tensor<100x100xf32>, tensor<200x200xf32>, tensor<300x300xf32>) {
- // Hoists.
- %st0 = tensor.extract_slice %arg6[%i, %i][%step, %step][1, 1] : tensor<100x100xf32> to tensor<?x?xf32>
- %r0 = vector.transfer_read %st0[%c0, %c0], %cst: tensor<?x?xf32>, vector<1xf32>
-
- // CHECK: %[[ST1:.*]] = tensor.extract_slice %[[TENSOR1_ARG_L2]][%[[J]],{{.*}}: tensor<200x200xf32> to tensor<?x?xf32>
- // CHECK: %[[V1:.*]] = vector.transfer_read %[[ST1]]{{.*}} : tensor<?x?xf32>, vector<2xf32>
- // Does not hoist (slice depends on %j)
- %st1 = tensor.extract_slice %arg7[%j, %c0][%step, %step][1, 1] : tensor<200x200xf32> to tensor<?x?xf32>
- %r1 = vector.transfer_read %st1[%c0, %c0], %cst: tensor<?x?xf32>, vector<2xf32>
-
- // CHECK: %[[ST2:.*]] = tensor.extract_slice %[[TENSOR2_ARG_L2]][%[[I]],{{.*}}: tensor<300x300xf32> to tensor<?x?xf32>
- // CHECK: %[[V2:.*]] = vector.transfer_read %[[ST2]]{{.*}} : tensor<?x?xf32>, vector<3xf32>
- // Does not hoist, 2 slice %arg8.
- %st2 = tensor.extract_slice %arg8[%i, %c0][%step, %step][1, 1] : tensor<300x300xf32> to tensor<?x?xf32>
- %r2 = vector.transfer_read %st2[%c0, %c0], %cst: tensor<?x?xf32>, vector<3xf32>
-
- // CHECK: %[[U0:.*]] = "some_use"(%[[V0_ARG_L2]]) : (vector<1xf32>) -> vector<1xf32>
- // CHECK: %[[U1:.*]] = "some_use"(%[[V1]]) : (vector<2xf32>) -> vector<2xf32>
- // CHECK: %[[U2:.*]] = "some_use"(%[[V2]]) : (vector<3xf32>) -> vector<3xf32>
- %u0 = "some_use"(%r0) : (vector<1xf32>) -> vector<1xf32>
- %u1 = "some_use"(%r1) : (vector<2xf32>) -> vector<2xf32>
- %u2 = "some_use"(%r2) : (vector<3xf32>) -> vector<3xf32>
-
- // Hoists
- %w0 = vector.transfer_write %u0, %st0[%c0, %c0] : vector<1xf32>, tensor<?x?xf32>
-
- // CHECK-DAG: %[[STI1:.*]] = vector.transfer_write %[[U1]], %{{.*}} : vector<2xf32>, tensor<?x?xf32>
- // Does not hoist (associated slice depends on %j).
- %w1 = vector.transfer_write %u1, %st1[%i, %i] : vector<2xf32>, tensor<?x?xf32>
-
- // CHECK-DAG: %[[STI2:.*]] = vector.transfer_write %[[U2]], %{{.*}} : vector<3xf32>, tensor<?x?xf32>
- // Does not hoist, 2 slice / insert_slice for %arg8.
- %w2 = vector.transfer_write %u2, %st2[%c0, %c0] : vector<3xf32>, tensor<?x?xf32>
-
- // Hoists.
- %sti0 = tensor.insert_slice %w0 into %arg6[%i, %i][%step, %step][1, 1] : tensor<?x?xf32> into tensor<100x100xf32>
-
- // CHECK-DAG: tensor.insert_slice %[[STI1]] into %[[TENSOR1_ARG_L2]][%[[J]],{{.*}}: tensor<?x?xf32> into tensor<200x200xf32>
- // Does not hoist (depends on %j).
- %sti1 = tensor.insert_slice %w1 into %arg7[%j, %c0][%step, %step][1, 1] : tensor<?x?xf32> into tensor<200x200xf32>
-
- // CHECK-DAG: tensor.insert_slice %[[STI2]] into %[[TENSOR2_ARG_L2]][%[[I]],{{.*}}: tensor<?x?xf32> into tensor<300x300xf32>
- // Does not hoist, 2 slice / insert_slice for %arg8.
- %sti2 = tensor.insert_slice %w2 into %arg8[%i, %c0][%step, %step][1, 1] : tensor<?x?xf32> into tensor<300x300xf32>
- // Extract with a different stride to make sure we cannot fold this extract with the above insert.
- %st22 = tensor.extract_slice %sti2[%i, %c0][%step, %step][2, 1] : tensor<300x300xf32> to tensor<?x?xf32>
- %sti22 = tensor.insert_slice %st22 into %arg8[%i, %c0][%step, %step][1, 1] : tensor<?x?xf32> into tensor<300x300xf32>
-
- // CHECK: scf.yield {{.*}} : tensor<200x200xf32>, tensor<300x300xf32>, vector<1xf32>
- // CHECK: }
- scf.yield %sti0, %sti1, %sti22:
- tensor<100x100xf32>, tensor<200x200xf32>, tensor<300x300xf32>
- }
-
- // Hoisted
- // CHECK: %[[STI0:.*]] = vector.transfer_write %[[R]]#2, %[[ST0]]{{.*}} : vector<1xf32>, tensor<?x?xf32>
- // CHECK: tensor.insert_slice %[[STI0]] into %[[TENSOR0_ARG]][%[[I]], %[[I]]]{{.*}} : tensor<?x?xf32> into tensor<100x100xf32>
-
- // CHECK: scf.yield {{.*}} : tensor<100x100xf32>, tensor<200x200xf32>, tensor<300x300xf32>
- scf.yield %1#0, %1#1, %1#2 :
- tensor<100x100xf32>, tensor<200x200xf32>, tensor<300x300xf32>
-
- // CHECK: }
- }
- return %0#0, %0#1, %0#2 : tensor<100x100xf32>, tensor<200x200xf32>, tensor<300x300xf32>
-}
-
-module attributes {transform.with_named_sequence} {
- transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["func.func"]} in %arg1
- : (!transform.any_op) -> !transform.any_op
- transform.structured.hoist_redundant_tensor_subsets %0
- : (!transform.any_op) -> ()
- transform.yield
- }
-}
-
-// -----
-
// CHECK-LABEL: func.func @hoist_vector_transfer_read(
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[C128:.+]] = arith.constant 128 : index
diff --git a/mlir/test/Transforms/loop-invariant-subset-hoisting.mlir b/mlir/test/Transforms/loop-invariant-subset-hoisting.mlir
index bb60eeaba52455c..3a78287a0dcad2f 100644
--- a/mlir/test/Transforms/loop-invariant-subset-hoisting.mlir
+++ b/mlir/test/Transforms/loop-invariant-subset-hoisting.mlir
@@ -277,3 +277,321 @@ func.func @nested_hoisting(%arg: tensor<?xf32>) -> tensor<?xf32> {
// CHECK: return %[[insert2]]
return %0 : tensor<?xf32>
}
+
+// -----
+
+// CHECK-LABEL: func @hoist_vector_transfer_pairs_tensor
+func.func @hoist_vector_transfer_pairs_tensor(
+ %tensor0: tensor<?x?xf32>, %tensor1: tensor<?x?xf32>, %tensor2: tensor<?x?xf32>,
+ %tensor3: tensor<?x?xf32>, %tensor4: tensor<?x?xf32>, %tensor5: tensor<?x?xf32>,
+ %val: index, %lb : index, %ub : index, %step: index) ->
+ (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>,
+ tensor<?x?xf32>, tensor<?x?xf32>) {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0.0 : f32
+
+// CHECK: vector.transfer_read %{{.*}} : tensor<?x?xf32>, vector<1xf32>
+// CHECK: scf.for {{.*}} iter_args({{.*}}) ->
+// CHECK-SAME: (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, vector<1xf32>) {
+// CHECK: vector.transfer_read %{{.*}} : tensor<?x?xf32>, vector<2xf32>
+// CHECK: scf.for {{.*}} iter_args({{.*}}) ->
+// CHECK-SAME: (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, vector<2xf32>, vector<1xf32>) {
+// CHECK: vector.transfer_read %{{.*}} : tensor<?x?xf32>, vector<4xf32>
+// CHECK: "test.some_crippling_use"(%{{.*}}) : (tensor<?x?xf32>) -> ()
+// CHECK: vector.transfer_read %{{.*}} : tensor<?x?xf32>, vector<5xf32>
+// CHECK: "test.some_use"(%{{.*}}) : (vector<1xf32>) -> vector<1xf32>
+// CHECK: "test.some_use"(%{{.*}}) : (vector<2xf32>) -> vector<2xf32>
+// CHECK: "test.some_use"(%{{.*}}) : (tensor<?x?xf32>) -> vector<3xf32>
+// CHECK: "test.some_use"(%{{.*}}) : (vector<4xf32>) -> vector<4xf32>
+// CHECK: "test.some_use"(%{{.*}}) : (vector<5xf32>) -> vector<5xf32>
+// CHECK: vector.transfer_write %{{.*}} : vector<3xf32>, tensor<?x?xf32>
+// CHECK: vector.transfer_write %{{.*}} : vector<4xf32>, tensor<?x?xf32>
+// CHECK: vector.transfer_write %{{.*}} : vector<5xf32>, tensor<?x?xf32>
+// CHECK: "test.some_crippling_use"(%{{.*}}) : (tensor<?x?xf32>) -> ()
+// CHECK: scf.yield {{.*}} :
+// CHECK-SAME: tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, vector<2xf32>, vector<1xf32>
+// CHECK: }
+// CHECK: vector.transfer_write %{{.*}} : vector<2xf32>, tensor<?x?xf32>
+// CHECK: scf.yield {{.*}} :
+// CHECK-SAME: tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, vector<1xf32>
+// CHECK: }
+// CHECK: vector.transfer_write %{{.*}} : vector<1xf32>, tensor<?x?xf32>
+ %0:6 = scf.for %i = %lb to %ub step %step
+ iter_args(%arg0 = %tensor0, %arg1 = %tensor1, %arg2 = %tensor2,
+ %arg3 = %tensor3, %arg4 = %tensor4, %arg5 = %tensor5)
+ -> (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>,
+ tensor<?x?xf32>, tensor<?x?xf32>) {
+ %1:6 = scf.for %j = %lb to %ub step %step
+ iter_args(%arg6 = %arg0, %arg7 = %arg1, %arg8 = %arg2,
+ %arg9 = %arg3, %arg10 = %arg4, %arg11 = %arg5)
+ -> (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>,
+ tensor<?x?xf32>, tensor<?x?xf32>) {
+ %r0 = vector.transfer_read %arg7[%c0, %c0], %cst: tensor<?x?xf32>, vector<1xf32>
+ %r1 = vector.transfer_read %arg6[%i, %i], %cst: tensor<?x?xf32>, vector<2xf32>
+ %r3 = vector.transfer_read %arg9[%c0, %c0], %cst: tensor<?x?xf32>, vector<4xf32>
+ "test.some_crippling_use"(%arg10) : (tensor<?x?xf32>) -> ()
+ %r4 = vector.transfer_read %arg10[%c0, %c0], %cst: tensor<?x?xf32>, vector<5xf32>
+ %r5 = vector.transfer_read %arg11[%c0, %c0], %cst: tensor<?x?xf32>, vector<6xf32>
+ "test.some_crippling_use"(%arg11) : (tensor<?x?xf32>) -> ()
+ %u0 = "test.some_use"(%r0) : (vector<1xf32>) -> vector<1xf32>
+ %u1 = "test.some_use"(%r1) : (vector<2xf32>) -> vector<2xf32>
+ %u2 = "test.some_use"(%arg8) : (tensor<?x?xf32>) -> vector<3xf32>
+ %u3 = "test.some_use"(%r3) : (vector<4xf32>) -> vector<4xf32>
+ %u4 = "test.some_use"(%r4) : (vector<5xf32>) -> vector<5xf32>
+ %u5 = "test.some_use"(%r5) : (vector<6xf32>) -> vector<6xf32>
+ %w1 = vector.transfer_write %u0, %arg7[%c0, %c0] : vector<1xf32>, tensor<?x?xf32>
+ %w0 = vector.transfer_write %u1, %arg6[%i, %i] : vector<2xf32>, tensor<?x?xf32>
+ %w2 = vector.transfer_write %u2, %arg8[%c0, %c0] : vector<3xf32>, tensor<?x?xf32>
+ %w3 = vector.transfer_write %u3, %arg9[%c0, %c0] : vector<4xf32>, tensor<?x?xf32>
+ %w4 = vector.transfer_write %u4, %arg10[%c0, %c0] : vector<5xf32>, tensor<?x?xf32>
+ %w5 = vector.transfer_write %u5, %arg11[%c0, %c0] : vector<6xf32>, tensor<?x?xf32>
+ "test.some_crippling_use"(%w3) : (tensor<?x?xf32>) -> ()
+ scf.yield %w0, %w1, %w2, %w3, %w4, %w5 :
+ tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>,
+ tensor<?x?xf32>, tensor<?x?xf32>
+ }
+ scf.yield %1#0, %1#1, %1#2, %1#3, %1#4, %1#5 :
+ tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>,
+ tensor<?x?xf32>, tensor<?x?xf32>
+ }
+ return %0#0, %0#1, %0#2, %0#3, %0#4, %0#5 :
+ tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>,
+ tensor<?x?xf32>, tensor<?x?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @hoist_vector_transfer_pairs_disjoint_tensor(
+// CHECK-SAME: %[[TENSOR0:[a-zA-Z0-9]*]]: tensor<?x?xf32>,
+// CHECK-SAME: %[[TENSOR1:[a-zA-Z0-9]*]]: tensor<?x?xf32>,
+// CHECK-SAME: %[[TENSOR2:[a-zA-Z0-9]*]]: tensor<?x?xf32>,
+// CHECK-SAME: %[[TENSOR3:[a-zA-Z0-9]*]]: tensor<?x?xf32>,
+func.func @hoist_vector_transfer_pairs_disjoint_tensor(
+ %tensor0: tensor<?x?xf32>, %tensor1: tensor<?x?xf32>,
+ %tensor2: tensor<?x?xf32>, %tensor3: tensor<?x?xf32>,
+ %val: index, %lb : index, %ub : index, %step: index,
+ %random_index : index) ->
+ (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c3 = arith.constant 3 : index
+ %cst = arith.constant 0.0 : f32
+
+// CHECK: vector.transfer_read %[[TENSOR2]]{{.*}} : tensor<?x?xf32>, vector<3xf32>
+// CHECK: vector.transfer_read %[[TENSOR2]]{{.*}} : tensor<?x?xf32>, vector<3xf32>
+// CHECK: vector.transfer_read %[[TENSOR3]]{{.*}} : tensor<?x?xf32>, vector<4xf32>
+// CHECK: vector.transfer_read %[[TENSOR3]]{{.*}} : tensor<?x?xf32>, vector<4xf32>
+// CHECK: %[[R:.*]]:8 = scf.for {{.*}} iter_args({{.*}}) ->
+// CHECK-SAME: (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, vector<3xf32>, vector<3xf32>, vector<4xf32>, vector<4xf32>) {
+// CHECK: scf.for {{.*}} iter_args({{.*}}) ->
+// CHECK-SAME: (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, vector<3xf32>, vector<3xf32>, vector<4xf32>, vector<4xf32>) {
+// CHECK: vector.transfer_read %[[TENSOR1]]{{.*}} : tensor<?x?xf32>, vector<2xf32>
+// CHECK: vector.transfer_read %[[TENSOR1]]{{.*}} : tensor<?x?xf32>, vector<2xf32>
+// CHECK: "test.some_use"(%{{.*}}) : (vector<2xf32>) -> vector<2xf32>
+// CHECK: "test.some_use"(%{{.*}}) : (vector<2xf32>) -> vector<2xf32>
+// CHECK: "test.some_use"(%{{.*}}) : (vector<3xf32>) -> vector<3xf32>
+// CHECK: "test.some_use"(%{{.*}}) : (vector<3xf32>) -> vector<3xf32>
+// CHECK: "test.some_use"(%{{.*}}) : (vector<4xf32>) -> vector<4xf32>
+// CHECK: "test.some_use"(%{{.*}}) : (vector<4xf32>) -> vector<4xf32>
+// CHECK: "test.some_use"(%{{.*}}) : (vector<2xf32>) -> vector<2xf32>
+// CHECK: "test.some_use"(%{{.*}}) : (vector<2xf32>) -> vector<2xf32>
+// CHECK: vector.transfer_write %{{.*}}, %{{.*}}{{.*}} : vector<2xf32>, tensor<?x?xf32>
+// CHECK: vector.transfer_write %{{.*}}, %{{.*}}{{.*}} : vector<2xf32>, tensor<?x?xf32>
+// CHECK: scf.yield {{.*}} :
+// CHECK-SAME: tensor<?x?xf32>, tensor<?x?xf32>, vector<3xf32>, vector<3xf32>, vector<4xf32>, vector<4xf32>
+// CHECK: }
+// CHECK: scf.yield {{.*}} :
+// CHECK-SAME: tensor<?x?xf32>, tensor<?x?xf32>, vector<3xf32>, vector<3xf32>, vector<4xf32>, vector<4xf32>
+// CHECK: }
+// CHECK: %[[TENSOR4:.*]] = vector.transfer_write %[[R]]#7, %[[R]]#3{{.*}} : vector<4xf32>, tensor<?x?xf32>
+// CHECK: vector.transfer_write %[[R]]#6, %[[TENSOR4]]{{.*}} : vector<4xf32>, tensor<?x?xf32>
+// CHECK: %[[TENSOR5:.*]] = vector.transfer_write %[[R]]#5, %[[R]]#2{{.*}} : vector<3xf32>, tensor<?x?xf32>
+// CHECK: vector.transfer_write %[[R]]#4, %[[TENSOR5]]{{.*}} : vector<3xf32>, tensor<?x?xf32>
+ %0:4 = scf.for %i = %lb to %ub step %step
+ iter_args(%arg0 = %tensor0, %arg1 = %tensor1, %arg2 = %tensor2,
+ %arg3 = %tensor3)
+ -> (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>) {
+ %1:4 = scf.for %j = %lb to %ub step %step
+ iter_args(%arg4 = %arg0, %arg5 = %arg1, %arg6 = %arg2,
+ %arg7 = %arg3)
+ -> (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>) {
+ %r00 = vector.transfer_read %arg5[%c0, %c0], %cst: tensor<?x?xf32>, vector<2xf32>
+ %r01 = vector.transfer_read %arg5[%c0, %c1], %cst: tensor<?x?xf32>, vector<2xf32>
+ %r20 = vector.transfer_read %arg6[%c0, %c0], %cst: tensor<?x?xf32>, vector<3xf32>
+ %r21 = vector.transfer_read %arg6[%c0, %c3], %cst: tensor<?x?xf32>, vector<3xf32>
+ %r30 = vector.transfer_read %arg7[%c0, %random_index], %cst: tensor<?x?xf32>, vector<4xf32>
+ %r31 = vector.transfer_read %arg7[%c1, %random_index], %cst: tensor<?x?xf32>, vector<4xf32>
+ %r10 = vector.transfer_read %arg4[%i, %i], %cst: tensor<?x?xf32>, vector<2xf32>
+ %r11 = vector.transfer_read %arg4[%random_index, %random_index], %cst: tensor<?x?xf32>, vector<2xf32>
+ %u00 = "test.some_use"(%r00) : (vector<2xf32>) -> vector<2xf32>
+ %u01 = "test.some_use"(%r01) : (vector<2xf32>) -> vector<2xf32>
+ %u20 = "test.some_use"(%r20) : (vector<3xf32>) -> vector<3xf32>
+ %u21 = "test.some_use"(%r21) : (vector<3xf32>) -> vector<3xf32>
+ %u30 = "test.some_use"(%r30) : (vector<4xf32>) -> vector<4xf32>
+ %u31 = "test.some_use"(%r31) : (vector<4xf32>) -> vector<4xf32>
+ %u10 = "test.some_use"(%r10) : (vector<2xf32>) -> vector<2xf32>
+ %u11 = "test.some_use"(%r11) : (vector<2xf32>) -> vector<2xf32>
+ %w10 = vector.transfer_write %u00, %arg5[%c0, %c0] : vector<2xf32>, tensor<?x?xf32>
+ %w11 = vector.transfer_write %u01, %w10[%c0, %c1] : vector<2xf32>, tensor<?x?xf32>
+ %w20 = vector.transfer_write %u20, %arg6[%c0, %c0] : vector<3xf32>, tensor<?x?xf32>
+ %w21 = vector.transfer_write %u21, %w20[%c0, %c3] : vector<3xf32>, tensor<?x?xf32>
+ %w30 = vector.transfer_write %u30, %arg7[%c0, %random_index] : vector<4xf32>, tensor<?x?xf32>
+ %w31 = vector.transfer_write %u31, %w30[%c1, %random_index] : vector<4xf32>, tensor<?x?xf32>
+ %w00 = vector.transfer_write %u10, %arg4[%i, %i] : vector<2xf32>, tensor<?x?xf32>
+ %w01 = vector.transfer_write %u11, %w00[%random_index, %random_index] : vector<2xf32>, tensor<?x?xf32>
+ scf.yield %w01, %w11, %w21, %w31 : tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>
+ }
+ scf.yield %1#0, %1#1, %1#2, %1#3 : tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>
+ }
+ return %0#0, %0#1, %0#2, %0#3 : tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @hoist_vector_transfer_pairs_tensor_and_slices
+// CHECK-SAME: %[[TENSOR0:[a-zA-Z0-9]*]]: tensor<?x?xf32>,
+// CHECK-SAME: %[[TENSOR1:[a-zA-Z0-9]*]]: tensor<?x?xf32>,
+// CHECK-SAME: %[[TENSOR2:[a-zA-Z0-9]*]]: tensor<?x?xf32>,
+// CHECK-SAME: %[[TENSOR3:[a-zA-Z0-9]*]]: tensor<?x?xf32>,
+// CHECK-SAME: %[[TENSOR4:[a-zA-Z0-9]*]]: tensor<?x?xf32>,
+// CHECK-SAME: %[[TENSOR5:[a-zA-Z0-9]*]]: tensor<?x?xf32>
+func.func @hoist_vector_transfer_pairs_tensor_and_slices(
+ %tensor0: tensor<?x?xf32>, %tensor1: tensor<?x?xf32>, %tensor2: tensor<?x?xf32>,
+ %tensor3: tensor<?x?xf32>, %tensor4: tensor<?x?xf32>, %tensor5: tensor<?x?xf32>,
+ %val: index, %lb : index, %ub : index, %step: index) ->
+ (
+ tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>//, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>
+ ) {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0.0 : f32
+
+ // CHECK: scf.for %[[I:.*]] = {{.*}} iter_args(
+ // CHECK-SAME: %[[TENSOR0_ARG:[0-9a-zA-Z]+]] = %[[TENSOR0]],
+ // CHECK-SAME: %[[TENSOR1_ARG:[0-9a-zA-Z]+]] = %[[TENSOR1]],
+ // CHECK-SAME: %[[TENSOR2_ARG:[0-9a-zA-Z]+]] = %[[TENSOR2]]
+ // CHECK-SAME: ) ->
+ // CHECK-SAME: (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>
+ %0:3 = scf.for %i = %lb to %ub step %step
+ iter_args(%arg0 = %tensor0, %arg1 = %tensor1, %arg2 = %tensor2)
+ -> (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>) {
+
+ // Hoisted
+ // CHECK: %[[ST0:.*]] = tensor.extract_slice %[[TENSOR0_ARG]][%[[I]], %[[I]]]{{.*}}: tensor<?x?xf32> to tensor<?x?xf32>
+ // CHECK: %[[V0:.*]] = vector.transfer_read %[[ST0]]{{.*}} : tensor<?x?xf32>, vector<1xf32>
+
+ // CHECK: %[[R:.*]]:5 = scf.for %[[J:.*]] = {{.*}} iter_args(
+ // CHECK-SAME: %[[TENSOR0_ARG_L2:[0-9a-zA-Z]+]] = %[[TENSOR0_ARG]]
+ // CHECK-SAME: %[[TENSOR1_ARG_L2:[0-9a-zA-Z]+]] = %[[TENSOR1_ARG]]
+ // CHECK-SAME: %[[TENSOR2_ARG_L2:[0-9a-zA-Z]+]] = %[[TENSOR2_ARG]]
+ // CHECK-SAME: %[[ST0_ARG_L2:[0-9a-zA-Z]+]] = %[[ST0]]
+ // CHECK-SAME: %[[V0_ARG_L2:[0-9a-zA-Z]+]] = %[[V0]]
+ // CHECK-SAME: ) ->
+ // CHECK-SAME: (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, vector<1xf32>)
+ %1:3 = scf.for %j = %lb to %ub step %step
+ iter_args(%arg6 = %arg0, %arg7 = %arg1, %arg8 = %arg2)
+ -> (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>) {
+ // Hoists.
+ %st0 = tensor.extract_slice %arg6[%i, %i][%step, %step][1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
+ %r0 = vector.transfer_read %st0[%c0, %c0], %cst: tensor<?x?xf32>, vector<1xf32>
+
+ // CHECK: %[[ST1:.*]] = tensor.extract_slice %[[TENSOR1_ARG_L2]][%[[J]],{{.*}}: tensor<?x?xf32> to tensor<?x?xf32>
+ // CHECK: %[[V1:.*]] = vector.transfer_read %[[ST1]]{{.*}} : tensor<?x?xf32>, vector<2xf32>
+ // Does not hoist (slice depends on %j)
+ %st1 = tensor.extract_slice %arg7[%j, %c0][%step, %step][1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
+ %r1 = vector.transfer_read %st1[%c0, %c0], %cst: tensor<?x?xf32>, vector<2xf32>
+
+ // CHECK: %[[ST2:.*]] = tensor.extract_slice %[[TENSOR2_ARG_L2]][%[[I]],{{.*}}: tensor<?x?xf32> to tensor<?x?xf32>
+ // CHECK: %[[V2:.*]] = vector.transfer_read %[[ST2]]{{.*}} : tensor<?x?xf32>, vector<3xf32>
+ // Does not hoist, 2 slice %arg8.
+ %st2 = tensor.extract_slice %arg8[%i, %c0][%step, %step][1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
+ %r2 = vector.transfer_read %st2[%c0, %c0], %cst: tensor<?x?xf32>, vector<3xf32>
+
+ // CHECK: %[[U0:.*]] = "test.some_use"(%[[V0_ARG_L2]]) : (vector<1xf32>) -> vector<1xf32>
+ // CHECK: %[[U1:.*]] = "test.some_use"(%[[V1]]) : (vector<2xf32>) -> vector<2xf32>
+ // CHECK: %[[U2:.*]] = "test.some_use"(%[[V2]]) : (vector<3xf32>) -> vector<3xf32>
+ %u0 = "test.some_use"(%r0) : (vector<1xf32>) -> vector<1xf32>
+ %u1 = "test.some_use"(%r1) : (vector<2xf32>) -> vector<2xf32>
+ %u2 = "test.some_use"(%r2) : (vector<3xf32>) -> vector<3xf32>
+
+ // Hoists
+ %w0 = vector.transfer_write %u0, %st0[%c0, %c0] : vector<1xf32>, tensor<?x?xf32>
+
+ // CHECK-DAG: %[[STI1:.*]] = vector.transfer_write %[[U1]], %{{.*}} : vector<2xf32>, tensor<?x?xf32>
+ // Does not hoist (associated slice depends on %j).
+ %w1 = vector.transfer_write %u1, %st1[%i, %i] : vector<2xf32>, tensor<?x?xf32>
+
+ // CHECK-DAG: %[[STI2:.*]] = vector.transfer_write %[[U2]], %{{.*}} : vector<3xf32>, tensor<?x?xf32>
+ // Does not hoist, 2 slice / insert_slice for %arg8.
+ %w2 = vector.transfer_write %u2, %st2[%c0, %c0] : vector<3xf32>, tensor<?x?xf32>
+
+ // Hoists.
+ %sti0 = tensor.insert_slice %w0 into %arg6[%i, %i][%step, %step][1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
+
+ // CHECK-DAG: tensor.insert_slice %[[STI1]] into %[[TENSOR1_ARG_L2]][%[[J]],{{.*}}: tensor<?x?xf32> into tensor<?x?xf32>
+ // Does not hoist (depends on %j).
+ %sti1 = tensor.insert_slice %w1 into %arg7[%j, %c0][%step, %step][1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
+
+ // CHECK-DAG: tensor.insert_slice %[[STI2]] into %[[TENSOR2_ARG_L2]][%[[I]],{{.*}}: tensor<?x?xf32> into tensor<?x?xf32>
+ // Does not hoist, 2 slice / insert_slice for %arg8.
+ %sti2 = tensor.insert_slice %w2 into %arg8[%i, %c0][%step, %step][1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
+ // Extract with a different stride to make sure we cannot fold this extract with the above insert.
+ %st22 = tensor.extract_slice %sti2[%i, %c0][%step, %step][2, 1] : tensor<?x?xf32> to tensor<?x?xf32>
+ %sti22 = tensor.insert_slice %st22 into %arg8[%i, %c0][%step, %step][1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
+
+ // CHECK: scf.yield {{.*}} : tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, vector<1xf32>
+ // CHECK: }
+ scf.yield %sti0, %sti1, %sti22:
+ tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>
+ }
+
+ // Hoisted
+ // CHECK: %[[STI0:.*]] = vector.transfer_write %[[R]]#4, %[[R]]#3{{.*}} : vector<1xf32>, tensor<?x?xf32>
+ // CHECK: tensor.insert_slice %[[STI0]] into %[[R]]#0[%[[I]], %[[I]]]{{.*}} : tensor<?x?xf32> into tensor<?x?xf32>
+
+ // CHECK: scf.yield {{.*}} : tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>
+ scf.yield %1#0, %1#1, %1#2 :
+ tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>
+
+ // CHECK: }
+ }
+ return %0#0, %0#1, %0#2 : tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @hoist_vector_transfer_write_pairs_disjoint_tensor(
+// CHECK-SAME: %[[T:.*]]: tensor<?x?xf32>,
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
+// CHECK-DAG: %[[R0:.*]] = vector.transfer_read %[[T]][%[[C0]], %[[C0]]], %{{.*}} : tensor<?x?xf32>, vector<2xf32>
+// CHECK-DAG: %[[R1:.*]] = vector.transfer_read %[[T]][%[[C0]], %[[C3]]], %{{.*}} : tensor<?x?xf32>, vector<2xf32>
+// CHECK: %[[F:.*]]:3 = scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[TL:.*]] = %[[T]], %[[R2:.*]] = %[[R0]], %[[R3:.*]] = %[[R1]]) -> (tensor<?x?xf32>, vector<2xf32>, vector<2xf32>) {
+// CHECK: %[[R4:.*]] = "test.some_use"(%[[R2]]) : (vector<2xf32>) -> vector<2xf32>
+// CHECK: %[[R5:.*]] = "test.some_use"(%[[R3]]) : (vector<2xf32>) -> vector<2xf32>
+// CHECK: scf.yield %[[TL]], %[[R4]], %[[R5]] : tensor<?x?xf32>, vector<2xf32>, vector<2xf32>
+// CHECK: }
+// CHECK: %[[W0:.*]] = vector.transfer_write %[[F]]#2, %[[F]]#0[%[[C0]], %[[C3]]] : vector<2xf32>, tensor<?x?xf32>
+// CHECK: %[[W1:.*]] = vector.transfer_write %[[F]]#1, %[[W0]][%[[C0]], %[[C0]]] : vector<2xf32>, tensor<?x?xf32>
+// CHECK: return %[[W1]] : tensor<?x?xf32>
+func.func @hoist_vector_transfer_write_pairs_disjoint_tensor(
+ %tensor: tensor<?x?xf32>,
+ %val: index, %lb : index, %ub : index, %step: index) ->
+ (tensor<?x?xf32>) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c3 = arith.constant 3 : index
+ %cst = arith.constant 0.0 : f32
+ %1 = scf.for %j = %lb to %ub step %step iter_args(%arg5 = %tensor)
+ -> (tensor<?x?xf32>) {
+ %r00 = vector.transfer_read %arg5[%c0, %c0], %cst: tensor<?x?xf32>, vector<2xf32>
+ %u00 = "test.some_use"(%r00) : (vector<2xf32>) -> vector<2xf32>
+ %w10 = vector.transfer_write %u00, %arg5[%c0, %c0] : vector<2xf32>, tensor<?x?xf32>
+
+ // Hoist by properly bypassing the disjoint write %w10.
+ %r01 = vector.transfer_read %w10[%c0, %c3], %cst: tensor<?x?xf32>, vector<2xf32>
+ %u01 = "test.some_use"(%r01) : (vector<2xf32>) -> vector<2xf32>
+ %w11 = vector.transfer_write %u01, %w10[%c0, %c3] : vector<2xf32>, tensor<?x?xf32>
+ scf.yield %w11 : tensor<?x?xf32>
+ }
+ return %1 : tensor<?x?xf32>
+}
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 4e814428f0cc457..daa3e56d7d4f06b 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -4598,6 +4598,7 @@ cc_library(
":Pass",
":SCFDialect",
":SideEffectInterfaces",
+ ":SubsetOpInterface",
":Support",
":TensorDialect",
":Transforms",
>From 62af2908562759413c6955148c6486d06b739ee0 Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Mon, 30 Oct 2023 15:51:20 +0900
Subject: [PATCH 6/6] [mlir][transform] Add transform op for loop-invariant
subset hoisting
---
.../Linalg/TransformOps/LinalgTransformOps.td | 50 ----------------
.../mlir/Dialect/Transform/IR/TransformOps.td | 59 +++++++++++++++++++
.../Transforms/LoopInvariantCodeMotionUtils.h | 4 +-
.../TransformOps/LinalgTransformOps.cpp | 29 ---------
.../lib/Dialect/Transform/IR/TransformOps.cpp | 20 +++++++
.../Transforms/LoopInvariantCodeMotion.cpp | 4 +-
.../Utils/LoopInvariantCodeMotionUtils.cpp | 17 +++---
.../Dialect/Transform/test-interpreter.mlir | 45 ++++++++++++++
8 files changed, 140 insertions(+), 88 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 1ff88d036bc036c..732b6fe95c837d6 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -2210,56 +2210,6 @@ def ConvertConv2DToImg2ColOp : Op<Transform_Dialect,
}];
}
-//===----------------------------------------------------------------------===//
-// HoistRedundantTensorSubsetsOp
-//===----------------------------------------------------------------------===//
-
-def HoistRedundantTensorSubsetsOp :
- Op<Transform_Dialect, "structured.hoist_redundant_tensor_subsets",
- [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
- TransformEachOpTrait,
- TransformOpInterface,
- ReportTrackingListenerFailuresOpTrait]> {
- let description = [{
- Hoists supported tensor subset extract/insert operation pairs out of
- immediately enclosing loop iteratively, if the following conditions
- are true:
- 1. The 2 ops access the same tensor subset.
- 2. All operands are invariant under the enclosing loop.
-
- The supported subset extract/insert operation pairs currently comprise:
- - tensor.extract_slice / tensor.insert_slice
- - vector.transfer_read / vector.transfer_write on tensors
-
- Only scf.for loops are currently supported.
-
- When applied to:
- 1. an scf.for loop, hoist out of this loop only.
- 2. a non-loop op, apply hoisting to all the contained loop ops.
-
- #### Return modes:
-
- The operation always succeeds and returns nothing.
- }];
-
- let arguments = (ins TransformHandleTypeInterface:$target);
- let results = (outs);
-
- let assemblyFormat = [{
- $target
- attr-dict
- `:` functional-type(operands, results)
- }];
-
- let extraClassDeclaration = [{
- ::mlir::DiagnosedSilenceableFailure applyToOne(
- ::mlir::transform::TransformRewriter &rewriter,
- ::mlir::Operation *target,
- ::mlir::transform::ApplyToEachResultList &results,
- ::mlir::transform::TransformState &state);
- }];
-}
-
//===----------------------------------------------------------------------===//
// InsertSliceToCopyOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index b14c89eadb097d9..6d57e104a90285a 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -691,6 +691,65 @@ def GetTypeOp : TransformDialectOp<"get_type",
"functional-type(operands, results)";
}
+def HoistLoopInvariantSubsetsOp
+ : TransformDialectOp<"hoist_loop_invariant_subsets",
+ [TransformOpInterface, TransformEachOpTrait,
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+ ReportTrackingListenerFailuresOpTrait]> {
+ let summary = "Hoist loop invariant subset ops";
+ let description = [{
+ This transform hoist loop-invariant subset ops out of loop-like ops. It
+ looks for matching subset extraction/insertion op pairs and hoists them. The
+ loop body operates on a newly introduced region iter_arg.
+
+ Example:
+ ```
+ %r = scf.for ... iter_args(%t = %a) -> (tensor<?xf32>) {
+ %0 = tensor.extract_slice %t[0][5][1] : tensor<?xf32> to tensor<5xf32>
+ %1 = "test.foo"(%0) : (tensor<5xf32>) -> (tensor<5xf32>)
+ %2 = tensor.insert_slice %1 into %t[0][5][1]
+ : tensor<5xf32> into tensor<?xf32>
+ scf.yield %2 : tensor<?xf32>
+ }
+ ```
+ Is transformed to:
+ ```
+ %0 = tensor.extract_slice %a[0][5][1] : tensor<?xf32> to tensor<5xf32>
+ %new_loop:2 = scf.for ... iter_args(%t = %a, %h = %0) -> (tensor<?xf32>) {
+ %1 = "test.foo"(%h) : (tensor<5xf32>) -> (tensor<5xf32>)
+ scf.yield %t, %2 : tensor<?xf32>, tensor<5xf32>
+ }
+ %r = tensor.insert_slice %new_loop#1 into %new_loop#0
+ : tensor<5xf32> into tensor<?xf32>
+ ```
+
+ Subset ops are hoisted only if there are no conflicting subset ops. E.g.,
+ if there were a second overlapping extraction in the above example, no ops
+ could be hoisted safely.
+
+ This transform looks for `LoopLikeOpInterface` ops within the targeted op,
+ including the target op itself. It attempts hoisting on all found loop-like
+ ops.
+
+ This transform reads the target handle and modifies the payload.
+
+ TODO: Make this op more targeted if needed. I.e., apply the transformation
+ only to the targeted `LoopLikeOpInterface` op.
+ }];
+
+ let arguments = (ins TransformHandleTypeInterface:$target);
+ let results = (outs);
+ let assemblyFormat = "$target attr-dict `:` type($target)";
+
+ let extraClassDeclaration = [{
+ ::mlir::DiagnosedSilenceableFailure applyToOne(
+ ::mlir::transform::TransformRewriter &rewriter,
+ ::mlir::Operation *target,
+ ::mlir::transform::ApplyToEachResultList &results,
+ ::mlir::transform::TransformState &state);
+ }];
+}
+
def IncludeOp : TransformDialectOp<"include",
[CallOpInterface,
MatchOpInterface,
diff --git a/mlir/include/mlir/Transforms/LoopInvariantCodeMotionUtils.h b/mlir/include/mlir/Transforms/LoopInvariantCodeMotionUtils.h
index 579054070f729b0..3ceef44d799e893 100644
--- a/mlir/include/mlir/Transforms/LoopInvariantCodeMotionUtils.h
+++ b/mlir/include/mlir/Transforms/LoopInvariantCodeMotionUtils.h
@@ -18,6 +18,7 @@ namespace mlir {
class LoopLikeOpInterface;
class Operation;
class Region;
+class RewriterBase;
class Value;
/// Given a list of regions, perform loop-invariant code motion. An operation is
@@ -108,7 +109,8 @@ size_t moveLoopInvariantCode(LoopLikeOpInterface loopLike);
/// %r = tensor.insert_slice %new_loop#1 into %new_loop#0
/// : tensor<5xf32> into tensor<?xf32>
/// ```
-LoopLikeOpInterface hoistLoopInvariantSubsets(LoopLikeOpInterface loopLike);
+LoopLikeOpInterface hoistLoopInvariantSubsets(RewriterBase &rewriter,
+ LoopLikeOpInterface loopLike);
} // end namespace mlir
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 8508507871d0c6c..79bb708eea67572 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3139,35 +3139,6 @@ DiagnosedSilenceableFailure transform::ConvertConv2DToImg2ColOp::applyToOne(
return DiagnosedSilenceableFailure::success();
}
-//===----------------------------------------------------------------------===//
-// HoistRedundantTensorSubsetsOp
-//===----------------------------------------------------------------------===//
-
-DiagnosedSilenceableFailure
-transform::HoistRedundantTensorSubsetsOp::applyToOne(
- transform::TransformRewriter &rewriter, Operation *target,
- transform::ApplyToEachResultList &results,
- transform::TransformState &state) {
- auto forOp = dyn_cast<scf::ForOp>(target);
- if (forOp) {
- linalg::hoistRedundantSubsetExtractInsert(rewriter, forOp);
- return DiagnosedSilenceableFailure::success();
- }
-
- // TODO: walking in some reverse / inside-out order would be more efficient
- // and would capture more cases.
- target->walk([&](scf::ForOp forOp) {
- hoistRedundantSubsetExtractInsert(rewriter, forOp);
- });
- return DiagnosedSilenceableFailure::success();
-}
-
-void transform::HoistRedundantTensorSubsetsOp::getEffects(
- SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
- transform::onlyReadsHandle(getTarget(), effects);
- transform::modifiesPayload(effects);
-}
-
//===----------------------------------------------------------------------===//
// InsertSliceToCopyOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 8db77b6059dd2e3..ff71518e4358472 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -1395,6 +1395,26 @@ transform::GetTypeOp::apply(transform::TransformRewriter &rewriter,
return DiagnosedSilenceableFailure::success();
}
+//===----------------------------------------------------------------------===//
+// HoistLoopInvariantSubsetsOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure transform::HoistLoopInvariantSubsetsOp::applyToOne(
+ transform::TransformRewriter &rewriter, Operation *target,
+ transform::ApplyToEachResultList &results,
+ transform::TransformState &state) {
+ target->walk([&](LoopLikeOpInterface loopLike) {
+ (void)hoistLoopInvariantSubsets(rewriter, loopLike);
+ });
+ return DiagnosedSilenceableFailure::success();
+}
+
+void transform::HoistLoopInvariantSubsetsOp::getEffects(
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ transform::onlyReadsHandle(getTarget(), effects);
+ transform::modifiesPayload(effects);
+}
+
//===----------------------------------------------------------------------===//
// IncludeOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp b/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp
index e6d8af8f05832d3..02c3ea1ce9b650c 100644
--- a/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp
+++ b/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp
@@ -12,6 +12,7 @@
#include "mlir/Transforms/Passes.h"
+#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/LoopLikeInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Transforms/LoopInvariantCodeMotionUtils.h"
@@ -47,11 +48,12 @@ void LoopInvariantCodeMotion::runOnOperation() {
}
void LoopInvariantSubsetHoisting::runOnOperation() {
+ IRRewriter rewriter(getOperation()->getContext());
// Walk through all loops in a function in innermost-loop-first order. This
// way, we first hoist from the inner loop, and place the ops in the outer
// loop, which in turn can be further hoisted from.
getOperation()->walk([&](LoopLikeOpInterface loopLike) {
- (void)hoistLoopInvariantSubsets(loopLike);
+ (void)hoistLoopInvariantSubsets(rewriter, loopLike);
});
}
diff --git a/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp b/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
index bb4e6dc62b9c935..4f67cca74088248 100644
--- a/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
@@ -311,12 +311,12 @@ MatchingSubsets::populateSubsetOpsAtIterArg(LoopLikeOpInterface loopLike,
/// loop-like op and index into loop-invariant subset locations. Return the
/// newly created loop op (that has extra iter_args) or the original loop op if
/// nothing was hoisted.
-static LoopLikeOpInterface hoistSubsetAtIterArg(LoopLikeOpInterface loopLike,
+static LoopLikeOpInterface hoistSubsetAtIterArg(RewriterBase &rewriter,
+ LoopLikeOpInterface loopLike,
BlockArgument iterArg) {
assert(iterArg.getOwner()->getParentOp() == loopLike && "invalid iter_arg");
auto it = llvm::find(loopLike.getRegionIterArgs(), iterArg);
int64_t iterArgIdx = std::distance(loopLike.getRegionIterArgs().begin(), it);
- IRRewriter rewriter(loopLike.getContext());
MatchingSubsets subsets;
if (failed(subsets.populateSubsetOpsAtIterArg(loopLike, iterArg)))
return loopLike;
@@ -367,11 +367,12 @@ static LoopLikeOpInterface hoistSubsetAtIterArg(LoopLikeOpInterface loopLike,
OpResult newLoopResult = loopLike.getLoopResults()->back();
extractionOp->moveBefore(loopLike);
insertionOp->moveAfter(loopLike);
- insertionOp.getUpdatedDestination().replaceAllUsesWith(
- insertionOp.getDestinationOperand().get());
+ rewriter.replaceAllUsesWith(insertionOp.getUpdatedDestination(),
+ insertionOp.getDestinationOperand().get());
extractionOp.getSourceOperand().set(
loopLike.getTiedLoopInit(iterArg)->get());
- loopResult.replaceAllUsesWith(insertionOp.getUpdatedDestination());
+ rewriter.replaceAllUsesWith(loopResult,
+ insertionOp.getUpdatedDestination());
insertionOp.getSourceOperand().set(newLoopResult);
insertionOp.getDestinationOperand().set(loopResult);
}
@@ -381,12 +382,14 @@ static LoopLikeOpInterface hoistSubsetAtIterArg(LoopLikeOpInterface loopLike,
}
LoopLikeOpInterface
-mlir::hoistLoopInvariantSubsets(LoopLikeOpInterface loopLike) {
+mlir::hoistLoopInvariantSubsets(RewriterBase &rewriter,
+ LoopLikeOpInterface loopLike) {
// Note: As subset ops are getting hoisted, the number of region iter_args
// increases. This can enable further hoisting opportunities on the new
// iter_args.
for (int64_t i = 0; i < loopLike.getRegionIterArgs().size(); ++i) {
- loopLike = hoistSubsetAtIterArg(loopLike, loopLike.getRegionIterArgs()[i]);
+ loopLike = hoistSubsetAtIterArg(rewriter, loopLike,
+ loopLike.getRegionIterArgs()[i]);
}
return loopLike;
}
diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir
index 3891c16b4115595..fd0c0fd2117ab63 100644
--- a/mlir/test/Dialect/Transform/test-interpreter.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter.mlir
@@ -2109,3 +2109,48 @@ transform.sequence failures(propagate) {
transform.yield
}
}
+
+// -----
+
+// CHECK-LABEL: func @test_loop_invariant_subset_hoisting(
+// CHECK-SAME: %[[arg:.*]]: tensor<?xf32>
+func.func @test_loop_invariant_subset_hoisting(%arg: tensor<?xf32>) -> tensor<?xf32> {
+ %lb = "test.foo"() : () -> (index)
+ %ub = "test.foo"() : () -> (index)
+ %step = "test.foo"() : () -> (index)
+ // CHECK: %[[extract:.*]] = tensor.extract_slice %[[arg]]
+ // CHECK: %[[for:.*]]:2 = scf.for {{.*}} iter_args(%[[t:.*]] = %[[arg]], %[[hoisted:.*]] = %[[extract]])
+ // expected-remark @below{{new loop op}}
+ %0 = scf.for %iv = %lb to %ub step %step iter_args(%t = %arg) -> (tensor<?xf32>) {
+ %1 = tensor.extract_slice %t[0][5][1] : tensor<?xf32> to tensor<5xf32>
+ // CHECK: %[[foo:.*]] = "test.foo"(%[[hoisted]])
+ %2 = "test.foo"(%1) : (tensor<5xf32>) -> (tensor<5xf32>)
+ // Obfuscate the IR by inserting at offset %sub instead of 0; both of them
+ // have the same value.
+ %3 = tensor.insert_slice %2 into %t[0][5][1] : tensor<5xf32> into tensor<?xf32>
+ // CHECK: scf.yield %[[t]], %[[foo]]
+ scf.yield %3 : tensor<?xf32>
+ }
+ // CHECK: %[[insert:.*]] = tensor.insert_slice %[[for]]#1 into %[[for]]#0
+ // CHECK: return %[[insert]]
+ return %0 : tensor<?xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+ %0 = transform.structured.match ops{["scf.for"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.structured.match ops{["tensor.extract_slice"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %2 = transform.structured.match ops{["tensor.insert_slice"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+
+ transform.hoist_loop_invariant_subsets %0 : !transform.any_op
+ // Make sure that the handles are still valid (and were updated in case of
+ // the loop).
+
+ // expected-remark @below{{1}}
+ test_print_number_of_associated_payload_ir_ops %0 : !transform.any_op
+ test_print_remark_at_operand %0, "new loop op" : !transform.any_op
+ // expected-remark @below{{1}}
+ test_print_number_of_associated_payload_ir_ops %1 : !transform.any_op
+ // expected-remark @below{{1}}
+ test_print_number_of_associated_payload_ir_ops %2 : !transform.any_op
+}
More information about the Mlir-commits
mailing list