[Mlir-commits] [mlir] [Arith][MemRef] add AtomicRMWKind::xori to enum (PR #151701)
Scott Manley
llvmlistbot at llvm.org
Mon Aug 11 05:21:41 PDT 2025
https://github.com/rscottmanley updated https://github.com/llvm/llvm-project/pull/151701
>From cd2b21834916ff0cf465ec4f84d16265bae32393 Mon Sep 17 00:00:00 2001
From: Scott Manley <scmanley at nvidia.com>
Date: Fri, 1 Aug 2025 06:19:33 -0700
Subject: [PATCH] [Arith][MemRef] add AtomicRMWKind::xor to enum
Add missing xor AtomicRMWKind enum in arith. Also add support for
xor to memref.atomic_rmw so the change can be tested.
This does NOT add it for all users of the enum (e.g. Affine, Vector)
---
.../mlir/Dialect/Arith/IR/ArithBase.td | 36 ++++++++++---------
.../Conversion/MemRefToLLVM/MemRefToLLVM.cpp | 2 ++
mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 5 ++-
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 1 +
.../MemRefToLLVM/memref-to-llvm.mlir | 4 ++-
5 files changed, 29 insertions(+), 19 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td b/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td
index 19a2ade2e95a0..e51c20498746b 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td
@@ -76,27 +76,29 @@ def Arith_CmpIPredicateAttr : I64EnumAttr<
def ATOMIC_RMW_KIND_ADDF : I64EnumAttrCase<"addf", 0>;
def ATOMIC_RMW_KIND_ADDI : I64EnumAttrCase<"addi", 1>;
-def ATOMIC_RMW_KIND_ASSIGN : I64EnumAttrCase<"assign", 2>;
-def ATOMIC_RMW_KIND_MAXIMUMF : I64EnumAttrCase<"maximumf", 3>;
-def ATOMIC_RMW_KIND_MAXS : I64EnumAttrCase<"maxs", 4>;
-def ATOMIC_RMW_KIND_MAXU : I64EnumAttrCase<"maxu", 5>;
-def ATOMIC_RMW_KIND_MINIMUMF : I64EnumAttrCase<"minimumf", 6>;
-def ATOMIC_RMW_KIND_MINS : I64EnumAttrCase<"mins", 7>;
-def ATOMIC_RMW_KIND_MINU : I64EnumAttrCase<"minu", 8>;
-def ATOMIC_RMW_KIND_MULF : I64EnumAttrCase<"mulf", 9>;
-def ATOMIC_RMW_KIND_MULI : I64EnumAttrCase<"muli", 10>;
-def ATOMIC_RMW_KIND_ORI : I64EnumAttrCase<"ori", 11>;
-def ATOMIC_RMW_KIND_ANDI : I64EnumAttrCase<"andi", 12>;
-def ATOMIC_RMW_KIND_MAXNUMF : I64EnumAttrCase<"maxnumf", 13>;
-def ATOMIC_RMW_KIND_MINNUMF : I64EnumAttrCase<"minnumf", 14>;
+def ATOMIC_RMW_KIND_ANDI : I64EnumAttrCase<"andi", 2>;
+def ATOMIC_RMW_KIND_ASSIGN : I64EnumAttrCase<"assign", 3>;
+def ATOMIC_RMW_KIND_MAXIMUMF : I64EnumAttrCase<"maximumf", 4>;
+def ATOMIC_RMW_KIND_MAXNUMF : I64EnumAttrCase<"maxnumf", 5>;
+def ATOMIC_RMW_KIND_MAXS : I64EnumAttrCase<"maxs", 6>;
+def ATOMIC_RMW_KIND_MAXU : I64EnumAttrCase<"maxu", 7>;
+def ATOMIC_RMW_KIND_MINIMUMF : I64EnumAttrCase<"minimumf", 8>;
+def ATOMIC_RMW_KIND_MINNUMF : I64EnumAttrCase<"minnumf", 9>;
+def ATOMIC_RMW_KIND_MINS : I64EnumAttrCase<"mins", 10>;
+def ATOMIC_RMW_KIND_MINU : I64EnumAttrCase<"minu", 11>;
+def ATOMIC_RMW_KIND_MULF : I64EnumAttrCase<"mulf", 12>;
+def ATOMIC_RMW_KIND_MULI : I64EnumAttrCase<"muli", 13>;
+def ATOMIC_RMW_KIND_ORI : I64EnumAttrCase<"ori", 14>;
+def ATOMIC_RMW_KIND_XORI : I64EnumAttrCase<"xori", 15>;
def AtomicRMWKindAttr : I64EnumAttr<
"AtomicRMWKind", "",
- [ATOMIC_RMW_KIND_ADDF, ATOMIC_RMW_KIND_ADDI, ATOMIC_RMW_KIND_ASSIGN,
- ATOMIC_RMW_KIND_MAXIMUMF, ATOMIC_RMW_KIND_MAXS, ATOMIC_RMW_KIND_MAXU,
- ATOMIC_RMW_KIND_MINIMUMF, ATOMIC_RMW_KIND_MINS, ATOMIC_RMW_KIND_MINU,
+ [ATOMIC_RMW_KIND_ADDF, ATOMIC_RMW_KIND_ADDI, ATOMIC_RMW_KIND_ANDI,
+ ATOMIC_RMW_KIND_ASSIGN, ATOMIC_RMW_KIND_MAXIMUMF, ATOMIC_RMW_KIND_MAXNUMF,
+ ATOMIC_RMW_KIND_MAXS, ATOMIC_RMW_KIND_MAXU, ATOMIC_RMW_KIND_MINIMUMF,
+ ATOMIC_RMW_KIND_MINNUMF, ATOMIC_RMW_KIND_MINS, ATOMIC_RMW_KIND_MINU,
ATOMIC_RMW_KIND_MULF, ATOMIC_RMW_KIND_MULI, ATOMIC_RMW_KIND_ORI,
- ATOMIC_RMW_KIND_ANDI, ATOMIC_RMW_KIND_MAXNUMF, ATOMIC_RMW_KIND_MINNUMF]> {
+ ATOMIC_RMW_KIND_XORI]> {
let cppNamespace = "::mlir::arith";
}
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index dc2035b0700d0..a45fde4b85ab8 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -1871,6 +1871,8 @@ matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) {
return LLVM::AtomicBinOp::umin;
case arith::AtomicRMWKind::ori:
return LLVM::AtomicBinOp::_or;
+ case arith::AtomicRMWKind::xori:
+ return LLVM::AtomicBinOp::_xor;
case arith::AtomicRMWKind::andi:
return LLVM::AtomicBinOp::_and;
default:
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 488c3c369afed..7d4d818ee448b 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -2678,6 +2678,7 @@ TypedAttr mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
case AtomicRMWKind::addi:
case AtomicRMWKind::maxu:
case AtomicRMWKind::ori:
+ case AtomicRMWKind::xori:
return builder.getZeroAttr(resultType);
case AtomicRMWKind::andi:
return builder.getIntegerAttr(
@@ -2736,7 +2737,7 @@ std::optional<TypedAttr> mlir::arith::getNeutralElement(Operation *op) {
// Integer operations.
.Case([](arith::AddIOp op) { return AtomicRMWKind::addi; })
.Case([](arith::OrIOp op) { return AtomicRMWKind::ori; })
- .Case([](arith::XOrIOp op) { return AtomicRMWKind::ori; })
+ .Case([](arith::XOrIOp op) { return AtomicRMWKind::xori; })
.Case([](arith::AndIOp op) { return AtomicRMWKind::andi; })
.Case([](arith::MaxUIOp op) { return AtomicRMWKind::maxu; })
.Case([](arith::MinUIOp op) { return AtomicRMWKind::minu; })
@@ -2806,6 +2807,8 @@ Value mlir::arith::getReductionOp(AtomicRMWKind op, OpBuilder &builder,
return arith::OrIOp::create(builder, loc, lhs, rhs);
case AtomicRMWKind::andi:
return arith::AndIOp::create(builder, loc, lhs, rhs);
+ case AtomicRMWKind::xori:
+ return arith::XOrIOp::create(builder, loc, lhs, rhs);
// TODO: Add remaining reduction operations.
default:
(void)emitOptionalError(loc, "Reduction operation type not supported");
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 74b968c27a62a..b59d73d1291c8 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -3558,6 +3558,7 @@ LogicalResult AtomicRMWOp::verify() {
case arith::AtomicRMWKind::minu:
case arith::AtomicRMWKind::muli:
case arith::AtomicRMWKind::ori:
+ case arith::AtomicRMWKind::xori:
case arith::AtomicRMWKind::andi:
if (!llvm::isa<IntegerType>(getValue().getType()))
return emitOpError() << "with kind '"
diff --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
index 51d56389dac9e..12a46daffcb9e 100644
--- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
@@ -464,7 +464,9 @@ func.func @atomic_rmw(%I : memref<10xi32>, %ival : i32, %F : memref<10xf32>, %fv
// CHECK: llvm.atomicrmw _or %{{.*}}, %{{.*}} acq_rel
memref.atomic_rmw andi %ival, %I[%i] : (i32, memref<10xi32>) -> i32
// CHECK: llvm.atomicrmw _and %{{.*}}, %{{.*}} acq_rel
- // CHECK-INTERFACE-COUNT-13: llvm.atomicrmw
+ memref.atomic_rmw xori %ival, %I[%i] : (i32, memref<10xi32>) -> i32
+ // CHECK: llvm.atomicrmw _xor %{{.*}}, %{{.*}} acq_rel
+ // CHECK-INTERFACE-COUNT-14: llvm.atomicrmw
return
}
More information about the Mlir-commits
mailing list