[flang-commits] [flang] 206b853 - [flang] add hlfir.all intrinsic
Jacob Crawley via flang-commits
flang-commits at lists.llvm.org
Tue May 30 07:48:59 PDT 2023
Author: Jacob Crawley
Date: 2023-05-30T14:46:06Z
New Revision: 206b8538a6df53d5245b7524d83501e027c52418
URL: https://github.com/llvm/llvm-project/commit/206b8538a6df53d5245b7524d83501e027c52418
DIFF: https://github.com/llvm/llvm-project/commit/206b8538a6df53d5245b7524d83501e027c52418.diff
LOG: [flang] add hlfir.all intrinsic
Adds a new HLFIR operation for the ALL intrinsic according to the
design set out in flang/docs/HighLevel.md
Differential Revision: https://reviews.llvm.org/D151090
Added:
flang/test/HLFIR/all.fir
Modified:
flang/include/flang/Optimizer/HLFIR/HLFIROps.td
flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
flang/test/HLFIR/invalid.fir
Removed:
################################################################################
diff --git a/flang/include/flang/Optimizer/HLFIR/HLFIROps.td b/flang/include/flang/Optimizer/HLFIR/HLFIROps.td
index 15b92385a7720..142a70c639127 100644
--- a/flang/include/flang/Optimizer/HLFIR/HLFIROps.td
+++ b/flang/include/flang/Optimizer/HLFIR/HLFIROps.td
@@ -317,6 +317,27 @@ def hlfir_ConcatOp : hlfir_Op<"concat", []> {
let hasVerifier = 1;
}
+def hlfir_AllOp : hlfir_Op<"all", []> {
+ let summary = "ALL transformational intrinsic";
+ let description = [{
+ Takes a logical array MASK as argument, optionally along a particular dimension,
+ and returns true if all elements of MASK are true.
+ }];
+
+ let arguments = (ins
+ AnyFortranLogicalArrayObject:$mask,
+ Optional<AnyIntegerType>:$dim
+ );
+
+ let results = (outs AnyFortranValue);
+
+ let assemblyFormat = [{
+ $mask (`dim` $dim^)? attr-dict `:` functional-type(operands, results)
+ }];
+
+ let hasVerifier = 1;
+}
+
def hlfir_AnyOp : hlfir_Op<"any", []> {
let summary = "ANY transformational intrinsic";
let description = [{
diff --git a/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp b/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
index 4547c4247241e..adf8b72993e4c 100644
--- a/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
+++ b/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
@@ -442,16 +442,19 @@ mlir::LogicalResult hlfir::ParentComponentOp::verify() {
}
//===----------------------------------------------------------------------===//
-// AnyOp
+// LogicalReductionOp
//===----------------------------------------------------------------------===//
-mlir::LogicalResult hlfir::AnyOp::verify() {
- mlir::Operation *op = getOperation();
+template <typename LogicalReductionOp>
+static mlir::LogicalResult
+verifyLogicalReductionOp(LogicalReductionOp reductionOp) {
+ mlir::Operation *op = reductionOp->getOperation();
auto results = op->getResultTypes();
assert(results.size() == 1);
- mlir::Value mask = getMask();
- mlir::Value dim = getDim();
+ mlir::Value mask = reductionOp->getMask();
+ mlir::Value dim = reductionOp->getDim();
+
fir::SequenceType maskTy =
hlfir::getFortranElementOrSequenceType(mask.getType())
.cast<fir::SequenceType>();
@@ -462,7 +465,7 @@ mlir::LogicalResult hlfir::AnyOp::verify() {
if (mlir::isa<fir::LogicalType>(resultType)) {
// Result is of the same type as MASK
if (resultType != logicalTy)
- return emitOpError(
+ return reductionOp->emitOpError(
"result must have the same element type as MASK argument");
} else if (auto resultExpr =
@@ -470,25 +473,42 @@ mlir::LogicalResult hlfir::AnyOp::verify() {
// Result should only be in hlfir.expr form if it is an array
if (maskShape.size() > 1 && dim != nullptr) {
if (!resultExpr.isArray())
- return emitOpError("result must be an array");
+ return reductionOp->emitOpError("result must be an array");
if (resultExpr.getEleTy() != logicalTy)
- return emitOpError(
+ return reductionOp->emitOpError(
"result must have the same element type as MASK argument");
llvm::ArrayRef<int64_t> resultShape = resultExpr.getShape();
// Result has rank n-1
if (resultShape.size() != (maskShape.size() - 1))
- return emitOpError("result rank must be one less than MASK");
+ return reductionOp->emitOpError(
+ "result rank must be one less than MASK");
} else {
- return emitOpError("result must be of logical type");
+ return reductionOp->emitOpError("result must be of logical type");
}
} else {
- return emitOpError("result must be of logical type");
+ return reductionOp->emitOpError("result must be of logical type");
}
return mlir::success();
}
+//===----------------------------------------------------------------------===//
+// AllOp
+//===----------------------------------------------------------------------===//
+
+mlir::LogicalResult hlfir::AllOp::verify() {
+ return verifyLogicalReductionOp<hlfir::AllOp *>(this);
+}
+
+//===----------------------------------------------------------------------===//
+// AnyOp
+//===----------------------------------------------------------------------===//
+
+mlir::LogicalResult hlfir::AnyOp::verify() {
+ return verifyLogicalReductionOp<hlfir::AnyOp *>(this);
+}
+
//===----------------------------------------------------------------------===//
// ConcatOp
//===----------------------------------------------------------------------===//
@@ -537,11 +557,12 @@ void hlfir::ConcatOp::build(mlir::OpBuilder &builder,
}
//===----------------------------------------------------------------------===//
-// ReductionOp
+// NumericalReductionOp
//===----------------------------------------------------------------------===//
-template <typename ReductionOp>
-static mlir::LogicalResult verifyReductionOp(ReductionOp reductionOp) {
+template <typename NumericalReductionOp>
+static mlir::LogicalResult
+verifyNumericalReductionOp(NumericalReductionOp reductionOp) {
mlir::Operation *op = reductionOp->getOperation();
auto results = op->getResultTypes();
@@ -619,7 +640,7 @@ static mlir::LogicalResult verifyReductionOp(ReductionOp reductionOp) {
//===----------------------------------------------------------------------===//
mlir::LogicalResult hlfir::ProductOp::verify() {
- return verifyReductionOp<hlfir::ProductOp *>(this);
+ return verifyNumericalReductionOp<hlfir::ProductOp *>(this);
}
//===----------------------------------------------------------------------===//
@@ -645,7 +666,7 @@ void hlfir::SetLengthOp::build(mlir::OpBuilder &builder,
//===----------------------------------------------------------------------===//
mlir::LogicalResult hlfir::SumOp::verify() {
- return verifyReductionOp<hlfir::SumOp *>(this);
+ return verifyNumericalReductionOp<hlfir::SumOp *>(this);
}
//===----------------------------------------------------------------------===//
diff --git a/flang/test/HLFIR/all.fir b/flang/test/HLFIR/all.fir
new file mode 100644
index 0000000000000..00ce1b3a5fbae
--- /dev/null
+++ b/flang/test/HLFIR/all.fir
@@ -0,0 +1,113 @@
+// Test hlfir.all operation parse, verify (no errors), and unparse
+
+// RUN: fir-opt %s | fir-opt | FileCheck %s
+
+// mask is an expression of known shape
+func.func @all0(%arg0: !hlfir.expr<2x!fir.logical<4>>) {
+ %all = hlfir.all %arg0 : (!hlfir.expr<2x!fir.logical<4>>) -> !fir.logical<4>
+ return
+}
+// CHECK: func.func @all0(%[[ARRAY:.*]]: !hlfir.expr<2x!fir.logical<4>>) {
+// CHECK-NEXT: %[[ALL:.*]] = hlfir.all %[[ARRAY]] : (!hlfir.expr<2x!fir.logical<4>>) -> !fir.logical<4>
+// CHECK-NEXT: return
+// CHECK-NEXT: }
+
+// mask is an expression of assumed shape
+func.func @all1(%arg0: !hlfir.expr<?x!fir.logical<4>>) {
+ %all = hlfir.all %arg0 : (!hlfir.expr<?x!fir.logical<4>>) -> !fir.logical<4>
+ return
+}
+// CHECK: func.func @all1(%[[ARRAY:.*]]: !hlfir.expr<?x!fir.logical<4>>) {
+// CHECK-NEXT: %[[ALL:.*]] = hlfir.all %[[ARRAY]] : (!hlfir.expr<?x!fir.logical<4>>) -> !fir.logical<4>
+// CHECK-NEXT: return
+// CHECK-NEXT: }
+
+// mask is a boxed array
+func.func @all2(%arg0: !fir.box<!fir.array<2x!fir.logical<4>>>) {
+ %all = hlfir.all %arg0 : (!fir.box<!fir.array<2x!fir.logical<4>>>) -> !fir.logical<4>
+ return
+}
+// CHECK: func.func @all2(%[[ARRAY:.*]]: !fir.box<!fir.array<2x!fir.logical<4>>>) {
+// CHECK-NEXT: %[[ALL:.*]] = hlfir.all %[[ARRAY]] : (!fir.box<!fir.array<2x!fir.logical<4>>>) -> !fir.logical<4>
+// CHECK-NEXT: return
+// CHECK-NEXT: }
+
+// mask is an assumed shape boxed array
+func.func @all3(%arg0: !fir.box<!fir.array<?x!fir.logical<4>>>){
+ %all = hlfir.all %arg0 : (!fir.box<!fir.array<?x!fir.logical<4>>>) -> !fir.logical<4>
+ return
+}
+// CHECK: func.func @all3(%[[ARRAY:.*]]: !fir.box<!fir.array<?x!fir.logical<4>>>) {
+// CHECK-NEXT: %[[ALL:.*]] = hlfir.all %[[ARRAY]] : (!fir.box<!fir.array<?x!fir.logical<4>>>) -> !fir.logical<4>
+// CHECK-NEXT: return
+// CHECK-NEXT: }
+
+// mask is a 2-dimensional array
+func.func @all4(%arg0: !fir.box<!fir.array<?x?x!fir.logical<4>>>){
+ %all = hlfir.all %arg0 : (!fir.box<!fir.array<?x?x!fir.logical<4>>>) -> !fir.logical<4>
+ return
+}
+// CHECK: func.func @all4(%[[ARRAY:.*]]: !fir.box<!fir.array<?x?x!fir.logical<4>>>) {
+// CHECK-NEXT: %[[ALL:.*]] = hlfir.all %[[ARRAY]] : (!fir.box<!fir.array<?x?x!fir.logical<4>>>) -> !fir.logical<4>
+// CHECK-NEXT: return
+// CHECK-NEXT: }
+
+// mask and dim argument
+func.func @all5(%arg0: !fir.box<!fir.array<?x!fir.logical<4>>>, %arg1: i32) {
+ %all = hlfir.all %arg0 dim %arg1 : (!fir.box<!fir.array<?x!fir.logical<4>>>, i32) -> !fir.logical<4>
+ return
+}
+// CHECK: func.func @all5(%[[ARRAY:.*]]: !fir.box<!fir.array<?x!fir.logical<4>>>, %[[DIM:.*]]: i32) {
+// CHECK-NEXT: %[[ALL:.*]] = hlfir.all %[[ARRAY]] dim %[[DIM]] : (!fir.box<!fir.array<?x!fir.logical<4>>>, i32) -> !fir.logical<4>
+// CHECK-NEXT: return
+// CHECK-NEXT: }
+
+// hlfir.all with dim argument with an unusual type
+func.func @all6(%arg0: !fir.box<!fir.array<?x!fir.logical<4>>>, %arg1: index) {
+ %all = hlfir.all %arg0 dim %arg1 : (!fir.box<!fir.array<?x!fir.logical<4>>>, index) ->!fir.logical<4>
+ return
+}
+// CHECK: func.func @all6(%[[ARRAY:.*]]: !fir.box<!fir.array<?x!fir.logical<4>>>, %[[DIM:.*]]: index) {
+// CHECK-NEXT: %[[ALL:.*]] = hlfir.all %[[ARRAY]] dim %[[DIM]] : (!fir.box<!fir.array<?x!fir.logical<4>>>, index) -> !fir.logical<4>
+// CHECK-NEXT: return
+// CHECK-NEXT: }
+
+// mask is a 2 dimensional array with dim
+func.func @all7(%arg0: !fir.box<!fir.array<?x?x!fir.logical<4>>>, %arg1: i32) {
+ %all = hlfir.all %arg0 dim %arg1 : (!fir.box<!fir.array<?x?x!fir.logical<4>>>, i32) -> !hlfir.expr<?x!fir.logical<4>>
+ return
+}
+// CHECK: func.func @all7(%[[ARRAY:.*]]: !fir.box<!fir.array<?x?x!fir.logical<4>>>, %[[DIM:.*]]: i32) {
+// CHECK-NEXT: %[[ALL:.*]] = hlfir.all %[[ARRAY]] dim %[[DIM]] : (!fir.box<!fir.array<?x?x!fir.logical<4>>>, i32) -> !hlfir.expr<?x!fir.logical<4>>
+// CHECK-NEXT: return
+// CHECK-NEXT: }
+
+// known shape expr return
+func.func @all8(%arg0: !fir.box<!fir.array<2x2x!fir.logical<4>>>, %arg1: i32) {
+ %all = hlfir.all %arg0 dim %arg1 : (!fir.box<!fir.array<2x2x!fir.logical<4>>>, i32) -> !hlfir.expr<2x!fir.logical<4>>
+ return
+}
+// CHECK: func.func @all8(%[[ARRAY:.*]]: !fir.box<!fir.array<2x2x!fir.logical<4>>>, %[[DIM:.*]]: i32) {
+// CHECK-NEXT: %[[ALL:.*]] = hlfir.all %[[ARRAY]] dim %[[DIM]] : (!fir.box<!fir.array<2x2x!fir.logical<4>>>, i32) -> !hlfir.expr<2x!fir.logical<4>>
+// CHECK-NEXT: return
+// CHECK-NEXT: }
+
+// hlfir.all with mask argument of ref<array<>> type
+func.func @all9(%arg0: !fir.ref<!fir.array<?x!fir.logical<4>>>) {
+ %all = hlfir.all %arg0 : (!fir.ref<!fir.array<?x!fir.logical<4>>>) -> !fir.logical<4>
+ return
+}
+// CHECK: func.func @all9(%[[ARRAY:.*]]: !fir.ref<!fir.array<?x!fir.logical<4>>>) {
+// CHECK-NEXT: %[[ALL:.*]] = hlfir.all %[[ARRAY]] : (!fir.ref<!fir.array<?x!fir.logical<4>>>) -> !fir.logical<4>
+// CHECK-NEXT: return
+// CHECK-NEXT: }
+
+// hlfir.all with fir.logical<8> type
+func.func @all10(%arg0: !fir.box<!fir.array<?x!fir.logical<8>>>) {
+ %all = hlfir.all %arg0 : (!fir.box<!fir.array<?x!fir.logical<8>>>) -> !fir.logical<8>
+ return
+}
+// CHECK: func.func @all10(%[[ARRAY:.*]]: !fir.box<!fir.array<?x!fir.logical<8>>>) {
+// CHECK-NEXT: %[[ALL:.*]] = hlfir.all %[[ARRAY]] : (!fir.box<!fir.array<?x!fir.logical<8>>>) -> !fir.logical<8>
+// CHECK-NEXT: return
+// CHECK-NEXT: }
\ No newline at end of file
diff --git a/flang/test/HLFIR/invalid.fir b/flang/test/HLFIR/invalid.fir
index e1c95c1046dc4..8dc5679346bc1 100644
--- a/flang/test/HLFIR/invalid.fir
+++ b/flang/test/HLFIR/invalid.fir
@@ -332,6 +332,42 @@ func.func @bad_any6(%arg0: !hlfir.expr<?x!fir.logical<4>>) {
%0 = hlfir.any %arg0 : (!hlfir.expr<?x!fir.logical<4>>) -> !hlfir.expr<!fir.logical<4>>
}
+// -----
+func.func @bad_all1(%arg0: !hlfir.expr<?x!fir.logical<4>>) {
+ // expected-error at +1 {{'hlfir.all' op result must have the same element type as MASK argument}}
+ %0 = hlfir.all %arg0 : (!hlfir.expr<?x!fir.logical<4>>) -> !fir.logical<8>
+}
+
+// -----
+func.func @bad_all2(%arg0: !hlfir.expr<?x?x!fir.logical<4>>, %arg1: i32) {
+ // expected-error at +1 {{'hlfir.all' op result must have the same element type as MASK argument}}
+ %0 = hlfir.all %arg0 dim %arg1 : (!hlfir.expr<?x?x!fir.logical<4>>, i32) -> !hlfir.expr<?x!fir.logical<8>>
+}
+
+// -----
+func.func @bad_all3(%arg0: !hlfir.expr<?x?x!fir.logical<4>>, %arg1: i32){
+ // expected-error at +1 {{'hlfir.all' op result rank must be one less than MASK}}
+ %0 = hlfir.all %arg0 dim %arg1 : (!hlfir.expr<?x?x!fir.logical<4>>, i32) -> !hlfir.expr<?x?x!fir.logical<4>>
+}
+
+// -----
+func.func @bad_all4(%arg0: !hlfir.expr<?x?x!fir.logical<4>>, %arg1: i32) {
+ // expected-error at +1 {{'hlfir.all' op result must be an array}}
+ %0 = hlfir.all %arg0 dim %arg1 : (!hlfir.expr<?x?x!fir.logical<4>>, i32) -> !hlfir.expr<!fir.logical<4>>
+}
+
+// -----
+func.func @bad_all5(%arg0: !hlfir.expr<?x!fir.logical<4>>) {
+ // expected-error at +1 {{'hlfir.all' op result must be of logical type}}
+ %0 = hlfir.all %arg0 : (!hlfir.expr<?x!fir.logical<4>>) -> i32
+}
+
+// -----
+func.func @bad_all6(%arg0: !hlfir.expr<?x!fir.logical<4>>) {
+ // expected-error at +1 {{'hlfir.all' op result must be of logical type}}
+ %0 = hlfir.all %arg0 : (!hlfir.expr<?x!fir.logical<4>>) -> !hlfir.expr<!fir.logical<4>>
+}
+
// -----
func.func @bad_product1(%arg0: !hlfir.expr<?xi32>, %arg1: i32, %arg2: !fir.box<!fir.logical<4>>) {
// expected-error at +1 {{'hlfir.product' op result must have the same element type as ARRAY argument}}
More information about the flang-commits
mailing list