[Mlir-commits] [mlir] [mlir][xegpu] Add SIMT distribution support for GEMM transpose B case. (PR #155517)

Charitha Saumya llvmlistbot at llvm.org
Thu Sep 18 14:13:13 PDT 2025


================
@@ -441,14 +478,94 @@ void LayoutInfoPropagation::visitVectorMultiReductionOp(
   }
   // Given that the result is 1D, the layout of the operand should be 2D with
   // default layout.
-  LayoutInfo operandLayout = getDefaultSIMTLayoutInfo(2);
+  LayoutInfo operandLayout =
+      getDefaultSIMTLayoutInfo(reduction->getContext(), 2);
   propagateIfChanged(operands[0], operands[0]->meet(operandLayout));
   // Accumulator should have the same layout as the result.
   propagateIfChanged(operands[1], operands[1]->meet(resultLayout));
 }
 
-/// Propagate the layout of the result tensor to the source tensor descriptor in
-/// UpdateNdOffsetOp.
+void LayoutInfoPropagation::visitVectorBroadCastOp(
+    vector::BroadcastOp broadcast, ArrayRef<LayoutInfoLattice *> operands,
+    ArrayRef<const LayoutInfoLattice *> results) {
+  // The layout of the result must be present.
+  LayoutInfo resultLayout = results[0]->getValue();
+  if (!resultLayout.isAssigned())
+    return;
+  // Only consider vector to vector broadcasts for now.
+  VectorType resultTy = broadcast.getResultVectorType();
+  VectorType sourceTy = dyn_cast<VectorType>(broadcast.getSourceType());
+  if (!sourceTy) {
+    broadcast.emitWarning("Expecting source type to be a vector type.");
+    return;
+  }
+
+  // Only consider nD -> nD broadcast.
+  if (sourceTy.getRank() != resultTy.getRank()) {
+    broadcast.emitWarning("Expecting source and result to have same rank.");
+    return;
+  }
+  SetVector<int64_t> broadcastUnitDims = broadcast.computeBroadcastedUnitDims();
+  if (broadcastUnitDims.size() != 1) {
+    broadcast.emitWarning("Expecting source type to be nD vector only with "
+                          "one broadcasted dimension.");
+    return;
+  }
+  // Propagate the result layout to the source operand.
+  propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
+}
+
+void LayoutInfoPropagation::visitShapeCastOp(
+    vector::ShapeCastOp shapeCast, ArrayRef<LayoutInfoLattice *> operands,
+    ArrayRef<const LayoutInfoLattice *> results) {
+  // The layout of the result must be present.
+  LayoutInfo resultLayout = results[0]->getValue();
+  if (!resultLayout.isAssigned())
+    return;
+  VectorType sourceTy = shapeCast.getSourceVectorType();
+  VectorType resultTy = shapeCast.getResultVectorType();
+  // Shape cast layout propagation has following restrictions:
+  // 1) nD -> nD shape cast is not supported.
+  // 2) Shape cast must always expand the rank (e.g. 1D -> 2D).
+  // 3) Newly expanded dimensions must be 1.
+  // 4) Result layout can not be a slice layout.
+  if (sourceTy.getRank() == resultTy.getRank()) {
+    shapeCast.emitWarning("nD -> nD shape cast is not supported.");
+    return;
+  }
+  if (sourceTy.getRank() > resultTy.getRank()) {
+    shapeCast.emitWarning("Expecting shape cast to expand the rank.");
+    return;
+  }
+  if (resultLayout.getRank() != resultTy.getRank() ||
+      resultLayout.isSliceLayout()) {
+    shapeCast.emitWarning("Expecting result layout to have same rank as the "
+                          "result type and not be a slice layout.");
+    return;
+  }
+  ArrayRef<int64_t> resultShape = shapeCast.getResultVectorType().getShape();
+  ArrayRef<int64_t> sourceShape = shapeCast.getSourceVectorType().getShape();
+  auto findUnitDims = [](ArrayRef<int64_t> shape) {
+    SmallVector<int64_t> unitDims;
+    for (int i = 0, e = shape.size(); i < e; ++i)
+      if (shape[i] == 1)
+        unitDims.push_back(i);
+    return unitDims;
+  };
+  SmallVector<int64_t> resultUnitDims = findUnitDims(resultShape);
+  SmallVector<int64_t> sourceUnitDims = findUnitDims(sourceShape);
+  // Remove first `sourceUnitDims.size()` unit dims from resultUnitDims.
+  auto sliceDims =
+      ArrayRef<int64_t>(resultUnitDims).drop_front(sourceUnitDims.size());
----------------
charithaintc wrote:

Thanks for pointing out. Earlier logic was wrong.

I fixed the code to detect the unit dimensions that are new in the result shape. It looks bit complex now and I don't have a good way to test them (I can only write 2D test cases due to store restrictions).  

Please take a look. 

In next PR, I will modify our test structure so I plan to add some more tests with that. 

https://github.com/llvm/llvm-project/pull/155517


More information about the Mlir-commits mailing list