[Mlir-commits] [mlir] [mlir][arith] Rename `AtomicRMWKind`'s `maxf` → `maximumf`, `minf` → `minimumf` (PR #66135)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Sep 12 13:04:13 PDT 2023
llvmbot wrote:
@llvm/pr-subscribers-mlir-arith
<details>
<summary>Changes</summary>
This patch is part of a larger initiative aimed at fixing floating-point `max` and `min` operations in MLIR: https://discourse.llvm.org/t/rfc-fix-floating-point-max-and-min-operations-in-mlir/72671.
This commit renames `maxf` and `minf` enumerators of `AtomicRMWKind`
to better reflect the current naming scheme and the goals of the RFC.
--
Full diff: https://github.com/llvm/llvm-project/pull/66135.diff
11 Files Affected:
- (modified) mlir/include/mlir/Dialect/Arith/IR/ArithBase.td (+15-15)
- (modified) mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp (+2-2)
- (modified) mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp (+4-2)
- (modified) mlir/lib/Dialect/Arith/IR/ArithOps.cpp (+6-6)
- (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+3-3)
- (modified) mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp (+2-2)
- (modified) mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp (+5-5)
- (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+2-2)
- (modified) mlir/test/Dialect/Affine/invalid.mlir (+1-1)
- (modified) mlir/test/Dialect/Affine/ops.mlir (+2-2)
- (modified) mlir/test/Dialect/MemRef/expand-ops.mlir (+1-1)
<pre>
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td b/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td
index 78fd7bdf012f8a8..a833e9c8220af5b 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td
@@ -69,25 +69,25 @@ def Arith_CmpIPredicateAttr : I64EnumAttr<
let cppNamespace = "::mlir::arith";
}
-def ATOMIC_RMW_KIND_ADDF : I64EnumAttrCase<"addf", 0>;
-def ATOMIC_RMW_KIND_ADDI : I64EnumAttrCase<"addi", 1>;
-def ATOMIC_RMW_KIND_ASSIGN : I64EnumAttrCase<"assign", 2>;
-def ATOMIC_RMW_KIND_MAXF : I64EnumAttrCase<"maxf", 3>;
-def ATOMIC_RMW_KIND_MAXS : I64EnumAttrCase<"maxs", 4>;
-def ATOMIC_RMW_KIND_MAXU : I64EnumAttrCase<"maxu", 5>;
-def ATOMIC_RMW_KIND_MINF : I64EnumAttrCase<"minf", 6>;
-def ATOMIC_RMW_KIND_MINS : I64EnumAttrCase<"mins", 7>;
-def ATOMIC_RMW_KIND_MINU : I64EnumAttrCase<"minu", 8>;
-def ATOMIC_RMW_KIND_MULF : I64EnumAttrCase<"mulf", 9>;
-def ATOMIC_RMW_KIND_MULI : I64EnumAttrCase<"muli", 10>;
-def ATOMIC_RMW_KIND_ORI : I64EnumAttrCase<"ori", 11>;
-def ATOMIC_RMW_KIND_ANDI : I64EnumAttrCase<"andi", 12>;
+def ATOMIC_RMW_KIND_ADDF : I64EnumAttrCase<"addf", 0>;
+def ATOMIC_RMW_KIND_ADDI : I64EnumAttrCase<"addi", 1>;
+def ATOMIC_RMW_KIND_ASSIGN : I64EnumAttrCase<"assign", 2>;
+def ATOMIC_RMW_KIND_MAXIMUMF : I64EnumAttrCase<"maximumf", 3>;
+def ATOMIC_RMW_KIND_MAXS : I64EnumAttrCase<"maxs", 4>;
+def ATOMIC_RMW_KIND_MAXU : I64EnumAttrCase<"maxu", 5>;
+def ATOMIC_RMW_KIND_MINIMUMF : I64EnumAttrCase<"minimumf", 6>;
+def ATOMIC_RMW_KIND_MINS : I64EnumAttrCase<"mins", 7>;
+def ATOMIC_RMW_KIND_MINU : I64EnumAttrCase<"minu", 8>;
+def ATOMIC_RMW_KIND_MULF : I64EnumAttrCase<"mulf", 9>;
+def ATOMIC_RMW_KIND_MULI : I64EnumAttrCase<"muli", 10>;
+def ATOMIC_RMW_KIND_ORI : I64EnumAttrCase<"ori", 11>;
+def ATOMIC_RMW_KIND_ANDI : I64EnumAttrCase<"andi", 12>;
def AtomicRMWKindAttr : I64EnumAttr<
"AtomicRMWKind", "",
[ATOMIC_RMW_KIND_ADDF, ATOMIC_RMW_KIND_ADDI, ATOMIC_RMW_KIND_ASSIGN,
- ATOMIC_RMW_KIND_MAXF, ATOMIC_RMW_KIND_MAXS, ATOMIC_RMW_KIND_MAXU,
- ATOMIC_RMW_KIND_MINF, ATOMIC_RMW_KIND_MINS, ATOMIC_RMW_KIND_MINU,
+ ATOMIC_RMW_KIND_MAXIMUMF, ATOMIC_RMW_KIND_MAXS, ATOMIC_RMW_KIND_MAXU,
+ ATOMIC_RMW_KIND_MINIMUMF, ATOMIC_RMW_KIND_MINS, ATOMIC_RMW_KIND_MINU,
ATOMIC_RMW_KIND_MULF, ATOMIC_RMW_KIND_MULI, ATOMIC_RMW_KIND_ORI,
ATOMIC_RMW_KIND_ANDI]> {
let cppNamespace = "::mlir::arith";
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 516a6b8ed88e61e..97faefe2cd4d631 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -1594,13 +1594,13 @@ matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) {
return LLVM::AtomicBinOp::add;
case arith::AtomicRMWKind::assign:
return LLVM::AtomicBinOp::xchg;
- case arith::AtomicRMWKind::maxf:
+ case arith::AtomicRMWKind::maximumf:
return LLVM::AtomicBinOp::fmax;
case arith::AtomicRMWKind::maxs:
return LLVM::AtomicBinOp::max;
case arith::AtomicRMWKind::maxu:
return LLVM::AtomicBinOp::umax;
- case arith::AtomicRMWKind::minf:
+ case arith::AtomicRMWKind::minimumf:
return LLVM::AtomicBinOp::fmin;
case arith::AtomicRMWKind::mins:
return LLVM::AtomicBinOp::min;
diff --git a/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp b/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp
index ab1dfbdb419b891..1ba0bc8b6bfbe5e 100644
--- a/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp
+++ b/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp
@@ -60,8 +60,10 @@ static Value getSupportedReduction(AffineForOp forOp, unsigned pos,
.Case([](arith::AndIOp) { return arith::AtomicRMWKind::andi; })
.Case([](arith::OrIOp) { return arith::AtomicRMWKind::ori; })
.Case([](arith::MulIOp) { return arith::AtomicRMWKind::muli; })
- .Case([](arith::MinimumFOp) { return arith::AtomicRMWKind::minf; })
- .Case([](arith::MaximumFOp) { return arith::AtomicRMWKind::maxf; })
+ .Case(
+ [](arith::MinimumFOp) { return arith::AtomicRMWKind::minimumf; })
+ .Case(
+ [](arith::MaximumFOp) { return arith::AtomicRMWKind::maximumf; })
.Case([](arith::MinSIOp) { return arith::AtomicRMWKind::mins; })
.Case([](arith::MaxSIOp) { return arith::AtomicRMWKind::maxs; })
.Case([](arith::MinUIOp) { return arith::AtomicRMWKind::minu; })
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 77bf8a438d6db84..1e34ac598860f52 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -2369,7 +2369,7 @@ TypedAttr mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
OpBuilder &builder, Location loc,
bool useOnlyFiniteValue) {
switch (kind) {
- case AtomicRMWKind::maxf: {
+ case AtomicRMWKind::maximumf: {
const llvm::fltSemantics &semantic =
llvm::cast<FloatType>(resultType).getFloatSemantics();
APFloat identity = useOnlyFiniteValue
@@ -2390,7 +2390,7 @@ TypedAttr mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
return builder.getIntegerAttr(
resultType, APInt::getSignedMinValue(
llvm::cast<IntegerType>(resultType).getWidth()));
- case AtomicRMWKind::minf: {
+ case AtomicRMWKind::minimumf: {
const llvm::fltSemantics &semantic =
llvm::cast<FloatType>(resultType).getFloatSemantics();
APFloat identity = useOnlyFiniteValue
@@ -2426,8 +2426,8 @@ std::optional<TypedAttr> mlir::arith::getNeutralElement(Operation *op) {
// Floating-point operations.
.Case([](arith::AddFOp op) { return AtomicRMWKind::addf; })
.Case([](arith::MulFOp op) { return AtomicRMWKind::mulf; })
- .Case([](arith::MaximumFOp op) { return AtomicRMWKind::maxf; })
- .Case([](arith::MinimumFOp op) { return AtomicRMWKind::minf; })
+ .Case([](arith::MaximumFOp op) { return AtomicRMWKind::maximumf; })
+ .Case([](arith::MinimumFOp op) { return AtomicRMWKind::minimumf; })
// Integer operations.
.Case([](arith::AddIOp op) { return AtomicRMWKind::addi; })
.Case([](arith::OrIOp op) { return AtomicRMWKind::ori; })
@@ -2482,9 +2482,9 @@ Value mlir::arith::getReductionOp(AtomicRMWKind op, OpBuilder &builder,
return builder.create<arith::MulFOp>(loc, lhs, rhs);
case AtomicRMWKind::muli:
return builder.create<arith::MulIOp>(loc, lhs, rhs);
- case AtomicRMWKind::maxf:
+ case AtomicRMWKind::maximumf:
return builder.create<arith::MaximumFOp>(loc, lhs, rhs);
- case AtomicRMWKind::minf:
+ case AtomicRMWKind::minimumf:
return builder.create<arith::MinimumFOp>(loc, lhs, rhs);
case AtomicRMWKind::maxs:
return builder.create<arith::MaxSIOp>(loc, lhs, rhs);
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index d26e68cb47ac1e0..f87aa4559e10afe 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2549,9 +2549,9 @@ FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &b) {
dims.erase(dims.begin() + reductionDim);
// Step 1: Compute max along dim.
Value outputReduce = b.create<tensor::EmptyOp>(loc, dims, elementType);
- Value neutralForMaxF =
- arith::getIdentityValue(arith::AtomicRMWKind::maxf, elementType, b, loc,
- /*useOnlyFiniteValue=*/true);
+ Value neutralForMaxF = arith::getIdentityValue(arith::AtomicRMWKind::maximumf,
+ elementType, b, loc,
+ /*useOnlyFiniteValue=*/true);
Value neutralForMaxFInit =
b.create<linalg::FillOp>(loc, Value{neutralForMaxF}, outputReduce)
.result();
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 42da47a5381e789..215a8f5e7d18be0 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -3402,8 +3402,8 @@ LogicalResult AtomicRMWOp::verify() {
"expects the number of subscripts to be equal to memref rank");
switch (getKind()) {
case arith::AtomicRMWKind::addf:
- case arith::AtomicRMWKind::maxf:
- case arith::AtomicRMWKind::minf:
+ case arith::AtomicRMWKind::maximumf:
+ case arith::AtomicRMWKind::minimumf:
case arith::AtomicRMWKind::mulf:
if (!llvm::isa<FloatType>(getValue().getType()))
return emitOpError() << "with kind '"
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp
index 8a276ebbff6a921..7c3ca19b789c750 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp
@@ -36,7 +36,7 @@ namespace {
/// AtomicRMWOpLowering pattern, e.g. with "minf" or "maxf" attributes, to
/// `memref.generic_atomic_rmw` with the expanded code.
///
-/// %x = atomic_rmw "maxf" %fval, %F[%i] : (f32, memref<10xf32>) -> f32
+/// %x = atomic_rmw "maximumf" %fval, %F[%i] : (f32, memref<10xf32>) -> f32
///
/// will be lowered to
///
@@ -54,10 +54,10 @@ struct AtomicRMWOpConverter : public OpRewritePattern<memref::AtomicRMWOp> {
PatternRewriter &rewriter) const final {
arith::CmpFPredicate predicate;
switch (op.getKind()) {
- case arith::AtomicRMWKind::maxf:
+ case arith::AtomicRMWKind::maximumf:
predicate = arith::CmpFPredicate::OGT;
break;
- case arith::AtomicRMWKind::minf:
+ case arith::AtomicRMWKind::minimumf:
predicate = arith::CmpFPredicate::OLT;
break;
default:
@@ -137,8 +137,8 @@ struct ExpandOpsPass : public memref::impl::ExpandOpsBase<ExpandOpsPass> {
target.addLegalDialect<arith::ArithDialect, memref::MemRefDialect>();
target.addDynamicallyLegalOp<memref::AtomicRMWOp>(
[](memref::AtomicRMWOp op) {
- return op.getKind() != arith::AtomicRMWKind::maxf &&
- op.getKind() != arith::AtomicRMWKind::minf;
+ return op.getKind() != arith::AtomicRMWKind::maximumf &&
+ op.getKind() != arith::AtomicRMWKind::minimumf;
});
target.addDynamicallyLegalOp<memref::ReshapeOp>([](memref::ReshapeOp op) {
return !cast<MemRefType>(op.getShape().getType()).hasStaticShape();
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 9422936bf21e357..11aa76798bcaae1 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -493,7 +493,7 @@ Value mlir::vector::getVectorReductionOp(arith::AtomicRMWKind op,
case arith::AtomicRMWKind::muli:
return builder.create<vector::ReductionOp>(vector.getLoc(),
CombiningKind::MUL, vector);
- case arith::AtomicRMWKind::minf:
+ case arith::AtomicRMWKind::minimumf:
return builder.create<vector::ReductionOp>(vector.getLoc(),
CombiningKind::MINF, vector);
case arith::AtomicRMWKind::mins:
@@ -502,7 +502,7 @@ Value mlir::vector::getVectorReductionOp(arith::AtomicRMWKind op,
case arith::AtomicRMWKind::minu:
return builder.create<vector::ReductionOp>(vector.getLoc(),
CombiningKind::MINUI, vector);
- case arith::AtomicRMWKind::maxf:
+ case arith::AtomicRMWKind::maximumf:
return builder.create<vector::ReductionOp>(vector.getLoc(),
CombiningKind::MAXF, vector);
case arith::AtomicRMWKind::maxs:
diff --git a/mlir/test/Dialect/Affine/invalid.mlir b/mlir/test/Dialect/Affine/invalid.mlir
index fd7d7df9d8735f5..1dc3451ed7db87c 100644
--- a/mlir/test/Dialect/Affine/invalid.mlir
+++ b/mlir/test/Dialect/Affine/invalid.mlir
@@ -287,7 +287,7 @@ func.func @affine_parallel(%arg0 : index, %arg1 : index, %arg2 : index) {
func.func @affine_parallel(%arg0 : index, %arg1 : index, %arg2 : index) {
%0 = memref.alloc() : memref<100x100xi32>
- %1 = affine.parallel (%i, %j) = (0, 0) to (100, 100) step (10, 10) reduce ("minf") -> (f32) {
+ %1 = affine.parallel (%i, %j) = (0, 0) to (100, 100) step (10, 10) reduce ("minimumf") -> (f32) {
%2 = affine.load %0[%i, %j] : memref<100x100xi32>
// expected-error at +1 {{types mismatch between yield op and its parent}}
affine.yield %2 : i32
diff --git a/mlir/test/Dialect/Affine/ops.mlir b/mlir/test/Dialect/Affine/ops.mlir
index f55d59a3e64707b..1063f2a7ecba489 100644
--- a/mlir/test/Dialect/Affine/ops.mlir
+++ b/mlir/test/Dialect/Affine/ops.mlir
@@ -158,8 +158,8 @@ func.func @valid_symbol_affine_scope(%n : index, %A : memref<?xf32>) {
func.func @parallel(%A : memref<100x100xf32>, %N : index) {
// CHECK: affine.parallel (%[[I0:.*]], %[[J0:.*]]) = (0, 0) to (symbol(%[[N]]), 100) step (10, 10)
affine.parallel (%i0, %j0) = (0, 0) to (symbol(%N), 100) step (10, 10) {
- // CHECK: affine.parallel (%{{.*}}, %{{.*}}) = (%[[I0]], %[[J0]]) to (%[[I0]] + 10, %[[J0]] + 10) reduce ("minf", "maxf") -> (f32, f32)
- %0:2 = affine.parallel (%i1, %j1) = (%i0, %j0) to (%i0 + 10, %j0 + 10) reduce ("minf", "maxf") -> (f32, f32) {
+ // CHECK: affine.parallel (%{{.*}}, %{{.*}}) = (%[[I0]], %[[J0]]) to (%[[I0]] + 10, %[[J0]] + 10) reduce ("minimumf", "maximumf") -> (f32, f32)
+ %0:2 = affine.parallel (%i1, %j1) = (%i0, %j0) to (%i0 + 10, %j0 + 10) reduce ("minimumf", "maximumf") -> (f32, f32) {
%2 = affine.load %A[%i0 + %i0, %j0 + %j1] : memref<100x100xf32>
affine.yield %2, %2 : f32, f32
}
diff --git a/mlir/test/Dialect/MemRef/expand-ops.mlir b/mlir/test/Dialect/MemRef/expand-ops.mlir
index a0d8e52d6e7e275..3234b35e99dcdfe 100644
--- a/mlir/test/Dialect/MemRef/expand-ops.mlir
+++ b/mlir/test/Dialect/MemRef/expand-ops.mlir
@@ -3,7 +3,7 @@
// CHECK-LABEL: func @atomic_rmw_to_generic
// CHECK-SAME: ([[F:%.*]]: memref<10xf32>, [[f:%.*]]: f32, [[i:%.*]]: index)
func.func @atomic_rmw_to_generic(%F: memref<10xf32>, %f: f32, %i: index) -> f32 {
- %x = memref.atomic_rmw maxf %f, %F[%i] : (f32, memref<10xf32>) -> f32
+ %x = memref.atomic_rmw maximumf %f, %F[%i] : (f32, memref<10xf32>) -> f32
return %x : f32
}
// CHECK: %0 = memref.generic_atomic_rmw %arg0[%arg2] : memref<10xf32> {
</pre>
</details>
https://github.com/llvm/llvm-project/pull/66135
More information about the Mlir-commits
mailing list