[Mlir-commits] [mlir] [mlir][amx] Prevent crash on invalid tile element type (PR #155587)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Aug 27 03:12:56 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-amx

Author: Adam Siemieniuk (adam-smnk)

<details>
<summary>Changes</summary>

Fixes AMX tile type parser to prevent crashes on invalid element type.

---
Full diff: https://github.com/llvm/llvm-project/pull/155587.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/AMX/IR/AMXDialect.cpp (+3-1) 
- (modified) mlir/test/Dialect/AMX/invalid.mlir (+20-4) 


``````````diff
diff --git a/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp b/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp
index 6f3110cdf00ef..68990ef0dc0c3 100644
--- a/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp
+++ b/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp
@@ -271,7 +271,9 @@ Type amx::TileType::parse(AsmParser &parser) {
   if (parser.parseGreater())
     return nullptr;
 
-  return TileType::get(shape, elementType);
+  return TileType::getChecked(
+      [&] { return parser.emitError(parser.getNameLoc()); }, shape,
+      elementType);
 }
 
 void amx::TileType::print(AsmPrinter &os) const {
diff --git a/mlir/test/Dialect/AMX/invalid.mlir b/mlir/test/Dialect/AMX/invalid.mlir
index a401770240d0a..5de9b3f82a868 100644
--- a/mlir/test/Dialect/AMX/invalid.mlir
+++ b/mlir/test/Dialect/AMX/invalid.mlir
@@ -16,6 +16,22 @@ func.func @tile_col_width() {
 
 // -----
 
+func.func @tile_element_type() {
+  // expected-error at +1 {{failed to verify 'elementType'}}
+  %0 = amx.tile_zero : !amx.tile<8x8xi16>
+  return
+}
+
+// -----
+
+func.func @tile_rank() {
+  // expected-error at +1 {{'amx.tile_zero' op result #0 must be tile of}}
+  %0 = amx.tile_zero : !amx.tile<32xi8>
+  return
+}
+
+// -----
+
 func.func @tile_col_4_byte_multiple() {
   // expected-error at +1 {{'amx.tile_zero' op bad column width: 5}}
   %0 = amx.tile_zero : !amx.tile<16x5xi8>
@@ -24,7 +40,7 @@ func.func @tile_col_4_byte_multiple() {
 
 // -----
 
-func.func @load_base_tilesize(%arg0: memref<?x?xf32>) {
+func.func @load_base_tile_size(%arg0: memref<?x?xf32>) {
   %0 = arith.constant 0 : index
   // expected-error at +1 {{'amx.tile_load' op bad column width: 68}}
   %1 = amx.tile_load %arg0[%0, %0] : memref<?x?xf32> into !amx.tile<16x17xf32>
@@ -33,7 +49,7 @@ func.func @load_base_tilesize(%arg0: memref<?x?xf32>) {
 
 // -----
 
-func.func @store_base_tilesize(%arg0: memref<?x?xf32>, %arg1: !amx.tile<16x17xf32>) {
+func.func @store_base_tile_size(%arg0: memref<?x?xf32>, %arg1: !amx.tile<16x17xf32>) {
   %0 = arith.constant 0 : index
   // expected-error at +1 {{'amx.tile_store' op bad column width: 68}}
   amx.tile_store %arg0[%0, %0], %arg1 : memref<?x?xf32>, !amx.tile<16x17xf32>
@@ -42,7 +58,7 @@ func.func @store_base_tilesize(%arg0: memref<?x?xf32>, %arg1: !amx.tile<16x17xf3
 
 // -----
 
-func.func @load_base_indexsize(%arg0: memref<?x?xf32>) {
+func.func @load_base_index_size(%arg0: memref<?x?xf32>) {
   %0 = arith.constant 0 : index
   // expected-error at +1 {{'amx.tile_load' op requires 2 indices}}
   %1 = amx.tile_load %arg0[%0] : memref<?x?xf32> into !amx.tile<16x16xf32>
@@ -51,7 +67,7 @@ func.func @load_base_indexsize(%arg0: memref<?x?xf32>) {
 
 // -----
 
-func.func @store_base_indexsize(%arg0: memref<?x?xf32>, %arg1: !amx.tile<16x16xf32>) {
+func.func @store_base_index_size(%arg0: memref<?x?xf32>, %arg1: !amx.tile<16x16xf32>) {
   %0 = arith.constant 0 : index
   // expected-error at +1 {{'amx.tile_store' op requires 2 indices}}
   amx.tile_store %arg0[%0], %arg1 : memref<?x?xf32>, !amx.tile<16x16xf32>

``````````

</details>


https://github.com/llvm/llvm-project/pull/155587


More information about the Mlir-commits mailing list