[Mlir-commits] [mlir] 162f757 - [mlir][LLVM] Add an attribute to control use of bare-pointer calling convention.

Mahesh Ravishankar llvmlistbot at llvm.org
Thu Apr 6 09:20:14 PDT 2023


Author: Mahesh Ravishankar
Date: 2023-04-06T16:19:56Z
New Revision: 162f7572067d7d2d70202f5ff42532adf6f75517

URL: https://github.com/llvm/llvm-project/commit/162f7572067d7d2d70202f5ff42532adf6f75517
DIFF: https://github.com/llvm/llvm-project/commit/162f7572067d7d2d70202f5ff42532adf6f75517.diff

LOG: [mlir][LLVM] Add an attribute to control use of bare-pointer calling convention.

Currently the use of bare pointer calling convention is controlled
globally through use of an option in the `LLVMTypeConverter`. To allow
more fine-grained control use an attribute on a function to drive the
calling convention to use.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D147494

Added: 
    

Modified: 
    mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
    mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
    mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
    mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
    mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
    mlir/test/Conversion/FuncToLLVM/calling-convention.mlir
    mlir/test/Conversion/FuncToLLVM/func-memref-return.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
index b13b88d6773a8..600575139dbe5 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
@@ -53,20 +53,23 @@ class LLVMTypeConverter : public TypeConverter {
   /// one and results are packed into a wrapped LLVM IR structure type. `result`
   /// is populated with argument mapping.
   Type convertFunctionSignature(FunctionType funcTy, bool isVariadic,
+                                bool useBarePtrCallConv,
                                 SignatureConversion &result);
 
   /// Convert a non-empty list of types to be returned from a function into a
   /// supported LLVM IR type.  In particular, if more than one value is
   /// returned, create an LLVM IR structure type with elements that correspond
   /// to each of the MLIR types converted with `convertType`.
-  Type packFunctionResults(TypeRange types);
+  Type packFunctionResults(TypeRange types,
+                           bool useBarePointerCallConv = false);
 
   /// Convert a type in the context of the default or bare pointer calling
   /// convention. Calling convention sensitive types, such as MemRefType and
   /// UnrankedMemRefType, are converted following the specific rules for the
   /// calling convention. Calling convention independent types are converted
   /// following the default LLVM type conversions.
-  Type convertCallingConventionType(Type type);
+  Type convertCallingConventionType(Type type,
+                                    bool useBarePointerCallConv = false);
 
   /// Promote the bare pointers in 'values' that resulted from memrefs to
   /// descriptors. 'stdTypes' holds the types of 'values' before the conversion
@@ -95,8 +98,8 @@ class LLVMTypeConverter : public TypeConverter {
   /// of the platform-specific C/C++ ABI lowering related to struct argument
   /// passing.
   SmallVector<Value, 4> promoteOperands(Location loc, ValueRange opOperands,
-                                        ValueRange operands,
-                                        OpBuilder &builder);
+                                        ValueRange operands, OpBuilder &builder,
+                                        bool useBarePtrCallConv = false);
 
   /// Promote the LLVM struct representation of one MemRef descriptor to stack
   /// and use pointer to struct to avoid the complexity of the platform-specific

diff  --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
index 7200b2b3ea9af..86394aa969bb3 100644
--- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
+++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
@@ -58,6 +58,14 @@ using namespace mlir;
 
 static constexpr StringRef varargsAttrName = "func.varargs";
 static constexpr StringRef linkageAttrName = "llvm.linkage";
+static constexpr StringRef barePtrAttrName = "llvm.bareptr";
+
+/// Return `true` if the `op` should use bare pointer calling convention.
+static bool shouldUseBarePtrCallConv(Operation *op,
+                                     LLVMTypeConverter *typeConverter) {
+  return (op && op->hasAttr(barePtrAttrName)) ||
+         typeConverter->getOptions().useBarePtrCallConv;
+}
 
 /// Only retain those attributes that are not constructed by
 /// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out argument
@@ -267,6 +275,55 @@ static void wrapExternalFunction(OpBuilder &builder, Location loc,
   }
 }
 
+/// Modifies the body of the function to construct the `MemRefDescriptor` from
+/// the bare pointer calling convention lowering of `memref` types.
+static void modifyFuncOpToUseBarePtrCallingConv(
+    ConversionPatternRewriter &rewriter, Location loc,
+    LLVMTypeConverter &typeConverter, LLVM::LLVMFuncOp funcOp,
+    TypeRange oldArgTypes) {
+  if (funcOp.getBody().empty())
+    return;
+
+  // Promote bare pointers from memref arguments to memref descriptors at the
+  // beginning of the function so that all the memrefs in the function have a
+  // uniform representation.
+  Block *entryBlock = &funcOp.getBody().front();
+  auto blockArgs = entryBlock->getArguments();
+  assert(blockArgs.size() == oldArgTypes.size() &&
+         "The number of arguments and types doesn't match");
+
+  OpBuilder::InsertionGuard guard(rewriter);
+  rewriter.setInsertionPointToStart(entryBlock);
+  for (auto it : llvm::zip(blockArgs, oldArgTypes)) {
+    BlockArgument arg = std::get<0>(it);
+    Type argTy = std::get<1>(it);
+
+    // Unranked memrefs are not supported in the bare pointer calling
+    // convention. We should have bailed out before in the presence of
+    // unranked memrefs.
+    assert(!argTy.isa<UnrankedMemRefType>() &&
+           "Unranked memref is not supported");
+    auto memrefTy = argTy.dyn_cast<MemRefType>();
+    if (!memrefTy)
+      continue;
+
+    // Replace barePtr with a placeholder (undef), promote barePtr to a ranked
+    // or unranked memref descriptor and replace placeholder with the last
+    // instruction of the memref descriptor.
+    // TODO: The placeholder is needed to avoid replacing barePtr uses in the
+    // MemRef descriptor instructions. We may want to have a utility in the
+    // rewriter to properly handle this use case.
+    Location loc = funcOp.getLoc();
+    auto placeholder = rewriter.create<LLVM::UndefOp>(
+        loc, typeConverter.convertType(memrefTy));
+    rewriter.replaceUsesOfBlockArgument(arg, placeholder);
+
+    Value desc = MemRefDescriptor::fromStaticShape(rewriter, loc, typeConverter,
+                                                   memrefTy, arg);
+    rewriter.replaceOp(placeholder, {desc});
+  }
+}
+
 namespace {
 
 struct FuncOpConversionBase : public ConvertOpToLLVMPattern<func::FuncOp> {
@@ -284,7 +341,7 @@ struct FuncOpConversionBase : public ConvertOpToLLVMPattern<func::FuncOp> {
     TypeConverter::SignatureConversion result(funcOp.getNumArguments());
     auto llvmType = getTypeConverter()->convertFunctionSignature(
         funcOp.getFunctionType(), varargsAttr && varargsAttr.getValue(),
-        result);
+        shouldUseBarePtrCallConv(funcOp, getTypeConverter()), result);
     if (!llvmType)
       return nullptr;
 
@@ -415,89 +472,24 @@ struct FuncOpConversion : public FuncOpConversionBase {
     if (!newFuncOp)
       return failure();
 
-    if (funcOp->getAttrOfType<UnitAttr>(
-            LLVM::LLVMDialect::getEmitCWrapperAttrName())) {
-      if (newFuncOp.isVarArg())
-        return funcOp->emitError("C interface for variadic functions is not "
-                                 "supported yet.");
-
-      if (newFuncOp.isExternal())
-        wrapExternalFunction(rewriter, funcOp.getLoc(), *getTypeConverter(),
-                             funcOp, newFuncOp);
-      else
-        wrapForExternalCallers(rewriter, funcOp.getLoc(), *getTypeConverter(),
-                               funcOp, newFuncOp);
-    }
-
-    rewriter.eraseOp(funcOp);
-    return success();
-  }
-};
-
-/// FuncOp legalization pattern that converts MemRef arguments to bare pointers
-/// to the MemRef element type. This will impact the calling convention and ABI.
-struct BarePtrFuncOpConversion : public FuncOpConversionBase {
-  using FuncOpConversionBase::FuncOpConversionBase;
-
-  LogicalResult
-  matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-
-    // TODO: bare ptr conversion could be handled by argument materialization
-    // and most of the code below would go away. But to do this, we would need a
-    // way to distinguish between FuncOp and other regions in the
-    // addArgumentMaterialization hook.
+    if (!shouldUseBarePtrCallConv(funcOp, this->getTypeConverter())) {
+      if (funcOp->getAttrOfType<UnitAttr>(
+              LLVM::LLVMDialect::getEmitCWrapperAttrName())) {
+        if (newFuncOp.isVarArg())
+          return funcOp->emitError("C interface for variadic functions is not "
+                                   "supported yet.");
 
-    // Store the type of memref-typed arguments before the conversion so that we
-    // can promote them to MemRef descriptor at the beginning of the function.
-    SmallVector<Type, 8> oldArgTypes =
-        llvm::to_vector<8>(funcOp.getFunctionType().getInputs());
-
-    auto newFuncOp = convertFuncOpToLLVMFuncOp(funcOp, rewriter);
-    if (!newFuncOp)
-      return failure();
-    if (newFuncOp.getBody().empty()) {
-      rewriter.eraseOp(funcOp);
-      return success();
-    }
-
-    // Promote bare pointers from memref arguments to memref descriptors at the
-    // beginning of the function so that all the memrefs in the function have a
-    // uniform representation.
-    Block *entryBlock = &newFuncOp.getBody().front();
-    auto blockArgs = entryBlock->getArguments();
-    assert(blockArgs.size() == oldArgTypes.size() &&
-           "The number of arguments and types doesn't match");
-
-    OpBuilder::InsertionGuard guard(rewriter);
-    rewriter.setInsertionPointToStart(entryBlock);
-    for (auto it : llvm::zip(blockArgs, oldArgTypes)) {
-      BlockArgument arg = std::get<0>(it);
-      Type argTy = std::get<1>(it);
-
-      // Unranked memrefs are not supported in the bare pointer calling
-      // convention. We should have bailed out before in the presence of
-      // unranked memrefs.
-      assert(!argTy.isa<UnrankedMemRefType>() &&
-             "Unranked memref is not supported");
-      auto memrefTy = argTy.dyn_cast<MemRefType>();
-      if (!memrefTy)
-        continue;
-
-      // Replace barePtr with a placeholder (undef), promote barePtr to a ranked
-      // or unranked memref descriptor and replace placeholder with the last
-      // instruction of the memref descriptor.
-      // TODO: The placeholder is needed to avoid replacing barePtr uses in the
-      // MemRef descriptor instructions. We may want to have a utility in the
-      // rewriter to properly handle this use case.
-      Location loc = funcOp.getLoc();
-      auto placeholder = rewriter.create<LLVM::UndefOp>(
-          loc, getTypeConverter()->convertType(memrefTy));
-      rewriter.replaceUsesOfBlockArgument(arg, placeholder);
-
-      Value desc = MemRefDescriptor::fromStaticShape(
-          rewriter, loc, *getTypeConverter(), memrefTy, arg);
-      rewriter.replaceOp(placeholder, {desc});
+        if (newFuncOp.isExternal())
+          wrapExternalFunction(rewriter, funcOp.getLoc(), *getTypeConverter(),
+                               funcOp, newFuncOp);
+        else
+          wrapForExternalCallers(rewriter, funcOp.getLoc(), *getTypeConverter(),
+                                 funcOp, newFuncOp);
+      }
+    } else {
+      modifyFuncOpToUseBarePtrCallingConv(rewriter, funcOp.getLoc(),
+                                          *getTypeConverter(), newFuncOp,
+                                          funcOp.getFunctionType().getInputs());
     }
 
     rewriter.eraseOp(funcOp);
@@ -535,23 +527,24 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> {
   using Super = CallOpInterfaceLowering<CallOpType>;
   using Base = ConvertOpToLLVMPattern<CallOpType>;
 
-  LogicalResult
-  matchAndRewrite(CallOpType callOp, typename CallOpType::Adaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
+  LogicalResult matchAndRewriteImpl(CallOpType callOp,
+                                    typename CallOpType::Adaptor adaptor,
+                                    ConversionPatternRewriter &rewriter,
+                                    bool useBarePtrCallConv = false) const {
     // Pack the result types into a struct.
     Type packedResult = nullptr;
     unsigned numResults = callOp.getNumResults();
     auto resultTypes = llvm::to_vector<4>(callOp.getResultTypes());
 
     if (numResults != 0) {
-      if (!(packedResult =
-                this->getTypeConverter()->packFunctionResults(resultTypes)))
+      if (!(packedResult = this->getTypeConverter()->packFunctionResults(
+                resultTypes, useBarePtrCallConv)))
         return failure();
     }
 
     auto promoted = this->getTypeConverter()->promoteOperands(
         callOp.getLoc(), /*opOperands=*/callOp->getOperands(),
-        adaptor.getOperands(), rewriter);
+        adaptor.getOperands(), rewriter, useBarePtrCallConv);
     auto newOp = rewriter.create<LLVM::CallOp>(
         callOp.getLoc(), packedResult ? TypeRange(packedResult) : TypeRange(),
         promoted, callOp->getAttrs());
@@ -570,7 +563,7 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> {
       }
     }
 
-    if (this->getTypeConverter()->getOptions().useBarePtrCallConv) {
+    if (useBarePtrCallConv) {
       // For the bare-ptr calling convention, promote memref results to
       // descriptors.
       assert(results.size() == resultTypes.size() &&
@@ -590,11 +583,28 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> {
 
 struct CallOpLowering : public CallOpInterfaceLowering<func::CallOp> {
   using Super::Super;
+
+  LogicalResult
+  matchAndRewrite(func::CallOp callOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    bool useBarePtrCallConv = false;
+    if (Operation *callee = SymbolTable::lookupNearestSymbolFrom(
+            callOp, callOp.getCalleeAttr())) {
+      useBarePtrCallConv = shouldUseBarePtrCallConv(callee, getTypeConverter());
+    }
+    return matchAndRewriteImpl(callOp, adaptor, rewriter, useBarePtrCallConv);
+  }
 };
 
 struct CallIndirectOpLowering
     : public CallOpInterfaceLowering<func::CallIndirectOp> {
   using Super::Super;
+
+  LogicalResult
+  matchAndRewrite(func::CallIndirectOp callIndirectOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    return matchAndRewriteImpl(callIndirectOp, adaptor, rewriter);
+  }
 };
 
 struct UnrealizedConversionCastOpLowering
@@ -640,7 +650,10 @@ struct ReturnOpLowering : public ConvertOpToLLVMPattern<func::ReturnOp> {
     unsigned numArguments = op.getNumOperands();
     SmallVector<Value, 4> updatedOperands;
 
-    if (getTypeConverter()->getOptions().useBarePtrCallConv) {
+    auto funcOp = op->getParentOfType<LLVM::LLVMFuncOp>();
+    bool useBarePtrCallConv =
+        shouldUseBarePtrCallConv(funcOp, this->getTypeConverter());
+    if (useBarePtrCallConv) {
       // For the bare-ptr calling convention, extract the aligned pointer to
       // be returned from the memref descriptor.
       for (auto it : llvm::zip(op->getOperands(), adaptor.getOperands())) {
@@ -649,7 +662,7 @@ struct ReturnOpLowering : public ConvertOpToLLVMPattern<func::ReturnOp> {
         if (oldTy.isa<MemRefType>() && getTypeConverter()->canConvertToBarePtr(
                                            oldTy.cast<BaseMemRefType>())) {
           MemRefDescriptor memrefDesc(newOperand);
-          newOperand = memrefDesc.alignedPtr(rewriter, loc);
+          newOperand = memrefDesc.allocatedPtr(rewriter, loc);
         } else if (oldTy.isa<UnrankedMemRefType>()) {
           // Unranked memref is not supported in the bare pointer calling
           // convention.
@@ -673,8 +686,8 @@ struct ReturnOpLowering : public ConvertOpToLLVMPattern<func::ReturnOp> {
 
     // Otherwise, we need to pack the arguments into an LLVM struct type before
     // returning.
-    auto packedType =
-        getTypeConverter()->packFunctionResults(op.getOperandTypes());
+    auto packedType = getTypeConverter()->packFunctionResults(
+        op.getOperandTypes(), useBarePtrCallConv);
     if (!packedType) {
       return rewriter.notifyMatchFailure(op, "could not convert result types");
     }
@@ -692,10 +705,7 @@ struct ReturnOpLowering : public ConvertOpToLLVMPattern<func::ReturnOp> {
 
 void mlir::populateFuncToLLVMFuncOpConversionPattern(
     LLVMTypeConverter &converter, RewritePatternSet &patterns) {
-  if (converter.getOptions().useBarePtrCallConv)
-    patterns.add<BarePtrFuncOpConversion>(converter);
-  else
-    patterns.add<FuncOpConversion>(converter);
+  patterns.add<FuncOpConversion>(converter);
 }
 
 void mlir::populateFuncToLLVMConversionPatterns(LLVMTypeConverter &converter,

diff  --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index ec0d240040d1e..82c73b5f4dd2e 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -47,7 +47,8 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
   TypeConverter::SignatureConversion signatureConversion(
       gpuFuncOp.front().getNumArguments());
   Type funcType = getTypeConverter()->convertFunctionSignature(
-      gpuFuncOp.getFunctionType(), /*isVariadic=*/false, signatureConversion);
+      gpuFuncOp.getFunctionType(), /*isVariadic=*/false,
+      getTypeConverter()->getOptions().useBarePtrCallConv, signatureConversion);
 
   // Create the new function operation. Only copy those attributes that are
   // not specific to function modeling.

diff  --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index e24be1dfdf6b9..833ea36ecf7bd 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -209,8 +209,8 @@ Type LLVMTypeConverter::convertComplexType(ComplexType type) {
 // pointer-to-function types.
 Type LLVMTypeConverter::convertFunctionType(FunctionType type) {
   SignatureConversion conversion(type.getNumInputs());
-  Type converted =
-      convertFunctionSignature(type, /*isVariadic=*/false, conversion);
+  Type converted = convertFunctionSignature(
+      type, /*isVariadic=*/false, options.useBarePtrCallConv, conversion);
   if (!converted)
     return {};
   return getPointerType(converted);
@@ -221,12 +221,12 @@ Type LLVMTypeConverter::convertFunctionType(FunctionType type) {
 // Function has one VoidType result.  If MLIR Function has more than one result,
 // they are into an LLVM StructType in their order of appearance.
 Type LLVMTypeConverter::convertFunctionSignature(
-    FunctionType funcTy, bool isVariadic,
+    FunctionType funcTy, bool isVariadic, bool useBarePtrCallConv,
     LLVMTypeConverter::SignatureConversion &result) {
   // Select the argument converter depending on the calling convention.
-  auto funcArgConverter = options.useBarePtrCallConv
-                              ? barePtrFuncArgTypeConverter
-                              : structFuncArgTypeConverter;
+  useBarePtrCallConv = useBarePtrCallConv || options.useBarePtrCallConv;
+  auto funcArgConverter = useBarePtrCallConv ? barePtrFuncArgTypeConverter
+                                             : structFuncArgTypeConverter;
   // Convert argument types one by one and check for errors.
   for (auto [idx, type] : llvm::enumerate(funcTy.getInputs())) {
     SmallVector<Type, 8> converted;
@@ -238,9 +238,10 @@ Type LLVMTypeConverter::convertFunctionSignature(
   // If function does not return anything, create the void result type,
   // if it returns on element, convert it, otherwise pack the result types into
   // a struct.
-  Type resultType = funcTy.getNumResults() == 0
-                        ? LLVM::LLVMVoidType::get(&getContext())
-                        : packFunctionResults(funcTy.getResults());
+  Type resultType =
+      funcTy.getNumResults() == 0
+          ? LLVM::LLVMVoidType::get(&getContext())
+          : packFunctionResults(funcTy.getResults(), useBarePtrCallConv);
   if (!resultType)
     return {};
   return LLVM::LLVMFunctionType::get(resultType, result.getConvertedTypes(),
@@ -472,8 +473,9 @@ Type LLVMTypeConverter::convertVectorType(VectorType type) {
 /// UnrankedMemRefType, are converted following the specific rules for the
 /// calling convention. Calling convention independent types are converted
 /// following the default LLVM type conversions.
-Type LLVMTypeConverter::convertCallingConventionType(Type type) {
-  if (options.useBarePtrCallConv)
+Type LLVMTypeConverter::convertCallingConventionType(Type type,
+                                                     bool useBarePtrCallConv) {
+  if (useBarePtrCallConv)
     if (auto memrefTy = type.dyn_cast<BaseMemRefType>())
       return convertMemRefToBarePtr(memrefTy);
 
@@ -498,16 +500,18 @@ void LLVMTypeConverter::promoteBarePtrsToDescriptors(
 /// supported LLVM IR type.  In particular, if more than one value is returned,
 /// create an LLVM IR structure type with elements that correspond to each of
 /// the MLIR types converted with `convertType`.
-Type LLVMTypeConverter::packFunctionResults(TypeRange types) {
+Type LLVMTypeConverter::packFunctionResults(TypeRange types,
+                                            bool useBarePtrCallConv) {
   assert(!types.empty() && "expected non-empty list of type");
 
+  useBarePtrCallConv |= options.useBarePtrCallConv;
   if (types.size() == 1)
-    return convertCallingConventionType(types.front());
+    return convertCallingConventionType(types.front(), useBarePtrCallConv);
 
   SmallVector<Type, 8> resultTypes;
   resultTypes.reserve(types.size());
   for (auto t : types) {
-    auto converted = convertCallingConventionType(t);
+    auto converted = convertCallingConventionType(t, useBarePtrCallConv);
     if (!converted || !LLVM::isCompatibleType(converted))
       return {};
     resultTypes.push_back(converted);
@@ -530,17 +534,18 @@ Value LLVMTypeConverter::promoteOneMemRefDescriptor(Location loc, Value operand,
   return allocated;
 }
 
-SmallVector<Value, 4> LLVMTypeConverter::promoteOperands(Location loc,
-                                                         ValueRange opOperands,
-                                                         ValueRange operands,
-                                                         OpBuilder &builder) {
+SmallVector<Value, 4>
+LLVMTypeConverter::promoteOperands(Location loc, ValueRange opOperands,
+                                   ValueRange operands, OpBuilder &builder,
+                                   bool useBarePtrCallConv) {
   SmallVector<Value, 4> promotedOperands;
   promotedOperands.reserve(operands.size());
+  useBarePtrCallConv |= options.useBarePtrCallConv;
   for (auto it : llvm::zip(opOperands, operands)) {
     auto operand = std::get<0>(it);
     auto llvmOperand = std::get<1>(it);
 
-    if (options.useBarePtrCallConv) {
+    if (useBarePtrCallConv) {
       // For the bare-ptr calling convention, we only have to extract the
       // aligned pointer of a memref.
       if (auto memrefType = operand.getType().dyn_cast<MemRefType>()) {
@@ -603,7 +608,8 @@ LogicalResult mlir::structFuncArgTypeConverter(LLVMTypeConverter &converter,
 LogicalResult mlir::barePtrFuncArgTypeConverter(LLVMTypeConverter &converter,
                                                 Type type,
                                                 SmallVectorImpl<Type> &result) {
-  auto llvmTy = converter.convertCallingConventionType(type);
+  auto llvmTy =
+      converter.convertCallingConventionType(type, /*useBarePtrCallConv=*/true);
   if (!llvmTy)
     return failure();
 

diff  --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
index 2cdce91806068..b93894757daa5 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
@@ -338,7 +338,8 @@ class AddressOfPattern : public SPIRVToLLVMConversion<spirv::AddressOfOp> {
     auto dstType = typeConverter.convertType(op.getPointer().getType());
     if (!dstType)
       return failure();
-    rewriter.replaceOpWithNewOp<LLVM::AddressOfOp>(op, dstType, op.getVariable());
+    rewriter.replaceOpWithNewOp<LLVM::AddressOfOp>(op, dstType,
+                                                   op.getVariable());
     return success();
   }
 };
@@ -582,7 +583,8 @@ class CompositeExtractPattern
     }
 
     rewriter.replaceOpWithNewOp<LLVM::ExtractValueOp>(
-        op, adaptor.getComposite(), LLVM::convertArrayToIndices(op.getIndices()));
+        op, adaptor.getComposite(),
+        LLVM::convertArrayToIndices(op.getIndices()));
     return success();
   }
 };
@@ -1146,7 +1148,8 @@ class SelectionPattern : public SPIRVToLLVMConversion<spirv::SelectionOp> {
     Block *falseBlock = condBrOp.getFalseBlock();
     rewriter.setInsertionPointToEnd(currentBlock);
     rewriter.create<LLVM::CondBrOp>(loc, condBrOp.getCondition(), trueBlock,
-                                    condBrOp.getTrueTargetOperands(), falseBlock,
+                                    condBrOp.getTrueTargetOperands(),
+                                    falseBlock,
                                     condBrOp.getFalseTargetOperands());
 
     rewriter.inlineRegionBefore(op.getBody(), continueBlock);
@@ -1329,7 +1332,8 @@ class FuncConversionPattern : public SPIRVToLLVMConversion<spirv::FuncOp> {
     TypeConverter::SignatureConversion signatureConverter(
         funcType.getNumInputs());
     auto llvmType = typeConverter.convertFunctionSignature(
-        funcType, /*isVariadic=*/false, signatureConverter);
+        funcType, /*isVariadic=*/false, /*useBarePtrCallConv=*/false,
+        signatureConverter);
     if (!llvmType)
       return failure();
 

diff  --git a/mlir/test/Conversion/FuncToLLVM/calling-convention.mlir b/mlir/test/Conversion/FuncToLLVM/calling-convention.mlir
index daa824d84ba74..b1c065e0f1f8d 100644
--- a/mlir/test/Conversion/FuncToLLVM/calling-convention.mlir
+++ b/mlir/test/Conversion/FuncToLLVM/calling-convention.mlir
@@ -242,3 +242,67 @@ func.func @return_two_var_memref(%arg0: memref<4x3xf32>) -> (memref<*xf32>, memr
 // CHECK-LABEL: @_mlir_ciface_return_two_var_memref
 // CHECK-SAME: (%{{.*}}: !llvm.ptr,
 // CHECK-SAME: %{{.*}}: !llvm.ptr)
+
+// CHECK-LABEL: llvm.func @bare_ptr_calling_conv(
+// CHECK-SAME: %[[ARG0:.*]]: !llvm.ptr
+// CHECK-SAME: -> !llvm.ptr
+func.func @bare_ptr_calling_conv(%arg0: memref<4x3xf32>, %arg1 : index, %arg2 : index, %arg3 : f32)
+     -> (memref<4x3xf32>) attributes { llvm.bareptr } {
+  // CHECK: %[[UNDEF_DESC:.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+  // CHECK: %[[INSERT_ALLOCPTR:.*]] = llvm.insertvalue %[[ARG0]], %[[UNDEF_DESC]][0]
+  // CHECK: %[[INSERT_ALIGNEDPTR:.*]] = llvm.insertvalue %[[ARG0]], %[[INSERT_ALLOCPTR]][1]
+  // CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : index) : i64
+  // CHECK: %[[INSERT_OFFSET:.*]] = llvm.insertvalue %[[C0]], %[[INSERT_ALIGNEDPTR]][2]
+  // CHECK: %[[C4:.*]] = llvm.mlir.constant(4 : index) : i64
+  // CHECK: %[[INSERT_DIM0:.*]] = llvm.insertvalue %[[C4]], %[[INSERT_OFFSET]][3, 0]
+  // CHECK: %[[C3:.*]] = llvm.mlir.constant(3 : index) : i64
+  // CHECK: %[[INSERT_STRIDE0:.*]] = llvm.insertvalue %[[C3]], %[[INSERT_DIM0]][4, 0]
+  // CHECK: %[[C3:.*]] = llvm.mlir.constant(3 : index) : i64
+  // CHECK: %[[INSERT_DIM1:.*]] = llvm.insertvalue %[[C3]], %[[INSERT_STRIDE0]][3, 1]
+  // CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : index) : i64
+  // CHECK: %[[INSERT_STRIDE1:.*]] = llvm.insertvalue %[[C1]], %[[INSERT_DIM1]][4, 1]
+
+  // CHECK: %[[ALIGNEDPTR:.*]] = llvm.extractvalue %[[INSERT_STRIDE1]][1]
+  // CHECK: %[[STOREPTR:.*]] = llvm.getelementptr %[[ALIGNEDPTR]]
+  // CHECK: llvm.store %{{.*}}, %[[STOREPTR]]
+  memref.store %arg3, %arg0[%arg1, %arg2] : memref<4x3xf32>
+
+  // CHECK: llvm.return %[[ARG0]]
+  return %arg0 : memref<4x3xf32>
+}
+
+// CHECK-LABEL: llvm.func @bare_ptr_calling_conv_multiresult(
+// CHECK-SAME: %[[ARG0:.*]]: !llvm.ptr
+// CHECK-SAME: -> !llvm.struct<(f32, ptr)>
+func.func @bare_ptr_calling_conv_multiresult(%arg0: memref<4x3xf32>, %arg1 : index, %arg2 : index, %arg3 : f32)
+     -> (f32, memref<4x3xf32>) attributes { llvm.bareptr } {
+  // CHECK: %[[UNDEF_DESC:.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+  // CHECK: %[[INSERT_ALLOCPTR:.*]] = llvm.insertvalue %[[ARG0]], %[[UNDEF_DESC]][0]
+  // CHECK: %[[INSERT_ALIGNEDPTR:.*]] = llvm.insertvalue %[[ARG0]], %[[INSERT_ALLOCPTR]][1]
+  // CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : index) : i64
+  // CHECK: %[[INSERT_OFFSET:.*]] = llvm.insertvalue %[[C0]], %[[INSERT_ALIGNEDPTR]][2]
+  // CHECK: %[[C4:.*]] = llvm.mlir.constant(4 : index) : i64
+  // CHECK: %[[INSERT_DIM0:.*]] = llvm.insertvalue %[[C4]], %[[INSERT_OFFSET]][3, 0]
+  // CHECK: %[[C3:.*]] = llvm.mlir.constant(3 : index) : i64
+  // CHECK: %[[INSERT_STRIDE0:.*]] = llvm.insertvalue %[[C3]], %[[INSERT_DIM0]][4, 0]
+  // CHECK: %[[C3:.*]] = llvm.mlir.constant(3 : index) : i64
+  // CHECK: %[[INSERT_DIM1:.*]] = llvm.insertvalue %[[C3]], %[[INSERT_STRIDE0]][3, 1]
+  // CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : index) : i64
+  // CHECK: %[[INSERT_STRIDE1:.*]] = llvm.insertvalue %[[C1]], %[[INSERT_DIM1]][4, 1]
+
+  // CHECK: %[[ALIGNEDPTR:.*]] = llvm.extractvalue %[[INSERT_STRIDE1]][1]
+  // CHECK: %[[STOREPTR:.*]] = llvm.getelementptr %[[ALIGNEDPTR]]
+  // CHECK: llvm.store %{{.*}}, %[[STOREPTR]]
+  memref.store %arg3, %arg0[%arg1, %arg2] : memref<4x3xf32>
+
+  // CHECK: %[[ALIGNEDPTR0:.*]] = llvm.extractvalue %[[INSERT_STRIDE1]][1]
+  // CHECK: %[[LOADPTR:.*]] = llvm.getelementptr %[[ALIGNEDPTR0]]
+  // CHECK: %[[RETURN0:.*]] = llvm.load %[[LOADPTR]]
+  %0 = memref.load %arg0[%arg1, %arg2] : memref<4x3xf32>
+
+  // CHECK: %[[RETURN_DESC:.*]] = llvm.mlir.undef : !llvm.struct<(f32, ptr)>
+  // CHECK: %[[INSERT_RETURN0:.*]] = llvm.insertvalue %[[RETURN0]], %[[RETURN_DESC]][0]
+  // CHECK: %[[INSERT_RETURN1:.*]] = llvm.insertvalue %[[ARG0]], %[[INSERT_RETURN0]][1]
+  // CHECK: llvm.return %[[INSERT_RETURN1]]
+  return %0, %arg0 : f32, memref<4x3xf32>
+}

diff  --git a/mlir/test/Conversion/FuncToLLVM/func-memref-return.mlir b/mlir/test/Conversion/FuncToLLVM/func-memref-return.mlir
index 8663ce8cbbf2f..956c298123db2 100644
--- a/mlir/test/Conversion/FuncToLLVM/func-memref-return.mlir
+++ b/mlir/test/Conversion/FuncToLLVM/func-memref-return.mlir
@@ -27,7 +27,7 @@ func.func @check_static_return(%static : memref<32x18xf32>) -> memref<32x18xf32>
 // BAREPTR-NEXT: %[[ins3:.*]] = llvm.insertvalue %[[val3]], %[[ins2]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
 // BAREPTR-NEXT: %[[val4:.*]] = llvm.mlir.constant(1 : index) : i64
 // BAREPTR-NEXT: %[[ins4:.*]] = llvm.insertvalue %[[val4]], %[[ins3]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
-// BAREPTR-NEXT: %[[base1:.*]] = llvm.extractvalue %[[ins4]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// BAREPTR-NEXT: %[[base1:.*]] = llvm.extractvalue %[[ins4]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
 // BAREPTR-NEXT: llvm.return %[[base1]] : !llvm.ptr
   return %static : memref<32x18xf32>
 }
@@ -56,7 +56,7 @@ func.func @check_static_return_with_offset(%static : memref<32x18xf32, strided<[
 // BAREPTR-NEXT: %[[ins3:.*]] = llvm.insertvalue %[[val3]], %[[ins2]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
 // BAREPTR-NEXT: %[[val4:.*]] = llvm.mlir.constant(1 : index) : i64
 // BAREPTR-NEXT: %[[ins4:.*]] = llvm.insertvalue %[[val4]], %[[ins3]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
-// BAREPTR-NEXT: %[[base1:.*]] = llvm.extractvalue %[[ins4]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// BAREPTR-NEXT: %[[base1:.*]] = llvm.extractvalue %[[ins4]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
 // BAREPTR-NEXT: llvm.return %[[base1]] : !llvm.ptr
   return %static : memref<32x18xf32, strided<[22,1], offset: 7>>
 }
@@ -82,7 +82,7 @@ func.func @check_memref_func_call(%in : memref<10xi8>) -> memref<20xi8> {
   // BAREPTR-NEXT:    %[[c1:.*]] = llvm.mlir.constant(1 : index) : i64
   // BAREPTR-NEXT:    %[[outDesc:.*]] = llvm.insertvalue %[[c1]], %[[desc6]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
   %res = call @foo(%in) : (memref<10xi8>) -> (memref<20xi8>)
-  // BAREPTR-NEXT:    %[[res:.*]] = llvm.extractvalue %[[outDesc]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+  // BAREPTR-NEXT:    %[[res:.*]] = llvm.extractvalue %[[outDesc]][0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
   // BAREPTR-NEXT:    llvm.return %[[res]] : !llvm.ptr
   return %res : memref<20xi8>
 }


        


More information about the Mlir-commits mailing list