[Mlir-commits] [mlir] e7833c2 - [mlir] Use splitBlock instread of createBlock in GenericAtomicRMWLowering.

Alexander Belyaev llvmlistbot at llvm.org
Mon Mar 13 10:14:25 PDT 2023


Author: Alexander Belyaev
Date: 2023-03-13T18:14:04+01:00
New Revision: e7833c20d835d0f358acf7708a72bc23b1d87973

URL: https://github.com/llvm/llvm-project/commit/e7833c20d835d0f358acf7708a72bc23b1d87973
DIFF: https://github.com/llvm/llvm-project/commit/e7833c20d835d0f358acf7708a72bc23b1d87973.diff

LOG: [mlir] Use splitBlock instread of createBlock in GenericAtomicRMWLowering.

When generic_atomic_rmw is inside of memref.alloca_scope, then the pattern would fail.

Differential Revision: https://reviews.llvm.org/D145901

Added: 
    

Modified: 
    mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
    mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 7da7b66ab1f35..2af5a2522566d 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -580,15 +580,11 @@ struct GenericAtomicRMWOpLowering
 
     // Split the block into initial, loop, and ending parts.
     auto *initBlock = rewriter.getInsertionBlock();
-    auto *loopBlock = rewriter.createBlock(
-        initBlock->getParent(), std::next(Region::iterator(initBlock)),
-        valueType, loc);
-    auto *endBlock = rewriter.createBlock(
-        loopBlock->getParent(), std::next(Region::iterator(loopBlock)));
+    auto *loopBlock = rewriter.splitBlock(initBlock, Block::iterator(atomicOp));
+    loopBlock->addArgument(valueType, loc);
 
-    // Operations range to be moved to `endBlock`.
-    auto opsToMoveStart = atomicOp->getIterator();
-    auto opsToMoveEnd = initBlock->back().getIterator();
+    auto *endBlock =
+        rewriter.splitBlock(loopBlock, Block::iterator(atomicOp)++);
 
     // Compute the loaded value and branch to the loop block.
     rewriter.setInsertionPointToEnd(initBlock);
@@ -628,30 +624,12 @@ struct GenericAtomicRMWOpLowering
                                     loopBlock, newLoaded);
 
     rewriter.setInsertionPointToEnd(endBlock);
-    moveOpsRange(atomicOp.getResult(), newLoaded, std::next(opsToMoveStart),
-                 std::next(opsToMoveEnd), rewriter);
 
     // The 'result' of the atomic_rmw op is the newly loaded value.
     rewriter.replaceOp(atomicOp, {newLoaded});
 
     return success();
   }
-
-private:
-  // Clones a segment of ops [start, end) and erases the original.
-  void moveOpsRange(ValueRange oldResult, ValueRange newResult,
-                    Block::iterator start, Block::iterator end,
-                    ConversionPatternRewriter &rewriter) const {
-    IRMapping mapping;
-    mapping.map(oldResult, newResult);
-    SmallVector<Operation *, 2> opsToErase;
-    for (auto it = start; it != end; ++it) {
-      rewriter.clone(*it, mapping);
-      opsToErase.push_back(&*it);
-    }
-    for (auto *it : opsToErase)
-      rewriter.eraseOp(it);
-  }
 };
 
 /// Returns the LLVM type of the global variable given the memref type `type`.

diff  --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
index f6dc44cf4571f..4b4b3836a0075 100644
--- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
@@ -362,16 +362,47 @@ func.func @generic_atomic_rmw(%I : memref<10xi32>, %i : index) {
     ^bb0(%old_value : i32):
       memref.atomic_yield %old_value : i32
   }
-  // CHECK: [[init:%.*]] = llvm.load %{{.*}} : !llvm.ptr -> i32
-  // CHECK-NEXT: llvm.br ^bb1([[init]] : i32)
-  // CHECK-NEXT: ^bb1([[loaded:%.*]]: i32):
-  // CHECK-NEXT: [[pair:%.*]] = llvm.cmpxchg %{{.*}}, [[loaded]], [[loaded]]
-  // CHECK-SAME:                    acq_rel monotonic : !llvm.ptr, i32
-  // CHECK-NEXT: [[new:%.*]] = llvm.extractvalue [[pair]][0]
-  // CHECK-NEXT: [[ok:%.*]] = llvm.extractvalue [[pair]][1]
-  // CHECK-NEXT: llvm.cond_br [[ok]], ^bb2, ^bb1([[new]] : i32)
   llvm.return
 }
+// CHECK:        %[[INIT:.*]] = llvm.load %{{.*}} : !llvm.ptr -> i32
+// CHECK-NEXT:   llvm.br ^bb1(%[[INIT]] : i32)
+// CHECK-NEXT: ^bb1(%[[LOADED:.*]]: i32):
+// CHECK-NEXT:   %[[PAIR:.*]] = llvm.cmpxchg %{{.*}}, %[[LOADED]], %[[LOADED]]
+// CHECK-SAME:                      acq_rel monotonic : !llvm.ptr, i32
+// CHECK-NEXT:   %[[NEW:.*]] = llvm.extractvalue %[[PAIR]][0]
+// CHECK-NEXT:   %[[OK:.*]] = llvm.extractvalue %[[PAIR]][1]
+// CHECK-NEXT:   llvm.cond_br %[[OK]], ^bb2, ^bb1(%[[NEW]] : i32)
+
+// -----
+
+// CHECK-LABEL: func @generic_atomic_rmw_in_alloca_scope
+func.func @generic_atomic_rmw_in_alloca_scope(){
+  %c1 = arith.constant 1 : index
+  %alloc = memref.alloc() : memref<2x3xi32>
+  memref.alloca_scope  {
+    %0 = memref.generic_atomic_rmw %alloc[%c1, %c1] : memref<2x3xi32> {
+    ^bb0(%arg0: i32):
+      memref.atomic_yield %arg0 : i32
+    }
+  }
+  return
+}
+// CHECK:        %[[STACK_SAVE:.*]] = llvm.intr.stacksave : !llvm.ptr
+// CHECK-NEXT:   llvm.br ^bb1
+// CHECK:      ^bb1:
+// CHECK:        %[[INIT:.*]] = llvm.load %[[BUF:.*]] : !llvm.ptr -> i32
+// CHECK-NEXT:   llvm.br ^bb2(%[[INIT]] : i32)
+// CHECK-NEXT: ^bb2(%[[LOADED:.*]]: i32):
+// CHECK-NEXT:   %[[PAIR:.*]] = llvm.cmpxchg %[[BUF]], %[[LOADED]], %[[LOADED]]
+// CHECK-SAME:     acq_rel monotonic : !llvm.ptr, i32
+// CHECK-NEXT:   %[[NEW:.*]] = llvm.extractvalue %[[PAIR]][0]
+// CHECK-NEXT:   %[[OK:.*]] = llvm.extractvalue %[[PAIR]][1]
+// CHECK-NEXT:   llvm.cond_br %[[OK]], ^bb3, ^bb2(%[[NEW]] : i32)
+// CHECK-NEXT: ^bb3:
+// CHECK-NEXT:   llvm.intr.stackrestore %[[STACK_SAVE]] : !llvm.ptr
+// CHECK-NEXT:   llvm.br ^bb4
+// CHECK-NEXT: ^bb4:
+// CHECK-NEXT:   return
 
 // -----
 


        


More information about the Mlir-commits mailing list