[Mlir-commits] [mlir] [mlir][memref] Use dedicated ops in `AtomicRMWOpConverter` (PR #66437)

Daniil Dudkin llvmlistbot at llvm.org
Thu Sep 14 14:29:25 PDT 2023


https://github.com/unterumarmung created https://github.com/llvm/llvm-project/pull/66437:

This patch refactors the `AtomicRMWOpConverter` class to use
the dedicated operations from Arith dialect instead of using
`cmpf` + `select` pattern.
Also, a test for `minimumf` kind of `atomic_rmw` has been added.


>From b331d17321e4f7195a931c65f3832f440ef156e3 Mon Sep 17 00:00:00 2001
From: Daniil Dudkin <unterumarmung at yandex.ru>
Date: Fri, 15 Sep 2023 00:28:32 +0300
Subject: [PATCH] [mlir][memref] Use dedicated ops in `AtomicRMWOpConverter`

This patch refactors the `AtomicRMWOpConverter` class to use
the dedicated operations from Arith dialect instead of using
`cmpf` + `select` pattern.
Also, a test for `minimumf` kind of `atomic_rmw` has been added.
---
 .../Dialect/MemRef/Transforms/ExpandOps.cpp   | 31 ++++++-------------
 mlir/test/Dialect/MemRef/expand-ops.mlir      | 15 ++++++---
 2 files changed, 20 insertions(+), 26 deletions(-)

diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp
index 7c3ca19b789c750..b3beaada2539dbc 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp
@@ -33,18 +33,18 @@ using namespace mlir;
 namespace {
 
 /// Converts `atomic_rmw` that cannot be lowered to a simple atomic op with
-/// AtomicRMWOpLowering pattern, e.g. with "minf" or "maxf" attributes, to
-/// `memref.generic_atomic_rmw` with the expanded code.
+/// AtomicRMWOpLowering pattern, such as minimum and maximum operations for
+/// floating-point numbers, to `memref.generic_atomic_rmw` with the expanded
+/// code.
 ///
-/// %x = atomic_rmw "maximumf" %fval, %F[%i] : (f32, memref<10xf32>) -> f32
+/// %x = atomic_rmw maximumf %fval, %F[%i] : (f32, memref<10xf32>) -> f32
 ///
 /// will be lowered to
 ///
 /// %x = memref.generic_atomic_rmw %F[%i] : memref<10xf32> {
 /// ^bb0(%current: f32):
-///   %cmp = arith.cmpf "ogt", %current, %fval : f32
-///   %new_value = select %cmp, %current, %fval : f32
-///   memref.atomic_yield %new_value : f32
+///   %1 = arith.maximumf %current, %fval : f32
+///   memref.atomic_yield %1 : f32
 /// }
 struct AtomicRMWOpConverter : public OpRewritePattern<memref::AtomicRMWOp> {
 public:
@@ -52,18 +52,6 @@ struct AtomicRMWOpConverter : public OpRewritePattern<memref::AtomicRMWOp> {
 
   LogicalResult matchAndRewrite(memref::AtomicRMWOp op,
                                 PatternRewriter &rewriter) const final {
-    arith::CmpFPredicate predicate;
-    switch (op.getKind()) {
-    case arith::AtomicRMWKind::maximumf:
-      predicate = arith::CmpFPredicate::OGT;
-      break;
-    case arith::AtomicRMWKind::minimumf:
-      predicate = arith::CmpFPredicate::OLT;
-      break;
-    default:
-      return failure();
-    }
-
     auto loc = op.getLoc();
     auto genericOp = rewriter.create<memref::GenericAtomicRMWOp>(
         loc, op.getMemref(), op.getIndices());
@@ -72,9 +60,10 @@ struct AtomicRMWOpConverter : public OpRewritePattern<memref::AtomicRMWOp> {
 
     Value lhs = genericOp.getCurrentValue();
     Value rhs = op.getValue();
-    Value cmp = bodyBuilder.create<arith::CmpFOp>(loc, predicate, lhs, rhs);
-    Value select = bodyBuilder.create<arith::SelectOp>(loc, cmp, lhs, rhs);
-    bodyBuilder.create<memref::AtomicYieldOp>(loc, select);
+
+    Value arithOp =
+        mlir::arith::getReductionOp(op.getKind(), bodyBuilder, loc, lhs, rhs);
+    bodyBuilder.create<memref::AtomicYieldOp>(loc, arithOp);
 
     rewriter.replaceOp(op, genericOp.getResult());
     return success();
diff --git a/mlir/test/Dialect/MemRef/expand-ops.mlir b/mlir/test/Dialect/MemRef/expand-ops.mlir
index 3234b35e99dcdfe..6c98cf978505334 100644
--- a/mlir/test/Dialect/MemRef/expand-ops.mlir
+++ b/mlir/test/Dialect/MemRef/expand-ops.mlir
@@ -4,15 +4,20 @@
 // CHECK-SAME: ([[F:%.*]]: memref<10xf32>, [[f:%.*]]: f32, [[i:%.*]]: index)
 func.func @atomic_rmw_to_generic(%F: memref<10xf32>, %f: f32, %i: index) -> f32 {
   %x = memref.atomic_rmw maximumf %f, %F[%i] : (f32, memref<10xf32>) -> f32
+  %y = memref.atomic_rmw minimumf %f, %F[%i] : (f32, memref<10xf32>) -> f32
   return %x : f32
 }
-// CHECK: %0 = memref.generic_atomic_rmw %arg0[%arg2] : memref<10xf32> {
+// CHECK: [[RESULT:%.*]] = memref.generic_atomic_rmw %arg0[%arg2] : memref<10xf32> {
 // CHECK: ^bb0([[CUR_VAL:%.*]]: f32):
-// CHECK:   [[CMP:%.*]] = arith.cmpf ogt, [[CUR_VAL]], [[f]] : f32
-// CHECK:   [[SELECT:%.*]] = arith.select [[CMP]], [[CUR_VAL]], [[f]] : f32
-// CHECK:   memref.atomic_yield [[SELECT]] : f32
+// CHECK:   [[MAXIMUM:%.*]] = arith.maximumf [[CUR_VAL]], [[f]] : f32
+// CHECK:   memref.atomic_yield [[MAXIMUM]] : f32
 // CHECK: }
-// CHECK: return %0 : f32
+// CHECK: memref.generic_atomic_rmw %arg0[%arg2] : memref<10xf32> {
+// CHECK: ^bb0([[CUR_VAL:%.*]]: f32):
+// CHECK:   [[MINIMUM:%.*]] = arith.minimumf [[CUR_VAL]], [[f]] : f32
+// CHECK:   memref.atomic_yield [[MINIMUM]] : f32
+// CHECK: }
+// CHECK: return [[RESULT]] : f32
 
 // -----
 



More information about the Mlir-commits mailing list