[flang-commits] [flang] 5b5ef2e - [fir] Add fir.save_result op

Valentin Clement via flang-commits flang-commits at lists.llvm.org
Tue Sep 28 02:58:50 PDT 2021


Author: Eric Schweitz
Date: 2021-09-28T11:58:42+02:00
New Revision: 5b5ef2e26558ee9d5da4ba4af69737732da49858

URL: https://github.com/llvm/llvm-project/commit/5b5ef2e26558ee9d5da4ba4af69737732da49858
DIFF: https://github.com/llvm/llvm-project/commit/5b5ef2e26558ee9d5da4ba4af69737732da49858.diff

LOG: [fir] Add fir.save_result op

Add the fir.save_result operation. It is use to save an
array, box, or record function result SSA-value to a memory location

Reviewed By: jeanPerier

Differential Revision: https://reviews.llvm.org/D110407

Co-authored-by: Jean Perier <jperier at nvidia.com>
Co-authored-by: Valentin Clement <clementval at gmail.com>

Added: 
    

Modified: 
    flang/include/flang/Optimizer/Dialect/FIROps.td
    flang/include/flang/Optimizer/Dialect/FIRTypes.td
    flang/lib/Optimizer/Dialect/FIROps.cpp
    flang/test/Fir/fir-ops.fir
    flang/test/Fir/invalid.fir

Removed: 
    


################################################################################
diff  --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td
index 5aedc017740f0..3a3112beb4ef1 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROps.td
+++ b/flang/include/flang/Optimizer/Dialect/FIROps.td
@@ -352,6 +352,55 @@ def fir_LoadOp : fir_OneResultOp<"load"> {
   }];
 }
 
+def fir_SaveResultOp : fir_Op<"save_result", [AttrSizedOperandSegments]> {
+  let summary = [{
+    save an array, box, or record function result SSA-value to a memory location
+  }];
+
+  let description = [{
+    Save the result of a function returning an array, box, or record type value
+    into a memory location given the shape and length parameters of the result.
+
+    Function results of type fir.box, fir.array, or fir.rec are abstract values
+    that require a storage to be manipulated on the caller side. This operation
+    allows associating such abstract result to a storage. In later lowering of
+    the function interfaces, this storage might be used to pass the result in
+    memory.
+
+    For arrays, result, it is required to provide the shape of the result. For
+    character arrays and derived types with length parameters, the length
+    parameter values must be provided.
+
+    The fir.save_result associated to a function call must immediately follow
+    the call and be in the same block.
+
+    ```mlir
+      %buffer = fir.alloca fir.array<?xf32>, %c100
+      %shape = fir.shape %c100
+      %array_result = fir.call @foo() : () -> fir.array<?xf32>
+      fir.save_result %array_result to %buffer(%shape)
+      %coor = fir.array_coor %buffer%(%shape), %c5
+      %fifth_element = fir.load %coor : f32
+    ```
+
+    The above fir.save_result allows saving a fir.array function result into
+    a buffer to later access its 5th element.
+
+  }];
+
+  let arguments = (ins ArrayOrBoxOrRecord:$value,
+                   Arg<AnyReferenceLike, "", [MemWrite]>:$memref,
+                   Optional<AnyShapeType>:$shape,
+                   Variadic<AnyIntegerType>:$typeparams);
+
+  let assemblyFormat = [{
+    $value `to` $memref (`(` $shape^ `)`)? (`typeparams` $typeparams^)?
+    attr-dict `:` type(operands)
+  }];
+
+  let verifier = [{ return ::verify(*this); }];
+}
+
 def fir_StoreOp : fir_Op<"store", []> {
   let summary = "store an SSA-value to a memory location";
 

diff  --git a/flang/include/flang/Optimizer/Dialect/FIRTypes.td b/flang/include/flang/Optimizer/Dialect/FIRTypes.td
index b1df67186ae91..b873c72deb67b 100644
--- a/flang/include/flang/Optimizer/Dialect/FIRTypes.td
+++ b/flang/include/flang/Optimizer/Dialect/FIRTypes.td
@@ -551,4 +551,9 @@ def AnyCoordinateType : Type<AnyCoordinateLike.predicate, "coordinate type">;
 def AnyAddressableLike : TypeConstraint<Or<[fir_ReferenceType.predicate,
     FunctionType.predicate]>, "any addressable">;
 
+def ArrayOrBoxOrRecord : TypeConstraint<Or<[fir_SequenceType.predicate,
+    fir_BoxType.predicate, fir_RecordType.predicate]>,
+    "fir.box, fir.array or fir.type">;
+
+
 #endif // FIR_DIALECT_FIR_TYPES

diff  --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp
index 05d222a2ff4c9..1b757553ca30b 100644
--- a/flang/lib/Optimizer/Dialect/FIROps.cpp
+++ b/flang/lib/Optimizer/Dialect/FIROps.cpp
@@ -1361,6 +1361,63 @@ static mlir::LogicalResult verify(fir::ResultOp op) {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// SaveResultOp
+//===----------------------------------------------------------------------===//
+
+static mlir::LogicalResult verify(fir::SaveResultOp op) {
+  auto resultType = op.value().getType();
+  if (resultType != fir::dyn_cast_ptrEleTy(op.memref().getType()))
+    return op.emitOpError("value type must match memory reference type");
+  if (fir::isa_unknown_size_box(resultType))
+    return op.emitOpError("cannot save !fir.box of unknown rank or type");
+
+  if (resultType.isa<fir::BoxType>()) {
+    if (op.shape() || !op.typeparams().empty())
+      return op.emitOpError(
+          "must not have shape or length operands if the value is a fir.box");
+    return mlir::success();
+  }
+
+  // fir.record or fir.array case.
+  unsigned shapeTyRank = 0;
+  if (auto shapeOp = op.shape()) {
+    auto shapeTy = shapeOp.getType();
+    if (auto s = shapeTy.dyn_cast<fir::ShapeType>())
+      shapeTyRank = s.getRank();
+    else
+      shapeTyRank = shapeTy.cast<fir::ShapeShiftType>().getRank();
+  }
+
+  auto eleTy = resultType;
+  if (auto seqTy = resultType.dyn_cast<fir::SequenceType>()) {
+    if (seqTy.getDimension() != shapeTyRank)
+      op.emitOpError("shape operand must be provided and have the value rank "
+                     "when the value is a fir.array");
+    eleTy = seqTy.getEleTy();
+  } else {
+    if (shapeTyRank != 0)
+      op.emitOpError(
+          "shape operand should only be provided if the value is a fir.array");
+  }
+
+  if (auto recTy = eleTy.dyn_cast<fir::RecordType>()) {
+    if (recTy.getNumLenParams() != op.typeparams().size())
+      op.emitOpError("length parameters number must match with the value type "
+                     "length parameters");
+  } else if (auto charTy = eleTy.dyn_cast<fir::CharacterType>()) {
+    if (op.typeparams().size() > 1)
+      op.emitOpError("no more than one length parameter must be provided for "
+                     "character value");
+  } else {
+    if (!op.typeparams().empty())
+      op.emitOpError(
+          "length parameters must not be provided for this value type");
+  }
+
+  return mlir::success();
+}
+
 //===----------------------------------------------------------------------===//
 // SelectOp
 //===----------------------------------------------------------------------===//

diff  --git a/flang/test/Fir/fir-ops.fir b/flang/test/Fir/fir-ops.fir
index dfd97a8b171a9..fcd638cf0ccaa 100644
--- a/flang/test/Fir/fir-ops.fir
+++ b/flang/test/Fir/fir-ops.fir
@@ -671,3 +671,14 @@ func @test_rebox(%arg0: !fir.box<!fir.array<?xf32>>) {
   fir.call @bar_rebox_test(%4) : (!fir.box<!fir.array<?x?xf32>>) -> ()
   return
 }
+
+// CHECK-LABEL: @test_save_result(
+func @test_save_result(%buffer: !fir.ref<!fir.array<?x!fir.char<1,?>>>) {
+  %c100 = constant 100 : index
+  %c50 = constant 50 : index
+  %shape = fir.shape %c100 : (index) -> !fir.shape<1>
+  %res = fir.call @array_func() : () -> !fir.array<?x!fir.char<1,?>>
+  // CHECK: fir.save_result %{{.*}} to %{{.*}}(%{{.*}}) typeparams %{{.*}} : !fir.array<?x!fir.char<1,?>>, !fir.ref<!fir.array<?x!fir.char<1,?>>>, !fir.shape<1>, index
+  fir.save_result %res to %buffer(%shape) typeparams %c50 : !fir.array<?x!fir.char<1,?>>, !fir.ref<!fir.array<?x!fir.char<1,?>>>, !fir.shape<1>, index
+  return
+}

diff  --git a/flang/test/Fir/invalid.fir b/flang/test/Fir/invalid.fir
index f1e482568674d..dd0229662dedf 100644
--- a/flang/test/Fir/invalid.fir
+++ b/flang/test/Fir/invalid.fir
@@ -417,3 +417,80 @@ fir.global internal @_QEmultiarray : !fir.array<32x32xi32> {
   %2 = fir.insert_on_range %0, %c0_i32, [10 : index, 9 : index] : (!fir.array<32x32xi32>, i32) -> !fir.array<32x32xi32>
   fir.has_value %2 : !fir.array<32x32xi32>
 }
+
+// -----
+
+func @bad_save_result(%buffer : !fir.ref<!fir.array<?xf64>>, %n :index) {
+  %res = fir.call @array_func() : () -> !fir.array<?xf32>
+  %shape = fir.shape %n : (index) -> !fir.shape<1>
+  // expected-error at +1 {{'fir.save_result' op value type must match memory reference type}}
+  fir.save_result %res to %buffer(%shape) : !fir.array<?xf32>, !fir.ref<!fir.array<?xf64>>, !fir.shape<1>
+  return
+}
+
+// -----
+
+func @bad_save_result(%buffer : !fir.ref<!fir.box<!fir.array<*:f32>>>) {
+  %res = fir.call @array_func() : () -> !fir.box<!fir.array<*:f32>>
+  // expected-error at +1 {{'fir.save_result' op cannot save !fir.box of unknown rank or type}}
+  fir.save_result %res to %buffer : !fir.box<!fir.array<*:f32>>, !fir.ref<!fir.box<!fir.array<*:f32>>>
+  return
+}
+
+// -----
+
+func @bad_save_result(%buffer : !fir.ref<f64>) {
+  %res = fir.call @array_func() : () -> f64
+  // expected-error at +1 {{'fir.save_result' op operand #0 must be fir.box, fir.array or fir.type, but got 'f64'}}
+  fir.save_result %res to %buffer : f64, !fir.ref<f64>
+  return
+}
+
+// -----
+
+func @bad_save_result(%buffer : !fir.ref<!fir.box<!fir.array<?xf32>>>, %n : index) {
+  %res = fir.call @array_func() : () -> !fir.box<!fir.array<?xf32>>
+  %shape = fir.shape %n : (index) -> !fir.shape<1>
+  // expected-error at +1 {{'fir.save_result' op must not have shape or length operands if the value is a fir.box}}
+  fir.save_result %res to %buffer(%shape) : !fir.box<!fir.array<?xf32>>, !fir.ref<!fir.box<!fir.array<?xf32>>>, !fir.shape<1>
+  return
+}
+
+// -----
+
+func @bad_save_result(%buffer : !fir.ref<!fir.array<?xf32>>, %n :index) {
+  %res = fir.call @array_func() : () -> !fir.array<?xf32>
+  %shape = fir.shape %n, %n : (index, index) -> !fir.shape<2>
+  // expected-error at +1 {{'fir.save_result' op shape operand must be provided and have the value rank when the value is a fir.array}}
+  fir.save_result %res to %buffer(%shape) : !fir.array<?xf32>, !fir.ref<!fir.array<?xf32>>, !fir.shape<2>
+  return
+}
+
+// -----
+
+func @bad_save_result(%buffer : !fir.ref<!fir.type<t{x:f32}>>, %n :index) {
+  %res = fir.call @array_func() : () -> !fir.type<t{x:f32}>
+  %shape = fir.shape %n : (index) -> !fir.shape<1>
+  // expected-error at +1 {{'fir.save_result' op shape operand should only be provided if the value is a fir.array}}
+  fir.save_result %res to %buffer(%shape) : !fir.type<t{x:f32}>, !fir.ref<!fir.type<t{x:f32}>>, !fir.shape<1>
+  return
+}
+
+// -----
+
+func @bad_save_result(%buffer : !fir.ref<!fir.type<t{x:f32}>>, %n :index) {
+  %res = fir.call @array_func() : () -> !fir.type<t{x:f32}>
+  // expected-error at +1 {{'fir.save_result' op length parameters number must match with the value type length parameters}}
+  fir.save_result %res to %buffer typeparams %n : !fir.type<t{x:f32}>, !fir.ref<!fir.type<t{x:f32}>>, index
+  return
+}
+
+// -----
+
+func @bad_save_result(%buffer : !fir.ref<!fir.array<?xf32>>, %n :index) {
+  %res = fir.call @array_func() : () -> !fir.array<?xf32>
+  %shape = fir.shape %n : (index) -> !fir.shape<1>
+  // expected-error at +1 {{'fir.save_result' op length parameters must not be provided for this value type}}
+  fir.save_result %res to %buffer(%shape) typeparams %n : !fir.array<?xf32>, !fir.ref<!fir.array<?xf32>>, !fir.shape<1>, index
+  return
+}


        


More information about the flang-commits mailing list