[flang-commits] [flang] 80dcc90 - [NFC] Restructured SimplifyIntrinsicsPass::getOrCreateFunction.
Slava Zakharin via flang-commits
flang-commits at lists.llvm.org
Wed Aug 10 09:42:49 PDT 2022
Author: Slava Zakharin
Date: 2022-08-10T09:40:57-07:00
New Revision: 80dcc907a8a2de9c0e24d1d40625c54c828d508a
URL: https://github.com/llvm/llvm-project/commit/80dcc907a8a2de9c0e24d1d40625c54c828d508a
DIFF: https://github.com/llvm/llvm-project/commit/80dcc907a8a2de9c0e24d1d40625c54c828d508a.diff
LOG: [NFC] Restructured SimplifyIntrinsicsPass::getOrCreateFunction.
I would like to add DOT_PRODUCT support in this pass, so this restructuring
is the first step to allow some code reuse inside getOrCreateFunction().
Differential Revision: https://reviews.llvm.org/D131530
Added:
Modified:
flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp
Removed:
################################################################################
diff --git a/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp b/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp
index cc30694ff4769..ff8f4cff18ec6 100644
--- a/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp
+++ b/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp
@@ -40,35 +40,41 @@ namespace {
class SimplifyIntrinsicsPass
: public fir::SimplifyIntrinsicsBase<SimplifyIntrinsicsPass> {
+ using FunctionTypeGeneratorTy =
+ std::function<mlir::FunctionType(fir::FirOpBuilder &)>;
+ using FunctionBodyGeneratorTy =
+ std::function<void(fir::FirOpBuilder &, mlir::func::FuncOp &)>;
+
public:
- mlir::func::FuncOp getOrCreateFunction(const mlir::Location &loc,
- fir::FirOpBuilder &builder,
- const mlir::Type &type,
- const mlir::StringRef &basename);
+ /// Generate a new function implementing a simplified version
+ /// of a Fortran runtime function defined by \p basename name.
+ /// \p typeGenerator is a callback that generates the new function's type.
+ /// \p bodyGenerator is a callback that generates the new function's body.
+ /// The new function is created in the \p builder's Module.
+ mlir::func::FuncOp getOrCreateFunction(fir::FirOpBuilder &builder,
+ const mlir::StringRef &basename,
+ FunctionTypeGeneratorTy typeGenerator,
+ FunctionBodyGeneratorTy bodyGenerator);
void runOnOperation() override;
};
} // namespace
-mlir::func::FuncOp SimplifyIntrinsicsPass::getOrCreateFunction(
- const mlir::Location &loc, fir::FirOpBuilder &builder,
- const mlir::Type &type, const mlir::StringRef &baseName) {
- // In future, the idea is that instead of building the function inside
- // this function, this does the base creation, and calls a callback
- // function (e.g. a lambda function) that fills in the actual content.
- // For now, check that it's the ONLY the SUM runtime call.
- assert(baseName.startswith("_FortranASum"));
-
- std::string replacementName = mlir::Twine{baseName, "_simplified"}.str();
- mlir::ModuleOp module = builder.getModule();
- // If we already have a function, just return it.
- mlir::func::FuncOp newFunc =
- fir::FirOpBuilder::getNamedFunction(module, replacementName);
- if (newFunc)
- return newFunc;
+/// Generate function type for the simplified version of FortranASum
+/// operating on the given \p elementType.
+static mlir::FunctionType genFortranASumType(fir::FirOpBuilder &builder,
+ const mlir::Type &elementType) {
+ mlir::Type boxType = fir::BoxType::get(builder.getNoneType());
+ return mlir::FunctionType::get(builder.getContext(), {boxType},
+ {elementType});
+}
- // Need to build the function!
- // Basic idea:
+/// Generate function body of the simplified version of FortranASum
+/// with signature provided by \p funcOp. The caller is responsible
+/// for saving/restoring the original insertion point of \p builder.
+/// \p funcOp is expected to be empty on entry to this function.
+static void genFortranASumBody(fir::FirOpBuilder &builder,
+ mlir::func::FuncOp &funcOp) {
// function FortranASum<T>_simplified(arr)
// T, dimension(:) :: arr
// T sum = 0
@@ -78,35 +84,25 @@ mlir::func::FuncOp SimplifyIntrinsicsPass::getOrCreateFunction(
// end do
// FortranASum<T>_simplified = sum
// end function FortranASum<T>_simplified
- mlir::Type boxType = fir::BoxType::get(builder.getNoneType());
- mlir::FunctionType fType =
- mlir::FunctionType::get(builder.getContext(), {boxType}, {type});
- newFunc =
- fir::FirOpBuilder::createFunction(loc, module, replacementName, fType);
- auto inlineLinkage = mlir::LLVM::linkage::Linkage::LinkonceODR;
- auto linkage =
- mlir::LLVM::LinkageAttr::get(builder.getContext(), inlineLinkage);
- newFunc->setAttr("llvm.linkage", linkage);
-
- // Save the position of the original call.
- mlir::OpBuilder::InsertPoint insertPt = builder.saveInsertionPoint();
- builder.setInsertionPointToEnd(newFunc.addEntryBlock());
+ auto loc = mlir::UnknownLoc::get(builder.getContext());
+ mlir::Type elementType = funcOp.getResultTypes()[0];
+ builder.setInsertionPointToEnd(funcOp.addEntryBlock());
mlir::IndexType idxTy = builder.getIndexType();
- mlir::Value zero = type.isa<mlir::FloatType>()
- ? builder.createRealConstant(loc, type, 0.0)
- : builder.createIntegerConstant(loc, type, 0);
- mlir::Value sum = builder.create<fir::AllocaOp>(loc, type);
+ mlir::Value zero = elementType.isa<mlir::FloatType>()
+ ? builder.createRealConstant(loc, elementType, 0.0)
+ : builder.createIntegerConstant(loc, elementType, 0);
+ mlir::Value sum = builder.create<fir::AllocaOp>(loc, elementType);
builder.create<fir::StoreOp>(loc, zero, sum);
- mlir::Block::BlockArgListType args = newFunc.front().getArguments();
+ mlir::Block::BlockArgListType args = funcOp.front().getArguments();
mlir::Value arg = args[0];
mlir::Value zeroIdx = builder.createIntegerConstant(loc, idxTy, 0);
fir::SequenceType::Shape flatShape = {fir::SequenceType::getUnknownExtent()};
- mlir::Type arrTy = fir::SequenceType::get(flatShape, type);
+ mlir::Type arrTy = fir::SequenceType::get(flatShape, elementType);
mlir::Type boxArrTy = fir::BoxType::get(arrTy);
mlir::Value array = builder.create<fir::ConvertOp>(loc, boxArrTy, arg);
auto dims =
@@ -123,7 +119,7 @@ mlir::func::FuncOp SimplifyIntrinsicsPass::getOrCreateFunction(
mlir::OpBuilder::InsertPoint loopEndPt = builder.saveInsertionPoint();
builder.setInsertionPointToStart(loop.getBody());
- mlir::Type eleRefTy = builder.getRefType(type);
+ mlir::Type eleRefTy = builder.getRefType(elementType);
mlir::Value index = loop.getInductionVar();
mlir::Value addr =
builder.create<fir::CoordinateOp>(loc, eleRefTy, array, index);
@@ -131,9 +127,9 @@ mlir::func::FuncOp SimplifyIntrinsicsPass::getOrCreateFunction(
mlir::Value sumVal = builder.create<fir::LoadOp>(loc, sum);
mlir::Value res;
- if (type.isa<mlir::FloatType>())
+ if (elementType.isa<mlir::FloatType>())
res = builder.create<mlir::arith::AddFOp>(loc, elem, sumVal);
- else if (type.isa<mlir::IntegerType>())
+ else if (elementType.isa<mlir::IntegerType>())
res = builder.create<mlir::arith::AddIOp>(loc, elem, sumVal);
else
TODO(loc, "Unsupported type");
@@ -144,6 +140,44 @@ mlir::func::FuncOp SimplifyIntrinsicsPass::getOrCreateFunction(
mlir::Value resultVal = builder.create<fir::LoadOp>(loc, sum);
builder.create<mlir::func::ReturnOp>(loc, resultVal);
+}
+
+mlir::func::FuncOp SimplifyIntrinsicsPass::getOrCreateFunction(
+ fir::FirOpBuilder &builder, const mlir::StringRef &baseName,
+ FunctionTypeGeneratorTy typeGenerator,
+ FunctionBodyGeneratorTy bodyGenerator) {
+ // WARNING: if the function generated here changes its signature
+ // or behavior (the body code), we should probably embed some
+ // versioning information into its name, otherwise libraries
+ // statically linked with older versions of Flang may stop
+ // working with object files created with newer Flang.
+ // We can also avoid this by using internal linkage, but
+ // this may increase the size of final executable/shared library.
+ std::string replacementName = mlir::Twine{baseName, "_simplified"}.str();
+ mlir::ModuleOp module = builder.getModule();
+ // If we already have a function, just return it.
+ mlir::func::FuncOp newFunc =
+ fir::FirOpBuilder::getNamedFunction(module, replacementName);
+ mlir::FunctionType fType = typeGenerator(builder);
+ if (newFunc) {
+ assert(newFunc.getFunctionType() == fType &&
+ "type mismatch for simplified function");
+ return newFunc;
+ }
+
+ // Need to build the function!
+ auto loc = mlir::UnknownLoc::get(builder.getContext());
+ newFunc =
+ fir::FirOpBuilder::createFunction(loc, module, replacementName, fType);
+ auto inlineLinkage = mlir::LLVM::linkage::Linkage::LinkonceODR;
+ auto linkage =
+ mlir::LLVM::LinkageAttr::get(builder.getContext(), inlineLinkage);
+ newFunc->setAttr("llvm.linkage", linkage);
+
+ // Save the position of the original call.
+ mlir::OpBuilder::InsertPoint insertPt = builder.saveInsertionPoint();
+
+ bodyGenerator(builder, newFunc);
// Now back to where we were adding code earlier...
builder.restoreInsertionPoint(insertPt);
@@ -218,8 +252,11 @@ void SimplifyIntrinsicsPass::runOnOperation() {
} else {
return;
}
- mlir::func::FuncOp newFunc =
- getOrCreateFunction(loc, builder, type, funcName);
+ auto typeGenerator = [&type](fir::FirOpBuilder &builder) {
+ return genFortranASumType(builder, type);
+ };
+ mlir::func::FuncOp newFunc = getOrCreateFunction(
+ builder, funcName, typeGenerator, genFortranASumBody);
auto newCall = builder.create<fir::CallOp>(
loc, newFunc, mlir::ValueRange{args[0]});
call->replaceAllUsesWith(newCall.getResults());
More information about the flang-commits
mailing list