[Mlir-commits] [mlir] [Arith][MemRef] add AtomicRMWKind::xori to enum (PR #151701)

Scott Manley llvmlistbot at llvm.org
Mon Aug 11 05:21:41 PDT 2025


https://github.com/rscottmanley updated https://github.com/llvm/llvm-project/pull/151701

>From cd2b21834916ff0cf465ec4f84d16265bae32393 Mon Sep 17 00:00:00 2001
From: Scott Manley <scmanley at nvidia.com>
Date: Fri, 1 Aug 2025 06:19:33 -0700
Subject: [PATCH] [Arith][MemRef] add AtomicRMWKind::xor to enum

Add missing xor AtomicRMWKind enum in arith. Also add support for
xor to memref.atomic_rmw so the change can be tested.

This does NOT add it for all users of the enum (e.g. Affine, Vector)
---
 .../mlir/Dialect/Arith/IR/ArithBase.td        | 36 ++++++++++---------
 .../Conversion/MemRefToLLVM/MemRefToLLVM.cpp  |  2 ++
 mlir/lib/Dialect/Arith/IR/ArithOps.cpp        |  5 ++-
 mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp      |  1 +
 .../MemRefToLLVM/memref-to-llvm.mlir          |  4 ++-
 5 files changed, 29 insertions(+), 19 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td b/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td
index 19a2ade2e95a0..e51c20498746b 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td
@@ -76,27 +76,29 @@ def Arith_CmpIPredicateAttr : I64EnumAttr<
 
 def ATOMIC_RMW_KIND_ADDF     : I64EnumAttrCase<"addf", 0>;
 def ATOMIC_RMW_KIND_ADDI     : I64EnumAttrCase<"addi", 1>;
-def ATOMIC_RMW_KIND_ASSIGN   : I64EnumAttrCase<"assign", 2>;
-def ATOMIC_RMW_KIND_MAXIMUMF : I64EnumAttrCase<"maximumf", 3>;
-def ATOMIC_RMW_KIND_MAXS     : I64EnumAttrCase<"maxs", 4>;
-def ATOMIC_RMW_KIND_MAXU     : I64EnumAttrCase<"maxu", 5>;
-def ATOMIC_RMW_KIND_MINIMUMF : I64EnumAttrCase<"minimumf", 6>;
-def ATOMIC_RMW_KIND_MINS     : I64EnumAttrCase<"mins", 7>;
-def ATOMIC_RMW_KIND_MINU     : I64EnumAttrCase<"minu", 8>;
-def ATOMIC_RMW_KIND_MULF     : I64EnumAttrCase<"mulf", 9>;
-def ATOMIC_RMW_KIND_MULI     : I64EnumAttrCase<"muli", 10>;
-def ATOMIC_RMW_KIND_ORI      : I64EnumAttrCase<"ori", 11>;
-def ATOMIC_RMW_KIND_ANDI     : I64EnumAttrCase<"andi", 12>;
-def ATOMIC_RMW_KIND_MAXNUMF  : I64EnumAttrCase<"maxnumf", 13>;
-def ATOMIC_RMW_KIND_MINNUMF  : I64EnumAttrCase<"minnumf", 14>;
+def ATOMIC_RMW_KIND_ANDI     : I64EnumAttrCase<"andi", 2>;
+def ATOMIC_RMW_KIND_ASSIGN   : I64EnumAttrCase<"assign", 3>;
+def ATOMIC_RMW_KIND_MAXIMUMF : I64EnumAttrCase<"maximumf", 4>;
+def ATOMIC_RMW_KIND_MAXNUMF  : I64EnumAttrCase<"maxnumf", 5>;
+def ATOMIC_RMW_KIND_MAXS     : I64EnumAttrCase<"maxs", 6>;
+def ATOMIC_RMW_KIND_MAXU     : I64EnumAttrCase<"maxu", 7>;
+def ATOMIC_RMW_KIND_MINIMUMF : I64EnumAttrCase<"minimumf", 8>;
+def ATOMIC_RMW_KIND_MINNUMF  : I64EnumAttrCase<"minnumf", 9>;
+def ATOMIC_RMW_KIND_MINS     : I64EnumAttrCase<"mins", 10>;
+def ATOMIC_RMW_KIND_MINU     : I64EnumAttrCase<"minu", 11>;
+def ATOMIC_RMW_KIND_MULF     : I64EnumAttrCase<"mulf", 12>;
+def ATOMIC_RMW_KIND_MULI     : I64EnumAttrCase<"muli", 13>;
+def ATOMIC_RMW_KIND_ORI      : I64EnumAttrCase<"ori", 14>;
+def ATOMIC_RMW_KIND_XORI     : I64EnumAttrCase<"xori", 15>;
 
 def AtomicRMWKindAttr : I64EnumAttr<
     "AtomicRMWKind", "",
-    [ATOMIC_RMW_KIND_ADDF, ATOMIC_RMW_KIND_ADDI, ATOMIC_RMW_KIND_ASSIGN,
-     ATOMIC_RMW_KIND_MAXIMUMF, ATOMIC_RMW_KIND_MAXS, ATOMIC_RMW_KIND_MAXU,
-     ATOMIC_RMW_KIND_MINIMUMF, ATOMIC_RMW_KIND_MINS, ATOMIC_RMW_KIND_MINU,
+    [ATOMIC_RMW_KIND_ADDF, ATOMIC_RMW_KIND_ADDI, ATOMIC_RMW_KIND_ANDI, 
+     ATOMIC_RMW_KIND_ASSIGN, ATOMIC_RMW_KIND_MAXIMUMF, ATOMIC_RMW_KIND_MAXNUMF,
+     ATOMIC_RMW_KIND_MAXS, ATOMIC_RMW_KIND_MAXU, ATOMIC_RMW_KIND_MINIMUMF, 
+     ATOMIC_RMW_KIND_MINNUMF, ATOMIC_RMW_KIND_MINS, ATOMIC_RMW_KIND_MINU,
      ATOMIC_RMW_KIND_MULF, ATOMIC_RMW_KIND_MULI, ATOMIC_RMW_KIND_ORI,
-     ATOMIC_RMW_KIND_ANDI, ATOMIC_RMW_KIND_MAXNUMF, ATOMIC_RMW_KIND_MINNUMF]> {
+     ATOMIC_RMW_KIND_XORI]> {
   let cppNamespace = "::mlir::arith";
 }
 
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index dc2035b0700d0..a45fde4b85ab8 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -1871,6 +1871,8 @@ matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) {
     return LLVM::AtomicBinOp::umin;
   case arith::AtomicRMWKind::ori:
     return LLVM::AtomicBinOp::_or;
+  case arith::AtomicRMWKind::xori:
+    return LLVM::AtomicBinOp::_xor;
   case arith::AtomicRMWKind::andi:
     return LLVM::AtomicBinOp::_and;
   default:
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 488c3c369afed..7d4d818ee448b 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -2678,6 +2678,7 @@ TypedAttr mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
   case AtomicRMWKind::addi:
   case AtomicRMWKind::maxu:
   case AtomicRMWKind::ori:
+  case AtomicRMWKind::xori:
     return builder.getZeroAttr(resultType);
   case AtomicRMWKind::andi:
     return builder.getIntegerAttr(
@@ -2736,7 +2737,7 @@ std::optional<TypedAttr> mlir::arith::getNeutralElement(Operation *op) {
           // Integer operations.
           .Case([](arith::AddIOp op) { return AtomicRMWKind::addi; })
           .Case([](arith::OrIOp op) { return AtomicRMWKind::ori; })
-          .Case([](arith::XOrIOp op) { return AtomicRMWKind::ori; })
+          .Case([](arith::XOrIOp op) { return AtomicRMWKind::xori; })
           .Case([](arith::AndIOp op) { return AtomicRMWKind::andi; })
           .Case([](arith::MaxUIOp op) { return AtomicRMWKind::maxu; })
           .Case([](arith::MinUIOp op) { return AtomicRMWKind::minu; })
@@ -2806,6 +2807,8 @@ Value mlir::arith::getReductionOp(AtomicRMWKind op, OpBuilder &builder,
     return arith::OrIOp::create(builder, loc, lhs, rhs);
   case AtomicRMWKind::andi:
     return arith::AndIOp::create(builder, loc, lhs, rhs);
+  case AtomicRMWKind::xori:
+    return arith::XOrIOp::create(builder, loc, lhs, rhs);
   // TODO: Add remaining reduction operations.
   default:
     (void)emitOptionalError(loc, "Reduction operation type not supported");
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 74b968c27a62a..b59d73d1291c8 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -3558,6 +3558,7 @@ LogicalResult AtomicRMWOp::verify() {
   case arith::AtomicRMWKind::minu:
   case arith::AtomicRMWKind::muli:
   case arith::AtomicRMWKind::ori:
+  case arith::AtomicRMWKind::xori:
   case arith::AtomicRMWKind::andi:
     if (!llvm::isa<IntegerType>(getValue().getType()))
       return emitOpError() << "with kind '"
diff --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
index 51d56389dac9e..12a46daffcb9e 100644
--- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
@@ -464,7 +464,9 @@ func.func @atomic_rmw(%I : memref<10xi32>, %ival : i32, %F : memref<10xf32>, %fv
   // CHECK: llvm.atomicrmw _or %{{.*}}, %{{.*}} acq_rel
   memref.atomic_rmw andi %ival, %I[%i] : (i32, memref<10xi32>) -> i32
   // CHECK: llvm.atomicrmw _and %{{.*}}, %{{.*}} acq_rel
-  // CHECK-INTERFACE-COUNT-13: llvm.atomicrmw
+  memref.atomic_rmw xori %ival, %I[%i] : (i32, memref<10xi32>) -> i32
+  // CHECK: llvm.atomicrmw _xor %{{.*}}, %{{.*}} acq_rel
+  // CHECK-INTERFACE-COUNT-14: llvm.atomicrmw
   return
 }
 



More information about the Mlir-commits mailing list