[flang-commits] [flang] 93fea7d - [flang][hlfir] Support mold operand for hlfir.elemental.

Slava Zakharin via flang-commits flang-commits at lists.llvm.org
Tue Aug 8 09:59:03 PDT 2023


Author: Slava Zakharin
Date: 2023-08-08T09:58:48-07:00
New Revision: 93fea7dd11477983a6903187fa6fec65be1ffe1b

URL: https://github.com/llvm/llvm-project/commit/93fea7dd11477983a6903187fa6fec65be1ffe1b
DIFF: https://github.com/llvm/llvm-project/commit/93fea7dd11477983a6903187fa6fec65be1ffe1b.diff

LOG: [flang][hlfir] Support mold operand for hlfir.elemental.

To properly create temporary array for a polymorphic result
of hlfir.elemental we need to keep the mold as its operand.
This patch adds just the basic support.

Reviewed By: clementval, tblah

Differential Revision: https://reviews.llvm.org/D157315

Added: 
    

Modified: 
    flang/include/flang/Optimizer/Builder/HLFIRTools.h
    flang/include/flang/Optimizer/HLFIR/HLFIRDialect.h
    flang/include/flang/Optimizer/HLFIR/HLFIROpBase.td
    flang/include/flang/Optimizer/HLFIR/HLFIROps.td
    flang/lib/Lower/ConvertArrayConstructor.cpp
    flang/lib/Optimizer/Builder/HLFIRTools.cpp
    flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp
    flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
    flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
    flang/test/HLFIR/elemental.fir
    flang/test/HLFIR/invalid.fir

Removed: 
    


################################################################################
diff  --git a/flang/include/flang/Optimizer/Builder/HLFIRTools.h b/flang/include/flang/Optimizer/Builder/HLFIRTools.h
index 9d631dba9412cf..6d73ebc3a7e1d9 100644
--- a/flang/include/flang/Optimizer/Builder/HLFIRTools.h
+++ b/flang/include/flang/Optimizer/Builder/HLFIRTools.h
@@ -374,14 +374,15 @@ using ElementalKernelGenerator = std::function<hlfir::Entity(
     mlir::Location, fir::FirOpBuilder &, mlir::ValueRange)>;
 /// Generate an hlfir.elementalOp given call back to generate the element
 /// value at for each iteration.
-/// If exprType is specified, this will be the return type of the elemental op
-hlfir::ElementalOp genElementalOp(mlir::Location loc,
-                                  fir::FirOpBuilder &builder,
-                                  mlir::Type elementType, mlir::Value shape,
-                                  mlir::ValueRange typeParams,
-                                  const ElementalKernelGenerator &genKernel,
-                                  bool isUnordered = false,
-                                  mlir::Type exprType = mlir::Type{});
+/// If exprType is specified, this will be the return type of the elemental op.
+/// If exprType is not specified, the resulting expression type is computed
+/// from the given \p elementType and \p shape, and the type is polymorphic
+/// if \p polymorphicMold is present.
+hlfir::ElementalOp genElementalOp(
+    mlir::Location loc, fir::FirOpBuilder &builder, mlir::Type elementType,
+    mlir::Value shape, mlir::ValueRange typeParams,
+    const ElementalKernelGenerator &genKernel, bool isUnordered = false,
+    mlir::Value polymorphicMold = {}, mlir::Type exprType = mlir::Type{});
 
 /// Structure to describe a loop nest.
 struct LoopNest {

diff  --git a/flang/include/flang/Optimizer/HLFIR/HLFIRDialect.h b/flang/include/flang/Optimizer/HLFIR/HLFIRDialect.h
index d080286f0e0929..b76063fb7c5353 100644
--- a/flang/include/flang/Optimizer/HLFIR/HLFIRDialect.h
+++ b/flang/include/flang/Optimizer/HLFIR/HLFIRDialect.h
@@ -87,6 +87,7 @@ bool isPassByRefOrIntegerType(mlir::Type);
 bool isI1Type(mlir::Type);
 // scalar i1 or logical, or sequence of logical (via (boxed?) array or expr)
 bool isMaskArgument(mlir::Type);
+bool isPolymorphicObject(mlir::Type);
 
 /// If an expression's extents are known at compile time, generate a fir.shape
 /// for this expression. Otherwise return {}

diff  --git a/flang/include/flang/Optimizer/HLFIR/HLFIROpBase.td b/flang/include/flang/Optimizer/HLFIR/HLFIROpBase.td
index 324689d22d4cbb..018e187ed46e69 100644
--- a/flang/include/flang/Optimizer/HLFIR/HLFIROpBase.td
+++ b/flang/include/flang/Optimizer/HLFIR/HLFIROpBase.td
@@ -149,6 +149,11 @@ def IsFortranLogicalArrayPred
 def AnyFortranLogicalArrayObject : Type<IsFortranLogicalArrayPred,
     "any array-like object containing logicals">;
 
+def IsPolymorphicObjectPred
+        : CPred<"::hlfir::isPolymorphicObject($_self)">;
+def AnyPolymorphicObject : Type<IsPolymorphicObjectPred,
+    "any polymorphic object">;
+
 def hlfir_CharExtremumPredicateAttr : I32EnumAttr<
     "CharExtremumPredicate", "",
     [

diff  --git a/flang/include/flang/Optimizer/HLFIR/HLFIROps.td b/flang/include/flang/Optimizer/HLFIR/HLFIROps.td
index d5114ec3de9b7d..24c2dad497fd2c 100644
--- a/flang/include/flang/Optimizer/HLFIR/HLFIROps.td
+++ b/flang/include/flang/Optimizer/HLFIR/HLFIROps.td
@@ -740,7 +740,7 @@ def hlfir_ElementalOpInterface : OpInterface<"ElementalOpInterface"> {
   let cppNamespace = "hlfir";
 }
 
-def hlfir_ElementalOp : hlfir_Op<"elemental", [RecursiveMemoryEffects, hlfir_ElementalOpInterface]> {
+def hlfir_ElementalOp : hlfir_Op<"elemental", [RecursiveMemoryEffects, hlfir_ElementalOpInterface, AttrSizedOperandSegments]> {
   let summary = "elemental expression";
   let description = [{
     Represent an elemental expression as a function of the indices.
@@ -753,6 +753,12 @@ def hlfir_ElementalOp : hlfir_Op<"elemental", [RecursiveMemoryEffects, hlfir_Ele
     The shape and typeparams operands represent the extents and type
     parameters of the resulting array value.
 
+    The optional mold is an entity carrying the information about
+    the dynamic type of the polymorphic result. Note that the shape
+    of the mold does not necessarily match the shape of the result,
+    for example, the result of `merge(poly_scalar1, poly_scalar2, mask_array)`
+    will have the shape of `mask_array` and the dynamic type of `poly_scalar*`.
+
     The unordered attribute can be set to allow out of order processing
     of the indices. This is safe only if the operations in the body
     of the elemental do not have side effects.
@@ -775,6 +781,7 @@ def hlfir_ElementalOp : hlfir_Op<"elemental", [RecursiveMemoryEffects, hlfir_Ele
 
   let arguments = (ins
     AnyShapeType:$shape,
+    Optional<AnyPolymorphicObject>:$mold,
     Variadic<AnyIntegerType>:$typeparams,
     OptionalAttr<UnitAttr>:$unordered
   );
@@ -783,7 +790,8 @@ def hlfir_ElementalOp : hlfir_Op<"elemental", [RecursiveMemoryEffects, hlfir_Ele
   let regions = (region SizedRegion<1>:$region);
 
   let assemblyFormat = [{
-    $shape (`typeparams` $typeparams^)? (`unordered` $unordered^)?
+    $shape (`mold` $mold^)? (`typeparams` $typeparams^)?
+    (`unordered` $unordered^)?
     attr-dict `:` functional-type(operands, results)
     $region
     }];
@@ -808,10 +816,12 @@ def hlfir_ElementalOp : hlfir_Op<"elemental", [RecursiveMemoryEffects, hlfir_Ele
   let skipDefaultBuilders = 1;
   let builders = [
     OpBuilder<(ins "mlir::Type":$result_type, "mlir::Value":$shape,
+      CArg<"mlir::Value", "{}">:$mold,
       CArg<"mlir::ValueRange", "{}">:$typeparams,
       CArg<"bool", "false">:$isUnordered)>
   ];
 
+  let hasVerifier = 1;
 }
 
 def hlfir_YieldElementOp : hlfir_Op<"yield_element", [Terminator, HasParent<"ElementalOp">, Pure]> {

diff  --git a/flang/lib/Lower/ConvertArrayConstructor.cpp b/flang/lib/Lower/ConvertArrayConstructor.cpp
index 2ef500ecf22dba..24aa9beba6bf48 100644
--- a/flang/lib/Lower/ConvertArrayConstructor.cpp
+++ b/flang/lib/Lower/ConvertArrayConstructor.cpp
@@ -214,9 +214,9 @@ class AsElementalStrategy : public StrategyBase {
     assert(!elementalOp && "expected only one implied-do");
     mlir::Value one =
         builder.createIntegerConstant(loc, builder.getIndexType(), 1);
-    elementalOp =
-        builder.create<hlfir::ElementalOp>(loc, exprType, shape, lengthParams,
-                                           /*isUnordered=*/true);
+    elementalOp = builder.create<hlfir::ElementalOp>(
+        loc, exprType, shape,
+        /*mold=*/nullptr, lengthParams, /*isUnordered=*/true);
     builder.setInsertionPointToStart(elementalOp.getBody());
     // implied-do-index = lower+((i-1)*stride)
     mlir::Value 
diff  = builder.create<mlir::arith::SubIOp>(

diff  --git a/flang/lib/Optimizer/Builder/HLFIRTools.cpp b/flang/lib/Optimizer/Builder/HLFIRTools.cpp
index 39346094911649..dd62aa0e370122 100644
--- a/flang/lib/Optimizer/Builder/HLFIRTools.cpp
+++ b/flang/lib/Optimizer/Builder/HLFIRTools.cpp
@@ -737,16 +737,15 @@ static hlfir::ExprType getArrayExprType(mlir::Type elementType,
                               isPolymorphic);
 }
 
-hlfir::ElementalOp
-hlfir::genElementalOp(mlir::Location loc, fir::FirOpBuilder &builder,
-                      mlir::Type elementType, mlir::Value shape,
-                      mlir::ValueRange typeParams,
-                      const ElementalKernelGenerator &genKernel,
-                      bool isUnordered, mlir::Type exprType) {
+hlfir::ElementalOp hlfir::genElementalOp(
+    mlir::Location loc, fir::FirOpBuilder &builder, mlir::Type elementType,
+    mlir::Value shape, mlir::ValueRange typeParams,
+    const ElementalKernelGenerator &genKernel, bool isUnordered,
+    mlir::Value polymorphicMold, mlir::Type exprType) {
   if (!exprType)
-    exprType = getArrayExprType(elementType, shape, false);
+    exprType = getArrayExprType(elementType, shape, !!polymorphicMold);
   auto elementalOp = builder.create<hlfir::ElementalOp>(
-      loc, exprType, shape, typeParams, isUnordered);
+      loc, exprType, shape, polymorphicMold, typeParams, isUnordered);
   auto insertPt = builder.saveInsertionPoint();
   builder.setInsertionPointToStart(elementalOp.getBody());
   mlir::Value elementResult = genKernel(loc, builder, elementalOp.getIndices());

diff  --git a/flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp b/flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp
index 1f4f62f29e3dbd..7ca6108a31acbb 100644
--- a/flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp
+++ b/flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp
@@ -181,6 +181,13 @@ bool hlfir::isMaskArgument(mlir::Type type) {
   return mlir::isa<fir::LogicalType>(elementType) || isI1Type(elementType);
 }
 
+bool hlfir::isPolymorphicObject(mlir::Type type) {
+  if (auto exprType = mlir::dyn_cast<hlfir::ExprType>(type))
+    return exprType.isPolymorphic();
+
+  return fir::isPolymorphicType(type);
+}
+
 mlir::Value hlfir::genExprShape(mlir::OpBuilder &builder,
                                 const mlir::Location &loc,
                                 const hlfir::ExprType &expr) {

diff  --git a/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp b/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
index a74d6f94f4df14..0b4b9c1588efaf 100644
--- a/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
+++ b/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
@@ -1036,10 +1036,17 @@ void hlfir::AsExprOp::build(mlir::OpBuilder &builder,
 void hlfir::ElementalOp::build(mlir::OpBuilder &builder,
                                mlir::OperationState &odsState,
                                mlir::Type resultType, mlir::Value shape,
-                               mlir::ValueRange typeparams, bool isUnordered) {
+                               mlir::Value mold, mlir::ValueRange typeparams,
+                               bool isUnordered) {
   odsState.addOperands(shape);
+  if (mold)
+    odsState.addOperands(mold);
   odsState.addOperands(typeparams);
   odsState.addTypes(resultType);
+  odsState.addAttribute(
+      getOperandSegmentSizesAttrName(odsState.name),
+      builder.getDenseI32ArrayAttr({/*shape=*/1, (mold ? 1 : 0),
+                                    static_cast<int32_t>(typeparams.size())}));
   if (isUnordered)
     odsState.addAttribute(getUnorderedAttrName(odsState.name),
                           isUnordered ? builder.getUnitAttr() : nullptr);
@@ -1057,6 +1064,16 @@ mlir::Value hlfir::ElementalOp::getElementEntity() {
   return mlir::cast<hlfir::YieldElementOp>(getBody()->back()).getElementValue();
 }
 
+mlir::LogicalResult hlfir::ElementalOp::verify() {
+  mlir::Value mold = getMold();
+  hlfir::ExprType resultType = mlir::cast<hlfir::ExprType>(getType());
+  if (!!mold != resultType.isPolymorphic())
+    return emitOpError("result must be polymorphic when mold is present "
+                       "and vice versa");
+
+  return mlir::success();
+}
+
 //===----------------------------------------------------------------------===//
 // ApplyOp
 //===----------------------------------------------------------------------===//

diff  --git a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
index 6206deee411c34..5f065056bac00c 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
@@ -58,7 +58,8 @@ class TransposeAsElementalConversion
     };
     hlfir::ElementalOp elementalOp = hlfir::genElementalOp(
         loc, builder, elementType, resultShape, typeParams, genKernel,
-        /*isUnordered=*/true, transpose.getResult().getType());
+        /*isUnordered=*/true, /*polymorphicMold=*/nullptr,
+        transpose.getResult().getType());
 
     // it wouldn't be safe to replace block arguments with a 
diff erent
     // hlfir.expr type. Types can 
diff er due to 
diff ering amounts of shape

diff  --git a/flang/test/HLFIR/elemental.fir b/flang/test/HLFIR/elemental.fir
index d4cef6705b1768..174c39b99b3721 100644
--- a/flang/test/HLFIR/elemental.fir
+++ b/flang/test/HLFIR/elemental.fir
@@ -99,3 +99,45 @@ func.func @unordered() {
 // CHECK:           }
 // CHECK:           return
 // CHECK:         }
+
+func.func @polymorphic_mold_var(%arg0: !fir.class<!fir.array<?x!fir.type<_QMtypesTt>>>, %shape : index) {
+  %3 = fir.shape %shape : (index) -> !fir.shape<1>
+  %4 = hlfir.elemental %3 mold %arg0 unordered : (!fir.shape<1>, !fir.class<!fir.array<?x!fir.type<_QMtypesTt>>>) -> !hlfir.expr<?x!fir.type<_QMtypesTt>?> {
+  ^bb0(%arg2: index):
+    %6 = fir.undefined !hlfir.expr<!fir.type<_QMtypesTt>?>
+    hlfir.yield_element %6 : !hlfir.expr<!fir.type<_QMtypesTt>?>
+  }
+  return
+}
+// CHECK-LABEL:   func.func @polymorphic_mold_var(
+// CHECK-SAME:        %[[VAL_0:.*]]: !fir.class<!fir.array<?x!fir.type<_QMtypesTt>>>,                   %[[VAL_1:.*]]: index) {
+// CHECK:           %[[VAL_2:.*]] = fir.shape %[[VAL_1]] : (index) -> !fir.shape<1>
+// CHECK:           %[[VAL_3:.*]] = hlfir.elemental %[[VAL_2]] mold %[[VAL_0]] unordered : (!fir.shape<1>, !fir.class<!fir.array<?x!fir.type<_QMtypesTt>>>) -> !hlfir.expr<?x!fir.type<_QMtypesTt>?> {
+// CHECK:           ^bb0(%[[VAL_4:.*]]: index):
+// CHECK:             %[[VAL_5:.*]] = fir.undefined !hlfir.expr<!fir.type<_QMtypesTt>?>
+// CHECK:             hlfir.yield_element %[[VAL_5]] : !hlfir.expr<!fir.type<_QMtypesTt>?>
+// CHECK:           }
+// CHECK:           return
+// CHECK:         }
+
+func.func @polymorphic_mold_expr(%shape : index) {
+  %3 = fir.shape %shape : (index) -> !fir.shape<1>
+  %mold = fir.undefined !hlfir.expr<?x!fir.type<_QMtypesTt>?>
+  %4 = hlfir.elemental %3 mold %mold unordered : (!fir.shape<1>, !hlfir.expr<?x!fir.type<_QMtypesTt>?>) -> !hlfir.expr<?x!fir.type<_QMtypesTt>?> {
+  ^bb0(%arg2: index):
+    %6 = fir.undefined !hlfir.expr<!fir.type<_QMtypesTt>?>
+    hlfir.yield_element %6 : !hlfir.expr<!fir.type<_QMtypesTt>?>
+  }
+  return
+}
+// CHECK-LABEL:   func.func @polymorphic_mold_expr(
+// CHECK-SAME:        %[[VAL_0:.*]]: index) {
+// CHECK:           %[[VAL_1:.*]] = fir.shape %[[VAL_0]] : (index) -> !fir.shape<1>
+// CHECK:           %[[VAL_2:.*]] = fir.undefined !hlfir.expr<?x!fir.type<_QMtypesTt>?>
+// CHECK:           %[[VAL_3:.*]] = hlfir.elemental %[[VAL_1]] mold %[[VAL_2]] unordered : (!fir.shape<1>, !hlfir.expr<?x!fir.type<_QMtypesTt>?>) -> !hlfir.expr<?x!fir.type<_QMtypesTt>?> {
+// CHECK:           ^bb0(%[[VAL_4:.*]]: index):
+// CHECK:             %[[VAL_5:.*]] = fir.undefined !hlfir.expr<!fir.type<_QMtypesTt>?>
+// CHECK:             hlfir.yield_element %[[VAL_5]] : !hlfir.expr<!fir.type<_QMtypesTt>?>
+// CHECK:           }
+// CHECK:           return
+// CHECK:         }

diff  --git a/flang/test/HLFIR/invalid.fir b/flang/test/HLFIR/invalid.fir
index bbfe543a3427eb..db16cb29c62388 100644
--- a/flang/test/HLFIR/invalid.fir
+++ b/flang/test/HLFIR/invalid.fir
@@ -961,3 +961,51 @@ func.func @bad_get_length_3(%arg0: !hlfir.expr<!fir.boxchar<1>>) {
   %1 = hlfir.get_length %arg0 : (!hlfir.expr<!fir.boxchar<1>>) -> index
   return
 }
+
+// -----
+func.func @elemental_poly_1(%arg0: !fir.box<!fir.array<?x!fir.type<_QMtypesTt>>>, %shape : index) {
+  %3 = fir.shape %shape : (index) -> !fir.shape<1>
+  // expected-error at +1 {{'hlfir.elemental' op operand #1 must be any polymorphic object, but got '!fir.box<!fir.array<?x!fir.type<_QMtypesTt>>>'}}
+  %4 = hlfir.elemental %3 mold %arg0 unordered : (!fir.shape<1>, !fir.box<!fir.array<?x!fir.type<_QMtypesTt>>>) -> !hlfir.expr<?x!fir.type<_QMtypesTt>?> {
+  ^bb0(%arg2: index):
+    %6 = fir.undefined !hlfir.expr<!fir.type<_QMtypesTt>?>
+    hlfir.yield_element %6 : !hlfir.expr<!fir.type<_QMtypesTt>?>
+  }
+  return
+}
+
+// -----
+func.func @elemental_poly_2(%arg0: !hlfir.expr<?x!fir.type<_QMtypesTt>>, %shape : index) {
+  %3 = fir.shape %shape : (index) -> !fir.shape<1>
+  // expected-error at +1 {{'hlfir.elemental' op operand #1 must be any polymorphic object, but got '!hlfir.expr<?x!fir.type<_QMtypesTt>>'}}
+  %4 = hlfir.elemental %3 mold %arg0 unordered : (!fir.shape<1>, !hlfir.expr<?x!fir.type<_QMtypesTt>>) -> !hlfir.expr<?x!fir.type<_QMtypesTt>?> {
+  ^bb0(%arg2: index):
+    %6 = fir.undefined !hlfir.expr<!fir.type<_QMtypesTt>?>
+    hlfir.yield_element %6 : !hlfir.expr<!fir.type<_QMtypesTt>?>
+  }
+  return
+}
+
+// -----
+func.func @elemental_poly_3(%arg0: !hlfir.expr<?x!fir.type<_QMtypesTt>?>, %shape : index) {
+  %3 = fir.shape %shape : (index) -> !fir.shape<1>
+// expected-error at +1 {{'hlfir.elemental' op result must be polymorphic when mold is present and vice versa}}
+  %4 = hlfir.elemental %3 mold %arg0 unordered : (!fir.shape<1>, !hlfir.expr<?x!fir.type<_QMtypesTt>?>) -> !hlfir.expr<?x!fir.type<_QMtypesTt>> {
+  ^bb0(%arg2: index):
+    %6 = fir.undefined !hlfir.expr<!fir.type<_QMtypesTt>>
+    hlfir.yield_element %6 : !hlfir.expr<!fir.type<_QMtypesTt>>
+  }
+  return
+}
+
+// -----
+func.func @elemental_poly_4(%shape : index) {
+  %3 = fir.shape %shape : (index) -> !fir.shape<1>
+// expected-error at +1 {{'hlfir.elemental' op result must be polymorphic when mold is present and vice versa}}
+  %4 = hlfir.elemental %3 unordered : (!fir.shape<1>) -> !hlfir.expr<?x!fir.type<_QMtypesTt>?> {
+  ^bb0(%arg2: index):
+    %6 = fir.undefined !hlfir.expr<!fir.type<_QMtypesTt>?>
+    hlfir.yield_element %6 : !hlfir.expr<!fir.type<_QMtypesTt>?>
+  }
+  return
+}


        


More information about the flang-commits mailing list