[Mlir-commits] [mlir] bbd5b1d - [mlir][VectorToXeGPU] Fix crash on memref with non-scalar element type (#183905)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Mar 3 03:33:09 PST 2026
Author: Mehdi Amini
Date: 2026-03-03T11:33:03Z
New Revision: bbd5b1d3bd073c239b36359932c3049b0d5c83bd
URL: https://github.com/llvm/llvm-project/commit/bbd5b1d3bd073c239b36359932c3049b0d5c83bd
DIFF: https://github.com/llvm/llvm-project/commit/bbd5b1d3bd073c239b36359932c3049b0d5c83bd.diff
LOG: [mlir][VectorToXeGPU] Fix crash on memref with non-scalar element type (#183905)
The vector.store and vector.load lowering in --convert-vector-to-xegpu
would crash when the source memref had a non-integer/float element type
(e.g. memref<?xvector<4xf32>>).
The crash occurred inside createNdDescriptor() when computing the byte
offset for dynamic memrefs: srcTy.getElementTypeBitWidth() internally
calls getIntOrFloatBitWidth() which asserts on non-scalar types such as
vector<4xf32>.
Fix by adding a check for the memref's element type in
storeLoadPreconditions(). If the element type is not an integer or
float, the pattern returns notifyMatchFailure() instead of proceeding
and crashing.
The same guard is applied to TransferReadLowering and
TransferWriteLowering which share the same helper and can hit the same
path.
Fixes #181463
Added:
Modified:
mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
index c81bb4b455b98..eb45e323ef849 100644
--- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
+++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
@@ -48,7 +48,8 @@ static bool isZeroConstant(Value val) {
}
static LogicalResult storeLoadPreconditions(PatternRewriter &rewriter,
- Operation *op, VectorType vecTy) {
+ Operation *op, VectorType vecTy,
+ MemRefType memTy) {
// Validate only vector as the basic vector store and load ops guarantee
// XeGPU-compatible memref source.
unsigned vecRank = vecTy.getRank();
@@ -59,6 +60,14 @@ static LogicalResult storeLoadPreconditions(PatternRewriter &rewriter,
return rewriter.notifyMatchFailure(
op, "Expected scalar type with known bitwidth");
+ // XeGPU requires the memref to have a scalar integer or float element type.
+ // Memrefs with vector element types (e.g. memref<?xvector<4xf32>>) are not
+ // supported because createNdDescriptor computes byte offsets using
+ // getElementTypeBitWidth(), which asserts on non-integer/float types.
+ if (!memTy.getElementType().isIntOrFloat())
+ return rewriter.notifyMatchFailure(
+ op, "Unsupported memref element type: expected integer or float");
+
return success();
}
@@ -556,7 +565,8 @@ struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
return lowerToScatteredLoadOp(readOp, rewriter);
// Perform common data transfer checks.
- if (failed(storeLoadPreconditions(rewriter, readOp, vecTy)))
+ auto readMemTy = cast<MemRefType>(readOp.getShapedType());
+ if (failed(storeLoadPreconditions(rewriter, readOp, vecTy, readMemTy)))
return failure();
bool isOutOfBounds = readOp.hasOutOfBoundsDim();
@@ -629,7 +639,8 @@ struct TransferWriteLowering
// Perform common data transfer checks.
VectorType vecTy = writeOp.getVectorType();
- if (failed(storeLoadPreconditions(rewriter, writeOp, vecTy)))
+ auto writeMemTy = cast<MemRefType>(writeOp.getShapedType());
+ if (failed(storeLoadPreconditions(rewriter, writeOp, vecTy, writeMemTy)))
return failure();
AffineMap map = writeOp.getPermutationMap();
@@ -735,7 +746,8 @@ struct LoadLowering : public OpRewritePattern<vector::LoadOp> {
Location loc = loadOp.getLoc();
VectorType vecTy = loadOp.getResult().getType();
- if (failed(storeLoadPreconditions(rewriter, loadOp, vecTy)))
+ MemRefType memTy = loadOp.getBase().getType();
+ if (failed(storeLoadPreconditions(rewriter, loadOp, vecTy, memTy)))
return failure();
// Boundary check is available only for block instructions.
@@ -774,7 +786,8 @@ struct StoreLowering : public OpRewritePattern<vector::StoreOp> {
TypedValue<VectorType> vector = storeOp.getValueToStore();
VectorType vecTy = vector.getType();
- if (failed(storeLoadPreconditions(rewriter, storeOp, vecTy)))
+ MemRefType memTy = storeOp.getBase().getType();
+ if (failed(storeLoadPreconditions(rewriter, storeOp, vecTy, memTy)))
return failure();
// Boundary check is available only for block instructions.
diff --git a/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir
index 3c11313d05536..8ff2e6ee7d13c 100644
--- a/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir
@@ -115,3 +115,19 @@ func.func @no_store_zero_dim_vector(%vec: vector<f32>,
// CHECK-LABEL: @no_store_zero_dim_vector(
// CHECK: vector.store
+
+// -----
+
+// Regression test for https://github.com/llvm/llvm-project/issues/181463:
+// vector.store with a memref whose element type is a vector (e.g.
+// memref<?xvector<4xf32>>) must not crash. The pass used to call
+// getElementTypeBitWidth() on the vector element type which asserts on
+// non-integer/float types; now it bails out gracefully instead.
+
+// CHECK-LABEL: @no_store_vec_element_memref(
+// CHECK: vector.store
+func.func @no_store_vec_element_memref(%vec: vector<4xf32>,
+ %source: memref<?xvector<4xf32>>, %offset: index) {
+ vector.store %vec, %source[%offset] : memref<?xvector<4xf32>>, vector<4xf32>
+ return
+}
More information about the Mlir-commits
mailing list