[Mlir-commits] [mlir] [MLIR][XeGPU] Refactor layout propagation utilities (PR #179016)

Igor Zamyatin llvmlistbot at llvm.org
Tue Feb 3 06:51:55 PST 2026


================
@@ -0,0 +1,810 @@
+//===---- XeGPULayoutImpls.cpp - MLIR Utilities for XeGPUOps
+//------------------===//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements layout utility functions for XeGPU dialect
+// transformation.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpls.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/LLVMIR/XeVMDialect.h"
+#include "mlir/Dialect/SCF/Transforms/Patterns.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/ValueRange.h"
+#include "mlir/Interfaces/LoopLikeInterface.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/Support/FormatVariadic.h"
+#include <cstdint>
+#include <numeric>
+
+using namespace mlir;
+
+void xegpu::recoverTemporaryLayoutsDeprecated(Operation *op) {
+  op->walk([&](Operation *nestOp) {
+    for (OpOperand &opr : nestOp->getOpOperands()) {
+      auto layout = getDistributeLayoutAttr(opr.get());
+      setDistributeLayoutAttr(opr, layout);
+    }
+
+    for (OpResult result : nestOp->getOpResults()) {
+      auto layout = getDistributeLayoutAttr(result);
+      setDistributeLayoutAttr(result, layout);
+    }
+  });
+}
+
+SmallVector<NamedAttribute>
+xegpu::dropSgLayoutAndDataOnAttrs(ArrayRef<NamedAttribute> attrs) {
+  SmallVector<NamedAttribute> out;
+  out.reserve(attrs.size());
+
+  for (auto attr : attrs) {
+    if (auto dist = dyn_cast<xegpu::DistributeLayoutAttr>(attr.getValue())) {
+      auto newLayout = dist.dropSgLayoutAndData();
+      if (newLayout)
+        out.emplace_back(attr.getName(), newLayout);
+    } else {
+      out.push_back(attr);
+    }
+  }
+
+  return out;
+}
+
+SmallVector<NamedAttribute>
+xegpu::dropInstDataOnAttrs(ArrayRef<NamedAttribute> attrs) {
+  SmallVector<NamedAttribute> out;
+  out.reserve(attrs.size());
+
+  for (auto attr : attrs) {
+    if (auto dist = dyn_cast<xegpu::DistributeLayoutAttr>(attr.getValue())) {
+      auto newLayout = dist.dropInstData();
+      if (newLayout)
+        out.emplace_back(attr.getName(), newLayout);
+    } else {
+      out.push_back(attr);
+    }
+  }
+
+  return out;
+}
+
+// Attach layout attributes to all vector-type operands of operations within
+// the given operation's region. Reports an error if any vector operand lacks
+// a layout attribute.
+bool xegpu::recoverTemporaryLayouts(Operation *rootOp) {
+  auto result = rootOp->walk([&](Operation *op) {
+    for (OpOperand &operand : op->getOpOperands()) {
+      // Layouts are needed for vector type only.
+      if (!isa<VectorType>(operand.get().getType()))
+        continue;
+      auto layout = xegpu::getDistributeLayoutAttr(operand.get());
+      if (!layout) {
+        op->emitError("Could not find layout attribute for operand ")
+            << operand.getOperandNumber() << " of operation " << op->getName();
+        return WalkResult::interrupt();
+      }
+      xegpu::setDistributeLayoutAttr(operand, layout);
+    }
+    return WalkResult::advance();
+  });
+  return !result.wasInterrupted();
+}
+
+template <typename T, typename>
+void xegpu::removeLayoutAttr(const T &operandOrResult) {
+  Operation *owner = operandOrResult.getOwner();
+  std::string name = xegpu::getTemporaryLayoutName(operandOrResult);
+  if (owner->hasAttrOfType<DistributeLayoutAttr>(name))
+    owner->removeAttr(name);
+}
+
+// Explicit instantiation for OpResult
+template void
+xegpu::removeLayoutAttr<mlir::OpResult>(const mlir::OpResult &result);
+
+// Explicit instantiation for OpOperand
+template void
+xegpu::removeLayoutAttr<mlir::OpOperand>(const mlir::OpOperand &operand);
+
+void xegpu::removeLayoutAttrs(Operation *op) {
+  op->walk([&](Operation *nestOp) {
+    // Remove all attributes of DistributeLayoutAttr type
+    SmallVector<StringAttr> attrsToRemove;
+    for (auto namedAttr : nestOp->getAttrs()) {
+      if (isa<DistributeLayoutAttr>(namedAttr.getValue()))
+        attrsToRemove.push_back(namedAttr.getName());
+    }
+    for (auto attrName : attrsToRemove)
+      nestOp->removeAttr(attrName);
+  });
+}
+
+/// Infers the source layout attribute for a broadcast operation given the
+/// result layout attribute, result shape, source shape.
+xegpu::DistributeLayoutAttr
+xegpu::inferBroadcastSourceLayout(xegpu::DistributeLayoutAttr resLayout,
+                                  ArrayRef<int64_t> resShape,
+                                  ArrayRef<int64_t> srcShape) {
+
+  SmallVector<int64_t> bcastDims;
+  auto returnLayout = resLayout;
+
+  // Hanlding broadcast from low-rank to high-rank (e.g., 1D to 2D) case.
+  int dimDiff = resShape.size() - srcShape.size();
+
+  if (dimDiff > 0) {
+    // adding the missing leading dims
+    for (int i = 0; i < dimDiff; i++)
+      bcastDims.push_back(i);
+
+    // create a slice layout for the source
+    returnLayout = xegpu::SliceAttr::get(
+        resLayout.getContext(), resLayout,
+        DenseI64ArrayAttr::get(resLayout.getContext(), bcastDims));
+  }
+  return returnLayout;
+}
+
+/// Infers the source layout attribute for a reduction operation given the
+/// result layout attribute and reduced dims.
+xegpu::DistributeLayoutAttr
+xegpu::inferMultiReductionSourceLayout(xegpu::DistributeLayoutAttr resLayout,
+                                       SmallVector<int64_t> reduceDims) {
+
+  //  assert the resLayout must be slice layout
+  assert(isa<xegpu::SliceAttr>(resLayout) &&
+         "reduction result layout must be slice layout");
+
+  // assert that the reduceDims must match with the slice dims of resLayout
+  xegpu::SliceAttr sliceLayout = dyn_cast<xegpu::SliceAttr>(resLayout);
+  auto sliceDims = sliceLayout.getDims().asArrayRef();
+  assert(reduceDims == sliceDims &&
+         "reduction dims must match with slice dims");
+
+  //  then return the parent layout of sliceLayout
+  return sliceLayout.getParent();
+}
+
+/// Infers the source layout attribute for a bitcast operation given the
+/// result layout attribute, result element type bitwidth, and source element
+/// type bitwidth.
+xegpu::DistributeLayoutAttr
+xegpu::inferBitCastSourceLayout(xegpu::DistributeLayoutAttr resLayout,
+                                int resElemTyBitWidth, int srcElemTyBitWidth) {
+  // the result and source layout must be the same
+  // only adjust the sg_data, inst_data, lane_data accordingly
+  // based on the bitwidth ratio between source and result element type
+
+  SmallVector<int64_t> sgData = resLayout.getEffectiveSgDataAsInt();
+  SmallVector<int64_t> instData = resLayout.getEffectiveInstDataAsInt();
+  SmallVector<int64_t> laneData = resLayout.getEffectiveLaneDataAsInt();
+  size_t sgDataSize = sgData.size();
+  size_t instDataSize = instData.size();
+  size_t laneDataSize = laneData.size();
+  int64_t sgDataValue = -1;
+  int64_t instDataValue = -1;
+  int64_t laneDataValue = -1;
+  int64_t dim = resLayout.getRank() - 1;
+
+  if (srcElemTyBitWidth <= resElemTyBitWidth) {
+    int bitWidthRatio = resElemTyBitWidth / srcElemTyBitWidth;
+    if (sgDataSize)
+      sgDataValue = sgData[sgDataSize - 1] * bitWidthRatio;
+    if (instDataSize)
+      instDataValue = instData[instDataSize - 1] * bitWidthRatio;
+    if (laneDataSize)
+      laneDataValue = laneData[laneDataSize - 1] * bitWidthRatio;
+  } else {
+    int bitWidthRatio = srcElemTyBitWidth / resElemTyBitWidth;
+    if (sgDataSize) {
+      assert((sgData[sgDataSize - 1] % bitWidthRatio) == 0 &&
+             "sgData not divisible by bitWidthRatio");
+      sgDataValue = sgData[sgDataSize - 1] / bitWidthRatio;
+    }
+    if (instDataSize) {
+      assert((instData[instDataSize - 1] % bitWidthRatio) == 0 &&
+             "instData not divisible by bitWidthRatio");
+      instDataValue = instData[instDataSize - 1] / bitWidthRatio;
+    }
+    if (laneDataSize) {
+      assert((laneData[laneDataSize - 1] % bitWidthRatio) == 0 &&
+             "laneData not divisible by bitWidthRatio");
+      laneDataValue = laneData[laneDataSize - 1] / bitWidthRatio;
+    }
+  }
+
+  // Now set only instData and laneData, preserving sgData
----------------
Garra1980 wrote:

Maybe I miss something but is this comment relevant?

https://github.com/llvm/llvm-project/pull/179016


More information about the Mlir-commits mailing list