[Mlir-commits] [mlir] [mlir][xegpu] Convert Vector contraction to XeGPU (PR #122115)

Charitha Saumya llvmlistbot at llvm.org
Tue Mar 11 19:11:58 PDT 2025


================
@@ -312,6 +313,48 @@ struct StoreLowering : public OpRewritePattern<vector::StoreOp> {
   }
 };
 
+struct ContractionLowering : public OpRewritePattern<vector::ContractionOp> {
+  using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
+                                PatternRewriter &rewriter) const override {
+    Location loc = contractOp.getLoc();
+
+    if (contractOp.getKind() != vector::CombiningKind::ADD)
+      return rewriter.notifyMatchFailure(contractOp,
+                                         "Expects add combining kind");
+
+    TypedValue<Type> acc = contractOp.getAcc();
+    VectorType accType = dyn_cast<VectorType>(acc.getType());
+    if (!accType || accType.getRank() != 2)
+      return rewriter.notifyMatchFailure(contractOp, "Expects acc 2D vector");
+
+    // Accept only plain 2D data layout.
+    // VNNI packing is applied to DPAS as a separate lowering step.
+    TypedValue<VectorType> lhs = contractOp.getLhs();
+    TypedValue<VectorType> rhs = contractOp.getRhs();
+    if (lhs.getType().getRank() != 2 || rhs.getType().getRank() != 2)
+      return rewriter.notifyMatchFailure(contractOp,
+                                         "Expects lhs and rhs 2D vectors");
+
+    if (!isRowMajorMatmul(contractOp.getIndexingMapsAttr()))
+      return rewriter.notifyMatchFailure(contractOp, "Invalid indexing maps");
+
+    // TODO: Update shape validation to be target aware.
+    auto accShape = accType.getShape();
+    int64_t dimN = accShape[1];
+    if (dimN != 8 && dimN != 16)
----------------
charithaintc wrote:

consider defining a set of allowed N dims rather than using hard coding. 

https://github.com/llvm/llvm-project/pull/122115


More information about the Mlir-commits mailing list