[Mlir-commits] [mlir] [MLIR][XeGPU] Validate DPAS operand types against uArch in XeGPUToXeVM conversion (PR #185081)

Arjun Bhamra llvmlistbot at llvm.org
Wed Mar 11 17:54:44 PDT 2026


https://github.com/abhamra updated https://github.com/llvm/llvm-project/pull/185081

>From cea8e96e5bd900c8bef445de47c481514c64e04c Mon Sep 17 00:00:00 2001
From: Arjun Bhamra <arjun.bhamra25 at gmail.com>
Date: Wed, 4 Mar 2026 20:56:26 -0500
Subject: [PATCH 1/8] update xegpu types, remove verifier check

---
 mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td |  8 ++++++--
 mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp           | 14 ++++++++------
 2 files changed, 14 insertions(+), 8 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
index c50bd25df2742..5e10aa2981524 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
@@ -21,8 +21,12 @@ def XeGPU_ScalarType: AnyTypeOf<[XeGPU_IntType, XeGPU_FloatType]>;
 def XeGPU_PointerType : AnyTypeOf<[UI64, UI32, I64, I32]>;
 def XeGPU_BaseAddrType
     : AnyTypeOf<[Non0RankedMemRefOf<[XeGPU_ScalarType]>, XeGPU_PointerType]>;
-def XeGPU_DpasOprType: FixedVectorOfRankAndType<[1, 2, 3], [XeGPU_ScalarType]>;
-def XeGPU_DpasResType: FixedVectorOfRankAndType<[1, 2], [XeGPU_ScalarType]>;
+def XeGPU_DpasOprElemType : AnyTypeOf<[AnyI1, I<4>, AnyI8, AnyI16,
+                                       F4E2M1FN, F8E4M3FN, F8E5M2, F8E8M0FNU, F16, BF16]>;
+def XeGPU_DpasResElemType : AnyTypeOf<[XeGPU_DpasOprElemType, AnyI32, TF32, F32]>;
+
+def XeGPU_DpasOprType : FixedVectorOfRankAndType<[1, 2, 3], [XeGPU_DpasOprElemType]>;
+def XeGPU_DpasResType : FixedVectorOfRankAndType<[1, 2], [XeGPU_DpasResElemType]>;
 def XeGPU_OffsetType: FixedVectorOfNonZeroRankOf<[Index]>;
 def XeGPU_MaskType: FixedVectorOfNonZeroRankOf<[I1]>;
 def XeGPU_ValueType: FixedVectorOfNonZeroRankOf<[XeGPU_ScalarType]>;
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 91ba07a8e0256..81df3159a36b1 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -1045,11 +1045,13 @@ LogicalResult UpdateOffsetOp::verify() {
 // XeGPU_DpasOp
 //===----------------------------------------------------------------------===//
 LogicalResult DpasOp::verify() {
-  int64_t lhsRank = getLhsType().getRank();
-  int64_t rhsRank = getRhsType().getRank();
+  auto lhsType = getLhsType();
+  auto rhsType = getRhsType();
+  int64_t lhsRank = lhsType.getRank();
+  int64_t rhsRank = rhsType.getRank();
   int64_t resRank = getResultType().getRank();
-  auto lhsShape = getLhsType().getShape();
-  auto rhsShape = getRhsType().getShape();
+  auto lhsShape = lhsType.getShape();
+  auto rhsShape = rhsType.getShape();
   auto resShape = getResultType().getShape();
 
   if (getAcc() && getAcc().getType() != getResultType())
@@ -1059,8 +1061,8 @@ LogicalResult DpasOp::verify() {
   // It skips the semantic check since lack of architecture information.
   // Users need to ensure the correctness.
   if (lhsRank == 1 && rhsRank == 1 && resRank == 1) {
-    auto numElems = getRhsType().getNumElements();
-    auto elemTy = getRhsType().getElementType();
+    auto numElems = rhsType.getNumElements();
+    auto elemTy = rhsType.getElementType();
     auto factor = 32 / elemTy.getIntOrFloatBitWidth();
     if (numElems % factor != 0)
       return emitOpError("Expecting B operand to be a multiple of 32 bits.");

>From 2410439b69b4cf85b1dd303725fa7839f56c5ebd Mon Sep 17 00:00:00 2001
From: Arjun Bhamra <arjun.bhamra25 at gmail.com>
Date: Fri, 6 Mar 2026 13:45:53 -0500
Subject: [PATCH 2/8] dpas fix in lowering via uarch checks and
 notifymatcherror

---
 .../mlir/Dialect/XeGPU/IR/XeGPUTypes.td       |  8 ++---
 .../Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp    | 34 +++++++++++++++++++
 mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp        | 14 ++++----
 mlir/test/Conversion/XeGPUToXeVM/dpas.mlir    |  2 +-
 4 files changed, 43 insertions(+), 15 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
index 5e10aa2981524..8286566ddabee 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
@@ -21,12 +21,8 @@ def XeGPU_ScalarType: AnyTypeOf<[XeGPU_IntType, XeGPU_FloatType]>;
 def XeGPU_PointerType : AnyTypeOf<[UI64, UI32, I64, I32]>;
 def XeGPU_BaseAddrType
     : AnyTypeOf<[Non0RankedMemRefOf<[XeGPU_ScalarType]>, XeGPU_PointerType]>;
-def XeGPU_DpasOprElemType : AnyTypeOf<[AnyI1, I<4>, AnyI8, AnyI16,
-                                       F4E2M1FN, F8E4M3FN, F8E5M2, F8E8M0FNU, F16, BF16]>;
-def XeGPU_DpasResElemType : AnyTypeOf<[XeGPU_DpasOprElemType, AnyI32, TF32, F32]>;
-
-def XeGPU_DpasOprType : FixedVectorOfRankAndType<[1, 2, 3], [XeGPU_DpasOprElemType]>;
-def XeGPU_DpasResType : FixedVectorOfRankAndType<[1, 2], [XeGPU_DpasResElemType]>;
+def XeGPU_DpasOprType : FixedVectorOfRankAndType<[1, 2, 3], [XeGPU_ScalarType]>;
+def XeGPU_DpasResType : FixedVectorOfRankAndType<[1, 2], [XeGPU_ScalarType]>;
 def XeGPU_OffsetType: FixedVectorOfNonZeroRankOf<[Index]>;
 def XeGPU_MaskType: FixedVectorOfNonZeroRankOf<[I1]>;
 def XeGPU_ValueType: FixedVectorOfNonZeroRankOf<[XeGPU_ScalarType]>;
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index 8a06271eadd84..d928e24b98493 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -22,6 +22,7 @@
 #include "mlir/Dialect/SCF/Transforms/Patterns.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h"
 #include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Support/LLVM.h"
@@ -902,6 +903,39 @@ class DpasToXeVMPattern : public OpConversionPattern<xegpu::DpasOp> {
     auto bTy = cast<VectorType>(op.getRhs().getType());
     auto resultType = cast<VectorType>(op.getResultType());
 
+    // get the correct dpasInst by getting info from chip
+    auto chipStr = xegpu::getChipStr(op);
+    if (!chipStr)
+      return rewriter.notifyMatchFailure(op, "cannot determine target chip");
+
+    const auto *uArch= mlir::xegpu::uArch::getUArch(*chipStr);
+    if (!uArch)
+      return rewriter.notifyMatchFailure(op, "unsupported target uArch");
+
+    auto *dpasInst = const_cast<xegpu::uArch::SubgroupMatrixMultiplyAcc*>(
+        llvm::dyn_cast_or_null<xegpu::uArch::SubgroupMatrixMultiplyAcc>(
+            uArch->getInstruction(
+                xegpu::uArch::InstructionKind::SubgroupMatrixMultiplyAcc)));
+    if (!dpasInst)
+      return rewriter.notifyMatchFailure(op, "DPAS not supported by target uArch");
+
+    auto supportedA = dpasInst->getSupportedTypes(*ctxt, xegpu::uArch::MMAOpndKind::MatrixA);
+    auto supportedB = dpasInst->getSupportedTypes(*ctxt, xegpu::uArch::MMAOpndKind::MatrixB);
+    auto supportedD = dpasInst->getSupportedTypes(*ctxt, xegpu::uArch::MMAOpndKind::MatrixD);
+    // NOTE: Supported types for MatrixC and MatrixD are identical
+
+    if (llvm::find(supportedA, aTy.getElementType()) == supportedA.end())
+      return rewriter.notifyMatchFailure(
+          op, "A-matrix element type not supported by target uArch");
+
+    if (llvm::find(supportedB, bTy.getElementType()) == supportedB.end())
+      return rewriter.notifyMatchFailure(
+          op, "B-matrix element type not supported by target uArch");
+
+    if (llvm::find(supportedD, resultType.getElementType()) == supportedD.end())
+      return rewriter.notifyMatchFailure(
+          op, "result/accumulator element type not supported by target uArch");
+
     auto encodePrecision = [&](Type type) -> xevm::ElemType {
       if (type == rewriter.getBF16Type())
         return xevm::ElemType::BF16;
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 81df3159a36b1..91ba07a8e0256 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -1045,13 +1045,11 @@ LogicalResult UpdateOffsetOp::verify() {
 // XeGPU_DpasOp
 //===----------------------------------------------------------------------===//
 LogicalResult DpasOp::verify() {
-  auto lhsType = getLhsType();
-  auto rhsType = getRhsType();
-  int64_t lhsRank = lhsType.getRank();
-  int64_t rhsRank = rhsType.getRank();
+  int64_t lhsRank = getLhsType().getRank();
+  int64_t rhsRank = getRhsType().getRank();
   int64_t resRank = getResultType().getRank();
-  auto lhsShape = lhsType.getShape();
-  auto rhsShape = rhsType.getShape();
+  auto lhsShape = getLhsType().getShape();
+  auto rhsShape = getRhsType().getShape();
   auto resShape = getResultType().getShape();
 
   if (getAcc() && getAcc().getType() != getResultType())
@@ -1061,8 +1059,8 @@ LogicalResult DpasOp::verify() {
   // It skips the semantic check since lack of architecture information.
   // Users need to ensure the correctness.
   if (lhsRank == 1 && rhsRank == 1 && resRank == 1) {
-    auto numElems = rhsType.getNumElements();
-    auto elemTy = rhsType.getElementType();
+    auto numElems = getRhsType().getNumElements();
+    auto elemTy = getRhsType().getElementType();
     auto factor = 32 / elemTy.getIntOrFloatBitWidth();
     if (numElems % factor != 0)
       return emitOpError("Expecting B operand to be a multiple of 32 bits.");
diff --git a/mlir/test/Conversion/XeGPUToXeVM/dpas.mlir b/mlir/test/Conversion/XeGPUToXeVM/dpas.mlir
index a9ab0be00722c..7cc59f4cfdcdf 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/dpas.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/dpas.mlir
@@ -1,6 +1,6 @@
 // RUN: mlir-opt -convert-xegpu-to-xevm %s | FileCheck %s
 
-gpu.module @test_kernel {
+gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
     // CHECK-LABEL: func.func @dpas(
     // CHECK-SAME: %[[ARG0:.*]]: vector<8xf16>, %[[ARG1:.*]]: vector<16xf16>, %[[ARG2:.*]]: vector<8xf32>
     func.func @dpas(%a_loaded: vector<8xf16>, %b_loaded: vector<16xf16>, %c_loaded: vector<8xf32>) -> vector<8xf32> {

>From ef9b386ecd84a7cabd1012cd96590a1eebfdcaed Mon Sep 17 00:00:00 2001
From: Arjun Bhamra <arjun.bhamra25 at gmail.com>
Date: Fri, 6 Mar 2026 13:56:51 -0500
Subject: [PATCH 3/8] minor indentation mistake

---
 mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
index 8286566ddabee..c50bd25df2742 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
@@ -21,8 +21,8 @@ def XeGPU_ScalarType: AnyTypeOf<[XeGPU_IntType, XeGPU_FloatType]>;
 def XeGPU_PointerType : AnyTypeOf<[UI64, UI32, I64, I32]>;
 def XeGPU_BaseAddrType
     : AnyTypeOf<[Non0RankedMemRefOf<[XeGPU_ScalarType]>, XeGPU_PointerType]>;
-def XeGPU_DpasOprType : FixedVectorOfRankAndType<[1, 2, 3], [XeGPU_ScalarType]>;
-def XeGPU_DpasResType : FixedVectorOfRankAndType<[1, 2], [XeGPU_ScalarType]>;
+def XeGPU_DpasOprType: FixedVectorOfRankAndType<[1, 2, 3], [XeGPU_ScalarType]>;
+def XeGPU_DpasResType: FixedVectorOfRankAndType<[1, 2], [XeGPU_ScalarType]>;
 def XeGPU_OffsetType: FixedVectorOfNonZeroRankOf<[Index]>;
 def XeGPU_MaskType: FixedVectorOfNonZeroRankOf<[I1]>;
 def XeGPU_ValueType: FixedVectorOfNonZeroRankOf<[XeGPU_ScalarType]>;

>From 546a34b09596a52ca822e74beaf52e5b6520ed9d Mon Sep 17 00:00:00 2001
From: Arjun Bhamra <arjun.bhamra25 at gmail.com>
Date: Fri, 6 Mar 2026 14:03:54 -0500
Subject: [PATCH 4/8] failed conversion test

---
 .../Conversion/XeGPUToXeVM/failed_conversion.mlir  | 14 ++++++++++++++
 1 file changed, 14 insertions(+)
 create mode 100644 mlir/test/Conversion/XeGPUToXeVM/failed_conversion.mlir

diff --git a/mlir/test/Conversion/XeGPUToXeVM/failed_conversion.mlir b/mlir/test/Conversion/XeGPUToXeVM/failed_conversion.mlir
new file mode 100644
index 0000000000000..c4088c430a4d9
--- /dev/null
+++ b/mlir/test/Conversion/XeGPUToXeVM/failed_conversion.mlir
@@ -0,0 +1,14 @@
+// RUN: mlir-opt --convert-xegpu-to-xevm %s -split-input-file -verify-diagnostics
+
+// Verify that xegpu.dpas with unsupported element types (i16) is rejected
+// during XeGPUToXeVM conversion rather than crashing.
+
+gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
+  func.func @main() {
+    %0 = spirv.Constant dense<0> : vector<4xi16>
+    %1 = spirv.Constant dense<0> : vector<4xi32>
+    // expected-error at +1 {{failed to legalize operation 'xegpu.dpas' that was explicitly marked illegal}}
+    %2 = xegpu.dpas %0, %0, %1 : vector<4xi16>, vector<4xi16>, vector<4xi32> -> vector<4xi32>
+    return
+  }
+}

>From 1e4b0ea652453bd2c8ba74ce1b050c216e3deb8f Mon Sep 17 00:00:00 2001
From: Arjun Bhamra <arjun.bhamra25 at gmail.com>
Date: Fri, 6 Mar 2026 14:16:15 -0500
Subject: [PATCH 5/8] formatter issues

---
 .../lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp | 18 +++++++++++-------
 1 file changed, 11 insertions(+), 7 deletions(-)

diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index d928e24b98493..94430e1068c99 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -908,20 +908,24 @@ class DpasToXeVMPattern : public OpConversionPattern<xegpu::DpasOp> {
     if (!chipStr)
       return rewriter.notifyMatchFailure(op, "cannot determine target chip");
 
-    const auto *uArch= mlir::xegpu::uArch::getUArch(*chipStr);
+    const auto *uArch = mlir::xegpu::uArch::getUArch(*chipStr);
     if (!uArch)
       return rewriter.notifyMatchFailure(op, "unsupported target uArch");
 
-    auto *dpasInst = const_cast<xegpu::uArch::SubgroupMatrixMultiplyAcc*>(
+    auto *dpasInst = const_cast<xegpu::uArch::SubgroupMatrixMultiplyAcc *>(
         llvm::dyn_cast_or_null<xegpu::uArch::SubgroupMatrixMultiplyAcc>(
             uArch->getInstruction(
                 xegpu::uArch::InstructionKind::SubgroupMatrixMultiplyAcc)));
     if (!dpasInst)
-      return rewriter.notifyMatchFailure(op, "DPAS not supported by target uArch");
-
-    auto supportedA = dpasInst->getSupportedTypes(*ctxt, xegpu::uArch::MMAOpndKind::MatrixA);
-    auto supportedB = dpasInst->getSupportedTypes(*ctxt, xegpu::uArch::MMAOpndKind::MatrixB);
-    auto supportedD = dpasInst->getSupportedTypes(*ctxt, xegpu::uArch::MMAOpndKind::MatrixD);
+      return rewriter.notifyMatchFailure(op,
+                                         "DPAS not supported by target uArch");
+
+    auto supportedA =
+      dpasInst->getSupportedTypes(*ctxt, xegpu::uArch::MMAOpndKind::MatrixA);
+    auto supportedB =
+      dpasInst->getSupportedTypes(*ctxt, xegpu::uArch::MMAOpndKind::MatrixB);
+    auto supportedD =
+      dpasInst->getSupportedTypes(*ctxt, xegpu::uArch::MMAOpndKind::MatrixD);
     // NOTE: Supported types for MatrixC and MatrixD are identical
 
     if (llvm::find(supportedA, aTy.getElementType()) == supportedA.end())

>From 5b5778c6bf017c8e6316a14509fbd19fe07a56e4 Mon Sep 17 00:00:00 2001
From: Arjun Bhamra <arjun.bhamra25 at gmail.com>
Date: Fri, 6 Mar 2026 14:20:07 -0500
Subject: [PATCH 6/8] fix formatting again

---
 .../Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp    | 126 ++++++------------
 1 file changed, 40 insertions(+), 86 deletions(-)

diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index 2d6f9c927484c..795bc0ad84a4a 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -22,8 +22,8 @@
 #include "mlir/Dialect/SCF/Transforms/Patterns.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/XeGPU/IR/XeGPU.h"
-#include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h"
 #include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
+#include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Support/LLVM.h"
 #include "llvm/ADT/STLExtras.h"
@@ -921,11 +921,11 @@ class DpasToXeVMPattern : public OpConversionPattern<xegpu::DpasOp> {
                                          "DPAS not supported by target uArch");
 
     auto supportedA =
-      dpasInst->getSupportedTypes(*ctxt, xegpu::uArch::MMAOpndKind::MatrixA);
+        dpasInst->getSupportedTypes(*ctxt, xegpu::uArch::MMAOpndKind::MatrixA);
     auto supportedB =
-      dpasInst->getSupportedTypes(*ctxt, xegpu::uArch::MMAOpndKind::MatrixB);
+        dpasInst->getSupportedTypes(*ctxt, xegpu::uArch::MMAOpndKind::MatrixB);
     auto supportedD =
-      dpasInst->getSupportedTypes(*ctxt, xegpu::uArch::MMAOpndKind::MatrixD);
+        dpasInst->getSupportedTypes(*ctxt, xegpu::uArch::MMAOpndKind::MatrixD);
     // NOTE: Supported types for MatrixC and MatrixD are identical
 
     if (llvm::find(supportedA, aTy.getElementType()) == supportedA.end())
@@ -1094,8 +1094,8 @@ struct ConvertXeGPUToXeVMPass
       // If the element type is index, convert it to i64.
       if (llvm::isa<IndexType>(elemType))
         elemType = IntegerType::get(&getContext(), 64);
-      // If the vector rank is 0 or has a single element, return the element
-      if (rank == 0 || type.getNumElements() == 1)
+      // If the vector is a scalar or has a single element, return the element
+      if (rank < 1 || type.getNumElements() == 1)
         return elemType;
       // Otherwise, convert the vector to a flat vector type.
       int64_t sum = llvm::product_of(type.getShape());
@@ -1123,12 +1123,9 @@ struct ConvertXeGPUToXeVMPass
     // add materialization casts to handle them.
 
     // Materialization to convert memref to i64 or i32 depending on global/SLM
-    // Applies only to target materialization.
-    // Note: int type to memref materialization is not required as xegpu ops
-    // currently do not produce memrefs as result.
-    auto memrefToIntMaterializationCast = [](OpBuilder &builder, Type type,
-                                             ValueRange inputs,
-                                             Location loc) -> Value {
+    auto memrefMaterializationCast = [](OpBuilder &builder, Type type,
+                                        ValueRange inputs,
+                                        Location loc) -> Value {
       if (inputs.size() != 1)
         return {};
       auto input = inputs.front();
@@ -1187,12 +1184,9 @@ struct ConvertXeGPUToXeVMPass
     };
 
     // Materialization to convert ui64 to i64
-    // Applies only to target materialization.
-    // Note: i64 to ui64 materialization is not required as xegpu ops
-    // currently do not produce ui64 as result.
-    auto ui64ToI64MaterializationCast = [](OpBuilder &builder, Type type,
-                                           ValueRange inputs,
-                                           Location loc) -> Value {
+    auto ui64MaterializationCast = [](OpBuilder &builder, Type type,
+                                      ValueRange inputs,
+                                      Location loc) -> Value {
       if (inputs.size() != 1)
         return {};
       auto input = inputs.front();
@@ -1207,12 +1201,9 @@ struct ConvertXeGPUToXeVMPass
     };
 
     // Materialization to convert ui32 to i32
-    // Applies only to target materialization.
-    // Note: i32 to ui32 materialization is not required as xegpu ops
-    // currently do not produce ui32 as result.
-    auto ui32ToI32MaterializationCast = [](OpBuilder &builder, Type type,
-                                           ValueRange inputs,
-                                           Location loc) -> Value {
+    auto ui32MaterializationCast = [](OpBuilder &builder, Type type,
+                                      ValueRange inputs,
+                                      Location loc) -> Value {
       if (inputs.size() != 1)
         return {};
       auto input = inputs.front();
@@ -1227,17 +1218,25 @@ struct ConvertXeGPUToXeVMPass
     };
 
     // Materialization to convert
+    //   - single element 1D vector to scalar
     //   - bitcast vector of same rank
     //   - shape vector of different rank but same element type
-    // Applies to both source and target materialization.
-    auto vectorToVectorMaterializationCast = [](OpBuilder &builder, Type type,
-                                                ValueRange inputs,
-                                                Location loc) -> Value {
+    auto vectorMaterializationCast = [](OpBuilder &builder, Type type,
+                                        ValueRange inputs,
+                                        Location loc) -> Value {
       if (inputs.size() != 1)
         return {};
       auto input = inputs.front();
       if (auto vecTy = dyn_cast<VectorType>(input.getType())) {
-        if (auto targetVecTy = dyn_cast<VectorType>(type)) {
+        if (vecTy.getNumElements() == 1) {
+          // If the vector has a single element, return the element type.
+          Value cast =
+              vector::ExtractOp::create(builder, loc, input, 0).getResult();
+          if (vecTy.getElementType() == builder.getIndexType())
+            cast = arith::IndexCastUIOp::create(builder, loc, type, cast)
+                       .getResult();
+          return cast;
+        } else if (auto targetVecTy = dyn_cast<VectorType>(type)) {
           // If the target type is a vector of same rank,
           //   bitcast to the target type.
           if (targetVecTy.getRank() == vecTy.getRank())
@@ -1254,79 +1253,34 @@ struct ConvertXeGPUToXeVMPass
       return {};
     };
 
-    // Materialization to convert
-    //   - single element vector to single element of vector element type
-    // Applies only to target materialization.
-    auto vectorToSingleElementMaterializationCast =
-        [](OpBuilder &builder, Type type, ValueRange inputs,
-           Location loc) -> Value {
-      if (inputs.size() != 1)
-        return {};
-      auto input = inputs.front();
-      if (auto vecTy = dyn_cast<VectorType>(input.getType())) {
-        if (type == vecTy.getElementType() ||
-            ((vecTy.getElementType() == builder.getIndexType()) &&
-             type.isInteger())) {
-          // If the vector rank is 0 or has a single element,
-          // extract scalar of target type.
-          auto rank = vecTy.getRank();
-          Value cast;
-          if (rank == 0) {
-            cast =
-                vector::ExtractOp::create(builder, loc, input, {}).getResult();
-          } else {
-            cast = vector::ExtractOp::create(builder, loc, input,
-                                             SmallVector<int64_t>(rank, 0))
-                       .getResult();
-          }
-          if (type != vecTy.getElementType())
-            cast = arith::IndexCastUIOp::create(builder, loc, type, cast)
-                       .getResult();
-          return cast;
-        }
-      }
-      return {};
-    };
-
-    // Materialization to convert
-    //   - single element of vector element type to single element vector
     // If result type of original op is single element vector and lowered type
     // is scalar. This materialization cast creates a single element vector by
     // broadcasting the scalar value.
-    // Applies only to source materialization.
-    auto singleElementToVectorMaterializationCast =
+    auto singleElementVectorMaterializationCast =
         [](OpBuilder &builder, Type type, ValueRange inputs,
            Location loc) -> Value {
       if (inputs.size() != 1)
         return {};
       auto input = inputs.front();
-      // If the target type is a vector of rank 0 or single element vector
-      // of element type matching input type, broadcast input to target type.
-      if (auto vecTy = dyn_cast<VectorType>(type)) {
-        if (vecTy.getRank() == 0 || vecTy.getNumElements() == 1) {
-          if (input.getType() == vecTy.getElementType()) {
+      if (input.getType().isIntOrIndexOrFloat()) {
+        // If the input is a scalar, and the target type is a vector of single
+        // element, create a single element vector by broadcasting.
+        if (auto vecTy = dyn_cast<VectorType>(type)) {
+          if (vecTy.getNumElements() == 1) {
             return vector::BroadcastOp::create(builder, loc, vecTy, input)
                 .getResult();
-          } else if (vecTy.getElementType() == builder.getIndexType()) {
-            Value cast = arith::IndexCastUIOp::create(
-                             builder, loc, builder.getIndexType(), input)
-                             .getResult();
-            return vector::BroadcastOp::create(builder, loc, vecTy, cast)
-                .getResult();
           }
         }
       }
       return {};
     };
     typeConverter.addSourceMaterialization(
-        singleElementToVectorMaterializationCast);
-    typeConverter.addSourceMaterialization(vectorToVectorMaterializationCast);
-    typeConverter.addTargetMaterialization(memrefToIntMaterializationCast);
-    typeConverter.addTargetMaterialization(ui32ToI32MaterializationCast);
-    typeConverter.addTargetMaterialization(ui64ToI64MaterializationCast);
-    typeConverter.addTargetMaterialization(
-        vectorToSingleElementMaterializationCast);
-    typeConverter.addTargetMaterialization(vectorToVectorMaterializationCast);
+        singleElementVectorMaterializationCast);
+    typeConverter.addSourceMaterialization(vectorMaterializationCast);
+    typeConverter.addTargetMaterialization(memrefMaterializationCast);
+    typeConverter.addTargetMaterialization(ui32MaterializationCast);
+    typeConverter.addTargetMaterialization(ui64MaterializationCast);
+    typeConverter.addTargetMaterialization(vectorMaterializationCast);
     ConversionTarget target(getContext());
     target.addLegalDialect<xevm::XeVMDialect, LLVM::LLVMDialect,
                            vector::VectorDialect, arith::ArithDialect,

>From 32f2f30d2a930a675a5f496dc677128201d376cd Mon Sep 17 00:00:00 2001
From: Arjun Bhamra <arjun.bhamra25 at gmail.com>
Date: Tue, 10 Mar 2026 11:17:48 -0400
Subject: [PATCH 7/8] reverted conversion changes, added changes to verifier
 only

---
 .../Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp    | 156 +++++++++---------
 mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp        |  23 ++-
 mlir/test/Conversion/XeGPUToXeVM/dpas.mlir    |   2 +-
 .../XeGPUToXeVM/failed_conversion.mlir        |  14 --
 mlir/test/Dialect/XeGPU/invalid.mlir          |   8 +
 5 files changed, 107 insertions(+), 96 deletions(-)
 delete mode 100644 mlir/test/Conversion/XeGPUToXeVM/failed_conversion.mlir

diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index 795bc0ad84a4a..6df209438447b 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -23,7 +23,6 @@
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/XeGPU/IR/XeGPU.h"
 #include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
-#include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Support/LLVM.h"
 #include "llvm/ADT/STLExtras.h"
@@ -903,43 +902,6 @@ class DpasToXeVMPattern : public OpConversionPattern<xegpu::DpasOp> {
     auto bTy = cast<VectorType>(op.getRhs().getType());
     auto resultType = cast<VectorType>(op.getResultType());
 
-    // get the correct dpasInst by getting info from chip
-    auto chipStr = xegpu::getChipStr(op);
-    if (!chipStr)
-      return rewriter.notifyMatchFailure(op, "cannot determine target chip");
-
-    const auto *uArch = mlir::xegpu::uArch::getUArch(*chipStr);
-    if (!uArch)
-      return rewriter.notifyMatchFailure(op, "unsupported target uArch");
-
-    auto *dpasInst = const_cast<xegpu::uArch::SubgroupMatrixMultiplyAcc *>(
-        llvm::dyn_cast_or_null<xegpu::uArch::SubgroupMatrixMultiplyAcc>(
-            uArch->getInstruction(
-                xegpu::uArch::InstructionKind::SubgroupMatrixMultiplyAcc)));
-    if (!dpasInst)
-      return rewriter.notifyMatchFailure(op,
-                                         "DPAS not supported by target uArch");
-
-    auto supportedA =
-        dpasInst->getSupportedTypes(*ctxt, xegpu::uArch::MMAOpndKind::MatrixA);
-    auto supportedB =
-        dpasInst->getSupportedTypes(*ctxt, xegpu::uArch::MMAOpndKind::MatrixB);
-    auto supportedD =
-        dpasInst->getSupportedTypes(*ctxt, xegpu::uArch::MMAOpndKind::MatrixD);
-    // NOTE: Supported types for MatrixC and MatrixD are identical
-
-    if (llvm::find(supportedA, aTy.getElementType()) == supportedA.end())
-      return rewriter.notifyMatchFailure(
-          op, "A-matrix element type not supported by target uArch");
-
-    if (llvm::find(supportedB, bTy.getElementType()) == supportedB.end())
-      return rewriter.notifyMatchFailure(
-          op, "B-matrix element type not supported by target uArch");
-
-    if (llvm::find(supportedD, resultType.getElementType()) == supportedD.end())
-      return rewriter.notifyMatchFailure(
-          op, "result/accumulator element type not supported by target uArch");
-
     auto encodePrecision = [&](Type type) -> xevm::ElemType {
       if (type == rewriter.getBF16Type())
         return xevm::ElemType::BF16;
@@ -1094,8 +1056,8 @@ struct ConvertXeGPUToXeVMPass
       // If the element type is index, convert it to i64.
       if (llvm::isa<IndexType>(elemType))
         elemType = IntegerType::get(&getContext(), 64);
-      // If the vector is a scalar or has a single element, return the element
-      if (rank < 1 || type.getNumElements() == 1)
+      // If the vector rank is 0 or has a single element, return the element
+      if (rank == 0 || type.getNumElements() == 1)
         return elemType;
       // Otherwise, convert the vector to a flat vector type.
       int64_t sum = llvm::product_of(type.getShape());
@@ -1123,9 +1085,12 @@ struct ConvertXeGPUToXeVMPass
     // add materialization casts to handle them.
 
     // Materialization to convert memref to i64 or i32 depending on global/SLM
-    auto memrefMaterializationCast = [](OpBuilder &builder, Type type,
-                                        ValueRange inputs,
-                                        Location loc) -> Value {
+    // Applies only to target materialization.
+    // Note: int type to memref materialization is not required as xegpu ops
+    // currently do not produce memrefs as result.
+    auto memrefToIntMaterializationCast = [](OpBuilder &builder, Type type,
+                                             ValueRange inputs,
+                                             Location loc) -> Value {
       if (inputs.size() != 1)
         return {};
       auto input = inputs.front();
@@ -1184,9 +1149,12 @@ struct ConvertXeGPUToXeVMPass
     };
 
     // Materialization to convert ui64 to i64
-    auto ui64MaterializationCast = [](OpBuilder &builder, Type type,
-                                      ValueRange inputs,
-                                      Location loc) -> Value {
+    // Applies only to target materialization.
+    // Note: i64 to ui64 materialization is not required as xegpu ops
+    // currently do not produce ui64 as result.
+    auto ui64ToI64MaterializationCast = [](OpBuilder &builder, Type type,
+                                           ValueRange inputs,
+                                           Location loc) -> Value {
       if (inputs.size() != 1)
         return {};
       auto input = inputs.front();
@@ -1201,9 +1169,12 @@ struct ConvertXeGPUToXeVMPass
     };
 
     // Materialization to convert ui32 to i32
-    auto ui32MaterializationCast = [](OpBuilder &builder, Type type,
-                                      ValueRange inputs,
-                                      Location loc) -> Value {
+    // Applies only to target materialization.
+    // Note: i32 to ui32 materialization is not required as xegpu ops
+    // currently do not produce ui32 as result.
+    auto ui32ToI32MaterializationCast = [](OpBuilder &builder, Type type,
+                                           ValueRange inputs,
+                                           Location loc) -> Value {
       if (inputs.size() != 1)
         return {};
       auto input = inputs.front();
@@ -1218,25 +1189,17 @@ struct ConvertXeGPUToXeVMPass
     };
 
     // Materialization to convert
-    //   - single element 1D vector to scalar
     //   - bitcast vector of same rank
     //   - shape vector of different rank but same element type
-    auto vectorMaterializationCast = [](OpBuilder &builder, Type type,
-                                        ValueRange inputs,
-                                        Location loc) -> Value {
+    // Applies to both source and target materialization.
+    auto vectorToVectorMaterializationCast = [](OpBuilder &builder, Type type,
+                                                ValueRange inputs,
+                                                Location loc) -> Value {
       if (inputs.size() != 1)
         return {};
       auto input = inputs.front();
       if (auto vecTy = dyn_cast<VectorType>(input.getType())) {
-        if (vecTy.getNumElements() == 1) {
-          // If the vector has a single element, return the element type.
-          Value cast =
-              vector::ExtractOp::create(builder, loc, input, 0).getResult();
-          if (vecTy.getElementType() == builder.getIndexType())
-            cast = arith::IndexCastUIOp::create(builder, loc, type, cast)
-                       .getResult();
-          return cast;
-        } else if (auto targetVecTy = dyn_cast<VectorType>(type)) {
+        if (auto targetVecTy = dyn_cast<VectorType>(type)) {
           // If the target type is a vector of same rank,
           //   bitcast to the target type.
           if (targetVecTy.getRank() == vecTy.getRank())
@@ -1253,34 +1216,79 @@ struct ConvertXeGPUToXeVMPass
       return {};
     };
 
+    // Materialization to convert
+    //   - single element vector to single element of vector element type
+    // Applies only to target materialization.
+    auto vectorToSingleElementMaterializationCast =
+        [](OpBuilder &builder, Type type, ValueRange inputs,
+           Location loc) -> Value {
+      if (inputs.size() != 1)
+        return {};
+      auto input = inputs.front();
+      if (auto vecTy = dyn_cast<VectorType>(input.getType())) {
+        if (type == vecTy.getElementType() ||
+            ((vecTy.getElementType() == builder.getIndexType()) &&
+             type.isInteger())) {
+          // If the vector rank is 0 or has a single element,
+          // extract scalar of target type.
+          auto rank = vecTy.getRank();
+          Value cast;
+          if (rank == 0) {
+            cast =
+                vector::ExtractOp::create(builder, loc, input, {}).getResult();
+          } else {
+            cast = vector::ExtractOp::create(builder, loc, input,
+                                             SmallVector<int64_t>(rank, 0))
+                       .getResult();
+          }
+          if (type != vecTy.getElementType())
+            cast = arith::IndexCastUIOp::create(builder, loc, type, cast)
+                       .getResult();
+          return cast;
+        }
+      }
+      return {};
+    };
+
+    // Materialization to convert
+    //   - single element of vector element type to single element vector
     // If result type of original op is single element vector and lowered type
     // is scalar. This materialization cast creates a single element vector by
     // broadcasting the scalar value.
-    auto singleElementVectorMaterializationCast =
+    // Applies only to source materialization.
+    auto singleElementToVectorMaterializationCast =
         [](OpBuilder &builder, Type type, ValueRange inputs,
            Location loc) -> Value {
       if (inputs.size() != 1)
         return {};
       auto input = inputs.front();
-      if (input.getType().isIntOrIndexOrFloat()) {
-        // If the input is a scalar, and the target type is a vector of single
-        // element, create a single element vector by broadcasting.
-        if (auto vecTy = dyn_cast<VectorType>(type)) {
-          if (vecTy.getNumElements() == 1) {
+      // If the target type is a vector of rank 0 or single element vector
+      // of element type matching input type, broadcast input to target type.
+      if (auto vecTy = dyn_cast<VectorType>(type)) {
+        if (vecTy.getRank() == 0 || vecTy.getNumElements() == 1) {
+          if (input.getType() == vecTy.getElementType()) {
             return vector::BroadcastOp::create(builder, loc, vecTy, input)
                 .getResult();
+          } else if (vecTy.getElementType() == builder.getIndexType()) {
+            Value cast = arith::IndexCastUIOp::create(
+                             builder, loc, builder.getIndexType(), input)
+                             .getResult();
+            return vector::BroadcastOp::create(builder, loc, vecTy, cast)
+                .getResult();
           }
         }
       }
       return {};
     };
     typeConverter.addSourceMaterialization(
-        singleElementVectorMaterializationCast);
-    typeConverter.addSourceMaterialization(vectorMaterializationCast);
-    typeConverter.addTargetMaterialization(memrefMaterializationCast);
-    typeConverter.addTargetMaterialization(ui32MaterializationCast);
-    typeConverter.addTargetMaterialization(ui64MaterializationCast);
-    typeConverter.addTargetMaterialization(vectorMaterializationCast);
+        singleElementToVectorMaterializationCast);
+    typeConverter.addSourceMaterialization(vectorToVectorMaterializationCast);
+    typeConverter.addTargetMaterialization(memrefToIntMaterializationCast);
+    typeConverter.addTargetMaterialization(ui32ToI32MaterializationCast);
+    typeConverter.addTargetMaterialization(ui64ToI64MaterializationCast);
+    typeConverter.addTargetMaterialization(
+        vectorToSingleElementMaterializationCast);
+    typeConverter.addTargetMaterialization(vectorToVectorMaterializationCast);
     ConversionTarget target(getContext());
     target.addLegalDialect<xevm::XeVMDialect, LLVM::LLVMDialect,
                            vector::VectorDialect, arith::ArithDialect,
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index e470d1f820f79..4973b677f65f8 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -1045,22 +1045,32 @@ LogicalResult UpdateOffsetOp::verify() {
 // XeGPU_DpasOp
 //===----------------------------------------------------------------------===//
 LogicalResult DpasOp::verify() {
-  int64_t lhsRank = getLhsType().getRank();
-  int64_t rhsRank = getRhsType().getRank();
+  auto lhsType = getLhsType();
+  auto rhsType = getRhsType();
+  int64_t lhsRank = lhsType.getRank();
+  int64_t rhsRank = rhsType.getRank();
   int64_t resRank = getResultType().getRank();
-  auto lhsShape = getLhsType().getShape();
-  auto rhsShape = getRhsType().getShape();
+  auto lhsShape = lhsType.getShape();
+  auto rhsShape = rhsType.getShape();
   auto resShape = getResultType().getShape();
 
   if (getAcc() && getAcc().getType() != getResultType())
     return emitOpError("Expecting the acc type to be the same as result.");
 
+  auto lhsElemTy = lhsType.getElementType();
+  auto rhsElemTy = rhsType.getElementType();
+  if (lhsElemTy.getIntOrFloatBitWidth() >= 32 ||
+      rhsElemTy.getIntOrFloatBitWidth() >= 32) {
+    return emitOpError(
+        "Expecting lhs and rhs element types to be at most 32 bits.");
+  }
+
   // SIMT code: the size of the B operand has to be a multiple of 32 bits.
   // It skips the semantic check since lack of architecture information.
   // Users need to ensure the correctness.
   if (lhsRank == 1 && rhsRank == 1 && resRank == 1) {
-    auto numElems = getRhsType().getNumElements();
-    auto elemTy = getRhsType().getElementType();
+    auto numElems = rhsType.getNumElements();
+    auto elemTy = rhsType.getElementType();
     auto factor = 32 / elemTy.getIntOrFloatBitWidth();
     if (numElems % factor != 0)
       return emitOpError("Expecting B operand to be a multiple of 32 bits.");
@@ -1082,7 +1092,6 @@ LogicalResult DpasOp::verify() {
 
   return success();
 }
-
 //===----------------------------------------------------------------------===//
 // XeGPU_ConvertLayoutOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/XeGPUToXeVM/dpas.mlir b/mlir/test/Conversion/XeGPUToXeVM/dpas.mlir
index 7cc59f4cfdcdf..a9ab0be00722c 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/dpas.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/dpas.mlir
@@ -1,6 +1,6 @@
 // RUN: mlir-opt -convert-xegpu-to-xevm %s | FileCheck %s
 
-gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
+gpu.module @test_kernel {
     // CHECK-LABEL: func.func @dpas(
     // CHECK-SAME: %[[ARG0:.*]]: vector<8xf16>, %[[ARG1:.*]]: vector<16xf16>, %[[ARG2:.*]]: vector<8xf32>
     func.func @dpas(%a_loaded: vector<8xf16>, %b_loaded: vector<16xf16>, %c_loaded: vector<8xf32>) -> vector<8xf32> {
diff --git a/mlir/test/Conversion/XeGPUToXeVM/failed_conversion.mlir b/mlir/test/Conversion/XeGPUToXeVM/failed_conversion.mlir
deleted file mode 100644
index c4088c430a4d9..0000000000000
--- a/mlir/test/Conversion/XeGPUToXeVM/failed_conversion.mlir
+++ /dev/null
@@ -1,14 +0,0 @@
-// RUN: mlir-opt --convert-xegpu-to-xevm %s -split-input-file -verify-diagnostics
-
-// Verify that xegpu.dpas with unsupported element types (i16) is rejected
-// during XeGPUToXeVM conversion rather than crashing.
-
-gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
-  func.func @main() {
-    %0 = spirv.Constant dense<0> : vector<4xi16>
-    %1 = spirv.Constant dense<0> : vector<4xi32>
-    // expected-error at +1 {{failed to legalize operation 'xegpu.dpas' that was explicitly marked illegal}}
-    %2 = xegpu.dpas %0, %0, %1 : vector<4xi16>, vector<4xi16>, vector<4xi32> -> vector<4xi32>
-    return
-  }
-}
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index 53d497e4c2087..34083f72f4e67 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -601,6 +601,14 @@ func.func @dpas_5(%a : vector<8x16xf16>, %b: vector<8x8x2xf16>) {
   return
 }
 
+// -----
+func.func @dpas_6() {
+  %0 = arith.constant dense<0> : vector<4xi32>
+  // expected-error at +1 {{'xegpu.dpas' op Expecting lhs and rhs element types to be at most 32 bits.}}
+  %1 = xegpu.dpas %0, %0 : vector<4xi32>, vector<4xi32> -> vector<4xi32>
+  return
+}
+
 // -----
 func.func @dpas_simt_1(%a : vector<8xf16>, %b: vector<15xf16>) {
   // expected-error at +1 {{Expecting B operand to be a multiple of 32 bits}}

>From 252d8a75a073d1ece7bdf07feab29b0b6b360a5a Mon Sep 17 00:00:00 2001
From: Arjun Bhamra <arjun.bhamra25 at gmail.com>
Date: Wed, 11 Mar 2026 20:54:08 -0400
Subject: [PATCH 8/8] re-added failed conversion check, uArch check, etc.

---
 .../Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp    | 35 +++++++++++++++++++
 .../XeGPUToXeVM/failed_conversion.mlir        | 14 ++++++++
 2 files changed, 49 insertions(+)
 create mode 100644 mlir/test/Conversion/XeGPUToXeVM/failed_conversion.mlir

diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index 6df209438447b..6dc9d0a55a308 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -23,6 +23,7 @@
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/XeGPU/IR/XeGPU.h"
 #include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
+#include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Support/LLVM.h"
 #include "llvm/ADT/STLExtras.h"
@@ -902,6 +903,40 @@ class DpasToXeVMPattern : public OpConversionPattern<xegpu::DpasOp> {
     auto bTy = cast<VectorType>(op.getRhs().getType());
     auto resultType = cast<VectorType>(op.getResultType());
 
+    // get the correct dpasInst by getting info from chip
+    auto chipStr = xegpu::getChipStr(op);
+    if (!chipStr)
+      return rewriter.notifyMatchFailure(op, "cannot determine target chip");
+
+    const auto *uArch = mlir::xegpu::uArch::getUArch(*chipStr);
+    if (!uArch)
+      return rewriter.notifyMatchFailure(op, "unsupported target uArch");
+
+    auto *dpasInst = const_cast<xegpu::uArch::SubgroupMatrixMultiplyAcc *>(
+        llvm::dyn_cast_or_null<xegpu::uArch::SubgroupMatrixMultiplyAcc>(
+            uArch->getInstruction(
+                xegpu::uArch::InstructionKind::SubgroupMatrixMultiplyAcc)));    
+    if (!dpasInst)
+      return rewriter.notifyMatchFailure(op,
+                                         "DPAS not supported by target uArch");
+
+    auto checkSupportedTypes = [&](VectorType vecTy,
+                                    xegpu::uArch::MMAOpndKind kind) -> bool {
+      auto supported = dpasInst->getSupportedTypes(*ctxt, kind);
+      return llvm::find(supported, vecTy.getElementType()) != supported.end();
+    };
+
+    if (!checkSupportedTypes(aTy, xegpu::uArch::MMAOpndKind::MatrixA))
+      return rewriter.notifyMatchFailure(
+          op, "A-matrix element type not supported by target uArch");
+    if (!checkSupportedTypes(bTy, xegpu::uArch::MMAOpndKind::MatrixB))
+      return rewriter.notifyMatchFailure(
+          op, "B-matrix element type not supported by target uArch");
+    // NOTE: Supported types for MatrixC and MatrixD are identical
+    if (!checkSupportedTypes(resultType, xegpu::uArch::MMAOpndKind::MatrixD))
+      return rewriter.notifyMatchFailure(
+          op, "result/accumulator element type not supported by target uArch");
+
     auto encodePrecision = [&](Type type) -> xevm::ElemType {
       if (type == rewriter.getBF16Type())
         return xevm::ElemType::BF16;
diff --git a/mlir/test/Conversion/XeGPUToXeVM/failed_conversion.mlir b/mlir/test/Conversion/XeGPUToXeVM/failed_conversion.mlir
new file mode 100644
index 0000000000000..95211dcff250c
--- /dev/null
+++ b/mlir/test/Conversion/XeGPUToXeVM/failed_conversion.mlir
@@ -0,0 +1,14 @@
+// RUN: mlir-opt --convert-xegpu-to-xevm %s -split-input-file -verify-diagnostics
+
+// Verify that xegpu.dpas with unsupported element types (i16) is rejected
+// during XeGPUToXeVM conversion rather than crashing.
+
+gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
+  func.func @main() {
+    %0 = arith.constant dense<0> : vector<4xi16>
+    %1 = arith.constant dense<0> : vector<4xi32>
+    // expected-error at +1 {{failed to legalize operation 'xegpu.dpas' that was explicitly marked illegal}}
+    %2 = xegpu.dpas %0, %0, %1 : vector<4xi16>, vector<4xi16>, vector<4xi32> -> vector<4xi32>
+    return
+  }
+}



More information about the Mlir-commits mailing list