[Mlir-commits] [mlir] [NFC][FuncOpToLLVM] refactor and move some utils to headers (PR #68665)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Oct 9 23:38:38 PDT 2023


https://github.com/allatit23 created https://github.com/llvm/llvm-project/pull/68665

* refactor `convertFuncOpToLLVMFuncOp` to accept a `FunctionOpInterface` instead of func::FuncOp
* move `convertFuncOpToLLVMFuncOp` to corresponding public header, making it available for downstream project.

>From 15fcfab294578ff6add4ad8e2d1aeecd7f0fba9b Mon Sep 17 00:00:00 2001
From: Allen Zhao <allzhao at nvidia.com>
Date: Mon, 9 Oct 2023 23:20:14 -0700
Subject: [PATCH] [NFC][FuncOpToLLVM] refactor and move some utils to headers

* refactor `convertFuncOpToLLVMFuncOp` to accept a `FunctionOpInterface` instead of func::FuncOp
* move `convertFuncOpToLLVMFuncOp` to corresponding public header, making it available for downstream project.
---
 .../Conversion/FuncToLLVM/ConvertFuncToLLVM.h |  15 +
 mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp | 284 +++++++++---------
 2 files changed, 165 insertions(+), 134 deletions(-)

diff --git a/mlir/include/mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h b/mlir/include/mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h
index 21bd191aa9dc8c3..7ce9bb18c93ea43 100644
--- a/mlir/include/mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h
+++ b/mlir/include/mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h
@@ -14,13 +14,28 @@
 #ifndef MLIR_CONVERSION_FUNCTOLLVM_CONVERTFUNCTOLLVM_H
 #define MLIR_CONVERSION_FUNCTOLLVM_CONVERTFUNCTOLLVM_H
 
+#include "mlir/Interfaces/FunctionInterfaces.h"
+#include "mlir/Support/LogicalResult.h"
+
 namespace mlir {
 
+namespace LLVM {
+class LLVMFuncOp;
+} // namespace LLVM
+
+class ConversionPatternRewriter;
 class DialectRegistry;
 class LLVMTypeConverter;
 class RewritePatternSet;
 class SymbolTable;
 
+/// Convert input FuncOp to LLVMFuncOp by using the provided LLVMTypeConverter.
+/// Returns failure if failed to so.
+FailureOr<LLVM::LLVMFuncOp>
+convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp,
+                          ConversionPatternRewriter &rewriter,
+                          const LLVMTypeConverter &converter);
+
 /// Collect the default pattern to convert a FuncOp to the LLVM dialect. If
 /// `emitCWrappers` is set, the pattern will also produce functions
 /// that pass memref descriptors by pointer-to-structure in addition to the
diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
index 4aacb47a7fe9cc1..3506f50916132dd 100644
--- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
+++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
@@ -74,7 +74,7 @@ static bool shouldUseBarePtrCallConv(Operation *op,
 
 /// Only retain those attributes that are not constructed by
 /// `LLVMFuncOp::build`.
-static void filterFuncAttributes(func::FuncOp func,
+static void filterFuncAttributes(FunctionOpInterface func,
                                  SmallVectorImpl<NamedAttribute> &result) {
   for (const NamedAttribute &attr : func->getDiscardableAttrs()) {
     if (attr.getName() == linkageAttrName ||
@@ -87,26 +87,26 @@ static void filterFuncAttributes(func::FuncOp func,
 
 /// Propagate argument/results attributes.
 static void propagateArgResAttrs(OpBuilder &builder, bool resultStructType,
-                                 func::FuncOp funcOp,
+                                 FunctionOpInterface funcOp,
                                  LLVM::LLVMFuncOp wrapperFuncOp) {
-  auto argAttrs = funcOp.getArgAttrs();
+  auto argAttrs = funcOp.getAllArgAttrs();
   if (!resultStructType) {
     if (auto resAttrs = funcOp.getAllResultAttrs())
       wrapperFuncOp.setAllResultAttrs(resAttrs);
     if (argAttrs)
-      wrapperFuncOp.setAllArgAttrs(*argAttrs);
+      wrapperFuncOp.setAllArgAttrs(argAttrs);
   } else {
     SmallVector<Attribute> argAttributes;
     // Only modify the argument and result attributes when the result is now
     // an argument.
     if (argAttrs) {
       argAttributes.push_back(builder.getDictionaryAttr({}));
-      argAttributes.append(argAttrs->begin(), argAttrs->end());
+      argAttributes.append(argAttrs.begin(), argAttrs.end());
       wrapperFuncOp.setAllArgAttrs(argAttributes);
     }
   }
-  if (funcOp.getSymVisibilityAttr())
-    wrapperFuncOp.setSymVisibility(funcOp.getSymVisibilityAttr());
+  cast<FunctionOpInterface>(wrapperFuncOp.getOperation())
+      .setVisibility(funcOp.getVisibility());
 }
 
 /// Creates an auxiliary function with pointer-to-memref-descriptor-struct
@@ -119,9 +119,9 @@ static void propagateArgResAttrs(OpBuilder &builder, bool resultStructType,
 /// the extra arguments.
 static void wrapForExternalCallers(OpBuilder &rewriter, Location loc,
                                    const LLVMTypeConverter &typeConverter,
-                                   func::FuncOp funcOp,
+                                   FunctionOpInterface funcOp,
                                    LLVM::LLVMFuncOp newFuncOp) {
-  auto type = funcOp.getFunctionType();
+  auto type = cast<FunctionType>(funcOp.getFunctionType());
   auto [wrapperFuncType, resultStructType] =
       typeConverter.convertFunctionTypeCWrapper(type);
 
@@ -179,12 +179,13 @@ static void wrapForExternalCallers(OpBuilder &rewriter, Location loc,
 /// corresponding to a memref descriptor.
 static void wrapExternalFunction(OpBuilder &builder, Location loc,
                                  const LLVMTypeConverter &typeConverter,
-                                 func::FuncOp funcOp,
+                                 FunctionOpInterface funcOp,
                                  LLVM::LLVMFuncOp newFuncOp) {
   OpBuilder::InsertionGuard guard(builder);
 
   auto [wrapperType, resultStructType] =
-      typeConverter.convertFunctionTypeCWrapper(funcOp.getFunctionType());
+      typeConverter.convertFunctionTypeCWrapper(
+          cast<FunctionType>(funcOp.getFunctionType()));
   // This conversion can only fail if it could not convert one of the argument
   // types. But since it has been applied to a non-wrapper function before, it
   // should have failed earlier and not reach this point at all.
@@ -205,7 +206,7 @@ static void wrapExternalFunction(OpBuilder &builder, Location loc,
   builder.setInsertionPointToStart(newFuncOp.addEntryBlock());
 
   // Get a ValueRange containing arguments.
-  FunctionType type = funcOp.getFunctionType();
+  FunctionType type = cast<FunctionType>(funcOp.getFunctionType());
   SmallVector<Value, 8> args;
   args.reserve(type.getNumInputs());
   ValueRange wrapperArgsRange(newFuncOp.getArguments());
@@ -317,6 +318,140 @@ static void modifyFuncOpToUseBarePtrCallingConv(
   }
 }
 
+FailureOr<LLVM::LLVMFuncOp>
+mlir::convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp,
+                                ConversionPatternRewriter &rewriter,
+                                const LLVMTypeConverter &converter) {
+  // Check the funcOp has `FunctionType`.
+  auto funcTy = dyn_cast<FunctionType>(funcOp.getFunctionType());
+  if (!funcTy)
+    return rewriter.notifyMatchFailure(
+        funcOp, "Only support FunctionOpInterface with FunctionType");
+
+  // Convert the original function arguments. They are converted using the
+  // LLVMTypeConverter provided to this legalization pattern.
+  auto varargsAttr = funcOp->getAttrOfType<BoolAttr>(varargsAttrName);
+  TypeConverter::SignatureConversion result(funcOp.getNumArguments());
+  auto llvmType = converter.convertFunctionSignature(
+      funcTy, varargsAttr && varargsAttr.getValue(),
+      shouldUseBarePtrCallConv(funcOp, &converter), result);
+  if (!llvmType)
+    return rewriter.notifyMatchFailure(funcOp, "signature conversion failed");
+
+  // Create an LLVM function, use external linkage by default until MLIR
+  // functions have linkage.
+  LLVM::Linkage linkage = LLVM::Linkage::External;
+  if (funcOp->hasAttr(linkageAttrName)) {
+    auto attr =
+        dyn_cast<mlir::LLVM::LinkageAttr>(funcOp->getAttr(linkageAttrName));
+    if (!attr) {
+      funcOp->emitError() << "Contains " << linkageAttrName
+                          << " attribute not of type LLVM::LinkageAttr";
+      return rewriter.notifyMatchFailure(
+          funcOp, "Contains linkage attribute not of type LLVM::LinkageAttr");
+    }
+    linkage = attr.getLinkage();
+  }
+
+  SmallVector<NamedAttribute, 4> attributes;
+  filterFuncAttributes(funcOp, attributes);
+  auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
+      funcOp.getLoc(), funcOp.getName(), llvmType, linkage,
+      /*dsoLocal=*/false, /*cconv=*/LLVM::CConv::C, /*comdat=*/nullptr,
+      attributes);
+  cast<FunctionOpInterface>(newFuncOp.getOperation())
+      .setVisibility(funcOp.getVisibility());
+
+  // Create a memory effect attribute corresponding to readnone.
+  StringRef readnoneAttrName = LLVM::LLVMDialect::getReadnoneAttrName();
+  if (funcOp->hasAttr(readnoneAttrName)) {
+    auto attr = funcOp->getAttrOfType<UnitAttr>(readnoneAttrName);
+    if (!attr) {
+      funcOp->emitError() << "Contains " << readnoneAttrName
+                          << " attribute not of type UnitAttr";
+      return rewriter.notifyMatchFailure(
+          funcOp, "Contains readnone attribute not of type UnitAttr");
+    }
+    auto memoryAttr = LLVM::MemoryEffectsAttr::get(
+        rewriter.getContext(),
+        {LLVM::ModRefInfo::NoModRef, LLVM::ModRefInfo::NoModRef,
+         LLVM::ModRefInfo::NoModRef});
+    newFuncOp.setMemoryAttr(memoryAttr);
+  }
+
+  // Propagate argument/result attributes to all converted arguments/result
+  // obtained after converting a given original argument/result.
+  if (ArrayAttr resAttrDicts = funcOp.getAllResultAttrs()) {
+    assert(!resAttrDicts.empty() && "expected array to be non-empty");
+    if (funcOp.getNumResults() == 1)
+      newFuncOp.setAllResultAttrs(resAttrDicts);
+  }
+  if (ArrayAttr argAttrDicts = funcOp.getAllArgAttrs()) {
+    SmallVector<Attribute> newArgAttrs(
+        cast<LLVM::LLVMFunctionType>(llvmType).getNumParams());
+    for (unsigned i = 0, e = funcOp.getNumArguments(); i < e; ++i) {
+      // Some LLVM IR attribute have a type attached to them. During FuncOp ->
+      // LLVMFuncOp conversion these types may have changed. Account for that
+      // change by converting attributes' types as well.
+      SmallVector<NamedAttribute, 4> convertedAttrs;
+      auto attrsDict = cast<DictionaryAttr>(argAttrDicts[i]);
+      convertedAttrs.reserve(attrsDict.size());
+      for (const NamedAttribute &attr : attrsDict) {
+        const auto convert = [&](const NamedAttribute &attr) {
+          return TypeAttr::get(converter.convertType(
+              cast<TypeAttr>(attr.getValue()).getValue()));
+        };
+        if (attr.getName().getValue() ==
+            LLVM::LLVMDialect::getByValAttrName()) {
+          convertedAttrs.push_back(rewriter.getNamedAttr(
+              LLVM::LLVMDialect::getByValAttrName(), convert(attr)));
+        } else if (attr.getName().getValue() ==
+                   LLVM::LLVMDialect::getByRefAttrName()) {
+          convertedAttrs.push_back(rewriter.getNamedAttr(
+              LLVM::LLVMDialect::getByRefAttrName(), convert(attr)));
+        } else if (attr.getName().getValue() ==
+                   LLVM::LLVMDialect::getStructRetAttrName()) {
+          convertedAttrs.push_back(rewriter.getNamedAttr(
+              LLVM::LLVMDialect::getStructRetAttrName(), convert(attr)));
+        } else if (attr.getName().getValue() ==
+                   LLVM::LLVMDialect::getInAllocaAttrName()) {
+          convertedAttrs.push_back(rewriter.getNamedAttr(
+              LLVM::LLVMDialect::getInAllocaAttrName(), convert(attr)));
+        } else {
+          convertedAttrs.push_back(attr);
+        }
+      }
+      auto mapping = result.getInputMapping(i);
+      assert(mapping && "unexpected deletion of function argument");
+      // Only attach the new argument attributes if there is a one-to-one
+      // mapping from old to new types. Otherwise, attributes might be
+      // attached to types that they do not support.
+      if (mapping->size == 1) {
+        newArgAttrs[mapping->inputNo] =
+            DictionaryAttr::get(rewriter.getContext(), convertedAttrs);
+        continue;
+      }
+      // TODO: Implement custom handling for types that expand to multiple
+      // function arguments.
+      for (size_t j = 0; j < mapping->size; ++j)
+        newArgAttrs[mapping->inputNo + j] =
+            DictionaryAttr::get(rewriter.getContext(), {});
+    }
+    if (!newArgAttrs.empty())
+      newFuncOp.setAllArgAttrs(rewriter.getArrayAttr(newArgAttrs));
+  }
+
+  rewriter.inlineRegionBefore(funcOp.getFunctionBody(), newFuncOp.getBody(),
+                              newFuncOp.end());
+  if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), converter,
+                                         &result))) {
+    return rewriter.notifyMatchFailure(funcOp,
+                                       "region types conversion failed");
+  }
+
+  return newFuncOp;
+}
+
 namespace {
 
 struct FuncOpConversionBase : public ConvertOpToLLVMPattern<func::FuncOp> {
@@ -328,128 +463,9 @@ struct FuncOpConversionBase : public ConvertOpToLLVMPattern<func::FuncOp> {
   FailureOr<LLVM::LLVMFuncOp>
   convertFuncOpToLLVMFuncOp(func::FuncOp funcOp,
                             ConversionPatternRewriter &rewriter) const {
-    // Convert the original function arguments. They are converted using the
-    // LLVMTypeConverter provided to this legalization pattern.
-    auto varargsAttr = funcOp->getAttrOfType<BoolAttr>(varargsAttrName);
-    TypeConverter::SignatureConversion result(funcOp.getNumArguments());
-    auto llvmType = getTypeConverter()->convertFunctionSignature(
-        funcOp.getFunctionType(), varargsAttr && varargsAttr.getValue(),
-        shouldUseBarePtrCallConv(funcOp, getTypeConverter()), result);
-    if (!llvmType)
-      return rewriter.notifyMatchFailure(funcOp, "signature conversion failed");
-
-    // Create an LLVM function, use external linkage by default until MLIR
-    // functions have linkage.
-    LLVM::Linkage linkage = LLVM::Linkage::External;
-    if (funcOp->hasAttr(linkageAttrName)) {
-      auto attr =
-          dyn_cast<mlir::LLVM::LinkageAttr>(funcOp->getAttr(linkageAttrName));
-      if (!attr) {
-        funcOp->emitError() << "Contains " << linkageAttrName
-                            << " attribute not of type LLVM::LinkageAttr";
-        return rewriter.notifyMatchFailure(
-            funcOp, "Contains linkage attribute not of type LLVM::LinkageAttr");
-      }
-      linkage = attr.getLinkage();
-    }
-
-    SmallVector<NamedAttribute, 4> attributes;
-    filterFuncAttributes(funcOp, attributes);
-    auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
-        funcOp.getLoc(), funcOp.getName(), llvmType, linkage,
-        /*dsoLocal=*/false, /*cconv=*/LLVM::CConv::C, /*comdat=*/nullptr,
-        attributes);
-    if (funcOp.getSymVisibilityAttr())
-      newFuncOp.setSymVisibility(funcOp.getSymVisibilityAttr());
-
-    // Create a memory effect attribute corresponding to readnone.
-    StringRef readnoneAttrName = LLVM::LLVMDialect::getReadnoneAttrName();
-    if (funcOp->hasAttr(readnoneAttrName)) {
-      auto attr = funcOp->getAttrOfType<UnitAttr>(readnoneAttrName);
-      if (!attr) {
-        funcOp->emitError() << "Contains " << readnoneAttrName
-                            << " attribute not of type UnitAttr";
-        return rewriter.notifyMatchFailure(
-            funcOp, "Contains readnone attribute not of type UnitAttr");
-      }
-      auto memoryAttr = LLVM::MemoryEffectsAttr::get(
-          rewriter.getContext(),
-          {LLVM::ModRefInfo::NoModRef, LLVM::ModRefInfo::NoModRef,
-           LLVM::ModRefInfo::NoModRef});
-      newFuncOp.setMemoryAttr(memoryAttr);
-    }
-
-    // Propagate argument/result attributes to all converted arguments/result
-    // obtained after converting a given original argument/result.
-    if (ArrayAttr resAttrDicts = funcOp.getAllResultAttrs()) {
-      assert(!resAttrDicts.empty() && "expected array to be non-empty");
-      if (funcOp.getNumResults() == 1)
-        newFuncOp.setAllResultAttrs(resAttrDicts);
-    }
-    if (ArrayAttr argAttrDicts = funcOp.getAllArgAttrs()) {
-      SmallVector<Attribute> newArgAttrs(
-          cast<LLVM::LLVMFunctionType>(llvmType).getNumParams());
-      for (unsigned i = 0, e = funcOp.getNumArguments(); i < e; ++i) {
-        // Some LLVM IR attribute have a type attached to them. During FuncOp ->
-        // LLVMFuncOp conversion these types may have changed. Account for that
-        // change by converting attributes' types as well.
-        SmallVector<NamedAttribute, 4> convertedAttrs;
-        auto attrsDict = cast<DictionaryAttr>(argAttrDicts[i]);
-        convertedAttrs.reserve(attrsDict.size());
-        for (const NamedAttribute &attr : attrsDict) {
-          const auto convert = [&](const NamedAttribute &attr) {
-            return TypeAttr::get(getTypeConverter()->convertType(
-                cast<TypeAttr>(attr.getValue()).getValue()));
-          };
-          if (attr.getName().getValue() ==
-              LLVM::LLVMDialect::getByValAttrName()) {
-            convertedAttrs.push_back(rewriter.getNamedAttr(
-                LLVM::LLVMDialect::getByValAttrName(), convert(attr)));
-          } else if (attr.getName().getValue() ==
-                     LLVM::LLVMDialect::getByRefAttrName()) {
-            convertedAttrs.push_back(rewriter.getNamedAttr(
-                LLVM::LLVMDialect::getByRefAttrName(), convert(attr)));
-          } else if (attr.getName().getValue() ==
-                     LLVM::LLVMDialect::getStructRetAttrName()) {
-            convertedAttrs.push_back(rewriter.getNamedAttr(
-                LLVM::LLVMDialect::getStructRetAttrName(), convert(attr)));
-          } else if (attr.getName().getValue() ==
-                     LLVM::LLVMDialect::getInAllocaAttrName()) {
-            convertedAttrs.push_back(rewriter.getNamedAttr(
-                LLVM::LLVMDialect::getInAllocaAttrName(), convert(attr)));
-          } else {
-            convertedAttrs.push_back(attr);
-          }
-        }
-        auto mapping = result.getInputMapping(i);
-        assert(mapping && "unexpected deletion of function argument");
-        // Only attach the new argument attributes if there is a one-to-one
-        // mapping from old to new types. Otherwise, attributes might be
-        // attached to types that they do not support.
-        if (mapping->size == 1) {
-          newArgAttrs[mapping->inputNo] =
-              DictionaryAttr::get(rewriter.getContext(), convertedAttrs);
-          continue;
-        }
-        // TODO: Implement custom handling for types that expand to multiple
-        // function arguments.
-        for (size_t j = 0; j < mapping->size; ++j)
-          newArgAttrs[mapping->inputNo + j] =
-              DictionaryAttr::get(rewriter.getContext(), {});
-      }
-      if (!newArgAttrs.empty())
-        newFuncOp.setAllArgAttrs(rewriter.getArrayAttr(newArgAttrs));
-    }
-
-    rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
-                                newFuncOp.end());
-    if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), *typeConverter,
-                                           &result))) {
-      return rewriter.notifyMatchFailure(funcOp,
-                                         "region types conversion failed");
-    }
-
-    return newFuncOp;
+    return mlir::convertFuncOpToLLVMFuncOp(
+        cast<FunctionOpInterface>(funcOp.getOperation()), rewriter,
+        *getTypeConverter());
   }
 };
 



More information about the Mlir-commits mailing list