[Mlir-commits] [mlir] [MLIR][XeGPU] Lower vector.multi_reduction to vector.reduction for lane local (PR #191037)

Nishant Patel llvmlistbot at llvm.org
Thu Apr 9 13:18:21 PDT 2026


https://github.com/nbpatel updated https://github.com/llvm/llvm-project/pull/191037

>From 65f18bf269c4328bf7711359a278750af56b5a2c Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Wed, 8 Apr 2026 17:02:15 +0000
Subject: [PATCH 1/3] Lower multi_reduction to reduction for lane local

---
 .../mlir/Dialect/XeGPU/Utils/XeGPUUtils.h     |  3 +-
 .../XeGPUSgToWiDistributeExperimental.cpp     | 20 ++++++-----
 mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp   | 30 +++++++++-------
 .../XeGPU/sg-to-wi-experimental-unit.mlir     | 35 ++++++++++++++-----
 .../Dialect/XeGPU/sg-to-wi-experimental.mlir  |  3 +-
 5 files changed, 58 insertions(+), 33 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
index 0aa2cd45088f3..f571aece0daf7 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
@@ -145,7 +145,8 @@ Value subgroupReduction(Location loc, OpBuilder &builder, Value input,
 Value lowerToVectorReductions(TypedValue<VectorType> src,
                               TypedValue<VectorType> acc,
                               vector::CombiningKind kind, int64_t reductionDim,
-                              Location loc, PatternRewriter &rewriter);
+                              Location loc, PatternRewriter &rewriter,
+                              bool setLayout = true);
 
 /// Creates a constant filled with the neutral (identity) value for the
 /// given reduction kind. For example: 0 for ADD/OR/XOR, 1 for MUL/AND,
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
index e3227c7f5b149..31ed21f8be143 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
@@ -590,15 +590,17 @@ struct SgToWiMultiDimReduction
         result = vector::makeArithReduction(rewriter, op.getLoc(), op.getKind(),
                                             result, adaptor.getAcc());
     } else if (isReductionLaneLocal(op)) {
-      auto resLayout = xegpu::getTemporaryLayout(op->getOpResult(0));
-      VectorType resVecTy = dyn_cast<VectorType>(op.getType());
-      auto resDistVecTyOrFailure =
-          getDistVecTypeBasedOnLaneLayout(resLayout, resVecTy);
-      // For lane local reduction, simply create a new MultiDimReductionOp using
-      // adaptor operands and the new result type.
-      result = vector::MultiDimReductionOp::create(
-          rewriter, op.getLoc(), resDistVecTyOrFailure.value(), op.getKind(),
-          adaptor.getSource(), adaptor.getAcc(), op.getReductionDims());
+      // For lane-local reduction, lower to a sequence of vector.reduction ops
+      // over 1D slices extracted from the distributed source vector. This is
+      // required so we dont have 2D source vectors at xegpu-linearize. The
+      // setLayout parameter is to make lowerToVectorReductions generic for both
+      // the old and the new pass. It will be removed once we deprecate the old
+      // pass.
+      auto reductionDim = reductionDims[0];
+      result = xegpu::lowerToVectorReductions(
+          cast<TypedValue<VectorType>>(adaptor.getSource()),
+          cast<TypedValue<VectorType>>(adaptor.getAcc()), op.getKind(),
+          reductionDim, op.getLoc(), rewriter, /*setLayout=*/false);
     } else {
       auto reductionDim = reductionDims[0];
       VectorType sourceType = op.getSourceVectorType();
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
index 243581b4ce522..acae9bf1c9562 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
@@ -671,7 +671,8 @@ Value xegpu::lowerToVectorReductions(TypedValue<VectorType> src,
                                      TypedValue<VectorType> acc,
                                      vector::CombiningKind kind,
                                      int64_t reductionDim, Location loc,
-                                     PatternRewriter &rewriter) {
+                                     PatternRewriter &rewriter,
+                                     bool setLayout) {
   VectorType sourceType = src.getType();
   int64_t sourceRank = sourceType.getRank();
   // Expecting at least a 2D source vector. Leading dimensions (all except the
@@ -690,10 +691,13 @@ Value xegpu::lowerToVectorReductions(TypedValue<VectorType> src,
   Value reductionResult = arith::ConstantOp::create(
       rewriter, loc, acc.getType(),
       DenseElementsAttr::get(acc.getType(), zeroAttr));
-  auto srcLayout = xegpu::getTemporaryLayout(dyn_cast<OpResult>(src));
-  auto accLayout = xegpu::getTemporaryLayout(dyn_cast<OpResult>(acc));
-  // Reduction result should have the same layout as the accumulator.
-  xegpu::setTemporaryLayout(cast<OpResult>(reductionResult), accLayout);
+  xegpu::DistributeLayoutAttr srcLayout, accLayout;
+  if (setLayout) {
+    srcLayout = xegpu::getTemporaryLayout(dyn_cast<OpResult>(src));
+    accLayout = xegpu::getTemporaryLayout(dyn_cast<OpResult>(acc));
+    // Reduction result should have the same layout as the accumulator.
+    xegpu::setTemporaryLayout(cast<OpResult>(reductionResult), accLayout);
+  }
   // For each slice of the source, extract the slice vector, do a reduction
   // and, insert the reduced value back to the result vector.
   int64_t accRank = acc.getType().getRank();
@@ -714,8 +718,8 @@ Value xegpu::lowerToVectorReductions(TypedValue<VectorType> src,
     vector::ExtractStridedSliceOp extractOp =
         vector::ExtractStridedSliceOp::create(rewriter, loc, src, sliceOffsets,
                                               sliceSizes, strides);
-    // Extract strided slice has the same layout as src.
-    xegpu::setTemporaryLayout(extractOp->getOpResult(0), srcLayout);
+    if (setLayout)
+      xegpu::setTemporaryLayout(extractOp->getOpResult(0), srcLayout);
 
     int64_t nSliceElements = extractOp.getResult().getType().getNumElements();
 
@@ -724,10 +728,10 @@ Value xegpu::lowerToVectorReductions(TypedValue<VectorType> src,
         VectorType::get({nSliceElements}, sourceType.getElementType()),
         extractOp.getResult());
 
-    // Shape cast output has the same layout as the accumulator. Shape cast
-    // source has the same layout as the original reduction source.
-    xegpu::setTemporaryLayout(slice->getOpOperand(0), srcLayout);
-    xegpu::setTemporaryLayout(slice->getOpResult(0), accLayout);
+    if (setLayout) {
+      xegpu::setTemporaryLayout(slice->getOpOperand(0), srcLayout);
+      xegpu::setTemporaryLayout(slice->getOpResult(0), accLayout);
+    }
     // Extract and reduction results in scalars, so no result layout is needed.
     // Build multi-dim index into acc (sourceRank-1 dims, i.e. source shape with
     // the reduction dim removed). Leading unit dims get index 0.
@@ -738,8 +742,8 @@ Value xegpu::lowerToVectorReductions(TypedValue<VectorType> src,
         rewriter, loc, kind, slice.getResult(), accExtract);
     reductionResult = vector::InsertOp::create(rewriter, loc, reduction,
                                                reductionResult, accIdx);
-    // Insert op should have the same layout as the accumulator.
-    xegpu::setTemporaryLayout(cast<OpResult>(reductionResult), accLayout);
+    if (setLayout)
+      xegpu::setTemporaryLayout(cast<OpResult>(reductionResult), accLayout);
   }
   return reductionResult;
 }
diff --git a/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir b/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
index 0335105ebe7f0..4c3727388831b 100644
--- a/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
+++ b/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
@@ -432,9 +432,13 @@ gpu.func @vector_multi_reduction_dim0_distributed_dim0_reduction(%laneid: index)
 }
 
 // CHECK-LABEL: gpu.func @vector_multi_reduction_dim1_distributed_dim0_reduction
-// CHECK:         %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<4x1xf32>
-// CHECK:         %[[CST_0:.*]] = arith.constant dense<0.000000e+00> : vector<1xf32>
-// CHECK:         %[[RED:.*]] = vector.multi_reduction <add>, %[[CST]], %[[CST_0]] [0] : vector<4x1xf32> to vector<1xf32>
+// CHECK-DAG:     %[[SRC:.*]] = arith.constant dense<0.000000e+00> : vector<4x1xf32>
+// CHECK-DAG:     %[[ACC:.*]] = arith.constant dense<0.000000e+00> : vector<1xf32>
+// CHECK:         %[[SLICE:.*]] = vector.extract_strided_slice %[[SRC]] {offsets = [0, 0], sizes = [4, 1], strides = [1, 1]} : vector<4x1xf32> to vector<4x1xf32>
+// CHECK:         %[[FLAT:.*]] = vector.shape_cast %[[SLICE]] : vector<4x1xf32> to vector<4xf32>
+// CHECK:         %[[ACC_EL:.*]] = vector.extract %[[ACC]][0] : f32 from vector<1xf32>
+// CHECK:         %[[RED:.*]] = vector.reduction <add>, %[[FLAT]], %[[ACC_EL]] : vector<4xf32> into f32
+// CHECK:         vector.insert %[[RED]], %{{.*}} [0] : f32 into vector<1xf32>
 // CHECK:         gpu.return
 gpu.func @vector_multi_reduction_dim1_distributed_dim0_reduction(%laneid: index) {
   %c0 = arith.constant 0 : index
@@ -453,9 +457,13 @@ gpu.func @vector_multi_reduction_dim1_distributed_dim0_reduction(%laneid: index)
 }
 
 // CHECK-LABEL: gpu.func @vector_multi_reduction_dim0_distributed_dim1_reduction
-// CHECK:         %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<1x12xf32>
-// CHECK:         %[[CST_0:.*]] = arith.constant dense<0.000000e+00> : vector<1xf32>
-// CHECK:         %[[RED:.*]] = vector.multi_reduction <add>, %[[CST]], %[[CST_0]] [1] : vector<1x12xf32> to vector<1xf32>
+// CHECK-DAG:     %[[SRC:.*]] = arith.constant dense<0.000000e+00> : vector<1x12xf32>
+// CHECK-DAG:     %[[ACC:.*]] = arith.constant dense<0.000000e+00> : vector<1xf32>
+// CHECK:         %[[SLICE:.*]] = vector.extract_strided_slice %[[SRC]] {offsets = [0, 0], sizes = [1, 12], strides = [1, 1]} : vector<1x12xf32> to vector<1x12xf32>
+// CHECK:         %[[FLAT:.*]] = vector.shape_cast %[[SLICE]] : vector<1x12xf32> to vector<12xf32>
+// CHECK:         %[[ACC_EL:.*]] = vector.extract %[[ACC]][0] : f32 from vector<1xf32>
+// CHECK:         %[[RED:.*]] = vector.reduction <add>, %[[FLAT]], %[[ACC_EL]] : vector<12xf32> into f32
+// CHECK:         vector.insert %[[RED]], %{{.*}} [0] : f32 into vector<1xf32>
 // CHECK:         gpu.return
 gpu.func @vector_multi_reduction_dim0_distributed_dim1_reduction(%laneid: index) {
   %c0 = arith.constant 0 : index
@@ -582,9 +590,18 @@ gpu.func @constant_mask_2d() {
 
 
 // CHECK-LABEL: gpu.func @vector_multi_reduction_3d_leading_unit_dim_lane_local
-// CHECK:         %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<1x16x2xf32>
-// CHECK:         %[[CST_0:.*]] = arith.constant dense<0.000000e+00> : vector<1x2xf32>
-// CHECK:         %[[RED:.*]] = vector.multi_reduction <add>, %[[CST]], %[[CST_0]] [1] : vector<1x16x2xf32> to vector<1x2xf32>
+// CHECK-DAG:     %[[SRC:.*]] = arith.constant dense<0.000000e+00> : vector<1x16x2xf32>
+// CHECK-DAG:     %[[ACC:.*]] = arith.constant dense<0.000000e+00> : vector<1x2xf32>
+// CHECK:         %[[S0:.*]] = vector.extract_strided_slice %[[SRC]] {offsets = [0, 0, 0], sizes = [1, 16, 1], strides = [1, 1, 1]} : vector<1x16x2xf32> to vector<1x16x1xf32>
+// CHECK:         %[[F0:.*]] = vector.shape_cast %[[S0]] : vector<1x16x1xf32> to vector<16xf32>
+// CHECK:         %[[A0:.*]] = vector.extract %[[ACC]][0, 0] : f32 from vector<1x2xf32>
+// CHECK:         %[[R0:.*]] = vector.reduction <add>, %[[F0]], %[[A0]] : vector<16xf32> into f32
+// CHECK:         %[[I0:.*]] = vector.insert %[[R0]], %{{.*}} [0, 0] : f32 into vector<1x2xf32>
+// CHECK:         %[[S1:.*]] = vector.extract_strided_slice %[[SRC]] {offsets = [0, 0, 1], sizes = [1, 16, 1], strides = [1, 1, 1]} : vector<1x16x2xf32> to vector<1x16x1xf32>
+// CHECK:         %[[F1:.*]] = vector.shape_cast %[[S1]] : vector<1x16x1xf32> to vector<16xf32>
+// CHECK:         %[[A1:.*]] = vector.extract %[[ACC]][0, 1] : f32 from vector<1x2xf32>
+// CHECK:         %[[R1:.*]] = vector.reduction <add>, %[[F1]], %[[A1]] : vector<16xf32> into f32
+// CHECK:         vector.insert %[[R1]], %[[I0]] [0, 1] : f32 into vector<1x2xf32>
 // CHECK:         gpu.return
 gpu.func @vector_multi_reduction_3d_leading_unit_dim_lane_local() {
     %src = arith.constant
diff --git a/mlir/test/Dialect/XeGPU/sg-to-wi-experimental.mlir b/mlir/test/Dialect/XeGPU/sg-to-wi-experimental.mlir
index 9febd79c7adc3..babb01c131792 100644
--- a/mlir/test/Dialect/XeGPU/sg-to-wi-experimental.mlir
+++ b/mlir/test/Dialect/XeGPU/sg-to-wi-experimental.mlir
@@ -445,7 +445,8 @@ gpu.module @xevm_module{
 
 // -----
 // CHECK-LABEL: gpu.func @vector_broadcast_1d_to_2d_broadcast_within_lane({{.*}}) {
-// CHECK: %[[BCAST:.*]] = vector.broadcast %{{.*}} : vector<1xf16> to vector<16xf16>
+// CHECK: %[[RED:.*]] = vector.reduction <add>, %{{.*}}, %{{.*}} : vector<16xf16> into f16
+// CHECK: %[[BCAST:.*]] = vector.broadcast %[[RED]] : f16 to vector<16xf16>
 gpu.module @xevm_module{
    gpu.func @vector_broadcast_1d_to_2d_broadcast_within_lane(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf16>) {
     %c0 = arith.constant 0 : index

>From 0853c6a37bb80781857a904921955c98b4a2e2a7 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Thu, 9 Apr 2026 03:26:05 +0000
Subject: [PATCH 2/3] Address feedback

---
 .../mlir/Dialect/XeGPU/Utils/XeGPUUtils.h     |  3 +-
 .../XeGPUSgToWiDistributeExperimental.cpp     |  7 ++---
 mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp   | 30 ++++++++-----------
 3 files changed, 16 insertions(+), 24 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
index f571aece0daf7..0aa2cd45088f3 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
@@ -145,8 +145,7 @@ Value subgroupReduction(Location loc, OpBuilder &builder, Value input,
 Value lowerToVectorReductions(TypedValue<VectorType> src,
                               TypedValue<VectorType> acc,
                               vector::CombiningKind kind, int64_t reductionDim,
-                              Location loc, PatternRewriter &rewriter,
-                              bool setLayout = true);
+                              Location loc, PatternRewriter &rewriter);
 
 /// Creates a constant filled with the neutral (identity) value for the
 /// given reduction kind. For example: 0 for ADD/OR/XOR, 1 for MUL/AND,
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
index 31ed21f8be143..b086a6571ddb4 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
@@ -592,15 +592,12 @@ struct SgToWiMultiDimReduction
     } else if (isReductionLaneLocal(op)) {
       // For lane-local reduction, lower to a sequence of vector.reduction ops
       // over 1D slices extracted from the distributed source vector. This is
-      // required so we dont have 2D source vectors at xegpu-linearize. The
-      // setLayout parameter is to make lowerToVectorReductions generic for both
-      // the old and the new pass. It will be removed once we deprecate the old
-      // pass.
+      // required so we dont have 2D source vectors at xegpu-linearize.
       auto reductionDim = reductionDims[0];
       result = xegpu::lowerToVectorReductions(
           cast<TypedValue<VectorType>>(adaptor.getSource()),
           cast<TypedValue<VectorType>>(adaptor.getAcc()), op.getKind(),
-          reductionDim, op.getLoc(), rewriter, /*setLayout=*/false);
+          reductionDim, op.getLoc(), rewriter);
     } else {
       auto reductionDim = reductionDims[0];
       VectorType sourceType = op.getSourceVectorType();
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
index acae9bf1c9562..3bb6cf82f9ee4 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
@@ -671,8 +671,7 @@ Value xegpu::lowerToVectorReductions(TypedValue<VectorType> src,
                                      TypedValue<VectorType> acc,
                                      vector::CombiningKind kind,
                                      int64_t reductionDim, Location loc,
-                                     PatternRewriter &rewriter,
-                                     bool setLayout) {
+                                     PatternRewriter &rewriter) {
   VectorType sourceType = src.getType();
   int64_t sourceRank = sourceType.getRank();
   // Expecting at least a 2D source vector. Leading dimensions (all except the
@@ -691,13 +690,14 @@ Value xegpu::lowerToVectorReductions(TypedValue<VectorType> src,
   Value reductionResult = arith::ConstantOp::create(
       rewriter, loc, acc.getType(),
       DenseElementsAttr::get(acc.getType(), zeroAttr));
-  xegpu::DistributeLayoutAttr srcLayout, accLayout;
-  if (setLayout) {
-    srcLayout = xegpu::getTemporaryLayout(dyn_cast<OpResult>(src));
-    accLayout = xegpu::getTemporaryLayout(dyn_cast<OpResult>(acc));
-    // Reduction result should have the same layout as the accumulator.
-    xegpu::setTemporaryLayout(cast<OpResult>(reductionResult), accLayout);
-  }
+  // TODO: Remove these get/setTemporaryLayout calls after we deprecate the old
+  // pass.
+  xegpu::DistributeLayoutAttr srcLayout =
+      xegpu::getTemporaryLayout(dyn_cast<OpResult>(src));
+  xegpu::DistributeLayoutAttr accLayout =
+      xegpu::getTemporaryLayout(dyn_cast<OpResult>(acc));
+  // Reduction result should have the same layout as the accumulator.
+  xegpu::setTemporaryLayout(cast<OpResult>(reductionResult), accLayout);
   // For each slice of the source, extract the slice vector, do a reduction
   // and, insert the reduced value back to the result vector.
   int64_t accRank = acc.getType().getRank();
@@ -718,8 +718,7 @@ Value xegpu::lowerToVectorReductions(TypedValue<VectorType> src,
     vector::ExtractStridedSliceOp extractOp =
         vector::ExtractStridedSliceOp::create(rewriter, loc, src, sliceOffsets,
                                               sliceSizes, strides);
-    if (setLayout)
-      xegpu::setTemporaryLayout(extractOp->getOpResult(0), srcLayout);
+    xegpu::setTemporaryLayout(extractOp->getOpResult(0), srcLayout);
 
     int64_t nSliceElements = extractOp.getResult().getType().getNumElements();
 
@@ -728,10 +727,8 @@ Value xegpu::lowerToVectorReductions(TypedValue<VectorType> src,
         VectorType::get({nSliceElements}, sourceType.getElementType()),
         extractOp.getResult());
 
-    if (setLayout) {
-      xegpu::setTemporaryLayout(slice->getOpOperand(0), srcLayout);
-      xegpu::setTemporaryLayout(slice->getOpResult(0), accLayout);
-    }
+    xegpu::setTemporaryLayout(slice->getOpOperand(0), srcLayout);
+    xegpu::setTemporaryLayout(slice->getOpResult(0), accLayout);
     // Extract and reduction results in scalars, so no result layout is needed.
     // Build multi-dim index into acc (sourceRank-1 dims, i.e. source shape with
     // the reduction dim removed). Leading unit dims get index 0.
@@ -742,8 +739,7 @@ Value xegpu::lowerToVectorReductions(TypedValue<VectorType> src,
         rewriter, loc, kind, slice.getResult(), accExtract);
     reductionResult = vector::InsertOp::create(rewriter, loc, reduction,
                                                reductionResult, accIdx);
-    if (setLayout)
-      xegpu::setTemporaryLayout(cast<OpResult>(reductionResult), accLayout);
+    xegpu::setTemporaryLayout(cast<OpResult>(reductionResult), accLayout);
   }
   return reductionResult;
 }

>From 258e8ace7b2e54db271772459441d18d142114e5 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Thu, 9 Apr 2026 04:10:14 +0000
Subject: [PATCH 3/3] clean up

---
 mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp | 12 +++++++-----
 1 file changed, 7 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
index 3bb6cf82f9ee4..d06369c507e31 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
@@ -691,11 +691,9 @@ Value xegpu::lowerToVectorReductions(TypedValue<VectorType> src,
       rewriter, loc, acc.getType(),
       DenseElementsAttr::get(acc.getType(), zeroAttr));
   // TODO: Remove these get/setTemporaryLayout calls after we deprecate the old
-  // pass.
-  xegpu::DistributeLayoutAttr srcLayout =
-      xegpu::getTemporaryLayout(dyn_cast<OpResult>(src));
-  xegpu::DistributeLayoutAttr accLayout =
-      xegpu::getTemporaryLayout(dyn_cast<OpResult>(acc));
+  // XeGPUSubgroupDistribute pass.
+  auto srcLayout = xegpu::getTemporaryLayout(dyn_cast<OpResult>(src));
+  auto accLayout = xegpu::getTemporaryLayout(dyn_cast<OpResult>(acc));
   // Reduction result should have the same layout as the accumulator.
   xegpu::setTemporaryLayout(cast<OpResult>(reductionResult), accLayout);
   // For each slice of the source, extract the slice vector, do a reduction
@@ -718,6 +716,7 @@ Value xegpu::lowerToVectorReductions(TypedValue<VectorType> src,
     vector::ExtractStridedSliceOp extractOp =
         vector::ExtractStridedSliceOp::create(rewriter, loc, src, sliceOffsets,
                                               sliceSizes, strides);
+    // Extract strided slice has the same layout as src.
     xegpu::setTemporaryLayout(extractOp->getOpResult(0), srcLayout);
 
     int64_t nSliceElements = extractOp.getResult().getType().getNumElements();
@@ -727,6 +726,8 @@ Value xegpu::lowerToVectorReductions(TypedValue<VectorType> src,
         VectorType::get({nSliceElements}, sourceType.getElementType()),
         extractOp.getResult());
 
+    // Shape cast output has the same layout as the accumulator. Shape cast
+    // source has the same layout as the original reduction source.
     xegpu::setTemporaryLayout(slice->getOpOperand(0), srcLayout);
     xegpu::setTemporaryLayout(slice->getOpResult(0), accLayout);
     // Extract and reduction results in scalars, so no result layout is needed.
@@ -739,6 +740,7 @@ Value xegpu::lowerToVectorReductions(TypedValue<VectorType> src,
         rewriter, loc, kind, slice.getResult(), accExtract);
     reductionResult = vector::InsertOp::create(rewriter, loc, reduction,
                                                reductionResult, accIdx);
+    // Insert op should have the same layout as the accumulator.
     xegpu::setTemporaryLayout(cast<OpResult>(reductionResult), accLayout);
   }
   return reductionResult;



More information about the Mlir-commits mailing list