[Mlir-commits] [mlir] fe210a1 - [MLIR] Add std.atomic_rmw op

Frank Laub llvmlistbot at llvm.org
Mon Feb 24 16:54:32 PST 2020


Author: Frank Laub
Date: 2020-02-24T16:54:21-08:00
New Revision: fe210a1ff2e90093e210bcbcc1184308903c7bdb

URL: https://github.com/llvm/llvm-project/commit/fe210a1ff2e90093e210bcbcc1184308903c7bdb
DIFF: https://github.com/llvm/llvm-project/commit/fe210a1ff2e90093e210bcbcc1184308903c7bdb.diff

LOG: [MLIR] Add std.atomic_rmw op

Summary:
The RFC for this op is here: https://llvm.discourse.group/t/rfc-add-std-atomic-rmw-op/489

The std.atmomic_rmw op provides a way to support read-modify-write
sequences with data race freedom. It is intended to be used in the lowering
of an upcoming affine.atomic_rmw op which can be used for reductions.

A lowering to LLVM is provided with 2 paths:
- Simple patterns: llvm.atomicrmw
- Everything else: llvm.cmpxchg

Differential Revision: https://reviews.llvm.org/D74401

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
    mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
    mlir/lib/Dialect/StandardOps/IR/Ops.cpp
    mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
    mlir/test/IR/core-ops.mlir
    mlir/test/IR/invalid-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index 83c781a19b18..85870010f0e2 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -218,6 +218,69 @@ def AndOp : IntArithmeticOp<"and", [Commutative]> {
   let hasFolder = 1;
 }
 
+def ATOMIC_RMW_KIND_ADDF    : I64EnumAttrCase<"addf", 0>;
+def ATOMIC_RMW_KIND_ADDI    : I64EnumAttrCase<"addi", 1>;
+def ATOMIC_RMW_KIND_ASSIGN  : I64EnumAttrCase<"assign", 2>;
+def ATOMIC_RMW_KIND_MAXF    : I64EnumAttrCase<"maxf", 3>;
+def ATOMIC_RMW_KIND_MAXS    : I64EnumAttrCase<"maxs", 4>;
+def ATOMIC_RMW_KIND_MAXU    : I64EnumAttrCase<"maxu", 5>;
+def ATOMIC_RMW_KIND_MINF    : I64EnumAttrCase<"minf", 6>;
+def ATOMIC_RMW_KIND_MINS    : I64EnumAttrCase<"mins", 7>;
+def ATOMIC_RMW_KIND_MINU    : I64EnumAttrCase<"minu", 8>;
+def ATOMIC_RMW_KIND_MULF    : I64EnumAttrCase<"mulf", 9>;
+def ATOMIC_RMW_KIND_MULI    : I64EnumAttrCase<"muli", 10>;
+
+def AtomicRMWKindAttr : I64EnumAttr<
+    "AtomicRMWKind", "",
+    [ATOMIC_RMW_KIND_ADDF, ATOMIC_RMW_KIND_ADDI, ATOMIC_RMW_KIND_ASSIGN,
+     ATOMIC_RMW_KIND_MAXF, ATOMIC_RMW_KIND_MAXS, ATOMIC_RMW_KIND_MAXU,
+     ATOMIC_RMW_KIND_MINF, ATOMIC_RMW_KIND_MINS, ATOMIC_RMW_KIND_MINU,
+     ATOMIC_RMW_KIND_MULF, ATOMIC_RMW_KIND_MULI]> {
+  let cppNamespace = "::mlir";
+}
+
+def AtomicRMWOp : Std_Op<"atomic_rmw", [
+      AllTypesMatch<["value", "result"]>,
+      TypesMatchWith<"value type matches element type of memref",
+                     "memref", "value",
+                     "$_self.cast<MemRefType>().getElementType()">
+    ]> {
+  let summary = "atomic read-modify-write operation";
+  let description = [{
+    The "atomic_rmw" operation provides a way to perform a read-modify-write
+    sequence that is free from data races. The kind enumeration specifies the
+    modification to perform. The value operand represents the new value to be
+    applied during the modification. The memref operand represents the buffer
+    that the read and write will be performed against, as accessed by the
+    specified indices. The arity of the indices is the rank of the memref. The
+    result represents the latest value that was stored.
+
+    Example:
+
+    ```mlir
+      %x = atomic_rmw "addf" %value, %I[%i] : (f32, memref<10xf32>) -> f32
+    ```
+  }];
+
+  let arguments = (ins
+      AtomicRMWKindAttr:$kind,
+      AnyTypeOf<[AnySignlessInteger, AnyFloat]>:$value,
+      MemRefOf<[AnySignlessInteger, AnyFloat]>:$memref,
+      Variadic<Index>:$indices);
+  let results = (outs AnyTypeOf<[AnySignlessInteger, AnyFloat]>:$result);
+
+  let assemblyFormat = [{
+    $kind $value `,` $memref `[` $indices `]` attr-dict `:` `(` type($value) `,`
+    type($memref) `)` `->` type($result)
+  }];
+
+  let extraClassDeclaration = [{
+    MemRefType getMemRefType() {
+      return memref().getType().cast<MemRefType>();
+    }
+  }];
+}
+
 def BranchOp : Std_Op<"br", [Terminator]> {
   let summary = "branch operation";
   let description = [{

diff  --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
index 061d4f9bd095..b5b415e7705d 100644
--- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
@@ -1143,7 +1143,8 @@ struct OneToOneLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> {
   }
 };
 
-template <typename SourceOp, unsigned OpCount> struct OpCountValidator {
+template <typename SourceOp, unsigned OpCount>
+struct OpCountValidator {
   static_assert(
       std::is_base_of<
           typename OpTrait::NOperands<OpCount>::template Impl<SourceOp>,
@@ -1151,12 +1152,14 @@ template <typename SourceOp, unsigned OpCount> struct OpCountValidator {
       "wrong operand count");
 };
 
-template <typename SourceOp> struct OpCountValidator<SourceOp, 1> {
+template <typename SourceOp>
+struct OpCountValidator<SourceOp, 1> {
   static_assert(std::is_base_of<OpTrait::OneOperand<SourceOp>, SourceOp>::value,
                 "expected a single operand");
 };
 
-template <typename SourceOp, unsigned OpCount> void ValidateOpCount() {
+template <typename SourceOp, unsigned OpCount>
+void ValidateOpCount() {
   OpCountValidator<SourceOp, OpCount>();
 }
 
@@ -1524,11 +1527,10 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
       if (strides[index] == MemRefType::getDynamicStrideOrOffset())
         // Identity layout map is enforced in the match function, so we compute:
         //   `runningStride *= sizes[index + 1]`
-        runningStride =
-            runningStride
-                ? rewriter.create<LLVM::MulOp>(loc, runningStride,
-                                               sizes[index + 1])
-                : createIndexConstant(rewriter, loc, 1);
+        runningStride = runningStride
+                            ? rewriter.create<LLVM::MulOp>(loc, runningStride,
+                                                           sizes[index + 1])
+                            : createIndexConstant(rewriter, loc, 1);
       else
         runningStride = createIndexConstant(rewriter, loc, strides[index]);
       strideValues[index] = runningStride;
@@ -2537,6 +2539,170 @@ struct AssumeAlignmentOpLowering
 
 } // namespace
 
+/// Try to match the kind of a std.atomic_rmw to determine whether to use a
+/// lowering to llvm.atomicrmw or fallback to llvm.cmpxchg.
+static Optional<LLVM::AtomicBinOp> matchSimpleAtomicOp(AtomicRMWOp atomicOp) {
+  switch (atomicOp.kind()) {
+  case AtomicRMWKind::addf:
+    return LLVM::AtomicBinOp::fadd;
+  case AtomicRMWKind::addi:
+    return LLVM::AtomicBinOp::add;
+  case AtomicRMWKind::assign:
+    return LLVM::AtomicBinOp::xchg;
+  case AtomicRMWKind::maxs:
+    return LLVM::AtomicBinOp::max;
+  case AtomicRMWKind::maxu:
+    return LLVM::AtomicBinOp::umax;
+  case AtomicRMWKind::mins:
+    return LLVM::AtomicBinOp::min;
+  case AtomicRMWKind::minu:
+    return LLVM::AtomicBinOp::umin;
+  default:
+    return llvm::None;
+  }
+  llvm_unreachable("Invalid AtomicRMWKind");
+}
+
+namespace {
+
+struct AtomicRMWOpLowering : public LoadStoreOpLowering<AtomicRMWOp> {
+  using Base::Base;
+
+  PatternMatchResult
+  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto atomicOp = cast<AtomicRMWOp>(op);
+    auto maybeKind = matchSimpleAtomicOp(atomicOp);
+    if (!maybeKind)
+      return matchFailure();
+    OperandAdaptor<AtomicRMWOp> adaptor(operands);
+    auto resultType = adaptor.value().getType();
+    auto memRefType = atomicOp.getMemRefType();
+    auto dataPtr = getDataPtr(op->getLoc(), memRefType, adaptor.memref(),
+                              adaptor.indices(), rewriter, getModule());
+    rewriter.replaceOpWithNewOp<LLVM::AtomicRMWOp>(
+        op, resultType, *maybeKind, dataPtr, adaptor.value(),
+        LLVM::AtomicOrdering::acq_rel);
+    return matchSuccess();
+  }
+};
+
+/// Wrap a llvm.cmpxchg operation in a while loop so that the operation can be
+/// retried until it succeeds in atomically storing a new value into memory.
+///
+///      +---------------------------------+
+///      |   <code before the AtomicRMWOp> |
+///      |   <compute initial %loaded>     |
+///      |   br loop(%loaded)              |
+///      +---------------------------------+
+///             |
+///  -------|   |
+///  |      v   v
+///  |   +--------------------------------+
+///  |   | loop(%loaded):                 |
+///  |   |   <body contents>              |
+///  |   |   %pair = cmpxchg              |
+///  |   |   %ok = %pair[0]               |
+///  |   |   %new = %pair[1]              |
+///  |   |   cond_br %ok, end, loop(%new) |
+///  |   +--------------------------------+
+///  |          |        |
+///  |-----------        |
+///                      v
+///      +--------------------------------+
+///      | end:                           |
+///      |   <code after the AtomicRMWOp> |
+///      +--------------------------------+
+///
+struct AtomicCmpXchgOpLowering : public LoadStoreOpLowering<AtomicRMWOp> {
+  using Base::Base;
+
+  PatternMatchResult
+  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto atomicOp = cast<AtomicRMWOp>(op);
+    auto maybeKind = matchSimpleAtomicOp(atomicOp);
+    if (maybeKind)
+      return matchFailure();
+
+    LLVM::FCmpPredicate predicate;
+    switch (atomicOp.kind()) {
+    case AtomicRMWKind::maxf:
+      predicate = LLVM::FCmpPredicate::ogt;
+      break;
+    case AtomicRMWKind::minf:
+      predicate = LLVM::FCmpPredicate::olt;
+      break;
+    default:
+      return matchFailure();
+    }
+
+    OperandAdaptor<AtomicRMWOp> adaptor(operands);
+    auto loc = op->getLoc();
+    auto valueType = adaptor.value().getType().cast<LLVM::LLVMType>();
+
+    // Split the block into initial, loop, and ending parts.
+    auto *initBlock = rewriter.getInsertionBlock();
+    auto initPosition = rewriter.getInsertionPoint();
+    auto *loopBlock = rewriter.splitBlock(initBlock, initPosition);
+    auto loopArgument = loopBlock->addArgument(valueType);
+    auto loopPosition = rewriter.getInsertionPoint();
+    auto *endBlock = rewriter.splitBlock(loopBlock, loopPosition);
+
+    // Compute the loaded value and branch to the loop block.
+    rewriter.setInsertionPointToEnd(initBlock);
+    auto memRefType = atomicOp.getMemRefType();
+    auto dataPtr = getDataPtr(loc, memRefType, adaptor.memref(),
+                              adaptor.indices(), rewriter, getModule());
+    auto init = rewriter.create<LLVM::LoadOp>(loc, dataPtr);
+    std::array<Value, 1> brRegionOperands{init};
+    std::array<ValueRange, 1> brOperands{brRegionOperands};
+    rewriter.create<LLVM::BrOp>(loc, ArrayRef<Value>{}, loopBlock, brOperands);
+
+    // Prepare the body of the loop block.
+    rewriter.setInsertionPointToStart(loopBlock);
+    auto predicateI64 =
+        rewriter.getI64IntegerAttr(static_cast<int64_t>(predicate));
+    auto boolType = LLVM::LLVMType::getInt1Ty(&getDialect());
+    auto lhs = loopArgument;
+    auto rhs = adaptor.value();
+    auto cmp =
+        rewriter.create<LLVM::FCmpOp>(loc, boolType, predicateI64, lhs, rhs);
+    auto select = rewriter.create<LLVM::SelectOp>(loc, cmp, lhs, rhs);
+
+    // Prepare the epilog of the loop block.
+    rewriter.setInsertionPointToEnd(loopBlock);
+    // Append the cmpxchg op to the end of the loop block.
+    auto successOrdering = LLVM::AtomicOrdering::acq_rel;
+    auto failureOrdering = LLVM::AtomicOrdering::monotonic;
+    auto pairType = LLVM::LLVMType::getStructTy(valueType, boolType);
+    auto cmpxchg = rewriter.create<LLVM::AtomicCmpXchgOp>(
+        loc, pairType, dataPtr, loopArgument, select, successOrdering,
+        failureOrdering);
+    // Extract the %new_loaded and %ok values from the pair.
+    auto newLoaded = rewriter.create<LLVM::ExtractValueOp>(
+        loc, valueType, cmpxchg, rewriter.getI64ArrayAttr({0}));
+    auto ok = rewriter.create<LLVM::ExtractValueOp>(
+        loc, boolType, cmpxchg, rewriter.getI64ArrayAttr({1}));
+
+    // Conditionally branch to the end or back to the loop depending on %ok.
+    std::array<Value, 1> condBrProperOperands{ok};
+    std::array<Block *, 2> condBrDestinations{endBlock, loopBlock};
+    std::array<Value, 1> condBrRegionOperands{newLoaded};
+    std::array<ValueRange, 2> condBrOperands{ArrayRef<Value>{},
+                                             condBrRegionOperands};
+    rewriter.create<LLVM::CondBrOp>(loc, condBrProperOperands,
+                                    condBrDestinations, condBrOperands);
+
+    // The 'result' of the atomic_rmw op is the newly loaded value.
+    rewriter.replaceOp(op, {newLoaded});
+
+    return matchSuccess();
+  }
+};
+
+} // namespace
+
 static void ensureDistinctSuccessors(Block &bb) {
   auto *terminator = bb.getTerminator();
 
@@ -2594,6 +2760,8 @@ void mlir::populateStdToLLVMNonMemoryConversionPatterns(
       AddFOpLowering,
       AddIOpLowering,
       AndOpLowering,
+      AtomicCmpXchgOpLowering,
+      AtomicRMWOpLowering,
       BranchOpLowering,
       CallIndirectOpLowering,
       CallOpLowering,

diff  --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 5c5fcfc47c11..65bd714d7881 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -135,7 +135,8 @@ static void printStandardCastOp(Operation *op, OpAsmPrinter &p) {
 }
 
 /// A custom cast operation verifier.
-template <typename T> static LogicalResult verifyCastOp(T op) {
+template <typename T>
+static LogicalResult verifyCastOp(T op) {
   auto opType = op.getOperand().getType();
   auto resType = op.getType();
   if (!T::areCastCompatible(opType, resType))
@@ -2614,6 +2615,41 @@ bool FPTruncOp::areCastCompatible(Type a, Type b) {
   return false;
 }
 
+//===----------------------------------------------------------------------===//
+// AtomicRMWOp
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verify(AtomicRMWOp op) {
+  if (op.getMemRefType().getRank() != op.getNumOperands() - 2)
+    return op.emitOpError(
+        "expects the number of subscripts to be equal to memref rank");
+  switch (op.kind()) {
+  case AtomicRMWKind::addf:
+  case AtomicRMWKind::maxf:
+  case AtomicRMWKind::minf:
+  case AtomicRMWKind::mulf:
+    if (!op.value().getType().isa<FloatType>())
+      return op.emitOpError()
+             << "with kind '" << stringifyAtomicRMWKind(op.kind())
+             << "' expects a floating-point type";
+    break;
+  case AtomicRMWKind::addi:
+  case AtomicRMWKind::maxs:
+  case AtomicRMWKind::maxu:
+  case AtomicRMWKind::mins:
+  case AtomicRMWKind::minu:
+  case AtomicRMWKind::muli:
+    if (!op.value().getType().isa<IntegerType>())
+      return op.emitOpError()
+             << "with kind '" << stringifyAtomicRMWKind(op.kind())
+             << "' expects an integer type";
+    break;
+  default:
+    break;
+  }
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // TableGen'd op method definitions
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
index 8839514937e0..27c249372b15 100644
--- a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
+++ b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
@@ -858,6 +858,46 @@ module {
 
 // -----
 
+// CHECK-LABEL: func @atomic_rmw
+func @atomic_rmw(%I : memref<10xi32>, %ival : i32, %F : memref<10xf32>, %fval : f32, %i : index) {
+  atomic_rmw "assign" %fval, %F[%i] : (f32, memref<10xf32>) -> f32
+  // CHECK: llvm.atomicrmw xchg %{{.*}}, %{{.*}} acq_rel
+  atomic_rmw "addi" %ival, %I[%i] : (i32, memref<10xi32>) -> i32
+  // CHECK: llvm.atomicrmw add %{{.*}}, %{{.*}} acq_rel
+  atomic_rmw "maxs" %ival, %I[%i] : (i32, memref<10xi32>) -> i32
+  // CHECK: llvm.atomicrmw max %{{.*}}, %{{.*}} acq_rel
+  atomic_rmw "mins" %ival, %I[%i] : (i32, memref<10xi32>) -> i32
+  // CHECK: llvm.atomicrmw min %{{.*}}, %{{.*}} acq_rel
+  atomic_rmw "maxu" %ival, %I[%i] : (i32, memref<10xi32>) -> i32
+  // CHECK: llvm.atomicrmw umax %{{.*}}, %{{.*}} acq_rel
+  atomic_rmw "minu" %ival, %I[%i] : (i32, memref<10xi32>) -> i32
+  // CHECK: llvm.atomicrmw umin %{{.*}}, %{{.*}} acq_rel
+  atomic_rmw "addf" %fval, %F[%i] : (f32, memref<10xf32>) -> f32
+  // CHECK: llvm.atomicrmw fadd %{{.*}}, %{{.*}} acq_rel
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @cmpxchg
+func @cmpxchg(%F : memref<10xf32>, %fval : f32, %i : index) -> f32 {
+  %x = atomic_rmw "maxf" %fval, %F[%i] : (f32, memref<10xf32>) -> f32
+  // CHECK: %[[init:.*]] = llvm.load %{{.*}} : !llvm<"float*">
+  // CHECK-NEXT: llvm.br ^bb1(%[[init]] : !llvm.float)
+  // CHECK-NEXT: ^bb1(%[[loaded:.*]]: !llvm.float):
+  // CHECK-NEXT: %[[cmp:.*]] = llvm.fcmp "ogt" %[[loaded]], %{{.*}} : !llvm.float
+  // CHECK-NEXT: %[[max:.*]] = llvm.select %[[cmp]], %[[loaded]], %{{.*}} : !llvm.i1, !llvm.float
+  // CHECK-NEXT: %[[pair:.*]] = llvm.cmpxchg %{{.*}}, %[[loaded]], %[[max]] acq_rel monotonic : !llvm.float
+  // CHECK-NEXT: %[[new:.*]] = llvm.extractvalue %[[pair]][0] : !llvm<"{ float, i1 }">
+  // CHECK-NEXT: %[[ok:.*]] = llvm.extractvalue %[[pair]][1] : !llvm<"{ float, i1 }">
+  // CHECK-NEXT: llvm.cond_br %[[ok]], ^bb2, ^bb1(%[[new]] : !llvm.float)
+  // CHECK-NEXT: ^bb2:
+  return %x : f32
+  // CHECK-NEXT: llvm.return %[[new]]
+}
+
+// -----
+
 // CHECK-LABEL: func @assume_alignment
 func @assume_alignment(%0 : memref<4x4xf16>) {
   // CHECK: %[[PTR:.*]] = llvm.extractvalue %[[MEMREF:.*]][1] : !llvm<"{ half*, half*, i64, [2 x i64], [2 x i64] }">

diff  --git a/mlir/test/IR/core-ops.mlir b/mlir/test/IR/core-ops.mlir
index 382b1602df0d..c07931f01f8c 100644
--- a/mlir/test/IR/core-ops.mlir
+++ b/mlir/test/IR/core-ops.mlir
@@ -741,6 +741,13 @@ func @tensor_load_store(%0 : memref<4x4xi32>) {
   return
 }
 
+// CHECK-LABEL: func @atomic_rmw
+func @atomic_rmw(%I: memref<10xf32>, %val: f32, %i : index) {
+  // CHECK: %{{.*}} = atomic_rmw "addf" %{{.*}}, %{{.*}}[%{{.*}}]
+  %x = atomic_rmw "addf" %val, %I[%i] : (f32, memref<10xf32>) -> f32
+  return
+}
+
 // CHECK-LABEL: func @assume_alignment
 // CHECK-SAME: %[[MEMREF:.*]]: memref<4x4xf16>
 func @assume_alignment(%0: memref<4x4xf16>) {

diff  --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir
index 20f90c76e3d1..7cc0331bd484 100644
--- a/mlir/test/IR/invalid-ops.mlir
+++ b/mlir/test/IR/invalid-ops.mlir
@@ -1039,6 +1039,30 @@ func @invalid_memref_cast() {
 
 // -----
 
+func @atomic_rmw_idxs_rank_mismatch(%I: memref<16x10xf32>, %i : index, %val : f32) {
+  // expected-error at +1 {{expects the number of subscripts to be equal to memref rank}}
+  %x = atomic_rmw "addf" %val, %I[%i] : (f32, memref<16x10xf32>) -> f32
+  return
+}
+
+// -----
+
+func @atomic_rmw_expects_float(%I: memref<16x10xi32>, %i : index, %val : i32) {
+  // expected-error at +1 {{expects a floating-point type}}
+  %x = atomic_rmw "addf" %val, %I[%i, %i] : (i32, memref<16x10xi32>) -> i32
+  return
+}
+
+// -----
+
+func @atomic_rmw_expects_int(%I: memref<16x10xf32>, %i : index, %val : f32) {
+  // expected-error at +1 {{expects an integer type}}
+  %x = atomic_rmw "addi" %val, %I[%i, %i] : (f32, memref<16x10xf32>) -> f32
+  return
+}
+
+// -----
+
 // alignment is not power of 2.
 func @assume_alignment(%0: memref<4x4xf16>) {
   // expected-error at +1 {{alignment must be power of 2}}


        


More information about the Mlir-commits mailing list