[Mlir-commits] [mlir] 01d97a3 - [MLIR] Add support to use aligned_alloc to lower AllocOp from std to llvm

Uday Bondhugula llvmlistbot at llvm.org
Wed Apr 8 02:40:53 PDT 2020


Author: Uday Bondhugula
Date: 2020-04-08T15:10:19+05:30
New Revision: 01d97a35493a8a306bfaa3ceb3e6fa49b05dea89

URL: https://github.com/llvm/llvm-project/commit/01d97a35493a8a306bfaa3ceb3e6fa49b05dea89
DIFF: https://github.com/llvm/llvm-project/commit/01d97a35493a8a306bfaa3ceb3e6fa49b05dea89.diff

LOG: [MLIR] Add support to use aligned_alloc to lower AllocOp from std to llvm

Support to recognize and deal with aligned_alloc was recently added to
LLVM's TLI/MemoryBuiltins and its various optimization passes. This
revision adds support for generation of aligned_alloc's when lowering
AllocOp from std to LLVM. Setting 'use-aligned_alloc=1' will lead to
aligned_alloc being used for all heap allocations. An alignment and size
that works with the constraints of aligned_alloc is chosen.

Using aligned_alloc is preferable to "using malloc and adjusting the
allocated pointer to align for indexing" because the pointer access
arithmetic done for the latter only makes it harder for LLVM passes to
deal with for analysis, optimization, attribute deduction, and rewrites.

Differential Revision: https://reviews.llvm.org/D77528

Added: 
    

Modified: 
    mlir/include/mlir/Conversion/Passes.td
    mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h
    mlir/lib/Analysis/Utils.cpp
    mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
    mlir/lib/Dialect/StandardOps/IR/Ops.cpp
    mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 078d978b59af..7a24e0f6e72e 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -226,6 +226,8 @@ def ConvertStandardToLLVM : Pass<"convert-std-to-llvm", "ModuleOp"> {
   }];
   let constructor = "mlir::createLowerToLLVMPass()";
   let options = [
+    Option<"useAlignedAlloc", "use-aligned-alloc", "bool", /*default=*/"false",
+           "Use aligned_alloc in place of malloc for heap allocations">,
     Option<"useBarePtrCallConv", "use-bare-ptr-memref-call-conv", "bool",
            /*default=*/"false",
            "Replace FuncOp's MemRef arguments with bare pointers to the MemRef "

diff  --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h
index 72f852e57187..5479f189b73c 100644
--- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h
+++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h
@@ -20,8 +20,9 @@ class OwningRewritePatternList;
 /// Collect a set of patterns to convert memory-related operations from the
 /// Standard dialect to the LLVM dialect, excluding non-memory-related
 /// operations and FuncOp.
-void populateStdToLLVMMemoryConversionPatters(
-    LLVMTypeConverter &converter, OwningRewritePatternList &patterns);
+void populateStdToLLVMMemoryConversionPatterns(
+    LLVMTypeConverter &converter, OwningRewritePatternList &patterns,
+    bool useAlignedAlloc);
 
 /// Collect a set of patterns to convert from the Standard dialect to the LLVM
 /// dialect, excluding the memory-related operations.
@@ -40,13 +41,15 @@ void populateStdToLLVMDefaultFuncOpConversionPattern(
 /// LLVM.
 void populateStdToLLVMConversionPatterns(LLVMTypeConverter &converter,
                                          OwningRewritePatternList &patterns,
-                                         bool emitCWrappers = false);
+                                         bool emitCWrappers = false,
+                                         bool useAlignedAlloc = false);
 
 /// Collect a set of patterns to convert from the Standard dialect to
 /// LLVM using the bare pointer calling convention for MemRef function
 /// arguments.
 void populateStdToLLVMBarePtrConversionPatterns(
-    LLVMTypeConverter &converter, OwningRewritePatternList &patterns);
+    LLVMTypeConverter &converter, OwningRewritePatternList &patterns,
+    bool useAlignedAlloc);
 
 /// Value to pass as bitwidth for the index type when the converter is expected
 /// to derive the bitwidth from the LLVM data layout.
@@ -56,15 +59,18 @@ struct LowerToLLVMOptions {
   bool useBarePtrCallConv = false;
   bool emitCWrappers = false;
   unsigned indexBitwidth = kDeriveIndexBitwidthFromDataLayout;
+  /// Use aligned_alloc for heap allocations.
+  bool useAlignedAlloc = false;
 };
 
 /// Creates a pass to convert the Standard dialect into the LLVMIR dialect.
-/// stdlib malloc/free is used for allocating memrefs allocated with std.alloc,
-/// while LLVM's alloca is used for those allocated with std.alloca.
-std::unique_ptr<OperationPass<ModuleOp>> createLowerToLLVMPass(
-    const LowerToLLVMOptions &options = {
-        /*useBarePtrCallConv=*/false, /*emitCWrappers=*/false,
-        /*indexBitwidth=*/kDeriveIndexBitwidthFromDataLayout});
+/// stdlib malloc/free is used by default for allocating memrefs allocated with
+/// std.alloc, while LLVM's alloca is used for those allocated with std.alloca.
+std::unique_ptr<OperationPass<ModuleOp>>
+createLowerToLLVMPass(const LowerToLLVMOptions &options = {
+                          /*useBarePtrCallConv=*/false, /*emitCWrappers=*/false,
+                          /*indexBitwidth=*/kDeriveIndexBitwidthFromDataLayout,
+                          /*useAlignedAlloc=*/false});
 
 } // namespace mlir
 

diff  --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp
index 67d0138c8808..3d159f24bdf1 100644
--- a/mlir/lib/Analysis/Utils.cpp
+++ b/mlir/lib/Analysis/Utils.cpp
@@ -344,7 +344,6 @@ LogicalResult MemRefRegion::compute(Operation *op, unsigned loopDepth,
   return success();
 }
 
-//  TODO(mlir-team): improve/complete this when we have target data.
 static unsigned getMemRefEltSizeInBytes(MemRefType memRefType) {
   auto elementType = memRefType.getElementType();
 

diff  --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index 92b02704e015..93847d0eb644 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -1248,8 +1248,10 @@ struct AllocLikeOpLowering : public ConvertOpToLLVMPattern<AllocLikeOp> {
   using ConvertOpToLLVMPattern<AllocLikeOp>::typeConverter;
   using ConvertOpToLLVMPattern<AllocLikeOp>::getVoidPtrType;
 
-  explicit AllocLikeOpLowering(LLVMTypeConverter &converter)
-      : ConvertOpToLLVMPattern<AllocLikeOp>(converter) {}
+  explicit AllocLikeOpLowering(LLVMTypeConverter &converter,
+                               bool useAlignedAlloc = false)
+      : ConvertOpToLLVMPattern<AllocLikeOp>(converter),
+        useAlignedAlloc(useAlignedAlloc) {}
 
   LogicalResult match(Operation *op) const override {
     MemRefType memRefType = cast<AllocLikeOp>(op).getType();
@@ -1271,6 +1273,18 @@ struct AllocLikeOpLowering : public ConvertOpToLLVMPattern<AllocLikeOp> {
     return success();
   }
 
+  // Returns bump = (alignment - (input % alignment))% alignment, which is the
+  // increment necessary to align `input` to `alignment` boundary.
+  // TODO: this can be made more efficient by just using a single addition
+  // and two bit shifts: (ptr + align - 1)/align, align is always power of 2.
+  Value createBumpToAlign(Location loc, OpBuilder b, Value input,
+                          Value alignment) const {
+    Value modAlign = b.create<LLVM::URemOp>(loc, input, alignment);
+    Value 
diff  = b.create<LLVM::SubOp>(loc, alignment, modAlign);
+    Value shift = b.create<LLVM::URemOp>(loc, 
diff , alignment);
+    return shift;
+  }
+
   /// Creates and populates the memref descriptor struct given all its fields.
   /// This method also performs any post allocation alignment needed for heap
   /// allocations when `accessAlignment` is non null. This is used with
@@ -1292,12 +1306,7 @@ struct AllocLikeOpLowering : public ConvertOpToLLVMPattern<AllocLikeOp> {
       // offset = (align - (ptr % align))% align
       Value intVal = rewriter.create<LLVM::PtrToIntOp>(
           loc, this->getIndexType(), allocatedBytePtr);
-      Value ptrModAlign =
-          rewriter.create<LLVM::URemOp>(loc, intVal, accessAlignment);
-      Value subbed =
-          rewriter.create<LLVM::SubOp>(loc, accessAlignment, ptrModAlign);
-      Value offset =
-          rewriter.create<LLVM::URemOp>(loc, subbed, accessAlignment);
+      Value offset = createBumpToAlign(loc, rewriter, intVal, accessAlignment);
       Value aligned = rewriter.create<LLVM::GEPOp>(
           loc, allocatedBytePtr.getType(), allocatedBytePtr, offset);
       alignedBytePtr = rewriter.create<LLVM::BitcastOp>(
@@ -1388,38 +1397,90 @@ struct AllocLikeOpLowering : public ConvertOpToLLVMPattern<AllocLikeOp> {
         memRefType.getMemorySpace());
   }
 
+  /// Returns the memref's element size in bytes.
+  // TODO: there are other places where this is used. Expose publicly?
+  static unsigned getMemRefEltSizeInBytes(MemRefType memRefType) {
+    auto elementType = memRefType.getElementType();
+
+    unsigned sizeInBits;
+    if (elementType.isIntOrFloat()) {
+      sizeInBits = elementType.getIntOrFloatBitWidth();
+    } else {
+      auto vectorType = elementType.cast<VectorType>();
+      sizeInBits =
+          vectorType.getElementTypeBitWidth() * vectorType.getNumElements();
+    }
+    return llvm::divideCeil(sizeInBits, 8);
+  }
+
+  /// Returns the alignment to be used for the allocation call itself.
+  /// aligned_alloc requires the allocation size to be a power of two, and the
+  /// allocation size to be a multiple of alignment,
+  Optional<int64_t> getAllocationAlignment(AllocOp allocOp) const {
+    // No alignment can be used for the 'malloc' call itself.
+    if (!useAlignedAlloc)
+      return None;
+
+    if (allocOp.alignment())
+      return allocOp.alignment().getValue().getSExtValue();
+
+    // Whenever we don't have alignment set, we will use an alignment
+    // consistent with the element type; since the allocation size has to be a
+    // power of two, we will bump to the next power of two if it already isn't.
+    auto eltSizeBytes = getMemRefEltSizeInBytes(allocOp.getType());
+    return std::max(kMinAlignedAllocAlignment,
+                    llvm::PowerOf2Ceil(eltSizeBytes));
+  }
+
+  /// Returns true if the memref size in bytes is known to be a multiple of
+  /// factor.
+  static bool isMemRefSizeMultipleOf(MemRefType type, uint64_t factor) {
+    uint64_t sizeDivisor = getMemRefEltSizeInBytes(type);
+    for (unsigned i = 0, e = type.getRank(); i < e; i++) {
+      if (type.isDynamic(type.getDimSize(i)))
+        continue;
+      sizeDivisor = sizeDivisor * type.getDimSize(i);
+    }
+    return sizeDivisor % factor == 0;
+  }
+
   /// Allocates the underlying buffer using the right call. `allocatedBytePtr`
   /// is set to null for stack allocations. `accessAlignment` is set if
   /// alignment is neeeded post allocation (for eg. in conjunction with malloc).
-  /// TODO(bondhugula): next revision will support std lib func aligned_alloc.
   Value allocateBuffer(Location loc, Value cumulativeSize, Operation *op,
                        MemRefType memRefType, Value one, Value &accessAlignment,
                        Value &allocatedBytePtr,
                        ConversionPatternRewriter &rewriter) const {
     auto elementPtrType = getElementPtrType(memRefType);
 
-    // Whether to use std lib function aligned_alloc that supports alignment.
-    Optional<APInt> allocationAlignment = cast<AllocLikeOp>(op).alignment();
-
     // With alloca, one gets a pointer to the element type right away.
-    bool onStack = isa<AllocaOp>(op);
-    if (onStack) {
+    // For stack allocations.
+    if (auto allocaOp = dyn_cast<AllocaOp>(op)) {
       allocatedBytePtr = nullptr;
       accessAlignment = nullptr;
       return rewriter.create<LLVM::AllocaOp>(
           loc, elementPtrType, cumulativeSize,
-          allocationAlignment ? allocationAlignment.getValue().getSExtValue()
-                              : 0);
+          allocaOp.alignment() ? allocaOp.alignment().getValue().getSExtValue()
+                               : 0);
     }
 
-    // Use malloc. Insert the malloc declaration if it is not already present.
-    auto allocFuncName = "malloc";
+    // Heap allocations.
     AllocOp allocOp = cast<AllocOp>(op);
+
+    Optional<int64_t> allocationAlignment = getAllocationAlignment(allocOp);
+    // Whether to use std lib function aligned_alloc that supports alignment.
+    bool useAlignedAlloc = allocationAlignment.hasValue();
+
+    // Insert the malloc/aligned_alloc declaration if it is not already present.
+    auto allocFuncName = useAlignedAlloc ? "aligned_alloc" : "malloc";
     auto module = allocOp.getParentOfType<ModuleOp>();
     auto allocFunc = module.lookupSymbol<LLVM::LLVMFuncOp>(allocFuncName);
     if (!allocFunc) {
       OpBuilder moduleBuilder(op->getParentOfType<ModuleOp>().getBodyRegion());
       SmallVector<LLVM::LLVMType, 2> callArgTypes = {getIndexType()};
+      // aligned_alloc(size_t alignment, size_t size)
+      if (useAlignedAlloc)
+        callArgTypes.push_back(getIndexType());
       allocFunc = moduleBuilder.create<LLVM::LLVMFuncOp>(
           rewriter.getUnknownLoc(), allocFuncName,
           LLVM::LLVMType::getFunctionTy(getVoidPtrType(), callArgTypes,
@@ -1429,16 +1490,33 @@ struct AllocLikeOpLowering : public ConvertOpToLLVMPattern<AllocLikeOp> {
     // Allocate the underlying buffer and store a pointer to it in the MemRef
     // descriptor.
     SmallVector<Value, 2> callArgs;
-    // Adjust the allocation size to consider alignment.
-    if (allocOp.alignment()) {
-      accessAlignment = createIndexConstant(
-          rewriter, loc, allocOp.alignment().getValue().getSExtValue());
-      cumulativeSize = rewriter.create<LLVM::SubOp>(
-          loc,
-          rewriter.create<LLVM::AddOp>(loc, cumulativeSize, accessAlignment),
-          one);
+    if (useAlignedAlloc) {
+      // Use aligned_alloc.
+      assert(allocationAlignment && "allocation alignment should be present");
+      auto alignedAllocAlignmentValue = rewriter.create<LLVM::ConstantOp>(
+          loc, typeConverter.convertType(rewriter.getIntegerType(64)),
+          rewriter.getI64IntegerAttr(allocationAlignment.getValue()));
+      // aligned_alloc requires size to be a multiple of alignment; we will pad
+      // the size to the next multiple if necessary.
+      if (!isMemRefSizeMultipleOf(memRefType, allocationAlignment.getValue())) {
+        Value bump = createBumpToAlign(loc, rewriter, cumulativeSize,
+                                       alignedAllocAlignmentValue);
+        cumulativeSize =
+            rewriter.create<LLVM::AddOp>(loc, cumulativeSize, bump);
+      }
+      callArgs = {alignedAllocAlignmentValue, cumulativeSize};
+    } else {
+      // Adjust the allocation size to consider alignment.
+      if (allocOp.alignment()) {
+        accessAlignment = createIndexConstant(
+            rewriter, loc, allocOp.alignment().getValue().getSExtValue());
+        cumulativeSize = rewriter.create<LLVM::SubOp>(
+            loc,
+            rewriter.create<LLVM::AddOp>(loc, cumulativeSize, accessAlignment),
+            one);
+      }
+      callArgs.push_back(cumulativeSize);
     }
-    callArgs.push_back(cumulativeSize);
     auto allocFuncSymbol = rewriter.getSymbolRefAttr(allocFunc);
     allocatedBytePtr = rewriter
                            .create<LLVM::CallOp>(loc, getVoidPtrType(),
@@ -1510,11 +1588,20 @@ struct AllocLikeOpLowering : public ConvertOpToLLVMPattern<AllocLikeOp> {
     // Return the final value of the descriptor.
     rewriter.replaceOp(op, {memRefDescriptor});
   }
+
+protected:
+  /// Use aligned_alloc instead of malloc for all heap allocations.
+  bool useAlignedAlloc;
+  /// The minimum alignment to use with aligned_alloc (has to be a power of 2).
+  uint64_t kMinAlignedAllocAlignment = 16UL;
 };
 
 struct AllocOpLowering : public AllocLikeOpLowering<AllocOp> {
-  using Base::Base;
+  explicit AllocOpLowering(LLVMTypeConverter &converter,
+                           bool useAlignedAlloc = false)
+      : AllocLikeOpLowering<AllocOp>(converter, useAlignedAlloc) {}
 };
+
 struct AllocaOpLowering : public AllocLikeOpLowering<AllocaOp> {
   using Base::Base;
 };
@@ -2738,11 +2825,13 @@ void mlir::populateStdToLLVMNonMemoryConversionPatterns(
   // clang-format on
 }
 
-void mlir::populateStdToLLVMMemoryConversionPatters(
-    LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
+void mlir::populateStdToLLVMMemoryConversionPatterns(
+    LLVMTypeConverter &converter, OwningRewritePatternList &patterns,
+    bool useAlignedAlloc) {
   // clang-format off
   patterns.insert<
       AssumeAlignmentOpLowering,
+      DeallocOpLowering,
       DimOpLowering,
       LoadOpLowering,
       MemRefCastOpLowering,
@@ -2750,8 +2839,8 @@ void mlir::populateStdToLLVMMemoryConversionPatters(
       SubViewOpLowering,
       ViewOpLowering>(converter);
   patterns.insert<
-      AllocOpLowering,
-      DeallocOpLowering>(converter);
+      AllocOpLowering
+      >(converter, useAlignedAlloc);
   // clang-format on
 }
 
@@ -2763,11 +2852,12 @@ void mlir::populateStdToLLVMDefaultFuncOpConversionPattern(
 
 void mlir::populateStdToLLVMConversionPatterns(
     LLVMTypeConverter &converter, OwningRewritePatternList &patterns,
-    bool emitCWrappers) {
+    bool emitCWrappers, bool useAlignedAlloc) {
   populateStdToLLVMDefaultFuncOpConversionPattern(converter, patterns,
                                                   emitCWrappers);
   populateStdToLLVMNonMemoryConversionPatterns(converter, patterns);
-  populateStdToLLVMMemoryConversionPatters(converter, patterns);
+  populateStdToLLVMMemoryConversionPatterns(converter, patterns,
+                                            useAlignedAlloc);
 }
 
 static void populateStdToLLVMBarePtrFuncOpConversionPattern(
@@ -2776,10 +2866,12 @@ static void populateStdToLLVMBarePtrFuncOpConversionPattern(
 }
 
 void mlir::populateStdToLLVMBarePtrConversionPatterns(
-    LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
+    LLVMTypeConverter &converter, OwningRewritePatternList &patterns,
+    bool useAlignedAlloc) {
   populateStdToLLVMBarePtrFuncOpConversionPattern(converter, patterns);
   populateStdToLLVMNonMemoryConversionPatterns(converter, patterns);
-  populateStdToLLVMMemoryConversionPatters(converter, patterns);
+  populateStdToLLVMMemoryConversionPatterns(converter, patterns,
+                                            useAlignedAlloc);
 }
 
 // Create an LLVM IR structure type if there is more than one result.
@@ -2850,10 +2942,11 @@ namespace {
 struct LLVMLoweringPass : public ConvertStandardToLLVMBase<LLVMLoweringPass> {
   LLVMLoweringPass() = default;
   LLVMLoweringPass(bool useBarePtrCallConv, bool emitCWrappers,
-                   unsigned indexBitwidth) {
+                   unsigned indexBitwidth, bool useAlignedAlloc) {
     this->useBarePtrCallConv = useBarePtrCallConv;
     this->emitCWrappers = emitCWrappers;
     this->indexBitwidth = indexBitwidth;
+    this->useAlignedAlloc = useAlignedAlloc;
   }
 
   /// Run the dialect converter on the module.
@@ -2876,10 +2969,11 @@ struct LLVMLoweringPass : public ConvertStandardToLLVMBase<LLVMLoweringPass> {
 
     OwningRewritePatternList patterns;
     if (useBarePtrCallConv)
-      populateStdToLLVMBarePtrConversionPatterns(typeConverter, patterns);
+      populateStdToLLVMBarePtrConversionPatterns(typeConverter, patterns,
+                                                 useAlignedAlloc);
     else
       populateStdToLLVMConversionPatterns(typeConverter, patterns,
-                                          emitCWrappers);
+                                          emitCWrappers, useAlignedAlloc);
 
     LLVMConversionTarget target(getContext());
     if (failed(applyPartialConversion(m, target, patterns, &typeConverter)))
@@ -2898,5 +2992,6 @@ mlir::LLVMConversionTarget::LLVMConversionTarget(MLIRContext &ctx)
 std::unique_ptr<OperationPass<ModuleOp>>
 mlir::createLowerToLLVMPass(const LowerToLLVMOptions &options) {
   return std::make_unique<LLVMLoweringPass>(
-      options.useBarePtrCallConv, options.emitCWrappers, options.indexBitwidth);
+      options.useBarePtrCallConv, options.emitCWrappers, options.indexBitwidth,
+      options.useAlignedAlloc);
 }

diff  --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 2e277d4524d6..7c7a8b0a0805 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -1068,7 +1068,7 @@ static LogicalResult verify(DimOp op) {
 OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
   // Constant fold dim when the size along the index referred to is a constant.
   auto opType = memrefOrTensor().getType();
-  int64_t indexSize = -1;
+  int64_t indexSize = ShapedType::kDynamicSize;
   if (auto tensorType = opType.dyn_cast<RankedTensorType>())
     indexSize = tensorType.getShape()[getIndex()];
   else if (auto memrefType = opType.dyn_cast<MemRefType>())

diff  --git a/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir b/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir
index 7cef73ee0868..5069172db926 100644
--- a/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir
+++ b/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir
@@ -1,4 +1,5 @@
 // RUN: mlir-opt -convert-std-to-llvm %s | FileCheck %s
+// RUN: mlir-opt -convert-std-to-llvm='use-aligned-alloc=1' %s | FileCheck %s --check-prefix=ALIGNED-ALLOC
 
 // CHECK-LABEL: func @check_strided_memref_arguments(
 // CHECK-COUNT-2: !llvm<"float*">
@@ -138,6 +139,52 @@ func @dynamic_dealloc(%arg0: memref<?x?xf32>) {
   return
 }
 
+// CHECK-LABEL: func @stdlib_aligned_alloc({{.*}}) -> !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> {
+// ALIGNED-ALLOC-LABEL: func @stdlib_aligned_alloc({{.*}}) -> !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> {
+func @stdlib_aligned_alloc(%N : index) -> memref<32x18xf32> {
+// ALIGNED-ALLOC-NEXT:  %[[sz1:.*]] = llvm.mlir.constant(32 : index) : !llvm.i64
+// ALIGNED-ALLOC-NEXT:  %[[sz2:.*]] = llvm.mlir.constant(18 : index) : !llvm.i64
+// ALIGNED-ALLOC-NEXT:  %[[num_elems:.*]] = llvm.mul %0, %1 : !llvm.i64
+// ALIGNED-ALLOC-NEXT:  %[[null:.*]] = llvm.mlir.null : !llvm<"float*">
+// ALIGNED-ALLOC-NEXT:  %[[one:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
+// ALIGNED-ALLOC-NEXT:  %[[gep:.*]] = llvm.getelementptr %[[null]][%[[one]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*">
+// ALIGNED-ALLOC-NEXT:  %[[sizeof:.*]] = llvm.ptrtoint %[[gep]] : !llvm<"float*"> to !llvm.i64
+// ALIGNED-ALLOC-NEXT:  %[[bytes:.*]] = llvm.mul %[[num_elems]], %[[sizeof]] : !llvm.i64
+// ALIGNED-ALLOC-NEXT:  %[[alignment:.*]] = llvm.mlir.constant(32 : i64) : !llvm.i64
+// ALIGNED-ALLOC-NEXT:  %[[allocated:.*]] = llvm.call @aligned_alloc(%[[alignment]], %[[bytes]]) : (!llvm.i64, !llvm.i64) -> !llvm<"i8*">
+// ALIGNED-ALLOC-NEXT:  llvm.bitcast %[[allocated]] : !llvm<"i8*"> to !llvm<"float*">
+  %0 = alloc() {alignment = 32} : memref<32x18xf32>
+  // Do another alloc just to test that we have a unique declaration for
+  // aligned_alloc.
+  // ALIGNED-ALLOC:  llvm.call @aligned_alloc
+  %1 = alloc() {alignment = 64} : memref<4096xf32>
+
+  // Alignment is to element type boundaries (minimum 16 bytes).
+  // ALIGNED-ALLOC:  %[[c32:.*]] = llvm.mlir.constant(32 : i64) : !llvm.i64
+  // ALIGNED-ALLOC-NEXT:  llvm.call @aligned_alloc(%[[c32]]
+  %2 = alloc() : memref<4096xvector<8xf32>>
+  // The minimum alignment is 16 bytes unless explicitly specified.
+  // ALIGNED-ALLOC:  %[[c16:.*]] = llvm.mlir.constant(16 : i64) : !llvm.i64
+  // ALIGNED-ALLOC-NEXT:  llvm.call @aligned_alloc(%[[c16]],
+  %3 = alloc() : memref<4096xvector<2xf32>>
+  // ALIGNED-ALLOC:  %[[c8:.*]] = llvm.mlir.constant(8 : i64) : !llvm.i64
+  // ALIGNED-ALLOC-NEXT:  llvm.call @aligned_alloc(%[[c8]],
+  %4 = alloc() {alignment = 8} : memref<1024xvector<4xf32>>
+  // Bump the memref allocation size if its size is not a multiple of alignment.
+  // ALIGNED-ALLOC:       %[[c32:.*]] = llvm.mlir.constant(32 : i64) : !llvm.i64
+  // ALIGNED-ALLOC-NEXT:  llvm.urem
+  // ALIGNED-ALLOC-NEXT:  llvm.sub
+  // ALIGNED-ALLOC-NEXT:  llvm.urem
+  // ALIGNED-ALLOC-NEXT:  %[[SIZE_ALIGNED:.*]] = llvm.add
+  // ALIGNED-ALLOC-NEXT:  llvm.call @aligned_alloc(%[[c32]], %[[SIZE_ALIGNED]])
+  %5 = alloc() {alignment = 32} : memref<100xf32>
+  // Bump alignment to the next power of two if it isn't.
+  // ALIGNED-ALLOC:  %[[c128:.*]] = llvm.mlir.constant(128 : i64) : !llvm.i64
+  // ALIGNED-ALLOC:  llvm.call @aligned_alloc(%[[c128]]
+  %6 = alloc(%N) : memref<?xvector<18xf32>>
+  return %0 : memref<32x18xf32>
+}
+
 // CHECK-LABEL: func @mixed_load(
 // CHECK-COUNT-2: !llvm<"float*">,
 // CHECK-COUNT-5: {{%[a-zA-Z0-9]*}}: !llvm.i64


        


More information about the Mlir-commits mailing list