[Mlir-commits] [mlir] [mlir][xegpu] Convert Vector contraction to XeGPU (PR #122115)
Adam Siemieniuk
llvmlistbot at llvm.org
Fri Jan 10 09:37:59 PST 2025
================
@@ -313,6 +313,91 @@ struct StoreLowering : public OpRewritePattern<vector::StoreOp> {
}
};
+static LogicalResult validateDpasIndexing(PatternRewriter &rewriter,
+ vector::ContractionOp contractOp) {
+ MLIRContext *ctx = contractOp.getContext();
+ SmallVector<AffineMap, 4> maps = contractOp.getIndexingMapsArray();
+
+ // Operand rank defines expected data layout:
+ // - 2D for standard GEMM
+ // - 3D for VNNI layout
+ using MapList = ArrayRef<ArrayRef<AffineExpr>>;
+ auto infer = [&](MapList m) { return AffineMap::inferFromExprList(m, ctx); };
+ AffineExpr m, n, k, vnni;
+ bindDims(ctx, m, n, k, vnni);
+
+ if (contractOp.getRhsType().getRank() == 2) {
+ // Require plain GEMM without any transposition.
+ return success(maps == infer({{m, k}, {k, n}, {m, n}}));
+ }
+
+ // Require VNNI layout.
+ return success(maps == infer({{m, k, vnni}, {k, n, vnni}, {m, n}}));
+}
+
+struct ContractionLowering : public OpRewritePattern<vector::ContractionOp> {
+ using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
+ PatternRewriter &rewriter) const override {
+ Location loc = contractOp.getLoc();
+
+ if (contractOp.getKind() != vector::CombiningKind::ADD)
+ return rewriter.notifyMatchFailure(contractOp,
+ "Expects add combining kind");
+
+ TypedValue<Type> acc = contractOp.getAcc();
+ VectorType accType = dyn_cast<VectorType>(acc.getType());
+ if (!accType || accType.getRank() != 2)
+ return rewriter.notifyMatchFailure(contractOp, "Expects acc 2D vector");
+ TypedValue<VectorType> lhs = contractOp.getLhs();
+ VectorType lhsType = lhs.getType();
+ int64_t lhsRank = lhsType.getRank();
+ if (!(lhsRank == 2 || lhsRank == 3))
+ return rewriter.notifyMatchFailure(contractOp,
+ "Expects lhs 2D or 3D vector");
+ TypedValue<VectorType> rhs = contractOp.getRhs();
+ VectorType rhsType = rhs.getType();
+ int64_t rhsRank = rhsType.getRank();
+ if (!(rhsRank == 2 || rhsRank == 3))
+ return rewriter.notifyMatchFailure(contractOp,
+ "Expects rhs 2D or 3D vector");
+ if (lhsRank != rhsRank)
+ return rewriter.notifyMatchFailure(
+ contractOp, "Expects lhs and rhs to be the same rank");
+
+ if (failed(validateDpasIndexing(rewriter, contractOp)))
+ return rewriter.notifyMatchFailure(contractOp, "Invalid indexing maps");
+
+ // 3D shape implies VNNI layout verified by the earlier indexing check.
+ bool isVnni = rhsRank == 3;
+ auto rhsShape = rhsType.getShape();
+ int64_t dimK = isVnni ? rhsShape[0] * rhsShape[2] : rhsShape[0];
+ unsigned elemBitWidth = rhsType.getElementType().getIntOrFloatBitWidth();
+ if (dimK != (8 * 32 / elemBitWidth))
+ return rewriter.notifyMatchFailure(contractOp,
+ "Invalid K-dimension size");
+ if (isVnni && rhsShape[2] != (32 / elemBitWidth))
+ return rewriter.notifyMatchFailure(contractOp, "Invalid VNNI factor");
+
+ if (isVnni) {
----------------
adam-smnk wrote:
The motivation is that in order to have correct affine maps for `vector.contract` I need matrix A in VNNI 3D format as well.
Ideally, the load is still in 2D followed by a nop expand cast like:
```mlir
%A = vector.load <8x16>
%A_VNNI = vector.shape_cast %A <8x16> to <8x8x2>
%res = vector.contract %A_VNNI, %B_VNNI
```
As `xegpu.dpas` requires 2D for lhs, I add here another nop collapse cast:
```mlir
%A_FLAT = vector.shape_cast %A_VNNI <8x8x2> to <8x16>
%res = xegpu.dpas %A_FLAT, %B_VNNI
```
In the end, these should have no performance impact and are present only to propagate shapes correctly.
https://github.com/llvm/llvm-project/pull/122115
More information about the Mlir-commits
mailing list