[Mlir-commits] [mlir] [MLIR][XeGPU][XeVM] Update single element vector type handling. (PR #178558)
Sang Ik Lee
llvmlistbot at llvm.org
Thu Jan 29 11:49:00 PST 2026
https://github.com/silee2 updated https://github.com/llvm/llvm-project/pull/178558
>From c11157c24bee67fd98e0a3182ef530b628799aa1 Mon Sep 17 00:00:00 2001
From: "Lee, Sang Ik" <sang.ik.lee at intel.com>
Date: Thu, 29 Jan 2026 01:29:34 +0000
Subject: [PATCH 1/6] [MLIR][XeGPU][XeVM] Only convert 1D single element vector
to scalar.
---
mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index 8a06271eadd84..8efbb0702f0d3 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -1190,7 +1190,7 @@ struct ConvertXeGPUToXeVMPass
return {};
auto input = inputs.front();
if (auto vecTy = dyn_cast<VectorType>(input.getType())) {
- if (vecTy.getNumElements() == 1) {
+ if (vecTy.getRank() == 1 && vecTy.getNumElements() == 1) {
// If the vector has a single element, return the element type.
Value cast =
vector::ExtractOp::create(builder, loc, input, 0).getResult();
>From 185c0ff2cb707b271f57c06942f20cc2ec8c3fd5 Mon Sep 17 00:00:00 2001
From: "Lee, Sang Ik" <sang.ik.lee at intel.com>
Date: Thu, 29 Jan 2026 02:13:41 +0000
Subject: [PATCH 2/6] Handle n-D single element vectors.
---
mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp | 16 ++++++++++++----
1 file changed, 12 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index 8efbb0702f0d3..b983e421ff83e 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -1180,7 +1180,7 @@ struct ConvertXeGPUToXeVMPass
};
// Materialization to convert
- // - single element 1D vector to scalar
+ // - single element vector to scalar
// - bitcast vector of same rank
// - shape vector of different rank but same element type
auto vectorMaterializationCast = [](OpBuilder &builder, Type type,
@@ -1190,10 +1190,18 @@ struct ConvertXeGPUToXeVMPass
return {};
auto input = inputs.front();
if (auto vecTy = dyn_cast<VectorType>(input.getType())) {
- if (vecTy.getRank() == 1 && vecTy.getNumElements() == 1) {
+ if (vecTy.getNumElements() == 1) {
// If the vector has a single element, return the element type.
- Value cast =
- vector::ExtractOp::create(builder, loc, input, 0).getResult();
+ auto rank = vecTy.getRank();
+ Value cast;
+ if (rank > 1) {
+ cast = vector::ExtractOp::create(builder, loc, input,
+ SmallVector<int64_t>(rank, 0))
+ .getResult();
+ } else {
+ cast =
+ vector::ExtractOp::create(builder, loc, input, 0).getResult();
+ }
if (vecTy.getElementType() == builder.getIndexType())
cast = arith::IndexCastUIOp::create(builder, loc, type, cast)
.getResult();
>From bd03a06e81ae9aab6b0d5bade865fcd509dee0f5 Mon Sep 17 00:00:00 2001
From: "Lee, Sang Ik" <sang.ik.lee at intel.com>
Date: Thu, 29 Jan 2026 03:28:25 +0000
Subject: [PATCH 3/6] Align single element vector type conversion rule and
materialization method.
---
mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp | 12 ++++++------
1 file changed, 6 insertions(+), 6 deletions(-)
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index b983e421ff83e..db21dab150bcf 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -1057,7 +1057,7 @@ struct ConvertXeGPUToXeVMPass
if (llvm::isa<IndexType>(elemType))
elemType = IntegerType::get(&getContext(), 64);
// If the vector is a scalar or has a single element, return the element
- if (rank < 1 || type.getNumElements() == 1)
+ if (rank == 0 || type.getNumElements() == 1)
return elemType;
// Otherwise, convert the vector to a flat vector type.
int64_t sum = llvm::product_of(type.getShape());
@@ -1190,17 +1190,17 @@ struct ConvertXeGPUToXeVMPass
return {};
auto input = inputs.front();
if (auto vecTy = dyn_cast<VectorType>(input.getType())) {
- if (vecTy.getNumElements() == 1) {
+ if (type.isIntOrIndexOrFloat()) {
// If the vector has a single element, return the element type.
auto rank = vecTy.getRank();
Value cast;
- if (rank > 1) {
+ if (rank == 0) {
+ cast =
+ vector::ExtractOp::create(builder, loc, input, {}).getResult();
+ } else {
cast = vector::ExtractOp::create(builder, loc, input,
SmallVector<int64_t>(rank, 0))
.getResult();
- } else {
- cast =
- vector::ExtractOp::create(builder, loc, input, 0).getResult();
}
if (vecTy.getElementType() == builder.getIndexType())
cast = arith::IndexCastUIOp::create(builder, loc, type, cast)
>From b3c35f507f929ee40bc842788cf1d0009c191aab Mon Sep 17 00:00:00 2001
From: "Lee, Sang Ik" <sang.ik.lee at intel.com>
Date: Thu, 29 Jan 2026 03:35:00 +0000
Subject: [PATCH 4/6] Update comments.
---
mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp | 11 ++++++-----
1 file changed, 6 insertions(+), 5 deletions(-)
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index db21dab150bcf..a4c987573223f 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -1056,7 +1056,7 @@ struct ConvertXeGPUToXeVMPass
// If the element type is index, convert it to i64.
if (llvm::isa<IndexType>(elemType))
elemType = IntegerType::get(&getContext(), 64);
- // If the vector is a scalar or has a single element, return the element
+ // If the vector rank is 0 or has a single element, return the element
if (rank == 0 || type.getNumElements() == 1)
return elemType;
// Otherwise, convert the vector to a flat vector type.
@@ -1191,7 +1191,8 @@ struct ConvertXeGPUToXeVMPass
auto input = inputs.front();
if (auto vecTy = dyn_cast<VectorType>(input.getType())) {
if (type.isIntOrIndexOrFloat()) {
- // If the vector has a single element, return the element type.
+ // If the vector rank is 0 or has a single element,
+ // extract scalar of target type.
auto rank = vecTy.getRank();
Value cast;
if (rank == 0) {
@@ -1233,10 +1234,10 @@ struct ConvertXeGPUToXeVMPass
return {};
auto input = inputs.front();
if (input.getType().isIntOrIndexOrFloat()) {
- // If the input is a scalar, and the target type is a vector of single
- // element, create a single element vector by broadcasting.
+ // If the input is a scalar, and the target type is a vector of rank 0
+ // or single element, broadcast scalar to target type.
if (auto vecTy = dyn_cast<VectorType>(type)) {
- if (vecTy.getNumElements() == 1) {
+ if (vecTy.getRank() == 0 || vecTy.getNumElements() == 1) {
return vector::BroadcastOp::create(builder, loc, vecTy, input)
.getResult();
}
>From dc6c997993d584a526b3c1050653616188a726e6 Mon Sep 17 00:00:00 2001
From: "Lee, Sang Ik" <sang.ik.lee at intel.com>
Date: Thu, 29 Jan 2026 18:07:36 +0000
Subject: [PATCH 5/6] Address reviewer comment on element type check.
---
.../Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp | 22 +++++++++++++------
1 file changed, 15 insertions(+), 7 deletions(-)
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index a4c987573223f..02b736706c0db 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -1190,7 +1190,9 @@ struct ConvertXeGPUToXeVMPass
return {};
auto input = inputs.front();
if (auto vecTy = dyn_cast<VectorType>(input.getType())) {
- if (type.isIntOrIndexOrFloat()) {
+ if (type == vecTy.getElementType() ||
+ ((vecTy.getElementType() == builder.getIndexType()) &&
+ type.isInteger())) {
// If the vector rank is 0 or has a single element,
// extract scalar of target type.
auto rank = vecTy.getRank();
@@ -1203,7 +1205,7 @@ struct ConvertXeGPUToXeVMPass
SmallVector<int64_t>(rank, 0))
.getResult();
}
- if (vecTy.getElementType() == builder.getIndexType())
+ if (type != vecTy.getElementType())
cast = arith::IndexCastUIOp::create(builder, loc, type, cast)
.getResult();
return cast;
@@ -1233,13 +1235,19 @@ struct ConvertXeGPUToXeVMPass
if (inputs.size() != 1)
return {};
auto input = inputs.front();
- if (input.getType().isIntOrIndexOrFloat()) {
- // If the input is a scalar, and the target type is a vector of rank 0
- // or single element, broadcast scalar to target type.
- if (auto vecTy = dyn_cast<VectorType>(type)) {
- if (vecTy.getRank() == 0 || vecTy.getNumElements() == 1) {
+ // If the target type is a vector of rank 0 or single element vector
+ // of element type matching input type, broadcast input to target type.
+ if (auto vecTy = dyn_cast<VectorType>(type)) {
+ if (vecTy.getRank() == 0 || vecTy.getNumElements() == 1) {
+ if (input.getType() == vecTy.getElementType()) {
return vector::BroadcastOp::create(builder, loc, vecTy, input)
.getResult();
+ } else if (vecTy.getElementType() == builder.getIndexType()) {
+ Value cast = arith::IndexCastUIOp::create(
+ builder, loc, builder.getIndexType(), input)
+ .getResult();
+ return vector::BroadcastOp::create(builder, loc, vecTy, cast)
+ .getResult();
}
}
}
>From 85199a626765ead20f5a68aeb7c73bfd9752ad26 Mon Sep 17 00:00:00 2001
From: "Lee, Sang Ik" <sang.ik.lee at intel.com>
Date: Thu, 29 Jan 2026 19:48:09 +0000
Subject: [PATCH 6/6] Split materialization cast callback functions and use
better names.
---
.../Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp | 93 ++++++++++++-------
1 file changed, 61 insertions(+), 32 deletions(-)
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index 02b736706c0db..6df209438447b 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -1085,9 +1085,12 @@ struct ConvertXeGPUToXeVMPass
// add materialization casts to handle them.
// Materialization to convert memref to i64 or i32 depending on global/SLM
- auto memrefMaterializationCast = [](OpBuilder &builder, Type type,
- ValueRange inputs,
- Location loc) -> Value {
+ // Applies only to target materialization.
+ // Note: int type to memref materialization is not required as xegpu ops
+ // currently do not produce memrefs as result.
+ auto memrefToIntMaterializationCast = [](OpBuilder &builder, Type type,
+ ValueRange inputs,
+ Location loc) -> Value {
if (inputs.size() != 1)
return {};
auto input = inputs.front();
@@ -1146,9 +1149,12 @@ struct ConvertXeGPUToXeVMPass
};
// Materialization to convert ui64 to i64
- auto ui64MaterializationCast = [](OpBuilder &builder, Type type,
- ValueRange inputs,
- Location loc) -> Value {
+ // Applies only to target materialization.
+ // Note: i64 to ui64 materialization is not required as xegpu ops
+ // currently do not produce ui64 as result.
+ auto ui64ToI64MaterializationCast = [](OpBuilder &builder, Type type,
+ ValueRange inputs,
+ Location loc) -> Value {
if (inputs.size() != 1)
return {};
auto input = inputs.front();
@@ -1163,9 +1169,12 @@ struct ConvertXeGPUToXeVMPass
};
// Materialization to convert ui32 to i32
- auto ui32MaterializationCast = [](OpBuilder &builder, Type type,
- ValueRange inputs,
- Location loc) -> Value {
+ // Applies only to target materialization.
+ // Note: i32 to ui32 materialization is not required as xegpu ops
+ // currently do not produce ui32 as result.
+ auto ui32ToI32MaterializationCast = [](OpBuilder &builder, Type type,
+ ValueRange inputs,
+ Location loc) -> Value {
if (inputs.size() != 1)
return {};
auto input = inputs.front();
@@ -1180,12 +1189,39 @@ struct ConvertXeGPUToXeVMPass
};
// Materialization to convert
- // - single element vector to scalar
// - bitcast vector of same rank
// - shape vector of different rank but same element type
- auto vectorMaterializationCast = [](OpBuilder &builder, Type type,
- ValueRange inputs,
- Location loc) -> Value {
+ // Applies to both source and target materialization.
+ auto vectorToVectorMaterializationCast = [](OpBuilder &builder, Type type,
+ ValueRange inputs,
+ Location loc) -> Value {
+ if (inputs.size() != 1)
+ return {};
+ auto input = inputs.front();
+ if (auto vecTy = dyn_cast<VectorType>(input.getType())) {
+ if (auto targetVecTy = dyn_cast<VectorType>(type)) {
+ // If the target type is a vector of same rank,
+ // bitcast to the target type.
+ if (targetVecTy.getRank() == vecTy.getRank())
+ return vector::BitCastOp::create(builder, loc, targetVecTy, input)
+ .getResult();
+ else if (targetVecTy.getElementType() == vecTy.getElementType()) {
+ // If the target type is a vector of different rank but same element
+ // type, reshape to the target type.
+ return vector::ShapeCastOp::create(builder, loc, targetVecTy, input)
+ .getResult();
+ }
+ }
+ }
+ return {};
+ };
+
+ // Materialization to convert
+ // - single element vector to single element of vector element type
+ // Applies only to target materialization.
+ auto vectorToSingleElementMaterializationCast =
+ [](OpBuilder &builder, Type type, ValueRange inputs,
+ Location loc) -> Value {
if (inputs.size() != 1)
return {};
auto input = inputs.front();
@@ -1209,27 +1245,18 @@ struct ConvertXeGPUToXeVMPass
cast = arith::IndexCastUIOp::create(builder, loc, type, cast)
.getResult();
return cast;
- } else if (auto targetVecTy = dyn_cast<VectorType>(type)) {
- // If the target type is a vector of same rank,
- // bitcast to the target type.
- if (targetVecTy.getRank() == vecTy.getRank())
- return vector::BitCastOp::create(builder, loc, targetVecTy, input)
- .getResult();
- else if (targetVecTy.getElementType() == vecTy.getElementType()) {
- // If the target type is a vector of different rank but same element
- // type, reshape to the target type.
- return vector::ShapeCastOp::create(builder, loc, targetVecTy, input)
- .getResult();
- }
}
}
return {};
};
+ // Materialization to convert
+ // - single element of vector element type to single element vector
// If result type of original op is single element vector and lowered type
// is scalar. This materialization cast creates a single element vector by
// broadcasting the scalar value.
- auto singleElementVectorMaterializationCast =
+ // Applies only to source materialization.
+ auto singleElementToVectorMaterializationCast =
[](OpBuilder &builder, Type type, ValueRange inputs,
Location loc) -> Value {
if (inputs.size() != 1)
@@ -1254,12 +1281,14 @@ struct ConvertXeGPUToXeVMPass
return {};
};
typeConverter.addSourceMaterialization(
- singleElementVectorMaterializationCast);
- typeConverter.addSourceMaterialization(vectorMaterializationCast);
- typeConverter.addTargetMaterialization(memrefMaterializationCast);
- typeConverter.addTargetMaterialization(ui32MaterializationCast);
- typeConverter.addTargetMaterialization(ui64MaterializationCast);
- typeConverter.addTargetMaterialization(vectorMaterializationCast);
+ singleElementToVectorMaterializationCast);
+ typeConverter.addSourceMaterialization(vectorToVectorMaterializationCast);
+ typeConverter.addTargetMaterialization(memrefToIntMaterializationCast);
+ typeConverter.addTargetMaterialization(ui32ToI32MaterializationCast);
+ typeConverter.addTargetMaterialization(ui64ToI64MaterializationCast);
+ typeConverter.addTargetMaterialization(
+ vectorToSingleElementMaterializationCast);
+ typeConverter.addTargetMaterialization(vectorToVectorMaterializationCast);
ConversionTarget target(getContext());
target.addLegalDialect<xevm::XeVMDialect, LLVM::LLVMDialect,
vector::VectorDialect, arith::ArithDialect,
More information about the Mlir-commits
mailing list