[flang-commits] [flang] a1f9bd3 - [Flang] Add a factory class for creating Complex Ops

Kiran Chandramohan via flang-commits flang-commits at lists.llvm.org
Thu Nov 18 08:56:09 PST 2021


Author: Kiran Chandramohan
Date: 2021-11-18T16:55:35Z
New Revision: a1f9bd32c57649b9f7695fa6564f3f92f09a1785

URL: https://github.com/llvm/llvm-project/commit/a1f9bd32c57649b9f7695fa6564f3f92f09a1785
DIFF: https://github.com/llvm/llvm-project/commit/a1f9bd32c57649b9f7695fa6564f3f92f09a1785.diff

LOG: [Flang] Add a factory class for creating Complex Ops

Use the factory class in the FIRBuilder.
Add unit tests for the factory class function and the convert function
of the Complex class.

Reviewed By: clementval, rovka

Differential Revision: https://reviews.llvm.org/D114125

Co-authored-by: Jean Perier <jperier at nvidia.com>
Co-authored-by: Eric Schweitz <eschweitz at nvidia.com>

Added: 
    flang/include/flang/Optimizer/Builder/Complex.h
    flang/lib/Optimizer/Builder/Complex.cpp
    flang/unittests/Optimizer/Builder/ComplexTest.cpp

Modified: 
    flang/include/flang/Optimizer/Builder/FIRBuilder.h
    flang/lib/Optimizer/Builder/CMakeLists.txt
    flang/lib/Optimizer/Builder/FIRBuilder.cpp
    flang/unittests/Optimizer/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/flang/include/flang/Optimizer/Builder/Complex.h b/flang/include/flang/Optimizer/Builder/Complex.h
new file mode 100644
index 0000000000000..3286e41ee039e
--- /dev/null
+++ b/flang/include/flang/Optimizer/Builder/Complex.h
@@ -0,0 +1,89 @@
+//===-- Complex.h -- lowering of complex values -----------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef FORTRAN_OPTIMIZER_BUILDER_COMPLEX_H
+#define FORTRAN_OPTIMIZER_BUILDER_COMPLEX_H
+
+#include "flang/Optimizer/Builder/FIRBuilder.h"
+
+namespace fir::factory {
+
+/// Helper to facilitate lowering of COMPLEX manipulations in FIR.
+class Complex {
+public:
+  explicit Complex(FirOpBuilder &builder, mlir::Location loc)
+      : builder(builder), loc(loc) {}
+  Complex(const Complex &) = delete;
+
+  // The values of part enum members are meaningful for
+  // InsertValueOp and ExtractValueOp so they are explicit.
+  enum class Part { Real = 0, Imag = 1 };
+
+  /// Get the Complex Type. Determine the type. Do not create MLIR operations.
+  mlir::Type getComplexPartType(mlir::Value cplx) const;
+  mlir::Type getComplexPartType(mlir::Type complexType) const;
+
+  /// Complex operation creation. They create MLIR operations.
+  mlir::Value createComplex(fir::KindTy kind, mlir::Value real,
+                            mlir::Value imag);
+
+  /// Create a complex value.
+  mlir::Value createComplex(mlir::Type complexType, mlir::Value real,
+                            mlir::Value imag);
+
+  /// Returns the Real/Imag part of \p cplx
+  mlir::Value extractComplexPart(mlir::Value cplx, bool isImagPart) {
+    return isImagPart ? extract<Part::Imag>(cplx) : extract<Part::Real>(cplx);
+  }
+
+  /// Returns (Real, Imag) pair of \p cplx
+  std::pair<mlir::Value, mlir::Value> extractParts(mlir::Value cplx) {
+    return {extract<Part::Real>(cplx), extract<Part::Imag>(cplx)};
+  }
+
+  mlir::Value insertComplexPart(mlir::Value cplx, mlir::Value part,
+                                bool isImagPart) {
+    return isImagPart ? insert<Part::Imag>(cplx, part)
+                      : insert<Part::Real>(cplx, part);
+  }
+
+protected:
+  template <Part partId>
+  mlir::Value extract(mlir::Value cplx) {
+    return builder.create<fir::ExtractValueOp>(
+        loc, getComplexPartType(cplx), cplx,
+        builder.getArrayAttr({builder.getIntegerAttr(
+            builder.getIndexType(), static_cast<int>(partId))}));
+  }
+
+  template <Part partId>
+  mlir::Value insert(mlir::Value cplx, mlir::Value part) {
+    return builder.create<fir::InsertValueOp>(
+        loc, cplx.getType(), cplx, part,
+        builder.getArrayAttr({builder.getIntegerAttr(
+            builder.getIndexType(), static_cast<int>(partId))}));
+  }
+
+  template <Part partId>
+  mlir::Value createPartId() {
+    return builder.createIntegerConstant(loc, builder.getIndexType(),
+                                         static_cast<int>(partId));
+  }
+
+private:
+  FirOpBuilder &builder;
+  mlir::Location loc;
+};
+
+} // namespace fir::factory
+
+#endif // FORTRAN_OPTIMIZER_BUILDER_COMPLEX_H

diff  --git a/flang/include/flang/Optimizer/Builder/FIRBuilder.h b/flang/include/flang/Optimizer/Builder/FIRBuilder.h
index d6ed9fd881091..c48a29b1ebdf8 100644
--- a/flang/include/flang/Optimizer/Builder/FIRBuilder.h
+++ b/flang/include/flang/Optimizer/Builder/FIRBuilder.h
@@ -57,6 +57,15 @@ class FirOpBuilder : public mlir::OpBuilder {
   /// Get a reference to the kind map.
   const fir::KindMapping &getKindMap() { return kindMap; }
 
+  /// The LHS and RHS are not always in agreement in terms of
+  /// type. In some cases, the disagreement is between COMPLEX and other scalar
+  /// types. In that case, the conversion must insert/extract out of a COMPLEX
+  /// value to have the proper semantics and be strongly typed. For e.g for
+  /// converting an integer/real to a complex, the real part is filled using
+  /// the integer/real after type conversion and the imaginary part is zero.
+  mlir::Value convertWithSemantics(mlir::Location loc, mlir::Type toTy,
+                                   mlir::Value val);
+
   /// Get the entry block of the current Function
   mlir::Block *getEntryBlock() { return &getFunction().front(); }
 

diff  --git a/flang/lib/Optimizer/Builder/CMakeLists.txt b/flang/lib/Optimizer/Builder/CMakeLists.txt
index 0a4b2dcacf083..13c10128905b1 100644
--- a/flang/lib/Optimizer/Builder/CMakeLists.txt
+++ b/flang/lib/Optimizer/Builder/CMakeLists.txt
@@ -3,6 +3,7 @@ get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
 add_flang_library(FIRBuilder
   BoxValue.cpp
   Character.cpp
+  Complex.cpp
   DoLoopHelper.cpp
   FIRBuilder.cpp
   MutableBox.cpp

diff  --git a/flang/lib/Optimizer/Builder/Complex.cpp b/flang/lib/Optimizer/Builder/Complex.cpp
new file mode 100644
index 0000000000000..e97cb30678089
--- /dev/null
+++ b/flang/lib/Optimizer/Builder/Complex.cpp
@@ -0,0 +1,36 @@
+//===-- Complex.cpp -------------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "flang/Optimizer/Builder/Complex.h"
+
+//===----------------------------------------------------------------------===//
+// Complex Factory implementation
+//===----------------------------------------------------------------------===//
+
+mlir::Type
+fir::factory::Complex::getComplexPartType(mlir::Type complexType) const {
+  return builder.getRealType(complexType.cast<fir::ComplexType>().getFKind());
+}
+
+mlir::Type fir::factory::Complex::getComplexPartType(mlir::Value cplx) const {
+  return getComplexPartType(cplx.getType());
+}
+
+mlir::Value fir::factory::Complex::createComplex(fir::KindTy kind,
+                                                 mlir::Value real,
+                                                 mlir::Value imag) {
+  auto complexTy = fir::ComplexType::get(builder.getContext(), kind);
+  return createComplex(complexTy, real, imag);
+}
+
+mlir::Value fir::factory::Complex::createComplex(mlir::Type cplxTy,
+                                                 mlir::Value real,
+                                                 mlir::Value imag) {
+  mlir::Value und = builder.create<fir::UndefOp>(loc, cplxTy);
+  return insert<Part::Imag>(insert<Part::Real>(und, real), imag);
+}

diff  --git a/flang/lib/Optimizer/Builder/FIRBuilder.cpp b/flang/lib/Optimizer/Builder/FIRBuilder.cpp
index 435088cefbb5f..9b0ea245f4a4f 100644
--- a/flang/lib/Optimizer/Builder/FIRBuilder.cpp
+++ b/flang/lib/Optimizer/Builder/FIRBuilder.cpp
@@ -9,6 +9,7 @@
 #include "flang/Optimizer/Builder/FIRBuilder.h"
 #include "flang/Optimizer/Builder/BoxValue.h"
 #include "flang/Optimizer/Builder/Character.h"
+#include "flang/Optimizer/Builder/Complex.h"
 #include "flang/Optimizer/Builder/MutableBox.h"
 #include "flang/Optimizer/Dialect/FIROpsSupport.h"
 #include "flang/Optimizer/Support/FatalError.h"
@@ -257,6 +258,33 @@ fir::GlobalOp fir::FirOpBuilder::createGlobal(
   return glob;
 }
 
+mlir::Value fir::FirOpBuilder::convertWithSemantics(mlir::Location loc,
+                                                    mlir::Type toTy,
+                                                    mlir::Value val) {
+  assert(toTy && "store location must be typed");
+  auto fromTy = val.getType();
+  if (fromTy == toTy)
+    return val;
+  fir::factory::Complex helper{*this, loc};
+  if ((fir::isa_real(fromTy) || fir::isa_integer(fromTy)) &&
+      fir::isa_complex(toTy)) {
+    // imaginary part is zero
+    auto eleTy = helper.getComplexPartType(toTy);
+    auto cast = createConvert(loc, eleTy, val);
+    llvm::APFloat zero{
+        kindMap.getFloatSemantics(toTy.cast<fir::ComplexType>().getFKind()), 0};
+    auto imag = createRealConstant(loc, eleTy, zero);
+    return helper.createComplex(toTy, cast, imag);
+  }
+  if (fir::isa_complex(fromTy) &&
+      (fir::isa_integer(toTy) || fir::isa_real(toTy))) {
+    // drop the imaginary part
+    auto rp = helper.extractComplexPart(val, /*isImagPart=*/false);
+    return createConvert(loc, toTy, rp);
+  }
+  return createConvert(loc, toTy, val);
+}
+
 mlir::Value fir::FirOpBuilder::createConvert(mlir::Location loc,
                                              mlir::Type toTy, mlir::Value val) {
   if (val.getType() != toTy) {

diff  --git a/flang/unittests/Optimizer/Builder/ComplexTest.cpp b/flang/unittests/Optimizer/Builder/ComplexTest.cpp
new file mode 100644
index 0000000000000..54e335e7c031b
--- /dev/null
+++ b/flang/unittests/Optimizer/Builder/ComplexTest.cpp
@@ -0,0 +1,100 @@
+//===- ComplexExprTest.cpp -- ComplexExpr unit tests ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "flang/Optimizer/Builder/Complex.h"
+#include "gtest/gtest.h"
+#include "flang/Optimizer/Builder/FIRBuilder.h"
+#include "flang/Optimizer/Support/InitFIR.h"
+#include "flang/Optimizer/Support/KindMapping.h"
+
+struct ComplexTest : public testing::Test {
+public:
+  void SetUp() override {
+    mlir::OpBuilder builder(&context);
+    auto loc = builder.getUnknownLoc();
+
+    // Set up a Module with a dummy function operation inside.
+    // Set the insertion point in the function entry block.
+    mlir::ModuleOp mod = builder.create<mlir::ModuleOp>(loc);
+    mlir::FuncOp func = mlir::FuncOp::create(
+        loc, "func1", builder.getFunctionType(llvm::None, llvm::None));
+    auto *entryBlock = func.addEntryBlock();
+    mod.push_back(mod);
+    builder.setInsertionPointToStart(entryBlock);
+
+    fir::support::loadDialects(context);
+    kindMap = std::make_unique<fir::KindMapping>(&context);
+    firBuilder = std::make_unique<fir::FirOpBuilder>(mod, *kindMap);
+    helper = std::make_unique<fir::factory::Complex>(*firBuilder, loc);
+
+    // Init commonly used types
+    realTy1 = mlir::FloatType::getF32(&context);
+    complexTy1 = fir::ComplexType::get(&context, 4);
+    integerTy1 = mlir::IntegerType::get(&context, 32);
+
+    // Create commonly used reals
+    rOne = firBuilder->createRealConstant(loc, realTy1, 1u);
+    rTwo = firBuilder->createRealConstant(loc, realTy1, 2u);
+    rThree = firBuilder->createRealConstant(loc, realTy1, 3u);
+    rFour = firBuilder->createRealConstant(loc, realTy1, 4u);
+  }
+
+  mlir::MLIRContext context;
+  std::unique_ptr<fir::KindMapping> kindMap;
+  std::unique_ptr<fir::FirOpBuilder> firBuilder;
+  std::unique_ptr<fir::factory::Complex> helper;
+
+  // Commonly used real/complex/integer types
+  mlir::FloatType realTy1;
+  fir::ComplexType complexTy1;
+  mlir::IntegerType integerTy1;
+
+  // Commonly used real numbers
+  mlir::Value rOne;
+  mlir::Value rTwo;
+  mlir::Value rThree;
+  mlir::Value rFour;
+};
+
+TEST_F(ComplexTest, verifyTypes) {
+  mlir::Value cVal1 = helper->createComplex(complexTy1, rOne, rTwo);
+  mlir::Value cVal2 = helper->createComplex(4, rOne, rTwo);
+  EXPECT_TRUE(fir::isa_complex(cVal1.getType()));
+  EXPECT_TRUE(fir::isa_complex(cVal2.getType()));
+  EXPECT_TRUE(fir::isa_real(helper->getComplexPartType(cVal1)));
+  EXPECT_TRUE(fir::isa_real(helper->getComplexPartType(cVal2)));
+
+  mlir::Value real1 = helper->extractComplexPart(cVal1, /*isImagPart=*/false);
+  mlir::Value imag1 = helper->extractComplexPart(cVal1, /*isImagPart=*/true);
+  mlir::Value real2 = helper->extractComplexPart(cVal2, /*isImagPart=*/false);
+  mlir::Value imag2 = helper->extractComplexPart(cVal2, /*isImagPart=*/true);
+  EXPECT_EQ(realTy1, real1.getType());
+  EXPECT_EQ(realTy1, imag1.getType());
+  EXPECT_EQ(realTy1, real2.getType());
+  EXPECT_EQ(realTy1, imag2.getType());
+
+  mlir::Value cVal3 =
+      helper->insertComplexPart(cVal1, rThree, /*isImagPart=*/false);
+  mlir::Value cVal4 =
+      helper->insertComplexPart(cVal3, rFour, /*isImagPart=*/true);
+  EXPECT_TRUE(fir::isa_complex(cVal4.getType()));
+  EXPECT_TRUE(fir::isa_real(helper->getComplexPartType(cVal4)));
+}
+
+TEST_F(ComplexTest, verifyConvertWithSemantics) {
+  auto loc = firBuilder->getUnknownLoc();
+  rOne = firBuilder->createRealConstant(loc, realTy1, 1u);
+  // Convert real to complex
+  mlir::Value v1 = firBuilder->convertWithSemantics(loc, complexTy1, rOne);
+  EXPECT_TRUE(fir::isa_complex(v1.getType()));
+
+  // Convert complex to integer
+  mlir::Value v2 = firBuilder->convertWithSemantics(loc, integerTy1, v1);
+  EXPECT_TRUE(v2.getType().isa<mlir::IntegerType>());
+  EXPECT_TRUE(mlir::dyn_cast<fir::ConvertOp>(v2.getDefiningOp()));
+}

diff  --git a/flang/unittests/Optimizer/CMakeLists.txt b/flang/unittests/Optimizer/CMakeLists.txt
index 8d280e936d244..e54aae6554de8 100644
--- a/flang/unittests/Optimizer/CMakeLists.txt
+++ b/flang/unittests/Optimizer/CMakeLists.txt
@@ -10,6 +10,7 @@ set(LIBS
 
 add_flang_unittest(FlangOptimizerTests
   Builder/CharacterTest.cpp
+  Builder/ComplexTest.cpp
   Builder/DoLoopHelperTest.cpp
   Builder/FIRBuilderTest.cpp
   FIRContextTest.cpp


        


More information about the flang-commits mailing list