[Mlir-commits] [mlir] [MLIR][XeGPU] Clean up the temporary layout usage in XeGPU test (PR #195739)
Jianhui Li
llvmlistbot at llvm.org
Tue May 5 19:56:41 PDT 2026
================
@@ -484,6 +482,86 @@ xegpu::DistributeLayoutAttr xegpu::inferInsertStridedSliceSourceLayout(
return resLayout;
}
+/// Infers the source layout attribute for an insert operation
+/// given the result layout attribute, result shape, and source shape. Removes
+/// leading dimensions from the result layout to match the source shape size.
+// TODO: add propagation support for insert op
+xegpu::DistributeLayoutAttr
+xegpu::inferInsertSourceLayout(xegpu::DistributeLayoutAttr resLayout,
+ ArrayRef<int64_t> resShape,
+ ArrayRef<int64_t> srcShape) {
+
+ int srcShapeSize = srcShape.size();
+ int resShapeSize = resShape.size();
+ int dimDiff = resShapeSize - srcShapeSize;
+
+ if (dimDiff > 0) {
+ // assert that the leading dimensions being sliced off are not distributed
+ // (i.e. sg_layout and lane_layout for those dimensions are all 1)
+ auto resSgLayout = resLayout.getEffectiveSgLayoutAsInt();
+ auto resLaneLayout = resLayout.getEffectiveLaneLayoutAsInt();
+ for (int i = 0; i < dimDiff; i++) {
+ assert((resSgLayout.size() == 0 || resSgLayout[i] == 1) &&
+ (resLaneLayout.size() == 0 || resLaneLayout[i] == 1) &&
+ "Leading dimensions being sliced off must not be distributed");
+ }
+ return resLayout.dropDims(llvm::to_vector(llvm::seq<int64_t>(0, dimDiff)));
+ }
+ return resLayout;
+}
+
+/// Infers the source layout attribute for extract operation
+/// given the result layout attribute, result shape, and source shape. Adds
+/// leading dimensions to the source layout to match the source shape size.
+// TODO: add layout attribute interface: expandDims() and use it here.
+// TODO: add propagation support for extract op
+xegpu::DistributeLayoutAttr
+xegpu::inferExtractSourceLayout(xegpu::DistributeLayoutAttr resLayout,
+ ArrayRef<int64_t> resShape,
+ ArrayRef<int64_t> srcShape) {
+
+ int srcShapeSize = srcShape.size();
+ int resShapeSize = resShape.size();
+ int dimDiff = srcShapeSize - resShapeSize;
+ auto context = resLayout.getContext();
+ // construct the source layout by adding unit dimensions to the front of
+ // result layout
+ if (dimDiff > 0) {
+ auto sgLayout = resLayout.getEffectiveSgLayoutAsInt();
+ auto sgData = resLayout.getEffectiveSgDataAsInt();
+ auto instData = resLayout.getEffectiveInstDataAsInt();
+ auto laneLayout = resLayout.getEffectiveLaneLayoutAsInt();
+ auto laneData = resLayout.getEffectiveLaneDataAsInt();
+ auto order = resLayout.getEffectiveOrderAsInt();
+
+ for (int i = resShapeSize; i < dimDiff; i++) {
----------------
Jianhui-Li wrote:
fixed.
https://github.com/llvm/llvm-project/pull/195739
More information about the Mlir-commits
mailing list