[Mlir-commits] [mlir] NFC: resolve TODO in LLVM dialect conversions (PR #91497)
Christopher Bate
llvmlistbot at llvm.org
Thu May 30 11:47:39 PDT 2024
https://github.com/christopherbate updated https://github.com/llvm/llvm-project/pull/91497
>From d7b01139663b95e77d364a577dd79a61c87c2de1 Mon Sep 17 00:00:00 2001
From: Christopher Bate <cbate at nvidia.com>
Date: Wed, 8 May 2024 09:49:38 -0600
Subject: [PATCH] NFC: resolve TODO in LLVM dialect conversions
Relaxes restriction that certain public utility functions only apply
to the builtin ModuleOp.
---
.../mlir/Dialect/LLVMIR/FunctionCallUtils.h | 38 +++++++--------
.../MemRefToLLVM/AllocLikeConversion.cpp | 16 +++----
.../Dialect/LLVMIR/IR/FunctionCallUtils.cpp | 46 ++++++++++---------
3 files changed, 51 insertions(+), 49 deletions(-)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
index 123ce36cb0a79..852490cf7428f 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
@@ -33,36 +33,36 @@ class LLVMFuncOp;
/// external C function calls. The list of functions provided here must be
/// implemented separately (e.g. as part of a support runtime library or as part
/// of the libc).
-LLVM::LLVMFuncOp lookupOrCreatePrintI64Fn(ModuleOp moduleOp);
-LLVM::LLVMFuncOp lookupOrCreatePrintU64Fn(ModuleOp moduleOp);
-LLVM::LLVMFuncOp lookupOrCreatePrintF16Fn(ModuleOp moduleOp);
-LLVM::LLVMFuncOp lookupOrCreatePrintBF16Fn(ModuleOp moduleOp);
-LLVM::LLVMFuncOp lookupOrCreatePrintF32Fn(ModuleOp moduleOp);
-LLVM::LLVMFuncOp lookupOrCreatePrintF64Fn(ModuleOp moduleOp);
+LLVM::LLVMFuncOp lookupOrCreatePrintI64Fn(Operation *moduleOp);
+LLVM::LLVMFuncOp lookupOrCreatePrintU64Fn(Operation *moduleOp);
+LLVM::LLVMFuncOp lookupOrCreatePrintF16Fn(Operation *moduleOp);
+LLVM::LLVMFuncOp lookupOrCreatePrintBF16Fn(Operation *moduleOp);
+LLVM::LLVMFuncOp lookupOrCreatePrintF32Fn(Operation *moduleOp);
+LLVM::LLVMFuncOp lookupOrCreatePrintF64Fn(Operation *moduleOp);
/// Declares a function to print a C-string.
/// If a custom runtime function is defined via `runtimeFunctionName`, it must
/// have the signature void(char const*). The default function is `printString`.
LLVM::LLVMFuncOp
-lookupOrCreatePrintStringFn(ModuleOp moduleOp,
+lookupOrCreatePrintStringFn(Operation *moduleOp,
std::optional<StringRef> runtimeFunctionName = {});
-LLVM::LLVMFuncOp lookupOrCreatePrintOpenFn(ModuleOp moduleOp);
-LLVM::LLVMFuncOp lookupOrCreatePrintCloseFn(ModuleOp moduleOp);
-LLVM::LLVMFuncOp lookupOrCreatePrintCommaFn(ModuleOp moduleOp);
-LLVM::LLVMFuncOp lookupOrCreatePrintNewlineFn(ModuleOp moduleOp);
-LLVM::LLVMFuncOp lookupOrCreateMallocFn(ModuleOp moduleOp, Type indexType);
-LLVM::LLVMFuncOp lookupOrCreateAlignedAllocFn(ModuleOp moduleOp,
+LLVM::LLVMFuncOp lookupOrCreatePrintOpenFn(Operation *moduleOp);
+LLVM::LLVMFuncOp lookupOrCreatePrintCloseFn(Operation *moduleOp);
+LLVM::LLVMFuncOp lookupOrCreatePrintCommaFn(Operation *moduleOp);
+LLVM::LLVMFuncOp lookupOrCreatePrintNewlineFn(Operation *moduleOp);
+LLVM::LLVMFuncOp lookupOrCreateMallocFn(Operation *moduleOp, Type indexType);
+LLVM::LLVMFuncOp lookupOrCreateAlignedAllocFn(Operation *moduleOp,
Type indexType);
-LLVM::LLVMFuncOp lookupOrCreateFreeFn(ModuleOp moduleOp);
-LLVM::LLVMFuncOp lookupOrCreateGenericAllocFn(ModuleOp moduleOp,
+LLVM::LLVMFuncOp lookupOrCreateFreeFn(Operation *moduleOp);
+LLVM::LLVMFuncOp lookupOrCreateGenericAllocFn(Operation *moduleOp,
Type indexType);
-LLVM::LLVMFuncOp lookupOrCreateGenericAlignedAllocFn(ModuleOp moduleOp,
+LLVM::LLVMFuncOp lookupOrCreateGenericAlignedAllocFn(Operation *moduleOp,
Type indexType);
-LLVM::LLVMFuncOp lookupOrCreateGenericFreeFn(ModuleOp moduleOp);
-LLVM::LLVMFuncOp lookupOrCreateMemRefCopyFn(ModuleOp moduleOp, Type indexType,
+LLVM::LLVMFuncOp lookupOrCreateGenericFreeFn(Operation *moduleOp);
+LLVM::LLVMFuncOp lookupOrCreateMemRefCopyFn(Operation *moduleOp, Type indexType,
Type unrankedDescriptorType);
/// Create a FuncOp with signature `resultType`(`paramTypes`)` and name `name`.
-LLVM::LLVMFuncOp lookupOrCreateFn(ModuleOp moduleOp, StringRef name,
+LLVM::LLVMFuncOp lookupOrCreateFn(Operation *moduleOp, StringRef name,
ArrayRef<Type> paramTypes = {},
Type resultType = {}, bool isVarArg = false);
diff --git a/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp b/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp
index b29abc94ce400..e48ca5180b706 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp
@@ -10,18 +10,14 @@
#include "mlir/Analysis/DataLayoutAnalysis.h"
#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/IR/SymbolTable.h"
using namespace mlir;
namespace {
-// TODO: Fix the LLVM utilities for looking up functions to take Operation*
-// with SymbolTable trait instead of ModuleOp and make similar change here. This
-// allows call sites to use getParentWithTrait<OpTrait::SymbolTable> instead
-// of getParentOfType<ModuleOp> to pass down the operation.
LLVM::LLVMFuncOp getNotalignedAllocFn(const LLVMTypeConverter *typeConverter,
- ModuleOp module, Type indexType) {
+ Operation *module, Type indexType) {
bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
-
if (useGenericFn)
return LLVM::lookupOrCreateGenericAllocFn(module, indexType);
@@ -29,7 +25,7 @@ LLVM::LLVMFuncOp getNotalignedAllocFn(const LLVMTypeConverter *typeConverter,
}
LLVM::LLVMFuncOp getAlignedAllocFn(const LLVMTypeConverter *typeConverter,
- ModuleOp module, Type indexType) {
+ Operation *module, Type indexType) {
bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
if (useGenericFn)
@@ -79,7 +75,8 @@ std::tuple<Value, Value> AllocationOpLLVMLowering::allocateBufferManuallyAlign(
// Allocate the underlying buffer.
Type elementPtrType = this->getElementPtrType(memRefType);
LLVM::LLVMFuncOp allocFuncOp = getNotalignedAllocFn(
- getTypeConverter(), op->getParentOfType<ModuleOp>(), getIndexType());
+ getTypeConverter(), op->getParentWithTrait<OpTrait::SymbolTable>(),
+ getIndexType());
auto results = rewriter.create<LLVM::CallOp>(loc, allocFuncOp, sizeBytes);
Value allocatedPtr =
@@ -144,7 +141,8 @@ Value AllocationOpLLVMLowering::allocateBufferAutoAlign(
Type elementPtrType = this->getElementPtrType(memRefType);
LLVM::LLVMFuncOp allocFuncOp = getAlignedAllocFn(
- getTypeConverter(), op->getParentOfType<ModuleOp>(), getIndexType());
+ getTypeConverter(), op->getParentWithTrait<OpTrait::SymbolTable>(),
+ getIndexType());
auto results = rewriter.create<LLVM::CallOp>(
loc, allocFuncOp, ValueRange({allocAlignment, sizeBytes}));
diff --git a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
index 0004c2e3403e5..88421a16ccf9f 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
@@ -45,49 +45,53 @@ static constexpr llvm::StringRef kGenericFree = "_mlir_memref_to_llvm_free";
static constexpr llvm::StringRef kMemRefCopy = "memrefCopy";
/// Generic print function lookupOrCreate helper.
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFn(ModuleOp moduleOp, StringRef name,
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFn(Operation *moduleOp,
+ StringRef name,
ArrayRef<Type> paramTypes,
Type resultType, bool isVarArg) {
- auto func = moduleOp.lookupSymbol<LLVM::LLVMFuncOp>(name);
+ assert(moduleOp->hasTrait<OpTrait::SymbolTable>() &&
+ "expected SymbolTable operation");
+ auto func = llvm::dyn_cast_or_null<LLVM::LLVMFuncOp>(
+ SymbolTable::lookupSymbolIn(moduleOp, name));
if (func)
return func;
- OpBuilder b(moduleOp.getBodyRegion());
+ OpBuilder b(moduleOp->getRegion(0));
return b.create<LLVM::LLVMFuncOp>(
moduleOp->getLoc(), name,
LLVM::LLVMFunctionType::get(resultType, paramTypes, isVarArg));
}
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintI64Fn(ModuleOp moduleOp) {
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintI64Fn(Operation *moduleOp) {
return lookupOrCreateFn(moduleOp, kPrintI64,
IntegerType::get(moduleOp->getContext(), 64),
LLVM::LLVMVoidType::get(moduleOp->getContext()));
}
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintU64Fn(ModuleOp moduleOp) {
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintU64Fn(Operation *moduleOp) {
return lookupOrCreateFn(moduleOp, kPrintU64,
IntegerType::get(moduleOp->getContext(), 64),
LLVM::LLVMVoidType::get(moduleOp->getContext()));
}
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF16Fn(ModuleOp moduleOp) {
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF16Fn(Operation *moduleOp) {
return lookupOrCreateFn(moduleOp, kPrintF16,
IntegerType::get(moduleOp->getContext(), 16), // bits!
LLVM::LLVMVoidType::get(moduleOp->getContext()));
}
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintBF16Fn(ModuleOp moduleOp) {
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintBF16Fn(Operation *moduleOp) {
return lookupOrCreateFn(moduleOp, kPrintBF16,
IntegerType::get(moduleOp->getContext(), 16), // bits!
LLVM::LLVMVoidType::get(moduleOp->getContext()));
}
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF32Fn(ModuleOp moduleOp) {
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF32Fn(Operation *moduleOp) {
return lookupOrCreateFn(moduleOp, kPrintF32,
Float32Type::get(moduleOp->getContext()),
LLVM::LLVMVoidType::get(moduleOp->getContext()));
}
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF64Fn(ModuleOp moduleOp) {
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF64Fn(Operation *moduleOp) {
return lookupOrCreateFn(moduleOp, kPrintF64,
Float64Type::get(moduleOp->getContext()),
LLVM::LLVMVoidType::get(moduleOp->getContext()));
@@ -103,72 +107,72 @@ static LLVM::LLVMPointerType getVoidPtr(MLIRContext *context) {
}
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintStringFn(
- ModuleOp moduleOp, std::optional<StringRef> runtimeFunctionName) {
+ Operation *moduleOp, std::optional<StringRef> runtimeFunctionName) {
return lookupOrCreateFn(moduleOp, runtimeFunctionName.value_or(kPrintString),
getCharPtr(moduleOp->getContext()),
LLVM::LLVMVoidType::get(moduleOp->getContext()));
}
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintOpenFn(ModuleOp moduleOp) {
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintOpenFn(Operation *moduleOp) {
return lookupOrCreateFn(moduleOp, kPrintOpen, {},
LLVM::LLVMVoidType::get(moduleOp->getContext()));
}
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintCloseFn(ModuleOp moduleOp) {
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintCloseFn(Operation *moduleOp) {
return lookupOrCreateFn(moduleOp, kPrintClose, {},
LLVM::LLVMVoidType::get(moduleOp->getContext()));
}
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintCommaFn(ModuleOp moduleOp) {
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintCommaFn(Operation *moduleOp) {
return lookupOrCreateFn(moduleOp, kPrintComma, {},
LLVM::LLVMVoidType::get(moduleOp->getContext()));
}
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintNewlineFn(ModuleOp moduleOp) {
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintNewlineFn(Operation *moduleOp) {
return lookupOrCreateFn(moduleOp, kPrintNewline, {},
LLVM::LLVMVoidType::get(moduleOp->getContext()));
}
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateMallocFn(ModuleOp moduleOp,
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateMallocFn(Operation *moduleOp,
Type indexType) {
return LLVM::lookupOrCreateFn(moduleOp, kMalloc, indexType,
getVoidPtr(moduleOp->getContext()));
}
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateAlignedAllocFn(ModuleOp moduleOp,
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateAlignedAllocFn(Operation *moduleOp,
Type indexType) {
return LLVM::lookupOrCreateFn(moduleOp, kAlignedAlloc, {indexType, indexType},
getVoidPtr(moduleOp->getContext()));
}
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFreeFn(ModuleOp moduleOp) {
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFreeFn(Operation *moduleOp) {
return LLVM::lookupOrCreateFn(
moduleOp, kFree, getVoidPtr(moduleOp->getContext()),
LLVM::LLVMVoidType::get(moduleOp->getContext()));
}
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateGenericAllocFn(ModuleOp moduleOp,
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateGenericAllocFn(Operation *moduleOp,
Type indexType) {
return LLVM::lookupOrCreateFn(moduleOp, kGenericAlloc, indexType,
getVoidPtr(moduleOp->getContext()));
}
LLVM::LLVMFuncOp
-mlir::LLVM::lookupOrCreateGenericAlignedAllocFn(ModuleOp moduleOp,
+mlir::LLVM::lookupOrCreateGenericAlignedAllocFn(Operation *moduleOp,
Type indexType) {
return LLVM::lookupOrCreateFn(moduleOp, kGenericAlignedAlloc,
{indexType, indexType},
getVoidPtr(moduleOp->getContext()));
}
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateGenericFreeFn(ModuleOp moduleOp) {
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateGenericFreeFn(Operation *moduleOp) {
return LLVM::lookupOrCreateFn(
moduleOp, kGenericFree, getVoidPtr(moduleOp->getContext()),
LLVM::LLVMVoidType::get(moduleOp->getContext()));
}
LLVM::LLVMFuncOp
-mlir::LLVM::lookupOrCreateMemRefCopyFn(ModuleOp moduleOp, Type indexType,
+mlir::LLVM::lookupOrCreateMemRefCopyFn(Operation *moduleOp, Type indexType,
Type unrankedDescriptorType) {
return LLVM::lookupOrCreateFn(
moduleOp, kMemRefCopy,
More information about the Mlir-commits
mailing list