[Mlir-commits] [mlir] [mlir][XeGPU] add WgToSg distribution pattern for load_matrix and store_matrix. (PR #154403)
Artem Kroviakov
llvmlistbot at llvm.org
Wed Aug 20 08:00:14 PDT 2025
================
@@ -77,6 +76,72 @@ getSgShapeAndCount(ArrayRef<int64_t> shape, xegpu::LayoutAttr layout) {
return std::make_pair(sgShape, count);
}
+// An util helper to generate elementwise addition ops for index computing.
+// lhs and rhs are vectors of Values. If the rank of lhs and rhs doesn't match.
+// left-alignment is performed.
+static SmallVector<OpFoldResult> add(ConversionPatternRewriter &rewriter,
+ Location loc, ArrayRef<OpFoldResult> lhs,
+ ArrayRef<OpFoldResult> rhs) {
+ SmallVector<OpFoldResult> reversedResult;
+ auto l = lhs.rbegin();
+ auto r = rhs.rbegin();
+ for (; l != lhs.rend() || r != rhs.rend(); ++l, ++r) {
+ if (l == lhs.rend()) {
+ reversedResult.push_back(*r);
+ } else if (r == rhs.rend()) {
+ reversedResult.push_back(*l);
+ } else {
+ auto lval = getValueOrCreateConstantIndexOp(rewriter, loc, *l);
+ auto rval = getValueOrCreateConstantIndexOp(rewriter, loc, *r);
+ auto add = rewriter.createOrFold<index::AddOp>(loc, lval, rval);
+ reversedResult.push_back(add);
+ }
+ }
+ return llvm::to_vector(llvm::reverse(reversedResult));
+}
+
+// A callback funtion type used to create new load/store_matrix ops
+using CreatorFuncType =
+ llvm::function_ref<void(ArrayRef<OpFoldResult> baseOffsets,
+ SmallVector<SmallVector<Value>> &descOffsets)>;
+
+/// Utility helper for distributing logic shared by operations with offsets
+template <typename OpType,
+ typename = std::enable_if_t<llvm::is_one_of<
+ OpType, xegpu::CreateNdDescOp, xegpu::LoadMatrixOp,
+ xegpu::StoreMatrixOp>::value>>
+static LogicalResult
+distributeOp(ConversionPatternRewriter &rewriter,
+ typename OpConversionPattern<OpType>::OneToNOpAdaptor adaptor,
+ OpType op, ArrayRef<int64_t> wgShape, CreatorFuncType callback) {
+ Location loc = op.getLoc();
+ auto layout = op.getLayoutAttr();
+ if (!layout || !layout.isWgLayout())
+ return failure();
+
+ Value sgId = rewriter.create<gpu::SubgroupIdOp>(loc, /*upper_bound=*/nullptr);
+
+ // adjust the linearId if the range specifier is present
+ int64_t startOfRange = -1, endOfRange = -1;
+ bool sgIdRangeSpecified = isSgIdRangeSpecified(op, startOfRange, endOfRange);
+ if (sgIdRangeSpecified) {
----------------
akroviakov wrote:
> I reused this code from upstream
The issue has been known for some time, correct me if I'm wrong @nbpatel
https://github.com/llvm/llvm-project/pull/154403
More information about the Mlir-commits
mailing list