[flang-commits] [flang] [Flang] Minloc elemental intrinsic lowering (PR #74828)

Slava Zakharin via flang-commits flang-commits at lists.llvm.org
Mon Dec 18 13:37:32 PST 2023


================
@@ -659,6 +677,194 @@ mlir::LogicalResult VariableAssignBufferization::matchAndRewrite(
   return mlir::success();
 }
 
+// Look for assign(minloc(mask=elemental)) and generate the minloc loop with
+// inlined elemental and no extra temporaries.
+//  %e = hlfir.elemental %shape ({ ... })
+//  %m = hlfir.minloc %array mask %e
+//  hlfir.assign %m to %result
+//  hlfir.destroy %m
+class AssignMinMaxlocElementalConversion
+    : public mlir::OpRewritePattern<hlfir::AssignOp> {
+public:
+  using mlir::OpRewritePattern<hlfir::AssignOp>::OpRewritePattern;
+
+  mlir::LogicalResult
+  matchAndRewrite(hlfir::AssignOp assign,
+                  mlir::PatternRewriter &rewriter) const override {
+    auto minloc = assign.getOperand(0).getDefiningOp<hlfir::MinlocOp>();
+    if (!minloc || !minloc.getMask() || minloc.getDim() || minloc.getBack())
+      return rewriter.notifyMatchFailure(assign,
+                                         "Did not find minloc with kind");
+
+    auto elemental = minloc.getMask().getDefiningOp<hlfir::ElementalOp>();
+    if (!elemental || hlfir::elementalOpMustProduceTemp(elemental))
+      return rewriter.notifyMatchFailure(assign, "Did not find elemental");
+
+    mlir::Operation::user_range users = minloc->getUsers();
+    if (std::distance(users.begin(), users.end()) != 2)
+      return rewriter.notifyMatchFailure(assign, "Did not find minloc users");
+    auto destroy = mlir::dyn_cast<hlfir::DestroyOp>(
+        *users.begin() == minloc ? *++users.begin() : *users.begin());
+    if (!destroy)
+      return rewriter.notifyMatchFailure(assign, "Did not find destroy");
+
+    if (!checkForElementalEffectsBetween(elemental, assign, minloc.getArray(),
----------------
vzakhari wrote:

Please make sure that the effects are handled properly in this case:
```
function test(x, mask)
  integer :: x(:)
  integer :: mask(:)
  print *, x(5:5)
  x(5:5) = minloc(x,mask=mask.ge.5)
end function test
```

Compiling this with -O2 gives this before the optimized bufferization:
```
    %3:2 = hlfir.declare %arg0 {uniq_name = "_QFtestEx"} : (!fir.box<!fir.array<?xi32>>) -> (!fir.box<!fir.array<?xi32>>, !fir.box<!fir.array<?xi32>>)
...
    %8 = hlfir.designate %3#0 (%c5:%c5:%c1)  shape %7 : (!fir.box<!fir.array<?xi32>>, index, index, index, !fir.shape<1>) -> !fir.box<!fir.array<1xi32>>
...
    %14 = hlfir.elemental %13 unordered : (!fir.shape<1>) -> !hlfir.expr<?x!fir.logical<4>> {
    ^bb0(%arg2: index):
      %17 = hlfir.designate %0#0 (%arg2)  : (!fir.box<!fir.array<?xi32>>, index) -> !fir.ref<i32>
      %18 = fir.load %17 : !fir.ref<i32>
      %19 = arith.cmpi sge, %18, %c5_i32 : i32
      %20 = fir.convert %19 : (i1) -> !fir.logical<4>
      hlfir.yield_element %20 : !fir.logical<4>
    }
    %15 = hlfir.minloc %3#0 mask %14 {fastmath = #arith.fastmath<contract>} : (!fir.box<!fir.array<?xi32>>, !hlfir.expr<?x!fir.logical<4>>) -> !hlfir.expr<1xi32>
    hlfir.assign %15 to %8 : !hlfir.expr<1xi32>, !fir.box<!fir.array<1xi32>>
```

https://github.com/llvm/llvm-project/pull/74828


More information about the flang-commits mailing list