[Mlir-commits] [mlir] 23326b9 - [mlir][spirv] Fix a few issues in ModuleCombiner
Lei Zhang
llvmlistbot at llvm.org
Wed Jul 28 07:36:26 PDT 2021
Author: Lei Zhang
Date: 2021-07-28T10:31:01-04:00
New Revision: 23326b9f1723a398681def87c80e608fa94485f2
URL: https://github.com/llvm/llvm-project/commit/23326b9f1723a398681def87c80e608fa94485f2
DIFF: https://github.com/llvm/llvm-project/commit/23326b9f1723a398681def87c80e608fa94485f2.diff
LOG: [mlir][spirv] Fix a few issues in ModuleCombiner
- Fixed symbol insertion into `symNameToModuleMap`. Insertion
needs to happen whether symbols are renamed or not.
- Added check for the VCE triple and avoid dropping it.
- Disabled function deduplication. It requires more careful
rules. Right now it can remove different functions.
- Added tests for symbol rename listener.
- And some other code/comment cleanups.
Reviewed By: ergawy
Differential Revision: https://reviews.llvm.org/D106886
Added:
mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/symbol-rename-listener.mlir
Modified:
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.h
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
mlir/include/mlir/Dialect/SPIRV/Linking/ModuleCombiner.h
mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp
mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/basic.mlir
mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/conflict-resolution.mlir
mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/deduplication.mlir
mlir/test/lib/Dialect/SPIRV/TestModuleCombiner.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.h
index 410fc946dd08..c4f5fd1192c4 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.h
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.h
@@ -13,6 +13,7 @@
#ifndef MLIR_DIALECT_SPIRV_IR_SPIRVOPS_H_
#define MLIR_DIALECT_SPIRV_IR_SPIRVOPS_H_
+#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOpTraits.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
#include "mlir/IR/BuiltinOps.h"
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
index 683c6dab183f..bb4e558cde54 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
@@ -467,8 +467,9 @@ def SPV_ModuleOp : SPV_Op<"module",
let builders = [
OpBuilder<(ins CArg<"Optional<StringRef>", "llvm::None">:$name)>,
OpBuilder<(ins "spirv::AddressingModel":$addressing_model,
- "spirv::MemoryModel":$memory_model,
- CArg<"Optional<StringRef>", "llvm::None">:$name)>
+ "spirv::MemoryModel":$memory_model,
+ CArg<"Optional<spirv::VerCapExtAttr>", "llvm::None">:$vce_triple,
+ CArg<"Optional<StringRef>", "llvm::None">:$name)>
];
// We need to ensure the block inside the region is properly terminated;
diff --git a/mlir/include/mlir/Dialect/SPIRV/Linking/ModuleCombiner.h b/mlir/include/mlir/Dialect/SPIRV/Linking/ModuleCombiner.h
index c2eaef037d10..833da17452f1 100644
--- a/mlir/include/mlir/Dialect/SPIRV/Linking/ModuleCombiner.h
+++ b/mlir/include/mlir/Dialect/SPIRV/Linking/ModuleCombiner.h
@@ -22,53 +22,54 @@ class OpBuilder;
namespace spirv {
class ModuleOp;
-/// To combine a number of MLIR SPIR-V modules, we move all the module-level ops
+/// The listener function to receive symbol renaming events.
+///
+/// `originalModule` is the input spirv::ModuleOp that contains the renamed
+/// symbol. `oldSymbol` and `newSymbol` are the original and renamed symbol.
+/// Note that it's the responsibility of the caller to properly retain the
+/// storage underlying the passed StringRefs if the listener callback outlives
+/// this function call.
+using SymbolRenameListener = function_ref<void(
+ spirv::ModuleOp originalModule, StringRef oldSymbol, StringRef newSymbol)>;
+
+/// Combines a list of SPIR-V `inputModules` into one. Returns the combined
+/// module on success; returns a null module otherwise.
+//
+/// \param inputModules the list of modules to combine. They won't be modified.
+/// \param combinedMdouleBuilder an OpBuilder for building the combined module.
+/// \param symbRenameListener a listener that gets called everytime a symbol in
+/// one of the input modules is renamed.
+///
+/// To combine multiple SPIR-V modules, we move all the module-level ops
/// from all the input modules into one big combined module. To that end, the
/// combination process proceeds in 2 phases:
///
-/// (1) resolve conflicts between pairs of ops from
diff erent modules
-/// (2) deduplicate equivalent ops/sub-ops in the merged module.
+/// 1. resolve conflicts between pairs of ops from
diff erent modules,
+/// 2. deduplicate equivalent ops/sub-ops in the merged module.
///
/// For the conflict resolution phase, the following rules are employed to
/// resolve such conflicts:
///
-/// - If 2 spv.func's have the same symbol name, then rename one of the
+/// - If 2 spv.func's have the same symbol name, then rename one of the
/// functions.
-/// - If an spv.func and another op have the same symbol name, then rename the
+/// - If an spv.func and another op have the same symbol name, then rename the
/// other symbol.
-/// - If none of the 2 conflicting ops are spv.func, then rename either.
+/// - If none of the 2 conflicting ops are spv.func, then rename either.
///
/// For deduplication, the following 3 cases are taken into consideration:
///
-/// - If 2 spv.GlobalVariable's have either the same descriptor set + binding
+/// - If 2 spv.GlobalVariable's have either the same descriptor set + binding
/// or the same build_in attribute value, then replace one of them using the
/// other.
-/// - If 2 spv.SpecConstant's have the same spec_id attribute value, then
+/// - If 2 spv.SpecConstant's have the same spec_id attribute value, then
/// replace one of them using the other.
-/// - If 2 spv.func's are identical replace one of them using the other.
+/// - Deduplicating functions are not supported right now.
///
/// In all cases, the references to the updated symbol (whether renamed or
/// deduplicated) are also updated to reflect the change.
-///
-/// \param modules the list of modules to combine. Input modules are not
-/// modified.
-/// \param combinedMdouleBuilder an OpBuilder to be used for
-// building up the combined module.
-/// \param symbRenameListener a listener that gets called everytime a symbol in
-/// one of the input modules is renamed. The arguments
-/// passed to the listener are: the input
-/// spirv::ModuleOp that contains the renamed symbol,
-/// a StringRef to the old symbol name, and a
-/// StringRef to the new symbol name. Note that it is
-/// the responsibility of the caller to properly
-/// retain the storage underlying the passed
-/// StringRefs if the listener callback outlives this
-/// function call.
-///
-/// \return the combined module.
-OwningOpRef<spirv::ModuleOp>
-combine(MutableArrayRef<ModuleOp> modules, OpBuilder &combinedModuleBuilder,
- function_ref<void(ModuleOp, StringRef, StringRef)> symbRenameListener);
+OwningOpRef<spirv::ModuleOp> combine(ArrayRef<spirv::ModuleOp> inputModules,
+ OpBuilder &combinedModuleBuilder,
+ SymbolRenameListener symRenameListener);
} // namespace spirv
} // namespace mlir
diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
index 5aaec815a83b..79f68e2e476f 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
@@ -310,7 +310,7 @@ LogicalResult GPUModuleConversion::matchAndRewrite(
// Add a keyword to the module name to avoid symbolic conflict.
std::string spvModuleName = (kSPIRVModule + moduleOp.getName()).str();
auto spvModule = rewriter.create<spirv::ModuleOp>(
- moduleOp.getLoc(), addressingModel, memoryModel.getValue(),
+ moduleOp.getLoc(), addressingModel, memoryModel.getValue(), llvm::None,
StringRef(spvModuleName));
// Move the region from the module op into the SPIR-V module.
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 471fa2c1b4a3..c03388d266b5 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -2540,6 +2540,7 @@ void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state,
void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state,
spirv::AddressingModel addressingModel,
spirv::MemoryModel memoryModel,
+ Optional<VerCapExtAttr> vceTriple,
Optional<StringRef> name) {
state.addAttribute(
"addressing_model",
@@ -2548,10 +2549,11 @@ void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state,
static_cast<int32_t>(memoryModel)));
OpBuilder::InsertionGuard guard(builder);
builder.createBlock(state.addRegion());
- if (name) {
- state.attributes.append(mlir::SymbolTable::getSymbolAttrName(),
- builder.getStringAttr(*name));
- }
+ if (vceTriple)
+ state.addAttribute(getVCETripleAttrName(), *vceTriple);
+ if (name)
+ state.addAttribute(mlir::SymbolTable::getSymbolAttrName(),
+ builder.getStringAttr(*name));
}
static ParseResult parseModuleOp(OpAsmParser &parser, OperationState &state) {
diff --git a/mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp b/mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp
index 5fc948a07090..1007603d3cf2 100644
--- a/mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp
+++ b/mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp
@@ -12,27 +12,33 @@
#include "mlir/Dialect/SPIRV/Linking/ModuleCombiner.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
+#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/SymbolTable.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/Hashing.h"
+#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringExtras.h"
+#include "llvm/ADT/StringMap.h"
using namespace mlir;
static constexpr unsigned maxFreeID = 1 << 20;
+/// Returns an unsed symbol in `module` for `oldSymbolName` by trying numeric
+/// suffix in `lastUsedID`.
static SmallString<64> renameSymbol(StringRef oldSymName, unsigned &lastUsedID,
- spirv::ModuleOp combinedModule) {
+ spirv::ModuleOp module) {
SmallString<64> newSymName(oldSymName);
newSymName.push_back('_');
while (lastUsedID < maxFreeID) {
std::string possible = (newSymName + llvm::utostr(++lastUsedID)).str();
- if (!SymbolTable::lookupSymbolIn(combinedModule, possible)) {
+ if (!SymbolTable::lookupSymbolIn(module, possible)) {
newSymName += llvm::utostr(lastUsedID);
break;
}
@@ -41,8 +47,8 @@ static SmallString<64> renameSymbol(StringRef oldSymName, unsigned &lastUsedID,
return newSymName;
}
-/// Check if a symbol with the same name as op already exists in source. If so,
-/// rename op and update all its references in target.
+/// Checks if a symbol with the same name as `op` already exists in `source`.
+/// If so, renames `op` and updates all its references in `target`.
static LogicalResult updateSymbolAndAllUses(SymbolOpInterface op,
spirv::ModuleOp target,
spirv::ModuleOp source,
@@ -61,99 +67,67 @@ static LogicalResult updateSymbolAndAllUses(SymbolOpInterface op,
return success();
}
-template <typename KeyTy, typename SymbolOpTy>
-static SymbolOpTy
-emplaceOrGetReplacementSymbol(KeyTy key, SymbolOpTy symbolOp,
- DenseMap<KeyTy, SymbolOpTy> &deduplicationMap) {
- auto result = deduplicationMap.try_emplace(key, symbolOp);
-
- if (result.second)
- return SymbolOpTy();
-
- return result.first->second;
-}
-
-/// Computes a hash code to represent the argument SymbolOpInterface based on
-/// all the Op's attributes except for the symbol name.
-///
-/// \return the hash code computed from the Op's attributes as described above.
+/// Computes a hash code to represent `symbolOp` based on all its attributes
+/// except for the symbol name.
///
/// Note: We use the operation's name (not the symbol name) as part of the hash
/// computation. This prevents, for example, mistakenly considering a global
/// variable and a spec constant as duplicates because their descriptor set +
/// binding and spec_id, respectively, happen to hash to the same value.
static llvm::hash_code computeHash(SymbolOpInterface symbolOp) {
- llvm::hash_code hashCode(0);
- hashCode = llvm::hash_combine(symbolOp->getName());
-
- for (auto attr : symbolOp->getAttrs()) {
- if (attr.first == SymbolTable::getSymbolAttrName())
- continue;
- hashCode = llvm::hash_combine(hashCode, attr);
- }
-
- return hashCode;
-}
-
-/// Computes a hash code from the argument Block.
-llvm::hash_code computeHash(Block *block) {
- // TODO: Consider extracting BlockEquivalenceData into a common header and
- // re-using it here.
- llvm::hash_code hash(0);
-
- for (Operation &op : *block) {
- // TODO: Properly handle operations with regions.
- if (op.getNumRegions() > 0)
- return 0;
-
- hash = llvm::hash_combine(
- hash, OperationEquivalence::computeHash(
- &op, OperationEquivalence::Flags::IgnoreOperands));
- }
-
- return hash;
+ auto range =
+ llvm::make_filter_range(symbolOp->getAttrs(), [](NamedAttribute attr) {
+ return attr.first != SymbolTable::getSymbolAttrName();
+ });
+
+ return llvm::hash_combine(
+ symbolOp->getName(),
+ llvm::hash_combine_range(range.begin(), range.end()));
}
namespace mlir {
namespace spirv {
-// TODO Properly test symbol rename listener mechanism.
-
-OwningOpRef<spirv::ModuleOp>
-combine(llvm::MutableArrayRef<spirv::ModuleOp> modules,
- OpBuilder &combinedModuleBuilder,
- llvm::function_ref<void(ModuleOp, StringRef, StringRef)>
- symRenameListener) {
- unsigned lastUsedID = 0;
-
- if (modules.empty())
+OwningOpRef<spirv::ModuleOp> combine(ArrayRef<spirv::ModuleOp> inputModules,
+ OpBuilder &combinedModuleBuilder,
+ SymbolRenameListener symRenameListener) {
+ if (inputModules.empty())
return nullptr;
- auto addressingModel = modules[0].addressing_model();
- auto memoryModel = modules[0].memory_model();
+ spirv::ModuleOp firstModule = inputModules.front();
+ auto addressingModel = firstModule.addressing_model();
+ auto memoryModel = firstModule.memory_model();
+ auto vceTriple = firstModule.vce_triple();
+
+ // First check whether there are conflicts between addressing/memory model.
+ // Return early if so.
+ for (auto module : inputModules) {
+ if (module.addressing_model() != addressingModel ||
+ module.memory_model() != memoryModel ||
+ module.vce_triple() != vceTriple) {
+ module.emitError("input modules
diff er in addressing model, memory "
+ "model, and/or VCE triple");
+ return nullptr;
+ }
+ }
auto combinedModule = combinedModuleBuilder.create<spirv::ModuleOp>(
- modules[0].getLoc(), addressingModel, memoryModel);
+ firstModule.getLoc(), addressingModel, memoryModel, vceTriple);
combinedModuleBuilder.setInsertionPointToStart(combinedModule.getBody());
// In some cases, a symbol in the (current state of the) combined module is
- // renamed in order to maintain the conflicting symbol in the input module
+ // renamed in order to enable the conflicting symbol in the input module
// being merged. For example, if the conflict is between a global variable in
// the current combined module and a function in the input module, the global
// variable is renamed. In order to notify listeners of the symbol updates in
// such cases, we need to keep track of the module from which the renamed
// symbol in the combined module originated. This map keeps such information.
- DenseMap<StringRef, spirv::ModuleOp> symNameToModuleMap;
+ llvm::StringMap<spirv::ModuleOp> symNameToModuleMap;
- for (auto module : modules) {
- if (module.addressing_model() != addressingModel ||
- module.memory_model() != memoryModel) {
- module.emitError(
- "input modules
diff er in addressing model and/or memory model");
- return nullptr;
- }
+ unsigned lastUsedID = 0;
- spirv::ModuleOp moduleClone = module.clone();
+ for (auto inputModule : inputModules) {
+ spirv::ModuleOp moduleClone = inputModule.clone();
// In the combined module, rename all symbols that conflict with symbols
// from the current input module. This renaming applies to all ops except
@@ -161,65 +135,70 @@ combine(llvm::MutableArrayRef<spirv::ModuleOp> modules,
// non-spv.func, we rename that symbol instead and maintain the spv.func in
// the combined module name as it is.
for (auto &op : *combinedModule.getBody()) {
- if (auto symbolOp = dyn_cast<SymbolOpInterface>(op)) {
- StringRef oldSymName = symbolOp.getName();
+ auto symbolOp = dyn_cast<SymbolOpInterface>(op);
+ if (!symbolOp)
+ continue;
- if (!isa<FuncOp>(op) &&
- failed(updateSymbolAndAllUses(symbolOp, combinedModule, moduleClone,
- lastUsedID)))
- return nullptr;
+ StringRef oldSymName = symbolOp.getName();
- StringRef newSymName = symbolOp.getName();
+ if (!isa<FuncOp>(op) &&
+ failed(updateSymbolAndAllUses(symbolOp, combinedModule, moduleClone,
+ lastUsedID)))
+ return nullptr;
- if (symRenameListener && oldSymName != newSymName) {
- spirv::ModuleOp originalModule =
- symNameToModuleMap.lookup(oldSymName);
+ StringRef newSymName = symbolOp.getName();
- if (!originalModule) {
- module.emitError("unable to find original ModuleOp for symbol ")
- << oldSymName;
- return nullptr;
- }
+ if (symRenameListener && oldSymName != newSymName) {
+ spirv::ModuleOp originalModule = symNameToModuleMap.lookup(oldSymName);
- symRenameListener(originalModule, oldSymName, newSymName);
-
- // Since the symbol name is updated, there is no need to maintain the
- // entry that associates the old symbol name with the original module.
- symNameToModuleMap.erase(oldSymName);
- // Instead, add a new entry to map the new symbol name to the original
- // module in case it gets renamed again later.
- symNameToModuleMap[newSymName] = originalModule;
+ if (!originalModule) {
+ inputModule.emitError(
+ "unable to find original spirv::ModuleOp for symbol ")
+ << oldSymName;
+ return nullptr;
}
+
+ symRenameListener(originalModule, oldSymName, newSymName);
+
+ // Since the symbol name is updated, there is no need to maintain the
+ // entry that associates the old symbol name with the original module.
+ symNameToModuleMap.erase(oldSymName);
+ // Instead, add a new entry to map the new symbol name to the original
+ // module in case it gets renamed again later.
+ symNameToModuleMap[newSymName] = originalModule;
}
}
// In the current input module, rename all symbols that conflict with
// symbols from the combined module. This includes renaming spv.funcs.
for (auto &op : *moduleClone.getBody()) {
- if (auto symbolOp = dyn_cast<SymbolOpInterface>(op)) {
- StringRef oldSymName = symbolOp.getName();
+ auto symbolOp = dyn_cast<SymbolOpInterface>(op);
+ if (!symbolOp)
+ continue;
- if (failed(updateSymbolAndAllUses(symbolOp, moduleClone, combinedModule,
- lastUsedID)))
- return nullptr;
+ StringRef oldSymName = symbolOp.getName();
- StringRef newSymName = symbolOp.getName();
+ if (failed(updateSymbolAndAllUses(symbolOp, moduleClone, combinedModule,
+ lastUsedID)))
+ return nullptr;
- if (symRenameListener && oldSymName != newSymName) {
- symRenameListener(module, oldSymName, newSymName);
+ StringRef newSymName = symbolOp.getName();
- // Insert the module associated with the symbol name.
- auto emplaceResult =
- symNameToModuleMap.try_emplace(symbolOp.getName(), module);
+ if (symRenameListener) {
+ if (oldSymName != newSymName)
+ symRenameListener(inputModule, oldSymName, newSymName);
- // If an entry with the same symbol name is already present, this must
- // be a problem with the implementation, specially clean-up of the map
- // while iterating over the combined module above.
- if (!emplaceResult.second) {
- module.emitError("did not expect to find an entry for symbol ")
- << symbolOp.getName();
- return nullptr;
- }
+ // Insert the module associated with the symbol name.
+ auto emplaceResult =
+ symNameToModuleMap.try_emplace(newSymName, inputModule);
+
+ // If an entry with the same symbol name is already present, this must
+ // be a problem with the implementation, specially clean-up of the map
+ // while iterating over the combined module above.
+ if (!emplaceResult.second) {
+ inputModule.emitError("did not expect to find an entry for symbol ")
+ << symbolOp.getName();
+ return nullptr;
}
}
}
@@ -234,30 +213,26 @@ combine(llvm::MutableArrayRef<spirv::ModuleOp> modules,
SmallVector<SymbolOpInterface, 0> eraseList;
for (auto &op : *combinedModule.getBody()) {
- llvm::hash_code hashCode(0);
SymbolOpInterface symbolOp = dyn_cast<SymbolOpInterface>(op);
-
if (!symbolOp)
continue;
- hashCode = computeHash(symbolOp);
-
- // A 0 hash code means the op is not suitable for deduplication and should
- // be skipped. An example of this is when a function has ops with regions
- // which are not properly supported yet.
- if (!hashCode)
+ // Do not support ops with operands or results.
+ // Global variables, spec constants, and functions won't have
+ // operands/results, but just for safety here.
+ if (op.getNumOperands() != 0 || op.getNumResults() != 0)
continue;
- if (auto funcOp = dyn_cast<FuncOp>(op))
- for (auto &blk : funcOp)
- hashCode = llvm::hash_combine(hashCode, computeHash(&blk));
-
- SymbolOpInterface replacementSymOp =
- emplaceOrGetReplacementSymbol(hashCode, symbolOp, hashToSymbolOp);
+ // Deduplicating functions are not supported yet.
+ if (isa<FuncOp>(op))
+ continue;
- if (!replacementSymOp)
+ auto result = hashToSymbolOp.try_emplace(computeHash(symbolOp), symbolOp);
+ if (result.second)
continue;
+ SymbolOpInterface replacementSymOp = result.first->second;
+
if (failed(SymbolTable::replaceAllSymbolUses(
symbolOp, replacementSymOp.getName(), combinedModule))) {
symbolOp.emitError("unable to update all symbol uses for ")
diff --git a/mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/basic.mlir b/mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/basic.mlir
index af3cd51a67a8..0c9f6cafaf9f 100644
--- a/mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/basic.mlir
+++ b/mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/basic.mlir
@@ -1,9 +1,19 @@
// RUN: mlir-opt -test-spirv-module-combiner -split-input-file -verify-diagnostics %s | FileCheck %s
+// Combine modules without the same symbols
+
// CHECK: module {
// CHECK-NEXT: spv.module Logical GLSL450 {
// CHECK-NEXT: spv.SpecConstant @m1_sc
+// CHECK-NEXT: spv.GlobalVariable @m1_gv bind(1, 0)
+// CHECK-NEXT: spv.func @no_op
+// CHECK-NEXT: spv.Return
+// CHECK-NEXT: }
+// CHECK-NEXT: spv.EntryPoint "GLCompute" @no_op
+// CHECK-NEXT: spv.ExecutionMode @no_op "LocalSize", 32, 1, 1
+
// CHECK-NEXT: spv.SpecConstant @m2_sc
+// CHECK-NEXT: spv.GlobalVariable @m2_gv bind(0, 1)
// CHECK-NEXT: spv.func @variable_init_spec_constant
// CHECK-NEXT: spv.mlir.referenceof @m2_sc
// CHECK-NEXT: spv.Variable init
@@ -15,10 +25,17 @@
module {
spv.module Logical GLSL450 {
spv.SpecConstant @m1_sc = 42.42 : f32
+ spv.GlobalVariable @m1_gv bind(1, 0): !spv.ptr<f32, Input>
+ spv.func @no_op() -> () "None" {
+ spv.Return
+ }
+ spv.EntryPoint "GLCompute" @no_op
+ spv.ExecutionMode @no_op "LocalSize", 32, 1, 1
}
spv.module Logical GLSL450 {
spv.SpecConstant @m2_sc = 42 : i32
+ spv.GlobalVariable @m2_gv bind(0, 1): !spv.ptr<f32, Input>
spv.func @variable_init_spec_constant() -> () "None" {
%0 = spv.mlir.referenceof @m2_sc : i32
%1 = spv.Variable init(%0) : !spv.ptr<i32, Function>
@@ -33,7 +50,7 @@ module {
spv.module Physical64 GLSL450 {
}
-// expected-error @+1 {{input modules
diff er in addressing model and/or memory model}}
+// expected-error @+1 {{input modules
diff er in addressing model, memory model, and/or VCE triple}}
spv.module Logical GLSL450 {
}
}
@@ -44,7 +61,19 @@ module {
spv.module Logical Simple {
}
-// expected-error @+1 {{input modules
diff er in addressing model and/or memory model}}
+// expected-error @+1 {{input modules
diff er in addressing model, memory model, and/or VCE triple}}
+spv.module Logical GLSL450 {
+}
+}
+
+// -----
+
+module {
spv.module Logical GLSL450 {
}
+
+// expected-error @+1 {{input modules
diff er in addressing model, memory model, and/or VCE triple}}
+spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]> {
+}
}
+
diff --git a/mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/conflict-resolution.mlir b/mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/conflict-resolution.mlir
index 9123bf4478f3..f3995aa28ba1 100644
--- a/mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/conflict-resolution.mlir
+++ b/mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/conflict-resolution.mlir
@@ -215,7 +215,7 @@ spv.module Logical GLSL450 {
spv.func @foo(%arg0 : i32) -> i32 "None" {
spv.ReturnValue %arg0 : i32
}
-
+
spv.EntryPoint "GLCompute" @foo
spv.ExecutionMode @foo "ContractionOff"
}
@@ -383,7 +383,7 @@ spv.module Logical GLSL450 {
spv.SpecConstant @foo = -5 : i32
spv.func @bar() -> i32 "None" {
- %0 = spv.mlir.referenceof @foo : i32
+ %0 = spv.mlir.referenceof @foo : i32
spv.ReturnValue %0 : i32
}
}
diff --git a/mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/deduplication.mlir b/mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/deduplication.mlir
index 0bd76113f198..f7aab6e6f40a 100644
--- a/mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/deduplication.mlir
+++ b/mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/deduplication.mlir
@@ -21,7 +21,6 @@
// CHECK-NEXT: }
// CHECK-NEXT: }
-module {
spv.module Logical GLSL450 {
spv.GlobalVariable @foo bind(1, 0) : !spv.ptr<f32, Input>
@@ -42,7 +41,6 @@ spv.module Logical GLSL450 {
spv.ReturnValue %2 : f32
}
}
-}
// -----
@@ -62,7 +60,6 @@ spv.module Logical GLSL450 {
// CHECK-NEXT: }
// CHECK-NEXT: }
-module {
spv.module Logical GLSL450 {
spv.GlobalVariable @foo bind(1, 0) : !spv.ptr<i32, Input>
}
@@ -76,7 +73,6 @@ spv.module Logical GLSL450 {
spv.ReturnValue %1 : f32
}
}
-}
// -----
@@ -93,7 +89,6 @@ spv.module Logical GLSL450 {
// CHECK-NEXT: }
// CHECK-NEXT: }
-module {
spv.module Logical GLSL450 {
spv.GlobalVariable @foo built_in("GlobalInvocationId") : !spv.ptr<vector<3xi32>, Input>
}
@@ -107,10 +102,11 @@ spv.module Logical GLSL450 {
spv.ReturnValue %1 : vector<3xi32>
}
}
-}
// -----
+// Deduplicate 2 spec constants with the same spec ID.
+
// CHECK: module {
// CHECK-NEXT: spv.module Logical GLSL450 {
// CHECK-NEXT: spv.SpecConstant @foo spec_id(5)
@@ -128,7 +124,6 @@ spv.module Logical GLSL450 {
// CHECK-NEXT: }
// CHECK-NEXT: }
-module {
spv.module Logical GLSL450 {
spv.SpecConstant @foo spec_id(5) = 1. : f32
@@ -147,48 +142,82 @@ spv.module Logical GLSL450 {
spv.ReturnValue %1 : f32
}
}
+
+// -----
+
+// Don't deduplicate functions with similar ops but
diff erent operands.
+
+// CHECK: spv.module Logical GLSL450 {
+// CHECK-NEXT: spv.func @foo(%[[ARG0:.+]]: f32, %[[ARG1:.+]]: f32, %[[ARG2:.+]]: f32)
+// CHECK-NEXT: %[[ADD:.+]] = spv.FAdd %[[ARG0]], %[[ARG1]] : f32
+// CHECK-NEXT: %[[MUL:.+]] = spv.FMul %[[ADD]], %[[ARG2]] : f32
+// CHECK-NEXT: spv.ReturnValue %[[MUL]] : f32
+// CHECK-NEXT: }
+// CHECK-NEXT: spv.func @foo_1(%[[ARG0:.+]]: f32, %[[ARG1:.+]]: f32, %[[ARG2:.+]]: f32)
+// CHECK-NEXT: %[[ADD:.+]] = spv.FAdd %[[ARG0]], %[[ARG2]] : f32
+// CHECK-NEXT: %[[MUL:.+]] = spv.FMul %[[ADD]], %[[ARG1]] : f32
+// CHECK-NEXT: spv.ReturnValue %[[MUL]] : f32
+// CHECK-NEXT: }
+// CHECK-NEXT: }
+
+spv.module Logical GLSL450 {
+ spv.func @foo(%a: f32, %b: f32, %c: f32) -> f32 "None" {
+ %add = spv.FAdd %a, %b: f32
+ %mul = spv.FMul %add, %c: f32
+ spv.ReturnValue %mul: f32
+ }
+}
+
+spv.module Logical GLSL450 {
+ spv.func @foo(%a: f32, %b: f32, %c: f32) -> f32 "None" {
+ %add = spv.FAdd %a, %c: f32
+ %mul = spv.FMul %add, %b: f32
+ spv.ReturnValue %mul: f32
+ }
}
// -----
-// CHECK: module {
-// CHECK-NEXT: spv.module Logical GLSL450 {
-// CHECK-NEXT: spv.SpecConstant @bar spec_id(5)
+// TODO: re-enable this test once we have better function deduplication.
-// CHECK-NEXT: spv.func @foo(%arg0: f32)
-// CHECK-NEXT: spv.ReturnValue
-// CHECK-NEXT: }
+// XXXXX: module {
+// XXXXX-NEXT: spv.module Logical GLSL450 {
+// XXXXX-NEXT: spv.SpecConstant @bar spec_id(5)
-// CHECK-NEXT: spv.func @foo_
diff erent_body(%arg0: f32)
-// CHECK-NEXT: spv.mlir.referenceof
-// CHECK-NEXT: spv.ReturnValue
-// CHECK-NEXT: }
+// XXXXX-NEXT: spv.func @foo(%arg0: f32)
+// XXXXX-NEXT: spv.ReturnValue
+// XXXXX-NEXT: }
-// CHECK-NEXT: spv.func @baz(%arg0: i32)
-// CHECK-NEXT: spv.ReturnValue
-// CHECK-NEXT: }
+// XXXXX-NEXT: spv.func @foo_
diff erent_body(%arg0: f32)
+// XXXXX-NEXT: spv.mlir.referenceof
+// XXXXX-NEXT: spv.ReturnValue
+// XXXXX-NEXT: }
-// CHECK-NEXT: spv.func @baz_no_return(%arg0: i32)
-// CHECK-NEXT: spv.Return
-// CHECK-NEXT: }
+// XXXXX-NEXT: spv.func @baz(%arg0: i32)
+// XXXXX-NEXT: spv.ReturnValue
+// XXXXX-NEXT: }
-// CHECK-NEXT: spv.func @baz_no_return_
diff erent_control
-// CHECK-NEXT: spv.Return
-// CHECK-NEXT: }
+// XXXXX-NEXT: spv.func @baz_no_return(%arg0: i32)
+// XXXXX-NEXT: spv.Return
+// XXXXX-NEXT: }
-// CHECK-NEXT: spv.func @baz_no_return_another_control
-// CHECK-NEXT: spv.Return
-// CHECK-NEXT: }
+// XXXXX-NEXT: spv.func @baz_no_return_
diff erent_control
+// XXXXX-NEXT: spv.Return
+// XXXXX-NEXT: }
-// CHECK-NEXT: spv.func @kernel
-// CHECK-NEXT: spv.Return
-// CHECK-NEXT: }
+// XXXXX-NEXT: spv.func @baz_no_return_another_control
+// XXXXX-NEXT: spv.Return
+// XXXXX-NEXT: }
-// CHECK-NEXT: spv.func @kernel_
diff erent_attr
-// CHECK-NEXT: spv.Return
-// CHECK-NEXT: }
-// CHECK-NEXT: }
-// CHECK-NEXT: }
+// XXXXX-NEXT: spv.func @kernel
+// XXXXX-NEXT: spv.Return
+// XXXXX-NEXT: }
+
+// XXXXX-NEXT: spv.func @kernel_
diff erent_attr
+// XXXXX-NEXT: spv.Return
+// XXXXX-NEXT: }
+// XXXXX-NEXT: }
+// XXXXX-NEXT: }
module {
spv.module Logical GLSL450 {
diff --git a/mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/symbol-rename-listener.mlir b/mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/symbol-rename-listener.mlir
new file mode 100644
index 000000000000..a9691bdeb6ff
--- /dev/null
+++ b/mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/symbol-rename-listener.mlir
@@ -0,0 +1,54 @@
+// RUN: mlir-opt -test-spirv-module-combiner -split-input-file -verify-diagnostics %s | FileCheck %s
+
+module {
+spv.module @Module1 Logical GLSL450 {
+ spv.GlobalVariable @foo bind(1, 0) : !spv.ptr<f32, Input>
+ spv.func @bar() -> () "None" {
+ spv.Return
+ }
+ spv.func @baz() -> () "None" {
+ spv.Return
+ }
+
+ spv.SpecConstant @sc = -5 : i32
+}
+
+spv.module @Module2 Logical GLSL450 {
+ spv.func @foo() -> () "None" {
+ spv.Return
+ }
+
+ spv.GlobalVariable @bar bind(1, 0) : !spv.ptr<f32, Input>
+
+ spv.func @baz() -> () "None" {
+ spv.Return
+ }
+
+ spv.SpecConstant @sc = -5 : i32
+}
+
+spv.module @Module3 Logical GLSL450 {
+ spv.func @foo() -> () "None" {
+ spv.Return
+ }
+
+ spv.GlobalVariable @bar bind(1, 0) : !spv.ptr<f32, Input>
+
+ spv.func @baz() -> () "None" {
+ spv.Return
+ }
+
+ spv.SpecConstant @sc = -5 : i32
+}
+}
+
+// CHECK: [Module1] foo -> foo_1
+// CHECK: [Module1] sc -> sc_2
+
+// CHECK: [Module2] bar -> bar_3
+// CHECK: [Module2] baz -> baz_4
+// CHECK: [Module2] sc -> sc_5
+
+// CHECK: [Module3] foo -> foo_6
+// CHECK: [Module3] bar -> bar_7
+// CHECK: [Module3] baz -> baz_8
diff --git a/mlir/test/lib/Dialect/SPIRV/TestModuleCombiner.cpp b/mlir/test/lib/Dialect/SPIRV/TestModuleCombiner.cpp
index d7675012a77c..12054b1dc454 100644
--- a/mlir/test/lib/Dialect/SPIRV/TestModuleCombiner.cpp
+++ b/mlir/test/lib/Dialect/SPIRV/TestModuleCombiner.cpp
@@ -37,7 +37,14 @@ void TestModuleCombinerPass::runOnOperation() {
auto modules = llvm::to_vector<4>(getOperation().getOps<spirv::ModuleOp>());
OpBuilder combinedModuleBuilder(modules[0]);
- combinedModule = spirv::combine(modules, combinedModuleBuilder, nullptr);
+
+ auto listener = [](spirv::ModuleOp originalModule, StringRef oldSymbol,
+ StringRef newSymbol) {
+ llvm::outs() << "[" << originalModule.getName() << "] " << oldSymbol
+ << " -> " << newSymbol << "\n";
+ };
+
+ combinedModule = spirv::combine(modules, combinedModuleBuilder, listener);
for (spirv::ModuleOp module : modules)
module.erase();
More information about the Mlir-commits
mailing list