[Mlir-commits] [mlir] 871beba - [MLIR] Add AtomicRMWRegionOp.

Alexander Belyaev llvmlistbot at llvm.org
Mon Apr 20 08:13:58 PDT 2020


Author: Alexander Belyaev
Date: 2020-04-20T17:13:28+02:00
New Revision: 871beba234a83a2a02da9dedbd59b91a1bfbd7af

URL: https://github.com/llvm/llvm-project/commit/871beba234a83a2a02da9dedbd59b91a1bfbd7af
DIFF: https://github.com/llvm/llvm-project/commit/871beba234a83a2a02da9dedbd59b91a1bfbd7af.diff

LOG: [MLIR] Add AtomicRMWRegionOp.

https://llvm.discourse.group/t/rfc-add-std-atomic-rmw-op/489

Differensial revision: https://reviews.llvm.org/D78352

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
    mlir/lib/Dialect/StandardOps/IR/Ops.cpp
    mlir/test/IR/core-ops.mlir
    mlir/test/IR/invalid-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index 00e50fc6d660..45eaa464fa8d 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -484,6 +484,77 @@ def AtomicRMWOp : Std_Op<"atomic_rmw", [
   }];
 }
 
+def GenericAtomicRMWOp : Std_Op<"generic_atomic_rmw", [
+      SingleBlockImplicitTerminator<"AtomicYieldOp">,
+      TypesMatchWith<"result type matches element type of memref",
+                     "memref", "result",
+                     "$_self.cast<MemRefType>().getElementType()">
+    ]> {
+  let summary = "atomic read-modify-write operation with a region";
+  let description = [{
+    The `atomic_rmw` operation provides a way to perform a read-modify-write
+    sequence that is free from data races. The memref operand represents the
+    buffer that the read and write will be performed against, as accessed by
+    the specified indices. The arity of the indices is the rank of the memref.
+    The result represents the latest value that was stored. The region contains
+    the code for the modification itself. The entry block has a single argument
+    that represents the value stored in `memref[indices]` before the write is
+    performed.
+
+    Example:
+
+    ```mlir
+    %x = generic_atomic_rmw %I[%i] : memref<10xf32> {
+      ^bb0(%current_value : f32):
+        %c1 = constant 1.0 : f32
+        %inc = addf %c1, %current_value : f32
+        atomic_yield %inc : f32
+    }
+    ```
+  }];
+
+  let arguments = (ins
+      MemRefOf<[AnySignlessInteger, AnyFloat]>:$memref,
+      Variadic<Index>:$indices);
+
+  let results = (outs
+      AnyTypeOf<[AnySignlessInteger, AnyFloat]>:$result);
+
+  let regions = (region AnyRegion:$body);
+
+  let skipDefaultBuilders = 1;
+  let builders = [
+    OpBuilder<"Builder *builder, OperationState &result, "
+              "Value memref, ValueRange ivs">
+  ];
+
+  let extraClassDeclaration = [{
+    OpBuilder getBodyBuilder() {
+      assert(!body().empty() && "Unexpected empty 'body' region.");
+      Block &block = body().front();
+      return OpBuilder(&block, block.end());
+    }
+    // The value stored in memref[ivs].
+    Value getCurrentValue() {
+      return body().front().getArgument(0);
+    }
+  }];
+}
+
+def AtomicYieldOp : Std_Op<"atomic_yield", [
+      HasParent<"GenericAtomicRMWOp">,
+      NoSideEffect,
+      Terminator
+    ]> {
+  let summary = "yield operation for GenericAtomicRMWOp";
+  let description = [{
+    "atomic_yield" yields an SSA value from a GenericAtomicRMWOp region.
+  }];
+
+  let arguments = (ins AnyType:$result);
+  let assemblyFormat = "$result attr-dict `:` type($result)";
+}
+
 //===----------------------------------------------------------------------===//
 // BranchOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 452dade61e32..2a97a9415058 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -480,6 +480,77 @@ static LogicalResult verify(AtomicRMWOp op) {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// GenericAtomicRMWOp
+//===----------------------------------------------------------------------===//
+
+void GenericAtomicRMWOp::build(Builder *builder, OperationState &result,
+                               Value memref, ValueRange ivs) {
+  result.addOperands(memref);
+  result.addOperands(ivs);
+
+  if (auto memrefType = memref.getType().dyn_cast<MemRefType>()) {
+    Type elementType = memrefType.getElementType();
+    result.addTypes(elementType);
+
+    Region *bodyRegion = result.addRegion();
+    bodyRegion->push_back(new Block());
+    bodyRegion->front().addArgument(elementType);
+  }
+}
+
+static LogicalResult verify(GenericAtomicRMWOp op) {
+  auto &block = op.body().front();
+  if (block.getNumArguments() != 1)
+    return op.emitOpError("expected single number of entry block arguments");
+
+  if (op.getResult().getType() != block.getArgument(0).getType())
+    return op.emitOpError(
+        "expected block argument of the same type result type");
+  return success();
+}
+
+static ParseResult parseGenericAtomicRMWOp(OpAsmParser &parser,
+                                           OperationState &result) {
+  OpAsmParser::OperandType memref;
+  Type memrefType;
+  SmallVector<OpAsmParser::OperandType, 4> ivs;
+
+  Type indexType = parser.getBuilder().getIndexType();
+  if (parser.parseOperand(memref) ||
+      parser.parseOperandList(ivs, OpAsmParser::Delimiter::Square) ||
+      parser.parseColonType(memrefType) ||
+      parser.resolveOperand(memref, memrefType, result.operands) ||
+      parser.resolveOperands(ivs, indexType, result.operands))
+    return failure();
+
+  Region *body = result.addRegion();
+  if (parser.parseRegion(*body, llvm::None, llvm::None))
+    return failure();
+  result.types.push_back(memrefType.cast<MemRefType>().getElementType());
+  return success();
+}
+
+static void print(OpAsmPrinter &p, GenericAtomicRMWOp op) {
+  p << op.getOperationName() << ' ' << op.memref() << "[" << op.indices()
+    << "] : " << op.memref().getType();
+  p.printRegion(op.body());
+  p.printOptionalAttrDict(op.getAttrs());
+}
+
+//===----------------------------------------------------------------------===//
+// AtomicYieldOp
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verify(AtomicYieldOp op) {
+  Type parentType = op.getParentOp()->getResultTypes().front();
+  Type resultType = op.result().getType();
+  if (parentType != resultType)
+    return op.emitOpError() << "types mismatch between yield op: " << resultType
+                            << " and its parent: " << parentType;
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // BranchOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/IR/core-ops.mlir b/mlir/test/IR/core-ops.mlir
index b50ffa65f179..d19f3445655a 100644
--- a/mlir/test/IR/core-ops.mlir
+++ b/mlir/test/IR/core-ops.mlir
@@ -751,9 +751,23 @@ func @tensor_load_store(%0 : memref<4x4xi32>) {
 }
 
 // CHECK-LABEL: func @atomic_rmw
+// CHECK-SAME: ([[BUF:%.*]]: memref<10xf32>, [[VAL:%.*]]: f32, [[I:%.*]]: index)
 func @atomic_rmw(%I: memref<10xf32>, %val: f32, %i : index) {
-  // CHECK: %{{.*}} = atomic_rmw "addf" %{{.*}}, %{{.*}}[%{{.*}}]
   %x = atomic_rmw "addf" %val, %I[%i] : (f32, memref<10xf32>) -> f32
+  // CHECK: atomic_rmw "addf" [[VAL]], [[BUF]]{{\[}}[[I]]]
+  return
+}
+
+// CHECK-LABEL: func @generic_atomic_rmw
+// CHECK-SAME: ([[BUF:%.*]]: memref<1x2xf32>, [[I:%.*]]: index, [[J:%.*]]: index)
+func @generic_atomic_rmw(%I: memref<1x2xf32>, %i : index, %j : index) {
+  %x = generic_atomic_rmw %I[%i, %j] : memref<1x2xf32> {
+  // CHECK-NEXT: generic_atomic_rmw [[BUF]]{{\[}}[[I]], [[J]]] : memref
+    ^bb0(%old_value : f32):
+      %c1 = constant 1.0 : f32
+      %out = addf %c1, %old_value : f32
+      atomic_yield %out : f32
+  }
   return
 }
 

diff  --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir
index c555776d9d35..17eaded116e6 100644
--- a/mlir/test/IR/invalid-ops.mlir
+++ b/mlir/test/IR/invalid-ops.mlir
@@ -1143,6 +1143,54 @@ func @atomic_rmw_expects_int(%I: memref<16x10xf32>, %i : index, %val : f32) {
 
 // -----
 
+func @generic_atomic_rmw_wrong_arg_num(%I: memref<10xf32>, %i : index) {
+  // expected-error at +1 {{expected single number of entry block arguments}}
+  %x = generic_atomic_rmw %I[%i] : memref<10xf32> {
+    ^bb0(%arg0 : f32, %arg1 : f32):
+      %c1 = constant 1.0 : f32
+      atomic_yield %c1 : f32
+  }
+  return
+}
+
+// -----
+
+func @generic_atomic_rmw_wrong_arg_type(%I: memref<10xf32>, %i : index) {
+  // expected-error at +1 {{expected block argument of the same type result type}}
+  %x = generic_atomic_rmw %I[%i] : memref<10xf32> {
+    ^bb0(%old_value : i32):
+      %c1 = constant 1.0 : f32
+      atomic_yield %c1 : f32
+  }
+  return
+}
+
+// -----
+
+func @generic_atomic_rmw_result_type_mismatch(%I: memref<10xf32>, %i : index) {
+ // expected-error at +1 {{failed to verify that result type matches element type of memref}}
+ %0 = "std.generic_atomic_rmw"(%I, %i) ( {
+    ^bb0(%old_value: f32):
+      %c1 = constant 1.0 : f32
+      atomic_yield %c1 : f32
+    }) : (memref<10xf32>, index) -> i32
+  return
+}
+
+// -----
+
+func @atomic_yield_type_mismatch(%I: memref<10xf32>, %i : index) {
+  // expected-error at +4 {{op types mismatch between yield op: 'i32' and its parent: 'f32'}}
+  %x = generic_atomic_rmw %I[%i] : memref<10xf32> {
+    ^bb0(%old_value : f32):
+      %c1 = constant 1 : i32
+      atomic_yield %c1 : i32
+  }
+  return
+}
+
+// -----
+
 // alignment is not power of 2.
 func @assume_alignment(%0: memref<4x4xf16>) {
   // expected-error at +1 {{alignment must be power of 2}}


        


More information about the Mlir-commits mailing list