[Mlir-commits] [mlir] [mlir][IR] Add builtin `TokenType` (PR #195640)
Matthias Springer
llvmlistbot at llvm.org
Mon May 25 17:31:26 PDT 2026
https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/195640
>From 3db869216970c06ad49b0850b0e507577ddd7d75 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Tue, 26 May 2026 00:03:41 +0000
Subject: [PATCH 01/15] [mlir][async] Lazily create the coroutine
destroy-cleanup block
`setupCoroMachinery` previously emitted a `cleanupForDestroy` block
unconditionally, alongside the normal `cleanup` block. That block is
only ever used as the "destroy" successor of an `async.coro.suspend`,
so for coroutines that never suspend (e.g. an `async.func` whose body
contains no `async.await`) it ended up unreachable in the lowered CFG.
Make `cleanupForDestroy` mirror the existing `setError` pattern and
materialize it lazily via a new `setupCleanupForDestroyBlock` helper,
called only from the two places (`outlineExecuteOp` and the
`async.await` lowering) that actually wire it up. Store the coroutine
id on `CoroMachinery` so the helper can rebuild the block contents
without keeping the original `async.coro.id` op around.
Assisted-by: Opus 4.7
---
.../Async/Transforms/AsyncToAsyncRuntime.cpp | 48 +++++++++++++------
.../Dialect/Async/async-to-async-runtime.mlir | 33 +++++++++++++
2 files changed, 66 insertions(+), 15 deletions(-)
diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
index 0c5bcfe631c6c..6ed50671fb3b3 100644
--- a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
+++ b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
@@ -90,6 +90,7 @@ struct CoroMachinery {
std::optional<Value> asyncToken; // returned completion token
llvm::SmallVector<Value, 4> returnValues; // returned async values
+ Value coroId; // coroutine id (!async.coro.id value)
Value coroHandle; // coroutine handle (!async.coro.getHandle value)
Block *entry; // coroutine entry block
std::optional<Block *> setError; // set returned values to error state
@@ -115,7 +116,12 @@ struct CoroMachinery {
// If there is resume-specific cleanup logic, it can go into the Cleanup
// block but not the destroy block. Otherwise, it can fail block dominance
// check.
- Block *cleanupForDestroy;
+ //
+ // This block is created lazily by `setupCleanupForDestroyBlock` only when a
+ // suspension point needs a destroy successor, so that functions without any
+ // coroutine suspends (e.g. an `async.func` body with no `await`) don't end
+ // up with dead code.
+ std::optional<Block *> cleanupForDestroy;
Block *suspend; // coroutine suspension block
};
} // namespace
@@ -204,21 +210,16 @@ static CoroMachinery setupCoroMachinery(func::FuncOp func) {
cf::BranchOp::create(builder, originalEntryBlock);
Block *cleanupBlock = func.addBlock();
- Block *cleanupBlockForDestroy = func.addBlock();
Block *suspendBlock = func.addBlock();
// ------------------------------------------------------------------------ //
- // Coroutine cleanup blocks: deallocate coroutine frame, free the memory.
+ // Coroutine cleanup block: deallocate coroutine frame, free the memory.
// ------------------------------------------------------------------------ //
- auto buildCleanupBlock = [&](Block *cb) {
- builder.setInsertionPointToStart(cb);
- CoroFreeOp::create(builder, coroIdOp.getId(), coroHdlOp.getHandle());
-
- // Branch into the suspend block.
- cf::BranchOp::create(builder, suspendBlock);
- };
- buildCleanupBlock(cleanupBlock);
- buildCleanupBlock(cleanupBlockForDestroy);
+ // The matching "destroy" cleanup block is materialized lazily by
+ // `setupCleanupForDestroyBlock` only when a suspend point needs it.
+ builder.setInsertionPointToStart(cleanupBlock);
+ CoroFreeOp::create(builder, coroIdOp.getId(), coroHdlOp.getHandle());
+ cf::BranchOp::create(builder, suspendBlock);
// ------------------------------------------------------------------------ //
// Coroutine suspend block: mark the end of a coroutine and return allocated
@@ -249,11 +250,12 @@ static CoroMachinery setupCoroMachinery(func::FuncOp func) {
machinery.func = func;
machinery.asyncToken = retToken;
machinery.returnValues = retValues;
+ machinery.coroId = coroIdOp.getId();
machinery.coroHandle = coroHdlOp.getHandle();
machinery.entry = entryBlock;
machinery.setError = std::nullopt; // created lazily only if needed
machinery.cleanup = cleanupBlock;
- machinery.cleanupForDestroy = cleanupBlockForDestroy;
+ machinery.cleanupForDestroy = std::nullopt; // created lazily only if needed
machinery.suspend = suspendBlock;
return machinery;
}
@@ -283,6 +285,20 @@ static Block *setupSetErrorBlock(CoroMachinery &coro) {
return *coro.setError;
}
+// Lazily creates the `cleanupForDestroy` block only if a suspension point
+// actually needs a destroy successor. This avoids leaving an unreachable
+// cleanup block behind in coroutines that never suspend.
+static Block *setupCleanupForDestroyBlock(ImplicitLocOpBuilder &builder,
+ CoroMachinery &coro) {
+ if (coro.cleanupForDestroy)
+ return *coro.cleanupForDestroy;
+ OpBuilder::InsertionGuard guard(builder);
+ coro.cleanupForDestroy = builder.createBlock(coro.suspend);
+ CoroFreeOp::create(builder, coro.coroId, coro.coroHandle);
+ cf::BranchOp::create(builder, coro.suspend);
+ return *coro.cleanupForDestroy;
+}
+
//===----------------------------------------------------------------------===//
// async.execute op outlining to the coroutine functions.
//===----------------------------------------------------------------------===//
@@ -373,8 +389,9 @@ outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) {
RuntimeResumeOp::create(builder, coro.coroHandle);
// Add async.coro.suspend as a suspended block terminator.
+ Block *destroy = setupCleanupForDestroyBlock(builder, coro);
CoroSuspendOp::create(builder, coroSaveOp.getState(), coro.suspend,
- branch.getDest(), coro.cleanupForDestroy);
+ branch.getDest(), destroy);
branch.erase();
}
@@ -614,9 +631,10 @@ class AwaitOpLoweringBase : public OpConversionPattern<AwaitType> {
Block *resume = rewriter.splitBlock(suspended, Block::iterator(op));
// Add async.coro.suspend as a suspended block terminator.
+ Block *destroy = setupCleanupForDestroyBlock(builder, coro);
builder.setInsertionPointToEnd(suspended);
CoroSuspendOp::create(builder, coroSaveOp.getState(), coro.suspend,
- resume, coro.cleanupForDestroy);
+ resume, destroy);
// Split the resume block into error checking and continuation.
Block *continuation = rewriter.splitBlock(resume, Block::iterator(op));
diff --git a/mlir/test/Dialect/Async/async-to-async-runtime.mlir b/mlir/test/Dialect/Async/async-to-async-runtime.mlir
index 36583b2b94a3c..c7734aa044519 100644
--- a/mlir/test/Dialect/Async/async-to-async-runtime.mlir
+++ b/mlir/test/Dialect/Async/async-to-async-runtime.mlir
@@ -498,3 +498,36 @@ async.func @execute_in_async_func(%arg0: f32, %arg1: memref<1xf32>)
// CHECK-SAME: ) -> !async.token
// CHECK: %[[CST:.*]] = arith.constant 0 : index
// CHECK: memref.store %[[VALUE]], %[[MEMREF]][%[[CST]]]
+
+// -----
+
+// An async.func with no suspension points must not leave behind an
+// unreachable `cleanupForDestroy` block. The block is created lazily and is
+// only needed as the destroy successor of an `async.coro.suspend`, so an
+// empty body that never suspends should only have a single cleanup block.
+
+// CHECK-LABEL: @async_func_empty
+async.func @async_func_empty() -> !async.token {
+ return
+}
+// CHECK: %[[TOKEN:.*]] = async.runtime.create : !async.token
+// CHECK: %[[ID:.*]] = async.coro.id
+// CHECK: %[[HDL:.*]] = async.coro.begin
+// CHECK: cf.br ^[[ORIGIN_ENTRY:.*]]
+
+// CHECK: ^[[ORIGIN_ENTRY]]:
+// CHECK-NEXT: async.runtime.set_available %[[TOKEN]]
+// CHECK-NEXT: cf.br ^[[CLEANUP:.*]]
+
+// CHECK: ^[[CLEANUP]]:
+// CHECK-NEXT: async.coro.free %[[ID]], %[[HDL]]
+// CHECK-NEXT: cf.br ^[[SUSPEND:.*]]
+
+// CHECK: ^[[SUSPEND]]:
+// CHECK-NEXT: async.coro.end %[[HDL]]
+// CHECK-NEXT: return %[[TOKEN]]
+
+// There must be exactly one async.coro.free op in the lowered function:
+// the destroy-cleanup block (which would contain a second one) should not
+// have been emitted.
+// CHECK-NOT: async.coro.free
>From dcd3cab086ef7e336c344113be6c7baf6b551002 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Mon, 4 May 2026 12:14:41 +0000
Subject: [PATCH 02/15] [mlir][IR] Add builtin `TokenTypeInterface`
type instead of type interface
add bytecode
---
mlir/docs/Tokens.md | 104 ++++++++++++++++++
mlir/include/mlir/Dialect/Async/IR/Async.h | 2 +-
.../include/mlir/Dialect/Async/IR/AsyncOps.td | 4 +-
.../include/mlir/IR/BuiltinDialectBytecode.td | 5 +-
mlir/include/mlir/IR/BuiltinTypes.td | 20 ++++
mlir/include/mlir/IR/CommonTypeConstraints.td | 20 +++-
mlir/lib/AsmParser/TokenKinds.def | 1 +
mlir/lib/AsmParser/TypeParser.cpp | 6 +
.../Conversion/AsyncToLLVM/AsyncToLLVM.cpp | 35 +++---
mlir/lib/Dialect/Async/IR/Async.cpp | 10 +-
.../Transforms/AsyncRuntimeRefCounting.cpp | 2 +-
.../Async/Transforms/AsyncToAsyncRuntime.cpp | 14 ++-
mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp | 60 ++++++----
mlir/lib/IR/AsmPrinter.cpp | 1 +
mlir/test/Dialect/ArmSME/invalid.mlir | 4 +-
.../Builtin/Bytecode/builtin_fixed.mlir | 4 +-
.../Builtin/Bytecode/builtin_fixed_0.mlirbc | Bin 4471 -> 4503 bytes
mlir/test/Dialect/Builtin/Bytecode/types.mlir | 10 ++
mlir/test/Dialect/Linalg/invalid.mlir | 4 +-
mlir/test/Dialect/MemRef/invalid.mlir | 4 +-
mlir/test/Dialect/SparseTensor/invalid.mlir | 24 ++--
mlir/test/Dialect/Tensor/invalid.mlir | 2 +-
mlir/test/Dialect/Vector/invalid.mlir | 10 +-
mlir/test/Dialect/traits.mlir | 2 +-
mlir/test/IR/operand.mlir | 6 +-
mlir/test/IR/result.mlir | 6 +-
mlir/test/IR/token-type.mlir | 60 ++++++++++
mlir/test/lib/Dialect/Test/TestOps.td | 30 +++++
mlir/test/mlir-tblgen/predicate.td | 4 +-
mlir/test/mlir-tblgen/types.mlir | 6 +-
30 files changed, 365 insertions(+), 95 deletions(-)
create mode 100644 mlir/docs/Tokens.md
create mode 100644 mlir/test/IR/token-type.mlir
diff --git a/mlir/docs/Tokens.md b/mlir/docs/Tokens.md
new file mode 100644
index 0000000000000..51d4005f4723c
--- /dev/null
+++ b/mlir/docs/Tokens.md
@@ -0,0 +1,104 @@
+# Tokens
+
+[TOC]
+
+## Overview
+
+Intuitively, a *token* value is a pointer to an operation (via an OpResult)
+or a pointer to a region (via an entry block argument). A token cannot be
+forwarded: a token def-use chain cannot be obscured by ops with forwarding
+semantics such as `arith.select` or `cf.br`. This allows you to always walk
+back from a use and say "this token came from *that* specific op".
+
+A token is an SSA value that has the builtin token type. The token type is
+parameterless, opaque and prints as `token`. A token carries no runtime data.
+Apart from the structural contract below, tokens are like any other SSA values.
+
+## Design Rationale
+
+The token type allows operations to refer to another operation without a new
+parallel def-use system for operations. The existing def-use machinery for SSA
+values can be reused. Moreover, no changes were needed for the generic op
+syntax, the bytecode infrastructure and core C++ APIs around `Operation`.
+
+As with regular use-def chains, a token def-use chain is unidirectional. A
+token use points to the token's definition. (But not the other way around.)
+Transformations can remove the use of a token without having to touch or
+inspect the definition of the token. (Whether such a transformation is correct
+depends on the semantics of the token-producing and token-consuming ops.)
+
+Token-producing and token-consuming ops are subject to standard transformations
+such as CSE, DCE and hoisting. If such transformations are not desirable due to
+op semantics, common IR design patterns can be employed. To give a few examples:
+Terminators or ops with side effects are not CSE'd or DCE'd. Region block
+arguments semantically belong to the enclosing op and are never CSE'd, DCE'd or
+hoisted. Non-speculatability may also prevent hoisting.
+
+## Structural Contract
+
+1. A token must not appear as a forwarded value, e.g.:
+ * a forwarded result/operand of a `CallOpInterface` op,
+ * an argument or result type of a `FunctionOpInterface` op (a token
+ block argument *inside* a function body is fine — what is disallowed
+ is forwarding tokens across the call/return boundary),
+ * a successor operand or successor block argument of a
+ `BranchOpInterface` op,
+ * a forwarded operand to/from any region of a `RegionBranchOpInterface`
+ op (iter-args, region results, yielded values), or
+ * the result of any op that selects or merges values it does not
+ understand (e.g. `arith.select`).
+
+2. As a consequence of (1), given a use of a token SSA value, its definition is
+ guaranteed to be the semantic producer of the token.
+
+3. A token cannot constant-fold. No constant of token type exists.
+
+These properties mirror what LLVM IR already documents for its own
+[`token` type](https://llvm.org/docs/LangRef.html#token-type).
+
+## ODS Integration
+
+Tokens are excluded from the default `AnyType` predicate, so an op that has
+not opted in cannot accept a token as an arbitrary operand or result. This
+restriction prevents tokens from being accidentally passed as operands with
+forwarding semantics.
+
+Three predicates are provided in `CommonTypeConstraints.td`:
+
+| Predicate | Accepts | Use when … |
+| ------------------ | ------------------------------------ | ----------------------------------------------------------------------|
+| `AnyType` | any non-token type | the default; matches the historical meaning of "any type" pre-tokens. |
+| `AnyTypeOrToken` | any type, including tokens | the op legitimately accepts arbitrary types (including tokens). |
+| `Token` | only the builtin `TokenType` | the op specifically takes a token operand/result. |
+
+Example:
+
+```tablegen
+def MyConsumeOp : MyDialect_Op<"consume"> {
+ let arguments = (ins Token:$scope, AnyType:$value);
+}
+```
+
+## Examples
+
+### Rejected: tokens in `AnyType` positions
+
+`scf.yield` operands have forwarding semantics. A token cannot be yielded from
+a branch or a loop.
+
+```mlir
+// error: 'scf.if' op result #0 must be variadic of any non-token type,
+// but got 'token'
+%t = scf.if %cond -> token {
+ %a = my.token.produce : token
+ scf.yield %a : token
+} else {
+ %b = my.token.produce : token
+ scf.yield %b : token
+}
+```
+
+`scf.if`'s results are declared with `Variadic<AnyType>` and `scf.yield`'s
+operands likewise use `AnyType`. Because `AnyType` excludes tokens by
+default, yielding (or returning) a token through a `scf.if` (or any other
+op that has not explicitly opted in via `AnyTypeOrToken`) is rejected.
diff --git a/mlir/include/mlir/Dialect/Async/IR/Async.h b/mlir/include/mlir/Dialect/Async/IR/Async.h
index f16e87e71373a..fc0b086126f52 100644
--- a/mlir/include/mlir/Dialect/Async/IR/Async.h
+++ b/mlir/include/mlir/Dialect/Async/IR/Async.h
@@ -50,7 +50,7 @@ namespace async {
/// Returns true if the type is reference counted at runtime.
inline bool isRefCounted(Type type) {
- return isa<TokenType, ValueType, GroupType>(type);
+ return isa<async::TokenType, ValueType, GroupType>(type);
}
} // namespace async
diff --git a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td
index 2cebeac767f29..058f58bda6433 100644
--- a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td
+++ b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td
@@ -174,7 +174,9 @@ def Async_FuncOp : Async_Op<"func",
unsigned getNumResults() {return getResultTypes().size();}
/// Is the async func stateful
- bool isStateful() { return isa<TokenType>(getFunctionType().getResult(0));}
+ bool isStateful() {
+ return isa<async::TokenType>(getFunctionType().getResult(0));
+ }
//===------------------------------------------------------------------===//
// OpAsmOpInterface Methods
diff --git a/mlir/include/mlir/IR/BuiltinDialectBytecode.td b/mlir/include/mlir/IR/BuiltinDialectBytecode.td
index 207b99164d0d6..213ba260cc4fd 100644
--- a/mlir/include/mlir/IR/BuiltinDialectBytecode.td
+++ b/mlir/include/mlir/IR/BuiltinDialectBytecode.td
@@ -324,6 +324,8 @@ def UnrankedTensorType : DialectType<(type
Type:$elementType
)>;
+def TokenType : DialectType<(type)>;
+
let cType = "VectorType" in {
def VectorType : DialectType<(type
Array<SignedVarIntList>:$shape,
@@ -413,7 +415,8 @@ def BuiltinDialectTypes : DialectTypes<"Builtin"> {
Float4E2M1FNType,
Float6E2M3FNType,
Float6E3M2FNType,
- Float8E8M0FNUType
+ Float8E8M0FNUType,
+ TokenType
];
}
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index 20c41c5f79729..3d36ffa5802c3 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -1237,6 +1237,26 @@ def Builtin_RankedTensor : Builtin_Type<"RankedTensor", "tensor", [
let genVerifyDecl = 1;
}
+//===----------------------------------------------------------------------===//
+// TokenType
+//===----------------------------------------------------------------------===//
+
+def Builtin_Token : Builtin_Type<"Token", "token"> {
+ let summary = "Token type";
+ let description = [{
+ Syntax:
+
+ ```
+ token-type ::= `token`
+ ```
+
+ A use of a token SSA value is a pointer to an operation (in case of an
+ OpResult) or a pointer to a region (in case of an entry block argument).
+ A token carries no runtime data and cannot be forwarded. Tokens are
+ excluded from the `AnyType` type constraint.
+ }];
+}
+
//===----------------------------------------------------------------------===//
// TupleType
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index 57caaae08462f..86066bcda73e6 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -165,8 +165,24 @@ class SameBuildabilityAs<Type type, code builder> {
code builderCall = !if(!empty(type.builderCall), "", builder);
}
-// Any type at all.
-def AnyType : Type<CPred<"true">, "any type">;
+// Whether a type is the builtin `TokenType`.
+def IsTokenTypePred : CPred<"::llvm::isa<::mlir::TokenType>($_self)">;
+
+// Any non-token type. Tokens are excluded by default to prevent ops that
+// accept arbitrary types from accidentally accepting tokens as operands /
+// results, since a token must not be value-forwarded. Ops that legitimately
+// want to accept any type, including tokens, should use `AnyTypeOrToken`
+// instead.
+def AnyType : Type<Neg<IsTokenTypePred>, "any non-token type">;
+
+// Any type at all, including tokens. Used by ops that explicitly opt in to
+// accepting tokens (e.g. ops in interfaces such as `CallOpInterface`,
+// `BranchOpInterface`, etc. that legitimately handle arbitrary types).
+def AnyTypeOrToken : Type<CPred<"true">, "any type">;
+
+// The builtin token type.
+def Token : Type<IsTokenTypePred, "token", "::mlir::TokenType">,
+ BuildableType<"$_builder.getType<::mlir::TokenType>()">;
// None type
def NoneType : Type<CPred<"::llvm::isa<::mlir::NoneType>($_self)">, "none type",
diff --git a/mlir/lib/AsmParser/TokenKinds.def b/mlir/lib/AsmParser/TokenKinds.def
index fe7c53753e156..f5e5c25832a30 100644
--- a/mlir/lib/AsmParser/TokenKinds.def
+++ b/mlir/lib/AsmParser/TokenKinds.def
@@ -127,6 +127,7 @@ TOK_KEYWORD(symbol)
TOK_KEYWORD(tensor)
TOK_KEYWORD(tf32)
TOK_KEYWORD(to)
+TOK_KEYWORD(token)
TOK_KEYWORD(true)
TOK_KEYWORD(tuple)
TOK_KEYWORD(type)
diff --git a/mlir/lib/AsmParser/TypeParser.cpp b/mlir/lib/AsmParser/TypeParser.cpp
index a461ebed967a8..2cdec14d65fa6 100644
--- a/mlir/lib/AsmParser/TypeParser.cpp
+++ b/mlir/lib/AsmParser/TypeParser.cpp
@@ -58,6 +58,7 @@ OptionalParseResult Parser::parseOptionalType(Type &type) {
case Token::kw_f128:
case Token::kw_index:
case Token::kw_none:
+ case Token::kw_token:
case Token::exclamation_identifier:
return failure(!(type = parseType()));
@@ -371,6 +372,11 @@ Type Parser::parseNonFunctionType() {
consumeToken(Token::kw_none);
return builder.getNoneType();
+ // token-type
+ case Token::kw_token:
+ consumeToken(Token::kw_token);
+ return builder.getType<TokenType>();
+
// extended type
case Token::exclamation_identifier:
return parseExtendedType();
diff --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
index 29e6552231f9c..7844c9dda877c 100644
--- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
+++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
@@ -89,7 +89,7 @@ struct AsyncAPI {
}
static FunctionType createTokenFunctionType(MLIRContext *ctx) {
- return FunctionType::get(ctx, {}, {TokenType::get(ctx)});
+ return FunctionType::get(ctx, {}, {async::TokenType::get(ctx)});
}
static FunctionType createValueFunctionType(MLIRContext *ctx) {
@@ -109,7 +109,7 @@ struct AsyncAPI {
}
static FunctionType emplaceTokenFunctionType(MLIRContext *ctx) {
- return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
+ return FunctionType::get(ctx, {async::TokenType::get(ctx)}, {});
}
static FunctionType emplaceValueFunctionType(MLIRContext *ctx) {
@@ -118,7 +118,7 @@ struct AsyncAPI {
}
static FunctionType setTokenErrorFunctionType(MLIRContext *ctx) {
- return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
+ return FunctionType::get(ctx, {async::TokenType::get(ctx)}, {});
}
static FunctionType setValueErrorFunctionType(MLIRContext *ctx) {
@@ -128,7 +128,7 @@ struct AsyncAPI {
static FunctionType isTokenErrorFunctionType(MLIRContext *ctx) {
auto i1 = IntegerType::get(ctx, 1);
- return FunctionType::get(ctx, {TokenType::get(ctx)}, {i1});
+ return FunctionType::get(ctx, {async::TokenType::get(ctx)}, {i1});
}
static FunctionType isValueErrorFunctionType(MLIRContext *ctx) {
@@ -143,7 +143,7 @@ struct AsyncAPI {
}
static FunctionType awaitTokenFunctionType(MLIRContext *ctx) {
- return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
+ return FunctionType::get(ctx, {async::TokenType::get(ctx)}, {});
}
static FunctionType awaitValueFunctionType(MLIRContext *ctx) {
@@ -162,13 +162,14 @@ struct AsyncAPI {
static FunctionType addTokenToGroupFunctionType(MLIRContext *ctx) {
auto i64 = IntegerType::get(ctx, 64);
- return FunctionType::get(ctx, {TokenType::get(ctx), GroupType::get(ctx)},
- {i64});
+ return FunctionType::get(
+ ctx, {async::TokenType::get(ctx), GroupType::get(ctx)}, {i64});
}
static FunctionType awaitTokenAndExecuteFunctionType(MLIRContext *ctx) {
auto ptrType = opaquePointerType(ctx);
- return FunctionType::get(ctx, {TokenType::get(ctx), ptrType, ptrType}, {});
+ return FunctionType::get(
+ ctx, {async::TokenType::get(ctx), ptrType, ptrType}, {});
}
static FunctionType awaitValueAndExecuteFunctionType(MLIRContext *ctx) {
@@ -291,7 +292,7 @@ class AsyncRuntimeTypeConverter : public TypeConverter {
}
static std::optional<Type> convertAsyncTypes(Type type) {
- if (isa<TokenType, GroupType, ValueType>(type))
+ if (isa<async::TokenType, GroupType, ValueType>(type))
return AsyncAPI::opaquePointerType(type.getContext());
if (isa<CoroIdType, CoroStateType>(type))
@@ -583,7 +584,7 @@ class RuntimeCreateOpLowering : public ConvertOpToLLVMPattern<RuntimeCreateOp> {
Type resultType = op->getResultTypes()[0];
// Tokens creation maps to a simple function call.
- if (isa<TokenType>(resultType)) {
+ if (isa<async::TokenType>(resultType)) {
rewriter.replaceOpWithNewOp<func::CallOp>(
op, kCreateToken, converter->convertType(resultType));
return success();
@@ -659,7 +660,7 @@ class RuntimeSetAvailableOpLowering
ConversionPatternRewriter &rewriter) const override {
StringRef apiFuncName =
TypeSwitch<Type, StringRef>(op.getOperand().getType())
- .Case<TokenType>([](Type) { return kEmplaceToken; })
+ .Case<async::TokenType>([](Type) { return kEmplaceToken; })
.Case<ValueType>([](Type) { return kEmplaceValue; });
rewriter.replaceOpWithNewOp<func::CallOp>(op, apiFuncName, TypeRange(),
@@ -685,7 +686,7 @@ class RuntimeSetErrorOpLowering
ConversionPatternRewriter &rewriter) const override {
StringRef apiFuncName =
TypeSwitch<Type, StringRef>(op.getOperand().getType())
- .Case<TokenType>([](Type) { return kSetTokenError; })
+ .Case<async::TokenType>([](Type) { return kSetTokenError; })
.Case<ValueType>([](Type) { return kSetValueError; });
rewriter.replaceOpWithNewOp<func::CallOp>(op, apiFuncName, TypeRange(),
@@ -710,7 +711,7 @@ class RuntimeIsErrorOpLowering : public OpConversionPattern<RuntimeIsErrorOp> {
ConversionPatternRewriter &rewriter) const override {
StringRef apiFuncName =
TypeSwitch<Type, StringRef>(op.getOperand().getType())
- .Case<TokenType>([](Type) { return kIsTokenError; })
+ .Case<async::TokenType>([](Type) { return kIsTokenError; })
.Case<GroupType>([](Type) { return kIsGroupError; })
.Case<ValueType>([](Type) { return kIsValueError; });
@@ -735,7 +736,7 @@ class RuntimeAwaitOpLowering : public OpConversionPattern<RuntimeAwaitOp> {
ConversionPatternRewriter &rewriter) const override {
StringRef apiFuncName =
TypeSwitch<Type, StringRef>(op.getOperand().getType())
- .Case<TokenType>([](Type) { return kAwaitToken; })
+ .Case<async::TokenType>([](Type) { return kAwaitToken; })
.Case<ValueType>([](Type) { return kAwaitValue; })
.Case<GroupType>([](Type) { return kAwaitGroup; });
@@ -763,7 +764,7 @@ class RuntimeAwaitAndResumeOpLowering
ConversionPatternRewriter &rewriter) const override {
StringRef apiFuncName =
TypeSwitch<Type, StringRef>(op.getOperand().getType())
- .Case<TokenType>([](Type) { return kAwaitTokenAndExecute; })
+ .Case<async::TokenType>([](Type) { return kAwaitTokenAndExecute; })
.Case<ValueType>([](Type) { return kAwaitValueAndExecute; })
.Case<GroupType>([](Type) { return kAwaitAllAndExecute; });
@@ -906,7 +907,7 @@ class RuntimeAddToGroupOpLowering
matchAndRewrite(RuntimeAddToGroupOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Currently we can only add tokens to the group.
- if (!isa<TokenType>(op.getOperand().getType()))
+ if (!isa<async::TokenType>(op.getOperand().getType()))
return rewriter.notifyMatchFailure(op, "only token type is supported");
// Replace with a runtime API function call.
@@ -1151,7 +1152,7 @@ class ConvertYieldOpTypes : public OpConversionPattern<async::YieldOp> {
void mlir::populateAsyncStructuralTypeConversionsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target) {
- typeConverter.addConversion([&](TokenType type) { return type; });
+ typeConverter.addConversion([&](async::TokenType type) { return type; });
typeConverter.addConversion([&](ValueType type) {
Type converted = typeConverter.convertType(type.getValueType());
return converted ? ValueType::get(converted) : converted;
diff --git a/mlir/lib/Dialect/Async/IR/Async.cpp b/mlir/lib/Dialect/Async/IR/Async.cpp
index 71be1d275280e..1713da07da60d 100644
--- a/mlir/lib/Dialect/Async/IR/Async.cpp
+++ b/mlir/lib/Dialect/Async/IR/Async.cpp
@@ -84,7 +84,7 @@ void ExecuteOp::build(OpBuilder &builder, OperationState &result,
// First result is always a token, and then `resultTypes` wrapped into
// `async.value`.
- result.addTypes({TokenType::get(result.getContext())});
+ result.addTypes({async::TokenType::get(result.getContext())});
for (Type type : resultTypes)
result.addTypes(ValueType::get(type));
@@ -139,7 +139,7 @@ ParseResult ExecuteOp::parse(OpAsmParser &parser, OperationState &result) {
// Sizes of parsed variadic operands, will be updated below after parsing.
int32_t numDependencies = 0;
- auto tokenTy = TokenType::get(ctx);
+ auto tokenTy = async::TokenType::get(ctx);
// Parse dependency tokens.
if (succeeded(parser.parseOptionalLSquare())) {
@@ -280,7 +280,7 @@ LogicalResult AwaitOp::verify() {
Type argType = getOperand().getType();
// Awaiting on a token does not have any results.
- if (llvm::isa<TokenType>(argType) && !getResultTypes().empty())
+ if (llvm::isa<async::TokenType>(argType) && !getResultTypes().empty())
return emitOpError("awaiting on a token must have empty result");
// Awaiting on a value unwraps the async value type.
@@ -345,12 +345,12 @@ LogicalResult FuncOp::verify() {
for (unsigned i = 0, e = resultTypes.size(); i != e; ++i) {
auto type = resultTypes[i];
- if (!llvm::isa<TokenType>(type) && !llvm::isa<ValueType>(type))
+ if (!llvm::isa<async::TokenType>(type) && !llvm::isa<ValueType>(type))
return emitOpError() << "result type must be async value type or async "
"token type, but got "
<< type;
// We only allow AsyncToken appear as the first return value
- if (llvm::isa<TokenType>(type) && i != 0) {
+ if (llvm::isa<async::TokenType>(type) && i != 0) {
return emitOpError()
<< " results' (optional) async token type is expected "
"to appear as the 1st return value, but got "
diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp
index 91e37dd9ac36e..2a726f3fd2999 100644
--- a/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp
+++ b/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp
@@ -526,7 +526,7 @@ void AsyncRuntimePolicyBasedRefCountingPass::initializeDefaultPolicy() {
Operation *op = operand.getOwner();
Type type = operand.get().getType();
- bool isToken = isa<TokenType>(type);
+ bool isToken = isa<async::TokenType>(type);
bool isGroup = isa<GroupType>(type);
bool isValue = isa<ValueType>(type);
diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
index 6ed50671fb3b3..92c285ee4df30 100644
--- a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
+++ b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
@@ -186,13 +186,14 @@ static CoroMachinery setupCoroMachinery(func::FuncOp func) {
// Allocate async token/values that we will return from a ramp function.
// ------------------------------------------------------------------------ //
- // We treat TokenType as state update marker to represent side-effects of
- // async computations
- bool isStateful = isa<TokenType>(func.getResultTypes().front());
+ // We treat async::TokenType as state update marker to represent
+ // side-effects of async computations
+ bool isStateful = isa<async::TokenType>(func.getResultTypes().front());
std::optional<Value> retToken;
if (isStateful)
- retToken.emplace(RuntimeCreateOp::create(builder, TokenType::get(ctx)));
+ retToken.emplace(
+ RuntimeCreateOp::create(builder, async::TokenType::get(ctx)));
llvm::SmallVector<Value, 4> retValues;
ArrayRef<Type> resValueTypes =
@@ -673,8 +674,9 @@ class AwaitOpLoweringBase : public OpConversionPattern<AwaitType> {
};
/// Lowering for `async.await` with a token operand.
-class AwaitTokenOpLowering : public AwaitOpLoweringBase<AwaitOp, TokenType> {
- using Base = AwaitOpLoweringBase<AwaitOp, TokenType>;
+class AwaitTokenOpLowering
+ : public AwaitOpLoweringBase<AwaitOp, async::TokenType> {
+ using Base = AwaitOpLoweringBase<AwaitOp, async::TokenType>;
public:
using Base::Base;
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
index 705d07d3e6c42..87df1e2df6102 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
@@ -238,13 +238,41 @@ Type LLVMStructType::parse(AsmParser &parser) {
}
/// Parses a type appearing inside another LLVM dialect-compatible type. This
-/// will try to parse any type in full form (including types with the `!llvm`
-/// prefix), and on failure fall back to parsing the short-hand version of the
-/// LLVM dialect types without the `!llvm` prefix.
+/// will first try to parse the LLVM dialect's short-hand keyword form (e.g.
+/// `token`, `void`, `ptr`, ...) and, failing that, fall back to parsing any
+/// MLIR type in full form (including types with the `!llvm` prefix).
+///
+/// Trying the short-hand form first matters because some LLVM short-hand
+/// keywords (notably `token`) collide with builtin type keywords whose
+/// semantics differ from the LLVM dialect's. Inside an `!llvm.<...>` type, the
+/// LLVM-specific meaning must always win.
static Type dispatchParse(AsmParser &parser, bool allowAny = true) {
SMLoc keyLoc = parser.getCurrentLocation();
+ MLIRContext *ctx = parser.getContext();
+
+ // Try parsing the LLVM dialect's short-hand keyword form first.
+ StringRef key;
+ if (succeeded(parser.parseOptionalKeyword(
+ &key, {"void", "ppc_fp128", "token", "label", "metadata", "func",
+ "ptr", "array", "struct", "target", "x86_amx"}))) {
+ // `parseOptionalKeyword` already restricted `key` to one of the cases
+ // below, so the `Default` is unreachable.
+ return StringSwitch<function_ref<Type()>>(key)
+ .Case("void", [&] { return LLVMVoidType::get(ctx); })
+ .Case("ppc_fp128", [&] { return LLVMPPCFP128Type::get(ctx); })
+ .Case("token", [&] { return LLVMTokenType::get(ctx); })
+ .Case("label", [&] { return LLVMLabelType::get(ctx); })
+ .Case("metadata", [&] { return LLVMMetadataType::get(ctx); })
+ .Case("func", [&] { return LLVMFunctionType::parse(parser); })
+ .Case("ptr", [&] { return LLVMPointerType::parse(parser); })
+ .Case("array", [&] { return LLVMArrayType::parse(parser); })
+ .Case("struct", [&] { return LLVMStructType::parse(parser); })
+ .Case("target", [&] { return LLVMTargetExtType::parse(parser); })
+ .Case("x86_amx", [&] { return LLVMX86AMXType::get(ctx); })
+ .Default([] { return Type(); })();
+ }
- // Try parsing any MLIR type.
+ // Otherwise, try parsing any MLIR type (only when allowed).
Type type;
OptionalParseResult result = parser.parseOptionalType(type);
if (result.has_value()) {
@@ -257,28 +285,12 @@ static Type dispatchParse(AsmParser &parser, bool allowAny = true) {
return type;
}
- // If no type found, fallback to the shorthand form.
- StringRef key;
+ // Neither a known LLVM short-hand keyword nor a parseable MLIR type.
+ // Re-run `parseKeyword` to produce a useful error message.
if (failed(parser.parseKeyword(&key)))
return Type();
-
- MLIRContext *ctx = parser.getContext();
- return StringSwitch<function_ref<Type()>>(key)
- .Case("void", [&] { return LLVMVoidType::get(ctx); })
- .Case("ppc_fp128", [&] { return LLVMPPCFP128Type::get(ctx); })
- .Case("token", [&] { return LLVMTokenType::get(ctx); })
- .Case("label", [&] { return LLVMLabelType::get(ctx); })
- .Case("metadata", [&] { return LLVMMetadataType::get(ctx); })
- .Case("func", [&] { return LLVMFunctionType::parse(parser); })
- .Case("ptr", [&] { return LLVMPointerType::parse(parser); })
- .Case("array", [&] { return LLVMArrayType::parse(parser); })
- .Case("struct", [&] { return LLVMStructType::parse(parser); })
- .Case("target", [&] { return LLVMTargetExtType::parse(parser); })
- .Case("x86_amx", [&] { return LLVMX86AMXType::get(ctx); })
- .Default([&] {
- parser.emitError(keyLoc) << "unknown LLVM type: " << key;
- return Type();
- })();
+ parser.emitError(keyLoc) << "unknown LLVM type: " << key;
+ return Type();
}
/// Helper to use in parse lists.
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index ec270db189081..ca5c2d2a88ee5 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -2907,6 +2907,7 @@ void AsmPrinter::Impl::printTypeImpl(Type type) {
os << '>';
})
.Case<NoneType>([&](Type) { os << "none"; })
+ .Case<TokenType>([&](Type) { os << "token"; })
.Case([&](GraphType graphTy) {
os << '(';
interleaveComma(graphTy.getInputs(), [&](Type ty) { printType(ty); });
diff --git a/mlir/test/Dialect/ArmSME/invalid.mlir b/mlir/test/Dialect/ArmSME/invalid.mlir
index 8c5a098a0c785..f00945e18cc1f 100644
--- a/mlir/test/Dialect/ArmSME/invalid.mlir
+++ b/mlir/test/Dialect/ArmSME/invalid.mlir
@@ -132,7 +132,7 @@ func.func @arm_sme_tile_load__pad_but_no_mask(%src : memref<?x?xf64>, %pad : f64
func.func @arm_sme_tile_load__bad_memref_rank(%src : memref<?xf64>, %pad : f64) {
%c0 = arith.constant 0 : index
- // expected-error at +1 {{op operand #0 must be 2D memref of any type values, but got 'memref<?xf64>'}}
+ // expected-error at +1 {{op operand #0 must be 2D memref of any non-token type values, but got 'memref<?xf64>'}}
%tile = arm_sme.tile_load %src[%c0], %pad, : memref<?xf64>, vector<[2]x[2]xf64>
return
}
@@ -186,7 +186,7 @@ func.func @arm_sme_tile_store__bad_mask_type(%tile : vector<[16]x[16]xi8>, %mask
func.func @arm_sme_tile_store__bad_memref_rank(%tile : vector<[16]x[16]xi8>, %dest : memref<?xi8>) {
%c0 = arith.constant 0 : index
- // expected-error at +1 {{op operand #1 must be 2D memref of any type values, but got 'memref<?xi8>'}}
+ // expected-error at +1 {{op operand #1 must be 2D memref of any non-token type values, but got 'memref<?xi8>'}}
arm_sme.tile_store %tile, %dest[%c0] : memref<?xi8>, vector<[16]x[16]xi8>
return
}
diff --git a/mlir/test/Dialect/Builtin/Bytecode/builtin_fixed.mlir b/mlir/test/Dialect/Builtin/Bytecode/builtin_fixed.mlir
index 638689406ab17..31061a6fa8653 100644
--- a/mlir/test/Dialect/Builtin/Bytecode/builtin_fixed.mlir
+++ b/mlir/test/Dialect/Builtin/Bytecode/builtin_fixed.mlir
@@ -282,6 +282,7 @@ module @TestBasicTypes attributes {
// CHECK-DAG: bytecode.ui64 = ui64
// CHECK-DAG: bytecode.index = index
// CHECK-DAG: bytecode.none = none
+ // CHECK-DAG: bytecode.token = token
bytecode.i1 = i1,
bytecode.i8 = i8,
bytecode.i32 = i32,
@@ -289,7 +290,8 @@ module @TestBasicTypes attributes {
bytecode.si32 = si32,
bytecode.ui64 = ui64,
bytecode.index = index,
- bytecode.none = none
+ bytecode.none = none,
+ bytecode.token = token
} {} loc(unknown)
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Builtin/Bytecode/builtin_fixed_0.mlirbc b/mlir/test/Dialect/Builtin/Bytecode/builtin_fixed_0.mlirbc
index e2f58a42751a1cddc0dc2845d06735faafdc8b32..76a764de40eced87c6cc3d98b1aabf8153649989 100644
GIT binary patch
delta 551
zcmX}gOK1~89LI5Ie)FGMT_#JrG=aosX;OEI!Kw%zJXE5jzNk{fiUuuJ8j)ZXDIPqu
zUhKhxv>mYG!DA1CUIHEzdhEeeMaB9;MMOo#`ob0w9|OW3zWhG?7ITyNu@)6lD2}mk
zeECgGy6*Vn@=~wpMPABdH50NRqr6PUbt^$7Wyi}0c{k<zK2YZK^*%NR=9 at 1fz)EzO
z1ZB5wq}$r3yHFd at iB0IjcI*V*(<OSKH`7C1rbnRtIEX{2f}ZFt^h|G~=TOl}oPu9C
z4ZVY2>TYv_(j%|zeGca_feW|<y^CJ!9(tp9(>r<(T){QaN4<wWfxH_iOhfP0`{*;&
zE!@F<%;70sfWGPqeb)!*hwh_A&<D)p8x}x6bwB;mhv_%Kg%tF+ZGaKypk|R2X^|Dw
zB6)O#E#nE^%qhvEqeyegnpWA*ie^}LuwipmwzA8nA=gkXgQnVqnjKk>HYF4oV?C at F
z`PhS45qlI->?sstUq><a3|7T{i`B9JU~TM#Soh!3V+=W7T5^_3O^rswbUW>s&V^>k
tX=g$GqSGm9=y+{lus)kO!;4kP)Rb+`CKXd%wyJH^He*|`B^~=a>>r_ZaOMC2
delta 579
zcmX}gU5HF^9L90~|8t&mW at q}%X=a$I8D`p<WlWRxR(nBKLqihUmD$#`w1kFo;bLyo
zPRUDA+I4E(xp1M}xG=RBgpiT<koOR>LRQFZR+dxhTzq>zJ;yWS>A`7MoJnJsN)Amw
zuMQ18FSZ2ocGiyBs^ZL6skog;1Onw$Ip&U?No4}DNF)Ljb)&vT4Wbqr)8BPdPv}V<
z(lFsi5Me|SM-pkw0xw7lFUewF=2fwHO@!BF32#V<w<N+luoj>Li_nQ>SOMOXrMxc%
zK9uEr1a_l at jo1V}ks_bTO1^-V-HF|Z<n};T at uim4&R)iAuibwi4&pG5;uvHN-^e<?
zlP<pJ2jC=5fgfc(KZEvJ<VGMHq=#Q%T|omkaTgEp82l=|{3d1okUstd-(nP>Fb0mx
zX8w|X{sty63I1u_s`wZ58$ly%L=EbxI=fAo+OLC}VV&K9sAk;Bnf=ZYvt3o4bLI>+
z>^wE+v64Wj;)r#-){a)ou%MLcQfra%yd5(=4<PIL1ah9wAn*Ai=6HUExt^Pt at 7XjK
z{<mo>g`_S_IiL$6r_a}hfpps0?`x}{_H~-|X0WRBB~w|$ZVk7lTZg5rl3Uen#BJ2g
J^tr!J{R0}mc9Q at A
diff --git a/mlir/test/Dialect/Builtin/Bytecode/types.mlir b/mlir/test/Dialect/Builtin/Bytecode/types.mlir
index 5e421e2bf75bf..91d4512998b99 100644
--- a/mlir/test/Dialect/Builtin/Bytecode/types.mlir
+++ b/mlir/test/Dialect/Builtin/Bytecode/types.mlir
@@ -169,3 +169,13 @@ module @TestVector attributes {
bytecode.test = vector<8x8x128xi8>,
bytecode.test1 = vector<8x[8]xf32>
} {}
+
+//===----------------------------------------------------------------------===//
+// TokenType
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @TestToken
+module @TestToken attributes {
+ // CHECK: bytecode.test = token
+ bytecode.test = token
+} {}
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index 06f3fcb41190b..a446cfcc4eec1 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -415,7 +415,7 @@ func.func @illegal_fill_memref_with_tensor_return
func.func @illegal_fill_tensor_with_memref_return
(%arg0 : tensor<?x?xf32>, %arg1 : f32) -> memref<?x?xf32>
{
- // expected-error @+1 {{result #0 must be variadic of ranked tensor of any type values, but got 'memref<?x?xf32>'}}
+ // expected-error @+1 {{result #0 must be variadic of ranked tensor of any non-token type values, but got 'memref<?x?xf32>'}}
%0 = linalg.fill ins(%arg1 : f32) outs(%arg0 : tensor<?x?xf32>) -> memref<?x?xf32>
return %0 : memref<?x?xf32>
}
@@ -468,7 +468,7 @@ func.func @invalid_scalar_input_matmul(%arg0: f32, %arg1: memref<3x4xf32>, %arg2
// -----
func.func @invalid_scalar_output_matmul(%arg0: memref<2x3xf32>, %arg1: memref<3x4xf32>, %arg2: f32) {
- // expected-error @+1 {{'linalg.matmul' op operand #2 must be variadic of shaped of any type values, but got 'f32'}}
+ // expected-error @+1 {{'linalg.matmul' op operand #2 must be variadic of shaped of any non-token type values, but got 'f32'}}
linalg.matmul ins(%arg0, %arg1 : memref<2x3xf32>, memref<3x4xf32>)
outs(%arg2 : f32)
return
diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir
index 2f061a1bb773e..ecffd683a98c2 100644
--- a/mlir/test/Dialect/MemRef/invalid.mlir
+++ b/mlir/test/Dialect/MemRef/invalid.mlir
@@ -1037,7 +1037,7 @@ func.func @test_alloc_memref_map_rank_mismatch() {
// -----
func.func @rank(%0: f32) {
- // expected-error at +1 {{'memref.rank' op operand #0 must be ranked or unranked memref of any type values}}
+ // expected-error at +1 {{'memref.rank' op operand #0 must be ranked or unranked memref of any non-token type values}}
"memref.rank"(%0): (f32)->index
return
}
@@ -1172,7 +1172,7 @@ func.func @memref_realloc_type(%src : memref<256xf32>) -> memref<?xi32>{
// Asking the dimension of a 0-D shape doesn't make sense.
func.func @dim_0_ranked(%arg : memref<f32>, %arg1 : index) {
- memref.dim %arg, %arg1 : memref<f32> // expected-error {{'memref.dim' op operand #0 must be unranked.memref of any type values or non-0-ranked.memref of any type values, but got 'memref<f32>'}}
+ memref.dim %arg, %arg1 : memref<f32> // expected-error {{'memref.dim' op operand #0 must be unranked.memref of any non-token type values or non-0-ranked.memref of any non-token type values, but got 'memref<f32>'}}
return
}
diff --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir
index ae706b9b148a6..d14229b011f11 100644
--- a/mlir/test/Dialect/SparseTensor/invalid.mlir
+++ b/mlir/test/Dialect/SparseTensor/invalid.mlir
@@ -1,7 +1,7 @@
// RUN: mlir-opt %s -split-input-file -verify-diagnostics
func.func @invalid_new_dense(%arg0: !llvm.ptr) -> tensor<32xf32> {
- // expected-error at +1 {{'sparse_tensor.new' op result #0 must be sparse tensor of any type values, but got 'tensor<32xf32>'}}
+ // expected-error at +1 {{'sparse_tensor.new' op result #0 must be sparse tensor of any non-token type values, but got 'tensor<32xf32>'}}
%0 = sparse_tensor.new %arg0 : !llvm.ptr to tensor<32xf32>
return %0 : tensor<32xf32>
}
@@ -96,7 +96,7 @@ func.func @invalid_unpack_mis_position(%sp: tensor<2x100xf64, #CSR>, %values: te
// -----
func.func @invalid_positions_dense(%arg0: tensor<128xf64>) -> memref<?xindex> {
- // expected-error at +1 {{'sparse_tensor.positions' op operand #0 must be sparse tensor of any type values, but got 'tensor<128xf64>'}}
+ // expected-error at +1 {{'sparse_tensor.positions' op operand #0 must be sparse tensor of any non-token type values, but got 'tensor<128xf64>'}}
%0 = sparse_tensor.positions %arg0 { level = 0 : index } : tensor<128xf64> to memref<?xindex>
return %0 : memref<?xindex>
}
@@ -104,7 +104,7 @@ func.func @invalid_positions_dense(%arg0: tensor<128xf64>) -> memref<?xindex> {
// -----
func.func @invalid_positions_unranked(%arg0: tensor<*xf64>) -> memref<?xindex> {
- // expected-error at +1 {{'sparse_tensor.positions' op operand #0 must be sparse tensor of any type values, but got 'tensor<*xf64>'}}
+ // expected-error at +1 {{'sparse_tensor.positions' op operand #0 must be sparse tensor of any non-token type values, but got 'tensor<*xf64>'}}
%0 = "sparse_tensor.positions"(%arg0) { level = 0 : index } : (tensor<*xf64>) -> (memref<?xindex>)
return %0 : memref<?xindex>
}
@@ -132,7 +132,7 @@ func.func @positions_oob(%arg0: tensor<128xf64, #SparseVector>) -> memref<?xinde
// -----
func.func @invalid_indices_dense(%arg0: tensor<10x10xi32>) -> memref<?xindex> {
- // expected-error at +1 {{'sparse_tensor.coordinates' op operand #0 must be sparse tensor of any type values, but got 'tensor<10x10xi32>'}}
+ // expected-error at +1 {{'sparse_tensor.coordinates' op operand #0 must be sparse tensor of any non-token type values, but got 'tensor<10x10xi32>'}}
%0 = sparse_tensor.coordinates %arg0 { level = 1 : index } : tensor<10x10xi32> to memref<?xindex>
return %0 : memref<?xindex>
}
@@ -140,7 +140,7 @@ func.func @invalid_indices_dense(%arg0: tensor<10x10xi32>) -> memref<?xindex> {
// -----
func.func @invalid_indices_unranked(%arg0: tensor<*xf64>) -> memref<?xindex> {
- // expected-error at +1 {{'sparse_tensor.coordinates' op operand #0 must be sparse tensor of any type values, but got 'tensor<*xf64>'}}
+ // expected-error at +1 {{'sparse_tensor.coordinates' op operand #0 must be sparse tensor of any non-token type values, but got 'tensor<*xf64>'}}
%0 = "sparse_tensor.coordinates"(%arg0) { level = 0 : index } : (tensor<*xf64>) -> (memref<?xindex>)
return %0 : memref<?xindex>
}
@@ -168,7 +168,7 @@ func.func @indices_oob(%arg0: tensor<128xf64, #SparseVector>) -> memref<?xindex>
// -----
func.func @invalid_values_dense(%arg0: tensor<1024xf32>) -> memref<?xf32> {
- // expected-error at +1 {{'sparse_tensor.values' op operand #0 must be sparse tensor of any type values, but got 'tensor<1024xf32>'}}
+ // expected-error at +1 {{'sparse_tensor.values' op operand #0 must be sparse tensor of any non-token type values, but got 'tensor<1024xf32>'}}
%0 = sparse_tensor.values %arg0 : tensor<1024xf32> to memref<?xf32>
return %0 : memref<?xf32>
}
@@ -186,7 +186,7 @@ func.func @indices_buffer_noncoo(%arg0: tensor<128xf64, #SparseVector>) -> memre
// -----
func.func @indices_buffer_dense(%arg0: tensor<1024xf32>) -> memref<?xindex> {
- // expected-error at +1 {{must be sparse tensor of any type values}}
+ // expected-error at +1 {{must be sparse tensor of any non-token type values}}
%0 = sparse_tensor.coordinates_buffer %arg0 : tensor<1024xf32> to memref<?xindex>
return %0 : memref<?xindex>
}
@@ -283,7 +283,7 @@ func.func @sparse_get_md(%arg0: !sparse_tensor.storage_specifier<#COO>) -> index
// -----
func.func @sparse_unannotated_load(%arg0: tensor<16x32xf64>) -> tensor<16x32xf64> {
- // expected-error at +1 {{'sparse_tensor.load' op operand #0 must be sparse tensor of any type values, but got 'tensor<16x32xf64>'}}
+ // expected-error at +1 {{'sparse_tensor.load' op operand #0 must be sparse tensor of any non-token type values, but got 'tensor<16x32xf64>'}}
%0 = sparse_tensor.load %arg0 : tensor<16x32xf64>
return %0 : tensor<16x32xf64>
}
@@ -308,7 +308,7 @@ func.func @sparse_push_back_n(%arg0: index, %arg1: memref<?xf32>, %arg2: f32) ->
// -----
func.func @sparse_unannotated_expansion(%arg0: tensor<128xf64>) {
- // expected-error at +1 {{'sparse_tensor.expand' op operand #0 must be sparse tensor of any type values, but got 'tensor<128xf64>'}}
+ // expected-error at +1 {{'sparse_tensor.expand' op operand #0 must be sparse tensor of any non-token type values, but got 'tensor<128xf64>'}}
%values, %filled, %added, %count = sparse_tensor.expand %arg0
: tensor<128xf64> to memref<?xf64>, memref<?xi1>, memref<?xindex>
return
@@ -322,7 +322,7 @@ func.func @sparse_unannotated_compression(%arg0: memref<?xf64>,
%arg3: index,
%arg4: tensor<8x8xf64>,
%arg5: index) {
- // expected-error at +1 {{'sparse_tensor.compress' op operand #4 must be sparse tensor of any type values, but got 'tensor<8x8xf64>'}}
+ // expected-error at +1 {{'sparse_tensor.compress' op operand #4 must be sparse tensor of any non-token type values, but got 'tensor<8x8xf64>'}}
sparse_tensor.compress %arg0, %arg1, %arg2, %arg3 into %arg4[%arg5]
: memref<?xf64>, memref<?xi1>, memref<?xindex>, tensor<8x8xf64>
return
@@ -375,7 +375,7 @@ func.func @sparse_convert_dim_mismatch(%arg0: tensor<10x?xf32>) -> tensor<10x10x
// -----
func.func @invalid_out_dense(%arg0: tensor<10xf64>, %arg1: !llvm.ptr) {
- // expected-error at +1 {{'sparse_tensor.out' op operand #0 must be sparse tensor of any type values, but got 'tensor<10xf64>'}}
+ // expected-error at +1 {{'sparse_tensor.out' op operand #0 must be sparse tensor of any non-token type values, but got 'tensor<10xf64>'}}
sparse_tensor.out %arg0, %arg1 : tensor<10xf64>, !llvm.ptr
return
}
@@ -1022,7 +1022,7 @@ func.func @sparse_reinterpret_map(%t0 : tensor<6x12xi32, #BSR>) -> tensor<3x4x2x
#CSR = #sparse_tensor.encoding<{map = (d0, d1) -> (d0 : compressed, d1 : compressed)}>
func.func @sparse_print(%arg0: tensor<10x10xf64>) {
- // expected-error at +1 {{'sparse_tensor.print' op operand #0 must be sparse tensor of any type values}}
+ // expected-error at +1 {{'sparse_tensor.print' op operand #0 must be sparse tensor of any non-token type values}}
sparse_tensor.print %arg0 : tensor<10x10xf64>
return
}
diff --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir
index 6ee2f9911663f..a526d7ed61722 100644
--- a/mlir/test/Dialect/Tensor/invalid.mlir
+++ b/mlir/test/Dialect/Tensor/invalid.mlir
@@ -404,7 +404,7 @@ func.func @illegal_collapsing_reshape_mixed_tensor_2(%arg0 : tensor<?x4x5xf32>)
// -----
func.func @rank(%0: f32) {
- // expected-error at +1 {{'tensor.rank' op operand #0 must be tensor of any type values}}
+ // expected-error at +1 {{'tensor.rank' op operand #0 must be tensor of any non-token type values}}
"tensor.rank"(%0): (f32)->index
return
}
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 662e4e8b5b561..2fed3002596a3 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -106,7 +106,7 @@ func.func @shuffle_index_out_of_range(%arg0: vector<2xf32>, %arg1: vector<2xf32>
// -----
func.func @shuffle_scalable_vec(%arg0: vector<[2]xf32>, %arg1: vector<[2]xf32>) {
- // expected-error at +1 {{'vector.shuffle' op operand #0 must be fixed-length vector of any type values}}
+ // expected-error at +1 {{'vector.shuffle' op operand #0 must be fixed-length vector of any non-token type values}}
%1 = vector.shuffle %arg0, %arg1 [0, 1, 2, 3] : vector<[2]xf32>, vector<[2]xf32>
}
@@ -1460,7 +1460,7 @@ func.func @maskedstore_memref_mismatch(%base: memref<?xf32>, %mask: vector<16xi1
func.func @gather_from_vector(%base: vector<16xf32>, %indices: vector<16xi32>,
%mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
%c0 = arith.constant 0 : index
- // expected-error at +1 {{'vector.gather' op operand #0 must be Tensor or MemRef of any type values, but got 'vector<16xf32>'}}
+ // expected-error at +1 {{'vector.gather' op operand #0 must be Tensor or MemRef of any non-token type values, but got 'vector<16xf32>'}}
%0 = vector.gather %base[%c0][%indices], %mask, %pass_thru
: vector<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
}
@@ -1557,7 +1557,7 @@ func.func @gather_tensor_alignment(%base: tensor<16xf32>, %indices: vector<16xi3
func.func @scatter_to_vector(%base: vector<16xf32>, %indices: vector<16xi32>,
%mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
%c0 = arith.constant 0 : index
- // expected-error at +1 {{'vector.scatter' op operand #0 must be Tensor or MemRef of any type values, but got 'vector<16xf32>'}}
+ // expected-error at +1 {{'vector.scatter' op operand #0 must be Tensor or MemRef of any non-token type values, but got 'vector<16xf32>'}}
vector.scatter %base[%c0][%indices], %mask, %pass_thru
: vector<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
}
@@ -1959,7 +1959,7 @@ func.func @invalid_outerproduct1(%src : memref<?xf32>, %lhs : vector<[4]x[4]xf32
// -----
func.func @deinterleave_zero_dim_fail(%vec : vector<f32>) {
- // expected-error @+1 {{'vector.deinterleave' op operand #0 must be vector of any type values, but got 'vector<f32>}}
+ // expected-error @+1 {{'vector.deinterleave' op operand #0 must be vector of any non-token type values, but got 'vector<f32>}}
%0, %1 = vector.deinterleave %vec : vector<f32> -> vector<f32>
return
}
@@ -2048,7 +2048,7 @@ func.func @from_elements_wrong_operand_type(%a: f32, %b: i32) {
// -----
func.func @invalid_from_elements_scalable(%a: f32, %b: i32) {
- // expected-error @+1 {{'dest' must be fixed-length vector of any type values, but got 'vector<[2]xf32>'}}
+ // expected-error @+1 {{'dest' must be fixed-length vector of any non-token type values, but got 'vector<[2]xf32>'}}
vector.from_elements %a, %b : vector<[2]xf32>
return
}
diff --git a/mlir/test/Dialect/traits.mlir b/mlir/test/Dialect/traits.mlir
index 4d583435adeee..ae48cadbf370f 100644
--- a/mlir/test/Dialect/traits.mlir
+++ b/mlir/test/Dialect/traits.mlir
@@ -58,7 +58,7 @@ func.func @broadcast_tensor_tensor_tensor(tensor<8x1x?x1xi32>, tensor<7x1x5xi32>
// Check incompatible vector and tensor result type
func.func @broadcast_scalar_vector_vector(tensor<4xf32>, tensor<4xf32>) -> vector<4xf32> {
^bb0(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>):
- // expected-error @+1 {{op result #0 must be tensor of any type values, but got 'vector<4xf32>'}}
+ // expected-error @+1 {{op result #0 must be tensor of any non-token type values, but got 'vector<4xf32>'}}
%0 = "test.broadcastable"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> vector<4xf32>
return %0 : vector<4xf32>
}
diff --git a/mlir/test/IR/operand.mlir b/mlir/test/IR/operand.mlir
index 507e37c775c0b..1ac12dc4b9556 100644
--- a/mlir/test/IR/operand.mlir
+++ b/mlir/test/IR/operand.mlir
@@ -13,7 +13,7 @@ func.func @correct_variadic_operand(%arg0: tensor<f32>, %arg1: f32) {
// -----
func.func @error_in_first_variadic_operand(%arg0: tensor<f32>, %arg1: f32) {
- // expected-error @+1 {{operand #1 must be variadic of tensor of any type}}
+ // expected-error @+1 {{operand #1 must be variadic of tensor of any non-token type}}
"test.mixed_normal_variadic_operand"(%arg0, %arg1, %arg0, %arg0, %arg0) : (tensor<f32>, f32, tensor<f32>, tensor<f32>, tensor<f32>) -> ()
return
}
@@ -21,7 +21,7 @@ func.func @error_in_first_variadic_operand(%arg0: tensor<f32>, %arg1: f32) {
// -----
func.func @error_in_normal_operand(%arg0: tensor<f32>, %arg1: f32) {
- // expected-error @+1 {{operand #2 must be tensor of any type}}
+ // expected-error @+1 {{operand #2 must be tensor of any non-token type}}
"test.mixed_normal_variadic_operand"(%arg0, %arg0, %arg1, %arg0, %arg0) : (tensor<f32>, tensor<f32>, f32, tensor<f32>, tensor<f32>) -> ()
return
}
@@ -29,7 +29,7 @@ func.func @error_in_normal_operand(%arg0: tensor<f32>, %arg1: f32) {
// -----
func.func @error_in_second_variadic_operand(%arg0: tensor<f32>, %arg1: f32) {
- // expected-error @+1 {{operand #3 must be variadic of tensor of any type}}
+ // expected-error @+1 {{operand #3 must be variadic of tensor of any non-token type}}
"test.mixed_normal_variadic_operand"(%arg0, %arg0, %arg0, %arg1, %arg0) : (tensor<f32>, tensor<f32>, tensor<f32>, f32, tensor<f32>) -> ()
return
}
diff --git a/mlir/test/IR/result.mlir b/mlir/test/IR/result.mlir
index 1e4eb3bede4c5..cdeae4202f0ff 100644
--- a/mlir/test/IR/result.mlir
+++ b/mlir/test/IR/result.mlir
@@ -13,7 +13,7 @@ func.func @correct_variadic_result() -> tensor<f32> {
// -----
func.func @error_in_first_variadic_result() -> tensor<f32> {
- // expected-error @+1 {{result #1 must be variadic of tensor of any type}}
+ // expected-error @+1 {{result #1 must be variadic of tensor of any non-token type}}
%0:5 = "test.mixed_normal_variadic_result"() : () -> (tensor<f32>, f32, tensor<f32>, tensor<f32>, tensor<f32>)
return %0#4 : tensor<f32>
}
@@ -21,7 +21,7 @@ func.func @error_in_first_variadic_result() -> tensor<f32> {
// -----
func.func @error_in_normal_result() -> tensor<f32> {
- // expected-error @+1 {{result #2 must be tensor of any type}}
+ // expected-error @+1 {{result #2 must be tensor of any non-token type}}
%0:5 = "test.mixed_normal_variadic_result"() : () -> (tensor<f32>, tensor<f32>, f32, tensor<f32>, tensor<f32>)
return %0#4 : tensor<f32>
}
@@ -29,7 +29,7 @@ func.func @error_in_normal_result() -> tensor<f32> {
// -----
func.func @error_in_second_variadic_result() -> tensor<f32> {
- // expected-error @+1 {{result #3 must be variadic of tensor of any type}}
+ // expected-error @+1 {{result #3 must be variadic of tensor of any non-token type}}
%0:5 = "test.mixed_normal_variadic_result"() : () -> (tensor<f32>, tensor<f32>, tensor<f32>, f32, tensor<f32>)
return %0#4 : tensor<f32>
}
diff --git a/mlir/test/IR/token-type.mlir b/mlir/test/IR/token-type.mlir
new file mode 100644
index 0000000000000..990a7917effe1
--- /dev/null
+++ b/mlir/test/IR/token-type.mlir
@@ -0,0 +1,60 @@
+// RUN: mlir-opt %s -verify-diagnostics -split-input-file | FileCheck %s
+
+// Tests for the builtin `token` type and the
+// `Token` / `AnyType` / `AnyTypeOrToken` ODS predicates.
+//
+// The default `AnyType` predicate excludes tokens, while `AnyTypeOrToken` and
+// `Token` accept them.
+
+// CHECK-LABEL: @token_produce_consume
+func.func @token_produce_consume() {
+ // CHECK: %[[T:.*]] = test.token.produce
+ %t = test.token.produce
+ // CHECK: test.token.consume %[[T]]
+ test.token.consume %t
+ // CHECK: test.token.any_or_token %[[T]] : token
+ test.token.any_or_token %t : token
+ return
+}
+
+// -----
+
+// `AnyTypeOrToken` also accepts non-token types.
+// CHECK-LABEL: @any_or_token_with_non_token
+func.func @any_or_token_with_non_token(%arg0: i32) {
+ // CHECK: test.token.any_or_token %{{.*}} : i32
+ test.token.any_or_token %arg0 : i32
+ return
+}
+
+// -----
+
+// `AnyType` accepts arbitrary non-token types.
+// CHECK-LABEL: @any_type_with_non_token
+func.func @any_type_with_non_token(%arg0: i32) {
+ // CHECK: test.token.any_type %{{.*}} : i32
+ test.token.any_type %arg0 : i32
+ return
+}
+
+// -----
+
+// `AnyType` rejects tokens by default.
+func.func @any_type_rejects_token() {
+ %t = test.token.produce
+ // expected-error @below {{operand #0 must be any non-token type}}
+ test.token.any_type %t : token
+ return
+}
+
+// -----
+
+// `Token` rejects non-token types. The op's operand type is fixed to the
+// builtin `token` (it's a `BuildableType`), so passing a non-token SSA value
+// fails at parse time with an SSA type mismatch.
+// expected-note @below {{prior use here}}
+func.func @token_rejects_non_token(%arg0: i32) {
+ // expected-error @below {{use of value '%arg0' expects different type than prior uses: 'token' vs 'i32'}}
+ test.token.consume %arg0
+ return
+}
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 56db6837b870c..eda75edbb69cd 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -110,6 +110,36 @@ def SignlessLikeVariadic : TEST_Op<"signless_like_variadic"> {
let arguments = (ins Variadic<SignlessIntegerLike>:$x);
}
+//===----------------------------------------------------------------------===//
+// Test Token Type
+//===----------------------------------------------------------------------===//
+
+// Produce a builtin `!token` value.
+def TestTokenProduceOp : TEST_Op<"token.produce"> {
+ let results = (outs Token:$token);
+ let assemblyFormat = "attr-dict";
+}
+
+// Consume a builtin `!token` value (token-only operand).
+def TestTokenConsumeOp : TEST_Op<"token.consume"> {
+ let arguments = (ins Token:$token);
+ let assemblyFormat = "$token attr-dict";
+}
+
+// Op that accepts any type, including a token. Uses the `AnyTypeOrToken`
+// opt-in predicate.
+def TestTokenAnyTypeOrTokenOp : TEST_Op<"token.any_or_token"> {
+ let arguments = (ins AnyTypeOrToken:$value);
+ let assemblyFormat = "$value attr-dict `:` type($value)";
+}
+
+// Op that uses the default `AnyType` predicate. Tokens are excluded by
+// default and should be rejected by the verifier when passed here.
+def TestTokenAnyTypeOp : TEST_Op<"token.any_type"> {
+ let arguments = (ins AnyType:$value);
+ let assemblyFormat = "$value attr-dict `:` type($value)";
+}
+
//===----------------------------------------------------------------------===//
// Test Symbols
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/mlir-tblgen/predicate.td b/mlir/test/mlir-tblgen/predicate.td
index 41e041f171213..ae436885b421f 100644
--- a/mlir/test/mlir-tblgen/predicate.td
+++ b/mlir/test/mlir-tblgen/predicate.td
@@ -27,9 +27,9 @@ def OpA : NS_Op<"op_for_CPred_containing_multiple_same_placeholder", []> {
// CHECK-NOT. << " must be 32-bit integer or floating-point type, but got " << type;
// CHECK: static ::llvm::LogicalResult [[$TENSOR_CONSTRAINT:__mlir_ods_local_type_constraint.*]](
-// CHECK: if (!(((::llvm::isa<::mlir::TensorType>(type))) && ([](::mlir::Type elementType) { return (true); }(::llvm::cast<::mlir::ShapedType>(type).getElementType())))) {
+// CHECK: if (!(((::llvm::isa<::mlir::TensorType>(type))) && ([](::mlir::Type elementType) { return !((::llvm::isa<::mlir::TokenType>(elementType))); }(::llvm::cast<::mlir::ShapedType>(type).getElementType())))) {
// CHECK-NEXT: return op->emitOpError(valueKind) << " #" << valueIndex
-// CHECK-NEXT: << " must be tensor of any type values, but got " << type;
+// CHECK-NEXT: << " must be tensor of any non-token type values, but got " << type;
// CHECK: static ::llvm::LogicalResult [[$TENSOR_INTEGER_FLOAT_CONSTRAINT:__mlir_ods_local_type_constraint.*]](
// CHECK: if (!(((::llvm::isa<::mlir::TensorType>(type))) && ([](::mlir::Type elementType) { return ((elementType.isF32())) || ((elementType.isSignlessInteger(32))); }(::llvm::cast<::mlir::ShapedType>(type).getElementType())))) {
diff --git a/mlir/test/mlir-tblgen/types.mlir b/mlir/test/mlir-tblgen/types.mlir
index c2acce0903bf4..30aea48e3e369 100644
--- a/mlir/test/mlir-tblgen/types.mlir
+++ b/mlir/test/mlir-tblgen/types.mlir
@@ -204,7 +204,7 @@ func.func @ranked_tensor_success(%arg0: tensor<i8>, %arg1: tensor<1xi32>, %arg2:
// -----
func.func @ranked_tensor_success(%arg0: tensor<*xf32>) {
- // expected-error @+1 {{must be ranked tensor of any type values}}
+ // expected-error @+1 {{must be ranked tensor of any non-token type values}}
"test.ranked_tensor_op"(%arg0) : (tensor<*xf32>) -> ()
return
}
@@ -212,7 +212,7 @@ func.func @ranked_tensor_success(%arg0: tensor<*xf32>) {
// -----
func.func @ranked_tensor_success(%arg0: vector<2xf32>) {
- // expected-error @+1 {{must be ranked tensor of any type values}}
+ // expected-error @+1 {{must be ranked tensor of any non-token type values}}
"test.ranked_tensor_op"(%arg0) : (vector<2xf32>) -> ()
return
}
@@ -510,7 +510,7 @@ func.func @does_not_have_i32(%arg0: tensor<1x2xi32>, %arg1: none) {
// -----
func.func @does_not_have_static_memref(%arg0: memref<?xi32>) {
- // expected-error at +1 {{'test.takes_static_memref' op operand #0 must be statically shaped memref of any type values}}
+ // expected-error at +1 {{'test.takes_static_memref' op operand #0 must be statically shaped memref of any non-token type values}}
"test.takes_static_memref"(%arg0) : (memref<?xi32>) -> ()
}
>From 01a0f8890bd232f23b6b58dcbb91159289d1a621 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Tue, 12 May 2026 11:27:35 +0000
Subject: [PATCH 03/15] address comments
---
mlir/docs/LangRef.md | 3 +-
mlir/docs/Tokens.md | 38 +++++++++----------
mlir/include/mlir/IR/CommonTypeConstraints.td | 9 +----
mlir/test/IR/token-type.mlir | 19 +---------
mlir/test/lib/Dialect/Test/TestOps.td | 7 ----
5 files changed, 22 insertions(+), 54 deletions(-)
diff --git a/mlir/docs/LangRef.md b/mlir/docs/LangRef.md
index 5e53df83997e2..185321ed5cfc6 100644
--- a/mlir/docs/LangRef.md
+++ b/mlir/docs/LangRef.md
@@ -738,7 +738,8 @@ dialect types.
The [builtin dialect](Dialects/Builtin.md) defines a set of types that are
directly usable by any other dialect in MLIR. These types cover a range from
-primitive integer and floating-point types, function types, and more.
+primitive integer and floating-point types, function types,
+[tokens](Tokens.md), and more.
## Properties
diff --git a/mlir/docs/Tokens.md b/mlir/docs/Tokens.md
index 51d4005f4723c..b33cece47fd06 100644
--- a/mlir/docs/Tokens.md
+++ b/mlir/docs/Tokens.md
@@ -10,32 +10,29 @@ forwarded: a token def-use chain cannot be obscured by ops with forwarding
semantics such as `arith.select` or `cf.br`. This allows you to always walk
back from a use and say "this token came from *that* specific op".
-A token is an SSA value that has the builtin token type. The token type is
-parameterless, opaque and prints as `token`. A token carries no runtime data.
+A token is an SSA value that has the builtin token type. The token type is
+parameterless, opaque and carries no runtime data. A token prints as `token`.
Apart from the structural contract below, tokens are like any other SSA values.
## Design Rationale
The token type allows operations to refer to another operation without a new
-parallel def-use system for operations. The existing def-use machinery for SSA
-values can be reused. Moreover, no changes were needed for the generic op
-syntax, the bytecode infrastructure and core C++ APIs around `Operation`.
+parallel def-use system for operations. It reuses the existing def-use
+machinery for SSA. It introduces no changes to the generic op syntax, the
+bytecode infrastructure or core C++ APIs around `Operation`.
As with regular use-def chains, a token def-use chain is unidirectional. A
-token use points to the token's definition. (But not the other way around.)
+token use points to the token's definition and not the other way around.
Transformations can remove the use of a token without having to touch or
-inspect the definition of the token. (Whether such a transformation is correct
-depends on the semantics of the token-producing and token-consuming ops.)
-
-Token-producing and token-consuming ops are subject to standard transformations
-such as CSE, DCE and hoisting. If such transformations are not desirable due to
-op semantics, common IR design patterns can be employed. To give a few examples:
-Terminators or ops with side effects are not CSE'd or DCE'd. Region block
-arguments semantically belong to the enclosing op and are never CSE'd, DCE'd or
-hoisted. Non-speculatability may also prevent hoisting.
+inspect the definition of the token.
## Structural Contract
+A token use cannot be substituted with another token value: the use of a token
+points directly to a specific producer. Generic transformations must not alter
+or break this link. New uses of a token can be introduced safely. As a
+consequence:
+
1. A token must not appear as a forwarded value, e.g.:
* a forwarded result/operand of a `CallOpInterface` op,
* an argument or result type of a `FunctionOpInterface` op (a token
@@ -48,8 +45,8 @@ hoisted. Non-speculatability may also prevent hoisting.
* the result of any op that selects or merges values it does not
understand (e.g. `arith.select`).
-2. As a consequence of (1), given a use of a token SSA value, its definition is
- guaranteed to be the semantic producer of the token.
+2. Given a use of a token SSA value, its definition is guaranteed to be the
+ semantic producer of the token.
3. A token cannot constant-fold. No constant of token type exists.
@@ -68,7 +65,6 @@ Three predicates are provided in `CommonTypeConstraints.td`:
| Predicate | Accepts | Use when … |
| ------------------ | ------------------------------------ | ----------------------------------------------------------------------|
| `AnyType` | any non-token type | the default; matches the historical meaning of "any type" pre-tokens. |
-| `AnyTypeOrToken` | any type, including tokens | the op legitimately accepts arbitrary types (including tokens). |
| `Token` | only the builtin `TokenType` | the op specifically takes a token operand/result. |
Example:
@@ -99,6 +95,6 @@ a branch or a loop.
```
`scf.if`'s results are declared with `Variadic<AnyType>` and `scf.yield`'s
-operands likewise use `AnyType`. Because `AnyType` excludes tokens by
-default, yielding (or returning) a token through a `scf.if` (or any other
-op that has not explicitly opted in via `AnyTypeOrToken`) is rejected.
+operands likewise use `AnyType`. Because `AnyType` excludes tokens, yielding
+(or returning) a token through `scf.if` (or any other op that has not
+explicitly opted in) is rejected.
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index 86066bcda73e6..d8615a4730c32 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -170,16 +170,9 @@ def IsTokenTypePred : CPred<"::llvm::isa<::mlir::TokenType>($_self)">;
// Any non-token type. Tokens are excluded by default to prevent ops that
// accept arbitrary types from accidentally accepting tokens as operands /
-// results, since a token must not be value-forwarded. Ops that legitimately
-// want to accept any type, including tokens, should use `AnyTypeOrToken`
-// instead.
+// results, since a token must not be value-forwarded.
def AnyType : Type<Neg<IsTokenTypePred>, "any non-token type">;
-// Any type at all, including tokens. Used by ops that explicitly opt in to
-// accepting tokens (e.g. ops in interfaces such as `CallOpInterface`,
-// `BranchOpInterface`, etc. that legitimately handle arbitrary types).
-def AnyTypeOrToken : Type<CPred<"true">, "any type">;
-
// The builtin token type.
def Token : Type<IsTokenTypePred, "token", "::mlir::TokenType">,
BuildableType<"$_builder.getType<::mlir::TokenType>()">;
diff --git a/mlir/test/IR/token-type.mlir b/mlir/test/IR/token-type.mlir
index 990a7917effe1..0e46daed248fb 100644
--- a/mlir/test/IR/token-type.mlir
+++ b/mlir/test/IR/token-type.mlir
@@ -1,10 +1,7 @@
// RUN: mlir-opt %s -verify-diagnostics -split-input-file | FileCheck %s
-// Tests for the builtin `token` type and the
-// `Token` / `AnyType` / `AnyTypeOrToken` ODS predicates.
-//
-// The default `AnyType` predicate excludes tokens, while `AnyTypeOrToken` and
-// `Token` accept them.
+// Tests for the builtin `token` type and the `Token`, `AnyType` ODS predicates.
+// The default `AnyType` predicate excludes tokens.
// CHECK-LABEL: @token_produce_consume
func.func @token_produce_consume() {
@@ -12,18 +9,6 @@ func.func @token_produce_consume() {
%t = test.token.produce
// CHECK: test.token.consume %[[T]]
test.token.consume %t
- // CHECK: test.token.any_or_token %[[T]] : token
- test.token.any_or_token %t : token
- return
-}
-
-// -----
-
-// `AnyTypeOrToken` also accepts non-token types.
-// CHECK-LABEL: @any_or_token_with_non_token
-func.func @any_or_token_with_non_token(%arg0: i32) {
- // CHECK: test.token.any_or_token %{{.*}} : i32
- test.token.any_or_token %arg0 : i32
return
}
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index eda75edbb69cd..21dd4727a49ed 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -126,13 +126,6 @@ def TestTokenConsumeOp : TEST_Op<"token.consume"> {
let assemblyFormat = "$token attr-dict";
}
-// Op that accepts any type, including a token. Uses the `AnyTypeOrToken`
-// opt-in predicate.
-def TestTokenAnyTypeOrTokenOp : TEST_Op<"token.any_or_token"> {
- let arguments = (ins AnyTypeOrToken:$value);
- let assemblyFormat = "$value attr-dict `:` type($value)";
-}
-
// Op that uses the default `AnyType` predicate. Tokens are excluded by
// default and should be rejected by the verifier when passed here.
def TestTokenAnyTypeOp : TEST_Op<"token.any_type"> {
>From 50fded248cbd8d4d7fadd1f1f52c8f4430f4e5f4 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Wed, 13 May 2026 11:39:51 +0200
Subject: [PATCH 04/15] Update mlir/docs/Tokens.md
Co-authored-by: Mehdi Amini <joker.eph at gmail.com>
---
mlir/docs/Tokens.md | 2 ++
1 file changed, 2 insertions(+)
diff --git a/mlir/docs/Tokens.md b/mlir/docs/Tokens.md
index b33cece47fd06..121b1f2dfc8b8 100644
--- a/mlir/docs/Tokens.md
+++ b/mlir/docs/Tokens.md
@@ -50,6 +50,8 @@ consequence:
3. A token cannot constant-fold. No constant of token type exists.
+4. Use of a token is side-effect free: a token user follows the usual `isTriviallyDead()` rules.
+
These properties mirror what LLVM IR already documents for its own
[`token` type](https://llvm.org/docs/LangRef.html#token-type).
>From 6cca0377549265293a30130b342b14f38a37273f Mon Sep 17 00:00:00 2001
From: Mehdi Amini <joker.eph at gmail.com>
Date: Wed, 13 May 2026 02:00:52 -0700
Subject: [PATCH 05/15] [mlir][IR] Require token producer and consumer traits
Add marker traits for operations that intentionally produce or consume the
builtin token type. The verifier now rejects token results without
TokenProducerTrait, token operands without TokenConsumerTrait, token entry
block arguments whose parent op does not produce tokens, and token block
arguments outside entry blocks.
Extend the Test dialect token ops to cover valid opt-in cases and each
verifier rejection path.
Assisted-by: Codex
---
mlir/docs/Tokens.md | 21 ++++--
mlir/docs/Traits/_index.md | 14 ++++
mlir/include/mlir/IR/BuiltinTypes.td | 4 +-
mlir/include/mlir/IR/OpBase.td | 4 +
mlir/include/mlir/IR/OpDefinition.h | 12 +++
mlir/lib/IR/Verifier.cpp | 96 +++++++++++++++++++++++-
mlir/test/IR/token-type.mlir | 102 +++++++++++++++++++++++++-
mlir/test/lib/Dialect/Test/TestOps.td | 31 +++++++-
8 files changed, 268 insertions(+), 16 deletions(-)
diff --git a/mlir/docs/Tokens.md b/mlir/docs/Tokens.md
index 121b1f2dfc8b8..9be9aaf8ec9de 100644
--- a/mlir/docs/Tokens.md
+++ b/mlir/docs/Tokens.md
@@ -30,14 +30,13 @@ inspect the definition of the token.
A token use cannot be substituted with another token value: the use of a token
points directly to a specific producer. Generic transformations must not alter
-or break this link. New uses of a token can be introduced safely. As a
-consequence:
+or break this link. New uses of a token can be introduced safely. Operations
+must opt in to producing or consuming tokens with `TokenProducerTrait` and
+`TokenConsumerTrait`. As a consequence:
1. A token must not appear as a forwarded value, e.g.:
* a forwarded result/operand of a `CallOpInterface` op,
- * an argument or result type of a `FunctionOpInterface` op (a token
- block argument *inside* a function body is fine — what is disallowed
- is forwarding tokens across the call/return boundary),
+ * an argument or result type of a `FunctionOpInterface` op,
* a successor operand or successor block argument of a
`BranchOpInterface` op,
* a forwarded operand to/from any region of a `RegionBranchOpInterface`
@@ -62,7 +61,7 @@ not opted in cannot accept a token as an arbitrary operand or result. This
restriction prevents tokens from being accidentally passed as operands with
forwarding semantics.
-Three predicates are provided in `CommonTypeConstraints.td`:
+Two predicates are provided in `CommonTypeConstraints.td`:
| Predicate | Accepts | Use when … |
| ------------------ | ------------------------------------ | ----------------------------------------------------------------------|
@@ -72,11 +71,19 @@ Three predicates are provided in `CommonTypeConstraints.td`:
Example:
```tablegen
-def MyConsumeOp : MyDialect_Op<"consume"> {
+def MyProduceOp : MyDialect_Op<"produce", [TokenProducerTrait]> {
+ let results = (outs Token:$token);
+}
+
+def MyConsumeOp : MyDialect_Op<"consume", [TokenConsumerTrait]> {
let arguments = (ins Token:$scope, AnyType:$value);
}
```
+Region entry block arguments of `token` type are also token producers and
+require the parent operation to define `TokenProducerTrait`. Token block
+arguments in non-entry blocks are rejected.
+
## Examples
### Rejected: tokens in `AnyType` positions
diff --git a/mlir/docs/Traits/_index.md b/mlir/docs/Traits/_index.md
index 866716b9f5193..7d01f8517477a 100644
--- a/mlir/docs/Traits/_index.md
+++ b/mlir/docs/Traits/_index.md
@@ -349,3 +349,17 @@ This trait removes the requirement on regions held by an operation to have
[terminator operations](../LangRef.md/#control-flow-and-ssacfg-regions) at the end of a block.
This requires that these regions have a single block. An example of operation
using this trait is the top-level `ModuleOp`.
+
+### TokenProducerTrait
+
+* `OpTrait::TokenProducerTrait` -- `TokenProducerTrait`
+
+This trait marks operations that are allowed to produce builtin `token` values
+as operation results or as region entry block arguments.
+
+### TokenConsumerTrait
+
+* `OpTrait::TokenConsumerTrait` -- `TokenConsumerTrait`
+
+This trait marks operations that are allowed to consume builtin `token` values
+as operands.
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index 3d36ffa5802c3..40ccaefa6de3f 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -1253,7 +1253,9 @@ def Builtin_Token : Builtin_Type<"Token", "token"> {
A use of a token SSA value is a pointer to an operation (in case of an
OpResult) or a pointer to a region (in case of an entry block argument).
A token carries no runtime data and cannot be forwarded. Tokens are
- excluded from the `AnyType` type constraint.
+ excluded from the `AnyType` type constraint. Operations must define
+ `TokenProducerTrait` to produce token results or token region entry block
+ arguments, and must define `TokenConsumerTrait` to consume token operands.
}];
}
diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 1e34959d0d557..0d0669e90c3f7 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -98,6 +98,10 @@ def SameOperandsAndResultElementType :
NativeOpTrait<"SameOperandsAndResultElementType">;
// Op is a terminator.
def Terminator : NativeOpTrait<"IsTerminator">;
+// Op produces builtin token values.
+def TokenProducerTrait : NativeOpTrait<"TokenProducerTrait">;
+// Op consumes builtin token values.
+def TokenConsumerTrait : NativeOpTrait<"TokenConsumerTrait">;
// Op can be safely normalized in the presence of MemRefs with
// non-identity maps.
def MemRefsNormalizable : NativeOpTrait<"MemRefsNormalizable">;
diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index c1fba10e06a90..c85aefc19eb81 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -778,6 +778,18 @@ class IsTerminator : public TraitBase<ConcreteType, IsTerminator> {
}
};
+/// This trait marks operations that are allowed to produce builtin token
+/// values.
+template <typename ConcreteType>
+class TokenProducerTrait : public TraitBase<ConcreteType, TokenProducerTrait> {
+};
+
+/// This trait marks operations that are allowed to consume builtin token
+/// values.
+template <typename ConcreteType>
+class TokenConsumerTrait : public TraitBase<ConcreteType, TokenConsumerTrait> {
+};
+
/// This class provides verification for ops that are known to have zero
/// successors.
template <typename ConcreteType>
diff --git a/mlir/lib/IR/Verifier.cpp b/mlir/lib/IR/Verifier.cpp
index 33da2cd867f42..af11e5251f279 100644
--- a/mlir/lib/IR/Verifier.cpp
+++ b/mlir/lib/IR/Verifier.cpp
@@ -26,6 +26,7 @@
#include "mlir/IR/Verifier.h"
#include "mlir/IR/Attributes.h"
+#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Dominance.h"
#include "mlir/IR/Operation.h"
@@ -42,8 +43,8 @@ class OperationVerifier {
public:
/// If `verifyRecursively` is true, then this will also recursively verify
/// nested operations.
- explicit OperationVerifier(bool verifyRecursively)
- : verifyRecursively(verifyRecursively) {}
+ OperationVerifier(MLIRContext *ctx, bool verifyRecursively)
+ : tokenType(TokenType::get(ctx)), verifyRecursively(verifyRecursively) {}
/// Verify the given operation.
LogicalResult verifyOpAndDominance(Operation &op);
@@ -58,6 +59,12 @@ class OperationVerifier {
/// upon exit from the subtree, i.e. when we visit a node for the second time.
LogicalResult verifyOnEntrance(Block &block);
LogicalResult verifyOnEntrance(Operation &op);
+ LogicalResult
+ verifyTokenValue(Operation &producer, Value value,
+ function_ref<InFlightDiagnostic()> emitProducerError);
+ LogicalResult verifyTokenValues(Operation &op);
+ LogicalResult verifyTokenBlockArgument(Block &block, BlockArgument arg,
+ unsigned idx);
LogicalResult verifyOnExit(Block &block);
LogicalResult verifyOnExit(Operation &op);
@@ -70,6 +77,9 @@ class OperationVerifier {
LogicalResult verifyDominanceOfContainedRegions(Operation &op,
DominanceInfo &domInfo);
+ /// The cached instance of the builtin token type.
+ TokenType tokenType;
+
/// A flag indicating if this verifier should recursively verify nested
/// operations.
bool verifyRecursively;
@@ -109,6 +119,81 @@ static bool mayBeValidWithoutTerminator(Block *block) {
return !op || op->mightHaveTrait<OpTrait::NoTerminator>();
}
+LogicalResult OperationVerifier::verifyTokenValue(
+ Operation &producer, Value value,
+ function_ref<InFlightDiagnostic()> emitProducerError) {
+ if (value.getType() != tokenType)
+ return success();
+
+ if (!producer.hasTrait<OpTrait::TokenProducerTrait>())
+ return emitProducerError();
+
+ for (OpOperand &use : value.getUses()) {
+ Operation *user = use.getOwner();
+ if (user->hasTrait<OpTrait::TokenConsumerTrait>())
+ continue;
+
+ return user->emitOpError()
+ << "consumes token operand #" << use.getOperandNumber()
+ << " but does not define the TokenConsumerTrait";
+ }
+
+ return success();
+}
+
+LogicalResult OperationVerifier::verifyTokenValues(Operation &op) {
+ for (auto resultIt : llvm::enumerate(op.getResults())) {
+ unsigned idx = resultIt.index();
+ OpResult result = resultIt.value();
+ if (failed(verifyTokenValue(op, result, [&]() {
+ return op.emitOpError()
+ << "produces token result #" << idx
+ << " but does not define the TokenProducerTrait";
+ })))
+ return failure();
+ }
+
+ for (Region ®ion : op.getRegions()) {
+ if (region.empty())
+ continue;
+
+ Block &entryBlock = region.front();
+ for (auto argIt : llvm::enumerate(entryBlock.getArguments())) {
+ unsigned idx = argIt.index();
+ BlockArgument arg = argIt.value();
+ if (failed(verifyTokenValue(op, arg, [&]() {
+ return emitError(arg.getLoc(), "token entry block argument #")
+ << idx << " requires the parent operation to define the "
+ << "TokenProducerTrait";
+ })))
+ return failure();
+ }
+ }
+
+ return success();
+}
+
+LogicalResult OperationVerifier::verifyTokenBlockArgument(Block &block,
+ BlockArgument arg,
+ unsigned idx) {
+ if (arg.getType() != tokenType)
+ return success();
+
+ Region *parentRegion = block.getParent();
+ if (!parentRegion || !block.isEntryBlock())
+ return emitError(arg.getLoc(), "token block argument #")
+ << idx << " is only allowed in a region entry block";
+
+ Operation *parentOp = parentRegion->getParentOp();
+ if (!parentOp || !parentOp->hasTrait<OpTrait::TokenProducerTrait>())
+ return emitError(arg.getLoc(), "token entry block argument #")
+ << idx
+ << " requires the parent operation to define the "
+ "TokenProducerTrait";
+
+ return success();
+}
+
LogicalResult OperationVerifier::verifyOnEntrance(Block &block) {
// Get the parent op and context for cross-context checks. Both are available
// whenever the block lives inside a region that has a parent operation.
@@ -133,6 +218,8 @@ LogicalResult OperationVerifier::verifyOnEntrance(Block &block) {
<< " type from a different MLIRContext than its "
"parent operation";
}
+ if (failed(verifyTokenBlockArgument(block, arg, idx)))
+ return failure();
}
// Verify that this block has a terminator.
@@ -232,6 +319,9 @@ LogicalResult OperationVerifier::verifyOnEntrance(Operation &op) {
if (registeredInfo && failed(registeredInfo->verifyInvariants(&op)))
return failure();
+ if (failed(verifyTokenValues(op)))
+ return failure();
+
unsigned numRegions = op.getNumRegions();
if (!numRegions)
return success();
@@ -478,6 +568,6 @@ OperationVerifier::verifyDominanceOfContainedRegions(Operation &op,
//===----------------------------------------------------------------------===//
LogicalResult mlir::verify(Operation *op, bool verifyRecursively) {
- OperationVerifier verifier(verifyRecursively);
+ OperationVerifier verifier(op->getContext(), verifyRecursively);
return verifier.verifyOpAndDominance(*op);
}
diff --git a/mlir/test/IR/token-type.mlir b/mlir/test/IR/token-type.mlir
index 0e46daed248fb..10347ee20dd40 100644
--- a/mlir/test/IR/token-type.mlir
+++ b/mlir/test/IR/token-type.mlir
@@ -1,7 +1,8 @@
// RUN: mlir-opt %s -verify-diagnostics -split-input-file | FileCheck %s
-// Tests for the builtin `token` type and the `Token`, `AnyType` ODS predicates.
-// The default `AnyType` predicate excludes tokens.
+// Tests for the builtin `token` type, the token producer/consumer operation
+// traits, and the `Token`, `AnyType` ODS predicates. The default `AnyType`
+// predicate excludes tokens.
// CHECK-LABEL: @token_produce_consume
func.func @token_produce_consume() {
@@ -14,6 +15,21 @@ func.func @token_produce_consume() {
// -----
+// Region entry block arguments may produce tokens when the parent op opts in.
+// CHECK-LABEL: @token_region_entry_block_arg
+func.func @token_region_entry_block_arg() {
+ // CHECK: "test.token.region"
+ "test.token.region"() ({
+ ^bb0(%arg0: token):
+ // CHECK: test.token.consume
+ test.token.consume %arg0
+ "test.finish"() : () -> ()
+ }) : () -> ()
+ return
+}
+
+// -----
+
// `AnyType` accepts arbitrary non-token types.
// CHECK-LABEL: @any_type_with_non_token
func.func @any_type_with_non_token(%arg0: i32) {
@@ -34,6 +50,88 @@ func.func @any_type_rejects_token() {
// -----
+// Token-producing ops must explicitly define the TokenProducerTrait.
+func.func @token_result_requires_producer_trait() {
+ // expected-error @below {{'test.token.produce_without_trait' op produces token result #0 but does not define the TokenProducerTrait}}
+ %t = test.token.produce_without_trait
+ return
+}
+
+// -----
+
+// Token-consuming ops must explicitly define the TokenConsumerTrait.
+func.func @token_operand_requires_consumer_trait() {
+ %t = test.token.produce
+ // expected-error @below {{'test.token.consume_without_trait' op consumes token operand #0 but does not define the TokenConsumerTrait}}
+ test.token.consume_without_trait %t
+ return
+}
+
+// -----
+
+// Token entry block arguments require the parent op to define the
+// TokenProducerTrait.
+func.func @token_entry_block_arg_requires_parent_producer_trait() {
+ "test.token.region_without_trait"() ({
+ // expected-error @below {{token entry block argument #0 requires the parent operation to define the TokenProducerTrait}}
+ ^bb0(%arg0: token):
+ test.token.consume %arg0
+ "test.finish"() : () -> ()
+ }) : () -> ()
+ return
+}
+
+// -----
+
+// A region with a parent op still cannot have token entry block arguments unless
+// the parent op defines the TokenProducerTrait.
+func.func @token_entry_block_arg_requires_parent_producer_trait_without_uses() {
+ "test.token.region_without_trait"() ({
+ // expected-error @below {{token entry block argument #0 requires the parent operation to define the TokenProducerTrait}}
+ ^bb0(%arg0: token):
+ "test.finish"() : () -> ()
+ }) : () -> ()
+ return
+}
+
+// -----
+
+// Token entry block arguments still require consumers to define the
+// TokenConsumerTrait.
+func.func @token_entry_block_arg_use_requires_consumer_trait() {
+ "test.token.region"() ({
+ ^bb0(%arg0: token):
+ // expected-error @below {{'test.token.consume_without_trait' op consumes token operand #0 but does not define the TokenConsumerTrait}}
+ test.token.consume_without_trait %arg0
+ "test.finish"() : () -> ()
+ }) : () -> ()
+ return
+}
+
+// -----
+
+// Tokens cannot be non-entry block arguments.
+func.func @token_non_entry_block_arg_is_rejected() {
+ "test.token.region"() ({
+ "test.finish"() : () -> ()
+ // expected-error @below {{token block argument #0 is only allowed in a region entry block}}
+ ^bb1(%arg0: token):
+ "test.finish"() : () -> ()
+ }) : () -> ()
+ return
+}
+
+// -----
+
+// Function entry blocks do not opt in to producing builtin tokens.
+// expected-error @below {{token entry block argument #0 requires the parent operation to define the TokenProducerTrait}}
+func.func @token_region_arg(%arg0: token) {
+ test.token.consume %arg0
+ return
+}
+
+// -----
+
// `Token` rejects non-token types. The op's operand type is fixed to the
// builtin `token` (it's a `BuildableType`), so passing a non-token SSA value
// fails at parse time with an SSA type mismatch.
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 21dd4727a49ed..3f649a97c7ff0 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -115,24 +115,49 @@ def SignlessLikeVariadic : TEST_Op<"signless_like_variadic"> {
//===----------------------------------------------------------------------===//
// Produce a builtin `!token` value.
-def TestTokenProduceOp : TEST_Op<"token.produce"> {
+def TestTokenProduceOp : TEST_Op<"token.produce", [TokenProducerTrait]> {
+ let results = (outs Token:$token);
+ let assemblyFormat = "attr-dict";
+}
+
+// Produce a builtin `!token` value without the required producer trait.
+def TestTokenProduceWithoutTraitOp
+ : TEST_Op<"token.produce_without_trait"> {
let results = (outs Token:$token);
let assemblyFormat = "attr-dict";
}
// Consume a builtin `!token` value (token-only operand).
-def TestTokenConsumeOp : TEST_Op<"token.consume"> {
+def TestTokenConsumeOp : TEST_Op<"token.consume", [TokenConsumerTrait]> {
+ let arguments = (ins Token:$token);
+ let assemblyFormat = "$token attr-dict";
+}
+
+// Consume a builtin `!token` value without the required consumer trait.
+def TestTokenConsumeWithoutTraitOp
+ : TEST_Op<"token.consume_without_trait"> {
let arguments = (ins Token:$token);
let assemblyFormat = "$token attr-dict";
}
// Op that uses the default `AnyType` predicate. Tokens are excluded by
// default and should be rejected by the verifier when passed here.
-def TestTokenAnyTypeOp : TEST_Op<"token.any_type"> {
+def TestTokenAnyTypeOp : TEST_Op<"token.any_type", [TokenConsumerTrait]> {
let arguments = (ins AnyType:$value);
let assemblyFormat = "$value attr-dict `:` type($value)";
}
+// Op whose region entry blocks may produce builtin `!token` values.
+def TestTokenRegionOp : TEST_Op<"token.region", [TokenProducerTrait]> {
+ let regions = (region AnyRegion:$body);
+}
+
+// Op whose regions do not opt in to producing builtin `!token` values.
+def TestTokenRegionWithoutTraitOp
+ : TEST_Op<"token.region_without_trait"> {
+ let regions = (region AnyRegion:$body);
+}
+
//===----------------------------------------------------------------------===//
// Test Symbols
//===----------------------------------------------------------------------===//
>From 183f61abf1170fce50be703fb474c65d84a6a060 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Thu, 14 May 2026 16:40:54 +0000
Subject: [PATCH 06/15] remove LLVM token type
---
mlir/docs/Dialects/LLVM.md | 2 +-
.../mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td | 25 ++++----
.../include/mlir/Dialect/LLVMIR/LLVMOpBase.td | 6 --
mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td | 13 ++--
mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h | 1 -
mlir/include/mlir/IR/BuiltinOps.td | 6 +-
mlir/include/mlir/IR/CommonTypeConstraints.td | 5 ++
.../Conversion/AsyncToLLVM/AsyncToLLVM.cpp | 4 +-
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 9 ++-
mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp | 60 +++++++------------
mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp | 22 +++----
mlir/lib/IR/Verifier.cpp | 6 +-
mlir/lib/Target/LLVMIR/TypeFromLLVM.cpp | 2 +-
mlir/lib/Target/LLVMIR/TypeToLLVM.cpp | 4 +-
.../AsyncToLLVM/convert-coro-to-llvm.mlir | 2 +-
mlir/test/Dialect/LLVMIR/types.mlir | 4 +-
mlir/test/Target/LLVMIR/Import/intrinsic.ll | 14 ++---
.../test/Target/LLVMIR/llvmir-intrinsics.mlir | 18 +++---
18 files changed, 94 insertions(+), 109 deletions(-)
diff --git a/mlir/docs/Dialects/LLVM.md b/mlir/docs/Dialects/LLVM.md
index 4b5d518ca4eab..87c27cc12b7ed 100644
--- a/mlir/docs/Dialects/LLVM.md
+++ b/mlir/docs/Dialects/LLVM.md
@@ -242,7 +242,7 @@ LLVM dialect:
- `!llvm.ppc_fp128` (`LLVMPPCFP128Type`) - 128-bit floating-point value (two
64 bits).
-- `!llvm.token` (`LLVMTokenType`) - a non-inspectable value associated with an
+- `!builtin.token` (`TokenType`) - a non-inspectable value associated with an
operation.
- `!llvm.metadata` (`LLVMMetadataType`) - LLVM IR metadata, to be used only if
the metadata cannot be represented as structured MLIR attributes.
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
index 688bc19cbf18a..62941eeac2fa9 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
@@ -698,7 +698,7 @@ def LLVM_ThreadlocalAddressOp : LLVM_OneResultIntrOp<"threadlocal.address", [],
// Coroutine intrinsics.
//
-def LLVM_CoroIdOp : LLVM_IntrOp<"coro.id", [], [], [], 1> {
+def LLVM_CoroIdOp : LLVM_IntrOp<"coro.id", [], [], [TokenProducerTrait], 1> {
let arguments = (ins I32:$align,
LLVM_AnyPointer:$promise,
LLVM_AnyPointer:$coroaddr,
@@ -707,8 +707,9 @@ def LLVM_CoroIdOp : LLVM_IntrOp<"coro.id", [], [], [], 1> {
" attr-dict `:` functional-type(operands, results)";
}
-def LLVM_CoroBeginOp : LLVM_IntrOp<"coro.begin", [], [], [], 1> {
- let arguments = (ins LLVM_TokenType:$token,
+def LLVM_CoroBeginOp
+ : LLVM_IntrOp<"coro.begin", [], [], [TokenConsumerTrait], 1> {
+ let arguments = (ins Token:$token,
LLVM_AnyPointer:$mem);
let assemblyFormat = "$token `,` $mem attr-dict `:` functional-type(operands, results)";
}
@@ -721,26 +722,30 @@ def LLVM_CoroAlignOp : LLVM_IntrOp<"coro.align", [0], [], [], 1> {
let assemblyFormat = "attr-dict `:` type($res)";
}
-def LLVM_CoroSaveOp : LLVM_IntrOp<"coro.save", [], [], [], 1> {
+def LLVM_CoroSaveOp
+ : LLVM_IntrOp<"coro.save", [], [], [TokenProducerTrait], 1> {
let arguments = (ins LLVM_AnyPointer:$handle);
let assemblyFormat = "$handle attr-dict `:` functional-type(operands, results)";
}
-def LLVM_CoroSuspendOp : LLVM_IntrOp<"coro.suspend", [], [], [], 1> {
- let arguments = (ins LLVM_TokenType:$save,
+def LLVM_CoroSuspendOp
+ : LLVM_IntrOp<"coro.suspend", [], [], [TokenConsumerTrait], 1> {
+ let arguments = (ins Token:$save,
I1:$final);
let assemblyFormat = "$save `,` $final attr-dict `:` type($res)";
}
-def LLVM_CoroEndOp : LLVM_IntrOp<"coro.end", [], [], [], 1> {
+def LLVM_CoroEndOp
+ : LLVM_IntrOp<"coro.end", [], [], [TokenConsumerTrait], 1> {
let arguments = (ins LLVM_AnyPointer:$handle,
I1:$unwind,
- LLVM_TokenType:$retvals);
+ Token:$retvals);
let assemblyFormat = "$handle `,` $unwind `,` $retvals attr-dict `:` functional-type(operands, results)";
}
-def LLVM_CoroFreeOp : LLVM_IntrOp<"coro.free", [], [], [], 1> {
- let arguments = (ins LLVM_TokenType:$id,
+def LLVM_CoroFreeOp
+ : LLVM_IntrOp<"coro.free", [], [], [TokenConsumerTrait], 1> {
+ let arguments = (ins Token:$id,
LLVM_AnyPointer:$handle);
let assemblyFormat = "$id `,` $handle attr-dict `:` functional-type(operands, results)";
}
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
index bd59319c79ad3..5e5b9e9faa070 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
@@ -29,12 +29,6 @@ def LLVM_Type : DialectType<LLVM_Dialect,
CPred<"::mlir::LLVM::isCompatibleOuterType($_self)">,
"LLVM dialect-compatible type">;
-// Type constraint accepting LLVM token type.
-def LLVM_TokenType : Type<
- CPred<"::llvm::isa<::mlir::LLVM::LLVMTokenType>($_self)">,
- "LLVM token type">,
- BuildableType<"::mlir::LLVM::LLVMTokenType::get($_builder.getContext())">;
-
// Type constraint accepting LLVM primitive types, i.e. all types except void
// and function.
def LLVM_PrimitiveType : Type<
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index d7c8cf236f0da..4401d5ea4f5f0 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -2135,18 +2135,17 @@ def LLVM_LLVMFuncOp : LLVM_Op<"func", [
}
def LLVM_NoneTokenOp
- : LLVM_Op<"mlir.none", [Pure]> {
+ : LLVM_Op<"mlir.none", [Pure, TokenProducerTrait]> {
let summary = "Defines a value containing an empty token to LLVM type.";
let description = [{
- Unlike LLVM IR, MLIR does not have first-class token values. They must be
- explicitly created as SSA values using `llvm.mlir.none`. This operation has
- no operands or attributes, and returns a none token value of a wrapped LLVM IR
- pointer type.
+ MLIR does not have a way to spell the LLVM IR `none` token literal. This
+ operation produces a builtin `!token` SSA value that lowers to
+ `llvm::ConstantTokenNone` in LLVM IR.
Examples:
```mlir
- %0 = llvm.mlir.none : !llvm.token
+ %0 = llvm.mlir.none : !token
```
}];
@@ -2154,7 +2153,7 @@ def LLVM_NoneTokenOp
$res = llvm::ConstantTokenNone::get(builder.getContext());
}];
- let results = (outs LLVM_TokenType:$res);
+ let results = (outs Token:$res);
let builders = [LLVM_OneResultOpBuilder];
let assemblyFormat = "attr-dict `:` type($res)";
}
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
index a1506497dc85c..a54f83660ca9c 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
@@ -67,7 +67,6 @@ namespace LLVM {
}
DEFINE_TRIVIAL_LLVM_TYPE(LLVMVoidType, "llvm.void");
-DEFINE_TRIVIAL_LLVM_TYPE(LLVMTokenType, "llvm.token");
DEFINE_TRIVIAL_LLVM_TYPE(LLVMLabelType, "llvm.label");
DEFINE_TRIVIAL_LLVM_TYPE(LLVMMetadataType, "llvm.metadata");
diff --git a/mlir/include/mlir/IR/BuiltinOps.td b/mlir/include/mlir/IR/BuiltinOps.td
index cdc09afe0b67e..3e9d2578fba12 100644
--- a/mlir/include/mlir/IR/BuiltinOps.td
+++ b/mlir/include/mlir/IR/BuiltinOps.td
@@ -100,7 +100,7 @@ def ModuleOp : Builtin_Op<"module", [
//===----------------------------------------------------------------------===//
def UnrealizedConversionCastOp : Builtin_Op<"unrealized_conversion_cast", [
- Pure
+ Pure, TokenProducerTrait, TokenConsumerTrait
]> {
let summary = "An unrealized conversion from one set of types to another";
let description = [{
@@ -136,8 +136,8 @@ def UnrealizedConversionCastOp : Builtin_Op<"unrealized_conversion_cast", [
```
}];
- let arguments = (ins Variadic<AnyType>:$inputs);
- let results = (outs Variadic<AnyType>:$outputs);
+ let arguments = (ins Variadic<AnyTypeOrToken>:$inputs);
+ let results = (outs Variadic<AnyTypeOrToken>:$outputs);
let assemblyFormat = [{
($inputs^ `:` type($inputs))? `to` type($outputs) attr-dict
}];
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index d8615a4730c32..8436cf5582c4d 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -173,6 +173,11 @@ def IsTokenTypePred : CPred<"::llvm::isa<::mlir::TokenType>($_self)">;
// results, since a token must not be value-forwarded.
def AnyType : Type<Neg<IsTokenTypePred>, "any non-token type">;
+// Any type, including the builtin `TokenType`. Use this as an explicit
+// opt-in for ops that legitimately need to handle arbitrary types,
+// including tokens (e.g. `builtin.unrealized_conversion_cast`).
+def AnyTypeOrToken : Type<CPred<"true">, "any type (including token)">;
+
// The builtin token type.
def Token : Type<IsTokenTypePred, "token", "::mlir::TokenType">,
BuildableType<"$_builder.getType<::mlir::TokenType>()">;
diff --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
index 7844c9dda877c..46e53e71d35f5 100644
--- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
+++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
@@ -78,8 +78,8 @@ struct AsyncAPI {
return LLVM::LLVMPointerType::get(ctx);
}
- static LLVM::LLVMTokenType tokenType(MLIRContext *ctx) {
- return LLVM::LLVMTokenType::get(ctx);
+ static mlir::TokenType tokenType(MLIRContext *ctx) {
+ return mlir::TokenType::get(ctx);
}
static FunctionType addOrDropRefFunctionType(MLIRContext *ctx) {
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 63bd9f8a3d625..aac100198509a 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -2605,8 +2605,8 @@ static bool isZeroAttribute(Attribute value) {
LogicalResult GlobalOp::verify() {
bool validType = isCompatibleOuterType(getType())
- ? !llvm::isa<LLVMVoidType, LLVMTokenType,
- LLVMMetadataType, LLVMLabelType>(getType())
+ ? !llvm::isa<LLVMVoidType, TokenType, LLVMMetadataType,
+ LLVMLabelType>(getType())
: llvm::isa<PointerElementTypeInterface>(getType());
if (!validType)
return emitOpError(
@@ -2826,8 +2826,8 @@ ParseResult AliasOp::parse(OpAsmParser &parser, OperationState &result) {
LogicalResult AliasOp::verify() {
bool validType = isCompatibleOuterType(getType())
- ? !llvm::isa<LLVMVoidType, LLVMTokenType,
- LLVMMetadataType, LLVMLabelType>(getType())
+ ? !llvm::isa<LLVMVoidType, TokenType, LLVMMetadataType,
+ LLVMLabelType>(getType())
: llvm::isa<PointerElementTypeInterface>(getType());
if (!validType)
return emitOpError(
@@ -4462,7 +4462,6 @@ void LLVMDialect::initialize() {
// clang-format off
addTypes<LLVMVoidType,
- LLVMTokenType,
LLVMLabelType,
LLVMMetadataType>();
// clang-format on
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
index 87df1e2df6102..c498f0746169d 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
@@ -36,7 +36,6 @@ static StringRef getTypeKeyword(Type type) {
return TypeSwitch<Type, StringRef>(type)
.Case<LLVMVoidType>([&](Type) { return "void"; })
.Case<LLVMPPCFP128Type>([&](Type) { return "ppc_fp128"; })
- .Case<LLVMTokenType>([&](Type) { return "token"; })
.Case<LLVMLabelType>([&](Type) { return "label"; })
.Case<LLVMMetadataType>([&](Type) { return "metadata"; })
.Case<LLVMFunctionType>([&](Type) { return "func"; })
@@ -238,41 +237,13 @@ Type LLVMStructType::parse(AsmParser &parser) {
}
/// Parses a type appearing inside another LLVM dialect-compatible type. This
-/// will first try to parse the LLVM dialect's short-hand keyword form (e.g.
-/// `token`, `void`, `ptr`, ...) and, failing that, fall back to parsing any
-/// MLIR type in full form (including types with the `!llvm` prefix).
-///
-/// Trying the short-hand form first matters because some LLVM short-hand
-/// keywords (notably `token`) collide with builtin type keywords whose
-/// semantics differ from the LLVM dialect's. Inside an `!llvm.<...>` type, the
-/// LLVM-specific meaning must always win.
+/// will try to parse any type in full form (including types with the `!llvm`
+/// prefix), and on failure fall back to parsing the short-hand version of the
+/// LLVM dialect types without the `!llvm` prefix.
static Type dispatchParse(AsmParser &parser, bool allowAny = true) {
SMLoc keyLoc = parser.getCurrentLocation();
- MLIRContext *ctx = parser.getContext();
-
- // Try parsing the LLVM dialect's short-hand keyword form first.
- StringRef key;
- if (succeeded(parser.parseOptionalKeyword(
- &key, {"void", "ppc_fp128", "token", "label", "metadata", "func",
- "ptr", "array", "struct", "target", "x86_amx"}))) {
- // `parseOptionalKeyword` already restricted `key` to one of the cases
- // below, so the `Default` is unreachable.
- return StringSwitch<function_ref<Type()>>(key)
- .Case("void", [&] { return LLVMVoidType::get(ctx); })
- .Case("ppc_fp128", [&] { return LLVMPPCFP128Type::get(ctx); })
- .Case("token", [&] { return LLVMTokenType::get(ctx); })
- .Case("label", [&] { return LLVMLabelType::get(ctx); })
- .Case("metadata", [&] { return LLVMMetadataType::get(ctx); })
- .Case("func", [&] { return LLVMFunctionType::parse(parser); })
- .Case("ptr", [&] { return LLVMPointerType::parse(parser); })
- .Case("array", [&] { return LLVMArrayType::parse(parser); })
- .Case("struct", [&] { return LLVMStructType::parse(parser); })
- .Case("target", [&] { return LLVMTargetExtType::parse(parser); })
- .Case("x86_amx", [&] { return LLVMX86AMXType::get(ctx); })
- .Default([] { return Type(); })();
- }
- // Otherwise, try parsing any MLIR type (only when allowed).
+ // Try parsing any MLIR type.
Type type;
OptionalParseResult result = parser.parseOptionalType(type);
if (result.has_value()) {
@@ -285,12 +256,27 @@ static Type dispatchParse(AsmParser &parser, bool allowAny = true) {
return type;
}
- // Neither a known LLVM short-hand keyword nor a parseable MLIR type.
- // Re-run `parseKeyword` to produce a useful error message.
+ // If no type found, fallback to the shorthand form.
+ StringRef key;
if (failed(parser.parseKeyword(&key)))
return Type();
- parser.emitError(keyLoc) << "unknown LLVM type: " << key;
- return Type();
+
+ MLIRContext *ctx = parser.getContext();
+ return StringSwitch<function_ref<Type()>>(key)
+ .Case("void", [&] { return LLVMVoidType::get(ctx); })
+ .Case("ppc_fp128", [&] { return LLVMPPCFP128Type::get(ctx); })
+ .Case("label", [&] { return LLVMLabelType::get(ctx); })
+ .Case("metadata", [&] { return LLVMMetadataType::get(ctx); })
+ .Case("func", [&] { return LLVMFunctionType::parse(parser); })
+ .Case("ptr", [&] { return LLVMPointerType::parse(parser); })
+ .Case("array", [&] { return LLVMArrayType::parse(parser); })
+ .Case("struct", [&] { return LLVMStructType::parse(parser); })
+ .Case("target", [&] { return LLVMTargetExtType::parse(parser); })
+ .Case("x86_amx", [&] { return LLVMX86AMXType::get(ctx); })
+ .Default([&] {
+ parser.emitError(keyLoc) << "unknown LLVM type: " << key;
+ return Type();
+ })();
}
/// Helper to use in parse lists.
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
index 2b3ba1b8b5a35..2c29b38c26f08 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
@@ -150,7 +150,7 @@ generatedTypeParser(AsmParser &parser, StringRef *mnemonic, Type &value);
bool LLVMArrayType::isValidElementType(Type type) {
return !llvm::isa<LLVMVoidType, LLVMLabelType, LLVMMetadataType,
- LLVMFunctionType, LLVMTokenType>(type);
+ LLVMFunctionType, TokenType>(type);
}
LLVMArrayType LLVMArrayType::get(Type elementType, uint64_t numElements) {
@@ -435,7 +435,7 @@ LogicalResult LLVMPointerType::verifyEntries(DataLayoutEntryListRef entries,
bool LLVMStructType::isValidElementType(Type type) {
return !llvm::isa<LLVMVoidType, LLVMLabelType, LLVMMetadataType,
- LLVMFunctionType, LLVMTokenType>(type);
+ LLVMFunctionType, TokenType>(type);
}
LLVMStructType LLVMStructType::getIdentified(MLIRContext *context,
@@ -743,10 +743,10 @@ bool mlir::LLVM::isCompatibleOuterType(Type type) {
LLVMPPCFP128Type,
LLVMPointerType,
LLVMStructType,
- LLVMTokenType,
LLVMTargetExtType,
LLVMVoidType,
- LLVMX86AMXType
+ LLVMX86AMXType,
+ TokenType
>(type)) {
// clang-format on
return true;
@@ -803,9 +803,9 @@ static bool isCompatibleImpl(Type type, DenseSet<Type> &compatibleTypes) {
LLVMLabelType,
LLVMMetadataType,
LLVMPPCFP128Type,
- LLVMTokenType,
LLVMVoidType,
- LLVMX86AMXType
+ LLVMX86AMXType,
+ TokenType
>([](Type) { return true; })
// clang-format on
.Case<PtrLikeTypeInterface>(
@@ -917,11 +917,11 @@ llvm::TypeSize mlir::LLVM::getPrimitiveTypeSizeInBits(Type type) {
elementSize.isScalable());
})
.Default([](Type ty) {
- assert((llvm::isa<LLVMVoidType, LLVMLabelType, LLVMMetadataType,
- LLVMTokenType, LLVMStructType, LLVMArrayType,
- LLVMPointerType, LLVMFunctionType, LLVMTargetExtType>(
- ty)) &&
- "unexpected missing support for primitive type");
+ assert(
+ (llvm::isa<LLVMVoidType, LLVMLabelType, LLVMMetadataType, TokenType,
+ LLVMStructType, LLVMArrayType, LLVMPointerType,
+ LLVMFunctionType, LLVMTargetExtType>(ty)) &&
+ "unexpected missing support for primitive type");
return llvm::TypeSize::getFixed(0);
});
}
diff --git a/mlir/lib/IR/Verifier.cpp b/mlir/lib/IR/Verifier.cpp
index af11e5251f279..12ef7388f8de5 100644
--- a/mlir/lib/IR/Verifier.cpp
+++ b/mlir/lib/IR/Verifier.cpp
@@ -125,12 +125,12 @@ LogicalResult OperationVerifier::verifyTokenValue(
if (value.getType() != tokenType)
return success();
- if (!producer.hasTrait<OpTrait::TokenProducerTrait>())
+ if (!producer.mightHaveTrait<OpTrait::TokenProducerTrait>())
return emitProducerError();
for (OpOperand &use : value.getUses()) {
Operation *user = use.getOwner();
- if (user->hasTrait<OpTrait::TokenConsumerTrait>())
+ if (user->mightHaveTrait<OpTrait::TokenConsumerTrait>())
continue;
return user->emitOpError()
@@ -185,7 +185,7 @@ LogicalResult OperationVerifier::verifyTokenBlockArgument(Block &block,
<< idx << " is only allowed in a region entry block";
Operation *parentOp = parentRegion->getParentOp();
- if (!parentOp || !parentOp->hasTrait<OpTrait::TokenProducerTrait>())
+ if (!parentOp || !parentOp->mightHaveTrait<OpTrait::TokenProducerTrait>())
return emitError(arg.getLoc(), "token entry block argument #")
<< idx
<< " requires the parent operation to define the "
diff --git a/mlir/lib/Target/LLVMIR/TypeFromLLVM.cpp b/mlir/lib/Target/LLVMIR/TypeFromLLVM.cpp
index 5d9345d707a44..018f66802fc47 100644
--- a/mlir/lib/Target/LLVMIR/TypeFromLLVM.cpp
+++ b/mlir/lib/Target/LLVMIR/TypeFromLLVM.cpp
@@ -73,7 +73,7 @@ class TypeFromLLVMIRTranslatorImpl {
if (type->isMetadataTy())
return LLVM::LLVMMetadataType::get(&context);
if (type->isTokenTy())
- return LLVM::LLVMTokenType::get(&context);
+ return TokenType::get(&context);
llvm_unreachable("not a primitive type");
}
diff --git a/mlir/lib/Target/LLVMIR/TypeToLLVM.cpp b/mlir/lib/Target/LLVMIR/TypeToLLVM.cpp
index 807a94c61f0c8..61997c691e35d 100644
--- a/mlir/lib/Target/LLVMIR/TypeToLLVM.cpp
+++ b/mlir/lib/Target/LLVMIR/TypeToLLVM.cpp
@@ -58,9 +58,7 @@ class TypeToLLVMIRTranslatorImpl {
.Case([this](LLVM::LLVMPPCFP128Type) {
return llvm::Type::getPPC_FP128Ty(context);
})
- .Case([this](LLVM::LLVMTokenType) {
- return llvm::Type::getTokenTy(context);
- })
+ .Case([this](TokenType) { return llvm::Type::getTokenTy(context); })
.Case([this](LLVM::LLVMLabelType) {
return llvm::Type::getLabelTy(context);
})
diff --git a/mlir/test/Conversion/AsyncToLLVM/convert-coro-to-llvm.mlir b/mlir/test/Conversion/AsyncToLLVM/convert-coro-to-llvm.mlir
index a398bc5710a86..7874aca75dc4c 100644
--- a/mlir/test/Conversion/AsyncToLLVM/convert-coro-to-llvm.mlir
+++ b/mlir/test/Conversion/AsyncToLLVM/convert-coro-to-llvm.mlir
@@ -4,7 +4,7 @@
func.func @coro_id() {
// CHECK: %0 = llvm.mlir.constant(0 : i32) : i32
// CHECK: %1 = llvm.mlir.zero : !llvm.ptr
- // CHECK: %2 = llvm.intr.coro.id %0, %1, %1, %1 : (i32, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> !llvm.token
+ // CHECK: %2 = llvm.intr.coro.id %0, %1, %1, %1 : (i32, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> token
%0 = async.coro.id
return
}
diff --git a/mlir/test/Dialect/LLVMIR/types.mlir b/mlir/test/Dialect/LLVMIR/types.mlir
index b87c3dd6f2d7a..69582546cb429 100644
--- a/mlir/test/Dialect/LLVMIR/types.mlir
+++ b/mlir/test/Dialect/LLVMIR/types.mlir
@@ -6,12 +6,12 @@ func.func @primitive() {
"some.op"() : () -> !llvm.void
// CHECK: !llvm.ppc_fp128
"some.op"() : () -> !llvm.ppc_fp128
- // CHECK: !llvm.token
- "some.op"() : () -> !llvm.token
// CHECK: !llvm.label
"some.op"() : () -> !llvm.label
// CHECK: !llvm.metadata
"some.op"() : () -> !llvm.metadata
+ // CHECK: token
+ "some.op"() : () -> token
return
}
diff --git a/mlir/test/Target/LLVMIR/Import/intrinsic.ll b/mlir/test/Target/LLVMIR/Import/intrinsic.ll
index f79d09aa3d633..959a04fff6dca 100644
--- a/mlir/test/Target/LLVMIR/Import/intrinsic.ll
+++ b/mlir/test/Target/LLVMIR/Import/intrinsic.ll
@@ -793,16 +793,16 @@ define void @threadlocal_test() {
; CHECK-LABEL: llvm.func @coro_id
define void @coro_id() {
%a = alloca [16 x i8]
- ; CHECK: llvm.intr.coro.id %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : (i32, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> !llvm.token
+ ; CHECK: llvm.intr.coro.id %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : (i32, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> token
%3 = call token @llvm.coro.id(i32 0, ptr %a, ptr null, ptr null)
ret void
}
; CHECK-LABEL: llvm.func @coro_begin
define void @coro_begin(ptr %0) {
- ; CHECK: llvm.intr.coro.id %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : (i32, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> !llvm.token
+ ; CHECK: llvm.intr.coro.id %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : (i32, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> token
%3 = call token @llvm.coro.id(i32 0, ptr null, ptr null, ptr null)
- ; CHECK: llvm.intr.coro.begin %{{.*}}, %{{.*}} : (!llvm.token, !llvm.ptr) -> !llvm.ptr
+ ; CHECK: llvm.intr.coro.begin %{{.*}}, %{{.*}} : (token, !llvm.ptr) -> !llvm.ptr
%4 = call ptr @llvm.coro.begin(token %3, ptr %0)
ret void
}
@@ -826,14 +826,14 @@ define void @coro_align() {
; CHECK-LABEL: llvm.func @coro_save
define void @coro_save(ptr %0) {
- ; CHECK: llvm.intr.coro.save %{{.*}} : (!llvm.ptr) -> !llvm.token
+ ; CHECK: llvm.intr.coro.save %{{.*}} : (!llvm.ptr) -> token
%2 = call token @llvm.coro.save(ptr %0)
ret void
}
; CHECK-LABEL: llvm.func @coro_suspend
define void @coro_suspend(i1 %0) {
- ; CHECK: llvm.intr.coro.id %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : (i32, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> !llvm.token
+ ; CHECK: llvm.intr.coro.id %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : (i32, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> token
%4 = call token @llvm.coro.id(i32 0, ptr null, ptr null, ptr null)
; CHECK: llvm.intr.coro.suspend %{{.*}}, %{{.*}} : i8
%5 = call i8 @llvm.coro.suspend(token %4, i1 %0)
@@ -849,9 +849,9 @@ define void @coro_end(ptr %0, i1 %1) {
; CHECK-LABEL: llvm.func @coro_free
define void @coro_free(ptr %0) {
- ; CHECK: llvm.intr.coro.id %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : (i32, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> !llvm.token
+ ; CHECK: llvm.intr.coro.id %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : (i32, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> token
%3 = call token @llvm.coro.id(i32 0, ptr null, ptr null, ptr null)
- ; CHECK: llvm.intr.coro.free %{{.*}}, %{{.*}} : (!llvm.token, !llvm.ptr) -> !llvm.ptr
+ ; CHECK: llvm.intr.coro.free %{{.*}}, %{{.*}} : (token, !llvm.ptr) -> !llvm.ptr
%4 = call ptr @llvm.coro.free(token %3, ptr %0)
ret void
}
diff --git a/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir b/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir
index 11882a0a1d4c6..5865e046aa5ac 100644
--- a/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir
@@ -816,7 +816,7 @@ llvm.func @coro_id() {
%a = llvm.alloca %c x i8 : (i64) -> !llvm.ptr
// CHECK: call token @llvm.coro.id
%null = llvm.mlir.zero : !llvm.ptr
- llvm.intr.coro.id %zero, %a, %null, %null : (i32, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> !llvm.token
+ llvm.intr.coro.id %zero, %a, %null, %null : (i32, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> token
llvm.return
}
@@ -824,9 +824,9 @@ llvm.func @coro_id() {
llvm.func @coro_begin(%arg0: !llvm.ptr) {
%zero = llvm.mlir.constant(0 : i32) : i32
%null = llvm.mlir.zero : !llvm.ptr
- %token = llvm.intr.coro.id %zero, %null, %null, %null : (i32, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> !llvm.token
+ %token = llvm.intr.coro.id %zero, %null, %null, %null : (i32, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> token
// CHECK: call ptr @llvm.coro.begin
- llvm.intr.coro.begin %token, %arg0 : (!llvm.token, !llvm.ptr) -> !llvm.ptr
+ llvm.intr.coro.begin %token, %arg0 : (token, !llvm.ptr) -> !llvm.ptr
llvm.return
}
@@ -851,7 +851,7 @@ llvm.func @coro_align() {
// CHECK-LABEL: @coro_save
llvm.func @coro_save(%arg0: !llvm.ptr) {
// CHECK: call token @llvm.coro.save
- %0 = llvm.intr.coro.save %arg0 : (!llvm.ptr) -> !llvm.token
+ %0 = llvm.intr.coro.save %arg0 : (!llvm.ptr) -> token
llvm.return
}
@@ -859,7 +859,7 @@ llvm.func @coro_save(%arg0: !llvm.ptr) {
llvm.func @coro_suspend(%arg0 : i1) {
%zero = llvm.mlir.constant(0 : i32) : i32
%null = llvm.mlir.zero : !llvm.ptr
- %token = llvm.intr.coro.id %zero, %null, %null, %null : (i32, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> !llvm.token
+ %token = llvm.intr.coro.id %zero, %null, %null, %null : (i32, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> token
// CHECK: call i8 @llvm.coro.suspend
%0 = llvm.intr.coro.suspend %token, %arg0 : i8
llvm.return
@@ -867,9 +867,9 @@ llvm.func @coro_suspend(%arg0 : i1) {
// CHECK-LABEL: @coro_end
llvm.func @coro_end(%arg0: !llvm.ptr, %arg1 : i1) {
- %none = llvm.mlir.none : !llvm.token
+ %none = llvm.mlir.none : token
// CHECK: call void @llvm.coro.end
- llvm.intr.coro.end %arg0, %arg1, %none : (!llvm.ptr, i1, !llvm.token) -> !llvm.void
+ llvm.intr.coro.end %arg0, %arg1, %none : (!llvm.ptr, i1, token) -> !llvm.void
llvm.return
}
@@ -877,9 +877,9 @@ llvm.func @coro_end(%arg0: !llvm.ptr, %arg1 : i1) {
llvm.func @coro_free(%arg0 : !llvm.ptr) {
%zero = llvm.mlir.constant(0 : i32) : i32
%null = llvm.mlir.zero : !llvm.ptr
- %token = llvm.intr.coro.id %zero, %null, %null, %null : (i32, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> !llvm.token
+ %token = llvm.intr.coro.id %zero, %null, %null, %null : (i32, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> token
// CHECK: call ptr @llvm.coro.free
- %0 = llvm.intr.coro.free %token, %arg0 : (!llvm.token, !llvm.ptr) -> !llvm.ptr
+ %0 = llvm.intr.coro.free %token, %arg0 : (token, !llvm.ptr) -> !llvm.ptr
llvm.return
}
>From 4e47ed1636c60ecb1e760151526509d89086b7fe Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Thu, 14 May 2026 17:39:56 +0000
Subject: [PATCH 07/15] rewrite design contract
---
mlir/docs/LangRef.md | 11 +++++++++--
mlir/docs/Tokens.md | 22 +++++++++++++---------
2 files changed, 22 insertions(+), 11 deletions(-)
diff --git a/mlir/docs/LangRef.md b/mlir/docs/LangRef.md
index 185321ed5cfc6..4b94162e738f6 100644
--- a/mlir/docs/LangRef.md
+++ b/mlir/docs/LangRef.md
@@ -738,8 +738,15 @@ dialect types.
The [builtin dialect](Dialects/Builtin.md) defines a set of types that are
directly usable by any other dialect in MLIR. These types cover a range from
-primitive integer and floating-point types, function types,
-[tokens](Tokens.md), and more.
+primitive integer and floating-point types, function types, and more.
+
+A *token* is an SSA value of the builtin parameterless, opaque `token` type.
+It carries no runtime data. Intuitively, a token is a pointer to an operation
+(via its result) or to a region (via an entry block argument). Unlike regular
+SSA values, a token cannot be forwarded: its def-use chain cannot be obscured
+by ops with forwarding semantics such as `arith.select` or `cf.br`, so you can
+always walk back from any use of a token to *the* specific operation or region
+that produced it. See [Tokens](Tokens.md) for more details.
## Properties
diff --git a/mlir/docs/Tokens.md b/mlir/docs/Tokens.md
index 9be9aaf8ec9de..b87e14e10b8e3 100644
--- a/mlir/docs/Tokens.md
+++ b/mlir/docs/Tokens.md
@@ -28,11 +28,11 @@ inspect the definition of the token.
## Structural Contract
-A token use cannot be substituted with another token value: the use of a token
-points directly to a specific producer. Generic transformations must not alter
-or break this link. New uses of a token can be introduced safely. Operations
-must opt in to producing or consuming tokens with `TokenProducerTrait` and
-`TokenConsumerTrait`. As a consequence:
+Given a use of a token SSA value, its definition is guaranteed to be the
+semantic producer of the token. Generic transformations must preserve this
+invariant: they may not introduce a forwarding step between a use and its
+producer, nor retarget a use to a producer with different semantics. New
+uses of a token can be introduced safely. As a consequence:
1. A token must not appear as a forwarded value, e.g.:
* a forwarded result/operand of a `CallOpInterface` op,
@@ -44,16 +44,20 @@ must opt in to producing or consuming tokens with `TokenProducerTrait` and
* the result of any op that selects or merges values it does not
understand (e.g. `arith.select`).
-2. Given a use of a token SSA value, its definition is guaranteed to be the
- semantic producer of the token.
+2. A token cannot constant-fold. No constant of token type exists.
-3. A token cannot constant-fold. No constant of token type exists.
+3. The presence of tokens has no effect on standard transformations such as
+ CSE, DCE or hoisting.
-4. Use of a token is side-effect free: a token user follows the usual `isTriviallyDead()` rules.
+4. Use of a token is side-effect free: a token user follows the usual
+ `isTriviallyDead()` rules.
These properties mirror what LLVM IR already documents for its own
[`token` type](https://llvm.org/docs/LangRef.html#token-type).
+Operations must opt in to producing or consuming tokens with
+`TokenProducerTrait` and `TokenConsumerTrait`.
+
## ODS Integration
Tokens are excluded from the default `AnyType` predicate, so an op that has
>From 8c7c6a909bc0d1da2a04439d05bfdd930b2b94fd Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Fri, 15 May 2026 08:14:33 +0000
Subject: [PATCH 08/15] address comments
---
mlir/docs/Dialects/LLVM.md | 3 +--
mlir/lib/IR/Verifier.cpp | 14 +++++---------
2 files changed, 6 insertions(+), 11 deletions(-)
diff --git a/mlir/docs/Dialects/LLVM.md b/mlir/docs/Dialects/LLVM.md
index 87c27cc12b7ed..f8419d726fc33 100644
--- a/mlir/docs/Dialects/LLVM.md
+++ b/mlir/docs/Dialects/LLVM.md
@@ -224,6 +224,7 @@ dialect-compatible types_. The following types are compatible:
(`FloatType`).
- 1D vectors of signless integers or floating point types - `vector<NxT>`
(`VectorType`).
+- Tokens (`TokenType`) - non-inspectable values associated with an operation.
Note that only a subset of types that can be represented by a given class is
compatible. For example, signed and unsigned integers are not compatible. LLVM
@@ -242,8 +243,6 @@ LLVM dialect:
- `!llvm.ppc_fp128` (`LLVMPPCFP128Type`) - 128-bit floating-point value (two
64 bits).
-- `!builtin.token` (`TokenType`) - a non-inspectable value associated with an
- operation.
- `!llvm.metadata` (`LLVMMetadataType`) - LLVM IR metadata, to be used only if
the metadata cannot be represented as structured MLIR attributes.
- `!llvm.void` (`LLVMVoidType`) - does not represent any value; can only
diff --git a/mlir/lib/IR/Verifier.cpp b/mlir/lib/IR/Verifier.cpp
index 12ef7388f8de5..9f9d2168f24df 100644
--- a/mlir/lib/IR/Verifier.cpp
+++ b/mlir/lib/IR/Verifier.cpp
@@ -179,18 +179,14 @@ LogicalResult OperationVerifier::verifyTokenBlockArgument(Block &block,
if (arg.getType() != tokenType)
return success();
- Region *parentRegion = block.getParent();
- if (!parentRegion || !block.isEntryBlock())
+ // The producer-trait check on the parent op (and the token consumer check
+ // on the uses) is performed by `verifyTokenValues` when it iterates the
+ // entry block arguments of an op's regions. Here we only enforce that
+ // tokens are not used as non-entry block arguments.
+ if (!block.getParent() || !block.isEntryBlock())
return emitError(arg.getLoc(), "token block argument #")
<< idx << " is only allowed in a region entry block";
- Operation *parentOp = parentRegion->getParentOp();
- if (!parentOp || !parentOp->mightHaveTrait<OpTrait::TokenProducerTrait>())
- return emitError(arg.getLoc(), "token entry block argument #")
- << idx
- << " requires the parent operation to define the "
- "TokenProducerTrait";
-
return success();
}
>From 8882b47c1b3db7f5382bb3fbdd450117522a0ca2 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Fri, 22 May 2026 11:04:33 +0000
Subject: [PATCH 09/15] address comments
---
mlir/docs/Tokens.md | 5 ++---
1 file changed, 2 insertions(+), 3 deletions(-)
diff --git a/mlir/docs/Tokens.md b/mlir/docs/Tokens.md
index b87e14e10b8e3..3776632f12cf4 100644
--- a/mlir/docs/Tokens.md
+++ b/mlir/docs/Tokens.md
@@ -108,6 +108,5 @@ a branch or a loop.
```
`scf.if`'s results are declared with `Variadic<AnyType>` and `scf.yield`'s
-operands likewise use `AnyType`. Because `AnyType` excludes tokens, yielding
-(or returning) a token through `scf.if` (or any other op that has not
-explicitly opted in) is rejected.
+operands likewise use `AnyType`. Because `AnyType` excludes tokens, both
+`scf.if` and `scf.yield` fail verification.
>From 95d4e8d8129371e42d5874e4e90ebee4e650c6e7 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Fri, 22 May 2026 12:11:30 +0000
Subject: [PATCH 10/15] regenerate bytecode
---
.../Builtin/Bytecode/builtin_fixed_0.mlirbc | Bin 4503 -> 4503 bytes
1 file changed, 0 insertions(+), 0 deletions(-)
diff --git a/mlir/test/Dialect/Builtin/Bytecode/builtin_fixed_0.mlirbc b/mlir/test/Dialect/Builtin/Bytecode/builtin_fixed_0.mlirbc
index 76a764de40eced87c6cc3d98b1aabf8153649989..f08188cf1b8c0551ef912c5252a455b9a18c2c86 100644
GIT binary patch
delta 14
VcmbQPJY9K17aODV=596~ZU80^1XKV3
delta 14
VcmbQPJY9K17aODY=596~ZU7~J1Uvu$
>From 972be7e624a9ce33db97797ca56692868e709750 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Fri, 22 May 2026 12:14:40 +0000
Subject: [PATCH 11/15] call out IsolatedFromAbove restriction
---
mlir/docs/Tokens.md | 4 ++++
1 file changed, 4 insertions(+)
diff --git a/mlir/docs/Tokens.md b/mlir/docs/Tokens.md
index 3776632f12cf4..5b6c1655c2cf6 100644
--- a/mlir/docs/Tokens.md
+++ b/mlir/docs/Tokens.md
@@ -58,6 +58,10 @@ These properties mirror what LLVM IR already documents for its own
Operations must opt in to producing or consuming tokens with
`TokenProducerTrait` and `TokenConsumerTrait`.
+Note: Because tokens are SSA values, they cannot cross `IsolatedFromAbove`
+region boundaries. Symbols should be used instead when a token-like
+dependency must cross such a boundary.
+
## ODS Integration
Tokens are excluded from the default `AnyType` predicate, so an op that has
>From 69f9322c99e1d083854746424853733f1cec6c44 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Sun, 24 May 2026 14:33:16 +0000
Subject: [PATCH 12/15] move structural contract to LangRef
---
mlir/docs/LangRef.md | 33 +++++++++++++++++++-----
mlir/docs/Tokens.md | 60 +++++++++++++++-----------------------------
2 files changed, 47 insertions(+), 46 deletions(-)
diff --git a/mlir/docs/LangRef.md b/mlir/docs/LangRef.md
index 4b94162e738f6..b3103d7e65a6b 100644
--- a/mlir/docs/LangRef.md
+++ b/mlir/docs/LangRef.md
@@ -740,13 +740,34 @@ The [builtin dialect](Dialects/Builtin.md) defines a set of types that are
directly usable by any other dialect in MLIR. These types cover a range from
primitive integer and floating-point types, function types, and more.
+### Token Type
+
A *token* is an SSA value of the builtin parameterless, opaque `token` type.
-It carries no runtime data. Intuitively, a token is a pointer to an operation
-(via its result) or to a region (via an entry block argument). Unlike regular
-SSA values, a token cannot be forwarded: its def-use chain cannot be obscured
-by ops with forwarding semantics such as `arith.select` or `cf.br`, so you can
-always walk back from any use of a token to *the* specific operation or region
-that produced it. See [Tokens](Tokens.md) for more details.
+It carries no runtime data. Given a use of a token SSA value, its definition
+is guaranteed to be the semantic producer of the token. Generic transformations
+must preserve this invariant: they may not introduce a forwarding step between
+a use and its producer, nor retarget a use to a producer with different
+semantics. New uses of a token can be introduced safely. As a consequence:
+
+1. A token must not appear as a forwarded value. E.g., it cannot be used as a
+ successor operand of a `BranchOpInterface` op.
+2. A token cannot constant-fold. No constant of token type exists.
+3. The presence of tokens has no effect on standard transformations such as
+ CSE, DCE or hoisting.
+4. Use of a token is side-effect free: a token user follows the usual
+ `isTriviallyDead()` rules.
+
+These properties mirror what LLVM IR already documents for its own
+[`token` type](https://llvm.org/docs/LangRef.html#token-type).
+
+Operations must opt in to producing or consuming tokens with
+`TokenProducerTrait` and `TokenConsumerTrait`.
+
+Note: Because tokens are SSA values, they cannot cross `IsolatedFromAbove`
+region boundaries. Symbols should be used instead when a token-like
+dependency must cross such a boundary.
+
+See [Tokens](Tokens.md) for details on ODS integration and examples.
## Properties
diff --git a/mlir/docs/Tokens.md b/mlir/docs/Tokens.md
index 5b6c1655c2cf6..9cde2e88484ae 100644
--- a/mlir/docs/Tokens.md
+++ b/mlir/docs/Tokens.md
@@ -8,11 +8,13 @@ Intuitively, a *token* value is a pointer to an operation (via an OpResult)
or a pointer to a region (via an entry block argument). A token cannot be
forwarded: a token def-use chain cannot be obscured by ops with forwarding
semantics such as `arith.select` or `cf.br`. This allows you to always walk
-back from a use and say "this token came from *that* specific op".
+back from a use and say "this token came from *that* specific op". The exact
+structural contract is specified in the
+[LangRef section on tokens](LangRef.md#token-type).
A token is an SSA value that has the builtin token type. The token type is
-parameterless, opaque and carries no runtime data. A token prints as `token`.
-Apart from the structural contract below, tokens are like any other SSA values.
+parameterless, opaque and carries no runtime data. Apart from the structural
+contract specified in the LangRef, tokens are like any other SSA values.
## Design Rationale
@@ -26,42 +28,6 @@ token use points to the token's definition and not the other way around.
Transformations can remove the use of a token without having to touch or
inspect the definition of the token.
-## Structural Contract
-
-Given a use of a token SSA value, its definition is guaranteed to be the
-semantic producer of the token. Generic transformations must preserve this
-invariant: they may not introduce a forwarding step between a use and its
-producer, nor retarget a use to a producer with different semantics. New
-uses of a token can be introduced safely. As a consequence:
-
-1. A token must not appear as a forwarded value, e.g.:
- * a forwarded result/operand of a `CallOpInterface` op,
- * an argument or result type of a `FunctionOpInterface` op,
- * a successor operand or successor block argument of a
- `BranchOpInterface` op,
- * a forwarded operand to/from any region of a `RegionBranchOpInterface`
- op (iter-args, region results, yielded values), or
- * the result of any op that selects or merges values it does not
- understand (e.g. `arith.select`).
-
-2. A token cannot constant-fold. No constant of token type exists.
-
-3. The presence of tokens has no effect on standard transformations such as
- CSE, DCE or hoisting.
-
-4. Use of a token is side-effect free: a token user follows the usual
- `isTriviallyDead()` rules.
-
-These properties mirror what LLVM IR already documents for its own
-[`token` type](https://llvm.org/docs/LangRef.html#token-type).
-
-Operations must opt in to producing or consuming tokens with
-`TokenProducerTrait` and `TokenConsumerTrait`.
-
-Note: Because tokens are SSA values, they cannot cross `IsolatedFromAbove`
-region boundaries. Symbols should be used instead when a token-like
-dependency must cross such a boundary.
-
## ODS Integration
Tokens are excluded from the default `AnyType` predicate, so an op that has
@@ -94,7 +60,21 @@ arguments in non-entry blocks are rejected.
## Examples
-### Rejected: tokens in `AnyType` positions
+### Non-forwarding Semantics
+
+The [LangRef](LangRef.md#token-type) requires that a token never appears as a
+forwarded value. For example, you cannot use a token like this:
+
+* a forwarded result or operand of a `CallOpInterface` op;
+* an argument or result type of a `FunctionOpInterface` op;
+* a successor operand of a `BranchOpInterface` op;
+* a block argument of a non-entry block;
+* a forwarded operand to or from any region of a `RegionBranchOpInterface`
+ op (iter-args, region results, or yielded values); or
+* the result of any op that selects or merges values it does not understand
+ (e.g. `arith.select`).
+
+### ODS-based Verification: Tokens Rejected in `AnyType` Positions
`scf.yield` operands have forwarding semantics. A token cannot be yielded from
a branch or a loop.
>From 3ad6af9879c2abb647a3f13226683cfe84c358c0 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Mon, 25 May 2026 13:26:48 +0000
Subject: [PATCH 13/15] address comments: symbols / IsolatedFromAbove
---
mlir/docs/LangRef.md | 3 +--
mlir/docs/Tokens.md | 7 ++++++-
2 files changed, 7 insertions(+), 3 deletions(-)
diff --git a/mlir/docs/LangRef.md b/mlir/docs/LangRef.md
index b3103d7e65a6b..0e6fb006da48b 100644
--- a/mlir/docs/LangRef.md
+++ b/mlir/docs/LangRef.md
@@ -764,8 +764,7 @@ Operations must opt in to producing or consuming tokens with
`TokenProducerTrait` and `TokenConsumerTrait`.
Note: Because tokens are SSA values, they cannot cross `IsolatedFromAbove`
-region boundaries. Symbols should be used instead when a token-like
-dependency must cross such a boundary.
+region boundaries.
See [Tokens](Tokens.md) for details on ODS integration and examples.
diff --git a/mlir/docs/Tokens.md b/mlir/docs/Tokens.md
index 9cde2e88484ae..27cb38168f716 100644
--- a/mlir/docs/Tokens.md
+++ b/mlir/docs/Tokens.md
@@ -23,11 +23,16 @@ parallel def-use system for operations. It reuses the existing def-use
machinery for SSA. It introduces no changes to the generic op syntax, the
bytecode infrastructure or core C++ APIs around `Operation`.
-As with regular use-def chains, a token def-use chain is unidirectional. A
+As with regular def-use chains, a token def-use chain is unidirectional. A
token use points to the token's definition and not the other way around.
Transformations can remove the use of a token without having to touch or
inspect the definition of the token.
+Because tokens are SSA values, they cannot cross `IsolatedFromAbove` region
+boundaries. This is intentional: it allows passes to process isolated regions
+concurrently without racing on def-use chains. When a token-like dependency
+must cross such a boundary, use a symbol instead.
+
## ODS Integration
Tokens are excluded from the default `AnyType` predicate, so an op that has
>From 955ba236d62bf0a6afbe501048794005115a4908 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Mon, 25 May 2026 17:20:16 -0700
Subject: [PATCH 14/15] Update mlir/docs/Tokens.md
Co-authored-by: Mehdi Amini <joker.eph at gmail.com>
---
mlir/docs/Tokens.md | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/mlir/docs/Tokens.md b/mlir/docs/Tokens.md
index 27cb38168f716..df353f55edd1f 100644
--- a/mlir/docs/Tokens.md
+++ b/mlir/docs/Tokens.md
@@ -31,7 +31,8 @@ inspect the definition of the token.
Because tokens are SSA values, they cannot cross `IsolatedFromAbove` region
boundaries. This is intentional: it allows passes to process isolated regions
concurrently without racing on def-use chains. When a token-like dependency
-must cross such a boundary, use a symbol instead.
+must cross such a boundary, another mechanism must be used (e.g. a symbolic
+reference using an attribute).
## ODS Integration
>From fc3e151927fcb724bd523e689b0ee859d2bb7bae Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Tue, 26 May 2026 00:27:58 +0000
Subject: [PATCH 15/15] drop unrealized_conversion_cast change
---
mlir/include/mlir/IR/BuiltinOps.td | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/mlir/include/mlir/IR/BuiltinOps.td b/mlir/include/mlir/IR/BuiltinOps.td
index 3e9d2578fba12..cdc09afe0b67e 100644
--- a/mlir/include/mlir/IR/BuiltinOps.td
+++ b/mlir/include/mlir/IR/BuiltinOps.td
@@ -100,7 +100,7 @@ def ModuleOp : Builtin_Op<"module", [
//===----------------------------------------------------------------------===//
def UnrealizedConversionCastOp : Builtin_Op<"unrealized_conversion_cast", [
- Pure, TokenProducerTrait, TokenConsumerTrait
+ Pure
]> {
let summary = "An unrealized conversion from one set of types to another";
let description = [{
@@ -136,8 +136,8 @@ def UnrealizedConversionCastOp : Builtin_Op<"unrealized_conversion_cast", [
```
}];
- let arguments = (ins Variadic<AnyTypeOrToken>:$inputs);
- let results = (outs Variadic<AnyTypeOrToken>:$outputs);
+ let arguments = (ins Variadic<AnyType>:$inputs);
+ let results = (outs Variadic<AnyType>:$outputs);
let assemblyFormat = [{
($inputs^ `:` type($inputs))? `to` type($outputs) attr-dict
}];
More information about the Mlir-commits
mailing list