[Mlir-commits] [mlir] [mlir] [XeGPU] Add XeGPU workgroup to subgroup pass (PR #139477)
Adam Siemieniuk
llvmlistbot at llvm.org
Tue May 20 04:57:54 PDT 2025
================
@@ -0,0 +1,388 @@
+//===- XeGPUWgToSgDistribute.cpp - XeGPU Workgroup to Subgroup Pass -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
+
+#include "mlir/Dialect/Affine/Utils.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/Index/IR/IndexDialect.h"
+#include "mlir/Dialect/Index/IR/IndexOps.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+namespace xegpu {
+#define GEN_PASS_DEF_XEGPUWGTOSGDISTRIBUTE
+#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
+} // namespace xegpu
+} // namespace mlir
+
+using namespace mlir;
+
+namespace {
+
+/// This pattern transforms the CreateNdDescOp to create a subgroup descriptor
+/// from a workgroup descriptor. It replaces the offsets and sizes with
+/// appropriate values for the subgroup.
+/// It uses round-robin assignment to distribute the work to the subgroups.
+/// Following create_nd_desc operation:,
+/// %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x24xf32>
+/// -> !xegpu.tensor_desc<24x24xf32, #xegpu.layout<sg_layout = [4, 4],
+/// sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+/// is converted to 9 subgroup level operations based on the sg_layout &
+/// sg_data:
+/// %tdesc = xegpu.create_nd_tdesc %src[off1, off2] : memref<24x24xf32> ->
+/// !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2],
+/// lane_data = [1, 1]>>
+///
+/// The sg_layout and sg_data attributes are dropped after the pass as they are
+/// no longer needed.
+///
+/// 24x24 matrix distribution example:
+/// sg_layout = [4, 4], sg_data = [2, 2]
+/// Each 8x8 matrix within the 24x24 matrix is called a distribution unit.
+/// dist_unit_shape = [8, 8] --> sg_layout[i] * sg_data[i]
+///
+/// +------------------------+
+/// | 8x8 | 8x8 | 8x8 | <- 3 tiles across
+/// |-----+-----+-----|
+/// | 8x8 | 8x8 | 8x8 | <- 3 tiles down
+/// |-----+-----+-----|
+/// | 8x8 | 8x8 | 8x8 |
+/// +------------------------+
+///
+/// Each 8x8 tile is further subdivided among subgroups:
+/// +------------------------+
+/// | 2x2 2x2 2x2 2x2 | <- 4 subgroups across (each handles 2 columns)
+/// | 2x2 2x2 2x2 2x2 | <- 4 subgroups down (each handles 2 rows)
+/// | 2x2 2x2 2x2 2x2 |
+/// | 2x2 2x2 2x2 2x2 |
+/// +------------------------+
+///
+/// Since the 24x24 matrix is divided into 8x8 distribution units, there will be
+/// 9 distribution units (3x3) in total. Hence the 9 subgroup level operations.
+
+/// The pass currently has entire distribution logic in the WgToSgCreateNdOp
+/// pattern and all the other ops just follow.
+/// TODO: Decouple the distribution logic from WgToSgCreateNdOp for all the
+/// ops in the pass.
+struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
+ using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
+
+ // Calculate offset for each subgroup
+ SmallVector<OpFoldResult>
+ calculateGlobalOffsets(ConversionPatternRewriter &rewriter, Location loc,
+ const SmallVector<OpFoldResult> &originalOffsets,
+ const SmallVector<Value> &localOffset,
+ const SmallVector<int64_t> &distUnitBaseAddr,
+ const SmallVector<int64_t> &distUnitShape) const {
+ assert(localOffset.size() == distUnitBaseAddr.size() &&
+ "localOffset and distUnitBaseAddr must have the same rank");
+
+ // Convert originalOffsets to Value
+ auto getValueFromOpFoldResult = [&](OpFoldResult ofr) -> Value {
+ if (auto val = ofr.dyn_cast<Value>())
+ return val;
+ if (auto attr = ofr.dyn_cast<Attribute>()) {
+ int64_t staticOffset = cast<IntegerAttr>(attr).getInt();
+ return rewriter.create<arith::ConstantIndexOp>(loc, staticOffset);
+ }
+ llvm_unreachable("Unsupported OpFoldResult kind");
+ };
+
+ SmallVector<OpFoldResult> globalOffsets(originalOffsets.begin(),
+ originalOffsets.end());
+ size_t rank = localOffset.size();
+ for (size_t i = 0; i < rank; ++i) {
+ size_t dimIdx = originalOffsets.size() - rank + i;
+ Value constOffset =
+ rewriter.create<arith::ConstantIndexOp>(loc, distUnitBaseAddr[i]);
+ Value offset =
+ rewriter.createOrFold<index::AddOp>(loc, localOffset[i], constOffset);
+ Value modValue =
+ rewriter.create<arith::ConstantIndexOp>(loc, distUnitShape[i]);
+ Value offsetMod =
+ rewriter.createOrFold<index::RemUOp>(loc, offset, modValue);
+ Value origOffset = getValueFromOpFoldResult(originalOffsets[dimIdx]);
+ Value globalOffset =
+ rewriter.createOrFold<index::AddOp>(loc, origOffset, offsetMod);
+ globalOffsets[dimIdx] = globalOffset;
+ }
+
+ return globalOffsets;
+ }
+
+ LogicalResult
+ matchAndRewrite(xegpu::CreateNdDescOp op, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ MLIRContext *ctx = op.getContext();
+ xegpu::TensorDescType tdescTy = op.getType();
+ auto layout = dyn_cast<xegpu::LayoutAttr>(tdescTy.getLayout());
+ if (!layout)
+ return failure();
+ Type elemTy = tdescTy.getElementType();
+ ArrayRef<int64_t> wgShape = tdescTy.getShape();
+ SmallVector<int64_t> sgLayout;
+ if (auto sgLayoutAttr = layout.getSgLayout()) {
----------------
adam-smnk wrote:
Maybe this could be checked earlier with `isWgLayout`?
I think other ops could also benefit from extra matcher checks.
https://github.com/llvm/llvm-project/pull/139477
More information about the Mlir-commits
mailing list