[Mlir-commits] [mlir] 23326b9 - [mlir][spirv] Fix a few issues in ModuleCombiner

Lei Zhang llvmlistbot at llvm.org
Wed Jul 28 07:36:26 PDT 2021


Author: Lei Zhang
Date: 2021-07-28T10:31:01-04:00
New Revision: 23326b9f1723a398681def87c80e608fa94485f2

URL: https://github.com/llvm/llvm-project/commit/23326b9f1723a398681def87c80e608fa94485f2
DIFF: https://github.com/llvm/llvm-project/commit/23326b9f1723a398681def87c80e608fa94485f2.diff

LOG: [mlir][spirv] Fix a few issues in ModuleCombiner

- Fixed symbol insertion into `symNameToModuleMap`. Insertion
  needs to happen whether symbols are renamed or not.
- Added check for the VCE triple and avoid dropping it.
- Disabled function deduplication. It requires more careful
  rules. Right now it can remove different functions.
- Added tests for symbol rename listener.
- And some other code/comment cleanups.

Reviewed By: ergawy

Differential Revision: https://reviews.llvm.org/D106886

Added: 
    mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/symbol-rename-listener.mlir

Modified: 
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.h
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
    mlir/include/mlir/Dialect/SPIRV/Linking/ModuleCombiner.h
    mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
    mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
    mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp
    mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/basic.mlir
    mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/conflict-resolution.mlir
    mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/deduplication.mlir
    mlir/test/lib/Dialect/SPIRV/TestModuleCombiner.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.h
index 410fc946dd08..c4f5fd1192c4 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.h
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.h
@@ -13,6 +13,7 @@
 #ifndef MLIR_DIALECT_SPIRV_IR_SPIRVOPS_H_
 #define MLIR_DIALECT_SPIRV_IR_SPIRVOPS_H_
 
+#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVOpTraits.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
 #include "mlir/IR/BuiltinOps.h"

diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
index 683c6dab183f..bb4e558cde54 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
@@ -467,8 +467,9 @@ def SPV_ModuleOp : SPV_Op<"module",
   let builders = [
     OpBuilder<(ins CArg<"Optional<StringRef>", "llvm::None">:$name)>,
     OpBuilder<(ins "spirv::AddressingModel":$addressing_model,
-      "spirv::MemoryModel":$memory_model,
-      CArg<"Optional<StringRef>", "llvm::None">:$name)>
+                   "spirv::MemoryModel":$memory_model,
+                   CArg<"Optional<spirv::VerCapExtAttr>", "llvm::None">:$vce_triple,
+                   CArg<"Optional<StringRef>", "llvm::None">:$name)>
   ];
 
   // We need to ensure the block inside the region is properly terminated;

diff  --git a/mlir/include/mlir/Dialect/SPIRV/Linking/ModuleCombiner.h b/mlir/include/mlir/Dialect/SPIRV/Linking/ModuleCombiner.h
index c2eaef037d10..833da17452f1 100644
--- a/mlir/include/mlir/Dialect/SPIRV/Linking/ModuleCombiner.h
+++ b/mlir/include/mlir/Dialect/SPIRV/Linking/ModuleCombiner.h
@@ -22,53 +22,54 @@ class OpBuilder;
 namespace spirv {
 class ModuleOp;
 
-/// To combine a number of MLIR SPIR-V modules, we move all the module-level ops
+/// The listener function to receive symbol renaming events.
+///
+/// `originalModule` is the input spirv::ModuleOp that contains the renamed
+/// symbol. `oldSymbol` and `newSymbol` are the original and renamed symbol.
+/// Note that it's the responsibility of the caller to properly retain the
+/// storage underlying the passed StringRefs if the listener callback outlives
+/// this function call.
+using SymbolRenameListener = function_ref<void(
+    spirv::ModuleOp originalModule, StringRef oldSymbol, StringRef newSymbol)>;
+
+/// Combines a list of SPIR-V `inputModules` into one. Returns the combined
+/// module on success; returns a null module otherwise.
+//
+/// \param inputModules the list of modules to combine. They won't be modified.
+/// \param combinedMdouleBuilder an OpBuilder for building the combined module.
+/// \param symbRenameListener a listener that gets called everytime a symbol in
+///                           one of the input modules is renamed.
+///
+/// To combine multiple SPIR-V modules, we move all the module-level ops
 /// from all the input modules into one big combined module. To that end, the
 /// combination process proceeds in 2 phases:
 ///
-///   (1) resolve conflicts between pairs of ops from 
diff erent modules
-///   (2) deduplicate equivalent ops/sub-ops in the merged module.
+/// 1. resolve conflicts between pairs of ops from 
diff erent modules,
+/// 2. deduplicate equivalent ops/sub-ops in the merged module.
 ///
 /// For the conflict resolution phase, the following rules are employed to
 /// resolve such conflicts:
 ///
-///   - If 2 spv.func's have the same symbol name, then rename one of the
+/// - If 2 spv.func's have the same symbol name, then rename one of the
 ///   functions.
-///   - If an spv.func and another op have the same symbol name, then rename the
+/// - If an spv.func and another op have the same symbol name, then rename the
 ///   other symbol.
-///   - If none of the 2 conflicting ops are spv.func, then rename either.
+/// - If none of the 2 conflicting ops are spv.func, then rename either.
 ///
 /// For deduplication, the following 3 cases are taken into consideration:
 ///
-///   - If 2 spv.GlobalVariable's have either the same descriptor set + binding
+/// - If 2 spv.GlobalVariable's have either the same descriptor set + binding
 ///   or the same build_in attribute value, then replace one of them using the
 ///   other.
-///   - If 2 spv.SpecConstant's have the same spec_id attribute value, then
+/// - If 2 spv.SpecConstant's have the same spec_id attribute value, then
 ///   replace one of them using the other.
-///   - If 2 spv.func's are identical replace one of them using the other.
+/// - Deduplicating functions are not supported right now.
 ///
 /// In all cases, the references to the updated symbol (whether renamed or
 /// deduplicated) are also updated to reflect the change.
-///
-/// \param modules the list of modules to combine. Input modules are not
-/// modified.
-/// \param combinedMdouleBuilder an OpBuilder to be used for
-//                               building up the combined module.
-/// \param symbRenameListener a listener that gets called everytime a symbol in
-///                           one of the input modules is renamed. The arguments
-///                           passed to the listener are: the input
-///                           spirv::ModuleOp that contains the renamed symbol,
-///                           a StringRef to the old symbol name, and a
-///                           StringRef to the new symbol name. Note that it is
-///                           the responsibility of the caller to properly
-///                           retain the storage underlying the passed
-///                           StringRefs if the listener callback outlives this
-///                           function call.
-///
-/// \return the combined module.
-OwningOpRef<spirv::ModuleOp>
-combine(MutableArrayRef<ModuleOp> modules, OpBuilder &combinedModuleBuilder,
-        function_ref<void(ModuleOp, StringRef, StringRef)> symbRenameListener);
+OwningOpRef<spirv::ModuleOp> combine(ArrayRef<spirv::ModuleOp> inputModules,
+                                     OpBuilder &combinedModuleBuilder,
+                                     SymbolRenameListener symRenameListener);
 } // namespace spirv
 } // namespace mlir
 

diff  --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
index 5aaec815a83b..79f68e2e476f 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
@@ -310,7 +310,7 @@ LogicalResult GPUModuleConversion::matchAndRewrite(
   // Add a keyword to the module name to avoid symbolic conflict.
   std::string spvModuleName = (kSPIRVModule + moduleOp.getName()).str();
   auto spvModule = rewriter.create<spirv::ModuleOp>(
-      moduleOp.getLoc(), addressingModel, memoryModel.getValue(),
+      moduleOp.getLoc(), addressingModel, memoryModel.getValue(), llvm::None,
       StringRef(spvModuleName));
 
   // Move the region from the module op into the SPIR-V module.

diff  --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 471fa2c1b4a3..c03388d266b5 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -2540,6 +2540,7 @@ void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state,
 void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state,
                             spirv::AddressingModel addressingModel,
                             spirv::MemoryModel memoryModel,
+                            Optional<VerCapExtAttr> vceTriple,
                             Optional<StringRef> name) {
   state.addAttribute(
       "addressing_model",
@@ -2548,10 +2549,11 @@ void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state,
                                          static_cast<int32_t>(memoryModel)));
   OpBuilder::InsertionGuard guard(builder);
   builder.createBlock(state.addRegion());
-  if (name) {
-    state.attributes.append(mlir::SymbolTable::getSymbolAttrName(),
-                            builder.getStringAttr(*name));
-  }
+  if (vceTriple)
+    state.addAttribute(getVCETripleAttrName(), *vceTriple);
+  if (name)
+    state.addAttribute(mlir::SymbolTable::getSymbolAttrName(),
+                       builder.getStringAttr(*name));
 }
 
 static ParseResult parseModuleOp(OpAsmParser &parser, OperationState &state) {

diff  --git a/mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp b/mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp
index 5fc948a07090..1007603d3cf2 100644
--- a/mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp
+++ b/mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp
@@ -12,27 +12,33 @@
 
 #include "mlir/Dialect/SPIRV/Linking/ModuleCombiner.h"
 
+#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
+#include "mlir/IR/Attributes.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/SymbolTable.h"
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/Hashing.h"
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/StringExtras.h"
+#include "llvm/ADT/StringMap.h"
 
 using namespace mlir;
 
 static constexpr unsigned maxFreeID = 1 << 20;
 
+/// Returns an unsed symbol in `module` for `oldSymbolName` by trying numeric
+/// suffix in `lastUsedID`.
 static SmallString<64> renameSymbol(StringRef oldSymName, unsigned &lastUsedID,
-                                    spirv::ModuleOp combinedModule) {
+                                    spirv::ModuleOp module) {
   SmallString<64> newSymName(oldSymName);
   newSymName.push_back('_');
 
   while (lastUsedID < maxFreeID) {
     std::string possible = (newSymName + llvm::utostr(++lastUsedID)).str();
 
-    if (!SymbolTable::lookupSymbolIn(combinedModule, possible)) {
+    if (!SymbolTable::lookupSymbolIn(module, possible)) {
       newSymName += llvm::utostr(lastUsedID);
       break;
     }
@@ -41,8 +47,8 @@ static SmallString<64> renameSymbol(StringRef oldSymName, unsigned &lastUsedID,
   return newSymName;
 }
 
-/// Check if a symbol with the same name as op already exists in source. If so,
-/// rename op and update all its references in target.
+/// Checks if a symbol with the same name as `op` already exists in `source`.
+/// If so, renames `op` and updates all its references in `target`.
 static LogicalResult updateSymbolAndAllUses(SymbolOpInterface op,
                                             spirv::ModuleOp target,
                                             spirv::ModuleOp source,
@@ -61,99 +67,67 @@ static LogicalResult updateSymbolAndAllUses(SymbolOpInterface op,
   return success();
 }
 
-template <typename KeyTy, typename SymbolOpTy>
-static SymbolOpTy
-emplaceOrGetReplacementSymbol(KeyTy key, SymbolOpTy symbolOp,
-                              DenseMap<KeyTy, SymbolOpTy> &deduplicationMap) {
-  auto result = deduplicationMap.try_emplace(key, symbolOp);
-
-  if (result.second)
-    return SymbolOpTy();
-
-  return result.first->second;
-}
-
-/// Computes a hash code to represent the argument SymbolOpInterface based on
-/// all the Op's attributes except for the symbol name.
-///
-/// \return the hash code computed from the Op's attributes as described above.
+/// Computes a hash code to represent `symbolOp` based on all its attributes
+/// except for the symbol name.
 ///
 /// Note: We use the operation's name (not the symbol name) as part of the hash
 /// computation. This prevents, for example, mistakenly considering a global
 /// variable and a spec constant as duplicates because their descriptor set +
 /// binding and spec_id, respectively, happen to hash to the same value.
 static llvm::hash_code computeHash(SymbolOpInterface symbolOp) {
-  llvm::hash_code hashCode(0);
-  hashCode = llvm::hash_combine(symbolOp->getName());
-
-  for (auto attr : symbolOp->getAttrs()) {
-    if (attr.first == SymbolTable::getSymbolAttrName())
-      continue;
-    hashCode = llvm::hash_combine(hashCode, attr);
-  }
-
-  return hashCode;
-}
-
-/// Computes a hash code from the argument Block.
-llvm::hash_code computeHash(Block *block) {
-  // TODO: Consider extracting BlockEquivalenceData into a common header and
-  // re-using it here.
-  llvm::hash_code hash(0);
-
-  for (Operation &op : *block) {
-    // TODO: Properly handle operations with regions.
-    if (op.getNumRegions() > 0)
-      return 0;
-
-    hash = llvm::hash_combine(
-        hash, OperationEquivalence::computeHash(
-                  &op, OperationEquivalence::Flags::IgnoreOperands));
-  }
-
-  return hash;
+  auto range =
+      llvm::make_filter_range(symbolOp->getAttrs(), [](NamedAttribute attr) {
+        return attr.first != SymbolTable::getSymbolAttrName();
+      });
+
+  return llvm::hash_combine(
+      symbolOp->getName(),
+      llvm::hash_combine_range(range.begin(), range.end()));
 }
 
 namespace mlir {
 namespace spirv {
 
-// TODO Properly test symbol rename listener mechanism.
-
-OwningOpRef<spirv::ModuleOp>
-combine(llvm::MutableArrayRef<spirv::ModuleOp> modules,
-        OpBuilder &combinedModuleBuilder,
-        llvm::function_ref<void(ModuleOp, StringRef, StringRef)>
-            symRenameListener) {
-  unsigned lastUsedID = 0;
-
-  if (modules.empty())
+OwningOpRef<spirv::ModuleOp> combine(ArrayRef<spirv::ModuleOp> inputModules,
+                                     OpBuilder &combinedModuleBuilder,
+                                     SymbolRenameListener symRenameListener) {
+  if (inputModules.empty())
     return nullptr;
 
-  auto addressingModel = modules[0].addressing_model();
-  auto memoryModel = modules[0].memory_model();
+  spirv::ModuleOp firstModule = inputModules.front();
+  auto addressingModel = firstModule.addressing_model();
+  auto memoryModel = firstModule.memory_model();
+  auto vceTriple = firstModule.vce_triple();
+
+  // First check whether there are conflicts between addressing/memory model.
+  // Return early if so.
+  for (auto module : inputModules) {
+    if (module.addressing_model() != addressingModel ||
+        module.memory_model() != memoryModel ||
+        module.vce_triple() != vceTriple) {
+      module.emitError("input modules 
diff er in addressing model, memory "
+                       "model, and/or VCE triple");
+      return nullptr;
+    }
+  }
 
   auto combinedModule = combinedModuleBuilder.create<spirv::ModuleOp>(
-      modules[0].getLoc(), addressingModel, memoryModel);
+      firstModule.getLoc(), addressingModel, memoryModel, vceTriple);
   combinedModuleBuilder.setInsertionPointToStart(combinedModule.getBody());
 
   // In some cases, a symbol in the (current state of the) combined module is
-  // renamed in order to maintain the conflicting symbol in the input module
+  // renamed in order to enable the conflicting symbol in the input module
   // being merged. For example, if the conflict is between a global variable in
   // the current combined module and a function in the input module, the global
   // variable is renamed. In order to notify listeners of the symbol updates in
   // such cases, we need to keep track of the module from which the renamed
   // symbol in the combined module originated. This map keeps such information.
-  DenseMap<StringRef, spirv::ModuleOp> symNameToModuleMap;
+  llvm::StringMap<spirv::ModuleOp> symNameToModuleMap;
 
-  for (auto module : modules) {
-    if (module.addressing_model() != addressingModel ||
-        module.memory_model() != memoryModel) {
-      module.emitError(
-          "input modules 
diff er in addressing model and/or memory model");
-      return nullptr;
-    }
+  unsigned lastUsedID = 0;
 
-    spirv::ModuleOp moduleClone = module.clone();
+  for (auto inputModule : inputModules) {
+    spirv::ModuleOp moduleClone = inputModule.clone();
 
     // In the combined module, rename all symbols that conflict with symbols
     // from the current input module. This renaming applies to all ops except
@@ -161,65 +135,70 @@ combine(llvm::MutableArrayRef<spirv::ModuleOp> modules,
     // non-spv.func, we rename that symbol instead and maintain the spv.func in
     // the combined module name as it is.
     for (auto &op : *combinedModule.getBody()) {
-      if (auto symbolOp = dyn_cast<SymbolOpInterface>(op)) {
-        StringRef oldSymName = symbolOp.getName();
+      auto symbolOp = dyn_cast<SymbolOpInterface>(op);
+      if (!symbolOp)
+        continue;
 
-        if (!isa<FuncOp>(op) &&
-            failed(updateSymbolAndAllUses(symbolOp, combinedModule, moduleClone,
-                                          lastUsedID)))
-          return nullptr;
+      StringRef oldSymName = symbolOp.getName();
 
-        StringRef newSymName = symbolOp.getName();
+      if (!isa<FuncOp>(op) &&
+          failed(updateSymbolAndAllUses(symbolOp, combinedModule, moduleClone,
+                                        lastUsedID)))
+        return nullptr;
 
-        if (symRenameListener && oldSymName != newSymName) {
-          spirv::ModuleOp originalModule =
-              symNameToModuleMap.lookup(oldSymName);
+      StringRef newSymName = symbolOp.getName();
 
-          if (!originalModule) {
-            module.emitError("unable to find original ModuleOp for symbol ")
-                << oldSymName;
-            return nullptr;
-          }
+      if (symRenameListener && oldSymName != newSymName) {
+        spirv::ModuleOp originalModule = symNameToModuleMap.lookup(oldSymName);
 
-          symRenameListener(originalModule, oldSymName, newSymName);
-
-          // Since the symbol name is updated, there is no need to maintain the
-          // entry that associates the old symbol name with the original module.
-          symNameToModuleMap.erase(oldSymName);
-          // Instead, add a new entry to map the new symbol name to the original
-          // module in case it gets renamed again later.
-          symNameToModuleMap[newSymName] = originalModule;
+        if (!originalModule) {
+          inputModule.emitError(
+              "unable to find original spirv::ModuleOp for symbol ")
+              << oldSymName;
+          return nullptr;
         }
+
+        symRenameListener(originalModule, oldSymName, newSymName);
+
+        // Since the symbol name is updated, there is no need to maintain the
+        // entry that associates the old symbol name with the original module.
+        symNameToModuleMap.erase(oldSymName);
+        // Instead, add a new entry to map the new symbol name to the original
+        // module in case it gets renamed again later.
+        symNameToModuleMap[newSymName] = originalModule;
       }
     }
 
     // In the current input module, rename all symbols that conflict with
     // symbols from the combined module. This includes renaming spv.funcs.
     for (auto &op : *moduleClone.getBody()) {
-      if (auto symbolOp = dyn_cast<SymbolOpInterface>(op)) {
-        StringRef oldSymName = symbolOp.getName();
+      auto symbolOp = dyn_cast<SymbolOpInterface>(op);
+      if (!symbolOp)
+        continue;
 
-        if (failed(updateSymbolAndAllUses(symbolOp, moduleClone, combinedModule,
-                                          lastUsedID)))
-          return nullptr;
+      StringRef oldSymName = symbolOp.getName();
 
-        StringRef newSymName = symbolOp.getName();
+      if (failed(updateSymbolAndAllUses(symbolOp, moduleClone, combinedModule,
+                                        lastUsedID)))
+        return nullptr;
 
-        if (symRenameListener && oldSymName != newSymName) {
-          symRenameListener(module, oldSymName, newSymName);
+      StringRef newSymName = symbolOp.getName();
 
-          // Insert the module associated with the symbol name.
-          auto emplaceResult =
-              symNameToModuleMap.try_emplace(symbolOp.getName(), module);
+      if (symRenameListener) {
+        if (oldSymName != newSymName)
+          symRenameListener(inputModule, oldSymName, newSymName);
 
-          // If an entry with the same symbol name is already present, this must
-          // be a problem with the implementation, specially clean-up of the map
-          // while iterating over the combined module above.
-          if (!emplaceResult.second) {
-            module.emitError("did not expect to find an entry for symbol ")
-                << symbolOp.getName();
-            return nullptr;
-          }
+        // Insert the module associated with the symbol name.
+        auto emplaceResult =
+            symNameToModuleMap.try_emplace(newSymName, inputModule);
+
+        // If an entry with the same symbol name is already present, this must
+        // be a problem with the implementation, specially clean-up of the map
+        // while iterating over the combined module above.
+        if (!emplaceResult.second) {
+          inputModule.emitError("did not expect to find an entry for symbol ")
+              << symbolOp.getName();
+          return nullptr;
         }
       }
     }
@@ -234,30 +213,26 @@ combine(llvm::MutableArrayRef<spirv::ModuleOp> modules,
   SmallVector<SymbolOpInterface, 0> eraseList;
 
   for (auto &op : *combinedModule.getBody()) {
-    llvm::hash_code hashCode(0);
     SymbolOpInterface symbolOp = dyn_cast<SymbolOpInterface>(op);
-
     if (!symbolOp)
       continue;
 
-    hashCode = computeHash(symbolOp);
-
-    // A 0 hash code means the op is not suitable for deduplication and should
-    // be skipped. An example of this is when a function has ops with regions
-    // which are not properly supported yet.
-    if (!hashCode)
+    // Do not support ops with operands or results.
+    // Global variables, spec constants, and functions won't have
+    // operands/results, but just for safety here.
+    if (op.getNumOperands() != 0 || op.getNumResults() != 0)
       continue;
 
-    if (auto funcOp = dyn_cast<FuncOp>(op))
-      for (auto &blk : funcOp)
-        hashCode = llvm::hash_combine(hashCode, computeHash(&blk));
-
-    SymbolOpInterface replacementSymOp =
-        emplaceOrGetReplacementSymbol(hashCode, symbolOp, hashToSymbolOp);
+    // Deduplicating functions are not supported yet.
+    if (isa<FuncOp>(op))
+      continue;
 
-    if (!replacementSymOp)
+    auto result = hashToSymbolOp.try_emplace(computeHash(symbolOp), symbolOp);
+    if (result.second)
       continue;
 
+    SymbolOpInterface replacementSymOp = result.first->second;
+
     if (failed(SymbolTable::replaceAllSymbolUses(
             symbolOp, replacementSymOp.getName(), combinedModule))) {
       symbolOp.emitError("unable to update all symbol uses for ")

diff  --git a/mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/basic.mlir b/mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/basic.mlir
index af3cd51a67a8..0c9f6cafaf9f 100644
--- a/mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/basic.mlir
+++ b/mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/basic.mlir
@@ -1,9 +1,19 @@
 // RUN: mlir-opt -test-spirv-module-combiner -split-input-file -verify-diagnostics %s | FileCheck %s
 
+// Combine modules without the same symbols
+
 // CHECK:      module {
 // CHECK-NEXT:   spv.module Logical GLSL450 {
 // CHECK-NEXT:     spv.SpecConstant @m1_sc
+// CHECK-NEXT:     spv.GlobalVariable @m1_gv bind(1, 0)
+// CHECK-NEXT:     spv.func @no_op
+// CHECK-NEXT:       spv.Return
+// CHECK-NEXT:     }
+// CHECK-NEXT:     spv.EntryPoint "GLCompute" @no_op
+// CHECK-NEXT:     spv.ExecutionMode @no_op "LocalSize", 32, 1, 1
+
 // CHECK-NEXT:     spv.SpecConstant @m2_sc
+// CHECK-NEXT:     spv.GlobalVariable @m2_gv bind(0, 1)
 // CHECK-NEXT:     spv.func @variable_init_spec_constant
 // CHECK-NEXT:       spv.mlir.referenceof @m2_sc
 // CHECK-NEXT:       spv.Variable init
@@ -15,10 +25,17 @@
 module {
 spv.module Logical GLSL450 {
   spv.SpecConstant @m1_sc = 42.42 : f32
+  spv.GlobalVariable @m1_gv bind(1, 0): !spv.ptr<f32, Input>
+  spv.func @no_op() -> () "None" {
+    spv.Return
+  }
+  spv.EntryPoint "GLCompute" @no_op
+  spv.ExecutionMode @no_op "LocalSize", 32, 1, 1
 }
 
 spv.module Logical GLSL450 {
   spv.SpecConstant @m2_sc = 42 : i32
+  spv.GlobalVariable @m2_gv bind(0, 1): !spv.ptr<f32, Input>
   spv.func @variable_init_spec_constant() -> () "None" {
     %0 = spv.mlir.referenceof @m2_sc : i32
     %1 = spv.Variable init(%0) : !spv.ptr<i32, Function>
@@ -33,7 +50,7 @@ module {
 spv.module Physical64 GLSL450 {
 }
 
-// expected-error @+1 {{input modules 
diff er in addressing model and/or memory model}}
+// expected-error @+1 {{input modules 
diff er in addressing model, memory model, and/or VCE triple}}
 spv.module Logical GLSL450 {
 }
 }
@@ -44,7 +61,19 @@ module {
 spv.module Logical Simple {
 }
 
-// expected-error @+1 {{input modules 
diff er in addressing model and/or memory model}}
+// expected-error @+1 {{input modules 
diff er in addressing model, memory model, and/or VCE triple}}
+spv.module Logical GLSL450 {
+}
+}
+
+// -----
+
+module {
 spv.module Logical GLSL450 {
 }
+
+// expected-error @+1 {{input modules 
diff er in addressing model, memory model, and/or VCE triple}}
+spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]> {
+}
 }
+

diff  --git a/mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/conflict-resolution.mlir b/mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/conflict-resolution.mlir
index 9123bf4478f3..f3995aa28ba1 100644
--- a/mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/conflict-resolution.mlir
+++ b/mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/conflict-resolution.mlir
@@ -215,7 +215,7 @@ spv.module Logical GLSL450 {
   spv.func @foo(%arg0 : i32) -> i32 "None" {
     spv.ReturnValue %arg0 : i32
   }
-  
+
   spv.EntryPoint "GLCompute" @foo
   spv.ExecutionMode @foo "ContractionOff"
 }
@@ -383,7 +383,7 @@ spv.module Logical GLSL450 {
   spv.SpecConstant @foo = -5 : i32
 
   spv.func @bar() -> i32 "None" {
-    %0 = spv.mlir.referenceof @foo : i32 
+    %0 = spv.mlir.referenceof @foo : i32
     spv.ReturnValue %0 : i32
   }
 }

diff  --git a/mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/deduplication.mlir b/mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/deduplication.mlir
index 0bd76113f198..f7aab6e6f40a 100644
--- a/mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/deduplication.mlir
+++ b/mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/deduplication.mlir
@@ -21,7 +21,6 @@
 // CHECK-NEXT:   }
 // CHECK-NEXT: }
 
-module {
 spv.module Logical GLSL450 {
   spv.GlobalVariable @foo bind(1, 0) : !spv.ptr<f32, Input>
 
@@ -42,7 +41,6 @@ spv.module Logical GLSL450 {
     spv.ReturnValue %2 : f32
   }
 }
-}
 
 // -----
 
@@ -62,7 +60,6 @@ spv.module Logical GLSL450 {
 // CHECK-NEXT: }
 // CHECK-NEXT: }
 
-module {
 spv.module Logical GLSL450 {
   spv.GlobalVariable @foo bind(1, 0) : !spv.ptr<i32, Input>
 }
@@ -76,7 +73,6 @@ spv.module Logical GLSL450 {
     spv.ReturnValue %1 : f32
   }
 }
-}
 
 // -----
 
@@ -93,7 +89,6 @@ spv.module Logical GLSL450 {
 // CHECK-NEXT:   }
 // CHECK-NEXT: }
 
-module {
 spv.module Logical GLSL450 {
   spv.GlobalVariable @foo built_in("GlobalInvocationId") : !spv.ptr<vector<3xi32>, Input>
 }
@@ -107,10 +102,11 @@ spv.module Logical GLSL450 {
     spv.ReturnValue %1 : vector<3xi32>
   }
 }
-}
 
 // -----
 
+// Deduplicate 2 spec constants with the same spec ID.
+
 // CHECK:      module {
 // CHECK-NEXT:   spv.module Logical GLSL450 {
 // CHECK-NEXT:     spv.SpecConstant @foo spec_id(5)
@@ -128,7 +124,6 @@ spv.module Logical GLSL450 {
 // CHECK-NEXT:   }
 // CHECK-NEXT: }
 
-module {
 spv.module Logical GLSL450 {
   spv.SpecConstant @foo spec_id(5) = 1. : f32
 
@@ -147,48 +142,82 @@ spv.module Logical GLSL450 {
     spv.ReturnValue %1 : f32
   }
 }
+
+// -----
+
+// Don't deduplicate functions with similar ops but 
diff erent operands.
+
+//       CHECK: spv.module Logical GLSL450 {
+//  CHECK-NEXT:   spv.func @foo(%[[ARG0:.+]]: f32, %[[ARG1:.+]]: f32, %[[ARG2:.+]]: f32)
+//  CHECK-NEXT:     %[[ADD:.+]] = spv.FAdd %[[ARG0]], %[[ARG1]] : f32
+//  CHECK-NEXT:     %[[MUL:.+]] = spv.FMul %[[ADD]], %[[ARG2]] : f32
+//  CHECK-NEXT:     spv.ReturnValue %[[MUL]] : f32
+//  CHECK-NEXT:   }
+//  CHECK-NEXT:   spv.func @foo_1(%[[ARG0:.+]]: f32, %[[ARG1:.+]]: f32, %[[ARG2:.+]]: f32)
+//  CHECK-NEXT:     %[[ADD:.+]] = spv.FAdd %[[ARG0]], %[[ARG2]] : f32
+//  CHECK-NEXT:     %[[MUL:.+]] = spv.FMul %[[ADD]], %[[ARG1]] : f32
+//  CHECK-NEXT:     spv.ReturnValue %[[MUL]] : f32
+//  CHECK-NEXT:   }
+//  CHECK-NEXT: }
+
+spv.module Logical GLSL450 {
+  spv.func @foo(%a: f32, %b: f32, %c: f32) -> f32 "None" {
+    %add = spv.FAdd %a, %b: f32
+    %mul = spv.FMul %add, %c: f32
+    spv.ReturnValue %mul: f32
+  }
+}
+
+spv.module Logical GLSL450 {
+  spv.func @foo(%a: f32, %b: f32, %c: f32) -> f32 "None" {
+    %add = spv.FAdd %a, %c: f32
+    %mul = spv.FMul %add, %b: f32
+    spv.ReturnValue %mul: f32
+  }
 }
 
 // -----
 
-// CHECK:      module {
-// CHECK-NEXT:   spv.module Logical GLSL450 {
-// CHECK-NEXT:     spv.SpecConstant @bar spec_id(5)
+// TODO: re-enable this test once we have better function deduplication.
 
-// CHECK-NEXT:     spv.func @foo(%arg0: f32)
-// CHECK-NEXT:       spv.ReturnValue
-// CHECK-NEXT:     }
+// XXXXX:      module {
+// XXXXX-NEXT:   spv.module Logical GLSL450 {
+// XXXXX-NEXT:     spv.SpecConstant @bar spec_id(5)
 
-// CHECK-NEXT:     spv.func @foo_
diff erent_body(%arg0: f32)
-// CHECK-NEXT:       spv.mlir.referenceof
-// CHECK-NEXT:       spv.ReturnValue
-// CHECK-NEXT:     }
+// XXXXX-NEXT:     spv.func @foo(%arg0: f32)
+// XXXXX-NEXT:       spv.ReturnValue
+// XXXXX-NEXT:     }
 
-// CHECK-NEXT:     spv.func @baz(%arg0: i32)
-// CHECK-NEXT:       spv.ReturnValue
-// CHECK-NEXT:     }
+// XXXXX-NEXT:     spv.func @foo_
diff erent_body(%arg0: f32)
+// XXXXX-NEXT:       spv.mlir.referenceof
+// XXXXX-NEXT:       spv.ReturnValue
+// XXXXX-NEXT:     }
 
-// CHECK-NEXT:     spv.func @baz_no_return(%arg0: i32)
-// CHECK-NEXT:       spv.Return
-// CHECK-NEXT:     }
+// XXXXX-NEXT:     spv.func @baz(%arg0: i32)
+// XXXXX-NEXT:       spv.ReturnValue
+// XXXXX-NEXT:     }
 
-// CHECK-NEXT:     spv.func @baz_no_return_
diff erent_control
-// CHECK-NEXT:       spv.Return
-// CHECK-NEXT:     }
+// XXXXX-NEXT:     spv.func @baz_no_return(%arg0: i32)
+// XXXXX-NEXT:       spv.Return
+// XXXXX-NEXT:     }
 
-// CHECK-NEXT:     spv.func @baz_no_return_another_control
-// CHECK-NEXT:       spv.Return
-// CHECK-NEXT:     }
+// XXXXX-NEXT:     spv.func @baz_no_return_
diff erent_control
+// XXXXX-NEXT:       spv.Return
+// XXXXX-NEXT:     }
 
-// CHECK-NEXT:     spv.func @kernel
-// CHECK-NEXT:       spv.Return
-// CHECK-NEXT:     }
+// XXXXX-NEXT:     spv.func @baz_no_return_another_control
+// XXXXX-NEXT:       spv.Return
+// XXXXX-NEXT:     }
 
-// CHECK-NEXT:     spv.func @kernel_
diff erent_attr
-// CHECK-NEXT:       spv.Return
-// CHECK-NEXT:     }
-// CHECK-NEXT:   }
-// CHECK-NEXT:   }
+// XXXXX-NEXT:     spv.func @kernel
+// XXXXX-NEXT:       spv.Return
+// XXXXX-NEXT:     }
+
+// XXXXX-NEXT:     spv.func @kernel_
diff erent_attr
+// XXXXX-NEXT:       spv.Return
+// XXXXX-NEXT:     }
+// XXXXX-NEXT:   }
+// XXXXX-NEXT:   }
 
 module {
 spv.module Logical GLSL450 {

diff  --git a/mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/symbol-rename-listener.mlir b/mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/symbol-rename-listener.mlir
new file mode 100644
index 000000000000..a9691bdeb6ff
--- /dev/null
+++ b/mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/symbol-rename-listener.mlir
@@ -0,0 +1,54 @@
+// RUN: mlir-opt -test-spirv-module-combiner -split-input-file -verify-diagnostics %s | FileCheck %s
+
+module {
+spv.module @Module1 Logical GLSL450 {
+  spv.GlobalVariable @foo bind(1, 0) : !spv.ptr<f32, Input>
+  spv.func @bar() -> () "None" {
+    spv.Return
+  }
+  spv.func @baz() -> () "None" {
+    spv.Return
+  }
+
+  spv.SpecConstant @sc = -5 : i32
+}
+
+spv.module @Module2 Logical GLSL450 {
+  spv.func @foo() -> () "None" {
+    spv.Return
+  }
+
+  spv.GlobalVariable @bar bind(1, 0) : !spv.ptr<f32, Input>
+
+  spv.func @baz() -> () "None" {
+    spv.Return
+  }
+
+  spv.SpecConstant @sc = -5 : i32
+}
+
+spv.module @Module3 Logical GLSL450 {
+  spv.func @foo() -> () "None" {
+    spv.Return
+  }
+
+  spv.GlobalVariable @bar bind(1, 0) : !spv.ptr<f32, Input>
+
+  spv.func @baz() -> () "None" {
+    spv.Return
+  }
+
+  spv.SpecConstant @sc = -5 : i32
+}
+}
+
+// CHECK: [Module1] foo -> foo_1
+// CHECK: [Module1] sc -> sc_2
+
+// CHECK: [Module2] bar -> bar_3
+// CHECK: [Module2] baz -> baz_4
+// CHECK: [Module2] sc -> sc_5
+
+// CHECK: [Module3] foo -> foo_6
+// CHECK: [Module3] bar -> bar_7
+// CHECK: [Module3] baz -> baz_8

diff  --git a/mlir/test/lib/Dialect/SPIRV/TestModuleCombiner.cpp b/mlir/test/lib/Dialect/SPIRV/TestModuleCombiner.cpp
index d7675012a77c..12054b1dc454 100644
--- a/mlir/test/lib/Dialect/SPIRV/TestModuleCombiner.cpp
+++ b/mlir/test/lib/Dialect/SPIRV/TestModuleCombiner.cpp
@@ -37,7 +37,14 @@ void TestModuleCombinerPass::runOnOperation() {
   auto modules = llvm::to_vector<4>(getOperation().getOps<spirv::ModuleOp>());
 
   OpBuilder combinedModuleBuilder(modules[0]);
-  combinedModule = spirv::combine(modules, combinedModuleBuilder, nullptr);
+
+  auto listener = [](spirv::ModuleOp originalModule, StringRef oldSymbol,
+                     StringRef newSymbol) {
+    llvm::outs() << "[" << originalModule.getName() << "] " << oldSymbol
+                 << " -> " << newSymbol << "\n";
+  };
+
+  combinedModule = spirv::combine(modules, combinedModuleBuilder, listener);
 
   for (spirv::ModuleOp module : modules)
     module.erase();


        


More information about the Mlir-commits mailing list