[Mlir-commits] [mlir] 665371d - [mlir] Split alloc-like op LLVM lowerings into base and separate derived classes.
Christian Sigg
llvmlistbot at llvm.org
Mon Oct 5 08:36:13 PDT 2020
Author: Christian Sigg
Date: 2020-10-05T17:36:01+02:00
New Revision: 665371d0b29910d7fba618a707d6b732e2037ee2
URL: https://github.com/llvm/llvm-project/commit/665371d0b29910d7fba618a707d6b732e2037ee2
DIFF: https://github.com/llvm/llvm-project/commit/665371d0b29910d7fba618a707d6b732e2037ee2.diff
LOG: [mlir] Split alloc-like op LLVM lowerings into base and separate derived classes.
The previous code did the lowering to alloca, malloc, and aligned_malloc
in a single class with different code paths that are somewhat difficult to
follow.
This change moves the common code to a base class and has a separte
derived class per lowering target that contains the specifics.
Reviewed By: ftynse
Differential Revision: https://reviews.llvm.org/D88696
Added:
Modified:
mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir
mlir/test/Conversion/StandardToLLVM/convert-static-memref-ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
index d98a0ff6efb3..645f4cd26581 100644
--- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
+++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
@@ -412,6 +412,7 @@ class ConvertToLLVMPattern : public ConversionPattern {
LLVMTypeConverter &typeConverter,
PatternBenefit benefit = 1);
+protected:
/// Returns the LLVM dialect.
LLVM::LLVMDialect &getDialect() const;
@@ -419,6 +420,10 @@ class ConvertToLLVMPattern : public ConversionPattern {
/// defined by the used type converter.
LLVM::LLVMType getIndexType() const;
+ /// Gets the MLIR type wrapping the LLVM integer type whose bit width
+ /// corresponds to that of a LLVM pointer type.
+ LLVM::LLVMType getIntPtrType(unsigned addressSpace = 0) const;
+
/// Gets the MLIR type wrapping the LLVM void type.
LLVM::LLVMType getVoidType() const;
@@ -470,6 +475,15 @@ class ConvertToLLVMPattern : public ConversionPattern {
ArrayRef<Value> shape,
ConversionPatternRewriter &rewriter) const;
+ /// Creates and populates the memref descriptor struct given all its fields.
+ /// 'strides' can be either dynamic (kDynamicStrideOrOffset) or static, but
+ /// not a mix of the two.
+ MemRefDescriptor
+ createMemRefDescriptor(Location loc, MemRefType memRefType,
+ Value allocatedPtr, Value alignedPtr, uint64_t offset,
+ ArrayRef<int64_t> strides, ArrayRef<Value> sizes,
+ ConversionPatternRewriter &rewriter) const;
+
protected:
/// Reference to the type converter, with potential extensions.
LLVMTypeConverter &typeConverter;
diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index 731eab0c28df..75d07f35d226 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -872,6 +872,13 @@ LLVM::LLVMType ConvertToLLVMPattern::getIndexType() const {
return typeConverter.getIndexType();
}
+LLVM::LLVMType
+ConvertToLLVMPattern::getIntPtrType(unsigned addressSpace) const {
+ return LLVM::LLVMType::getIntNTy(
+ &typeConverter.getContext(),
+ typeConverter.getPointerBitwidth(addressSpace));
+}
+
LLVM::LLVMType ConvertToLLVMPattern::getVoidType() const {
return LLVM::LLVMType::getVoidTy(&typeConverter.getContext());
}
@@ -911,12 +918,12 @@ Value ConvertToLLVMPattern::getStridedElementPtr(
Value base = memRefDescriptor.alignedPtr(rewriter, loc);
Value offsetValue = offset == MemRefType::getDynamicStrideOrOffset()
? memRefDescriptor.offset(rewriter, loc)
- : this->createIndexConstant(rewriter, loc, offset);
+ : createIndexConstant(rewriter, loc, offset);
for (int i = 0, e = indices.size(); i < e; ++i) {
Value stride = strides[i] == MemRefType::getDynamicStrideOrOffset()
? memRefDescriptor.stride(rewriter, loc, i)
- : this->createIndexConstant(rewriter, loc, strides[i]);
+ : createIndexConstant(rewriter, loc, strides[i]);
Value additionalOffset =
rewriter.create<LLVM::MulOp>(loc, indices[i], stride);
offsetValue =
@@ -973,19 +980,69 @@ Value ConvertToLLVMPattern::getSizeInBytes(
}
Value ConvertToLLVMPattern::getCumulativeSizeInBytes(
- Location loc, Type elementType, ArrayRef<Value> sizes,
+ Location loc, Type elementType, ArrayRef<Value> shape,
ConversionPatternRewriter &rewriter) const {
// Compute the total number of memref elements.
Value cumulativeSizeInBytes =
- sizes.empty() ? createIndexConstant(rewriter, loc, 1) : sizes.front();
- for (unsigned i = 1, e = sizes.size(); i < e; ++i)
+ shape.empty() ? createIndexConstant(rewriter, loc, 1) : shape.front();
+ for (unsigned i = 1, e = shape.size(); i < e; ++i)
cumulativeSizeInBytes = rewriter.create<LLVM::MulOp>(
- loc, getIndexType(), ArrayRef<Value>{cumulativeSizeInBytes, sizes[i]});
+ loc, getIndexType(), ArrayRef<Value>{cumulativeSizeInBytes, shape[i]});
auto elementSize = this->getSizeInBytes(loc, elementType, rewriter);
return rewriter.create<LLVM::MulOp>(
loc, getIndexType(), ArrayRef<Value>{cumulativeSizeInBytes, elementSize});
}
+/// Creates and populates the memref descriptor struct given all its fields.
+MemRefDescriptor ConvertToLLVMPattern::createMemRefDescriptor(
+ Location loc, MemRefType memRefType, Value allocatedPtr, Value alignedPtr,
+ uint64_t offset, ArrayRef<int64_t> strides, ArrayRef<Value> sizes,
+ ConversionPatternRewriter &rewriter) const {
+ auto structType = typeConverter.convertType(memRefType);
+ auto memRefDescriptor = MemRefDescriptor::undef(rewriter, loc, structType);
+
+ // Field 1: Allocated pointer, used for malloc/free.
+ memRefDescriptor.setAllocatedPtr(rewriter, loc, allocatedPtr);
+
+ // Field 2: Actual aligned pointer to payload.
+ memRefDescriptor.setAlignedPtr(rewriter, loc, alignedPtr);
+
+ // Field 3: Offset in aligned pointer.
+ memRefDescriptor.setOffset(rewriter, loc,
+ createIndexConstant(rewriter, loc, offset));
+
+ if (memRefType.getRank() == 0)
+ // No size/stride descriptor in memref, return the descriptor value.
+ return memRefDescriptor;
+
+ // Fields 4 and 5: sizes and strides of the strided MemRef.
+ // Store all sizes in the descriptor. Only dynamic sizes are passed in as
+ // operands to AllocOp.
+ Value runningStride = nullptr;
+ // Iterate strides in reverse order, compute runningStride and strideValues.
+ auto nStrides = strides.size();
+ SmallVector<Value, 4> strideValues(nStrides, nullptr);
+ for (unsigned i = 0; i < nStrides; ++i) {
+ int64_t index = nStrides - 1 - i;
+ if (strides[index] == MemRefType::getDynamicStrideOrOffset())
+ // Identity layout map is enforced in the match function, so we compute:
+ // `runningStride *= sizes[index + 1]`
+ runningStride = runningStride ? rewriter.create<LLVM::MulOp>(
+ loc, runningStride, sizes[index + 1])
+ : createIndexConstant(rewriter, loc, 1);
+ else
+ runningStride = createIndexConstant(rewriter, loc, strides[index]);
+ strideValues[index] = runningStride;
+ }
+ // Fill size and stride descriptors in memref.
+ for (auto indexedSize : llvm::enumerate(sizes)) {
+ int64_t index = indexedSize.index();
+ memRefDescriptor.setSize(rewriter, loc, index, indexedSize.value());
+ memRefDescriptor.setStride(rewriter, loc, index, strideValues[index]);
+ }
+ return memRefDescriptor;
+}
+
/// Only retain those attributes that are not constructed by
/// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out argument
/// attributes.
@@ -1710,251 +1767,84 @@ static bool isSupportedMemRefType(MemRefType type) {
}
/// Lowering for AllocOp and AllocaOp.
-template <typename AllocLikeOp>
-struct AllocLikeOpLowering : public ConvertOpToLLVMPattern<AllocLikeOp> {
- using ConvertOpToLLVMPattern<AllocLikeOp>::createIndexConstant;
- using ConvertOpToLLVMPattern<AllocLikeOp>::getIndexType;
- using ConvertOpToLLVMPattern<AllocLikeOp>::typeConverter;
- using ConvertOpToLLVMPattern<AllocLikeOp>::getVoidPtrType;
+struct AllocLikeOpLowering : public ConvertToLLVMPattern {
+ using ConvertToLLVMPattern::createIndexConstant;
+ using ConvertToLLVMPattern::getIndexType;
+ using ConvertToLLVMPattern::getVoidPtrType;
+ using ConvertToLLVMPattern::typeConverter;
+
+ explicit AllocLikeOpLowering(StringRef opName, LLVMTypeConverter &converter)
+ : ConvertToLLVMPattern(opName, &converter.getContext(), converter) {}
+
+protected:
+ // Returns 'input' aligned up to 'alignment'. Computes
+ // bumped = input + alignement - 1
+ // aligned = bumped - bumped % alignment
+ static Value createAligned(ConversionPatternRewriter &rewriter, Location loc,
+ Value input, Value alignment) {
+ Value one = createIndexAttrConstant(rewriter, loc, alignment.getType(), 1);
+ Value bump = rewriter.create<LLVM::SubOp>(loc, alignment, one);
+ Value bumped = rewriter.create<LLVM::AddOp>(loc, input, bump);
+ Value mod = rewriter.create<LLVM::URemOp>(loc, bumped, alignment);
+ return rewriter.create<LLVM::SubOp>(loc, bumped, mod);
+ }
+
+ // Creates a call to an allocation function with params and casts the
+ // resulting void pointer to ptrType.
+ Value createAllocCall(Location loc, StringRef name, Type ptrType,
+ ArrayRef<Value> params, ModuleOp module,
+ ConversionPatternRewriter &rewriter) const {
+ SmallVector<LLVM::LLVMType, 2> paramTypes;
+ auto allocFuncOp = module.lookupSymbol<LLVM::LLVMFuncOp>(name);
+ if (!allocFuncOp) {
+ for (Value param : params)
+ paramTypes.push_back(param.getType().cast<LLVM::LLVMType>());
+ auto allocFuncType =
+ LLVM::LLVMType::getFunctionTy(getVoidPtrType(), paramTypes,
+ /*isVarArg=*/false);
+ OpBuilder::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPointToStart(module.getBody());
+ allocFuncOp = rewriter.create<LLVM::LLVMFuncOp>(rewriter.getUnknownLoc(),
+ name, allocFuncType);
+ }
+ auto allocFuncSymbol = rewriter.getSymbolRefAttr(allocFuncOp);
+ auto allocatedPtr = rewriter
+ .create<LLVM::CallOp>(loc, getVoidPtrType(),
+ allocFuncSymbol, params)
+ .getResult(0);
+ return rewriter.create<LLVM::BitcastOp>(loc, ptrType, allocatedPtr);
+ }
- explicit AllocLikeOpLowering(LLVMTypeConverter &converter)
- : ConvertOpToLLVMPattern<AllocLikeOp>(converter) {}
+ /// Allocates the underlying buffer. Returns the allocated pointer and the
+ /// aligned pointer.
+ virtual std::tuple<Value, Value>
+ allocateBuffer(ConversionPatternRewriter &rewriter, Location loc,
+ Value cumulativeSize, Operation *op) const = 0;
+
+private:
+ static MemRefType getMemRefResultType(Operation *op) {
+ return op->getResult(0).getType().cast<MemRefType>();
+ }
LogicalResult match(Operation *op) const override {
- MemRefType memRefType = cast<AllocLikeOp>(op).getType();
+ MemRefType memRefType = getMemRefResultType(op);
if (isSupportedMemRefType(memRefType))
return success();
int64_t offset;
SmallVector<int64_t, 4> strides;
- auto successStrides = getStridesAndOffset(memRefType, strides, offset);
- if (failed(successStrides))
+ if (failed(getStridesAndOffset(memRefType, strides, offset)))
return failure();
// Dynamic strides are ok if they can be deduced from dynamic sizes (which
- // is guaranteed when succeeded(successStrides)). Dynamic offset however can
- // never be alloc'ed.
+ // is guaranteed when getStridesAndOffset succeeded. Dynamic offset however
+ // can never be alloc'ed.
if (offset == MemRefType::getDynamicStrideOrOffset())
return failure();
return success();
}
- // Returns bump = (alignment - (input % alignment))% alignment, which is the
- // increment necessary to align `input` to `alignment` boundary.
- // TODO: this can be made more efficient by just using a single addition
- // and two bit shifts: (ptr + align - 1)/align, align is always power of 2.
- Value createBumpToAlign(Location loc, OpBuilder b, Value input,
- Value alignment) const {
- Value modAlign = b.create<LLVM::URemOp>(loc, input, alignment);
- Value
diff = b.create<LLVM::SubOp>(loc, alignment, modAlign);
- Value shift = b.create<LLVM::URemOp>(loc,
diff , alignment);
- return shift;
- }
-
- /// Creates and populates the memref descriptor struct given all its fields.
- /// This method also performs any post allocation alignment needed for heap
- /// allocations when `accessAlignment` is non null. This is used with
- /// allocators that do not support alignment.
- MemRefDescriptor createMemRefDescriptor(
- Location loc, ConversionPatternRewriter &rewriter, MemRefType memRefType,
- Value allocatedTypePtr, Value allocatedBytePtr, Value accessAlignment,
- uint64_t offset, ArrayRef<int64_t> strides, ArrayRef<Value> sizes) const {
- auto elementPtrType = this->getElementPtrType(memRefType);
- auto structType = typeConverter.convertType(memRefType);
- auto memRefDescriptor = MemRefDescriptor::undef(rewriter, loc, structType);
-
- // Field 1: Allocated pointer, used for malloc/free.
- memRefDescriptor.setAllocatedPtr(rewriter, loc, allocatedTypePtr);
-
- // Field 2: Actual aligned pointer to payload.
- Value alignedBytePtr = allocatedTypePtr;
- if (accessAlignment) {
- // offset = (align - (ptr % align))% align
- Value intVal = rewriter.create<LLVM::PtrToIntOp>(
- loc, this->getIndexType(), allocatedBytePtr);
- Value offset = createBumpToAlign(loc, rewriter, intVal, accessAlignment);
- Value aligned = rewriter.create<LLVM::GEPOp>(
- loc, allocatedBytePtr.getType(), allocatedBytePtr, offset);
- alignedBytePtr = rewriter.create<LLVM::BitcastOp>(
- loc, elementPtrType, ArrayRef<Value>(aligned));
- }
- memRefDescriptor.setAlignedPtr(rewriter, loc, alignedBytePtr);
-
- // Field 3: Offset in aligned pointer.
- memRefDescriptor.setOffset(rewriter, loc,
- createIndexConstant(rewriter, loc, offset));
-
- if (memRefType.getRank() == 0)
- // No size/stride descriptor in memref, return the descriptor value.
- return memRefDescriptor;
-
- // Fields 4 and 5: sizes and strides of the strided MemRef.
- // Store all sizes in the descriptor. Only dynamic sizes are passed in as
- // operands to AllocOp.
- Value runningStride = nullptr;
- // Iterate strides in reverse order, compute runningStride and strideValues.
- auto nStrides = strides.size();
- SmallVector<Value, 4> strideValues(nStrides, nullptr);
- for (unsigned i = 0; i < nStrides; ++i) {
- int64_t index = nStrides - 1 - i;
- if (strides[index] == MemRefType::getDynamicStrideOrOffset())
- // Identity layout map is enforced in the match function, so we compute:
- // `runningStride *= sizes[index + 1]`
- runningStride = runningStride
- ? rewriter.create<LLVM::MulOp>(loc, runningStride,
- sizes[index + 1])
- : createIndexConstant(rewriter, loc, 1);
- else
- runningStride = createIndexConstant(rewriter, loc, strides[index]);
- strideValues[index] = runningStride;
- }
- // Fill size and stride descriptors in memref.
- for (auto indexedSize : llvm::enumerate(sizes)) {
- int64_t index = indexedSize.index();
- memRefDescriptor.setSize(rewriter, loc, index, indexedSize.value());
- memRefDescriptor.setStride(rewriter, loc, index, strideValues[index]);
- }
- return memRefDescriptor;
- }
-
- /// Returns the memref's element size in bytes.
- // TODO: there are other places where this is used. Expose publicly?
- static unsigned getMemRefEltSizeInBytes(MemRefType memRefType) {
- auto elementType = memRefType.getElementType();
-
- unsigned sizeInBits;
- if (elementType.isIntOrFloat()) {
- sizeInBits = elementType.getIntOrFloatBitWidth();
- } else {
- auto vectorType = elementType.cast<VectorType>();
- sizeInBits =
- vectorType.getElementTypeBitWidth() * vectorType.getNumElements();
- }
- return llvm::divideCeil(sizeInBits, 8);
- }
-
- /// Returns the alignment to be used for the allocation call itself.
- /// aligned_alloc requires the allocation size to be a power of two, and the
- /// allocation size to be a multiple of alignment,
- Optional<int64_t> getAllocationAlignment(AllocOp allocOp) const {
- // No alignment can be used for the 'malloc' call itself.
- if (!typeConverter.getOptions().useAlignedAlloc)
- return None;
-
- if (Optional<uint64_t> alignment = allocOp.alignment())
- return *alignment;
-
- // Whenever we don't have alignment set, we will use an alignment
- // consistent with the element type; since the allocation size has to be a
- // power of two, we will bump to the next power of two if it already isn't.
- auto eltSizeBytes = getMemRefEltSizeInBytes(allocOp.getType());
- return std::max(kMinAlignedAllocAlignment,
- llvm::PowerOf2Ceil(eltSizeBytes));
- }
-
- /// Returns true if the memref size in bytes is known to be a multiple of
- /// factor.
- static bool isMemRefSizeMultipleOf(MemRefType type, uint64_t factor) {
- uint64_t sizeDivisor = getMemRefEltSizeInBytes(type);
- for (unsigned i = 0, e = type.getRank(); i < e; i++) {
- if (type.isDynamic(type.getDimSize(i)))
- continue;
- sizeDivisor = sizeDivisor * type.getDimSize(i);
- }
- return sizeDivisor % factor == 0;
- }
-
- /// Allocates the underlying buffer using the right call. `allocatedBytePtr`
- /// is set to null for stack allocations. `accessAlignment` is set if
- /// alignment is needed post allocation (for eg. in conjunction with malloc).
- Value allocateBuffer(Location loc, Value cumulativeSize, Operation *op,
- MemRefType memRefType, Value one, Value &accessAlignment,
- Value &allocatedBytePtr,
- ConversionPatternRewriter &rewriter) const {
- auto elementPtrType = this->getElementPtrType(memRefType);
-
- // With alloca, one gets a pointer to the element type right away.
- // For stack allocations.
- if (auto allocaOp = dyn_cast<AllocaOp>(op)) {
- allocatedBytePtr = nullptr;
- accessAlignment = nullptr;
- return rewriter.create<LLVM::AllocaOp>(
- loc, elementPtrType, cumulativeSize,
- allocaOp.alignment() ? *allocaOp.alignment() : 0);
- }
-
- // Heap allocations.
- AllocOp allocOp = cast<AllocOp>(op);
-
- Optional<int64_t> allocationAlignment = getAllocationAlignment(allocOp);
- // Whether to use std lib function aligned_alloc that supports alignment.
- bool useAlignedAlloc = allocationAlignment.hasValue();
-
- // Insert the malloc/aligned_alloc declaration if it is not already present.
- const auto *allocFuncName = useAlignedAlloc ? "aligned_alloc" : "malloc";
- auto module = allocOp.getParentOfType<ModuleOp>();
- auto allocFunc = module.lookupSymbol<LLVM::LLVMFuncOp>(allocFuncName);
- if (!allocFunc) {
- OpBuilder::InsertionGuard guard(rewriter);
- rewriter.setInsertionPointToStart(
- op->getParentOfType<ModuleOp>().getBody());
- SmallVector<LLVM::LLVMType, 2> callArgTypes = {getIndexType()};
- // aligned_alloc(size_t alignment, size_t size)
- if (useAlignedAlloc)
- callArgTypes.push_back(getIndexType());
- allocFunc = rewriter.create<LLVM::LLVMFuncOp>(
- rewriter.getUnknownLoc(), allocFuncName,
- LLVM::LLVMType::getFunctionTy(getVoidPtrType(), callArgTypes,
- /*isVarArg=*/false));
- }
-
- // Allocate the underlying buffer and store a pointer to it in the MemRef
- // descriptor.
- SmallVector<Value, 2> callArgs;
- if (useAlignedAlloc) {
- // Use aligned_alloc.
- assert(allocationAlignment && "allocation alignment should be present");
- auto alignedAllocAlignmentValue = rewriter.create<LLVM::ConstantOp>(
- loc, typeConverter.convertType(rewriter.getIntegerType(64)),
- rewriter.getI64IntegerAttr(allocationAlignment.getValue()));
- // aligned_alloc requires size to be a multiple of alignment; we will pad
- // the size to the next multiple if necessary.
- if (!isMemRefSizeMultipleOf(memRefType, allocationAlignment.getValue())) {
- Value bump = createBumpToAlign(loc, rewriter, cumulativeSize,
- alignedAllocAlignmentValue);
- cumulativeSize =
- rewriter.create<LLVM::AddOp>(loc, cumulativeSize, bump);
- }
- callArgs = {alignedAllocAlignmentValue, cumulativeSize};
- } else {
- // Adjust the allocation size to consider alignment.
- if (Optional<uint64_t> alignment = allocOp.alignment()) {
- accessAlignment = createIndexConstant(rewriter, loc, *alignment);
- } else if (!memRefType.getElementType().isSignlessIntOrIndexOrFloat()) {
- // In the case where no alignment is specified, we may want to override
- // `malloc's` behavior. `malloc` typically aligns at the size of the
- // biggest scalar on a target HW. For non-scalars, use the natural
- // alignment of the LLVM type given by the LLVM DataLayout.
- accessAlignment =
- this->getSizeInBytes(loc, memRefType.getElementType(), rewriter);
- }
- if (accessAlignment)
- cumulativeSize =
- rewriter.create<LLVM::AddOp>(loc, cumulativeSize, accessAlignment);
- callArgs.push_back(cumulativeSize);
- }
- auto allocFuncSymbol = rewriter.getSymbolRefAttr(allocFunc);
- allocatedBytePtr = rewriter
- .create<LLVM::CallOp>(loc, getVoidPtrType(),
- allocFuncSymbol, callArgs)
- .getResult(0);
- // For heap allocations, the allocated pointer is a cast of the byte pointer
- // to the type pointer.
- return rewriter.create<LLVM::BitcastOp>(loc, elementPtrType,
- allocatedBytePtr);
- }
-
// An `alloc` is converted into a definition of a memref descriptor value and
// a call to `malloc` to allocate the underlying data buffer. The memref
// descriptor is of the LLVM structure type where:
@@ -1964,15 +1854,16 @@ struct AllocLikeOpLowering : public ConvertOpToLLVMPattern<AllocLikeOp> {
// 3. the remaining elements serve to store all the sizes and strides of the
// memref using LLVM-converted `index` type.
//
- // Alignment is performed by allocating `alignment - 1` more bytes than
+ // Alignment is performed by allocating `alignment` more bytes than
// requested and shifting the aligned pointer relative to the allocated
- // memory. If alignment is unspecified, the two pointers are equal.
+ // memory. Note: `alignment - <minimum malloc alignment>` would actually be
+ // sufficient. If alignment is unspecified, the two pointers are equal.
// An `alloca` is converted into a definition of a memref descriptor value and
// an llvm.alloca to allocate the underlying data buffer.
void rewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- MemRefType memRefType = cast<AllocLikeOp>(op).getType();
+ MemRefType memRefType = getMemRefResultType(op);
auto loc = op->getLoc();
// Get actual sizes of the memref as values: static sizes are constant
@@ -1983,17 +1874,12 @@ struct AllocLikeOpLowering : public ConvertOpToLLVMPattern<AllocLikeOp> {
Value cumulativeSize = this->getCumulativeSizeInBytes(
loc, memRefType.getElementType(), sizes, rewriter);
+
// Allocate the underlying buffer.
- // Value holding the alignment that has to be performed post allocation
- // (in conjunction with allocators that do not support alignment, eg.
- // malloc); nullptr if no such adjustment needs to be performed.
- Value accessAlignment;
- // Byte pointer to the allocated buffer.
- Value allocatedBytePtr;
- Value allocatedTypePtr =
- allocateBuffer(loc, cumulativeSize, op, memRefType,
- createIndexConstant(rewriter, loc, 1), accessAlignment,
- allocatedBytePtr, rewriter);
+ Value allocatedPtr;
+ Value alignedPtr;
+ std::tie(allocatedPtr, alignedPtr) =
+ this->allocateBuffer(rewriter, loc, cumulativeSize, op);
int64_t offset;
SmallVector<int64_t, 4> strides;
@@ -2010,25 +1896,163 @@ struct AllocLikeOpLowering : public ConvertOpToLLVMPattern<AllocLikeOp> {
"unexpected number of strides");
// Create the MemRef descriptor.
- auto memRefDescriptor = createMemRefDescriptor(
- loc, rewriter, memRefType, allocatedTypePtr, allocatedBytePtr,
- accessAlignment, offset, strides, sizes);
+ auto memRefDescriptor =
+ this->createMemRefDescriptor(loc, memRefType, allocatedPtr, alignedPtr,
+ offset, strides, sizes, rewriter);
// Return the final value of the descriptor.
rewriter.replaceOp(op, {memRefDescriptor});
}
+};
-protected:
- /// The minimum alignment to use with aligned_alloc (has to be a power of 2).
- uint64_t kMinAlignedAllocAlignment = 16UL;
+struct AllocOpLowering : public AllocLikeOpLowering {
+ AllocOpLowering(LLVMTypeConverter &converter)
+ : AllocLikeOpLowering(AllocOp::getOperationName(), converter) {}
+
+ std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
+ Location loc, Value cumulativeSize,
+ Operation *op) const override {
+ // Heap allocations.
+ AllocOp allocOp = cast<AllocOp>(op);
+ MemRefType memRefType = allocOp.getType();
+
+ Value alignment;
+ if (auto alignmentAttr = allocOp.alignment()) {
+ alignment = createIndexConstant(rewriter, loc, *alignmentAttr);
+ } else if (!memRefType.getElementType().isSignlessIntOrIndexOrFloat()) {
+ // In the case where no alignment is specified, we may want to override
+ // `malloc's` behavior. `malloc` typically aligns at the size of the
+ // biggest scalar on a target HW. For non-scalars, use the natural
+ // alignment of the LLVM type given by the LLVM DataLayout.
+ alignment = getSizeInBytes(loc, memRefType.getElementType(), rewriter);
+ }
+
+ if (alignment) {
+ // Adjust the allocation size to consider alignment.
+ cumulativeSize =
+ rewriter.create<LLVM::AddOp>(loc, cumulativeSize, alignment);
+ }
+
+ // Allocate the underlying buffer and store a pointer to it in the MemRef
+ // descriptor.
+ Type elementPtrType = this->getElementPtrType(memRefType);
+ Value allocatedPtr =
+ createAllocCall(loc, "malloc", elementPtrType, {cumulativeSize},
+ allocOp.getParentOfType<ModuleOp>(), rewriter);
+
+ Value alignedPtr = allocatedPtr;
+ if (alignment) {
+ auto intPtrType = getIntPtrType(memRefType.getMemorySpace());
+ // Compute the aligned type pointer.
+ Value allocatedInt =
+ rewriter.create<LLVM::PtrToIntOp>(loc, intPtrType, allocatedPtr);
+ Value alignmentInt =
+ createAligned(rewriter, loc, allocatedInt, alignment);
+ alignedPtr =
+ rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, alignmentInt);
+ }
+
+ return std::make_tuple(allocatedPtr, alignedPtr);
+ }
};
-struct AllocOpLowering : public AllocLikeOpLowering<AllocOp> {
- explicit AllocOpLowering(LLVMTypeConverter &converter)
- : AllocLikeOpLowering<AllocOp>(converter) {}
+struct AlignedAllocOpLowering : public AllocLikeOpLowering {
+ AlignedAllocOpLowering(LLVMTypeConverter &converter)
+ : AllocLikeOpLowering(AllocOp::getOperationName(), converter) {}
+
+ /// Returns the memref's element size in bytes.
+ // TODO: there are other places where this is used. Expose publicly?
+ static unsigned getMemRefEltSizeInBytes(MemRefType memRefType) {
+ auto elementType = memRefType.getElementType();
+
+ unsigned sizeInBits;
+ if (elementType.isIntOrFloat()) {
+ sizeInBits = elementType.getIntOrFloatBitWidth();
+ } else {
+ auto vectorType = elementType.cast<VectorType>();
+ sizeInBits =
+ vectorType.getElementTypeBitWidth() * vectorType.getNumElements();
+ }
+ return llvm::divideCeil(sizeInBits, 8);
+ }
+
+ /// Returns true if the memref size in bytes is known to be a multiple of
+ /// factor.
+ static bool isMemRefSizeMultipleOf(MemRefType type, uint64_t factor) {
+ uint64_t sizeDivisor = getMemRefEltSizeInBytes(type);
+ for (unsigned i = 0, e = type.getRank(); i < e; i++) {
+ if (type.isDynamic(type.getDimSize(i)))
+ continue;
+ sizeDivisor = sizeDivisor * type.getDimSize(i);
+ }
+ return sizeDivisor % factor == 0;
+ }
+
+ /// Returns the alignment to be used for the allocation call itself.
+ /// aligned_alloc requires the allocation size to be a power of two, and the
+ /// allocation size to be a multiple of alignment,
+ int64_t getAllocationAlignment(AllocOp allocOp) const {
+ if (Optional<uint64_t> alignment = allocOp.alignment())
+ return *alignment;
+
+ // Whenever we don't have alignment set, we will use an alignment
+ // consistent with the element type; since the allocation size has to be a
+ // power of two, we will bump to the next power of two if it already isn't.
+ auto eltSizeBytes = getMemRefEltSizeInBytes(allocOp.getType());
+ return std::max(kMinAlignedAllocAlignment,
+ llvm::PowerOf2Ceil(eltSizeBytes));
+ }
+
+ std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
+ Location loc, Value cumulativeSize,
+ Operation *op) const override {
+ // Heap allocations.
+ AllocOp allocOp = cast<AllocOp>(op);
+ MemRefType memRefType = allocOp.getType();
+ int64_t alignment = getAllocationAlignment(allocOp);
+ Value allocAlignment = createIndexConstant(rewriter, loc, alignment);
+
+ // aligned_alloc requires size to be a multiple of alignment; we will pad
+ // the size to the next multiple if necessary.
+ if (!isMemRefSizeMultipleOf(memRefType, alignment))
+ cumulativeSize =
+ createAligned(rewriter, loc, cumulativeSize, allocAlignment);
+
+ Type elementPtrType = this->getElementPtrType(memRefType);
+ Value allocatedPtr = createAllocCall(
+ loc, "aligned_alloc", elementPtrType, {allocAlignment, cumulativeSize},
+ allocOp.getParentOfType<ModuleOp>(), rewriter);
+
+ return std::make_tuple(allocatedPtr, allocatedPtr);
+ }
+
+ /// The minimum alignment to use with aligned_alloc (has to be a power of 2).
+ static constexpr uint64_t kMinAlignedAllocAlignment = 16UL;
};
-using AllocaOpLowering = AllocLikeOpLowering<AllocaOp>;
+struct AllocaOpLowering : public AllocLikeOpLowering {
+ AllocaOpLowering(LLVMTypeConverter &converter)
+ : AllocLikeOpLowering(AllocaOp::getOperationName(), converter) {}
+
+ /// Allocates the underlying buffer using the right call. `allocatedBytePtr`
+ /// is set to null for stack allocations. `accessAlignment` is set if
+ /// alignment is needed post allocation (for eg. in conjunction with malloc).
+ std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
+ Location loc, Value cumulativeSize,
+ Operation *op) const override {
+
+ // With alloca, one gets a pointer to the element type right away.
+ // For stack allocations.
+ auto allocaOp = cast<AllocaOp>(op);
+ auto elementPtrType = this->getElementPtrType(allocaOp.getType());
+
+ auto allocatedElementPtr = rewriter.create<LLVM::AllocaOp>(
+ loc, elementPtrType, cumulativeSize,
+ allocaOp.alignment() ? *allocaOp.alignment() : 0);
+
+ return std::make_tuple(allocatedElementPtr, allocatedElementPtr);
+ }
+};
/// Copies the shaped descriptor part to (if `toDynamic` is set) or from
/// (otherwise) the dynamically allocated memory for any operands that were
@@ -3200,12 +3224,13 @@ struct AssumeAlignmentOpLowering
// This relies on LLVM's CSE optimization (potentially after SROA), since
// after CSE all memref.alignedPtr instances get de-duplicated into the same
// pointer SSA value.
- Value zero =
- createIndexAttrConstant(rewriter, op->getLoc(), getIndexType(), 0);
- Value mask = createIndexAttrConstant(rewriter, op->getLoc(), getIndexType(),
+ auto intPtrType =
+ getIntPtrType(memRefDescriptor.getElementPtrType().getAddressSpace());
+ Value zero = createIndexAttrConstant(rewriter, op->getLoc(), intPtrType, 0);
+ Value mask = createIndexAttrConstant(rewriter, op->getLoc(), intPtrType,
alignment - 1);
Value ptrValue =
- rewriter.create<LLVM::PtrToIntOp>(op->getLoc(), getIndexType(), ptr);
+ rewriter.create<LLVM::PtrToIntOp>(op->getLoc(), intPtrType, ptr);
rewriter.create<LLVM::AssumeOp>(
op->getLoc(),
rewriter.create<LLVM::ICmpOp>(
@@ -3477,9 +3502,12 @@ void mlir::populateStdToLLVMMemoryConversionPatterns(
StoreOpLowering,
SubViewOpLowering,
TransposeOpLowering,
- ViewOpLowering,
- AllocOpLowering>(converter);
+ ViewOpLowering>(converter);
// clang-format on
+ if (converter.getOptions().useAlignedAlloc)
+ patterns.insert<AlignedAllocOpLowering>(converter);
+ else
+ patterns.insert<AllocOpLowering>(converter);
}
void mlir::populateStdToLLVMFuncOpConversionPattern(
diff --git a/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir b/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir
index 4e3edd4c7c15..8e7b22574432 100644
--- a/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir
+++ b/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir
@@ -36,7 +36,6 @@ func @mixed_alloc(%arg0: index, %arg1: index) -> memref<?x42x?xf32> {
// CHECK-NEXT: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[one]]] : (!llvm.ptr<float>, !llvm.i64) -> !llvm.ptr<float>
// CHECK-NEXT: %[[sizeof:.*]] = llvm.ptrtoint %[[gep]] : !llvm.ptr<float> to !llvm.i64
// CHECK-NEXT: %[[sz_bytes:.*]] = llvm.mul %[[sz]], %[[sizeof]] : !llvm.i64
-// CHECK-NEXT: %[[one_1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
// CHECK-NEXT: llvm.call @malloc(%[[sz_bytes]]) : (!llvm.i64) -> !llvm.ptr<i8>
// CHECK-NEXT: llvm.bitcast %{{.*}} : !llvm.ptr<i8> to !llvm.ptr<float>
// CHECK-NEXT: llvm.mlir.undef : !llvm.struct<(ptr<float>, ptr<float>, i64, array<3 x i64>, array<3 x i64>)>
@@ -77,7 +76,6 @@ func @dynamic_alloc(%arg0: index, %arg1: index) -> memref<?x?xf32> {
// CHECK-NEXT: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[one]]] : (!llvm.ptr<float>, !llvm.i64) -> !llvm.ptr<float>
// CHECK-NEXT: %[[sizeof:.*]] = llvm.ptrtoint %[[gep]] : !llvm.ptr<float> to !llvm.i64
// CHECK-NEXT: %[[sz_bytes:.*]] = llvm.mul %[[sz]], %[[sizeof]] : !llvm.i64
-// CHECK-NEXT: %[[one_1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
// CHECK-NEXT: llvm.call @malloc(%[[sz_bytes]]) : (!llvm.i64) -> !llvm.ptr<i8>
// CHECK-NEXT: llvm.bitcast %{{.*}} : !llvm.ptr<i8> to !llvm.ptr<float>
// CHECK-NEXT: llvm.mlir.undef : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
@@ -107,7 +105,6 @@ func @dynamic_alloca(%arg0: index, %arg1: index) -> memref<?x?xf32> {
// CHECK-NEXT: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[one]]] : (!llvm.ptr<float>, !llvm.i64) -> !llvm.ptr<float>
// CHECK-NEXT: %[[sizeof:.*]] = llvm.ptrtoint %[[gep]] : !llvm.ptr<float> to !llvm.i64
// CHECK-NEXT: %[[sz_bytes:.*]] = llvm.mul %[[num_elems]], %[[sizeof]] : !llvm.i64
-// CHECK-NEXT: %[[one_1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
// CHECK-NEXT: %[[allocated:.*]] = llvm.alloca %[[sz_bytes]] x !llvm.float : (!llvm.i64) -> !llvm.ptr<float>
// CHECK-NEXT: llvm.mlir.undef : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK-NEXT: llvm.insertvalue %[[allocated]], %{{.*}}[0] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
@@ -153,8 +150,7 @@ func @stdlib_aligned_alloc(%N : index) -> memref<32x18xf32> {
// ALIGNED-ALLOC-NEXT: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[one]]] : (!llvm.ptr<float>, !llvm.i64) -> !llvm.ptr<float>
// ALIGNED-ALLOC-NEXT: %[[sizeof:.*]] = llvm.ptrtoint %[[gep]] : !llvm.ptr<float> to !llvm.i64
// ALIGNED-ALLOC-NEXT: %[[bytes:.*]] = llvm.mul %[[num_elems]], %[[sizeof]] : !llvm.i64
-// ALIGNED-ALLOC-NEXT: %[[one_1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
-// ALIGNED-ALLOC-NEXT: %[[alignment:.*]] = llvm.mlir.constant(32 : i64) : !llvm.i64
+// ALIGNED-ALLOC-NEXT: %[[alignment:.*]] = llvm.mlir.constant(32 : index) : !llvm.i64
// ALIGNED-ALLOC-NEXT: %[[allocated:.*]] = llvm.call @aligned_alloc(%[[alignment]], %[[bytes]]) : (!llvm.i64, !llvm.i64) -> !llvm.ptr<i8>
// ALIGNED-ALLOC-NEXT: llvm.bitcast %[[allocated]] : !llvm.ptr<i8> to !llvm.ptr<float>
%0 = alloc() {alignment = 32} : memref<32x18xf32>
@@ -164,26 +160,27 @@ func @stdlib_aligned_alloc(%N : index) -> memref<32x18xf32> {
%1 = alloc() {alignment = 64} : memref<4096xf32>
// Alignment is to element type boundaries (minimum 16 bytes).
- // ALIGNED-ALLOC: %[[c32:.*]] = llvm.mlir.constant(32 : i64) : !llvm.i64
+ // ALIGNED-ALLOC: %[[c32:.*]] = llvm.mlir.constant(32 : index) : !llvm.i64
// ALIGNED-ALLOC-NEXT: llvm.call @aligned_alloc(%[[c32]]
%2 = alloc() : memref<4096xvector<8xf32>>
// The minimum alignment is 16 bytes unless explicitly specified.
- // ALIGNED-ALLOC: %[[c16:.*]] = llvm.mlir.constant(16 : i64) : !llvm.i64
+ // ALIGNED-ALLOC: %[[c16:.*]] = llvm.mlir.constant(16 : index) : !llvm.i64
// ALIGNED-ALLOC-NEXT: llvm.call @aligned_alloc(%[[c16]],
%3 = alloc() : memref<4096xvector<2xf32>>
- // ALIGNED-ALLOC: %[[c8:.*]] = llvm.mlir.constant(8 : i64) : !llvm.i64
+ // ALIGNED-ALLOC: %[[c8:.*]] = llvm.mlir.constant(8 : index) : !llvm.i64
// ALIGNED-ALLOC-NEXT: llvm.call @aligned_alloc(%[[c8]],
%4 = alloc() {alignment = 8} : memref<1024xvector<4xf32>>
// Bump the memref allocation size if its size is not a multiple of alignment.
- // ALIGNED-ALLOC: %[[c32:.*]] = llvm.mlir.constant(32 : i64) : !llvm.i64
- // ALIGNED-ALLOC-NEXT: llvm.urem
+ // ALIGNED-ALLOC: %[[c32:.*]] = llvm.mlir.constant(32 : index) : !llvm.i64
+ // ALIGNED-ALLOC-NEXT: llvm.mlir.constant(1 : index) : !llvm.i64
// ALIGNED-ALLOC-NEXT: llvm.sub
+ // ALIGNED-ALLOC-NEXT: llvm.add
// ALIGNED-ALLOC-NEXT: llvm.urem
- // ALIGNED-ALLOC-NEXT: %[[SIZE_ALIGNED:.*]] = llvm.add
+ // ALIGNED-ALLOC-NEXT: %[[SIZE_ALIGNED:.*]] = llvm.sub
// ALIGNED-ALLOC-NEXT: llvm.call @aligned_alloc(%[[c32]], %[[SIZE_ALIGNED]])
%5 = alloc() {alignment = 32} : memref<100xf32>
// Bump alignment to the next power of two if it isn't.
- // ALIGNED-ALLOC: %[[c128:.*]] = llvm.mlir.constant(128 : i64) : !llvm.i64
+ // ALIGNED-ALLOC: %[[c128:.*]] = llvm.mlir.constant(128 : index) : !llvm.i64
// ALIGNED-ALLOC: llvm.call @aligned_alloc(%[[c128]]
%6 = alloc(%N) : memref<?xvector<18xf32>>
return %0 : memref<32x18xf32>
diff --git a/mlir/test/Conversion/StandardToLLVM/convert-static-memref-ops.mlir b/mlir/test/Conversion/StandardToLLVM/convert-static-memref-ops.mlir
index b93446f00d2e..d9d93b7823b8 100644
--- a/mlir/test/Conversion/StandardToLLVM/convert-static-memref-ops.mlir
+++ b/mlir/test/Conversion/StandardToLLVM/convert-static-memref-ops.mlir
@@ -76,7 +76,6 @@ func @zero_d_alloc() -> memref<f32> {
// CHECK-NEXT: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[one]]] : (!llvm.ptr<float>, !llvm.i64) -> !llvm.ptr<float>
// CHECK-NEXT: %[[sizeof:.*]] = llvm.ptrtoint %[[gep]] : !llvm.ptr<float> to !llvm.i64
// CHECK-NEXT: llvm.mul %{{.*}}, %[[sizeof]] : !llvm.i64
-// CHECK-NEXT: %[[one_1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
// CHECK-NEXT: llvm.call @malloc(%{{.*}}) : (!llvm.i64) -> !llvm.ptr<i8>
// CHECK-NEXT: %[[ptr:.*]] = llvm.bitcast %{{.*}} : !llvm.ptr<i8> to !llvm.ptr<float>
// CHECK-NEXT: llvm.mlir.undef : !llvm.struct<(ptr<float>, ptr<float>, i64)>
@@ -91,7 +90,6 @@ func @zero_d_alloc() -> memref<f32> {
// BAREPTR-NEXT: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[one]]] : (!llvm.ptr<float>, !llvm.i64) -> !llvm.ptr<float>
// BAREPTR-NEXT: %[[sizeof:.*]] = llvm.ptrtoint %[[gep]] : !llvm.ptr<float> to !llvm.i64
// BAREPTR-NEXT: llvm.mul %{{.*}}, %[[sizeof]] : !llvm.i64
-// BAREPTR-NEXT: %[[one_1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
// BAREPTR-NEXT: llvm.call @malloc(%{{.*}}) : (!llvm.i64) -> !llvm.ptr<i8>
// BAREPTR-NEXT: %[[ptr:.*]] = llvm.bitcast %{{.*}} : !llvm.ptr<i8> to !llvm.ptr<float>
// BAREPTR-NEXT: llvm.mlir.undef : !llvm.struct<(ptr<float>, ptr<float>, i64)>
@@ -130,19 +128,19 @@ func @aligned_1d_alloc() -> memref<42xf32> {
// CHECK-NEXT: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[one]]] : (!llvm.ptr<float>, !llvm.i64) -> !llvm.ptr<float>
// CHECK-NEXT: %[[sizeof:.*]] = llvm.ptrtoint %[[gep]] : !llvm.ptr<float> to !llvm.i64
// CHECK-NEXT: llvm.mul %{{.*}}, %[[sizeof]] : !llvm.i64
-// CHECK-NEXT: %[[one_1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
// CHECK-NEXT: %[[alignment:.*]] = llvm.mlir.constant(8 : index) : !llvm.i64
// CHECK-NEXT: %[[allocsize:.*]] = llvm.add {{.*}}, %[[alignment]] : !llvm.i64
// CHECK-NEXT: %[[allocated:.*]] = llvm.call @malloc(%[[allocsize]]) : (!llvm.i64) -> !llvm.ptr<i8>
// CHECK-NEXT: %[[ptr:.*]] = llvm.bitcast %{{.*}} : !llvm.ptr<i8> to !llvm.ptr<float>
+// CHECK-NEXT: %[[allocatedAsInt:.*]] = llvm.ptrtoint %[[ptr]] : !llvm.ptr<float> to !llvm.i64
+// CHECK-NEXT: %[[one_1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
+// CHECK-NEXT: %[[bump:.*]] = llvm.sub %[[alignment]], %[[one_1]] : !llvm.i64
+// CHECK-NEXT: %[[bumped:.*]] = llvm.add %[[allocatedAsInt]], %[[bump]] : !llvm.i64
+// CHECK-NEXT: %[[mod:.*]] = llvm.urem %[[bumped]], %[[alignment]] : !llvm.i64
+// CHECK-NEXT: %[[aligned:.*]] = llvm.sub %[[bumped]], %[[mod]] : !llvm.i64
+// CHECK-NEXT: %[[alignedBitCast:.*]] = llvm.inttoptr %[[aligned]] : !llvm.i64 to !llvm.ptr<float>
// CHECK-NEXT: llvm.mlir.undef : !llvm.struct<(ptr<float>, ptr<float>, i64, array<1 x i64>, array<1 x i64>)>
// CHECK-NEXT: llvm.insertvalue %[[ptr]], %{{.*}}[0] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<1 x i64>, array<1 x i64>)>
-// CHECK-NEXT: %[[allocatedAsInt:.*]] = llvm.ptrtoint %[[allocated]] : !llvm.ptr<i8> to !llvm.i64
-// CHECK-NEXT: %[[alignAdj1:.*]] = llvm.urem %[[allocatedAsInt]], %[[alignment]] : !llvm.i64
-// CHECK-NEXT: %[[alignAdj2:.*]] = llvm.sub %[[alignment]], %[[alignAdj1]] : !llvm.i64
-// CHECK-NEXT: %[[alignAdj3:.*]] = llvm.urem %[[alignAdj2]], %[[alignment]] : !llvm.i64
-// CHECK-NEXT: %[[aligned:.*]] = llvm.getelementptr %[[allocated]][%[[alignAdj3]]] : (!llvm.ptr<i8>, !llvm.i64) -> !llvm.ptr<i8>
-// CHECK-NEXT: %[[alignedBitCast:.*]] = llvm.bitcast %[[aligned]] : !llvm.ptr<i8> to !llvm.ptr<float>
// CHECK-NEXT: llvm.insertvalue %[[alignedBitCast]], %{{.*}}[1] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<1 x i64>, array<1 x i64>)>
// CHECK-NEXT: %[[c0:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
// CHECK-NEXT: llvm.insertvalue %[[c0]], %{{.*}}[2] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<1 x i64>, array<1 x i64>)>
@@ -153,19 +151,19 @@ func @aligned_1d_alloc() -> memref<42xf32> {
// BAREPTR-NEXT: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[one]]] : (!llvm.ptr<float>, !llvm.i64) -> !llvm.ptr<float>
// BAREPTR-NEXT: %[[sizeof:.*]] = llvm.ptrtoint %[[gep]] : !llvm.ptr<float> to !llvm.i64
// BAREPTR-NEXT: llvm.mul %{{.*}}, %[[sizeof]] : !llvm.i64
-// BAREPTR-NEXT: %[[one_1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
// BAREPTR-NEXT: %[[alignment:.*]] = llvm.mlir.constant(8 : index) : !llvm.i64
// BAREPTR-NEXT: %[[allocsize:.*]] = llvm.add {{.*}}, %[[alignment]] : !llvm.i64
// BAREPTR-NEXT: %[[allocated:.*]] = llvm.call @malloc(%[[allocsize]]) : (!llvm.i64) -> !llvm.ptr<i8>
// BAREPTR-NEXT: %[[ptr:.*]] = llvm.bitcast %{{.*}} : !llvm.ptr<i8> to !llvm.ptr<float>
+// BAREPTR-NEXT: %[[allocatedAsInt:.*]] = llvm.ptrtoint %[[ptr]] : !llvm.ptr<float> to !llvm.i64
+// BAREPTR-NEXT: %[[one_2:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
+// BAREPTR-NEXT: %[[bump:.*]] = llvm.sub %[[alignment]], %[[one_2]] : !llvm.i64
+// BAREPTR-NEXT: %[[bumped:.*]] = llvm.add %[[allocatedAsInt]], %[[bump]] : !llvm.i64
+// BAREPTR-NEXT: %[[mod:.*]] = llvm.urem %[[bumped]], %[[alignment]] : !llvm.i64
+// BAREPTR-NEXT: %[[aligned:.*]] = llvm.sub %[[bumped]], %[[mod]] : !llvm.i64
+// BAREPTR-NEXT: %[[alignedBitCast:.*]] = llvm.inttoptr %[[aligned]] : !llvm.i64 to !llvm.ptr<float>
// BAREPTR-NEXT: llvm.mlir.undef : !llvm.struct<(ptr<float>, ptr<float>, i64, array<1 x i64>, array<1 x i64>)>
// BAREPTR-NEXT: llvm.insertvalue %[[ptr]], %{{.*}}[0] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<1 x i64>, array<1 x i64>)>
-// BAREPTR-NEXT: %[[allocatedAsInt:.*]] = llvm.ptrtoint %[[allocated]] : !llvm.ptr<i8> to !llvm.i64
-// BAREPTR-NEXT: %[[alignAdj1:.*]] = llvm.urem %[[allocatedAsInt]], %[[alignment]] : !llvm.i64
-// BAREPTR-NEXT: %[[alignAdj2:.*]] = llvm.sub %[[alignment]], %[[alignAdj1]] : !llvm.i64
-// BAREPTR-NEXT: %[[alignAdj3:.*]] = llvm.urem %[[alignAdj2]], %[[alignment]] : !llvm.i64
-// BAREPTR-NEXT: %[[aligned:.*]] = llvm.getelementptr %[[allocated]][%[[alignAdj3]]] : (!llvm.ptr<i8>, !llvm.i64) -> !llvm.ptr<i8>
-// BAREPTR-NEXT: %[[alignedBitCast:.*]] = llvm.bitcast %[[aligned]] : !llvm.ptr<i8> to !llvm.ptr<float>
// BAREPTR-NEXT: llvm.insertvalue %[[alignedBitCast]], %{{.*}}[1] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<1 x i64>, array<1 x i64>)>
// BAREPTR-NEXT: %[[c0:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
// BAREPTR-NEXT: llvm.insertvalue %[[c0]], %{{.*}}[2] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<1 x i64>, array<1 x i64>)>
@@ -186,7 +184,6 @@ func @static_alloc() -> memref<32x18xf32> {
// CHECK-NEXT: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[one]]] : (!llvm.ptr<float>, !llvm.i64) -> !llvm.ptr<float>
// CHECK-NEXT: %[[sizeof:.*]] = llvm.ptrtoint %[[gep]] : !llvm.ptr<float> to !llvm.i64
// CHECK-NEXT: %[[bytes:.*]] = llvm.mul %[[num_elems]], %[[sizeof]] : !llvm.i64
-// CHECK-NEXT: %[[one_1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
// CHECK-NEXT: %[[allocated:.*]] = llvm.call @malloc(%[[bytes]]) : (!llvm.i64) -> !llvm.ptr<i8>
// CHECK-NEXT: llvm.bitcast %[[allocated]] : !llvm.ptr<i8> to !llvm.ptr<float>
@@ -198,7 +195,6 @@ func @static_alloc() -> memref<32x18xf32> {
// BAREPTR-NEXT: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[one]]] : (!llvm.ptr<float>, !llvm.i64) -> !llvm.ptr<float>
// BAREPTR-NEXT: %[[sizeof:.*]] = llvm.ptrtoint %[[gep]] : !llvm.ptr<float> to !llvm.i64
// BAREPTR-NEXT: %[[bytes:.*]] = llvm.mul %[[num_elems]], %[[sizeof]] : !llvm.i64
-// BAREPTR-NEXT: %[[one_1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
// BAREPTR-NEXT: %[[allocated:.*]] = llvm.call @malloc(%[[bytes]]) : (!llvm.i64) -> !llvm.ptr<i8>
// BAREPTR-NEXT: llvm.bitcast %[[allocated]] : !llvm.ptr<i8> to !llvm.ptr<float>
%0 = alloc() : memref<32x18xf32>
@@ -217,7 +213,6 @@ func @static_alloca() -> memref<32x18xf32> {
// CHECK-NEXT: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[one]]] : (!llvm.ptr<float>, !llvm.i64) -> !llvm.ptr<float>
// CHECK-NEXT: %[[sizeof:.*]] = llvm.ptrtoint %[[gep]] : !llvm.ptr<float> to !llvm.i64
// CHECK-NEXT: %[[bytes:.*]] = llvm.mul %[[num_elems]], %[[sizeof]] : !llvm.i64
-// CHECK-NEXT: %[[one_1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
// CHECK-NEXT: %[[allocated:.*]] = llvm.alloca %[[bytes]] x !llvm.float : (!llvm.i64) -> !llvm.ptr<float>
%0 = alloca() : memref<32x18xf32>
More information about the Mlir-commits
mailing list