[Mlir-commits] [mlir] b126ee6 - [mlir][llvm] Add comdat attribute to functions

Tobias Gysi llvmlistbot at llvm.org
Tue Jun 27 00:29:13 PDT 2023


Author: Tobias Gysi
Date: 2023-06-27T07:26:59Z
New Revision: b126ee65fcbb49054e32fd11fdac07279d00f159

URL: https://github.com/llvm/llvm-project/commit/b126ee65fcbb49054e32fd11fdac07279d00f159
DIFF: https://github.com/llvm/llvm-project/commit/b126ee65fcbb49054e32fd11fdac07279d00f159.diff

LOG: [mlir][llvm] Add comdat attribute to functions

This revision adds comdat support to functions. Additionally,
it ensures only comdats that have uses are imported/exported and
only non-empty global comdat operations are created.

Reviewed By: Dinistro

Differential Revision: https://reviews.llvm.org/D153739

Added: 
    mlir/test/Target/LLVMIR/comdat.mlir

Modified: 
    mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
    mlir/include/mlir/Target/LLVMIR/ModuleImport.h
    mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
    mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
    mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
    mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
    mlir/lib/Target/LLVMIR/ModuleImport.cpp
    mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
    mlir/test/Dialect/LLVMIR/comdat.mlir
    mlir/test/Dialect/LLVMIR/func.mlir
    mlir/test/Dialect/LLVMIR/global.mlir
    mlir/test/Dialect/LLVMIR/invalid.mlir
    mlir/test/Target/LLVMIR/Import/comdat.ll
    mlir/test/Target/LLVMIR/llvmir.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index e7aca6da5d68e..62d338b051c64 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -1634,6 +1634,7 @@ def LLVM_LLVMFuncOp : LLVM_Op<"func", [
     DefaultValuedAttr<Linkage, "Linkage::External">:$linkage,
     UnitAttr:$dso_local,
     DefaultValuedAttr<CConv, "CConv::C">:$CConv,
+    OptionalAttr<SymbolRefAttr>:$comdat,
     OptionalAttr<FlatSymbolRefAttr>:$personality,
     OptionalAttr<StrAttr>:$garbageCollector,
     OptionalAttr<ArrayAttr>:$passthrough,
@@ -1655,6 +1656,7 @@ def LLVM_LLVMFuncOp : LLVM_Op<"func", [
       CArg<"Linkage", "Linkage::External">:$linkage,
       CArg<"bool", "false">:$dsoLocal,
       CArg<"CConv", "CConv::C">:$cconv,
+      CArg<"SymbolRefAttr", "{}">:$comdat,
       CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs,
       CArg<"ArrayRef<DictionaryAttr>", "{}">:$argAttrs,
       CArg<"std::optional<uint64_t>", "{}">:$functionEntryCount)>

diff  --git a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
index 0a97bdfe19fcf..2278ffd8e10dd 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
@@ -310,6 +310,10 @@ class ModuleImport {
   /// operation. Returns success if all conversions succeed and failure
   /// otherwise.
   LogicalResult processAliasScopeMetadata(const llvm::MDNode *node);
+  /// Converts the given LLVM comdat struct to an MLIR comdat selector operation
+  /// and stores a mapping from the struct to the symbol pointing to the
+  /// translated operation.
+  void processComdat(const llvm::Comdat *comdat);
 
   /// Builder pointing at where the next instruction should be generated.
   OpBuilder builder;
@@ -348,6 +352,9 @@ class ModuleImport {
   /// Mapping between LLVM TBAA metadata nodes and symbol references to the LLVM
   /// dialect TBAA operations corresponding to these nodes.
   DenseMap<const llvm::MDNode *, SymbolRefAttr> tbaaMapping;
+  /// Mapping between LLVM comdat structs and symbol references to LLVM dialect
+  /// comdat selector operations corresponding to these structs.
+  DenseMap<const llvm::Comdat *, SymbolRefAttr> comdatMapping;
   /// The stateful type translator (contains named structs).
   LLVM::TypeFromLLVMIRTranslator typeTranslator;
   /// Stateful debug information importer.

diff  --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
index 960416d4913f7..9647c8db1bf3b 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
@@ -46,6 +46,7 @@ class LoopAnnotationTranslation;
 
 class DINodeAttr;
 class LLVMFuncOp;
+class ComdatSelectorOp;
 
 /// Implementation class for module translation. Holds a reference to the module
 /// being translated, and the mappings between the original and the translated
@@ -340,6 +341,10 @@ class ModuleTranslation {
   /// This map is populated on module entry.
   DenseMap<const Operation *, llvm::MDNode *> tbaaMetadataMapping;
 
+  /// Mapping from a comdat selector operation to its LLVM comdat struct.
+  /// This map is populated on module entry.
+  DenseMap<ComdatSelectorOp, llvm::Comdat *> comdatMapping;
+
   /// Stack of user-specified state elements, useful when translating operations
   /// with regions.
   SmallVector<std::unique_ptr<StackFrame>> stack;

diff  --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
index 3ae997d713104..776e6222fdcfe 100644
--- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
+++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
@@ -134,8 +134,8 @@ static void wrapForExternalCallers(OpBuilder &rewriter, Location loc,
       attributes);
   auto wrapperFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
       loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(),
-      wrapperFuncType, LLVM::Linkage::External, /*dsoLocal*/ false,
-      /*cconv*/ LLVM::CConv::C, attributes);
+      wrapperFuncType, LLVM::Linkage::External, /*dsoLocal=*/false,
+      /*cconv=*/LLVM::CConv::C, /*comdat=*/nullptr, attributes);
 
   OpBuilder::InsertionGuard guard(rewriter);
   rewriter.setInsertionPointToStart(wrapperFuncOp.addEntryBlock());
@@ -205,8 +205,8 @@ static void wrapExternalFunction(OpBuilder &builder, Location loc,
   // Create the auxiliary function.
   auto wrapperFunc = builder.create<LLVM::LLVMFuncOp>(
       loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(),
-      wrapperType, LLVM::Linkage::External, /*dsoLocal*/ false,
-      /*cconv*/ LLVM::CConv::C, attributes);
+      wrapperType, LLVM::Linkage::External, /*dsoLocal=*/false,
+      /*cconv=*/LLVM::CConv::C, /*comdat=*/nullptr, attributes);
 
   // The wrapper that we synthetize here should only be visible in this module.
   newFuncOp.setLinkage(LLVM::Linkage::Private);
@@ -445,7 +445,8 @@ struct FuncOpConversionBase : public ConvertOpToLLVMPattern<func::FuncOp> {
     }
     auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
         funcOp.getLoc(), funcOp.getName(), llvmType, linkage,
-        /*dsoLocal*/ false, /*cconv*/ LLVM::CConv::C, attributes);
+        /*dsoLocal=*/false, /*cconv=*/LLVM::CConv::C, /*comdat=*/nullptr,
+        attributes);
     // If the memory attribute was created, add it to the function.
     if (memoryAttr)
       newFuncOp.setMemoryAttr(memoryAttr);

diff  --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index 664d077c58875..38b7248e39725 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -75,8 +75,8 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
     attributes.emplace_back(kernelAttributeName, rewriter.getUnitAttr());
   auto llvmFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
       gpuFuncOp.getLoc(), gpuFuncOp.getName(), funcType,
-      LLVM::Linkage::External, /*dsoLocal*/ false, /*cconv*/ LLVM::CConv::C,
-      attributes);
+      LLVM::Linkage::External, /*dsoLocal=*/false, /*cconv=*/LLVM::CConv::C,
+      /*comdat=*/nullptr, attributes);
 
   {
     // Insert operations that correspond to converted workgroup and private

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 9b91304fadc2d..23c7515e6bfd4 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -1807,9 +1807,11 @@ static LogicalResult verifyComdat(Operation *op,
   return success();
 }
 
-// operation ::= `llvm.mlir.global` linkage? `constant`? `@` identifier
-//               `(` attribute? `)` align? attribute-list? (`:` type)? region?
-// align     ::= `align` `=` UINT64
+// operation ::= `llvm.mlir.global` linkage? visibility?
+//               (`unnamed_addr` | `local_unnamed_addr`)?
+//               `thread_local`? `constant`? `@` identifier
+//               `(` attribute? `)` (`comdat(` symbol-ref-id `)`)?
+//               attribute-list? (`:` type)? region?
 //
 // The type can be omitted for string attributes, in which case it will be
 // inferred from the value of the string as [strlen(value) x i8].
@@ -2103,7 +2105,7 @@ Block *LLVMFuncOp::addEntryBlock() {
 
 void LLVMFuncOp::build(OpBuilder &builder, OperationState &result,
                        StringRef name, Type type, LLVM::Linkage linkage,
-                       bool dsoLocal, CConv cconv,
+                       bool dsoLocal, CConv cconv, SymbolRefAttr comdat,
                        ArrayRef<NamedAttribute> attrs,
                        ArrayRef<DictionaryAttr> argAttrs,
                        std::optional<uint64_t> functionEntryCount) {
@@ -2120,6 +2122,8 @@ void LLVMFuncOp::build(OpBuilder &builder, OperationState &result,
   if (dsoLocal)
     result.addAttribute(getDsoLocalAttrName(result.name),
                         builder.getUnitAttr());
+  if (comdat)
+    result.addAttribute(getComdatAttrName(result.name), comdat);
   if (functionEntryCount)
     result.addAttribute(getFunctionEntryCountAttrName(result.name),
                         builder.getI64IntegerAttr(functionEntryCount.value()));
@@ -2174,8 +2178,9 @@ buildLLVMFunctionType(OpAsmParser &parser, SMLoc loc, ArrayRef<Type> inputs,
 // Parses an LLVM function.
 //
 // operation ::= `llvm.func` linkage? cconv? function-signature
-// function-attributes?
-//               function-body
+//                (`comdat(` symbol-ref-id `)`)?
+//                function-attributes?
+//                function-body
 //
 ParseResult LLVMFuncOp::parse(OpAsmParser &parser, OperationState &result) {
   // Default to external linkage if no keyword is provided.
@@ -2222,6 +2227,16 @@ ParseResult LLVMFuncOp::parse(OpAsmParser &parser, OperationState &result) {
   result.addAttribute(getFunctionTypeAttrName(result.name),
                       TypeAttr::get(type));
 
+  // Parse the optional comdat selector.
+  if (succeeded(parser.parseOptionalKeyword("comdat"))) {
+    SymbolRefAttr comdat;
+    if (parser.parseLParen() || parser.parseAttribute(comdat) ||
+        parser.parseRParen())
+      return failure();
+
+    result.addAttribute(getComdatAttrName(result.name), comdat);
+  }
+
   if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
     return failure();
   function_interface_impl::addArgAndResultAttrs(
@@ -2262,10 +2277,16 @@ void LLVMFuncOp::print(OpAsmPrinter &p) {
 
   function_interface_impl::printFunctionSignature(p, *this, argTypes,
                                                   isVarArg(), resTypes);
+
+  // Print the optional comdat selector.
+  if (auto comdat = getComdat())
+    p << " comdat(" << *comdat << ')';
+
   function_interface_impl::printFunctionAttributes(
       p, *this,
       {getFunctionTypeAttrName(), getArgAttrsAttrName(), getResAttrsAttrName(),
-       getLinkageAttrName(), getCConvAttrName(), getVisibility_AttrName()});
+       getLinkageAttrName(), getCConvAttrName(), getVisibility_AttrName(),
+       getComdatAttrName()});
 
   // Print the body if this is not an external function.
   Region &body = getBody();
@@ -2286,6 +2307,9 @@ LogicalResult LLVMFuncOp::verify() {
                          << stringifyLinkage(LLVM::Linkage::Common)
                          << "' linkage";
 
+  if (failed(verifyComdat(*this, getComdat())))
+    return failure();
+
   if (isExternal()) {
     if (getLinkage() != LLVM::Linkage::External &&
         getLinkage() != LLVM::Linkage::ExternWeak)

diff  --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index ac7280470f874..381560a43874a 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -172,8 +172,10 @@ MetadataOp ModuleImport::getGlobalMetadataOp() {
 
   OpBuilder::InsertionGuard guard(builder);
   builder.setInsertionPointToEnd(mlirModule.getBody());
-  return globalMetadataOp = builder.create<MetadataOp>(
-             mlirModule.getLoc(), getGlobalMetadataOpName());
+  globalMetadataOp = builder.create<MetadataOp>(mlirModule.getLoc(),
+                                                getGlobalMetadataOpName());
+  globalInsertionOp = globalMetadataOp;
+  return globalMetadataOp;
 }
 
 ComdatOp ModuleImport::getGlobalComdatOp() {
@@ -181,9 +183,11 @@ ComdatOp ModuleImport::getGlobalComdatOp() {
     return globalComdatOp;
 
   OpBuilder::InsertionGuard guard(builder);
-  builder.setInsertionPointToStart(mlirModule.getBody());
-  return globalComdatOp = builder.create<ComdatOp>(mlirModule.getLoc(),
-                                                   getGlobalComdatOpName());
+  builder.setInsertionPointToEnd(mlirModule.getBody());
+  globalComdatOp =
+      builder.create<ComdatOp>(mlirModule.getLoc(), getGlobalComdatOpName());
+  globalInsertionOp = globalComdatOp;
+  return globalComdatOp;
 }
 
 LogicalResult ModuleImport::processTBAAMetadata(const llvm::MDNode *node) {
@@ -559,17 +563,29 @@ LogicalResult ModuleImport::convertMetadata() {
   return success();
 }
 
-LogicalResult ModuleImport::convertComdats() {
-  ComdatOp comdat = getGlobalComdatOp();
+void ModuleImport::processComdat(const llvm::Comdat *comdat) {
+  if (comdatMapping.contains(comdat))
+    return;
+
+  ComdatOp comdatOp = getGlobalComdatOp();
   OpBuilder::InsertionGuard guard(builder);
-  builder.setInsertionPointToEnd(&comdat.getBody().back());
-  for (auto &kv : llvmModule->getComdatSymbolTable()) {
-    StringRef name = kv.getKey();
-    llvm::Comdat::SelectionKind selector = kv.getValue().getSelectionKind();
-    builder.create<ComdatSelectorOp>(mlirModule.getLoc(), name,
-                                     convertComdatFromLLVM(selector));
-  }
+  builder.setInsertionPointToEnd(&comdatOp.getBody().back());
+  auto selectorOp = builder.create<ComdatSelectorOp>(
+      mlirModule.getLoc(), comdat->getName(),
+      convertComdatFromLLVM(comdat->getSelectionKind()));
+  auto symbolRef =
+      SymbolRefAttr::get(builder.getContext(), getGlobalComdatOpName(),
+                         FlatSymbolRefAttr::get(selectorOp.getSymNameAttr()));
+  comdatMapping.try_emplace(comdat, symbolRef);
+}
 
+LogicalResult ModuleImport::convertComdats() {
+  for (llvm::GlobalVariable &globalVar : llvmModule->globals())
+    if (globalVar.hasComdat())
+      processComdat(globalVar.getComdat());
+  for (llvm::Function &func : llvmModule->functions())
+    if (func.hasComdat())
+      processComdat(func.getComdat());
   return success();
 }
 
@@ -892,18 +908,8 @@ LogicalResult ModuleImport::convertGlobal(llvm::GlobalVariable *globalVar) {
   globalOp.setVisibility_(
       convertVisibilityFromLLVM(globalVar->getVisibility()));
 
-  if (globalVar->hasComdat()) {
-    llvm::Comdat *llvmComdat = globalVar->getComdat();
-    ComdatOp comdat = getGlobalComdatOp();
-    if (ComdatSelectorOp selector = dyn_cast<ComdatSelectorOp>(
-            comdat.lookupSymbol(llvmComdat->getName()))) {
-      auto symbolRef =
-          SymbolRefAttr::get(builder.getContext(), getGlobalComdatOpName(),
-                             FlatSymbolRefAttr::get(selector.getSymNameAttr()));
-      globalOp.setComdatAttr(symbolRef);
-    } else
-      return failure();
-  }
+  if (globalVar->hasComdat())
+    globalOp.setComdatAttr(comdatMapping.lookup(globalVar->getComdat()));
 
   return success();
 }
@@ -1728,6 +1734,9 @@ LogicalResult ModuleImport::processFunction(llvm::Function *func) {
 
   funcOp.setVisibility_(convertVisibilityFromLLVM(func->getVisibility()));
 
+  if (func->hasComdat())
+    funcOp.setComdatAttr(comdatMapping.lookup(func->getComdat()));
+
   // Handle Function attributes.
   processFunctionAttributes(func, funcOp);
 
@@ -1831,10 +1840,10 @@ mlir::translateLLVMIRToModule(std::unique_ptr<llvm::Module> llvmModule,
     return {};
   if (failed(moduleImport.convertDataLayout()))
     return {};
-  if (failed(moduleImport.convertMetadata()))
-    return {};
   if (failed(moduleImport.convertComdats()))
     return {};
+  if (failed(moduleImport.convertMetadata()))
+    return {};
   if (failed(moduleImport.convertGlobals()))
     return {};
   if (failed(moduleImport.convertFunctions()))

diff  --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index 685b031dd509b..62f881171ba22 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -732,11 +732,9 @@ LogicalResult ModuleTranslation::convertGlobals() {
         addrSpace);
 
     if (std::optional<mlir::SymbolRefAttr> comdat = op.getComdat()) {
-      StringRef name = comdat->getLeafReference().getValue();
-      if (!llvmModule->getComdatSymbolTable().contains(name))
-        return emitError(op.getLoc(), "global references non-existant comdat");
-      llvm::Comdat *llvmComdat = llvmModule->getOrInsertComdat(name);
-      var->setComdat(llvmComdat);
+      auto selectorOp = cast<ComdatSelectorOp>(
+          SymbolTable::lookupNearestSymbolFrom(op, *comdat));
+      var->setComdat(comdatMapping.lookup(selectorOp));
     }
 
     if (op.getUnnamedAddr().has_value())
@@ -1033,6 +1031,13 @@ LogicalResult ModuleTranslation::convertFunctionSignatures() {
 
     // Convert visibility attribute.
     llvmFunc->setVisibility(convertVisibilityToLLVM(function.getVisibility_()));
+
+    // Convert the comdat attribute.
+    if (std::optional<mlir::SymbolRefAttr> comdat = function.getComdat()) {
+      auto selectorOp = cast<ComdatSelectorOp>(
+          SymbolTable::lookupNearestSymbolFrom(function, *comdat));
+      llvmFunc->setComdat(comdatMapping.lookup(selectorOp));
+    }
   }
 
   return success();
@@ -1053,17 +1058,16 @@ LogicalResult ModuleTranslation::convertFunctions() {
 }
 
 LogicalResult ModuleTranslation::convertComdats() {
-  for (ComdatOp comdat : getModuleBody(mlirModule).getOps<ComdatOp>()) {
-    for (ComdatSelectorOp selector : comdat.getOps<ComdatSelectorOp>()) {
-      StringRef name = selector.getName();
+  for (auto comdatOp : getModuleBody(mlirModule).getOps<ComdatOp>()) {
+    for (auto selectorOp : comdatOp.getOps<ComdatSelectorOp>()) {
       llvm::Module *module = getLLVMModule();
-      if (module->getComdatSymbolTable().contains(name)) {
-        return emitError(selector.getLoc())
+      if (module->getComdatSymbolTable().contains(selectorOp.getSymName()))
+        return emitError(selectorOp.getLoc())
                << "comdat selection symbols must be unique even in 
diff erent "
                   "comdat regions";
-      }
-      llvm::Comdat *comdat = module->getOrInsertComdat(name);
-      comdat->setSelectionKind(convertComdatToLLVM(selector.getComdat()));
+      llvm::Comdat *comdat = module->getOrInsertComdat(selectorOp.getSymName());
+      comdat->setSelectionKind(convertComdatToLLVM(selectorOp.getComdat()));
+      comdatMapping.try_emplace(selectorOp, comdat);
     }
   }
   return success();
@@ -1402,10 +1406,10 @@ mlir::translateModuleToLLVMIR(Operation *module, llvm::LLVMContext &llvmContext,
   LLVM::ensureDistinctSuccessors(module);
 
   ModuleTranslation translator(module, std::move(llvmModule));
-  if (failed(translator.convertFunctionSignatures()))
-    return nullptr;
   if (failed(translator.convertComdats()))
     return nullptr;
+  if (failed(translator.convertFunctionSignatures()))
+    return nullptr;
   if (failed(translator.convertGlobals()))
     return nullptr;
   if (failed(translator.createAccessGroupMetadata()))

diff  --git a/mlir/test/Dialect/LLVMIR/comdat.mlir b/mlir/test/Dialect/LLVMIR/comdat.mlir
index e8eaaa15160a0..234ab0463fd7b 100644
--- a/mlir/test/Dialect/LLVMIR/comdat.mlir
+++ b/mlir/test/Dialect/LLVMIR/comdat.mlir
@@ -13,4 +13,3 @@ llvm.comdat @__llvm_comdat {
   // CHECK: llvm.comdat_selector @samesize_comdat samesize
   llvm.comdat_selector @samesize_comdat samesize
 }
-

diff  --git a/mlir/test/Dialect/LLVMIR/func.mlir b/mlir/test/Dialect/LLVMIR/func.mlir
index eec5bb7855dbb..466f7d1eb70d8 100644
--- a/mlir/test/Dialect/LLVMIR/func.mlir
+++ b/mlir/test/Dialect/LLVMIR/func.mlir
@@ -204,6 +204,16 @@ module {
   llvm.func protected @protected() {
     llvm.return
   }
+
+  // CHECK: llvm.comdat @__llvm_comdat
+  llvm.comdat @__llvm_comdat {
+    // CHECK: llvm.comdat_selector @any any
+    llvm.comdat_selector @any any
+  }
+  // CHECK: @any() comdat(@__llvm_comdat::@any) attributes
+  llvm.func @any() comdat(@__llvm_comdat::@any) attributes { dso_local } {
+    llvm.return
+  }
 }
 
 // -----

diff  --git a/mlir/test/Dialect/LLVMIR/global.mlir b/mlir/test/Dialect/LLVMIR/global.mlir
index c9a28e0a9eb93..e653ec48d5679 100644
--- a/mlir/test/Dialect/LLVMIR/global.mlir
+++ b/mlir/test/Dialect/LLVMIR/global.mlir
@@ -69,8 +69,8 @@ llvm.comdat @__llvm_comdat {
   // CHECK: llvm.comdat_selector @any any
   llvm.comdat_selector @any any
 }
-// CHECK: llvm.mlir.global external @any() comdat(@__llvm_comdat::@any) {addr_space = 0 : i32} : i64
-llvm.mlir.global @any() comdat(@__llvm_comdat::@any) : i64
+// CHECK: llvm.mlir.global external @any() comdat(@__llvm_comdat::@any) {addr_space = 1 : i32} : i64
+llvm.mlir.global @any() comdat(@__llvm_comdat::@any) {addr_space = 1 : i32} : i64
 
 // CHECK-LABEL: references
 func.func @references() {

diff  --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index e6e11923439e0..e8a297c9fde1c 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -1438,15 +1438,22 @@ func.func @invalid_target_ext_constant() {
 // -----
 
 llvm.comdat @__llvm_comdat {
-  // expected-error at +1 {{only comdat selector symbols can appear in a comdat region}}
+  // expected-error at below {{only comdat selector symbols can appear in a comdat region}}
   llvm.return
 }
 
 // -----
 
 llvm.mlir.global @not_comdat(0 : i32) : i32
-// expected-error at +1 {{expected comdat symbol}}
-llvm.mlir.global @invalid_comdat_use(0 : i32) comdat(@not_comdat) : i32
+// expected-error at below {{expected comdat symbol}}
+llvm.mlir.global @invalid_global_comdat(0 : i32) comdat(@not_comdat) : i32
+
+// -----
+
+// expected-error at below {{expected comdat symbol}}
+llvm.func @invalid_func_comdat() comdat(@foo) {
+  llvm.return
+}
 
 // -----
 

diff  --git a/mlir/test/Target/LLVMIR/Import/comdat.ll b/mlir/test/Target/LLVMIR/Import/comdat.ll
index 8ef38eb8728a1..e6ec77a219717 100644
--- a/mlir/test/Target/LLVMIR/Import/comdat.ll
+++ b/mlir/test/Target/LLVMIR/Import/comdat.ll
@@ -1,20 +1,54 @@
-; RUN: mlir-translate -import-llvm %s | FileCheck %s
+; RUN: mlir-translate -import-llvm -split-input-file %s | FileCheck %s
 
-; CHECK: llvm.mlir.global external @foo(42 : i64) comdat(@__llvm_global_comdat::@foo) {addr_space = 0 : i32} : i64
+; CHECK: llvm.comdat @__llvm_global_comdat {
+; CHECK: llvm.comdat_selector @foo any
+$foo = comdat any
+; CHECK: }
+
+; CHECK: llvm.mlir.global external @foo(42 : i64) comdat(@__llvm_global_comdat::@foo)
 @foo = global i64 42, comdat
-; CHECK: llvm.mlir.global external @bar(42 : i64) comdat(@__llvm_global_comdat::@foo) {addr_space = 0 : i32} : i64
+; CHECK: llvm.mlir.global external @bar(42 : i64) comdat(@__llvm_global_comdat::@foo)
 @bar = global i64 42, comdat($foo)
 
+; // -----
+
 ; CHECK: llvm.comdat @__llvm_global_comdat {
-; CHECK-DAG: llvm.comdat_selector @foo any
+; CHECK: llvm.comdat_selector @foo any
 $foo = comdat any
-; CHECK-DAG: llvm.comdat_selector @exact exactmatch
+; CHECK: }
+
+
+; CHECK: llvm.func @foo() comdat(@__llvm_global_comdat::@foo)
+define void @foo() comdat {
+  ret void
+}
+; CHECK: llvm.func @bar() comdat(@__llvm_global_comdat::@foo)
+define void @bar() comdat($foo) {
+  ret void
+}
+
+; // -----
+
+; CHECK: llvm.comdat @__llvm_global_comdat {
+; CHECK: llvm.comdat_selector @exact exactmatch
 $exact = comdat exactmatch
-; CHECK-DAG: llvm.comdat_selector @largest largest
+; CHECK: llvm.comdat_selector @largest largest
 $largest = comdat largest
-; CHECK-DAG: llvm.comdat_selector @nodedup nodeduplicate
+; CHECK: llvm.comdat_selector @nodedup nodeduplicate
 $nodedup = comdat nodeduplicate
-; CHECK-DAG: llvm.comdat_selector @same samesize
+; CHECK: llvm.comdat_selector @same samesize
 $same = comdat samesize
-
 ; CHECK: }
+
+ at exact = global i64 42, comdat
+ at largest = global i64 42, comdat
+ at nodedup = global i64 42, comdat
+ at same = global i64 42, comdat
+
+; // -----
+
+; Verify a global comdat operation is only created if there are comdats to import.
+; CHECK-NOT: llvm.comdat
+; CHECK: llvm.mlir.global external @foobar
+; CHECK-NOT: llvm.comdat
+ at foobar = global i64 42

diff  --git a/mlir/test/Target/LLVMIR/comdat.mlir b/mlir/test/Target/LLVMIR/comdat.mlir
new file mode 100644
index 0000000000000..c77e433db46bb
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/comdat.mlir
@@ -0,0 +1,30 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+llvm.comdat @__llvm_comdat {
+  // CHECK-DAG: $[[ANY:.*]] = comdat any
+  llvm.comdat_selector @any any
+  // CHECK-DAG: $[[EXACT:.*]] = comdat exactmatch
+  llvm.comdat_selector @exactmatch exactmatch
+  // CHECK-DAG: $[[LARGEST:.*]] = comdat largest
+  llvm.comdat_selector @largest largest
+  // CHECK-DAG: $[[NODEDUP:.*]] = comdat nodeduplicate
+  llvm.comdat_selector @nodeduplicate nodeduplicate
+  // CHECK-DAG: $[[SAME:.*]] = comdat samesize
+  llvm.comdat_selector @samesize samesize
+}
+
+// CHECK: @any = internal constant i64 1, comdat
+llvm.mlir.global internal constant @any(1 : i64) comdat(@__llvm_comdat::@any) : i64
+// CHECK: @any_global = internal constant i64 1, comdat($[[ANY]])
+llvm.mlir.global internal constant @any_global(1 : i64) comdat(@__llvm_comdat::@any) : i64
+// CHECK: @exact_global = internal constant i64 1, comdat($[[EXACT]])
+llvm.mlir.global internal constant @exact_global(1 : i64) comdat(@__llvm_comdat::@exactmatch) : i64
+// CHECK: @largest_global = internal constant i64 1, comdat($[[LARGEST]])
+llvm.mlir.global internal constant @largest_global(1 : i64) comdat(@__llvm_comdat::@largest) : i64
+
+// CHECK: define void @nodeduplicate() comdat
+llvm.func @nodeduplicate() comdat(@__llvm_comdat::@nodeduplicate) { llvm.return }
+// CHECK: define void @nodeduplicate_func() comdat($[[NODEDUP]])
+llvm.func @nodeduplicate_func() comdat(@__llvm_comdat::@nodeduplicate) { llvm.return }
+// CHECK: define void @samesize_func() comdat($[[SAME]])
+llvm.func @samesize_func() comdat(@__llvm_comdat::@samesize) { llvm.return }

diff  --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir
index 6365eabb34165..9b7dc9d28c229 100644
--- a/mlir/test/Target/LLVMIR/llvmir.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir.mlir
@@ -1,21 +1,5 @@
 // RUN: mlir-translate -mlir-to-llvmir -split-input-file %s | FileCheck %s
 
-
-// Comdat sections
-llvm.comdat @__llvm_comdat {
-  // CHECK: $any = comdat any
-  llvm.comdat_selector @any any
-  // CHECK: $exactmatch = comdat exactmatch
-  llvm.comdat_selector @exactmatch exactmatch
-  // CHECK: $largest = comdat largest
-  llvm.comdat_selector @largest largest
-  // CHECK: $nodeduplicate = comdat nodeduplicate
-  llvm.comdat_selector @nodeduplicate nodeduplicate
-  // CHECK: $samesize = comdat samesize
-  llvm.comdat_selector @samesize samesize
-}
-
-
 // CHECK: @global_aligned32 = private global i64 42, align 32
 "llvm.mlir.global"() ({}) {sym_name = "global_aligned32", global_type = i64, value = 42 : i64, linkage = #llvm.linkage<private>, alignment = 32} : () -> ()
 
@@ -184,20 +168,6 @@ llvm.mlir.global thread_local @has_thr_local(42 : i64) : i64
 // CHECK: @sectionvar = internal constant [10 x i8] c"teststring", section ".mysection"
 llvm.mlir.global internal constant @sectionvar("teststring")  {section = ".mysection"}: !llvm.array<10 x i8>
 
-//
-// Comdat attribute.
-//
-// CHECK: @has_any_comdat = internal constant i64 1, comdat($any)
-llvm.mlir.global internal constant @has_any_comdat(1 : i64) comdat(@__llvm_comdat::@any) : i64
-// CHECK: @has_exactmatch_comdat = internal constant i64 1, comdat($exactmatch)
-llvm.mlir.global internal constant @has_exactmatch_comdat(1 : i64) comdat(@__llvm_comdat::@exactmatch) : i64
-// CHECK: @has_largest_comdat = internal constant i64 1, comdat($largest)
-llvm.mlir.global internal constant @has_largest_comdat(1 : i64) comdat(@__llvm_comdat::@largest) : i64
-// CHECK: @has_nodeduplicate_comdat = internal constant i64 1, comdat($nodeduplicate)
-llvm.mlir.global internal constant @has_nodeduplicate_comdat(1 : i64) comdat(@__llvm_comdat::@nodeduplicate) : i64
-// CHECK: @has_samesize_comdat = internal constant i64 1, comdat($samesize)
-llvm.mlir.global internal constant @has_samesize_comdat(1 : i64) comdat(@__llvm_comdat::@samesize) : i64
-
 //
 // Declarations of the allocation functions to be linked against. These are
 // inserted before other functions in the module.


        


More information about the Mlir-commits mailing list