[flang-commits] [flang] 3191e8e - [flang] Lower binary and unary elemental array operations

Jean Perier via flang-commits flang-commits at lists.llvm.org
Thu Dec 15 04:16:42 PST 2022


Author: Jean Perier
Date: 2022-12-15T13:15:12+01:00
New Revision: 3191e8e19f1a7007ddd0e55cee60a51a058c99f5

URL: https://github.com/llvm/llvm-project/commit/3191e8e19f1a7007ddd0e55cee60a51a058c99f5
DIFF: https://github.com/llvm/llvm-project/commit/3191e8e19f1a7007ddd0e55cee60a51a058c99f5.diff

LOG: [flang] Lower binary and unary elemental array operations

Lower binary and unary elemental operations with an array argument
using hlfir.elemental, hlfir.yield_element, and hlfir.apply.

Concat implementation, which is a binary operation, is moved to a
BinaryOp struct so that it can leverage this new code.

This patch implements the "not yet implemented: character array
expression temp with dynamic length" TODO of the current lowering
by splitting the result length computation from the result value
computation. That way, the result length computation can be done
before lowering the operation to an hlfir.elemental, and the length
of the hlfir.elemental is known and storage for it can later be
allocated.

It adds a DesignatorOp builder to make "dumb" indexing (without triplets,
component, substrings or derived type component ref) easier since indexing
needs to be generated for array variables in elemental expression (in
the added hlfir::genElementAt helper).

Differential Revision: https://reviews.llvm.org/D140040

Added: 
    flang/test/Lower/HLFIR/elemental-array-ops.f90

Modified: 
    flang/include/flang/Optimizer/Builder/HLFIRTools.h
    flang/include/flang/Optimizer/HLFIR/HLFIROpBase.td
    flang/include/flang/Optimizer/HLFIR/HLFIROps.td
    flang/lib/Lower/ConvertExprToHLFIR.cpp
    flang/lib/Optimizer/Builder/HLFIRTools.cpp
    flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp

Removed: 
    


################################################################################
diff  --git a/flang/include/flang/Optimizer/Builder/HLFIRTools.h b/flang/include/flang/Optimizer/Builder/HLFIRTools.h
index 8f64075f37053..f1492b2a78eee 100644
--- a/flang/include/flang/Optimizer/Builder/HLFIRTools.h
+++ b/flang/include/flang/Optimizer/Builder/HLFIRTools.h
@@ -24,6 +24,7 @@ class FirOpBuilder;
 namespace hlfir {
 
 class AssociateOp;
+class ElementalOp;
 
 /// Is this an SSA value type for the value of a Fortran expression?
 inline bool isFortranValueType(mlir::Type type) {
@@ -70,6 +71,9 @@ class Entity : public mlir::Value {
   bool isValue() const { return isFortranValue(*this); }
   bool isVariable() const { return !isValue(); }
   bool isMutableBox() const { return hlfir::isBoxAddressType(getType()); }
+  bool isBoxAddressOrValue() const {
+    return hlfir::isBoxAddressOrValueType(getType());
+  }
   bool isArray() const {
     mlir::Type type = fir::unwrapPassByRefType(fir::unwrapRefType(getType()));
     if (type.isa<fir::SequenceType>())
@@ -80,6 +84,12 @@ class Entity : public mlir::Value {
   }
   bool isScalar() const { return !isArray(); }
 
+  bool isPolymorphic() const {
+    if (auto exprType = getType().dyn_cast<hlfir::ExprType>())
+      return exprType.isPolymorphic();
+    return fir::isPolymorphicType(getType());
+  }
+
   mlir::Type getFortranElementType() const {
     return hlfir::getFortranElementType(getType());
   }
@@ -94,6 +104,20 @@ class Entity : public mlir::Value {
     return getFortranElementType().isa<fir::CharacterType>();
   }
 
+  bool hasNonDefaultLowerBounds() const {
+    if (!isBoxAddressOrValue() || isScalar())
+      return false;
+    if (isMutableBox())
+      return true;
+    if (auto varIface = getIfVariableInterface())
+      if (auto shape = varIface.getShape()) {
+        auto shapeTy = shape.getType();
+        return shapeTy.isa<fir::ShiftType>() ||
+               shapeTy.isa<fir::ShapeShiftType>();
+      }
+    return true;
+  }
+
   fir::FortranVariableOpInterface getIfVariableInterface() const {
     return this->getDefiningOp<fir::FortranVariableOpInterface>();
   }
@@ -176,8 +200,9 @@ mlir::Value genVariableBoxChar(mlir::Location loc, fir::FirOpBuilder &builder,
                                hlfir::Entity var);
 
 /// If the entity is a variable, load its value (dereference pointers and
-/// allocatables if needed). Do nothing if the entity os already a variable or
-/// if it is not a scalar entity of numerical or logical type.
+/// allocatables if needed). Do nothing if the entity is already a value, and
+/// only dereference pointers and allocatables if it is not a scalar entity
+/// of numerical or logical type.
 Entity loadTrivialScalar(mlir::Location loc, fir::FirOpBuilder &builder,
                          Entity entity);
 
@@ -187,10 +212,19 @@ hlfir::Entity derefPointersAndAllocatables(mlir::Location loc,
                                            fir::FirOpBuilder &builder,
                                            Entity entity);
 
+/// Get element entity(oneBasedIndices) if entity is an array, or return entity
+/// if it is a scalar. The indices are one based. If the entity has non default
+/// lower bounds, the function will adapt the indices in the indexing operation.
+hlfir::Entity getElementAt(mlir::Location loc, fir::FirOpBuilder &builder,
+                           Entity entity, mlir::ValueRange oneBasedIndices);
 /// Compute the lower and upper bounds of an entity.
 llvm::SmallVector<std::pair<mlir::Value, mlir::Value>>
 genBounds(mlir::Location loc, fir::FirOpBuilder &builder, Entity entity);
 
+/// Compute fir.shape<> (no lower bounds) for an entity.
+mlir::Value genShape(mlir::Location loc, fir::FirOpBuilder &builder,
+                     Entity entity);
+
 /// Read length parameters into result if this entity has any.
 void genLengthParameters(mlir::Location loc, fir::FirOpBuilder &builder,
                          Entity entity,
@@ -204,6 +238,21 @@ std::pair<mlir::Value, mlir::Value> genVariableFirBaseShapeAndParams(
     mlir::Location loc, fir::FirOpBuilder &builder, Entity entity,
     llvm::SmallVectorImpl<mlir::Value> &typeParams);
 
+/// Get the variable type for an element of an array type entity. Returns the
+/// input entity type if it is scalar. Will crash if the entity is not a
+/// variable.
+mlir::Type getVariableElementType(hlfir::Entity variable);
+
+using ElementalKernelGenerator = std::function<hlfir::Entity(
+    mlir::Location, fir::FirOpBuilder &, mlir::ValueRange)>;
+/// Generate an hlfir.elementalOp given call back to generate the element
+/// value at for each iteration.
+hlfir::ElementalOp genElementalOp(mlir::Location loc,
+                                  fir::FirOpBuilder &builder,
+                                  mlir::Type elementType, mlir::Value shape,
+                                  mlir::ValueRange typeParams,
+                                  const ElementalKernelGenerator &genKernel);
+
 } // namespace hlfir
 
 #endif // FORTRAN_OPTIMIZER_BUILDER_HLFIRTOOLS_H

diff  --git a/flang/include/flang/Optimizer/HLFIR/HLFIROpBase.td b/flang/include/flang/Optimizer/HLFIR/HLFIROpBase.td
index 557bf5cacc14b..d17a4cbf5e1b6 100644
--- a/flang/include/flang/Optimizer/HLFIR/HLFIROpBase.td
+++ b/flang/include/flang/Optimizer/HLFIR/HLFIROpBase.td
@@ -74,6 +74,9 @@ def hlfir_ExprType : TypeDef<hlfir_Dialect, "Expr"> {
       return hlfir::ExprType::get(eleTy.getContext(), Shape{}, eleTy,
                 isPolymorphic());
     }
+    static constexpr int64_t getUnknownExtent() {
+      return mlir::ShapedType::kDynamic;
+    }
   }];
 
   let hasCustomAssemblyFormat = 1;

diff  --git a/flang/include/flang/Optimizer/HLFIR/HLFIROps.td b/flang/include/flang/Optimizer/HLFIR/HLFIROps.td
index 7d4a024ed488e..93e3383e27124 100644
--- a/flang/include/flang/Optimizer/HLFIR/HLFIROps.td
+++ b/flang/include/flang/Optimizer/HLFIR/HLFIROps.td
@@ -194,6 +194,7 @@ def hlfir_DesignateOp : hlfir_Op<"designate", [AttrSizedOperandSegments,
   let extraClassDeclaration = [{
     using Triplet = std::tuple<mlir::Value, mlir::Value, mlir::Value>;
     using Subscript = std::variant<mlir::Value, Triplet>;
+    using Subscripts = llvm::SmallVector<Subscript, 8>;
   }];
 
   let builders = [
@@ -203,7 +204,13 @@ def hlfir_DesignateOp : hlfir_Op<"designate", [AttrSizedOperandSegments,
       CArg<"mlir::ValueRange", "{}">:$substring,
       CArg<"llvm::Optional<bool>", "{}">:$complex_part,
       CArg<"mlir::Value", "{}">:$shape, CArg<"mlir::ValueRange", "{}">:$typeparams,
-      CArg<"fir::FortranVariableFlagsAttr", "{}">:$fortran_attrs)>];
+      CArg<"fir::FortranVariableFlagsAttr", "{}">:$fortran_attrs)>,
+
+    OpBuilder<(ins "mlir::Type":$result_type, "mlir::Value":$memref,
+      "mlir::ValueRange":$indices,
+      CArg<"mlir::ValueRange", "{}">:$typeparams,
+      CArg<"fir::FortranVariableFlagsAttr", "{}">:$fortran_attrs)>
+    ];
 
   let hasVerifier = 1;
 }

diff  --git a/flang/lib/Lower/ConvertExprToHLFIR.cpp b/flang/lib/Lower/ConvertExprToHLFIR.cpp
index 454f224f2e5c1..b26c499cb5172 100644
--- a/flang/lib/Lower/ConvertExprToHLFIR.cpp
+++ b/flang/lib/Lower/ConvertExprToHLFIR.cpp
@@ -72,7 +72,7 @@ class HlfirDesignatorBuilder {
   /// become the operands of an hlfir.declare.
   struct PartInfo {
     fir::FortranVariableOpInterface base;
-    llvm::SmallVector<hlfir::DesignateOp::Subscript, 8> subscripts;
+    hlfir::DesignateOp::Subscripts subscripts;
     mlir::Value resultShape;
     llvm::SmallVector<mlir::Value> typeParams;
   };
@@ -319,14 +319,6 @@ struct BinaryOp<
                                          fir::FirOpBuilder &builder,
                                          const Op &op, hlfir::Entity lhs,
                                          hlfir::Entity rhs) {
-    // evaluate::Extremum is only created by the front-end when building
-    // compiler generated expressions (like when folding LEN() or shape/bounds
-    // inquiries). MIN and MAX are represented as evaluate::ProcedureRef and are
-    // not going through here. So far the frontend does not generate character
-    // Extremum so there is no way to test it.
-    if constexpr (TC == Fortran::common::TypeCategory::Character) {
-      fir::emitFatalError(loc, "Fortran::evaluate::Extremum are unexpected");
-    }
     llvm::SmallVector<mlir::Value, 2> args{lhs, rhs};
     fir::ExtendedValue res = op.ordering == Fortran::evaluate::Ordering::Greater
                                  ? Fortran::lower::genMax(builder, loc, args)
@@ -335,6 +327,28 @@ struct BinaryOp<
   }
 };
 
+// evaluate::Extremum is only created by the front-end when building compiler
+// generated expressions (like when folding LEN() or shape/bounds inquiries).
+// MIN and MAX are represented as evaluate::ProcedureRef and are not going
+// through here. So far the frontend does not generate character Extremum so
+// there is no way to test it.
+template <int KIND>
+struct BinaryOp<Fortran::evaluate::Extremum<
+    Fortran::evaluate::Type<Fortran::common::TypeCategory::Character, KIND>>> {
+  using Op = Fortran::evaluate::Extremum<
+      Fortran::evaluate::Type<Fortran::common::TypeCategory::Character, KIND>>;
+  static hlfir::EntityWithAttributes gen(mlir::Location loc,
+                                         fir::FirOpBuilder &, const Op &,
+                                         hlfir::Entity, hlfir::Entity) {
+    fir::emitFatalError(loc, "Fortran::evaluate::Extremum are unexpected");
+  }
+  static void genResultTypeParams(mlir::Location loc, fir::FirOpBuilder &,
+                                  hlfir::Entity, hlfir::Entity,
+                                  llvm::SmallVectorImpl<mlir::Value> &) {
+    fir::emitFatalError(loc, "Fortran::evaluate::Extremum are unexpected");
+  }
+};
+
 /// Convert parser's INTEGER relational operators to MLIR.
 static mlir::arith::CmpIPredicate
 translateRelational(Fortran::common::RelationalOperator rop) {
@@ -501,6 +515,42 @@ struct BinaryOp<Fortran::evaluate::SetLength<KIND>> {
                                          hlfir::Entity, hlfir::Entity) {
     TODO(loc, "SetLength lowering to HLFIR");
   }
+  static void
+  genResultTypeParams(mlir::Location loc, fir::FirOpBuilder &builder,
+                      hlfir::Entity lhs, hlfir::Entity rhs,
+                      llvm::SmallVectorImpl<mlir::Value> &resultTypeParams) {
+    resultTypeParams.push_back(rhs);
+  }
+};
+
+template <int KIND>
+struct BinaryOp<Fortran::evaluate::Concat<KIND>> {
+  using Op = Fortran::evaluate::Concat<KIND>;
+  hlfir::EntityWithAttributes gen(mlir::Location loc,
+                                  fir::FirOpBuilder &builder, const Op &,
+                                  hlfir::Entity lhs, hlfir::Entity rhs) {
+    assert(len && "genResultTypeParams must have been called");
+    auto concat =
+        builder.create<hlfir::ConcatOp>(loc, mlir::ValueRange{lhs, rhs}, len);
+    return hlfir::EntityWithAttributes{concat.getResult()};
+  }
+  void
+  genResultTypeParams(mlir::Location loc, fir::FirOpBuilder &builder,
+                      hlfir::Entity lhs, hlfir::Entity rhs,
+                      llvm::SmallVectorImpl<mlir::Value> &resultTypeParams) {
+    llvm::SmallVector<mlir::Value> lengths;
+    hlfir::genLengthParameters(loc, builder, lhs, lengths);
+    hlfir::genLengthParameters(loc, builder, rhs, lengths);
+    assert(lengths.size() == 2 && "lacks rhs or lhs length");
+    mlir::Type idxType = builder.getIndexType();
+    mlir::Value lhsLen = builder.createConvert(loc, idxType, lengths[0]);
+    mlir::Value rhsLen = builder.createConvert(loc, idxType, lengths[1]);
+    len = builder.create<mlir::arith::AddIOp>(loc, lhsLen, rhsLen);
+    resultTypeParams.push_back(len);
+  }
+
+private:
+  mlir::Value len{};
 };
 
 //===--------------------------------------------------------------------===//
@@ -590,6 +640,13 @@ struct UnaryOp<Fortran::evaluate::Parentheses<T>> {
     return hlfir::EntityWithAttributes{
         builder.create<hlfir::NoReassocOp>(loc, lhs.getType(), lhs)};
   }
+
+  static void
+  genResultTypeParams(mlir::Location loc, fir::FirOpBuilder &builder,
+                      hlfir::Entity lhs,
+                      llvm::SmallVectorImpl<mlir::Value> &resultTypeParams) {
+    hlfir::genLengthParameters(loc, builder, lhs, resultTypeParams);
+  }
 };
 
 template <Fortran::common::TypeCategory TC1, int KIND,
@@ -610,6 +667,13 @@ struct UnaryOp<
     mlir::Value res = builder.convertWithSemantics(loc, type, lhs);
     return hlfir::EntityWithAttributes{res};
   }
+
+  static void
+  genResultTypeParams(mlir::Location loc, fir::FirOpBuilder &builder,
+                      hlfir::Entity lhs,
+                      llvm::SmallVectorImpl<mlir::Value> &resultTypeParams) {
+    hlfir::genLengthParameters(loc, builder, lhs, resultTypeParams);
+  }
 };
 
 /// Lower Expr to HLFIR.
@@ -695,10 +759,37 @@ class HlfirBuilder {
   gen(const Fortran::evaluate::Operation<D, R, O> &op) {
     auto &builder = getBuilder();
     mlir::Location loc = getLoc();
-    if (op.Rank() != 0)
-      TODO(loc, "elemental operations in HLFIR");
+    const int rank = op.Rank();
+    UnaryOp<D> unaryOp;
     auto left = hlfir::loadTrivialScalar(loc, builder, gen(op.left()));
-    return UnaryOp<D>::gen(loc, builder, op.derived(), left);
+    llvm::SmallVector<mlir::Value, 1> typeParams;
+    if constexpr (R::category == Fortran::common::TypeCategory::Character) {
+      unaryOp.genResultTypeParams(loc, builder, left, typeParams);
+    }
+    if (rank == 0)
+      return unaryOp.gen(loc, builder, op.derived(), left);
+
+    // Elemental expression.
+    mlir::Type elementType;
+    if constexpr (R::category == Fortran::common::TypeCategory::Derived) {
+      elementType = Fortran::lower::translateDerivedTypeToFIRType(
+          getConverter(), op.derived().GetType().GetDerivedTypeSpec());
+    } else {
+      elementType =
+          Fortran::lower::getFIRType(builder.getContext(), R::category, R::kind,
+                                     /*params=*/std::nullopt);
+    }
+    mlir::Value shape = hlfir::genShape(loc, builder, left);
+    auto genKernel = [&op, &left, &unaryOp](
+                         mlir::Location l, fir::FirOpBuilder &b,
+                         mlir::ValueRange oneBasedIndices) -> hlfir::Entity {
+      auto leftElement = hlfir::getElementAt(l, b, left, oneBasedIndices);
+      auto leftVal = hlfir::loadTrivialScalar(l, b, leftElement);
+      return unaryOp.gen(l, b, op.derived(), leftVal);
+    };
+    // TODO: deal with hlfir.elemental result destruction.
+    return hlfir::EntityWithAttributes{hlfir::genElementalOp(
+        loc, builder, elementType, shape, typeParams, genKernel)};
   }
 
   template <typename D, typename R, typename LO, typename RO>
@@ -706,30 +797,41 @@ class HlfirBuilder {
   gen(const Fortran::evaluate::Operation<D, R, LO, RO> &op) {
     auto &builder = getBuilder();
     mlir::Location loc = getLoc();
-    if (op.Rank() != 0)
-      TODO(loc, "elemental operations in HLFIR");
+    const int rank = op.Rank();
+    BinaryOp<D> binaryOp;
     auto left = hlfir::loadTrivialScalar(loc, builder, gen(op.left()));
     auto right = hlfir::loadTrivialScalar(loc, builder, gen(op.right()));
-    return BinaryOp<D>::gen(loc, builder, op.derived(), left, right);
-  }
-
-  template <int KIND>
-  hlfir::EntityWithAttributes gen(const Fortran::evaluate::Concat<KIND> &op) {
-    auto lhs = gen(op.left());
-    auto rhs = gen(op.right());
-    llvm::SmallVector<mlir::Value> lengths;
-    auto &builder = getBuilder();
-    mlir::Location loc = getLoc();
-    hlfir::genLengthParameters(loc, builder, lhs, lengths);
-    hlfir::genLengthParameters(loc, builder, rhs, lengths);
-    assert(lengths.size() == 2 && "lacks rhs or lhs length");
-    mlir::Type idxType = builder.getIndexType();
-    mlir::Value lhsLen = builder.createConvert(loc, idxType, lengths[0]);
-    mlir::Value rhsLen = builder.createConvert(loc, idxType, lengths[1]);
-    mlir::Value len = builder.create<mlir::arith::AddIOp>(loc, lhsLen, rhsLen);
-    auto concat =
-        builder.create<hlfir::ConcatOp>(loc, mlir::ValueRange{lhs, rhs}, len);
-    return hlfir::EntityWithAttributes{concat.getResult()};
+    llvm::SmallVector<mlir::Value, 1> typeParams;
+    if constexpr (R::category == Fortran::common::TypeCategory::Character) {
+      binaryOp.genResultTypeParams(loc, builder, left, right, typeParams);
+    }
+    if (rank == 0)
+      return binaryOp.gen(loc, builder, op.derived(), left, right);
+
+    // Elemental expression.
+    mlir::Type elementType =
+        Fortran::lower::getFIRType(builder.getContext(), R::category, R::kind,
+                                   /*params=*/std::nullopt);
+    // TODO: "merge" shape, get cst shape from front-end if possible.
+    mlir::Value shape;
+    if (left.isArray()) {
+      shape = hlfir::genShape(loc, builder, left);
+    } else {
+      assert(right.isArray() && "must have at least one array operand");
+      shape = hlfir::genShape(loc, builder, right);
+    }
+    auto genKernel = [&op, &left, &right, &binaryOp](
+                         mlir::Location l, fir::FirOpBuilder &b,
+                         mlir::ValueRange oneBasedIndices) -> hlfir::Entity {
+      auto leftElement = hlfir::getElementAt(l, b, left, oneBasedIndices);
+      auto rightElement = hlfir::getElementAt(l, b, right, oneBasedIndices);
+      auto leftVal = hlfir::loadTrivialScalar(l, b, leftElement);
+      auto rightVal = hlfir::loadTrivialScalar(l, b, rightElement);
+      return binaryOp.gen(l, b, op.derived(), leftVal, rightVal);
+    };
+    // TODO: deal with hlfir.elemental result destruction.
+    return hlfir::EntityWithAttributes{hlfir::genElementalOp(
+        loc, builder, elementType, shape, typeParams, genKernel)};
   }
 
   hlfir::EntityWithAttributes

diff  --git a/flang/lib/Optimizer/Builder/HLFIRTools.cpp b/flang/lib/Optimizer/Builder/HLFIRTools.cpp
index 73d45d37c6cf8..28791b546fad4 100644
--- a/flang/lib/Optimizer/Builder/HLFIRTools.cpp
+++ b/flang/lib/Optimizer/Builder/HLFIRTools.cpp
@@ -183,7 +183,7 @@ hlfir::AssociateOp hlfir::genAssociateExpr(mlir::Location loc,
   assert(value.isValue() && "must not be a variable");
   mlir::Value shape{};
   if (value.isArray())
-    TODO(loc, "associating array expressions");
+    shape = genShape(loc, builder, value);
 
   mlir::Value source = value;
   // Lowered scalar expression values for numerical and logical may have a
@@ -244,14 +244,63 @@ mlir::Value hlfir::genVariableBoxChar(mlir::Location loc,
 hlfir::Entity hlfir::loadTrivialScalar(mlir::Location loc,
                                        fir::FirOpBuilder &builder,
                                        Entity entity) {
+  entity = derefPointersAndAllocatables(loc, builder, entity);
   if (entity.isVariable() && entity.isScalar() &&
       fir::isa_trivial(entity.getFortranElementType())) {
-    entity = derefPointersAndAllocatables(loc, builder, entity);
     return Entity{builder.create<fir::LoadOp>(loc, entity)};
   }
   return entity;
 }
 
+static std::optional<llvm::SmallVector<mlir::Value>>
+getNonDefaultLowerBounds(mlir::Location loc, fir::FirOpBuilder &builder,
+                         hlfir::Entity entity) {
+  if (!entity.hasNonDefaultLowerBounds())
+    return std::nullopt;
+  if (auto varIface = entity.getIfVariableInterface()) {
+    llvm::SmallVector<mlir::Value> lbounds = getExplicitLbounds(varIface);
+    if (!lbounds.empty())
+      return lbounds;
+  }
+  TODO(loc, "get non default lower bounds without FortranVariableInterface");
+}
+
+hlfir::Entity hlfir::getElementAt(mlir::Location loc,
+                                  fir::FirOpBuilder &builder, Entity entity,
+                                  mlir::ValueRange oneBasedIndices) {
+  if (entity.isScalar())
+    return entity;
+  llvm::SmallVector<mlir::Value> lenParams;
+  genLengthParameters(loc, builder, entity, lenParams);
+  if (entity.getType().isa<hlfir::ExprType>())
+    return hlfir::Entity{builder.create<hlfir::ApplyOp>(
+        loc, entity, oneBasedIndices, lenParams)};
+  // Build hlfir.designate. The lower bounds may need to be added to
+  // the oneBasedIndices since hlfir.designate expect indices
+  // based on the array operand lower bounds.
+  mlir::Type resultType = hlfir::getVariableElementType(entity);
+  hlfir::DesignateOp designate;
+  if (auto lbounds = getNonDefaultLowerBounds(loc, builder, entity)) {
+    llvm::SmallVector<mlir::Value> indices;
+    mlir::Type idxTy = builder.getIndexType();
+    mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1);
+    for (auto [oneBased, lb] : llvm::zip(oneBasedIndices, *lbounds)) {
+      auto lbIdx = builder.createConvert(loc, idxTy, lb);
+      auto oneBasedIdx = builder.createConvert(loc, idxTy, oneBased);
+      auto shift = builder.create<mlir::arith::SubIOp>(loc, lbIdx, one);
+      mlir::Value index =
+          builder.create<mlir::arith::AddIOp>(loc, oneBasedIdx, shift);
+      indices.push_back(index);
+    }
+    designate = builder.create<hlfir::DesignateOp>(loc, resultType, entity,
+                                                   indices, lenParams);
+  } else {
+    designate = builder.create<hlfir::DesignateOp>(loc, resultType, entity,
+                                                   oneBasedIndices, lenParams);
+  }
+  return mlir::cast<fir::FortranVariableOpInterface>(designate.getOperation());
+}
+
 static mlir::Value genUBound(mlir::Location loc, fir::FirOpBuilder &builder,
                              mlir::Value lb, mlir::Value extent,
                              mlir::Value one) {
@@ -285,6 +334,45 @@ hlfir::genBounds(mlir::Location loc, fir::FirOpBuilder &builder,
   return result;
 }
 
+static hlfir::Entity followEntitySource(hlfir::Entity entity) {
+  while (true) {
+    if (auto reassoc = entity.getDefiningOp<hlfir::NoReassocOp>()) {
+      entity = hlfir::Entity{reassoc.getVal()};
+      continue;
+    }
+    if (auto asExpr = entity.getDefiningOp<hlfir::AsExprOp>()) {
+      entity = hlfir::Entity{asExpr.getVar()};
+      continue;
+    }
+    break;
+  }
+  return entity;
+}
+
+mlir::Value hlfir::genShape(mlir::Location loc, fir::FirOpBuilder &builder,
+                            hlfir::Entity entity) {
+  assert(entity.isArray() && "entity must be an array");
+  if (entity.isMutableBox())
+    entity = hlfir::derefPointersAndAllocatables(loc, builder, entity);
+  else
+    entity = followEntitySource(entity);
+
+  if (auto varIface = entity.getIfVariableInterface()) {
+    if (auto shape = varIface.getShape()) {
+      if (shape.getType().isa<fir::ShapeType>())
+        return shape;
+      if (shape.getType().isa<fir::ShapeShiftType>())
+        if (auto s = shape.getDefiningOp<fir::ShapeShiftOp>())
+          return builder.create<fir::ShapeOp>(loc, s.getExtents());
+    }
+  } else if (entity.getType().isa<hlfir::ExprType>()) {
+    if (auto elemental = entity.getDefiningOp<hlfir::ElementalOp>())
+      return elemental.getShape();
+    TODO(loc, "get shape from HLFIR expr without producer holding the shape");
+  }
+  TODO(loc, "get shape from HLFIR variable without interface");
+}
+
 void hlfir::genLengthParameters(mlir::Location loc, fir::FirOpBuilder &builder,
                                 Entity entity,
                                 llvm::SmallVectorImpl<mlir::Value> &result) {
@@ -304,6 +392,12 @@ void hlfir::genLengthParameters(mlir::Location loc, fir::FirOpBuilder &builder,
       hlfir::genLengthParameters(loc, builder, hlfir::Entity{asExpr.getVar()},
                                  result);
       return;
+    } else if (auto elemental = expr.getDefiningOp<hlfir::ElementalOp>()) {
+      result.append(elemental.getTypeparams().begin(),
+                    elemental.getTypeparams().end());
+      return;
+    } else if (auto apply = expr.getDefiningOp<hlfir::ApplyOp>()) {
+      result.append(apply.getTypeparams().begin(), apply.getTypeparams().end());
     }
     TODO(loc, "inquire type parameters of hlfir.expr");
   }
@@ -340,3 +434,53 @@ hlfir::Entity hlfir::derefPointersAndAllocatables(mlir::Location loc,
     return hlfir::Entity{builder.create<fir::LoadOp>(loc, entity).getResult()};
   return entity;
 }
+
+mlir::Type hlfir::getVariableElementType(hlfir::Entity variable) {
+  assert(variable.isVariable() && "entity must be a variable");
+  if (variable.isScalar())
+    return variable.getType();
+  mlir::Type eleTy = variable.getFortranElementType();
+  if (variable.isPolymorphic())
+    return fir::ClassType::get(eleTy);
+  if (auto charType = eleTy.dyn_cast<fir::CharacterType>()) {
+    if (charType.hasDynamicLen())
+      return fir::BoxCharType::get(charType.getContext(), charType.getFKind());
+  } else if (fir::isRecordWithTypeParameters(eleTy)) {
+    return fir::BoxType::get(eleTy);
+  }
+  return fir::ReferenceType::get(eleTy);
+}
+
+static hlfir::ExprType getArrayExprType(mlir::Type elementType,
+                                        mlir::Value shape, bool isPolymorphic) {
+  unsigned rank = shape.getType().cast<fir::ShapeType>().getRank();
+  hlfir::ExprType::Shape typeShape(rank, hlfir::ExprType::getUnknownExtent());
+  if (auto shapeOp = shape.getDefiningOp<fir::ShapeOp>())
+    for (auto extent : llvm::enumerate(shapeOp.getExtents()))
+      if (auto cstExtent = fir::factory::getIntIfConstant(extent.value()))
+        typeShape[extent.index()] = *cstExtent;
+  return hlfir::ExprType::get(elementType.getContext(), typeShape, elementType,
+                              isPolymorphic);
+}
+
+hlfir::ElementalOp
+hlfir::genElementalOp(mlir::Location loc, fir::FirOpBuilder &builder,
+                      mlir::Type elementType, mlir::Value shape,
+                      mlir::ValueRange typeParams,
+                      const ElementalKernelGenerator &genKernel) {
+  mlir::Type exprType = getArrayExprType(elementType, shape, false);
+  auto elementalOp =
+      builder.create<hlfir::ElementalOp>(loc, exprType, shape, typeParams);
+  auto insertPt = builder.saveInsertionPoint();
+  builder.setInsertionPointToStart(elementalOp.getBody());
+  mlir::Value elementResult = genKernel(loc, builder, elementalOp.getIndices());
+  // Numerical and logical scalars may be lowered to another type than the
+  // Fortran expression type (e.g i1 instead of fir.logical). Array expression
+  // values are typed according to their Fortran type. Insert a cast if needed
+  // here.
+  if (fir::isa_trivial(elementResult.getType()))
+    elementResult = builder.createConvert(loc, elementType, elementResult);
+  builder.create<hlfir::YieldElementOp>(loc, elementResult);
+  builder.restoreInsertionPoint(insertPt);
+  return elementalOp;
+}

diff  --git a/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp b/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
index 0406d209d4c14..e537e7cfe3a35 100644
--- a/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
+++ b/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
@@ -116,6 +116,22 @@ void hlfir::DesignateOp::build(
         fortran_attrs);
 }
 
+void hlfir::DesignateOp::build(mlir::OpBuilder &builder,
+                               mlir::OperationState &result,
+                               mlir::Type result_type, mlir::Value memref,
+                               mlir::ValueRange indices,
+                               mlir::ValueRange typeparams,
+                               fir::FortranVariableFlagsAttr fortran_attrs) {
+  llvm::SmallVector<bool> isTriplet(indices.size(), false);
+  auto isTripletAttr =
+      mlir::DenseBoolArrayAttr::get(builder.getContext(), isTriplet);
+  build(builder, result, result_type, memref,
+        /*componentAttr=*/mlir::StringAttr{}, /*component_shape=*/mlir::Value{},
+        indices, isTripletAttr, /*substring*/ mlir::ValueRange{},
+        /*complexPartAttr=*/mlir::BoolAttr{}, /*shape=*/mlir::Value{},
+        typeparams, fortran_attrs);
+}
+
 static mlir::ParseResult parseDesignatorIndices(
     mlir::OpAsmParser &parser,
     llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &indices,

diff  --git a/flang/test/Lower/HLFIR/elemental-array-ops.f90 b/flang/test/Lower/HLFIR/elemental-array-ops.f90
new file mode 100644
index 0000000000000..a9bde0ec577de
--- /dev/null
+++ b/flang/test/Lower/HLFIR/elemental-array-ops.f90
@@ -0,0 +1,128 @@
+! Test lowering of elemental intrinsic operations with array arguments to HLFIR
+! RUN: bbc -emit-fir -hlfir -o - %s 2>&1 | FileCheck %s
+
+subroutine binary(x, y)
+  integer :: x(100), y(100)
+  x = x+y
+end subroutine
+! CHECK-LABEL: func.func @_QPbinary(
+! CHECK:  %[[VAL_4:.*]]:2 = hlfir.declare %{{.*}}(%[[VAL_3:[^)]*]]) {{.*}}x
+! CHECK:  %[[VAL_7:.*]]:2 = hlfir.declare %{{.*}}(%[[VAL_6:[^)]*]]) {{.*}}y
+! CHECK:  %[[VAL_8:.*]] = hlfir.elemental %[[VAL_3]] : (!fir.shape<1>) -> !hlfir.expr<100xi32> {
+! CHECK:  ^bb0(%[[VAL_9:.*]]: index):
+! CHECK:    %[[VAL_10:.*]] = hlfir.designate %[[VAL_4]]#0 (%[[VAL_9]])  : (!fir.ref<!fir.array<100xi32>>, index) -> !fir.ref<i32>
+! CHECK:    %[[VAL_11:.*]] = hlfir.designate %[[VAL_7]]#0 (%[[VAL_9]])  : (!fir.ref<!fir.array<100xi32>>, index) -> !fir.ref<i32>
+! CHECK:    %[[VAL_12:.*]] = fir.load %[[VAL_10]] : !fir.ref<i32>
+! CHECK:    %[[VAL_13:.*]] = fir.load %[[VAL_11]] : !fir.ref<i32>
+! CHECK:    %[[VAL_14:.*]] = arith.addi %[[VAL_12]], %[[VAL_13]] : i32
+! CHECK:    hlfir.yield_element %[[VAL_14]] : i32
+! CHECK:  }
+
+subroutine binary_with_scalar_and_array(x, y)
+  integer :: x(100), y
+  x = x+y
+end subroutine
+! CHECK-LABEL: func.func @_QPbinary_with_scalar_and_array(
+! CHECK:  %[[VAL_4:.*]]:2 = hlfir.declare %{{.*}}(%[[VAL_3:[^)]*]]) {{.*}}x
+! CHECK:  %[[VAL_5:.*]]:2 = hlfir.declare %{{.*}} {{.*}}y
+! CHECK:  %[[VAL_6:.*]] = fir.load %[[VAL_5]]#0 : !fir.ref<i32>
+! CHECK:  %[[VAL_7:.*]] = hlfir.elemental %[[VAL_3]] : (!fir.shape<1>) -> !hlfir.expr<100xi32> {
+! CHECK:  ^bb0(%[[VAL_8:.*]]: index):
+! CHECK:    %[[VAL_9:.*]] = hlfir.designate %[[VAL_4]]#0 (%[[VAL_8]])  : (!fir.ref<!fir.array<100xi32>>, index) -> !fir.ref<i32>
+! CHECK:    %[[VAL_10:.*]] = fir.load %[[VAL_9]] : !fir.ref<i32>
+! CHECK:    %[[VAL_11:.*]] = arith.addi %[[VAL_10]], %[[VAL_6]] : i32
+! CHECK:    hlfir.yield_element %[[VAL_11]] : i32
+! CHECK:  }
+
+subroutine char_binary(x, y)
+  character(*) :: x(100), y(100)
+  call test_char(x//y)
+end subroutine
+! CHECK-LABEL: func.func @_QPchar_binary(
+! CHECK:  %[[VAL_6:.*]]:2 = hlfir.declare %{{.*}}(%[[VAL_5:.*]]) typeparams %[[VAL_2:.*]]#1 {{.*}}x
+! CHECK:  %[[VAL_11:.*]]:2 = hlfir.declare %{{.*}}(%[[VAL_10:.*]]) typeparams %[[VAL_7:.*]]#1 {{.*}}y
+! CHECK:  %[[VAL_12:.*]] = arith.addi %[[VAL_2]]#1, %[[VAL_7]]#1 : index
+! CHECK:  %[[VAL_13:.*]] = hlfir.elemental %[[VAL_5]] typeparams %[[VAL_12]] : (!fir.shape<1>, index) -> !hlfir.expr<100x!fir.char<1,?>> {
+! CHECK:  ^bb0(%[[VAL_14:.*]]: index):
+! CHECK:    %[[VAL_15:.*]] = hlfir.designate %[[VAL_6]]#0 (%[[VAL_14]])  typeparams %[[VAL_2]]#1 : (!fir.box<!fir.array<100x!fir.char<1,?>>>, index, index) -> !fir.boxchar<1>
+! CHECK:    %[[VAL_16:.*]] = hlfir.designate %[[VAL_11]]#0 (%[[VAL_14]])  typeparams %[[VAL_7]]#1 : (!fir.box<!fir.array<100x!fir.char<1,?>>>, index, index) -> !fir.boxchar<1>
+! CHECK:    %[[VAL_17:.*]] = hlfir.concat %[[VAL_15]], %[[VAL_16]] len %[[VAL_12]] : (!fir.boxchar<1>, !fir.boxchar<1>, index) -> !hlfir.expr<!fir.char<1,?>>
+! CHECK:    hlfir.yield_element %[[VAL_17]] : !hlfir.expr<!fir.char<1,?>>
+! CHECK:  }
+
+subroutine unary(x, n)
+  integer :: n
+  logical :: x(n)
+  x = .not.x
+end subroutine
+! CHECK-LABEL: func.func @_QPunary(
+! CHECK:  %[[VAL_10:.*]]:2 = hlfir.declare %{{.*}}(%[[VAL_9:[^)]*]]) {{.*}}x
+! CHECK:  %[[VAL_11:.*]] = hlfir.elemental %[[VAL_9]] : (!fir.shape<1>) -> !hlfir.expr<?x!fir.logical<4>> {
+! CHECK:  ^bb0(%[[VAL_12:.*]]: index):
+! CHECK:    %[[VAL_13:.*]] = hlfir.designate %[[VAL_10]]#0 (%[[VAL_12]])  : (!fir.box<!fir.array<?x!fir.logical<4>>>, index) -> !fir.ref<!fir.logical<4>>
+! CHECK:    %[[VAL_14:.*]] = fir.load %[[VAL_13]] : !fir.ref<!fir.logical<4>>
+! CHECK:    %[[VAL_15:.*]] = arith.constant true
+! CHECK:    %[[VAL_16:.*]] = fir.convert %[[VAL_14]] : (!fir.logical<4>) -> i1
+! CHECK:    %[[VAL_17:.*]] = arith.xori %[[VAL_16]], %[[VAL_15]] : i1
+! CHECK:    %[[VAL_18:.*]] = fir.convert %[[VAL_17]] : (i1) -> !fir.logical<4>
+! CHECK:    hlfir.yield_element %[[VAL_18]] : !fir.logical<4>
+! CHECK:  }
+
+subroutine char_unary(x)
+  character(10) :: x(20)
+  call test_char_2((x))
+end subroutine
+! CHECK-LABEL: func.func @_QPchar_unary(
+! CHECK:  %[[VAL_6:.*]]:2 = hlfir.declare %{{.*}}(%[[VAL_5:.*]]) typeparams %[[VAL_2:[^ ]*]] {{.*}}x
+! CHECK:  %[[VAL_7:.*]] = hlfir.elemental %[[VAL_5]] typeparams %[[VAL_2]] : (!fir.shape<1>, index) -> !hlfir.expr<20x!fir.char<1,?>> {
+! CHECK:  ^bb0(%[[VAL_8:.*]]: index):
+! CHECK:    %[[VAL_9:.*]] = hlfir.designate %[[VAL_6]]#0 (%[[VAL_8]])  typeparams %[[VAL_2]] : (!fir.ref<!fir.array<20x!fir.char<1,10>>>, index, index) -> !fir.ref<!fir.char<1,10>>
+! CHECK:    %[[VAL_10:.*]] = hlfir.as_expr %[[VAL_9]] : (!fir.ref<!fir.char<1,10>>) -> !hlfir.expr<!fir.char<1,10>>
+! CHECK:    hlfir.yield_element %[[VAL_10]] : !hlfir.expr<!fir.char<1,10>>
+! CHECK:  }
+
+subroutine chained_elemental(x, y, z)
+  integer :: x(100), y(100), z(100)
+  x = x+y+z
+end subroutine
+! CHECK-LABEL: func.func @_QPchained_elemental(
+! CHECK:  %[[VAL_5:.*]]:2 = hlfir.declare %{{.*}}(%[[VAL_4:[^)]*]]) {{.*}}x
+! CHECK:  %[[VAL_8:.*]]:2 = hlfir.declare %{{.*}}(%[[VAL_7:[^)]*]]) {{.*}}y
+! CHECK:  %[[VAL_11:.*]]:2 = hlfir.declare %{{.*}}(%[[VAL_10:[^)]*]]) {{.*}}z
+! CHECK:  %[[VAL_12:.*]] = hlfir.elemental %[[VAL_4]] : (!fir.shape<1>) -> !hlfir.expr<100xi32> {
+! CHECK:  ^bb0(%[[VAL_13:.*]]: index):
+! CHECK:    %[[VAL_14:.*]] = hlfir.designate %[[VAL_5]]#0 (%[[VAL_13]])  : (!fir.ref<!fir.array<100xi32>>, index) -> !fir.ref<i32>
+! CHECK:    %[[VAL_15:.*]] = hlfir.designate %[[VAL_8]]#0 (%[[VAL_13]])  : (!fir.ref<!fir.array<100xi32>>, index) -> !fir.ref<i32>
+! CHECK:    %[[VAL_16:.*]] = fir.load %[[VAL_14]] : !fir.ref<i32>
+! CHECK:    %[[VAL_17:.*]] = fir.load %[[VAL_15]] : !fir.ref<i32>
+! CHECK:    %[[VAL_18:.*]] = arith.addi %[[VAL_16]], %[[VAL_17]] : i32
+! CHECK:    hlfir.yield_element %[[VAL_18]] : i32
+! CHECK:  }
+! CHECK:  %[[VAL_19:.*]] = hlfir.elemental %[[VAL_4]] : (!fir.shape<1>) -> !hlfir.expr<100xi32> {
+! CHECK:  ^bb0(%[[VAL_20:.*]]: index):
+! CHECK:    %[[VAL_21:.*]] = hlfir.apply %[[VAL_22:.*]], %[[VAL_20]] : (!hlfir.expr<100xi32>, index) -> i32
+! CHECK:    %[[VAL_23:.*]] = hlfir.designate %[[VAL_11]]#0 (%[[VAL_20]])  : (!fir.ref<!fir.array<100xi32>>, index) -> !fir.ref<i32>
+! CHECK:    %[[VAL_24:.*]] = fir.load %[[VAL_23]] : !fir.ref<i32>
+! CHECK:    %[[VAL_25:.*]] = arith.addi %[[VAL_21]], %[[VAL_24]] : i32
+! CHECK:    hlfir.yield_element %[[VAL_25]] : i32
+! CHECK:  }
+
+subroutine lower_bounds(x)
+  integer :: x(2:101)
+  call test((x))
+end subroutine
+! CHECK-LABEL: func.func @_QPlower_bounds(
+! CHECK:  %[[VAL_1:.*]] = arith.constant 2 : index
+! CHECK:  %[[VAL_2:.*]] = arith.constant 100 : index
+! CHECK:  %[[VAL_4:.*]]:2 = hlfir.declare %{{.*}}(%[[VAL_3:[^)]*]]) {{.*}}x
+! CHECK:  %[[VAL_5:.*]] = fir.shape %[[VAL_2]] : (index) -> !fir.shape<1>
+! CHECK:  %[[VAL_6:.*]] = hlfir.elemental %[[VAL_5]] : (!fir.shape<1>) -> !hlfir.expr<100xi32> {
+! CHECK:  ^bb0(%[[VAL_7:.*]]: index):
+! CHECK:    %[[VAL_8:.*]] = arith.constant 1 : index
+! CHECK:    %[[VAL_9:.*]] = arith.subi %[[VAL_1]], %[[VAL_8]] : index
+! CHECK:    %[[VAL_10:.*]] = arith.addi %[[VAL_7]], %[[VAL_9]] : index
+! CHECK:    %[[VAL_11:.*]] = hlfir.designate %[[VAL_4]]#0 (%[[VAL_10]])  : (!fir.box<!fir.array<100xi32>>, index) -> !fir.ref<i32>
+! CHECK:    %[[VAL_12:.*]] = fir.load %[[VAL_11]] : !fir.ref<i32>
+! CHECK:    %[[VAL_13:.*]] = hlfir.no_reassoc %[[VAL_12]] : i32
+! CHECK:    hlfir.yield_element %[[VAL_13]] : i32
+! CHECK:  }


        


More information about the flang-commits mailing list