[Mlir-commits] [mlir] a4f67f3 - [mlir][amx] Prevent crash on invalid tile element type (#155587)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Aug 27 06:59:48 PDT 2025
Author: Adam Siemieniuk
Date: 2025-08-27T15:59:45+02:00
New Revision: a4f67f3684d9eef3bf721509364239af9e3c4ec4
URL: https://github.com/llvm/llvm-project/commit/a4f67f3684d9eef3bf721509364239af9e3c4ec4
DIFF: https://github.com/llvm/llvm-project/commit/a4f67f3684d9eef3bf721509364239af9e3c4ec4.diff
LOG: [mlir][amx] Prevent crash on invalid tile element type (#155587)
Fixes AMX tile type parser to prevent crashes on invalid element type.
Added:
Modified:
mlir/lib/Dialect/AMX/IR/AMXDialect.cpp
mlir/test/Dialect/AMX/invalid.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp b/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp
index 6f3110cdf00ef..68990ef0dc0c3 100644
--- a/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp
+++ b/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp
@@ -271,7 +271,9 @@ Type amx::TileType::parse(AsmParser &parser) {
if (parser.parseGreater())
return nullptr;
- return TileType::get(shape, elementType);
+ return TileType::getChecked(
+ [&] { return parser.emitError(parser.getNameLoc()); }, shape,
+ elementType);
}
void amx::TileType::print(AsmPrinter &os) const {
diff --git a/mlir/test/Dialect/AMX/invalid.mlir b/mlir/test/Dialect/AMX/invalid.mlir
index a401770240d0a..5de9b3f82a868 100644
--- a/mlir/test/Dialect/AMX/invalid.mlir
+++ b/mlir/test/Dialect/AMX/invalid.mlir
@@ -16,6 +16,22 @@ func.func @tile_col_width() {
// -----
+func.func @tile_element_type() {
+ // expected-error at +1 {{failed to verify 'elementType'}}
+ %0 = amx.tile_zero : !amx.tile<8x8xi16>
+ return
+}
+
+// -----
+
+func.func @tile_rank() {
+ // expected-error at +1 {{'amx.tile_zero' op result #0 must be tile of}}
+ %0 = amx.tile_zero : !amx.tile<32xi8>
+ return
+}
+
+// -----
+
func.func @tile_col_4_byte_multiple() {
// expected-error at +1 {{'amx.tile_zero' op bad column width: 5}}
%0 = amx.tile_zero : !amx.tile<16x5xi8>
@@ -24,7 +40,7 @@ func.func @tile_col_4_byte_multiple() {
// -----
-func.func @load_base_tilesize(%arg0: memref<?x?xf32>) {
+func.func @load_base_tile_size(%arg0: memref<?x?xf32>) {
%0 = arith.constant 0 : index
// expected-error at +1 {{'amx.tile_load' op bad column width: 68}}
%1 = amx.tile_load %arg0[%0, %0] : memref<?x?xf32> into !amx.tile<16x17xf32>
@@ -33,7 +49,7 @@ func.func @load_base_tilesize(%arg0: memref<?x?xf32>) {
// -----
-func.func @store_base_tilesize(%arg0: memref<?x?xf32>, %arg1: !amx.tile<16x17xf32>) {
+func.func @store_base_tile_size(%arg0: memref<?x?xf32>, %arg1: !amx.tile<16x17xf32>) {
%0 = arith.constant 0 : index
// expected-error at +1 {{'amx.tile_store' op bad column width: 68}}
amx.tile_store %arg0[%0, %0], %arg1 : memref<?x?xf32>, !amx.tile<16x17xf32>
@@ -42,7 +58,7 @@ func.func @store_base_tilesize(%arg0: memref<?x?xf32>, %arg1: !amx.tile<16x17xf3
// -----
-func.func @load_base_indexsize(%arg0: memref<?x?xf32>) {
+func.func @load_base_index_size(%arg0: memref<?x?xf32>) {
%0 = arith.constant 0 : index
// expected-error at +1 {{'amx.tile_load' op requires 2 indices}}
%1 = amx.tile_load %arg0[%0] : memref<?x?xf32> into !amx.tile<16x16xf32>
@@ -51,7 +67,7 @@ func.func @load_base_indexsize(%arg0: memref<?x?xf32>) {
// -----
-func.func @store_base_indexsize(%arg0: memref<?x?xf32>, %arg1: !amx.tile<16x16xf32>) {
+func.func @store_base_index_size(%arg0: memref<?x?xf32>, %arg1: !amx.tile<16x16xf32>) {
%0 = arith.constant 0 : index
// expected-error at +1 {{'amx.tile_store' op requires 2 indices}}
amx.tile_store %arg0[%0], %arg1 : memref<?x?xf32>, !amx.tile<16x16xf32>
More information about the Mlir-commits
mailing list