[Mlir-commits] [mlir] 34259b7 - [MLIR][XeGPU] Refactoring Transpose OP Layout Propagation (#184702)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Mar 5 15:03:53 PST 2026


Author: Jianhui Li
Date: 2026-03-05T15:03:49-08:00
New Revision: 34259b76bf9c728b5877d00cac4be8fedabc6fef

URL: https://github.com/llvm/llvm-project/commit/34259b76bf9c728b5877d00cac4be8fedabc6fef
DIFF: https://github.com/llvm/llvm-project/commit/34259b76bf9c728b5877d00cac4be8fedabc6fef.diff

LOG: [MLIR][XeGPU] Refactoring Transpose OP Layout Propagation (#184702)

This PR refactors Transpose Op Layout Propagation: 
1. Add inferTransposeSourceLayout() to layout utility, enhance layout
propagation and conflict handling to use this function
2. Add Layout utility: TransposeDims()
3. Refactor IsTransposeOf() and fix minor bugs
4. Fix minor issue in dropSgLayoutAndData()

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
    mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h
    mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
    mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
    mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
    mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
    mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
    mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir
    mlir/test/Dialect/XeGPU/propagate-layout.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index 6f667f4801673..a98073f3c5cf2 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -254,6 +254,10 @@ def DistributeLayoutAttr: AttrInterface<"DistributeLayoutAttr"> {
                     "xegpu::DistributeLayoutAttr",
                     "collapseDims",
                     (ins "SmallVector<int64_t>": $dimGroup)>,
+    InterfaceMethod<[{Derive a new layout by trasnposing it using `permutation`.}],
+                    "xegpu::DistributeLayoutAttr",
+                    "transposeDims",
+                    (ins "ArrayRef<int64_t>": $permutation)>,
     InterfaceMethod<[{Generates instructions to compute multidimensional coordinates for dist units
                       assigned to a level identified by linearId. The shape parameter
                       represents the higher-level problem size. Each level may access
@@ -261,56 +265,17 @@ def DistributeLayoutAttr: AttrInterface<"DistributeLayoutAttr"> {
                     "FailureOr<SmallVector<SmallVector<Value>>>",
                     "computeDistributedCoords",
                     (ins "OpBuilder &": $builder, "Location":$loc, "Value":$linearId, "ArrayRef<int64_t>":$shape)>,
-    InterfaceMethod</*desc=*/[{Check if this layout can be achieved by applying a transpose
-                     to some other layout according to given permutation of (0...n-1).}],
-                    /*retTy=*/"bool",
-                    /*methodName=*/"isTransposeOf",
-                    /*args=*/(ins "const xegpu::DistributeLayoutAttr&": $other, "ArrayRef<int64_t>": $perm),
-                    /*methodBody=*/[{
-                      if (!other)
-                        return false;
-                      if ($_self.getRank() != other.getRank() || perm.size() != static_cast<size_t>($_self.getRank()))
-                        return false;
-                      // Check if the permutation is valid
-                      if (!isPermutationVector(perm))
-                        return false;
-                      auto checkTranspose = [](ArrayRef<int64_t> dst, ArrayRef<int64_t> src, ArrayRef<int64_t> perm) {
-                        // If both `dst` and `src` are empty, conservatively return true
-                        // here because some layout fields can be empty.
-                        if (dst.empty() && src.empty())
-                          return true;
-                        for (const auto &ta : llvm::enumerate(perm)) {
-                          if (src[ta.index()] != dst[ta.value()])
-                            return false;
-                        }
-                        return true;
-                      };
-                      // Check sgLayout
-                      if (!checkTranspose($_self.getEffectiveSgLayoutAsInt(), other.getEffectiveSgLayoutAsInt(), perm))
-                        return false;
-                      // Check sgData
-                      if (!checkTranspose($_self.getEffectiveSgDataAsInt(), other.getEffectiveSgDataAsInt(), perm))
-                        return false;
-                      // Check instData
-                      if (!checkTranspose($_self.getEffectiveInstDataAsInt(), other.getEffectiveInstDataAsInt(), perm))
-                        return false;
-                      // Check laneLayout
-                      if (!checkTranspose($_self.getEffectiveLaneLayoutAsInt(), other.getEffectiveLaneLayoutAsInt(), perm))
-                        return false;
-                      // Check laneData
-                      if (!checkTranspose($_self.getEffectiveLaneDataAsInt(), other.getEffectiveLaneDataAsInt(), perm))
-                        return false;
-                      // Check order
-                      if (!checkTranspose($_self.getEffectiveOrderAsInt(), other.getEffectiveOrderAsInt(), perm))
-                        return false;
-
-                      return true;
-                    }]>,
     InterfaceMethod</*desc=*/[{Check if this layout is a slice of another layout.}],
                     /*retTy=*/"bool",
                     /*methodName=*/"isSliceOf",
                     /*args=*/(ins "const xegpu::DistributeLayoutAttr&": $other)>,
-
+    InterfaceMethod</*desc=*/[{Check if this layout is a transpose of
+                     the other layout according to given permutation of (0...n-1).}],
+                    /*retTy=*/"bool",
+                    /*methodName=*/"isTransposeOf",
+                    /*args=*/(ins "const xegpu::DistributeLayoutAttr&": $other,
+                                  "ArrayRef<int64_t>": $perm,
+                                  "xegpu::LayoutKind": $kind)>,
     InterfaceMethod</*desc=*/[{Check if this layout is compatible with another layout
                      at a specific level of the layout hierarchy. Unlike isEqualTo,
                      this compares only the effective (non-sliced) fields at the
@@ -498,8 +463,11 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttr]> {
       // avoid every field of the attribute is nullptr, which may lead to segment fault
       if (!getInstData() && !getLaneLayout())
         return nullptr;
+      // Only preserve order if lane_layout remains, since order requires
+      // sg_layout or lane_layout to be present.
+      auto order = getLaneLayout() ? getOrder() : nullptr;
       return LayoutAttr::get(getContext(), nullptr, nullptr, getInstData(),
-                             getLaneLayout(), getLaneData(), getOrder());
+                             getLaneLayout(), getLaneData(), order);
     }
 
     LayoutAttr dropInstData() const{
@@ -567,6 +535,9 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttr]> {
     // that are collapsed into a single dimension in the derived layout.
     DistributeLayoutAttr collapseDims(SmallVector<int64_t> dimGroup);
 
+    // Derive a new layout by transposing the layout using `permutation`.
+    DistributeLayoutAttr transposeDims(ArrayRef<int64_t> permutation);
+
     /// Delinearizes a linear ID into its multidimensional indices
     /// based on the effective level of the layout.
     FailureOr<SmallVector<Value>>
@@ -584,6 +555,9 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttr]> {
 
     /// Check if this layout is equal to another layout.
     bool isEqualTo(const xegpu::DistributeLayoutAttr &other);
+
+    /// Check if this layout is a transpose of another layout.
+    bool isTransposeOf(const xegpu::DistributeLayoutAttr &other, ArrayRef<int64_t> perm, const xegpu::LayoutKind kind);
   }];
 
   let assemblyFormat = "`<` struct(params) `>`";
@@ -767,6 +741,9 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttr]> {
     // that are collapsed into a single dimension in the derived layout.
     DistributeLayoutAttr collapseDims(SmallVector<int64_t> dimGroup);
 
+    // Derive a new layout by transposing the layout using `permutation`.
+    DistributeLayoutAttr transposeDims(ArrayRef<int64_t> permutation);
+
     /// flatten a nested SliceAttr, e.g., for 2-level nested SliceAttr
     /// #xegpu.slice<#xegpu.slice<#xegpu.layout<sg_layout = [4, 8, 12]>, dims = [0]>, dims = [0]>
     /// it will coalese two slice operations and return a simplified SliceAttr
@@ -792,6 +769,9 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttr]> {
     /// Check if this layout is equal to another layout.
     bool isEqualTo(const xegpu::DistributeLayoutAttr &other);
 
+    /// Check if this layout is a transpose of another layout.
+    bool isTransposeOf(const xegpu::DistributeLayoutAttr &other, ArrayRef<int64_t> perm, const xegpu::LayoutKind kind);
+
     /// Drop the slice dims to get the original layout.
     SliceAttr dropSliceDims(ArrayRef<int64_t> sliceDimsToDrop);
   }];

diff  --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h b/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h
index 3482d1b9401bb..2ae0ef3ae852d 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h
@@ -81,6 +81,11 @@ DistributeLayoutAttr
 inferMultiReductionSourceLayout(DistributeLayoutAttr resLayout,
                                 SmallVector<int64_t> reduceDims);
 
+/// Infers the source layout attribute for a transpose operation given the
+/// result layout attribute and permutation.
+DistributeLayoutAttr inferTransposeSourceLayout(DistributeLayoutAttr resLayout,
+                                                ArrayRef<int64_t> permutation);
+
 /// Infers the source layout attribute for a bitcast operation given the
 /// result layout attribute, result element type bitwidth, and source element
 /// type bitwidth.

diff  --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index c082600ec27d7..4d412dd92e1b0 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -618,24 +618,97 @@ DistributeLayoutAttr LayoutAttr::collapseDims(SmallVector<int64_t> dimGroup) {
   SmallVector<int32_t> laneLayout32(laneLayout.begin(), laneLayout.end());
   SmallVector<int32_t> laneData32(laneData.begin(), laneData.end());
 
+  auto toAttr = [&](ArrayRef<int32_t> v) -> DenseI32ArrayAttr {
+    return v.empty() ? nullptr : DenseI32ArrayAttr::get(getContext(), v);
+  };
+
   auto collapsedLayout = xegpu::LayoutAttr::get(
-      getContext(),
-      sgLayout32.empty() ? DenseI32ArrayAttr()
-                         : DenseI32ArrayAttr::get(getContext(), sgLayout32),
-      sgData32.empty() ? DenseI32ArrayAttr()
-                       : DenseI32ArrayAttr::get(getContext(), sgData32),
-      instData32.empty() ? DenseI32ArrayAttr()
-                         : DenseI32ArrayAttr::get(getContext(), instData32),
-      laneLayout32.empty() ? DenseI32ArrayAttr()
-                           : DenseI32ArrayAttr::get(getContext(), laneLayout32),
-      laneData32.empty() ? DenseI32ArrayAttr()
-                         : DenseI32ArrayAttr::get(getContext(), laneData32),
-      collapsedOrder.empty()
-          ? DenseI32ArrayAttr()
-          : DenseI32ArrayAttr::get(getContext(), collapsedOrder));
+      getContext(), toAttr(sgLayout32), toAttr(sgData32), toAttr(instData32),
+      toAttr(laneLayout32), toAttr(laneData32), toAttr(collapsedOrder));
   return collapsedLayout;
 }
 
+// Derive a new layout by transpose the layout using `permutation`.
+DistributeLayoutAttr LayoutAttr::transposeDims(ArrayRef<int64_t> permutation) {
+
+  SmallVector<int64_t> origSgLayout = getEffectiveSgLayoutAsInt();
+  SmallVector<int64_t> origSgData = getEffectiveSgDataAsInt();
+  SmallVector<int64_t> origInstData = getEffectiveInstDataAsInt();
+  SmallVector<int64_t> origLaneLayout = getEffectiveLaneLayoutAsInt();
+  SmallVector<int64_t> origLaneData = getEffectiveLaneDataAsInt();
+  SmallVector<int64_t> origOrder = getEffectiveOrderAsInt();
+
+  SmallVector<int32_t> sgLayout;
+  SmallVector<int32_t> sgData;
+  SmallVector<int32_t> instData;
+  SmallVector<int32_t> laneLayout;
+  SmallVector<int32_t> laneData;
+  SmallVector<int32_t> order;
+
+  for (int64_t idx : permutation) {
+    if (!origLaneLayout.empty()) {
+      laneLayout.push_back(static_cast<int32_t>(origLaneLayout[idx]));
+      laneData.push_back(static_cast<int32_t>(origLaneData[idx]));
+    }
+    if (!origInstData.empty())
+      instData.push_back(static_cast<int32_t>(origInstData[idx]));
+    if (!origSgLayout.empty()) {
+      sgLayout.push_back(static_cast<int32_t>(origSgLayout[idx]));
+      sgData.push_back(static_cast<int32_t>(origSgData[idx]));
+    }
+    order.push_back(static_cast<int32_t>(origOrder[idx]));
+  }
+  if (origLaneLayout.empty() && origSgLayout.empty())
+    order.clear();
+
+  auto toAttr = [&](ArrayRef<int32_t> v) -> DenseI32ArrayAttr {
+    return v.empty() ? nullptr : DenseI32ArrayAttr::get(getContext(), v);
+  };
+  return xegpu::LayoutAttr::get(getContext(), toAttr(sgLayout), toAttr(sgData),
+                                toAttr(instData), toAttr(laneLayout),
+                                toAttr(laneData), toAttr(order));
+}
+
+/// Check if this layout is a transpose of another layout.
+bool LayoutAttr::isTransposeOf(const xegpu::DistributeLayoutAttr &other,
+                               ArrayRef<int64_t> perm,
+                               const xegpu::LayoutKind kind) {
+  if (!other)
+    return false;
+  if (getRank() != other.getRank() ||
+      perm.size() != static_cast<size_t>(getRank()))
+    return false;
+  if (!isPermutationVector(perm))
+    return false;
+  auto checkTranspose = [](ArrayRef<int64_t> dst, ArrayRef<int64_t> src,
+                           ArrayRef<int64_t> perm) {
+    for (const auto &ta : llvm::enumerate(perm)) {
+      if (src[ta.index()] != dst[ta.value()])
+        return false;
+    }
+    return true;
+  };
+  if (kind == xegpu::LayoutKind::Subgroup)
+    return checkTranspose(getEffectiveSgLayoutAsInt(),
+                          other.getEffectiveSgLayoutAsInt(), perm) &&
+           checkTranspose(getEffectiveSgDataAsInt(),
+                          other.getEffectiveSgDataAsInt(), perm) &&
+           checkTranspose(getEffectiveOrderAsInt(),
+                          other.getEffectiveOrderAsInt(), perm);
+  if (kind == xegpu::LayoutKind::InstData)
+    return checkTranspose(getEffectiveInstDataAsInt(),
+                          other.getEffectiveInstDataAsInt(), perm);
+  if (kind == xegpu::LayoutKind::Lane)
+    return checkTranspose(getEffectiveLaneLayoutAsInt(),
+                          other.getEffectiveLaneLayoutAsInt(), perm) &&
+           checkTranspose(getEffectiveLaneDataAsInt(),
+                          other.getEffectiveLaneDataAsInt(), perm) &&
+           checkTranspose(getEffectiveOrderAsInt(),
+                          other.getEffectiveOrderAsInt(), perm);
+
+  return false;
+}
+
 //===----------------------------------------------------------------------===//
 // XeGPU_SliceAttr
 //===----------------------------------------------------------------------===//
@@ -881,6 +954,62 @@ DistributeLayoutAttr SliceAttr::collapseDims(SmallVector<int64_t> dimGroup) {
                         DenseI64ArrayAttr::get(getContext(), sliceDims));
 }
 
+SmallVector<int64_t> getPermForParentLayout(ArrayRef<int64_t> sliceDims,
+                                            ArrayRef<int64_t> permutation) {
+  SmallVector<int64_t> sortedSliceDims = llvm::to_vector(sliceDims);
+  llvm::sort(sortedSliceDims);
+
+  for (size_t i = 1; i < sortedSliceDims.size(); ++i) {
+    assert((sortedSliceDims[i] == sortedSliceDims[i - 1] + 1) &&
+           "slice dims non consecutive, cannot be transposed");
+  }
+
+  SmallVector<int64_t> permForParent;
+  if (sortedSliceDims.front() == 0) {
+    // Example: sliceDims.size() = 2, permutation= {1, 0}
+    // result: {3, 2, 1, 0}.
+    for (int64_t dim : permutation)
+      permForParent.push_back(dim + sortedSliceDims.size());
+    for (int64_t i = sortedSliceDims.size() - 1; i >= 0; --i)
+      permForParent.push_back(i);
+  } else {
+    // Example: sliceDims.size() = 2, permutation = {0, 1}
+    // result: {3, 2, 0, 1}.
+    for (int64_t i = sortedSliceDims.size() - 1; i >= 0; --i)
+      permForParent.push_back(i + permutation.size());
+    for (int64_t dim : permutation)
+      permForParent.push_back(dim);
+  }
+  return permForParent;
+}
+
+// Derive a new layout by transpose the layout using `permutation`.
+DistributeLayoutAttr SliceAttr::transposeDims(ArrayRef<int64_t> permutation) {
+  SmallVector<int64_t> sliceDims = llvm::to_vector(getDims().asArrayRef());
+  DistributeLayoutAttr parent = getParent();
+  SmallVector<int64_t> permForParent =
+      getPermForParentLayout(sliceDims, permutation);
+  auto transposedParent = parent.transposeDims(permForParent);
+  return SliceAttr::get(getContext(), transposedParent,
+                        DenseI64ArrayAttr::get(getContext(), sliceDims));
+}
+
+/// Check if this layout is a transpose of another layout.
+bool SliceAttr::isTransposeOf(const xegpu::DistributeLayoutAttr &other,
+                              ArrayRef<int64_t> perm,
+                              const xegpu::LayoutKind kind) {
+  // other must be a SliceAttr with the same slice dims.
+  auto otherSlice = dyn_cast<xegpu::SliceAttr>(other);
+  if (!otherSlice || getDims() != otherSlice.getDims())
+    return false;
+  // check whether the parent layout is transpose of each other.
+  SmallVector<int64_t> sliceDims = llvm::to_vector(getDims().asArrayRef());
+  DistributeLayoutAttr parent = getParent();
+  SmallVector<int64_t> permForParent = getPermForParentLayout(sliceDims, perm);
+  auto otherParent = otherSlice.getParent();
+  return parent.isTransposeOf(otherParent, permForParent, kind);
+}
+
 //===----------------------------------------------------------------------===//
 // XeGPU_RangeAttr
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
index 7aa186bb22224..432886db29d23 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
@@ -178,6 +178,14 @@ xegpu::inferMultiReductionSourceLayout(xegpu::DistributeLayoutAttr resLayout,
   return sliceLayout.getParent();
 }
 
+/// Infers the source layout attribute for a transpose operation given the
+/// result layout attribute and permutation.
+xegpu::DistributeLayoutAttr
+xegpu::inferTransposeSourceLayout(xegpu::DistributeLayoutAttr resLayout,
+                                  ArrayRef<int64_t> permutation) {
+  return resLayout.transposeDims(permutation);
+}
+
 /// Infers the source layout attribute for a bitcast operation given the
 /// result layout attribute, result element type bitwidth, and source element
 /// type bitwidth.
@@ -1144,6 +1152,16 @@ xegpu::DistributeLayoutAttr xegpu::getConsumerLayoutAt(OpOperand &operand) {
     if (idx == 1)
       return resLayout;
   }
+
+  // For vector::TransposeOp, infer source layout from result layout using
+  // permutation.
+  if (auto transpose = dyn_cast<vector::TransposeOp>(op)) {
+    if (!resLayout)
+      return xegpu::DistributeLayoutAttr();
+    return xegpu::inferTransposeSourceLayout(resLayout,
+                                             transpose.getPermutation());
+  }
+
   // For elementwise operations, all operands must have the same layout as the
   // result.
   if (OpTrait::hasElementwiseMappableTraits(op) && op->getNumResults() == 1) {

diff  --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 7f7e8d6ad7734..ab8f7e768ec1c 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -912,9 +912,12 @@ void LayoutInfoPropagation::visitTransposeOp(
   LayoutInfo resultLayout = results[0]->getValue();
   if (!resultLayout.isAssigned())
     return;
-  LayoutInfo newLayout = resultLayout.transpose(transpose.getPermutation());
+  auto consumerLayoutAttr =
+      dyn_cast<xegpu::DistributeLayoutAttr>(resultLayout.get());
+  auto srcLayoutAttr = xegpu::inferTransposeSourceLayout(
+      consumerLayoutAttr, transpose.getPermutation());
   // Propagate the new layout to the vector operand.
-  propagateIfChanged(operands[0], operands[0]->meet(newLayout));
+  propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
 }
 
 /// For vector::BitCastOp, the lane_data of the source layout is changed based

diff  --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index bf9fded8a3abe..38bc95d39c2c6 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -1963,7 +1963,8 @@ struct VectorTransposeDistribution final : public gpu::WarpDistributionPattern {
                        "does not have 2D layout");
     ArrayRef<int64_t> perm = transposeOp.getPermutation();
     // Result layout must be a transpose of source layout.
-    if (!resultLayout.isTransposeOf(sourceLayout, perm))
+    if (!resultLayout.isTransposeOf(sourceLayout, perm,
+                                    xegpu::LayoutKind::Lane))
       return rewriter.notifyMatchFailure(
           transposeOp,
           "the source or result vector layouts must be 2D transposes of each "

diff  --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 30e4a956a0add..139a30e76854f 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -1532,7 +1532,8 @@ struct WgToSgVectorTransposeOp
 
     // Check that sgLayout, sgData & order are properly transposed for source
     // and result
-    if (!layout.isTransposeOf(sourceLayout, permutation))
+    if (!layout.isTransposeOf(sourceLayout, permutation,
+                              xegpu::LayoutKind::Subgroup))
       return rewriter.notifyMatchFailure(
           op, "Result layout is not a valid transpose of source layout "
               "according to permutation");
@@ -1540,13 +1541,13 @@ struct WgToSgVectorTransposeOp
     SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
     VectorType newResultType =
         VectorType::get(sgShape, resultType.getElementType());
+
     SmallVector<Value> newTransposeOps;
     for (auto src : adaptor.getVector()) {
       auto newTranspose = vector::TransposeOp::create(
           rewriter, op.getLoc(), newResultType, src, permutation);
       newTransposeOps.push_back(newTranspose.getResult());
     }
-
     rewriter.replaceOpWithMultiple(op, {newTransposeOps});
     return success();
   }

diff  --git a/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir b/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir
index c073045691f56..ffbe95b2a6f84 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir
@@ -58,12 +58,12 @@ gpu.module @test {
   gpu.func @vector_transpose(%src: memref<256x128xf32>, %src1: memref<128x256xf32>) kernel attributes
       {known_block_size = array<i32: 1, 32, 16>} {
     // CHECK: %[[TDESC_LD:.*]] = xegpu.create_nd_tdesc %[[ARG_0]] : memref<256x128xf32> ->
-    // CHECK-SAME: !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>>
+    // CHECK-SAME: !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], order = [0, 1]>>
     // CHECK: %[[TDESC_ST:.*]] = xegpu.create_nd_tdesc %[[ARG_1]] : memref<128x256xf32> ->
     // CHECK-SAME: !xegpu.tensor_desc<128x256xf32, #xegpu.layout<sg_layout = [4, 8], sg_data = [32, 32]>>
 
-    // CHECK: %[[LOAD:.*]] = xegpu.load_nd %[[TDESC_LD]][0, 0] <{layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>}> :
-    // CHECK-SAME: !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>> -> vector<256x128xf32>
+    // CHECK: %[[LOAD:.*]] = xegpu.load_nd %[[TDESC_LD]][0, 0] <{layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], order = [0, 1]>}> :
+    // CHECK-SAME: !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], order = [0, 1]>> -> vector<256x128xf32>
 
     // CHECK: %[[TRANSPOSED:.*]] = vector.transpose %2, [1, 0]
     // CHECK-SAME {layout_result_0 = #xegpu.layout<sg_layout = [4, 8], sg_data = [32, 32]>} : vector<256x128xf32> to vector<128x256xf32>

diff  --git a/mlir/test/Dialect/XeGPU/propagate-layout.mlir b/mlir/test/Dialect/XeGPU/propagate-layout.mlir
index 4f2349a89b1ed..3253d0004caf4 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout.mlir
@@ -278,9 +278,9 @@ func.func @vector_bitcast_i16_to_f16(%arg0: memref<8x16xi16>, %arg1: memref<16x1
 // -----
 gpu.module @test {
 // CHECK-LABEL: func.func @vector_bitcast_i32_to_f16(
-// CHECK:      %[[LOAD:.*]] = xegpu.load_nd %{{.*}} <{layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>}>
-// CHECK-SAME:     !xegpu.tensor_desc<16x8xi32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>> -> vector<16x8xi32>
-// CHECK-NEXT: %{{.*}} = vector.bitcast %[[LOAD]] {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>}
+// CHECK:      %[[LOAD:.*]] = xegpu.load_nd %{{.*}} <{layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1], order = [0, 1]>}>
+// CHECK-SAME:     !xegpu.tensor_desc<16x8xi32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1], order = [0, 1]>> -> vector<16x8xi32>
+// CHECK-NEXT: %{{.*}} = vector.bitcast %[[LOAD]] {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2], order = [0, 1]>}
 // CHECK-SAME:     vector<16x8xi32> to vector<16x16xf16>
 func.func @vector_bitcast_i32_to_f16(%arg0: memref<8x16xf16>, %arg1: memref<16x8xi32>, %arg2: memref<8x16xf32>) {
   %c0 = arith.constant 0 : index


        


More information about the Mlir-commits mailing list