[Mlir-commits] [mlir] [mlir][memref] Move `AllocLikeConversion.h` helpers into `MemRefToLLVM.cpp` (PR #136424)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sat Apr 19 04:26:26 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Matthias Springer (matthias-springer)
<details>
<summary>Changes</summary>
This commit moves code around: The helper functions/classes are moved into `MemRefToLLVM.cpp`. This simplifies the code a bit: fewer templatized functions, fewer function calls.
This commit also moves checks in `matchAndRewrite` to the beginning of the functions, such that patterns bail out before starting to modify any IR. This is in preparation of the One-Shot Dialect Conversion refactoring, which will disallow pattern rollbacks.
---
Patch is 35.56 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/136424.diff
4 Files Affected:
- (removed) mlir/include/mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h (-153)
- (removed) mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp (-195)
- (modified) mlir/lib/Conversion/MemRefToLLVM/CMakeLists.txt (-1)
- (modified) mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp (+301-56)
``````````diff
diff --git a/mlir/include/mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h b/mlir/include/mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h
deleted file mode 100644
index 8bf04219c759a..0000000000000
--- a/mlir/include/mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h
+++ /dev/null
@@ -1,153 +0,0 @@
-//===- AllocLikeConversion.h - Convert allocation ops to LLVM ---*- C++ -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_CONVERSION_MEMREFTOLLVM_ALLOCLIKECONVERSION_H
-#define MLIR_CONVERSION_MEMREFTOLLVM_ALLOCLIKECONVERSION_H
-
-#include "mlir/Conversion/LLVMCommon/Pattern.h"
-
-namespace mlir {
-
-/// Lowering for memory allocation ops.
-struct AllocationOpLLVMLowering : public ConvertToLLVMPattern {
- using ConvertToLLVMPattern::createIndexAttrConstant;
- using ConvertToLLVMPattern::getIndexType;
- using ConvertToLLVMPattern::getVoidPtrType;
-
- explicit AllocationOpLLVMLowering(StringRef opName,
- const LLVMTypeConverter &converter,
- PatternBenefit benefit = 1)
- : ConvertToLLVMPattern(opName, &converter.getContext(), converter,
- benefit) {}
-
-protected:
- /// Computes the aligned value for 'input' as follows:
- /// bumped = input + alignement - 1
- /// aligned = bumped - bumped % alignment
- static Value createAligned(ConversionPatternRewriter &rewriter, Location loc,
- Value input, Value alignment);
-
- static MemRefType getMemRefResultType(Operation *op) {
- return cast<MemRefType>(op->getResult(0).getType());
- }
-
- /// Computes the alignment for the given memory allocation op.
- template <typename OpType>
- Value getAlignment(ConversionPatternRewriter &rewriter, Location loc,
- OpType op) const {
- MemRefType memRefType = op.getType();
- Value alignment;
- if (auto alignmentAttr = op.getAlignment()) {
- Type indexType = getIndexType();
- alignment =
- createIndexAttrConstant(rewriter, loc, indexType, *alignmentAttr);
- } else if (!memRefType.getElementType().isSignlessIntOrIndexOrFloat()) {
- // In the case where no alignment is specified, we may want to override
- // `malloc's` behavior. `malloc` typically aligns at the size of the
- // biggest scalar on a target HW. For non-scalars, use the natural
- // alignment of the LLVM type given by the LLVM DataLayout.
- alignment = getSizeInBytes(loc, memRefType.getElementType(), rewriter);
- }
- return alignment;
- }
-
- /// Computes the alignment for aligned_alloc used to allocate the buffer for
- /// the memory allocation op.
- ///
- /// Aligned_alloc requires the allocation size to be a power of two, and the
- /// allocation size to be a multiple of the alignment.
- template <typename OpType>
- int64_t alignedAllocationGetAlignment(ConversionPatternRewriter &rewriter,
- Location loc, OpType op,
- const DataLayout *defaultLayout) const {
- if (std::optional<uint64_t> alignment = op.getAlignment())
- return *alignment;
-
- // Whenever we don't have alignment set, we will use an alignment
- // consistent with the element type; since the allocation size has to be a
- // power of two, we will bump to the next power of two if it isn't.
- unsigned eltSizeBytes =
- getMemRefEltSizeInBytes(op.getType(), op, defaultLayout);
- return std::max(kMinAlignedAllocAlignment,
- llvm::PowerOf2Ceil(eltSizeBytes));
- }
-
- /// Allocates a memory buffer using an allocation method that doesn't
- /// guarantee alignment. Returns the pointer and its aligned value.
- std::tuple<Value, Value>
- allocateBufferManuallyAlign(ConversionPatternRewriter &rewriter, Location loc,
- Value sizeBytes, Operation *op,
- Value alignment) const;
-
- /// Allocates a memory buffer using an aligned allocation method.
- Value allocateBufferAutoAlign(ConversionPatternRewriter &rewriter,
- Location loc, Value sizeBytes, Operation *op,
- const DataLayout *defaultLayout,
- int64_t alignment) const;
-
-private:
- /// Computes the byte size for the MemRef element type.
- unsigned getMemRefEltSizeInBytes(MemRefType memRefType, Operation *op,
- const DataLayout *defaultLayout) const;
-
- /// Returns true if the memref size in bytes is known to be a multiple of
- /// factor.
- bool isMemRefSizeMultipleOf(MemRefType type, uint64_t factor, Operation *op,
- const DataLayout *defaultLayout) const;
-
- /// The minimum alignment to use with aligned_alloc (has to be a power of 2).
- static constexpr uint64_t kMinAlignedAllocAlignment = 16UL;
-};
-
-/// Lowering for AllocOp and AllocaOp.
-struct AllocLikeOpLLVMLowering : public AllocationOpLLVMLowering {
- explicit AllocLikeOpLLVMLowering(StringRef opName,
- const LLVMTypeConverter &converter,
- PatternBenefit benefit = 1)
- : AllocationOpLLVMLowering(opName, converter, benefit) {}
-
-protected:
- /// Allocates the underlying buffer. Returns the allocated pointer and the
- /// aligned pointer.
- virtual std::tuple<Value, Value>
- allocateBuffer(ConversionPatternRewriter &rewriter, Location loc, Value size,
- Operation *op) const = 0;
-
- /// Sets the flag 'requiresNumElements', specifying the Op requires the number
- /// of elements instead of the size in bytes.
- void setRequiresNumElements();
-
-private:
- // An `alloc` is converted into a definition of a memref descriptor value and
- // a call to `malloc` to allocate the underlying data buffer. The memref
- // descriptor is of the LLVM structure type where:
- // 1. the first element is a pointer to the allocated (typed) data buffer,
- // 2. the second element is a pointer to the (typed) payload, aligned to the
- // specified alignment,
- // 3. the remaining elements serve to store all the sizes and strides of the
- // memref using LLVM-converted `index` type.
- //
- // Alignment is performed by allocating `alignment` more bytes than
- // requested and shifting the aligned pointer relative to the allocated
- // memory. Note: `alignment - <minimum malloc alignment>` would actually be
- // sufficient. If alignment is unspecified, the two pointers are equal.
-
- // An `alloca` is converted into a definition of a memref descriptor value and
- // an llvm.alloca to allocate the underlying data buffer.
- LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const override;
-
- // Flag for specifying the Op requires the number of elements instead of the
- // size in bytes.
- bool requiresNumElements = false;
-};
-
-} // namespace mlir
-
-#endif // MLIR_CONVERSION_MEMREFTOLLVM_ALLOCLIKECONVERSION_H
diff --git a/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp b/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp
deleted file mode 100644
index bad209a4ddecf..0000000000000
--- a/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp
+++ /dev/null
@@ -1,195 +0,0 @@
-//===- AllocLikeConversion.cpp - LLVM conversion for alloc operations -----===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h"
-#include "mlir/Analysis/DataLayoutAnalysis.h"
-#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
-#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
-#include "mlir/IR/SymbolTable.h"
-
-using namespace mlir;
-
-static FailureOr<LLVM::LLVMFuncOp>
-getNotalignedAllocFn(const LLVMTypeConverter *typeConverter, Operation *module,
- Type indexType) {
- bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
- if (useGenericFn)
- return LLVM::lookupOrCreateGenericAllocFn(module, indexType);
-
- return LLVM::lookupOrCreateMallocFn(module, indexType);
-}
-
-static FailureOr<LLVM::LLVMFuncOp>
-getAlignedAllocFn(const LLVMTypeConverter *typeConverter, Operation *module,
- Type indexType) {
- bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
-
- if (useGenericFn)
- return LLVM::lookupOrCreateGenericAlignedAllocFn(module, indexType);
-
- return LLVM::lookupOrCreateAlignedAllocFn(module, indexType);
-}
-
-Value AllocationOpLLVMLowering::createAligned(
- ConversionPatternRewriter &rewriter, Location loc, Value input,
- Value alignment) {
- Value one = createIndexAttrConstant(rewriter, loc, alignment.getType(), 1);
- Value bump = rewriter.create<LLVM::SubOp>(loc, alignment, one);
- Value bumped = rewriter.create<LLVM::AddOp>(loc, input, bump);
- Value mod = rewriter.create<LLVM::URemOp>(loc, bumped, alignment);
- return rewriter.create<LLVM::SubOp>(loc, bumped, mod);
-}
-
-static Value castAllocFuncResult(ConversionPatternRewriter &rewriter,
- Location loc, Value allocatedPtr,
- MemRefType memRefType, Type elementPtrType,
- const LLVMTypeConverter &typeConverter) {
- auto allocatedPtrTy = cast<LLVM::LLVMPointerType>(allocatedPtr.getType());
- FailureOr<unsigned> maybeMemrefAddrSpace =
- typeConverter.getMemRefAddressSpace(memRefType);
- if (failed(maybeMemrefAddrSpace))
- return Value();
- unsigned memrefAddrSpace = *maybeMemrefAddrSpace;
- if (allocatedPtrTy.getAddressSpace() != memrefAddrSpace)
- allocatedPtr = rewriter.create<LLVM::AddrSpaceCastOp>(
- loc, LLVM::LLVMPointerType::get(rewriter.getContext(), memrefAddrSpace),
- allocatedPtr);
- return allocatedPtr;
-}
-
-std::tuple<Value, Value> AllocationOpLLVMLowering::allocateBufferManuallyAlign(
- ConversionPatternRewriter &rewriter, Location loc, Value sizeBytes,
- Operation *op, Value alignment) const {
- if (alignment) {
- // Adjust the allocation size to consider alignment.
- sizeBytes = rewriter.create<LLVM::AddOp>(loc, sizeBytes, alignment);
- }
-
- MemRefType memRefType = getMemRefResultType(op);
- // Allocate the underlying buffer.
- Type elementPtrType = this->getElementPtrType(memRefType);
- assert(elementPtrType && "could not compute element ptr type");
- FailureOr<LLVM::LLVMFuncOp> allocFuncOp = getNotalignedAllocFn(
- getTypeConverter(), op->getParentWithTrait<OpTrait::SymbolTable>(),
- getIndexType());
- if (failed(allocFuncOp))
- return std::make_tuple(Value(), Value());
- auto results =
- rewriter.create<LLVM::CallOp>(loc, allocFuncOp.value(), sizeBytes);
-
- Value allocatedPtr =
- castAllocFuncResult(rewriter, loc, results.getResult(), memRefType,
- elementPtrType, *getTypeConverter());
- if (!allocatedPtr)
- return std::make_tuple(Value(), Value());
- Value alignedPtr = allocatedPtr;
- if (alignment) {
- // Compute the aligned pointer.
- Value allocatedInt =
- rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), allocatedPtr);
- Value alignmentInt = createAligned(rewriter, loc, allocatedInt, alignment);
- alignedPtr =
- rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, alignmentInt);
- }
-
- return std::make_tuple(allocatedPtr, alignedPtr);
-}
-
-unsigned AllocationOpLLVMLowering::getMemRefEltSizeInBytes(
- MemRefType memRefType, Operation *op,
- const DataLayout *defaultLayout) const {
- const DataLayout *layout = defaultLayout;
- if (const DataLayoutAnalysis *analysis =
- getTypeConverter()->getDataLayoutAnalysis()) {
- layout = &analysis->getAbove(op);
- }
- Type elementType = memRefType.getElementType();
- if (auto memRefElementType = dyn_cast<MemRefType>(elementType))
- return getTypeConverter()->getMemRefDescriptorSize(memRefElementType,
- *layout);
- if (auto memRefElementType = dyn_cast<UnrankedMemRefType>(elementType))
- return getTypeConverter()->getUnrankedMemRefDescriptorSize(
- memRefElementType, *layout);
- return layout->getTypeSize(elementType);
-}
-
-bool AllocationOpLLVMLowering::isMemRefSizeMultipleOf(
- MemRefType type, uint64_t factor, Operation *op,
- const DataLayout *defaultLayout) const {
- uint64_t sizeDivisor = getMemRefEltSizeInBytes(type, op, defaultLayout);
- for (unsigned i = 0, e = type.getRank(); i < e; i++) {
- if (type.isDynamicDim(i))
- continue;
- sizeDivisor = sizeDivisor * type.getDimSize(i);
- }
- return sizeDivisor % factor == 0;
-}
-
-Value AllocationOpLLVMLowering::allocateBufferAutoAlign(
- ConversionPatternRewriter &rewriter, Location loc, Value sizeBytes,
- Operation *op, const DataLayout *defaultLayout, int64_t alignment) const {
- Value allocAlignment =
- createIndexAttrConstant(rewriter, loc, getIndexType(), alignment);
-
- MemRefType memRefType = getMemRefResultType(op);
- // Function aligned_alloc requires size to be a multiple of alignment; we pad
- // the size to the next multiple if necessary.
- if (!isMemRefSizeMultipleOf(memRefType, alignment, op, defaultLayout))
- sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment);
-
- Type elementPtrType = this->getElementPtrType(memRefType);
- FailureOr<LLVM::LLVMFuncOp> allocFuncOp = getAlignedAllocFn(
- getTypeConverter(), op->getParentWithTrait<OpTrait::SymbolTable>(),
- getIndexType());
- if (failed(allocFuncOp))
- return Value();
- auto results = rewriter.create<LLVM::CallOp>(
- loc, allocFuncOp.value(), ValueRange({allocAlignment, sizeBytes}));
-
- return castAllocFuncResult(rewriter, loc, results.getResult(), memRefType,
- elementPtrType, *getTypeConverter());
-}
-
-void AllocLikeOpLLVMLowering::setRequiresNumElements() {
- requiresNumElements = true;
-}
-
-LogicalResult AllocLikeOpLLVMLowering::matchAndRewrite(
- Operation *op, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const {
- MemRefType memRefType = getMemRefResultType(op);
- if (!isConvertibleAndHasIdentityMaps(memRefType))
- return rewriter.notifyMatchFailure(op, "incompatible memref type");
- auto loc = op->getLoc();
-
- // Get actual sizes of the memref as values: static sizes are constant
- // values and dynamic sizes are passed to 'alloc' as operands. In case of
- // zero-dimensional memref, assume a scalar (size 1).
- SmallVector<Value, 4> sizes;
- SmallVector<Value, 4> strides;
- Value size;
-
- this->getMemRefDescriptorSizes(loc, memRefType, operands, rewriter, sizes,
- strides, size, !requiresNumElements);
-
- // Allocate the underlying buffer.
- auto [allocatedPtr, alignedPtr] =
- this->allocateBuffer(rewriter, loc, size, op);
-
- if (!allocatedPtr || !alignedPtr)
- return rewriter.notifyMatchFailure(loc,
- "underlying buffer allocation failed");
-
- // Create the MemRef descriptor.
- auto memRefDescriptor = this->createMemRefDescriptor(
- loc, memRefType, allocatedPtr, alignedPtr, sizes, strides, rewriter);
-
- // Return the final value of the descriptor.
- rewriter.replaceOp(op, {memRefDescriptor});
- return success();
-}
diff --git a/mlir/lib/Conversion/MemRefToLLVM/CMakeLists.txt b/mlir/lib/Conversion/MemRefToLLVM/CMakeLists.txt
index f0d95f5ada290..9da4b23d42f41 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/CMakeLists.txt
+++ b/mlir/lib/Conversion/MemRefToLLVM/CMakeLists.txt
@@ -1,5 +1,4 @@
add_mlir_conversion_library(MLIRMemRefToLLVM
- AllocLikeConversion.cpp
MemRefToLLVM.cpp
ADDITIONAL_HEADER_DIRS
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index cb4317ef1bcec..91a95bcdef465 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -13,7 +13,6 @@
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
-#include "mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
@@ -52,33 +51,247 @@ getFreeFn(const LLVMTypeConverter *typeConverter, ModuleOp module) {
return LLVM::lookupOrCreateFreeFn(module);
}
-struct AllocOpLowering : public AllocLikeOpLLVMLowering {
- AllocOpLowering(const LLVMTypeConverter &converter)
- : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(),
- converter) {}
- std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
- Location loc, Value sizeBytes,
- Operation *op) const override {
- return allocateBufferManuallyAlign(
- rewriter, loc, sizeBytes, op,
- getAlignment(rewriter, loc, cast<memref::AllocOp>(op)));
+static FailureOr<LLVM::LLVMFuncOp>
+getNotalignedAllocFn(const LLVMTypeConverter *typeConverter, Operation *module,
+ Type indexType) {
+ bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
+ if (useGenericFn)
+ return LLVM::lookupOrCreateGenericAllocFn(module, indexType);
+
+ return LLVM::lookupOrCreateMallocFn(module, indexType);
+}
+
+static FailureOr<LLVM::LLVMFuncOp>
+getAlignedAllocFn(const LLVMTypeConverter *typeConverter, Operation *module,
+ Type indexType) {
+ bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
+
+ if (useGenericFn)
+ return LLVM::lookupOrCreateGenericAlignedAllocFn(module, indexType);
+
+ return LLVM::lookupOrCreateAlignedAllocFn(module, indexType);
+}
+
+/// Computes the aligned value for 'input' as follows:
+/// bumped = input + alignement - 1
+/// aligned = bumped - bumped % alignment
+static Value createAligned(ConversionPatternRewriter &rewriter, Location loc,
+ Value input, Value alignment) {
+ Value one = rewriter.create<LLVM::ConstantOp>(loc, alignment.getType(),
+ rewriter.getIndexAttr(1));
+ Value bump = rewriter.create<LLVM::SubOp>(loc, alignment, one);
+ Value bumped = rewriter.create<LLVM::AddOp>(loc, input, bump);
+ Value mod = rewriter.create<LLVM::URemOp>(loc, bumped, alignment);
+ return rewriter.create<LLVM::SubOp>(loc, bumped, mod);
+}
+
+/// Computes the byte size for the MemRef element type.
+static unsigned getMemRefEltSizeInBytes(const LLVMTypeConverter *typeConverter,
+ MemRefType memRefType, Operation *op,
+ const DataLayout *defaultLayout) {
+ const DataLayout *layout = defaultLayout;
+ if (const DataLayoutAnalysis *analysis =
+ typeConverter->getDataLayoutAnalysis()) {
+ layout = &analysis->getAbove(op);
+ }
+ Type elementType = memRefType.getElementType();
+ if (auto memRefElementType = dyn_cast<MemRefType>(elementType))
+ return typeConverter->getMemRefDescriptorSize(memRefElementType, *layout);
+ if (auto memRefElementType = dyn_cast<UnrankedMemRefType>(elementType))
+ return typeConverter->getUnrankedMemRefDescriptorSize(memRefElementType,
+ *layout);
+ return layout->getTypeSize(elementType);
+}
+
+static Value castAllocFuncResult(ConversionPatternRewriter &rewriter,
+ Location loc, Value allocatedPtr,
+ MemRefType memRefType, Type elementPtrType,
+ const LLVMTypeConverter &typeConverter) {
+ auto allocatedPtrTy = ...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/136424
More information about the Mlir-commits
mailing list