[Mlir-commits] [mlir] [mlir][LLVM] `FuncToLLVM`: Add 1:N type conversion support (PR #153823)

Matthias Springer llvmlistbot at llvm.org
Fri Aug 15 08:53:40 PDT 2025


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/153823

>From b79aea37e442d5a8c24241f1cbeacaaafea83e12 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Thu, 14 Aug 2025 13:45:48 +0000
Subject: [PATCH] [mlir][LLVM] FuncToLLVM: Add 1:N support

---
 .../Conversion/LLVMCommon/TypeConverter.h     |  24 ++--
 mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp | 104 +++++++++++-------
 .../GPUCommon/GPUToLLVMConversion.cpp         |  12 +-
 .../Conversion/LLVMCommon/TypeConverter.cpp   | 100 ++++++++---------
 .../Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp    |  28 +++--
 .../MemRefToLLVM/type-conversion.mlir         |  97 ++++++++++++++--
 mlir/test/lib/Dialect/LLVM/TestPatterns.cpp   |  30 +++++
 7 files changed, 268 insertions(+), 127 deletions(-)

diff --git a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
index 38b5e492a8ed8..a38b3283416e0 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
@@ -74,8 +74,13 @@ class LLVMTypeConverter : public TypeConverter {
   /// LLVM-compatible type. In particular, if more than one value is returned,
   /// create an LLVM dialect structure type with elements that correspond to
   /// each of the types converted with `convertCallingConventionType`.
-  Type packFunctionResults(TypeRange types,
-                           bool useBarePointerCallConv = false) const;
+  ///
+  /// Populate the converted (unpacked) types into `groupedTypes`, if provided.
+  /// `groupedType` contains one nested vector per input type. In case of a 1:N
+  /// conversion, a nested vector may contain 0 or more then 1 converted type.
+  Type packFunctionResults(
+      TypeRange types, bool useBarePointerCallConv = false,
+      SmallVector<SmallVector<Type>> *groupedTypes = nullptr) const;
 
   /// Convert a non-empty list of types of values produced by an operation into
   /// an LLVM-compatible type. In particular, if more than one value is
@@ -88,15 +93,9 @@ class LLVMTypeConverter : public TypeConverter {
   /// UnrankedMemRefType, are converted following the specific rules for the
   /// calling convention. Calling convention independent types are converted
   /// following the default LLVM type conversions.
-  Type convertCallingConventionType(Type type,
-                                    bool useBarePointerCallConv = false) const;
-
-  /// Promote the bare pointers in 'values' that resulted from memrefs to
-  /// descriptors. 'stdTypes' holds the types of 'values' before the conversion
-  /// to the LLVM-IR dialect (i.e., MemRefType, or any other builtin type).
-  void promoteBarePtrsToDescriptors(ConversionPatternRewriter &rewriter,
-                                    Location loc, ArrayRef<Type> stdTypes,
-                                    SmallVectorImpl<Value> &values) const;
+  LogicalResult
+  convertCallingConventionType(Type type, SmallVectorImpl<Type> &result,
+                               bool useBarePointerCallConv = false) const;
 
   /// Returns the MLIR context.
   MLIRContext &getContext() const;
@@ -111,7 +110,8 @@ class LLVMTypeConverter : public TypeConverter {
   /// of the platform-specific C/C++ ABI lowering related to struct argument
   /// passing.
   SmallVector<Value, 4> promoteOperands(Location loc, ValueRange opOperands,
-                                        ValueRange operands, OpBuilder &builder,
+                                        ArrayRef<ValueRange> operands,
+                                        OpBuilder &builder,
                                         bool useBarePtrCallConv = false) const;
 
   /// Promote the LLVM struct representation of one MemRef descriptor to stack
diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
index a4a6ae250640f..95981138a7253 100644
--- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
+++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
@@ -527,19 +527,19 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> {
   using ConvertOpToLLVMPattern<CallOpType>::ConvertOpToLLVMPattern;
   using Super = CallOpInterfaceLowering<CallOpType>;
   using Base = ConvertOpToLLVMPattern<CallOpType>;
+  using Adaptor = typename ConvertOpToLLVMPattern<CallOpType>::OneToNOpAdaptor;
 
-  LogicalResult matchAndRewriteImpl(CallOpType callOp,
-                                    typename CallOpType::Adaptor adaptor,
+  LogicalResult matchAndRewriteImpl(CallOpType callOp, Adaptor adaptor,
                                     ConversionPatternRewriter &rewriter,
                                     bool useBarePtrCallConv = false) const {
     // Pack the result types into a struct.
     Type packedResult = nullptr;
+    SmallVector<SmallVector<Type>> groupedResultTypes;
     unsigned numResults = callOp.getNumResults();
     auto resultTypes = llvm::to_vector<4>(callOp.getResultTypes());
-
     if (numResults != 0) {
       if (!(packedResult = this->getTypeConverter()->packFunctionResults(
-                resultTypes, useBarePtrCallConv)))
+                resultTypes, useBarePtrCallConv, &groupedResultTypes)))
         return failure();
     }
 
@@ -565,34 +565,61 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> {
         static_cast<int32_t>(promoted.size()), 0};
     newOp.getProperties().op_bundle_sizes = rewriter.getDenseI32ArrayAttr({});
 
-    SmallVector<Value, 4> results;
-    if (numResults < 2) {
-      // If < 2 results, packing did not do anything and we can just return.
-      results.append(newOp.result_begin(), newOp.result_end());
-    } else {
-      // Otherwise, it had been converted to an operation producing a structure.
-      // Extract individual results from the structure and return them as list.
-      results.reserve(numResults);
-      for (unsigned i = 0; i < numResults; ++i) {
-        results.push_back(LLVM::ExtractValueOp::create(
-            rewriter, callOp.getLoc(), newOp->getResult(0), i));
+    // Helper function that extracts an individual result from the return value
+    // of the new call op. llvm.call ops support only 0 or 1 result. In case of
+    // 2 or more results, the results are packed into a structure.
+    auto getUnpackedResult = [&](unsigned i) -> Value {
+      assert(packedResult && "convert op has no results");
+      if (!isa<LLVM::LLVMStructType>(packedResult)) {
+        assert(i == 0 && "out of bounds: converted op has only one result");
+        return newOp->getResult(0);
+      }
+      // Results have been converted to a structure. Extract individual results
+      // from the structure.
+      return LLVM::ExtractValueOp::create(rewriter, callOp.getLoc(),
+                                          newOp->getResult(0), i);
+    };
+
+    // Group the results into a vector of vectors, such that it is clear which
+    // original op result is replaced with which range of values. (In case of a
+    // 1:N conversion, there can be multiple replacements for a single result.)
+    SmallVector<SmallVector<Value>> results;
+    results.reserve(numResults);
+    unsigned counter = 0;
+    for (unsigned i = 0; i < numResults; ++i) {
+      SmallVector<Value> &group = results.emplace_back();
+      for (unsigned j = 0, e = groupedResultTypes[i].size(); j < e; ++j) {
+        group.push_back(getUnpackedResult(counter++));
       }
     }
 
-    if (useBarePtrCallConv) {
-      // For the bare-ptr calling convention, promote memref results to
-      // descriptors.
-      assert(results.size() == resultTypes.size() &&
-             "The number of arguments and types doesn't match");
-      this->getTypeConverter()->promoteBarePtrsToDescriptors(
-          rewriter, callOp.getLoc(), resultTypes, results);
-    } else if (failed(this->copyUnrankedDescriptors(rewriter, callOp.getLoc(),
-                                                    resultTypes, results,
-                                                    /*toDynamic=*/false))) {
-      return failure();
+    // Special handling for MemRef types.
+    for (unsigned i = 0; i < numResults; ++i) {
+      Type origType = resultTypes[i];
+      auto memrefType = dyn_cast<MemRefType>(origType);
+      auto unrankedMemrefType = dyn_cast<UnrankedMemRefType>(origType);
+      if (useBarePtrCallConv && memrefType) {
+        // For the bare-ptr calling convention, promote memref results to
+        // descriptors.
+        assert(results[i].size() == 1 && "expected one converted result");
+        results[i].front() = MemRefDescriptor::fromStaticShape(
+            rewriter, callOp.getLoc(), *this->getTypeConverter(), memrefType,
+            results[i].front());
+      }
+      if (unrankedMemrefType) {
+        assert(!useBarePtrCallConv && "unranked memref is not supported in the "
+                                      "bare-ptr calling convention");
+        assert(results[i].size() == 1 && "expected one converted result");
+        Value desc = this->copyUnrankedDescriptor(
+            rewriter, callOp.getLoc(), unrankedMemrefType, results[i].front(),
+            /*toDynamic=*/false);
+        if (!desc)
+          return failure();
+        results[i].front() = desc;
+      }
     }
 
-    rewriter.replaceOp(callOp, results);
+    rewriter.replaceOpWithMultiple(callOp, results);
     return success();
   }
 };
@@ -606,7 +633,7 @@ class CallOpLowering : public CallOpInterfaceLowering<func::CallOp> {
         symbolTables(symbolTables) {}
 
   LogicalResult
-  matchAndRewrite(func::CallOp callOp, OpAdaptor adaptor,
+  matchAndRewrite(func::CallOp callOp, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     bool useBarePtrCallConv = false;
     if (getTypeConverter()->getOptions().useBarePtrCallConv) {
@@ -636,7 +663,7 @@ struct CallIndirectOpLowering
   using Super::Super;
 
   LogicalResult
-  matchAndRewrite(func::CallIndirectOp callIndirectOp, OpAdaptor adaptor,
+  matchAndRewrite(func::CallIndirectOp callIndirectOp, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     return matchAndRewriteImpl(callIndirectOp, adaptor, rewriter);
   }
@@ -679,47 +706,50 @@ struct ReturnOpLowering : public ConvertOpToLLVMPattern<func::ReturnOp> {
   using ConvertOpToLLVMPattern<func::ReturnOp>::ConvertOpToLLVMPattern;
 
   LogicalResult
-  matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor,
+  matchAndRewrite(func::ReturnOp op, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     Location loc = op.getLoc();
-    unsigned numArguments = op.getNumOperands();
     SmallVector<Value, 4> updatedOperands;
 
     auto funcOp = op->getParentOfType<LLVM::LLVMFuncOp>();
     bool useBarePtrCallConv =
         shouldUseBarePtrCallConv(funcOp, this->getTypeConverter());
 
-    for (auto [oldOperand, newOperand] :
+    for (auto [oldOperand, newOperands] :
          llvm::zip_equal(op->getOperands(), adaptor.getOperands())) {
       Type oldTy = oldOperand.getType();
       if (auto memRefType = dyn_cast<MemRefType>(oldTy)) {
+        assert(newOperands.size() == 1 && "expected one converted result");
         if (useBarePtrCallConv &&
             getTypeConverter()->canConvertToBarePtr(memRefType)) {
           // For the bare-ptr calling convention, extract the aligned pointer to
           // be returned from the memref descriptor.
-          MemRefDescriptor memrefDesc(newOperand);
+          MemRefDescriptor memrefDesc(newOperands.front());
           updatedOperands.push_back(memrefDesc.allocatedPtr(rewriter, loc));
           continue;
         }
       } else if (auto unrankedMemRefType =
                      dyn_cast<UnrankedMemRefType>(oldTy)) {
+        assert(newOperands.size() == 1 && "expected one converted result");
         if (useBarePtrCallConv) {
           // Unranked memref is not supported in the bare pointer calling
           // convention.
           return failure();
         }
-        Value updatedDesc = copyUnrankedDescriptor(
-            rewriter, loc, unrankedMemRefType, newOperand, /*toDynamic=*/true);
+        Value updatedDesc =
+            copyUnrankedDescriptor(rewriter, loc, unrankedMemRefType,
+                                   newOperands.front(), /*toDynamic=*/true);
         if (!updatedDesc)
           return failure();
         updatedOperands.push_back(updatedDesc);
         continue;
       }
-      updatedOperands.push_back(newOperand);
+
+      llvm::append_range(updatedOperands, newOperands);
     }
 
     // If ReturnOp has 0 or 1 operand, create it and return immediately.
-    if (numArguments <= 1) {
+    if (updatedOperands.size() <= 1) {
       rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(
           op, TypeRange(), updatedOperands, op->getAttrs());
       return success();
diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
index 3cfbd898e49e2..a3ec644dcc068 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
@@ -719,8 +719,10 @@ LogicalResult ConvertHostRegisterOpToGpuRuntimeCallPattern::matchAndRewrite(
   auto elementType = cast<UnrankedMemRefType>(memRefType).getElementType();
   auto elementSize = getSizeInBytes(loc, elementType, rewriter);
 
+  SmallVector<ValueRange> adaptorOperands = llvm::map_to_vector(
+      adaptor.getOperands(), [](Value v) { return ValueRange(v); });
   auto arguments = getTypeConverter()->promoteOperands(
-      loc, op->getOperands(), adaptor.getOperands(), rewriter);
+      loc, op->getOperands(), adaptorOperands, rewriter);
   arguments.push_back(elementSize);
   hostRegisterCallBuilder.create(loc, rewriter, arguments);
 
@@ -741,8 +743,10 @@ LogicalResult ConvertHostUnregisterOpToGpuRuntimeCallPattern::matchAndRewrite(
   auto elementType = cast<UnrankedMemRefType>(memRefType).getElementType();
   auto elementSize = getSizeInBytes(loc, elementType, rewriter);
 
+  SmallVector<ValueRange> adaptorOperands = llvm::map_to_vector(
+      adaptor.getOperands(), [](Value v) { return ValueRange(v); });
   auto arguments = getTypeConverter()->promoteOperands(
-      loc, op->getOperands(), adaptor.getOperands(), rewriter);
+      loc, op->getOperands(), adaptorOperands, rewriter);
   arguments.push_back(elementSize);
   hostUnregisterCallBuilder.create(loc, rewriter, arguments);
 
@@ -973,8 +977,10 @@ LogicalResult LegalizeLaunchFuncOpPattern::matchAndRewrite(
   // Note: If `useBarePtrCallConv` is set in the type converter's options,
   // the value of `kernelBarePtrCallConv` will be ignored.
   OperandRange origArguments = launchOp.getKernelOperands();
+  SmallVector<ValueRange> adaptorOperands = llvm::map_to_vector(
+      adaptor.getKernelOperands(), [](Value v) { return ValueRange(v); });
   SmallVector<Value, 8> llvmArguments = getTypeConverter()->promoteOperands(
-      loc, origArguments, adaptor.getKernelOperands(), rewriter,
+      loc, origArguments, adaptorOperands, rewriter,
       /*useBarePtrCallConv=*/kernelBarePtrCallConv);
   SmallVector<Value, 8> llvmArgumentsWithSizes;
 
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index 1a9bf569086da..621900e40f77d 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -365,6 +365,7 @@ Type LLVMTypeConverter::convertFunctionSignatureImpl(
   useBarePtrCallConv = useBarePtrCallConv || options.useBarePtrCallConv;
   auto funcArgConverter = useBarePtrCallConv ? barePtrFuncArgTypeConverter
                                              : structFuncArgTypeConverter;
+
   // Convert argument types one by one and check for errors.
   for (auto [idx, type] : llvm::enumerate(funcTy.getInputs())) {
     SmallVector<Type, 8> converted;
@@ -658,27 +659,19 @@ FailureOr<Type> LLVMTypeConverter::convertVectorType(VectorType type) const {
 /// UnrankedMemRefType, are converted following the specific rules for the
 /// calling convention. Calling convention independent types are converted
 /// following the default LLVM type conversions.
-Type LLVMTypeConverter::convertCallingConventionType(
-    Type type, bool useBarePtrCallConv) const {
-  if (useBarePtrCallConv)
-    if (auto memrefTy = dyn_cast<BaseMemRefType>(type))
-      return convertMemRefToBarePtr(memrefTy);
-
-  return convertType(type);
-}
+LogicalResult LLVMTypeConverter::convertCallingConventionType(
+    Type type, SmallVectorImpl<Type> &result, bool useBarePtrCallConv) const {
+  if (useBarePtrCallConv) {
+    if (auto memrefTy = dyn_cast<BaseMemRefType>(type)) {
+      Type converted = convertMemRefToBarePtr(memrefTy);
+      if (!converted)
+        return failure();
+      result.push_back(converted);
+      return success();
+    }
+  }
 
-/// Promote the bare pointers in 'values' that resulted from memrefs to
-/// descriptors. 'stdTypes' holds they types of 'values' before the conversion
-/// to the LLVM-IR dialect (i.e., MemRefType, or any other builtin type).
-void LLVMTypeConverter::promoteBarePtrsToDescriptors(
-    ConversionPatternRewriter &rewriter, Location loc, ArrayRef<Type> stdTypes,
-    SmallVectorImpl<Value> &values) const {
-  assert(stdTypes.size() == values.size() &&
-         "The number of types and values doesn't match");
-  for (unsigned i = 0, end = values.size(); i < end; ++i)
-    if (auto memrefTy = dyn_cast<MemRefType>(stdTypes[i]))
-      values[i] = MemRefDescriptor::fromStaticShape(rewriter, loc, *this,
-                                                    memrefTy, values[i]);
+  return convertType(type, result);
 }
 
 /// Convert a non-empty list of types of values produced by an operation into an
@@ -706,23 +699,32 @@ Type LLVMTypeConverter::packOperationResults(TypeRange types) const {
 /// LLVM-compatible type. In particular, if more than one value is returned,
 /// create an LLVM dialect structure type with elements that correspond to each
 /// of the types converted with `convertCallingConventionType`.
-Type LLVMTypeConverter::packFunctionResults(TypeRange types,
-                                            bool useBarePtrCallConv) const {
+Type LLVMTypeConverter::packFunctionResults(
+    TypeRange types, bool useBarePtrCallConv,
+    SmallVector<SmallVector<Type>> *groupedTypes) const {
   assert(!types.empty() && "expected non-empty list of type");
+  assert((!groupedTypes || groupedTypes->empty()) &&
+         "expected groupedTypes to be empty");
 
   useBarePtrCallConv |= options.useBarePtrCallConv;
-  if (types.size() == 1)
-    return convertCallingConventionType(types.front(), useBarePtrCallConv);
-
   SmallVector<Type> resultTypes;
   resultTypes.reserve(types.size());
+  size_t sizeBefore = 0;
   for (auto t : types) {
-    auto converted = convertCallingConventionType(t, useBarePtrCallConv);
-    if (!converted || !LLVM::isCompatibleType(converted))
+    if (failed(
+            convertCallingConventionType(t, resultTypes, useBarePtrCallConv)))
       return {};
-    resultTypes.push_back(converted);
+    if (groupedTypes) {
+      SmallVector<Type> &group = groupedTypes->emplace_back();
+      llvm::append_range(group, ArrayRef(resultTypes).drop_front(sizeBefore));
+    }
+    sizeBefore = resultTypes.size();
   }
 
+  if (resultTypes.size() == 1)
+    return resultTypes.front();
+  if (resultTypes.empty())
+    return {};
   return LLVM::LLVMStructType::getLiteral(&getContext(), resultTypes);
 }
 
@@ -740,40 +742,40 @@ Value LLVMTypeConverter::promoteOneMemRefDescriptor(Location loc, Value operand,
   return allocated;
 }
 
-SmallVector<Value, 4>
-LLVMTypeConverter::promoteOperands(Location loc, ValueRange opOperands,
-                                   ValueRange operands, OpBuilder &builder,
-                                   bool useBarePtrCallConv) const {
+SmallVector<Value, 4> LLVMTypeConverter::promoteOperands(
+    Location loc, ValueRange opOperands, ArrayRef<ValueRange> operands,
+    OpBuilder &builder, bool useBarePtrCallConv) const {
   SmallVector<Value, 4> promotedOperands;
   promotedOperands.reserve(operands.size());
   useBarePtrCallConv |= options.useBarePtrCallConv;
-  for (auto it : llvm::zip(opOperands, operands)) {
-    auto operand = std::get<0>(it);
-    auto llvmOperand = std::get<1>(it);
-
+  for (auto [operand, llvmOperand] : llvm::zip_equal(opOperands, operands)) {
     if (useBarePtrCallConv) {
       // For the bare-ptr calling convention, we only have to extract the
       // aligned pointer of a memref.
       if (isa<MemRefType>(operand.getType())) {
-        MemRefDescriptor desc(llvmOperand);
-        llvmOperand = desc.alignedPtr(builder, loc);
+        assert(llvmOperand.size() == 1 && "Expected a single operand");
+        MemRefDescriptor desc(llvmOperand.front());
+        promotedOperands.push_back(desc.alignedPtr(builder, loc));
+        continue;
       } else if (isa<UnrankedMemRefType>(operand.getType())) {
         llvm_unreachable("Unranked memrefs are not supported");
       }
     } else {
       if (isa<UnrankedMemRefType>(operand.getType())) {
-        UnrankedMemRefDescriptor::unpack(builder, loc, llvmOperand,
+        assert(llvmOperand.size() == 1 && "Expected a single operand");
+        UnrankedMemRefDescriptor::unpack(builder, loc, llvmOperand.front(),
                                          promotedOperands);
         continue;
       }
       if (auto memrefType = dyn_cast<MemRefType>(operand.getType())) {
-        MemRefDescriptor::unpack(builder, loc, llvmOperand, memrefType,
+        assert(llvmOperand.size() == 1 && "Expected a single operand");
+        MemRefDescriptor::unpack(builder, loc, llvmOperand.front(), memrefType,
                                  promotedOperands);
         continue;
       }
     }
 
-    promotedOperands.push_back(llvmOperand);
+    llvm::append_range(promotedOperands, llvmOperand);
   }
   return promotedOperands;
 }
@@ -802,11 +804,7 @@ mlir::structFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type,
     result.append(converted.begin(), converted.end());
     return success();
   }
-  auto converted = converter.convertType(type);
-  if (!converted)
-    return failure();
-  result.push_back(converted);
-  return success();
+  return converter.convertType(type, result);
 }
 
 /// Callback to convert function argument types. It converts MemRef function
@@ -814,11 +812,7 @@ mlir::structFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type,
 LogicalResult
 mlir::barePtrFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type,
                                   SmallVectorImpl<Type> &result) {
-  auto llvmTy = converter.convertCallingConventionType(
-      type, /*useBarePointerCallConv=*/true);
-  if (!llvmTy)
-    return failure();
-
-  result.push_back(llvmTy);
-  return success();
+  return converter.convertCallingConventionType(
+      type, result,
+      /*useBarePointerCallConv=*/true);
 }
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index f7f5381799529..46856203672e6 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -1106,12 +1106,10 @@ struct NVGPUGenerateWarpgroupDescriptorLowering
     // // [0,14)   start_address
     dsc = insertBit(dsc, basePtr14bit, startBaseAddrBit);
 
-    LDBG() << "Generating warpgroup.descriptor: "
-           << "leading_off:" << leadDimVal << "\t"
-           << "stride_off :" << strideDimVal << "\t"
-           << "base_offset:" << offsetVal << "\t"
-           << "layout_type:" << swizzle << " ("
-           << nvgpu::stringifyTensorMapSwizzleKind(swizzleKind)
+    LDBG() << "Generating warpgroup.descriptor: " << "leading_off:"
+           << leadDimVal << "\t" << "stride_off :" << strideDimVal << "\t"
+           << "base_offset:" << offsetVal << "\t" << "layout_type:" << swizzle
+           << " (" << nvgpu::stringifyTensorMapSwizzleKind(swizzleKind)
            << ")\n start_addr :  " << baseAddr;
 
     rewriter.replaceOp(op, dsc);
@@ -1181,8 +1179,10 @@ struct NVGPUTmaCreateDescriptorOpLowering
 
     Value tensorElementType =
         elementTypeAsLLVMConstant(b, op.getTensor().getType().getElementType());
+    SmallVector<ValueRange> adaptorOperands = llvm::map_to_vector(
+        adaptor.getOperands(), [](Value v) { return ValueRange(v); });
     auto promotedOperands = getTypeConverter()->promoteOperands(
-        b.getLoc(), op->getOperands(), adaptor.getOperands(), b);
+        b.getLoc(), op->getOperands(), adaptorOperands, b);
 
     Value boxArrayPtr = LLVM::AllocaOp::create(
         b, llvmPointerType, llvmInt64Type, makeI64Const(b, 5));
@@ -1401,14 +1401,12 @@ struct NVGPUWarpgroupMmaOpLowering
     /// This function generates a WgmmaMmaAsyncOp using provided GMMA matrix
     /// descriptors and arranges them based on induction variables: i, j, and k.
     Value generateWgmma(int i, int j, int k, Value matrixC) {
-      LDBG() << "\t wgmma."
-             << "m" << wgmmaM << "n" << wgmmaN << "k" << wgmmaK << "(A["
-             << (iterationM * wgmmaM) << ":" << (iterationM * wgmmaM) + wgmmaM
-             << "][" << (iterationK * wgmmaK) << ":"
-             << (iterationK * wgmmaK + wgmmaK) << "] * "
-             << " B[" << (iterationK * wgmmaK) << ":"
-             << (iterationK * wgmmaK + wgmmaK) << "][" << 0 << ":" << wgmmaN
-             << "])";
+      LDBG() << "\t wgmma." << "m" << wgmmaM << "n" << wgmmaN << "k" << wgmmaK
+             << "(A[" << (iterationM * wgmmaM) << ":"
+             << (iterationM * wgmmaM) + wgmmaM << "][" << (iterationK * wgmmaK)
+             << ":" << (iterationK * wgmmaK + wgmmaK) << "] * " << " B["
+             << (iterationK * wgmmaK) << ":" << (iterationK * wgmmaK + wgmmaK)
+             << "][" << 0 << ":" << wgmmaN << "])";
 
       Value descriptorA = iterateDescriptorA(adaptor.getDescriptorA(), i, j, k);
       Value descriptorB = iterateDescriptorB(adaptor.getDescriptorB(), i, j, k);
diff --git a/mlir/test/Conversion/MemRefToLLVM/type-conversion.mlir b/mlir/test/Conversion/MemRefToLLVM/type-conversion.mlir
index 0288aa11313c7..c1751f282b002 100644
--- a/mlir/test/Conversion/MemRefToLLVM/type-conversion.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/type-conversion.mlir
@@ -1,12 +1,13 @@
-// RUN: mlir-opt %s -test-llvm-legalize-patterns -split-input-file
+// RUN: mlir-opt %s -test-llvm-legalize-patterns -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -test-llvm-legalize-patterns="allow-pattern-rollback=0" -split-input-file | FileCheck %s
 
 // Test the argument materializer for ranked MemRef types.
 
 //   CHECK-LABEL: func @construct_ranked_memref_descriptor(
-//         CHECK:   llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+//         CHECK:   llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
 // CHECK-COUNT-7:   llvm.insertvalue
 //         CHECK:   builtin.unrealized_conversion_cast %{{.*}} : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> to memref<5x4xf32>
-func.func @construct_ranked_memref_descriptor(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: i64, %arg3: i64, %arg4: i64, %arg5: i64, %arg6: i64) {
+func.func @construct_ranked_memref_descriptor(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: i64, %arg3: i64, %arg4: i64, %arg5: i64, %arg6: i64) attributes {is_legal} {
   %0 = "test.direct_replacement"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6) : (!llvm.ptr, !llvm.ptr, i64, i64, i64, i64, i64) -> (memref<5x4xf32>)
   "test.legal_op"(%0) : (memref<5x4xf32>) -> ()
   return
@@ -21,7 +22,7 @@ func.func @construct_ranked_memref_descriptor(%arg0: !llvm.ptr, %arg1: !llvm.ptr
 // CHECK-LABEL: func @invalid_ranked_memref_descriptor(
 //       CHECK:   %[[cast:.*]] = builtin.unrealized_conversion_cast %{{.*}} : i1 to memref<5x4xf32>
 //       CHECK:   "test.legal_op"(%[[cast]])
-func.func @invalid_ranked_memref_descriptor(%arg0: i1) {
+func.func @invalid_ranked_memref_descriptor(%arg0: i1) attributes {is_legal} {
   %0 = "test.direct_replacement"(%arg0) : (i1) -> (memref<5x4xf32>)
   "test.legal_op"(%0) : (memref<5x4xf32>) -> ()
   return
@@ -32,10 +33,10 @@ func.func @invalid_ranked_memref_descriptor(%arg0: i1) {
 // Test the argument materializer for unranked MemRef types.
 
 //   CHECK-LABEL: func @construct_unranked_memref_descriptor(
-//         CHECK:   llvm.mlir.undef : !llvm.struct<(i64, ptr)>
+//         CHECK:   llvm.mlir.poison : !llvm.struct<(i64, ptr)>
 // CHECK-COUNT-2:   llvm.insertvalue
 //         CHECK:   builtin.unrealized_conversion_cast %{{.*}} : !llvm.struct<(i64, ptr)> to memref<*xf32>
-func.func @construct_unranked_memref_descriptor(%arg0: i64, %arg1: !llvm.ptr) {
+func.func @construct_unranked_memref_descriptor(%arg0: i64, %arg1: !llvm.ptr) attributes {is_legal} {
   %0 = "test.direct_replacement"(%arg0, %arg1) : (i64, !llvm.ptr) -> (memref<*xf32>)
   "test.legal_op"(%0) : (memref<*xf32>) -> ()
   return
@@ -50,8 +51,90 @@ func.func @construct_unranked_memref_descriptor(%arg0: i64, %arg1: !llvm.ptr) {
 // CHECK-LABEL: func @invalid_unranked_memref_descriptor(
 //       CHECK:   %[[cast:.*]] = builtin.unrealized_conversion_cast %{{.*}} : i1 to memref<*xf32>
 //       CHECK:   "test.legal_op"(%[[cast]])
-func.func @invalid_unranked_memref_descriptor(%arg0: i1) {
+func.func @invalid_unranked_memref_descriptor(%arg0: i1) attributes {is_legal} {
   %0 = "test.direct_replacement"(%arg0) : (i1) -> (memref<*xf32>)
   "test.legal_op"(%0) : (memref<*xf32>) -> ()
   return
 }
+
+// -----
+
+// CHECK-LABEL: llvm.func @simple_func_conversion(
+//  CHECK-SAME:     %[[arg0:.*]]: i64) -> i64
+//       CHECK:   llvm.return %[[arg0]] : i64
+func.func @simple_func_conversion(%arg0: i64) -> i64 {
+  return %arg0 : i64
+}
+
+// -----
+
+// CHECK-LABEL: llvm.func @one_to_n_argument_conversion(
+//  CHECK-SAME:     %[[arg0:.*]]: i18, %[[arg1:.*]]: i18)
+//       CHECK:   %[[cast:.*]] = builtin.unrealized_conversion_cast %[[arg0]], %[[arg1]] : i18, i18 to i17
+//       CHECK:   "test.legal_op"(%[[cast]]) : (i17) -> ()
+func.func @one_to_n_argument_conversion(%arg0: i17) {
+  "test.legal_op"(%arg0) : (i17) -> ()
+  return
+}
+
+// CHECK: llvm.func @caller(%[[arg0:.*]]: i18, %[[arg1:.*]]: i18)
+// CHECK:   llvm.call @one_to_n_argument_conversion(%[[arg0]], %[[arg1]]) : (i18, i18) -> ()
+func.func @caller(%arg0: i17) {
+  func.call @one_to_n_argument_conversion(%arg0) : (i17) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: llvm.func @one_to_n_return_conversion(
+//  CHECK-SAME:     %[[arg0:.*]]: i18, %[[arg1:.*]]: i18) -> !llvm.struct<(i18, i18)>
+//       CHECK:   %[[p1:.*]] = llvm.mlir.poison : !llvm.struct<(i18, i18)>
+//       CHECK:   %[[p2:.*]] = llvm.insertvalue %[[arg0]], %[[p1]][0] : !llvm.struct<(i18, i18)>
+//       CHECK:   %[[p3:.*]] = llvm.insertvalue %[[arg1]], %[[p2]][1] : !llvm.struct<(i18, i18)>
+//       CHECK:   llvm.return %[[p3]]
+func.func @one_to_n_return_conversion(%arg0: i17) -> i17 {
+  return %arg0 : i17
+}
+
+// CHECK: llvm.func @caller(%[[arg0:.*]]: i18, %[[arg1:.*]]: i18)
+// CHECK:   %[[res:.*]] = llvm.call @one_to_n_return_conversion(%[[arg0]], %[[arg1]]) : (i18, i18) -> !llvm.struct<(i18, i18)>
+// CHECK:   %[[e0:.*]] = llvm.extractvalue %[[res]][0] : !llvm.struct<(i18, i18)>
+// CHECK:   %[[e1:.*]] = llvm.extractvalue %[[res]][1] : !llvm.struct<(i18, i18)>
+// CHECK:   %[[i0:.*]] = llvm.mlir.poison : !llvm.struct<(i18, i18)>
+// CHECK:   %[[i1:.*]] = llvm.insertvalue %[[e0]], %[[i0]][0] : !llvm.struct<(i18, i18)>
+// CHECK:   %[[i2:.*]] = llvm.insertvalue %[[e1]], %[[i1]][1] : !llvm.struct<(i18, i18)>
+// CHECK:   llvm.return %[[i2]]
+func.func @caller(%arg0: i17) -> (i17) {
+  %res = func.call @one_to_n_return_conversion(%arg0) : (i17) -> (i17)
+  return %res : i17
+}
+
+// -----
+
+// CHECK-LABEL: llvm.func @multi_return(
+//  CHECK-SAME:     %[[arg0:.*]]: i18, %[[arg1:.*]]: i18, %[[arg2:.*]]: i1) -> !llvm.struct<(i18, i18, i1)>
+//       CHECK:   %[[p1:.*]] = llvm.mlir.poison : !llvm.struct<(i18, i18, i1)>
+//       CHECK:   %[[p2:.*]] = llvm.insertvalue %[[arg0]], %[[p1]][0] : !llvm.struct<(i18, i18, i1)>
+//       CHECK:   %[[p3:.*]] = llvm.insertvalue %[[arg1]], %[[p2]][1] : !llvm.struct<(i18, i18, i1)>
+//       CHECK:   %[[p4:.*]] = llvm.insertvalue %[[arg2]], %[[p3]][2] : !llvm.struct<(i18, i18, i1)>
+//       CHECK:   llvm.return %[[p4]]
+func.func @multi_return(%arg0: i17, %arg1: i1) -> (i17, i1) {
+  return %arg0, %arg1 : i17, i1
+}
+
+// CHECK: llvm.func @caller(%[[arg0:.*]]: i1, %[[arg1:.*]]: i18, %[[arg2:.*]]: i18)
+// CHECK:   %[[res:.*]] = llvm.call @multi_return(%[[arg1]], %[[arg2]], %[[arg0]]) : (i18, i18, i1) -> !llvm.struct<(i18, i18, i1)>
+// CHECK:   %[[e0:.*]] = llvm.extractvalue %[[res]][0] : !llvm.struct<(i18, i18, i1)>
+// CHECK:   %[[e1:.*]] = llvm.extractvalue %[[res]][1] : !llvm.struct<(i18, i18, i1)>
+// CHECK:   %[[e2:.*]] = llvm.extractvalue %[[res]][2] : !llvm.struct<(i18, i18, i1)>
+// CHECK:   %[[i0:.*]] = llvm.mlir.poison : !llvm.struct<(i18, i18, i1, i18, i18)>
+// CHECK:   %[[i1:.*]] = llvm.insertvalue %[[e0]], %[[i0]][0]
+// CHECK:   %[[i2:.*]] = llvm.insertvalue %[[e1]], %[[i1]][1]
+// CHECK:   %[[i3:.*]] = llvm.insertvalue %[[e2]], %[[i2]][2]
+// CHECK:   %[[i4:.*]] = llvm.insertvalue %[[e0]], %[[i3]][3]
+// CHECK:   %[[i5:.*]] = llvm.insertvalue %[[e1]], %[[i4]][4]
+// CHECK:   llvm.return %[[i5]]
+func.func @caller(%arg0: i1, %arg1: i17) -> (i17, i1, i17) {
+  %res:2 = func.call @multi_return(%arg1, %arg0) : (i17, i1) -> (i17, i1)
+  return %res#0, %res#1, %res#0 : i17, i1, i17
+}
diff --git a/mlir/test/lib/Dialect/LLVM/TestPatterns.cpp b/mlir/test/lib/Dialect/LLVM/TestPatterns.cpp
index ab02866970b1d..fe9aa0f2a9902 100644
--- a/mlir/test/lib/Dialect/LLVM/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/LLVM/TestPatterns.cpp
@@ -6,7 +6,9 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
 #include "mlir/Pass/Pass.h"
@@ -34,6 +36,10 @@ struct TestLLVMLegalizePatternsPass
     : public PassWrapper<TestLLVMLegalizePatternsPass, OperationPass<>> {
   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLLVMLegalizePatternsPass)
 
+  TestLLVMLegalizePatternsPass() = default;
+  TestLLVMLegalizePatternsPass(const TestLLVMLegalizePatternsPass &other)
+      : PassWrapper(other) {}
+
   StringRef getArgument() const final { return "test-llvm-legalize-patterns"; }
   StringRef getDescription() const final {
     return "Run LLVM dialect legalization patterns";
@@ -45,22 +51,46 @@ struct TestLLVMLegalizePatternsPass
 
   void runOnOperation() override {
     MLIRContext *ctx = &getContext();
+
+    // Set up type converter.
     LLVMTypeConverter converter(ctx);
+    converter.addConversion(
+        [&](IntegerType type, SmallVectorImpl<Type> &result) {
+          if (type.isInteger(17)) {
+            // Convert i17 -> (i18, i18).
+            result.append(2, Builder(ctx).getIntegerType(18));
+            return success();
+          }
+
+          result.push_back(type);
+          return success();
+        });
+
+    // Populate patterns.
     mlir::RewritePatternSet patterns(ctx);
     patterns.add<TestDirectReplacementOp>(ctx, converter);
+    populateFuncToLLVMConversionPatterns(converter, patterns);
 
     // Define the conversion target used for the test.
     ConversionTarget target(*ctx);
     target.addLegalOp(OperationName("test.legal_op", ctx));
+    target.addLegalDialect<LLVM::LLVMDialect>();
+    target.addDynamicallyLegalOp<func::FuncOp>(
+        [&](func::FuncOp funcOp) { return funcOp->hasAttr("is_legal"); });
 
     // Handle a partial conversion.
     DenseSet<Operation *> unlegalizedOps;
     ConversionConfig config;
     config.unlegalizedOps = &unlegalizedOps;
+    config.allowPatternRollback = allowPatternRollback;
     if (failed(applyPartialConversion(getOperation(), target,
                                       std::move(patterns), config)))
       getOperation()->emitError() << "applyPartialConversion failed";
   }
+
+  Option<bool> allowPatternRollback{*this, "allow-pattern-rollback",
+                                    llvm::cl::desc("Allow pattern rollback"),
+                                    llvm::cl::init(true)};
 };
 } // namespace
 



More information about the Mlir-commits mailing list