[Mlir-commits] [mlir] [mlir][xegpu] Add SIMT distribution support for GEMM transpose B case. (PR #155517)
Charitha Saumya
llvmlistbot at llvm.org
Fri Sep 12 17:04:44 PDT 2025
================
@@ -441,14 +471,128 @@ void LayoutInfoPropagation::visitVectorMultiReductionOp(
}
// Given that the result is 1D, the layout of the operand should be 2D with
// default layout.
- LayoutInfo operandLayout = getDefaultSIMTLayoutInfo(2);
+ LayoutInfo operandLayout =
+ getDefaultSIMTLayoutInfo(reduction->getContext(), 2);
propagateIfChanged(operands[0], operands[0]->meet(operandLayout));
// Accumulator should have the same layout as the result.
propagateIfChanged(operands[1], operands[1]->meet(resultLayout));
}
-/// Propagate the layout of the result tensor to the source tensor descriptor in
-/// UpdateNdOffsetOp.
+void LayoutInfoPropagation::visitVectorBroadCastOp(
+ vector::BroadcastOp broadcast, ArrayRef<LayoutInfoLattice *> operands,
+ ArrayRef<const LayoutInfoLattice *> results) {
+ // The layout of the result must be present.
+ LayoutInfo resultLayout = results[0]->getValue();
+ if (!resultLayout.isAssigned())
+ return;
+ // Only consider vector to vector broadcasts for now.
+ VectorType resultTy = broadcast.getResultVectorType();
+ VectorType sourceTy = dyn_cast<VectorType>(broadcast.getSourceType());
+ if (!sourceTy) {
+ broadcast.emitWarning("Expecting source type to be a vector type.");
+ return;
+ }
+
+ // Only consider 2D -> 2D broadcast.
+ if (sourceTy.getRank() != 2 || resultTy.getRank() != 2) {
+ broadcast.emitWarning("Expecting source type to be 2D vector and "
+ "result type to be 2D vector.");
+ return;
+ }
+ SetVector<int64_t> broadcastUnitDims = broadcast.computeBroadcastedUnitDims();
+ if (broadcastUnitDims.size() != 1) {
+ broadcast.emitWarning("Expecting source type to be 2D vector only with "
+ "one broadcasted dimension.");
+ return;
+ }
+ // Propagate the result layout to the source operand.
+ propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
+}
+
+void LayoutInfoPropagation::visitShapeCastOp(
+ vector::ShapeCastOp shapeCast, ArrayRef<LayoutInfoLattice *> operands,
+ ArrayRef<const LayoutInfoLattice *> results) {
+ // The layout of the result must be present.
+ LayoutInfo resultLayout = results[0]->getValue();
+ if (!resultLayout.isAssigned())
+ return;
+ int64_t sourceRank = shapeCast.getSourceVectorType().getRank();
+ int64_t resultRank = shapeCast.getResultVectorType().getRank();
+ // Expecting source rank to be 1D or 2D.
+ if (sourceRank != 1 && sourceRank != 2) {
+ shapeCast.emitWarning("Expecting source type to be 1D or 2D vector.");
+ return;
+ }
+ // Expecting result rank to be 1D or 2D.
+ if (resultRank != 1 && resultRank != 2) {
+ shapeCast.emitWarning("Expecting result type to be 1D or 2D vector.");
+ return;
+ }
+ // For 2D -> 2D shape cast, propagate the result layout to the source.
----------------
charithaintc wrote:
fixed I also added this condition for now.
4) Result layout can not be a slice layout and it must have same rank as result.
https://github.com/llvm/llvm-project/pull/155517
More information about the Mlir-commits
mailing list