[Mlir-commits] [mlir] [mlir][xegpu] Add support for `vector.multi_reduction` and `vector.shape_cast` SIMT distribution. (PR #157560)

Nishant Patel llvmlistbot at llvm.org
Fri Sep 12 10:13:23 PDT 2025


================
@@ -1001,12 +1011,282 @@ struct LoadDistribution final : public gpu::WarpDistributionPattern {
   }
 };
 
+/// Helper to rewrite a 2D VectorMultiReductionOp into a sequence of 1D
+/// VectorReductionOps.
+static Value lowerToVectorReductions(TypedValue<VectorType> src,
+                                     TypedValue<VectorType> acc,
+                                     vector::CombiningKind kind,
+                                     int64_t reductionDim, Location loc,
+                                     PatternRewriter &rewriter) {
+  // Expecting a 2D source vector.
+  assert(src.getType().getRank() == 2 && "expected a 2D source vector");
+  VectorType sourceType = src.getType();
+  int64_t sourceH = sourceType.getShape()[0];
+  int64_t sourceW = sourceType.getShape()[1];
+  int nSlices = (reductionDim == 0) ? sourceW : sourceH;
+  // Create a constant vector to hold the result of the reduction.
+  TypedAttr zeroAttr = rewriter.getZeroAttr(sourceType.getElementType());
+  Value reductionResult = arith::ConstantOp::create(
+      rewriter, loc, acc.getType(),
+      DenseElementsAttr::get(acc.getType(), zeroAttr));
+  // For each slice of the source, extract the slice vector, do a reduction
+  // and, insert the reduced value back to the result vector.
+  for (int i = 0; i < nSlices; ++i) {
+    SmallVector<int64_t, 2> sliceOffsets, sliceSizes;
+    if (reductionDim == 1) {
+      sliceOffsets = {i, 0};
+      sliceSizes = {1, sourceW};
+    } else {
+      sliceOffsets = {0, i};
+      sliceSizes = {sourceH, 1};
+    }
+    vector::ExtractStridedSliceOp extractOp =
+        vector::ExtractStridedSliceOp::create(rewriter, loc, src, sliceOffsets,
+                                              sliceSizes, {1, 1});
+    int64_t nSliceElements = extractOp.getResult().getType().getNumElements();
+    Value slice = vector::ShapeCastOp::create(
+        rewriter, loc,
+        VectorType::get({nSliceElements}, sourceType.getElementType()),
+        extractOp.getResult());
+    Value accExtract = vector::ExtractOp::create(rewriter, loc, acc, i);
+    Value reduction =
+        vector::ReductionOp::create(rewriter, loc, kind, slice, accExtract);
+    reductionResult =
+        vector::InsertOp::create(rewriter, loc, reduction, reductionResult, i);
+  }
+  return reductionResult;
+}
+
+/// This patterns distribute the `vector.multi_reduction` operation across
+/// lanes in a warp. Currently only 2D to 1D reductions are supported. Given
+/// layouts for the source and accumulator vectors,
+/// * If the reduction dimension is distributed across lanes, the reduction is
+///   non-lane-local and the reduction is done using warp shuffles. Here we
+///   simply rewrite the MultiDimReductionOp to a sequence of ReductionOps in
+///   the warp op body.
+/// * If the reduction dimension is not distributed across lanes, the reduction
+///   is lane-local. In this case, we yield the source and accumulator vectors
+///   from the warp op and perform the lane-local reduction outside the warp op
+///   using a sequence of ReductionOps.
+/// Example 1 (Reduction is lane-local):
+/// ```
+/// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xf32>) {
+///   %0 = "some_def"() : () -> (vector<16x32xf32>)
+///   %acc = "some_def"() : () -> (vector<32xf32>)
+///   %1 = vector.multi_reduction <add>, %0, %acc [0] : vector<16x32xf32> to
+///   vector<32xf32> gpu.yield %1 : vector<32xf32>
+/// }
+/// ```
+/// is lowered to:
+/// ```
+/// %r:2 = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<16x1xf32>,
+/// vector<1xf32>) {
+///   %0 = "some_def"() : () -> (vector<16x32xf32>)
+///   %acc = "some_def"() : () -> (vector<32xf32>)
+///   gpu.yield %0, %acc : vector<16x32xf32>, vector<32xf32>
+/// }
+/// %c = arith.constant dense<0.0> : vector<1xf32>
+/// %1 = vector.shape_cast %r#0 : vector<16x1xf32> to vector<16xf32>
+/// %2 = vector.reduction <add>, %1, %r#1 : vector<16xf32> to f32
+/// %3 = vector.insert %2, %c[0] : f32 into vector<1xf32>
+/// ```
+/// Example 2 (Reduction is non-lane-local):
+/// ```
+/// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<2xf32>) {
+///   %0 = "some_def"() : () -> (vector<2x32xf32>)
+///   %acc = "some_def"() : () -> (vector<2xf32>)
+///   %1 = vector.multi_reduction <add>, %0, %acc [1] : vector<2x32xf32> to
+///   vector<2xf32>
+///   gpu.yield %1 : vector<2xf32>
+/// }
+/// ```
+/// is lowered to:
+/// ```
+/// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<2xf32>) {
+///   %0 = "some_def"() : () -> (vector<2x32xf32>)
+///   %acc = "some_def"() : () -> (vector<2xf32>)
+///   %1 = arith.constant dense<0.0> : vector<2xf32>
+///   %2 = vector.extract %0[0] : vector<32xf32> from <vector<2x32xf32>>
+///   %3 = ("warp.reduction %2") : f32
+///   %4 = vector.insert %3, %1[0] : f32 into vector<2xf32>
+///   ... repeat for row 1
+///   gpu.yield %1 : vector<2xf32>
+/// }
+struct VectorMultiReductionDistribution : public gpu::WarpDistributionPattern {
+  using gpu::WarpDistributionPattern::WarpDistributionPattern;
+  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
+                                PatternRewriter &rewriter) const override {
+    OpOperand *yieldOperand =
+        getWarpResult(warpOp, llvm::IsaPred<vector::MultiDimReductionOp>);
+    if (!yieldOperand)
+      return failure();
+    auto reductionOp =
+        cast<vector::MultiDimReductionOp>(yieldOperand->get().getDefiningOp());
+    unsigned operandNumber = yieldOperand->getOperandNumber();
+    VectorType sourceType = reductionOp.getSourceVectorType();
+    // Only 2D vectors are supported.
+    if (sourceType.getRank() != 2)
+      return rewriter.notifyMatchFailure(warpOp,
+                                         "Only 2D reductions are supported.");
+    ArrayRef<int64_t> reductionDims = reductionOp.getReductionDims();
+    // Only 1 reduction dimension supported. This also ensures that the result
+    // is vector type.
+    if (reductionDims.size() != 1)
+      return rewriter.notifyMatchFailure(
+          warpOp, "Only 1 reduction dimension is supported.");
+    int64_t reductionDim = reductionDims[0];
+    VectorType distributedResultType =
+        cast<VectorType>(warpOp.getResult(operandNumber).getType());
+    VectorType resultType = cast<VectorType>(reductionOp.getType());
+    xegpu::DistributeLayoutAttr sourceLayout =
+        xegpu::getDistributeLayoutAttr(reductionOp.getSource());
+
+    FailureOr<VectorType> sourceDistTypeOrFailure =
+        getDistVecTypeBasedOnLaneLayout(sourceLayout, sourceType);
+    if (failed(sourceDistTypeOrFailure))
+      return rewriter.notifyMatchFailure(
+          warpOp, "Failed to distribute the source vector type.");
+    VectorType sourceDistType = sourceDistTypeOrFailure.value();
+    // Only single dimension distribution is supported.
+    bool dim0Distributed =
+        sourceDistType.getShape()[0] != sourceType.getShape()[0];
+    bool dim1Distributed =
+        sourceDistType.getShape()[1] != sourceType.getShape()[1];
+    if (dim0Distributed && dim1Distributed)
+      return rewriter.notifyMatchFailure(
+          warpOp, "Expecting source to be distributed in a single dimension.");
+    int64_t sourceDistDim = dim0Distributed ? 0 : (dim1Distributed ? 1 : -1);
+    if (sourceDistDim == -1)
+      return rewriter.notifyMatchFailure(
+          warpOp, "Expecting a distributed source vector.");
+    bool resultDistributed =
+        distributedResultType.getNumElements() < resultType.getNumElements();
+    // If the lane owns all the data required for reduction (i.e. reduction is
+    // fully parallel accross lanes), then each lane owns part of the result
+    // (i.e. result is distributed). If the reduction require cross-lane
+    // shuffling, then the result is shared among all lanes (broadcasted).
+    // Therefore we expect following cases:
+    //
+    // | Source vector        | Reduction dim  | Result vector  |
+    // |----------------------|----------------|----------------|
+    // |  dim-0 distributed   |       0        | broadcasted    |
+    // |  dim-0 distributed   |       1        | distributed    |
+    // |  dim-1 distributed   |       0        | distributed    |
+    // |  dim-1 distributed   |       1        | broadcasted    |
+
+    bool isReductionLaneLocal = (sourceDistDim == 0 && reductionDim == 1) ||
+                                (sourceDistDim == 1 && reductionDim == 0);
+    if (isReductionLaneLocal && !resultDistributed)
+      return rewriter.notifyMatchFailure(
+          warpOp, "Expecting a distributed result for lane-local reduction.");
+
+    if (!isReductionLaneLocal && resultDistributed)
+      return rewriter.notifyMatchFailure(
+          warpOp,
+          "Expecting a broadcasted result for non-lane-local reduction.");
+
+    // Handle lane-local reduction case. In this case we fully distribute the
+    // reduction result.
+    if (isReductionLaneLocal) {
+      // Yield the source and acc vectors from the WarpOp.
+      SmallVector<size_t> newRetIndices;
+      auto newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+          rewriter, warpOp, {reductionOp.getSource(), reductionOp.getAcc()},
+          {sourceDistType, distributedResultType}, newRetIndices);
+      rewriter.setInsertionPointAfter(newWarpOp);
+      Value result = lowerToVectorReductions(
+          cast<TypedValue<VectorType>>(newWarpOp->getResult(newRetIndices[0])),
+          cast<TypedValue<VectorType>>(newWarpOp->getResult(newRetIndices[1])),
+          reductionOp.getKind(), reductionDim, reductionOp.getLoc(), rewriter);
+      // Replace the warp op result with the final result.
+      rewriter.replaceAllUsesWith(reductionOp.getResult(), result);
+      return success();
+    }
+    // For non-lane-local case, we simply rewrite the MultiReductionOp in terms
+    // of multiple ReductionOps. Actual distribution is done by the
+    // WarpOpReduction pattern.
+    rewriter.setInsertionPointAfter(reductionOp);
+    Value result = lowerToVectorReductions(
+        cast<TypedValue<VectorType>>(reductionOp.getSource()),
+        cast<TypedValue<VectorType>>(reductionOp.getAcc()),
+        reductionOp.getKind(), reductionDim, reductionOp.getLoc(), rewriter);
+    // Replace the warp op result with the final result.
+    rewriter.replaceAllUsesWith(reductionOp.getResult(), result);
+    return success();
+  }
+};
+
+/// Distribute a `vector.shape_cast` op feeding into yield op of an enclosing
+/// `gpu.warp_execute_on_lane_0` region.
+struct VectorShapeCastDistribution : public gpu::WarpDistributionPattern {
+  using gpu::WarpDistributionPattern::WarpDistributionPattern;
+  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
+                                PatternRewriter &rewriter) const override {
+    OpOperand *yieldOperand =
+        getWarpResult(warpOp, llvm::IsaPred<vector::ShapeCastOp>);
+    if (!yieldOperand)
+      return failure();
+    auto shapeCastOp =
+        cast<vector::ShapeCastOp>(yieldOperand->get().getDefiningOp());
+    unsigned operandNumber = yieldOperand->getOperandNumber();
+    auto resultDistTy =
+        cast<VectorType>(warpOp.getResult(operandNumber).getType());
+    xegpu::DistributeLayoutAttr sourceLayout =
+        xegpu::getDistributeLayoutAttr(shapeCastOp.getSource());
+    xegpu::DistributeLayoutAttr resultLayout =
+        xegpu::getDistributeLayoutAttr(shapeCastOp.getResult());
+    if (!sourceLayout || !resultLayout)
+      return rewriter.notifyMatchFailure(
+          warpOp,
+          "the source or result of shape_cast op lacks distribution layout");
+
+    // For rank reducing or increasing shape_cast ops, the lower rank layout
+    // must be a slice of higher rank layout.
+    int64_t sourceRank = shapeCastOp.getSourceVectorType().getRank();
+    int64_t resultRank = shapeCastOp.getResultVectorType().getRank();
+    if (sourceRank < resultRank && !sourceLayout.isSliceOf(resultLayout))
+      return rewriter.notifyMatchFailure(
+          warpOp, "shape_cast is rank reducing but source layout is not a "
+                  "slice of result layout");
+    if (sourceRank > resultRank && !resultLayout.isSliceOf(sourceLayout))
+      return rewriter.notifyMatchFailure(
+          warpOp, "shape_cast is rank increasing but result layout is not a "
+                  "slice of source layout");
+
----------------
nbpatel wrote:

Shape cast patterns needs to check only unit dims can be squeezed/expanded 

https://github.com/llvm/llvm-project/pull/157560


More information about the Mlir-commits mailing list