[Mlir-commits] [mlir] [MLIR][XeGPU] Refactor layout propagation utilities (PR #179016)
Jianhui Li
llvmlistbot at llvm.org
Mon Feb 2 16:01:53 PST 2026
================
@@ -0,0 +1,830 @@
+//===---- XeGPUUtils.cpp - MLIR Utilities for XeGPUOps ------------------===//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements utility methods for working with the XeGPU dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpls.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/LLVMIR/XeVMDialect.h"
+#include "mlir/Dialect/SCF/Transforms/Patterns.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/ValueRange.h"
+#include "mlir/Interfaces/LoopLikeInterface.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/Support/FormatVariadic.h"
+#include <cstdint>
+#include <numeric>
+
+using namespace mlir;
+
+void xegpu::recoverTemporaryLayoutsDeprecated(Operation *op) {
+ op->walk([&](Operation *nestOp) {
+ for (OpOperand &opr : nestOp->getOpOperands()) {
+ auto layout = getDistributeLayoutAttr(opr.get());
+ setDistributeLayoutAttr(opr, layout);
+ }
+
+ for (OpResult result : nestOp->getOpResults()) {
+ auto layout = getDistributeLayoutAttr(result);
+ setDistributeLayoutAttr(result, layout);
+ }
+ });
+}
+
+SmallVector<NamedAttribute>
+xegpu::dropSgLayoutAndDataOnAttrs(ArrayRef<NamedAttribute> attrs) {
+ SmallVector<NamedAttribute> out;
+ out.reserve(attrs.size());
+
+ for (auto attr : attrs) {
+ if (auto dist = dyn_cast<xegpu::DistributeLayoutAttr>(attr.getValue())) {
+ auto newLayout = dist.dropSgLayoutAndData();
+ if (newLayout)
+ out.emplace_back(attr.getName(), newLayout);
+ } else {
+ out.push_back(attr);
+ }
+ }
+
+ return out;
+}
+
+SmallVector<NamedAttribute>
+xegpu::dropInstDataOnAttrs(ArrayRef<NamedAttribute> attrs) {
+ SmallVector<NamedAttribute> out;
+ out.reserve(attrs.size());
+
+ for (auto attr : attrs) {
+ if (auto dist = dyn_cast<xegpu::DistributeLayoutAttr>(attr.getValue())) {
+ auto newLayout = dist.dropInstData();
+ if (newLayout)
+ out.emplace_back(attr.getName(), newLayout);
+ } else {
+ out.push_back(attr);
+ }
+ }
+
+ return out;
+}
+
+// Attach layout attributes to all vector-type operands of operations within
+// the given operation's region. Reports an error if any vector operand lacks
+// a layout attribute.
+bool xegpu::recoverTemporaryLayouts(Operation *rootOp) {
+ auto result = rootOp->walk([&](Operation *op) {
+ for (OpOperand &operand : op->getOpOperands()) {
+ // Layouts are needed for vector type only.
+ if (!isa<VectorType>(operand.get().getType()))
+ continue;
+ auto layout = xegpu::getDistributeLayoutAttr(operand.get());
+ if (!layout) {
+ op->emitError("Could not find layout attribute for operand ")
+ << operand.getOperandNumber() << " of operation " << op->getName();
+ return WalkResult::interrupt();
+ }
+ xegpu::setDistributeLayoutAttr(operand, layout);
+ }
+ return WalkResult::advance();
+ });
+ return !result.wasInterrupted();
+}
+
+template <typename T, typename>
+void xegpu::removeLayoutAttr(const T &operandOrResult) {
+ Operation *owner = operandOrResult.getOwner();
+ std::string name = xegpu::getTemporaryLayoutName(operandOrResult);
+ if (owner->hasAttrOfType<DistributeLayoutAttr>(name))
+ owner->removeAttr(name);
+}
+
+// Explicit instantiation for OpResult
+template void
+xegpu::removeLayoutAttr<mlir::OpResult>(const mlir::OpResult &result);
+
+// Explicit instantiation for OpOperand
+template void
+xegpu::removeLayoutAttr<mlir::OpOperand>(const mlir::OpOperand &operand);
+
+void xegpu::removeLayoutAttrs(Operation *op) {
+ op->walk([&](Operation *nestOp) {
+ for (OpOperand &opr : nestOp->getOpOperands())
+ removeLayoutAttr(opr);
+ for (OpResult result : nestOp->getOpResults())
+ removeLayoutAttr(result);
+ if (op->hasAttrOfType<DistributeLayoutAttr>("layout"))
+ op->removeAttr("layout");
+ if (op->hasAttrOfType<DistributeLayoutAttr>("layout_a"))
+ op->removeAttr("layout_a");
+ if (op->hasAttrOfType<DistributeLayoutAttr>("layout_b"))
+ op->removeAttr("layout_b");
+ if (op->hasAttrOfType<DistributeLayoutAttr>("layout_cd"))
+ op->removeAttr("layout_cd");
+ });
+}
+
+/// Infers the source layout attribute for a broadcast operation given the
+/// result layout attribute, result shape, source shape.
+xegpu::DistributeLayoutAttr
+xegpu::inferBroadcastSourceLayout(xegpu::DistributeLayoutAttr resLayout,
+ ArrayRef<int64_t> resShape,
+ ArrayRef<int64_t> srcShape) {
+
+ SmallVector<int64_t> bcastDims;
+ auto returnLayout = resLayout;
+
+ // Hanlding broadcast from low-rank to high-rank (e.g., 1D to 2D) case.
+ int dimDiff = resShape.size() - srcShape.size();
+
+ if (dimDiff > 0) {
+ // adding the missing leading dims
+ for (int i = 0; i < dimDiff; i++)
+ bcastDims.push_back(i);
+
+ // create a slice layout for the source
+ returnLayout = xegpu::SliceAttr::get(
+ resLayout.getContext(), resLayout,
+ DenseI64ArrayAttr::get(resLayout.getContext(), bcastDims));
+ }
+ return returnLayout;
+}
+
+/// Infers the source layout attribute for a reduction operation given the
+/// result layout attribute and reduced dims.
+xegpu::DistributeLayoutAttr
+xegpu::inferMultiReductionSourceLayout(xegpu::DistributeLayoutAttr resLayout,
+ SmallVector<int64_t> reduceDims) {
+
+ // assert the resLayout must be slice layout
+ assert(isa<xegpu::SliceAttr>(resLayout) &&
+ "reduction result layout must be slice layout");
+
+ // assert that the reduceDims must match with the slice dims of resLayout
+ xegpu::SliceAttr sliceLayout = dyn_cast<xegpu::SliceAttr>(resLayout);
+ auto sliceDims = sliceLayout.getDims().asArrayRef();
+ assert(reduceDims == sliceDims &&
+ "reduction dims must match with slice dims");
+
+ // then return the parent layout of sliceLayout
+ return sliceLayout.getParent();
+}
+
+/// Infers the source layout attribute for a bitcast operation given the
+/// result layout attribute, result element type bitwidth, and source element
+/// type bitwidth.
+xegpu::DistributeLayoutAttr
+xegpu::inferBitCastSourceLayout(xegpu::DistributeLayoutAttr resLayout,
+ int resElemTyBitWidth, int srcElemTyBitWidth) {
+ // the result and source layout must be the same
+ // only adjust the sg_data, inst_data, lane_data accordingly
+ // based on the bitwidth ratio between source and result element type
+
+ SmallVector<int64_t> sgData = resLayout.getEffectiveSgDataAsInt();
+ SmallVector<int64_t> instData = resLayout.getEffectiveInstDataAsInt();
+ SmallVector<int64_t> laneData = resLayout.getEffectiveLaneDataAsInt();
+ size_t sgDataSize = sgData.size();
+ size_t instDataSize = instData.size();
+ size_t laneDataSize = laneData.size();
+ int64_t sgDataValue = -1;
+ int64_t instDataValue = -1;
+ int64_t laneDataValue = -1;
+ int64_t dim = resLayout.getRank() - 1;
+
+ if (srcElemTyBitWidth <= resElemTyBitWidth) {
+ int bitWidthRatio = resElemTyBitWidth / srcElemTyBitWidth;
+ if (sgDataSize)
+ sgDataValue = sgData[sgDataSize - 1] * bitWidthRatio;
+ if (instDataSize)
+ instDataValue = instData[instDataSize - 1] * bitWidthRatio;
+ if (laneDataSize)
+ laneDataValue = laneData[laneDataSize - 1] * bitWidthRatio;
+ } else {
+ int bitWidthRatio = srcElemTyBitWidth / resElemTyBitWidth;
+ if (sgDataSize) {
+ assert((sgData[sgDataSize - 1] % bitWidthRatio) == 0 &&
+ "sgData not divisible by bitWidthRatio");
+ sgDataValue = sgData[sgDataSize - 1] / bitWidthRatio;
+ }
+ if (instDataSize) {
+ assert((instData[instDataSize - 1] % bitWidthRatio) == 0 &&
+ "instData not divisible by bitWidthRatio");
+ instDataValue = instData[instDataSize - 1] / bitWidthRatio;
+ }
+ if (laneDataSize) {
+ assert((laneData[laneDataSize - 1] % bitWidthRatio) == 0 &&
+ "laneData not divisible by bitWidthRatio");
+ laneDataValue = laneData[laneDataSize - 1] / bitWidthRatio;
+ }
+ }
+
+ // Now set only instData and laneData, preserving sgData
+ xegpu::DistributeLayoutAttr finalSrcLayout;
+ finalSrcLayout =
+ resLayout.setDimData(dim, sgDataValue, instDataValue, laneDataValue);
+
+ return finalSrcLayout;
+}
+
+/// Infers the source layout attribute for an insert strided slice operation
+/// given the result layout attribute, result shape, and source shape. Removes
+/// leading dimensions from the result layout to match the source shape size.
+xegpu::DistributeLayoutAttr xegpu::inferInsertStridedSliceSourceLayout(
+ xegpu::DistributeLayoutAttr resLayout, ArrayRef<int64_t> resShape,
+ ArrayRef<int64_t> srcShape) {
+
+ int srcShapeSize = srcShape.size();
+ int resShapeSize = resShape.size();
+ int dimDiff = resShapeSize - srcShapeSize;
+
+ // assert resLayout must be a plain layout
+ assert(isa<xegpu::LayoutAttr>(resLayout) &&
+ "insertStridedSlice result layout must be plain layout");
+ auto context = resLayout.getContext();
+ auto resInstData = resLayout.getEffectiveInstDataAsInt();
+ auto resLaneLayout = resLayout.getEffectiveLaneLayoutAsInt();
+ auto resLaneData = resLayout.getEffectiveLaneDataAsInt();
+
+ if (resInstData.size() != 0) {
+ SmallVector<int> inferredInstData(srcShapeSize);
+ // remove the initial dims in resInstData to match srcShapeSize
+ for (int i = 0; i < srcShapeSize; i++)
+ inferredInstData[i] = resInstData[i + dimDiff];
+ return xegpu::LayoutAttr::get(context, inferredInstData);
+ }
+
+ if (resLaneLayout.size() != 0) {
+ // construct source lane_layout like [1, ..., 1, subgroupSize]
+ SmallVector<int> inferredLaneLayout(srcShapeSize);
+ SmallVector<int> inferredLaneData(srcShapeSize);
+ // remove the initial dims in resInstData to match srcShapeSize
+ for (int i = 0; i < srcShapeSize; i++) {
+ inferredLaneLayout[i] = resLaneLayout[i + dimDiff];
+ inferredLaneData[i] = resLaneData[i + dimDiff];
+ }
+ return xegpu::LayoutAttr::get(context, inferredLaneLayout,
+ inferredLaneData);
+ }
+ return nullptr;
+}
+
+/// Infers the source layout attribute for a shape cast operation given the
+/// result layout attribute, result shape, and source shape.
+xegpu::DistributeLayoutAttr
+xegpu::inferShapeCastSourceLayout(xegpu::DistributeLayoutAttr resLayout,
+ ArrayRef<int64_t> resShape,
+ ArrayRef<int64_t> srcShape) {
+
+ // there are three use cases:
+ // 1. expand dims of low-rank dimensions (e.g., 1D to 2D): to set up the
+ // tensor before broadcast
+ // 2. split dim of a high-rank dimension (e.g., 1D to 2D): to setup tensor
+ // for multi-stage reduction
+ // 3. combines all dims to a single dim and put in the innermost dim in 2d as
+ // [1, combinedData] or [combinedData]. Only used after workgroup
+ // distribution. Example like cross-sg reduction saves multidimension data to
+ // 1D slm buffer, shapecast inserted by cse/canonicalization passes.
+
+ // Use case 1: Check if shapes only differ by expanding unit dimensions (like
+ // expand_dims)
+ SmallVector<int64_t> expandedUnitDims;
+ auto checkOnlyExpandUnitDims = [&](ArrayRef<int64_t> src,
+ ArrayRef<int64_t> dst) -> bool {
+ // All unit dimensions in dst that don't appear in src are the expanded
+ // unit dimensions
+ size_t srcIdx = 0;
+ for (size_t dstIdx = 0; dstIdx < dst.size(); ++dstIdx)
+ if (srcIdx < src.size() && src[srcIdx] == dst[dstIdx])
+ srcIdx++;
+ else if (dst[dstIdx] == 1)
+ expandedUnitDims.push_back(dstIdx);
+ else
+ return false;
+ return srcIdx == src.size();
+ };
+
+ if (checkOnlyExpandUnitDims(srcShape, resShape)) {
+ // create a slice layout for the source by removing the expanded unit dims
+ auto sliceDimsAttr = DenseI64ArrayAttr::get(
+ resLayout.getContext(), ArrayRef<int64_t>(expandedUnitDims));
+ auto srcLayout =
+ xegpu::SliceAttr::get(resLayout.getContext(), resLayout, sliceDimsAttr);
+ return srcLayout;
+ }
+
+ // Maps each source dimension to the range of destination dimensions it splits
+ // into
+ SmallVector<SmallVector<int64_t>> splitDimGroups;
+
+ auto checkSplitDims = [&](ArrayRef<int64_t> src,
+ ArrayRef<int64_t> dst) -> bool {
+ // each dim in src can be mapped to one or more dims in dst whose product
+ // equals to the src dim
+ splitDimGroups.clear();
+ size_t srcIdx = 0;
+ int64_t accumulatedSize = 1;
+ SmallVector<int64_t> currentDstDims;
+
+ for (size_t dstIdx = 0; dstIdx < dst.size(); ++dstIdx) {
+ if (srcIdx >= src.size())
+ return false;
+ accumulatedSize *= dst[dstIdx];
+ currentDstDims.push_back(dstIdx);
+
+ if (accumulatedSize == src[srcIdx]) {
+ // Record the mapping: srcIdx -> currentDstDims
+ splitDimGroups.push_back(currentDstDims);
+ // move to next src dim
+ srcIdx++;
+ accumulatedSize = 1;
+ currentDstDims.clear();
+ } else if (accumulatedSize > src[srcIdx]) {
+ return false;
+ }
+ }
+ return srcIdx == src.size();
+ };
+
+ if (checkSplitDims(srcShape, resShape)) {
+ return resLayout.collapseDims(splitDimGroups);
+ }
+
+ auto checkCombineToInnerMostDim = [&](ArrayRef<int64_t> src,
+ ArrayRef<int64_t> dst) -> bool {
+ // only one non-unit dim in dst which is the innermost dim
+ if ((dst.size() != 2) && (dst.size() != 1))
+ return false;
+ int64_t srcSize = std::accumulate(src.begin(), src.end(), 1LL,
+ std::multiplies<int64_t>());
+ if (dst.size() == 1)
+ return (dst[0] == srcSize);
+ return (dst[0] == 1) && (dst[1] == srcSize);
+ };
+
+ if (checkCombineToInnerMostDim(srcShape, resShape)) {
+ int srcShapeSize = srcShape.size();
+ int resShapeSize = resShape.size();
+ auto context = resLayout.getContext();
+ auto resInstData = resLayout.getEffectiveInstDataAsInt();
+ auto resLaneLayout = resLayout.getEffectiveLaneLayoutAsInt();
+ auto resLaneData = resLayout.getEffectiveLaneDataAsInt();
+
+ // get the layout info from the innermost dim of result layout
+ if (resInstData.size() != 0) {
+ SmallVector<int> inferredInstData(srcShapeSize, 1);
+ assert((resShapeSize == 1 || resInstData[0] == 1) &&
+ "only innermost dim can have data and instData layout");
+ inferredInstData[srcShapeSize - 1] = resInstData[resShapeSize - 1];
+ return xegpu::LayoutAttr::get(context, inferredInstData);
+ }
+
+ if (resLaneLayout.size() != 0) {
+ SmallVector<int> inferredLaneLayout(srcShapeSize, 1);
+ SmallVector<int> inferredLaneData(srcShapeSize, 1);
+ assert((resShapeSize == 1 || resLaneLayout[0] == 1) &&
+ "only innermost dim can have data and lane layout");
+ inferredLaneLayout[srcShapeSize - 1] = resLaneLayout[resShapeSize - 1];
+ inferredLaneData[srcShapeSize - 1] = resLaneData[resShapeSize - 1];
+ return xegpu::LayoutAttr::get(context, inferredLaneLayout,
+ inferredLaneData);
+ }
+ }
+ assert("running into unsupported shape cast scenarios");
+ return nullptr;
+}
+
+/// Sets up layout for reduction operations by creating a SliceAttr for the
+/// result.
+///
+/// Algorithm Overview:
+/// This function attempts to construct a source layout that, when sliced along
+/// reduction dimensions, produces a result layout compatible with the
+/// consumer layout.
+///
+/// For subgroup layouts, it first tries to align the source layout's subgroup
+/// layout and data with the consumer's layout on non-reduction dimensions.
+/// Then, it distributes remaining subgroups across reduction dimensions. This
+/// avoid subgroup data redistribution overhead between the reduced result and
+/// its consumer.
+///
+/// InstData requries {1, ..., min(maxReduceVectorSize, srcShape),subgroupSize}
+/// Lane Layout requires {1, ..., 1, subgroupSize}
+/// Lane data requires {1, ..., min(maxReduceVectorSize, srcShape), 1}
+
+xegpu::SliceAttr xegpu::setupMultiReductionResultLayout(
+ xegpu::LayoutKind layoutKind, VectorType srcVecTy,
+ DistributeLayoutAttr consumerLayout, SmallVector<int64_t> reductionDims,
+ const xegpu::uArch::uArch *uArch) {
+
+ auto srcShape = srcVecTy.getShape();
+ int srcRank = srcShape.size();
+ auto context = consumerLayout.getContext();
+
+ // Reduction layout requires at least 2D tensors
+ if (srcRank < 2)
+ return nullptr;
+
+ // Helper lambda to convert int64 vectors to int32 DenseArrayAttr
+ auto toInt32Attr = [&](ArrayRef<int64_t> vec) {
+ SmallVector<int32_t> vec32(vec.begin(), vec.end());
+ return DenseI32ArrayAttr::get(context, vec32);
+ };
+
+ // Extract original plain layout for workgroup/subgroup size recovery
+ xegpu::SliceAttr consumerSliceLayout =
+ dyn_cast<xegpu::SliceAttr>(consumerLayout);
+ DistributeLayoutAttr plainLayout =
+ consumerSliceLayout ? consumerSliceLayout.flatten().getParent()
+ : consumerLayout;
+
+ auto sgLayoutVec = plainLayout.getEffectiveSgLayoutAsInt();
----------------
Jianhui-Li wrote:
move to under subgroup layoutKind.
https://github.com/llvm/llvm-project/pull/179016
More information about the Mlir-commits
mailing list