[Mlir-commits] [mlir] 21caba5 - [MLIR] Lower GenericAtomicRMWOp to llvm.cmpxchg.
Alexander Belyaev
llvmlistbot at llvm.org
Thu Apr 23 00:30:16 PDT 2020
Author: Alexander Belyaev
Date: 2020-04-23T09:29:34+02:00
New Revision: 21caba599e6ce806abc492b7ed1653a1aed8b63c
URL: https://github.com/llvm/llvm-project/commit/21caba599e6ce806abc492b7ed1653a1aed8b63c
DIFF: https://github.com/llvm/llvm-project/commit/21caba599e6ce806abc492b7ed1653a1aed8b63c.diff
LOG: [MLIR] Lower GenericAtomicRMWOp to llvm.cmpxchg.
Summary:
Lowering is pretty much a copy of AtomicRMWOp -> llvm.cmpxchg
pattern.
Differential Revision: https://reviews.llvm.org/D78647
Added:
Modified:
mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index e5a1af544e09..6c5f7c270f7a 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -539,6 +539,9 @@ def GenericAtomicRMWOp : Std_Op<"generic_atomic_rmw", [
Value getCurrentValue() {
return body().front().getArgument(0);
}
+ MemRefType getMemRefType() {
+ return memref().getType().cast<MemRefType>();
+ }
}];
}
diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index 3c5a92584425..74488cc9ef09 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -17,6 +17,7 @@
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Attributes.h"
+#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Module.h"
@@ -2746,6 +2747,104 @@ struct AtomicCmpXchgOpLowering : public LoadStoreOpLowering<AtomicRMWOp> {
}
};
+/// Wrap a llvm.cmpxchg operation in a while loop so that the operation can be
+/// retried until it succeeds in atomically storing a new value into memory.
+///
+/// +---------------------------------+
+/// | <code before the AtomicRMWOp> |
+/// | <compute initial %loaded> |
+/// | br loop(%loaded) |
+/// +---------------------------------+
+/// |
+/// -------| |
+/// | v v
+/// | +--------------------------------+
+/// | | loop(%loaded): |
+/// | | <body contents> |
+/// | | %pair = cmpxchg |
+/// | | %ok = %pair[0] |
+/// | | %new = %pair[1] |
+/// | | cond_br %ok, end, loop(%new) |
+/// | +--------------------------------+
+/// | | |
+/// |----------- |
+/// v
+/// +--------------------------------+
+/// | end: |
+/// | <code after the AtomicRMWOp> |
+/// +--------------------------------+
+///
+struct GenericAtomicRMWOpLowering
+ : public LoadStoreOpLowering<GenericAtomicRMWOp> {
+ using Base::Base;
+
+ LogicalResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto atomicOp = cast<GenericAtomicRMWOp>(op);
+
+ auto loc = op->getLoc();
+ OperandAdaptor<GenericAtomicRMWOp> adaptor(operands);
+ LLVM::LLVMType valueType =
+ typeConverter.convertType(atomicOp.getResult().getType())
+ .cast<LLVM::LLVMType>();
+
+ // Split the block into initial, loop, and ending parts.
+ auto *initBlock = rewriter.getInsertionBlock();
+ auto initPosition = rewriter.getInsertionPoint();
+ auto *loopBlock = rewriter.splitBlock(initBlock, initPosition);
+ auto loopArgument = loopBlock->addArgument(valueType);
+ auto loopPosition = rewriter.getInsertionPoint();
+ auto *endBlock = rewriter.splitBlock(loopBlock, loopPosition);
+
+ // Compute the loaded value and branch to the loop block.
+ rewriter.setInsertionPointToEnd(initBlock);
+ auto memRefType = atomicOp.memref().getType().cast<MemRefType>();
+ auto dataPtr = getDataPtr(loc, memRefType, adaptor.memref(),
+ adaptor.indices(), rewriter, getModule());
+ Value init = rewriter.create<LLVM::LoadOp>(loc, dataPtr);
+ rewriter.create<LLVM::BrOp>(loc, init, loopBlock);
+
+ // Prepare the body of the loop block.
+ rewriter.setInsertionPointToStart(loopBlock);
+ auto boolType = LLVM::LLVMType::getInt1Ty(&getDialect());
+
+ // Clone the GenericAtomicRMWOp region and extract the result.
+ BlockAndValueMapping mapping;
+ mapping.map(atomicOp.getCurrentValue(), loopArgument);
+ Block &entryBlock = atomicOp.body().front();
+ for (auto &nestedOp : entryBlock.without_terminator()) {
+ Operation *clone = rewriter.clone(nestedOp, mapping);
+ mapping.map(nestedOp.getResults(), clone->getResults());
+ }
+ Value result = mapping.lookup(entryBlock.getTerminator()->getOperand(0));
+
+ // Prepare the epilog of the loop block.
+ rewriter.setInsertionPointToEnd(loopBlock);
+ // Append the cmpxchg op to the end of the loop block.
+ auto successOrdering = LLVM::AtomicOrdering::acq_rel;
+ auto failureOrdering = LLVM::AtomicOrdering::monotonic;
+ auto pairType = LLVM::LLVMType::getStructTy(valueType, boolType);
+ auto cmpxchg = rewriter.create<LLVM::AtomicCmpXchgOp>(
+ loc, pairType, dataPtr, loopArgument, result, successOrdering,
+ failureOrdering);
+ // Extract the %new_loaded and %ok values from the pair.
+ Value newLoaded = rewriter.create<LLVM::ExtractValueOp>(
+ loc, valueType, cmpxchg, rewriter.getI64ArrayAttr({0}));
+ Value ok = rewriter.create<LLVM::ExtractValueOp>(
+ loc, boolType, cmpxchg, rewriter.getI64ArrayAttr({1}));
+
+ // Conditionally branch to the end or back to the loop depending on %ok.
+ rewriter.create<LLVM::CondBrOp>(loc, ok, endBlock, ArrayRef<Value>(),
+ loopBlock, newLoaded);
+
+ // The 'result' of the atomic_rmw op is the newly loaded value.
+ rewriter.replaceOp(op, {newLoaded});
+
+ return success();
+ }
+};
+
} // namespace
/// Collect a set of patterns to convert from the Standard dialect to LLVM.
@@ -2775,6 +2874,7 @@ void mlir::populateStdToLLVMNonMemoryConversionPatterns(
DivFOpLowering,
ExpOpLowering,
Exp2OpLowering,
+ GenericAtomicRMWOpLowering,
LogOpLowering,
Log10OpLowering,
Log2OpLowering,
diff --git a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
index 9c8d47db05b8..2b6038dfe776 100644
--- a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
+++ b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
@@ -1029,6 +1029,30 @@ func @cmpxchg(%F : memref<10xf32>, %fval : f32, %i : index) -> f32 {
// -----
+// CHECK-LABEL: func @generic_atomic_rmw
+// CHECK32-LABEL: func @generic_atomic_rmw
+func @generic_atomic_rmw(%I : memref<10xf32>, %i : index) -> f32 {
+ %x = generic_atomic_rmw %I[%i] : memref<10xf32> {
+ ^bb0(%old_value : f32):
+ %c1 = constant 1.0 : f32
+ atomic_yield %c1 : f32
+ }
+ // CHECK: [[init:%.*]] = llvm.load %{{.*}} : !llvm<"float*">
+ // CHECK-NEXT: llvm.br ^bb1([[init]] : !llvm.float)
+ // CHECK-NEXT: ^bb1([[loaded:%.*]]: !llvm.float):
+ // CHECK-NEXT: [[c1:%.*]] = llvm.mlir.constant(1.000000e+00 : f32)
+ // CHECK-NEXT: [[pair:%.*]] = llvm.cmpxchg %{{.*}}, [[loaded]], [[c1]]
+ // CHECK-SAME: acq_rel monotonic : !llvm.float
+ // CHECK-NEXT: [[new:%.*]] = llvm.extractvalue [[pair]][0]
+ // CHECK-NEXT: [[ok:%.*]] = llvm.extractvalue [[pair]][1]
+ // CHECK-NEXT: llvm.cond_br [[ok]], ^bb2, ^bb1([[new]] : !llvm.float)
+ // CHECK-NEXT: ^bb2:
+ return %x : f32
+ // CHECK-NEXT: llvm.return [[new]]
+}
+
+// -----
+
// CHECK-LABEL: func @assume_alignment
func @assume_alignment(%0 : memref<4x4xf16>) {
// CHECK: %[[PTR:.*]] = llvm.extractvalue %[[MEMREF:.*]][1] : !llvm<"{ half*, half*, i64, [2 x i64], [2 x i64] }">
More information about the Mlir-commits
mailing list