[Mlir-commits] [mlir] [mlir][XeGPU] add WgToSg distribution pattern for load_matrix and store_matrix. (PR #154403)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Aug 19 11:48:33 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Chao Chen (chencha3)
<details>
<summary>Changes</summary>
As described by the title.
---
Patch is 33.42 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/154403.diff
9 Files Affected:
- (modified) mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h (+2)
- (modified) mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td (+41-4)
- (modified) mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td (+8-4)
- (modified) mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp (+9-6)
- (modified) mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp (+2-2)
- (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp (+180-76)
- (modified) mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir (+6-2)
- (modified) mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir (+74-2)
- (modified) mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp (+2-2)
``````````diff
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h
index 3592da4c46364..1d152f0c9ca9a 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h
@@ -11,6 +11,7 @@
#include "mlir/Bytecode/BytecodeOpInterface.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
@@ -23,6 +24,7 @@
namespace mlir {
namespace xegpu {
class TensorDescType;
+class DistributeLayoutAttrInterface;
class LayoutAttr;
class SliceAttr;
} // namespace xegpu
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index a94987885c9e0..de86141ad006a 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -175,22 +175,31 @@ def XeGPU_FenceScopeAttr:
let assemblyFormat = "$value";
}
-def LayoutTrait: AttrInterface<"LayoutTrait"> {
+def DistributeLayoutAttrInterface: AttrInterface<"DistributeLayoutAttrInterface"> {
let cppNamespace = "::mlir::xegpu";
let description = [{
Common trait for all XeGPU layouts.
}];
let methods = [
+ InterfaceMethod<"Check the availability of workgroup level layouts",
+ "bool",
+ "isWgLayout">,
InterfaceMethod<"Get the rank of attribute",
"int64_t",
"getRank">,
+ InterfaceMethod<"Get the num of effective subgroups",
+ "int64_t",
+ "getNumSubgroups">,
InterfaceMethod<"Get the SgLayout field of the attribute as integer array",
"std::optional<SmallVector<int64_t>>",
"getSgLayoutAsInt">,
InterfaceMethod<"Get the SgData field of the attribute as integer array",
"std::optional<SmallVector<int64_t>>",
"getSgDataAsInt">,
+ InterfaceMethod<"Derive a new layout by dropping sgLayout and sgData",
+ "xegpu::DistributeLayoutAttrInterface",
+ "dropSgLayoutAndData">,
InterfaceMethod<[{Delinearizes a linear subgroup ID into its multidimensional
indices based on the effective subgroup layout.}],
"FailureOr<SmallVector<Value>>",
@@ -206,7 +215,7 @@ def LayoutTrait: AttrInterface<"LayoutTrait"> {
];
}
-def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [LayoutTrait]> {
+def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttrInterface]> {
let summary = [{
Describes the data distribution to subgroups and work-items for a tensor
specified by the tensor descriptor.
@@ -346,6 +355,13 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [LayoutTrait]> {
return 0;
}
+ int64_t getNumSubgroups() {
+ std::optional<SmallVector<int64_t>> sgLayout = getSgLayoutAsInt();
+ if (sgLayout.has_value())
+ return computeProduct(*sgLayout);
+ return 0;
+ }
+
LayoutAttr dropSgLayoutAndData() {
// avoid every field of the attribute is nullptr, which may lead to segment fault
if (!getInstData() && !getLaneLayout())
@@ -393,7 +409,7 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [LayoutTrait]> {
}
-def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [LayoutTrait]> {
+def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttrInterface]> {
let summary = [{Describes the data distribution and sharing among subgroups or work-items.}];
let description = [{
@@ -420,7 +436,7 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [LayoutTrait]> {
}];
let parameters = (ins
- "xegpu::LayoutTrait": $parent,
+ "xegpu::DistributeLayoutAttrInterface": $parent,
"DenseI64ArrayAttr": $dims
);
@@ -450,6 +466,13 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [LayoutTrait]> {
return parent.isSgLayout();
}
+ int64_t getNumSubgroups() {
+ std::optional<SmallVector<int64_t>> sgLayout = getSgLayoutAsInt();
+ if (sgLayout.has_value())
+ return computeProduct(*sgLayout);
+ return 0;
+ }
+
/// Returns the SgLayout of the attribute, computed by applying
/// the slice dimensions to the underlying LayoutAttr.
std::optional<SmallVector<int64_t>> getSgLayoutAsInt() const {
@@ -474,6 +497,20 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [LayoutTrait]> {
return std::nullopt;
}
+ SliceAttr dropSgLayoutAndData() {
+ SliceAttr attr = flatten();
+ auto parent = dyn_cast<LayoutAttr>(attr.getParent());
+ parent = parent.dropSgLayoutAndData();
+ return SliceAttr::get(getContext(), parent, attr.getDims());
+ }
+
+ SliceAttr dropInstData() {
+ SliceAttr attr = flatten();
+ auto parent = dyn_cast<LayoutAttr>(attr.getParent());
+ parent = parent.dropInstData();
+ return SliceAttr::get(getContext(), parent, attr.getDims());
+ }
+
/// flatten a nested SliceAttr, e.g., for 2-level nested SliceAttr
/// #xegpu.slice<#xegpu.slice<#xegpu.layout<sg_layout = [4, 8, 12]>, dims = [0]>, dims = [0]>
/// it will coalese two slice operations and return a simplified SliceAttr
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index eb54d6887681d..3ba9eaa4a66da 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -232,6 +232,10 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface
return static_cast<unsigned>(MemorySpace::Global);
}
+ xegpu::DistributeLayoutAttrInterface getLayoutAttr() {
+ return dyn_cast_if_present<xegpu::DistributeLayoutAttrInterface>(getType().getLayout());
+ }
+
}];
}
@@ -1150,7 +1154,7 @@ def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>,
let arguments = (ins XeGPU_MemDesc:$mem_desc,
Variadic<Index>: $offsets,
DenseI64ArrayAttr: $const_offsets,
- OptionalAttr<LayoutTrait>:$layout
+ OptionalAttr<DistributeLayoutAttrInterface>:$layout
);
let results = (outs XeGPU_ValueType:$res);
let assemblyFormat = [{
@@ -1175,7 +1179,7 @@ def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>,
let builders = [
OpBuilder<(ins "Type":$res, "TypedValue<MemDescType>": $mem_desc,
- "llvm::ArrayRef<OpFoldResult>": $offsets, "LayoutTrait": $layout)>,
+ "llvm::ArrayRef<OpFoldResult>": $offsets, "DistributeLayoutAttrInterface": $layout)>,
];
let extraClassDeclaration = [{
SmallVector<OpFoldResult> getMixedOffsets() {
@@ -1194,7 +1198,7 @@ def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>,
XeGPU_MemDesc:$mem_desc,
Variadic<Index>: $offsets,
DenseI64ArrayAttr: $const_offsets,
- OptionalAttr<LayoutTrait>:$layout
+ OptionalAttr<DistributeLayoutAttrInterface>:$layout
);
let assemblyFormat = [{ $data `,` $mem_desc `` custom<DynamicIndexList>($offsets, $const_offsets)
prop-dict attr-dict `` `:` type(operands)}];
@@ -1213,7 +1217,7 @@ def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>,
}];
let builders = [
OpBuilder<(ins "Value" : $data, "TypedValue<MemDescType>": $mem_desc,
- "llvm::ArrayRef<OpFoldResult>": $offsets, "LayoutTrait": $layout)>,
+ "llvm::ArrayRef<OpFoldResult>": $offsets, "DistributeLayoutAttrInterface": $layout)>,
];
let extraClassDeclaration = [{
SmallVector<OpFoldResult> getMixedOffsets() {
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 8ea8cb1f45972..de118b7faea4d 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -290,8 +290,9 @@ LayoutAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc,
return affine::delinearizeIndex(builder, loc, linearId, dims);
}
-/// Implements LayoutTrait::getOffsets to generate instructions for
-/// computing multi-dimensional offsets when distributed by LayoutAttr.
+/// Implements DistributeLayoutAttrInterface::getOffsets to generate
+/// instructions for computing multi-dimensional offsets when distributed by
+/// LayoutAttr.
FailureOr<SmallVector<SmallVector<Value>>>
LayoutAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId,
ArrayRef<int64_t> shape) {
@@ -322,7 +323,8 @@ LayoutAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId,
//===----------------------------------------------------------------------===//
LogicalResult
SliceAttr::verify(llvm::function_ref<InFlightDiagnostic()> emitError,
- xegpu::LayoutTrait parent, DenseI64ArrayAttr dims) {
+ xegpu::DistributeLayoutAttrInterface parent,
+ DenseI64ArrayAttr dims) {
if (!parent || !dims)
return emitError() << "expected parent layout and dims attribute";
@@ -340,7 +342,7 @@ SliceAttr::verify(llvm::function_ref<InFlightDiagnostic()> emitError,
}
SliceAttr SliceAttr::flatten() const {
- xegpu::LayoutTrait parent = getParent();
+ xegpu::DistributeLayoutAttrInterface parent = getParent();
SmallVector<DenseI64ArrayAttr> slicedDims({getDims()});
while (auto sliceAttr = dyn_cast<xegpu::SliceAttr>(parent)) {
@@ -375,8 +377,9 @@ SliceAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc,
return parent.delinearizeSubgroupId(builder, loc, linearId);
}
-/// Implements LayoutTrait::getOffsets to generate instructions for
-/// computing multi-dimensional offsets when distributed by SliceAttr.
+/// Implements DistributeLayoutAttrInterface::getOffsets to generate
+/// instructions for computing multi-dimensional offsets when distributed by
+/// SliceAttr.
FailureOr<SmallVector<SmallVector<Value>>>
SliceAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId,
ArrayRef<int64_t> shape) {
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 906c71d8b8dad..0e22af900daf1 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -984,7 +984,7 @@ void ConvertLayoutOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
void LoadMatrixOp::build(OpBuilder &builder, OperationState &state, Type res,
TypedValue<MemDescType> memDesc,
llvm::ArrayRef<OpFoldResult> offsets,
- LayoutTrait layout) {
+ DistributeLayoutAttrInterface layout) {
llvm::SmallVector<Value> dynamicOffsets;
llvm::SmallVector<int64_t> staticOffsets;
dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
@@ -1014,7 +1014,7 @@ LogicalResult LoadMatrixOp::verify() {
void StoreMatrixOp::build(OpBuilder &builder, OperationState &state, Value data,
TypedValue<MemDescType> memDesc,
llvm::ArrayRef<OpFoldResult> offsets,
- LayoutTrait layout) {
+ DistributeLayoutAttrInterface layout) {
llvm::SmallVector<Value> dynamicOffsets;
llvm::SmallVector<int64_t> staticOffsets;
dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 8f1208e77ca5d..ca1209e776d0e 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -55,17 +55,16 @@ static bool isSgIdRangeSpecified(Operation *op, int64_t &startOfRange,
}
static std::pair<SmallVector<int64_t>, int>
-getSgShapeAndCount(ArrayRef<int64_t> shape, xegpu::LayoutAttr layout) {
+getSgShapeAndCount(ArrayRef<int64_t> shape,
+ xegpu::DistributeLayoutAttrInterface layout) {
int count = 1;
SmallVector<int64_t> sgShape(shape);
-
if (layout && layout.isWgLayout()) {
- DenseI32ArrayAttr sgLayoutAttr = layout.getSgLayout();
- auto sgLayout = llvm::to_vector_of<int64_t>(sgLayoutAttr.asArrayRef());
- if (DenseI32ArrayAttr sgDataAttr = layout.getSgData())
- sgShape = llvm::to_vector_of<int64_t>(sgDataAttr.asArrayRef());
- else
- sgShape = computeShapeRatio(shape, sgLayout).value_or(sgShape);
+ SmallVector<int64_t> sgLayout = layout.getSgLayoutAsInt().value();
+ if (auto maybeSgData = layout.getSgDataAsInt())
+ sgShape = *maybeSgData;
+ else if (auto maybeDerivedSgData = computeShapeRatio(shape, sgLayout))
+ sgShape = *maybeDerivedSgData;
SmallVector<int64_t> distUnit = computeElementwiseMul(sgLayout, sgShape);
// Clamp distUnit to the original shape to handle cases where data is
// shared among subgroups, which may cause distUnit to exceed the original
@@ -77,6 +76,72 @@ getSgShapeAndCount(ArrayRef<int64_t> shape, xegpu::LayoutAttr layout) {
return std::make_pair(sgShape, count);
}
+// An util helper to generate elementwise addition ops for index computing.
+// lhs and rhs are vectors of Values. If the rank of lhs and rhs doesn't match.
+// left-alignment is performed.
+static SmallVector<OpFoldResult> add(ConversionPatternRewriter &rewriter,
+ Location loc, ArrayRef<OpFoldResult> lhs,
+ ArrayRef<OpFoldResult> rhs) {
+ SmallVector<OpFoldResult> reversedResult;
+ auto l = lhs.rbegin();
+ auto r = rhs.rbegin();
+ for (; l != lhs.rend() || r != rhs.rend(); ++l, ++r) {
+ if (l == lhs.rend()) {
+ reversedResult.push_back(*r);
+ } else if (r == rhs.rend()) {
+ reversedResult.push_back(*l);
+ } else {
+ auto lval = getValueOrCreateConstantIndexOp(rewriter, loc, *l);
+ auto rval = getValueOrCreateConstantIndexOp(rewriter, loc, *r);
+ auto add = rewriter.createOrFold<index::AddOp>(loc, lval, rval);
+ reversedResult.push_back(add);
+ }
+ }
+ return llvm::to_vector(llvm::reverse(reversedResult));
+}
+
+// A callback funtion type used to create new load/store_matrix ops
+using CreatorFuncType =
+ llvm::function_ref<void(ArrayRef<OpFoldResult> baseOffsets,
+ SmallVector<SmallVector<Value>> &descOffsets)>;
+
+/// Utility helper for distributing logic shared by operations with offsets
+template <typename OpType,
+ typename = std::enable_if_t<llvm::is_one_of<
+ OpType, xegpu::CreateNdDescOp, xegpu::LoadMatrixOp,
+ xegpu::StoreMatrixOp>::value>>
+static LogicalResult
+distributeOp(ConversionPatternRewriter &rewriter,
+ typename OpConversionPattern<OpType>::OneToNOpAdaptor adaptor,
+ OpType op, ArrayRef<int64_t> wgShape, CreatorFuncType callback) {
+ Location loc = op.getLoc();
+ auto layout = op.getLayoutAttr();
+ if (!layout || !layout.isWgLayout())
+ return failure();
+
+ Value sgId = rewriter.create<gpu::SubgroupIdOp>(loc, /*upper_bound=*/nullptr);
+
+ // adjust the linearId if the range specifier is present
+ int64_t startOfRange = -1, endOfRange = -1;
+ bool sgIdRangeSpecified = isSgIdRangeSpecified(op, startOfRange, endOfRange);
+ if (sgIdRangeSpecified) {
+ if (layout.getNumSubgroups() != endOfRange - startOfRange)
+ return rewriter.notifyMatchFailure(
+ op, "sg_layout size must match the sg_id_range");
+ Value startOfRangeVal =
+ rewriter.create<arith::ConstantIndexOp>(loc, startOfRange);
+ sgId = rewriter.create<index::SubOp>(loc, sgId, startOfRangeVal);
+ }
+
+ auto maybeMdescOffsets = layout.getOffsets(rewriter, loc, sgId, wgShape);
+ if (failed(maybeMdescOffsets))
+ return failure();
+
+ SmallVector<OpFoldResult> wgOffsets = op.getMixedOffsets();
+ callback(wgOffsets, *maybeMdescOffsets);
+ return success();
+}
+
/// This pattern transforms the CreateNdDescOp to create a subgroup descriptor
/// from a workgroup descriptor. It replaces the offsets and sizes with
/// appropriate values for the subgroup.
@@ -137,71 +202,35 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
Location loc = op.getLoc();
MLIRContext *ctx = op.getContext();
xegpu::TensorDescType tdescTy = op.getType();
- auto layout = dyn_cast<xegpu::LayoutAttr>(tdescTy.getLayout());
- if (!layout)
- return failure();
- Type elemTy = tdescTy.getElementType();
ArrayRef<int64_t> wgShape = tdescTy.getShape();
- // sgLayout must be present for workgroup-level distribution.
- SmallVector<int64_t> sgLayout;
- if (auto sgLayoutAttr = layout.getSgLayout())
- sgLayout = llvm::to_vector_of<int64_t>(sgLayoutAttr.asArrayRef());
- else
- return rewriter.notifyMatchFailure(
- op, "sgLayout attribute is required in layout");
-
- // Get the subgroup ID
- Value linearSgId =
- gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
-
- int64_t startOfRange = -1, endOfRange = -1;
- bool sgIdRangeSpecified =
- isSgIdRangeSpecified(op, startOfRange, endOfRange);
-
- if (sgIdRangeSpecified) {
- int64_t sgCount = endOfRange - startOfRange;
- if (computeProduct(sgLayout) != sgCount)
- return rewriter.notifyMatchFailure(
- op, "sg_layout size must match the sg_id_range");
- // Subtract startOfRange from the original subgroup id to get
- // the adjusted sg id
- Value startOfRangeVal =
- arith::ConstantIndexOp::create(rewriter, loc, startOfRange);
- linearSgId =
- rewriter.createOrFold<index::SubOp>(loc, linearSgId, startOfRangeVal);
- }
-
- auto maybeTdescOffsets =
- layout.getOffsets(rewriter, loc, linearSgId, wgShape);
- if (failed(maybeTdescOffsets))
- return failure();
-
- SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
- xegpu::TensorDescType newTdescTy =
- xegpu::TensorDescType::get(ctx, sgShape, elemTy, tdescTy.getEncoding(),
- layout.dropSgLayoutAndData());
+ Type elemTy = tdescTy.getElementType();
- SmallVector<Value> newCreateNdOps;
- SmallVector<OpFoldResult> origOffsets = op.getMixedOffsets();
-
- for (auto tdescOffsets : *maybeTdescOffsets) {
- SmallVector<OpFoldResult> sgOffsets;
- size_t rank = tdescOffsets.size();
- for (size_t i = 0; i < rank; i++) {
- size_t idx = origOffsets.size() - rank + i;
- Value add = rewriter.createOrFold<index::AddOp>(
- loc, tdescOffsets[i],
- getValueOrCreateConstantIndexOp(rewriter, loc, origOffsets[idx]));
- sgOffsets.push_back(add);
+ // the call back function for creating new CreateNdOps,
+ // the baseOffsets is the origial offsets of the op, and
+ // descOffsets is the relative offsets to the mem_desc accessed
+ // by each subgroup op.
+ auto callback = [&](ArrayRef<OpFoldResult> baseOffsets,
+ SmallVector<SmallVector<Value>> descOffsets) {
+ xegpu::DistributeLayoutAttrInterface layout = op.getLayoutAttr();
+ SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
+ auto newTdescTy = xegpu::TensorDescType::get(
+ ctx, sgShape, elemTy, tdescTy.getEncoding(),
+ layout.dropSgLayoutAndData());
+
+ SmallVector<Value> newOps;
+ for (auto offsets : descOffsets) {
+ SmallVector<OpFoldResult> sgOffsets =
+ add(rewriter, loc, baseOffsets, getAsOpFoldResult(offsets));
+ auto newOp = xegpu::CreateNdDescOp::create(
+ rewriter, loc, newTdescTy, op.getSource(), sgOffsets,
+ op.getMixedSizes(), op.getMixedStrides());
+
+ newOps.push_back(newOp);
}
+ rewriter.replaceOpWithMultiple(op, {newOps});
+ };
- auto newOp = xegpu::CreateNdDescOp::create(
- rewriter, loc, newTdescTy, op.getSource(), sgOffsets,
- op.getMixedSizes(), op.getMixedStrides());
- newCreateNdOps.push_back(newOp);
- }
- rewriter.replaceOpWithMultiple(op, {newCreateNdOps});
- return success();
+ return distributeOp(rewriter, adaptor, op, wgShape, callback);
}
};
@@ -723,8 +752,8 @@...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/154403
More information about the Mlir-commits
mailing list