[Mlir-commits] [mlir] [MLIR][XeGPU] Refactor layout propagation utilities (PR #179016)
Jianhui Li
llvmlistbot at llvm.org
Mon Feb 2 16:01:54 PST 2026
================
@@ -0,0 +1,830 @@
+//===---- XeGPUUtils.cpp - MLIR Utilities for XeGPUOps ------------------===//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements utility methods for working with the XeGPU dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpls.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/LLVMIR/XeVMDialect.h"
+#include "mlir/Dialect/SCF/Transforms/Patterns.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/ValueRange.h"
+#include "mlir/Interfaces/LoopLikeInterface.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/Support/FormatVariadic.h"
+#include <cstdint>
+#include <numeric>
+
+using namespace mlir;
+
+void xegpu::recoverTemporaryLayoutsDeprecated(Operation *op) {
+ op->walk([&](Operation *nestOp) {
+ for (OpOperand &opr : nestOp->getOpOperands()) {
+ auto layout = getDistributeLayoutAttr(opr.get());
+ setDistributeLayoutAttr(opr, layout);
+ }
+
+ for (OpResult result : nestOp->getOpResults()) {
+ auto layout = getDistributeLayoutAttr(result);
+ setDistributeLayoutAttr(result, layout);
+ }
+ });
+}
+
+SmallVector<NamedAttribute>
+xegpu::dropSgLayoutAndDataOnAttrs(ArrayRef<NamedAttribute> attrs) {
+ SmallVector<NamedAttribute> out;
+ out.reserve(attrs.size());
+
+ for (auto attr : attrs) {
+ if (auto dist = dyn_cast<xegpu::DistributeLayoutAttr>(attr.getValue())) {
+ auto newLayout = dist.dropSgLayoutAndData();
+ if (newLayout)
+ out.emplace_back(attr.getName(), newLayout);
+ } else {
+ out.push_back(attr);
+ }
+ }
+
+ return out;
+}
+
+SmallVector<NamedAttribute>
+xegpu::dropInstDataOnAttrs(ArrayRef<NamedAttribute> attrs) {
+ SmallVector<NamedAttribute> out;
+ out.reserve(attrs.size());
+
+ for (auto attr : attrs) {
+ if (auto dist = dyn_cast<xegpu::DistributeLayoutAttr>(attr.getValue())) {
+ auto newLayout = dist.dropInstData();
+ if (newLayout)
+ out.emplace_back(attr.getName(), newLayout);
+ } else {
+ out.push_back(attr);
+ }
+ }
+
+ return out;
+}
+
+// Attach layout attributes to all vector-type operands of operations within
+// the given operation's region. Reports an error if any vector operand lacks
+// a layout attribute.
+bool xegpu::recoverTemporaryLayouts(Operation *rootOp) {
+ auto result = rootOp->walk([&](Operation *op) {
+ for (OpOperand &operand : op->getOpOperands()) {
+ // Layouts are needed for vector type only.
+ if (!isa<VectorType>(operand.get().getType()))
+ continue;
+ auto layout = xegpu::getDistributeLayoutAttr(operand.get());
+ if (!layout) {
+ op->emitError("Could not find layout attribute for operand ")
+ << operand.getOperandNumber() << " of operation " << op->getName();
+ return WalkResult::interrupt();
+ }
+ xegpu::setDistributeLayoutAttr(operand, layout);
+ }
+ return WalkResult::advance();
+ });
+ return !result.wasInterrupted();
+}
+
+template <typename T, typename>
+void xegpu::removeLayoutAttr(const T &operandOrResult) {
+ Operation *owner = operandOrResult.getOwner();
+ std::string name = xegpu::getTemporaryLayoutName(operandOrResult);
+ if (owner->hasAttrOfType<DistributeLayoutAttr>(name))
+ owner->removeAttr(name);
+}
+
+// Explicit instantiation for OpResult
+template void
+xegpu::removeLayoutAttr<mlir::OpResult>(const mlir::OpResult &result);
+
+// Explicit instantiation for OpOperand
+template void
+xegpu::removeLayoutAttr<mlir::OpOperand>(const mlir::OpOperand &operand);
+
+void xegpu::removeLayoutAttrs(Operation *op) {
+ op->walk([&](Operation *nestOp) {
+ for (OpOperand &opr : nestOp->getOpOperands())
+ removeLayoutAttr(opr);
+ for (OpResult result : nestOp->getOpResults())
+ removeLayoutAttr(result);
+ if (op->hasAttrOfType<DistributeLayoutAttr>("layout"))
+ op->removeAttr("layout");
+ if (op->hasAttrOfType<DistributeLayoutAttr>("layout_a"))
+ op->removeAttr("layout_a");
+ if (op->hasAttrOfType<DistributeLayoutAttr>("layout_b"))
+ op->removeAttr("layout_b");
+ if (op->hasAttrOfType<DistributeLayoutAttr>("layout_cd"))
+ op->removeAttr("layout_cd");
+ });
+}
+
+/// Infers the source layout attribute for a broadcast operation given the
+/// result layout attribute, result shape, source shape.
+xegpu::DistributeLayoutAttr
+xegpu::inferBroadcastSourceLayout(xegpu::DistributeLayoutAttr resLayout,
+ ArrayRef<int64_t> resShape,
+ ArrayRef<int64_t> srcShape) {
+
+ SmallVector<int64_t> bcastDims;
+ auto returnLayout = resLayout;
+
+ // Hanlding broadcast from low-rank to high-rank (e.g., 1D to 2D) case.
+ int dimDiff = resShape.size() - srcShape.size();
+
+ if (dimDiff > 0) {
+ // adding the missing leading dims
+ for (int i = 0; i < dimDiff; i++)
+ bcastDims.push_back(i);
+
+ // create a slice layout for the source
+ returnLayout = xegpu::SliceAttr::get(
+ resLayout.getContext(), resLayout,
+ DenseI64ArrayAttr::get(resLayout.getContext(), bcastDims));
+ }
+ return returnLayout;
+}
+
+/// Infers the source layout attribute for a reduction operation given the
+/// result layout attribute and reduced dims.
+xegpu::DistributeLayoutAttr
+xegpu::inferMultiReductionSourceLayout(xegpu::DistributeLayoutAttr resLayout,
+ SmallVector<int64_t> reduceDims) {
+
+ // assert the resLayout must be slice layout
+ assert(isa<xegpu::SliceAttr>(resLayout) &&
+ "reduction result layout must be slice layout");
+
+ // assert that the reduceDims must match with the slice dims of resLayout
+ xegpu::SliceAttr sliceLayout = dyn_cast<xegpu::SliceAttr>(resLayout);
+ auto sliceDims = sliceLayout.getDims().asArrayRef();
+ assert(reduceDims == sliceDims &&
----------------
Jianhui-Li wrote:
We do. For that use case, this assertion still hold. the setupMulitReductionResultLayout ensures it.
https://github.com/llvm/llvm-project/pull/179016
More information about the Mlir-commits
mailing list