[Mlir-commits] [mlir] [mlir][xegpu] Convert Vector contraction to XeGPU (PR #122115)
Adam Siemieniuk
llvmlistbot at llvm.org
Tue Mar 11 06:25:57 PDT 2025
https://github.com/adam-smnk updated https://github.com/llvm/llvm-project/pull/122115
>From 1e00eb8403ef97d59e3f7e6fcc22f1fe868720da Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Tue, 7 Jan 2025 16:54:52 +0100
Subject: [PATCH 1/3] [mlir][xegpu] Convert Vector contraction to XeGPU
Adds pattern to lower vector.contract to XeGPU operation.
---
.../VectorToXeGPU/VectorToXeGPU.cpp | 87 +++++-
.../VectorToXeGPU/contract-to-xegpu.mlir | 259 ++++++++++++++++++
2 files changed, 345 insertions(+), 1 deletion(-)
create mode 100644 mlir/test/Conversion/VectorToXeGPU/contract-to-xegpu.mlir
diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
index 61b55c57240ce..fd96341768199 100644
--- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
+++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
@@ -312,6 +312,91 @@ struct StoreLowering : public OpRewritePattern<vector::StoreOp> {
}
};
+static LogicalResult validateDpasIndexing(PatternRewriter &rewriter,
+ vector::ContractionOp contractOp) {
+ MLIRContext *ctx = contractOp.getContext();
+ SmallVector<AffineMap, 4> maps = contractOp.getIndexingMapsArray();
+
+ // Operand rank defines expected data layout:
+ // - 2D for standard GEMM
+ // - 3D for VNNI layout
+ using MapList = ArrayRef<ArrayRef<AffineExpr>>;
+ auto infer = [&](MapList m) { return AffineMap::inferFromExprList(m, ctx); };
+ AffineExpr m, n, k, vnni;
+ bindDims(ctx, m, n, k, vnni);
+
+ if (contractOp.getRhsType().getRank() == 2) {
+ // Require plain GEMM without any transposition.
+ return success(maps == infer({{m, k}, {k, n}, {m, n}}));
+ }
+
+ // Require VNNI layout.
+ return success(maps == infer({{m, k, vnni}, {k, n, vnni}, {m, n}}));
+}
+
+struct ContractionLowering : public OpRewritePattern<vector::ContractionOp> {
+ using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
+ PatternRewriter &rewriter) const override {
+ Location loc = contractOp.getLoc();
+
+ if (contractOp.getKind() != vector::CombiningKind::ADD)
+ return rewriter.notifyMatchFailure(contractOp,
+ "Expects add combining kind");
+
+ TypedValue<Type> acc = contractOp.getAcc();
+ VectorType accType = dyn_cast<VectorType>(acc.getType());
+ if (!accType || accType.getRank() != 2)
+ return rewriter.notifyMatchFailure(contractOp, "Expects acc 2D vector");
+ TypedValue<VectorType> lhs = contractOp.getLhs();
+ VectorType lhsType = lhs.getType();
+ int64_t lhsRank = lhsType.getRank();
+ if (!(lhsRank == 2 || lhsRank == 3))
+ return rewriter.notifyMatchFailure(contractOp,
+ "Expects lhs 2D or 3D vector");
+ TypedValue<VectorType> rhs = contractOp.getRhs();
+ VectorType rhsType = rhs.getType();
+ int64_t rhsRank = rhsType.getRank();
+ if (!(rhsRank == 2 || rhsRank == 3))
+ return rewriter.notifyMatchFailure(contractOp,
+ "Expects rhs 2D or 3D vector");
+ if (lhsRank != rhsRank)
+ return rewriter.notifyMatchFailure(
+ contractOp, "Expects lhs and rhs to be the same rank");
+
+ if (failed(validateDpasIndexing(rewriter, contractOp)))
+ return rewriter.notifyMatchFailure(contractOp, "Invalid indexing maps");
+
+ // 3D shape implies VNNI layout verified by the earlier indexing check.
+ bool isVnni = rhsRank == 3;
+ auto rhsShape = rhsType.getShape();
+ int64_t dimK = isVnni ? rhsShape[0] * rhsShape[2] : rhsShape[0];
+ unsigned elemBitWidth = rhsType.getElementType().getIntOrFloatBitWidth();
+ if (dimK != (8 * 32 / elemBitWidth))
+ return rewriter.notifyMatchFailure(contractOp,
+ "Invalid K-dimension size");
+ if (isVnni && rhsShape[2] != (32 / elemBitWidth))
+ return rewriter.notifyMatchFailure(contractOp, "Invalid VNNI factor");
+
+ if (isVnni) {
+ // Collapse contract lhs VNNI factor back into K-dim as dpas op expects
+ // flat 2D shape for its lhs operand.
+ auto lhsShape = lhsType.getShape();
+ auto lhsFlatType = VectorType::get(
+ {lhsShape[0], lhsShape[1] * lhsShape[2]}, lhsType.getElementType());
+ lhs = rewriter.create<vector::ShapeCastOp>(loc, lhsFlatType, lhs)
+ .getResult();
+ }
+
+ auto dpasOp = rewriter.create<xegpu::DpasOp>(
+ loc, contractOp.getResultType(), lhs, rhs, acc);
+ rewriter.replaceOp(contractOp, dpasOp);
+
+ return success();
+ }
+};
+
struct ConvertVectorToXeGPUPass
: public impl::ConvertVectorToXeGPUBase<ConvertVectorToXeGPUPass> {
void runOnOperation() override {
@@ -327,5 +412,5 @@ struct ConvertVectorToXeGPUPass
void mlir::populateVectorToXeGPUConversionPatterns(
RewritePatternSet &patterns) {
patterns.add<TransferReadLowering, TransferWriteLowering, LoadLowering,
- StoreLowering>(patterns.getContext());
+ StoreLowering, ContractionLowering>(patterns.getContext());
}
diff --git a/mlir/test/Conversion/VectorToXeGPU/contract-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/contract-to-xegpu.mlir
new file mode 100644
index 0000000000000..c470422e5ac76
--- /dev/null
+++ b/mlir/test/Conversion/VectorToXeGPU/contract-to-xegpu.mlir
@@ -0,0 +1,259 @@
+// RUN: mlir-opt %s -convert-vector-to-xegpu -split-input-file | FileCheck %s
+
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @dpas_gemm_f32(%lhs: vector<8x8xf32>, %rhs: vector<8x16xf32>,
+ %acc: vector<8x16xf32>) -> vector<8x16xf32> {
+ %3 = vector.contract
+ {indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>} %lhs, %rhs, %acc
+ : vector<8x8xf32>, vector<8x16xf32> into vector<8x16xf32>
+ return %3 : vector<8x16xf32>
+}
+
+// CHECK-LABEL: @dpas_gemm_f32(
+// CHECK-SAME: %[[LHS:.+]]: vector<8x8xf32>,
+// CHECK-SAME: %[[RHS:.+]]: vector<8x16xf32>,
+// CHECK-SAME: %[[ACC:.+]]: vector<8x16xf32>
+// CHECK: %[[DPAS:.+]] = xegpu.dpas
+// CHECK-SAME: %[[LHS]], %[[RHS]], %[[ACC]]
+// CHECK-SAME: {{.*}}-> vector<8x16xf32>
+// CHECK: return %[[DPAS]]
+
+// -----
+
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @dpas_gemm_f16(%lhs: vector<8x16xf16>, %rhs: vector<16x16xf16>,
+ %acc: vector<8x16xf16>) -> vector<8x16xf16> {
+ %3 = vector.contract
+ {indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>} %lhs, %rhs, %acc
+ : vector<8x16xf16>, vector<16x16xf16> into vector<8x16xf16>
+ return %3 : vector<8x16xf16>
+}
+
+// CHECK-LABEL: @dpas_gemm_f16(
+// CHECK-SAME: %[[LHS:.+]]: vector<8x16xf16>,
+// CHECK-SAME: %[[RHS:.+]]: vector<16x16xf16>,
+// CHECK-SAME: %[[ACC:.+]]: vector<8x16xf16>
+// CHECK: %[[DPAS:.+]] = xegpu.dpas
+// CHECK-SAME: %[[LHS]], %[[RHS]], %[[ACC]]
+// CHECK-SAME: {{.*}}-> vector<8x16xf16>
+// CHECK: return %[[DPAS]]
+
+// -----
+
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d2, d1, d3)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1)>
+func.func @dpas_gemm_f16_vnni(%lhs: vector<8x8x2xf16>, %rhs: vector<8x16x2xf16>,
+ %acc: vector<8x16xf16>) -> vector<8x16xf16> {
+ %3 = vector.contract
+ {indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction", "reduction"],
+ kind = #vector.kind<add>} %lhs, %rhs, %acc
+ : vector<8x8x2xf16>, vector<8x16x2xf16> into vector<8x16xf16>
+ return %3 : vector<8x16xf16>
+}
+
+// CHECK-LABEL: @dpas_gemm_f16_vnni(
+// CHECK-SAME: %[[LHS:.+]]: vector<8x8x2xf16>,
+// CHECK-SAME: %[[RHS:.+]]: vector<8x16x2xf16>,
+// CHECK-SAME: %[[ACC:.+]]: vector<8x16xf16>
+// CHECK: %[[CAST_LHS:.+]] = vector.shape_cast %[[LHS]]
+// CHECK-SAME: vector<8x8x2xf16> to vector<8x16xf16>
+// CHECK: %[[DPAS:.+]] = xegpu.dpas
+// CHECK-SAME: %[[CAST_LHS]], %[[RHS]], %[[ACC]]
+// CHECK-SAME: {{.*}}-> vector<8x16xf16>
+// CHECK: return %[[DPAS]]
+
+// -----
+
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @dpas_gemm_mixed_types(%lhs: vector<8x16xi16>, %rhs: vector<16x16xi16>,
+ %acc: vector<8x16xf16>) -> vector<8x16xf16> {
+ %3 = vector.contract
+ {indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>} %lhs, %rhs, %acc
+ : vector<8x16xi16>, vector<16x16xi16> into vector<8x16xf16>
+ return %3 : vector<8x16xf16>
+}
+
+// CHECK-LABEL: @dpas_gemm_mixed_types(
+// CHECK-SAME: %[[LHS:.+]]: vector<8x16xi16>,
+// CHECK-SAME: %[[RHS:.+]]: vector<16x16xi16>,
+// CHECK-SAME: %[[ACC:.+]]: vector<8x16xf16>
+// CHECK: %[[DPAS:.+]] = xegpu.dpas
+// CHECK-SAME: %[[LHS]], %[[RHS]], %[[ACC]]
+// CHECK-SAME: {{.*}}-> vector<8x16xf16>
+// CHECK: return %[[DPAS]]
+
+// -----
+
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @invalid_combining_type(%lhs: vector<8x16xf16>, %rhs: vector<16x16xf16>,
+ %acc: vector<8x16xf16>) -> vector<8x16xf16> {
+ %3 = vector.contract
+ {indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<mul>} %lhs, %rhs, %acc
+ : vector<8x16xf16>, vector<16x16xf16> into vector<8x16xf16>
+ return %3 : vector<8x16xf16>
+}
+
+// CHECK-LABEL: @invalid_combining_type(
+// CHECK: vector.contract
+
+// -----
+
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> ()>
+func.func @invalid_accumulator_shape(%lhs: vector<8x16xf16>, %rhs: vector<16x16xf16>,
+ %acc: vector<f16>) -> vector<f16> {
+ %3 = vector.contract
+ {indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["reduction", "reduction", "reduction"],
+ kind = #vector.kind<add>} %lhs, %rhs, %acc
+ : vector<8x16xf16>, vector<16x16xf16> into vector<f16>
+ return %3 : vector<f16>
+}
+
+// CHECK-LABEL: @invalid_accumulator_shape(
+// CHECK: vector.contract
+
+// -----
+
+#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>
+#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2, d4)>
+#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d1, d2)>
+func.func @invalid_high_dim_reduction(%lhs: vector<3x8x8x2xf16>, %rhs: vector<3x8x16x2xf16>,
+ %acc: vector<8x16xf16>) -> vector<8x16xf16> {
+ %3 = vector.contract
+ {indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["reduction", "parallel", "parallel", "reduction", "reduction"],
+ kind = #vector.kind<add>} %lhs, %rhs, %acc
+ : vector<3x8x8x2xf16>, vector<3x8x16x2xf16> into vector<8x16xf16>
+ return %3 : vector<8x16xf16>
+}
+
+// CHECK-LABEL: @invalid_high_dim_reduction(
+// CHECK: vector.contract
+
+// -----
+
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+func.func @invalid_indexing_maps(%lhs: vector<3x8x16xf16>, %rhs: vector<3x16x16xf16>,
+ %acc: vector<8x16xf16>) -> vector<8x16xf16> {
+ %3 = vector.contract
+ {indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["reduction", "parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>} %lhs, %rhs, %acc
+ : vector<3x8x16xf16>, vector<3x16x16xf16> into vector<8x16xf16>
+ return %3 : vector<8x16xf16>
+}
+
+// CHECK-LABEL: @invalid_indexing_maps(
+// CHECK: vector.contract
+
+// -----
+
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1)>
+func.func @not_vnni_layout(%lhs: vector<8x8x2xf16>, %rhs: vector<16x8x2xf16>,
+ %acc: vector<8x16xf16>) -> vector<8x16xf16> {
+ %3 = vector.contract
+ {indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction", "reduction"],
+ kind = #vector.kind<add>} %lhs, %rhs, %acc
+ : vector<8x8x2xf16>, vector<16x8x2xf16> into vector<8x16xf16>
+ return %3 : vector<8x16xf16>
+}
+
+// CHECK-LABEL: @not_vnni_layout(
+// CHECK: vector.contract
+
+// -----
+
+#map = affine_map<(d0, d1, d2) -> (d2, d0)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @invalid_gemm_transpose_a(%lhs: vector<8x8xf32>, %rhs: vector<8x16xf32>,
+ %acc: vector<8x16xf32>) -> vector<8x16xf32> {
+ %3 = vector.contract
+ {indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>} %lhs, %rhs, %acc
+ : vector<8x8xf32>, vector<8x16xf32> into vector<8x16xf32>
+ return %3 : vector<8x16xf32>
+}
+
+// CHECK-LABEL: @invalid_gemm_transpose_a(
+// CHECK: vector.contract
+
+// -----
+
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @invalid_gemm_transpose_b(%lhs: vector<8x8xf32>, %rhs: vector<16x8xf32>,
+ %acc: vector<8x16xf32>) -> vector<8x16xf32> {
+ %3 = vector.contract
+ {indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>} %lhs, %rhs, %acc
+ : vector<8x8xf32>, vector<16x8xf32> into vector<8x16xf32>
+ return %3 : vector<8x16xf32>
+}
+
+// CHECK-LABEL: @invalid_gemm_transpose_b(
+// CHECK: vector.contract
+
+// -----
+
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d2, d1, d3)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1)>
+func.func @invalid_k_dim_size(%lhs: vector<8x4x2xf16>, %rhs: vector<4x16x2xf16>,
+ %acc: vector<8x16xf16>) -> vector<8x16xf16> {
+ %3 = vector.contract
+ {indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction", "reduction"],
+ kind = #vector.kind<add>} %lhs, %rhs, %acc
+ : vector<8x4x2xf16>, vector<4x16x2xf16> into vector<8x16xf16>
+ return %3 : vector<8x16xf16>
+}
+
+// CHECK-LABEL: @invalid_k_dim_size(
+// CHECK: vector.contract
+
+// -----
+
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d2, d1, d3)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1)>
+func.func @invalid_vnni_factor(%lhs: vector<8x4x4xf16>, %rhs: vector<4x16x4xf16>,
+ %acc: vector<8x16xf16>) -> vector<8x16xf16> {
+ %3 = vector.contract
+ {indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction", "reduction"],
+ kind = #vector.kind<add>} %lhs, %rhs, %acc
+ : vector<8x4x4xf16>, vector<4x16x4xf16> into vector<8x16xf16>
+ return %3 : vector<8x16xf16>
+}
+
+// CHECK-LABEL: @invalid_vnni_factor(
+// CHECK: vector.contract
>From 1a1ec4ad2fd37f550adc2718298c453749435cb2 Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Tue, 11 Mar 2025 14:15:52 +0100
Subject: [PATCH 2/3] Simplify validation + update tests
---
.../VectorToXeGPU/VectorToXeGPU.cpp | 71 ++----
.../VectorToXeGPU/contract-to-xegpu.mlir | 205 ++++++------------
2 files changed, 86 insertions(+), 190 deletions(-)
diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
index fd96341768199..f1a37d7d8336c 100644
--- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
+++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
@@ -14,6 +14,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
#include "mlir/Pass/Pass.h"
@@ -312,28 +313,6 @@ struct StoreLowering : public OpRewritePattern<vector::StoreOp> {
}
};
-static LogicalResult validateDpasIndexing(PatternRewriter &rewriter,
- vector::ContractionOp contractOp) {
- MLIRContext *ctx = contractOp.getContext();
- SmallVector<AffineMap, 4> maps = contractOp.getIndexingMapsArray();
-
- // Operand rank defines expected data layout:
- // - 2D for standard GEMM
- // - 3D for VNNI layout
- using MapList = ArrayRef<ArrayRef<AffineExpr>>;
- auto infer = [&](MapList m) { return AffineMap::inferFromExprList(m, ctx); };
- AffineExpr m, n, k, vnni;
- bindDims(ctx, m, n, k, vnni);
-
- if (contractOp.getRhsType().getRank() == 2) {
- // Require plain GEMM without any transposition.
- return success(maps == infer({{m, k}, {k, n}, {m, n}}));
- }
-
- // Require VNNI layout.
- return success(maps == infer({{m, k, vnni}, {k, n, vnni}, {m, n}}));
-}
-
struct ContractionLowering : public OpRewritePattern<vector::ContractionOp> {
using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
@@ -349,48 +328,30 @@ struct ContractionLowering : public OpRewritePattern<vector::ContractionOp> {
VectorType accType = dyn_cast<VectorType>(acc.getType());
if (!accType || accType.getRank() != 2)
return rewriter.notifyMatchFailure(contractOp, "Expects acc 2D vector");
+
+ // Accept only plain 2D data layout.
+ // VNNI packing is left to later lowering.
TypedValue<VectorType> lhs = contractOp.getLhs();
- VectorType lhsType = lhs.getType();
- int64_t lhsRank = lhsType.getRank();
- if (!(lhsRank == 2 || lhsRank == 3))
- return rewriter.notifyMatchFailure(contractOp,
- "Expects lhs 2D or 3D vector");
TypedValue<VectorType> rhs = contractOp.getRhs();
- VectorType rhsType = rhs.getType();
- int64_t rhsRank = rhsType.getRank();
- if (!(rhsRank == 2 || rhsRank == 3))
+ if (lhs.getType().getRank() != 2 || rhs.getType().getRank() != 2)
return rewriter.notifyMatchFailure(contractOp,
- "Expects rhs 2D or 3D vector");
- if (lhsRank != rhsRank)
- return rewriter.notifyMatchFailure(
- contractOp, "Expects lhs and rhs to be the same rank");
+ "Expects lhs and rhs 2D vectors");
- if (failed(validateDpasIndexing(rewriter, contractOp)))
+ if (!isRowMajorMatmul(contractOp.getIndexingMapsAttr()))
return rewriter.notifyMatchFailure(contractOp, "Invalid indexing maps");
- // 3D shape implies VNNI layout verified by the earlier indexing check.
- bool isVnni = rhsRank == 3;
- auto rhsShape = rhsType.getShape();
- int64_t dimK = isVnni ? rhsShape[0] * rhsShape[2] : rhsShape[0];
- unsigned elemBitWidth = rhsType.getElementType().getIntOrFloatBitWidth();
- if (dimK != (8 * 32 / elemBitWidth))
+ // TODO: Update shape validation to be target aware.
+ auto rhsShape = rhs.getType().getShape();
+ auto accShape = accType.getShape();
+ int64_t dimM = accShape[0];
+ int64_t dimN = accShape[1];
+ int64_t dimK = rhsShape[0];
+ if (dimM != 8 || dimN != 16 || dimK % 8 != 0)
return rewriter.notifyMatchFailure(contractOp,
- "Invalid K-dimension size");
- if (isVnni && rhsShape[2] != (32 / elemBitWidth))
- return rewriter.notifyMatchFailure(contractOp, "Invalid VNNI factor");
-
- if (isVnni) {
- // Collapse contract lhs VNNI factor back into K-dim as dpas op expects
- // flat 2D shape for its lhs operand.
- auto lhsShape = lhsType.getShape();
- auto lhsFlatType = VectorType::get(
- {lhsShape[0], lhsShape[1] * lhsShape[2]}, lhsType.getElementType());
- lhs = rewriter.create<vector::ShapeCastOp>(loc, lhsFlatType, lhs)
- .getResult();
- }
+ "Invalid operand dimensions");
auto dpasOp = rewriter.create<xegpu::DpasOp>(
- loc, contractOp.getResultType(), lhs, rhs, acc);
+ loc, TypeRange{contractOp.getResultType()}, ValueRange{lhs, rhs, acc});
rewriter.replaceOp(contractOp, dpasOp);
return success();
diff --git a/mlir/test/Conversion/VectorToXeGPU/contract-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/contract-to-xegpu.mlir
index c470422e5ac76..1e41f8aa8fdd3 100644
--- a/mlir/test/Conversion/VectorToXeGPU/contract-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/contract-to-xegpu.mlir
@@ -3,19 +3,19 @@
#map = affine_map<(d0, d1, d2) -> (d0, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
-func.func @dpas_gemm_f32(%lhs: vector<8x8xf32>, %rhs: vector<8x16xf32>,
+func.func @dpas_gemm_f16(%lhs: vector<8x16xf16>, %rhs: vector<16x16xf16>,
%acc: vector<8x16xf32>) -> vector<8x16xf32> {
%3 = vector.contract
{indexing_maps = [#map, #map1, #map2],
iterator_types = ["parallel", "parallel", "reduction"],
kind = #vector.kind<add>} %lhs, %rhs, %acc
- : vector<8x8xf32>, vector<8x16xf32> into vector<8x16xf32>
+ : vector<8x16xf16>, vector<16x16xf16> into vector<8x16xf32>
return %3 : vector<8x16xf32>
}
-// CHECK-LABEL: @dpas_gemm_f32(
-// CHECK-SAME: %[[LHS:.+]]: vector<8x8xf32>,
-// CHECK-SAME: %[[RHS:.+]]: vector<8x16xf32>,
+// CHECK-LABEL: @dpas_gemm_f16(
+// CHECK-SAME: %[[LHS:.+]]: vector<8x16xf16>,
+// CHECK-SAME: %[[RHS:.+]]: vector<16x16xf16>,
// CHECK-SAME: %[[ACC:.+]]: vector<8x16xf32>
// CHECK: %[[DPAS:.+]] = xegpu.dpas
// CHECK-SAME: %[[LHS]], %[[RHS]], %[[ACC]]
@@ -27,91 +27,62 @@ func.func @dpas_gemm_f32(%lhs: vector<8x8xf32>, %rhs: vector<8x16xf32>,
#map = affine_map<(d0, d1, d2) -> (d0, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
-func.func @dpas_gemm_f16(%lhs: vector<8x16xf16>, %rhs: vector<16x16xf16>,
- %acc: vector<8x16xf16>) -> vector<8x16xf16> {
+func.func @dpas_gemm_i8(%lhs: vector<8x32xi8>, %rhs: vector<32x16xi8>,
+ %acc: vector<8x16xi32>) -> vector<8x16xi32> {
%3 = vector.contract
{indexing_maps = [#map, #map1, #map2],
iterator_types = ["parallel", "parallel", "reduction"],
kind = #vector.kind<add>} %lhs, %rhs, %acc
- : vector<8x16xf16>, vector<16x16xf16> into vector<8x16xf16>
- return %3 : vector<8x16xf16>
+ : vector<8x32xi8>, vector<32x16xi8> into vector<8x16xi32>
+ return %3 : vector<8x16xi32>
}
-// CHECK-LABEL: @dpas_gemm_f16(
-// CHECK-SAME: %[[LHS:.+]]: vector<8x16xf16>,
-// CHECK-SAME: %[[RHS:.+]]: vector<16x16xf16>,
-// CHECK-SAME: %[[ACC:.+]]: vector<8x16xf16>
+// CHECK-LABEL: @dpas_gemm_i8(
+// CHECK-SAME: %[[LHS:.+]]: vector<8x32xi8>,
+// CHECK-SAME: %[[RHS:.+]]: vector<32x16xi8>,
+// CHECK-SAME: %[[ACC:.+]]: vector<8x16xi32>
// CHECK: %[[DPAS:.+]] = xegpu.dpas
// CHECK-SAME: %[[LHS]], %[[RHS]], %[[ACC]]
-// CHECK-SAME: {{.*}}-> vector<8x16xf16>
+// CHECK-SAME: {{.*}}-> vector<8x16xi32>
// CHECK: return %[[DPAS]]
// -----
+// VNNI packing is added automatically by later XeGPU consumer.
+// For simplicity, only plain data layouts are currently supported.
+
#map = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
#map1 = affine_map<(d0, d1, d2, d3) -> (d2, d1, d3)>
#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1)>
-func.func @dpas_gemm_f16_vnni(%lhs: vector<8x8x2xf16>, %rhs: vector<8x16x2xf16>,
- %acc: vector<8x16xf16>) -> vector<8x16xf16> {
+func.func @negative_vnni_packed(%lhs: vector<8x8x2xf16>, %rhs: vector<8x16x2xf16>,
+ %acc: vector<8x16xf32>) -> vector<8x16xf32> {
%3 = vector.contract
{indexing_maps = [#map, #map1, #map2],
iterator_types = ["parallel", "parallel", "reduction", "reduction"],
kind = #vector.kind<add>} %lhs, %rhs, %acc
- : vector<8x8x2xf16>, vector<8x16x2xf16> into vector<8x16xf16>
- return %3 : vector<8x16xf16>
-}
-
-// CHECK-LABEL: @dpas_gemm_f16_vnni(
-// CHECK-SAME: %[[LHS:.+]]: vector<8x8x2xf16>,
-// CHECK-SAME: %[[RHS:.+]]: vector<8x16x2xf16>,
-// CHECK-SAME: %[[ACC:.+]]: vector<8x16xf16>
-// CHECK: %[[CAST_LHS:.+]] = vector.shape_cast %[[LHS]]
-// CHECK-SAME: vector<8x8x2xf16> to vector<8x16xf16>
-// CHECK: %[[DPAS:.+]] = xegpu.dpas
-// CHECK-SAME: %[[CAST_LHS]], %[[RHS]], %[[ACC]]
-// CHECK-SAME: {{.*}}-> vector<8x16xf16>
-// CHECK: return %[[DPAS]]
-
-// -----
-
-#map = affine_map<(d0, d1, d2) -> (d0, d2)>
-#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
-#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
-func.func @dpas_gemm_mixed_types(%lhs: vector<8x16xi16>, %rhs: vector<16x16xi16>,
- %acc: vector<8x16xf16>) -> vector<8x16xf16> {
- %3 = vector.contract
- {indexing_maps = [#map, #map1, #map2],
- iterator_types = ["parallel", "parallel", "reduction"],
- kind = #vector.kind<add>} %lhs, %rhs, %acc
- : vector<8x16xi16>, vector<16x16xi16> into vector<8x16xf16>
- return %3 : vector<8x16xf16>
+ : vector<8x8x2xf16>, vector<8x16x2xf16> into vector<8x16xf32>
+ return %3 : vector<8x16xf32>
}
-// CHECK-LABEL: @dpas_gemm_mixed_types(
-// CHECK-SAME: %[[LHS:.+]]: vector<8x16xi16>,
-// CHECK-SAME: %[[RHS:.+]]: vector<16x16xi16>,
-// CHECK-SAME: %[[ACC:.+]]: vector<8x16xf16>
-// CHECK: %[[DPAS:.+]] = xegpu.dpas
-// CHECK-SAME: %[[LHS]], %[[RHS]], %[[ACC]]
-// CHECK-SAME: {{.*}}-> vector<8x16xf16>
-// CHECK: return %[[DPAS]]
+// CHECK-LABEL: @negative_vnni_packed(
+// CHECK: vector.contract
// -----
#map = affine_map<(d0, d1, d2) -> (d0, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
-func.func @invalid_combining_type(%lhs: vector<8x16xf16>, %rhs: vector<16x16xf16>,
- %acc: vector<8x16xf16>) -> vector<8x16xf16> {
+func.func @negative_combining_type(%lhs: vector<8x16xf16>, %rhs: vector<16x16xf16>,
+ %acc: vector<8x16xf32>) -> vector<8x16xf32> {
%3 = vector.contract
{indexing_maps = [#map, #map1, #map2],
iterator_types = ["parallel", "parallel", "reduction"],
kind = #vector.kind<mul>} %lhs, %rhs, %acc
- : vector<8x16xf16>, vector<16x16xf16> into vector<8x16xf16>
- return %3 : vector<8x16xf16>
+ : vector<8x16xf16>, vector<16x16xf16> into vector<8x16xf32>
+ return %3 : vector<8x16xf32>
}
-// CHECK-LABEL: @invalid_combining_type(
+// CHECK-LABEL: @negative_combining_type(
// CHECK: vector.contract
// -----
@@ -119,141 +90,105 @@ func.func @invalid_combining_type(%lhs: vector<8x16xf16>, %rhs: vector<16x16xf16
#map = affine_map<(d0, d1, d2) -> (d0, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
#map2 = affine_map<(d0, d1, d2) -> ()>
-func.func @invalid_accumulator_shape(%lhs: vector<8x16xf16>, %rhs: vector<16x16xf16>,
- %acc: vector<f16>) -> vector<f16> {
+func.func @negative_accumulator_shape(%lhs: vector<8x16xf16>, %rhs: vector<16x16xf16>,
+ %acc: vector<f32>) -> vector<f32> {
%3 = vector.contract
{indexing_maps = [#map, #map1, #map2],
iterator_types = ["reduction", "reduction", "reduction"],
kind = #vector.kind<add>} %lhs, %rhs, %acc
- : vector<8x16xf16>, vector<16x16xf16> into vector<f16>
- return %3 : vector<f16>
+ : vector<8x16xf16>, vector<16x16xf16> into vector<f32>
+ return %3 : vector<f32>
}
-// CHECK-LABEL: @invalid_accumulator_shape(
+// CHECK-LABEL: @negative_accumulator_shape(
// CHECK: vector.contract
// -----
-#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>
-#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2, d4)>
-#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d1, d2)>
-func.func @invalid_high_dim_reduction(%lhs: vector<3x8x8x2xf16>, %rhs: vector<3x8x16x2xf16>,
- %acc: vector<8x16xf16>) -> vector<8x16xf16> {
+#map = affine_map<(d0, d1, d2) -> (d2, d0)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @negative_gemm_transpose_a(%lhs: vector<16x8xf16>, %rhs: vector<16x16xf16>,
+ %acc: vector<8x16xf32>) -> vector<8x16xf32> {
%3 = vector.contract
{indexing_maps = [#map, #map1, #map2],
- iterator_types = ["reduction", "parallel", "parallel", "reduction", "reduction"],
+ iterator_types = ["parallel", "parallel", "reduction"],
kind = #vector.kind<add>} %lhs, %rhs, %acc
- : vector<3x8x8x2xf16>, vector<3x8x16x2xf16> into vector<8x16xf16>
- return %3 : vector<8x16xf16>
+ : vector<16x8xf16>, vector<16x16xf16> into vector<8x16xf32>
+ return %3 : vector<8x16xf32>
}
-// CHECK-LABEL: @invalid_high_dim_reduction(
+// CHECK-LABEL: @negative_gemm_transpose_a(
// CHECK: vector.contract
// -----
-#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
-#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
-#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
-func.func @invalid_indexing_maps(%lhs: vector<3x8x16xf16>, %rhs: vector<3x16x16xf16>,
- %acc: vector<8x16xf16>) -> vector<8x16xf16> {
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @negative_gemm_transpose_b(%lhs: vector<8x16xf16>, %rhs: vector<16x16xf16>,
+ %acc: vector<8x16xf32>) -> vector<8x16xf32> {
%3 = vector.contract
{indexing_maps = [#map, #map1, #map2],
- iterator_types = ["reduction", "parallel", "parallel", "reduction"],
+ iterator_types = ["parallel", "parallel", "reduction"],
kind = #vector.kind<add>} %lhs, %rhs, %acc
- : vector<3x8x16xf16>, vector<3x16x16xf16> into vector<8x16xf16>
- return %3 : vector<8x16xf16>
+ : vector<8x16xf16>, vector<16x16xf16> into vector<8x16xf32>
+ return %3 : vector<8x16xf32>
}
-// CHECK-LABEL: @invalid_indexing_maps(
+// CHECK-LABEL: @negative_gemm_transpose_b(
// CHECK: vector.contract
// -----
-#map = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
-#map1 = affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)>
-#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1)>
-func.func @not_vnni_layout(%lhs: vector<8x8x2xf16>, %rhs: vector<16x8x2xf16>,
- %acc: vector<8x16xf16>) -> vector<8x16xf16> {
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @negative_m_dim_size(%lhs: vector<16x16xf16>, %rhs: vector<16x16xf16>,
+ %acc: vector<16x16xf32>) -> vector<16x16xf32> {
%3 = vector.contract
{indexing_maps = [#map, #map1, #map2],
- iterator_types = ["parallel", "parallel", "reduction", "reduction"],
+ iterator_types = ["parallel", "parallel", "reduction"],
kind = #vector.kind<add>} %lhs, %rhs, %acc
- : vector<8x8x2xf16>, vector<16x8x2xf16> into vector<8x16xf16>
- return %3 : vector<8x16xf16>
+ : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf32>
+ return %3 : vector<16x16xf32>
}
-// CHECK-LABEL: @not_vnni_layout(
+// CHECK-LABEL: @negative_m_dim_size(
// CHECK: vector.contract
// -----
-#map = affine_map<(d0, d1, d2) -> (d2, d0)>
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
-func.func @invalid_gemm_transpose_a(%lhs: vector<8x8xf32>, %rhs: vector<8x16xf32>,
- %acc: vector<8x16xf32>) -> vector<8x16xf32> {
+func.func @negative_n_dim_size(%lhs: vector<8x16xf16>, %rhs: vector<16x8xf16>,
+ %acc: vector<8x8xf32>) -> vector<8x8xf32> {
%3 = vector.contract
{indexing_maps = [#map, #map1, #map2],
iterator_types = ["parallel", "parallel", "reduction"],
kind = #vector.kind<add>} %lhs, %rhs, %acc
- : vector<8x8xf32>, vector<8x16xf32> into vector<8x16xf32>
- return %3 : vector<8x16xf32>
+ : vector<8x16xf16>, vector<16x8xf16> into vector<8x8xf32>
+ return %3 : vector<8x8xf32>
}
-// CHECK-LABEL: @invalid_gemm_transpose_a(
+// CHECK-LABEL: @negative_n_dim_size(
// CHECK: vector.contract
// -----
#map = affine_map<(d0, d1, d2) -> (d0, d2)>
-#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
-func.func @invalid_gemm_transpose_b(%lhs: vector<8x8xf32>, %rhs: vector<16x8xf32>,
+func.func @negative_k_dim_size(%lhs: vector<8x12xf16>, %rhs: vector<12x16xf16>,
%acc: vector<8x16xf32>) -> vector<8x16xf32> {
%3 = vector.contract
{indexing_maps = [#map, #map1, #map2],
iterator_types = ["parallel", "parallel", "reduction"],
kind = #vector.kind<add>} %lhs, %rhs, %acc
- : vector<8x8xf32>, vector<16x8xf32> into vector<8x16xf32>
+ : vector<8x12xf16>, vector<12x16xf16> into vector<8x16xf32>
return %3 : vector<8x16xf32>
}
-// CHECK-LABEL: @invalid_gemm_transpose_b(
-// CHECK: vector.contract
-
-// -----
-
-#map = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
-#map1 = affine_map<(d0, d1, d2, d3) -> (d2, d1, d3)>
-#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1)>
-func.func @invalid_k_dim_size(%lhs: vector<8x4x2xf16>, %rhs: vector<4x16x2xf16>,
- %acc: vector<8x16xf16>) -> vector<8x16xf16> {
- %3 = vector.contract
- {indexing_maps = [#map, #map1, #map2],
- iterator_types = ["parallel", "parallel", "reduction", "reduction"],
- kind = #vector.kind<add>} %lhs, %rhs, %acc
- : vector<8x4x2xf16>, vector<4x16x2xf16> into vector<8x16xf16>
- return %3 : vector<8x16xf16>
-}
-
-// CHECK-LABEL: @invalid_k_dim_size(
-// CHECK: vector.contract
-
-// -----
-
-#map = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
-#map1 = affine_map<(d0, d1, d2, d3) -> (d2, d1, d3)>
-#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1)>
-func.func @invalid_vnni_factor(%lhs: vector<8x4x4xf16>, %rhs: vector<4x16x4xf16>,
- %acc: vector<8x16xf16>) -> vector<8x16xf16> {
- %3 = vector.contract
- {indexing_maps = [#map, #map1, #map2],
- iterator_types = ["parallel", "parallel", "reduction", "reduction"],
- kind = #vector.kind<add>} %lhs, %rhs, %acc
- : vector<8x4x4xf16>, vector<4x16x4xf16> into vector<8x16xf16>
- return %3 : vector<8x16xf16>
-}
-
-// CHECK-LABEL: @invalid_vnni_factor(
+// CHECK-LABEL: @negative_k_dim_size(
// CHECK: vector.contract
>From fa9dbc3550283cab535a62f1b8e4d3e0e6f369dd Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Tue, 11 Mar 2025 14:25:42 +0100
Subject: [PATCH 3/3] Rename test case
---
mlir/test/Conversion/VectorToXeGPU/contract-to-xegpu.mlir | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/mlir/test/Conversion/VectorToXeGPU/contract-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/contract-to-xegpu.mlir
index 1e41f8aa8fdd3..98c9fac83f9c7 100644
--- a/mlir/test/Conversion/VectorToXeGPU/contract-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/contract-to-xegpu.mlir
@@ -72,7 +72,7 @@ func.func @negative_vnni_packed(%lhs: vector<8x8x2xf16>, %rhs: vector<8x16x2xf16
#map = affine_map<(d0, d1, d2) -> (d0, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
-func.func @negative_combining_type(%lhs: vector<8x16xf16>, %rhs: vector<16x16xf16>,
+func.func @negative_combining_kind(%lhs: vector<8x16xf16>, %rhs: vector<16x16xf16>,
%acc: vector<8x16xf32>) -> vector<8x16xf32> {
%3 = vector.contract
{indexing_maps = [#map, #map1, #map2],
@@ -82,7 +82,7 @@ func.func @negative_combining_type(%lhs: vector<8x16xf16>, %rhs: vector<16x16xf1
return %3 : vector<8x16xf32>
}
-// CHECK-LABEL: @negative_combining_type(
+// CHECK-LABEL: @negative_combining_kind(
// CHECK: vector.contract
// -----
More information about the Mlir-commits
mailing list