[Mlir-commits] [mlir] [MLIR][XeGPU] Support vector.contract transpose_a/transpose_b via 'vector-to-gpu' patterns (PR #182885)
Dmitry Chigarev
llvmlistbot at llvm.org
Sun Mar 1 13:20:49 PST 2026
https://github.com/dchigarev updated https://github.com/llvm/llvm-project/pull/182885
>From a5b55995ea465040319cd048e8f4201ea6e043a0 Mon Sep 17 00:00:00 2001
From: dchigarev <dmitry.chigarev at intel.com>
Date: Fri, 20 Feb 2026 16:41:54 +0000
Subject: [PATCH 1/2] [MLIR][XeGPU] Support vector.contract
transpose_a/transpose_b via 'vector-to-gpu' patterns
Signed-off-by: dchigarev <dmitry.chigarev at intel.com>
---
.../VectorToXeGPU/VectorToXeGPU.cpp | 6 ++
.../VectorToXeGPU/contract-to-xegpu.mlir | 86 +++++++++++--------
2 files changed, 56 insertions(+), 36 deletions(-)
diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
index c81bb4b455b98..1025632f4f14c 100644
--- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
+++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
@@ -11,6 +11,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/VectorToXeGPU/VectorToXeGPU.h"
+#include "mlir/Conversion/VectorToGPU/VectorToGPU.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
@@ -844,6 +845,11 @@ struct ContractionLowering : public OpRewritePattern<vector::ContractionOp> {
struct ConvertVectorToXeGPUPass
: public impl::ConvertVectorToXeGPUBase<ConvertVectorToXeGPUPass> {
void runOnOperation() override {
+ RewritePatternSet prep(&getContext());
+ populatePrepareVectorToMMAPatterns(prep);
+ if (failed(applyPatternsGreedily(getOperation(), std::move(prep))))
+ return signalPassFailure();
+
RewritePatternSet patterns(&getContext());
populateVectorToXeGPUConversionPatterns(patterns);
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
diff --git a/mlir/test/Conversion/VectorToXeGPU/contract-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/contract-to-xegpu.mlir
index 38bda39d3aca2..292e4ff882000 100644
--- a/mlir/test/Conversion/VectorToXeGPU/contract-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/contract-to-xegpu.mlir
@@ -76,6 +76,56 @@ func.func @dpas_large_dims(%lhs: vector<128x512xf16>, %rhs: vector<512x256xf16>,
// -----
+#map = affine_map<(d0, d1, d2) -> (d2, d0)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @gemm_transpose_a(%lhs: vector<16x8xf16>, %rhs: vector<16x16xf16>,
+ %acc: vector<8x16xf32>) -> vector<8x16xf32> {
+ %3 = vector.contract
+ {indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>} %lhs, %rhs, %acc
+ : vector<16x8xf16>, vector<16x16xf16> into vector<8x16xf32>
+ return %3 : vector<8x16xf32>
+}
+
+// CHECK-LABEL: @gemm_transpose_a(
+// CHECK-SAME: %[[LHS:.+]]: vector<16x8xf16>,
+// CHECK-SAME: %[[RHS:.+]]: vector<16x16xf16>,
+// CHECK-SAME: %[[ACC:.+]]: vector<8x16xf32>
+// CHECK: %[[LHS_TRANSPOSED:.+]] = vector.transpose %[[LHS]], [1, 0] : vector<16x8xf16> to vector<8x16xf16>
+// CHECK: %[[DPAS:.+]] = xegpu.dpas
+// CHECK-SAME: %[[LHS_TRANSPOSED]], %[[RHS]], %[[ACC]]
+// CHECK-SAME: {{.*}}-> vector<8x16xf32>
+// CHECK: return %[[DPAS]]
+
+// -----
+
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @gemm_transpose_b(%lhs: vector<8x16xf16>, %rhs: vector<16x16xf16>,
+ %acc: vector<8x16xf32>) -> vector<8x16xf32> {
+ %3 = vector.contract
+ {indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>} %lhs, %rhs, %acc
+ : vector<8x16xf16>, vector<16x16xf16> into vector<8x16xf32>
+ return %3 : vector<8x16xf32>
+}
+
+// CHECK-LABEL: @gemm_transpose_b(
+// CHECK-SAME: %[[LHS:.+]]: vector<8x16xf16>,
+// CHECK-SAME: %[[RHS:.+]]: vector<16x16xf16>,
+// CHECK-SAME: %[[ACC:.+]]: vector<8x16xf32>
+// CHECK: %[[RHS_TRANSPOSED:.+]] = vector.transpose %[[RHS]], [1, 0] : vector<16x16xf16> to vector<16x16xf16>
+// CHECK: %[[DPAS:.+]] = xegpu.dpas
+// CHECK-SAME: %[[LHS]], %[[RHS_TRANSPOSED]], %[[ACC]]
+// CHECK-SAME: {{.*}}-> vector<8x16xf32>
+// CHECK: return %[[DPAS]]
+
+// -----
+
// For simplicity, only plain data layouts are currently supported.
// VNNI packing is applied later as a separate lowering step.
@@ -130,39 +180,3 @@ func.func @negative_accumulator_shape(%lhs: vector<8x16xf16>, %rhs: vector<16x16
// CHECK-LABEL: @negative_accumulator_shape(
// CHECK: vector.contract
-
-// -----
-
-#map = affine_map<(d0, d1, d2) -> (d2, d0)>
-#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
-#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
-func.func @negative_gemm_transpose_a(%lhs: vector<16x8xf16>, %rhs: vector<16x16xf16>,
- %acc: vector<8x16xf32>) -> vector<8x16xf32> {
- %3 = vector.contract
- {indexing_maps = [#map, #map1, #map2],
- iterator_types = ["parallel", "parallel", "reduction"],
- kind = #vector.kind<add>} %lhs, %rhs, %acc
- : vector<16x8xf16>, vector<16x16xf16> into vector<8x16xf32>
- return %3 : vector<8x16xf32>
-}
-
-// CHECK-LABEL: @negative_gemm_transpose_a(
-// CHECK: vector.contract
-
-// -----
-
-#map = affine_map<(d0, d1, d2) -> (d0, d2)>
-#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>
-#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
-func.func @negative_gemm_transpose_b(%lhs: vector<8x16xf16>, %rhs: vector<16x16xf16>,
- %acc: vector<8x16xf32>) -> vector<8x16xf32> {
- %3 = vector.contract
- {indexing_maps = [#map, #map1, #map2],
- iterator_types = ["parallel", "parallel", "reduction"],
- kind = #vector.kind<add>} %lhs, %rhs, %acc
- : vector<8x16xf16>, vector<16x16xf16> into vector<8x16xf32>
- return %3 : vector<8x16xf32>
-}
-
-// CHECK-LABEL: @negative_gemm_transpose_b(
-// CHECK: vector.contract
>From 78a803820f83ccb53b986c68a3d31b0155bac2d7 Mon Sep 17 00:00:00 2001
From: dchigarev <dmitry.chigarev at intel.com>
Date: Sun, 1 Mar 2026 21:20:35 +0000
Subject: [PATCH 2/2] apply new paterns within the existing pattern-set
Signed-off-by: dchigarev <dmitry.chigarev at intel.com>
---
mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp | 6 +-----
1 file changed, 1 insertion(+), 5 deletions(-)
diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
index 1025632f4f14c..2bece58a119b3 100644
--- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
+++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
@@ -845,13 +845,9 @@ struct ContractionLowering : public OpRewritePattern<vector::ContractionOp> {
struct ConvertVectorToXeGPUPass
: public impl::ConvertVectorToXeGPUBase<ConvertVectorToXeGPUPass> {
void runOnOperation() override {
- RewritePatternSet prep(&getContext());
- populatePrepareVectorToMMAPatterns(prep);
- if (failed(applyPatternsGreedily(getOperation(), std::move(prep))))
- return signalPassFailure();
-
RewritePatternSet patterns(&getContext());
populateVectorToXeGPUConversionPatterns(patterns);
+ populatePrepareVectorToMMAPatterns(patterns);
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
return signalPassFailure();
}
More information about the Mlir-commits
mailing list