[Mlir-commits] [mlir] [MLIR][XeGPU] Add distribution pattern for xegpu.load & store for sg to wi pass (PR #181917)

Jianhui Li llvmlistbot at llvm.org
Thu Feb 26 11:52:36 PST 2026


================
@@ -522,6 +587,73 @@ struct LowerVectorMultiReductionPattern
   }
 };
 
+/// Distributes a subgroup-level StoreScatter (xegpu.store) op to
+/// workitem-level.
+struct SgToWiStoreScatter : public OpConversionPattern<xegpu::StoreScatterOp> {
+  using OpConversionPattern<xegpu::StoreScatterOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(xegpu::StoreScatterOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    xegpu::DistributeLayoutAttr layout = op.getAnchorLayout();
+    if (!layout)
+      return failure();
+
+    VectorType valueTy = op.getValueType();
+    if (!valueTy)
+      return failure();
+
+    // Check that all leading dimensions are unit dimensions.
+    int chunkSize = op.getChunkSize().value_or(1);
+    int effectiveVecRank = (chunkSize == 1) ? 1 : 2;
+    ArrayRef<int64_t> shape = valueTy.getShape();
+    if (llvm::any_of(shape.take_front(valueTy.getRank() - effectiveVecRank),
+                     [](int64_t d) { return d != 1; }))
+      return rewriter.notifyMatchFailure(
+          op, "Only unit dimensions allowed for the leading "
+              "dimensions of the store vector!");
+
+    auto expectedWiValueTyOrFailure =
+        xegpu::getDistVecTypeBasedOnLaneLayout(layout, valueTy);
+    if (failed(expectedWiValueTyOrFailure))
+      return rewriter.notifyMatchFailure(
+          op,
+          "unable to compute expected workitem vector type from lane layout");
+
+    VectorType expectedWiValueTy = expectedWiValueTyOrFailure.value();
+    VectorType supportedWiValueTy =
----------------
Jianhui-Li wrote:

Not sure about "expected" and "supported" prefix. why not just distValueTy and distValueTy1D? 

Somthing like 
    auto distValueTyOrFailure = xegpu::getDistVecTypeBasedOnLaneLayout(layout, origValueTy);
    VectorType distValueTy = distValueTyOrFailure.value();
    VectorType distValueTy1D = vectorType::get(distValueTy.getNumElements()} ...

https://github.com/llvm/llvm-project/pull/181917


More information about the Mlir-commits mailing list