[Mlir-commits] [mlir] [mlir][xegpu] Convert Vector contraction to XeGPU (PR #122115)
Charitha Saumya
llvmlistbot at llvm.org
Tue Mar 11 19:11:58 PDT 2025
================
@@ -312,6 +313,48 @@ struct StoreLowering : public OpRewritePattern<vector::StoreOp> {
}
};
+struct ContractionLowering : public OpRewritePattern<vector::ContractionOp> {
+ using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
+ PatternRewriter &rewriter) const override {
+ Location loc = contractOp.getLoc();
+
+ if (contractOp.getKind() != vector::CombiningKind::ADD)
+ return rewriter.notifyMatchFailure(contractOp,
+ "Expects add combining kind");
+
+ TypedValue<Type> acc = contractOp.getAcc();
+ VectorType accType = dyn_cast<VectorType>(acc.getType());
+ if (!accType || accType.getRank() != 2)
+ return rewriter.notifyMatchFailure(contractOp, "Expects acc 2D vector");
+
+ // Accept only plain 2D data layout.
+ // VNNI packing is applied to DPAS as a separate lowering step.
+ TypedValue<VectorType> lhs = contractOp.getLhs();
+ TypedValue<VectorType> rhs = contractOp.getRhs();
+ if (lhs.getType().getRank() != 2 || rhs.getType().getRank() != 2)
+ return rewriter.notifyMatchFailure(contractOp,
+ "Expects lhs and rhs 2D vectors");
+
+ if (!isRowMajorMatmul(contractOp.getIndexingMapsAttr()))
+ return rewriter.notifyMatchFailure(contractOp, "Invalid indexing maps");
+
+ // TODO: Update shape validation to be target aware.
+ auto accShape = accType.getShape();
+ int64_t dimN = accShape[1];
+ if (dimN != 8 && dimN != 16)
----------------
charithaintc wrote:
consider defining a set of allowed N dims rather than using hard coding.
https://github.com/llvm/llvm-project/pull/122115
More information about the Mlir-commits
mailing list