[Mlir-commits] [mlir] [mlir][VectorToXeGPU] Fix crash on memref with non-scalar element type (PR #183905)

Mehdi Amini llvmlistbot at llvm.org
Sat Feb 28 04:10:41 PST 2026


https://github.com/joker-eph created https://github.com/llvm/llvm-project/pull/183905

The vector.store and vector.load lowering in --convert-vector-to-xegpu would crash when the source memref had a non-integer/float element type (e.g. memref<?xvector<4xf32>>).

The crash occurred inside createNdDescriptor() when computing the byte offset for dynamic memrefs: srcTy.getElementTypeBitWidth() internally calls getIntOrFloatBitWidth() which asserts on non-scalar types such as vector<4xf32>.

Fix by adding a check for the memref's element type in storeLoadPreconditions(). If the element type is not an integer or float, the pattern returns notifyMatchFailure() instead of proceeding and crashing.

The same guard is applied to TransferReadLowering and TransferWriteLowering which share the same helper and can hit the same path.

Fixes #181463

>From afd310caa8b7b4b8ab7236bc3dfae1406dc04dfb Mon Sep 17 00:00:00 2001
From: Mehdi Amini <joker.eph at gmail.com>
Date: Sat, 28 Feb 2026 03:55:44 -0800
Subject: [PATCH] [mlir][VectorToXeGPU] Fix crash on memref with non-scalar
 element type

The vector.store and vector.load lowering in --convert-vector-to-xegpu
would crash when the source memref had a non-integer/float element type
(e.g. memref<?xvector<4xf32>>).

The crash occurred inside createNdDescriptor() when computing the byte
offset for dynamic memrefs: srcTy.getElementTypeBitWidth() internally
calls getIntOrFloatBitWidth() which asserts on non-scalar types such as
vector<4xf32>.

Fix by adding a check for the memref's element type in
storeLoadPreconditions(). If the element type is not an integer or float,
the pattern returns notifyMatchFailure() instead of proceeding and crashing.

The same guard is applied to TransferReadLowering and TransferWriteLowering
which share the same helper and can hit the same path.

Fixes #181463
---
 .../VectorToXeGPU/VectorToXeGPU.cpp           | 23 +++++++++++++++----
 .../VectorToXeGPU/store-to-xegpu.mlir         | 16 +++++++++++++
 2 files changed, 34 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
index c81bb4b455b98..105f2916a26dc 100644
--- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
+++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
@@ -48,7 +48,8 @@ static bool isZeroConstant(Value val) {
 }
 
 static LogicalResult storeLoadPreconditions(PatternRewriter &rewriter,
-                                            Operation *op, VectorType vecTy) {
+                                            Operation *op, VectorType vecTy,
+                                            MemRefType memTy) {
   // Validate only vector as the basic vector store and load ops guarantee
   // XeGPU-compatible memref source.
   unsigned vecRank = vecTy.getRank();
@@ -59,6 +60,14 @@ static LogicalResult storeLoadPreconditions(PatternRewriter &rewriter,
     return rewriter.notifyMatchFailure(
         op, "Expected scalar type with known bitwidth");
 
+  // XeGPU requires the memref to have a scalar integer or float element type.
+  // Memrefs with vector element types (e.g. memref<?xvector<4xf32>>) are not
+  // supported because createNdDescriptor computes byte offsets using
+  // getElementTypeBitWidth(), which asserts on non-integer/float types.
+  if (!memTy.getElementType().isIntOrFloat())
+    return rewriter.notifyMatchFailure(
+        op, "Unsupported memref element type: expected integer or float");
+
   return success();
 }
 
@@ -556,7 +565,8 @@ struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
       return lowerToScatteredLoadOp(readOp, rewriter);
 
     // Perform common data transfer checks.
-    if (failed(storeLoadPreconditions(rewriter, readOp, vecTy)))
+    auto readMemTy = cast<MemRefType>(readOp.getShapedType());
+    if (failed(storeLoadPreconditions(rewriter, readOp, vecTy, readMemTy)))
       return failure();
 
     bool isOutOfBounds = readOp.hasOutOfBoundsDim();
@@ -629,7 +639,8 @@ struct TransferWriteLowering
 
     // Perform common data transfer checks.
     VectorType vecTy = writeOp.getVectorType();
-    if (failed(storeLoadPreconditions(rewriter, writeOp, vecTy)))
+    auto writeMemTy = cast<MemRefType>(writeOp.getShapedType());
+    if (failed(storeLoadPreconditions(rewriter, writeOp, vecTy, writeMemTy)))
       return failure();
 
     AffineMap map = writeOp.getPermutationMap();
@@ -735,7 +746,8 @@ struct LoadLowering : public OpRewritePattern<vector::LoadOp> {
     Location loc = loadOp.getLoc();
 
     VectorType vecTy = loadOp.getResult().getType();
-    if (failed(storeLoadPreconditions(rewriter, loadOp, vecTy)))
+    auto memTy = cast<MemRefType>(loadOp.getBase().getType());
+    if (failed(storeLoadPreconditions(rewriter, loadOp, vecTy, memTy)))
       return failure();
 
     // Boundary check is available only for block instructions.
@@ -774,7 +786,8 @@ struct StoreLowering : public OpRewritePattern<vector::StoreOp> {
 
     TypedValue<VectorType> vector = storeOp.getValueToStore();
     VectorType vecTy = vector.getType();
-    if (failed(storeLoadPreconditions(rewriter, storeOp, vecTy)))
+    auto storeMemTy = cast<MemRefType>(storeOp.getBase().getType());
+    if (failed(storeLoadPreconditions(rewriter, storeOp, vecTy, storeMemTy)))
       return failure();
 
     // Boundary check is available only for block instructions.
diff --git a/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir
index 3c11313d05536..8ff2e6ee7d13c 100644
--- a/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir
@@ -115,3 +115,19 @@ func.func @no_store_zero_dim_vector(%vec: vector<f32>,
 
 // CHECK-LABEL: @no_store_zero_dim_vector(
 // CHECK:       vector.store
+
+// -----
+
+// Regression test for https://github.com/llvm/llvm-project/issues/181463:
+// vector.store with a memref whose element type is a vector (e.g.
+// memref<?xvector<4xf32>>) must not crash. The pass used to call
+// getElementTypeBitWidth() on the vector element type which asserts on
+// non-integer/float types; now it bails out gracefully instead.
+
+// CHECK-LABEL: @no_store_vec_element_memref(
+// CHECK:       vector.store
+func.func @no_store_vec_element_memref(%vec: vector<4xf32>,
+    %source: memref<?xvector<4xf32>>, %offset: index) {
+  vector.store %vec, %source[%offset] : memref<?xvector<4xf32>>, vector<4xf32>
+  return
+}



More information about the Mlir-commits mailing list