[flang-commits] [flang] 206b853 - [flang] add hlfir.all intrinsic

Jacob Crawley via flang-commits flang-commits at lists.llvm.org
Tue May 30 07:48:59 PDT 2023


Author: Jacob Crawley
Date: 2023-05-30T14:46:06Z
New Revision: 206b8538a6df53d5245b7524d83501e027c52418

URL: https://github.com/llvm/llvm-project/commit/206b8538a6df53d5245b7524d83501e027c52418
DIFF: https://github.com/llvm/llvm-project/commit/206b8538a6df53d5245b7524d83501e027c52418.diff

LOG: [flang] add hlfir.all intrinsic

Adds a new HLFIR operation for the ALL intrinsic according to the
design set out in flang/docs/HighLevel.md

Differential Revision: https://reviews.llvm.org/D151090

Added: 
    flang/test/HLFIR/all.fir

Modified: 
    flang/include/flang/Optimizer/HLFIR/HLFIROps.td
    flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
    flang/test/HLFIR/invalid.fir

Removed: 
    


################################################################################
diff  --git a/flang/include/flang/Optimizer/HLFIR/HLFIROps.td b/flang/include/flang/Optimizer/HLFIR/HLFIROps.td
index 15b92385a7720..142a70c639127 100644
--- a/flang/include/flang/Optimizer/HLFIR/HLFIROps.td
+++ b/flang/include/flang/Optimizer/HLFIR/HLFIROps.td
@@ -317,6 +317,27 @@ def hlfir_ConcatOp : hlfir_Op<"concat", []> {
   let hasVerifier = 1;
 }
 
+def hlfir_AllOp : hlfir_Op<"all", []> {
+  let summary = "ALL transformational intrinsic";
+  let description = [{
+    Takes a logical array MASK as argument, optionally along a particular dimension,
+    and returns true if all elements of MASK are true.
+  }];
+
+  let arguments = (ins
+    AnyFortranLogicalArrayObject:$mask,
+    Optional<AnyIntegerType>:$dim
+  );
+
+  let results = (outs AnyFortranValue);
+
+  let assemblyFormat = [{
+    $mask  (`dim` $dim^)?  attr-dict `:` functional-type(operands, results)
+  }];
+
+  let hasVerifier = 1;
+}
+
 def hlfir_AnyOp : hlfir_Op<"any", []> {
   let summary = "ANY transformational intrinsic";
   let description = [{

diff  --git a/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp b/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
index 4547c4247241e..adf8b72993e4c 100644
--- a/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
+++ b/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
@@ -442,16 +442,19 @@ mlir::LogicalResult hlfir::ParentComponentOp::verify() {
 }
 
 //===----------------------------------------------------------------------===//
-// AnyOp
+// LogicalReductionOp
 //===----------------------------------------------------------------------===//
-mlir::LogicalResult hlfir::AnyOp::verify() {
-  mlir::Operation *op = getOperation();
+template <typename LogicalReductionOp>
+static mlir::LogicalResult
+verifyLogicalReductionOp(LogicalReductionOp reductionOp) {
+  mlir::Operation *op = reductionOp->getOperation();
 
   auto results = op->getResultTypes();
   assert(results.size() == 1);
 
-  mlir::Value mask = getMask();
-  mlir::Value dim = getDim();
+  mlir::Value mask = reductionOp->getMask();
+  mlir::Value dim = reductionOp->getDim();
+
   fir::SequenceType maskTy =
       hlfir::getFortranElementOrSequenceType(mask.getType())
           .cast<fir::SequenceType>();
@@ -462,7 +465,7 @@ mlir::LogicalResult hlfir::AnyOp::verify() {
   if (mlir::isa<fir::LogicalType>(resultType)) {
     // Result is of the same type as MASK
     if (resultType != logicalTy)
-      return emitOpError(
+      return reductionOp->emitOpError(
           "result must have the same element type as MASK argument");
 
   } else if (auto resultExpr =
@@ -470,25 +473,42 @@ mlir::LogicalResult hlfir::AnyOp::verify() {
     // Result should only be in hlfir.expr form if it is an array
     if (maskShape.size() > 1 && dim != nullptr) {
       if (!resultExpr.isArray())
-        return emitOpError("result must be an array");
+        return reductionOp->emitOpError("result must be an array");
 
       if (resultExpr.getEleTy() != logicalTy)
-        return emitOpError(
+        return reductionOp->emitOpError(
             "result must have the same element type as MASK argument");
 
       llvm::ArrayRef<int64_t> resultShape = resultExpr.getShape();
       // Result has rank n-1
       if (resultShape.size() != (maskShape.size() - 1))
-        return emitOpError("result rank must be one less than MASK");
+        return reductionOp->emitOpError(
+            "result rank must be one less than MASK");
     } else {
-      return emitOpError("result must be of logical type");
+      return reductionOp->emitOpError("result must be of logical type");
     }
   } else {
-    return emitOpError("result must be of logical type");
+    return reductionOp->emitOpError("result must be of logical type");
   }
   return mlir::success();
 }
 
+//===----------------------------------------------------------------------===//
+// AllOp
+//===----------------------------------------------------------------------===//
+
+mlir::LogicalResult hlfir::AllOp::verify() {
+  return verifyLogicalReductionOp<hlfir::AllOp *>(this);
+}
+
+//===----------------------------------------------------------------------===//
+// AnyOp
+//===----------------------------------------------------------------------===//
+
+mlir::LogicalResult hlfir::AnyOp::verify() {
+  return verifyLogicalReductionOp<hlfir::AnyOp *>(this);
+}
+
 //===----------------------------------------------------------------------===//
 // ConcatOp
 //===----------------------------------------------------------------------===//
@@ -537,11 +557,12 @@ void hlfir::ConcatOp::build(mlir::OpBuilder &builder,
 }
 
 //===----------------------------------------------------------------------===//
-// ReductionOp
+// NumericalReductionOp
 //===----------------------------------------------------------------------===//
 
-template <typename ReductionOp>
-static mlir::LogicalResult verifyReductionOp(ReductionOp reductionOp) {
+template <typename NumericalReductionOp>
+static mlir::LogicalResult
+verifyNumericalReductionOp(NumericalReductionOp reductionOp) {
   mlir::Operation *op = reductionOp->getOperation();
 
   auto results = op->getResultTypes();
@@ -619,7 +640,7 @@ static mlir::LogicalResult verifyReductionOp(ReductionOp reductionOp) {
 //===----------------------------------------------------------------------===//
 
 mlir::LogicalResult hlfir::ProductOp::verify() {
-  return verifyReductionOp<hlfir::ProductOp *>(this);
+  return verifyNumericalReductionOp<hlfir::ProductOp *>(this);
 }
 
 //===----------------------------------------------------------------------===//
@@ -645,7 +666,7 @@ void hlfir::SetLengthOp::build(mlir::OpBuilder &builder,
 //===----------------------------------------------------------------------===//
 
 mlir::LogicalResult hlfir::SumOp::verify() {
-  return verifyReductionOp<hlfir::SumOp *>(this);
+  return verifyNumericalReductionOp<hlfir::SumOp *>(this);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/flang/test/HLFIR/all.fir b/flang/test/HLFIR/all.fir
new file mode 100644
index 0000000000000..00ce1b3a5fbae
--- /dev/null
+++ b/flang/test/HLFIR/all.fir
@@ -0,0 +1,113 @@
+// Test hlfir.all operation parse, verify (no errors), and unparse
+
+// RUN: fir-opt %s | fir-opt | FileCheck %s
+
+// mask is an expression of known shape
+func.func @all0(%arg0: !hlfir.expr<2x!fir.logical<4>>) {
+  %all = hlfir.all %arg0 : (!hlfir.expr<2x!fir.logical<4>>) -> !fir.logical<4>
+  return
+}
+// CHECK:      func.func @all0(%[[ARRAY:.*]]: !hlfir.expr<2x!fir.logical<4>>) {
+// CHECK-NEXT:   %[[ALL:.*]] = hlfir.all %[[ARRAY]] : (!hlfir.expr<2x!fir.logical<4>>) -> !fir.logical<4>
+// CHECK-NEXT:   return
+// CHECK-NEXT: }
+
+// mask is an expression of assumed shape
+func.func @all1(%arg0: !hlfir.expr<?x!fir.logical<4>>) {
+  %all = hlfir.all %arg0 : (!hlfir.expr<?x!fir.logical<4>>) -> !fir.logical<4>
+  return
+}
+// CHECK:      func.func @all1(%[[ARRAY:.*]]: !hlfir.expr<?x!fir.logical<4>>) {
+// CHECK-NEXT:   %[[ALL:.*]] = hlfir.all %[[ARRAY]] : (!hlfir.expr<?x!fir.logical<4>>) -> !fir.logical<4>
+// CHECK-NEXT:   return
+// CHECK-NEXT: }
+
+// mask is a boxed array
+func.func @all2(%arg0: !fir.box<!fir.array<2x!fir.logical<4>>>) {
+  %all = hlfir.all %arg0 : (!fir.box<!fir.array<2x!fir.logical<4>>>) -> !fir.logical<4>
+  return
+}
+// CHECK:      func.func @all2(%[[ARRAY:.*]]: !fir.box<!fir.array<2x!fir.logical<4>>>) {
+// CHECK-NEXT:   %[[ALL:.*]] = hlfir.all %[[ARRAY]] : (!fir.box<!fir.array<2x!fir.logical<4>>>) -> !fir.logical<4>
+// CHECK-NEXT:   return
+// CHECK-NEXT: }
+
+// mask is an assumed shape boxed array
+func.func @all3(%arg0: !fir.box<!fir.array<?x!fir.logical<4>>>){
+  %all = hlfir.all %arg0 : (!fir.box<!fir.array<?x!fir.logical<4>>>) -> !fir.logical<4>
+  return
+}
+// CHECK:      func.func @all3(%[[ARRAY:.*]]: !fir.box<!fir.array<?x!fir.logical<4>>>) {
+// CHECK-NEXT:   %[[ALL:.*]] = hlfir.all %[[ARRAY]] : (!fir.box<!fir.array<?x!fir.logical<4>>>) -> !fir.logical<4>
+// CHECK-NEXT:   return
+// CHECK-NEXT: }
+
+// mask is a 2-dimensional array
+func.func @all4(%arg0: !fir.box<!fir.array<?x?x!fir.logical<4>>>){
+  %all = hlfir.all %arg0 : (!fir.box<!fir.array<?x?x!fir.logical<4>>>) -> !fir.logical<4>
+  return
+}
+// CHECK:      func.func @all4(%[[ARRAY:.*]]: !fir.box<!fir.array<?x?x!fir.logical<4>>>) {
+// CHECK-NEXT:   %[[ALL:.*]] = hlfir.all %[[ARRAY]] : (!fir.box<!fir.array<?x?x!fir.logical<4>>>) -> !fir.logical<4>
+// CHECK-NEXT:   return
+// CHECK-NEXT: }
+
+// mask and dim argument
+func.func @all5(%arg0: !fir.box<!fir.array<?x!fir.logical<4>>>, %arg1: i32) {
+  %all = hlfir.all %arg0 dim %arg1 : (!fir.box<!fir.array<?x!fir.logical<4>>>, i32) -> !fir.logical<4>
+  return
+}
+// CHECK:      func.func @all5(%[[ARRAY:.*]]: !fir.box<!fir.array<?x!fir.logical<4>>>, %[[DIM:.*]]: i32) {
+// CHECK-NEXT:   %[[ALL:.*]] = hlfir.all %[[ARRAY]] dim %[[DIM]] : (!fir.box<!fir.array<?x!fir.logical<4>>>, i32) -> !fir.logical<4>
+// CHECK-NEXT:   return
+// CHECK-NEXT: }
+
+// hlfir.all with dim argument with an unusual type
+func.func @all6(%arg0: !fir.box<!fir.array<?x!fir.logical<4>>>, %arg1: index) {
+  %all = hlfir.all %arg0 dim %arg1 : (!fir.box<!fir.array<?x!fir.logical<4>>>, index) ->!fir.logical<4>
+  return
+}
+// CHECK:      func.func @all6(%[[ARRAY:.*]]: !fir.box<!fir.array<?x!fir.logical<4>>>, %[[DIM:.*]]: index) {
+// CHECK-NEXT:   %[[ALL:.*]] = hlfir.all %[[ARRAY]] dim %[[DIM]] : (!fir.box<!fir.array<?x!fir.logical<4>>>, index) -> !fir.logical<4>
+// CHECK-NEXT:   return
+// CHECK-NEXT: }
+
+// mask is a 2 dimensional array with dim
+func.func @all7(%arg0: !fir.box<!fir.array<?x?x!fir.logical<4>>>, %arg1: i32) {
+  %all = hlfir.all %arg0 dim %arg1 : (!fir.box<!fir.array<?x?x!fir.logical<4>>>, i32) -> !hlfir.expr<?x!fir.logical<4>>
+  return
+}
+// CHECK:      func.func @all7(%[[ARRAY:.*]]: !fir.box<!fir.array<?x?x!fir.logical<4>>>, %[[DIM:.*]]: i32) {
+// CHECK-NEXT:   %[[ALL:.*]] = hlfir.all %[[ARRAY]] dim %[[DIM]] : (!fir.box<!fir.array<?x?x!fir.logical<4>>>, i32) -> !hlfir.expr<?x!fir.logical<4>>
+// CHECK-NEXT:   return
+// CHECK-NEXT: }
+
+// known shape expr return
+func.func @all8(%arg0: !fir.box<!fir.array<2x2x!fir.logical<4>>>, %arg1: i32) {
+  %all = hlfir.all %arg0 dim %arg1 : (!fir.box<!fir.array<2x2x!fir.logical<4>>>, i32) -> !hlfir.expr<2x!fir.logical<4>>
+  return
+}
+// CHECK:      func.func @all8(%[[ARRAY:.*]]: !fir.box<!fir.array<2x2x!fir.logical<4>>>, %[[DIM:.*]]: i32) {
+// CHECK-NEXT:   %[[ALL:.*]] = hlfir.all %[[ARRAY]] dim %[[DIM]] : (!fir.box<!fir.array<2x2x!fir.logical<4>>>, i32) -> !hlfir.expr<2x!fir.logical<4>>
+// CHECK-NEXT:   return
+// CHECK-NEXT: }
+
+// hlfir.all with mask argument of ref<array<>> type
+func.func @all9(%arg0: !fir.ref<!fir.array<?x!fir.logical<4>>>) {
+  %all = hlfir.all %arg0 : (!fir.ref<!fir.array<?x!fir.logical<4>>>) -> !fir.logical<4>
+  return
+}
+// CHECK:      func.func @all9(%[[ARRAY:.*]]: !fir.ref<!fir.array<?x!fir.logical<4>>>) {
+// CHECK-NEXT:   %[[ALL:.*]] = hlfir.all %[[ARRAY]] : (!fir.ref<!fir.array<?x!fir.logical<4>>>) -> !fir.logical<4>
+// CHECK-NEXT:   return
+// CHECK-NEXT: }
+
+// hlfir.all with fir.logical<8> type
+func.func @all10(%arg0: !fir.box<!fir.array<?x!fir.logical<8>>>) {
+  %all = hlfir.all %arg0 : (!fir.box<!fir.array<?x!fir.logical<8>>>) -> !fir.logical<8>
+  return
+}
+// CHECK:      func.func @all10(%[[ARRAY:.*]]: !fir.box<!fir.array<?x!fir.logical<8>>>) {
+// CHECK-NEXT:   %[[ALL:.*]] = hlfir.all %[[ARRAY]] : (!fir.box<!fir.array<?x!fir.logical<8>>>) -> !fir.logical<8>
+// CHECK-NEXT:   return
+// CHECK-NEXT: }
\ No newline at end of file

diff  --git a/flang/test/HLFIR/invalid.fir b/flang/test/HLFIR/invalid.fir
index e1c95c1046dc4..8dc5679346bc1 100644
--- a/flang/test/HLFIR/invalid.fir
+++ b/flang/test/HLFIR/invalid.fir
@@ -332,6 +332,42 @@ func.func @bad_any6(%arg0: !hlfir.expr<?x!fir.logical<4>>) {
   %0 = hlfir.any %arg0 : (!hlfir.expr<?x!fir.logical<4>>) -> !hlfir.expr<!fir.logical<4>>
 }
 
+// -----
+func.func @bad_all1(%arg0: !hlfir.expr<?x!fir.logical<4>>) {
+  // expected-error at +1 {{'hlfir.all' op result must have the same element type as MASK argument}}
+  %0 = hlfir.all %arg0 : (!hlfir.expr<?x!fir.logical<4>>) -> !fir.logical<8>
+}
+
+// -----
+func.func @bad_all2(%arg0: !hlfir.expr<?x?x!fir.logical<4>>, %arg1: i32) {
+  // expected-error at +1 {{'hlfir.all' op result must have the same element type as MASK argument}}
+  %0 = hlfir.all %arg0 dim %arg1 : (!hlfir.expr<?x?x!fir.logical<4>>, i32) -> !hlfir.expr<?x!fir.logical<8>>
+}
+
+// -----
+func.func @bad_all3(%arg0: !hlfir.expr<?x?x!fir.logical<4>>, %arg1: i32){
+  // expected-error at +1 {{'hlfir.all' op result rank must be one less than MASK}}
+  %0 = hlfir.all %arg0 dim %arg1 : (!hlfir.expr<?x?x!fir.logical<4>>, i32) -> !hlfir.expr<?x?x!fir.logical<4>>
+}
+
+// -----
+func.func @bad_all4(%arg0: !hlfir.expr<?x?x!fir.logical<4>>, %arg1: i32) {
+  // expected-error at +1 {{'hlfir.all' op result must be an array}}
+  %0 = hlfir.all %arg0 dim %arg1 : (!hlfir.expr<?x?x!fir.logical<4>>, i32) -> !hlfir.expr<!fir.logical<4>>
+}
+
+// -----
+func.func @bad_all5(%arg0: !hlfir.expr<?x!fir.logical<4>>) {
+  // expected-error at +1 {{'hlfir.all' op result must be of logical type}}
+  %0 = hlfir.all %arg0 : (!hlfir.expr<?x!fir.logical<4>>) -> i32
+}
+
+// -----
+func.func @bad_all6(%arg0: !hlfir.expr<?x!fir.logical<4>>) {
+  // expected-error at +1 {{'hlfir.all' op result must be of logical type}}
+  %0 = hlfir.all %arg0 : (!hlfir.expr<?x!fir.logical<4>>) -> !hlfir.expr<!fir.logical<4>>
+}
+
 // -----
 func.func @bad_product1(%arg0: !hlfir.expr<?xi32>, %arg1: i32, %arg2: !fir.box<!fir.logical<4>>) {
   // expected-error at +1 {{'hlfir.product' op result must have the same element type as ARRAY argument}}


        


More information about the flang-commits mailing list