[Mlir-commits] [mlir] [MLIR][XeGPU][XeVM] Update single element vector type handling. (PR #178558)
Sang Ik Lee
llvmlistbot at llvm.org
Thu Jan 29 10:29:32 PST 2026
https://github.com/silee2 updated https://github.com/llvm/llvm-project/pull/178558
>From c11157c24bee67fd98e0a3182ef530b628799aa1 Mon Sep 17 00:00:00 2001
From: "Lee, Sang Ik" <sang.ik.lee at intel.com>
Date: Thu, 29 Jan 2026 01:29:34 +0000
Subject: [PATCH 1/5] [MLIR][XeGPU][XeVM] Only convert 1D single element vector
to scalar.
---
mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index 8a06271eadd84..8efbb0702f0d3 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -1190,7 +1190,7 @@ struct ConvertXeGPUToXeVMPass
return {};
auto input = inputs.front();
if (auto vecTy = dyn_cast<VectorType>(input.getType())) {
- if (vecTy.getNumElements() == 1) {
+ if (vecTy.getRank() == 1 && vecTy.getNumElements() == 1) {
// If the vector has a single element, return the element type.
Value cast =
vector::ExtractOp::create(builder, loc, input, 0).getResult();
>From 185c0ff2cb707b271f57c06942f20cc2ec8c3fd5 Mon Sep 17 00:00:00 2001
From: "Lee, Sang Ik" <sang.ik.lee at intel.com>
Date: Thu, 29 Jan 2026 02:13:41 +0000
Subject: [PATCH 2/5] Handle n-D single element vectors.
---
mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp | 16 ++++++++++++----
1 file changed, 12 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index 8efbb0702f0d3..b983e421ff83e 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -1180,7 +1180,7 @@ struct ConvertXeGPUToXeVMPass
};
// Materialization to convert
- // - single element 1D vector to scalar
+ // - single element vector to scalar
// - bitcast vector of same rank
// - shape vector of different rank but same element type
auto vectorMaterializationCast = [](OpBuilder &builder, Type type,
@@ -1190,10 +1190,18 @@ struct ConvertXeGPUToXeVMPass
return {};
auto input = inputs.front();
if (auto vecTy = dyn_cast<VectorType>(input.getType())) {
- if (vecTy.getRank() == 1 && vecTy.getNumElements() == 1) {
+ if (vecTy.getNumElements() == 1) {
// If the vector has a single element, return the element type.
- Value cast =
- vector::ExtractOp::create(builder, loc, input, 0).getResult();
+ auto rank = vecTy.getRank();
+ Value cast;
+ if (rank > 1) {
+ cast = vector::ExtractOp::create(builder, loc, input,
+ SmallVector<int64_t>(rank, 0))
+ .getResult();
+ } else {
+ cast =
+ vector::ExtractOp::create(builder, loc, input, 0).getResult();
+ }
if (vecTy.getElementType() == builder.getIndexType())
cast = arith::IndexCastUIOp::create(builder, loc, type, cast)
.getResult();
>From bd03a06e81ae9aab6b0d5bade865fcd509dee0f5 Mon Sep 17 00:00:00 2001
From: "Lee, Sang Ik" <sang.ik.lee at intel.com>
Date: Thu, 29 Jan 2026 03:28:25 +0000
Subject: [PATCH 3/5] Align single element vector type conversion rule and
materialization method.
---
mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp | 12 ++++++------
1 file changed, 6 insertions(+), 6 deletions(-)
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index b983e421ff83e..db21dab150bcf 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -1057,7 +1057,7 @@ struct ConvertXeGPUToXeVMPass
if (llvm::isa<IndexType>(elemType))
elemType = IntegerType::get(&getContext(), 64);
// If the vector is a scalar or has a single element, return the element
- if (rank < 1 || type.getNumElements() == 1)
+ if (rank == 0 || type.getNumElements() == 1)
return elemType;
// Otherwise, convert the vector to a flat vector type.
int64_t sum = llvm::product_of(type.getShape());
@@ -1190,17 +1190,17 @@ struct ConvertXeGPUToXeVMPass
return {};
auto input = inputs.front();
if (auto vecTy = dyn_cast<VectorType>(input.getType())) {
- if (vecTy.getNumElements() == 1) {
+ if (type.isIntOrIndexOrFloat()) {
// If the vector has a single element, return the element type.
auto rank = vecTy.getRank();
Value cast;
- if (rank > 1) {
+ if (rank == 0) {
+ cast =
+ vector::ExtractOp::create(builder, loc, input, {}).getResult();
+ } else {
cast = vector::ExtractOp::create(builder, loc, input,
SmallVector<int64_t>(rank, 0))
.getResult();
- } else {
- cast =
- vector::ExtractOp::create(builder, loc, input, 0).getResult();
}
if (vecTy.getElementType() == builder.getIndexType())
cast = arith::IndexCastUIOp::create(builder, loc, type, cast)
>From b3c35f507f929ee40bc842788cf1d0009c191aab Mon Sep 17 00:00:00 2001
From: "Lee, Sang Ik" <sang.ik.lee at intel.com>
Date: Thu, 29 Jan 2026 03:35:00 +0000
Subject: [PATCH 4/5] Update comments.
---
mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp | 11 ++++++-----
1 file changed, 6 insertions(+), 5 deletions(-)
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index db21dab150bcf..a4c987573223f 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -1056,7 +1056,7 @@ struct ConvertXeGPUToXeVMPass
// If the element type is index, convert it to i64.
if (llvm::isa<IndexType>(elemType))
elemType = IntegerType::get(&getContext(), 64);
- // If the vector is a scalar or has a single element, return the element
+ // If the vector rank is 0 or has a single element, return the element
if (rank == 0 || type.getNumElements() == 1)
return elemType;
// Otherwise, convert the vector to a flat vector type.
@@ -1191,7 +1191,8 @@ struct ConvertXeGPUToXeVMPass
auto input = inputs.front();
if (auto vecTy = dyn_cast<VectorType>(input.getType())) {
if (type.isIntOrIndexOrFloat()) {
- // If the vector has a single element, return the element type.
+ // If the vector rank is 0 or has a single element,
+ // extract scalar of target type.
auto rank = vecTy.getRank();
Value cast;
if (rank == 0) {
@@ -1233,10 +1234,10 @@ struct ConvertXeGPUToXeVMPass
return {};
auto input = inputs.front();
if (input.getType().isIntOrIndexOrFloat()) {
- // If the input is a scalar, and the target type is a vector of single
- // element, create a single element vector by broadcasting.
+ // If the input is a scalar, and the target type is a vector of rank 0
+ // or single element, broadcast scalar to target type.
if (auto vecTy = dyn_cast<VectorType>(type)) {
- if (vecTy.getNumElements() == 1) {
+ if (vecTy.getRank() == 0 || vecTy.getNumElements() == 1) {
return vector::BroadcastOp::create(builder, loc, vecTy, input)
.getResult();
}
>From dc6c997993d584a526b3c1050653616188a726e6 Mon Sep 17 00:00:00 2001
From: "Lee, Sang Ik" <sang.ik.lee at intel.com>
Date: Thu, 29 Jan 2026 18:07:36 +0000
Subject: [PATCH 5/5] Address reviewer comment on element type check.
---
.../Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp | 22 +++++++++++++------
1 file changed, 15 insertions(+), 7 deletions(-)
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index a4c987573223f..02b736706c0db 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -1190,7 +1190,9 @@ struct ConvertXeGPUToXeVMPass
return {};
auto input = inputs.front();
if (auto vecTy = dyn_cast<VectorType>(input.getType())) {
- if (type.isIntOrIndexOrFloat()) {
+ if (type == vecTy.getElementType() ||
+ ((vecTy.getElementType() == builder.getIndexType()) &&
+ type.isInteger())) {
// If the vector rank is 0 or has a single element,
// extract scalar of target type.
auto rank = vecTy.getRank();
@@ -1203,7 +1205,7 @@ struct ConvertXeGPUToXeVMPass
SmallVector<int64_t>(rank, 0))
.getResult();
}
- if (vecTy.getElementType() == builder.getIndexType())
+ if (type != vecTy.getElementType())
cast = arith::IndexCastUIOp::create(builder, loc, type, cast)
.getResult();
return cast;
@@ -1233,13 +1235,19 @@ struct ConvertXeGPUToXeVMPass
if (inputs.size() != 1)
return {};
auto input = inputs.front();
- if (input.getType().isIntOrIndexOrFloat()) {
- // If the input is a scalar, and the target type is a vector of rank 0
- // or single element, broadcast scalar to target type.
- if (auto vecTy = dyn_cast<VectorType>(type)) {
- if (vecTy.getRank() == 0 || vecTy.getNumElements() == 1) {
+ // If the target type is a vector of rank 0 or single element vector
+ // of element type matching input type, broadcast input to target type.
+ if (auto vecTy = dyn_cast<VectorType>(type)) {
+ if (vecTy.getRank() == 0 || vecTy.getNumElements() == 1) {
+ if (input.getType() == vecTy.getElementType()) {
return vector::BroadcastOp::create(builder, loc, vecTy, input)
.getResult();
+ } else if (vecTy.getElementType() == builder.getIndexType()) {
+ Value cast = arith::IndexCastUIOp::create(
+ builder, loc, builder.getIndexType(), input)
+ .getResult();
+ return vector::BroadcastOp::create(builder, loc, vecTy, cast)
+ .getResult();
}
}
}
More information about the Mlir-commits
mailing list