[Mlir-commits] [mlir] [mlir][memref] Move `AllocLikeConversion.h` helpers into `MemRefToLLVM.cpp` (PR #136424)
Matthias Springer
llvmlistbot at llvm.org
Sun Apr 20 03:26:14 PDT 2025
https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/136424
>From 38b538448c13e8fe83a8275cd2fb3a692e07881c Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Sat, 19 Apr 2025 12:37:23 +0200
Subject: [PATCH] tmp
---
.../MemRefToLLVM/AllocLikeConversion.h | 153 --------
.../MemRefToLLVM/AllocLikeConversion.cpp | 195 ----------
.../Conversion/MemRefToLLVM/CMakeLists.txt | 1 -
.../Conversion/MemRefToLLVM/MemRefToLLVM.cpp | 357 +++++++++++++++---
4 files changed, 301 insertions(+), 405 deletions(-)
delete mode 100644 mlir/include/mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h
delete mode 100644 mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp
diff --git a/mlir/include/mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h b/mlir/include/mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h
deleted file mode 100644
index 8bf04219c759a..0000000000000
--- a/mlir/include/mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h
+++ /dev/null
@@ -1,153 +0,0 @@
-//===- AllocLikeConversion.h - Convert allocation ops to LLVM ---*- C++ -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_CONVERSION_MEMREFTOLLVM_ALLOCLIKECONVERSION_H
-#define MLIR_CONVERSION_MEMREFTOLLVM_ALLOCLIKECONVERSION_H
-
-#include "mlir/Conversion/LLVMCommon/Pattern.h"
-
-namespace mlir {
-
-/// Lowering for memory allocation ops.
-struct AllocationOpLLVMLowering : public ConvertToLLVMPattern {
- using ConvertToLLVMPattern::createIndexAttrConstant;
- using ConvertToLLVMPattern::getIndexType;
- using ConvertToLLVMPattern::getVoidPtrType;
-
- explicit AllocationOpLLVMLowering(StringRef opName,
- const LLVMTypeConverter &converter,
- PatternBenefit benefit = 1)
- : ConvertToLLVMPattern(opName, &converter.getContext(), converter,
- benefit) {}
-
-protected:
- /// Computes the aligned value for 'input' as follows:
- /// bumped = input + alignement - 1
- /// aligned = bumped - bumped % alignment
- static Value createAligned(ConversionPatternRewriter &rewriter, Location loc,
- Value input, Value alignment);
-
- static MemRefType getMemRefResultType(Operation *op) {
- return cast<MemRefType>(op->getResult(0).getType());
- }
-
- /// Computes the alignment for the given memory allocation op.
- template <typename OpType>
- Value getAlignment(ConversionPatternRewriter &rewriter, Location loc,
- OpType op) const {
- MemRefType memRefType = op.getType();
- Value alignment;
- if (auto alignmentAttr = op.getAlignment()) {
- Type indexType = getIndexType();
- alignment =
- createIndexAttrConstant(rewriter, loc, indexType, *alignmentAttr);
- } else if (!memRefType.getElementType().isSignlessIntOrIndexOrFloat()) {
- // In the case where no alignment is specified, we may want to override
- // `malloc's` behavior. `malloc` typically aligns at the size of the
- // biggest scalar on a target HW. For non-scalars, use the natural
- // alignment of the LLVM type given by the LLVM DataLayout.
- alignment = getSizeInBytes(loc, memRefType.getElementType(), rewriter);
- }
- return alignment;
- }
-
- /// Computes the alignment for aligned_alloc used to allocate the buffer for
- /// the memory allocation op.
- ///
- /// Aligned_alloc requires the allocation size to be a power of two, and the
- /// allocation size to be a multiple of the alignment.
- template <typename OpType>
- int64_t alignedAllocationGetAlignment(ConversionPatternRewriter &rewriter,
- Location loc, OpType op,
- const DataLayout *defaultLayout) const {
- if (std::optional<uint64_t> alignment = op.getAlignment())
- return *alignment;
-
- // Whenever we don't have alignment set, we will use an alignment
- // consistent with the element type; since the allocation size has to be a
- // power of two, we will bump to the next power of two if it isn't.
- unsigned eltSizeBytes =
- getMemRefEltSizeInBytes(op.getType(), op, defaultLayout);
- return std::max(kMinAlignedAllocAlignment,
- llvm::PowerOf2Ceil(eltSizeBytes));
- }
-
- /// Allocates a memory buffer using an allocation method that doesn't
- /// guarantee alignment. Returns the pointer and its aligned value.
- std::tuple<Value, Value>
- allocateBufferManuallyAlign(ConversionPatternRewriter &rewriter, Location loc,
- Value sizeBytes, Operation *op,
- Value alignment) const;
-
- /// Allocates a memory buffer using an aligned allocation method.
- Value allocateBufferAutoAlign(ConversionPatternRewriter &rewriter,
- Location loc, Value sizeBytes, Operation *op,
- const DataLayout *defaultLayout,
- int64_t alignment) const;
-
-private:
- /// Computes the byte size for the MemRef element type.
- unsigned getMemRefEltSizeInBytes(MemRefType memRefType, Operation *op,
- const DataLayout *defaultLayout) const;
-
- /// Returns true if the memref size in bytes is known to be a multiple of
- /// factor.
- bool isMemRefSizeMultipleOf(MemRefType type, uint64_t factor, Operation *op,
- const DataLayout *defaultLayout) const;
-
- /// The minimum alignment to use with aligned_alloc (has to be a power of 2).
- static constexpr uint64_t kMinAlignedAllocAlignment = 16UL;
-};
-
-/// Lowering for AllocOp and AllocaOp.
-struct AllocLikeOpLLVMLowering : public AllocationOpLLVMLowering {
- explicit AllocLikeOpLLVMLowering(StringRef opName,
- const LLVMTypeConverter &converter,
- PatternBenefit benefit = 1)
- : AllocationOpLLVMLowering(opName, converter, benefit) {}
-
-protected:
- /// Allocates the underlying buffer. Returns the allocated pointer and the
- /// aligned pointer.
- virtual std::tuple<Value, Value>
- allocateBuffer(ConversionPatternRewriter &rewriter, Location loc, Value size,
- Operation *op) const = 0;
-
- /// Sets the flag 'requiresNumElements', specifying the Op requires the number
- /// of elements instead of the size in bytes.
- void setRequiresNumElements();
-
-private:
- // An `alloc` is converted into a definition of a memref descriptor value and
- // a call to `malloc` to allocate the underlying data buffer. The memref
- // descriptor is of the LLVM structure type where:
- // 1. the first element is a pointer to the allocated (typed) data buffer,
- // 2. the second element is a pointer to the (typed) payload, aligned to the
- // specified alignment,
- // 3. the remaining elements serve to store all the sizes and strides of the
- // memref using LLVM-converted `index` type.
- //
- // Alignment is performed by allocating `alignment` more bytes than
- // requested and shifting the aligned pointer relative to the allocated
- // memory. Note: `alignment - <minimum malloc alignment>` would actually be
- // sufficient. If alignment is unspecified, the two pointers are equal.
-
- // An `alloca` is converted into a definition of a memref descriptor value and
- // an llvm.alloca to allocate the underlying data buffer.
- LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const override;
-
- // Flag for specifying the Op requires the number of elements instead of the
- // size in bytes.
- bool requiresNumElements = false;
-};
-
-} // namespace mlir
-
-#endif // MLIR_CONVERSION_MEMREFTOLLVM_ALLOCLIKECONVERSION_H
diff --git a/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp b/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp
deleted file mode 100644
index e9b79983696aa..0000000000000
--- a/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp
+++ /dev/null
@@ -1,195 +0,0 @@
-//===- AllocLikeConversion.cpp - LLVM conversion for alloc operations -----===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h"
-#include "mlir/Analysis/DataLayoutAnalysis.h"
-#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
-#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
-#include "mlir/IR/SymbolTable.h"
-
-using namespace mlir;
-
-static FailureOr<LLVM::LLVMFuncOp>
-getNotalignedAllocFn(OpBuilder &b, const LLVMTypeConverter *typeConverter,
- Operation *module, Type indexType) {
- bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
- if (useGenericFn)
- return LLVM::lookupOrCreateGenericAllocFn(b, module, indexType);
-
- return LLVM::lookupOrCreateMallocFn(b, module, indexType);
-}
-
-static FailureOr<LLVM::LLVMFuncOp>
-getAlignedAllocFn(OpBuilder &b, const LLVMTypeConverter *typeConverter,
- Operation *module, Type indexType) {
- bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
-
- if (useGenericFn)
- return LLVM::lookupOrCreateGenericAlignedAllocFn(b, module, indexType);
-
- return LLVM::lookupOrCreateAlignedAllocFn(b, module, indexType);
-}
-
-Value AllocationOpLLVMLowering::createAligned(
- ConversionPatternRewriter &rewriter, Location loc, Value input,
- Value alignment) {
- Value one = createIndexAttrConstant(rewriter, loc, alignment.getType(), 1);
- Value bump = rewriter.create<LLVM::SubOp>(loc, alignment, one);
- Value bumped = rewriter.create<LLVM::AddOp>(loc, input, bump);
- Value mod = rewriter.create<LLVM::URemOp>(loc, bumped, alignment);
- return rewriter.create<LLVM::SubOp>(loc, bumped, mod);
-}
-
-static Value castAllocFuncResult(ConversionPatternRewriter &rewriter,
- Location loc, Value allocatedPtr,
- MemRefType memRefType, Type elementPtrType,
- const LLVMTypeConverter &typeConverter) {
- auto allocatedPtrTy = cast<LLVM::LLVMPointerType>(allocatedPtr.getType());
- FailureOr<unsigned> maybeMemrefAddrSpace =
- typeConverter.getMemRefAddressSpace(memRefType);
- if (failed(maybeMemrefAddrSpace))
- return Value();
- unsigned memrefAddrSpace = *maybeMemrefAddrSpace;
- if (allocatedPtrTy.getAddressSpace() != memrefAddrSpace)
- allocatedPtr = rewriter.create<LLVM::AddrSpaceCastOp>(
- loc, LLVM::LLVMPointerType::get(rewriter.getContext(), memrefAddrSpace),
- allocatedPtr);
- return allocatedPtr;
-}
-
-std::tuple<Value, Value> AllocationOpLLVMLowering::allocateBufferManuallyAlign(
- ConversionPatternRewriter &rewriter, Location loc, Value sizeBytes,
- Operation *op, Value alignment) const {
- if (alignment) {
- // Adjust the allocation size to consider alignment.
- sizeBytes = rewriter.create<LLVM::AddOp>(loc, sizeBytes, alignment);
- }
-
- MemRefType memRefType = getMemRefResultType(op);
- // Allocate the underlying buffer.
- Type elementPtrType = this->getElementPtrType(memRefType);
- assert(elementPtrType && "could not compute element ptr type");
- FailureOr<LLVM::LLVMFuncOp> allocFuncOp = getNotalignedAllocFn(
- rewriter, getTypeConverter(),
- op->getParentWithTrait<OpTrait::SymbolTable>(), getIndexType());
- if (failed(allocFuncOp))
- return std::make_tuple(Value(), Value());
- auto results =
- rewriter.create<LLVM::CallOp>(loc, allocFuncOp.value(), sizeBytes);
-
- Value allocatedPtr =
- castAllocFuncResult(rewriter, loc, results.getResult(), memRefType,
- elementPtrType, *getTypeConverter());
- if (!allocatedPtr)
- return std::make_tuple(Value(), Value());
- Value alignedPtr = allocatedPtr;
- if (alignment) {
- // Compute the aligned pointer.
- Value allocatedInt =
- rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), allocatedPtr);
- Value alignmentInt = createAligned(rewriter, loc, allocatedInt, alignment);
- alignedPtr =
- rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, alignmentInt);
- }
-
- return std::make_tuple(allocatedPtr, alignedPtr);
-}
-
-unsigned AllocationOpLLVMLowering::getMemRefEltSizeInBytes(
- MemRefType memRefType, Operation *op,
- const DataLayout *defaultLayout) const {
- const DataLayout *layout = defaultLayout;
- if (const DataLayoutAnalysis *analysis =
- getTypeConverter()->getDataLayoutAnalysis()) {
- layout = &analysis->getAbove(op);
- }
- Type elementType = memRefType.getElementType();
- if (auto memRefElementType = dyn_cast<MemRefType>(elementType))
- return getTypeConverter()->getMemRefDescriptorSize(memRefElementType,
- *layout);
- if (auto memRefElementType = dyn_cast<UnrankedMemRefType>(elementType))
- return getTypeConverter()->getUnrankedMemRefDescriptorSize(
- memRefElementType, *layout);
- return layout->getTypeSize(elementType);
-}
-
-bool AllocationOpLLVMLowering::isMemRefSizeMultipleOf(
- MemRefType type, uint64_t factor, Operation *op,
- const DataLayout *defaultLayout) const {
- uint64_t sizeDivisor = getMemRefEltSizeInBytes(type, op, defaultLayout);
- for (unsigned i = 0, e = type.getRank(); i < e; i++) {
- if (type.isDynamicDim(i))
- continue;
- sizeDivisor = sizeDivisor * type.getDimSize(i);
- }
- return sizeDivisor % factor == 0;
-}
-
-Value AllocationOpLLVMLowering::allocateBufferAutoAlign(
- ConversionPatternRewriter &rewriter, Location loc, Value sizeBytes,
- Operation *op, const DataLayout *defaultLayout, int64_t alignment) const {
- Value allocAlignment =
- createIndexAttrConstant(rewriter, loc, getIndexType(), alignment);
-
- MemRefType memRefType = getMemRefResultType(op);
- // Function aligned_alloc requires size to be a multiple of alignment; we pad
- // the size to the next multiple if necessary.
- if (!isMemRefSizeMultipleOf(memRefType, alignment, op, defaultLayout))
- sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment);
-
- Type elementPtrType = this->getElementPtrType(memRefType);
- FailureOr<LLVM::LLVMFuncOp> allocFuncOp = getAlignedAllocFn(
- rewriter, getTypeConverter(),
- op->getParentWithTrait<OpTrait::SymbolTable>(), getIndexType());
- if (failed(allocFuncOp))
- return Value();
- auto results = rewriter.create<LLVM::CallOp>(
- loc, allocFuncOp.value(), ValueRange({allocAlignment, sizeBytes}));
-
- return castAllocFuncResult(rewriter, loc, results.getResult(), memRefType,
- elementPtrType, *getTypeConverter());
-}
-
-void AllocLikeOpLLVMLowering::setRequiresNumElements() {
- requiresNumElements = true;
-}
-
-LogicalResult AllocLikeOpLLVMLowering::matchAndRewrite(
- Operation *op, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const {
- MemRefType memRefType = getMemRefResultType(op);
- if (!isConvertibleAndHasIdentityMaps(memRefType))
- return rewriter.notifyMatchFailure(op, "incompatible memref type");
- auto loc = op->getLoc();
-
- // Get actual sizes of the memref as values: static sizes are constant
- // values and dynamic sizes are passed to 'alloc' as operands. In case of
- // zero-dimensional memref, assume a scalar (size 1).
- SmallVector<Value, 4> sizes;
- SmallVector<Value, 4> strides;
- Value size;
-
- this->getMemRefDescriptorSizes(loc, memRefType, operands, rewriter, sizes,
- strides, size, !requiresNumElements);
-
- // Allocate the underlying buffer.
- auto [allocatedPtr, alignedPtr] =
- this->allocateBuffer(rewriter, loc, size, op);
-
- if (!allocatedPtr || !alignedPtr)
- return rewriter.notifyMatchFailure(loc,
- "underlying buffer allocation failed");
-
- // Create the MemRef descriptor.
- auto memRefDescriptor = this->createMemRefDescriptor(
- loc, memRefType, allocatedPtr, alignedPtr, sizes, strides, rewriter);
-
- // Return the final value of the descriptor.
- rewriter.replaceOp(op, {memRefDescriptor});
- return success();
-}
diff --git a/mlir/lib/Conversion/MemRefToLLVM/CMakeLists.txt b/mlir/lib/Conversion/MemRefToLLVM/CMakeLists.txt
index f0d95f5ada290..9da4b23d42f41 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/CMakeLists.txt
+++ b/mlir/lib/Conversion/MemRefToLLVM/CMakeLists.txt
@@ -1,5 +1,4 @@
add_mlir_conversion_library(MLIRMemRefToLLVM
- AllocLikeConversion.cpp
MemRefToLLVM.cpp
ADDITIONAL_HEADER_DIRS
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 9c219d8a3d8cb..c8b2c0bdc6c20 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -13,7 +13,6 @@
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
-#include "mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
@@ -53,33 +52,247 @@ getFreeFn(OpBuilder &b, const LLVMTypeConverter *typeConverter,
return LLVM::lookupOrCreateFreeFn(b, module);
}
-struct AllocOpLowering : public AllocLikeOpLLVMLowering {
- AllocOpLowering(const LLVMTypeConverter &converter)
- : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(),
- converter) {}
- std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
- Location loc, Value sizeBytes,
- Operation *op) const override {
- return allocateBufferManuallyAlign(
- rewriter, loc, sizeBytes, op,
- getAlignment(rewriter, loc, cast<memref::AllocOp>(op)));
+static FailureOr<LLVM::LLVMFuncOp>
+getNotalignedAllocFn(OpBuilder &b, const LLVMTypeConverter *typeConverter,
+ Operation *module, Type indexType) {
+ bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
+ if (useGenericFn)
+ return LLVM::lookupOrCreateGenericAllocFn(b, module, indexType);
+
+ return LLVM::lookupOrCreateMallocFn(b, module, indexType);
+}
+
+static FailureOr<LLVM::LLVMFuncOp>
+getAlignedAllocFn(OpBuilder &b, const LLVMTypeConverter *typeConverter,
+ Operation *module, Type indexType) {
+ bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
+
+ if (useGenericFn)
+ return LLVM::lookupOrCreateGenericAlignedAllocFn(b, module, indexType);
+
+ return LLVM::lookupOrCreateAlignedAllocFn(b, module, indexType);
+}
+
+/// Computes the aligned value for 'input' as follows:
+/// bumped = input + alignement - 1
+/// aligned = bumped - bumped % alignment
+static Value createAligned(ConversionPatternRewriter &rewriter, Location loc,
+ Value input, Value alignment) {
+ Value one = rewriter.create<LLVM::ConstantOp>(loc, alignment.getType(),
+ rewriter.getIndexAttr(1));
+ Value bump = rewriter.create<LLVM::SubOp>(loc, alignment, one);
+ Value bumped = rewriter.create<LLVM::AddOp>(loc, input, bump);
+ Value mod = rewriter.create<LLVM::URemOp>(loc, bumped, alignment);
+ return rewriter.create<LLVM::SubOp>(loc, bumped, mod);
+}
+
+/// Computes the byte size for the MemRef element type.
+static unsigned getMemRefEltSizeInBytes(const LLVMTypeConverter *typeConverter,
+ MemRefType memRefType, Operation *op,
+ const DataLayout *defaultLayout) {
+ const DataLayout *layout = defaultLayout;
+ if (const DataLayoutAnalysis *analysis =
+ typeConverter->getDataLayoutAnalysis()) {
+ layout = &analysis->getAbove(op);
+ }
+ Type elementType = memRefType.getElementType();
+ if (auto memRefElementType = dyn_cast<MemRefType>(elementType))
+ return typeConverter->getMemRefDescriptorSize(memRefElementType, *layout);
+ if (auto memRefElementType = dyn_cast<UnrankedMemRefType>(elementType))
+ return typeConverter->getUnrankedMemRefDescriptorSize(memRefElementType,
+ *layout);
+ return layout->getTypeSize(elementType);
+}
+
+static Value castAllocFuncResult(ConversionPatternRewriter &rewriter,
+ Location loc, Value allocatedPtr,
+ MemRefType memRefType, Type elementPtrType,
+ const LLVMTypeConverter &typeConverter) {
+ auto allocatedPtrTy = cast<LLVM::LLVMPointerType>(allocatedPtr.getType());
+ FailureOr<unsigned> maybeMemrefAddrSpace =
+ typeConverter.getMemRefAddressSpace(memRefType);
+ assert(succeeded(maybeMemrefAddrSpace) && "unsupported address space");
+ unsigned memrefAddrSpace = *maybeMemrefAddrSpace;
+ if (allocatedPtrTy.getAddressSpace() != memrefAddrSpace)
+ allocatedPtr = rewriter.create<LLVM::AddrSpaceCastOp>(
+ loc, LLVM::LLVMPointerType::get(rewriter.getContext(), memrefAddrSpace),
+ allocatedPtr);
+ return allocatedPtr;
+}
+
+struct AllocOpLowering : public ConvertOpToLLVMPattern<memref::AllocOp> {
+ using ConvertOpToLLVMPattern<memref::AllocOp>::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(memref::AllocOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = op.getLoc();
+ MemRefType memRefType = op.getType();
+ if (!isConvertibleAndHasIdentityMaps(memRefType))
+ return rewriter.notifyMatchFailure(op, "incompatible memref type");
+
+ // Get or insert alloc function into the module.
+ FailureOr<LLVM::LLVMFuncOp> allocFuncOp = getNotalignedAllocFn(
+ rewriter, getTypeConverter(),
+ op->getParentWithTrait<OpTrait::SymbolTable>(), getIndexType());
+ if (failed(allocFuncOp))
+ return failure();
+
+ // Get actual sizes of the memref as values: static sizes are constant
+ // values and dynamic sizes are passed to 'alloc' as operands. In case of
+ // zero-dimensional memref, assume a scalar (size 1).
+ SmallVector<Value, 4> sizes;
+ SmallVector<Value, 4> strides;
+ Value sizeBytes;
+
+ this->getMemRefDescriptorSizes(loc, memRefType, adaptor.getOperands(),
+ rewriter, sizes, strides, sizeBytes, true);
+
+ Value alignment = getAlignment(rewriter, loc, op);
+ if (alignment) {
+ // Adjust the allocation size to consider alignment.
+ sizeBytes = rewriter.create<LLVM::AddOp>(loc, sizeBytes, alignment);
+ }
+
+ // Allocate the underlying buffer.
+ Type elementPtrType = this->getElementPtrType(memRefType);
+ assert(elementPtrType && "could not compute element ptr type");
+ auto results =
+ rewriter.create<LLVM::CallOp>(loc, allocFuncOp.value(), sizeBytes);
+
+ Value allocatedPtr =
+ castAllocFuncResult(rewriter, loc, results.getResult(), memRefType,
+ elementPtrType, *getTypeConverter());
+ Value alignedPtr = allocatedPtr;
+ if (alignment) {
+ // Compute the aligned pointer.
+ Value allocatedInt =
+ rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), allocatedPtr);
+ Value alignmentInt =
+ createAligned(rewriter, loc, allocatedInt, alignment);
+ alignedPtr =
+ rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, alignmentInt);
+ }
+
+ // Create the MemRef descriptor.
+ auto memRefDescriptor = this->createMemRefDescriptor(
+ loc, memRefType, allocatedPtr, alignedPtr, sizes, strides, rewriter);
+
+ // Return the final value of the descriptor.
+ rewriter.replaceOp(op, {memRefDescriptor});
+ return success();
+ }
+
+ /// Computes the alignment for the given memory allocation op.
+ template <typename OpType>
+ Value getAlignment(ConversionPatternRewriter &rewriter, Location loc,
+ OpType op) const {
+ MemRefType memRefType = op.getType();
+ Value alignment;
+ if (auto alignmentAttr = op.getAlignment()) {
+ Type indexType = getIndexType();
+ alignment =
+ createIndexAttrConstant(rewriter, loc, indexType, *alignmentAttr);
+ } else if (!memRefType.getElementType().isSignlessIntOrIndexOrFloat()) {
+ // In the case where no alignment is specified, we may want to override
+ // `malloc's` behavior. `malloc` typically aligns at the size of the
+ // biggest scalar on a target HW. For non-scalars, use the natural
+ // alignment of the LLVM type given by the LLVM DataLayout.
+ alignment = getSizeInBytes(loc, memRefType.getElementType(), rewriter);
+ }
+ return alignment;
}
};
-struct AlignedAllocOpLowering : public AllocLikeOpLLVMLowering {
- AlignedAllocOpLowering(const LLVMTypeConverter &converter)
- : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(),
- converter) {}
- std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
- Location loc, Value sizeBytes,
- Operation *op) const override {
- Value ptr = allocateBufferAutoAlign(
- rewriter, loc, sizeBytes, op, &defaultLayout,
- alignedAllocationGetAlignment(rewriter, loc, cast<memref::AllocOp>(op),
- &defaultLayout));
- if (!ptr)
- return std::make_tuple(Value(), Value());
- return std::make_tuple(ptr, ptr);
+struct AlignedAllocOpLowering : public ConvertOpToLLVMPattern<memref::AllocOp> {
+ using ConvertOpToLLVMPattern<memref::AllocOp>::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(memref::AllocOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = op.getLoc();
+ MemRefType memRefType = op.getType();
+ if (!isConvertibleAndHasIdentityMaps(memRefType))
+ return rewriter.notifyMatchFailure(op, "incompatible memref type");
+
+ // Get or insert alloc function into module.
+ FailureOr<LLVM::LLVMFuncOp> allocFuncOp = getAlignedAllocFn(
+ rewriter, getTypeConverter(),
+ op->getParentWithTrait<OpTrait::SymbolTable>(), getIndexType());
+ if (failed(allocFuncOp))
+ return failure();
+
+ // Get actual sizes of the memref as values: static sizes are constant
+ // values and dynamic sizes are passed to 'alloc' as operands. In case of
+ // zero-dimensional memref, assume a scalar (size 1).
+ SmallVector<Value, 4> sizes;
+ SmallVector<Value, 4> strides;
+ Value sizeBytes;
+
+ this->getMemRefDescriptorSizes(loc, memRefType, adaptor.getOperands(),
+ rewriter, sizes, strides, sizeBytes, !false);
+
+ int64_t alignment = alignedAllocationGetAlignment(op, &defaultLayout);
+
+ Value allocAlignment =
+ createIndexAttrConstant(rewriter, loc, getIndexType(), alignment);
+
+ // Function aligned_alloc requires size to be a multiple of alignment; we
+ // pad the size to the next multiple if necessary.
+ if (!isMemRefSizeMultipleOf(memRefType, alignment, op, &defaultLayout))
+ sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment);
+
+ Type elementPtrType = this->getElementPtrType(memRefType);
+ auto results = rewriter.create<LLVM::CallOp>(
+ loc, allocFuncOp.value(), ValueRange({allocAlignment, sizeBytes}));
+
+ Value ptr =
+ castAllocFuncResult(rewriter, loc, results.getResult(), memRefType,
+ elementPtrType, *getTypeConverter());
+
+ // Create the MemRef descriptor.
+ auto memRefDescriptor = this->createMemRefDescriptor(
+ loc, memRefType, ptr, ptr, sizes, strides, rewriter);
+
+ // Return the final value of the descriptor.
+ rewriter.replaceOp(op, {memRefDescriptor});
+ return success();
+ }
+
+ /// The minimum alignment to use with aligned_alloc (has to be a power of 2).
+ static constexpr uint64_t kMinAlignedAllocAlignment = 16UL;
+
+ /// Computes the alignment for aligned_alloc used to allocate the buffer for
+ /// the memory allocation op.
+ ///
+ /// Aligned_alloc requires the allocation size to be a power of two, and the
+ /// allocation size to be a multiple of the alignment.
+ int64_t alignedAllocationGetAlignment(memref::AllocOp op,
+ const DataLayout *defaultLayout) const {
+ if (std::optional<uint64_t> alignment = op.getAlignment())
+ return *alignment;
+
+ // Whenever we don't have alignment set, we will use an alignment
+ // consistent with the element type; since the allocation size has to be a
+ // power of two, we will bump to the next power of two if it isn't.
+ unsigned eltSizeBytes = getMemRefEltSizeInBytes(
+ getTypeConverter(), op.getType(), op, defaultLayout);
+ return std::max(kMinAlignedAllocAlignment,
+ llvm::PowerOf2Ceil(eltSizeBytes));
+ }
+
+ /// Returns true if the memref size in bytes is known to be a multiple of
+ /// factor.
+ bool isMemRefSizeMultipleOf(MemRefType type, uint64_t factor, Operation *op,
+ const DataLayout *defaultLayout) const {
+ uint64_t sizeDivisor =
+ getMemRefEltSizeInBytes(getTypeConverter(), type, op, defaultLayout);
+ for (unsigned i = 0, e = type.getRank(); i < e; i++) {
+ if (type.isDynamicDim(i))
+ continue;
+ sizeDivisor = sizeDivisor * type.getDimSize(i);
+ }
+ return sizeDivisor % factor == 0;
}
private:
@@ -87,38 +300,52 @@ struct AlignedAllocOpLowering : public AllocLikeOpLLVMLowering {
DataLayout defaultLayout;
};
-struct AllocaOpLowering : public AllocLikeOpLLVMLowering {
- AllocaOpLowering(const LLVMTypeConverter &converter)
- : AllocLikeOpLLVMLowering(memref::AllocaOp::getOperationName(),
- converter) {
- setRequiresNumElements();
- }
+struct AllocaOpLowering : public ConvertOpToLLVMPattern<memref::AllocaOp> {
+ using ConvertOpToLLVMPattern<memref::AllocaOp>::ConvertOpToLLVMPattern;
/// Allocates the underlying buffer using the right call. `allocatedBytePtr`
/// is set to null for stack allocations. `accessAlignment` is set if
/// alignment is needed post allocation (for eg. in conjunction with malloc).
- std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
- Location loc, Value size,
- Operation *op) const override {
+ LogicalResult
+ matchAndRewrite(memref::AllocaOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = op.getLoc();
+ MemRefType memRefType = op.getType();
+ if (!isConvertibleAndHasIdentityMaps(memRefType))
+ return rewriter.notifyMatchFailure(op, "incompatible memref type");
+
+ // Get actual sizes of the memref as values: static sizes are constant
+ // values and dynamic sizes are passed to 'alloc' as operands. In case of
+ // zero-dimensional memref, assume a scalar (size 1).
+ SmallVector<Value, 4> sizes;
+ SmallVector<Value, 4> strides;
+ Value size;
+
+ this->getMemRefDescriptorSizes(loc, memRefType, adaptor.getOperands(),
+ rewriter, sizes, strides, size, !true);
// With alloca, one gets a pointer to the element type right away.
// For stack allocations.
- auto allocaOp = cast<memref::AllocaOp>(op);
auto elementType =
- typeConverter->convertType(allocaOp.getType().getElementType());
+ typeConverter->convertType(op.getType().getElementType());
FailureOr<unsigned> maybeAddressSpace =
- getTypeConverter()->getMemRefAddressSpace(allocaOp.getType());
- if (failed(maybeAddressSpace))
- return std::make_tuple(Value(), Value());
+ getTypeConverter()->getMemRefAddressSpace(op.getType());
+ assert(succeeded(maybeAddressSpace) && "unsupported address space");
unsigned addrSpace = *maybeAddressSpace;
auto elementPtrType =
LLVM::LLVMPointerType::get(rewriter.getContext(), addrSpace);
- auto allocatedElementPtr =
- rewriter.create<LLVM::AllocaOp>(loc, elementPtrType, elementType, size,
- allocaOp.getAlignment().value_or(0));
+ auto allocatedElementPtr = rewriter.create<LLVM::AllocaOp>(
+ loc, elementPtrType, elementType, size, op.getAlignment().value_or(0));
- return std::make_tuple(allocatedElementPtr, allocatedElementPtr);
+ // Create the MemRef descriptor.
+ auto memRefDescriptor = this->createMemRefDescriptor(
+ loc, memRefType, allocatedElementPtr, allocatedElementPtr, sizes,
+ strides, rewriter);
+
+ // Return the final value of the descriptor.
+ rewriter.replaceOp(op, {memRefDescriptor});
+ return success();
}
};
@@ -527,31 +754,43 @@ struct GlobalMemrefOpLowering
/// GetGlobalMemrefOp is lowered into a Memref descriptor with the pointer to
/// the first element stashed into the descriptor. This reuses
/// `AllocLikeOpLowering` to reuse the Memref descriptor construction.
-struct GetGlobalMemrefOpLowering : public AllocLikeOpLLVMLowering {
- GetGlobalMemrefOpLowering(const LLVMTypeConverter &converter)
- : AllocLikeOpLLVMLowering(memref::GetGlobalOp::getOperationName(),
- converter) {}
+struct GetGlobalMemrefOpLowering
+ : public ConvertOpToLLVMPattern<memref::GetGlobalOp> {
+ using ConvertOpToLLVMPattern<memref::GetGlobalOp>::ConvertOpToLLVMPattern;
/// Buffer "allocation" for memref.get_global op is getting the address of
/// the global variable referenced.
- std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
- Location loc, Value sizeBytes,
- Operation *op) const override {
- auto getGlobalOp = cast<memref::GetGlobalOp>(op);
- MemRefType type = cast<MemRefType>(getGlobalOp.getResult().getType());
+ LogicalResult
+ matchAndRewrite(memref::GetGlobalOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = op.getLoc();
+ MemRefType memRefType = op.getType();
+ if (!isConvertibleAndHasIdentityMaps(memRefType))
+ return rewriter.notifyMatchFailure(op, "incompatible memref type");
+
+ // Get actual sizes of the memref as values: static sizes are constant
+ // values and dynamic sizes are passed to 'alloc' as operands. In case of
+ // zero-dimensional memref, assume a scalar (size 1).
+ SmallVector<Value, 4> sizes;
+ SmallVector<Value, 4> strides;
+ Value sizeBytes;
+
+ this->getMemRefDescriptorSizes(loc, memRefType, adaptor.getOperands(),
+ rewriter, sizes, strides, sizeBytes, !false);
+
+ MemRefType type = cast<MemRefType>(op.getResult().getType());
// This is called after a type conversion, which would have failed if this
// call fails.
FailureOr<unsigned> maybeAddressSpace =
getTypeConverter()->getMemRefAddressSpace(type);
- if (failed(maybeAddressSpace))
- return std::make_tuple(Value(), Value());
+ assert(succeeded(maybeAddressSpace) && "unsupported address space");
unsigned memSpace = *maybeAddressSpace;
Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
auto ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), memSpace);
auto addressOf =
- rewriter.create<LLVM::AddressOfOp>(loc, ptrTy, getGlobalOp.getName());
+ rewriter.create<LLVM::AddressOfOp>(loc, ptrTy, op.getName());
// Get the address of the first element in the array by creating a GEP with
// the address of the GV as the base, and (rank + 1) number of 0 indices.
@@ -570,7 +809,13 @@ struct GetGlobalMemrefOpLowering : public AllocLikeOpLLVMLowering {
// Both allocated and aligned pointers are same. We could potentially stash
// a nullptr for the allocated pointer since we do not expect any dealloc.
- return std::make_tuple(deadBeefPtr, gep);
+ // Create the MemRef descriptor.
+ auto memRefDescriptor = this->createMemRefDescriptor(
+ loc, memRefType, deadBeefPtr, gep, sizes, strides, rewriter);
+
+ // Return the final value of the descriptor.
+ rewriter.replaceOp(op, {memRefDescriptor});
+ return success();
}
};
More information about the Mlir-commits
mailing list