[Mlir-commits] [mlir] NFC: resolve TODO in LLVM dialect conversions (PR #91497)

Christopher Bate llvmlistbot at llvm.org
Wed May 8 09:18:52 PDT 2024


https://github.com/christopherbate created https://github.com/llvm/llvm-project/pull/91497

Relaxes restriction that certain public utility functions only apply
to the builtin ModuleOp.


>From a020bb9c5492ee26ffdc26e97ddee8e68d7172cc Mon Sep 17 00:00:00 2001
From: Christopher Bate <cbate at nvidia.com>
Date: Wed, 8 May 2024 09:49:38 -0600
Subject: [PATCH] NFC: resolve TODO in LLVM dialect conversions

Relaxes restriction that certain public utility functions only apply
to the builtin ModuleOp.
---
 .../mlir/Dialect/LLVMIR/FunctionCallUtils.h   | 38 +++++++--------
 .../MemRefToLLVM/AllocLikeConversion.cpp      | 16 +++----
 .../Dialect/LLVMIR/IR/FunctionCallUtils.cpp   | 46 ++++++++++---------
 3 files changed, 51 insertions(+), 49 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
index 123ce36cb0a79..852490cf7428f 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
@@ -33,36 +33,36 @@ class LLVMFuncOp;
 /// external C function calls. The list of functions provided here must be
 /// implemented separately (e.g. as part of a support runtime library or as part
 /// of the libc).
-LLVM::LLVMFuncOp lookupOrCreatePrintI64Fn(ModuleOp moduleOp);
-LLVM::LLVMFuncOp lookupOrCreatePrintU64Fn(ModuleOp moduleOp);
-LLVM::LLVMFuncOp lookupOrCreatePrintF16Fn(ModuleOp moduleOp);
-LLVM::LLVMFuncOp lookupOrCreatePrintBF16Fn(ModuleOp moduleOp);
-LLVM::LLVMFuncOp lookupOrCreatePrintF32Fn(ModuleOp moduleOp);
-LLVM::LLVMFuncOp lookupOrCreatePrintF64Fn(ModuleOp moduleOp);
+LLVM::LLVMFuncOp lookupOrCreatePrintI64Fn(Operation *moduleOp);
+LLVM::LLVMFuncOp lookupOrCreatePrintU64Fn(Operation *moduleOp);
+LLVM::LLVMFuncOp lookupOrCreatePrintF16Fn(Operation *moduleOp);
+LLVM::LLVMFuncOp lookupOrCreatePrintBF16Fn(Operation *moduleOp);
+LLVM::LLVMFuncOp lookupOrCreatePrintF32Fn(Operation *moduleOp);
+LLVM::LLVMFuncOp lookupOrCreatePrintF64Fn(Operation *moduleOp);
 /// Declares a function to print a C-string.
 /// If a custom runtime function is defined via `runtimeFunctionName`, it must
 /// have the signature void(char const*). The default function is `printString`.
 LLVM::LLVMFuncOp
-lookupOrCreatePrintStringFn(ModuleOp moduleOp,
+lookupOrCreatePrintStringFn(Operation *moduleOp,
                             std::optional<StringRef> runtimeFunctionName = {});
-LLVM::LLVMFuncOp lookupOrCreatePrintOpenFn(ModuleOp moduleOp);
-LLVM::LLVMFuncOp lookupOrCreatePrintCloseFn(ModuleOp moduleOp);
-LLVM::LLVMFuncOp lookupOrCreatePrintCommaFn(ModuleOp moduleOp);
-LLVM::LLVMFuncOp lookupOrCreatePrintNewlineFn(ModuleOp moduleOp);
-LLVM::LLVMFuncOp lookupOrCreateMallocFn(ModuleOp moduleOp, Type indexType);
-LLVM::LLVMFuncOp lookupOrCreateAlignedAllocFn(ModuleOp moduleOp,
+LLVM::LLVMFuncOp lookupOrCreatePrintOpenFn(Operation *moduleOp);
+LLVM::LLVMFuncOp lookupOrCreatePrintCloseFn(Operation *moduleOp);
+LLVM::LLVMFuncOp lookupOrCreatePrintCommaFn(Operation *moduleOp);
+LLVM::LLVMFuncOp lookupOrCreatePrintNewlineFn(Operation *moduleOp);
+LLVM::LLVMFuncOp lookupOrCreateMallocFn(Operation *moduleOp, Type indexType);
+LLVM::LLVMFuncOp lookupOrCreateAlignedAllocFn(Operation *moduleOp,
                                               Type indexType);
-LLVM::LLVMFuncOp lookupOrCreateFreeFn(ModuleOp moduleOp);
-LLVM::LLVMFuncOp lookupOrCreateGenericAllocFn(ModuleOp moduleOp,
+LLVM::LLVMFuncOp lookupOrCreateFreeFn(Operation *moduleOp);
+LLVM::LLVMFuncOp lookupOrCreateGenericAllocFn(Operation *moduleOp,
                                               Type indexType);
-LLVM::LLVMFuncOp lookupOrCreateGenericAlignedAllocFn(ModuleOp moduleOp,
+LLVM::LLVMFuncOp lookupOrCreateGenericAlignedAllocFn(Operation *moduleOp,
                                                      Type indexType);
-LLVM::LLVMFuncOp lookupOrCreateGenericFreeFn(ModuleOp moduleOp);
-LLVM::LLVMFuncOp lookupOrCreateMemRefCopyFn(ModuleOp moduleOp, Type indexType,
+LLVM::LLVMFuncOp lookupOrCreateGenericFreeFn(Operation *moduleOp);
+LLVM::LLVMFuncOp lookupOrCreateMemRefCopyFn(Operation *moduleOp, Type indexType,
                                             Type unrankedDescriptorType);
 
 /// Create a FuncOp with signature `resultType`(`paramTypes`)` and name `name`.
-LLVM::LLVMFuncOp lookupOrCreateFn(ModuleOp moduleOp, StringRef name,
+LLVM::LLVMFuncOp lookupOrCreateFn(Operation *moduleOp, StringRef name,
                                   ArrayRef<Type> paramTypes = {},
                                   Type resultType = {}, bool isVarArg = false);
 
diff --git a/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp b/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp
index b29abc94ce400..e48ca5180b706 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp
@@ -10,18 +10,14 @@
 #include "mlir/Analysis/DataLayoutAnalysis.h"
 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/IR/SymbolTable.h"
 
 using namespace mlir;
 
 namespace {
-// TODO: Fix the LLVM utilities for looking up functions to take Operation*
-// with SymbolTable trait instead of ModuleOp and make similar change here. This
-// allows call sites to use getParentWithTrait<OpTrait::SymbolTable> instead
-// of getParentOfType<ModuleOp> to pass down the operation.
 LLVM::LLVMFuncOp getNotalignedAllocFn(const LLVMTypeConverter *typeConverter,
-                                      ModuleOp module, Type indexType) {
+                                      Operation *module, Type indexType) {
   bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
-
   if (useGenericFn)
     return LLVM::lookupOrCreateGenericAllocFn(module, indexType);
 
@@ -29,7 +25,7 @@ LLVM::LLVMFuncOp getNotalignedAllocFn(const LLVMTypeConverter *typeConverter,
 }
 
 LLVM::LLVMFuncOp getAlignedAllocFn(const LLVMTypeConverter *typeConverter,
-                                   ModuleOp module, Type indexType) {
+                                   Operation *module, Type indexType) {
   bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
 
   if (useGenericFn)
@@ -79,7 +75,8 @@ std::tuple<Value, Value> AllocationOpLLVMLowering::allocateBufferManuallyAlign(
   // Allocate the underlying buffer.
   Type elementPtrType = this->getElementPtrType(memRefType);
   LLVM::LLVMFuncOp allocFuncOp = getNotalignedAllocFn(
-      getTypeConverter(), op->getParentOfType<ModuleOp>(), getIndexType());
+      getTypeConverter(), op->getParentWithTrait<OpTrait::SymbolTable>(),
+      getIndexType());
   auto results = rewriter.create<LLVM::CallOp>(loc, allocFuncOp, sizeBytes);
 
   Value allocatedPtr =
@@ -144,7 +141,8 @@ Value AllocationOpLLVMLowering::allocateBufferAutoAlign(
 
   Type elementPtrType = this->getElementPtrType(memRefType);
   LLVM::LLVMFuncOp allocFuncOp = getAlignedAllocFn(
-      getTypeConverter(), op->getParentOfType<ModuleOp>(), getIndexType());
+      getTypeConverter(), op->getParentWithTrait<OpTrait::SymbolTable>(),
+      getIndexType());
   auto results = rewriter.create<LLVM::CallOp>(
       loc, allocFuncOp, ValueRange({allocAlignment, sizeBytes}));
 
diff --git a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
index 0004c2e3403e5..88421a16ccf9f 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
@@ -45,49 +45,53 @@ static constexpr llvm::StringRef kGenericFree = "_mlir_memref_to_llvm_free";
 static constexpr llvm::StringRef kMemRefCopy = "memrefCopy";
 
 /// Generic print function lookupOrCreate helper.
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFn(ModuleOp moduleOp, StringRef name,
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFn(Operation *moduleOp,
+                                              StringRef name,
                                               ArrayRef<Type> paramTypes,
                                               Type resultType, bool isVarArg) {
-  auto func = moduleOp.lookupSymbol<LLVM::LLVMFuncOp>(name);
+  assert(moduleOp->hasTrait<OpTrait::SymbolTable>() &&
+         "expected SymbolTable operation");
+  auto func = llvm::dyn_cast_or_null<LLVM::LLVMFuncOp>(
+      SymbolTable::lookupSymbolIn(moduleOp, name));
   if (func)
     return func;
-  OpBuilder b(moduleOp.getBodyRegion());
+  OpBuilder b(moduleOp->getRegion(0));
   return b.create<LLVM::LLVMFuncOp>(
       moduleOp->getLoc(), name,
       LLVM::LLVMFunctionType::get(resultType, paramTypes, isVarArg));
 }
 
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintI64Fn(ModuleOp moduleOp) {
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintI64Fn(Operation *moduleOp) {
   return lookupOrCreateFn(moduleOp, kPrintI64,
                           IntegerType::get(moduleOp->getContext(), 64),
                           LLVM::LLVMVoidType::get(moduleOp->getContext()));
 }
 
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintU64Fn(ModuleOp moduleOp) {
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintU64Fn(Operation *moduleOp) {
   return lookupOrCreateFn(moduleOp, kPrintU64,
                           IntegerType::get(moduleOp->getContext(), 64),
                           LLVM::LLVMVoidType::get(moduleOp->getContext()));
 }
 
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF16Fn(ModuleOp moduleOp) {
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF16Fn(Operation *moduleOp) {
   return lookupOrCreateFn(moduleOp, kPrintF16,
                           IntegerType::get(moduleOp->getContext(), 16), // bits!
                           LLVM::LLVMVoidType::get(moduleOp->getContext()));
 }
 
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintBF16Fn(ModuleOp moduleOp) {
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintBF16Fn(Operation *moduleOp) {
   return lookupOrCreateFn(moduleOp, kPrintBF16,
                           IntegerType::get(moduleOp->getContext(), 16), // bits!
                           LLVM::LLVMVoidType::get(moduleOp->getContext()));
 }
 
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF32Fn(ModuleOp moduleOp) {
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF32Fn(Operation *moduleOp) {
   return lookupOrCreateFn(moduleOp, kPrintF32,
                           Float32Type::get(moduleOp->getContext()),
                           LLVM::LLVMVoidType::get(moduleOp->getContext()));
 }
 
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF64Fn(ModuleOp moduleOp) {
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF64Fn(Operation *moduleOp) {
   return lookupOrCreateFn(moduleOp, kPrintF64,
                           Float64Type::get(moduleOp->getContext()),
                           LLVM::LLVMVoidType::get(moduleOp->getContext()));
@@ -103,72 +107,72 @@ static LLVM::LLVMPointerType getVoidPtr(MLIRContext *context) {
 }
 
 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintStringFn(
-    ModuleOp moduleOp, std::optional<StringRef> runtimeFunctionName) {
+    Operation *moduleOp, std::optional<StringRef> runtimeFunctionName) {
   return lookupOrCreateFn(moduleOp, runtimeFunctionName.value_or(kPrintString),
                           getCharPtr(moduleOp->getContext()),
                           LLVM::LLVMVoidType::get(moduleOp->getContext()));
 }
 
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintOpenFn(ModuleOp moduleOp) {
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintOpenFn(Operation *moduleOp) {
   return lookupOrCreateFn(moduleOp, kPrintOpen, {},
                           LLVM::LLVMVoidType::get(moduleOp->getContext()));
 }
 
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintCloseFn(ModuleOp moduleOp) {
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintCloseFn(Operation *moduleOp) {
   return lookupOrCreateFn(moduleOp, kPrintClose, {},
                           LLVM::LLVMVoidType::get(moduleOp->getContext()));
 }
 
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintCommaFn(ModuleOp moduleOp) {
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintCommaFn(Operation *moduleOp) {
   return lookupOrCreateFn(moduleOp, kPrintComma, {},
                           LLVM::LLVMVoidType::get(moduleOp->getContext()));
 }
 
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintNewlineFn(ModuleOp moduleOp) {
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintNewlineFn(Operation *moduleOp) {
   return lookupOrCreateFn(moduleOp, kPrintNewline, {},
                           LLVM::LLVMVoidType::get(moduleOp->getContext()));
 }
 
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateMallocFn(ModuleOp moduleOp,
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateMallocFn(Operation *moduleOp,
                                                     Type indexType) {
   return LLVM::lookupOrCreateFn(moduleOp, kMalloc, indexType,
                                 getVoidPtr(moduleOp->getContext()));
 }
 
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateAlignedAllocFn(ModuleOp moduleOp,
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateAlignedAllocFn(Operation *moduleOp,
                                                           Type indexType) {
   return LLVM::lookupOrCreateFn(moduleOp, kAlignedAlloc, {indexType, indexType},
                                 getVoidPtr(moduleOp->getContext()));
 }
 
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFreeFn(ModuleOp moduleOp) {
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFreeFn(Operation *moduleOp) {
   return LLVM::lookupOrCreateFn(
       moduleOp, kFree, getVoidPtr(moduleOp->getContext()),
       LLVM::LLVMVoidType::get(moduleOp->getContext()));
 }
 
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateGenericAllocFn(ModuleOp moduleOp,
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateGenericAllocFn(Operation *moduleOp,
                                                           Type indexType) {
   return LLVM::lookupOrCreateFn(moduleOp, kGenericAlloc, indexType,
                                 getVoidPtr(moduleOp->getContext()));
 }
 
 LLVM::LLVMFuncOp
-mlir::LLVM::lookupOrCreateGenericAlignedAllocFn(ModuleOp moduleOp,
+mlir::LLVM::lookupOrCreateGenericAlignedAllocFn(Operation *moduleOp,
                                                 Type indexType) {
   return LLVM::lookupOrCreateFn(moduleOp, kGenericAlignedAlloc,
                                 {indexType, indexType},
                                 getVoidPtr(moduleOp->getContext()));
 }
 
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateGenericFreeFn(ModuleOp moduleOp) {
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateGenericFreeFn(Operation *moduleOp) {
   return LLVM::lookupOrCreateFn(
       moduleOp, kGenericFree, getVoidPtr(moduleOp->getContext()),
       LLVM::LLVMVoidType::get(moduleOp->getContext()));
 }
 
 LLVM::LLVMFuncOp
-mlir::LLVM::lookupOrCreateMemRefCopyFn(ModuleOp moduleOp, Type indexType,
+mlir::LLVM::lookupOrCreateMemRefCopyFn(Operation *moduleOp, Type indexType,
                                        Type unrankedDescriptorType) {
   return LLVM::lookupOrCreateFn(
       moduleOp, kMemRefCopy,



More information about the Mlir-commits mailing list