[flang-commits] [flang] 34d6c18 - [fir] Update fir.array_update, fir.array_fetch and fir.array_merge_store
Valentin Clement via flang-commits
flang-commits at lists.llvm.org
Thu Sep 30 01:14:07 PDT 2021
Author: Eric Schweitz
Date: 2021-09-30T09:56:50+02:00
New Revision: 34d6c1822eebe2c69c850e14d48e986a5b68cbd6
URL: https://github.com/llvm/llvm-project/commit/34d6c1822eebe2c69c850e14d48e986a5b68cbd6
DIFF: https://github.com/llvm/llvm-project/commit/34d6c1822eebe2c69c850e14d48e986a5b68cbd6.diff
LOG: [fir] Update fir.array_update, fir.array_fetch and fir.array_merge_store
Add typeparams to fir.array_update, fir.array_fetch and
fir.array_merge_store operations. Add optional slice operands to
fir.array_merge_store op.
Move verifiers to cpp file.
Reviewed By: kiranchandramohan
Differential Revision: https://reviews.llvm.org/D110701
Co-authored-by: Valentin Clement <clementval at gmail.com>
Added:
flang/include/flang/Optimizer/Support/Utils.h
Modified:
flang/include/flang/Optimizer/Dialect/FIROps.td
flang/include/flang/Optimizer/Dialect/FIRType.h
flang/lib/Optimizer/Dialect/FIROps.cpp
flang/test/Fir/fir-ops.fir
flang/test/Fir/invalid.fir
Removed:
################################################################################
diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td
index 09c81d3883467..470a1ff4df4c3 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROps.td
+++ b/flang/include/flang/Optimizer/Dialect/FIROps.td
@@ -1658,7 +1658,8 @@ def fir_ArrayLoadOp : fir_Op<"array_load", [AttrSizedOperandSegments]> {
}];
}
-def fir_ArrayFetchOp : fir_Op<"array_fetch", [NoSideEffect]> {
+def fir_ArrayFetchOp : fir_Op<"array_fetch", [AttrSizedOperandSegments,
+ NoSideEffect]> {
let summary = "Fetch the value of an element of an array value";
@@ -1689,28 +1690,22 @@ def fir_ArrayFetchOp : fir_Op<"array_fetch", [NoSideEffect]> {
let arguments = (ins
fir_SequenceType:$sequence,
- Variadic<AnyCoordinateType>:$indices
+ Variadic<AnyCoordinateType>:$indices,
+ Variadic<AnyIntegerType>:$typeparams
);
let results = (outs AnyType:$element);
let assemblyFormat = [{
- $sequence `,` $indices attr-dict `:` functional-type(operands, results)
+ $sequence `,` $indices (`typeparams` $typeparams^)? attr-dict `:`
+ functional-type(operands, results)
}];
- let verifier = [{
- auto arrTy = sequence().getType().cast<fir::SequenceType>();
- if (indices().size() != arrTy.getDimension())
- return emitOpError("number of indices != dimension of array");
- if (element().getType() != arrTy.getEleTy())
- return emitOpError("return type does not match array");
- if (!isa<fir::ArrayLoadOp>(sequence().getDefiningOp()))
- return emitOpError("argument #0 must be result of fir.array_load");
- return mlir::success();
- }];
+ let verifier = "return ::verify(*this);";
}
-def fir_ArrayUpdateOp : fir_Op<"array_update", [NoSideEffect]> {
+def fir_ArrayUpdateOp : fir_Op<"array_update", [AttrSizedOperandSegments,
+ NoSideEffect]> {
let summary = "Update the value of an element of an array value";
@@ -1747,32 +1742,22 @@ def fir_ArrayUpdateOp : fir_Op<"array_update", [NoSideEffect]> {
let arguments = (ins
fir_SequenceType:$sequence,
AnyType:$merge,
- Variadic<AnyCoordinateType>:$indices
+ Variadic<AnyCoordinateType>:$indices,
+ Variadic<AnyIntegerType>:$typeparams
);
let results = (outs fir_SequenceType);
let assemblyFormat = [{
- $sequence `,` $merge `,` $indices attr-dict `:` functional-type(operands, results)
+ $sequence `,` $merge `,` $indices (`typeparams` $typeparams^)? attr-dict
+ `:` functional-type(operands, results)
}];
- let verifier = [{
- auto arrTy = sequence().getType().cast<fir::SequenceType>();
- if (merge().getType() != arrTy.getEleTy())
- return emitOpError("merged value does not have element type");
- if (indices().size() != arrTy.getDimension())
- return emitOpError("number of indices != dimension of array");
- return mlir::success();
- }];
+ let verifier = "return ::verify(*this);";
}
-def fir_ArrayMergeStoreOp : fir_Op<"array_merge_store", [
- TypesMatchWith<"type of 'original' matches element type of 'memref'",
- "memref", "original",
- "fir::dyn_cast_ptrOrBoxEleTy($_self)">,
- TypesMatchWith<"type of 'sequence' matches element type of 'memref'",
- "memref", "sequence",
- "fir::dyn_cast_ptrOrBoxEleTy($_self)">]> {
+def fir_ArrayMergeStoreOp : fir_Op<"array_merge_store",
+ [AttrSizedOperandSegments]> {
let summary = "Store merged array value to memory.";
@@ -1801,16 +1786,17 @@ def fir_ArrayMergeStoreOp : fir_Op<"array_merge_store", [
let arguments = (ins
fir_SequenceType:$original,
fir_SequenceType:$sequence,
- Arg<AnyRefOrBox, "", [MemWrite]>:$memref
+ Arg<AnyRefOrBox, "", [MemWrite]>:$memref,
+ Optional<fir_SliceType>:$slice,
+ Variadic<AnyIntegerType>:$typeparams
);
- let assemblyFormat = "$original `,` $sequence `to` $memref attr-dict `:` type($memref)";
-
- let verifier = [{
- if (!isa<ArrayLoadOp>(original().getDefiningOp()))
- return emitOpError("operand #0 must be result of a fir.array_load op");
- return mlir::success();
+ let assemblyFormat = [{
+ $original `,` $sequence `to` $memref (`[` $slice^ `]`)? (`typeparams`
+ $typeparams^)? attr-dict `:` type(operands)
}];
+
+ let verifier = "return ::verify(*this);";
}
//===----------------------------------------------------------------------===//
diff --git a/flang/include/flang/Optimizer/Dialect/FIRType.h b/flang/include/flang/Optimizer/Dialect/FIRType.h
index 21e592167f847..82c4a43c06843 100644
--- a/flang/include/flang/Optimizer/Dialect/FIRType.h
+++ b/flang/include/flang/Optimizer/Dialect/FIRType.h
@@ -33,6 +33,7 @@ class DialectAsmParser;
class DialectAsmPrinter;
class ComplexType;
class FloatType;
+class ValueRange;
} // namespace mlir
namespace fir {
@@ -122,6 +123,9 @@ inline bool isa_complex(mlir::Type t) {
return t.isa<fir::ComplexType>() || t.isa<mlir::ComplexType>();
}
+/// Is `t` a CHARACTER type? Does not check the length.
+inline bool isa_char(mlir::Type t) { return t.isa<fir::CharacterType>(); }
+
/// Is `t` a CHARACTER type with a LEN other than 1?
inline bool isa_char_string(mlir::Type t) {
if (auto ct = t.dyn_cast_or_null<fir::CharacterType>())
@@ -134,6 +138,13 @@ inline bool isa_char_string(mlir::Type t) {
/// of unknown rank or type.
bool isa_unknown_size_box(mlir::Type t);
+/// If `t` is a SequenceType return its element type, otherwise return `t`.
+inline mlir::Type unwrapSequenceType(mlir::Type t) {
+ if (auto seqTy = t.dyn_cast<fir::SequenceType>())
+ return seqTy.getEleTy();
+ return t;
+}
+
#ifndef NDEBUG
// !fir.ptr<X> and !fir.heap<X> where X is !fir.ptr, !fir.heap, or !fir.ref
// is undefined and disallowed.
@@ -142,6 +153,11 @@ inline bool singleIndirectionLevel(mlir::Type ty) {
}
#endif
+/// Apply the components specified by `path` to `rootTy` to determine the type
+/// of the resulting component element. `rootTy` should be an aggregate type.
+/// Returns null on error.
+mlir::Type applyPathToType(mlir::Type rootTy, mlir::ValueRange path);
+
} // namespace fir
#endif // FORTRAN_OPTIMIZER_DIALECT_FIRTYPE_H
diff --git a/flang/include/flang/Optimizer/Support/Utils.h b/flang/include/flang/Optimizer/Support/Utils.h
new file mode 100644
index 0000000000000..edb14db370a3d
--- /dev/null
+++ b/flang/include/flang/Optimizer/Support/Utils.h
@@ -0,0 +1,26 @@
+//===-- Optimizer/Support/Utils.h -------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef FORTRAN_OPTIMIZER_SUPPORT_UTILS_H
+#define FORTRAN_OPTIMIZER_SUPPORT_UTILS_H
+
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/IR/BuiltinAttributes.h"
+
+namespace fir {
+/// Return the integer value of a ConstantOp.
+inline std::int64_t toInt(mlir::ConstantOp cop) {
+ return cop.getValue().cast<mlir::IntegerAttr>().getValue().getSExtValue();
+}
+} // namespace fir
+
+#endif // FORTRAN_OPTIMIZER_SUPPORT_UTILS_H
diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp
index 11ff61e69f350..88101585f63d2 100644
--- a/flang/lib/Optimizer/Dialect/FIROps.cpp
+++ b/flang/lib/Optimizer/Dialect/FIROps.cpp
@@ -14,6 +14,7 @@
#include "flang/Optimizer/Dialect/FIRAttr.h"
#include "flang/Optimizer/Dialect/FIROpsSupport.h"
#include "flang/Optimizer/Dialect/FIRType.h"
+#include "flang/Optimizer/Support/Utils.h"
#include "mlir/Dialect/CommonFolders.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/BuiltinOps.h"
@@ -153,6 +154,19 @@ static mlir::LogicalResult verify(fir::ArrayCoorOp op) {
// ArrayLoadOp
//===----------------------------------------------------------------------===//
+static mlir::Type adjustedElementType(mlir::Type t) {
+ if (auto ty = t.dyn_cast<fir::ReferenceType>()) {
+ auto eleTy = ty.getEleTy();
+ if (fir::isa_char(eleTy))
+ return eleTy;
+ if (fir::isa_derived(eleTy))
+ return eleTy;
+ if (eleTy.isa<fir::SequenceType>())
+ return eleTy;
+ }
+ return t;
+}
+
std::vector<mlir::Value> fir::ArrayLoadOp::getExtents() {
if (auto sh = shape())
if (auto *op = sh.getDefiningOp()) {
@@ -195,6 +209,90 @@ static mlir::LogicalResult verify(fir::ArrayLoadOp op) {
return mlir::success();
}
+//===----------------------------------------------------------------------===//
+// ArrayMergeStoreOp
+//===----------------------------------------------------------------------===//
+
+static mlir::LogicalResult verify(fir::ArrayMergeStoreOp op) {
+ if (!isa<ArrayLoadOp>(op.original().getDefiningOp()))
+ return op.emitOpError("operand #0 must be result of a fir.array_load op");
+ if (auto sl = op.slice()) {
+ if (auto *slOp = sl.getDefiningOp()) {
+ auto sliceOp = mlir::cast<fir::SliceOp>(slOp);
+ if (!sliceOp.fields().empty()) {
+ // This is an intra-object merge, where the slice is projecting the
+ // subfields that are to be overwritten by the merge operation.
+ auto eleTy = fir::dyn_cast_ptrOrBoxEleTy(op.memref().getType());
+ if (auto seqTy = eleTy.dyn_cast<fir::SequenceType>()) {
+ auto projTy =
+ fir::applyPathToType(seqTy.getEleTy(), sliceOp.fields());
+ if (fir::unwrapSequenceType(op.original().getType()) != projTy)
+ return op.emitOpError(
+ "type of origin does not match sliced memref type");
+ if (fir::unwrapSequenceType(op.sequence().getType()) != projTy)
+ return op.emitOpError(
+ "type of sequence does not match sliced memref type");
+ return mlir::success();
+ }
+ return op.emitOpError("referenced type is not an array");
+ }
+ }
+ return mlir::success();
+ }
+ auto eleTy = fir::dyn_cast_ptrOrBoxEleTy(op.memref().getType());
+ if (op.original().getType() != eleTy)
+ return op.emitOpError("type of origin does not match memref element type");
+ if (op.sequence().getType() != eleTy)
+ return op.emitOpError(
+ "type of sequence does not match memref element type");
+ return mlir::success();
+}
+
+//===----------------------------------------------------------------------===//
+// ArrayFetchOp
+//===----------------------------------------------------------------------===//
+
+// Template function used for both array_fetch and array_update verification.
+template <typename A>
+mlir::Type validArraySubobject(A op) {
+ auto ty = op.sequence().getType();
+ return fir::applyPathToType(ty, op.indices());
+}
+
+static mlir::LogicalResult verify(fir::ArrayFetchOp op) {
+ auto arrTy = op.sequence().getType().cast<fir::SequenceType>();
+ auto indSize = op.indices().size();
+ if (indSize < arrTy.getDimension())
+ return op.emitOpError("number of indices != dimension of array");
+ if (indSize == arrTy.getDimension() &&
+ ::adjustedElementType(op.element().getType()) != arrTy.getEleTy())
+ return op.emitOpError("return type does not match array");
+ auto ty = validArraySubobject(op);
+ if (!ty || ty != ::adjustedElementType(op.getType()))
+ return op.emitOpError("return type and/or indices do not type check");
+ if (!isa<fir::ArrayLoadOp>(op.sequence().getDefiningOp()))
+ return op.emitOpError("argument #0 must be result of fir.array_load");
+ return mlir::success();
+}
+
+//===----------------------------------------------------------------------===//
+// ArrayUpdateOp
+//===----------------------------------------------------------------------===//
+
+static mlir::LogicalResult verify(fir::ArrayUpdateOp op) {
+ auto arrTy = op.sequence().getType().cast<fir::SequenceType>();
+ auto indSize = op.indices().size();
+ if (indSize < arrTy.getDimension())
+ return op.emitOpError("number of indices != dimension of array");
+ if (indSize == arrTy.getDimension() &&
+ ::adjustedElementType(op.merge().getType()) != arrTy.getEleTy())
+ return op.emitOpError("merged value does not have element type");
+ auto ty = validArraySubobject(op);
+ if (!ty || ty != ::adjustedElementType(op.merge().getType()))
+ return op.emitOpError("merged value and/or indices do not type check");
+ return mlir::success();
+}
+
//===----------------------------------------------------------------------===//
// BoxAddrOp
//===----------------------------------------------------------------------===//
@@ -2197,6 +2295,47 @@ bool fir::valueHasFirAttribute(mlir::Value value,
return false;
}
+mlir::Type fir::applyPathToType(mlir::Type eleTy, mlir::ValueRange path) {
+ for (auto i = path.begin(), end = path.end(); eleTy && i < end;) {
+ eleTy = llvm::TypeSwitch<mlir::Type, mlir::Type>(eleTy)
+ .Case<fir::RecordType>([&](fir::RecordType ty) {
+ if (auto *op = (*i++).getDefiningOp()) {
+ if (auto off = mlir::dyn_cast<fir::FieldIndexOp>(op))
+ return ty.getType(off.getFieldName());
+ if (auto off = mlir::dyn_cast<mlir::ConstantOp>(op))
+ return ty.getType(fir::toInt(off));
+ }
+ return mlir::Type{};
+ })
+ .Case<fir::SequenceType>([&](fir::SequenceType ty) {
+ bool valid = true;
+ const auto rank = ty.getDimension();
+ for (std::remove_const_t<decltype(rank)> ii = 0;
+ valid && ii < rank; ++ii)
+ valid = i < end && fir::isa_integer((*i++).getType());
+ return valid ? ty.getEleTy() : mlir::Type{};
+ })
+ .Case<mlir::TupleType>([&](mlir::TupleType ty) {
+ if (auto *op = (*i++).getDefiningOp())
+ if (auto off = mlir::dyn_cast<mlir::ConstantOp>(op))
+ return ty.getType(fir::toInt(off));
+ return mlir::Type{};
+ })
+ .Case<fir::ComplexType>([&](fir::ComplexType ty) {
+ if (fir::isa_integer((*i++).getType()))
+ return ty.getElementType();
+ return mlir::Type{};
+ })
+ .Case<mlir::ComplexType>([&](mlir::ComplexType ty) {
+ if (fir::isa_integer((*i++).getType()))
+ return ty.getElementType();
+ return mlir::Type{};
+ })
+ .Default([&](const auto &) { return mlir::Type{}; });
+ }
+ return eleTy;
+}
+
// Tablegen operators
#define GET_OP_CLASSES
diff --git a/flang/test/Fir/fir-ops.fir b/flang/test/Fir/fir-ops.fir
index fcd638cf0ccaa..4f57c6a52ed6b 100644
--- a/flang/test/Fir/fir-ops.fir
+++ b/flang/test/Fir/fir-ops.fir
@@ -631,12 +631,12 @@ func @test_misc_ops(%arr1 : !fir.ref<!fir.array<?x?xf32>>, %m : index, %n : inde
// CHECK: [[AV1:%.*]] = fir.array_load [[ARR1]]([[SHAPE]]) : (!fir.ref<!fir.array<?x?xf32>>, !fir.shapeshift<2>) -> !fir.array<?x?xf32>
// CHECK: [[FVAL:%.*]] = fir.array_fetch [[AV1]], [[I10]], [[J20]] : (!fir.array<?x?xf32>, index, index) -> f32
// CHECK: [[AV2:%.*]] = fir.array_update [[AV1]], [[FVAL]], [[I10]], [[J20]] : (!fir.array<?x?xf32>, f32, index, index) -> !fir.array<?x?xf32>
- // CHECK: fir.array_merge_store [[AV1]], [[AV2]] to [[ARR1]] : !fir.ref<!fir.array<?x?xf32>>
+ // CHECK: fir.array_merge_store [[AV1]], [[AV2]] to [[ARR1]] : !fir.array<?x?xf32>, !fir.array<?x?xf32>, !fir.ref<!fir.array<?x?xf32>>
%s = fir.shape_shift %m, %n, %o, %p : (index, index, index, index) -> !fir.shapeshift<2>
%av1 = fir.array_load %arr1(%s) : (!fir.ref<!fir.array<?x?xf32>>, !fir.shapeshift<2>) -> !fir.array<?x?xf32>
%f = fir.array_fetch %av1, %i10, %j20 : (!fir.array<?x?xf32>, index, index) -> f32
%av2 = fir.array_update %av1, %f, %i10, %j20 : (!fir.array<?x?xf32>, f32, index, index) -> !fir.array<?x?xf32>
- fir.array_merge_store %av1, %av2 to %arr1 : !fir.ref<!fir.array<?x?xf32>>
+ fir.array_merge_store %av1, %av2 to %arr1 : !fir.array<?x?xf32>, !fir.array<?x?xf32>, !fir.ref<!fir.array<?x?xf32>>
return
}
diff --git a/flang/test/Fir/invalid.fir b/flang/test/Fir/invalid.fir
index dd0229662dedf..4e74de90ed0e0 100644
--- a/flang/test/Fir/invalid.fir
+++ b/flang/test/Fir/invalid.fir
@@ -494,3 +494,57 @@ func @bad_save_result(%buffer : !fir.ref<!fir.array<?xf32>>, %n :index) {
fir.save_result %res to %buffer(%shape) typeparams %n : !fir.array<?xf32>, !fir.ref<!fir.array<?xf32>>, !fir.shape<1>, index
return
}
+
+// -----
+
+func @test_misc_ops(%arr1 : !fir.ref<!fir.array<?x?xf32>>, %m : index, %n : index, %o : index, %p : index) {
+ %s = fir.shape_shift %m, %n, %o, %p : (index, index, index, index) -> !fir.shapeshift<2>
+ %av1 = fir.array_load %arr1(%s) : (!fir.ref<!fir.array<?x?xf32>>, !fir.shapeshift<2>) -> !fir.array<?x?xf32>
+ // expected-error at +1 {{'fir.array_fetch' op number of indices != dimension of array}}
+ %f = fir.array_fetch %av1, %m : (!fir.array<?x?xf32>, index) -> f32
+ return
+}
+
+// -----
+
+func @test_misc_ops(%arr1 : !fir.ref<!fir.array<?x?xf32>>, %m : index, %n : index, %o : index, %p : index) {
+ %s = fir.shape_shift %m, %n, %o, %p : (index, index, index, index) -> !fir.shapeshift<2>
+ %av1 = fir.array_load %arr1(%s) : (!fir.ref<!fir.array<?x?xf32>>, !fir.shapeshift<2>) -> !fir.array<?x?xf32>
+ // expected-error at +1 {{'fir.array_fetch' op return type does not match array}}
+ %f = fir.array_fetch %av1, %m, %n : (!fir.array<?x?xf32>, index, index) -> i32
+ return
+}
+
+// -----
+
+func @test_misc_ops(%arr1 : !fir.ref<!fir.array<?x?xf32>>, %m : index, %n : index, %o : index, %p : index) {
+ %s = fir.shape_shift %m, %n, %o, %p : (index, index, index, index) -> !fir.shapeshift<2>
+ %av1 = fir.array_load %arr1(%s) : (!fir.ref<!fir.array<?x?xf32>>, !fir.shapeshift<2>) -> !fir.array<?x?xf32>
+ %f = fir.array_fetch %av1, %m, %n : (!fir.array<?x?xf32>, index, index) -> f32
+ // expected-error at +1 {{'fir.array_update' op number of indices != dimension of array}}
+ %av2 = fir.array_update %av1, %f, %m : (!fir.array<?x?xf32>, f32, index) -> !fir.array<?x?xf32>
+ return
+}
+
+// -----
+
+func @test_misc_ops(%arr1 : !fir.ref<!fir.array<?x?xf32>>, %m : index, %n : index, %o : index, %p : index) {
+ %s = fir.shape_shift %m, %n, %o, %p : (index, index, index, index) -> !fir.shapeshift<2>
+ %av1 = fir.array_load %arr1(%s) : (!fir.ref<!fir.array<?x?xf32>>, !fir.shapeshift<2>) -> !fir.array<?x?xf32>
+ %c0 = constant 0 : i32
+ // expected-error at +1 {{'fir.array_update' op merged value does not have element type}}
+ %av2 = fir.array_update %av1, %c0, %m, %n : (!fir.array<?x?xf32>, i32, index, index) -> !fir.array<?x?xf32>
+ return
+}
+
+// -----
+
+func @test_misc_ops(%arr1 : !fir.ref<!fir.array<?x?xf32>>, %m : index, %n : index, %o : index, %p : index) {
+ %s = fir.shape_shift %m, %n, %o, %p : (index, index, index, index) -> !fir.shapeshift<2>
+ %av1 = fir.array_load %arr1(%s) : (!fir.ref<!fir.array<?x?xf32>>, !fir.shapeshift<2>) -> !fir.array<?x?xf32>
+ %f = fir.array_fetch %av1, %m, %n : (!fir.array<?x?xf32>, index, index) -> f32
+ %av2 = fir.array_update %av1, %f, %m, %n : (!fir.array<?x?xf32>, f32, index, index) -> !fir.array<?x?xf32>
+ // expected-error at +1 {{'fir.array_merge_store' op operand #0 must be result of a fir.array_load op}}
+ fir.array_merge_store %av2, %av2 to %arr1 : !fir.array<?x?xf32>, !fir.array<?x?xf32>, !fir.ref<!fir.array<?x?xf32>>
+ return
+}
More information about the flang-commits
mailing list