[Mlir-commits] [mlir] [mlir][memref] Use dedicated ops in `AtomicRMWOpConverter` (PR #66437)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Sep 14 14:30:30 PDT 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-core
            
<details>
<summary>Changes</summary>
This patch refactors the `AtomicRMWOpConverter` class to use
the dedicated operations from Arith dialect instead of using
`cmpf` + `select` pattern.
Also, a test for `minimumf` kind of `atomic_rmw` has been added.

--
Full diff: https://github.com/llvm/llvm-project/pull/66437.diff

2 Files Affected:

- (modified) mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp (+10-21) 
- (modified) mlir/test/Dialect/MemRef/expand-ops.mlir (+10-5) 


<pre>
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp
index 7c3ca19b789c750..b3beaada2539dbc 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp
@@ -33,18 +33,18 @@ using namespace mlir;
 namespace {
 
 /// Converts `atomic_rmw` that cannot be lowered to a simple atomic op with
-/// AtomicRMWOpLowering pattern, e.g. with &quot;minf&quot; or &quot;maxf&quot; attributes, to
-/// `memref.generic_atomic_rmw` with the expanded code.
+/// AtomicRMWOpLowering pattern, such as minimum and maximum operations for
+/// floating-point numbers, to `memref.generic_atomic_rmw` with the expanded
+/// code.
 ///
-/// %x = atomic_rmw &quot;maximumf&quot; %fval, %F[%i] : (f32, memref&lt;10xf32&gt;) -&gt; f32
+/// %x = atomic_rmw maximumf %fval, %F[%i] : (f32, memref&lt;10xf32&gt;) -&gt; f32
 ///
 /// will be lowered to
 ///
 /// %x = memref.generic_atomic_rmw %F[%i] : memref&lt;10xf32&gt; {
 /// ^bb0(%current: f32):
-///   %cmp = arith.cmpf &quot;ogt&quot;, %current, %fval : f32
-///   %new_value = select %cmp, %current, %fval : f32
-///   memref.atomic_yield %new_value : f32
+///   %1 = arith.maximumf %current, %fval : f32
+///   memref.atomic_yield %1 : f32
 /// }
 struct AtomicRMWOpConverter : public OpRewritePattern&lt;memref::AtomicRMWOp&gt; {
 public:
@@ -52,18 +52,6 @@ struct AtomicRMWOpConverter : public OpRewritePattern&lt;memref::AtomicRMWOp&gt; {
 
   LogicalResult matchAndRewrite(memref::AtomicRMWOp op,
                                 PatternRewriter &amp;rewriter) const final {
-    arith::CmpFPredicate predicate;
-    switch (op.getKind()) {
-    case arith::AtomicRMWKind::maximumf:
-      predicate = arith::CmpFPredicate::OGT;
-      break;
-    case arith::AtomicRMWKind::minimumf:
-      predicate = arith::CmpFPredicate::OLT;
-      break;
-    default:
-      return failure();
-    }
-
     auto loc = op.getLoc();
     auto genericOp = rewriter.create&lt;memref::GenericAtomicRMWOp&gt;(
         loc, op.getMemref(), op.getIndices());
@@ -72,9 +60,10 @@ struct AtomicRMWOpConverter : public OpRewritePattern&lt;memref::AtomicRMWOp&gt; {
 
     Value lhs = genericOp.getCurrentValue();
     Value rhs = op.getValue();
-    Value cmp = bodyBuilder.create&lt;arith::CmpFOp&gt;(loc, predicate, lhs, rhs);
-    Value select = bodyBuilder.create&lt;arith::SelectOp&gt;(loc, cmp, lhs, rhs);
-    bodyBuilder.create&lt;memref::AtomicYieldOp&gt;(loc, select);
+
+    Value arithOp =
+        mlir::arith::getReductionOp(op.getKind(), bodyBuilder, loc, lhs, rhs);
+    bodyBuilder.create&lt;memref::AtomicYieldOp&gt;(loc, arithOp);
 
     rewriter.replaceOp(op, genericOp.getResult());
     return success();
diff --git a/mlir/test/Dialect/MemRef/expand-ops.mlir b/mlir/test/Dialect/MemRef/expand-ops.mlir
index 3234b35e99dcdfe..6c98cf978505334 100644
--- a/mlir/test/Dialect/MemRef/expand-ops.mlir
+++ b/mlir/test/Dialect/MemRef/expand-ops.mlir
@@ -4,15 +4,20 @@
 // CHECK-SAME: ([[F:%.*]]: memref&lt;10xf32&gt;, [[f:%.*]]: f32, [[i:%.*]]: index)
 func.func @atomic_rmw_to_generic(%F: memref&lt;10xf32&gt;, %f: f32, %i: index) -&gt; f32 {
   %x = memref.atomic_rmw maximumf %f, %F[%i] : (f32, memref&lt;10xf32&gt;) -&gt; f32
+  %y = memref.atomic_rmw minimumf %f, %F[%i] : (f32, memref&lt;10xf32&gt;) -&gt; f32
   return %x : f32
 }
-// CHECK: %0 = memref.generic_atomic_rmw %arg0[%arg2] : memref&lt;10xf32&gt; {
+// CHECK: [[RESULT:%.*]] = memref.generic_atomic_rmw %arg0[%arg2] : memref&lt;10xf32&gt; {
 // CHECK: ^bb0([[CUR_VAL:%.*]]: f32):
-// CHECK:   [[CMP:%.*]] = arith.cmpf ogt, [[CUR_VAL]], [[f]] : f32
-// CHECK:   [[SELECT:%.*]] = arith.select [[CMP]], [[CUR_VAL]], [[f]] : f32
-// CHECK:   memref.atomic_yield [[SELECT]] : f32
+// CHECK:   [[MAXIMUM:%.*]] = arith.maximumf [[CUR_VAL]], [[f]] : f32
+// CHECK:   memref.atomic_yield [[MAXIMUM]] : f32
 // CHECK: }
-// CHECK: return %0 : f32
+// CHECK: memref.generic_atomic_rmw %arg0[%arg2] : memref&lt;10xf32&gt; {
+// CHECK: ^bb0([[CUR_VAL:%.*]]: f32):
+// CHECK:   [[MINIMUM:%.*]] = arith.minimumf [[CUR_VAL]], [[f]] : f32
+// CHECK:   memref.atomic_yield [[MINIMUM]] : f32
+// CHECK: }
+// CHECK: return [[RESULT]] : f32
 
 // -----
 
</pre>
</details>


https://github.com/llvm/llvm-project/pull/66437


More information about the Mlir-commits mailing list