[Mlir-commits] [mlir] [MLIR][XeGPU] Support order attribute and add pattern for vector.transpose in WgToSg Pass (PR #165307)
Nishant Patel
llvmlistbot at llvm.org
Wed Oct 29 14:46:21 PDT 2025
================
@@ -1217,6 +1217,93 @@ struct WgToSgMultiDimReductionOp
}
};
+// This pattern transforms vector.transpose ops to work at subgroup level.
+struct WgToSgVectorTransposeOp
+ : public OpConversionPattern<vector::TransposeOp> {
+ using OpConversionPattern<vector::TransposeOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::TransposeOp op, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ VectorType resultType = dyn_cast<VectorType>(op.getResult().getType());
+ if (!resultType)
+ return failure();
+
+ ArrayRef<int64_t> wgShape = resultType.getShape();
+ xegpu::DistributeLayoutAttr layout =
+ xegpu::getDistributeLayoutAttr(op.getResult());
+ if (!layout || !layout.isForWorkgroup())
+ return failure();
+
+ xegpu::DistributeLayoutAttr sourceLayout =
+ xegpu::getDistributeLayoutAttr(op.getVector());
+ if (!sourceLayout || !sourceLayout.isForWorkgroup())
+ return failure();
+
+ SmallVector<int64_t> sourceSgLayout =
+ sourceLayout.getEffectiveSgLayoutAsInt();
+ SmallVector<int64_t> sourceSgData = sourceLayout.getEffectiveSgDataAsInt();
+ SmallVector<int64_t> resultSgLayout = layout.getEffectiveSgLayoutAsInt();
+ SmallVector<int64_t> resultSgData = layout.getEffectiveSgDataAsInt();
+ DenseI32ArrayAttr sourceOrder = sourceLayout.getOrder();
+ DenseI32ArrayAttr resultOrder = layout.getOrder();
+
+ if (!sourceOrder || !resultOrder) {
+ return rewriter.notifyMatchFailure(
+ op, "Both source and result must have order attributes");
+ }
+
+ SmallVector<int64_t> sourceOrderVec = llvm::to_vector(
+ llvm::map_range(sourceOrder.asArrayRef(),
+ [](int32_t idx) { return static_cast<int64_t>(idx); }));
+ SmallVector<int64_t> resultOrderVec = llvm::to_vector(
+ llvm::map_range(resultOrder.asArrayRef(),
+ [](int32_t idx) { return static_cast<int64_t>(idx); }));
+
+ ArrayRef<int64_t> permutation = op.getPermutation();
+ size_t expectedSize = permutation.size();
+ if (sourceSgLayout.size() != expectedSize ||
+ sourceSgData.size() != expectedSize ||
+ resultSgLayout.size() != expectedSize ||
+ resultSgData.size() != expectedSize ||
+ sourceOrderVec.size() != expectedSize ||
+ resultOrderVec.size() != expectedSize) {
+ return rewriter.notifyMatchFailure(
+ op, "All layouts and permutation must have the same rank");
+ }
+
+ // Check that sgLayout, sgData & order are properly transposed for operand
+ // and result
+ for (size_t i = 0; i < permutation.size(); ++i) {
+ int64_t srcDim = permutation[i];
+ if (resultSgLayout[i] != sourceSgLayout[srcDim] ||
+ resultSgData[i] != sourceSgData[srcDim] ||
+ resultOrderVec[i] != sourceOrderVec[srcDim]) {
+ return rewriter.notifyMatchFailure(
+ op, "Result layout is not a valid transpose of source layout "
+ "according to permutation");
+ }
+ }
----------------
nbpatel wrote:
Thanks for pointing this out.
https://github.com/llvm/llvm-project/pull/165307
More information about the Mlir-commits
mailing list