[Mlir-commits] [mlir] [mlir][xegpu] Remove vector contract to dpas size restriction (PR #147470)
Adam Siemieniuk
llvmlistbot at llvm.org
Tue Jul 8 00:28:42 PDT 2025
https://github.com/adam-smnk created https://github.com/llvm/llvm-project/pull/147470
Removes contraction shape check to allow representing large workgroup-level workloads in preparation for distribtion.
>From b054f6533f1ee6867ec29ea29683a4f530f8af5f Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Tue, 8 Jul 2025 09:25:15 +0200
Subject: [PATCH] [mlir][xegpu] Remove vector contract to dpas size restriction
Removes contraction shape check to allow representing large
workgroup-level workloads in preparation for distribtion.
---
.../VectorToXeGPU/VectorToXeGPU.cpp | 7 ---
.../VectorToXeGPU/contract-to-xegpu.mlir | 46 +++++++++++--------
2 files changed, 28 insertions(+), 25 deletions(-)
diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
index 0ec7129a40a66..2e6a16ddbfdaa 100644
--- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
+++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
@@ -339,13 +339,6 @@ struct ContractionLowering : public OpRewritePattern<vector::ContractionOp> {
if (!isRowMajorMatmul(contractOp.getIndexingMapsAttr()))
return rewriter.notifyMatchFailure(contractOp, "Invalid indexing maps");
- // TODO: Update shape validation to be target aware.
- auto accShape = accType.getShape();
- int64_t dimN = accShape[1];
- if (dimN != 8 && dimN != 16)
- return rewriter.notifyMatchFailure(contractOp,
- "Invalid operand dimensions");
-
auto dpasOp = rewriter.create<xegpu::DpasOp>(
loc, TypeRange{contractOp.getResultType()}, ValueRange{lhs, rhs, acc});
rewriter.replaceOp(contractOp, dpasOp);
diff --git a/mlir/test/Conversion/VectorToXeGPU/contract-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/contract-to-xegpu.mlir
index 8857ac204adca..38bda39d3aca2 100644
--- a/mlir/test/Conversion/VectorToXeGPU/contract-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/contract-to-xegpu.mlir
@@ -48,6 +48,34 @@ func.func @dpas_gemm_i8(%lhs: vector<8x32xi8>, %rhs: vector<32x16xi8>,
// -----
+// No restriction on vector sizes to allow capturing workgroup-sized operations.
+// The operations can then be progressively resized through distribution down
+// to hardware compatible sizes.
+
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @dpas_large_dims(%lhs: vector<128x512xf16>, %rhs: vector<512x256xf16>,
+ %acc: vector<128x256xf32>) -> vector<128x256xf32> {
+ %3 = vector.contract
+ {indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>} %lhs, %rhs, %acc
+ : vector<128x512xf16>, vector<512x256xf16> into vector<128x256xf32>
+ return %3 : vector<128x256xf32>
+}
+
+// CHECK-LABEL: @dpas_large_dims(
+// CHECK-SAME: %[[LHS:.+]]: vector<128x512xf16>,
+// CHECK-SAME: %[[RHS:.+]]: vector<512x256xf16>,
+// CHECK-SAME: %[[ACC:.+]]: vector<128x256xf32>
+// CHECK: %[[DPAS:.+]] = xegpu.dpas
+// CHECK-SAME: %[[LHS]], %[[RHS]], %[[ACC]]
+// CHECK-SAME: {{.*}}-> vector<128x256xf32>
+// CHECK: return %[[DPAS]]
+
+// -----
+
// For simplicity, only plain data layouts are currently supported.
// VNNI packing is applied later as a separate lowering step.
@@ -138,21 +166,3 @@ func.func @negative_gemm_transpose_b(%lhs: vector<8x16xf16>, %rhs: vector<16x16x
// CHECK-LABEL: @negative_gemm_transpose_b(
// CHECK: vector.contract
-
-// -----
-
-#map = affine_map<(d0, d1, d2) -> (d0, d2)>
-#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
-#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
-func.func @negative_n_dim_size(%lhs: vector<8x16xf16>, %rhs: vector<16x32xf16>,
- %acc: vector<8x32xf32>) -> vector<8x32xf32> {
- %3 = vector.contract
- {indexing_maps = [#map, #map1, #map2],
- iterator_types = ["parallel", "parallel", "reduction"],
- kind = #vector.kind<add>} %lhs, %rhs, %acc
- : vector<8x16xf16>, vector<16x32xf16> into vector<8x32xf32>
- return %3 : vector<8x32xf32>
-}
-
-// CHECK-LABEL: @negative_n_dim_size(
-// CHECK: vector.contract
More information about the Mlir-commits
mailing list