[Mlir-commits] [mlir] 871beba - [MLIR] Add AtomicRMWRegionOp.
Alexander Belyaev
llvmlistbot at llvm.org
Mon Apr 20 08:13:58 PDT 2020
Author: Alexander Belyaev
Date: 2020-04-20T17:13:28+02:00
New Revision: 871beba234a83a2a02da9dedbd59b91a1bfbd7af
URL: https://github.com/llvm/llvm-project/commit/871beba234a83a2a02da9dedbd59b91a1bfbd7af
DIFF: https://github.com/llvm/llvm-project/commit/871beba234a83a2a02da9dedbd59b91a1bfbd7af.diff
LOG: [MLIR] Add AtomicRMWRegionOp.
https://llvm.discourse.group/t/rfc-add-std-atomic-rmw-op/489
Differensial revision: https://reviews.llvm.org/D78352
Added:
Modified:
mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
mlir/lib/Dialect/StandardOps/IR/Ops.cpp
mlir/test/IR/core-ops.mlir
mlir/test/IR/invalid-ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index 00e50fc6d660..45eaa464fa8d 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -484,6 +484,77 @@ def AtomicRMWOp : Std_Op<"atomic_rmw", [
}];
}
+def GenericAtomicRMWOp : Std_Op<"generic_atomic_rmw", [
+ SingleBlockImplicitTerminator<"AtomicYieldOp">,
+ TypesMatchWith<"result type matches element type of memref",
+ "memref", "result",
+ "$_self.cast<MemRefType>().getElementType()">
+ ]> {
+ let summary = "atomic read-modify-write operation with a region";
+ let description = [{
+ The `atomic_rmw` operation provides a way to perform a read-modify-write
+ sequence that is free from data races. The memref operand represents the
+ buffer that the read and write will be performed against, as accessed by
+ the specified indices. The arity of the indices is the rank of the memref.
+ The result represents the latest value that was stored. The region contains
+ the code for the modification itself. The entry block has a single argument
+ that represents the value stored in `memref[indices]` before the write is
+ performed.
+
+ Example:
+
+ ```mlir
+ %x = generic_atomic_rmw %I[%i] : memref<10xf32> {
+ ^bb0(%current_value : f32):
+ %c1 = constant 1.0 : f32
+ %inc = addf %c1, %current_value : f32
+ atomic_yield %inc : f32
+ }
+ ```
+ }];
+
+ let arguments = (ins
+ MemRefOf<[AnySignlessInteger, AnyFloat]>:$memref,
+ Variadic<Index>:$indices);
+
+ let results = (outs
+ AnyTypeOf<[AnySignlessInteger, AnyFloat]>:$result);
+
+ let regions = (region AnyRegion:$body);
+
+ let skipDefaultBuilders = 1;
+ let builders = [
+ OpBuilder<"Builder *builder, OperationState &result, "
+ "Value memref, ValueRange ivs">
+ ];
+
+ let extraClassDeclaration = [{
+ OpBuilder getBodyBuilder() {
+ assert(!body().empty() && "Unexpected empty 'body' region.");
+ Block &block = body().front();
+ return OpBuilder(&block, block.end());
+ }
+ // The value stored in memref[ivs].
+ Value getCurrentValue() {
+ return body().front().getArgument(0);
+ }
+ }];
+}
+
+def AtomicYieldOp : Std_Op<"atomic_yield", [
+ HasParent<"GenericAtomicRMWOp">,
+ NoSideEffect,
+ Terminator
+ ]> {
+ let summary = "yield operation for GenericAtomicRMWOp";
+ let description = [{
+ "atomic_yield" yields an SSA value from a GenericAtomicRMWOp region.
+ }];
+
+ let arguments = (ins AnyType:$result);
+ let assemblyFormat = "$result attr-dict `:` type($result)";
+}
+
//===----------------------------------------------------------------------===//
// BranchOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 452dade61e32..2a97a9415058 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -480,6 +480,77 @@ static LogicalResult verify(AtomicRMWOp op) {
return success();
}
+//===----------------------------------------------------------------------===//
+// GenericAtomicRMWOp
+//===----------------------------------------------------------------------===//
+
+void GenericAtomicRMWOp::build(Builder *builder, OperationState &result,
+ Value memref, ValueRange ivs) {
+ result.addOperands(memref);
+ result.addOperands(ivs);
+
+ if (auto memrefType = memref.getType().dyn_cast<MemRefType>()) {
+ Type elementType = memrefType.getElementType();
+ result.addTypes(elementType);
+
+ Region *bodyRegion = result.addRegion();
+ bodyRegion->push_back(new Block());
+ bodyRegion->front().addArgument(elementType);
+ }
+}
+
+static LogicalResult verify(GenericAtomicRMWOp op) {
+ auto &block = op.body().front();
+ if (block.getNumArguments() != 1)
+ return op.emitOpError("expected single number of entry block arguments");
+
+ if (op.getResult().getType() != block.getArgument(0).getType())
+ return op.emitOpError(
+ "expected block argument of the same type result type");
+ return success();
+}
+
+static ParseResult parseGenericAtomicRMWOp(OpAsmParser &parser,
+ OperationState &result) {
+ OpAsmParser::OperandType memref;
+ Type memrefType;
+ SmallVector<OpAsmParser::OperandType, 4> ivs;
+
+ Type indexType = parser.getBuilder().getIndexType();
+ if (parser.parseOperand(memref) ||
+ parser.parseOperandList(ivs, OpAsmParser::Delimiter::Square) ||
+ parser.parseColonType(memrefType) ||
+ parser.resolveOperand(memref, memrefType, result.operands) ||
+ parser.resolveOperands(ivs, indexType, result.operands))
+ return failure();
+
+ Region *body = result.addRegion();
+ if (parser.parseRegion(*body, llvm::None, llvm::None))
+ return failure();
+ result.types.push_back(memrefType.cast<MemRefType>().getElementType());
+ return success();
+}
+
+static void print(OpAsmPrinter &p, GenericAtomicRMWOp op) {
+ p << op.getOperationName() << ' ' << op.memref() << "[" << op.indices()
+ << "] : " << op.memref().getType();
+ p.printRegion(op.body());
+ p.printOptionalAttrDict(op.getAttrs());
+}
+
+//===----------------------------------------------------------------------===//
+// AtomicYieldOp
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verify(AtomicYieldOp op) {
+ Type parentType = op.getParentOp()->getResultTypes().front();
+ Type resultType = op.result().getType();
+ if (parentType != resultType)
+ return op.emitOpError() << "types mismatch between yield op: " << resultType
+ << " and its parent: " << parentType;
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// BranchOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/IR/core-ops.mlir b/mlir/test/IR/core-ops.mlir
index b50ffa65f179..d19f3445655a 100644
--- a/mlir/test/IR/core-ops.mlir
+++ b/mlir/test/IR/core-ops.mlir
@@ -751,9 +751,23 @@ func @tensor_load_store(%0 : memref<4x4xi32>) {
}
// CHECK-LABEL: func @atomic_rmw
+// CHECK-SAME: ([[BUF:%.*]]: memref<10xf32>, [[VAL:%.*]]: f32, [[I:%.*]]: index)
func @atomic_rmw(%I: memref<10xf32>, %val: f32, %i : index) {
- // CHECK: %{{.*}} = atomic_rmw "addf" %{{.*}}, %{{.*}}[%{{.*}}]
%x = atomic_rmw "addf" %val, %I[%i] : (f32, memref<10xf32>) -> f32
+ // CHECK: atomic_rmw "addf" [[VAL]], [[BUF]]{{\[}}[[I]]]
+ return
+}
+
+// CHECK-LABEL: func @generic_atomic_rmw
+// CHECK-SAME: ([[BUF:%.*]]: memref<1x2xf32>, [[I:%.*]]: index, [[J:%.*]]: index)
+func @generic_atomic_rmw(%I: memref<1x2xf32>, %i : index, %j : index) {
+ %x = generic_atomic_rmw %I[%i, %j] : memref<1x2xf32> {
+ // CHECK-NEXT: generic_atomic_rmw [[BUF]]{{\[}}[[I]], [[J]]] : memref
+ ^bb0(%old_value : f32):
+ %c1 = constant 1.0 : f32
+ %out = addf %c1, %old_value : f32
+ atomic_yield %out : f32
+ }
return
}
diff --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir
index c555776d9d35..17eaded116e6 100644
--- a/mlir/test/IR/invalid-ops.mlir
+++ b/mlir/test/IR/invalid-ops.mlir
@@ -1143,6 +1143,54 @@ func @atomic_rmw_expects_int(%I: memref<16x10xf32>, %i : index, %val : f32) {
// -----
+func @generic_atomic_rmw_wrong_arg_num(%I: memref<10xf32>, %i : index) {
+ // expected-error at +1 {{expected single number of entry block arguments}}
+ %x = generic_atomic_rmw %I[%i] : memref<10xf32> {
+ ^bb0(%arg0 : f32, %arg1 : f32):
+ %c1 = constant 1.0 : f32
+ atomic_yield %c1 : f32
+ }
+ return
+}
+
+// -----
+
+func @generic_atomic_rmw_wrong_arg_type(%I: memref<10xf32>, %i : index) {
+ // expected-error at +1 {{expected block argument of the same type result type}}
+ %x = generic_atomic_rmw %I[%i] : memref<10xf32> {
+ ^bb0(%old_value : i32):
+ %c1 = constant 1.0 : f32
+ atomic_yield %c1 : f32
+ }
+ return
+}
+
+// -----
+
+func @generic_atomic_rmw_result_type_mismatch(%I: memref<10xf32>, %i : index) {
+ // expected-error at +1 {{failed to verify that result type matches element type of memref}}
+ %0 = "std.generic_atomic_rmw"(%I, %i) ( {
+ ^bb0(%old_value: f32):
+ %c1 = constant 1.0 : f32
+ atomic_yield %c1 : f32
+ }) : (memref<10xf32>, index) -> i32
+ return
+}
+
+// -----
+
+func @atomic_yield_type_mismatch(%I: memref<10xf32>, %i : index) {
+ // expected-error at +4 {{op types mismatch between yield op: 'i32' and its parent: 'f32'}}
+ %x = generic_atomic_rmw %I[%i] : memref<10xf32> {
+ ^bb0(%old_value : f32):
+ %c1 = constant 1 : i32
+ atomic_yield %c1 : i32
+ }
+ return
+}
+
+// -----
+
// alignment is not power of 2.
func @assume_alignment(%0: memref<4x4xf16>) {
// expected-error at +1 {{alignment must be power of 2}}
More information about the Mlir-commits
mailing list