[Mlir-commits] [mlir] [mlir][xegpu] Minor fixes in XeGPU subgroup distribution. (PR #147846)
Charitha Saumya
llvmlistbot at llvm.org
Wed Jul 9 15:49:14 PDT 2025
https://github.com/charithaintc created https://github.com/llvm/llvm-project/pull/147846
This PR addresses the following issues.
1. Add the missing attributes when creating a new GPU funcOp in `MoveFuncBodyToWarpExecuteOnLane0` pattern.
2. Bug fix in LoadNd distribution to make sure LoadOp is the last op in warpOp region before it is distributed (needed for preserving the memory op ordering during distribution).
>From 141c55145b1d766ce94b718f4e311f187c64ba04 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Wed, 9 Jul 2025 22:44:43 +0000
Subject: [PATCH] save work
---
.../Transforms/XeGPUSubgroupDistribute.cpp | 87 ++++++++++---------
.../Dialect/XeGPU/subgroup-distribute.mlir | 8 +-
2 files changed, 52 insertions(+), 43 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index c072557c2bd22..c87b596b4df43 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -34,6 +34,7 @@
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/SmallVectorExtras.h"
namespace mlir {
namespace xegpu {
@@ -197,9 +198,17 @@ struct MoveFuncBodyToWarpExecuteOnLane0
return isa<gpu::WarpExecuteOnLane0Op>(op);
}))
return failure();
- // Create a new function with the same signature.
+ // Create a new function with the same signature and same attributes.
+ SmallVector<Type> workgroupAttributionsTypes =
+ llvm::map_to_vector(gpuFuncOp.getWorkgroupAttributions(),
+ [](BlockArgument arg) { return arg.getType(); });
+ SmallVector<Type> privateAttributionsTypes =
+ llvm::map_to_vector(gpuFuncOp.getPrivateAttributions(),
+ [](BlockArgument arg) { return arg.getType(); });
auto newGpuFunc = rewriter.create<gpu::GPUFuncOp>(
- gpuFuncOp.getLoc(), gpuFuncOp.getName(), gpuFuncOp.getFunctionType());
+ gpuFuncOp.getLoc(), gpuFuncOp.getName(), gpuFuncOp.getFunctionType(),
+ workgroupAttributionsTypes, privateAttributionsTypes);
+ newGpuFunc->setAttrs(gpuFuncOp->getAttrs());
// Create a WarpExecuteOnLane0Op with same arguments and results as the
// original gpuFuncOp.
rewriter.setInsertionPointToEnd(&newGpuFunc.getFunctionBody().front());
@@ -265,13 +274,13 @@ struct MoveFuncBodyToWarpExecuteOnLane0
/// ```
struct CreateNdDescDistribution final : public gpu::WarpDistributionPattern {
using gpu::WarpDistributionPattern::WarpDistributionPattern;
- LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
+ LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
OpOperand *operand =
- getWarpResult(subgroupOp, llvm::IsaPred<xegpu::CreateNdDescOp>);
+ getWarpResult(warpOp, llvm::IsaPred<xegpu::CreateNdDescOp>);
if (!operand)
return rewriter.notifyMatchFailure(
- subgroupOp, "warp result is not a xegpu::CreateNdDesc op");
+ warpOp, "warp result is not a xegpu::CreateNdDesc op");
auto descOp = operand->get().getDefiningOp<xegpu::CreateNdDescOp>();
unsigned operandIdx = operand->getOperandNumber();
@@ -288,9 +297,9 @@ struct CreateNdDescDistribution final : public gpu::WarpDistributionPattern {
newYieldValues.push_back(operand);
newYieldTypes.push_back(operand.getType());
}
- rewriter.setInsertionPoint(subgroupOp);
+ rewriter.setInsertionPoint(warpOp);
gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
- rewriter, subgroupOp, /* new yieled values = */ newYieldValues,
+ rewriter, warpOp, /* new yieled values = */ newYieldValues,
/* new yielded types = */ newYieldTypes, newRetIndices);
SmallVector<Value> newDescOperands;
@@ -347,10 +356,10 @@ struct CreateNdDescDistribution final : public gpu::WarpDistributionPattern {
/// ```
struct StoreNdDistribution final : public gpu::WarpDistributionPattern {
using gpu::WarpDistributionPattern::WarpDistributionPattern;
- LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
+ LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
auto yield = cast<gpu::YieldOp>(
- subgroupOp.getBodyRegion().getBlocks().begin()->getTerminator());
+ warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
Operation *lastNode = yield->getPrevNode();
auto storeOp = dyn_cast_or_null<xegpu::StoreNdOp>(lastNode);
if (!storeOp)
@@ -372,7 +381,7 @@ struct StoreNdDistribution final : public gpu::WarpDistributionPattern {
SmallVector<size_t> newRetIndices;
gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
- rewriter, subgroupOp,
+ rewriter, warpOp,
/* new yielded values = */
ValueRange{storeOp.getValue(), storeOp.getTensorDesc()},
/* new yielded types = */
@@ -449,21 +458,22 @@ struct StoreNdDistribution final : public gpu::WarpDistributionPattern {
/// ```
struct LoadNdDistribution final : public gpu::WarpDistributionPattern {
using gpu::WarpDistributionPattern::WarpDistributionPattern;
- LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
+ LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
- OpOperand *operand =
- getWarpResult(subgroupOp, llvm::IsaPred<xegpu::LoadNdOp>);
+ OpOperand *operand = getWarpResult(warpOp, [&](Operation *op) {
+ if (!isa<xegpu::LoadNdOp>(op))
+ return false;
+ // Make sure the same load op is the last operation in the warp op body.
+ // This ensure that load op is not sinked earlier violating any barrier
+ // synchronizations.
+ auto yield = cast<gpu::YieldOp>(
+ warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+ return yield->getPrevNode() == op;
+ });
+
if (!operand)
return rewriter.notifyMatchFailure(
- subgroupOp, "warp result is not a xegpu::LoadNd op");
- // Make sure the load op is the last operation in the warp op body. This
- // ensure that load op is not sinked earlier violating any barrier
- // synchronizations.
- auto yield = cast<gpu::YieldOp>(
- subgroupOp.getBodyRegion().getBlocks().begin()->getTerminator());
- Operation *lastNode = yield->getPrevNode();
- if (!dyn_cast_or_null<xegpu::LoadNdOp>(lastNode))
- return failure();
+ warpOp, "warp result is not a xegpu::LoadNd op");
auto loadOp = operand->get().getDefiningOp<xegpu::LoadNdOp>();
xegpu::TensorDescType tensorDescTy = loadOp.getTensorDescType();
@@ -474,11 +484,11 @@ struct LoadNdDistribution final : public gpu::WarpDistributionPattern {
unsigned operandIdx = operand->getOperandNumber();
VectorType distributedTypeByWarpOp =
- cast<VectorType>(subgroupOp.getResult(operandIdx).getType());
+ cast<VectorType>(warpOp.getResult(operandIdx).getType());
SmallVector<size_t> newRetIndices;
gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
- rewriter, subgroupOp,
+ rewriter, warpOp,
/* new yielded values = */ loadOp.getTensorDesc(),
/* new yielded types = */ tensorDescTy, newRetIndices);
@@ -548,12 +558,11 @@ struct LoadNdDistribution final : public gpu::WarpDistributionPattern {
/// ```
struct DpasDistribution final : public gpu::WarpDistributionPattern {
using gpu::WarpDistributionPattern::WarpDistributionPattern;
- LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
+ LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
- OpOperand *operand =
- getWarpResult(subgroupOp, llvm::IsaPred<xegpu::DpasOp>);
+ OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<xegpu::DpasOp>);
if (!operand)
- return rewriter.notifyMatchFailure(subgroupOp,
+ return rewriter.notifyMatchFailure(warpOp,
"warp result is not a xegpu::Dpas op");
auto dpasOp = operand->get().getDefiningOp<xegpu::DpasOp>();
@@ -599,7 +608,7 @@ struct DpasDistribution final : public gpu::WarpDistributionPattern {
// Create a new warp op without the dpas.
SmallVector<size_t> newRetIndices;
gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
- rewriter, subgroupOp, newYieldValues, newYieldTypes, newRetIndices);
+ rewriter, warpOp, newYieldValues, newYieldTypes, newRetIndices);
FailureOr<VectorType> expectedDistLhsTyOrFailure =
xegpu::getDistributedVectorType(dpasOp.getLhsType(), layoutA);
@@ -678,13 +687,13 @@ struct DpasDistribution final : public gpu::WarpDistributionPattern {
/// ```
struct UpdateNdOffsetDistribution final : public gpu::WarpDistributionPattern {
using gpu::WarpDistributionPattern::WarpDistributionPattern;
- LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
+ LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
OpOperand *operand =
- getWarpResult(subgroupOp, llvm::IsaPred<xegpu::UpdateNdOffsetOp>);
+ getWarpResult(warpOp, llvm::IsaPred<xegpu::UpdateNdOffsetOp>);
if (!operand)
return rewriter.notifyMatchFailure(
- subgroupOp, "warp result is not a xegpu::UpdateNdOffset op");
+ warpOp, "warp result is not a xegpu::UpdateNdOffset op");
auto updateOp = operand->get().getDefiningOp<xegpu::UpdateNdOffsetOp>();
unsigned operandIdx = operand->getOperandNumber();
// new update op does not have layout attribute.
@@ -703,7 +712,7 @@ struct UpdateNdOffsetDistribution final : public gpu::WarpDistributionPattern {
}
SmallVector<size_t> newRetIndices;
gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
- rewriter, subgroupOp, newYieldValues, newYieldTypes, newRetIndices);
+ rewriter, warpOp, newYieldValues, newYieldTypes, newRetIndices);
rewriter.setInsertionPointAfter(newWarpOp);
SmallVector<Value> newUpdateOperands;
for (size_t i : newRetIndices) {
@@ -758,10 +767,10 @@ struct UpdateNdOffsetDistribution final : public gpu::WarpDistributionPattern {
/// ```
struct PrefetchNdDistribution final : public gpu::WarpDistributionPattern {
using gpu::WarpDistributionPattern::WarpDistributionPattern;
- LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
+ LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
auto yield = cast<gpu::YieldOp>(
- subgroupOp.getBodyRegion().getBlocks().begin()->getTerminator());
+ warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
Operation *lastNode = yield->getPrevNode();
auto prefetchOp = dyn_cast_or_null<xegpu::PrefetchNdOp>(lastNode);
if (!prefetchOp)
@@ -775,7 +784,7 @@ struct PrefetchNdDistribution final : public gpu::WarpDistributionPattern {
SmallVector<Type, 1> newYieldTypes = {prefetchOp.getTensorDescType()};
SmallVector<size_t> newRetIndices;
gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
- rewriter, subgroupOp, newYieldValues, newYieldTypes, newRetIndices);
+ rewriter, warpOp, newYieldValues, newYieldTypes, newRetIndices);
// Create a new prefetch op outside the warp op with updated tensor
// descriptor type. Source tensor descriptor require type resolution.
xegpu::TensorDescType newTensorDescTy =
@@ -795,17 +804,17 @@ struct PrefetchNdDistribution final : public gpu::WarpDistributionPattern {
/// region. This will simply move the barrier op outside of the warp op.
struct GpuBarrierDistribution final : public gpu::WarpDistributionPattern {
using gpu::WarpDistributionPattern::WarpDistributionPattern;
- LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
+ LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
auto yield = cast<gpu::YieldOp>(
- subgroupOp.getBodyRegion().getBlocks().begin()->getTerminator());
+ warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
Operation *lastNode = yield->getPrevNode();
// The last node must be a gpu::BarrierOp.
auto barrierOp = dyn_cast_or_null<gpu::BarrierOp>(lastNode);
if (!barrierOp)
return failure();
// Move the barrier op outside of the warp op.
- rewriter.setInsertionPointAfter(subgroupOp);
+ rewriter.setInsertionPointAfter(warpOp);
rewriter.create<gpu::BarrierOp>(
barrierOp.getLoc(), barrierOp->getResultTypes(),
barrierOp->getOperands(), barrierOp->getAttrs());
diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
index 3d91b2269bc4b..0b078727d4fe5 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
@@ -95,10 +95,10 @@ gpu.module @test {
// -----
// CHECK-LABEL: gpu.func @load_dpas_store
// CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: memref<8x16xf16>, %[[ARG1:[0-9a-zA-Z]+]]: memref<16x16xf16>, %[[ARG2:[0-9a-zA-Z]+]]: memref<8x16xf32>) {
-// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG1]][%{{.*}}] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
-// CHECK: %[[T1:.*]] = xegpu.load_nd %[[T0]] <{packed}> : !xegpu.tensor_desc<16x16xf16> -> vector<16xf16>
// CHECK: %[[T2:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%{{.*}}] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
// CHECK: %[[T3:.*]] = xegpu.load_nd %[[T2]] : !xegpu.tensor_desc<8x16xf16> -> vector<8xf16>
+// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG1]][%{{.*}}] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
+// CHECK: %[[T1:.*]] = xegpu.load_nd %[[T0]] <{packed}> : !xegpu.tensor_desc<16x16xf16> -> vector<16xf16>
// CHECK-DAG: %[[T4:.*]] = xegpu.dpas %[[T3]], %[[T1]] : vector<8xf16>, vector<16xf16> -> vector<8xf32>
// CHECK-DAG: %[[T5:.*]] = xegpu.create_nd_tdesc %[[ARG2]][%{{.*}}] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
// CHECK: xegpu.store_nd %[[T4]], %[[T5]] : vector<8xf32>, !xegpu.tensor_desc<8x16xf32>
@@ -120,10 +120,10 @@ gpu.module @test {
// -----
// CHECK-LABEL: gpu.func @load_dpas_postop_store
// CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: memref<8x16xf16>, %[[ARG1:[0-9a-zA-Z]+]]: memref<16x16xf16>, %[[ARG2:[0-9a-zA-Z]+]]: memref<8x16xf32>) {
-// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG1]][%{{.*}}] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
-// CHECK: %[[T1:.*]] = xegpu.load_nd %[[T0]] <{packed}> : !xegpu.tensor_desc<16x16xf16> -> vector<16xf16>
// CHECK: %[[T2:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%{{.*}}] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
// CHECK: %[[T3:.*]] = xegpu.load_nd %[[T2]] : !xegpu.tensor_desc<8x16xf16> -> vector<8xf16>
+// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG1]][%{{.*}}] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
+// CHECK: %[[T1:.*]] = xegpu.load_nd %[[T0]] <{packed}> : !xegpu.tensor_desc<16x16xf16> -> vector<16xf16>
// CHECK-DAG: %[[T4:.*]] = xegpu.dpas %[[T3]], %[[T1]] : vector<8xf16>, vector<16xf16> -> vector<8xf32>
// CHECK: %[[T5:.*]] = vector.shape_cast %[[T4]] : vector<8xf32> to vector<8x1xf32>
// CHECK: %[[T6:.*]] = math.exp %[[T5]] {{{.*}}} : vector<8x1xf32>
More information about the Mlir-commits
mailing list