[Mlir-commits] [mlir] [mlir][xegpu] Remove vector contract to dpas size restriction (PR #147470)

Adam Siemieniuk llvmlistbot at llvm.org
Tue Jul 8 00:28:42 PDT 2025


https://github.com/adam-smnk created https://github.com/llvm/llvm-project/pull/147470

Removes contraction shape check to allow representing large workgroup-level workloads in preparation for distribtion.

>From b054f6533f1ee6867ec29ea29683a4f530f8af5f Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Tue, 8 Jul 2025 09:25:15 +0200
Subject: [PATCH] [mlir][xegpu] Remove vector contract to dpas size restriction

Removes contraction shape check to allow representing large
workgroup-level workloads in preparation for distribtion.
---
 .../VectorToXeGPU/VectorToXeGPU.cpp           |  7 ---
 .../VectorToXeGPU/contract-to-xegpu.mlir      | 46 +++++++++++--------
 2 files changed, 28 insertions(+), 25 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
index 0ec7129a40a66..2e6a16ddbfdaa 100644
--- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
+++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
@@ -339,13 +339,6 @@ struct ContractionLowering : public OpRewritePattern<vector::ContractionOp> {
     if (!isRowMajorMatmul(contractOp.getIndexingMapsAttr()))
       return rewriter.notifyMatchFailure(contractOp, "Invalid indexing maps");
 
-    // TODO: Update shape validation to be target aware.
-    auto accShape = accType.getShape();
-    int64_t dimN = accShape[1];
-    if (dimN != 8 && dimN != 16)
-      return rewriter.notifyMatchFailure(contractOp,
-                                         "Invalid operand dimensions");
-
     auto dpasOp = rewriter.create<xegpu::DpasOp>(
         loc, TypeRange{contractOp.getResultType()}, ValueRange{lhs, rhs, acc});
     rewriter.replaceOp(contractOp, dpasOp);
diff --git a/mlir/test/Conversion/VectorToXeGPU/contract-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/contract-to-xegpu.mlir
index 8857ac204adca..38bda39d3aca2 100644
--- a/mlir/test/Conversion/VectorToXeGPU/contract-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/contract-to-xegpu.mlir
@@ -48,6 +48,34 @@ func.func @dpas_gemm_i8(%lhs: vector<8x32xi8>, %rhs: vector<32x16xi8>,
 
 // -----
 
+// No restriction on vector sizes to allow capturing workgroup-sized operations.
+// The operations can then be progressively resized through distribution down
+// to hardware compatible sizes.
+
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @dpas_large_dims(%lhs: vector<128x512xf16>, %rhs: vector<512x256xf16>,
+    %acc: vector<128x256xf32>) -> vector<128x256xf32> {
+  %3 = vector.contract
+    {indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>} %lhs, %rhs, %acc
+    : vector<128x512xf16>, vector<512x256xf16> into vector<128x256xf32>
+  return %3 : vector<128x256xf32>
+}
+
+// CHECK-LABEL: @dpas_large_dims(
+// CHECK-SAME:  %[[LHS:.+]]: vector<128x512xf16>,
+// CHECK-SAME:  %[[RHS:.+]]: vector<512x256xf16>,
+// CHECK-SAME:  %[[ACC:.+]]: vector<128x256xf32>
+// CHECK:       %[[DPAS:.+]] = xegpu.dpas
+// CHECK-SAME:    %[[LHS]], %[[RHS]], %[[ACC]]
+// CHECK-SAME:    {{.*}}-> vector<128x256xf32>
+// CHECK:       return %[[DPAS]]
+
+// -----
+
 // For simplicity, only plain data layouts are currently supported.
 // VNNI packing is applied later as a separate lowering step.
 
@@ -138,21 +166,3 @@ func.func @negative_gemm_transpose_b(%lhs: vector<8x16xf16>, %rhs: vector<16x16x
 
 // CHECK-LABEL: @negative_gemm_transpose_b(
 // CHECK:       vector.contract
-
-// -----
-
-#map = affine_map<(d0, d1, d2) -> (d0, d2)>
-#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
-#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
-func.func @negative_n_dim_size(%lhs: vector<8x16xf16>, %rhs: vector<16x32xf16>,
-    %acc: vector<8x32xf32>) -> vector<8x32xf32> {
-  %3 = vector.contract
-    {indexing_maps = [#map, #map1, #map2],
-    iterator_types = ["parallel", "parallel", "reduction"],
-    kind = #vector.kind<add>} %lhs, %rhs, %acc
-    : vector<8x16xf16>, vector<16x32xf16> into vector<8x32xf32>
-  return %3 : vector<8x32xf32>
-}
-
-// CHECK-LABEL: @negative_n_dim_size(
-// CHECK:       vector.contract



More information about the Mlir-commits mailing list