[Mlir-commits] [mlir] [MLIR][XeGPU] Validate DPAS operand types against uArch in XeGPUToXeVM conversion (PR #185081)
Arjun Bhamra
llvmlistbot at llvm.org
Wed Mar 11 17:54:44 PDT 2026
https://github.com/abhamra updated https://github.com/llvm/llvm-project/pull/185081
>From cea8e96e5bd900c8bef445de47c481514c64e04c Mon Sep 17 00:00:00 2001
From: Arjun Bhamra <arjun.bhamra25 at gmail.com>
Date: Wed, 4 Mar 2026 20:56:26 -0500
Subject: [PATCH 1/8] update xegpu types, remove verifier check
---
mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td | 8 ++++++--
mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 14 ++++++++------
2 files changed, 14 insertions(+), 8 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
index c50bd25df2742..5e10aa2981524 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
@@ -21,8 +21,12 @@ def XeGPU_ScalarType: AnyTypeOf<[XeGPU_IntType, XeGPU_FloatType]>;
def XeGPU_PointerType : AnyTypeOf<[UI64, UI32, I64, I32]>;
def XeGPU_BaseAddrType
: AnyTypeOf<[Non0RankedMemRefOf<[XeGPU_ScalarType]>, XeGPU_PointerType]>;
-def XeGPU_DpasOprType: FixedVectorOfRankAndType<[1, 2, 3], [XeGPU_ScalarType]>;
-def XeGPU_DpasResType: FixedVectorOfRankAndType<[1, 2], [XeGPU_ScalarType]>;
+def XeGPU_DpasOprElemType : AnyTypeOf<[AnyI1, I<4>, AnyI8, AnyI16,
+ F4E2M1FN, F8E4M3FN, F8E5M2, F8E8M0FNU, F16, BF16]>;
+def XeGPU_DpasResElemType : AnyTypeOf<[XeGPU_DpasOprElemType, AnyI32, TF32, F32]>;
+
+def XeGPU_DpasOprType : FixedVectorOfRankAndType<[1, 2, 3], [XeGPU_DpasOprElemType]>;
+def XeGPU_DpasResType : FixedVectorOfRankAndType<[1, 2], [XeGPU_DpasResElemType]>;
def XeGPU_OffsetType: FixedVectorOfNonZeroRankOf<[Index]>;
def XeGPU_MaskType: FixedVectorOfNonZeroRankOf<[I1]>;
def XeGPU_ValueType: FixedVectorOfNonZeroRankOf<[XeGPU_ScalarType]>;
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 91ba07a8e0256..81df3159a36b1 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -1045,11 +1045,13 @@ LogicalResult UpdateOffsetOp::verify() {
// XeGPU_DpasOp
//===----------------------------------------------------------------------===//
LogicalResult DpasOp::verify() {
- int64_t lhsRank = getLhsType().getRank();
- int64_t rhsRank = getRhsType().getRank();
+ auto lhsType = getLhsType();
+ auto rhsType = getRhsType();
+ int64_t lhsRank = lhsType.getRank();
+ int64_t rhsRank = rhsType.getRank();
int64_t resRank = getResultType().getRank();
- auto lhsShape = getLhsType().getShape();
- auto rhsShape = getRhsType().getShape();
+ auto lhsShape = lhsType.getShape();
+ auto rhsShape = rhsType.getShape();
auto resShape = getResultType().getShape();
if (getAcc() && getAcc().getType() != getResultType())
@@ -1059,8 +1061,8 @@ LogicalResult DpasOp::verify() {
// It skips the semantic check since lack of architecture information.
// Users need to ensure the correctness.
if (lhsRank == 1 && rhsRank == 1 && resRank == 1) {
- auto numElems = getRhsType().getNumElements();
- auto elemTy = getRhsType().getElementType();
+ auto numElems = rhsType.getNumElements();
+ auto elemTy = rhsType.getElementType();
auto factor = 32 / elemTy.getIntOrFloatBitWidth();
if (numElems % factor != 0)
return emitOpError("Expecting B operand to be a multiple of 32 bits.");
>From 2410439b69b4cf85b1dd303725fa7839f56c5ebd Mon Sep 17 00:00:00 2001
From: Arjun Bhamra <arjun.bhamra25 at gmail.com>
Date: Fri, 6 Mar 2026 13:45:53 -0500
Subject: [PATCH 2/8] dpas fix in lowering via uarch checks and
notifymatcherror
---
.../mlir/Dialect/XeGPU/IR/XeGPUTypes.td | 8 ++---
.../Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp | 34 +++++++++++++++++++
mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 14 ++++----
mlir/test/Conversion/XeGPUToXeVM/dpas.mlir | 2 +-
4 files changed, 43 insertions(+), 15 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
index 5e10aa2981524..8286566ddabee 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
@@ -21,12 +21,8 @@ def XeGPU_ScalarType: AnyTypeOf<[XeGPU_IntType, XeGPU_FloatType]>;
def XeGPU_PointerType : AnyTypeOf<[UI64, UI32, I64, I32]>;
def XeGPU_BaseAddrType
: AnyTypeOf<[Non0RankedMemRefOf<[XeGPU_ScalarType]>, XeGPU_PointerType]>;
-def XeGPU_DpasOprElemType : AnyTypeOf<[AnyI1, I<4>, AnyI8, AnyI16,
- F4E2M1FN, F8E4M3FN, F8E5M2, F8E8M0FNU, F16, BF16]>;
-def XeGPU_DpasResElemType : AnyTypeOf<[XeGPU_DpasOprElemType, AnyI32, TF32, F32]>;
-
-def XeGPU_DpasOprType : FixedVectorOfRankAndType<[1, 2, 3], [XeGPU_DpasOprElemType]>;
-def XeGPU_DpasResType : FixedVectorOfRankAndType<[1, 2], [XeGPU_DpasResElemType]>;
+def XeGPU_DpasOprType : FixedVectorOfRankAndType<[1, 2, 3], [XeGPU_ScalarType]>;
+def XeGPU_DpasResType : FixedVectorOfRankAndType<[1, 2], [XeGPU_ScalarType]>;
def XeGPU_OffsetType: FixedVectorOfNonZeroRankOf<[Index]>;
def XeGPU_MaskType: FixedVectorOfNonZeroRankOf<[I1]>;
def XeGPU_ValueType: FixedVectorOfNonZeroRankOf<[XeGPU_ScalarType]>;
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index 8a06271eadd84..d928e24b98493 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -22,6 +22,7 @@
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h"
#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
@@ -902,6 +903,39 @@ class DpasToXeVMPattern : public OpConversionPattern<xegpu::DpasOp> {
auto bTy = cast<VectorType>(op.getRhs().getType());
auto resultType = cast<VectorType>(op.getResultType());
+ // get the correct dpasInst by getting info from chip
+ auto chipStr = xegpu::getChipStr(op);
+ if (!chipStr)
+ return rewriter.notifyMatchFailure(op, "cannot determine target chip");
+
+ const auto *uArch= mlir::xegpu::uArch::getUArch(*chipStr);
+ if (!uArch)
+ return rewriter.notifyMatchFailure(op, "unsupported target uArch");
+
+ auto *dpasInst = const_cast<xegpu::uArch::SubgroupMatrixMultiplyAcc*>(
+ llvm::dyn_cast_or_null<xegpu::uArch::SubgroupMatrixMultiplyAcc>(
+ uArch->getInstruction(
+ xegpu::uArch::InstructionKind::SubgroupMatrixMultiplyAcc)));
+ if (!dpasInst)
+ return rewriter.notifyMatchFailure(op, "DPAS not supported by target uArch");
+
+ auto supportedA = dpasInst->getSupportedTypes(*ctxt, xegpu::uArch::MMAOpndKind::MatrixA);
+ auto supportedB = dpasInst->getSupportedTypes(*ctxt, xegpu::uArch::MMAOpndKind::MatrixB);
+ auto supportedD = dpasInst->getSupportedTypes(*ctxt, xegpu::uArch::MMAOpndKind::MatrixD);
+ // NOTE: Supported types for MatrixC and MatrixD are identical
+
+ if (llvm::find(supportedA, aTy.getElementType()) == supportedA.end())
+ return rewriter.notifyMatchFailure(
+ op, "A-matrix element type not supported by target uArch");
+
+ if (llvm::find(supportedB, bTy.getElementType()) == supportedB.end())
+ return rewriter.notifyMatchFailure(
+ op, "B-matrix element type not supported by target uArch");
+
+ if (llvm::find(supportedD, resultType.getElementType()) == supportedD.end())
+ return rewriter.notifyMatchFailure(
+ op, "result/accumulator element type not supported by target uArch");
+
auto encodePrecision = [&](Type type) -> xevm::ElemType {
if (type == rewriter.getBF16Type())
return xevm::ElemType::BF16;
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 81df3159a36b1..91ba07a8e0256 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -1045,13 +1045,11 @@ LogicalResult UpdateOffsetOp::verify() {
// XeGPU_DpasOp
//===----------------------------------------------------------------------===//
LogicalResult DpasOp::verify() {
- auto lhsType = getLhsType();
- auto rhsType = getRhsType();
- int64_t lhsRank = lhsType.getRank();
- int64_t rhsRank = rhsType.getRank();
+ int64_t lhsRank = getLhsType().getRank();
+ int64_t rhsRank = getRhsType().getRank();
int64_t resRank = getResultType().getRank();
- auto lhsShape = lhsType.getShape();
- auto rhsShape = rhsType.getShape();
+ auto lhsShape = getLhsType().getShape();
+ auto rhsShape = getRhsType().getShape();
auto resShape = getResultType().getShape();
if (getAcc() && getAcc().getType() != getResultType())
@@ -1061,8 +1059,8 @@ LogicalResult DpasOp::verify() {
// It skips the semantic check since lack of architecture information.
// Users need to ensure the correctness.
if (lhsRank == 1 && rhsRank == 1 && resRank == 1) {
- auto numElems = rhsType.getNumElements();
- auto elemTy = rhsType.getElementType();
+ auto numElems = getRhsType().getNumElements();
+ auto elemTy = getRhsType().getElementType();
auto factor = 32 / elemTy.getIntOrFloatBitWidth();
if (numElems % factor != 0)
return emitOpError("Expecting B operand to be a multiple of 32 bits.");
diff --git a/mlir/test/Conversion/XeGPUToXeVM/dpas.mlir b/mlir/test/Conversion/XeGPUToXeVM/dpas.mlir
index a9ab0be00722c..7cc59f4cfdcdf 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/dpas.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/dpas.mlir
@@ -1,6 +1,6 @@
// RUN: mlir-opt -convert-xegpu-to-xevm %s | FileCheck %s
-gpu.module @test_kernel {
+gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
// CHECK-LABEL: func.func @dpas(
// CHECK-SAME: %[[ARG0:.*]]: vector<8xf16>, %[[ARG1:.*]]: vector<16xf16>, %[[ARG2:.*]]: vector<8xf32>
func.func @dpas(%a_loaded: vector<8xf16>, %b_loaded: vector<16xf16>, %c_loaded: vector<8xf32>) -> vector<8xf32> {
>From ef9b386ecd84a7cabd1012cd96590a1eebfdcaed Mon Sep 17 00:00:00 2001
From: Arjun Bhamra <arjun.bhamra25 at gmail.com>
Date: Fri, 6 Mar 2026 13:56:51 -0500
Subject: [PATCH 3/8] minor indentation mistake
---
mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
index 8286566ddabee..c50bd25df2742 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
@@ -21,8 +21,8 @@ def XeGPU_ScalarType: AnyTypeOf<[XeGPU_IntType, XeGPU_FloatType]>;
def XeGPU_PointerType : AnyTypeOf<[UI64, UI32, I64, I32]>;
def XeGPU_BaseAddrType
: AnyTypeOf<[Non0RankedMemRefOf<[XeGPU_ScalarType]>, XeGPU_PointerType]>;
-def XeGPU_DpasOprType : FixedVectorOfRankAndType<[1, 2, 3], [XeGPU_ScalarType]>;
-def XeGPU_DpasResType : FixedVectorOfRankAndType<[1, 2], [XeGPU_ScalarType]>;
+def XeGPU_DpasOprType: FixedVectorOfRankAndType<[1, 2, 3], [XeGPU_ScalarType]>;
+def XeGPU_DpasResType: FixedVectorOfRankAndType<[1, 2], [XeGPU_ScalarType]>;
def XeGPU_OffsetType: FixedVectorOfNonZeroRankOf<[Index]>;
def XeGPU_MaskType: FixedVectorOfNonZeroRankOf<[I1]>;
def XeGPU_ValueType: FixedVectorOfNonZeroRankOf<[XeGPU_ScalarType]>;
>From 546a34b09596a52ca822e74beaf52e5b6520ed9d Mon Sep 17 00:00:00 2001
From: Arjun Bhamra <arjun.bhamra25 at gmail.com>
Date: Fri, 6 Mar 2026 14:03:54 -0500
Subject: [PATCH 4/8] failed conversion test
---
.../Conversion/XeGPUToXeVM/failed_conversion.mlir | 14 ++++++++++++++
1 file changed, 14 insertions(+)
create mode 100644 mlir/test/Conversion/XeGPUToXeVM/failed_conversion.mlir
diff --git a/mlir/test/Conversion/XeGPUToXeVM/failed_conversion.mlir b/mlir/test/Conversion/XeGPUToXeVM/failed_conversion.mlir
new file mode 100644
index 0000000000000..c4088c430a4d9
--- /dev/null
+++ b/mlir/test/Conversion/XeGPUToXeVM/failed_conversion.mlir
@@ -0,0 +1,14 @@
+// RUN: mlir-opt --convert-xegpu-to-xevm %s -split-input-file -verify-diagnostics
+
+// Verify that xegpu.dpas with unsupported element types (i16) is rejected
+// during XeGPUToXeVM conversion rather than crashing.
+
+gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
+ func.func @main() {
+ %0 = spirv.Constant dense<0> : vector<4xi16>
+ %1 = spirv.Constant dense<0> : vector<4xi32>
+ // expected-error at +1 {{failed to legalize operation 'xegpu.dpas' that was explicitly marked illegal}}
+ %2 = xegpu.dpas %0, %0, %1 : vector<4xi16>, vector<4xi16>, vector<4xi32> -> vector<4xi32>
+ return
+ }
+}
>From 1e4b0ea652453bd2c8ba74ce1b050c216e3deb8f Mon Sep 17 00:00:00 2001
From: Arjun Bhamra <arjun.bhamra25 at gmail.com>
Date: Fri, 6 Mar 2026 14:16:15 -0500
Subject: [PATCH 5/8] formatter issues
---
.../lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp | 18 +++++++++++-------
1 file changed, 11 insertions(+), 7 deletions(-)
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index d928e24b98493..94430e1068c99 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -908,20 +908,24 @@ class DpasToXeVMPattern : public OpConversionPattern<xegpu::DpasOp> {
if (!chipStr)
return rewriter.notifyMatchFailure(op, "cannot determine target chip");
- const auto *uArch= mlir::xegpu::uArch::getUArch(*chipStr);
+ const auto *uArch = mlir::xegpu::uArch::getUArch(*chipStr);
if (!uArch)
return rewriter.notifyMatchFailure(op, "unsupported target uArch");
- auto *dpasInst = const_cast<xegpu::uArch::SubgroupMatrixMultiplyAcc*>(
+ auto *dpasInst = const_cast<xegpu::uArch::SubgroupMatrixMultiplyAcc *>(
llvm::dyn_cast_or_null<xegpu::uArch::SubgroupMatrixMultiplyAcc>(
uArch->getInstruction(
xegpu::uArch::InstructionKind::SubgroupMatrixMultiplyAcc)));
if (!dpasInst)
- return rewriter.notifyMatchFailure(op, "DPAS not supported by target uArch");
-
- auto supportedA = dpasInst->getSupportedTypes(*ctxt, xegpu::uArch::MMAOpndKind::MatrixA);
- auto supportedB = dpasInst->getSupportedTypes(*ctxt, xegpu::uArch::MMAOpndKind::MatrixB);
- auto supportedD = dpasInst->getSupportedTypes(*ctxt, xegpu::uArch::MMAOpndKind::MatrixD);
+ return rewriter.notifyMatchFailure(op,
+ "DPAS not supported by target uArch");
+
+ auto supportedA =
+ dpasInst->getSupportedTypes(*ctxt, xegpu::uArch::MMAOpndKind::MatrixA);
+ auto supportedB =
+ dpasInst->getSupportedTypes(*ctxt, xegpu::uArch::MMAOpndKind::MatrixB);
+ auto supportedD =
+ dpasInst->getSupportedTypes(*ctxt, xegpu::uArch::MMAOpndKind::MatrixD);
// NOTE: Supported types for MatrixC and MatrixD are identical
if (llvm::find(supportedA, aTy.getElementType()) == supportedA.end())
>From 5b5778c6bf017c8e6316a14509fbd19fe07a56e4 Mon Sep 17 00:00:00 2001
From: Arjun Bhamra <arjun.bhamra25 at gmail.com>
Date: Fri, 6 Mar 2026 14:20:07 -0500
Subject: [PATCH 6/8] fix formatting again
---
.../Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp | 126 ++++++------------
1 file changed, 40 insertions(+), 86 deletions(-)
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index 2d6f9c927484c..795bc0ad84a4a 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -22,8 +22,8 @@
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
-#include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h"
#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
+#include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/STLExtras.h"
@@ -921,11 +921,11 @@ class DpasToXeVMPattern : public OpConversionPattern<xegpu::DpasOp> {
"DPAS not supported by target uArch");
auto supportedA =
- dpasInst->getSupportedTypes(*ctxt, xegpu::uArch::MMAOpndKind::MatrixA);
+ dpasInst->getSupportedTypes(*ctxt, xegpu::uArch::MMAOpndKind::MatrixA);
auto supportedB =
- dpasInst->getSupportedTypes(*ctxt, xegpu::uArch::MMAOpndKind::MatrixB);
+ dpasInst->getSupportedTypes(*ctxt, xegpu::uArch::MMAOpndKind::MatrixB);
auto supportedD =
- dpasInst->getSupportedTypes(*ctxt, xegpu::uArch::MMAOpndKind::MatrixD);
+ dpasInst->getSupportedTypes(*ctxt, xegpu::uArch::MMAOpndKind::MatrixD);
// NOTE: Supported types for MatrixC and MatrixD are identical
if (llvm::find(supportedA, aTy.getElementType()) == supportedA.end())
@@ -1094,8 +1094,8 @@ struct ConvertXeGPUToXeVMPass
// If the element type is index, convert it to i64.
if (llvm::isa<IndexType>(elemType))
elemType = IntegerType::get(&getContext(), 64);
- // If the vector rank is 0 or has a single element, return the element
- if (rank == 0 || type.getNumElements() == 1)
+ // If the vector is a scalar or has a single element, return the element
+ if (rank < 1 || type.getNumElements() == 1)
return elemType;
// Otherwise, convert the vector to a flat vector type.
int64_t sum = llvm::product_of(type.getShape());
@@ -1123,12 +1123,9 @@ struct ConvertXeGPUToXeVMPass
// add materialization casts to handle them.
// Materialization to convert memref to i64 or i32 depending on global/SLM
- // Applies only to target materialization.
- // Note: int type to memref materialization is not required as xegpu ops
- // currently do not produce memrefs as result.
- auto memrefToIntMaterializationCast = [](OpBuilder &builder, Type type,
- ValueRange inputs,
- Location loc) -> Value {
+ auto memrefMaterializationCast = [](OpBuilder &builder, Type type,
+ ValueRange inputs,
+ Location loc) -> Value {
if (inputs.size() != 1)
return {};
auto input = inputs.front();
@@ -1187,12 +1184,9 @@ struct ConvertXeGPUToXeVMPass
};
// Materialization to convert ui64 to i64
- // Applies only to target materialization.
- // Note: i64 to ui64 materialization is not required as xegpu ops
- // currently do not produce ui64 as result.
- auto ui64ToI64MaterializationCast = [](OpBuilder &builder, Type type,
- ValueRange inputs,
- Location loc) -> Value {
+ auto ui64MaterializationCast = [](OpBuilder &builder, Type type,
+ ValueRange inputs,
+ Location loc) -> Value {
if (inputs.size() != 1)
return {};
auto input = inputs.front();
@@ -1207,12 +1201,9 @@ struct ConvertXeGPUToXeVMPass
};
// Materialization to convert ui32 to i32
- // Applies only to target materialization.
- // Note: i32 to ui32 materialization is not required as xegpu ops
- // currently do not produce ui32 as result.
- auto ui32ToI32MaterializationCast = [](OpBuilder &builder, Type type,
- ValueRange inputs,
- Location loc) -> Value {
+ auto ui32MaterializationCast = [](OpBuilder &builder, Type type,
+ ValueRange inputs,
+ Location loc) -> Value {
if (inputs.size() != 1)
return {};
auto input = inputs.front();
@@ -1227,17 +1218,25 @@ struct ConvertXeGPUToXeVMPass
};
// Materialization to convert
+ // - single element 1D vector to scalar
// - bitcast vector of same rank
// - shape vector of different rank but same element type
- // Applies to both source and target materialization.
- auto vectorToVectorMaterializationCast = [](OpBuilder &builder, Type type,
- ValueRange inputs,
- Location loc) -> Value {
+ auto vectorMaterializationCast = [](OpBuilder &builder, Type type,
+ ValueRange inputs,
+ Location loc) -> Value {
if (inputs.size() != 1)
return {};
auto input = inputs.front();
if (auto vecTy = dyn_cast<VectorType>(input.getType())) {
- if (auto targetVecTy = dyn_cast<VectorType>(type)) {
+ if (vecTy.getNumElements() == 1) {
+ // If the vector has a single element, return the element type.
+ Value cast =
+ vector::ExtractOp::create(builder, loc, input, 0).getResult();
+ if (vecTy.getElementType() == builder.getIndexType())
+ cast = arith::IndexCastUIOp::create(builder, loc, type, cast)
+ .getResult();
+ return cast;
+ } else if (auto targetVecTy = dyn_cast<VectorType>(type)) {
// If the target type is a vector of same rank,
// bitcast to the target type.
if (targetVecTy.getRank() == vecTy.getRank())
@@ -1254,79 +1253,34 @@ struct ConvertXeGPUToXeVMPass
return {};
};
- // Materialization to convert
- // - single element vector to single element of vector element type
- // Applies only to target materialization.
- auto vectorToSingleElementMaterializationCast =
- [](OpBuilder &builder, Type type, ValueRange inputs,
- Location loc) -> Value {
- if (inputs.size() != 1)
- return {};
- auto input = inputs.front();
- if (auto vecTy = dyn_cast<VectorType>(input.getType())) {
- if (type == vecTy.getElementType() ||
- ((vecTy.getElementType() == builder.getIndexType()) &&
- type.isInteger())) {
- // If the vector rank is 0 or has a single element,
- // extract scalar of target type.
- auto rank = vecTy.getRank();
- Value cast;
- if (rank == 0) {
- cast =
- vector::ExtractOp::create(builder, loc, input, {}).getResult();
- } else {
- cast = vector::ExtractOp::create(builder, loc, input,
- SmallVector<int64_t>(rank, 0))
- .getResult();
- }
- if (type != vecTy.getElementType())
- cast = arith::IndexCastUIOp::create(builder, loc, type, cast)
- .getResult();
- return cast;
- }
- }
- return {};
- };
-
- // Materialization to convert
- // - single element of vector element type to single element vector
// If result type of original op is single element vector and lowered type
// is scalar. This materialization cast creates a single element vector by
// broadcasting the scalar value.
- // Applies only to source materialization.
- auto singleElementToVectorMaterializationCast =
+ auto singleElementVectorMaterializationCast =
[](OpBuilder &builder, Type type, ValueRange inputs,
Location loc) -> Value {
if (inputs.size() != 1)
return {};
auto input = inputs.front();
- // If the target type is a vector of rank 0 or single element vector
- // of element type matching input type, broadcast input to target type.
- if (auto vecTy = dyn_cast<VectorType>(type)) {
- if (vecTy.getRank() == 0 || vecTy.getNumElements() == 1) {
- if (input.getType() == vecTy.getElementType()) {
+ if (input.getType().isIntOrIndexOrFloat()) {
+ // If the input is a scalar, and the target type is a vector of single
+ // element, create a single element vector by broadcasting.
+ if (auto vecTy = dyn_cast<VectorType>(type)) {
+ if (vecTy.getNumElements() == 1) {
return vector::BroadcastOp::create(builder, loc, vecTy, input)
.getResult();
- } else if (vecTy.getElementType() == builder.getIndexType()) {
- Value cast = arith::IndexCastUIOp::create(
- builder, loc, builder.getIndexType(), input)
- .getResult();
- return vector::BroadcastOp::create(builder, loc, vecTy, cast)
- .getResult();
}
}
}
return {};
};
typeConverter.addSourceMaterialization(
- singleElementToVectorMaterializationCast);
- typeConverter.addSourceMaterialization(vectorToVectorMaterializationCast);
- typeConverter.addTargetMaterialization(memrefToIntMaterializationCast);
- typeConverter.addTargetMaterialization(ui32ToI32MaterializationCast);
- typeConverter.addTargetMaterialization(ui64ToI64MaterializationCast);
- typeConverter.addTargetMaterialization(
- vectorToSingleElementMaterializationCast);
- typeConverter.addTargetMaterialization(vectorToVectorMaterializationCast);
+ singleElementVectorMaterializationCast);
+ typeConverter.addSourceMaterialization(vectorMaterializationCast);
+ typeConverter.addTargetMaterialization(memrefMaterializationCast);
+ typeConverter.addTargetMaterialization(ui32MaterializationCast);
+ typeConverter.addTargetMaterialization(ui64MaterializationCast);
+ typeConverter.addTargetMaterialization(vectorMaterializationCast);
ConversionTarget target(getContext());
target.addLegalDialect<xevm::XeVMDialect, LLVM::LLVMDialect,
vector::VectorDialect, arith::ArithDialect,
>From 32f2f30d2a930a675a5f496dc677128201d376cd Mon Sep 17 00:00:00 2001
From: Arjun Bhamra <arjun.bhamra25 at gmail.com>
Date: Tue, 10 Mar 2026 11:17:48 -0400
Subject: [PATCH 7/8] reverted conversion changes, added changes to verifier
only
---
.../Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp | 156 +++++++++---------
mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 23 ++-
mlir/test/Conversion/XeGPUToXeVM/dpas.mlir | 2 +-
.../XeGPUToXeVM/failed_conversion.mlir | 14 --
mlir/test/Dialect/XeGPU/invalid.mlir | 8 +
5 files changed, 107 insertions(+), 96 deletions(-)
delete mode 100644 mlir/test/Conversion/XeGPUToXeVM/failed_conversion.mlir
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index 795bc0ad84a4a..6df209438447b 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -23,7 +23,6 @@
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
-#include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/STLExtras.h"
@@ -903,43 +902,6 @@ class DpasToXeVMPattern : public OpConversionPattern<xegpu::DpasOp> {
auto bTy = cast<VectorType>(op.getRhs().getType());
auto resultType = cast<VectorType>(op.getResultType());
- // get the correct dpasInst by getting info from chip
- auto chipStr = xegpu::getChipStr(op);
- if (!chipStr)
- return rewriter.notifyMatchFailure(op, "cannot determine target chip");
-
- const auto *uArch = mlir::xegpu::uArch::getUArch(*chipStr);
- if (!uArch)
- return rewriter.notifyMatchFailure(op, "unsupported target uArch");
-
- auto *dpasInst = const_cast<xegpu::uArch::SubgroupMatrixMultiplyAcc *>(
- llvm::dyn_cast_or_null<xegpu::uArch::SubgroupMatrixMultiplyAcc>(
- uArch->getInstruction(
- xegpu::uArch::InstructionKind::SubgroupMatrixMultiplyAcc)));
- if (!dpasInst)
- return rewriter.notifyMatchFailure(op,
- "DPAS not supported by target uArch");
-
- auto supportedA =
- dpasInst->getSupportedTypes(*ctxt, xegpu::uArch::MMAOpndKind::MatrixA);
- auto supportedB =
- dpasInst->getSupportedTypes(*ctxt, xegpu::uArch::MMAOpndKind::MatrixB);
- auto supportedD =
- dpasInst->getSupportedTypes(*ctxt, xegpu::uArch::MMAOpndKind::MatrixD);
- // NOTE: Supported types for MatrixC and MatrixD are identical
-
- if (llvm::find(supportedA, aTy.getElementType()) == supportedA.end())
- return rewriter.notifyMatchFailure(
- op, "A-matrix element type not supported by target uArch");
-
- if (llvm::find(supportedB, bTy.getElementType()) == supportedB.end())
- return rewriter.notifyMatchFailure(
- op, "B-matrix element type not supported by target uArch");
-
- if (llvm::find(supportedD, resultType.getElementType()) == supportedD.end())
- return rewriter.notifyMatchFailure(
- op, "result/accumulator element type not supported by target uArch");
-
auto encodePrecision = [&](Type type) -> xevm::ElemType {
if (type == rewriter.getBF16Type())
return xevm::ElemType::BF16;
@@ -1094,8 +1056,8 @@ struct ConvertXeGPUToXeVMPass
// If the element type is index, convert it to i64.
if (llvm::isa<IndexType>(elemType))
elemType = IntegerType::get(&getContext(), 64);
- // If the vector is a scalar or has a single element, return the element
- if (rank < 1 || type.getNumElements() == 1)
+ // If the vector rank is 0 or has a single element, return the element
+ if (rank == 0 || type.getNumElements() == 1)
return elemType;
// Otherwise, convert the vector to a flat vector type.
int64_t sum = llvm::product_of(type.getShape());
@@ -1123,9 +1085,12 @@ struct ConvertXeGPUToXeVMPass
// add materialization casts to handle them.
// Materialization to convert memref to i64 or i32 depending on global/SLM
- auto memrefMaterializationCast = [](OpBuilder &builder, Type type,
- ValueRange inputs,
- Location loc) -> Value {
+ // Applies only to target materialization.
+ // Note: int type to memref materialization is not required as xegpu ops
+ // currently do not produce memrefs as result.
+ auto memrefToIntMaterializationCast = [](OpBuilder &builder, Type type,
+ ValueRange inputs,
+ Location loc) -> Value {
if (inputs.size() != 1)
return {};
auto input = inputs.front();
@@ -1184,9 +1149,12 @@ struct ConvertXeGPUToXeVMPass
};
// Materialization to convert ui64 to i64
- auto ui64MaterializationCast = [](OpBuilder &builder, Type type,
- ValueRange inputs,
- Location loc) -> Value {
+ // Applies only to target materialization.
+ // Note: i64 to ui64 materialization is not required as xegpu ops
+ // currently do not produce ui64 as result.
+ auto ui64ToI64MaterializationCast = [](OpBuilder &builder, Type type,
+ ValueRange inputs,
+ Location loc) -> Value {
if (inputs.size() != 1)
return {};
auto input = inputs.front();
@@ -1201,9 +1169,12 @@ struct ConvertXeGPUToXeVMPass
};
// Materialization to convert ui32 to i32
- auto ui32MaterializationCast = [](OpBuilder &builder, Type type,
- ValueRange inputs,
- Location loc) -> Value {
+ // Applies only to target materialization.
+ // Note: i32 to ui32 materialization is not required as xegpu ops
+ // currently do not produce ui32 as result.
+ auto ui32ToI32MaterializationCast = [](OpBuilder &builder, Type type,
+ ValueRange inputs,
+ Location loc) -> Value {
if (inputs.size() != 1)
return {};
auto input = inputs.front();
@@ -1218,25 +1189,17 @@ struct ConvertXeGPUToXeVMPass
};
// Materialization to convert
- // - single element 1D vector to scalar
// - bitcast vector of same rank
// - shape vector of different rank but same element type
- auto vectorMaterializationCast = [](OpBuilder &builder, Type type,
- ValueRange inputs,
- Location loc) -> Value {
+ // Applies to both source and target materialization.
+ auto vectorToVectorMaterializationCast = [](OpBuilder &builder, Type type,
+ ValueRange inputs,
+ Location loc) -> Value {
if (inputs.size() != 1)
return {};
auto input = inputs.front();
if (auto vecTy = dyn_cast<VectorType>(input.getType())) {
- if (vecTy.getNumElements() == 1) {
- // If the vector has a single element, return the element type.
- Value cast =
- vector::ExtractOp::create(builder, loc, input, 0).getResult();
- if (vecTy.getElementType() == builder.getIndexType())
- cast = arith::IndexCastUIOp::create(builder, loc, type, cast)
- .getResult();
- return cast;
- } else if (auto targetVecTy = dyn_cast<VectorType>(type)) {
+ if (auto targetVecTy = dyn_cast<VectorType>(type)) {
// If the target type is a vector of same rank,
// bitcast to the target type.
if (targetVecTy.getRank() == vecTy.getRank())
@@ -1253,34 +1216,79 @@ struct ConvertXeGPUToXeVMPass
return {};
};
+ // Materialization to convert
+ // - single element vector to single element of vector element type
+ // Applies only to target materialization.
+ auto vectorToSingleElementMaterializationCast =
+ [](OpBuilder &builder, Type type, ValueRange inputs,
+ Location loc) -> Value {
+ if (inputs.size() != 1)
+ return {};
+ auto input = inputs.front();
+ if (auto vecTy = dyn_cast<VectorType>(input.getType())) {
+ if (type == vecTy.getElementType() ||
+ ((vecTy.getElementType() == builder.getIndexType()) &&
+ type.isInteger())) {
+ // If the vector rank is 0 or has a single element,
+ // extract scalar of target type.
+ auto rank = vecTy.getRank();
+ Value cast;
+ if (rank == 0) {
+ cast =
+ vector::ExtractOp::create(builder, loc, input, {}).getResult();
+ } else {
+ cast = vector::ExtractOp::create(builder, loc, input,
+ SmallVector<int64_t>(rank, 0))
+ .getResult();
+ }
+ if (type != vecTy.getElementType())
+ cast = arith::IndexCastUIOp::create(builder, loc, type, cast)
+ .getResult();
+ return cast;
+ }
+ }
+ return {};
+ };
+
+ // Materialization to convert
+ // - single element of vector element type to single element vector
// If result type of original op is single element vector and lowered type
// is scalar. This materialization cast creates a single element vector by
// broadcasting the scalar value.
- auto singleElementVectorMaterializationCast =
+ // Applies only to source materialization.
+ auto singleElementToVectorMaterializationCast =
[](OpBuilder &builder, Type type, ValueRange inputs,
Location loc) -> Value {
if (inputs.size() != 1)
return {};
auto input = inputs.front();
- if (input.getType().isIntOrIndexOrFloat()) {
- // If the input is a scalar, and the target type is a vector of single
- // element, create a single element vector by broadcasting.
- if (auto vecTy = dyn_cast<VectorType>(type)) {
- if (vecTy.getNumElements() == 1) {
+ // If the target type is a vector of rank 0 or single element vector
+ // of element type matching input type, broadcast input to target type.
+ if (auto vecTy = dyn_cast<VectorType>(type)) {
+ if (vecTy.getRank() == 0 || vecTy.getNumElements() == 1) {
+ if (input.getType() == vecTy.getElementType()) {
return vector::BroadcastOp::create(builder, loc, vecTy, input)
.getResult();
+ } else if (vecTy.getElementType() == builder.getIndexType()) {
+ Value cast = arith::IndexCastUIOp::create(
+ builder, loc, builder.getIndexType(), input)
+ .getResult();
+ return vector::BroadcastOp::create(builder, loc, vecTy, cast)
+ .getResult();
}
}
}
return {};
};
typeConverter.addSourceMaterialization(
- singleElementVectorMaterializationCast);
- typeConverter.addSourceMaterialization(vectorMaterializationCast);
- typeConverter.addTargetMaterialization(memrefMaterializationCast);
- typeConverter.addTargetMaterialization(ui32MaterializationCast);
- typeConverter.addTargetMaterialization(ui64MaterializationCast);
- typeConverter.addTargetMaterialization(vectorMaterializationCast);
+ singleElementToVectorMaterializationCast);
+ typeConverter.addSourceMaterialization(vectorToVectorMaterializationCast);
+ typeConverter.addTargetMaterialization(memrefToIntMaterializationCast);
+ typeConverter.addTargetMaterialization(ui32ToI32MaterializationCast);
+ typeConverter.addTargetMaterialization(ui64ToI64MaterializationCast);
+ typeConverter.addTargetMaterialization(
+ vectorToSingleElementMaterializationCast);
+ typeConverter.addTargetMaterialization(vectorToVectorMaterializationCast);
ConversionTarget target(getContext());
target.addLegalDialect<xevm::XeVMDialect, LLVM::LLVMDialect,
vector::VectorDialect, arith::ArithDialect,
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index e470d1f820f79..4973b677f65f8 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -1045,22 +1045,32 @@ LogicalResult UpdateOffsetOp::verify() {
// XeGPU_DpasOp
//===----------------------------------------------------------------------===//
LogicalResult DpasOp::verify() {
- int64_t lhsRank = getLhsType().getRank();
- int64_t rhsRank = getRhsType().getRank();
+ auto lhsType = getLhsType();
+ auto rhsType = getRhsType();
+ int64_t lhsRank = lhsType.getRank();
+ int64_t rhsRank = rhsType.getRank();
int64_t resRank = getResultType().getRank();
- auto lhsShape = getLhsType().getShape();
- auto rhsShape = getRhsType().getShape();
+ auto lhsShape = lhsType.getShape();
+ auto rhsShape = rhsType.getShape();
auto resShape = getResultType().getShape();
if (getAcc() && getAcc().getType() != getResultType())
return emitOpError("Expecting the acc type to be the same as result.");
+ auto lhsElemTy = lhsType.getElementType();
+ auto rhsElemTy = rhsType.getElementType();
+ if (lhsElemTy.getIntOrFloatBitWidth() >= 32 ||
+ rhsElemTy.getIntOrFloatBitWidth() >= 32) {
+ return emitOpError(
+ "Expecting lhs and rhs element types to be at most 32 bits.");
+ }
+
// SIMT code: the size of the B operand has to be a multiple of 32 bits.
// It skips the semantic check since lack of architecture information.
// Users need to ensure the correctness.
if (lhsRank == 1 && rhsRank == 1 && resRank == 1) {
- auto numElems = getRhsType().getNumElements();
- auto elemTy = getRhsType().getElementType();
+ auto numElems = rhsType.getNumElements();
+ auto elemTy = rhsType.getElementType();
auto factor = 32 / elemTy.getIntOrFloatBitWidth();
if (numElems % factor != 0)
return emitOpError("Expecting B operand to be a multiple of 32 bits.");
@@ -1082,7 +1092,6 @@ LogicalResult DpasOp::verify() {
return success();
}
-
//===----------------------------------------------------------------------===//
// XeGPU_ConvertLayoutOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/XeGPUToXeVM/dpas.mlir b/mlir/test/Conversion/XeGPUToXeVM/dpas.mlir
index 7cc59f4cfdcdf..a9ab0be00722c 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/dpas.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/dpas.mlir
@@ -1,6 +1,6 @@
// RUN: mlir-opt -convert-xegpu-to-xevm %s | FileCheck %s
-gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
+gpu.module @test_kernel {
// CHECK-LABEL: func.func @dpas(
// CHECK-SAME: %[[ARG0:.*]]: vector<8xf16>, %[[ARG1:.*]]: vector<16xf16>, %[[ARG2:.*]]: vector<8xf32>
func.func @dpas(%a_loaded: vector<8xf16>, %b_loaded: vector<16xf16>, %c_loaded: vector<8xf32>) -> vector<8xf32> {
diff --git a/mlir/test/Conversion/XeGPUToXeVM/failed_conversion.mlir b/mlir/test/Conversion/XeGPUToXeVM/failed_conversion.mlir
deleted file mode 100644
index c4088c430a4d9..0000000000000
--- a/mlir/test/Conversion/XeGPUToXeVM/failed_conversion.mlir
+++ /dev/null
@@ -1,14 +0,0 @@
-// RUN: mlir-opt --convert-xegpu-to-xevm %s -split-input-file -verify-diagnostics
-
-// Verify that xegpu.dpas with unsupported element types (i16) is rejected
-// during XeGPUToXeVM conversion rather than crashing.
-
-gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
- func.func @main() {
- %0 = spirv.Constant dense<0> : vector<4xi16>
- %1 = spirv.Constant dense<0> : vector<4xi32>
- // expected-error at +1 {{failed to legalize operation 'xegpu.dpas' that was explicitly marked illegal}}
- %2 = xegpu.dpas %0, %0, %1 : vector<4xi16>, vector<4xi16>, vector<4xi32> -> vector<4xi32>
- return
- }
-}
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index 53d497e4c2087..34083f72f4e67 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -601,6 +601,14 @@ func.func @dpas_5(%a : vector<8x16xf16>, %b: vector<8x8x2xf16>) {
return
}
+// -----
+func.func @dpas_6() {
+ %0 = arith.constant dense<0> : vector<4xi32>
+ // expected-error at +1 {{'xegpu.dpas' op Expecting lhs and rhs element types to be at most 32 bits.}}
+ %1 = xegpu.dpas %0, %0 : vector<4xi32>, vector<4xi32> -> vector<4xi32>
+ return
+}
+
// -----
func.func @dpas_simt_1(%a : vector<8xf16>, %b: vector<15xf16>) {
// expected-error at +1 {{Expecting B operand to be a multiple of 32 bits}}
>From 252d8a75a073d1ece7bdf07feab29b0b6b360a5a Mon Sep 17 00:00:00 2001
From: Arjun Bhamra <arjun.bhamra25 at gmail.com>
Date: Wed, 11 Mar 2026 20:54:08 -0400
Subject: [PATCH 8/8] re-added failed conversion check, uArch check, etc.
---
.../Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp | 35 +++++++++++++++++++
.../XeGPUToXeVM/failed_conversion.mlir | 14 ++++++++
2 files changed, 49 insertions(+)
create mode 100644 mlir/test/Conversion/XeGPUToXeVM/failed_conversion.mlir
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index 6df209438447b..6dc9d0a55a308 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -23,6 +23,7 @@
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
+#include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/STLExtras.h"
@@ -902,6 +903,40 @@ class DpasToXeVMPattern : public OpConversionPattern<xegpu::DpasOp> {
auto bTy = cast<VectorType>(op.getRhs().getType());
auto resultType = cast<VectorType>(op.getResultType());
+ // get the correct dpasInst by getting info from chip
+ auto chipStr = xegpu::getChipStr(op);
+ if (!chipStr)
+ return rewriter.notifyMatchFailure(op, "cannot determine target chip");
+
+ const auto *uArch = mlir::xegpu::uArch::getUArch(*chipStr);
+ if (!uArch)
+ return rewriter.notifyMatchFailure(op, "unsupported target uArch");
+
+ auto *dpasInst = const_cast<xegpu::uArch::SubgroupMatrixMultiplyAcc *>(
+ llvm::dyn_cast_or_null<xegpu::uArch::SubgroupMatrixMultiplyAcc>(
+ uArch->getInstruction(
+ xegpu::uArch::InstructionKind::SubgroupMatrixMultiplyAcc)));
+ if (!dpasInst)
+ return rewriter.notifyMatchFailure(op,
+ "DPAS not supported by target uArch");
+
+ auto checkSupportedTypes = [&](VectorType vecTy,
+ xegpu::uArch::MMAOpndKind kind) -> bool {
+ auto supported = dpasInst->getSupportedTypes(*ctxt, kind);
+ return llvm::find(supported, vecTy.getElementType()) != supported.end();
+ };
+
+ if (!checkSupportedTypes(aTy, xegpu::uArch::MMAOpndKind::MatrixA))
+ return rewriter.notifyMatchFailure(
+ op, "A-matrix element type not supported by target uArch");
+ if (!checkSupportedTypes(bTy, xegpu::uArch::MMAOpndKind::MatrixB))
+ return rewriter.notifyMatchFailure(
+ op, "B-matrix element type not supported by target uArch");
+ // NOTE: Supported types for MatrixC and MatrixD are identical
+ if (!checkSupportedTypes(resultType, xegpu::uArch::MMAOpndKind::MatrixD))
+ return rewriter.notifyMatchFailure(
+ op, "result/accumulator element type not supported by target uArch");
+
auto encodePrecision = [&](Type type) -> xevm::ElemType {
if (type == rewriter.getBF16Type())
return xevm::ElemType::BF16;
diff --git a/mlir/test/Conversion/XeGPUToXeVM/failed_conversion.mlir b/mlir/test/Conversion/XeGPUToXeVM/failed_conversion.mlir
new file mode 100644
index 0000000000000..95211dcff250c
--- /dev/null
+++ b/mlir/test/Conversion/XeGPUToXeVM/failed_conversion.mlir
@@ -0,0 +1,14 @@
+// RUN: mlir-opt --convert-xegpu-to-xevm %s -split-input-file -verify-diagnostics
+
+// Verify that xegpu.dpas with unsupported element types (i16) is rejected
+// during XeGPUToXeVM conversion rather than crashing.
+
+gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
+ func.func @main() {
+ %0 = arith.constant dense<0> : vector<4xi16>
+ %1 = arith.constant dense<0> : vector<4xi32>
+ // expected-error at +1 {{failed to legalize operation 'xegpu.dpas' that was explicitly marked illegal}}
+ %2 = xegpu.dpas %0, %0, %1 : vector<4xi16>, vector<4xi16>, vector<4xi32> -> vector<4xi32>
+ return
+ }
+}
More information about the Mlir-commits
mailing list