[Mlir-commits] [mlir] [MLIR][XeGPU] Matrix load/store subgroup distribution (PR #165008)

Artem Kroviakov llvmlistbot at llvm.org
Tue Oct 28 09:12:59 PDT 2025


================
@@ -906,6 +907,123 @@ struct StoreDistribution final : public gpu::WarpDistributionPattern {
   }
 };
 
+template <class MatrixOp>
+struct MatrixOpDistribution final : public gpu::WarpDistributionPattern {
+  using gpu::WarpDistributionPattern::WarpDistributionPattern;
+  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
+                                PatternRewriter &rewriter) const override {
+    gpu::YieldOp yield = warpOp.getTerminator();
+    Operation *lastNode = yield->getPrevNode();
+    auto matrixOp = dyn_cast_or_null<MatrixOp>(lastNode);
+    if (!matrixOp)
+      return failure();
+    constexpr bool isLoad{std::is_same_v<MatrixOp, xegpu::LoadMatrixOp>};
+    int operandIdx{-1};
+
+    VectorType sgPayloadTy;
+    VectorType warpResultTy;
+    if constexpr (isLoad) {
+      OpOperand *producedByLastLoad = getWarpResult(warpOp, [&](Operation *op) {
+        return isa<xegpu::LoadMatrixOp>(op) && matrixOp == op;
+      });
+      if (!producedByLastLoad)
+        return rewriter.notifyMatchFailure(
+            warpOp, "The last op is not xegpu::LoadMatrixOp");
+      operandIdx = producedByLastLoad->getOperandNumber();
+      sgPayloadTy = dyn_cast<VectorType>(matrixOp.getResult().getType());
+      warpResultTy = cast<VectorType>(warpOp.getResult(operandIdx).getType());
+    } else {
+      sgPayloadTy = dyn_cast<VectorType>(matrixOp.getData().getType());
+    }
+    if (!sgPayloadTy)
+      return rewriter.notifyMatchFailure(
+          matrixOp, "the matrix op payload must be a vector type");
+
+    auto loc = matrixOp.getLoc();
+    auto offsets = matrixOp.getMixedOffsets();
+    if (offsets.empty())
+      return rewriter.notifyMatchFailure(matrixOp,
+                                         "the load op must have offsets");
+    SmallVector<Value> offsetsAsValues =
+        vector::getAsValues(rewriter, matrixOp.getLoc(), offsets);
+
+    auto layout = matrixOp.getLayoutAttr();
+    if (!layout)
+      return rewriter.notifyMatchFailure(
+          matrixOp, "the matrix operation lacks layout attribute");
+
+    FailureOr<VectorType> distPayloadByWarpOpOrFailure =
+        getDistVecTypeBasedOnLaneLayout(layout, sgPayloadTy);
+    if (failed(distPayloadByWarpOpOrFailure))
+      return rewriter.notifyMatchFailure(
+          matrixOp,
+          "The matrix op payload has no layouts, using defaults instead.");
----------------
akroviakov wrote:

Removed this part

https://github.com/llvm/llvm-project/pull/165008


More information about the Mlir-commits mailing list