[flang-commits] [flang] 67360de - [flang][fir] Add array value operations.
Eric Schweitz via flang-commits
flang-commits at lists.llvm.org
Thu Feb 25 19:17:05 PST 2021
Author: Eric Schweitz
Date: 2021-02-25T19:16:43-08:00
New Revision: 67360decc3d4bda363c2fa2550eb1c2b876c8cf0
URL: https://github.com/llvm/llvm-project/commit/67360decc3d4bda363c2fa2550eb1c2b876c8cf0
DIFF: https://github.com/llvm/llvm-project/commit/67360decc3d4bda363c2fa2550eb1c2b876c8cf0.diff
LOG: [flang][fir] Add array value operations.
We lower expressions with rank > 0 to a set of high-level array operations.
These operations are then analyzed and refined to more primitve
operations in subsequent pass(es).
This patch upstreams these array operations and some other helper ops.
Authors: Eric Schweitz, Rajan Walia, Kiran Chandramohan, et.al.
https://github.com/flang-compiler/f18-llvm-project/pull/565
Differential Revision: https://reviews.llvm.org/D97421
Added:
Modified:
flang/include/flang/Optimizer/Dialect/FIROps.td
flang/include/flang/Optimizer/Dialect/FIRType.h
flang/lib/Optimizer/Dialect/FIROps.cpp
flang/lib/Optimizer/Dialect/FIRType.cpp
flang/test/Fir/fir-ops.fir
Removed:
################################################################################
diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td
index c7f4353abc9b..32b9d341e89b 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROps.td
+++ b/flang/include/flang/Optimizer/Dialect/FIROps.td
@@ -19,7 +19,6 @@ include "mlir/Interfaces/CallInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/LoopLikeInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
-
include "flang/Optimizer/Dialect/FIRTypes.td"
// Base class for FIR operations.
@@ -1495,9 +1494,263 @@ def fir_BoxTypeDescOp : fir_SimpleOneResultOp<"box_tdesc", [NoSideEffect]> {
let results = (outs fir_TypeDescType);
}
+//===----------------------------------------------------------------------===//
+// Array value operations
+//===----------------------------------------------------------------------===//
+
+def fir_ArrayLoadOp : fir_Op<"array_load", [AttrSizedOperandSegments]> {
+
+ let summary = "Load an array as a value.";
+
+ let description = [{
+ Load an entire array as a single SSA value.
+
+ ```fortran
+ real :: a(o:n,p:m)
+ ...
+ ... = ... a ...
+ ```
+
+ One can use `fir.array_load` to produce an ssa-value that captures an
+ immutable value of the entire array `a`, as in the Fortran array expression
+ shown above. Subsequent changes to the memory containing the array do not
+ alter its composite value. This operation let's one load an array as a
+ value while applying a runtime shape, shift, or slice to the memory
+ reference, and its semantics guarantee immutability.
+
+ ```mlir
+ %s = fir.shape_shift %o, %n, %p, %m : (index, index, index, index) -> !fir.shape<2>
+ // load the entire array 'a'
+ %v = fir.array_load %a(%s) : (!fir.ref<!fir.array<?x?xf32>>, !fir.shape<2>) -> !fir.array<?x?xf32>
+ // a fir.store here into array %a does not change %v
+ ```
+ }];
+
+ let arguments = (ins
+ Arg<AnyRefOrBox, "", [MemRead]>:$memref,
+ Optional<AnyShapeOrShiftType>:$shape,
+ Optional<fir_SliceType>:$slice,
+ Variadic<AnyIntegerType>:$lenParams
+ );
+
+ let results = (outs fir_SequenceType);
+
+ let assemblyFormat = [{
+ $memref (`(`$shape^`)`)? (`[`$slice^`]`)? (`typeparams` $lenParams^)? attr-dict `:` functional-type(operands, results)
+ }];
+
+ let verifier = [{ return ::verify(*this); }];
+
+ let extraClassDeclaration = [{
+ std::vector<mlir::Value> getExtents();
+ }];
+}
+
+def fir_ArrayFetchOp : fir_Op<"array_fetch", [NoSideEffect]> {
+
+ let summary = "Fetch the value of an element of an array value";
+
+ let description = [{
+ Fetch the value of an element in an array value.
+
+ ```fortran
+ real :: a(n,m)
+ ...
+ ... a ...
+ ... a(r,s+1) ...
+ ```
+
+ One can use `fir.array_fetch` to fetch the (implied) value of `a(i,j)` in
+ an array expression as shown above. It can also be used to extract the
+ element `a(r,s+1)` in the second expression.
+
+ ```mlir
+ %s = fir.shape %n, %m : (index, index) -> !fir.shape<2>
+ // load the entire array 'a'
+ %v = fir.array_load %a(%s) : (!fir.ref<!fir.array<?x?xf32>>, !fir.shape<2>) -> !fir.array<?x?xf32>
+ // fetch the value of one of the array value's elements
+ %1 = fir.array_fetch %v, %i, %j : (!fir.array<?x?xf32>, index, index) -> f32
+ ```
+
+ It is only possible to use `array_fetch` on an `array_load` result value.
+ }];
+
+ let arguments = (ins
+ fir_SequenceType:$sequence,
+ Variadic<AnyCoordinateType>:$indices
+ );
+
+ let results = (outs AnyType:$element);
+
+ let assemblyFormat = [{
+ $sequence `,` $indices attr-dict `:` functional-type(operands, results)
+ }];
+
+ let verifier = [{
+ auto arrTy = sequence().getType().cast<fir::SequenceType>();
+ if (indices().size() != arrTy.getDimension())
+ return emitOpError("number of indices != dimension of array");
+ if (element().getType() != arrTy.getEleTy())
+ return emitOpError("return type does not match array");
+ if (!isa<fir::ArrayLoadOp>(sequence().getDefiningOp()))
+ return emitOpError("argument #0 must be result of fir.array_load");
+ return mlir::success();
+ }];
+}
+
+def fir_ArrayUpdateOp : fir_Op<"array_update", [NoSideEffect]> {
+
+ let summary = "Update the value of an element of an array value";
+
+ let description = [{
+ Updates the value of an element in an array value. A new array value is
+ returned where all element values of the input array are identical except
+ for the selected element which is the value passed in the update.
+
+ ```fortran
+ real :: a(n,m)
+ ...
+ a = ...
+ ```
+
+ One can use `fir.array_update` to update the (implied) value of `a(i,j)`
+ in an array expression as shown above.
+
+ ```mlir
+ %s = fir.shape %n, %m : (index, index) -> !fir.shape<2>
+ // load the entire array 'a'
+ %v = fir.array_load %a(%s) : (!fir.ref<!fir.array<?x?xf32>>, !fir.shape<2>) -> !fir.array<?x?xf32>
+ // update the value of one of the array value's elements
+ // %r_{ij} = %f if (i,j) = (%i,%j), %v_{ij} otherwise
+ %r = fir.array_update %v, %f, %i, %j : (!fir.array<?x?xf32>, f32, index, index) -> !fir.array<?x?xf32>
+ fir.array_merge_store %v, %r to %a : !fir.ref<!fir.array<?x?xf32>>
+ ```
+
+ An array value update behaves as if a mapping function from the indices
+ to the new value has been added, replacing the previous mapping. These
+ mappings can be added to the ssa-value, but will not be materialized in
+ memory until the `fir.array_merge_store` is performed.
+ }];
+
+ let arguments = (ins
+ fir_SequenceType:$sequence,
+ AnyType:$merge,
+ Variadic<AnyCoordinateType>:$indices
+ );
+
+ let results = (outs fir_SequenceType);
+
+ let assemblyFormat = [{
+ $sequence `,` $merge `,` $indices attr-dict `:` functional-type(operands, results)
+ }];
+
+ let verifier = [{
+ auto arrTy = sequence().getType().cast<fir::SequenceType>();
+ if (merge().getType() != arrTy.getEleTy())
+ return emitOpError("merged value does not have element type");
+ if (indices().size() != arrTy.getDimension())
+ return emitOpError("number of indices != dimension of array");
+ return mlir::success();
+ }];
+}
+
+def fir_ArrayMergeStoreOp : fir_Op<"array_merge_store", [
+ TypesMatchWith<"type of 'original' matches element type of 'memref'",
+ "memref", "original",
+ "fir::dyn_cast_ptrOrBoxEleTy($_self)">,
+ TypesMatchWith<"type of 'sequence' matches element type of 'memref'",
+ "memref", "sequence",
+ "fir::dyn_cast_ptrOrBoxEleTy($_self)">]> {
+
+ let summary = "Store merged array value to memory.";
+
+ let description = [{
+ Store a merged array value to memory.
+
+ ```fortran
+ real :: a(n,m)
+ ...
+ a = ...
+ ```
+
+ One can use `fir.array_merge_store` to merge/copy the value of `a` in an
+ array expression as shown above.
+
+ ```mlir
+ %v = fir.array_load %a(%shape) : ...
+ %r = fir.array_update %v, %f, %i, %j : (!fir.array<?x?xf32>, f32, index, index) -> !fir.array<?x?xf32>
+ fir.array_merge_store %v, %r to %a : !fir.ref<!fir.array<?x?xf32>>
+ ```
+
+ This operation merges the original loaded array value, `%v`, with the
+ chained updates, `%r`, and stores the result to the array at address, `%a`.
+ }];
+
+ let arguments = (ins
+ fir_SequenceType:$original,
+ fir_SequenceType:$sequence,
+ Arg<AnyRefOrBox, "", [MemWrite]>:$memref
+ );
+
+ let assemblyFormat = "$original `,` $sequence `to` $memref attr-dict `:` type($memref)";
+
+ let verifier = [{
+ if (!isa<ArrayLoadOp>(original().getDefiningOp()))
+ return emitOpError("operand #0 must be result of a fir.array_load op");
+ return mlir::success();
+ }];
+}
+
+//===----------------------------------------------------------------------===//
// Record and array type operations
+//===----------------------------------------------------------------------===//
+
+def fir_ArrayCoorOp : fir_Op<"array_coor",
+ [NoSideEffect, AttrSizedOperandSegments]> {
+
+ let summary = "Find the coordinate of an element of an array";
+
+ let description = [{
+ Compute the location of an element in an array when the shape of the
+ array is only known at runtime.
+
+ This operation is intended to capture all the runtime values needed to
+ compute the address of an array reference in a single high-level op. Given
+ the following Fortran input:
+
+ ```fortran
+ real :: a(n,m)
+ ...
+ ... a(i,j) ...
+ ```
+
+ One can use `fir.array_coor` to determine the address of `a(i,j)`.
+
+ ```mlir
+ %s = fir.shape %n, %m : (index, index) -> !fir.shape<2>
+ %1 = fir.array_coor %a(%s) %i, %j : (!fir.ref<!fir.array<?x?xf32>>, !fir.shape<2>, index, index) -> !fir.ref<f32>
+ ```
+ }];
+
+ let arguments = (ins
+ AnyRefOrBox:$memref,
+ Optional<AnyShapeOrShiftType>:$shape,
+ Optional<fir_SliceType>:$slice,
+ Variadic<AnyCoordinateType>:$indices,
+ Variadic<AnyIntegerType>:$lenParams
+ );
+
+ let results = (outs fir_ReferenceType);
+
+ let assemblyFormat = [{
+ $memref (`(`$shape^`)`)? (`[`$slice^`]`)? $indices (`typeparams` $lenParams^)? attr-dict `:` functional-type(operands, results)
+ }];
+
+ let verifier = [{ return ::verify(*this); }];
+}
def fir_CoordinateOp : fir_Op<"coordinate_of", [NoSideEffect]> {
+
let summary = "Finds the coordinate (location) of a value in memory";
let description = [{
@@ -1674,18 +1927,218 @@ def fir_FieldIndexOp : fir_OneResultOp<"field_index", [NoSideEffect]> {
}
}];
- let builders = [
- OpBuilderDAG<(ins "StringRef":$fieldName, "Type":$recTy,
- CArg<"ValueRange", "{}">:$operands),
+ let builders = [OpBuilderDAG<(ins "llvm::StringRef":$fieldName,
+ "mlir::Type":$recTy, CArg<"mlir::ValueRange","{}">:$operands),
[{
- $_state.addAttribute(fieldAttrName(), $_builder.getStringAttr(fieldName));
+ $_state.addAttribute(fieldAttrName(),
+ $_builder.getStringAttr(fieldName));
$_state.addAttribute(typeAttrName(), TypeAttr::get(recTy));
$_state.addOperands(operands);
- }]>];
+ }]
+ >];
let extraClassDeclaration = [{
static constexpr llvm::StringRef fieldAttrName() { return "field_id"; }
static constexpr llvm::StringRef typeAttrName() { return "on_type"; }
+ llvm::StringRef getFieldName() { return field_id(); }
+ }];
+}
+
+def fir_ShapeOp : fir_Op<"shape", [NoSideEffect]> {
+
+ let summary = "generate an abstract shape vector of type `!fir.shape`";
+
+ let description = [{
+ The arguments are an ordered list of integral type values that define the
+ runtime extent of each dimension of an array. The shape information is
+ given in the same row-to-column order as Fortran. This abstract shape value
+ must be applied to a reified object, so all shape information must be
+ specified. The extent must be nonnegative.
+
+ ```mlir
+ %d = fir.shape %row_sz, %col_sz : (index, index) -> !fir.shape<2>
+ ```
+ }];
+
+ let arguments = (ins Variadic<AnyIntegerType>:$extents);
+
+ let results = (outs fir_ShapeType);
+
+ let assemblyFormat = [{
+ operands attr-dict `:` functional-type(operands, results)
+ }];
+
+ let verifier = [{
+ auto size = extents().size();
+ auto shapeTy = getType().dyn_cast<fir::ShapeType>();
+ assert(shapeTy && "must be a shape type");
+ if (shapeTy.getRank() != size)
+ return emitOpError("shape type rank mismatch");
+ return mlir::success();
+ }];
+
+ let extraClassDeclaration = [{
+ std::vector<mlir::Value> getExtents() {
+ return {extents().begin(), extents().end()};
+ }
+ }];
+}
+
+def fir_ShapeShiftOp : fir_Op<"shape_shift", [NoSideEffect]> {
+
+ let summary = [{
+ generate an abstract shape and shift vector of type `!fir.shapeshift`
+ }];
+
+ let description = [{
+ The arguments are an ordered list of integral type values that is a multiple
+ of 2 in length. Each such pair is defined as: the lower bound and the
+ extent for that dimension. The shifted shape information is given in the
+ same row-to-column order as Fortran. This abstract shifted shape value must
+ be applied to a reified object, so all shifted shape information must be
+ specified. The extent must be nonnegative.
+
+ ```mlir
+ %d = fir.shape_shift %lo, %extent : (index, index) -> !fir.shapeshift<1>
+ ```
+ }];
+
+ let arguments = (ins Variadic<AnyIntegerType>:$pairs);
+
+ let results = (outs fir_ShapeShiftType);
+
+ let assemblyFormat = [{
+ operands attr-dict `:` functional-type(operands, results)
+ }];
+
+ let verifier = [{
+ auto size = pairs().size();
+ if (size < 2 || size > 16 * 2)
+ return emitOpError("incorrect number of args");
+ if (size % 2 != 0)
+ return emitOpError("requires a multiple of 2 args");
+ auto shapeTy = getType().dyn_cast<fir::ShapeShiftType>();
+ assert(shapeTy && "must be a shape shift type");
+ if (shapeTy.getRank() * 2 != size)
+ return emitOpError("shape type rank mismatch");
+ return mlir::success();
+ }];
+
+ let extraClassDeclaration = [{
+ // Logically unzip the origins from the extent values.
+ std::vector<mlir::Value> getOrigins() {
+ std::vector<mlir::Value> result;
+ for (auto i : llvm::enumerate(pairs()))
+ if (!(i.index() & 1))
+ result.push_back(i.value());
+ return result;
+ }
+
+ // Logically unzip the extents from the origin values.
+ std::vector<mlir::Value> getExtents() {
+ std::vector<mlir::Value> result;
+ for (auto i : llvm::enumerate(pairs()))
+ if (i.index() & 1)
+ result.push_back(i.value());
+ return result;
+ }
+ }];
+}
+
+def fir_ShiftOp : fir_Op<"shift", [NoSideEffect]> {
+
+ let summary = "generate an abstract shift vector of type `!fir.shift`";
+
+ let description = [{
+ The arguments are an ordered list of integral type values that define the
+ runtime lower bound of each dimension of an array. The shape information is
+ given in the same row-to-column order as Fortran. This abstract shift value
+ must be applied to a reified object, so all shift information must be
+ specified.
+
+ ```mlir
+ %d = fir.shift %row_lb, %col_lb : (index, index) -> !fir.shift<2>
+ ```
+ }];
+
+ let arguments = (ins Variadic<AnyIntegerType>:$origins);
+
+ let results = (outs fir_ShiftType);
+
+ let assemblyFormat = [{
+ operands attr-dict `:` functional-type(operands, results)
+ }];
+
+ let verifier = [{
+ auto size = origins().size();
+ auto shiftTy = getType().dyn_cast<fir::ShiftType>();
+ assert(shiftTy && "must be a shift type");
+ if (shiftTy.getRank() != size)
+ return emitOpError("shift type rank mismatch");
+ return mlir::success();
+ }];
+
+ let extraClassDeclaration = [{
+ std::vector<mlir::Value> getOrigins() {
+ return {origins().begin(), origins().end()};
+ }
+ }];
+}
+
+def fir_SliceOp : fir_Op<"slice", [NoSideEffect, AttrSizedOperandSegments]> {
+
+ let summary = "generate an abstract slice vector of type `!fir.slice`";
+
+ let description = [{
+ The array slicing arguments are an ordered list of integral type values
+ that must be a multiple of 3 in length. Each such triple is defined as:
+ the lower bound, the upper bound, and the stride for that dimension, as in
+ Fortran syntax. Both bounds are inclusive. The array slice information is
+ given in the same row-to-column order as Fortran. This abstract slice value
+ must be applied to a reified object, so all slice information must be
+ specified. The extent must be nonnegative and the stride must not be zero.
+
+ ```mlir
+ %d = fir.slice %lo, %hi, %step : (index, index, index) -> !fir.slice<1>
+ ```
+
+ To support generalized slicing of Fortran's dynamic derived types, a slice
+ op can be given a component path (narrowing from the product type of the
+ original array to the specific elemental type of the sliced projection).
+
+ ```mlir
+ %fld = fir.field_index component, !fir.type<t{...component:ct...}>
+ %d = fir.slice %lo, %hi, %step path %fld : (index, index, index, !fir.field) -> !fir.slice<1>
+ ```
+ }];
+
+ let arguments = (ins
+ Variadic<AnyCoordinateType>:$triples,
+ Variadic<AnyComponentType>:$fields
+ );
+
+ let results = (outs fir_SliceType);
+
+ let assemblyFormat = [{
+ $triples (`path` $fields^)? attr-dict `:` functional-type(operands, results)
+ }];
+
+ let verifier = [{
+ auto size = triples().size();
+ if (size < 3 || size > 16 * 3)
+ return emitOpError("incorrect number of args for triple");
+ if (size % 3 != 0)
+ return emitOpError("requires a multiple of 3 args");
+ auto sliceTy = getType().dyn_cast<fir::SliceType>();
+ assert(sliceTy && "must be a slice type");
+ if (sliceTy.getRank() * 3 != size)
+ return emitOpError("slice type rank mismatch");
+ return mlir::success();
+ }];
+
+ let extraClassDeclaration = [{
+ unsigned getOutRank() { return getOutputRank(triples()); }
+ static unsigned getOutputRank(mlir::ValueRange triples);
}];
}
diff --git a/flang/include/flang/Optimizer/Dialect/FIRType.h b/flang/include/flang/Optimizer/Dialect/FIRType.h
index 2477b07d1d08..ca0dddd2a2c7 100644
--- a/flang/include/flang/Optimizer/Dialect/FIRType.h
+++ b/flang/include/flang/Optimizer/Dialect/FIRType.h
@@ -10,8 +10,8 @@
//
//===----------------------------------------------------------------------===//
-#ifndef OPTIMIZER_DIALECT_FIRTYPE_H
-#define OPTIMIZER_DIALECT_FIRTYPE_H
+#ifndef FORTRAN_OPTIMIZER_DIALECT_FIRTYPE_H
+#define FORTRAN_OPTIMIZER_DIALECT_FIRTYPE_H
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
@@ -23,7 +23,8 @@
namespace llvm {
class raw_ostream;
class StringRef;
-template <typename> class ArrayRef;
+template <typename>
+class ArrayRef;
class hash_code;
} // namespace llvm
@@ -80,6 +81,10 @@ bool isa_aggregate(mlir::Type t);
/// not a memory reference type, then returns a null `Type`.
mlir::Type dyn_cast_ptrEleTy(mlir::Type t);
+/// Extract the `Type` pointed to from a FIR memory reference or box type. If
+/// `t` is not a memory reference or box type, then returns a null `Type`.
+mlir::Type dyn_cast_ptrOrBoxEleTy(mlir::Type t);
+
/// Is `t` a FIR Real or MLIR Float type?
inline bool isa_real(mlir::Type t) {
return t.isa<fir::RealType>() || t.isa<mlir::FloatType>();
@@ -125,4 +130,4 @@ inline bool singleIndirectionLevel(mlir::Type ty) {
} // namespace fir
-#endif // OPTIMIZER_DIALECT_FIRTYPE_H
+#endif // FORTRAN_OPTIMIZER_DIALECT_FIRTYPE_H
diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp
index ed053fd2b1a6..80f1a1d83a18 100644
--- a/flang/lib/Optimizer/Dialect/FIROps.cpp
+++ b/flang/lib/Optimizer/Dialect/FIROps.cpp
@@ -5,6 +5,10 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
+//
+// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
+//
+//===----------------------------------------------------------------------===//
#include "flang/Optimizer/Dialect/FIROps.h"
#include "flang/Optimizer/Dialect/FIRAttr.h"
@@ -115,6 +119,90 @@ mlir::Type fir::AllocMemOp::wrapResultType(mlir::Type intype) {
return HeapType::get(intype);
}
+//===----------------------------------------------------------------------===//
+// ArrayCoorOp
+//===----------------------------------------------------------------------===//
+
+static mlir::LogicalResult verify(fir::ArrayCoorOp op) {
+ auto eleTy = fir::dyn_cast_ptrOrBoxEleTy(op.memref().getType());
+ auto arrTy = eleTy.dyn_cast<fir::SequenceType>();
+ if (!arrTy)
+ return op.emitOpError("must be a reference to an array");
+ auto arrDim = arrTy.getDimension();
+
+ if (auto shapeOp = op.shape()) {
+ auto shapeTy = shapeOp.getType();
+ unsigned shapeTyRank = 0;
+ if (auto s = shapeTy.dyn_cast<fir::ShapeType>()) {
+ shapeTyRank = s.getRank();
+ } else if (auto ss = shapeTy.dyn_cast<fir::ShapeShiftType>()) {
+ shapeTyRank = ss.getRank();
+ } else {
+ auto s = shapeTy.cast<fir::ShiftType>();
+ shapeTyRank = s.getRank();
+ if (!op.memref().getType().isa<fir::BoxType>())
+ return op.emitOpError("shift can only be provided with fir.box memref");
+ }
+ if (arrDim && arrDim != shapeTyRank)
+ return op.emitOpError("rank of dimension mismatched");
+ if (shapeTyRank != op.indices().size())
+ return op.emitOpError("number of indices do not match dim rank");
+ }
+
+ if (auto sliceOp = op.slice())
+ if (auto sliceTy = sliceOp.getType().dyn_cast<fir::SliceType>())
+ if (sliceTy.getRank() != arrDim)
+ return op.emitOpError("rank of dimension in slice mismatched");
+
+ return mlir::success();
+}
+
+//===----------------------------------------------------------------------===//
+// ArrayLoadOp
+//===----------------------------------------------------------------------===//
+
+std::vector<mlir::Value> fir::ArrayLoadOp::getExtents() {
+ if (auto sh = shape())
+ if (auto *op = sh.getDefiningOp()) {
+ if (auto shOp = dyn_cast<fir::ShapeOp>(op))
+ return shOp.getExtents();
+ return cast<fir::ShapeShiftOp>(op).getExtents();
+ }
+ return {};
+}
+
+static mlir::LogicalResult verify(fir::ArrayLoadOp op) {
+ auto eleTy = fir::dyn_cast_ptrOrBoxEleTy(op.memref().getType());
+ auto arrTy = eleTy.dyn_cast<fir::SequenceType>();
+ if (!arrTy)
+ return op.emitOpError("must be a reference to an array");
+ auto arrDim = arrTy.getDimension();
+
+ if (auto shapeOp = op.shape()) {
+ auto shapeTy = shapeOp.getType();
+ unsigned shapeTyRank = 0;
+ if (auto s = shapeTy.dyn_cast<fir::ShapeType>()) {
+ shapeTyRank = s.getRank();
+ } else if (auto ss = shapeTy.dyn_cast<fir::ShapeShiftType>()) {
+ shapeTyRank = ss.getRank();
+ } else {
+ auto s = shapeTy.cast<fir::ShiftType>();
+ shapeTyRank = s.getRank();
+ if (!op.memref().getType().isa<fir::BoxType>())
+ return op.emitOpError("shift can only be provided with fir.box memref");
+ }
+ if (arrDim && arrDim != shapeTyRank)
+ return op.emitOpError("rank of dimension mismatched");
+ }
+
+ if (auto sliceOp = op.slice())
+ if (auto sliceTy = sliceOp.getType().dyn_cast<fir::SliceType>())
+ if (sliceTy.getRank() != arrDim)
+ return op.emitOpError("rank of dimension in slice mismatched");
+
+ return mlir::success();
+}
+
//===----------------------------------------------------------------------===//
// BoxAddrOp
//===----------------------------------------------------------------------===//
diff --git a/flang/lib/Optimizer/Dialect/FIRType.cpp b/flang/lib/Optimizer/Dialect/FIRType.cpp
index 4cbdf15f38d6..d3b9a2fdb05b 100644
--- a/flang/lib/Optimizer/Dialect/FIRType.cpp
+++ b/flang/lib/Optimizer/Dialect/FIRType.cpp
@@ -223,6 +223,19 @@ mlir::Type dyn_cast_ptrEleTy(mlir::Type t) {
.Default([](mlir::Type) { return mlir::Type{}; });
}
+mlir::Type dyn_cast_ptrOrBoxEleTy(mlir::Type t) {
+ return llvm::TypeSwitch<mlir::Type, mlir::Type>(t)
+ .Case<fir::ReferenceType, fir::PointerType, fir::HeapType>(
+ [](auto p) { return p.getEleTy(); })
+ .Case<fir::BoxType>([](auto p) {
+ auto eleTy = p.getEleTy();
+ if (auto ty = fir::dyn_cast_ptrEleTy(eleTy))
+ return ty;
+ return eleTy;
+ })
+ .Default([](mlir::Type) { return mlir::Type{}; });
+}
+
} // namespace fir
namespace {
diff --git a/flang/test/Fir/fir-ops.fir b/flang/test/Fir/fir-ops.fir
index 3e8c81cd8d1b..cbfe31879030 100644
--- a/flang/test/Fir/fir-ops.fir
+++ b/flang/test/Fir/fir-ops.fir
@@ -618,5 +618,17 @@ func @test_misc_ops(%arr1 : !fir.ref<!fir.array<?x?xf32>>, %m : index, %n : inde
// CHECK: [[ARR2:%.*]] = fir.zero_bits !fir.array<10xi32>
%arr2 = fir.zero_bits !fir.array<10xi32>
+
+ // CHECK: [[SHAPE:%.*]] = fir.shape_shift [[INDXM:%.*]], [[INDXN:%.*]], [[INDXO:%.*]], [[INDXP:%.*]] : (index, index, index, index) -> !fir.shapeshift<2>
+ // CHECK: [[AV1:%.*]] = fir.array_load [[ARR1]]([[SHAPE]]) : (!fir.ref<!fir.array<?x?xf32>>, !fir.shapeshift<2>) -> !fir.array<?x?xf32>
+ // CHECK: [[FVAL:%.*]] = fir.array_fetch [[AV1]], [[I10]], [[J20]] : (!fir.array<?x?xf32>, index, index) -> f32
+ // CHECK: [[AV2:%.*]] = fir.array_update [[AV1]], [[FVAL]], [[I10]], [[J20]] : (!fir.array<?x?xf32>, f32, index, index) -> !fir.array<?x?xf32>
+ // CHECK: fir.array_merge_store [[AV1]], [[AV2]] to [[ARR1]] : !fir.ref<!fir.array<?x?xf32>>
+ %s = fir.shape_shift %m, %n, %o, %p : (index, index, index, index) -> !fir.shapeshift<2>
+ %av1 = fir.array_load %arr1(%s) : (!fir.ref<!fir.array<?x?xf32>>, !fir.shapeshift<2>) -> !fir.array<?x?xf32>
+ %f = fir.array_fetch %av1, %i10, %j20 : (!fir.array<?x?xf32>, index, index) -> f32
+ %av2 = fir.array_update %av1, %f, %i10, %j20 : (!fir.array<?x?xf32>, f32, index, index) -> !fir.array<?x?xf32>
+ fir.array_merge_store %av1, %av2 to %arr1 : !fir.ref<!fir.array<?x?xf32>>
+
return
}
More information about the flang-commits
mailing list