[Mlir-commits] [mlir] [mlir] [XeGPU] Add XeGPU workgroup to subgroup pass (PR #139477)
Nishant Patel
llvmlistbot at llvm.org
Sun May 11 15:08:40 PDT 2025
https://github.com/nbpatel updated https://github.com/llvm/llvm-project/pull/139477
>From 1ed4cb5b381898728f850da43a10826493fce94b Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Sat, 10 May 2025 17:04:39 +0000
Subject: [PATCH 1/3] Add XeGPUWgToSg pass
---
.../mlir/Dialect/XeGPU/Transforms/Passes.td | 31 +-
.../Dialect/XeGPU/Transforms/Transforms.h | 4 +
.../Dialect/XeGPU/Transforms/CMakeLists.txt | 1 +
.../Dialect/XeGPU/Transforms/XeGPUWgToSg.cpp | 374 ++++++++++++++++++
.../test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir | 65 +++
mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir | 81 ++++
6 files changed, 544 insertions(+), 12 deletions(-)
create mode 100644 mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSg.cpp
create mode 100644 mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
create mode 100644 mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
index 3e81f2d0ed786..bdea88cfd7022 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
@@ -6,7 +6,6 @@
//
//===----------------------------------------------------------------------===//
-
#ifndef MLIR_DIALECT_XEGPU_TRANSFORMS_PASSES_TD
#define MLIR_DIALECT_XEGPU_TRANSFORMS_PASSES_TD
@@ -18,9 +17,7 @@ def XeGPUFoldAliasOps : Pass<"xegpu-fold-alias-ops"> {
The pass folds aliasing ops into XeGPU ops that they operate on the original
source references.
}];
- let dependentDialects = [
- "memref::MemRefDialect", "xegpu::XeGPUDialect"
- ];
+ let dependentDialects = ["memref::MemRefDialect", "xegpu::XeGPUDialect"];
}
def XeGPUSubgroupDistribute : Pass<"xegpu-subgroup-distribute"> {
@@ -28,14 +25,24 @@ def XeGPUSubgroupDistribute : Pass<"xegpu-subgroup-distribute"> {
let description = [{
The pass distributes subgroup level (SIMD) XeGPU ops to work items.
}];
- let dependentDialects = [
- "memref::MemRefDialect", "xegpu::XeGPUDialect", "vector::VectorDialect"
- ];
- let options = [
- Option<"printOnly", "print-analysis-only", "bool",
- /*default=*/"false",
- "Print the result of the subgroup map propagation analysis and exit.">
- ];
+ let dependentDialects = ["memref::MemRefDialect", "xegpu::XeGPUDialect",
+ "vector::VectorDialect"];
+ let options = [Option<
+ "printOnly", "print-analysis-only", "bool",
+ /*default=*/"false",
+ "Print the result of the subgroup map propagation analysis and exit.">];
+}
+
+def XeGPUWgToSg : Pass<"xegpu-wg-to-sg", "::mlir::gpu::GPUModuleOp"> {
+ let summary = "Transform WorkGroup level XeGPU code to SubGroup level";
+ let description = [{
+ This transform pass distributes the workgroup level computation to
+ multiple subgroups based on the sg_layout and sg_data attributes.
+ }];
+
+ let dependentDialects = ["memref::MemRefDialect", "xegpu::XeGPUDialect",
+ "vector::VectorDialect", "arith::ArithDialect",
+ "gpu::GPUDialect", "index::IndexDialect"];
}
#endif // MLIR_DIALECT_XEGPU_TRANSFORMS_PASSES_TD
diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h b/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
index 3e94021c7a1ea..388ba32e1eebb 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
@@ -9,6 +9,8 @@
#ifndef MLIR_DIALECT_XEGPU_TRANSFORMS_TRANSFORMS_H
#define MLIR_DIALECT_XEGPU_TRANSFORMS_TRANSFORMS_H
+#include "mlir/Transforms/DialectConversion.h"
+
namespace mlir {
class RewritePatternSet;
@@ -18,6 +20,8 @@ namespace xegpu {
void populateXeGPUFoldAliasOpsPatterns(RewritePatternSet &patterns);
/// Appends patterns for XeGPU SIMT distribution into `patterns`.
void populateXeGPUSubgroupDistributePatterns(RewritePatternSet &patterns);
+void populateXeGPUWgToSgPatterns(RewritePatternSet &patterns,
+ ConversionTarget &target);
} // namespace xegpu
} // namespace mlir
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
index 901e02d3c9cf5..b258921cc87fd 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
@@ -1,6 +1,7 @@
add_mlir_dialect_library(MLIRXeGPUTransforms
XeGPUFoldAliasOps.cpp
XeGPUSubgroupDistribute.cpp
+ XeGPUWgToSg.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/XeGPU
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSg.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSg.cpp
new file mode 100644
index 0000000000000..7969d37d67f04
--- /dev/null
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSg.cpp
@@ -0,0 +1,374 @@
+//===- XeGPUWgToSg.cpp - XeGPU WorkGroup to Subgroup Pass -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
+
+#include "mlir/Dialect/Index/IR/IndexDialect.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/Support/Debug.h"
+#include <mlir/Dialect/GPU/IR/GPUDialect.h>
+#include <mlir/Dialect/Index/IR/IndexOps.h>
+#include <numeric>
+
+namespace mlir {
+namespace xegpu {
+#define GEN_PASS_DEF_XEGPUWGTOSG
+#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
+} // namespace xegpu
+} // namespace mlir
+
+#define DEBUG_TYPE "xegpu-wg-to-sg"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
+#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
+
+using namespace mlir;
+
+namespace {
+
+// clang-format off
+/// This pattern transform the CreateNdDescOp to create a subgroup descriptor
+/// from a workgroup descriptor. It replaces the offsets and sizes with
+/// appropriate values for the subgroup.
+/// It uses round-robin distribution to create the subgroup descriptor.
+
+/// Following create_nd_desc operation:,
+/// %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x24xf32>
+/// -> !xegpu.tensor_desc<24x24xf32, #xegpu.layout<sg_layout = [4, 4],
+/// sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+/// is converted to 9 subgroup level operations based on the sg_layout & sg_data:
+/// %tdesc = xegpu.create_nd_tdesc %src[off1, off2] : memref<24x24xf32> ->
+/// !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
+///
+/// The sg_layout and sg_data are dropped from the layout attribute as they are no longer needed.
+///
+/// 24x24 matrix distribution example:
+/// sg_layout = [4, 4], sg_data = [2, 2]
+/// Each 8x8 matrix within the 24x24 matrix is called a distribution unit.
+/// dist_unit_shape = [8, 8] --> sg_layout[i] * sg_data[i]
+///
+/// +------------------------+
+/// | 8x8 | 8x8 | 8x8 | <- 3 tiles across
+/// |-----+-----+-----|
+/// | 8x8 | 8x8 | 8x8 | <- 3 tiles down
+/// |-----+-----+-----|
+/// | 8x8 | 8x8 | 8x8 |
+/// +------------------------+
+///
+/// Each 8x8 tile is further subdivided among subgroups:
+/// +------------------------+
+/// | 2x2 2x2 2x2 2x2 | <- 4 subgroups across (each handles 2 columns)
+/// | 2x2 2x2 2x2 2x2 | <- 4 subgroups down (each handles 2 rows)
+/// | 2x2 2x2 2x2 2x2 |
+/// | 2x2 2x2 2x2 2x2 |
+/// +------------------------+
+///
+/// Since the 24x24 matrix is divided into 8x8 distribution units, there will be 9
+/// distribution units (3x3) in total. Hence the 9 subgroup level operations.
+/// Each 8x8 matrix within the 24x24 matrix is called a distribution unit.
+// clang-format on
+struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
+ using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
+
+ // Helper to extract mixed offsets into a Value array
+ SmallVector<Value> extractOffsets(ConversionPatternRewriter &rewriter,
+ xegpu::CreateNdDescOp op) const {
+ llvm::SmallVector<Value> offsets;
+ auto staticOffsets = op.getStaticOffsets();
+ auto dynamicOffsets = op.getOffsets();
+
+ for (size_t i = 0, j = 0; i != staticOffsets.size(); i++) {
+ if (ShapedType::isDynamic(staticOffsets[i])) {
+ offsets.push_back(dynamicOffsets[j++]);
+ } else {
+ offsets.push_back(rewriter.create<arith::ConstantIndexOp>(
+ op.getLoc(), staticOffsets[i]));
+ }
+ }
+ return offsets;
+ }
+
+ // Convert linear subgroup ID to 2D coordinates
+ // TODO: Delinearize for nD
+ SmallVector<Value> delinearizeSubgroupId(ConversionPatternRewriter &rewriter,
+ Location loc, Value sgID,
+ Value sgDimX, Value sgDimY) const {
+ return {rewriter.create<index::DivUOp>(loc, sgID, sgDimY),
+ rewriter.create<index::RemUOp>(loc, sgID, sgDimY)};
+ }
+
+ // Create a constant index value
+ Value createConstantIndex(ConversionPatternRewriter &rewriter, Location loc,
+ int64_t value) const {
+ return rewriter.create<arith::ConstantIndexOp>(loc, value);
+ }
+
+ // Calculate global offset for each subgroup
+ SmallVector<OpFoldResult>
+ calculateGlobalOffsets(ConversionPatternRewriter &rewriter, Location loc,
+ const SmallVector<Value> &originalOffsets,
+ const SmallVector<Value> &localOffset,
+ const SmallVector<int64_t> &distUnitBaseAddr) const {
+
+ Value constOffsetX =
+ createConstantIndex(rewriter, loc, distUnitBaseAddr[0]);
+ Value constOffsetY =
+ createConstantIndex(rewriter, loc, distUnitBaseAddr[1]);
+
+ // Compute offsets within entire tile
+ Value offsetX =
+ rewriter.createOrFold<index::AddOp>(loc, localOffset[0], constOffsetX);
+ Value offsetY =
+ rewriter.createOrFold<index::AddOp>(loc, localOffset[1], constOffsetY);
+
+ // Add to global offsets
+ size_t lastDimIndex = originalOffsets.size() - 1;
+ size_t secondLastDimIndex = lastDimIndex - 1;
+
+ Value globalOffsetX = rewriter.createOrFold<index::AddOp>(
+ loc, originalOffsets[secondLastDimIndex], offsetX);
+ Value globalOffsetY = rewriter.createOrFold<index::AddOp>(
+ loc, originalOffsets[lastDimIndex], offsetY);
+
+ // Create final offset list
+ SmallVector<OpFoldResult> globalOffsets(originalOffsets.begin(),
+ originalOffsets.end());
+ globalOffsets[secondLastDimIndex] = globalOffsetX;
+ globalOffsets[lastDimIndex] = globalOffsetY;
+
+ return globalOffsets;
+ }
+
+ LogicalResult
+ matchAndRewrite(xegpu::CreateNdDescOp op, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ MLIRContext *ctx = op.getContext();
+ xegpu::TensorDescType tdescTy = op.getType();
+ auto layout = dyn_cast<xegpu::LayoutAttr>(tdescTy.getLayout());
+ Type elemTy = tdescTy.getElementType();
+ ArrayRef<int64_t> wgShape = tdescTy.getShape();
+ ArrayRef<int64_t> sgShape =
+ llvm::to_vector_of<int64_t>(layout.getSgData().asArrayRef());
+ ArrayRef<int64_t> sgLayout =
+ llvm::to_vector_of<int64_t>(layout.getSgLayout().asArrayRef());
+
+ // Get the subgroup ID
+ auto linearSgId = rewriter.create<gpu::SubgroupIdOp>(loc, nullptr);
+
+ // Create constants for layout dimensions
+ SmallVector<Value> sgLayoutDim(sgLayout.size());
+ SmallVector<Value> sgDataDim(sgShape.size());
+
+ for (size_t i = 0; i < sgLayout.size(); i++) {
+ sgLayoutDim[i] = createConstantIndex(rewriter, loc, sgLayout[i]);
+ sgDataDim[i] = createConstantIndex(rewriter, loc, sgShape[i]);
+ }
+
+ // Delinearize the 1D subgroup id into nd coordinates
+ SmallVector<Value> sgIds = delinearizeSubgroupId(
+ rewriter, loc, linearSgId, sgLayoutDim[0], sgLayoutDim[1]);
+
+ // Calculate distribution unit shape and local offsets for subgroup
+ SmallVector<int64_t> distUnitShape(sgLayout.size());
+ SmallVector<Value> localOffset(sgLayout.size());
+ for (size_t i = 0; i < sgLayout.size(); i++) {
+ distUnitShape[i] = sgLayout[i] * sgShape[i];
+ localOffset[i] =
+ rewriter.createOrFold<index::MulOp>(loc, sgIds[i], sgDataDim[i]);
+ }
+
+ SmallVector<Value> originalOffsets = extractOffsets(rewriter, op);
+
+ xegpu::TensorDescType newTdescTy =
+ xegpu::TensorDescType::get(ctx, sgShape, elemTy, tdescTy.getEncoding(),
+ layout.dropSgLayoutAndData());
+ SmallVector<Value> newCreateNdOps;
+ for (const SmallVector<int64_t> &distUnitBaseAddr :
+ StaticTileOffsetRange(wgShape, distUnitShape)) {
+ SmallVector<OpFoldResult> globalOffsets = calculateGlobalOffsets(
+ rewriter, loc, originalOffsets, localOffset, distUnitBaseAddr);
+
+ auto newCreateNdOp = rewriter.create<xegpu::CreateNdDescOp>(
+ loc, newTdescTy, op.getSource(), globalOffsets, op.getMixedSizes(),
+ op.getMixedStrides());
+ newCreateNdOps.push_back(newCreateNdOp);
+ }
+
+ rewriter.replaceOpWithMultiple(op, {newCreateNdOps});
+ return success();
+ }
+};
+
+/// This pattern transforms the LoadNdOp to load from a subgroup descriptor
+/// It creates a LoadNdOp op to load the new subgroup src tensor descriptors.
+struct WgToSgLoadNdOp : public OpConversionPattern<xegpu::LoadNdOp> {
+ using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(xegpu::LoadNdOp op, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ SmallVector<Value> newLoadOps;
+ for (auto src : adaptor.getTensorDesc()) {
+ xegpu::TensorDescType tdescTy =
+ dyn_cast<xegpu::TensorDescType>(src.getType());
+ ArrayRef<int64_t> srcShape = tdescTy.getShape();
+ VectorType newResTy = VectorType::get(srcShape, tdescTy.getElementType());
+ auto newLoadOp = rewriter.create<xegpu::LoadNdOp>(op.getLoc(), newResTy,
+ src, op->getAttrs());
+ newLoadOps.push_back(newLoadOp);
+ }
+ rewriter.replaceOpWithMultiple(op, {newLoadOps});
+ return mlir::success();
+ }
+};
+
+/// This pattern transforms the StoreNdOp to store to a subgroup descriptor
+/// It creates a StoreNdOp op to store the updated values to the new subgroup
+/// src tensor descriptors.
+struct WgToSgStoreNdOp : public OpConversionPattern<xegpu::StoreNdOp> {
+ using OpConversionPattern<xegpu::StoreNdOp>::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(xegpu::StoreNdOp op, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ for (auto [v, t] : llvm::zip(adaptor.getValue(), adaptor.getTensorDesc()))
+ rewriter.create<xegpu::StoreNdOp>(op.getLoc(), v, t, op.getL1HintAttr(),
+ op.getL2HintAttr(), op.getL3HintAttr());
+
+ rewriter.eraseOp(op);
+ return success();
+ }
+};
+
+/// This pattern transforms the UpdateNdOffsetOp to update the offsets of a
+/// subgroup descriptor. It creates an UpdateNdOffsetOp op to update the
+/// offsets of the new subgroup src tensor descriptors.
+struct WgToSgUpdateNdOffsetOp
+ : public OpConversionPattern<xegpu::UpdateNdOffsetOp> {
+ using OpConversionPattern<xegpu::UpdateNdOffsetOp>::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(xegpu::UpdateNdOffsetOp op, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ llvm::SmallVector<Value> newUpdateTileOffsetOps;
+ for (auto tDesc : adaptor.getTensorDesc()) {
+ auto newUpdateTileOffsetOp = rewriter.create<xegpu::UpdateNdOffsetOp>(
+ op.getLoc(), tDesc.getType(), tDesc, op.getOffsets(),
+ op.getConstOffsets());
+ newUpdateTileOffsetOps.push_back(newUpdateTileOffsetOp);
+ }
+
+ rewriter.replaceOpWithMultiple(op, {newUpdateTileOffsetOps});
+ return success();
+ }
+};
+
+/// This pattern transforms the DpasOp to work at subgroup level.
+struct WgToSgDpasOp : public OpConversionPattern<xegpu::DpasOp> {
+ using OpConversionPattern<xegpu::DpasOp>::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(xegpu::DpasOp op, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ VectorType resultTy = op.getResult().getType();
+ if (resultTy.getRank() != 2)
+ return failure();
+
+ auto originalLayout =
+ llvm::dyn_cast_or_null<xegpu::LayoutAttr>(op->getAttr("layout"));
+ if (!originalLayout)
+ return failure();
+
+ SmallVector<Value> newDpasOps;
+ size_t i = 0;
+ for (auto aVec : adaptor.getLhs()) {
+ for (auto bVec : adaptor.getRhs()) {
+
+ llvm::SmallVector<Value> operands({aVec, bVec});
+ Value tmpC;
+ if (op.getAcc()) {
+ tmpC = adaptor.getAcc()[i++];
+ operands.push_back(tmpC);
+ }
+
+ ArrayRef<int64_t> aVecShape =
+ llvm::cast<VectorType>(aVec.getType()).getShape();
+ ArrayRef<int64_t> bVecShape =
+ llvm::cast<VectorType>(bVec.getType()).getShape();
+ VectorType resTy = VectorType::get({aVecShape[0], bVecShape[1]},
+ resultTy.getElementType());
+ tmpC = rewriter.create<xegpu::DpasOp>(
+ loc, resTy, operands,
+ llvm::ArrayRef<NamedAttribute>(
+ {"layout", originalLayout.dropSgLayoutAndData()}));
+ newDpasOps.push_back(tmpC);
+ }
+ }
+ rewriter.replaceOpWithMultiple(op, {newDpasOps});
+ return mlir::success();
+ }
+};
+
+} // namespace
+
+namespace mlir {
+namespace xegpu {
+void populateXeGPUWgToSgPatterns(RewritePatternSet &patterns) {
+ patterns.add<WgToSgCreateNdOp, WgToSgLoadNdOp, WgToSgStoreNdOp,
+ WgToSgUpdateNdOffsetOp, WgToSgDpasOp>(patterns.getContext());
+}
+} // namespace xegpu
+} // namespace mlir
+
+namespace {
+struct XeGPUWgToSgPass : public xegpu::impl::XeGPUWgToSgBase<XeGPUWgToSgPass> {
+ void runOnOperation() override;
+};
+} // namespace
+
+void XeGPUWgToSgPass::runOnOperation() {
+ MLIRContext *ctx = &getContext();
+ RewritePatternSet patterns(ctx);
+ ConversionTarget target(*ctx);
+
+ auto getTensorDescType = [](Operation *op) -> xegpu::TensorDescType {
+ if (auto createOp = dyn_cast<xegpu::CreateNdDescOp>(op))
+ return createOp.getType();
+ if (auto loadOp = dyn_cast<xegpu::LoadNdOp>(op))
+ return loadOp.getTensorDescType();
+ if (auto storeOp = dyn_cast<xegpu::StoreNdOp>(op))
+ return storeOp.getTensorDescType();
+ if (auto updateOp = dyn_cast<xegpu::UpdateNdOffsetOp>(op))
+ return updateOp.getType();
+ return xegpu::TensorDescType();
+ };
+
+ auto isLegal = [&](xegpu::LayoutAttr layout) -> bool {
+ return !layout || layout.getSgLayout() == nullptr;
+ };
+
+ target.addDynamicallyLegalOp<xegpu::CreateNdDescOp, xegpu::LoadNdOp,
+ xegpu::StoreNdOp, xegpu::UpdateNdOffsetOp>(
+ [=](Operation *op) -> bool {
+ auto tdescTy = getTensorDescType(op);
+ auto layout = dyn_cast_or_null<xegpu::LayoutAttr>(tdescTy.getLayout());
+ return isLegal(layout);
+ });
+
+ target.addDynamicallyLegalOp<xegpu::DpasOp>([=](xegpu::DpasOp op) -> bool {
+ auto layout = dyn_cast_or_null<xegpu::LayoutAttr>(op->getAttr("layout"));
+ return isLegal(layout);
+ });
+
+ target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
+
+ xegpu::populateXeGPUWgToSgPatterns(patterns);
+ if (failed(
+ applyPartialConversion(getOperation(), target, std::move(patterns))))
+ return signalPassFailure();
+}
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
new file mode 100644
index 0000000000000..d0f225c3e7304
--- /dev/null
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
@@ -0,0 +1,65 @@
+// RUN: mlir-opt --xegpu-wg-to-sg -split-input-file %s | FileCheck %s
+
+gpu.module @test_round_robin_assignment {
+ // CHECK: test_create_nd_tdesc
+ // CHECK: %[[ARG_0:.*]]: memref<24x32xf32>
+ gpu.func @test_create_nd_tdesc(%src: memref<24x32xf32>) {
+ // CHECK-COUNT-12: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][%{{.*}}, %{{.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
+ %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+ gpu.return
+ }
+
+ // CHECK: test_load_nd_tdesc
+ // CHECK: %[[ARG_0:.*]]: memref<24x32xf32>
+ gpu.func @test_load_nd_tdesc(%src: memref<24x32xf32>) {
+ %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+ // CHECK-COUNT-12: %[[LOAD:.*]] = xegpu.load_nd %{{.*}} : !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>> -> vector<2x2xf32>
+ %load = xegpu.load_nd %tdesc: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>> -> vector<24x32xf32>
+ gpu.return
+ }
+
+ // CHECK: test_store_nd
+ // CHECK: %[[ARG_0:.*]]: memref<24x32xf32>
+ gpu.func @test_store_nd(%src: memref<24x32xf32>) {
+ %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+ // CHECK-COUNT-12: xegpu.store_nd %{{.*}}, %{{.*}} : vector<2x2xf32>, !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
+ %load = xegpu.load_nd %tdesc: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>> -> vector<24x32xf32>
+ xegpu.store_nd %load, %tdesc: vector<24x32xf32>, !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+ gpu.return
+ }
+
+ // CHECK: test_update_nd
+ // CHECK: %[[ARG_0:.*]]: memref<24x32xf32>
+ gpu.func @test_update_nd(%src: memref<24x32xf32>){
+ %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+ // CHECK-COUNT-12: %[[UPDATE:.*]] = xegpu.update_nd_offset %{{.*}}, [0, 16] : !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
+ %update = xegpu.update_nd_offset %tdesc, [0, 16] : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+ gpu.return
+ }
+
+ // CHECK: test_dpas
+ // CHECK: %[[ARG_0:.*]]: memref<24x32xf32>
+ // CHECK: %[[ARG_1:.*]]: memref<32x24xf32>
+ // CHECK: %[[ARG_2:.*]]: memref<24x24xf32>
+ gpu.func @test_dpas(%a: memref<24x32xf32>, %b: memref<32x24xf32>, %c: memref<24x24xf32>) {
+ // CHECK-COUNT-12: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][%{{.*}},
+ // %{{.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<2x2xf32,
+ // #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>> CHECK-COUNT-12:
+ // %[[TDESC1:.*]] = xegpu.create_nd_tdesc %[[ARG_1]][%{{.*}}, %{{.*}}] :
+ // memref<32x24xf32> -> !xegpu.tensor_desc<2x2xf32,
+ // #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>> CHECK-COUNT-9:
+ // %[[TDESC2:.*]] = xegpu.create_nd_tdesc %{{.*}}[%{{.*}}, %{{.*}}] :
+ // memref<24x24xf32> -> !xegpu.tensor_desc<2x2xf32,
+ // #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>> CHECK-COUNT-144:
+ // %[[DPAS:.*]] = xegpu.dpas %{{.*}}, %{{.*}} {layout =
+ // #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} :
+ // vector<2x2xf32>, vector<2x2xf32> -> vector<2x2xf32>
+ %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+ %load_a = xegpu.load_nd %tdesc_a: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>> -> vector<24x32xf32>
+ %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<32x24xf32> -> !xegpu.tensor_desc<32x24xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+ %load_b = xegpu.load_nd %tdesc_b: !xegpu.tensor_desc<32x24xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>> -> vector<32x24xf32>
+ %tdesc_c = xegpu.create_nd_tdesc %c[0, 0] : memref<24x24xf32> -> !xegpu.tensor_desc<24x24xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+ %dpas = xegpu.dpas %load_a, %load_b {layout = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>} : vector<24x32xf32>, vector<32x24xf32> -> vector<24x24xf32>
+ gpu.return
+ }
+}
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
new file mode 100644
index 0000000000000..c4c8881e65597
--- /dev/null
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
@@ -0,0 +1,81 @@
+// RUN: mlir-opt --xegpu-wg-to-sg -split-input-file %s | FileCheck %s
+
+gpu.module @test_1_1_assignment {
+ // CHECK: test_create_nd_tdesc
+ // CHECK: %[[ARG_0:.*]]: memref<24x32xf32>
+ gpu.func @test_create_nd_tdesc(%src: memref<24x32xf32>) {
+ // CHECK: %[[SGID:.*]] = gpu.subgroup_id
+ // CHECK: %[[C12:.*]] = arith.constant 12 : index
+ // CHECK: %[[C4:.*]] = arith.constant 4 : index
+ // CHECK: %[[C8:.*]] = arith.constant 8 : index
+ // CHECK: %[[DIV:.*]] = index.divu %[[SGID]], %[[C4]]
+ // CHECK: %[[REM:.*]] = index.remu %[[SGID]], %[[C4]]
+ // CHECK: %[[MUL1:.*]] = index.mul %[[DIV]], %[[C12]]
+ // CHECK: %[[MUL2:.*]] = index.mul %[[REM]], %[[C8]]
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK: %[[ADD1:.*]] = index.add %[[MUL1]], %[[C0]]
+ // CHECK: %[[ADD2:.*]] = index.add %[[MUL2]], %[[C0]]
+ // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][%[[ADD1]], %[[ADD2]]] : memref<24x32xf32> -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
+ // CHECK: gpu.return
+ %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+ gpu.return
+ }
+
+ // CHECK: test_load_nd_tdesc
+ // CHECK: %[[ARG_0:.*]]: memref<24x32xf32>
+ gpu.func @test_load_nd_tdesc(%src: memref<24x32xf32>) {
+ // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
+ // CHECK: %[[LOAD:.*]] = xegpu.load_nd %[[TDESC]] : !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>> -> vector<12x8xf32>
+ %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+ %load = xegpu.load_nd %tdesc: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>> -> vector<24x32xf32>
+ gpu.return
+ }
+
+ // CHECK: test_store_nd
+ // CHECK: %[[ARG_0:.*]]: memref<24x32xf32>
+ gpu.func @test_store_nd(%src: memref<24x32xf32>) {
+ // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
+ // CHECK: %[[LOAD:.*]] = xegpu.load_nd %[[TDESC]] : !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>> -> vector<12x8xf32>
+ // CHECK: xegpu.store_nd %[[LOAD]], %[[TDESC]] : vector<12x8xf32>, !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
+ %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+ %load = xegpu.load_nd %tdesc: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>> -> vector<24x32xf32>
+ xegpu.store_nd %load, %tdesc: vector<24x32xf32>, !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+ gpu.return
+}
+
+// CHECK: test_update_nd
+// CHECK: %[[ARG_0:.*]]: memref<24x32xf32>
+gpu.func @test_update_nd(%src: memref<24x32xf32>){
+ // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
+ // CHECK: %[[UPDATE:.*]] = xegpu.update_nd_offset %[[TDESC]], [0, 16] : !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
+ %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+ %update = xegpu.update_nd_offset %tdesc, [0, 16] : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+ gpu.return
+}
+
+// CHECK: test_dpas
+// CHECK: %[[ARG_0:.*]]: memref<24x32xf32>
+// CHECK: %[[ARG_1:.*]]: memref<32x24xf32>
+gpu.func @test_dpas(%a: memref<24x32xf32>, %b: memref<32x24xf32>) {
+ // CHECK: %[[TDESC_A:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}},
+ // {{%.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<12x8xf32,
+ // #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>> CHECK:
+ // %[[LOAD_A:.*]] = xegpu.load_nd %[[TDESC_A]] :
+ // !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8],
+ // lane_data = [1, 1]>> -> vector<12x8xf32> CHECK: %[[TDESC_B:.*]] =
+ // xegpu.create_nd_tdesc %[[ARG_1]][{{%.*}}, {{%.*}}] : memref<32x24xf32> ->
+ // !xegpu.tensor_desc<8x12xf32, #xegpu.layout<lane_layout = [8, 2],
+ // lane_data = [1, 1]>> CHECK: %[[LOAD_B:.*]] = xegpu.load_nd %[[TDESC_B]] :
+ // !xegpu.tensor_desc<8x12xf32, #xegpu.layout<lane_layout = [8, 2],
+ // lane_data = [1, 1]>> -> vector<8x12xf32> CHECK: %[[DPAS:.*]] = xegpu.dpas
+ // %[[LOAD_A]], %[[LOAD_B]] {layout = #xegpu.layout<lane_layout = [2, 2],
+ // lane_data = [1, 1]>} : vector<12x8xf32>, vector<8x12xf32> ->
+ // vector<12x12xf32>
+ %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+ %load_a = xegpu.load_nd %tdesc_a: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>> -> vector<24x32xf32>
+ %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<32x24xf32> -> !xegpu.tensor_desc<32x24xf32, #xegpu.layout<sg_layout = [4, 2], sg_data = [8, 12], lane_layout = [8, 2], lane_data = [1, 1]>>
+ %load_b = xegpu.load_nd %tdesc_b: !xegpu.tensor_desc<32x24xf32, #xegpu.layout<sg_layout = [4, 2], sg_data = [8, 12], lane_layout = [8, 2], lane_data = [1, 1]>> -> vector<32x24xf32>
+ %dpas = xegpu.dpas %load_a, %load_b {layout = #xegpu.layout<sg_layout = [2, 2], sg_data = [12, 12], lane_layout = [2, 2], lane_data = [1, 1]>} : vector<24x32xf32>, vector<32x24xf32> -> vector<24x24xf32>
+ gpu.return
+ }
+}
>From b3bf12f082eb08aa3f82503142140fc686e0e950 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Sun, 11 May 2025 15:49:35 +0000
Subject: [PATCH 2/3] Add prefetch_nd op
---
.../Dialect/XeGPU/Transforms/XeGPUWgToSg.cpp | 52 ++++++++++++-------
.../test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir | 25 ++++-----
mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir | 29 ++++++-----
3 files changed, 60 insertions(+), 46 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSg.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSg.cpp
index 7969d37d67f04..5eabb04e3b858 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSg.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSg.cpp
@@ -34,11 +34,10 @@ using namespace mlir;
namespace {
// clang-format off
-/// This pattern transform the CreateNdDescOp to create a subgroup descriptor
+/// This pattern transforms the CreateNdDescOp to create a subgroup descriptor
/// from a workgroup descriptor. It replaces the offsets and sizes with
/// appropriate values for the subgroup.
-/// It uses round-robin distribution to create the subgroup descriptor.
-
+/// It uses round-robin assignment to distribute the work to the subgroups.
/// Following create_nd_desc operation:,
/// %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x24xf32>
/// -> !xegpu.tensor_desc<24x24xf32, #xegpu.layout<sg_layout = [4, 4],
@@ -47,7 +46,7 @@ namespace {
/// %tdesc = xegpu.create_nd_tdesc %src[off1, off2] : memref<24x24xf32> ->
/// !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
///
-/// The sg_layout and sg_data are dropped from the layout attribute as they are no longer needed.
+/// The sg_layout and sg_data attributes are dropped after the pass as they are no longer needed.
///
/// 24x24 matrix distribution example:
/// sg_layout = [4, 4], sg_data = [2, 2]
@@ -72,7 +71,6 @@ namespace {
///
/// Since the 24x24 matrix is divided into 8x8 distribution units, there will be 9
/// distribution units (3x3) in total. Hence the 9 subgroup level operations.
-/// Each 8x8 matrix within the 24x24 matrix is called a distribution unit.
// clang-format on
struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
@@ -110,7 +108,7 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
return rewriter.create<arith::ConstantIndexOp>(loc, value);
}
- // Calculate global offset for each subgroup
+ // Calculate offset for each subgroup
SmallVector<OpFoldResult>
calculateGlobalOffsets(ConversionPatternRewriter &rewriter, Location loc,
const SmallVector<Value> &originalOffsets,
@@ -122,13 +120,11 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
Value constOffsetY =
createConstantIndex(rewriter, loc, distUnitBaseAddr[1]);
- // Compute offsets within entire tile
Value offsetX =
rewriter.createOrFold<index::AddOp>(loc, localOffset[0], constOffsetX);
Value offsetY =
rewriter.createOrFold<index::AddOp>(loc, localOffset[1], constOffsetY);
- // Add to global offsets
size_t lastDimIndex = originalOffsets.size() - 1;
size_t secondLastDimIndex = lastDimIndex - 1;
@@ -137,7 +133,6 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
Value globalOffsetY = rewriter.createOrFold<index::AddOp>(
loc, originalOffsets[lastDimIndex], offsetY);
- // Create final offset list
SmallVector<OpFoldResult> globalOffsets(originalOffsets.begin(),
originalOffsets.end());
globalOffsets[secondLastDimIndex] = globalOffsetX;
@@ -172,7 +167,7 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
sgDataDim[i] = createConstantIndex(rewriter, loc, sgShape[i]);
}
- // Delinearize the 1D subgroup id into nd coordinates
+ // Delinearize the 1D subgroup id into 2d
SmallVector<Value> sgIds = delinearizeSubgroupId(
rewriter, loc, linearSgId, sgLayoutDim[0], sgLayoutDim[1]);
@@ -207,8 +202,7 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
}
};
-/// This pattern transforms the LoadNdOp to load from a subgroup descriptor
-/// It creates a LoadNdOp op to load the new subgroup src tensor descriptors.
+/// This pattern transforms the LoadNdOp to load subgroup data.
struct WgToSgLoadNdOp : public OpConversionPattern<xegpu::LoadNdOp> {
using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
LogicalResult
@@ -310,7 +304,22 @@ struct WgToSgDpasOp : public OpConversionPattern<xegpu::DpasOp> {
}
}
rewriter.replaceOpWithMultiple(op, {newDpasOps});
- return mlir::success();
+ return success();
+ }
+};
+
+/// This pattern transforms the PrefetchNdOp to prefetch the subgroup data.
+struct WgToSgPrefetchNdOp : public OpConversionPattern<xegpu::PrefetchNdOp> {
+ using OpConversionPattern<xegpu::PrefetchNdOp>::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(xegpu::PrefetchNdOp op, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ for (auto src : adaptor.getTensorDesc()) {
+ rewriter.create<xegpu::PrefetchNdOp>(op.getLoc(), TypeRange(), src,
+ op->getAttrs());
+ }
+ rewriter.eraseOp(op);
+ return success();
}
};
@@ -320,7 +329,8 @@ namespace mlir {
namespace xegpu {
void populateXeGPUWgToSgPatterns(RewritePatternSet &patterns) {
patterns.add<WgToSgCreateNdOp, WgToSgLoadNdOp, WgToSgStoreNdOp,
- WgToSgUpdateNdOffsetOp, WgToSgDpasOp>(patterns.getContext());
+ WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp>(
+ patterns.getContext());
}
} // namespace xegpu
} // namespace mlir
@@ -345,6 +355,8 @@ void XeGPUWgToSgPass::runOnOperation() {
return storeOp.getTensorDescType();
if (auto updateOp = dyn_cast<xegpu::UpdateNdOffsetOp>(op))
return updateOp.getType();
+ if (auto prefetchOp = dyn_cast<xegpu::PrefetchNdOp>(op))
+ return prefetchOp.getTensorDescType();
return xegpu::TensorDescType();
};
@@ -353,12 +365,12 @@ void XeGPUWgToSgPass::runOnOperation() {
};
target.addDynamicallyLegalOp<xegpu::CreateNdDescOp, xegpu::LoadNdOp,
- xegpu::StoreNdOp, xegpu::UpdateNdOffsetOp>(
- [=](Operation *op) -> bool {
- auto tdescTy = getTensorDescType(op);
- auto layout = dyn_cast_or_null<xegpu::LayoutAttr>(tdescTy.getLayout());
- return isLegal(layout);
- });
+ xegpu::StoreNdOp, xegpu::UpdateNdOffsetOp,
+ xegpu::PrefetchNdOp>([=](Operation *op) -> bool {
+ auto tdescTy = getTensorDescType(op);
+ auto layout = dyn_cast_or_null<xegpu::LayoutAttr>(tdescTy.getLayout());
+ return isLegal(layout);
+ });
target.addDynamicallyLegalOp<xegpu::DpasOp>([=](xegpu::DpasOp op) -> bool {
auto layout = dyn_cast_or_null<xegpu::LayoutAttr>(op->getAttr("layout"));
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
index d0f225c3e7304..de2c548ec7ebb 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
@@ -42,18 +42,10 @@ gpu.module @test_round_robin_assignment {
// CHECK: %[[ARG_1:.*]]: memref<32x24xf32>
// CHECK: %[[ARG_2:.*]]: memref<24x24xf32>
gpu.func @test_dpas(%a: memref<24x32xf32>, %b: memref<32x24xf32>, %c: memref<24x24xf32>) {
- // CHECK-COUNT-12: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][%{{.*}},
- // %{{.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<2x2xf32,
- // #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>> CHECK-COUNT-12:
- // %[[TDESC1:.*]] = xegpu.create_nd_tdesc %[[ARG_1]][%{{.*}}, %{{.*}}] :
- // memref<32x24xf32> -> !xegpu.tensor_desc<2x2xf32,
- // #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>> CHECK-COUNT-9:
- // %[[TDESC2:.*]] = xegpu.create_nd_tdesc %{{.*}}[%{{.*}}, %{{.*}}] :
- // memref<24x24xf32> -> !xegpu.tensor_desc<2x2xf32,
- // #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>> CHECK-COUNT-144:
- // %[[DPAS:.*]] = xegpu.dpas %{{.*}}, %{{.*}} {layout =
- // #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} :
- // vector<2x2xf32>, vector<2x2xf32> -> vector<2x2xf32>
+ // CHECK-COUNT-12: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][%{{.*}}, %{{.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
+ // CHECK-COUNT-12: %[[TDESC1:.*]] = xegpu.create_nd_tdesc %[[ARG_1]][%{{.*}}, %{{.*}}] : memref<32x24xf32> -> !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
+ // CHECK-COUNT-9: %[[TDESC2:.*]] = xegpu.create_nd_tdesc %{{.*}}[%{{.*}}, %{{.*}}] : memref<24x24xf32> -> !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
+ // CHECK-COUNT-144: %[[DPAS:.*]] = xegpu.dpas %{{.*}}, %{{.*}} {layout = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32>, vector<2x2xf32> -> vector<2x2xf32>
%tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
%load_a = xegpu.load_nd %tdesc_a: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>> -> vector<24x32xf32>
%tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<32x24xf32> -> !xegpu.tensor_desc<32x24xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
@@ -62,4 +54,13 @@ gpu.module @test_round_robin_assignment {
%dpas = xegpu.dpas %load_a, %load_b {layout = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>} : vector<24x32xf32>, vector<32x24xf32> -> vector<24x24xf32>
gpu.return
}
+
+ // CHECK: test_prefetch_nd_tdesc
+ // CHECK: %[[ARG_0:.*]]: memref<24x32xf32>
+ gpu.func @test_prefetch_nd_tdesc(%src: memref<24x32xf32>) {
+ // CHECK-COUNT-12: xegpu.prefetch_nd %{{.*}} : !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
+ %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+ xegpu.prefetch_nd %tdesc: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+ gpu.return
+ }
}
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
index c4c8881e65597..1cae2c822d826 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
@@ -57,20 +57,11 @@ gpu.func @test_update_nd(%src: memref<24x32xf32>){
// CHECK: %[[ARG_0:.*]]: memref<24x32xf32>
// CHECK: %[[ARG_1:.*]]: memref<32x24xf32>
gpu.func @test_dpas(%a: memref<24x32xf32>, %b: memref<32x24xf32>) {
- // CHECK: %[[TDESC_A:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}},
- // {{%.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<12x8xf32,
- // #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>> CHECK:
- // %[[LOAD_A:.*]] = xegpu.load_nd %[[TDESC_A]] :
- // !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8],
- // lane_data = [1, 1]>> -> vector<12x8xf32> CHECK: %[[TDESC_B:.*]] =
- // xegpu.create_nd_tdesc %[[ARG_1]][{{%.*}}, {{%.*}}] : memref<32x24xf32> ->
- // !xegpu.tensor_desc<8x12xf32, #xegpu.layout<lane_layout = [8, 2],
- // lane_data = [1, 1]>> CHECK: %[[LOAD_B:.*]] = xegpu.load_nd %[[TDESC_B]] :
- // !xegpu.tensor_desc<8x12xf32, #xegpu.layout<lane_layout = [8, 2],
- // lane_data = [1, 1]>> -> vector<8x12xf32> CHECK: %[[DPAS:.*]] = xegpu.dpas
- // %[[LOAD_A]], %[[LOAD_B]] {layout = #xegpu.layout<lane_layout = [2, 2],
- // lane_data = [1, 1]>} : vector<12x8xf32>, vector<8x12xf32> ->
- // vector<12x12xf32>
+ // CHECK: %[[TDESC_A:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
+ // CHECK: %[[LOAD_A:.*]] = xegpu.load_nd %[[TDESC_A]] : !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>> -> vector<12x8xf32>
+ // CHECK: %[[TDESC_B:.*]] = xegpu.create_nd_tdesc %[[ARG_1]][{{%.*}}, {{%.*}}] : memref<32x24xf32> -> !xegpu.tensor_desc<8x12xf32, #xegpu.layout<lane_layout = [8, 2], lane_data = [1, 1]>>
+ // CHECK: %[[LOAD_B:.*]] = xegpu.load_nd %[[TDESC_B]] : !xegpu.tensor_desc<8x12xf32, #xegpu.layout<lane_layout = [8, 2], lane_data = [1, 1]>> -> vector<8x12xf32>
+ // CHECK: %[[DPAS:.*]] = xegpu.dpas %[[LOAD_A]], %[[LOAD_B]] {layout = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<12x8xf32>, vector<8x12xf32> -> vector<12x12xf32>
%tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
%load_a = xegpu.load_nd %tdesc_a: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>> -> vector<24x32xf32>
%tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<32x24xf32> -> !xegpu.tensor_desc<32x24xf32, #xegpu.layout<sg_layout = [4, 2], sg_data = [8, 12], lane_layout = [8, 2], lane_data = [1, 1]>>
@@ -78,4 +69,14 @@ gpu.func @test_dpas(%a: memref<24x32xf32>, %b: memref<32x24xf32>) {
%dpas = xegpu.dpas %load_a, %load_b {layout = #xegpu.layout<sg_layout = [2, 2], sg_data = [12, 12], lane_layout = [2, 2], lane_data = [1, 1]>} : vector<24x32xf32>, vector<32x24xf32> -> vector<24x24xf32>
gpu.return
}
+
+ // CHECK: test_prefetch_nd_tdesc
+ // CHECK: %[[ARG_0:.*]]: memref<24x32xf32>
+ gpu.func @test_prefetch_nd_tdesc(%src: memref<24x32xf32>) {
+ // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
+ // CHECK: xegpu.prefetch_nd %[[TDESC]] : !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
+ %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+ xegpu.prefetch_nd %tdesc: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+ gpu.return
+ }
}
>From 6a8647fa764e710f5aaeb51b46ae2ea398a959a3 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Sun, 11 May 2025 22:06:24 +0000
Subject: [PATCH 3/3] Remove braces for single statement for and if
---
mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSg.cpp | 8 +++-----
1 file changed, 3 insertions(+), 5 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSg.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSg.cpp
index 5eabb04e3b858..836f307ece9e1 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSg.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSg.cpp
@@ -83,12 +83,11 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
auto dynamicOffsets = op.getOffsets();
for (size_t i = 0, j = 0; i != staticOffsets.size(); i++) {
- if (ShapedType::isDynamic(staticOffsets[i])) {
+ if (ShapedType::isDynamic(staticOffsets[i]))
offsets.push_back(dynamicOffsets[j++]);
- } else {
+ else
offsets.push_back(rewriter.create<arith::ConstantIndexOp>(
op.getLoc(), staticOffsets[i]));
- }
}
return offsets;
}
@@ -314,10 +313,9 @@ struct WgToSgPrefetchNdOp : public OpConversionPattern<xegpu::PrefetchNdOp> {
LogicalResult
matchAndRewrite(xegpu::PrefetchNdOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- for (auto src : adaptor.getTensorDesc()) {
+ for (auto src : adaptor.getTensorDesc())
rewriter.create<xegpu::PrefetchNdOp>(op.getLoc(), TypeRange(), src,
op->getAttrs());
- }
rewriter.eraseOp(op);
return success();
}
More information about the Mlir-commits
mailing list