[Mlir-commits] [mlir] [mlir][xegpu] Convert Vector contraction to XeGPU (PR #122115)

Adam Siemieniuk llvmlistbot at llvm.org
Tue Mar 11 06:25:57 PDT 2025


https://github.com/adam-smnk updated https://github.com/llvm/llvm-project/pull/122115

>From 1e00eb8403ef97d59e3f7e6fcc22f1fe868720da Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Tue, 7 Jan 2025 16:54:52 +0100
Subject: [PATCH 1/3] [mlir][xegpu] Convert Vector contraction to XeGPU

Adds pattern to lower vector.contract to XeGPU operation.
---
 .../VectorToXeGPU/VectorToXeGPU.cpp           |  87 +++++-
 .../VectorToXeGPU/contract-to-xegpu.mlir      | 259 ++++++++++++++++++
 2 files changed, 345 insertions(+), 1 deletion(-)
 create mode 100644 mlir/test/Conversion/VectorToXeGPU/contract-to-xegpu.mlir

diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
index 61b55c57240ce..fd96341768199 100644
--- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
+++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
@@ -312,6 +312,91 @@ struct StoreLowering : public OpRewritePattern<vector::StoreOp> {
   }
 };
 
+static LogicalResult validateDpasIndexing(PatternRewriter &rewriter,
+                                          vector::ContractionOp contractOp) {
+  MLIRContext *ctx = contractOp.getContext();
+  SmallVector<AffineMap, 4> maps = contractOp.getIndexingMapsArray();
+
+  // Operand rank defines expected data layout:
+  //   - 2D for standard GEMM
+  //   - 3D for VNNI layout
+  using MapList = ArrayRef<ArrayRef<AffineExpr>>;
+  auto infer = [&](MapList m) { return AffineMap::inferFromExprList(m, ctx); };
+  AffineExpr m, n, k, vnni;
+  bindDims(ctx, m, n, k, vnni);
+
+  if (contractOp.getRhsType().getRank() == 2) {
+    // Require plain GEMM without any transposition.
+    return success(maps == infer({{m, k}, {k, n}, {m, n}}));
+  }
+
+  // Require VNNI layout.
+  return success(maps == infer({{m, k, vnni}, {k, n, vnni}, {m, n}}));
+}
+
+struct ContractionLowering : public OpRewritePattern<vector::ContractionOp> {
+  using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
+                                PatternRewriter &rewriter) const override {
+    Location loc = contractOp.getLoc();
+
+    if (contractOp.getKind() != vector::CombiningKind::ADD)
+      return rewriter.notifyMatchFailure(contractOp,
+                                         "Expects add combining kind");
+
+    TypedValue<Type> acc = contractOp.getAcc();
+    VectorType accType = dyn_cast<VectorType>(acc.getType());
+    if (!accType || accType.getRank() != 2)
+      return rewriter.notifyMatchFailure(contractOp, "Expects acc 2D vector");
+    TypedValue<VectorType> lhs = contractOp.getLhs();
+    VectorType lhsType = lhs.getType();
+    int64_t lhsRank = lhsType.getRank();
+    if (!(lhsRank == 2 || lhsRank == 3))
+      return rewriter.notifyMatchFailure(contractOp,
+                                         "Expects lhs 2D or 3D vector");
+    TypedValue<VectorType> rhs = contractOp.getRhs();
+    VectorType rhsType = rhs.getType();
+    int64_t rhsRank = rhsType.getRank();
+    if (!(rhsRank == 2 || rhsRank == 3))
+      return rewriter.notifyMatchFailure(contractOp,
+                                         "Expects rhs 2D or 3D vector");
+    if (lhsRank != rhsRank)
+      return rewriter.notifyMatchFailure(
+          contractOp, "Expects lhs and rhs to be the same rank");
+
+    if (failed(validateDpasIndexing(rewriter, contractOp)))
+      return rewriter.notifyMatchFailure(contractOp, "Invalid indexing maps");
+
+    // 3D shape implies VNNI layout verified by the earlier indexing check.
+    bool isVnni = rhsRank == 3;
+    auto rhsShape = rhsType.getShape();
+    int64_t dimK = isVnni ? rhsShape[0] * rhsShape[2] : rhsShape[0];
+    unsigned elemBitWidth = rhsType.getElementType().getIntOrFloatBitWidth();
+    if (dimK != (8 * 32 / elemBitWidth))
+      return rewriter.notifyMatchFailure(contractOp,
+                                         "Invalid K-dimension size");
+    if (isVnni && rhsShape[2] != (32 / elemBitWidth))
+      return rewriter.notifyMatchFailure(contractOp, "Invalid VNNI factor");
+
+    if (isVnni) {
+      // Collapse contract lhs VNNI factor back into K-dim as dpas op expects
+      // flat 2D shape for its lhs operand.
+      auto lhsShape = lhsType.getShape();
+      auto lhsFlatType = VectorType::get(
+          {lhsShape[0], lhsShape[1] * lhsShape[2]}, lhsType.getElementType());
+      lhs = rewriter.create<vector::ShapeCastOp>(loc, lhsFlatType, lhs)
+                .getResult();
+    }
+
+    auto dpasOp = rewriter.create<xegpu::DpasOp>(
+        loc, contractOp.getResultType(), lhs, rhs, acc);
+    rewriter.replaceOp(contractOp, dpasOp);
+
+    return success();
+  }
+};
+
 struct ConvertVectorToXeGPUPass
     : public impl::ConvertVectorToXeGPUBase<ConvertVectorToXeGPUPass> {
   void runOnOperation() override {
@@ -327,5 +412,5 @@ struct ConvertVectorToXeGPUPass
 void mlir::populateVectorToXeGPUConversionPatterns(
     RewritePatternSet &patterns) {
   patterns.add<TransferReadLowering, TransferWriteLowering, LoadLowering,
-               StoreLowering>(patterns.getContext());
+               StoreLowering, ContractionLowering>(patterns.getContext());
 }
diff --git a/mlir/test/Conversion/VectorToXeGPU/contract-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/contract-to-xegpu.mlir
new file mode 100644
index 0000000000000..c470422e5ac76
--- /dev/null
+++ b/mlir/test/Conversion/VectorToXeGPU/contract-to-xegpu.mlir
@@ -0,0 +1,259 @@
+// RUN: mlir-opt %s -convert-vector-to-xegpu -split-input-file | FileCheck %s
+
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @dpas_gemm_f32(%lhs: vector<8x8xf32>, %rhs: vector<8x16xf32>,
+    %acc: vector<8x16xf32>) -> vector<8x16xf32> {
+  %3 = vector.contract
+    {indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>} %lhs, %rhs, %acc
+    : vector<8x8xf32>, vector<8x16xf32> into vector<8x16xf32>
+  return %3 : vector<8x16xf32>
+}
+
+// CHECK-LABEL: @dpas_gemm_f32(
+// CHECK-SAME:  %[[LHS:.+]]: vector<8x8xf32>,
+// CHECK-SAME:  %[[RHS:.+]]: vector<8x16xf32>,
+// CHECK-SAME:  %[[ACC:.+]]: vector<8x16xf32>
+// CHECK:       %[[DPAS:.+]] = xegpu.dpas
+// CHECK-SAME:    %[[LHS]], %[[RHS]], %[[ACC]]
+// CHECK-SAME:    {{.*}}-> vector<8x16xf32>
+// CHECK:       return %[[DPAS]]
+
+// -----
+
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @dpas_gemm_f16(%lhs: vector<8x16xf16>, %rhs: vector<16x16xf16>,
+    %acc: vector<8x16xf16>) -> vector<8x16xf16> {
+  %3 = vector.contract
+    {indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>} %lhs, %rhs, %acc
+    : vector<8x16xf16>, vector<16x16xf16> into vector<8x16xf16>
+  return %3 : vector<8x16xf16>
+}
+
+// CHECK-LABEL: @dpas_gemm_f16(
+// CHECK-SAME:  %[[LHS:.+]]: vector<8x16xf16>,
+// CHECK-SAME:  %[[RHS:.+]]: vector<16x16xf16>,
+// CHECK-SAME:  %[[ACC:.+]]: vector<8x16xf16>
+// CHECK:       %[[DPAS:.+]] = xegpu.dpas
+// CHECK-SAME:    %[[LHS]], %[[RHS]], %[[ACC]]
+// CHECK-SAME:    {{.*}}-> vector<8x16xf16>
+// CHECK:       return %[[DPAS]]
+
+// -----
+
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d2, d1, d3)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1)>
+func.func @dpas_gemm_f16_vnni(%lhs: vector<8x8x2xf16>, %rhs: vector<8x16x2xf16>,
+    %acc: vector<8x16xf16>) -> vector<8x16xf16> {
+  %3 = vector.contract
+    {indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction", "reduction"],
+    kind = #vector.kind<add>} %lhs, %rhs, %acc
+    : vector<8x8x2xf16>, vector<8x16x2xf16> into vector<8x16xf16>
+  return %3 : vector<8x16xf16>
+}
+
+// CHECK-LABEL: @dpas_gemm_f16_vnni(
+// CHECK-SAME:  %[[LHS:.+]]: vector<8x8x2xf16>,
+// CHECK-SAME:  %[[RHS:.+]]: vector<8x16x2xf16>,
+// CHECK-SAME:  %[[ACC:.+]]: vector<8x16xf16>
+// CHECK:       %[[CAST_LHS:.+]] = vector.shape_cast %[[LHS]]
+// CHECK-SAME:    vector<8x8x2xf16> to vector<8x16xf16>
+// CHECK:       %[[DPAS:.+]] = xegpu.dpas
+// CHECK-SAME:    %[[CAST_LHS]], %[[RHS]], %[[ACC]]
+// CHECK-SAME:    {{.*}}-> vector<8x16xf16>
+// CHECK:       return %[[DPAS]]
+
+// -----
+
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @dpas_gemm_mixed_types(%lhs: vector<8x16xi16>, %rhs: vector<16x16xi16>,
+    %acc: vector<8x16xf16>) -> vector<8x16xf16> {
+  %3 = vector.contract
+    {indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>} %lhs, %rhs, %acc
+    : vector<8x16xi16>, vector<16x16xi16> into vector<8x16xf16>
+  return %3 : vector<8x16xf16>
+}
+
+// CHECK-LABEL: @dpas_gemm_mixed_types(
+// CHECK-SAME:  %[[LHS:.+]]: vector<8x16xi16>,
+// CHECK-SAME:  %[[RHS:.+]]: vector<16x16xi16>,
+// CHECK-SAME:  %[[ACC:.+]]: vector<8x16xf16>
+// CHECK:       %[[DPAS:.+]] = xegpu.dpas
+// CHECK-SAME:    %[[LHS]], %[[RHS]], %[[ACC]]
+// CHECK-SAME:    {{.*}}-> vector<8x16xf16>
+// CHECK:       return %[[DPAS]]
+
+// -----
+
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @invalid_combining_type(%lhs: vector<8x16xf16>, %rhs: vector<16x16xf16>,
+    %acc: vector<8x16xf16>) -> vector<8x16xf16> {
+  %3 = vector.contract
+    {indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction"],
+    kind = #vector.kind<mul>} %lhs, %rhs, %acc
+    : vector<8x16xf16>, vector<16x16xf16> into vector<8x16xf16>
+  return %3 : vector<8x16xf16>
+}
+
+// CHECK-LABEL: @invalid_combining_type(
+// CHECK:       vector.contract
+
+// -----
+
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> ()>
+func.func @invalid_accumulator_shape(%lhs: vector<8x16xf16>, %rhs: vector<16x16xf16>,
+    %acc: vector<f16>) -> vector<f16> {
+  %3 = vector.contract
+    {indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["reduction", "reduction", "reduction"],
+    kind = #vector.kind<add>} %lhs, %rhs, %acc
+    : vector<8x16xf16>, vector<16x16xf16> into vector<f16>
+  return %3 : vector<f16>
+}
+
+// CHECK-LABEL: @invalid_accumulator_shape(
+// CHECK:       vector.contract
+
+// -----
+
+#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>
+#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2, d4)>
+#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d1, d2)>
+func.func @invalid_high_dim_reduction(%lhs: vector<3x8x8x2xf16>, %rhs: vector<3x8x16x2xf16>,
+    %acc: vector<8x16xf16>) -> vector<8x16xf16> {
+  %3 = vector.contract
+    {indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["reduction", "parallel", "parallel", "reduction", "reduction"],
+    kind = #vector.kind<add>} %lhs, %rhs, %acc
+    : vector<3x8x8x2xf16>, vector<3x8x16x2xf16> into vector<8x16xf16>
+  return %3 : vector<8x16xf16>
+}
+
+// CHECK-LABEL: @invalid_high_dim_reduction(
+// CHECK:       vector.contract
+
+// -----
+
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+func.func @invalid_indexing_maps(%lhs: vector<3x8x16xf16>, %rhs: vector<3x16x16xf16>,
+    %acc: vector<8x16xf16>) -> vector<8x16xf16> {
+  %3 = vector.contract
+    {indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["reduction", "parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>} %lhs, %rhs, %acc
+    : vector<3x8x16xf16>, vector<3x16x16xf16> into vector<8x16xf16>
+  return %3 : vector<8x16xf16>
+}
+
+// CHECK-LABEL: @invalid_indexing_maps(
+// CHECK:       vector.contract
+
+// -----
+
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1)>
+func.func @not_vnni_layout(%lhs: vector<8x8x2xf16>, %rhs: vector<16x8x2xf16>,
+    %acc: vector<8x16xf16>) -> vector<8x16xf16> {
+  %3 = vector.contract
+    {indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction", "reduction"],
+    kind = #vector.kind<add>} %lhs, %rhs, %acc
+    : vector<8x8x2xf16>, vector<16x8x2xf16> into vector<8x16xf16>
+  return %3 : vector<8x16xf16>
+}
+
+// CHECK-LABEL: @not_vnni_layout(
+// CHECK:       vector.contract
+
+// -----
+
+#map = affine_map<(d0, d1, d2) -> (d2, d0)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @invalid_gemm_transpose_a(%lhs: vector<8x8xf32>, %rhs: vector<8x16xf32>,
+    %acc: vector<8x16xf32>) -> vector<8x16xf32> {
+  %3 = vector.contract
+    {indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>} %lhs, %rhs, %acc
+    : vector<8x8xf32>, vector<8x16xf32> into vector<8x16xf32>
+  return %3 : vector<8x16xf32>
+}
+
+// CHECK-LABEL: @invalid_gemm_transpose_a(
+// CHECK:       vector.contract
+
+// -----
+
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @invalid_gemm_transpose_b(%lhs: vector<8x8xf32>, %rhs: vector<16x8xf32>,
+    %acc: vector<8x16xf32>) -> vector<8x16xf32> {
+  %3 = vector.contract
+    {indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>} %lhs, %rhs, %acc
+    : vector<8x8xf32>, vector<16x8xf32> into vector<8x16xf32>
+  return %3 : vector<8x16xf32>
+}
+
+// CHECK-LABEL: @invalid_gemm_transpose_b(
+// CHECK:       vector.contract
+
+// -----
+
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d2, d1, d3)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1)>
+func.func @invalid_k_dim_size(%lhs: vector<8x4x2xf16>, %rhs: vector<4x16x2xf16>,
+    %acc: vector<8x16xf16>) -> vector<8x16xf16> {
+  %3 = vector.contract
+    {indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction", "reduction"],
+    kind = #vector.kind<add>} %lhs, %rhs, %acc
+    : vector<8x4x2xf16>, vector<4x16x2xf16> into vector<8x16xf16>
+  return %3 : vector<8x16xf16>
+}
+
+// CHECK-LABEL: @invalid_k_dim_size(
+// CHECK:       vector.contract
+
+// -----
+
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d2, d1, d3)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1)>
+func.func @invalid_vnni_factor(%lhs: vector<8x4x4xf16>, %rhs: vector<4x16x4xf16>,
+    %acc: vector<8x16xf16>) -> vector<8x16xf16> {
+  %3 = vector.contract
+    {indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction", "reduction"],
+    kind = #vector.kind<add>} %lhs, %rhs, %acc
+    : vector<8x4x4xf16>, vector<4x16x4xf16> into vector<8x16xf16>
+  return %3 : vector<8x16xf16>
+}
+
+// CHECK-LABEL: @invalid_vnni_factor(
+// CHECK:       vector.contract

>From 1a1ec4ad2fd37f550adc2718298c453749435cb2 Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Tue, 11 Mar 2025 14:15:52 +0100
Subject: [PATCH 2/3] Simplify validation + update tests

---
 .../VectorToXeGPU/VectorToXeGPU.cpp           |  71 ++----
 .../VectorToXeGPU/contract-to-xegpu.mlir      | 205 ++++++------------
 2 files changed, 86 insertions(+), 190 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
index fd96341768199..f1a37d7d8336c 100644
--- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
+++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
@@ -14,6 +14,7 @@
 
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/XeGPU/IR/XeGPU.h"
 #include "mlir/Pass/Pass.h"
@@ -312,28 +313,6 @@ struct StoreLowering : public OpRewritePattern<vector::StoreOp> {
   }
 };
 
-static LogicalResult validateDpasIndexing(PatternRewriter &rewriter,
-                                          vector::ContractionOp contractOp) {
-  MLIRContext *ctx = contractOp.getContext();
-  SmallVector<AffineMap, 4> maps = contractOp.getIndexingMapsArray();
-
-  // Operand rank defines expected data layout:
-  //   - 2D for standard GEMM
-  //   - 3D for VNNI layout
-  using MapList = ArrayRef<ArrayRef<AffineExpr>>;
-  auto infer = [&](MapList m) { return AffineMap::inferFromExprList(m, ctx); };
-  AffineExpr m, n, k, vnni;
-  bindDims(ctx, m, n, k, vnni);
-
-  if (contractOp.getRhsType().getRank() == 2) {
-    // Require plain GEMM without any transposition.
-    return success(maps == infer({{m, k}, {k, n}, {m, n}}));
-  }
-
-  // Require VNNI layout.
-  return success(maps == infer({{m, k, vnni}, {k, n, vnni}, {m, n}}));
-}
-
 struct ContractionLowering : public OpRewritePattern<vector::ContractionOp> {
   using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
 
@@ -349,48 +328,30 @@ struct ContractionLowering : public OpRewritePattern<vector::ContractionOp> {
     VectorType accType = dyn_cast<VectorType>(acc.getType());
     if (!accType || accType.getRank() != 2)
       return rewriter.notifyMatchFailure(contractOp, "Expects acc 2D vector");
+
+    // Accept only plain 2D data layout.
+    // VNNI packing is left to later lowering.
     TypedValue<VectorType> lhs = contractOp.getLhs();
-    VectorType lhsType = lhs.getType();
-    int64_t lhsRank = lhsType.getRank();
-    if (!(lhsRank == 2 || lhsRank == 3))
-      return rewriter.notifyMatchFailure(contractOp,
-                                         "Expects lhs 2D or 3D vector");
     TypedValue<VectorType> rhs = contractOp.getRhs();
-    VectorType rhsType = rhs.getType();
-    int64_t rhsRank = rhsType.getRank();
-    if (!(rhsRank == 2 || rhsRank == 3))
+    if (lhs.getType().getRank() != 2 || rhs.getType().getRank() != 2)
       return rewriter.notifyMatchFailure(contractOp,
-                                         "Expects rhs 2D or 3D vector");
-    if (lhsRank != rhsRank)
-      return rewriter.notifyMatchFailure(
-          contractOp, "Expects lhs and rhs to be the same rank");
+                                         "Expects lhs and rhs 2D vectors");
 
-    if (failed(validateDpasIndexing(rewriter, contractOp)))
+    if (!isRowMajorMatmul(contractOp.getIndexingMapsAttr()))
       return rewriter.notifyMatchFailure(contractOp, "Invalid indexing maps");
 
-    // 3D shape implies VNNI layout verified by the earlier indexing check.
-    bool isVnni = rhsRank == 3;
-    auto rhsShape = rhsType.getShape();
-    int64_t dimK = isVnni ? rhsShape[0] * rhsShape[2] : rhsShape[0];
-    unsigned elemBitWidth = rhsType.getElementType().getIntOrFloatBitWidth();
-    if (dimK != (8 * 32 / elemBitWidth))
+    // TODO: Update shape validation to be target aware.
+    auto rhsShape = rhs.getType().getShape();
+    auto accShape = accType.getShape();
+    int64_t dimM = accShape[0];
+    int64_t dimN = accShape[1];
+    int64_t dimK = rhsShape[0];
+    if (dimM != 8 || dimN != 16 || dimK % 8 != 0)
       return rewriter.notifyMatchFailure(contractOp,
-                                         "Invalid K-dimension size");
-    if (isVnni && rhsShape[2] != (32 / elemBitWidth))
-      return rewriter.notifyMatchFailure(contractOp, "Invalid VNNI factor");
-
-    if (isVnni) {
-      // Collapse contract lhs VNNI factor back into K-dim as dpas op expects
-      // flat 2D shape for its lhs operand.
-      auto lhsShape = lhsType.getShape();
-      auto lhsFlatType = VectorType::get(
-          {lhsShape[0], lhsShape[1] * lhsShape[2]}, lhsType.getElementType());
-      lhs = rewriter.create<vector::ShapeCastOp>(loc, lhsFlatType, lhs)
-                .getResult();
-    }
+                                         "Invalid operand dimensions");
 
     auto dpasOp = rewriter.create<xegpu::DpasOp>(
-        loc, contractOp.getResultType(), lhs, rhs, acc);
+        loc, TypeRange{contractOp.getResultType()}, ValueRange{lhs, rhs, acc});
     rewriter.replaceOp(contractOp, dpasOp);
 
     return success();
diff --git a/mlir/test/Conversion/VectorToXeGPU/contract-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/contract-to-xegpu.mlir
index c470422e5ac76..1e41f8aa8fdd3 100644
--- a/mlir/test/Conversion/VectorToXeGPU/contract-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/contract-to-xegpu.mlir
@@ -3,19 +3,19 @@
 #map = affine_map<(d0, d1, d2) -> (d0, d2)>
 #map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
 #map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
-func.func @dpas_gemm_f32(%lhs: vector<8x8xf32>, %rhs: vector<8x16xf32>,
+func.func @dpas_gemm_f16(%lhs: vector<8x16xf16>, %rhs: vector<16x16xf16>,
     %acc: vector<8x16xf32>) -> vector<8x16xf32> {
   %3 = vector.contract
     {indexing_maps = [#map, #map1, #map2],
     iterator_types = ["parallel", "parallel", "reduction"],
     kind = #vector.kind<add>} %lhs, %rhs, %acc
-    : vector<8x8xf32>, vector<8x16xf32> into vector<8x16xf32>
+    : vector<8x16xf16>, vector<16x16xf16> into vector<8x16xf32>
   return %3 : vector<8x16xf32>
 }
 
-// CHECK-LABEL: @dpas_gemm_f32(
-// CHECK-SAME:  %[[LHS:.+]]: vector<8x8xf32>,
-// CHECK-SAME:  %[[RHS:.+]]: vector<8x16xf32>,
+// CHECK-LABEL: @dpas_gemm_f16(
+// CHECK-SAME:  %[[LHS:.+]]: vector<8x16xf16>,
+// CHECK-SAME:  %[[RHS:.+]]: vector<16x16xf16>,
 // CHECK-SAME:  %[[ACC:.+]]: vector<8x16xf32>
 // CHECK:       %[[DPAS:.+]] = xegpu.dpas
 // CHECK-SAME:    %[[LHS]], %[[RHS]], %[[ACC]]
@@ -27,91 +27,62 @@ func.func @dpas_gemm_f32(%lhs: vector<8x8xf32>, %rhs: vector<8x16xf32>,
 #map = affine_map<(d0, d1, d2) -> (d0, d2)>
 #map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
 #map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
-func.func @dpas_gemm_f16(%lhs: vector<8x16xf16>, %rhs: vector<16x16xf16>,
-    %acc: vector<8x16xf16>) -> vector<8x16xf16> {
+func.func @dpas_gemm_i8(%lhs: vector<8x32xi8>, %rhs: vector<32x16xi8>,
+    %acc: vector<8x16xi32>) -> vector<8x16xi32> {
   %3 = vector.contract
     {indexing_maps = [#map, #map1, #map2],
     iterator_types = ["parallel", "parallel", "reduction"],
     kind = #vector.kind<add>} %lhs, %rhs, %acc
-    : vector<8x16xf16>, vector<16x16xf16> into vector<8x16xf16>
-  return %3 : vector<8x16xf16>
+    : vector<8x32xi8>, vector<32x16xi8> into vector<8x16xi32>
+  return %3 : vector<8x16xi32>
 }
 
-// CHECK-LABEL: @dpas_gemm_f16(
-// CHECK-SAME:  %[[LHS:.+]]: vector<8x16xf16>,
-// CHECK-SAME:  %[[RHS:.+]]: vector<16x16xf16>,
-// CHECK-SAME:  %[[ACC:.+]]: vector<8x16xf16>
+// CHECK-LABEL: @dpas_gemm_i8(
+// CHECK-SAME:  %[[LHS:.+]]: vector<8x32xi8>,
+// CHECK-SAME:  %[[RHS:.+]]: vector<32x16xi8>,
+// CHECK-SAME:  %[[ACC:.+]]: vector<8x16xi32>
 // CHECK:       %[[DPAS:.+]] = xegpu.dpas
 // CHECK-SAME:    %[[LHS]], %[[RHS]], %[[ACC]]
-// CHECK-SAME:    {{.*}}-> vector<8x16xf16>
+// CHECK-SAME:    {{.*}}-> vector<8x16xi32>
 // CHECK:       return %[[DPAS]]
 
 // -----
 
+// VNNI packing is added automatically by later XeGPU consumer.
+// For simplicity, only plain data layouts are currently supported.
+
 #map = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
 #map1 = affine_map<(d0, d1, d2, d3) -> (d2, d1, d3)>
 #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1)>
-func.func @dpas_gemm_f16_vnni(%lhs: vector<8x8x2xf16>, %rhs: vector<8x16x2xf16>,
-    %acc: vector<8x16xf16>) -> vector<8x16xf16> {
+func.func @negative_vnni_packed(%lhs: vector<8x8x2xf16>, %rhs: vector<8x16x2xf16>,
+    %acc: vector<8x16xf32>) -> vector<8x16xf32> {
   %3 = vector.contract
     {indexing_maps = [#map, #map1, #map2],
     iterator_types = ["parallel", "parallel", "reduction", "reduction"],
     kind = #vector.kind<add>} %lhs, %rhs, %acc
-    : vector<8x8x2xf16>, vector<8x16x2xf16> into vector<8x16xf16>
-  return %3 : vector<8x16xf16>
-}
-
-// CHECK-LABEL: @dpas_gemm_f16_vnni(
-// CHECK-SAME:  %[[LHS:.+]]: vector<8x8x2xf16>,
-// CHECK-SAME:  %[[RHS:.+]]: vector<8x16x2xf16>,
-// CHECK-SAME:  %[[ACC:.+]]: vector<8x16xf16>
-// CHECK:       %[[CAST_LHS:.+]] = vector.shape_cast %[[LHS]]
-// CHECK-SAME:    vector<8x8x2xf16> to vector<8x16xf16>
-// CHECK:       %[[DPAS:.+]] = xegpu.dpas
-// CHECK-SAME:    %[[CAST_LHS]], %[[RHS]], %[[ACC]]
-// CHECK-SAME:    {{.*}}-> vector<8x16xf16>
-// CHECK:       return %[[DPAS]]
-
-// -----
-
-#map = affine_map<(d0, d1, d2) -> (d0, d2)>
-#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
-#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
-func.func @dpas_gemm_mixed_types(%lhs: vector<8x16xi16>, %rhs: vector<16x16xi16>,
-    %acc: vector<8x16xf16>) -> vector<8x16xf16> {
-  %3 = vector.contract
-    {indexing_maps = [#map, #map1, #map2],
-    iterator_types = ["parallel", "parallel", "reduction"],
-    kind = #vector.kind<add>} %lhs, %rhs, %acc
-    : vector<8x16xi16>, vector<16x16xi16> into vector<8x16xf16>
-  return %3 : vector<8x16xf16>
+    : vector<8x8x2xf16>, vector<8x16x2xf16> into vector<8x16xf32>
+  return %3 : vector<8x16xf32>
 }
 
-// CHECK-LABEL: @dpas_gemm_mixed_types(
-// CHECK-SAME:  %[[LHS:.+]]: vector<8x16xi16>,
-// CHECK-SAME:  %[[RHS:.+]]: vector<16x16xi16>,
-// CHECK-SAME:  %[[ACC:.+]]: vector<8x16xf16>
-// CHECK:       %[[DPAS:.+]] = xegpu.dpas
-// CHECK-SAME:    %[[LHS]], %[[RHS]], %[[ACC]]
-// CHECK-SAME:    {{.*}}-> vector<8x16xf16>
-// CHECK:       return %[[DPAS]]
+// CHECK-LABEL: @negative_vnni_packed(
+// CHECK: vector.contract
 
 // -----
 
 #map = affine_map<(d0, d1, d2) -> (d0, d2)>
 #map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
 #map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
-func.func @invalid_combining_type(%lhs: vector<8x16xf16>, %rhs: vector<16x16xf16>,
-    %acc: vector<8x16xf16>) -> vector<8x16xf16> {
+func.func @negative_combining_type(%lhs: vector<8x16xf16>, %rhs: vector<16x16xf16>,
+    %acc: vector<8x16xf32>) -> vector<8x16xf32> {
   %3 = vector.contract
     {indexing_maps = [#map, #map1, #map2],
     iterator_types = ["parallel", "parallel", "reduction"],
     kind = #vector.kind<mul>} %lhs, %rhs, %acc
-    : vector<8x16xf16>, vector<16x16xf16> into vector<8x16xf16>
-  return %3 : vector<8x16xf16>
+    : vector<8x16xf16>, vector<16x16xf16> into vector<8x16xf32>
+  return %3 : vector<8x16xf32>
 }
 
-// CHECK-LABEL: @invalid_combining_type(
+// CHECK-LABEL: @negative_combining_type(
 // CHECK:       vector.contract
 
 // -----
@@ -119,141 +90,105 @@ func.func @invalid_combining_type(%lhs: vector<8x16xf16>, %rhs: vector<16x16xf16
 #map = affine_map<(d0, d1, d2) -> (d0, d2)>
 #map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
 #map2 = affine_map<(d0, d1, d2) -> ()>
-func.func @invalid_accumulator_shape(%lhs: vector<8x16xf16>, %rhs: vector<16x16xf16>,
-    %acc: vector<f16>) -> vector<f16> {
+func.func @negative_accumulator_shape(%lhs: vector<8x16xf16>, %rhs: vector<16x16xf16>,
+    %acc: vector<f32>) -> vector<f32> {
   %3 = vector.contract
     {indexing_maps = [#map, #map1, #map2],
     iterator_types = ["reduction", "reduction", "reduction"],
     kind = #vector.kind<add>} %lhs, %rhs, %acc
-    : vector<8x16xf16>, vector<16x16xf16> into vector<f16>
-  return %3 : vector<f16>
+    : vector<8x16xf16>, vector<16x16xf16> into vector<f32>
+  return %3 : vector<f32>
 }
 
-// CHECK-LABEL: @invalid_accumulator_shape(
+// CHECK-LABEL: @negative_accumulator_shape(
 // CHECK:       vector.contract
 
 // -----
 
-#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>
-#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2, d4)>
-#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d1, d2)>
-func.func @invalid_high_dim_reduction(%lhs: vector<3x8x8x2xf16>, %rhs: vector<3x8x16x2xf16>,
-    %acc: vector<8x16xf16>) -> vector<8x16xf16> {
+#map = affine_map<(d0, d1, d2) -> (d2, d0)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @negative_gemm_transpose_a(%lhs: vector<16x8xf16>, %rhs: vector<16x16xf16>,
+    %acc: vector<8x16xf32>) -> vector<8x16xf32> {
   %3 = vector.contract
     {indexing_maps = [#map, #map1, #map2],
-    iterator_types = ["reduction", "parallel", "parallel", "reduction", "reduction"],
+    iterator_types = ["parallel", "parallel", "reduction"],
     kind = #vector.kind<add>} %lhs, %rhs, %acc
-    : vector<3x8x8x2xf16>, vector<3x8x16x2xf16> into vector<8x16xf16>
-  return %3 : vector<8x16xf16>
+    : vector<16x8xf16>, vector<16x16xf16> into vector<8x16xf32>
+  return %3 : vector<8x16xf32>
 }
 
-// CHECK-LABEL: @invalid_high_dim_reduction(
+// CHECK-LABEL: @negative_gemm_transpose_a(
 // CHECK:       vector.contract
 
 // -----
 
-#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
-#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
-#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
-func.func @invalid_indexing_maps(%lhs: vector<3x8x16xf16>, %rhs: vector<3x16x16xf16>,
-    %acc: vector<8x16xf16>) -> vector<8x16xf16> {
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @negative_gemm_transpose_b(%lhs: vector<8x16xf16>, %rhs: vector<16x16xf16>,
+    %acc: vector<8x16xf32>) -> vector<8x16xf32> {
   %3 = vector.contract
     {indexing_maps = [#map, #map1, #map2],
-    iterator_types = ["reduction", "parallel", "parallel", "reduction"],
+    iterator_types = ["parallel", "parallel", "reduction"],
     kind = #vector.kind<add>} %lhs, %rhs, %acc
-    : vector<3x8x16xf16>, vector<3x16x16xf16> into vector<8x16xf16>
-  return %3 : vector<8x16xf16>
+    : vector<8x16xf16>, vector<16x16xf16> into vector<8x16xf32>
+  return %3 : vector<8x16xf32>
 }
 
-// CHECK-LABEL: @invalid_indexing_maps(
+// CHECK-LABEL: @negative_gemm_transpose_b(
 // CHECK:       vector.contract
 
 // -----
 
-#map = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
-#map1 = affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)>
-#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1)>
-func.func @not_vnni_layout(%lhs: vector<8x8x2xf16>, %rhs: vector<16x8x2xf16>,
-    %acc: vector<8x16xf16>) -> vector<8x16xf16> {
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @negative_m_dim_size(%lhs: vector<16x16xf16>, %rhs: vector<16x16xf16>,
+    %acc: vector<16x16xf32>) -> vector<16x16xf32> {
   %3 = vector.contract
     {indexing_maps = [#map, #map1, #map2],
-    iterator_types = ["parallel", "parallel", "reduction", "reduction"],
+    iterator_types = ["parallel", "parallel", "reduction"],
     kind = #vector.kind<add>} %lhs, %rhs, %acc
-    : vector<8x8x2xf16>, vector<16x8x2xf16> into vector<8x16xf16>
-  return %3 : vector<8x16xf16>
+    : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf32>
+  return %3 : vector<16x16xf32>
 }
 
-// CHECK-LABEL: @not_vnni_layout(
+// CHECK-LABEL: @negative_m_dim_size(
 // CHECK:       vector.contract
 
 // -----
 
-#map = affine_map<(d0, d1, d2) -> (d2, d0)>
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
 #map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
 #map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
-func.func @invalid_gemm_transpose_a(%lhs: vector<8x8xf32>, %rhs: vector<8x16xf32>,
-    %acc: vector<8x16xf32>) -> vector<8x16xf32> {
+func.func @negative_n_dim_size(%lhs: vector<8x16xf16>, %rhs: vector<16x8xf16>,
+    %acc: vector<8x8xf32>) -> vector<8x8xf32> {
   %3 = vector.contract
     {indexing_maps = [#map, #map1, #map2],
     iterator_types = ["parallel", "parallel", "reduction"],
     kind = #vector.kind<add>} %lhs, %rhs, %acc
-    : vector<8x8xf32>, vector<8x16xf32> into vector<8x16xf32>
-  return %3 : vector<8x16xf32>
+    : vector<8x16xf16>, vector<16x8xf16> into vector<8x8xf32>
+  return %3 : vector<8x8xf32>
 }
 
-// CHECK-LABEL: @invalid_gemm_transpose_a(
+// CHECK-LABEL: @negative_n_dim_size(
 // CHECK:       vector.contract
 
 // -----
 
 #map = affine_map<(d0, d1, d2) -> (d0, d2)>
-#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
 #map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
-func.func @invalid_gemm_transpose_b(%lhs: vector<8x8xf32>, %rhs: vector<16x8xf32>,
+func.func @negative_k_dim_size(%lhs: vector<8x12xf16>, %rhs: vector<12x16xf16>,
     %acc: vector<8x16xf32>) -> vector<8x16xf32> {
   %3 = vector.contract
     {indexing_maps = [#map, #map1, #map2],
     iterator_types = ["parallel", "parallel", "reduction"],
     kind = #vector.kind<add>} %lhs, %rhs, %acc
-    : vector<8x8xf32>, vector<16x8xf32> into vector<8x16xf32>
+    : vector<8x12xf16>, vector<12x16xf16> into vector<8x16xf32>
   return %3 : vector<8x16xf32>
 }
 
-// CHECK-LABEL: @invalid_gemm_transpose_b(
-// CHECK:       vector.contract
-
-// -----
-
-#map = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
-#map1 = affine_map<(d0, d1, d2, d3) -> (d2, d1, d3)>
-#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1)>
-func.func @invalid_k_dim_size(%lhs: vector<8x4x2xf16>, %rhs: vector<4x16x2xf16>,
-    %acc: vector<8x16xf16>) -> vector<8x16xf16> {
-  %3 = vector.contract
-    {indexing_maps = [#map, #map1, #map2],
-    iterator_types = ["parallel", "parallel", "reduction", "reduction"],
-    kind = #vector.kind<add>} %lhs, %rhs, %acc
-    : vector<8x4x2xf16>, vector<4x16x2xf16> into vector<8x16xf16>
-  return %3 : vector<8x16xf16>
-}
-
-// CHECK-LABEL: @invalid_k_dim_size(
-// CHECK:       vector.contract
-
-// -----
-
-#map = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
-#map1 = affine_map<(d0, d1, d2, d3) -> (d2, d1, d3)>
-#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1)>
-func.func @invalid_vnni_factor(%lhs: vector<8x4x4xf16>, %rhs: vector<4x16x4xf16>,
-    %acc: vector<8x16xf16>) -> vector<8x16xf16> {
-  %3 = vector.contract
-    {indexing_maps = [#map, #map1, #map2],
-    iterator_types = ["parallel", "parallel", "reduction", "reduction"],
-    kind = #vector.kind<add>} %lhs, %rhs, %acc
-    : vector<8x4x4xf16>, vector<4x16x4xf16> into vector<8x16xf16>
-  return %3 : vector<8x16xf16>
-}
-
-// CHECK-LABEL: @invalid_vnni_factor(
+// CHECK-LABEL: @negative_k_dim_size(
 // CHECK:       vector.contract

>From fa9dbc3550283cab535a62f1b8e4d3e0e6f369dd Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Tue, 11 Mar 2025 14:25:42 +0100
Subject: [PATCH 3/3] Rename test case

---
 mlir/test/Conversion/VectorToXeGPU/contract-to-xegpu.mlir | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/test/Conversion/VectorToXeGPU/contract-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/contract-to-xegpu.mlir
index 1e41f8aa8fdd3..98c9fac83f9c7 100644
--- a/mlir/test/Conversion/VectorToXeGPU/contract-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/contract-to-xegpu.mlir
@@ -72,7 +72,7 @@ func.func @negative_vnni_packed(%lhs: vector<8x8x2xf16>, %rhs: vector<8x16x2xf16
 #map = affine_map<(d0, d1, d2) -> (d0, d2)>
 #map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
 #map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
-func.func @negative_combining_type(%lhs: vector<8x16xf16>, %rhs: vector<16x16xf16>,
+func.func @negative_combining_kind(%lhs: vector<8x16xf16>, %rhs: vector<16x16xf16>,
     %acc: vector<8x16xf32>) -> vector<8x16xf32> {
   %3 = vector.contract
     {indexing_maps = [#map, #map1, #map2],
@@ -82,7 +82,7 @@ func.func @negative_combining_type(%lhs: vector<8x16xf16>, %rhs: vector<16x16xf1
   return %3 : vector<8x16xf32>
 }
 
-// CHECK-LABEL: @negative_combining_type(
+// CHECK-LABEL: @negative_combining_kind(
 // CHECK:       vector.contract
 
 // -----



More information about the Mlir-commits mailing list