[flang-commits] [flang] 575a6f8 - [flang] add ExtendedValue type helpers and factory::genZeroValue
Jean Perier via flang-commits
flang-commits at lists.llvm.org
Thu Feb 3 01:14:45 PST 2022
Author: Jean Perier
Date: 2022-02-03T10:13:54+01:00
New Revision: 575a6f819bc60fc2423e492f6c133404740445db
URL: https://github.com/llvm/llvm-project/commit/575a6f819bc60fc2423e492f6c133404740445db
DIFF: https://github.com/llvm/llvm-project/commit/575a6f819bc60fc2423e492f6c133404740445db.diff
LOG: [flang] add ExtendedValue type helpers and factory::genZeroValue
Add some helpers to get the base type and element type of
fir::ExtendedValue and to test if a fir::ExtendedValue is
a derived type with length parameters.
Add a new helper factory::genZeroValue to generate zero scalar value for
all the numerical types and false for logicals.
These helpers are used only in lowering for now, so add unit tests.
Differential Revision: https://reviews.llvm.org/D118795
Added:
Modified:
flang/include/flang/Optimizer/Builder/BoxValue.h
flang/include/flang/Optimizer/Builder/FIRBuilder.h
flang/lib/Optimizer/Builder/FIRBuilder.cpp
flang/unittests/Optimizer/Builder/FIRBuilderTest.cpp
Removed:
################################################################################
diff --git a/flang/include/flang/Optimizer/Builder/BoxValue.h b/flang/include/flang/Optimizer/Builder/BoxValue.h
index ef84dfd30f1a6..a1ff8be8df703 100644
--- a/flang/include/flang/Optimizer/Builder/BoxValue.h
+++ b/flang/include/flang/Optimizer/Builder/BoxValue.h
@@ -467,6 +467,30 @@ inline bool isUnboxedValue(const ExtendedValue &exv) {
[](const fir::UnboxedValue &box) { return box ? true : false; },
[](const auto &) { return false; });
}
+
+/// Returns the base type of \p exv. This is the type of \p exv
+/// without any memory or box type. The sequence type, if any, is kept.
+inline mlir::Type getBaseTypeOf(const ExtendedValue &exv) {
+ return exv.match(
+ [](const fir::MutableBoxValue &box) { return box.getBaseTy(); },
+ [](const fir::BoxValue &box) { return box.getBaseTy(); },
+ [&](const auto &) {
+ return fir::unwrapRefType(fir::getBase(exv).getType());
+ });
+}
+
+/// Return the scalar type of \p exv type. This removes all
+/// reference, box, or sequence type from \p exv base.
+inline mlir::Type getElementTypeOf(const ExtendedValue &exv) {
+ return fir::unwrapSequenceType(getBaseTypeOf(exv));
+}
+
+/// Is the extended value `exv` a derived type with length parameters ?
+inline bool isDerivedWithLengthParameters(const ExtendedValue &exv) {
+ auto record = getElementTypeOf(exv).dyn_cast<fir::RecordType>();
+ return record && record.getNumLenParams() != 0;
+}
+
} // namespace fir
#endif // FORTRAN_OPTIMIZER_BUILDER_BOXVALUE_H
diff --git a/flang/include/flang/Optimizer/Builder/FIRBuilder.h b/flang/include/flang/Optimizer/Builder/FIRBuilder.h
index a4de869ec0a0b..3ed4917616f46 100644
--- a/flang/include/flang/Optimizer/Builder/FIRBuilder.h
+++ b/flang/include/flang/Optimizer/Builder/FIRBuilder.h
@@ -416,6 +416,11 @@ mlir::Value locationToLineNo(fir::FirOpBuilder &, mlir::Location, mlir::Type);
/// flang/include/flang/Runtime/ragged.h.
mlir::TupleType getRaggedArrayHeaderType(fir::FirOpBuilder &builder);
+/// Create the zero value of a given the numerical or logical \p type (`false`
+/// for logical types).
+mlir::Value createZeroValue(fir::FirOpBuilder &builder, mlir::Location loc,
+ mlir::Type type);
+
} // namespace fir::factory
#endif // FORTRAN_OPTIMIZER_BUILDER_FIRBUILDER_H
diff --git a/flang/lib/Optimizer/Builder/FIRBuilder.cpp b/flang/lib/Optimizer/Builder/FIRBuilder.cpp
index ac56e32c4e00f..bcc67a9c1e7d2 100644
--- a/flang/lib/Optimizer/Builder/FIRBuilder.cpp
+++ b/flang/lib/Optimizer/Builder/FIRBuilder.cpp
@@ -604,3 +604,22 @@ fir::factory::getRaggedArrayHeaderType(fir::FirOpBuilder &builder) {
auto shTy = fir::HeapType::get(extTy);
return mlir::TupleType::get(builder.getContext(), {i64Ty, buffTy, shTy});
}
+
+mlir::Value fir::factory::createZeroValue(fir::FirOpBuilder &builder,
+ mlir::Location loc, mlir::Type type) {
+ mlir::Type i1 = builder.getIntegerType(1);
+ if (type.isa<fir::LogicalType>() || type == i1)
+ return builder.createConvert(loc, type, builder.createBool(loc, false));
+ if (fir::isa_integer(type))
+ return builder.createIntegerConstant(loc, type, 0);
+ if (fir::isa_real(type))
+ return builder.createRealZeroConstant(loc, type);
+ if (fir::isa_complex(type)) {
+ fir::factory::Complex complexHelper(builder, loc);
+ mlir::Type partType = complexHelper.getComplexPartType(type);
+ mlir::Value zeroPart = builder.createRealZeroConstant(loc, partType);
+ return complexHelper.createComplex(type, zeroPart, zeroPart);
+ }
+ fir::emitFatalError(loc, "internal: trying to generate zero value of non "
+ "numeric or logical type");
+}
diff --git a/flang/unittests/Optimizer/Builder/FIRBuilderTest.cpp b/flang/unittests/Optimizer/Builder/FIRBuilderTest.cpp
index 7613d9145944a..bb16db9720cf9 100644
--- a/flang/unittests/Optimizer/Builder/FIRBuilderTest.cpp
+++ b/flang/unittests/Optimizer/Builder/FIRBuilderTest.cpp
@@ -414,3 +414,114 @@ TEST_F(FIRBuilderTest, getExtents) {
auto readExtents = fir::factory::getExtents(builder, loc, ex);
EXPECT_EQ(2u, readExtents.size());
}
+
+TEST_F(FIRBuilderTest, createZeroValue) {
+ auto builder = getBuilder();
+ auto loc = builder.getUnknownLoc();
+
+ mlir::Type i64Ty = mlir::IntegerType::get(builder.getContext(), 64);
+ mlir::Value zeroInt = fir::factory::createZeroValue(builder, loc, i64Ty);
+ EXPECT_TRUE(zeroInt.getType() == i64Ty);
+ auto cst =
+ mlir::dyn_cast_or_null<mlir::arith::ConstantOp>(zeroInt.getDefiningOp());
+ EXPECT_TRUE(cst);
+ auto intAttr = cst.getValue().dyn_cast<mlir::IntegerAttr>();
+ EXPECT_TRUE(intAttr && intAttr.getInt() == 0);
+
+ mlir::Type f32Ty = mlir::FloatType::getF32(builder.getContext());
+ mlir::Value zeroFloat = fir::factory::createZeroValue(builder, loc, f32Ty);
+ EXPECT_TRUE(zeroFloat.getType() == f32Ty);
+ auto cst2 = mlir::dyn_cast_or_null<mlir::arith::ConstantOp>(
+ zeroFloat.getDefiningOp());
+ EXPECT_TRUE(cst2);
+ auto floatAttr = cst2.getValue().dyn_cast<mlir::FloatAttr>();
+ EXPECT_TRUE(floatAttr && floatAttr.getValueAsDouble() == 0.);
+
+ mlir::Type boolTy = mlir::IntegerType::get(builder.getContext(), 1);
+ mlir::Value flaseBool = fir::factory::createZeroValue(builder, loc, boolTy);
+ EXPECT_TRUE(flaseBool.getType() == boolTy);
+ auto cst3 = mlir::dyn_cast_or_null<mlir::arith::ConstantOp>(
+ flaseBool.getDefiningOp());
+ EXPECT_TRUE(cst3);
+ auto intAttr2 = cst.getValue().dyn_cast<mlir::IntegerAttr>();
+ EXPECT_TRUE(intAttr2 && intAttr2.getInt() == 0);
+}
+
+TEST_F(FIRBuilderTest, getBaseTypeOf) {
+ auto builder = getBuilder();
+ auto loc = builder.getUnknownLoc();
+
+ auto makeExv = [&](mlir::Type elementType, mlir::Type arrayType)
+ -> std::tuple<llvm::SmallVector<fir::ExtendedValue, 4>,
+ llvm::SmallVector<fir::ExtendedValue, 4>> {
+ auto ptrTyArray = fir::PointerType::get(arrayType);
+ auto ptrTyScalar = fir::PointerType::get(elementType);
+ auto ptrBoxTyArray = fir::BoxType::get(ptrTyArray);
+ auto ptrBoxTyScalar = fir::BoxType::get(ptrTyScalar);
+ auto boxRefTyArray = fir::ReferenceType::get(ptrBoxTyArray);
+ auto boxRefTyScalar = fir::ReferenceType::get(ptrBoxTyScalar);
+ auto boxTyArray = fir::BoxType::get(arrayType);
+ auto boxTyScalar = fir::BoxType::get(elementType);
+
+ auto ptrValArray = builder.create<fir::UndefOp>(loc, ptrTyArray);
+ auto ptrValScalar = builder.create<fir::UndefOp>(loc, ptrTyScalar);
+ auto boxRefValArray = builder.create<fir::UndefOp>(loc, boxRefTyArray);
+ auto boxRefValScalar = builder.create<fir::UndefOp>(loc, boxRefTyScalar);
+ auto boxValArray = builder.create<fir::UndefOp>(loc, boxTyArray);
+ auto boxValScalar = builder.create<fir::UndefOp>(loc, boxTyScalar);
+
+ llvm::SmallVector<fir::ExtendedValue, 4> scalars;
+ scalars.emplace_back(fir::UnboxedValue(ptrValScalar));
+ scalars.emplace_back(fir::BoxValue(boxValScalar));
+ scalars.emplace_back(
+ fir::MutableBoxValue(boxRefValScalar, mlir::ValueRange(), {}));
+
+ llvm::SmallVector<fir::ExtendedValue, 4> arrays;
+ auto extent = builder.create<fir::UndefOp>(loc, builder.getIndexType());
+ llvm::SmallVector<mlir::Value> extents(
+ arrayType.dyn_cast<fir::SequenceType>().getDimension(),
+ extent.getResult());
+ arrays.emplace_back(fir::ArrayBoxValue(ptrValArray, extents));
+ arrays.emplace_back(fir::BoxValue(boxValArray));
+ arrays.emplace_back(
+ fir::MutableBoxValue(boxRefValArray, mlir::ValueRange(), {}));
+ return {scalars, arrays};
+ };
+
+ auto f32Ty = mlir::FloatType::getF32(builder.getContext());
+ mlir::Type f32SeqTy = builder.getVarLenSeqTy(f32Ty);
+ auto [f32Scalars, f32Arrays] = makeExv(f32Ty, f32SeqTy);
+ for (const auto &scalar : f32Scalars) {
+ EXPECT_EQ(fir::getBaseTypeOf(scalar), f32Ty);
+ EXPECT_EQ(fir::getElementTypeOf(scalar), f32Ty);
+ EXPECT_FALSE(fir::isDerivedWithLengthParameters(scalar));
+ }
+ for (const auto &array : f32Arrays) {
+ EXPECT_EQ(fir::getBaseTypeOf(array), f32SeqTy);
+ EXPECT_EQ(fir::getElementTypeOf(array), f32Ty);
+ EXPECT_FALSE(fir::isDerivedWithLengthParameters(array));
+ }
+
+ auto derivedWithLengthTy =
+ fir::RecordType::get(builder.getContext(), "derived_test");
+
+ llvm::SmallVector<std::pair<std::string, mlir::Type>> parameters;
+ llvm::SmallVector<std::pair<std::string, mlir::Type>> components;
+ parameters.emplace_back("p1", builder.getI64Type());
+ components.emplace_back("c1", f32Ty);
+ derivedWithLengthTy.finalize(parameters, components);
+ mlir::Type derivedWithLengthSeqTy =
+ builder.getVarLenSeqTy(derivedWithLengthTy);
+ auto [derivedWithLengthScalars, derivedWithLengthArrays] =
+ makeExv(derivedWithLengthTy, derivedWithLengthSeqTy);
+ for (const auto &scalar : derivedWithLengthScalars) {
+ EXPECT_EQ(fir::getBaseTypeOf(scalar), derivedWithLengthTy);
+ EXPECT_EQ(fir::getElementTypeOf(scalar), derivedWithLengthTy);
+ EXPECT_TRUE(fir::isDerivedWithLengthParameters(scalar));
+ }
+ for (const auto &array : derivedWithLengthArrays) {
+ EXPECT_EQ(fir::getBaseTypeOf(array), derivedWithLengthSeqTy);
+ EXPECT_EQ(fir::getElementTypeOf(array), derivedWithLengthTy);
+ EXPECT_TRUE(fir::isDerivedWithLengthParameters(array));
+ }
+}
More information about the flang-commits
mailing list