[Mlir-commits] [mlir] 21caba5 - [MLIR] Lower GenericAtomicRMWOp to llvm.cmpxchg.

Alexander Belyaev llvmlistbot at llvm.org
Thu Apr 23 00:30:16 PDT 2020


Author: Alexander Belyaev
Date: 2020-04-23T09:29:34+02:00
New Revision: 21caba599e6ce806abc492b7ed1653a1aed8b63c

URL: https://github.com/llvm/llvm-project/commit/21caba599e6ce806abc492b7ed1653a1aed8b63c
DIFF: https://github.com/llvm/llvm-project/commit/21caba599e6ce806abc492b7ed1653a1aed8b63c.diff

LOG: [MLIR] Lower GenericAtomicRMWOp to llvm.cmpxchg.

Summary:
Lowering is pretty much a copy of AtomicRMWOp -> llvm.cmpxchg
pattern.

Differential Revision: https://reviews.llvm.org/D78647

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
    mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
    mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index e5a1af544e09..6c5f7c270f7a 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -539,6 +539,9 @@ def GenericAtomicRMWOp : Std_Op<"generic_atomic_rmw", [
     Value getCurrentValue() {
       return body().front().getArgument(0);
     }
+    MemRefType getMemRefType() {
+      return memref().getType().cast<MemRefType>();
+    }
   }];
 }
 

diff  --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index 3c5a92584425..74488cc9ef09 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -17,6 +17,7 @@
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/IR/Attributes.h"
+#include "mlir/IR/BlockAndValueMapping.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/MLIRContext.h"
 #include "mlir/IR/Module.h"
@@ -2746,6 +2747,104 @@ struct AtomicCmpXchgOpLowering : public LoadStoreOpLowering<AtomicRMWOp> {
   }
 };
 
+/// Wrap a llvm.cmpxchg operation in a while loop so that the operation can be
+/// retried until it succeeds in atomically storing a new value into memory.
+///
+///      +---------------------------------+
+///      |   <code before the AtomicRMWOp> |
+///      |   <compute initial %loaded>     |
+///      |   br loop(%loaded)              |
+///      +---------------------------------+
+///             |
+///  -------|   |
+///  |      v   v
+///  |   +--------------------------------+
+///  |   | loop(%loaded):                 |
+///  |   |   <body contents>              |
+///  |   |   %pair = cmpxchg              |
+///  |   |   %ok = %pair[0]               |
+///  |   |   %new = %pair[1]              |
+///  |   |   cond_br %ok, end, loop(%new) |
+///  |   +--------------------------------+
+///  |          |        |
+///  |-----------        |
+///                      v
+///      +--------------------------------+
+///      | end:                           |
+///      |   <code after the AtomicRMWOp> |
+///      +--------------------------------+
+///
+struct GenericAtomicRMWOpLowering
+    : public LoadStoreOpLowering<GenericAtomicRMWOp> {
+  using Base::Base;
+
+  LogicalResult
+  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto atomicOp = cast<GenericAtomicRMWOp>(op);
+
+    auto loc = op->getLoc();
+    OperandAdaptor<GenericAtomicRMWOp> adaptor(operands);
+    LLVM::LLVMType valueType =
+        typeConverter.convertType(atomicOp.getResult().getType())
+            .cast<LLVM::LLVMType>();
+
+    // Split the block into initial, loop, and ending parts.
+    auto *initBlock = rewriter.getInsertionBlock();
+    auto initPosition = rewriter.getInsertionPoint();
+    auto *loopBlock = rewriter.splitBlock(initBlock, initPosition);
+    auto loopArgument = loopBlock->addArgument(valueType);
+    auto loopPosition = rewriter.getInsertionPoint();
+    auto *endBlock = rewriter.splitBlock(loopBlock, loopPosition);
+
+    // Compute the loaded value and branch to the loop block.
+    rewriter.setInsertionPointToEnd(initBlock);
+    auto memRefType = atomicOp.memref().getType().cast<MemRefType>();
+    auto dataPtr = getDataPtr(loc, memRefType, adaptor.memref(),
+                              adaptor.indices(), rewriter, getModule());
+    Value init = rewriter.create<LLVM::LoadOp>(loc, dataPtr);
+    rewriter.create<LLVM::BrOp>(loc, init, loopBlock);
+
+    // Prepare the body of the loop block.
+    rewriter.setInsertionPointToStart(loopBlock);
+    auto boolType = LLVM::LLVMType::getInt1Ty(&getDialect());
+
+    // Clone the GenericAtomicRMWOp region and extract the result.
+    BlockAndValueMapping mapping;
+    mapping.map(atomicOp.getCurrentValue(), loopArgument);
+    Block &entryBlock = atomicOp.body().front();
+    for (auto &nestedOp : entryBlock.without_terminator()) {
+      Operation *clone = rewriter.clone(nestedOp, mapping);
+      mapping.map(nestedOp.getResults(), clone->getResults());
+    }
+    Value result = mapping.lookup(entryBlock.getTerminator()->getOperand(0));
+
+    // Prepare the epilog of the loop block.
+    rewriter.setInsertionPointToEnd(loopBlock);
+    // Append the cmpxchg op to the end of the loop block.
+    auto successOrdering = LLVM::AtomicOrdering::acq_rel;
+    auto failureOrdering = LLVM::AtomicOrdering::monotonic;
+    auto pairType = LLVM::LLVMType::getStructTy(valueType, boolType);
+    auto cmpxchg = rewriter.create<LLVM::AtomicCmpXchgOp>(
+        loc, pairType, dataPtr, loopArgument, result, successOrdering,
+        failureOrdering);
+    // Extract the %new_loaded and %ok values from the pair.
+    Value newLoaded = rewriter.create<LLVM::ExtractValueOp>(
+        loc, valueType, cmpxchg, rewriter.getI64ArrayAttr({0}));
+    Value ok = rewriter.create<LLVM::ExtractValueOp>(
+        loc, boolType, cmpxchg, rewriter.getI64ArrayAttr({1}));
+
+    // Conditionally branch to the end or back to the loop depending on %ok.
+    rewriter.create<LLVM::CondBrOp>(loc, ok, endBlock, ArrayRef<Value>(),
+                                    loopBlock, newLoaded);
+
+    // The 'result' of the atomic_rmw op is the newly loaded value.
+    rewriter.replaceOp(op, {newLoaded});
+
+    return success();
+  }
+};
+
 } // namespace
 
 /// Collect a set of patterns to convert from the Standard dialect to LLVM.
@@ -2775,6 +2874,7 @@ void mlir::populateStdToLLVMNonMemoryConversionPatterns(
       DivFOpLowering,
       ExpOpLowering,
       Exp2OpLowering,
+      GenericAtomicRMWOpLowering,
       LogOpLowering,
       Log10OpLowering,
       Log2OpLowering,

diff  --git a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
index 9c8d47db05b8..2b6038dfe776 100644
--- a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
+++ b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
@@ -1029,6 +1029,30 @@ func @cmpxchg(%F : memref<10xf32>, %fval : f32, %i : index) -> f32 {
 
 // -----
 
+// CHECK-LABEL: func @generic_atomic_rmw
+// CHECK32-LABEL: func @generic_atomic_rmw
+func @generic_atomic_rmw(%I : memref<10xf32>, %i : index) -> f32 {
+  %x = generic_atomic_rmw %I[%i] : memref<10xf32> {
+    ^bb0(%old_value : f32):
+      %c1 = constant 1.0 : f32
+      atomic_yield %c1 : f32
+  }
+  // CHECK: [[init:%.*]] = llvm.load %{{.*}} : !llvm<"float*">
+  // CHECK-NEXT: llvm.br ^bb1([[init]] : !llvm.float)
+  // CHECK-NEXT: ^bb1([[loaded:%.*]]: !llvm.float):
+  // CHECK-NEXT: [[c1:%.*]] = llvm.mlir.constant(1.000000e+00 : f32)
+  // CHECK-NEXT: [[pair:%.*]] = llvm.cmpxchg %{{.*}}, [[loaded]], [[c1]]
+  // CHECK-SAME:                    acq_rel monotonic : !llvm.float
+  // CHECK-NEXT: [[new:%.*]] = llvm.extractvalue [[pair]][0]
+  // CHECK-NEXT: [[ok:%.*]] = llvm.extractvalue [[pair]][1]
+  // CHECK-NEXT: llvm.cond_br [[ok]], ^bb2, ^bb1([[new]] : !llvm.float)
+  // CHECK-NEXT: ^bb2:
+  return %x : f32
+  // CHECK-NEXT: llvm.return [[new]]
+}
+
+// -----
+
 // CHECK-LABEL: func @assume_alignment
 func @assume_alignment(%0 : memref<4x4xf16>) {
   // CHECK: %[[PTR:.*]] = llvm.extractvalue %[[MEMREF:.*]][1] : !llvm<"{ half*, half*, i64, [2 x i64], [2 x i64] }">


        


More information about the Mlir-commits mailing list