[Mlir-commits] [mlir] [mlir][xegpu] Add SIMT distribution support for GEMM transpose B case. (PR #155517)
    Jianhui Li 
    llvmlistbot at llvm.org
       
    Thu Sep 11 16:51:00 PDT 2025
    
    
  
================
@@ -539,16 +682,48 @@ void LayoutInfoPropagation::visitVectorBitcastOp(
       bitcast.getSourceVectorType().getElementType().getIntOrFloatBitWidth();
   int outElemTyBitWidth =
       bitcast.getResultVectorType().getElementType().getIntOrFloatBitWidth();
-
-  // NOTE: We do not expect widening or narrowing bitcasts at this stage. Emit
-  // a warning and return.
-  if (inElemTyBitWidth != outElemTyBitWidth) {
-    bitcast.emitWarning("Widening or narrowing bitcasts are not expected at "
-                        "layout propagation stage.");
+  // If the element bit widths are the same, then the layout does not change.
+  if (inElemTyBitWidth == outElemTyBitWidth) {
+    propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
+    return;
+  }
+  int64_t rank = bitcast.getSourceVectorType().getRank();
+  // Bitcast is a `narrowing` if the input element type bit width larger than
+  // the output element type bit width. eg. f32 -> f16 is a narrowing bitcast.
+  bool isNarrowing = inElemTyBitWidth > outElemTyBitWidth;
+  int bitCastRatio = isNarrowing ? inElemTyBitWidth / outElemTyBitWidth
+                                 : outElemTyBitWidth / inElemTyBitWidth;
+  SmallVector<int> sourceLaneLayout =
+      resultLayout.getLaneLayout(); // Lane layout does not change for bitcast.
+  SmallVector<int> outData = resultLayout.getLaneData();
+
+  // TODO: Currently we assume that bitcasts does not require cross lane
+  // communication. So each lane must own the required number of elements to
+  // perform the bitcast locally without cross-lane communication.
+  int outInnerBitsPerLane = outData[rank - 1] * outElemTyBitWidth;
+  if (outInnerBitsPerLane < inElemTyBitWidth) {
----------------
Jianhui-Li wrote:
check the condition 
srcInnerBitsPerLane = inElemTypeBitWidth x sourceLayout.getLaneData
if (outInnerBitsPerLane != srcInnerBitsPerLane)
https://github.com/llvm/llvm-project/pull/155517
    
    
More information about the Mlir-commits
mailing list