[Mlir-commits] [mlir] a664c14 - [mlir][LLVM] Revert bareptr calling convention handling as an argument materialization.

Mehdi Amini llvmlistbot at llvm.org
Wed Jul 21 15:07:04 PDT 2021


Author: Nicolas Vasilache
Date: 2021-07-21T22:06:50Z
New Revision: a664c14001fa2359604527084c91d0864aa131a4

URL: https://github.com/llvm/llvm-project/commit/a664c14001fa2359604527084c91d0864aa131a4
DIFF: https://github.com/llvm/llvm-project/commit/a664c14001fa2359604527084c91d0864aa131a4.diff

LOG: [mlir][LLVM] Revert bareptr calling convention handling as an argument materialization.

Type conversion and argument materialization are context-free: there is no available information on which op / branch is currently being converted.
As a consequence, bare ptr convention cannot be handled as an argument materialization: it would apply irrespectively of the parent op.
This doesn't typecheck in the case of non-funcOp and we would see cases where a memref descriptor would be inserted in place of the pointer in another memref descriptor.

For now the proper behavior is to revert to a specific BarePtrFunc implementation and drop the blanket argument materialization logic.

This reverts the relevant piece of the conversion to LLVM to what it was before https://reviews.llvm.org/D105880 and adds a relevant test and documentation to avoid the mistake by whomever attempts this again in the future.

Reviewed By: arpith-jacob

Differential Revision: https://reviews.llvm.org/D106495

Added: 
    

Modified: 
    mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
    mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
    mlir/test/Conversion/StandardToLLVM/func-memref.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index 623bf963f3fb..ca1882899edf 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -58,16 +58,8 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
   addArgumentMaterialization([&](OpBuilder &builder, MemRefType resultType,
                                  ValueRange inputs,
                                  Location loc) -> Optional<Value> {
-    // Explicit "this" is necessary here because otherwise "options" resolves to
-    // the argument of the parent function (constructor), which is a reference
-    // and not a copy. This can lead to UB when the lambda is actually called.
-    if (this->options.useBarePtrCallConv) {
-      if (!resultType.hasStaticShape())
-        return llvm::None;
-      Value v = MemRefDescriptor::fromStaticShape(builder, loc, *this,
-                                                  resultType, inputs[0]);
-      return v;
-    }
+    // TODO: bare ptr conversion could be handled here but we would need a way
+    // to distinguish between FuncOp and other regions.
     if (inputs.size() == 1)
       return llvm::None;
     return MemRefDescriptor::pack(builder, loc, *this, resultType, inputs);

diff  --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index 884cc76e170e..4079e0932a02 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -309,9 +309,62 @@ struct BarePtrFuncOpConversion : public FuncOpConversionBase {
   LogicalResult
   matchAndRewrite(FuncOp funcOp, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
+
+    // TODO: bare ptr conversion could be handled by argument materialization
+    // and most of the code below would go away. But to do this, we would need a
+    // way to distinguish between FuncOp and other regions in the
+    // addArgumentMaterialization hook.
+
+    // Store the type of memref-typed arguments before the conversion so that we
+    // can promote them to MemRef descriptor at the beginning of the function.
+    SmallVector<Type, 8> oldArgTypes =
+        llvm::to_vector<8>(funcOp.getType().getInputs());
+
     auto newFuncOp = convertFuncOpToLLVMFuncOp(funcOp, rewriter);
     if (!newFuncOp)
       return failure();
+    if (newFuncOp.getBody().empty()) {
+      rewriter.eraseOp(funcOp);
+      return success();
+    }
+
+    // Promote bare pointers from memref arguments to memref descriptors at the
+    // beginning of the function so that all the memrefs in the function have a
+    // uniform representation.
+    Block *entryBlock = &newFuncOp.getBody().front();
+    auto blockArgs = entryBlock->getArguments();
+    assert(blockArgs.size() == oldArgTypes.size() &&
+           "The number of arguments and types doesn't match");
+
+    OpBuilder::InsertionGuard guard(rewriter);
+    rewriter.setInsertionPointToStart(entryBlock);
+    for (auto it : llvm::zip(blockArgs, oldArgTypes)) {
+      BlockArgument arg = std::get<0>(it);
+      Type argTy = std::get<1>(it);
+
+      // Unranked memrefs are not supported in the bare pointer calling
+      // convention. We should have bailed out before in the presence of
+      // unranked memrefs.
+      assert(!argTy.isa<UnrankedMemRefType>() &&
+             "Unranked memref is not supported");
+      auto memrefTy = argTy.dyn_cast<MemRefType>();
+      if (!memrefTy)
+        continue;
+
+      // Replace barePtr with a placeholder (undef), promote barePtr to a ranked
+      // or unranked memref descriptor and replace placeholder with the last
+      // instruction of the memref descriptor.
+      // TODO: The placeholder is needed to avoid replacing barePtr uses in the
+      // MemRef descriptor instructions. We may want to have a utility in the
+      // rewriter to properly handle this use case.
+      Location loc = funcOp.getLoc();
+      auto placeholder = rewriter.create<LLVM::UndefOp>(loc, memrefTy);
+      rewriter.replaceUsesOfBlockArgument(arg, placeholder);
+
+      Value desc = MemRefDescriptor::fromStaticShape(
+          rewriter, loc, *getTypeConverter(), memrefTy, arg);
+      rewriter.replaceOp(placeholder, {desc});
+    }
 
     rewriter.eraseOp(funcOp);
     return success();
@@ -330,7 +383,8 @@ using DivFOpLowering = VectorConvertToLLVMPattern<DivFOp, LLVM::FDivOp>;
 using FPExtOpLowering = VectorConvertToLLVMPattern<FPExtOp, LLVM::FPExtOp>;
 using FPToSIOpLowering = VectorConvertToLLVMPattern<FPToSIOp, LLVM::FPToSIOp>;
 using FPToUIOpLowering = VectorConvertToLLVMPattern<FPToUIOp, LLVM::FPToUIOp>;
-using FPTruncOpLowering = VectorConvertToLLVMPattern<FPTruncOp, LLVM::FPTruncOp>;
+using FPTruncOpLowering =
+    VectorConvertToLLVMPattern<FPTruncOp, LLVM::FPTruncOp>;
 using FloorFOpLowering = VectorConvertToLLVMPattern<FloorFOp, LLVM::FFloorOp>;
 using FmaFOpLowering = VectorConvertToLLVMPattern<FmaFOp, LLVM::FMAOp>;
 using MulFOpLowering = VectorConvertToLLVMPattern<MulFOp, LLVM::FMulOp>;
@@ -352,7 +406,8 @@ using SignedShiftRightOpLowering =
     OneToOneConvertToLLVMPattern<SignedShiftRightOp, LLVM::AShrOp>;
 using SubFOpLowering = VectorConvertToLLVMPattern<SubFOp, LLVM::FSubOp>;
 using SubIOpLowering = VectorConvertToLLVMPattern<SubIOp, LLVM::SubOp>;
-using TruncateIOpLowering = VectorConvertToLLVMPattern<TruncateIOp, LLVM::TruncOp>;
+using TruncateIOpLowering =
+    VectorConvertToLLVMPattern<TruncateIOp, LLVM::TruncOp>;
 using UIToFPOpLowering = VectorConvertToLLVMPattern<UIToFPOp, LLVM::UIToFPOp>;
 using UnsignedDivIOpLowering =
     VectorConvertToLLVMPattern<UnsignedDivIOp, LLVM::UDivOp>;
@@ -1196,4 +1251,3 @@ mlir::createLowerToLLVMPass(const LowerToLLVMOptions &options) {
       options.useBarePtrCallConv, options.emitCWrappers,
       options.getIndexBitwidth(), useAlignedAlloc, options.dataLayout);
 }
-

diff  --git a/mlir/test/Conversion/StandardToLLVM/func-memref.mlir b/mlir/test/Conversion/StandardToLLVM/func-memref.mlir
index b96c260d5a0f..fefffd683be5 100644
--- a/mlir/test/Conversion/StandardToLLVM/func-memref.mlir
+++ b/mlir/test/Conversion/StandardToLLVM/func-memref.mlir
@@ -182,3 +182,28 @@ func @check_scalar_func_call(%in : f32) {
   %res = call @goo(%in) : (f32) -> (f32)
   return
 }
+
+// -----
+
+!base_type = type memref<64xi32, 201>
+
+// CHECK-LABEL: func @loop_carried
+// BAREPTR-LABEL: func @loop_carried
+func @loop_carried(%arg0 : index, %arg1 : index, %arg2 : index, %base0 : !base_type, %base1 : !base_type) -> (!base_type, !base_type) {
+  // This test checks that in the BAREPTR case, the branch arguments only forward the descriptor.
+  // This test was lowered from a simple scf.for that swaps 2 memref iter_args.
+  //      BAREPTR: llvm.br ^bb1(%{{.*}}, %{{.*}}, %{{.*}} : i64, !llvm.struct<(ptr<i32, 201>, ptr<i32, 201>, i64, array<1 x i64>, array<1 x i64>)>, !llvm.struct<(ptr<i32, 201>, ptr<i32, 201>, i64, array<1 x i64>, array<1 x i64>)>)
+  br ^bb1(%arg0, %base0, %base1 : index, memref<64xi32, 201>, memref<64xi32, 201>)
+
+  // BAREPTR-NEXT: ^bb1
+  // BAREPTR-NEXT:   llvm.icmp
+  // BAREPTR-NEXT:   llvm.cond_br %{{.*}}, ^bb2, ^bb3
+  ^bb1(%0: index, %1: memref<64xi32, 201>, %2: memref<64xi32, 201>):  // 2 preds: ^bb0, ^bb2
+    %3 = cmpi slt, %0, %arg1 : index
+    cond_br %3, ^bb2, ^bb3
+  ^bb2:  // pred: ^bb1
+    %4 = addi %0, %arg2 : index
+    br ^bb1(%4, %2, %1 : index, memref<64xi32, 201>, memref<64xi32, 201>)
+  ^bb3:  // pred: ^bb1
+    return %1, %2 : memref<64xi32, 201>, memref<64xi32, 201>
+}


        


More information about the Mlir-commits mailing list