[Mlir-commits] [mlir] [mlir][xegpu] Minor fixes in XeGPU subgroup distribution. (PR #147846)
Charitha Saumya
llvmlistbot at llvm.org
Tue Jul 15 09:45:57 PDT 2025
https://github.com/charithaintc updated https://github.com/llvm/llvm-project/pull/147846
>From 141c55145b1d766ce94b718f4e311f187c64ba04 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Wed, 9 Jul 2025 22:44:43 +0000
Subject: [PATCH 1/2] save work
---
.../Transforms/XeGPUSubgroupDistribute.cpp | 87 ++++++++++---------
.../Dialect/XeGPU/subgroup-distribute.mlir | 8 +-
2 files changed, 52 insertions(+), 43 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index c072557c2bd22..c87b596b4df43 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -34,6 +34,7 @@
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/SmallVectorExtras.h"
namespace mlir {
namespace xegpu {
@@ -197,9 +198,17 @@ struct MoveFuncBodyToWarpExecuteOnLane0
return isa<gpu::WarpExecuteOnLane0Op>(op);
}))
return failure();
- // Create a new function with the same signature.
+ // Create a new function with the same signature and same attributes.
+ SmallVector<Type> workgroupAttributionsTypes =
+ llvm::map_to_vector(gpuFuncOp.getWorkgroupAttributions(),
+ [](BlockArgument arg) { return arg.getType(); });
+ SmallVector<Type> privateAttributionsTypes =
+ llvm::map_to_vector(gpuFuncOp.getPrivateAttributions(),
+ [](BlockArgument arg) { return arg.getType(); });
auto newGpuFunc = rewriter.create<gpu::GPUFuncOp>(
- gpuFuncOp.getLoc(), gpuFuncOp.getName(), gpuFuncOp.getFunctionType());
+ gpuFuncOp.getLoc(), gpuFuncOp.getName(), gpuFuncOp.getFunctionType(),
+ workgroupAttributionsTypes, privateAttributionsTypes);
+ newGpuFunc->setAttrs(gpuFuncOp->getAttrs());
// Create a WarpExecuteOnLane0Op with same arguments and results as the
// original gpuFuncOp.
rewriter.setInsertionPointToEnd(&newGpuFunc.getFunctionBody().front());
@@ -265,13 +274,13 @@ struct MoveFuncBodyToWarpExecuteOnLane0
/// ```
struct CreateNdDescDistribution final : public gpu::WarpDistributionPattern {
using gpu::WarpDistributionPattern::WarpDistributionPattern;
- LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
+ LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
OpOperand *operand =
- getWarpResult(subgroupOp, llvm::IsaPred<xegpu::CreateNdDescOp>);
+ getWarpResult(warpOp, llvm::IsaPred<xegpu::CreateNdDescOp>);
if (!operand)
return rewriter.notifyMatchFailure(
- subgroupOp, "warp result is not a xegpu::CreateNdDesc op");
+ warpOp, "warp result is not a xegpu::CreateNdDesc op");
auto descOp = operand->get().getDefiningOp<xegpu::CreateNdDescOp>();
unsigned operandIdx = operand->getOperandNumber();
@@ -288,9 +297,9 @@ struct CreateNdDescDistribution final : public gpu::WarpDistributionPattern {
newYieldValues.push_back(operand);
newYieldTypes.push_back(operand.getType());
}
- rewriter.setInsertionPoint(subgroupOp);
+ rewriter.setInsertionPoint(warpOp);
gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
- rewriter, subgroupOp, /* new yieled values = */ newYieldValues,
+ rewriter, warpOp, /* new yieled values = */ newYieldValues,
/* new yielded types = */ newYieldTypes, newRetIndices);
SmallVector<Value> newDescOperands;
@@ -347,10 +356,10 @@ struct CreateNdDescDistribution final : public gpu::WarpDistributionPattern {
/// ```
struct StoreNdDistribution final : public gpu::WarpDistributionPattern {
using gpu::WarpDistributionPattern::WarpDistributionPattern;
- LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
+ LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
auto yield = cast<gpu::YieldOp>(
- subgroupOp.getBodyRegion().getBlocks().begin()->getTerminator());
+ warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
Operation *lastNode = yield->getPrevNode();
auto storeOp = dyn_cast_or_null<xegpu::StoreNdOp>(lastNode);
if (!storeOp)
@@ -372,7 +381,7 @@ struct StoreNdDistribution final : public gpu::WarpDistributionPattern {
SmallVector<size_t> newRetIndices;
gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
- rewriter, subgroupOp,
+ rewriter, warpOp,
/* new yielded values = */
ValueRange{storeOp.getValue(), storeOp.getTensorDesc()},
/* new yielded types = */
@@ -449,21 +458,22 @@ struct StoreNdDistribution final : public gpu::WarpDistributionPattern {
/// ```
struct LoadNdDistribution final : public gpu::WarpDistributionPattern {
using gpu::WarpDistributionPattern::WarpDistributionPattern;
- LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
+ LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
- OpOperand *operand =
- getWarpResult(subgroupOp, llvm::IsaPred<xegpu::LoadNdOp>);
+ OpOperand *operand = getWarpResult(warpOp, [&](Operation *op) {
+ if (!isa<xegpu::LoadNdOp>(op))
+ return false;
+ // Make sure the same load op is the last operation in the warp op body.
+ // This ensure that load op is not sinked earlier violating any barrier
+ // synchronizations.
+ auto yield = cast<gpu::YieldOp>(
+ warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+ return yield->getPrevNode() == op;
+ });
+
if (!operand)
return rewriter.notifyMatchFailure(
- subgroupOp, "warp result is not a xegpu::LoadNd op");
- // Make sure the load op is the last operation in the warp op body. This
- // ensure that load op is not sinked earlier violating any barrier
- // synchronizations.
- auto yield = cast<gpu::YieldOp>(
- subgroupOp.getBodyRegion().getBlocks().begin()->getTerminator());
- Operation *lastNode = yield->getPrevNode();
- if (!dyn_cast_or_null<xegpu::LoadNdOp>(lastNode))
- return failure();
+ warpOp, "warp result is not a xegpu::LoadNd op");
auto loadOp = operand->get().getDefiningOp<xegpu::LoadNdOp>();
xegpu::TensorDescType tensorDescTy = loadOp.getTensorDescType();
@@ -474,11 +484,11 @@ struct LoadNdDistribution final : public gpu::WarpDistributionPattern {
unsigned operandIdx = operand->getOperandNumber();
VectorType distributedTypeByWarpOp =
- cast<VectorType>(subgroupOp.getResult(operandIdx).getType());
+ cast<VectorType>(warpOp.getResult(operandIdx).getType());
SmallVector<size_t> newRetIndices;
gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
- rewriter, subgroupOp,
+ rewriter, warpOp,
/* new yielded values = */ loadOp.getTensorDesc(),
/* new yielded types = */ tensorDescTy, newRetIndices);
@@ -548,12 +558,11 @@ struct LoadNdDistribution final : public gpu::WarpDistributionPattern {
/// ```
struct DpasDistribution final : public gpu::WarpDistributionPattern {
using gpu::WarpDistributionPattern::WarpDistributionPattern;
- LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
+ LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
- OpOperand *operand =
- getWarpResult(subgroupOp, llvm::IsaPred<xegpu::DpasOp>);
+ OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<xegpu::DpasOp>);
if (!operand)
- return rewriter.notifyMatchFailure(subgroupOp,
+ return rewriter.notifyMatchFailure(warpOp,
"warp result is not a xegpu::Dpas op");
auto dpasOp = operand->get().getDefiningOp<xegpu::DpasOp>();
@@ -599,7 +608,7 @@ struct DpasDistribution final : public gpu::WarpDistributionPattern {
// Create a new warp op without the dpas.
SmallVector<size_t> newRetIndices;
gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
- rewriter, subgroupOp, newYieldValues, newYieldTypes, newRetIndices);
+ rewriter, warpOp, newYieldValues, newYieldTypes, newRetIndices);
FailureOr<VectorType> expectedDistLhsTyOrFailure =
xegpu::getDistributedVectorType(dpasOp.getLhsType(), layoutA);
@@ -678,13 +687,13 @@ struct DpasDistribution final : public gpu::WarpDistributionPattern {
/// ```
struct UpdateNdOffsetDistribution final : public gpu::WarpDistributionPattern {
using gpu::WarpDistributionPattern::WarpDistributionPattern;
- LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
+ LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
OpOperand *operand =
- getWarpResult(subgroupOp, llvm::IsaPred<xegpu::UpdateNdOffsetOp>);
+ getWarpResult(warpOp, llvm::IsaPred<xegpu::UpdateNdOffsetOp>);
if (!operand)
return rewriter.notifyMatchFailure(
- subgroupOp, "warp result is not a xegpu::UpdateNdOffset op");
+ warpOp, "warp result is not a xegpu::UpdateNdOffset op");
auto updateOp = operand->get().getDefiningOp<xegpu::UpdateNdOffsetOp>();
unsigned operandIdx = operand->getOperandNumber();
// new update op does not have layout attribute.
@@ -703,7 +712,7 @@ struct UpdateNdOffsetDistribution final : public gpu::WarpDistributionPattern {
}
SmallVector<size_t> newRetIndices;
gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
- rewriter, subgroupOp, newYieldValues, newYieldTypes, newRetIndices);
+ rewriter, warpOp, newYieldValues, newYieldTypes, newRetIndices);
rewriter.setInsertionPointAfter(newWarpOp);
SmallVector<Value> newUpdateOperands;
for (size_t i : newRetIndices) {
@@ -758,10 +767,10 @@ struct UpdateNdOffsetDistribution final : public gpu::WarpDistributionPattern {
/// ```
struct PrefetchNdDistribution final : public gpu::WarpDistributionPattern {
using gpu::WarpDistributionPattern::WarpDistributionPattern;
- LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
+ LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
auto yield = cast<gpu::YieldOp>(
- subgroupOp.getBodyRegion().getBlocks().begin()->getTerminator());
+ warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
Operation *lastNode = yield->getPrevNode();
auto prefetchOp = dyn_cast_or_null<xegpu::PrefetchNdOp>(lastNode);
if (!prefetchOp)
@@ -775,7 +784,7 @@ struct PrefetchNdDistribution final : public gpu::WarpDistributionPattern {
SmallVector<Type, 1> newYieldTypes = {prefetchOp.getTensorDescType()};
SmallVector<size_t> newRetIndices;
gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
- rewriter, subgroupOp, newYieldValues, newYieldTypes, newRetIndices);
+ rewriter, warpOp, newYieldValues, newYieldTypes, newRetIndices);
// Create a new prefetch op outside the warp op with updated tensor
// descriptor type. Source tensor descriptor require type resolution.
xegpu::TensorDescType newTensorDescTy =
@@ -795,17 +804,17 @@ struct PrefetchNdDistribution final : public gpu::WarpDistributionPattern {
/// region. This will simply move the barrier op outside of the warp op.
struct GpuBarrierDistribution final : public gpu::WarpDistributionPattern {
using gpu::WarpDistributionPattern::WarpDistributionPattern;
- LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
+ LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
auto yield = cast<gpu::YieldOp>(
- subgroupOp.getBodyRegion().getBlocks().begin()->getTerminator());
+ warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
Operation *lastNode = yield->getPrevNode();
// The last node must be a gpu::BarrierOp.
auto barrierOp = dyn_cast_or_null<gpu::BarrierOp>(lastNode);
if (!barrierOp)
return failure();
// Move the barrier op outside of the warp op.
- rewriter.setInsertionPointAfter(subgroupOp);
+ rewriter.setInsertionPointAfter(warpOp);
rewriter.create<gpu::BarrierOp>(
barrierOp.getLoc(), barrierOp->getResultTypes(),
barrierOp->getOperands(), barrierOp->getAttrs());
diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
index 3d91b2269bc4b..0b078727d4fe5 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
@@ -95,10 +95,10 @@ gpu.module @test {
// -----
// CHECK-LABEL: gpu.func @load_dpas_store
// CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: memref<8x16xf16>, %[[ARG1:[0-9a-zA-Z]+]]: memref<16x16xf16>, %[[ARG2:[0-9a-zA-Z]+]]: memref<8x16xf32>) {
-// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG1]][%{{.*}}] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
-// CHECK: %[[T1:.*]] = xegpu.load_nd %[[T0]] <{packed}> : !xegpu.tensor_desc<16x16xf16> -> vector<16xf16>
// CHECK: %[[T2:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%{{.*}}] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
// CHECK: %[[T3:.*]] = xegpu.load_nd %[[T2]] : !xegpu.tensor_desc<8x16xf16> -> vector<8xf16>
+// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG1]][%{{.*}}] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
+// CHECK: %[[T1:.*]] = xegpu.load_nd %[[T0]] <{packed}> : !xegpu.tensor_desc<16x16xf16> -> vector<16xf16>
// CHECK-DAG: %[[T4:.*]] = xegpu.dpas %[[T3]], %[[T1]] : vector<8xf16>, vector<16xf16> -> vector<8xf32>
// CHECK-DAG: %[[T5:.*]] = xegpu.create_nd_tdesc %[[ARG2]][%{{.*}}] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
// CHECK: xegpu.store_nd %[[T4]], %[[T5]] : vector<8xf32>, !xegpu.tensor_desc<8x16xf32>
@@ -120,10 +120,10 @@ gpu.module @test {
// -----
// CHECK-LABEL: gpu.func @load_dpas_postop_store
// CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: memref<8x16xf16>, %[[ARG1:[0-9a-zA-Z]+]]: memref<16x16xf16>, %[[ARG2:[0-9a-zA-Z]+]]: memref<8x16xf32>) {
-// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG1]][%{{.*}}] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
-// CHECK: %[[T1:.*]] = xegpu.load_nd %[[T0]] <{packed}> : !xegpu.tensor_desc<16x16xf16> -> vector<16xf16>
// CHECK: %[[T2:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%{{.*}}] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
// CHECK: %[[T3:.*]] = xegpu.load_nd %[[T2]] : !xegpu.tensor_desc<8x16xf16> -> vector<8xf16>
+// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG1]][%{{.*}}] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
+// CHECK: %[[T1:.*]] = xegpu.load_nd %[[T0]] <{packed}> : !xegpu.tensor_desc<16x16xf16> -> vector<16xf16>
// CHECK-DAG: %[[T4:.*]] = xegpu.dpas %[[T3]], %[[T1]] : vector<8xf16>, vector<16xf16> -> vector<8xf32>
// CHECK: %[[T5:.*]] = vector.shape_cast %[[T4]] : vector<8xf32> to vector<8x1xf32>
// CHECK: %[[T6:.*]] = math.exp %[[T5]] {{{.*}}} : vector<8x1xf32>
>From 8589fa20cddc6bf75f06e92052906193c03f702d Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Wed, 9 Jul 2025 23:31:05 +0000
Subject: [PATCH 2/2] save work
---
.../mlir/Dialect/XeGPU/Utils/XeGPUUtils.h | 11 ++++
.../Transforms/XeGPUSubgroupDistribute.cpp | 54 ++++++++-----------
mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp | 27 ++++++++++
3 files changed, 61 insertions(+), 31 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
index 6fea10185402a..488f358ff3802 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
@@ -76,6 +76,17 @@ LayoutAttr getLayoutAttr(const Value value);
/// it will check the operand itself and its defining op.
LayoutAttr getLayoutAttr(const OpOperand &opr);
+/// Removes the LayoutAttr for a given OpOperand or OpResult if it exists.
+template <typename T,
+ typename = std::enable_if_t<std::is_same_v<T, OpOperand> ||
+ std::is_same_v<T, OpResult>>>
+void removeLayoutAttr(const T &operandOrResult);
+
+/// Removes the LayoutAttr for each OpOperand and OpResult of the given
+/// operation if they exist. If the operation contains regions, it is also
+/// applied recursively to the contained operations
+void removeLayoutAttrs(Operation *op);
+
/// Sets the LayoutAttr for a given OpOperand or OpResult by attaching
/// it to the owner's dictionary attributes
template <typename T,
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index c87b596b4df43..1d0f0faf914d4 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -136,19 +136,6 @@ static Value resolveDistributedTy(Value orig, T expected,
return orig;
}
-/// Helper function to filter out the temporary layout attributes attached
-/// during the layout assignment process. These are not needed after going to
-/// SIMT.
-static SmallVector<NamedAttribute>
-removeTemporaryLayoutAttributes(ArrayRef<NamedAttribute> attrs) {
- SmallVector<NamedAttribute> newAttrs;
- for (NamedAttribute attr : attrs) {
- if (!isa<xegpu::LayoutAttr>(attr.getValue()))
- newAttrs.push_back(attr);
- }
- return newAttrs;
-}
-
/// Helper function to check if the layout is packed. Layout is packed if it is
/// 2D and lane_data[0] != 1 (data packed from col dimension).
static bool hasPackedLayout(xegpu::LayoutAttr layout) {
@@ -412,9 +399,9 @@ struct StoreNdDistribution final : public gpu::WarpDistributionPattern {
resolveDistributedTy(newWarpOp.getResult(newRetIndices[1]),
distributedTensorDescTy, rewriter));
- rewriter.create<xegpu::StoreNdOp>(
- newWarpOp.getLoc(), TypeRange{}, newStoreOperands,
- removeTemporaryLayoutAttributes(storeOp->getAttrs()));
+ auto newStoreOp = rewriter.create<xegpu::StoreNdOp>(
+ newWarpOp.getLoc(), TypeRange{}, newStoreOperands, storeOp->getAttrs());
+ xegpu::removeLayoutAttrs(newStoreOp);
rewriter.eraseOp(storeOp);
return success();
}
@@ -508,7 +495,8 @@ struct LoadNdDistribution final : public gpu::WarpDistributionPattern {
newWarpOp.getLoc(), loadNdDistValueTyOrFailure.value(),
resolveDistributedTy(newWarpOp->getResult(newRetIndices[0]),
distributedTensorDescTy, rewriter),
- removeTemporaryLayoutAttributes(loadOp->getAttrs()));
+ loadOp->getAttrs());
+ xegpu::removeLayoutAttrs(newLoadOp);
// Set the packed attribute if the layout requires it.
newLoadOp.setPacked(hasPackedLayout(layout));
Value distributedVal = newWarpOp.getResult(operandIdx);
@@ -639,14 +627,16 @@ struct DpasDistribution final : public gpu::WarpDistributionPattern {
resolveDistributedTy(newWarpOp.getResult(newRetIndices[i]),
newDpasOperandExpectedTypes[i], rewriter));
}
- Value newDpasOp = rewriter.create<xegpu::DpasOp>(
- newWarpOp->getLoc(), distributedResultTy, newDpasOperands,
- removeTemporaryLayoutAttributes(dpasOp->getAttrs()));
+ auto newDpasOp =
+ rewriter.create<xegpu::DpasOp>(newWarpOp->getLoc(), distributedResultTy,
+ newDpasOperands, dpasOp->getAttrs());
+ xegpu::removeLayoutAttrs(newDpasOp);
Value distributedVal = newWarpOp.getResult(operandIdx);
// Resolve the output type.
- newDpasOp = resolveDistributedTy(
- newDpasOp, distResultTypeByWarpOpOrFailure.value(), rewriter);
- rewriter.replaceAllUsesWith(distributedVal, newDpasOp);
+ Value typeResolved =
+ resolveDistributedTy(newDpasOp.getResult(),
+ distResultTypeByWarpOpOrFailure.value(), rewriter);
+ rewriter.replaceAllUsesWith(distributedVal, typeResolved);
return success();
}
};
@@ -726,14 +716,15 @@ struct UpdateNdOffsetDistribution final : public gpu::WarpDistributionPattern {
}
}
// Create a new update op outside the warp op.
- Value newUpdateOp = rewriter.create<xegpu::UpdateNdOffsetOp>(
+ auto newUpdateOp = rewriter.create<xegpu::UpdateNdOffsetOp>(
newWarpOp.getLoc(), newTensorDescTy, newUpdateOperands,
- removeTemporaryLayoutAttributes(updateOp->getAttrs()));
+ updateOp->getAttrs());
+ xegpu::removeLayoutAttrs(newUpdateOp);
Value distributedVal = newWarpOp.getResult(operandIdx);
// Resolve the distributed type with the original type.
- newUpdateOp =
- resolveDistributedTy(newUpdateOp, distributedVal.getType(), rewriter);
- rewriter.replaceAllUsesWith(distributedVal, newUpdateOp);
+ Value typeResolved = resolveDistributedTy(
+ newUpdateOp.getResult(), distributedVal.getType(), rewriter);
+ rewriter.replaceAllUsesWith(distributedVal, typeResolved);
return success();
}
};
@@ -792,9 +783,10 @@ struct PrefetchNdDistribution final : public gpu::WarpDistributionPattern {
rewriter.setInsertionPointAfter(newWarpOp);
SmallVector<Value> newPrefetchOperands = {resolveDistributedTy(
newWarpOp.getResult(newRetIndices[0]), newTensorDescTy, rewriter)};
- rewriter.create<xegpu::PrefetchNdOp>(
- newWarpOp.getLoc(), TypeRange{}, newPrefetchOperands,
- removeTemporaryLayoutAttributes(prefetchOp->getAttrs()));
+ rewriter.create<xegpu::PrefetchNdOp>(newWarpOp.getLoc(), TypeRange{},
+ newPrefetchOperands,
+ prefetchOp->getAttrs());
+ xegpu::removeLayoutAttrs(prefetchOp);
rewriter.eraseOp(prefetchOp);
return success();
}
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
index 6b85a66a8bd36..64d58153baa74 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
@@ -184,6 +184,33 @@ void xegpu::setLayoutAttrs(Operation *op,
});
}
+template <typename T, typename>
+void xegpu::removeLayoutAttr(const T &operandOrResult) {
+ Operation *owner = operandOrResult.getOwner();
+ std::string name = xegpu::getLayoutName(operandOrResult);
+ if (owner->hasAttrOfType<LayoutAttr>(name))
+ owner->removeAttr(name);
+}
+
+// Explicit instantiation for OpResult
+template void
+xegpu::removeLayoutAttr<mlir::OpResult>(const mlir::OpResult &result);
+
+// Explicit instantiation for OpOperand
+template void
+xegpu::removeLayoutAttr<mlir::OpOperand>(const mlir::OpOperand &operand);
+
+void xegpu::removeLayoutAttrs(Operation *op) {
+ op->walk([&](Operation *nestOp) {
+ for (OpOperand &opr : nestOp->getOpOperands()) {
+ removeLayoutAttr(opr);
+ }
+ for (OpResult result : nestOp->getOpResults()) {
+ removeLayoutAttr(result);
+ }
+ });
+}
+
SmallVector<Value>
xegpu::extractVectorsWithShapeFromValue(OpBuilder &builder, Location loc,
Value value, ArrayRef<int64_t> shape) {
More information about the Mlir-commits
mailing list