[Mlir-commits] [mlir] [mlir][xegpu] Convert Vector contraction to XeGPU (PR #122115)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Jan 8 06:21:42 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-gpu

@llvm/pr-subscribers-mlir

Author: Adam Siemieniuk (adam-smnk)

<details>
<summary>Changes</summary>

Adds pattern to lower vector.contract to XeGPU operation.

---
Full diff: https://github.com/llvm/llvm-project/pull/122115.diff


2 Files Affected:

- (modified) mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp (+86-1) 
- (added) mlir/test/Conversion/VectorToXeGPU/contract-to-xegpu.mlir (+259) 


``````````diff
diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
index 8041bdf7da19b3..1859f8cc0421e8 100644
--- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
+++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
@@ -313,6 +313,91 @@ struct StoreLowering : public OpRewritePattern<vector::StoreOp> {
   }
 };
 
+static LogicalResult validateDpasIndexing(PatternRewriter &rewriter,
+                                          vector::ContractionOp contractOp) {
+  MLIRContext *ctx = contractOp.getContext();
+  SmallVector<AffineMap, 4> maps = contractOp.getIndexingMapsArray();
+
+  // Operand rank defines expected data layout:
+  //   - 2D for standard GEMM
+  //   - 3D for VNNI layout
+  using MapList = ArrayRef<ArrayRef<AffineExpr>>;
+  auto infer = [&](MapList m) { return AffineMap::inferFromExprList(m, ctx); };
+  AffineExpr m, n, k, vnni;
+  bindDims(ctx, m, n, k, vnni);
+
+  if (contractOp.getRhsType().getRank() == 2) {
+    // Require plain GEMM without any transposition.
+    return success(maps == infer({{m, k}, {k, n}, {m, n}}));
+  }
+
+  // Require VNNI layout.
+  return success(maps == infer({{m, k, vnni}, {k, n, vnni}, {m, n}}));
+}
+
+struct ContractionLowering : public OpRewritePattern<vector::ContractionOp> {
+  using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
+                                PatternRewriter &rewriter) const override {
+    Location loc = contractOp.getLoc();
+
+    if (contractOp.getKind() != vector::CombiningKind::ADD)
+      return rewriter.notifyMatchFailure(contractOp,
+                                         "Expects add combining kind");
+
+    TypedValue<Type> acc = contractOp.getAcc();
+    VectorType accType = dyn_cast<VectorType>(acc.getType());
+    if (!accType || accType.getRank() != 2)
+      return rewriter.notifyMatchFailure(contractOp, "Expects acc 2D vector");
+    TypedValue<VectorType> lhs = contractOp.getLhs();
+    VectorType lhsType = lhs.getType();
+    int64_t lhsRank = lhsType.getRank();
+    if (!(lhsRank == 2 || lhsRank == 3))
+      return rewriter.notifyMatchFailure(contractOp,
+                                         "Expects lhs 2D or 3D vector");
+    TypedValue<VectorType> rhs = contractOp.getRhs();
+    VectorType rhsType = rhs.getType();
+    int64_t rhsRank = rhsType.getRank();
+    if (!(rhsRank == 2 || rhsRank == 3))
+      return rewriter.notifyMatchFailure(contractOp,
+                                         "Expects rhs 2D or 3D vector");
+    if (lhsRank != rhsRank)
+      return rewriter.notifyMatchFailure(
+          contractOp, "Expects lhs and rhs to be the same rank");
+
+    if (failed(validateDpasIndexing(rewriter, contractOp)))
+      return rewriter.notifyMatchFailure(contractOp, "Invalid indexing maps");
+
+    // 3D shape implies VNNI layout verified by the earlier indexing check.
+    bool isVnni = rhsRank == 3;
+    auto rhsShape = rhsType.getShape();
+    int64_t dimK = isVnni ? rhsShape[0] * rhsShape[2] : rhsShape[0];
+    unsigned elemBitWidth = rhsType.getElementType().getIntOrFloatBitWidth();
+    if (dimK != (8 * 32 / elemBitWidth))
+      return rewriter.notifyMatchFailure(contractOp,
+                                         "Invalid K-dimension size");
+    if (isVnni && rhsShape[2] != (32 / elemBitWidth))
+      return rewriter.notifyMatchFailure(contractOp, "Invalid VNNI factor");
+
+    if (isVnni) {
+      // Collapse contract lhs VNNI factor back into K-dim as dpas op expects
+      // flat 2D shape for its lhs operand.
+      auto lhsShape = lhsType.getShape();
+      auto lhsFlatType = VectorType::get(
+          {lhsShape[0], lhsShape[1] * lhsShape[2]}, lhsType.getElementType());
+      lhs = rewriter.create<vector::ShapeCastOp>(loc, lhsFlatType, lhs)
+                .getResult();
+    }
+
+    auto dpasOp = rewriter.create<xegpu::DpasOp>(
+        loc, contractOp.getResultType(), lhs, rhs, acc);
+    rewriter.replaceOp(contractOp, dpasOp);
+
+    return success();
+  }
+};
+
 struct ConvertVectorToXeGPUPass
     : public impl::ConvertVectorToXeGPUBase<ConvertVectorToXeGPUPass> {
   void runOnOperation() override {
@@ -328,7 +413,7 @@ struct ConvertVectorToXeGPUPass
 void mlir::populateVectorToXeGPUConversionPatterns(
     RewritePatternSet &patterns) {
   patterns.add<TransferReadLowering, TransferWriteLowering, LoadLowering,
-               StoreLowering>(patterns.getContext());
+               StoreLowering, ContractionLowering>(patterns.getContext());
 }
 
 std::unique_ptr<Pass> mlir::createConvertVectorToXeGPUPass() {
diff --git a/mlir/test/Conversion/VectorToXeGPU/contract-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/contract-to-xegpu.mlir
new file mode 100644
index 00000000000000..c470422e5ac763
--- /dev/null
+++ b/mlir/test/Conversion/VectorToXeGPU/contract-to-xegpu.mlir
@@ -0,0 +1,259 @@
+// RUN: mlir-opt %s -convert-vector-to-xegpu -split-input-file | FileCheck %s
+
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @dpas_gemm_f32(%lhs: vector<8x8xf32>, %rhs: vector<8x16xf32>,
+    %acc: vector<8x16xf32>) -> vector<8x16xf32> {
+  %3 = vector.contract
+    {indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>} %lhs, %rhs, %acc
+    : vector<8x8xf32>, vector<8x16xf32> into vector<8x16xf32>
+  return %3 : vector<8x16xf32>
+}
+
+// CHECK-LABEL: @dpas_gemm_f32(
+// CHECK-SAME:  %[[LHS:.+]]: vector<8x8xf32>,
+// CHECK-SAME:  %[[RHS:.+]]: vector<8x16xf32>,
+// CHECK-SAME:  %[[ACC:.+]]: vector<8x16xf32>
+// CHECK:       %[[DPAS:.+]] = xegpu.dpas
+// CHECK-SAME:    %[[LHS]], %[[RHS]], %[[ACC]]
+// CHECK-SAME:    {{.*}}-> vector<8x16xf32>
+// CHECK:       return %[[DPAS]]
+
+// -----
+
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @dpas_gemm_f16(%lhs: vector<8x16xf16>, %rhs: vector<16x16xf16>,
+    %acc: vector<8x16xf16>) -> vector<8x16xf16> {
+  %3 = vector.contract
+    {indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>} %lhs, %rhs, %acc
+    : vector<8x16xf16>, vector<16x16xf16> into vector<8x16xf16>
+  return %3 : vector<8x16xf16>
+}
+
+// CHECK-LABEL: @dpas_gemm_f16(
+// CHECK-SAME:  %[[LHS:.+]]: vector<8x16xf16>,
+// CHECK-SAME:  %[[RHS:.+]]: vector<16x16xf16>,
+// CHECK-SAME:  %[[ACC:.+]]: vector<8x16xf16>
+// CHECK:       %[[DPAS:.+]] = xegpu.dpas
+// CHECK-SAME:    %[[LHS]], %[[RHS]], %[[ACC]]
+// CHECK-SAME:    {{.*}}-> vector<8x16xf16>
+// CHECK:       return %[[DPAS]]
+
+// -----
+
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d2, d1, d3)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1)>
+func.func @dpas_gemm_f16_vnni(%lhs: vector<8x8x2xf16>, %rhs: vector<8x16x2xf16>,
+    %acc: vector<8x16xf16>) -> vector<8x16xf16> {
+  %3 = vector.contract
+    {indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction", "reduction"],
+    kind = #vector.kind<add>} %lhs, %rhs, %acc
+    : vector<8x8x2xf16>, vector<8x16x2xf16> into vector<8x16xf16>
+  return %3 : vector<8x16xf16>
+}
+
+// CHECK-LABEL: @dpas_gemm_f16_vnni(
+// CHECK-SAME:  %[[LHS:.+]]: vector<8x8x2xf16>,
+// CHECK-SAME:  %[[RHS:.+]]: vector<8x16x2xf16>,
+// CHECK-SAME:  %[[ACC:.+]]: vector<8x16xf16>
+// CHECK:       %[[CAST_LHS:.+]] = vector.shape_cast %[[LHS]]
+// CHECK-SAME:    vector<8x8x2xf16> to vector<8x16xf16>
+// CHECK:       %[[DPAS:.+]] = xegpu.dpas
+// CHECK-SAME:    %[[CAST_LHS]], %[[RHS]], %[[ACC]]
+// CHECK-SAME:    {{.*}}-> vector<8x16xf16>
+// CHECK:       return %[[DPAS]]
+
+// -----
+
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @dpas_gemm_mixed_types(%lhs: vector<8x16xi16>, %rhs: vector<16x16xi16>,
+    %acc: vector<8x16xf16>) -> vector<8x16xf16> {
+  %3 = vector.contract
+    {indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>} %lhs, %rhs, %acc
+    : vector<8x16xi16>, vector<16x16xi16> into vector<8x16xf16>
+  return %3 : vector<8x16xf16>
+}
+
+// CHECK-LABEL: @dpas_gemm_mixed_types(
+// CHECK-SAME:  %[[LHS:.+]]: vector<8x16xi16>,
+// CHECK-SAME:  %[[RHS:.+]]: vector<16x16xi16>,
+// CHECK-SAME:  %[[ACC:.+]]: vector<8x16xf16>
+// CHECK:       %[[DPAS:.+]] = xegpu.dpas
+// CHECK-SAME:    %[[LHS]], %[[RHS]], %[[ACC]]
+// CHECK-SAME:    {{.*}}-> vector<8x16xf16>
+// CHECK:       return %[[DPAS]]
+
+// -----
+
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @invalid_combining_type(%lhs: vector<8x16xf16>, %rhs: vector<16x16xf16>,
+    %acc: vector<8x16xf16>) -> vector<8x16xf16> {
+  %3 = vector.contract
+    {indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction"],
+    kind = #vector.kind<mul>} %lhs, %rhs, %acc
+    : vector<8x16xf16>, vector<16x16xf16> into vector<8x16xf16>
+  return %3 : vector<8x16xf16>
+}
+
+// CHECK-LABEL: @invalid_combining_type(
+// CHECK:       vector.contract
+
+// -----
+
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> ()>
+func.func @invalid_accumulator_shape(%lhs: vector<8x16xf16>, %rhs: vector<16x16xf16>,
+    %acc: vector<f16>) -> vector<f16> {
+  %3 = vector.contract
+    {indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["reduction", "reduction", "reduction"],
+    kind = #vector.kind<add>} %lhs, %rhs, %acc
+    : vector<8x16xf16>, vector<16x16xf16> into vector<f16>
+  return %3 : vector<f16>
+}
+
+// CHECK-LABEL: @invalid_accumulator_shape(
+// CHECK:       vector.contract
+
+// -----
+
+#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>
+#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2, d4)>
+#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d1, d2)>
+func.func @invalid_high_dim_reduction(%lhs: vector<3x8x8x2xf16>, %rhs: vector<3x8x16x2xf16>,
+    %acc: vector<8x16xf16>) -> vector<8x16xf16> {
+  %3 = vector.contract
+    {indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["reduction", "parallel", "parallel", "reduction", "reduction"],
+    kind = #vector.kind<add>} %lhs, %rhs, %acc
+    : vector<3x8x8x2xf16>, vector<3x8x16x2xf16> into vector<8x16xf16>
+  return %3 : vector<8x16xf16>
+}
+
+// CHECK-LABEL: @invalid_high_dim_reduction(
+// CHECK:       vector.contract
+
+// -----
+
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+func.func @invalid_indexing_maps(%lhs: vector<3x8x16xf16>, %rhs: vector<3x16x16xf16>,
+    %acc: vector<8x16xf16>) -> vector<8x16xf16> {
+  %3 = vector.contract
+    {indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["reduction", "parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>} %lhs, %rhs, %acc
+    : vector<3x8x16xf16>, vector<3x16x16xf16> into vector<8x16xf16>
+  return %3 : vector<8x16xf16>
+}
+
+// CHECK-LABEL: @invalid_indexing_maps(
+// CHECK:       vector.contract
+
+// -----
+
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1)>
+func.func @not_vnni_layout(%lhs: vector<8x8x2xf16>, %rhs: vector<16x8x2xf16>,
+    %acc: vector<8x16xf16>) -> vector<8x16xf16> {
+  %3 = vector.contract
+    {indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction", "reduction"],
+    kind = #vector.kind<add>} %lhs, %rhs, %acc
+    : vector<8x8x2xf16>, vector<16x8x2xf16> into vector<8x16xf16>
+  return %3 : vector<8x16xf16>
+}
+
+// CHECK-LABEL: @not_vnni_layout(
+// CHECK:       vector.contract
+
+// -----
+
+#map = affine_map<(d0, d1, d2) -> (d2, d0)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @invalid_gemm_transpose_a(%lhs: vector<8x8xf32>, %rhs: vector<8x16xf32>,
+    %acc: vector<8x16xf32>) -> vector<8x16xf32> {
+  %3 = vector.contract
+    {indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>} %lhs, %rhs, %acc
+    : vector<8x8xf32>, vector<8x16xf32> into vector<8x16xf32>
+  return %3 : vector<8x16xf32>
+}
+
+// CHECK-LABEL: @invalid_gemm_transpose_a(
+// CHECK:       vector.contract
+
+// -----
+
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @invalid_gemm_transpose_b(%lhs: vector<8x8xf32>, %rhs: vector<16x8xf32>,
+    %acc: vector<8x16xf32>) -> vector<8x16xf32> {
+  %3 = vector.contract
+    {indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>} %lhs, %rhs, %acc
+    : vector<8x8xf32>, vector<16x8xf32> into vector<8x16xf32>
+  return %3 : vector<8x16xf32>
+}
+
+// CHECK-LABEL: @invalid_gemm_transpose_b(
+// CHECK:       vector.contract
+
+// -----
+
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d2, d1, d3)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1)>
+func.func @invalid_k_dim_size(%lhs: vector<8x4x2xf16>, %rhs: vector<4x16x2xf16>,
+    %acc: vector<8x16xf16>) -> vector<8x16xf16> {
+  %3 = vector.contract
+    {indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction", "reduction"],
+    kind = #vector.kind<add>} %lhs, %rhs, %acc
+    : vector<8x4x2xf16>, vector<4x16x2xf16> into vector<8x16xf16>
+  return %3 : vector<8x16xf16>
+}
+
+// CHECK-LABEL: @invalid_k_dim_size(
+// CHECK:       vector.contract
+
+// -----
+
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d2, d1, d3)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1)>
+func.func @invalid_vnni_factor(%lhs: vector<8x4x4xf16>, %rhs: vector<4x16x4xf16>,
+    %acc: vector<8x16xf16>) -> vector<8x16xf16> {
+  %3 = vector.contract
+    {indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction", "reduction"],
+    kind = #vector.kind<add>} %lhs, %rhs, %acc
+    : vector<8x4x4xf16>, vector<4x16x4xf16> into vector<8x16xf16>
+  return %3 : vector<8x16xf16>
+}
+
+// CHECK-LABEL: @invalid_vnni_factor(
+// CHECK:       vector.contract

``````````

</details>


https://github.com/llvm/llvm-project/pull/122115


More information about the Mlir-commits mailing list