[Mlir-commits] [mlir] [mlir][xegpu] Convert Vector contraction to XeGPU (PR #122115)

Adam Siemieniuk llvmlistbot at llvm.org
Tue Jan 14 08:10:42 PST 2025


================
@@ -313,6 +313,91 @@ struct StoreLowering : public OpRewritePattern<vector::StoreOp> {
   }
 };
 
+static LogicalResult validateDpasIndexing(PatternRewriter &rewriter,
+                                          vector::ContractionOp contractOp) {
+  MLIRContext *ctx = contractOp.getContext();
+  SmallVector<AffineMap, 4> maps = contractOp.getIndexingMapsArray();
+
+  // Operand rank defines expected data layout:
+  //   - 2D for standard GEMM
+  //   - 3D for VNNI layout
+  using MapList = ArrayRef<ArrayRef<AffineExpr>>;
+  auto infer = [&](MapList m) { return AffineMap::inferFromExprList(m, ctx); };
+  AffineExpr m, n, k, vnni;
+  bindDims(ctx, m, n, k, vnni);
+
+  if (contractOp.getRhsType().getRank() == 2) {
+    // Require plain GEMM without any transposition.
+    return success(maps == infer({{m, k}, {k, n}, {m, n}}));
+  }
+
+  // Require VNNI layout.
+  return success(maps == infer({{m, k, vnni}, {k, n, vnni}, {m, n}}));
+}
+
+struct ContractionLowering : public OpRewritePattern<vector::ContractionOp> {
+  using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
+                                PatternRewriter &rewriter) const override {
+    Location loc = contractOp.getLoc();
+
+    if (contractOp.getKind() != vector::CombiningKind::ADD)
+      return rewriter.notifyMatchFailure(contractOp,
+                                         "Expects add combining kind");
+
+    TypedValue<Type> acc = contractOp.getAcc();
+    VectorType accType = dyn_cast<VectorType>(acc.getType());
+    if (!accType || accType.getRank() != 2)
+      return rewriter.notifyMatchFailure(contractOp, "Expects acc 2D vector");
+    TypedValue<VectorType> lhs = contractOp.getLhs();
+    VectorType lhsType = lhs.getType();
+    int64_t lhsRank = lhsType.getRank();
+    if (!(lhsRank == 2 || lhsRank == 3))
+      return rewriter.notifyMatchFailure(contractOp,
+                                         "Expects lhs 2D or 3D vector");
+    TypedValue<VectorType> rhs = contractOp.getRhs();
+    VectorType rhsType = rhs.getType();
+    int64_t rhsRank = rhsType.getRank();
+    if (!(rhsRank == 2 || rhsRank == 3))
+      return rewriter.notifyMatchFailure(contractOp,
+                                         "Expects rhs 2D or 3D vector");
+    if (lhsRank != rhsRank)
+      return rewriter.notifyMatchFailure(
+          contractOp, "Expects lhs and rhs to be the same rank");
+
+    if (failed(validateDpasIndexing(rewriter, contractOp)))
+      return rewriter.notifyMatchFailure(contractOp, "Invalid indexing maps");
+
+    // 3D shape implies VNNI layout verified by the earlier indexing check.
+    bool isVnni = rhsRank == 3;
+    auto rhsShape = rhsType.getShape();
+    int64_t dimK = isVnni ? rhsShape[0] * rhsShape[2] : rhsShape[0];
+    unsigned elemBitWidth = rhsType.getElementType().getIntOrFloatBitWidth();
+    if (dimK != (8 * 32 / elemBitWidth))
+      return rewriter.notifyMatchFailure(contractOp,
+                                         "Invalid K-dimension size");
+    if (isVnni && rhsShape[2] != (32 / elemBitWidth))
+      return rewriter.notifyMatchFailure(contractOp, "Invalid VNNI factor");
+
+    if (isVnni) {
----------------
adam-smnk wrote:

I don't think I have any use case right now that would require 3D shapes support.
I will constrain in only to 2D inputs for simplicity.

https://github.com/llvm/llvm-project/pull/122115


More information about the Mlir-commits mailing list