[Mlir-commits] [mlir] c5c6588 - [MLIR][XeGPU] Add distribution pattern for xegpu load & store matrix from sg to wi (#183179)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Mar 5 09:56:56 PST 2026


Author: Nishant Patel
Date: 2026-03-05T09:56:51-08:00
New Revision: c5c6588c115347cdd000cb8a650b19dca0381196

URL: https://github.com/llvm/llvm-project/commit/c5c6588c115347cdd000cb8a650b19dca0381196
DIFF: https://github.com/llvm/llvm-project/commit/c5c6588c115347cdd000cb8a650b19dca0381196.diff

LOG: [MLIR][XeGPU] Add distribution pattern for xegpu load & store matrix from sg to wi (#183179)

This PR adds distribution pattern for xegpu.load_matrix &
xegpu.store_matrix ops for the new sg-to-wi pass

Added: 
    

Modified: 
    mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
    mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
index cffd80f8fcf92..5cd766ed2813e 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
@@ -5,6 +5,7 @@
 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 //
 //===----------------------------------------------------------------------===//
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/Dialect/Index/IR/IndexDialect.h"
 #include "mlir/Dialect/Math/IR/Math.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
@@ -597,6 +598,137 @@ struct SgToWiMultiDimReduction
   }
 };
 
+/// Helper to compute distributed coordinates for matrix ops.
+/// When not using subgroup_block_io, each workitem computes its own
+/// coordinates based on the layout and lane ID.
+static SmallVector<Value> computeDistributedCoordsForMatrixOp(
+    ConversionPatternRewriter &rewriter, Location loc,
+    xegpu::DistributeLayoutAttr layout, ArrayRef<int64_t> payloadShape,
+    ValueRange origOffsets) {
+  Value laneId = gpu::LaneIdOp::create(rewriter, loc, rewriter.getIndexType(),
+                                       /*upperBound=*/mlir::IntegerAttr());
+  auto maybeCoords =
+      layout.computeDistributedCoords(rewriter, loc, laneId, payloadShape);
+  if (failed(maybeCoords))
+    return {};
+  assert(maybeCoords.value().size() == 1 &&
+         "Expected one set of distributed offsets");
+  SmallVector<OpFoldResult> ofrVec = xegpu::addWithRightAligned(
+      rewriter, loc, getAsOpFoldResult(maybeCoords.value()[0]),
+      getAsOpFoldResult(origOffsets));
+  return llvm::map_to_vector(ofrVec, llvm::CastTo<Value>);
+}
+
+/// This pattern distributes a subgroup-level LoadMatrix op to workitem-level.
+struct SgToWiLoadMatrix : public OpConversionPattern<xegpu::LoadMatrixOp> {
+  using OpConversionPattern<xegpu::LoadMatrixOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(xegpu::LoadMatrixOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto layout = op.getLayoutAttr();
+    // If no layout, nothing to do.
+    if (!layout)
+      return failure();
+
+    VectorType sgPayloadTy = dyn_cast<VectorType>(op.getResult().getType());
+    if (!sgPayloadTy)
+      return rewriter.notifyMatchFailure(
+          op, "the matrix op payload must be a vector type");
+
+    auto loc = op.getLoc();
+    auto offsets = op.getMixedOffsets();
+    if (offsets.empty())
+      return rewriter.notifyMatchFailure(op, "the load op must have offsets");
+
+    FailureOr<VectorType> distPayloadTyOrFailure =
+        getDistVecTypeBasedOnLaneLayout(layout, sgPayloadTy);
+    if (failed(distPayloadTyOrFailure))
+      return rewriter.notifyMatchFailure(
+          op, "Failed to distribute matrix op payload based on layout.");
+
+    SmallVector<Value> offsetsAsValues =
+        vector::getAsValues(rewriter, loc, offsets);
+
+    SmallVector<Value> newCoords = offsetsAsValues;
+    if (!op.getSubgroupBlockIoAttr()) {
+      newCoords = computeDistributedCoordsForMatrixOp(
+          rewriter, loc, layout, sgPayloadTy.getShape(), offsetsAsValues);
+      if (newCoords.empty())
+        return rewriter.notifyMatchFailure(
+            op, "Failed to compute distributed coordinates.");
+    }
+
+    SmallVector<int64_t> newConstOffsets(op.getConstOffsets().size(),
+                                         ShapedType::kDynamic);
+    DenseI64ArrayAttr newConstOffsetsAttr =
+        rewriter.getDenseI64ArrayAttr(newConstOffsets);
+
+    auto newOp = xegpu::LoadMatrixOp::create(
+        rewriter, loc, *distPayloadTyOrFailure, adaptor.getMemDesc(),
+        ValueRange(newCoords), newConstOffsetsAttr, op.getSubgroupBlockIoAttr(),
+        xegpu::DistributeLayoutAttr{});
+    rewriter.replaceOp(op, newOp.getResult());
+    return success();
+  }
+};
+
+/// This pattern distributes a subgroup-level StoreMatrix op to workitem-level.
+struct SgToWiStoreMatrix : public OpConversionPattern<xegpu::StoreMatrixOp> {
+  using OpConversionPattern<xegpu::StoreMatrixOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(xegpu::StoreMatrixOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto layout = op.getLayoutAttr();
+    // If no layout, nothing to do.
+    if (!layout)
+      return failure();
+
+    VectorType sgPayloadTy = dyn_cast<VectorType>(op.getData().getType());
+    if (!sgPayloadTy)
+      return rewriter.notifyMatchFailure(
+          op, "the matrix op payload must be a vector type");
+
+    auto loc = op.getLoc();
+    auto offsets = op.getMixedOffsets();
+    if (offsets.empty())
+      return rewriter.notifyMatchFailure(op, "the store op must have offsets");
+
+    FailureOr<VectorType> distPayloadTyOrFailure =
+        getDistVecTypeBasedOnLaneLayout(layout, sgPayloadTy);
+    if (failed(distPayloadTyOrFailure))
+      return rewriter.notifyMatchFailure(
+          op, "Failed to distribute matrix op payload based on layout.");
+
+    SmallVector<Value> offsetsAsValues =
+        vector::getAsValues(rewriter, loc, offsets);
+
+    SmallVector<Value> newCoords = offsetsAsValues;
+    if (!op.getSubgroupBlockIoAttr()) {
+      newCoords = computeDistributedCoordsForMatrixOp(
+          rewriter, loc, layout, sgPayloadTy.getShape(), offsetsAsValues);
+      if (newCoords.empty())
+        return rewriter.notifyMatchFailure(
+            op, "Failed to compute distributed coordinates.");
+    }
+
+    SmallVector<int64_t> newConstOffsets(op.getConstOffsets().size(),
+                                         ShapedType::kDynamic);
+    DenseI64ArrayAttr newConstOffsetsAttr =
+        rewriter.getDenseI64ArrayAttr(newConstOffsets);
+
+    xegpu::StoreMatrixOp::create(
+        rewriter, loc, TypeRange{},
+        castValueTo(rewriter, cast<TypedValue<VectorType>>(adaptor.getData()),
+                    distPayloadTyOrFailure.value()),
+        adaptor.getMemDesc(), ValueRange(newCoords), newConstOffsetsAttr,
+        op.getSubgroupBlockIoAttr(), xegpu::DistributeLayoutAttr{});
+    rewriter.eraseOp(op);
+    return success();
+  }
+};
+
 /// Distributes a subgroup-level StoreScatter (xegpu.store) op to
 /// workitem-level.
 ///
@@ -901,5 +1033,6 @@ void xegpu::populateXeGPUSgToWiDistributeTypeConversionAndLegality(
   patterns.add<SgToWiCreateNdDesc, SgToWiLoadNd, SgToWiStoreNd, SgToWiDpas,
                SgToWiElementWise, SgToWiArithConstant, SgToWiPrefetchNd,
                SgToWiLoadGather, SgToWiStoreScatter, SgToWiVectorReduction,
-               SgToWiMultiDimReduction>(typeConverter, patterns.getContext());
+               SgToWiMultiDimReduction, SgToWiLoadMatrix, SgToWiStoreMatrix>(
+      typeConverter, patterns.getContext());
 }

diff  --git a/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir b/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
index bc36554f5c266..d7b4883760c05 100644
--- a/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
+++ b/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
@@ -461,3 +461,67 @@ gpu.func @vector_multi_reduction_dim0_distributed_dim1_reduction(%laneid: index)
   gpu.return
 }
 }
+
+// -----
+// load_matrix and store_matrix with coordinate computation (offsets [0,0])
+gpu.module @xevm_module {
+// CHECK-LABEL: gpu.func @load_store_matrix_1
+// CHECK-DAG: %[[LANE_ID1:.*]] = gpu.lane_id
+// CHECK-DAG: %[[R1:.*]] = arith.remui %[[LANE_ID1]], %{{.*}} : index
+// CHECK-DAG: %[[D1:.*]] = arith.divui %[[LANE_ID1]], %{{.*}} : index
+// CHECK-DAG: %[[R2:.*]] = arith.remui %[[D1]], %{{.*}} : index
+// CHECK-DAG: %[[ROW:.*]] = arith.remui %[[R2]], %{{.*}} : index
+// CHECK-DAG: %[[COL:.*]] = arith.remui %[[R1]], %{{.*}} : index
+// CHECK: %[[MAT:.*]] = xegpu.load_matrix %arg0[%[[ROW]], %[[COL]]] : !xegpu.mem_desc<32x32xf32>, index, index -> vector<1x1xf32>
+// CHECK: %[[LANE_ID2:.*]] = gpu.lane_id
+// CHECK: xegpu.store_matrix %[[MAT]], %arg0[%{{.*}}, %{{.*}}] : vector<1x1xf32>, !xegpu.mem_desc<32x32xf32>, index, index
+gpu.func @load_store_matrix_1(%arg0: !xegpu.mem_desc<32x32xf32>) {
+  %c0 = arith.constant 0 : index
+  %1 = xegpu.load_matrix %arg0[%c0, %c0] <{layout = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>}> : !xegpu.mem_desc<32x32xf32>, index, index -> vector<2x8xf32>
+  xegpu.store_matrix %1, %arg0[%c0, %c0] <{layout = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>}> : vector<2x8xf32>, !xegpu.mem_desc<32x32xf32>, index, index
+  gpu.return
+}
+}
+
+// -----
+// load_matrix and store_matrix with non-zero offsets [0,1]
+gpu.module @xevm_module {
+// CHECK-LABEL: gpu.func @load_store_matrix_2
+// CHECK-DAG: %[[LANE_ID1:.*]] = gpu.lane_id
+// CHECK-DAG: %[[R1:.*]] = arith.remui %[[LANE_ID1]], %{{.*}} : index
+// CHECK-DAG: %[[D1:.*]] = arith.divui %[[LANE_ID1]], %{{.*}} : index
+// CHECK-DAG: %[[R2:.*]] = arith.remui %[[D1]], %{{.*}} : index
+// CHECK-DAG: %[[MUL:.*]] = arith.muli %[[R2]], %{{.*}} : index
+// CHECK-DAG: %[[ROW:.*]] = arith.remui %[[MUL]], %{{.*}} : index
+// CHECK-DAG: %[[R3:.*]] = arith.remui %[[R1]], %{{.*}} : index
+// CHECK-DAG: %[[ADD:.*]] = arith.addi %[[R3]], %{{.*}} : index
+// CHECK: %[[MAT:.*]] = xegpu.load_matrix %arg0[%[[ROW]], %[[ADD]]] : !xegpu.mem_desc<32x32xf32>, index, index -> vector<2x1xf32>
+// CHECK: %[[LANE_ID2:.*]] = gpu.lane_id
+// CHECK: xegpu.store_matrix %[[MAT]], %arg0[%{{.*}}, %{{.*}}] : vector<2x1xf32>, !xegpu.mem_desc<32x32xf32>, index, index
+gpu.func @load_store_matrix_2(%arg0: !xegpu.mem_desc<32x32xf32>) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %1 = xegpu.load_matrix %arg0[%c0, %c1] <{layout = #xegpu.layout<lane_layout = [4, 4], lane_data = [2, 1]>}> : !xegpu.mem_desc<32x32xf32>, index, index -> vector<8x4xf32>
+  xegpu.store_matrix %1, %arg0[%c0, %c1] <{layout = #xegpu.layout<lane_layout = [4, 4], lane_data = [2, 1]>}> : vector<8x4xf32>, !xegpu.mem_desc<32x32xf32>, index, index
+  gpu.return
+}
+}
+
+// -----
+// load_matrix and store_matrix with subgroup_block_io (no coordinate computation)
+gpu.module @xevm_module {
+// CHECK-LABEL: gpu.func @load_store_matrix_3
+// CHECK: %[[MAT:.*]] = xegpu.load_matrix %arg0[%{{.*}}, %{{.*}}] <{subgroup_block_io}>:
+// CHECK-SAME: !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<block = [16, 1], stride = [1, 32]>>, index, index -> vector<1x2xf32>
+// CHECK: xegpu.store_matrix %[[MAT]], %arg0[%{{.*}}, %{{.*}}] <{subgroup_block_io}>:
+// CHECK-SAME: vector<1x2xf32>, !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<block = [16, 1], stride = [1, 32]>>, index, index
+gpu.func @load_store_matrix_3(%arg0: !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<stride = [1, 32], block = [16, 1]>>) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %1 = xegpu.load_matrix %arg0[%c0, %c1] <{subgroup_block_io, layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>}> :
+    !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<stride = [1, 32], block = [16, 1]>>, index, index -> vector<16x2xf32>
+  xegpu.store_matrix %1, %arg0[%c0, %c1] <{subgroup_block_io, layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>}> :
+    vector<16x2xf32>, !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<stride = [1, 32], block = [16, 1]>>, index, index
+  gpu.return
+}
+}


        


More information about the Mlir-commits mailing list