[Mlir-commits] [mlir] [mlir] Speed up FuncToLLVM: CallOpLowering using SymbolTableCollection (PR #68082)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Oct 3 08:24:37 PDT 2023
https://github.com/tdanyluk updated https://github.com/llvm/llvm-project/pull/68082
>From 63c17005b71a1c26e91260232456090a54f622b1 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Tam=C3=A1s=20Danyluk?= <tdanyluk at google.com>
Date: Tue, 3 Oct 2023 11:07:20 +0200
Subject: [PATCH] [mlir] Speed up FuncToLLVM by using a hashmap
We have a project where this saves 23% of the compilation time.
---
mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp | 72 +++++++++++++++----
1 file changed, 59 insertions(+), 13 deletions(-)
diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
index 7de7f3cb9e36b06..ec8811c976aa6e2 100644
--- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
+++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
@@ -48,6 +48,7 @@
#include "llvm/Support/FormatVariadic.h"
#include <algorithm>
#include <functional>
+#include <unordered_map>
namespace mlir {
#define GEN_PASS_DEF_CONVERTFUNCTOLLVMPASS
@@ -601,19 +602,38 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> {
}
};
-struct CallOpLowering : public CallOpInterfaceLowering<func::CallOp> {
- using Super::Super;
+class CallOpLowering : public CallOpInterfaceLowering<func::CallOp> {
+public:
+ CallOpLowering(
+ const LLVMTypeConverter &typeConverter,
+ // Can be nullptr.
+ const std::unordered_map<std::string, bool> *hasBarePtrAttribute,
+ PatternBenefit benefit = 1)
+ : CallOpInterfaceLowering<func::CallOp>(typeConverter, benefit),
+ barePtrCallConvForcedByOptions(
+ typeConverter.getOptions().useBarePtrCallConv),
+ hasBarePtrAttribute(hasBarePtrAttribute) {}
LogicalResult
matchAndRewrite(func::CallOp callOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
bool useBarePtrCallConv = false;
- if (Operation *callee = SymbolTable::lookupNearestSymbolFrom(
+ if (barePtrCallConvForcedByOptions) {
+ useBarePtrCallConv = true;
+ } else if (hasBarePtrAttribute != nullptr) {
+ useBarePtrCallConv =
+ hasBarePtrAttribute->at(callOp.getCalleeAttr().getValue().str());
+ } else if ( // Warning: This is a linear lookup.
+ Operation *callee = SymbolTable::lookupNearestSymbolFrom(
callOp, callOp.getCalleeAttr())) {
- useBarePtrCallConv = shouldUseBarePtrCallConv(callee, getTypeConverter());
+ useBarePtrCallConv = callee->hasAttr(barePtrAttrName);
}
return matchAndRewriteImpl(callOp, adaptor, rewriter, useBarePtrCallConv);
}
+
+private:
+ bool barePtrCallConvForcedByOptions = false;
+ const std::unordered_map<std::string, bool> *hasBarePtrAttribute = nullptr;
};
struct CallIndirectOpLowering
@@ -728,16 +748,23 @@ void mlir::populateFuncToLLVMFuncOpConversionPattern(
patterns.add<FuncOpConversion>(converter);
}
+namespace {
+void populateFuncToLLVMConversionPatternsInternal(
+ LLVMTypeConverter &converter, RewritePatternSet &patterns,
+ // Can be nullptr.
+ const std::unordered_map<std::string, bool> *hasBarePtrAttribute) {
+ populateFuncToLLVMFuncOpConversionPattern(converter, patterns);
+ patterns.add<CallIndirectOpLowering>(converter);
+ patterns.add<CallOpLowering>(converter, hasBarePtrAttribute);
+ patterns.add<ConstantOpLowering>(converter);
+ patterns.add<ReturnOpLowering>(converter);
+}
+} // namespace
+
void mlir::populateFuncToLLVMConversionPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns) {
- populateFuncToLLVMFuncOpConversionPattern(converter, patterns);
- // clang-format off
- patterns.add<
- CallIndirectOpLowering,
- CallOpLowering,
- ConstantOpLowering,
- ReturnOpLowering>(converter);
- // clang-format on
+ populateFuncToLLVMConversionPatternsInternal(converter, patterns,
+ /*hasBarePtrAttribute=*/nullptr);
}
namespace {
@@ -765,6 +792,24 @@ struct ConvertFuncToLLVMPass
const auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
+ std::unordered_map<std::string, bool> hasBarePtrAttribute;
+ for (Region ®ion : m->getRegions()) {
+ for (Block &block : region.getBlocks()) {
+ for (Operation &op : block.getOperations()) {
+ if (isa<func::FuncOp>(op)) {
+ auto funcOp = cast<func::FuncOp>(op);
+ bool inserted = hasBarePtrAttribute
+ .insert({funcOp.getSymName().str(),
+ funcOp->hasAttr(barePtrAttrName)})
+ .second;
+ (void)inserted;
+ assert(inserted &&
+ "expected module to contain uniquely named funcOps");
+ }
+ }
+ }
+ }
+
LowerToLLVMOptions options(&getContext(),
dataLayoutAnalysis.getAtOrAbove(m));
options.useBarePtrCallConv = useBarePtrCallConv;
@@ -777,7 +822,8 @@ struct ConvertFuncToLLVMPass
&dataLayoutAnalysis);
RewritePatternSet patterns(&getContext());
- populateFuncToLLVMConversionPatterns(typeConverter, patterns);
+ populateFuncToLLVMConversionPatternsInternal(typeConverter, patterns,
+ &hasBarePtrAttribute);
// TODO: Remove these in favor of their dedicated conversion passes.
arith::populateArithToLLVMConversionPatterns(typeConverter, patterns);
More information about the Mlir-commits
mailing list