[flang-commits] [flang] [flang][HLFIR] Optimize MINLOC/MAXLOC for equality masks (PR #186916)
via flang-commits
flang-commits at lists.llvm.org
Mon Mar 23 08:10:27 PDT 2026
https://github.com/anoopkg6 updated https://github.com/llvm/llvm-project/pull/186916
>From c0a9d1369803086be8a5e2c7eb83641efc79f332 Mon Sep 17 00:00:00 2001
From: "anoop.kumar6 at ibm.com" <anoopk at b35lp63.lnxne.boe>
Date: Mon, 16 Mar 2026 15:26:39 +0100
Subject: [PATCH] [flang][HLFIR] Optimize MINLOC/MAXLOC for equality masks
This patch implements `isEqualityMask` to identify when the MASK argument is an equality comparison against an invariant value (e.g., MASK = A == X).
- This allows the SimplifyHLFIRIntrinsicscation pass to extract the invariant
search target and bypasses the creation of a temporary logical mask array
by inlining the equality comparison directly into the reduction loop.
optimization removes the 'hlfir.apply' to the mask's hlfir.elemental, which
gets eliminated in bufferize-hlfir pass.
- Simplifies the reduction state by removing the min/max value tracker,
as the target value is already known.
- Implements a "first-hit" locking mechanism.
Test Coverage:
- 1D, 2D, 3D Variable/Constant equality searches - Verified optimized
- Duplicate match handling - Verified first-occurrence logic
- No-match cases - Verified zero result
- Different array/Non-invariant target - Verified safe fallback
---
.../Transforms/SimplifyHLFIRIntrinsics.cpp | 225 +++++++++++++-
...plify-hlfir-intrinsics-equality-maxloc.fir | 269 +++++++++++++++++
...plify-hlfir-intrinsics-equality-minloc.fir | 274 ++++++++++++++++++
3 files changed, 764 insertions(+), 4 deletions(-)
create mode 100644 flang/test/HLFIR/simplify-hlfir-intrinsics-equality-maxloc.fir
create mode 100644 flang/test/HLFIR/simplify-hlfir-intrinsics-equality-minloc.fir
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
index f47353dc30f64..2d987b6300ab3 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
@@ -32,6 +32,94 @@ namespace hlfir {
#define DEBUG_TYPE "simplify-hlfir-intrinsics"
+namespace {
+// Check if the given mask is an equality comparison of the search array
+// against an invariant value (e.g., MASK = A == target) by traversing
+// HLFIR/FIR operations to find the underlying elemental comparison
+// and extract the invariant search targetVal.
+// It returns true if the mask is a simple equality comparison against a
+// scalar/invariant.
+bool isEqualityMask(mlir::Value mask, mlir::Value searchArray,
+ mlir::Value &targetVal) {
+ if (!mask)
+ return false;
+
+ // Trace back HLFIR/FIR wrappers to get Elemental producer.
+ mlir::Value currentMask = mask;
+ while (auto def = currentMask.getDefiningOp()) {
+ if (!mlir::isa<hlfir::AsExprOp, fir::ConvertOp, hlfir::DeclareOp,
+ hlfir::CopyInOp>(def))
+ break;
+ currentMask = def->getOperand(0);
+ }
+ // Ensure the mask is produced by an hlfir.elemental.
+ auto elemental = currentMask.getDefiningOp<hlfir::ElementalOp>();
+ if (!elemental)
+ return false;
+
+ // Inspect the elemental body to find the boolean result logic.
+ mlir::Block &body = elemental.getRegion().front();
+ auto yieldOp = mlir::cast<hlfir::YieldElementOp>(body.getTerminator());
+ mlir::Value val = yieldOp.getElementValue();
+ // Get core comparison, ignoring intermediate type casts.
+ while (auto conv = val.getDefiningOp<fir::ConvertOp>())
+ val = conv.getOperand();
+
+ // We currently only optimize integer equality (arith.cmpi eq).
+ auto cmpOp = val.getDefiningOp<mlir::arith::CmpIOp>();
+ if (!cmpOp || cmpOp.getPredicate() != mlir::arith::CmpIPredicate::eq)
+ return false;
+
+ // Determine if a value is invariant relative to the mask loop.
+ // Handles constants, function arguments, and values defined in outer scopes.
+ auto isInvariant = [&](mlir::Value v) {
+ if (auto arg = mlir::dyn_cast<mlir::BlockArgument>(v))
+ return arg.getOwner()->getParent() != &elemental.getRegion();
+ if (auto *op = v.getDefiningOp())
+ return !elemental.getRegion().isAncestor(op->getParentRegion());
+ return true;
+ };
+
+ // Trace the Array Side to the base buffer.
+ auto getBase = [](mlir::Value v) -> mlir::Value {
+ while (v) {
+ mlir::Operation *def = v.getDefiningOp();
+ if (!def)
+ break;
+ if (auto decl = mlir::dyn_cast<hlfir::DeclareOp>(def))
+ v = decl.getMemref();
+ else if (auto load = mlir::dyn_cast<fir::LoadOp>(def))
+ v = load.getMemref();
+ else if (auto apply = mlir::dyn_cast<hlfir::ApplyOp>(def))
+ v = apply.getExpr();
+ else if (auto des = mlir::dyn_cast<hlfir::DesignateOp>(def))
+ v = des.getMemref();
+ else if (mlir::isa<fir::ConvertOp, hlfir::AsExprOp>(def))
+ v = def->getOperand(0);
+ else
+ break;
+ }
+ return v;
+ };
+
+ mlir::Value lhs = cmpOp.getLhs(), rhs = cmpOp.getRhs();
+ bool lhsInv = isInvariant(lhs), rhsInv = isInvariant(rhs);
+ // The optimization is valid only if exactly one side is invariant (the
+ // target) and the other side is variant (the array element).
+ if (lhsInv == rhsInv)
+ return false;
+
+ targetVal = lhsInv ? lhs : rhs;
+ mlir::Value arraySide = lhsInv ? rhs : lhs;
+
+ // Verify the mask refers to the same array being searched.
+ if (getBase(arraySide) == getBase(searchArray))
+ return true;
+
+ return false;
+}
+} // end anonymous namespace
+
static llvm::cl::opt<bool> forceMatmulAsElemental(
"flang-inline-matmul-as-elemental",
llvm::cl::desc("Expand hlfir.matmul as elemental operation"),
@@ -526,6 +614,15 @@ class MinMaxlocAsElementalConverter : public ReductionAsElementalConverter {
void
checkReductions(const llvm::SmallVectorImpl<mlir::Value> &reductions) const {
+ mlir::Value targetVal;
+ // Check if the mask qualifies for the optimized equality mask search path.
+ if (isEqualityMask(this->getMask(), mlir::cast<T>(this->op).getArray(),
+ targetVal)) {
+ // Expect coordinate indices.
+ assert(reductions.size() == getNumCoors() &&
+ "invalid number of reductions for equality mask MINLOC/MAXLOC");
+ return;
+ }
if (!useIsFirst())
assert(reductions.size() == getNumCoors() + 1 &&
"invalid number of reductions for MINLOC/MAXLOC");
@@ -635,6 +732,51 @@ llvm::SmallVector<mlir::Value>
MinMaxlocAsElementalConverter<T>::reduceOneElement(
const llvm::SmallVectorImpl<mlir::Value> ¤tValue, hlfir::Entity array,
mlir::ValueRange oneBasedIndices) {
+ mlir::Value targetVal;
+ // The mask is an equality comparison (e.g., MASK = A == target) inline the
+ // comparison to find the first occurrence efficiently.
+ if (isEqualityMask(this->getMask(), array, targetVal)) {
+ // Directly load the array element and compare with the targetVal.
+ hlfir::Entity elementValue =
+ hlfir::loadElementAt(loc, builder, array, oneBasedIndices);
+ mlir::Value isMatch = mlir::arith::CmpIOp::create(
+ builder, loc, mlir::arith::CmpIPredicate::eq, (mlir::Value)elementValue,
+ targetVal);
+ // currentValue contains [Coord1, ..., CoordN, FirstHitBool]
+ mlir::Value firstHitBool = currentValue.back();
+ // shouldUpdate is true only if we have a match and we haven't found one
+ // yet.
+ mlir::Value shouldUpdate =
+ mlir::arith::AndIOp::create(builder, loc, isMatch, firstHitBool);
+ // Conditional Update: Only update coordinates if a match is found.
+ auto ifOp = fir::IfOp::create(builder, loc,
+ mlir::ValueRange(currentValue).getTypes(),
+ shouldUpdate, /*withElse=*/true);
+ // If match found and it's the first one, record coordinates.
+ builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
+ llvm::SmallVector<mlir::Value> thenResults;
+ unsigned rank = array.getRank();
+ // Get the firstHit flag.
+ for (unsigned i = 0; i < rank; ++i) {
+ mlir::Value loopIdx = builder.createConvert(
+ loc, currentValue[i].getType(), oneBasedIndices[i]);
+ thenResults.emplace_back(loopIdx);
+ }
+
+ // Update the flag: Set to 0 (False) for all future iterations.
+ mlir::Value falseVal =
+ mlir::arith::ConstantIntOp::create(builder, loc, 0, 1);
+ thenResults.emplace_back(falseVal);
+
+ fir::ResultOp::create(builder, loc, thenResults);
+
+ // No match or already found a previous match: maintain the current state.
+ builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
+ fir::ResultOp::create(builder, loc, currentValue);
+
+ builder.setInsertionPointAfter(ifOp);
+ return ifOp.getResults();
+ }
checkReductions(currentValue);
hlfir::Entity elementValue =
hlfir::loadElementAt(loc, builder, array, oneBasedIndices);
@@ -685,6 +827,49 @@ MinMaxlocAsElementalConverter<T>::reduceOneElement(
template <typename T>
hlfir::Entity MinMaxlocAsElementalConverter<T>::genFinalResult(
const llvm::SmallVectorImpl<mlir::Value> &reductionResults) {
+ mlir::Value targetVal;
+ // Finalize results for the equality-mask search.
+ if (isEqualityMask(this->getMask(), mlir::cast<T>(this->op).getArray(),
+ targetVal)) {
+ unsigned rank = getNumCoors();
+ mlir::Type resultElemTy =
+ hlfir::getFortranElementType(this->getResultType());
+ // MINLOC/MAXLOC returns an integer array of shape [rank].
+ // Manually build the HLFIR expression to hold the resulting coordinates.
+ llvm::SmallVector<int64_t> shapeVec{static_cast<int64_t>(rank)};
+ mlir::Type exprTy = hlfir::ExprType::get(builder.getContext(), shapeVec,
+ resultElemTy, false);
+ mlir::Value resRank =
+ builder.createIntegerConstant(loc, builder.getIndexType(), rank);
+ mlir::Value resShape = fir::ShapeOp::create(builder, loc, resRank);
+
+ // Create an elemental operation to map the scalar reduction results
+ // (coordinates) back into a Fortran array result.
+ auto elemental =
+ hlfir::ElementalOp::create(builder, loc, exprTy, resShape,
+ /*mold=*/mlir::Value{},
+ /*typeparams=*/mlir::ValueRange{},
+ /*isUnordered=*/false);
+ {
+ // Fill the elemental body.
+ mlir::OpBuilder::InsertionGuard guard(builder);
+ builder.setInsertionPointToStart(elemental.getBody());
+ // Map the 1-based elemental index, result[i] = reductionResults[i-1].
+ mlir::Value elemIdx = elemental.getIndices()[0];
+ mlir::Value resultVal = reductionResults[0];
+ for (unsigned i = 1; i < rank; ++i) {
+ mlir::Value dimConst =
+ builder.createIntegerConstant(loc, builder.getIndexType(), i + 1);
+ mlir::Value isDimMatch = mlir::arith::CmpIOp::create(
+ builder, loc, mlir::arith::CmpIPredicate::eq, elemIdx, dimConst);
+ // Select specific coordinate matching current elemental dimension.
+ resultVal = mlir::arith::SelectOp::create(
+ builder, loc, isDimMatch, reductionResults[i], resultVal);
+ }
+ hlfir::YieldElementOp::create(builder, loc, resultVal);
+ }
+ return hlfir::Entity{elemental.getResult()};
+ }
// Identification of the final result of MINLOC/MAXLOC:
// * If DIM is absent, the result is rank-one array.
// * If DIM is present:
@@ -1214,9 +1399,39 @@ mlir::LogicalResult ReductionAsElementalConverter::convert() {
extents.push_back(
builder.createConvert(loc, builder.getIndexType(), dimExtent));
- // Initial value for the reduction.
- llvm::SmallVector<mlir::Value, 1> reductionInitValues =
- genReductionInitValues(inputIndices, extents);
+ mlir::Value minMaxMask;
+ if (auto minloc = mlir::dyn_cast<hlfir::MinlocOp>(op)) {
+ minMaxMask = minloc.getMask();
+ } else if (auto maxloc = mlir::dyn_cast<hlfir::MaxlocOp>(op)) {
+ minMaxMask = maxloc.getMask();
+ }
+ mlir::Value targetVal;
+ bool isFixedSearch = false;
+ // Check if the mask allows for a simplified search optimization.
+ if (minMaxMask)
+ isFixedSearch =
+ isEqualityMask(minMaxMask, this->op->getOperand(0), targetVal);
+ llvm::SmallVector<mlir::Value, 1> reductionInitValues;
+ if (isFixedSearch) {
+ // For optimized equality searches, we skip the 'Min/Max value' reduction
+ // and only track coordinate indices and the firstHit flag.
+ unsigned rank = hlfir::Entity{array}.getRank();
+ mlir::Type resElemTy =
+ hlfir::getFortranElementType(this->getResultType());
+ mlir::Value zeroVal = builder.createIntegerConstant(loc, resElemTy, 0);
+
+ // Initialize all coordinates to 0.
+ for (unsigned i = 0; i < rank; ++i) {
+ reductionInitValues.emplace_back(zeroVal);
+ }
+ // First hit flag: [Row, Col, FirstHit=1] (Size: 3)
+ mlir::Type i1Type = builder.getI1Type();
+ mlir::Value firstHitTrue = mlir::arith::ConstantOp::create(
+ builder, loc, i1Type, builder.getBoolAttr(true));
+ reductionInitValues.emplace_back(firstHitTrue);
+ } else {
+ reductionInitValues = genReductionInitValues(inputIndices, extents);
+ }
auto genBody = [&](mlir::Location loc, fir::FirOpBuilder &builder,
mlir::ValueRange oneBasedIndices,
@@ -1238,7 +1453,9 @@ mlir::LogicalResult ReductionAsElementalConverter::convert() {
llvm::transform(reductionValues, std::back_inserter(reductionTypes),
[](mlir::Value v) { return v.getType(); });
fir::IfOp ifOp;
- if (mask) {
+ // Skip standard masking block in case of 'isFixedSearch', as it handles
+ // its own masking logic inside the comparison.
+ if (mask && !isFixedSearch) {
// Make the reduction value update conditional on the value
// of the mask.
if (!maskValue) {
diff --git a/flang/test/HLFIR/simplify-hlfir-intrinsics-equality-maxloc.fir b/flang/test/HLFIR/simplify-hlfir-intrinsics-equality-maxloc.fir
new file mode 100644
index 0000000000000..31925ae41467e
--- /dev/null
+++ b/flang/test/HLFIR/simplify-hlfir-intrinsics-equality-maxloc.fir
@@ -0,0 +1,269 @@
+// RUN: fir-opt %s --simplify-hlfir-intrinsics | FileCheck %s
+
+// Rank 1: Variable: A == %target
+func.func @test_maxloc_1d_equality_variable(%arg0: !hlfir.expr<?xi32>, %target: i32) -> !hlfir.expr<1xi32> {
+ %shape = hlfir.shape_of %arg0 : (!hlfir.expr<?xi32>) -> !fir.shape<1>
+ %mask = hlfir.elemental %shape : (!fir.shape<1>) -> !hlfir.expr<?x!fir.logical<4>> {
+ ^bb0(%i: index):
+ %val = hlfir.apply %arg0, %i : (!hlfir.expr<?xi32>, index) -> i32
+ %cmp = arith.cmpi eq, %val, %target : i32
+ %logical = fir.convert %cmp : (i1) -> !fir.logical<4>
+ hlfir.yield_element %logical : !fir.logical<4>
+ }
+ %res = hlfir.maxloc %arg0 mask %mask : (!hlfir.expr<?xi32>, !hlfir.expr<?x!fir.logical<4>>) -> !hlfir.expr<1xi32>
+ return %res : !hlfir.expr<1xi32>
+}
+// CHECK-LABEL: func.func @test_maxloc_1d_equality_variable
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : i32
+// CHECK-DAG: %[[TRUE:.*]] = arith.constant true
+// CHECK-NOT: arith.constant -2147483648 : i32
+// CHECK: fir.do_loop %[[IV:.*]] = {{.*}} iter_args(%[[LOC:.*]] = %[[C0]], %[[FIRST:.*]] = %[[TRUE]]) -> (i32, i1)
+// CHECK-NOT: hlfir.apply {{.*}} !fir.logical<4>
+// CHECK: %[[VAL:.*]] = hlfir.apply %{{.*}}, %[[IV]]
+// CHECK: %[[EQ:.*]] = arith.cmpi eq, %[[VAL]], %{{.*}}
+// CHECK: %[[COND:.*]] = arith.andi %[[EQ]], %[[FIRST]]
+// CHECK: %[[IF_RES:.*]]:2 = fir.if %[[COND]] -> (i32, i1)
+
+// Rank 2: Variable: A == %target
+func.func @test_maxloc_2d_equality_variable(%arg0: !hlfir.expr<?x?xi32>, %target: i32) -> !hlfir.expr<2xi32> {
+ %shape = hlfir.shape_of %arg0 : (!hlfir.expr<?x?xi32>) -> !fir.shape<2>
+ %mask = hlfir.elemental %shape : (!fir.shape<2>) -> !hlfir.expr<?x?x!fir.logical<4>> {
+ ^bb0(%i: index, %j: index):
+ %val = hlfir.apply %arg0, %i, %j : (!hlfir.expr<?x?xi32>, index, index) -> i32
+ %cmp = arith.cmpi eq, %val, %target : i32
+ %logical = fir.convert %cmp : (i1) -> !fir.logical<4>
+ hlfir.yield_element %logical : !fir.logical<4>
+ }
+ %res = hlfir.maxloc %arg0 mask %mask : (!hlfir.expr<?x?xi32>, !hlfir.expr<?x?x!fir.logical<4>>) -> !hlfir.expr<2xi32>
+ return %res : !hlfir.expr<2xi32>
+}
+// CHECK-LABEL: func.func @test_maxloc_2d_equality_variable
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : i32
+// CHECK-DAG: %[[TRUE:.*]] = arith.constant true
+// CHECK: %[[RES_OUTER:.*]]:3 = fir.do_loop %[[IV1:.*]] = {{.*}} iter_args(%[[OUT1:.*]] = %[[C0]], %[[OUT2:.*]] = %[[C0]], %[[OUT3:.*]] = %[[TRUE]]) -> (i32, i32, i1)
+// CHECK: %[[RES_INNER:.*]]:3 = fir.do_loop %[[IV2:.*]] = {{.*}} iter_args(%[[IN1:.*]] = %[[OUT1]], %[[IN2:.*]] = %[[OUT2]], %[[IN3:.*]] = %[[OUT3]]) -> (i32, i32, i1)
+// CHECK-NOT: hlfir.apply {{.*}} !fir.logical<4>
+// CHECK: %[[EQ:.*]] = arith.cmpi eq, {{.*}}
+// CHECK: %[[COND:.*]] = arith.andi %[[EQ]], %[[IN3]]
+// CHECK: %[[IF:.*]]:3 = fir.if %[[COND]] -> (i32, i32, i1)
+
+// Rank 3: Variable: A == %target
+func.func @test_maxloc_3d_equality_variable(%arg0: !hlfir.expr<?x?x?xi32>, %target: i32) -> !hlfir.expr<3xi32> {
+ %shape = hlfir.shape_of %arg0 : (!hlfir.expr<?x?x?xi32>) -> !fir.shape<3>
+ %mask = hlfir.elemental %shape : (!fir.shape<3>) -> !hlfir.expr<?x?x?x!fir.logical<4>> {
+ ^bb0(%i: index, %j: index, %k: index):
+ %val = hlfir.apply %arg0, %i, %j, %k : (!hlfir.expr<?x?x?xi32>, index, index, index) -> i32
+ %cmp = arith.cmpi eq, %val, %target : i32
+ %logical = fir.convert %cmp : (i1) -> !fir.logical<4>
+ hlfir.yield_element %logical : !fir.logical<4>
+ }
+ %res = hlfir.maxloc %arg0 mask %mask : (!hlfir.expr<?x?x?xi32>, !hlfir.expr<?x?x?x!fir.logical<4>>) -> !hlfir.expr<3xi32>
+ return %res : !hlfir.expr<3xi32>
+}
+// CHECK-LABEL: func.func @test_maxloc_3d_equality_variable
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : i32
+// CHECK-DAG: %[[TRUE:.*]] = arith.constant true
+// CHECK: %[[OUTER:.*]]:4 = fir.do_loop %[[IV1:.*]] = {{.*}} iter_args(%[[O1:.*]] = %[[C0]], %[[O2:.*]] = %[[C0]], %[[O3:.*]] = %[[C0]], %[[O4:.*]] = %[[TRUE]]) -> (i32, i32, i32, i1)
+// CHECK: %[[MIDDLE:.*]]:4 = fir.do_loop %[[IV2:.*]] = {{.*}} iter_args(%[[M1:.*]] = %[[O1]], %[[M2:.*]] = %[[O2]], %[[M3:.*]] = %[[O3]], %[[M4:.*]] = %[[O4]]) -> (i32, i32, i32, i1)
+// CHECK: %[[INNER:.*]]:4 = fir.do_loop %[[IV3:.*]] = {{.*}} iter_args(%[[I1:.*]] = %[[M1]], %[[I2:.*]] = %[[M2]], %[[I3:.*]] = %[[M3]], %[[I4:.*]] = %[[M4]]) -> (i32, i32, i32, i1)
+// CHECK-NOT: hlfir.apply {{.*}} !fir.logical<4>
+// CHECK: %[[EQ:.*]] = arith.cmpi eq, {{.*}}
+// CHECK: %[[COND:.*]] = arith.andi %[[EQ]], %[[I4]]
+// CHECK: %[[IF:.*]]:4 = fir.if %[[COND]] -> (i32, i32, i32, i1)
+
+// Rank 1: Constant: A == 42
+func.func @test_maxloc_1d_equality_constant(%arg0: !hlfir.expr<?xi32>) -> !hlfir.expr<1xi32> {
+ %c42 = arith.constant 42 : i32
+ %shape = hlfir.shape_of %arg0 : (!hlfir.expr<?xi32>) -> !fir.shape<1>
+ %mask = hlfir.elemental %shape : (!fir.shape<1>) -> !hlfir.expr<?x!fir.logical<4>> {
+ ^bb0(%i: index):
+ %val = hlfir.apply %arg0, %i : (!hlfir.expr<?xi32>, index) -> i32
+ %cmp = arith.cmpi eq, %val, %c42 : i32
+ %logical = fir.convert %cmp : (i1) -> !fir.logical<4>
+ hlfir.yield_element %logical : !fir.logical<4>
+ }
+ %res = hlfir.maxloc %arg0 mask %mask : (!hlfir.expr<?xi32>, !hlfir.expr<?x!fir.logical<4>>) -> !hlfir.expr<1xi32>
+ return %res : !hlfir.expr<1xi32>
+}
+// CHECK-LABEL: func.func @test_maxloc_1d_equality_constant
+// CHECK-DAG: %[[C42:.*]] = arith.constant 42 : i32
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : i32
+// CHECK-DAG: %[[TRUE:.*]] = arith.constant true
+
+// CHECK: %[[RES:.*]]:2 = fir.do_loop %[[IV:.*]] = {{.*}} iter_args(%[[LOC:.*]] = %[[C0]], %[[FIRST:.*]] = %[[TRUE]]) -> (i32, i1)
+// CHECK-NOT: hlfir.apply {{.*}} !fir.logical<4>
+// CHECK: %[[VAL:.*]] = hlfir.apply %{{.*}}, %[[IV]]
+// CHECK: %[[EQ:.*]] = arith.cmpi eq, %[[VAL]], %[[C42]]
+// CHECK: %[[COND:.*]] = arith.andi %[[EQ]], %[[FIRST]]
+// CHECK: %[[IF:.*]]:2 = fir.if %[[COND]] -> (i32, i1)
+
+// Rank 2: Constant: A == 42
+func.func @test_maxloc_2d_equality_constant(%arg0: !hlfir.expr<?x?xi32>) -> !hlfir.expr<2xi32> {
+ %c42 = arith.constant 42 : i32
+ %shape = hlfir.shape_of %arg0 : (!hlfir.expr<?x?xi32>) -> !fir.shape<2>
+ %mask = hlfir.elemental %shape : (!fir.shape<2>) -> !hlfir.expr<?x?x!fir.logical<4>> {
+ ^bb0(%i: index, %j: index):
+ %val = hlfir.apply %arg0, %i, %j : (!hlfir.expr<?x?xi32>, index, index) -> i32
+ %cmp = arith.cmpi eq, %val, %c42 : i32
+ %logical = fir.convert %cmp : (i1) -> !fir.logical<4>
+ hlfir.yield_element %logical : !fir.logical<4>
+ }
+ %res = hlfir.maxloc %arg0 mask %mask : (!hlfir.expr<?x?xi32>, !hlfir.expr<?x?x!fir.logical<4>>) -> !hlfir.expr<2xi32>
+ return %res : !hlfir.expr<2xi32>
+}
+// CHECK-LABEL: func.func @test_maxloc_2d_equality_constant
+// CHECK-DAG: %[[C42:.*]] = arith.constant 42 : i32
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : i32
+// CHECK-DAG: %[[TRUE:.*]] = arith.constant true
+// CHECK: %[[OUTER:.*]]:3 = fir.do_loop %[[IV1:.*]] = {{.*}} iter_args(%[[O1:.*]] = %[[C0]], %[[O2:.*]] = %[[C0]], %[[O3:.*]] = %[[TRUE]]) -> (i32, i32, i1)
+// CHECK: %[[INNER:.*]]:3 = fir.do_loop %[[IV2:.*]] = {{.*}} iter_args(%[[I1:.*]] = %[[O1]], %[[I2:.*]] = %[[O2]], %[[I3:.*]] = %[[O3]]) -> (i32, i32, i1)
+// CHECK-NOT: hlfir.apply {{.*}} !fir.logical<4>
+// CHECK: %[[VAL:.*]] = hlfir.apply %{{.*}}, %[[IV2]], %[[IV1]]
+// CHECK: %[[EQ:.*]] = arith.cmpi eq, %[[VAL]], %[[C42]]
+// CHECK: %[[COND:.*]] = arith.andi %[[EQ]], %[[I3]]
+
+// Rank 3: Constant: A == 42
+func.func @test_maxloc_3d_equality_constant(%arg0: !hlfir.expr<?x?x?xi32>) -> !hlfir.expr<3xi32> {
+ %c42 = arith.constant 42 : i32
+ %shape = hlfir.shape_of %arg0 : (!hlfir.expr<?x?x?xi32>) -> !fir.shape<3>
+ %mask = hlfir.elemental %shape : (!fir.shape<3>) -> !hlfir.expr<?x?x?x!fir.logical<4>> {
+ ^bb0(%i: index, %j: index, %k: index):
+ %val = hlfir.apply %arg0, %i, %j, %k : (!hlfir.expr<?x?x?xi32>, index, index, index) -> i32
+ %cmp = arith.cmpi eq, %val, %c42 : i32
+ %logical = fir.convert %cmp : (i1) -> !fir.logical<4>
+ hlfir.yield_element %logical : !fir.logical<4>
+ }
+ %res = hlfir.maxloc %arg0 mask %mask : (!hlfir.expr<?x?x?xi32>, !hlfir.expr<?x?x?x!fir.logical<4>>) -> !hlfir.expr<3xi32>
+ return %res : !hlfir.expr<3xi32>
+}
+// CHECK-LABEL: func.func @test_maxloc_3d_equality_constant
+// CHECK-DAG: %[[C42:.*]] = arith.constant 42 : i32
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : i32
+// CHECK-DAG: %[[TRUE:.*]] = arith.constant true
+// CHECK: %[[OUT:.*]]:4 = fir.do_loop {{.*}} iter_args(%[[O1:.*]] = %[[C0]], %[[O2:.*]] = %[[C0]], %[[O3:.*]] = %[[C0]], %[[O4:.*]] = %[[TRUE]]) -> (i32, i32, i32, i1)
+// CHECK: %[[MID:.*]]:4 = fir.do_loop {{.*}} iter_args(%[[M1:.*]] = %[[O1]], %[[M2:.*]] = %[[O2]], %[[M3:.*]] = %[[O3]], %[[M4:.*]] = %[[O4]]) -> (i32, i32, i32, i1)
+// CHECK: %[[INN:.*]]:4 = fir.do_loop %[[IV3:.*]] = {{.*}} iter_args(%[[I1:.*]] = %[[M1]], %[[I2:.*]] = %[[M2]], %[[I3:.*]] = %[[M3]], %[[I4:.*]] = %[[M4]]) -> (i32, i32, i32, i1)
+// CHECK-NOT: hlfir.apply {{.*}} !fir.logical<4>
+// CHECK: %[[VAL:.*]] = hlfir.apply %{{.*}}, %[[IV3]], %{{.*}}, %{{.*}}
+// CHECK: %[[EQ:.*]] = arith.cmpi eq, %[[VAL]], %[[C42]]
+// CHECK: %[[COND:.*]] = arith.andi %[[EQ]], %[[I4]]
+
+// No Match: Result must be 0
+func.func @test_maxloc_no_match(%arg0: !hlfir.expr<?xi32>) -> !hlfir.expr<1xi32> {
+ %c99 = arith.constant 99 : i32
+ %shape = hlfir.shape_of %arg0 : (!hlfir.expr<?xi32>) -> !fir.shape<1>
+ %mask = hlfir.elemental %shape : (!fir.shape<1>) -> !hlfir.expr<?x!fir.logical<4>> {
+ ^bb0(%i: index):
+ %val = hlfir.apply %arg0, %i : (!hlfir.expr<?xi32>, index) -> i32
+ %cmp = arith.cmpi eq, %val, %c99 : i32
+ %logical = fir.convert %cmp : (i1) -> !fir.logical<4>
+ hlfir.yield_element %logical : !fir.logical<4>
+ }
+ %res = hlfir.maxloc %arg0 mask %mask : (!hlfir.expr<?xi32>, !hlfir.expr<?x!fir.logical<4>>) -> !hlfir.expr<1xi32>
+ return %res : !hlfir.expr<1xi32>
+}
+// CHECK-LABEL: func.func @test_maxloc_no_match(
+// CHECK-SAME: %[[ARRAY_NM:.*]]: !hlfir.expr<?xi32>)
+// CHECK-DAG: %[[C99:.*]] = arith.constant 99 : i32
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : i32
+// CHECK-DAG: %[[TRUE:.*]] = arith.constant true
+// CHECK: fir.do_loop %[[IV:.*]] = {{.*}} iter_args(%[[LOC:.*]] = %[[C0]], %[[FIRST:.*]] = %[[TRUE]]) -> (i32, i1)
+// CHECK-NOT: hlfir.apply {{.*}} !fir.logical<4>
+// CHECK: %[[VAL:.*]] = hlfir.apply %[[ARRAY_NM]], %[[IV]]
+// CHECK: %[[MATCH:.*]] = arith.cmpi eq, %[[VAL]], %[[C99]] : i32
+// CHECK: %[[COND:.*]] = arith.andi %[[MATCH]], %[[FIRST]] : i1
+// CHECK: %[[IF_RES:.*]]:2 = fir.if %[[COND]] -> (i32, i1) {
+// CHECK: %[[CONV:.*]] = fir.convert %[[IV]]
+// CHECK: fir.result %[[CONV]], %false
+// CHECK: } else {
+// CHECK: fir.result %[[LOC]], %[[FIRST]] : i32, i1
+// CHECK: }
+
+// First Match: Duplicate values
+func.func @test_maxloc_first_match(%arg0: !hlfir.expr<?xi32>, %target: i32) -> !hlfir.expr<1xi32> {
+ %shape = hlfir.shape_of %arg0 : (!hlfir.expr<?xi32>) -> !fir.shape<1>
+ %mask = hlfir.elemental %shape : (!fir.shape<1>) -> !hlfir.expr<?x!fir.logical<4>> {
+ ^bb0(%i: index):
+ %val = hlfir.apply %arg0, %i : (!hlfir.expr<?xi32>, index) -> i32
+ %cmp = arith.cmpi eq, %val, %target : i32
+ %logical = fir.convert %cmp : (i1) -> !fir.logical<4>
+ hlfir.yield_element %logical : !fir.logical<4>
+ }
+ %res = hlfir.maxloc %arg0 mask %mask : (!hlfir.expr<?xi32>, !hlfir.expr<?x!fir.logical<4>>) -> !hlfir.expr<1xi32>
+ return %res : !hlfir.expr<1xi32>
+}
+// CHECK-LABEL: func.func @test_maxloc_first_match(
+// CHECK-SAME: %[[ARRAY_FM:.*]]: !hlfir.expr<?xi32>, %[[TARGET_FM:.*]]: i32)
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : i32
+// CHECK-DAG: %[[TRUE:.*]] = arith.constant true
+// CHECK-DAG: %[[FALSE:.*]] = arith.constant false
+// Verify loop has only 2 iter_args (Coord, FirstHitFlag)
+// CHECK: fir.do_loop %[[IV:.*]] = {{.*}} iter_args(%[[LOC:.*]] = %[[C0]], %[[FIRST:.*]] = %[[TRUE]]) -> (i32, i1)
+// Verify mask elemental is bypassed
+// CHECK-NOT: hlfir.apply {{.*}} !fir.logical<4>
+// Verify the "Locking" logic: (Match 'and' is_first)
+// CHECK: %[[VAL:.*]] = hlfir.apply %[[ARRAY_FM]], %[[IV]]
+// CHECK: %[[MATCH:.*]] = arith.cmpi eq, %[[VAL]], %[[TARGET_FM]] : i32
+// CHECK: %[[COND:.*]] = arith.andi %[[MATCH]], %[[FIRST]] : i1
+// Verify that once a match is found, we result in %false to lock it
+// CHECK: %[[IF_RES:.*]]:2 = fir.if %[[COND]] -> (i32, i1) {
+// CHECK: %[[CONV:.*]] = fir.convert %[[IV]]
+// CHECK: fir.result %[[CONV]], %[[FALSE]] : i32, i1
+// CHECK: } else {
+// CHECK: fir.result %[[LOC]], %[[FIRST]] : i32, i1
+// CHECK: }
+
+// Negative test: Mask refers to a different array (%arg1) than the search
+// array (%arg0).
+func.func @test_maxloc_different_arrays(%arg0: !hlfir.expr<?xi32>, %arg1: !hlfir.expr<?xi32>, %target: i32) -> !hlfir.expr<1xi32> {
+ %shape = hlfir.shape_of %arg0 : (!hlfir.expr<?xi32>) -> !fir.shape<1>
+ %mask = hlfir.elemental %shape : (!fir.shape<1>) -> !hlfir.expr<?x!fir.logical<4>> {
+ ^bb0(%i: index):
+ // Optimization should fail here because %arg1 != %arg0
+ %val_b = hlfir.apply %arg1, %i : (!hlfir.expr<?xi32>, index) -> i32
+ %cmp = arith.cmpi eq, %val_b, %target : i32
+ %logical = fir.convert %cmp : (i1) -> !fir.logical<4>
+ hlfir.yield_element %logical : !fir.logical<4>
+ }
+ %res = hlfir.maxloc %arg0 mask %mask : (!hlfir.expr<?xi32>, !hlfir.expr<?x!fir.logical<4>>) -> !hlfir.expr<1xi32>
+ return %res : !hlfir.expr<1xi32>
+}
+// CHECK-LABEL: func.func @test_maxloc_different_arrays(
+// CHECK-SAME: %[[ARRAY_A:.*]]: !hlfir.expr<?xi32>, %[[ARRAY_B:.*]]: !hlfir.expr<?xi32>, %[[TARGET:.*]]: i32)
+// CHECK: %[[SENTINEL:.*]] = arith.constant -2147483648 : i32
+// Verify the loop uses three iter_args (standard path)
+// CHECK: fir.do_loop %[[IV:.*]] = {{.*}} iter_args(%[[LOC:.*]] = {{.*}}, %[[MAX:.*]] = %[[SENTINEL]], %[[FIRST:.*]] = {{.*}}) -> (i32, i32, i1)
+// Verify the mask is applied (Since we can't inline it safely)
+// CHECK: %[[MASK_VAL:.*]] = hlfir.apply {{.*}} : (!hlfir.expr<?x!fir.logical<4>>, index) -> !fir.logical<4>
+// CHECK: %[[MASK_I1:.*]] = fir.convert %[[MASK_VAL]] : (!fir.logical<4>) -> i1
+// CHECK: fir.if %[[MASK_I1]] -> (i32, i32, i1) {
+// CHECK: %[[VAL_A:.*]] = hlfir.apply %[[ARRAY_A]], %[[IV]]
+// CHECK: %[[CMP:.*]] = arith.cmpi sgt, %[[VAL_A]], %[[MAX]]
+
+// Negative Test: The target value is another array, so it is not invariant.
+func.func @test_maxloc_non_invariant_target(%arg0: !hlfir.expr<?xi32>, %arg1: !hlfir.expr<?xi32>) -> !hlfir.expr<1xi32> {
+ %shape = hlfir.shape_of %arg0 : (!hlfir.expr<?xi32>) -> !fir.shape<1>
+ %mask = hlfir.elemental %shape : (!fir.shape<1>) -> !hlfir.expr<?x!fir.logical<4>> {
+ ^bb0(%i: index):
+ %val_a = hlfir.apply %arg0, %i : (!hlfir.expr<?xi32>, index) -> i32
+ %val_target = hlfir.apply %arg1, %i : (!hlfir.expr<?xi32>, index) -> i32
+ // Optimization should fail here because %val_target is defined inside the elemental
+ %cmp = arith.cmpi eq, %val_a, %val_target : i32
+ %logical = fir.convert %cmp : (i1) -> !fir.logical<4>
+ hlfir.yield_element %logical : !fir.logical<4>
+ }
+ %res = hlfir.maxloc %arg0 mask %mask : (!hlfir.expr<?xi32>, !hlfir.expr<?x!fir.logical<4>>) -> !hlfir.expr<1xi32>
+ return %res : !hlfir.expr<1xi32>
+}
+// CHECK-LABEL: func.func @test_maxloc_non_invariant_target(
+// CHECK-SAME: %[[ARRAY_A:.*]]: !hlfir.expr<?xi32>, %[[ARRAY_B:.*]]: !hlfir.expr<?xi32>)
+// CHECK: %[[SENTINEL:.*]] = arith.constant -2147483648 : i32
+// Verify the loop uses three iter_args (Standard path)
+// CHECK: fir.do_loop %[[IV:.*]] = {{.*}} iter_args(%[[LOC:.*]] = {{.*}}, %[[MAX:.*]] = %[[SENTINEL]], %[[FIRST:.*]] = {{.*}}) -> (i32, i32, i1)
+// Verify the mask is still applied (because we couldn't inline the comparison)
+// CHECK: %[[MASK_BIT:.*]] = hlfir.apply %{{.*}}, %[[IV]] : (!hlfir.expr<?x!fir.logical<4>>, index) -> !fir.logical<4>
+// CHECK: %[[MASK_I1:.*]] = fir.convert %[[MASK_BIT]] : (!fir.logical<4>) -> i1
+// CHECK: fir.if %[[MASK_I1]] -> (i32, i32, i1) {
+// CHECK: %[[VAL:.*]] = hlfir.apply %[[ARRAY_A]], %[[IV]]
+// CHECK: %[[CMP:.*]] = arith.cmpi sgt, %[[VAL]], %[[MAX]]
+
diff --git a/flang/test/HLFIR/simplify-hlfir-intrinsics-equality-minloc.fir b/flang/test/HLFIR/simplify-hlfir-intrinsics-equality-minloc.fir
new file mode 100644
index 0000000000000..0bfa58968a2fe
--- /dev/null
+++ b/flang/test/HLFIR/simplify-hlfir-intrinsics-equality-minloc.fir
@@ -0,0 +1,274 @@
+// RUN: fir-opt %s --simplify-hlfir-intrinsics | FileCheck %s
+
+// Rank 1: Variable: A == %target
+func.func @test_minloc_1d_equality_variable(%arg0: !hlfir.expr<?xi32>, %target: i32) -> !hlfir.expr<1xi32> {
+ %shape = hlfir.shape_of %arg0 : (!hlfir.expr<?xi32>) -> !fir.shape<1>
+ %mask = hlfir.elemental %shape : (!fir.shape<1>) -> !hlfir.expr<?x!fir.logical<4>> {
+ ^bb0(%i: index):
+ %val = hlfir.apply %arg0, %i : (!hlfir.expr<?xi32>, index) -> i32
+ %cmp = arith.cmpi eq, %val, %target : i32
+ %logical = fir.convert %cmp : (i1) -> !fir.logical<4>
+ hlfir.yield_element %logical : !fir.logical<4>
+ }
+ %res = hlfir.minloc %arg0 mask %mask : (!hlfir.expr<?xi32>, !hlfir.expr<?x!fir.logical<4>>) -> !hlfir.expr<1xi32>
+ return %res : !hlfir.expr<1xi32>
+}
+// CHECK-LABEL: func.func @test_minloc_1d_equality_variable
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : i32
+// CHECK-DAG: %[[TRUE:.*]] = arith.constant true
+// CHECK-NOT: arith.constant 2147483647 : i32
+// CHECK: fir.do_loop %[[IV:.*]] = {{.*}} iter_args(%[[LOC:.*]] = %[[C0]], %[[FIRST:.*]] = %[[TRUE]]) -> (i32, i1)
+// CHECK-NOT: hlfir.apply {{.*}} !fir.logical<4>
+// CHECK: %[[VAL:.*]] = hlfir.apply %{{.*}}, %[[IV]]
+// CHECK: %[[EQ:.*]] = arith.cmpi eq, %[[VAL]], %{{.*}}
+// CHECK: %[[COND:.*]] = arith.andi %[[EQ]], %[[FIRST]]
+// CHECK: %[[IF_RES:.*]]:2 = fir.if %[[COND]] -> (i32, i1)
+
+// Rank 2: Variable: A == %target
+func.func @test_minloc_2d_equality_variable(%arg0: !hlfir.expr<?x?xi32>, %target: i32) -> !hlfir.expr<2xi32> {
+ %shape = hlfir.shape_of %arg0 : (!hlfir.expr<?x?xi32>) -> !fir.shape<2>
+ %mask = hlfir.elemental %shape : (!fir.shape<2>) -> !hlfir.expr<?x?x!fir.logical<4>> {
+ ^bb0(%i: index, %j: index):
+ %val = hlfir.apply %arg0, %i, %j : (!hlfir.expr<?x?xi32>, index, index) -> i32
+ %cmp = arith.cmpi eq, %val, %target : i32
+ %logical = fir.convert %cmp : (i1) -> !fir.logical<4>
+ hlfir.yield_element %logical : !fir.logical<4>
+ }
+ %res = hlfir.minloc %arg0 mask %mask : (!hlfir.expr<?x?xi32>, !hlfir.expr<?x?x!fir.logical<4>>) -> !hlfir.expr<2xi32>
+ return %res : !hlfir.expr<2xi32>
+}
+// CHECK-LABEL: func.func @test_minloc_2d_equality_variable
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : i32
+// CHECK-DAG: %[[TRUE:.*]] = arith.constant true
+// CHECK: %[[RES_OUTER:.*]]:3 = fir.do_loop %[[IV1:.*]] = {{.*}} iter_args(%[[OUT1:.*]] = %[[C0]], %[[OUT2:.*]] = %[[C0]], %[[OUT3:.*]] = %[[TRUE]]) -> (i32, i32, i1)
+// CHECK: %[[RES_INNER:.*]]:3 = fir.do_loop %[[IV2:.*]] = {{.*}} iter_args(%[[IN1:.*]] = %[[OUT1]], %[[IN2:.*]] = %[[OUT2]], %[[IN3:.*]] = %[[OUT3]]) -> (i32, i32, i1)
+// CHECK-NOT: hlfir.apply {{.*}} !fir.logical<4>
+// CHECK: %[[EQ:.*]] = arith.cmpi eq, {{.*}}
+// CHECK: %[[COND:.*]] = arith.andi %[[EQ]], %[[IN3]]
+// CHECK: %[[IF:.*]]:3 = fir.if %[[COND]] -> (i32, i32, i1)
+
+// Rank 3: Variable: A == %target
+func.func @test_minloc_3d_equality_variable(%arg0: !hlfir.expr<?x?x?xi32>, %target: i32) -> !hlfir.expr<3xi32> {
+ %shape = hlfir.shape_of %arg0 : (!hlfir.expr<?x?x?xi32>) -> !fir.shape<3>
+ %mask = hlfir.elemental %shape : (!fir.shape<3>) -> !hlfir.expr<?x?x?x!fir.logical<4>> {
+ ^bb0(%i: index, %j: index, %k: index):
+ %val = hlfir.apply %arg0, %i, %j, %k : (!hlfir.expr<?x?x?xi32>, index, index, index) -> i32
+ %cmp = arith.cmpi eq, %val, %target : i32
+ %logical = fir.convert %cmp : (i1) -> !fir.logical<4>
+ hlfir.yield_element %logical : !fir.logical<4>
+ }
+ %res = hlfir.minloc %arg0 mask %mask : (!hlfir.expr<?x?x?xi32>, !hlfir.expr<?x?x?x!fir.logical<4>>) -> !hlfir.expr<3xi32>
+ return %res : !hlfir.expr<3xi32>
+}
+// CHECK-LABEL: func.func @test_minloc_3d_equality_variable
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : i32
+// CHECK-DAG: %[[TRUE:.*]] = arith.constant true
+// CHECK: %[[OUTER:.*]]:4 = fir.do_loop %[[IV1:.*]] = {{.*}} iter_args(%[[O1:.*]] = %[[C0]], %[[O2:.*]] = %[[C0]], %[[O3:.*]] = %[[C0]], %[[O4:.*]] = %[[TRUE]]) -> (i32, i32, i32, i1)
+// CHECK: %[[MIDDLE:.*]]:4 = fir.do_loop %[[IV2:.*]] = {{.*}} iter_args(%[[M1:.*]] = %[[O1]], %[[M2:.*]] = %[[O2]], %[[M3:.*]] = %[[O3]], %[[M4:.*]] = %[[O4]]) -> (i32, i32, i32, i1)
+// CHECK: %[[INNER:.*]]:4 = fir.do_loop %[[IV3:.*]] = {{.*}} iter_args(%[[I1:.*]] = %[[M1]], %[[I2:.*]] = %[[M2]], %[[I3:.*]] = %[[M3]], %[[I4:.*]] = %[[M4]]) -> (i32, i32, i32, i1)
+// CHECK-NOT: hlfir.apply {{.*}} !fir.logical<4>
+// CHECK: %[[EQ:.*]] = arith.cmpi eq, {{.*}}
+// CHECK: %[[COND:.*]] = arith.andi %[[EQ]], %[[I4]]
+// CHECK: %[[IF:.*]]:4 = fir.if %[[COND]] -> (i32, i32, i32, i1)
+
+// Rank 1: Constant: A == 42
+func.func @test_minloc_1d_equality_constant(%arg0: !hlfir.expr<?xi32>) -> !hlfir.expr<1xi32> {
+ %c42 = arith.constant 42 : i32
+ %shape = hlfir.shape_of %arg0 : (!hlfir.expr<?xi32>) -> !fir.shape<1>
+ %mask = hlfir.elemental %shape : (!fir.shape<1>) -> !hlfir.expr<?x!fir.logical<4>> {
+ ^bb0(%i: index):
+ %val = hlfir.apply %arg0, %i : (!hlfir.expr<?xi32>, index) -> i32
+ %cmp = arith.cmpi eq, %val, %c42 : i32
+ %logical = fir.convert %cmp : (i1) -> !fir.logical<4>
+ hlfir.yield_element %logical : !fir.logical<4>
+ }
+ %res = hlfir.minloc %arg0 mask %mask : (!hlfir.expr<?xi32>, !hlfir.expr<?x!fir.logical<4>>) -> !hlfir.expr<1xi32>
+ return %res : !hlfir.expr<1xi32>
+}
+// CHECK-LABEL: func.func @test_minloc_1d_equality_constant
+// CHECK-DAG: %[[C42:.*]] = arith.constant 42 : i32
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : i32
+// CHECK-DAG: %[[TRUE:.*]] = arith.constant true
+
+// CHECK: %[[RES:.*]]:2 = fir.do_loop %[[IV:.*]] = {{.*}} iter_args(%[[LOC:.*]] = %[[C0]], %[[FIRST:.*]] = %[[TRUE]]) -> (i32, i1)
+// CHECK-NOT: hlfir.apply {{.*}} !fir.logical<4>
+// CHECK: %[[VAL:.*]] = hlfir.apply %{{.*}}, %[[IV]]
+// CHECK: %[[EQ:.*]] = arith.cmpi eq, %[[VAL]], %[[C42]]
+// CHECK: %[[COND:.*]] = arith.andi %[[EQ]], %[[FIRST]]
+// CHECK: %[[IF:.*]]:2 = fir.if %[[COND]] -> (i32, i1)
+
+// Rank 2: Constant: A == 42
+func.func @test_minloc_2d_equality_constant(%arg0: !hlfir.expr<?x?xi32>) -> !hlfir.expr<2xi32> {
+ %c42 = arith.constant 42 : i32
+ %shape = hlfir.shape_of %arg0 : (!hlfir.expr<?x?xi32>) -> !fir.shape<2>
+ %mask = hlfir.elemental %shape : (!fir.shape<2>) -> !hlfir.expr<?x?x!fir.logical<4>> {
+ ^bb0(%i: index, %j: index):
+ %val = hlfir.apply %arg0, %i, %j : (!hlfir.expr<?x?xi32>, index, index) -> i32
+ %cmp = arith.cmpi eq, %val, %c42 : i32
+ %logical = fir.convert %cmp : (i1) -> !fir.logical<4>
+ hlfir.yield_element %logical : !fir.logical<4>
+ }
+ %res = hlfir.minloc %arg0 mask %mask : (!hlfir.expr<?x?xi32>, !hlfir.expr<?x?x!fir.logical<4>>) -> !hlfir.expr<2xi32>
+ return %res : !hlfir.expr<2xi32>
+}
+// CHECK-LABEL: func.func @test_minloc_2d_equality_constant
+// CHECK-DAG: %[[C42:.*]] = arith.constant 42 : i32
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : i32
+// CHECK-DAG: %[[TRUE:.*]] = arith.constant true
+// CHECK: %[[OUTER:.*]]:3 = fir.do_loop %[[IV1:.*]] = {{.*}} iter_args(%[[O1:.*]] = %[[C0]], %[[O2:.*]] = %[[C0]], %[[O3:.*]] = %[[TRUE]]) -> (i32, i32, i1)
+// CHECK: %[[INNER:.*]]:3 = fir.do_loop %[[IV2:.*]] = {{.*}} iter_args(%[[I1:.*]] = %[[O1]], %[[I2:.*]] = %[[O2]], %[[I3:.*]] = %[[O3]]) -> (i32, i32, i1)
+// CHECK-NOT: hlfir.apply {{.*}} !fir.logical<4>
+// CHECK: %[[VAL:.*]] = hlfir.apply %{{.*}}, %[[IV2]], %[[IV1]]
+// CHECK: %[[EQ:.*]] = arith.cmpi eq, %[[VAL]], %[[C42]]
+// CHECK: %[[COND:.*]] = arith.andi %[[EQ]], %[[I3]]
+
+// Rank 3: Constant: A == 42
+func.func @test_minloc_3d_equality_constant(%arg0: !hlfir.expr<?x?x?xi32>) -> !hlfir.expr<3xi32> {
+ %c42 = arith.constant 42 : i32
+ %shape = hlfir.shape_of %arg0 : (!hlfir.expr<?x?x?xi32>) -> !fir.shape<3>
+ %mask = hlfir.elemental %shape : (!fir.shape<3>) -> !hlfir.expr<?x?x?x!fir.logical<4>> {
+ ^bb0(%i: index, %j: index, %k: index):
+ %val = hlfir.apply %arg0, %i, %j, %k : (!hlfir.expr<?x?x?xi32>, index, index, index) -> i32
+ %cmp = arith.cmpi eq, %val, %c42 : i32
+ %logical = fir.convert %cmp : (i1) -> !fir.logical<4>
+ hlfir.yield_element %logical : !fir.logical<4>
+ }
+ %res = hlfir.minloc %arg0 mask %mask : (!hlfir.expr<?x?x?xi32>, !hlfir.expr<?x?x?x!fir.logical<4>>) -> !hlfir.expr<3xi32>
+ return %res : !hlfir.expr<3xi32>
+}
+// CHECK-LABEL: func.func @test_minloc_3d_equality_constant
+// CHECK-DAG: %[[C42:.*]] = arith.constant 42 : i32
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : i32
+// CHECK-DAG: %[[TRUE:.*]] = arith.constant true
+// CHECK: %[[OUT:.*]]:4 = fir.do_loop {{.*}} iter_args(%[[O1:.*]] = %[[C0]], %[[O2:.*]] = %[[C0]], %[[O3:.*]] = %[[C0]], %[[O4:.*]] = %[[TRUE]]) -> (i32, i32, i32, i1)
+// CHECK: %[[MID:.*]]:4 = fir.do_loop {{.*}} iter_args(%[[M1:.*]] = %[[O1]], %[[M2:.*]] = %[[O2]], %[[M3:.*]] = %[[O3]], %[[M4:.*]] = %[[O4]]) -> (i32, i32, i32, i1)
+// CHECK: %[[INN:.*]]:4 = fir.do_loop %[[IV3:.*]] = {{.*}} iter_args(%[[I1:.*]] = %[[M1]], %[[I2:.*]] = %[[M2]], %[[I3:.*]] = %[[M3]], %[[I4:.*]] = %[[M4]]) -> (i32, i32, i32, i1)
+// CHECK-NOT: hlfir.apply {{.*}} !fir.logical<4>
+// CHECK: %[[VAL:.*]] = hlfir.apply %{{.*}}, %[[IV3]], %{{.*}}, %{{.*}}
+// CHECK: %[[EQ:.*]] = arith.cmpi eq, %[[VAL]], %[[C42]]
+// CHECK: %[[COND:.*]] = arith.andi %[[EQ]], %[[I4]]
+
+// No Match: Result must be 0
+func.func @test_minloc_no_match(%arg0: !hlfir.expr<?xi32>) -> !hlfir.expr<1xi32> {
+ %c99 = arith.constant 99 : i32
+ %shape = hlfir.shape_of %arg0 : (!hlfir.expr<?xi32>) -> !fir.shape<1>
+ %mask = hlfir.elemental %shape : (!fir.shape<1>) -> !hlfir.expr<?x!fir.logical<4>> {
+ ^bb0(%i: index):
+ %val = hlfir.apply %arg0, %i : (!hlfir.expr<?xi32>, index) -> i32
+ %cmp = arith.cmpi eq, %val, %c99 : i32
+ %logical = fir.convert %cmp : (i1) -> !fir.logical<4>
+ hlfir.yield_element %logical : !fir.logical<4>
+ }
+ %res = hlfir.minloc %arg0 mask %mask : (!hlfir.expr<?xi32>, !hlfir.expr<?x!fir.logical<4>>) -> !hlfir.expr<1xi32>
+ return %res : !hlfir.expr<1xi32>
+}
+// CHECK-LABEL: func.func @test_minloc_no_match(
+// CHECK-SAME: %[[ARRAY_NM:.*]]: !hlfir.expr<?xi32>)
+// CHECK-DAG: %[[C99:.*]] = arith.constant 99 : i32
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : i32
+// CHECK-DAG: %[[TRUE:.*]] = arith.constant true
+// CHECK: fir.do_loop %[[IV:.*]] = {{.*}} iter_args(%[[LOC:.*]] = %[[C0]], %[[FIRST:.*]] = %[[TRUE]]) -> (i32, i1)
+// CHECK-NOT: hlfir.apply {{.*}} !fir.logical<4>
+// CHECK: %[[VAL:.*]] = hlfir.apply %[[ARRAY_NM]], %[[IV]]
+// CHECK: %[[MATCH:.*]] = arith.cmpi eq, %[[VAL]], %[[C99]] : i32
+// CHECK: %[[COND:.*]] = arith.andi %[[MATCH]], %[[FIRST]] : i1
+// CHECK: %[[IF_RES:.*]]:2 = fir.if %[[COND]] -> (i32, i1) {
+// CHECK: %[[CONV:.*]] = fir.convert %[[IV]]
+// CHECK: fir.result %[[CONV]], %false
+// CHECK: } else {
+// CHECK: fir.result %[[LOC]], %[[FIRST]] : i32, i1
+// CHECK: }
+
+// First Match: Duplicate values
+func.func @test_minloc_first_match(%arg0: !hlfir.expr<?xi32>, %target: i32) -> !hlfir.expr<1xi32> {
+ %shape = hlfir.shape_of %arg0 : (!hlfir.expr<?xi32>) -> !fir.shape<1>
+ %mask = hlfir.elemental %shape : (!fir.shape<1>) -> !hlfir.expr<?x!fir.logical<4>> {
+ ^bb0(%i: index):
+ %val = hlfir.apply %arg0, %i : (!hlfir.expr<?xi32>, index) -> i32
+ %cmp = arith.cmpi eq, %val, %target : i32
+ %logical = fir.convert %cmp : (i1) -> !fir.logical<4>
+ hlfir.yield_element %logical : !fir.logical<4>
+ }
+ %res = hlfir.minloc %arg0 mask %mask : (!hlfir.expr<?xi32>, !hlfir.expr<?x!fir.logical<4>>) -> !hlfir.expr<1xi32>
+ return %res : !hlfir.expr<1xi32>
+}
+// CHECK-LABEL: func.func @test_minloc_first_match(
+// CHECK-SAME: %[[ARRAY_FM:.*]]: !hlfir.expr<?xi32>, %[[TARGET_FM:.*]]: i32)
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : i32
+// CHECK-DAG: %[[TRUE:.*]] = arith.constant true
+// CHECK-DAG: %[[FALSE:.*]] = arith.constant false
+// Verify loop has only 2 iter_args (Coord, FirstHitFlag)
+// CHECK: fir.do_loop %[[IV:.*]] = {{.*}} iter_args(%[[LOC:.*]] = %[[C0]], %[[FIRST:.*]] = %[[TRUE]]) -> (i32, i1)
+// Verify mask elemental is bypassed
+// CHECK-NOT: hlfir.apply {{.*}} !fir.logical<4>
+// Verify the "Locking" logic: (Match AND is_first)
+// CHECK: %[[VAL:.*]] = hlfir.apply %[[ARRAY_FM]], %[[IV]]
+// CHECK: %[[MATCH:.*]] = arith.cmpi eq, %[[VAL]], %[[TARGET_FM]] : i32
+// CHECK: %[[COND:.*]] = arith.andi %[[MATCH]], %[[FIRST]] : i1
+// Verify that once a match is found, we result in %false to lock it
+// CHECK: %[[IF_RES:.*]]:2 = fir.if %[[COND]] -> (i32, i1) {
+// CHECK: %[[CONV:.*]] = fir.convert %[[IV]]
+// CHECK: fir.result %[[CONV]], %[[FALSE]] : i32, i1
+// CHECK: } else {
+// CHECK: fir.result %[[LOC]], %[[FIRST]] : i32, i1
+// CHECK: }
+
+// Negative test: Mask refers to a different array (%arg1) than the search
+// array (%arg0).
+func.func @test_minloc_different_arrays(%arg0: !hlfir.expr<?xi32>, %arg1: !hlfir.expr<?xi32>, %target: i32) -> !hlfir.expr<1xi32> {
+ %shape = hlfir.shape_of %arg0 : (!hlfir.expr<?xi32>) -> !fir.shape<1>
+ %mask = hlfir.elemental %shape : (!fir.shape<1>) -> !hlfir.expr<?x!fir.logical<4>> {
+ ^bb0(%i: index):
+ // Optimization should fail here because %arg1 != %arg0
+ %val_b = hlfir.apply %arg1, %i : (!hlfir.expr<?xi32>, index) -> i32
+ %cmp = arith.cmpi eq, %val_b, %target : i32
+ %logical = fir.convert %cmp : (i1) -> !fir.logical<4>
+ hlfir.yield_element %logical : !fir.logical<4>
+ }
+ %res = hlfir.minloc %arg0 mask %mask : (!hlfir.expr<?xi32>, !hlfir.expr<?x!fir.logical<4>>) -> !hlfir.expr<1xi32>
+ return %res : !hlfir.expr<1xi32>
+}
+// CHECK-LABEL: func.func @test_minloc_different_arrays(
+// CHECK-SAME: %[[ARRAY_A:.*]]: !hlfir.expr<?xi32>, %[[ARRAY_B:.*]]: !hlfir.expr<?xi32>, %[[TARGET:.*]]: i32)
+// CHECK: %[[SENTINEL:.*]] = arith.constant 2147483647 : i32
+
+// 1. Verify the loop uses three iter_args (Standard path: Loc, MinVal, FirstHit)
+// CHECK: fir.do_loop %[[IV:.*]] = {{.*}} iter_args(%[[LOC:.*]] = {{.*}}, %[[MIN_VAL:.*]] = %[[SENTINEL]], %[[FIRST:.*]] = {{.*}}) -> (i32, i32, i1)
+
+// 2. Verify the mask IS still applied (Optimization correctly skipped)
+// CHECK: %[[MASK_VAL:.*]] = hlfir.apply {{.*}} : (!hlfir.expr<?x!fir.logical<4>>, index) -> !fir.logical<4>
+// CHECK: %[[MASK_I1:.*]] = fir.convert %[[MASK_VAL]] : (!fir.logical<4>) -> i1
+
+// 3. Verify the standard path's MINLOC comparison logic (slt instead of sgt)
+// CHECK: fir.if %[[MASK_I1]] -> (i32, i32, i1) {
+// CHECK: %[[VAL_A:.*]] = hlfir.apply %[[ARRAY_A]], %[[IV]]
+// CHECK: %[[CMP:.*]] = arith.cmpi slt, %[[VAL_A]], %[[MIN_VAL]] : i32
+
+// Negative Test: The target value is another array, so it is not invariant.
+func.func @test_minloc_non_invariant_target(%arg0: !hlfir.expr<?xi32>, %arg1: !hlfir.expr<?xi32>) -> !hlfir.expr<1xi32> {
+ %shape = hlfir.shape_of %arg0 : (!hlfir.expr<?xi32>) -> !fir.shape<1>
+ %mask = hlfir.elemental %shape : (!fir.shape<1>) -> !hlfir.expr<?x!fir.logical<4>> {
+ ^bb0(%i: index):
+ %val_a = hlfir.apply %arg0, %i : (!hlfir.expr<?xi32>, index) -> i32
+ %val_target = hlfir.apply %arg1, %i : (!hlfir.expr<?xi32>, index) -> i32
+ // Optimization should fail here because %val_target is defined inside the
+ // elemental
+ %cmp = arith.cmpi eq, %val_a, %val_target : i32
+ %logical = fir.convert %cmp : (i1) -> !fir.logical<4>
+ hlfir.yield_element %logical : !fir.logical<4>
+ }
+ %res = hlfir.minloc %arg0 mask %mask : (!hlfir.expr<?xi32>, !hlfir.expr<?x!fir.logical<4>>) -> !hlfir.expr<1xi32>
+ return %res : !hlfir.expr<1xi32>
+}
+// CHECK-LABEL: func.func @test_minloc_non_invariant_target(
+// CHECK-SAME: %[[ARRAY_A:.*]]: !hlfir.expr<?xi32>, %[[ARRAY_B:.*]]: !hlfir.expr<?xi32>)
+// CHECK: %[[SENTINEL:.*]] = arith.constant 2147483647 : i32
+// Verify the loop uses three iter_args (standard path)
+// CHECK: fir.do_loop %[[IV:.*]] = {{.*}} iter_args(%[[LOC:.*]] = {{.*}}, %[[MAX:.*]] = %[[SENTINEL]], %[[FIRST:.*]] = {{.*}}) -> (i32, i32, i1)
+// Verify the mask is still applied (because we couldn't inline the comparison)
+// CHECK: %[[MASK_BIT:.*]] = hlfir.apply %{{.*}}, %[[IV]] : (!hlfir.expr<?x!fir.logical<4>>, index) -> !fir.logical<4>
+// CHECK: %[[MASK_I1:.*]] = fir.convert %[[MASK_BIT]] : (!fir.logical<4>) -> i1
+// CHECK: fir.if %[[MASK_I1]] -> (i32, i32, i1) {
+// CHECK: %[[VAL:.*]] = hlfir.apply %[[ARRAY_A]], %[[IV]]
+// CHECK: %[[CMP:.*]] = arith.cmpi slt, %[[VAL]], %[[MAX]]
+
More information about the flang-commits
mailing list