[Mlir-commits] [mlir] [mlir] Add `maxnumf` and `minnumf` to `AtomicRMWKind` (PR #66442)
Daniil Dudkin
llvmlistbot at llvm.org
Thu Sep 14 15:31:57 PDT 2023
https://github.com/unterumarmung created https://github.com/llvm/llvm-project/pull/66442:
This commit adds the mentioned kinds of `AtomicRMWKind`
as well as code generation for them.
>From 3d680710f6ea0680895024337820a367d4f38fe8 Mon Sep 17 00:00:00 2001
From: Daniil Dudkin <unterumarmung at yandex.ru>
Date: Fri, 15 Sep 2023 01:00:22 +0300
Subject: [PATCH] [mlir] Add `maxnumf` and `minnumf` to `AtomicRMWKind`
This commit adds the mentioned kinds of `AtomicRMWKind`
as well as code generation for them.
---
.../include/mlir/Dialect/Arith/IR/ArithBase.td | 4 +++-
mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 4 ++++
.../Dialect/MemRef/Transforms/ExpandOps.cpp | 7 +++++--
mlir/test/Dialect/MemRef/expand-ops.mlir | 18 +++++++++++++++---
4 files changed, 27 insertions(+), 6 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td b/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td
index a833e9c8220af5b..133af893e4efa74 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td
@@ -82,6 +82,8 @@ def ATOMIC_RMW_KIND_MULF : I64EnumAttrCase<"mulf", 9>;
def ATOMIC_RMW_KIND_MULI : I64EnumAttrCase<"muli", 10>;
def ATOMIC_RMW_KIND_ORI : I64EnumAttrCase<"ori", 11>;
def ATOMIC_RMW_KIND_ANDI : I64EnumAttrCase<"andi", 12>;
+def ATOMIC_RMW_KIND_MAXNUMF : I64EnumAttrCase<"maxnumf", 13>;
+def ATOMIC_RMW_KIND_MINNUMF : I64EnumAttrCase<"minnumf", 14>;
def AtomicRMWKindAttr : I64EnumAttr<
"AtomicRMWKind", "",
@@ -89,7 +91,7 @@ def AtomicRMWKindAttr : I64EnumAttr<
ATOMIC_RMW_KIND_MAXIMUMF, ATOMIC_RMW_KIND_MAXS, ATOMIC_RMW_KIND_MAXU,
ATOMIC_RMW_KIND_MINIMUMF, ATOMIC_RMW_KIND_MINS, ATOMIC_RMW_KIND_MINU,
ATOMIC_RMW_KIND_MULF, ATOMIC_RMW_KIND_MULI, ATOMIC_RMW_KIND_ORI,
- ATOMIC_RMW_KIND_ANDI]> {
+ ATOMIC_RMW_KIND_ANDI, ATOMIC_RMW_KIND_MAXNUMF, ATOMIC_RMW_KIND_MINNUMF]> {
let cppNamespace = "::mlir::arith";
}
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index d39c5b6051122e4..ae8a6ef350ce191 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -2523,6 +2523,10 @@ Value mlir::arith::getReductionOp(AtomicRMWKind op, OpBuilder &builder,
return builder.create<arith::MaximumFOp>(loc, lhs, rhs);
case AtomicRMWKind::minimumf:
return builder.create<arith::MinimumFOp>(loc, lhs, rhs);
+ case AtomicRMWKind::maxnumf:
+ return builder.create<arith::MaxNumFOp>(loc, lhs, rhs);
+ case AtomicRMWKind::minnumf:
+ return builder.create<arith::MinNumFOp>(loc, lhs, rhs);
case AtomicRMWKind::maxs:
return builder.create<arith::MaxSIOp>(loc, lhs, rhs);
case AtomicRMWKind::mins:
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp
index b3beaada2539dbc..faba12f5bf82f89 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp
@@ -20,6 +20,7 @@
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/STLExtras.h"
namespace mlir {
namespace memref {
@@ -126,8 +127,10 @@ struct ExpandOpsPass : public memref::impl::ExpandOpsBase<ExpandOpsPass> {
target.addLegalDialect<arith::ArithDialect, memref::MemRefDialect>();
target.addDynamicallyLegalOp<memref::AtomicRMWOp>(
[](memref::AtomicRMWOp op) {
- return op.getKind() != arith::AtomicRMWKind::maximumf &&
- op.getKind() != arith::AtomicRMWKind::minimumf;
+ constexpr std::array shouldBeExpandedKinds = {
+ arith::AtomicRMWKind::maximumf, arith::AtomicRMWKind::minimumf,
+ arith::AtomicRMWKind::minnumf, arith::AtomicRMWKind::maxnumf};
+ return !llvm::is_contained(shouldBeExpandedKinds, op.getKind());
});
target.addDynamicallyLegalOp<memref::ReshapeOp>([](memref::ReshapeOp op) {
return !cast<MemRefType>(op.getShape().getType()).hasStaticShape();
diff --git a/mlir/test/Dialect/MemRef/expand-ops.mlir b/mlir/test/Dialect/MemRef/expand-ops.mlir
index 6c98cf978505334..f958a92b751a4ab 100644
--- a/mlir/test/Dialect/MemRef/expand-ops.mlir
+++ b/mlir/test/Dialect/MemRef/expand-ops.mlir
@@ -3,9 +3,11 @@
// CHECK-LABEL: func @atomic_rmw_to_generic
// CHECK-SAME: ([[F:%.*]]: memref<10xf32>, [[f:%.*]]: f32, [[i:%.*]]: index)
func.func @atomic_rmw_to_generic(%F: memref<10xf32>, %f: f32, %i: index) -> f32 {
- %x = memref.atomic_rmw maximumf %f, %F[%i] : (f32, memref<10xf32>) -> f32
- %y = memref.atomic_rmw minimumf %f, %F[%i] : (f32, memref<10xf32>) -> f32
- return %x : f32
+ %a = memref.atomic_rmw maximumf %f, %F[%i] : (f32, memref<10xf32>) -> f32
+ %b = memref.atomic_rmw minimumf %f, %F[%i] : (f32, memref<10xf32>) -> f32
+ %c = memref.atomic_rmw maxnumf %f, %F[%i] : (f32, memref<10xf32>) -> f32
+ %d = memref.atomic_rmw minnumf %f, %F[%i] : (f32, memref<10xf32>) -> f32
+ return %a : f32
}
// CHECK: [[RESULT:%.*]] = memref.generic_atomic_rmw %arg0[%arg2] : memref<10xf32> {
// CHECK: ^bb0([[CUR_VAL:%.*]]: f32):
@@ -17,6 +19,16 @@ func.func @atomic_rmw_to_generic(%F: memref<10xf32>, %f: f32, %i: index) -> f32
// CHECK: [[MINIMUM:%.*]] = arith.minimumf [[CUR_VAL]], [[f]] : f32
// CHECK: memref.atomic_yield [[MINIMUM]] : f32
// CHECK: }
+// CHECK: memref.generic_atomic_rmw %arg0[%arg2] : memref<10xf32> {
+// CHECK: ^bb0([[CUR_VAL:%.*]]: f32):
+// CHECK: [[MAXNUM:%.*]] = arith.maxnumf [[CUR_VAL]], [[f]] : f32
+// CHECK: memref.atomic_yield [[MAXNUM]] : f32
+// CHECK: }
+// CHECK: memref.generic_atomic_rmw %arg0[%arg2] : memref<10xf32> {
+// CHECK: ^bb0([[CUR_VAL:%.*]]: f32):
+// CHECK: [[MINNUM:%.*]] = arith.minnumf [[CUR_VAL]], [[f]] : f32
+// CHECK: memref.atomic_yield [[MINNUM]] : f32
+// CHECK: }
// CHECK: return [[RESULT]] : f32
// -----
More information about the Mlir-commits
mailing list