[Mlir-commits] [mlir] [MLIR][XeGPU][XeVM] Update single element vector type handling. (PR #178558)

Sang Ik Lee llvmlistbot at llvm.org
Thu Jan 29 11:49:00 PST 2026


https://github.com/silee2 updated https://github.com/llvm/llvm-project/pull/178558

>From c11157c24bee67fd98e0a3182ef530b628799aa1 Mon Sep 17 00:00:00 2001
From: "Lee, Sang Ik" <sang.ik.lee at intel.com>
Date: Thu, 29 Jan 2026 01:29:34 +0000
Subject: [PATCH 1/6] [MLIR][XeGPU][XeVM] Only convert 1D single element vector
 to scalar.

---
 mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index 8a06271eadd84..8efbb0702f0d3 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -1190,7 +1190,7 @@ struct ConvertXeGPUToXeVMPass
         return {};
       auto input = inputs.front();
       if (auto vecTy = dyn_cast<VectorType>(input.getType())) {
-        if (vecTy.getNumElements() == 1) {
+        if (vecTy.getRank() == 1 && vecTy.getNumElements() == 1) {
           // If the vector has a single element, return the element type.
           Value cast =
               vector::ExtractOp::create(builder, loc, input, 0).getResult();

>From 185c0ff2cb707b271f57c06942f20cc2ec8c3fd5 Mon Sep 17 00:00:00 2001
From: "Lee, Sang Ik" <sang.ik.lee at intel.com>
Date: Thu, 29 Jan 2026 02:13:41 +0000
Subject: [PATCH 2/6] Handle n-D single element vectors.

---
 mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp | 16 ++++++++++++----
 1 file changed, 12 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index 8efbb0702f0d3..b983e421ff83e 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -1180,7 +1180,7 @@ struct ConvertXeGPUToXeVMPass
     };
 
     // Materialization to convert
-    //   - single element 1D vector to scalar
+    //   - single element vector to scalar
     //   - bitcast vector of same rank
     //   - shape vector of different rank but same element type
     auto vectorMaterializationCast = [](OpBuilder &builder, Type type,
@@ -1190,10 +1190,18 @@ struct ConvertXeGPUToXeVMPass
         return {};
       auto input = inputs.front();
       if (auto vecTy = dyn_cast<VectorType>(input.getType())) {
-        if (vecTy.getRank() == 1 && vecTy.getNumElements() == 1) {
+        if (vecTy.getNumElements() == 1) {
           // If the vector has a single element, return the element type.
-          Value cast =
-              vector::ExtractOp::create(builder, loc, input, 0).getResult();
+          auto rank = vecTy.getRank();
+          Value cast;
+          if (rank > 1) {
+            cast = vector::ExtractOp::create(builder, loc, input,
+                                             SmallVector<int64_t>(rank, 0))
+                       .getResult();
+          } else {
+            cast =
+                vector::ExtractOp::create(builder, loc, input, 0).getResult();
+          }
           if (vecTy.getElementType() == builder.getIndexType())
             cast = arith::IndexCastUIOp::create(builder, loc, type, cast)
                        .getResult();

>From bd03a06e81ae9aab6b0d5bade865fcd509dee0f5 Mon Sep 17 00:00:00 2001
From: "Lee, Sang Ik" <sang.ik.lee at intel.com>
Date: Thu, 29 Jan 2026 03:28:25 +0000
Subject: [PATCH 3/6] Align single element vector type conversion rule and
 materialization method.

---
 mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp | 12 ++++++------
 1 file changed, 6 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index b983e421ff83e..db21dab150bcf 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -1057,7 +1057,7 @@ struct ConvertXeGPUToXeVMPass
       if (llvm::isa<IndexType>(elemType))
         elemType = IntegerType::get(&getContext(), 64);
       // If the vector is a scalar or has a single element, return the element
-      if (rank < 1 || type.getNumElements() == 1)
+      if (rank == 0 || type.getNumElements() == 1)
         return elemType;
       // Otherwise, convert the vector to a flat vector type.
       int64_t sum = llvm::product_of(type.getShape());
@@ -1190,17 +1190,17 @@ struct ConvertXeGPUToXeVMPass
         return {};
       auto input = inputs.front();
       if (auto vecTy = dyn_cast<VectorType>(input.getType())) {
-        if (vecTy.getNumElements() == 1) {
+        if (type.isIntOrIndexOrFloat()) {
           // If the vector has a single element, return the element type.
           auto rank = vecTy.getRank();
           Value cast;
-          if (rank > 1) {
+          if (rank == 0) {
+            cast =
+                vector::ExtractOp::create(builder, loc, input, {}).getResult();
+          } else {
             cast = vector::ExtractOp::create(builder, loc, input,
                                              SmallVector<int64_t>(rank, 0))
                        .getResult();
-          } else {
-            cast =
-                vector::ExtractOp::create(builder, loc, input, 0).getResult();
           }
           if (vecTy.getElementType() == builder.getIndexType())
             cast = arith::IndexCastUIOp::create(builder, loc, type, cast)

>From b3c35f507f929ee40bc842788cf1d0009c191aab Mon Sep 17 00:00:00 2001
From: "Lee, Sang Ik" <sang.ik.lee at intel.com>
Date: Thu, 29 Jan 2026 03:35:00 +0000
Subject: [PATCH 4/6] Update comments.

---
 mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp | 11 ++++++-----
 1 file changed, 6 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index db21dab150bcf..a4c987573223f 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -1056,7 +1056,7 @@ struct ConvertXeGPUToXeVMPass
       // If the element type is index, convert it to i64.
       if (llvm::isa<IndexType>(elemType))
         elemType = IntegerType::get(&getContext(), 64);
-      // If the vector is a scalar or has a single element, return the element
+      // If the vector rank is 0 or has a single element, return the element
       if (rank == 0 || type.getNumElements() == 1)
         return elemType;
       // Otherwise, convert the vector to a flat vector type.
@@ -1191,7 +1191,8 @@ struct ConvertXeGPUToXeVMPass
       auto input = inputs.front();
       if (auto vecTy = dyn_cast<VectorType>(input.getType())) {
         if (type.isIntOrIndexOrFloat()) {
-          // If the vector has a single element, return the element type.
+          // If the vector rank is 0 or has a single element,
+          // extract scalar of target type.
           auto rank = vecTy.getRank();
           Value cast;
           if (rank == 0) {
@@ -1233,10 +1234,10 @@ struct ConvertXeGPUToXeVMPass
         return {};
       auto input = inputs.front();
       if (input.getType().isIntOrIndexOrFloat()) {
-        // If the input is a scalar, and the target type is a vector of single
-        // element, create a single element vector by broadcasting.
+        // If the input is a scalar, and the target type is a vector of rank 0
+        // or single element, broadcast scalar to target type.
         if (auto vecTy = dyn_cast<VectorType>(type)) {
-          if (vecTy.getNumElements() == 1) {
+          if (vecTy.getRank() == 0 || vecTy.getNumElements() == 1) {
             return vector::BroadcastOp::create(builder, loc, vecTy, input)
                 .getResult();
           }

>From dc6c997993d584a526b3c1050653616188a726e6 Mon Sep 17 00:00:00 2001
From: "Lee, Sang Ik" <sang.ik.lee at intel.com>
Date: Thu, 29 Jan 2026 18:07:36 +0000
Subject: [PATCH 5/6] Address reviewer comment on element type check.

---
 .../Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp    | 22 +++++++++++++------
 1 file changed, 15 insertions(+), 7 deletions(-)

diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index a4c987573223f..02b736706c0db 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -1190,7 +1190,9 @@ struct ConvertXeGPUToXeVMPass
         return {};
       auto input = inputs.front();
       if (auto vecTy = dyn_cast<VectorType>(input.getType())) {
-        if (type.isIntOrIndexOrFloat()) {
+        if (type == vecTy.getElementType() ||
+            ((vecTy.getElementType() == builder.getIndexType()) &&
+             type.isInteger())) {
           // If the vector rank is 0 or has a single element,
           // extract scalar of target type.
           auto rank = vecTy.getRank();
@@ -1203,7 +1205,7 @@ struct ConvertXeGPUToXeVMPass
                                              SmallVector<int64_t>(rank, 0))
                        .getResult();
           }
-          if (vecTy.getElementType() == builder.getIndexType())
+          if (type != vecTy.getElementType())
             cast = arith::IndexCastUIOp::create(builder, loc, type, cast)
                        .getResult();
           return cast;
@@ -1233,13 +1235,19 @@ struct ConvertXeGPUToXeVMPass
       if (inputs.size() != 1)
         return {};
       auto input = inputs.front();
-      if (input.getType().isIntOrIndexOrFloat()) {
-        // If the input is a scalar, and the target type is a vector of rank 0
-        // or single element, broadcast scalar to target type.
-        if (auto vecTy = dyn_cast<VectorType>(type)) {
-          if (vecTy.getRank() == 0 || vecTy.getNumElements() == 1) {
+      // If the target type is a vector of rank 0 or single element vector
+      // of element type matching input type, broadcast input to target type.
+      if (auto vecTy = dyn_cast<VectorType>(type)) {
+        if (vecTy.getRank() == 0 || vecTy.getNumElements() == 1) {
+          if (input.getType() == vecTy.getElementType()) {
             return vector::BroadcastOp::create(builder, loc, vecTy, input)
                 .getResult();
+          } else if (vecTy.getElementType() == builder.getIndexType()) {
+            Value cast = arith::IndexCastUIOp::create(
+                             builder, loc, builder.getIndexType(), input)
+                             .getResult();
+            return vector::BroadcastOp::create(builder, loc, vecTy, cast)
+                .getResult();
           }
         }
       }

>From 85199a626765ead20f5a68aeb7c73bfd9752ad26 Mon Sep 17 00:00:00 2001
From: "Lee, Sang Ik" <sang.ik.lee at intel.com>
Date: Thu, 29 Jan 2026 19:48:09 +0000
Subject: [PATCH 6/6] Split materialization cast callback functions and use
 better names.

---
 .../Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp    | 93 ++++++++++++-------
 1 file changed, 61 insertions(+), 32 deletions(-)

diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index 02b736706c0db..6df209438447b 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -1085,9 +1085,12 @@ struct ConvertXeGPUToXeVMPass
     // add materialization casts to handle them.
 
     // Materialization to convert memref to i64 or i32 depending on global/SLM
-    auto memrefMaterializationCast = [](OpBuilder &builder, Type type,
-                                        ValueRange inputs,
-                                        Location loc) -> Value {
+    // Applies only to target materialization.
+    // Note: int type to memref materialization is not required as xegpu ops
+    // currently do not produce memrefs as result.
+    auto memrefToIntMaterializationCast = [](OpBuilder &builder, Type type,
+                                             ValueRange inputs,
+                                             Location loc) -> Value {
       if (inputs.size() != 1)
         return {};
       auto input = inputs.front();
@@ -1146,9 +1149,12 @@ struct ConvertXeGPUToXeVMPass
     };
 
     // Materialization to convert ui64 to i64
-    auto ui64MaterializationCast = [](OpBuilder &builder, Type type,
-                                      ValueRange inputs,
-                                      Location loc) -> Value {
+    // Applies only to target materialization.
+    // Note: i64 to ui64 materialization is not required as xegpu ops
+    // currently do not produce ui64 as result.
+    auto ui64ToI64MaterializationCast = [](OpBuilder &builder, Type type,
+                                           ValueRange inputs,
+                                           Location loc) -> Value {
       if (inputs.size() != 1)
         return {};
       auto input = inputs.front();
@@ -1163,9 +1169,12 @@ struct ConvertXeGPUToXeVMPass
     };
 
     // Materialization to convert ui32 to i32
-    auto ui32MaterializationCast = [](OpBuilder &builder, Type type,
-                                      ValueRange inputs,
-                                      Location loc) -> Value {
+    // Applies only to target materialization.
+    // Note: i32 to ui32 materialization is not required as xegpu ops
+    // currently do not produce ui32 as result.
+    auto ui32ToI32MaterializationCast = [](OpBuilder &builder, Type type,
+                                           ValueRange inputs,
+                                           Location loc) -> Value {
       if (inputs.size() != 1)
         return {};
       auto input = inputs.front();
@@ -1180,12 +1189,39 @@ struct ConvertXeGPUToXeVMPass
     };
 
     // Materialization to convert
-    //   - single element vector to scalar
     //   - bitcast vector of same rank
     //   - shape vector of different rank but same element type
-    auto vectorMaterializationCast = [](OpBuilder &builder, Type type,
-                                        ValueRange inputs,
-                                        Location loc) -> Value {
+    // Applies to both source and target materialization.
+    auto vectorToVectorMaterializationCast = [](OpBuilder &builder, Type type,
+                                                ValueRange inputs,
+                                                Location loc) -> Value {
+      if (inputs.size() != 1)
+        return {};
+      auto input = inputs.front();
+      if (auto vecTy = dyn_cast<VectorType>(input.getType())) {
+        if (auto targetVecTy = dyn_cast<VectorType>(type)) {
+          // If the target type is a vector of same rank,
+          //   bitcast to the target type.
+          if (targetVecTy.getRank() == vecTy.getRank())
+            return vector::BitCastOp::create(builder, loc, targetVecTy, input)
+                .getResult();
+          else if (targetVecTy.getElementType() == vecTy.getElementType()) {
+            // If the target type is a vector of different rank but same element
+            // type, reshape to the target type.
+            return vector::ShapeCastOp::create(builder, loc, targetVecTy, input)
+                .getResult();
+          }
+        }
+      }
+      return {};
+    };
+
+    // Materialization to convert
+    //   - single element vector to single element of vector element type
+    // Applies only to target materialization.
+    auto vectorToSingleElementMaterializationCast =
+        [](OpBuilder &builder, Type type, ValueRange inputs,
+           Location loc) -> Value {
       if (inputs.size() != 1)
         return {};
       auto input = inputs.front();
@@ -1209,27 +1245,18 @@ struct ConvertXeGPUToXeVMPass
             cast = arith::IndexCastUIOp::create(builder, loc, type, cast)
                        .getResult();
           return cast;
-        } else if (auto targetVecTy = dyn_cast<VectorType>(type)) {
-          // If the target type is a vector of same rank,
-          //   bitcast to the target type.
-          if (targetVecTy.getRank() == vecTy.getRank())
-            return vector::BitCastOp::create(builder, loc, targetVecTy, input)
-                .getResult();
-          else if (targetVecTy.getElementType() == vecTy.getElementType()) {
-            // If the target type is a vector of different rank but same element
-            // type, reshape to the target type.
-            return vector::ShapeCastOp::create(builder, loc, targetVecTy, input)
-                .getResult();
-          }
         }
       }
       return {};
     };
 
+    // Materialization to convert
+    //   - single element of vector element type to single element vector
     // If result type of original op is single element vector and lowered type
     // is scalar. This materialization cast creates a single element vector by
     // broadcasting the scalar value.
-    auto singleElementVectorMaterializationCast =
+    // Applies only to source materialization.
+    auto singleElementToVectorMaterializationCast =
         [](OpBuilder &builder, Type type, ValueRange inputs,
            Location loc) -> Value {
       if (inputs.size() != 1)
@@ -1254,12 +1281,14 @@ struct ConvertXeGPUToXeVMPass
       return {};
     };
     typeConverter.addSourceMaterialization(
-        singleElementVectorMaterializationCast);
-    typeConverter.addSourceMaterialization(vectorMaterializationCast);
-    typeConverter.addTargetMaterialization(memrefMaterializationCast);
-    typeConverter.addTargetMaterialization(ui32MaterializationCast);
-    typeConverter.addTargetMaterialization(ui64MaterializationCast);
-    typeConverter.addTargetMaterialization(vectorMaterializationCast);
+        singleElementToVectorMaterializationCast);
+    typeConverter.addSourceMaterialization(vectorToVectorMaterializationCast);
+    typeConverter.addTargetMaterialization(memrefToIntMaterializationCast);
+    typeConverter.addTargetMaterialization(ui32ToI32MaterializationCast);
+    typeConverter.addTargetMaterialization(ui64ToI64MaterializationCast);
+    typeConverter.addTargetMaterialization(
+        vectorToSingleElementMaterializationCast);
+    typeConverter.addTargetMaterialization(vectorToVectorMaterializationCast);
     ConversionTarget target(getContext());
     target.addLegalDialect<xevm::XeVMDialect, LLVM::LLVMDialect,
                            vector::VectorDialect, arith::ArithDialect,



More information about the Mlir-commits mailing list