[Mlir-commits] [mlir] [mlir][xegpu] Minor fixes in XeGPU subgroup distribution. (PR #147846)

Charitha Saumya llvmlistbot at llvm.org
Wed Jul 9 15:49:14 PDT 2025


https://github.com/charithaintc created https://github.com/llvm/llvm-project/pull/147846

This PR addresses the following issues.

1. Add the missing attributes when creating a new GPU funcOp in `MoveFuncBodyToWarpExecuteOnLane0` pattern.
2. Bug fix in LoadNd distribution to make sure LoadOp is the last op in warpOp region before it is distributed (needed for preserving the memory op ordering during distribution).

>From 141c55145b1d766ce94b718f4e311f187c64ba04 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Wed, 9 Jul 2025 22:44:43 +0000
Subject: [PATCH] save work

---
 .../Transforms/XeGPUSubgroupDistribute.cpp    | 87 ++++++++++---------
 .../Dialect/XeGPU/subgroup-distribute.mlir    |  8 +-
 2 files changed, 52 insertions(+), 43 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index c072557c2bd22..c87b596b4df43 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -34,6 +34,7 @@
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/SmallVectorExtras.h"
 
 namespace mlir {
 namespace xegpu {
@@ -197,9 +198,17 @@ struct MoveFuncBodyToWarpExecuteOnLane0
           return isa<gpu::WarpExecuteOnLane0Op>(op);
         }))
       return failure();
-    // Create a new function with the same signature.
+    // Create a new function with the same signature and same attributes.
+    SmallVector<Type> workgroupAttributionsTypes =
+        llvm::map_to_vector(gpuFuncOp.getWorkgroupAttributions(),
+                            [](BlockArgument arg) { return arg.getType(); });
+    SmallVector<Type> privateAttributionsTypes =
+        llvm::map_to_vector(gpuFuncOp.getPrivateAttributions(),
+                            [](BlockArgument arg) { return arg.getType(); });
     auto newGpuFunc = rewriter.create<gpu::GPUFuncOp>(
-        gpuFuncOp.getLoc(), gpuFuncOp.getName(), gpuFuncOp.getFunctionType());
+        gpuFuncOp.getLoc(), gpuFuncOp.getName(), gpuFuncOp.getFunctionType(),
+        workgroupAttributionsTypes, privateAttributionsTypes);
+    newGpuFunc->setAttrs(gpuFuncOp->getAttrs());
     // Create a WarpExecuteOnLane0Op with same arguments and results as the
     // original gpuFuncOp.
     rewriter.setInsertionPointToEnd(&newGpuFunc.getFunctionBody().front());
@@ -265,13 +274,13 @@ struct MoveFuncBodyToWarpExecuteOnLane0
 /// ```
 struct CreateNdDescDistribution final : public gpu::WarpDistributionPattern {
   using gpu::WarpDistributionPattern::WarpDistributionPattern;
-  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
+  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
     OpOperand *operand =
-        getWarpResult(subgroupOp, llvm::IsaPred<xegpu::CreateNdDescOp>);
+        getWarpResult(warpOp, llvm::IsaPred<xegpu::CreateNdDescOp>);
     if (!operand)
       return rewriter.notifyMatchFailure(
-          subgroupOp, "warp result is not a xegpu::CreateNdDesc op");
+          warpOp, "warp result is not a xegpu::CreateNdDesc op");
     auto descOp = operand->get().getDefiningOp<xegpu::CreateNdDescOp>();
     unsigned operandIdx = operand->getOperandNumber();
 
@@ -288,9 +297,9 @@ struct CreateNdDescDistribution final : public gpu::WarpDistributionPattern {
       newYieldValues.push_back(operand);
       newYieldTypes.push_back(operand.getType());
     }
-    rewriter.setInsertionPoint(subgroupOp);
+    rewriter.setInsertionPoint(warpOp);
     gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
-        rewriter, subgroupOp, /* new yieled values = */ newYieldValues,
+        rewriter, warpOp, /* new yieled values = */ newYieldValues,
         /* new yielded types = */ newYieldTypes, newRetIndices);
 
     SmallVector<Value> newDescOperands;
@@ -347,10 +356,10 @@ struct CreateNdDescDistribution final : public gpu::WarpDistributionPattern {
 /// ```
 struct StoreNdDistribution final : public gpu::WarpDistributionPattern {
   using gpu::WarpDistributionPattern::WarpDistributionPattern;
-  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
+  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
     auto yield = cast<gpu::YieldOp>(
-        subgroupOp.getBodyRegion().getBlocks().begin()->getTerminator());
+        warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
     Operation *lastNode = yield->getPrevNode();
     auto storeOp = dyn_cast_or_null<xegpu::StoreNdOp>(lastNode);
     if (!storeOp)
@@ -372,7 +381,7 @@ struct StoreNdDistribution final : public gpu::WarpDistributionPattern {
 
     SmallVector<size_t> newRetIndices;
     gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
-        rewriter, subgroupOp,
+        rewriter, warpOp,
         /* new yielded values = */
         ValueRange{storeOp.getValue(), storeOp.getTensorDesc()},
         /* new yielded types = */
@@ -449,21 +458,22 @@ struct StoreNdDistribution final : public gpu::WarpDistributionPattern {
 /// ```
 struct LoadNdDistribution final : public gpu::WarpDistributionPattern {
   using gpu::WarpDistributionPattern::WarpDistributionPattern;
-  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
+  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
-    OpOperand *operand =
-        getWarpResult(subgroupOp, llvm::IsaPred<xegpu::LoadNdOp>);
+    OpOperand *operand = getWarpResult(warpOp, [&](Operation *op) {
+      if (!isa<xegpu::LoadNdOp>(op))
+        return false;
+      // Make sure the same load op is the last operation in the warp op body.
+      // This ensure that load op is not sinked earlier violating any barrier
+      // synchronizations.
+      auto yield = cast<gpu::YieldOp>(
+          warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+      return yield->getPrevNode() == op;
+    });
+
     if (!operand)
       return rewriter.notifyMatchFailure(
-          subgroupOp, "warp result is not a xegpu::LoadNd op");
-    // Make sure the load op is the last operation in the warp op body. This
-    // ensure that load op is not sinked earlier violating any barrier
-    // synchronizations.
-    auto yield = cast<gpu::YieldOp>(
-        subgroupOp.getBodyRegion().getBlocks().begin()->getTerminator());
-    Operation *lastNode = yield->getPrevNode();
-    if (!dyn_cast_or_null<xegpu::LoadNdOp>(lastNode))
-      return failure();
+          warpOp, "warp result is not a xegpu::LoadNd op");
 
     auto loadOp = operand->get().getDefiningOp<xegpu::LoadNdOp>();
     xegpu::TensorDescType tensorDescTy = loadOp.getTensorDescType();
@@ -474,11 +484,11 @@ struct LoadNdDistribution final : public gpu::WarpDistributionPattern {
 
     unsigned operandIdx = operand->getOperandNumber();
     VectorType distributedTypeByWarpOp =
-        cast<VectorType>(subgroupOp.getResult(operandIdx).getType());
+        cast<VectorType>(warpOp.getResult(operandIdx).getType());
 
     SmallVector<size_t> newRetIndices;
     gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
-        rewriter, subgroupOp,
+        rewriter, warpOp,
         /* new yielded values = */ loadOp.getTensorDesc(),
         /* new yielded types = */ tensorDescTy, newRetIndices);
 
@@ -548,12 +558,11 @@ struct LoadNdDistribution final : public gpu::WarpDistributionPattern {
 /// ```
 struct DpasDistribution final : public gpu::WarpDistributionPattern {
   using gpu::WarpDistributionPattern::WarpDistributionPattern;
-  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
+  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
-    OpOperand *operand =
-        getWarpResult(subgroupOp, llvm::IsaPred<xegpu::DpasOp>);
+    OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<xegpu::DpasOp>);
     if (!operand)
-      return rewriter.notifyMatchFailure(subgroupOp,
+      return rewriter.notifyMatchFailure(warpOp,
                                          "warp result is not a xegpu::Dpas op");
 
     auto dpasOp = operand->get().getDefiningOp<xegpu::DpasOp>();
@@ -599,7 +608,7 @@ struct DpasDistribution final : public gpu::WarpDistributionPattern {
     // Create a new warp op without the dpas.
     SmallVector<size_t> newRetIndices;
     gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
-        rewriter, subgroupOp, newYieldValues, newYieldTypes, newRetIndices);
+        rewriter, warpOp, newYieldValues, newYieldTypes, newRetIndices);
 
     FailureOr<VectorType> expectedDistLhsTyOrFailure =
         xegpu::getDistributedVectorType(dpasOp.getLhsType(), layoutA);
@@ -678,13 +687,13 @@ struct DpasDistribution final : public gpu::WarpDistributionPattern {
 /// ```
 struct UpdateNdOffsetDistribution final : public gpu::WarpDistributionPattern {
   using gpu::WarpDistributionPattern::WarpDistributionPattern;
-  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
+  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
     OpOperand *operand =
-        getWarpResult(subgroupOp, llvm::IsaPred<xegpu::UpdateNdOffsetOp>);
+        getWarpResult(warpOp, llvm::IsaPred<xegpu::UpdateNdOffsetOp>);
     if (!operand)
       return rewriter.notifyMatchFailure(
-          subgroupOp, "warp result is not a xegpu::UpdateNdOffset op");
+          warpOp, "warp result is not a xegpu::UpdateNdOffset op");
     auto updateOp = operand->get().getDefiningOp<xegpu::UpdateNdOffsetOp>();
     unsigned operandIdx = operand->getOperandNumber();
     // new update op does not have layout attribute.
@@ -703,7 +712,7 @@ struct UpdateNdOffsetDistribution final : public gpu::WarpDistributionPattern {
     }
     SmallVector<size_t> newRetIndices;
     gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
-        rewriter, subgroupOp, newYieldValues, newYieldTypes, newRetIndices);
+        rewriter, warpOp, newYieldValues, newYieldTypes, newRetIndices);
     rewriter.setInsertionPointAfter(newWarpOp);
     SmallVector<Value> newUpdateOperands;
     for (size_t i : newRetIndices) {
@@ -758,10 +767,10 @@ struct UpdateNdOffsetDistribution final : public gpu::WarpDistributionPattern {
 /// ```
 struct PrefetchNdDistribution final : public gpu::WarpDistributionPattern {
   using gpu::WarpDistributionPattern::WarpDistributionPattern;
-  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
+  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
     auto yield = cast<gpu::YieldOp>(
-        subgroupOp.getBodyRegion().getBlocks().begin()->getTerminator());
+        warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
     Operation *lastNode = yield->getPrevNode();
     auto prefetchOp = dyn_cast_or_null<xegpu::PrefetchNdOp>(lastNode);
     if (!prefetchOp)
@@ -775,7 +784,7 @@ struct PrefetchNdDistribution final : public gpu::WarpDistributionPattern {
     SmallVector<Type, 1> newYieldTypes = {prefetchOp.getTensorDescType()};
     SmallVector<size_t> newRetIndices;
     gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
-        rewriter, subgroupOp, newYieldValues, newYieldTypes, newRetIndices);
+        rewriter, warpOp, newYieldValues, newYieldTypes, newRetIndices);
     // Create a new prefetch op outside the warp op with updated tensor
     // descriptor type. Source tensor descriptor require type resolution.
     xegpu::TensorDescType newTensorDescTy =
@@ -795,17 +804,17 @@ struct PrefetchNdDistribution final : public gpu::WarpDistributionPattern {
 /// region. This will simply move the barrier op outside of the warp op.
 struct GpuBarrierDistribution final : public gpu::WarpDistributionPattern {
   using gpu::WarpDistributionPattern::WarpDistributionPattern;
-  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
+  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
     auto yield = cast<gpu::YieldOp>(
-        subgroupOp.getBodyRegion().getBlocks().begin()->getTerminator());
+        warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
     Operation *lastNode = yield->getPrevNode();
     // The last node must be a gpu::BarrierOp.
     auto barrierOp = dyn_cast_or_null<gpu::BarrierOp>(lastNode);
     if (!barrierOp)
       return failure();
     // Move the barrier op outside of the warp op.
-    rewriter.setInsertionPointAfter(subgroupOp);
+    rewriter.setInsertionPointAfter(warpOp);
     rewriter.create<gpu::BarrierOp>(
         barrierOp.getLoc(), barrierOp->getResultTypes(),
         barrierOp->getOperands(), barrierOp->getAttrs());
diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
index 3d91b2269bc4b..0b078727d4fe5 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
@@ -95,10 +95,10 @@ gpu.module @test {
 // -----
 // CHECK-LABEL: gpu.func @load_dpas_store
 // CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: memref<8x16xf16>, %[[ARG1:[0-9a-zA-Z]+]]: memref<16x16xf16>, %[[ARG2:[0-9a-zA-Z]+]]: memref<8x16xf32>) {
-// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG1]][%{{.*}}] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
-// CHECK: %[[T1:.*]] = xegpu.load_nd %[[T0]] <{packed}> : !xegpu.tensor_desc<16x16xf16> -> vector<16xf16>
 // CHECK: %[[T2:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%{{.*}}] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
 // CHECK: %[[T3:.*]] = xegpu.load_nd %[[T2]]  : !xegpu.tensor_desc<8x16xf16> -> vector<8xf16>
+// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG1]][%{{.*}}] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
+// CHECK: %[[T1:.*]] = xegpu.load_nd %[[T0]] <{packed}> : !xegpu.tensor_desc<16x16xf16> -> vector<16xf16>
 // CHECK-DAG: %[[T4:.*]] = xegpu.dpas %[[T3]], %[[T1]] : vector<8xf16>, vector<16xf16> -> vector<8xf32>
 // CHECK-DAG: %[[T5:.*]] = xegpu.create_nd_tdesc %[[ARG2]][%{{.*}}] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
 // CHECK: xegpu.store_nd %[[T4]], %[[T5]]  : vector<8xf32>, !xegpu.tensor_desc<8x16xf32>
@@ -120,10 +120,10 @@ gpu.module @test {
 // -----
 // CHECK-LABEL: gpu.func @load_dpas_postop_store
 // CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: memref<8x16xf16>, %[[ARG1:[0-9a-zA-Z]+]]: memref<16x16xf16>, %[[ARG2:[0-9a-zA-Z]+]]: memref<8x16xf32>) {
-// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG1]][%{{.*}}] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
-// CHECK: %[[T1:.*]] = xegpu.load_nd %[[T0]] <{packed}> : !xegpu.tensor_desc<16x16xf16> -> vector<16xf16>
 // CHECK: %[[T2:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%{{.*}}] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
 // CHECK: %[[T3:.*]] = xegpu.load_nd %[[T2]]  : !xegpu.tensor_desc<8x16xf16> -> vector<8xf16>
+// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG1]][%{{.*}}] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
+// CHECK: %[[T1:.*]] = xegpu.load_nd %[[T0]] <{packed}> : !xegpu.tensor_desc<16x16xf16> -> vector<16xf16>
 // CHECK-DAG: %[[T4:.*]] = xegpu.dpas %[[T3]], %[[T1]] : vector<8xf16>, vector<16xf16> -> vector<8xf32>
 // CHECK: %[[T5:.*]] = vector.shape_cast %[[T4]] : vector<8xf32> to vector<8x1xf32>
 // CHECK: %[[T6:.*]] = math.exp %[[T5]] {{{.*}}} : vector<8x1xf32>



More information about the Mlir-commits mailing list