[Mlir-commits] [mlir] [mlir][NFC] Move LLVM::ModuleTranslation::SaveStack to a shared header (PR #144897)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Jun 19 07:05:11 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-openmp

@llvm/pr-subscribers-flang-openmp

Author: Tom Eccles (tblah)

<details>
<summary>Changes</summary>

This is so that we can re-use the same code in Flang.

---
Full diff: https://github.com/llvm/llvm-project/pull/144897.diff


6 Files Affected:

- (added) mlir/include/mlir/Support/StateStack.h (+116) 
- (modified) mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h (+6-66) 
- (modified) mlir/lib/Support/CMakeLists.txt (+1) 
- (added) mlir/lib/Support/StateStack.cpp (+9) 
- (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+2-2) 
- (modified) mlir/lib/Target/LLVMIR/ModuleTranslation.cpp (-2) 


``````````diff
diff --git a/mlir/include/mlir/Support/StateStack.h b/mlir/include/mlir/Support/StateStack.h
new file mode 100644
index 0000000000000..aca2375028246
--- /dev/null
+++ b/mlir/include/mlir/Support/StateStack.h
@@ -0,0 +1,116 @@
+//===- StateStack.h - Utility for storing a stack of state ------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines utilities for storing a stack of generic context.
+// The context can be arbitrary data, possibly including file-scoped types.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_SUPPORT_STACKFRAME_H
+#define MLIR_SUPPORT_STACKFRAME_H
+
+#include "mlir/IR/Visitors.h"
+#include "mlir/Support/TypeID.h"
+#include <memory>
+
+namespace mlir {
+
+/// Common CRTP base class for StateStack frames.
+class StateStackFrame {
+public:
+  virtual ~StateStackFrame() = default;
+  TypeID getTypeID() const { return typeID; }
+
+protected:
+  explicit StateStackFrame(TypeID typeID) : typeID(typeID) {}
+
+private:
+  const TypeID typeID;
+  virtual void anchor() {};
+};
+
+/// Concrete CRTP base class for StateStack frames. This is used for keeping a
+/// stack of common state useful for recursive IR conversions. For example, when
+/// translating operations with regions, users of StateStack can store state on
+/// StateStack before entering the region and inspect it when converting
+/// operations nested within that region. Users are expected to derive this
+/// class and put any relevant information into fields of the derived class. The
+/// usual isa/dyn_cast functionality is available for instances of derived
+/// classes.
+template <typename Derived>
+class StateStackFrameBase : public StateStackFrame {
+public:
+  explicit StateStackFrameBase() : StateStackFrame(TypeID::get<Derived>()) {}
+};
+
+class StateStack {
+public:
+  /// Creates a stack frame of type `T` on StateStack. `T` must
+  /// be derived from `StackFrameBase<T>` and constructible from the provided
+  /// arguments. Doing this before entering the region of the op being
+  /// translated makes the frame available when translating ops within that
+  /// region.
+  template <typename T, typename... Args>
+  void stackPush(Args &&...args) {
+    static_assert(std::is_base_of<StateStackFrame, T>::value,
+                  "can only push instances of StackFrame on StateStack");
+    stack.push_back(std::make_unique<T>(std::forward<Args>(args)...));
+  }
+
+  /// Pops the last element from the StateStack.
+  void stackPop() { stack.pop_back(); }
+
+  /// Calls `callback` for every StateStack frame of type `T`
+  /// starting from the top of the stack.
+  template <typename T>
+  WalkResult stackWalk(llvm::function_ref<WalkResult(T &)> callback) {
+    static_assert(std::is_base_of<StateStackFrame, T>::value,
+                  "expected T derived from StackFrame");
+    if (!callback)
+      return WalkResult::skip();
+    for (std::unique_ptr<StateStackFrame> &frame : llvm::reverse(stack)) {
+      if (T *ptr = dyn_cast_or_null<T>(frame.get())) {
+        WalkResult result = callback(*ptr);
+        if (result.wasInterrupted())
+          return result;
+      }
+    }
+    return WalkResult::advance();
+  }
+
+private:
+  SmallVector<std::unique_ptr<StateStackFrame>> stack;
+};
+
+/// RAII object calling stackPush/stackPop on construction/destruction.
+/// HOST_CLASS could be a StateStack or some other class which forwards calls to
+/// one.
+template <typename T, typename HOST_CLASS>
+struct SaveStateStack {
+  template <typename... Args>
+  explicit SaveStateStack(HOST_CLASS &host, Args &&...args) : host(host) {
+    host.template stackPush<T>(std::forward<Args>(args)...);
+  }
+  ~SaveStateStack() { host.stackPop(); }
+
+private:
+  HOST_CLASS &host;
+};
+
+} // namespace mlir
+
+namespace llvm {
+template <typename T>
+struct isa_impl<T, ::mlir::StateStackFrame> {
+  static inline bool doit(const ::mlir::StateStackFrame &frame) {
+    return frame.getTypeID() == ::mlir::TypeID::get<T>();
+  }
+};
+} // namespace llvm
+
+#endif // MLIR_SUPPORT_STACKFRAME_H
diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
index 0f136c5c46d79..79e8bb6add0da 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
@@ -18,6 +18,7 @@
 #include "mlir/IR/Operation.h"
 #include "mlir/IR/SymbolTable.h"
 #include "mlir/IR/Value.h"
+#include "mlir/Support/StateStack.h"
 #include "mlir/Target/LLVMIR/Export.h"
 #include "mlir/Target/LLVMIR/LLVMTranslationInterface.h"
 #include "mlir/Target/LLVMIR/TypeToLLVM.h"
@@ -271,33 +272,6 @@ class ModuleTranslation {
   /// it if it does not exist.
   llvm::NamedMDNode *getOrInsertNamedModuleMetadata(StringRef name);
 
-  /// Common CRTP base class for ModuleTranslation stack frames.
-  class StackFrame {
-  public:
-    virtual ~StackFrame() = default;
-    TypeID getTypeID() const { return typeID; }
-
-  protected:
-    explicit StackFrame(TypeID typeID) : typeID(typeID) {}
-
-  private:
-    const TypeID typeID;
-    virtual void anchor();
-  };
-
-  /// Concrete CRTP base class for ModuleTranslation stack frames. When
-  /// translating operations with regions, users of ModuleTranslation can store
-  /// state on ModuleTranslation stack before entering the region and inspect
-  /// it when converting operations nested within that region. Users are
-  /// expected to derive this class and put any relevant information into fields
-  /// of the derived class. The usual isa/dyn_cast functionality is available
-  /// for instances of derived classes.
-  template <typename Derived>
-  class StackFrameBase : public StackFrame {
-  public:
-    explicit StackFrameBase() : StackFrame(TypeID::get<Derived>()) {}
-  };
-
   /// Creates a stack frame of type `T` on ModuleTranslation stack. `T` must
   /// be derived from `StackFrameBase<T>` and constructible from the provided
   /// arguments. Doing this before entering the region of the op being
@@ -305,46 +279,22 @@ class ModuleTranslation {
   /// region.
   template <typename T, typename... Args>
   void stackPush(Args &&...args) {
-    static_assert(
-        std::is_base_of<StackFrame, T>::value,
-        "can only push instances of StackFrame on ModuleTranslation stack");
-    stack.push_back(std::make_unique<T>(std::forward<Args>(args)...));
+    stack.stackPush<T>(std::forward<Args>(args)...);
   }
 
   /// Pops the last element from the ModuleTranslation stack.
-  void stackPop() { stack.pop_back(); }
+  void stackPop() { stack.stackPop(); }
 
   /// Calls `callback` for every ModuleTranslation stack frame of type `T`
   /// starting from the top of the stack.
   template <typename T>
   WalkResult stackWalk(llvm::function_ref<WalkResult(T &)> callback) {
-    static_assert(std::is_base_of<StackFrame, T>::value,
-                  "expected T derived from StackFrame");
-    if (!callback)
-      return WalkResult::skip();
-    for (std::unique_ptr<StackFrame> &frame : llvm::reverse(stack)) {
-      if (T *ptr = dyn_cast_or_null<T>(frame.get())) {
-        WalkResult result = callback(*ptr);
-        if (result.wasInterrupted())
-          return result;
-      }
-    }
-    return WalkResult::advance();
+    return stack.stackWalk(callback);
   }
 
   /// RAII object calling stackPush/stackPop on construction/destruction.
   template <typename T>
-  struct SaveStack {
-    template <typename... Args>
-    explicit SaveStack(ModuleTranslation &m, Args &&...args)
-        : moduleTranslation(m) {
-      moduleTranslation.stackPush<T>(std::forward<Args>(args)...);
-    }
-    ~SaveStack() { moduleTranslation.stackPop(); }
-
-  private:
-    ModuleTranslation &moduleTranslation;
-  };
+  using SaveStack = SaveStateStack<T, ModuleTranslation>;
 
   SymbolTableCollection &symbolTable() { return symbolTableCollection; }
 
@@ -468,7 +418,7 @@ class ModuleTranslation {
 
   /// Stack of user-specified state elements, useful when translating operations
   /// with regions.
-  SmallVector<std::unique_ptr<StackFrame>> stack;
+  StateStack stack;
 
   /// A cache for the symbol tables constructed during symbols lookup.
   SymbolTableCollection symbolTableCollection;
@@ -510,14 +460,4 @@ llvm::CallInst *createIntrinsicCall(
 } // namespace LLVM
 } // namespace mlir
 
-namespace llvm {
-template <typename T>
-struct isa_impl<T, ::mlir::LLVM::ModuleTranslation::StackFrame> {
-  static inline bool
-  doit(const ::mlir::LLVM::ModuleTranslation::StackFrame &frame) {
-    return frame.getTypeID() == ::mlir::TypeID::get<T>();
-  }
-};
-} // namespace llvm
-
 #endif // MLIR_TARGET_LLVMIR_MODULETRANSLATION_H
diff --git a/mlir/lib/Support/CMakeLists.txt b/mlir/lib/Support/CMakeLists.txt
index 488decd52ae64..02b6c694a28fd 100644
--- a/mlir/lib/Support/CMakeLists.txt
+++ b/mlir/lib/Support/CMakeLists.txt
@@ -11,6 +11,7 @@ add_mlir_library(MLIRSupport
   FileUtilities.cpp
   InterfaceSupport.cpp
   RawOstreamExtras.cpp
+  StateStack.cpp
   StorageUniquer.cpp
   Timing.cpp
   ToolUtilities.cpp
diff --git a/mlir/lib/Support/StateStack.cpp b/mlir/lib/Support/StateStack.cpp
new file mode 100644
index 0000000000000..ce1417cf3eba7
--- /dev/null
+++ b/mlir/lib/Support/StateStack.cpp
@@ -0,0 +1,9 @@
+//===- StateStack.cpp - Utility for storing a stack of state --------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Support/StateStack.h"
\ No newline at end of file
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 90ce06a0345c0..e29e3d8f820dc 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -71,7 +71,7 @@ convertToScheduleKind(std::optional<omp::ClauseScheduleKind> schedKind) {
 /// ModuleTranslation stack frame for OpenMP operations. This keeps track of the
 /// insertion points for allocas.
 class OpenMPAllocaStackFrame
-    : public LLVM::ModuleTranslation::StackFrameBase<OpenMPAllocaStackFrame> {
+    : public StateStackFrameBase<OpenMPAllocaStackFrame> {
 public:
   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OpenMPAllocaStackFrame)
 
@@ -84,7 +84,7 @@ class OpenMPAllocaStackFrame
 /// collapsed canonical loop information corresponding to an \c omp.loop_nest
 /// operation.
 class OpenMPLoopInfoStackFrame
-    : public LLVM::ModuleTranslation::StackFrameBase<OpenMPLoopInfoStackFrame> {
+    : public StateStackFrameBase<OpenMPLoopInfoStackFrame> {
 public:
   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OpenMPLoopInfoStackFrame)
   llvm::CanonicalLoopInfo *loopInfo = nullptr;
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index 3eaa24eb5c95b..e8ce528bd185e 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -2225,8 +2225,6 @@ ModuleTranslation::getOrInsertNamedModuleMetadata(StringRef name) {
   return llvmModule->getOrInsertNamedMetadata(name);
 }
 
-void ModuleTranslation::StackFrame::anchor() {}
-
 static std::unique_ptr<llvm::Module>
 prepareLLVMModule(Operation *m, llvm::LLVMContext &llvmContext,
                   StringRef name) {

``````````

</details>


https://github.com/llvm/llvm-project/pull/144897


More information about the Mlir-commits mailing list