[Mlir-commits] [mlir] [mlir][NFC] Move LLVM::ModuleTranslation::SaveStack to a shared header (PR #144897)
Tom Eccles
llvmlistbot at llvm.org
Thu Jun 19 07:04:42 PDT 2025
https://github.com/tblah created https://github.com/llvm/llvm-project/pull/144897
This is so that we can re-use the same code in Flang.
>From 7ac1750c6ad9f780fba0154713d872dc8a665280 Mon Sep 17 00:00:00 2001
From: Tom Eccles <tom.eccles at arm.com>
Date: Wed, 18 Jun 2025 20:42:02 +0000
Subject: [PATCH] [mlir][NFC] Move LLVM::ModuleTranslation::SaveStack to a
shared header
This is so that we can re-use the same code in Flang.
---
mlir/include/mlir/Support/StateStack.h | 116 ++++++++++++++++++
.../mlir/Target/LLVMIR/ModuleTranslation.h | 72 +----------
mlir/lib/Support/CMakeLists.txt | 1 +
mlir/lib/Support/StateStack.cpp | 9 ++
.../OpenMP/OpenMPToLLVMIRTranslation.cpp | 4 +-
mlir/lib/Target/LLVMIR/ModuleTranslation.cpp | 2 -
6 files changed, 134 insertions(+), 70 deletions(-)
create mode 100644 mlir/include/mlir/Support/StateStack.h
create mode 100644 mlir/lib/Support/StateStack.cpp
diff --git a/mlir/include/mlir/Support/StateStack.h b/mlir/include/mlir/Support/StateStack.h
new file mode 100644
index 0000000000000..aca2375028246
--- /dev/null
+++ b/mlir/include/mlir/Support/StateStack.h
@@ -0,0 +1,116 @@
+//===- StateStack.h - Utility for storing a stack of state ------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines utilities for storing a stack of generic context.
+// The context can be arbitrary data, possibly including file-scoped types.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_SUPPORT_STACKFRAME_H
+#define MLIR_SUPPORT_STACKFRAME_H
+
+#include "mlir/IR/Visitors.h"
+#include "mlir/Support/TypeID.h"
+#include <memory>
+
+namespace mlir {
+
+/// Common CRTP base class for StateStack frames.
+class StateStackFrame {
+public:
+ virtual ~StateStackFrame() = default;
+ TypeID getTypeID() const { return typeID; }
+
+protected:
+ explicit StateStackFrame(TypeID typeID) : typeID(typeID) {}
+
+private:
+ const TypeID typeID;
+ virtual void anchor() {};
+};
+
+/// Concrete CRTP base class for StateStack frames. This is used for keeping a
+/// stack of common state useful for recursive IR conversions. For example, when
+/// translating operations with regions, users of StateStack can store state on
+/// StateStack before entering the region and inspect it when converting
+/// operations nested within that region. Users are expected to derive this
+/// class and put any relevant information into fields of the derived class. The
+/// usual isa/dyn_cast functionality is available for instances of derived
+/// classes.
+template <typename Derived>
+class StateStackFrameBase : public StateStackFrame {
+public:
+ explicit StateStackFrameBase() : StateStackFrame(TypeID::get<Derived>()) {}
+};
+
+class StateStack {
+public:
+ /// Creates a stack frame of type `T` on StateStack. `T` must
+ /// be derived from `StackFrameBase<T>` and constructible from the provided
+ /// arguments. Doing this before entering the region of the op being
+ /// translated makes the frame available when translating ops within that
+ /// region.
+ template <typename T, typename... Args>
+ void stackPush(Args &&...args) {
+ static_assert(std::is_base_of<StateStackFrame, T>::value,
+ "can only push instances of StackFrame on StateStack");
+ stack.push_back(std::make_unique<T>(std::forward<Args>(args)...));
+ }
+
+ /// Pops the last element from the StateStack.
+ void stackPop() { stack.pop_back(); }
+
+ /// Calls `callback` for every StateStack frame of type `T`
+ /// starting from the top of the stack.
+ template <typename T>
+ WalkResult stackWalk(llvm::function_ref<WalkResult(T &)> callback) {
+ static_assert(std::is_base_of<StateStackFrame, T>::value,
+ "expected T derived from StackFrame");
+ if (!callback)
+ return WalkResult::skip();
+ for (std::unique_ptr<StateStackFrame> &frame : llvm::reverse(stack)) {
+ if (T *ptr = dyn_cast_or_null<T>(frame.get())) {
+ WalkResult result = callback(*ptr);
+ if (result.wasInterrupted())
+ return result;
+ }
+ }
+ return WalkResult::advance();
+ }
+
+private:
+ SmallVector<std::unique_ptr<StateStackFrame>> stack;
+};
+
+/// RAII object calling stackPush/stackPop on construction/destruction.
+/// HOST_CLASS could be a StateStack or some other class which forwards calls to
+/// one.
+template <typename T, typename HOST_CLASS>
+struct SaveStateStack {
+ template <typename... Args>
+ explicit SaveStateStack(HOST_CLASS &host, Args &&...args) : host(host) {
+ host.template stackPush<T>(std::forward<Args>(args)...);
+ }
+ ~SaveStateStack() { host.stackPop(); }
+
+private:
+ HOST_CLASS &host;
+};
+
+} // namespace mlir
+
+namespace llvm {
+template <typename T>
+struct isa_impl<T, ::mlir::StateStackFrame> {
+ static inline bool doit(const ::mlir::StateStackFrame &frame) {
+ return frame.getTypeID() == ::mlir::TypeID::get<T>();
+ }
+};
+} // namespace llvm
+
+#endif // MLIR_SUPPORT_STACKFRAME_H
diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
index 0f136c5c46d79..79e8bb6add0da 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
@@ -18,6 +18,7 @@
#include "mlir/IR/Operation.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/IR/Value.h"
+#include "mlir/Support/StateStack.h"
#include "mlir/Target/LLVMIR/Export.h"
#include "mlir/Target/LLVMIR/LLVMTranslationInterface.h"
#include "mlir/Target/LLVMIR/TypeToLLVM.h"
@@ -271,33 +272,6 @@ class ModuleTranslation {
/// it if it does not exist.
llvm::NamedMDNode *getOrInsertNamedModuleMetadata(StringRef name);
- /// Common CRTP base class for ModuleTranslation stack frames.
- class StackFrame {
- public:
- virtual ~StackFrame() = default;
- TypeID getTypeID() const { return typeID; }
-
- protected:
- explicit StackFrame(TypeID typeID) : typeID(typeID) {}
-
- private:
- const TypeID typeID;
- virtual void anchor();
- };
-
- /// Concrete CRTP base class for ModuleTranslation stack frames. When
- /// translating operations with regions, users of ModuleTranslation can store
- /// state on ModuleTranslation stack before entering the region and inspect
- /// it when converting operations nested within that region. Users are
- /// expected to derive this class and put any relevant information into fields
- /// of the derived class. The usual isa/dyn_cast functionality is available
- /// for instances of derived classes.
- template <typename Derived>
- class StackFrameBase : public StackFrame {
- public:
- explicit StackFrameBase() : StackFrame(TypeID::get<Derived>()) {}
- };
-
/// Creates a stack frame of type `T` on ModuleTranslation stack. `T` must
/// be derived from `StackFrameBase<T>` and constructible from the provided
/// arguments. Doing this before entering the region of the op being
@@ -305,46 +279,22 @@ class ModuleTranslation {
/// region.
template <typename T, typename... Args>
void stackPush(Args &&...args) {
- static_assert(
- std::is_base_of<StackFrame, T>::value,
- "can only push instances of StackFrame on ModuleTranslation stack");
- stack.push_back(std::make_unique<T>(std::forward<Args>(args)...));
+ stack.stackPush<T>(std::forward<Args>(args)...);
}
/// Pops the last element from the ModuleTranslation stack.
- void stackPop() { stack.pop_back(); }
+ void stackPop() { stack.stackPop(); }
/// Calls `callback` for every ModuleTranslation stack frame of type `T`
/// starting from the top of the stack.
template <typename T>
WalkResult stackWalk(llvm::function_ref<WalkResult(T &)> callback) {
- static_assert(std::is_base_of<StackFrame, T>::value,
- "expected T derived from StackFrame");
- if (!callback)
- return WalkResult::skip();
- for (std::unique_ptr<StackFrame> &frame : llvm::reverse(stack)) {
- if (T *ptr = dyn_cast_or_null<T>(frame.get())) {
- WalkResult result = callback(*ptr);
- if (result.wasInterrupted())
- return result;
- }
- }
- return WalkResult::advance();
+ return stack.stackWalk(callback);
}
/// RAII object calling stackPush/stackPop on construction/destruction.
template <typename T>
- struct SaveStack {
- template <typename... Args>
- explicit SaveStack(ModuleTranslation &m, Args &&...args)
- : moduleTranslation(m) {
- moduleTranslation.stackPush<T>(std::forward<Args>(args)...);
- }
- ~SaveStack() { moduleTranslation.stackPop(); }
-
- private:
- ModuleTranslation &moduleTranslation;
- };
+ using SaveStack = SaveStateStack<T, ModuleTranslation>;
SymbolTableCollection &symbolTable() { return symbolTableCollection; }
@@ -468,7 +418,7 @@ class ModuleTranslation {
/// Stack of user-specified state elements, useful when translating operations
/// with regions.
- SmallVector<std::unique_ptr<StackFrame>> stack;
+ StateStack stack;
/// A cache for the symbol tables constructed during symbols lookup.
SymbolTableCollection symbolTableCollection;
@@ -510,14 +460,4 @@ llvm::CallInst *createIntrinsicCall(
} // namespace LLVM
} // namespace mlir
-namespace llvm {
-template <typename T>
-struct isa_impl<T, ::mlir::LLVM::ModuleTranslation::StackFrame> {
- static inline bool
- doit(const ::mlir::LLVM::ModuleTranslation::StackFrame &frame) {
- return frame.getTypeID() == ::mlir::TypeID::get<T>();
- }
-};
-} // namespace llvm
-
#endif // MLIR_TARGET_LLVMIR_MODULETRANSLATION_H
diff --git a/mlir/lib/Support/CMakeLists.txt b/mlir/lib/Support/CMakeLists.txt
index 488decd52ae64..02b6c694a28fd 100644
--- a/mlir/lib/Support/CMakeLists.txt
+++ b/mlir/lib/Support/CMakeLists.txt
@@ -11,6 +11,7 @@ add_mlir_library(MLIRSupport
FileUtilities.cpp
InterfaceSupport.cpp
RawOstreamExtras.cpp
+ StateStack.cpp
StorageUniquer.cpp
Timing.cpp
ToolUtilities.cpp
diff --git a/mlir/lib/Support/StateStack.cpp b/mlir/lib/Support/StateStack.cpp
new file mode 100644
index 0000000000000..ce1417cf3eba7
--- /dev/null
+++ b/mlir/lib/Support/StateStack.cpp
@@ -0,0 +1,9 @@
+//===- StateStack.cpp - Utility for storing a stack of state --------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Support/StateStack.h"
\ No newline at end of file
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 90ce06a0345c0..e29e3d8f820dc 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -71,7 +71,7 @@ convertToScheduleKind(std::optional<omp::ClauseScheduleKind> schedKind) {
/// ModuleTranslation stack frame for OpenMP operations. This keeps track of the
/// insertion points for allocas.
class OpenMPAllocaStackFrame
- : public LLVM::ModuleTranslation::StackFrameBase<OpenMPAllocaStackFrame> {
+ : public StateStackFrameBase<OpenMPAllocaStackFrame> {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OpenMPAllocaStackFrame)
@@ -84,7 +84,7 @@ class OpenMPAllocaStackFrame
/// collapsed canonical loop information corresponding to an \c omp.loop_nest
/// operation.
class OpenMPLoopInfoStackFrame
- : public LLVM::ModuleTranslation::StackFrameBase<OpenMPLoopInfoStackFrame> {
+ : public StateStackFrameBase<OpenMPLoopInfoStackFrame> {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OpenMPLoopInfoStackFrame)
llvm::CanonicalLoopInfo *loopInfo = nullptr;
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index 3eaa24eb5c95b..e8ce528bd185e 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -2225,8 +2225,6 @@ ModuleTranslation::getOrInsertNamedModuleMetadata(StringRef name) {
return llvmModule->getOrInsertNamedMetadata(name);
}
-void ModuleTranslation::StackFrame::anchor() {}
-
static std::unique_ptr<llvm::Module>
prepareLLVMModule(Operation *m, llvm::LLVMContext &llvmContext,
StringRef name) {
More information about the Mlir-commits
mailing list