[Mlir-commits] [mlir] [MLIR][XeGPU] Validate DPAS operand types against uArch in XeGPUToXeVM conversion (PR #185081)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Mar 6 11:11:39 PST 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-gpu
Author: Arjun Bhamra (abhamra)
<details>
<summary>Changes</summary>
The `DpasOp` would crash with `llvm_unreachable` with unsupported types (like i16, or i32 in operand) when during lowering to the XeVM dialect. This happens in both `encodePrecision` and `getNumOperandsPerDword`.
Per https://github.com/llvm/llvm-project/issues/180107#issuecomment-4009160113, we handle this in the `matchAndRewrite` by retrieving the uArch instance and fetching the registered `SubgroupMatrixMultiplyAcc` instruction. Then, we validate with `getSupportedTypes` and check `aTy`, `bTy`, and `resultType` correctly with `notifyMatchError` for reporting and graceful handling.
We add a failed conversion test for a simplified version of the reproducible error in #<!-- -->180107, although I'm not sure this is as good as a direct XeGPU type validation test.
**Note:** Since `getChipStr` requires an `XeVMTargetAttr`, I modified the `dpas.mlir` conversion test to handle that case.
Closes #<!-- -->180107.
cc: @<!-- -->Jianhui-Li @<!-- -->adam-smnk
---
Full diff: https://github.com/llvm/llvm-project/pull/185081.diff
3 Files Affected:
- (modified) mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp (+34)
- (modified) mlir/test/Conversion/XeGPUToXeVM/dpas.mlir (+1-1)
- (added) mlir/test/Conversion/XeGPUToXeVM/failed_conversion.mlir (+14)
``````````diff
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index 6df209438447b..5aa78f5b5b9f2 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -22,6 +22,7 @@
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h"
#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
@@ -902,6 +903,39 @@ class DpasToXeVMPattern : public OpConversionPattern<xegpu::DpasOp> {
auto bTy = cast<VectorType>(op.getRhs().getType());
auto resultType = cast<VectorType>(op.getResultType());
+ // get the correct dpasInst by getting info from chip
+ auto chipStr = xegpu::getChipStr(op);
+ if (!chipStr)
+ return rewriter.notifyMatchFailure(op, "cannot determine target chip");
+
+ const auto *uArch= mlir::xegpu::uArch::getUArch(*chipStr);
+ if (!uArch)
+ return rewriter.notifyMatchFailure(op, "unsupported target uArch");
+
+ auto *dpasInst = const_cast<xegpu::uArch::SubgroupMatrixMultiplyAcc*>(
+ llvm::dyn_cast_or_null<xegpu::uArch::SubgroupMatrixMultiplyAcc>(
+ uArch->getInstruction(
+ xegpu::uArch::InstructionKind::SubgroupMatrixMultiplyAcc)));
+ if (!dpasInst)
+ return rewriter.notifyMatchFailure(op, "DPAS not supported by target uArch");
+
+ auto supportedA = dpasInst->getSupportedTypes(*ctxt, xegpu::uArch::MMAOpndKind::MatrixA);
+ auto supportedB = dpasInst->getSupportedTypes(*ctxt, xegpu::uArch::MMAOpndKind::MatrixB);
+ auto supportedD = dpasInst->getSupportedTypes(*ctxt, xegpu::uArch::MMAOpndKind::MatrixD);
+ // NOTE: Supported types for MatrixC and MatrixD are identical
+
+ if (llvm::find(supportedA, aTy.getElementType()) == supportedA.end())
+ return rewriter.notifyMatchFailure(
+ op, "A-matrix element type not supported by target uArch");
+
+ if (llvm::find(supportedB, bTy.getElementType()) == supportedB.end())
+ return rewriter.notifyMatchFailure(
+ op, "B-matrix element type not supported by target uArch");
+
+ if (llvm::find(supportedD, resultType.getElementType()) == supportedD.end())
+ return rewriter.notifyMatchFailure(
+ op, "result/accumulator element type not supported by target uArch");
+
auto encodePrecision = [&](Type type) -> xevm::ElemType {
if (type == rewriter.getBF16Type())
return xevm::ElemType::BF16;
diff --git a/mlir/test/Conversion/XeGPUToXeVM/dpas.mlir b/mlir/test/Conversion/XeGPUToXeVM/dpas.mlir
index a9ab0be00722c..7cc59f4cfdcdf 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/dpas.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/dpas.mlir
@@ -1,6 +1,6 @@
// RUN: mlir-opt -convert-xegpu-to-xevm %s | FileCheck %s
-gpu.module @test_kernel {
+gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
// CHECK-LABEL: func.func @dpas(
// CHECK-SAME: %[[ARG0:.*]]: vector<8xf16>, %[[ARG1:.*]]: vector<16xf16>, %[[ARG2:.*]]: vector<8xf32>
func.func @dpas(%a_loaded: vector<8xf16>, %b_loaded: vector<16xf16>, %c_loaded: vector<8xf32>) -> vector<8xf32> {
diff --git a/mlir/test/Conversion/XeGPUToXeVM/failed_conversion.mlir b/mlir/test/Conversion/XeGPUToXeVM/failed_conversion.mlir
new file mode 100644
index 0000000000000..c4088c430a4d9
--- /dev/null
+++ b/mlir/test/Conversion/XeGPUToXeVM/failed_conversion.mlir
@@ -0,0 +1,14 @@
+// RUN: mlir-opt --convert-xegpu-to-xevm %s -split-input-file -verify-diagnostics
+
+// Verify that xegpu.dpas with unsupported element types (i16) is rejected
+// during XeGPUToXeVM conversion rather than crashing.
+
+gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
+ func.func @main() {
+ %0 = spirv.Constant dense<0> : vector<4xi16>
+ %1 = spirv.Constant dense<0> : vector<4xi32>
+ // expected-error at +1 {{failed to legalize operation 'xegpu.dpas' that was explicitly marked illegal}}
+ %2 = xegpu.dpas %0, %0, %1 : vector<4xi16>, vector<4xi16>, vector<4xi32> -> vector<4xi32>
+ return
+ }
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/185081
More information about the Mlir-commits
mailing list