[Mlir-commits] [mlir] 11ee125 - [MLIR][LLVM] Realign allocas to avoid dynamic realignment in inliner.

Johannes de Fine Licht llvmlistbot at llvm.org
Wed Apr 19 11:41:52 PDT 2023


Author: Johannes de Fine Licht
Date: 2023-04-19T18:39:54Z
New Revision: 11ee125cc1935d068c1dfe72a2005e0a6fda4100

URL: https://github.com/llvm/llvm-project/commit/11ee125cc1935d068c1dfe72a2005e0a6fda4100
DIFF: https://github.com/llvm/llvm-project/commit/11ee125cc1935d068c1dfe72a2005e0a6fda4100.diff

LOG: [MLIR][LLVM] Realign allocas to avoid dynamic realignment in inliner.

When the natural stack alignment is not set or is larger than or equal
to the target alignment required by a read-only byval argument defined
by an alloca, avoid the copy by just realigning the alloca to the target
alignment.

The code to check existing alignment is reorganized a bit to avoid
redundant casts.

This also includes a bugfix for passing a null DataLayoutInterface to
the DataLayout constructor when no parent op defines
DataLayoutInterface, and will now pass ModuleOp instead in this case.

Reviewed By: gysit

Differential Revision: https://reviews.llvm.org/D148557

Added: 
    

Modified: 
    mlir/lib/Dialect/LLVMIR/IR/LLVMInlining.cpp
    mlir/lib/Interfaces/DataLayoutInterfaces.cpp
    mlir/test/Dialect/LLVMIR/inlining.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMInlining.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMInlining.cpp
index d1d452124cefd..57852ed6e5d9a 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMInlining.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMInlining.cpp
@@ -14,6 +14,7 @@
 #include "LLVMInlining.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/IR/Matchers.h"
+#include "mlir/Interfaces/DataLayoutInterfaces.h"
 #include "mlir/Transforms/InliningUtils.h"
 #include "llvm/Support/Debug.h"
 
@@ -124,14 +125,44 @@ handleInlinedAllocas(Operation *call,
   }
 }
 
+/// If `requestedAlignment` is higher than the alignment specified on `alloca`,
+/// realigns `alloca` if this does not exceed the natural stack alignment.
+/// Returns the post-alignment of `alloca`, whether it was realigned or not.
+static unsigned tryToEnforceAllocaAlignment(LLVM::AllocaOp alloca,
+                                            unsigned requestedAlignment,
+                                            DataLayout const &dataLayout) {
+  unsigned allocaAlignment = alloca.getAlignment().value_or(1);
+  if (requestedAlignment <= allocaAlignment)
+    // No realignment necessary.
+    return allocaAlignment;
+  unsigned naturalStackAlignmentBits = dataLayout.getStackAlignment();
+  // If the natural stack alignment is not specified, the data layout returns
+  // zero. Optimistically allow realignment in this case.
+  if (naturalStackAlignmentBits == 0 ||
+      // If the requested alignment exceeds the natural stack alignment, this
+      // will trigger a dynamic stack realignment, so we prefer to copy...
+      8 * requestedAlignment <= naturalStackAlignmentBits ||
+      // ...unless the alloca already triggers dynamic stack realignment. Then
+      // we might as well further increase the alignment to avoid a copy.
+      8 * allocaAlignment > naturalStackAlignmentBits) {
+    alloca.setAlignment(requestedAlignment);
+    allocaAlignment = requestedAlignment;
+  }
+  return allocaAlignment;
+}
+
 /// Tries to find and return the alignment of the pointer `value` by looking for
 /// an alignment attribute on the defining allocation op or function argument.
-/// If no such attribute is found, returns 1 (i.e., assume that no alignment is
-/// guaranteed).
-static unsigned getAlignmentOf(Value value) {
+/// If the found alignment is lower than `requestedAlignment`, tries to realign
+/// the pointer, then returns the resulting post-alignment, regardless of
+/// whether it was realigned or not. If no existing alignment attribute is
+/// found, returns 1 (i.e., assume that no alignment is guaranteed).
+static unsigned tryToEnforceAlignment(Value value, unsigned requestedAlignment,
+                                      DataLayout const &dataLayout) {
   if (Operation *definingOp = value.getDefiningOp()) {
     if (auto alloca = dyn_cast<LLVM::AllocaOp>(definingOp))
-      return alloca.getAlignment().value_or(1);
+      return tryToEnforceAllocaAlignment(alloca, requestedAlignment,
+                                         dataLayout);
     if (auto addressOf = dyn_cast<LLVM::AddressOfOp>(definingOp))
       if (auto global = SymbolTable::lookupNearestSymbolFrom<LLVM::GlobalOp>(
               definingOp, addressOf.getGlobalNameAttr()))
@@ -143,8 +174,8 @@ static unsigned getAlignmentOf(Value value) {
   // comes directly from a function argument, so check that this is the case.
   Operation *parentOp = value.getParentBlock()->getParentOp();
   if (auto func = dyn_cast<LLVM::LLVMFuncOp>(parentOp)) {
-    // Use the alignment attribute set for this argument in the parent
-    // function if it has been set.
+    // Use the alignment attribute set for this argument in the parent function
+    // if it has been set.
     auto blockArg = value.cast<BlockArgument>();
     if (Attribute alignAttr = func.getArgAttr(
             blockArg.getArgNumber(), LLVM::LLVMDialect::getAlignAttrName()))
@@ -154,19 +185,19 @@ static unsigned getAlignmentOf(Value value) {
   return 1;
 }
 
-/// Copies the data from a byval pointer argument into newly alloca'ed memory
-/// and returns the value of the alloca.
+/// Introduces a new alloca and copies the memory pointed to by `argument` to
+/// the address of the new alloca, then returns the value of the new alloca.
 static Value handleByValArgumentInit(OpBuilder &builder, Location loc,
                                      Value argument, Type elementType,
                                      unsigned elementTypeSize,
                                      unsigned targetAlignment) {
-  Block *entryBlock = &(*argument.getParentRegion()->begin());
   // Allocate the new value on the stack.
   Value allocaOp;
   {
     // Since this is a static alloca, we can put it directly in the entry block,
     // so they can be absorbed into the prologue/epilogue at code generation.
     OpBuilder::InsertionGuard insertionGuard(builder);
+    Block *entryBlock = &(*argument.getParentRegion()->begin());
     builder.setInsertionPointToStart(entryBlock);
     Value one = builder.create<LLVM::ConstantOp>(loc, builder.getI64Type(),
                                                  builder.getI64IntegerAttr(1));
@@ -183,10 +214,10 @@ static Value handleByValArgumentInit(OpBuilder &builder, Location loc,
 }
 
 /// Handles a function argument marked with the byval attribute by introducing a
-/// memcpy if necessary, either due to the pointee being writeable in the
-/// callee, and/or due to an alignment mismatch. `requestedAlignment` specifies
-/// the alignment set in the "align" argument attribute (or 1 if no align
-/// attribute was set).
+/// memcpy or realigning the defining operation, if required either due to the
+/// pointee being writeable in the callee, and/or due to an alignment mismatch.
+/// `requestedAlignment` specifies the alignment set in the "align" argument
+/// attribute (or 1 if no align attribute was set).
 static Value handleByValArgument(OpBuilder &builder, Operation *callable,
                                  Value argument, Type elementType,
                                  unsigned requestedAlignment) {
@@ -198,11 +229,16 @@ static Value handleByValArgument(OpBuilder &builder, Operation *callable,
                     memoryEffects.getArgMem() != LLVM::ModRefInfo::ModRef &&
                     memoryEffects.getArgMem() != LLVM::ModRefInfo::Mod;
   // Check if there's an alignment mismatch requiring us to copy.
-  DataLayout dataLayout(callable->getParentOfType<DataLayoutOpInterface>());
+  DataLayout dataLayout = DataLayout::closest(callable);
   unsigned minimumAlignment = dataLayout.getTypeABIAlignment(elementType);
-  if (isReadOnly && (requestedAlignment <= minimumAlignment ||
-                     getAlignmentOf(argument) >= requestedAlignment))
-    return argument;
+  if (isReadOnly) {
+    if (requestedAlignment <= minimumAlignment)
+      return argument;
+    unsigned currentAlignment =
+        tryToEnforceAlignment(argument, requestedAlignment, dataLayout);
+    if (currentAlignment >= requestedAlignment)
+      return argument;
+  }
   unsigned targetAlignment = std::max(requestedAlignment, minimumAlignment);
   return handleByValArgumentInit(builder, func.getLoc(), argument, elementType,
                                  dataLayout.getTypeSize(elementType),

diff  --git a/mlir/lib/Interfaces/DataLayoutInterfaces.cpp b/mlir/lib/Interfaces/DataLayoutInterfaces.cpp
index ae8e680f0a068..cc04fa551d5c0 100644
--- a/mlir/lib/Interfaces/DataLayoutInterfaces.cpp
+++ b/mlir/lib/Interfaces/DataLayoutInterfaces.cpp
@@ -483,13 +483,13 @@ unsigned mlir::DataLayout::getTypePreferredAlignment(Type t) const {
 
 mlir::Attribute mlir::DataLayout::getAllocaMemorySpace() const {
   checkValid();
-  MLIRContext *context = scope->getContext();
   if (allocaMemorySpace)
     return *allocaMemorySpace;
   DataLayoutEntryInterface entry;
   if (originalLayout)
     entry = originalLayout.getSpecForIdentifier(
-        originalLayout.getAllocaMemorySpaceIdentifier(context));
+        originalLayout.getAllocaMemorySpaceIdentifier(
+            originalLayout.getContext()));
   if (auto iface = dyn_cast_or_null<DataLayoutOpInterface>(scope))
     allocaMemorySpace = iface.getAllocaMemorySpace(entry);
   else
@@ -499,13 +499,13 @@ mlir::Attribute mlir::DataLayout::getAllocaMemorySpace() const {
 
 unsigned mlir::DataLayout::getStackAlignment() const {
   checkValid();
-  MLIRContext *context = scope->getContext();
   if (stackAlignment)
     return *stackAlignment;
   DataLayoutEntryInterface entry;
   if (originalLayout)
     entry = originalLayout.getSpecForIdentifier(
-        originalLayout.getStackAlignmentIdentifier(context));
+        originalLayout.getStackAlignmentIdentifier(
+            originalLayout.getContext()));
   if (auto iface = dyn_cast_or_null<DataLayoutOpInterface>(scope))
     stackAlignment = iface.getStackAlignment(entry);
   else

diff  --git a/mlir/test/Dialect/LLVMIR/inlining.mlir b/mlir/test/Dialect/LLVMIR/inlining.mlir
index a1f4838dfee98..e6ca34067413c 100644
--- a/mlir/test/Dialect/LLVMIR/inlining.mlir
+++ b/mlir/test/Dialect/LLVMIR/inlining.mlir
@@ -452,16 +452,19 @@ llvm.func @test_byval_input_aligned(%unaligned : !llvm.ptr, %aligned : !llvm.ptr
 
 // -----
 
+llvm.func @func_that_uses_ptr(%ptr : !llvm.ptr)
+
 llvm.func @aligned_byval_arg(%ptr : !llvm.ptr { llvm.byval = i16, llvm.align = 16 }) attributes {memory = #llvm.memory_effects<other = read, argMem = read, inaccessibleMem = read>} {
+  llvm.call @func_that_uses_ptr(%ptr) : (!llvm.ptr) -> ()
   llvm.return
 }
 
-// CHECK-LABEL: llvm.func @test_byval_unaligned_alloca
-llvm.func @test_byval_unaligned_alloca() {
+// CHECK-LABEL: llvm.func @test_byval_realign_alloca
+llvm.func @test_byval_realign_alloca() {
   %size = llvm.mlir.constant(4 : i64) : i64
-  // CHECK-DAG: %[[SRC:.+]] = llvm.alloca {{.+}}alignment = 1 : i64
-  // CHECK-DAG: %[[DST:.+]] = llvm.alloca {{.+}}alignment = 16 : i64
-  // CHECK: "llvm.intr.memcpy"(%[[DST]], %[[SRC]]
+  // CHECK-NOT: llvm.alloca{{.+}}alignment = 1
+  // CHECK: llvm.alloca {{.+}}alignment = 16 : i64
+  // CHECK-NOT: llvm.intr.memcpy
   %unaligned = llvm.alloca %size x i16 { alignment = 1 } : (i64) -> !llvm.ptr
   llvm.call @aligned_byval_arg(%unaligned) : (!llvm.ptr) -> ()
   llvm.return
@@ -469,19 +472,61 @@ llvm.func @test_byval_unaligned_alloca() {
 
 // -----
 
+module attributes {
+  dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.stack_alignment", 32 : i32>>
+} {
+
+llvm.func @func_that_uses_ptr(%ptr : !llvm.ptr)
+
 llvm.func @aligned_byval_arg(%ptr : !llvm.ptr { llvm.byval = i16, llvm.align = 16 }) attributes {memory = #llvm.memory_effects<other = read, argMem = read, inaccessibleMem = read>} {
+  llvm.call @func_that_uses_ptr(%ptr) : (!llvm.ptr) -> ()
   llvm.return
 }
 
-// CHECK-LABEL: llvm.func @test_byval_aligned_alloca
-llvm.func @test_byval_aligned_alloca() {
-  // CHECK-NOT: memcpy
-  %size = llvm.mlir.constant(1 : i64) : i64
-  %aligned = llvm.alloca %size x i16 { alignment = 16 } : (i64) -> !llvm.ptr
-  llvm.call @aligned_byval_arg(%aligned) : (!llvm.ptr) -> ()
+// CHECK-LABEL: llvm.func @test_exceeds_natural_stack_alignment
+llvm.func @test_exceeds_natural_stack_alignment() {
+  %size = llvm.mlir.constant(4 : i64) : i64
+  // Natural stack alignment is exceeded, so prefer a copy instead of
+  // triggering a dynamic stack realignment.
+  // CHECK-DAG: %[[SRC:[a-zA-Z0-9_]+]] = llvm.alloca{{.+}}alignment = 2
+  // CHECK-DAG: %[[DST:[a-zA-Z0-9_]+]] = llvm.alloca{{.+}}alignment = 16
+  // CHECK: "llvm.intr.memcpy"(%[[DST]], %[[SRC]]
+  %unaligned = llvm.alloca %size x i16 { alignment = 2 } : (i64) -> !llvm.ptr
+  llvm.call @aligned_byval_arg(%unaligned) : (!llvm.ptr) -> ()
+  llvm.return
+}
+
+}
+
+// -----
+
+module attributes {
+  dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.stack_alignment", 32 : i32>>
+} {
+
+llvm.func @func_that_uses_ptr(%ptr : !llvm.ptr)
+
+llvm.func @aligned_byval_arg(%ptr : !llvm.ptr { llvm.byval = i16, llvm.align = 16 }) attributes {memory = #llvm.memory_effects<other = read, argMem = read, inaccessibleMem = read>} {
+  llvm.call @func_that_uses_ptr(%ptr) : (!llvm.ptr) -> ()
   llvm.return
 }
 
+// CHECK-LABEL: llvm.func @test_alignment_exceeded_anyway
+llvm.func @test_alignment_exceeded_anyway() {
+  %size = llvm.mlir.constant(4 : i64) : i64
+  // Natural stack alignment is lower than the target alignment, but the
+  // alloca's existing alignment already exceeds it, so we might as well avoid
+  // the copy.
+  // CHECK-NOT: llvm.alloca{{.+}}alignment = 1
+  // CHECK: llvm.alloca {{.+}}alignment = 16 : i64
+  // CHECK-NOT: llvm.intr.memcpy
+  %unaligned = llvm.alloca %size x i16 { alignment = 8 } : (i64) -> !llvm.ptr
+  llvm.call @aligned_byval_arg(%unaligned) : (!llvm.ptr) -> ()
+  llvm.return
+}
+
+}
+
 // -----
 
 llvm.mlir.global private @unaligned_global(42 : i64) : i64


        


More information about the Mlir-commits mailing list