[Mlir-commits] [mlir] [MLIR][XeGPU] Validate DPAS operand types against uArch in XeGPUToXeVM conversion (PR #185081)
Arjun Bhamra
llvmlistbot at llvm.org
Fri Mar 6 11:16:51 PST 2026
https://github.com/abhamra updated https://github.com/llvm/llvm-project/pull/185081
>From cea8e96e5bd900c8bef445de47c481514c64e04c Mon Sep 17 00:00:00 2001
From: Arjun Bhamra <arjun.bhamra25 at gmail.com>
Date: Wed, 4 Mar 2026 20:56:26 -0500
Subject: [PATCH 1/5] update xegpu types, remove verifier check
---
mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td | 8 ++++++--
mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 14 ++++++++------
2 files changed, 14 insertions(+), 8 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
index c50bd25df2742..5e10aa2981524 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
@@ -21,8 +21,12 @@ def XeGPU_ScalarType: AnyTypeOf<[XeGPU_IntType, XeGPU_FloatType]>;
def XeGPU_PointerType : AnyTypeOf<[UI64, UI32, I64, I32]>;
def XeGPU_BaseAddrType
: AnyTypeOf<[Non0RankedMemRefOf<[XeGPU_ScalarType]>, XeGPU_PointerType]>;
-def XeGPU_DpasOprType: FixedVectorOfRankAndType<[1, 2, 3], [XeGPU_ScalarType]>;
-def XeGPU_DpasResType: FixedVectorOfRankAndType<[1, 2], [XeGPU_ScalarType]>;
+def XeGPU_DpasOprElemType : AnyTypeOf<[AnyI1, I<4>, AnyI8, AnyI16,
+ F4E2M1FN, F8E4M3FN, F8E5M2, F8E8M0FNU, F16, BF16]>;
+def XeGPU_DpasResElemType : AnyTypeOf<[XeGPU_DpasOprElemType, AnyI32, TF32, F32]>;
+
+def XeGPU_DpasOprType : FixedVectorOfRankAndType<[1, 2, 3], [XeGPU_DpasOprElemType]>;
+def XeGPU_DpasResType : FixedVectorOfRankAndType<[1, 2], [XeGPU_DpasResElemType]>;
def XeGPU_OffsetType: FixedVectorOfNonZeroRankOf<[Index]>;
def XeGPU_MaskType: FixedVectorOfNonZeroRankOf<[I1]>;
def XeGPU_ValueType: FixedVectorOfNonZeroRankOf<[XeGPU_ScalarType]>;
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 91ba07a8e0256..81df3159a36b1 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -1045,11 +1045,13 @@ LogicalResult UpdateOffsetOp::verify() {
// XeGPU_DpasOp
//===----------------------------------------------------------------------===//
LogicalResult DpasOp::verify() {
- int64_t lhsRank = getLhsType().getRank();
- int64_t rhsRank = getRhsType().getRank();
+ auto lhsType = getLhsType();
+ auto rhsType = getRhsType();
+ int64_t lhsRank = lhsType.getRank();
+ int64_t rhsRank = rhsType.getRank();
int64_t resRank = getResultType().getRank();
- auto lhsShape = getLhsType().getShape();
- auto rhsShape = getRhsType().getShape();
+ auto lhsShape = lhsType.getShape();
+ auto rhsShape = rhsType.getShape();
auto resShape = getResultType().getShape();
if (getAcc() && getAcc().getType() != getResultType())
@@ -1059,8 +1061,8 @@ LogicalResult DpasOp::verify() {
// It skips the semantic check since lack of architecture information.
// Users need to ensure the correctness.
if (lhsRank == 1 && rhsRank == 1 && resRank == 1) {
- auto numElems = getRhsType().getNumElements();
- auto elemTy = getRhsType().getElementType();
+ auto numElems = rhsType.getNumElements();
+ auto elemTy = rhsType.getElementType();
auto factor = 32 / elemTy.getIntOrFloatBitWidth();
if (numElems % factor != 0)
return emitOpError("Expecting B operand to be a multiple of 32 bits.");
>From 2410439b69b4cf85b1dd303725fa7839f56c5ebd Mon Sep 17 00:00:00 2001
From: Arjun Bhamra <arjun.bhamra25 at gmail.com>
Date: Fri, 6 Mar 2026 13:45:53 -0500
Subject: [PATCH 2/5] dpas fix in lowering via uarch checks and
notifymatcherror
---
.../mlir/Dialect/XeGPU/IR/XeGPUTypes.td | 8 ++---
.../Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp | 34 +++++++++++++++++++
mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 14 ++++----
mlir/test/Conversion/XeGPUToXeVM/dpas.mlir | 2 +-
4 files changed, 43 insertions(+), 15 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
index 5e10aa2981524..8286566ddabee 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
@@ -21,12 +21,8 @@ def XeGPU_ScalarType: AnyTypeOf<[XeGPU_IntType, XeGPU_FloatType]>;
def XeGPU_PointerType : AnyTypeOf<[UI64, UI32, I64, I32]>;
def XeGPU_BaseAddrType
: AnyTypeOf<[Non0RankedMemRefOf<[XeGPU_ScalarType]>, XeGPU_PointerType]>;
-def XeGPU_DpasOprElemType : AnyTypeOf<[AnyI1, I<4>, AnyI8, AnyI16,
- F4E2M1FN, F8E4M3FN, F8E5M2, F8E8M0FNU, F16, BF16]>;
-def XeGPU_DpasResElemType : AnyTypeOf<[XeGPU_DpasOprElemType, AnyI32, TF32, F32]>;
-
-def XeGPU_DpasOprType : FixedVectorOfRankAndType<[1, 2, 3], [XeGPU_DpasOprElemType]>;
-def XeGPU_DpasResType : FixedVectorOfRankAndType<[1, 2], [XeGPU_DpasResElemType]>;
+def XeGPU_DpasOprType : FixedVectorOfRankAndType<[1, 2, 3], [XeGPU_ScalarType]>;
+def XeGPU_DpasResType : FixedVectorOfRankAndType<[1, 2], [XeGPU_ScalarType]>;
def XeGPU_OffsetType: FixedVectorOfNonZeroRankOf<[Index]>;
def XeGPU_MaskType: FixedVectorOfNonZeroRankOf<[I1]>;
def XeGPU_ValueType: FixedVectorOfNonZeroRankOf<[XeGPU_ScalarType]>;
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index 8a06271eadd84..d928e24b98493 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -22,6 +22,7 @@
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h"
#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
@@ -902,6 +903,39 @@ class DpasToXeVMPattern : public OpConversionPattern<xegpu::DpasOp> {
auto bTy = cast<VectorType>(op.getRhs().getType());
auto resultType = cast<VectorType>(op.getResultType());
+ // get the correct dpasInst by getting info from chip
+ auto chipStr = xegpu::getChipStr(op);
+ if (!chipStr)
+ return rewriter.notifyMatchFailure(op, "cannot determine target chip");
+
+ const auto *uArch= mlir::xegpu::uArch::getUArch(*chipStr);
+ if (!uArch)
+ return rewriter.notifyMatchFailure(op, "unsupported target uArch");
+
+ auto *dpasInst = const_cast<xegpu::uArch::SubgroupMatrixMultiplyAcc*>(
+ llvm::dyn_cast_or_null<xegpu::uArch::SubgroupMatrixMultiplyAcc>(
+ uArch->getInstruction(
+ xegpu::uArch::InstructionKind::SubgroupMatrixMultiplyAcc)));
+ if (!dpasInst)
+ return rewriter.notifyMatchFailure(op, "DPAS not supported by target uArch");
+
+ auto supportedA = dpasInst->getSupportedTypes(*ctxt, xegpu::uArch::MMAOpndKind::MatrixA);
+ auto supportedB = dpasInst->getSupportedTypes(*ctxt, xegpu::uArch::MMAOpndKind::MatrixB);
+ auto supportedD = dpasInst->getSupportedTypes(*ctxt, xegpu::uArch::MMAOpndKind::MatrixD);
+ // NOTE: Supported types for MatrixC and MatrixD are identical
+
+ if (llvm::find(supportedA, aTy.getElementType()) == supportedA.end())
+ return rewriter.notifyMatchFailure(
+ op, "A-matrix element type not supported by target uArch");
+
+ if (llvm::find(supportedB, bTy.getElementType()) == supportedB.end())
+ return rewriter.notifyMatchFailure(
+ op, "B-matrix element type not supported by target uArch");
+
+ if (llvm::find(supportedD, resultType.getElementType()) == supportedD.end())
+ return rewriter.notifyMatchFailure(
+ op, "result/accumulator element type not supported by target uArch");
+
auto encodePrecision = [&](Type type) -> xevm::ElemType {
if (type == rewriter.getBF16Type())
return xevm::ElemType::BF16;
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 81df3159a36b1..91ba07a8e0256 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -1045,13 +1045,11 @@ LogicalResult UpdateOffsetOp::verify() {
// XeGPU_DpasOp
//===----------------------------------------------------------------------===//
LogicalResult DpasOp::verify() {
- auto lhsType = getLhsType();
- auto rhsType = getRhsType();
- int64_t lhsRank = lhsType.getRank();
- int64_t rhsRank = rhsType.getRank();
+ int64_t lhsRank = getLhsType().getRank();
+ int64_t rhsRank = getRhsType().getRank();
int64_t resRank = getResultType().getRank();
- auto lhsShape = lhsType.getShape();
- auto rhsShape = rhsType.getShape();
+ auto lhsShape = getLhsType().getShape();
+ auto rhsShape = getRhsType().getShape();
auto resShape = getResultType().getShape();
if (getAcc() && getAcc().getType() != getResultType())
@@ -1061,8 +1059,8 @@ LogicalResult DpasOp::verify() {
// It skips the semantic check since lack of architecture information.
// Users need to ensure the correctness.
if (lhsRank == 1 && rhsRank == 1 && resRank == 1) {
- auto numElems = rhsType.getNumElements();
- auto elemTy = rhsType.getElementType();
+ auto numElems = getRhsType().getNumElements();
+ auto elemTy = getRhsType().getElementType();
auto factor = 32 / elemTy.getIntOrFloatBitWidth();
if (numElems % factor != 0)
return emitOpError("Expecting B operand to be a multiple of 32 bits.");
diff --git a/mlir/test/Conversion/XeGPUToXeVM/dpas.mlir b/mlir/test/Conversion/XeGPUToXeVM/dpas.mlir
index a9ab0be00722c..7cc59f4cfdcdf 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/dpas.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/dpas.mlir
@@ -1,6 +1,6 @@
// RUN: mlir-opt -convert-xegpu-to-xevm %s | FileCheck %s
-gpu.module @test_kernel {
+gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
// CHECK-LABEL: func.func @dpas(
// CHECK-SAME: %[[ARG0:.*]]: vector<8xf16>, %[[ARG1:.*]]: vector<16xf16>, %[[ARG2:.*]]: vector<8xf32>
func.func @dpas(%a_loaded: vector<8xf16>, %b_loaded: vector<16xf16>, %c_loaded: vector<8xf32>) -> vector<8xf32> {
>From ef9b386ecd84a7cabd1012cd96590a1eebfdcaed Mon Sep 17 00:00:00 2001
From: Arjun Bhamra <arjun.bhamra25 at gmail.com>
Date: Fri, 6 Mar 2026 13:56:51 -0500
Subject: [PATCH 3/5] minor indentation mistake
---
mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
index 8286566ddabee..c50bd25df2742 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
@@ -21,8 +21,8 @@ def XeGPU_ScalarType: AnyTypeOf<[XeGPU_IntType, XeGPU_FloatType]>;
def XeGPU_PointerType : AnyTypeOf<[UI64, UI32, I64, I32]>;
def XeGPU_BaseAddrType
: AnyTypeOf<[Non0RankedMemRefOf<[XeGPU_ScalarType]>, XeGPU_PointerType]>;
-def XeGPU_DpasOprType : FixedVectorOfRankAndType<[1, 2, 3], [XeGPU_ScalarType]>;
-def XeGPU_DpasResType : FixedVectorOfRankAndType<[1, 2], [XeGPU_ScalarType]>;
+def XeGPU_DpasOprType: FixedVectorOfRankAndType<[1, 2, 3], [XeGPU_ScalarType]>;
+def XeGPU_DpasResType: FixedVectorOfRankAndType<[1, 2], [XeGPU_ScalarType]>;
def XeGPU_OffsetType: FixedVectorOfNonZeroRankOf<[Index]>;
def XeGPU_MaskType: FixedVectorOfNonZeroRankOf<[I1]>;
def XeGPU_ValueType: FixedVectorOfNonZeroRankOf<[XeGPU_ScalarType]>;
>From 546a34b09596a52ca822e74beaf52e5b6520ed9d Mon Sep 17 00:00:00 2001
From: Arjun Bhamra <arjun.bhamra25 at gmail.com>
Date: Fri, 6 Mar 2026 14:03:54 -0500
Subject: [PATCH 4/5] failed conversion test
---
.../Conversion/XeGPUToXeVM/failed_conversion.mlir | 14 ++++++++++++++
1 file changed, 14 insertions(+)
create mode 100644 mlir/test/Conversion/XeGPUToXeVM/failed_conversion.mlir
diff --git a/mlir/test/Conversion/XeGPUToXeVM/failed_conversion.mlir b/mlir/test/Conversion/XeGPUToXeVM/failed_conversion.mlir
new file mode 100644
index 0000000000000..c4088c430a4d9
--- /dev/null
+++ b/mlir/test/Conversion/XeGPUToXeVM/failed_conversion.mlir
@@ -0,0 +1,14 @@
+// RUN: mlir-opt --convert-xegpu-to-xevm %s -split-input-file -verify-diagnostics
+
+// Verify that xegpu.dpas with unsupported element types (i16) is rejected
+// during XeGPUToXeVM conversion rather than crashing.
+
+gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
+ func.func @main() {
+ %0 = spirv.Constant dense<0> : vector<4xi16>
+ %1 = spirv.Constant dense<0> : vector<4xi32>
+ // expected-error at +1 {{failed to legalize operation 'xegpu.dpas' that was explicitly marked illegal}}
+ %2 = xegpu.dpas %0, %0, %1 : vector<4xi16>, vector<4xi16>, vector<4xi32> -> vector<4xi32>
+ return
+ }
+}
>From 1e4b0ea652453bd2c8ba74ce1b050c216e3deb8f Mon Sep 17 00:00:00 2001
From: Arjun Bhamra <arjun.bhamra25 at gmail.com>
Date: Fri, 6 Mar 2026 14:16:15 -0500
Subject: [PATCH 5/5] formatter issues
---
.../lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp | 18 +++++++++++-------
1 file changed, 11 insertions(+), 7 deletions(-)
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index d928e24b98493..94430e1068c99 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -908,20 +908,24 @@ class DpasToXeVMPattern : public OpConversionPattern<xegpu::DpasOp> {
if (!chipStr)
return rewriter.notifyMatchFailure(op, "cannot determine target chip");
- const auto *uArch= mlir::xegpu::uArch::getUArch(*chipStr);
+ const auto *uArch = mlir::xegpu::uArch::getUArch(*chipStr);
if (!uArch)
return rewriter.notifyMatchFailure(op, "unsupported target uArch");
- auto *dpasInst = const_cast<xegpu::uArch::SubgroupMatrixMultiplyAcc*>(
+ auto *dpasInst = const_cast<xegpu::uArch::SubgroupMatrixMultiplyAcc *>(
llvm::dyn_cast_or_null<xegpu::uArch::SubgroupMatrixMultiplyAcc>(
uArch->getInstruction(
xegpu::uArch::InstructionKind::SubgroupMatrixMultiplyAcc)));
if (!dpasInst)
- return rewriter.notifyMatchFailure(op, "DPAS not supported by target uArch");
-
- auto supportedA = dpasInst->getSupportedTypes(*ctxt, xegpu::uArch::MMAOpndKind::MatrixA);
- auto supportedB = dpasInst->getSupportedTypes(*ctxt, xegpu::uArch::MMAOpndKind::MatrixB);
- auto supportedD = dpasInst->getSupportedTypes(*ctxt, xegpu::uArch::MMAOpndKind::MatrixD);
+ return rewriter.notifyMatchFailure(op,
+ "DPAS not supported by target uArch");
+
+ auto supportedA =
+ dpasInst->getSupportedTypes(*ctxt, xegpu::uArch::MMAOpndKind::MatrixA);
+ auto supportedB =
+ dpasInst->getSupportedTypes(*ctxt, xegpu::uArch::MMAOpndKind::MatrixB);
+ auto supportedD =
+ dpasInst->getSupportedTypes(*ctxt, xegpu::uArch::MMAOpndKind::MatrixD);
// NOTE: Supported types for MatrixC and MatrixD are identical
if (llvm::find(supportedA, aTy.getElementType()) == supportedA.end())
More information about the Mlir-commits
mailing list