[flang-commits] [flang] 3ad2606 - [flang] add hlfir.sum operation

Tom Eccles via flang-commits flang-commits at lists.llvm.org
Mon Feb 13 02:52:14 PST 2023


Author: Tom Eccles
Date: 2023-02-13T10:50:11Z
New Revision: 3ad26060e4bceb2cf9f4959d659cbb29d88344cf

URL: https://github.com/llvm/llvm-project/commit/3ad26060e4bceb2cf9f4959d659cbb29d88344cf
DIFF: https://github.com/llvm/llvm-project/commit/3ad26060e4bceb2cf9f4959d659cbb29d88344cf.diff

LOG: [flang] add hlfir.sum operation

Add an HLFIR operation for the SUM transformational intrinsic, according
to the design set out in flang/doc/HighLevelFIR.md.

I decided to make hlfir.sum very lenient about the form of its
arguments. This allows the sum intrinsic to be lowered to only this HLFIR
operation, without needing several operations to convert and box
arguments. Having only one operation generated for the intrinsic
invocation should make optimisation passes on HLFIR simpler.

Differential Revision: https://reviews.llvm.org/D142897

Added: 
    flang/test/HLFIR/sum.fir

Modified: 
    flang/include/flang/Optimizer/Builder/HLFIRTools.h
    flang/include/flang/Optimizer/HLFIR/HLFIRDialect.h
    flang/include/flang/Optimizer/HLFIR/HLFIROpBase.td
    flang/include/flang/Optimizer/HLFIR/HLFIROps.h
    flang/include/flang/Optimizer/HLFIR/HLFIROps.td
    flang/lib/Optimizer/Builder/HLFIRTools.cpp
    flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp
    flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
    flang/test/HLFIR/invalid.fir

Removed: 
    


################################################################################
diff  --git a/flang/include/flang/Optimizer/Builder/HLFIRTools.h b/flang/include/flang/Optimizer/Builder/HLFIRTools.h
index e610478c34994..389f5ce1c3f25 100644
--- a/flang/include/flang/Optimizer/Builder/HLFIRTools.h
+++ b/flang/include/flang/Optimizer/Builder/HLFIRTools.h
@@ -373,6 +373,7 @@ convertToAddress(mlir::Location loc, fir::FirOpBuilder &builder,
 std::pair<fir::ExtendedValue, std::optional<hlfir::CleanupFunction>>
 convertToBox(mlir::Location loc, fir::FirOpBuilder &builder,
              const hlfir::Entity &entity, mlir::Type targetType);
+
 } // namespace hlfir
 
 #endif // FORTRAN_OPTIMIZER_BUILDER_HLFIRTOOLS_H

diff  --git a/flang/include/flang/Optimizer/HLFIR/HLFIRDialect.h b/flang/include/flang/Optimizer/HLFIR/HLFIRDialect.h
index 8aace6b8ffb06..6a9acb443f9d8 100644
--- a/flang/include/flang/Optimizer/HLFIR/HLFIRDialect.h
+++ b/flang/include/flang/Optimizer/HLFIR/HLFIRDialect.h
@@ -69,6 +69,13 @@ inline bool isBoxAddressOrValueType(mlir::Type type) {
   return fir::unwrapRefType(type).isa<fir::BaseBoxType>();
 }
 
+bool isFortranScalarNumericalType(mlir::Type);
+bool isFortranNumericalArrayObject(mlir::Type);
+bool isPassByRefOrIntegerType(mlir::Type);
+bool isI1Type(mlir::Type);
+// scalar i1 or logical, or sequence of logical (via (boxed?) array or expr)
+bool isMaskArgument(mlir::Type);
+
 } // namespace hlfir
 
 #endif // FORTRAN_OPTIMIZER_HLFIR_HLFIRDIALECT_H

diff  --git a/flang/include/flang/Optimizer/HLFIR/HLFIROpBase.td b/flang/include/flang/Optimizer/HLFIR/HLFIROpBase.td
index d17a4cbf5e1b6..23ad2eda36732 100644
--- a/flang/include/flang/Optimizer/HLFIR/HLFIROpBase.td
+++ b/flang/include/flang/Optimizer/HLFIR/HLFIROpBase.td
@@ -107,4 +107,19 @@ def IsFortranScalarCharacterExprPred
 def AnyScalarCharacterExpr : Type<IsFortranScalarCharacterExprPred,
     "any character scalar expression type">;
 
+def IsFortranNumericalArrayObjectPred
+        : CPred<"::hlfir::isFortranNumericalArrayObject($_self)">;
+def AnyFortranNumericalArrayObject : Type<IsFortranNumericalArrayObjectPred,
+    "any array-like object containing a numerical type">;
+
+def IsPassByRefOrIntegerTypePred
+        : CPred<"::hlfir::isPassByRefOrIntegerType($_self)">;
+def AnyPassByRefOrIntegerType : Type<IsPassByRefOrIntegerTypePred,
+    "an integer type either by value or by reference">;
+
+def IsMaskArgumentPred
+        : CPred<"::hlfir::isMaskArgument($_self)">;
+def AnyFortranLogicalOrI1ArrayObject : Type<IsMaskArgumentPred,
+    "A scalar i1 or logical or an array-like object containing logicals">;
+
 #endif // FORTRAN_DIALECT_HLFIR_OP_BASE

diff  --git a/flang/include/flang/Optimizer/HLFIR/HLFIROps.h b/flang/include/flang/Optimizer/HLFIR/HLFIROps.h
index 33530d12e6173..e0e718346c115 100644
--- a/flang/include/flang/Optimizer/HLFIR/HLFIROps.h
+++ b/flang/include/flang/Optimizer/HLFIR/HLFIROps.h
@@ -13,6 +13,7 @@
 #include "flang/Optimizer/Dialect/FIRType.h"
 #include "flang/Optimizer/Dialect/FortranVariableInterface.h"
 #include "flang/Optimizer/HLFIR/HLFIRDialect.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 

diff  --git a/flang/include/flang/Optimizer/HLFIR/HLFIROps.td b/flang/include/flang/Optimizer/HLFIR/HLFIROps.td
index 3228685122834..1321eab11041e 100644
--- a/flang/include/flang/Optimizer/HLFIR/HLFIROps.td
+++ b/flang/include/flang/Optimizer/HLFIR/HLFIROps.td
@@ -18,6 +18,8 @@ include "flang/Optimizer/HLFIR/HLFIROpBase.td"
 include "flang/Optimizer/Dialect/FIRTypes.td"
 include "flang/Optimizer/Dialect/FIRAttr.td"
 include "flang/Optimizer/Dialect/FortranVariableInterface.td"
+include "mlir/Dialect/Arith/IR/ArithBase.td"
+include "mlir/Dialect/Arith/IR/ArithOpsInterfaces.td"
 include "mlir/IR/BuiltinAttributes.td"
 
 // Base class for FIR operations.
@@ -256,6 +258,31 @@ def hlfir_SetLengthOp : hlfir_Op<"set_length", []> {
   let builders = [OpBuilder<(ins "mlir::Value":$string,"mlir::Value":$len)>];
 }
 
+def hlfir_SumOp : hlfir_Op<"sum", [AttrSizedOperandSegments,
+    DeclareOpInterfaceMethods<ArithFastMathInterface>]> {
+  let summary = "SUM transformational intrinsic";
+  let description = [{
+    Sums the elements of an array, optionally along a particular dimension,
+    optionally if a mask is true.
+  }];
+
+  let arguments = (ins
+    AnyFortranNumericalArrayObject:$array,
+    Optional<AnyIntegerType>:$dim,
+    Optional<AnyFortranLogicalOrI1ArrayObject>:$mask,
+    DefaultValuedAttr<Arith_FastMathAttr,
+                      "::mlir::arith::FastMathFlags::none">:$fastmath
+  );
+
+  let results = (outs hlfir_ExprType);
+
+  let assemblyFormat = [{
+    $array (`dim` $dim^)? (`mask` $mask^)? attr-dict `:` functional-type(operands, results)
+  }];
+
+  let hasVerifier = 1;
+}
+
 def hlfir_AssociateOp : hlfir_Op<"associate", [AttrSizedOperandSegments,
     DeclareOpInterfaceMethods<fir_FortranVariableOpInterface>]> {
   let summary = "Create a variable from an expression value";

diff  --git a/flang/lib/Optimizer/Builder/HLFIRTools.cpp b/flang/lib/Optimizer/Builder/HLFIRTools.cpp
index 41cc800ac182c..072fb5c0fc27e 100644
--- a/flang/lib/Optimizer/Builder/HLFIRTools.cpp
+++ b/flang/lib/Optimizer/Builder/HLFIRTools.cpp
@@ -17,6 +17,8 @@
 #include "flang/Optimizer/Builder/Todo.h"
 #include "flang/Optimizer/HLFIR/HLFIROps.h"
 #include "mlir/IR/IRMapping.h"
+#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/TypeSwitch.h"
 #include <optional>
 
 // Return explicit extents. If the base is a fir.box, this won't read it to

diff  --git a/flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp b/flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp
index fbb7d77e25ca4..f23be5de3be14 100644
--- a/flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp
+++ b/flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp
@@ -11,6 +11,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "flang/Optimizer/HLFIR/HLFIRDialect.h"
+#include "flang/Optimizer/Dialect/FIRType.h"
 #include "flang/Optimizer/HLFIR/HLFIROps.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinTypes.h"
@@ -99,3 +100,43 @@ bool hlfir::isFortranScalarCharacterExprType(mlir::Type type) {
            exprType.getElementType().isa<fir::CharacterType>();
   return false;
 }
+
+bool hlfir::isFortranScalarNumericalType(mlir::Type type) {
+  return fir::isa_integer(type) || fir::isa_real(type) ||
+         fir::isa_complex(type);
+}
+
+bool hlfir::isFortranNumericalArrayObject(mlir::Type type) {
+  if (isBoxAddressType(type))
+    return false;
+  if (auto arrayTy =
+          getFortranElementOrSequenceType(type).dyn_cast<fir::SequenceType>())
+    return isFortranScalarNumericalType(arrayTy.getEleTy());
+  return false;
+}
+
+bool hlfir::isPassByRefOrIntegerType(mlir::Type type) {
+  mlir::Type unwrappedType = fir::unwrapPassByRefType(type);
+  return fir::isa_integer(unwrappedType);
+}
+
+bool hlfir::isI1Type(mlir::Type type) {
+  if (mlir::IntegerType integer = type.dyn_cast<mlir::IntegerType>())
+    if (integer.getWidth() == 1)
+      return true;
+  return false;
+}
+
+bool hlfir::isMaskArgument(mlir::Type type) {
+  if (isBoxAddressType(type))
+    return false;
+
+  mlir::Type unwrappedType = fir::unwrapPassByRefType(fir::unwrapRefType(type));
+  mlir::Type elementType = getFortranElementType(unwrappedType);
+  if (unwrappedType != elementType)
+    // input type is an array
+    return mlir::isa<fir::LogicalType>(elementType);
+
+  // input is a scalar, so allow i1 too
+  return mlir::isa<fir::LogicalType>(elementType) || isI1Type(elementType);
+}

diff  --git a/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp b/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
index 9bf5601ce6523..feea448b3d4c9 100644
--- a/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
+++ b/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
@@ -11,7 +11,10 @@
 //===----------------------------------------------------------------------===//
 
 #include "flang/Optimizer/HLFIR/HLFIROps.h"
+#include "flang/Optimizer/Builder/HLFIRTools.h"
 #include "flang/Optimizer/Dialect/FIROpsSupport.h"
+#include "flang/Optimizer/Dialect/FIRType.h"
+#include "flang/Optimizer/HLFIR/HLFIRDialect.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/DialectImplementation.h"
@@ -420,6 +423,72 @@ void hlfir::SetLengthOp::build(mlir::OpBuilder &builder,
   build(builder, result, resultType, string, len);
 }
 
+//===----------------------------------------------------------------------===//
+// SumOp
+//===----------------------------------------------------------------------===//
+
+mlir::LogicalResult hlfir::SumOp::verify() {
+  mlir::Operation *op = getOperation();
+
+  auto results = op->getResultTypes();
+  assert(results.size() == 1);
+
+  mlir::Value array = getArray();
+  mlir::Value mask = getMask();
+
+  fir::SequenceType arrayTy =
+      hlfir::getFortranElementOrSequenceType(array.getType())
+          .cast<fir::SequenceType>();
+  mlir::Type numTy = arrayTy.getEleTy();
+  llvm::ArrayRef<int64_t> arrayShape = arrayTy.getShape();
+  hlfir::ExprType resultTy = results[0].cast<hlfir::ExprType>();
+
+  if (mask) {
+    fir::SequenceType maskSeq =
+        hlfir::getFortranElementOrSequenceType(mask.getType())
+            .dyn_cast<fir::SequenceType>();
+    llvm::ArrayRef<int64_t> maskShape;
+
+    if (maskSeq)
+      maskShape = maskSeq.getShape();
+
+    if (!maskShape.empty()) {
+      if (maskShape.size() != arrayShape.size())
+        return emitWarning("MASK must be conformable to ARRAY");
+      static_assert(fir::SequenceType::getUnknownExtent() ==
+                    hlfir::ExprType::getUnknownExtent());
+      constexpr int64_t unknownExtent = fir::SequenceType::getUnknownExtent();
+      for (std::size_t i = 0; i < arrayShape.size(); ++i) {
+        int64_t arrayExtent = arrayShape[i];
+        int64_t maskExtent = maskShape[i];
+        if ((arrayExtent != maskExtent) && (arrayExtent != unknownExtent) &&
+            (maskExtent != unknownExtent))
+          return emitWarning("MASK must be conformable to ARRAY");
+      }
+    }
+  }
+
+  if (resultTy.isArray()) {
+    // Result is of the same type as ARRAY
+    if (resultTy.getEleTy() != numTy)
+      return emitOpError(
+          "result must have the same element type as ARRAY argument");
+
+    llvm::ArrayRef<int64_t> resultShape = resultTy.getShape();
+
+    // Result has rank n-1
+    if (resultShape.size() != (arrayShape.size() - 1))
+      return emitOpError("result rank must be one less than ARRAY");
+  } else {
+    // Result is of the same type as ARRAY
+    if (resultTy.getElementType() != numTy)
+      return emitOpError(
+          "result must have the same element type as ARRAY argument");
+  }
+
+  return mlir::success();
+}
+
 //===----------------------------------------------------------------------===//
 // AssociateOp
 //===----------------------------------------------------------------------===//

diff  --git a/flang/test/HLFIR/invalid.fir b/flang/test/HLFIR/invalid.fir
index a7a3150a1534f..dd801634c6f7d 100644
--- a/flang/test/HLFIR/invalid.fir
+++ b/flang/test/HLFIR/invalid.fir
@@ -295,3 +295,27 @@ func.func @bad_concat_4(%arg0: !fir.ref<!fir.char<1,30>>) {
   %0 = hlfir.concat %arg0 len %c30 : (!fir.ref<!fir.char<1,30>>, index) -> (!hlfir.expr<!fir.char<1,30>>)
   return
 }
+
+// -----
+func.func @bad_sum1(%arg0: !hlfir.expr<?xi32>, %arg1: i32, %arg2: !fir.box<!fir.logical<4>>) {
+  // expected-error at +1 {{'hlfir.sum' op result must have the same element type as ARRAY argument}}
+  %0 = hlfir.sum %arg0 dim %arg1 mask %arg2 : (!hlfir.expr<?xi32>, i32, !fir.box<!fir.logical<4>>) -> !hlfir.expr<f32>
+}
+
+// -----
+func.func @bad_sum2(%arg0: !hlfir.expr<?xi32>, %arg1: i32, %arg2: !fir.box<!fir.array<?x?x?x?x?x!fir.logical<4>>>) {
+  // expected-warning at +1 {{MASK must be conformable to ARRAY}}
+  %0 = hlfir.sum %arg0 dim %arg1 mask %arg2 : (!hlfir.expr<?xi32>, i32, !fir.box<!fir.array<?x?x?x?x?x!fir.logical<4>>>) -> !hlfir.expr<i32>
+}
+
+// -----
+func.func @bad_sum3(%arg0: !hlfir.expr<?x5x?xi32>, %arg1: i32, %arg2: !fir.box<!fir.array<2x6x?x!fir.logical<4>>>) {
+  // expected-warning at +1 {{MASK must be conformable to ARRAY}}
+  %0 = hlfir.sum %arg0 dim %arg1 mask %arg2 : (!hlfir.expr<?x5x?xi32>, i32, !fir.box<!fir.array<2x6x?x!fir.logical<4>>>) -> !hlfir.expr<i32>
+}
+
+// -----
+func.func @bad_sum4(%arg0: !hlfir.expr<?xi32>, %arg1: i32, %arg2: !fir.box<!fir.logical<4>>) {
+  // expected-error at +1 {{'hlfir.sum' op result rank must be one less than ARRAY}}
+  %0 = hlfir.sum %arg0 dim %arg1 mask %arg2 : (!hlfir.expr<?xi32>, i32, !fir.box<!fir.logical<4>>) -> !hlfir.expr<?x?xi32>
+}

diff  --git a/flang/test/HLFIR/sum.fir b/flang/test/HLFIR/sum.fir
new file mode 100644
index 0000000000000..45388c33091e4
--- /dev/null
+++ b/flang/test/HLFIR/sum.fir
@@ -0,0 +1,239 @@
+// Test hlfir.sum operation parse, verify (no errors), and unparse
+
+// RUN: fir-opt %s | fir-opt | FileCheck %s
+
+// array is an expression of known shape
+func.func @sum0(%arg0: !hlfir.expr<42xi32>) {
+  %mask = fir.alloca !fir.logical<4>
+  %c_1 = arith.constant 1 : index
+  %true = arith.constant true
+  %true_logical = fir.convert %true : (i1) -> !fir.logical<4>
+  fir.store %true_logical to %mask : !fir.ref<!fir.logical<4>>
+  %mask_box = fir.embox %mask : (!fir.ref<!fir.logical<4>>) -> !fir.box<!fir.logical<4>>
+  %0 = hlfir.sum %arg0 dim %c_1 mask %mask_box : (!hlfir.expr<42xi32>, index, !fir.box<!fir.logical<4>>) -> !hlfir.expr<i32>
+  return
+}
+// CHECK:      func.func @sum0(%[[ARRAY:.*]]: !hlfir.expr<42xi32>) {
+// CHECK-NEXT:   %[[MASK:.*]] = fir.alloca !fir.logical<4>
+// CHECK-NEXT:   %[[C1:.*]] = arith.constant 1 : index
+// CHECK-NEXT:   %[[TRUE:.*]] = arith.constant true
+// CHECK-NEXT:   %[[LOGICAL:.*]] = fir.convert %[[TRUE]] : (i1) -> !fir.logical<4>
+// CHECK-NEXT:   fir.store %[[LOGICAL]] to %[[MASK]] : !fir.ref<!fir.logical<4>>
+// CHECK-NEXT:   %[[BOX:.*]] = fir.embox %0 : (!fir.ref<!fir.logical<4>>) -> !fir.box<!fir.logical<4>>
+// CHECK-NEXT:   hlfir.sum %[[ARRAY]] dim %[[C1]] mask %[[BOX]] : (!hlfir.expr<42xi32>, index, !fir.box<!fir.logical<4>>) -> !hlfir.expr<i32>
+// CHECK-NEXT:   return
+// CHECK-NEXT: }
+
+// array is an expression of assumed shape
+func.func @sum1(%arg0: !hlfir.expr<?xi32>) {
+  %mask = fir.alloca !fir.logical<4>
+  %c_1 = arith.constant 1 : index
+  %true = arith.constant true
+  %true_logical = fir.convert %true : (i1) -> !fir.logical<4>
+  fir.store %true_logical to %mask : !fir.ref<!fir.logical<4>>
+  %mask_box = fir.embox %mask : (!fir.ref<!fir.logical<4>>) -> !fir.box<!fir.logical<4>>
+  %0 = hlfir.sum %arg0 dim %c_1 mask %mask_box : (!hlfir.expr<?xi32>, index, !fir.box<!fir.logical<4>>) -> !hlfir.expr<i32>
+  return
+}
+// CHECK:      func.func @sum1(%[[ARRAY:.*]]: !hlfir.expr<?xi32>) {
+// CHECK-NEXT:   %[[MASK:.*]] = fir.alloca !fir.logical<4>
+// CHECK-NEXT:   %[[C1:.*]] = arith.constant 1 : index
+// CHECK-NEXT:   %[[TRUE:.*]] = arith.constant true
+// CHECK-NEXT:   %[[LOGICAL:.*]] = fir.convert %[[TRUE]] : (i1) -> !fir.logical<4>
+// CHECK-NEXT:   fir.store %[[LOGICAL:.*]] to %[[MASK:.*]] : !fir.ref<!fir.logical<4>>
+// CHECK-NEXT:   %[[BOX:.*]] = fir.embox %[[MASK:.*]] : (!fir.ref<!fir.logical<4>>) -> !fir.box<!fir.logical<4>>
+// CHECK-NEXT:   hlfir.sum %[[ARRAY:.*]] dim %[[C1]] mask %[[BOX]] : (!hlfir.expr<?xi32>, index, !fir.box<!fir.logical<4>>) -> !hlfir.expr<i32>
+// CHECK-NEXT:   return
+// CHECK-NEXT: }
+
+// boxed array
+func.func @sum2(%arg0: !fir.box<!fir.array<42xi32>>) {
+  %mask = fir.alloca !fir.logical<4>
+  %c_1 = arith.constant 1 : index
+  %true = arith.constant true
+  %true_logical = fir.convert %true : (i1) -> !fir.logical<4>
+  fir.store %true_logical to %mask : !fir.ref<!fir.logical<4>>
+  %mask_box = fir.embox %mask : (!fir.ref<!fir.logical<4>>) -> !fir.box<!fir.logical<4>>
+  %0 = hlfir.sum %arg0 dim %c_1 mask %mask_box : (!fir.box<!fir.array<42xi32>>, index, !fir.box<!fir.logical<4>>) -> !hlfir.expr<i32>
+  return
+}
+// CHECK:      func.func @sum2(%[[ARRAY:.*]]: !fir.box<!fir.array<42xi32>>) {
+// CHECK-NEXT:   %[[MASK:.*]] = fir.alloca !fir.logical<4>
+// CHECK-NEXT:   %[[C1:.*]] = arith.constant 1 : index
+// CHECK-NEXT:   %[[TRUE:.*]] = arith.constant true
+// CHECK-NEXT:   %[[LOGICAL:.*]] = fir.convert %[[TRUE]] : (i1) -> !fir.logical<4>
+// CHECK-NEXT:   fir.store %[[LOGICAL:.*]] to %[[MASK:.*]] : !fir.ref<!fir.logical<4>>
+// CHECK-NEXT:   %[[BOX:.*]] = fir.embox %[[MASK:.*]] : (!fir.ref<!fir.logical<4>>) -> !fir.box<!fir.logical<4>>
+// CHECK-NEXT:   hlfir.sum %[[ARRAY:.*]] dim %[[C1]] mask %[[BOX]] : (!fir.box<!fir.array<42xi32>>, index, !fir.box<!fir.logical<4>>) -> !hlfir.expr<i32>
+// CHECK-NEXT:   return
+// CHECK-NEXT: }
+
+// assumed shape boxed array
+func.func @sum3(%arg0: !fir.box<!fir.array<?xi32>>) {
+  %mask = fir.alloca !fir.logical<4>
+  %c_1 = arith.constant 1 : index
+  %true = arith.constant true
+  %true_logical = fir.convert %true : (i1) -> !fir.logical<4>
+  fir.store %true_logical to %mask : !fir.ref<!fir.logical<4>>
+  %mask_box = fir.embox %mask : (!fir.ref<!fir.logical<4>>) -> !fir.box<!fir.logical<4>>
+  %0 = hlfir.sum %arg0 dim %c_1 mask %mask_box : (!fir.box<!fir.array<?xi32>>, index, !fir.box<!fir.logical<4>>) -> !hlfir.expr<i32>
+  return
+}
+// CHECK:      func.func @sum3(%[[ARRAY:.*]]: !fir.box<!fir.array<?xi32>>) {
+// CHECK-NEXT:   %[[MASK:.*]] = fir.alloca !fir.logical<4>
+// CHECK-NEXT:   %[[C1:.*]] = arith.constant 1 : index
+// CHECK-NEXT:   %[[TRUE:.*]] = arith.constant true
+// CHECK-NEXT:   %[[LOGICAL:.*]] = fir.convert %[[TRUE]] : (i1) -> !fir.logical<4>
+// CHECK-NEXT:   fir.store %[[LOGICAL:.*]] to %[[MASK:.*]] : !fir.ref<!fir.logical<4>>
+// CHECK-NEXT:   %[[BOX:.*]] = fir.embox %[[MASK:.*]] : (!fir.ref<!fir.logical<4>>) -> !fir.box<!fir.logical<4>>
+// CHECK-NEXT:   hlfir.sum %[[ARRAY:.*]] dim %[[C1]] mask %[[BOX]] : (!fir.box<!fir.array<?xi32>>, index, !fir.box<!fir.logical<4>>) -> !hlfir.expr<i32>
+// CHECK-NEXT:   return
+// CHECK-NEXT: }
+
+// known shape expr mask
+func.func @sum4(%arg0: !fir.box<!fir.array<?xi32>>, %arg1: !hlfir.expr<42x!fir.logical<4>>) {
+  %c_1 = arith.constant 1 : index
+  %0 = hlfir.sum %arg0 dim %c_1 mask %arg1 : (!fir.box<!fir.array<?xi32>>, index, !hlfir.expr<42x!fir.logical<4>>) -> !hlfir.expr<i32>
+  return
+}
+// CHECK:      func.func @sum4(%[[ARRAY:.*]]: !fir.box<!fir.array<?xi32>>, %[[MASK:.*]]: !hlfir.expr<42x!fir.logical<4>>) {
+// CHECK-NEXT:   %[[C1:.*]] = arith.constant 1 : index
+// CHECK-NEXT:   hlfir.sum %[[ARRAY]] dim %[[C1]] mask %[[MASK]] : (!fir.box<!fir.array<?xi32>>, index, !hlfir.expr<42x!fir.logical<4>>) -> !hlfir.expr<i32>
+// CHECK-NEXT:   return
+// CHECK-NEXT: }
+
+// assumed shape expr mask
+func.func @sum5(%arg0: !fir.box<!fir.array<?xi32>>, %arg1: !hlfir.expr<?x!fir.logical<4>>) {
+  %c_1 = arith.constant 1 : index
+  %0 = hlfir.sum %arg0 dim %c_1 mask %arg1 : (!fir.box<!fir.array<?xi32>>, index, !hlfir.expr<?x!fir.logical<4>>) -> !hlfir.expr<i32>
+  return
+}
+// CHECK:      func.func @sum5(%[[ARRAY:.*]]: !fir.box<!fir.array<?xi32>>, %[[MASK:.*]]: !hlfir.expr<?x!fir.logical<4>>) {
+// CHECK-NEXT:   %[[C1:.*]] = arith.constant 1 : index
+// CHECK-NEXT:   hlfir.sum %[[ARRAY]] dim %[[C1]] mask %[[MASK]] : (!fir.box<!fir.array<?xi32>>, index, !hlfir.expr<?x!fir.logical<4>>) -> !hlfir.expr<i32>
+// CHECK-NEXT:   return
+// CHECK-NEXT: }
+
+// known shape array mask
+func.func @sum6(%arg0: !fir.box<!fir.array<?xi32>>, %arg1: !fir.box<!fir.array<42x!fir.logical<4>>>) {
+  %c_1 = arith.constant 1 : index
+  %0 = hlfir.sum %arg0 dim %c_1 mask %arg1 : (!fir.box<!fir.array<?xi32>>, index, !fir.box<!fir.array<42x!fir.logical<4>>>) -> !hlfir.expr<i32>
+  return
+}
+// CHECK:      func.func @sum6(%[[ARRAY:.*]]: !fir.box<!fir.array<?xi32>>, %[[MASK:.*]]: !fir.box<!fir.array<42x!fir.logical<4>>>) {
+// CHECK-NEXT:   %[[C1:.*]] = arith.constant 1 : index
+// CHECK-NEXT:   hlfir.sum %[[ARRAY]] dim %[[C1]] mask %[[MASK]] : (!fir.box<!fir.array<?xi32>>, index, !fir.box<!fir.array<42x!fir.logical<4>>>) -> !hlfir.expr<i32>
+// CHECK-NEXT:   return
+// CHECK-NEXT: }
+
+// assumed shape array mask
+func.func @sum7(%arg0: !fir.box<!fir.array<?xi32>>, %arg1: !fir.box<!fir.array<?x!fir.logical<4>>>) {
+  %c_1 = arith.constant 1 : index
+  %0 = hlfir.sum %arg0 dim %c_1 mask %arg1 : (!fir.box<!fir.array<?xi32>>, index, !fir.box<!fir.array<?x!fir.logical<4>>>) -> !hlfir.expr<i32>
+  return
+}
+// CHECK:      func.func @sum7(%[[ARRAY:.*]]: !fir.box<!fir.array<?xi32>>, %[[MASK:.*]]: !fir.box<!fir.array<?x!fir.logical<4>>>) {
+// CHECK-NEXT:   %[[C1:.*]] = arith.constant 1 : index
+// CHECK-NEXT:   hlfir.sum %[[ARRAY]] dim %[[C1]] mask %[[MASK]] : (!fir.box<!fir.array<?xi32>>, index, !fir.box<!fir.array<?x!fir.logical<4>>>) -> !hlfir.expr<i32>
+// CHECK-NEXT:   return
+// CHECK-NEXT: }
+
+// known shape expr return
+func.func @sum8(%arg0: !fir.box<!fir.array<2x2xi32>>, %arg1: i32) {
+  %mask = fir.alloca !fir.logical<4>
+  %true = arith.constant true
+  %true_logical = fir.convert %true : (i1) -> !fir.logical<4>
+  fir.store %true_logical to %mask : !fir.ref<!fir.logical<4>>
+  %mask_box = fir.embox %mask : (!fir.ref<!fir.logical<4>>) -> !fir.box<!fir.logical<4>>
+  %0 = hlfir.sum %arg0 dim %arg1 mask %mask_box : (!fir.box<!fir.array<2x2xi32>>, i32, !fir.box<!fir.logical<4>>) -> !hlfir.expr<2xi32>
+  return
+}
+// CHECK:      func.func @sum8(%[[ARRAY:.*]]: !fir.box<!fir.array<2x2xi32>>, %[[DIM:.*]]: i32) {
+// CHECK-NEXT:   %[[MASK:.*]] = fir.alloca !fir.logical<4>
+// CHECK-NEXT:   %[[TRUE:.*]] = arith.constant true
+// CHECK-NEXT:   %[[LOGICAL:.*]] = fir.convert %[[TRUE]] : (i1) -> !fir.logical<4>
+// CHECK-NEXT:   fir.store %[[LOGICAL]] to %[[MASK]] : !fir.ref<!fir.logical<4>>
+// CHECK-NEXT:   %[[BOX:.*]] = fir.embox %0 : (!fir.ref<!fir.logical<4>>) -> !fir.box<!fir.logical<4>>
+// CHECK-NEXT:   hlfir.sum %[[ARRAY]] dim %[[DIM]] mask %[[BOX]] : (!fir.box<!fir.array<2x2xi32>>, i32, !fir.box<!fir.logical<4>>) -> !hlfir.expr<2xi32>
+// CHECK-NEXT:   return
+// CHECK-NEXT: }
+
+// assumed shape expr return
+func.func @sum9(%arg0: !fir.box<!fir.array<?x?xi32>>, %arg1: i32) {
+  %mask = fir.alloca !fir.logical<4>
+  %true = arith.constant true
+  %true_logical = fir.convert %true : (i1) -> !fir.logical<4>
+  fir.store %true_logical to %mask : !fir.ref<!fir.logical<4>>
+  %mask_box = fir.embox %mask : (!fir.ref<!fir.logical<4>>) -> !fir.box<!fir.logical<4>>
+  %0 = hlfir.sum %arg0 dim %arg1 mask %mask_box : (!fir.box<!fir.array<?x?xi32>>, i32, !fir.box<!fir.logical<4>>) -> !hlfir.expr<?xi32>
+  return
+}
+// CHECK:      func.func @sum9(%[[ARRAY:.*]]: !fir.box<!fir.array<?x?xi32>>, %[[DIM:.*]]: i32) {
+// CHECK-NEXT:   %[[MASK:.*]] = fir.alloca !fir.logical<4>
+// CHECK-NEXT:   %[[TRUE:.*]] = arith.constant true
+// CHECK-NEXT:   %[[LOGICAL:.*]] = fir.convert %[[TRUE]] : (i1) -> !fir.logical<4>
+// CHECK-NEXT:   fir.store %[[LOGICAL]] to %[[MASK]] : !fir.ref<!fir.logical<4>>
+// CHECK-NEXT:   %[[BOX:.*]] = fir.embox %0 : (!fir.ref<!fir.logical<4>>) -> !fir.box<!fir.logical<4>>
+// CHECK-NEXT:   hlfir.sum %[[ARRAY]] dim %[[DIM]] mask %[[BOX]] : (!fir.box<!fir.array<?x?xi32>>, i32, !fir.box<!fir.logical<4>>) -> !hlfir.expr<?xi32>
+// CHECK-NEXT:   return
+// CHECK-NEXT: }
+
+// hlfir.sum with only an array argument
+func.func @sum10(%arg0: !fir.box<!fir.array<?x?xi32>>) {
+  %sum = hlfir.sum %arg0 : (!fir.box<!fir.array<?x?xi32>>) -> !hlfir.expr<i32>
+  return
+}
+// CHECK:      func.func @sum10(%[[ARRAY:.*]]: !fir.box<!fir.array<?x?xi32>>
+// CHECK-NEXT:   %[[SUM:.*]] = hlfir.sum %[[ARRAY]] : (!fir.box<!fir.array<?x?xi32>>) -> !hlfir.expr<i32>
+// CHECK-NEXT:   return
+// CHECK-NEXT: }
+
+// hlfir.sum with array and dim argument
+func.func @sum11(%arg0: !fir.box<!fir.array<?x?xi32>>, %arg1: i32) {
+  %sum = hlfir.sum %arg0 dim %arg1 : (!fir.box<!fir.array<?x?xi32>>, i32) -> !hlfir.expr<?xi32>
+  return
+}
+// CHECK:      func.func @sum11(%[[ARRAY:.*]]: !fir.box<!fir.array<?x?xi32>>, %[[DIM:.*]]: i32
+// CHECK-NEXT:   %[[SUM:.*]] = hlfir.sum %[[ARRAY]] dim %[[DIM]] : (!fir.box<!fir.array<?x?xi32>>, i32) -> !hlfir.expr<?xi32>
+// CHECK-NEXT:   return
+// CHECK-NEXT: }
+
+// hlfir.sum with array and mask argument
+func.func @sum12(%arg0: !fir.box<!fir.array<?xi32>>, %arg1: !fir.logical<4>) {
+  %sum = hlfir.sum %arg0 mask %arg1 : (!fir.box<!fir.array<?xi32>>, !fir.logical<4>) -> !hlfir.expr<i32>
+  return
+}
+// CHECK:      func.func @sum12(%[[ARRAY:.*]]: !fir.box<!fir.array<?xi32>>, %[[MASK:.*]]: !fir.logical<4>
+// CHECK-NEXT:   %[[SUM:.*]] = hlfir.sum %[[ARRAY]] mask %[[MASK]] : (!fir.box<!fir.array<?xi32>>, !fir.logical<4>) -> !hlfir.expr<i32>
+// CHECK-NEXT:   return
+// CHECK-NEXT: }
+
+// hlfir.sum with dim argument with an unusual type
+func.func @sum13(%arg0: !fir.box<!fir.array<?x?xi32>>, %arg1: index) {
+  %sum = hlfir.sum %arg0 dim %arg1 : (!fir.box<!fir.array<?x?xi32>>, index) -> !hlfir.expr<?xi32>
+  return
+}
+// CHECK:      func.func @sum13(%[[ARRAY:.*]]: !fir.box<!fir.array<?x?xi32>>, %[[DIM:.*]]: index
+// CHECK-NEXT:   %[[SUM:.*]] = hlfir.sum %[[ARRAY]] dim %[[DIM]] : (!fir.box<!fir.array<?x?xi32>>, index) -> !hlfir.expr<?xi32>
+// CHECK-NEXT:   return
+// CHECK-NEXT: }
+
+// hlfir.sum with mask argument of unusual type
+func.func @sum14(%arg0: !fir.box<!fir.array<?xi32>>, %arg1: i1) {
+  %sum = hlfir.sum %arg0 mask %arg1 : (!fir.box<!fir.array<?xi32>>, i1) -> !hlfir.expr<i32>
+  return
+}
+// CHECK:      func.func @sum14(%[[ARRAY:.*]]: !fir.box<!fir.array<?xi32>>, %[[MASK:.*]]: i1
+// CHECK-NEXT:   %[[SUM:.*]] = hlfir.sum %[[ARRAY]] mask %[[MASK]] : (!fir.box<!fir.array<?xi32>>, i1) -> !hlfir.expr<i32>
+// CHECK-NEXT:   return
+// CHECK-NEXT: }
+
+// hlfir.sum with mask argument of ref<array<>> type
+func.func @sum15(%arg0: !fir.box<!fir.array<?xi32>>, %arg1: !fir.ref<!fir.array<?x!fir.logical<4>>>) {
+  %sum = hlfir.sum %arg0 mask %arg1 : (!fir.box<!fir.array<?xi32>>, !fir.ref<!fir.array<?x!fir.logical<4>>>) -> !hlfir.expr<i32>
+  return
+}
+// CHECK:      func.func @sum15(%[[ARRAY:.*]]: !fir.box<!fir.array<?xi32>>, %[[MASK:.*]]: !fir.ref<!fir.array<?x!fir.logical<4>>>
+// CHECK-NEXT:   %[[SUM:.*]] = hlfir.sum %[[ARRAY]] mask %[[MASK]] : (!fir.box<!fir.array<?xi32>>, !fir.ref<!fir.array<?x!fir.logical<4>>>) -> !hlfir.expr<i32>
+// CHECK-NEXT:   return
+// CHECK-NEXT: }


        


More information about the flang-commits mailing list