[Mlir-commits] [mlir] 9e8ccf6 - [mlir][Conversion] `FuncToLLVM`: Simplify bare-pointer handling (#96393)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun Jun 23 23:38:30 PDT 2024
Author: Matthias Springer
Date: 2024-06-24T08:38:26+02:00
New Revision: 9e8ccf6b6410a598f94d2ce4c29d753b8609c907
URL: https://github.com/llvm/llvm-project/commit/9e8ccf6b6410a598f94d2ce4c29d753b8609c907
DIFF: https://github.com/llvm/llvm-project/commit/9e8ccf6b6410a598f94d2ce4c29d753b8609c907.diff
LOG: [mlir][Conversion] `FuncToLLVM`: Simplify bare-pointer handling (#96393)
Before this commit, there used to be a workaround in the
`func.func`/`gpu.func` op lowering when the bare-pointer calling
convention is enabled. This workaround "patched up" the argument
materializations for memref arguments. This can be done directly in the
argument materialization functions (as the TODOs in the code base
indicate).
This commit effectively reverts back to the old implementation
(a664c14001fa2359604527084c91d0864aa131a4) and adds additional checks to
make sure that bare pointers are used only for function entry block
arguments.
Added:
Modified:
mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
index 744236692fbb6..efb80467369a2 100644
--- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
+++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
@@ -268,55 +268,6 @@ static void wrapExternalFunction(OpBuilder &builder, Location loc,
}
}
-/// Modifies the body of the function to construct the `MemRefDescriptor` from
-/// the bare pointer calling convention lowering of `memref` types.
-static void modifyFuncOpToUseBarePtrCallingConv(
- ConversionPatternRewriter &rewriter, Location loc,
- const LLVMTypeConverter &typeConverter, LLVM::LLVMFuncOp funcOp,
- TypeRange oldArgTypes) {
- if (funcOp.getBody().empty())
- return;
-
- // Promote bare pointers from memref arguments to memref descriptors at the
- // beginning of the function so that all the memrefs in the function have a
- // uniform representation.
- Block *entryBlock = &funcOp.getBody().front();
- auto blockArgs = entryBlock->getArguments();
- assert(blockArgs.size() == oldArgTypes.size() &&
- "The number of arguments and types doesn't match");
-
- OpBuilder::InsertionGuard guard(rewriter);
- rewriter.setInsertionPointToStart(entryBlock);
- for (auto it : llvm::zip(blockArgs, oldArgTypes)) {
- BlockArgument arg = std::get<0>(it);
- Type argTy = std::get<1>(it);
-
- // Unranked memrefs are not supported in the bare pointer calling
- // convention. We should have bailed out before in the presence of
- // unranked memrefs.
- assert(!isa<UnrankedMemRefType>(argTy) &&
- "Unranked memref is not supported");
- auto memrefTy = dyn_cast<MemRefType>(argTy);
- if (!memrefTy)
- continue;
-
- // Replace barePtr with a placeholder (undef), promote barePtr to a ranked
- // or unranked memref descriptor and replace placeholder with the last
- // instruction of the memref descriptor.
- // TODO: The placeholder is needed to avoid replacing barePtr uses in the
- // MemRef descriptor instructions. We may want to have a utility in the
- // rewriter to properly handle this use case.
- Location loc = funcOp.getLoc();
- auto placeholder = rewriter.create<LLVM::UndefOp>(
- loc, typeConverter.convertType(memrefTy));
- rewriter.replaceUsesOfBlockArgument(arg, placeholder);
-
- Value desc = MemRefDescriptor::fromStaticShape(rewriter, loc, typeConverter,
- memrefTy, arg);
- rewriter.replaceOp(placeholder, {desc});
- }
-}
-
FailureOr<LLVM::LLVMFuncOp>
mlir::convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp,
ConversionPatternRewriter &rewriter,
@@ -462,10 +413,6 @@ mlir::convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp,
wrapForExternalCallers(rewriter, funcOp->getLoc(), converter, funcOp,
newFuncOp);
}
- } else {
- modifyFuncOpToUseBarePtrCallingConv(
- rewriter, funcOp->getLoc(), converter, newFuncOp,
- llvm::cast<FunctionType>(funcOp.getFunctionType()).getInputs());
}
return newFuncOp;
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index 3e6fcc076fb4d..6053e34f30a41 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -182,35 +182,6 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
&signatureConversion)))
return failure();
- // If bare memref pointers are being used, remap them back to memref
- // descriptors This must be done after signature conversion to get rid of the
- // unrealized casts.
- if (getTypeConverter()->getOptions().useBarePtrCallConv) {
- OpBuilder::InsertionGuard guard(rewriter);
- rewriter.setInsertionPointToStart(&llvmFuncOp.getBody().front());
- for (const auto [idx, argTy] :
- llvm::enumerate(gpuFuncOp.getArgumentTypes())) {
- auto memrefTy = dyn_cast<MemRefType>(argTy);
- if (!memrefTy)
- continue;
- assert(memrefTy.hasStaticShape() &&
- "Bare pointer convertion used with dynamically-shaped memrefs");
- // Use a placeholder when replacing uses of the memref argument to prevent
- // circular replacements.
- auto remapping = signatureConversion.getInputMapping(idx);
- assert(remapping && remapping->size == 1 &&
- "Type converter should produce 1-to-1 mapping for bare memrefs");
- BlockArgument newArg =
- llvmFuncOp.getBody().getArgument(remapping->inputNo);
- auto placeholder = rewriter.create<LLVM::UndefOp>(
- loc, getTypeConverter()->convertType(memrefTy));
- rewriter.replaceUsesOfBlockArgument(newArg, placeholder);
- Value desc = MemRefDescriptor::fromStaticShape(
- rewriter, loc, *getTypeConverter(), memrefTy, newArg);
- rewriter.replaceOp(placeholder, {desc});
- }
- }
-
// Get memref type from function arguments and set the noalias to
// pointer arguments.
for (const auto [idx, argTy] :
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index 3a01795ce3f53..f5620a6a7cd91 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -159,18 +159,30 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
addArgumentMaterialization(
[&](OpBuilder &builder, UnrankedMemRefType resultType, ValueRange inputs,
Location loc) -> std::optional<Value> {
- if (inputs.size() == 1)
+ if (inputs.size() == 1) {
+ // Bare pointers are not supported for unranked memrefs because a
+ // memref descriptor cannot be built just from a bare pointer.
return std::nullopt;
+ }
return UnrankedMemRefDescriptor::pack(builder, loc, *this, resultType,
inputs);
});
addArgumentMaterialization([&](OpBuilder &builder, MemRefType resultType,
ValueRange inputs,
Location loc) -> std::optional<Value> {
- // TODO: bare ptr conversion could be handled here but we would need a way
- // to distinguish between FuncOp and other regions.
- if (inputs.size() == 1)
- return std::nullopt;
+ if (inputs.size() == 1) {
+ // This is a bare pointer. We allow bare pointers only for function entry
+ // blocks.
+ BlockArgument barePtr = dyn_cast<BlockArgument>(inputs.front());
+ if (!barePtr)
+ return std::nullopt;
+ Block *block = barePtr.getOwner();
+ if (!block->isEntryBlock() ||
+ !isa<FunctionOpInterface>(block->getParentOp()))
+ return std::nullopt;
+ return MemRefDescriptor::fromStaticShape(builder, loc, *this, resultType,
+ inputs[0]);
+ }
return MemRefDescriptor::pack(builder, loc, *this, resultType, inputs);
});
// Add generic source and target materializations to handle cases where
More information about the Mlir-commits
mailing list