[Mlir-commits] [mlir] [mlir][IR] Add builtin `TokenType` (PR #195640)
Matthias Springer
llvmlistbot at llvm.org
Fri May 8 04:02:55 PDT 2026
https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/195640
>From b578602af7dd5ed9d698ff00c3879141d69c992e Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Mon, 4 May 2026 12:14:41 +0000
Subject: [PATCH] [mlir][IR] Add builtin `TokenTypeInterface`
type instead of type interface
add bytecode
---
mlir/docs/Tokens.md | 104 ++++++++++++++++++
mlir/include/mlir/Dialect/Async/IR/Async.h | 2 +-
.../include/mlir/Dialect/Async/IR/AsyncOps.td | 4 +-
.../include/mlir/IR/BuiltinDialectBytecode.td | 5 +-
mlir/include/mlir/IR/BuiltinTypes.td | 27 +++++
mlir/include/mlir/IR/CommonTypeConstraints.td | 20 +++-
mlir/lib/AsmParser/TokenKinds.def | 1 +
mlir/lib/AsmParser/TypeParser.cpp | 6 +
.../Conversion/AsyncToLLVM/AsyncToLLVM.cpp | 35 +++---
mlir/lib/Dialect/Async/IR/Async.cpp | 10 +-
.../Transforms/AsyncRuntimeRefCounting.cpp | 2 +-
.../Async/Transforms/AsyncToAsyncRuntime.cpp | 14 ++-
mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp | 60 ++++++----
mlir/lib/IR/AsmPrinter.cpp | 1 +
mlir/test/Dialect/ArmSME/invalid.mlir | 4 +-
.../Builtin/Bytecode/builtin_fixed.mlir | 4 +-
.../Builtin/Bytecode/builtin_fixed_0.mlirbc | Bin 4471 -> 4503 bytes
mlir/test/Dialect/Builtin/Bytecode/types.mlir | 10 ++
mlir/test/Dialect/Linalg/invalid.mlir | 4 +-
mlir/test/Dialect/MemRef/invalid.mlir | 4 +-
mlir/test/Dialect/SparseTensor/invalid.mlir | 24 ++--
mlir/test/Dialect/Tensor/invalid.mlir | 2 +-
mlir/test/Dialect/Vector/invalid.mlir | 10 +-
mlir/test/Dialect/traits.mlir | 2 +-
mlir/test/IR/operand.mlir | 6 +-
mlir/test/IR/result.mlir | 6 +-
mlir/test/IR/token-type.mlir | 60 ++++++++++
mlir/test/lib/Dialect/Test/TestOps.td | 30 +++++
mlir/test/mlir-tblgen/predicate.td | 4 +-
mlir/test/mlir-tblgen/types.mlir | 6 +-
30 files changed, 372 insertions(+), 95 deletions(-)
create mode 100644 mlir/docs/Tokens.md
create mode 100644 mlir/test/IR/token-type.mlir
diff --git a/mlir/docs/Tokens.md b/mlir/docs/Tokens.md
new file mode 100644
index 0000000000000..093408c942d0f
--- /dev/null
+++ b/mlir/docs/Tokens.md
@@ -0,0 +1,104 @@
+# Tokens
+
+[TOC]
+
+## Overview
+
+Intuitively, a *token* value is a pointer to an operation (via an OpResult)
+or a pointer to a region (via an entry block argument). A token cannot be
+forwarded: a token def-use chain cannot be obscured by ops with forwarding
+semantics such as `arith.select` or `cf.br`. This allows you to always walk
+back from a use and say "this token came from *that* specific op".
+
+A token is an SSA value that has the builtin token type. The token type is
+parameterless, opaque and prints as `token`. A token carries no runtime data.
+Apart from the structural contract below, tokens are like any other SSA values.
+
+## Design Rationale
+
+The token type allows operations to refer to another operation without a new
+parallel def-use system for operations. The existing def-use machinery for SSA
+values can be reused. Moreover, no changes were needed for the generic op
+syntax, the bytecode infrastructure and core C++ APIs around `Operation`.
+
+As with regular use-def chains, a token def-use chain is unidirectional. A
+token use points to the token's definition. (But not the other way around.)
+Transformations can remove the use of a token without having to touch the
+definition of the token. (Whether such a transformation is correct depends on
+the semantics of the token-producing and token-consuming ops.)
+
+Token-producing and token-consuming ops are subject to standard transformations
+such as CSE, DCE and hoisting. If such transformations are not desirable due to
+the concrete op semantics, common IR design patterns can be employed. To give a
+few examples: Terminators or ops with side effects are not CSE'd or DCE'd.
+Region block arguments semantically belong to the enclosing op and are never
+CSE'd, DCE'd or hoisted. Non-speculatability may also prevent hoisting.
+
+## Structural Contract
+
+1. A token must not appear as a forwarded value, e.g.:
+ * a forwarded result/operand of a `CallOpInterface` op,
+ * an argument or result type of a `FunctionOpInterface` op (a token
+ block argument *inside* a function body is fine — what is disallowed
+ is forwarding tokens across the call/return boundary),
+ * a successor operand or successor block argument of a
+ `BranchOpInterface` op,
+ * a forwarded operand to/from any region of a `RegionBranchOpInterface`
+ op (iter-args, region results, yielded values), or
+ * the result of any op that selects or merges values it does not
+ understand (e.g. `arith.select`).
+
+2. As a consequence of (1), given a use of a token SSA value, its definition is
+ guaranteed to be the semantic producer of the token.
+
+3. A token cannot constant-fold. No constant of token type exists.
+
+These properties mirror what LLVM IR already documents for its own
+[`token` type](https://llvm.org/docs/LangRef.html#token-type).
+
+## ODS Integration
+
+Tokens are excluded from the default `AnyType` predicate, so an op that has
+not opted in cannot accept a token as an arbitrary operand or result. This
+restriction prevents tokens from being accidentally passed as operands with
+forwarding semantics.
+
+Three predicates are provided in `CommonTypeConstraints.td`:
+
+| Predicate | Accepts | Use when … |
+| ------------------ | ------------------------------------ | ----------------------------------------------------------------------|
+| `AnyType` | any non-token type | the default; matches the historical meaning of "any type" pre-tokens. |
+| `AnyTypeOrToken` | any type, including tokens | the op legitimately accepts arbitrary types (including tokens). |
+| `Token` | only the builtin `TokenType` | the op specifically takes a token operand/result. |
+
+Example:
+
+```tablegen
+def MyConsumeOp : MyDialect_Op<"consume"> {
+ let arguments = (ins Token:$scope, AnyType:$value);
+}
+```
+
+## Examples
+
+### Rejected: tokens in `AnyType` positions
+
+`scf.yield` operands have forwarding semantics. A token cannot be yielded from
+a branch or a loop.
+
+```mlir
+// error: 'scf.if' op result #0 must be variadic of any non-token type,
+// but got 'token'
+%t = scf.if %cond -> token {
+ %a = my.token.produce : token
+ scf.yield %a : token
+} else {
+ %b = my.token.produce : token
+ scf.yield %b : token
+}
+```
+
+`scf.if`'s results are declared with `Variadic<AnyType>` and `scf.yield`'s
+operands likewise use `AnyType`. Because `AnyType` excludes tokens by
+default, yielding (or returning) a token through a `scf.if` (or any other
+op that has not explicitly opted in via `AnyTypeOrToken`) is rejected.
diff --git a/mlir/include/mlir/Dialect/Async/IR/Async.h b/mlir/include/mlir/Dialect/Async/IR/Async.h
index f16e87e71373a..fc0b086126f52 100644
--- a/mlir/include/mlir/Dialect/Async/IR/Async.h
+++ b/mlir/include/mlir/Dialect/Async/IR/Async.h
@@ -50,7 +50,7 @@ namespace async {
/// Returns true if the type is reference counted at runtime.
inline bool isRefCounted(Type type) {
- return isa<TokenType, ValueType, GroupType>(type);
+ return isa<async::TokenType, ValueType, GroupType>(type);
}
} // namespace async
diff --git a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td
index 2cebeac767f29..058f58bda6433 100644
--- a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td
+++ b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td
@@ -174,7 +174,9 @@ def Async_FuncOp : Async_Op<"func",
unsigned getNumResults() {return getResultTypes().size();}
/// Is the async func stateful
- bool isStateful() { return isa<TokenType>(getFunctionType().getResult(0));}
+ bool isStateful() {
+ return isa<async::TokenType>(getFunctionType().getResult(0));
+ }
//===------------------------------------------------------------------===//
// OpAsmOpInterface Methods
diff --git a/mlir/include/mlir/IR/BuiltinDialectBytecode.td b/mlir/include/mlir/IR/BuiltinDialectBytecode.td
index c97d093c84e51..99bcf70c77564 100644
--- a/mlir/include/mlir/IR/BuiltinDialectBytecode.td
+++ b/mlir/include/mlir/IR/BuiltinDialectBytecode.td
@@ -294,6 +294,8 @@ def UnrankedTensorType : DialectType<(type
Type:$elementType
)>;
+def TokenType : DialectType<(type)>;
+
let cType = "VectorType" in {
def VectorType : DialectType<(type
Array<SignedVarIntList>:$shape,
@@ -371,7 +373,8 @@ def BuiltinDialectTypes : DialectTypes<"Builtin"> {
UnrankedMemRefTypeWithMemSpace,
UnrankedTensorType,
VectorType,
- VectorTypeWithScalableDims
+ VectorTypeWithScalableDims,
+ TokenType
];
}
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index 20c41c5f79729..33e2a48e26386 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -1075,6 +1075,33 @@ def Builtin_None : Builtin_Type<"None", "none"> {
}];
}
+//===----------------------------------------------------------------------===//
+// TokenType
+//===----------------------------------------------------------------------===//
+
+def Builtin_Token : Builtin_Type<"Token", "token"> {
+ let summary = "Token type for static def-use links";
+ let description = [{
+ Intuitively, a *token* value is a pointer to an operation (via an OpResult)
+ or a pointer to a region (via an entry block argument).
+
+ More precisely, a token is an SSA value whose purpose is to encode a
+ static def-use relationship between operations or regions. It carries
+ no runtime data and is not allowed to flow through "regular"
+ value-forwarding constructs. A token's provenance cannot be obscured
+ through value forwarding.
+
+ See the "Tokens" design note in the documentation for details on the
+ structural contract.
+
+ Syntax:
+
+ ```
+ token-type ::= `token`
+ ```
+ }];
+}
+
//===----------------------------------------------------------------------===//
// OpaqueType
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index 57caaae08462f..86066bcda73e6 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -165,8 +165,24 @@ class SameBuildabilityAs<Type type, code builder> {
code builderCall = !if(!empty(type.builderCall), "", builder);
}
-// Any type at all.
-def AnyType : Type<CPred<"true">, "any type">;
+// Whether a type is the builtin `TokenType`.
+def IsTokenTypePred : CPred<"::llvm::isa<::mlir::TokenType>($_self)">;
+
+// Any non-token type. Tokens are excluded by default to prevent ops that
+// accept arbitrary types from accidentally accepting tokens as operands /
+// results, since a token must not be value-forwarded. Ops that legitimately
+// want to accept any type, including tokens, should use `AnyTypeOrToken`
+// instead.
+def AnyType : Type<Neg<IsTokenTypePred>, "any non-token type">;
+
+// Any type at all, including tokens. Used by ops that explicitly opt in to
+// accepting tokens (e.g. ops in interfaces such as `CallOpInterface`,
+// `BranchOpInterface`, etc. that legitimately handle arbitrary types).
+def AnyTypeOrToken : Type<CPred<"true">, "any type">;
+
+// The builtin token type.
+def Token : Type<IsTokenTypePred, "token", "::mlir::TokenType">,
+ BuildableType<"$_builder.getType<::mlir::TokenType>()">;
// None type
def NoneType : Type<CPred<"::llvm::isa<::mlir::NoneType>($_self)">, "none type",
diff --git a/mlir/lib/AsmParser/TokenKinds.def b/mlir/lib/AsmParser/TokenKinds.def
index fe7c53753e156..f5e5c25832a30 100644
--- a/mlir/lib/AsmParser/TokenKinds.def
+++ b/mlir/lib/AsmParser/TokenKinds.def
@@ -127,6 +127,7 @@ TOK_KEYWORD(symbol)
TOK_KEYWORD(tensor)
TOK_KEYWORD(tf32)
TOK_KEYWORD(to)
+TOK_KEYWORD(token)
TOK_KEYWORD(true)
TOK_KEYWORD(tuple)
TOK_KEYWORD(type)
diff --git a/mlir/lib/AsmParser/TypeParser.cpp b/mlir/lib/AsmParser/TypeParser.cpp
index a461ebed967a8..2cdec14d65fa6 100644
--- a/mlir/lib/AsmParser/TypeParser.cpp
+++ b/mlir/lib/AsmParser/TypeParser.cpp
@@ -58,6 +58,7 @@ OptionalParseResult Parser::parseOptionalType(Type &type) {
case Token::kw_f128:
case Token::kw_index:
case Token::kw_none:
+ case Token::kw_token:
case Token::exclamation_identifier:
return failure(!(type = parseType()));
@@ -371,6 +372,11 @@ Type Parser::parseNonFunctionType() {
consumeToken(Token::kw_none);
return builder.getNoneType();
+ // token-type
+ case Token::kw_token:
+ consumeToken(Token::kw_token);
+ return builder.getType<TokenType>();
+
// extended type
case Token::exclamation_identifier:
return parseExtendedType();
diff --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
index 29e6552231f9c..7844c9dda877c 100644
--- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
+++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
@@ -89,7 +89,7 @@ struct AsyncAPI {
}
static FunctionType createTokenFunctionType(MLIRContext *ctx) {
- return FunctionType::get(ctx, {}, {TokenType::get(ctx)});
+ return FunctionType::get(ctx, {}, {async::TokenType::get(ctx)});
}
static FunctionType createValueFunctionType(MLIRContext *ctx) {
@@ -109,7 +109,7 @@ struct AsyncAPI {
}
static FunctionType emplaceTokenFunctionType(MLIRContext *ctx) {
- return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
+ return FunctionType::get(ctx, {async::TokenType::get(ctx)}, {});
}
static FunctionType emplaceValueFunctionType(MLIRContext *ctx) {
@@ -118,7 +118,7 @@ struct AsyncAPI {
}
static FunctionType setTokenErrorFunctionType(MLIRContext *ctx) {
- return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
+ return FunctionType::get(ctx, {async::TokenType::get(ctx)}, {});
}
static FunctionType setValueErrorFunctionType(MLIRContext *ctx) {
@@ -128,7 +128,7 @@ struct AsyncAPI {
static FunctionType isTokenErrorFunctionType(MLIRContext *ctx) {
auto i1 = IntegerType::get(ctx, 1);
- return FunctionType::get(ctx, {TokenType::get(ctx)}, {i1});
+ return FunctionType::get(ctx, {async::TokenType::get(ctx)}, {i1});
}
static FunctionType isValueErrorFunctionType(MLIRContext *ctx) {
@@ -143,7 +143,7 @@ struct AsyncAPI {
}
static FunctionType awaitTokenFunctionType(MLIRContext *ctx) {
- return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
+ return FunctionType::get(ctx, {async::TokenType::get(ctx)}, {});
}
static FunctionType awaitValueFunctionType(MLIRContext *ctx) {
@@ -162,13 +162,14 @@ struct AsyncAPI {
static FunctionType addTokenToGroupFunctionType(MLIRContext *ctx) {
auto i64 = IntegerType::get(ctx, 64);
- return FunctionType::get(ctx, {TokenType::get(ctx), GroupType::get(ctx)},
- {i64});
+ return FunctionType::get(
+ ctx, {async::TokenType::get(ctx), GroupType::get(ctx)}, {i64});
}
static FunctionType awaitTokenAndExecuteFunctionType(MLIRContext *ctx) {
auto ptrType = opaquePointerType(ctx);
- return FunctionType::get(ctx, {TokenType::get(ctx), ptrType, ptrType}, {});
+ return FunctionType::get(
+ ctx, {async::TokenType::get(ctx), ptrType, ptrType}, {});
}
static FunctionType awaitValueAndExecuteFunctionType(MLIRContext *ctx) {
@@ -291,7 +292,7 @@ class AsyncRuntimeTypeConverter : public TypeConverter {
}
static std::optional<Type> convertAsyncTypes(Type type) {
- if (isa<TokenType, GroupType, ValueType>(type))
+ if (isa<async::TokenType, GroupType, ValueType>(type))
return AsyncAPI::opaquePointerType(type.getContext());
if (isa<CoroIdType, CoroStateType>(type))
@@ -583,7 +584,7 @@ class RuntimeCreateOpLowering : public ConvertOpToLLVMPattern<RuntimeCreateOp> {
Type resultType = op->getResultTypes()[0];
// Tokens creation maps to a simple function call.
- if (isa<TokenType>(resultType)) {
+ if (isa<async::TokenType>(resultType)) {
rewriter.replaceOpWithNewOp<func::CallOp>(
op, kCreateToken, converter->convertType(resultType));
return success();
@@ -659,7 +660,7 @@ class RuntimeSetAvailableOpLowering
ConversionPatternRewriter &rewriter) const override {
StringRef apiFuncName =
TypeSwitch<Type, StringRef>(op.getOperand().getType())
- .Case<TokenType>([](Type) { return kEmplaceToken; })
+ .Case<async::TokenType>([](Type) { return kEmplaceToken; })
.Case<ValueType>([](Type) { return kEmplaceValue; });
rewriter.replaceOpWithNewOp<func::CallOp>(op, apiFuncName, TypeRange(),
@@ -685,7 +686,7 @@ class RuntimeSetErrorOpLowering
ConversionPatternRewriter &rewriter) const override {
StringRef apiFuncName =
TypeSwitch<Type, StringRef>(op.getOperand().getType())
- .Case<TokenType>([](Type) { return kSetTokenError; })
+ .Case<async::TokenType>([](Type) { return kSetTokenError; })
.Case<ValueType>([](Type) { return kSetValueError; });
rewriter.replaceOpWithNewOp<func::CallOp>(op, apiFuncName, TypeRange(),
@@ -710,7 +711,7 @@ class RuntimeIsErrorOpLowering : public OpConversionPattern<RuntimeIsErrorOp> {
ConversionPatternRewriter &rewriter) const override {
StringRef apiFuncName =
TypeSwitch<Type, StringRef>(op.getOperand().getType())
- .Case<TokenType>([](Type) { return kIsTokenError; })
+ .Case<async::TokenType>([](Type) { return kIsTokenError; })
.Case<GroupType>([](Type) { return kIsGroupError; })
.Case<ValueType>([](Type) { return kIsValueError; });
@@ -735,7 +736,7 @@ class RuntimeAwaitOpLowering : public OpConversionPattern<RuntimeAwaitOp> {
ConversionPatternRewriter &rewriter) const override {
StringRef apiFuncName =
TypeSwitch<Type, StringRef>(op.getOperand().getType())
- .Case<TokenType>([](Type) { return kAwaitToken; })
+ .Case<async::TokenType>([](Type) { return kAwaitToken; })
.Case<ValueType>([](Type) { return kAwaitValue; })
.Case<GroupType>([](Type) { return kAwaitGroup; });
@@ -763,7 +764,7 @@ class RuntimeAwaitAndResumeOpLowering
ConversionPatternRewriter &rewriter) const override {
StringRef apiFuncName =
TypeSwitch<Type, StringRef>(op.getOperand().getType())
- .Case<TokenType>([](Type) { return kAwaitTokenAndExecute; })
+ .Case<async::TokenType>([](Type) { return kAwaitTokenAndExecute; })
.Case<ValueType>([](Type) { return kAwaitValueAndExecute; })
.Case<GroupType>([](Type) { return kAwaitAllAndExecute; });
@@ -906,7 +907,7 @@ class RuntimeAddToGroupOpLowering
matchAndRewrite(RuntimeAddToGroupOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Currently we can only add tokens to the group.
- if (!isa<TokenType>(op.getOperand().getType()))
+ if (!isa<async::TokenType>(op.getOperand().getType()))
return rewriter.notifyMatchFailure(op, "only token type is supported");
// Replace with a runtime API function call.
@@ -1151,7 +1152,7 @@ class ConvertYieldOpTypes : public OpConversionPattern<async::YieldOp> {
void mlir::populateAsyncStructuralTypeConversionsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target) {
- typeConverter.addConversion([&](TokenType type) { return type; });
+ typeConverter.addConversion([&](async::TokenType type) { return type; });
typeConverter.addConversion([&](ValueType type) {
Type converted = typeConverter.convertType(type.getValueType());
return converted ? ValueType::get(converted) : converted;
diff --git a/mlir/lib/Dialect/Async/IR/Async.cpp b/mlir/lib/Dialect/Async/IR/Async.cpp
index 71be1d275280e..1713da07da60d 100644
--- a/mlir/lib/Dialect/Async/IR/Async.cpp
+++ b/mlir/lib/Dialect/Async/IR/Async.cpp
@@ -84,7 +84,7 @@ void ExecuteOp::build(OpBuilder &builder, OperationState &result,
// First result is always a token, and then `resultTypes` wrapped into
// `async.value`.
- result.addTypes({TokenType::get(result.getContext())});
+ result.addTypes({async::TokenType::get(result.getContext())});
for (Type type : resultTypes)
result.addTypes(ValueType::get(type));
@@ -139,7 +139,7 @@ ParseResult ExecuteOp::parse(OpAsmParser &parser, OperationState &result) {
// Sizes of parsed variadic operands, will be updated below after parsing.
int32_t numDependencies = 0;
- auto tokenTy = TokenType::get(ctx);
+ auto tokenTy = async::TokenType::get(ctx);
// Parse dependency tokens.
if (succeeded(parser.parseOptionalLSquare())) {
@@ -280,7 +280,7 @@ LogicalResult AwaitOp::verify() {
Type argType = getOperand().getType();
// Awaiting on a token does not have any results.
- if (llvm::isa<TokenType>(argType) && !getResultTypes().empty())
+ if (llvm::isa<async::TokenType>(argType) && !getResultTypes().empty())
return emitOpError("awaiting on a token must have empty result");
// Awaiting on a value unwraps the async value type.
@@ -345,12 +345,12 @@ LogicalResult FuncOp::verify() {
for (unsigned i = 0, e = resultTypes.size(); i != e; ++i) {
auto type = resultTypes[i];
- if (!llvm::isa<TokenType>(type) && !llvm::isa<ValueType>(type))
+ if (!llvm::isa<async::TokenType>(type) && !llvm::isa<ValueType>(type))
return emitOpError() << "result type must be async value type or async "
"token type, but got "
<< type;
// We only allow AsyncToken appear as the first return value
- if (llvm::isa<TokenType>(type) && i != 0) {
+ if (llvm::isa<async::TokenType>(type) && i != 0) {
return emitOpError()
<< " results' (optional) async token type is expected "
"to appear as the 1st return value, but got "
diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp
index 91e37dd9ac36e..2a726f3fd2999 100644
--- a/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp
+++ b/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp
@@ -526,7 +526,7 @@ void AsyncRuntimePolicyBasedRefCountingPass::initializeDefaultPolicy() {
Operation *op = operand.getOwner();
Type type = operand.get().getType();
- bool isToken = isa<TokenType>(type);
+ bool isToken = isa<async::TokenType>(type);
bool isGroup = isa<GroupType>(type);
bool isValue = isa<ValueType>(type);
diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
index 0c5bcfe631c6c..8b4aaf56fdf76 100644
--- a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
+++ b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
@@ -180,13 +180,14 @@ static CoroMachinery setupCoroMachinery(func::FuncOp func) {
// Allocate async token/values that we will return from a ramp function.
// ------------------------------------------------------------------------ //
- // We treat TokenType as state update marker to represent side-effects of
- // async computations
- bool isStateful = isa<TokenType>(func.getResultTypes().front());
+ // We treat async::TokenType as state update marker to represent
+ // side-effects of async computations
+ bool isStateful = isa<async::TokenType>(func.getResultTypes().front());
std::optional<Value> retToken;
if (isStateful)
- retToken.emplace(RuntimeCreateOp::create(builder, TokenType::get(ctx)));
+ retToken.emplace(
+ RuntimeCreateOp::create(builder, async::TokenType::get(ctx)));
llvm::SmallVector<Value, 4> retValues;
ArrayRef<Type> resValueTypes =
@@ -655,8 +656,9 @@ class AwaitOpLoweringBase : public OpConversionPattern<AwaitType> {
};
/// Lowering for `async.await` with a token operand.
-class AwaitTokenOpLowering : public AwaitOpLoweringBase<AwaitOp, TokenType> {
- using Base = AwaitOpLoweringBase<AwaitOp, TokenType>;
+class AwaitTokenOpLowering
+ : public AwaitOpLoweringBase<AwaitOp, async::TokenType> {
+ using Base = AwaitOpLoweringBase<AwaitOp, async::TokenType>;
public:
using Base::Base;
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
index 705d07d3e6c42..87df1e2df6102 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
@@ -238,13 +238,41 @@ Type LLVMStructType::parse(AsmParser &parser) {
}
/// Parses a type appearing inside another LLVM dialect-compatible type. This
-/// will try to parse any type in full form (including types with the `!llvm`
-/// prefix), and on failure fall back to parsing the short-hand version of the
-/// LLVM dialect types without the `!llvm` prefix.
+/// will first try to parse the LLVM dialect's short-hand keyword form (e.g.
+/// `token`, `void`, `ptr`, ...) and, failing that, fall back to parsing any
+/// MLIR type in full form (including types with the `!llvm` prefix).
+///
+/// Trying the short-hand form first matters because some LLVM short-hand
+/// keywords (notably `token`) collide with builtin type keywords whose
+/// semantics differ from the LLVM dialect's. Inside an `!llvm.<...>` type, the
+/// LLVM-specific meaning must always win.
static Type dispatchParse(AsmParser &parser, bool allowAny = true) {
SMLoc keyLoc = parser.getCurrentLocation();
+ MLIRContext *ctx = parser.getContext();
+
+ // Try parsing the LLVM dialect's short-hand keyword form first.
+ StringRef key;
+ if (succeeded(parser.parseOptionalKeyword(
+ &key, {"void", "ppc_fp128", "token", "label", "metadata", "func",
+ "ptr", "array", "struct", "target", "x86_amx"}))) {
+ // `parseOptionalKeyword` already restricted `key` to one of the cases
+ // below, so the `Default` is unreachable.
+ return StringSwitch<function_ref<Type()>>(key)
+ .Case("void", [&] { return LLVMVoidType::get(ctx); })
+ .Case("ppc_fp128", [&] { return LLVMPPCFP128Type::get(ctx); })
+ .Case("token", [&] { return LLVMTokenType::get(ctx); })
+ .Case("label", [&] { return LLVMLabelType::get(ctx); })
+ .Case("metadata", [&] { return LLVMMetadataType::get(ctx); })
+ .Case("func", [&] { return LLVMFunctionType::parse(parser); })
+ .Case("ptr", [&] { return LLVMPointerType::parse(parser); })
+ .Case("array", [&] { return LLVMArrayType::parse(parser); })
+ .Case("struct", [&] { return LLVMStructType::parse(parser); })
+ .Case("target", [&] { return LLVMTargetExtType::parse(parser); })
+ .Case("x86_amx", [&] { return LLVMX86AMXType::get(ctx); })
+ .Default([] { return Type(); })();
+ }
- // Try parsing any MLIR type.
+ // Otherwise, try parsing any MLIR type (only when allowed).
Type type;
OptionalParseResult result = parser.parseOptionalType(type);
if (result.has_value()) {
@@ -257,28 +285,12 @@ static Type dispatchParse(AsmParser &parser, bool allowAny = true) {
return type;
}
- // If no type found, fallback to the shorthand form.
- StringRef key;
+ // Neither a known LLVM short-hand keyword nor a parseable MLIR type.
+ // Re-run `parseKeyword` to produce a useful error message.
if (failed(parser.parseKeyword(&key)))
return Type();
-
- MLIRContext *ctx = parser.getContext();
- return StringSwitch<function_ref<Type()>>(key)
- .Case("void", [&] { return LLVMVoidType::get(ctx); })
- .Case("ppc_fp128", [&] { return LLVMPPCFP128Type::get(ctx); })
- .Case("token", [&] { return LLVMTokenType::get(ctx); })
- .Case("label", [&] { return LLVMLabelType::get(ctx); })
- .Case("metadata", [&] { return LLVMMetadataType::get(ctx); })
- .Case("func", [&] { return LLVMFunctionType::parse(parser); })
- .Case("ptr", [&] { return LLVMPointerType::parse(parser); })
- .Case("array", [&] { return LLVMArrayType::parse(parser); })
- .Case("struct", [&] { return LLVMStructType::parse(parser); })
- .Case("target", [&] { return LLVMTargetExtType::parse(parser); })
- .Case("x86_amx", [&] { return LLVMX86AMXType::get(ctx); })
- .Default([&] {
- parser.emitError(keyLoc) << "unknown LLVM type: " << key;
- return Type();
- })();
+ parser.emitError(keyLoc) << "unknown LLVM type: " << key;
+ return Type();
}
/// Helper to use in parse lists.
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index ec270db189081..ca5c2d2a88ee5 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -2907,6 +2907,7 @@ void AsmPrinter::Impl::printTypeImpl(Type type) {
os << '>';
})
.Case<NoneType>([&](Type) { os << "none"; })
+ .Case<TokenType>([&](Type) { os << "token"; })
.Case([&](GraphType graphTy) {
os << '(';
interleaveComma(graphTy.getInputs(), [&](Type ty) { printType(ty); });
diff --git a/mlir/test/Dialect/ArmSME/invalid.mlir b/mlir/test/Dialect/ArmSME/invalid.mlir
index 8c5a098a0c785..f00945e18cc1f 100644
--- a/mlir/test/Dialect/ArmSME/invalid.mlir
+++ b/mlir/test/Dialect/ArmSME/invalid.mlir
@@ -132,7 +132,7 @@ func.func @arm_sme_tile_load__pad_but_no_mask(%src : memref<?x?xf64>, %pad : f64
func.func @arm_sme_tile_load__bad_memref_rank(%src : memref<?xf64>, %pad : f64) {
%c0 = arith.constant 0 : index
- // expected-error at +1 {{op operand #0 must be 2D memref of any type values, but got 'memref<?xf64>'}}
+ // expected-error at +1 {{op operand #0 must be 2D memref of any non-token type values, but got 'memref<?xf64>'}}
%tile = arm_sme.tile_load %src[%c0], %pad, : memref<?xf64>, vector<[2]x[2]xf64>
return
}
@@ -186,7 +186,7 @@ func.func @arm_sme_tile_store__bad_mask_type(%tile : vector<[16]x[16]xi8>, %mask
func.func @arm_sme_tile_store__bad_memref_rank(%tile : vector<[16]x[16]xi8>, %dest : memref<?xi8>) {
%c0 = arith.constant 0 : index
- // expected-error at +1 {{op operand #1 must be 2D memref of any type values, but got 'memref<?xi8>'}}
+ // expected-error at +1 {{op operand #1 must be 2D memref of any non-token type values, but got 'memref<?xi8>'}}
arm_sme.tile_store %tile, %dest[%c0] : memref<?xi8>, vector<[16]x[16]xi8>
return
}
diff --git a/mlir/test/Dialect/Builtin/Bytecode/builtin_fixed.mlir b/mlir/test/Dialect/Builtin/Bytecode/builtin_fixed.mlir
index 638689406ab17..31061a6fa8653 100644
--- a/mlir/test/Dialect/Builtin/Bytecode/builtin_fixed.mlir
+++ b/mlir/test/Dialect/Builtin/Bytecode/builtin_fixed.mlir
@@ -282,6 +282,7 @@ module @TestBasicTypes attributes {
// CHECK-DAG: bytecode.ui64 = ui64
// CHECK-DAG: bytecode.index = index
// CHECK-DAG: bytecode.none = none
+ // CHECK-DAG: bytecode.token = token
bytecode.i1 = i1,
bytecode.i8 = i8,
bytecode.i32 = i32,
@@ -289,7 +290,8 @@ module @TestBasicTypes attributes {
bytecode.si32 = si32,
bytecode.ui64 = ui64,
bytecode.index = index,
- bytecode.none = none
+ bytecode.none = none,
+ bytecode.token = token
} {} loc(unknown)
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Builtin/Bytecode/builtin_fixed_0.mlirbc b/mlir/test/Dialect/Builtin/Bytecode/builtin_fixed_0.mlirbc
index e2f58a42751a1cddc0dc2845d06735faafdc8b32..76a764de40eced87c6cc3d98b1aabf8153649989 100644
GIT binary patch
delta 551
zcmX}gOK1~89LI5Ie)FGMT_#JrG=aosX;OEI!Kw%zJXE5jzNk{fiUuuJ8j)ZXDIPqu
zUhKhxv>mYG!DA1CUIHEzdhEeeMaB9;MMOo#`ob0w9|OW3zWhG?7ITyNu@)6lD2}mk
zeECgGy6*Vn@=~wpMPABdH50NRqr6PUbt^$7Wyi}0c{k<zK2YZK^*%NR=9 at 1fz)EzO
z1ZB5wq}$r3yHFd at iB0IjcI*V*(<OSKH`7C1rbnRtIEX{2f}ZFt^h|G~=TOl}oPu9C
z4ZVY2>TYv_(j%|zeGca_feW|<y^CJ!9(tp9(>r<(T){QaN4<wWfxH_iOhfP0`{*;&
zE!@F<%;70sfWGPqeb)!*hwh_A&<D)p8x}x6bwB;mhv_%Kg%tF+ZGaKypk|R2X^|Dw
zB6)O#E#nE^%qhvEqeyegnpWA*ie^}LuwipmwzA8nA=gkXgQnVqnjKk>HYF4oV?C at F
z`PhS45qlI->?sstUq><a3|7T{i`B9JU~TM#Soh!3V+=W7T5^_3O^rswbUW>s&V^>k
tX=g$GqSGm9=y+{lus)kO!;4kP)Rb+`CKXd%wyJH^He*|`B^~=a>>r_ZaOMC2
delta 579
zcmX}gU5HF^9L90~|8t&mW at q}%X=a$I8D`p<WlWRxR(nBKLqihUmD$#`w1kFo;bLyo
zPRUDA+I4E(xp1M}xG=RBgpiT<koOR>LRQFZR+dxhTzq>zJ;yWS>A`7MoJnJsN)Amw
zuMQ18FSZ2ocGiyBs^ZL6skog;1Onw$Ip&U?No4}DNF)Ljb)&vT4Wbqr)8BPdPv}V<
z(lFsi5Me|SM-pkw0xw7lFUewF=2fwHO@!BF32#V<w<N+luoj>Li_nQ>SOMOXrMxc%
zK9uEr1a_l at jo1V}ks_bTO1^-V-HF|Z<n};T at uim4&R)iAuibwi4&pG5;uvHN-^e<?
zlP<pJ2jC=5fgfc(KZEvJ<VGMHq=#Q%T|omkaTgEp82l=|{3d1okUstd-(nP>Fb0mx
zX8w|X{sty63I1u_s`wZ58$ly%L=EbxI=fAo+OLC}VV&K9sAk;Bnf=ZYvt3o4bLI>+
z>^wE+v64Wj;)r#-){a)ou%MLcQfra%yd5(=4<PIL1ah9wAn*Ai=6HUExt^Pt at 7XjK
z{<mo>g`_S_IiL$6r_a}hfpps0?`x}{_H~-|X0WRBB~w|$ZVk7lTZg5rl3Uen#BJ2g
J^tr!J{R0}mc9Q at A
diff --git a/mlir/test/Dialect/Builtin/Bytecode/types.mlir b/mlir/test/Dialect/Builtin/Bytecode/types.mlir
index bcfbf64c833dd..f3d15813b3978 100644
--- a/mlir/test/Dialect/Builtin/Bytecode/types.mlir
+++ b/mlir/test/Dialect/Builtin/Bytecode/types.mlir
@@ -145,3 +145,13 @@ module @TestVector attributes {
bytecode.test = vector<8x8x128xi8>,
bytecode.test1 = vector<8x[8]xf32>
} {}
+
+//===----------------------------------------------------------------------===//
+// TokenType
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @TestToken
+module @TestToken attributes {
+ // CHECK: bytecode.test = token
+ bytecode.test = token
+} {}
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index 06f3fcb41190b..a446cfcc4eec1 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -415,7 +415,7 @@ func.func @illegal_fill_memref_with_tensor_return
func.func @illegal_fill_tensor_with_memref_return
(%arg0 : tensor<?x?xf32>, %arg1 : f32) -> memref<?x?xf32>
{
- // expected-error @+1 {{result #0 must be variadic of ranked tensor of any type values, but got 'memref<?x?xf32>'}}
+ // expected-error @+1 {{result #0 must be variadic of ranked tensor of any non-token type values, but got 'memref<?x?xf32>'}}
%0 = linalg.fill ins(%arg1 : f32) outs(%arg0 : tensor<?x?xf32>) -> memref<?x?xf32>
return %0 : memref<?x?xf32>
}
@@ -468,7 +468,7 @@ func.func @invalid_scalar_input_matmul(%arg0: f32, %arg1: memref<3x4xf32>, %arg2
// -----
func.func @invalid_scalar_output_matmul(%arg0: memref<2x3xf32>, %arg1: memref<3x4xf32>, %arg2: f32) {
- // expected-error @+1 {{'linalg.matmul' op operand #2 must be variadic of shaped of any type values, but got 'f32'}}
+ // expected-error @+1 {{'linalg.matmul' op operand #2 must be variadic of shaped of any non-token type values, but got 'f32'}}
linalg.matmul ins(%arg0, %arg1 : memref<2x3xf32>, memref<3x4xf32>)
outs(%arg2 : f32)
return
diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir
index 2f061a1bb773e..ecffd683a98c2 100644
--- a/mlir/test/Dialect/MemRef/invalid.mlir
+++ b/mlir/test/Dialect/MemRef/invalid.mlir
@@ -1037,7 +1037,7 @@ func.func @test_alloc_memref_map_rank_mismatch() {
// -----
func.func @rank(%0: f32) {
- // expected-error at +1 {{'memref.rank' op operand #0 must be ranked or unranked memref of any type values}}
+ // expected-error at +1 {{'memref.rank' op operand #0 must be ranked or unranked memref of any non-token type values}}
"memref.rank"(%0): (f32)->index
return
}
@@ -1172,7 +1172,7 @@ func.func @memref_realloc_type(%src : memref<256xf32>) -> memref<?xi32>{
// Asking the dimension of a 0-D shape doesn't make sense.
func.func @dim_0_ranked(%arg : memref<f32>, %arg1 : index) {
- memref.dim %arg, %arg1 : memref<f32> // expected-error {{'memref.dim' op operand #0 must be unranked.memref of any type values or non-0-ranked.memref of any type values, but got 'memref<f32>'}}
+ memref.dim %arg, %arg1 : memref<f32> // expected-error {{'memref.dim' op operand #0 must be unranked.memref of any non-token type values or non-0-ranked.memref of any non-token type values, but got 'memref<f32>'}}
return
}
diff --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir
index ae706b9b148a6..d14229b011f11 100644
--- a/mlir/test/Dialect/SparseTensor/invalid.mlir
+++ b/mlir/test/Dialect/SparseTensor/invalid.mlir
@@ -1,7 +1,7 @@
// RUN: mlir-opt %s -split-input-file -verify-diagnostics
func.func @invalid_new_dense(%arg0: !llvm.ptr) -> tensor<32xf32> {
- // expected-error at +1 {{'sparse_tensor.new' op result #0 must be sparse tensor of any type values, but got 'tensor<32xf32>'}}
+ // expected-error at +1 {{'sparse_tensor.new' op result #0 must be sparse tensor of any non-token type values, but got 'tensor<32xf32>'}}
%0 = sparse_tensor.new %arg0 : !llvm.ptr to tensor<32xf32>
return %0 : tensor<32xf32>
}
@@ -96,7 +96,7 @@ func.func @invalid_unpack_mis_position(%sp: tensor<2x100xf64, #CSR>, %values: te
// -----
func.func @invalid_positions_dense(%arg0: tensor<128xf64>) -> memref<?xindex> {
- // expected-error at +1 {{'sparse_tensor.positions' op operand #0 must be sparse tensor of any type values, but got 'tensor<128xf64>'}}
+ // expected-error at +1 {{'sparse_tensor.positions' op operand #0 must be sparse tensor of any non-token type values, but got 'tensor<128xf64>'}}
%0 = sparse_tensor.positions %arg0 { level = 0 : index } : tensor<128xf64> to memref<?xindex>
return %0 : memref<?xindex>
}
@@ -104,7 +104,7 @@ func.func @invalid_positions_dense(%arg0: tensor<128xf64>) -> memref<?xindex> {
// -----
func.func @invalid_positions_unranked(%arg0: tensor<*xf64>) -> memref<?xindex> {
- // expected-error at +1 {{'sparse_tensor.positions' op operand #0 must be sparse tensor of any type values, but got 'tensor<*xf64>'}}
+ // expected-error at +1 {{'sparse_tensor.positions' op operand #0 must be sparse tensor of any non-token type values, but got 'tensor<*xf64>'}}
%0 = "sparse_tensor.positions"(%arg0) { level = 0 : index } : (tensor<*xf64>) -> (memref<?xindex>)
return %0 : memref<?xindex>
}
@@ -132,7 +132,7 @@ func.func @positions_oob(%arg0: tensor<128xf64, #SparseVector>) -> memref<?xinde
// -----
func.func @invalid_indices_dense(%arg0: tensor<10x10xi32>) -> memref<?xindex> {
- // expected-error at +1 {{'sparse_tensor.coordinates' op operand #0 must be sparse tensor of any type values, but got 'tensor<10x10xi32>'}}
+ // expected-error at +1 {{'sparse_tensor.coordinates' op operand #0 must be sparse tensor of any non-token type values, but got 'tensor<10x10xi32>'}}
%0 = sparse_tensor.coordinates %arg0 { level = 1 : index } : tensor<10x10xi32> to memref<?xindex>
return %0 : memref<?xindex>
}
@@ -140,7 +140,7 @@ func.func @invalid_indices_dense(%arg0: tensor<10x10xi32>) -> memref<?xindex> {
// -----
func.func @invalid_indices_unranked(%arg0: tensor<*xf64>) -> memref<?xindex> {
- // expected-error at +1 {{'sparse_tensor.coordinates' op operand #0 must be sparse tensor of any type values, but got 'tensor<*xf64>'}}
+ // expected-error at +1 {{'sparse_tensor.coordinates' op operand #0 must be sparse tensor of any non-token type values, but got 'tensor<*xf64>'}}
%0 = "sparse_tensor.coordinates"(%arg0) { level = 0 : index } : (tensor<*xf64>) -> (memref<?xindex>)
return %0 : memref<?xindex>
}
@@ -168,7 +168,7 @@ func.func @indices_oob(%arg0: tensor<128xf64, #SparseVector>) -> memref<?xindex>
// -----
func.func @invalid_values_dense(%arg0: tensor<1024xf32>) -> memref<?xf32> {
- // expected-error at +1 {{'sparse_tensor.values' op operand #0 must be sparse tensor of any type values, but got 'tensor<1024xf32>'}}
+ // expected-error at +1 {{'sparse_tensor.values' op operand #0 must be sparse tensor of any non-token type values, but got 'tensor<1024xf32>'}}
%0 = sparse_tensor.values %arg0 : tensor<1024xf32> to memref<?xf32>
return %0 : memref<?xf32>
}
@@ -186,7 +186,7 @@ func.func @indices_buffer_noncoo(%arg0: tensor<128xf64, #SparseVector>) -> memre
// -----
func.func @indices_buffer_dense(%arg0: tensor<1024xf32>) -> memref<?xindex> {
- // expected-error at +1 {{must be sparse tensor of any type values}}
+ // expected-error at +1 {{must be sparse tensor of any non-token type values}}
%0 = sparse_tensor.coordinates_buffer %arg0 : tensor<1024xf32> to memref<?xindex>
return %0 : memref<?xindex>
}
@@ -283,7 +283,7 @@ func.func @sparse_get_md(%arg0: !sparse_tensor.storage_specifier<#COO>) -> index
// -----
func.func @sparse_unannotated_load(%arg0: tensor<16x32xf64>) -> tensor<16x32xf64> {
- // expected-error at +1 {{'sparse_tensor.load' op operand #0 must be sparse tensor of any type values, but got 'tensor<16x32xf64>'}}
+ // expected-error at +1 {{'sparse_tensor.load' op operand #0 must be sparse tensor of any non-token type values, but got 'tensor<16x32xf64>'}}
%0 = sparse_tensor.load %arg0 : tensor<16x32xf64>
return %0 : tensor<16x32xf64>
}
@@ -308,7 +308,7 @@ func.func @sparse_push_back_n(%arg0: index, %arg1: memref<?xf32>, %arg2: f32) ->
// -----
func.func @sparse_unannotated_expansion(%arg0: tensor<128xf64>) {
- // expected-error at +1 {{'sparse_tensor.expand' op operand #0 must be sparse tensor of any type values, but got 'tensor<128xf64>'}}
+ // expected-error at +1 {{'sparse_tensor.expand' op operand #0 must be sparse tensor of any non-token type values, but got 'tensor<128xf64>'}}
%values, %filled, %added, %count = sparse_tensor.expand %arg0
: tensor<128xf64> to memref<?xf64>, memref<?xi1>, memref<?xindex>
return
@@ -322,7 +322,7 @@ func.func @sparse_unannotated_compression(%arg0: memref<?xf64>,
%arg3: index,
%arg4: tensor<8x8xf64>,
%arg5: index) {
- // expected-error at +1 {{'sparse_tensor.compress' op operand #4 must be sparse tensor of any type values, but got 'tensor<8x8xf64>'}}
+ // expected-error at +1 {{'sparse_tensor.compress' op operand #4 must be sparse tensor of any non-token type values, but got 'tensor<8x8xf64>'}}
sparse_tensor.compress %arg0, %arg1, %arg2, %arg3 into %arg4[%arg5]
: memref<?xf64>, memref<?xi1>, memref<?xindex>, tensor<8x8xf64>
return
@@ -375,7 +375,7 @@ func.func @sparse_convert_dim_mismatch(%arg0: tensor<10x?xf32>) -> tensor<10x10x
// -----
func.func @invalid_out_dense(%arg0: tensor<10xf64>, %arg1: !llvm.ptr) {
- // expected-error at +1 {{'sparse_tensor.out' op operand #0 must be sparse tensor of any type values, but got 'tensor<10xf64>'}}
+ // expected-error at +1 {{'sparse_tensor.out' op operand #0 must be sparse tensor of any non-token type values, but got 'tensor<10xf64>'}}
sparse_tensor.out %arg0, %arg1 : tensor<10xf64>, !llvm.ptr
return
}
@@ -1022,7 +1022,7 @@ func.func @sparse_reinterpret_map(%t0 : tensor<6x12xi32, #BSR>) -> tensor<3x4x2x
#CSR = #sparse_tensor.encoding<{map = (d0, d1) -> (d0 : compressed, d1 : compressed)}>
func.func @sparse_print(%arg0: tensor<10x10xf64>) {
- // expected-error at +1 {{'sparse_tensor.print' op operand #0 must be sparse tensor of any type values}}
+ // expected-error at +1 {{'sparse_tensor.print' op operand #0 must be sparse tensor of any non-token type values}}
sparse_tensor.print %arg0 : tensor<10x10xf64>
return
}
diff --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir
index 6ee2f9911663f..a526d7ed61722 100644
--- a/mlir/test/Dialect/Tensor/invalid.mlir
+++ b/mlir/test/Dialect/Tensor/invalid.mlir
@@ -404,7 +404,7 @@ func.func @illegal_collapsing_reshape_mixed_tensor_2(%arg0 : tensor<?x4x5xf32>)
// -----
func.func @rank(%0: f32) {
- // expected-error at +1 {{'tensor.rank' op operand #0 must be tensor of any type values}}
+ // expected-error at +1 {{'tensor.rank' op operand #0 must be tensor of any non-token type values}}
"tensor.rank"(%0): (f32)->index
return
}
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index f90312c915334..36c697c78d93d 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -106,7 +106,7 @@ func.func @shuffle_index_out_of_range(%arg0: vector<2xf32>, %arg1: vector<2xf32>
// -----
func.func @shuffle_scalable_vec(%arg0: vector<[2]xf32>, %arg1: vector<[2]xf32>) {
- // expected-error at +1 {{'vector.shuffle' op operand #0 must be fixed-length vector of any type values}}
+ // expected-error at +1 {{'vector.shuffle' op operand #0 must be fixed-length vector of any non-token type values}}
%1 = vector.shuffle %arg0, %arg1 [0, 1, 2, 3] : vector<[2]xf32>, vector<[2]xf32>
}
@@ -1460,7 +1460,7 @@ func.func @maskedstore_memref_mismatch(%base: memref<?xf32>, %mask: vector<16xi1
func.func @gather_from_vector(%base: vector<16xf32>, %indices: vector<16xi32>,
%mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
%c0 = arith.constant 0 : index
- // expected-error at +1 {{'vector.gather' op operand #0 must be Tensor or MemRef of any type values, but got 'vector<16xf32>'}}
+ // expected-error at +1 {{'vector.gather' op operand #0 must be Tensor or MemRef of any non-token type values, but got 'vector<16xf32>'}}
%0 = vector.gather %base[%c0][%indices], %mask, %pass_thru
: vector<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
}
@@ -1557,7 +1557,7 @@ func.func @gather_tensor_alignment(%base: tensor<16xf32>, %indices: vector<16xi3
func.func @scatter_to_vector(%base: vector<16xf32>, %indices: vector<16xi32>,
%mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
%c0 = arith.constant 0 : index
- // expected-error at +1 {{'vector.scatter' op operand #0 must be Tensor or MemRef of any type values, but got 'vector<16xf32>'}}
+ // expected-error at +1 {{'vector.scatter' op operand #0 must be Tensor or MemRef of any non-token type values, but got 'vector<16xf32>'}}
vector.scatter %base[%c0][%indices], %mask, %pass_thru
: vector<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
}
@@ -1943,7 +1943,7 @@ func.func @invalid_outerproduct1(%src : memref<?xf32>, %lhs : vector<[4]x[4]xf32
// -----
func.func @deinterleave_zero_dim_fail(%vec : vector<f32>) {
- // expected-error @+1 {{'vector.deinterleave' op operand #0 must be vector of any type values, but got 'vector<f32>}}
+ // expected-error @+1 {{'vector.deinterleave' op operand #0 must be vector of any non-token type values, but got 'vector<f32>}}
%0, %1 = vector.deinterleave %vec : vector<f32> -> vector<f32>
return
}
@@ -2032,7 +2032,7 @@ func.func @from_elements_wrong_operand_type(%a: f32, %b: i32) {
// -----
func.func @invalid_from_elements_scalable(%a: f32, %b: i32) {
- // expected-error @+1 {{'dest' must be fixed-length vector of any type values, but got 'vector<[2]xf32>'}}
+ // expected-error @+1 {{'dest' must be fixed-length vector of any non-token type values, but got 'vector<[2]xf32>'}}
vector.from_elements %a, %b : vector<[2]xf32>
return
}
diff --git a/mlir/test/Dialect/traits.mlir b/mlir/test/Dialect/traits.mlir
index 4d583435adeee..ae48cadbf370f 100644
--- a/mlir/test/Dialect/traits.mlir
+++ b/mlir/test/Dialect/traits.mlir
@@ -58,7 +58,7 @@ func.func @broadcast_tensor_tensor_tensor(tensor<8x1x?x1xi32>, tensor<7x1x5xi32>
// Check incompatible vector and tensor result type
func.func @broadcast_scalar_vector_vector(tensor<4xf32>, tensor<4xf32>) -> vector<4xf32> {
^bb0(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>):
- // expected-error @+1 {{op result #0 must be tensor of any type values, but got 'vector<4xf32>'}}
+ // expected-error @+1 {{op result #0 must be tensor of any non-token type values, but got 'vector<4xf32>'}}
%0 = "test.broadcastable"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> vector<4xf32>
return %0 : vector<4xf32>
}
diff --git a/mlir/test/IR/operand.mlir b/mlir/test/IR/operand.mlir
index 507e37c775c0b..1ac12dc4b9556 100644
--- a/mlir/test/IR/operand.mlir
+++ b/mlir/test/IR/operand.mlir
@@ -13,7 +13,7 @@ func.func @correct_variadic_operand(%arg0: tensor<f32>, %arg1: f32) {
// -----
func.func @error_in_first_variadic_operand(%arg0: tensor<f32>, %arg1: f32) {
- // expected-error @+1 {{operand #1 must be variadic of tensor of any type}}
+ // expected-error @+1 {{operand #1 must be variadic of tensor of any non-token type}}
"test.mixed_normal_variadic_operand"(%arg0, %arg1, %arg0, %arg0, %arg0) : (tensor<f32>, f32, tensor<f32>, tensor<f32>, tensor<f32>) -> ()
return
}
@@ -21,7 +21,7 @@ func.func @error_in_first_variadic_operand(%arg0: tensor<f32>, %arg1: f32) {
// -----
func.func @error_in_normal_operand(%arg0: tensor<f32>, %arg1: f32) {
- // expected-error @+1 {{operand #2 must be tensor of any type}}
+ // expected-error @+1 {{operand #2 must be tensor of any non-token type}}
"test.mixed_normal_variadic_operand"(%arg0, %arg0, %arg1, %arg0, %arg0) : (tensor<f32>, tensor<f32>, f32, tensor<f32>, tensor<f32>) -> ()
return
}
@@ -29,7 +29,7 @@ func.func @error_in_normal_operand(%arg0: tensor<f32>, %arg1: f32) {
// -----
func.func @error_in_second_variadic_operand(%arg0: tensor<f32>, %arg1: f32) {
- // expected-error @+1 {{operand #3 must be variadic of tensor of any type}}
+ // expected-error @+1 {{operand #3 must be variadic of tensor of any non-token type}}
"test.mixed_normal_variadic_operand"(%arg0, %arg0, %arg0, %arg1, %arg0) : (tensor<f32>, tensor<f32>, tensor<f32>, f32, tensor<f32>) -> ()
return
}
diff --git a/mlir/test/IR/result.mlir b/mlir/test/IR/result.mlir
index 1e4eb3bede4c5..cdeae4202f0ff 100644
--- a/mlir/test/IR/result.mlir
+++ b/mlir/test/IR/result.mlir
@@ -13,7 +13,7 @@ func.func @correct_variadic_result() -> tensor<f32> {
// -----
func.func @error_in_first_variadic_result() -> tensor<f32> {
- // expected-error @+1 {{result #1 must be variadic of tensor of any type}}
+ // expected-error @+1 {{result #1 must be variadic of tensor of any non-token type}}
%0:5 = "test.mixed_normal_variadic_result"() : () -> (tensor<f32>, f32, tensor<f32>, tensor<f32>, tensor<f32>)
return %0#4 : tensor<f32>
}
@@ -21,7 +21,7 @@ func.func @error_in_first_variadic_result() -> tensor<f32> {
// -----
func.func @error_in_normal_result() -> tensor<f32> {
- // expected-error @+1 {{result #2 must be tensor of any type}}
+ // expected-error @+1 {{result #2 must be tensor of any non-token type}}
%0:5 = "test.mixed_normal_variadic_result"() : () -> (tensor<f32>, tensor<f32>, f32, tensor<f32>, tensor<f32>)
return %0#4 : tensor<f32>
}
@@ -29,7 +29,7 @@ func.func @error_in_normal_result() -> tensor<f32> {
// -----
func.func @error_in_second_variadic_result() -> tensor<f32> {
- // expected-error @+1 {{result #3 must be variadic of tensor of any type}}
+ // expected-error @+1 {{result #3 must be variadic of tensor of any non-token type}}
%0:5 = "test.mixed_normal_variadic_result"() : () -> (tensor<f32>, tensor<f32>, tensor<f32>, f32, tensor<f32>)
return %0#4 : tensor<f32>
}
diff --git a/mlir/test/IR/token-type.mlir b/mlir/test/IR/token-type.mlir
new file mode 100644
index 0000000000000..990a7917effe1
--- /dev/null
+++ b/mlir/test/IR/token-type.mlir
@@ -0,0 +1,60 @@
+// RUN: mlir-opt %s -verify-diagnostics -split-input-file | FileCheck %s
+
+// Tests for the builtin `token` type and the
+// `Token` / `AnyType` / `AnyTypeOrToken` ODS predicates.
+//
+// The default `AnyType` predicate excludes tokens, while `AnyTypeOrToken` and
+// `Token` accept them.
+
+// CHECK-LABEL: @token_produce_consume
+func.func @token_produce_consume() {
+ // CHECK: %[[T:.*]] = test.token.produce
+ %t = test.token.produce
+ // CHECK: test.token.consume %[[T]]
+ test.token.consume %t
+ // CHECK: test.token.any_or_token %[[T]] : token
+ test.token.any_or_token %t : token
+ return
+}
+
+// -----
+
+// `AnyTypeOrToken` also accepts non-token types.
+// CHECK-LABEL: @any_or_token_with_non_token
+func.func @any_or_token_with_non_token(%arg0: i32) {
+ // CHECK: test.token.any_or_token %{{.*}} : i32
+ test.token.any_or_token %arg0 : i32
+ return
+}
+
+// -----
+
+// `AnyType` accepts arbitrary non-token types.
+// CHECK-LABEL: @any_type_with_non_token
+func.func @any_type_with_non_token(%arg0: i32) {
+ // CHECK: test.token.any_type %{{.*}} : i32
+ test.token.any_type %arg0 : i32
+ return
+}
+
+// -----
+
+// `AnyType` rejects tokens by default.
+func.func @any_type_rejects_token() {
+ %t = test.token.produce
+ // expected-error @below {{operand #0 must be any non-token type}}
+ test.token.any_type %t : token
+ return
+}
+
+// -----
+
+// `Token` rejects non-token types. The op's operand type is fixed to the
+// builtin `token` (it's a `BuildableType`), so passing a non-token SSA value
+// fails at parse time with an SSA type mismatch.
+// expected-note @below {{prior use here}}
+func.func @token_rejects_non_token(%arg0: i32) {
+ // expected-error @below {{use of value '%arg0' expects different type than prior uses: 'token' vs 'i32'}}
+ test.token.consume %arg0
+ return
+}
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 56db6837b870c..eda75edbb69cd 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -110,6 +110,36 @@ def SignlessLikeVariadic : TEST_Op<"signless_like_variadic"> {
let arguments = (ins Variadic<SignlessIntegerLike>:$x);
}
+//===----------------------------------------------------------------------===//
+// Test Token Type
+//===----------------------------------------------------------------------===//
+
+// Produce a builtin `!token` value.
+def TestTokenProduceOp : TEST_Op<"token.produce"> {
+ let results = (outs Token:$token);
+ let assemblyFormat = "attr-dict";
+}
+
+// Consume a builtin `!token` value (token-only operand).
+def TestTokenConsumeOp : TEST_Op<"token.consume"> {
+ let arguments = (ins Token:$token);
+ let assemblyFormat = "$token attr-dict";
+}
+
+// Op that accepts any type, including a token. Uses the `AnyTypeOrToken`
+// opt-in predicate.
+def TestTokenAnyTypeOrTokenOp : TEST_Op<"token.any_or_token"> {
+ let arguments = (ins AnyTypeOrToken:$value);
+ let assemblyFormat = "$value attr-dict `:` type($value)";
+}
+
+// Op that uses the default `AnyType` predicate. Tokens are excluded by
+// default and should be rejected by the verifier when passed here.
+def TestTokenAnyTypeOp : TEST_Op<"token.any_type"> {
+ let arguments = (ins AnyType:$value);
+ let assemblyFormat = "$value attr-dict `:` type($value)";
+}
+
//===----------------------------------------------------------------------===//
// Test Symbols
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/mlir-tblgen/predicate.td b/mlir/test/mlir-tblgen/predicate.td
index 41e041f171213..ae436885b421f 100644
--- a/mlir/test/mlir-tblgen/predicate.td
+++ b/mlir/test/mlir-tblgen/predicate.td
@@ -27,9 +27,9 @@ def OpA : NS_Op<"op_for_CPred_containing_multiple_same_placeholder", []> {
// CHECK-NOT. << " must be 32-bit integer or floating-point type, but got " << type;
// CHECK: static ::llvm::LogicalResult [[$TENSOR_CONSTRAINT:__mlir_ods_local_type_constraint.*]](
-// CHECK: if (!(((::llvm::isa<::mlir::TensorType>(type))) && ([](::mlir::Type elementType) { return (true); }(::llvm::cast<::mlir::ShapedType>(type).getElementType())))) {
+// CHECK: if (!(((::llvm::isa<::mlir::TensorType>(type))) && ([](::mlir::Type elementType) { return !((::llvm::isa<::mlir::TokenType>(elementType))); }(::llvm::cast<::mlir::ShapedType>(type).getElementType())))) {
// CHECK-NEXT: return op->emitOpError(valueKind) << " #" << valueIndex
-// CHECK-NEXT: << " must be tensor of any type values, but got " << type;
+// CHECK-NEXT: << " must be tensor of any non-token type values, but got " << type;
// CHECK: static ::llvm::LogicalResult [[$TENSOR_INTEGER_FLOAT_CONSTRAINT:__mlir_ods_local_type_constraint.*]](
// CHECK: if (!(((::llvm::isa<::mlir::TensorType>(type))) && ([](::mlir::Type elementType) { return ((elementType.isF32())) || ((elementType.isSignlessInteger(32))); }(::llvm::cast<::mlir::ShapedType>(type).getElementType())))) {
diff --git a/mlir/test/mlir-tblgen/types.mlir b/mlir/test/mlir-tblgen/types.mlir
index c2acce0903bf4..30aea48e3e369 100644
--- a/mlir/test/mlir-tblgen/types.mlir
+++ b/mlir/test/mlir-tblgen/types.mlir
@@ -204,7 +204,7 @@ func.func @ranked_tensor_success(%arg0: tensor<i8>, %arg1: tensor<1xi32>, %arg2:
// -----
func.func @ranked_tensor_success(%arg0: tensor<*xf32>) {
- // expected-error @+1 {{must be ranked tensor of any type values}}
+ // expected-error @+1 {{must be ranked tensor of any non-token type values}}
"test.ranked_tensor_op"(%arg0) : (tensor<*xf32>) -> ()
return
}
@@ -212,7 +212,7 @@ func.func @ranked_tensor_success(%arg0: tensor<*xf32>) {
// -----
func.func @ranked_tensor_success(%arg0: vector<2xf32>) {
- // expected-error @+1 {{must be ranked tensor of any type values}}
+ // expected-error @+1 {{must be ranked tensor of any non-token type values}}
"test.ranked_tensor_op"(%arg0) : (vector<2xf32>) -> ()
return
}
@@ -510,7 +510,7 @@ func.func @does_not_have_i32(%arg0: tensor<1x2xi32>, %arg1: none) {
// -----
func.func @does_not_have_static_memref(%arg0: memref<?xi32>) {
- // expected-error at +1 {{'test.takes_static_memref' op operand #0 must be statically shaped memref of any type values}}
+ // expected-error at +1 {{'test.takes_static_memref' op operand #0 must be statically shaped memref of any non-token type values}}
"test.takes_static_memref"(%arg0) : (memref<?xi32>) -> ()
}
More information about the Mlir-commits
mailing list