[Mlir-commits] [mlir] [mlir][xegpu] Add SIMT distribution support for GEMM transpose B case. (PR #155517)

Charitha Saumya llvmlistbot at llvm.org
Fri Sep 12 12:42:18 PDT 2025


================
@@ -1001,6 +1016,129 @@ struct LoadDistribution final : public gpu::WarpDistributionPattern {
   }
 };
 
+/// Sink a memref::ExtractAlignedPointerAsIndex op feeding into yield op of an
+/// enclosing `gpu.warp_execute_on_lane_0` region. This will simply move the op
+/// outside of the warp op.
+struct MemrefExtractAlignedPointerAsIndexDistribution final
+    : public gpu::WarpDistributionPattern {
+  using gpu::WarpDistributionPattern::WarpDistributionPattern;
+  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
+                                PatternRewriter &rewriter) const override {
+    OpOperand *operand = getWarpResult(
+        warpOp, llvm::IsaPred<memref::ExtractAlignedPointerAsIndexOp>);
+    if (!operand)
+      return rewriter.notifyMatchFailure(
+          warpOp,
+          "warp result is not a xegpu::MemrefExtractAlignedPointerAsIndex op");
+    auto extractOp =
+        operand->get().getDefiningOp<memref::ExtractAlignedPointerAsIndexOp>();
+    unsigned operandIdx = operand->getOperandNumber();
+    SmallVector<size_t> newRetIndices;
+    gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+        rewriter, warpOp, extractOp.getSource(),
+        TypeRange{extractOp.getSource().getType()}, newRetIndices);
+    rewriter.setInsertionPointAfter(newWarpOp);
+    auto newExtractOp = memref::ExtractAlignedPointerAsIndexOp::create(
+        rewriter, newWarpOp.getLoc(), extractOp.getType(),
+        newWarpOp.getResult(newRetIndices[0]));
+    Value distributedVal = newWarpOp.getResult(operandIdx);
+    rewriter.replaceAllUsesWith(distributedVal, newExtractOp.getResult());
+    return success();
+  }
+};
+
+/// Distribute a vector::BitCastOp feeding into yield op of an enclosing
+/// `gpu.warp_execute_on_lane_0` region. Bitcast only impacts the innermost
+/// diemension of the source/result vectors. Equivalent vector::BitCastOp is
+/// created outside of the warp op with distributed source vector type (computed
+/// using assigned layout).
+struct VectorBitcastDistribution final : public gpu::WarpDistributionPattern {
+  using gpu::WarpDistributionPattern::WarpDistributionPattern;
+  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
+                                PatternRewriter &rewriter) const override {
+    OpOperand *operand =
+        getWarpResult(warpOp, llvm::IsaPred<vector::BitCastOp>);
+    if (!operand)
+      return rewriter.notifyMatchFailure(
+          warpOp, "warp result is not a vector::BitCast op");
+    auto bitcastOp = operand->get().getDefiningOp<vector::BitCastOp>();
+    unsigned operandIdx = operand->getOperandNumber();
+    VectorType distributedSourceType =
+        getDistVecTypeBasedOnLaneLayout(
+            xegpu::getDistributeLayoutAttr(bitcastOp.getSource()),
+            bitcastOp.getSourceVectorType())
+            .value_or(VectorType());
+    if (!distributedSourceType)
+      return rewriter.notifyMatchFailure(
+          bitcastOp, "Failed to distribute the source vector type in "
+                     "vector::BitCast op");
+    VectorType distributedResultType =
+        cast<VectorType>(warpOp.getResult(operandIdx).getType());
+    SmallVector<size_t> newRetIndices;
+    gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+        rewriter, warpOp, bitcastOp.getSource(),
+        TypeRange{distributedSourceType}, newRetIndices);
+    rewriter.setInsertionPointAfter(newWarpOp);
+    auto newBitcastOp = vector::BitCastOp::create(
+        rewriter, newWarpOp.getLoc(), distributedResultType,
+        newWarpOp.getResult(newRetIndices[0]));
+    Value distributedVal = newWarpOp.getResult(operandIdx);
+    rewriter.replaceAllUsesWith(distributedVal, newBitcastOp.getResult());
+    return success();
+  }
+};
+
+struct VectorTransposeDistribution final : public gpu::WarpDistributionPattern {
+  using gpu::WarpDistributionPattern::WarpDistributionPattern;
+  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
+                                PatternRewriter &rewriter) const override {
+    OpOperand *operand =
+        getWarpResult(warpOp, llvm::IsaPred<vector::TransposeOp>);
+    if (!operand)
+      return rewriter.notifyMatchFailure(
+          warpOp, "warp result is not a vector::Transpose op");
+    auto transposeOp = operand->get().getDefiningOp<vector::TransposeOp>();
+    unsigned operandIdx = operand->getOperandNumber();
+    xegpu::DistributeLayoutAttr sourceLayout =
+        xegpu::getDistributeLayoutAttr(transposeOp.getVector());
+    xegpu::DistributeLayoutAttr resultLayout =
+        xegpu::getDistributeLayoutAttr(transposeOp.getResult());
+    if (!sourceLayout || !resultLayout)
+      return rewriter.notifyMatchFailure(
+          transposeOp,
+          "the source or result vector of the transpose op lacks layout "
+          "attribute");
+    if (sourceLayout.getRank() != 2 || resultLayout.getRank() != 2)
----------------
charithaintc wrote:

added a TODO note. I will improve this later once we can write more complex test cases. Current SIMT distribution test infra has some limitations. 

https://github.com/llvm/llvm-project/pull/155517


More information about the Mlir-commits mailing list