[Mlir-commits] [mlir] Akroviak/vector dist uniform vecs (PR #176737)

Artem Kroviakov llvmlistbot at llvm.org
Mon Jan 19 03:48:21 PST 2026


https://github.com/akroviakov created https://github.com/llvm/llvm-project/pull/176737

This PR introduces the notion of uniform operations into the distribution logic of XeGPU.

Uniform operations do not have any `xegpu.layout` assigned to any of their operands/results. That is, a vector operation (even the one distributable by `VectorDistribute.cpp`) will be treated as a uniform operation if it does not have a layout for any of its operands or results.

Some nested ops (like `scf.if` distribution in `VectorDistribute.cpp`) cannot be simply cloned outside of warp with remapped operands. For such cases, `VectorDistribute.cpp` needs to be aware of uniform values during their distribution. We leverage the affine map normally used for distribution and keep it empty for uniform values. An empty map means that a value does not have to be changed by the distribution.

Minor:
Some xegpu tests had to be modified to ensure the layout accessibility and prevent the UniformOps pattern from hijacking the distribution.
PrefetchNd distribution had to be modified to ensure layout erasure in the distributed op.

>From 9ab4768e074b5da16c17f681ea9beec9c9577356 Mon Sep 17 00:00:00 2001
From: Artem Kroviakov <artem.kroviakov at intel.com>
Date: Sun, 18 Jan 2026 16:25:43 +0000
Subject: [PATCH 1/2] [MLIR][XeGPU] Add uniform values "distribution" pattern

---
 .../Vector/Transforms/VectorDistribute.cpp    |  15 ++-
 .../Transforms/XeGPUSubgroupDistribute.cpp    | 101 ++++++++++++++++--
 mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp   |   4 +-
 .../XeGPU/subgroup-distribute-unit.mlir       |  57 ++++++----
 4 files changed, 142 insertions(+), 35 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 5334470e2e3a0..743fb51bab1ab 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -398,7 +398,8 @@ getInnerRegionEscapingValues(WarpExecuteOnLane0Op warpOp, Region &innerRegion,
       Type distType = operand->get().getType();
       if (auto vecType = dyn_cast<VectorType>(distType)) {
         AffineMap map = distributionMapFn(operand->get());
-        distType = getDistributedType(vecType, map, warpOp.getWarpSize());
+        distType = getDistributedType(vecType, map,
+                                      map.isEmpty() ? 1 : warpOp.getWarpSize());
       }
       escapingValueTypes.push_back(operand->get().getType());
       escapingValueDistTypes.push_back(distType);
@@ -1886,7 +1887,9 @@ struct WarpOpScfIfOp : public WarpDistributionPattern {
         // Fallback to affine map if the dist result was not previously recorded
         distType = ifResultDistTypes.count(i)
                        ? ifResultDistTypes[i]
-                       : getDistributedType(vecType, map, warpOp.getWarpSize());
+                       : getDistributedType(
+                             vecType, map,
+                             map.isEmpty() ? 1 : newWarpOp.getWarpSize());
       }
       newIfOpDistResTypes.push_back(distType);
     }
@@ -2075,9 +2078,11 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
         // we can get the distributed type from `forResultDistTypes` map.
         // Otherwise, we compute it using distributionMapFn.
         AffineMap map = distributionMapFn(initArg);
-        distType = forResultDistTypes.count(i)
-                       ? forResultDistTypes[i]
-                       : getDistributedType(vecType, map, warpOp.getWarpSize());
+        distType =
+            forResultDistTypes.count(i)
+                ? forResultDistTypes[i]
+                : getDistributedType(vecType, map,
+                                     map.isEmpty() ? 1 : warpOp.getWarpSize());
       }
       newWarpOpDistTypes.push_back(distType);
     }
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 9113f00ac39f0..30e104271d2b7 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -64,6 +64,14 @@ namespace {
 /// priorities to patterns.
 static constexpr unsigned regularPatternBenefit = 1;
 static constexpr unsigned highPatternBenefit = 2;
+static constexpr unsigned highestPatternBenefit = 3;
+
+enum PatternHierarchy : unsigned {
+  Regular = 1,
+  AboveRegular = 2,
+  High = 3,
+  Highest = 4
+};
 
 /// Helper function to get  distributed vector type for a source vector type
 /// according to the lane_layout. We simply divide each dimension of tensor
@@ -792,9 +800,10 @@ struct PrefetchNdDistribution final : public gpu::WarpDistributionPattern {
     // Collect offsets.
     for (size_t i = 1; i < newRetIndices.size(); ++i)
       newPrefetchOperands.push_back(newWarpOp.getResult(newRetIndices[i]));
-    xegpu::PrefetchNdOp::create(rewriter, newWarpOp.getLoc(), TypeRange{},
-                                newPrefetchOperands, prefetchOp->getAttrs());
-    xegpu::removeLayoutAttrs(prefetchOp);
+    Operation *newPrefetchOp = xegpu::PrefetchNdOp::create(
+        rewriter, newWarpOp.getLoc(), TypeRange{}, newPrefetchOperands,
+        prefetchOp->getAttrs());
+    xegpu::removeLayoutAttrs(newPrefetchOp);
     rewriter.eraseOp(prefetchOp);
     return success();
   }
@@ -1208,6 +1217,77 @@ struct LoadDistribution final : public gpu::WarpDistributionPattern {
   }
 };
 
+// Sink SG-uniform ops. An op is uniform if any of the following is true:
+// 1. It has no vector operands.
+// 2. All of its vector operands and results are uniform.
+// Non-uniform vectors are handled by dedicated patterns.
+// This pattern must have a higher priority than distribution patterns,
+// because a distributable shape may be logically intended as uniform.
+struct SinkUniformOps final : public gpu::WarpDistributionPattern {
+  using gpu::WarpDistributionPattern::WarpDistributionPattern;
+  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
+                                PatternRewriter &rewriter) const override {
+    // Take the last op
+    Operation *warpRegionPreYieldOp = warpOp.getTerminator()->getPrevNode();
+    // Any ops with nested regions must be handled carefully in dedicated
+    // patterns.
+    if (!warpRegionPreYieldOp || warpRegionPreYieldOp->getNumRegions())
+      return failure();
+
+    int operandIdx = -1;
+    if (warpRegionPreYieldOp->getNumResults()) {
+      OpOperand *operand = getWarpResult(
+          warpOp, [&](Operation *op) { return warpRegionPreYieldOp == op; });
+      if (!operand)
+        return failure();
+      operandIdx = operand->getOperandNumber();
+      if (warpRegionPreYieldOp->getResult(0).getType() !=
+          warpOp.getResult(operandIdx).getType())
+        return rewriter.notifyMatchFailure(warpOp,
+                                           "The op result is not uniform.");
+    }
+
+    // The op must have at least no layout-based operands or results.
+    bool uniformValuesOnly =
+        llvm::all_of(warpRegionPreYieldOp->getResults(), [](Value v) {
+          return !xegpu::getDistributeLayoutAttr(v);
+        });
+    uniformValuesOnly &=
+        llvm::all_of(warpRegionPreYieldOp->getOpOperands(), [](OpOperand &opr) {
+          return !xegpu::getDistributeLayoutAttr(opr);
+        });
+    if (!uniformValuesOnly)
+      return rewriter.notifyMatchFailure(warpOp,
+                                         "Some values are not uniform.");
+
+    // Capture its operands and create a new warp op that yields them.
+    SmallVector<size_t> newRetIndices;
+    SmallVector<Value> operands =
+        llvm::to_vector_of<Value>(warpRegionPreYieldOp->getOperands());
+    SmallVector<Type> operandTypes =
+        llvm::to_vector_of<Type>(warpRegionPreYieldOp->getOperandTypes());
+    gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+        rewriter, warpOp, operands, operandTypes, newRetIndices);
+
+    // Clone the op after the new warp op.
+    rewriter.setInsertionPointAfter(newWarpOp);
+
+    IRMapping operandMapper;
+    for (auto [oldOperandIdx, newOperandIdx] : llvm::enumerate(newRetIndices))
+      operandMapper.map(warpRegionPreYieldOp->getOperand(oldOperandIdx),
+                        newWarpOp->getResult(newOperandIdx));
+    Operation *clonedOp = rewriter.clone(*warpRegionPreYieldOp, operandMapper);
+    if (!clonedOp->getNumResults())
+      rewriter.eraseOp(warpRegionPreYieldOp);
+    else {
+      assert(operandIdx != -1 && "Expected a warp result for the operation");
+      rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIdx),
+                                  clonedOp->getResult(0));
+    }
+    return success();
+  }
+};
+
 /// Helper to rewrite a 2D VectorMultiReductionOp into a sequence of 1D
 /// VectorReductionOps. We also insert layouts for the newly created ops.
 static Value lowerToVectorReductions(TypedValue<VectorType> src,
@@ -2036,14 +2116,16 @@ void xegpu::populateXeGPUSubgroupDistributePatterns(
                StoreMatrixDistribution,
                MemrefExtractAlignedPointerAsIndexDistribution>(
       patterns.getContext(),
-      /*pattern benefit=*/regularPatternBenefit);
+      /*pattern benefit=*/PatternHierarchy::Regular);
   // For following patterns, we need to override the regular vector distribution
   // patterns. Therefore, assign higher benefit.
   patterns
       .add<VectorShapeCastDistribution, VectorExtractStridedSliceDistribution,
            VectorInsertStridedSliceDistribution, VectorBroadcastDistribution>(
           patterns.getContext(),
-          /*pattern benefit=*/highPatternBenefit);
+          /*pattern benefit=*/PatternHierarchy::AboveRegular);
+  patterns.add<SinkUniformOps>(patterns.getContext(),
+                               /*pattern benefit=*/PatternHierarchy::High);
 }
 
 void xegpu::populateXeGPUMoveFuncBodyToWarpOpPatterns(
@@ -2094,10 +2176,9 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
       return AffineMap::get(val.getContext());
     // Get the layout of the vector type.
     xegpu::DistributeLayoutAttr layout = xegpu::getDistributeLayoutAttr(val);
-    // If no layout is specified, that means no distribution.
+    // If no layout is specified, assume uniform case (no distribution).
     if (!layout)
-      return AffineMap::getMultiDimMapWithTargets(vecRank, {},
-                                                  val.getContext());
+      return AffineMap::get(val.getContext());
     // Expecting vector and layout rank to match.
     assert(layout.getRank() == vecRank &&
            "Expecting vector and layout rank to match");
@@ -2133,11 +2214,11 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
 
   vector::populateDistributeReduction(
       patterns, warpReduction,
-      /*pattern benefit=*/regularPatternBenefit);
+      /*pattern benefit=*/PatternHierarchy::Regular);
 
   vector::populatePropagateWarpVectorDistributionPatterns(
       patterns, distributionFn, shuffleFn,
-      /*pattern benefit=*/regularPatternBenefit);
+      /*pattern benefit=*/PatternHierarchy::Regular);
   if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
     signalPassFailure();
     return;
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
index 51783b41c4c96..e9d81dd4d1d81 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
@@ -361,9 +361,9 @@ bool xegpu::recoverTemporaryLayouts(Operation *rootOp) {
         continue;
       auto layout = xegpu::getDistributeLayoutAttr(operand.get());
       if (!layout) {
-        op->emitError("Could not find layout attribute for operand ")
+        op->emitWarning("Could not find layout attribute for operand ")
             << operand.getOperandNumber() << " of operation " << op->getName();
-        return WalkResult::interrupt();
+        continue;
       }
       xegpu::setDistributeLayoutAttr(operand, layout);
     }
diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir
index b136c89925682..a99f850de6175 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir
@@ -16,7 +16,7 @@ gpu.func @store_nd_1d(%laneid: index) {
   gpu.warp_execute_on_lane_0(%laneid)[16] {
     %0 = "some_op"() : () -> !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
     %cst = "some_op"() : () -> vector<16xf32>
-    xegpu.store_nd %cst, %0 [%c0] {layout_operand_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
+    xegpu.store_nd %cst, %0 [%c0] {layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
       : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
   }
   gpu.return
@@ -39,7 +39,7 @@ gpu.func @store_nd_2d(%laneid : index) {
   gpu.warp_execute_on_lane_0(%laneid)[16] {
     %0 = "some_op"() : () -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
     %cst = "some_op"() : () -> vector<16x16xf16>
-    xegpu.store_nd %cst, %0 [%c0, %c0] {layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+    xegpu.store_nd %cst, %0 [%c0, %c0] {layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
       : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
   }
   gpu.return
@@ -60,7 +60,7 @@ gpu.func @load_nd_1d(%laneid: index) {
   %c0 = arith.constant 0 : index
   %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<1xf32>) {
     %0 = "some_op"() : () -> !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
-    %1 = xegpu.load_nd %0 [%c0]  {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} :
+    %1 = xegpu.load_nd %0 [%c0]  {layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>} :
       !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>> -> vector<16xf32>
     gpu.yield %1 : vector<16xf32>
   }
@@ -84,7 +84,7 @@ gpu.func @load_nd_2d(%laneid: index) {
   %c0 = arith.constant 0 : index
   %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<16x1xf16>) {
     %0 = "some_op"() : () -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
-    %1 = xegpu.load_nd %0[%c0, %c0]  {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+    %1 = xegpu.load_nd %0[%c0, %c0]  {layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
       : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<16x16xf16>
     gpu.yield %1 : vector<16x16xf16>
   }
@@ -112,7 +112,7 @@ gpu.func @load_nd_array_length(%laneid: index) {
   %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<2x16x1xf16>) {
     %0 = "some_op"() : () -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2 : i64>,
       #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
-    %1 = xegpu.load_nd %0[%c0, %c0]  {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+    %1 = xegpu.load_nd %0[%c0, %c0]  {layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
       : !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2 : i64>,
         #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<2x16x16xf16>
     gpu.yield %1 : vector<2x16x16xf16>
@@ -192,7 +192,7 @@ gpu.func @prefetch_2d(%laneid: index) {
     %0 = "some_op"() : ()
       -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
     xegpu.prefetch_nd %0[%c0, %c0]
-      <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+      {layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}
       : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
   }
   gpu.return
@@ -215,7 +215,7 @@ gpu.func @prefetch_1d(%laneid: index) {
     %0 = "some_op"() : ()
       -> !xegpu.tensor_desc<16xf16, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
     xegpu.prefetch_nd %0[%c0]
-      <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+      {layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>, l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}
       : !xegpu.tensor_desc<16xf16, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
   }
   gpu.return
@@ -285,14 +285,15 @@ gpu.func @vector_multi_reduction_dim1_distributed_dim0_reduction(%laneid: index)
 
 
 // CHECK-LABEL: gpu.func @vector_multi_reduction_dim1_distributed_dim1_reduction
-// CHECK:      %[[W:.*]] = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<2xf32>) {
+// CHECK:      %[[W:.*]] = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> ({{.*}}) {
 // CHECK-NEXT:   %[[SRC:.*]] = "some_def"() {{.*}} : () -> vector<2x16xf32>
 // CHECK-NEXT:   %[[T2:.*]] = vector.extract %[[SRC]][0] : vector<16xf32> from vector<2x16xf32>
 // CHECK-NEXT:   %[[T3:.*]] = vector.reduction <add>, %[[T2]], %cst : vector<16xf32> into f32
-// CHECK-NEXT:   %[[T4:.*]] = vector.extract %[[SRC]][1] : vector<16xf32> from vector<2x16xf32>
-// CHECK-NEXT:   %[[T5:.*]] = vector.reduction <add>, %[[T4]], %cst : vector<16xf32> into f32
-// CHECK-NEXT:   %[[T6:.*]] = vector.from_elements %[[T3]], %[[T5]] : vector<2xf32>
-// CHECK-NEXT:   gpu.yield %[[T6]] : vector<2xf32>
+// CHECK-NEXT:   %[[T4:.*]] = vector.insert %[[T3]], %cst_0 [0] : f32 into vector<2xf32>
+// CHECK-NEXT:   %[[T5:.*]] = vector.extract %[[SRC]][1] : vector<16xf32> from vector<2x16xf32>
+// CHECK-NEXT:   %[[T6:.*]] = vector.reduction <add>, %[[T5]], %cst : vector<16xf32> into f32
+// CHECK-NEXT:   %[[T7:.*]] = vector.insert %[[T6]], %[[T4]] [1] : f32 into vector<2xf32>
+// CHECK-NEXT:   gpu.yield %[[T7]]
 // CHECK-NEXT: }
 gpu.func @vector_multi_reduction_dim1_distributed_dim1_reduction(%laneid: index) {
   %c0 = arith.constant 0 : index
@@ -317,6 +318,7 @@ gpu.func @vector_multi_reduction_dim1_distributed_dim1_reduction(%laneid: index)
 }
 
 
+
 // CHECK-LABEL:   gpu.func @vector_multi_reduction_dim0_distributed_dim1_reduction
 // CHECK:       %[[ACC:.*]] = arith.constant {{.*}} dense<0.000000e+00> : vector<32xf32>
 // CHECK:       %[[W:.*]]:3 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<2xf32>, vector<2x16xf32>, vector<2xf32>) {
@@ -354,18 +356,19 @@ gpu.func @vector_multi_reduction_dim0_distributed_dim1_reduction(%laneid: index)
 
 
 // CHECK-LABEL: gpu.func @vector_multi_reduction_dim0_distributed_dim0_reduction
-// CHECK:     %[[W:.*]] = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<2xf32>) {
+// CHECK:     %[[W:.*]] = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<2xf32>{{.*}}) {
 // CHECK:       %[[SRC:.*]] = "some_def"() {{.*}} : () -> vector<16x2xf32>
 // CHECK:       %[[T1:.*]] = vector.extract_strided_slice %[[SRC]]
 // CHECK-SAME:    {offsets = [0, 0], sizes = [16, 1], strides = [1, 1]} : vector<16x2xf32> to vector<16x1xf32>
 // CHECK:       %[[T2:.*]] = vector.shape_cast %[[T1]] {{.*}} : vector<16x1xf32> to vector<16xf32>
 // CHECK:       %[[T3:.*]] = vector.reduction <add>, %[[T2]], %{{.*}} : vector<16xf32> into f32
-// CHECK:       %[[T4:.*]] = vector.extract_strided_slice %[[SRC]]
+// CHECK:       %[[T4:.*]] = vector.insert %[[T3]], %cst_0 [0] : f32 into vector<2xf32>
+// CHECK:       %[[T5:.*]] = vector.extract_strided_slice %[[SRC]]
 // CHECK-SAME:     {offsets = [0, 1], sizes = [16, 1], strides = [1, 1]} : vector<16x2xf32> to vector<16x1xf32>
-// CHECK:       %[[T5:.*]] = vector.shape_cast %[[T4]] {{.*}} : vector<16x1xf32> to vector<16xf32>
-// CHECK:       %[[T6:.*]] = vector.reduction <add>, %[[T5]], %{{.*}} : vector<16xf32> into f32
-// CHECK:       %[[T7:.*]] = vector.from_elements %[[T3]], %[[T6]] : vector<2xf32>
-// CHECK:       gpu.yield %[[T7]] : vector<2xf32>
+// CHECK:       %[[T6:.*]] = vector.shape_cast %[[T5]] {{.*}} : vector<16x1xf32> to vector<16xf32>
+// CHECK:       %[[T7:.*]] = vector.reduction <add>, %[[T6]], %{{.*}} : vector<16xf32> into f32
+// CHECK:       %[[T8:.*]] = vector.insert %[[T7]], %[[T4]] [1] : f32 into vector<2xf32>
+// CHECK:       gpu.yield %[[T8]]
 // CHECK:     }
 gpu.func @vector_multi_reduction_dim0_distributed_dim0_reduction(%laneid: index) {
   %c0 = arith.constant 0 : index
@@ -1030,4 +1033,22 @@ gpu.func
   gpu.return
 }
 
+// CHECK-LABEL: gpu.func @vector_shape_cast_scalar_to_vector_uniform
+// CHECK-SAME: (%[[ARG0:.*]]: index)
+// CHECK: %[[R:.*]]:2 = gpu.warp_execute_on_lane_0(%[[ARG0]])[16] -> (vector<16x16xf16>, f16)
+// CHECK: %[[DEF:.*]] = "some_def"()
+// CHECK: %[[BCAST:.*]] = vector.broadcast %[[DEF]] : f16 to vector<16x16xf16>
+// CHECK: gpu.yield %[[BCAST]], %[[DEF]] : vector<16x16xf16>, f16
+// CHECK: %[[RESULT:.*]] = vector.broadcast %[[R]]#1 : f16 to vector<16x16xf16>
+// CHECK: "some_use"(%[[RESULT]])
+  gpu.func @vector_shape_cast_scalar_to_vector_uniform(%arg0: index) {
+    %0 = gpu.warp_execute_on_lane_0(%arg0)[16] -> (vector<16x16xf16>) {
+      %1 = "some_def"() : () -> f16
+      %2 = vector.broadcast %1 : f16 to vector<16x16xf16>
+      gpu.yield %2 : vector<16x16xf16>
+    }
+    "some_use"(%0) : (vector<16x16xf16>) -> ()
+    gpu.return
+  }
+
 }

>From 55e4f7997363db450fa6034aca02bb0b83c617dc Mon Sep 17 00:00:00 2001
From: Artem Kroviakov <artem.kroviakov at intel.com>
Date: Mon, 19 Jan 2026 11:27:33 +0000
Subject: [PATCH 2/2] Cleanup

---
 .../XeGPU/Transforms/XeGPUSubgroupDistribute.cpp   | 14 ++------------
 1 file changed, 2 insertions(+), 12 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 30e104271d2b7..c56793effb2f7 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -62,10 +62,6 @@ namespace {
 /// In certain cases, we may need to favor XeGPU specific distribution patterns
 /// over generic vector distribution patterns. In such cases, we can assign
 /// priorities to patterns.
-static constexpr unsigned regularPatternBenefit = 1;
-static constexpr unsigned highPatternBenefit = 2;
-static constexpr unsigned highestPatternBenefit = 3;
-
 enum PatternHierarchy : unsigned {
   Regular = 1,
   AboveRegular = 2,
@@ -1217,9 +1213,8 @@ struct LoadDistribution final : public gpu::WarpDistributionPattern {
   }
 };
 
-// Sink SG-uniform ops. An op is uniform if any of the following is true:
-// 1. It has no vector operands.
-// 2. All of its vector operands and results are uniform.
+// Sink SG-uniform ops. An op is uniform if any of the following is none
+// of its operands/results has a distribution layout attribute.
 // Non-uniform vectors are handled by dedicated patterns.
 // This pattern must have a higher priority than distribution patterns,
 // because a distributable shape may be logically intended as uniform.
@@ -1233,7 +1228,6 @@ struct SinkUniformOps final : public gpu::WarpDistributionPattern {
     // patterns.
     if (!warpRegionPreYieldOp || warpRegionPreYieldOp->getNumRegions())
       return failure();
-
     int operandIdx = -1;
     if (warpRegionPreYieldOp->getNumResults()) {
       OpOperand *operand = getWarpResult(
@@ -1259,8 +1253,6 @@ struct SinkUniformOps final : public gpu::WarpDistributionPattern {
     if (!uniformValuesOnly)
       return rewriter.notifyMatchFailure(warpOp,
                                          "Some values are not uniform.");
-
-    // Capture its operands and create a new warp op that yields them.
     SmallVector<size_t> newRetIndices;
     SmallVector<Value> operands =
         llvm::to_vector_of<Value>(warpRegionPreYieldOp->getOperands());
@@ -1269,9 +1261,7 @@ struct SinkUniformOps final : public gpu::WarpDistributionPattern {
     gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
         rewriter, warpOp, operands, operandTypes, newRetIndices);
 
-    // Clone the op after the new warp op.
     rewriter.setInsertionPointAfter(newWarpOp);
-
     IRMapping operandMapper;
     for (auto [oldOperandIdx, newOperandIdx] : llvm::enumerate(newRetIndices))
       operandMapper.map(warpRegionPreYieldOp->getOperand(oldOperandIdx),



More information about the Mlir-commits mailing list