[Mlir-commits] [mlir] [MLIR][XeGPU] Validate DPAS operand types against uArch in XeGPUToXeVM conversion (PR #185081)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Mar 6 11:11:39 PST 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-gpu

Author: Arjun Bhamra (abhamra)

<details>
<summary>Changes</summary>

The `DpasOp` would crash with `llvm_unreachable` with unsupported types (like i16, or i32 in operand) when during lowering to the XeVM dialect. This happens in both `encodePrecision` and `getNumOperandsPerDword`. 

Per https://github.com/llvm/llvm-project/issues/180107#issuecomment-4009160113, we handle this in the `matchAndRewrite` by retrieving the uArch instance and fetching the registered `SubgroupMatrixMultiplyAcc` instruction. Then, we validate with `getSupportedTypes` and check `aTy`, `bTy`, and `resultType` correctly with `notifyMatchError` for reporting and graceful handling.

We add a failed conversion test for a simplified version of the reproducible error in #<!-- -->180107, although I'm not sure this is as good as a direct XeGPU type validation test.

**Note:** Since `getChipStr` requires an `XeVMTargetAttr`, I modified the `dpas.mlir` conversion test to handle that case.

Closes #<!-- -->180107.

cc: @<!-- -->Jianhui-Li @<!-- -->adam-smnk 

---
Full diff: https://github.com/llvm/llvm-project/pull/185081.diff


3 Files Affected:

- (modified) mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp (+34) 
- (modified) mlir/test/Conversion/XeGPUToXeVM/dpas.mlir (+1-1) 
- (added) mlir/test/Conversion/XeGPUToXeVM/failed_conversion.mlir (+14) 


``````````diff
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index 6df209438447b..5aa78f5b5b9f2 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -22,6 +22,7 @@
 #include "mlir/Dialect/SCF/Transforms/Patterns.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h"
 #include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Support/LLVM.h"
@@ -902,6 +903,39 @@ class DpasToXeVMPattern : public OpConversionPattern<xegpu::DpasOp> {
     auto bTy = cast<VectorType>(op.getRhs().getType());
     auto resultType = cast<VectorType>(op.getResultType());
 
+    // get the correct dpasInst by getting info from chip
+    auto chipStr = xegpu::getChipStr(op);
+    if (!chipStr)
+      return rewriter.notifyMatchFailure(op, "cannot determine target chip");
+
+    const auto *uArch= mlir::xegpu::uArch::getUArch(*chipStr);
+    if (!uArch)
+      return rewriter.notifyMatchFailure(op, "unsupported target uArch");
+
+    auto *dpasInst = const_cast<xegpu::uArch::SubgroupMatrixMultiplyAcc*>(
+        llvm::dyn_cast_or_null<xegpu::uArch::SubgroupMatrixMultiplyAcc>(
+            uArch->getInstruction(
+                xegpu::uArch::InstructionKind::SubgroupMatrixMultiplyAcc)));
+    if (!dpasInst)
+      return rewriter.notifyMatchFailure(op, "DPAS not supported by target uArch");
+
+    auto supportedA = dpasInst->getSupportedTypes(*ctxt, xegpu::uArch::MMAOpndKind::MatrixA);
+    auto supportedB = dpasInst->getSupportedTypes(*ctxt, xegpu::uArch::MMAOpndKind::MatrixB);
+    auto supportedD = dpasInst->getSupportedTypes(*ctxt, xegpu::uArch::MMAOpndKind::MatrixD);
+    // NOTE: Supported types for MatrixC and MatrixD are identical
+
+    if (llvm::find(supportedA, aTy.getElementType()) == supportedA.end())
+      return rewriter.notifyMatchFailure(
+          op, "A-matrix element type not supported by target uArch");
+
+    if (llvm::find(supportedB, bTy.getElementType()) == supportedB.end())
+      return rewriter.notifyMatchFailure(
+          op, "B-matrix element type not supported by target uArch");
+
+    if (llvm::find(supportedD, resultType.getElementType()) == supportedD.end())
+      return rewriter.notifyMatchFailure(
+          op, "result/accumulator element type not supported by target uArch");
+
     auto encodePrecision = [&](Type type) -> xevm::ElemType {
       if (type == rewriter.getBF16Type())
         return xevm::ElemType::BF16;
diff --git a/mlir/test/Conversion/XeGPUToXeVM/dpas.mlir b/mlir/test/Conversion/XeGPUToXeVM/dpas.mlir
index a9ab0be00722c..7cc59f4cfdcdf 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/dpas.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/dpas.mlir
@@ -1,6 +1,6 @@
 // RUN: mlir-opt -convert-xegpu-to-xevm %s | FileCheck %s
 
-gpu.module @test_kernel {
+gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
     // CHECK-LABEL: func.func @dpas(
     // CHECK-SAME: %[[ARG0:.*]]: vector<8xf16>, %[[ARG1:.*]]: vector<16xf16>, %[[ARG2:.*]]: vector<8xf32>
     func.func @dpas(%a_loaded: vector<8xf16>, %b_loaded: vector<16xf16>, %c_loaded: vector<8xf32>) -> vector<8xf32> {
diff --git a/mlir/test/Conversion/XeGPUToXeVM/failed_conversion.mlir b/mlir/test/Conversion/XeGPUToXeVM/failed_conversion.mlir
new file mode 100644
index 0000000000000..c4088c430a4d9
--- /dev/null
+++ b/mlir/test/Conversion/XeGPUToXeVM/failed_conversion.mlir
@@ -0,0 +1,14 @@
+// RUN: mlir-opt --convert-xegpu-to-xevm %s -split-input-file -verify-diagnostics
+
+// Verify that xegpu.dpas with unsupported element types (i16) is rejected
+// during XeGPUToXeVM conversion rather than crashing.
+
+gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
+  func.func @main() {
+    %0 = spirv.Constant dense<0> : vector<4xi16>
+    %1 = spirv.Constant dense<0> : vector<4xi32>
+    // expected-error at +1 {{failed to legalize operation 'xegpu.dpas' that was explicitly marked illegal}}
+    %2 = xegpu.dpas %0, %0, %1 : vector<4xi16>, vector<4xi16>, vector<4xi32> -> vector<4xi32>
+    return
+  }
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/185081


More information about the Mlir-commits mailing list