[Mlir-commits] [mlir] f76e40d - [mlir] MemRefToLLVM: Save / restore stack when lowering memref.copy

Alex Zinenko llvmlistbot at llvm.org
Thu Oct 13 01:13:16 PDT 2022


Author: Andi Drebes
Date: 2022-10-13T10:13:04+02:00
New Revision: f76e40d1a4d7d95d8ceb1fad7be15bdfe96725a0

URL: https://github.com/llvm/llvm-project/commit/f76e40d1a4d7d95d8ceb1fad7be15bdfe96725a0
DIFF: https://github.com/llvm/llvm-project/commit/f76e40d1a4d7d95d8ceb1fad7be15bdfe96725a0.diff

LOG: [mlir] MemRefToLLVM: Save / restore stack when lowering memref.copy

The MemRef to LLVM conversion pass emits `llvm.alloca` operations to promote MemRef descriptors to the stack when lowering `memref.copy` operations for operands which do not have a contiguous layout in memory. The original stack position is never restored after the allocations, which creates an issue when the copy operation is embedded into a loop with a high trip count, ultimately resulting in a segmentation fault due to the stack growing too large.

Below is as a minimal example illustrating the issue:

```
module {
  func.func @main() {
    %arg0 = memref.alloc() : memref<32x64xi64>
    %arg1 = memref.alloc() : memref<16x32xi64>
    %lb = arith.constant 0 : index
    %ub = arith.constant 100000 : index
    %step = arith.constant 1 : index
    %slice = memref.subview %arg0[16,32][16,32][1,1] :
       memref<32x64xi64> to memref<16x32xi64, #map>

    scf.for %i = %lb to %ub step %step {
       memref.copy %slice, %arg1 :
         memref<16x32xi64, #map> to memref<16x32xi64>
    }

    return
  }
}
```

When running the code above, e.g., with mlir-cpu-runner, the execution crashes with a segmentation fault:

```
$ mlir-opt \
    --convert-scf-to-cf \
    --convert-memref-to-llvm \
    --convert-func-to-llvm
    --convert-cf-to-llvm \
    --reconcile-unrealized-casts <file> | \
  mlir-cpu-runner \
    -e main -entry-point-result=void \
    --shared-libs=$PWD/build/lib/libmlir_c_runner_utils.so
[...]
Segmentation fault
```

This patch causes the code lowering a `memref.copy` operation in the MemRefToLLVM pass to emit a pair of matching `llvm.intr.stacksave` and `llvm.intr.stackrestore` operations around the promotion of memory descriptors and the subsequent call to `memrefCopy` in order to restore the stack to its original position after the call.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D135756

Added: 
    

Modified: 
    mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
    mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 0f1f644347e4b..141b250ada181 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -961,6 +961,10 @@ struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
                                             ValueRange{rank, voidPtr});
     };
 
+    // Save stack position before promoting descriptors
+    auto stackSaveOp =
+        rewriter.create<LLVM::StackSaveOp>(loc, getVoidPtrType());
+
     Value unrankedSource = srcType.hasRank()
                                ? makeUnranked(adaptor.getSource(), srcType)
                                : adaptor.getSource();
@@ -990,6 +994,10 @@ struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
         op->getParentOfType<ModuleOp>(), getIndexType(), sourcePtr.getType());
     rewriter.create<LLVM::CallOp>(loc, copyFn,
                                   ValueRange{elemSize, sourcePtr, targetPtr});
+
+    // Restore stack used for descriptors
+    rewriter.create<LLVM::StackRestoreOp>(loc, stackSaveOp);
+
     rewriter.eraseOp(op);
 
     return success();

diff  --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
index c66cd5824ca16..1e4b89f92da5a 100644
--- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
@@ -1138,6 +1138,7 @@ func.func @memref_copy_unranked() {
   // CHECK: [[UNDEF:%.*]] = llvm.mlir.undef : !llvm.struct<(i64, ptr<i8>)>
   // CHECK: [[INSERT:%.*]] = llvm.insertvalue [[RANK]], [[UNDEF]][0] : !llvm.struct<(i64, ptr<i8>)>
   // CHECK: [[INSERT2:%.*]] = llvm.insertvalue [[BITCAST]], [[INSERT]][1] : !llvm.struct<(i64, ptr<i8>)>
+  // CHECK: [[STACKSAVE:%.*]] = llvm.intr.stacksave : !llvm.ptr<i8>
   // CHECK: [[RANK2:%.*]] = llvm.mlir.constant(1 : index) : i64
   // CHECK: [[ALLOCA2:%.*]] = llvm.alloca [[RANK2]] x !llvm.struct<(i64, ptr<i8>)> : (i64) -> !llvm.ptr<struct<(i64, ptr<i8>)>>
   // CHECK: llvm.store {{%.*}}, [[ALLOCA2]] : !llvm.ptr<struct<(i64, ptr<i8>)>>
@@ -1145,6 +1146,7 @@ func.func @memref_copy_unranked() {
   // CHECK: llvm.store [[INSERT2]], [[ALLOCA3]] : !llvm.ptr<struct<(i64, ptr<i8>)>>
   // CHECK: [[SIZE:%.*]] = llvm.mlir.constant(1 : index) : i64
   // CHECK: llvm.call @memrefCopy([[SIZE]], [[ALLOCA2]], [[ALLOCA3]]) : (i64, !llvm.ptr<struct<(i64, ptr<i8>)>>, !llvm.ptr<struct<(i64, ptr<i8>)>>) -> ()
+  // CHECK: llvm.intr.stackrestore [[STACKSAVE]]
   return
 }
 


        


More information about the Mlir-commits mailing list