[Mlir-commits] [mlir] [MLIR] [XeGPU] Add distribution patterns for vector transpose, bitcast & mask ops in sg to wi pass (PR #187392)
Nishant Patel
llvmlistbot at llvm.org
Sun Mar 29 14:53:07 PDT 2026
================
@@ -687,6 +688,162 @@ struct SgToWiLoadMatrix : public OpConversionPattern<xegpu::LoadMatrixOp> {
}
};
+/// Distributes a subgroup-level vector.transpose op to workitem-level.
+struct SgToWiVectorTranspose : public OpConversionPattern<vector::TransposeOp> {
+ using OpConversionPattern<vector::TransposeOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::TransposeOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ xegpu::DistributeLayoutAttr sourceLayout =
+ xegpu::getTemporaryLayout(op->getOpOperand(0));
+ xegpu::DistributeLayoutAttr resultLayout =
+ xegpu::getTemporaryLayout(op->getOpResult(0));
+ if (!sourceLayout || !resultLayout)
+ return rewriter.notifyMatchFailure(
+ op, "the source or result vector of the transpose op lacks layout "
+ "attribute");
+ ArrayRef<int64_t> perm = op.getPermutation();
+ // Result layout must be a transpose of source layout.
+ if (!resultLayout.isTransposeOf(sourceLayout, perm,
+ xegpu::LayoutKind::Lane))
+ return rewriter.notifyMatchFailure(
+ op, "the source or result vector layouts must be transposes of "
+ "each other");
+ FailureOr<VectorType> distributedResultTypeOrFailure =
+ getDistVecTypeBasedOnLaneLayout(resultLayout, op.getResultVectorType());
+ if (failed(distributedResultTypeOrFailure))
+ return rewriter.notifyMatchFailure(
+ op, "Failed to distribute the result vector type in "
+ "vector::Transpose op");
+ auto newOp = vector::TransposeOp::create(rewriter, op.getLoc(),
+ adaptor.getVector(), perm);
+ rewriter.replaceOp(op, castValueTo(rewriter, newOp.getResult(),
+ distributedResultTypeOrFailure.value()));
+ return success();
+ }
+};
+
+/// Distributes a subgroup-level vector.bitcast op to workitem-level.
+/// Bitcast only impacts the innermost dimension of the source/result vectors.
+struct SgToWiVectorBitcast : public OpConversionPattern<vector::BitCastOp> {
+ using OpConversionPattern<vector::BitCastOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::BitCastOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ xegpu::DistributeLayoutAttr resultLayout =
+ xegpu::getTemporaryLayout(op->getOpResult(0));
+ if (!resultLayout)
+ return rewriter.notifyMatchFailure(
+ op, "result vector of the bitcast op lacks layout attribute");
+ FailureOr<VectorType> distributedResultTypeOrFailure =
+ getDistVecTypeBasedOnLaneLayout(resultLayout, op.getResultVectorType());
+ if (failed(distributedResultTypeOrFailure))
+ return rewriter.notifyMatchFailure(
+ op, "Failed to distribute the result vector type in "
+ "vector::BitCast op");
+ auto newOp = vector::BitCastOp::create(
+ rewriter, op.getLoc(), distributedResultTypeOrFailure.value(),
+ adaptor.getSource());
+ rewriter.replaceOp(op, newOp.getResult());
+ return success();
+ }
+};
+
+/// Distributes a subgroup-level vector.create_mask or vector.constant_mask op
+/// to workitem-level. Each lane computes its own mask bounds based on its
+/// lane coordinates. For each dimension i, the new mask bound is:
+/// new_bound[i] = original_bound[i] - lane_coord[i] * wi_elem_count[i]
+/// where `wi_elem_count[i]` is the number of elements each workitem holds
+/// along dimension i (i.e., `distType.getShape()[i]`).
+/// `vector.create_mask` implicitly clamps the bounds to
+/// `[0, wi_elem_count[i]]`, so no explicit clamping is needed.
+/// For constant_mask, the constant dim sizes are first materialized as
+/// Values, then the same logic applies, producing a vector.create_mask.
+///
+/// Example:
+/// layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>
+/// %mask = vector.create_mask %m0 : vector<16xi1>
+/// For lane k, wi_elem_count = [1], so:
+/// new_bound = m0 - k * 1
+/// Distributed to:
+/// %lane = gpu.lane_id
+/// %new_bound = affine.apply affine_map<()[s0, s1] -> (-s0 + s1)>
+/// ()[%lane, %m0]
+/// %mask = vector.create_mask %new_bound : vector<1xi1>
+template <typename OpType,
+ typename = std::enable_if_t<llvm::is_one_of<
+ OpType, vector::CreateMaskOp, vector::ConstantMaskOp>::value>>
+struct SgToWiCreateMask : public OpConversionPattern<OpType> {
+ using OpConversionPattern<OpType>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ xegpu::DistributeLayoutAttr layout =
+ xegpu::getTemporaryLayout(op->getOpResult(0));
+ if (!layout || !layout.isForSubgroup())
+ return rewriter.notifyMatchFailure(
+ op, "operation result does not have subgroup distribute layout");
+
+ VectorType origType = op.getType();
+ FailureOr<VectorType> distTypeOrFailure =
+ getDistVecTypeBasedOnLaneLayout(layout, origType);
+ if (failed(distTypeOrFailure))
+ return rewriter.notifyMatchFailure(
+ op, "unable to compute workitem vector type from the layout");
+
+ VectorType distType = distTypeOrFailure.value();
+ Location loc = op.getLoc();
+
+ // Materialize the original mask operands as Values.
+ SmallVector<Value> origOperands;
+ if constexpr (std::is_same_v<OpType, vector::CreateMaskOp>) {
+ origOperands.append(op.getOperands().begin(), op.getOperands().end());
+ } else {
+ auto dimSizes = op.getMaskDimSizesAttr().asArrayRef();
+ for (auto dimSize : dimSizes)
+ origOperands.push_back(
+ arith::ConstantIndexOp::create(rewriter, loc, dimSize).getResult());
+ }
+
+ ArrayRef<int64_t> origShape = origType.getShape();
+ ArrayRef<int64_t> distShape = distType.getShape();
+
+ // Delinearize lane ID using the layout.
+ Value laneId = gpu::LaneIdOp::create(rewriter, loc, rewriter.getIndexType(),
+ /*upperBound=*/mlir::IntegerAttr());
+ auto maybeIds = layout.delinearizeId(rewriter, loc, laneId);
+ if (failed(maybeIds))
+ return rewriter.notifyMatchFailure(
+ op, "failed to delinearize lane ID from layout");
+ SmallVector<Value> laneIds = maybeIds.value();
+
+ // Compute new mask operands.
+ AffineExpr s0, s1;
+ bindSymbols(rewriter.getContext(), s0, s1);
+ SmallVector<Value> newOperands;
+ for (int i = 0, e = distShape.size(); i < e; ++i) {
+ if (origShape[i] == distShape[i]) {
+ // Dimension is not distributed, keep the original operand.
+ newOperands.push_back(origOperands[i]);
+ } else {
+ // new_bound = original_bound - lane_coord * dist_size
----------------
nbpatel wrote:
added
https://github.com/llvm/llvm-project/pull/187392
More information about the Mlir-commits
mailing list