[Mlir-commits] [mlir] [mlir] Speed up FuncToLLVM using a SymbolTable (PR #68082)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Oct 5 01:23:39 PDT 2023


https://github.com/tdanyluk updated https://github.com/llvm/llvm-project/pull/68082

>From d40486d947401cbd9c46c69b3da42944188b2cfb Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Tam=C3=A1s=20Danyluk?= <tdanyluk at google.com>
Date: Tue, 3 Oct 2023 11:07:20 +0200
Subject: [PATCH] [mlir] Speed up FuncToLLVM using a SymbolTable

This means that we do hashmap lookups instead of linear lookups.
We have a project where this saves 23% of the compilation time.
---
 .../Conversion/FuncToLLVM/ConvertFuncToLLVM.h | 13 ++++-
 mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp | 56 ++++++++++++++-----
 2 files changed, 52 insertions(+), 17 deletions(-)

diff --git a/mlir/include/mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h b/mlir/include/mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h
index 142f77c976ffc3b..c9aa331d28e0e2b 100644
--- a/mlir/include/mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h
+++ b/mlir/include/mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h
@@ -19,6 +19,7 @@ namespace mlir {
 class DialectRegistry;
 class LLVMTypeConverter;
 class RewritePatternSet;
+class SymbolTable;
 
 /// Collect the default pattern to convert a FuncOp to the LLVM dialect. If
 /// `emitCWrappers` is set, the pattern will also produce functions
@@ -31,8 +32,16 @@ void populateFuncToLLVMFuncOpConversionPattern(LLVMTypeConverter &converter,
 /// conversion patterns capture the LLVMTypeConverter and the LowerToLLVMOptions
 /// by reference meaning the references have to remain alive during the entire
 /// pattern lifetime.
-void populateFuncToLLVMConversionPatterns(LLVMTypeConverter &converter,
-                                          RewritePatternSet &patterns);
+///
+/// The `symbolTable` parameter can be used to speed up function lookups in the
+/// module. It's good to provide it, but only if we know that the patterns will
+/// be applied to a single module and the symbols referenced by the symbol table
+/// will not be removed and new symbols will not be added during the usage of
+/// the patterns. If provided, the lookups will have O(calls) cumulative
+/// runtime, otherwise O(calls * functions).
+void populateFuncToLLVMConversionPatterns(
+    LLVMTypeConverter &converter, RewritePatternSet &patterns,
+    const SymbolTable *symbolTable = nullptr);
 
 void registerConvertFuncToLLVMInterface(DialectRegistry &registry);
 
diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
index 7de7f3cb9e36b06..d52f01880282e1a 100644
--- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
+++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
@@ -33,6 +33,7 @@
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/IRMapping.h"
 #include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/SymbolTable.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Support/LogicalResult.h"
 #include "mlir/Support/MathExtras.h"
@@ -48,6 +49,7 @@
 #include "llvm/Support/FormatVariadic.h"
 #include <algorithm>
 #include <functional>
+#include <optional>
 
 namespace mlir {
 #define GEN_PASS_DEF_CONVERTFUNCTOLLVMPASS
@@ -601,19 +603,38 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> {
   }
 };
 
-struct CallOpLowering : public CallOpInterfaceLowering<func::CallOp> {
-  using Super::Super;
+class CallOpLowering : public CallOpInterfaceLowering<func::CallOp> {
+public:
+  CallOpLowering(const LLVMTypeConverter &typeConverter,
+                 // Can be nullptr.
+                 const SymbolTable *symbolTable, PatternBenefit benefit = 1)
+      : CallOpInterfaceLowering<func::CallOp>(typeConverter, benefit),
+        symbolTable(symbolTable) {}
 
   LogicalResult
   matchAndRewrite(func::CallOp callOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     bool useBarePtrCallConv = false;
-    if (Operation *callee = SymbolTable::lookupNearestSymbolFrom(
-            callOp, callOp.getCalleeAttr())) {
-      useBarePtrCallConv = shouldUseBarePtrCallConv(callee, getTypeConverter());
+    if (getTypeConverter()->getOptions().useBarePtrCallConv) {
+      useBarePtrCallConv = true;
+    } else if (symbolTable != nullptr) {
+      // Fast lookup.
+      Operation *callee =
+          symbolTable->lookup(callOp.getCalleeAttr().getValue());
+      useBarePtrCallConv =
+          callee != nullptr && callee->hasAttr(barePtrAttrName);
+    } else {
+      // Warning: This is a linear lookup.
+      Operation *callee =
+          SymbolTable::lookupNearestSymbolFrom(callOp, callOp.getCalleeAttr());
+      useBarePtrCallConv =
+          callee != nullptr && callee->hasAttr(barePtrAttrName);
     }
     return matchAndRewriteImpl(callOp, adaptor, rewriter, useBarePtrCallConv);
   }
+
+private:
+  const SymbolTable *symbolTable = nullptr;
 };
 
 struct CallIndirectOpLowering
@@ -728,16 +749,14 @@ void mlir::populateFuncToLLVMFuncOpConversionPattern(
   patterns.add<FuncOpConversion>(converter);
 }
 
-void mlir::populateFuncToLLVMConversionPatterns(LLVMTypeConverter &converter,
-                                                RewritePatternSet &patterns) {
+void mlir::populateFuncToLLVMConversionPatterns(
+    LLVMTypeConverter &converter, RewritePatternSet &patterns,
+    const SymbolTable *symbolTable) {
   populateFuncToLLVMFuncOpConversionPattern(converter, patterns);
-  // clang-format off
-  patterns.add<
-      CallIndirectOpLowering,
-      CallOpLowering,
-      ConstantOpLowering,
-      ReturnOpLowering>(converter);
-  // clang-format on
+  patterns.add<CallIndirectOpLowering>(converter);
+  patterns.add<CallOpLowering>(converter, symbolTable);
+  patterns.add<ConstantOpLowering>(converter);
+  patterns.add<ReturnOpLowering>(converter);
 }
 
 namespace {
@@ -776,8 +795,15 @@ struct ConvertFuncToLLVMPass
     LLVMTypeConverter typeConverter(&getContext(), options,
                                     &dataLayoutAnalysis);
 
+    std::optional<SymbolTable> optSymbolTable = std::nullopt;
+    const SymbolTable *symbolTable = nullptr;
+    if (!options.useBarePtrCallConv) {
+      optSymbolTable.emplace(m);
+      symbolTable = &optSymbolTable.value();
+    }
+
     RewritePatternSet patterns(&getContext());
-    populateFuncToLLVMConversionPatterns(typeConverter, patterns);
+    populateFuncToLLVMConversionPatterns(typeConverter, patterns, symbolTable);
 
     // TODO: Remove these in favor of their dedicated conversion passes.
     arith::populateArithToLLVMConversionPatterns(typeConverter, patterns);



More information about the Mlir-commits mailing list