[Mlir-commits] [mlir] 34259b7 - [MLIR][XeGPU] Refactoring Transpose OP Layout Propagation (#184702)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Mar 5 15:03:53 PST 2026
Author: Jianhui Li
Date: 2026-03-05T15:03:49-08:00
New Revision: 34259b76bf9c728b5877d00cac4be8fedabc6fef
URL: https://github.com/llvm/llvm-project/commit/34259b76bf9c728b5877d00cac4be8fedabc6fef
DIFF: https://github.com/llvm/llvm-project/commit/34259b76bf9c728b5877d00cac4be8fedabc6fef.diff
LOG: [MLIR][XeGPU] Refactoring Transpose OP Layout Propagation (#184702)
This PR refactors Transpose Op Layout Propagation:
1. Add inferTransposeSourceLayout() to layout utility, enhance layout
propagation and conflict handling to use this function
2. Add Layout utility: TransposeDims()
3. Refactor IsTransposeOf() and fix minor bugs
4. Fix minor issue in dropSgLayoutAndData()
Added:
Modified:
mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h
mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir
mlir/test/Dialect/XeGPU/propagate-layout.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index 6f667f4801673..a98073f3c5cf2 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -254,6 +254,10 @@ def DistributeLayoutAttr: AttrInterface<"DistributeLayoutAttr"> {
"xegpu::DistributeLayoutAttr",
"collapseDims",
(ins "SmallVector<int64_t>": $dimGroup)>,
+ InterfaceMethod<[{Derive a new layout by trasnposing it using `permutation`.}],
+ "xegpu::DistributeLayoutAttr",
+ "transposeDims",
+ (ins "ArrayRef<int64_t>": $permutation)>,
InterfaceMethod<[{Generates instructions to compute multidimensional coordinates for dist units
assigned to a level identified by linearId. The shape parameter
represents the higher-level problem size. Each level may access
@@ -261,56 +265,17 @@ def DistributeLayoutAttr: AttrInterface<"DistributeLayoutAttr"> {
"FailureOr<SmallVector<SmallVector<Value>>>",
"computeDistributedCoords",
(ins "OpBuilder &": $builder, "Location":$loc, "Value":$linearId, "ArrayRef<int64_t>":$shape)>,
- InterfaceMethod</*desc=*/[{Check if this layout can be achieved by applying a transpose
- to some other layout according to given permutation of (0...n-1).}],
- /*retTy=*/"bool",
- /*methodName=*/"isTransposeOf",
- /*args=*/(ins "const xegpu::DistributeLayoutAttr&": $other, "ArrayRef<int64_t>": $perm),
- /*methodBody=*/[{
- if (!other)
- return false;
- if ($_self.getRank() != other.getRank() || perm.size() != static_cast<size_t>($_self.getRank()))
- return false;
- // Check if the permutation is valid
- if (!isPermutationVector(perm))
- return false;
- auto checkTranspose = [](ArrayRef<int64_t> dst, ArrayRef<int64_t> src, ArrayRef<int64_t> perm) {
- // If both `dst` and `src` are empty, conservatively return true
- // here because some layout fields can be empty.
- if (dst.empty() && src.empty())
- return true;
- for (const auto &ta : llvm::enumerate(perm)) {
- if (src[ta.index()] != dst[ta.value()])
- return false;
- }
- return true;
- };
- // Check sgLayout
- if (!checkTranspose($_self.getEffectiveSgLayoutAsInt(), other.getEffectiveSgLayoutAsInt(), perm))
- return false;
- // Check sgData
- if (!checkTranspose($_self.getEffectiveSgDataAsInt(), other.getEffectiveSgDataAsInt(), perm))
- return false;
- // Check instData
- if (!checkTranspose($_self.getEffectiveInstDataAsInt(), other.getEffectiveInstDataAsInt(), perm))
- return false;
- // Check laneLayout
- if (!checkTranspose($_self.getEffectiveLaneLayoutAsInt(), other.getEffectiveLaneLayoutAsInt(), perm))
- return false;
- // Check laneData
- if (!checkTranspose($_self.getEffectiveLaneDataAsInt(), other.getEffectiveLaneDataAsInt(), perm))
- return false;
- // Check order
- if (!checkTranspose($_self.getEffectiveOrderAsInt(), other.getEffectiveOrderAsInt(), perm))
- return false;
-
- return true;
- }]>,
InterfaceMethod</*desc=*/[{Check if this layout is a slice of another layout.}],
/*retTy=*/"bool",
/*methodName=*/"isSliceOf",
/*args=*/(ins "const xegpu::DistributeLayoutAttr&": $other)>,
-
+ InterfaceMethod</*desc=*/[{Check if this layout is a transpose of
+ the other layout according to given permutation of (0...n-1).}],
+ /*retTy=*/"bool",
+ /*methodName=*/"isTransposeOf",
+ /*args=*/(ins "const xegpu::DistributeLayoutAttr&": $other,
+ "ArrayRef<int64_t>": $perm,
+ "xegpu::LayoutKind": $kind)>,
InterfaceMethod</*desc=*/[{Check if this layout is compatible with another layout
at a specific level of the layout hierarchy. Unlike isEqualTo,
this compares only the effective (non-sliced) fields at the
@@ -498,8 +463,11 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttr]> {
// avoid every field of the attribute is nullptr, which may lead to segment fault
if (!getInstData() && !getLaneLayout())
return nullptr;
+ // Only preserve order if lane_layout remains, since order requires
+ // sg_layout or lane_layout to be present.
+ auto order = getLaneLayout() ? getOrder() : nullptr;
return LayoutAttr::get(getContext(), nullptr, nullptr, getInstData(),
- getLaneLayout(), getLaneData(), getOrder());
+ getLaneLayout(), getLaneData(), order);
}
LayoutAttr dropInstData() const{
@@ -567,6 +535,9 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttr]> {
// that are collapsed into a single dimension in the derived layout.
DistributeLayoutAttr collapseDims(SmallVector<int64_t> dimGroup);
+ // Derive a new layout by transposing the layout using `permutation`.
+ DistributeLayoutAttr transposeDims(ArrayRef<int64_t> permutation);
+
/// Delinearizes a linear ID into its multidimensional indices
/// based on the effective level of the layout.
FailureOr<SmallVector<Value>>
@@ -584,6 +555,9 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttr]> {
/// Check if this layout is equal to another layout.
bool isEqualTo(const xegpu::DistributeLayoutAttr &other);
+
+ /// Check if this layout is a transpose of another layout.
+ bool isTransposeOf(const xegpu::DistributeLayoutAttr &other, ArrayRef<int64_t> perm, const xegpu::LayoutKind kind);
}];
let assemblyFormat = "`<` struct(params) `>`";
@@ -767,6 +741,9 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttr]> {
// that are collapsed into a single dimension in the derived layout.
DistributeLayoutAttr collapseDims(SmallVector<int64_t> dimGroup);
+ // Derive a new layout by transposing the layout using `permutation`.
+ DistributeLayoutAttr transposeDims(ArrayRef<int64_t> permutation);
+
/// flatten a nested SliceAttr, e.g., for 2-level nested SliceAttr
/// #xegpu.slice<#xegpu.slice<#xegpu.layout<sg_layout = [4, 8, 12]>, dims = [0]>, dims = [0]>
/// it will coalese two slice operations and return a simplified SliceAttr
@@ -792,6 +769,9 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttr]> {
/// Check if this layout is equal to another layout.
bool isEqualTo(const xegpu::DistributeLayoutAttr &other);
+ /// Check if this layout is a transpose of another layout.
+ bool isTransposeOf(const xegpu::DistributeLayoutAttr &other, ArrayRef<int64_t> perm, const xegpu::LayoutKind kind);
+
/// Drop the slice dims to get the original layout.
SliceAttr dropSliceDims(ArrayRef<int64_t> sliceDimsToDrop);
}];
diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h b/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h
index 3482d1b9401bb..2ae0ef3ae852d 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h
@@ -81,6 +81,11 @@ DistributeLayoutAttr
inferMultiReductionSourceLayout(DistributeLayoutAttr resLayout,
SmallVector<int64_t> reduceDims);
+/// Infers the source layout attribute for a transpose operation given the
+/// result layout attribute and permutation.
+DistributeLayoutAttr inferTransposeSourceLayout(DistributeLayoutAttr resLayout,
+ ArrayRef<int64_t> permutation);
+
/// Infers the source layout attribute for a bitcast operation given the
/// result layout attribute, result element type bitwidth, and source element
/// type bitwidth.
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index c082600ec27d7..4d412dd92e1b0 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -618,24 +618,97 @@ DistributeLayoutAttr LayoutAttr::collapseDims(SmallVector<int64_t> dimGroup) {
SmallVector<int32_t> laneLayout32(laneLayout.begin(), laneLayout.end());
SmallVector<int32_t> laneData32(laneData.begin(), laneData.end());
+ auto toAttr = [&](ArrayRef<int32_t> v) -> DenseI32ArrayAttr {
+ return v.empty() ? nullptr : DenseI32ArrayAttr::get(getContext(), v);
+ };
+
auto collapsedLayout = xegpu::LayoutAttr::get(
- getContext(),
- sgLayout32.empty() ? DenseI32ArrayAttr()
- : DenseI32ArrayAttr::get(getContext(), sgLayout32),
- sgData32.empty() ? DenseI32ArrayAttr()
- : DenseI32ArrayAttr::get(getContext(), sgData32),
- instData32.empty() ? DenseI32ArrayAttr()
- : DenseI32ArrayAttr::get(getContext(), instData32),
- laneLayout32.empty() ? DenseI32ArrayAttr()
- : DenseI32ArrayAttr::get(getContext(), laneLayout32),
- laneData32.empty() ? DenseI32ArrayAttr()
- : DenseI32ArrayAttr::get(getContext(), laneData32),
- collapsedOrder.empty()
- ? DenseI32ArrayAttr()
- : DenseI32ArrayAttr::get(getContext(), collapsedOrder));
+ getContext(), toAttr(sgLayout32), toAttr(sgData32), toAttr(instData32),
+ toAttr(laneLayout32), toAttr(laneData32), toAttr(collapsedOrder));
return collapsedLayout;
}
+// Derive a new layout by transpose the layout using `permutation`.
+DistributeLayoutAttr LayoutAttr::transposeDims(ArrayRef<int64_t> permutation) {
+
+ SmallVector<int64_t> origSgLayout = getEffectiveSgLayoutAsInt();
+ SmallVector<int64_t> origSgData = getEffectiveSgDataAsInt();
+ SmallVector<int64_t> origInstData = getEffectiveInstDataAsInt();
+ SmallVector<int64_t> origLaneLayout = getEffectiveLaneLayoutAsInt();
+ SmallVector<int64_t> origLaneData = getEffectiveLaneDataAsInt();
+ SmallVector<int64_t> origOrder = getEffectiveOrderAsInt();
+
+ SmallVector<int32_t> sgLayout;
+ SmallVector<int32_t> sgData;
+ SmallVector<int32_t> instData;
+ SmallVector<int32_t> laneLayout;
+ SmallVector<int32_t> laneData;
+ SmallVector<int32_t> order;
+
+ for (int64_t idx : permutation) {
+ if (!origLaneLayout.empty()) {
+ laneLayout.push_back(static_cast<int32_t>(origLaneLayout[idx]));
+ laneData.push_back(static_cast<int32_t>(origLaneData[idx]));
+ }
+ if (!origInstData.empty())
+ instData.push_back(static_cast<int32_t>(origInstData[idx]));
+ if (!origSgLayout.empty()) {
+ sgLayout.push_back(static_cast<int32_t>(origSgLayout[idx]));
+ sgData.push_back(static_cast<int32_t>(origSgData[idx]));
+ }
+ order.push_back(static_cast<int32_t>(origOrder[idx]));
+ }
+ if (origLaneLayout.empty() && origSgLayout.empty())
+ order.clear();
+
+ auto toAttr = [&](ArrayRef<int32_t> v) -> DenseI32ArrayAttr {
+ return v.empty() ? nullptr : DenseI32ArrayAttr::get(getContext(), v);
+ };
+ return xegpu::LayoutAttr::get(getContext(), toAttr(sgLayout), toAttr(sgData),
+ toAttr(instData), toAttr(laneLayout),
+ toAttr(laneData), toAttr(order));
+}
+
+/// Check if this layout is a transpose of another layout.
+bool LayoutAttr::isTransposeOf(const xegpu::DistributeLayoutAttr &other,
+ ArrayRef<int64_t> perm,
+ const xegpu::LayoutKind kind) {
+ if (!other)
+ return false;
+ if (getRank() != other.getRank() ||
+ perm.size() != static_cast<size_t>(getRank()))
+ return false;
+ if (!isPermutationVector(perm))
+ return false;
+ auto checkTranspose = [](ArrayRef<int64_t> dst, ArrayRef<int64_t> src,
+ ArrayRef<int64_t> perm) {
+ for (const auto &ta : llvm::enumerate(perm)) {
+ if (src[ta.index()] != dst[ta.value()])
+ return false;
+ }
+ return true;
+ };
+ if (kind == xegpu::LayoutKind::Subgroup)
+ return checkTranspose(getEffectiveSgLayoutAsInt(),
+ other.getEffectiveSgLayoutAsInt(), perm) &&
+ checkTranspose(getEffectiveSgDataAsInt(),
+ other.getEffectiveSgDataAsInt(), perm) &&
+ checkTranspose(getEffectiveOrderAsInt(),
+ other.getEffectiveOrderAsInt(), perm);
+ if (kind == xegpu::LayoutKind::InstData)
+ return checkTranspose(getEffectiveInstDataAsInt(),
+ other.getEffectiveInstDataAsInt(), perm);
+ if (kind == xegpu::LayoutKind::Lane)
+ return checkTranspose(getEffectiveLaneLayoutAsInt(),
+ other.getEffectiveLaneLayoutAsInt(), perm) &&
+ checkTranspose(getEffectiveLaneDataAsInt(),
+ other.getEffectiveLaneDataAsInt(), perm) &&
+ checkTranspose(getEffectiveOrderAsInt(),
+ other.getEffectiveOrderAsInt(), perm);
+
+ return false;
+}
+
//===----------------------------------------------------------------------===//
// XeGPU_SliceAttr
//===----------------------------------------------------------------------===//
@@ -881,6 +954,62 @@ DistributeLayoutAttr SliceAttr::collapseDims(SmallVector<int64_t> dimGroup) {
DenseI64ArrayAttr::get(getContext(), sliceDims));
}
+SmallVector<int64_t> getPermForParentLayout(ArrayRef<int64_t> sliceDims,
+ ArrayRef<int64_t> permutation) {
+ SmallVector<int64_t> sortedSliceDims = llvm::to_vector(sliceDims);
+ llvm::sort(sortedSliceDims);
+
+ for (size_t i = 1; i < sortedSliceDims.size(); ++i) {
+ assert((sortedSliceDims[i] == sortedSliceDims[i - 1] + 1) &&
+ "slice dims non consecutive, cannot be transposed");
+ }
+
+ SmallVector<int64_t> permForParent;
+ if (sortedSliceDims.front() == 0) {
+ // Example: sliceDims.size() = 2, permutation= {1, 0}
+ // result: {3, 2, 1, 0}.
+ for (int64_t dim : permutation)
+ permForParent.push_back(dim + sortedSliceDims.size());
+ for (int64_t i = sortedSliceDims.size() - 1; i >= 0; --i)
+ permForParent.push_back(i);
+ } else {
+ // Example: sliceDims.size() = 2, permutation = {0, 1}
+ // result: {3, 2, 0, 1}.
+ for (int64_t i = sortedSliceDims.size() - 1; i >= 0; --i)
+ permForParent.push_back(i + permutation.size());
+ for (int64_t dim : permutation)
+ permForParent.push_back(dim);
+ }
+ return permForParent;
+}
+
+// Derive a new layout by transpose the layout using `permutation`.
+DistributeLayoutAttr SliceAttr::transposeDims(ArrayRef<int64_t> permutation) {
+ SmallVector<int64_t> sliceDims = llvm::to_vector(getDims().asArrayRef());
+ DistributeLayoutAttr parent = getParent();
+ SmallVector<int64_t> permForParent =
+ getPermForParentLayout(sliceDims, permutation);
+ auto transposedParent = parent.transposeDims(permForParent);
+ return SliceAttr::get(getContext(), transposedParent,
+ DenseI64ArrayAttr::get(getContext(), sliceDims));
+}
+
+/// Check if this layout is a transpose of another layout.
+bool SliceAttr::isTransposeOf(const xegpu::DistributeLayoutAttr &other,
+ ArrayRef<int64_t> perm,
+ const xegpu::LayoutKind kind) {
+ // other must be a SliceAttr with the same slice dims.
+ auto otherSlice = dyn_cast<xegpu::SliceAttr>(other);
+ if (!otherSlice || getDims() != otherSlice.getDims())
+ return false;
+ // check whether the parent layout is transpose of each other.
+ SmallVector<int64_t> sliceDims = llvm::to_vector(getDims().asArrayRef());
+ DistributeLayoutAttr parent = getParent();
+ SmallVector<int64_t> permForParent = getPermForParentLayout(sliceDims, perm);
+ auto otherParent = otherSlice.getParent();
+ return parent.isTransposeOf(otherParent, permForParent, kind);
+}
+
//===----------------------------------------------------------------------===//
// XeGPU_RangeAttr
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
index 7aa186bb22224..432886db29d23 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
@@ -178,6 +178,14 @@ xegpu::inferMultiReductionSourceLayout(xegpu::DistributeLayoutAttr resLayout,
return sliceLayout.getParent();
}
+/// Infers the source layout attribute for a transpose operation given the
+/// result layout attribute and permutation.
+xegpu::DistributeLayoutAttr
+xegpu::inferTransposeSourceLayout(xegpu::DistributeLayoutAttr resLayout,
+ ArrayRef<int64_t> permutation) {
+ return resLayout.transposeDims(permutation);
+}
+
/// Infers the source layout attribute for a bitcast operation given the
/// result layout attribute, result element type bitwidth, and source element
/// type bitwidth.
@@ -1144,6 +1152,16 @@ xegpu::DistributeLayoutAttr xegpu::getConsumerLayoutAt(OpOperand &operand) {
if (idx == 1)
return resLayout;
}
+
+ // For vector::TransposeOp, infer source layout from result layout using
+ // permutation.
+ if (auto transpose = dyn_cast<vector::TransposeOp>(op)) {
+ if (!resLayout)
+ return xegpu::DistributeLayoutAttr();
+ return xegpu::inferTransposeSourceLayout(resLayout,
+ transpose.getPermutation());
+ }
+
// For elementwise operations, all operands must have the same layout as the
// result.
if (OpTrait::hasElementwiseMappableTraits(op) && op->getNumResults() == 1) {
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 7f7e8d6ad7734..ab8f7e768ec1c 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -912,9 +912,12 @@ void LayoutInfoPropagation::visitTransposeOp(
LayoutInfo resultLayout = results[0]->getValue();
if (!resultLayout.isAssigned())
return;
- LayoutInfo newLayout = resultLayout.transpose(transpose.getPermutation());
+ auto consumerLayoutAttr =
+ dyn_cast<xegpu::DistributeLayoutAttr>(resultLayout.get());
+ auto srcLayoutAttr = xegpu::inferTransposeSourceLayout(
+ consumerLayoutAttr, transpose.getPermutation());
// Propagate the new layout to the vector operand.
- propagateIfChanged(operands[0], operands[0]->meet(newLayout));
+ propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
}
/// For vector::BitCastOp, the lane_data of the source layout is changed based
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index bf9fded8a3abe..38bc95d39c2c6 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -1963,7 +1963,8 @@ struct VectorTransposeDistribution final : public gpu::WarpDistributionPattern {
"does not have 2D layout");
ArrayRef<int64_t> perm = transposeOp.getPermutation();
// Result layout must be a transpose of source layout.
- if (!resultLayout.isTransposeOf(sourceLayout, perm))
+ if (!resultLayout.isTransposeOf(sourceLayout, perm,
+ xegpu::LayoutKind::Lane))
return rewriter.notifyMatchFailure(
transposeOp,
"the source or result vector layouts must be 2D transposes of each "
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 30e4a956a0add..139a30e76854f 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -1532,7 +1532,8 @@ struct WgToSgVectorTransposeOp
// Check that sgLayout, sgData & order are properly transposed for source
// and result
- if (!layout.isTransposeOf(sourceLayout, permutation))
+ if (!layout.isTransposeOf(sourceLayout, permutation,
+ xegpu::LayoutKind::Subgroup))
return rewriter.notifyMatchFailure(
op, "Result layout is not a valid transpose of source layout "
"according to permutation");
@@ -1540,13 +1541,13 @@ struct WgToSgVectorTransposeOp
SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
VectorType newResultType =
VectorType::get(sgShape, resultType.getElementType());
+
SmallVector<Value> newTransposeOps;
for (auto src : adaptor.getVector()) {
auto newTranspose = vector::TransposeOp::create(
rewriter, op.getLoc(), newResultType, src, permutation);
newTransposeOps.push_back(newTranspose.getResult());
}
-
rewriter.replaceOpWithMultiple(op, {newTransposeOps});
return success();
}
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir b/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir
index c073045691f56..ffbe95b2a6f84 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir
@@ -58,12 +58,12 @@ gpu.module @test {
gpu.func @vector_transpose(%src: memref<256x128xf32>, %src1: memref<128x256xf32>) kernel attributes
{known_block_size = array<i32: 1, 32, 16>} {
// CHECK: %[[TDESC_LD:.*]] = xegpu.create_nd_tdesc %[[ARG_0]] : memref<256x128xf32> ->
- // CHECK-SAME: !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>>
+ // CHECK-SAME: !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], order = [0, 1]>>
// CHECK: %[[TDESC_ST:.*]] = xegpu.create_nd_tdesc %[[ARG_1]] : memref<128x256xf32> ->
// CHECK-SAME: !xegpu.tensor_desc<128x256xf32, #xegpu.layout<sg_layout = [4, 8], sg_data = [32, 32]>>
- // CHECK: %[[LOAD:.*]] = xegpu.load_nd %[[TDESC_LD]][0, 0] <{layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>}> :
- // CHECK-SAME: !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>> -> vector<256x128xf32>
+ // CHECK: %[[LOAD:.*]] = xegpu.load_nd %[[TDESC_LD]][0, 0] <{layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], order = [0, 1]>}> :
+ // CHECK-SAME: !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], order = [0, 1]>> -> vector<256x128xf32>
// CHECK: %[[TRANSPOSED:.*]] = vector.transpose %2, [1, 0]
// CHECK-SAME {layout_result_0 = #xegpu.layout<sg_layout = [4, 8], sg_data = [32, 32]>} : vector<256x128xf32> to vector<128x256xf32>
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout.mlir b/mlir/test/Dialect/XeGPU/propagate-layout.mlir
index 4f2349a89b1ed..3253d0004caf4 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout.mlir
@@ -278,9 +278,9 @@ func.func @vector_bitcast_i16_to_f16(%arg0: memref<8x16xi16>, %arg1: memref<16x1
// -----
gpu.module @test {
// CHECK-LABEL: func.func @vector_bitcast_i32_to_f16(
-// CHECK: %[[LOAD:.*]] = xegpu.load_nd %{{.*}} <{layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>}>
-// CHECK-SAME: !xegpu.tensor_desc<16x8xi32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>> -> vector<16x8xi32>
-// CHECK-NEXT: %{{.*}} = vector.bitcast %[[LOAD]] {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>}
+// CHECK: %[[LOAD:.*]] = xegpu.load_nd %{{.*}} <{layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1], order = [0, 1]>}>
+// CHECK-SAME: !xegpu.tensor_desc<16x8xi32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1], order = [0, 1]>> -> vector<16x8xi32>
+// CHECK-NEXT: %{{.*}} = vector.bitcast %[[LOAD]] {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2], order = [0, 1]>}
// CHECK-SAME: vector<16x8xi32> to vector<16x16xf16>
func.func @vector_bitcast_i32_to_f16(%arg0: memref<8x16xf16>, %arg1: memref<16x8xi32>, %arg2: memref<8x16xf32>) {
%c0 = arith.constant 0 : index
More information about the Mlir-commits
mailing list