[flang-commits] [flang] [Flang] Maxloc elemental intrinsic lowering. (PR #79469)

David Green via flang-commits flang-commits at lists.llvm.org
Thu Jan 25 08:55:09 PST 2024


https://github.com/davemgreen created https://github.com/llvm/llvm-project/pull/79469

This is an extension to #74828 to handle maxloc too, to keep the minloc and maxloc symmetric.

>From d578694454fd649ad032a9f57a4a4efba1b6c2a5 Mon Sep 17 00:00:00 2001
From: David Green <david.green at arm.com>
Date: Thu, 25 Jan 2024 16:05:40 +0000
Subject: [PATCH] [Flang] Maxloc elemental intrinsic lowering.

This is an extension to #74828 to handle maxloc too, to keep the minloc and
maxloc symmetric.
---
 flang/include/flang/Optimizer/Support/Utils.h | 135 ++++++++++++++++-
 .../Transforms/OptimizedBufferization.cpp     |  64 ++++----
 .../Transforms/SimplifyIntrinsics.cpp         | 128 ----------------
 flang/test/HLFIR/maxloc-elemental.fir         | 140 ++++++++++++++++++
 4 files changed, 306 insertions(+), 161 deletions(-)
 create mode 100644 flang/test/HLFIR/maxloc-elemental.fir

diff --git a/flang/include/flang/Optimizer/Support/Utils.h b/flang/include/flang/Optimizer/Support/Utils.h
index e31121260acdae7..b50f297a7d31410 100644
--- a/flang/include/flang/Optimizer/Support/Utils.h
+++ b/flang/include/flang/Optimizer/Support/Utils.h
@@ -18,6 +18,7 @@
 #include "flang/Optimizer/Builder/Todo.h"
 #include "flang/Optimizer/Dialect/FIROps.h"
 #include "flang/Optimizer/Dialect/FIRType.h"
+#include "flang/Optimizer/HLFIR/HLFIRDialect.h"
 #include "flang/Optimizer/Support/FatalError.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -144,13 +145,133 @@ using AddrGeneratorTy = llvm::function_ref<mlir::Value(
     mlir::Value)>;
 
 // Produces a loop nest for a Minloc intrinsic.
-void genMinMaxlocReductionLoop(fir::FirOpBuilder &builder, mlir::Value array,
-                               InitValGeneratorTy initVal,
-                               MinlocBodyOpGeneratorTy genBody,
-                               fir::AddrGeneratorTy getAddrFn, unsigned rank,
-                               mlir::Type elementType, mlir::Location loc,
-                               mlir::Type maskElemType, mlir::Value resultArr,
-                               bool maskMayBeLogicalScalar);
+inline void genMinMaxlocReductionLoop(
+    fir::FirOpBuilder &builder, mlir::Value array,
+    fir::InitValGeneratorTy initVal, fir::MinlocBodyOpGeneratorTy genBody,
+    fir::AddrGeneratorTy getAddrFn, unsigned rank, mlir::Type elementType,
+    mlir::Location loc, mlir::Type maskElemType, mlir::Value resultArr,
+    bool maskMayBeLogicalScalar) {
+  mlir::IndexType idxTy = builder.getIndexType();
+
+  mlir::Value zeroIdx = builder.createIntegerConstant(loc, idxTy, 0);
+
+  fir::SequenceType::Shape flatShape(rank,
+                                     fir::SequenceType::getUnknownExtent());
+  mlir::Type arrTy = fir::SequenceType::get(flatShape, elementType);
+  mlir::Type boxArrTy = fir::BoxType::get(arrTy);
+  array = builder.create<fir::ConvertOp>(loc, boxArrTy, array);
+
+  mlir::Type resultElemType = hlfir::getFortranElementType(resultArr.getType());
+  mlir::Value flagSet = builder.createIntegerConstant(loc, resultElemType, 1);
+  mlir::Value zero = builder.createIntegerConstant(loc, resultElemType, 0);
+  mlir::Value flagRef = builder.createTemporary(loc, resultElemType);
+  builder.create<fir::StoreOp>(loc, zero, flagRef);
+
+  mlir::Value init = initVal(builder, loc, elementType);
+  llvm::SmallVector<mlir::Value, Fortran::common::maxRank> bounds;
+
+  assert(rank > 0 && "rank cannot be zero");
+  mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1);
+
+  // Compute all the upper bounds before the loop nest.
+  // It is not strictly necessary for performance, since the loop nest
+  // does not have any store operations and any LICM optimization
+  // should be able to optimize the redundancy.
+  for (unsigned i = 0; i < rank; ++i) {
+    mlir::Value dimIdx = builder.createIntegerConstant(loc, idxTy, i);
+    auto dims =
+        builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy, idxTy, array, dimIdx);
+    mlir::Value len = dims.getResult(1);
+    // We use C indexing here, so len-1 as loopcount
+    mlir::Value loopCount = builder.create<mlir::arith::SubIOp>(loc, len, one);
+    bounds.push_back(loopCount);
+  }
+  // Create a loop nest consisting of OP operations.
+  // Collect the loops' induction variables into indices array,
+  // which will be used in the innermost loop to load the input
+  // array's element.
+  // The loops are generated such that the innermost loop processes
+  // the 0 dimension.
+  llvm::SmallVector<mlir::Value, Fortran::common::maxRank> indices;
+  for (unsigned i = rank; 0 < i; --i) {
+    mlir::Value step = one;
+    mlir::Value loopCount = bounds[i - 1];
+    auto loop =
+        builder.create<fir::DoLoopOp>(loc, zeroIdx, loopCount, step, false,
+                                      /*finalCountValue=*/false, init);
+    init = loop.getRegionIterArgs()[0];
+    indices.push_back(loop.getInductionVar());
+    // Set insertion point to the loop body so that the next loop
+    // is inserted inside the current one.
+    builder.setInsertionPointToStart(loop.getBody());
+  }
+
+  // Reverse the indices such that they are ordered as:
+  //   <dim-0-idx, dim-1-idx, ...>
+  std::reverse(indices.begin(), indices.end());
+  mlir::Value reductionVal =
+      genBody(builder, loc, elementType, array, flagRef, init, indices);
+
+  // Unwind the loop nest and insert ResultOp on each level
+  // to return the updated value of the reduction to the enclosing
+  // loops.
+  for (unsigned i = 0; i < rank; ++i) {
+    auto result = builder.create<fir::ResultOp>(loc, reductionVal);
+    // Proceed to the outer loop.
+    auto loop = mlir::cast<fir::DoLoopOp>(result->getParentOp());
+    reductionVal = loop.getResult(0);
+    // Set insertion point after the loop operation that we have
+    // just processed.
+    builder.setInsertionPointAfter(loop.getOperation());
+  }
+  // End of loop nest. The insertion point is after the outermost loop.
+  if (maskMayBeLogicalScalar) {
+    if (fir::IfOp ifOp =
+            mlir::dyn_cast<fir::IfOp>(builder.getBlock()->getParentOp())) {
+      builder.create<fir::ResultOp>(loc, reductionVal);
+      builder.setInsertionPointAfter(ifOp);
+      // Redefine flagSet to escape scope of ifOp
+      flagSet = builder.createIntegerConstant(loc, resultElemType, 1);
+      reductionVal = ifOp.getResult(0);
+    }
+  }
+
+  // Check for case where array was full of max values.
+  // flag will be 0 if mask was never true, 1 if mask was true as some point,
+  // this is needed to avoid catching cases where we didn't access any elements
+  // e.g. mask=.FALSE.
+  mlir::Value flagValue =
+      builder.create<fir::LoadOp>(loc, resultElemType, flagRef);
+  mlir::Value flagCmp = builder.create<mlir::arith::CmpIOp>(
+      loc, mlir::arith::CmpIPredicate::eq, flagValue, flagSet);
+  fir::IfOp ifMaskTrueOp =
+      builder.create<fir::IfOp>(loc, flagCmp, /*withElseRegion=*/false);
+  builder.setInsertionPointToStart(&ifMaskTrueOp.getThenRegion().front());
+
+  mlir::Value testInit = initVal(builder, loc, elementType);
+  fir::IfOp ifMinSetOp;
+  if (elementType.isa<mlir::FloatType>()) {
+    mlir::Value cmp = builder.create<mlir::arith::CmpFOp>(
+        loc, mlir::arith::CmpFPredicate::OEQ, testInit, reductionVal);
+    ifMinSetOp = builder.create<fir::IfOp>(loc, cmp,
+                                           /*withElseRegion*/ false);
+  } else {
+    mlir::Value cmp = builder.create<mlir::arith::CmpIOp>(
+        loc, mlir::arith::CmpIPredicate::eq, testInit, reductionVal);
+    ifMinSetOp = builder.create<fir::IfOp>(loc, cmp,
+                                           /*withElseRegion*/ false);
+  }
+  builder.setInsertionPointToStart(&ifMinSetOp.getThenRegion().front());
+
+  // Load output array with 1s instead of 0s
+  for (unsigned int i = 0; i < rank; ++i) {
+    mlir::Value index = builder.createIntegerConstant(loc, idxTy, i);
+    mlir::Value resultElemAddr =
+        getAddrFn(builder, loc, resultElemType, resultArr, index);
+    builder.create<fir::StoreOp>(loc, flagSet, resultElemAddr);
+  }
+  builder.setInsertionPointAfter(ifMaskTrueOp);
+}
 
 } // namespace fir
 
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp b/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
index 3f4ec4f3bccc80f..0915f25d49375de 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
@@ -812,38 +812,41 @@ class ReductionElementalConversion : public mlir::OpRewritePattern<Op> {
 // inlined elemental.
 //  %e = hlfir.elemental %shape ({ ... })
 //  %m = hlfir.minloc %array mask %e
-class MinMaxlocElementalConversion
-    : public mlir::OpRewritePattern<hlfir::MinlocOp> {
+template <typename Op>
+class MinMaxlocElementalConversion : public mlir::OpRewritePattern<Op> {
 public:
-  using mlir::OpRewritePattern<hlfir::MinlocOp>::OpRewritePattern;
+  using mlir::OpRewritePattern<Op>::OpRewritePattern;
 
   mlir::LogicalResult
-  matchAndRewrite(hlfir::MinlocOp minloc,
-                  mlir::PatternRewriter &rewriter) const override {
-    if (!minloc.getMask() || minloc.getDim() || minloc.getBack())
-      return rewriter.notifyMatchFailure(minloc, "Did not find valid minloc");
+  matchAndRewrite(Op mloc, mlir::PatternRewriter &rewriter) const override {
+    if (!mloc.getMask() || mloc.getDim() || mloc.getBack())
+      return rewriter.notifyMatchFailure(mloc,
+                                         "Did not find valid minloc/maxloc");
 
-    auto elemental = minloc.getMask().getDefiningOp<hlfir::ElementalOp>();
+    constexpr bool isMax = std::is_same_v<Op, hlfir::MaxlocOp>;
+
+    auto elemental =
+        mloc.getMask().template getDefiningOp<hlfir::ElementalOp>();
     if (!elemental || hlfir::elementalOpMustProduceTemp(elemental))
-      return rewriter.notifyMatchFailure(minloc, "Did not find elemental");
+      return rewriter.notifyMatchFailure(mloc, "Did not find elemental");
 
-    mlir::Value array = minloc.getArray();
+    mlir::Value array = mloc.getArray();
 
-    unsigned rank = mlir::cast<hlfir::ExprType>(minloc.getType()).getShape()[0];
+    unsigned rank = mlir::cast<hlfir::ExprType>(mloc.getType()).getShape()[0];
     mlir::Type arrayType = array.getType();
     if (!arrayType.isa<fir::BoxType>())
       return rewriter.notifyMatchFailure(
-          minloc, "Currently requires a boxed type input");
+          mloc, "Currently requires a boxed type input");
     mlir::Type elementType = hlfir::getFortranElementType(arrayType);
     if (!fir::isa_trivial(elementType))
       return rewriter.notifyMatchFailure(
-          minloc, "Character arrays are currently not handled");
+          mloc, "Character arrays are currently not handled");
 
-    mlir::Location loc = minloc.getLoc();
-    fir::FirOpBuilder builder{rewriter, minloc.getOperation()};
+    mlir::Location loc = mloc.getLoc();
+    fir::FirOpBuilder builder{rewriter, mloc.getOperation()};
     mlir::Value resultArr = builder.createTemporary(
         loc, fir::SequenceType::get(
-                 rank, hlfir::getFortranElementType(minloc.getType())));
+                 rank, hlfir::getFortranElementType(mloc.getType())));
 
     auto init = [](fir::FirOpBuilder builder, mlir::Location loc,
                    mlir::Type elementType) {
@@ -851,11 +854,13 @@ class MinMaxlocElementalConversion
         const llvm::fltSemantics &sem = ty.getFloatSemantics();
         return builder.createRealConstant(
             loc, elementType,
-            llvm::APFloat::getLargest(sem, /*Negative=*/false));
+            llvm::APFloat::getLargest(sem, /*Negative=*/!isMax));
       }
       unsigned bits = elementType.getIntOrFloatBitWidth();
-      int64_t maxInt = llvm::APInt::getSignedMaxValue(bits).getSExtValue();
-      return builder.createIntegerConstant(loc, elementType, maxInt);
+      int64_t limitInt =
+          isMax ? llvm::APInt::getSignedMinValue(bits).getSExtValue()
+                : llvm::APInt::getSignedMaxValue(bits).getSExtValue();
+      return builder.createIntegerConstant(loc, elementType, limitInt);
     };
 
     auto genBodyOp =
@@ -899,10 +904,16 @@ class MinMaxlocElementalConversion
       mlir::Value cmp;
       if (elementType.isa<mlir::FloatType>()) {
         cmp = builder.create<mlir::arith::CmpFOp>(
-            loc, mlir::arith::CmpFPredicate::OLT, elem, reduction);
+            loc,
+            isMax ? mlir::arith::CmpFPredicate::OGT
+                  : mlir::arith::CmpFPredicate::OLT,
+            elem, reduction);
       } else if (elementType.isa<mlir::IntegerType>()) {
         cmp = builder.create<mlir::arith::CmpIOp>(
-            loc, mlir::arith::CmpIPredicate::slt, elem, reduction);
+            loc,
+            isMax ? mlir::arith::CmpIPredicate::sgt
+                  : mlir::arith::CmpIPredicate::slt,
+            elem, reduction);
       } else {
         llvm_unreachable("unsupported type");
       }
@@ -975,15 +986,15 @@ class MinMaxlocElementalConversion
     // AsExpr for the temporary resultArr.
     llvm::SmallVector<hlfir::DestroyOp> destroys;
     llvm::SmallVector<hlfir::AssignOp> assigns;
-    for (auto user : minloc->getUsers()) {
+    for (auto user : mloc->getUsers()) {
       if (auto destroy = mlir::dyn_cast<hlfir::DestroyOp>(user))
         destroys.push_back(destroy);
       else if (auto assign = mlir::dyn_cast<hlfir::AssignOp>(user))
         assigns.push_back(assign);
     }
 
-    // Check if the minloc was the only user of the elemental (apart from a
-    // destroy), and remove it if so.
+    // Check if the minloc/maxloc was the only user of the elemental (apart from
+    // a destroy), and remove it if so.
     mlir::Operation::user_range elemUsers = elemental->getUsers();
     hlfir::DestroyOp elemDestroy;
     if (std::distance(elemUsers.begin(), elemUsers.end()) == 2) {
@@ -996,7 +1007,7 @@ class MinMaxlocElementalConversion
       rewriter.eraseOp(d);
     for (auto a : assigns)
       a.setOperand(0, resultArr);
-    rewriter.replaceOp(minloc, asExpr);
+    rewriter.replaceOp(mloc, asExpr);
     if (elemDestroy) {
       rewriter.eraseOp(elemDestroy);
       rewriter.eraseOp(elemental);
@@ -1030,7 +1041,8 @@ class OptimizedBufferizationPass
     patterns.insert<ReductionElementalConversion<hlfir::CountOp>>(context);
     patterns.insert<ReductionElementalConversion<hlfir::AnyOp>>(context);
     patterns.insert<ReductionElementalConversion<hlfir::AllOp>>(context);
-    patterns.insert<MinMaxlocElementalConversion>(context);
+    patterns.insert<MinMaxlocElementalConversion<hlfir::MinlocOp>>(context);
+    patterns.insert<MinMaxlocElementalConversion<hlfir::MaxlocOp>>(context);
 
     if (mlir::failed(mlir::applyPatternsAndFoldGreedily(
             func, std::move(patterns), config))) {
diff --git a/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp b/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp
index 2301e7146f14101..b415463075d68f9 100644
--- a/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp
+++ b/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp
@@ -353,134 +353,6 @@ genReductionLoop(fir::FirOpBuilder &builder, mlir::func::FuncOp &funcOp,
   builder.create<mlir::func::ReturnOp>(loc, results[resultIndex]);
 }
 
-void fir::genMinMaxlocReductionLoop(
-    fir::FirOpBuilder &builder, mlir::Value array,
-    fir::InitValGeneratorTy initVal, fir::MinlocBodyOpGeneratorTy genBody,
-    fir::AddrGeneratorTy getAddrFn, unsigned rank, mlir::Type elementType,
-    mlir::Location loc, mlir::Type maskElemType, mlir::Value resultArr,
-    bool maskMayBeLogicalScalar) {
-  mlir::IndexType idxTy = builder.getIndexType();
-
-  mlir::Value zeroIdx = builder.createIntegerConstant(loc, idxTy, 0);
-
-  fir::SequenceType::Shape flatShape(rank,
-                                     fir::SequenceType::getUnknownExtent());
-  mlir::Type arrTy = fir::SequenceType::get(flatShape, elementType);
-  mlir::Type boxArrTy = fir::BoxType::get(arrTy);
-  array = builder.create<fir::ConvertOp>(loc, boxArrTy, array);
-
-  mlir::Type resultElemType = hlfir::getFortranElementType(resultArr.getType());
-  mlir::Value flagSet = builder.createIntegerConstant(loc, resultElemType, 1);
-  mlir::Value zero = builder.createIntegerConstant(loc, resultElemType, 0);
-  mlir::Value flagRef = builder.createTemporary(loc, resultElemType);
-  builder.create<fir::StoreOp>(loc, zero, flagRef);
-
-  mlir::Value init = initVal(builder, loc, elementType);
-  llvm::SmallVector<mlir::Value, Fortran::common::maxRank> bounds;
-
-  assert(rank > 0 && "rank cannot be zero");
-  mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1);
-
-  // Compute all the upper bounds before the loop nest.
-  // It is not strictly necessary for performance, since the loop nest
-  // does not have any store operations and any LICM optimization
-  // should be able to optimize the redundancy.
-  for (unsigned i = 0; i < rank; ++i) {
-    mlir::Value dimIdx = builder.createIntegerConstant(loc, idxTy, i);
-    auto dims =
-        builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy, idxTy, array, dimIdx);
-    mlir::Value len = dims.getResult(1);
-    // We use C indexing here, so len-1 as loopcount
-    mlir::Value loopCount = builder.create<mlir::arith::SubIOp>(loc, len, one);
-    bounds.push_back(loopCount);
-  }
-  // Create a loop nest consisting of OP operations.
-  // Collect the loops' induction variables into indices array,
-  // which will be used in the innermost loop to load the input
-  // array's element.
-  // The loops are generated such that the innermost loop processes
-  // the 0 dimension.
-  llvm::SmallVector<mlir::Value, Fortran::common::maxRank> indices;
-  for (unsigned i = rank; 0 < i; --i) {
-    mlir::Value step = one;
-    mlir::Value loopCount = bounds[i - 1];
-    auto loop =
-        builder.create<fir::DoLoopOp>(loc, zeroIdx, loopCount, step, false,
-                                      /*finalCountValue=*/false, init);
-    init = loop.getRegionIterArgs()[0];
-    indices.push_back(loop.getInductionVar());
-    // Set insertion point to the loop body so that the next loop
-    // is inserted inside the current one.
-    builder.setInsertionPointToStart(loop.getBody());
-  }
-
-  // Reverse the indices such that they are ordered as:
-  //   <dim-0-idx, dim-1-idx, ...>
-  std::reverse(indices.begin(), indices.end());
-  mlir::Value reductionVal =
-      genBody(builder, loc, elementType, array, flagRef, init, indices);
-
-  // Unwind the loop nest and insert ResultOp on each level
-  // to return the updated value of the reduction to the enclosing
-  // loops.
-  for (unsigned i = 0; i < rank; ++i) {
-    auto result = builder.create<fir::ResultOp>(loc, reductionVal);
-    // Proceed to the outer loop.
-    auto loop = mlir::cast<fir::DoLoopOp>(result->getParentOp());
-    reductionVal = loop.getResult(0);
-    // Set insertion point after the loop operation that we have
-    // just processed.
-    builder.setInsertionPointAfter(loop.getOperation());
-  }
-  // End of loop nest. The insertion point is after the outermost loop.
-  if (maskMayBeLogicalScalar) {
-    if (fir::IfOp ifOp =
-            mlir::dyn_cast<fir::IfOp>(builder.getBlock()->getParentOp())) {
-      builder.create<fir::ResultOp>(loc, reductionVal);
-      builder.setInsertionPointAfter(ifOp);
-      // Redefine flagSet to escape scope of ifOp
-      flagSet = builder.createIntegerConstant(loc, resultElemType, 1);
-      reductionVal = ifOp.getResult(0);
-    }
-  }
-
-  // Check for case where array was full of max values.
-  // flag will be 0 if mask was never true, 1 if mask was true as some point,
-  // this is needed to avoid catching cases where we didn't access any elements
-  // e.g. mask=.FALSE.
-  mlir::Value flagValue =
-      builder.create<fir::LoadOp>(loc, resultElemType, flagRef);
-  mlir::Value flagCmp = builder.create<mlir::arith::CmpIOp>(
-      loc, mlir::arith::CmpIPredicate::eq, flagValue, flagSet);
-  fir::IfOp ifMaskTrueOp =
-      builder.create<fir::IfOp>(loc, flagCmp, /*withElseRegion=*/false);
-  builder.setInsertionPointToStart(&ifMaskTrueOp.getThenRegion().front());
-
-  mlir::Value testInit = initVal(builder, loc, elementType);
-  fir::IfOp ifMinSetOp;
-  if (elementType.isa<mlir::FloatType>()) {
-    mlir::Value cmp = builder.create<mlir::arith::CmpFOp>(
-        loc, mlir::arith::CmpFPredicate::OEQ, testInit, reductionVal);
-    ifMinSetOp = builder.create<fir::IfOp>(loc, cmp,
-                                           /*withElseRegion*/ false);
-  } else {
-    mlir::Value cmp = builder.create<mlir::arith::CmpIOp>(
-        loc, mlir::arith::CmpIPredicate::eq, testInit, reductionVal);
-    ifMinSetOp = builder.create<fir::IfOp>(loc, cmp,
-                                           /*withElseRegion*/ false);
-  }
-  builder.setInsertionPointToStart(&ifMinSetOp.getThenRegion().front());
-
-  // Load output array with 1s instead of 0s
-  for (unsigned int i = 0; i < rank; ++i) {
-    mlir::Value index = builder.createIntegerConstant(loc, idxTy, i);
-    mlir::Value resultElemAddr =
-        getAddrFn(builder, loc, resultElemType, resultArr, index);
-    builder.create<fir::StoreOp>(loc, flagSet, resultElemAddr);
-  }
-  builder.setInsertionPointAfter(ifMaskTrueOp);
-}
-
 static llvm::SmallVector<mlir::Value> nopLoopCond(fir::FirOpBuilder &builder,
                                                   mlir::Location loc,
                                                   mlir::Value reductionVal) {
diff --git a/flang/test/HLFIR/maxloc-elemental.fir b/flang/test/HLFIR/maxloc-elemental.fir
new file mode 100644
index 000000000000000..67cd9ee4bb75a7a
--- /dev/null
+++ b/flang/test/HLFIR/maxloc-elemental.fir
@@ -0,0 +1,140 @@
+// RUN: fir-opt %s -opt-bufferization | FileCheck %s
+
+func.func @_QPtest(%arg0: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "array"}, %arg1: !fir.ref<i32> {fir.bindc_name = "val"}, %arg2: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "m"}) {
+  %c0 = arith.constant 0 : index
+  %0:2 = hlfir.declare %arg0 {uniq_name = "_QFtestEarray"} : (!fir.box<!fir.array<?xi32>>) -> (!fir.box<!fir.array<?xi32>>, !fir.box<!fir.array<?xi32>>)
+  %1:2 = hlfir.declare %arg2 {uniq_name = "_QFtestEm"} : (!fir.box<!fir.array<?xi32>>) -> (!fir.box<!fir.array<?xi32>>, !fir.box<!fir.array<?xi32>>)
+  %2:2 = hlfir.declare %arg1 {uniq_name = "_QFtestEval"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+  %3 = fir.load %2#0 : !fir.ref<i32>
+  %4:3 = fir.box_dims %0#0, %c0 : (!fir.box<!fir.array<?xi32>>, index) -> (index, index, index)
+  %5 = fir.shape %4#1 : (index) -> !fir.shape<1>
+  %6 = hlfir.elemental %5 unordered : (!fir.shape<1>) -> !hlfir.expr<?x!fir.logical<4>> {
+  ^bb0(%arg3: index):
+    %8 = hlfir.designate %0#0 (%arg3)  : (!fir.box<!fir.array<?xi32>>, index) -> !fir.ref<i32>
+    %9 = fir.load %8 : !fir.ref<i32>
+    %10 = arith.cmpi sge, %9, %3 : i32
+    %11 = fir.convert %10 : (i1) -> !fir.logical<4>
+    hlfir.yield_element %11 : !fir.logical<4>
+  }
+  %7 = hlfir.maxloc %0#0 mask %6 {fastmath = #arith.fastmath<contract>} : (!fir.box<!fir.array<?xi32>>, !hlfir.expr<?x!fir.logical<4>>) -> !hlfir.expr<1xi32>
+  hlfir.assign %7 to %1#0 : !hlfir.expr<1xi32>, !fir.box<!fir.array<?xi32>>
+  hlfir.destroy %7 : !hlfir.expr<1xi32>
+  hlfir.destroy %6 : !hlfir.expr<?x!fir.logical<4>>
+  return
+}
+// CHECK-LABEL: func.func @_QPtest(%arg0: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "array"}, %arg1: !fir.ref<i32> {fir.bindc_name = "val"}, %arg2: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "m"}) {
+// CHECK-NEXT:    %c-2147483648_i32 = arith.constant -2147483648 : i32
+// CHECK-NEXT:    %c1_i32 = arith.constant 1 : i32
+// CHECK-NEXT:    %c0 = arith.constant 0 : index
+// CHECK-NEXT:    %c1 = arith.constant 1 : index
+// CHECK-NEXT:    %c0_i32 = arith.constant 0 : i32
+// CHECK-NEXT:    %[[V0:.*]] = fir.alloca i32
+// CHECK-NEXT:    %[[RES:.*]] = fir.alloca !fir.array<1xi32>
+// CHECK-NEXT:    %[[V1:.*]]:2 = hlfir.declare %arg0 {uniq_name = "_QFtestEarray"} : (!fir.box<!fir.array<?xi32>>) -> (!fir.box<!fir.array<?xi32>>, !fir.box<!fir.array<?xi32>>)
+// CHECK-NEXT:    %[[V2:.*]]:2 = hlfir.declare %arg2 {uniq_name = "_QFtestEm"} : (!fir.box<!fir.array<?xi32>>) -> (!fir.box<!fir.array<?xi32>>, !fir.box<!fir.array<?xi32>>)
+// CHECK-NEXT:    %[[V3:.*]]:2 = hlfir.declare %arg1 {uniq_name = "_QFtestEval"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+// CHECK-NEXT:    %[[V4:.*]] = fir.load %[[V3]]#0 : !fir.ref<i32>
+// CHECK-NEXT:    %[[V8:.*]] = hlfir.designate %[[RES]] (%c1) : (!fir.ref<!fir.array<1xi32>>, index) -> !fir.ref<i32>
+// CHECK-NEXT:    fir.store %c0_i32 to %[[V8]] : !fir.ref<i32>
+// CHECK-NEXT:    fir.store %c0_i32 to %[[V0]] : !fir.ref<i32>
+// CHECK-NEXT:    %[[V9:.*]]:3 = fir.box_dims %[[V1]]#0, %c0 : (!fir.box<!fir.array<?xi32>>, index) -> (index, index, index)
+// CHECK-NEXT:    %[[V10:.*]] = arith.subi %[[V9]]#1, %c1 : index
+// CHECK-NEXT:    %[[V11:.*]] = fir.do_loop %arg3 = %c0 to %[[V10]] step %c1 iter_args(%arg4 = %c-2147483648_i32) -> (i32) {
+// CHECK-NEXT:      %[[V14:.*]] = arith.addi %arg3, %c1 : index
+// CHECK-NEXT:      %[[V15:.*]] = hlfir.designate %[[V1]]#0 (%[[V14]])  : (!fir.box<!fir.array<?xi32>>, index) -> !fir.ref<i32>
+// CHECK-NEXT:      %[[V16:.*]] = fir.load %[[V15]] : !fir.ref<i32>
+// CHECK-NEXT:      %[[V17:.*]] = arith.cmpi sge, %[[V16]], %[[V4]] : i32
+// CHECK-NEXT:      %[[V18:.*]] = fir.if %[[V17]] -> (i32) {
+// CHECK-NEXT:        fir.store %c1_i32 to %[[V0]] : !fir.ref<i32>
+// CHECK-NEXT:        %[[DIMS:.*]]:3 = fir.box_dims %[[V1]]#0, %c0 : (!fir.box<!fir.array<?xi32>>, index) -> (index, index, index)
+// CHECK-NEXT:        %[[SUB:.*]] = arith.subi %[[DIMS]]#0, %c1 : index
+// CHECK-NEXT:        %[[ADD:.*]] = arith.addi %[[V14]], %[[SUB]] : index
+// CHECK-NEXT:        %[[V19:.*]] = hlfir.designate %[[V1]]#0 (%[[ADD]]) : (!fir.box<!fir.array<?xi32>>, index) -> !fir.ref<i32>
+// CHECK-NEXT:        %[[V20:.*]] = fir.load %[[V19]] : !fir.ref<i32>
+// CHECK-NEXT:        %[[V21:.*]] = arith.cmpi sgt, %[[V20]], %arg4 : i32
+// CHECK-NEXT:        %[[V22:.*]] = fir.if %[[V21]] -> (i32) {
+// CHECK-NEXT:          %[[V23:.*]] = hlfir.designate %[[RES]] (%c1) : (!fir.ref<!fir.array<1xi32>>, index) -> !fir.ref<i32>
+// CHECK-NEXT:          %[[V24:.*]] = fir.convert %[[V14]] : (index) -> i32
+// CHECK-NEXT:          fir.store %[[V24]] to %[[V23]] : !fir.ref<i32>
+// CHECK-NEXT:          fir.result %[[V20]] : i32
+// CHECK-NEXT:        } else {
+// CHECK-NEXT:          fir.result %arg4 : i32
+// CHECK-NEXT:        }
+// CHECK-NEXT:        fir.result %[[V22]] : i32
+// CHECK-NEXT:      } else {
+// CHECK-NEXT:        fir.result %arg4 : i32
+// CHECK-NEXT:      }
+// CHECK-NEXT:      fir.result %[[V18]] : i32
+// CHECK-NEXT:    }
+// CHECK-NEXT:    %[[V12:.*]] = fir.load %[[V0]] : !fir.ref<i32>
+// CHECK-NEXT:    %[[V13:.*]] = arith.cmpi eq, %[[V12]], %c1_i32 : i32
+// CHECK-NEXT:    fir.if %[[V13]] {
+// CHECK-NEXT:      %[[V14:.*]] = arith.cmpi eq, %[[V11]], %c-2147483648_i32 : i32
+// CHECK-NEXT:      fir.if %[[V14]] {
+// CHECK-NEXT:        %[[V15:.*]] = hlfir.designate %[[RES]] (%c1) : (!fir.ref<!fir.array<1xi32>>, index) -> !fir.ref<i32>
+// CHECK-NEXT:        fir.store %c1_i32 to %[[V15]] : !fir.ref<i32>
+// CHECK-NEXT:      }
+// CHECK-NEXT:    }
+// CHECK-NEXT:    %[[BD:.*]]:3 = fir.box_dims %[[V2]]#0, %c0 : (!fir.box<!fir.array<?xi32>>, index) -> (index, index, index)
+// CHECK-NEXT:    fir.do_loop %arg3 = %c1 to %[[BD]]#1 step %c1 unordered {
+// CHECK-NEXT:      %[[V13:.*]] = hlfir.designate %[[RES]] (%arg3)  : (!fir.ref<!fir.array<1xi32>>, index) -> !fir.ref<i32>
+// CHECK-NEXT:      %[[V14:.*]] = fir.load %[[V13]] : !fir.ref<i32>
+// CHECK-NEXT:      %[[V15:.*]] = hlfir.designate %[[V2]]#0 (%arg3)  : (!fir.box<!fir.array<?xi32>>, index) -> !fir.ref<i32>
+// CHECK-NEXT:      hlfir.assign %[[V14]] to %[[V15]] : i32, !fir.ref<i32>
+// CHECK-NEXT:    }
+// CHECK-NEXT:    return
+// CHECK-NEXT:  }
+
+
+
+func.func @_QPtest_float(%arg0: !fir.box<!fir.array<?xf32>> {fir.bindc_name = "array"}, %arg1: !fir.ref<f32> {fir.bindc_name = "val"}, %arg2: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "m"}) {
+  %c0 = arith.constant 0 : index
+  %0:2 = hlfir.declare %arg0 {uniq_name = "_QFtestEarray"} : (!fir.box<!fir.array<?xf32>>) -> (!fir.box<!fir.array<?xf32>>, !fir.box<!fir.array<?xf32>>)
+  %1:2 = hlfir.declare %arg2 {uniq_name = "_QFtestEm"} : (!fir.box<!fir.array<?xi32>>) -> (!fir.box<!fir.array<?xi32>>, !fir.box<!fir.array<?xi32>>)
+  %2:2 = hlfir.declare %arg1 {uniq_name = "_QFtestEval"} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
+  %3 = fir.load %2#0 : !fir.ref<f32>
+  %4:3 = fir.box_dims %0#0, %c0 : (!fir.box<!fir.array<?xf32>>, index) -> (index, index, index)
+  %5 = fir.shape %4#1 : (index) -> !fir.shape<1>
+  %6 = hlfir.elemental %5 unordered : (!fir.shape<1>) -> !hlfir.expr<?x!fir.logical<4>> {
+  ^bb0(%arg3: index):
+    %8 = hlfir.designate %0#0 (%arg3)  : (!fir.box<!fir.array<?xf32>>, index) -> !fir.ref<f32>
+    %9 = fir.load %8 : !fir.ref<f32>
+    %10 = arith.cmpf oge, %9, %3 : f32
+    %11 = fir.convert %10 : (i1) -> !fir.logical<4>
+    hlfir.yield_element %11 : !fir.logical<4>
+  }
+  %7 = hlfir.maxloc %0#0 mask %6 {fastmath = #arith.fastmath<contract>} : (!fir.box<!fir.array<?xf32>>, !hlfir.expr<?x!fir.logical<4>>) -> !hlfir.expr<1xi32>
+  hlfir.assign %7 to %1#0 : !hlfir.expr<1xi32>, !fir.box<!fir.array<?xi32>>
+  hlfir.destroy %7 : !hlfir.expr<1xi32>
+  hlfir.destroy %6 : !hlfir.expr<?x!fir.logical<4>>
+  return
+}
+// CHECK-LABEL: _QPtest_float
+// CHECK:        %[[V11:.*]] = fir.do_loop %arg3 = %c0 to %[[V10:.*]] step %c1 iter_args(%arg4 = %cst) -> (f32) {
+// CHECK-NEXT:     %[[V14:.*]] = arith.addi %arg3, %c1 : index
+// CHECK-NEXT:     %[[V15:.*]] = hlfir.designate %[[V1:.*]]#0 (%[[V14]])  : (!fir.box<!fir.array<?xf32>>, index) -> !fir.ref<f32>
+// CHECK-NEXT:     %[[V16:.*]] = fir.load %[[V15]] : !fir.ref<f32>
+// CHECK-NEXT:     %[[V17:.*]] = arith.cmpf oge, %[[V16]], %[[V4:.*]] : f32
+// CHECK-NEXT:     %[[V18:.*]] = fir.if %[[V17]] -> (f32) {
+// CHECK-NEXT:       fir.store %c1_i32 to %[[V0:.*]] : !fir.ref<i32>
+// CHECK-NEXT:       %[[DIMS:.*]]:3 = fir.box_dims %2#0, %c0 : (!fir.box<!fir.array<?xf32>>, index) -> (index, index, index)
+// CHECK-NEXT:       %[[SUB:.*]] = arith.subi %[[DIMS]]#0, %c1 : index
+// CHECK-NEXT:       %[[ADD:.*]] = arith.addi %[[V14]], %[[SUB]] : index
+// CHECK-NEXT:       %[[V19:.*]] = hlfir.designate %[[V1]]#0 (%[[ADD]]) : (!fir.box<!fir.array<?xf32>>, index) -> !fir.ref<f32>
+// CHECK-NEXT:       %[[V20:.*]] = fir.load %[[V19]] : !fir.ref<f32>
+// CHECK-NEXT:       %[[V21:.*]] = arith.cmpf ogt, %[[V20]], %arg4 fastmath<contract> : f32
+// CHECK-NEXT:       %[[V22:.*]] = fir.if %[[V21]] -> (f32) {
+// CHECK-NEXT:         %[[V23:.*]] = hlfir.designate %{{.}} (%c1) : (!fir.ref<!fir.array<1xi32>>, index) -> !fir.ref<i32>
+// CHECK-NEXT:         %[[V24:.*]] = fir.convert %[[V14]] : (index) -> i32
+// CHECK-NEXT:         fir.store %[[V24]] to %[[V23]] : !fir.ref<i32>
+// CHECK-NEXT:         fir.result %[[V20]] : f32
+// CHECK-NEXT:       } else {
+// CHECK-NEXT:         fir.result %arg4 : f32
+// CHECK-NEXT:       }
+// CHECK-NEXT:       fir.result %[[V22]] : f32
+// CHECK-NEXT:     } else {
+// CHECK-NEXT:       fir.result %arg4 : f32
+// CHECK-NEXT:     }
+// CHECK-NEXT:     fir.result %[[V18]] : f32
+// CHECK-NEXT:   }
+



More information about the flang-commits mailing list