[llvm-branch-commits] [mlir] [mlir][Conversion] `FuncToLLVM`: Simplify bare-pointer handling (PR #96393)

Matthias Springer via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Sat Jun 22 06:55:08 PDT 2024


https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/96393

Before this commit, there used to be a workaround in the `func.func`/`gpu.func` op lowering when the bare-pointer calling convention is enabled. This workaround "patched up" the argument materializations for memref arguments. This can be done directly in the argument materialization functions (as the TODOs in the code base indicate).

This commit effectively reverts back to the old implementation (a664c14001fa2359604527084c91d0864aa131a4) and adds additional checks to make sure that bare pointers are used only for function entry block arguments.


>From f65911a2b08c538d24a9b2044123390ceae5b4b5 Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Sat, 22 Jun 2024 14:54:21 +0200
Subject: [PATCH] [mlir][Conversion] `FuncToLLVM`: Simplify bare-pointer
 handling

Before this commit, there used to be a workaround in the `func.func`/`gpu.func` op lowering when the bare-pointer calling convention was enabled. This workaround "patched up" the argument materializations for memref arguments. This can be done directly in the argument materialization functions (as the TODOs in the code base indicate).
---
 mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp |  53 --------
 .../Conversion/GPUCommon/GPUOpsLowering.cpp   | 128 +++++++-----------
 .../Conversion/LLVMCommon/TypeConverter.cpp   |  22 ++-
 3 files changed, 66 insertions(+), 137 deletions(-)

diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
index 744236692fbb6..efb80467369a2 100644
--- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
+++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
@@ -268,55 +268,6 @@ static void wrapExternalFunction(OpBuilder &builder, Location loc,
   }
 }
 
-/// Modifies the body of the function to construct the `MemRefDescriptor` from
-/// the bare pointer calling convention lowering of `memref` types.
-static void modifyFuncOpToUseBarePtrCallingConv(
-    ConversionPatternRewriter &rewriter, Location loc,
-    const LLVMTypeConverter &typeConverter, LLVM::LLVMFuncOp funcOp,
-    TypeRange oldArgTypes) {
-  if (funcOp.getBody().empty())
-    return;
-
-  // Promote bare pointers from memref arguments to memref descriptors at the
-  // beginning of the function so that all the memrefs in the function have a
-  // uniform representation.
-  Block *entryBlock = &funcOp.getBody().front();
-  auto blockArgs = entryBlock->getArguments();
-  assert(blockArgs.size() == oldArgTypes.size() &&
-         "The number of arguments and types doesn't match");
-
-  OpBuilder::InsertionGuard guard(rewriter);
-  rewriter.setInsertionPointToStart(entryBlock);
-  for (auto it : llvm::zip(blockArgs, oldArgTypes)) {
-    BlockArgument arg = std::get<0>(it);
-    Type argTy = std::get<1>(it);
-
-    // Unranked memrefs are not supported in the bare pointer calling
-    // convention. We should have bailed out before in the presence of
-    // unranked memrefs.
-    assert(!isa<UnrankedMemRefType>(argTy) &&
-           "Unranked memref is not supported");
-    auto memrefTy = dyn_cast<MemRefType>(argTy);
-    if (!memrefTy)
-      continue;
-
-    // Replace barePtr with a placeholder (undef), promote barePtr to a ranked
-    // or unranked memref descriptor and replace placeholder with the last
-    // instruction of the memref descriptor.
-    // TODO: The placeholder is needed to avoid replacing barePtr uses in the
-    // MemRef descriptor instructions. We may want to have a utility in the
-    // rewriter to properly handle this use case.
-    Location loc = funcOp.getLoc();
-    auto placeholder = rewriter.create<LLVM::UndefOp>(
-        loc, typeConverter.convertType(memrefTy));
-    rewriter.replaceUsesOfBlockArgument(arg, placeholder);
-
-    Value desc = MemRefDescriptor::fromStaticShape(rewriter, loc, typeConverter,
-                                                   memrefTy, arg);
-    rewriter.replaceOp(placeholder, {desc});
-  }
-}
-
 FailureOr<LLVM::LLVMFuncOp>
 mlir::convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp,
                                 ConversionPatternRewriter &rewriter,
@@ -462,10 +413,6 @@ mlir::convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp,
         wrapForExternalCallers(rewriter, funcOp->getLoc(), converter, funcOp,
                                newFuncOp);
     }
-  } else {
-    modifyFuncOpToUseBarePtrCallingConv(
-        rewriter, funcOp->getLoc(), converter, newFuncOp,
-        llvm::cast<FunctionType>(funcOp.getFunctionType()).getInputs());
   }
 
   return newFuncOp;
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index 7ea05b7e7f6c1..6053e34f30a41 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -182,35 +182,6 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
                                          &signatureConversion)))
     return failure();
 
-  // If bare memref pointers are being used, remap them back to memref
-  // descriptors This must be done after signature conversion to get rid of the
-  // unrealized casts.
-  if (getTypeConverter()->getOptions().useBarePtrCallConv) {
-    OpBuilder::InsertionGuard guard(rewriter);
-    rewriter.setInsertionPointToStart(&llvmFuncOp.getBody().front());
-    for (const auto [idx, argTy] :
-         llvm::enumerate(gpuFuncOp.getArgumentTypes())) {
-      auto memrefTy = dyn_cast<MemRefType>(argTy);
-      if (!memrefTy)
-        continue;
-      assert(memrefTy.hasStaticShape() &&
-             "Bare pointer convertion used with dynamically-shaped memrefs");
-      // Use a placeholder when replacing uses of the memref argument to prevent
-      // circular replacements.
-      auto remapping = signatureConversion.getInputMapping(idx);
-      assert(remapping && remapping->size == 1 &&
-             "Type converter should produce 1-to-1 mapping for bare memrefs");
-      BlockArgument newArg =
-          llvmFuncOp.getBody().getArgument(remapping->inputNo);
-      auto placeholder = rewriter.create<LLVM::UndefOp>(
-          loc, getTypeConverter()->convertType(memrefTy));
-      rewriter.replaceUsesOfBlockArgument(newArg, placeholder);
-      Value desc = MemRefDescriptor::fromStaticShape(
-          rewriter, loc, *getTypeConverter(), memrefTy, newArg);
-      rewriter.replaceOp(placeholder, {desc});
-    }
-  }
-
   // Get memref type from function arguments and set the noalias to
   // pointer arguments.
   for (const auto [idx, argTy] :
@@ -684,62 +655,61 @@ LogicalResult GPUDynamicSharedMemoryOpLowering::matchAndRewrite(
   return success();
 }
 
-  LogicalResult
-  GPUReturnOpLowering::matchAndRewrite(gpu::ReturnOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const {
-    Location loc = op.getLoc();
-    unsigned numArguments = op.getNumOperands();
-    SmallVector<Value, 4> updatedOperands;
-
-    bool useBarePtrCallConv =
-        getTypeConverter()->getOptions().useBarePtrCallConv;
-    if (useBarePtrCallConv) {
-      // For the bare-ptr calling convention, extract the aligned pointer to
-      // be returned from the memref descriptor.
-      for (auto it : llvm::zip(op->getOperands(), adaptor.getOperands())) {
-        Type oldTy = std::get<0>(it).getType();
-        Value newOperand = std::get<1>(it);
-        if (isa<MemRefType>(oldTy) && getTypeConverter()->canConvertToBarePtr(
-                                          cast<BaseMemRefType>(oldTy))) {
-          MemRefDescriptor memrefDesc(newOperand);
-          newOperand = memrefDesc.allocatedPtr(rewriter, loc);
-        } else if (isa<UnrankedMemRefType>(oldTy)) {
-          // Unranked memref is not supported in the bare pointer calling
-          // convention.
-          return failure();
-        }
-        updatedOperands.push_back(newOperand);
+LogicalResult GPUReturnOpLowering::matchAndRewrite(
+    gpu::ReturnOp op, OpAdaptor adaptor,
+    ConversionPatternRewriter &rewriter) const {
+  Location loc = op.getLoc();
+  unsigned numArguments = op.getNumOperands();
+  SmallVector<Value, 4> updatedOperands;
+
+  bool useBarePtrCallConv = getTypeConverter()->getOptions().useBarePtrCallConv;
+  if (useBarePtrCallConv) {
+    // For the bare-ptr calling convention, extract the aligned pointer to
+    // be returned from the memref descriptor.
+    for (auto it : llvm::zip(op->getOperands(), adaptor.getOperands())) {
+      Type oldTy = std::get<0>(it).getType();
+      Value newOperand = std::get<1>(it);
+      if (isa<MemRefType>(oldTy) && getTypeConverter()->canConvertToBarePtr(
+                                        cast<BaseMemRefType>(oldTy))) {
+        MemRefDescriptor memrefDesc(newOperand);
+        newOperand = memrefDesc.allocatedPtr(rewriter, loc);
+      } else if (isa<UnrankedMemRefType>(oldTy)) {
+        // Unranked memref is not supported in the bare pointer calling
+        // convention.
+        return failure();
       }
-    } else {
-      updatedOperands = llvm::to_vector<4>(adaptor.getOperands());
-      (void)copyUnrankedDescriptors(rewriter, loc, op.getOperands().getTypes(),
-                                    updatedOperands,
-                                    /*toDynamic=*/true);
+      updatedOperands.push_back(newOperand);
     }
+  } else {
+    updatedOperands = llvm::to_vector<4>(adaptor.getOperands());
+    (void)copyUnrankedDescriptors(rewriter, loc, op.getOperands().getTypes(),
+                                  updatedOperands,
+                                  /*toDynamic=*/true);
+  }
 
-    // If ReturnOp has 0 or 1 operand, create it and return immediately.
-    if (numArguments <= 1) {
-      rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(
-          op, TypeRange(), updatedOperands, op->getAttrs());
-      return success();
-    }
+  // If ReturnOp has 0 or 1 operand, create it and return immediately.
+  if (numArguments <= 1) {
+    rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(
+        op, TypeRange(), updatedOperands, op->getAttrs());
+    return success();
+  }
 
-    // Otherwise, we need to pack the arguments into an LLVM struct type before
-    // returning.
-    auto packedType = getTypeConverter()->packFunctionResults(
-        op.getOperandTypes(), useBarePtrCallConv);
-    if (!packedType) {
-      return rewriter.notifyMatchFailure(op, "could not convert result types");
-    }
+  // Otherwise, we need to pack the arguments into an LLVM struct type before
+  // returning.
+  auto packedType = getTypeConverter()->packFunctionResults(
+      op.getOperandTypes(), useBarePtrCallConv);
+  if (!packedType) {
+    return rewriter.notifyMatchFailure(op, "could not convert result types");
+  }
 
-    Value packed = rewriter.create<LLVM::UndefOp>(loc, packedType);
-    for (auto [idx, operand] : llvm::enumerate(updatedOperands)) {
-      packed = rewriter.create<LLVM::InsertValueOp>(loc, packed, operand, idx);
-    }
-    rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, TypeRange(), packed,
-                                                op->getAttrs());
-    return success();
+  Value packed = rewriter.create<LLVM::UndefOp>(loc, packedType);
+  for (auto [idx, operand] : llvm::enumerate(updatedOperands)) {
+    packed = rewriter.create<LLVM::InsertValueOp>(loc, packed, operand, idx);
   }
+  rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, TypeRange(), packed,
+                                              op->getAttrs());
+  return success();
+}
 
 void mlir::populateGpuMemorySpaceAttributeConversions(
     TypeConverter &typeConverter, const MemorySpaceMapping &mapping) {
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index 3a01795ce3f53..f5620a6a7cd91 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -159,18 +159,30 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
   addArgumentMaterialization(
       [&](OpBuilder &builder, UnrankedMemRefType resultType, ValueRange inputs,
           Location loc) -> std::optional<Value> {
-        if (inputs.size() == 1)
+        if (inputs.size() == 1) {
+          // Bare pointers are not supported for unranked memrefs because a
+          // memref descriptor cannot be built just from a bare pointer.
           return std::nullopt;
+        }
         return UnrankedMemRefDescriptor::pack(builder, loc, *this, resultType,
                                               inputs);
       });
   addArgumentMaterialization([&](OpBuilder &builder, MemRefType resultType,
                                  ValueRange inputs,
                                  Location loc) -> std::optional<Value> {
-    // TODO: bare ptr conversion could be handled here but we would need a way
-    // to distinguish between FuncOp and other regions.
-    if (inputs.size() == 1)
-      return std::nullopt;
+    if (inputs.size() == 1) {
+      // This is a bare pointer. We allow bare pointers only for function entry
+      // blocks.
+      BlockArgument barePtr = dyn_cast<BlockArgument>(inputs.front());
+      if (!barePtr)
+        return std::nullopt;
+      Block *block = barePtr.getOwner();
+      if (!block->isEntryBlock() ||
+          !isa<FunctionOpInterface>(block->getParentOp()))
+        return std::nullopt;
+      return MemRefDescriptor::fromStaticShape(builder, loc, *this, resultType,
+                                               inputs[0]);
+    }
     return MemRefDescriptor::pack(builder, loc, *this, resultType, inputs);
   });
   // Add generic source and target materializations to handle cases where



More information about the llvm-branch-commits mailing list