[Mlir-commits] [mlir] 90a8260 - [MLIR][SPIRV] Start module combiner.
Lei Zhang
llvmlistbot at llvm.org
Fri Oct 30 14:01:51 PDT 2020
Author: ergawy
Date: 2020-10-30T16:55:43-04:00
New Revision: 90a8260cb46dbf15b6c2325979273da6d15e9aee
URL: https://github.com/llvm/llvm-project/commit/90a8260cb46dbf15b6c2325979273da6d15e9aee
DIFF: https://github.com/llvm/llvm-project/commit/90a8260cb46dbf15b6c2325979273da6d15e9aee.diff
LOG: [MLIR][SPIRV] Start module combiner.
This commit adds a new library that merges/combines a number of spv
modules into a combined one. The library has a single entry point:
combine(...).
To combine a number of MLIR spv modules, we move all the module-level ops
from all the input modules into one big combined module. To that end, the
combination process can proceed in 2 phases:
(1) resolving conflicts between pairs of ops from different modules
(2) deduplicate equivalent ops/sub-ops in the merged module. (TODO)
This patch implements only the first phase.
Reviewed By: antiagainst
Differential Revision: https://reviews.llvm.org/D90477
Added:
mlir/include/mlir/Dialect/SPIRV/ModuleCombiner.h
mlir/lib/Dialect/SPIRV/Linking/CMakeLists.txt
mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/CMakeLists.txt
mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp
mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/basic.mlir
mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/conflict_resolution.mlir
mlir/test/lib/Dialect/SPIRV/TestModuleCombiner.cpp
Modified:
mlir/lib/Dialect/SPIRV/CMakeLists.txt
mlir/test/lib/Dialect/SPIRV/CMakeLists.txt
mlir/tools/mlir-opt/mlir-opt.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SPIRV/ModuleCombiner.h b/mlir/include/mlir/Dialect/SPIRV/ModuleCombiner.h
new file mode 100644
index 000000000000..b7ecd57d103d
--- /dev/null
+++ b/mlir/include/mlir/Dialect/SPIRV/ModuleCombiner.h
@@ -0,0 +1,69 @@
+//===- ModuleCombiner.h - MLIR SPIR-V Module Combiner -----------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares the entry point to the SPIR-V module combiner library.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_SPIRV_MODULECOMBINER_H_
+#define MLIR_DIALECT_SPIRV_MODULECOMBINER_H_
+
+#include "mlir/Dialect/SPIRV/SPIRVModule.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/SmallVector.h"
+
+namespace mlir {
+class OpBuilder;
+
+namespace spirv {
+class ModuleOp;
+
+/// To combine a number of MLIR SPIR-V modules, we move all the module-level ops
+/// from all the input modules into one big combined module. To that end, the
+/// combination process proceeds in 2 phases:
+///
+/// (1) resolve conflicts between pairs of ops from
diff erent modules
+/// (2) deduplicate equivalent ops/sub-ops in the merged module. (TODO)
+///
+/// For the conflict resolution phase, the following rules are employed to
+/// resolve such conflicts:
+///
+/// - If 2 spv.func's have the same symbol name, then rename one of the
+/// functions.
+/// - If an spv.func and another op have the same symbol name, then rename the
+/// other symbol.
+/// - If none of the 2 conflicting ops are spv.func, then rename either.
+///
+/// In all cases, the references to the updated symbol are also updated to
+/// reflect the change.
+///
+/// \param modules the list of modules to combine. Input modules are not
+/// modified.
+/// \param combinedMdouleBuilder an OpBuilder to be used for
+/// building up the combined module.
+/// \param symbRenameListener a listener that gets called everytime a symbol in
+/// one of the input modules is renamed. The arguments
+/// passed to the listener are: the input
+/// spirv::ModuleOp that contains the renamed symbol,
+/// a StringRef to the old symbol name, and a
+/// StringRef to the new symbol name. Note that it is
+/// the responsibility of the caller to properly
+/// retain the storage underlying the passed
+/// StringRefs if the listener callback outlives this
+/// function call.
+///
+/// \return the combined module.
+OwningSPIRVModuleRef
+combine(llvm::MutableArrayRef<ModuleOp> modules,
+ OpBuilder &combinedModuleBuilder,
+ llvm::function_ref<void(ModuleOp, StringRef, StringRef)>
+ symbRenameListener);
+} // namespace spirv
+} // namespace mlir
+
+#endif // MLIR_DIALECT_SPIRV_MODULECOMBINER_H_
diff --git a/mlir/lib/Dialect/SPIRV/CMakeLists.txt b/mlir/lib/Dialect/SPIRV/CMakeLists.txt
index 10f06fdb8861..f37182121fed 100644
--- a/mlir/lib/Dialect/SPIRV/CMakeLists.txt
+++ b/mlir/lib/Dialect/SPIRV/CMakeLists.txt
@@ -34,5 +34,6 @@ add_mlir_dialect_library(MLIRSPIRV
MLIRTransforms
)
+add_subdirectory(Linking)
add_subdirectory(Serialization)
add_subdirectory(Transforms)
diff --git a/mlir/lib/Dialect/SPIRV/Linking/CMakeLists.txt b/mlir/lib/Dialect/SPIRV/Linking/CMakeLists.txt
new file mode 100644
index 000000000000..4cc016812701
--- /dev/null
+++ b/mlir/lib/Dialect/SPIRV/Linking/CMakeLists.txt
@@ -0,0 +1 @@
+add_subdirectory(ModuleCombiner)
diff --git a/mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/CMakeLists.txt b/mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/CMakeLists.txt
new file mode 100644
index 000000000000..69af5a69ce8a
--- /dev/null
+++ b/mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/CMakeLists.txt
@@ -0,0 +1,11 @@
+add_mlir_dialect_library(MLIRSPIRVModuleCombiner
+ ModuleCombiner.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SPIRV
+
+ LINK_LIBS PUBLIC
+ MLIRIR
+ MLIRSPIRV
+ MLIRSupport
+ )
diff --git a/mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp b/mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp
new file mode 100644
index 000000000000..7687ab27e753
--- /dev/null
+++ b/mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp
@@ -0,0 +1,181 @@
+//===- ModuleCombiner.cpp - MLIR SPIR-V Module Combiner ---------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the the SPIR-V module combiner library.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SPIRV/ModuleCombiner.h"
+
+#include "mlir/Dialect/SPIRV/SPIRVOps.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/SymbolTable.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/StringExtras.h"
+
+using namespace mlir;
+
+static constexpr unsigned maxFreeID = 1 << 20;
+
+static SmallString<64> renameSymbol(StringRef oldSymName, unsigned &lastUsedID,
+ spirv::ModuleOp combinedModule) {
+ SmallString<64> newSymName(oldSymName);
+ newSymName.push_back('_');
+
+ while (lastUsedID < maxFreeID) {
+ std::string possible = (newSymName + llvm::utostr(++lastUsedID)).str();
+
+ if (!SymbolTable::lookupSymbolIn(combinedModule, possible)) {
+ newSymName += llvm::utostr(lastUsedID);
+ break;
+ }
+ }
+
+ return newSymName;
+}
+
+/// Check if a symbol with the same name as op already exists in source. If so,
+/// rename op and update all its references in target.
+static LogicalResult updateSymbolAndAllUses(SymbolOpInterface op,
+ spirv::ModuleOp target,
+ spirv::ModuleOp source,
+ unsigned &lastUsedID) {
+ if (!SymbolTable::lookupSymbolIn(source, op.getName()))
+ return success();
+
+ StringRef oldSymName = op.getName();
+ SmallString<64> newSymName = renameSymbol(oldSymName, lastUsedID, target);
+
+ if (failed(SymbolTable::replaceAllSymbolUses(op, newSymName, target)))
+ return op.emitError("unable to update all symbol uses for ")
+ << oldSymName << " to " << newSymName;
+
+ SymbolTable::setSymbolName(op, newSymName);
+ return success();
+}
+
+namespace mlir {
+namespace spirv {
+
+// TODO Properly test symbol rename listener mechanism.
+
+OwningSPIRVModuleRef
+combine(llvm::MutableArrayRef<spirv::ModuleOp> modules,
+ OpBuilder &combinedModuleBuilder,
+ llvm::function_ref<void(ModuleOp, StringRef, StringRef)>
+ symRenameListener) {
+ unsigned lastUsedID = 0;
+
+ if (modules.empty())
+ return nullptr;
+
+ auto addressingModel = modules[0].addressing_model();
+ auto memoryModel = modules[0].memory_model();
+
+ auto combinedModule = combinedModuleBuilder.create<spirv::ModuleOp>(
+ modules[0].getLoc(), addressingModel, memoryModel);
+ combinedModuleBuilder.setInsertionPointToStart(&*combinedModule.getBody());
+
+ // In some cases, a symbol in the (current state of the) combined module is
+ // renamed in order to maintain the conflicting symbol in the input module
+ // being merged. For example, if the conflict is between a global variable in
+ // the current combined module and a function in the input module, the global
+ // varaible is renamed. In order to notify listeners of the symbol updates in
+ // such cases, we need to keep track of the module from which the renamed
+ // symbol in the combined module originated. This map keeps such information.
+ DenseMap<StringRef, spirv::ModuleOp> symNameToModuleMap;
+
+ for (auto module : modules) {
+ if (module.addressing_model() != addressingModel ||
+ module.memory_model() != memoryModel) {
+ module.emitError(
+ "input modules
diff er in addressing model and/or memory model");
+ return nullptr;
+ }
+
+ spirv::ModuleOp moduleClone = module.clone();
+
+ // In the combined module, rename all symbols that conflict with symbols
+ // from the current input module. This renmaing applies to all ops except
+ // for spv.funcs. This way, if the conflicting op in the input module is
+ // non-spv.func, we rename that symbol instead and maintain the spv.func in
+ // the combined module name as it is.
+ for (auto &op : combinedModule.getBlock().without_terminator()) {
+ if (auto symbolOp = dyn_cast<SymbolOpInterface>(op)) {
+ StringRef oldSymName = symbolOp.getName();
+
+ if (!isa<FuncOp>(op) &&
+ failed(updateSymbolAndAllUses(symbolOp, combinedModule, moduleClone,
+ lastUsedID)))
+ return nullptr;
+
+ StringRef newSymName = symbolOp.getName();
+
+ if (symRenameListener && oldSymName != newSymName) {
+ spirv::ModuleOp originalModule =
+ symNameToModuleMap.lookup(oldSymName);
+
+ if (!originalModule) {
+ module.emitError("unable to find original ModuleOp for symbol ")
+ << oldSymName;
+ return nullptr;
+ }
+
+ symRenameListener(originalModule, oldSymName, newSymName);
+
+ // Since the symbol name is updated, there is no need to maintain the
+ // entry that assocaites the old symbol name with the original module.
+ symNameToModuleMap.erase(oldSymName);
+ // Instead, add a new entry to map the new symbol name to the original
+ // module in case it gets renamed again later.
+ symNameToModuleMap[newSymName] = originalModule;
+ }
+ }
+ }
+
+ // In the current input module, rename all symbols that conflict with
+ // symbols from the combined module. This includes renaming spv.funcs.
+ for (auto &op : moduleClone.getBlock().without_terminator()) {
+ if (auto symbolOp = dyn_cast<SymbolOpInterface>(op)) {
+ StringRef oldSymName = symbolOp.getName();
+
+ if (failed(updateSymbolAndAllUses(symbolOp, moduleClone, combinedModule,
+ lastUsedID)))
+ return nullptr;
+
+ StringRef newSymName = symbolOp.getName();
+
+ if (symRenameListener && oldSymName != newSymName) {
+ symRenameListener(module, oldSymName, newSymName);
+
+ // Insert the module associated with the symbol name.
+ auto emplaceResult =
+ symNameToModuleMap.try_emplace(symbolOp.getName(), module);
+
+ // If an entry with the same symbol name is already present, this must
+ // be a problem with the implementation, specially clean-up of the map
+ // while iterating over the combined module above.
+ if (!emplaceResult.second) {
+ module.emitError("did not expect to find an entry for symbol ")
+ << symbolOp.getName();
+ return nullptr;
+ }
+ }
+ }
+ }
+
+ // Clone all the module's ops to the combined module.
+ for (auto &op : moduleClone.getBlock().without_terminator())
+ combinedModuleBuilder.insert(op.clone());
+ }
+
+ return combinedModule;
+}
+
+} // namespace spirv
+} // namespace mlir
diff --git a/mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/basic.mlir b/mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/basic.mlir
new file mode 100644
index 000000000000..07fd41e4fe86
--- /dev/null
+++ b/mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/basic.mlir
@@ -0,0 +1,50 @@
+// RUN: mlir-opt -test-spirv-module-combiner -split-input-file -verify-diagnostics %s | FileCheck %s
+
+// CHECK: module {
+// CHECK-NEXT: spv.module Logical GLSL450 {
+// CHECK-NEXT: spv.specConstant @m1_sc
+// CHECK-NEXT: spv.specConstant @m2_sc
+// CHECK-NEXT: spv.func @variable_init_spec_constant
+// CHECK-NEXT: spv._reference_of @m2_sc
+// CHECK-NEXT: spv.Variable init
+// CHECK-NEXT: spv.Return
+// CHECK-NEXT: }
+// CHECK-NEXT: }
+// CHECK-NEXT: }
+
+module {
+spv.module Logical GLSL450 {
+ spv.specConstant @m1_sc = 42.42 : f32
+}
+
+spv.module Logical GLSL450 {
+ spv.specConstant @m2_sc = 42 : i32
+ spv.func @variable_init_spec_constant() -> () "None" {
+ %0 = spv._reference_of @m2_sc : i32
+ %1 = spv.Variable init(%0) : !spv.ptr<i32, Function>
+ spv.Return
+ }
+}
+}
+
+// -----
+
+module {
+spv.module Physical64 GLSL450 {
+}
+
+// expected-error @+1 {{input modules
diff er in addressing model and/or memory model}}
+spv.module Logical GLSL450 {
+}
+}
+
+// -----
+
+module {
+spv.module Logical Simple {
+}
+
+// expected-error @+1 {{input modules
diff er in addressing model and/or memory model}}
+spv.module Logical GLSL450 {
+}
+}
diff --git a/mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/conflict_resolution.mlir b/mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/conflict_resolution.mlir
new file mode 100644
index 000000000000..f5535c483171
--- /dev/null
+++ b/mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/conflict_resolution.mlir
@@ -0,0 +1,682 @@
+// RUN: mlir-opt -test-spirv-module-combiner -split-input-file -verify-diagnostics %s | FileCheck %s
+
+// Test basic renaming of conflicting funcOps.
+
+// CHECK: module {
+// CHECK-NEXT: spv.module Logical GLSL450 {
+// CHECK-NEXT: spv.func @foo
+// CHECK-NEXT: spv.ReturnValue
+// CHECK-NEXT: }
+
+// CHECK-NEXT: spv.func @foo_1
+// CHECK-NEXT: spv.ReturnValue
+// CHECK-NEXT: }
+// CHECK-NEXT: }
+// CHECK-NEXT: }
+
+module {
+spv.module Logical GLSL450 {
+ spv.func @foo(%arg0 : i32) -> i32 "None" {
+ spv.ReturnValue %arg0 : i32
+ }
+}
+
+spv.module Logical GLSL450 {
+ spv.func @foo(%arg0 : f32) -> f32 "None" {
+ spv.ReturnValue %arg0 : f32
+ }
+}
+}
+
+// -----
+
+// Test basic renaming of conflicting funcOps across 3 modules.
+
+// CHECK: module {
+// CHECK-NEXT: spv.module Logical GLSL450 {
+// CHECK-NEXT: spv.func @foo
+// CHECK-NEXT: spv.ReturnValue
+// CHECK-NEXT: }
+
+// CHECK-NEXT: spv.func @foo_1
+// CHECK-NEXT: spv.ReturnValue
+// CHECK-NEXT: }
+
+// CHECK-NEXT: spv.func @foo_2
+// CHECK-NEXT: spv.ReturnValue
+// CHECK-NEXT: }
+// CHECK-NEXT: }
+// CHECK-NEXT: }
+
+module {
+spv.module Logical GLSL450 {
+ spv.func @foo(%arg0 : i32) -> i32 "None" {
+ spv.ReturnValue %arg0 : i32
+ }
+}
+
+spv.module Logical GLSL450 {
+ spv.func @foo(%arg0 : f32) -> f32 "None" {
+ spv.ReturnValue %arg0 : f32
+ }
+}
+
+spv.module Logical GLSL450 {
+ spv.func @foo(%arg0 : i32) -> i32 "None" {
+ spv.ReturnValue %arg0 : i32
+ }
+}
+}
+
+// -----
+
+// Test properly updating references to a renamed funcOp.
+
+// CHECK: module {
+// CHECK-NEXT: spv.module Logical GLSL450 {
+// CHECK-NEXT: spv.func @foo
+// CHECK-NEXT: spv.ReturnValue
+// CHECK-NEXT: }
+
+// CHECK-NEXT: spv.func @foo_1
+// CHECK-NEXT: spv.ReturnValue
+// CHECK-NEXT: }
+
+// CHECK-NEXT: spv.func @bar
+// CHECK-NEXT: spv.FunctionCall @foo_1
+// CHECK-NEXT: spv.ReturnValue
+// CHECK-NEXT: }
+// CHECK-NEXT: }
+// CHECK-NEXT: }
+
+module {
+spv.module Logical GLSL450 {
+ spv.func @foo(%arg0 : i32) -> i32 "None" {
+ spv.ReturnValue %arg0 : i32
+ }
+}
+
+spv.module Logical GLSL450 {
+ spv.func @foo(%arg0 : f32) -> f32 "None" {
+ spv.ReturnValue %arg0 : f32
+ }
+
+ spv.func @bar(%arg0 : f32) -> f32 "None" {
+ %0 = spv.FunctionCall @foo(%arg0) : (f32) -> (f32)
+ spv.ReturnValue %0 : f32
+ }
+}
+}
+
+// -----
+
+// Test properly updating references to a renamed funcOp if the functionCallOp
+// preceeds the callee funcOp definition.
+
+// CHECK: module {
+// CHECK-NEXT: spv.module Logical GLSL450 {
+// CHECK-NEXT: spv.func @foo
+// CHECK-NEXT: spv.ReturnValue
+// CHECK-NEXT: }
+
+// CHECK-NEXT: spv.func @bar
+// CHECK-NEXT: spv.FunctionCall @foo_1
+// CHECK-NEXT: spv.ReturnValue
+// CHECK-NEXT: }
+
+// CHECK-NEXT: spv.func @foo_1
+// CHECK-NEXT: spv.ReturnValue
+// CHECK-NEXT: }
+// CHECK-NEXT: }
+// CHECK-NEXT: }
+
+module {
+spv.module Logical GLSL450 {
+ spv.func @foo(%arg0 : i32) -> i32 "None" {
+ spv.ReturnValue %arg0 : i32
+ }
+}
+
+spv.module Logical GLSL450 {
+ spv.func @bar(%arg0 : f32) -> f32 "None" {
+ %0 = spv.FunctionCall @foo(%arg0) : (f32) -> (f32)
+ spv.ReturnValue %0 : f32
+ }
+
+ spv.func @foo(%arg0 : f32) -> f32 "None" {
+ spv.ReturnValue %arg0 : f32
+ }
+}
+}
+
+// -----
+
+// Test properly updating entryPointOp and executionModeOp attached to renamed
+// funcOp.
+
+// CHECK: module {
+// CHECK-NEXT: spv.module Logical GLSL450 {
+// CHECK-NEXT: spv.func @foo
+// CHECK-NEXT: spv.ReturnValue
+// CHECK-NEXT: }
+
+// CHECK-NEXT: spv.func @foo_1
+// CHECK-NEXT: spv.ReturnValue
+// CHECK-NEXT: }
+
+// CHECK-NEXT: spv.EntryPoint "GLCompute" @foo_1
+// CHECK-NEXT: spv.ExecutionMode @foo_1 "ContractionOff"
+// CHECK-NEXT: }
+// CHECK-NEXT: }
+
+module {
+spv.module Logical GLSL450 {
+ spv.func @foo(%arg0 : i32) -> i32 "None" {
+ spv.ReturnValue %arg0 : i32
+ }
+}
+
+spv.module Logical GLSL450 {
+ spv.func @foo(%arg0 : f32) -> f32 "None" {
+ spv.ReturnValue %arg0 : f32
+ }
+
+ spv.EntryPoint "GLCompute" @foo
+ spv.ExecutionMode @foo "ContractionOff"
+}
+}
+
+// -----
+
+// CHECK: module {
+// CHECK-NEXT: spv.module Logical GLSL450 {
+// CHECK-NEXT: spv.func @foo
+// CHECK-NEXT: spv.ReturnValue
+// CHECK-NEXT: }
+
+// CHECK-NEXT: spv.EntryPoint "GLCompute" @fo
+// CHECK-NEXT: spv.ExecutionMode @foo "ContractionOff"
+
+// CHECK-NEXT: spv.func @foo_1
+// CHECK-NEXT: spv.ReturnValue
+// CHECK-NEXT: }
+
+// CHECK-NEXT: spv.EntryPoint "GLCompute" @foo_1
+// CHECK-NEXT: spv.ExecutionMode @foo_1 "ContractionOff"
+// CHECK-NEXT: }
+// CHECK-NEXT: }
+
+module {
+spv.module Logical GLSL450 {
+ spv.func @foo(%arg0 : i32) -> i32 "None" {
+ spv.ReturnValue %arg0 : i32
+ }
+
+ spv.EntryPoint "GLCompute" @foo
+ spv.ExecutionMode @foo "ContractionOff"
+}
+
+spv.module Logical GLSL450 {
+ spv.func @foo(%arg0 : f32) -> f32 "None" {
+ spv.ReturnValue %arg0 : f32
+ }
+
+ spv.EntryPoint "GLCompute" @foo
+ spv.ExecutionMode @foo "ContractionOff"
+}
+}
+
+// -----
+
+// Resolve conflicting funcOp and globalVariableOp.
+
+// CHECK: module {
+// CHECK-NEXT: spv.module Logical GLSL450 {
+// CHECK-NEXT: spv.func @foo
+// CHECK-NEXT: spv.ReturnValue
+// CHECK-NEXT: }
+
+// CHECK-NEXT: spv.globalVariable @foo_1
+// CHECK-NEXT: }
+
+module {
+spv.module Logical GLSL450 {
+ spv.func @foo(%arg0 : i32) -> i32 "None" {
+ spv.ReturnValue %arg0 : i32
+ }
+}
+
+spv.module Logical GLSL450 {
+ spv.globalVariable @foo bind(1, 0) : !spv.ptr<f32, Input>
+}
+}
+
+// -----
+
+// Resolve conflicting funcOp and globalVariableOp and update the global variable's
+// references.
+
+// CHECK: module {
+// CHECK-NEXT: spv.module Logical GLSL450 {
+// CHECK-NEXT: spv.func @foo
+// CHECK-NEXT: spv.ReturnValue
+// CHECK-NEXT: }
+
+// CHECK-NEXT: spv.globalVariable @foo_1
+// CHECK-NEXT: spv.func @bar
+// CHECK-NEXT: spv._address_of @foo_1
+// CHECK-NEXT: spv.Load
+// CHECK-NEXT: spv.ReturnValue
+// CHECK-NEXT: }
+// CHECK-NEXT: }
+
+module {
+spv.module Logical GLSL450 {
+ spv.func @foo(%arg0 : i32) -> i32 "None" {
+ spv.ReturnValue %arg0 : i32
+ }
+}
+
+spv.module Logical GLSL450 {
+ spv.globalVariable @foo bind(1, 0) : !spv.ptr<f32, Input>
+
+ spv.func @bar() -> f32 "None" {
+ %0 = spv._address_of @foo : !spv.ptr<f32, Input>
+ %1 = spv.Load "Input" %0 : f32
+ spv.ReturnValue %1 : f32
+ }
+}
+}
+
+// -----
+
+// Resolve conflicting globalVariableOp and funcOp and update the global variable's
+// references.
+
+// CHECK: module {
+// CHECK-NEXT: spv.module Logical GLSL450 {
+// CHECK-NEXT: spv.globalVariable @foo_1
+// CHECK-NEXT: spv.func @bar
+// CHECK-NEXT: spv._address_of @foo_1
+// CHECK-NEXT: spv.Load
+// CHECK-NEXT: spv.ReturnValue
+// CHECK-NEXT: }
+
+// CHECK-NEXT: spv.func @foo
+// CHECK-NEXT: spv.ReturnValue
+// CHECK-NEXT: }
+// CHECK-NEXT: }
+
+module {
+spv.module Logical GLSL450 {
+ spv.globalVariable @foo bind(1, 0) : !spv.ptr<f32, Input>
+
+ spv.func @bar() -> f32 "None" {
+ %0 = spv._address_of @foo : !spv.ptr<f32, Input>
+ %1 = spv.Load "Input" %0 : f32
+ spv.ReturnValue %1 : f32
+ }
+}
+
+spv.module Logical GLSL450 {
+ spv.func @foo(%arg0 : i32) -> i32 "None" {
+ spv.ReturnValue %arg0 : i32
+ }
+}
+}
+
+// -----
+
+// Resolve conflicting funcOp and specConstantOp.
+
+// CHECK: module {
+// CHECK-NEXT: spv.module Logical GLSL450 {
+// CHECK-NEXT: spv.func @foo
+// CHECK-NEXT: spv.ReturnValue
+// CHECK-NEXT: }
+
+// CHECK-NEXT: spv.specConstant @foo_1
+// CHECK-NEXT: }
+
+module {
+spv.module Logical GLSL450 {
+ spv.func @foo(%arg0 : i32) -> i32 "None" {
+ spv.ReturnValue %arg0 : i32
+ }
+}
+
+spv.module Logical GLSL450 {
+ spv.specConstant @foo = -5 : i32
+}
+}
+
+// -----
+
+// Resolve conflicting funcOp and specConstantOp and update the spec constant's
+// references.
+
+// CHECK: module {
+// CHECK-NEXT: spv.module Logical GLSL450 {
+// CHECK-NEXT: spv.func @foo
+// CHECK-NEXT: spv.ReturnValue
+// CHECK-NEXT: }
+
+// CHECK-NEXT: spv.specConstant @foo_1
+// CHECK-NEXT: spv.func @bar
+// CHECK-NEXT: spv._reference_of @foo_1
+// CHECK-NEXT: spv.ReturnValue
+// CHECK-NEXT: }
+// CHECK-NEXT: }
+
+module {
+spv.module Logical GLSL450 {
+ spv.func @foo(%arg0 : i32) -> i32 "None" {
+ spv.ReturnValue %arg0 : i32
+ }
+}
+
+spv.module Logical GLSL450 {
+ spv.specConstant @foo = -5 : i32
+
+ spv.func @bar() -> i32 "None" {
+ %0 = spv._reference_of @foo : i32
+ spv.ReturnValue %0 : i32
+ }
+}
+}
+
+// -----
+
+// Resolve conflicting specConstantOp and funcOp and update the spec constant's
+// references.
+
+// CHECK: module {
+// CHECK-NEXT: spv.module Logical GLSL450 {
+// CHECK-NEXT: spv.specConstant @foo_1
+// CHECK-NEXT: spv.func @bar
+// CHECK-NEXT: spv._reference_of @foo_1
+// CHECK-NEXT: spv.ReturnValue
+// CHECK-NEXT: }
+
+// CHECK-NEXT: spv.func @foo
+// CHECK-NEXT: spv.ReturnValue
+// CHECK-NEXT: }
+// CHECK-NEXT: }
+
+module {
+spv.module Logical GLSL450 {
+ spv.specConstant @foo = -5 : i32
+
+ spv.func @bar() -> i32 "None" {
+ %0 = spv._reference_of @foo : i32
+ spv.ReturnValue %0 : i32
+ }
+}
+
+spv.module Logical GLSL450 {
+ spv.func @foo(%arg0 : i32) -> i32 "None" {
+ spv.ReturnValue %arg0 : i32
+ }
+}
+}
+
+// -----
+
+// Resolve conflicting funcOp and specConstantCompositeOp.
+
+// CHECK: module {
+// CHECK-NEXT: spv.module Logical GLSL450 {
+// CHECK-NEXT: spv.func @foo
+// CHECK-NEXT: spv.ReturnValue
+// CHECK-NEXT: }
+
+// CHECK-NEXT: spv.specConstant @bar
+// CHECK-NEXT: spv.specConstantComposite @foo_1 (@bar, @bar)
+// CHECK-NEXT: }
+
+module {
+spv.module Logical GLSL450 {
+ spv.func @foo(%arg0 : i32) -> i32 "None" {
+ spv.ReturnValue %arg0 : i32
+ }
+}
+
+spv.module Logical GLSL450 {
+ spv.specConstant @bar = -5 : i32
+ spv.specConstantComposite @foo (@bar, @bar) : !spv.array<2 x i32>
+}
+}
+
+// -----
+
+// Resolve conflicting funcOp and specConstantCompositeOp and update the spec
+// constant's references.
+
+// CHECK: module {
+// CHECK-NEXT: spv.module Logical GLSL450 {
+// CHECK-NEXT: spv.func @foo
+// CHECK-NEXT: spv.ReturnValue
+// CHECK-NEXT: }
+
+// CHECK-NEXT: spv.specConstant @bar
+// CHECK-NEXT: spv.specConstantComposite @foo_1 (@bar, @bar)
+// CHECK-NEXT: spv.func @baz
+// CHECK-NEXT: spv._reference_of @foo_1
+// CHECK-NEXT: spv.CompositeExtract
+// CHECK-NEXT: spv.ReturnValue
+// CHECK-NEXT: }
+// CHECK-NEXT: }
+
+module {
+spv.module Logical GLSL450 {
+ spv.func @foo(%arg0 : i32) -> i32 "None" {
+ spv.ReturnValue %arg0 : i32
+ }
+}
+
+spv.module Logical GLSL450 {
+ spv.specConstant @bar = -5 : i32
+ spv.specConstantComposite @foo (@bar, @bar) : !spv.array<2 x i32>
+
+ spv.func @baz() -> i32 "None" {
+ %0 = spv._reference_of @foo : !spv.array<2 x i32>
+ %1 = spv.CompositeExtract %0[0 : i32] : !spv.array<2 x i32>
+ spv.ReturnValue %1 : i32
+ }
+}
+}
+
+// -----
+
+// Resolve conflicting specConstantCompositeOp and funcOp and update the spec
+// constant's references.
+
+// CHECK: module {
+// CHECK-NEXT: spv.module Logical GLSL450 {
+// CHECK-NEXT: spv.specConstant @bar
+// CHECK-NEXT: spv.specConstantComposite @foo_1 (@bar, @bar)
+// CHECK-NEXT: spv.func @baz
+// CHECK-NEXT: spv._reference_of @foo_1
+// CHECK-NEXT: spv.CompositeExtract
+// CHECK-NEXT: spv.ReturnValue
+// CHECK-NEXT: }
+
+// CHECK-NEXT: spv.func @foo
+// CHECK-NEXT: spv.ReturnValue
+// CHECK-NEXT: }
+// CHECK-NEXT: }
+
+module {
+spv.module Logical GLSL450 {
+ spv.specConstant @bar = -5 : i32
+ spv.specConstantComposite @foo (@bar, @bar) : !spv.array<2 x i32>
+
+ spv.func @baz() -> i32 "None" {
+ %0 = spv._reference_of @foo : !spv.array<2 x i32>
+ %1 = spv.CompositeExtract %0[0 : i32] : !spv.array<2 x i32>
+ spv.ReturnValue %1 : i32
+ }
+}
+
+spv.module Logical GLSL450 {
+ spv.func @foo(%arg0 : i32) -> i32 "None" {
+ spv.ReturnValue %arg0 : i32
+ }
+}
+}
+
+// -----
+
+// Resolve conflicting spec constants and funcOps and update the spec constant's
+// references.
+
+// CHECK: module {
+// CHECK-NEXT: spv.module Logical GLSL450 {
+// CHECK-NEXT: spv.specConstant @bar_1
+// CHECK-NEXT: spv.specConstantComposite @foo_2 (@bar_1, @bar_1)
+// CHECK-NEXT: spv.func @baz
+// CHECK-NEXT: spv._reference_of @foo_2
+// CHECK-NEXT: spv.CompositeExtract
+// CHECK-NEXT: spv.ReturnValue
+// CHECK-NEXT: }
+
+// CHECK-NEXT: spv.func @foo
+// CHECK-NEXT: spv.ReturnValue
+// CHECK-NEXT: }
+
+// CHECK-NEXT: spv.func @bar
+// CHECK-NEXT: spv.ReturnValue
+// CHECK-NEXT: }
+// CHECK-NEXT: }
+
+module {
+spv.module Logical GLSL450 {
+ spv.specConstant @bar = -5 : i32
+ spv.specConstantComposite @foo (@bar, @bar) : !spv.array<2 x i32>
+
+ spv.func @baz() -> i32 "None" {
+ %0 = spv._reference_of @foo : !spv.array<2 x i32>
+ %1 = spv.CompositeExtract %0[0 : i32] : !spv.array<2 x i32>
+ spv.ReturnValue %1 : i32
+ }
+}
+
+spv.module Logical GLSL450 {
+ spv.func @foo(%arg0 : i32) -> i32 "None" {
+ spv.ReturnValue %arg0 : i32
+ }
+
+ spv.func @bar(%arg0 : f32) -> f32 "None" {
+ spv.ReturnValue %arg0 : f32
+ }
+}
+}
+
+// -----
+
+// Resolve conflicting globalVariableOps.
+
+// CHECK: module {
+// CHECK-NEXT: spv.module Logical GLSL450 {
+// CHECK-NEXT: spv.globalVariable @foo_1
+
+// CHECK-NEXT: spv.globalVariable @foo
+// CHECK-NEXT: }
+
+module {
+spv.module Logical GLSL450 {
+ spv.globalVariable @foo bind(1, 0) : !spv.ptr<f32, Input>
+}
+
+spv.module Logical GLSL450 {
+ spv.globalVariable @foo bind(1, 0) : !spv.ptr<f32, Input>
+}
+}
+
+// -----
+
+// Resolve conflicting globalVariableOp and specConstantOp.
+
+// CHECK: module {
+// CHECK-NEXT: spv.module Logical GLSL450 {
+// CHECK-NEXT: spv.globalVariable @foo_1
+
+// CHECK-NEXT: spv.specConstant @foo
+// CHECK-NEXT: }
+
+module {
+spv.module Logical GLSL450 {
+ spv.globalVariable @foo bind(1, 0) : !spv.ptr<f32, Input>
+}
+
+spv.module Logical GLSL450 {
+ spv.specConstant @foo = -5 : i32
+}
+}
+
+// -----
+
+// Resolve conflicting specConstantOp and globalVariableOp.
+
+// CHECK: module {
+// CHECK-NEXT: spv.module Logical GLSL450 {
+// CHECK-NEXT: spv.specConstant @foo_1
+
+// CHECK-NEXT: spv.globalVariable @foo
+// CHECK-NEXT: }
+
+module {
+spv.module Logical GLSL450 {
+ spv.specConstant @foo = -5 : i32
+}
+
+spv.module Logical GLSL450 {
+ spv.globalVariable @foo bind(1, 0) : !spv.ptr<f32, Input>
+}
+}
+
+// -----
+
+// Resolve conflicting globalVariableOp and specConstantCompositeOp.
+
+// CHECK: module {
+// CHECK-NEXT: spv.module Logical GLSL450 {
+// CHECK-NEXT: spv.globalVariable @foo_1
+
+// CHECK-NEXT: spv.specConstant @bar
+// CHECK-NEXT: spv.specConstantComposite @foo (@bar, @bar)
+// CHECK-NEXT: }
+
+module {
+spv.module Logical GLSL450 {
+ spv.globalVariable @foo bind(1, 0) : !spv.ptr<f32, Input>
+}
+
+spv.module Logical GLSL450 {
+ spv.specConstant @bar = -5 : i32
+ spv.specConstantComposite @foo (@bar, @bar) : !spv.array<2 x i32>
+}
+}
+
+// -----
+
+// Resolve conflicting globalVariableOp and specConstantComposite.
+
+// CHECK: module {
+// CHECK-NEXT: spv.module Logical GLSL450 {
+// CHECK-NEXT: spv.specConstant @bar
+// CHECK-NEXT: spv.specConstantComposite @foo_1 (@bar, @bar)
+
+// CHECK-NEXT: spv.globalVariable @foo
+// CHECK-NEXT: }
+
+module {
+spv.module Logical GLSL450 {
+ spv.specConstant @bar = -5 : i32
+ spv.specConstantComposite @foo (@bar, @bar) : !spv.array<2 x i32>
+}
+
+spv.module Logical GLSL450 {
+ spv.globalVariable @foo bind(1, 0) : !spv.ptr<f32, Input>
+}
+}
diff --git a/mlir/test/lib/Dialect/SPIRV/CMakeLists.txt b/mlir/test/lib/Dialect/SPIRV/CMakeLists.txt
index 204a63337730..6c74d2f26357 100644
--- a/mlir/test/lib/Dialect/SPIRV/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/SPIRV/CMakeLists.txt
@@ -2,6 +2,7 @@
add_mlir_library(MLIRSPIRVTestPasses
TestAvailability.cpp
TestEntryPointAbi.cpp
+ TestModuleCombiner.cpp
EXCLUDE_FROM_LIBMLIR
@@ -14,5 +15,6 @@ add_mlir_library(MLIRSPIRVTestPasses
MLIRIR
MLIRPass
MLIRSPIRV
+ MLIRSPIRVModuleCombiner
MLIRSupport
)
diff --git a/mlir/test/lib/Dialect/SPIRV/TestModuleCombiner.cpp b/mlir/test/lib/Dialect/SPIRV/TestModuleCombiner.cpp
new file mode 100644
index 000000000000..b321954c87f3
--- /dev/null
+++ b/mlir/test/lib/Dialect/SPIRV/TestModuleCombiner.cpp
@@ -0,0 +1,48 @@
+//===- TestModuleCombiner.cpp - Pass to test SPIR-V module combiner lib ---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SPIRV/ModuleCombiner.h"
+
+#include "mlir/Dialect/SPIRV/SPIRVOps.h"
+#include "mlir/Dialect/SPIRV/SPIRVTypes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Module.h"
+#include "mlir/Pass/Pass.h"
+
+using namespace mlir;
+
+namespace {
+class TestModuleCombinerPass
+ : public PassWrapper<TestModuleCombinerPass,
+ OperationPass<mlir::ModuleOp>> {
+public:
+ TestModuleCombinerPass() = default;
+ TestModuleCombinerPass(const TestModuleCombinerPass &) {}
+ void runOnOperation() override;
+
+private:
+ mlir::spirv::OwningSPIRVModuleRef combinedModule;
+};
+} // namespace
+
+void TestModuleCombinerPass::runOnOperation() {
+ auto modules = llvm::to_vector<4>(getOperation().getOps<spirv::ModuleOp>());
+
+ OpBuilder combinedModuleBuilder(modules[0]);
+ combinedModule = spirv::combine(modules, combinedModuleBuilder, nullptr);
+
+ for (spirv::ModuleOp module : modules)
+ module.erase();
+}
+
+namespace mlir {
+void registerTestSpirvModuleCombinerPass() {
+ PassRegistration<TestModuleCombinerPass> registration(
+ "test-spirv-module-combiner", "Tests SPIR-V module combiner library");
+}
+} // namespace mlir
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 196bda69dbaf..b5506a5a34a0 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -79,6 +79,7 @@ void registerTestPrintNestingPass();
void registerTestRecursiveTypesPass();
void registerTestReducer();
void registerTestSpirvEntryPointABIPass();
+void registerTestSpirvModuleCombinerPass();
void registerTestSCFUtilsPass();
void registerTestTraitsPass();
void registerTestVectorConversions();
@@ -140,6 +141,7 @@ void registerTestPasses() {
registerTestReducer();
registerTestGpuParallelLoopMappingPass();
registerTestSpirvEntryPointABIPass();
+ registerTestSpirvModuleCombinerPass();
registerTestSCFUtilsPass();
registerTestTraitsPass();
registerTestVectorConversions();
More information about the Mlir-commits
mailing list