[Mlir-commits] [mlir] [MLIR][XeGPU] Clean up stale convert_layout on single-element vector in peephole (PR #194043)

Nishant Patel llvmlistbot at llvm.org
Tue Apr 28 19:59:02 PDT 2026


https://github.com/nbpatel updated https://github.com/llvm/llvm-project/pull/194043

>From 09934bfc97b21bc352c729db26545c50b1f56484 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Fri, 24 Apr 2026 20:02:53 +0000
Subject: [PATCH 1/3] Fix a bug in XeGPUPeepHoleOptimizer

---
 .../Transforms/XeGPUPeepHoleOptimizer.cpp     | 11 +++--
 .../test/Dialect/XeGPU/peephole-optimize.mlir | 41 +++++++++++++++++++
 2 files changed, 48 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPeepHoleOptimizer.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPeepHoleOptimizer.cpp
index 8ade936724480..f4790aa96c920 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPeepHoleOptimizer.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPeepHoleOptimizer.cpp
@@ -449,11 +449,14 @@ class MultiRed2dOpPattern
     auto loc = reductionOp.getLoc();
     auto acc = reductionOp.getAcc();
 
-    // If the result is scalar after reduction, look for consumer
-    // convert_layout op and remove it. The layout propagation pass will
-    // re-install it properly after the decomposition.
+    // If the result is scalar or a single-element vector after reduction,
+    // look for consumer convert_layout op and remove it. The layout
+    // propagation pass will re-install it properly after the decomposition.
     Type resultType = reductionOp.getResult().getType();
-    if (resultType.isIntOrFloat()) {
+    bool isSingleElementVector = false;
+    if (auto vecTy = dyn_cast<VectorType>(resultType))
+      isSingleElementVector = vecTy.getNumElements() == 1;
+    if (resultType.isIntOrFloat() || isSingleElementVector) {
       for (auto &use : reductionOp.getResult().getUses()) {
         if (auto convertLayoutOp =
                 llvm::dyn_cast<xegpu::ConvertLayoutOp>(use.getOwner())) {
diff --git a/mlir/test/Dialect/XeGPU/peephole-optimize.mlir b/mlir/test/Dialect/XeGPU/peephole-optimize.mlir
index f8dfd9a082ba2..5816507cee385 100644
--- a/mlir/test/Dialect/XeGPU/peephole-optimize.mlir
+++ b/mlir/test/Dialect/XeGPU/peephole-optimize.mlir
@@ -399,3 +399,44 @@ gpu.module @xevm_test {
   }
 }
 
+
+// -----
+// CHECK-LABEL: gpu.func @reduce_2d_vec1_convert_layout(
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<4x16xf32>, %[[ARG1:[0-9a-zA-Z]+]]: memref<256xf32>) {
+// CHECK:      %[[ACC_2D:.*]] = arith.constant dense<0.000000e+00> : vector<1x16xf32>
+// CHECK:      %[[ACC_1D:.*]] = arith.constant dense<1.000000e+00> : vector<1xf32>
+// CHECK:      %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG0]] : memref<4x16xf32> -> !xegpu.tensor_desc<4x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+// CHECK:      %[[LOADED:.*]] = xegpu.load_nd %[[TDESC]][0, 0] : !xegpu.tensor_desc<4x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<4x16xf32>
+// CHECK:      %[[SHAPED:.*]] = vector.shape_cast %[[LOADED]] : vector<4x16xf32> to vector<1x4x16xf32>
+// CHECK:      %[[REDUCE_1:.*]] = vector.multi_reduction <add>, %[[SHAPED]], %[[ACC_2D]] [1] : vector<1x4x16xf32> to vector<1x16xf32>
+// CHECK:      %[[REDUCE_2:.*]] = vector.multi_reduction <add>, %[[REDUCE_1]], %[[ACC_1D]] [1] : vector<1x16xf32> to vector<1xf32>
+// CHECK-NOT:  xegpu.convert_layout
+// CHECK:      %[[BCAST:.*]] = vector.broadcast %[[REDUCE_2]] : vector<1xf32> to vector<16xf32>
+// CHECK:      xegpu.store %[[BCAST]], %[[ARG1]]
+gpu.module @xevm_test {
+  gpu.func @reduce_2d_vec1_convert_layout(%src: memref<4x16xf32>, %dst: memref<256xf32>) {
+    %cst = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>, dims = [1, 2]>} dense<1.000000e+00> : vector<1xf32>
+    %tdesc = xegpu.create_nd_tdesc %src : memref<4x16xf32>
+      -> !xegpu.tensor_desc<4x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    %load = xegpu.load_nd %tdesc[0, 0]
+      : !xegpu.tensor_desc<4x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+      -> vector<4x16xf32>
+    %load1 = vector.broadcast %load {layout_result_0 = #xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>}: vector<4x16xf32> to vector<1x4x16xf32>
+    %reduce = vector.multi_reduction <add>, %load1, %cst
+     {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>, dims = [1, 2]>}
+     [1, 2] : vector<1x4x16xf32> to vector<1xf32>
+    %cvt = xegpu.convert_layout %reduce
+     <{input_layout = #xegpu.slice<#xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>, dims = [1, 2]>,
+       target_layout = #xegpu.slice<#xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>, dims = [1, 2]>}>
+     : vector<1xf32>
+    %reduce_bcast = vector.broadcast %cvt
+     {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
+     : vector<1xf32> to vector<16xf32>
+
+    %offset = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<0> : vector<16xindex>
+    %mask = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<1> : vector<16xi1>
+
+    xegpu.store %reduce_bcast, %dst[%offset], %mask {layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>} : vector<16xf32>, memref<256xf32>, vector<16xindex>, vector<16xi1>
+    gpu.return
+  }
+}

>From c43004b25ec92bc3cf62dfe86e938c2996ac56fe Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Sun, 26 Apr 2026 20:26:27 +0000
Subject: [PATCH 2/3] Address feedback

---
 .../mlir/Dialect/XeGPU/IR/XeGPUAttrs.td       |  3 +-
 .../Transforms/XeGPUPeepHoleOptimizer.cpp     | 73 +++++++++++++++----
 .../test/Dialect/XeGPU/peephole-optimize.mlir | 12 +--
 3 files changed, 68 insertions(+), 20 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index 313a4355701a8..40edce8a60429 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -686,7 +686,8 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttr]> {
     }
 
     bool isForLane() const {
-      auto parent = dyn_cast<LayoutAttr>(getParent());
+      SliceAttr attr = flatten();
+      auto parent = dyn_cast<LayoutAttr>(attr.getParent());
       return parent.isForLane();
     }
 
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPeepHoleOptimizer.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPeepHoleOptimizer.cpp
index f4790aa96c920..3f453d1a2ce68 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPeepHoleOptimizer.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPeepHoleOptimizer.cpp
@@ -449,20 +449,49 @@ class MultiRed2dOpPattern
     auto loc = reductionOp.getLoc();
     auto acc = reductionOp.getAcc();
 
-    // If the result is scalar or a single-element vector after reduction,
-    // look for consumer convert_layout op and remove it. The layout
-    // propagation pass will re-install it properly after the decomposition.
-    Type resultType = reductionOp.getResult().getType();
-    bool isSingleElementVector = false;
-    if (auto vecTy = dyn_cast<VectorType>(resultType))
-      isSingleElementVector = vecTy.getNumElements() == 1;
-    if (resultType.isIntOrFloat() || isSingleElementVector) {
-      for (auto &use : reductionOp.getResult().getUses()) {
-        if (auto convertLayoutOp =
-                llvm::dyn_cast<xegpu::ConvertLayoutOp>(use.getOwner())) {
-          rewriter.replaceOp(convertLayoutOp, reductionOp.getResult());
-          break;
-        }
+    // The decomposition below splits the 2D reduction into an intra-lane
+    // then a cross-lane 1D reduction. If a consumer xegpu.convert_layout
+    // exists on the reduction result, its input_layout was stamped by
+    // layout propagation against the original 2D reduction's slice and
+    // is therefore stale once we replace the producer with two 1D
+    // reductions.
+    //
+    // Hence insert a NEW xegpu.convert_layout between the decomposed
+    // reduction result and the existing convert_layout. The new op
+    // bridges from the natural post-decomposition producer layout
+    // to the layout that the existing convert_layout currently expects on its
+    // input. The existing convert_layout is left untouched.
+    xegpu::ConvertLayoutOp consumerConvertOp;
+    for (auto &use : reductionOp.getResult().getUses()) {
+      if (auto convertLayoutOp =
+              llvm::dyn_cast<xegpu::ConvertLayoutOp>(use.getOwner())) {
+        consumerConvertOp = convertLayoutOp;
+        break;
+      }
+    }
+    xegpu::DistributeLayoutAttr postDecompLayout;
+    if (consumerConvertOp) {
+      // Derive the source vector's layout.
+      xegpu::DistributeLayoutAttr srcLayoutForCvt;
+      if (auto resSlice = dyn_cast_if_present<xegpu::SliceAttr>(resLayout))
+        srcLayoutForCvt = resSlice.getParent();
+      if (!srcLayoutForCvt)
+        srcLayoutForCvt =
+            xegpu::getDistributeLayoutAttr(reductionOp.getSource());
+      if (srcLayoutForCvt) {
+        // The natural layout of the post-decomposition reduction result
+        // is a nested SliceAttr: REDUCE_1 (reduces `intraLaneDim` from
+        // the source) yields `slice<src, [intraLaneDim]>`; REDUCE_2
+        // then reduces `adjCrossLaneDim` from that intermediate, giving
+        // `slice<slice<src, [intraLaneDim]>, [adjCrossLaneDim]>`.
+        MLIRContext *ctx = consumerConvertOp.getContext();
+        int64_t adjCrossLaneDim =
+            crossLaneDim > intraLaneDim ? crossLaneDim - 1 : crossLaneDim;
+        auto intermediateLayout = xegpu::SliceAttr::get(
+            ctx, srcLayoutForCvt, DenseI64ArrayAttr::get(ctx, {intraLaneDim}));
+        postDecompLayout = xegpu::SliceAttr::get(
+            ctx, intermediateLayout,
+            DenseI64ArrayAttr::get(ctx, {adjCrossLaneDim}));
       }
     }
 
@@ -484,6 +513,19 @@ class MultiRed2dOpPattern
         ArrayRef<int64_t>(crossLaneDim));
     assert(crossLaneReduced.getType() == reductionOp.getResult().getType() &&
            "Type mismatch");
+
+    if (consumerConvertOp && postDecompLayout) {
+      auto consumerInputLayout = consumerConvertOp.getInputLayoutAttr();
+      OpBuilder::InsertionGuard g(rewriter);
+      rewriter.setInsertionPoint(consumerConvertOp);
+      auto bridgeOp = xegpu::ConvertLayoutOp::create(
+          rewriter, loc, crossLaneReduced.getType(), crossLaneReduced,
+          postDecompLayout, consumerInputLayout);
+      rewriter.modifyOpInPlace(consumerConvertOp, [&]() {
+        consumerConvertOp.getSourceMutable().set(bridgeOp.getResult());
+      });
+    }
+
     rewriter.replaceOp(reductionOp, crossLaneReduced);
     return success();
   }
@@ -587,6 +629,9 @@ struct XeGPUPeepHoleOptimizerPass final
 
     target.addLegalDialect<arith::ArithDialect, memref::MemRefDialect,
                            vector::VectorDialect>();
+    // xegpu.convert_layout is left untouched by this pass; mark it legal
+    // so in-place updates don't trigger re-legalization failures.
+    target.addLegalOp<xegpu::ConvertLayoutOp>();
     scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns,
                                                          target);
     xegpu::populateXeGPUPeepHoleOptimizerPatterns(patterns);
diff --git a/mlir/test/Dialect/XeGPU/peephole-optimize.mlir b/mlir/test/Dialect/XeGPU/peephole-optimize.mlir
index 5816507cee385..543f059aef5bf 100644
--- a/mlir/test/Dialect/XeGPU/peephole-optimize.mlir
+++ b/mlir/test/Dialect/XeGPU/peephole-optimize.mlir
@@ -369,8 +369,9 @@ gpu.module @xevm_test {
 // CHECK:      %[[LOADED:.*]] = xegpu.load_nd %[[TDESC]][0, 0] : !xegpu.tensor_desc<4x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<4x16xf32>
 // CHECK:      %[[REDUCE_1:.*]] = vector.multi_reduction <add>, %[[LOADED]], %[[ACC_VEC]] [0] : vector<4x16xf32> to vector<16xf32>
 // CHECK:      %[[REDUCE_2:.*]] = vector.multi_reduction <add>, %[[REDUCE_1]], %[[ACC_SCALAR]] [0] : vector<16xf32> to f32
-// CHECK-NOT:  xegpu.convert_layout
-// CHECK:      %[[BCAST:.*]] = vector.broadcast %[[REDUCE_2]] : f32 to vector<16xf32>
+// CHECK:      %[[BRIDGE:.*]] = xegpu.convert_layout %[[REDUCE_2]] <{input_layout = #xegpu.slice<#xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>, dims = [0]>, target_layout = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0, 1]>}> : f32
+// CHECK:      %[[CVT:.*]] = xegpu.convert_layout %[[BRIDGE]] <{input_layout = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0, 1]>, target_layout = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0, 1]>}> : f32
+// CHECK:      %[[BCAST:.*]] = vector.broadcast %[[CVT]] : f32 to vector<16xf32>
 // CHECK:      xegpu.store %[[BCAST]], %[[ARG1]]
 gpu.module @xevm_test {
   gpu.func @reduce_2d_scalar_convert_layout(%src: memref<4x16xf32>, %dst: memref<256xf32>) {
@@ -410,8 +411,9 @@ gpu.module @xevm_test {
 // CHECK:      %[[SHAPED:.*]] = vector.shape_cast %[[LOADED]] : vector<4x16xf32> to vector<1x4x16xf32>
 // CHECK:      %[[REDUCE_1:.*]] = vector.multi_reduction <add>, %[[SHAPED]], %[[ACC_2D]] [1] : vector<1x4x16xf32> to vector<1x16xf32>
 // CHECK:      %[[REDUCE_2:.*]] = vector.multi_reduction <add>, %[[REDUCE_1]], %[[ACC_1D]] [1] : vector<1x16xf32> to vector<1xf32>
-// CHECK-NOT:  xegpu.convert_layout
-// CHECK:      %[[BCAST:.*]] = vector.broadcast %[[REDUCE_2]] : vector<1xf32> to vector<16xf32>
+// CHECK:      %[[BRIDGE:.*]] = xegpu.convert_layout %[[REDUCE_2]] <{input_layout = #xegpu.slice<#xegpu.slice<#xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>, dims = [1]>, dims = [1]>, target_layout = #xegpu.slice<#xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>, dims = [1, 2]>}> : vector<1xf32>
+// CHECK:      %[[CVT:.*]] = xegpu.convert_layout %[[BRIDGE]] <{input_layout = #xegpu.slice<#xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>, dims = [1, 2]>, target_layout = #xegpu.slice<#xegpu.layout<lane_layout = [1, 1, 1], lane_data = [1, 1, 1]>, dims = [1, 2]>}> : vector<1xf32>
+// CHECK:      %[[BCAST:.*]] = vector.broadcast %[[CVT]] : vector<1xf32> to vector<16xf32>
 // CHECK:      xegpu.store %[[BCAST]], %[[ARG1]]
 gpu.module @xevm_test {
   gpu.func @reduce_2d_vec1_convert_layout(%src: memref<4x16xf32>, %dst: memref<256xf32>) {
@@ -427,7 +429,7 @@ gpu.module @xevm_test {
      [1, 2] : vector<1x4x16xf32> to vector<1xf32>
     %cvt = xegpu.convert_layout %reduce
      <{input_layout = #xegpu.slice<#xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>, dims = [1, 2]>,
-       target_layout = #xegpu.slice<#xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>, dims = [1, 2]>}>
+       target_layout = #xegpu.slice<#xegpu.layout<lane_layout = [1, 1, 1], lane_data = [1, 1, 1]>, dims = [1, 2]>}>
      : vector<1xf32>
     %reduce_bcast = vector.broadcast %cvt
      {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}

>From 528fd63f75a092075a949dad7142a51900ee9206 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Wed, 29 Apr 2026 02:49:34 +0000
Subject: [PATCH 3/3] insert conv layout

---
 .../Transforms/XeGPUPeepHoleOptimizer.cpp     | 51 ++++++++-----------
 .../test/Dialect/XeGPU/peephole-optimize.mlir |  6 ++-
 2 files changed, 25 insertions(+), 32 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPeepHoleOptimizer.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPeepHoleOptimizer.cpp
index 3f453d1a2ce68..9f39deb06959e 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPeepHoleOptimizer.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPeepHoleOptimizer.cpp
@@ -450,27 +450,17 @@ class MultiRed2dOpPattern
     auto acc = reductionOp.getAcc();
 
     // The decomposition below splits the 2D reduction into an intra-lane
-    // then a cross-lane 1D reduction. If a consumer xegpu.convert_layout
-    // exists on the reduction result, its input_layout was stamped by
-    // layout propagation against the original 2D reduction's slice and
-    // is therefore stale once we replace the producer with two 1D
-    // reductions.
-    //
-    // Hence insert a NEW xegpu.convert_layout between the decomposed
-    // reduction result and the existing convert_layout. The new op
-    // bridges from the natural post-decomposition producer layout
-    // to the layout that the existing convert_layout currently expects on its
-    // input. The existing convert_layout is left untouched.
-    xegpu::ConvertLayoutOp consumerConvertOp;
-    for (auto &use : reductionOp.getResult().getUses()) {
-      if (auto convertLayoutOp =
-              llvm::dyn_cast<xegpu::ConvertLayoutOp>(use.getOwner())) {
-        consumerConvertOp = convertLayoutOp;
-        break;
-      }
-    }
+    // then a cross-lane 1D reduction. The natural result layout of the
+    // decomposed sequence (a doubly-sliced layout) differs from the
+    // original 2D reduction's result layout that the rest of the IR was
+    // written/propagated against. To keep the post-peephole IR
+    // self-consistent without depending on a follow-up layout
+    // propagation pass, we always insert a bridge xegpu.convert_layout
+    // from the natural post-decomposition layout to the original
+    // reduction's result layout. Trivial bridges fold away in
+    // canonicalization.
     xegpu::DistributeLayoutAttr postDecompLayout;
-    if (consumerConvertOp) {
+    if (resLayout) {
       // Derive the source vector's layout.
       xegpu::DistributeLayoutAttr srcLayoutForCvt;
       if (auto resSlice = dyn_cast_if_present<xegpu::SliceAttr>(resLayout))
@@ -484,7 +474,7 @@ class MultiRed2dOpPattern
         // the source) yields `slice<src, [intraLaneDim]>`; REDUCE_2
         // then reduces `adjCrossLaneDim` from that intermediate, giving
         // `slice<slice<src, [intraLaneDim]>, [adjCrossLaneDim]>`.
-        MLIRContext *ctx = consumerConvertOp.getContext();
+        MLIRContext *ctx = reductionOp.getContext();
         int64_t adjCrossLaneDim =
             crossLaneDim > intraLaneDim ? crossLaneDim - 1 : crossLaneDim;
         auto intermediateLayout = xegpu::SliceAttr::get(
@@ -514,19 +504,20 @@ class MultiRed2dOpPattern
     assert(crossLaneReduced.getType() == reductionOp.getResult().getType() &&
            "Type mismatch");
 
-    if (consumerConvertOp && postDecompLayout) {
-      auto consumerInputLayout = consumerConvertOp.getInputLayoutAttr();
-      OpBuilder::InsertionGuard g(rewriter);
-      rewriter.setInsertionPoint(consumerConvertOp);
+    Value replacement = crossLaneReduced;
+    if (resLayout && postDecompLayout) {
+      // Bridge from the natural post-decomposition layout to the
+      // original reduction's result layout. This preserves the contract
+      // any consumer (convert_layout, anchor op, or otherwise) was
+      // written against, so the rewrite is correct independent of
+      // whether layout propagation runs afterwards.
       auto bridgeOp = xegpu::ConvertLayoutOp::create(
           rewriter, loc, crossLaneReduced.getType(), crossLaneReduced,
-          postDecompLayout, consumerInputLayout);
-      rewriter.modifyOpInPlace(consumerConvertOp, [&]() {
-        consumerConvertOp.getSourceMutable().set(bridgeOp.getResult());
-      });
+          postDecompLayout, resLayout);
+      replacement = bridgeOp.getResult();
     }
 
-    rewriter.replaceOp(reductionOp, crossLaneReduced);
+    rewriter.replaceOp(reductionOp, replacement);
     return success();
   }
 
diff --git a/mlir/test/Dialect/XeGPU/peephole-optimize.mlir b/mlir/test/Dialect/XeGPU/peephole-optimize.mlir
index 543f059aef5bf..d360ade3f1adf 100644
--- a/mlir/test/Dialect/XeGPU/peephole-optimize.mlir
+++ b/mlir/test/Dialect/XeGPU/peephole-optimize.mlir
@@ -293,7 +293,8 @@ gpu.func @array_length(%arg0: vector<8x16xf16>, %arg1: memref<256x256xf16>, %arg
 // CHECK:      %[[LOADED:.*]] = xegpu.load_nd %[[TDESC]][0, 0] : !xegpu.tensor_desc<4x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<4x16xf32>
 // CHECK:      %[[REDUCE_1:.*]] = vector.multi_reduction <add>, %[[LOADED]], %[[ACC_VEC]] [0] : vector<4x16xf32> to vector<16xf32>
 // CHECK:      %[[REDUCE_2:.*]] = vector.multi_reduction <add>, %[[REDUCE_1]], %[[ACC_SCALAR]] [0] : vector<16xf32> to f32
-// CHECK:      %[[BCAST:.*]] = vector.broadcast %[[REDUCE_2]] : f32 to vector<16xf32>
+// CHECK:      %[[BRIDGE:.*]] = xegpu.convert_layout %[[REDUCE_2]] <{input_layout = #xegpu.slice<#xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>, dims = [0]>, target_layout = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0, 1]>}> : f32
+// CHECK:      %[[BCAST:.*]] = vector.broadcast %[[BRIDGE]] : f32 to vector<16xf32>
 // CHECK:      xegpu.store %[[BCAST]], %[[ARG1]][%[[OFFSET]]], %[[MASK]]
 // CHECK-SAME: <{layout = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>}>
 // CHECK-SAME: : vector<16xf32>, memref<256xf32>, vector<16xindex>, vector<16xi1>
@@ -332,7 +333,8 @@ gpu.module @xevm_test {
 // CHECK:      %[[SHAPED:.*]] = vector.shape_cast %[[LOADED]] : vector<4x16xf32> to vector<1x4x16xf32>
 // CHECK:      %[[REDUCE_1:.*]] = vector.multi_reduction <add>, %[[SHAPED]], %[[ACC_2D]] [1] : vector<1x4x16xf32> to vector<1x16xf32>
 // CHECK:      %[[REDUCE_2:.*]] = vector.multi_reduction <add>, %[[REDUCE_1]], %[[ACC_1D]] [1] : vector<1x16xf32> to vector<1xf32>
-// CHECK:      %[[BCAST:.*]] = vector.broadcast %[[REDUCE_2]] : vector<1xf32> to vector<16xf32>
+// CHECK:      %[[BRIDGE:.*]] = xegpu.convert_layout %[[REDUCE_2]] <{input_layout = #xegpu.slice<#xegpu.slice<#xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>, dims = [1]>, dims = [1]>, target_layout = #xegpu.slice<#xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>, dims = [1, 2]>}> : vector<1xf32>
+// CHECK:      %[[BCAST:.*]] = vector.broadcast %[[BRIDGE]] : vector<1xf32> to vector<16xf32>
 // CHECK:      xegpu.store %[[BCAST]], %[[ARG1]][%[[OFFSET]]], %[[MASK]]
 // CHECK-SAME: <{layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>}>
 // CHECK-SAME: : vector<16xf32>, memref<256xf32>, vector<16xindex>, vector<16xi1>



More information about the Mlir-commits mailing list