[mlir] [llvm] [mlir][transform] LISH: Add transform op (PR #70630)
Matthias Springer via llvm-commits
llvm-commits at lists.llvm.org
Sat Nov 4 19:34:01 PDT 2023
https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/70630
>From bb6f8153f6b8cb6dee66541c84fc53eb1aa46784 Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Sun, 5 Nov 2023 11:19:40 +0900
Subject: [PATCH] [mlir][transform] Add transform op for loop-invariant subset
hoisting
---
.../Linalg/TransformOps/LinalgTransformOps.td | 50 ------------
.../mlir/Dialect/Transform/CMakeLists.txt | 1 +
.../Transform/LoopExtension/CMakeLists.txt | 6 ++
.../Transform/LoopExtension/LoopExtension.h | 16 ++++
.../LoopExtension/LoopExtensionOps.h | 23 ++++++
.../LoopExtension/LoopExtensionOps.td | 76 ++++++++++++++++++
.../Transform/PDLExtension/PDLExtensionOps.td | 2 +-
mlir/include/mlir/InitAllExtensions.h | 2 +
.../Transforms/LoopInvariantCodeMotionUtils.h | 4 +-
.../TransformOps/LinalgTransformOps.cpp | 29 -------
mlir/lib/Dialect/Transform/CMakeLists.txt | 1 +
.../Transform/LoopExtension/CMakeLists.txt | 13 ++++
.../Transform/LoopExtension/LoopExtension.cpp | 34 ++++++++
.../LoopExtension/LoopExtensionOps.cpp | 36 +++++++++
.../Transforms/LoopInvariantCodeMotion.cpp | 4 +-
.../Utils/LoopInvariantCodeMotionUtils.cpp | 17 ++--
.../Transform/test-loop-transforms.mlir | 78 +++++++++++++++++++
.../test/lib/Dialect/Transform/CMakeLists.txt | 1 +
.../llvm-project-overlay/mlir/BUILD.bazel | 48 ++++++++++++
19 files changed, 352 insertions(+), 89 deletions(-)
create mode 100644 mlir/include/mlir/Dialect/Transform/LoopExtension/CMakeLists.txt
create mode 100644 mlir/include/mlir/Dialect/Transform/LoopExtension/LoopExtension.h
create mode 100644 mlir/include/mlir/Dialect/Transform/LoopExtension/LoopExtensionOps.h
create mode 100644 mlir/include/mlir/Dialect/Transform/LoopExtension/LoopExtensionOps.td
create mode 100644 mlir/lib/Dialect/Transform/LoopExtension/CMakeLists.txt
create mode 100644 mlir/lib/Dialect/Transform/LoopExtension/LoopExtension.cpp
create mode 100644 mlir/lib/Dialect/Transform/LoopExtension/LoopExtensionOps.cpp
create mode 100644 mlir/test/Dialect/Transform/test-loop-transforms.mlir
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 9e3f79e64bb1d79..e60c3f364604527 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -2247,56 +2247,6 @@ def ConvertConv2DToImg2ColOp : Op<Transform_Dialect,
}];
}
-//===----------------------------------------------------------------------===//
-// HoistRedundantTensorSubsetsOp
-//===----------------------------------------------------------------------===//
-
-def HoistRedundantTensorSubsetsOp :
- Op<Transform_Dialect, "structured.hoist_redundant_tensor_subsets",
- [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
- TransformEachOpTrait,
- TransformOpInterface,
- ReportTrackingListenerFailuresOpTrait]> {
- let description = [{
- Hoists supported tensor subset extract/insert operation pairs out of
- immediately enclosing loop iteratively, if the following conditions
- are true:
- 1. The 2 ops access the same tensor subset.
- 2. All operands are invariant under the enclosing loop.
-
- The supported subset extract/insert operation pairs currently comprise:
- - tensor.extract_slice / tensor.insert_slice
- - vector.transfer_read / vector.transfer_write on tensors
-
- Only scf.for loops are currently supported.
-
- When applied to:
- 1. an scf.for loop, hoist out of this loop only.
- 2. a non-loop op, apply hoisting to all the contained loop ops.
-
- #### Return modes:
-
- The operation always succeeds and returns nothing.
- }];
-
- let arguments = (ins TransformHandleTypeInterface:$target);
- let results = (outs);
-
- let assemblyFormat = [{
- $target
- attr-dict
- `:` functional-type(operands, results)
- }];
-
- let extraClassDeclaration = [{
- ::mlir::DiagnosedSilenceableFailure applyToOne(
- ::mlir::transform::TransformRewriter &rewriter,
- ::mlir::Operation *target,
- ::mlir::transform::ApplyToEachResultList &results,
- ::mlir::transform::TransformState &state);
- }];
-}
-
//===----------------------------------------------------------------------===//
// InsertSliceToCopyOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Transform/CMakeLists.txt b/mlir/include/mlir/Dialect/Transform/CMakeLists.txt
index d9fbaee802398fb..d6c5c975c2e93c1 100644
--- a/mlir/include/mlir/Dialect/Transform/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Transform/CMakeLists.txt
@@ -1,3 +1,4 @@
add_subdirectory(IR)
+add_subdirectory(LoopExtension)
add_subdirectory(PDLExtension)
add_subdirectory(Transforms)
diff --git a/mlir/include/mlir/Dialect/Transform/LoopExtension/CMakeLists.txt b/mlir/include/mlir/Dialect/Transform/LoopExtension/CMakeLists.txt
new file mode 100644
index 000000000000000..8f5e510ad39a39e
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Transform/LoopExtension/CMakeLists.txt
@@ -0,0 +1,6 @@
+set(LLVM_TARGET_DEFINITIONS LoopExtensionOps.td)
+mlir_tablegen(LoopExtensionOps.h.inc -gen-op-decls)
+mlir_tablegen(LoopExtensionOps.cpp.inc -gen-op-defs)
+add_public_tablegen_target(MLIRTransformDialectLoopExtensionOpsIncGen)
+
+add_mlir_doc(LoopExtensionOps LoopExtensionOps Dialects/ -gen-op-doc)
diff --git a/mlir/include/mlir/Dialect/Transform/LoopExtension/LoopExtension.h b/mlir/include/mlir/Dialect/Transform/LoopExtension/LoopExtension.h
new file mode 100644
index 000000000000000..7a8ed2075ef12e3
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Transform/LoopExtension/LoopExtension.h
@@ -0,0 +1,16 @@
+//===- LoopExtension.h - Loop extension for Transform dialect ---*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+namespace mlir {
+class DialectRegistry;
+
+namespace transform {
+/// Registers the loop extension of the Transform dialect in the given registry.
+void registerLoopExtension(DialectRegistry &dialectRegistry);
+} // namespace transform
+} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/Transform/LoopExtension/LoopExtensionOps.h b/mlir/include/mlir/Dialect/Transform/LoopExtension/LoopExtensionOps.h
new file mode 100644
index 000000000000000..68cc0699d081a0d
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Transform/LoopExtension/LoopExtensionOps.h
@@ -0,0 +1,23 @@
+//===- LoopExtensionOps.h - Loop ext. for Transform dialect -----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_TRANSFORM_LOOPEXTENSION_LOOPEXTENSIONOPS_H
+#define MLIR_DIALECT_TRANSFORM_LOOPEXTENSION_LOOPEXTENSIONOPS_H
+
+#include "mlir/Bytecode/BytecodeOpInterface.h"
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/Interfaces/LoopLikeInterface.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/Transform/LoopExtension/LoopExtensionOps.h.inc"
+
+#endif // MLIR_DIALECT_TRANSFORM_LOOPEXTENSION_LOOPEXTENSIONOPS_H
diff --git a/mlir/include/mlir/Dialect/Transform/LoopExtension/LoopExtensionOps.td b/mlir/include/mlir/Dialect/Transform/LoopExtension/LoopExtensionOps.td
new file mode 100644
index 000000000000000..78a8c6ad489a9af
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Transform/LoopExtension/LoopExtensionOps.td
@@ -0,0 +1,76 @@
+//===- LoopExtensionOps.td - Transform dialect operations --*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_TRANSFORM_LOOPEXTENSION_LOOPEXTENSIONOPS
+#define MLIR_DIALECT_TRANSFORM_LOOPEXTENSION_LOOPEXTENSIONOPS
+
+include "mlir/Dialect/Transform/IR/TransformDialect.td"
+include "mlir/Dialect/Transform/IR/TransformInterfaces.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+
+def HoistLoopInvariantSubsetsOp
+ : TransformDialectOp<"loop.hoist_loop_invariant_subsets",
+ [TransformOpInterface, TransformEachOpTrait,
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+ ReportTrackingListenerFailuresOpTrait]> {
+ let summary = "Hoist loop invariant subset ops";
+ let description = [{
+ This transform hoists loop-invariant subset ops out of the targeted
+ loop-like op. It looks for matching subset extraction/insertion op pairs and
+ hoists them. The loop body operates on a newly introduced region iter_arg.
+
+ Subset ops are hoisted only from the targeted op. If subset ops should be
+ hoisted from an entire loop nest, this transformation must be applied to
+ each loop-like op of the loop nest, starting with the innermost loop and
+ ending with the outermost loop.
+
+ Example:
+ ```
+ %r = scf.for ... iter_args(%t = %a) -> (tensor<?xf32>) {
+ %0 = tensor.extract_slice %t[0][5][1] : tensor<?xf32> to tensor<5xf32>
+ %1 = "test.foo"(%0) : (tensor<5xf32>) -> (tensor<5xf32>)
+ %2 = tensor.insert_slice %1 into %t[0][5][1]
+ : tensor<5xf32> into tensor<?xf32>
+ scf.yield %2 : tensor<?xf32>
+ }
+ ```
+ Is transformed to:
+ ```
+ %0 = tensor.extract_slice %a[0][5][1] : tensor<?xf32> to tensor<5xf32>
+ %new_loop:2 = scf.for ... iter_args(%t = %a, %h = %0) -> (tensor<?xf32>) {
+ %1 = "test.foo"(%h) : (tensor<5xf32>) -> (tensor<5xf32>)
+ scf.yield %t, %2 : tensor<?xf32>, tensor<5xf32>
+ }
+ %r = tensor.insert_slice %new_loop#1 into %new_loop#0
+ : tensor<5xf32> into tensor<?xf32>
+ ```
+
+ Subset ops are hoisted only if there are no conflicting subset ops. E.g.,
+ if there were a second overlapping extraction in the above example, no ops
+ could be hoisted safely.
+
+ This transform reads the target handle and modifies the payload. This
+ transform does not invalidate any handles, but loop-like ops are replaced
+ with new loop-like ops when a subset op is hoisted. The transform rewriter
+ updates all handles accordingly.
+ }];
+
+ let arguments = (ins TransformHandleTypeInterface:$target);
+ let results = (outs);
+ let assemblyFormat = "$target attr-dict `:` type($target)";
+
+ let extraClassDeclaration = [{
+ ::mlir::DiagnosedSilenceableFailure applyToOne(
+ ::mlir::transform::TransformRewriter &rewriter,
+ ::mlir::LoopLikeOpInterface loopLikeOp,
+ ::mlir::transform::ApplyToEachResultList &results,
+ ::mlir::transform::TransformState &state);
+ }];
+}
+
+#endif // MLIR_DIALECT_TRANSFORM_LOOPEXTENSION_LOOPEXTENSIONOPS
diff --git a/mlir/include/mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.td b/mlir/include/mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.td
index 16107b3d0869f1a..206a799690aa59c 100644
--- a/mlir/include/mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.td
+++ b/mlir/include/mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.td
@@ -1,4 +1,4 @@
-//===- TransformOps.td - Transform dialect operations ------*- tablegen -*-===//
+//===- PDLExtensionOps.td - Transform dialect operations ---*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
diff --git a/mlir/include/mlir/InitAllExtensions.h b/mlir/include/mlir/InitAllExtensions.h
index 8e2ad3a2e34f60e..c04ce850fb96f41 100644
--- a/mlir/include/mlir/InitAllExtensions.h
+++ b/mlir/include/mlir/InitAllExtensions.h
@@ -34,6 +34,7 @@
#include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h"
#include "mlir/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.h"
#include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.h"
+#include "mlir/Dialect/Transform/LoopExtension/LoopExtension.h"
#include "mlir/Dialect/Transform/PDLExtension/PDLExtension.h"
#include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h"
#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h"
@@ -74,6 +75,7 @@ inline void registerAllExtensions(DialectRegistry ®istry) {
scf::registerTransformDialectExtension(registry);
sparse_tensor::registerTransformDialectExtension(registry);
tensor::registerTransformDialectExtension(registry);
+ transform::registerLoopExtension(registry);
transform::registerPDLExtension(registry);
vector::registerTransformDialectExtension(registry);
diff --git a/mlir/include/mlir/Transforms/LoopInvariantCodeMotionUtils.h b/mlir/include/mlir/Transforms/LoopInvariantCodeMotionUtils.h
index 579054070f729b0..3ceef44d799e893 100644
--- a/mlir/include/mlir/Transforms/LoopInvariantCodeMotionUtils.h
+++ b/mlir/include/mlir/Transforms/LoopInvariantCodeMotionUtils.h
@@ -18,6 +18,7 @@ namespace mlir {
class LoopLikeOpInterface;
class Operation;
class Region;
+class RewriterBase;
class Value;
/// Given a list of regions, perform loop-invariant code motion. An operation is
@@ -108,7 +109,8 @@ size_t moveLoopInvariantCode(LoopLikeOpInterface loopLike);
/// %r = tensor.insert_slice %new_loop#1 into %new_loop#0
/// : tensor<5xf32> into tensor<?xf32>
/// ```
-LoopLikeOpInterface hoistLoopInvariantSubsets(LoopLikeOpInterface loopLike);
+LoopLikeOpInterface hoistLoopInvariantSubsets(RewriterBase &rewriter,
+ LoopLikeOpInterface loopLike);
} // end namespace mlir
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 87be3bb85b6e788..fd8a1657db3ae5d 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3163,35 +3163,6 @@ DiagnosedSilenceableFailure transform::ConvertConv2DToImg2ColOp::applyToOne(
return DiagnosedSilenceableFailure::success();
}
-//===----------------------------------------------------------------------===//
-// HoistRedundantTensorSubsetsOp
-//===----------------------------------------------------------------------===//
-
-DiagnosedSilenceableFailure
-transform::HoistRedundantTensorSubsetsOp::applyToOne(
- transform::TransformRewriter &rewriter, Operation *target,
- transform::ApplyToEachResultList &results,
- transform::TransformState &state) {
- auto forOp = dyn_cast<scf::ForOp>(target);
- if (forOp) {
- linalg::hoistRedundantSubsetExtractInsert(rewriter, forOp);
- return DiagnosedSilenceableFailure::success();
- }
-
- // TODO: walking in some reverse / inside-out order would be more efficient
- // and would capture more cases.
- target->walk([&](scf::ForOp forOp) {
- hoistRedundantSubsetExtractInsert(rewriter, forOp);
- });
- return DiagnosedSilenceableFailure::success();
-}
-
-void transform::HoistRedundantTensorSubsetsOp::getEffects(
- SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
- transform::onlyReadsHandle(getTarget(), effects);
- transform::modifiesPayload(effects);
-}
-
//===----------------------------------------------------------------------===//
// InsertSliceToCopyOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Transform/CMakeLists.txt b/mlir/lib/Dialect/Transform/CMakeLists.txt
index 9e144eba25710dd..6898d81df7ca63f 100644
--- a/mlir/lib/Dialect/Transform/CMakeLists.txt
+++ b/mlir/lib/Dialect/Transform/CMakeLists.txt
@@ -1,4 +1,5 @@
add_subdirectory(IR)
+add_subdirectory(LoopExtension)
add_subdirectory(PDLExtension)
add_subdirectory(Transforms)
add_subdirectory(Utils)
diff --git a/mlir/lib/Dialect/Transform/LoopExtension/CMakeLists.txt b/mlir/lib/Dialect/Transform/LoopExtension/CMakeLists.txt
new file mode 100644
index 000000000000000..9e1abdd1ca17b56
--- /dev/null
+++ b/mlir/lib/Dialect/Transform/LoopExtension/CMakeLists.txt
@@ -0,0 +1,13 @@
+add_mlir_dialect_library(MLIRTransformLoopExtension
+ LoopExtension.cpp
+ LoopExtensionOps.cpp
+
+ DEPENDS
+ MLIRTransformDialectLoopExtensionOpsIncGen
+
+ LINK_LIBS PUBLIC
+ MLIRIR
+ MLIRLoopLikeInterface
+ MLIRTransformDialect
+ MLIRTransforms
+)
diff --git a/mlir/lib/Dialect/Transform/LoopExtension/LoopExtension.cpp b/mlir/lib/Dialect/Transform/LoopExtension/LoopExtension.cpp
new file mode 100644
index 000000000000000..b33288fd7b991fd
--- /dev/null
+++ b/mlir/lib/Dialect/Transform/LoopExtension/LoopExtension.cpp
@@ -0,0 +1,34 @@
+//===- LoopExtension.cpp - Loop extension for the Transform dialect -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Transform/LoopExtension/LoopExtension.h"
+
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Transform/LoopExtension/LoopExtensionOps.h"
+#include "mlir/IR/DialectRegistry.h"
+
+using namespace mlir;
+
+namespace {
+/// Loop extension of the Transform dialect. This provides "core" transform
+/// operations for loop-like ops.
+class LoopExtension
+ : public transform::TransformDialectExtension<LoopExtension> {
+public:
+ void init() {
+ registerTransformOps<
+#define GET_OP_LIST
+#include "mlir/Dialect/Transform/LoopExtension/LoopExtensionOps.cpp.inc"
+ >();
+ }
+};
+} // namespace
+
+void mlir::transform::registerLoopExtension(DialectRegistry &dialectRegistry) {
+ dialectRegistry.addExtensions<LoopExtension>();
+}
diff --git a/mlir/lib/Dialect/Transform/LoopExtension/LoopExtensionOps.cpp b/mlir/lib/Dialect/Transform/LoopExtension/LoopExtensionOps.cpp
new file mode 100644
index 000000000000000..c992fd15946f36f
--- /dev/null
+++ b/mlir/lib/Dialect/Transform/LoopExtension/LoopExtensionOps.cpp
@@ -0,0 +1,36 @@
+//===- LoopExtensionOps.cpp - Loop extension for the Transform dialect ----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Transform/LoopExtension/LoopExtensionOps.h"
+
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/LoopInvariantCodeMotionUtils.h"
+
+using namespace mlir;
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/Transform/LoopExtension/LoopExtensionOps.cpp.inc"
+
+//===----------------------------------------------------------------------===//
+// HoistLoopInvariantSubsetsOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure transform::HoistLoopInvariantSubsetsOp::applyToOne(
+ transform::TransformRewriter &rewriter, LoopLikeOpInterface loopLikeOp,
+ transform::ApplyToEachResultList &results,
+ transform::TransformState &state) {
+ hoistLoopInvariantSubsets(rewriter, loopLikeOp);
+ return DiagnosedSilenceableFailure::success();
+}
+
+void transform::HoistLoopInvariantSubsetsOp::getEffects(
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ transform::onlyReadsHandle(getTarget(), effects);
+ transform::modifiesPayload(effects);
+}
diff --git a/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp b/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp
index e6d8af8f05832d3..02c3ea1ce9b650c 100644
--- a/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp
+++ b/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp
@@ -12,6 +12,7 @@
#include "mlir/Transforms/Passes.h"
+#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/LoopLikeInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Transforms/LoopInvariantCodeMotionUtils.h"
@@ -47,11 +48,12 @@ void LoopInvariantCodeMotion::runOnOperation() {
}
void LoopInvariantSubsetHoisting::runOnOperation() {
+ IRRewriter rewriter(getOperation()->getContext());
// Walk through all loops in a function in innermost-loop-first order. This
// way, we first hoist from the inner loop, and place the ops in the outer
// loop, which in turn can be further hoisted from.
getOperation()->walk([&](LoopLikeOpInterface loopLike) {
- (void)hoistLoopInvariantSubsets(loopLike);
+ (void)hoistLoopInvariantSubsets(rewriter, loopLike);
});
}
diff --git a/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp b/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
index 53bdb7aafe41a0c..8f97fd3d9ddf84e 100644
--- a/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
@@ -311,12 +311,12 @@ MatchingSubsets::populateSubsetOpsAtIterArg(LoopLikeOpInterface loopLike,
/// loop-like op and index into loop-invariant subset locations. Return the
/// newly created loop op (that has extra iter_args) or the original loop op if
/// nothing was hoisted.
-static LoopLikeOpInterface hoistSubsetAtIterArg(LoopLikeOpInterface loopLike,
+static LoopLikeOpInterface hoistSubsetAtIterArg(RewriterBase &rewriter,
+ LoopLikeOpInterface loopLike,
BlockArgument iterArg) {
assert(iterArg.getOwner()->getParentOp() == loopLike && "invalid iter_arg");
auto it = llvm::find(loopLike.getRegionIterArgs(), iterArg);
int64_t iterArgIdx = std::distance(loopLike.getRegionIterArgs().begin(), it);
- IRRewriter rewriter(loopLike.getContext());
MatchingSubsets subsets;
if (failed(subsets.populateSubsetOpsAtIterArg(loopLike, iterArg)))
return loopLike;
@@ -367,11 +367,12 @@ static LoopLikeOpInterface hoistSubsetAtIterArg(LoopLikeOpInterface loopLike,
OpResult newLoopResult = loopLike.getLoopResults()->back();
extractionOp->moveBefore(loopLike);
insertionOp->moveAfter(loopLike);
- insertionOp.getUpdatedDestination().replaceAllUsesWith(
- insertionOp.getDestinationOperand().get());
+ rewriter.replaceAllUsesWith(insertionOp.getUpdatedDestination(),
+ insertionOp.getDestinationOperand().get());
extractionOp.getSourceOperand().set(
loopLike.getTiedLoopInit(iterArg)->get());
- loopResult.replaceAllUsesWith(insertionOp.getUpdatedDestination());
+ rewriter.replaceAllUsesWith(loopResult,
+ insertionOp.getUpdatedDestination());
insertionOp.getSourceOperand().set(newLoopResult);
insertionOp.getDestinationOperand().set(loopResult);
}
@@ -381,13 +382,15 @@ static LoopLikeOpInterface hoistSubsetAtIterArg(LoopLikeOpInterface loopLike,
}
LoopLikeOpInterface
-mlir::hoistLoopInvariantSubsets(LoopLikeOpInterface loopLike) {
+mlir::hoistLoopInvariantSubsets(RewriterBase &rewriter,
+ LoopLikeOpInterface loopLike) {
// Note: As subset ops are getting hoisted, the number of region iter_args
// increases. This can enable further hoisting opportunities on the new
// iter_args.
for (int64_t i = 0;
i < static_cast<int64_t>(loopLike.getRegionIterArgs().size()); ++i) {
- loopLike = hoistSubsetAtIterArg(loopLike, loopLike.getRegionIterArgs()[i]);
+ loopLike = hoistSubsetAtIterArg(rewriter, loopLike,
+ loopLike.getRegionIterArgs()[i]);
}
return loopLike;
}
diff --git a/mlir/test/Dialect/Transform/test-loop-transforms.mlir b/mlir/test/Dialect/Transform/test-loop-transforms.mlir
new file mode 100644
index 000000000000000..425962757f720b7
--- /dev/null
+++ b/mlir/test/Dialect/Transform/test-loop-transforms.mlir
@@ -0,0 +1,78 @@
+// RUN: mlir-opt %s --transform-interpreter --split-input-file \
+// RUN: --verify-diagnostics | FileCheck %s
+
+// UNSUPPORTED: target=aarch64-pc-windows-msvc
+
+// CHECK-LABEL: func @test_loop_invariant_subset_hoisting(
+// CHECK-SAME: %[[arg:.*]]: tensor<?xf32>
+func.func @test_loop_invariant_subset_hoisting(%arg: tensor<?xf32>) -> tensor<?xf32> {
+ %lb = "test.foo"() : () -> (index)
+ %ub = "test.foo"() : () -> (index)
+ %step = "test.foo"() : () -> (index)
+ // CHECK: %[[extract:.*]] = tensor.extract_slice %[[arg]]
+ // CHECK: %[[for:.*]]:2 = scf.for {{.*}} iter_args(%[[t:.*]] = %[[arg]], %[[hoisted:.*]] = %[[extract]])
+ // expected-remark @below{{new loop op}}
+ %0 = scf.for %iv = %lb to %ub step %step iter_args(%t = %arg) -> (tensor<?xf32>) {
+ %1 = tensor.extract_slice %t[0][5][1] : tensor<?xf32> to tensor<5xf32>
+ // CHECK: %[[foo:.*]] = "test.foo"(%[[hoisted]])
+ %2 = "test.foo"(%1) : (tensor<5xf32>) -> (tensor<5xf32>)
+ // Obfuscate the IR by inserting at offset %sub instead of 0; both of them
+ // have the same value.
+ %3 = tensor.insert_slice %2 into %t[0][5][1] : tensor<5xf32> into tensor<?xf32>
+ // CHECK: scf.yield %[[t]], %[[foo]]
+ scf.yield %3 : tensor<?xf32>
+ }
+ // CHECK: %[[insert:.*]] = tensor.insert_slice %[[for]]#1 into %[[for]]#0
+ // CHECK: return %[[insert]]
+ return %0 : tensor<?xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["scf.for"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.structured.match ops{["tensor.extract_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %2 = transform.structured.match ops{["tensor.insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+
+ transform.loop.hoist_loop_invariant_subsets %0 : !transform.any_op
+ // Make sure that the handles are still valid (and were updated in case of
+ // the loop).
+
+ // expected-remark @below{{1}}
+ transform.test_print_number_of_associated_payload_ir_ops %0 : !transform.any_op
+ transform.test_print_remark_at_operand %0, "new loop op" : !transform.any_op
+ // expected-remark @below{{1}}
+ transform.test_print_number_of_associated_payload_ir_ops %1 : !transform.any_op
+ // expected-remark @below{{1}}
+ transform.test_print_number_of_associated_payload_ir_ops %2 : !transform.any_op
+
+ transform.yield
+ }
+}
+
+// -----
+
+// Checks that transform ops from LoopExtensionOps and SCFTransformOps can be
+// used together.
+
+// CHECK-LABEL: func @test_mixed_loop_extension_scf_transform(
+func.func @test_mixed_loop_extension_scf_transform(%arg: tensor<?xf32>) -> tensor<?xf32> {
+ %lb = "test.foo"() : () -> (index)
+ %ub = "test.foo"() : () -> (index)
+ %step = "test.foo"() : () -> (index)
+ // CHECK: scf.for
+ // CHECK: scf.for
+ %0 = scf.for %iv = %lb to %ub step %step iter_args(%t = %arg) -> (tensor<?xf32>) {
+ %1 = "test.foo"(%t) : (tensor<?xf32>) -> (tensor<?xf32>)
+ scf.yield %1 : tensor<?xf32>
+ }
+ return %0 : tensor<?xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["scf.for"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ transform.loop.hoist_loop_invariant_subsets %0 : !transform.any_op
+ transform.loop.unroll %0 { factor = 4 } : !transform.any_op
+ transform.yield
+ }
+}
diff --git a/mlir/test/lib/Dialect/Transform/CMakeLists.txt b/mlir/test/lib/Dialect/Transform/CMakeLists.txt
index c7e83d3a7128bcb..436f892a27232b2 100644
--- a/mlir/test/lib/Dialect/Transform/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Transform/CMakeLists.txt
@@ -21,5 +21,6 @@ add_mlir_library(MLIRTestTransformDialect
MLIRPDLDialect
MLIRTransformDialect
MLIRTransformDialectTransforms
+ MLIRTransformLoopExtension
MLIRTransformPDLExtension
)
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 2cadd4e0d2911a6..99aa78bb3d3d33b 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -4416,6 +4416,7 @@ cc_library(
":SCFTransformOps",
":SparseTensorTransformOps",
":TensorTransformOps",
+ ":TransformLoopExtension",
":TransformPDLExtension",
":UBToLLVM",
":VectorTransformOps",
@@ -8677,6 +8678,7 @@ cc_library(
":TosaToLinalg",
":TransformDialect",
":TransformDialectTransforms",
+ ":TransformLoopExtension",
":TransformPDLExtension",
":Transforms",
":TransformsPassIncGen",
@@ -11401,6 +11403,52 @@ cc_library(
],
)
+td_library(
+ name = "TransformLoopExtensionTdFiles",
+ srcs = glob(["include/mlir/Dialect/Transform/LoopExtension/*.td"]),
+ deps = [
+ ":TransformDialectTdFiles",
+ ],
+)
+
+gentbl_cc_library(
+ name = "TransformLoopExtensionOpsIncGen",
+ tbl_outs = [
+ (
+ [
+ "-gen-op-decls",
+ ],
+ "include/mlir/Dialect/Transform/LoopExtension/LoopExtensionOps.h.inc",
+ ),
+ (
+ [
+ "-gen-op-defs",
+ ],
+ "include/mlir/Dialect/Transform/LoopExtension/LoopExtensionOps.cpp.inc",
+ ),
+ ],
+ tblgen = ":mlir-tblgen",
+ td_file = "include/mlir/Dialect/Transform/LoopExtension/LoopExtensionOps.td",
+ deps = [":TransformLoopExtensionTdFiles"],
+)
+
+cc_library(
+ name = "TransformLoopExtension",
+ srcs = glob(["lib/Dialect/Transform/LoopExtension/*.cpp"]),
+ hdrs = glob(["include/mlir/Dialect/Transform/LoopExtension/*.h"]),
+ deps = [
+ ":IR",
+ ":LoopLikeInterface",
+ ":Rewrite",
+ ":SideEffectInterfaces",
+ ":Support",
+ ":TransformDialect",
+ ":TransformLoopExtensionOpsIncGen",
+ ":Transforms",
+ "//llvm:Support",
+ ],
+)
+
td_library(
name = "TransformDialectTransformsTdFiles",
srcs = glob(["include/mlir/Dialect/Transform/Transforms/*.td"]),
More information about the llvm-commits
mailing list