[Mlir-commits] [mlir] [mlir] Add `maxnumf` and `minnumf` to `AtomicRMWKind` (PR #66442)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Sep 14 15:32:59 PDT 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir
            
<details>
<summary>Changes</summary>
This commit adds the mentioned kinds of `AtomicRMWKind`
as well as code generation for them.

--
Full diff: https://github.com/llvm/llvm-project/pull/66442.diff

4 Files Affected:

- (modified) mlir/include/mlir/Dialect/Arith/IR/ArithBase.td (+3-1) 
- (modified) mlir/lib/Dialect/Arith/IR/ArithOps.cpp (+4) 
- (modified) mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp (+5-2) 
- (modified) mlir/test/Dialect/MemRef/expand-ops.mlir (+15-3) 


<pre>
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td b/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td
index a833e9c8220af5b..133af893e4efa74 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td
@@ -82,6 +82,8 @@ def ATOMIC_RMW_KIND_MULF     : I64EnumAttrCase&lt;&quot;mulf&quot;, 9&gt;;
 def ATOMIC_RMW_KIND_MULI     : I64EnumAttrCase&lt;&quot;muli&quot;, 10&gt;;
 def ATOMIC_RMW_KIND_ORI      : I64EnumAttrCase&lt;&quot;ori&quot;, 11&gt;;
 def ATOMIC_RMW_KIND_ANDI     : I64EnumAttrCase&lt;&quot;andi&quot;, 12&gt;;
+def ATOMIC_RMW_KIND_MAXNUMF  : I64EnumAttrCase&lt;&quot;maxnumf&quot;, 13&gt;;
+def ATOMIC_RMW_KIND_MINNUMF  : I64EnumAttrCase&lt;&quot;minnumf&quot;, 14&gt;;
 
 def AtomicRMWKindAttr : I64EnumAttr&lt;
     &quot;AtomicRMWKind&quot;, &quot;&quot;,
@@ -89,7 +91,7 @@ def AtomicRMWKindAttr : I64EnumAttr&lt;
      ATOMIC_RMW_KIND_MAXIMUMF, ATOMIC_RMW_KIND_MAXS, ATOMIC_RMW_KIND_MAXU,
      ATOMIC_RMW_KIND_MINIMUMF, ATOMIC_RMW_KIND_MINS, ATOMIC_RMW_KIND_MINU,
      ATOMIC_RMW_KIND_MULF, ATOMIC_RMW_KIND_MULI, ATOMIC_RMW_KIND_ORI,
-     ATOMIC_RMW_KIND_ANDI]&gt; {
+     ATOMIC_RMW_KIND_ANDI, ATOMIC_RMW_KIND_MAXNUMF, ATOMIC_RMW_KIND_MINNUMF]&gt; {
   let cppNamespace = &quot;::mlir::arith&quot;;
 }
 
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index d39c5b6051122e4..ae8a6ef350ce191 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -2523,6 +2523,10 @@ Value mlir::arith::getReductionOp(AtomicRMWKind op, OpBuilder &amp;builder,
     return builder.create&lt;arith::MaximumFOp&gt;(loc, lhs, rhs);
   case AtomicRMWKind::minimumf:
     return builder.create&lt;arith::MinimumFOp&gt;(loc, lhs, rhs);
+   case AtomicRMWKind::maxnumf:
+    return builder.create&lt;arith::MaxNumFOp&gt;(loc, lhs, rhs);
+  case AtomicRMWKind::minnumf:
+    return builder.create&lt;arith::MinNumFOp&gt;(loc, lhs, rhs);
   case AtomicRMWKind::maxs:
     return builder.create&lt;arith::MaxSIOp&gt;(loc, lhs, rhs);
   case AtomicRMWKind::mins:
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp
index b3beaada2539dbc..faba12f5bf82f89 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp
@@ -20,6 +20,7 @@
 #include &quot;mlir/Dialect/MemRef/Transforms/Transforms.h&quot;
 #include &quot;mlir/IR/TypeUtilities.h&quot;
 #include &quot;mlir/Transforms/DialectConversion.h&quot;
+#include &quot;llvm/ADT/STLExtras.h&quot;
 
 namespace mlir {
 namespace memref {
@@ -126,8 +127,10 @@ struct ExpandOpsPass : public memref::impl::ExpandOpsBase&lt;ExpandOpsPass&gt; {
     target.addLegalDialect&lt;arith::ArithDialect, memref::MemRefDialect&gt;();
     target.addDynamicallyLegalOp&lt;memref::AtomicRMWOp&gt;(
         [](memref::AtomicRMWOp op) {
-          return op.getKind() != arith::AtomicRMWKind::maximumf &amp;&amp;
-                 op.getKind() != arith::AtomicRMWKind::minimumf;
+          constexpr std::array shouldBeExpandedKinds = {
+              arith::AtomicRMWKind::maximumf, arith::AtomicRMWKind::minimumf,
+              arith::AtomicRMWKind::minnumf, arith::AtomicRMWKind::maxnumf};
+          return !llvm::is_contained(shouldBeExpandedKinds, op.getKind());
         });
     target.addDynamicallyLegalOp&lt;memref::ReshapeOp&gt;([](memref::ReshapeOp op) {
       return !cast&lt;MemRefType&gt;(op.getShape().getType()).hasStaticShape();
diff --git a/mlir/test/Dialect/MemRef/expand-ops.mlir b/mlir/test/Dialect/MemRef/expand-ops.mlir
index 6c98cf978505334..f958a92b751a4ab 100644
--- a/mlir/test/Dialect/MemRef/expand-ops.mlir
+++ b/mlir/test/Dialect/MemRef/expand-ops.mlir
@@ -3,9 +3,11 @@
 // CHECK-LABEL: func @atomic_rmw_to_generic
 // CHECK-SAME: ([[F:%.*]]: memref&lt;10xf32&gt;, [[f:%.*]]: f32, [[i:%.*]]: index)
 func.func @atomic_rmw_to_generic(%F: memref&lt;10xf32&gt;, %f: f32, %i: index) -&gt; f32 {
-  %x = memref.atomic_rmw maximumf %f, %F[%i] : (f32, memref&lt;10xf32&gt;) -&gt; f32
-  %y = memref.atomic_rmw minimumf %f, %F[%i] : (f32, memref&lt;10xf32&gt;) -&gt; f32
-  return %x : f32
+  %a = memref.atomic_rmw maximumf %f, %F[%i] : (f32, memref&lt;10xf32&gt;) -&gt; f32
+  %b = memref.atomic_rmw minimumf %f, %F[%i] : (f32, memref&lt;10xf32&gt;) -&gt; f32
+  %c = memref.atomic_rmw maxnumf %f, %F[%i] : (f32, memref&lt;10xf32&gt;) -&gt; f32
+  %d = memref.atomic_rmw minnumf %f, %F[%i] : (f32, memref&lt;10xf32&gt;) -&gt; f32
+  return %a : f32
 }
 // CHECK: [[RESULT:%.*]] = memref.generic_atomic_rmw %arg0[%arg2] : memref&lt;10xf32&gt; {
 // CHECK: ^bb0([[CUR_VAL:%.*]]: f32):
@@ -17,6 +19,16 @@ func.func @atomic_rmw_to_generic(%F: memref&lt;10xf32&gt;, %f: f32, %i: index) -&gt; f32
 // CHECK:   [[MINIMUM:%.*]] = arith.minimumf [[CUR_VAL]], [[f]] : f32
 // CHECK:   memref.atomic_yield [[MINIMUM]] : f32
 // CHECK: }
+// CHECK: memref.generic_atomic_rmw %arg0[%arg2] : memref&lt;10xf32&gt; {
+// CHECK: ^bb0([[CUR_VAL:%.*]]: f32):
+// CHECK:   [[MAXNUM:%.*]] = arith.maxnumf [[CUR_VAL]], [[f]] : f32
+// CHECK:   memref.atomic_yield [[MAXNUM]] : f32
+// CHECK: }
+// CHECK: memref.generic_atomic_rmw %arg0[%arg2] : memref&lt;10xf32&gt; {
+// CHECK: ^bb0([[CUR_VAL:%.*]]: f32):
+// CHECK:   [[MINNUM:%.*]] = arith.minnumf [[CUR_VAL]], [[f]] : f32
+// CHECK:   memref.atomic_yield [[MINNUM]] : f32
+// CHECK: }
 // CHECK: return [[RESULT]] : f32
 
 // -----
</pre>
</details>


https://github.com/llvm/llvm-project/pull/66442


More information about the Mlir-commits mailing list