[Mlir-commits] [mlir] 9f13b93 - [mlir][memref] Add realloc op.
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Sep 21 08:04:09 PDT 2022
Author: bixia1
Date: 2022-09-21T08:04:00-07:00
New Revision: 9f13b9346b7c159fb40168ed0bb5dab6b0652836
URL: https://github.com/llvm/llvm-project/commit/9f13b9346b7c159fb40168ed0bb5dab6b0652836
DIFF: https://github.com/llvm/llvm-project/commit/9f13b9346b7c159fb40168ed0bb5dab6b0652836.diff
LOG: [mlir][memref] Add realloc op.
Add memref.realloc and canonicalization of the op. Add conversion patterns for
lowering the op to LLVM using unaligned alloc or aligned alloc based on the
conversion option.
Add filecheck tests for parsing and converting the op. Add an integration test.
Reviewed By: ftynse
Differential Revision: https://reviews.llvm.org/D133424
Added:
mlir/test/Integration/Dialect/Vector/CPU/test-realloc.mlir
Modified:
mlir/include/mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h
mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp
mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
mlir/test/Conversion/MemRefToLLVM/convert-dynamic-memref-ops.mlir
mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir
mlir/test/Dialect/MemRef/canonicalize.mlir
mlir/test/Dialect/MemRef/invalid.mlir
mlir/test/Dialect/MemRef/ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h b/mlir/include/mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h
index a612ecce5bc31..25cf034ea1e16 100644
--- a/mlir/include/mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h
+++ b/mlir/include/mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h
@@ -13,23 +13,100 @@
namespace mlir {
-/// Lowering for AllocOp and AllocaOp.
-struct AllocLikeOpLLVMLowering : public ConvertToLLVMPattern {
+/// Lowering for memory allocation ops.
+struct AllocationOpLLVMLowering : public ConvertToLLVMPattern {
using ConvertToLLVMPattern::createIndexConstant;
using ConvertToLLVMPattern::getIndexType;
using ConvertToLLVMPattern::getVoidPtrType;
- explicit AllocLikeOpLLVMLowering(StringRef opName,
- LLVMTypeConverter &converter)
+ explicit AllocationOpLLVMLowering(StringRef opName,
+ LLVMTypeConverter &converter)
: ConvertToLLVMPattern(opName, &converter.getContext(), converter) {}
protected:
- // Returns 'input' aligned up to 'alignment'. Computes
- // bumped = input + alignement - 1
- // aligned = bumped - bumped % alignment
+ /// Computes the aligned value for 'input' as follows:
+ /// bumped = input + alignement - 1
+ /// aligned = bumped - bumped % alignment
static Value createAligned(ConversionPatternRewriter &rewriter, Location loc,
Value input, Value alignment);
+ static MemRefType getMemRefResultType(Operation *op) {
+ return op->getResult(0).getType().cast<MemRefType>();
+ }
+
+ /// Computes the alignment for the given memory allocation op.
+ template <typename OpType>
+ Value getAlignment(ConversionPatternRewriter &rewriter, Location loc,
+ OpType op) const {
+ MemRefType memRefType = op.getType();
+ Value alignment;
+ if (auto alignmentAttr = op.getAlignment()) {
+ alignment = createIndexConstant(rewriter, loc, *alignmentAttr);
+ } else if (!memRefType.getElementType().isSignlessIntOrIndexOrFloat()) {
+ // In the case where no alignment is specified, we may want to override
+ // `malloc's` behavior. `malloc` typically aligns at the size of the
+ // biggest scalar on a target HW. For non-scalars, use the natural
+ // alignment of the LLVM type given by the LLVM DataLayout.
+ alignment = getSizeInBytes(loc, memRefType.getElementType(), rewriter);
+ }
+ return alignment;
+ }
+
+ /// Computes the alignment for aligned_alloc used to allocate the buffer for
+ /// the memory allocation op.
+ ///
+ /// Aligned_alloc requires the allocation size to be a power of two, and the
+ /// allocation size to be a multiple of the alignment.
+ template <typename OpType>
+ int64_t alignedAllocationGetAlignment(ConversionPatternRewriter &rewriter,
+ Location loc, OpType op,
+ const DataLayout *defaultLayout) const {
+ if (Optional<uint64_t> alignment = op.getAlignment())
+ return *alignment;
+
+ // Whenever we don't have alignment set, we will use an alignment
+ // consistent with the element type; since the allocation size has to be a
+ // power of two, we will bump to the next power of two if it isn't.
+ unsigned eltSizeBytes =
+ getMemRefEltSizeInBytes(op.getType(), op, defaultLayout);
+ return std::max(kMinAlignedAllocAlignment,
+ llvm::PowerOf2Ceil(eltSizeBytes));
+ }
+
+ /// Allocates a memory buffer using an allocation method that doesn't
+ /// guarantee alignment. Returns the pointer and its aligned value.
+ std::tuple<Value, Value>
+ allocateBufferManuallyAlign(ConversionPatternRewriter &rewriter, Location loc,
+ Value sizeBytes, Operation *op,
+ Value alignment) const;
+
+ /// Allocates a memory buffer using an aligned allocation method.
+ Value allocateBufferAutoAlign(ConversionPatternRewriter &rewriter,
+ Location loc, Value sizeBytes, Operation *op,
+ const DataLayout *defaultLayout,
+ int64_t alignment) const;
+
+private:
+ /// Computes the byte size for the MemRef element type.
+ unsigned getMemRefEltSizeInBytes(MemRefType memRefType, Operation *op,
+ const DataLayout *defaultLayout) const;
+
+ /// Returns true if the memref size in bytes is known to be a multiple of
+ /// factor.
+ bool isMemRefSizeMultipleOf(MemRefType type, uint64_t factor, Operation *op,
+ const DataLayout *defaultLayout) const;
+
+ /// The minimum alignment to use with aligned_alloc (has to be a power of 2).
+ static constexpr uint64_t kMinAlignedAllocAlignment = 16UL;
+};
+
+/// Lowering for AllocOp and AllocaOp.
+struct AllocLikeOpLLVMLowering : public AllocationOpLLVMLowering {
+ explicit AllocLikeOpLLVMLowering(StringRef opName,
+ LLVMTypeConverter &converter)
+ : AllocationOpLLVMLowering(opName, converter) {}
+
+protected:
/// Allocates the underlying buffer. Returns the allocated pointer and the
/// aligned pointer.
virtual std::tuple<Value, Value>
@@ -37,10 +114,6 @@ struct AllocLikeOpLLVMLowering : public ConvertToLLVMPattern {
Value sizeBytes, Operation *op) const = 0;
private:
- static MemRefType getMemRefResultType(Operation *op) {
- return op->getResult(0).getType().cast<MemRefType>();
- }
-
// An `alloc` is converted into a definition of a memref descriptor value and
// a call to `malloc` to allocate the underlying data buffer. The memref
// descriptor is of the LLVM structure type where:
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 2ea19bbb0216c..d223fb5eef998 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -178,6 +178,99 @@ def MemRef_AllocOp : AllocLikeOp<"alloc", DefaultResource, []> {
let hasVerifier = 1;
}
+//===----------------------------------------------------------------------===//
+// ReallocOp
+//===----------------------------------------------------------------------===//
+
+
+def MemRef_ReallocOp : MemRef_Op<"realloc"> {
+ let summary = "memory reallocation operation";
+ let description = [{
+ The `realloc` operation changes the size of a memory region. The memory
+ region is specified by a 1D source memref and the size of the new memory
+ region is specified by a 1D result memref type and an optional dynamic Value
+ of `Index` type. The source and the result memref must be in the same memory
+ space and have the same element type.
+
+ The operation may move the memory region to a new location. In this case,
+ the content of the memory block is preserved up to the lesser of the new
+ and old sizes. If the new size if larger, the value of the extended memory
+ is undefined. This is consistent with the ISO C realloc.
+
+ The operation returns an SSA value for the memref.
+
+ Example:
+
+ ```mlir
+ %0 = memref.realloc %src : memref<64xf32> to memref<124xf32>
+ ```
+
+ The source memref may have a dynamic shape, in which case, the compiler will
+ generate code to extract its size from the runtime data structure for the
+ memref.
+
+ ```mlir
+ %1 = memref.realloc %src : memref<?xf32> to memref<124xf32>
+ ```
+
+ If the result memref has a dynamic shape, a result dimension operand is
+ needed to spefify its dynamic dimension. In the example below, the ssa value
+ '%d' specifies the unknown dimension of the result memref.
+
+ ```mlir
+ %2 = memref.realloc %src(%d) : memref<?xf32> to memref<?xf32>
+ ```
+
+ An optional `alignment` attribute may be specified to ensure that the
+ region of memory that will be indexed is aligned at the specified byte
+ boundary. This is consistent with the fact that memref.alloc supports such
+ an optional alignment attribute. Note that in ISO C standard, neither alloc
+ nor realloc supports alignment, though there is aligned_alloc but not
+ aligned_realloc.
+
+ ```mlir
+ %3 = memref.ralloc %src {alignment = 8} : memref<64xf32> to memref<124xf32>
+ ```
+
+ Referencing the memref through the old SSA value after realloc is undefined
+ behavior.
+
+ ```mlir
+ %new = memref.realloc %old : memref<64xf32> to memref<124xf32>
+ %4 = memref.load %new[%index] // ok
+ %5 = memref.load %old[%index] // undefined behavior
+ ```
+ }];
+
+ let arguments = (ins MemRefRankOf<[AnyType], [1]>:$source,
+ Optional<Index>:$dynamicResultSize,
+ ConfinedAttr<OptionalAttr<I64Attr>,
+ [IntMinValue<0>]>:$alignment);
+
+ let results = (outs MemRefRankOf<[AnyType], [1]>);
+
+ let builders = [
+ OpBuilder<(ins "MemRefType":$resultType,
+ "Value":$source,
+ CArg<"Value", "Value()">:$dynamicResultSize), [{
+ return build($_builder, $_state, resultType, source, dynamicResultSize,
+ IntegerAttr());
+ }]>];
+
+ let extraClassDeclaration = [{
+ /// The result of a realloc is always a memref.
+ MemRefType getType() { return getResult().getType().cast<MemRefType>(); }
+ }];
+
+ let assemblyFormat = [{
+ $source (`(` $dynamicResultSize^ `)`)? attr-dict
+ `:` type($source) `to` type(results)
+ }];
+
+ let hasCanonicalizer = 1;
+ let hasVerifier = 1;
+}
+
//===----------------------------------------------------------------------===//
// AllocaOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp b/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp
index d2778b785be02..4a5be48707097 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp
@@ -7,11 +7,40 @@
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h"
+#include "mlir/Analysis/DataLayoutAnalysis.h"
+#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
using namespace mlir;
-Value AllocLikeOpLLVMLowering::createAligned(
+namespace {
+// TODO: Fix the LLVM utilities for looking up functions to take Operation*
+// with SymbolTable trait instead of ModuleOp and make similar change here. This
+// allows call sites to use getParentWithTrait<OpTrait::SymbolTable> instead
+// of getParentOfType<ModuleOp> to pass down the operation.
+LLVM::LLVMFuncOp getNotalignedAllocFn(LLVMTypeConverter *typeConverter,
+ ModuleOp module, Type indexType) {
+ bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
+
+ if (useGenericFn)
+ return LLVM::lookupOrCreateGenericAllocFn(module, indexType);
+
+ return LLVM::lookupOrCreateMallocFn(module, indexType);
+}
+
+LLVM::LLVMFuncOp getAlignedAllocFn(LLVMTypeConverter *typeConverter,
+ ModuleOp module, Type indexType) {
+ bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
+
+ if (useGenericFn)
+ return LLVM::lookupOrCreateGenericAlignedAllocFn(module, indexType);
+
+ return LLVM::lookupOrCreateAlignedAllocFn(module, indexType);
+}
+
+} // end namespace
+
+Value AllocationOpLLVMLowering::createAligned(
ConversionPatternRewriter &rewriter, Location loc, Value input,
Value alignment) {
Value one = createIndexAttrConstant(rewriter, loc, alignment.getType(), 1);
@@ -21,6 +50,88 @@ Value AllocLikeOpLLVMLowering::createAligned(
return rewriter.create<LLVM::SubOp>(loc, bumped, mod);
}
+std::tuple<Value, Value> AllocationOpLLVMLowering::allocateBufferManuallyAlign(
+ ConversionPatternRewriter &rewriter, Location loc, Value sizeBytes,
+ Operation *op, Value alignment) const {
+ if (alignment) {
+ // Adjust the allocation size to consider alignment.
+ sizeBytes = rewriter.create<LLVM::AddOp>(loc, sizeBytes, alignment);
+ }
+
+ MemRefType memRefType = getMemRefResultType(op);
+ // Allocate the underlying buffer.
+ Type elementPtrType = this->getElementPtrType(memRefType);
+ LLVM::LLVMFuncOp allocFuncOp = getNotalignedAllocFn(
+ getTypeConverter(), op->getParentOfType<ModuleOp>(), getIndexType());
+ auto results = rewriter.create<LLVM::CallOp>(loc, allocFuncOp, sizeBytes);
+ Value allocatedPtr = rewriter.create<LLVM::BitcastOp>(loc, elementPtrType,
+ results.getResult());
+
+ Value alignedPtr = allocatedPtr;
+ if (alignment) {
+ // Compute the aligned pointer.
+ Value allocatedInt =
+ rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), allocatedPtr);
+ Value alignmentInt = createAligned(rewriter, loc, allocatedInt, alignment);
+ alignedPtr =
+ rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, alignmentInt);
+ }
+
+ return std::make_tuple(allocatedPtr, alignedPtr);
+}
+
+unsigned AllocationOpLLVMLowering::getMemRefEltSizeInBytes(
+ MemRefType memRefType, Operation *op,
+ const DataLayout *defaultLayout) const {
+ const DataLayout *layout = defaultLayout;
+ if (const DataLayoutAnalysis *analysis =
+ getTypeConverter()->getDataLayoutAnalysis()) {
+ layout = &analysis->getAbove(op);
+ }
+ Type elementType = memRefType.getElementType();
+ if (auto memRefElementType = elementType.dyn_cast<MemRefType>())
+ return getTypeConverter()->getMemRefDescriptorSize(memRefElementType,
+ *layout);
+ if (auto memRefElementType = elementType.dyn_cast<UnrankedMemRefType>())
+ return getTypeConverter()->getUnrankedMemRefDescriptorSize(
+ memRefElementType, *layout);
+ return layout->getTypeSize(elementType);
+}
+
+bool AllocationOpLLVMLowering::isMemRefSizeMultipleOf(
+ MemRefType type, uint64_t factor, Operation *op,
+ const DataLayout *defaultLayout) const {
+ uint64_t sizeDivisor = getMemRefEltSizeInBytes(type, op, defaultLayout);
+ for (unsigned i = 0, e = type.getRank(); i < e; i++) {
+ if (ShapedType::isDynamic(type.getDimSize(i)))
+ continue;
+ sizeDivisor = sizeDivisor * type.getDimSize(i);
+ }
+ return sizeDivisor % factor == 0;
+}
+
+Value AllocationOpLLVMLowering::allocateBufferAutoAlign(
+ ConversionPatternRewriter &rewriter, Location loc, Value sizeBytes,
+ Operation *op, const DataLayout *defaultLayout, int64_t alignment) const {
+ Value allocAlignment = createIndexConstant(rewriter, loc, alignment);
+
+ MemRefType memRefType = getMemRefResultType(op);
+ // Function aligned_alloc requires size to be a multiple of alignment; we pad
+ // the size to the next multiple if necessary.
+ if (!isMemRefSizeMultipleOf(memRefType, alignment, op, defaultLayout))
+ sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment);
+
+ Type elementPtrType = this->getElementPtrType(memRefType);
+ LLVM::LLVMFuncOp allocFuncOp = getAlignedAllocFn(
+ getTypeConverter(), op->getParentOfType<ModuleOp>(), getIndexType());
+ auto results = rewriter.create<LLVM::CallOp>(
+ loc, allocFuncOp, ValueRange({allocAlignment, sizeBytes}));
+ Value allocatedPtr = rewriter.create<LLVM::BitcastOp>(loc, elementPtrType,
+ results.getResult());
+
+ return allocatedPtr;
+}
+
LogicalResult AllocLikeOpLLVMLowering::matchAndRewrite(
Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 65f533ba9fa42..241f62eb1d59a 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -36,63 +36,25 @@ bool isStaticStrideOrOffset(int64_t strideOrOffset) {
return !ShapedType::isDynamicStrideOrOffset(strideOrOffset);
}
+LLVM::LLVMFuncOp getFreeFn(LLVMTypeConverter *typeConverter, ModuleOp module) {
+ bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
+
+ if (useGenericFn)
+ return LLVM::lookupOrCreateGenericFreeFn(module);
+
+ return LLVM::lookupOrCreateFreeFn(module);
+}
+
struct AllocOpLowering : public AllocLikeOpLLVMLowering {
AllocOpLowering(LLVMTypeConverter &converter)
: AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(),
converter) {}
-
- LLVM::LLVMFuncOp getAllocFn(ModuleOp module) const {
- bool useGenericFn = getTypeConverter()->getOptions().useGenericFunctions;
-
- if (useGenericFn)
- return LLVM::lookupOrCreateGenericAllocFn(module, getIndexType());
-
- return LLVM::lookupOrCreateMallocFn(module, getIndexType());
- }
-
std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
Location loc, Value sizeBytes,
Operation *op) const override {
- // Heap allocations.
- memref::AllocOp allocOp = cast<memref::AllocOp>(op);
- MemRefType memRefType = allocOp.getType();
-
- Value alignment;
- if (auto alignmentAttr = allocOp.getAlignment()) {
- alignment = createIndexConstant(rewriter, loc, *alignmentAttr);
- } else if (!memRefType.getElementType().isSignlessIntOrIndexOrFloat()) {
- // In the case where no alignment is specified, we may want to override
- // `malloc's` behavior. `malloc` typically aligns at the size of the
- // biggest scalar on a target HW. For non-scalars, use the natural
- // alignment of the LLVM type given by the LLVM DataLayout.
- alignment = getSizeInBytes(loc, memRefType.getElementType(), rewriter);
- }
-
- if (alignment) {
- // Adjust the allocation size to consider alignment.
- sizeBytes = rewriter.create<LLVM::AddOp>(loc, sizeBytes, alignment);
- }
-
- // Allocate the underlying buffer and store a pointer to it in the MemRef
- // descriptor.
- Type elementPtrType = this->getElementPtrType(memRefType);
- auto allocFuncOp = getAllocFn(allocOp->getParentOfType<ModuleOp>());
- auto results = rewriter.create<LLVM::CallOp>(loc, allocFuncOp, sizeBytes);
- Value allocatedPtr = rewriter.create<LLVM::BitcastOp>(loc, elementPtrType,
- results.getResult());
-
- Value alignedPtr = allocatedPtr;
- if (alignment) {
- // Compute the aligned type pointer.
- Value allocatedInt =
- rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), allocatedPtr);
- Value alignmentInt =
- createAligned(rewriter, loc, allocatedInt, alignment);
- alignedPtr =
- rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, alignmentInt);
- }
-
- return std::make_tuple(allocatedPtr, alignedPtr);
+ return allocateBufferManuallyAlign(
+ rewriter, loc, sizeBytes, op,
+ getAlignment(rewriter, loc, cast<memref::AllocOp>(op)));
}
};
@@ -100,90 +62,17 @@ struct AlignedAllocOpLowering : public AllocLikeOpLLVMLowering {
AlignedAllocOpLowering(LLVMTypeConverter &converter)
: AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(),
converter) {}
-
- /// Returns the memref's element size in bytes using the data layout active at
- /// `op`.
- // TODO: there are other places where this is used. Expose publicly?
- unsigned getMemRefEltSizeInBytes(MemRefType memRefType, Operation *op) const {
- const DataLayout *layout = &defaultLayout;
- if (const DataLayoutAnalysis *analysis =
- getTypeConverter()->getDataLayoutAnalysis()) {
- layout = &analysis->getAbove(op);
- }
- Type elementType = memRefType.getElementType();
- if (auto memRefElementType = elementType.dyn_cast<MemRefType>())
- return getTypeConverter()->getMemRefDescriptorSize(memRefElementType,
- *layout);
- if (auto memRefElementType = elementType.dyn_cast<UnrankedMemRefType>())
- return getTypeConverter()->getUnrankedMemRefDescriptorSize(
- memRefElementType, *layout);
- return layout->getTypeSize(elementType);
- }
-
- /// Returns true if the memref size in bytes is known to be a multiple of
- /// factor assuming the data layout active at `op`.
- bool isMemRefSizeMultipleOf(MemRefType type, uint64_t factor,
- Operation *op) const {
- uint64_t sizeDivisor = getMemRefEltSizeInBytes(type, op);
- for (unsigned i = 0, e = type.getRank(); i < e; i++) {
- if (ShapedType::isDynamic(type.getDimSize(i)))
- continue;
- sizeDivisor = sizeDivisor * type.getDimSize(i);
- }
- return sizeDivisor % factor == 0;
- }
-
- /// Returns the alignment to be used for the allocation call itself.
- /// aligned_alloc requires the allocation size to be a power of two, and the
- /// allocation size to be a multiple of alignment,
- int64_t getAllocationAlignment(memref::AllocOp allocOp) const {
- if (Optional<uint64_t> alignment = allocOp.getAlignment())
- return *alignment;
-
- // Whenever we don't have alignment set, we will use an alignment
- // consistent with the element type; since the allocation size has to be a
- // power of two, we will bump to the next power of two if it already isn't.
- auto eltSizeBytes = getMemRefEltSizeInBytes(allocOp.getType(), allocOp);
- return std::max(kMinAlignedAllocAlignment,
- llvm::PowerOf2Ceil(eltSizeBytes));
- }
-
- LLVM::LLVMFuncOp getAllocFn(ModuleOp module) const {
- bool useGenericFn = getTypeConverter()->getOptions().useGenericFunctions;
-
- if (useGenericFn)
- return LLVM::lookupOrCreateGenericAlignedAllocFn(module, getIndexType());
-
- return LLVM::lookupOrCreateAlignedAllocFn(module, getIndexType());
- }
-
std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
Location loc, Value sizeBytes,
Operation *op) const override {
- // Heap allocations.
- memref::AllocOp allocOp = cast<memref::AllocOp>(op);
- MemRefType memRefType = allocOp.getType();
- int64_t alignment = getAllocationAlignment(allocOp);
- Value allocAlignment = createIndexConstant(rewriter, loc, alignment);
-
- // aligned_alloc requires size to be a multiple of alignment; we will pad
- // the size to the next multiple if necessary.
- if (!isMemRefSizeMultipleOf(memRefType, alignment, op))
- sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment);
-
- Type elementPtrType = this->getElementPtrType(memRefType);
- auto allocFuncOp = getAllocFn(allocOp->getParentOfType<ModuleOp>());
- auto results = rewriter.create<LLVM::CallOp>(
- loc, allocFuncOp, ValueRange({allocAlignment, sizeBytes}));
- Value allocatedPtr = rewriter.create<LLVM::BitcastOp>(loc, elementPtrType,
- results.getResult());
-
- return std::make_tuple(allocatedPtr, allocatedPtr);
+ Value ptr = allocateBufferAutoAlign(
+ rewriter, loc, sizeBytes, op, &defaultLayout,
+ alignedAllocationGetAlignment(rewriter, loc, cast<memref::AllocOp>(op),
+ &defaultLayout));
+ return std::make_tuple(ptr, ptr);
}
- /// The minimum alignment to use with aligned_alloc (has to be a power of 2).
- static constexpr uint64_t kMinAlignedAllocAlignment = 16UL;
-
+private:
/// Default layout to use in absence of the corresponding analysis.
DataLayout defaultLayout;
};
@@ -212,6 +101,160 @@ struct AllocaOpLowering : public AllocLikeOpLLVMLowering {
}
};
+/// The base class for lowering realloc op, to support the implementation of
+/// realloc via allocation methods that may or may not support alignment.
+/// A derived class should provide an implementation of allocateBuffer using
+/// the underline allocation methods.
+struct ReallocOpLoweringBase : public AllocationOpLLVMLowering {
+ using OpAdaptor = typename memref::ReallocOp::Adaptor;
+
+ ReallocOpLoweringBase(LLVMTypeConverter &converter)
+ : AllocationOpLLVMLowering(memref::ReallocOp::getOperationName(),
+ converter) {}
+
+ /// Allocates the new buffer. Returns the allocated pointer and the
+ /// aligned pointer.
+ virtual std::tuple<Value, Value>
+ allocateBuffer(ConversionPatternRewriter &rewriter, Location loc,
+ Value sizeBytes, memref::ReallocOp op) const = 0;
+
+ LogicalResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const final {
+ return matchAndRewrite(cast<memref::ReallocOp>(op),
+ OpAdaptor(operands, op->getAttrDictionary()),
+ rewriter);
+ }
+
+ // A `realloc` is converted as follows:
+ // If new_size > old_size
+ // 1. allocates a new buffer
+ // 2. copies the content of the old buffer to the new buffer
+ // 3. release the old buffer
+ // 3. updates the buffer pointers in the memref descriptor
+ // Update the size in the memref descriptor
+ // Alignment request is handled by allocating `alignment` more bytes than
+ // requested and shifting the aligned pointer relative to the allocated
+ // memory.
+ LogicalResult matchAndRewrite(memref::ReallocOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
+ OpBuilder::InsertionGuard guard(rewriter);
+ Location loc = op.getLoc();
+
+ auto computeNumElements =
+ [&](MemRefType type, function_ref<Value()> getDynamicSize) -> Value {
+ // Compute number of elements.
+ int64_t size = type.getShape()[0];
+ Value numElements = ((size == ShapedType::kDynamicSize)
+ ? getDynamicSize()
+ : createIndexConstant(rewriter, loc, size));
+ Type indexType = getIndexType();
+ if (numElements.getType() != indexType)
+ numElements = typeConverter->materializeTargetConversion(
+ rewriter, loc, indexType, numElements);
+ return numElements;
+ };
+
+ MemRefDescriptor desc(adaptor.getSource());
+ Value oldDesc = desc;
+
+ // Split the block right before the current op into two blocks.
+ Block *currentBlock = rewriter.getInsertionBlock();
+ Block *block =
+ rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
+ // Add a block argument by creating an empty block with the argument type
+ // and then merging the block into the empty block.
+ Block *endBlock = rewriter.createBlock(
+ block->getParent(), Region::iterator(block), oldDesc.getType(), loc);
+ rewriter.mergeBlocks(block, endBlock, {});
+ // Add a new block for the true branch of the conditional statement we will
+ // add.
+ Block *trueBlock = rewriter.createBlock(
+ currentBlock->getParent(), std::next(Region::iterator(currentBlock)));
+
+ rewriter.setInsertionPointToEnd(currentBlock);
+ Value src = op.getSource();
+ auto srcType = src.getType().dyn_cast<MemRefType>();
+ Value srcNumElements = computeNumElements(
+ srcType, [&]() -> Value { return desc.size(rewriter, loc, 0); });
+ auto dstType = op.getType().cast<MemRefType>();
+ Value dstNumElements = computeNumElements(
+ dstType, [&]() -> Value { return op.getDynamicResultSize(); });
+ Value cond = rewriter.create<LLVM::ICmpOp>(
+ loc, IntegerType::get(rewriter.getContext(), 1),
+ LLVM::ICmpPredicate::ugt, dstNumElements, srcNumElements);
+ rewriter.create<LLVM::CondBrOp>(loc, cond, trueBlock, ArrayRef<Value>(),
+ endBlock, ValueRange{oldDesc});
+
+ rewriter.setInsertionPointToStart(trueBlock);
+ Value sizeInBytes = getSizeInBytes(loc, dstType.getElementType(), rewriter);
+ // Compute total byte size.
+ auto dstByteSize =
+ rewriter.create<LLVM::MulOp>(loc, dstNumElements, sizeInBytes);
+ // Allocate a new buffer.
+ auto [dstRawPtr, dstAlignedPtr] =
+ allocateBuffer(rewriter, loc, dstByteSize, op);
+ // Copy the data from the old buffer to the new buffer.
+ Value srcAlignedPtr = desc.alignedPtr(rewriter, loc);
+ Value isVolatile =
+ rewriter.create<LLVM::ConstantOp>(loc, rewriter.getBoolAttr(false));
+ auto toVoidPtr = [&](Value ptr) -> Value {
+ return rewriter.create<LLVM::BitcastOp>(loc, getVoidPtrType(), ptr);
+ };
+ rewriter.create<LLVM::MemcpyOp>(loc, toVoidPtr(dstAlignedPtr),
+ toVoidPtr(srcAlignedPtr), dstByteSize,
+ isVolatile);
+ // Deallocate the old buffer.
+ LLVM::LLVMFuncOp freeFunc =
+ getFreeFn(getTypeConverter(), op->getParentOfType<ModuleOp>());
+ rewriter.create<LLVM::CallOp>(loc, freeFunc,
+ toVoidPtr(desc.allocatedPtr(rewriter, loc)));
+ // Replace the old buffer addresses in the MemRefDescriptor with the new
+ // buffer addresses.
+ desc.setAllocatedPtr(rewriter, loc, dstRawPtr);
+ desc.setAlignedPtr(rewriter, loc, dstAlignedPtr);
+ rewriter.create<LLVM::BrOp>(loc, Value(desc), endBlock);
+
+ rewriter.setInsertionPoint(op);
+ // Update the memref size.
+ MemRefDescriptor newDesc(endBlock->getArgument(0));
+ newDesc.setSize(rewriter, loc, 0, dstNumElements);
+ rewriter.replaceOp(op, {newDesc});
+ return success();
+ }
+
+private:
+ using ConvertToLLVMPattern::matchAndRewrite;
+};
+
+struct ReallocOpLowering : public ReallocOpLoweringBase {
+ ReallocOpLowering(LLVMTypeConverter &converter)
+ : ReallocOpLoweringBase(converter) {}
+ std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
+ Location loc, Value sizeBytes,
+ memref::ReallocOp op) const override {
+ return allocateBufferManuallyAlign(rewriter, loc, sizeBytes, op,
+ getAlignment(rewriter, loc, op));
+ }
+};
+
+struct AlignedReallocOpLowering : public ReallocOpLoweringBase {
+ AlignedReallocOpLowering(LLVMTypeConverter &converter)
+ : ReallocOpLoweringBase(converter) {}
+ std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
+ Location loc, Value sizeBytes,
+ memref::ReallocOp op) const override {
+ Value ptr = allocateBufferAutoAlign(
+ rewriter, loc, sizeBytes, op, &defaultLayout,
+ alignedAllocationGetAlignment(rewriter, loc, op, &defaultLayout));
+ return std::make_tuple(ptr, ptr);
+ }
+
+private:
+ /// Default layout to use in absence of the corresponding analysis.
+ DataLayout defaultLayout;
+};
+
struct AllocaScopeOpLowering
: public ConvertOpToLLVMPattern<memref::AllocaScopeOp> {
using ConvertOpToLLVMPattern<memref::AllocaScopeOp>::ConvertOpToLLVMPattern;
@@ -316,20 +359,12 @@ struct DeallocOpLowering : public ConvertOpToLLVMPattern<memref::DeallocOp> {
explicit DeallocOpLowering(LLVMTypeConverter &converter)
: ConvertOpToLLVMPattern<memref::DeallocOp>(converter) {}
- LLVM::LLVMFuncOp getFreeFn(ModuleOp module) const {
- bool useGenericFn = getTypeConverter()->getOptions().useGenericFunctions;
-
- if (useGenericFn)
- return LLVM::lookupOrCreateGenericFreeFn(module);
-
- return LLVM::lookupOrCreateFreeFn(module);
- }
-
LogicalResult
matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Insert the `free` declaration if it is not already present.
- auto freeFunc = getFreeFn(op->getParentOfType<ModuleOp>());
+ LLVM::LLVMFuncOp freeFunc =
+ getFreeFn(getTypeConverter(), op->getParentOfType<ModuleOp>());
MemRefDescriptor memref(adaptor.getMemref());
Value casted = rewriter.create<LLVM::BitcastOp>(
op.getLoc(), getVoidPtrType(),
@@ -2060,9 +2095,11 @@ void mlir::populateMemRefToLLVMConversionPatterns(LLVMTypeConverter &converter,
// clang-format on
auto allocLowering = converter.getOptions().allocLowering;
if (allocLowering == LowerToLLVMOptions::AllocLowering::AlignedAlloc)
- patterns.add<AlignedAllocOpLowering, DeallocOpLowering>(converter);
+ patterns.add<AlignedAllocOpLowering, AlignedReallocOpLowering,
+ DeallocOpLowering>(converter);
else if (allocLowering == LowerToLLVMOptions::AllocLowering::Malloc)
- patterns.add<AllocOpLowering, DeallocOpLowering>(converter);
+ patterns.add<AllocOpLowering, ReallocOpLowering, DeallocOpLowering>(
+ converter);
}
namespace {
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 227352b28de5a..aa302bd957e0a 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -246,6 +246,52 @@ void AllocaOp::getCanonicalizationPatterns(RewritePatternSet &results,
context);
}
+//===----------------------------------------------------------------------===//
+// ReallocOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult ReallocOp::verify() {
+ auto sourceType = getOperand(0).getType().cast<MemRefType>();
+ MemRefType resultType = getType();
+
+ // The source memref should have identity layout (or none).
+ if (!sourceType.getLayout().isIdentity())
+ return emitError("unsupported layout for source memref type ")
+ << sourceType;
+
+ // The result memref should have identity layout (or none).
+ if (!resultType.getLayout().isIdentity())
+ return emitError("unsupported layout for result memref type ")
+ << resultType;
+
+ // The source memref and the result memref should be in the same memory space.
+ if (sourceType.getMemorySpace() != resultType.getMemorySpace())
+ return emitError("
diff erent memory spaces specified for source memref "
+ "type ")
+ << sourceType << " and result memref type " << resultType;
+
+ // The source memref and the result memref should have the same element type.
+ if (sourceType.getElementType() != resultType.getElementType())
+ return emitError("
diff erent element types specified for source memref "
+ "type ")
+ << sourceType << " and result memref type " << resultType;
+
+ // Verify that we have the dynamic dimension operand when it is needed.
+ if (resultType.getNumDynamicDims() && !getDynamicResultSize())
+ return emitError("missing dimension operand for result type ")
+ << resultType;
+ if (!resultType.getNumDynamicDims() && getDynamicResultSize())
+ return emitError("unnecessary dimension operand for result type ")
+ << resultType;
+
+ return success();
+}
+
+void ReallocOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.add<SimplifyDeadAlloc<ReallocOp>>(context);
+}
+
//===----------------------------------------------------------------------===//
// AllocaScopeOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/MemRefToLLVM/convert-dynamic-memref-ops.mlir b/mlir/test/Conversion/MemRefToLLVM/convert-dynamic-memref-ops.mlir
index 22bec31f9d91e..3cd8fb3334640 100644
--- a/mlir/test/Conversion/MemRefToLLVM/convert-dynamic-memref-ops.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/convert-dynamic-memref-ops.mlir
@@ -626,3 +626,125 @@ func.func @ranked_unranked() {
return
}
+// -----
+
+// CHECK-LABEL: func.func @realloc_dynamic(
+// CHECK-SAME: %[[arg0:.*]]: memref<?xf32>,
+// CHECK-SAME: %[[arg1:.*]]: index) -> memref<?xf32> {
+func.func @realloc_dynamic(%in: memref<?xf32>, %d: index) -> memref<?xf32>{
+// CHECK: %[[descriptor:.*]] = builtin.unrealized_conversion_cast %[[arg0]]
+// CHECK: %[[drc_dim:.*]] = llvm.extractvalue %[[descriptor]][3, 0]
+// CHECK: %[[dst_dim:.*]] = builtin.unrealized_conversion_cast %[[arg1]] : index to i64
+// CHECK: %[[cond:.*]] = llvm.icmp "ugt" %[[dst_dim]], %[[drc_dim]] : i64
+// CHECK: llvm.cond_br %[[cond]], ^bb1, ^bb2(%[[descriptor]]
+// CHECK: ^bb1:
+// CHECK: %[[dst_null:.*]] = llvm.mlir.null : !llvm.ptr<f32>
+// CHECK: %[[dst_gep:.*]] = llvm.getelementptr %[[dst_null]][1]
+// CHECK: %[[dst_es:.*]] = llvm.ptrtoint %[[dst_gep]] : !llvm.ptr<f32> to i64
+// CHECK: %[[dst_size:.*]] = llvm.mul %[[dst_dim]], %[[dst_es]]
+// CHECK: %[[new_buffer_raw:.*]] = llvm.call @malloc(%[[dst_size]])
+// CHECK: %[[new_buffer:.*]] = llvm.bitcast %[[new_buffer_raw]] : !llvm.ptr<i8> to !llvm.ptr<f32>
+// CHECK: %[[old_buffer_aligned:.*]] = llvm.extractvalue %[[descriptor]][1]
+// CHECK: %[[volatile:.*]] = llvm.mlir.constant(false) : i1
+// CHECK-DAG: %[[new_buffer_void:.*]] = llvm.bitcast %[[new_buffer]] : !llvm.ptr<f32> to !llvm.ptr<i8>
+// CHECK-DAG: %[[old_buffer_void:.*]] = llvm.bitcast %[[old_buffer_aligned]] : !llvm.ptr<f32> to !llvm.ptr<i8>
+// CHECK: "llvm.intr.memcpy"(%[[new_buffer_void]], %[[old_buffer_void]], %[[dst_size]], %[[volatile]])
+// CHECK: %[[old_buffer_unaligned:.*]] = llvm.extractvalue %[[descriptor]][0]
+// CHECK: %[[old_buffer_unaligned_void:.*]] = llvm.bitcast %[[old_buffer_unaligned]] : !llvm.ptr<f32> to !llvm.ptr<i8>
+// CHECK: llvm.call @free(%[[old_buffer_unaligned_void]])
+// CHECK: %[[descriptor_update1:.*]] = llvm.insertvalue %[[new_buffer]], %[[descriptor]][0]
+// CHECK: %[[descriptor_update2:.*]] = llvm.insertvalue %[[new_buffer]], %[[descriptor_update1]][1]
+// CHECK: llvm.br ^bb2(%[[descriptor_update2]]
+// CHECK: ^bb2(%[[descriptor_update3:.*]]: !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>):
+// CHECK: %[[descriptor_update4:.*]] = llvm.insertvalue %[[dst_dim]], %[[descriptor_update3]][3, 0]
+// CHECK: %[[descriptor_update5:.*]] = builtin.unrealized_conversion_cast %[[descriptor_update4]]
+// CHECK: return %[[descriptor_update5]] : memref<?xf32>
+
+ %out = memref.realloc %in(%d) : memref<?xf32> to memref<?xf32>
+ return %out : memref<?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @realloc_dynamic_alignment(
+// CHECK-SAME: %[[arg0:.*]]: memref<?xf32>,
+// CHECK-SAME: %[[arg1:.*]]: index) -> memref<?xf32> {
+// ALIGNED-ALLOC-LABEL: func.func @realloc_dynamic_alignment(
+// ALIGNED-ALLOC-SAME: %[[arg0:.*]]: memref<?xf32>,
+// ALIGNED-ALLOC-SAME: %[[arg1:.*]]: index) -> memref<?xf32> {
+func.func @realloc_dynamic_alignment(%in: memref<?xf32>, %d: index) -> memref<?xf32>{
+// CHECK: %[[descriptor:.*]] = builtin.unrealized_conversion_cast %[[arg0]]
+// CHECK: %[[drc_dim:.*]] = llvm.extractvalue %[[descriptor]][3, 0]
+// CHECK: %[[dst_dim:.*]] = builtin.unrealized_conversion_cast %[[arg1]] : index to i64
+// CHECK: %[[cond:.*]] = llvm.icmp "ugt" %[[dst_dim]], %[[drc_dim]] : i64
+// CHECK: llvm.cond_br %[[cond]], ^bb1, ^bb2(%[[descriptor]]
+// CHECK: ^bb1:
+// CHECK: %[[dst_null:.*]] = llvm.mlir.null : !llvm.ptr<f32>
+// CHECK: %[[dst_gep:.*]] = llvm.getelementptr %[[dst_null]][1]
+// CHECK: %[[dst_es:.*]] = llvm.ptrtoint %[[dst_gep]] : !llvm.ptr<f32> to i64
+// CHECK: %[[dst_size:.*]] = llvm.mul %[[dst_dim]], %[[dst_es]]
+// CHECK: %[[alignment:.*]] = llvm.mlir.constant(8 : index) : i64
+// CHECK: %[[adjust_dst_size:.*]] = llvm.add %[[dst_size]], %[[alignment]]
+// CHECK: %[[new_buffer_raw:.*]] = llvm.call @malloc(%[[adjust_dst_size]])
+// CHECK: %[[new_buffer_unaligned:.*]] = llvm.bitcast %[[new_buffer_raw]] : !llvm.ptr<i8> to !llvm.ptr<f32>
+// CHECK: %[[new_buffer_int:.*]] = llvm.ptrtoint %[[new_buffer_unaligned]] : !llvm.ptr<f32>
+// CHECK: %[[const_1:.*]] = llvm.mlir.constant(1 : index) : i64
+// CHECK: %[[alignment_m1:.*]] = llvm.sub %[[alignment]], %[[const_1]]
+// CHECK: %[[ptr_alignment_m1:.*]] = llvm.add %[[new_buffer_int]], %[[alignment_m1]]
+// CHECK: %[[padding:.*]] = llvm.urem %[[ptr_alignment_m1]], %[[alignment]]
+// CHECK: %[[new_buffer_aligned_int:.*]] = llvm.sub %[[ptr_alignment_m1]], %[[padding]]
+// CHECK: %[[new_buffer_aligned:.*]] = llvm.inttoptr %[[new_buffer_aligned_int]] : i64 to !llvm.ptr<f32>
+// CHECK: %[[old_buffer_aligned:.*]] = llvm.extractvalue %[[descriptor]][1]
+// CHECK: %[[volatile:.*]] = llvm.mlir.constant(false) : i1
+// CHECK-DAG: %[[new_buffer_void:.*]] = llvm.bitcast %[[new_buffer_aligned]] : !llvm.ptr<f32> to !llvm.ptr<i8>
+// CHECK-DAG: %[[old_buffer_void:.*]] = llvm.bitcast %[[old_buffer_aligned]] : !llvm.ptr<f32> to !llvm.ptr<i8>
+// CHECK: "llvm.intr.memcpy"(%[[new_buffer_void]], %[[old_buffer_void]], %[[dst_size]], %[[volatile]])
+// CHECK: %[[old_buffer_unaligned:.*]] = llvm.extractvalue %[[descriptor]][0]
+// CHECK: %[[old_buffer_unaligned_void:.*]] = llvm.bitcast %[[old_buffer_unaligned]] : !llvm.ptr<f32> to !llvm.ptr<i8>
+// CHECK: llvm.call @free(%[[old_buffer_unaligned_void]])
+// CHECK: %[[descriptor_update1:.*]] = llvm.insertvalue %[[new_buffer_unaligned]], %[[descriptor]][0]
+// CHECK: %[[descriptor_update2:.*]] = llvm.insertvalue %[[new_buffer_aligned]], %[[descriptor_update1]][1]
+// CHECK: llvm.br ^bb2(%[[descriptor_update2]]
+// CHECK: ^bb2(%[[descriptor_update3:.*]]: !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>):
+// CHECK: %[[descriptor_update4:.*]] = llvm.insertvalue %[[dst_dim]], %[[descriptor_update3]][3, 0]
+// CHECK: %[[descriptor_update5:.*]] = builtin.unrealized_conversion_cast %[[descriptor_update4]]
+// CHECK: return %[[descriptor_update5]] : memref<?xf32>
+
+// ALIGNED-ALLOC: %[[descriptor:.*]] = builtin.unrealized_conversion_cast %[[arg0]]
+// ALIGNED-ALLOC: %[[drc_dim:.*]] = llvm.extractvalue %[[descriptor]][3, 0]
+// ALIGNED-ALLOC: %[[dst_dim:.*]] = builtin.unrealized_conversion_cast %[[arg1]] : index to i64
+// ALIGNED-ALLOC: %[[cond:.*]] = llvm.icmp "ugt" %[[dst_dim]], %[[drc_dim]] : i64
+// ALIGNED-ALLOC: llvm.cond_br %[[cond]], ^bb1, ^bb2(%[[descriptor]]
+// ALIGNED-ALLOC: ^bb1:
+// ALIGNED-ALLOC: %[[dst_null:.*]] = llvm.mlir.null : !llvm.ptr<f32>
+// ALIGNED-ALLOC: %[[dst_gep:.*]] = llvm.getelementptr %[[dst_null]][1]
+// ALIGNED-ALLOC: %[[dst_es:.*]] = llvm.ptrtoint %[[dst_gep]] : !llvm.ptr<f32> to i64
+// ALIGNED-ALLOC: %[[dst_size:.*]] = llvm.mul %[[dst_dim]], %[[dst_es]]
+// ALIGNED-ALLOC-DAG: %[[alignment:.*]] = llvm.mlir.constant(8 : index) : i64
+// ALIGNED-ALLOC-DAG: %[[const_1:.*]] = llvm.mlir.constant(1 : index) : i64
+// ALIGNED-ALLOC: %[[alignment_m1:.*]] = llvm.sub %[[alignment]], %[[const_1]]
+// ALIGNED-ALLOC: %[[size_alignment_m1:.*]] = llvm.add %[[dst_size]], %[[alignment_m1]]
+// ALIGNED-ALLOC: %[[padding:.*]] = llvm.urem %[[size_alignment_m1]], %[[alignment]]
+// ALIGNED-ALLOC: %[[adjust_dst_size:.*]] = llvm.sub %[[size_alignment_m1]], %[[padding]]
+// ALIGNED-ALLOC: %[[new_buffer_raw:.*]] = llvm.call @aligned_alloc(%[[alignment]], %[[adjust_dst_size]])
+// ALIGNED-ALLOC: %[[new_buffer_aligned:.*]] = llvm.bitcast %[[new_buffer_raw]] : !llvm.ptr<i8> to !llvm.ptr<f32>
+// ALIGNED-ALLOC: %[[old_buffer_aligned:.*]] = llvm.extractvalue %[[descriptor]][1]
+// ALIGNED-ALLOC: %[[volatile:.*]] = llvm.mlir.constant(false) : i1
+// ALIGNED-ALLOC-DAG: %[[new_buffer_void:.*]] = llvm.bitcast %[[new_buffer_aligned]] : !llvm.ptr<f32> to !llvm.ptr<i8>
+// ALIGNED-ALLOC-DAG: %[[old_buffer_void:.*]] = llvm.bitcast %[[old_buffer_aligned]] : !llvm.ptr<f32> to !llvm.ptr<i8>
+// ALIGNED-ALLOC: "llvm.intr.memcpy"(%[[new_buffer_void]], %[[old_buffer_void]], %[[dst_size]], %[[volatile]])
+// ALIGNED-ALLOC: %[[old_buffer_unaligned:.*]] = llvm.extractvalue %[[descriptor]][0]
+// ALIGNED-ALLOC: %[[old_buffer_unaligned_void:.*]] = llvm.bitcast %[[old_buffer_unaligned]] : !llvm.ptr<f32> to !llvm.ptr<i8>
+// ALIGNED-ALLOC: llvm.call @free(%[[old_buffer_unaligned_void]])
+// ALIGNED-ALLOC: %[[descriptor_update1:.*]] = llvm.insertvalue %[[new_buffer_aligned]], %[[descriptor]][0]
+// ALIGNED-ALLOC: %[[descriptor_update2:.*]] = llvm.insertvalue %[[new_buffer_aligned]], %[[descriptor_update1]][1]
+// ALIGNED-ALLOC: llvm.br ^bb2(%[[descriptor_update2]] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>)
+// ALIGNED-ALLOC: ^bb2(%[[descriptor_update3:.*]]: !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>):
+// ALIGNED-ALLOC: %[[descriptor_update4:.*]] = llvm.insertvalue %[[dst_dim]], %[[descriptor_update3]][3, 0]
+// ALIGNED-ALLOC: %[[descriptor_update5:.*]] = builtin.unrealized_conversion_cast %[[descriptor_update4]]
+// ALIGNED-ALLOC: return %[[descriptor_update5]] : memref<?xf32>
+
+ %out = memref.realloc %in(%d) {alignment = 8} : memref<?xf32> to memref<?xf32>
+ return %out : memref<?xf32>
+}
+
diff --git a/mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir b/mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir
index 2520294ed03dc..cabc84f57847b 100644
--- a/mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir
@@ -338,3 +338,87 @@ func.func @memref.reshape_index(%arg0: memref<?x?xi32>, %shape: memref<1xindex>)
%1 = memref.reshape %arg0(%shape) : (memref<?x?xi32>, memref<1xindex>) -> memref<?xi32>
return %1 : memref<?xi32>
}
+
+// -----
+
+// CHECK-LABEL: func.func @realloc_static(
+// CHECK-SAME: %[[arg0:.*]]: memref<2xi32>) -> memref<4xi32> {
+func.func @realloc_static(%in: memref<2xi32>) -> memref<4xi32>{
+// CHECK: %[[descriptor:.*]] = builtin.unrealized_conversion_cast %[[arg0]] : memref<2xi32> to !llvm.struct<(ptr<i32>, ptr<i32>, i64, array<1 x i64>, array<1 x i64>)>
+// CHECK: %[[src_dim:.*]] = llvm.mlir.constant(2 : index) : i64
+// CHECK: %[[dst_dim:.*]] = llvm.mlir.constant(4 : index) : i64
+// CHECK: %[[cond:.*]] = llvm.icmp "ugt" %[[dst_dim]], %[[src_dim]]
+// CHECK: llvm.cond_br %[[cond]], ^bb1, ^bb2(%[[descriptor]]
+// CHECK: ^bb1:
+// CHECK: %[[dst_null:.*]] = llvm.mlir.null : !llvm.ptr<i32>
+// CHECK: %[[dst_gep:.*]] = llvm.getelementptr %[[dst_null]][1]
+// CHECK: %[[dst_es:.*]] = llvm.ptrtoint %[[dst_gep]] : !llvm.ptr<i32> to i64
+// CHECK: %[[dst_size:.*]] = llvm.mul %[[dst_dim]], %[[dst_es]]
+// CHECK: %[[new_buffer_raw:.*]] = llvm.call @malloc(%[[dst_size]])
+// CHECK: %[[new_buffer:.*]] = llvm.bitcast %[[new_buffer_raw]] : !llvm.ptr<i8> to !llvm.ptr<i32>
+// CHECK: %[[old_buffer_aligned:.*]] = llvm.extractvalue %[[descriptor]][1]
+// CHECK: %[[volatile:.*]] = llvm.mlir.constant(false) : i1
+// CHECK-DAG: %[[new_buffer_void:.*]] = llvm.bitcast %[[new_buffer]] : !llvm.ptr<i32> to !llvm.ptr<i8>
+// CHECK-DAG: %[[old_buffer_void:.*]] = llvm.bitcast %[[old_buffer_aligned]] : !llvm.ptr<i32> to !llvm.ptr<i8>
+// CHECK: "llvm.intr.memcpy"(%[[new_buffer_void]], %[[old_buffer_void]], %[[dst_size]], %[[volatile]])
+// CHECK: %[[old_buffer_unaligned:.*]] = llvm.extractvalue %[[descriptor]][0]
+// CHECK: %[[old_buffer_unaligned_void:.*]] = llvm.bitcast %[[old_buffer_unaligned]] : !llvm.ptr<i32> to !llvm.ptr<i8>
+// CHECK: llvm.call @free(%[[old_buffer_unaligned_void]])
+// CHECK: %[[descriptor_update1:.*]] = llvm.insertvalue %[[new_buffer]], %[[descriptor]][0]
+// CHECK: %[[descriptor_update2:.*]] = llvm.insertvalue %[[new_buffer]], %[[descriptor_update1]][1]
+// CHECK: llvm.br ^bb2(%[[descriptor_update2]] : !llvm.struct<(ptr<i32>, ptr<i32>, i64, array<1 x i64>, array<1 x i64>)>)
+// CHECK: ^bb2(%[[descriptor_update3:.*]]: !llvm.struct<(ptr<i32>, ptr<i32>, i64, array<1 x i64>, array<1 x i64>)>):
+// CHECK: %[[descriptor_update4:.*]] = llvm.insertvalue %[[dst_dim]], %[[descriptor_update3]][3, 0]
+// CHECK: %[[descriptor_update5:.*]] = builtin.unrealized_conversion_cast %[[descriptor_update4]]
+// CHECK: return %[[descriptor_update5]] : memref<4xi32>
+
+ %out = memref.realloc %in : memref<2xi32> to memref<4xi32>
+ return %out : memref<4xi32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @realloc_static_alignment(
+// CHECK-SAME: %[[arg0:.*]]: memref<2xf32>) -> memref<4xf32> {
+func.func @realloc_static_alignment(%in: memref<2xf32>) -> memref<4xf32>{
+// CHECK: %[[descriptor:.*]] = builtin.unrealized_conversion_cast %[[arg0]] : memref<2xf32> to !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+// CHECK: %[[src_dim:.*]] = llvm.mlir.constant(2 : index) : i64
+// CHECK: %[[dst_dim:.*]] = llvm.mlir.constant(4 : index) : i64
+// CHECK: %[[cond:.*]] = llvm.icmp "ugt" %[[dst_dim]], %[[src_dim]] : i64
+// CHECK: llvm.cond_br %[[cond]], ^bb1, ^bb2(%[[descriptor]]
+// CHECK: ^bb1:
+// CHECK: %[[dst_null:.*]] = llvm.mlir.null : !llvm.ptr<f32>
+// CHECK: %[[dst_gep:.*]] = llvm.getelementptr %[[dst_null]][1]
+// CHECK: %[[dst_es:.*]] = llvm.ptrtoint %[[dst_gep]] : !llvm.ptr<f32> to i64
+// CHECK: %[[dst_size:.*]] = llvm.mul %[[dst_dim]], %[[dst_es]]
+// CHECK: %[[alignment:.*]] = llvm.mlir.constant(8 : index) : i64
+// CHECK: %[[adjust_dst_size:.*]] = llvm.add %[[dst_size]], %[[alignment]]
+// CHECK: %[[new_buffer_raw:.*]] = llvm.call @malloc(%[[adjust_dst_size]])
+// CHECK: %[[new_buffer_unaligned:.*]] = llvm.bitcast %[[new_buffer_raw]] : !llvm.ptr<i8> to !llvm.ptr<f32>
+// CHECK: %[[new_buffer_int:.*]] = llvm.ptrtoint %[[new_buffer_unaligned]] : !llvm.ptr<f32>
+// CHECK: %[[const_1:.*]] = llvm.mlir.constant(1 : index) : i64
+// CHECK: %[[alignment_m1:.*]] = llvm.sub %[[alignment]], %[[const_1]]
+// CHECK: %[[ptr_alignment_m1:.*]] = llvm.add %[[new_buffer_int]], %[[alignment_m1]]
+// CHECK: %[[padding:.*]] = llvm.urem %[[ptr_alignment_m1]], %[[alignment]]
+// CHECK: %[[new_buffer_aligned_int:.*]] = llvm.sub %[[ptr_alignment_m1]], %[[padding]]
+// CHECK: %[[new_buffer_aligned:.*]] = llvm.inttoptr %[[new_buffer_aligned_int]] : i64 to !llvm.ptr<f32>
+// CHECK: %[[old_buffer_aligned:.*]] = llvm.extractvalue %[[descriptor]][1]
+// CHECK: %[[volatile:.*]] = llvm.mlir.constant(false) : i1
+// CHECK-DAG: %[[new_buffer_void:.*]] = llvm.bitcast %[[new_buffer_aligned]] : !llvm.ptr<f32> to !llvm.ptr<i8>
+// CHECK-DAG: %[[old_buffer_void:.*]] = llvm.bitcast %[[old_buffer_aligned]] : !llvm.ptr<f32> to !llvm.ptr<i8>
+// CHECK: "llvm.intr.memcpy"(%[[new_buffer_void]], %[[old_buffer_void]], %[[dst_size]], %[[volatile]])
+// CHECK: %[[old_buffer_unaligned:.*]] = llvm.extractvalue %[[descriptor]][0]
+// CHECK: %[[old_buffer_unaligned_void:.*]] = llvm.bitcast %[[old_buffer_unaligned]] : !llvm.ptr<f32> to !llvm.ptr<i8>
+// CHECK: llvm.call @free(%[[old_buffer_unaligned_void]])
+// CHECK: %[[descriptor_update1:.*]] = llvm.insertvalue %[[new_buffer_unaligned]], %[[descriptor]][0]
+// CHECK: %[[descriptor_update2:.*]] = llvm.insertvalue %[[new_buffer_aligned]], %[[descriptor_update1]][1]
+// CHECK: llvm.br ^bb2(%[[descriptor_update2]]
+// CHECK: ^bb2(%[[descriptor_update3:.*]]: !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>):
+// CHECK: %[[descriptor_update4:.*]] = llvm.insertvalue %[[dst_dim]], %[[descriptor_update3]][3, 0]
+// CHECK: %[[descriptor_update5:.*]] = builtin.unrealized_conversion_cast %[[descriptor_update4]]
+// CHECK: return %[[descriptor_update5]] : memref<4xf32>
+
+
+ %out = memref.realloc %in {alignment = 8} : memref<2xf32> to memref<4xf32>
+ return %out : memref<4xf32>
+}
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 3835cb3221c93..fc52e9c7b750d 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -824,3 +824,16 @@ func.func @canonicalize_rank_reduced_subview(%arg0 : memref<8x?xf32>,
// CHECK-SAME: %[[ARG1:.+]]: index
// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG0]][0, 0] [1, %[[ARG1]]] [1, 1]
// CHECK-SAME: memref<8x?xf32> to memref<?xf32, strided<[1], offset: ?>>
+
+// ----
+
+// CHECK-LABEL: func @memref_realloc_dead
+// CHECK-SAME: %[[SRC:[0-9a-z]+]]: memref<2xf32>
+// CHECK-NOT: memref.realloc
+// CHECK: return %[[SRC]]
+func.func @memref_realloc_dead(%src : memref<2xf32>, %v : f32) -> memref<2xf32>{
+ %0 = memref.realloc %src : memref<2xf32> to memref<4xf32>
+ %i2 = arith.constant 2 : index
+ memref.store %v, %0[%i2] : memref<4xf32>
+ return %src : memref<2xf32>
+}
diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir
index 6dd5439cd056f..344f22cb7d2eb 100644
--- a/mlir/test/Dialect/MemRef/invalid.mlir
+++ b/mlir/test/Dialect/MemRef/invalid.mlir
@@ -992,3 +992,37 @@ func.func @atomic_yield_type_mismatch(%I: memref<10xf32>, %i : index) {
}
return
}
+
+// -----
+
+#map0 = affine_map<(d0) -> (d0 floordiv 8, d0 mod 8)>
+func.func @memref_realloc_layout(%src : memref<256xf32, #map0>) -> memref<?xf32>{
+ // expected-error at +1 {{unsupported layout}}
+ %0 = memref.realloc %src : memref<256xf32, #map0> to memref<?xf32>
+ return %0 : memref<?xf32>
+}
+
+// -----
+
+func.func @memref_realloc_sizes_1(%src : memref<2xf32>) -> memref<?xf32>{
+ // expected-error at +1 {{missing dimension operand}}
+ %0 = memref.realloc %src : memref<2xf32> to memref<?xf32>
+ return %0 : memref<?xf32>
+}
+
+// -----
+
+func.func @memref_realloc_sizes_2(%src : memref<?xf32>, %d : index)
+ -> memref<4xf32>{
+ // expected-error at +1 {{unnecessary dimension operand}}
+ %0 = memref.realloc %src(%d) : memref<?xf32> to memref<4xf32>
+ return %0 : memref<4xf32>
+}
+
+// -----
+
+func.func @memref_realloc_type(%src : memref<256xf32>) -> memref<?xi32>{
+ // expected-error at +1 {{
diff erent element types}}
+ %0 = memref.realloc %src : memref<256xf32> to memref<?xi32>
+ return %0 : memref<?xi32>
+}
diff --git a/mlir/test/Dialect/MemRef/ops.mlir b/mlir/test/Dialect/MemRef/ops.mlir
index 7d469df556454..ca64a6ac2a8ea 100644
--- a/mlir/test/Dialect/MemRef/ops.mlir
+++ b/mlir/test/Dialect/MemRef/ops.mlir
@@ -347,3 +347,30 @@ func.func @extract_strided_metadata(%memref : memref<10x?xf32>)
return %m2: memref<?x?xf32, strided<[?, ?], offset: ?>>
}
+
+// -----
+
+// CHECK-LABEL: func @memref_realloc_ss
+func.func @memref_realloc_ss(%src : memref<2xf32>) -> memref<4xf32>{
+ %0 = memref.realloc %src : memref<2xf32> to memref<4xf32>
+ return %0 : memref<4xf32>
+}
+
+// CHECK-LABEL: func @memref_realloc_sd
+func.func @memref_realloc_sd(%src : memref<2xf32>, %d : index) -> memref<?xf32>{
+ %0 = memref.realloc %src(%d) : memref<2xf32> to memref<?xf32>
+ return %0 : memref<?xf32>
+}
+
+// CHECK-LABEL: func @memref_realloc_ds
+func.func @memref_realloc_ds(%src : memref<?xf32>) -> memref<4xf32>{
+ %0 = memref.realloc %src: memref<?xf32> to memref<4xf32>
+ return %0 : memref<4xf32>
+}
+
+// CHECK-LABEL: func @memref_realloc_dd
+func.func @memref_realloc_dd(%src : memref<?xf32>, %d: index)
+ -> memref<?xf32>{
+ %0 = memref.realloc %src(%d) : memref<?xf32> to memref<?xf32>
+ return %0 : memref<?xf32>
+}
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-realloc.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-realloc.mlir
new file mode 100644
index 0000000000000..518b1554505f6
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-realloc.mlir
@@ -0,0 +1,57 @@
+// RUN: mlir-opt %s -convert-scf-to-cf -convert-vector-to-llvm -convert-memref-to-llvm -convert-func-to-llvm -reconcile-unrealized-casts |\
+// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
+// RUN: -shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext
+// RUN: mlir-opt %s -convert-scf-to-cf -convert-vector-to-llvm -convert-memref-to-llvm='use-aligned-alloc=1' -convert-func-to-llvm -arith-expand -reconcile-unrealized-casts |\
+// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
+// RUN: -shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext | FileCheck %s
+
+func.func @entry() {
+ // Set up memory.
+ %c0 = arith.constant 0: index
+ %c1 = arith.constant 1: index
+ %c8 = arith.constant 8: index
+ %A = memref.alloc() : memref<8xf32>
+ scf.for %i = %c0 to %c8 step %c1 {
+ %i32 = arith.index_cast %i : index to i32
+ %fi = arith.sitofp %i32 : i32 to f32
+ memref.store %fi, %A[%i] : memref<8xf32>
+ }
+
+ %d0 = arith.constant -1.0 : f32
+ %Av = vector.transfer_read %A[%c0], %d0: memref<8xf32>, vector<8xf32>
+ vector.print %Av : vector<8xf32>
+ // CHECK: ( 0, 1, 2, 3, 4, 5, 6, 7 )
+
+ // Realloc with static sizes.
+ %B = memref.realloc %A : memref<8xf32> to memref<10xf32>
+
+ %c10 = arith.constant 10: index
+ scf.for %i = %c8 to %c10 step %c1 {
+ %i32 = arith.index_cast %i : index to i32
+ %fi = arith.sitofp %i32 : i32 to f32
+ memref.store %fi, %B[%i] : memref<10xf32>
+ }
+
+ %Bv = vector.transfer_read %B[%c0], %d0: memref<10xf32>, vector<10xf32>
+ vector.print %Bv : vector<10xf32>
+ // CHECK: ( 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 )
+
+ // Realloc with dynamic sizes.
+ %Bd = memref.cast %B : memref<10xf32> to memref<?xf32>
+ %c13 = arith.constant 13: index
+ %Cd = memref.realloc %Bd(%c13) : memref<?xf32> to memref<?xf32>
+ %C = memref.cast %Cd : memref<?xf32> to memref<13xf32>
+
+ scf.for %i = %c10 to %c13 step %c1 {
+ %i32 = arith.index_cast %i : index to i32
+ %fi = arith.sitofp %i32 : i32 to f32
+ memref.store %fi, %C[%i] : memref<13xf32>
+ }
+
+ %Cv = vector.transfer_read %C[%c0], %d0: memref<13xf32>, vector<13xf32>
+ vector.print %Cv : vector<13xf32>
+ // CHECK: ( 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12 )
+
+ memref.dealloc %C : memref<13xf32>
+ return
+}
More information about the Mlir-commits
mailing list