[flang-commits] [flang] 5b5ef2e - [fir] Add fir.save_result op
Valentin Clement via flang-commits
flang-commits at lists.llvm.org
Tue Sep 28 02:58:50 PDT 2021
Author: Eric Schweitz
Date: 2021-09-28T11:58:42+02:00
New Revision: 5b5ef2e26558ee9d5da4ba4af69737732da49858
URL: https://github.com/llvm/llvm-project/commit/5b5ef2e26558ee9d5da4ba4af69737732da49858
DIFF: https://github.com/llvm/llvm-project/commit/5b5ef2e26558ee9d5da4ba4af69737732da49858.diff
LOG: [fir] Add fir.save_result op
Add the fir.save_result operation. It is use to save an
array, box, or record function result SSA-value to a memory location
Reviewed By: jeanPerier
Differential Revision: https://reviews.llvm.org/D110407
Co-authored-by: Jean Perier <jperier at nvidia.com>
Co-authored-by: Valentin Clement <clementval at gmail.com>
Added:
Modified:
flang/include/flang/Optimizer/Dialect/FIROps.td
flang/include/flang/Optimizer/Dialect/FIRTypes.td
flang/lib/Optimizer/Dialect/FIROps.cpp
flang/test/Fir/fir-ops.fir
flang/test/Fir/invalid.fir
Removed:
################################################################################
diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td
index 5aedc017740f0..3a3112beb4ef1 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROps.td
+++ b/flang/include/flang/Optimizer/Dialect/FIROps.td
@@ -352,6 +352,55 @@ def fir_LoadOp : fir_OneResultOp<"load"> {
}];
}
+def fir_SaveResultOp : fir_Op<"save_result", [AttrSizedOperandSegments]> {
+ let summary = [{
+ save an array, box, or record function result SSA-value to a memory location
+ }];
+
+ let description = [{
+ Save the result of a function returning an array, box, or record type value
+ into a memory location given the shape and length parameters of the result.
+
+ Function results of type fir.box, fir.array, or fir.rec are abstract values
+ that require a storage to be manipulated on the caller side. This operation
+ allows associating such abstract result to a storage. In later lowering of
+ the function interfaces, this storage might be used to pass the result in
+ memory.
+
+ For arrays, result, it is required to provide the shape of the result. For
+ character arrays and derived types with length parameters, the length
+ parameter values must be provided.
+
+ The fir.save_result associated to a function call must immediately follow
+ the call and be in the same block.
+
+ ```mlir
+ %buffer = fir.alloca fir.array<?xf32>, %c100
+ %shape = fir.shape %c100
+ %array_result = fir.call @foo() : () -> fir.array<?xf32>
+ fir.save_result %array_result to %buffer(%shape)
+ %coor = fir.array_coor %buffer%(%shape), %c5
+ %fifth_element = fir.load %coor : f32
+ ```
+
+ The above fir.save_result allows saving a fir.array function result into
+ a buffer to later access its 5th element.
+
+ }];
+
+ let arguments = (ins ArrayOrBoxOrRecord:$value,
+ Arg<AnyReferenceLike, "", [MemWrite]>:$memref,
+ Optional<AnyShapeType>:$shape,
+ Variadic<AnyIntegerType>:$typeparams);
+
+ let assemblyFormat = [{
+ $value `to` $memref (`(` $shape^ `)`)? (`typeparams` $typeparams^)?
+ attr-dict `:` type(operands)
+ }];
+
+ let verifier = [{ return ::verify(*this); }];
+}
+
def fir_StoreOp : fir_Op<"store", []> {
let summary = "store an SSA-value to a memory location";
diff --git a/flang/include/flang/Optimizer/Dialect/FIRTypes.td b/flang/include/flang/Optimizer/Dialect/FIRTypes.td
index b1df67186ae91..b873c72deb67b 100644
--- a/flang/include/flang/Optimizer/Dialect/FIRTypes.td
+++ b/flang/include/flang/Optimizer/Dialect/FIRTypes.td
@@ -551,4 +551,9 @@ def AnyCoordinateType : Type<AnyCoordinateLike.predicate, "coordinate type">;
def AnyAddressableLike : TypeConstraint<Or<[fir_ReferenceType.predicate,
FunctionType.predicate]>, "any addressable">;
+def ArrayOrBoxOrRecord : TypeConstraint<Or<[fir_SequenceType.predicate,
+ fir_BoxType.predicate, fir_RecordType.predicate]>,
+ "fir.box, fir.array or fir.type">;
+
+
#endif // FIR_DIALECT_FIR_TYPES
diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp
index 05d222a2ff4c9..1b757553ca30b 100644
--- a/flang/lib/Optimizer/Dialect/FIROps.cpp
+++ b/flang/lib/Optimizer/Dialect/FIROps.cpp
@@ -1361,6 +1361,63 @@ static mlir::LogicalResult verify(fir::ResultOp op) {
return success();
}
+//===----------------------------------------------------------------------===//
+// SaveResultOp
+//===----------------------------------------------------------------------===//
+
+static mlir::LogicalResult verify(fir::SaveResultOp op) {
+ auto resultType = op.value().getType();
+ if (resultType != fir::dyn_cast_ptrEleTy(op.memref().getType()))
+ return op.emitOpError("value type must match memory reference type");
+ if (fir::isa_unknown_size_box(resultType))
+ return op.emitOpError("cannot save !fir.box of unknown rank or type");
+
+ if (resultType.isa<fir::BoxType>()) {
+ if (op.shape() || !op.typeparams().empty())
+ return op.emitOpError(
+ "must not have shape or length operands if the value is a fir.box");
+ return mlir::success();
+ }
+
+ // fir.record or fir.array case.
+ unsigned shapeTyRank = 0;
+ if (auto shapeOp = op.shape()) {
+ auto shapeTy = shapeOp.getType();
+ if (auto s = shapeTy.dyn_cast<fir::ShapeType>())
+ shapeTyRank = s.getRank();
+ else
+ shapeTyRank = shapeTy.cast<fir::ShapeShiftType>().getRank();
+ }
+
+ auto eleTy = resultType;
+ if (auto seqTy = resultType.dyn_cast<fir::SequenceType>()) {
+ if (seqTy.getDimension() != shapeTyRank)
+ op.emitOpError("shape operand must be provided and have the value rank "
+ "when the value is a fir.array");
+ eleTy = seqTy.getEleTy();
+ } else {
+ if (shapeTyRank != 0)
+ op.emitOpError(
+ "shape operand should only be provided if the value is a fir.array");
+ }
+
+ if (auto recTy = eleTy.dyn_cast<fir::RecordType>()) {
+ if (recTy.getNumLenParams() != op.typeparams().size())
+ op.emitOpError("length parameters number must match with the value type "
+ "length parameters");
+ } else if (auto charTy = eleTy.dyn_cast<fir::CharacterType>()) {
+ if (op.typeparams().size() > 1)
+ op.emitOpError("no more than one length parameter must be provided for "
+ "character value");
+ } else {
+ if (!op.typeparams().empty())
+ op.emitOpError(
+ "length parameters must not be provided for this value type");
+ }
+
+ return mlir::success();
+}
+
//===----------------------------------------------------------------------===//
// SelectOp
//===----------------------------------------------------------------------===//
diff --git a/flang/test/Fir/fir-ops.fir b/flang/test/Fir/fir-ops.fir
index dfd97a8b171a9..fcd638cf0ccaa 100644
--- a/flang/test/Fir/fir-ops.fir
+++ b/flang/test/Fir/fir-ops.fir
@@ -671,3 +671,14 @@ func @test_rebox(%arg0: !fir.box<!fir.array<?xf32>>) {
fir.call @bar_rebox_test(%4) : (!fir.box<!fir.array<?x?xf32>>) -> ()
return
}
+
+// CHECK-LABEL: @test_save_result(
+func @test_save_result(%buffer: !fir.ref<!fir.array<?x!fir.char<1,?>>>) {
+ %c100 = constant 100 : index
+ %c50 = constant 50 : index
+ %shape = fir.shape %c100 : (index) -> !fir.shape<1>
+ %res = fir.call @array_func() : () -> !fir.array<?x!fir.char<1,?>>
+ // CHECK: fir.save_result %{{.*}} to %{{.*}}(%{{.*}}) typeparams %{{.*}} : !fir.array<?x!fir.char<1,?>>, !fir.ref<!fir.array<?x!fir.char<1,?>>>, !fir.shape<1>, index
+ fir.save_result %res to %buffer(%shape) typeparams %c50 : !fir.array<?x!fir.char<1,?>>, !fir.ref<!fir.array<?x!fir.char<1,?>>>, !fir.shape<1>, index
+ return
+}
diff --git a/flang/test/Fir/invalid.fir b/flang/test/Fir/invalid.fir
index f1e482568674d..dd0229662dedf 100644
--- a/flang/test/Fir/invalid.fir
+++ b/flang/test/Fir/invalid.fir
@@ -417,3 +417,80 @@ fir.global internal @_QEmultiarray : !fir.array<32x32xi32> {
%2 = fir.insert_on_range %0, %c0_i32, [10 : index, 9 : index] : (!fir.array<32x32xi32>, i32) -> !fir.array<32x32xi32>
fir.has_value %2 : !fir.array<32x32xi32>
}
+
+// -----
+
+func @bad_save_result(%buffer : !fir.ref<!fir.array<?xf64>>, %n :index) {
+ %res = fir.call @array_func() : () -> !fir.array<?xf32>
+ %shape = fir.shape %n : (index) -> !fir.shape<1>
+ // expected-error at +1 {{'fir.save_result' op value type must match memory reference type}}
+ fir.save_result %res to %buffer(%shape) : !fir.array<?xf32>, !fir.ref<!fir.array<?xf64>>, !fir.shape<1>
+ return
+}
+
+// -----
+
+func @bad_save_result(%buffer : !fir.ref<!fir.box<!fir.array<*:f32>>>) {
+ %res = fir.call @array_func() : () -> !fir.box<!fir.array<*:f32>>
+ // expected-error at +1 {{'fir.save_result' op cannot save !fir.box of unknown rank or type}}
+ fir.save_result %res to %buffer : !fir.box<!fir.array<*:f32>>, !fir.ref<!fir.box<!fir.array<*:f32>>>
+ return
+}
+
+// -----
+
+func @bad_save_result(%buffer : !fir.ref<f64>) {
+ %res = fir.call @array_func() : () -> f64
+ // expected-error at +1 {{'fir.save_result' op operand #0 must be fir.box, fir.array or fir.type, but got 'f64'}}
+ fir.save_result %res to %buffer : f64, !fir.ref<f64>
+ return
+}
+
+// -----
+
+func @bad_save_result(%buffer : !fir.ref<!fir.box<!fir.array<?xf32>>>, %n : index) {
+ %res = fir.call @array_func() : () -> !fir.box<!fir.array<?xf32>>
+ %shape = fir.shape %n : (index) -> !fir.shape<1>
+ // expected-error at +1 {{'fir.save_result' op must not have shape or length operands if the value is a fir.box}}
+ fir.save_result %res to %buffer(%shape) : !fir.box<!fir.array<?xf32>>, !fir.ref<!fir.box<!fir.array<?xf32>>>, !fir.shape<1>
+ return
+}
+
+// -----
+
+func @bad_save_result(%buffer : !fir.ref<!fir.array<?xf32>>, %n :index) {
+ %res = fir.call @array_func() : () -> !fir.array<?xf32>
+ %shape = fir.shape %n, %n : (index, index) -> !fir.shape<2>
+ // expected-error at +1 {{'fir.save_result' op shape operand must be provided and have the value rank when the value is a fir.array}}
+ fir.save_result %res to %buffer(%shape) : !fir.array<?xf32>, !fir.ref<!fir.array<?xf32>>, !fir.shape<2>
+ return
+}
+
+// -----
+
+func @bad_save_result(%buffer : !fir.ref<!fir.type<t{x:f32}>>, %n :index) {
+ %res = fir.call @array_func() : () -> !fir.type<t{x:f32}>
+ %shape = fir.shape %n : (index) -> !fir.shape<1>
+ // expected-error at +1 {{'fir.save_result' op shape operand should only be provided if the value is a fir.array}}
+ fir.save_result %res to %buffer(%shape) : !fir.type<t{x:f32}>, !fir.ref<!fir.type<t{x:f32}>>, !fir.shape<1>
+ return
+}
+
+// -----
+
+func @bad_save_result(%buffer : !fir.ref<!fir.type<t{x:f32}>>, %n :index) {
+ %res = fir.call @array_func() : () -> !fir.type<t{x:f32}>
+ // expected-error at +1 {{'fir.save_result' op length parameters number must match with the value type length parameters}}
+ fir.save_result %res to %buffer typeparams %n : !fir.type<t{x:f32}>, !fir.ref<!fir.type<t{x:f32}>>, index
+ return
+}
+
+// -----
+
+func @bad_save_result(%buffer : !fir.ref<!fir.array<?xf32>>, %n :index) {
+ %res = fir.call @array_func() : () -> !fir.array<?xf32>
+ %shape = fir.shape %n : (index) -> !fir.shape<1>
+ // expected-error at +1 {{'fir.save_result' op length parameters must not be provided for this value type}}
+ fir.save_result %res to %buffer(%shape) typeparams %n : !fir.array<?xf32>, !fir.ref<!fir.array<?xf32>>, !fir.shape<1>, index
+ return
+}
More information about the flang-commits
mailing list