[Mlir-commits] [mlir] [mlir] Add pass to add comdat to all linkonce functions (PR #65270)

David Truby llvmlistbot at llvm.org
Tue Sep 12 06:29:19 PDT 2023


https://github.com/DavidTruby updated https://github.com/llvm/llvm-project/pull/65270:

>From a28350de1a2682a94819526e9c0b10be299aece4 Mon Sep 17 00:00:00 2001
From: David Truby <david at truby.dev>
Date: Mon, 4 Sep 2023 16:06:24 +0100
Subject: [PATCH 1/6] [mlir] Add option to add comdat to all linkonce functions

This adds an option to the FuncToLLVM pass to add an Any comdat to each linkonce
and linkonce_odr function when lowering. These comdats are necessary on Windows
to allow the default system linker to link binaries containing these functions.
---
 .../Conversion/LLVMCommon/LoweringOptions.h   |  1 +
 mlir/include/mlir/Conversion/Passes.td        |  7 +++--
 .../mlir/Dialect/LLVMIR/LLVMAttrDefs.td       |  2 +-
 mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp | 27 +++++++++++++++++++
 .../FuncToLLVM/add-linkonce-comdat.mlir       | 17 ++++++++++++
 5 files changed, 51 insertions(+), 3 deletions(-)
 create mode 100644 mlir/test/Conversion/FuncToLLVM/add-linkonce-comdat.mlir

diff --git a/mlir/include/mlir/Conversion/LLVMCommon/LoweringOptions.h b/mlir/include/mlir/Conversion/LLVMCommon/LoweringOptions.h
index cc4e17e9527f01e..677fdbb3e973df2 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/LoweringOptions.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/LoweringOptions.h
@@ -34,6 +34,7 @@ class LowerToLLVMOptions {
 
   bool useBarePtrCallConv = false;
   bool useOpaquePointers = true;
+  bool addComdatToLinkonceFuncs = false;
 
   enum class AllocLowering {
     /// Use malloc for heap allocations.
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index eaf016bde69e3be..6bcffa32a7010be 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -381,6 +381,9 @@ def ConvertFuncToLLVMPass : Pass<"convert-func-to-llvm", "ModuleOp"> {
     Option<"useOpaquePointers", "use-opaque-pointers", "bool",
                        /*default=*/"true", "Generate LLVM IR using opaque pointers "
                        "instead of typed pointers">,
+    Option<"addComdatToLinkonceFuncs", "add-comdat-to-linkonce-funcs", "bool",
+                       /*default=*/"false", "Add an any comdat selector to Linkonce "
+                       "functions">,
   ];
 }
 
@@ -760,12 +763,12 @@ def ConvertMemRefToSPIRV : Pass<"convert-memref-to-spirv"> {
 def ConvertNVVMToLLVMPass : Pass<"convert-nvvm-to-llvm"> {
   let summary = "Convert NVVM dialect to LLVM dialect";
   let description = [{
-    This pass generates inline assembly for the NVVM ops which is not 
+    This pass generates inline assembly for the NVVM ops which is not
     implemented in LLVM core.
   }];
   let dependentDialects = [
     "NVVM::NVVMDialect",
-  ];  
+  ];
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
index 2443b23e42e43ce..d4a9e911f36b145 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
@@ -35,7 +35,7 @@ def CConvAttr : LLVM_Attr<"CConv", "cconv"> {
 
 def ComdatAttr : LLVM_Attr<"Comdat", "comdat"> {
   let parameters = (ins "comdat::Comdat":$comdat);
-  let assemblyFormat = "$comdat";
+  let assemblyFormat = "`<` $comdat `>`";
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
index 1db463c0ab7163b..580b4796fbe55f9 100644
--- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
+++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
@@ -328,6 +328,25 @@ static void modifyFuncOpToUseBarePtrCallingConv(
 
 namespace {
 
+template <typename Op>
+void addComdat(Op &op, mlir::PatternRewriter &rewriter,
+               mlir::ModuleOp &module) {
+  const char *comdatName = "__llvm_comdat";
+  mlir::LLVM::ComdatOp comdatOp =
+      module.lookupSymbol<mlir::LLVM::ComdatOp>(comdatName);
+  if (!comdatOp) {
+    comdatOp =
+        rewriter.create<mlir::LLVM::ComdatOp>(module.getLoc(), comdatName);
+  }
+  mlir::OpBuilder::InsertionGuard guard(rewriter);
+  rewriter.setInsertionPointToEnd(&comdatOp.getBody().back());
+  auto selectorOp = rewriter.create<mlir::LLVM::ComdatSelectorOp>(
+      comdatOp.getLoc(), op.getSymName(), mlir::LLVM::comdat::Comdat::Any);
+  op.setComdatAttr(mlir::SymbolRefAttr::get(
+      rewriter.getContext(), comdatName,
+      mlir::FlatSymbolRefAttr::get(selectorOp.getSymNameAttr())));
+}
+
 struct FuncOpConversionBase : public ConvertOpToLLVMPattern<func::FuncOp> {
 protected:
   using ConvertOpToLLVMPattern<func::FuncOp>::ConvertOpToLLVMPattern;
@@ -461,6 +480,13 @@ struct FuncOpConversionBase : public ConvertOpToLLVMPattern<func::FuncOp> {
                                          "region types conversion failed");
     }
 
+    if (getTypeConverter()->getOptions().addComdatToLinkonceFuncs &&
+        (newFuncOp.getLinkage() == LLVM::Linkage::Linkonce ||
+         newFuncOp.getLinkage() == LLVM::Linkage::LinkonceODR)) {
+      auto module = newFuncOp->getParentOfType<mlir::ModuleOp>();
+      addComdat(newFuncOp, rewriter, module);
+    }
+
     return newFuncOp;
   }
 };
@@ -764,6 +790,7 @@ struct ConvertFuncToLLVMPass
       options.overrideIndexBitwidth(indexBitwidth);
     options.dataLayout = llvm::DataLayout(this->dataLayout);
     options.useOpaquePointers = useOpaquePointers;
+    options.addComdatToLinkonceFuncs = addComdatToLinkonceFuncs;
 
     LLVMTypeConverter typeConverter(&getContext(), options,
                                     &dataLayoutAnalysis);
diff --git a/mlir/test/Conversion/FuncToLLVM/add-linkonce-comdat.mlir b/mlir/test/Conversion/FuncToLLVM/add-linkonce-comdat.mlir
new file mode 100644
index 000000000000000..95cd8b6dd79de65
--- /dev/null
+++ b/mlir/test/Conversion/FuncToLLVM/add-linkonce-comdat.mlir
@@ -0,0 +1,17 @@
+// RUN: mlir-opt -convert-func-to-llvm='add-comdat-to-linkonce-funcs=1' -split-input-file -verify-diagnostics %s | FileCheck %s
+
+// CHECK-DAG: llvm.func linkonce @linkonce() comdat(@__llvm_comdat::@linkonce)
+func.func @linkonce() attributes {llvm.linkage = #llvm.linkage<linkonce>} {
+  return
+}
+
+// CHECK-DAG: llvm.comdat @__llvm_comdat {
+// CHECK: llvm.comdat_selector @linkonce any
+// CHECK: llvm.comdat_selector @linkonce_odr any
+// CHECK: }
+
+
+// CHECK-DAG: llvm.func linkonce_odr @linkonce_odr() comdat(@__llvm_comdat::@linkonce_odr)
+func.func @linkonce_odr() attributes {llvm.linkage = #llvm.linkage<linkonce_odr>} {
+  return
+}

>From 430df6dbf427232af6f85ad8350aedf7cfd328c8 Mon Sep 17 00:00:00 2001
From: David Truby <david at truby.dev>
Date: Wed, 6 Sep 2023 17:38:52 +0100
Subject: [PATCH 2/6] Switch to adding a pass instead of an option for
 FuncToLLVM

---
 .../Conversion/LLVMCommon/LoweringOptions.h   |  1 -
 mlir/include/mlir/Conversion/Passes.td        |  3 -
 .../mlir/Dialect/LLVMIR/LLVMAttrDefs.td       |  2 +-
 .../Dialect/LLVMIR/Transforms/AddComdats.h    | 29 +++++++
 .../mlir/Dialect/LLVMIR/Transforms/Passes.h   |  1 +
 .../mlir/Dialect/LLVMIR/Transforms/Passes.td  |  5 ++
 mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp | 27 -------
 .../Dialect/LLVMIR/Transforms/AddComdats.cpp  | 78 +++++++++++++++++++
 .../Dialect/LLVMIR/Transforms/CMakeLists.txt  |  1 +
 .../FuncToLLVM/add-linkonce-comdat.mlir       | 24 +++---
 10 files changed, 127 insertions(+), 44 deletions(-)
 create mode 100644 mlir/include/mlir/Dialect/LLVMIR/Transforms/AddComdats.h
 create mode 100644 mlir/lib/Dialect/LLVMIR/Transforms/AddComdats.cpp

diff --git a/mlir/include/mlir/Conversion/LLVMCommon/LoweringOptions.h b/mlir/include/mlir/Conversion/LLVMCommon/LoweringOptions.h
index 677fdbb3e973df2..cc4e17e9527f01e 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/LoweringOptions.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/LoweringOptions.h
@@ -34,7 +34,6 @@ class LowerToLLVMOptions {
 
   bool useBarePtrCallConv = false;
   bool useOpaquePointers = true;
-  bool addComdatToLinkonceFuncs = false;
 
   enum class AllocLowering {
     /// Use malloc for heap allocations.
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 6bcffa32a7010be..8094ed29b4bdc9e 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -381,9 +381,6 @@ def ConvertFuncToLLVMPass : Pass<"convert-func-to-llvm", "ModuleOp"> {
     Option<"useOpaquePointers", "use-opaque-pointers", "bool",
                        /*default=*/"true", "Generate LLVM IR using opaque pointers "
                        "instead of typed pointers">,
-    Option<"addComdatToLinkonceFuncs", "add-comdat-to-linkonce-funcs", "bool",
-                       /*default=*/"false", "Add an any comdat selector to Linkonce "
-                       "functions">,
   ];
 }
 
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
index d4a9e911f36b145..2443b23e42e43ce 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
@@ -35,7 +35,7 @@ def CConvAttr : LLVM_Attr<"CConv", "cconv"> {
 
 def ComdatAttr : LLVM_Attr<"Comdat", "comdat"> {
   let parameters = (ins "comdat::Comdat":$comdat);
-  let assemblyFormat = "`<` $comdat `>`";
+  let assemblyFormat = "$comdat";
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/LLVMIR/Transforms/AddComdats.h b/mlir/include/mlir/Dialect/LLVMIR/Transforms/AddComdats.h
new file mode 100644
index 000000000000000..4fee1a64a788105
--- /dev/null
+++ b/mlir/include/mlir/Dialect/LLVMIR/Transforms/AddComdats.h
@@ -0,0 +1,29 @@
+//===- AddComdats.h - Prepare for translation to LLVM IR -*- C++ -*--------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_LLVMIR_TRANSFORMS_ADDCOMDATS_H
+#define MLIR_DIALECT_LLVMIR_TRANSFORMS_ADDCOMDATS_H
+
+#include <memory>
+
+namespace mlir {
+
+class Pass;
+
+namespace LLVM {
+
+#define GEN_PASS_DECL_LLVMADDCOMDATS
+#include "mlir/Dialect/LLVMIR/Transforms/Passes.h.inc"
+
+/// Creates a pass that adds Comdats to any functions with linkonce or
+/// linkonce_odr linkage
+std::unique_ptr<Pass> createAddComdatsPass();
+} // namespace LLVM
+} // namespace mlir
+
+#endif // MLIR_DIALECT_LLVMIR_TRANSFORMS_ADDCOMDATS_H
diff --git a/mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.h b/mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.h
index 7e61bd2419d6509..13e10b29c0743ca 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.h
@@ -9,6 +9,7 @@
 #ifndef MLIR_DIALECT_LLVMIR_TRANSFORMS_PASSES_H
 #define MLIR_DIALECT_LLVMIR_TRANSFORMS_PASSES_H
 
+#include "mlir/Dialect/LLVMIR/Transforms/AddComdats.h"
 #include "mlir/Dialect/LLVMIR/Transforms/LegalizeForExport.h"
 #include "mlir/Dialect/LLVMIR/Transforms/OptimizeForNVVM.h"
 #include "mlir/Dialect/LLVMIR/Transforms/RequestCWrappers.h"
diff --git a/mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.td b/mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.td
index b7dfc8656fc1fd5..82f3bf2265a1001 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.td
@@ -11,6 +11,11 @@
 
 include "mlir/Pass/PassBase.td"
 
+def LLVMAddComdats : Pass<"llvm-add-comdats", "::mlir::ModuleOp"> {
+  let summary = "Add comdats to linkonce and linkonce_odr functions";
+  let constructor = "::mlir::LLVM::createAddComdatsPass()";
+}
+
 def LLVMLegalizeForExport : Pass<"llvm-legalize-for-export"> {
   let summary = "Legalize LLVM dialect to be convertible to LLVM IR";
   let constructor = "::mlir::LLVM::createLegalizeForExportPass()";
diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
index 580b4796fbe55f9..1db463c0ab7163b 100644
--- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
+++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
@@ -328,25 +328,6 @@ static void modifyFuncOpToUseBarePtrCallingConv(
 
 namespace {
 
-template <typename Op>
-void addComdat(Op &op, mlir::PatternRewriter &rewriter,
-               mlir::ModuleOp &module) {
-  const char *comdatName = "__llvm_comdat";
-  mlir::LLVM::ComdatOp comdatOp =
-      module.lookupSymbol<mlir::LLVM::ComdatOp>(comdatName);
-  if (!comdatOp) {
-    comdatOp =
-        rewriter.create<mlir::LLVM::ComdatOp>(module.getLoc(), comdatName);
-  }
-  mlir::OpBuilder::InsertionGuard guard(rewriter);
-  rewriter.setInsertionPointToEnd(&comdatOp.getBody().back());
-  auto selectorOp = rewriter.create<mlir::LLVM::ComdatSelectorOp>(
-      comdatOp.getLoc(), op.getSymName(), mlir::LLVM::comdat::Comdat::Any);
-  op.setComdatAttr(mlir::SymbolRefAttr::get(
-      rewriter.getContext(), comdatName,
-      mlir::FlatSymbolRefAttr::get(selectorOp.getSymNameAttr())));
-}
-
 struct FuncOpConversionBase : public ConvertOpToLLVMPattern<func::FuncOp> {
 protected:
   using ConvertOpToLLVMPattern<func::FuncOp>::ConvertOpToLLVMPattern;
@@ -480,13 +461,6 @@ struct FuncOpConversionBase : public ConvertOpToLLVMPattern<func::FuncOp> {
                                          "region types conversion failed");
     }
 
-    if (getTypeConverter()->getOptions().addComdatToLinkonceFuncs &&
-        (newFuncOp.getLinkage() == LLVM::Linkage::Linkonce ||
-         newFuncOp.getLinkage() == LLVM::Linkage::LinkonceODR)) {
-      auto module = newFuncOp->getParentOfType<mlir::ModuleOp>();
-      addComdat(newFuncOp, rewriter, module);
-    }
-
     return newFuncOp;
   }
 };
@@ -790,7 +764,6 @@ struct ConvertFuncToLLVMPass
       options.overrideIndexBitwidth(indexBitwidth);
     options.dataLayout = llvm::DataLayout(this->dataLayout);
     options.useOpaquePointers = useOpaquePointers;
-    options.addComdatToLinkonceFuncs = addComdatToLinkonceFuncs;
 
     LLVMTypeConverter typeConverter(&getContext(), options,
                                     &dataLayoutAnalysis);
diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/AddComdats.cpp b/mlir/lib/Dialect/LLVMIR/Transforms/AddComdats.cpp
new file mode 100644
index 000000000000000..5a38af0172214e2
--- /dev/null
+++ b/mlir/lib/Dialect/LLVMIR/Transforms/AddComdats.cpp
@@ -0,0 +1,78 @@
+//===- AddComdats.cpp - Prepare for translation to LLVM IR ----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/LLVMIR/Transforms/AddComdats.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir {
+namespace LLVM {
+#define GEN_PASS_DEF_LLVMADDCOMDATS
+#include "mlir/Dialect/LLVMIR/Transforms/Passes.h.inc"
+} // namespace LLVM
+} // namespace mlir
+
+using namespace mlir;
+
+static void addComdat(LLVM::LLVMFuncOp &op, PatternRewriter &rewriter,
+                      ModuleOp &module) {
+  const char *comdatName = "__llvm_comdat";
+  mlir::LLVM::ComdatOp comdatOp =
+      module.lookupSymbol<mlir::LLVM::ComdatOp>(comdatName);
+  if (!comdatOp) {
+    PatternRewriter::InsertionGuard guard(rewriter);
+    rewriter.setInsertionPointToStart(module.getBody());
+    comdatOp =
+        rewriter.create<mlir::LLVM::ComdatOp>(module.getLoc(), comdatName);
+  }
+
+  PatternRewriter::InsertionGuard guard(rewriter);
+  rewriter.setInsertionPointToStart(&comdatOp.getBody().back());
+  auto selectorOp = rewriter.create<mlir::LLVM::ComdatSelectorOp>(
+      comdatOp.getLoc(), op.getSymName(), mlir::LLVM::comdat::Comdat::Any);
+  rewriter.updateRootInPlace(op, [&]() {
+    op.setComdatAttr(mlir::SymbolRefAttr::get(
+        rewriter.getContext(), comdatName,
+        mlir::FlatSymbolRefAttr::get(selectorOp.getSymNameAttr())));
+  });
+}
+
+struct AddComdat : public OpRewritePattern<LLVM::LLVMFuncOp> {
+  using OpRewritePattern<LLVM::LLVMFuncOp>::OpRewritePattern;
+
+private:
+  LogicalResult matchAndRewrite(LLVM::LLVMFuncOp op,
+                                PatternRewriter &rewriter) const override {
+    if (op.getComdat() || (op.getLinkage() != LLVM::Linkage::Linkonce &&
+                           op.getLinkage() != LLVM::Linkage::LinkonceODR))
+      return failure();
+
+    auto mod = op->getParentOfType<mlir::ModuleOp>();
+    addComdat(op, rewriter, mod);
+    return success();
+  }
+};
+
+namespace {
+struct AddComdatsPass : public LLVM::impl::LLVMAddComdatsBase<AddComdatsPass> {
+  void runOnOperation() override {
+    MLIRContext *ctx = &getContext();
+    RewritePatternSet patterns(ctx);
+    patterns.add<AddComdat>(ctx);
+    if (failed(
+            applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
+      return signalPassFailure();
+  }
+};
+} // namespace
+
+std::unique_ptr<Pass> LLVM::createAddComdatsPass() {
+  return std::make_unique<AddComdatsPass>();
+}
diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/CMakeLists.txt b/mlir/lib/Dialect/LLVMIR/Transforms/CMakeLists.txt
index fac33b29a511c81..47a2a251bf3e8b2 100644
--- a/mlir/lib/Dialect/LLVMIR/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/LLVMIR/Transforms/CMakeLists.txt
@@ -1,4 +1,5 @@
 add_mlir_dialect_library(MLIRLLVMIRTransforms
+  AddComdats.cpp
   DIScopeForLLVMFuncOp.cpp
   LegalizeForExport.cpp
   OptimizeForNVVM.cpp
diff --git a/mlir/test/Conversion/FuncToLLVM/add-linkonce-comdat.mlir b/mlir/test/Conversion/FuncToLLVM/add-linkonce-comdat.mlir
index 95cd8b6dd79de65..01ebd6650496a87 100644
--- a/mlir/test/Conversion/FuncToLLVM/add-linkonce-comdat.mlir
+++ b/mlir/test/Conversion/FuncToLLVM/add-linkonce-comdat.mlir
@@ -1,17 +1,17 @@
-// RUN: mlir-opt -convert-func-to-llvm='add-comdat-to-linkonce-funcs=1' -split-input-file -verify-diagnostics %s | FileCheck %s
+// RUN: mlir-opt -llvm-add-comdats -verify-diagnostics %s | FileCheck %s
 
-// CHECK-DAG: llvm.func linkonce @linkonce() comdat(@__llvm_comdat::@linkonce)
-func.func @linkonce() attributes {llvm.linkage = #llvm.linkage<linkonce>} {
-  return
-}
-
-// CHECK-DAG: llvm.comdat @__llvm_comdat {
-// CHECK: llvm.comdat_selector @linkonce any
-// CHECK: llvm.comdat_selector @linkonce_odr any
+// CHECK: llvm.comdat @__llvm_comdat {
+// CHECK-DAG: llvm.comdat_selector @linkonce any
+// CHECK-DAG: llvm.comdat_selector @linkonce_odr any
 // CHECK: }
 
+// CHECK: llvm.func linkonce @linkonce() comdat(@__llvm_comdat::@linkonce)
+llvm.func linkonce @linkonce() {
+  llvm.return
+}
 
-// CHECK-DAG: llvm.func linkonce_odr @linkonce_odr() comdat(@__llvm_comdat::@linkonce_odr)
-func.func @linkonce_odr() attributes {llvm.linkage = #llvm.linkage<linkonce_odr>} {
-  return
+// CHECK: llvm.func linkonce_odr @linkonce_odr() comdat(@__llvm_comdat::@linkonce_odr)
+llvm.func linkonce_odr @linkonce_odr() {
+  llvm.return
 }
+

>From d482ac24a21299798d2cbc1e9e2655cf6a27a636 Mon Sep 17 00:00:00 2001
From: David Truby <david at truby.dev>
Date: Fri, 8 Sep 2023 13:41:38 +0100
Subject: [PATCH 3/6] Use simple traversal instead of pattern rewriter.

---
 .../Dialect/LLVMIR/Transforms/AddComdats.h    |  3 -
 .../mlir/Dialect/LLVMIR/Transforms/Passes.td  | 10 +++-
 .../Dialect/LLVMIR/Transforms/AddComdats.cpp  | 60 ++++++-------------
 .../LLVMIR}/add-linkonce-comdat.mlir          |  0
 4 files changed, 28 insertions(+), 45 deletions(-)
 rename mlir/test/{Conversion/FuncToLLVM => Dialect/LLVMIR}/add-linkonce-comdat.mlir (100%)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/Transforms/AddComdats.h b/mlir/include/mlir/Dialect/LLVMIR/Transforms/AddComdats.h
index 4fee1a64a788105..1edeb0928391665 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/Transforms/AddComdats.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/Transforms/AddComdats.h
@@ -20,9 +20,6 @@ namespace LLVM {
 #define GEN_PASS_DECL_LLVMADDCOMDATS
 #include "mlir/Dialect/LLVMIR/Transforms/Passes.h.inc"
 
-/// Creates a pass that adds Comdats to any functions with linkonce or
-/// linkonce_odr linkage
-std::unique_ptr<Pass> createAddComdatsPass();
 } // namespace LLVM
 } // namespace mlir
 
diff --git a/mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.td b/mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.td
index 82f3bf2265a1001..6ebbd08acfc431d 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.td
@@ -13,7 +13,15 @@ include "mlir/Pass/PassBase.td"
 
 def LLVMAddComdats : Pass<"llvm-add-comdats", "::mlir::ModuleOp"> {
   let summary = "Add comdats to linkonce and linkonce_odr functions";
-  let constructor = "::mlir::LLVM::createAddComdatsPass()";
+  let description = [{
+    Add an any COMDAT to every linkonce and linkonce_odr function.
+    This is necessary on Windows to link these functions as the system
+    linker won't link weak symbols without a COMDAT. It also provides better
+    behavior than standard weak symbols on ELF-based platforms.
+    This pass will still add COMDATs on platforms that do not support them,
+    for example macOS, so should only be run when the target platform supports
+    COMDATs.
+  }];
 }
 
 def LLVMLegalizeForExport : Pass<"llvm-legalize-for-export"> {
diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/AddComdats.cpp b/mlir/lib/Dialect/LLVMIR/Transforms/AddComdats.cpp
index 5a38af0172214e2..b5d208d157bd7a4 100644
--- a/mlir/lib/Dialect/LLVMIR/Transforms/AddComdats.cpp
+++ b/mlir/lib/Dialect/LLVMIR/Transforms/AddComdats.cpp
@@ -1,4 +1,4 @@
-//===- AddComdats.cpp - Prepare for translation to LLVM IR ----------------===//
+//===- AddComdats.cpp - Add comdats to linkonce functions -----------------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -8,9 +8,7 @@
 
 #include "mlir/Dialect/LLVMIR/Transforms/AddComdats.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
-#include "mlir/IR/PatternMatch.h"
 #include "mlir/Pass/Pass.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
 namespace mlir {
 namespace LLVM {
@@ -21,58 +19,38 @@ namespace LLVM {
 
 using namespace mlir;
 
-static void addComdat(LLVM::LLVMFuncOp &op, PatternRewriter &rewriter,
+static void addComdat(LLVM::LLVMFuncOp &op, OpBuilder &builder,
                       ModuleOp &module) {
   const char *comdatName = "__llvm_comdat";
   mlir::LLVM::ComdatOp comdatOp =
       module.lookupSymbol<mlir::LLVM::ComdatOp>(comdatName);
   if (!comdatOp) {
-    PatternRewriter::InsertionGuard guard(rewriter);
-    rewriter.setInsertionPointToStart(module.getBody());
+    PatternRewriter::InsertionGuard guard(builder);
+    builder.setInsertionPointToStart(module.getBody());
     comdatOp =
-        rewriter.create<mlir::LLVM::ComdatOp>(module.getLoc(), comdatName);
+        builder.create<mlir::LLVM::ComdatOp>(module.getLoc(), comdatName);
   }
 
-  PatternRewriter::InsertionGuard guard(rewriter);
-  rewriter.setInsertionPointToStart(&comdatOp.getBody().back());
-  auto selectorOp = rewriter.create<mlir::LLVM::ComdatSelectorOp>(
+  PatternRewriter::InsertionGuard guard(builder);
+  builder.setInsertionPointToStart(&comdatOp.getBody().back());
+  auto selectorOp = builder.create<mlir::LLVM::ComdatSelectorOp>(
       comdatOp.getLoc(), op.getSymName(), mlir::LLVM::comdat::Comdat::Any);
-  rewriter.updateRootInPlace(op, [&]() {
-    op.setComdatAttr(mlir::SymbolRefAttr::get(
-        rewriter.getContext(), comdatName,
-        mlir::FlatSymbolRefAttr::get(selectorOp.getSymNameAttr())));
-  });
+  op.setComdatAttr(mlir::SymbolRefAttr::get(
+      builder.getContext(), comdatName,
+      mlir::FlatSymbolRefAttr::get(selectorOp.getSymNameAttr())));
 }
 
-struct AddComdat : public OpRewritePattern<LLVM::LLVMFuncOp> {
-  using OpRewritePattern<LLVM::LLVMFuncOp>::OpRewritePattern;
-
-private:
-  LogicalResult matchAndRewrite(LLVM::LLVMFuncOp op,
-                                PatternRewriter &rewriter) const override {
-    if (op.getComdat() || (op.getLinkage() != LLVM::Linkage::Linkonce &&
-                           op.getLinkage() != LLVM::Linkage::LinkonceODR))
-      return failure();
-
-    auto mod = op->getParentOfType<mlir::ModuleOp>();
-    addComdat(op, rewriter, mod);
-    return success();
-  }
-};
-
 namespace {
 struct AddComdatsPass : public LLVM::impl::LLVMAddComdatsBase<AddComdatsPass> {
   void runOnOperation() override {
-    MLIRContext *ctx = &getContext();
-    RewritePatternSet patterns(ctx);
-    patterns.add<AddComdat>(ctx);
-    if (failed(
-            applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
-      return signalPassFailure();
+    OpBuilder builder{&getContext()};
+    ModuleOp mod = getOperation();
+    mod.walk([&](LLVM::LLVMFuncOp op) {
+      if (op.getLinkage() == LLVM::Linkage::Linkonce ||
+          op.getLinkage() == LLVM::Linkage::LinkonceODR) {
+        addComdat(op, builder, mod);
+      }
+    });
   }
 };
 } // namespace
-
-std::unique_ptr<Pass> LLVM::createAddComdatsPass() {
-  return std::make_unique<AddComdatsPass>();
-}
diff --git a/mlir/test/Conversion/FuncToLLVM/add-linkonce-comdat.mlir b/mlir/test/Dialect/LLVMIR/add-linkonce-comdat.mlir
similarity index 100%
rename from mlir/test/Conversion/FuncToLLVM/add-linkonce-comdat.mlir
rename to mlir/test/Dialect/LLVMIR/add-linkonce-comdat.mlir

>From 675715d9129337362542dfdfcc8486dfb2e1d0ba Mon Sep 17 00:00:00 2001
From: David Truby <david at truby.dev>
Date: Fri, 8 Sep 2023 14:34:49 +0100
Subject: [PATCH 4/6] Remove unrelated change to Conversion/Passes.td

---
 mlir/include/mlir/Conversion/Passes.td | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 8094ed29b4bdc9e..eaf016bde69e3be 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -760,12 +760,12 @@ def ConvertMemRefToSPIRV : Pass<"convert-memref-to-spirv"> {
 def ConvertNVVMToLLVMPass : Pass<"convert-nvvm-to-llvm"> {
   let summary = "Convert NVVM dialect to LLVM dialect";
   let description = [{
-    This pass generates inline assembly for the NVVM ops which is not
+    This pass generates inline assembly for the NVVM ops which is not 
     implemented in LLVM core.
   }];
   let dependentDialects = [
     "NVVM::NVVMDialect",
-  ];
+  ];  
 }
 
 //===----------------------------------------------------------------------===//

>From fcae58c3e8469cc9bdeaca89e40d1f291d22e029 Mon Sep 17 00:00:00 2001
From: David Truby <david at truby.dev>
Date: Mon, 11 Sep 2023 15:20:08 +0100
Subject: [PATCH 5/6] Build SymbolTable once and keep it updated during pass

---
 .../mlir/Dialect/LLVMIR/Transforms/AddComdats.h      |  2 +-
 mlir/lib/Dialect/LLVMIR/Transforms/AddComdats.cpp    | 12 +++++++-----
 2 files changed, 8 insertions(+), 6 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/Transforms/AddComdats.h b/mlir/include/mlir/Dialect/LLVMIR/Transforms/AddComdats.h
index 1edeb0928391665..a7bc1a1d286deda 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/Transforms/AddComdats.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/Transforms/AddComdats.h
@@ -1,4 +1,4 @@
-//===- AddComdats.h - Prepare for translation to LLVM IR -*- C++ -*--------===//
+//===- AddComdats.h - Add comdats to linkonce functions -*- C++ -*---------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/AddComdats.cpp b/mlir/lib/Dialect/LLVMIR/Transforms/AddComdats.cpp
index b5d208d157bd7a4..99dcedaf725bf65 100644
--- a/mlir/lib/Dialect/LLVMIR/Transforms/AddComdats.cpp
+++ b/mlir/lib/Dialect/LLVMIR/Transforms/AddComdats.cpp
@@ -20,15 +20,16 @@ namespace LLVM {
 using namespace mlir;
 
 static void addComdat(LLVM::LLVMFuncOp &op, OpBuilder &builder,
-                      ModuleOp &module) {
+                      SymbolTable &symbolTable, ModuleOp &module) {
   const char *comdatName = "__llvm_comdat";
   mlir::LLVM::ComdatOp comdatOp =
-      module.lookupSymbol<mlir::LLVM::ComdatOp>(comdatName);
+      symbolTable.lookup<mlir::LLVM::ComdatOp>(comdatName);
   if (!comdatOp) {
     PatternRewriter::InsertionGuard guard(builder);
     builder.setInsertionPointToStart(module.getBody());
     comdatOp =
         builder.create<mlir::LLVM::ComdatOp>(module.getLoc(), comdatName);
+    symbolTable.insert(comdatOp);
   }
 
   PatternRewriter::InsertionGuard guard(builder);
@@ -45,12 +46,13 @@ struct AddComdatsPass : public LLVM::impl::LLVMAddComdatsBase<AddComdatsPass> {
   void runOnOperation() override {
     OpBuilder builder{&getContext()};
     ModuleOp mod = getOperation();
-    mod.walk([&](LLVM::LLVMFuncOp op) {
+    SymbolTable symbolTable{mod};
+    for (auto op : mod.getBody()->getOps<LLVM::LLVMFuncOp>()) {
       if (op.getLinkage() == LLVM::Linkage::Linkonce ||
           op.getLinkage() == LLVM::Linkage::LinkonceODR) {
-        addComdat(op, builder, mod);
+        addComdat(op, builder, symbolTable, mod);
       }
-    });
+    }
   }
 };
 } // namespace

>From 361e8f66a2a7492aa1db30a1eddbe7bfdf8982d4 Mon Sep 17 00:00:00 2001
From: David Truby <David.Truby at arm.com>
Date: Tue, 12 Sep 2023 13:28:45 +0000
Subject: [PATCH 6/6] Make symbol table creation lazy

---
 mlir/lib/Dialect/LLVMIR/Transforms/AddComdats.cpp | 10 ++++++++--
 1 file changed, 8 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/AddComdats.cpp b/mlir/lib/Dialect/LLVMIR/Transforms/AddComdats.cpp
index 99dcedaf725bf65..6fbb0d24826d001 100644
--- a/mlir/lib/Dialect/LLVMIR/Transforms/AddComdats.cpp
+++ b/mlir/lib/Dialect/LLVMIR/Transforms/AddComdats.cpp
@@ -46,11 +46,17 @@ struct AddComdatsPass : public LLVM::impl::LLVMAddComdatsBase<AddComdatsPass> {
   void runOnOperation() override {
     OpBuilder builder{&getContext()};
     ModuleOp mod = getOperation();
-    SymbolTable symbolTable{mod};
+
+    std::unique_ptr<SymbolTable> symbolTable;
+    auto getSymTab = [&]() -> SymbolTable & {
+      if (!symbolTable)
+        symbolTable = std::make_unique<SymbolTable>(mod);
+      return *symbolTable;
+    };
     for (auto op : mod.getBody()->getOps<LLVM::LLVMFuncOp>()) {
       if (op.getLinkage() == LLVM::Linkage::Linkonce ||
           op.getLinkage() == LLVM::Linkage::LinkonceODR) {
-        addComdat(op, builder, symbolTable, mod);
+        addComdat(op, builder, getSymTab(), mod);
       }
     }
   }



More information about the Mlir-commits mailing list