[llvm-branch-commits] [mlir] [mlir][Conversion] `FuncToLLVM`: Simplify bare-pointer handling (PR #96393)

Matthias Springer via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Sat Jun 22 06:56:32 PDT 2024


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/96393

>From 2d838580bf8c17ea7a17d73415b3c64c1775b37d Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Sat, 22 Jun 2024 14:54:21 +0200
Subject: [PATCH] [mlir][Conversion] `FuncToLLVM`: Simplify bare-pointer
 handling

Before this commit, there used to be a workaround in the `func.func`/`gpu.func` op lowering when the bare-pointer calling convention was enabled. This workaround "patched up" the argument materializations for memref arguments. This can be done directly in the argument materialization functions (as the TODOs in the code base indicate).
---
 mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp | 53 -------------------
 .../Conversion/GPUCommon/GPUOpsLowering.cpp   | 29 ----------
 .../Conversion/LLVMCommon/TypeConverter.cpp   | 22 ++++++--
 3 files changed, 17 insertions(+), 87 deletions(-)

diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
index 744236692fbb6..efb80467369a2 100644
--- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
+++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
@@ -268,55 +268,6 @@ static void wrapExternalFunction(OpBuilder &builder, Location loc,
   }
 }
 
-/// Modifies the body of the function to construct the `MemRefDescriptor` from
-/// the bare pointer calling convention lowering of `memref` types.
-static void modifyFuncOpToUseBarePtrCallingConv(
-    ConversionPatternRewriter &rewriter, Location loc,
-    const LLVMTypeConverter &typeConverter, LLVM::LLVMFuncOp funcOp,
-    TypeRange oldArgTypes) {
-  if (funcOp.getBody().empty())
-    return;
-
-  // Promote bare pointers from memref arguments to memref descriptors at the
-  // beginning of the function so that all the memrefs in the function have a
-  // uniform representation.
-  Block *entryBlock = &funcOp.getBody().front();
-  auto blockArgs = entryBlock->getArguments();
-  assert(blockArgs.size() == oldArgTypes.size() &&
-         "The number of arguments and types doesn't match");
-
-  OpBuilder::InsertionGuard guard(rewriter);
-  rewriter.setInsertionPointToStart(entryBlock);
-  for (auto it : llvm::zip(blockArgs, oldArgTypes)) {
-    BlockArgument arg = std::get<0>(it);
-    Type argTy = std::get<1>(it);
-
-    // Unranked memrefs are not supported in the bare pointer calling
-    // convention. We should have bailed out before in the presence of
-    // unranked memrefs.
-    assert(!isa<UnrankedMemRefType>(argTy) &&
-           "Unranked memref is not supported");
-    auto memrefTy = dyn_cast<MemRefType>(argTy);
-    if (!memrefTy)
-      continue;
-
-    // Replace barePtr with a placeholder (undef), promote barePtr to a ranked
-    // or unranked memref descriptor and replace placeholder with the last
-    // instruction of the memref descriptor.
-    // TODO: The placeholder is needed to avoid replacing barePtr uses in the
-    // MemRef descriptor instructions. We may want to have a utility in the
-    // rewriter to properly handle this use case.
-    Location loc = funcOp.getLoc();
-    auto placeholder = rewriter.create<LLVM::UndefOp>(
-        loc, typeConverter.convertType(memrefTy));
-    rewriter.replaceUsesOfBlockArgument(arg, placeholder);
-
-    Value desc = MemRefDescriptor::fromStaticShape(rewriter, loc, typeConverter,
-                                                   memrefTy, arg);
-    rewriter.replaceOp(placeholder, {desc});
-  }
-}
-
 FailureOr<LLVM::LLVMFuncOp>
 mlir::convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp,
                                 ConversionPatternRewriter &rewriter,
@@ -462,10 +413,6 @@ mlir::convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp,
         wrapForExternalCallers(rewriter, funcOp->getLoc(), converter, funcOp,
                                newFuncOp);
     }
-  } else {
-    modifyFuncOpToUseBarePtrCallingConv(
-        rewriter, funcOp->getLoc(), converter, newFuncOp,
-        llvm::cast<FunctionType>(funcOp.getFunctionType()).getInputs());
   }
 
   return newFuncOp;
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index 3e6fcc076fb4d..6053e34f30a41 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -182,35 +182,6 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
                                          &signatureConversion)))
     return failure();
 
-  // If bare memref pointers are being used, remap them back to memref
-  // descriptors This must be done after signature conversion to get rid of the
-  // unrealized casts.
-  if (getTypeConverter()->getOptions().useBarePtrCallConv) {
-    OpBuilder::InsertionGuard guard(rewriter);
-    rewriter.setInsertionPointToStart(&llvmFuncOp.getBody().front());
-    for (const auto [idx, argTy] :
-         llvm::enumerate(gpuFuncOp.getArgumentTypes())) {
-      auto memrefTy = dyn_cast<MemRefType>(argTy);
-      if (!memrefTy)
-        continue;
-      assert(memrefTy.hasStaticShape() &&
-             "Bare pointer convertion used with dynamically-shaped memrefs");
-      // Use a placeholder when replacing uses of the memref argument to prevent
-      // circular replacements.
-      auto remapping = signatureConversion.getInputMapping(idx);
-      assert(remapping && remapping->size == 1 &&
-             "Type converter should produce 1-to-1 mapping for bare memrefs");
-      BlockArgument newArg =
-          llvmFuncOp.getBody().getArgument(remapping->inputNo);
-      auto placeholder = rewriter.create<LLVM::UndefOp>(
-          loc, getTypeConverter()->convertType(memrefTy));
-      rewriter.replaceUsesOfBlockArgument(newArg, placeholder);
-      Value desc = MemRefDescriptor::fromStaticShape(
-          rewriter, loc, *getTypeConverter(), memrefTy, newArg);
-      rewriter.replaceOp(placeholder, {desc});
-    }
-  }
-
   // Get memref type from function arguments and set the noalias to
   // pointer arguments.
   for (const auto [idx, argTy] :
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index 3a01795ce3f53..f5620a6a7cd91 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -159,18 +159,30 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
   addArgumentMaterialization(
       [&](OpBuilder &builder, UnrankedMemRefType resultType, ValueRange inputs,
           Location loc) -> std::optional<Value> {
-        if (inputs.size() == 1)
+        if (inputs.size() == 1) {
+          // Bare pointers are not supported for unranked memrefs because a
+          // memref descriptor cannot be built just from a bare pointer.
           return std::nullopt;
+        }
         return UnrankedMemRefDescriptor::pack(builder, loc, *this, resultType,
                                               inputs);
       });
   addArgumentMaterialization([&](OpBuilder &builder, MemRefType resultType,
                                  ValueRange inputs,
                                  Location loc) -> std::optional<Value> {
-    // TODO: bare ptr conversion could be handled here but we would need a way
-    // to distinguish between FuncOp and other regions.
-    if (inputs.size() == 1)
-      return std::nullopt;
+    if (inputs.size() == 1) {
+      // This is a bare pointer. We allow bare pointers only for function entry
+      // blocks.
+      BlockArgument barePtr = dyn_cast<BlockArgument>(inputs.front());
+      if (!barePtr)
+        return std::nullopt;
+      Block *block = barePtr.getOwner();
+      if (!block->isEntryBlock() ||
+          !isa<FunctionOpInterface>(block->getParentOp()))
+        return std::nullopt;
+      return MemRefDescriptor::fromStaticShape(builder, loc, *this, resultType,
+                                               inputs[0]);
+    }
     return MemRefDescriptor::pack(builder, loc, *this, resultType, inputs);
   });
   // Add generic source and target materializations to handle cases where



More information about the llvm-branch-commits mailing list