[Mlir-commits] [mlir] [mlir][xegpu] Add SIMT distribution support for GEMM transpose B case. (PR #155517)
Charitha Saumya
llvmlistbot at llvm.org
Wed Sep 10 09:26:28 PDT 2025
================
@@ -807,6 +822,138 @@ struct GpuBarrierDistribution final : public gpu::WarpDistributionPattern {
}
};
+/// Sink a memref::ExtractAlignedPointerAsIndex op feeding into yield op of an
+/// enclosing `gpu.warp_execute_on_lane_0` region. This will simply move the op
+/// outside of the warp op.
+struct MemrefExtractAlignedPointerAsIndexDistribution final
+ : public gpu::WarpDistributionPattern {
+ using gpu::WarpDistributionPattern::WarpDistributionPattern;
+ LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
+ PatternRewriter &rewriter) const override {
+ OpOperand *operand = getWarpResult(
+ warpOp, llvm::IsaPred<memref::ExtractAlignedPointerAsIndexOp>);
+ if (!operand)
+ return rewriter.notifyMatchFailure(
+ warpOp,
+ "warp result is not a xegpu::MemrefExtractAlignedPointerAsIndex op");
+ auto extractOp =
+ operand->get().getDefiningOp<memref::ExtractAlignedPointerAsIndexOp>();
+ unsigned operandIdx = operand->getOperandNumber();
+ SmallVector<size_t> newRetIndices;
+ gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+ rewriter, warpOp, extractOp.getSource(),
+ TypeRange{extractOp.getSource().getType()}, newRetIndices);
+ rewriter.setInsertionPointAfter(newWarpOp);
+ auto newExtractOp = memref::ExtractAlignedPointerAsIndexOp::create(
+ rewriter, newWarpOp.getLoc(), extractOp.getType(),
+ newWarpOp.getResult(newRetIndices[0]));
+ Value distributedVal = newWarpOp.getResult(operandIdx);
+ rewriter.replaceAllUsesWith(distributedVal, newExtractOp.getResult());
+ return success();
+ }
+};
+
+struct VectorBitcastDistribution final : public gpu::WarpDistributionPattern {
+ using gpu::WarpDistributionPattern::WarpDistributionPattern;
+ LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
+ PatternRewriter &rewriter) const override {
+ OpOperand *operand =
+ getWarpResult(warpOp, llvm::IsaPred<vector::BitCastOp>);
+ if (!operand)
+ return rewriter.notifyMatchFailure(
+ warpOp, "warp result is not a vector::BitCast op");
+ auto bitcastOp = operand->get().getDefiningOp<vector::BitCastOp>();
+ unsigned operandIdx = operand->getOperandNumber();
+ VectorType distributedSourceType =
+ getDistVecTypeBasedOnLaneLayout(
+ xegpu::getDistributeLayoutAttr(bitcastOp.getSource()),
+ bitcastOp.getSourceVectorType())
+ .value_or(VectorType());
+ if (!distributedSourceType)
+ return rewriter.notifyMatchFailure(
+ bitcastOp, "Failed to distribute the source vector type in "
+ "vector::BitCast op");
+ VectorType distributedResultType =
+ cast<VectorType>(warpOp.getResult(operandIdx).getType());
+ if (distributedSourceType.getRank() != 2 ||
+ distributedResultType.getRank() != 2)
+ return rewriter.notifyMatchFailure(
+ bitcastOp, "the source or result vector of the bitcast op "
+ "are not 2D vectors");
+ SmallVector<size_t> newRetIndices;
+ gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+ rewriter, warpOp, bitcastOp.getSource(),
+ TypeRange{distributedSourceType}, newRetIndices);
+ rewriter.setInsertionPointAfter(newWarpOp);
+ auto newBitcastOp = vector::BitCastOp::create(
+ rewriter, newWarpOp.getLoc(), distributedResultType,
+ newWarpOp.getResult(newRetIndices[0]));
+ Value distributedVal = newWarpOp.getResult(operandIdx);
+ rewriter.replaceAllUsesWith(distributedVal, newBitcastOp.getResult());
+ return success();
+ }
+};
+
+struct VectorTransposeDistribution final : public gpu::WarpDistributionPattern {
+ using gpu::WarpDistributionPattern::WarpDistributionPattern;
+ LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
+ PatternRewriter &rewriter) const override {
+ OpOperand *operand =
+ getWarpResult(warpOp, llvm::IsaPred<vector::TransposeOp>);
+ if (!operand)
+ return rewriter.notifyMatchFailure(
+ warpOp, "warp result is not a vector::Transpose op");
+ auto transposeOp = operand->get().getDefiningOp<vector::TransposeOp>();
+ unsigned operandIdx = operand->getOperandNumber();
+ xegpu::DistributeLayoutAttr sourceLayout =
+ xegpu::getDistributeLayoutAttr(transposeOp.getVector());
+ xegpu::DistributeLayoutAttr resultLayout =
+ xegpu::getDistributeLayoutAttr(transposeOp.getResult());
+ if (!sourceLayout || !resultLayout)
+ return rewriter.notifyMatchFailure(
+ transposeOp,
+ "the source or result vector of the transpose op lacks layout "
+ "attribute");
+ SmallVector<int64_t> sourceLaneLayout = sourceLayout.getLaneLayoutAsInt();
+ SmallVector<int64_t> resultLaneLayout = resultLayout.getLaneLayoutAsInt();
+ SmallVector<int64_t> sourceLaneData = sourceLayout.getLaneDataAsInt();
+ SmallVector<int64_t> resultLaneData = resultLayout.getLaneDataAsInt();
+ if (sourceLaneLayout.size() != 2 || resultLaneLayout.size() != 2)
+ return rewriter.notifyMatchFailure(
+ transposeOp, "the source or result vector of the transpose op "
+ "does not have 2D layout");
+ auto is2DTranspose = [](ArrayRef<int64_t> input, ArrayRef<int64_t> output) {
+ return input.size() == 2 && output.size() == 2 && input[0] == output[1] &&
+ input[1] == output[0];
+ };
+
+ if (!is2DTranspose(sourceLaneLayout, resultLaneLayout) ||
----------------
charithaintc wrote:
I am working on this. I prefer to do this in a separate PR. I will change the use locations once it is ready.
https://github.com/llvm/llvm-project/pull/155517
More information about the Mlir-commits
mailing list