[Mlir-commits] [mlir] [mlir][xegpu] Add support for `vector.multi_reduction` and `vector.shape_cast` SIMT distribution. (PR #157560)
Nishant Patel
llvmlistbot at llvm.org
Fri Sep 12 10:13:23 PDT 2025
================
@@ -1001,12 +1011,282 @@ struct LoadDistribution final : public gpu::WarpDistributionPattern {
}
};
+/// Helper to rewrite a 2D VectorMultiReductionOp into a sequence of 1D
+/// VectorReductionOps.
+static Value lowerToVectorReductions(TypedValue<VectorType> src,
+ TypedValue<VectorType> acc,
+ vector::CombiningKind kind,
+ int64_t reductionDim, Location loc,
+ PatternRewriter &rewriter) {
+ // Expecting a 2D source vector.
+ assert(src.getType().getRank() == 2 && "expected a 2D source vector");
+ VectorType sourceType = src.getType();
+ int64_t sourceH = sourceType.getShape()[0];
+ int64_t sourceW = sourceType.getShape()[1];
+ int nSlices = (reductionDim == 0) ? sourceW : sourceH;
+ // Create a constant vector to hold the result of the reduction.
+ TypedAttr zeroAttr = rewriter.getZeroAttr(sourceType.getElementType());
+ Value reductionResult = arith::ConstantOp::create(
+ rewriter, loc, acc.getType(),
+ DenseElementsAttr::get(acc.getType(), zeroAttr));
+ // For each slice of the source, extract the slice vector, do a reduction
+ // and, insert the reduced value back to the result vector.
+ for (int i = 0; i < nSlices; ++i) {
+ SmallVector<int64_t, 2> sliceOffsets, sliceSizes;
+ if (reductionDim == 1) {
+ sliceOffsets = {i, 0};
+ sliceSizes = {1, sourceW};
+ } else {
+ sliceOffsets = {0, i};
+ sliceSizes = {sourceH, 1};
+ }
+ vector::ExtractStridedSliceOp extractOp =
+ vector::ExtractStridedSliceOp::create(rewriter, loc, src, sliceOffsets,
+ sliceSizes, {1, 1});
+ int64_t nSliceElements = extractOp.getResult().getType().getNumElements();
+ Value slice = vector::ShapeCastOp::create(
+ rewriter, loc,
+ VectorType::get({nSliceElements}, sourceType.getElementType()),
+ extractOp.getResult());
+ Value accExtract = vector::ExtractOp::create(rewriter, loc, acc, i);
+ Value reduction =
+ vector::ReductionOp::create(rewriter, loc, kind, slice, accExtract);
+ reductionResult =
+ vector::InsertOp::create(rewriter, loc, reduction, reductionResult, i);
+ }
+ return reductionResult;
+}
+
+/// This patterns distribute the `vector.multi_reduction` operation across
+/// lanes in a warp. Currently only 2D to 1D reductions are supported. Given
+/// layouts for the source and accumulator vectors,
+/// * If the reduction dimension is distributed across lanes, the reduction is
+/// non-lane-local and the reduction is done using warp shuffles. Here we
+/// simply rewrite the MultiDimReductionOp to a sequence of ReductionOps in
+/// the warp op body.
+/// * If the reduction dimension is not distributed across lanes, the reduction
+/// is lane-local. In this case, we yield the source and accumulator vectors
+/// from the warp op and perform the lane-local reduction outside the warp op
+/// using a sequence of ReductionOps.
+/// Example 1 (Reduction is lane-local):
+/// ```
+/// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xf32>) {
+/// %0 = "some_def"() : () -> (vector<16x32xf32>)
+/// %acc = "some_def"() : () -> (vector<32xf32>)
+/// %1 = vector.multi_reduction <add>, %0, %acc [0] : vector<16x32xf32> to
+/// vector<32xf32> gpu.yield %1 : vector<32xf32>
+/// }
+/// ```
+/// is lowered to:
+/// ```
+/// %r:2 = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<16x1xf32>,
+/// vector<1xf32>) {
+/// %0 = "some_def"() : () -> (vector<16x32xf32>)
+/// %acc = "some_def"() : () -> (vector<32xf32>)
+/// gpu.yield %0, %acc : vector<16x32xf32>, vector<32xf32>
+/// }
+/// %c = arith.constant dense<0.0> : vector<1xf32>
+/// %1 = vector.shape_cast %r#0 : vector<16x1xf32> to vector<16xf32>
+/// %2 = vector.reduction <add>, %1, %r#1 : vector<16xf32> to f32
+/// %3 = vector.insert %2, %c[0] : f32 into vector<1xf32>
+/// ```
+/// Example 2 (Reduction is non-lane-local):
+/// ```
+/// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<2xf32>) {
+/// %0 = "some_def"() : () -> (vector<2x32xf32>)
+/// %acc = "some_def"() : () -> (vector<2xf32>)
+/// %1 = vector.multi_reduction <add>, %0, %acc [1] : vector<2x32xf32> to
+/// vector<2xf32>
+/// gpu.yield %1 : vector<2xf32>
+/// }
+/// ```
+/// is lowered to:
+/// ```
+/// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<2xf32>) {
+/// %0 = "some_def"() : () -> (vector<2x32xf32>)
+/// %acc = "some_def"() : () -> (vector<2xf32>)
+/// %1 = arith.constant dense<0.0> : vector<2xf32>
+/// %2 = vector.extract %0[0] : vector<32xf32> from <vector<2x32xf32>>
+/// %3 = ("warp.reduction %2") : f32
+/// %4 = vector.insert %3, %1[0] : f32 into vector<2xf32>
+/// ... repeat for row 1
+/// gpu.yield %1 : vector<2xf32>
+/// }
+struct VectorMultiReductionDistribution : public gpu::WarpDistributionPattern {
+ using gpu::WarpDistributionPattern::WarpDistributionPattern;
+ LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
+ PatternRewriter &rewriter) const override {
+ OpOperand *yieldOperand =
+ getWarpResult(warpOp, llvm::IsaPred<vector::MultiDimReductionOp>);
+ if (!yieldOperand)
+ return failure();
+ auto reductionOp =
+ cast<vector::MultiDimReductionOp>(yieldOperand->get().getDefiningOp());
+ unsigned operandNumber = yieldOperand->getOperandNumber();
+ VectorType sourceType = reductionOp.getSourceVectorType();
+ // Only 2D vectors are supported.
+ if (sourceType.getRank() != 2)
+ return rewriter.notifyMatchFailure(warpOp,
+ "Only 2D reductions are supported.");
+ ArrayRef<int64_t> reductionDims = reductionOp.getReductionDims();
+ // Only 1 reduction dimension supported. This also ensures that the result
+ // is vector type.
+ if (reductionDims.size() != 1)
+ return rewriter.notifyMatchFailure(
+ warpOp, "Only 1 reduction dimension is supported.");
+ int64_t reductionDim = reductionDims[0];
+ VectorType distributedResultType =
+ cast<VectorType>(warpOp.getResult(operandNumber).getType());
+ VectorType resultType = cast<VectorType>(reductionOp.getType());
+ xegpu::DistributeLayoutAttr sourceLayout =
+ xegpu::getDistributeLayoutAttr(reductionOp.getSource());
+
+ FailureOr<VectorType> sourceDistTypeOrFailure =
+ getDistVecTypeBasedOnLaneLayout(sourceLayout, sourceType);
+ if (failed(sourceDistTypeOrFailure))
+ return rewriter.notifyMatchFailure(
+ warpOp, "Failed to distribute the source vector type.");
+ VectorType sourceDistType = sourceDistTypeOrFailure.value();
+ // Only single dimension distribution is supported.
+ bool dim0Distributed =
+ sourceDistType.getShape()[0] != sourceType.getShape()[0];
+ bool dim1Distributed =
+ sourceDistType.getShape()[1] != sourceType.getShape()[1];
+ if (dim0Distributed && dim1Distributed)
+ return rewriter.notifyMatchFailure(
+ warpOp, "Expecting source to be distributed in a single dimension.");
+ int64_t sourceDistDim = dim0Distributed ? 0 : (dim1Distributed ? 1 : -1);
+ if (sourceDistDim == -1)
+ return rewriter.notifyMatchFailure(
+ warpOp, "Expecting a distributed source vector.");
+ bool resultDistributed =
+ distributedResultType.getNumElements() < resultType.getNumElements();
+ // If the lane owns all the data required for reduction (i.e. reduction is
+ // fully parallel accross lanes), then each lane owns part of the result
+ // (i.e. result is distributed). If the reduction require cross-lane
+ // shuffling, then the result is shared among all lanes (broadcasted).
+ // Therefore we expect following cases:
+ //
+ // | Source vector | Reduction dim | Result vector |
+ // |----------------------|----------------|----------------|
+ // | dim-0 distributed | 0 | broadcasted |
+ // | dim-0 distributed | 1 | distributed |
+ // | dim-1 distributed | 0 | distributed |
+ // | dim-1 distributed | 1 | broadcasted |
+
+ bool isReductionLaneLocal = (sourceDistDim == 0 && reductionDim == 1) ||
+ (sourceDistDim == 1 && reductionDim == 0);
+ if (isReductionLaneLocal && !resultDistributed)
+ return rewriter.notifyMatchFailure(
+ warpOp, "Expecting a distributed result for lane-local reduction.");
+
+ if (!isReductionLaneLocal && resultDistributed)
+ return rewriter.notifyMatchFailure(
+ warpOp,
+ "Expecting a broadcasted result for non-lane-local reduction.");
+
+ // Handle lane-local reduction case. In this case we fully distribute the
+ // reduction result.
+ if (isReductionLaneLocal) {
+ // Yield the source and acc vectors from the WarpOp.
+ SmallVector<size_t> newRetIndices;
+ auto newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+ rewriter, warpOp, {reductionOp.getSource(), reductionOp.getAcc()},
+ {sourceDistType, distributedResultType}, newRetIndices);
+ rewriter.setInsertionPointAfter(newWarpOp);
+ Value result = lowerToVectorReductions(
+ cast<TypedValue<VectorType>>(newWarpOp->getResult(newRetIndices[0])),
+ cast<TypedValue<VectorType>>(newWarpOp->getResult(newRetIndices[1])),
+ reductionOp.getKind(), reductionDim, reductionOp.getLoc(), rewriter);
+ // Replace the warp op result with the final result.
+ rewriter.replaceAllUsesWith(reductionOp.getResult(), result);
+ return success();
+ }
+ // For non-lane-local case, we simply rewrite the MultiReductionOp in terms
+ // of multiple ReductionOps. Actual distribution is done by the
+ // WarpOpReduction pattern.
+ rewriter.setInsertionPointAfter(reductionOp);
+ Value result = lowerToVectorReductions(
+ cast<TypedValue<VectorType>>(reductionOp.getSource()),
+ cast<TypedValue<VectorType>>(reductionOp.getAcc()),
+ reductionOp.getKind(), reductionDim, reductionOp.getLoc(), rewriter);
+ // Replace the warp op result with the final result.
+ rewriter.replaceAllUsesWith(reductionOp.getResult(), result);
+ return success();
+ }
+};
+
+/// Distribute a `vector.shape_cast` op feeding into yield op of an enclosing
+/// `gpu.warp_execute_on_lane_0` region.
+struct VectorShapeCastDistribution : public gpu::WarpDistributionPattern {
+ using gpu::WarpDistributionPattern::WarpDistributionPattern;
+ LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
+ PatternRewriter &rewriter) const override {
+ OpOperand *yieldOperand =
+ getWarpResult(warpOp, llvm::IsaPred<vector::ShapeCastOp>);
+ if (!yieldOperand)
+ return failure();
+ auto shapeCastOp =
+ cast<vector::ShapeCastOp>(yieldOperand->get().getDefiningOp());
+ unsigned operandNumber = yieldOperand->getOperandNumber();
+ auto resultDistTy =
+ cast<VectorType>(warpOp.getResult(operandNumber).getType());
+ xegpu::DistributeLayoutAttr sourceLayout =
+ xegpu::getDistributeLayoutAttr(shapeCastOp.getSource());
+ xegpu::DistributeLayoutAttr resultLayout =
+ xegpu::getDistributeLayoutAttr(shapeCastOp.getResult());
+ if (!sourceLayout || !resultLayout)
+ return rewriter.notifyMatchFailure(
+ warpOp,
+ "the source or result of shape_cast op lacks distribution layout");
+
+ // For rank reducing or increasing shape_cast ops, the lower rank layout
+ // must be a slice of higher rank layout.
+ int64_t sourceRank = shapeCastOp.getSourceVectorType().getRank();
+ int64_t resultRank = shapeCastOp.getResultVectorType().getRank();
+ if (sourceRank < resultRank && !sourceLayout.isSliceOf(resultLayout))
+ return rewriter.notifyMatchFailure(
+ warpOp, "shape_cast is rank reducing but source layout is not a "
+ "slice of result layout");
+ if (sourceRank > resultRank && !resultLayout.isSliceOf(sourceLayout))
+ return rewriter.notifyMatchFailure(
+ warpOp, "shape_cast is rank increasing but result layout is not a "
+ "slice of source layout");
+
----------------
nbpatel wrote:
Shape cast patterns needs to check only unit dims can be squeezed/expanded
https://github.com/llvm/llvm-project/pull/157560
More information about the Mlir-commits
mailing list