[Mlir-commits] [mlir] [mlir][LLVM] FuncToLLVM: Add 1:N support (PR #153823)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Aug 15 08:43:49 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-gpu
Author: Matthias Springer (matthias-springer)
<details>
<summary>Changes</summary>
Add support for 1:N type conversions to the `FuncToLLVM` lowering patterns. This commit does not change the lowering of any types (such as `MemRefType`). It just sets up the infrastructure, such that 1:N type conversions can be used during `FuncToLLVM`.
---
Patch is 33.90 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/153823.diff
7 Files Affected:
- (modified) mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h (+12-12)
- (modified) mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp (+66-37)
- (modified) mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp (+9-3)
- (modified) mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp (+47-53)
- (modified) mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp (+13-15)
- (modified) mlir/test/Conversion/MemRefToLLVM/type-conversion.mlir (+89-7)
- (modified) mlir/test/lib/Dialect/LLVM/TestPatterns.cpp (+21)
``````````diff
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
index 38b5e492a8ed8..a38b3283416e0 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
@@ -74,8 +74,13 @@ class LLVMTypeConverter : public TypeConverter {
/// LLVM-compatible type. In particular, if more than one value is returned,
/// create an LLVM dialect structure type with elements that correspond to
/// each of the types converted with `convertCallingConventionType`.
- Type packFunctionResults(TypeRange types,
- bool useBarePointerCallConv = false) const;
+ ///
+ /// Populate the converted (unpacked) types into `groupedTypes`, if provided.
+ /// `groupedType` contains one nested vector per input type. In case of a 1:N
+ /// conversion, a nested vector may contain 0 or more then 1 converted type.
+ Type packFunctionResults(
+ TypeRange types, bool useBarePointerCallConv = false,
+ SmallVector<SmallVector<Type>> *groupedTypes = nullptr) const;
/// Convert a non-empty list of types of values produced by an operation into
/// an LLVM-compatible type. In particular, if more than one value is
@@ -88,15 +93,9 @@ class LLVMTypeConverter : public TypeConverter {
/// UnrankedMemRefType, are converted following the specific rules for the
/// calling convention. Calling convention independent types are converted
/// following the default LLVM type conversions.
- Type convertCallingConventionType(Type type,
- bool useBarePointerCallConv = false) const;
-
- /// Promote the bare pointers in 'values' that resulted from memrefs to
- /// descriptors. 'stdTypes' holds the types of 'values' before the conversion
- /// to the LLVM-IR dialect (i.e., MemRefType, or any other builtin type).
- void promoteBarePtrsToDescriptors(ConversionPatternRewriter &rewriter,
- Location loc, ArrayRef<Type> stdTypes,
- SmallVectorImpl<Value> &values) const;
+ LogicalResult
+ convertCallingConventionType(Type type, SmallVectorImpl<Type> &result,
+ bool useBarePointerCallConv = false) const;
/// Returns the MLIR context.
MLIRContext &getContext() const;
@@ -111,7 +110,8 @@ class LLVMTypeConverter : public TypeConverter {
/// of the platform-specific C/C++ ABI lowering related to struct argument
/// passing.
SmallVector<Value, 4> promoteOperands(Location loc, ValueRange opOperands,
- ValueRange operands, OpBuilder &builder,
+ ArrayRef<ValueRange> operands,
+ OpBuilder &builder,
bool useBarePtrCallConv = false) const;
/// Promote the LLVM struct representation of one MemRef descriptor to stack
diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
index a4a6ae250640f..9ff96c88d9d6f 100644
--- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
+++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
@@ -527,19 +527,19 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> {
using ConvertOpToLLVMPattern<CallOpType>::ConvertOpToLLVMPattern;
using Super = CallOpInterfaceLowering<CallOpType>;
using Base = ConvertOpToLLVMPattern<CallOpType>;
+ using Adaptor = typename ConvertOpToLLVMPattern<CallOpType>::OneToNOpAdaptor;
- LogicalResult matchAndRewriteImpl(CallOpType callOp,
- typename CallOpType::Adaptor adaptor,
+ LogicalResult matchAndRewriteImpl(CallOpType callOp, Adaptor adaptor,
ConversionPatternRewriter &rewriter,
bool useBarePtrCallConv = false) const {
// Pack the result types into a struct.
Type packedResult = nullptr;
+ SmallVector<SmallVector<Type>> groupedResultTypes;
unsigned numResults = callOp.getNumResults();
auto resultTypes = llvm::to_vector<4>(callOp.getResultTypes());
-
if (numResults != 0) {
if (!(packedResult = this->getTypeConverter()->packFunctionResults(
- resultTypes, useBarePtrCallConv)))
+ resultTypes, useBarePtrCallConv, &groupedResultTypes)))
return failure();
}
@@ -565,34 +565,60 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> {
static_cast<int32_t>(promoted.size()), 0};
newOp.getProperties().op_bundle_sizes = rewriter.getDenseI32ArrayAttr({});
- SmallVector<Value, 4> results;
- if (numResults < 2) {
- // If < 2 results, packing did not do anything and we can just return.
- results.append(newOp.result_begin(), newOp.result_end());
- } else {
- // Otherwise, it had been converted to an operation producing a structure.
- // Extract individual results from the structure and return them as list.
- results.reserve(numResults);
- for (unsigned i = 0; i < numResults; ++i) {
- results.push_back(LLVM::ExtractValueOp::create(
- rewriter, callOp.getLoc(), newOp->getResult(0), i));
+ // Helper function that extracts an individual result from the return value
+ // of the new call op. llvm.call ops support only 0 or 1 result. In case of
+ // 2 or more results, the results are packed into a structure.
+ auto getUnpackedResult = [&](unsigned i) -> Value {
+ assert(packedResult && "convert op has no results");
+ if (!isa<LLVM::LLVMStructType>(packedResult)) {
+ assert(i == 0 && "out of bounds: converted op has only one result");
+ return newOp->getResult(0);
+ }
+ // Results have been converted to a structure. Extract individual results
+ // from the structure.
+ return LLVM::ExtractValueOp::create(rewriter, callOp.getLoc(),
+ newOp->getResult(0), i);
+ };
+
+ // Group the results into a vector of vectors, such that it is clear which
+ // original op result is replaced with which range of values. (In case of a
+ // 1:N conversion, there can be multiple replacements for a single result.)
+ SmallVector<SmallVector<Value>> results;
+ results.reserve(numResults);
+ unsigned counter = 0;
+ for (unsigned i = 0; i < numResults; ++i) {
+ SmallVector<Value> &group = results.emplace_back();
+ for (unsigned j = 0, e = groupedResultTypes[i].size(); j < e; ++j) {
+ group.push_back(getUnpackedResult(counter++));
}
}
- if (useBarePtrCallConv) {
- // For the bare-ptr calling convention, promote memref results to
- // descriptors.
- assert(results.size() == resultTypes.size() &&
- "The number of arguments and types doesn't match");
- this->getTypeConverter()->promoteBarePtrsToDescriptors(
- rewriter, callOp.getLoc(), resultTypes, results);
- } else if (failed(this->copyUnrankedDescriptors(rewriter, callOp.getLoc(),
- resultTypes, results,
- /*toDynamic=*/false))) {
- return failure();
+ for (unsigned i = 0; i < numResults; ++i) {
+ Type origType = resultTypes[i];
+ auto memrefType = dyn_cast<MemRefType>(origType);
+ auto unrankedMemrefType = dyn_cast<UnrankedMemRefType>(origType);
+ if (useBarePtrCallConv && memrefType) {
+ // For the bare-ptr calling convention, promote memref results to
+ // descriptors.
+ assert(results[i].size() == 1 && "expected one converted result");
+ results[i].front() = MemRefDescriptor::fromStaticShape(
+ rewriter, callOp.getLoc(), *this->getTypeConverter(), memrefType,
+ results[i].front());
+ }
+ if (unrankedMemrefType) {
+ assert(!useBarePtrCallConv && "unranked memref is not supported in the "
+ "bare-ptr calling convention");
+ assert(results[i].size() == 1 && "expected one converted result");
+ Value desc = this->copyUnrankedDescriptor(
+ rewriter, callOp.getLoc(), unrankedMemrefType, results[i].front(),
+ /*toDynamic=*/false);
+ if (!desc)
+ return failure();
+ results[i].front() = desc;
+ }
}
- rewriter.replaceOp(callOp, results);
+ rewriter.replaceOpWithMultiple(callOp, results);
return success();
}
};
@@ -606,7 +632,7 @@ class CallOpLowering : public CallOpInterfaceLowering<func::CallOp> {
symbolTables(symbolTables) {}
LogicalResult
- matchAndRewrite(func::CallOp callOp, OpAdaptor adaptor,
+ matchAndRewrite(func::CallOp callOp, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
bool useBarePtrCallConv = false;
if (getTypeConverter()->getOptions().useBarePtrCallConv) {
@@ -636,7 +662,7 @@ struct CallIndirectOpLowering
using Super::Super;
LogicalResult
- matchAndRewrite(func::CallIndirectOp callIndirectOp, OpAdaptor adaptor,
+ matchAndRewrite(func::CallIndirectOp callIndirectOp, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
return matchAndRewriteImpl(callIndirectOp, adaptor, rewriter);
}
@@ -679,47 +705,50 @@ struct ReturnOpLowering : public ConvertOpToLLVMPattern<func::ReturnOp> {
using ConvertOpToLLVMPattern<func::ReturnOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor,
+ matchAndRewrite(func::ReturnOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
- unsigned numArguments = op.getNumOperands();
SmallVector<Value, 4> updatedOperands;
auto funcOp = op->getParentOfType<LLVM::LLVMFuncOp>();
bool useBarePtrCallConv =
shouldUseBarePtrCallConv(funcOp, this->getTypeConverter());
- for (auto [oldOperand, newOperand] :
+ for (auto [oldOperand, newOperands] :
llvm::zip_equal(op->getOperands(), adaptor.getOperands())) {
Type oldTy = oldOperand.getType();
if (auto memRefType = dyn_cast<MemRefType>(oldTy)) {
+ assert(newOperands.size() == 1 && "expected one converted result");
if (useBarePtrCallConv &&
getTypeConverter()->canConvertToBarePtr(memRefType)) {
// For the bare-ptr calling convention, extract the aligned pointer to
// be returned from the memref descriptor.
- MemRefDescriptor memrefDesc(newOperand);
+ MemRefDescriptor memrefDesc(newOperands.front());
updatedOperands.push_back(memrefDesc.allocatedPtr(rewriter, loc));
continue;
}
} else if (auto unrankedMemRefType =
dyn_cast<UnrankedMemRefType>(oldTy)) {
+ assert(newOperands.size() == 1 && "expected one converted result");
if (useBarePtrCallConv) {
// Unranked memref is not supported in the bare pointer calling
// convention.
return failure();
}
- Value updatedDesc = copyUnrankedDescriptor(
- rewriter, loc, unrankedMemRefType, newOperand, /*toDynamic=*/true);
+ Value updatedDesc =
+ copyUnrankedDescriptor(rewriter, loc, unrankedMemRefType,
+ newOperands.front(), /*toDynamic=*/true);
if (!updatedDesc)
return failure();
updatedOperands.push_back(updatedDesc);
continue;
}
- updatedOperands.push_back(newOperand);
+
+ llvm::append_range(updatedOperands, newOperands);
}
// If ReturnOp has 0 or 1 operand, create it and return immediately.
- if (numArguments <= 1) {
+ if (updatedOperands.size() <= 1) {
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(
op, TypeRange(), updatedOperands, op->getAttrs());
return success();
diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
index 3cfbd898e49e2..a3ec644dcc068 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
@@ -719,8 +719,10 @@ LogicalResult ConvertHostRegisterOpToGpuRuntimeCallPattern::matchAndRewrite(
auto elementType = cast<UnrankedMemRefType>(memRefType).getElementType();
auto elementSize = getSizeInBytes(loc, elementType, rewriter);
+ SmallVector<ValueRange> adaptorOperands = llvm::map_to_vector(
+ adaptor.getOperands(), [](Value v) { return ValueRange(v); });
auto arguments = getTypeConverter()->promoteOperands(
- loc, op->getOperands(), adaptor.getOperands(), rewriter);
+ loc, op->getOperands(), adaptorOperands, rewriter);
arguments.push_back(elementSize);
hostRegisterCallBuilder.create(loc, rewriter, arguments);
@@ -741,8 +743,10 @@ LogicalResult ConvertHostUnregisterOpToGpuRuntimeCallPattern::matchAndRewrite(
auto elementType = cast<UnrankedMemRefType>(memRefType).getElementType();
auto elementSize = getSizeInBytes(loc, elementType, rewriter);
+ SmallVector<ValueRange> adaptorOperands = llvm::map_to_vector(
+ adaptor.getOperands(), [](Value v) { return ValueRange(v); });
auto arguments = getTypeConverter()->promoteOperands(
- loc, op->getOperands(), adaptor.getOperands(), rewriter);
+ loc, op->getOperands(), adaptorOperands, rewriter);
arguments.push_back(elementSize);
hostUnregisterCallBuilder.create(loc, rewriter, arguments);
@@ -973,8 +977,10 @@ LogicalResult LegalizeLaunchFuncOpPattern::matchAndRewrite(
// Note: If `useBarePtrCallConv` is set in the type converter's options,
// the value of `kernelBarePtrCallConv` will be ignored.
OperandRange origArguments = launchOp.getKernelOperands();
+ SmallVector<ValueRange> adaptorOperands = llvm::map_to_vector(
+ adaptor.getKernelOperands(), [](Value v) { return ValueRange(v); });
SmallVector<Value, 8> llvmArguments = getTypeConverter()->promoteOperands(
- loc, origArguments, adaptor.getKernelOperands(), rewriter,
+ loc, origArguments, adaptorOperands, rewriter,
/*useBarePtrCallConv=*/kernelBarePtrCallConv);
SmallVector<Value, 8> llvmArgumentsWithSizes;
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index 1a9bf569086da..621900e40f77d 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -365,6 +365,7 @@ Type LLVMTypeConverter::convertFunctionSignatureImpl(
useBarePtrCallConv = useBarePtrCallConv || options.useBarePtrCallConv;
auto funcArgConverter = useBarePtrCallConv ? barePtrFuncArgTypeConverter
: structFuncArgTypeConverter;
+
// Convert argument types one by one and check for errors.
for (auto [idx, type] : llvm::enumerate(funcTy.getInputs())) {
SmallVector<Type, 8> converted;
@@ -658,27 +659,19 @@ FailureOr<Type> LLVMTypeConverter::convertVectorType(VectorType type) const {
/// UnrankedMemRefType, are converted following the specific rules for the
/// calling convention. Calling convention independent types are converted
/// following the default LLVM type conversions.
-Type LLVMTypeConverter::convertCallingConventionType(
- Type type, bool useBarePtrCallConv) const {
- if (useBarePtrCallConv)
- if (auto memrefTy = dyn_cast<BaseMemRefType>(type))
- return convertMemRefToBarePtr(memrefTy);
-
- return convertType(type);
-}
+LogicalResult LLVMTypeConverter::convertCallingConventionType(
+ Type type, SmallVectorImpl<Type> &result, bool useBarePtrCallConv) const {
+ if (useBarePtrCallConv) {
+ if (auto memrefTy = dyn_cast<BaseMemRefType>(type)) {
+ Type converted = convertMemRefToBarePtr(memrefTy);
+ if (!converted)
+ return failure();
+ result.push_back(converted);
+ return success();
+ }
+ }
-/// Promote the bare pointers in 'values' that resulted from memrefs to
-/// descriptors. 'stdTypes' holds they types of 'values' before the conversion
-/// to the LLVM-IR dialect (i.e., MemRefType, or any other builtin type).
-void LLVMTypeConverter::promoteBarePtrsToDescriptors(
- ConversionPatternRewriter &rewriter, Location loc, ArrayRef<Type> stdTypes,
- SmallVectorImpl<Value> &values) const {
- assert(stdTypes.size() == values.size() &&
- "The number of types and values doesn't match");
- for (unsigned i = 0, end = values.size(); i < end; ++i)
- if (auto memrefTy = dyn_cast<MemRefType>(stdTypes[i]))
- values[i] = MemRefDescriptor::fromStaticShape(rewriter, loc, *this,
- memrefTy, values[i]);
+ return convertType(type, result);
}
/// Convert a non-empty list of types of values produced by an operation into an
@@ -706,23 +699,32 @@ Type LLVMTypeConverter::packOperationResults(TypeRange types) const {
/// LLVM-compatible type. In particular, if more than one value is returned,
/// create an LLVM dialect structure type with elements that correspond to each
/// of the types converted with `convertCallingConventionType`.
-Type LLVMTypeConverter::packFunctionResults(TypeRange types,
- bool useBarePtrCallConv) const {
+Type LLVMTypeConverter::packFunctionResults(
+ TypeRange types, bool useBarePtrCallConv,
+ SmallVector<SmallVector<Type>> *groupedTypes) const {
assert(!types.empty() && "expected non-empty list of type");
+ assert((!groupedTypes || groupedTypes->empty()) &&
+ "expected groupedTypes to be empty");
useBarePtrCallConv |= options.useBarePtrCallConv;
- if (types.size() == 1)
- return convertCallingConventionType(types.front(), useBarePtrCallConv);
-
SmallVector<Type> resultTypes;
resultTypes.reserve(types.size());
+ size_t sizeBefore = 0;
for (auto t : types) {
- auto converted = convertCallingConventionType(t, useBarePtrCallConv);
- if (!converted || !LLVM::isCompatibleType(converted))
+ if (failed(
+ convertCallingConventionType(t, resultTypes, useBarePtrCallConv)))
return {};
- resultTypes.push_back(converted);
+ if (groupedTypes) {
+ SmallVector<Type> &group = groupedTypes->emplace_back();
+ llvm::append_range(group, ArrayRef(resultTypes).drop_front(sizeBefore));
+ }
+ sizeBefore = resultTypes.size();
}
+ if (resultTypes.size() == 1)
+ return resultTypes.front();
+ if (resultTypes.empty())
+ return {};
return LLVM::LLVMStructType::getLiteral(&getContext(), resultTypes);
}
@@ -740,40 +742,40 @@ Value LLVMTypeConverter::promoteOneMemRefDescriptor(Location loc, Value operand,
return allocated;
}
-SmallVector<Value, 4>
-LLVMTypeConverter::promoteOperands(Location loc, ValueRange opOperands,
- ValueRange operands, OpBuilder &builder,
- bool useBarePtrCallConv) const {
+SmallVector<Value, 4> LLVMTypeConverter::promoteOperands(
+ Location loc, ValueRange opOperands, ArrayRef<ValueRange> operands,
+ OpBuilder &builder, bool useBarePtrCallConv) const {
SmallVector<Value, 4> promotedOperands;
promotedOperands.reserve(operands.size());
useBarePtrCallConv |= options.useBarePtrCallConv;
- for (auto it : llvm::zip(opOperands, operands)) {
- auto operand = std::get<0>(it);
- auto llvmOperand = std::get<1>(it);
-
+ for (auto [operand, llvmOperand] : llvm::zip_equal(opOperands, operands)) {
if (useBarePtrCallConv) {
// For the bare-ptr calling convention, we only have to extract the
// aligned pointer of a memref.
if (isa<MemRefType>(operand.getType())) {
- MemRefDescriptor desc(llvmOperand);
- llvmOperand = desc.alignedPtr(builder, loc);
+ assert(llvmOperand.size() == 1 && "Expected a single operand");
+ MemRefDescriptor desc(llvmOperand.front());
+ promotedOperands.push_back(desc.alignedPtr(builder, loc));
+ continue;
} else if (isa<UnrankedMemRefType>(operand.getType())) {
llvm_unreachable("Unranked memrefs are not supported");
}
} else {
if (isa<UnrankedMemRefType>(operand.getType())) {
- U...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/153823
More information about the Mlir-commits
mailing list