[Mlir-commits] [mlir] [mlir][xegpu] Add SIMT distribution support for GEMM transpose B case. (PR #155517)
Jianhui Li
llvmlistbot at llvm.org
Thu Sep 11 16:51:00 PDT 2025
================
@@ -1001,6 +1016,129 @@ struct LoadDistribution final : public gpu::WarpDistributionPattern {
}
};
+/// Sink a memref::ExtractAlignedPointerAsIndex op feeding into yield op of an
+/// enclosing `gpu.warp_execute_on_lane_0` region. This will simply move the op
+/// outside of the warp op.
+struct MemrefExtractAlignedPointerAsIndexDistribution final
+ : public gpu::WarpDistributionPattern {
+ using gpu::WarpDistributionPattern::WarpDistributionPattern;
+ LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
+ PatternRewriter &rewriter) const override {
+ OpOperand *operand = getWarpResult(
+ warpOp, llvm::IsaPred<memref::ExtractAlignedPointerAsIndexOp>);
+ if (!operand)
+ return rewriter.notifyMatchFailure(
+ warpOp,
+ "warp result is not a xegpu::MemrefExtractAlignedPointerAsIndex op");
+ auto extractOp =
+ operand->get().getDefiningOp<memref::ExtractAlignedPointerAsIndexOp>();
+ unsigned operandIdx = operand->getOperandNumber();
+ SmallVector<size_t> newRetIndices;
+ gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+ rewriter, warpOp, extractOp.getSource(),
+ TypeRange{extractOp.getSource().getType()}, newRetIndices);
+ rewriter.setInsertionPointAfter(newWarpOp);
+ auto newExtractOp = memref::ExtractAlignedPointerAsIndexOp::create(
+ rewriter, newWarpOp.getLoc(), extractOp.getType(),
+ newWarpOp.getResult(newRetIndices[0]));
+ Value distributedVal = newWarpOp.getResult(operandIdx);
+ rewriter.replaceAllUsesWith(distributedVal, newExtractOp.getResult());
+ return success();
+ }
+};
+
+/// Distribute a vector::BitCastOp feeding into yield op of an enclosing
+/// `gpu.warp_execute_on_lane_0` region. Bitcast only impacts the innermost
+/// diemension of the source/result vectors. Equivalent vector::BitCastOp is
+/// created outside of the warp op with distributed source vector type (computed
+/// using assigned layout).
+struct VectorBitcastDistribution final : public gpu::WarpDistributionPattern {
+ using gpu::WarpDistributionPattern::WarpDistributionPattern;
+ LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
+ PatternRewriter &rewriter) const override {
+ OpOperand *operand =
+ getWarpResult(warpOp, llvm::IsaPred<vector::BitCastOp>);
+ if (!operand)
+ return rewriter.notifyMatchFailure(
+ warpOp, "warp result is not a vector::BitCast op");
+ auto bitcastOp = operand->get().getDefiningOp<vector::BitCastOp>();
+ unsigned operandIdx = operand->getOperandNumber();
+ VectorType distributedSourceType =
+ getDistVecTypeBasedOnLaneLayout(
+ xegpu::getDistributeLayoutAttr(bitcastOp.getSource()),
+ bitcastOp.getSourceVectorType())
+ .value_or(VectorType());
+ if (!distributedSourceType)
+ return rewriter.notifyMatchFailure(
+ bitcastOp, "Failed to distribute the source vector type in "
+ "vector::BitCast op");
+ VectorType distributedResultType =
+ cast<VectorType>(warpOp.getResult(operandIdx).getType());
+ SmallVector<size_t> newRetIndices;
+ gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+ rewriter, warpOp, bitcastOp.getSource(),
+ TypeRange{distributedSourceType}, newRetIndices);
+ rewriter.setInsertionPointAfter(newWarpOp);
+ auto newBitcastOp = vector::BitCastOp::create(
+ rewriter, newWarpOp.getLoc(), distributedResultType,
+ newWarpOp.getResult(newRetIndices[0]));
+ Value distributedVal = newWarpOp.getResult(operandIdx);
+ rewriter.replaceAllUsesWith(distributedVal, newBitcastOp.getResult());
+ return success();
+ }
+};
+
+struct VectorTransposeDistribution final : public gpu::WarpDistributionPattern {
+ using gpu::WarpDistributionPattern::WarpDistributionPattern;
+ LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
+ PatternRewriter &rewriter) const override {
+ OpOperand *operand =
+ getWarpResult(warpOp, llvm::IsaPred<vector::TransposeOp>);
+ if (!operand)
+ return rewriter.notifyMatchFailure(
+ warpOp, "warp result is not a vector::Transpose op");
+ auto transposeOp = operand->get().getDefiningOp<vector::TransposeOp>();
+ unsigned operandIdx = operand->getOperandNumber();
+ xegpu::DistributeLayoutAttr sourceLayout =
+ xegpu::getDistributeLayoutAttr(transposeOp.getVector());
+ xegpu::DistributeLayoutAttr resultLayout =
+ xegpu::getDistributeLayoutAttr(transposeOp.getResult());
+ if (!sourceLayout || !resultLayout)
+ return rewriter.notifyMatchFailure(
+ transposeOp,
+ "the source or result vector of the transpose op lacks layout "
+ "attribute");
+ if (sourceLayout.getRank() != 2 || resultLayout.getRank() != 2)
----------------
Jianhui-Li wrote:
Add a TODO to conside removing this check. As long as the layout is transpose of the other, then it means the distributed transpose happens within the lane and should work properly.
https://github.com/llvm/llvm-project/pull/155517
More information about the Mlir-commits
mailing list