[Mlir-commits] [mlir] [MLIR][XeGPU] Validate DPAS operand types against uArch in XeGPUToXeVM conversion (PR #185081)

Arjun Bhamra llvmlistbot at llvm.org
Fri Mar 6 11:16:51 PST 2026


https://github.com/abhamra updated https://github.com/llvm/llvm-project/pull/185081

>From cea8e96e5bd900c8bef445de47c481514c64e04c Mon Sep 17 00:00:00 2001
From: Arjun Bhamra <arjun.bhamra25 at gmail.com>
Date: Wed, 4 Mar 2026 20:56:26 -0500
Subject: [PATCH 1/5] update xegpu types, remove verifier check

---
 mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td |  8 ++++++--
 mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp           | 14 ++++++++------
 2 files changed, 14 insertions(+), 8 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
index c50bd25df2742..5e10aa2981524 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
@@ -21,8 +21,12 @@ def XeGPU_ScalarType: AnyTypeOf<[XeGPU_IntType, XeGPU_FloatType]>;
 def XeGPU_PointerType : AnyTypeOf<[UI64, UI32, I64, I32]>;
 def XeGPU_BaseAddrType
     : AnyTypeOf<[Non0RankedMemRefOf<[XeGPU_ScalarType]>, XeGPU_PointerType]>;
-def XeGPU_DpasOprType: FixedVectorOfRankAndType<[1, 2, 3], [XeGPU_ScalarType]>;
-def XeGPU_DpasResType: FixedVectorOfRankAndType<[1, 2], [XeGPU_ScalarType]>;
+def XeGPU_DpasOprElemType : AnyTypeOf<[AnyI1, I<4>, AnyI8, AnyI16,
+                                       F4E2M1FN, F8E4M3FN, F8E5M2, F8E8M0FNU, F16, BF16]>;
+def XeGPU_DpasResElemType : AnyTypeOf<[XeGPU_DpasOprElemType, AnyI32, TF32, F32]>;
+
+def XeGPU_DpasOprType : FixedVectorOfRankAndType<[1, 2, 3], [XeGPU_DpasOprElemType]>;
+def XeGPU_DpasResType : FixedVectorOfRankAndType<[1, 2], [XeGPU_DpasResElemType]>;
 def XeGPU_OffsetType: FixedVectorOfNonZeroRankOf<[Index]>;
 def XeGPU_MaskType: FixedVectorOfNonZeroRankOf<[I1]>;
 def XeGPU_ValueType: FixedVectorOfNonZeroRankOf<[XeGPU_ScalarType]>;
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 91ba07a8e0256..81df3159a36b1 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -1045,11 +1045,13 @@ LogicalResult UpdateOffsetOp::verify() {
 // XeGPU_DpasOp
 //===----------------------------------------------------------------------===//
 LogicalResult DpasOp::verify() {
-  int64_t lhsRank = getLhsType().getRank();
-  int64_t rhsRank = getRhsType().getRank();
+  auto lhsType = getLhsType();
+  auto rhsType = getRhsType();
+  int64_t lhsRank = lhsType.getRank();
+  int64_t rhsRank = rhsType.getRank();
   int64_t resRank = getResultType().getRank();
-  auto lhsShape = getLhsType().getShape();
-  auto rhsShape = getRhsType().getShape();
+  auto lhsShape = lhsType.getShape();
+  auto rhsShape = rhsType.getShape();
   auto resShape = getResultType().getShape();
 
   if (getAcc() && getAcc().getType() != getResultType())
@@ -1059,8 +1061,8 @@ LogicalResult DpasOp::verify() {
   // It skips the semantic check since lack of architecture information.
   // Users need to ensure the correctness.
   if (lhsRank == 1 && rhsRank == 1 && resRank == 1) {
-    auto numElems = getRhsType().getNumElements();
-    auto elemTy = getRhsType().getElementType();
+    auto numElems = rhsType.getNumElements();
+    auto elemTy = rhsType.getElementType();
     auto factor = 32 / elemTy.getIntOrFloatBitWidth();
     if (numElems % factor != 0)
       return emitOpError("Expecting B operand to be a multiple of 32 bits.");

>From 2410439b69b4cf85b1dd303725fa7839f56c5ebd Mon Sep 17 00:00:00 2001
From: Arjun Bhamra <arjun.bhamra25 at gmail.com>
Date: Fri, 6 Mar 2026 13:45:53 -0500
Subject: [PATCH 2/5] dpas fix in lowering via uarch checks and
 notifymatcherror

---
 .../mlir/Dialect/XeGPU/IR/XeGPUTypes.td       |  8 ++---
 .../Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp    | 34 +++++++++++++++++++
 mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp        | 14 ++++----
 mlir/test/Conversion/XeGPUToXeVM/dpas.mlir    |  2 +-
 4 files changed, 43 insertions(+), 15 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
index 5e10aa2981524..8286566ddabee 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
@@ -21,12 +21,8 @@ def XeGPU_ScalarType: AnyTypeOf<[XeGPU_IntType, XeGPU_FloatType]>;
 def XeGPU_PointerType : AnyTypeOf<[UI64, UI32, I64, I32]>;
 def XeGPU_BaseAddrType
     : AnyTypeOf<[Non0RankedMemRefOf<[XeGPU_ScalarType]>, XeGPU_PointerType]>;
-def XeGPU_DpasOprElemType : AnyTypeOf<[AnyI1, I<4>, AnyI8, AnyI16,
-                                       F4E2M1FN, F8E4M3FN, F8E5M2, F8E8M0FNU, F16, BF16]>;
-def XeGPU_DpasResElemType : AnyTypeOf<[XeGPU_DpasOprElemType, AnyI32, TF32, F32]>;
-
-def XeGPU_DpasOprType : FixedVectorOfRankAndType<[1, 2, 3], [XeGPU_DpasOprElemType]>;
-def XeGPU_DpasResType : FixedVectorOfRankAndType<[1, 2], [XeGPU_DpasResElemType]>;
+def XeGPU_DpasOprType : FixedVectorOfRankAndType<[1, 2, 3], [XeGPU_ScalarType]>;
+def XeGPU_DpasResType : FixedVectorOfRankAndType<[1, 2], [XeGPU_ScalarType]>;
 def XeGPU_OffsetType: FixedVectorOfNonZeroRankOf<[Index]>;
 def XeGPU_MaskType: FixedVectorOfNonZeroRankOf<[I1]>;
 def XeGPU_ValueType: FixedVectorOfNonZeroRankOf<[XeGPU_ScalarType]>;
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index 8a06271eadd84..d928e24b98493 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -22,6 +22,7 @@
 #include "mlir/Dialect/SCF/Transforms/Patterns.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h"
 #include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Support/LLVM.h"
@@ -902,6 +903,39 @@ class DpasToXeVMPattern : public OpConversionPattern<xegpu::DpasOp> {
     auto bTy = cast<VectorType>(op.getRhs().getType());
     auto resultType = cast<VectorType>(op.getResultType());
 
+    // get the correct dpasInst by getting info from chip
+    auto chipStr = xegpu::getChipStr(op);
+    if (!chipStr)
+      return rewriter.notifyMatchFailure(op, "cannot determine target chip");
+
+    const auto *uArch= mlir::xegpu::uArch::getUArch(*chipStr);
+    if (!uArch)
+      return rewriter.notifyMatchFailure(op, "unsupported target uArch");
+
+    auto *dpasInst = const_cast<xegpu::uArch::SubgroupMatrixMultiplyAcc*>(
+        llvm::dyn_cast_or_null<xegpu::uArch::SubgroupMatrixMultiplyAcc>(
+            uArch->getInstruction(
+                xegpu::uArch::InstructionKind::SubgroupMatrixMultiplyAcc)));
+    if (!dpasInst)
+      return rewriter.notifyMatchFailure(op, "DPAS not supported by target uArch");
+
+    auto supportedA = dpasInst->getSupportedTypes(*ctxt, xegpu::uArch::MMAOpndKind::MatrixA);
+    auto supportedB = dpasInst->getSupportedTypes(*ctxt, xegpu::uArch::MMAOpndKind::MatrixB);
+    auto supportedD = dpasInst->getSupportedTypes(*ctxt, xegpu::uArch::MMAOpndKind::MatrixD);
+    // NOTE: Supported types for MatrixC and MatrixD are identical
+
+    if (llvm::find(supportedA, aTy.getElementType()) == supportedA.end())
+      return rewriter.notifyMatchFailure(
+          op, "A-matrix element type not supported by target uArch");
+
+    if (llvm::find(supportedB, bTy.getElementType()) == supportedB.end())
+      return rewriter.notifyMatchFailure(
+          op, "B-matrix element type not supported by target uArch");
+
+    if (llvm::find(supportedD, resultType.getElementType()) == supportedD.end())
+      return rewriter.notifyMatchFailure(
+          op, "result/accumulator element type not supported by target uArch");
+
     auto encodePrecision = [&](Type type) -> xevm::ElemType {
       if (type == rewriter.getBF16Type())
         return xevm::ElemType::BF16;
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 81df3159a36b1..91ba07a8e0256 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -1045,13 +1045,11 @@ LogicalResult UpdateOffsetOp::verify() {
 // XeGPU_DpasOp
 //===----------------------------------------------------------------------===//
 LogicalResult DpasOp::verify() {
-  auto lhsType = getLhsType();
-  auto rhsType = getRhsType();
-  int64_t lhsRank = lhsType.getRank();
-  int64_t rhsRank = rhsType.getRank();
+  int64_t lhsRank = getLhsType().getRank();
+  int64_t rhsRank = getRhsType().getRank();
   int64_t resRank = getResultType().getRank();
-  auto lhsShape = lhsType.getShape();
-  auto rhsShape = rhsType.getShape();
+  auto lhsShape = getLhsType().getShape();
+  auto rhsShape = getRhsType().getShape();
   auto resShape = getResultType().getShape();
 
   if (getAcc() && getAcc().getType() != getResultType())
@@ -1061,8 +1059,8 @@ LogicalResult DpasOp::verify() {
   // It skips the semantic check since lack of architecture information.
   // Users need to ensure the correctness.
   if (lhsRank == 1 && rhsRank == 1 && resRank == 1) {
-    auto numElems = rhsType.getNumElements();
-    auto elemTy = rhsType.getElementType();
+    auto numElems = getRhsType().getNumElements();
+    auto elemTy = getRhsType().getElementType();
     auto factor = 32 / elemTy.getIntOrFloatBitWidth();
     if (numElems % factor != 0)
       return emitOpError("Expecting B operand to be a multiple of 32 bits.");
diff --git a/mlir/test/Conversion/XeGPUToXeVM/dpas.mlir b/mlir/test/Conversion/XeGPUToXeVM/dpas.mlir
index a9ab0be00722c..7cc59f4cfdcdf 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/dpas.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/dpas.mlir
@@ -1,6 +1,6 @@
 // RUN: mlir-opt -convert-xegpu-to-xevm %s | FileCheck %s
 
-gpu.module @test_kernel {
+gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
     // CHECK-LABEL: func.func @dpas(
     // CHECK-SAME: %[[ARG0:.*]]: vector<8xf16>, %[[ARG1:.*]]: vector<16xf16>, %[[ARG2:.*]]: vector<8xf32>
     func.func @dpas(%a_loaded: vector<8xf16>, %b_loaded: vector<16xf16>, %c_loaded: vector<8xf32>) -> vector<8xf32> {

>From ef9b386ecd84a7cabd1012cd96590a1eebfdcaed Mon Sep 17 00:00:00 2001
From: Arjun Bhamra <arjun.bhamra25 at gmail.com>
Date: Fri, 6 Mar 2026 13:56:51 -0500
Subject: [PATCH 3/5] minor indentation mistake

---
 mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
index 8286566ddabee..c50bd25df2742 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
@@ -21,8 +21,8 @@ def XeGPU_ScalarType: AnyTypeOf<[XeGPU_IntType, XeGPU_FloatType]>;
 def XeGPU_PointerType : AnyTypeOf<[UI64, UI32, I64, I32]>;
 def XeGPU_BaseAddrType
     : AnyTypeOf<[Non0RankedMemRefOf<[XeGPU_ScalarType]>, XeGPU_PointerType]>;
-def XeGPU_DpasOprType : FixedVectorOfRankAndType<[1, 2, 3], [XeGPU_ScalarType]>;
-def XeGPU_DpasResType : FixedVectorOfRankAndType<[1, 2], [XeGPU_ScalarType]>;
+def XeGPU_DpasOprType: FixedVectorOfRankAndType<[1, 2, 3], [XeGPU_ScalarType]>;
+def XeGPU_DpasResType: FixedVectorOfRankAndType<[1, 2], [XeGPU_ScalarType]>;
 def XeGPU_OffsetType: FixedVectorOfNonZeroRankOf<[Index]>;
 def XeGPU_MaskType: FixedVectorOfNonZeroRankOf<[I1]>;
 def XeGPU_ValueType: FixedVectorOfNonZeroRankOf<[XeGPU_ScalarType]>;

>From 546a34b09596a52ca822e74beaf52e5b6520ed9d Mon Sep 17 00:00:00 2001
From: Arjun Bhamra <arjun.bhamra25 at gmail.com>
Date: Fri, 6 Mar 2026 14:03:54 -0500
Subject: [PATCH 4/5] failed conversion test

---
 .../Conversion/XeGPUToXeVM/failed_conversion.mlir  | 14 ++++++++++++++
 1 file changed, 14 insertions(+)
 create mode 100644 mlir/test/Conversion/XeGPUToXeVM/failed_conversion.mlir

diff --git a/mlir/test/Conversion/XeGPUToXeVM/failed_conversion.mlir b/mlir/test/Conversion/XeGPUToXeVM/failed_conversion.mlir
new file mode 100644
index 0000000000000..c4088c430a4d9
--- /dev/null
+++ b/mlir/test/Conversion/XeGPUToXeVM/failed_conversion.mlir
@@ -0,0 +1,14 @@
+// RUN: mlir-opt --convert-xegpu-to-xevm %s -split-input-file -verify-diagnostics
+
+// Verify that xegpu.dpas with unsupported element types (i16) is rejected
+// during XeGPUToXeVM conversion rather than crashing.
+
+gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
+  func.func @main() {
+    %0 = spirv.Constant dense<0> : vector<4xi16>
+    %1 = spirv.Constant dense<0> : vector<4xi32>
+    // expected-error at +1 {{failed to legalize operation 'xegpu.dpas' that was explicitly marked illegal}}
+    %2 = xegpu.dpas %0, %0, %1 : vector<4xi16>, vector<4xi16>, vector<4xi32> -> vector<4xi32>
+    return
+  }
+}

>From 1e4b0ea652453bd2c8ba74ce1b050c216e3deb8f Mon Sep 17 00:00:00 2001
From: Arjun Bhamra <arjun.bhamra25 at gmail.com>
Date: Fri, 6 Mar 2026 14:16:15 -0500
Subject: [PATCH 5/5] formatter issues

---
 .../lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp | 18 +++++++++++-------
 1 file changed, 11 insertions(+), 7 deletions(-)

diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index d928e24b98493..94430e1068c99 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -908,20 +908,24 @@ class DpasToXeVMPattern : public OpConversionPattern<xegpu::DpasOp> {
     if (!chipStr)
       return rewriter.notifyMatchFailure(op, "cannot determine target chip");
 
-    const auto *uArch= mlir::xegpu::uArch::getUArch(*chipStr);
+    const auto *uArch = mlir::xegpu::uArch::getUArch(*chipStr);
     if (!uArch)
       return rewriter.notifyMatchFailure(op, "unsupported target uArch");
 
-    auto *dpasInst = const_cast<xegpu::uArch::SubgroupMatrixMultiplyAcc*>(
+    auto *dpasInst = const_cast<xegpu::uArch::SubgroupMatrixMultiplyAcc *>(
         llvm::dyn_cast_or_null<xegpu::uArch::SubgroupMatrixMultiplyAcc>(
             uArch->getInstruction(
                 xegpu::uArch::InstructionKind::SubgroupMatrixMultiplyAcc)));
     if (!dpasInst)
-      return rewriter.notifyMatchFailure(op, "DPAS not supported by target uArch");
-
-    auto supportedA = dpasInst->getSupportedTypes(*ctxt, xegpu::uArch::MMAOpndKind::MatrixA);
-    auto supportedB = dpasInst->getSupportedTypes(*ctxt, xegpu::uArch::MMAOpndKind::MatrixB);
-    auto supportedD = dpasInst->getSupportedTypes(*ctxt, xegpu::uArch::MMAOpndKind::MatrixD);
+      return rewriter.notifyMatchFailure(op,
+                                         "DPAS not supported by target uArch");
+
+    auto supportedA =
+      dpasInst->getSupportedTypes(*ctxt, xegpu::uArch::MMAOpndKind::MatrixA);
+    auto supportedB =
+      dpasInst->getSupportedTypes(*ctxt, xegpu::uArch::MMAOpndKind::MatrixB);
+    auto supportedD =
+      dpasInst->getSupportedTypes(*ctxt, xegpu::uArch::MMAOpndKind::MatrixD);
     // NOTE: Supported types for MatrixC and MatrixD are identical
 
     if (llvm::find(supportedA, aTy.getElementType()) == supportedA.end())



More information about the Mlir-commits mailing list