[Mlir-commits] [mlir] [mlir][xegpu] Minor fixes in XeGPU subgroup distribution. (PR #147846)

Charitha Saumya llvmlistbot at llvm.org
Tue Jul 15 09:45:57 PDT 2025


https://github.com/charithaintc updated https://github.com/llvm/llvm-project/pull/147846

>From 141c55145b1d766ce94b718f4e311f187c64ba04 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Wed, 9 Jul 2025 22:44:43 +0000
Subject: [PATCH 1/2] save work

---
 .../Transforms/XeGPUSubgroupDistribute.cpp    | 87 ++++++++++---------
 .../Dialect/XeGPU/subgroup-distribute.mlir    |  8 +-
 2 files changed, 52 insertions(+), 43 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index c072557c2bd22..c87b596b4df43 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -34,6 +34,7 @@
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/SmallVectorExtras.h"
 
 namespace mlir {
 namespace xegpu {
@@ -197,9 +198,17 @@ struct MoveFuncBodyToWarpExecuteOnLane0
           return isa<gpu::WarpExecuteOnLane0Op>(op);
         }))
       return failure();
-    // Create a new function with the same signature.
+    // Create a new function with the same signature and same attributes.
+    SmallVector<Type> workgroupAttributionsTypes =
+        llvm::map_to_vector(gpuFuncOp.getWorkgroupAttributions(),
+                            [](BlockArgument arg) { return arg.getType(); });
+    SmallVector<Type> privateAttributionsTypes =
+        llvm::map_to_vector(gpuFuncOp.getPrivateAttributions(),
+                            [](BlockArgument arg) { return arg.getType(); });
     auto newGpuFunc = rewriter.create<gpu::GPUFuncOp>(
-        gpuFuncOp.getLoc(), gpuFuncOp.getName(), gpuFuncOp.getFunctionType());
+        gpuFuncOp.getLoc(), gpuFuncOp.getName(), gpuFuncOp.getFunctionType(),
+        workgroupAttributionsTypes, privateAttributionsTypes);
+    newGpuFunc->setAttrs(gpuFuncOp->getAttrs());
     // Create a WarpExecuteOnLane0Op with same arguments and results as the
     // original gpuFuncOp.
     rewriter.setInsertionPointToEnd(&newGpuFunc.getFunctionBody().front());
@@ -265,13 +274,13 @@ struct MoveFuncBodyToWarpExecuteOnLane0
 /// ```
 struct CreateNdDescDistribution final : public gpu::WarpDistributionPattern {
   using gpu::WarpDistributionPattern::WarpDistributionPattern;
-  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
+  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
     OpOperand *operand =
-        getWarpResult(subgroupOp, llvm::IsaPred<xegpu::CreateNdDescOp>);
+        getWarpResult(warpOp, llvm::IsaPred<xegpu::CreateNdDescOp>);
     if (!operand)
       return rewriter.notifyMatchFailure(
-          subgroupOp, "warp result is not a xegpu::CreateNdDesc op");
+          warpOp, "warp result is not a xegpu::CreateNdDesc op");
     auto descOp = operand->get().getDefiningOp<xegpu::CreateNdDescOp>();
     unsigned operandIdx = operand->getOperandNumber();
 
@@ -288,9 +297,9 @@ struct CreateNdDescDistribution final : public gpu::WarpDistributionPattern {
       newYieldValues.push_back(operand);
       newYieldTypes.push_back(operand.getType());
     }
-    rewriter.setInsertionPoint(subgroupOp);
+    rewriter.setInsertionPoint(warpOp);
     gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
-        rewriter, subgroupOp, /* new yieled values = */ newYieldValues,
+        rewriter, warpOp, /* new yieled values = */ newYieldValues,
         /* new yielded types = */ newYieldTypes, newRetIndices);
 
     SmallVector<Value> newDescOperands;
@@ -347,10 +356,10 @@ struct CreateNdDescDistribution final : public gpu::WarpDistributionPattern {
 /// ```
 struct StoreNdDistribution final : public gpu::WarpDistributionPattern {
   using gpu::WarpDistributionPattern::WarpDistributionPattern;
-  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
+  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
     auto yield = cast<gpu::YieldOp>(
-        subgroupOp.getBodyRegion().getBlocks().begin()->getTerminator());
+        warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
     Operation *lastNode = yield->getPrevNode();
     auto storeOp = dyn_cast_or_null<xegpu::StoreNdOp>(lastNode);
     if (!storeOp)
@@ -372,7 +381,7 @@ struct StoreNdDistribution final : public gpu::WarpDistributionPattern {
 
     SmallVector<size_t> newRetIndices;
     gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
-        rewriter, subgroupOp,
+        rewriter, warpOp,
         /* new yielded values = */
         ValueRange{storeOp.getValue(), storeOp.getTensorDesc()},
         /* new yielded types = */
@@ -449,21 +458,22 @@ struct StoreNdDistribution final : public gpu::WarpDistributionPattern {
 /// ```
 struct LoadNdDistribution final : public gpu::WarpDistributionPattern {
   using gpu::WarpDistributionPattern::WarpDistributionPattern;
-  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
+  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
-    OpOperand *operand =
-        getWarpResult(subgroupOp, llvm::IsaPred<xegpu::LoadNdOp>);
+    OpOperand *operand = getWarpResult(warpOp, [&](Operation *op) {
+      if (!isa<xegpu::LoadNdOp>(op))
+        return false;
+      // Make sure the same load op is the last operation in the warp op body.
+      // This ensure that load op is not sinked earlier violating any barrier
+      // synchronizations.
+      auto yield = cast<gpu::YieldOp>(
+          warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+      return yield->getPrevNode() == op;
+    });
+
     if (!operand)
       return rewriter.notifyMatchFailure(
-          subgroupOp, "warp result is not a xegpu::LoadNd op");
-    // Make sure the load op is the last operation in the warp op body. This
-    // ensure that load op is not sinked earlier violating any barrier
-    // synchronizations.
-    auto yield = cast<gpu::YieldOp>(
-        subgroupOp.getBodyRegion().getBlocks().begin()->getTerminator());
-    Operation *lastNode = yield->getPrevNode();
-    if (!dyn_cast_or_null<xegpu::LoadNdOp>(lastNode))
-      return failure();
+          warpOp, "warp result is not a xegpu::LoadNd op");
 
     auto loadOp = operand->get().getDefiningOp<xegpu::LoadNdOp>();
     xegpu::TensorDescType tensorDescTy = loadOp.getTensorDescType();
@@ -474,11 +484,11 @@ struct LoadNdDistribution final : public gpu::WarpDistributionPattern {
 
     unsigned operandIdx = operand->getOperandNumber();
     VectorType distributedTypeByWarpOp =
-        cast<VectorType>(subgroupOp.getResult(operandIdx).getType());
+        cast<VectorType>(warpOp.getResult(operandIdx).getType());
 
     SmallVector<size_t> newRetIndices;
     gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
-        rewriter, subgroupOp,
+        rewriter, warpOp,
         /* new yielded values = */ loadOp.getTensorDesc(),
         /* new yielded types = */ tensorDescTy, newRetIndices);
 
@@ -548,12 +558,11 @@ struct LoadNdDistribution final : public gpu::WarpDistributionPattern {
 /// ```
 struct DpasDistribution final : public gpu::WarpDistributionPattern {
   using gpu::WarpDistributionPattern::WarpDistributionPattern;
-  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
+  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
-    OpOperand *operand =
-        getWarpResult(subgroupOp, llvm::IsaPred<xegpu::DpasOp>);
+    OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<xegpu::DpasOp>);
     if (!operand)
-      return rewriter.notifyMatchFailure(subgroupOp,
+      return rewriter.notifyMatchFailure(warpOp,
                                          "warp result is not a xegpu::Dpas op");
 
     auto dpasOp = operand->get().getDefiningOp<xegpu::DpasOp>();
@@ -599,7 +608,7 @@ struct DpasDistribution final : public gpu::WarpDistributionPattern {
     // Create a new warp op without the dpas.
     SmallVector<size_t> newRetIndices;
     gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
-        rewriter, subgroupOp, newYieldValues, newYieldTypes, newRetIndices);
+        rewriter, warpOp, newYieldValues, newYieldTypes, newRetIndices);
 
     FailureOr<VectorType> expectedDistLhsTyOrFailure =
         xegpu::getDistributedVectorType(dpasOp.getLhsType(), layoutA);
@@ -678,13 +687,13 @@ struct DpasDistribution final : public gpu::WarpDistributionPattern {
 /// ```
 struct UpdateNdOffsetDistribution final : public gpu::WarpDistributionPattern {
   using gpu::WarpDistributionPattern::WarpDistributionPattern;
-  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
+  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
     OpOperand *operand =
-        getWarpResult(subgroupOp, llvm::IsaPred<xegpu::UpdateNdOffsetOp>);
+        getWarpResult(warpOp, llvm::IsaPred<xegpu::UpdateNdOffsetOp>);
     if (!operand)
       return rewriter.notifyMatchFailure(
-          subgroupOp, "warp result is not a xegpu::UpdateNdOffset op");
+          warpOp, "warp result is not a xegpu::UpdateNdOffset op");
     auto updateOp = operand->get().getDefiningOp<xegpu::UpdateNdOffsetOp>();
     unsigned operandIdx = operand->getOperandNumber();
     // new update op does not have layout attribute.
@@ -703,7 +712,7 @@ struct UpdateNdOffsetDistribution final : public gpu::WarpDistributionPattern {
     }
     SmallVector<size_t> newRetIndices;
     gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
-        rewriter, subgroupOp, newYieldValues, newYieldTypes, newRetIndices);
+        rewriter, warpOp, newYieldValues, newYieldTypes, newRetIndices);
     rewriter.setInsertionPointAfter(newWarpOp);
     SmallVector<Value> newUpdateOperands;
     for (size_t i : newRetIndices) {
@@ -758,10 +767,10 @@ struct UpdateNdOffsetDistribution final : public gpu::WarpDistributionPattern {
 /// ```
 struct PrefetchNdDistribution final : public gpu::WarpDistributionPattern {
   using gpu::WarpDistributionPattern::WarpDistributionPattern;
-  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
+  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
     auto yield = cast<gpu::YieldOp>(
-        subgroupOp.getBodyRegion().getBlocks().begin()->getTerminator());
+        warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
     Operation *lastNode = yield->getPrevNode();
     auto prefetchOp = dyn_cast_or_null<xegpu::PrefetchNdOp>(lastNode);
     if (!prefetchOp)
@@ -775,7 +784,7 @@ struct PrefetchNdDistribution final : public gpu::WarpDistributionPattern {
     SmallVector<Type, 1> newYieldTypes = {prefetchOp.getTensorDescType()};
     SmallVector<size_t> newRetIndices;
     gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
-        rewriter, subgroupOp, newYieldValues, newYieldTypes, newRetIndices);
+        rewriter, warpOp, newYieldValues, newYieldTypes, newRetIndices);
     // Create a new prefetch op outside the warp op with updated tensor
     // descriptor type. Source tensor descriptor require type resolution.
     xegpu::TensorDescType newTensorDescTy =
@@ -795,17 +804,17 @@ struct PrefetchNdDistribution final : public gpu::WarpDistributionPattern {
 /// region. This will simply move the barrier op outside of the warp op.
 struct GpuBarrierDistribution final : public gpu::WarpDistributionPattern {
   using gpu::WarpDistributionPattern::WarpDistributionPattern;
-  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
+  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
     auto yield = cast<gpu::YieldOp>(
-        subgroupOp.getBodyRegion().getBlocks().begin()->getTerminator());
+        warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
     Operation *lastNode = yield->getPrevNode();
     // The last node must be a gpu::BarrierOp.
     auto barrierOp = dyn_cast_or_null<gpu::BarrierOp>(lastNode);
     if (!barrierOp)
       return failure();
     // Move the barrier op outside of the warp op.
-    rewriter.setInsertionPointAfter(subgroupOp);
+    rewriter.setInsertionPointAfter(warpOp);
     rewriter.create<gpu::BarrierOp>(
         barrierOp.getLoc(), barrierOp->getResultTypes(),
         barrierOp->getOperands(), barrierOp->getAttrs());
diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
index 3d91b2269bc4b..0b078727d4fe5 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
@@ -95,10 +95,10 @@ gpu.module @test {
 // -----
 // CHECK-LABEL: gpu.func @load_dpas_store
 // CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: memref<8x16xf16>, %[[ARG1:[0-9a-zA-Z]+]]: memref<16x16xf16>, %[[ARG2:[0-9a-zA-Z]+]]: memref<8x16xf32>) {
-// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG1]][%{{.*}}] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
-// CHECK: %[[T1:.*]] = xegpu.load_nd %[[T0]] <{packed}> : !xegpu.tensor_desc<16x16xf16> -> vector<16xf16>
 // CHECK: %[[T2:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%{{.*}}] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
 // CHECK: %[[T3:.*]] = xegpu.load_nd %[[T2]]  : !xegpu.tensor_desc<8x16xf16> -> vector<8xf16>
+// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG1]][%{{.*}}] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
+// CHECK: %[[T1:.*]] = xegpu.load_nd %[[T0]] <{packed}> : !xegpu.tensor_desc<16x16xf16> -> vector<16xf16>
 // CHECK-DAG: %[[T4:.*]] = xegpu.dpas %[[T3]], %[[T1]] : vector<8xf16>, vector<16xf16> -> vector<8xf32>
 // CHECK-DAG: %[[T5:.*]] = xegpu.create_nd_tdesc %[[ARG2]][%{{.*}}] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
 // CHECK: xegpu.store_nd %[[T4]], %[[T5]]  : vector<8xf32>, !xegpu.tensor_desc<8x16xf32>
@@ -120,10 +120,10 @@ gpu.module @test {
 // -----
 // CHECK-LABEL: gpu.func @load_dpas_postop_store
 // CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: memref<8x16xf16>, %[[ARG1:[0-9a-zA-Z]+]]: memref<16x16xf16>, %[[ARG2:[0-9a-zA-Z]+]]: memref<8x16xf32>) {
-// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG1]][%{{.*}}] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
-// CHECK: %[[T1:.*]] = xegpu.load_nd %[[T0]] <{packed}> : !xegpu.tensor_desc<16x16xf16> -> vector<16xf16>
 // CHECK: %[[T2:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%{{.*}}] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
 // CHECK: %[[T3:.*]] = xegpu.load_nd %[[T2]]  : !xegpu.tensor_desc<8x16xf16> -> vector<8xf16>
+// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG1]][%{{.*}}] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
+// CHECK: %[[T1:.*]] = xegpu.load_nd %[[T0]] <{packed}> : !xegpu.tensor_desc<16x16xf16> -> vector<16xf16>
 // CHECK-DAG: %[[T4:.*]] = xegpu.dpas %[[T3]], %[[T1]] : vector<8xf16>, vector<16xf16> -> vector<8xf32>
 // CHECK: %[[T5:.*]] = vector.shape_cast %[[T4]] : vector<8xf32> to vector<8x1xf32>
 // CHECK: %[[T6:.*]] = math.exp %[[T5]] {{{.*}}} : vector<8x1xf32>

>From 8589fa20cddc6bf75f06e92052906193c03f702d Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Wed, 9 Jul 2025 23:31:05 +0000
Subject: [PATCH 2/2] save work

---
 .../mlir/Dialect/XeGPU/Utils/XeGPUUtils.h     | 11 ++++
 .../Transforms/XeGPUSubgroupDistribute.cpp    | 54 ++++++++-----------
 mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp   | 27 ++++++++++
 3 files changed, 61 insertions(+), 31 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
index 6fea10185402a..488f358ff3802 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
@@ -76,6 +76,17 @@ LayoutAttr getLayoutAttr(const Value value);
 /// it will check the operand itself and its defining op.
 LayoutAttr getLayoutAttr(const OpOperand &opr);
 
+/// Removes the LayoutAttr for a given OpOperand or OpResult if it exists.
+template <typename T,
+          typename = std::enable_if_t<std::is_same_v<T, OpOperand> ||
+                                      std::is_same_v<T, OpResult>>>
+void removeLayoutAttr(const T &operandOrResult);
+
+/// Removes the LayoutAttr for each OpOperand and OpResult of the given
+/// operation if they exist. If the operation contains regions, it is also
+/// applied recursively to the contained operations
+void removeLayoutAttrs(Operation *op);
+
 /// Sets the LayoutAttr for a given OpOperand or OpResult by attaching
 /// it to the owner's dictionary attributes
 template <typename T,
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index c87b596b4df43..1d0f0faf914d4 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -136,19 +136,6 @@ static Value resolveDistributedTy(Value orig, T expected,
   return orig;
 }
 
-/// Helper function to filter out the temporary layout attributes attached
-/// during the layout assignment process. These are not needed after going to
-/// SIMT.
-static SmallVector<NamedAttribute>
-removeTemporaryLayoutAttributes(ArrayRef<NamedAttribute> attrs) {
-  SmallVector<NamedAttribute> newAttrs;
-  for (NamedAttribute attr : attrs) {
-    if (!isa<xegpu::LayoutAttr>(attr.getValue()))
-      newAttrs.push_back(attr);
-  }
-  return newAttrs;
-}
-
 /// Helper function to check if the layout is packed. Layout is packed if it is
 /// 2D and lane_data[0] != 1 (data packed from col dimension).
 static bool hasPackedLayout(xegpu::LayoutAttr layout) {
@@ -412,9 +399,9 @@ struct StoreNdDistribution final : public gpu::WarpDistributionPattern {
         resolveDistributedTy(newWarpOp.getResult(newRetIndices[1]),
                              distributedTensorDescTy, rewriter));
 
-    rewriter.create<xegpu::StoreNdOp>(
-        newWarpOp.getLoc(), TypeRange{}, newStoreOperands,
-        removeTemporaryLayoutAttributes(storeOp->getAttrs()));
+    auto newStoreOp = rewriter.create<xegpu::StoreNdOp>(
+        newWarpOp.getLoc(), TypeRange{}, newStoreOperands, storeOp->getAttrs());
+    xegpu::removeLayoutAttrs(newStoreOp);
     rewriter.eraseOp(storeOp);
     return success();
   }
@@ -508,7 +495,8 @@ struct LoadNdDistribution final : public gpu::WarpDistributionPattern {
         newWarpOp.getLoc(), loadNdDistValueTyOrFailure.value(),
         resolveDistributedTy(newWarpOp->getResult(newRetIndices[0]),
                              distributedTensorDescTy, rewriter),
-        removeTemporaryLayoutAttributes(loadOp->getAttrs()));
+        loadOp->getAttrs());
+    xegpu::removeLayoutAttrs(newLoadOp);
     // Set the packed attribute if the layout requires it.
     newLoadOp.setPacked(hasPackedLayout(layout));
     Value distributedVal = newWarpOp.getResult(operandIdx);
@@ -639,14 +627,16 @@ struct DpasDistribution final : public gpu::WarpDistributionPattern {
           resolveDistributedTy(newWarpOp.getResult(newRetIndices[i]),
                                newDpasOperandExpectedTypes[i], rewriter));
     }
-    Value newDpasOp = rewriter.create<xegpu::DpasOp>(
-        newWarpOp->getLoc(), distributedResultTy, newDpasOperands,
-        removeTemporaryLayoutAttributes(dpasOp->getAttrs()));
+    auto newDpasOp =
+        rewriter.create<xegpu::DpasOp>(newWarpOp->getLoc(), distributedResultTy,
+                                       newDpasOperands, dpasOp->getAttrs());
+    xegpu::removeLayoutAttrs(newDpasOp);
     Value distributedVal = newWarpOp.getResult(operandIdx);
     // Resolve the output type.
-    newDpasOp = resolveDistributedTy(
-        newDpasOp, distResultTypeByWarpOpOrFailure.value(), rewriter);
-    rewriter.replaceAllUsesWith(distributedVal, newDpasOp);
+    Value typeResolved =
+        resolveDistributedTy(newDpasOp.getResult(),
+                             distResultTypeByWarpOpOrFailure.value(), rewriter);
+    rewriter.replaceAllUsesWith(distributedVal, typeResolved);
     return success();
   }
 };
@@ -726,14 +716,15 @@ struct UpdateNdOffsetDistribution final : public gpu::WarpDistributionPattern {
       }
     }
     // Create a new update op outside the warp op.
-    Value newUpdateOp = rewriter.create<xegpu::UpdateNdOffsetOp>(
+    auto newUpdateOp = rewriter.create<xegpu::UpdateNdOffsetOp>(
         newWarpOp.getLoc(), newTensorDescTy, newUpdateOperands,
-        removeTemporaryLayoutAttributes(updateOp->getAttrs()));
+        updateOp->getAttrs());
+    xegpu::removeLayoutAttrs(newUpdateOp);
     Value distributedVal = newWarpOp.getResult(operandIdx);
     // Resolve the distributed type with the original type.
-    newUpdateOp =
-        resolveDistributedTy(newUpdateOp, distributedVal.getType(), rewriter);
-    rewriter.replaceAllUsesWith(distributedVal, newUpdateOp);
+    Value typeResolved = resolveDistributedTy(
+        newUpdateOp.getResult(), distributedVal.getType(), rewriter);
+    rewriter.replaceAllUsesWith(distributedVal, typeResolved);
     return success();
   }
 };
@@ -792,9 +783,10 @@ struct PrefetchNdDistribution final : public gpu::WarpDistributionPattern {
     rewriter.setInsertionPointAfter(newWarpOp);
     SmallVector<Value> newPrefetchOperands = {resolveDistributedTy(
         newWarpOp.getResult(newRetIndices[0]), newTensorDescTy, rewriter)};
-    rewriter.create<xegpu::PrefetchNdOp>(
-        newWarpOp.getLoc(), TypeRange{}, newPrefetchOperands,
-        removeTemporaryLayoutAttributes(prefetchOp->getAttrs()));
+    rewriter.create<xegpu::PrefetchNdOp>(newWarpOp.getLoc(), TypeRange{},
+                                         newPrefetchOperands,
+                                         prefetchOp->getAttrs());
+    xegpu::removeLayoutAttrs(prefetchOp);
     rewriter.eraseOp(prefetchOp);
     return success();
   }
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
index 6b85a66a8bd36..64d58153baa74 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
@@ -184,6 +184,33 @@ void xegpu::setLayoutAttrs(Operation *op,
   });
 }
 
+template <typename T, typename>
+void xegpu::removeLayoutAttr(const T &operandOrResult) {
+  Operation *owner = operandOrResult.getOwner();
+  std::string name = xegpu::getLayoutName(operandOrResult);
+  if (owner->hasAttrOfType<LayoutAttr>(name))
+    owner->removeAttr(name);
+}
+
+// Explicit instantiation for OpResult
+template void
+xegpu::removeLayoutAttr<mlir::OpResult>(const mlir::OpResult &result);
+
+// Explicit instantiation for OpOperand
+template void
+xegpu::removeLayoutAttr<mlir::OpOperand>(const mlir::OpOperand &operand);
+
+void xegpu::removeLayoutAttrs(Operation *op) {
+  op->walk([&](Operation *nestOp) {
+    for (OpOperand &opr : nestOp->getOpOperands()) {
+      removeLayoutAttr(opr);
+    }
+    for (OpResult result : nestOp->getOpResults()) {
+      removeLayoutAttr(result);
+    }
+  });
+}
+
 SmallVector<Value>
 xegpu::extractVectorsWithShapeFromValue(OpBuilder &builder, Location loc,
                                         Value value, ArrayRef<int64_t> shape) {



More information about the Mlir-commits mailing list