[Mlir-commits] [mlir] [mlir][XeGPU] add WgToSg distribution pattern for load_matrix and store_matrix. (PR #154403)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Aug 19 11:48:33 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Chao Chen (chencha3)

<details>
<summary>Changes</summary>

As described by the title. 

---

Patch is 33.42 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/154403.diff


9 Files Affected:

- (modified) mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h (+2) 
- (modified) mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td (+41-4) 
- (modified) mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td (+8-4) 
- (modified) mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp (+9-6) 
- (modified) mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp (+2-2) 
- (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp (+180-76) 
- (modified) mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir (+6-2) 
- (modified) mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir (+74-2) 
- (modified) mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp (+2-2) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h
index 3592da4c46364..1d152f0c9ca9a 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h
@@ -11,6 +11,7 @@
 
 #include "mlir/Bytecode/BytecodeOpInterface.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Dialect.h"
@@ -23,6 +24,7 @@
 namespace mlir {
 namespace xegpu {
 class TensorDescType;
+class DistributeLayoutAttrInterface;
 class LayoutAttr;
 class SliceAttr;
 } // namespace xegpu
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index a94987885c9e0..de86141ad006a 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -175,22 +175,31 @@ def XeGPU_FenceScopeAttr:
     let assemblyFormat = "$value";
 }
 
-def LayoutTrait: AttrInterface<"LayoutTrait"> {
+def DistributeLayoutAttrInterface: AttrInterface<"DistributeLayoutAttrInterface"> {
   let cppNamespace = "::mlir::xegpu";
   let description = [{
     Common trait for all XeGPU layouts.
   }];
 
   let methods = [
+    InterfaceMethod<"Check the availability of workgroup level layouts",
+                    "bool",
+                    "isWgLayout">,
     InterfaceMethod<"Get the rank of attribute",
                     "int64_t",
                     "getRank">,
+    InterfaceMethod<"Get the num of effective subgroups",
+                    "int64_t",
+                    "getNumSubgroups">,
     InterfaceMethod<"Get the SgLayout field of the attribute as integer array",
                     "std::optional<SmallVector<int64_t>>",
                     "getSgLayoutAsInt">,
     InterfaceMethod<"Get the SgData field of the attribute as integer array",
                     "std::optional<SmallVector<int64_t>>",
                     "getSgDataAsInt">,
+    InterfaceMethod<"Derive a new layout by dropping sgLayout and sgData",
+                    "xegpu::DistributeLayoutAttrInterface",
+                    "dropSgLayoutAndData">,
     InterfaceMethod<[{Delinearizes a linear subgroup ID into its multidimensional
                       indices based on the effective subgroup layout.}],
                     "FailureOr<SmallVector<Value>>",
@@ -206,7 +215,7 @@ def LayoutTrait: AttrInterface<"LayoutTrait"> {
   ];
 }
 
-def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [LayoutTrait]> {
+def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttrInterface]> {
   let summary = [{
     Describes the data distribution to subgroups and work-items for a tensor
     specified by the tensor descriptor.
@@ -346,6 +355,13 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [LayoutTrait]> {
       return 0;
     }
 
+    int64_t getNumSubgroups() {
+      std::optional<SmallVector<int64_t>> sgLayout = getSgLayoutAsInt();
+      if (sgLayout.has_value())
+        return computeProduct(*sgLayout);
+      return 0;
+    }
+
     LayoutAttr dropSgLayoutAndData() {
       // avoid every field of the attribute is nullptr, which may lead to segment fault
       if (!getInstData() && !getLaneLayout())
@@ -393,7 +409,7 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [LayoutTrait]> {
 }
 
 
-def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [LayoutTrait]> {
+def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttrInterface]> {
   let summary = [{Describes the data distribution and sharing among subgroups or work-items.}];
 
   let description = [{
@@ -420,7 +436,7 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [LayoutTrait]> {
   }];
 
   let parameters = (ins
-    "xegpu::LayoutTrait": $parent,
+    "xegpu::DistributeLayoutAttrInterface": $parent,
     "DenseI64ArrayAttr": $dims
   );
 
@@ -450,6 +466,13 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [LayoutTrait]> {
       return parent.isSgLayout();
     }
 
+    int64_t getNumSubgroups() {
+      std::optional<SmallVector<int64_t>> sgLayout = getSgLayoutAsInt();
+      if (sgLayout.has_value())
+        return computeProduct(*sgLayout);
+      return 0;
+    }
+
     /// Returns the SgLayout of the attribute, computed by applying
     /// the slice dimensions to the underlying LayoutAttr.
     std::optional<SmallVector<int64_t>> getSgLayoutAsInt() const {
@@ -474,6 +497,20 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [LayoutTrait]> {
       return std::nullopt;
     }
 
+    SliceAttr dropSgLayoutAndData() {
+      SliceAttr attr = flatten();
+      auto parent = dyn_cast<LayoutAttr>(attr.getParent());
+      parent = parent.dropSgLayoutAndData();
+      return SliceAttr::get(getContext(), parent, attr.getDims());
+    }
+
+    SliceAttr dropInstData() {
+      SliceAttr attr = flatten();
+      auto parent = dyn_cast<LayoutAttr>(attr.getParent());
+      parent = parent.dropInstData();
+      return SliceAttr::get(getContext(), parent, attr.getDims());
+    }
+
     /// flatten a nested SliceAttr, e.g., for 2-level nested SliceAttr
     /// #xegpu.slice<#xegpu.slice<#xegpu.layout<sg_layout = [4, 8, 12]>, dims = [0]>, dims = [0]>
     /// it will coalese two slice operations and return a simplified SliceAttr
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index eb54d6887681d..3ba9eaa4a66da 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -232,6 +232,10 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface
       return static_cast<unsigned>(MemorySpace::Global);
     }
 
+    xegpu::DistributeLayoutAttrInterface getLayoutAttr() {
+      return dyn_cast_if_present<xegpu::DistributeLayoutAttrInterface>(getType().getLayout());
+    }
+
   }];
 }
 
@@ -1150,7 +1154,7 @@ def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>,
   let arguments = (ins XeGPU_MemDesc:$mem_desc,
     Variadic<Index>: $offsets,
     DenseI64ArrayAttr: $const_offsets,
-    OptionalAttr<LayoutTrait>:$layout
+    OptionalAttr<DistributeLayoutAttrInterface>:$layout
   );
   let results = (outs XeGPU_ValueType:$res);
   let assemblyFormat = [{
@@ -1175,7 +1179,7 @@ def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>,
 
   let builders = [
     OpBuilder<(ins "Type":$res, "TypedValue<MemDescType>": $mem_desc,
-                    "llvm::ArrayRef<OpFoldResult>": $offsets, "LayoutTrait": $layout)>,
+                    "llvm::ArrayRef<OpFoldResult>": $offsets, "DistributeLayoutAttrInterface": $layout)>,
   ];
   let extraClassDeclaration = [{
     SmallVector<OpFoldResult> getMixedOffsets() {
@@ -1194,7 +1198,7 @@ def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>,
     XeGPU_MemDesc:$mem_desc,
     Variadic<Index>: $offsets,
     DenseI64ArrayAttr: $const_offsets,
-    OptionalAttr<LayoutTrait>:$layout
+    OptionalAttr<DistributeLayoutAttrInterface>:$layout
   );
   let assemblyFormat = [{ $data `,` $mem_desc `` custom<DynamicIndexList>($offsets, $const_offsets)
                           prop-dict attr-dict `` `:` type(operands)}];
@@ -1213,7 +1217,7 @@ def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>,
   }];
   let builders = [
     OpBuilder<(ins "Value" : $data, "TypedValue<MemDescType>": $mem_desc,
-                   "llvm::ArrayRef<OpFoldResult>": $offsets, "LayoutTrait": $layout)>,
+                   "llvm::ArrayRef<OpFoldResult>": $offsets, "DistributeLayoutAttrInterface": $layout)>,
   ];
   let extraClassDeclaration = [{
     SmallVector<OpFoldResult> getMixedOffsets() {
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 8ea8cb1f45972..de118b7faea4d 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -290,8 +290,9 @@ LayoutAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc,
   return affine::delinearizeIndex(builder, loc, linearId, dims);
 }
 
-/// Implements LayoutTrait::getOffsets to generate instructions for
-/// computing multi-dimensional offsets when distributed by LayoutAttr.
+/// Implements DistributeLayoutAttrInterface::getOffsets to generate
+/// instructions for computing multi-dimensional offsets when distributed by
+/// LayoutAttr.
 FailureOr<SmallVector<SmallVector<Value>>>
 LayoutAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId,
                        ArrayRef<int64_t> shape) {
@@ -322,7 +323,8 @@ LayoutAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId,
 //===----------------------------------------------------------------------===//
 LogicalResult
 SliceAttr::verify(llvm::function_ref<InFlightDiagnostic()> emitError,
-                  xegpu::LayoutTrait parent, DenseI64ArrayAttr dims) {
+                  xegpu::DistributeLayoutAttrInterface parent,
+                  DenseI64ArrayAttr dims) {
   if (!parent || !dims)
     return emitError() << "expected parent layout and dims attribute";
 
@@ -340,7 +342,7 @@ SliceAttr::verify(llvm::function_ref<InFlightDiagnostic()> emitError,
 }
 
 SliceAttr SliceAttr::flatten() const {
-  xegpu::LayoutTrait parent = getParent();
+  xegpu::DistributeLayoutAttrInterface parent = getParent();
   SmallVector<DenseI64ArrayAttr> slicedDims({getDims()});
 
   while (auto sliceAttr = dyn_cast<xegpu::SliceAttr>(parent)) {
@@ -375,8 +377,9 @@ SliceAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc,
   return parent.delinearizeSubgroupId(builder, loc, linearId);
 }
 
-/// Implements LayoutTrait::getOffsets to generate instructions for
-/// computing multi-dimensional offsets when distributed by SliceAttr.
+/// Implements DistributeLayoutAttrInterface::getOffsets to generate
+/// instructions for computing multi-dimensional offsets when distributed by
+/// SliceAttr.
 FailureOr<SmallVector<SmallVector<Value>>>
 SliceAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId,
                       ArrayRef<int64_t> shape) {
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 906c71d8b8dad..0e22af900daf1 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -984,7 +984,7 @@ void ConvertLayoutOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
 void LoadMatrixOp::build(OpBuilder &builder, OperationState &state, Type res,
                          TypedValue<MemDescType> memDesc,
                          llvm::ArrayRef<OpFoldResult> offsets,
-                         LayoutTrait layout) {
+                         DistributeLayoutAttrInterface layout) {
   llvm::SmallVector<Value> dynamicOffsets;
   llvm::SmallVector<int64_t> staticOffsets;
   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
@@ -1014,7 +1014,7 @@ LogicalResult LoadMatrixOp::verify() {
 void StoreMatrixOp::build(OpBuilder &builder, OperationState &state, Value data,
                           TypedValue<MemDescType> memDesc,
                           llvm::ArrayRef<OpFoldResult> offsets,
-                          LayoutTrait layout) {
+                          DistributeLayoutAttrInterface layout) {
   llvm::SmallVector<Value> dynamicOffsets;
   llvm::SmallVector<int64_t> staticOffsets;
   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 8f1208e77ca5d..ca1209e776d0e 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -55,17 +55,16 @@ static bool isSgIdRangeSpecified(Operation *op, int64_t &startOfRange,
 }
 
 static std::pair<SmallVector<int64_t>, int>
-getSgShapeAndCount(ArrayRef<int64_t> shape, xegpu::LayoutAttr layout) {
+getSgShapeAndCount(ArrayRef<int64_t> shape,
+                   xegpu::DistributeLayoutAttrInterface layout) {
   int count = 1;
   SmallVector<int64_t> sgShape(shape);
-
   if (layout && layout.isWgLayout()) {
-    DenseI32ArrayAttr sgLayoutAttr = layout.getSgLayout();
-    auto sgLayout = llvm::to_vector_of<int64_t>(sgLayoutAttr.asArrayRef());
-    if (DenseI32ArrayAttr sgDataAttr = layout.getSgData())
-      sgShape = llvm::to_vector_of<int64_t>(sgDataAttr.asArrayRef());
-    else
-      sgShape = computeShapeRatio(shape, sgLayout).value_or(sgShape);
+    SmallVector<int64_t> sgLayout = layout.getSgLayoutAsInt().value();
+    if (auto maybeSgData = layout.getSgDataAsInt())
+      sgShape = *maybeSgData;
+    else if (auto maybeDerivedSgData = computeShapeRatio(shape, sgLayout))
+      sgShape = *maybeDerivedSgData;
     SmallVector<int64_t> distUnit = computeElementwiseMul(sgLayout, sgShape);
     // Clamp distUnit to the original shape to handle cases where data is
     // shared among subgroups, which may cause distUnit to exceed the original
@@ -77,6 +76,72 @@ getSgShapeAndCount(ArrayRef<int64_t> shape, xegpu::LayoutAttr layout) {
   return std::make_pair(sgShape, count);
 }
 
+// An util helper to generate elementwise addition ops for index computing.
+// lhs and rhs are vectors of Values. If the rank of lhs and rhs doesn't match.
+// left-alignment is performed.
+static SmallVector<OpFoldResult> add(ConversionPatternRewriter &rewriter,
+                                     Location loc, ArrayRef<OpFoldResult> lhs,
+                                     ArrayRef<OpFoldResult> rhs) {
+  SmallVector<OpFoldResult> reversedResult;
+  auto l = lhs.rbegin();
+  auto r = rhs.rbegin();
+  for (; l != lhs.rend() || r != rhs.rend(); ++l, ++r) {
+    if (l == lhs.rend()) {
+      reversedResult.push_back(*r);
+    } else if (r == rhs.rend()) {
+      reversedResult.push_back(*l);
+    } else {
+      auto lval = getValueOrCreateConstantIndexOp(rewriter, loc, *l);
+      auto rval = getValueOrCreateConstantIndexOp(rewriter, loc, *r);
+      auto add = rewriter.createOrFold<index::AddOp>(loc, lval, rval);
+      reversedResult.push_back(add);
+    }
+  }
+  return llvm::to_vector(llvm::reverse(reversedResult));
+}
+
+// A callback funtion type used to create new load/store_matrix ops
+using CreatorFuncType =
+    llvm::function_ref<void(ArrayRef<OpFoldResult> baseOffsets,
+                            SmallVector<SmallVector<Value>> &descOffsets)>;
+
+/// Utility helper for distributing logic shared by operations with offsets
+template <typename OpType,
+          typename = std::enable_if_t<llvm::is_one_of<
+              OpType, xegpu::CreateNdDescOp, xegpu::LoadMatrixOp,
+              xegpu::StoreMatrixOp>::value>>
+static LogicalResult
+distributeOp(ConversionPatternRewriter &rewriter,
+             typename OpConversionPattern<OpType>::OneToNOpAdaptor adaptor,
+             OpType op, ArrayRef<int64_t> wgShape, CreatorFuncType callback) {
+  Location loc = op.getLoc();
+  auto layout = op.getLayoutAttr();
+  if (!layout || !layout.isWgLayout())
+    return failure();
+
+  Value sgId = rewriter.create<gpu::SubgroupIdOp>(loc, /*upper_bound=*/nullptr);
+
+  // adjust the linearId if the range specifier is present
+  int64_t startOfRange = -1, endOfRange = -1;
+  bool sgIdRangeSpecified = isSgIdRangeSpecified(op, startOfRange, endOfRange);
+  if (sgIdRangeSpecified) {
+    if (layout.getNumSubgroups() != endOfRange - startOfRange)
+      return rewriter.notifyMatchFailure(
+          op, "sg_layout size must match the sg_id_range");
+    Value startOfRangeVal =
+        rewriter.create<arith::ConstantIndexOp>(loc, startOfRange);
+    sgId = rewriter.create<index::SubOp>(loc, sgId, startOfRangeVal);
+  }
+
+  auto maybeMdescOffsets = layout.getOffsets(rewriter, loc, sgId, wgShape);
+  if (failed(maybeMdescOffsets))
+    return failure();
+
+  SmallVector<OpFoldResult> wgOffsets = op.getMixedOffsets();
+  callback(wgOffsets, *maybeMdescOffsets);
+  return success();
+}
+
 /// This pattern transforms the CreateNdDescOp to create a subgroup descriptor
 /// from a workgroup descriptor. It replaces the offsets and sizes with
 /// appropriate values for the subgroup.
@@ -137,71 +202,35 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
     Location loc = op.getLoc();
     MLIRContext *ctx = op.getContext();
     xegpu::TensorDescType tdescTy = op.getType();
-    auto layout = dyn_cast<xegpu::LayoutAttr>(tdescTy.getLayout());
-    if (!layout)
-      return failure();
-    Type elemTy = tdescTy.getElementType();
     ArrayRef<int64_t> wgShape = tdescTy.getShape();
-    // sgLayout must be present for workgroup-level distribution.
-    SmallVector<int64_t> sgLayout;
-    if (auto sgLayoutAttr = layout.getSgLayout())
-      sgLayout = llvm::to_vector_of<int64_t>(sgLayoutAttr.asArrayRef());
-    else
-      return rewriter.notifyMatchFailure(
-          op, "sgLayout attribute is required in layout");
-
-    // Get the subgroup ID
-    Value linearSgId =
-        gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
-
-    int64_t startOfRange = -1, endOfRange = -1;
-    bool sgIdRangeSpecified =
-        isSgIdRangeSpecified(op, startOfRange, endOfRange);
-
-    if (sgIdRangeSpecified) {
-      int64_t sgCount = endOfRange - startOfRange;
-      if (computeProduct(sgLayout) != sgCount)
-        return rewriter.notifyMatchFailure(
-            op, "sg_layout size must match the sg_id_range");
-      // Subtract startOfRange from the original subgroup id to get
-      // the adjusted sg id
-      Value startOfRangeVal =
-          arith::ConstantIndexOp::create(rewriter, loc, startOfRange);
-      linearSgId =
-          rewriter.createOrFold<index::SubOp>(loc, linearSgId, startOfRangeVal);
-    }
-
-    auto maybeTdescOffsets =
-        layout.getOffsets(rewriter, loc, linearSgId, wgShape);
-    if (failed(maybeTdescOffsets))
-      return failure();
-
-    SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
-    xegpu::TensorDescType newTdescTy =
-        xegpu::TensorDescType::get(ctx, sgShape, elemTy, tdescTy.getEncoding(),
-                                   layout.dropSgLayoutAndData());
+    Type elemTy = tdescTy.getElementType();
 
-    SmallVector<Value> newCreateNdOps;
-    SmallVector<OpFoldResult> origOffsets = op.getMixedOffsets();
-
-    for (auto tdescOffsets : *maybeTdescOffsets) {
-      SmallVector<OpFoldResult> sgOffsets;
-      size_t rank = tdescOffsets.size();
-      for (size_t i = 0; i < rank; i++) {
-        size_t idx = origOffsets.size() - rank + i;
-        Value add = rewriter.createOrFold<index::AddOp>(
-            loc, tdescOffsets[i],
-            getValueOrCreateConstantIndexOp(rewriter, loc, origOffsets[idx]));
-        sgOffsets.push_back(add);
+    // the call back function for creating new CreateNdOps,
+    // the baseOffsets is the origial offsets of the op, and
+    // descOffsets is the relative offsets to the mem_desc accessed
+    // by each subgroup op.
+    auto callback = [&](ArrayRef<OpFoldResult> baseOffsets,
+                        SmallVector<SmallVector<Value>> descOffsets) {
+      xegpu::DistributeLayoutAttrInterface layout = op.getLayoutAttr();
+      SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
+      auto newTdescTy = xegpu::TensorDescType::get(
+          ctx, sgShape, elemTy, tdescTy.getEncoding(),
+          layout.dropSgLayoutAndData());
+
+      SmallVector<Value> newOps;
+      for (auto offsets : descOffsets) {
+        SmallVector<OpFoldResult> sgOffsets =
+            add(rewriter, loc, baseOffsets, getAsOpFoldResult(offsets));
+        auto newOp = xegpu::CreateNdDescOp::create(
+            rewriter, loc, newTdescTy, op.getSource(), sgOffsets,
+            op.getMixedSizes(), op.getMixedStrides());
+
+        newOps.push_back(newOp);
       }
+      rewriter.replaceOpWithMultiple(op, {newOps});
+    };
 
-      auto newOp = xegpu::CreateNdDescOp::create(
-          rewriter, loc, newTdescTy, op.getSource(), sgOffsets,
-          op.getMixedSizes(), op.getMixedStrides());
-      newCreateNdOps.push_back(newOp);
-    }
-    rewriter.replaceOpWithMultiple(op, {newCreateNdOps});
-    return success();
+    return distributeOp(rewriter, adaptor, op, wgShape, callback);
   }
 };
 
@@ -723,8 +752,8 @@...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/154403


More information about the Mlir-commits mailing list