[Mlir-commits] [mlir] [mlir] [XeGPU] Add XeGPU workgroup to subgroup pass (PR #139477)
Frank Schlimbach
llvmlistbot at llvm.org
Wed May 14 01:00:24 PDT 2025
================
@@ -0,0 +1,387 @@
+//===- XeGPUWgToSg.cpp - XeGPU WorkGroup to Subgroup Pass -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
+
+#include "mlir/Dialect/Index/IR/IndexDialect.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/Support/Debug.h"
+#include <mlir/Dialect/GPU/IR/GPUDialect.h>
+#include <mlir/Dialect/Index/IR/IndexOps.h>
+#include <numeric>
+
+namespace mlir {
+namespace xegpu {
+#define GEN_PASS_DEF_XEGPUWGTOSG
+#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
+} // namespace xegpu
+} // namespace mlir
+
+#define DEBUG_TYPE "xegpu-wg-to-sg"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
+#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
+
+using namespace mlir;
+
+namespace {
+
+// clang-format off
+/// This pattern transforms the CreateNdDescOp to create a subgroup descriptor
+/// from a workgroup descriptor. It replaces the offsets and sizes with
+/// appropriate values for the subgroup.
+/// It uses round-robin assignment to distribute the work to the subgroups.
+/// Following create_nd_desc operation:,
+/// %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x24xf32>
+/// -> !xegpu.tensor_desc<24x24xf32, #xegpu.layout<sg_layout = [4, 4],
+/// sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+/// is converted to 9 subgroup level operations based on the sg_layout & sg_data:
+/// %tdesc = xegpu.create_nd_tdesc %src[off1, off2] : memref<24x24xf32> ->
+/// !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
+///
+/// The sg_layout and sg_data attributes are dropped after the pass as they are no longer needed.
+///
+/// 24x24 matrix distribution example:
+/// sg_layout = [4, 4], sg_data = [2, 2]
+/// Each 8x8 matrix within the 24x24 matrix is called a distribution unit.
+/// dist_unit_shape = [8, 8] --> sg_layout[i] * sg_data[i]
+///
+/// +------------------------+
+/// | 8x8 | 8x8 | 8x8 | <- 3 tiles across
+/// |-----+-----+-----|
+/// | 8x8 | 8x8 | 8x8 | <- 3 tiles down
+/// |-----+-----+-----|
+/// | 8x8 | 8x8 | 8x8 |
+/// +------------------------+
+///
+/// Each 8x8 tile is further subdivided among subgroups:
+/// +------------------------+
+/// | 2x2 2x2 2x2 2x2 | <- 4 subgroups across (each handles 2 columns)
+/// | 2x2 2x2 2x2 2x2 | <- 4 subgroups down (each handles 2 rows)
+/// | 2x2 2x2 2x2 2x2 |
+/// | 2x2 2x2 2x2 2x2 |
+/// +------------------------+
+///
+/// Since the 24x24 matrix is divided into 8x8 distribution units, there will be 9
+/// distribution units (3x3) in total. Hence the 9 subgroup level operations.
+// clang-format on
+struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
+ using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
+
+ // Helper to extract mixed offsets into a Value array
+ SmallVector<Value> extractOffsets(ConversionPatternRewriter &rewriter,
+ xegpu::CreateNdDescOp op) const {
+ llvm::SmallVector<Value> offsets;
+ auto staticOffsets = op.getStaticOffsets();
+ auto dynamicOffsets = op.getOffsets();
+
+ for (size_t i = 0, j = 0; i != staticOffsets.size(); i++) {
+ if (ShapedType::isDynamic(staticOffsets[i]))
+ offsets.push_back(dynamicOffsets[j++]);
+ else
+ offsets.push_back(rewriter.create<arith::ConstantIndexOp>(
+ op.getLoc(), staticOffsets[i]));
+ }
+ return offsets;
+ }
+
+ // Convert linear subgroup ID to 2D coordinates
+ // TODO: Delinearize for nD
+ SmallVector<Value> delinearizeSubgroupId(ConversionPatternRewriter &rewriter,
+ Location loc, Value sgID,
+ Value sgDimX, Value sgDimY) const {
+ return {rewriter.create<index::DivUOp>(loc, sgID, sgDimY),
+ rewriter.create<index::RemUOp>(loc, sgID, sgDimY)};
+ }
+
+ // Create a constant index value
+ Value createConstantIndex(ConversionPatternRewriter &rewriter, Location loc,
+ int64_t value) const {
+ return rewriter.create<arith::ConstantIndexOp>(loc, value);
+ }
+
+ // Calculate offset for each subgroup
+ SmallVector<OpFoldResult>
+ calculateGlobalOffsets(ConversionPatternRewriter &rewriter, Location loc,
+ const SmallVector<Value> &originalOffsets,
+ const SmallVector<Value> &localOffset,
+ const SmallVector<int64_t> &distUnitBaseAddr) const {
+
+ Value constOffsetX =
+ createConstantIndex(rewriter, loc, distUnitBaseAddr[0]);
+ Value constOffsetY =
+ createConstantIndex(rewriter, loc, distUnitBaseAddr[1]);
+
+ Value offsetX =
+ rewriter.createOrFold<index::AddOp>(loc, localOffset[0], constOffsetX);
+ Value offsetY =
+ rewriter.createOrFold<index::AddOp>(loc, localOffset[1], constOffsetY);
+
+ size_t lastDimIndex = originalOffsets.size() - 1;
+ size_t secondLastDimIndex = lastDimIndex - 1;
+
+ Value globalOffsetX = rewriter.createOrFold<index::AddOp>(
+ loc, originalOffsets[secondLastDimIndex], offsetX);
+ Value globalOffsetY = rewriter.createOrFold<index::AddOp>(
+ loc, originalOffsets[lastDimIndex], offsetY);
+
+ SmallVector<OpFoldResult> globalOffsets(originalOffsets.begin(),
+ originalOffsets.end());
+ globalOffsets[secondLastDimIndex] = globalOffsetX;
+ globalOffsets[lastDimIndex] = globalOffsetY;
+
+ return globalOffsets;
+ }
+
+ LogicalResult
+ matchAndRewrite(xegpu::CreateNdDescOp op, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ MLIRContext *ctx = op.getContext();
+ xegpu::TensorDescType tdescTy = op.getType();
+ auto layout = dyn_cast<xegpu::LayoutAttr>(tdescTy.getLayout());
+ if (!layout)
+ return failure();
+ Type elemTy = tdescTy.getElementType();
+ ArrayRef<int64_t> wgShape = tdescTy.getShape();
+ SmallVector<int64_t> sgShape =
+ llvm::to_vector_of<int64_t>(layout.getSgData().asArrayRef());
+ SmallVector<int64_t> sgLayout =
+ llvm::to_vector_of<int64_t>(layout.getSgLayout().asArrayRef());
+
+ // TODO : Handle order attribute
+ // Get the subgroup ID
+ auto linearSgId = rewriter.create<gpu::SubgroupIdOp>(loc, nullptr);
+
+ // Create constants for layout dimensions
+ SmallVector<Value> sgLayoutDim(sgLayout.size());
+ SmallVector<Value> sgDataDim(sgShape.size());
+
+ for (size_t i = 0; i < sgLayout.size(); i++) {
+ sgLayoutDim[i] = createConstantIndex(rewriter, loc, sgLayout[i]);
+ sgDataDim[i] = createConstantIndex(rewriter, loc, sgShape[i]);
+ }
+
+ // Delinearize the 1D subgroup id into 2d
+ SmallVector<Value> sgIds = delinearizeSubgroupId(
+ rewriter, loc, linearSgId, sgLayoutDim[0], sgLayoutDim[1]);
----------------
fschlimb wrote:
is `layout.getSgLayout().size() > 1` guaranteed?
https://github.com/llvm/llvm-project/pull/139477
More information about the Mlir-commits
mailing list