[Mlir-commits] [mlir] [mlir] [XeGPU] Add XeGPU workgroup to subgroup pass (PR #139477)
Chao Chen
llvmlistbot at llvm.org
Mon May 19 14:03:13 PDT 2025
================
@@ -0,0 +1,366 @@
+//===- XeGPUWgToSgDistribute.cpp - XeGPU Workgroup to Subgroup Pass -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
+
+#include "mlir/Dialect/Affine/Utils.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/Index/IR/IndexDialect.h"
+#include "mlir/Dialect/Index/IR/IndexOps.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+namespace xegpu {
+#define GEN_PASS_DEF_XEGPUWGTOSGDISTRIBUTE
+#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
+} // namespace xegpu
+} // namespace mlir
+
+using namespace mlir;
+
+namespace {
+
+/// This pattern transforms the CreateNdDescOp to create a subgroup descriptor
+/// from a workgroup descriptor. It replaces the offsets and sizes with
+/// appropriate values for the subgroup.
+/// It uses round-robin assignment to distribute the work to the subgroups.
+/// Following create_nd_desc operation:,
+/// %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x24xf32>
+/// -> !xegpu.tensor_desc<24x24xf32, #xegpu.layout<sg_layout = [4, 4],
+/// sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+/// is converted to 9 subgroup level operations based on the sg_layout &
+/// sg_data:
+/// %tdesc = xegpu.create_nd_tdesc %src[off1, off2] : memref<24x24xf32> ->
+/// !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2],
+/// lane_data = [1, 1]>>
+///
+/// The sg_layout and sg_data attributes are dropped after the pass as they are
+/// no longer needed.
+///
+/// 24x24 matrix distribution example:
+/// sg_layout = [4, 4], sg_data = [2, 2]
+/// Each 8x8 matrix within the 24x24 matrix is called a distribution unit.
+/// dist_unit_shape = [8, 8] --> sg_layout[i] * sg_data[i]
+///
+/// +------------------------+
+/// | 8x8 | 8x8 | 8x8 | <- 3 tiles across
+/// |-----+-----+-----|
+/// | 8x8 | 8x8 | 8x8 | <- 3 tiles down
+/// |-----+-----+-----|
+/// | 8x8 | 8x8 | 8x8 |
+/// +------------------------+
+///
+/// Each 8x8 tile is further subdivided among subgroups:
+/// +------------------------+
+/// | 2x2 2x2 2x2 2x2 | <- 4 subgroups across (each handles 2 columns)
+/// | 2x2 2x2 2x2 2x2 | <- 4 subgroups down (each handles 2 rows)
+/// | 2x2 2x2 2x2 2x2 |
+/// | 2x2 2x2 2x2 2x2 |
+/// +------------------------+
+///
+/// Since the 24x24 matrix is divided into 8x8 distribution units, there will be
+/// 9 distribution units (3x3) in total. Hence the 9 subgroup level operations.
+
+/// The pass currently has entire distribution logic in the WgToSgCreateNdOp
+/// pattern and all the other ops just follow.
+/// TODO: Decouple the distribution logic from WgToSgCreateNdOp for all the
+/// ops in the pass.
+struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
+ using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
+
+ // Calculate offset for each subgroup
+ SmallVector<OpFoldResult>
+ calculateGlobalOffsets(ConversionPatternRewriter &rewriter, Location loc,
+ const SmallVector<OpFoldResult> &originalOffsets,
+ const SmallVector<Value> &localOffset,
+ const SmallVector<int64_t> &distUnitBaseAddr) const {
+ assert(localOffset.size() == distUnitBaseAddr.size() &&
+ "localOffset and distUnitBaseAddr must have the same rank");
+
+ // Convert originalOffsets to Value
+ auto getValueFromOpFoldResult = [&](OpFoldResult ofr) -> Value {
+ if (auto val = ofr.dyn_cast<Value>())
+ return val;
+ if (auto attr = ofr.dyn_cast<Attribute>()) {
+ int64_t staticOffset = cast<IntegerAttr>(attr).getInt();
+ return rewriter.create<arith::ConstantIndexOp>(loc, staticOffset);
+ }
+ llvm_unreachable("Unsupported OpFoldResult kind");
+ };
+
+ SmallVector<OpFoldResult> globalOffsets(originalOffsets.begin(),
+ originalOffsets.end());
+ size_t rank = localOffset.size();
+ for (size_t i = 0; i < rank; ++i) {
+ size_t dimIdx = originalOffsets.size() - rank + i;
+ Value constOffset =
+ rewriter.create<arith::ConstantIndexOp>(loc, distUnitBaseAddr[i]);
+ Value offset =
+ rewriter.createOrFold<index::AddOp>(loc, localOffset[i], constOffset);
+ Value origOffset = getValueFromOpFoldResult(originalOffsets[dimIdx]);
+ Value globalOffset =
+ rewriter.createOrFold<index::AddOp>(loc, origOffset, offset);
+ globalOffsets[dimIdx] = globalOffset;
+ }
+
+ return globalOffsets;
+ }
+
+ LogicalResult
+ matchAndRewrite(xegpu::CreateNdDescOp op, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ MLIRContext *ctx = op.getContext();
+ xegpu::TensorDescType tdescTy = op.getType();
+ auto layout = dyn_cast<xegpu::LayoutAttr>(tdescTy.getLayout());
+ if (!layout)
+ return failure();
+ Type elemTy = tdescTy.getElementType();
+ ArrayRef<int64_t> wgShape = tdescTy.getShape();
+ SmallVector<int64_t> sgShape =
----------------
chencha3 wrote:
It will crash here if sgData or SgLayout is not available. As discussed, sgLayout is required for Workgroup level, but sgData is optional. if it is missing (for elementwise) the tileShape is wgShape/sgLayout. and distUnit is simply wgShape.
https://github.com/llvm/llvm-project/pull/139477
More information about the Mlir-commits
mailing list