[Mlir-commits] [mlir] [mlir][LLVM] FuncToLLVM: Add 1:N support (PR #153823)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Aug 15 08:43:49 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-gpu

Author: Matthias Springer (matthias-springer)

<details>
<summary>Changes</summary>

Add support for 1:N type conversions to the `FuncToLLVM` lowering patterns. This commit does not change the lowering of any types (such as `MemRefType`). It just sets up the infrastructure, such that 1:N type conversions can be used during `FuncToLLVM`.


---

Patch is 33.90 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/153823.diff


7 Files Affected:

- (modified) mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h (+12-12) 
- (modified) mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp (+66-37) 
- (modified) mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp (+9-3) 
- (modified) mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp (+47-53) 
- (modified) mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp (+13-15) 
- (modified) mlir/test/Conversion/MemRefToLLVM/type-conversion.mlir (+89-7) 
- (modified) mlir/test/lib/Dialect/LLVM/TestPatterns.cpp (+21) 


``````````diff
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
index 38b5e492a8ed8..a38b3283416e0 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
@@ -74,8 +74,13 @@ class LLVMTypeConverter : public TypeConverter {
   /// LLVM-compatible type. In particular, if more than one value is returned,
   /// create an LLVM dialect structure type with elements that correspond to
   /// each of the types converted with `convertCallingConventionType`.
-  Type packFunctionResults(TypeRange types,
-                           bool useBarePointerCallConv = false) const;
+  ///
+  /// Populate the converted (unpacked) types into `groupedTypes`, if provided.
+  /// `groupedType` contains one nested vector per input type. In case of a 1:N
+  /// conversion, a nested vector may contain 0 or more then 1 converted type.
+  Type packFunctionResults(
+      TypeRange types, bool useBarePointerCallConv = false,
+      SmallVector<SmallVector<Type>> *groupedTypes = nullptr) const;
 
   /// Convert a non-empty list of types of values produced by an operation into
   /// an LLVM-compatible type. In particular, if more than one value is
@@ -88,15 +93,9 @@ class LLVMTypeConverter : public TypeConverter {
   /// UnrankedMemRefType, are converted following the specific rules for the
   /// calling convention. Calling convention independent types are converted
   /// following the default LLVM type conversions.
-  Type convertCallingConventionType(Type type,
-                                    bool useBarePointerCallConv = false) const;
-
-  /// Promote the bare pointers in 'values' that resulted from memrefs to
-  /// descriptors. 'stdTypes' holds the types of 'values' before the conversion
-  /// to the LLVM-IR dialect (i.e., MemRefType, or any other builtin type).
-  void promoteBarePtrsToDescriptors(ConversionPatternRewriter &rewriter,
-                                    Location loc, ArrayRef<Type> stdTypes,
-                                    SmallVectorImpl<Value> &values) const;
+  LogicalResult
+  convertCallingConventionType(Type type, SmallVectorImpl<Type> &result,
+                               bool useBarePointerCallConv = false) const;
 
   /// Returns the MLIR context.
   MLIRContext &getContext() const;
@@ -111,7 +110,8 @@ class LLVMTypeConverter : public TypeConverter {
   /// of the platform-specific C/C++ ABI lowering related to struct argument
   /// passing.
   SmallVector<Value, 4> promoteOperands(Location loc, ValueRange opOperands,
-                                        ValueRange operands, OpBuilder &builder,
+                                        ArrayRef<ValueRange> operands,
+                                        OpBuilder &builder,
                                         bool useBarePtrCallConv = false) const;
 
   /// Promote the LLVM struct representation of one MemRef descriptor to stack
diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
index a4a6ae250640f..9ff96c88d9d6f 100644
--- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
+++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
@@ -527,19 +527,19 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> {
   using ConvertOpToLLVMPattern<CallOpType>::ConvertOpToLLVMPattern;
   using Super = CallOpInterfaceLowering<CallOpType>;
   using Base = ConvertOpToLLVMPattern<CallOpType>;
+  using Adaptor = typename ConvertOpToLLVMPattern<CallOpType>::OneToNOpAdaptor;
 
-  LogicalResult matchAndRewriteImpl(CallOpType callOp,
-                                    typename CallOpType::Adaptor adaptor,
+  LogicalResult matchAndRewriteImpl(CallOpType callOp, Adaptor adaptor,
                                     ConversionPatternRewriter &rewriter,
                                     bool useBarePtrCallConv = false) const {
     // Pack the result types into a struct.
     Type packedResult = nullptr;
+    SmallVector<SmallVector<Type>> groupedResultTypes;
     unsigned numResults = callOp.getNumResults();
     auto resultTypes = llvm::to_vector<4>(callOp.getResultTypes());
-
     if (numResults != 0) {
       if (!(packedResult = this->getTypeConverter()->packFunctionResults(
-                resultTypes, useBarePtrCallConv)))
+                resultTypes, useBarePtrCallConv, &groupedResultTypes)))
         return failure();
     }
 
@@ -565,34 +565,60 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> {
         static_cast<int32_t>(promoted.size()), 0};
     newOp.getProperties().op_bundle_sizes = rewriter.getDenseI32ArrayAttr({});
 
-    SmallVector<Value, 4> results;
-    if (numResults < 2) {
-      // If < 2 results, packing did not do anything and we can just return.
-      results.append(newOp.result_begin(), newOp.result_end());
-    } else {
-      // Otherwise, it had been converted to an operation producing a structure.
-      // Extract individual results from the structure and return them as list.
-      results.reserve(numResults);
-      for (unsigned i = 0; i < numResults; ++i) {
-        results.push_back(LLVM::ExtractValueOp::create(
-            rewriter, callOp.getLoc(), newOp->getResult(0), i));
+    // Helper function that extracts an individual result from the return value
+    // of the new call op. llvm.call ops support only 0 or 1 result. In case of
+    // 2 or more results, the results are packed into a structure.
+    auto getUnpackedResult = [&](unsigned i) -> Value {
+      assert(packedResult && "convert op has no results");
+      if (!isa<LLVM::LLVMStructType>(packedResult)) {
+        assert(i == 0 && "out of bounds: converted op has only one result");
+        return newOp->getResult(0);
+      }
+      // Results have been converted to a structure. Extract individual results
+      // from the structure.
+      return LLVM::ExtractValueOp::create(rewriter, callOp.getLoc(),
+                                          newOp->getResult(0), i);
+    };
+
+    // Group the results into a vector of vectors, such that it is clear which
+    // original op result is replaced with which range of values. (In case of a
+    // 1:N conversion, there can be multiple replacements for a single result.)
+    SmallVector<SmallVector<Value>> results;
+    results.reserve(numResults);
+    unsigned counter = 0;
+    for (unsigned i = 0; i < numResults; ++i) {
+      SmallVector<Value> &group = results.emplace_back();
+      for (unsigned j = 0, e = groupedResultTypes[i].size(); j < e; ++j) {
+        group.push_back(getUnpackedResult(counter++));
       }
     }
 
-    if (useBarePtrCallConv) {
-      // For the bare-ptr calling convention, promote memref results to
-      // descriptors.
-      assert(results.size() == resultTypes.size() &&
-             "The number of arguments and types doesn't match");
-      this->getTypeConverter()->promoteBarePtrsToDescriptors(
-          rewriter, callOp.getLoc(), resultTypes, results);
-    } else if (failed(this->copyUnrankedDescriptors(rewriter, callOp.getLoc(),
-                                                    resultTypes, results,
-                                                    /*toDynamic=*/false))) {
-      return failure();
+    for (unsigned i = 0; i < numResults; ++i) {
+      Type origType = resultTypes[i];
+      auto memrefType = dyn_cast<MemRefType>(origType);
+      auto unrankedMemrefType = dyn_cast<UnrankedMemRefType>(origType);
+      if (useBarePtrCallConv && memrefType) {
+        // For the bare-ptr calling convention, promote memref results to
+        // descriptors.
+        assert(results[i].size() == 1 && "expected one converted result");
+        results[i].front() = MemRefDescriptor::fromStaticShape(
+            rewriter, callOp.getLoc(), *this->getTypeConverter(), memrefType,
+            results[i].front());
+      }
+      if (unrankedMemrefType) {
+        assert(!useBarePtrCallConv && "unranked memref is not supported in the "
+                                      "bare-ptr calling convention");
+        assert(results[i].size() == 1 && "expected one converted result");
+        Value desc = this->copyUnrankedDescriptor(
+            rewriter, callOp.getLoc(), unrankedMemrefType, results[i].front(),
+            /*toDynamic=*/false);
+        if (!desc)
+          return failure();
+        results[i].front() = desc;
+      }
     }
 
-    rewriter.replaceOp(callOp, results);
+    rewriter.replaceOpWithMultiple(callOp, results);
     return success();
   }
 };
@@ -606,7 +632,7 @@ class CallOpLowering : public CallOpInterfaceLowering<func::CallOp> {
         symbolTables(symbolTables) {}
 
   LogicalResult
-  matchAndRewrite(func::CallOp callOp, OpAdaptor adaptor,
+  matchAndRewrite(func::CallOp callOp, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     bool useBarePtrCallConv = false;
     if (getTypeConverter()->getOptions().useBarePtrCallConv) {
@@ -636,7 +662,7 @@ struct CallIndirectOpLowering
   using Super::Super;
 
   LogicalResult
-  matchAndRewrite(func::CallIndirectOp callIndirectOp, OpAdaptor adaptor,
+  matchAndRewrite(func::CallIndirectOp callIndirectOp, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     return matchAndRewriteImpl(callIndirectOp, adaptor, rewriter);
   }
@@ -679,47 +705,50 @@ struct ReturnOpLowering : public ConvertOpToLLVMPattern<func::ReturnOp> {
   using ConvertOpToLLVMPattern<func::ReturnOp>::ConvertOpToLLVMPattern;
 
   LogicalResult
-  matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor,
+  matchAndRewrite(func::ReturnOp op, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     Location loc = op.getLoc();
-    unsigned numArguments = op.getNumOperands();
     SmallVector<Value, 4> updatedOperands;
 
     auto funcOp = op->getParentOfType<LLVM::LLVMFuncOp>();
     bool useBarePtrCallConv =
         shouldUseBarePtrCallConv(funcOp, this->getTypeConverter());
 
-    for (auto [oldOperand, newOperand] :
+    for (auto [oldOperand, newOperands] :
          llvm::zip_equal(op->getOperands(), adaptor.getOperands())) {
       Type oldTy = oldOperand.getType();
       if (auto memRefType = dyn_cast<MemRefType>(oldTy)) {
+        assert(newOperands.size() == 1 && "expected one converted result");
         if (useBarePtrCallConv &&
             getTypeConverter()->canConvertToBarePtr(memRefType)) {
           // For the bare-ptr calling convention, extract the aligned pointer to
           // be returned from the memref descriptor.
-          MemRefDescriptor memrefDesc(newOperand);
+          MemRefDescriptor memrefDesc(newOperands.front());
           updatedOperands.push_back(memrefDesc.allocatedPtr(rewriter, loc));
           continue;
         }
       } else if (auto unrankedMemRefType =
                      dyn_cast<UnrankedMemRefType>(oldTy)) {
+        assert(newOperands.size() == 1 && "expected one converted result");
         if (useBarePtrCallConv) {
           // Unranked memref is not supported in the bare pointer calling
           // convention.
           return failure();
         }
-        Value updatedDesc = copyUnrankedDescriptor(
-            rewriter, loc, unrankedMemRefType, newOperand, /*toDynamic=*/true);
+        Value updatedDesc =
+            copyUnrankedDescriptor(rewriter, loc, unrankedMemRefType,
+                                   newOperands.front(), /*toDynamic=*/true);
         if (!updatedDesc)
           return failure();
         updatedOperands.push_back(updatedDesc);
         continue;
       }
-      updatedOperands.push_back(newOperand);
+
+      llvm::append_range(updatedOperands, newOperands);
     }
 
     // If ReturnOp has 0 or 1 operand, create it and return immediately.
-    if (numArguments <= 1) {
+    if (updatedOperands.size() <= 1) {
       rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(
           op, TypeRange(), updatedOperands, op->getAttrs());
       return success();
diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
index 3cfbd898e49e2..a3ec644dcc068 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
@@ -719,8 +719,10 @@ LogicalResult ConvertHostRegisterOpToGpuRuntimeCallPattern::matchAndRewrite(
   auto elementType = cast<UnrankedMemRefType>(memRefType).getElementType();
   auto elementSize = getSizeInBytes(loc, elementType, rewriter);
 
+  SmallVector<ValueRange> adaptorOperands = llvm::map_to_vector(
+      adaptor.getOperands(), [](Value v) { return ValueRange(v); });
   auto arguments = getTypeConverter()->promoteOperands(
-      loc, op->getOperands(), adaptor.getOperands(), rewriter);
+      loc, op->getOperands(), adaptorOperands, rewriter);
   arguments.push_back(elementSize);
   hostRegisterCallBuilder.create(loc, rewriter, arguments);
 
@@ -741,8 +743,10 @@ LogicalResult ConvertHostUnregisterOpToGpuRuntimeCallPattern::matchAndRewrite(
   auto elementType = cast<UnrankedMemRefType>(memRefType).getElementType();
   auto elementSize = getSizeInBytes(loc, elementType, rewriter);
 
+  SmallVector<ValueRange> adaptorOperands = llvm::map_to_vector(
+      adaptor.getOperands(), [](Value v) { return ValueRange(v); });
   auto arguments = getTypeConverter()->promoteOperands(
-      loc, op->getOperands(), adaptor.getOperands(), rewriter);
+      loc, op->getOperands(), adaptorOperands, rewriter);
   arguments.push_back(elementSize);
   hostUnregisterCallBuilder.create(loc, rewriter, arguments);
 
@@ -973,8 +977,10 @@ LogicalResult LegalizeLaunchFuncOpPattern::matchAndRewrite(
   // Note: If `useBarePtrCallConv` is set in the type converter's options,
   // the value of `kernelBarePtrCallConv` will be ignored.
   OperandRange origArguments = launchOp.getKernelOperands();
+  SmallVector<ValueRange> adaptorOperands = llvm::map_to_vector(
+      adaptor.getKernelOperands(), [](Value v) { return ValueRange(v); });
   SmallVector<Value, 8> llvmArguments = getTypeConverter()->promoteOperands(
-      loc, origArguments, adaptor.getKernelOperands(), rewriter,
+      loc, origArguments, adaptorOperands, rewriter,
       /*useBarePtrCallConv=*/kernelBarePtrCallConv);
   SmallVector<Value, 8> llvmArgumentsWithSizes;
 
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index 1a9bf569086da..621900e40f77d 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -365,6 +365,7 @@ Type LLVMTypeConverter::convertFunctionSignatureImpl(
   useBarePtrCallConv = useBarePtrCallConv || options.useBarePtrCallConv;
   auto funcArgConverter = useBarePtrCallConv ? barePtrFuncArgTypeConverter
                                              : structFuncArgTypeConverter;
+
   // Convert argument types one by one and check for errors.
   for (auto [idx, type] : llvm::enumerate(funcTy.getInputs())) {
     SmallVector<Type, 8> converted;
@@ -658,27 +659,19 @@ FailureOr<Type> LLVMTypeConverter::convertVectorType(VectorType type) const {
 /// UnrankedMemRefType, are converted following the specific rules for the
 /// calling convention. Calling convention independent types are converted
 /// following the default LLVM type conversions.
-Type LLVMTypeConverter::convertCallingConventionType(
-    Type type, bool useBarePtrCallConv) const {
-  if (useBarePtrCallConv)
-    if (auto memrefTy = dyn_cast<BaseMemRefType>(type))
-      return convertMemRefToBarePtr(memrefTy);
-
-  return convertType(type);
-}
+LogicalResult LLVMTypeConverter::convertCallingConventionType(
+    Type type, SmallVectorImpl<Type> &result, bool useBarePtrCallConv) const {
+  if (useBarePtrCallConv) {
+    if (auto memrefTy = dyn_cast<BaseMemRefType>(type)) {
+      Type converted = convertMemRefToBarePtr(memrefTy);
+      if (!converted)
+        return failure();
+      result.push_back(converted);
+      return success();
+    }
+  }
 
-/// Promote the bare pointers in 'values' that resulted from memrefs to
-/// descriptors. 'stdTypes' holds they types of 'values' before the conversion
-/// to the LLVM-IR dialect (i.e., MemRefType, or any other builtin type).
-void LLVMTypeConverter::promoteBarePtrsToDescriptors(
-    ConversionPatternRewriter &rewriter, Location loc, ArrayRef<Type> stdTypes,
-    SmallVectorImpl<Value> &values) const {
-  assert(stdTypes.size() == values.size() &&
-         "The number of types and values doesn't match");
-  for (unsigned i = 0, end = values.size(); i < end; ++i)
-    if (auto memrefTy = dyn_cast<MemRefType>(stdTypes[i]))
-      values[i] = MemRefDescriptor::fromStaticShape(rewriter, loc, *this,
-                                                    memrefTy, values[i]);
+  return convertType(type, result);
 }
 
 /// Convert a non-empty list of types of values produced by an operation into an
@@ -706,23 +699,32 @@ Type LLVMTypeConverter::packOperationResults(TypeRange types) const {
 /// LLVM-compatible type. In particular, if more than one value is returned,
 /// create an LLVM dialect structure type with elements that correspond to each
 /// of the types converted with `convertCallingConventionType`.
-Type LLVMTypeConverter::packFunctionResults(TypeRange types,
-                                            bool useBarePtrCallConv) const {
+Type LLVMTypeConverter::packFunctionResults(
+    TypeRange types, bool useBarePtrCallConv,
+    SmallVector<SmallVector<Type>> *groupedTypes) const {
   assert(!types.empty() && "expected non-empty list of type");
+  assert((!groupedTypes || groupedTypes->empty()) &&
+         "expected groupedTypes to be empty");
 
   useBarePtrCallConv |= options.useBarePtrCallConv;
-  if (types.size() == 1)
-    return convertCallingConventionType(types.front(), useBarePtrCallConv);
-
   SmallVector<Type> resultTypes;
   resultTypes.reserve(types.size());
+  size_t sizeBefore = 0;
   for (auto t : types) {
-    auto converted = convertCallingConventionType(t, useBarePtrCallConv);
-    if (!converted || !LLVM::isCompatibleType(converted))
+    if (failed(
+            convertCallingConventionType(t, resultTypes, useBarePtrCallConv)))
       return {};
-    resultTypes.push_back(converted);
+    if (groupedTypes) {
+      SmallVector<Type> &group = groupedTypes->emplace_back();
+      llvm::append_range(group, ArrayRef(resultTypes).drop_front(sizeBefore));
+    }
+    sizeBefore = resultTypes.size();
   }
 
+  if (resultTypes.size() == 1)
+    return resultTypes.front();
+  if (resultTypes.empty())
+    return {};
   return LLVM::LLVMStructType::getLiteral(&getContext(), resultTypes);
 }
 
@@ -740,40 +742,40 @@ Value LLVMTypeConverter::promoteOneMemRefDescriptor(Location loc, Value operand,
   return allocated;
 }
 
-SmallVector<Value, 4>
-LLVMTypeConverter::promoteOperands(Location loc, ValueRange opOperands,
-                                   ValueRange operands, OpBuilder &builder,
-                                   bool useBarePtrCallConv) const {
+SmallVector<Value, 4> LLVMTypeConverter::promoteOperands(
+    Location loc, ValueRange opOperands, ArrayRef<ValueRange> operands,
+    OpBuilder &builder, bool useBarePtrCallConv) const {
   SmallVector<Value, 4> promotedOperands;
   promotedOperands.reserve(operands.size());
   useBarePtrCallConv |= options.useBarePtrCallConv;
-  for (auto it : llvm::zip(opOperands, operands)) {
-    auto operand = std::get<0>(it);
-    auto llvmOperand = std::get<1>(it);
-
+  for (auto [operand, llvmOperand] : llvm::zip_equal(opOperands, operands)) {
     if (useBarePtrCallConv) {
       // For the bare-ptr calling convention, we only have to extract the
       // aligned pointer of a memref.
       if (isa<MemRefType>(operand.getType())) {
-        MemRefDescriptor desc(llvmOperand);
-        llvmOperand = desc.alignedPtr(builder, loc);
+        assert(llvmOperand.size() == 1 && "Expected a single operand");
+        MemRefDescriptor desc(llvmOperand.front());
+        promotedOperands.push_back(desc.alignedPtr(builder, loc));
+        continue;
       } else if (isa<UnrankedMemRefType>(operand.getType())) {
         llvm_unreachable("Unranked memrefs are not supported");
       }
     } else {
       if (isa<UnrankedMemRefType>(operand.getType())) {
-        U...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/153823


More information about the Mlir-commits mailing list