[Mlir-commits] [mlir] [mlir] Speed up FuncToLLVM: CallOpLowering using SymbolTableCollection (PR #68082)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Oct 3 08:24:37 PDT 2023


https://github.com/tdanyluk updated https://github.com/llvm/llvm-project/pull/68082

>From 63c17005b71a1c26e91260232456090a54f622b1 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Tam=C3=A1s=20Danyluk?= <tdanyluk at google.com>
Date: Tue, 3 Oct 2023 11:07:20 +0200
Subject: [PATCH] [mlir] Speed up FuncToLLVM by using a hashmap

We have a project where this saves 23% of the compilation time.
---
 mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp | 72 +++++++++++++++----
 1 file changed, 59 insertions(+), 13 deletions(-)

diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
index 7de7f3cb9e36b06..ec8811c976aa6e2 100644
--- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
+++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
@@ -48,6 +48,7 @@
 #include "llvm/Support/FormatVariadic.h"
 #include <algorithm>
 #include <functional>
+#include <unordered_map>
 
 namespace mlir {
 #define GEN_PASS_DEF_CONVERTFUNCTOLLVMPASS
@@ -601,19 +602,38 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> {
   }
 };
 
-struct CallOpLowering : public CallOpInterfaceLowering<func::CallOp> {
-  using Super::Super;
+class CallOpLowering : public CallOpInterfaceLowering<func::CallOp> {
+public:
+  CallOpLowering(
+      const LLVMTypeConverter &typeConverter,
+      // Can be nullptr.
+      const std::unordered_map<std::string, bool> *hasBarePtrAttribute,
+      PatternBenefit benefit = 1)
+      : CallOpInterfaceLowering<func::CallOp>(typeConverter, benefit),
+        barePtrCallConvForcedByOptions(
+            typeConverter.getOptions().useBarePtrCallConv),
+        hasBarePtrAttribute(hasBarePtrAttribute) {}
 
   LogicalResult
   matchAndRewrite(func::CallOp callOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     bool useBarePtrCallConv = false;
-    if (Operation *callee = SymbolTable::lookupNearestSymbolFrom(
+    if (barePtrCallConvForcedByOptions) {
+      useBarePtrCallConv = true;
+    } else if (hasBarePtrAttribute != nullptr) {
+      useBarePtrCallConv =
+          hasBarePtrAttribute->at(callOp.getCalleeAttr().getValue().str());
+    } else if ( // Warning: This is a linear lookup.
+        Operation *callee = SymbolTable::lookupNearestSymbolFrom(
             callOp, callOp.getCalleeAttr())) {
-      useBarePtrCallConv = shouldUseBarePtrCallConv(callee, getTypeConverter());
+      useBarePtrCallConv = callee->hasAttr(barePtrAttrName);
     }
     return matchAndRewriteImpl(callOp, adaptor, rewriter, useBarePtrCallConv);
   }
+
+private:
+  bool barePtrCallConvForcedByOptions = false;
+  const std::unordered_map<std::string, bool> *hasBarePtrAttribute = nullptr;
 };
 
 struct CallIndirectOpLowering
@@ -728,16 +748,23 @@ void mlir::populateFuncToLLVMFuncOpConversionPattern(
   patterns.add<FuncOpConversion>(converter);
 }
 
+namespace {
+void populateFuncToLLVMConversionPatternsInternal(
+    LLVMTypeConverter &converter, RewritePatternSet &patterns,
+    // Can be nullptr.
+    const std::unordered_map<std::string, bool> *hasBarePtrAttribute) {
+  populateFuncToLLVMFuncOpConversionPattern(converter, patterns);
+  patterns.add<CallIndirectOpLowering>(converter);
+  patterns.add<CallOpLowering>(converter, hasBarePtrAttribute);
+  patterns.add<ConstantOpLowering>(converter);
+  patterns.add<ReturnOpLowering>(converter);
+}
+} // namespace
+
 void mlir::populateFuncToLLVMConversionPatterns(LLVMTypeConverter &converter,
                                                 RewritePatternSet &patterns) {
-  populateFuncToLLVMFuncOpConversionPattern(converter, patterns);
-  // clang-format off
-  patterns.add<
-      CallIndirectOpLowering,
-      CallOpLowering,
-      ConstantOpLowering,
-      ReturnOpLowering>(converter);
-  // clang-format on
+  populateFuncToLLVMConversionPatternsInternal(converter, patterns,
+                                               /*hasBarePtrAttribute=*/nullptr);
 }
 
 namespace {
@@ -765,6 +792,24 @@ struct ConvertFuncToLLVMPass
 
     const auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
 
+    std::unordered_map<std::string, bool> hasBarePtrAttribute;
+    for (Region &region : m->getRegions()) {
+      for (Block &block : region.getBlocks()) {
+        for (Operation &op : block.getOperations()) {
+          if (isa<func::FuncOp>(op)) {
+            auto funcOp = cast<func::FuncOp>(op);
+            bool inserted = hasBarePtrAttribute
+                                .insert({funcOp.getSymName().str(),
+                                         funcOp->hasAttr(barePtrAttrName)})
+                                .second;
+            (void)inserted;
+            assert(inserted &&
+                   "expected module to contain uniquely named funcOps");
+          }
+        }
+      }
+    }
+
     LowerToLLVMOptions options(&getContext(),
                                dataLayoutAnalysis.getAtOrAbove(m));
     options.useBarePtrCallConv = useBarePtrCallConv;
@@ -777,7 +822,8 @@ struct ConvertFuncToLLVMPass
                                     &dataLayoutAnalysis);
 
     RewritePatternSet patterns(&getContext());
-    populateFuncToLLVMConversionPatterns(typeConverter, patterns);
+    populateFuncToLLVMConversionPatternsInternal(typeConverter, patterns,
+                                                 &hasBarePtrAttribute);
 
     // TODO: Remove these in favor of their dedicated conversion passes.
     arith::populateArithToLLVMConversionPatterns(typeConverter, patterns);



More information about the Mlir-commits mailing list