[Mlir-commits] [mlir] a5aa783 - [mlir:Async][NFC] Update Async API to use prefixed accessors
River Riddle
llvmlistbot at llvm.org
Fri Sep 30 15:34:18 PDT 2022
Author: River Riddle
Date: 2022-09-30T15:27:10-07:00
New Revision: a5aa783685c10f1326a6cb0bb93ebab0c5a3e78d
URL: https://github.com/llvm/llvm-project/commit/a5aa783685c10f1326a6cb0bb93ebab0c5a3e78d
DIFF: https://github.com/llvm/llvm-project/commit/a5aa783685c10f1326a6cb0bb93ebab0c5a3e78d.diff
LOG: [mlir:Async][NFC] Update Async API to use prefixed accessors
This doesn't flip the switch for prefix generation yet, that'll be
done in a followup.
Added:
Modified:
mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
mlir/lib/Dialect/Async/IR/Async.cpp
mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCountingOpt.cpp
mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
index 238f558b90d36..97fc875cb7de4 100644
--- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
+++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
@@ -70,7 +70,7 @@ namespace {
/// Async Runtime API function types.
///
/// Because we can't create API function signature for type parametrized
-/// async.value type, we use opaque pointers (!llvm.ptr<i8>) instead. After
+/// async.getValue type, we use opaque pointers (!llvm.ptr<i8>) instead. After
/// lowering all async data types become opaque pointers at runtime.
struct AsyncAPI {
// All async types are lowered to opaque i8* LLVM pointers at runtime.
@@ -383,7 +383,7 @@ class CoroBeginOpConversion : public OpConversionPattern<CoroBeginOp> {
loc, allocFuncOp, ValueRange{coroAlign, coroSize});
// Begin a coroutine: @llvm.coro.begin.
- auto coroId = CoroBeginOpAdaptor(adaptor.getOperands()).id();
+ auto coroId = CoroBeginOpAdaptor(adaptor.getOperands()).getId();
rewriter.replaceOpWithNewOp<LLVM::CoroBeginOp>(
op, i8Ptr, ValueRange({coroId, coroAlloc.getResult()}));
@@ -439,7 +439,7 @@ class CoroEndOpConversion : public OpConversionPattern<CoroEndOp> {
op->getLoc(), rewriter.getI1Type(), rewriter.getBoolAttr(false));
// Mark the end of a coroutine: @llvm.coro.end.
- auto coroHdl = adaptor.handle();
+ auto coroHdl = adaptor.getHandle();
rewriter.create<LLVM::CoroEndOp>(op->getLoc(), rewriter.getI1Type(),
ValueRange({coroHdl, constFalse}));
rewriter.eraseOp(op);
@@ -516,7 +516,7 @@ class CoroSuspendOpConversion : public OpConversionPattern<CoroSuspendOp> {
loc, rewriter.getI1Type(), rewriter.getBoolAttr(false));
// Suspend a coroutine: @llvm.coro.suspend
- auto coroState = adaptor.state();
+ auto coroState = adaptor.getState();
auto coroSuspend = rewriter.create<LLVM::CoroSuspendOp>(
loc, i8, ValueRange({coroState, constFalse}));
@@ -526,11 +526,11 @@ class CoroSuspendOpConversion : public OpConversionPattern<CoroSuspendOp> {
// or suspend block of the coroutine (see @llvm.coro.suspend return code
// documentation).
llvm::SmallVector<int32_t, 2> caseValues = {0, 1};
- llvm::SmallVector<Block *, 2> caseDest = {op.resumeDest(),
- op.cleanupDest()};
+ llvm::SmallVector<Block *, 2> caseDest = {op.getResumeDest(),
+ op.getCleanupDest()};
rewriter.replaceOpWithNewOp<LLVM::SwitchOp>(
op, rewriter.create<LLVM::SExtOp>(loc, i32, coroSuspend.getResult()),
- /*defaultDestination=*/op.suspendDest(),
+ /*defaultDestination=*/op.getSuspendDest(),
/*defaultOperands=*/ValueRange(),
/*caseValues=*/caseValues,
/*caseDestinations=*/caseDest,
@@ -634,7 +634,7 @@ class RuntimeSetAvailableOpLowering
matchAndRewrite(RuntimeSetAvailableOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
StringRef apiFuncName =
- TypeSwitch<Type, StringRef>(op.operand().getType())
+ TypeSwitch<Type, StringRef>(op.getOperand().getType())
.Case<TokenType>([](Type) { return kEmplaceToken; })
.Case<ValueType>([](Type) { return kEmplaceValue; });
@@ -660,7 +660,7 @@ class RuntimeSetErrorOpLowering
matchAndRewrite(RuntimeSetErrorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
StringRef apiFuncName =
- TypeSwitch<Type, StringRef>(op.operand().getType())
+ TypeSwitch<Type, StringRef>(op.getOperand().getType())
.Case<TokenType>([](Type) { return kSetTokenError; })
.Case<ValueType>([](Type) { return kSetValueError; });
@@ -685,7 +685,7 @@ class RuntimeIsErrorOpLowering : public OpConversionPattern<RuntimeIsErrorOp> {
matchAndRewrite(RuntimeIsErrorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
StringRef apiFuncName =
- TypeSwitch<Type, StringRef>(op.operand().getType())
+ TypeSwitch<Type, StringRef>(op.getOperand().getType())
.Case<TokenType>([](Type) { return kIsTokenError; })
.Case<GroupType>([](Type) { return kIsGroupError; })
.Case<ValueType>([](Type) { return kIsValueError; });
@@ -710,7 +710,7 @@ class RuntimeAwaitOpLowering : public OpConversionPattern<RuntimeAwaitOp> {
matchAndRewrite(RuntimeAwaitOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
StringRef apiFuncName =
- TypeSwitch<Type, StringRef>(op.operand().getType())
+ TypeSwitch<Type, StringRef>(op.getOperand().getType())
.Case<TokenType>([](Type) { return kAwaitToken; })
.Case<ValueType>([](Type) { return kAwaitValue; })
.Case<GroupType>([](Type) { return kAwaitGroup; });
@@ -738,13 +738,13 @@ class RuntimeAwaitAndResumeOpLowering
matchAndRewrite(RuntimeAwaitAndResumeOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
StringRef apiFuncName =
- TypeSwitch<Type, StringRef>(op.operand().getType())
+ TypeSwitch<Type, StringRef>(op.getOperand().getType())
.Case<TokenType>([](Type) { return kAwaitTokenAndExecute; })
.Case<ValueType>([](Type) { return kAwaitValueAndExecute; })
.Case<GroupType>([](Type) { return kAwaitAllAndExecute; });
- Value operand = adaptor.operand();
- Value handle = adaptor.handle();
+ Value operand = adaptor.getOperand();
+ Value handle = adaptor.getHandle();
// A pointer to coroutine resume intrinsic wrapper.
addResumeFunction(op->getParentOfType<ModuleOp>());
@@ -781,7 +781,7 @@ class RuntimeResumeOpLowering : public OpConversionPattern<RuntimeResumeOp> {
op->getLoc(), LLVM::LLVMPointerType::get(resumeFnTy), kResume);
// Call async runtime API to execute a coroutine in the managed thread.
- auto coroHdl = adaptor.handle();
+ auto coroHdl = adaptor.getHandle();
rewriter.replaceOpWithNewOp<func::CallOp>(
op, TypeRange(), kExecute, ValueRange({coroHdl, resumePtr.getRes()}));
@@ -806,12 +806,12 @@ class RuntimeStoreOpLowering : public OpConversionPattern<RuntimeStoreOp> {
// Get a pointer to the async value storage from the runtime.
auto i8Ptr = AsyncAPI::opaquePointerType(rewriter.getContext());
- auto storage = adaptor.storage();
+ auto storage = adaptor.getStorage();
auto storagePtr = rewriter.create<func::CallOp>(loc, kGetValueStorage,
TypeRange(i8Ptr), storage);
// Cast from i8* to the LLVM pointer type.
- auto valueType = op.value().getType();
+ auto valueType = op.getValue().getType();
auto llvmValueType = getTypeConverter()->convertType(valueType);
if (!llvmValueType)
return rewriter.notifyMatchFailure(
@@ -822,7 +822,7 @@ class RuntimeStoreOpLowering : public OpConversionPattern<RuntimeStoreOp> {
storagePtr.getResult(0));
// Store the yielded value into the async value storage.
- auto value = adaptor.value();
+ auto value = adaptor.getValue();
rewriter.create<LLVM::StoreOp>(loc, value, castedStoragePtr.getResult());
// Erase the original runtime store operation.
@@ -849,12 +849,12 @@ class RuntimeLoadOpLowering : public OpConversionPattern<RuntimeLoadOp> {
// Get a pointer to the async value storage from the runtime.
auto i8Ptr = AsyncAPI::opaquePointerType(rewriter.getContext());
- auto storage = adaptor.storage();
+ auto storage = adaptor.getStorage();
auto storagePtr = rewriter.create<func::CallOp>(loc, kGetValueStorage,
TypeRange(i8Ptr), storage);
// Cast from i8* to the LLVM pointer type.
- auto valueType = op.result().getType();
+ auto valueType = op.getResult().getType();
auto llvmValueType = getTypeConverter()->convertType(valueType);
if (!llvmValueType)
return rewriter.notifyMatchFailure(
@@ -886,7 +886,7 @@ class RuntimeAddToGroupOpLowering
matchAndRewrite(RuntimeAddToGroupOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Currently we can only add tokens to the group.
- if (!op.operand().getType().isa<TokenType>())
+ if (!op.getOperand().getType().isa<TokenType>())
return rewriter.notifyMatchFailure(op, "only token type is supported");
// Replace with a runtime API function call.
@@ -941,9 +941,9 @@ class RefCountingOpLowering : public OpConversionPattern<RefCountingOp> {
ConversionPatternRewriter &rewriter) const override {
auto count = rewriter.create<arith::ConstantOp>(
op->getLoc(), rewriter.getI64Type(),
- rewriter.getI64IntegerAttr(op.count()));
+ rewriter.getI64IntegerAttr(op.getCount()));
- auto operand = adaptor.operand();
+ auto operand = adaptor.getOperand();
rewriter.replaceOpWithNewOp<func::CallOp>(op, TypeRange(), apiFunctionName,
ValueRange({operand, count}));
diff --git a/mlir/lib/Dialect/Async/IR/Async.cpp b/mlir/lib/Dialect/Async/IR/Async.cpp
index c31fbc4af87ce..4b5d6a1d78fe1 100644
--- a/mlir/lib/Dialect/Async/IR/Async.cpp
+++ b/mlir/lib/Dialect/Async/IR/Async.cpp
@@ -38,7 +38,7 @@ LogicalResult YieldOp::verify() {
// parent `async.execute` operation.
auto executeOp = (*this)->getParentOfType<ExecuteOp>();
auto types =
- llvm::map_range(executeOp.bodyResults(), [](const OpResult &result) {
+ llvm::map_range(executeOp.getBodyResults(), [](const OpResult &result) {
return result.getType().cast<ValueType>().getValueType();
});
@@ -62,7 +62,7 @@ constexpr char kOperandSegmentSizesAttr[] = "operand_segment_sizes";
OperandRange ExecuteOp::getSuccessorEntryOperands(Optional<unsigned> index) {
assert(index && *index == 0 && "invalid region index");
- return bodyOperands();
+ return getBodyOperands();
}
bool ExecuteOp::areTypesCompatible(Type lhs, Type rhs) {
@@ -80,13 +80,13 @@ void ExecuteOp::getSuccessorRegions(Optional<unsigned> index,
// The `body` region branch back to the parent operation.
if (index) {
assert(*index == 0 && "invalid region index");
- regions.push_back(RegionSuccessor(bodyResults()));
+ regions.push_back(RegionSuccessor(getBodyResults()));
return;
}
// Otherwise the successor is the body region.
regions.push_back(
- RegionSuccessor(&bodyRegion(), bodyRegion().getArguments()));
+ RegionSuccessor(&getBodyRegion(), getBodyRegion().getArguments()));
}
void ExecuteOp::build(OpBuilder &builder, OperationState &result,
@@ -136,17 +136,18 @@ void ExecuteOp::build(OpBuilder &builder, OperationState &result,
void ExecuteOp::print(OpAsmPrinter &p) {
// [%tokens,...]
- if (!dependencies().empty())
- p << " [" << dependencies() << "]";
+ if (!getDependencies().empty())
+ p << " [" << getDependencies() << "]";
// (%value as %unwrapped: !async.value<!arg.type>, ...)
- if (!bodyOperands().empty()) {
+ if (!getBodyOperands().empty()) {
p << " (";
- Block *entry = bodyRegion().empty() ? nullptr : &bodyRegion().front();
- llvm::interleaveComma(bodyOperands(), p, [&, n = 0](Value operand) mutable {
- Value argument = entry ? entry->getArgument(n++) : Value();
- p << operand << " as " << argument << ": " << operand.getType();
- });
+ Block *entry = getBodyRegion().empty() ? nullptr : &getBodyRegion().front();
+ llvm::interleaveComma(
+ getBodyOperands(), p, [&, n = 0](Value operand) mutable {
+ Value argument = entry ? entry->getArgument(n++) : Value();
+ p << operand << " as " << argument << ": " << operand.getType();
+ });
p << ")";
}
@@ -155,7 +156,7 @@ void ExecuteOp::print(OpAsmPrinter &p) {
p.printOptionalAttrDictWithKeyword((*this)->getAttrs(),
{kOperandSegmentSizesAttr});
p << ' ';
- p.printRegion(bodyRegion(), /*printEntryBlockArgs=*/false);
+ p.printRegion(getBodyRegion(), /*printEntryBlockArgs=*/false);
}
ParseResult ExecuteOp::parse(OpAsmParser &parser, OperationState &result) {
@@ -228,12 +229,12 @@ ParseResult ExecuteOp::parse(OpAsmParser &parser, OperationState &result) {
LogicalResult ExecuteOp::verifyRegions() {
// Unwrap async.execute value operands types.
- auto unwrappedTypes = llvm::map_range(bodyOperands(), [](Value operand) {
+ auto unwrappedTypes = llvm::map_range(getBodyOperands(), [](Value operand) {
return operand.getType().cast<ValueType>().getValueType();
});
// Verify that unwrapped argument types matches the body region arguments.
- if (bodyRegion().getArgumentTypes() != unwrappedTypes)
+ if (getBodyRegion().getArgumentTypes() != unwrappedTypes)
return emitOpError("async body region argument types do not match the "
"execute operation arguments types");
@@ -302,7 +303,7 @@ static void printAwaitResultType(OpAsmPrinter &p, Operation *op,
}
LogicalResult AwaitOp::verify() {
- Type argType = operand().getType();
+ Type argType = getOperand().getType();
// Awaiting on a token does not have any results.
if (argType.isa<TokenType>() && !getResultTypes().empty())
diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
index be9eff41d83e1..04977a98e6ba2 100644
--- a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
+++ b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
@@ -555,7 +555,7 @@ createAsyncDispatchFunction(ParallelComputeFunction &computeFunc,
// Create async.execute operation to dispatch half of the block range.
auto execute = b.create<ExecuteOp>(TypeRange(), ValueRange(), ValueRange(),
executeBodyBuilder);
- b.create<AddToGroupOp>(indexTy, execute.token(), group);
+ b.create<AddToGroupOp>(indexTy, execute.getToken(), group);
b.create<scf::YieldOp>(ValueRange({start, midIndex}));
}
@@ -702,7 +702,7 @@ doSequentialDispatch(ImplicitLocOpBuilder &b, PatternRewriter &rewriter,
// Create async.execute operation to launch parallel computate function.
auto execute = b.create<ExecuteOp>(TypeRange(), ValueRange(), ValueRange(),
executeBodyBuilder);
- b.create<AddToGroupOp>(rewriter.getIndexType(), execute.token(), group);
+ b.create<AddToGroupOp>(rewriter.getIndexType(), execute.getToken(), group);
b.create<scf::YieldOp>();
};
diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCountingOpt.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCountingOpt.cpp
index 5a8a9af6da9ce..e4504df93b82d 100644
--- a/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCountingOpt.cpp
+++ b/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCountingOpt.cpp
@@ -112,7 +112,7 @@ LogicalResult AsyncRuntimeRefCountingOptPass::optimizeReferenceCounting(
for (RuntimeAddRefOp addRef : info.addRefs) {
for (RuntimeDropRefOp dropRef : info.dropRefs) {
// `drop_ref` operation after the `add_ref` with matching count.
- if (dropRef.count() != addRef.count() ||
+ if (dropRef.getCount() != addRef.getCount() ||
dropRef->isBeforeInBlock(addRef.getOperation()))
continue;
diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
index 1df732573cce1..b4880c0e3b3f5 100644
--- a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
+++ b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
@@ -74,7 +74,7 @@ struct CoroMachinery {
Value asyncToken; // token representing completion of the async region
llvm::SmallVector<Value, 4> returnValues; // returned async values
- Value coroHandle; // coroutine handle (!async.coro.handle value)
+ Value coroHandle; // coroutine handle (!async.coro.getHandle value)
Block *entry; // coroutine entry block
Block *setError; // switch completion token and all values to error state
Block *cleanup; // coroutine cleanup block
@@ -110,7 +110,7 @@ struct CoroMachinery {
/// ^entry(<function-arguments>):
/// %token = <async token> : !async.token // create async runtime token
/// %value = <async value> : !async.value<T> // create async value
-/// %id = async.coro.id // create a coroutine id
+/// %id = async.coro.getId // create a coroutine id
/// %hdl = async.coro.begin %id // create a coroutine handle
/// cf.br ^preexisting_entry_block
///
@@ -142,18 +142,20 @@ static CoroMachinery setupCoroMachinery(func::FuncOp func) {
// ------------------------------------------------------------------------ //
// Allocate async token/values that we will return from a ramp function.
// ------------------------------------------------------------------------ //
- auto retToken = builder.create<RuntimeCreateOp>(TokenType::get(ctx)).result();
+ auto retToken =
+ builder.create<RuntimeCreateOp>(TokenType::get(ctx)).getResult();
llvm::SmallVector<Value, 4> retValues;
for (auto resType : func.getCallableResults().drop_front())
- retValues.emplace_back(builder.create<RuntimeCreateOp>(resType).result());
+ retValues.emplace_back(
+ builder.create<RuntimeCreateOp>(resType).getResult());
// ------------------------------------------------------------------------ //
// Initialize coroutine: get coroutine id and coroutine handle.
// ------------------------------------------------------------------------ //
auto coroIdOp = builder.create<CoroIdOp>(CoroIdType::get(ctx));
auto coroHdlOp =
- builder.create<CoroBeginOp>(CoroHandleType::get(ctx), coroIdOp.id());
+ builder.create<CoroBeginOp>(CoroHandleType::get(ctx), coroIdOp.getId());
builder.create<cf::BranchOp>(originalEntryBlock);
Block *cleanupBlock = func.addBlock();
@@ -163,7 +165,7 @@ static CoroMachinery setupCoroMachinery(func::FuncOp func) {
// Coroutine cleanup block: deallocate coroutine frame, free the memory.
// ------------------------------------------------------------------------ //
builder.setInsertionPointToStart(cleanupBlock);
- builder.create<CoroFreeOp>(coroIdOp.id(), coroHdlOp.handle());
+ builder.create<CoroFreeOp>(coroIdOp.getId(), coroHdlOp.getHandle());
// Branch into the suspend block.
builder.create<cf::BranchOp>(suspendBlock);
@@ -175,7 +177,7 @@ static CoroMachinery setupCoroMachinery(func::FuncOp func) {
builder.setInsertionPointToStart(suspendBlock);
// Mark the end of a coroutine: async.coro.end
- builder.create<CoroEndOp>(coroHdlOp.handle());
+ builder.create<CoroEndOp>(coroHdlOp.getHandle());
// Return created `async.token` and `async.values` from the suspend block.
// This will be the return value of a coroutine ramp function.
@@ -206,7 +208,7 @@ static CoroMachinery setupCoroMachinery(func::FuncOp func) {
machinery.func = func;
machinery.asyncToken = retToken;
machinery.returnValues = retValues;
- machinery.coroHandle = coroHdlOp.handle();
+ machinery.coroHandle = coroHdlOp.getHandle();
machinery.entry = entryBlock;
machinery.setError = nullptr; // created lazily only if needed
machinery.cleanup = cleanupBlock;
@@ -250,14 +252,14 @@ outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) {
// Make sure that all constants will be inside the outlined async function to
// reduce the number of function arguments.
- cloneConstantsIntoTheRegion(execute.bodyRegion());
+ cloneConstantsIntoTheRegion(execute.getBodyRegion());
// Collect all outlined function inputs.
- SetVector<mlir::Value> functionInputs(execute.dependencies().begin(),
- execute.dependencies().end());
- functionInputs.insert(execute.bodyOperands().begin(),
- execute.bodyOperands().end());
- getUsedValuesDefinedAbove(execute.bodyRegion(), functionInputs);
+ SetVector<mlir::Value> functionInputs(execute.getDependencies().begin(),
+ execute.getDependencies().end());
+ functionInputs.insert(execute.getBodyOperands().begin(),
+ execute.getBodyOperands().end());
+ getUsedValuesDefinedAbove(execute.getBodyRegion(), functionInputs);
// Collect types for the outlined function inputs and outputs.
auto typesRange = llvm::map_range(
@@ -279,8 +281,8 @@ outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) {
// Prepare for coroutine conversion by creating the body of the function.
{
- size_t numDependencies = execute.dependencies().size();
- size_t numOperands = execute.bodyOperands().size();
+ size_t numDependencies = execute.getDependencies().size();
+ size_t numOperands = execute.getBodyOperands().size();
// Await on all dependencies before starting to execute the body region.
for (size_t i = 0; i < numDependencies; ++i)
@@ -290,18 +292,18 @@ outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) {
SmallVector<Value, 4> unwrappedOperands(numOperands);
for (size_t i = 0; i < numOperands; ++i) {
Value operand = func.getArgument(numDependencies + i);
- unwrappedOperands[i] = builder.create<AwaitOp>(loc, operand).result();
+ unwrappedOperands[i] = builder.create<AwaitOp>(loc, operand).getResult();
}
// Map from function inputs defined above the execute op to the function
// arguments.
BlockAndValueMapping valueMapping;
valueMapping.map(functionInputs, func.getArguments());
- valueMapping.map(execute.bodyRegion().getArguments(), unwrappedOperands);
+ valueMapping.map(execute.getBodyRegion().getArguments(), unwrappedOperands);
// Clone all operations from the execute operation body into the outlined
// function body.
- for (Operation &op : execute.bodyRegion().getOps())
+ for (Operation &op : execute.getBodyRegion().getOps())
builder.clone(op, valueMapping);
}
@@ -324,7 +326,7 @@ outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) {
builder.create<RuntimeResumeOp>(coro.coroHandle);
// Add async.coro.suspend as a suspended block terminator.
- builder.create<CoroSuspendOp>(coroSaveOp.state(), coro.suspend,
+ builder.create<CoroSuspendOp>(coroSaveOp.getState(), coro.suspend,
branch.getDest(), coro.cleanup);
branch.erase();
@@ -402,7 +404,7 @@ class AwaitOpLoweringBase : public OpConversionPattern<AwaitType> {
ConversionPatternRewriter &rewriter) const override {
// We can only await on one the `AwaitableType` (for `await` it can be
// a `token` or a `value`, for `await_all` it must be a `group`).
- if (!op.operand().getType().template isa<AwaitableType>())
+ if (!op.getOperand().getType().template isa<AwaitableType>())
return rewriter.notifyMatchFailure(op, "unsupported awaitable type");
// Check if await operation is inside the outlined coroutine function.
@@ -411,7 +413,7 @@ class AwaitOpLoweringBase : public OpConversionPattern<AwaitType> {
const bool isInCoroutine = outlined != outlinedFunctions.end();
Location loc = op->getLoc();
- Value operand = adaptor.operand();
+ Value operand = adaptor.getOperand();
Type i1 = rewriter.getI1Type();
@@ -451,7 +453,7 @@ class AwaitOpLoweringBase : public OpConversionPattern<AwaitType> {
// Add async.coro.suspend as a suspended block terminator.
builder.setInsertionPointToEnd(suspended);
- builder.create<CoroSuspendOp>(coroSaveOp.state(), coro.suspend, resume,
+ builder.create<CoroSuspendOp>(coroSaveOp.getState(), coro.suspend, resume,
coro.cleanup);
// Split the resume block into error checking and continuation.
@@ -653,7 +655,7 @@ static void rewriteCallsiteForCoroutine(func::CallOp oldCall,
unwrappedResults.reserve(newCall->getResults().size() - 1);
for (Value result : newCall.getResults().drop_front())
unwrappedResults.push_back(
- callBuilder.create<AwaitOp>(loc, result).result());
+ callBuilder.create<AwaitOp>(loc, result).getResult());
// Careful, when result of a call is piped into another call this could lead
// to a dangling pointer.
oldCall.replaceAllUsesWith(unwrappedResults);
diff --git a/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp b/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp
index 1b55fab6d134b..ab0aa7cc2ed25 100644
--- a/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp
@@ -169,7 +169,7 @@ async::ExecuteOp addExecuteResults(async::ExecuteOp executeOp,
OpBuilder builder(executeOp);
auto newOp = builder.create<async::ExecuteOp>(
executeOp.getLoc(), TypeRange{resultTypes}.drop_front() /*drop token*/,
- executeOp.dependencies(), executeOp.bodyOperands());
+ executeOp.getDependencies(), executeOp.getBodyOperands());
BlockAndValueMapping mapper;
newOp.getRegion().getBlocks().clear();
executeOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
@@ -189,7 +189,7 @@ struct GpuAsyncRegionPass::DeferWaitCallback {
// ops, add the region's last `gpu.wait` op to the worklist if it is
// synchronous and is the last op with side effects.
void operator()(async::ExecuteOp executeOp) {
- if (!areAllUsersExecuteOrAwait(executeOp.token()))
+ if (!areAllUsersExecuteOrAwait(executeOp.getToken()))
return;
// async.execute's region is currently restricted to one block.
for (auto &op : llvm::reverse(executeOp.getBody()->without_terminator())) {
@@ -210,14 +210,14 @@ struct GpuAsyncRegionPass::DeferWaitCallback {
auto executeOp = waitOp->getParentOfType<async::ExecuteOp>();
// Erase `gpu.wait` and return async dependencies from execute op instead.
- SmallVector<Value, 4> dependencies = waitOp.asyncDependencies();
+ SmallVector<Value, 4> dependencies = waitOp.getAsyncDependencies();
waitOp.erase();
executeOp = addExecuteResults(executeOp, dependencies);
// Add the async dependency to each user of the `async.execute` token.
auto asyncTokens = executeOp.getResults().take_back(dependencies.size());
- SmallVector<Operation *, 4> users(executeOp.token().user_begin(),
- executeOp.token().user_end());
+ SmallVector<Operation *, 4> users(executeOp.getToken().user_begin(),
+ executeOp.getToken().user_end());
for (Operation *user : users)
addAsyncDependencyAfter(asyncTokens, user);
}
@@ -250,7 +250,7 @@ struct GpuAsyncRegionPass::DeferWaitCallback {
builder.setInsertionPointAfter(op);
for (auto asyncToken : asyncTokens)
tokens.push_back(
- builder.create<async::AwaitOp>(loc, asyncToken).result());
+ builder.create<async::AwaitOp>(loc, asyncToken).getResult());
// Set `it` after the inserted async.await ops.
it = builder.getInsertionPoint();
})
@@ -258,7 +258,7 @@ struct GpuAsyncRegionPass::DeferWaitCallback {
// Set `it` to the beginning of the region and add asyncTokens to the
// async.execute operands.
it = executeOp.getBody()->begin();
- executeOp.bodyOperandsMutable().append(asyncTokens);
+ executeOp.getBodyOperandsMutable().append(asyncTokens);
SmallVector<Type, 1> tokenTypes(
asyncTokens.size(), builder.getType<gpu::AsyncTokenType>());
SmallVector<Location, 1> tokenLocs(asyncTokens.size(),
@@ -287,7 +287,7 @@ struct GpuAsyncRegionPass::DeferWaitCallback {
// If the new waitOp is at the end of an async.execute region, add it to the
// worklist. 'operator()(executeOp)' would do the same, but this is faster.
auto executeOp = dyn_cast<async::ExecuteOp>(it->getParentOp());
- if (executeOp && areAllUsersExecuteOrAwait(executeOp.token()) &&
+ if (executeOp && areAllUsersExecuteOrAwait(executeOp.getToken()) &&
!it->getNextNode())
worklist.push_back(waitOp);
}
@@ -300,8 +300,8 @@ struct GpuAsyncRegionPass::DeferWaitCallback {
struct GpuAsyncRegionPass::SingleTokenUseCallback {
void operator()(async::ExecuteOp executeOp) {
// Extract !gpu.async.token results which have multiple uses.
- auto multiUseResults =
- llvm::make_filter_range(executeOp.bodyResults(), [](OpResult result) {
+ auto multiUseResults = llvm::make_filter_range(
+ executeOp.getBodyResults(), [](OpResult result) {
if (result.use_empty() || result.hasOneUse())
return false;
auto valueType = result.getType().dyn_cast<async::ValueType>();
@@ -319,16 +319,16 @@ struct GpuAsyncRegionPass::SingleTokenUseCallback {
});
for (auto index : indices) {
- assert(!executeOp.bodyResults()[index].getUses().empty());
+ assert(!executeOp.getBodyResults()[index].getUses().empty());
// Repeat async.yield token result, one for each use after the first one.
- auto uses = llvm::drop_begin(executeOp.bodyResults()[index].getUses());
+ auto uses = llvm::drop_begin(executeOp.getBodyResults()[index].getUses());
auto count = std::distance(uses.begin(), uses.end());
auto yieldOp = cast<async::YieldOp>(executeOp.getBody()->getTerminator());
SmallVector<Value, 4> operands(count, yieldOp.getOperand(index));
executeOp = addExecuteResults(executeOp, operands);
// Update 'uses' to refer to the new executeOp.
- uses = llvm::drop_begin(executeOp.bodyResults()[index].getUses());
- auto results = executeOp.bodyResults().take_back(count);
+ uses = llvm::drop_begin(executeOp.getBodyResults()[index].getUses());
+ auto results = executeOp.getBodyResults().take_back(count);
for (auto pair : llvm::zip(uses, results))
std::get<0>(pair).set(std::get<1>(pair));
}
More information about the Mlir-commits
mailing list