[Mlir-commits] [mlir] [mlir][IR] Add builtin `TokenTypeInterface` (PR #195640)

Matthias Springer llvmlistbot at llvm.org
Tue May 5 07:48:05 PDT 2026


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/195640

>From 39761ab7c7efdc44d80e65b6c1ced02ef01de999 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Mon, 4 May 2026 12:14:41 +0000
Subject: [PATCH 1/2] [mlir][IR] Add builtin `TokenTypeInterface`

---
 mlir/docs/Dialects/Builtin.md                 | 10 +++
 mlir/docs/Tokens.md                           | 89 +++++++++++++++++++
 mlir/include/mlir/IR/BuiltinTypeInterfaces.td | 20 +++++
 mlir/include/mlir/IR/CommonTypeConstraints.td | 21 ++++-
 mlir/test/Dialect/ArmSME/invalid.mlir         |  4 +-
 mlir/test/Dialect/Linalg/invalid.mlir         |  4 +-
 mlir/test/Dialect/MemRef/invalid.mlir         |  4 +-
 mlir/test/Dialect/SparseTensor/invalid.mlir   | 24 ++---
 mlir/test/Dialect/Tensor/invalid.mlir         |  2 +-
 mlir/test/Dialect/Vector/invalid.mlir         | 10 +--
 mlir/test/Dialect/traits.mlir                 |  2 +-
 mlir/test/IR/operand.mlir                     |  6 +-
 mlir/test/IR/result.mlir                      |  6 +-
 mlir/test/IR/token-type-interface.mlir        | 59 ++++++++++++
 mlir/test/lib/Dialect/Test/TestOps.td         | 32 +++++++
 mlir/test/lib/Dialect/Test/TestTypeDefs.td    |  7 ++
 mlir/test/mlir-tblgen/predicate.td            |  4 +-
 mlir/test/mlir-tblgen/types.mlir              |  6 +-
 18 files changed, 272 insertions(+), 38 deletions(-)
 create mode 100644 mlir/docs/Tokens.md
 create mode 100644 mlir/test/IR/token-type-interface.mlir

diff --git a/mlir/docs/Dialects/Builtin.md b/mlir/docs/Dialects/Builtin.md
index 0a9b7ae8919b5..818d6ded486db 100644
--- a/mlir/docs/Dialects/Builtin.md
+++ b/mlir/docs/Dialects/Builtin.md
@@ -65,3 +65,13 @@ marked using one DistinctAttribute instance per alias group.
 ## Type Interfaces
 
 [include "Dialects/BuiltinTypeInterfaces.md"]
+
+## Token Types
+
+A *token type* is any type that implements the builtin `TokenTypeInterface`.
+Tokens are SSA values that exist purely to encode a *static* def–use
+relationship between operations or regions; they carry no runtime data and
+must not be value-forwarded.
+
+See the [Tokens design note](../Tokens.md) for the structural contract,
+ODS predicates (`AnyType` / `AnyTypeOrToken` / `Token`), and examples.
diff --git a/mlir/docs/Tokens.md b/mlir/docs/Tokens.md
new file mode 100644
index 0000000000000..81bcc1203b6aa
--- /dev/null
+++ b/mlir/docs/Tokens.md
@@ -0,0 +1,89 @@
+# Tokens
+
+[TOC]
+
+## Overview
+
+Intuitively, a *token* value is a pointer to an operation (via an OpResult)
+or a pointer to a region (via an entry block argument).
+
+More precisely, a token is an SSA value whose purpose is to encode a
+**static** def–use relationship between operations or regions. It carries
+no runtime data and is not allowed to flow through "regular"
+value-forwarding constructs. A token's provenance cannot be obscured through
+value forwarding.
+
+In MLIR, "token" is not a single concrete builtin type. Instead, any type
+that implements the builtin `TokenTypeInterface` is treated as a token by
+the framework. Dialects can define their own dialect-specific token types.
+
+## `TokenTypeInterface`
+
+`TokenTypeInterface` is a parameterless, methodless marker type interface.
+
+A type opts in by attaching the interface in TableGen:
+
+```tablegen
+def MyDialect_Token : TypeDef<MyDialect, "MyToken", [TokenTypeInterface]> {
+  let mnemonic = "token";
+}
+```
+
+## Structural Contract
+
+A token value is, by construction:
+
+1. **Not value-forwarding.** In particular, a token must not appear as a
+   forwarded value. E.g.:
+    * a forwarded result/operand of a `CallOpInterface` op,
+    * an argument or result type of a `FunctionOpInterface` op (a token
+      block argument *inside* a function body is fine — what is disallowed
+      is forwarding tokens across the call/return boundary),
+    * a successor operand or successor block argument of a
+      `BranchOpInterface` op,
+    * a forwarded operand to/from any region of a `RegionBranchOpInterface`
+      op (iter-args, region results, yielded values), or
+    * the result of any op that selects or merges values it does not
+      understand (e.g. `arith.select`).
+
+2. **Statically resolvable.** Walking the def–use chain from any token use
+   reaches a producing op without crossing a forwarding boundary.
+
+3. **Cannot constant-fold.** No constant of token type exists.
+
+These properties mirror what LLVM IR already documents for its own
+[`token` type](https://llvm.org/docs/LangRef.html#token-type).
+
+## ODS Integration
+
+Tokens are excluded from the default `AnyType` predicate, so an op that has
+not opted in cannot accept a token as an arbitrary operand or result.
+Three predicates are provided in `CommonTypeConstraints.td`:
+
+| Predicate          | Accepts                              | Use when …                                                            |
+| ------------------ | ------------------------------------ | ----------------------------------------------------------------------|
+| `AnyType`          | any non-token type                   | the default; matches the historical meaning of "any type" pre-tokens. |
+| `AnyTypeOrToken`   | any type, including tokens           | the op legitimately accepts arbitrary types (including tokens).       |
+| `Token`            | only types implementing `TokenTypeInterface` | the op specifically takes a token operand/result.             |
+
+
+## Examples
+
+### Rejected: tokens in `AnyType` positions
+
+```mlir
+// error: 'scf.if' op result #0 must be variadic of any non-token type,
+//        but got '!my.token'
+%t = scf.if %cond -> !my.token {
+  %a = my.token.produce : !my.token
+  scf.yield %a : !my.token
+} else {
+  %b = my.token.produce : !my.token
+  scf.yield %b : !my.token
+}
+```
+
+`scf.if`'s results are declared with `Variadic<AnyType>` and `scf.yield`'s
+operands likewise use `AnyType`. Because `AnyType` excludes tokens by
+default, yielding (or returning) a token through a `scf.if` (or any other
+op that has not explicitly opted in via `AnyTypeOrToken`) is rejected.
diff --git a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
index 93c8c0694b467..a50d73e21a69a 100644
--- a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
+++ b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
@@ -236,6 +236,26 @@ def PtrLikeTypeInterface : TypeInterface<"PtrLikeTypeInterface"> {
   ];
 }
 
+//===----------------------------------------------------------------------===//
+// TokenTypeInterface
+//===----------------------------------------------------------------------===//
+
+def TokenTypeInterface : TypeInterface<"TokenTypeInterface"> {
+  let cppNamespace = "::mlir";
+  let description = [{
+    Intuitively, a *token* value is a pointer to an operation (via an OpResult)
+    or a pointer to a region (via an entry block argument).
+
+    More precisely, a token is an SSA value whose purpose is to encode a
+    static def–use relationship between operations or regions. It carries
+    no runtime data and is not allowed to flow through "regular"
+    value-forwarding constructs. A token's provenance cannot be obscured through
+    value forwarding.
+
+    This interface is a marker. It has no interface methods.
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // ShapedType
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index 57caaae08462f..7898937ca01f5 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -165,8 +165,25 @@ class SameBuildabilityAs<Type type, code builder> {
   code builderCall = !if(!empty(type.builderCall), "", builder);
 }
 
-// Any type at all.
-def AnyType : Type<CPred<"true">, "any type">;
+// Whether a type is a token (i.e. implements TokenTypeInterface).
+def IsTokenTypePred
+    : CPred<"::llvm::isa<::mlir::TokenTypeInterface>($_self)">;
+
+// Any non-token type. Tokens are excluded by default to prevent ops that
+// accept arbitrary types from accidentally accepting tokens as operands /
+// results, since a token must not be value-forwarded. Ops that legitimately
+// want to accept any type, including tokens, should use `AnyTypeOrToken`
+// instead.
+def AnyType : Type<Neg<IsTokenTypePred>, "any non-token type">;
+
+// Any type at all, including tokens. Used by ops that explicitly opt in to
+// accepting tokens (e.g. ops in interfaces such as `CallOpInterface`,
+// `BranchOpInterface`, etc. that legitimately handle arbitrary types).
+def AnyTypeOrToken : Type<CPred<"true">, "any type">;
+
+// A token type (any type implementing `TokenTypeInterface`).
+def Token : Type<IsTokenTypePred, "token",
+                 "::mlir::TokenTypeInterface">;
 
 // None type
 def NoneType : Type<CPred<"::llvm::isa<::mlir::NoneType>($_self)">, "none type",
diff --git a/mlir/test/Dialect/ArmSME/invalid.mlir b/mlir/test/Dialect/ArmSME/invalid.mlir
index 8c5a098a0c785..f00945e18cc1f 100644
--- a/mlir/test/Dialect/ArmSME/invalid.mlir
+++ b/mlir/test/Dialect/ArmSME/invalid.mlir
@@ -132,7 +132,7 @@ func.func @arm_sme_tile_load__pad_but_no_mask(%src : memref<?x?xf64>, %pad : f64
 
 func.func @arm_sme_tile_load__bad_memref_rank(%src : memref<?xf64>, %pad : f64) {
   %c0 = arith.constant 0 : index
-  // expected-error at +1 {{op operand #0 must be 2D memref of any type values, but got 'memref<?xf64>'}}
+  // expected-error at +1 {{op operand #0 must be 2D memref of any non-token type values, but got 'memref<?xf64>'}}
   %tile = arm_sme.tile_load %src[%c0], %pad, : memref<?xf64>, vector<[2]x[2]xf64>
   return
 }
@@ -186,7 +186,7 @@ func.func @arm_sme_tile_store__bad_mask_type(%tile : vector<[16]x[16]xi8>, %mask
 
 func.func @arm_sme_tile_store__bad_memref_rank(%tile : vector<[16]x[16]xi8>, %dest : memref<?xi8>) {
   %c0 = arith.constant 0 : index
-  // expected-error at +1 {{op operand #1 must be 2D memref of any type values, but got 'memref<?xi8>'}}
+  // expected-error at +1 {{op operand #1 must be 2D memref of any non-token type values, but got 'memref<?xi8>'}}
   arm_sme.tile_store %tile, %dest[%c0] : memref<?xi8>, vector<[16]x[16]xi8>
   return
 }
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index 06f3fcb41190b..a446cfcc4eec1 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -415,7 +415,7 @@ func.func @illegal_fill_memref_with_tensor_return
 func.func @illegal_fill_tensor_with_memref_return
   (%arg0 : tensor<?x?xf32>, %arg1 : f32) -> memref<?x?xf32>
 {
-  // expected-error @+1 {{result #0 must be variadic of ranked tensor of any type values, but got 'memref<?x?xf32>'}}
+  // expected-error @+1 {{result #0 must be variadic of ranked tensor of any non-token type values, but got 'memref<?x?xf32>'}}
   %0 = linalg.fill ins(%arg1 : f32) outs(%arg0 : tensor<?x?xf32>) -> memref<?x?xf32>
   return %0 : memref<?x?xf32>
 }
@@ -468,7 +468,7 @@ func.func @invalid_scalar_input_matmul(%arg0: f32, %arg1: memref<3x4xf32>, %arg2
 // -----
 
 func.func @invalid_scalar_output_matmul(%arg0: memref<2x3xf32>, %arg1: memref<3x4xf32>, %arg2: f32) {
-  // expected-error @+1 {{'linalg.matmul' op operand #2 must be variadic of shaped of any type values, but got 'f32'}}
+  // expected-error @+1 {{'linalg.matmul' op operand #2 must be variadic of shaped of any non-token type values, but got 'f32'}}
   linalg.matmul ins(%arg0, %arg1 : memref<2x3xf32>, memref<3x4xf32>)
                 outs(%arg2 : f32)
   return
diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir
index 2f061a1bb773e..ecffd683a98c2 100644
--- a/mlir/test/Dialect/MemRef/invalid.mlir
+++ b/mlir/test/Dialect/MemRef/invalid.mlir
@@ -1037,7 +1037,7 @@ func.func @test_alloc_memref_map_rank_mismatch() {
 // -----
 
 func.func @rank(%0: f32) {
-  // expected-error at +1 {{'memref.rank' op operand #0 must be ranked or unranked memref of any type values}}
+  // expected-error at +1 {{'memref.rank' op operand #0 must be ranked or unranked memref of any non-token type values}}
   "memref.rank"(%0): (f32)->index
   return
 }
@@ -1172,7 +1172,7 @@ func.func @memref_realloc_type(%src : memref<256xf32>) -> memref<?xi32>{
 
 // Asking the dimension of a 0-D shape doesn't make sense.
 func.func @dim_0_ranked(%arg : memref<f32>, %arg1 : index) {
-  memref.dim %arg, %arg1 : memref<f32> // expected-error {{'memref.dim' op operand #0 must be unranked.memref of any type values or non-0-ranked.memref of any type values, but got 'memref<f32>'}}
+  memref.dim %arg, %arg1 : memref<f32> // expected-error {{'memref.dim' op operand #0 must be unranked.memref of any non-token type values or non-0-ranked.memref of any non-token type values, but got 'memref<f32>'}}
   return
 }
 
diff --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir
index ae706b9b148a6..d14229b011f11 100644
--- a/mlir/test/Dialect/SparseTensor/invalid.mlir
+++ b/mlir/test/Dialect/SparseTensor/invalid.mlir
@@ -1,7 +1,7 @@
 // RUN: mlir-opt %s -split-input-file -verify-diagnostics
 
 func.func @invalid_new_dense(%arg0: !llvm.ptr) -> tensor<32xf32> {
-  // expected-error at +1 {{'sparse_tensor.new' op result #0 must be sparse tensor of any type values, but got 'tensor<32xf32>'}}
+  // expected-error at +1 {{'sparse_tensor.new' op result #0 must be sparse tensor of any non-token type values, but got 'tensor<32xf32>'}}
   %0 = sparse_tensor.new %arg0 : !llvm.ptr to tensor<32xf32>
   return %0 : tensor<32xf32>
 }
@@ -96,7 +96,7 @@ func.func @invalid_unpack_mis_position(%sp: tensor<2x100xf64, #CSR>, %values: te
 // -----
 
 func.func @invalid_positions_dense(%arg0: tensor<128xf64>) -> memref<?xindex> {
-  // expected-error at +1 {{'sparse_tensor.positions' op operand #0 must be sparse tensor of any type values, but got 'tensor<128xf64>'}}
+  // expected-error at +1 {{'sparse_tensor.positions' op operand #0 must be sparse tensor of any non-token type values, but got 'tensor<128xf64>'}}
   %0 = sparse_tensor.positions %arg0 { level = 0 : index } : tensor<128xf64> to memref<?xindex>
   return %0 : memref<?xindex>
 }
@@ -104,7 +104,7 @@ func.func @invalid_positions_dense(%arg0: tensor<128xf64>) -> memref<?xindex> {
 // -----
 
 func.func @invalid_positions_unranked(%arg0: tensor<*xf64>) -> memref<?xindex> {
-  // expected-error at +1 {{'sparse_tensor.positions' op operand #0 must be sparse tensor of any type values, but got 'tensor<*xf64>'}}
+  // expected-error at +1 {{'sparse_tensor.positions' op operand #0 must be sparse tensor of any non-token type values, but got 'tensor<*xf64>'}}
   %0 = "sparse_tensor.positions"(%arg0) { level = 0 : index } : (tensor<*xf64>) -> (memref<?xindex>)
   return %0 : memref<?xindex>
 }
@@ -132,7 +132,7 @@ func.func @positions_oob(%arg0: tensor<128xf64, #SparseVector>) -> memref<?xinde
 // -----
 
 func.func @invalid_indices_dense(%arg0: tensor<10x10xi32>) -> memref<?xindex> {
-  // expected-error at +1 {{'sparse_tensor.coordinates' op operand #0 must be sparse tensor of any type values, but got 'tensor<10x10xi32>'}}
+  // expected-error at +1 {{'sparse_tensor.coordinates' op operand #0 must be sparse tensor of any non-token type values, but got 'tensor<10x10xi32>'}}
   %0 = sparse_tensor.coordinates %arg0 { level = 1 : index } : tensor<10x10xi32> to memref<?xindex>
   return %0 : memref<?xindex>
 }
@@ -140,7 +140,7 @@ func.func @invalid_indices_dense(%arg0: tensor<10x10xi32>) -> memref<?xindex> {
 // -----
 
 func.func @invalid_indices_unranked(%arg0: tensor<*xf64>) -> memref<?xindex> {
-  // expected-error at +1 {{'sparse_tensor.coordinates' op operand #0 must be sparse tensor of any type values, but got 'tensor<*xf64>'}}
+  // expected-error at +1 {{'sparse_tensor.coordinates' op operand #0 must be sparse tensor of any non-token type values, but got 'tensor<*xf64>'}}
   %0 = "sparse_tensor.coordinates"(%arg0) { level = 0 : index } : (tensor<*xf64>) -> (memref<?xindex>)
   return %0 : memref<?xindex>
 }
@@ -168,7 +168,7 @@ func.func @indices_oob(%arg0: tensor<128xf64, #SparseVector>) -> memref<?xindex>
 // -----
 
 func.func @invalid_values_dense(%arg0: tensor<1024xf32>) -> memref<?xf32> {
-  // expected-error at +1 {{'sparse_tensor.values' op operand #0 must be sparse tensor of any type values, but got 'tensor<1024xf32>'}}
+  // expected-error at +1 {{'sparse_tensor.values' op operand #0 must be sparse tensor of any non-token type values, but got 'tensor<1024xf32>'}}
   %0 = sparse_tensor.values %arg0 : tensor<1024xf32> to memref<?xf32>
   return %0 : memref<?xf32>
 }
@@ -186,7 +186,7 @@ func.func @indices_buffer_noncoo(%arg0: tensor<128xf64, #SparseVector>) -> memre
 // -----
 
 func.func @indices_buffer_dense(%arg0: tensor<1024xf32>) -> memref<?xindex> {
-  // expected-error at +1 {{must be sparse tensor of any type values}}
+  // expected-error at +1 {{must be sparse tensor of any non-token type values}}
   %0 = sparse_tensor.coordinates_buffer %arg0 : tensor<1024xf32> to memref<?xindex>
   return %0 : memref<?xindex>
 }
@@ -283,7 +283,7 @@ func.func @sparse_get_md(%arg0: !sparse_tensor.storage_specifier<#COO>) -> index
 // -----
 
 func.func @sparse_unannotated_load(%arg0: tensor<16x32xf64>) -> tensor<16x32xf64> {
-  // expected-error at +1 {{'sparse_tensor.load' op operand #0 must be sparse tensor of any type values, but got 'tensor<16x32xf64>'}}
+  // expected-error at +1 {{'sparse_tensor.load' op operand #0 must be sparse tensor of any non-token type values, but got 'tensor<16x32xf64>'}}
   %0 = sparse_tensor.load %arg0 : tensor<16x32xf64>
   return %0 : tensor<16x32xf64>
 }
@@ -308,7 +308,7 @@ func.func @sparse_push_back_n(%arg0: index, %arg1: memref<?xf32>, %arg2: f32) ->
 // -----
 
 func.func @sparse_unannotated_expansion(%arg0: tensor<128xf64>) {
-  // expected-error at +1 {{'sparse_tensor.expand' op operand #0 must be sparse tensor of any type values, but got 'tensor<128xf64>'}}
+  // expected-error at +1 {{'sparse_tensor.expand' op operand #0 must be sparse tensor of any non-token type values, but got 'tensor<128xf64>'}}
   %values, %filled, %added, %count = sparse_tensor.expand %arg0
     : tensor<128xf64> to memref<?xf64>, memref<?xi1>, memref<?xindex>
   return
@@ -322,7 +322,7 @@ func.func @sparse_unannotated_compression(%arg0: memref<?xf64>,
                                           %arg3: index,
                                           %arg4: tensor<8x8xf64>,
                                           %arg5: index) {
-  // expected-error at +1 {{'sparse_tensor.compress' op operand #4 must be sparse tensor of any type values, but got 'tensor<8x8xf64>'}}
+  // expected-error at +1 {{'sparse_tensor.compress' op operand #4 must be sparse tensor of any non-token type values, but got 'tensor<8x8xf64>'}}
   sparse_tensor.compress %arg0, %arg1, %arg2, %arg3 into %arg4[%arg5]
     : memref<?xf64>, memref<?xi1>, memref<?xindex>, tensor<8x8xf64>
   return
@@ -375,7 +375,7 @@ func.func @sparse_convert_dim_mismatch(%arg0: tensor<10x?xf32>) -> tensor<10x10x
 // -----
 
 func.func @invalid_out_dense(%arg0: tensor<10xf64>, %arg1: !llvm.ptr) {
-  // expected-error at +1 {{'sparse_tensor.out' op operand #0 must be sparse tensor of any type values, but got 'tensor<10xf64>'}}
+  // expected-error at +1 {{'sparse_tensor.out' op operand #0 must be sparse tensor of any non-token type values, but got 'tensor<10xf64>'}}
   sparse_tensor.out %arg0, %arg1 : tensor<10xf64>, !llvm.ptr
   return
 }
@@ -1022,7 +1022,7 @@ func.func @sparse_reinterpret_map(%t0 : tensor<6x12xi32, #BSR>) -> tensor<3x4x2x
 #CSR = #sparse_tensor.encoding<{map = (d0, d1) -> (d0 : compressed, d1 : compressed)}>
 
 func.func @sparse_print(%arg0: tensor<10x10xf64>) {
-  // expected-error at +1 {{'sparse_tensor.print' op operand #0 must be sparse tensor of any type values}}
+  // expected-error at +1 {{'sparse_tensor.print' op operand #0 must be sparse tensor of any non-token type values}}
   sparse_tensor.print %arg0 : tensor<10x10xf64>
   return
 }
diff --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir
index 6ee2f9911663f..a526d7ed61722 100644
--- a/mlir/test/Dialect/Tensor/invalid.mlir
+++ b/mlir/test/Dialect/Tensor/invalid.mlir
@@ -404,7 +404,7 @@ func.func @illegal_collapsing_reshape_mixed_tensor_2(%arg0 : tensor<?x4x5xf32>)
 // -----
 
 func.func @rank(%0: f32) {
-  // expected-error at +1 {{'tensor.rank' op operand #0 must be tensor of any type values}}
+  // expected-error at +1 {{'tensor.rank' op operand #0 must be tensor of any non-token type values}}
   "tensor.rank"(%0): (f32)->index
   return
 }
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index f90312c915334..36c697c78d93d 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -106,7 +106,7 @@ func.func @shuffle_index_out_of_range(%arg0: vector<2xf32>, %arg1: vector<2xf32>
 // -----
 
 func.func @shuffle_scalable_vec(%arg0: vector<[2]xf32>, %arg1: vector<[2]xf32>) {
-  // expected-error at +1 {{'vector.shuffle' op operand #0 must be fixed-length vector of any type values}}
+  // expected-error at +1 {{'vector.shuffle' op operand #0 must be fixed-length vector of any non-token type values}}
   %1 = vector.shuffle %arg0, %arg1 [0, 1, 2, 3] : vector<[2]xf32>, vector<[2]xf32>
 }
 
@@ -1460,7 +1460,7 @@ func.func @maskedstore_memref_mismatch(%base: memref<?xf32>, %mask: vector<16xi1
 func.func @gather_from_vector(%base: vector<16xf32>, %indices: vector<16xi32>,
                                 %mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
   %c0 = arith.constant 0 : index
-  // expected-error at +1 {{'vector.gather' op operand #0 must be Tensor or MemRef of any type values, but got 'vector<16xf32>'}}
+  // expected-error at +1 {{'vector.gather' op operand #0 must be Tensor or MemRef of any non-token type values, but got 'vector<16xf32>'}}
   %0 = vector.gather %base[%c0][%indices], %mask, %pass_thru
     : vector<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
 }
@@ -1557,7 +1557,7 @@ func.func @gather_tensor_alignment(%base: tensor<16xf32>, %indices: vector<16xi3
 func.func @scatter_to_vector(%base: vector<16xf32>, %indices: vector<16xi32>,
                              %mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
   %c0 = arith.constant 0 : index
-  // expected-error at +1 {{'vector.scatter' op operand #0 must be Tensor or MemRef of any type values, but got 'vector<16xf32>'}}
+  // expected-error at +1 {{'vector.scatter' op operand #0 must be Tensor or MemRef of any non-token type values, but got 'vector<16xf32>'}}
   vector.scatter %base[%c0][%indices], %mask, %pass_thru
     : vector<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
 }
@@ -1943,7 +1943,7 @@ func.func @invalid_outerproduct1(%src : memref<?xf32>, %lhs : vector<[4]x[4]xf32
 // -----
 
 func.func @deinterleave_zero_dim_fail(%vec : vector<f32>) {
-  // expected-error @+1 {{'vector.deinterleave' op operand #0 must be vector of any type values, but got 'vector<f32>}}
+  // expected-error @+1 {{'vector.deinterleave' op operand #0 must be vector of any non-token type values, but got 'vector<f32>}}
   %0, %1 = vector.deinterleave %vec : vector<f32> -> vector<f32>
   return
 }
@@ -2032,7 +2032,7 @@ func.func @from_elements_wrong_operand_type(%a: f32, %b: i32) {
 // -----
 
 func.func @invalid_from_elements_scalable(%a: f32, %b: i32) {
-  // expected-error @+1 {{'dest' must be fixed-length vector of any type values, but got 'vector<[2]xf32>'}}
+  // expected-error @+1 {{'dest' must be fixed-length vector of any non-token type values, but got 'vector<[2]xf32>'}}
   vector.from_elements %a, %b : vector<[2]xf32>
   return
 }
diff --git a/mlir/test/Dialect/traits.mlir b/mlir/test/Dialect/traits.mlir
index 4d583435adeee..ae48cadbf370f 100644
--- a/mlir/test/Dialect/traits.mlir
+++ b/mlir/test/Dialect/traits.mlir
@@ -58,7 +58,7 @@ func.func @broadcast_tensor_tensor_tensor(tensor<8x1x?x1xi32>, tensor<7x1x5xi32>
 // Check incompatible vector and tensor result type
 func.func @broadcast_scalar_vector_vector(tensor<4xf32>, tensor<4xf32>) -> vector<4xf32> {
 ^bb0(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>):
-  // expected-error @+1 {{op result #0 must be tensor of any type values, but got 'vector<4xf32>'}}
+  // expected-error @+1 {{op result #0 must be tensor of any non-token type values, but got 'vector<4xf32>'}}
   %0 = "test.broadcastable"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> vector<4xf32>
   return %0 : vector<4xf32>
 }
diff --git a/mlir/test/IR/operand.mlir b/mlir/test/IR/operand.mlir
index 507e37c775c0b..1ac12dc4b9556 100644
--- a/mlir/test/IR/operand.mlir
+++ b/mlir/test/IR/operand.mlir
@@ -13,7 +13,7 @@ func.func @correct_variadic_operand(%arg0: tensor<f32>, %arg1: f32) {
 // -----
 
 func.func @error_in_first_variadic_operand(%arg0: tensor<f32>, %arg1: f32) {
-  // expected-error @+1 {{operand #1 must be variadic of tensor of any type}}
+  // expected-error @+1 {{operand #1 must be variadic of tensor of any non-token type}}
   "test.mixed_normal_variadic_operand"(%arg0, %arg1, %arg0, %arg0, %arg0) : (tensor<f32>, f32, tensor<f32>, tensor<f32>, tensor<f32>) -> ()
   return
 }
@@ -21,7 +21,7 @@ func.func @error_in_first_variadic_operand(%arg0: tensor<f32>, %arg1: f32) {
 // -----
 
 func.func @error_in_normal_operand(%arg0: tensor<f32>, %arg1: f32) {
-  // expected-error @+1 {{operand #2 must be tensor of any type}}
+  // expected-error @+1 {{operand #2 must be tensor of any non-token type}}
   "test.mixed_normal_variadic_operand"(%arg0, %arg0, %arg1, %arg0, %arg0) : (tensor<f32>, tensor<f32>, f32, tensor<f32>, tensor<f32>) -> ()
   return
 }
@@ -29,7 +29,7 @@ func.func @error_in_normal_operand(%arg0: tensor<f32>, %arg1: f32) {
 // -----
 
 func.func @error_in_second_variadic_operand(%arg0: tensor<f32>, %arg1: f32) {
-  // expected-error @+1 {{operand #3 must be variadic of tensor of any type}}
+  // expected-error @+1 {{operand #3 must be variadic of tensor of any non-token type}}
   "test.mixed_normal_variadic_operand"(%arg0, %arg0, %arg0, %arg1, %arg0) : (tensor<f32>, tensor<f32>, tensor<f32>, f32, tensor<f32>) -> ()
   return
 }
diff --git a/mlir/test/IR/result.mlir b/mlir/test/IR/result.mlir
index 1e4eb3bede4c5..cdeae4202f0ff 100644
--- a/mlir/test/IR/result.mlir
+++ b/mlir/test/IR/result.mlir
@@ -13,7 +13,7 @@ func.func @correct_variadic_result() -> tensor<f32> {
 // -----
 
 func.func @error_in_first_variadic_result() -> tensor<f32> {
-  // expected-error @+1 {{result #1 must be variadic of tensor of any type}}
+  // expected-error @+1 {{result #1 must be variadic of tensor of any non-token type}}
   %0:5 = "test.mixed_normal_variadic_result"() : () -> (tensor<f32>, f32, tensor<f32>, tensor<f32>, tensor<f32>)
   return %0#4 : tensor<f32>
 }
@@ -21,7 +21,7 @@ func.func @error_in_first_variadic_result() -> tensor<f32> {
 // -----
 
 func.func @error_in_normal_result() -> tensor<f32> {
-  // expected-error @+1 {{result #2 must be tensor of any type}}
+  // expected-error @+1 {{result #2 must be tensor of any non-token type}}
   %0:5 = "test.mixed_normal_variadic_result"() : () -> (tensor<f32>, tensor<f32>, f32, tensor<f32>, tensor<f32>)
   return %0#4 : tensor<f32>
 }
@@ -29,7 +29,7 @@ func.func @error_in_normal_result() -> tensor<f32> {
 // -----
 
 func.func @error_in_second_variadic_result() -> tensor<f32> {
-  // expected-error @+1 {{result #3 must be variadic of tensor of any type}}
+  // expected-error @+1 {{result #3 must be variadic of tensor of any non-token type}}
   %0:5 = "test.mixed_normal_variadic_result"() : () -> (tensor<f32>, tensor<f32>, tensor<f32>, f32, tensor<f32>)
   return %0#4 : tensor<f32>
 }
diff --git a/mlir/test/IR/token-type-interface.mlir b/mlir/test/IR/token-type-interface.mlir
new file mode 100644
index 0000000000000..c0632bdcf616b
--- /dev/null
+++ b/mlir/test/IR/token-type-interface.mlir
@@ -0,0 +1,59 @@
+// RUN: mlir-opt %s -verify-diagnostics -split-input-file | FileCheck %s
+
+// Tests for the builtin `TokenTypeInterface` and the
+// `Token` / `AnyType` / `AnyTypeOrToken` ODS predicates.
+//
+// `!test.test_token` is a test-dialect type that implements
+// `TokenTypeInterface`. The default `AnyType` predicate excludes tokens, while
+// `AnyTypeOrToken` and `Token` accept them.
+
+// CHECK-LABEL: @token_produce_consume
+func.func @token_produce_consume() {
+  // CHECK: %[[T:.*]] = test.token.produce : !test.test_token
+  %t = test.token.produce : !test.test_token
+  // CHECK: test.token.consume %[[T]] : !test.test_token
+  test.token.consume %t : !test.test_token
+  // CHECK: test.token.any_or_token %[[T]] : !test.test_token
+  test.token.any_or_token %t : !test.test_token
+  return
+}
+
+// -----
+
+// `AnyTypeOrToken` also accepts non-token types.
+// CHECK-LABEL: @any_or_token_with_non_token
+func.func @any_or_token_with_non_token(%arg0: i32) {
+  // CHECK: test.token.any_or_token %{{.*}} : i32
+  test.token.any_or_token %arg0 : i32
+  return
+}
+
+// -----
+
+// `AnyType` accepts arbitrary non-token types.
+// CHECK-LABEL: @any_type_with_non_token
+func.func @any_type_with_non_token(%arg0: i32) {
+  // CHECK: test.token.any_type %{{.*}} : i32
+  test.token.any_type %arg0 : i32
+  return
+}
+
+// -----
+
+// `AnyType` rejects tokens by default.
+func.func @any_type_rejects_token() {
+  %t = test.token.produce : !test.test_token
+  // expected-error @below {{operand #0 must be any non-token type}}
+  test.token.any_type %t : !test.test_token
+  return
+}
+
+// -----
+
+// `Token` rejects non-token types. The operand's cppType is
+// `TokenTypeInterface`, so type resolution fails at parse time.
+func.func @token_rejects_non_token(%arg0: i32) {
+  // expected-error @below {{invalid kind of type specified}}
+  test.token.consume %arg0 : i32
+  return
+}
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 348ff5d7f4ea0..529fc4b860ad4 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -110,6 +110,38 @@ def SignlessLikeVariadic : TEST_Op<"signless_like_variadic"> {
   let arguments = (ins Variadic<SignlessIntegerLike>:$x);
 }
 
+//===----------------------------------------------------------------------===//
+// Test Token Type
+//===----------------------------------------------------------------------===//
+
+// Produce a token. Demonstrates a type that implements the builtin
+// `TokenTypeInterface`.
+def TestTokenProduceOp : TEST_Op<"token.produce"> {
+  let results = (outs TestTokenType:$token);
+  let assemblyFormat = "attr-dict `:` type($token)";
+}
+
+// Consume a token (token-only operand). Uses the `Token` ODS predicate which
+// only accepts types implementing `TokenTypeInterface`.
+def TestTokenConsumeOp : TEST_Op<"token.consume"> {
+  let arguments = (ins Token:$token);
+  let assemblyFormat = "$token attr-dict `:` type($token)";
+}
+
+// Op that accepts any type, including a token. Uses the `AnyTypeOrToken`
+// opt-in predicate.
+def TestTokenAnyTypeOrTokenOp : TEST_Op<"token.any_or_token"> {
+  let arguments = (ins AnyTypeOrToken:$value);
+  let assemblyFormat = "$value attr-dict `:` type($value)";
+}
+
+// Op that uses the default `AnyType` predicate. Tokens are excluded by
+// default and should be rejected by the verifier when passed here.
+def TestTokenAnyTypeOp : TEST_Op<"token.any_type"> {
+  let arguments = (ins AnyType:$value);
+  let assemblyFormat = "$value attr-dict `:` type($value)";
+}
+
 //===----------------------------------------------------------------------===//
 // Test Symbols
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
index 08600ce713a17..5df0eab829a03 100644
--- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
@@ -173,6 +173,13 @@ def TestMemRefElementType : Test_Type<"TestMemRefElementType",
   let mnemonic = "memref_element";
 }
 
+// A test token type implementing the builtin `TokenTypeInterface`. Used to
+// exercise the default exclusion of tokens from `AnyType` and the explicit
+// `Token` / `AnyTypeOrToken` opt-ins.
+def TestTokenType : Test_Type<"TestToken", [TokenTypeInterface]> {
+  let mnemonic = "test_token";
+}
+
 def TestTypeTrait : NativeTypeTrait<"TestTypeTrait">;
 
 // The definition of a singleton type that has a trait.
diff --git a/mlir/test/mlir-tblgen/predicate.td b/mlir/test/mlir-tblgen/predicate.td
index 41e041f171213..07a5e6f2f261c 100644
--- a/mlir/test/mlir-tblgen/predicate.td
+++ b/mlir/test/mlir-tblgen/predicate.td
@@ -27,9 +27,9 @@ def OpA : NS_Op<"op_for_CPred_containing_multiple_same_placeholder", []> {
 // CHECK-NOT.        << " must be 32-bit integer or floating-point type, but got " << type;
 
 // CHECK: static ::llvm::LogicalResult [[$TENSOR_CONSTRAINT:__mlir_ods_local_type_constraint.*]](
-// CHECK:       if (!(((::llvm::isa<::mlir::TensorType>(type))) && ([](::mlir::Type elementType) { return (true); }(::llvm::cast<::mlir::ShapedType>(type).getElementType())))) {
+// CHECK:       if (!(((::llvm::isa<::mlir::TensorType>(type))) && ([](::mlir::Type elementType) { return !((::llvm::isa<::mlir::TokenTypeInterface>(elementType))); }(::llvm::cast<::mlir::ShapedType>(type).getElementType())))) {
 // CHECK-NEXT:    return op->emitOpError(valueKind) << " #" << valueIndex
-// CHECK-NEXT:        << " must be tensor of any type values, but got " << type;
+// CHECK-NEXT:        << " must be tensor of any non-token type values, but got " << type;
 
 // CHECK: static ::llvm::LogicalResult [[$TENSOR_INTEGER_FLOAT_CONSTRAINT:__mlir_ods_local_type_constraint.*]](
 // CHECK:       if (!(((::llvm::isa<::mlir::TensorType>(type))) && ([](::mlir::Type elementType) { return ((elementType.isF32())) || ((elementType.isSignlessInteger(32))); }(::llvm::cast<::mlir::ShapedType>(type).getElementType())))) {
diff --git a/mlir/test/mlir-tblgen/types.mlir b/mlir/test/mlir-tblgen/types.mlir
index c2acce0903bf4..30aea48e3e369 100644
--- a/mlir/test/mlir-tblgen/types.mlir
+++ b/mlir/test/mlir-tblgen/types.mlir
@@ -204,7 +204,7 @@ func.func @ranked_tensor_success(%arg0: tensor<i8>, %arg1: tensor<1xi32>, %arg2:
 // -----
 
 func.func @ranked_tensor_success(%arg0: tensor<*xf32>) {
-  // expected-error @+1 {{must be ranked tensor of any type values}}
+  // expected-error @+1 {{must be ranked tensor of any non-token type values}}
   "test.ranked_tensor_op"(%arg0) : (tensor<*xf32>) -> ()
   return
 }
@@ -212,7 +212,7 @@ func.func @ranked_tensor_success(%arg0: tensor<*xf32>) {
 // -----
 
 func.func @ranked_tensor_success(%arg0: vector<2xf32>) {
-  // expected-error @+1 {{must be ranked tensor of any type values}}
+  // expected-error @+1 {{must be ranked tensor of any non-token type values}}
   "test.ranked_tensor_op"(%arg0) : (vector<2xf32>) -> ()
   return
 }
@@ -510,7 +510,7 @@ func.func @does_not_have_i32(%arg0: tensor<1x2xi32>, %arg1: none) {
 // -----
 
 func.func @does_not_have_static_memref(%arg0: memref<?xi32>) {
-  // expected-error at +1 {{'test.takes_static_memref' op operand #0 must be statically shaped memref of any type values}}
+  // expected-error at +1 {{'test.takes_static_memref' op operand #0 must be statically shaped memref of any non-token type values}}
   "test.takes_static_memref"(%arg0) : (memref<?xi32>) -> ()
 }
 

>From 4b85bedee3cf5ddcb4f78b6b4792702dac12f9ef Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Tue, 5 May 2026 14:47:30 +0000
Subject: [PATCH 2/2] type instead of type interface

---
 mlir/docs/Dialects/Builtin.md                 |  9 ++-
 mlir/docs/Tokens.md                           | 38 +++++-------
 mlir/include/mlir/Dialect/Async/IR/Async.h    |  2 +-
 .../include/mlir/Dialect/Async/IR/AsyncOps.td |  4 +-
 .../include/mlir/IR/BuiltinDialectBytecode.td |  5 +-
 mlir/include/mlir/IR/BuiltinTypeInterfaces.td | 20 -------
 mlir/include/mlir/IR/BuiltinTypes.td          | 27 +++++++++
 mlir/include/mlir/IR/CommonTypeConstraints.td | 11 ++--
 mlir/lib/AsmParser/TokenKinds.def             |  1 +
 mlir/lib/AsmParser/TypeParser.cpp             |  6 ++
 .../Conversion/AsyncToLLVM/AsyncToLLVM.cpp    | 35 +++++------
 mlir/lib/Dialect/Async/IR/Async.cpp           | 10 ++--
 .../Transforms/AsyncRuntimeRefCounting.cpp    |  2 +-
 .../Async/Transforms/AsyncToAsyncRuntime.cpp  | 14 +++--
 mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp | 60 +++++++++++--------
 mlir/lib/IR/AsmPrinter.cpp                    |  1 +
 ...en-type-interface.mlir => token-type.mlir} | 33 +++++-----
 mlir/test/lib/Dialect/Test/TestOps.td         | 12 ++--
 mlir/test/lib/Dialect/Test/TestTypeDefs.td    |  7 ---
 mlir/test/mlir-tblgen/predicate.td            |  2 +-
 20 files changed, 159 insertions(+), 140 deletions(-)
 rename mlir/test/IR/{token-type-interface.mlir => token-type.mlir} (52%)

diff --git a/mlir/docs/Dialects/Builtin.md b/mlir/docs/Dialects/Builtin.md
index 818d6ded486db..8083d1f88b757 100644
--- a/mlir/docs/Dialects/Builtin.md
+++ b/mlir/docs/Dialects/Builtin.md
@@ -66,12 +66,11 @@ marked using one DistinctAttribute instance per alias group.
 
 [include "Dialects/BuiltinTypeInterfaces.md"]
 
-## Token Types
+## Token Type
 
-A *token type* is any type that implements the builtin `TokenTypeInterface`.
-Tokens are SSA values that exist purely to encode a *static* def–use
-relationship between operations or regions; they carry no runtime data and
-must not be value-forwarded.
+The builtin `token` type is an SSA value type that exists purely to encode
+a *static* def–use relationship between operations or regions. Tokens carry
+no runtime data and must not be value-forwarded.
 
 See the [Tokens design note](../Tokens.md) for the structural contract,
 ODS predicates (`AnyType` / `AnyTypeOrToken` / `Token`), and examples.
diff --git a/mlir/docs/Tokens.md b/mlir/docs/Tokens.md
index 81bcc1203b6aa..6ff25a3b8925d 100644
--- a/mlir/docs/Tokens.md
+++ b/mlir/docs/Tokens.md
@@ -13,21 +13,8 @@ no runtime data and is not allowed to flow through "regular"
 value-forwarding constructs. A token's provenance cannot be obscured through
 value forwarding.
 
-In MLIR, "token" is not a single concrete builtin type. Instead, any type
-that implements the builtin `TokenTypeInterface` is treated as a token by
-the framework. Dialects can define their own dialect-specific token types.
-
-## `TokenTypeInterface`
-
-`TokenTypeInterface` is a parameterless, methodless marker type interface.
-
-A type opts in by attaching the interface in TableGen:
-
-```tablegen
-def MyDialect_Token : TypeDef<MyDialect, "MyToken", [TokenTypeInterface]> {
-  let mnemonic = "token";
-}
-```
+The token type is a builtin type. It is parameterless, opaque, and prints
+as `token`.
 
 ## Structural Contract
 
@@ -64,8 +51,15 @@ Three predicates are provided in `CommonTypeConstraints.td`:
 | ------------------ | ------------------------------------ | ----------------------------------------------------------------------|
 | `AnyType`          | any non-token type                   | the default; matches the historical meaning of "any type" pre-tokens. |
 | `AnyTypeOrToken`   | any type, including tokens           | the op legitimately accepts arbitrary types (including tokens).       |
-| `Token`            | only types implementing `TokenTypeInterface` | the op specifically takes a token operand/result.             |
+| `Token`            | only the builtin `TokenType`         | the op specifically takes a token operand/result.                     |
+
+Example:
 
+```tablegen
+def MyConsumeOp : MyDialect_Op<"consume"> {
+  let arguments = (ins Token:$scope, AnyType:$value);
+}
+```
 
 ## Examples
 
@@ -73,13 +67,13 @@ Three predicates are provided in `CommonTypeConstraints.td`:
 
 ```mlir
 // error: 'scf.if' op result #0 must be variadic of any non-token type,
-//        but got '!my.token'
-%t = scf.if %cond -> !my.token {
-  %a = my.token.produce : !my.token
-  scf.yield %a : !my.token
+//        but got 'token'
+%t = scf.if %cond -> token {
+  %a = my.token.produce : token
+  scf.yield %a : token
 } else {
-  %b = my.token.produce : !my.token
-  scf.yield %b : !my.token
+  %b = my.token.produce : token
+  scf.yield %b : token
 }
 ```
 
diff --git a/mlir/include/mlir/Dialect/Async/IR/Async.h b/mlir/include/mlir/Dialect/Async/IR/Async.h
index f16e87e71373a..fc0b086126f52 100644
--- a/mlir/include/mlir/Dialect/Async/IR/Async.h
+++ b/mlir/include/mlir/Dialect/Async/IR/Async.h
@@ -50,7 +50,7 @@ namespace async {
 
 /// Returns true if the type is reference counted at runtime.
 inline bool isRefCounted(Type type) {
-  return isa<TokenType, ValueType, GroupType>(type);
+  return isa<async::TokenType, ValueType, GroupType>(type);
 }
 
 } // namespace async
diff --git a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td
index 2cebeac767f29..058f58bda6433 100644
--- a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td
+++ b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td
@@ -174,7 +174,9 @@ def Async_FuncOp : Async_Op<"func",
     unsigned getNumResults() {return getResultTypes().size();}
 
     /// Is the async func stateful
-    bool isStateful() { return isa<TokenType>(getFunctionType().getResult(0));}
+    bool isStateful() {
+      return isa<async::TokenType>(getFunctionType().getResult(0));
+    }
 
     //===------------------------------------------------------------------===//
     // OpAsmOpInterface Methods
diff --git a/mlir/include/mlir/IR/BuiltinDialectBytecode.td b/mlir/include/mlir/IR/BuiltinDialectBytecode.td
index c97d093c84e51..99bcf70c77564 100644
--- a/mlir/include/mlir/IR/BuiltinDialectBytecode.td
+++ b/mlir/include/mlir/IR/BuiltinDialectBytecode.td
@@ -294,6 +294,8 @@ def UnrankedTensorType : DialectType<(type
   Type:$elementType
 )>;
 
+def TokenType : DialectType<(type)>;
+
 let cType = "VectorType" in {
 def VectorType : DialectType<(type
   Array<SignedVarIntList>:$shape,
@@ -371,7 +373,8 @@ def BuiltinDialectTypes : DialectTypes<"Builtin"> {
     UnrankedMemRefTypeWithMemSpace,
     UnrankedTensorType,
     VectorType,
-    VectorTypeWithScalableDims
+    VectorTypeWithScalableDims,
+    TokenType
   ];
 }
 
diff --git a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
index a50d73e21a69a..93c8c0694b467 100644
--- a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
+++ b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
@@ -236,26 +236,6 @@ def PtrLikeTypeInterface : TypeInterface<"PtrLikeTypeInterface"> {
   ];
 }
 
-//===----------------------------------------------------------------------===//
-// TokenTypeInterface
-//===----------------------------------------------------------------------===//
-
-def TokenTypeInterface : TypeInterface<"TokenTypeInterface"> {
-  let cppNamespace = "::mlir";
-  let description = [{
-    Intuitively, a *token* value is a pointer to an operation (via an OpResult)
-    or a pointer to a region (via an entry block argument).
-
-    More precisely, a token is an SSA value whose purpose is to encode a
-    static def–use relationship between operations or regions. It carries
-    no runtime data and is not allowed to flow through "regular"
-    value-forwarding constructs. A token's provenance cannot be obscured through
-    value forwarding.
-
-    This interface is a marker. It has no interface methods.
-  }];
-}
-
 //===----------------------------------------------------------------------===//
 // ShapedType
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index 20c41c5f79729..33e2a48e26386 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -1075,6 +1075,33 @@ def Builtin_None : Builtin_Type<"None", "none"> {
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// TokenType
+//===----------------------------------------------------------------------===//
+
+def Builtin_Token : Builtin_Type<"Token", "token"> {
+  let summary = "Token type for static def-use links";
+  let description = [{
+    Intuitively, a *token* value is a pointer to an operation (via an OpResult)
+    or a pointer to a region (via an entry block argument).
+
+    More precisely, a token is an SSA value whose purpose is to encode a
+    static def-use relationship between operations or regions. It carries
+    no runtime data and is not allowed to flow through "regular"
+    value-forwarding constructs. A token's provenance cannot be obscured
+    through value forwarding.
+
+    See the "Tokens" design note in the documentation for details on the
+    structural contract.
+
+    Syntax:
+
+    ```
+    token-type ::= `token`
+    ```
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // OpaqueType
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index 7898937ca01f5..86066bcda73e6 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -165,9 +165,8 @@ class SameBuildabilityAs<Type type, code builder> {
   code builderCall = !if(!empty(type.builderCall), "", builder);
 }
 
-// Whether a type is a token (i.e. implements TokenTypeInterface).
-def IsTokenTypePred
-    : CPred<"::llvm::isa<::mlir::TokenTypeInterface>($_self)">;
+// Whether a type is the builtin `TokenType`.
+def IsTokenTypePred : CPred<"::llvm::isa<::mlir::TokenType>($_self)">;
 
 // Any non-token type. Tokens are excluded by default to prevent ops that
 // accept arbitrary types from accidentally accepting tokens as operands /
@@ -181,9 +180,9 @@ def AnyType : Type<Neg<IsTokenTypePred>, "any non-token type">;
 // `BranchOpInterface`, etc. that legitimately handle arbitrary types).
 def AnyTypeOrToken : Type<CPred<"true">, "any type">;
 
-// A token type (any type implementing `TokenTypeInterface`).
-def Token : Type<IsTokenTypePred, "token",
-                 "::mlir::TokenTypeInterface">;
+// The builtin token type.
+def Token : Type<IsTokenTypePred, "token", "::mlir::TokenType">,
+            BuildableType<"$_builder.getType<::mlir::TokenType>()">;
 
 // None type
 def NoneType : Type<CPred<"::llvm::isa<::mlir::NoneType>($_self)">, "none type",
diff --git a/mlir/lib/AsmParser/TokenKinds.def b/mlir/lib/AsmParser/TokenKinds.def
index fe7c53753e156..f5e5c25832a30 100644
--- a/mlir/lib/AsmParser/TokenKinds.def
+++ b/mlir/lib/AsmParser/TokenKinds.def
@@ -127,6 +127,7 @@ TOK_KEYWORD(symbol)
 TOK_KEYWORD(tensor)
 TOK_KEYWORD(tf32)
 TOK_KEYWORD(to)
+TOK_KEYWORD(token)
 TOK_KEYWORD(true)
 TOK_KEYWORD(tuple)
 TOK_KEYWORD(type)
diff --git a/mlir/lib/AsmParser/TypeParser.cpp b/mlir/lib/AsmParser/TypeParser.cpp
index a461ebed967a8..2cdec14d65fa6 100644
--- a/mlir/lib/AsmParser/TypeParser.cpp
+++ b/mlir/lib/AsmParser/TypeParser.cpp
@@ -58,6 +58,7 @@ OptionalParseResult Parser::parseOptionalType(Type &type) {
   case Token::kw_f128:
   case Token::kw_index:
   case Token::kw_none:
+  case Token::kw_token:
   case Token::exclamation_identifier:
     return failure(!(type = parseType()));
 
@@ -371,6 +372,11 @@ Type Parser::parseNonFunctionType() {
     consumeToken(Token::kw_none);
     return builder.getNoneType();
 
+  // token-type
+  case Token::kw_token:
+    consumeToken(Token::kw_token);
+    return builder.getType<TokenType>();
+
   // extended type
   case Token::exclamation_identifier:
     return parseExtendedType();
diff --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
index 29e6552231f9c..7844c9dda877c 100644
--- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
+++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
@@ -89,7 +89,7 @@ struct AsyncAPI {
   }
 
   static FunctionType createTokenFunctionType(MLIRContext *ctx) {
-    return FunctionType::get(ctx, {}, {TokenType::get(ctx)});
+    return FunctionType::get(ctx, {}, {async::TokenType::get(ctx)});
   }
 
   static FunctionType createValueFunctionType(MLIRContext *ctx) {
@@ -109,7 +109,7 @@ struct AsyncAPI {
   }
 
   static FunctionType emplaceTokenFunctionType(MLIRContext *ctx) {
-    return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
+    return FunctionType::get(ctx, {async::TokenType::get(ctx)}, {});
   }
 
   static FunctionType emplaceValueFunctionType(MLIRContext *ctx) {
@@ -118,7 +118,7 @@ struct AsyncAPI {
   }
 
   static FunctionType setTokenErrorFunctionType(MLIRContext *ctx) {
-    return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
+    return FunctionType::get(ctx, {async::TokenType::get(ctx)}, {});
   }
 
   static FunctionType setValueErrorFunctionType(MLIRContext *ctx) {
@@ -128,7 +128,7 @@ struct AsyncAPI {
 
   static FunctionType isTokenErrorFunctionType(MLIRContext *ctx) {
     auto i1 = IntegerType::get(ctx, 1);
-    return FunctionType::get(ctx, {TokenType::get(ctx)}, {i1});
+    return FunctionType::get(ctx, {async::TokenType::get(ctx)}, {i1});
   }
 
   static FunctionType isValueErrorFunctionType(MLIRContext *ctx) {
@@ -143,7 +143,7 @@ struct AsyncAPI {
   }
 
   static FunctionType awaitTokenFunctionType(MLIRContext *ctx) {
-    return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
+    return FunctionType::get(ctx, {async::TokenType::get(ctx)}, {});
   }
 
   static FunctionType awaitValueFunctionType(MLIRContext *ctx) {
@@ -162,13 +162,14 @@ struct AsyncAPI {
 
   static FunctionType addTokenToGroupFunctionType(MLIRContext *ctx) {
     auto i64 = IntegerType::get(ctx, 64);
-    return FunctionType::get(ctx, {TokenType::get(ctx), GroupType::get(ctx)},
-                             {i64});
+    return FunctionType::get(
+        ctx, {async::TokenType::get(ctx), GroupType::get(ctx)}, {i64});
   }
 
   static FunctionType awaitTokenAndExecuteFunctionType(MLIRContext *ctx) {
     auto ptrType = opaquePointerType(ctx);
-    return FunctionType::get(ctx, {TokenType::get(ctx), ptrType, ptrType}, {});
+    return FunctionType::get(
+        ctx, {async::TokenType::get(ctx), ptrType, ptrType}, {});
   }
 
   static FunctionType awaitValueAndExecuteFunctionType(MLIRContext *ctx) {
@@ -291,7 +292,7 @@ class AsyncRuntimeTypeConverter : public TypeConverter {
   }
 
   static std::optional<Type> convertAsyncTypes(Type type) {
-    if (isa<TokenType, GroupType, ValueType>(type))
+    if (isa<async::TokenType, GroupType, ValueType>(type))
       return AsyncAPI::opaquePointerType(type.getContext());
 
     if (isa<CoroIdType, CoroStateType>(type))
@@ -583,7 +584,7 @@ class RuntimeCreateOpLowering : public ConvertOpToLLVMPattern<RuntimeCreateOp> {
     Type resultType = op->getResultTypes()[0];
 
     // Tokens creation maps to a simple function call.
-    if (isa<TokenType>(resultType)) {
+    if (isa<async::TokenType>(resultType)) {
       rewriter.replaceOpWithNewOp<func::CallOp>(
           op, kCreateToken, converter->convertType(resultType));
       return success();
@@ -659,7 +660,7 @@ class RuntimeSetAvailableOpLowering
                   ConversionPatternRewriter &rewriter) const override {
     StringRef apiFuncName =
         TypeSwitch<Type, StringRef>(op.getOperand().getType())
-            .Case<TokenType>([](Type) { return kEmplaceToken; })
+            .Case<async::TokenType>([](Type) { return kEmplaceToken; })
             .Case<ValueType>([](Type) { return kEmplaceValue; });
 
     rewriter.replaceOpWithNewOp<func::CallOp>(op, apiFuncName, TypeRange(),
@@ -685,7 +686,7 @@ class RuntimeSetErrorOpLowering
                   ConversionPatternRewriter &rewriter) const override {
     StringRef apiFuncName =
         TypeSwitch<Type, StringRef>(op.getOperand().getType())
-            .Case<TokenType>([](Type) { return kSetTokenError; })
+            .Case<async::TokenType>([](Type) { return kSetTokenError; })
             .Case<ValueType>([](Type) { return kSetValueError; });
 
     rewriter.replaceOpWithNewOp<func::CallOp>(op, apiFuncName, TypeRange(),
@@ -710,7 +711,7 @@ class RuntimeIsErrorOpLowering : public OpConversionPattern<RuntimeIsErrorOp> {
                   ConversionPatternRewriter &rewriter) const override {
     StringRef apiFuncName =
         TypeSwitch<Type, StringRef>(op.getOperand().getType())
-            .Case<TokenType>([](Type) { return kIsTokenError; })
+            .Case<async::TokenType>([](Type) { return kIsTokenError; })
             .Case<GroupType>([](Type) { return kIsGroupError; })
             .Case<ValueType>([](Type) { return kIsValueError; });
 
@@ -735,7 +736,7 @@ class RuntimeAwaitOpLowering : public OpConversionPattern<RuntimeAwaitOp> {
                   ConversionPatternRewriter &rewriter) const override {
     StringRef apiFuncName =
         TypeSwitch<Type, StringRef>(op.getOperand().getType())
-            .Case<TokenType>([](Type) { return kAwaitToken; })
+            .Case<async::TokenType>([](Type) { return kAwaitToken; })
             .Case<ValueType>([](Type) { return kAwaitValue; })
             .Case<GroupType>([](Type) { return kAwaitGroup; });
 
@@ -763,7 +764,7 @@ class RuntimeAwaitAndResumeOpLowering
                   ConversionPatternRewriter &rewriter) const override {
     StringRef apiFuncName =
         TypeSwitch<Type, StringRef>(op.getOperand().getType())
-            .Case<TokenType>([](Type) { return kAwaitTokenAndExecute; })
+            .Case<async::TokenType>([](Type) { return kAwaitTokenAndExecute; })
             .Case<ValueType>([](Type) { return kAwaitValueAndExecute; })
             .Case<GroupType>([](Type) { return kAwaitAllAndExecute; });
 
@@ -906,7 +907,7 @@ class RuntimeAddToGroupOpLowering
   matchAndRewrite(RuntimeAddToGroupOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     // Currently we can only add tokens to the group.
-    if (!isa<TokenType>(op.getOperand().getType()))
+    if (!isa<async::TokenType>(op.getOperand().getType()))
       return rewriter.notifyMatchFailure(op, "only token type is supported");
 
     // Replace with a runtime API function call.
@@ -1151,7 +1152,7 @@ class ConvertYieldOpTypes : public OpConversionPattern<async::YieldOp> {
 void mlir::populateAsyncStructuralTypeConversionsAndLegality(
     TypeConverter &typeConverter, RewritePatternSet &patterns,
     ConversionTarget &target) {
-  typeConverter.addConversion([&](TokenType type) { return type; });
+  typeConverter.addConversion([&](async::TokenType type) { return type; });
   typeConverter.addConversion([&](ValueType type) {
     Type converted = typeConverter.convertType(type.getValueType());
     return converted ? ValueType::get(converted) : converted;
diff --git a/mlir/lib/Dialect/Async/IR/Async.cpp b/mlir/lib/Dialect/Async/IR/Async.cpp
index 71be1d275280e..1713da07da60d 100644
--- a/mlir/lib/Dialect/Async/IR/Async.cpp
+++ b/mlir/lib/Dialect/Async/IR/Async.cpp
@@ -84,7 +84,7 @@ void ExecuteOp::build(OpBuilder &builder, OperationState &result,
 
   // First result is always a token, and then `resultTypes` wrapped into
   // `async.value`.
-  result.addTypes({TokenType::get(result.getContext())});
+  result.addTypes({async::TokenType::get(result.getContext())});
   for (Type type : resultTypes)
     result.addTypes(ValueType::get(type));
 
@@ -139,7 +139,7 @@ ParseResult ExecuteOp::parse(OpAsmParser &parser, OperationState &result) {
   // Sizes of parsed variadic operands, will be updated below after parsing.
   int32_t numDependencies = 0;
 
-  auto tokenTy = TokenType::get(ctx);
+  auto tokenTy = async::TokenType::get(ctx);
 
   // Parse dependency tokens.
   if (succeeded(parser.parseOptionalLSquare())) {
@@ -280,7 +280,7 @@ LogicalResult AwaitOp::verify() {
   Type argType = getOperand().getType();
 
   // Awaiting on a token does not have any results.
-  if (llvm::isa<TokenType>(argType) && !getResultTypes().empty())
+  if (llvm::isa<async::TokenType>(argType) && !getResultTypes().empty())
     return emitOpError("awaiting on a token must have empty result");
 
   // Awaiting on a value unwraps the async value type.
@@ -345,12 +345,12 @@ LogicalResult FuncOp::verify() {
 
   for (unsigned i = 0, e = resultTypes.size(); i != e; ++i) {
     auto type = resultTypes[i];
-    if (!llvm::isa<TokenType>(type) && !llvm::isa<ValueType>(type))
+    if (!llvm::isa<async::TokenType>(type) && !llvm::isa<ValueType>(type))
       return emitOpError() << "result type must be async value type or async "
                               "token type, but got "
                            << type;
     // We only allow AsyncToken appear as the first return value
-    if (llvm::isa<TokenType>(type) && i != 0) {
+    if (llvm::isa<async::TokenType>(type) && i != 0) {
       return emitOpError()
              << " results' (optional) async token type is expected "
                 "to appear as the 1st return value, but got "
diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp
index 91e37dd9ac36e..2a726f3fd2999 100644
--- a/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp
+++ b/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp
@@ -526,7 +526,7 @@ void AsyncRuntimePolicyBasedRefCountingPass::initializeDefaultPolicy() {
     Operation *op = operand.getOwner();
     Type type = operand.get().getType();
 
-    bool isToken = isa<TokenType>(type);
+    bool isToken = isa<async::TokenType>(type);
     bool isGroup = isa<GroupType>(type);
     bool isValue = isa<ValueType>(type);
 
diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
index 0c5bcfe631c6c..8b4aaf56fdf76 100644
--- a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
+++ b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
@@ -180,13 +180,14 @@ static CoroMachinery setupCoroMachinery(func::FuncOp func) {
   // Allocate async token/values that we will return from a ramp function.
   // ------------------------------------------------------------------------ //
 
-  // We treat TokenType as state update marker to represent side-effects of
-  // async computations
-  bool isStateful = isa<TokenType>(func.getResultTypes().front());
+  // We treat async::TokenType as state update marker to represent
+  // side-effects of async computations
+  bool isStateful = isa<async::TokenType>(func.getResultTypes().front());
 
   std::optional<Value> retToken;
   if (isStateful)
-    retToken.emplace(RuntimeCreateOp::create(builder, TokenType::get(ctx)));
+    retToken.emplace(
+        RuntimeCreateOp::create(builder, async::TokenType::get(ctx)));
 
   llvm::SmallVector<Value, 4> retValues;
   ArrayRef<Type> resValueTypes =
@@ -655,8 +656,9 @@ class AwaitOpLoweringBase : public OpConversionPattern<AwaitType> {
 };
 
 /// Lowering for `async.await` with a token operand.
-class AwaitTokenOpLowering : public AwaitOpLoweringBase<AwaitOp, TokenType> {
-  using Base = AwaitOpLoweringBase<AwaitOp, TokenType>;
+class AwaitTokenOpLowering
+    : public AwaitOpLoweringBase<AwaitOp, async::TokenType> {
+  using Base = AwaitOpLoweringBase<AwaitOp, async::TokenType>;
 
 public:
   using Base::Base;
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
index 705d07d3e6c42..87df1e2df6102 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
@@ -238,13 +238,41 @@ Type LLVMStructType::parse(AsmParser &parser) {
 }
 
 /// Parses a type appearing inside another LLVM dialect-compatible type. This
-/// will try to parse any type in full form (including types with the `!llvm`
-/// prefix), and on failure fall back to parsing the short-hand version of the
-/// LLVM dialect types without the `!llvm` prefix.
+/// will first try to parse the LLVM dialect's short-hand keyword form (e.g.
+/// `token`, `void`, `ptr`, ...) and, failing that, fall back to parsing any
+/// MLIR type in full form (including types with the `!llvm` prefix).
+///
+/// Trying the short-hand form first matters because some LLVM short-hand
+/// keywords (notably `token`) collide with builtin type keywords whose
+/// semantics differ from the LLVM dialect's. Inside an `!llvm.<...>` type, the
+/// LLVM-specific meaning must always win.
 static Type dispatchParse(AsmParser &parser, bool allowAny = true) {
   SMLoc keyLoc = parser.getCurrentLocation();
+  MLIRContext *ctx = parser.getContext();
+
+  // Try parsing the LLVM dialect's short-hand keyword form first.
+  StringRef key;
+  if (succeeded(parser.parseOptionalKeyword(
+          &key, {"void", "ppc_fp128", "token", "label", "metadata", "func",
+                 "ptr", "array", "struct", "target", "x86_amx"}))) {
+    // `parseOptionalKeyword` already restricted `key` to one of the cases
+    // below, so the `Default` is unreachable.
+    return StringSwitch<function_ref<Type()>>(key)
+        .Case("void", [&] { return LLVMVoidType::get(ctx); })
+        .Case("ppc_fp128", [&] { return LLVMPPCFP128Type::get(ctx); })
+        .Case("token", [&] { return LLVMTokenType::get(ctx); })
+        .Case("label", [&] { return LLVMLabelType::get(ctx); })
+        .Case("metadata", [&] { return LLVMMetadataType::get(ctx); })
+        .Case("func", [&] { return LLVMFunctionType::parse(parser); })
+        .Case("ptr", [&] { return LLVMPointerType::parse(parser); })
+        .Case("array", [&] { return LLVMArrayType::parse(parser); })
+        .Case("struct", [&] { return LLVMStructType::parse(parser); })
+        .Case("target", [&] { return LLVMTargetExtType::parse(parser); })
+        .Case("x86_amx", [&] { return LLVMX86AMXType::get(ctx); })
+        .Default([] { return Type(); })();
+  }
 
-  // Try parsing any MLIR type.
+  // Otherwise, try parsing any MLIR type (only when allowed).
   Type type;
   OptionalParseResult result = parser.parseOptionalType(type);
   if (result.has_value()) {
@@ -257,28 +285,12 @@ static Type dispatchParse(AsmParser &parser, bool allowAny = true) {
     return type;
   }
 
-  // If no type found, fallback to the shorthand form.
-  StringRef key;
+  // Neither a known LLVM short-hand keyword nor a parseable MLIR type.
+  // Re-run `parseKeyword` to produce a useful error message.
   if (failed(parser.parseKeyword(&key)))
     return Type();
-
-  MLIRContext *ctx = parser.getContext();
-  return StringSwitch<function_ref<Type()>>(key)
-      .Case("void", [&] { return LLVMVoidType::get(ctx); })
-      .Case("ppc_fp128", [&] { return LLVMPPCFP128Type::get(ctx); })
-      .Case("token", [&] { return LLVMTokenType::get(ctx); })
-      .Case("label", [&] { return LLVMLabelType::get(ctx); })
-      .Case("metadata", [&] { return LLVMMetadataType::get(ctx); })
-      .Case("func", [&] { return LLVMFunctionType::parse(parser); })
-      .Case("ptr", [&] { return LLVMPointerType::parse(parser); })
-      .Case("array", [&] { return LLVMArrayType::parse(parser); })
-      .Case("struct", [&] { return LLVMStructType::parse(parser); })
-      .Case("target", [&] { return LLVMTargetExtType::parse(parser); })
-      .Case("x86_amx", [&] { return LLVMX86AMXType::get(ctx); })
-      .Default([&] {
-        parser.emitError(keyLoc) << "unknown LLVM type: " << key;
-        return Type();
-      })();
+  parser.emitError(keyLoc) << "unknown LLVM type: " << key;
+  return Type();
 }
 
 /// Helper to use in parse lists.
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 75008d6cc2591..83dbdacf60c90 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -2907,6 +2907,7 @@ void AsmPrinter::Impl::printTypeImpl(Type type) {
         os << '>';
       })
       .Case<NoneType>([&](Type) { os << "none"; })
+      .Case<TokenType>([&](Type) { os << "token"; })
       .Case([&](GraphType graphTy) {
         os << '(';
         interleaveComma(graphTy.getInputs(), [&](Type ty) { printType(ty); });
diff --git a/mlir/test/IR/token-type-interface.mlir b/mlir/test/IR/token-type.mlir
similarity index 52%
rename from mlir/test/IR/token-type-interface.mlir
rename to mlir/test/IR/token-type.mlir
index c0632bdcf616b..990a7917effe1 100644
--- a/mlir/test/IR/token-type-interface.mlir
+++ b/mlir/test/IR/token-type.mlir
@@ -1,20 +1,19 @@
 // RUN: mlir-opt %s -verify-diagnostics -split-input-file | FileCheck %s
 
-// Tests for the builtin `TokenTypeInterface` and the
+// Tests for the builtin `token` type and the
 // `Token` / `AnyType` / `AnyTypeOrToken` ODS predicates.
 //
-// `!test.test_token` is a test-dialect type that implements
-// `TokenTypeInterface`. The default `AnyType` predicate excludes tokens, while
-// `AnyTypeOrToken` and `Token` accept them.
+// The default `AnyType` predicate excludes tokens, while `AnyTypeOrToken` and
+// `Token` accept them.
 
 // CHECK-LABEL: @token_produce_consume
 func.func @token_produce_consume() {
-  // CHECK: %[[T:.*]] = test.token.produce : !test.test_token
-  %t = test.token.produce : !test.test_token
-  // CHECK: test.token.consume %[[T]] : !test.test_token
-  test.token.consume %t : !test.test_token
-  // CHECK: test.token.any_or_token %[[T]] : !test.test_token
-  test.token.any_or_token %t : !test.test_token
+  // CHECK: %[[T:.*]] = test.token.produce
+  %t = test.token.produce
+  // CHECK: test.token.consume %[[T]]
+  test.token.consume %t
+  // CHECK: test.token.any_or_token %[[T]] : token
+  test.token.any_or_token %t : token
   return
 }
 
@@ -42,18 +41,20 @@ func.func @any_type_with_non_token(%arg0: i32) {
 
 // `AnyType` rejects tokens by default.
 func.func @any_type_rejects_token() {
-  %t = test.token.produce : !test.test_token
+  %t = test.token.produce
   // expected-error @below {{operand #0 must be any non-token type}}
-  test.token.any_type %t : !test.test_token
+  test.token.any_type %t : token
   return
 }
 
 // -----
 
-// `Token` rejects non-token types. The operand's cppType is
-// `TokenTypeInterface`, so type resolution fails at parse time.
+// `Token` rejects non-token types. The op's operand type is fixed to the
+// builtin `token` (it's a `BuildableType`), so passing a non-token SSA value
+// fails at parse time with an SSA type mismatch.
+// expected-note @below {{prior use here}}
 func.func @token_rejects_non_token(%arg0: i32) {
-  // expected-error @below {{invalid kind of type specified}}
-  test.token.consume %arg0 : i32
+  // expected-error @below {{use of value '%arg0' expects different type than prior uses: 'token' vs 'i32'}}
+  test.token.consume %arg0
   return
 }
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 529fc4b860ad4..9a425d50b3f74 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -114,18 +114,16 @@ def SignlessLikeVariadic : TEST_Op<"signless_like_variadic"> {
 // Test Token Type
 //===----------------------------------------------------------------------===//
 
-// Produce a token. Demonstrates a type that implements the builtin
-// `TokenTypeInterface`.
+// Produce a builtin `!token` value.
 def TestTokenProduceOp : TEST_Op<"token.produce"> {
-  let results = (outs TestTokenType:$token);
-  let assemblyFormat = "attr-dict `:` type($token)";
+  let results = (outs Token:$token);
+  let assemblyFormat = "attr-dict";
 }
 
-// Consume a token (token-only operand). Uses the `Token` ODS predicate which
-// only accepts types implementing `TokenTypeInterface`.
+// Consume a builtin `!token` value (token-only operand).
 def TestTokenConsumeOp : TEST_Op<"token.consume"> {
   let arguments = (ins Token:$token);
-  let assemblyFormat = "$token attr-dict `:` type($token)";
+  let assemblyFormat = "$token attr-dict";
 }
 
 // Op that accepts any type, including a token. Uses the `AnyTypeOrToken`
diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
index 5df0eab829a03..08600ce713a17 100644
--- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
@@ -173,13 +173,6 @@ def TestMemRefElementType : Test_Type<"TestMemRefElementType",
   let mnemonic = "memref_element";
 }
 
-// A test token type implementing the builtin `TokenTypeInterface`. Used to
-// exercise the default exclusion of tokens from `AnyType` and the explicit
-// `Token` / `AnyTypeOrToken` opt-ins.
-def TestTokenType : Test_Type<"TestToken", [TokenTypeInterface]> {
-  let mnemonic = "test_token";
-}
-
 def TestTypeTrait : NativeTypeTrait<"TestTypeTrait">;
 
 // The definition of a singleton type that has a trait.
diff --git a/mlir/test/mlir-tblgen/predicate.td b/mlir/test/mlir-tblgen/predicate.td
index 07a5e6f2f261c..ae436885b421f 100644
--- a/mlir/test/mlir-tblgen/predicate.td
+++ b/mlir/test/mlir-tblgen/predicate.td
@@ -27,7 +27,7 @@ def OpA : NS_Op<"op_for_CPred_containing_multiple_same_placeholder", []> {
 // CHECK-NOT.        << " must be 32-bit integer or floating-point type, but got " << type;
 
 // CHECK: static ::llvm::LogicalResult [[$TENSOR_CONSTRAINT:__mlir_ods_local_type_constraint.*]](
-// CHECK:       if (!(((::llvm::isa<::mlir::TensorType>(type))) && ([](::mlir::Type elementType) { return !((::llvm::isa<::mlir::TokenTypeInterface>(elementType))); }(::llvm::cast<::mlir::ShapedType>(type).getElementType())))) {
+// CHECK:       if (!(((::llvm::isa<::mlir::TensorType>(type))) && ([](::mlir::Type elementType) { return !((::llvm::isa<::mlir::TokenType>(elementType))); }(::llvm::cast<::mlir::ShapedType>(type).getElementType())))) {
 // CHECK-NEXT:    return op->emitOpError(valueKind) << " #" << valueIndex
 // CHECK-NEXT:        << " must be tensor of any non-token type values, but got " << type;
 



More information about the Mlir-commits mailing list