[Mlir-commits] [mlir] 6f4a528 - [mlir][memref] Use dedicated ops in `AtomicRMWOpConverter` (#66437)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Sep 14 14:52:39 PDT 2023


Author: Daniil Dudkin
Date: 2023-09-15T00:52:35+03:00
New Revision: 6f4a5286981dd57826751834f7a434be1a463339

URL: https://github.com/llvm/llvm-project/commit/6f4a5286981dd57826751834f7a434be1a463339
DIFF: https://github.com/llvm/llvm-project/commit/6f4a5286981dd57826751834f7a434be1a463339.diff

LOG: [mlir][memref] Use dedicated ops in `AtomicRMWOpConverter` (#66437)

This patch refactors the `AtomicRMWOpConverter` class to use
the dedicated operations from Arith dialect instead of using
`cmpf` + `select` pattern.
Also, a test for `minimumf` kind of `atomic_rmw` has been added.

Added: 
    

Modified: 
    mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp
    mlir/test/Dialect/MemRef/expand-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp
index 7c3ca19b789c750..b3beaada2539dbc 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp
@@ -33,18 +33,18 @@ using namespace mlir;
 namespace {
 
 /// Converts `atomic_rmw` that cannot be lowered to a simple atomic op with
-/// AtomicRMWOpLowering pattern, e.g. with "minf" or "maxf" attributes, to
-/// `memref.generic_atomic_rmw` with the expanded code.
+/// AtomicRMWOpLowering pattern, such as minimum and maximum operations for
+/// floating-point numbers, to `memref.generic_atomic_rmw` with the expanded
+/// code.
 ///
-/// %x = atomic_rmw "maximumf" %fval, %F[%i] : (f32, memref<10xf32>) -> f32
+/// %x = atomic_rmw maximumf %fval, %F[%i] : (f32, memref<10xf32>) -> f32
 ///
 /// will be lowered to
 ///
 /// %x = memref.generic_atomic_rmw %F[%i] : memref<10xf32> {
 /// ^bb0(%current: f32):
-///   %cmp = arith.cmpf "ogt", %current, %fval : f32
-///   %new_value = select %cmp, %current, %fval : f32
-///   memref.atomic_yield %new_value : f32
+///   %1 = arith.maximumf %current, %fval : f32
+///   memref.atomic_yield %1 : f32
 /// }
 struct AtomicRMWOpConverter : public OpRewritePattern<memref::AtomicRMWOp> {
 public:
@@ -52,18 +52,6 @@ struct AtomicRMWOpConverter : public OpRewritePattern<memref::AtomicRMWOp> {
 
   LogicalResult matchAndRewrite(memref::AtomicRMWOp op,
                                 PatternRewriter &rewriter) const final {
-    arith::CmpFPredicate predicate;
-    switch (op.getKind()) {
-    case arith::AtomicRMWKind::maximumf:
-      predicate = arith::CmpFPredicate::OGT;
-      break;
-    case arith::AtomicRMWKind::minimumf:
-      predicate = arith::CmpFPredicate::OLT;
-      break;
-    default:
-      return failure();
-    }
-
     auto loc = op.getLoc();
     auto genericOp = rewriter.create<memref::GenericAtomicRMWOp>(
         loc, op.getMemref(), op.getIndices());
@@ -72,9 +60,10 @@ struct AtomicRMWOpConverter : public OpRewritePattern<memref::AtomicRMWOp> {
 
     Value lhs = genericOp.getCurrentValue();
     Value rhs = op.getValue();
-    Value cmp = bodyBuilder.create<arith::CmpFOp>(loc, predicate, lhs, rhs);
-    Value select = bodyBuilder.create<arith::SelectOp>(loc, cmp, lhs, rhs);
-    bodyBuilder.create<memref::AtomicYieldOp>(loc, select);
+
+    Value arithOp =
+        mlir::arith::getReductionOp(op.getKind(), bodyBuilder, loc, lhs, rhs);
+    bodyBuilder.create<memref::AtomicYieldOp>(loc, arithOp);
 
     rewriter.replaceOp(op, genericOp.getResult());
     return success();

diff  --git a/mlir/test/Dialect/MemRef/expand-ops.mlir b/mlir/test/Dialect/MemRef/expand-ops.mlir
index 3234b35e99dcdfe..6c98cf978505334 100644
--- a/mlir/test/Dialect/MemRef/expand-ops.mlir
+++ b/mlir/test/Dialect/MemRef/expand-ops.mlir
@@ -4,15 +4,20 @@
 // CHECK-SAME: ([[F:%.*]]: memref<10xf32>, [[f:%.*]]: f32, [[i:%.*]]: index)
 func.func @atomic_rmw_to_generic(%F: memref<10xf32>, %f: f32, %i: index) -> f32 {
   %x = memref.atomic_rmw maximumf %f, %F[%i] : (f32, memref<10xf32>) -> f32
+  %y = memref.atomic_rmw minimumf %f, %F[%i] : (f32, memref<10xf32>) -> f32
   return %x : f32
 }
-// CHECK: %0 = memref.generic_atomic_rmw %arg0[%arg2] : memref<10xf32> {
+// CHECK: [[RESULT:%.*]] = memref.generic_atomic_rmw %arg0[%arg2] : memref<10xf32> {
 // CHECK: ^bb0([[CUR_VAL:%.*]]: f32):
-// CHECK:   [[CMP:%.*]] = arith.cmpf ogt, [[CUR_VAL]], [[f]] : f32
-// CHECK:   [[SELECT:%.*]] = arith.select [[CMP]], [[CUR_VAL]], [[f]] : f32
-// CHECK:   memref.atomic_yield [[SELECT]] : f32
+// CHECK:   [[MAXIMUM:%.*]] = arith.maximumf [[CUR_VAL]], [[f]] : f32
+// CHECK:   memref.atomic_yield [[MAXIMUM]] : f32
 // CHECK: }
-// CHECK: return %0 : f32
+// CHECK: memref.generic_atomic_rmw %arg0[%arg2] : memref<10xf32> {
+// CHECK: ^bb0([[CUR_VAL:%.*]]: f32):
+// CHECK:   [[MINIMUM:%.*]] = arith.minimumf [[CUR_VAL]], [[f]] : f32
+// CHECK:   memref.atomic_yield [[MINIMUM]] : f32
+// CHECK: }
+// CHECK: return [[RESULT]] : f32
 
 // -----
 


        


More information about the Mlir-commits mailing list