[Mlir-commits] [mlir] a664c14 - [mlir][LLVM] Revert bareptr calling convention handling as an argument materialization.
Mehdi Amini
llvmlistbot at llvm.org
Wed Jul 21 15:07:04 PDT 2021
Author: Nicolas Vasilache
Date: 2021-07-21T22:06:50Z
New Revision: a664c14001fa2359604527084c91d0864aa131a4
URL: https://github.com/llvm/llvm-project/commit/a664c14001fa2359604527084c91d0864aa131a4
DIFF: https://github.com/llvm/llvm-project/commit/a664c14001fa2359604527084c91d0864aa131a4.diff
LOG: [mlir][LLVM] Revert bareptr calling convention handling as an argument materialization.
Type conversion and argument materialization are context-free: there is no available information on which op / branch is currently being converted.
As a consequence, bare ptr convention cannot be handled as an argument materialization: it would apply irrespectively of the parent op.
This doesn't typecheck in the case of non-funcOp and we would see cases where a memref descriptor would be inserted in place of the pointer in another memref descriptor.
For now the proper behavior is to revert to a specific BarePtrFunc implementation and drop the blanket argument materialization logic.
This reverts the relevant piece of the conversion to LLVM to what it was before https://reviews.llvm.org/D105880 and adds a relevant test and documentation to avoid the mistake by whomever attempts this again in the future.
Reviewed By: arpith-jacob
Differential Revision: https://reviews.llvm.org/D106495
Added:
Modified:
mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
mlir/test/Conversion/StandardToLLVM/func-memref.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index 623bf963f3fb..ca1882899edf 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -58,16 +58,8 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
addArgumentMaterialization([&](OpBuilder &builder, MemRefType resultType,
ValueRange inputs,
Location loc) -> Optional<Value> {
- // Explicit "this" is necessary here because otherwise "options" resolves to
- // the argument of the parent function (constructor), which is a reference
- // and not a copy. This can lead to UB when the lambda is actually called.
- if (this->options.useBarePtrCallConv) {
- if (!resultType.hasStaticShape())
- return llvm::None;
- Value v = MemRefDescriptor::fromStaticShape(builder, loc, *this,
- resultType, inputs[0]);
- return v;
- }
+ // TODO: bare ptr conversion could be handled here but we would need a way
+ // to distinguish between FuncOp and other regions.
if (inputs.size() == 1)
return llvm::None;
return MemRefDescriptor::pack(builder, loc, *this, resultType, inputs);
diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index 884cc76e170e..4079e0932a02 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -309,9 +309,62 @@ struct BarePtrFuncOpConversion : public FuncOpConversionBase {
LogicalResult
matchAndRewrite(FuncOp funcOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
+
+ // TODO: bare ptr conversion could be handled by argument materialization
+ // and most of the code below would go away. But to do this, we would need a
+ // way to distinguish between FuncOp and other regions in the
+ // addArgumentMaterialization hook.
+
+ // Store the type of memref-typed arguments before the conversion so that we
+ // can promote them to MemRef descriptor at the beginning of the function.
+ SmallVector<Type, 8> oldArgTypes =
+ llvm::to_vector<8>(funcOp.getType().getInputs());
+
auto newFuncOp = convertFuncOpToLLVMFuncOp(funcOp, rewriter);
if (!newFuncOp)
return failure();
+ if (newFuncOp.getBody().empty()) {
+ rewriter.eraseOp(funcOp);
+ return success();
+ }
+
+ // Promote bare pointers from memref arguments to memref descriptors at the
+ // beginning of the function so that all the memrefs in the function have a
+ // uniform representation.
+ Block *entryBlock = &newFuncOp.getBody().front();
+ auto blockArgs = entryBlock->getArguments();
+ assert(blockArgs.size() == oldArgTypes.size() &&
+ "The number of arguments and types doesn't match");
+
+ OpBuilder::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPointToStart(entryBlock);
+ for (auto it : llvm::zip(blockArgs, oldArgTypes)) {
+ BlockArgument arg = std::get<0>(it);
+ Type argTy = std::get<1>(it);
+
+ // Unranked memrefs are not supported in the bare pointer calling
+ // convention. We should have bailed out before in the presence of
+ // unranked memrefs.
+ assert(!argTy.isa<UnrankedMemRefType>() &&
+ "Unranked memref is not supported");
+ auto memrefTy = argTy.dyn_cast<MemRefType>();
+ if (!memrefTy)
+ continue;
+
+ // Replace barePtr with a placeholder (undef), promote barePtr to a ranked
+ // or unranked memref descriptor and replace placeholder with the last
+ // instruction of the memref descriptor.
+ // TODO: The placeholder is needed to avoid replacing barePtr uses in the
+ // MemRef descriptor instructions. We may want to have a utility in the
+ // rewriter to properly handle this use case.
+ Location loc = funcOp.getLoc();
+ auto placeholder = rewriter.create<LLVM::UndefOp>(loc, memrefTy);
+ rewriter.replaceUsesOfBlockArgument(arg, placeholder);
+
+ Value desc = MemRefDescriptor::fromStaticShape(
+ rewriter, loc, *getTypeConverter(), memrefTy, arg);
+ rewriter.replaceOp(placeholder, {desc});
+ }
rewriter.eraseOp(funcOp);
return success();
@@ -330,7 +383,8 @@ using DivFOpLowering = VectorConvertToLLVMPattern<DivFOp, LLVM::FDivOp>;
using FPExtOpLowering = VectorConvertToLLVMPattern<FPExtOp, LLVM::FPExtOp>;
using FPToSIOpLowering = VectorConvertToLLVMPattern<FPToSIOp, LLVM::FPToSIOp>;
using FPToUIOpLowering = VectorConvertToLLVMPattern<FPToUIOp, LLVM::FPToUIOp>;
-using FPTruncOpLowering = VectorConvertToLLVMPattern<FPTruncOp, LLVM::FPTruncOp>;
+using FPTruncOpLowering =
+ VectorConvertToLLVMPattern<FPTruncOp, LLVM::FPTruncOp>;
using FloorFOpLowering = VectorConvertToLLVMPattern<FloorFOp, LLVM::FFloorOp>;
using FmaFOpLowering = VectorConvertToLLVMPattern<FmaFOp, LLVM::FMAOp>;
using MulFOpLowering = VectorConvertToLLVMPattern<MulFOp, LLVM::FMulOp>;
@@ -352,7 +406,8 @@ using SignedShiftRightOpLowering =
OneToOneConvertToLLVMPattern<SignedShiftRightOp, LLVM::AShrOp>;
using SubFOpLowering = VectorConvertToLLVMPattern<SubFOp, LLVM::FSubOp>;
using SubIOpLowering = VectorConvertToLLVMPattern<SubIOp, LLVM::SubOp>;
-using TruncateIOpLowering = VectorConvertToLLVMPattern<TruncateIOp, LLVM::TruncOp>;
+using TruncateIOpLowering =
+ VectorConvertToLLVMPattern<TruncateIOp, LLVM::TruncOp>;
using UIToFPOpLowering = VectorConvertToLLVMPattern<UIToFPOp, LLVM::UIToFPOp>;
using UnsignedDivIOpLowering =
VectorConvertToLLVMPattern<UnsignedDivIOp, LLVM::UDivOp>;
@@ -1196,4 +1251,3 @@ mlir::createLowerToLLVMPass(const LowerToLLVMOptions &options) {
options.useBarePtrCallConv, options.emitCWrappers,
options.getIndexBitwidth(), useAlignedAlloc, options.dataLayout);
}
-
diff --git a/mlir/test/Conversion/StandardToLLVM/func-memref.mlir b/mlir/test/Conversion/StandardToLLVM/func-memref.mlir
index b96c260d5a0f..fefffd683be5 100644
--- a/mlir/test/Conversion/StandardToLLVM/func-memref.mlir
+++ b/mlir/test/Conversion/StandardToLLVM/func-memref.mlir
@@ -182,3 +182,28 @@ func @check_scalar_func_call(%in : f32) {
%res = call @goo(%in) : (f32) -> (f32)
return
}
+
+// -----
+
+!base_type = type memref<64xi32, 201>
+
+// CHECK-LABEL: func @loop_carried
+// BAREPTR-LABEL: func @loop_carried
+func @loop_carried(%arg0 : index, %arg1 : index, %arg2 : index, %base0 : !base_type, %base1 : !base_type) -> (!base_type, !base_type) {
+ // This test checks that in the BAREPTR case, the branch arguments only forward the descriptor.
+ // This test was lowered from a simple scf.for that swaps 2 memref iter_args.
+ // BAREPTR: llvm.br ^bb1(%{{.*}}, %{{.*}}, %{{.*}} : i64, !llvm.struct<(ptr<i32, 201>, ptr<i32, 201>, i64, array<1 x i64>, array<1 x i64>)>, !llvm.struct<(ptr<i32, 201>, ptr<i32, 201>, i64, array<1 x i64>, array<1 x i64>)>)
+ br ^bb1(%arg0, %base0, %base1 : index, memref<64xi32, 201>, memref<64xi32, 201>)
+
+ // BAREPTR-NEXT: ^bb1
+ // BAREPTR-NEXT: llvm.icmp
+ // BAREPTR-NEXT: llvm.cond_br %{{.*}}, ^bb2, ^bb3
+ ^bb1(%0: index, %1: memref<64xi32, 201>, %2: memref<64xi32, 201>): // 2 preds: ^bb0, ^bb2
+ %3 = cmpi slt, %0, %arg1 : index
+ cond_br %3, ^bb2, ^bb3
+ ^bb2: // pred: ^bb1
+ %4 = addi %0, %arg2 : index
+ br ^bb1(%4, %2, %1 : index, memref<64xi32, 201>, memref<64xi32, 201>)
+ ^bb3: // pred: ^bb1
+ return %1, %2 : memref<64xi32, 201>, memref<64xi32, 201>
+}
More information about the Mlir-commits
mailing list