[flang-commits] [flang] 80dcc90 - [NFC] Restructured SimplifyIntrinsicsPass::getOrCreateFunction.

Slava Zakharin via flang-commits flang-commits at lists.llvm.org
Wed Aug 10 09:42:49 PDT 2022


Author: Slava Zakharin
Date: 2022-08-10T09:40:57-07:00
New Revision: 80dcc907a8a2de9c0e24d1d40625c54c828d508a

URL: https://github.com/llvm/llvm-project/commit/80dcc907a8a2de9c0e24d1d40625c54c828d508a
DIFF: https://github.com/llvm/llvm-project/commit/80dcc907a8a2de9c0e24d1d40625c54c828d508a.diff

LOG: [NFC] Restructured SimplifyIntrinsicsPass::getOrCreateFunction.

I would like to add DOT_PRODUCT support in this pass, so this restructuring
is the first step to allow some code reuse inside getOrCreateFunction().

Differential Revision: https://reviews.llvm.org/D131530

Added: 
    

Modified: 
    flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp

Removed: 
    


################################################################################
diff  --git a/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp b/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp
index cc30694ff4769..ff8f4cff18ec6 100644
--- a/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp
+++ b/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp
@@ -40,35 +40,41 @@ namespace {
 
 class SimplifyIntrinsicsPass
     : public fir::SimplifyIntrinsicsBase<SimplifyIntrinsicsPass> {
+  using FunctionTypeGeneratorTy =
+      std::function<mlir::FunctionType(fir::FirOpBuilder &)>;
+  using FunctionBodyGeneratorTy =
+      std::function<void(fir::FirOpBuilder &, mlir::func::FuncOp &)>;
+
 public:
-  mlir::func::FuncOp getOrCreateFunction(const mlir::Location &loc,
-                                         fir::FirOpBuilder &builder,
-                                         const mlir::Type &type,
-                                         const mlir::StringRef &basename);
+  /// Generate a new function implementing a simplified version
+  /// of a Fortran runtime function defined by \p basename name.
+  /// \p typeGenerator is a callback that generates the new function's type.
+  /// \p bodyGenerator is a callback that generates the new function's body.
+  /// The new function is created in the \p builder's Module.
+  mlir::func::FuncOp getOrCreateFunction(fir::FirOpBuilder &builder,
+                                         const mlir::StringRef &basename,
+                                         FunctionTypeGeneratorTy typeGenerator,
+                                         FunctionBodyGeneratorTy bodyGenerator);
   void runOnOperation() override;
 };
 
 } // namespace
 
-mlir::func::FuncOp SimplifyIntrinsicsPass::getOrCreateFunction(
-    const mlir::Location &loc, fir::FirOpBuilder &builder,
-    const mlir::Type &type, const mlir::StringRef &baseName) {
-  // In future, the idea is that instead of building the function inside
-  // this function, this does the base creation, and calls a callback
-  // function (e.g. a lambda function) that fills in the actual content.
-  // For now, check that it's the ONLY the SUM runtime call.
-  assert(baseName.startswith("_FortranASum"));
-
-  std::string replacementName = mlir::Twine{baseName, "_simplified"}.str();
-  mlir::ModuleOp module = builder.getModule();
-  // If we already have a function, just return it.
-  mlir::func::FuncOp newFunc =
-      fir::FirOpBuilder::getNamedFunction(module, replacementName);
-  if (newFunc)
-    return newFunc;
+/// Generate function type for the simplified version of FortranASum
+/// operating on the given \p elementType.
+static mlir::FunctionType genFortranASumType(fir::FirOpBuilder &builder,
+                                             const mlir::Type &elementType) {
+  mlir::Type boxType = fir::BoxType::get(builder.getNoneType());
+  return mlir::FunctionType::get(builder.getContext(), {boxType},
+                                 {elementType});
+}
 
-  // Need to build the function!
-  // Basic idea:
+/// Generate function body of the simplified version of FortranASum
+/// with signature provided by \p funcOp. The caller is responsible
+/// for saving/restoring the original insertion point of \p builder.
+/// \p funcOp is expected to be empty on entry to this function.
+static void genFortranASumBody(fir::FirOpBuilder &builder,
+                               mlir::func::FuncOp &funcOp) {
   // function FortranASum<T>_simplified(arr)
   //   T, dimension(:) :: arr
   //   T sum = 0
@@ -78,35 +84,25 @@ mlir::func::FuncOp SimplifyIntrinsicsPass::getOrCreateFunction(
   //   end do
   //   FortranASum<T>_simplified = sum
   // end function FortranASum<T>_simplified
-  mlir::Type boxType = fir::BoxType::get(builder.getNoneType());
-  mlir::FunctionType fType =
-      mlir::FunctionType::get(builder.getContext(), {boxType}, {type});
-  newFunc =
-      fir::FirOpBuilder::createFunction(loc, module, replacementName, fType);
-  auto inlineLinkage = mlir::LLVM::linkage::Linkage::LinkonceODR;
-  auto linkage =
-      mlir::LLVM::LinkageAttr::get(builder.getContext(), inlineLinkage);
-  newFunc->setAttr("llvm.linkage", linkage);
-
-  // Save the position of the original call.
-  mlir::OpBuilder::InsertPoint insertPt = builder.saveInsertionPoint();
-  builder.setInsertionPointToEnd(newFunc.addEntryBlock());
+  auto loc = mlir::UnknownLoc::get(builder.getContext());
+  mlir::Type elementType = funcOp.getResultTypes()[0];
+  builder.setInsertionPointToEnd(funcOp.addEntryBlock());
 
   mlir::IndexType idxTy = builder.getIndexType();
 
-  mlir::Value zero = type.isa<mlir::FloatType>()
-                         ? builder.createRealConstant(loc, type, 0.0)
-                         : builder.createIntegerConstant(loc, type, 0);
-  mlir::Value sum = builder.create<fir::AllocaOp>(loc, type);
+  mlir::Value zero = elementType.isa<mlir::FloatType>()
+                         ? builder.createRealConstant(loc, elementType, 0.0)
+                         : builder.createIntegerConstant(loc, elementType, 0);
+  mlir::Value sum = builder.create<fir::AllocaOp>(loc, elementType);
   builder.create<fir::StoreOp>(loc, zero, sum);
 
-  mlir::Block::BlockArgListType args = newFunc.front().getArguments();
+  mlir::Block::BlockArgListType args = funcOp.front().getArguments();
   mlir::Value arg = args[0];
 
   mlir::Value zeroIdx = builder.createIntegerConstant(loc, idxTy, 0);
 
   fir::SequenceType::Shape flatShape = {fir::SequenceType::getUnknownExtent()};
-  mlir::Type arrTy = fir::SequenceType::get(flatShape, type);
+  mlir::Type arrTy = fir::SequenceType::get(flatShape, elementType);
   mlir::Type boxArrTy = fir::BoxType::get(arrTy);
   mlir::Value array = builder.create<fir::ConvertOp>(loc, boxArrTy, arg);
   auto dims =
@@ -123,7 +119,7 @@ mlir::func::FuncOp SimplifyIntrinsicsPass::getOrCreateFunction(
   mlir::OpBuilder::InsertPoint loopEndPt = builder.saveInsertionPoint();
   builder.setInsertionPointToStart(loop.getBody());
 
-  mlir::Type eleRefTy = builder.getRefType(type);
+  mlir::Type eleRefTy = builder.getRefType(elementType);
   mlir::Value index = loop.getInductionVar();
   mlir::Value addr =
       builder.create<fir::CoordinateOp>(loc, eleRefTy, array, index);
@@ -131,9 +127,9 @@ mlir::func::FuncOp SimplifyIntrinsicsPass::getOrCreateFunction(
   mlir::Value sumVal = builder.create<fir::LoadOp>(loc, sum);
 
   mlir::Value res;
-  if (type.isa<mlir::FloatType>())
+  if (elementType.isa<mlir::FloatType>())
     res = builder.create<mlir::arith::AddFOp>(loc, elem, sumVal);
-  else if (type.isa<mlir::IntegerType>())
+  else if (elementType.isa<mlir::IntegerType>())
     res = builder.create<mlir::arith::AddIOp>(loc, elem, sumVal);
   else
     TODO(loc, "Unsupported type");
@@ -144,6 +140,44 @@ mlir::func::FuncOp SimplifyIntrinsicsPass::getOrCreateFunction(
 
   mlir::Value resultVal = builder.create<fir::LoadOp>(loc, sum);
   builder.create<mlir::func::ReturnOp>(loc, resultVal);
+}
+
+mlir::func::FuncOp SimplifyIntrinsicsPass::getOrCreateFunction(
+    fir::FirOpBuilder &builder, const mlir::StringRef &baseName,
+    FunctionTypeGeneratorTy typeGenerator,
+    FunctionBodyGeneratorTy bodyGenerator) {
+  // WARNING: if the function generated here changes its signature
+  //          or behavior (the body code), we should probably embed some
+  //          versioning information into its name, otherwise libraries
+  //          statically linked with older versions of Flang may stop
+  //          working with object files created with newer Flang.
+  //          We can also avoid this by using internal linkage, but
+  //          this may increase the size of final executable/shared library.
+  std::string replacementName = mlir::Twine{baseName, "_simplified"}.str();
+  mlir::ModuleOp module = builder.getModule();
+  // If we already have a function, just return it.
+  mlir::func::FuncOp newFunc =
+      fir::FirOpBuilder::getNamedFunction(module, replacementName);
+  mlir::FunctionType fType = typeGenerator(builder);
+  if (newFunc) {
+    assert(newFunc.getFunctionType() == fType &&
+           "type mismatch for simplified function");
+    return newFunc;
+  }
+
+  // Need to build the function!
+  auto loc = mlir::UnknownLoc::get(builder.getContext());
+  newFunc =
+      fir::FirOpBuilder::createFunction(loc, module, replacementName, fType);
+  auto inlineLinkage = mlir::LLVM::linkage::Linkage::LinkonceODR;
+  auto linkage =
+      mlir::LLVM::LinkageAttr::get(builder.getContext(), inlineLinkage);
+  newFunc->setAttr("llvm.linkage", linkage);
+
+  // Save the position of the original call.
+  mlir::OpBuilder::InsertPoint insertPt = builder.saveInsertionPoint();
+
+  bodyGenerator(builder, newFunc);
 
   // Now back to where we were adding code earlier...
   builder.restoreInsertionPoint(insertPt);
@@ -218,8 +252,11 @@ void SimplifyIntrinsicsPass::runOnOperation() {
             } else {
               return;
             }
-            mlir::func::FuncOp newFunc =
-                getOrCreateFunction(loc, builder, type, funcName);
+            auto typeGenerator = [&type](fir::FirOpBuilder &builder) {
+              return genFortranASumType(builder, type);
+            };
+            mlir::func::FuncOp newFunc = getOrCreateFunction(
+                builder, funcName, typeGenerator, genFortranASumBody);
             auto newCall = builder.create<fir::CallOp>(
                 loc, newFunc, mlir::ValueRange{args[0]});
             call->replaceAllUsesWith(newCall.getResults());


        


More information about the flang-commits mailing list