[Mlir-commits] [mlir] a6a583d - [MLIR] Move AtomicRMW into MemRef dialect and enum into Arith

William S. Moses llvmlistbot at llvm.org
Thu Dec 30 11:31:38 PST 2021


Author: William S. Moses
Date: 2021-12-30T14:31:33-05:00
New Revision: a6a583dae40485cacfac56811e6d9131bac6ca74

URL: https://github.com/llvm/llvm-project/commit/a6a583dae40485cacfac56811e6d9131bac6ca74
DIFF: https://github.com/llvm/llvm-project/commit/a6a583dae40485cacfac56811e6d9131bac6ca74.diff

LOG: [MLIR] Move AtomicRMW into MemRef dialect and enum into Arith

Per the discussion in https://reviews.llvm.org/D116345 it makes sense
to move AtomicRMWOp out of the standard dialect. This was accentuated by the
need to add a fold op with a memref::cast. The only dialect
that would permit this is the memref dialect (keeping it in the standard dialect
or moving it to the arithmetic dialect would require those dialects to have a
dependency on the memref dialect, which breaks linking).

As the AtomicRMWKind enum is used throughout, this has been moved to Arith.

Reviewed By: Mogball

Differential Revision: https://reviews.llvm.org/D116392

Added: 
    

Modified: 
    mlir/include/mlir/Analysis/AffineAnalysis.h
    mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
    mlir/include/mlir/Dialect/Arithmetic/IR/Arithmetic.h
    mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticBase.td
    mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
    mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
    mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
    mlir/include/mlir/Dialect/Vector/VectorOps.h
    mlir/lib/Analysis/AffineAnalysis.cpp
    mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
    mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
    mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
    mlir/lib/Dialect/Affine/IR/AffineOps.cpp
    mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp
    mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
    mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
    mlir/lib/Dialect/StandardOps/IR/Ops.cpp
    mlir/lib/Dialect/StandardOps/Transforms/ExpandOps.cpp
    mlir/lib/Dialect/Vector/VectorOps.cpp
    mlir/lib/Transforms/Utils/LoopUtils.cpp
    mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
    mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir
    mlir/test/Dialect/MemRef/canonicalize.mlir
    mlir/test/Dialect/MemRef/invalid.mlir
    mlir/test/Dialect/MemRef/ops.mlir
    mlir/test/Dialect/Standard/expand-ops.mlir
    mlir/test/IR/core-ops.mlir
    mlir/test/IR/invalid-ops.mlir

Removed: 
    mlir/include/mlir/Dialect/StandardOps/IR/StandardOpsBase.td


################################################################################
diff  --git a/mlir/include/mlir/Analysis/AffineAnalysis.h b/mlir/include/mlir/Analysis/AffineAnalysis.h
index 120a5be4596ac..fa793a9e17f84 100644
--- a/mlir/include/mlir/Analysis/AffineAnalysis.h
+++ b/mlir/include/mlir/Analysis/AffineAnalysis.h
@@ -15,6 +15,7 @@
 #ifndef MLIR_ANALYSIS_AFFINE_ANALYSIS_H
 #define MLIR_ANALYSIS_AFFINE_ANALYSIS_H
 
+#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/IR/Value.h"
 #include "llvm/ADT/Optional.h"
@@ -32,7 +33,7 @@ class Operation;
 /// A description of a (parallelizable) reduction in an affine loop.
 struct LoopReduction {
   /// Reduction kind.
-  AtomicRMWKind kind;
+  arith::AtomicRMWKind kind;
 
   /// Position of the iteration argument that acts as accumulator.
   unsigned iterArgPosition;

diff  --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
index c04115476d35e..53b58fa23d342 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
@@ -13,7 +13,7 @@
 #ifndef AFFINE_OPS
 #define AFFINE_OPS
 
-include "mlir/Dialect/StandardOps/IR/StandardOpsBase.td"
+include "mlir/Dialect/Arithmetic/IR/ArithmeticBase.td"
 include "mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.td"
 include "mlir/Interfaces/ControlFlowInterfaces.td"
 include "mlir/Interfaces/LoopLikeInterface.td"
@@ -691,9 +691,9 @@ def AffineParallelOp : Affine_Op<"parallel",
 
   let builders = [
     OpBuilder<(ins "TypeRange":$resultTypes,
-      "ArrayRef<AtomicRMWKind>":$reductions, "ArrayRef<int64_t>":$ranges)>,
+      "ArrayRef<arith::AtomicRMWKind>":$reductions, "ArrayRef<int64_t>":$ranges)>,
     OpBuilder<(ins "TypeRange":$resultTypes,
-      "ArrayRef<AtomicRMWKind>":$reductions, "ArrayRef<AffineMap>":$lbMaps,
+      "ArrayRef<arith::AtomicRMWKind>":$reductions, "ArrayRef<AffineMap>":$lbMaps,
       "ValueRange":$lbArgs, "ArrayRef<AffineMap>":$ubMaps, "ValueRange":$ubArgs,
       "ArrayRef<int64_t>":$steps)>
   ];

diff  --git a/mlir/include/mlir/Dialect/Arithmetic/IR/Arithmetic.h b/mlir/include/mlir/Dialect/Arithmetic/IR/Arithmetic.h
index 4fa592d003a0f..31d6239388454 100644
--- a/mlir/include/mlir/Dialect/Arithmetic/IR/Arithmetic.h
+++ b/mlir/include/mlir/Dialect/Arithmetic/IR/Arithmetic.h
@@ -109,6 +109,18 @@ bool applyCmpPredicate(arith::CmpIPredicate predicate, const APInt &lhs,
 bool applyCmpPredicate(arith::CmpFPredicate predicate, const APFloat &lhs,
                        const APFloat &rhs);
 
+/// Returns the identity value attribute associated with an AtomicRMWKind op.
+Attribute getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
+                               OpBuilder &builder, Location loc);
+
+/// Returns the identity value associated with an AtomicRMWKind op.
+Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder,
+                       Location loc);
+
+/// Returns the value obtained by applying the reduction operation kind
+/// associated with a binary AtomicRMWKind op to `lhs` and `rhs`.
+Value getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc,
+                     Value lhs, Value rhs);
 } // namespace arith
 } // namespace mlir
 

diff  --git a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticBase.td b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticBase.td
index 87439da956407..704edbb587bd3 100644
--- a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticBase.td
+++ b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticBase.td
@@ -68,4 +68,28 @@ def Arith_CmpIPredicateAttr : I64EnumAttr<
   let cppNamespace = "::mlir::arith";
 }
 
+def ATOMIC_RMW_KIND_ADDF    : I64EnumAttrCase<"addf", 0>;
+def ATOMIC_RMW_KIND_ADDI    : I64EnumAttrCase<"addi", 1>;
+def ATOMIC_RMW_KIND_ASSIGN  : I64EnumAttrCase<"assign", 2>;
+def ATOMIC_RMW_KIND_MAXF    : I64EnumAttrCase<"maxf", 3>;
+def ATOMIC_RMW_KIND_MAXS    : I64EnumAttrCase<"maxs", 4>;
+def ATOMIC_RMW_KIND_MAXU    : I64EnumAttrCase<"maxu", 5>;
+def ATOMIC_RMW_KIND_MINF    : I64EnumAttrCase<"minf", 6>;
+def ATOMIC_RMW_KIND_MINS    : I64EnumAttrCase<"mins", 7>;
+def ATOMIC_RMW_KIND_MINU    : I64EnumAttrCase<"minu", 8>;
+def ATOMIC_RMW_KIND_MULF    : I64EnumAttrCase<"mulf", 9>;
+def ATOMIC_RMW_KIND_MULI    : I64EnumAttrCase<"muli", 10>;
+def ATOMIC_RMW_KIND_ORI     : I64EnumAttrCase<"ori", 11>;
+def ATOMIC_RMW_KIND_ANDI    : I64EnumAttrCase<"andi", 12>;
+
+def AtomicRMWKindAttr : I64EnumAttr<
+    "AtomicRMWKind", "",
+    [ATOMIC_RMW_KIND_ADDF, ATOMIC_RMW_KIND_ADDI, ATOMIC_RMW_KIND_ASSIGN,
+     ATOMIC_RMW_KIND_MAXF, ATOMIC_RMW_KIND_MAXS, ATOMIC_RMW_KIND_MAXU,
+     ATOMIC_RMW_KIND_MINF, ATOMIC_RMW_KIND_MINS, ATOMIC_RMW_KIND_MINU,
+     ATOMIC_RMW_KIND_MULF, ATOMIC_RMW_KIND_MULI, ATOMIC_RMW_KIND_ORI,
+     ATOMIC_RMW_KIND_ANDI]> {
+  let cppNamespace = "::mlir::arith";
+}
+
 #endif // ARITHMETIC_BASE

diff  --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index e529a50dae935..884dc0f5b0510 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -11,6 +11,7 @@
 
 include "mlir/Interfaces/ControlFlowInterfaces.td"
 include "mlir/Dialect/MemRef/IR/MemRefBase.td"
+include "mlir/Dialect/Arithmetic/IR/ArithmeticBase.td"
 include "mlir/IR/OpBase.td"
 include "mlir/Interfaces/CastInterfaces.td"
 include "mlir/Interfaces/CopyOpInterface.td"
@@ -1673,4 +1674,51 @@ def MemRef_ViewOp : MemRef_Op<"view", [
   let hasCanonicalizer = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// AtomicRMWOp
+//===----------------------------------------------------------------------===//
+
+def AtomicRMWOp : MemRef_Op<"atomic_rmw", [
+      AllTypesMatch<["value", "result"]>,
+      TypesMatchWith<"value type matches element type of memref",
+                     "memref", "value",
+                     "$_self.cast<MemRefType>().getElementType()">
+    ]> {
+  let summary = "atomic read-modify-write operation";
+  let description = [{
+    The `atomic_rmw` operation provides a way to perform a read-modify-write
+    sequence that is free from data races. The kind enumeration specifies the
+    modification to perform. The value operand represents the new value to be
+    applied during the modification. The memref operand represents the buffer
+    that the read and write will be performed against, as accessed by the
+    specified indices. The arity of the indices is the rank of the memref. The
+    result represents the latest value that was stored.
+
+    Example:
+
+    ```mlir
+    %x = arith.atomic_rmw "addf" %value, %I[%i] : (f32, memref<10xf32>) -> f32
+    ```
+  }];
+
+  let arguments = (ins
+      AtomicRMWKindAttr:$kind,
+      AnyTypeOf<[AnySignlessInteger, AnyFloat]>:$value,
+      MemRefOf<[AnySignlessInteger, AnyFloat]>:$memref,
+      Variadic<Index>:$indices);
+  let results = (outs AnyTypeOf<[AnySignlessInteger, AnyFloat]>:$result);
+
+  let assemblyFormat = [{
+    $kind $value `,` $memref `[` $indices `]` attr-dict `:` `(` type($value) `,`
+    type($memref) `)` `->` type($result)
+  }];
+
+  let extraClassDeclaration = [{
+    MemRefType getMemRefType() {
+      return memref().getType().cast<MemRefType>();
+    }
+  }];
+  let hasFolder = 1;
+}
+
 #endif // MEMREF_OPS

diff  --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
index fe9c3c6e26f1b..b309488aa4f53 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
@@ -42,31 +42,4 @@ class PatternRewriter;
 
 #include "mlir/Dialect/StandardOps/IR/OpsDialect.h.inc"
 
-namespace mlir {
-
-/// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer
-/// comparison predicates.
-bool applyCmpPredicate(arith::CmpIPredicate predicate, const APInt &lhs,
-                       const APInt &rhs);
-
-/// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point
-/// comparison predicates.
-bool applyCmpPredicate(arith::CmpFPredicate predicate, const APFloat &lhs,
-                       const APFloat &rhs);
-
-/// Returns the identity value attribute associated with an AtomicRMWKind op.
-Attribute getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
-                               OpBuilder &builder, Location loc);
-
-/// Returns the identity value associated with an AtomicRMWKind op.
-Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder,
-                       Location loc);
-
-/// Returns the value obtained by applying the reduction operation kind
-/// associated with a binary AtomicRMWKind op to `lhs` and `rhs`.
-Value getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc,
-                     Value lhs, Value rhs);
-
-} // namespace mlir
-
 #endif // MLIR_DIALECT_IR_STANDARDOPS_IR_OPS_H

diff  --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index 2e50971db9e7a..794f0157ef144 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -13,7 +13,6 @@
 #ifndef STANDARD_OPS
 #define STANDARD_OPS
 
-include "mlir/Dialect/StandardOps/IR/StandardOpsBase.td"
 include "mlir/IR/OpAsmInterface.td"
 include "mlir/IR/SymbolInterfaces.td"
 include "mlir/Interfaces/CallInterfaces.td"
@@ -179,52 +178,6 @@ def AssertOp : Std_Op<"assert"> {
   let hasCanonicalizeMethod = 1;
 }
 
-//===----------------------------------------------------------------------===//
-// AtomicRMWOp
-//===----------------------------------------------------------------------===//
-
-def AtomicRMWOp : Std_Op<"atomic_rmw", [
-      AllTypesMatch<["value", "result"]>,
-      TypesMatchWith<"value type matches element type of memref",
-                     "memref", "value",
-                     "$_self.cast<MemRefType>().getElementType()">
-    ]> {
-  let summary = "atomic read-modify-write operation";
-  let description = [{
-    The `atomic_rmw` operation provides a way to perform a read-modify-write
-    sequence that is free from data races. The kind enumeration specifies the
-    modification to perform. The value operand represents the new value to be
-    applied during the modification. The memref operand represents the buffer
-    that the read and write will be performed against, as accessed by the
-    specified indices. The arity of the indices is the rank of the memref. The
-    result represents the latest value that was stored.
-
-    Example:
-
-    ```mlir
-    %x = atomic_rmw "addf" %value, %I[%i] : (f32, memref<10xf32>) -> f32
-    ```
-  }];
-
-  let arguments = (ins
-      AtomicRMWKindAttr:$kind,
-      AnyTypeOf<[AnySignlessInteger, AnyFloat]>:$value,
-      MemRefOf<[AnySignlessInteger, AnyFloat]>:$memref,
-      Variadic<Index>:$indices);
-  let results = (outs AnyTypeOf<[AnySignlessInteger, AnyFloat]>:$result);
-
-  let assemblyFormat = [{
-    $kind $value `,` $memref `[` $indices `]` attr-dict `:` `(` type($value) `,`
-    type($memref) `)` `->` type($result)
-  }];
-
-  let extraClassDeclaration = [{
-    MemRefType getMemRefType() {
-      return getMemref().getType().cast<MemRefType>();
-    }
-  }];
-}
-
 def GenericAtomicRMWOp : Std_Op<"generic_atomic_rmw", [
       SingleBlockImplicitTerminator<"AtomicYieldOp">,
       TypesMatchWith<"result type matches element type of memref",

diff  --git a/mlir/include/mlir/Dialect/StandardOps/IR/StandardOpsBase.td b/mlir/include/mlir/Dialect/StandardOps/IR/StandardOpsBase.td
deleted file mode 100644
index 3016a197df0d0..0000000000000
--- a/mlir/include/mlir/Dialect/StandardOps/IR/StandardOpsBase.td
+++ /dev/null
@@ -1,42 +0,0 @@
-//===- StandardOpsBase.td - Standard ops definitions -------*- tablegen -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// Defines base support for standard operations.
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef STANDARD_OPS_BASE
-#define STANDARD_OPS_BASE
-
-include "mlir/IR/OpBase.td"
-
-def ATOMIC_RMW_KIND_ADDF    : I64EnumAttrCase<"addf", 0>;
-def ATOMIC_RMW_KIND_ADDI    : I64EnumAttrCase<"addi", 1>;
-def ATOMIC_RMW_KIND_ASSIGN  : I64EnumAttrCase<"assign", 2>;
-def ATOMIC_RMW_KIND_MAXF    : I64EnumAttrCase<"maxf", 3>;
-def ATOMIC_RMW_KIND_MAXS    : I64EnumAttrCase<"maxs", 4>;
-def ATOMIC_RMW_KIND_MAXU    : I64EnumAttrCase<"maxu", 5>;
-def ATOMIC_RMW_KIND_MINF    : I64EnumAttrCase<"minf", 6>;
-def ATOMIC_RMW_KIND_MINS    : I64EnumAttrCase<"mins", 7>;
-def ATOMIC_RMW_KIND_MINU    : I64EnumAttrCase<"minu", 8>;
-def ATOMIC_RMW_KIND_MULF    : I64EnumAttrCase<"mulf", 9>;
-def ATOMIC_RMW_KIND_MULI    : I64EnumAttrCase<"muli", 10>;
-def ATOMIC_RMW_KIND_ORI     : I64EnumAttrCase<"ori", 11>;
-def ATOMIC_RMW_KIND_ANDI    : I64EnumAttrCase<"andi", 12>;
-
-def AtomicRMWKindAttr : I64EnumAttr<
-    "AtomicRMWKind", "",
-    [ATOMIC_RMW_KIND_ADDF, ATOMIC_RMW_KIND_ADDI, ATOMIC_RMW_KIND_ASSIGN,
-     ATOMIC_RMW_KIND_MAXF, ATOMIC_RMW_KIND_MAXS, ATOMIC_RMW_KIND_MAXU,
-     ATOMIC_RMW_KIND_MINF, ATOMIC_RMW_KIND_MINS, ATOMIC_RMW_KIND_MINU,
-     ATOMIC_RMW_KIND_MULF, ATOMIC_RMW_KIND_MULI, ATOMIC_RMW_KIND_ORI,
-     ATOMIC_RMW_KIND_ANDI]> {
-  let cppNamespace = "::mlir";
-}
-
-#endif // STANDARD_OPS_BASE

diff  --git a/mlir/include/mlir/Dialect/Vector/VectorOps.h b/mlir/include/mlir/Dialect/Vector/VectorOps.h
index 14bd03968fcf6..816ec204acfe7 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.h
@@ -13,6 +13,7 @@
 #ifndef MLIR_DIALECT_VECTOR_VECTOROPS_H
 #define MLIR_DIALECT_VECTOR_VECTOROPS_H
 
+#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/Attributes.h"
@@ -145,8 +146,8 @@ ArrayAttr getVectorSubscriptAttr(Builder &b, ArrayRef<int64_t> values);
 
 /// Returns the value obtained by reducing the vector into a scalar using the
 /// operation kind associated with a binary AtomicRMWKind op.
-Value getVectorReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc,
-                           Value vector);
+Value getVectorReductionOp(arith::AtomicRMWKind op, OpBuilder &builder,
+                           Location loc, Value vector);
 
 /// Return true if the last dimension of the MemRefType has unit stride. Also
 /// return true for memrefs with no strides.

diff  --git a/mlir/lib/Analysis/AffineAnalysis.cpp b/mlir/lib/Analysis/AffineAnalysis.cpp
index 79a367e337133..c8022e0465483 100644
--- a/mlir/lib/Analysis/AffineAnalysis.cpp
+++ b/mlir/lib/Analysis/AffineAnalysis.cpp
@@ -40,7 +40,7 @@ using llvm::dbgs;
 /// reduction kind suitable for use in affine parallel loop builder. If the
 /// reduction is not supported, returns null.
 static Value getSupportedReduction(AffineForOp forOp, unsigned pos,
-                                   AtomicRMWKind &kind) {
+                                   arith::AtomicRMWKind &kind) {
   SmallVector<Operation *> combinerOps;
   Value reducedVal =
       matchReduction(forOp.getRegionIterArgs(), pos, combinerOps);
@@ -52,21 +52,21 @@ static Value getSupportedReduction(AffineForOp forOp, unsigned pos,
     return nullptr;
 
   Operation *combinerOp = combinerOps.back();
-  Optional<AtomicRMWKind> maybeKind =
-      TypeSwitch<Operation *, Optional<AtomicRMWKind>>(combinerOp)
-          .Case([](arith::AddFOp) { return AtomicRMWKind::addf; })
-          .Case([](arith::MulFOp) { return AtomicRMWKind::mulf; })
-          .Case([](arith::AddIOp) { return AtomicRMWKind::addi; })
-          .Case([](arith::AndIOp) { return AtomicRMWKind::andi; })
-          .Case([](arith::OrIOp) { return AtomicRMWKind::ori; })
-          .Case([](arith::MulIOp) { return AtomicRMWKind::muli; })
-          .Case([](arith::MinFOp) { return AtomicRMWKind::minf; })
-          .Case([](arith::MaxFOp) { return AtomicRMWKind::maxf; })
-          .Case([](arith::MinSIOp) { return AtomicRMWKind::mins; })
-          .Case([](arith::MaxSIOp) { return AtomicRMWKind::maxs; })
-          .Case([](arith::MinUIOp) { return AtomicRMWKind::minu; })
-          .Case([](arith::MaxUIOp) { return AtomicRMWKind::maxu; })
-          .Default([](Operation *) -> Optional<AtomicRMWKind> {
+  Optional<arith::AtomicRMWKind> maybeKind =
+      TypeSwitch<Operation *, Optional<arith::AtomicRMWKind>>(combinerOp)
+          .Case([](arith::AddFOp) { return arith::AtomicRMWKind::addf; })
+          .Case([](arith::MulFOp) { return arith::AtomicRMWKind::mulf; })
+          .Case([](arith::AddIOp) { return arith::AtomicRMWKind::addi; })
+          .Case([](arith::AndIOp) { return arith::AtomicRMWKind::andi; })
+          .Case([](arith::OrIOp) { return arith::AtomicRMWKind::ori; })
+          .Case([](arith::MulIOp) { return arith::AtomicRMWKind::muli; })
+          .Case([](arith::MinFOp) { return arith::AtomicRMWKind::minf; })
+          .Case([](arith::MaxFOp) { return arith::AtomicRMWKind::maxf; })
+          .Case([](arith::MinSIOp) { return arith::AtomicRMWKind::mins; })
+          .Case([](arith::MaxSIOp) { return arith::AtomicRMWKind::maxs; })
+          .Case([](arith::MinUIOp) { return arith::AtomicRMWKind::minu; })
+          .Case([](arith::MaxUIOp) { return arith::AtomicRMWKind::maxu; })
+          .Default([](Operation *) -> Optional<arith::AtomicRMWKind> {
             // TODO: AtomicRMW supports other kinds of reductions this is
             // currently not detecting, add those when the need arises.
             return llvm::None;
@@ -86,7 +86,7 @@ void mlir::getSupportedReductions(
     return;
   supportedReductions.reserve(numIterArgs);
   for (unsigned i = 0; i < numIterArgs; ++i) {
-    AtomicRMWKind kind;
+    arith::AtomicRMWKind kind;
     if (Value value = getSupportedReduction(forOp, i, kind))
       supportedReductions.emplace_back(LoopReduction{kind, i, value});
   }

diff  --git a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
index 8b48549b38052..bc2f5917160e7 100644
--- a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
+++ b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
@@ -430,13 +430,14 @@ class AffineParallelLowering : public OpRewritePattern<AffineParallelOp> {
       // initialization of the result values.
       Attribute reduction = std::get<0>(pair);
       Type resultType = std::get<1>(pair);
-      Optional<AtomicRMWKind> reductionOp = symbolizeAtomicRMWKind(
-          static_cast<uint64_t>(reduction.cast<IntegerAttr>().getInt()));
+      Optional<arith::AtomicRMWKind> reductionOp =
+          arith::symbolizeAtomicRMWKind(
+              static_cast<uint64_t>(reduction.cast<IntegerAttr>().getInt()));
       assert(reductionOp.hasValue() &&
              "Reduction operation cannot be of None Type");
-      AtomicRMWKind reductionOpValue = reductionOp.getValue();
+      arith::AtomicRMWKind reductionOpValue = reductionOp.getValue();
       identityVals.push_back(
-          getIdentityValue(reductionOpValue, resultType, rewriter, loc));
+          arith::getIdentityValue(reductionOpValue, resultType, rewriter, loc));
     }
     parOp = rewriter.create<scf::ParallelOp>(
         loc, lowerBoundTuple, upperBoundTuple, steps, identityVals,
@@ -450,16 +451,17 @@ class AffineParallelLowering : public OpRewritePattern<AffineParallelOp> {
            "Unequal number of reductions and operands.");
     for (unsigned i = 0, end = reductions.size(); i < end; i++) {
       // For each of the reduction operations get the respective mlir::Value.
-      Optional<AtomicRMWKind> reductionOp =
-          symbolizeAtomicRMWKind(reductions[i].cast<IntegerAttr>().getInt());
+      Optional<arith::AtomicRMWKind> reductionOp =
+          arith::symbolizeAtomicRMWKind(
+              reductions[i].cast<IntegerAttr>().getInt());
       assert(reductionOp.hasValue() &&
              "Reduction Operation cannot be of None Type");
-      AtomicRMWKind reductionOpValue = reductionOp.getValue();
+      arith::AtomicRMWKind reductionOpValue = reductionOp.getValue();
       rewriter.setInsertionPoint(&parOp.getBody()->back());
       auto reduceOp = rewriter.create<scf::ReduceOp>(
           loc, affineParOpTerminator->getOperand(i));
       rewriter.setInsertionPointToEnd(&reduceOp.getReductionOperator().front());
-      Value reductionResult = getReductionOp(
+      Value reductionResult = arith::getReductionOp(
           reductionOpValue, rewriter, loc,
           reduceOp.getReductionOperator().front().getArgument(0),
           reduceOp.getReductionOperator().front().getArgument(1));

diff  --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 28981dd87ecc9..b1f7d0452ee13 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -1553,6 +1553,62 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> {
   }
 };
 
+//===----------------------------------------------------------------------===//
+// AtomicRMWOpLowering
+//===----------------------------------------------------------------------===//
+
+/// Try to match the kind of a std.atomic_rmw to determine whether to use a
+/// lowering to llvm.atomicrmw or fallback to llvm.cmpxchg.
+static Optional<LLVM::AtomicBinOp>
+matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) {
+  switch (atomicOp.kind()) {
+  case arith::AtomicRMWKind::addf:
+    return LLVM::AtomicBinOp::fadd;
+  case arith::AtomicRMWKind::addi:
+    return LLVM::AtomicBinOp::add;
+  case arith::AtomicRMWKind::assign:
+    return LLVM::AtomicBinOp::xchg;
+  case arith::AtomicRMWKind::maxs:
+    return LLVM::AtomicBinOp::max;
+  case arith::AtomicRMWKind::maxu:
+    return LLVM::AtomicBinOp::umax;
+  case arith::AtomicRMWKind::mins:
+    return LLVM::AtomicBinOp::min;
+  case arith::AtomicRMWKind::minu:
+    return LLVM::AtomicBinOp::umin;
+  case arith::AtomicRMWKind::ori:
+    return LLVM::AtomicBinOp::_or;
+  case arith::AtomicRMWKind::andi:
+    return LLVM::AtomicBinOp::_and;
+  default:
+    return llvm::None;
+  }
+  llvm_unreachable("Invalid AtomicRMWKind");
+}
+
+struct AtomicRMWOpLowering : public LoadStoreOpLowering<memref::AtomicRMWOp> {
+  using Base::Base;
+
+  LogicalResult
+  matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    if (failed(match(atomicOp)))
+      return failure();
+    auto maybeKind = matchSimpleAtomicOp(atomicOp);
+    if (!maybeKind)
+      return failure();
+    auto resultType = adaptor.value().getType();
+    auto memRefType = atomicOp.getMemRefType();
+    auto dataPtr =
+        getStridedElementPtr(atomicOp.getLoc(), memRefType, adaptor.memref(),
+                             adaptor.indices(), rewriter);
+    rewriter.replaceOpWithNewOp<LLVM::AtomicRMWOp>(
+        atomicOp, resultType, *maybeKind, dataPtr, adaptor.value(),
+        LLVM::AtomicOrdering::acq_rel);
+    return success();
+  }
+};
+
 } // namespace
 
 void mlir::populateMemRefToLLVMConversionPatterns(LLVMTypeConverter &converter,
@@ -1561,6 +1617,7 @@ void mlir::populateMemRefToLLVMConversionPatterns(LLVMTypeConverter &converter,
   patterns.add<
       AllocaOpLowering,
       AllocaScopeOpLowering,
+      AtomicRMWOpLowering,
       AssumeAlignmentOpLowering,
       DimOpLowering,
       GlobalMemrefOpLowering,

diff  --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index da429dd8af119..feaa140cc710e 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -772,61 +772,6 @@ struct SplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
   }
 };
 
-} // namespace
-
-/// Try to match the kind of a std.atomic_rmw to determine whether to use a
-/// lowering to llvm.atomicrmw or fallback to llvm.cmpxchg.
-static Optional<LLVM::AtomicBinOp> matchSimpleAtomicOp(AtomicRMWOp atomicOp) {
-  switch (atomicOp.getKind()) {
-  case AtomicRMWKind::addf:
-    return LLVM::AtomicBinOp::fadd;
-  case AtomicRMWKind::addi:
-    return LLVM::AtomicBinOp::add;
-  case AtomicRMWKind::assign:
-    return LLVM::AtomicBinOp::xchg;
-  case AtomicRMWKind::maxs:
-    return LLVM::AtomicBinOp::max;
-  case AtomicRMWKind::maxu:
-    return LLVM::AtomicBinOp::umax;
-  case AtomicRMWKind::mins:
-    return LLVM::AtomicBinOp::min;
-  case AtomicRMWKind::minu:
-    return LLVM::AtomicBinOp::umin;
-  case AtomicRMWKind::ori:
-    return LLVM::AtomicBinOp::_or;
-  case AtomicRMWKind::andi:
-    return LLVM::AtomicBinOp::_and;
-  default:
-    return llvm::None;
-  }
-  llvm_unreachable("Invalid AtomicRMWKind");
-}
-
-namespace {
-
-struct AtomicRMWOpLowering : public LoadStoreOpLowering<AtomicRMWOp> {
-  using Base::Base;
-
-  LogicalResult
-  matchAndRewrite(AtomicRMWOp atomicOp, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    if (failed(match(atomicOp)))
-      return failure();
-    auto maybeKind = matchSimpleAtomicOp(atomicOp);
-    if (!maybeKind)
-      return failure();
-    auto resultType = adaptor.getValue().getType();
-    auto memRefType = atomicOp.getMemRefType();
-    auto dataPtr =
-        getStridedElementPtr(atomicOp.getLoc(), memRefType, adaptor.getMemref(),
-                             adaptor.getIndices(), rewriter);
-    rewriter.replaceOpWithNewOp<LLVM::AtomicRMWOp>(
-        atomicOp, resultType, *maybeKind, dataPtr, adaptor.getValue(),
-        LLVM::AtomicOrdering::acq_rel);
-    return success();
-  }
-};
-
 /// Wrap a llvm.cmpxchg operation in a while loop so that the operation can be
 /// retried until it succeeds in atomically storing a new value into memory.
 ///
@@ -962,7 +907,6 @@ void mlir::populateStdToLLVMConversionPatterns(LLVMTypeConverter &converter,
   // clang-format off
   patterns.add<
       AssertOpLowering,
-      AtomicRMWOpLowering,
       BranchOpLowering,
       CallIndirectOpLowering,
       CallOpLowering,

diff  --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 071838dcf2be4..c3c1b51294801 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -2801,7 +2801,7 @@ LogicalResult AffinePrefetchOp::fold(ArrayRef<Attribute> cstOperands,
 
 void AffineParallelOp::build(OpBuilder &builder, OperationState &result,
                              TypeRange resultTypes,
-                             ArrayRef<AtomicRMWKind> reductions,
+                             ArrayRef<arith::AtomicRMWKind> reductions,
                              ArrayRef<int64_t> ranges) {
   SmallVector<AffineMap> lbs(ranges.size(), builder.getConstantAffineMap(0));
   auto ubs = llvm::to_vector<4>(llvm::map_range(ranges, [&](int64_t value) {
@@ -2814,7 +2814,7 @@ void AffineParallelOp::build(OpBuilder &builder, OperationState &result,
 
 void AffineParallelOp::build(OpBuilder &builder, OperationState &result,
                              TypeRange resultTypes,
-                             ArrayRef<AtomicRMWKind> reductions,
+                             ArrayRef<arith::AtomicRMWKind> reductions,
                              ArrayRef<AffineMap> lbMaps, ValueRange lbArgs,
                              ArrayRef<AffineMap> ubMaps, ValueRange ubArgs,
                              ArrayRef<int64_t> steps) {
@@ -2843,7 +2843,7 @@ void AffineParallelOp::build(OpBuilder &builder, OperationState &result,
 
   // Convert the reductions to integer attributes.
   SmallVector<Attribute, 4> reductionAttrs;
-  for (AtomicRMWKind reduction : reductions)
+  for (arith::AtomicRMWKind reduction : reductions)
     reductionAttrs.push_back(
         builder.getI64IntegerAttr(static_cast<int64_t>(reduction)));
   result.addAttribute(getReductionsAttrName(),
@@ -3050,7 +3050,7 @@ static LogicalResult verify(AffineParallelOp op) {
   // Verify reduction  ops are all valid
   for (Attribute attr : op.reductions()) {
     auto intAttr = attr.dyn_cast<IntegerAttr>();
-    if (!intAttr || !symbolizeAtomicRMWKind(intAttr.getInt()))
+    if (!intAttr || !arith::symbolizeAtomicRMWKind(intAttr.getInt()))
       return op.emitOpError("invalid reduction attribute");
   }
 
@@ -3150,9 +3150,9 @@ static void print(OpAsmPrinter &p, AffineParallelOp op) {
   if (op.getNumResults()) {
     p << " reduce (";
     llvm::interleaveComma(op.reductions(), p, [&](auto &attr) {
-      AtomicRMWKind sym =
-          *symbolizeAtomicRMWKind(attr.template cast<IntegerAttr>().getInt());
-      p << "\"" << stringifyAtomicRMWKind(sym) << "\"";
+      arith::AtomicRMWKind sym = *arith::symbolizeAtomicRMWKind(
+          attr.template cast<IntegerAttr>().getInt());
+      p << "\"" << arith::stringifyAtomicRMWKind(sym) << "\"";
     });
     p << ") -> (" << op.getResultTypes() << ")";
   }
@@ -3374,8 +3374,8 @@ static ParseResult parseAffineParallelOp(OpAsmParser &parser,
       if (parser.parseAttribute(attrVal, builder.getNoneType(), "reduce",
                                 attrStorage))
         return failure();
-      llvm::Optional<AtomicRMWKind> reduction =
-          symbolizeAtomicRMWKind(attrVal.getValue());
+      llvm::Optional<arith::AtomicRMWKind> reduction =
+          arith::symbolizeAtomicRMWKind(attrVal.getValue());
       if (!reduction)
         return parser.emitError(loc, "invalid reduction value: ") << attrVal;
       reductions.push_back(builder.getI64IntegerAttr(

diff  --git a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp
index 9d59b89ea1a2c..7ecc6750bcca4 100644
--- a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp
@@ -971,7 +971,7 @@ static arith::ConstantOp vectorizeConstant(arith::ConstantOp constOp,
 /// Creates a constant vector filled with the neutral elements of the given
 /// reduction. The scalar type of vector elements will be taken from
 /// `oldOperand`.
-static arith::ConstantOp createInitialVector(AtomicRMWKind reductionKind,
+static arith::ConstantOp createInitialVector(arith::AtomicRMWKind reductionKind,
                                              Value oldOperand,
                                              VectorizationState &state) {
   Type scalarTy = oldOperand.getType();
@@ -1245,8 +1245,8 @@ static Operation *vectorizeAffineStore(AffineStoreOp storeOp,
 
 /// Returns true if `value` is a constant equal to the neutral element of the
 /// given vectorizable reduction.
-static bool isNeutralElementConst(AtomicRMWKind reductionKind, Value value,
-                                  VectorizationState &state) {
+static bool isNeutralElementConst(arith::AtomicRMWKind reductionKind,
+                                  Value value, VectorizationState &state) {
   Type scalarTy = value.getType();
   if (!VectorType::isValidElementType(scalarTy))
     return false;
@@ -1361,7 +1361,8 @@ static Operation *vectorizeAffineForOp(AffineForOp forOp,
       Value origInit = forOp.getOperand(forOp.getNumControlOperands() + i);
       Value finalRes = reducedRes;
       if (!isNeutralElementConst(reductions[i].kind, origInit, state))
-        finalRes = getReductionOp(reductions[i].kind, state.builder,
+        finalRes =
+            arith::getReductionOp(reductions[i].kind, state.builder,
                                   reducedRes.getLoc(), reducedRes, origInit);
       state.registerLoopResultScalarReplacement(forOp.getResult(i), finalRes);
     }

diff  --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
index f0ce1b7a4d704..048e4d89186cd 100644
--- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
+++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
@@ -8,6 +8,7 @@
 
 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
 #include "mlir/Dialect/CommonFolders.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/OpImplementation.h"
@@ -1208,6 +1209,101 @@ OpFoldResult arith::CmpFOp::fold(ArrayRef<Attribute> operands) {
   return BoolAttr::get(getContext(), val);
 }
 
+//===----------------------------------------------------------------------===//
+// Atomic Enum
+//===----------------------------------------------------------------------===//
+
+/// Returns the identity value attribute associated with an AtomicRMWKind op.
+Attribute mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
+                                            OpBuilder &builder, Location loc) {
+  switch (kind) {
+  case AtomicRMWKind::maxf:
+    return builder.getFloatAttr(
+        resultType,
+        APFloat::getInf(resultType.cast<FloatType>().getFloatSemantics(),
+                        /*Negative=*/true));
+  case AtomicRMWKind::addf:
+  case AtomicRMWKind::addi:
+  case AtomicRMWKind::maxu:
+  case AtomicRMWKind::ori:
+    return builder.getZeroAttr(resultType);
+  case AtomicRMWKind::andi:
+    return builder.getIntegerAttr(
+        resultType,
+        APInt::getAllOnes(resultType.cast<IntegerType>().getWidth()));
+  case AtomicRMWKind::maxs:
+    return builder.getIntegerAttr(
+        resultType,
+        APInt::getSignedMinValue(resultType.cast<IntegerType>().getWidth()));
+  case AtomicRMWKind::minf:
+    return builder.getFloatAttr(
+        resultType,
+        APFloat::getInf(resultType.cast<FloatType>().getFloatSemantics(),
+                        /*Negative=*/false));
+  case AtomicRMWKind::mins:
+    return builder.getIntegerAttr(
+        resultType,
+        APInt::getSignedMaxValue(resultType.cast<IntegerType>().getWidth()));
+  case AtomicRMWKind::minu:
+    return builder.getIntegerAttr(
+        resultType,
+        APInt::getMaxValue(resultType.cast<IntegerType>().getWidth()));
+  case AtomicRMWKind::muli:
+    return builder.getIntegerAttr(resultType, 1);
+  case AtomicRMWKind::mulf:
+    return builder.getFloatAttr(resultType, 1);
+  // TODO: Add remaining reduction operations.
+  default:
+    (void)emitOptionalError(loc, "Reduction operation type not supported");
+    break;
+  }
+  return nullptr;
+}
+
+/// Returns the identity value associated with an AtomicRMWKind op.
+Value mlir::arith::getIdentityValue(AtomicRMWKind op, Type resultType,
+                                    OpBuilder &builder, Location loc) {
+  Attribute attr = getIdentityValueAttr(op, resultType, builder, loc);
+  return builder.create<arith::ConstantOp>(loc, attr);
+}
+
+/// Return the value obtained by applying the reduction operation kind
+/// associated with a binary AtomicRMWKind op to `lhs` and `rhs`.
+Value mlir::arith::getReductionOp(AtomicRMWKind op, OpBuilder &builder,
+                                  Location loc, Value lhs, Value rhs) {
+  switch (op) {
+  case AtomicRMWKind::addf:
+    return builder.create<arith::AddFOp>(loc, lhs, rhs);
+  case AtomicRMWKind::addi:
+    return builder.create<arith::AddIOp>(loc, lhs, rhs);
+  case AtomicRMWKind::mulf:
+    return builder.create<arith::MulFOp>(loc, lhs, rhs);
+  case AtomicRMWKind::muli:
+    return builder.create<arith::MulIOp>(loc, lhs, rhs);
+  case AtomicRMWKind::maxf:
+    return builder.create<arith::MaxFOp>(loc, lhs, rhs);
+  case AtomicRMWKind::minf:
+    return builder.create<arith::MinFOp>(loc, lhs, rhs);
+  case AtomicRMWKind::maxs:
+    return builder.create<arith::MaxSIOp>(loc, lhs, rhs);
+  case AtomicRMWKind::mins:
+    return builder.create<arith::MinSIOp>(loc, lhs, rhs);
+  case AtomicRMWKind::maxu:
+    return builder.create<arith::MaxUIOp>(loc, lhs, rhs);
+  case AtomicRMWKind::minu:
+    return builder.create<arith::MinUIOp>(loc, lhs, rhs);
+  case AtomicRMWKind::ori:
+    return builder.create<arith::OrIOp>(loc, lhs, rhs);
+  case AtomicRMWKind::andi:
+    return builder.create<arith::AndIOp>(loc, lhs, rhs);
+  // TODO: Add remaining reduction operations.
+  default:
+    (void)emitOptionalError(loc, "Reduction operation type not supported");
+    break;
+  }
+  return nullptr;
+}
+
 //===----------------------------------------------------------------------===//
 // TableGen'd op method definitions
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index ab7e8305ab5b8..45ba726d5bf99 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2286,6 +2286,50 @@ void ViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
   results.add<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context);
 }
 
+//===----------------------------------------------------------------------===//
+// AtomicRMWOp
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verify(AtomicRMWOp op) {
+  if (op.getMemRefType().getRank() != op.getNumOperands() - 2)
+    return op.emitOpError(
+        "expects the number of subscripts to be equal to memref rank");
+  switch (op.kind()) {
+  case arith::AtomicRMWKind::addf:
+  case arith::AtomicRMWKind::maxf:
+  case arith::AtomicRMWKind::minf:
+  case arith::AtomicRMWKind::mulf:
+    if (!op.value().getType().isa<FloatType>())
+      return op.emitOpError()
+             << "with kind '" << arith::stringifyAtomicRMWKind(op.kind())
+             << "' expects a floating-point type";
+    break;
+  case arith::AtomicRMWKind::addi:
+  case arith::AtomicRMWKind::maxs:
+  case arith::AtomicRMWKind::maxu:
+  case arith::AtomicRMWKind::mins:
+  case arith::AtomicRMWKind::minu:
+  case arith::AtomicRMWKind::muli:
+  case arith::AtomicRMWKind::ori:
+  case arith::AtomicRMWKind::andi:
+    if (!op.value().getType().isa<IntegerType>())
+      return op.emitOpError()
+             << "with kind '" << arith::stringifyAtomicRMWKind(op.kind())
+             << "' expects an integer type";
+    break;
+  default:
+    break;
+  }
+  return success();
+}
+
+OpFoldResult AtomicRMWOp::fold(ArrayRef<Attribute> operands) {
+  /// atomicrmw(memrefcast) -> atomicrmw
+  if (succeeded(foldMemRefCast(*this, value())))
+    return getResult();
+  return OpFoldResult();
+}
+
 //===----------------------------------------------------------------------===//
 // TableGen'd op method definitions
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 02d54472baf50..a74b46c034c42 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -131,134 +131,6 @@ LogicalResult AssertOp::canonicalize(AssertOp op, PatternRewriter &rewriter) {
   return failure();
 }
 
-//===----------------------------------------------------------------------===//
-// AtomicRMWOp
-//===----------------------------------------------------------------------===//
-
-static LogicalResult verify(AtomicRMWOp op) {
-  if (op.getMemRefType().getRank() != op.getNumOperands() - 2)
-    return op.emitOpError(
-        "expects the number of subscripts to be equal to memref rank");
-  switch (op.getKind()) {
-  case AtomicRMWKind::addf:
-  case AtomicRMWKind::maxf:
-  case AtomicRMWKind::minf:
-  case AtomicRMWKind::mulf:
-    if (!op.getValue().getType().isa<FloatType>())
-      return op.emitOpError()
-             << "with kind '" << stringifyAtomicRMWKind(op.getKind())
-             << "' expects a floating-point type";
-    break;
-  case AtomicRMWKind::addi:
-  case AtomicRMWKind::maxs:
-  case AtomicRMWKind::maxu:
-  case AtomicRMWKind::mins:
-  case AtomicRMWKind::minu:
-  case AtomicRMWKind::muli:
-  case AtomicRMWKind::ori:
-  case AtomicRMWKind::andi:
-    if (!op.getValue().getType().isa<IntegerType>())
-      return op.emitOpError()
-             << "with kind '" << stringifyAtomicRMWKind(op.getKind())
-             << "' expects an integer type";
-    break;
-  default:
-    break;
-  }
-  return success();
-}
-
-/// Returns the identity value attribute associated with an AtomicRMWKind op.
-Attribute mlir::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
-                                     OpBuilder &builder, Location loc) {
-  switch (kind) {
-  case AtomicRMWKind::maxf:
-    return builder.getFloatAttr(
-        resultType,
-        APFloat::getInf(resultType.cast<FloatType>().getFloatSemantics(),
-                        /*Negative=*/true));
-  case AtomicRMWKind::addf:
-  case AtomicRMWKind::addi:
-  case AtomicRMWKind::maxu:
-  case AtomicRMWKind::ori:
-    return builder.getZeroAttr(resultType);
-  case AtomicRMWKind::andi:
-    return builder.getIntegerAttr(
-        resultType,
-        APInt::getAllOnes(resultType.cast<IntegerType>().getWidth()));
-  case AtomicRMWKind::maxs:
-    return builder.getIntegerAttr(
-        resultType,
-        APInt::getSignedMinValue(resultType.cast<IntegerType>().getWidth()));
-  case AtomicRMWKind::minf:
-    return builder.getFloatAttr(
-        resultType,
-        APFloat::getInf(resultType.cast<FloatType>().getFloatSemantics(),
-                        /*Negative=*/false));
-  case AtomicRMWKind::mins:
-    return builder.getIntegerAttr(
-        resultType,
-        APInt::getSignedMaxValue(resultType.cast<IntegerType>().getWidth()));
-  case AtomicRMWKind::minu:
-    return builder.getIntegerAttr(
-        resultType,
-        APInt::getMaxValue(resultType.cast<IntegerType>().getWidth()));
-  case AtomicRMWKind::muli:
-    return builder.getIntegerAttr(resultType, 1);
-  case AtomicRMWKind::mulf:
-    return builder.getFloatAttr(resultType, 1);
-  // TODO: Add remaining reduction operations.
-  default:
-    (void)emitOptionalError(loc, "Reduction operation type not supported");
-    break;
-  }
-  return nullptr;
-}
-
-/// Returns the identity value associated with an AtomicRMWKind op.
-Value mlir::getIdentityValue(AtomicRMWKind op, Type resultType,
-                             OpBuilder &builder, Location loc) {
-  Attribute attr = getIdentityValueAttr(op, resultType, builder, loc);
-  return builder.create<arith::ConstantOp>(loc, attr);
-}
-
-/// Return the value obtained by applying the reduction operation kind
-/// associated with a binary AtomicRMWKind op to `lhs` and `rhs`.
-Value mlir::getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc,
-                           Value lhs, Value rhs) {
-  switch (op) {
-  case AtomicRMWKind::addf:
-    return builder.create<arith::AddFOp>(loc, lhs, rhs);
-  case AtomicRMWKind::addi:
-    return builder.create<arith::AddIOp>(loc, lhs, rhs);
-  case AtomicRMWKind::mulf:
-    return builder.create<arith::MulFOp>(loc, lhs, rhs);
-  case AtomicRMWKind::muli:
-    return builder.create<arith::MulIOp>(loc, lhs, rhs);
-  case AtomicRMWKind::maxf:
-    return builder.create<arith::MaxFOp>(loc, lhs, rhs);
-  case AtomicRMWKind::minf:
-    return builder.create<arith::MinFOp>(loc, lhs, rhs);
-  case AtomicRMWKind::maxs:
-    return builder.create<arith::MaxSIOp>(loc, lhs, rhs);
-  case AtomicRMWKind::mins:
-    return builder.create<arith::MinSIOp>(loc, lhs, rhs);
-  case AtomicRMWKind::maxu:
-    return builder.create<arith::MaxUIOp>(loc, lhs, rhs);
-  case AtomicRMWKind::minu:
-    return builder.create<arith::MinUIOp>(loc, lhs, rhs);
-  case AtomicRMWKind::ori:
-    return builder.create<arith::OrIOp>(loc, lhs, rhs);
-  case AtomicRMWKind::andi:
-    return builder.create<arith::AndIOp>(loc, lhs, rhs);
-  // TODO: Add remaining reduction operations.
-  default:
-    (void)emitOptionalError(loc, "Reduction operation type not supported");
-    break;
-  }
-  return nullptr;
-}
-
 //===----------------------------------------------------------------------===//
 // GenericAtomicRMWOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/StandardOps/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/StandardOps/Transforms/ExpandOps.cpp
index 71a1a55903c65..a62f6a076f936 100644
--- a/mlir/lib/Dialect/StandardOps/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/StandardOps/Transforms/ExpandOps.cpp
@@ -40,18 +40,18 @@ namespace {
 ///   %new_value = select %cmp, %current, %fval : f32
 ///   atomic_yield %new_value : f32
 /// }
-struct AtomicRMWOpConverter : public OpRewritePattern<AtomicRMWOp> {
+struct AtomicRMWOpConverter : public OpRewritePattern<memref::AtomicRMWOp> {
 public:
   using OpRewritePattern::OpRewritePattern;
 
-  LogicalResult matchAndRewrite(AtomicRMWOp op,
+  LogicalResult matchAndRewrite(memref::AtomicRMWOp op,
                                 PatternRewriter &rewriter) const final {
     arith::CmpFPredicate predicate;
-    switch (op.getKind()) {
-    case AtomicRMWKind::maxf:
+    switch (op.kind()) {
+    case arith::AtomicRMWKind::maxf:
       predicate = arith::CmpFPredicate::OGT;
       break;
-    case AtomicRMWKind::minf:
+    case arith::AtomicRMWKind::minf:
       predicate = arith::CmpFPredicate::OLT;
       break;
     default:
@@ -59,13 +59,13 @@ struct AtomicRMWOpConverter : public OpRewritePattern<AtomicRMWOp> {
     }
 
     auto loc = op.getLoc();
-    auto genericOp = rewriter.create<GenericAtomicRMWOp>(loc, op.getMemref(),
-                                                         op.getIndices());
+    auto genericOp =
+        rewriter.create<GenericAtomicRMWOp>(loc, op.memref(), op.indices());
     OpBuilder bodyBuilder =
         OpBuilder::atBlockEnd(genericOp.getBody(), rewriter.getListener());
 
     Value lhs = genericOp.getCurrentValue();
-    Value rhs = op.getValue();
+    Value rhs = op.value();
     Value cmp = bodyBuilder.create<arith::CmpFOp>(loc, predicate, lhs, rhs);
     Value select = bodyBuilder.create<SelectOp>(loc, cmp, lhs, rhs);
     bodyBuilder.create<AtomicYieldOp>(loc, select);
@@ -130,10 +130,11 @@ struct StdExpandOpsPass : public StdExpandOpsBase<StdExpandOpsPass> {
 
     target.addLegalDialect<arith::ArithmeticDialect, memref::MemRefDialect,
                            StandardOpsDialect>();
-    target.addDynamicallyLegalOp<AtomicRMWOp>([](AtomicRMWOp op) {
-      return op.getKind() != AtomicRMWKind::maxf &&
-             op.getKind() != AtomicRMWKind::minf;
-    });
+    target.addDynamicallyLegalOp<memref::AtomicRMWOp>(
+        [](memref::AtomicRMWOp op) {
+          return op.kind() != arith::AtomicRMWKind::maxf &&
+                 op.kind() != arith::AtomicRMWKind::minf;
+        });
     target.addDynamicallyLegalOp<memref::ReshapeOp>([](memref::ReshapeOp op) {
       return !op.shape().getType().cast<MemRefType>().hasStaticShape();
     });

diff  --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index 3a65d0e93dfdb..60c0aac1a4bea 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -359,41 +359,42 @@ static void print(OpAsmPrinter &p, ReductionOp op) {
   p << " : " << op.vector().getType() << " into " << op.dest().getType();
 }
 
-Value mlir::vector::getVectorReductionOp(AtomicRMWKind op, OpBuilder &builder,
-                                         Location loc, Value vector) {
+Value mlir::vector::getVectorReductionOp(arith::AtomicRMWKind op,
+                                         OpBuilder &builder, Location loc,
+                                         Value vector) {
   Type scalarType = vector.getType().cast<ShapedType>().getElementType();
   switch (op) {
-  case AtomicRMWKind::addf:
-  case AtomicRMWKind::addi:
+  case arith::AtomicRMWKind::addf:
+  case arith::AtomicRMWKind::addi:
     return builder.create<vector::ReductionOp>(vector.getLoc(), scalarType,
                                                builder.getStringAttr("add"),
                                                vector, ValueRange{});
-  case AtomicRMWKind::mulf:
-  case AtomicRMWKind::muli:
+  case arith::AtomicRMWKind::mulf:
+  case arith::AtomicRMWKind::muli:
     return builder.create<vector::ReductionOp>(vector.getLoc(), scalarType,
                                                builder.getStringAttr("mul"),
                                                vector, ValueRange{});
-  case AtomicRMWKind::minf:
+  case arith::AtomicRMWKind::minf:
     return builder.create<vector::ReductionOp>(vector.getLoc(), scalarType,
                                                builder.getStringAttr("minf"),
                                                vector, ValueRange{});
-  case AtomicRMWKind::mins:
+  case arith::AtomicRMWKind::mins:
     return builder.create<vector::ReductionOp>(vector.getLoc(), scalarType,
                                                builder.getStringAttr("minsi"),
                                                vector, ValueRange{});
-  case AtomicRMWKind::minu:
+  case arith::AtomicRMWKind::minu:
     return builder.create<vector::ReductionOp>(vector.getLoc(), scalarType,
                                                builder.getStringAttr("minui"),
                                                vector, ValueRange{});
-  case AtomicRMWKind::maxf:
+  case arith::AtomicRMWKind::maxf:
     return builder.create<vector::ReductionOp>(vector.getLoc(), scalarType,
                                                builder.getStringAttr("maxf"),
                                                vector, ValueRange{});
-  case AtomicRMWKind::maxs:
+  case arith::AtomicRMWKind::maxs:
     return builder.create<vector::ReductionOp>(vector.getLoc(), scalarType,
                                                builder.getStringAttr("maxsi"),
                                                vector, ValueRange{});
-  case AtomicRMWKind::maxu:
+  case arith::AtomicRMWKind::maxu:
     return builder.create<vector::ReductionOp>(vector.getLoc(), scalarType,
                                                builder.getStringAttr("maxui"),
                                                vector, ValueRange{});

diff  --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp
index fbb79b5af3f86..91d4a7cd5d195 100644
--- a/mlir/lib/Transforms/Utils/LoopUtils.cpp
+++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp
@@ -1551,7 +1551,7 @@ LogicalResult mlir::loopUnrollJamByFactor(AffineForOp forOp,
       for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
         rhs = forOp.getResult(i * oldNumResults + pos);
         // Create ops based on reduction type.
-        lhs = getReductionOp(reduction.kind, builder, loc, lhs, rhs);
+        lhs = arith::getReductionOp(reduction.kind, builder, loc, lhs, rhs);
         if (!lhs)
           return failure();
         Operation *op = lhs.getDefiningOp();

diff  --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
index 5682c853964c8..70ba47d2d176b 100644
--- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
@@ -859,3 +859,28 @@ func @rank_of_ranked(%ranked: memref<?xi32>) {
 }
 // CHECK: llvm.mlir.constant(1 : index) : i64
 // CHECK32: llvm.mlir.constant(1 : index) : i32
+
+// -----
+
+// CHECK-LABEL: func @atomic_rmw
+func @atomic_rmw(%I : memref<10xi32>, %ival : i32, %F : memref<10xf32>, %fval : f32, %i : index) {
+  memref.atomic_rmw assign %fval, %F[%i] : (f32, memref<10xf32>) -> f32
+  // CHECK: llvm.atomicrmw xchg %{{.*}}, %{{.*}} acq_rel
+  memref.atomic_rmw addi %ival, %I[%i] : (i32, memref<10xi32>) -> i32
+  // CHECK: llvm.atomicrmw add %{{.*}}, %{{.*}} acq_rel
+  memref.atomic_rmw maxs %ival, %I[%i] : (i32, memref<10xi32>) -> i32
+  // CHECK: llvm.atomicrmw max %{{.*}}, %{{.*}} acq_rel
+  memref.atomic_rmw mins %ival, %I[%i] : (i32, memref<10xi32>) -> i32
+  // CHECK: llvm.atomicrmw min %{{.*}}, %{{.*}} acq_rel
+  memref.atomic_rmw maxu %ival, %I[%i] : (i32, memref<10xi32>) -> i32
+  // CHECK: llvm.atomicrmw umax %{{.*}}, %{{.*}} acq_rel
+  memref.atomic_rmw minu %ival, %I[%i] : (i32, memref<10xi32>) -> i32
+  // CHECK: llvm.atomicrmw umin %{{.*}}, %{{.*}} acq_rel
+  memref.atomic_rmw addf %fval, %F[%i] : (f32, memref<10xf32>) -> f32
+  // CHECK: llvm.atomicrmw fadd %{{.*}}, %{{.*}} acq_rel
+  memref.atomic_rmw ori %ival, %I[%i] : (i32, memref<10xi32>) -> i32
+  // CHECK: llvm.atomicrmw _or %{{.*}}, %{{.*}} acq_rel
+  memref.atomic_rmw andi %ival, %I[%i] : (i32, memref<10xi32>) -> i32
+  // CHECK: llvm.atomicrmw _and %{{.*}}, %{{.*}} acq_rel
+  return
+}

diff  --git a/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir b/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir
index c3282e1903d6f..0dc6bf10dc5ea 100644
--- a/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir
+++ b/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir
@@ -486,31 +486,6 @@ func @splat(%a: vector<4xf32>, %b: f32) -> vector<4xf32> {
 
 // -----
 
-// CHECK-LABEL: func @atomic_rmw
-func @atomic_rmw(%I : memref<10xi32>, %ival : i32, %F : memref<10xf32>, %fval : f32, %i : index) {
-  atomic_rmw assign %fval, %F[%i] : (f32, memref<10xf32>) -> f32
-  // CHECK: llvm.atomicrmw xchg %{{.*}}, %{{.*}} acq_rel
-  atomic_rmw addi %ival, %I[%i] : (i32, memref<10xi32>) -> i32
-  // CHECK: llvm.atomicrmw add %{{.*}}, %{{.*}} acq_rel
-  atomic_rmw maxs %ival, %I[%i] : (i32, memref<10xi32>) -> i32
-  // CHECK: llvm.atomicrmw max %{{.*}}, %{{.*}} acq_rel
-  atomic_rmw mins %ival, %I[%i] : (i32, memref<10xi32>) -> i32
-  // CHECK: llvm.atomicrmw min %{{.*}}, %{{.*}} acq_rel
-  atomic_rmw maxu %ival, %I[%i] : (i32, memref<10xi32>) -> i32
-  // CHECK: llvm.atomicrmw umax %{{.*}}, %{{.*}} acq_rel
-  atomic_rmw minu %ival, %I[%i] : (i32, memref<10xi32>) -> i32
-  // CHECK: llvm.atomicrmw umin %{{.*}}, %{{.*}} acq_rel
-  atomic_rmw addf %fval, %F[%i] : (f32, memref<10xf32>) -> f32
-  // CHECK: llvm.atomicrmw fadd %{{.*}}, %{{.*}} acq_rel
-  atomic_rmw ori %ival, %I[%i] : (i32, memref<10xi32>) -> i32
-  // CHECK: llvm.atomicrmw _or %{{.*}}, %{{.*}} acq_rel
-  atomic_rmw andi %ival, %I[%i] : (i32, memref<10xi32>) -> i32
-  // CHECK: llvm.atomicrmw _and %{{.*}}, %{{.*}} acq_rel
-  return
-}
-
-// -----
-
 // CHECK-LABEL: func @generic_atomic_rmw
 func @generic_atomic_rmw(%I : memref<10xi32>, %i : index) -> i32 {
   %x = generic_atomic_rmw %I[%i] : memref<10xi32> {

diff  --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 39f9847f4c9e2..2e81705049f54 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -499,3 +499,14 @@ func @no_fold_dynamic_no_op_subview(%arg0 : memref<?x?xf32>) -> memref<?x?xf32,
 // CHECK-LABEL: func @no_fold_dynamic_no_op_subview(
 //       CHECK:   %[[SUBVIEW:.+]] = memref.subview
 //       CHECK:    return %[[SUBVIEW]]
+
+// -----
+
+func @atomicrmw_cast_fold(%arg0 : f32, %arg1 : memref<4xf32>, %c : index) {
+  %v = memref.cast %arg1 : memref<4xf32> to memref<?xf32>
+  %a = memref.atomic_rmw addf %arg0, %v[%c] : (f32, memref<?xf32>) -> f32
+  return
+}
+
+// CHECK-LABEL: func @atomicrmw_cast_fold
+// CHECK-NEXT: memref.atomic_rmw addf %arg0, %arg1[%arg2] : (f32, memref<4xf32>) -> f32

diff  --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir
index 5cf32703c9ebb..90f851959748c 100644
--- a/mlir/test/Dialect/MemRef/invalid.mlir
+++ b/mlir/test/Dialect/MemRef/invalid.mlir
@@ -848,3 +848,27 @@ func @illegal_num_offsets(%arg0 : memref<?x?x?xf32>, %arg1 : index, %arg2 : inde
   // expected-error at +1 {{expected 3 offset values}}
   %0 = memref.subview %arg0[0, 0] [%arg1, %arg2] [1, 1] : memref<?x?x?xf32> to memref<?x?x?xf32, #map>
 }
+
+// -----
+
+func @atomic_rmw_idxs_rank_mismatch(%I: memref<16x10xf32>, %i : index, %val : f32) {
+  // expected-error at +1 {{expects the number of subscripts to be equal to memref rank}}
+  %x = memref.atomic_rmw addf %val, %I[%i] : (f32, memref<16x10xf32>) -> f32
+  return
+}
+
+// -----
+
+func @atomic_rmw_expects_float(%I: memref<16x10xi32>, %i : index, %val : i32) {
+  // expected-error at +1 {{expects a floating-point type}}
+  %x = memref.atomic_rmw addf %val, %I[%i, %i] : (i32, memref<16x10xi32>) -> i32
+  return
+}
+
+// -----
+
+func @atomic_rmw_expects_int(%I: memref<16x10xf32>, %i : index, %val : f32) {
+  // expected-error at +1 {{expects an integer type}}
+  %x = memref.atomic_rmw addi %val, %I[%i, %i] : (f32, memref<16x10xf32>) -> f32
+  return
+}

diff  --git a/mlir/test/Dialect/MemRef/ops.mlir b/mlir/test/Dialect/MemRef/ops.mlir
index 963c817af3981..71b6038a2f9d2 100644
--- a/mlir/test/Dialect/MemRef/ops.mlir
+++ b/mlir/test/Dialect/MemRef/ops.mlir
@@ -227,3 +227,13 @@ func @rank(%t : memref<4x4x?xf32>) {
   %1 = memref.rank %t : memref<4x4x?xf32>
   return
 }
+
+// ------
+
+// CHECK-LABEL: func @atomic_rmw
+// CHECK-SAME: ([[BUF:%.*]]: memref<10xf32>, [[VAL:%.*]]: f32, [[I:%.*]]: index)
+func @atomic_rmw(%I: memref<10xf32>, %val: f32, %i : index) {
+  %x = memref.atomic_rmw addf %val, %I[%i] : (f32, memref<10xf32>) -> f32
+  // CHECK: memref.atomic_rmw addf [[VAL]], [[BUF]]{{\[}}[[I]]]
+  return
+}

diff  --git a/mlir/test/Dialect/Standard/expand-ops.mlir b/mlir/test/Dialect/Standard/expand-ops.mlir
index 45659aee0763d..cb650ffd11bd5 100644
--- a/mlir/test/Dialect/Standard/expand-ops.mlir
+++ b/mlir/test/Dialect/Standard/expand-ops.mlir
@@ -3,7 +3,7 @@
 // CHECK-LABEL: func @atomic_rmw_to_generic
 // CHECK-SAME: ([[F:%.*]]: memref<10xf32>, [[f:%.*]]: f32, [[i:%.*]]: index)
 func @atomic_rmw_to_generic(%F: memref<10xf32>, %f: f32, %i: index) -> f32 {
-  %x = atomic_rmw maxf %f, %F[%i] : (f32, memref<10xf32>) -> f32
+  %x = memref.atomic_rmw maxf %f, %F[%i] : (f32, memref<10xf32>) -> f32
   return %x : f32
 }
 // CHECK: %0 = generic_atomic_rmw %arg0[%arg2] : memref<10xf32> {
@@ -18,7 +18,7 @@ func @atomic_rmw_to_generic(%F: memref<10xf32>, %f: f32, %i: index) -> f32 {
 
 // CHECK-LABEL: func @atomic_rmw_no_conversion
 func @atomic_rmw_no_conversion(%F: memref<10xf32>, %f: f32, %i: index) -> f32 {
-  %x = atomic_rmw addf %f, %F[%i] : (f32, memref<10xf32>) -> f32
+  %x = memref.atomic_rmw addf %f, %F[%i] : (f32, memref<10xf32>) -> f32
   return %x : f32
 }
 // CHECK-NOT: generic_atomic_rmw

diff  --git a/mlir/test/IR/core-ops.mlir b/mlir/test/IR/core-ops.mlir
index b83f530eeacc6..351e8a6b39c1b 100644
--- a/mlir/test/IR/core-ops.mlir
+++ b/mlir/test/IR/core-ops.mlir
@@ -325,14 +325,6 @@ func @unranked_tensor_load_store(%0 : memref<*xi32>, %1 : tensor<*xi32>) {
   return
 }
 
-// CHECK-LABEL: func @atomic_rmw
-// CHECK-SAME: ([[BUF:%.*]]: memref<10xf32>, [[VAL:%.*]]: f32, [[I:%.*]]: index)
-func @atomic_rmw(%I: memref<10xf32>, %val: f32, %i : index) {
-  %x = atomic_rmw addf %val, %I[%i] : (f32, memref<10xf32>) -> f32
-  // CHECK: atomic_rmw addf [[VAL]], [[BUF]]{{\[}}[[I]]]
-  return
-}
-
 // CHECK-LABEL: func @generic_atomic_rmw
 // CHECK-SAME: ([[BUF:%.*]]: memref<1x2xf32>, [[I:%.*]]: index, [[J:%.*]]: index)
 func @generic_atomic_rmw(%I: memref<1x2xf32>, %i : index, %j : index) {

diff  --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir
index 49f29f09bf492..2aae390af4d02 100644
--- a/mlir/test/IR/invalid-ops.mlir
+++ b/mlir/test/IR/invalid-ops.mlir
@@ -130,30 +130,6 @@ func @invalid_splat(%v : f32) { // expected-note {{prior use here}}
 
 // -----
 
-func @atomic_rmw_idxs_rank_mismatch(%I: memref<16x10xf32>, %i : index, %val : f32) {
-  // expected-error at +1 {{expects the number of subscripts to be equal to memref rank}}
-  %x = atomic_rmw addf %val, %I[%i] : (f32, memref<16x10xf32>) -> f32
-  return
-}
-
-// -----
-
-func @atomic_rmw_expects_float(%I: memref<16x10xi32>, %i : index, %val : i32) {
-  // expected-error at +1 {{expects a floating-point type}}
-  %x = atomic_rmw addf %val, %I[%i, %i] : (i32, memref<16x10xi32>) -> i32
-  return
-}
-
-// -----
-
-func @atomic_rmw_expects_int(%I: memref<16x10xf32>, %i : index, %val : f32) {
-  // expected-error at +1 {{expects an integer type}}
-  %x = atomic_rmw addi %val, %I[%i, %i] : (f32, memref<16x10xf32>) -> f32
-  return
-}
-
-// -----
-
 func @generic_atomic_rmw_wrong_arg_num(%I: memref<10xf32>, %i : index) {
   // expected-error at +1 {{expected single number of entry block arguments}}
   %x = generic_atomic_rmw %I[%i] : memref<10xf32> {


        


More information about the Mlir-commits mailing list