[flang-commits] [flang] [Flang] Minloc elemental intrinsic lowering (PR #74828)

via flang-commits flang-commits at lists.llvm.org
Fri Dec 8 03:26:13 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-flang-fir-hlfir

Author: David Green (davemgreen)

<details>
<summary>Changes</summary>

Currently the lowering of a minloc intrinsic with a mask will look something like
```
  %e = hlfir.elemental %shape ({
    ...
  })
  %m = hlfir.minloc %array mask %e
  hlfir.assign %m to %result
  hlfir.destroy %m
```
The elemental will be expanded into a temporary+loop, the minloc into a
FortranAMinloc call (which hopefully gets simplified to a specialized call that
can be inlined at the call site), and the assign might get expanded to a
FortranAAssign. The assign we could inline too, but it would be better to
generate the entire construct as single loop if we can - one that performs the
minloc calculation with the mask elemental computed inline and assigns directly
to the output array.

This patch attempt to do that, adding a hlfir version of the expansion code
from SimplifyIntrinsics that turns an assign+minloc+elemental into a single
combined loop nest. It attempts to reuse the methods in genMinlocReductionLoop
for constructing the loop with a modified loop body. The declaration for the
function is currently in Optimizer/Support/Utils.h, but there might be a better
place for it.

It is currently added as port of the OptimizedBufferizationPass. I originally
had it as part of the SimplifyHLFIRIntrinsics pass, but there were already some
methods doing similar things in OptimizedBufferization. It just needs to happen
before the elementals are expanded. I think I would like to do a similar thing
for maxloc and any/all/count too if this looks OK. I will rebase over #<!-- -->74436
once that goes in.

---

Patch is 148.67 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/74828.diff


15 Files Affected:

- (modified) flang/include/flang/Optimizer/HLFIR/HLFIROps.td (+26) 
- (modified) flang/include/flang/Optimizer/Support/Utils.h (+16) 
- (modified) flang/lib/Lower/HlfirIntrinsics.cpp (+65) 
- (modified) flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp (+75) 
- (modified) flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp (+31-7) 
- (modified) flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp (+282-88) 
- (modified) flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp (+92-96) 
- (modified) flang/test/HLFIR/invalid.fir (+68) 
- (modified) flang/test/HLFIR/memory-effects.fir (+15) 
- (added) flang/test/HLFIR/minloc-elemental.fir (+327) 
- (added) flang/test/HLFIR/minloc-lowering.fir (+329) 
- (added) flang/test/HLFIR/minloc.fir (+272) 
- (added) flang/test/Lower/HLFIR/minloc.f90 (+370) 
- (modified) flang/test/Lower/HLFIR/transformational.f90 (+9-19) 
- (modified) flang/test/Transforms/simplifyintrinsics.fir (+3-2) 


``````````diff
diff --git a/flang/include/flang/Optimizer/HLFIR/HLFIROps.td b/flang/include/flang/Optimizer/HLFIR/HLFIROps.td
index f4933785a8caa6..1f5bc42c43e65c 100644
--- a/flang/include/flang/Optimizer/HLFIR/HLFIROps.td
+++ b/flang/include/flang/Optimizer/HLFIR/HLFIROps.td
@@ -458,6 +458,32 @@ def hlfir_MinvalOp : hlfir_Op<"minval", [AttrSizedOperandSegments,
   let hasVerifier = 1;
 }
 
+def hlfir_MinlocOp : hlfir_Op<"minloc", [AttrSizedOperandSegments,
+    DeclareOpInterfaceMethods<ArithFastMathInterface>,
+    DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
+  let summary = "MINLOC transformational intrinsic";
+  let description = [{
+    Minlocs of an array.
+  }];
+
+  let arguments = (ins
+    AnyFortranArrayObject:$array,
+    Optional<AnyIntegerType>:$dim,
+    Optional<AnyFortranLogicalOrI1ArrayObject>:$mask,
+    Optional<Type<AnyLogicalLike.predicate>>:$back,
+    DefaultValuedAttr<Arith_FastMathAttr,
+                      "::mlir::arith::FastMathFlags::none">:$fastmath
+  );
+
+  let results = (outs AnyFortranValue);
+
+  let assemblyFormat = [{
+    $array (`dim` $dim^)? (`mask` $mask^)? (`back` $back^)?  attr-dict `:` functional-type(operands, results)
+  }];
+
+  let hasVerifier = 1;
+}
+
 def hlfir_ProductOp : hlfir_Op<"product", [AttrSizedOperandSegments,
     DeclareOpInterfaceMethods<ArithFastMathInterface>,
     DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
diff --git a/flang/include/flang/Optimizer/Support/Utils.h b/flang/include/flang/Optimizer/Support/Utils.h
index 34c8e79173bcd4..93caa8b23d320c 100644
--- a/flang/include/flang/Optimizer/Support/Utils.h
+++ b/flang/include/flang/Optimizer/Support/Utils.h
@@ -133,6 +133,22 @@ inline void intrinsicTypeTODO(fir::FirOpBuilder &builder, mlir::Type type,
            fir::numericMlirTypeToFortran(builder, type, loc, intrinsicName) +
            " in " + intrinsicName);
 }
+
+using MinlocBodyOpGeneratorTy = llvm::function_ref<mlir::Value(
+    fir::FirOpBuilder &, mlir::Location, const mlir::Type &, mlir::Value,
+    mlir::Value, mlir::Value, const llvm::SmallVectorImpl<mlir::Value> &)>;
+using InitValGeneratorTy = llvm::function_ref<mlir::Value(
+    fir::FirOpBuilder &, mlir::Location, const mlir::Type &)>;
+
+// Produces a loop nest for a Minloc intrinsic.
+void genMinlocReductionLoop(fir::FirOpBuilder &builder, mlir::Value array,
+                            InitValGeneratorTy initVal,
+                            MinlocBodyOpGeneratorTy genBody, unsigned rank,
+                            mlir::Type elementType, mlir::Location loc,
+                            mlir::Type maskElemType, mlir::Value resultArr,
+                            bool maskMayBeLogicalScalar);
+
+
 } // namespace fir
 
 #endif // FORTRAN_OPTIMIZER_SUPPORT_UTILS_H
diff --git a/flang/lib/Lower/HlfirIntrinsics.cpp b/flang/lib/Lower/HlfirIntrinsics.cpp
index 9f764b61425226..6e5ba92bee86a7 100644
--- a/flang/lib/Lower/HlfirIntrinsics.cpp
+++ b/flang/lib/Lower/HlfirIntrinsics.cpp
@@ -93,6 +93,19 @@ using HlfirMinvalLowering = HlfirReductionIntrinsic<hlfir::MinvalOp, true>;
 using HlfirAnyLowering = HlfirReductionIntrinsic<hlfir::AnyOp, false>;
 using HlfirAllLowering = HlfirReductionIntrinsic<hlfir::AllOp, false>;
 
+template <typename OP>
+class HlfirMinMaxLocIntrinsic : public HlfirTransformationalIntrinsic {
+public:
+  using HlfirTransformationalIntrinsic::HlfirTransformationalIntrinsic;
+
+protected:
+  mlir::Value
+  lowerImpl(const Fortran::lower::PreparedActualArguments &loweredActuals,
+            const fir::IntrinsicArgumentLoweringRules *argLowering,
+            mlir::Type stmtResultType) override;
+};
+using HlfirMinlocLowering = HlfirMinMaxLocIntrinsic<hlfir::MinlocOp>;
+
 template <typename OP>
 class HlfirProductIntrinsic : public HlfirTransformationalIntrinsic {
 public:
@@ -180,6 +193,31 @@ mlir::Value HlfirTransformationalIntrinsic::loadBoxAddress(
   return boxOrAbsent;
 }
 
+static mlir::Value loadOptionalValue(
+    mlir::Location loc, fir::FirOpBuilder &builder,
+    const std::optional<Fortran::lower::PreparedActualArgument> &arg,
+    hlfir::Entity actual) {
+  if (!arg->handleDynamicOptional())
+    return hlfir::loadTrivialScalar(loc, builder, actual);
+
+  mlir::Value isPresent = arg->getIsPresent();
+  mlir::Type eleType = hlfir::getFortranElementType(actual.getType());
+  return builder
+      .genIfOp(loc, {eleType}, isPresent,
+               /*withElseRegion=*/true)
+      .genThen([&]() {
+        assert(actual.isScalar() && fir::isa_trivial(eleType) &&
+               "must be a numerical or logical scalar");
+        hlfir::Entity val = hlfir::loadTrivialScalar(loc, builder, actual);
+        builder.create<fir::ResultOp>(loc, val);
+      })
+      .genElse([&]() {
+        mlir::Value zero = fir::factory::createZeroValue(builder, loc, eleType);
+        builder.create<fir::ResultOp>(loc, zero);
+      })
+      .getResults()[0];
+}
+
 llvm::SmallVector<mlir::Value> HlfirTransformationalIntrinsic::getOperandVector(
     const Fortran::lower::PreparedActualArguments &loweredActuals,
     const fir::IntrinsicArgumentLoweringRules *argLowering) {
@@ -206,6 +244,9 @@ llvm::SmallVector<mlir::Value> HlfirTransformationalIntrinsic::getOperandVector(
       else if (!argRules.handleDynamicOptional &&
                argRules.lowerAs != fir::LowerIntrinsicArgAs::Inquired)
         valArg = hlfir::derefPointersAndAllocatables(loc, builder, actual);
+      else if (argRules.handleDynamicOptional &&
+               argRules.lowerAs == fir::LowerIntrinsicArgAs::Value)
+        valArg = loadOptionalValue(loc, builder, arg, actual);
       else if (argRules.handleDynamicOptional)
         TODO(loc, "hlfir transformational intrinsic dynamically optional "
                   "argument without box lowering");
@@ -260,6 +301,27 @@ mlir::Value HlfirReductionIntrinsic<OP, HAS_MASK>::lowerImpl(
   return op;
 }
 
+template <typename OP>
+mlir::Value HlfirMinMaxLocIntrinsic<OP>::lowerImpl(
+    const Fortran::lower::PreparedActualArguments &loweredActuals,
+    const fir::IntrinsicArgumentLoweringRules *argLowering,
+    mlir::Type stmtResultType) {
+  auto operands = getOperandVector(loweredActuals, argLowering);
+  mlir::Value array = operands[0];
+  mlir::Value dim = operands[1];
+  mlir::Value mask = operands[2];
+  mlir::Value back = operands[4];
+  // dim, mask and back can be NULL if these arguments are not given.
+  if (dim)
+    dim = hlfir::loadTrivialScalar(loc, builder, hlfir::Entity{dim});
+  if (back)
+    back = hlfir::loadTrivialScalar(loc, builder, hlfir::Entity{back});
+
+  mlir::Type resultTy = computeResultType(array, stmtResultType);
+
+  return createOp<OP>(resultTy, array, dim, mask, back);
+}
+
 template <typename OP>
 mlir::Value HlfirProductIntrinsic<OP>::lowerImpl(
     const Fortran::lower::PreparedActualArguments &loweredActuals,
@@ -364,6 +426,9 @@ std::optional<hlfir::EntityWithAttributes> Fortran::lower::lowerHlfirIntrinsic(
   if (name == "minval")
     return HlfirMinvalLowering{builder, loc}.lower(loweredActuals, argLowering,
                                                    stmtResultType);
+  if (name == "minloc")
+    return HlfirMinlocLowering{builder, loc}.lower(loweredActuals, argLowering,
+                                                   stmtResultType);
   if (mlir::isa<fir::CharacterType>(stmtResultType)) {
     if (name == "min")
       return HlfirCharExtremumLowering{builder, loc,
diff --git a/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp b/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
index a276e5fc65dd59..94a2213306bfd5 100644
--- a/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
+++ b/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
@@ -870,6 +870,81 @@ void hlfir::MinvalOp::getEffects(
   getIntrinsicEffects(getOperation(), effects);
 }
 
+//===----------------------------------------------------------------------===//
+// MinlocOp
+//===----------------------------------------------------------------------===//
+
+mlir::LogicalResult hlfir::MinlocOp::verify() {
+  mlir::Operation *op = getOperation();
+
+  auto results = op->getResultTypes();
+  assert(results.size() == 1);
+  mlir::Value array = getArray();
+  mlir::Value dim = getDim();
+  mlir::Value mask = getMask();
+
+  fir::SequenceType arrayTy =
+      hlfir::getFortranElementOrSequenceType(array.getType())
+          .cast<fir::SequenceType>();
+  llvm::ArrayRef<int64_t> arrayShape = arrayTy.getShape();
+
+  if (mask) {
+    fir::SequenceType maskSeq =
+        hlfir::getFortranElementOrSequenceType(mask.getType())
+            .dyn_cast<fir::SequenceType>();
+    llvm::ArrayRef<int64_t> maskShape;
+
+    if (maskSeq)
+      maskShape = maskSeq.getShape();
+
+    if (!maskShape.empty()) {
+      if (maskShape.size() != arrayShape.size())
+        return emitWarning("MASK must be conformable to ARRAY");
+      static_assert(fir::SequenceType::getUnknownExtent() ==
+                    hlfir::ExprType::getUnknownExtent());
+      constexpr int64_t unknownExtent = fir::SequenceType::getUnknownExtent();
+      for (std::size_t i = 0; i < arrayShape.size(); ++i) {
+        int64_t arrayExtent = arrayShape[i];
+        int64_t maskExtent = maskShape[i];
+        if ((arrayExtent != maskExtent) && (arrayExtent != unknownExtent) &&
+            (maskExtent != unknownExtent))
+          return emitWarning("MASK must be conformable to ARRAY");
+      }
+    }
+  }
+
+  mlir::Type resultType = results[0];
+  if (dim && arrayShape.size() == 1) {
+    if (!fir::isa_integer(resultType))
+      return emitOpError("result must be scalar integer");
+  } else if (auto resultExpr =
+                 mlir::dyn_cast_or_null<hlfir::ExprType>(resultType)) {
+    if (!resultExpr.isArray())
+      return emitOpError("result must be an array");
+
+    if (!fir::isa_integer(resultExpr.getEleTy()))
+      return emitOpError("result must have integer elements");
+
+    llvm::ArrayRef<int64_t> resultShape = resultExpr.getShape();
+    // With dim the result has rank n-1
+    if (dim && resultShape.size() != (arrayShape.size() - 1))
+      return emitOpError("result rank must be one less than ARRAY");
+    // With dim the result has rank n
+    if (!dim && resultShape.size() != 1)
+      return emitOpError("result rank must be 1");
+  } else {
+    return emitOpError("result must be of numerical expr type");
+  }
+  return mlir::success();
+}
+
+void hlfir::MinlocOp::getEffects(
+    llvm::SmallVectorImpl<
+        mlir::SideEffects::EffectInstance<mlir::MemoryEffects::Effect>>
+        &effects) {
+  getIntrinsicEffects(getOperation(), effects);
+}
+
 //===----------------------------------------------------------------------===//
 // SetLengthOp
 //===----------------------------------------------------------------------===//
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp b/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp
index f2628fcb970bc4..bfebe26fe1d532 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp
@@ -201,6 +201,23 @@ class HlfirReductionIntrinsicConversion : public HlfirIntrinsicConversion<OP> {
     return lowerArguments(operation, inArgs, rewriter, argLowering);
   };
 
+  auto buildMinMaxLocArgs(OP operation, mlir::Type i32, mlir::Type logicalType,
+                          mlir::PatternRewriter &rewriter, std::string opName,
+                          fir::FirOpBuilder builder) const {
+    llvm::SmallVector<IntrinsicArgument, 3> inArgs;
+    inArgs.push_back({operation.getArray(), operation.getArray().getType()});
+    inArgs.push_back({operation.getDim(), i32});
+    inArgs.push_back({operation.getMask(), logicalType});
+    mlir::Type T = hlfir::getFortranElementType(operation.getType());
+    unsigned width = T.cast<mlir::IntegerType>().getWidth();
+    mlir::Value kind =
+        builder.createIntegerConstant(operation->getLoc(), i32, width / 8);
+    inArgs.push_back({kind, i32});
+    inArgs.push_back({operation.getBack(), i32});
+    auto *argLowering = fir::getIntrinsicArgumentLowering(opName);
+    return lowerArguments(operation, inArgs, rewriter, argLowering);
+  };
+
   auto buildLogicalArgs(OP operation, mlir::Type i32, mlir::Type logicalType,
                         mlir::PatternRewriter &rewriter,
                         std::string opName) const {
@@ -224,6 +241,8 @@ class HlfirReductionIntrinsicConversion : public HlfirIntrinsicConversion<OP> {
       opName = "maxval";
     } else if constexpr (std::is_same_v<OP, hlfir::MinvalOp>) {
       opName = "minval";
+    } else if constexpr (std::is_same_v<OP, hlfir::MinlocOp>) {
+      opName = "minloc";
     } else if constexpr (std::is_same_v<OP, hlfir::AnyOp>) {
       opName = "any";
     } else if constexpr (std::is_same_v<OP, hlfir::AllOp>) {
@@ -246,6 +265,9 @@ class HlfirReductionIntrinsicConversion : public HlfirIntrinsicConversion<OP> {
                   std::is_same_v<OP, hlfir::MaxvalOp> ||
                   std::is_same_v<OP, hlfir::MinvalOp>) {
       args = buildNumericalArgs(operation, i32, logicalType, rewriter, opName);
+    } else if constexpr (std::is_same_v<OP, hlfir::MinlocOp>) {
+      args = buildMinMaxLocArgs(operation, i32, logicalType, rewriter, opName,
+                                builder);
     } else {
       args = buildLogicalArgs(operation, i32, logicalType, rewriter, opName);
     }
@@ -269,6 +291,8 @@ using MaxvalOpConversion = HlfirReductionIntrinsicConversion<hlfir::MaxvalOp>;
 
 using MinvalOpConversion = HlfirReductionIntrinsicConversion<hlfir::MinvalOp>;
 
+using MinlocOpConversion = HlfirReductionIntrinsicConversion<hlfir::MinlocOp>;
+
 using AnyOpConversion = HlfirReductionIntrinsicConversion<hlfir::AnyOp>;
 
 using AllOpConversion = HlfirReductionIntrinsicConversion<hlfir::AllOp>;
@@ -441,12 +465,12 @@ class LowerHLFIRIntrinsics
     mlir::ModuleOp module = this->getOperation();
     mlir::MLIRContext *context = &getContext();
     mlir::RewritePatternSet patterns(context);
-    patterns
-        .insert<MatmulOpConversion, MatmulTransposeOpConversion,
-                AllOpConversion, AnyOpConversion, SumOpConversion,
-                ProductOpConversion, TransposeOpConversion, CountOpConversion,
-                DotProductOpConversion, MaxvalOpConversion, MinvalOpConversion>(
-            context);
+    patterns.insert<MatmulOpConversion, MatmulTransposeOpConversion,
+                    AllOpConversion, AnyOpConversion, SumOpConversion,
+                    ProductOpConversion, TransposeOpConversion,
+                    CountOpConversion, DotProductOpConversion,
+                    MaxvalOpConversion, MinvalOpConversion, MinlocOpConversion>(
+        context);
     mlir::ConversionTarget target(*context);
     target.addLegalDialect<mlir::BuiltinDialect, mlir::arith::ArithDialect,
                            mlir::func::FuncDialect, fir::FIROpsDialect,
@@ -454,7 +478,7 @@ class LowerHLFIRIntrinsics
     target.addIllegalOp<hlfir::MatmulOp, hlfir::MatmulTransposeOp, hlfir::SumOp,
                         hlfir::ProductOp, hlfir::TransposeOp, hlfir::AnyOp,
                         hlfir::AllOp, hlfir::DotProductOp, hlfir::CountOp,
-                        hlfir::MaxvalOp, hlfir::MinvalOp>();
+                        hlfir::MaxvalOp, hlfir::MinvalOp, hlfir::MinlocOp>();
     target.markUnknownOpDynamicallyLegal(
         [](mlir::Operation *) { return true; });
     if (mlir::failed(
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp b/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
index 7abfa20493c736..218ddd2a6a7b7e 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
@@ -20,6 +20,7 @@
 #include "flang/Optimizer/HLFIR/HLFIRDialect.h"
 #include "flang/Optimizer/HLFIR/HLFIROps.h"
 #include "flang/Optimizer/HLFIR/Passes.h"
+#include "flang/Optimizer/Support/Utils.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/IR/Dominance.h"
 #include "mlir/IR/PatternMatch.h"
@@ -98,7 +99,8 @@ class ElementalAssignBufferization
 /// the same block. If any operations with unknown effects are found,
 /// std::nullopt is returned
 static std::optional<mlir::SmallVector<mlir::MemoryEffects::EffectInstance>>
-getEffectsBetween(mlir::Operation *start, mlir::Operation *end) {
+getEffectsBetween(mlir::Operation *start, mlir::Operation *end,
+                  mlir::Operation *ignoring) {
   mlir::SmallVector<mlir::MemoryEffects::EffectInstance> ret;
   if (start == end)
     return ret;
@@ -108,6 +110,10 @@ getEffectsBetween(mlir::Operation *start, mlir::Operation *end) {
 
   mlir::Operation *nextOp = start;
   while (nextOp && nextOp != end) {
+    if (nextOp == ignoring) {
+      nextOp = nextOp->getNextNode();
+      continue;
+    }
     std::optional<mlir::SmallVector<mlir::MemoryEffects::EffectInstance>>
         effects = mlir::getEffectsRecursively(nextOp);
     if (!effects)
@@ -293,80 +299,10 @@ static bool areIdenticalOrDisjointSlices(mlir::Value ref1, mlir::Value ref2) {
   return false;
 }
 
-std::optional<ElementalAssignBufferization::MatchInfo>
-ElementalAssignBufferization::findMatch(hlfir::ElementalOp elemental) {
-  mlir::Operation::user_range users = elemental->getUsers();
-  // the only uses of the elemental should be the assignment and the destroy
-  if (std::distance(users.begin(), users.end()) != 2) {
-    LLVM_DEBUG(llvm::dbgs() << "Too many uses of the elemental\n");
-    return std::nullopt;
-  }
-
-  // If the ElementalOp must produce a temporary (e.g. for
-  // finalization purposes), then we cannot inline it.
-  if (hlfir::elementalOpMustProduceTemp(elemental)) {
-    LLVM_DEBUG(llvm::dbgs() << "ElementalOp must produce a temp\n");
-    return std::nullopt;
-  }
-
-  MatchInfo match;
-  for (mlir::Operation *user : users)
-    mlir::TypeSwitch<mlir::Operation *, void>(user)
-        .Case([&](hlfir::AssignOp op) { match.assign = op; })
-        .Case([&](hlfir::DestroyOp op) { match.destroy = op; });
-
-  if (!match.assign || !match.destroy) {
-    LLVM_DEBUG(llvm::dbgs() << "Couldn't find assign or destroy\n");
-    return std::nullopt;
-  }
-
-  // the array is what the elemental is assigned into
-  // TODO: this could be extended to also allow hlfir.expr by first bufferizing
-  // the incoming expression
-  match.array = match.assign.getLhs();
-  mlir::Type arrayType = mlir::dyn_cast<fir::SequenceType>(
-      fir::unwrapPassByRefType(match.array.getType()));
-  if (!arrayType)
-    return std::nullopt;
-
-  // require that the array elements are trivial
-  // TODO: this is just to make the pass easier to think about. Not an inherent
-  // limitation
-  mlir::Type eleTy = hlfir::getFortranElementType(arrayType);
-  if (!fir::isa_trivial(eleTy))
-    return std::nullopt;
-
-  // the array must have the same shape as the elemental. CSE should have
-  // deduplicated the fir.shape operations where they are provably the same
-  // so we just have to check for the same ssa value
-  // TODO: add more ways of getting the shape of the array
-  mlir::Value arrayShape;
-  if (match.array.getDefiningOp())
-    arrayShape =
-        mlir::TypeSwitch<mlir::Operation *, mlir::Value>(
-            match.array.getDefiningOp())
-            .Case([](hlfir::DesignateOp designate) {
-              return designate.getShape();
-            })
-            .Case([](hlfir::DeclareOp declare) { return declare.getShape(); })
-            .Default([](mlir::Operation *) { return mlir::Value{}; });
-  if (!arrayShape) {
-    LLVM_DEBUG(llvm::dbgs() << "Can't get shape of " << match.array << " at "
-                            << elemental->getLoc() << "\n");
-    return std::nullopt;
-  }
-  if (arrayShape != elemental.getShape()) {
-    // f2018 10.2.1.2 (3) requires the lhs and rhs of an assignment to be
-    // conformable unless the lhs is an allocatable array. In HLFIR we can
-    // see this from the presence or absence of the realloc attribute on
-    // hlfir.assign. If it is not a realloc assignment, we can trust that
-    // the shapes do conform
-    if (match.assign.getRealloc())
-      return std::nullopt;
-  }
-
-  // the transformation wants to apply the elemental in a do-loop at the
-  // hlfir.assign, check there are no effects which make this unsafe
+static bool checkForElementalEffectsBetween(hlfir::ElementalOp elemental,
+                            ...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/74828


More information about the flang-commits mailing list