[Mlir-commits] [mlir] [MLIR][XeGPU][XeVM] Update single element vector type handling. (PR #178558)

Sang Ik Lee llvmlistbot at llvm.org
Thu Jan 29 10:29:32 PST 2026


https://github.com/silee2 updated https://github.com/llvm/llvm-project/pull/178558

>From c11157c24bee67fd98e0a3182ef530b628799aa1 Mon Sep 17 00:00:00 2001
From: "Lee, Sang Ik" <sang.ik.lee at intel.com>
Date: Thu, 29 Jan 2026 01:29:34 +0000
Subject: [PATCH 1/5] [MLIR][XeGPU][XeVM] Only convert 1D single element vector
 to scalar.

---
 mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index 8a06271eadd84..8efbb0702f0d3 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -1190,7 +1190,7 @@ struct ConvertXeGPUToXeVMPass
         return {};
       auto input = inputs.front();
       if (auto vecTy = dyn_cast<VectorType>(input.getType())) {
-        if (vecTy.getNumElements() == 1) {
+        if (vecTy.getRank() == 1 && vecTy.getNumElements() == 1) {
           // If the vector has a single element, return the element type.
           Value cast =
               vector::ExtractOp::create(builder, loc, input, 0).getResult();

>From 185c0ff2cb707b271f57c06942f20cc2ec8c3fd5 Mon Sep 17 00:00:00 2001
From: "Lee, Sang Ik" <sang.ik.lee at intel.com>
Date: Thu, 29 Jan 2026 02:13:41 +0000
Subject: [PATCH 2/5] Handle n-D single element vectors.

---
 mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp | 16 ++++++++++++----
 1 file changed, 12 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index 8efbb0702f0d3..b983e421ff83e 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -1180,7 +1180,7 @@ struct ConvertXeGPUToXeVMPass
     };
 
     // Materialization to convert
-    //   - single element 1D vector to scalar
+    //   - single element vector to scalar
     //   - bitcast vector of same rank
     //   - shape vector of different rank but same element type
     auto vectorMaterializationCast = [](OpBuilder &builder, Type type,
@@ -1190,10 +1190,18 @@ struct ConvertXeGPUToXeVMPass
         return {};
       auto input = inputs.front();
       if (auto vecTy = dyn_cast<VectorType>(input.getType())) {
-        if (vecTy.getRank() == 1 && vecTy.getNumElements() == 1) {
+        if (vecTy.getNumElements() == 1) {
           // If the vector has a single element, return the element type.
-          Value cast =
-              vector::ExtractOp::create(builder, loc, input, 0).getResult();
+          auto rank = vecTy.getRank();
+          Value cast;
+          if (rank > 1) {
+            cast = vector::ExtractOp::create(builder, loc, input,
+                                             SmallVector<int64_t>(rank, 0))
+                       .getResult();
+          } else {
+            cast =
+                vector::ExtractOp::create(builder, loc, input, 0).getResult();
+          }
           if (vecTy.getElementType() == builder.getIndexType())
             cast = arith::IndexCastUIOp::create(builder, loc, type, cast)
                        .getResult();

>From bd03a06e81ae9aab6b0d5bade865fcd509dee0f5 Mon Sep 17 00:00:00 2001
From: "Lee, Sang Ik" <sang.ik.lee at intel.com>
Date: Thu, 29 Jan 2026 03:28:25 +0000
Subject: [PATCH 3/5] Align single element vector type conversion rule and
 materialization method.

---
 mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp | 12 ++++++------
 1 file changed, 6 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index b983e421ff83e..db21dab150bcf 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -1057,7 +1057,7 @@ struct ConvertXeGPUToXeVMPass
       if (llvm::isa<IndexType>(elemType))
         elemType = IntegerType::get(&getContext(), 64);
       // If the vector is a scalar or has a single element, return the element
-      if (rank < 1 || type.getNumElements() == 1)
+      if (rank == 0 || type.getNumElements() == 1)
         return elemType;
       // Otherwise, convert the vector to a flat vector type.
       int64_t sum = llvm::product_of(type.getShape());
@@ -1190,17 +1190,17 @@ struct ConvertXeGPUToXeVMPass
         return {};
       auto input = inputs.front();
       if (auto vecTy = dyn_cast<VectorType>(input.getType())) {
-        if (vecTy.getNumElements() == 1) {
+        if (type.isIntOrIndexOrFloat()) {
           // If the vector has a single element, return the element type.
           auto rank = vecTy.getRank();
           Value cast;
-          if (rank > 1) {
+          if (rank == 0) {
+            cast =
+                vector::ExtractOp::create(builder, loc, input, {}).getResult();
+          } else {
             cast = vector::ExtractOp::create(builder, loc, input,
                                              SmallVector<int64_t>(rank, 0))
                        .getResult();
-          } else {
-            cast =
-                vector::ExtractOp::create(builder, loc, input, 0).getResult();
           }
           if (vecTy.getElementType() == builder.getIndexType())
             cast = arith::IndexCastUIOp::create(builder, loc, type, cast)

>From b3c35f507f929ee40bc842788cf1d0009c191aab Mon Sep 17 00:00:00 2001
From: "Lee, Sang Ik" <sang.ik.lee at intel.com>
Date: Thu, 29 Jan 2026 03:35:00 +0000
Subject: [PATCH 4/5] Update comments.

---
 mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp | 11 ++++++-----
 1 file changed, 6 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index db21dab150bcf..a4c987573223f 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -1056,7 +1056,7 @@ struct ConvertXeGPUToXeVMPass
       // If the element type is index, convert it to i64.
       if (llvm::isa<IndexType>(elemType))
         elemType = IntegerType::get(&getContext(), 64);
-      // If the vector is a scalar or has a single element, return the element
+      // If the vector rank is 0 or has a single element, return the element
       if (rank == 0 || type.getNumElements() == 1)
         return elemType;
       // Otherwise, convert the vector to a flat vector type.
@@ -1191,7 +1191,8 @@ struct ConvertXeGPUToXeVMPass
       auto input = inputs.front();
       if (auto vecTy = dyn_cast<VectorType>(input.getType())) {
         if (type.isIntOrIndexOrFloat()) {
-          // If the vector has a single element, return the element type.
+          // If the vector rank is 0 or has a single element,
+          // extract scalar of target type.
           auto rank = vecTy.getRank();
           Value cast;
           if (rank == 0) {
@@ -1233,10 +1234,10 @@ struct ConvertXeGPUToXeVMPass
         return {};
       auto input = inputs.front();
       if (input.getType().isIntOrIndexOrFloat()) {
-        // If the input is a scalar, and the target type is a vector of single
-        // element, create a single element vector by broadcasting.
+        // If the input is a scalar, and the target type is a vector of rank 0
+        // or single element, broadcast scalar to target type.
         if (auto vecTy = dyn_cast<VectorType>(type)) {
-          if (vecTy.getNumElements() == 1) {
+          if (vecTy.getRank() == 0 || vecTy.getNumElements() == 1) {
             return vector::BroadcastOp::create(builder, loc, vecTy, input)
                 .getResult();
           }

>From dc6c997993d584a526b3c1050653616188a726e6 Mon Sep 17 00:00:00 2001
From: "Lee, Sang Ik" <sang.ik.lee at intel.com>
Date: Thu, 29 Jan 2026 18:07:36 +0000
Subject: [PATCH 5/5] Address reviewer comment on element type check.

---
 .../Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp    | 22 +++++++++++++------
 1 file changed, 15 insertions(+), 7 deletions(-)

diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index a4c987573223f..02b736706c0db 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -1190,7 +1190,9 @@ struct ConvertXeGPUToXeVMPass
         return {};
       auto input = inputs.front();
       if (auto vecTy = dyn_cast<VectorType>(input.getType())) {
-        if (type.isIntOrIndexOrFloat()) {
+        if (type == vecTy.getElementType() ||
+            ((vecTy.getElementType() == builder.getIndexType()) &&
+             type.isInteger())) {
           // If the vector rank is 0 or has a single element,
           // extract scalar of target type.
           auto rank = vecTy.getRank();
@@ -1203,7 +1205,7 @@ struct ConvertXeGPUToXeVMPass
                                              SmallVector<int64_t>(rank, 0))
                        .getResult();
           }
-          if (vecTy.getElementType() == builder.getIndexType())
+          if (type != vecTy.getElementType())
             cast = arith::IndexCastUIOp::create(builder, loc, type, cast)
                        .getResult();
           return cast;
@@ -1233,13 +1235,19 @@ struct ConvertXeGPUToXeVMPass
       if (inputs.size() != 1)
         return {};
       auto input = inputs.front();
-      if (input.getType().isIntOrIndexOrFloat()) {
-        // If the input is a scalar, and the target type is a vector of rank 0
-        // or single element, broadcast scalar to target type.
-        if (auto vecTy = dyn_cast<VectorType>(type)) {
-          if (vecTy.getRank() == 0 || vecTy.getNumElements() == 1) {
+      // If the target type is a vector of rank 0 or single element vector
+      // of element type matching input type, broadcast input to target type.
+      if (auto vecTy = dyn_cast<VectorType>(type)) {
+        if (vecTy.getRank() == 0 || vecTy.getNumElements() == 1) {
+          if (input.getType() == vecTy.getElementType()) {
             return vector::BroadcastOp::create(builder, loc, vecTy, input)
                 .getResult();
+          } else if (vecTy.getElementType() == builder.getIndexType()) {
+            Value cast = arith::IndexCastUIOp::create(
+                             builder, loc, builder.getIndexType(), input)
+                             .getResult();
+            return vector::BroadcastOp::create(builder, loc, vecTy, cast)
+                .getResult();
           }
         }
       }



More information about the Mlir-commits mailing list