[Mlir-commits] [mlir] 75a3f32 - [IR] Add an ImplicitLocOpBuilder helper class for building IR with the same loc.

Chris Lattner llvmlistbot at llvm.org
Tue Dec 22 14:47:43 PST 2020


Author: Chris Lattner
Date: 2020-12-22T14:47:33-08:00
New Revision: 75a3f326c3d874853031d8bedd1d00127c835103

URL: https://github.com/llvm/llvm-project/commit/75a3f326c3d874853031d8bedd1d00127c835103
DIFF: https://github.com/llvm/llvm-project/commit/75a3f326c3d874853031d8bedd1d00127c835103.diff

LOG: [IR] Add an ImplicitLocOpBuilder helper class for building IR with the same loc.

One common situation is to create a lot of IR at a well known location,
e.g. when doing a big rewrite from one dialect to another where you're expanding
ops out into lots of other ops.

For these sorts of situations, it is annoying to pass the location into
every create call.  As we discused in a few threads on the forum, a way to help
with this is to produce a new sort of builder that holds a location and provides
it to each of the create<> calls automatically.

This patch implements an ImplicitLocOpBuilder class that does this.  We've had
good experience with this in the CIRCT project, and it makes sense to upstream to
MLIR.

I picked a random pass to adopt it to show the impact, but I don't think there is
any particular need to force adopt it in the codebase.

Differential Revision: https://reviews.llvm.org/D93717

Added: 
    mlir/include/mlir/IR/ImplicitLocOpBuilder.h

Modified: 
    mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/ImplicitLocOpBuilder.h b/mlir/include/mlir/IR/ImplicitLocOpBuilder.h
new file mode 100644
index 000000000000..2dc7c34f4e85
--- /dev/null
+++ b/mlir/include/mlir/IR/ImplicitLocOpBuilder.h
@@ -0,0 +1,123 @@
+//===- ImplicitLocOpBuilder.h - Convenience OpBuilder -----------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Helper class to create ops with a modally set location.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_IR_IMPLICITLOCOPBUILDER_H
+#define MLIR_IR_IMPLICITLOCOPBUILDER_H
+
+#include "mlir/IR/Builders.h"
+
+namespace mlir {
+
+/// ImplictLocOpBuilder maintains a 'current location', allowing use of the
+/// create<> method without specifying the location.  It is otherwise the same
+/// as OpBuilder.
+class ImplicitLocOpBuilder : public mlir::OpBuilder {
+public:
+  /// Create an ImplicitLocOpBuilder using the insertion point and listener from
+  /// an existing OpBuilder.
+  ImplicitLocOpBuilder(Location loc, const OpBuilder &builder)
+      : OpBuilder(builder), curLoc(loc) {}
+
+  /// OpBuilder has a bunch of convenience constructors - we support them all
+  /// with the additional Location.
+  template <typename T>
+  ImplicitLocOpBuilder(Location loc, T &&operand, Listener *listener = nullptr)
+      : OpBuilder(std::forward<T>(operand), listener), curLoc(loc) {}
+
+  ImplicitLocOpBuilder(Location loc, Block *block, Block::iterator insertPoint,
+                       Listener *listener = nullptr)
+      : OpBuilder(block, insertPoint, listener), curLoc(loc) {}
+
+  /// Create a builder and set the insertion point to before the first operation
+  /// in the block but still inside the block.
+  static ImplicitLocOpBuilder atBlockBegin(Location loc, Block *block,
+                                           Listener *listener = nullptr) {
+    return ImplicitLocOpBuilder(loc, block, block->begin(), listener);
+  }
+
+  /// Create a builder and set the insertion point to after the last operation
+  /// in the block but still inside the block.
+  static ImplicitLocOpBuilder atBlockEnd(Location loc, Block *block,
+                                         Listener *listener = nullptr) {
+    return ImplicitLocOpBuilder(loc, block, block->end(), listener);
+  }
+
+  /// Create a builder and set the insertion point to before the block
+  /// terminator.
+  static ImplicitLocOpBuilder atBlockTerminator(Location loc, Block *block,
+                                                Listener *listener = nullptr) {
+    auto *terminator = block->getTerminator();
+    assert(terminator != nullptr && "the block has no terminator");
+    return ImplicitLocOpBuilder(loc, block, Block::iterator(terminator),
+                                listener);
+  }
+
+  /// Accessors for the implied location.
+  Location getLoc() const { return curLoc; }
+  void setLoc(Location loc) { curLoc = loc; }
+
+  // We allow clients to use the explicit-loc version of create as well.
+  using OpBuilder::create;
+  using OpBuilder::createOrFold;
+
+  /// Create an operation of specific op type at the current insertion point and
+  /// location.
+  template <typename OpTy, typename... Args>
+  OpTy create(Args &&... args) {
+    return OpBuilder::create<OpTy>(curLoc, std::forward<Args>(args)...);
+  }
+
+  /// Create an operation of specific op type at the current insertion point,
+  /// and immediately try to fold it. This functions populates 'results' with
+  /// the results after folding the operation.
+  template <typename OpTy, typename... Args>
+  void createOrFold(llvm::SmallVectorImpl<Value> &results, Args &&... args) {
+    OpBuilder::createOrFold<OpTy>(results, curLoc, std::forward<Args>(args)...);
+  }
+
+  /// Overload to create or fold a single result operation.
+  template <typename OpTy, typename... Args>
+  typename std::enable_if<OpTy::template hasTrait<mlir::OpTrait::OneResult>(),
+                          Value>::type
+  createOrFold(Args &&... args) {
+    return OpBuilder::createOrFold<OpTy>(curLoc, std::forward<Args>(args)...);
+  }
+
+  /// Overload to create or fold a zero result operation.
+  template <typename OpTy, typename... Args>
+  typename std::enable_if<OpTy::template hasTrait<mlir::OpTrait::ZeroResult>(),
+                          OpTy>::type
+  createOrFold(Args &&... args) {
+    return OpBuilder::createOrFold<OpTy>(curLoc, std::forward<Args>(args)...);
+  }
+
+  /// This builder can also be used to emit diagnostics to the current location.
+  mlir::InFlightDiagnostic
+  emitError(const llvm::Twine &message = llvm::Twine()) {
+    return mlir::emitError(curLoc, message);
+  }
+  mlir::InFlightDiagnostic
+  emitWarning(const llvm::Twine &message = llvm::Twine()) {
+    return mlir::emitWarning(curLoc, message);
+  }
+  mlir::InFlightDiagnostic
+  emitRemark(const llvm::Twine &message = llvm::Twine()) {
+    return mlir::emitRemark(curLoc, message);
+  }
+
+private:
+  Location curLoc;
+};
+
+} // namespace mlir
+
+#endif // MLIR_IR_IMPLICITLOCOPBUILDER_H
\ No newline at end of file

diff  --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
index 65545d8ab2de..2415924557db 100644
--- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
+++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
@@ -13,7 +13,7 @@
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/IR/BlockAndValueMapping.h"
-#include "mlir/IR/Builders.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/DialectConversion.h"
@@ -112,12 +112,13 @@ struct AsyncAPI {
 
 // Adds Async Runtime C API declarations to the module.
 static void addAsyncRuntimeApiDeclarations(ModuleOp module) {
-  auto builder = OpBuilder::atBlockTerminator(module.getBody());
+  auto builder = ImplicitLocOpBuilder::atBlockTerminator(module.getLoc(),
+                                                         module.getBody());
 
   auto addFuncDecl = [&](StringRef name, FunctionType type) {
     if (module.lookupSymbol(name))
       return;
-    builder.create<FuncOp>(module.getLoc(), name, type).setPrivate();
+    builder.create<FuncOp>(name, type).setPrivate();
   };
 
   MLIRContext *ctx = module.getContext();
@@ -149,13 +150,13 @@ static constexpr const char *kCoroFree = "llvm.coro.free";
 static constexpr const char *kCoroResume = "llvm.coro.resume";
 
 /// Adds an LLVM function declaration to a module.
-static void addLLVMFuncDecl(ModuleOp module, OpBuilder &builder, StringRef name,
-                            LLVM::LLVMType ret,
+static void addLLVMFuncDecl(ModuleOp module, ImplicitLocOpBuilder &builder,
+                            StringRef name, LLVM::LLVMType ret,
                             ArrayRef<LLVM::LLVMType> params) {
   if (module.lookupSymbol(name))
     return;
   LLVM::LLVMType type = LLVM::LLVMType::getFunctionTy(ret, params, false);
-  builder.create<LLVM::LLVMFuncOp>(module.getLoc(), name, type);
+  builder.create<LLVM::LLVMFuncOp>(name, type);
 }
 
 /// Adds coroutine intrinsics declarations to the module.
@@ -163,7 +164,8 @@ static void addCoroutineIntrinsicsDeclarations(ModuleOp module) {
   using namespace mlir::LLVM;
 
   MLIRContext *ctx = module.getContext();
-  OpBuilder builder(module.getBody()->getTerminator());
+  ImplicitLocOpBuilder builder(module.getLoc(),
+                               module.getBody()->getTerminator());
 
   auto token = LLVMTokenType::get(ctx);
   auto voidTy = LLVMType::getVoidTy(ctx);
@@ -196,7 +198,8 @@ static void addCRuntimeDeclarations(ModuleOp module) {
   using namespace mlir::LLVM;
 
   MLIRContext *ctx = module.getContext();
-  OpBuilder builder(module.getBody()->getTerminator());
+  ImplicitLocOpBuilder builder(module.getLoc(),
+                               module.getBody()->getTerminator());
 
   auto voidTy = LLVMType::getVoidTy(ctx);
   auto i64 = LLVMType::getInt64Ty(ctx);
@@ -232,13 +235,13 @@ static void addResumeFunction(ModuleOp module) {
   resumeOp.setPrivate();
 
   auto *block = resumeOp.addEntryBlock();
-  OpBuilder blockBuilder = OpBuilder::atBlockEnd(block);
+  auto blockBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, block);
 
-  blockBuilder.create<LLVM::CallOp>(loc, TypeRange(),
+  blockBuilder.create<LLVM::CallOp>(TypeRange(),
                                     blockBuilder.getSymbolRefAttr(kCoroResume),
                                     resumeOp.getArgument(0));
 
-  blockBuilder.create<LLVM::ReturnOp>(loc, ValueRange());
+  blockBuilder.create<LLVM::ReturnOp>(ValueRange());
 }
 
 //===----------------------------------------------------------------------===//
@@ -302,13 +305,12 @@ static CoroMachinery setupCoroMachinery(FuncOp func) {
   Block *entryBlock = func.addEntryBlock();
   Location loc = func.getBody().getLoc();
 
-  OpBuilder builder = OpBuilder::atBlockBegin(entryBlock);
+  auto builder = ImplicitLocOpBuilder::atBlockBegin(loc, entryBlock);
 
   // ------------------------------------------------------------------------ //
   // Allocate async tokens/values that we will return from a ramp function.
   // ------------------------------------------------------------------------ //
-  auto createToken =
-      builder.create<CallOp>(loc, kCreateToken, TokenType::get(ctx));
+  auto createToken = builder.create<CallOp>(kCreateToken, TokenType::get(ctx));
 
   // ------------------------------------------------------------------------ //
   // Initialize coroutine: allocate frame, get coroutine handle.
@@ -316,28 +318,28 @@ static CoroMachinery setupCoroMachinery(FuncOp func) {
 
   // Constants for initializing coroutine frame.
   auto constZero =
-      builder.create<LLVM::ConstantOp>(loc, i32, builder.getI32IntegerAttr(0));
+      builder.create<LLVM::ConstantOp>(i32, builder.getI32IntegerAttr(0));
   auto constFalse =
-      builder.create<LLVM::ConstantOp>(loc, i1, builder.getBoolAttr(false));
-  auto nullPtr = builder.create<LLVM::NullOp>(loc, i8Ptr);
+      builder.create<LLVM::ConstantOp>(i1, builder.getBoolAttr(false));
+  auto nullPtr = builder.create<LLVM::NullOp>(i8Ptr);
 
   // Get coroutine id: @llvm.coro.id
   auto coroId = builder.create<LLVM::CallOp>(
-      loc, token, builder.getSymbolRefAttr(kCoroId),
+      token, builder.getSymbolRefAttr(kCoroId),
       ValueRange({constZero, nullPtr, nullPtr, nullPtr}));
 
   // Get coroutine frame size: @llvm.coro.size.i64
   auto coroSize = builder.create<LLVM::CallOp>(
-      loc, i64, builder.getSymbolRefAttr(kCoroSizeI64), ValueRange());
+      i64, builder.getSymbolRefAttr(kCoroSizeI64), ValueRange());
 
   // Allocate memory for coroutine frame.
-  auto coroAlloc = builder.create<LLVM::CallOp>(
-      loc, i8Ptr, builder.getSymbolRefAttr(kMalloc),
-      ValueRange(coroSize.getResult(0)));
+  auto coroAlloc =
+      builder.create<LLVM::CallOp>(i8Ptr, builder.getSymbolRefAttr(kMalloc),
+                                   ValueRange(coroSize.getResult(0)));
 
   // Begin a coroutine: @llvm.coro.begin
   auto coroHdl = builder.create<LLVM::CallOp>(
-      loc, i8Ptr, builder.getSymbolRefAttr(kCoroBegin),
+      i8Ptr, builder.getSymbolRefAttr(kCoroBegin),
       ValueRange({coroId.getResult(0), coroAlloc.getResult(0)}));
 
   Block *cleanupBlock = func.addBlock();
@@ -350,15 +352,14 @@ static CoroMachinery setupCoroMachinery(FuncOp func) {
 
   // Get a pointer to the coroutine frame memory: @llvm.coro.free.
   auto coroMem = builder.create<LLVM::CallOp>(
-      loc, i8Ptr, builder.getSymbolRefAttr(kCoroFree),
+      i8Ptr, builder.getSymbolRefAttr(kCoroFree),
       ValueRange({coroId.getResult(0), coroHdl.getResult(0)}));
 
   // Free the memory.
-  builder.create<LLVM::CallOp>(loc, TypeRange(),
-                               builder.getSymbolRefAttr(kFree),
+  builder.create<LLVM::CallOp>(TypeRange(), builder.getSymbolRefAttr(kFree),
                                ValueRange(coroMem.getResult(0)));
   // Branch into the suspend block.
-  builder.create<BranchOp>(loc, suspendBlock);
+  builder.create<BranchOp>(suspendBlock);
 
   // ------------------------------------------------------------------------ //
   // Coroutine suspend block: mark the end of a coroutine and return allocated
@@ -367,17 +368,17 @@ static CoroMachinery setupCoroMachinery(FuncOp func) {
   builder.setInsertionPointToStart(suspendBlock);
 
   // Mark the end of a coroutine: @llvm.coro.end.
-  builder.create<LLVM::CallOp>(loc, i1, builder.getSymbolRefAttr(kCoroEnd),
+  builder.create<LLVM::CallOp>(i1, builder.getSymbolRefAttr(kCoroEnd),
                                ValueRange({coroHdl.getResult(0), constFalse}));
 
   // Return created `async.token` from the suspend block. This will be the
   // return value of a coroutine ramp function.
-  builder.create<ReturnOp>(loc, createToken.getResult(0));
+  builder.create<ReturnOp>(createToken.getResult(0));
 
   // Branch from the entry block to the cleanup block to create a valid CFG.
   builder.setInsertionPointToEnd(entryBlock);
 
-  builder.create<BranchOp>(loc, cleanupBlock);
+  builder.create<BranchOp>(cleanupBlock);
 
   // `async.await` op lowering will create resume blocks for async
   // continuations, and will conditionally branch to cleanup or suspend blocks.
@@ -471,8 +472,6 @@ outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) {
   MLIRContext *ctx = module.getContext();
   Location loc = execute.getLoc();
 
-  OpBuilder moduleBuilder(module.getBody()->getTerminator());
-
   // Collect all outlined function inputs.
   llvm::SetVector<mlir::Value> functionInputs(execute.dependencies().begin(),
                                               execute.dependencies().end());
@@ -484,13 +483,13 @@ outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) {
   SmallVector<Type, 4> inputTypes(typesRange.begin(), typesRange.end());
   auto outputTypes = execute.getResultTypes();
 
-  auto funcType = moduleBuilder.getFunctionType(inputTypes, outputTypes);
+  auto funcType = FunctionType::get(ctx, inputTypes, outputTypes);
   auto funcAttrs = ArrayRef<NamedAttribute>();
 
   // TODO: Derive outlined function name from the parent FuncOp (support
   // multiple nested async.execute operations).
   FuncOp func = FuncOp::create(loc, kAsyncFnPrefix, funcType, funcAttrs);
-  symbolTable.insert(func, moduleBuilder.getInsertionPoint());
+  symbolTable.insert(func, Block::iterator(module.getBody()->getTerminator()));
 
   SymbolTable::setSymbolVisibility(func, SymbolTable::Visibility::Private);
 
@@ -502,21 +501,21 @@ outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) {
   // Async execute API (execution will be resumed in a thread managed by the
   // async runtime).
   Block *entryBlock = &func.getBlocks().front();
-  OpBuilder builder = OpBuilder::atBlockTerminator(entryBlock);
+  auto builder = ImplicitLocOpBuilder::atBlockTerminator(loc, entryBlock);
 
   // A pointer to coroutine resume intrinsic wrapper.
   auto resumeFnTy = AsyncAPI::resumeFunctionType(ctx);
   auto resumePtr = builder.create<LLVM::AddressOfOp>(
-      loc, LLVM::LLVMPointerType::get(resumeFnTy), kResume);
+      LLVM::LLVMPointerType::get(resumeFnTy), kResume);
 
   // Save the coroutine state: @llvm.coro.save
   auto coroSave = builder.create<LLVM::CallOp>(
-      loc, LLVM::LLVMTokenType::get(ctx), builder.getSymbolRefAttr(kCoroSave),
+      LLVM::LLVMTokenType::get(ctx), builder.getSymbolRefAttr(kCoroSave),
       ValueRange({coro.coroHandle}));
 
   // Call async runtime API to execute a coroutine in the managed thread.
   SmallVector<Value, 2> executeArgs = {coro.coroHandle, resumePtr.res()};
-  builder.create<CallOp>(loc, TypeRange(), kExecute, executeArgs);
+  builder.create<CallOp>(TypeRange(), kExecute, executeArgs);
 
   // Split the entry block before the terminator.
   auto *terminatorOp = entryBlock->getTerminator();
@@ -528,7 +527,7 @@ outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) {
   // Await on all dependencies before starting to execute the body region.
   builder.setInsertionPointToStart(resume);
   for (size_t i = 0; i < execute.dependencies().size(); ++i)
-    builder.create<AwaitOp>(loc, func.getArgument(i));
+    builder.create<AwaitOp>(func.getArgument(i));
 
   // Map from function inputs defined above the execute op to the function
   // arguments.
@@ -540,17 +539,16 @@ outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) {
   // to async runtime to emplace the result token.
   for (Operation &op : execute.body().getOps()) {
     if (isa<async::YieldOp>(op)) {
-      builder.create<CallOp>(loc, kEmplaceToken, TypeRange(), coro.asyncToken);
+      builder.create<CallOp>(kEmplaceToken, TypeRange(), coro.asyncToken);
       continue;
     }
     builder.clone(op, valueMapping);
   }
 
   // Replace the original `async.execute` with a call to outlined function.
-  OpBuilder callBuilder(execute);
-  auto callOutlinedFunc =
-      callBuilder.create<CallOp>(loc, func.getName(), execute.getResultTypes(),
-                                 functionInputs.getArrayRef());
+  ImplicitLocOpBuilder callBuilder(loc, execute);
+  auto callOutlinedFunc = callBuilder.create<CallOp>(
+      func.getName(), execute.getResultTypes(), functionInputs.getArrayRef());
   execute.replaceAllUsesWith(callOutlinedFunc.getResults());
   execute.erase();
 
@@ -744,24 +742,24 @@ class AwaitOpLoweringBase : public ConversionPattern {
     if (isInCoroutine) {
       const CoroMachinery &coro = outlined->getSecond();
 
-      OpBuilder builder(op, rewriter.getListener());
+      ImplicitLocOpBuilder builder(loc, op, rewriter.getListener());
       MLIRContext *ctx = op->getContext();
 
       // A pointer to coroutine resume intrinsic wrapper.
       auto resumeFnTy = AsyncAPI::resumeFunctionType(ctx);
       auto resumePtr = builder.create<LLVM::AddressOfOp>(
-          loc, LLVM::LLVMPointerType::get(resumeFnTy), kResume);
+          LLVM::LLVMPointerType::get(resumeFnTy), kResume);
 
       // Save the coroutine state: @llvm.coro.save
       auto coroSave = builder.create<LLVM::CallOp>(
-          loc, LLVM::LLVMTokenType::get(ctx),
-          builder.getSymbolRefAttr(kCoroSave), ValueRange(coro.coroHandle));
+          LLVM::LLVMTokenType::get(ctx), builder.getSymbolRefAttr(kCoroSave),
+          ValueRange(coro.coroHandle));
 
       // Call async runtime API to resume a coroutine in the managed thread when
       // the async await argument becomes ready.
       SmallVector<Value, 3> awaitAndExecuteArgs = {operands[0], coro.coroHandle,
                                                    resumePtr.res()};
-      builder.create<CallOp>(loc, TypeRange(), coroAwaitFuncName,
+      builder.create<CallOp>(TypeRange(), coroAwaitFuncName,
                              awaitAndExecuteArgs);
 
       Block *suspended = op->getBlock();


        


More information about the Mlir-commits mailing list