[Mlir-commits] [mlir] [MLIR][XeGPU] Clean up the temporary layout usage in XeGPU test (PR #195739)

Jianhui Li llvmlistbot at llvm.org
Tue May 5 19:56:41 PDT 2026


================
@@ -484,6 +482,86 @@ xegpu::DistributeLayoutAttr xegpu::inferInsertStridedSliceSourceLayout(
   return resLayout;
 }
 
+/// Infers the source layout attribute for an insert operation
+/// given the result layout attribute, result shape, and source shape. Removes
+/// leading dimensions from the result layout to match the source shape size.
+// TODO: add propagation support for insert op
+xegpu::DistributeLayoutAttr
+xegpu::inferInsertSourceLayout(xegpu::DistributeLayoutAttr resLayout,
+                               ArrayRef<int64_t> resShape,
+                               ArrayRef<int64_t> srcShape) {
+
+  int srcShapeSize = srcShape.size();
+  int resShapeSize = resShape.size();
+  int dimDiff = resShapeSize - srcShapeSize;
+
+  if (dimDiff > 0) {
+    // assert that the leading dimensions being sliced off are not distributed
+    // (i.e. sg_layout and lane_layout for those dimensions are all 1)
+    auto resSgLayout = resLayout.getEffectiveSgLayoutAsInt();
+    auto resLaneLayout = resLayout.getEffectiveLaneLayoutAsInt();
+    for (int i = 0; i < dimDiff; i++) {
+      assert((resSgLayout.size() == 0 || resSgLayout[i] == 1) &&
+             (resLaneLayout.size() == 0 || resLaneLayout[i] == 1) &&
+             "Leading dimensions being sliced off must not be distributed");
+    }
+    return resLayout.dropDims(llvm::to_vector(llvm::seq<int64_t>(0, dimDiff)));
+  }
+  return resLayout;
+}
+
+/// Infers the source layout attribute for extract operation
+/// given the result layout attribute, result shape, and source shape. Adds
+/// leading dimensions to the source layout to match the source shape size.
+// TODO: add layout attribute interface: expandDims() and use it here.
+// TODO: add propagation support for extract op
+xegpu::DistributeLayoutAttr
+xegpu::inferExtractSourceLayout(xegpu::DistributeLayoutAttr resLayout,
+                                ArrayRef<int64_t> resShape,
+                                ArrayRef<int64_t> srcShape) {
+
+  int srcShapeSize = srcShape.size();
+  int resShapeSize = resShape.size();
+  int dimDiff = srcShapeSize - resShapeSize;
+  auto context = resLayout.getContext();
+  // construct the source layout by adding unit dimensions to the front of
+  // result layout
+  if (dimDiff > 0) {
+    auto sgLayout = resLayout.getEffectiveSgLayoutAsInt();
+    auto sgData = resLayout.getEffectiveSgDataAsInt();
+    auto instData = resLayout.getEffectiveInstDataAsInt();
+    auto laneLayout = resLayout.getEffectiveLaneLayoutAsInt();
+    auto laneData = resLayout.getEffectiveLaneDataAsInt();
+    auto order = resLayout.getEffectiveOrderAsInt();
+
+    for (int i = resShapeSize; i < dimDiff; i++) {
----------------
Jianhui-Li wrote:

fixed.

https://github.com/llvm/llvm-project/pull/195739


More information about the Mlir-commits mailing list