[flang-commits] [flang] 575a6f8 - [flang] add ExtendedValue type helpers and factory::genZeroValue

Jean Perier via flang-commits flang-commits at lists.llvm.org
Thu Feb 3 01:14:45 PST 2022


Author: Jean Perier
Date: 2022-02-03T10:13:54+01:00
New Revision: 575a6f819bc60fc2423e492f6c133404740445db

URL: https://github.com/llvm/llvm-project/commit/575a6f819bc60fc2423e492f6c133404740445db
DIFF: https://github.com/llvm/llvm-project/commit/575a6f819bc60fc2423e492f6c133404740445db.diff

LOG: [flang] add ExtendedValue type helpers and factory::genZeroValue

Add some helpers to get the base type and element type of
fir::ExtendedValue and to test if a fir::ExtendedValue is
a derived type with length parameters.

Add a new helper factory::genZeroValue to generate zero scalar value for
all the numerical types and false for logicals.

These helpers are used only in lowering for now, so add unit tests.

Differential Revision: https://reviews.llvm.org/D118795

Added: 
    

Modified: 
    flang/include/flang/Optimizer/Builder/BoxValue.h
    flang/include/flang/Optimizer/Builder/FIRBuilder.h
    flang/lib/Optimizer/Builder/FIRBuilder.cpp
    flang/unittests/Optimizer/Builder/FIRBuilderTest.cpp

Removed: 
    


################################################################################
diff  --git a/flang/include/flang/Optimizer/Builder/BoxValue.h b/flang/include/flang/Optimizer/Builder/BoxValue.h
index ef84dfd30f1a6..a1ff8be8df703 100644
--- a/flang/include/flang/Optimizer/Builder/BoxValue.h
+++ b/flang/include/flang/Optimizer/Builder/BoxValue.h
@@ -467,6 +467,30 @@ inline bool isUnboxedValue(const ExtendedValue &exv) {
       [](const fir::UnboxedValue &box) { return box ? true : false; },
       [](const auto &) { return false; });
 }
+
+/// Returns the base type of \p exv. This is the type of \p exv
+/// without any memory or box type. The sequence type, if any, is kept.
+inline mlir::Type getBaseTypeOf(const ExtendedValue &exv) {
+  return exv.match(
+      [](const fir::MutableBoxValue &box) { return box.getBaseTy(); },
+      [](const fir::BoxValue &box) { return box.getBaseTy(); },
+      [&](const auto &) {
+        return fir::unwrapRefType(fir::getBase(exv).getType());
+      });
+}
+
+/// Return the scalar type of \p exv type. This removes all
+/// reference, box, or sequence type from \p exv base.
+inline mlir::Type getElementTypeOf(const ExtendedValue &exv) {
+  return fir::unwrapSequenceType(getBaseTypeOf(exv));
+}
+
+/// Is the extended value `exv` a derived type with length parameters ?
+inline bool isDerivedWithLengthParameters(const ExtendedValue &exv) {
+  auto record = getElementTypeOf(exv).dyn_cast<fir::RecordType>();
+  return record && record.getNumLenParams() != 0;
+}
+
 } // namespace fir
 
 #endif // FORTRAN_OPTIMIZER_BUILDER_BOXVALUE_H

diff  --git a/flang/include/flang/Optimizer/Builder/FIRBuilder.h b/flang/include/flang/Optimizer/Builder/FIRBuilder.h
index a4de869ec0a0b..3ed4917616f46 100644
--- a/flang/include/flang/Optimizer/Builder/FIRBuilder.h
+++ b/flang/include/flang/Optimizer/Builder/FIRBuilder.h
@@ -416,6 +416,11 @@ mlir::Value locationToLineNo(fir::FirOpBuilder &, mlir::Location, mlir::Type);
 /// flang/include/flang/Runtime/ragged.h.
 mlir::TupleType getRaggedArrayHeaderType(fir::FirOpBuilder &builder);
 
+/// Create the zero value of a given the numerical or logical \p type (`false`
+/// for logical types).
+mlir::Value createZeroValue(fir::FirOpBuilder &builder, mlir::Location loc,
+                            mlir::Type type);
+
 } // namespace fir::factory
 
 #endif // FORTRAN_OPTIMIZER_BUILDER_FIRBUILDER_H

diff  --git a/flang/lib/Optimizer/Builder/FIRBuilder.cpp b/flang/lib/Optimizer/Builder/FIRBuilder.cpp
index ac56e32c4e00f..bcc67a9c1e7d2 100644
--- a/flang/lib/Optimizer/Builder/FIRBuilder.cpp
+++ b/flang/lib/Optimizer/Builder/FIRBuilder.cpp
@@ -604,3 +604,22 @@ fir::factory::getRaggedArrayHeaderType(fir::FirOpBuilder &builder) {
   auto shTy = fir::HeapType::get(extTy);
   return mlir::TupleType::get(builder.getContext(), {i64Ty, buffTy, shTy});
 }
+
+mlir::Value fir::factory::createZeroValue(fir::FirOpBuilder &builder,
+                                          mlir::Location loc, mlir::Type type) {
+  mlir::Type i1 = builder.getIntegerType(1);
+  if (type.isa<fir::LogicalType>() || type == i1)
+    return builder.createConvert(loc, type, builder.createBool(loc, false));
+  if (fir::isa_integer(type))
+    return builder.createIntegerConstant(loc, type, 0);
+  if (fir::isa_real(type))
+    return builder.createRealZeroConstant(loc, type);
+  if (fir::isa_complex(type)) {
+    fir::factory::Complex complexHelper(builder, loc);
+    mlir::Type partType = complexHelper.getComplexPartType(type);
+    mlir::Value zeroPart = builder.createRealZeroConstant(loc, partType);
+    return complexHelper.createComplex(type, zeroPart, zeroPart);
+  }
+  fir::emitFatalError(loc, "internal: trying to generate zero value of non "
+                           "numeric or logical type");
+}

diff  --git a/flang/unittests/Optimizer/Builder/FIRBuilderTest.cpp b/flang/unittests/Optimizer/Builder/FIRBuilderTest.cpp
index 7613d9145944a..bb16db9720cf9 100644
--- a/flang/unittests/Optimizer/Builder/FIRBuilderTest.cpp
+++ b/flang/unittests/Optimizer/Builder/FIRBuilderTest.cpp
@@ -414,3 +414,114 @@ TEST_F(FIRBuilderTest, getExtents) {
   auto readExtents = fir::factory::getExtents(builder, loc, ex);
   EXPECT_EQ(2u, readExtents.size());
 }
+
+TEST_F(FIRBuilderTest, createZeroValue) {
+  auto builder = getBuilder();
+  auto loc = builder.getUnknownLoc();
+
+  mlir::Type i64Ty = mlir::IntegerType::get(builder.getContext(), 64);
+  mlir::Value zeroInt = fir::factory::createZeroValue(builder, loc, i64Ty);
+  EXPECT_TRUE(zeroInt.getType() == i64Ty);
+  auto cst =
+      mlir::dyn_cast_or_null<mlir::arith::ConstantOp>(zeroInt.getDefiningOp());
+  EXPECT_TRUE(cst);
+  auto intAttr = cst.getValue().dyn_cast<mlir::IntegerAttr>();
+  EXPECT_TRUE(intAttr && intAttr.getInt() == 0);
+
+  mlir::Type f32Ty = mlir::FloatType::getF32(builder.getContext());
+  mlir::Value zeroFloat = fir::factory::createZeroValue(builder, loc, f32Ty);
+  EXPECT_TRUE(zeroFloat.getType() == f32Ty);
+  auto cst2 = mlir::dyn_cast_or_null<mlir::arith::ConstantOp>(
+      zeroFloat.getDefiningOp());
+  EXPECT_TRUE(cst2);
+  auto floatAttr = cst2.getValue().dyn_cast<mlir::FloatAttr>();
+  EXPECT_TRUE(floatAttr && floatAttr.getValueAsDouble() == 0.);
+
+  mlir::Type boolTy = mlir::IntegerType::get(builder.getContext(), 1);
+  mlir::Value flaseBool = fir::factory::createZeroValue(builder, loc, boolTy);
+  EXPECT_TRUE(flaseBool.getType() == boolTy);
+  auto cst3 = mlir::dyn_cast_or_null<mlir::arith::ConstantOp>(
+      flaseBool.getDefiningOp());
+  EXPECT_TRUE(cst3);
+  auto intAttr2 = cst.getValue().dyn_cast<mlir::IntegerAttr>();
+  EXPECT_TRUE(intAttr2 && intAttr2.getInt() == 0);
+}
+
+TEST_F(FIRBuilderTest, getBaseTypeOf) {
+  auto builder = getBuilder();
+  auto loc = builder.getUnknownLoc();
+
+  auto makeExv = [&](mlir::Type elementType, mlir::Type arrayType)
+      -> std::tuple<llvm::SmallVector<fir::ExtendedValue, 4>,
+          llvm::SmallVector<fir::ExtendedValue, 4>> {
+    auto ptrTyArray = fir::PointerType::get(arrayType);
+    auto ptrTyScalar = fir::PointerType::get(elementType);
+    auto ptrBoxTyArray = fir::BoxType::get(ptrTyArray);
+    auto ptrBoxTyScalar = fir::BoxType::get(ptrTyScalar);
+    auto boxRefTyArray = fir::ReferenceType::get(ptrBoxTyArray);
+    auto boxRefTyScalar = fir::ReferenceType::get(ptrBoxTyScalar);
+    auto boxTyArray = fir::BoxType::get(arrayType);
+    auto boxTyScalar = fir::BoxType::get(elementType);
+
+    auto ptrValArray = builder.create<fir::UndefOp>(loc, ptrTyArray);
+    auto ptrValScalar = builder.create<fir::UndefOp>(loc, ptrTyScalar);
+    auto boxRefValArray = builder.create<fir::UndefOp>(loc, boxRefTyArray);
+    auto boxRefValScalar = builder.create<fir::UndefOp>(loc, boxRefTyScalar);
+    auto boxValArray = builder.create<fir::UndefOp>(loc, boxTyArray);
+    auto boxValScalar = builder.create<fir::UndefOp>(loc, boxTyScalar);
+
+    llvm::SmallVector<fir::ExtendedValue, 4> scalars;
+    scalars.emplace_back(fir::UnboxedValue(ptrValScalar));
+    scalars.emplace_back(fir::BoxValue(boxValScalar));
+    scalars.emplace_back(
+        fir::MutableBoxValue(boxRefValScalar, mlir::ValueRange(), {}));
+
+    llvm::SmallVector<fir::ExtendedValue, 4> arrays;
+    auto extent = builder.create<fir::UndefOp>(loc, builder.getIndexType());
+    llvm::SmallVector<mlir::Value> extents(
+        arrayType.dyn_cast<fir::SequenceType>().getDimension(),
+        extent.getResult());
+    arrays.emplace_back(fir::ArrayBoxValue(ptrValArray, extents));
+    arrays.emplace_back(fir::BoxValue(boxValArray));
+    arrays.emplace_back(
+        fir::MutableBoxValue(boxRefValArray, mlir::ValueRange(), {}));
+    return {scalars, arrays};
+  };
+
+  auto f32Ty = mlir::FloatType::getF32(builder.getContext());
+  mlir::Type f32SeqTy = builder.getVarLenSeqTy(f32Ty);
+  auto [f32Scalars, f32Arrays] = makeExv(f32Ty, f32SeqTy);
+  for (const auto &scalar : f32Scalars) {
+    EXPECT_EQ(fir::getBaseTypeOf(scalar), f32Ty);
+    EXPECT_EQ(fir::getElementTypeOf(scalar), f32Ty);
+    EXPECT_FALSE(fir::isDerivedWithLengthParameters(scalar));
+  }
+  for (const auto &array : f32Arrays) {
+    EXPECT_EQ(fir::getBaseTypeOf(array), f32SeqTy);
+    EXPECT_EQ(fir::getElementTypeOf(array), f32Ty);
+    EXPECT_FALSE(fir::isDerivedWithLengthParameters(array));
+  }
+
+  auto derivedWithLengthTy =
+      fir::RecordType::get(builder.getContext(), "derived_test");
+
+  llvm::SmallVector<std::pair<std::string, mlir::Type>> parameters;
+  llvm::SmallVector<std::pair<std::string, mlir::Type>> components;
+  parameters.emplace_back("p1", builder.getI64Type());
+  components.emplace_back("c1", f32Ty);
+  derivedWithLengthTy.finalize(parameters, components);
+  mlir::Type derivedWithLengthSeqTy =
+      builder.getVarLenSeqTy(derivedWithLengthTy);
+  auto [derivedWithLengthScalars, derivedWithLengthArrays] =
+      makeExv(derivedWithLengthTy, derivedWithLengthSeqTy);
+  for (const auto &scalar : derivedWithLengthScalars) {
+    EXPECT_EQ(fir::getBaseTypeOf(scalar), derivedWithLengthTy);
+    EXPECT_EQ(fir::getElementTypeOf(scalar), derivedWithLengthTy);
+    EXPECT_TRUE(fir::isDerivedWithLengthParameters(scalar));
+  }
+  for (const auto &array : derivedWithLengthArrays) {
+    EXPECT_EQ(fir::getBaseTypeOf(array), derivedWithLengthSeqTy);
+    EXPECT_EQ(fir::getElementTypeOf(array), derivedWithLengthTy);
+    EXPECT_TRUE(fir::isDerivedWithLengthParameters(array));
+  }
+}


        


More information about the flang-commits mailing list