[Mlir-commits] [mlir] [MLIR][XeGPU] Refactor isEvenlyDistributable() to Layout attribute interface (PR #191945)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Apr 13 22:07:52 PDT 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-gpu

Author: Jianhui Li (Jianhui-Li)

<details>
<summary>Changes</summary>

This PR refactor isEvenlyDistributable() to layout attribute interface isDistributable(), and used them in all anchor operations to check the shape can be ditributed with the anchor layout. 

---

Patch is 27.95 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/191945.diff


8 Files Affected:

- (modified) mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td (+53-7) 
- (modified) mlir/include/mlir/Dialect/XeGPU/IR/XeGPUDialect.td (-4) 
- (modified) mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp (+4-84) 
- (modified) mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp (+73-3) 
- (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp (+1-1) 
- (modified) mlir/test/Dialect/XeGPU/invalid.mlir (-8) 
- (modified) mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir (+8-8) 
- (modified) mlir/test/Dialect/XeGPU/transform-ops.mlir (+11-11) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index f8a2beabb9b95..a7bea9881602f 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -188,6 +188,9 @@ def DistributeLayoutAttr: AttrInterface<"DistributeLayoutAttr"> {
     InterfaceMethod<"Check the availability of subgroup level layouts",
                     "bool",
                     "isForSubgroup">,
+    InterfaceMethod<"Check the availability of lane level layouts",
+                    "bool",
+                    "isForLane">,
     InterfaceMethod<"Get the rank of attribute",
                     "int64_t",
                     "getRank">,
@@ -299,26 +302,60 @@ def DistributeLayoutAttr: AttrInterface<"DistributeLayoutAttr"> {
                       } else {
                         return failure();
                       }
-                      assert(
-                          !subShape.empty() &&
-                          "sgdata or lanedata cannot be empty for distributed shape computation");
+                      // sgdata or lanedata cannot be empty for distributed shape computation
+                      if (subShape.empty())
+                        return failure();
                       SmallVector<int64_t> distributedShape(shape.size());
                       for (auto [i, dim] : llvm::enumerate(shape)) {
                         int64_t distriUnit = layout[i]*subShape[i];
                         if ((dim % distriUnit) == 0) {
                           // Evenly divisible case, divide the dimension by the layout factor.
                           distributedShape[i] = dim / layout[i];
-                          assert((distributedShape[i] % subShape[i] == 0) &&
-                                "Even distribution: sgdata or lanedata must divide the distributed dimension");
+                          if (distributedShape[i] % subShape[i] != 0)
+                            return failure();
                         } else {
                           // wrap around case, the dimension size must be equal to subShape value
-                          assert(dim == subShape[i] &&
-                                "Wrap-around distribution: sgdata or lanedata must be same as tensor tile shape");
+                          if(dim != subShape[i])
+                            return failure();
                           distributedShape[i] = dim;
                         }
                       }
                       return distributedShape;
                     }]>,
+    InterfaceMethod<[{Checks if the given shape can be distributed by the layout}],
+                    /*retTy=*/"bool",
+                    /*methodName=*/"isDistributable",
+                    /*args=*/(ins "SmallVector<int64_t>":$shape),
+                    /*methodBody=*/[{
+                        DistributeLayoutAttr curLayoutAttr = $_self;
+                        SmallVector<int64_t> curShape = shape;
+                        // Phase 1: Distribute across subgroups (sg_layout + sg_data).
+                        if (curLayoutAttr.isForWorkgroup()) {
+                          auto maybeSgShape = curLayoutAttr.computeDistributedShape(curShape);
+                          if (failed(maybeSgShape))
+                            return false;
+                          curShape = maybeSgShape.value();
+                          curLayoutAttr = curLayoutAttr.dropSgLayoutAndData();
+                          if (!curLayoutAttr)
+                            return true;
+                        }
+                        // Phase 2: Distribute across instruction data (inst_data).
+                        if (curLayoutAttr.isForSubgroup() && !curLayoutAttr.isForLane()) {
+                          SmallVector<int64_t> instData = curLayoutAttr.getEffectiveInstDataAsInt();
+                          for (size_t i = 0; i < curShape.size(); ++i) {
+                            if (curShape[i] % instData[i] != 0)
+                              return false;
+                          }
+                          // inst_data becomes the new shape for next phase
+                          curShape = instData;
+                          curLayoutAttr = curLayoutAttr.dropInstData();
+                          if (!curLayoutAttr)
+                            return true;
+                        }
+                        // Phase 3: Distribute across lanes (lane_layout + lane_data).
+                        auto maybeLaneShape = curLayoutAttr.computeDistributedShape(curShape);
+                        return succeeded(maybeLaneShape);
+                      }]>,
     InterfaceMethod</*desc=*/[{Check if this layout is a slice of another layout.}],
                     /*retTy=*/"bool",
                     /*methodName=*/"isSliceOf",
@@ -487,6 +524,10 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttr]> {
       return !isForWorkgroup();
     }
 
+    bool isForLane() {
+      return !isForWorkgroup() && (getInstData() == nullptr);
+    }
+
     int64_t getRank() const {
       if (auto attr = getSgLayout())
         return attr.size();
@@ -687,6 +728,11 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttr]> {
       return parent.isForSubgroup();
     }
 
+    bool isForLane() const {
+      auto parent = dyn_cast<LayoutAttr>(getParent());
+      return parent.isForLane();
+    }
+
     /// Returns the SgLayout of the attribute, computed by applying
     /// the slice dimensions to the underlying LayoutAttr.
     SmallVector<int64_t> getEffectiveSgLayoutAsInt() const {
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUDialect.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUDialect.td
index c173b93face98..84fd8f9e0060c 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUDialect.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUDialect.td
@@ -38,10 +38,6 @@ def XeGPU_Dialect : Dialect {
     let useDefaultAttributePrinterParser = true;
 
     let extraClassDeclaration = [{
-      /// Checks if the given shape can be evenly distributed based on the layout
-      /// and data factors provided by the LayoutAttr.
-      static bool isEvenlyDistributable(llvm::ArrayRef<int64_t> shape, xegpu::DistributeLayoutAttr attr);
-
       /// drops/slices the shape in the specified dims, and return the rest. e.g.,
       /// for shape = [32, 64, 8], dims = [0, 2], it will return [64]
       template<typename T, typename U>
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index eaa43c02946d8..80a3fc91f1c4f 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -121,74 +121,6 @@ static SmallVector<SmallVector<int64_t>> genStaticCoordinates(
   return coordinates;
 }
 
-// Checks if the given shape can be evenly distributed based on the layout
-// and data factors provided by the LayoutAttr.
-bool XeGPUDialect::isEvenlyDistributable(llvm::ArrayRef<int64_t> shape,
-                                         xegpu::DistributeLayoutAttr attr) {
-  assert(attr && "Layout attribute is missing.");
-
-  // Checks whether the given shape can be evenly distributed using the
-  // specified layout and data attributes. If successful, it returns the work
-  // size for each compute unit; otherwise, it returns `std::nullopt`. The work
-  // size per compute unit is calculated as follows:
-  //   - If `data` is null: newShape[i] = shape[i] / layout[i]
-  //   - If `data` is not null: newShape[i] = data[i]
-  // When round-robin distribution (`rr`) is enabled, `shape[i]` can be
-  // smaller than `layout[i] * data[i]`, allowing multiple compute units to
-  // share the data.
-  auto tryDistribute = [&](llvm::ArrayRef<int64_t> shape,
-                           SmallVector<int64_t> layout,
-                           SmallVector<int64_t> data,
-                           bool rr = true) -> optional<SmallVector<int64_t>> {
-    llvm::SmallVector<int64_t> newShape(shape);
-    if (layout.size()) {
-      if (layout.size() != shape.size())
-        return std::nullopt;
-      auto ratio = computeShapeRatio(shape, layout);
-      if (ratio.has_value()) {
-        newShape = ratio.value();
-      } else if (!rr || !computeShapeRatio(layout, shape).has_value()) {
-        return std::nullopt;
-      }
-      // Round-robin case: continue with original newShape
-    }
-
-    if (data.size()) {
-      if (data.size() != shape.size())
-        return std::nullopt;
-      auto ratio = computeShapeRatio(newShape, data);
-      if (!ratio.has_value() && rr)
-        ratio = computeShapeRatio(data, newShape);
-      if (!ratio.has_value())
-        return std::nullopt;
-
-      // if data is not null, we always return it for next phase.
-      newShape = data;
-    }
-    return newShape;
-  };
-
-  // check the sgLayout and sgData
-  auto maybeSgShape = tryDistribute(shape, attr.getEffectiveSgLayoutAsInt(),
-                                    attr.getEffectiveSgDataAsInt());
-  if (!maybeSgShape)
-    return false;
-  auto sgShape = maybeSgShape.value();
-
-  // check InstData, it neither have layout nor need round-robin
-  auto maybeInstShape =
-      tryDistribute(sgShape, {}, attr.getEffectiveInstDataAsInt(), false);
-  if (!maybeInstShape)
-    return false;
-  auto instShape = maybeInstShape.value();
-
-  // check LaneLayout and LaneData
-  auto maybeLaneShape =
-      tryDistribute(instShape, attr.getEffectiveLaneLayoutAsInt(),
-                    attr.getEffectiveLaneDataAsInt());
-  return maybeLaneShape.has_value();
-}
-
 //===----------------------------------------------------------------------===//
 // XeGPU_BlockTensorDescAttr
 //===----------------------------------------------------------------------===//
@@ -1431,25 +1363,12 @@ TensorDescType::verify(llvm::function_ref<InFlightDiagnostic()> emitError,
                            << chunkAlignmentFactor;
     }
   }
-
-  auto layoutAttr = llvm::dyn_cast_if_present<LayoutAttr>(layout);
-  if (layoutAttr) {
+  if (auto layoutAttr =
+          mlir::dyn_cast_if_present<DistributeLayoutAttr>(layout)) {
     if (rank != (size_t)layoutAttr.getRank())
       return emitError() << "expected layout rank to match tensor rank";
 
-    auto laneData = layoutAttr.getLaneData();
-    if (scatterAttr && laneData) {
-      // Validate subgroup mapping rules for scattered tensors.
-      // if chunkSize > 1, the last dimension of the tensor should
-      // be distributed in the units divisible by chunkAlignmentFactor.
-      int64_t chunkSize = scatterAttr.getChunkSizeAsInt();
-      if (chunkSize > 1 && laneData[rank - 1] % chunkAlignmentFactor)
-        return emitError()
-               << "expected last dim of lane_data to be a multiple of: "
-               << chunkAlignmentFactor;
-    }
-
-    if (!XeGPUDialect::isEvenlyDistributable(shape, layoutAttr)) {
+    if (!layoutAttr.isDistributable(SmallVector<int64_t>(shape))) {
       std::string shapeStr;
       llvm::raw_string_ostream stream(shapeStr);
       llvm::interleaveComma(shape, stream);
@@ -1457,6 +1376,7 @@ TensorDescType::verify(llvm::function_ref<InFlightDiagnostic()> emitError,
                          << layoutAttr;
     }
   }
+
   return success();
 }
 
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 5697097a4c999..9107cda30a8fa 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -218,6 +218,11 @@ IsValidMatrixOpParams(VectorType dataTy, MemDescType mdescTy,
       }
     }
   }
+
+  if (layout && !layout.isDistributable(
+                    SmallVector<int64_t>(dataShape.begin(), dataShape.end())))
+    return emitError() << "Value shape is not distributable with the layout";
+
   if (dataShape.size() == 2) {
     if (llvm::any_of(llvm::zip_equal(dataShape, mdescShape),
                      [](auto p) { return std::get<0>(p) > std::get<1>(p); }))
@@ -504,6 +509,14 @@ LogicalResult PrefetchNdOp::verify() {
     return emitOpError(
         "Mismatched ranks between offsets and tensor descriptor");
 
+  if (getAnchorLayout()) {
+    auto layout = getAnchorLayout();
+    auto tdescShape = getShapeOf(tdescTy);
+    if (!layout.isDistributable(tdescShape))
+      return emitOpError(
+          "TensorDesc shape is not distributable with the layout");
+  }
+
   return success();
 }
 
@@ -628,6 +641,14 @@ LogicalResult LoadNdOp::verify() {
     return emitOpError(
         "Mismatched ranks between offsets and tensor descriptor");
 
+  if (getAnchorLayout()) {
+    auto layout = getAnchorLayout();
+    auto origTdescShape = getShapeOf(tdescTy);
+    if (!layout.isDistributable(origTdescShape))
+      return emitOpError(
+          "TensorDesc shape is not distributable with the layout");
+  }
+
   return success();
 }
 
@@ -721,6 +742,13 @@ LogicalResult StoreNdOp::verify() {
     return emitOpError(
         "Mismatched ranks between offsets and tensor descriptor");
 
+  if (getAnchorLayout()) {
+    auto layout = getAnchorLayout();
+    if (!layout.isDistributable(tdescShape))
+      return emitOpError(
+          "TensorDesc shape is not distributable with the layout");
+  }
+
   return success();
 }
 
@@ -823,6 +851,19 @@ LogicalResult PrefetchOp::verify() {
   if (getOffsetAlignByteAttr() && !srcTy.isInteger())
     return emitOpError("offset_align_byte only allowed with integer source.");
 
+  if (getAnchorLayout()) {
+    auto layout = getAnchorLayout();
+    // get the offset operand and its shape
+    if (auto offsets = getOffsets()) {
+      auto offsetsTy = offsets.getType();
+      if (!llvm::isa<VectorType>(offsetsTy))
+        return emitOpError("Offsets should be a vector.");
+      auto offsetShape = getShapeOf(offsetsTy);
+      if (!layout.isDistributable(offsetShape))
+        return emitOpError("offset shape is not distributable with the layout");
+    }
+  }
+
   return success();
 }
 
@@ -870,6 +911,13 @@ LogicalResult LoadGatherOp::verify() {
   if (memTy && (getElementType() != memTy.getElementType()))
     return emitError() << "Value should have the same element type as MemRef.";
 
+  if (getAnchorLayout()) {
+    auto layout = getAnchorLayout();
+    auto valShape = getShapeOf(valueTy);
+    if (!layout.isDistributable(valShape))
+      return emitOpError("Value shape is not distributable with the layout");
+  }
+
   auto offsetsTy = getOffsets().getType();
   return isValidGatherScatterBufferParams(offsetsTy, maskTy, valueTy, chunkSize,
                                           [&]() { return emitOpError(); });
@@ -954,6 +1002,13 @@ LogicalResult StoreScatterOp::verify() {
   if (memTy && (getElementType() != memTy.getElementType()))
     return emitError() << "Value should have the same element type as MemRef.";
 
+  if (getAnchorLayout()) {
+    auto layout = getAnchorLayout();
+    auto valShape = getShapeOf(valueTy);
+    if (!layout.isDistributable(valShape))
+      return emitOpError("Value shape is not distributable with the layout");
+  }
+
   auto offsetsTy = getOffsets().getType();
   return isValidGatherScatterBufferParams(offsetsTy, maskTy, valueTy, chunkSize,
                                           [&]() { return emitOpError(); });
@@ -1052,6 +1107,21 @@ LogicalResult DpasOp::verify() {
   auto rhsShape = getRhsType().getShape();
   auto resShape = getResultType().getShape();
 
+  if (auto cdLayout = getLayoutCd())
+    if (!cdLayout->isDistributable(
+            SmallVector<int64_t>(resShape.begin(), resShape.end())))
+      return emitOpError("Value shape is not distributable with the layout");
+
+  if (auto aLayout = getLayoutA())
+    if (!aLayout->isDistributable(
+            SmallVector<int64_t>(lhsShape.begin(), lhsShape.end())))
+      return emitOpError("Value shape is not distributable with the layout");
+
+  if (auto bLayout = getLayoutB())
+    if (!bLayout->isDistributable(
+            SmallVector<int64_t>(rhsShape.begin(), rhsShape.end())))
+      return emitOpError("Value shape is not distributable with the layout");
+
   if (getAcc() && getAcc().getType() != getResultType())
     return emitOpError("Expecting the acc type to be the same as result.");
 
@@ -1103,12 +1173,12 @@ LogicalResult ConvertLayoutOp::verify() {
 
   Type srcType = getSource().getType();
   if (llvm::isa<VectorType>(srcType)) {
-    auto shape = llvm::cast<VectorType>(srcType).getShape();
-    if (!XeGPUDialect::isEvenlyDistributable(shape, srcLayout))
+    SmallVector<int64_t> shape(llvm::cast<VectorType>(srcType).getShape());
+    if (!srcLayout.isDistributable(shape))
       return emitOpError(
           "invalid input layout, data cannot be evenly distributed.");
 
-    if (!XeGPUDialect::isEvenlyDistributable(shape, resLayout))
+    if (!resLayout.isDistributable(shape))
       return emitOpError(
           "invalid target layout, data cannot be evenly distributed.");
   }
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index a095c19d66c15..d637b6828deab 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -498,7 +498,7 @@ struct WgToSgVectorBroadcastOp
     VectorType newResultType =
         VectorType::get(sgShape, resultType.getElementType());
 
-    if (!xegpu::XeGPUDialect::isEvenlyDistributable(wgShape, layout))
+    if (!layout.isDistributable(SmallVector<int64_t>(wgShape)))
       return failure();
 
     SmallVector<Value> newBroadcastOps;
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index 7390b47b3f8d9..82c7879c79d56 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -325,14 +325,6 @@ func.func @create_tdesc_layout_1(%src: ui64) {
   return
 }
 
-// -----
-func.func @create_tdesc_layout_2(%src: ui64) {
-  %cst = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
-  // expected-error at +1 {{expected last dim of lane_data to be a multiple of: 2}}
-  %1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x4xf16, #xegpu.scatter_tdesc_attr<chunk_size = 4>, #xegpu.layout<lane_layout = [4, 1], lane_data = [1, 1]>>
-  return
-}
-
 // -----
 func.func @load_gather_simt_1(%src: ui64) {
   %0 = arith.constant dense<1>: vector<4xi1>
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir b/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir
index 831d1e05967f8..62426d619445b 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir
@@ -27,24 +27,24 @@ gpu.module @test {
   // CHECK-SAME: %[[ARG_1:.*]]: memref<128x256xf32>
   func.func @vector_transpose(%src: memref<256x128xf32>, %src1: memref<128x256xf32>) {
     // CHECK: %[[TDESC_LD:.*]] = xegpu.create_nd_tdesc %[[ARG_0]] : memref<256x128xf32> ->
-    // CHECK-SAME: !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [4, 8], sg_data = [64, 32], order = [0, 1]>>
+    // CHECK-SAME: !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [4, 8], sg_data = [64, 16], order = [0, 1]>>
     // CHECK: %[[TDESC_ST:.*]] = xegpu.create_nd_tdesc %[[ARG_1]] : memref<128x256xf32> ->
-    // CHECK-SAME: !xegpu.tensor_desc<128x256xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], order = [1, 0]>>
+    // CHECK-SAME: !xegpu.tensor_desc<128x256xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 64], order = [1, 0]>>
 
-    // CHECK: %[[LOAD:.*]] = xegpu.load_nd %[[TDESC_LD]][0, 0] <{layout = #xegpu.layout<sg_layout = [4, 8], sg_data = [64, 32], order = [0, 1]>}>
-    // CHECK-SAME: !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [4, 8], sg_data = [64, 32], order = [0, 1]>> -> vector<256x128xf32>
+    // CHECK: %[[LOAD:.*]] = xegpu.load_nd %[[TDESC_LD]][0, 0] <{layout = #xegpu.layout<sg_layout = [4, 8], sg_data = [64, 16], order = [0, 1]>}>
+    // CHECK-SAME: !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [4, 8], sg_data = [64, 16], order = [0, 1]>> -> vector<256x128xf32>
 
     // CHECK: %[[TRANSPOSED:.*]] = vector.transpose %2, [1, 0]
-    // CHECK-SAME {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], order = [1, 0]>} : vector<256x128xf32> to vector<128x256xf32>
+    // CHECK-SAME {layout_resu...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/191945


More information about the Mlir-commits mailing list