[Mlir-commits] [mlir] 50a7eb6 - [MLIR][XeGPU] Add support for vector.multi_reduction in wg to sg pass [1/N] (#157554)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Sep 25 10:21:58 PDT 2025


Author: Nishant Patel
Date: 2025-09-25T10:21:54-07:00
New Revision: 50a7eb6fc2977d3a5c2d71d91a799d4275f5a595

URL: https://github.com/llvm/llvm-project/commit/50a7eb6fc2977d3a5c2d71d91a799d4275f5a595
DIFF: https://github.com/llvm/llvm-project/commit/50a7eb6fc2977d3a5c2d71d91a799d4275f5a595.diff

LOG: [MLIR][XeGPU] Add support for vector.multi_reduction in wg to sg pass [1/N] (#157554)

This PR adds pattern for lowering vector.multi_reduction from workgroup
to subgroup IR. It currently only supports sg local reductions

Added: 
    

Modified: 
    mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
    mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
    mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index d7592fed6d186..9413a9296b184 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -1027,6 +1027,70 @@ struct WgToSgVectorShapeCastOp
   }
 };
 
+/// Pattern for lowering vector.multi_reduction op to subgroup level.
+/// Current limitation: the sg_layout in the reduced dimension being 1
+/// so that reduction is local to subgroup & no cross-subgroup communication is
+/// needed.
+/// TODO: Add cases to handle more general situations which require SLM access.
+struct WgToSgMultiDimReductionOp
+    : public OpConversionPattern<vector::MultiDimReductionOp> {
+  using OpConversionPattern<vector::MultiDimReductionOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::MultiDimReductionOp op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    VectorType srcType = op.getSourceVectorType();
+    VectorType dstType = dyn_cast<VectorType>(op.getResult().getType());
+    if (!dstType)
+      return failure();
+
+    auto srcShape = srcType.getShape();
+    xegpu::DistributeLayoutAttr layout =
+        xegpu::getDistributeLayoutAttr(op.getResult());
+    if (!layout || !layout.isForWorkgroup())
+      return failure();
+
+    auto reductionDims = llvm::to_vector(op.getReductionDims());
+
+    SmallVector<int64_t> sgLayout = llvm::cast<xegpu::SliceAttr>(layout)
+                                        .getParent()
+                                        .getEffectiveSgLayoutAsInt();
+    SmallVector<int64_t> sgData = llvm::cast<xegpu::SliceAttr>(layout)
+                                      .getParent()
+                                      .getEffectiveSgDataAsInt();
+
+    // Check that the sgLayout in the reduced dimension is 1 and
+    // each sg gets the entire slice to reduce.
+    for (int64_t dim : reductionDims) {
+      if (sgLayout[dim] != 1 || sgData[dim] != srcShape[dim])
+        return rewriter.notifyMatchFailure(
+            op,
+            "sgLayout in each reduced dimension must be 1 and sgData in the "
+            "reduced dim must match srcShape in that dim");
+    }
+
+    SmallVector<int64_t> sgShape = getSgShapeAndCount(srcShape, layout).first;
+
+    VectorType newDstType =
+        VectorType::get({sgShape}, dstType.getElementType());
+
+    SmallVector<Value> newReductions;
+    for (auto sgSrc : adaptor.getSource()) {
+      auto newOp = rewriter.create<vector::MultiDimReductionOp>(
+          op.getLoc(), newDstType, op.getKind(), sgSrc, adaptor.getAcc()[0],
+          op.getReductionDims());
+      if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
+          !layout.getEffectiveInstDataAsInt().empty())
+        xegpu::setDistributeLayoutAttr(newOp->getResult(0),
+                                       layout.dropSgLayoutAndData());
+      newReductions.push_back(newOp.getResult());
+    }
+
+    rewriter.replaceOpWithMultiple(op, {newReductions});
+    return success();
+  }
+};
+
 } // namespace
 
 namespace mlir {
@@ -1040,8 +1104,8 @@ void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
            WgToSgElementwiseOp, WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp,
            WgToSgArithConstantOp, WgToSgLoadGatherOpWithOffset,
            WgToSgStoreScatterOpWithOffset, WgToSgLoadMatrixOp,
-           WgToSgStoreMatrixOp, WgToSgVectorStepOp, WgToSgVectorShapeCastOp>(
-          patterns.getContext());
+           WgToSgStoreMatrixOp, WgToSgVectorStepOp, WgToSgVectorShapeCastOp,
+           WgToSgMultiDimReductionOp>(patterns.getContext());
 }
 } // namespace xegpu
 } // namespace mlir
@@ -1195,6 +1259,11 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
         return isLegal(xegpu::getDistributeLayoutAttr(op.getResult()));
       });
 
+  target.addDynamicallyLegalOp<vector::MultiDimReductionOp>(
+      [=](vector::MultiDimReductionOp op) -> bool {
+        return isLegal(xegpu::getDistributeLayoutAttr(op.getResult()));
+      });
+
   target.addDynamicallyLegalOp<xegpu::ConvertLayoutOp>(
       [=](xegpu::ConvertLayoutOp op) -> bool {
         return isLegal(op.getInputLayout()) && isLegal(op.getTargetLayout());

diff  --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
index 6ff7a94d678a3..dce73dee507e1 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
@@ -82,4 +82,20 @@ gpu.module @test_distribution {
       : vector<256x128xf16>, vector<128x256xf16> -> vector<256x256xf32>
     gpu.return
   }
+
+  // CHECK-LABEL: vector_reduce_dim_1
+  gpu.func @vector_reduce_dim_1(%src: memref<256x64xf32>) {
+    // CHECK: %[[CST:.*]] = arith.constant dense<1.000000e+00> : vector<16xf32>
+    %cst = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [8, 1], sg_data = [16, 64]>, dims = [1]>} dense<1.0> : vector<256xf32>
+    %tdesc = xegpu.create_nd_tdesc %src : memref<256x64xf32>
+      -> !xegpu.tensor_desc<256x64xf32, #xegpu.layout<sg_layout = [8, 1], sg_data = [16, 64]>>
+    %load =  xegpu.load_nd %tdesc[0, 0]
+      : !xegpu.tensor_desc<256x64xf32, #xegpu.layout<sg_layout = [8, 1], sg_data = [16, 64]>>
+      -> vector<256x64xf32>
+    // CHECK-COUNT-2: vector.multi_reduction <add>, {{.*}}, %[[CST]] [1] : vector<16x64xf32> to vector<16xf32>
+    // CHECK-NOT: vector.multi_reduction
+    %reduce = vector.multi_reduction <add>, %load, %cst {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [8, 1], sg_data = [16, 64]>, dims = [1]>} [1]
+      : vector<256x64xf32> to vector<256xf32>
+    gpu.return
+  }
 }

diff  --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
index 3478a9b91da5f..48fc633974e63 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
@@ -367,6 +367,46 @@ gpu.module @test_distribution {
     gpu.return
   }
 
+  // CHECK-LABEL: @vector_reduce_dim_0
+  gpu.func @vector_reduce_dim_0(%src: memref<4x128xf32>) {
+    %cst = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [1, 32], sg_data = [4, 4]>, dims = [0]>} dense<1.0> : vector<128xf32>
+    %tdesc = xegpu.create_nd_tdesc %src : memref<4x128xf32>
+      -> !xegpu.tensor_desc<4x128xf32, #xegpu.layout<sg_layout = [1, 32], sg_data = [4, 4]>>
+    %load =  xegpu.load_nd %tdesc[0, 0]
+      : !xegpu.tensor_desc<4x128xf32, #xegpu.layout<sg_layout = [1, 32], sg_data = [4, 4]>>
+      -> vector<4x128xf32>
+    // CHECK: vector.multi_reduction <add>, {{.*}}, {{.*}} [0] : vector<4x4xf32> to vector<4xf32>
+    %reduce = vector.multi_reduction <add>, %load, %cst {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [1, 32], sg_data = [4, 4]>, dims = [0]>} [0]
+      : vector<4x128xf32> to vector<128xf32>
+    gpu.return
+  }
+
+  // CHECK-LABEL: @vector_reduce_dim_1
+  gpu.func @vector_reduce_dim_1(%src: memref<256x64xf32>) {
+    %cst = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [16, 1], sg_data = [16, 64]>, dims = [1]>} dense<1.0> : vector<256xf32>
+    %tdesc = xegpu.create_nd_tdesc %src : memref<256x64xf32>
+      -> !xegpu.tensor_desc<256x64xf32, #xegpu.layout<sg_layout = [16, 1], sg_data = [16, 64]>>
+    %load =  xegpu.load_nd %tdesc[0, 0]
+      : !xegpu.tensor_desc<256x64xf32, #xegpu.layout<sg_layout = [16, 1], sg_data = [16, 64]>>
+      -> vector<256x64xf32>
+    // CHECK: vector.multi_reduction <add>, {{.*}}, {{.*}} [1] : vector<16x64xf32> to vector<16xf32>
+    %reduce = vector.multi_reduction <add>, %load, %cst {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [16, 1], sg_data = [16, 64]>, dims = [1]>} [1]
+      : vector<256x64xf32> to vector<256xf32>
+    gpu.return
+  }
+
+  // CHECK-LABEL: @vector_reduce_4D
+   gpu.func @vector_reduce_4D(%src: ui64) {
+      %cst_acc = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [4, 2, 6, 1], sg_data = [1, 1, 1, 32]>, dims = [3]>} dense<0.0> : vector<4x2x6xf16>
+      %offset = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [4, 2, 6, 1], sg_data = [1, 1, 1, 32]>} dense<0>  : vector<4x2x6x32xindex>
+      %mask = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [4, 2, 6, 1], sg_data = [1, 1, 1, 32]>} dense<true> : vector<4x2x6x32xi1>
+      %load = xegpu.load %src[%offset], %mask  {layout_result_0 = #xegpu.layout<sg_layout = [4, 2, 6, 1], sg_data = [1, 1, 1, 32]>} : ui64, vector<4x2x6x32xindex>, vector<4x2x6x32xi1> -> vector<4x2x6x32xf16>
+      // CHECK: vector.multi_reduction <add>, {{.*}}, {{.*}} [3] : vector<1x1x1x32xf16> to vector<1x1x1xf16>
+      %reduce = vector.multi_reduction <add>, %load, %cst_acc {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [4, 2, 6, 1], sg_data = [1, 1, 1, 32]>, dims = [3]>} [3]
+      : vector<4x2x6x32xf16> to vector<4x2x6xf16>
+      gpu.return
+    }
+
   // CHECK-LABEL: vector_step_op
   gpu.func @vector_step_op_slice_attr() {
     //CHECK: [[sgId:%.+]] = gpu.subgroup_id : index


        


More information about the Mlir-commits mailing list