[Mlir-commits] [mlir] [llvm] [mlir][transform] LISH: Add transform op (PR #70630)

Matthias Springer llvmlistbot at llvm.org
Tue Oct 31 20:26:27 PDT 2023


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/70630

>From f7150d279146f3b14b4df3a9fca49d67cc2f67a9 Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Wed, 1 Nov 2023 12:23:24 +0900
Subject: [PATCH] [mlir][transform] Add transform op for loop-invariant subset
 hoisting

---
 .../Linalg/TransformOps/LinalgTransformOps.td | 50 ------------
 .../mlir/Dialect/Transform/CMakeLists.txt     |  1 +
 .../Transform/LoopExtension/CMakeLists.txt    |  6 ++
 .../Transform/LoopExtension/LoopExtension.h   | 16 ++++
 .../LoopExtension/LoopExtensionOps.h          | 23 ++++++
 .../LoopExtension/LoopExtensionOps.td         | 76 +++++++++++++++++++
 .../Transform/PDLExtension/PDLExtensionOps.td |  2 +-
 mlir/include/mlir/InitAllExtensions.h         |  2 +
 .../Transforms/LoopInvariantCodeMotionUtils.h |  4 +-
 .../TransformOps/LinalgTransformOps.cpp       | 29 -------
 mlir/lib/Dialect/Transform/CMakeLists.txt     |  1 +
 .../Transform/LoopExtension/CMakeLists.txt    | 13 ++++
 .../Transform/LoopExtension/LoopExtension.cpp | 34 +++++++++
 .../LoopExtension/LoopExtensionOps.cpp        | 36 +++++++++
 .../Transforms/LoopInvariantCodeMotion.cpp    |  4 +-
 .../Utils/LoopInvariantCodeMotionUtils.cpp    | 17 +++--
 .../Transform/test-loop-transforms.mlir       | 50 ++++++++++++
 .../test/lib/Dialect/Transform/CMakeLists.txt |  1 +
 .../llvm-project-overlay/mlir/BUILD.bazel     | 48 ++++++++++++
 19 files changed, 324 insertions(+), 89 deletions(-)
 create mode 100644 mlir/include/mlir/Dialect/Transform/LoopExtension/CMakeLists.txt
 create mode 100644 mlir/include/mlir/Dialect/Transform/LoopExtension/LoopExtension.h
 create mode 100644 mlir/include/mlir/Dialect/Transform/LoopExtension/LoopExtensionOps.h
 create mode 100644 mlir/include/mlir/Dialect/Transform/LoopExtension/LoopExtensionOps.td
 create mode 100644 mlir/lib/Dialect/Transform/LoopExtension/CMakeLists.txt
 create mode 100644 mlir/lib/Dialect/Transform/LoopExtension/LoopExtension.cpp
 create mode 100644 mlir/lib/Dialect/Transform/LoopExtension/LoopExtensionOps.cpp
 create mode 100644 mlir/test/Dialect/Transform/test-loop-transforms.mlir

diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 9e3f79e64bb1d79..e60c3f364604527 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -2247,56 +2247,6 @@ def ConvertConv2DToImg2ColOp : Op<Transform_Dialect,
   }];
 }
 
-//===----------------------------------------------------------------------===//
-// HoistRedundantTensorSubsetsOp
-//===----------------------------------------------------------------------===//
-
-def HoistRedundantTensorSubsetsOp :
-  Op<Transform_Dialect, "structured.hoist_redundant_tensor_subsets",
-    [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
-     TransformEachOpTrait,
-     TransformOpInterface,
-     ReportTrackingListenerFailuresOpTrait]> {
-  let description = [{
-    Hoists supported tensor subset extract/insert operation pairs out of
-    immediately enclosing loop iteratively, if the following conditions
-    are true:
-       1. The 2 ops access the same tensor subset.
-       2. All operands are invariant under the enclosing loop.
-
-    The supported subset extract/insert operation pairs currently comprise:
-       - tensor.extract_slice / tensor.insert_slice
-       - vector.transfer_read / vector.transfer_write on tensors
-
-    Only scf.for loops are currently supported.
-
-    When applied to:
-       1. an scf.for loop, hoist out of this loop only.
-       2. a non-loop op, apply hoisting to all the contained loop ops.
-
-    #### Return modes:
-
-    The operation always succeeds and returns nothing.
-  }];
-
-  let arguments = (ins TransformHandleTypeInterface:$target);
-  let results = (outs);
-
-  let assemblyFormat = [{
-    $target
-    attr-dict
-    `:` functional-type(operands, results)
-  }];
-
-  let extraClassDeclaration = [{
-    ::mlir::DiagnosedSilenceableFailure applyToOne(
-        ::mlir::transform::TransformRewriter &rewriter,
-        ::mlir::Operation *target,
-        ::mlir::transform::ApplyToEachResultList &results,
-        ::mlir::transform::TransformState &state);
-  }];
-}
-
 //===----------------------------------------------------------------------===//
 // InsertSliceToCopyOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Transform/CMakeLists.txt b/mlir/include/mlir/Dialect/Transform/CMakeLists.txt
index d9fbaee802398fb..d6c5c975c2e93c1 100644
--- a/mlir/include/mlir/Dialect/Transform/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Transform/CMakeLists.txt
@@ -1,3 +1,4 @@
 add_subdirectory(IR)
+add_subdirectory(LoopExtension)
 add_subdirectory(PDLExtension)
 add_subdirectory(Transforms)
diff --git a/mlir/include/mlir/Dialect/Transform/LoopExtension/CMakeLists.txt b/mlir/include/mlir/Dialect/Transform/LoopExtension/CMakeLists.txt
new file mode 100644
index 000000000000000..8f5e510ad39a39e
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Transform/LoopExtension/CMakeLists.txt
@@ -0,0 +1,6 @@
+set(LLVM_TARGET_DEFINITIONS LoopExtensionOps.td)
+mlir_tablegen(LoopExtensionOps.h.inc -gen-op-decls)
+mlir_tablegen(LoopExtensionOps.cpp.inc -gen-op-defs)
+add_public_tablegen_target(MLIRTransformDialectLoopExtensionOpsIncGen)
+
+add_mlir_doc(LoopExtensionOps LoopExtensionOps Dialects/ -gen-op-doc)
diff --git a/mlir/include/mlir/Dialect/Transform/LoopExtension/LoopExtension.h b/mlir/include/mlir/Dialect/Transform/LoopExtension/LoopExtension.h
new file mode 100644
index 000000000000000..7a8ed2075ef12e3
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Transform/LoopExtension/LoopExtension.h
@@ -0,0 +1,16 @@
+//===- LoopExtension.h - Loop extension for Transform dialect ---*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+namespace mlir {
+class DialectRegistry;
+
+namespace transform {
+/// Registers the loop extension of the Transform dialect in the given registry.
+void registerLoopExtension(DialectRegistry &dialectRegistry);
+} // namespace transform
+} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/Transform/LoopExtension/LoopExtensionOps.h b/mlir/include/mlir/Dialect/Transform/LoopExtension/LoopExtensionOps.h
new file mode 100644
index 000000000000000..68cc0699d081a0d
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Transform/LoopExtension/LoopExtensionOps.h
@@ -0,0 +1,23 @@
+//===- LoopExtensionOps.h - Loop ext. for Transform dialect -----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_TRANSFORM_LOOPEXTENSION_LOOPEXTENSIONOPS_H
+#define MLIR_DIALECT_TRANSFORM_LOOPEXTENSION_LOOPEXTENSIONOPS_H
+
+#include "mlir/Bytecode/BytecodeOpInterface.h"
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/Interfaces/LoopLikeInterface.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/Transform/LoopExtension/LoopExtensionOps.h.inc"
+
+#endif // MLIR_DIALECT_TRANSFORM_LOOPEXTENSION_LOOPEXTENSIONOPS_H
diff --git a/mlir/include/mlir/Dialect/Transform/LoopExtension/LoopExtensionOps.td b/mlir/include/mlir/Dialect/Transform/LoopExtension/LoopExtensionOps.td
new file mode 100644
index 000000000000000..78a8c6ad489a9af
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Transform/LoopExtension/LoopExtensionOps.td
@@ -0,0 +1,76 @@
+//===- LoopExtensionOps.td - Transform dialect operations --*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_TRANSFORM_LOOPEXTENSION_LOOPEXTENSIONOPS
+#define MLIR_DIALECT_TRANSFORM_LOOPEXTENSION_LOOPEXTENSIONOPS
+
+include "mlir/Dialect/Transform/IR/TransformDialect.td"
+include "mlir/Dialect/Transform/IR/TransformInterfaces.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+
+def HoistLoopInvariantSubsetsOp
+    : TransformDialectOp<"loop.hoist_loop_invariant_subsets",
+        [TransformOpInterface, TransformEachOpTrait,
+         DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+         ReportTrackingListenerFailuresOpTrait]> {
+  let summary = "Hoist loop invariant subset ops";
+  let description = [{
+    This transform hoists loop-invariant subset ops out of the targeted
+    loop-like op. It looks for matching subset extraction/insertion op pairs and
+    hoists them. The loop body operates on a newly introduced region iter_arg.
+
+    Subset ops are hoisted only from the targeted op. If subset ops should be
+    hoisted from an entire loop nest, this transformation must be applied to
+    each loop-like op of the loop nest, starting with the innermost loop and
+    ending with the outermost loop.
+
+    Example:
+    ```
+    %r = scf.for ... iter_args(%t = %a) -> (tensor<?xf32>) {
+      %0 = tensor.extract_slice %t[0][5][1] : tensor<?xf32> to tensor<5xf32>
+      %1 = "test.foo"(%0) : (tensor<5xf32>) -> (tensor<5xf32>)
+      %2 = tensor.insert_slice %1 into %t[0][5][1]
+          : tensor<5xf32> into tensor<?xf32>
+      scf.yield %2 : tensor<?xf32>
+    }
+    ```
+    Is transformed to:
+    ```
+    %0 = tensor.extract_slice %a[0][5][1] : tensor<?xf32> to tensor<5xf32>
+    %new_loop:2 = scf.for ... iter_args(%t = %a, %h = %0) -> (tensor<?xf32>) {
+      %1 = "test.foo"(%h) : (tensor<5xf32>) -> (tensor<5xf32>)
+      scf.yield %t, %2 : tensor<?xf32>, tensor<5xf32>
+    }
+    %r = tensor.insert_slice %new_loop#1 into %new_loop#0
+        : tensor<5xf32> into tensor<?xf32>
+    ```
+
+    Subset ops are hoisted only if there are no conflicting subset ops. E.g.,
+    if there were a second overlapping extraction in the above example, no ops
+    could be hoisted safely.
+
+    This transform reads the target handle and modifies the payload. This
+    transform does not invalidate any handles, but loop-like ops are replaced
+    with new loop-like ops when a subset op is hoisted. The transform rewriter
+    updates all handles accordingly.
+  }];
+
+  let arguments = (ins TransformHandleTypeInterface:$target);
+  let results = (outs);
+  let assemblyFormat = "$target attr-dict `:` type($target)";
+
+  let extraClassDeclaration = [{
+    ::mlir::DiagnosedSilenceableFailure applyToOne(
+      ::mlir::transform::TransformRewriter &rewriter,
+      ::mlir::LoopLikeOpInterface loopLikeOp,
+      ::mlir::transform::ApplyToEachResultList &results,
+      ::mlir::transform::TransformState &state);
+  }];
+}
+
+#endif // MLIR_DIALECT_TRANSFORM_LOOPEXTENSION_LOOPEXTENSIONOPS
diff --git a/mlir/include/mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.td b/mlir/include/mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.td
index 16107b3d0869f1a..206a799690aa59c 100644
--- a/mlir/include/mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.td
+++ b/mlir/include/mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.td
@@ -1,4 +1,4 @@
-//===- TransformOps.td - Transform dialect operations ------*- tablegen -*-===//
+//===- PDLExtensionOps.td - Transform dialect operations ---*- tablegen -*-===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
diff --git a/mlir/include/mlir/InitAllExtensions.h b/mlir/include/mlir/InitAllExtensions.h
index 8e2ad3a2e34f60e..c04ce850fb96f41 100644
--- a/mlir/include/mlir/InitAllExtensions.h
+++ b/mlir/include/mlir/InitAllExtensions.h
@@ -34,6 +34,7 @@
 #include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h"
 #include "mlir/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.h"
 #include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.h"
+#include "mlir/Dialect/Transform/LoopExtension/LoopExtension.h"
 #include "mlir/Dialect/Transform/PDLExtension/PDLExtension.h"
 #include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h"
 #include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h"
@@ -74,6 +75,7 @@ inline void registerAllExtensions(DialectRegistry &registry) {
   scf::registerTransformDialectExtension(registry);
   sparse_tensor::registerTransformDialectExtension(registry);
   tensor::registerTransformDialectExtension(registry);
+  transform::registerLoopExtension(registry);
   transform::registerPDLExtension(registry);
   vector::registerTransformDialectExtension(registry);
 
diff --git a/mlir/include/mlir/Transforms/LoopInvariantCodeMotionUtils.h b/mlir/include/mlir/Transforms/LoopInvariantCodeMotionUtils.h
index 579054070f729b0..3ceef44d799e893 100644
--- a/mlir/include/mlir/Transforms/LoopInvariantCodeMotionUtils.h
+++ b/mlir/include/mlir/Transforms/LoopInvariantCodeMotionUtils.h
@@ -18,6 +18,7 @@ namespace mlir {
 class LoopLikeOpInterface;
 class Operation;
 class Region;
+class RewriterBase;
 class Value;
 
 /// Given a list of regions, perform loop-invariant code motion. An operation is
@@ -108,7 +109,8 @@ size_t moveLoopInvariantCode(LoopLikeOpInterface loopLike);
 /// %r = tensor.insert_slice %new_loop#1 into %new_loop#0
 ///     : tensor<5xf32> into tensor<?xf32>
 /// ```
-LoopLikeOpInterface hoistLoopInvariantSubsets(LoopLikeOpInterface loopLike);
+LoopLikeOpInterface hoistLoopInvariantSubsets(RewriterBase &rewriter,
+                                              LoopLikeOpInterface loopLike);
 
 } // end namespace mlir
 
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 87be3bb85b6e788..fd8a1657db3ae5d 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3163,35 +3163,6 @@ DiagnosedSilenceableFailure transform::ConvertConv2DToImg2ColOp::applyToOne(
   return DiagnosedSilenceableFailure::success();
 }
 
-//===----------------------------------------------------------------------===//
-// HoistRedundantTensorSubsetsOp
-//===----------------------------------------------------------------------===//
-
-DiagnosedSilenceableFailure
-transform::HoistRedundantTensorSubsetsOp::applyToOne(
-    transform::TransformRewriter &rewriter, Operation *target,
-    transform::ApplyToEachResultList &results,
-    transform::TransformState &state) {
-  auto forOp = dyn_cast<scf::ForOp>(target);
-  if (forOp) {
-    linalg::hoistRedundantSubsetExtractInsert(rewriter, forOp);
-    return DiagnosedSilenceableFailure::success();
-  }
-
-  // TODO: walking in some reverse / inside-out order would be more efficient
-  // and would capture more cases.
-  target->walk([&](scf::ForOp forOp) {
-    hoistRedundantSubsetExtractInsert(rewriter, forOp);
-  });
-  return DiagnosedSilenceableFailure::success();
-}
-
-void transform::HoistRedundantTensorSubsetsOp::getEffects(
-    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
-  transform::onlyReadsHandle(getTarget(), effects);
-  transform::modifiesPayload(effects);
-}
-
 //===----------------------------------------------------------------------===//
 // InsertSliceToCopyOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Transform/CMakeLists.txt b/mlir/lib/Dialect/Transform/CMakeLists.txt
index 9e144eba25710dd..6898d81df7ca63f 100644
--- a/mlir/lib/Dialect/Transform/CMakeLists.txt
+++ b/mlir/lib/Dialect/Transform/CMakeLists.txt
@@ -1,4 +1,5 @@
 add_subdirectory(IR)
+add_subdirectory(LoopExtension)
 add_subdirectory(PDLExtension)
 add_subdirectory(Transforms)
 add_subdirectory(Utils)
diff --git a/mlir/lib/Dialect/Transform/LoopExtension/CMakeLists.txt b/mlir/lib/Dialect/Transform/LoopExtension/CMakeLists.txt
new file mode 100644
index 000000000000000..9e1abdd1ca17b56
--- /dev/null
+++ b/mlir/lib/Dialect/Transform/LoopExtension/CMakeLists.txt
@@ -0,0 +1,13 @@
+add_mlir_dialect_library(MLIRTransformLoopExtension
+  LoopExtension.cpp
+  LoopExtensionOps.cpp
+
+  DEPENDS
+  MLIRTransformDialectLoopExtensionOpsIncGen
+
+  LINK_LIBS PUBLIC
+  MLIRIR
+  MLIRLoopLikeInterface
+  MLIRTransformDialect
+  MLIRTransforms
+)
diff --git a/mlir/lib/Dialect/Transform/LoopExtension/LoopExtension.cpp b/mlir/lib/Dialect/Transform/LoopExtension/LoopExtension.cpp
new file mode 100644
index 000000000000000..b33288fd7b991fd
--- /dev/null
+++ b/mlir/lib/Dialect/Transform/LoopExtension/LoopExtension.cpp
@@ -0,0 +1,34 @@
+//===- LoopExtension.cpp - Loop extension for the Transform dialect -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Transform/LoopExtension/LoopExtension.h"
+
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Transform/LoopExtension/LoopExtensionOps.h"
+#include "mlir/IR/DialectRegistry.h"
+
+using namespace mlir;
+
+namespace {
+/// Loop extension of the Transform dialect. This provides "core" transform
+/// operations for loop-like ops.
+class LoopExtension
+    : public transform::TransformDialectExtension<LoopExtension> {
+public:
+  void init() {
+    registerTransformOps<
+#define GET_OP_LIST
+#include "mlir/Dialect/Transform/LoopExtension/LoopExtensionOps.cpp.inc"
+        >();
+  }
+};
+} // namespace
+
+void mlir::transform::registerLoopExtension(DialectRegistry &dialectRegistry) {
+  dialectRegistry.addExtensions<LoopExtension>();
+}
diff --git a/mlir/lib/Dialect/Transform/LoopExtension/LoopExtensionOps.cpp b/mlir/lib/Dialect/Transform/LoopExtension/LoopExtensionOps.cpp
new file mode 100644
index 000000000000000..867a8da92fe404e
--- /dev/null
+++ b/mlir/lib/Dialect/Transform/LoopExtension/LoopExtensionOps.cpp
@@ -0,0 +1,36 @@
+//===- LoopExtensionOps.cpp - Loop extension for the Transform dialect ----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Transform/LoopExtension/LoopExtensionOps.h"
+
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/LoopInvariantCodeMotionUtils.h"
+
+using namespace mlir;
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/Transform/LoopExtension/LoopExtensionOps.cpp.inc"
+
+//===----------------------------------------------------------------------===//
+// HoistLoopInvariantSubsetsOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure transform::HoistLoopInvariantSubsetsOp::applyToOne(
+    transform::TransformRewriter &rewriter, LoopLikeOpInterface loopLikeOp,
+    transform::ApplyToEachResultList &results,
+    transform::TransformState &state) {
+  (void)hoistLoopInvariantSubsets(rewriter, loopLikeOp);
+  return DiagnosedSilenceableFailure::success();
+}
+
+void transform::HoistLoopInvariantSubsetsOp::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  transform::onlyReadsHandle(getTarget(), effects);
+  transform::modifiesPayload(effects);
+}
diff --git a/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp b/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp
index e6d8af8f05832d3..02c3ea1ce9b650c 100644
--- a/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp
+++ b/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp
@@ -12,6 +12,7 @@
 
 #include "mlir/Transforms/Passes.h"
 
+#include "mlir/IR/PatternMatch.h"
 #include "mlir/Interfaces/LoopLikeInterface.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "mlir/Transforms/LoopInvariantCodeMotionUtils.h"
@@ -47,11 +48,12 @@ void LoopInvariantCodeMotion::runOnOperation() {
 }
 
 void LoopInvariantSubsetHoisting::runOnOperation() {
+  IRRewriter rewriter(getOperation()->getContext());
   // Walk through all loops in a function in innermost-loop-first order. This
   // way, we first hoist from the inner loop, and place the ops in the outer
   // loop, which in turn can be further hoisted from.
   getOperation()->walk([&](LoopLikeOpInterface loopLike) {
-    (void)hoistLoopInvariantSubsets(loopLike);
+    (void)hoistLoopInvariantSubsets(rewriter, loopLike);
   });
 }
 
diff --git a/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp b/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
index 53bdb7aafe41a0c..8f97fd3d9ddf84e 100644
--- a/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
@@ -311,12 +311,12 @@ MatchingSubsets::populateSubsetOpsAtIterArg(LoopLikeOpInterface loopLike,
 /// loop-like op and index into loop-invariant subset locations. Return the
 /// newly created loop op (that has extra iter_args) or the original loop op if
 /// nothing was hoisted.
-static LoopLikeOpInterface hoistSubsetAtIterArg(LoopLikeOpInterface loopLike,
+static LoopLikeOpInterface hoistSubsetAtIterArg(RewriterBase &rewriter,
+                                                LoopLikeOpInterface loopLike,
                                                 BlockArgument iterArg) {
   assert(iterArg.getOwner()->getParentOp() == loopLike && "invalid iter_arg");
   auto it = llvm::find(loopLike.getRegionIterArgs(), iterArg);
   int64_t iterArgIdx = std::distance(loopLike.getRegionIterArgs().begin(), it);
-  IRRewriter rewriter(loopLike.getContext());
   MatchingSubsets subsets;
   if (failed(subsets.populateSubsetOpsAtIterArg(loopLike, iterArg)))
     return loopLike;
@@ -367,11 +367,12 @@ static LoopLikeOpInterface hoistSubsetAtIterArg(LoopLikeOpInterface loopLike,
       OpResult newLoopResult = loopLike.getLoopResults()->back();
       extractionOp->moveBefore(loopLike);
       insertionOp->moveAfter(loopLike);
-      insertionOp.getUpdatedDestination().replaceAllUsesWith(
-          insertionOp.getDestinationOperand().get());
+      rewriter.replaceAllUsesWith(insertionOp.getUpdatedDestination(),
+                                  insertionOp.getDestinationOperand().get());
       extractionOp.getSourceOperand().set(
           loopLike.getTiedLoopInit(iterArg)->get());
-      loopResult.replaceAllUsesWith(insertionOp.getUpdatedDestination());
+      rewriter.replaceAllUsesWith(loopResult,
+                                  insertionOp.getUpdatedDestination());
       insertionOp.getSourceOperand().set(newLoopResult);
       insertionOp.getDestinationOperand().set(loopResult);
     }
@@ -381,13 +382,15 @@ static LoopLikeOpInterface hoistSubsetAtIterArg(LoopLikeOpInterface loopLike,
 }
 
 LoopLikeOpInterface
-mlir::hoistLoopInvariantSubsets(LoopLikeOpInterface loopLike) {
+mlir::hoistLoopInvariantSubsets(RewriterBase &rewriter,
+                                LoopLikeOpInterface loopLike) {
   // Note: As subset ops are getting hoisted, the number of region iter_args
   // increases. This can enable further hoisting opportunities on the new
   // iter_args.
   for (int64_t i = 0;
        i < static_cast<int64_t>(loopLike.getRegionIterArgs().size()); ++i) {
-    loopLike = hoistSubsetAtIterArg(loopLike, loopLike.getRegionIterArgs()[i]);
+    loopLike = hoistSubsetAtIterArg(rewriter, loopLike,
+                                    loopLike.getRegionIterArgs()[i]);
   }
   return loopLike;
 }
diff --git a/mlir/test/Dialect/Transform/test-loop-transforms.mlir b/mlir/test/Dialect/Transform/test-loop-transforms.mlir
new file mode 100644
index 000000000000000..c100a8ecf702589
--- /dev/null
+++ b/mlir/test/Dialect/Transform/test-loop-transforms.mlir
@@ -0,0 +1,50 @@
+// RUN: mlir-opt %s --transform-interpreter --split-input-file \
+// RUN:     --verify-diagnostics | FileCheck %s
+
+// UNSUPPORTED: target=aarch64-pc-windows-msvc
+
+// CHECK-LABEL: func @test_loop_invariant_subset_hoisting(
+//  CHECK-SAME:     %[[arg:.*]]: tensor<?xf32>
+func.func @test_loop_invariant_subset_hoisting(%arg: tensor<?xf32>) -> tensor<?xf32> {
+  %lb = "test.foo"() : () -> (index)
+  %ub = "test.foo"() : () -> (index)
+  %step = "test.foo"() : () -> (index)
+  // CHECK: %[[extract:.*]] = tensor.extract_slice %[[arg]]
+  // CHECK: %[[for:.*]]:2 = scf.for {{.*}} iter_args(%[[t:.*]] = %[[arg]], %[[hoisted:.*]] = %[[extract]])
+  // expected-remark @below{{new loop op}}
+  %0 = scf.for %iv = %lb to %ub step %step iter_args(%t = %arg) -> (tensor<?xf32>) {
+    %1 = tensor.extract_slice %t[0][5][1] : tensor<?xf32> to tensor<5xf32>
+    // CHECK: %[[foo:.*]] = "test.foo"(%[[hoisted]])
+    %2 = "test.foo"(%1) : (tensor<5xf32>) -> (tensor<5xf32>)
+    // Obfuscate the IR by inserting at offset %sub instead of 0; both of them
+    // have the same value.
+    %3 = tensor.insert_slice %2 into %t[0][5][1] : tensor<5xf32> into tensor<?xf32>
+    // CHECK: scf.yield %[[t]], %[[foo]]
+    scf.yield %3 : tensor<?xf32>
+  }
+  // CHECK: %[[insert:.*]] = tensor.insert_slice %[[for]]#1 into %[[for]]#0
+  // CHECK: return %[[insert]]
+  return %0 : tensor<?xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["scf.for"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.structured.match ops{["tensor.extract_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %2 = transform.structured.match ops{["tensor.insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+
+    transform.loop.hoist_loop_invariant_subsets %0 : !transform.any_op
+    // Make sure that the handles are still valid (and were updated in case of
+    // the loop).
+
+    // expected-remark @below{{1}}
+    transform.test_print_number_of_associated_payload_ir_ops %0 : !transform.any_op
+    transform.test_print_remark_at_operand %0, "new loop op" : !transform.any_op
+    // expected-remark @below{{1}}
+    transform.test_print_number_of_associated_payload_ir_ops %1 : !transform.any_op
+    // expected-remark @below{{1}}
+    transform.test_print_number_of_associated_payload_ir_ops %2 : !transform.any_op
+
+    transform.yield
+  }
+}
diff --git a/mlir/test/lib/Dialect/Transform/CMakeLists.txt b/mlir/test/lib/Dialect/Transform/CMakeLists.txt
index c7e83d3a7128bcb..436f892a27232b2 100644
--- a/mlir/test/lib/Dialect/Transform/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Transform/CMakeLists.txt
@@ -21,5 +21,6 @@ add_mlir_library(MLIRTestTransformDialect
   MLIRPDLDialect
   MLIRTransformDialect
   MLIRTransformDialectTransforms
+  MLIRTransformLoopExtension
   MLIRTransformPDLExtension
 )
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 2cadd4e0d2911a6..99aa78bb3d3d33b 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -4416,6 +4416,7 @@ cc_library(
         ":SCFTransformOps",
         ":SparseTensorTransformOps",
         ":TensorTransformOps",
+        ":TransformLoopExtension",
         ":TransformPDLExtension",
         ":UBToLLVM",
         ":VectorTransformOps",
@@ -8677,6 +8678,7 @@ cc_library(
         ":TosaToLinalg",
         ":TransformDialect",
         ":TransformDialectTransforms",
+        ":TransformLoopExtension",
         ":TransformPDLExtension",
         ":Transforms",
         ":TransformsPassIncGen",
@@ -11401,6 +11403,52 @@ cc_library(
     ],
 )
 
+td_library(
+    name = "TransformLoopExtensionTdFiles",
+    srcs = glob(["include/mlir/Dialect/Transform/LoopExtension/*.td"]),
+    deps = [
+        ":TransformDialectTdFiles",
+    ],
+)
+
+gentbl_cc_library(
+    name = "TransformLoopExtensionOpsIncGen",
+    tbl_outs = [
+        (
+            [
+                "-gen-op-decls",
+            ],
+            "include/mlir/Dialect/Transform/LoopExtension/LoopExtensionOps.h.inc",
+        ),
+        (
+            [
+                "-gen-op-defs",
+            ],
+            "include/mlir/Dialect/Transform/LoopExtension/LoopExtensionOps.cpp.inc",
+        ),
+    ],
+    tblgen = ":mlir-tblgen",
+    td_file = "include/mlir/Dialect/Transform/LoopExtension/LoopExtensionOps.td",
+    deps = [":TransformLoopExtensionTdFiles"],
+)
+
+cc_library(
+    name = "TransformLoopExtension",
+    srcs = glob(["lib/Dialect/Transform/LoopExtension/*.cpp"]),
+    hdrs = glob(["include/mlir/Dialect/Transform/LoopExtension/*.h"]),
+    deps = [
+        ":IR",
+        ":LoopLikeInterface",
+        ":Rewrite",
+        ":SideEffectInterfaces",
+        ":Support",
+        ":TransformDialect",
+        ":TransformLoopExtensionOpsIncGen",
+        ":Transforms",
+        "//llvm:Support",
+    ],
+)
+
 td_library(
     name = "TransformDialectTransformsTdFiles",
     srcs = glob(["include/mlir/Dialect/Transform/Transforms/*.td"]),



More information about the Mlir-commits mailing list