[Mlir-commits] [mlir] [mlir][LLVM] `FuncToLLVM`: Add 1:N type conversion support (PR #153823)
Matthias Springer
llvmlistbot at llvm.org
Fri Aug 15 08:53:40 PDT 2025
https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/153823
>From b79aea37e442d5a8c24241f1cbeacaaafea83e12 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Thu, 14 Aug 2025 13:45:48 +0000
Subject: [PATCH] [mlir][LLVM] FuncToLLVM: Add 1:N support
---
.../Conversion/LLVMCommon/TypeConverter.h | 24 ++--
mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp | 104 +++++++++++-------
.../GPUCommon/GPUToLLVMConversion.cpp | 12 +-
.../Conversion/LLVMCommon/TypeConverter.cpp | 100 ++++++++---------
.../Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp | 28 +++--
.../MemRefToLLVM/type-conversion.mlir | 97 ++++++++++++++--
mlir/test/lib/Dialect/LLVM/TestPatterns.cpp | 30 +++++
7 files changed, 268 insertions(+), 127 deletions(-)
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
index 38b5e492a8ed8..a38b3283416e0 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
@@ -74,8 +74,13 @@ class LLVMTypeConverter : public TypeConverter {
/// LLVM-compatible type. In particular, if more than one value is returned,
/// create an LLVM dialect structure type with elements that correspond to
/// each of the types converted with `convertCallingConventionType`.
- Type packFunctionResults(TypeRange types,
- bool useBarePointerCallConv = false) const;
+ ///
+ /// Populate the converted (unpacked) types into `groupedTypes`, if provided.
+ /// `groupedType` contains one nested vector per input type. In case of a 1:N
+ /// conversion, a nested vector may contain 0 or more then 1 converted type.
+ Type packFunctionResults(
+ TypeRange types, bool useBarePointerCallConv = false,
+ SmallVector<SmallVector<Type>> *groupedTypes = nullptr) const;
/// Convert a non-empty list of types of values produced by an operation into
/// an LLVM-compatible type. In particular, if more than one value is
@@ -88,15 +93,9 @@ class LLVMTypeConverter : public TypeConverter {
/// UnrankedMemRefType, are converted following the specific rules for the
/// calling convention. Calling convention independent types are converted
/// following the default LLVM type conversions.
- Type convertCallingConventionType(Type type,
- bool useBarePointerCallConv = false) const;
-
- /// Promote the bare pointers in 'values' that resulted from memrefs to
- /// descriptors. 'stdTypes' holds the types of 'values' before the conversion
- /// to the LLVM-IR dialect (i.e., MemRefType, or any other builtin type).
- void promoteBarePtrsToDescriptors(ConversionPatternRewriter &rewriter,
- Location loc, ArrayRef<Type> stdTypes,
- SmallVectorImpl<Value> &values) const;
+ LogicalResult
+ convertCallingConventionType(Type type, SmallVectorImpl<Type> &result,
+ bool useBarePointerCallConv = false) const;
/// Returns the MLIR context.
MLIRContext &getContext() const;
@@ -111,7 +110,8 @@ class LLVMTypeConverter : public TypeConverter {
/// of the platform-specific C/C++ ABI lowering related to struct argument
/// passing.
SmallVector<Value, 4> promoteOperands(Location loc, ValueRange opOperands,
- ValueRange operands, OpBuilder &builder,
+ ArrayRef<ValueRange> operands,
+ OpBuilder &builder,
bool useBarePtrCallConv = false) const;
/// Promote the LLVM struct representation of one MemRef descriptor to stack
diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
index a4a6ae250640f..95981138a7253 100644
--- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
+++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
@@ -527,19 +527,19 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> {
using ConvertOpToLLVMPattern<CallOpType>::ConvertOpToLLVMPattern;
using Super = CallOpInterfaceLowering<CallOpType>;
using Base = ConvertOpToLLVMPattern<CallOpType>;
+ using Adaptor = typename ConvertOpToLLVMPattern<CallOpType>::OneToNOpAdaptor;
- LogicalResult matchAndRewriteImpl(CallOpType callOp,
- typename CallOpType::Adaptor adaptor,
+ LogicalResult matchAndRewriteImpl(CallOpType callOp, Adaptor adaptor,
ConversionPatternRewriter &rewriter,
bool useBarePtrCallConv = false) const {
// Pack the result types into a struct.
Type packedResult = nullptr;
+ SmallVector<SmallVector<Type>> groupedResultTypes;
unsigned numResults = callOp.getNumResults();
auto resultTypes = llvm::to_vector<4>(callOp.getResultTypes());
-
if (numResults != 0) {
if (!(packedResult = this->getTypeConverter()->packFunctionResults(
- resultTypes, useBarePtrCallConv)))
+ resultTypes, useBarePtrCallConv, &groupedResultTypes)))
return failure();
}
@@ -565,34 +565,61 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> {
static_cast<int32_t>(promoted.size()), 0};
newOp.getProperties().op_bundle_sizes = rewriter.getDenseI32ArrayAttr({});
- SmallVector<Value, 4> results;
- if (numResults < 2) {
- // If < 2 results, packing did not do anything and we can just return.
- results.append(newOp.result_begin(), newOp.result_end());
- } else {
- // Otherwise, it had been converted to an operation producing a structure.
- // Extract individual results from the structure and return them as list.
- results.reserve(numResults);
- for (unsigned i = 0; i < numResults; ++i) {
- results.push_back(LLVM::ExtractValueOp::create(
- rewriter, callOp.getLoc(), newOp->getResult(0), i));
+ // Helper function that extracts an individual result from the return value
+ // of the new call op. llvm.call ops support only 0 or 1 result. In case of
+ // 2 or more results, the results are packed into a structure.
+ auto getUnpackedResult = [&](unsigned i) -> Value {
+ assert(packedResult && "convert op has no results");
+ if (!isa<LLVM::LLVMStructType>(packedResult)) {
+ assert(i == 0 && "out of bounds: converted op has only one result");
+ return newOp->getResult(0);
+ }
+ // Results have been converted to a structure. Extract individual results
+ // from the structure.
+ return LLVM::ExtractValueOp::create(rewriter, callOp.getLoc(),
+ newOp->getResult(0), i);
+ };
+
+ // Group the results into a vector of vectors, such that it is clear which
+ // original op result is replaced with which range of values. (In case of a
+ // 1:N conversion, there can be multiple replacements for a single result.)
+ SmallVector<SmallVector<Value>> results;
+ results.reserve(numResults);
+ unsigned counter = 0;
+ for (unsigned i = 0; i < numResults; ++i) {
+ SmallVector<Value> &group = results.emplace_back();
+ for (unsigned j = 0, e = groupedResultTypes[i].size(); j < e; ++j) {
+ group.push_back(getUnpackedResult(counter++));
}
}
- if (useBarePtrCallConv) {
- // For the bare-ptr calling convention, promote memref results to
- // descriptors.
- assert(results.size() == resultTypes.size() &&
- "The number of arguments and types doesn't match");
- this->getTypeConverter()->promoteBarePtrsToDescriptors(
- rewriter, callOp.getLoc(), resultTypes, results);
- } else if (failed(this->copyUnrankedDescriptors(rewriter, callOp.getLoc(),
- resultTypes, results,
- /*toDynamic=*/false))) {
- return failure();
+ // Special handling for MemRef types.
+ for (unsigned i = 0; i < numResults; ++i) {
+ Type origType = resultTypes[i];
+ auto memrefType = dyn_cast<MemRefType>(origType);
+ auto unrankedMemrefType = dyn_cast<UnrankedMemRefType>(origType);
+ if (useBarePtrCallConv && memrefType) {
+ // For the bare-ptr calling convention, promote memref results to
+ // descriptors.
+ assert(results[i].size() == 1 && "expected one converted result");
+ results[i].front() = MemRefDescriptor::fromStaticShape(
+ rewriter, callOp.getLoc(), *this->getTypeConverter(), memrefType,
+ results[i].front());
+ }
+ if (unrankedMemrefType) {
+ assert(!useBarePtrCallConv && "unranked memref is not supported in the "
+ "bare-ptr calling convention");
+ assert(results[i].size() == 1 && "expected one converted result");
+ Value desc = this->copyUnrankedDescriptor(
+ rewriter, callOp.getLoc(), unrankedMemrefType, results[i].front(),
+ /*toDynamic=*/false);
+ if (!desc)
+ return failure();
+ results[i].front() = desc;
+ }
}
- rewriter.replaceOp(callOp, results);
+ rewriter.replaceOpWithMultiple(callOp, results);
return success();
}
};
@@ -606,7 +633,7 @@ class CallOpLowering : public CallOpInterfaceLowering<func::CallOp> {
symbolTables(symbolTables) {}
LogicalResult
- matchAndRewrite(func::CallOp callOp, OpAdaptor adaptor,
+ matchAndRewrite(func::CallOp callOp, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
bool useBarePtrCallConv = false;
if (getTypeConverter()->getOptions().useBarePtrCallConv) {
@@ -636,7 +663,7 @@ struct CallIndirectOpLowering
using Super::Super;
LogicalResult
- matchAndRewrite(func::CallIndirectOp callIndirectOp, OpAdaptor adaptor,
+ matchAndRewrite(func::CallIndirectOp callIndirectOp, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
return matchAndRewriteImpl(callIndirectOp, adaptor, rewriter);
}
@@ -679,47 +706,50 @@ struct ReturnOpLowering : public ConvertOpToLLVMPattern<func::ReturnOp> {
using ConvertOpToLLVMPattern<func::ReturnOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor,
+ matchAndRewrite(func::ReturnOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
- unsigned numArguments = op.getNumOperands();
SmallVector<Value, 4> updatedOperands;
auto funcOp = op->getParentOfType<LLVM::LLVMFuncOp>();
bool useBarePtrCallConv =
shouldUseBarePtrCallConv(funcOp, this->getTypeConverter());
- for (auto [oldOperand, newOperand] :
+ for (auto [oldOperand, newOperands] :
llvm::zip_equal(op->getOperands(), adaptor.getOperands())) {
Type oldTy = oldOperand.getType();
if (auto memRefType = dyn_cast<MemRefType>(oldTy)) {
+ assert(newOperands.size() == 1 && "expected one converted result");
if (useBarePtrCallConv &&
getTypeConverter()->canConvertToBarePtr(memRefType)) {
// For the bare-ptr calling convention, extract the aligned pointer to
// be returned from the memref descriptor.
- MemRefDescriptor memrefDesc(newOperand);
+ MemRefDescriptor memrefDesc(newOperands.front());
updatedOperands.push_back(memrefDesc.allocatedPtr(rewriter, loc));
continue;
}
} else if (auto unrankedMemRefType =
dyn_cast<UnrankedMemRefType>(oldTy)) {
+ assert(newOperands.size() == 1 && "expected one converted result");
if (useBarePtrCallConv) {
// Unranked memref is not supported in the bare pointer calling
// convention.
return failure();
}
- Value updatedDesc = copyUnrankedDescriptor(
- rewriter, loc, unrankedMemRefType, newOperand, /*toDynamic=*/true);
+ Value updatedDesc =
+ copyUnrankedDescriptor(rewriter, loc, unrankedMemRefType,
+ newOperands.front(), /*toDynamic=*/true);
if (!updatedDesc)
return failure();
updatedOperands.push_back(updatedDesc);
continue;
}
- updatedOperands.push_back(newOperand);
+
+ llvm::append_range(updatedOperands, newOperands);
}
// If ReturnOp has 0 or 1 operand, create it and return immediately.
- if (numArguments <= 1) {
+ if (updatedOperands.size() <= 1) {
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(
op, TypeRange(), updatedOperands, op->getAttrs());
return success();
diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
index 3cfbd898e49e2..a3ec644dcc068 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
@@ -719,8 +719,10 @@ LogicalResult ConvertHostRegisterOpToGpuRuntimeCallPattern::matchAndRewrite(
auto elementType = cast<UnrankedMemRefType>(memRefType).getElementType();
auto elementSize = getSizeInBytes(loc, elementType, rewriter);
+ SmallVector<ValueRange> adaptorOperands = llvm::map_to_vector(
+ adaptor.getOperands(), [](Value v) { return ValueRange(v); });
auto arguments = getTypeConverter()->promoteOperands(
- loc, op->getOperands(), adaptor.getOperands(), rewriter);
+ loc, op->getOperands(), adaptorOperands, rewriter);
arguments.push_back(elementSize);
hostRegisterCallBuilder.create(loc, rewriter, arguments);
@@ -741,8 +743,10 @@ LogicalResult ConvertHostUnregisterOpToGpuRuntimeCallPattern::matchAndRewrite(
auto elementType = cast<UnrankedMemRefType>(memRefType).getElementType();
auto elementSize = getSizeInBytes(loc, elementType, rewriter);
+ SmallVector<ValueRange> adaptorOperands = llvm::map_to_vector(
+ adaptor.getOperands(), [](Value v) { return ValueRange(v); });
auto arguments = getTypeConverter()->promoteOperands(
- loc, op->getOperands(), adaptor.getOperands(), rewriter);
+ loc, op->getOperands(), adaptorOperands, rewriter);
arguments.push_back(elementSize);
hostUnregisterCallBuilder.create(loc, rewriter, arguments);
@@ -973,8 +977,10 @@ LogicalResult LegalizeLaunchFuncOpPattern::matchAndRewrite(
// Note: If `useBarePtrCallConv` is set in the type converter's options,
// the value of `kernelBarePtrCallConv` will be ignored.
OperandRange origArguments = launchOp.getKernelOperands();
+ SmallVector<ValueRange> adaptorOperands = llvm::map_to_vector(
+ adaptor.getKernelOperands(), [](Value v) { return ValueRange(v); });
SmallVector<Value, 8> llvmArguments = getTypeConverter()->promoteOperands(
- loc, origArguments, adaptor.getKernelOperands(), rewriter,
+ loc, origArguments, adaptorOperands, rewriter,
/*useBarePtrCallConv=*/kernelBarePtrCallConv);
SmallVector<Value, 8> llvmArgumentsWithSizes;
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index 1a9bf569086da..621900e40f77d 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -365,6 +365,7 @@ Type LLVMTypeConverter::convertFunctionSignatureImpl(
useBarePtrCallConv = useBarePtrCallConv || options.useBarePtrCallConv;
auto funcArgConverter = useBarePtrCallConv ? barePtrFuncArgTypeConverter
: structFuncArgTypeConverter;
+
// Convert argument types one by one and check for errors.
for (auto [idx, type] : llvm::enumerate(funcTy.getInputs())) {
SmallVector<Type, 8> converted;
@@ -658,27 +659,19 @@ FailureOr<Type> LLVMTypeConverter::convertVectorType(VectorType type) const {
/// UnrankedMemRefType, are converted following the specific rules for the
/// calling convention. Calling convention independent types are converted
/// following the default LLVM type conversions.
-Type LLVMTypeConverter::convertCallingConventionType(
- Type type, bool useBarePtrCallConv) const {
- if (useBarePtrCallConv)
- if (auto memrefTy = dyn_cast<BaseMemRefType>(type))
- return convertMemRefToBarePtr(memrefTy);
-
- return convertType(type);
-}
+LogicalResult LLVMTypeConverter::convertCallingConventionType(
+ Type type, SmallVectorImpl<Type> &result, bool useBarePtrCallConv) const {
+ if (useBarePtrCallConv) {
+ if (auto memrefTy = dyn_cast<BaseMemRefType>(type)) {
+ Type converted = convertMemRefToBarePtr(memrefTy);
+ if (!converted)
+ return failure();
+ result.push_back(converted);
+ return success();
+ }
+ }
-/// Promote the bare pointers in 'values' that resulted from memrefs to
-/// descriptors. 'stdTypes' holds they types of 'values' before the conversion
-/// to the LLVM-IR dialect (i.e., MemRefType, or any other builtin type).
-void LLVMTypeConverter::promoteBarePtrsToDescriptors(
- ConversionPatternRewriter &rewriter, Location loc, ArrayRef<Type> stdTypes,
- SmallVectorImpl<Value> &values) const {
- assert(stdTypes.size() == values.size() &&
- "The number of types and values doesn't match");
- for (unsigned i = 0, end = values.size(); i < end; ++i)
- if (auto memrefTy = dyn_cast<MemRefType>(stdTypes[i]))
- values[i] = MemRefDescriptor::fromStaticShape(rewriter, loc, *this,
- memrefTy, values[i]);
+ return convertType(type, result);
}
/// Convert a non-empty list of types of values produced by an operation into an
@@ -706,23 +699,32 @@ Type LLVMTypeConverter::packOperationResults(TypeRange types) const {
/// LLVM-compatible type. In particular, if more than one value is returned,
/// create an LLVM dialect structure type with elements that correspond to each
/// of the types converted with `convertCallingConventionType`.
-Type LLVMTypeConverter::packFunctionResults(TypeRange types,
- bool useBarePtrCallConv) const {
+Type LLVMTypeConverter::packFunctionResults(
+ TypeRange types, bool useBarePtrCallConv,
+ SmallVector<SmallVector<Type>> *groupedTypes) const {
assert(!types.empty() && "expected non-empty list of type");
+ assert((!groupedTypes || groupedTypes->empty()) &&
+ "expected groupedTypes to be empty");
useBarePtrCallConv |= options.useBarePtrCallConv;
- if (types.size() == 1)
- return convertCallingConventionType(types.front(), useBarePtrCallConv);
-
SmallVector<Type> resultTypes;
resultTypes.reserve(types.size());
+ size_t sizeBefore = 0;
for (auto t : types) {
- auto converted = convertCallingConventionType(t, useBarePtrCallConv);
- if (!converted || !LLVM::isCompatibleType(converted))
+ if (failed(
+ convertCallingConventionType(t, resultTypes, useBarePtrCallConv)))
return {};
- resultTypes.push_back(converted);
+ if (groupedTypes) {
+ SmallVector<Type> &group = groupedTypes->emplace_back();
+ llvm::append_range(group, ArrayRef(resultTypes).drop_front(sizeBefore));
+ }
+ sizeBefore = resultTypes.size();
}
+ if (resultTypes.size() == 1)
+ return resultTypes.front();
+ if (resultTypes.empty())
+ return {};
return LLVM::LLVMStructType::getLiteral(&getContext(), resultTypes);
}
@@ -740,40 +742,40 @@ Value LLVMTypeConverter::promoteOneMemRefDescriptor(Location loc, Value operand,
return allocated;
}
-SmallVector<Value, 4>
-LLVMTypeConverter::promoteOperands(Location loc, ValueRange opOperands,
- ValueRange operands, OpBuilder &builder,
- bool useBarePtrCallConv) const {
+SmallVector<Value, 4> LLVMTypeConverter::promoteOperands(
+ Location loc, ValueRange opOperands, ArrayRef<ValueRange> operands,
+ OpBuilder &builder, bool useBarePtrCallConv) const {
SmallVector<Value, 4> promotedOperands;
promotedOperands.reserve(operands.size());
useBarePtrCallConv |= options.useBarePtrCallConv;
- for (auto it : llvm::zip(opOperands, operands)) {
- auto operand = std::get<0>(it);
- auto llvmOperand = std::get<1>(it);
-
+ for (auto [operand, llvmOperand] : llvm::zip_equal(opOperands, operands)) {
if (useBarePtrCallConv) {
// For the bare-ptr calling convention, we only have to extract the
// aligned pointer of a memref.
if (isa<MemRefType>(operand.getType())) {
- MemRefDescriptor desc(llvmOperand);
- llvmOperand = desc.alignedPtr(builder, loc);
+ assert(llvmOperand.size() == 1 && "Expected a single operand");
+ MemRefDescriptor desc(llvmOperand.front());
+ promotedOperands.push_back(desc.alignedPtr(builder, loc));
+ continue;
} else if (isa<UnrankedMemRefType>(operand.getType())) {
llvm_unreachable("Unranked memrefs are not supported");
}
} else {
if (isa<UnrankedMemRefType>(operand.getType())) {
- UnrankedMemRefDescriptor::unpack(builder, loc, llvmOperand,
+ assert(llvmOperand.size() == 1 && "Expected a single operand");
+ UnrankedMemRefDescriptor::unpack(builder, loc, llvmOperand.front(),
promotedOperands);
continue;
}
if (auto memrefType = dyn_cast<MemRefType>(operand.getType())) {
- MemRefDescriptor::unpack(builder, loc, llvmOperand, memrefType,
+ assert(llvmOperand.size() == 1 && "Expected a single operand");
+ MemRefDescriptor::unpack(builder, loc, llvmOperand.front(), memrefType,
promotedOperands);
continue;
}
}
- promotedOperands.push_back(llvmOperand);
+ llvm::append_range(promotedOperands, llvmOperand);
}
return promotedOperands;
}
@@ -802,11 +804,7 @@ mlir::structFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type,
result.append(converted.begin(), converted.end());
return success();
}
- auto converted = converter.convertType(type);
- if (!converted)
- return failure();
- result.push_back(converted);
- return success();
+ return converter.convertType(type, result);
}
/// Callback to convert function argument types. It converts MemRef function
@@ -814,11 +812,7 @@ mlir::structFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type,
LogicalResult
mlir::barePtrFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type,
SmallVectorImpl<Type> &result) {
- auto llvmTy = converter.convertCallingConventionType(
- type, /*useBarePointerCallConv=*/true);
- if (!llvmTy)
- return failure();
-
- result.push_back(llvmTy);
- return success();
+ return converter.convertCallingConventionType(
+ type, result,
+ /*useBarePointerCallConv=*/true);
}
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index f7f5381799529..46856203672e6 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -1106,12 +1106,10 @@ struct NVGPUGenerateWarpgroupDescriptorLowering
// // [0,14) start_address
dsc = insertBit(dsc, basePtr14bit, startBaseAddrBit);
- LDBG() << "Generating warpgroup.descriptor: "
- << "leading_off:" << leadDimVal << "\t"
- << "stride_off :" << strideDimVal << "\t"
- << "base_offset:" << offsetVal << "\t"
- << "layout_type:" << swizzle << " ("
- << nvgpu::stringifyTensorMapSwizzleKind(swizzleKind)
+ LDBG() << "Generating warpgroup.descriptor: " << "leading_off:"
+ << leadDimVal << "\t" << "stride_off :" << strideDimVal << "\t"
+ << "base_offset:" << offsetVal << "\t" << "layout_type:" << swizzle
+ << " (" << nvgpu::stringifyTensorMapSwizzleKind(swizzleKind)
<< ")\n start_addr : " << baseAddr;
rewriter.replaceOp(op, dsc);
@@ -1181,8 +1179,10 @@ struct NVGPUTmaCreateDescriptorOpLowering
Value tensorElementType =
elementTypeAsLLVMConstant(b, op.getTensor().getType().getElementType());
+ SmallVector<ValueRange> adaptorOperands = llvm::map_to_vector(
+ adaptor.getOperands(), [](Value v) { return ValueRange(v); });
auto promotedOperands = getTypeConverter()->promoteOperands(
- b.getLoc(), op->getOperands(), adaptor.getOperands(), b);
+ b.getLoc(), op->getOperands(), adaptorOperands, b);
Value boxArrayPtr = LLVM::AllocaOp::create(
b, llvmPointerType, llvmInt64Type, makeI64Const(b, 5));
@@ -1401,14 +1401,12 @@ struct NVGPUWarpgroupMmaOpLowering
/// This function generates a WgmmaMmaAsyncOp using provided GMMA matrix
/// descriptors and arranges them based on induction variables: i, j, and k.
Value generateWgmma(int i, int j, int k, Value matrixC) {
- LDBG() << "\t wgmma."
- << "m" << wgmmaM << "n" << wgmmaN << "k" << wgmmaK << "(A["
- << (iterationM * wgmmaM) << ":" << (iterationM * wgmmaM) + wgmmaM
- << "][" << (iterationK * wgmmaK) << ":"
- << (iterationK * wgmmaK + wgmmaK) << "] * "
- << " B[" << (iterationK * wgmmaK) << ":"
- << (iterationK * wgmmaK + wgmmaK) << "][" << 0 << ":" << wgmmaN
- << "])";
+ LDBG() << "\t wgmma." << "m" << wgmmaM << "n" << wgmmaN << "k" << wgmmaK
+ << "(A[" << (iterationM * wgmmaM) << ":"
+ << (iterationM * wgmmaM) + wgmmaM << "][" << (iterationK * wgmmaK)
+ << ":" << (iterationK * wgmmaK + wgmmaK) << "] * " << " B["
+ << (iterationK * wgmmaK) << ":" << (iterationK * wgmmaK + wgmmaK)
+ << "][" << 0 << ":" << wgmmaN << "])";
Value descriptorA = iterateDescriptorA(adaptor.getDescriptorA(), i, j, k);
Value descriptorB = iterateDescriptorB(adaptor.getDescriptorB(), i, j, k);
diff --git a/mlir/test/Conversion/MemRefToLLVM/type-conversion.mlir b/mlir/test/Conversion/MemRefToLLVM/type-conversion.mlir
index 0288aa11313c7..c1751f282b002 100644
--- a/mlir/test/Conversion/MemRefToLLVM/type-conversion.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/type-conversion.mlir
@@ -1,12 +1,13 @@
-// RUN: mlir-opt %s -test-llvm-legalize-patterns -split-input-file
+// RUN: mlir-opt %s -test-llvm-legalize-patterns -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -test-llvm-legalize-patterns="allow-pattern-rollback=0" -split-input-file | FileCheck %s
// Test the argument materializer for ranked MemRef types.
// CHECK-LABEL: func @construct_ranked_memref_descriptor(
-// CHECK: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK-COUNT-7: llvm.insertvalue
// CHECK: builtin.unrealized_conversion_cast %{{.*}} : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> to memref<5x4xf32>
-func.func @construct_ranked_memref_descriptor(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: i64, %arg3: i64, %arg4: i64, %arg5: i64, %arg6: i64) {
+func.func @construct_ranked_memref_descriptor(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: i64, %arg3: i64, %arg4: i64, %arg5: i64, %arg6: i64) attributes {is_legal} {
%0 = "test.direct_replacement"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6) : (!llvm.ptr, !llvm.ptr, i64, i64, i64, i64, i64) -> (memref<5x4xf32>)
"test.legal_op"(%0) : (memref<5x4xf32>) -> ()
return
@@ -21,7 +22,7 @@ func.func @construct_ranked_memref_descriptor(%arg0: !llvm.ptr, %arg1: !llvm.ptr
// CHECK-LABEL: func @invalid_ranked_memref_descriptor(
// CHECK: %[[cast:.*]] = builtin.unrealized_conversion_cast %{{.*}} : i1 to memref<5x4xf32>
// CHECK: "test.legal_op"(%[[cast]])
-func.func @invalid_ranked_memref_descriptor(%arg0: i1) {
+func.func @invalid_ranked_memref_descriptor(%arg0: i1) attributes {is_legal} {
%0 = "test.direct_replacement"(%arg0) : (i1) -> (memref<5x4xf32>)
"test.legal_op"(%0) : (memref<5x4xf32>) -> ()
return
@@ -32,10 +33,10 @@ func.func @invalid_ranked_memref_descriptor(%arg0: i1) {
// Test the argument materializer for unranked MemRef types.
// CHECK-LABEL: func @construct_unranked_memref_descriptor(
-// CHECK: llvm.mlir.undef : !llvm.struct<(i64, ptr)>
+// CHECK: llvm.mlir.poison : !llvm.struct<(i64, ptr)>
// CHECK-COUNT-2: llvm.insertvalue
// CHECK: builtin.unrealized_conversion_cast %{{.*}} : !llvm.struct<(i64, ptr)> to memref<*xf32>
-func.func @construct_unranked_memref_descriptor(%arg0: i64, %arg1: !llvm.ptr) {
+func.func @construct_unranked_memref_descriptor(%arg0: i64, %arg1: !llvm.ptr) attributes {is_legal} {
%0 = "test.direct_replacement"(%arg0, %arg1) : (i64, !llvm.ptr) -> (memref<*xf32>)
"test.legal_op"(%0) : (memref<*xf32>) -> ()
return
@@ -50,8 +51,90 @@ func.func @construct_unranked_memref_descriptor(%arg0: i64, %arg1: !llvm.ptr) {
// CHECK-LABEL: func @invalid_unranked_memref_descriptor(
// CHECK: %[[cast:.*]] = builtin.unrealized_conversion_cast %{{.*}} : i1 to memref<*xf32>
// CHECK: "test.legal_op"(%[[cast]])
-func.func @invalid_unranked_memref_descriptor(%arg0: i1) {
+func.func @invalid_unranked_memref_descriptor(%arg0: i1) attributes {is_legal} {
%0 = "test.direct_replacement"(%arg0) : (i1) -> (memref<*xf32>)
"test.legal_op"(%0) : (memref<*xf32>) -> ()
return
}
+
+// -----
+
+// CHECK-LABEL: llvm.func @simple_func_conversion(
+// CHECK-SAME: %[[arg0:.*]]: i64) -> i64
+// CHECK: llvm.return %[[arg0]] : i64
+func.func @simple_func_conversion(%arg0: i64) -> i64 {
+ return %arg0 : i64
+}
+
+// -----
+
+// CHECK-LABEL: llvm.func @one_to_n_argument_conversion(
+// CHECK-SAME: %[[arg0:.*]]: i18, %[[arg1:.*]]: i18)
+// CHECK: %[[cast:.*]] = builtin.unrealized_conversion_cast %[[arg0]], %[[arg1]] : i18, i18 to i17
+// CHECK: "test.legal_op"(%[[cast]]) : (i17) -> ()
+func.func @one_to_n_argument_conversion(%arg0: i17) {
+ "test.legal_op"(%arg0) : (i17) -> ()
+ return
+}
+
+// CHECK: llvm.func @caller(%[[arg0:.*]]: i18, %[[arg1:.*]]: i18)
+// CHECK: llvm.call @one_to_n_argument_conversion(%[[arg0]], %[[arg1]]) : (i18, i18) -> ()
+func.func @caller(%arg0: i17) {
+ func.call @one_to_n_argument_conversion(%arg0) : (i17) -> ()
+ return
+}
+
+// -----
+
+// CHECK-LABEL: llvm.func @one_to_n_return_conversion(
+// CHECK-SAME: %[[arg0:.*]]: i18, %[[arg1:.*]]: i18) -> !llvm.struct<(i18, i18)>
+// CHECK: %[[p1:.*]] = llvm.mlir.poison : !llvm.struct<(i18, i18)>
+// CHECK: %[[p2:.*]] = llvm.insertvalue %[[arg0]], %[[p1]][0] : !llvm.struct<(i18, i18)>
+// CHECK: %[[p3:.*]] = llvm.insertvalue %[[arg1]], %[[p2]][1] : !llvm.struct<(i18, i18)>
+// CHECK: llvm.return %[[p3]]
+func.func @one_to_n_return_conversion(%arg0: i17) -> i17 {
+ return %arg0 : i17
+}
+
+// CHECK: llvm.func @caller(%[[arg0:.*]]: i18, %[[arg1:.*]]: i18)
+// CHECK: %[[res:.*]] = llvm.call @one_to_n_return_conversion(%[[arg0]], %[[arg1]]) : (i18, i18) -> !llvm.struct<(i18, i18)>
+// CHECK: %[[e0:.*]] = llvm.extractvalue %[[res]][0] : !llvm.struct<(i18, i18)>
+// CHECK: %[[e1:.*]] = llvm.extractvalue %[[res]][1] : !llvm.struct<(i18, i18)>
+// CHECK: %[[i0:.*]] = llvm.mlir.poison : !llvm.struct<(i18, i18)>
+// CHECK: %[[i1:.*]] = llvm.insertvalue %[[e0]], %[[i0]][0] : !llvm.struct<(i18, i18)>
+// CHECK: %[[i2:.*]] = llvm.insertvalue %[[e1]], %[[i1]][1] : !llvm.struct<(i18, i18)>
+// CHECK: llvm.return %[[i2]]
+func.func @caller(%arg0: i17) -> (i17) {
+ %res = func.call @one_to_n_return_conversion(%arg0) : (i17) -> (i17)
+ return %res : i17
+}
+
+// -----
+
+// CHECK-LABEL: llvm.func @multi_return(
+// CHECK-SAME: %[[arg0:.*]]: i18, %[[arg1:.*]]: i18, %[[arg2:.*]]: i1) -> !llvm.struct<(i18, i18, i1)>
+// CHECK: %[[p1:.*]] = llvm.mlir.poison : !llvm.struct<(i18, i18, i1)>
+// CHECK: %[[p2:.*]] = llvm.insertvalue %[[arg0]], %[[p1]][0] : !llvm.struct<(i18, i18, i1)>
+// CHECK: %[[p3:.*]] = llvm.insertvalue %[[arg1]], %[[p2]][1] : !llvm.struct<(i18, i18, i1)>
+// CHECK: %[[p4:.*]] = llvm.insertvalue %[[arg2]], %[[p3]][2] : !llvm.struct<(i18, i18, i1)>
+// CHECK: llvm.return %[[p4]]
+func.func @multi_return(%arg0: i17, %arg1: i1) -> (i17, i1) {
+ return %arg0, %arg1 : i17, i1
+}
+
+// CHECK: llvm.func @caller(%[[arg0:.*]]: i1, %[[arg1:.*]]: i18, %[[arg2:.*]]: i18)
+// CHECK: %[[res:.*]] = llvm.call @multi_return(%[[arg1]], %[[arg2]], %[[arg0]]) : (i18, i18, i1) -> !llvm.struct<(i18, i18, i1)>
+// CHECK: %[[e0:.*]] = llvm.extractvalue %[[res]][0] : !llvm.struct<(i18, i18, i1)>
+// CHECK: %[[e1:.*]] = llvm.extractvalue %[[res]][1] : !llvm.struct<(i18, i18, i1)>
+// CHECK: %[[e2:.*]] = llvm.extractvalue %[[res]][2] : !llvm.struct<(i18, i18, i1)>
+// CHECK: %[[i0:.*]] = llvm.mlir.poison : !llvm.struct<(i18, i18, i1, i18, i18)>
+// CHECK: %[[i1:.*]] = llvm.insertvalue %[[e0]], %[[i0]][0]
+// CHECK: %[[i2:.*]] = llvm.insertvalue %[[e1]], %[[i1]][1]
+// CHECK: %[[i3:.*]] = llvm.insertvalue %[[e2]], %[[i2]][2]
+// CHECK: %[[i4:.*]] = llvm.insertvalue %[[e0]], %[[i3]][3]
+// CHECK: %[[i5:.*]] = llvm.insertvalue %[[e1]], %[[i4]][4]
+// CHECK: llvm.return %[[i5]]
+func.func @caller(%arg0: i1, %arg1: i17) -> (i17, i1, i17) {
+ %res:2 = func.call @multi_return(%arg1, %arg0) : (i17, i1) -> (i17, i1)
+ return %res#0, %res#1, %res#0 : i17, i1, i17
+}
diff --git a/mlir/test/lib/Dialect/LLVM/TestPatterns.cpp b/mlir/test/lib/Dialect/LLVM/TestPatterns.cpp
index ab02866970b1d..fe9aa0f2a9902 100644
--- a/mlir/test/lib/Dialect/LLVM/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/LLVM/TestPatterns.cpp
@@ -6,7 +6,9 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/Pass/Pass.h"
@@ -34,6 +36,10 @@ struct TestLLVMLegalizePatternsPass
: public PassWrapper<TestLLVMLegalizePatternsPass, OperationPass<>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLLVMLegalizePatternsPass)
+ TestLLVMLegalizePatternsPass() = default;
+ TestLLVMLegalizePatternsPass(const TestLLVMLegalizePatternsPass &other)
+ : PassWrapper(other) {}
+
StringRef getArgument() const final { return "test-llvm-legalize-patterns"; }
StringRef getDescription() const final {
return "Run LLVM dialect legalization patterns";
@@ -45,22 +51,46 @@ struct TestLLVMLegalizePatternsPass
void runOnOperation() override {
MLIRContext *ctx = &getContext();
+
+ // Set up type converter.
LLVMTypeConverter converter(ctx);
+ converter.addConversion(
+ [&](IntegerType type, SmallVectorImpl<Type> &result) {
+ if (type.isInteger(17)) {
+ // Convert i17 -> (i18, i18).
+ result.append(2, Builder(ctx).getIntegerType(18));
+ return success();
+ }
+
+ result.push_back(type);
+ return success();
+ });
+
+ // Populate patterns.
mlir::RewritePatternSet patterns(ctx);
patterns.add<TestDirectReplacementOp>(ctx, converter);
+ populateFuncToLLVMConversionPatterns(converter, patterns);
// Define the conversion target used for the test.
ConversionTarget target(*ctx);
target.addLegalOp(OperationName("test.legal_op", ctx));
+ target.addLegalDialect<LLVM::LLVMDialect>();
+ target.addDynamicallyLegalOp<func::FuncOp>(
+ [&](func::FuncOp funcOp) { return funcOp->hasAttr("is_legal"); });
// Handle a partial conversion.
DenseSet<Operation *> unlegalizedOps;
ConversionConfig config;
config.unlegalizedOps = &unlegalizedOps;
+ config.allowPatternRollback = allowPatternRollback;
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns), config)))
getOperation()->emitError() << "applyPartialConversion failed";
}
+
+ Option<bool> allowPatternRollback{*this, "allow-pattern-rollback",
+ llvm::cl::desc("Allow pattern rollback"),
+ llvm::cl::init(true)};
};
} // namespace
More information about the Mlir-commits
mailing list