[Mlir-commits] [mlir] [mlir] [XeGPU] Add XeGPU workgroup to subgroup pass (PR #139477)

Chao Chen llvmlistbot at llvm.org
Wed May 14 14:16:07 PDT 2025


================
@@ -0,0 +1,387 @@
+//===- XeGPUWgToSg.cpp - XeGPU WorkGroup to Subgroup Pass -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
+
+#include "mlir/Dialect/Index/IR/IndexDialect.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/Support/Debug.h"
+#include <mlir/Dialect/GPU/IR/GPUDialect.h>
+#include <mlir/Dialect/Index/IR/IndexOps.h>
+#include <numeric>
+
+namespace mlir {
+namespace xegpu {
+#define GEN_PASS_DEF_XEGPUWGTOSG
+#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
+} // namespace xegpu
+} // namespace mlir
+
+#define DEBUG_TYPE "xegpu-wg-to-sg"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
+#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
+
+using namespace mlir;
+
+namespace {
+
+// clang-format off
+/// This pattern transforms the CreateNdDescOp to create a subgroup descriptor
+/// from a workgroup descriptor. It replaces the offsets and sizes with
+/// appropriate values for the subgroup.
+/// It uses round-robin assignment to distribute the work to the subgroups.
+/// Following create_nd_desc operation:,
+///    %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x24xf32>
+///       -> !xegpu.tensor_desc<24x24xf32, #xegpu.layout<sg_layout = [4, 4],
+///           sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+/// is converted to 9 subgroup level operations based on the sg_layout & sg_data:
+///    %tdesc = xegpu.create_nd_tdesc %src[off1, off2] : memref<24x24xf32> ->
+///           !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
+///
+/// The sg_layout and sg_data attributes are dropped after the pass as they are no longer needed.
+///
+/// 24x24 matrix distribution example:
+/// sg_layout = [4, 4], sg_data = [2, 2]
+/// Each 8x8 matrix within the 24x24 matrix is called a distribution unit.
+/// dist_unit_shape = [8, 8] --> sg_layout[i] * sg_data[i]
+///
+/// +------------------------+
+/// | 8x8 | 8x8 | 8x8 |      <- 3 tiles across
+/// |-----+-----+-----|
+/// | 8x8 | 8x8 | 8x8 |      <- 3 tiles down
+/// |-----+-----+-----|
+/// | 8x8 | 8x8 | 8x8 |
+/// +------------------------+
+///
+/// Each 8x8 tile is further subdivided among subgroups:
+/// +------------------------+
+/// | 2x2 2x2 2x2 2x2 |  <- 4 subgroups across (each handles 2 columns)
+/// | 2x2 2x2 2x2 2x2 |  <- 4 subgroups down (each handles 2 rows)
+/// | 2x2 2x2 2x2 2x2 |
+/// | 2x2 2x2 2x2 2x2 |
+/// +------------------------+
+///
+/// Since the 24x24 matrix is divided into 8x8 distribution units, there will be 9
+/// distribution units (3x3) in total. Hence the 9 subgroup level operations.
+// clang-format on
+struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
+  using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
+
+  // Helper to extract mixed offsets into a Value array
+  SmallVector<Value> extractOffsets(ConversionPatternRewriter &rewriter,
+                                    xegpu::CreateNdDescOp op) const {
+    llvm::SmallVector<Value> offsets;
+    auto staticOffsets = op.getStaticOffsets();
+    auto dynamicOffsets = op.getOffsets();
+
+    for (size_t i = 0, j = 0; i != staticOffsets.size(); i++) {
+      if (ShapedType::isDynamic(staticOffsets[i]))
+        offsets.push_back(dynamicOffsets[j++]);
+      else
+        offsets.push_back(rewriter.create<arith::ConstantIndexOp>(
+            op.getLoc(), staticOffsets[i]));
+    }
+    return offsets;
+  }
+
+  // Convert linear subgroup ID to 2D coordinates
+  // TODO: Delinearize for nD
+  SmallVector<Value> delinearizeSubgroupId(ConversionPatternRewriter &rewriter,
+                                           Location loc, Value sgID,
+                                           Value sgDimX, Value sgDimY) const {
+    return {rewriter.create<index::DivUOp>(loc, sgID, sgDimY),
+            rewriter.create<index::RemUOp>(loc, sgID, sgDimY)};
+  }
+
+  // Create a constant index value
+  Value createConstantIndex(ConversionPatternRewriter &rewriter, Location loc,
+                            int64_t value) const {
+    return rewriter.create<arith::ConstantIndexOp>(loc, value);
+  }
+
+  // Calculate offset for each subgroup
+  SmallVector<OpFoldResult>
+  calculateGlobalOffsets(ConversionPatternRewriter &rewriter, Location loc,
+                         const SmallVector<Value> &originalOffsets,
+                         const SmallVector<Value> &localOffset,
+                         const SmallVector<int64_t> &distUnitBaseAddr) const {
+
+    Value constOffsetX =
+        createConstantIndex(rewriter, loc, distUnitBaseAddr[0]);
+    Value constOffsetY =
+        createConstantIndex(rewriter, loc, distUnitBaseAddr[1]);
+
+    Value offsetX =
+        rewriter.createOrFold<index::AddOp>(loc, localOffset[0], constOffsetX);
+    Value offsetY =
+        rewriter.createOrFold<index::AddOp>(loc, localOffset[1], constOffsetY);
+
+    size_t lastDimIndex = originalOffsets.size() - 1;
+    size_t secondLastDimIndex = lastDimIndex - 1;
+
+    Value globalOffsetX = rewriter.createOrFold<index::AddOp>(
+        loc, originalOffsets[secondLastDimIndex], offsetX);
+    Value globalOffsetY = rewriter.createOrFold<index::AddOp>(
+        loc, originalOffsets[lastDimIndex], offsetY);
+
+    SmallVector<OpFoldResult> globalOffsets(originalOffsets.begin(),
+                                            originalOffsets.end());
+    globalOffsets[secondLastDimIndex] = globalOffsetX;
+    globalOffsets[lastDimIndex] = globalOffsetY;
+
+    return globalOffsets;
+  }
+
+  LogicalResult
+  matchAndRewrite(xegpu::CreateNdDescOp op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    MLIRContext *ctx = op.getContext();
+    xegpu::TensorDescType tdescTy = op.getType();
+    auto layout = dyn_cast<xegpu::LayoutAttr>(tdescTy.getLayout());
+    if (!layout)
+      return failure();
+    Type elemTy = tdescTy.getElementType();
+    ArrayRef<int64_t> wgShape = tdescTy.getShape();
+    SmallVector<int64_t> sgShape =
+        llvm::to_vector_of<int64_t>(layout.getSgData().asArrayRef());
+    SmallVector<int64_t> sgLayout =
+        llvm::to_vector_of<int64_t>(layout.getSgLayout().asArrayRef());
+
+    // TODO : Handle order attribute
+    // Get the subgroup ID
+    auto linearSgId = rewriter.create<gpu::SubgroupIdOp>(loc, nullptr);
+
+    // Create constants for layout dimensions
+    SmallVector<Value> sgLayoutDim(sgLayout.size());
+    SmallVector<Value> sgDataDim(sgShape.size());
+
+    for (size_t i = 0; i < sgLayout.size(); i++) {
+      sgLayoutDim[i] = createConstantIndex(rewriter, loc, sgLayout[i]);
+      sgDataDim[i] = createConstantIndex(rewriter, loc, sgShape[i]);
+    }
+
+    // Delinearize the 1D subgroup id into 2d
+    SmallVector<Value> sgIds = delinearizeSubgroupId(
+        rewriter, loc, linearSgId, sgLayoutDim[0], sgLayoutDim[1]);
+
+    // Calculate distribution unit shape and local offsets for subgroup
+    SmallVector<int64_t> distUnitShape(sgLayout.size());
+    SmallVector<Value> localOffset(sgLayout.size());
+    for (size_t i = 0; i < sgLayout.size(); i++) {
+      distUnitShape[i] = sgLayout[i] * sgShape[i];
+      localOffset[i] =
+          rewriter.createOrFold<index::MulOp>(loc, sgIds[i], sgDataDim[i]);
+    }
+
+    SmallVector<Value> originalOffsets = extractOffsets(rewriter, op);
+
+    xegpu::TensorDescType newTdescTy =
+        xegpu::TensorDescType::get(ctx, sgShape, elemTy, tdescTy.getEncoding(),
+                                   layout.dropSgLayoutAndData());
+    SmallVector<Value> newCreateNdOps;
+    for (SmallVector<int64_t> distUnitBaseAddr :
+         StaticTileOffsetRange(wgShape, distUnitShape)) {
+      SmallVector<OpFoldResult> globalOffsets = calculateGlobalOffsets(
+          rewriter, loc, originalOffsets, localOffset, distUnitBaseAddr);
+
+      auto newCreateNdOp = rewriter.create<xegpu::CreateNdDescOp>(
+          loc, newTdescTy, op.getSource(), globalOffsets, op.getMixedSizes(),
+          op.getMixedStrides());
+      newCreateNdOps.push_back(newCreateNdOp);
+    }
+
+    rewriter.replaceOpWithMultiple(op, {newCreateNdOps});
+    return success();
+  }
+};
+
+/// This pattern transforms the LoadNdOp to load subgroup data.
+struct WgToSgLoadNdOp : public OpConversionPattern<xegpu::LoadNdOp> {
+  using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(xegpu::LoadNdOp op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    SmallVector<Value> newLoadOps;
+    for (auto src : adaptor.getTensorDesc()) {
+      xegpu::TensorDescType tdescTy =
+          dyn_cast<xegpu::TensorDescType>(src.getType());
+      ArrayRef<int64_t> srcShape = tdescTy.getShape();
+      VectorType newResTy = VectorType::get(srcShape, tdescTy.getElementType());
+      auto newLoadOp = rewriter.create<xegpu::LoadNdOp>(op.getLoc(), newResTy,
+                                                        src, op->getAttrs());
+      newLoadOps.push_back(newLoadOp);
+    }
+    rewriter.replaceOpWithMultiple(op, {newLoadOps});
+    return mlir::success();
+  }
+};
+
+/// This pattern transforms the StoreNdOp to store to a subgroup descriptor
+/// It creates a StoreNdOp op to store the updated values to the new subgroup
+/// src tensor descriptors.
+struct WgToSgStoreNdOp : public OpConversionPattern<xegpu::StoreNdOp> {
+  using OpConversionPattern<xegpu::StoreNdOp>::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(xegpu::StoreNdOp op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    for (auto [v, t] : llvm::zip(adaptor.getValue(), adaptor.getTensorDesc()))
+      rewriter.create<xegpu::StoreNdOp>(op.getLoc(), v, t, op.getL1HintAttr(),
+                                        op.getL2HintAttr(), op.getL3HintAttr());
+
+    rewriter.eraseOp(op);
+    return success();
+  }
+};
+
+/// This pattern transforms the UpdateNdOffsetOp to update the offsets of a
+/// subgroup descriptor. It creates an UpdateNdOffsetOp op to update the
+/// offsets of the new subgroup src tensor descriptors.
+struct WgToSgUpdateNdOffsetOp
+    : public OpConversionPattern<xegpu::UpdateNdOffsetOp> {
+  using OpConversionPattern<xegpu::UpdateNdOffsetOp>::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(xegpu::UpdateNdOffsetOp op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    llvm::SmallVector<Value> newUpdateTileOffsetOps;
+    for (auto tDesc : adaptor.getTensorDesc()) {
+      auto newUpdateTileOffsetOp = rewriter.create<xegpu::UpdateNdOffsetOp>(
+          op.getLoc(), tDesc.getType(), tDesc, op.getOffsets(),
+          op.getConstOffsets());
+      newUpdateTileOffsetOps.push_back(newUpdateTileOffsetOp);
+    }
+
+    rewriter.replaceOpWithMultiple(op, {newUpdateTileOffsetOps});
+    return success();
+  }
+};
+
+/// This pattern transforms the DpasOp to work at subgroup level.
+struct WgToSgDpasOp : public OpConversionPattern<xegpu::DpasOp> {
+  using OpConversionPattern<xegpu::DpasOp>::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(xegpu::DpasOp op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    VectorType resultTy = op.getResult().getType();
+    if (resultTy.getRank() != 2)
+      return failure();
+
+    auto originalLayout =
+        llvm::dyn_cast_or_null<xegpu::LayoutAttr>(op->getAttr("layout"));
+    if (!originalLayout)
+      return failure();
+
+    SmallVector<Value> newDpasOps;
+    size_t i = 0;
+    for (auto aVec : adaptor.getLhs()) {
+      for (auto bVec : adaptor.getRhs()) {
+
+        llvm::SmallVector<Value> operands({aVec, bVec});
+        Value tmpC;
+        if (op.getAcc()) {
+          tmpC = adaptor.getAcc()[i++];
----------------
chencha3 wrote:

I feel the logic around C is not correct for oneToN cases, could you double check? (size_t i = 0 may need to be put inside the first loop)

https://github.com/llvm/llvm-project/pull/139477


More information about the Mlir-commits mailing list