[Mlir-commits] [mlir] 7797824 - [mlir][spirv] Allow disabling control flow structurization (#140561)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Jun 3 07:41:42 PDT 2025
Author: Igor Wodiany
Date: 2025-06-03T15:41:39+01:00
New Revision: 7797824297e17d4c02fbb1cb904c7919f21af47e
URL: https://github.com/llvm/llvm-project/commit/7797824297e17d4c02fbb1cb904c7919f21af47e
DIFF: https://github.com/llvm/llvm-project/commit/7797824297e17d4c02fbb1cb904c7919f21af47e.diff
LOG: [mlir][spirv] Allow disabling control flow structurization (#140561)
Currently some control flow patterns cannot be structurized into
existing SPIR-V MLIR constructs, e.g., conditional early exits (break).
Since the support for early exit cannot be currently added
(https://github.com/llvm/llvm-project/pull/138688#pullrequestreview-2830791677)
this patch enables structurizer to be disabled to keep
the control flow unstructurized. By default, the control flow is
structurized.
Added:
Modified:
mlir/include/mlir/Target/SPIRV/Deserialization.h
mlir/lib/Target/SPIRV/Deserialization/Deserialization.cpp
mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
mlir/lib/Target/SPIRV/TranslateRegistration.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Target/SPIRV/Deserialization.h b/mlir/include/mlir/Target/SPIRV/Deserialization.h
index a346a7fd1e5f7..68eb863b4a6f2 100644
--- a/mlir/include/mlir/Target/SPIRV/Deserialization.h
+++ b/mlir/include/mlir/Target/SPIRV/Deserialization.h
@@ -23,12 +23,19 @@ class MLIRContext;
namespace spirv {
class ModuleOp;
+struct DeserializationOptions {
+ // Whether to structurize control flow into `spirv.mlir.selection` and
+ // `spirv.mlir.loop`.
+ bool enableControlFlowStructurization = true;
+};
+
/// Deserializes the given SPIR-V `binary` module and creates a MLIR ModuleOp
/// in the given `context`. Returns the ModuleOp on success; otherwise, reports
/// errors to the error handler registered with `context` and returns a null
/// module.
-OwningOpRef<spirv::ModuleOp> deserialize(ArrayRef<uint32_t> binary,
- MLIRContext *context);
+OwningOpRef<spirv::ModuleOp>
+deserialize(ArrayRef<uint32_t> binary, MLIRContext *context,
+ const DeserializationOptions &options = {});
} // namespace spirv
} // namespace mlir
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserialization.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserialization.cpp
index 7bb8762660599..b82c61cafc8a7 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserialization.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserialization.cpp
@@ -12,9 +12,10 @@
using namespace mlir;
-OwningOpRef<spirv::ModuleOp> spirv::deserialize(ArrayRef<uint32_t> binary,
- MLIRContext *context) {
- Deserializer deserializer(binary, context);
+OwningOpRef<spirv::ModuleOp>
+spirv::deserialize(ArrayRef<uint32_t> binary, MLIRContext *context,
+ const DeserializationOptions &options) {
+ Deserializer deserializer(binary, context, options);
if (failed(deserializer.deserialize()))
return nullptr;
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 7afd6e9b25b77..a21d691ae5142 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -49,9 +49,10 @@ static inline bool isFnEntryBlock(Block *block) {
//===----------------------------------------------------------------------===//
spirv::Deserializer::Deserializer(ArrayRef<uint32_t> binary,
- MLIRContext *context)
+ MLIRContext *context,
+ const spirv::DeserializationOptions &options)
: binary(binary), context(context), unknownLoc(UnknownLoc::get(context)),
- module(createModuleOp()), opBuilder(module->getRegion())
+ module(createModuleOp()), opBuilder(module->getRegion()), options(options)
#ifndef NDEBUG
,
logger(llvm::dbgs())
@@ -2361,6 +2362,16 @@ LogicalResult spirv::Deserializer::splitConditionalBlocks() {
}
LogicalResult spirv::Deserializer::structurizeControlFlow() {
+ if (!options.enableControlFlowStructurization) {
+ LLVM_DEBUG(
+ {
+ logger.startLine()
+ << "//----- [cf] skip structurizing control flow -----//\n";
+ logger.indent();
+ });
+ return success();
+ }
+
LLVM_DEBUG({
logger.startLine()
<< "//----- [cf] start structurizing control flow -----//\n";
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
index bcc78e3e6508d..e4556e7652b17 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
@@ -16,6 +16,7 @@
#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/IR/Builders.h"
+#include "mlir/Target/SPIRV/Deserialization.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/StringRef.h"
@@ -121,7 +122,8 @@ class Deserializer {
public:
/// Creates a deserializer for the given SPIR-V `binary` module.
/// The SPIR-V ModuleOp will be created into `context.
- explicit Deserializer(ArrayRef<uint32_t> binary, MLIRContext *context);
+ explicit Deserializer(ArrayRef<uint32_t> binary, MLIRContext *context,
+ const DeserializationOptions &options);
/// Deserializes the remembered SPIR-V binary module.
LogicalResult deserialize();
@@ -622,6 +624,9 @@ class Deserializer {
/// A list of all structs which have unresolved member types.
SmallVector<DeferredStructTypeInfo, 0> deferredStructTypesInfos;
+ /// Deserialization options.
+ DeserializationOptions options;
+
#ifndef NDEBUG
/// A logger used to emit information during the deserialzation process.
llvm::ScopedPrinter logger;
diff --git a/mlir/lib/Target/SPIRV/TranslateRegistration.cpp b/mlir/lib/Target/SPIRV/TranslateRegistration.cpp
index ff34f02d07b73..682fff2784775 100644
--- a/mlir/lib/Target/SPIRV/TranslateRegistration.cpp
+++ b/mlir/lib/Target/SPIRV/TranslateRegistration.cpp
@@ -37,7 +37,8 @@ using namespace mlir;
// Deserializes the SPIR-V binary module stored in the file named as
// `inputFilename` and returns a module containing the SPIR-V module.
static OwningOpRef<Operation *>
-deserializeModule(const llvm::MemoryBuffer *input, MLIRContext *context) {
+deserializeModule(const llvm::MemoryBuffer *input, MLIRContext *context,
+ const spirv::DeserializationOptions &options) {
context->loadDialect<spirv::SPIRVDialect>();
// Make sure the input stream can be treated as a stream of SPIR-V words
@@ -51,17 +52,26 @@ deserializeModule(const llvm::MemoryBuffer *input, MLIRContext *context) {
auto binary = llvm::ArrayRef(reinterpret_cast<const uint32_t *>(start),
size / sizeof(uint32_t));
- return spirv::deserialize(binary, context);
+ return spirv::deserialize(binary, context, options);
}
namespace mlir {
void registerFromSPIRVTranslation() {
+ static llvm::cl::opt<bool> enableControlFlowStructurization(
+ "spirv-structurize-control-flow",
+ llvm::cl::desc(
+ "Enable control flow structurization into `spirv.mlir.selection` and "
+ "`spirv.mlir.loop`. This may need to be disabled to support "
+ "deserialization of early exits (see #138688)"),
+ llvm::cl::init(true));
+
TranslateToMLIRRegistration fromBinary(
"deserialize-spirv", "deserializes the SPIR-V module",
[](llvm::SourceMgr &sourceMgr, MLIRContext *context) {
assert(sourceMgr.getNumBuffers() == 1 && "expected one buffer");
return deserializeModule(
- sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID()), context);
+ sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID()), context,
+ {enableControlFlowStructurization});
});
}
} // namespace mlir
More information about the Mlir-commits
mailing list