[Mlir-commits] [mlir] [MLIR][XeGPU] Refactor isEvenlyDistributable() to Layout attribute interface (PR #191945)

Jianhui Li llvmlistbot at llvm.org
Mon Apr 13 22:07:18 PDT 2026


https://github.com/Jianhui-Li created https://github.com/llvm/llvm-project/pull/191945

This PR refactor isEvenlyDistributable() to layout attribute interface isDistributable(), and used them in all anchor operations to check the shape can be ditributed with the anchor layout. 

>From 2bb5d45698c9dc18a5f150a79e51a4b1258cae6c Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Tue, 14 Apr 2026 02:53:36 +0000
Subject: [PATCH 1/2] refactor isEvenlyDistributable

---
 .../mlir/Dialect/XeGPU/IR/XeGPUAttrs.td       | 60 +++++++++++--
 .../mlir/Dialect/XeGPU/IR/XeGPUDialect.td     |  4 -
 mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp    | 88 +------------------
 mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp        | 40 ++++++++-
 .../Transforms/XeGPUWgToSgDistribute.cpp      |  2 +-
 mlir/test/Dialect/XeGPU/invalid.mlir          |  8 --
 .../XeGPU/propagate-layout-subgroup.mlir      | 16 ++--
 mlir/test/Dialect/XeGPU/transform-ops.mlir    | 22 ++---
 8 files changed, 114 insertions(+), 126 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index f8a2beabb9b95..a7bea9881602f 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -188,6 +188,9 @@ def DistributeLayoutAttr: AttrInterface<"DistributeLayoutAttr"> {
     InterfaceMethod<"Check the availability of subgroup level layouts",
                     "bool",
                     "isForSubgroup">,
+    InterfaceMethod<"Check the availability of lane level layouts",
+                    "bool",
+                    "isForLane">,
     InterfaceMethod<"Get the rank of attribute",
                     "int64_t",
                     "getRank">,
@@ -299,26 +302,60 @@ def DistributeLayoutAttr: AttrInterface<"DistributeLayoutAttr"> {
                       } else {
                         return failure();
                       }
-                      assert(
-                          !subShape.empty() &&
-                          "sgdata or lanedata cannot be empty for distributed shape computation");
+                      // sgdata or lanedata cannot be empty for distributed shape computation
+                      if (subShape.empty())
+                        return failure();
                       SmallVector<int64_t> distributedShape(shape.size());
                       for (auto [i, dim] : llvm::enumerate(shape)) {
                         int64_t distriUnit = layout[i]*subShape[i];
                         if ((dim % distriUnit) == 0) {
                           // Evenly divisible case, divide the dimension by the layout factor.
                           distributedShape[i] = dim / layout[i];
-                          assert((distributedShape[i] % subShape[i] == 0) &&
-                                "Even distribution: sgdata or lanedata must divide the distributed dimension");
+                          if (distributedShape[i] % subShape[i] != 0)
+                            return failure();
                         } else {
                           // wrap around case, the dimension size must be equal to subShape value
-                          assert(dim == subShape[i] &&
-                                "Wrap-around distribution: sgdata or lanedata must be same as tensor tile shape");
+                          if(dim != subShape[i])
+                            return failure();
                           distributedShape[i] = dim;
                         }
                       }
                       return distributedShape;
                     }]>,
+    InterfaceMethod<[{Checks if the given shape can be distributed by the layout}],
+                    /*retTy=*/"bool",
+                    /*methodName=*/"isDistributable",
+                    /*args=*/(ins "SmallVector<int64_t>":$shape),
+                    /*methodBody=*/[{
+                        DistributeLayoutAttr curLayoutAttr = $_self;
+                        SmallVector<int64_t> curShape = shape;
+                        // Phase 1: Distribute across subgroups (sg_layout + sg_data).
+                        if (curLayoutAttr.isForWorkgroup()) {
+                          auto maybeSgShape = curLayoutAttr.computeDistributedShape(curShape);
+                          if (failed(maybeSgShape))
+                            return false;
+                          curShape = maybeSgShape.value();
+                          curLayoutAttr = curLayoutAttr.dropSgLayoutAndData();
+                          if (!curLayoutAttr)
+                            return true;
+                        }
+                        // Phase 2: Distribute across instruction data (inst_data).
+                        if (curLayoutAttr.isForSubgroup() && !curLayoutAttr.isForLane()) {
+                          SmallVector<int64_t> instData = curLayoutAttr.getEffectiveInstDataAsInt();
+                          for (size_t i = 0; i < curShape.size(); ++i) {
+                            if (curShape[i] % instData[i] != 0)
+                              return false;
+                          }
+                          // inst_data becomes the new shape for next phase
+                          curShape = instData;
+                          curLayoutAttr = curLayoutAttr.dropInstData();
+                          if (!curLayoutAttr)
+                            return true;
+                        }
+                        // Phase 3: Distribute across lanes (lane_layout + lane_data).
+                        auto maybeLaneShape = curLayoutAttr.computeDistributedShape(curShape);
+                        return succeeded(maybeLaneShape);
+                      }]>,
     InterfaceMethod</*desc=*/[{Check if this layout is a slice of another layout.}],
                     /*retTy=*/"bool",
                     /*methodName=*/"isSliceOf",
@@ -487,6 +524,10 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttr]> {
       return !isForWorkgroup();
     }
 
+    bool isForLane() {
+      return !isForWorkgroup() && (getInstData() == nullptr);
+    }
+
     int64_t getRank() const {
       if (auto attr = getSgLayout())
         return attr.size();
@@ -687,6 +728,11 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttr]> {
       return parent.isForSubgroup();
     }
 
+    bool isForLane() const {
+      auto parent = dyn_cast<LayoutAttr>(getParent());
+      return parent.isForLane();
+    }
+
     /// Returns the SgLayout of the attribute, computed by applying
     /// the slice dimensions to the underlying LayoutAttr.
     SmallVector<int64_t> getEffectiveSgLayoutAsInt() const {
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUDialect.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUDialect.td
index c173b93face98..84fd8f9e0060c 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUDialect.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUDialect.td
@@ -38,10 +38,6 @@ def XeGPU_Dialect : Dialect {
     let useDefaultAttributePrinterParser = true;
 
     let extraClassDeclaration = [{
-      /// Checks if the given shape can be evenly distributed based on the layout
-      /// and data factors provided by the LayoutAttr.
-      static bool isEvenlyDistributable(llvm::ArrayRef<int64_t> shape, xegpu::DistributeLayoutAttr attr);
-
       /// drops/slices the shape in the specified dims, and return the rest. e.g.,
       /// for shape = [32, 64, 8], dims = [0, 2], it will return [64]
       template<typename T, typename U>
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index eaa43c02946d8..80a3fc91f1c4f 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -121,74 +121,6 @@ static SmallVector<SmallVector<int64_t>> genStaticCoordinates(
   return coordinates;
 }
 
-// Checks if the given shape can be evenly distributed based on the layout
-// and data factors provided by the LayoutAttr.
-bool XeGPUDialect::isEvenlyDistributable(llvm::ArrayRef<int64_t> shape,
-                                         xegpu::DistributeLayoutAttr attr) {
-  assert(attr && "Layout attribute is missing.");
-
-  // Checks whether the given shape can be evenly distributed using the
-  // specified layout and data attributes. If successful, it returns the work
-  // size for each compute unit; otherwise, it returns `std::nullopt`. The work
-  // size per compute unit is calculated as follows:
-  //   - If `data` is null: newShape[i] = shape[i] / layout[i]
-  //   - If `data` is not null: newShape[i] = data[i]
-  // When round-robin distribution (`rr`) is enabled, `shape[i]` can be
-  // smaller than `layout[i] * data[i]`, allowing multiple compute units to
-  // share the data.
-  auto tryDistribute = [&](llvm::ArrayRef<int64_t> shape,
-                           SmallVector<int64_t> layout,
-                           SmallVector<int64_t> data,
-                           bool rr = true) -> optional<SmallVector<int64_t>> {
-    llvm::SmallVector<int64_t> newShape(shape);
-    if (layout.size()) {
-      if (layout.size() != shape.size())
-        return std::nullopt;
-      auto ratio = computeShapeRatio(shape, layout);
-      if (ratio.has_value()) {
-        newShape = ratio.value();
-      } else if (!rr || !computeShapeRatio(layout, shape).has_value()) {
-        return std::nullopt;
-      }
-      // Round-robin case: continue with original newShape
-    }
-
-    if (data.size()) {
-      if (data.size() != shape.size())
-        return std::nullopt;
-      auto ratio = computeShapeRatio(newShape, data);
-      if (!ratio.has_value() && rr)
-        ratio = computeShapeRatio(data, newShape);
-      if (!ratio.has_value())
-        return std::nullopt;
-
-      // if data is not null, we always return it for next phase.
-      newShape = data;
-    }
-    return newShape;
-  };
-
-  // check the sgLayout and sgData
-  auto maybeSgShape = tryDistribute(shape, attr.getEffectiveSgLayoutAsInt(),
-                                    attr.getEffectiveSgDataAsInt());
-  if (!maybeSgShape)
-    return false;
-  auto sgShape = maybeSgShape.value();
-
-  // check InstData, it neither have layout nor need round-robin
-  auto maybeInstShape =
-      tryDistribute(sgShape, {}, attr.getEffectiveInstDataAsInt(), false);
-  if (!maybeInstShape)
-    return false;
-  auto instShape = maybeInstShape.value();
-
-  // check LaneLayout and LaneData
-  auto maybeLaneShape =
-      tryDistribute(instShape, attr.getEffectiveLaneLayoutAsInt(),
-                    attr.getEffectiveLaneDataAsInt());
-  return maybeLaneShape.has_value();
-}
-
 //===----------------------------------------------------------------------===//
 // XeGPU_BlockTensorDescAttr
 //===----------------------------------------------------------------------===//
@@ -1431,25 +1363,12 @@ TensorDescType::verify(llvm::function_ref<InFlightDiagnostic()> emitError,
                            << chunkAlignmentFactor;
     }
   }
-
-  auto layoutAttr = llvm::dyn_cast_if_present<LayoutAttr>(layout);
-  if (layoutAttr) {
+  if (auto layoutAttr =
+          mlir::dyn_cast_if_present<DistributeLayoutAttr>(layout)) {
     if (rank != (size_t)layoutAttr.getRank())
       return emitError() << "expected layout rank to match tensor rank";
 
-    auto laneData = layoutAttr.getLaneData();
-    if (scatterAttr && laneData) {
-      // Validate subgroup mapping rules for scattered tensors.
-      // if chunkSize > 1, the last dimension of the tensor should
-      // be distributed in the units divisible by chunkAlignmentFactor.
-      int64_t chunkSize = scatterAttr.getChunkSizeAsInt();
-      if (chunkSize > 1 && laneData[rank - 1] % chunkAlignmentFactor)
-        return emitError()
-               << "expected last dim of lane_data to be a multiple of: "
-               << chunkAlignmentFactor;
-    }
-
-    if (!XeGPUDialect::isEvenlyDistributable(shape, layoutAttr)) {
+    if (!layoutAttr.isDistributable(SmallVector<int64_t>(shape))) {
       std::string shapeStr;
       llvm::raw_string_ostream stream(shapeStr);
       llvm::interleaveComma(shape, stream);
@@ -1457,6 +1376,7 @@ TensorDescType::verify(llvm::function_ref<InFlightDiagnostic()> emitError,
                          << layoutAttr;
     }
   }
+
   return success();
 }
 
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 5697097a4c999..dae041544753c 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -504,6 +504,14 @@ LogicalResult PrefetchNdOp::verify() {
     return emitOpError(
         "Mismatched ranks between offsets and tensor descriptor");
 
+  if (getAnchorLayout()) {
+    auto layout = getAnchorLayout();
+    auto tdescShape = getShapeOf(tdescTy);
+    if (!layout.isDistributable(tdescShape))
+      return emitOpError(
+          "TensorDesc shape is not distributable with the layout");
+  }
+
   return success();
 }
 
@@ -628,6 +636,13 @@ LogicalResult LoadNdOp::verify() {
     return emitOpError(
         "Mismatched ranks between offsets and tensor descriptor");
 
+  if (getAnchorLayout()) {
+    auto layout = getAnchorLayout();
+    if (!layout.isDistributable(tdescShape))
+      return emitOpError(
+          "TensorDesc shape is not distributable with the layout");
+  }
+
   return success();
 }
 
@@ -721,6 +736,13 @@ LogicalResult StoreNdOp::verify() {
     return emitOpError(
         "Mismatched ranks between offsets and tensor descriptor");
 
+  if (getAnchorLayout()) {
+    auto layout = getAnchorLayout();
+    if (!layout.isDistributable(tdescShape))
+      return emitOpError(
+          "TensorDesc shape is not distributable with the layout");
+  }
+
   return success();
 }
 
@@ -823,6 +845,18 @@ LogicalResult PrefetchOp::verify() {
   if (getOffsetAlignByteAttr() && !srcTy.isInteger())
     return emitOpError("offset_align_byte only allowed with integer source.");
 
+  if (getAnchorLayout()) {
+    auto layout = getAnchorLayout();
+    // get the offset operand and its shape
+    auto offsets = getOffsets();
+    auto offsetsTy = offsets.getType();
+    if (!llvm::isa<VectorType>(offsetsTy))
+      return emitOpError("Offsets should be a vector.");
+    auto offsetShape = getShapeOf(offsetsTy);
+    if (!layout.isDistributable(offsetShape))
+      return emitOpError("offset shape is not distributable with the layout");
+  }
+
   return success();
 }
 
@@ -1103,12 +1137,12 @@ LogicalResult ConvertLayoutOp::verify() {
 
   Type srcType = getSource().getType();
   if (llvm::isa<VectorType>(srcType)) {
-    auto shape = llvm::cast<VectorType>(srcType).getShape();
-    if (!XeGPUDialect::isEvenlyDistributable(shape, srcLayout))
+    SmallVector<int64_t> shape(llvm::cast<VectorType>(srcType).getShape());
+    if (!srcLayout.isDistributable(shape))
       return emitOpError(
           "invalid input layout, data cannot be evenly distributed.");
 
-    if (!XeGPUDialect::isEvenlyDistributable(shape, resLayout))
+    if (!resLayout.isDistributable(shape))
       return emitOpError(
           "invalid target layout, data cannot be evenly distributed.");
   }
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index a095c19d66c15..d637b6828deab 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -498,7 +498,7 @@ struct WgToSgVectorBroadcastOp
     VectorType newResultType =
         VectorType::get(sgShape, resultType.getElementType());
 
-    if (!xegpu::XeGPUDialect::isEvenlyDistributable(wgShape, layout))
+    if (!layout.isDistributable(SmallVector<int64_t>(wgShape)))
       return failure();
 
     SmallVector<Value> newBroadcastOps;
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index 7390b47b3f8d9..82c7879c79d56 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -325,14 +325,6 @@ func.func @create_tdesc_layout_1(%src: ui64) {
   return
 }
 
-// -----
-func.func @create_tdesc_layout_2(%src: ui64) {
-  %cst = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
-  // expected-error at +1 {{expected last dim of lane_data to be a multiple of: 2}}
-  %1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x4xf16, #xegpu.scatter_tdesc_attr<chunk_size = 4>, #xegpu.layout<lane_layout = [4, 1], lane_data = [1, 1]>>
-  return
-}
-
 // -----
 func.func @load_gather_simt_1(%src: ui64) {
   %0 = arith.constant dense<1>: vector<4xi1>
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir b/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir
index 831d1e05967f8..62426d619445b 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir
@@ -27,24 +27,24 @@ gpu.module @test {
   // CHECK-SAME: %[[ARG_1:.*]]: memref<128x256xf32>
   func.func @vector_transpose(%src: memref<256x128xf32>, %src1: memref<128x256xf32>) {
     // CHECK: %[[TDESC_LD:.*]] = xegpu.create_nd_tdesc %[[ARG_0]] : memref<256x128xf32> ->
-    // CHECK-SAME: !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [4, 8], sg_data = [64, 32], order = [0, 1]>>
+    // CHECK-SAME: !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [4, 8], sg_data = [64, 16], order = [0, 1]>>
     // CHECK: %[[TDESC_ST:.*]] = xegpu.create_nd_tdesc %[[ARG_1]] : memref<128x256xf32> ->
-    // CHECK-SAME: !xegpu.tensor_desc<128x256xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], order = [1, 0]>>
+    // CHECK-SAME: !xegpu.tensor_desc<128x256xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 64], order = [1, 0]>>
 
-    // CHECK: %[[LOAD:.*]] = xegpu.load_nd %[[TDESC_LD]][0, 0] <{layout = #xegpu.layout<sg_layout = [4, 8], sg_data = [64, 32], order = [0, 1]>}>
-    // CHECK-SAME: !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [4, 8], sg_data = [64, 32], order = [0, 1]>> -> vector<256x128xf32>
+    // CHECK: %[[LOAD:.*]] = xegpu.load_nd %[[TDESC_LD]][0, 0] <{layout = #xegpu.layout<sg_layout = [4, 8], sg_data = [64, 16], order = [0, 1]>}>
+    // CHECK-SAME: !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [4, 8], sg_data = [64, 16], order = [0, 1]>> -> vector<256x128xf32>
 
     // CHECK: %[[TRANSPOSED:.*]] = vector.transpose %2, [1, 0]
-    // CHECK-SAME {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], order = [1, 0]>} : vector<256x128xf32> to vector<128x256xf32>
+    // CHECK-SAME {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 64], order = [1, 0]>} : vector<256x128xf32> to vector<128x256xf32>
 
     // CHECK: xegpu.store_nd %[[TRANSPOSED]], %[[TDESC_ST]][0, 0]
-    // CHECK-SAME: <{layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], order = [1, 0]>}> : vector<128x256xf32>,
-    // CHECK-SAME: !xegpu.tensor_desc<128x256xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], order = [1, 0]>>
+    // CHECK-SAME: <{layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 64], order = [1, 0]>}> : vector<128x256xf32>,
+    // CHECK-SAME: !xegpu.tensor_desc<128x256xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 64], order = [1, 0]>>
     %tdesc = xegpu.create_nd_tdesc %src : memref<256x128xf32> -> !xegpu.tensor_desc<256x128xf32>
     %tdesc1 = xegpu.create_nd_tdesc %src1 : memref<128x256xf32> -> !xegpu.tensor_desc<128x256xf32>
     %load = xegpu.load_nd %tdesc[0, 0] : !xegpu.tensor_desc<256x128xf32> -> vector<256x128xf32>
     %trans = vector.transpose %load, [1, 0] : vector<256x128xf32> to vector<128x256xf32>
-    xegpu.store_nd %trans, %tdesc1[0, 0] <{layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], order = [1, 0]>}>
+    xegpu.store_nd %trans, %tdesc1[0, 0] <{layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 64], order = [1, 0]>}>
         : vector<128x256xf32>, !xegpu.tensor_desc<128x256xf32>
     return
   }
diff --git a/mlir/test/Dialect/XeGPU/transform-ops.mlir b/mlir/test/Dialect/XeGPU/transform-ops.mlir
index acba80d870253..3daa74d223946 100644
--- a/mlir/test/Dialect/XeGPU/transform-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/transform-ops.mlir
@@ -83,7 +83,7 @@ module attributes {transform.with_named_sequence} {
 func.func @set_anchor_layout(%arg0: memref<4096x4096xf16>) {
   %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
   // CHECK: = xegpu.load_nd %0[0, 0]
-  // CHECK-SAME: <{layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], inst_data = [8, 16]>}>
+  // CHECK-SAME: <{layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [8, 16]>}>
   %1 = xegpu.load_nd %0[0, 0]  : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
   return
 }
@@ -92,7 +92,7 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
     %0 = transform.structured.match ops{["xegpu.load_nd"]} in %arg1 : (!transform.any_op) -> !transform.any_op
     // CHECK: transform.xegpu.set_anchor_layout %{{.*}}
-    transform.xegpu.set_anchor_layout %0 index = 0 sg_layout = [8, 4] sg_data = [32, 64] inst_data = [8, 16] : !transform.any_op
+    transform.xegpu.set_anchor_layout %0 index = 0 sg_layout = [8, 4] sg_data = [32, 32] inst_data = [8, 16] : !transform.any_op
     transform.yield
   }
 }
@@ -103,9 +103,9 @@ module attributes {transform.with_named_sequence} {
 func.func @set_anchor_layout_multiple(%arg0: memref<4096x4096xf16>) {
   %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
   // CHECK: xegpu.prefetch_nd %0[0, 0]
-  // CHECK-SAME: <{layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], inst_data = [8, 16]>}>
+  // CHECK-SAME: <{layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [8, 16]>}>
   // CHECK: xegpu.prefetch_nd %0[16, 0]
-  // CHECK-SAME: <{layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], inst_data = [8, 16]>}>
+  // CHECK-SAME: <{layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [8, 16]>}>
   xegpu.prefetch_nd %0[0, 0] : !xegpu.tensor_desc<256x32xf16>
   xegpu.prefetch_nd %0[16, 0] : !xegpu.tensor_desc<256x32xf16>
   return
@@ -115,7 +115,7 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
     %0 = transform.structured.match ops{["xegpu.prefetch_nd"]} in %arg1 : (!transform.any_op) -> !transform.any_op
     // CHECK: transform.xegpu.set_anchor_layout %{{.*}}
-    transform.xegpu.set_anchor_layout %0 index = 0 sg_layout = [8, 4] sg_data = [32, 64] inst_data = [8, 16] : !transform.any_op
+    transform.xegpu.set_anchor_layout %0 index = 0 sg_layout = [8, 4] sg_data = [32, 32] inst_data = [8, 16] : !transform.any_op
     transform.yield
   }
 }
@@ -126,7 +126,7 @@ module attributes {transform.with_named_sequence} {
 func.func @set_anchor_layout_param(%arg0: memref<4096x4096xf16>) {
   %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
   // CHECK: = xegpu.load_nd %0[0, 0]
-  // CHECK-SAME: <{layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], inst_data = [8, 16]>}>
+  // CHECK-SAME: <{layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [8, 16]>}>
   %1 = xegpu.load_nd %0[0, 0]  : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
   return
 }
@@ -136,7 +136,7 @@ module attributes {transform.with_named_sequence} {
     %0 = transform.structured.match ops{["xegpu.load_nd"]} in %arg1 : (!transform.any_op) -> !transform.any_op
     // CHECK: transform.xegpu.set_anchor_layout %{{.*}}
     %layout0 = transform.param.constant 8 : i64 -> !transform.param<i64>
-    transform.xegpu.set_anchor_layout %0 index = 0 sg_layout = [%layout0, 4] sg_data = [32, 64] inst_data = [8, 16] : !transform.any_op, !transform.param<i64>
+    transform.xegpu.set_anchor_layout %0 index = 0 sg_layout = [%layout0, 4] sg_data = [32, 32] inst_data = [8, 16] : !transform.any_op, !transform.param<i64>
     transform.yield
   }
 }
@@ -147,7 +147,7 @@ module attributes {transform.with_named_sequence} {
 func.func @set_anchor_layout_param2(%arg0: memref<4096x4096xf16>) {
   %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
   // CHECK: = xegpu.load_nd %0[0, 0]
-  // CHECK-SAME: <{layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], inst_data = [8, 16]>}>
+  // CHECK-SAME: <{layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [8, 16]>}>
   %1 = xegpu.load_nd %0[0, 0]  : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
   return
 }
@@ -158,7 +158,7 @@ module attributes {transform.with_named_sequence} {
     // CHECK: transform.xegpu.set_anchor_layout %{{.*}}
     %layout0 = transform.param.constant 8 : i64 -> !transform.param<i64>
     %layout1 = transform.param.constant 4 : i64 -> !transform.param<i64>
-    transform.xegpu.set_anchor_layout %0 index = 0 sg_layout = [%layout0, %layout1] sg_data = [32, 64] inst_data = [8, 16] : !transform.any_op, !transform.param<i64>, !transform.param<i64>
+    transform.xegpu.set_anchor_layout %0 index = 0 sg_layout = [%layout0, %layout1] sg_data = [32, 32] inst_data = [8, 16] : !transform.any_op, !transform.param<i64>, !transform.param<i64>
     transform.yield
   }
 }
@@ -192,7 +192,7 @@ module attributes {transform.with_named_sequence} {
 func.func @set_anchor_layout_order(%arg0: memref<4096x4096xf16>) {
   %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
   // CHECK: = xegpu.load_nd %0[0, 0]
-  // CHECK-SAME: <{layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], inst_data = [8, 16], order = [1, 0]>}>
+  // CHECK-SAME: <{layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [8, 16], order = [1, 0]>}>
   %1 = xegpu.load_nd %0[0, 0]  : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
   return
 }
@@ -201,7 +201,7 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
     %0 = transform.structured.match ops{["xegpu.load_nd"]} in %arg1 : (!transform.any_op) -> !transform.any_op
     // CHECK: transform.xegpu.set_anchor_layout %{{.*}}
-    transform.xegpu.set_anchor_layout %0 index = 0 sg_layout = [8, 4] sg_data = [32, 64] inst_data = [8, 16] order = [1, 0] : !transform.any_op
+    transform.xegpu.set_anchor_layout %0 index = 0 sg_layout = [8, 4] sg_data = [32, 32] inst_data = [8, 16] order = [1, 0] : !transform.any_op
     transform.yield
   }
 }

>From aa7525b4beb0e35a95c3b78c74bd1f833e0d40a6 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Tue, 14 Apr 2026 05:03:35 +0000
Subject: [PATCH 2/2] add isDistributableCheck for all anchor operations

---
 mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 52 ++++++++++++++++++++++----
 1 file changed, 44 insertions(+), 8 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index dae041544753c..9107cda30a8fa 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -218,6 +218,11 @@ IsValidMatrixOpParams(VectorType dataTy, MemDescType mdescTy,
       }
     }
   }
+
+  if (layout && !layout.isDistributable(
+                    SmallVector<int64_t>(dataShape.begin(), dataShape.end())))
+    return emitError() << "Value shape is not distributable with the layout";
+
   if (dataShape.size() == 2) {
     if (llvm::any_of(llvm::zip_equal(dataShape, mdescShape),
                      [](auto p) { return std::get<0>(p) > std::get<1>(p); }))
@@ -638,7 +643,8 @@ LogicalResult LoadNdOp::verify() {
 
   if (getAnchorLayout()) {
     auto layout = getAnchorLayout();
-    if (!layout.isDistributable(tdescShape))
+    auto origTdescShape = getShapeOf(tdescTy);
+    if (!layout.isDistributable(origTdescShape))
       return emitOpError(
           "TensorDesc shape is not distributable with the layout");
   }
@@ -848,13 +854,14 @@ LogicalResult PrefetchOp::verify() {
   if (getAnchorLayout()) {
     auto layout = getAnchorLayout();
     // get the offset operand and its shape
-    auto offsets = getOffsets();
-    auto offsetsTy = offsets.getType();
-    if (!llvm::isa<VectorType>(offsetsTy))
-      return emitOpError("Offsets should be a vector.");
-    auto offsetShape = getShapeOf(offsetsTy);
-    if (!layout.isDistributable(offsetShape))
-      return emitOpError("offset shape is not distributable with the layout");
+    if (auto offsets = getOffsets()) {
+      auto offsetsTy = offsets.getType();
+      if (!llvm::isa<VectorType>(offsetsTy))
+        return emitOpError("Offsets should be a vector.");
+      auto offsetShape = getShapeOf(offsetsTy);
+      if (!layout.isDistributable(offsetShape))
+        return emitOpError("offset shape is not distributable with the layout");
+    }
   }
 
   return success();
@@ -904,6 +911,13 @@ LogicalResult LoadGatherOp::verify() {
   if (memTy && (getElementType() != memTy.getElementType()))
     return emitError() << "Value should have the same element type as MemRef.";
 
+  if (getAnchorLayout()) {
+    auto layout = getAnchorLayout();
+    auto valShape = getShapeOf(valueTy);
+    if (!layout.isDistributable(valShape))
+      return emitOpError("Value shape is not distributable with the layout");
+  }
+
   auto offsetsTy = getOffsets().getType();
   return isValidGatherScatterBufferParams(offsetsTy, maskTy, valueTy, chunkSize,
                                           [&]() { return emitOpError(); });
@@ -988,6 +1002,13 @@ LogicalResult StoreScatterOp::verify() {
   if (memTy && (getElementType() != memTy.getElementType()))
     return emitError() << "Value should have the same element type as MemRef.";
 
+  if (getAnchorLayout()) {
+    auto layout = getAnchorLayout();
+    auto valShape = getShapeOf(valueTy);
+    if (!layout.isDistributable(valShape))
+      return emitOpError("Value shape is not distributable with the layout");
+  }
+
   auto offsetsTy = getOffsets().getType();
   return isValidGatherScatterBufferParams(offsetsTy, maskTy, valueTy, chunkSize,
                                           [&]() { return emitOpError(); });
@@ -1086,6 +1107,21 @@ LogicalResult DpasOp::verify() {
   auto rhsShape = getRhsType().getShape();
   auto resShape = getResultType().getShape();
 
+  if (auto cdLayout = getLayoutCd())
+    if (!cdLayout->isDistributable(
+            SmallVector<int64_t>(resShape.begin(), resShape.end())))
+      return emitOpError("Value shape is not distributable with the layout");
+
+  if (auto aLayout = getLayoutA())
+    if (!aLayout->isDistributable(
+            SmallVector<int64_t>(lhsShape.begin(), lhsShape.end())))
+      return emitOpError("Value shape is not distributable with the layout");
+
+  if (auto bLayout = getLayoutB())
+    if (!bLayout->isDistributable(
+            SmallVector<int64_t>(rhsShape.begin(), rhsShape.end())))
+      return emitOpError("Value shape is not distributable with the layout");
+
   if (getAcc() && getAcc().getType() != getResultType())
     return emitOpError("Expecting the acc type to be the same as result.");
 



More information about the Mlir-commits mailing list