[Mlir-commits] [mlir] [MLIR][XeGPU] Add distribution pattern for xegpu.load & store for sg to wi pass (PR #181917)

Charitha Saumya llvmlistbot at llvm.org
Fri Feb 20 13:25:10 PST 2026


================
@@ -395,6 +395,78 @@ struct SgToWiPrefetchNd : public OpConversionPattern<xegpu::PrefetchNdOp> {
   }
 };
 
+/// Distributes a subgroup-level LoadGather (xegpu.load) op to workitem-level.
+struct SgToWiLoadGather : public OpConversionPattern<xegpu::LoadGatherOp> {
+  using OpConversionPattern<xegpu::LoadGatherOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(xegpu::LoadGatherOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    xegpu::DistributeLayoutAttr layout = op.getAnchorLayout();
+    if (!layout)
+      return failure();
+
+    VectorType resultTy = op.getValueType();
+    if (!resultTy)
+      return failure();
+
+    // Check that leading dimensions are unit.
+    int chunkSize = op.getChunkSize().value_or(1);
+    int effectiveVecRank = (chunkSize == 1) ? 1 : 2;
+    for (int i = 0; i < resultTy.getRank() - effectiveVecRank; i++) {
+      if (resultTy.getShape()[i] != 1)
+        return rewriter.notifyMatchFailure(
+            op, "Only unit dimensions allowed for the leading "
+                "dimensions of the load vector!");
+    }
+
+    auto expectedWiResultTyOrFailure =
+        xegpu::getDistVecTypeBasedOnLaneLayout(layout, resultTy);
+    if (failed(expectedWiResultTyOrFailure))
+      return rewriter.notifyMatchFailure(
+          op,
+          "unable to compute expected workitem vector type from lane layout");
+
+    VectorType expectedWiResultTy = expectedWiResultTyOrFailure.value();
+    VectorType supportedWiResultTy =
+        VectorType::get({expectedWiResultTy.getNumElements()},
+                        expectedWiResultTy.getElementType());
+
+    // Flatten offsets and mask to 1D to match the 1D result type.
+    Value offsets = adaptor.getOffsets();
+    if (auto offsetsTy = dyn_cast<VectorType>(offsets.getType())) {
+      VectorType offsetsTy1D = VectorType::get({offsetsTy.getNumElements()},
+                                               offsetsTy.getElementType());
+      if (offsetsTy != offsetsTy1D)
+        offsets = vector::ShapeCastOp::create(rewriter, op.getLoc(),
+                                              offsetsTy1D, offsets)
+                      .getResult();
+    }
+    Value mask = adaptor.getMask();
+    if (auto maskTy = dyn_cast<VectorType>(mask.getType())) {
+      VectorType maskTy1D =
+          VectorType::get({maskTy.getNumElements()}, maskTy.getElementType());
+      if (maskTy != maskTy1D)
+        mask =
+            vector::ShapeCastOp::create(rewriter, op.getLoc(), maskTy1D, mask)
+                .getResult();
+    }
----------------
charithaintc wrote:

consider using  a lambda. 

https://github.com/llvm/llvm-project/pull/181917


More information about the Mlir-commits mailing list