[Mlir-commits] [mlir] [MLIR][XeGPU] Add uniform values distribution pattern (PR #176737)

Artem Kroviakov llvmlistbot at llvm.org
Mon Jan 26 03:26:52 PST 2026


https://github.com/akroviakov updated https://github.com/llvm/llvm-project/pull/176737

>From 9ab4768e074b5da16c17f681ea9beec9c9577356 Mon Sep 17 00:00:00 2001
From: Artem Kroviakov <artem.kroviakov at intel.com>
Date: Sun, 18 Jan 2026 16:25:43 +0000
Subject: [PATCH 1/3] [MLIR][XeGPU] Add uniform values "distribution" pattern

---
 .../Vector/Transforms/VectorDistribute.cpp    |  15 ++-
 .../Transforms/XeGPUSubgroupDistribute.cpp    | 101 ++++++++++++++++--
 mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp   |   4 +-
 .../XeGPU/subgroup-distribute-unit.mlir       |  57 ++++++----
 4 files changed, 142 insertions(+), 35 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 5334470e2e3a0..743fb51bab1ab 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -398,7 +398,8 @@ getInnerRegionEscapingValues(WarpExecuteOnLane0Op warpOp, Region &innerRegion,
       Type distType = operand->get().getType();
       if (auto vecType = dyn_cast<VectorType>(distType)) {
         AffineMap map = distributionMapFn(operand->get());
-        distType = getDistributedType(vecType, map, warpOp.getWarpSize());
+        distType = getDistributedType(vecType, map,
+                                      map.isEmpty() ? 1 : warpOp.getWarpSize());
       }
       escapingValueTypes.push_back(operand->get().getType());
       escapingValueDistTypes.push_back(distType);
@@ -1886,7 +1887,9 @@ struct WarpOpScfIfOp : public WarpDistributionPattern {
         // Fallback to affine map if the dist result was not previously recorded
         distType = ifResultDistTypes.count(i)
                        ? ifResultDistTypes[i]
-                       : getDistributedType(vecType, map, warpOp.getWarpSize());
+                       : getDistributedType(
+                             vecType, map,
+                             map.isEmpty() ? 1 : newWarpOp.getWarpSize());
       }
       newIfOpDistResTypes.push_back(distType);
     }
@@ -2075,9 +2078,11 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
         // we can get the distributed type from `forResultDistTypes` map.
         // Otherwise, we compute it using distributionMapFn.
         AffineMap map = distributionMapFn(initArg);
-        distType = forResultDistTypes.count(i)
-                       ? forResultDistTypes[i]
-                       : getDistributedType(vecType, map, warpOp.getWarpSize());
+        distType =
+            forResultDistTypes.count(i)
+                ? forResultDistTypes[i]
+                : getDistributedType(vecType, map,
+                                     map.isEmpty() ? 1 : warpOp.getWarpSize());
       }
       newWarpOpDistTypes.push_back(distType);
     }
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 9113f00ac39f0..30e104271d2b7 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -64,6 +64,14 @@ namespace {
 /// priorities to patterns.
 static constexpr unsigned regularPatternBenefit = 1;
 static constexpr unsigned highPatternBenefit = 2;
+static constexpr unsigned highestPatternBenefit = 3;
+
+enum PatternHierarchy : unsigned {
+  Regular = 1,
+  AboveRegular = 2,
+  High = 3,
+  Highest = 4
+};
 
 /// Helper function to get  distributed vector type for a source vector type
 /// according to the lane_layout. We simply divide each dimension of tensor
@@ -792,9 +800,10 @@ struct PrefetchNdDistribution final : public gpu::WarpDistributionPattern {
     // Collect offsets.
     for (size_t i = 1; i < newRetIndices.size(); ++i)
       newPrefetchOperands.push_back(newWarpOp.getResult(newRetIndices[i]));
-    xegpu::PrefetchNdOp::create(rewriter, newWarpOp.getLoc(), TypeRange{},
-                                newPrefetchOperands, prefetchOp->getAttrs());
-    xegpu::removeLayoutAttrs(prefetchOp);
+    Operation *newPrefetchOp = xegpu::PrefetchNdOp::create(
+        rewriter, newWarpOp.getLoc(), TypeRange{}, newPrefetchOperands,
+        prefetchOp->getAttrs());
+    xegpu::removeLayoutAttrs(newPrefetchOp);
     rewriter.eraseOp(prefetchOp);
     return success();
   }
@@ -1208,6 +1217,77 @@ struct LoadDistribution final : public gpu::WarpDistributionPattern {
   }
 };
 
+// Sink SG-uniform ops. An op is uniform if any of the following is true:
+// 1. It has no vector operands.
+// 2. All of its vector operands and results are uniform.
+// Non-uniform vectors are handled by dedicated patterns.
+// This pattern must have a higher priority than distribution patterns,
+// because a distributable shape may be logically intended as uniform.
+struct SinkUniformOps final : public gpu::WarpDistributionPattern {
+  using gpu::WarpDistributionPattern::WarpDistributionPattern;
+  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
+                                PatternRewriter &rewriter) const override {
+    // Take the last op
+    Operation *warpRegionPreYieldOp = warpOp.getTerminator()->getPrevNode();
+    // Any ops with nested regions must be handled carefully in dedicated
+    // patterns.
+    if (!warpRegionPreYieldOp || warpRegionPreYieldOp->getNumRegions())
+      return failure();
+
+    int operandIdx = -1;
+    if (warpRegionPreYieldOp->getNumResults()) {
+      OpOperand *operand = getWarpResult(
+          warpOp, [&](Operation *op) { return warpRegionPreYieldOp == op; });
+      if (!operand)
+        return failure();
+      operandIdx = operand->getOperandNumber();
+      if (warpRegionPreYieldOp->getResult(0).getType() !=
+          warpOp.getResult(operandIdx).getType())
+        return rewriter.notifyMatchFailure(warpOp,
+                                           "The op result is not uniform.");
+    }
+
+    // The op must have at least no layout-based operands or results.
+    bool uniformValuesOnly =
+        llvm::all_of(warpRegionPreYieldOp->getResults(), [](Value v) {
+          return !xegpu::getDistributeLayoutAttr(v);
+        });
+    uniformValuesOnly &=
+        llvm::all_of(warpRegionPreYieldOp->getOpOperands(), [](OpOperand &opr) {
+          return !xegpu::getDistributeLayoutAttr(opr);
+        });
+    if (!uniformValuesOnly)
+      return rewriter.notifyMatchFailure(warpOp,
+                                         "Some values are not uniform.");
+
+    // Capture its operands and create a new warp op that yields them.
+    SmallVector<size_t> newRetIndices;
+    SmallVector<Value> operands =
+        llvm::to_vector_of<Value>(warpRegionPreYieldOp->getOperands());
+    SmallVector<Type> operandTypes =
+        llvm::to_vector_of<Type>(warpRegionPreYieldOp->getOperandTypes());
+    gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+        rewriter, warpOp, operands, operandTypes, newRetIndices);
+
+    // Clone the op after the new warp op.
+    rewriter.setInsertionPointAfter(newWarpOp);
+
+    IRMapping operandMapper;
+    for (auto [oldOperandIdx, newOperandIdx] : llvm::enumerate(newRetIndices))
+      operandMapper.map(warpRegionPreYieldOp->getOperand(oldOperandIdx),
+                        newWarpOp->getResult(newOperandIdx));
+    Operation *clonedOp = rewriter.clone(*warpRegionPreYieldOp, operandMapper);
+    if (!clonedOp->getNumResults())
+      rewriter.eraseOp(warpRegionPreYieldOp);
+    else {
+      assert(operandIdx != -1 && "Expected a warp result for the operation");
+      rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIdx),
+                                  clonedOp->getResult(0));
+    }
+    return success();
+  }
+};
+
 /// Helper to rewrite a 2D VectorMultiReductionOp into a sequence of 1D
 /// VectorReductionOps. We also insert layouts for the newly created ops.
 static Value lowerToVectorReductions(TypedValue<VectorType> src,
@@ -2036,14 +2116,16 @@ void xegpu::populateXeGPUSubgroupDistributePatterns(
                StoreMatrixDistribution,
                MemrefExtractAlignedPointerAsIndexDistribution>(
       patterns.getContext(),
-      /*pattern benefit=*/regularPatternBenefit);
+      /*pattern benefit=*/PatternHierarchy::Regular);
   // For following patterns, we need to override the regular vector distribution
   // patterns. Therefore, assign higher benefit.
   patterns
       .add<VectorShapeCastDistribution, VectorExtractStridedSliceDistribution,
            VectorInsertStridedSliceDistribution, VectorBroadcastDistribution>(
           patterns.getContext(),
-          /*pattern benefit=*/highPatternBenefit);
+          /*pattern benefit=*/PatternHierarchy::AboveRegular);
+  patterns.add<SinkUniformOps>(patterns.getContext(),
+                               /*pattern benefit=*/PatternHierarchy::High);
 }
 
 void xegpu::populateXeGPUMoveFuncBodyToWarpOpPatterns(
@@ -2094,10 +2176,9 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
       return AffineMap::get(val.getContext());
     // Get the layout of the vector type.
     xegpu::DistributeLayoutAttr layout = xegpu::getDistributeLayoutAttr(val);
-    // If no layout is specified, that means no distribution.
+    // If no layout is specified, assume uniform case (no distribution).
     if (!layout)
-      return AffineMap::getMultiDimMapWithTargets(vecRank, {},
-                                                  val.getContext());
+      return AffineMap::get(val.getContext());
     // Expecting vector and layout rank to match.
     assert(layout.getRank() == vecRank &&
            "Expecting vector and layout rank to match");
@@ -2133,11 +2214,11 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
 
   vector::populateDistributeReduction(
       patterns, warpReduction,
-      /*pattern benefit=*/regularPatternBenefit);
+      /*pattern benefit=*/PatternHierarchy::Regular);
 
   vector::populatePropagateWarpVectorDistributionPatterns(
       patterns, distributionFn, shuffleFn,
-      /*pattern benefit=*/regularPatternBenefit);
+      /*pattern benefit=*/PatternHierarchy::Regular);
   if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
     signalPassFailure();
     return;
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
index 51783b41c4c96..e9d81dd4d1d81 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
@@ -361,9 +361,9 @@ bool xegpu::recoverTemporaryLayouts(Operation *rootOp) {
         continue;
       auto layout = xegpu::getDistributeLayoutAttr(operand.get());
       if (!layout) {
-        op->emitError("Could not find layout attribute for operand ")
+        op->emitWarning("Could not find layout attribute for operand ")
             << operand.getOperandNumber() << " of operation " << op->getName();
-        return WalkResult::interrupt();
+        continue;
       }
       xegpu::setDistributeLayoutAttr(operand, layout);
     }
diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir
index b136c89925682..a99f850de6175 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir
@@ -16,7 +16,7 @@ gpu.func @store_nd_1d(%laneid: index) {
   gpu.warp_execute_on_lane_0(%laneid)[16] {
     %0 = "some_op"() : () -> !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
     %cst = "some_op"() : () -> vector<16xf32>
-    xegpu.store_nd %cst, %0 [%c0] {layout_operand_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
+    xegpu.store_nd %cst, %0 [%c0] {layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
       : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
   }
   gpu.return
@@ -39,7 +39,7 @@ gpu.func @store_nd_2d(%laneid : index) {
   gpu.warp_execute_on_lane_0(%laneid)[16] {
     %0 = "some_op"() : () -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
     %cst = "some_op"() : () -> vector<16x16xf16>
-    xegpu.store_nd %cst, %0 [%c0, %c0] {layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+    xegpu.store_nd %cst, %0 [%c0, %c0] {layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
       : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
   }
   gpu.return
@@ -60,7 +60,7 @@ gpu.func @load_nd_1d(%laneid: index) {
   %c0 = arith.constant 0 : index
   %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<1xf32>) {
     %0 = "some_op"() : () -> !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
-    %1 = xegpu.load_nd %0 [%c0]  {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} :
+    %1 = xegpu.load_nd %0 [%c0]  {layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>} :
       !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>> -> vector<16xf32>
     gpu.yield %1 : vector<16xf32>
   }
@@ -84,7 +84,7 @@ gpu.func @load_nd_2d(%laneid: index) {
   %c0 = arith.constant 0 : index
   %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<16x1xf16>) {
     %0 = "some_op"() : () -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
-    %1 = xegpu.load_nd %0[%c0, %c0]  {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+    %1 = xegpu.load_nd %0[%c0, %c0]  {layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
       : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<16x16xf16>
     gpu.yield %1 : vector<16x16xf16>
   }
@@ -112,7 +112,7 @@ gpu.func @load_nd_array_length(%laneid: index) {
   %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<2x16x1xf16>) {
     %0 = "some_op"() : () -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2 : i64>,
       #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
-    %1 = xegpu.load_nd %0[%c0, %c0]  {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+    %1 = xegpu.load_nd %0[%c0, %c0]  {layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
       : !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2 : i64>,
         #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<2x16x16xf16>
     gpu.yield %1 : vector<2x16x16xf16>
@@ -192,7 +192,7 @@ gpu.func @prefetch_2d(%laneid: index) {
     %0 = "some_op"() : ()
       -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
     xegpu.prefetch_nd %0[%c0, %c0]
-      <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+      {layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}
       : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
   }
   gpu.return
@@ -215,7 +215,7 @@ gpu.func @prefetch_1d(%laneid: index) {
     %0 = "some_op"() : ()
       -> !xegpu.tensor_desc<16xf16, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
     xegpu.prefetch_nd %0[%c0]
-      <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+      {layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>, l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}
       : !xegpu.tensor_desc<16xf16, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
   }
   gpu.return
@@ -285,14 +285,15 @@ gpu.func @vector_multi_reduction_dim1_distributed_dim0_reduction(%laneid: index)
 
 
 // CHECK-LABEL: gpu.func @vector_multi_reduction_dim1_distributed_dim1_reduction
-// CHECK:      %[[W:.*]] = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<2xf32>) {
+// CHECK:      %[[W:.*]] = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> ({{.*}}) {
 // CHECK-NEXT:   %[[SRC:.*]] = "some_def"() {{.*}} : () -> vector<2x16xf32>
 // CHECK-NEXT:   %[[T2:.*]] = vector.extract %[[SRC]][0] : vector<16xf32> from vector<2x16xf32>
 // CHECK-NEXT:   %[[T3:.*]] = vector.reduction <add>, %[[T2]], %cst : vector<16xf32> into f32
-// CHECK-NEXT:   %[[T4:.*]] = vector.extract %[[SRC]][1] : vector<16xf32> from vector<2x16xf32>
-// CHECK-NEXT:   %[[T5:.*]] = vector.reduction <add>, %[[T4]], %cst : vector<16xf32> into f32
-// CHECK-NEXT:   %[[T6:.*]] = vector.from_elements %[[T3]], %[[T5]] : vector<2xf32>
-// CHECK-NEXT:   gpu.yield %[[T6]] : vector<2xf32>
+// CHECK-NEXT:   %[[T4:.*]] = vector.insert %[[T3]], %cst_0 [0] : f32 into vector<2xf32>
+// CHECK-NEXT:   %[[T5:.*]] = vector.extract %[[SRC]][1] : vector<16xf32> from vector<2x16xf32>
+// CHECK-NEXT:   %[[T6:.*]] = vector.reduction <add>, %[[T5]], %cst : vector<16xf32> into f32
+// CHECK-NEXT:   %[[T7:.*]] = vector.insert %[[T6]], %[[T4]] [1] : f32 into vector<2xf32>
+// CHECK-NEXT:   gpu.yield %[[T7]]
 // CHECK-NEXT: }
 gpu.func @vector_multi_reduction_dim1_distributed_dim1_reduction(%laneid: index) {
   %c0 = arith.constant 0 : index
@@ -317,6 +318,7 @@ gpu.func @vector_multi_reduction_dim1_distributed_dim1_reduction(%laneid: index)
 }
 
 
+
 // CHECK-LABEL:   gpu.func @vector_multi_reduction_dim0_distributed_dim1_reduction
 // CHECK:       %[[ACC:.*]] = arith.constant {{.*}} dense<0.000000e+00> : vector<32xf32>
 // CHECK:       %[[W:.*]]:3 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<2xf32>, vector<2x16xf32>, vector<2xf32>) {
@@ -354,18 +356,19 @@ gpu.func @vector_multi_reduction_dim0_distributed_dim1_reduction(%laneid: index)
 
 
 // CHECK-LABEL: gpu.func @vector_multi_reduction_dim0_distributed_dim0_reduction
-// CHECK:     %[[W:.*]] = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<2xf32>) {
+// CHECK:     %[[W:.*]] = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<2xf32>{{.*}}) {
 // CHECK:       %[[SRC:.*]] = "some_def"() {{.*}} : () -> vector<16x2xf32>
 // CHECK:       %[[T1:.*]] = vector.extract_strided_slice %[[SRC]]
 // CHECK-SAME:    {offsets = [0, 0], sizes = [16, 1], strides = [1, 1]} : vector<16x2xf32> to vector<16x1xf32>
 // CHECK:       %[[T2:.*]] = vector.shape_cast %[[T1]] {{.*}} : vector<16x1xf32> to vector<16xf32>
 // CHECK:       %[[T3:.*]] = vector.reduction <add>, %[[T2]], %{{.*}} : vector<16xf32> into f32
-// CHECK:       %[[T4:.*]] = vector.extract_strided_slice %[[SRC]]
+// CHECK:       %[[T4:.*]] = vector.insert %[[T3]], %cst_0 [0] : f32 into vector<2xf32>
+// CHECK:       %[[T5:.*]] = vector.extract_strided_slice %[[SRC]]
 // CHECK-SAME:     {offsets = [0, 1], sizes = [16, 1], strides = [1, 1]} : vector<16x2xf32> to vector<16x1xf32>
-// CHECK:       %[[T5:.*]] = vector.shape_cast %[[T4]] {{.*}} : vector<16x1xf32> to vector<16xf32>
-// CHECK:       %[[T6:.*]] = vector.reduction <add>, %[[T5]], %{{.*}} : vector<16xf32> into f32
-// CHECK:       %[[T7:.*]] = vector.from_elements %[[T3]], %[[T6]] : vector<2xf32>
-// CHECK:       gpu.yield %[[T7]] : vector<2xf32>
+// CHECK:       %[[T6:.*]] = vector.shape_cast %[[T5]] {{.*}} : vector<16x1xf32> to vector<16xf32>
+// CHECK:       %[[T7:.*]] = vector.reduction <add>, %[[T6]], %{{.*}} : vector<16xf32> into f32
+// CHECK:       %[[T8:.*]] = vector.insert %[[T7]], %[[T4]] [1] : f32 into vector<2xf32>
+// CHECK:       gpu.yield %[[T8]]
 // CHECK:     }
 gpu.func @vector_multi_reduction_dim0_distributed_dim0_reduction(%laneid: index) {
   %c0 = arith.constant 0 : index
@@ -1030,4 +1033,22 @@ gpu.func
   gpu.return
 }
 
+// CHECK-LABEL: gpu.func @vector_shape_cast_scalar_to_vector_uniform
+// CHECK-SAME: (%[[ARG0:.*]]: index)
+// CHECK: %[[R:.*]]:2 = gpu.warp_execute_on_lane_0(%[[ARG0]])[16] -> (vector<16x16xf16>, f16)
+// CHECK: %[[DEF:.*]] = "some_def"()
+// CHECK: %[[BCAST:.*]] = vector.broadcast %[[DEF]] : f16 to vector<16x16xf16>
+// CHECK: gpu.yield %[[BCAST]], %[[DEF]] : vector<16x16xf16>, f16
+// CHECK: %[[RESULT:.*]] = vector.broadcast %[[R]]#1 : f16 to vector<16x16xf16>
+// CHECK: "some_use"(%[[RESULT]])
+  gpu.func @vector_shape_cast_scalar_to_vector_uniform(%arg0: index) {
+    %0 = gpu.warp_execute_on_lane_0(%arg0)[16] -> (vector<16x16xf16>) {
+      %1 = "some_def"() : () -> f16
+      %2 = vector.broadcast %1 : f16 to vector<16x16xf16>
+      gpu.yield %2 : vector<16x16xf16>
+    }
+    "some_use"(%0) : (vector<16x16xf16>) -> ()
+    gpu.return
+  }
+
 }

>From 55e4f7997363db450fa6034aca02bb0b83c617dc Mon Sep 17 00:00:00 2001
From: Artem Kroviakov <artem.kroviakov at intel.com>
Date: Mon, 19 Jan 2026 11:27:33 +0000
Subject: [PATCH 2/3] Cleanup

---
 .../XeGPU/Transforms/XeGPUSubgroupDistribute.cpp   | 14 ++------------
 1 file changed, 2 insertions(+), 12 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 30e104271d2b7..c56793effb2f7 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -62,10 +62,6 @@ namespace {
 /// In certain cases, we may need to favor XeGPU specific distribution patterns
 /// over generic vector distribution patterns. In such cases, we can assign
 /// priorities to patterns.
-static constexpr unsigned regularPatternBenefit = 1;
-static constexpr unsigned highPatternBenefit = 2;
-static constexpr unsigned highestPatternBenefit = 3;
-
 enum PatternHierarchy : unsigned {
   Regular = 1,
   AboveRegular = 2,
@@ -1217,9 +1213,8 @@ struct LoadDistribution final : public gpu::WarpDistributionPattern {
   }
 };
 
-// Sink SG-uniform ops. An op is uniform if any of the following is true:
-// 1. It has no vector operands.
-// 2. All of its vector operands and results are uniform.
+// Sink SG-uniform ops. An op is uniform if any of the following is none
+// of its operands/results has a distribution layout attribute.
 // Non-uniform vectors are handled by dedicated patterns.
 // This pattern must have a higher priority than distribution patterns,
 // because a distributable shape may be logically intended as uniform.
@@ -1233,7 +1228,6 @@ struct SinkUniformOps final : public gpu::WarpDistributionPattern {
     // patterns.
     if (!warpRegionPreYieldOp || warpRegionPreYieldOp->getNumRegions())
       return failure();
-
     int operandIdx = -1;
     if (warpRegionPreYieldOp->getNumResults()) {
       OpOperand *operand = getWarpResult(
@@ -1259,8 +1253,6 @@ struct SinkUniformOps final : public gpu::WarpDistributionPattern {
     if (!uniformValuesOnly)
       return rewriter.notifyMatchFailure(warpOp,
                                          "Some values are not uniform.");
-
-    // Capture its operands and create a new warp op that yields them.
     SmallVector<size_t> newRetIndices;
     SmallVector<Value> operands =
         llvm::to_vector_of<Value>(warpRegionPreYieldOp->getOperands());
@@ -1269,9 +1261,7 @@ struct SinkUniformOps final : public gpu::WarpDistributionPattern {
     gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
         rewriter, warpOp, operands, operandTypes, newRetIndices);
 
-    // Clone the op after the new warp op.
     rewriter.setInsertionPointAfter(newWarpOp);
-
     IRMapping operandMapper;
     for (auto [oldOperandIdx, newOperandIdx] : llvm::enumerate(newRetIndices))
       operandMapper.map(warpRegionPreYieldOp->getOperand(oldOperandIdx),

>From da65a9bc4f112c83b40b7f9ce3cad65ea7e6d3c0 Mon Sep 17 00:00:00 2001
From: Artem Kroviakov <artem.kroviakov at intel.com>
Date: Mon, 26 Jan 2026 11:21:42 +0000
Subject: [PATCH 3/3] Remove hierarchy level

---
 .../Transforms/XeGPUSubgroupDistribute.cpp    | 24 +++++++------------
 1 file changed, 9 insertions(+), 15 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index c56793effb2f7..ac8c02827c1d4 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -62,12 +62,7 @@ namespace {
 /// In certain cases, we may need to favor XeGPU specific distribution patterns
 /// over generic vector distribution patterns. In such cases, we can assign
 /// priorities to patterns.
-enum PatternHierarchy : unsigned {
-  Regular = 1,
-  AboveRegular = 2,
-  High = 3,
-  Highest = 4
-};
+enum PatternHierarchy : unsigned { Regular = 1, AboveRegular = 2 };
 
 /// Helper function to get  distributed vector type for a source vector type
 /// according to the lane_layout. We simply divide each dimension of tensor
@@ -1213,11 +1208,12 @@ struct LoadDistribution final : public gpu::WarpDistributionPattern {
   }
 };
 
-// Sink SG-uniform ops. An op is uniform if any of the following is none
+// Sink SG-uniform ops. An op is uniform if none
 // of its operands/results has a distribution layout attribute.
 // Non-uniform vectors are handled by dedicated patterns.
-// This pattern must have a higher priority than distribution patterns,
-// because a distributable shape may be logically intended as uniform.
+// This pattern must have a higher priority than vector dialect distribution
+// patterns, because a distributable shape may be logically intended as
+// uniform (i.e., no layout), so we want to omit its distribution.
 struct SinkUniformOps final : public gpu::WarpDistributionPattern {
   using gpu::WarpDistributionPattern::WarpDistributionPattern;
   LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
@@ -1241,7 +1237,7 @@ struct SinkUniformOps final : public gpu::WarpDistributionPattern {
                                            "The op result is not uniform.");
     }
 
-    // The op must have at least no layout-based operands or results.
+    // The op must have no layout-based operands or results.
     bool uniformValuesOnly =
         llvm::all_of(warpRegionPreYieldOp->getResults(), [](Value v) {
           return !xegpu::getDistributeLayoutAttr(v);
@@ -2111,11 +2107,9 @@ void xegpu::populateXeGPUSubgroupDistributePatterns(
   // patterns. Therefore, assign higher benefit.
   patterns
       .add<VectorShapeCastDistribution, VectorExtractStridedSliceDistribution,
-           VectorInsertStridedSliceDistribution, VectorBroadcastDistribution>(
-          patterns.getContext(),
-          /*pattern benefit=*/PatternHierarchy::AboveRegular);
-  patterns.add<SinkUniformOps>(patterns.getContext(),
-                               /*pattern benefit=*/PatternHierarchy::High);
+           VectorInsertStridedSliceDistribution, VectorBroadcastDistribution,
+           SinkUniformOps>(patterns.getContext(),
+                           /*pattern benefit=*/PatternHierarchy::AboveRegular);
 }
 
 void xegpu::populateXeGPUMoveFuncBodyToWarpOpPatterns(



More information about the Mlir-commits mailing list