[Mlir-commits] [mlir] [mlir][amx] Prevent crash on invalid tile element type (PR #155587)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Aug 27 03:12:57 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Adam Siemieniuk (adam-smnk)
<details>
<summary>Changes</summary>
Fixes AMX tile type parser to prevent crashes on invalid element type.
---
Full diff: https://github.com/llvm/llvm-project/pull/155587.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/AMX/IR/AMXDialect.cpp (+3-1)
- (modified) mlir/test/Dialect/AMX/invalid.mlir (+20-4)
``````````diff
diff --git a/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp b/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp
index 6f3110cdf00ef..68990ef0dc0c3 100644
--- a/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp
+++ b/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp
@@ -271,7 +271,9 @@ Type amx::TileType::parse(AsmParser &parser) {
if (parser.parseGreater())
return nullptr;
- return TileType::get(shape, elementType);
+ return TileType::getChecked(
+ [&] { return parser.emitError(parser.getNameLoc()); }, shape,
+ elementType);
}
void amx::TileType::print(AsmPrinter &os) const {
diff --git a/mlir/test/Dialect/AMX/invalid.mlir b/mlir/test/Dialect/AMX/invalid.mlir
index a401770240d0a..5de9b3f82a868 100644
--- a/mlir/test/Dialect/AMX/invalid.mlir
+++ b/mlir/test/Dialect/AMX/invalid.mlir
@@ -16,6 +16,22 @@ func.func @tile_col_width() {
// -----
+func.func @tile_element_type() {
+ // expected-error at +1 {{failed to verify 'elementType'}}
+ %0 = amx.tile_zero : !amx.tile<8x8xi16>
+ return
+}
+
+// -----
+
+func.func @tile_rank() {
+ // expected-error at +1 {{'amx.tile_zero' op result #0 must be tile of}}
+ %0 = amx.tile_zero : !amx.tile<32xi8>
+ return
+}
+
+// -----
+
func.func @tile_col_4_byte_multiple() {
// expected-error at +1 {{'amx.tile_zero' op bad column width: 5}}
%0 = amx.tile_zero : !amx.tile<16x5xi8>
@@ -24,7 +40,7 @@ func.func @tile_col_4_byte_multiple() {
// -----
-func.func @load_base_tilesize(%arg0: memref<?x?xf32>) {
+func.func @load_base_tile_size(%arg0: memref<?x?xf32>) {
%0 = arith.constant 0 : index
// expected-error at +1 {{'amx.tile_load' op bad column width: 68}}
%1 = amx.tile_load %arg0[%0, %0] : memref<?x?xf32> into !amx.tile<16x17xf32>
@@ -33,7 +49,7 @@ func.func @load_base_tilesize(%arg0: memref<?x?xf32>) {
// -----
-func.func @store_base_tilesize(%arg0: memref<?x?xf32>, %arg1: !amx.tile<16x17xf32>) {
+func.func @store_base_tile_size(%arg0: memref<?x?xf32>, %arg1: !amx.tile<16x17xf32>) {
%0 = arith.constant 0 : index
// expected-error at +1 {{'amx.tile_store' op bad column width: 68}}
amx.tile_store %arg0[%0, %0], %arg1 : memref<?x?xf32>, !amx.tile<16x17xf32>
@@ -42,7 +58,7 @@ func.func @store_base_tilesize(%arg0: memref<?x?xf32>, %arg1: !amx.tile<16x17xf3
// -----
-func.func @load_base_indexsize(%arg0: memref<?x?xf32>) {
+func.func @load_base_index_size(%arg0: memref<?x?xf32>) {
%0 = arith.constant 0 : index
// expected-error at +1 {{'amx.tile_load' op requires 2 indices}}
%1 = amx.tile_load %arg0[%0] : memref<?x?xf32> into !amx.tile<16x16xf32>
@@ -51,7 +67,7 @@ func.func @load_base_indexsize(%arg0: memref<?x?xf32>) {
// -----
-func.func @store_base_indexsize(%arg0: memref<?x?xf32>, %arg1: !amx.tile<16x16xf32>) {
+func.func @store_base_index_size(%arg0: memref<?x?xf32>, %arg1: !amx.tile<16x16xf32>) {
%0 = arith.constant 0 : index
// expected-error at +1 {{'amx.tile_store' op requires 2 indices}}
amx.tile_store %arg0[%0], %arg1 : memref<?x?xf32>, !amx.tile<16x16xf32>
``````````
</details>
https://github.com/llvm/llvm-project/pull/155587
More information about the Mlir-commits
mailing list