[Mlir-commits] [mlir] [mlir][memref] Use dedicated ops in `AtomicRMWOpConverter` (PR #66437)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Sep 14 14:30:30 PDT 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-core
<details>
<summary>Changes</summary>
This patch refactors the `AtomicRMWOpConverter` class to use
the dedicated operations from Arith dialect instead of using
`cmpf` + `select` pattern.
Also, a test for `minimumf` kind of `atomic_rmw` has been added.
--
Full diff: https://github.com/llvm/llvm-project/pull/66437.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp (+10-21)
- (modified) mlir/test/Dialect/MemRef/expand-ops.mlir (+10-5)
<pre>
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp
index 7c3ca19b789c750..b3beaada2539dbc 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp
@@ -33,18 +33,18 @@ using namespace mlir;
namespace {
/// Converts `atomic_rmw` that cannot be lowered to a simple atomic op with
-/// AtomicRMWOpLowering pattern, e.g. with "minf" or "maxf" attributes, to
-/// `memref.generic_atomic_rmw` with the expanded code.
+/// AtomicRMWOpLowering pattern, such as minimum and maximum operations for
+/// floating-point numbers, to `memref.generic_atomic_rmw` with the expanded
+/// code.
///
-/// %x = atomic_rmw "maximumf" %fval, %F[%i] : (f32, memref<10xf32>) -> f32
+/// %x = atomic_rmw maximumf %fval, %F[%i] : (f32, memref<10xf32>) -> f32
///
/// will be lowered to
///
/// %x = memref.generic_atomic_rmw %F[%i] : memref<10xf32> {
/// ^bb0(%current: f32):
-/// %cmp = arith.cmpf "ogt", %current, %fval : f32
-/// %new_value = select %cmp, %current, %fval : f32
-/// memref.atomic_yield %new_value : f32
+/// %1 = arith.maximumf %current, %fval : f32
+/// memref.atomic_yield %1 : f32
/// }
struct AtomicRMWOpConverter : public OpRewritePattern<memref::AtomicRMWOp> {
public:
@@ -52,18 +52,6 @@ struct AtomicRMWOpConverter : public OpRewritePattern<memref::AtomicRMWOp> {
LogicalResult matchAndRewrite(memref::AtomicRMWOp op,
PatternRewriter &rewriter) const final {
- arith::CmpFPredicate predicate;
- switch (op.getKind()) {
- case arith::AtomicRMWKind::maximumf:
- predicate = arith::CmpFPredicate::OGT;
- break;
- case arith::AtomicRMWKind::minimumf:
- predicate = arith::CmpFPredicate::OLT;
- break;
- default:
- return failure();
- }
-
auto loc = op.getLoc();
auto genericOp = rewriter.create<memref::GenericAtomicRMWOp>(
loc, op.getMemref(), op.getIndices());
@@ -72,9 +60,10 @@ struct AtomicRMWOpConverter : public OpRewritePattern<memref::AtomicRMWOp> {
Value lhs = genericOp.getCurrentValue();
Value rhs = op.getValue();
- Value cmp = bodyBuilder.create<arith::CmpFOp>(loc, predicate, lhs, rhs);
- Value select = bodyBuilder.create<arith::SelectOp>(loc, cmp, lhs, rhs);
- bodyBuilder.create<memref::AtomicYieldOp>(loc, select);
+
+ Value arithOp =
+ mlir::arith::getReductionOp(op.getKind(), bodyBuilder, loc, lhs, rhs);
+ bodyBuilder.create<memref::AtomicYieldOp>(loc, arithOp);
rewriter.replaceOp(op, genericOp.getResult());
return success();
diff --git a/mlir/test/Dialect/MemRef/expand-ops.mlir b/mlir/test/Dialect/MemRef/expand-ops.mlir
index 3234b35e99dcdfe..6c98cf978505334 100644
--- a/mlir/test/Dialect/MemRef/expand-ops.mlir
+++ b/mlir/test/Dialect/MemRef/expand-ops.mlir
@@ -4,15 +4,20 @@
// CHECK-SAME: ([[F:%.*]]: memref<10xf32>, [[f:%.*]]: f32, [[i:%.*]]: index)
func.func @atomic_rmw_to_generic(%F: memref<10xf32>, %f: f32, %i: index) -> f32 {
%x = memref.atomic_rmw maximumf %f, %F[%i] : (f32, memref<10xf32>) -> f32
+ %y = memref.atomic_rmw minimumf %f, %F[%i] : (f32, memref<10xf32>) -> f32
return %x : f32
}
-// CHECK: %0 = memref.generic_atomic_rmw %arg0[%arg2] : memref<10xf32> {
+// CHECK: [[RESULT:%.*]] = memref.generic_atomic_rmw %arg0[%arg2] : memref<10xf32> {
// CHECK: ^bb0([[CUR_VAL:%.*]]: f32):
-// CHECK: [[CMP:%.*]] = arith.cmpf ogt, [[CUR_VAL]], [[f]] : f32
-// CHECK: [[SELECT:%.*]] = arith.select [[CMP]], [[CUR_VAL]], [[f]] : f32
-// CHECK: memref.atomic_yield [[SELECT]] : f32
+// CHECK: [[MAXIMUM:%.*]] = arith.maximumf [[CUR_VAL]], [[f]] : f32
+// CHECK: memref.atomic_yield [[MAXIMUM]] : f32
// CHECK: }
-// CHECK: return %0 : f32
+// CHECK: memref.generic_atomic_rmw %arg0[%arg2] : memref<10xf32> {
+// CHECK: ^bb0([[CUR_VAL:%.*]]: f32):
+// CHECK: [[MINIMUM:%.*]] = arith.minimumf [[CUR_VAL]], [[f]] : f32
+// CHECK: memref.atomic_yield [[MINIMUM]] : f32
+// CHECK: }
+// CHECK: return [[RESULT]] : f32
// -----
</pre>
</details>
https://github.com/llvm/llvm-project/pull/66437
More information about the Mlir-commits
mailing list