[Mlir-commits] [mlir] 6b5c440 - [mlir][xegpu] Add support for `vector.reduction` and `vector.multi_reduction` subgroup to work-item distribution. (#180308)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Feb 13 11:50:00 PST 2026
Author: Charitha Saumya
Date: 2026-02-13T11:49:55-08:00
New Revision: 6b5c440a676defbbbef3f6a42404b0be2b3da54c
URL: https://github.com/llvm/llvm-project/commit/6b5c440a676defbbbef3f6a42404b0be2b3da54c
DIFF: https://github.com/llvm/llvm-project/commit/6b5c440a676defbbbef3f6a42404b0be2b3da54c.diff
LOG: [mlir][xegpu] Add support for `vector.reduction` and `vector.multi_reduction` subgroup to work-item distribution. (#180308)
This PR adds support for lowering of `vector.reduction` and
`vector.multi_reduction` ops in subgroup to work-item distribution.
Following cases are considered currently (more support will be added
later):
* `vector.reduction` : This assumes the source vector is distributed to
all lanes and lanes must shuffle data to do a collaborative reduction.
result is shared among all lanes. This is done by emitting
`gpu::ShuffleOp` s and doing a butterfly reduction. Refer
`VectorDistribution` for more details.
* `vector.multi_reduction`: 2 cases are considered,
1. **Reduction is lane-local**: simply lower to a lane local multi
reduction op. each lane does its own reduction. result is distributed.
2. **Reduction is not lane-local:** This one is handled indirectly. In
this case, we rewrite the reduction in terms of `vector.reduction` ops
(plus exrtact. insert) before the WI distribution even begin. Then whole
things is distributed using `gpu::ShuffleOp` s later (not fullly
supported yet).
Added:
Modified:
mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir
mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h b/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
index ea01975da582f..6f6d58d4ab605 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
@@ -82,6 +82,14 @@ void populateXeGPUSgToWiDistributeTypeConversions(TypeConverter &typeConverter);
void populateXeGPUSgToWiDistributeTypeConversionAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target);
+/// Appends patterns to rewrite vector::MultiDimReductionOp in terms of
+/// vector::ReductionOps if the multi-reduction involves cross-lane data
+/// movement. This pattern is used as pre-processing step before applying
+/// subgroup to workitem distribution patterns. This pattern will rewrite a
+/// multi reduction in terms of a series of simpler extract, reduction and
+/// insert ops if the reduction require cross-lane data movement.
+void populateXeGPUSgToWiLowerVectorMultiReductionAndLegality(
+ RewritePatternSet &patterns, ConversionTarget &target);
/// Collect a set of patterns to unroll xegpu operations to a smaller shapes.
/// Users can control whether an operation to be unrolled or not, as well as
@@ -93,7 +101,7 @@ void populateXeGPUSgToWiDistributeTypeConversionAndLegality(
/// 1. the unrolled type `unrolledType` and number of unrolled instances
/// `numUnrolledInstances` are computed from the `targetShape`.
/// 2. pack each operand. ExtractStridedSlice are created to break-up the
-/// vector operands. And BuiltinUnrealizedCastop are created to break-up
+/// vector operands. And BuiltinUnrealizedCastOp are created to break-up
/// the TensorDesc operands.
/// 3. the original op is cloned `numUnrolledInstances` times, once for each
/// result.
diff --git a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
index 4443f86d1e4e2..ebf50c4cd57de 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
@@ -129,6 +129,24 @@ SmallVector<OpFoldResult> addWithRightAligned(OpBuilder &builder, Location loc,
ArrayRef<OpFoldResult> lhs,
ArrayRef<OpFoldResult> rhs);
+/// Given an `input` value representing per-lane data, this function returns the
+/// result after performing a reduction on the input over all lanes (number of
+/// lanes given by `size`). This uses butterfly shuffles to perform the
+/// reduction in a log2(size) number of steps.
+/// NOTE: Implementation taken from TestVectorTransforms.cpp
+Value subgroupReduction(Location loc, OpBuilder &builder, Value input,
+ vector::CombiningKind kind, uint32_t size);
+
+/// Given a `src` and an `acc` argumments from a vector::MultiDimReductionOp,
+/// lower to a set of vector::ReductionOp ops over 1D slices extracted from
+/// `src`. The reduction is performed along `reductionDim`. The result is a
+/// vector with the same shape as `acc`.
+/// TODO: Only 2D to 1D reduction is supported for now.
+Value lowerToVectorReductions(TypedValue<VectorType> src,
+ TypedValue<VectorType> acc,
+ vector::CombiningKind kind, int64_t reductionDim,
+ Location loc, PatternRewriter &rewriter);
+
/// Helper Function to find a proper instruction multiple for the user-supplied
/// sg-level data shape (diven by `dim`). `candidates` are uArch allowed shapes.
/// `candidateMultiples` are uArch multiples of such shapes (i.e. block count or
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
index 8e530642d9c7a..3787fbb44e1b8 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
@@ -88,6 +88,39 @@ static LogicalResult verifyLayouts(Operation *root) {
return walkResult.wasInterrupted() ? failure() : success();
}
+/// A vector::MultiDimReductionOp at subgroup level in expected form if, it has
+/// exactly 1 reduction dimension, it had valid result layout attribute, and
+/// result type can be distributed to lanes using the layout.
+static bool isValidSubgroupMultiReductionOp(vector::MultiDimReductionOp op) {
+ auto resLayout = xegpu::getTemporaryLayout(op->getOpResult(0));
+ // If no layout, not valid.
+ if (!resLayout || !resLayout.isForSubgroup())
+ return false;
+ VectorType resTy = dyn_cast<VectorType>(op.getType());
+ if (!resTy)
+ return false;
+ // Compute the distributed result vector type based on the layout.
+ FailureOr<VectorType> resDistTypeOrFailure =
+ getDistVecTypeBasedOnLaneLayout(resLayout, resTy);
+ if (failed(resDistTypeOrFailure))
+ return false;
+ return op.getReductionDims().size() == 1;
+}
+
+/// A vector::MultiDimReductionOp is doing lane-local reduction if each workitem
+/// is doing its own local reduction. In this case the result layout ensures
+/// that result vector is distributed to lanes, i.e. the result vector type is
+///
diff erent from the distributed result vector type.
+static bool isReductionLaneLocal(vector::MultiDimReductionOp op) {
+ // Must be valid MultiDimReductionOp.
+ assert(isValidSubgroupMultiReductionOp(op) && "Expecting a valid subgroup "
+ "MultiDimReductionOp");
+ auto resLayout = xegpu::getTemporaryLayout(op->getOpResult(0));
+ VectorType resTy = dyn_cast<VectorType>(op.getType());
+ auto resDistTypeOrFailure = getDistVecTypeBasedOnLaneLayout(resLayout, resTy);
+ return resTy != resDistTypeOrFailure.value();
+}
+
/// Distributes a subgroup-level CreateNdDesc op to workitem-level CreateNdDesc
/// op. This simply drops the layout attribute from the tensor descriptor type.
struct SgToWiCreateNdDesc : public OpConversionPattern<xegpu::CreateNdDescOp> {
@@ -362,6 +395,133 @@ struct SgToWiPrefetchNd : public OpConversionPattern<xegpu::PrefetchNdOp> {
}
};
+/// This pattern distributes a subgroup-level vector.reduction op to
+/// workitem-level. This require shuffling the data across the workitems (using
+/// gpu::ShuffleOp) and reducing in stages until all workitems have the final
+/// result.
+struct SgToWiVectorReduction : public OpConversionPattern<vector::ReductionOp> {
+ using OpConversionPattern<vector::ReductionOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::ReductionOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto layout = xegpu::getDistributeLayoutAttr(op.getVector());
+
+ // If no layout, nothing to do.
+ if (!layout || !layout.isForSubgroup())
+ return failure();
+
+ VectorType srcVecType = op.getSourceVectorType();
+ // Only rank 1 vectors supported.
+ if (srcVecType.getRank() != 1)
+ return rewriter.notifyMatchFailure(
+ op, "Only rank 1 reductions can be distributed.");
+ // Lane layout must have the same rank as the vector.
+ if (layout.getRank() != srcVecType.getRank())
+ return rewriter.notifyMatchFailure(
+ op, "Layout rank does not match vector rank.");
+
+ // Get the subgroup size from the layout.
+ int64_t sgSize = layout.getEffectiveLaneLayoutAsInt()[0];
+ const auto *uArch = getUArch(xegpu::getChipStr(op).value_or(""));
+ if (!uArch)
+ return rewriter.notifyMatchFailure(
+ op, "xegpu::ReductionOp require target attribute attached to "
+ "determine subgroup size");
+
+ // Only subgroup-sized vectors supported.
+ if (sgSize != uArch->getSubgroupSize() ||
+ srcVecType.getShape()[0] % sgSize != 0)
+ return rewriter.notifyMatchFailure(op,
+ "Invalid layout or reduction vector "
+ "dimension must match subgroup size.");
+
+ if (!op.getType().isIntOrFloat())
+ return rewriter.notifyMatchFailure(
+ op, "Reduction distribution currently only supports floats and "
+ "integer types.");
+
+ // Get the distributed vector (per work-item portion).
+ Value laneValVec = adaptor.getVector();
+
+ // Distribute and reduce across work-items in the subgroup.
+ Value fullReduce = xegpu::subgroupReduction(
+ op.getLoc(), rewriter, laneValVec, op.getKind(), sgSize);
+
+ // If there's an accumulator, combine it with the reduced value.
+ if (adaptor.getAcc())
+ fullReduce = vector::makeArithReduction(
+ rewriter, op.getLoc(), op.getKind(), fullReduce, adaptor.getAcc());
+
+ rewriter.replaceOp(op, fullReduce);
+ return success();
+ }
+};
+
+/// This pattern distributes a subgroup-level vector.multi_reduction op to
+/// workitem-level only if the reduction is lane-local. This means that
+/// reduction dimension is not distributed to lanes and each lane does its own
+/// local reduction.
+struct SgToWiMultiDimReduction
+ : public OpConversionPattern<vector::MultiDimReductionOp> {
+ using OpConversionPattern<vector::MultiDimReductionOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::MultiDimReductionOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ // Only lane-local reduction is handled here.
+ if (!isReductionLaneLocal(op))
+ return rewriter.notifyMatchFailure(
+ op, "Only lane-local reduction is supported, expected reduction "
+ "dimension to be "
+ "not distributed.");
+ auto resLayout = xegpu::getTemporaryLayout(op->getOpResult(0));
+ VectorType resVecTy = dyn_cast<VectorType>(op.getType());
+ auto resDistVecTyOrFailure =
+ getDistVecTypeBasedOnLaneLayout(resLayout, resVecTy);
+ // Simply create a new MultiDimReductionOp using adaptor operands and the
+ // new result type.
+ auto newOp = vector::MultiDimReductionOp::create(
+ rewriter, op.getLoc(), resDistVecTyOrFailure.value(), op.getKind(),
+ adaptor.getSource(), adaptor.getAcc(), op.getReductionDims());
+ rewriter.replaceOp(op, newOp.getResult());
+ return success();
+ }
+};
+
+/// This pattern rewrites a subgroup-level vector.multi_reduction op to a series
+/// of vector.extract_strided_slice, vector.reduction and
+/// vector.insert_strided_slice ops. This is used when the reduction dimension
+/// is distributed to lanes and a naive (lane-local) distribution is not
+/// possible. Then later on, these partially lowered subgroup-level ops are
+/// further lowered to workitem-level by respective patterns.
+struct LowerVectorMultiReductionPattern
+ : public OpConversionPattern<vector::MultiDimReductionOp> {
+ using OpConversionPattern<vector::MultiDimReductionOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::MultiDimReductionOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ // Only non-lane-local reduction is handled here.
+ if (isReductionLaneLocal(op))
+ return rewriter.notifyMatchFailure(
+ op, "Reduction is lane-local, it does not require rewrite.");
+ ArrayRef<int64_t> reductionDims = op.getReductionDims();
+ assert(
+ reductionDims.size() == 1 &&
+ "Expecting single reduction dimension for subgroup multi reduction op");
+
+ // Rewrite MultiDimReductionOp into a sequence of ReductionOps.
+ Value result = xegpu::lowerToVectorReductions(
+ cast<TypedValue<VectorType>>(op.getSource()),
+ cast<TypedValue<VectorType>>(op.getAcc()), op.getKind(),
+ reductionDims[0], op.getLoc(), rewriter);
+
+ rewriter.replaceOp(op, result);
+ return success();
+ }
+};
+
struct XeGPUSgToWiDistributeExperimentalPass
: public xegpu::impl::XeGPUSgToWiDistributeExperimentalBase<
XeGPUSgToWiDistributeExperimentalPass> {
@@ -551,8 +711,44 @@ void xegpu::populateXeGPUSgToWiDistributeTypeConversionAndLegality(
}
return !xegpu::getTemporaryLayout(dyn_cast<OpResult>(op->getResult(0)));
});
+ // vector::ReductionOp is legal only if its source has no distribute layout
+ // attribute.
+ target.addDynamicallyLegalOp<vector::ReductionOp>(
+ [=](vector::ReductionOp op) -> bool {
+ auto layout = xegpu::getDistributeLayoutAttr(op.getVector());
+ return !layout;
+ });
+ // vector::MultiDimReductionOp op legality.
+ target.addDynamicallyLegalOp<vector::MultiDimReductionOp>(
+ [=](vector::MultiDimReductionOp op) -> bool {
+ // Check common conditions for subgroup multi reduction op.
+ if (!isValidSubgroupMultiReductionOp(op))
+ return true;
+ // Lane local reductions are illegal at this point and must be lowered.
+ return !isReductionLaneLocal(op);
+ });
target.markUnknownOpDynamicallyLegal([](Operation *op) { return true; });
patterns.add<SgToWiCreateNdDesc, SgToWiLoadNd, SgToWiStoreNd, SgToWiDpas,
- SgToWiElementWise, SgToWiArithConstant, SgToWiPrefetchNd>(
+ SgToWiElementWise, SgToWiArithConstant, SgToWiPrefetchNd,
+ SgToWiVectorReduction, SgToWiMultiDimReduction>(
typeConverter, patterns.getContext());
}
+
+void xegpu::populateXeGPUSgToWiLowerVectorMultiReductionAndLegality(
+ RewritePatternSet &patterns, ConversionTarget &target) {
+ // vector::MultiDimReductionOp legality.
+ target.addDynamicallyLegalOp<vector::MultiDimReductionOp>(
+ [&](vector::MultiDimReductionOp op) {
+ // Check common conditions for subgroup multi reduction op.
+ if (!isValidSubgroupMultiReductionOp(op))
+ return true;
+ // Lane local reductions are legal. We only rewrite non-lane-local
+ // reductions.
+ return isReductionLaneLocal(op);
+ });
+ // vector::ReductionOp is legal.
+ target.addDynamicallyLegalOp<vector::ReductionOp>(
+ [&](vector::ReductionOp op) { return true; });
+ target.markUnknownOpDynamicallyLegal([](Operation *op) { return true; });
+ patterns.add<LowerVectorMultiReductionPattern>(patterns.getContext());
+}
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 7671e2bbc3322..b8c4a309b8eb2 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -1253,69 +1253,6 @@ struct SinkUniformOps final : public gpu::WarpDistributionPattern {
}
};
-/// Helper to rewrite a 2D VectorMultiReductionOp into a sequence of 1D
-/// VectorReductionOps. We also insert layouts for the newly created ops.
-static Value lowerToVectorReductions(TypedValue<VectorType> src,
- TypedValue<VectorType> acc,
- vector::CombiningKind kind,
- int64_t reductionDim, Location loc,
- PatternRewriter &rewriter) {
- // Expecting a 2D source vector.
- assert(src.getType().getRank() == 2 && "expected a 2D source vector");
- VectorType sourceType = src.getType();
- int64_t sourceH = sourceType.getShape()[0];
- int64_t sourceW = sourceType.getShape()[1];
- int nSlices = (reductionDim == 0) ? sourceW : sourceH;
- // Create a constant vector to hold the result of the reduction.
- TypedAttr zeroAttr = rewriter.getZeroAttr(sourceType.getElementType());
- Value reductionResult = arith::ConstantOp::create(
- rewriter, loc, acc.getType(),
- DenseElementsAttr::get(acc.getType(), zeroAttr));
- // Reduction result should have the same layout as the accumulator.
- xegpu::setTemporaryLayout(cast<OpResult>(reductionResult),
- xegpu::getTemporaryLayout(dyn_cast<OpResult>(acc)));
- // For each slice of the source, extract the slice vector, do a reduction
- // and, insert the reduced value back to the result vector.
- for (int i = 0; i < nSlices; ++i) {
- SmallVector<int64_t, 2> sliceOffsets, sliceSizes;
- if (reductionDim == 1) {
- sliceOffsets = {i, 0};
- sliceSizes = {1, sourceW};
- } else {
- sliceOffsets = {0, i};
- sliceSizes = {sourceH, 1};
- }
- vector::ExtractStridedSliceOp extractOp =
- vector::ExtractStridedSliceOp::create(rewriter, loc, src, sliceOffsets,
- sliceSizes, {1, 1});
-
- int64_t nSliceElements = extractOp.getResult().getType().getNumElements();
-
- vector::ShapeCastOp slice = vector::ShapeCastOp::create(
- rewriter, loc,
- VectorType::get({nSliceElements}, sourceType.getElementType()),
- extractOp.getResult());
-
- // Shape cast is currently handled in xegpu side. So layouts must be
- // retained during lowering. Shape cast output has the same layout as the
- // accumulator. Shape cast source has the same layout as the original
- // reduction source.
- // TODO: other ops generated here may also need layout attributes.
- auto srcLayout = xegpu::getTemporaryLayout(dyn_cast<OpResult>(src));
- auto accLayout = xegpu::getTemporaryLayout(dyn_cast<OpResult>(acc));
-
- xegpu::setTemporaryLayout(slice->getOpOperand(0), srcLayout);
- xegpu::setTemporaryLayout(slice->getOpResult(0), accLayout);
- // Extract and reduction results in scalars, so no result layout is needed.
- Value accExtract = vector::ExtractOp::create(rewriter, loc, acc, i);
- Value reduction = vector::ReductionOp::create(
- rewriter, loc, kind, slice.getResult(), accExtract);
- reductionResult =
- vector::InsertOp::create(rewriter, loc, reduction, reductionResult, i);
- }
- return reductionResult;
-}
-
/// This patterns distribute the `vector.multi_reduction` operation across
/// lanes in a warp. Currently only 2D to 1D reductions are supported. Given
/// layouts for the source and accumulator vectors,
@@ -1453,7 +1390,7 @@ struct VectorMultiReductionDistribution : public gpu::WarpDistributionPattern {
rewriter, warpOp, {reductionOp.getSource(), reductionOp.getAcc()},
{sourceDistType, distributedResultType}, newRetIndices);
rewriter.setInsertionPointAfter(newWarpOp);
- Value result = lowerToVectorReductions(
+ Value result = xegpu::lowerToVectorReductions(
cast<TypedValue<VectorType>>(newWarpOp->getResult(newRetIndices[0])),
cast<TypedValue<VectorType>>(newWarpOp->getResult(newRetIndices[1])),
reductionOp.getKind(), reductionDim, reductionOp.getLoc(), rewriter);
@@ -1465,7 +1402,7 @@ struct VectorMultiReductionDistribution : public gpu::WarpDistributionPattern {
// of multiple ReductionOps. Actual distribution is done by the
// WarpOpReduction pattern.
rewriter.setInsertionPointAfter(reductionOp);
- Value result = lowerToVectorReductions(
+ Value result = xegpu::lowerToVectorReductions(
cast<TypedValue<VectorType>>(reductionOp.getSource()),
cast<TypedValue<VectorType>>(reductionOp.getAcc()),
reductionOp.getKind(), reductionDim, reductionOp.getLoc(), rewriter);
@@ -2151,23 +2088,8 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
auto shuffleFn = [](Location loc, OpBuilder &builder, Value val, Value srcIdx,
int64_t warpSz) { return Value(); };
- auto warpReduction = [](Location loc, OpBuilder &builder, Value input,
- vector::CombiningKind kind, uint32_t size) {
- // First reduce on a single thread to get per lane reduction value.
- Value laneVal = vector::ReductionOp::create(builder, loc, kind, input);
- // Parallel reduction using butterfly shuffles.
- for (uint64_t i = 1; i < size; i <<= 1) {
- Value shuffled = gpu::ShuffleOp::create(builder, loc, laneVal, i,
- /*width=*/size,
- /*mode=*/gpu::ShuffleMode::XOR)
- .getShuffleResult();
- laneVal = makeArithReduction(builder, loc, kind, laneVal, shuffled);
- }
- return laneVal;
- };
-
vector::populateDistributeReduction(
- patterns, warpReduction,
+ patterns, xegpu::subgroupReduction,
/*pattern benefit=*/PatternHierarchy::Regular);
vector::populatePropagateWarpVectorDistributionPatterns(
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
index c47fd92fe46d7..5fdab1e759deb 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
@@ -651,6 +651,82 @@ int xegpu::getLargestDivisor(T dim, ArrayRef<T> candidates,
return largest;
}
+Value xegpu::subgroupReduction(Location loc, OpBuilder &builder, Value input,
+ vector::CombiningKind kind, uint32_t size) {
+ // First reduce on a single thread to get per lane reduction value.
+ Value laneVal = vector::ReductionOp::create(builder, loc, kind, input);
+ // Parallel reduction using butterfly shuffles.
+ for (uint64_t i = 1; i < size; i <<= 1) {
+ Value shuffled =
+ gpu::ShuffleOp::create(builder, loc, laneVal, i, /** width = **/ size,
+ /** mode = **/ gpu::ShuffleMode::XOR)
+ .getShuffleResult();
+ laneVal = makeArithReduction(builder, loc, kind, laneVal, shuffled);
+ }
+ return laneVal;
+}
+
+Value xegpu::lowerToVectorReductions(TypedValue<VectorType> src,
+ TypedValue<VectorType> acc,
+ vector::CombiningKind kind,
+ int64_t reductionDim, Location loc,
+ PatternRewriter &rewriter) {
+ // Expecting a 2D source vector.
+ assert(src.getType().getRank() == 2 && "expected a 2D source vector");
+ VectorType sourceType = src.getType();
+ int64_t sourceH = sourceType.getShape()[0];
+ int64_t sourceW = sourceType.getShape()[1];
+ int nSlices = (reductionDim == 0) ? sourceW : sourceH;
+ // Create a constant vector to hold the result of the reduction.
+ TypedAttr zeroAttr = rewriter.getZeroAttr(sourceType.getElementType());
+ Value reductionResult = arith::ConstantOp::create(
+ rewriter, loc, acc.getType(),
+ DenseElementsAttr::get(acc.getType(), zeroAttr));
+ auto srcLayout = xegpu::getTemporaryLayout(dyn_cast<OpResult>(src));
+ auto accLayout = xegpu::getTemporaryLayout(dyn_cast<OpResult>(acc));
+ // Reduction result should have the same layout as the accumulator.
+ xegpu::setTemporaryLayout(cast<OpResult>(reductionResult), accLayout);
+ // For each slice of the source, extract the slice vector, do a reduction
+ // and, insert the reduced value back to the result vector.
+ for (int i = 0; i < nSlices; ++i) {
+ SmallVector<int64_t, 2> sliceOffsets, sliceSizes;
+ if (reductionDim == 1) {
+ sliceOffsets = {i, 0};
+ sliceSizes = {1, sourceW};
+ } else {
+ sliceOffsets = {0, i};
+ sliceSizes = {sourceH, 1};
+ }
+
+ vector::ExtractStridedSliceOp extractOp =
+ vector::ExtractStridedSliceOp::create(rewriter, loc, src, sliceOffsets,
+ sliceSizes, {1, 1});
+ // Extract strided slice has the same layout as src.
+ xegpu::setTemporaryLayout(extractOp->getOpResult(0), srcLayout);
+
+ int64_t nSliceElements = extractOp.getResult().getType().getNumElements();
+
+ vector::ShapeCastOp slice = vector::ShapeCastOp::create(
+ rewriter, loc,
+ VectorType::get({nSliceElements}, sourceType.getElementType()),
+ extractOp.getResult());
+
+ // Shape cast output has the same layout as the accumulator. Shape cast
+ // source has the same layout as the original reduction source.
+ xegpu::setTemporaryLayout(slice->getOpOperand(0), srcLayout);
+ xegpu::setTemporaryLayout(slice->getOpResult(0), accLayout);
+ // Extract and reduction results in scalars, so no result layout is needed.
+ Value accExtract = vector::ExtractOp::create(rewriter, loc, acc, i);
+ Value reduction = vector::ReductionOp::create(
+ rewriter, loc, kind, slice.getResult(), accExtract);
+ reductionResult =
+ vector::InsertOp::create(rewriter, loc, reduction, reductionResult, i);
+ // Insert op should have the same layout as the accumulator.
+ xegpu::setTemporaryLayout(cast<OpResult>(reductionResult), accLayout);
+ }
+ return reductionResult;
+}
+
/// Explicit instantiations
template int xegpu::getLargestDivisor<int>(int dim, ArrayRef<int> candidates,
ArrayRef<int> candidateMultiples);
diff --git a/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir b/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
index 0e9843f4626d4..1ec0879d4fb47 100644
--- a/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
+++ b/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
@@ -2,6 +2,10 @@
// RUN: mlir-opt --xevm-attach-target='module=xevm_* chip=pvc' --allow-unregistered-dialect \
// RUN: --test-xegpu-sg-to-wi-distribute-experimental --split-input-file %s | FileCheck %s
+// RUN: mlir-opt --allow-unregistered-dialect \
+// RUN: --test-xegpu-sg-to-wi-distribute-experimental="enable-rewrite-multi-reduction-to-reductions" \
+// RUN: --split-input-file %s | FileCheck --check-prefix=CHECK-REWRITE %s
+
gpu.module @xevm_module {
@@ -149,4 +153,168 @@ gpu.func @prefetch_nd() {
: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
gpu.return
}
+
+// CHECK-LABEL: gpu.func @vector_reduction
+// CHECK: %[[CST:.*]] = arith.constant 1.000000e+00 : f32
+// CHECK: %[[LANE_RED:.*]] = vector.reduction <add>, %[[CAST:.*]] : vector<2xf32> into f32
+// CHECK-DAG: %[[C16_1:.*]] = arith.constant 16 : i32
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : i32
+// CHECK: %[[SHUFFLE1:.*]], %{{.*}} = gpu.shuffle xor %[[LANE_RED]], %[[C1]], %[[C16_1]] : f32
+// CHECK: %[[ADD1:.*]] = arith.addf %[[LANE_RED]], %[[SHUFFLE1]] : f32
+// CHECK-DAG: %[[C16_2:.*]] = arith.constant 16 : i32
+// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : i32
+// CHECK: %[[SHUFFLE2:.*]], %{{.*}} = gpu.shuffle xor %[[ADD1]], %[[C2]], %[[C16_2]] : f32
+// CHECK: %[[ADD2:.*]] = arith.addf %[[ADD1]], %[[SHUFFLE2]] : f32
+// CHECK-DAG: %[[C16_3:.*]] = arith.constant 16 : i32
+// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : i32
+// CHECK: %[[SHUFFLE3:.*]], %{{.*}} = gpu.shuffle xor %[[ADD2]], %[[C4]], %[[C16_3]] : f32
+// CHECK: %[[ADD3:.*]] = arith.addf %[[ADD2]], %[[SHUFFLE3]] : f32
+// CHECK-DAG: %[[C16_4:.*]] = arith.constant 16 : i32
+// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : i32
+// CHECK: %[[SHUFFLE4:.*]], %{{.*}} = gpu.shuffle xor %[[ADD3]], %[[C8]], %[[C16_4]] : f32
+// CHECK: %[[ADD4:.*]] = arith.addf %[[ADD3]], %[[SHUFFLE4]] : f32
+// CHECK: %[[FINAL:.*]] = arith.addf %[[ADD4]], %[[CST]] : f32
+gpu.func @vector_reduction() {
+ %acc = arith.constant 1.0 : f32
+ %0 = "some_op"() {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} : () -> vector<32xf32>
+ %2 = vector.reduction <add>, %0, %acc : vector<32xf32> into f32
+ gpu.return
+}
+
+
+// CHECK-REWRITE-LABEL: gpu.func @vector_multi_reduction_dim1_distributed_dim1_reduction
+// CHECK-REWRITE-DAG: %[[SRC:.*]] = "some_def"() {layout_result_0 =
+// CHECK-REWRITE-SAME: #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : () -> vector<2x16xf32>
+// CHECK-REWRITE-DAG: %[[ACC:.*]] = arith.constant
+// CHECK-REWRITE-SAME: {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [1]>}
+// CHECK-REWRITE-SAME: dense<0.000000e+00> : vector<2xf32>
+// CHECK-REWRITE-DAG: %[[ZERO:.*]] = arith.constant
+// CHECK-REWRITE-SAME: {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [1]>}
+// CHECK-REWRITE-SAME: dense<0.000000e+00> : vector<2xf32>
+// CHECK-REWRITE: %[[SLICE0:.*]] = vector.extract_strided_slice %[[SRC]]
+// CHECK-REWRITE-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+// CHECK-REWRITE-SAME: offsets = [0, 0], sizes = [1, 16], strides = [1, 1]} : vector<2x16xf32> to vector<1x16xf32>
+// CHECK-REWRITE-NEXT: %[[CAST0:.*]] = vector.shape_cast %[[SLICE0]]
+// CHECK-REWRITE-SAME: {{{.*}}, layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [1]>}
+// CHECK-REWRITE-SAME: : vector<1x16xf32> to vector<16xf32>
+// CHECK-REWRITE-NEXT: %[[ACC0:.*]] = vector.extract %[[ACC]][0] : f32 from vector<2xf32>
+// CHECK-REWRITE-NEXT: %[[RED0:.*]] = vector.reduction <add>, %[[CAST0]], %[[ACC0]] : vector<16xf32> into f32
+// CHECK-REWRITE-NEXT: %[[INS0:.*]] = vector.insert %[[RED0]], %[[ZERO]] [0]
+// CHECK-REWRITE-SAME: {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [1]>}
+// CHECK-REWRITE-SAME: : f32 into vector<2xf32>
+// CHECK-REWRITE-NEXT: %[[SLICE1:.*]] = vector.extract_strided_slice %[[SRC]]
+// CHECK-REWRITE-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+// CHECK-REWRITE-SAME: offsets = [1, 0], sizes = [1, 16], strides = [1, 1]} : vector<2x16xf32> to vector<1x16xf32>
+// CHECK-REWRITE-NEXT: %[[CAST1:.*]] = vector.shape_cast %[[SLICE1]]
+// CHECK-REWRITE-SAME: {{{.*}}, layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [1]>}
+// CHECK-REWRITE-SAME: : vector<1x16xf32> to vector<16xf32>
+// CHECK-REWRITE-NEXT: %[[ACC1:.*]] = vector.extract %[[ACC]][1] : f32 from vector<2xf32>
+// CHECK-REWRITE-NEXT: %[[RED1:.*]] = vector.reduction <add>, %[[CAST1]], %[[ACC1]] : vector<16xf32> into f32
+// CHECK-REWRITE-NEXT: %[[INS1:.*]] = vector.insert %[[RED1]], %[[INS0]] [1]
+// CHECK-REWRITE-SAME: {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [1]>}
+// CHECK-REWRITE-SAME: : f32 into vector<2xf32>
+gpu.func @vector_multi_reduction_dim1_distributed_dim1_reduction(%laneid: index) {
+ %c0 = arith.constant 0 : index
+ %src = "some_def"()
+ {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+ : () -> (vector<2x16xf32>)
+ %acc = arith.constant
+ {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [1]>}
+ dense<0.0> : vector<2xf32>
+ %1 = vector.multi_reduction <add>, %src, %acc
+ {
+ layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [1]>
+ }
+ [1] : vector<2x16xf32> to vector<2xf32>
+ gpu.return
+}
+
+// CHECK-REWRITE-LABEL: gpu.func @vector_multi_reduction_dim0_distributed_dim0_reduction
+// CHECK-REWRITE-DAG: %[[SRC:.*]] = "some_def"() {layout_result_0 =
+// CHECK-REWRITE-SAME: #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>} : () -> vector<16x2xf32>
+// CHECK-REWRITE-DAG: %[[ACC:.*]] = arith.constant
+// CHECK-REWRITE-SAME: {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>, dims = [0]>}
+// CHECK-REWRITE-SAME: dense<0.000000e+00> : vector<2xf32>
+// CHECK-REWRITE-DAG: %[[ZERO:.*]] = arith.constant
+// CHECK-REWRITE-SAME: {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>, dims = [0]>}
+// CHECK-REWRITE-SAME: dense<0.000000e+00> : vector<2xf32>
+// CHECK-REWRITE: %[[SLICE0:.*]] = vector.extract_strided_slice %[[SRC]]
+// CHECK-REWRITE-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>,
+// CHECK-REWRITE-SAME: offsets = [0, 0], sizes = [16, 1], strides = [1, 1]} : vector<16x2xf32> to vector<16x1xf32>
+// CHECK-REWRITE-NEXT: %[[CAST0:.*]] = vector.shape_cast %[[SLICE0]]
+// CHECK-REWRITE-SAME: {{.*}}, layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>, dims = [0]>}
+// CHECK-REWRITE-SAME: : vector<16x1xf32> to vector<16xf32>
+// CHECK-REWRITE-NEXT: %[[ACC0:.*]] = vector.extract %[[ACC]][0] : f32 from vector<2xf32>
+// CHECK-REWRITE-NEXT: %[[RED0:.*]] = vector.reduction <add>, %[[CAST0]], %[[ACC0]] : vector<16xf32> into f32
+// CHECK-REWRITE-NEXT: %[[INS0:.*]] = vector.insert %[[RED0]], %[[ZERO]] [0]
+// CHECK-REWRITE-SAME: {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>, dims = [0]>}
+// CHECK-REWRITE-SAME: : f32 into vector<2xf32>
+// CHECK-REWRITE-NEXT: %[[SLICE1:.*]] = vector.extract_strided_slice %[[SRC]]
+// CHECK-REWRITE-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>,
+// CHECK-REWRITE-SAME: offsets = [0, 1], sizes = [16, 1], strides = [1, 1]} : vector<16x2xf32> to vector<16x1xf32>
+// CHECK-REWRITE-NEXT: %[[CAST1:.*]] = vector.shape_cast %[[SLICE1]]
+// CHECK-REWRITE-SAME: {{{.*}}, layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>, dims = [0]>
+// CHECK-REWRITE-SAME: : vector<16x1xf32> to vector<16xf32>
+// CHECK-REWRITE-NEXT: %[[ACC1:.*]] = vector.extract %[[ACC]][1] : f32 from vector<2xf32>
+// CHECK-REWRITE-NEXT: %[[RED1:.*]] = vector.reduction <add>, %[[CAST1]], %[[ACC1]] : vector<16xf32> into f32
+// CHECK-REWRITE-NEXT: %[[INS1:.*]] = vector.insert %[[RED1]], %[[INS0]] [1]
+// CHECK-REWRITE-SAME: {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>, dims = [0]>}
+// CHECK-REWRITE-SAME: : f32 into vector<2xf32>
+gpu.func @vector_multi_reduction_dim0_distributed_dim0_reduction(%laneid: index) {
+ %c0 = arith.constant 0 : index
+ %src = "some_def"()
+ {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>}
+ : () -> (vector<16x2xf32>)
+ %acc = arith.constant
+ {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>, dims = [0]>}
+ dense<0.0> : vector<2xf32>
+ %1 = vector.multi_reduction <add>, %src, %acc
+ {
+ layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>, dims = [0]>
+ }
+ [0] : vector<16x2xf32> to vector<2xf32>
+ gpu.return
+}
+
+// CHECK-LABEL: gpu.func @vector_multi_reduction_dim1_distributed_dim0_reduction
+// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<4x1xf32>
+// CHECK: %[[CST_0:.*]] = arith.constant dense<0.000000e+00> : vector<1xf32>
+// CHECK: %[[RED:.*]] = vector.multi_reduction <add>, %[[CST]], %[[CST_0]] [0] : vector<4x1xf32> to vector<1xf32>
+// CHECK: gpu.return
+gpu.func @vector_multi_reduction_dim1_distributed_dim0_reduction(%laneid: index) {
+ %c0 = arith.constant 0 : index
+ %src = arith.constant
+ {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+ dense<0.0> : vector<4x16xf32>
+ %acc = arith.constant
+ {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>}
+ dense<0.0> : vector<16xf32>
+ %1 = vector.multi_reduction <add>, %src, %acc
+ {
+ layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>
+ }
+ [0] : vector<4x16xf32> to vector<16xf32>
+ gpu.return
+}
+
+// CHECK-LABEL: gpu.func @vector_multi_reduction_dim0_distributed_dim1_reduction
+// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<1x12xf32>
+// CHECK: %[[CST_0:.*]] = arith.constant dense<0.000000e+00> : vector<1xf32>
+// CHECK: %[[RED:.*]] = vector.multi_reduction <add>, %[[CST]], %[[CST_0]] [1] : vector<1x12xf32> to vector<1xf32>
+// CHECK: gpu.return
+gpu.func @vector_multi_reduction_dim0_distributed_dim1_reduction(%laneid: index) {
+ %c0 = arith.constant 0 : index
+ %src = arith.constant
+ {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>}
+ dense<0.0> : vector<16x12xf32>
+ %acc = arith.constant
+ {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>, dims = [1]>}
+ dense<0.0> : vector<16xf32>
+ %1 = vector.multi_reduction <add>, %src, %acc
+ {
+ layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>, dims = [1]>
+ }
+ [1] : vector<16x12xf32> to vector<16xf32>
+ gpu.return
+}
}
diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir
index 55f8b870b6238..645e889d40657 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir
@@ -288,13 +288,9 @@ gpu.func @vector_multi_reduction_dim1_distributed_dim0_reduction(%laneid: index)
// CHECK: %[[W:.*]] = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> ({{.*}}) {
// CHECK-NEXT: %[[SRC:.*]] = "some_def"() {{.*}} : () -> vector<2x16xf32>
// CHECK-NEXT: %[[T2:.*]] = vector.extract %[[SRC]][0] : vector<16xf32> from vector<2x16xf32>
-// CHECK-NEXT: %[[T3:.*]] = vector.reduction <add>, %[[T2]], %cst : vector<16xf32> into f32
-// CHECK-NEXT: %[[T4:.*]] = vector.insert %[[T3]], %cst_0 [0] : f32 into vector<2xf32>
+// CHECK-NEXT: %[[T3:.*]] = vector.reduction <add>, %[[T2]], %{{.*}} : vector<16xf32> into f32
// CHECK-NEXT: %[[T5:.*]] = vector.extract %[[SRC]][1] : vector<16xf32> from vector<2x16xf32>
-// CHECK-NEXT: %[[T6:.*]] = vector.reduction <add>, %[[T5]], %cst : vector<16xf32> into f32
-// CHECK-NEXT: %[[T7:.*]] = vector.insert %[[T6]], %[[T4]] [1] : f32 into vector<2xf32>
-// CHECK-NEXT: gpu.yield %[[T7]]
-// CHECK-NEXT: }
+// CHECK-NEXT: %[[T6:.*]] = vector.reduction <add>, %[[T5]], %{{.*}} : vector<16xf32> into f32
gpu.func @vector_multi_reduction_dim1_distributed_dim1_reduction(%laneid: index) {
%c0 = arith.constant 0 : index
%r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<2xf32>) {
@@ -356,20 +352,27 @@ gpu.func @vector_multi_reduction_dim0_distributed_dim1_reduction(%laneid: index)
// CHECK-LABEL: gpu.func @vector_multi_reduction_dim0_distributed_dim0_reduction
-// CHECK: %[[W:.*]] = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<2xf32>{{.*}}) {
-// CHECK: %[[SRC:.*]] = "some_def"() {{.*}} : () -> vector<16x2xf32>
+// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[W:.*]] = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<2xf32>) {
+// CHECK: %[[SRC:.*]] = "some_def"()
+// CHECK-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>}
+// CHECK-SAME: : () -> vector<16x2xf32>
// CHECK: %[[T1:.*]] = vector.extract_strided_slice %[[SRC]]
-// CHECK-SAME: {offsets = [0, 0], sizes = [16, 1], strides = [1, 1]} : vector<16x2xf32> to vector<16x1xf32>
-// CHECK: %[[T2:.*]] = vector.shape_cast %[[T1]] {{.*}} : vector<16x1xf32> to vector<16xf32>
-// CHECK: %[[T3:.*]] = vector.reduction <add>, %[[T2]], %{{.*}} : vector<16xf32> into f32
-// CHECK: %[[T4:.*]] = vector.insert %[[T3]], %cst_0 [0] : f32 into vector<2xf32>
-// CHECK: %[[T5:.*]] = vector.extract_strided_slice %[[SRC]]
-// CHECK-SAME: {offsets = [0, 1], sizes = [16, 1], strides = [1, 1]} : vector<16x2xf32> to vector<16x1xf32>
-// CHECK: %[[T6:.*]] = vector.shape_cast %[[T5]] {{.*}} : vector<16x1xf32> to vector<16xf32>
-// CHECK: %[[T7:.*]] = vector.reduction <add>, %[[T6]], %{{.*}} : vector<16xf32> into f32
-// CHECK: %[[T8:.*]] = vector.insert %[[T7]], %[[T4]] [1] : f32 into vector<2xf32>
-// CHECK: gpu.yield %[[T8]]
-// CHECK: }
+// CHECK-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>,
+// CHECK-SAME: offsets = [0, 0], sizes = [16, 1], strides = [1, 1]} : vector<16x2xf32> to vector<16x1xf32>
+// CHECK: %[[T2:.*]] = vector.shape_cast %[[T1]]
+// CHECK-SAME: {layout_operand_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>,
+// CHECK-SAME: layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>, dims = [0]>}
+// CHECK-SAME: : vector<16x1xf32> to vector<16xf32>
+// CHECK: %[[T3:.*]] = vector.reduction <add>, %[[T2]], %[[CST]] : vector<16xf32> into f32
+// CHECK: %[[T4:.*]] = vector.extract_strided_slice %[[SRC]]
+// CHECK-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>,
+// CHECK-SAME: offsets = [0, 1], sizes = [16, 1], strides = [1, 1]} : vector<16x2xf32> to vector<16x1xf32>
+// CHECK: %[[T5:.*]] = vector.shape_cast %[[T4]]
+// CHECK-SAME: {layout_operand_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>,
+// CHECK-SAME: layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>, dims = [0]>}
+// CHECK-SAME: : vector<16x1xf32> to vector<16xf32>
+// CHECK: %[[T6:.*]] = vector.reduction <add>, %[[T5]], %[[CST]] : vector<16xf32> into f32
gpu.func @vector_multi_reduction_dim0_distributed_dim0_reduction(%laneid: index) {
%c0 = arith.constant 0 : index
%r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<2xf32>) {
diff --git a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
index 20bcb24a301e6..33af2c5b33d89 100644
--- a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
+++ b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
@@ -273,6 +273,12 @@ struct TestXeGPUSgToWiDistributeExperimental
"Work-item Distribution";
}
+ Option<bool> enableRewriteMultiReductionToReductions{
+ *this, "enable-rewrite-multi-reduction-to-reductions",
+ llvm::cl::desc("Partially lower multi-reduction ops to reduction ops if "
+ "the reduction dimension is distributed."),
+ llvm::cl::init(false)};
+
void getDependentDialects(::mlir::DialectRegistry ®istry) const override {
registry.insert<arith::ArithDialect>();
registry.insert<memref::MemRefDialect>();
@@ -284,7 +290,8 @@ struct TestXeGPUSgToWiDistributeExperimental
TestXeGPUSgToWiDistributeExperimental() = default;
TestXeGPUSgToWiDistributeExperimental(
- const TestXeGPUSgToWiDistributeExperimental &pass) = default;
+ const TestXeGPUSgToWiDistributeExperimental &pass)
+ : PassWrapper(pass) {}
void runOnOperation() override {
MLIRContext *ctx = &getContext();
@@ -298,6 +305,19 @@ struct TestXeGPUSgToWiDistributeExperimental
};
typeConverter.addSourceMaterialization(materializeCast);
typeConverter.addTargetMaterialization(materializeCast);
+
+ // If `enableRewriteMultiReductionToReductions` is set, only focus on
+ // testing the partial lowering of vector::MultiReductionOp.
+ if (enableRewriteMultiReductionToReductions) {
+ xegpu::populateXeGPUSgToWiDistributeTypeConversions(typeConverter);
+ ConversionTarget target(*ctx);
+ RewritePatternSet patterns(ctx);
+ xegpu::populateXeGPUSgToWiLowerVectorMultiReductionAndLegality(patterns,
+ target);
+ (void)applyPartialConversion(getOperation(), target, std::move(patterns));
+ return;
+ }
+
ConversionTarget target(*ctx);
RewritePatternSet patterns(ctx);
xegpu::populateXeGPUSgToWiDistributeTypeConversionAndLegality(
More information about the Mlir-commits
mailing list