[Mlir-commits] [mlir] 5236af8 - [MLIR][XeGPU] Extend propagation and sg_to_lane distribution pass support broadcast with low rank and scalar source input (#170409)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Dec 9 08:48:31 PST 2025
Author: Jianhui Li
Date: 2025-12-09T08:48:27-08:00
New Revision: 5236af88e5ed0a3449b2292ef02be28b8722b172
URL: https://github.com/llvm/llvm-project/commit/5236af88e5ed0a3449b2292ef02be28b8722b172
DIFF: https://github.com/llvm/llvm-project/commit/5236af88e5ed0a3449b2292ef02be28b8722b172.diff
LOG: [MLIR][XeGPU] Extend propagation and sg_to_lane distribution pass support broadcast with low rank and scalar source input (#170409)
This PR extends XeGPU layout propagation and distribution for
vector.broadcast operation.
It relaxes the restriction of layout propagation to allow low-rank and
scalar source input, and adds a pattern in sg-to-wi distribution to
support the lowering.
Added:
Modified:
mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
mlir/test/Dialect/XeGPU/propagate-layout.mlir
mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir
mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index 93c5187b00756..eae0bd4e68a84 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -223,6 +223,14 @@ def DistributeLayoutAttr: AttrInterface<"DistributeLayoutAttr"> {
InterfaceMethod<"Derive a new layout by dropping InstData",
"xegpu::DistributeLayoutAttr",
"dropInstData">,
+ InterfaceMethod<"Derive a new layout with sg_data, inst_data and lane_data set to 1 for the specified unit dims",
+ "xegpu::DistributeLayoutAttr",
+ "setUnitDimData",
+ /*args=*/(ins "const llvm::SetVector<int64_t>": $unitDims)>,
+ InterfaceMethod<"Derive a new layout with sg_lane and lane_layout set to 1 for the specified unit dims",
+ "xegpu::DistributeLayoutAttr",
+ "setUnitDimLayout",
+ /*args=*/(ins "const llvm::SetVector<int64_t>": $unitDims)>,
InterfaceMethod<[{Delinearizes a linear ID into its multidimensional
indices based on the effective layout level.}],
"FailureOr<SmallVector<Value>>",
@@ -283,9 +291,14 @@ def DistributeLayoutAttr: AttrInterface<"DistributeLayoutAttr"> {
}
return true;
}]>,
- InterfaceMethod</*desc=*/[{Check if this layout is a slice of some other layout.}],
+ InterfaceMethod</*desc=*/[{Check if this layout is a slice of another layout.}],
/*retTy=*/"bool",
/*methodName=*/"isSliceOf",
+ /*args=*/(ins "const xegpu::DistributeLayoutAttr&": $other)>,
+
+ InterfaceMethod</*desc=*/[{Check if this layout is identical to another layout.}],
+ /*retTy=*/"bool",
+ /*methodName=*/"isEqualTo",
/*args=*/(ins "const xegpu::DistributeLayoutAttr&": $other)>
];
}
@@ -487,6 +500,12 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttr]> {
return {};
}
+ //set the layout for the sepcified unit dims: sg_data, inst_data and lane_data to 1
+ DistributeLayoutAttr setUnitDimData(SetVector<int64_t> unitDims);
+
+ //set the layout for the sepcified unit dims: sg_lane and lane_layout to 1
+ DistributeLayoutAttr setUnitDimLayout(SetVector<int64_t> unitDims);
+
/// Delinearizes a linear ID into its multidimensional indices
/// based on the effective level of the layout.
FailureOr<SmallVector<Value>>
@@ -501,6 +520,9 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttr]> {
/// Check if this is slice of some other layout.
bool isSliceOf(const xegpu::DistributeLayoutAttr &other) { return false; }
+
+ /// Check if this is identical to some other layout.
+ bool isEqualTo(const xegpu::DistributeLayoutAttr &other);
}];
@@ -649,6 +671,12 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttr]> {
return SliceAttr::get(getContext(), parent, attr.getDims());
}
+ //set the layout for the sepcified unit dims: sg_data, inst_data and lane_data to 1
+ DistributeLayoutAttr setUnitDimData(SetVector<int64_t> unitDims);
+
+ //set the layout for the sepcified unit dims: sg_lane and lane_layout to 1
+ DistributeLayoutAttr setUnitDimLayout(SetVector<int64_t> unitDims);
+
/// flatten a nested SliceAttr, e.g., for 2-level nested SliceAttr
/// #xegpu.slice<#xegpu.slice<#xegpu.layout<sg_layout = [4, 8, 12]>, dims = [0]>, dims = [0]>
/// it will coalese two slice operations and return a simplified SliceAttr
@@ -670,7 +698,9 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttr]> {
/// Check if this is slice of some other layout.
bool isSliceOf(const xegpu::DistributeLayoutAttr &other);
-
+
+ /// Check if this is identical to some other layout.
+ bool isEqualTo(const xegpu::DistributeLayoutAttr &other);
}];
let assemblyFormat = "`<` qualified($parent) `,` `dims` `=` $dims `>`";
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index 4fe1087d18879..b54d620c3c0c3 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -405,7 +405,7 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [
OptionalAttr<DenseI64ArrayAttr>: $transpose,
OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint,
- OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint,
+ OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint,
OptionalAttr<DistributeLayoutAttr>:$layout);
let results = (outs XeGPU_ValueType: $value);
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 7ab2e612ed890..1a19ab5fd970b 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -390,6 +390,86 @@ LayoutAttr::computeDistributedCoords(OpBuilder &builder, Location loc,
return genCoordinates(builder, loc, ids, layout, subShape, shape);
}
+bool LayoutAttr::isEqualTo(const xegpu::DistributeLayoutAttr &other) {
+ if (dyn_cast<xegpu::SliceAttr>(other))
+ return false;
+
+ return *this == dyn_cast<xegpu::LayoutAttr>(other);
+}
+
+// set the layout for unit dims: sg_data, inst_data and lane_data to 1
+DistributeLayoutAttr LayoutAttr::setUnitDimData(SetVector<int64_t> unitDims) {
+ auto sgDataOpt = getSgData();
+ auto instDataOpt = getInstData();
+ auto laneDataOpt = getLaneData();
+
+ SmallVector<int32_t> sgData;
+ SmallVector<int32_t> instData;
+ SmallVector<int32_t> laneData;
+
+ if (sgDataOpt) {
+ sgData = llvm::to_vector(sgDataOpt.asArrayRef());
+ }
+ if (instDataOpt) {
+ instData = llvm::to_vector(instDataOpt.asArrayRef());
+ }
+ if (laneDataOpt) {
+ laneData = llvm::to_vector(laneDataOpt.asArrayRef());
+ }
+
+ for (auto dim : unitDims) {
+ if (dim < static_cast<int64_t>(sgData.size()))
+ sgData[dim] = 1;
+ if (dim < static_cast<int64_t>(instData.size()))
+ instData[dim] = 1;
+ if (dim < static_cast<int64_t>(laneData.size()))
+ laneData[dim] = 1;
+ }
+
+ return LayoutAttr::get(
+ getContext(), getSgLayout(),
+ sgData.empty() ? DenseI32ArrayAttr()
+ : DenseI32ArrayAttr::get(getContext(), sgData),
+ instData.empty() ? DenseI32ArrayAttr()
+ : DenseI32ArrayAttr::get(getContext(), instData),
+ getLaneLayout(),
+ laneData.empty() ? DenseI32ArrayAttr()
+ : DenseI32ArrayAttr::get(getContext(), laneData),
+ getOrder());
+}
+
+// set the layout for the sepcified unit dims: sg_lane and lane_layout to 1
+DistributeLayoutAttr LayoutAttr::setUnitDimLayout(SetVector<int64_t> unitDims) {
+ auto sgLayoutOpt = getSgLayout();
+ auto laneLayoutOpt = getLaneLayout();
+
+ SmallVector<int32_t> sgLayout;
+ SmallVector<int32_t> laneLayout;
+
+ if (sgLayoutOpt) {
+ sgLayout = llvm::to_vector(sgLayoutOpt.asArrayRef());
+ }
+ if (laneLayoutOpt) {
+ laneLayout = llvm::to_vector(laneLayoutOpt.asArrayRef());
+ }
+
+ for (auto dim : unitDims) {
+ if (dim < static_cast<int64_t>(sgLayout.size()))
+ sgLayout[dim] = 1;
+ if (dim < static_cast<int64_t>(laneLayout.size()))
+ laneLayout[dim] = 1;
+ }
+
+ return LayoutAttr::get(
+ getContext(),
+ sgLayout.empty() ? DenseI32ArrayAttr()
+ : DenseI32ArrayAttr::get(getContext(), sgLayout),
+ getSgData(), getInstData(),
+ laneLayout.empty() ? DenseI32ArrayAttr()
+ : DenseI32ArrayAttr::get(getContext(), laneLayout),
+ getLaneData(), getOrder());
+}
+
//===----------------------------------------------------------------------===//
// XeGPU_SliceAttr
//===----------------------------------------------------------------------===//
@@ -510,6 +590,69 @@ bool SliceAttr::isSliceOf(const xegpu::DistributeLayoutAttr &other) {
[&](int64_t dim) { return thisDims.contains(dim); });
}
+bool SliceAttr::isEqualTo(const xegpu::DistributeLayoutAttr &other) {
+ if (dyn_cast<xegpu::LayoutAttr>(other))
+ return false;
+
+ auto flattenedThis = flatten();
+ auto flattenedOther = dyn_cast<xegpu::SliceAttr>(other).flatten();
+
+ return ((flattenedThis.getParent() == flattenedOther.getParent()) &&
+ (flattenedThis.getDims() == flattenedOther.getDims()));
+}
+
+// Helper function to adjust unit dimensions from sliced space to parent space
+static SetVector<int64_t>
+adjustUnitDimsWithSliceDims(const SetVector<int64_t> &unitDims,
+ ArrayRef<int64_t> sliceDims) {
+ // Reconstruct parent's non-sliced dimensions
+
+ int64_t parentRank = sliceDims.size() + unitDims.size();
+ llvm::SmallDenseSet<int64_t> slicedDimsSet(sliceDims.begin(),
+ sliceDims.end());
+ SmallVector<int64_t> nonSlicedDims;
+ for (int64_t i = 0; i < parentRank; ++i) {
+ if (!slicedDimsSet.contains(i))
+ nonSlicedDims.push_back(i);
+ }
+
+ // Map unit dims from sliced space to parent space
+ SetVector<int64_t> adjustUnitDims;
+ for (auto dim : unitDims) {
+ if (dim < static_cast<int64_t>(nonSlicedDims.size())) {
+ adjustUnitDims.insert(nonSlicedDims[dim]);
+ }
+ }
+
+ return adjustUnitDims;
+}
+
+// set the layout for unit dims: sg_data, inst_data and lane_data to 1
+DistributeLayoutAttr SliceAttr::setUnitDimData(SetVector<int64_t> unitDims) {
+ SliceAttr attr = flatten();
+ ArrayRef<int64_t> sliceDims = attr.getDims().asArrayRef();
+ auto parent = dyn_cast<LayoutAttr>(attr.getParent());
+
+ SetVector<int64_t> adjustUnitDims =
+ adjustUnitDimsWithSliceDims(unitDims, sliceDims);
+
+ return SliceAttr::get(getContext(), parent.setUnitDimData(adjustUnitDims),
+ attr.getDims());
+}
+
+// set the layout for the sepcified unit dims: sg_lane and lane_layout to 1
+DistributeLayoutAttr SliceAttr::setUnitDimLayout(SetVector<int64_t> unitDims) {
+ SliceAttr attr = flatten();
+ ArrayRef<int64_t> sliceDims = attr.getDims().asArrayRef();
+ auto parent = dyn_cast<LayoutAttr>(attr.getParent());
+
+ SetVector<int64_t> adjustUnitDims =
+ adjustUnitDimsWithSliceDims(unitDims, sliceDims);
+
+ return SliceAttr::get(getContext(), parent.setUnitDimLayout(adjustUnitDims),
+ attr.getDims());
+}
+
//===----------------------------------------------------------------------===//
// XeGPU_RangeAttr
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 59a1ad9dbe189..dc9eb96c169b4 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -580,23 +580,39 @@ void LayoutInfoPropagation::visitVectorBroadCastOp(
// Only consider vector to vector broadcasts for now.
VectorType resultTy = broadcast.getResultVectorType();
VectorType sourceTy = dyn_cast<VectorType>(broadcast.getSourceType());
- if (!sourceTy) {
- broadcast.emitWarning("Expecting source type to be a vector type.");
+ // skip layout propagation for non-vector source operand.
+ if (!sourceTy)
return;
- }
- // Only consider nD -> nD broadcast.
+ // Hanlding broadcast from low-rank to high-rank (e.g., 1D to 2D) case.
if (sourceTy.getRank() != resultTy.getRank()) {
- broadcast.emitWarning("Expecting source and result to have same rank.");
+ auto sourceDims = sourceTy.getShape();
+ auto resultDims = resultTy.getShape();
+ SmallVector<int64_t> bcastDims;
+ auto dimDiff = resultTy.getRank() - sourceTy.getRank();
+ // adding the missing leading dims
+ for (int i = 0; i < dimDiff; i++)
+ bcastDims.push_back(i);
+
+ // for the rest dims in the resultTy, if sourceTy dim is 1, then it's
+ // broadcasted dim
+ for (size_t i = 0; i < sourceDims.size(); i++)
+ if ((sourceDims[i] == 1) && (resultDims[i + dimDiff] != 1))
+ bcastDims.push_back(i + dimDiff);
+
+ // create a slice layout for the source
+ xegpu::SliceAttr sliceLayout = xegpu::SliceAttr::get(
+ broadcast->getContext(),
+ cast<xegpu::DistributeLayoutAttr>(resultLayout.get()),
+ DenseI64ArrayAttr::get(broadcast->getContext(), bcastDims));
+
+ propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(sliceLayout)));
return;
}
+
SetVector<int64_t> broadcastUnitDims = broadcast.computeBroadcastedUnitDims();
- if (broadcastUnitDims.size() != 1) {
- broadcast.emitWarning("Expecting source type to be nD vector only with "
- "one broadcasted dimension.");
- return;
- }
- // Propagate the result layout to the source operand.
+ resultLayout = cast<xegpu::DistributeLayoutAttr>(resultLayout.get())
+ .setUnitDimData(broadcastUnitDims);
propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
}
@@ -917,7 +933,7 @@ void LayoutInfoPropagation::visitLoadGatherOp(
} else {
// The layout is strictly determined by the payload type.
- auto payloadTy = dyn_cast<VectorType>(load.getValueType());
+ VectorType payloadTy = load.getValueType();
if (!payloadTy) {
load.emitWarning("Not propagating, non-vector payload supplied.");
return;
@@ -987,7 +1003,7 @@ void LayoutInfoPropagation::visitStoreScatterOp(
// Currently, for 2D StoreScatterOp we expect that the height dimension of
// the tensor descriptor is equal to the subgroup size. This is ensured by
// the op verifier.
- auto payloadTy = dyn_cast<VectorType>(storeScatter.getValueType());
+ VectorType payloadTy = storeScatter.getValueType();
if (!payloadTy) {
storeScatter.emitWarning("Not propagating, non-vector payload supplied.");
return;
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 0d1c5eeeff711..ca81c3cd7be42 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -99,7 +99,6 @@ getDistVecTypeBasedOnLaneLayout(xegpu::DistributeLayoutAttr layout,
for (auto [i, dim] : llvm::enumerate(originalType.getShape())) {
if (i < distributionStart)
continue;
-
// Check if the dimension can be distributed evenly.
if (dim % effectiveLaneLayout[i - distributionStart] != 0)
return failure();
@@ -1424,6 +1423,166 @@ struct VectorMultiReductionDistribution : public gpu::WarpDistributionPattern {
}
};
+/// This pattern distributes the `vector.broadcast` operation across lanes in a
+/// warp. The pattern supports three use cases:
+///
+/// 1) Broadcast a low-rank vector to high-rank vector: The low-rank input
+/// vector
+/// must have a slice layout of the result. If the distributed source and
+/// target vector types are identical, this lowers to a no-op; otherwise, it
+/// remains a broadcast but operates on distributed vectors.
+///
+/// 2) Broadcast a same-rank vector with identical layouts for source and
+/// target:
+/// The source vector must have unit dimensions, and lane_data must be unit
+/// size for those unit dims. This always lowers to a no-op.
+///
+/// 3) Broadcast a scalar with no layout: This always lowers to a broadcast from
+/// scalar to distributed result type.
+///
+/// Example 1 (lowering to a broadcast with distributed types):
+/// ```
+/// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<8x1xf32>) {
+/// %0 = "some_def"() {layout_result_0 =
+/// #xegpu.slice<#xegpu.layout<lane_layout = [1, 32], lane_data = [1, 1]>,
+/// dims = [0]> } : () -> (vector<32xf32>)
+/// %2 = vector.broadcast %0 {layout_result_0 =
+/// #xegpu.layout<lane_layout = [1, 32], lane_data = [1, 1]>}
+/// : vector<32xf32> to vector<8x32xf32>
+/// gpu.yield %1 : vector<8x32xf32>
+/// }
+/// ```
+/// is lowered to:
+/// ```
+/// %r:1 = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xf32>) {
+/// %0 = "some_def"() {layout_result_0 =
+/// #xegpu.slice<#xegpu.layout<lane_layout = [1, 32], lane_data = [1, 1]>,
+/// dims = [0]> } : () -> (vector<32xf32>)
+/// gpu.yield %0 : vector<32xf32>
+/// }
+/// %2 = vector.broadcast %r#0 : vector<1xf32> to vector<8x1xf32>
+///
+/// Example 2 (no-op):
+/// ```
+/// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<8x32xf32>) {
+/// %0 = "some_def"() {layout_result_0 =
+/// #xegpu.slice<#xegpu.layout<lane_layout = [1, 32], lane_data = [1, 1]>,
+/// dims = [1]> } : () -> (vector<8xf32>)
+/// %1 = vector.shape_cast %0
+/// {layout_result_0 = #xegpu.layout<lane_layout = [1, 32], lane_data = [1,
+/// 1]>}: vector<8xf32> to vector<8x1xf32>
+/// %2 = vector.broadcast %1
+/// {layout_result_0 = #xegpu.layout<lane_layout = [1, 32], lane_data = [1,
+/// 1]>}: vector<8x1xf32> to vector<8x32xf32>
+/// gpu.yield %1 : vector<8x32xf32>
+/// }
+/// ```
+/// is lowered to:
+/// ```
+/// %r:1 = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<8x1xf32>) {
+/// %0 = "some_def"() {layout_result_0 =
+/// #xegpu.slice<#xegpu.layout<lane_layout = [1, 32], lane_data = [1, 1]>,
+/// dims = [1]> } : () -> (vector<8xf32>)
+/// %1 = vector.shape_cast %0
+/// {layout_result_0 = #xegpu.layout<lane_layout = [1, 32], lane_data = [1,
+/// 1]>}: vector<8xf32> to vector<8x1xf32>
+/// gpu.yield %1 : vector<8x1xf32>
+/// }
+/// // The broadcast is implicit through layout transformation (no-op)
+/// "some_use"(%r#0)
+/// ```
+struct VectorBroadcastDistribution : public gpu::WarpDistributionPattern {
+ using gpu::WarpDistributionPattern::WarpDistributionPattern;
+ LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
+ PatternRewriter &rewriter) const override {
+ OpOperand *yieldOperand =
+ getWarpResult(warpOp, llvm::IsaPred<vector::BroadcastOp>);
+ if (!yieldOperand)
+ return failure();
+ auto broadcastOp =
+ cast<vector::BroadcastOp>(yieldOperand->get().getDefiningOp());
+ unsigned operandIdx = yieldOperand->getOperandNumber();
+
+ VectorType sourceType = dyn_cast<VectorType>(broadcastOp.getSourceType());
+ VectorType destType =
+ dyn_cast<VectorType>(broadcastOp.getResult().getType());
+
+ xegpu::DistributeLayoutAttr sourceLayout =
+ xegpu::getDistributeLayoutAttr(broadcastOp->getOpOperand(0));
+ xegpu::DistributeLayoutAttr resultLayout =
+ xegpu::getDistributeLayoutAttr(broadcastOp.getResult());
+
+ FailureOr<VectorType> sourceDistType;
+ Type sourceElemOrDistType;
+ if (sourceType) {
+
+ // Case 1 and 2: source is a vector type.
+ int64_t rankDiff = destType.getRank() - sourceType.getRank();
+ if (rankDiff > 0) {
+ // Case 1: source is lower-rank than result.
+ bool isSliceOf = sourceLayout.isSliceOf(resultLayout);
+ if (!isSliceOf)
+ return rewriter.notifyMatchFailure(
+ warpOp,
+ "Broadcast input layout must be a slice of result layout.");
+ }
+ // case 2: source and result have same rank
+ if (rankDiff == 0) {
+ SetVector<int64_t> broadcastUnitDims =
+ broadcastOp.computeBroadcastedUnitDims();
+ resultLayout = resultLayout.setUnitDimData(broadcastUnitDims);
+ bool isEqualTo = sourceLayout.isEqualTo(resultLayout);
+ if (!isEqualTo)
+ return rewriter.notifyMatchFailure(
+ warpOp, "For same-rank broadcast, source must be identical to "
+ "adjusted result layouts with unit dims.");
+ sourceLayout = sourceLayout.setUnitDimLayout(broadcastUnitDims);
+ }
+
+ sourceDistType =
+ getDistVecTypeBasedOnLaneLayout(sourceLayout, sourceType);
+ if (failed(sourceDistType)) {
+ return rewriter.notifyMatchFailure(
+ warpOp, "Failed to distribute the source vector type.");
+ }
+ sourceElemOrDistType = sourceDistType.value();
+
+ } else {
+ // Case 3: source is a scalar type.
+ if (sourceLayout) {
+ return rewriter.notifyMatchFailure(
+ warpOp, "Broadcast from scalar must not have a layout attribute.");
+ }
+ sourceElemOrDistType = broadcastOp.getSourceType();
+ }
+ FailureOr<VectorType> destDistType =
+ getDistVecTypeBasedOnLaneLayout(resultLayout, destType);
+ if (failed(destDistType)) {
+ return rewriter.notifyMatchFailure(
+ warpOp, "Failed to distribute the dest vector type.");
+ }
+
+ SmallVector<size_t> newRetIndices;
+ auto newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+ rewriter, warpOp, {broadcastOp.getSource()}, sourceElemOrDistType,
+ newRetIndices);
+
+ Value distributedSource = newWarpOp.getResult(newRetIndices[0]);
+
+ Value newBroadcast = distributedSource;
+
+ if (sourceElemOrDistType != destDistType.value()) {
+ rewriter.setInsertionPointAfter(newWarpOp);
+ newBroadcast =
+ vector::BroadcastOp::create(rewriter, newWarpOp.getLoc(),
+ destDistType.value(), distributedSource);
+ }
+
+ rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIdx), newBroadcast);
+ return success();
+ }
+};
+
/// Distribute a `vector.shape_cast` op feeding into yield op of an enclosing
/// `gpu.warp_execute_on_lane_0` region.
struct VectorShapeCastDistribution : public gpu::WarpDistributionPattern {
@@ -1865,7 +2024,7 @@ void xegpu::populateXeGPUSubgroupDistributePatterns(
// patterns. Therefore, assign higher benefit.
patterns
.add<VectorShapeCastDistribution, VectorExtractStridedSliceDistribution,
- VectorInsertStridedSliceDistribution>(
+ VectorInsertStridedSliceDistribution, VectorBroadcastDistribution>(
patterns.getContext(),
/*pattern benefit=*/highPatternBenefit);
}
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout.mlir b/mlir/test/Dialect/XeGPU/propagate-layout.mlir
index f8b59b87a122b..48e77d867508b 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout.mlir
@@ -640,3 +640,61 @@ func.func @vector_shape_cast_1d_to_2d_dim0_broadcasted(%arg0: !xegpu.tensor_desc
return
}
}
+// -----
+gpu.module @test {
+// CHECK-LABEL: func.func @vector_broadcast_1d_to_2d_broadcast_along_row(
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>,
+// CHECK-SAME: %[[ARG1:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>) {
+// CHECK: %[[LOAD:.*]] = xegpu.load_nd %[[ARG0]] <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}> {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+// CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<16x16xf16>
+// CHECK-NEXT: %[[REDUCE:.*]] = vector.multi_reduction <add>, %[[LOAD]], %{{[0-9a-zA-Z]+}}
+// CHECK-SAME: {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>} [0] : vector<16x16xf16> to vector<16xf16>
+// CHECK-NEXT: %[[BROADCAST:.*]] = vector.broadcast %[[REDUCE]]
+// CHECK-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16xf16> to vector<16x16xf16>
+func.func @vector_broadcast_1d_to_2d_broadcast_along_row(%arg0: !xegpu.tensor_desc<16x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>) {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant dense<0.0000> : vector<16xf16>
+ %3 = xegpu.load_nd %arg0 : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+ %4 = vector.multi_reduction <add>, %3, %cst [0] : vector<16x16xf16> to vector<16xf16>
+ %5 = vector.broadcast %4 : vector<16xf16> to vector<16x16xf16>
+ xegpu.store_nd %5, %arg1 : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16>
+ return
+}
+}
+
+// -----
+gpu.module @test {
+// CHECK-LABEL: func.func @vector_broadcast_2d_to_2d_along_column(
+// CHECK: %[[REDUCE:.*]] = vector.multi_reduction <add>
+// CHECK-SAME: {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [1]>} [1] : vector<16x16xf16> to vector<16xf16>
+// CHECK-NEXT: %[[SHAPECAST:.*]] = vector.shape_cast %[[REDUCE]]
+// CHECK-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16xf16> to vector<16x1xf16>
+// CHECK-NEXT: vector.broadcast %[[SHAPECAST]]
+// CHECK-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16x1xf16> to vector<16x16xf16>
+
+func.func @vector_broadcast_2d_to_2d_along_column(%arg0: !xegpu.tensor_desc<16x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>) {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant dense<0.0000> : vector<16xf16>
+ %3 = xegpu.load_nd %arg0 : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+ %4 = vector.multi_reduction <add>, %3, %cst [1] : vector<16x16xf16> to vector<16xf16>
+ %5 = vector.shape_cast %4 : vector<16xf16> to vector<16x1xf16>
+ %6 = vector.broadcast %5 : vector<16x1xf16> to vector<16x16xf16>
+ xegpu.store_nd %6, %arg1 : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16>
+ return
+}
+}
+
+// -----
+gpu.module @test {
+// CHECK-LABEL: func.func @vector_broadcast_scalar_to_vector(
+// CHECK: %[[CST:.*]] = arith.constant 0.{{.*}} : f16
+// CHECK-NEXT: %[[BROADCAST:.*]] = vector.broadcast %[[CST]]
+// CHECK-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : f16 to vector<16x16xf16>
+
+func.func @vector_broadcast_scalar_to_vector(%arg0: !xegpu.tensor_desc<16x16xf16>) {
+ %cst = arith.constant 0.0000 : f16
+ %6 = vector.broadcast %cst : f16 to vector<16x16xf16>
+ xegpu.store_nd %6, %arg0 : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16>
+ return
+}
+}
\ No newline at end of file
diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir
index 44ec21359593f..216f3d19cff94 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir
@@ -920,4 +920,69 @@ gpu.func @vector_insert_strided_slice_unsupported_offset(%laneid: index) {
gpu.return
}
+// CHECK-LABEL: gpu.func @vector_broadcast_1d_to_2d_broadcast_within_lane
+// CHECK-SAME: (%[[ARG0:.*]]: index) {
+// CHECK: %[[R:.*]]:2 = gpu.warp_execute_on_lane_0(%[[ARG0]])[16] -> (vector<16x1xf16>, vector<1xf16>)
+// CHECK: %[[DEF:.*]] = "some_def"()
+// CHECK: %[[BCAST_INNER:.*]] = vector.broadcast %[[DEF]]
+// CHECK: gpu.yield %[[BCAST_INNER]], %[[DEF]]
+// CHECK: %[[BCAST:.*]] = vector.broadcast %[[R]]#1 : vector<1xf16> to vector<16x1xf16>
+// CHECK: "some_use"(%[[BCAST]])
+gpu.func @vector_broadcast_1d_to_2d_broadcast_within_lane(%laneid: index) {
+
+ %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<16x1xf16>) {
+
+ %1 = "some_def"() : () -> vector<16xf16>
+
+ %2 = vector.broadcast %1 {
+ layout_operand_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>,
+ layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+ } : vector<16xf16> to vector<16x16xf16>
+
+ gpu.yield %2 : vector<16x16xf16>
+ }
+ "some_use"(%r) : (vector<16x1xf16>) -> ()
+ gpu.return
+}
+
+// CHECK-LABEL: gpu.func @vector_broadcast_2d_to_2d_across_lane_lower_to_noop_case
+// CHECK-SAME: (%[[ARG0:.*]]: index)
+// CHECK: %[[R:.*]]:2 = gpu.warp_execute_on_lane_0(%[[ARG0]])[16] -> (vector<16x1xf16>, vector<16x1xf16>)
+// CHECK: %[[DEF:.*]] = "some_def"() : () -> vector<16x1xf16>
+// CHECK: %[[BCAST:.*]] = vector.broadcast %[[DEF]]
+// CHECK-SAME: : vector<16x1xf16> to vector<16x16xf16>
+// CHECK: gpu.yield %[[BCAST]], %[[DEF]] : vector<16x16xf16>, vector<16x1xf16>
+// CHECK: "some_use"(%[[R]]#1) : (vector<16x1xf16>) -> ()
+gpu.func @vector_broadcast_2d_to_2d_across_lane_lower_to_noop_case(%arg0: index) {
+ %0 = gpu.warp_execute_on_lane_0(%arg0)[16] -> (vector<16x1xf16>) {
+ %1 = "some_def"() : () -> vector<16x1xf16>
+ %2 = vector.broadcast %1 {
+ layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+ layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+ } : vector<16x1xf16> to vector<16x16xf16>
+ gpu.yield %2: vector<16x16xf16>
+ }
+ "some_use"(%0) : (vector<16x1xf16>) -> ()
+ gpu.return
+}
+
+// CHECK-LABEL: gpu.func @vector_shape_cast_scalar_to_vector
+// CHECK-SAME: (%[[ARG0:.*]]: index)
+// CHECK: %[[R:.*]]:2 = gpu.warp_execute_on_lane_0(%[[ARG0]])[16] -> (vector<16x1xf16>, f16)
+// CHECK: %[[DEF:.*]] = "some_def"()
+// CHECK: %[[BCAST:.*]] = vector.broadcast %[[DEF]] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : f16 to vector<16x16xf16>
+// CHECK: gpu.yield %[[BCAST]], %[[DEF]] : vector<16x16xf16>, f16
+// CHECK: %[[RESULT:.*]] = vector.broadcast %[[R]]#1 : f16 to vector<16x1xf16>
+// CHECK: "some_use"(%[[RESULT]])
+gpu.func
+ at vector_shape_cast_scalar_to_vector(%arg0: index) {
+ %0 = gpu.warp_execute_on_lane_0(%arg0)[16] -> (vector<16x1xf16>) {
+ %1 = "some_def"() : () -> f16
+ %2 = vector.broadcast %1 {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : f16 to vector<16x16xf16>
+ gpu.yield %2 : vector<16x16xf16>
+ }
+ "some_use"(%0) : (vector<16x1xf16>) -> ()
+ gpu.return
+}
+
}
diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
index 22177f8f6a15f..e5e3d2a1c1ad5 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
@@ -330,3 +330,64 @@ gpu.module @xevm_module{
gpu.return
}
}
+
+// -----
+// CHECK-LABEL: gpu.func @vector_broadcast_1d_to_2d_broadcast_within_lane({{.*}}) {
+gpu.module @xevm_module{
+ gpu.func @vector_broadcast_1d_to_2d_broadcast_within_lane(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf16>) {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>} dense<0.000000e+00> : vector<16xf16>
+ %tdesc0 = xegpu.create_nd_tdesc %arg0 : memref<16x16xf16>
+ -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ %tdesc1 = xegpu.create_nd_tdesc %arg1 : memref<16x16xf16>
+ -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ %0 = xegpu.load_nd %tdesc0[%c0, %c0] <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}> {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<16x16xf16>
+ %1 = vector.multi_reduction <add>, %0, %cst {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>} [0] : vector<16x16xf16> to vector<16xf16>
+ // CHECK: %[[BCAST:.*]] = vector.broadcast %{{.*}} : f16 to vector<16xf16>
+ %2 = vector.broadcast %1 {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16xf16> to vector<16x16xf16>
+ xegpu.store_nd %2, %tdesc1[%c0, %c0] <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}> : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ gpu.return
+ }
+}
+
+// -----
+// CHECK-LABEL: gpu.func @vector_broadcast_2d_to_2d_across_lane_lower_to_noop_case({{.*}}) {
+gpu.module @xevm_module{
+ gpu.func @vector_broadcast_2d_to_2d_across_lane_lower_to_noop_case(%arg0: memref<16xf16>, %arg1: memref<16x16xf16>) {
+ %c0 = arith.constant 0 : index
+ %mask = vector.constant_mask [16] {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [1]>}: vector<16xi1>
+ %1 = xegpu.load %arg0[%c0], %mask {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [1]>}: memref<16xf16>, index, vector<16xi1> -> vector<16xf16>
+
+ %11 = vector.shape_cast %1 {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16xf16> to vector<16x1xf16>
+ %2 = vector.broadcast %11 {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16x1xf16> to vector<16x16xf16>
+ // CHECK-NOT: vector.broadcast
+ // CHECK-NOT: vector.shape_cast
+
+ %tdesc1 = xegpu.create_nd_tdesc %arg1 : memref<16x16xf16>
+ -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ // CHECK: xegpu.store_nd {{.*}}, {{.*}}[{{.*}}, {{.*}}]
+ // CHECK-SAME: : vector<16xf16>, !xegpu.tensor_desc<16x16xf16>
+
+ xegpu.store_nd %2, %tdesc1[%c0, %c0] <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}> : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ gpu.return
+ }
+}
+
+// -----
+// CHECK-LABEL: gpu.func @vector_shape_cast_scalar_to_vector({{.*}}) {
+gpu.module @xevm_module{
+ gpu.func @vector_shape_cast_scalar_to_vector(%arg0: memref<16xf16>, %arg1: memref<16x16xf16>) {
+ %c0 = arith.constant 0 : index
+ %9 = gpu.block_id x
+ %10 = arith.index_cast %9 : index to i16
+ %11 = arith.bitcast %10 : i16 to f16
+ // CHECK: vector.broadcast {{.*}} : f16 to vector<16xf16>
+ %2 = vector.broadcast %11 {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : f16 to vector<16x16xf16>
+ %tdesc1 = xegpu.create_nd_tdesc %arg1 : memref<16x16xf16>
+ -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ xegpu.store_nd %2, %tdesc1[%c0, %c0] <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}> : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ gpu.return
+ }
+}
+
+
More information about the Mlir-commits
mailing list