[Mlir-commits] [mlir] [mlir][memref] Move `AllocLikeConversion.h` helpers into `MemRefToLLVM.cpp` (PR #136424)

Matthias Springer llvmlistbot at llvm.org
Sat Apr 19 04:25:53 PDT 2025


https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/136424

This commit moves code around: The helper functions/classes are moved into `MemRefToLLVM.cpp`. This simplifies the code a bit: fewer templatized functions, fewer function calls.

This commit also moves checks in `matchAndRewrite` to the beginning of the functions, such that patterns bail out before starting to modify any IR. This is in preparation of the One-Shot Dialect Conversion refactoring, which will disallow pattern rollbacks.


>From 4753aa484b16026063ea466e881fde266aff7cbc Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Sat, 19 Apr 2025 12:37:23 +0200
Subject: [PATCH] tmp

---
 .../MemRefToLLVM/AllocLikeConversion.h        | 153 --------
 .../MemRefToLLVM/AllocLikeConversion.cpp      | 195 ----------
 .../Conversion/MemRefToLLVM/CMakeLists.txt    |   1 -
 .../Conversion/MemRefToLLVM/MemRefToLLVM.cpp  | 357 +++++++++++++++---
 4 files changed, 301 insertions(+), 405 deletions(-)
 delete mode 100644 mlir/include/mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h
 delete mode 100644 mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp

diff --git a/mlir/include/mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h b/mlir/include/mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h
deleted file mode 100644
index 8bf04219c759a..0000000000000
--- a/mlir/include/mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h
+++ /dev/null
@@ -1,153 +0,0 @@
-//===- AllocLikeConversion.h - Convert allocation ops to LLVM ---*- C++ -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_CONVERSION_MEMREFTOLLVM_ALLOCLIKECONVERSION_H
-#define MLIR_CONVERSION_MEMREFTOLLVM_ALLOCLIKECONVERSION_H
-
-#include "mlir/Conversion/LLVMCommon/Pattern.h"
-
-namespace mlir {
-
-/// Lowering for memory allocation ops.
-struct AllocationOpLLVMLowering : public ConvertToLLVMPattern {
-  using ConvertToLLVMPattern::createIndexAttrConstant;
-  using ConvertToLLVMPattern::getIndexType;
-  using ConvertToLLVMPattern::getVoidPtrType;
-
-  explicit AllocationOpLLVMLowering(StringRef opName,
-                                    const LLVMTypeConverter &converter,
-                                    PatternBenefit benefit = 1)
-      : ConvertToLLVMPattern(opName, &converter.getContext(), converter,
-                             benefit) {}
-
-protected:
-  /// Computes the aligned value for 'input' as follows:
-  ///   bumped = input + alignement - 1
-  ///   aligned = bumped - bumped % alignment
-  static Value createAligned(ConversionPatternRewriter &rewriter, Location loc,
-                             Value input, Value alignment);
-
-  static MemRefType getMemRefResultType(Operation *op) {
-    return cast<MemRefType>(op->getResult(0).getType());
-  }
-
-  /// Computes the alignment for the given memory allocation op.
-  template <typename OpType>
-  Value getAlignment(ConversionPatternRewriter &rewriter, Location loc,
-                     OpType op) const {
-    MemRefType memRefType = op.getType();
-    Value alignment;
-    if (auto alignmentAttr = op.getAlignment()) {
-      Type indexType = getIndexType();
-      alignment =
-          createIndexAttrConstant(rewriter, loc, indexType, *alignmentAttr);
-    } else if (!memRefType.getElementType().isSignlessIntOrIndexOrFloat()) {
-      // In the case where no alignment is specified, we may want to override
-      // `malloc's` behavior. `malloc` typically aligns at the size of the
-      // biggest scalar on a target HW. For non-scalars, use the natural
-      // alignment of the LLVM type given by the LLVM DataLayout.
-      alignment = getSizeInBytes(loc, memRefType.getElementType(), rewriter);
-    }
-    return alignment;
-  }
-
-  /// Computes the alignment for aligned_alloc used to allocate the buffer for
-  /// the memory allocation op.
-  ///
-  /// Aligned_alloc requires the allocation size to be a power of two, and the
-  /// allocation size to be a multiple of the alignment.
-  template <typename OpType>
-  int64_t alignedAllocationGetAlignment(ConversionPatternRewriter &rewriter,
-                                        Location loc, OpType op,
-                                        const DataLayout *defaultLayout) const {
-    if (std::optional<uint64_t> alignment = op.getAlignment())
-      return *alignment;
-
-    // Whenever we don't have alignment set, we will use an alignment
-    // consistent with the element type; since the allocation size has to be a
-    // power of two, we will bump to the next power of two if it isn't.
-    unsigned eltSizeBytes =
-        getMemRefEltSizeInBytes(op.getType(), op, defaultLayout);
-    return std::max(kMinAlignedAllocAlignment,
-                    llvm::PowerOf2Ceil(eltSizeBytes));
-  }
-
-  /// Allocates a memory buffer using an allocation method that doesn't
-  /// guarantee alignment. Returns the pointer and its aligned value.
-  std::tuple<Value, Value>
-  allocateBufferManuallyAlign(ConversionPatternRewriter &rewriter, Location loc,
-                              Value sizeBytes, Operation *op,
-                              Value alignment) const;
-
-  /// Allocates a memory buffer using an aligned allocation method.
-  Value allocateBufferAutoAlign(ConversionPatternRewriter &rewriter,
-                                Location loc, Value sizeBytes, Operation *op,
-                                const DataLayout *defaultLayout,
-                                int64_t alignment) const;
-
-private:
-  /// Computes the byte size for the MemRef element type.
-  unsigned getMemRefEltSizeInBytes(MemRefType memRefType, Operation *op,
-                                   const DataLayout *defaultLayout) const;
-
-  /// Returns true if the memref size in bytes is known to be a multiple of
-  /// factor.
-  bool isMemRefSizeMultipleOf(MemRefType type, uint64_t factor, Operation *op,
-                              const DataLayout *defaultLayout) const;
-
-  /// The minimum alignment to use with aligned_alloc (has to be a power of 2).
-  static constexpr uint64_t kMinAlignedAllocAlignment = 16UL;
-};
-
-/// Lowering for AllocOp and AllocaOp.
-struct AllocLikeOpLLVMLowering : public AllocationOpLLVMLowering {
-  explicit AllocLikeOpLLVMLowering(StringRef opName,
-                                   const LLVMTypeConverter &converter,
-                                   PatternBenefit benefit = 1)
-      : AllocationOpLLVMLowering(opName, converter, benefit) {}
-
-protected:
-  /// Allocates the underlying buffer. Returns the allocated pointer and the
-  /// aligned pointer.
-  virtual std::tuple<Value, Value>
-  allocateBuffer(ConversionPatternRewriter &rewriter, Location loc, Value size,
-                 Operation *op) const = 0;
-
-  /// Sets the flag 'requiresNumElements', specifying the Op requires the number
-  /// of elements instead of the size in bytes.
-  void setRequiresNumElements();
-
-private:
-  // An `alloc` is converted into a definition of a memref descriptor value and
-  // a call to `malloc` to allocate the underlying data buffer.  The memref
-  // descriptor is of the LLVM structure type where:
-  //   1. the first element is a pointer to the allocated (typed) data buffer,
-  //   2. the second element is a pointer to the (typed) payload, aligned to the
-  //      specified alignment,
-  //   3. the remaining elements serve to store all the sizes and strides of the
-  //      memref using LLVM-converted `index` type.
-  //
-  // Alignment is performed by allocating `alignment` more bytes than
-  // requested and shifting the aligned pointer relative to the allocated
-  // memory. Note: `alignment - <minimum malloc alignment>` would actually be
-  // sufficient. If alignment is unspecified, the two pointers are equal.
-
-  // An `alloca` is converted into a definition of a memref descriptor value and
-  // an llvm.alloca to allocate the underlying data buffer.
-  LogicalResult
-  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
-                  ConversionPatternRewriter &rewriter) const override;
-
-  // Flag for specifying the Op requires the number of elements instead of the
-  // size in bytes.
-  bool requiresNumElements = false;
-};
-
-} // namespace mlir
-
-#endif // MLIR_CONVERSION_MEMREFTOLLVM_ALLOCLIKECONVERSION_H
diff --git a/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp b/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp
deleted file mode 100644
index bad209a4ddecf..0000000000000
--- a/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp
+++ /dev/null
@@ -1,195 +0,0 @@
-//===- AllocLikeConversion.cpp - LLVM conversion for alloc operations -----===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h"
-#include "mlir/Analysis/DataLayoutAnalysis.h"
-#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
-#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
-#include "mlir/IR/SymbolTable.h"
-
-using namespace mlir;
-
-static FailureOr<LLVM::LLVMFuncOp>
-getNotalignedAllocFn(const LLVMTypeConverter *typeConverter, Operation *module,
-                     Type indexType) {
-  bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
-  if (useGenericFn)
-    return LLVM::lookupOrCreateGenericAllocFn(module, indexType);
-
-  return LLVM::lookupOrCreateMallocFn(module, indexType);
-}
-
-static FailureOr<LLVM::LLVMFuncOp>
-getAlignedAllocFn(const LLVMTypeConverter *typeConverter, Operation *module,
-                  Type indexType) {
-  bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
-
-  if (useGenericFn)
-    return LLVM::lookupOrCreateGenericAlignedAllocFn(module, indexType);
-
-  return LLVM::lookupOrCreateAlignedAllocFn(module, indexType);
-}
-
-Value AllocationOpLLVMLowering::createAligned(
-    ConversionPatternRewriter &rewriter, Location loc, Value input,
-    Value alignment) {
-  Value one = createIndexAttrConstant(rewriter, loc, alignment.getType(), 1);
-  Value bump = rewriter.create<LLVM::SubOp>(loc, alignment, one);
-  Value bumped = rewriter.create<LLVM::AddOp>(loc, input, bump);
-  Value mod = rewriter.create<LLVM::URemOp>(loc, bumped, alignment);
-  return rewriter.create<LLVM::SubOp>(loc, bumped, mod);
-}
-
-static Value castAllocFuncResult(ConversionPatternRewriter &rewriter,
-                                 Location loc, Value allocatedPtr,
-                                 MemRefType memRefType, Type elementPtrType,
-                                 const LLVMTypeConverter &typeConverter) {
-  auto allocatedPtrTy = cast<LLVM::LLVMPointerType>(allocatedPtr.getType());
-  FailureOr<unsigned> maybeMemrefAddrSpace =
-      typeConverter.getMemRefAddressSpace(memRefType);
-  if (failed(maybeMemrefAddrSpace))
-    return Value();
-  unsigned memrefAddrSpace = *maybeMemrefAddrSpace;
-  if (allocatedPtrTy.getAddressSpace() != memrefAddrSpace)
-    allocatedPtr = rewriter.create<LLVM::AddrSpaceCastOp>(
-        loc, LLVM::LLVMPointerType::get(rewriter.getContext(), memrefAddrSpace),
-        allocatedPtr);
-  return allocatedPtr;
-}
-
-std::tuple<Value, Value> AllocationOpLLVMLowering::allocateBufferManuallyAlign(
-    ConversionPatternRewriter &rewriter, Location loc, Value sizeBytes,
-    Operation *op, Value alignment) const {
-  if (alignment) {
-    // Adjust the allocation size to consider alignment.
-    sizeBytes = rewriter.create<LLVM::AddOp>(loc, sizeBytes, alignment);
-  }
-
-  MemRefType memRefType = getMemRefResultType(op);
-  // Allocate the underlying buffer.
-  Type elementPtrType = this->getElementPtrType(memRefType);
-  assert(elementPtrType && "could not compute element ptr type");
-  FailureOr<LLVM::LLVMFuncOp> allocFuncOp = getNotalignedAllocFn(
-      getTypeConverter(), op->getParentWithTrait<OpTrait::SymbolTable>(),
-      getIndexType());
-  if (failed(allocFuncOp))
-    return std::make_tuple(Value(), Value());
-  auto results =
-      rewriter.create<LLVM::CallOp>(loc, allocFuncOp.value(), sizeBytes);
-
-  Value allocatedPtr =
-      castAllocFuncResult(rewriter, loc, results.getResult(), memRefType,
-                          elementPtrType, *getTypeConverter());
-  if (!allocatedPtr)
-    return std::make_tuple(Value(), Value());
-  Value alignedPtr = allocatedPtr;
-  if (alignment) {
-    // Compute the aligned pointer.
-    Value allocatedInt =
-        rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), allocatedPtr);
-    Value alignmentInt = createAligned(rewriter, loc, allocatedInt, alignment);
-    alignedPtr =
-        rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, alignmentInt);
-  }
-
-  return std::make_tuple(allocatedPtr, alignedPtr);
-}
-
-unsigned AllocationOpLLVMLowering::getMemRefEltSizeInBytes(
-    MemRefType memRefType, Operation *op,
-    const DataLayout *defaultLayout) const {
-  const DataLayout *layout = defaultLayout;
-  if (const DataLayoutAnalysis *analysis =
-          getTypeConverter()->getDataLayoutAnalysis()) {
-    layout = &analysis->getAbove(op);
-  }
-  Type elementType = memRefType.getElementType();
-  if (auto memRefElementType = dyn_cast<MemRefType>(elementType))
-    return getTypeConverter()->getMemRefDescriptorSize(memRefElementType,
-                                                       *layout);
-  if (auto memRefElementType = dyn_cast<UnrankedMemRefType>(elementType))
-    return getTypeConverter()->getUnrankedMemRefDescriptorSize(
-        memRefElementType, *layout);
-  return layout->getTypeSize(elementType);
-}
-
-bool AllocationOpLLVMLowering::isMemRefSizeMultipleOf(
-    MemRefType type, uint64_t factor, Operation *op,
-    const DataLayout *defaultLayout) const {
-  uint64_t sizeDivisor = getMemRefEltSizeInBytes(type, op, defaultLayout);
-  for (unsigned i = 0, e = type.getRank(); i < e; i++) {
-    if (type.isDynamicDim(i))
-      continue;
-    sizeDivisor = sizeDivisor * type.getDimSize(i);
-  }
-  return sizeDivisor % factor == 0;
-}
-
-Value AllocationOpLLVMLowering::allocateBufferAutoAlign(
-    ConversionPatternRewriter &rewriter, Location loc, Value sizeBytes,
-    Operation *op, const DataLayout *defaultLayout, int64_t alignment) const {
-  Value allocAlignment =
-      createIndexAttrConstant(rewriter, loc, getIndexType(), alignment);
-
-  MemRefType memRefType = getMemRefResultType(op);
-  // Function aligned_alloc requires size to be a multiple of alignment; we pad
-  // the size to the next multiple if necessary.
-  if (!isMemRefSizeMultipleOf(memRefType, alignment, op, defaultLayout))
-    sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment);
-
-  Type elementPtrType = this->getElementPtrType(memRefType);
-  FailureOr<LLVM::LLVMFuncOp> allocFuncOp = getAlignedAllocFn(
-      getTypeConverter(), op->getParentWithTrait<OpTrait::SymbolTable>(),
-      getIndexType());
-  if (failed(allocFuncOp))
-    return Value();
-  auto results = rewriter.create<LLVM::CallOp>(
-      loc, allocFuncOp.value(), ValueRange({allocAlignment, sizeBytes}));
-
-  return castAllocFuncResult(rewriter, loc, results.getResult(), memRefType,
-                             elementPtrType, *getTypeConverter());
-}
-
-void AllocLikeOpLLVMLowering::setRequiresNumElements() {
-  requiresNumElements = true;
-}
-
-LogicalResult AllocLikeOpLLVMLowering::matchAndRewrite(
-    Operation *op, ArrayRef<Value> operands,
-    ConversionPatternRewriter &rewriter) const {
-  MemRefType memRefType = getMemRefResultType(op);
-  if (!isConvertibleAndHasIdentityMaps(memRefType))
-    return rewriter.notifyMatchFailure(op, "incompatible memref type");
-  auto loc = op->getLoc();
-
-  // Get actual sizes of the memref as values: static sizes are constant
-  // values and dynamic sizes are passed to 'alloc' as operands.  In case of
-  // zero-dimensional memref, assume a scalar (size 1).
-  SmallVector<Value, 4> sizes;
-  SmallVector<Value, 4> strides;
-  Value size;
-
-  this->getMemRefDescriptorSizes(loc, memRefType, operands, rewriter, sizes,
-                                 strides, size, !requiresNumElements);
-
-  // Allocate the underlying buffer.
-  auto [allocatedPtr, alignedPtr] =
-      this->allocateBuffer(rewriter, loc, size, op);
-
-  if (!allocatedPtr || !alignedPtr)
-    return rewriter.notifyMatchFailure(loc,
-                                       "underlying buffer allocation failed");
-
-  // Create the MemRef descriptor.
-  auto memRefDescriptor = this->createMemRefDescriptor(
-      loc, memRefType, allocatedPtr, alignedPtr, sizes, strides, rewriter);
-
-  // Return the final value of the descriptor.
-  rewriter.replaceOp(op, {memRefDescriptor});
-  return success();
-}
diff --git a/mlir/lib/Conversion/MemRefToLLVM/CMakeLists.txt b/mlir/lib/Conversion/MemRefToLLVM/CMakeLists.txt
index f0d95f5ada290..9da4b23d42f41 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/CMakeLists.txt
+++ b/mlir/lib/Conversion/MemRefToLLVM/CMakeLists.txt
@@ -1,5 +1,4 @@
 add_mlir_conversion_library(MLIRMemRefToLLVM
-  AllocLikeConversion.cpp
   MemRefToLLVM.cpp
 
   ADDITIONAL_HEADER_DIRS
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index cb4317ef1bcec..91a95bcdef465 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -13,7 +13,6 @@
 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
 #include "mlir/Conversion/LLVMCommon/Pattern.h"
 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
-#include "mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
@@ -52,33 +51,247 @@ getFreeFn(const LLVMTypeConverter *typeConverter, ModuleOp module) {
   return LLVM::lookupOrCreateFreeFn(module);
 }
 
-struct AllocOpLowering : public AllocLikeOpLLVMLowering {
-  AllocOpLowering(const LLVMTypeConverter &converter)
-      : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(),
-                                converter) {}
-  std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
-                                          Location loc, Value sizeBytes,
-                                          Operation *op) const override {
-    return allocateBufferManuallyAlign(
-        rewriter, loc, sizeBytes, op,
-        getAlignment(rewriter, loc, cast<memref::AllocOp>(op)));
+static FailureOr<LLVM::LLVMFuncOp>
+getNotalignedAllocFn(const LLVMTypeConverter *typeConverter, Operation *module,
+                     Type indexType) {
+  bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
+  if (useGenericFn)
+    return LLVM::lookupOrCreateGenericAllocFn(module, indexType);
+
+  return LLVM::lookupOrCreateMallocFn(module, indexType);
+}
+
+static FailureOr<LLVM::LLVMFuncOp>
+getAlignedAllocFn(const LLVMTypeConverter *typeConverter, Operation *module,
+                  Type indexType) {
+  bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
+
+  if (useGenericFn)
+    return LLVM::lookupOrCreateGenericAlignedAllocFn(module, indexType);
+
+  return LLVM::lookupOrCreateAlignedAllocFn(module, indexType);
+}
+
+/// Computes the aligned value for 'input' as follows:
+///   bumped = input + alignement - 1
+///   aligned = bumped - bumped % alignment
+static Value createAligned(ConversionPatternRewriter &rewriter, Location loc,
+                           Value input, Value alignment) {
+  Value one = rewriter.create<LLVM::ConstantOp>(loc, alignment.getType(),
+                                                rewriter.getIndexAttr(1));
+  Value bump = rewriter.create<LLVM::SubOp>(loc, alignment, one);
+  Value bumped = rewriter.create<LLVM::AddOp>(loc, input, bump);
+  Value mod = rewriter.create<LLVM::URemOp>(loc, bumped, alignment);
+  return rewriter.create<LLVM::SubOp>(loc, bumped, mod);
+}
+
+/// Computes the byte size for the MemRef element type.
+static unsigned getMemRefEltSizeInBytes(const LLVMTypeConverter *typeConverter,
+                                        MemRefType memRefType, Operation *op,
+                                        const DataLayout *defaultLayout) {
+  const DataLayout *layout = defaultLayout;
+  if (const DataLayoutAnalysis *analysis =
+          typeConverter->getDataLayoutAnalysis()) {
+    layout = &analysis->getAbove(op);
+  }
+  Type elementType = memRefType.getElementType();
+  if (auto memRefElementType = dyn_cast<MemRefType>(elementType))
+    return typeConverter->getMemRefDescriptorSize(memRefElementType, *layout);
+  if (auto memRefElementType = dyn_cast<UnrankedMemRefType>(elementType))
+    return typeConverter->getUnrankedMemRefDescriptorSize(memRefElementType,
+                                                          *layout);
+  return layout->getTypeSize(elementType);
+}
+
+static Value castAllocFuncResult(ConversionPatternRewriter &rewriter,
+                                 Location loc, Value allocatedPtr,
+                                 MemRefType memRefType, Type elementPtrType,
+                                 const LLVMTypeConverter &typeConverter) {
+  auto allocatedPtrTy = cast<LLVM::LLVMPointerType>(allocatedPtr.getType());
+  FailureOr<unsigned> maybeMemrefAddrSpace =
+      typeConverter.getMemRefAddressSpace(memRefType);
+  assert(succeeded(maybeMemrefAddrSpace) && "unsupported address space");
+  unsigned memrefAddrSpace = *maybeMemrefAddrSpace;
+  if (allocatedPtrTy.getAddressSpace() != memrefAddrSpace)
+    allocatedPtr = rewriter.create<LLVM::AddrSpaceCastOp>(
+        loc, LLVM::LLVMPointerType::get(rewriter.getContext(), memrefAddrSpace),
+        allocatedPtr);
+  return allocatedPtr;
+}
+
+struct AllocOpLowering : public ConvertOpToLLVMPattern<memref::AllocOp> {
+  using ConvertOpToLLVMPattern<memref::AllocOp>::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(memref::AllocOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto loc = op.getLoc();
+    MemRefType memRefType = op.getType();
+    if (!isConvertibleAndHasIdentityMaps(memRefType))
+      return rewriter.notifyMatchFailure(op, "incompatible memref type");
+
+    // Get or insert alloc function into the module.
+    FailureOr<LLVM::LLVMFuncOp> allocFuncOp = getNotalignedAllocFn(
+        getTypeConverter(), op->getParentWithTrait<OpTrait::SymbolTable>(),
+        getIndexType());
+    if (failed(allocFuncOp))
+      return failure();
+
+    // Get actual sizes of the memref as values: static sizes are constant
+    // values and dynamic sizes are passed to 'alloc' as operands.  In case of
+    // zero-dimensional memref, assume a scalar (size 1).
+    SmallVector<Value, 4> sizes;
+    SmallVector<Value, 4> strides;
+    Value sizeBytes;
+
+    this->getMemRefDescriptorSizes(loc, memRefType, adaptor.getOperands(),
+                                   rewriter, sizes, strides, sizeBytes, true);
+
+    Value alignment = getAlignment(rewriter, loc, op);
+    if (alignment) {
+      // Adjust the allocation size to consider alignment.
+      sizeBytes = rewriter.create<LLVM::AddOp>(loc, sizeBytes, alignment);
+    }
+
+    // Allocate the underlying buffer.
+    Type elementPtrType = this->getElementPtrType(memRefType);
+    assert(elementPtrType && "could not compute element ptr type");
+    auto results =
+        rewriter.create<LLVM::CallOp>(loc, allocFuncOp.value(), sizeBytes);
+
+    Value allocatedPtr =
+        castAllocFuncResult(rewriter, loc, results.getResult(), memRefType,
+                            elementPtrType, *getTypeConverter());
+    Value alignedPtr = allocatedPtr;
+    if (alignment) {
+      // Compute the aligned pointer.
+      Value allocatedInt =
+          rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), allocatedPtr);
+      Value alignmentInt =
+          createAligned(rewriter, loc, allocatedInt, alignment);
+      alignedPtr =
+          rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, alignmentInt);
+    }
+
+    // Create the MemRef descriptor.
+    auto memRefDescriptor = this->createMemRefDescriptor(
+        loc, memRefType, allocatedPtr, alignedPtr, sizes, strides, rewriter);
+
+    // Return the final value of the descriptor.
+    rewriter.replaceOp(op, {memRefDescriptor});
+    return success();
+  }
+
+  /// Computes the alignment for the given memory allocation op.
+  template <typename OpType>
+  Value getAlignment(ConversionPatternRewriter &rewriter, Location loc,
+                     OpType op) const {
+    MemRefType memRefType = op.getType();
+    Value alignment;
+    if (auto alignmentAttr = op.getAlignment()) {
+      Type indexType = getIndexType();
+      alignment =
+          createIndexAttrConstant(rewriter, loc, indexType, *alignmentAttr);
+    } else if (!memRefType.getElementType().isSignlessIntOrIndexOrFloat()) {
+      // In the case where no alignment is specified, we may want to override
+      // `malloc's` behavior. `malloc` typically aligns at the size of the
+      // biggest scalar on a target HW. For non-scalars, use the natural
+      // alignment of the LLVM type given by the LLVM DataLayout.
+      alignment = getSizeInBytes(loc, memRefType.getElementType(), rewriter);
+    }
+    return alignment;
   }
 };
 
-struct AlignedAllocOpLowering : public AllocLikeOpLLVMLowering {
-  AlignedAllocOpLowering(const LLVMTypeConverter &converter)
-      : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(),
-                                converter) {}
-  std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
-                                          Location loc, Value sizeBytes,
-                                          Operation *op) const override {
-    Value ptr = allocateBufferAutoAlign(
-        rewriter, loc, sizeBytes, op, &defaultLayout,
-        alignedAllocationGetAlignment(rewriter, loc, cast<memref::AllocOp>(op),
-                                      &defaultLayout));
-    if (!ptr)
-      return std::make_tuple(Value(), Value());
-    return std::make_tuple(ptr, ptr);
+struct AlignedAllocOpLowering : public ConvertOpToLLVMPattern<memref::AllocOp> {
+  using ConvertOpToLLVMPattern<memref::AllocOp>::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(memref::AllocOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto loc = op.getLoc();
+    MemRefType memRefType = op.getType();
+    if (!isConvertibleAndHasIdentityMaps(memRefType))
+      return rewriter.notifyMatchFailure(op, "incompatible memref type");
+
+    // Get or insert alloc function into module.
+    FailureOr<LLVM::LLVMFuncOp> allocFuncOp = getAlignedAllocFn(
+        getTypeConverter(), op->getParentWithTrait<OpTrait::SymbolTable>(),
+        getIndexType());
+    if (failed(allocFuncOp))
+      return failure();
+
+    // Get actual sizes of the memref as values: static sizes are constant
+    // values and dynamic sizes are passed to 'alloc' as operands.  In case of
+    // zero-dimensional memref, assume a scalar (size 1).
+    SmallVector<Value, 4> sizes;
+    SmallVector<Value, 4> strides;
+    Value sizeBytes;
+
+    this->getMemRefDescriptorSizes(loc, memRefType, adaptor.getOperands(),
+                                   rewriter, sizes, strides, sizeBytes, !false);
+
+    int64_t alignment = alignedAllocationGetAlignment(op, &defaultLayout);
+
+    Value allocAlignment =
+        createIndexAttrConstant(rewriter, loc, getIndexType(), alignment);
+
+    // Function aligned_alloc requires size to be a multiple of alignment; we
+    // pad the size to the next multiple if necessary.
+    if (!isMemRefSizeMultipleOf(memRefType, alignment, op, &defaultLayout))
+      sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment);
+
+    Type elementPtrType = this->getElementPtrType(memRefType);
+    auto results = rewriter.create<LLVM::CallOp>(
+        loc, allocFuncOp.value(), ValueRange({allocAlignment, sizeBytes}));
+
+    Value ptr =
+        castAllocFuncResult(rewriter, loc, results.getResult(), memRefType,
+                            elementPtrType, *getTypeConverter());
+
+    // Create the MemRef descriptor.
+    auto memRefDescriptor = this->createMemRefDescriptor(
+        loc, memRefType, ptr, ptr, sizes, strides, rewriter);
+
+    // Return the final value of the descriptor.
+    rewriter.replaceOp(op, {memRefDescriptor});
+    return success();
+  }
+
+  /// The minimum alignment to use with aligned_alloc (has to be a power of 2).
+  static constexpr uint64_t kMinAlignedAllocAlignment = 16UL;
+
+  /// Computes the alignment for aligned_alloc used to allocate the buffer for
+  /// the memory allocation op.
+  ///
+  /// Aligned_alloc requires the allocation size to be a power of two, and the
+  /// allocation size to be a multiple of the alignment.
+  int64_t alignedAllocationGetAlignment(memref::AllocOp op,
+                                        const DataLayout *defaultLayout) const {
+    if (std::optional<uint64_t> alignment = op.getAlignment())
+      return *alignment;
+
+    // Whenever we don't have alignment set, we will use an alignment
+    // consistent with the element type; since the allocation size has to be a
+    // power of two, we will bump to the next power of two if it isn't.
+    unsigned eltSizeBytes = getMemRefEltSizeInBytes(
+        getTypeConverter(), op.getType(), op, defaultLayout);
+    return std::max(kMinAlignedAllocAlignment,
+                    llvm::PowerOf2Ceil(eltSizeBytes));
+  }
+
+  /// Returns true if the memref size in bytes is known to be a multiple of
+  /// factor.
+  bool isMemRefSizeMultipleOf(MemRefType type, uint64_t factor, Operation *op,
+                              const DataLayout *defaultLayout) const {
+    uint64_t sizeDivisor =
+        getMemRefEltSizeInBytes(getTypeConverter(), type, op, defaultLayout);
+    for (unsigned i = 0, e = type.getRank(); i < e; i++) {
+      if (type.isDynamicDim(i))
+        continue;
+      sizeDivisor = sizeDivisor * type.getDimSize(i);
+    }
+    return sizeDivisor % factor == 0;
   }
 
 private:
@@ -86,38 +299,52 @@ struct AlignedAllocOpLowering : public AllocLikeOpLLVMLowering {
   DataLayout defaultLayout;
 };
 
-struct AllocaOpLowering : public AllocLikeOpLLVMLowering {
-  AllocaOpLowering(const LLVMTypeConverter &converter)
-      : AllocLikeOpLLVMLowering(memref::AllocaOp::getOperationName(),
-                                converter) {
-    setRequiresNumElements();
-  }
+struct AllocaOpLowering : public ConvertOpToLLVMPattern<memref::AllocaOp> {
+  using ConvertOpToLLVMPattern<memref::AllocaOp>::ConvertOpToLLVMPattern;
 
   /// Allocates the underlying buffer using the right call. `allocatedBytePtr`
   /// is set to null for stack allocations. `accessAlignment` is set if
   /// alignment is needed post allocation (for eg. in conjunction with malloc).
-  std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
-                                          Location loc, Value size,
-                                          Operation *op) const override {
+  LogicalResult
+  matchAndRewrite(memref::AllocaOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto loc = op.getLoc();
+    MemRefType memRefType = op.getType();
+    if (!isConvertibleAndHasIdentityMaps(memRefType))
+      return rewriter.notifyMatchFailure(op, "incompatible memref type");
+
+    // Get actual sizes of the memref as values: static sizes are constant
+    // values and dynamic sizes are passed to 'alloc' as operands.  In case of
+    // zero-dimensional memref, assume a scalar (size 1).
+    SmallVector<Value, 4> sizes;
+    SmallVector<Value, 4> strides;
+    Value size;
+
+    this->getMemRefDescriptorSizes(loc, memRefType, adaptor.getOperands(),
+                                   rewriter, sizes, strides, size, !true);
 
     // With alloca, one gets a pointer to the element type right away.
     // For stack allocations.
-    auto allocaOp = cast<memref::AllocaOp>(op);
     auto elementType =
-        typeConverter->convertType(allocaOp.getType().getElementType());
+        typeConverter->convertType(op.getType().getElementType());
     FailureOr<unsigned> maybeAddressSpace =
-        getTypeConverter()->getMemRefAddressSpace(allocaOp.getType());
-    if (failed(maybeAddressSpace))
-      return std::make_tuple(Value(), Value());
+        getTypeConverter()->getMemRefAddressSpace(op.getType());
+    assert(succeeded(maybeAddressSpace) && "unsupported address space");
     unsigned addrSpace = *maybeAddressSpace;
     auto elementPtrType =
         LLVM::LLVMPointerType::get(rewriter.getContext(), addrSpace);
 
-    auto allocatedElementPtr =
-        rewriter.create<LLVM::AllocaOp>(loc, elementPtrType, elementType, size,
-                                        allocaOp.getAlignment().value_or(0));
+    auto allocatedElementPtr = rewriter.create<LLVM::AllocaOp>(
+        loc, elementPtrType, elementType, size, op.getAlignment().value_or(0));
 
-    return std::make_tuple(allocatedElementPtr, allocatedElementPtr);
+    // Create the MemRef descriptor.
+    auto memRefDescriptor = this->createMemRefDescriptor(
+        loc, memRefType, allocatedElementPtr, allocatedElementPtr, sizes,
+        strides, rewriter);
+
+    // Return the final value of the descriptor.
+    rewriter.replaceOp(op, {memRefDescriptor});
+    return success();
   }
 };
 
@@ -526,31 +753,43 @@ struct GlobalMemrefOpLowering
 /// GetGlobalMemrefOp is lowered into a Memref descriptor with the pointer to
 /// the first element stashed into the descriptor. This reuses
 /// `AllocLikeOpLowering` to reuse the Memref descriptor construction.
-struct GetGlobalMemrefOpLowering : public AllocLikeOpLLVMLowering {
-  GetGlobalMemrefOpLowering(const LLVMTypeConverter &converter)
-      : AllocLikeOpLLVMLowering(memref::GetGlobalOp::getOperationName(),
-                                converter) {}
+struct GetGlobalMemrefOpLowering
+    : public ConvertOpToLLVMPattern<memref::GetGlobalOp> {
+  using ConvertOpToLLVMPattern<memref::GetGlobalOp>::ConvertOpToLLVMPattern;
 
   /// Buffer "allocation" for memref.get_global op is getting the address of
   /// the global variable referenced.
-  std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
-                                          Location loc, Value sizeBytes,
-                                          Operation *op) const override {
-    auto getGlobalOp = cast<memref::GetGlobalOp>(op);
-    MemRefType type = cast<MemRefType>(getGlobalOp.getResult().getType());
+  LogicalResult
+  matchAndRewrite(memref::GetGlobalOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto loc = op.getLoc();
+    MemRefType memRefType = op.getType();
+    if (!isConvertibleAndHasIdentityMaps(memRefType))
+      return rewriter.notifyMatchFailure(op, "incompatible memref type");
+
+    // Get actual sizes of the memref as values: static sizes are constant
+    // values and dynamic sizes are passed to 'alloc' as operands.  In case of
+    // zero-dimensional memref, assume a scalar (size 1).
+    SmallVector<Value, 4> sizes;
+    SmallVector<Value, 4> strides;
+    Value sizeBytes;
+
+    this->getMemRefDescriptorSizes(loc, memRefType, adaptor.getOperands(),
+                                   rewriter, sizes, strides, sizeBytes, !false);
+
+    MemRefType type = cast<MemRefType>(op.getResult().getType());
 
     // This is called after a type conversion, which would have failed if this
     // call fails.
     FailureOr<unsigned> maybeAddressSpace =
         getTypeConverter()->getMemRefAddressSpace(type);
-    if (failed(maybeAddressSpace))
-      return std::make_tuple(Value(), Value());
+    assert(succeeded(maybeAddressSpace) && "unsupported address space");
     unsigned memSpace = *maybeAddressSpace;
 
     Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
     auto ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), memSpace);
     auto addressOf =
-        rewriter.create<LLVM::AddressOfOp>(loc, ptrTy, getGlobalOp.getName());
+        rewriter.create<LLVM::AddressOfOp>(loc, ptrTy, op.getName());
 
     // Get the address of the first element in the array by creating a GEP with
     // the address of the GV as the base, and (rank + 1) number of 0 indices.
@@ -569,7 +808,13 @@ struct GetGlobalMemrefOpLowering : public AllocLikeOpLLVMLowering {
 
     // Both allocated and aligned pointers are same. We could potentially stash
     // a nullptr for the allocated pointer since we do not expect any dealloc.
-    return std::make_tuple(deadBeefPtr, gep);
+    // Create the MemRef descriptor.
+    auto memRefDescriptor = this->createMemRefDescriptor(
+        loc, memRefType, deadBeefPtr, gep, sizes, strides, rewriter);
+
+    // Return the final value of the descriptor.
+    rewriter.replaceOp(op, {memRefDescriptor});
+    return success();
   }
 };
 



More information about the Mlir-commits mailing list