[Mlir-commits] [mlir] [mlir][memref] Move `AllocLikeConversion.h` helpers into `MemRefToLLVM.cpp` (PR #136424)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat Apr 19 04:26:26 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

<details>
<summary>Changes</summary>

This commit moves code around: The helper functions/classes are moved into `MemRefToLLVM.cpp`. This simplifies the code a bit: fewer templatized functions, fewer function calls.

This commit also moves checks in `matchAndRewrite` to the beginning of the functions, such that patterns bail out before starting to modify any IR. This is in preparation of the One-Shot Dialect Conversion refactoring, which will disallow pattern rollbacks.


---

Patch is 35.56 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/136424.diff


4 Files Affected:

- (removed) mlir/include/mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h (-153) 
- (removed) mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp (-195) 
- (modified) mlir/lib/Conversion/MemRefToLLVM/CMakeLists.txt (-1) 
- (modified) mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp (+301-56) 


``````````diff
diff --git a/mlir/include/mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h b/mlir/include/mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h
deleted file mode 100644
index 8bf04219c759a..0000000000000
--- a/mlir/include/mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h
+++ /dev/null
@@ -1,153 +0,0 @@
-//===- AllocLikeConversion.h - Convert allocation ops to LLVM ---*- C++ -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_CONVERSION_MEMREFTOLLVM_ALLOCLIKECONVERSION_H
-#define MLIR_CONVERSION_MEMREFTOLLVM_ALLOCLIKECONVERSION_H
-
-#include "mlir/Conversion/LLVMCommon/Pattern.h"
-
-namespace mlir {
-
-/// Lowering for memory allocation ops.
-struct AllocationOpLLVMLowering : public ConvertToLLVMPattern {
-  using ConvertToLLVMPattern::createIndexAttrConstant;
-  using ConvertToLLVMPattern::getIndexType;
-  using ConvertToLLVMPattern::getVoidPtrType;
-
-  explicit AllocationOpLLVMLowering(StringRef opName,
-                                    const LLVMTypeConverter &converter,
-                                    PatternBenefit benefit = 1)
-      : ConvertToLLVMPattern(opName, &converter.getContext(), converter,
-                             benefit) {}
-
-protected:
-  /// Computes the aligned value for 'input' as follows:
-  ///   bumped = input + alignement - 1
-  ///   aligned = bumped - bumped % alignment
-  static Value createAligned(ConversionPatternRewriter &rewriter, Location loc,
-                             Value input, Value alignment);
-
-  static MemRefType getMemRefResultType(Operation *op) {
-    return cast<MemRefType>(op->getResult(0).getType());
-  }
-
-  /// Computes the alignment for the given memory allocation op.
-  template <typename OpType>
-  Value getAlignment(ConversionPatternRewriter &rewriter, Location loc,
-                     OpType op) const {
-    MemRefType memRefType = op.getType();
-    Value alignment;
-    if (auto alignmentAttr = op.getAlignment()) {
-      Type indexType = getIndexType();
-      alignment =
-          createIndexAttrConstant(rewriter, loc, indexType, *alignmentAttr);
-    } else if (!memRefType.getElementType().isSignlessIntOrIndexOrFloat()) {
-      // In the case where no alignment is specified, we may want to override
-      // `malloc's` behavior. `malloc` typically aligns at the size of the
-      // biggest scalar on a target HW. For non-scalars, use the natural
-      // alignment of the LLVM type given by the LLVM DataLayout.
-      alignment = getSizeInBytes(loc, memRefType.getElementType(), rewriter);
-    }
-    return alignment;
-  }
-
-  /// Computes the alignment for aligned_alloc used to allocate the buffer for
-  /// the memory allocation op.
-  ///
-  /// Aligned_alloc requires the allocation size to be a power of two, and the
-  /// allocation size to be a multiple of the alignment.
-  template <typename OpType>
-  int64_t alignedAllocationGetAlignment(ConversionPatternRewriter &rewriter,
-                                        Location loc, OpType op,
-                                        const DataLayout *defaultLayout) const {
-    if (std::optional<uint64_t> alignment = op.getAlignment())
-      return *alignment;
-
-    // Whenever we don't have alignment set, we will use an alignment
-    // consistent with the element type; since the allocation size has to be a
-    // power of two, we will bump to the next power of two if it isn't.
-    unsigned eltSizeBytes =
-        getMemRefEltSizeInBytes(op.getType(), op, defaultLayout);
-    return std::max(kMinAlignedAllocAlignment,
-                    llvm::PowerOf2Ceil(eltSizeBytes));
-  }
-
-  /// Allocates a memory buffer using an allocation method that doesn't
-  /// guarantee alignment. Returns the pointer and its aligned value.
-  std::tuple<Value, Value>
-  allocateBufferManuallyAlign(ConversionPatternRewriter &rewriter, Location loc,
-                              Value sizeBytes, Operation *op,
-                              Value alignment) const;
-
-  /// Allocates a memory buffer using an aligned allocation method.
-  Value allocateBufferAutoAlign(ConversionPatternRewriter &rewriter,
-                                Location loc, Value sizeBytes, Operation *op,
-                                const DataLayout *defaultLayout,
-                                int64_t alignment) const;
-
-private:
-  /// Computes the byte size for the MemRef element type.
-  unsigned getMemRefEltSizeInBytes(MemRefType memRefType, Operation *op,
-                                   const DataLayout *defaultLayout) const;
-
-  /// Returns true if the memref size in bytes is known to be a multiple of
-  /// factor.
-  bool isMemRefSizeMultipleOf(MemRefType type, uint64_t factor, Operation *op,
-                              const DataLayout *defaultLayout) const;
-
-  /// The minimum alignment to use with aligned_alloc (has to be a power of 2).
-  static constexpr uint64_t kMinAlignedAllocAlignment = 16UL;
-};
-
-/// Lowering for AllocOp and AllocaOp.
-struct AllocLikeOpLLVMLowering : public AllocationOpLLVMLowering {
-  explicit AllocLikeOpLLVMLowering(StringRef opName,
-                                   const LLVMTypeConverter &converter,
-                                   PatternBenefit benefit = 1)
-      : AllocationOpLLVMLowering(opName, converter, benefit) {}
-
-protected:
-  /// Allocates the underlying buffer. Returns the allocated pointer and the
-  /// aligned pointer.
-  virtual std::tuple<Value, Value>
-  allocateBuffer(ConversionPatternRewriter &rewriter, Location loc, Value size,
-                 Operation *op) const = 0;
-
-  /// Sets the flag 'requiresNumElements', specifying the Op requires the number
-  /// of elements instead of the size in bytes.
-  void setRequiresNumElements();
-
-private:
-  // An `alloc` is converted into a definition of a memref descriptor value and
-  // a call to `malloc` to allocate the underlying data buffer.  The memref
-  // descriptor is of the LLVM structure type where:
-  //   1. the first element is a pointer to the allocated (typed) data buffer,
-  //   2. the second element is a pointer to the (typed) payload, aligned to the
-  //      specified alignment,
-  //   3. the remaining elements serve to store all the sizes and strides of the
-  //      memref using LLVM-converted `index` type.
-  //
-  // Alignment is performed by allocating `alignment` more bytes than
-  // requested and shifting the aligned pointer relative to the allocated
-  // memory. Note: `alignment - <minimum malloc alignment>` would actually be
-  // sufficient. If alignment is unspecified, the two pointers are equal.
-
-  // An `alloca` is converted into a definition of a memref descriptor value and
-  // an llvm.alloca to allocate the underlying data buffer.
-  LogicalResult
-  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
-                  ConversionPatternRewriter &rewriter) const override;
-
-  // Flag for specifying the Op requires the number of elements instead of the
-  // size in bytes.
-  bool requiresNumElements = false;
-};
-
-} // namespace mlir
-
-#endif // MLIR_CONVERSION_MEMREFTOLLVM_ALLOCLIKECONVERSION_H
diff --git a/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp b/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp
deleted file mode 100644
index bad209a4ddecf..0000000000000
--- a/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp
+++ /dev/null
@@ -1,195 +0,0 @@
-//===- AllocLikeConversion.cpp - LLVM conversion for alloc operations -----===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h"
-#include "mlir/Analysis/DataLayoutAnalysis.h"
-#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
-#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
-#include "mlir/IR/SymbolTable.h"
-
-using namespace mlir;
-
-static FailureOr<LLVM::LLVMFuncOp>
-getNotalignedAllocFn(const LLVMTypeConverter *typeConverter, Operation *module,
-                     Type indexType) {
-  bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
-  if (useGenericFn)
-    return LLVM::lookupOrCreateGenericAllocFn(module, indexType);
-
-  return LLVM::lookupOrCreateMallocFn(module, indexType);
-}
-
-static FailureOr<LLVM::LLVMFuncOp>
-getAlignedAllocFn(const LLVMTypeConverter *typeConverter, Operation *module,
-                  Type indexType) {
-  bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
-
-  if (useGenericFn)
-    return LLVM::lookupOrCreateGenericAlignedAllocFn(module, indexType);
-
-  return LLVM::lookupOrCreateAlignedAllocFn(module, indexType);
-}
-
-Value AllocationOpLLVMLowering::createAligned(
-    ConversionPatternRewriter &rewriter, Location loc, Value input,
-    Value alignment) {
-  Value one = createIndexAttrConstant(rewriter, loc, alignment.getType(), 1);
-  Value bump = rewriter.create<LLVM::SubOp>(loc, alignment, one);
-  Value bumped = rewriter.create<LLVM::AddOp>(loc, input, bump);
-  Value mod = rewriter.create<LLVM::URemOp>(loc, bumped, alignment);
-  return rewriter.create<LLVM::SubOp>(loc, bumped, mod);
-}
-
-static Value castAllocFuncResult(ConversionPatternRewriter &rewriter,
-                                 Location loc, Value allocatedPtr,
-                                 MemRefType memRefType, Type elementPtrType,
-                                 const LLVMTypeConverter &typeConverter) {
-  auto allocatedPtrTy = cast<LLVM::LLVMPointerType>(allocatedPtr.getType());
-  FailureOr<unsigned> maybeMemrefAddrSpace =
-      typeConverter.getMemRefAddressSpace(memRefType);
-  if (failed(maybeMemrefAddrSpace))
-    return Value();
-  unsigned memrefAddrSpace = *maybeMemrefAddrSpace;
-  if (allocatedPtrTy.getAddressSpace() != memrefAddrSpace)
-    allocatedPtr = rewriter.create<LLVM::AddrSpaceCastOp>(
-        loc, LLVM::LLVMPointerType::get(rewriter.getContext(), memrefAddrSpace),
-        allocatedPtr);
-  return allocatedPtr;
-}
-
-std::tuple<Value, Value> AllocationOpLLVMLowering::allocateBufferManuallyAlign(
-    ConversionPatternRewriter &rewriter, Location loc, Value sizeBytes,
-    Operation *op, Value alignment) const {
-  if (alignment) {
-    // Adjust the allocation size to consider alignment.
-    sizeBytes = rewriter.create<LLVM::AddOp>(loc, sizeBytes, alignment);
-  }
-
-  MemRefType memRefType = getMemRefResultType(op);
-  // Allocate the underlying buffer.
-  Type elementPtrType = this->getElementPtrType(memRefType);
-  assert(elementPtrType && "could not compute element ptr type");
-  FailureOr<LLVM::LLVMFuncOp> allocFuncOp = getNotalignedAllocFn(
-      getTypeConverter(), op->getParentWithTrait<OpTrait::SymbolTable>(),
-      getIndexType());
-  if (failed(allocFuncOp))
-    return std::make_tuple(Value(), Value());
-  auto results =
-      rewriter.create<LLVM::CallOp>(loc, allocFuncOp.value(), sizeBytes);
-
-  Value allocatedPtr =
-      castAllocFuncResult(rewriter, loc, results.getResult(), memRefType,
-                          elementPtrType, *getTypeConverter());
-  if (!allocatedPtr)
-    return std::make_tuple(Value(), Value());
-  Value alignedPtr = allocatedPtr;
-  if (alignment) {
-    // Compute the aligned pointer.
-    Value allocatedInt =
-        rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), allocatedPtr);
-    Value alignmentInt = createAligned(rewriter, loc, allocatedInt, alignment);
-    alignedPtr =
-        rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, alignmentInt);
-  }
-
-  return std::make_tuple(allocatedPtr, alignedPtr);
-}
-
-unsigned AllocationOpLLVMLowering::getMemRefEltSizeInBytes(
-    MemRefType memRefType, Operation *op,
-    const DataLayout *defaultLayout) const {
-  const DataLayout *layout = defaultLayout;
-  if (const DataLayoutAnalysis *analysis =
-          getTypeConverter()->getDataLayoutAnalysis()) {
-    layout = &analysis->getAbove(op);
-  }
-  Type elementType = memRefType.getElementType();
-  if (auto memRefElementType = dyn_cast<MemRefType>(elementType))
-    return getTypeConverter()->getMemRefDescriptorSize(memRefElementType,
-                                                       *layout);
-  if (auto memRefElementType = dyn_cast<UnrankedMemRefType>(elementType))
-    return getTypeConverter()->getUnrankedMemRefDescriptorSize(
-        memRefElementType, *layout);
-  return layout->getTypeSize(elementType);
-}
-
-bool AllocationOpLLVMLowering::isMemRefSizeMultipleOf(
-    MemRefType type, uint64_t factor, Operation *op,
-    const DataLayout *defaultLayout) const {
-  uint64_t sizeDivisor = getMemRefEltSizeInBytes(type, op, defaultLayout);
-  for (unsigned i = 0, e = type.getRank(); i < e; i++) {
-    if (type.isDynamicDim(i))
-      continue;
-    sizeDivisor = sizeDivisor * type.getDimSize(i);
-  }
-  return sizeDivisor % factor == 0;
-}
-
-Value AllocationOpLLVMLowering::allocateBufferAutoAlign(
-    ConversionPatternRewriter &rewriter, Location loc, Value sizeBytes,
-    Operation *op, const DataLayout *defaultLayout, int64_t alignment) const {
-  Value allocAlignment =
-      createIndexAttrConstant(rewriter, loc, getIndexType(), alignment);
-
-  MemRefType memRefType = getMemRefResultType(op);
-  // Function aligned_alloc requires size to be a multiple of alignment; we pad
-  // the size to the next multiple if necessary.
-  if (!isMemRefSizeMultipleOf(memRefType, alignment, op, defaultLayout))
-    sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment);
-
-  Type elementPtrType = this->getElementPtrType(memRefType);
-  FailureOr<LLVM::LLVMFuncOp> allocFuncOp = getAlignedAllocFn(
-      getTypeConverter(), op->getParentWithTrait<OpTrait::SymbolTable>(),
-      getIndexType());
-  if (failed(allocFuncOp))
-    return Value();
-  auto results = rewriter.create<LLVM::CallOp>(
-      loc, allocFuncOp.value(), ValueRange({allocAlignment, sizeBytes}));
-
-  return castAllocFuncResult(rewriter, loc, results.getResult(), memRefType,
-                             elementPtrType, *getTypeConverter());
-}
-
-void AllocLikeOpLLVMLowering::setRequiresNumElements() {
-  requiresNumElements = true;
-}
-
-LogicalResult AllocLikeOpLLVMLowering::matchAndRewrite(
-    Operation *op, ArrayRef<Value> operands,
-    ConversionPatternRewriter &rewriter) const {
-  MemRefType memRefType = getMemRefResultType(op);
-  if (!isConvertibleAndHasIdentityMaps(memRefType))
-    return rewriter.notifyMatchFailure(op, "incompatible memref type");
-  auto loc = op->getLoc();
-
-  // Get actual sizes of the memref as values: static sizes are constant
-  // values and dynamic sizes are passed to 'alloc' as operands.  In case of
-  // zero-dimensional memref, assume a scalar (size 1).
-  SmallVector<Value, 4> sizes;
-  SmallVector<Value, 4> strides;
-  Value size;
-
-  this->getMemRefDescriptorSizes(loc, memRefType, operands, rewriter, sizes,
-                                 strides, size, !requiresNumElements);
-
-  // Allocate the underlying buffer.
-  auto [allocatedPtr, alignedPtr] =
-      this->allocateBuffer(rewriter, loc, size, op);
-
-  if (!allocatedPtr || !alignedPtr)
-    return rewriter.notifyMatchFailure(loc,
-                                       "underlying buffer allocation failed");
-
-  // Create the MemRef descriptor.
-  auto memRefDescriptor = this->createMemRefDescriptor(
-      loc, memRefType, allocatedPtr, alignedPtr, sizes, strides, rewriter);
-
-  // Return the final value of the descriptor.
-  rewriter.replaceOp(op, {memRefDescriptor});
-  return success();
-}
diff --git a/mlir/lib/Conversion/MemRefToLLVM/CMakeLists.txt b/mlir/lib/Conversion/MemRefToLLVM/CMakeLists.txt
index f0d95f5ada290..9da4b23d42f41 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/CMakeLists.txt
+++ b/mlir/lib/Conversion/MemRefToLLVM/CMakeLists.txt
@@ -1,5 +1,4 @@
 add_mlir_conversion_library(MLIRMemRefToLLVM
-  AllocLikeConversion.cpp
   MemRefToLLVM.cpp
 
   ADDITIONAL_HEADER_DIRS
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index cb4317ef1bcec..91a95bcdef465 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -13,7 +13,6 @@
 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
 #include "mlir/Conversion/LLVMCommon/Pattern.h"
 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
-#include "mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
@@ -52,33 +51,247 @@ getFreeFn(const LLVMTypeConverter *typeConverter, ModuleOp module) {
   return LLVM::lookupOrCreateFreeFn(module);
 }
 
-struct AllocOpLowering : public AllocLikeOpLLVMLowering {
-  AllocOpLowering(const LLVMTypeConverter &converter)
-      : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(),
-                                converter) {}
-  std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
-                                          Location loc, Value sizeBytes,
-                                          Operation *op) const override {
-    return allocateBufferManuallyAlign(
-        rewriter, loc, sizeBytes, op,
-        getAlignment(rewriter, loc, cast<memref::AllocOp>(op)));
+static FailureOr<LLVM::LLVMFuncOp>
+getNotalignedAllocFn(const LLVMTypeConverter *typeConverter, Operation *module,
+                     Type indexType) {
+  bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
+  if (useGenericFn)
+    return LLVM::lookupOrCreateGenericAllocFn(module, indexType);
+
+  return LLVM::lookupOrCreateMallocFn(module, indexType);
+}
+
+static FailureOr<LLVM::LLVMFuncOp>
+getAlignedAllocFn(const LLVMTypeConverter *typeConverter, Operation *module,
+                  Type indexType) {
+  bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
+
+  if (useGenericFn)
+    return LLVM::lookupOrCreateGenericAlignedAllocFn(module, indexType);
+
+  return LLVM::lookupOrCreateAlignedAllocFn(module, indexType);
+}
+
+/// Computes the aligned value for 'input' as follows:
+///   bumped = input + alignement - 1
+///   aligned = bumped - bumped % alignment
+static Value createAligned(ConversionPatternRewriter &rewriter, Location loc,
+                           Value input, Value alignment) {
+  Value one = rewriter.create<LLVM::ConstantOp>(loc, alignment.getType(),
+                                                rewriter.getIndexAttr(1));
+  Value bump = rewriter.create<LLVM::SubOp>(loc, alignment, one);
+  Value bumped = rewriter.create<LLVM::AddOp>(loc, input, bump);
+  Value mod = rewriter.create<LLVM::URemOp>(loc, bumped, alignment);
+  return rewriter.create<LLVM::SubOp>(loc, bumped, mod);
+}
+
+/// Computes the byte size for the MemRef element type.
+static unsigned getMemRefEltSizeInBytes(const LLVMTypeConverter *typeConverter,
+                                        MemRefType memRefType, Operation *op,
+                                        const DataLayout *defaultLayout) {
+  const DataLayout *layout = defaultLayout;
+  if (const DataLayoutAnalysis *analysis =
+          typeConverter->getDataLayoutAnalysis()) {
+    layout = &analysis->getAbove(op);
+  }
+  Type elementType = memRefType.getElementType();
+  if (auto memRefElementType = dyn_cast<MemRefType>(elementType))
+    return typeConverter->getMemRefDescriptorSize(memRefElementType, *layout);
+  if (auto memRefElementType = dyn_cast<UnrankedMemRefType>(elementType))
+    return typeConverter->getUnrankedMemRefDescriptorSize(memRefElementType,
+                                                          *layout);
+  return layout->getTypeSize(elementType);
+}
+
+static Value castAllocFuncResult(ConversionPatternRewriter &rewriter,
+                                 Location loc, Value allocatedPtr,
+                                 MemRefType memRefType, Type elementPtrType,
+                                 const LLVMTypeConverter &typeConverter) {
+  auto allocatedPtrTy = ...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/136424


More information about the Mlir-commits mailing list