[Mlir-commits] [mlir] [MLIR][XeGPU] Refactor layout propagation utilities (PR #179016)
Jianhui Li
llvmlistbot at llvm.org
Tue Feb 3 21:11:49 PST 2026
https://github.com/Jianhui-Li updated https://github.com/llvm/llvm-project/pull/179016
>From 20d44c2ed8d34a6f90426f0b5984732c1912ecf9 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Thu, 18 Dec 2025 17:46:38 +0000
Subject: [PATCH 01/35] add layout utitilies interface
---
.../Dialect/XeGPU/Utils/XeGPULayoutUtils.h | 131 +++++++
.../mlir/Dialect/XeGPU/Utils/XeGPUUtils.h | 64 +---
.../XeGPU/Transforms/XeGPUPropagateLayout.cpp | 41 +-
mlir/lib/Dialect/XeGPU/Utils/CMakeLists.txt | 1 +
.../Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp | 350 ++++++++++++++++++
mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp | 304 ---------------
6 files changed, 504 insertions(+), 387 deletions(-)
create mode 100644 mlir/include/mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h
create mode 100644 mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp
diff --git a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h
new file mode 100644
index 0000000000000..6f528214f538b
--- /dev/null
+++ b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h
@@ -0,0 +1,131 @@
+//===- XeGPULayoutUtils.h - Layout Utilities --------------------------*- C++
+//-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_XEGPU_UTILS_XEGPULAYOUTUTILS_H_
+#define MLIR_DIALECT_XEGPU_UTILS_XEGPULAYOUTUTILS_H_
+
+#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/OpDefinition.h"
+namespace mlir {
+
+class VectorType;
+class OpOperand;
+class OpResult;
+class OpBuilder;
+class ValueRange;
+class TypeConverter;
+class OpFoldResult;
+
+namespace xegpu {
+class DistributeLayoutAttr;
+class LayoutAttr;
+class TensorDescType;
+} // namespace xegpu
+
+namespace xegpu {
+
+/// Return the attribute name for the OpOperand to attach DistributeLayoutAttr
+std::string getTemporaryLayoutName(const OpOperand &operand);
+
+/// Return the attribute name for the OpResult to attach DistributeLayoutAttr
+std::string getTemporaryLayoutName(const OpResult result);
+
+/// Retrieves the DistributeLayoutAttr associated with a given Value. For
+/// TensorDescType values, the DistributeLayoutAttr is extracted from the
+/// TensorDescType itself. For other values, it is obtained from the attributes
+/// of the defining operation. Returns nullptr if no DistributeLayoutAttr is
+/// found.
+DistributeLayoutAttr getDistributeLayoutAttr(const Value value);
+
+/// Retrieves the DistributeLayoutAttr associated with a given OpOperand. It
+/// will first check the operand_layout_{id} of the owner operation. If not
+/// found, it will check the operand itself and its defining op.
+DistributeLayoutAttr getDistributeLayoutAttr(const OpOperand &opr);
+
+/// [to-be-deprecated] Sets the DistributeLayoutAttr for a given OpResult
+/// user should use setAnchorLayout instead
+void setDistributeLayoutAttr(const OpResult &Result,
+ const DistributeLayoutAttr layout);
+
+/// [to-be-deprecated] Sets the DistributeLayoutAttr for a given OpOperand
+/// user should use setAnchorLayout instead
+void setDistributeLayoutAttr(const OpOperand &opr,
+ const DistributeLayoutAttr layout);
+
+/// get and set distribute layout attribute for non-anchor operations
+/// (and offsets/masks of load/store ops before we get rid of their temp attrs)
+template <typename T,
+ typename = std::enable_if_t<std::is_same_v<T, OpOperand> ||
+ std::is_same_v<T, OpResult>>>
+DistributeLayoutAttr getTemporaryLayout(const T &operandOrResult);
+
+template <typename T,
+ typename = std::enable_if_t<std::is_same_v<T, OpOperand> ||
+ std::is_same_v<T, OpResult>>>
+void setTemporaryLayout(const T &operandOrResult,
+ const DistributeLayoutAttr layout);
+
+/// [to-be-deprecated] Set the DistributeLayoutAttr for each OpOperand and
+/// OpResult of of the given operation. If the operation contains regions, it is
+/// also applied recursively to the contained operations operation.
+/// TODO: To be replaced by recoverTemporaryLayouts()
+void recoverTemporaryLayoutsDeprecated(Operation *op);
+
+/// Attach layout attributes to all vector-type operands of operations within
+/// the given operation's region. Reports an error if any vector operand lacks
+/// a layout attribute.
+bool recoverTemporaryLayouts(Operation *rootOp);
+
+/// Removes the LayoutAttr for a given OpOperand or OpResult if it exists.
+template <typename T,
+ typename = std::enable_if_t<std::is_same_v<T, OpOperand> ||
+ std::is_same_v<T, OpResult>>>
+void removeLayoutAttr(const T &operandOrResult);
+
+/// Removes the DistributeLayoutAttr for each OpOperand and OpResult of the
+/// given operation if they exist. If the operation contains regions, it is also
+/// applied recursively to the contained operations
+void removeLayoutAttrs(Operation *op);
+
+/// Infers the source layout attribute for a broadcast operation given the
+/// result layout attribute, result shape, source shape, and broadcasted dims.
+DistributeLayoutAttr inferBroadCastSourceLayout(MLIRContext *context,
+ DistributeLayoutAttr resLayout,
+ ArrayRef<int64_t> resShape,
+ ArrayRef<int64_t> srcShape);
+
+/// Infers the source layout attribute for a reduction operation given the
+/// result layout attribute, result shape, source shape, and reduced dims.
+DistributeLayoutAttr
+inferReductionSourceLayout(MLIRContext *context, DistributeLayoutAttr resLayout,
+ ArrayRef<int64_t> resShape,
+ ArrayRef<int64_t> srcShape,
+ SmallVector<int64_t> reduceDims);
+
+/// Infers the source layout attribute for a bitcast operation given the
+/// result layout attribute, result element type bitwidth, and source element
+/// type bitwidth.
+DistributeLayoutAttr inferBitCastSourceLayout(MLIRContext *context,
+ DistributeLayoutAttr resLayout,
+ int resElemTyBitWidth,
+ int srcElemTyBitWidth);
+
+/// Infers the source layout attribute for a shape cast operation given the
+/// result layout attribute, result shape, and source shape.
+DistributeLayoutAttr inferShapeCastSourceLayout(MLIRContext *context,
+ DistributeLayoutAttr resLayout,
+ ArrayRef<int64_t> resShape,
+ ArrayRef<int64_t> srcShape);
+
+} // namespace xegpu
+
+} // namespace mlir
+
+#endif // MLIR_DIALECT_XEGPU_UTILS_XEGPUUTILS_H_
diff --git a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
index 46d52516cbc15..3dbbe7e4c5dff 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
@@ -9,6 +9,7 @@
#ifndef MLIR_DIALECT_XEGPU_UTILS_XEGPUUTILS_H_
#define MLIR_DIALECT_XEGPU_UTILS_XEGPUUTILS_H_
+#include "XeGPULayoutUtils.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpDefinition.h"
@@ -118,69 +119,6 @@ template <typename T>
int getLargestDivisor(T dim, ArrayRef<T> candidates,
ArrayRef<T> candidateMultiples = {});
-/// Return the attribute name for the OpOperand to attach DistributeLayoutAttr
-std::string getTemporaryLayoutName(const OpOperand &operand);
-
-/// Return the attribute name for the OpResult to attach DistributeLayoutAttr
-std::string getTemporaryLayoutName(const OpResult result);
-
-/// Retrieves the DistributeLayoutAttr associated with a given Value. For
-/// TensorDescType values, the DistributeLayoutAttr is extracted from the
-/// TensorDescType itself. For other values, it is obtained from the attributes
-/// of the defining operation. Returns nullptr if no DistributeLayoutAttr is
-/// found.
-DistributeLayoutAttr getDistributeLayoutAttr(const Value value);
-
-/// Retrieves the DistributeLayoutAttr associated with a given OpOperand. It
-/// will first check the operand_layout_{id} of the owner operation. If not
-/// found, it will check the operand itself and its defining op.
-DistributeLayoutAttr getDistributeLayoutAttr(const OpOperand &opr);
-
-/// Removes the LayoutAttr for a given OpOperand or OpResult if it exists.
-template <typename T,
- typename = std::enable_if_t<std::is_same_v<T, OpOperand> ||
- std::is_same_v<T, OpResult>>>
-void removeLayoutAttr(const T &operandOrResult);
-
-/// Removes the DistributeLayoutAttr for each OpOperand and OpResult of the
-/// given operation if they exist. If the operation contains regions, it is also
-/// applied recursively to the contained operations
-void removeLayoutAttrs(Operation *op);
-
-/// [to-be-deprecated] Sets the DistributeLayoutAttr for a given OpResult
-/// user should use setAnchorLayout instead
-void setDistributeLayoutAttr(const OpResult &Result,
- const DistributeLayoutAttr layout);
-
-/// [to-be-deprecated] Sets the DistributeLayoutAttr for a given OpOperand
-/// user should use setAnchorLayout instead
-void setDistributeLayoutAttr(const OpOperand &opr,
- const DistributeLayoutAttr layout);
-
-/// get and set distribute layout attribute for non-anchor operations
-/// (and offsets/masks of load/store ops before we get rid of their temp attrs)
-template <typename T,
- typename = std::enable_if_t<std::is_same_v<T, OpOperand> ||
- std::is_same_v<T, OpResult>>>
-DistributeLayoutAttr getTemporaryLayout(const T &operandOrResult);
-
-template <typename T,
- typename = std::enable_if_t<std::is_same_v<T, OpOperand> ||
- std::is_same_v<T, OpResult>>>
-void setTemporaryLayout(const T &operandOrResult,
- const DistributeLayoutAttr layout);
-
-/// [to-be-deprecated] Set the DistributeLayoutAttr for each OpOperand and
-/// OpResult of of the given operation. If the operation contains regions, it is
-/// also applied recursively to the contained operations operation.
-/// TODO: To be replaced by recoverTemporaryLayouts()
-void recoverTemporaryLayoutsDeprecated(Operation *op);
-
-/// Attach layout attributes to all vector-type operands of operations within
-/// the given operation's region. Reports an error if any vector operand lacks
-/// a layout attribute.
-bool recoverTemporaryLayouts(Operation *rootOp);
-
} // namespace xegpu
} // namespace mlir
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 7fc75e7294ea3..60dedf373a07b 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -601,8 +601,21 @@ void LayoutInfoPropagation::visitVectorMultiReductionOp(
LayoutInfo resultLayout = results[0]->getValue();
if (!resultLayout.isAssigned())
return;
- // We only consider 2D -> 1D reductions at this point.
+
VectorType resultTy = llvm::dyn_cast<VectorType>(reduction.getDestType());
+ VectorType sourceTy =
+ llvm::dyn_cast<VectorType>(reduction.getSourceVectorType());
+ SmallVector<int64_t> reductionDims(reduction.getReductionDims().begin(),
+ reduction.getReductionDims().end());
+ // xegpu::DistributeLayoutAttr operandLayout =
+ // xegpu::inferReductionSourceLayout(
+ // reduction.getContext(),
+ // dyn_cast<xegpu::DistributeLayoutAttr>(resultLayout.get()),
+ // resultTy.getShape(),
+ // sourceTy.getShape(),
+ // reductionDims);
+ // We only consider 2D -> 1D reductions at this point.
+
if (!resultTy || resultTy.getRank() != 1) {
reduction.emitWarning("Expecting output type to be 1D vector.");
return;
@@ -633,25 +646,13 @@ void LayoutInfoPropagation::visitVectorBroadCastOp(
// Hanlding broadcast from low-rank to high-rank (e.g., 1D to 2D) case.
if (sourceTy.getRank() != resultTy.getRank()) {
- auto sourceDims = sourceTy.getShape();
- auto resultDims = resultTy.getShape();
- SmallVector<int64_t> bcastDims;
- auto dimDiff = resultTy.getRank() - sourceTy.getRank();
- // adding the missing leading dims
- for (int i = 0; i < dimDiff; i++)
- bcastDims.push_back(i);
-
- // for the rest dims in the resultTy, if sourceTy dim is 1, then it's
- // broadcasted dim
- for (size_t i = 0; i < sourceDims.size(); i++)
- if ((sourceDims[i] == 1) && (resultDims[i + dimDiff] != 1))
- bcastDims.push_back(i + dimDiff);
-
- // create a slice layout for the source
- xegpu::SliceAttr sliceLayout = xegpu::SliceAttr::get(
- broadcast->getContext(),
- cast<xegpu::DistributeLayoutAttr>(resultLayout.get()),
- DenseI64ArrayAttr::get(broadcast->getContext(), bcastDims));
+ auto srcShape = sourceTy.getShape();
+ auto resShape = resultTy.getShape();
+ auto resultLayoutAttr =
+ dyn_cast<xegpu::DistributeLayoutAttr>(resultLayout.get());
+
+ xegpu::DistributeLayoutAttr sliceLayout = xegpu::inferBroadCastSourceLayout(
+ broadcast.getContext(), resultLayoutAttr, resShape, srcShape);
propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(sliceLayout)));
return;
diff --git a/mlir/lib/Dialect/XeGPU/Utils/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/Utils/CMakeLists.txt
index d9bf4a1461c27..bde8324aab5fb 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/CMakeLists.txt
+++ b/mlir/lib/Dialect/XeGPU/Utils/CMakeLists.txt
@@ -1,5 +1,6 @@
add_mlir_dialect_library(MLIRXeGPUUtils
XeGPUUtils.cpp
+ XeGPULayoutUtils.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/XeGPU/Utils
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp
new file mode 100644
index 0000000000000..2f68aa7bda924
--- /dev/null
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp
@@ -0,0 +1,350 @@
+//===---- XeGPUUtils.cpp - MLIR Utilities for XeGPUOps ------------------===//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements utility methods for working with the XeGPU dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/LLVMIR/XeVMDialect.h"
+#include "mlir/Dialect/SCF/Transforms/Patterns.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/ValueRange.h"
+#include "mlir/Interfaces/LoopLikeInterface.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/Support/FormatVariadic.h"
+#include <cstdint>
+#include <numeric>
+
+using namespace mlir;
+
+std::string xegpu::getTemporaryLayoutName(const OpOperand &operand) {
+ const StringRef prefix("layout_operand_");
+ unsigned idx = const_cast<OpOperand &>(operand).getOperandNumber();
+ return llvm::formatv("{0}{1}", prefix, idx).str();
+}
+
+std::string xegpu::getTemporaryLayoutName(const OpResult result) {
+ const StringRef prefix = "layout_result_";
+ return llvm::formatv("{0}{1}", prefix, result.getResultNumber()).str();
+}
+
+xegpu::DistributeLayoutAttr xegpu::getDistributeLayoutAttr(const Value value) {
+ if (!value)
+ return nullptr;
+
+ if (auto tdescTy =
+ dyn_cast_if_present<xegpu::TensorDescType>(value.getType()))
+ return tdescTy.getLayoutAttr();
+
+ if (auto result = dyn_cast<OpResult>(value)) {
+ Operation *defOp = result.getDefiningOp();
+ assert(defOp && "result must have a defining op");
+
+ if (auto anchorOp = dyn_cast<xegpu::AnchorLayoutInterface>(defOp)) {
+ auto layout = anchorOp.getAnchorLayout();
+ return layout;
+ }
+
+ std::string layoutName = getTemporaryLayoutName(result);
+ if (defOp->hasAttr(layoutName)) {
+ auto layout =
+ defOp->getAttrOfType<xegpu::DistributeLayoutAttr>(layoutName);
+ return layout;
+ }
+ }
+
+ if (auto arg = dyn_cast<BlockArgument>(value)) {
+ auto *parentOp = arg.getOwner()->getParentOp();
+ if (auto loop = dyn_cast<LoopLikeOpInterface>(parentOp)) {
+ OpOperand *tiedInit = loop.getTiedLoopInit(arg);
+ if (tiedInit)
+ return getDistributeLayoutAttr(tiedInit->get());
+ }
+ }
+
+ return nullptr;
+}
+xegpu::DistributeLayoutAttr
+xegpu::getDistributeLayoutAttr(const OpOperand &opr) {
+ Operation *op = opr.getOwner();
+ unsigned idx = const_cast<OpOperand &>(opr).getOperandNumber();
+
+ if (auto anchorOp = dyn_cast<xegpu::AnchorLayoutInterface>(op)) {
+ if (auto dpasOp = dyn_cast<xegpu::DpasOp>(op)) {
+ if (idx == 0) {
+ return dpasOp.getLayoutAAttr();
+ } else if (idx == 1) {
+ return dpasOp.getLayoutBAttr();
+ } else if (idx == 2) {
+ return dpasOp.getLayoutCdAttr();
+ }
+ }
+ if (auto convertOp = dyn_cast<xegpu::ConvertLayoutOp>(op)) {
+ return convertOp.getInputLayoutAttr();
+ }
+ auto layout = anchorOp.getAnchorLayout();
+
+ if (idx == 0)
+ return layout;
+
+ // For store operations (StoreScatterOp, StoreNdOp, StoreMatrixOp),
+ // the layout is valid for the first two operands: value and memref/tdesc.
+ // For other operations, the layout applies to the first operand only.
+ if (isa<xegpu::StoreScatterOp, xegpu::StoreNdOp, xegpu::StoreMatrixOp>(
+ op) &&
+ (idx < 2))
+ return layout;
+ }
+
+ std::string layoutName = xegpu::getTemporaryLayoutName(opr);
+ if (op->hasAttr(layoutName)) {
+ auto layout = op->getAttrOfType<xegpu::DistributeLayoutAttr>(layoutName);
+ return layout;
+ }
+
+ auto layout = getDistributeLayoutAttr(opr.get());
+ return layout;
+}
+
+// TODO-LayoutRefactor: Remove this function after replacing use
+// with setTemporaryLayout or setAnchorLayout
+void xegpu::setDistributeLayoutAttr(
+ const mlir::OpResult &result,
+ const mlir::xegpu::DistributeLayoutAttr layout) {
+ Operation *owner = result.getOwner();
+
+ if (auto anchorOp = dyn_cast<xegpu::AnchorLayoutInterface>(owner)) {
+ if (anchorOp.getAnchorLayout() == layout)
+ return;
+ anchorOp.setAnchorLayout(layout);
+ return;
+ }
+
+ std::string name = xegpu::getTemporaryLayoutName(result);
+ if (owner->hasAttrOfType<DistributeLayoutAttr>(name)) {
+ return;
+ }
+ if (layout) {
+ owner->setAttr(name, layout);
+ }
+}
+
+// TODO-LayoutRefactor: Remove this function after replacing use
+// with setTemporaryLayout or setAnchorLayout
+void xegpu::setDistributeLayoutAttr(const OpOperand &operand,
+ const DistributeLayoutAttr layout) {
+ Operation *owner = operand.getOwner();
+ unsigned idx = const_cast<OpOperand &>(operand).getOperandNumber();
+
+ if (!layout) {
+ return;
+ }
+ if (auto anchorOp = dyn_cast<xegpu::AnchorLayoutInterface>(owner)) {
+ if (auto dpasOp = dyn_cast<xegpu::DpasOp>(owner)) {
+ if (idx == 0) {
+ return dpasOp.setLayoutAAttr(layout);
+ } else if (idx == 1) {
+ return dpasOp.setLayoutBAttr(layout);
+ } else if (idx == 2) {
+ return dpasOp.setLayoutCdAttr(layout);
+ }
+ }
+ if (auto convertOp = dyn_cast<xegpu::ConvertLayoutOp>(owner)) {
+ return convertOp.setInputLayoutAttr(layout);
+ }
+
+ // For store operations (StoreScatterOp, StoreNdOp, StoreMatrixOp),
+ // the layout is valid for the first two operands: value and memref/tdesc.
+ // For other operations, the layout applies to the first operand only.
+ if (isa<xegpu::StoreScatterOp, xegpu::StoreNdOp, xegpu::StoreMatrixOp>(
+ owner)) {
+ if (idx < 2) {
+ anchorOp.setAnchorLayout(layout);
+ }
+ } else {
+ if (idx == 0) {
+ anchorOp.setAnchorLayout(layout);
+ }
+ }
+ }
+
+ std::string name = xegpu::getTemporaryLayoutName(operand);
+ if (owner->hasAttrOfType<DistributeLayoutAttr>(name)) {
+ return;
+ }
+ if (layout) {
+ owner->setAttr(name, layout);
+ }
+}
+
+template <typename T, typename>
+xegpu::DistributeLayoutAttr
+xegpu::getTemporaryLayout(const T &operandOrResult) {
+ Operation *op = operandOrResult.getOwner();
+
+ std::string layoutName = xegpu::getTemporaryLayoutName(operandOrResult);
+ if (op->hasAttr(layoutName)) {
+ auto layout = op->getAttrOfType<xegpu::DistributeLayoutAttr>(layoutName);
+ return layout;
+ }
+
+ return nullptr;
+}
+
+template xegpu::DistributeLayoutAttr
+xegpu::getTemporaryLayout<mlir::OpResult>(const OpResult &result);
+template xegpu::DistributeLayoutAttr
+xegpu::getTemporaryLayout<mlir::OpOperand>(const OpOperand &operand);
+
+template <typename T, typename>
+void xegpu::setTemporaryLayout(const T &operandOrResult,
+ const xegpu::DistributeLayoutAttr layout) {
+ Operation *owner = operandOrResult.getOwner();
+ std::string name = xegpu::getTemporaryLayoutName(operandOrResult);
+ if (owner->hasAttrOfType<xegpu::DistributeLayoutAttr>(name)) {
+ return;
+ }
+ if (layout) {
+ owner->setAttr(name, layout);
+ }
+}
+
+template void xegpu::setTemporaryLayout<mlir::OpResult>(
+ const mlir::OpResult &result,
+ const mlir::xegpu::DistributeLayoutAttr layout);
+
+template void xegpu::setTemporaryLayout<mlir::OpOperand>(
+ const mlir::OpOperand &operand,
+ const mlir::xegpu::DistributeLayoutAttr layout);
+
+void xegpu::recoverTemporaryLayoutsDeprecated(Operation *op) {
+ op->walk([&](Operation *nestOp) {
+ for (OpOperand &opr : nestOp->getOpOperands()) {
+ auto layout = getDistributeLayoutAttr(opr.get());
+ setDistributeLayoutAttr(opr, layout);
+ }
+
+ for (OpResult result : nestOp->getOpResults()) {
+ auto layout = getDistributeLayoutAttr(result);
+ setDistributeLayoutAttr(result, layout);
+ }
+ });
+}
+
+/// Attach layout attributes to all vector-type operands of operations within
+/// the given operation's region. Reports an error if any vector operand lacks
+/// a layout attribute.
+bool xegpu::recoverTemporaryLayouts(Operation *rootOp) {
+ auto result = rootOp->walk([&](Operation *op) {
+ for (OpOperand &operand : op->getOpOperands()) {
+ // Layouts are needed for vector type only.
+ if (!isa<VectorType>(operand.get().getType()))
+ continue;
+ auto layout = xegpu::getDistributeLayoutAttr(operand.get());
+ if (!layout) {
+ op->emitError("Could not find layout attribute for operand ")
+ << operand.getOperandNumber() << " of operation " << op->getName();
+ return WalkResult::interrupt();
+ }
+ xegpu::setDistributeLayoutAttr(operand, layout);
+ }
+ return WalkResult::advance();
+ });
+ return !result.wasInterrupted();
+}
+
+template <typename T, typename>
+void xegpu::removeLayoutAttr(const T &operandOrResult) {
+ Operation *owner = operandOrResult.getOwner();
+ std::string name = xegpu::getTemporaryLayoutName(operandOrResult);
+ if (owner->hasAttrOfType<DistributeLayoutAttr>(name))
+ owner->removeAttr(name);
+}
+
+// Explicit instantiation for OpResult
+template void
+xegpu::removeLayoutAttr<mlir::OpResult>(const mlir::OpResult &result);
+
+// Explicit instantiation for OpOperand
+template void
+xegpu::removeLayoutAttr<mlir::OpOperand>(const mlir::OpOperand &operand);
+
+void xegpu::removeLayoutAttrs(Operation *op) {
+ op->walk([&](Operation *nestOp) {
+ for (OpOperand &opr : nestOp->getOpOperands())
+ removeLayoutAttr(opr);
+ for (OpResult result : nestOp->getOpResults())
+ removeLayoutAttr(result);
+ if (op->hasAttrOfType<DistributeLayoutAttr>("layout"))
+ op->removeAttr("layout");
+ if (op->hasAttrOfType<DistributeLayoutAttr>("layout_a"))
+ op->removeAttr("layout_a");
+ if (op->hasAttrOfType<DistributeLayoutAttr>("layout_b"))
+ op->removeAttr("layout_b");
+ if (op->hasAttrOfType<DistributeLayoutAttr>("layout_cd"))
+ op->removeAttr("layout_cd");
+ });
+}
+
+/// Infers the source layout attribute for a broadcast operation given the
+/// result layout attribute, result shape, source shape, and broadcasted dims.
+xegpu::DistributeLayoutAttr xegpu::inferBroadCastSourceLayout(
+ MLIRContext *context, xegpu::DistributeLayoutAttr resLayout,
+ ArrayRef<int64_t> resShape, ArrayRef<int64_t> srcShape) {
+
+ SmallVector<int64_t> bcastDims;
+ int dimDiff = resShape.size() - srcShape.size();
+ // adding the missing leading dims
+ for (int i = 0; i < dimDiff; i++)
+ bcastDims.push_back(i);
+
+ // for the rest dims in the resultTy, if sourceTy dim is 1, then it's
+ // broadcasted dim
+ // for (size_t i = 0; i < srcShape.size(); i++)
+ // if ((srcShape[i] == 1) && (resShape[i + dimDiff] != 1))
+ // bcastDims.push_back(i + dimDiff);
+
+ // create a slice layout for the source
+ xegpu::SliceAttr sliceLayout = xegpu::SliceAttr::get(
+ context, resLayout, DenseI64ArrayAttr::get(context, bcastDims));
+
+ return sliceLayout;
+}
+
+// /// Infers the source layout attribute for a reduction operation given the
+// /// result layout attribute, result shape, source shape, and reduced dims.
+// xegpu::DistributeLayoutAttr xegpu::inferReductionSourceLayout(
+// MLIRContext *context,
+// xegpu::DistributeLayoutAttr resLayout,
+// SmallVector<int64_t> reduceDims){
+
+// // flatten the reslayout first
+// xegpu::DistributeLayoutAttr flatResLayout =
+// resLayout.flatten();
+// // then unslice the reduceDims in the flattened layout
+
+//}
+
+/// Infers the source layout attribute for a bitcast operation given the
+/// result layout attribute, result element type bitwidth, and source element
+/// type bitwidth.
+xegpu::DistributeLayoutAttr
+xegpu::inferBitCastSourceLayout(MLIRContext *context,
+ xegpu::DistributeLayoutAttr resLayout,
+ int resElemTyBitWidth, int srcElemTyBitWidth);
+
+/// Infers the source layout attribute for a shape cast operation given the
+/// result layout attribute, result shape, and source shape.
+xegpu::DistributeLayoutAttr xegpu::inferShapeCastSourceLayout(
+ MLIRContext *context, xegpu::DistributeLayoutAttr resLayout,
+ ArrayRef<int64_t> resShape, ArrayRef<int64_t> srcShape);
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
index d3906e37ffbf1..181b7e9673fef 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
@@ -101,310 +101,6 @@ mlir::xegpu::getDistributedVectorType(VectorType originalType,
return xegpu::getDistributedVectorType(helperTdescTy);
}
-std::string xegpu::getTemporaryLayoutName(const OpOperand &operand) {
- const StringRef prefix("layout_operand_");
- unsigned idx = const_cast<OpOperand &>(operand).getOperandNumber();
- return llvm::formatv("{0}{1}", prefix, idx).str();
-}
-
-std::string xegpu::getTemporaryLayoutName(const OpResult result) {
- const StringRef prefix = "layout_result_";
- return llvm::formatv("{0}{1}", prefix, result.getResultNumber()).str();
-}
-
-xegpu::DistributeLayoutAttr xegpu::getDistributeLayoutAttr(const Value value) {
- if (!value)
- return nullptr;
-
- if (auto tdescTy =
- dyn_cast_if_present<xegpu::TensorDescType>(value.getType()))
- return tdescTy.getLayoutAttr();
-
- if (auto result = dyn_cast<OpResult>(value)) {
- Operation *defOp = result.getDefiningOp();
- assert(defOp && "result must have a defining op");
-
- if (auto anchorOp = dyn_cast<xegpu::AnchorLayoutInterface>(defOp)) {
- auto layout = anchorOp.getAnchorLayout();
- return layout;
- }
-
- std::string layoutName = getTemporaryLayoutName(result);
- if (defOp->hasAttr(layoutName)) {
- auto layout =
- defOp->getAttrOfType<xegpu::DistributeLayoutAttr>(layoutName);
- return layout;
- }
- }
-
- if (auto arg = dyn_cast<BlockArgument>(value)) {
- auto *parentOp = arg.getOwner()->getParentOp();
- if (auto loop = dyn_cast<LoopLikeOpInterface>(parentOp)) {
- OpOperand *tiedInit = loop.getTiedLoopInit(arg);
- if (tiedInit)
- return getDistributeLayoutAttr(tiedInit->get());
- }
- }
-
- return nullptr;
-}
-xegpu::DistributeLayoutAttr
-xegpu::getDistributeLayoutAttr(const OpOperand &opr) {
- Operation *op = opr.getOwner();
- unsigned idx = const_cast<OpOperand &>(opr).getOperandNumber();
-
- if (auto anchorOp = dyn_cast<xegpu::AnchorLayoutInterface>(op)) {
- if (auto dpasOp = dyn_cast<xegpu::DpasOp>(op)) {
- if (idx == 0) {
- return dpasOp.getLayoutAAttr();
- } else if (idx == 1) {
- return dpasOp.getLayoutBAttr();
- } else if (idx == 2) {
- return dpasOp.getLayoutCdAttr();
- }
- }
- if (auto convertOp = dyn_cast<xegpu::ConvertLayoutOp>(op)) {
- return convertOp.getInputLayoutAttr();
- }
- auto layout = anchorOp.getAnchorLayout();
-
- if (idx == 0)
- return layout;
-
- // For store operations (StoreScatterOp, StoreNdOp, StoreMatrixOp),
- // the layout is valid for the first two operands: value and memref/tdesc.
- // For other operations, the layout applies to the first operand only.
- if (isa<xegpu::StoreScatterOp, xegpu::StoreNdOp, xegpu::StoreMatrixOp>(
- op) &&
- (idx < 2))
- return layout;
- }
-
- std::string layoutName = xegpu::getTemporaryLayoutName(opr);
- if (op->hasAttr(layoutName)) {
- auto layout = op->getAttrOfType<xegpu::DistributeLayoutAttr>(layoutName);
- return layout;
- }
-
- auto layout = getDistributeLayoutAttr(opr.get());
- return layout;
-}
-
-// Returns the permanent layout attribute for the given result if it's
-// available on the defining op. Otherwise returns the provided layout.
-xegpu::DistributeLayoutAttr
-maybePickPermanentLayout(xegpu::DistributeLayoutAttr layout,
- const OpResult &result, mlir::Operation *owner,
- const std::string &name) {
- xegpu::DistributeLayoutAttr candidate = layout;
-
- if (auto loadOp = dyn_cast<xegpu::LoadGatherOp>(owner)) {
- if (auto perm = loadOp.getLayoutAttr())
- candidate = perm;
- }
-
- return candidate;
-}
-
-// Returns the permanent layout attribute for the given operand if it's
-// available on the defining op. Otherwise returns the provided layout.
-xegpu::DistributeLayoutAttr
-maybePickPermanentLayout(xegpu::DistributeLayoutAttr layout,
- const OpOperand &operand, mlir::Operation *owner,
- const std::string &name) {
- xegpu::DistributeLayoutAttr candidate = layout;
- unsigned idx = const_cast<OpOperand &>(operand).getOperandNumber();
-
- if (auto storeOp = dyn_cast<xegpu::StoreScatterOp>(owner)) {
- if (idx == 0) {
- if (auto perm = storeOp.getLayoutAttr())
- candidate = perm;
- }
- }
-
- return candidate;
-}
-
-// TODO-LayoutRefactor: Remove this function after replacing use
-// with setTemporaryLayout or setAnchorLayout
-void xegpu::setDistributeLayoutAttr(
- const mlir::OpResult &result,
- const mlir::xegpu::DistributeLayoutAttr layout) {
- Operation *owner = result.getOwner();
-
- if (auto anchorOp = dyn_cast<xegpu::AnchorLayoutInterface>(owner)) {
- if (anchorOp.getAnchorLayout() == layout)
- return;
- anchorOp.setAnchorLayout(layout);
- return;
- }
-
- std::string name = xegpu::getTemporaryLayoutName(result);
- if (owner->hasAttrOfType<DistributeLayoutAttr>(name)) {
- return;
- }
- if (layout) {
- owner->setAttr(name, layout);
- }
-}
-
-// TODO-LayoutRefactor: Remove this function after replacing use
-// with setTemporaryLayout or setAnchorLayout
-void xegpu::setDistributeLayoutAttr(const OpOperand &operand,
- const DistributeLayoutAttr layout) {
- Operation *owner = operand.getOwner();
- unsigned idx = const_cast<OpOperand &>(operand).getOperandNumber();
-
- if (!layout) {
- return;
- }
- if (auto anchorOp = dyn_cast<xegpu::AnchorLayoutInterface>(owner)) {
- if (auto dpasOp = dyn_cast<xegpu::DpasOp>(owner)) {
- if (idx == 0) {
- return dpasOp.setLayoutAAttr(layout);
- } else if (idx == 1) {
- return dpasOp.setLayoutBAttr(layout);
- } else if (idx == 2) {
- return dpasOp.setLayoutCdAttr(layout);
- }
- }
- if (auto convertOp = dyn_cast<xegpu::ConvertLayoutOp>(owner)) {
- return convertOp.setInputLayoutAttr(layout);
- }
-
- // For store operations (StoreScatterOp, StoreNdOp, StoreMatrixOp),
- // the layout is valid for the first two operands: value and memref/tdesc.
- // For other operations, the layout applies to the first operand only.
- if (isa<xegpu::StoreScatterOp, xegpu::StoreNdOp, xegpu::StoreMatrixOp>(
- owner)) {
- if (idx < 2) {
- anchorOp.setAnchorLayout(layout);
- }
- } else {
- if (idx == 0) {
- anchorOp.setAnchorLayout(layout);
- }
- }
- }
-
- std::string name = xegpu::getTemporaryLayoutName(operand);
- if (owner->hasAttrOfType<DistributeLayoutAttr>(name)) {
- return;
- }
- if (layout) {
- owner->setAttr(name, layout);
- }
-}
-
-template <typename T, typename>
-xegpu::DistributeLayoutAttr
-xegpu::getTemporaryLayout(const T &operandOrResult) {
- Operation *op = operandOrResult.getOwner();
-
- std::string layoutName = xegpu::getTemporaryLayoutName(operandOrResult);
- if (op->hasAttr(layoutName)) {
- auto layout = op->getAttrOfType<xegpu::DistributeLayoutAttr>(layoutName);
- return layout;
- }
-
- return nullptr;
-}
-
-template xegpu::DistributeLayoutAttr
-xegpu::getTemporaryLayout<mlir::OpResult>(const OpResult &result);
-template xegpu::DistributeLayoutAttr
-xegpu::getTemporaryLayout<mlir::OpOperand>(const OpOperand &operand);
-
-template <typename T, typename>
-void xegpu::setTemporaryLayout(const T &operandOrResult,
- const xegpu::DistributeLayoutAttr layout) {
- Operation *owner = operandOrResult.getOwner();
- std::string name = xegpu::getTemporaryLayoutName(operandOrResult);
- if (owner->hasAttrOfType<xegpu::DistributeLayoutAttr>(name)) {
- return;
- }
- if (layout) {
- owner->setAttr(name, layout);
- }
-}
-
-template void xegpu::setTemporaryLayout<mlir::OpResult>(
- const mlir::OpResult &result,
- const mlir::xegpu::DistributeLayoutAttr layout);
-
-template void xegpu::setTemporaryLayout<mlir::OpOperand>(
- const mlir::OpOperand &operand,
- const mlir::xegpu::DistributeLayoutAttr layout);
-
-void xegpu::recoverTemporaryLayoutsDeprecated(Operation *op) {
- op->walk([&](Operation *nestOp) {
- for (OpOperand &opr : nestOp->getOpOperands()) {
- auto layout = getDistributeLayoutAttr(opr.get());
- setDistributeLayoutAttr(opr, layout);
- }
-
- for (OpResult result : nestOp->getOpResults()) {
- auto layout = getDistributeLayoutAttr(result);
- setDistributeLayoutAttr(result, layout);
- }
- });
-}
-
-/// Attach layout attributes to all vector-type operands of operations within
-/// the given operation's region. Reports an error if any vector operand lacks
-/// a layout attribute.
-bool xegpu::recoverTemporaryLayouts(Operation *rootOp) {
- auto result = rootOp->walk([&](Operation *op) {
- for (OpOperand &operand : op->getOpOperands()) {
- // Layouts are needed for vector type only.
- if (!isa<VectorType>(operand.get().getType()))
- continue;
- auto layout = xegpu::getDistributeLayoutAttr(operand.get());
- if (!layout) {
- op->emitError("Could not find layout attribute for operand ")
- << operand.getOperandNumber() << " of operation " << op->getName();
- return WalkResult::interrupt();
- }
- xegpu::setDistributeLayoutAttr(operand, layout);
- }
- return WalkResult::advance();
- });
- return !result.wasInterrupted();
-}
-
-template <typename T, typename>
-void xegpu::removeLayoutAttr(const T &operandOrResult) {
- Operation *owner = operandOrResult.getOwner();
- std::string name = xegpu::getTemporaryLayoutName(operandOrResult);
- if (owner->hasAttrOfType<DistributeLayoutAttr>(name))
- owner->removeAttr(name);
-}
-
-// Explicit instantiation for OpResult
-template void
-xegpu::removeLayoutAttr<mlir::OpResult>(const mlir::OpResult &result);
-
-// Explicit instantiation for OpOperand
-template void
-xegpu::removeLayoutAttr<mlir::OpOperand>(const mlir::OpOperand &operand);
-
-void xegpu::removeLayoutAttrs(Operation *op) {
- op->walk([&](Operation *nestOp) {
- for (OpOperand &opr : nestOp->getOpOperands())
- removeLayoutAttr(opr);
- for (OpResult result : nestOp->getOpResults())
- removeLayoutAttr(result);
- if (op->hasAttrOfType<DistributeLayoutAttr>("layout"))
- op->removeAttr("layout");
- if (op->hasAttrOfType<DistributeLayoutAttr>("layout_a"))
- op->removeAttr("layout_a");
- if (op->hasAttrOfType<DistributeLayoutAttr>("layout_b"))
- op->removeAttr("layout_b");
- if (op->hasAttrOfType<DistributeLayoutAttr>("layout_cd"))
- op->removeAttr("layout_cd");
- });
-}
-
SmallVector<Value>
xegpu::extractVectorsWithShapeFromValue(OpBuilder &builder, Location loc,
Value value, ArrayRef<int64_t> shape) {
>From a45ca448e6900a024e2d22d2724af17c90ce9243 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Tue, 23 Dec 2025 23:56:39 +0000
Subject: [PATCH 02/35] add layout set up rule for reduction
---
.../mlir/Dialect/XeGPU/IR/XeGPUAttrs.td | 24 +-
.../Dialect/XeGPU/Utils/XeGPULayoutUtils.h | 14 +-
mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp | 148 +++++--
.../XeGPU/Transforms/XeGPUPropagateLayout.cpp | 98 +++--
.../Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp | 412 ++++++++++++++++--
mlir/test/Dialect/XeGPU/propagate-layout.mlir | 4 +-
6 files changed, 584 insertions(+), 116 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index 446f64fffa468..eae2cbe24a72c 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -236,6 +236,14 @@ def DistributeLayoutAttr: AttrInterface<"DistributeLayoutAttr"> {
"FailureOr<SmallVector<Value>>",
"delinearizeId",
(ins "OpBuilder &": $builder, "Location":$loc, "Value":$linearId)>,
+ InterfaceMethod<[{Derive a new layout with sg_data, inst_data and lane_data set to the
+ specified values for the given dimension}],
+ "xegpu::DistributeLayoutAttr",
+ "setDimData",
+ (ins "int64_t": $dim,
+ "int64_t": $sgData,
+ "int64_t": $instData,
+ "int64_t": $laneData)>,
InterfaceMethod<[{Generates instructions to compute multidimensional coordinates for dist units
assigned to a level identified by linearId. The shape parameter
represents the higher-level problem size. Each level may access
@@ -501,10 +509,14 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttr]> {
}
//set the layout for the sepcified unit dims: sg_data, inst_data and lane_data to 1
- DistributeLayoutAttr setUnitDimData(SetVector<int64_t> unitDims);
+ DistributeLayoutAttr setUnitDimData(SetVector<int64_t> unitDims) const;
//set the layout for the sepcified unit dims: sg_lane and lane_layout to 1
- DistributeLayoutAttr setUnitDimLayout(SetVector<int64_t> unitDims);
+ DistributeLayoutAttr setUnitDimLayout(SetVector<int64_t> unitDims) const;
+
+ // Derive a new layout with sg_data, inst_data and lane_data set to the
+ // specified values for the given dimension
+ DistributeLayoutAttr setDimData(int64_t dim, int64_t sgData, int64_t instData, int64_t laneData);
/// Delinearizes a linear ID into its multidimensional indices
/// based on the effective level of the layout.
@@ -672,10 +684,14 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttr]> {
}
//set the layout for the sepcified unit dims: sg_data, inst_data and lane_data to 1
- DistributeLayoutAttr setUnitDimData(SetVector<int64_t> unitDims);
+ DistributeLayoutAttr setUnitDimData(SetVector<int64_t> unitDims) const;
//set the layout for the sepcified unit dims: sg_lane and lane_layout to 1
- DistributeLayoutAttr setUnitDimLayout(SetVector<int64_t> unitDims);
+ DistributeLayoutAttr setUnitDimLayout(SetVector<int64_t> unitDims) const;
+
+ // Derive a new layout with sg_data, inst_data and lane_data set to the
+ // specified values for the given dimension
+ DistributeLayoutAttr setDimData(int64_t dim, int64_t sgData, int64_t instData, int64_t laneData);
/// flatten a nested SliceAttr, e.g., for 2-level nested SliceAttr
/// #xegpu.slice<#xegpu.slice<#xegpu.layout<sg_layout = [4, 8, 12]>, dims = [0]>, dims = [0]>
diff --git a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h
index 6f528214f538b..01cb43b73d5ca 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h
@@ -102,11 +102,9 @@ DistributeLayoutAttr inferBroadCastSourceLayout(MLIRContext *context,
ArrayRef<int64_t> srcShape);
/// Infers the source layout attribute for a reduction operation given the
-/// result layout attribute, result shape, source shape, and reduced dims.
+/// result layout attribute and reduced dims.
DistributeLayoutAttr
-inferReductionSourceLayout(MLIRContext *context, DistributeLayoutAttr resLayout,
- ArrayRef<int64_t> resShape,
- ArrayRef<int64_t> srcShape,
+inferReductionSourceLayout(DistributeLayoutAttr resLayout,
SmallVector<int64_t> reduceDims);
/// Infers the source layout attribute for a bitcast operation given the
@@ -124,6 +122,14 @@ DistributeLayoutAttr inferShapeCastSourceLayout(MLIRContext *context,
ArrayRef<int64_t> resShape,
ArrayRef<int64_t> srcShape);
+/// Sets the the layout attribute for result based on a preferred Layout
+/// propagated from consumer
+/// the ouput must be a slice attribute
+SliceAttr
+reductionLayoutSetupRule(ArrayRef<int64_t> srcShape,
+ SmallVector<int64_t> reductionDims,
+ DistributeLayoutAttr consumerPreferredLayout);
+
} // namespace xegpu
} // namespace mlir
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index ccf17da26c942..0ecfe3eac650c 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -400,7 +400,8 @@ bool LayoutAttr::isEqualTo(const xegpu::DistributeLayoutAttr &other) {
}
// set the layout for unit dims: sg_data, inst_data and lane_data to 1
-DistributeLayoutAttr LayoutAttr::setUnitDimData(SetVector<int64_t> unitDims) {
+DistributeLayoutAttr
+LayoutAttr::setUnitDimData(SetVector<int64_t> unitDims) const {
auto sgDataOpt = getSgData();
auto instDataOpt = getInstData();
auto laneDataOpt = getLaneData();
@@ -409,15 +410,14 @@ DistributeLayoutAttr LayoutAttr::setUnitDimData(SetVector<int64_t> unitDims) {
SmallVector<int32_t> instData;
SmallVector<int32_t> laneData;
- if (sgDataOpt) {
+ if (sgDataOpt)
sgData = llvm::to_vector(sgDataOpt.asArrayRef());
- }
- if (instDataOpt) {
+
+ if (instDataOpt)
instData = llvm::to_vector(instDataOpt.asArrayRef());
- }
- if (laneDataOpt) {
+
+ if (laneDataOpt)
laneData = llvm::to_vector(laneDataOpt.asArrayRef());
- }
for (auto dim : unitDims) {
if (dim < static_cast<int64_t>(sgData.size()))
@@ -441,19 +441,18 @@ DistributeLayoutAttr LayoutAttr::setUnitDimData(SetVector<int64_t> unitDims) {
}
// set the layout for the sepcified unit dims: sg_lane and lane_layout to 1
-DistributeLayoutAttr LayoutAttr::setUnitDimLayout(SetVector<int64_t> unitDims) {
+DistributeLayoutAttr
+LayoutAttr::setUnitDimLayout(SetVector<int64_t> unitDims) const {
auto sgLayoutOpt = getSgLayout();
auto laneLayoutOpt = getLaneLayout();
SmallVector<int32_t> sgLayout;
SmallVector<int32_t> laneLayout;
- if (sgLayoutOpt) {
+ if (sgLayoutOpt)
sgLayout = llvm::to_vector(sgLayoutOpt.asArrayRef());
- }
- if (laneLayoutOpt) {
+ if (laneLayoutOpt)
laneLayout = llvm::to_vector(laneLayoutOpt.asArrayRef());
- }
for (auto dim : unitDims) {
if (dim < static_cast<int64_t>(sgLayout.size()))
@@ -472,6 +471,47 @@ DistributeLayoutAttr LayoutAttr::setUnitDimLayout(SetVector<int64_t> unitDims) {
getLaneData(), getOrder());
}
+// Derive a new layout with sg_data, inst_data and lane_data set to the
+// specified values for the given dimension
+DistributeLayoutAttr LayoutAttr::setDimData(int64_t dim, int64_t sgData,
+ int64_t instData,
+ int64_t laneData) {
+ auto sgDataOpt = getSgData();
+ auto instDataOpt = getInstData();
+ auto laneDataOpt = getLaneData();
+
+ SmallVector<int32_t> sgDataVec;
+ SmallVector<int32_t> instDataVec;
+ SmallVector<int32_t> laneDataVec;
+
+ if (sgDataOpt)
+ sgDataVec = llvm::to_vector(sgDataOpt.asArrayRef());
+
+ if (instDataOpt)
+ instDataVec = llvm::to_vector(instDataOpt.asArrayRef());
+
+ if (laneDataOpt)
+ laneDataVec = llvm::to_vector(laneDataOpt.asArrayRef());
+
+ if (dim < static_cast<int64_t>(sgDataVec.size()))
+ sgDataVec[dim] = sgData;
+ if (dim < static_cast<int64_t>(instDataVec.size()))
+ instDataVec[dim] = instData;
+ if (dim < static_cast<int64_t>(laneDataVec.size()))
+ laneDataVec[dim] = laneData;
+
+ return LayoutAttr::get(
+ getContext(), getSgLayout(),
+ sgDataVec.empty() ? DenseI32ArrayAttr()
+ : DenseI32ArrayAttr::get(getContext(), sgDataVec),
+ instDataVec.empty() ? DenseI32ArrayAttr()
+ : DenseI32ArrayAttr::get(getContext(), instDataVec),
+ getLaneLayout(),
+ laneDataVec.empty() ? DenseI32ArrayAttr()
+ : DenseI32ArrayAttr::get(getContext(), laneDataVec),
+ getOrder());
+}
+
//===----------------------------------------------------------------------===//
// XeGPU_SliceAttr
//===----------------------------------------------------------------------===//
@@ -604,55 +644,83 @@ bool SliceAttr::isEqualTo(const xegpu::DistributeLayoutAttr &other) {
}
// Helper function to adjust unit dimensions from sliced space to parent space
+// say we have a parent shape of rank 4, and slice dims [1,3], so the sliced
+// shape is of rank 2, if we want to set unit dim [0] in sliced space, it maps
+// to dim [0] in parent space; if we want to set unit dim [1] in sliced space,
+// it maps to dim [2] in parent space.
static SetVector<int64_t>
-adjustUnitDimsWithSliceDims(const SetVector<int64_t> &unitDims,
- ArrayRef<int64_t> sliceDims) {
- // Reconstruct parent's non-sliced dimensions
-
- int64_t parentRank = sliceDims.size() + unitDims.size();
+mapDimsFromSlicedSpace(const SetVector<int64_t> &dimsInSlice,
+ ArrayRef<int64_t> sliceDims) {
+ // get max number from sliceDims and unitDims to determine parent space rank
+ int64_t maxDim = -1;
+ maxDim =
+ std::max(maxDim, *std::max_element(sliceDims.begin(), sliceDims.end()));
+ maxDim = std::max(maxDim,
+ *std::max_element(dimsInSlice.begin(), dimsInSlice.end()));
+ int64_t parentSpaceRank = maxDim + sliceDims.size() + 1;
+
+ // get remaining dims in parent space after applying slicing with parent's
+ // slice Dims
llvm::SmallDenseSet<int64_t> slicedDimsSet(sliceDims.begin(),
sliceDims.end());
- SmallVector<int64_t> nonSlicedDims;
- for (int64_t i = 0; i < parentRank; ++i) {
+ SmallVector<int64_t> remainingDims;
+ for (int64_t i = 0; i < parentSpaceRank; ++i) {
if (!slicedDimsSet.contains(i))
- nonSlicedDims.push_back(i);
+ remainingDims.push_back(i);
}
// Map unit dims from sliced space to parent space
- SetVector<int64_t> adjustUnitDims;
- for (auto dim : unitDims) {
- if (dim < static_cast<int64_t>(nonSlicedDims.size())) {
- adjustUnitDims.insert(nonSlicedDims[dim]);
- }
+ SetVector<int64_t> dimsInUnSlicedSpace;
+ for (auto dim : dimsInSlice) {
+ int64_t mappedDim = remainingDims[dim];
+ dimsInUnSlicedSpace.insert(mappedDim);
}
- return adjustUnitDims;
+ return dimsInUnSlicedSpace;
}
// set the layout for unit dims: sg_data, inst_data and lane_data to 1
-DistributeLayoutAttr SliceAttr::setUnitDimData(SetVector<int64_t> unitDims) {
- SliceAttr attr = flatten();
- ArrayRef<int64_t> sliceDims = attr.getDims().asArrayRef();
- auto parent = dyn_cast<LayoutAttr>(attr.getParent());
+DistributeLayoutAttr
+SliceAttr::setUnitDimData(SetVector<int64_t> unitDims) const {
+ DistributeLayoutAttr parentLayout = getParent();
+
+ ArrayRef<int64_t> sliceDims = getDims().asArrayRef();
SetVector<int64_t> adjustUnitDims =
- adjustUnitDimsWithSliceDims(unitDims, sliceDims);
+ mapDimsFromSlicedSpace(unitDims, sliceDims);
- return SliceAttr::get(getContext(), parent.setUnitDimData(adjustUnitDims),
- attr.getDims());
+ return SliceAttr::get(getContext(),
+ parentLayout.setUnitDimData(adjustUnitDims), getDims());
}
// set the layout for the sepcified unit dims: sg_lane and lane_layout to 1
-DistributeLayoutAttr SliceAttr::setUnitDimLayout(SetVector<int64_t> unitDims) {
- SliceAttr attr = flatten();
- ArrayRef<int64_t> sliceDims = attr.getDims().asArrayRef();
- auto parent = dyn_cast<LayoutAttr>(attr.getParent());
+DistributeLayoutAttr
+SliceAttr::setUnitDimLayout(SetVector<int64_t> unitDims) const {
+ DistributeLayoutAttr parentLayout = getParent();
+
+ ArrayRef<int64_t> sliceDims = getDims().asArrayRef();
SetVector<int64_t> adjustUnitDims =
- adjustUnitDimsWithSliceDims(unitDims, sliceDims);
+ mapDimsFromSlicedSpace(unitDims, sliceDims);
- return SliceAttr::get(getContext(), parent.setUnitDimLayout(adjustUnitDims),
- attr.getDims());
+ return SliceAttr::get(
+ getContext(), parentLayout.setUnitDimLayout(adjustUnitDims), getDims());
+}
+
+// Derive a new layout with sg_data, inst_data and lane_data set to the
+// specified values for the given dimension
+DistributeLayoutAttr SliceAttr::setDimData(int64_t dim, int64_t sgData,
+ int64_t instData, int64_t laneData) {
+ ArrayRef<int64_t> sliceDims = getDims().asArrayRef();
+ auto parent = dyn_cast<LayoutAttr>(getParent());
+
+ SetVector<int64_t> dimSet;
+ dimSet.insert(dim);
+ SetVector<int64_t> adjustDims = mapDimsFromSlicedSpace(dimSet, sliceDims);
+
+ return SliceAttr::get(
+ getContext(),
+ parent.setDimData(adjustDims[0], sgData, instData, laneData), getDims());
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 60dedf373a07b..1be44480de01d 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -128,6 +128,7 @@ struct LayoutInfo {
}
Attribute get() { return storage; }
+ void set(const xegpu::DistributeLayoutAttr &layout) { storage = layout; }
};
SmallVector<int> LayoutInfo::getLaneLayout() const {
@@ -607,25 +608,58 @@ void LayoutInfoPropagation::visitVectorMultiReductionOp(
llvm::dyn_cast<VectorType>(reduction.getSourceVectorType());
SmallVector<int64_t> reductionDims(reduction.getReductionDims().begin(),
reduction.getReductionDims().end());
- // xegpu::DistributeLayoutAttr operandLayout =
- // xegpu::inferReductionSourceLayout(
- // reduction.getContext(),
- // dyn_cast<xegpu::DistributeLayoutAttr>(resultLayout.get()),
- // resultTy.getShape(),
- // sourceTy.getShape(),
- // reductionDims);
- // We only consider 2D -> 1D reductions at this point.
-
- if (!resultTy || resultTy.getRank() != 1) {
- reduction.emitWarning("Expecting output type to be 1D vector.");
- return;
+
+ auto srcShape = sourceTy.getShape();
+
+ LLVM_DEBUG(DBGS() << "visitVectorMultiReductionOp: srcShape = [";
+ for (auto dim
+ : srcShape) llvm::dbgs()
+ << dim << " ";
+ llvm::dbgs() << "]\n");
+ LLVM_DEBUG(DBGS() << "visitVectorMultiReductionOp: reductionDims = [";
+ for (auto dim
+ : reductionDims) llvm::dbgs()
+ << dim << " ";
+ llvm::dbgs() << "]\n");
+
+ auto resultLayoutAttr =
+ dyn_cast<xegpu::DistributeLayoutAttr>(resultLayout.get());
+
+ LLVM_DEBUG(DBGS() << "visitVectorMultiReductionOp: resultLayoutAttr = "
+ << resultLayoutAttr << "\n");
+
+ // An dominant layout is for the result and represents the layout requirements
+ // for the operation it is recorded to anchor layout or temporary layout it
+ // must be honored for current op and may conflict with the layout propagated
+ // from consumer op the conflict is resolved in later phase by converting the
+ // dominant layout to the source layout
+
+ xegpu::DistributeLayoutAttr dominantLayout = xegpu::reductionLayoutSetupRule(
+ srcShape, reductionDims, resultLayoutAttr);
+
+ LLVM_DEBUG(DBGS() << "visitVectorMultiReductionOp: dominantLayout = "
+ << dominantLayout << "\n");
+
+ if (layoutKind == LayoutKind::Lane) {
+ // only lane layout/data is considered
+ dominantLayout = dominantLayout.dropInstData();
+ dominantLayout = dominantLayout.dropSgLayoutAndData();
+ } else if (layoutKind == LayoutKind::InstData) {
+ dominantLayout = dominantLayout.dropSgLayoutAndData();
}
- auto uArch = getUArch(xegpu::getChipStr(reduction).value_or(""));
- // Given that the result is 1D, the layout of the operand should be 2D with
- // default layout.
- LayoutInfo operandLayout = getDefaultSIMTLayoutInfo(
- reduction->getContext(), 2, uArch->getSubgroupSize());
- propagateIfChanged(operands[0], operands[0]->meet(operandLayout));
+
+ // record the dominant layout to the reduction op
+ xegpu::setTemporaryLayout(reduction->getResult(0), dominantLayout);
+
+ // derive the source layout from the dominant layout and reduction dims
+ auto srcLayoutAttr =
+ xegpu::inferReductionSourceLayout(dominantLayout, reductionDims);
+ // void set(const xegpu::DistributeLayoutAttr &layout) { storage = layout; }
+
+ LLVM_DEBUG(DBGS() << "visitVectorMultiReductionOp: srcLayoutAttr = "
+ << srcLayoutAttr << "\n");
+
+ propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
// Accumulator should have the same layout as the result.
propagateIfChanged(operands[1], operands[1]->meet(resultLayout));
}
@@ -637,6 +671,7 @@ void LayoutInfoPropagation::visitVectorBroadCastOp(
LayoutInfo resultLayout = results[0]->getValue();
if (!resultLayout.isAssigned())
return;
+
// Only consider vector to vector broadcasts for now.
VectorType resultTy = broadcast.getResultVectorType();
VectorType sourceTy = dyn_cast<VectorType>(broadcast.getSourceType());
@@ -644,20 +679,23 @@ void LayoutInfoPropagation::visitVectorBroadCastOp(
if (!sourceTy)
return;
- // Hanlding broadcast from low-rank to high-rank (e.g., 1D to 2D) case.
- if (sourceTy.getRank() != resultTy.getRank()) {
- auto srcShape = sourceTy.getShape();
- auto resShape = resultTy.getShape();
- auto resultLayoutAttr =
- dyn_cast<xegpu::DistributeLayoutAttr>(resultLayout.get());
+ auto srcShape = sourceTy.getShape();
+ auto resShape = resultTy.getShape();
- xegpu::DistributeLayoutAttr sliceLayout = xegpu::inferBroadCastSourceLayout(
- broadcast.getContext(), resultLayoutAttr, resShape, srcShape);
+ size_t dimDiff = resultTy.getRank() - sourceTy.getRank();
+ for (size_t i = 0; i < srcShape.size(); i++)
+ if ((srcShape[i] == 1) && (resShape[i + dimDiff] != 1))
+ broadcast.emitWarning("broadcast must either from low-rank or same-rank "
+ "with unit-dim, mixed scenario is not supported!");
- propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(sliceLayout)));
- return;
- }
- propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
+ auto resultLayoutAttr =
+ dyn_cast<xegpu::DistributeLayoutAttr>(resultLayout.get());
+
+ xegpu::DistributeLayoutAttr resLayout = xegpu::inferBroadCastSourceLayout(
+ broadcast.getContext(), resultLayoutAttr, resShape, srcShape);
+
+ propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(resLayout)));
+ return;
}
void LayoutInfoPropagation::visitShapeCastOp(
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp
index 2f68aa7bda924..58e222812661f 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp
@@ -263,6 +263,35 @@ bool xegpu::recoverTemporaryLayouts(Operation *rootOp) {
return !result.wasInterrupted();
}
+// Prerequisite for Layout Recovery
+// It relies on the following invariant:
+// 1. there is no layout conflict between different uses of the same definition.
+// 2. each definition has a well-defined layout requirement at its use point.
+// - Every definition must have at least one use that appears after it in
+// topological order.
+// - If a definition has no such use (e.g., a loop result or region output),
+// an explicit convert_layout operation is inserted to create a use.
+// - Only the result of convert_layout is permitted to have no subsequent
+// use.
+
+// The recover proceeds by scanning the operation in reverse topological orderas
+// follows: Across operations: layouts are propagated from uses to definitions.
+// Within an operation: layouts are propagated from definitions (result) to uses
+// (operands).
+// For region operations (e.g., loops):
+// - When backward propagation reaches a region op, it sets the layout of
+// the region op’s results according to use points like regular ops.
+// - Then, the result layouts (such as a loop output) are propagated to
+// thiers corresponding operands in the yield.
+// - When backward propagation reaches the first operation inside the
+// region, the pass examines the region op’s initialization list,
+// propagating from region arguments to the corresponding initialization
+// operands.
+// - This ensures that layout constraints are consistently propagated
+// across region boundaries
+// while preserving a single well-defined use for each definition at the
+// region-op level.
+
template <typename T, typename>
void xegpu::removeLayoutAttr(const T &operandOrResult) {
Operation *owner = operandOrResult.getOwner();
@@ -297,54 +326,365 @@ void xegpu::removeLayoutAttrs(Operation *op) {
}
/// Infers the source layout attribute for a broadcast operation given the
-/// result layout attribute, result shape, source shape, and broadcasted dims.
+/// result layout attribute, result shape, source shape.
xegpu::DistributeLayoutAttr xegpu::inferBroadCastSourceLayout(
MLIRContext *context, xegpu::DistributeLayoutAttr resLayout,
ArrayRef<int64_t> resShape, ArrayRef<int64_t> srcShape) {
SmallVector<int64_t> bcastDims;
+ auto returnLayout = resLayout;
+
+ // Hanlding broadcast from low-rank to high-rank (e.g., 1D to 2D) case.
int dimDiff = resShape.size() - srcShape.size();
- // adding the missing leading dims
- for (int i = 0; i < dimDiff; i++)
- bcastDims.push_back(i);
- // for the rest dims in the resultTy, if sourceTy dim is 1, then it's
- // broadcasted dim
- // for (size_t i = 0; i < srcShape.size(); i++)
- // if ((srcShape[i] == 1) && (resShape[i + dimDiff] != 1))
- // bcastDims.push_back(i + dimDiff);
+ if (dimDiff > 0) {
+ // adding the missing leading dims
+ for (int i = 0; i < dimDiff; i++)
+ bcastDims.push_back(i);
+
+ // create a slice layout for the source
+ returnLayout = xegpu::SliceAttr::get(
+ context, resLayout, DenseI64ArrayAttr::get(context, bcastDims));
+ }
+ return returnLayout;
+}
+
+/// Infers the source layout attribute for a reduction operation given the
+/// result layout attribute and reduced dims.
+xegpu::DistributeLayoutAttr
+xegpu::inferReductionSourceLayout(xegpu::DistributeLayoutAttr resLayout,
+ SmallVector<int64_t> reduceDims) {
+
+ // assert the resLayout must be slice layout
+ assert(isa<xegpu::SliceAttr>(resLayout) &&
+ "reduction result layout must be slice layout");
- // create a slice layout for the source
- xegpu::SliceAttr sliceLayout = xegpu::SliceAttr::get(
- context, resLayout, DenseI64ArrayAttr::get(context, bcastDims));
+ // assert that the reduceDims must match with the slice dims of resLayout
+ xegpu::SliceAttr sliceLayout = dyn_cast<xegpu::SliceAttr>(resLayout);
+ auto sliceDims = sliceLayout.getDims().asArrayRef();
+ assert(reduceDims == sliceDims &&
+ "reduction dims must match with slice dims");
- return sliceLayout;
+ // then return the parent layout of sliceLayout
+ return sliceLayout.getParent();
}
-// /// Infers the source layout attribute for a reduction operation given the
-// /// result layout attribute, result shape, source shape, and reduced dims.
-// xegpu::DistributeLayoutAttr xegpu::inferReductionSourceLayout(
-// MLIRContext *context,
-// xegpu::DistributeLayoutAttr resLayout,
-// SmallVector<int64_t> reduceDims){
+// /// Infers the source layout attribute for a bitcast operation given the
+// /// result layout attribute, result element type bitwidth, and source element
+// /// type bitwidth.
+// xegpu::DistributeLayoutAttr
+// xegpu::inferBitCastSourceLayout(MLIRContext *context,
+// xegpu::DistributeLayoutAttr resLayout,
+// int resElemTyBitWidth, int
+// srcElemTyBitWidth){
+// // the result and source layout must be the same
+// // if resLayout is SliceAttr, we need to first get its root layout
+// xegpu::DistributeLayoutAttr resRootLayout = resLayout;
+// while (auto sliceLayout = dyn_cast<xegpu::SliceAttr>(resRootLayout)) {
+// resRootLayout = sliceLayout.getParent();
+// }
+// // change the laneData of resRootLayout according to the bitwidth ratio
+// xegpu::LayoutAttr resRootPlainLayout =
+// dyn_cast<xegpu::LayoutAttr>(resRootLayout); SmallVector<int64_t> laneData =
+// resRootPlainLayout.getEffectiveLaneDataAsInt();
+
+// if (srcElemTyBitWidth >= resElemTyBitWidth) {
+// int bitWidthRatio = srcElemTyBitWidth / resElemTyBitWidth;
+// laneData[laneData.size()-1] = laneData[laneData.size()-1] *
+// bitWidthRatio;
+// } else {
+// int bitWidthRatio = resElemTyBitWidth / srcElemTyBitWidth;
+// assert((laneData[laneData.size()-2] % bitWidthRatio) == 0 &&
+// "laneData not divisible by bitWidthRatio");
+// laneData[laneData.size()-1] = laneData[laneData.size()-1] /
+// bitWidthRatio;
+// }
+
+// // now reconstruct the source layout with updated laneData
+// // by updating the root layout and going throught the slice layers
+// SmallVector<int32_t> laneData32(laneData.begin(), laneData.end());
+// xegpu::LayoutAttr proposedSrcLayout = xegpu::LayoutAttr::get(
+// context,
+// resRootPlainLayout.getSgLayout(),
+// resRootPlainLayout.getSgData(),
+// resRootPlainLayout.getInstData(),
+// resRootPlainLayout.getLaneLayout(),
+// DenseI32ArrayAttr::get(context, laneData32),
+// resRootPlainLayout.getOrder());
+
+// // reconstruct slice layers if any
+// // First collect all slice layers from innermost to outermost
+// SmallVector<DenseI64ArrayAttr> sliceDims;
+// xegpu::DistributeLayoutAttr currentLayout = resLayout;
+// while (auto sliceLayout = dyn_cast<xegpu::SliceAttr>(currentLayout)) {
+// sliceDims.push_back(sliceLayout.getDims());
+// currentLayout = sliceLayout.getParent();
+// }
+
+// // Now rebuild from outermost to innermost (reverse order)
+// xegpu::DistributeLayoutAttr finalSrcLayout = proposedSrcLayout;
+// for (auto it = sliceDims.rbegin(); it != sliceDims.rend(); ++it) {
+// finalSrcLayout = xegpu::SliceAttr::get(context, finalSrcLayout, *it);
+// }
+// return finalSrcLayout;
+// }
+
+// /// Infers the source layout attribute for a shape cast operation given the
+// /// result layout attribute, result shape, and source shape.
+// xegpu::DistributeLayoutAttr xegpu::inferShapeCastSourceLayout(
+// MLIRContext *context, xegpu::DistributeLayoutAttr resLayout,
+// ArrayRef<int64_t> resShape, ArrayRef<int64_t> srcShape){
+
+// // there are two use cases:
+// // 1. expand dims of low-rank dimensions (e.g., 1D to 2D): to set up the
+// tensor before broadcast
+// // 2. split dim of a high-rank dimension (e.g., 1D to 2D): to setup tensor
+// for multi-stage reduction
+
+// SmallVector<int64_t> shapeCastDims;
+// auto returnLayout = resLayout;
+
+// int resRank = resShape.size();
+// int srcRank = srcShape.size();
+
+// if (srcRank < resRank) {
+// // Case 1: expand dims of low-rank dimensions (e.g., 1D to 2D)
+// int dimDiff = resRank - srcRank;
+// // adding the missing leading dims
+// for (int i = 0; i < dimDiff; i++)
+// shapeCastDims.push_back(i);
+
+// // create a slice layout for the source
+// returnLayout = xegpu::SliceAttr::get(
+// context, resLayout, DenseI64ArrayAttr::get(context, shapeCastDims));
+// } else if (srcRank > resRank) {
+// // Case 2: split dim of a high-rank dimension (e.g., 1D to 2D)
+// // find the split dims by comparing srcShape and resShape
+// int srcIdx = 0;
+// int resIdx = 0;
+// while (srcIdx < srcRank && resIdx < resRank) {
+// if (srcShape[srcIdx] == resShape[resIdx]) {
+// srcIdx++;
+// resIdx++;
+// } else if (srcShape[srcIdx] < resShape[resIdx]) {
+// shapeCastDims.push_back(srcIdx);
+// srcIdx++;
+// } else {
+// // this should not happen in valid shape cast
+// assert(false && "Invalid shape cast: source shape dimension smaller
+// than result shape dimension");
+// }
+// }
+// // handle remaining src dims
+// while (srcIdx < srcRank) {
+// shapeCastDims.push_back(srcIdx);
+// srcIdx++;
+// }
+
+// // create a slice layout for the source
+// returnLayout = xegpu::SliceAttr::get(
+// context, resLayout, DenseI64ArrayAttr::get(context, shapeCastDims));
+// }
+// return returnLayout;
+
+// }
+
+xegpu::SliceAttr
+xegpu::reductionLayoutSetupRule(ArrayRef<int64_t> srcShape,
+ SmallVector<int64_t> reductionDims,
+ DistributeLayoutAttr consumerPreferredLayout) {
+
+ xegpu::SliceAttr sliceCPL =
+ dyn_cast<xegpu::SliceAttr>(consumerPreferredLayout);
+
+ // try to align wiht customer's preferred layout so that the slice layout
+ // structure is preserved, and thus avoid potential data movement acorss sg or
+ // lanes.
+
+ const int workgroupSize = 16; // assuming 16 subgroups for now
+ const int subgroupSize = 16; // assuming 16 lanes per subgroup
+ const int vectorSize = 8; // assuming 8 elements per vector lane
+ int srcShapeSize = srcShape.size();
+ xegpu::DistributeLayoutAttr proposedSrcLayout;
+ auto context = consumerPreferredLayout.getContext();
+ // if srcShapeSize is less than 2, we cannot proceed
+ if (srcShapeSize < 2)
+ return nullptr;
-// // flatten the reslayout first
-// xegpu::DistributeLayoutAttr flatResLayout =
-// resLayout.flatten();
-// // then unslice the reduceDims in the flattened layout
+ llvm::errs() << "DEBUG: Entering \n";
-//}
+ SmallVector<int64_t> sgLayout(srcShapeSize);
+ SmallVector<int64_t> sgData(srcShapeSize);
-/// Infers the source layout attribute for a bitcast operation given the
-/// result layout attribute, result element type bitwidth, and source element
-/// type bitwidth.
-xegpu::DistributeLayoutAttr
-xegpu::inferBitCastSourceLayout(MLIRContext *context,
- xegpu::DistributeLayoutAttr resLayout,
- int resElemTyBitWidth, int srcElemTyBitWidth);
+ SmallVector<int64_t> instData(srcShapeSize, 1);
+ instData[srcShapeSize - 1] = subgroupSize;
+ instData[srcShapeSize - 2] =
+ vectorSize; // assuming 8 elements per instruction as starting point
+ llvm::errs() << "DEBUG: Initial instData = [";
+ for (size_t i = 0; i < instData.size(); i++) {
+ llvm::errs() << instData[i];
+ if (i < instData.size() - 1)
+ llvm::errs() << ", ";
+ }
+ llvm::errs() << "]\n";
+ // construct a vector layout with lane_layout = [1, ..., 1, subgroupSize]
+ SmallVector<int64_t> laneLayout(srcShapeSize, 1);
+ laneLayout[srcShapeSize - 1] = subgroupSize;
+ llvm::errs() << "DEBUG: laneLayout = [";
+ for (size_t i = 0; i < laneLayout.size(); i++) {
+ llvm::errs() << laneLayout[i];
+ if (i < laneLayout.size() - 1)
+ llvm::errs() << ", ";
+ }
+ llvm::errs() << "]\n";
+ // construct a vector layout with lane_data = [1, ..., 1]
+ SmallVector<int64_t> laneData(srcShapeSize, 1);
+
+ bool failToAlignSliceStruct = false;
+ if (sliceCPL && sliceCPL.getDims().asArrayRef().equals(reductionDims)) {
+
+ xegpu::DistributeLayoutAttr parentCPL = sliceCPL.getParent();
+
+ // for each slice dim in source shape, if the dim size is differnt than the
+ // result shape, try to adjust the sg_data/inst_data accordingly.
+ SmallVector<int64_t> pcplSgLayout = parentCPL.getEffectiveSgLayoutAsInt();
+ SmallVector<int64_t> pcplLaneLayout =
+ parentCPL.getEffectiveLaneLayoutAsInt();
+ SmallVector<int64_t> pcplLaneData = parentCPL.getEffectiveLaneDataAsInt();
+
+ assert(srcShapeSize == parentCPL.getRank() &&
+ "parent layout rank must match source shape rank");
+
+ proposedSrcLayout = parentCPL;
+
+ llvm::errs() << "DEBUG: srcShapeSize = " << srcShapeSize << "\n";
+ llvm::errs() << "DEBUG: parentCPL rank = " << parentCPL.getRank() << "\n";
+ llvm::errs() << "DEBUG: srcShape = [";
+ for (int i = 0; i < srcShapeSize; i++) {
+ llvm::errs() << srcShape[i];
+ if (i < srcShapeSize - 1)
+ llvm::errs() << ", ";
+ }
+ llvm::errs() << "]\n";
+
+ if (pcplSgLayout.size() == static_cast<size_t>(srcShapeSize)) {
+ for (int i = 0; i < srcShapeSize; i++) {
+ if (srcShape[i] % pcplSgLayout[i] == 0) {
+ sgLayout[i] = pcplSgLayout[i];
+ sgData[i] = srcShape[i] / sgLayout[i];
+ instData[i] = std::min(instData[i], sgData[i]);
+ } else {
+ failToAlignSliceStruct = true;
+ break;
+ }
+ }
+ }
-/// Infers the source layout attribute for a shape cast operation given the
-/// result layout attribute, result shape, and source shape.
-xegpu::DistributeLayoutAttr xegpu::inferShapeCastSourceLayout(
- MLIRContext *context, xegpu::DistributeLayoutAttr resLayout,
- ArrayRef<int64_t> resShape, ArrayRef<int64_t> srcShape);
+ if (pcplLaneLayout.size() == static_cast<size_t>(srcShapeSize)) {
+ for (int i = 0; i < srcShapeSize; i++) {
+ if (instData[i] % pcplLaneLayout[i] == 0) {
+ laneLayout[i] = pcplLaneLayout[i];
+ laneData[i] = pcplLaneData[i];
+ } else {
+ failToAlignSliceStruct = true;
+ break;
+ }
+ }
+ }
+ } else {
+ failToAlignSliceStruct = true;
+ }
+
+ if (failToAlignSliceStruct) {
+
+ // try to align the sg layout
+ SmallVector<int64_t> cplSgLayout =
+ consumerPreferredLayout.getEffectiveSgLayoutAsInt();
+ llvm::errs() << "DEBUG: cplSgLayout size = " << cplSgLayout.size() << "\n";
+ // if sg layout doesn't cover all the sg ids, distribute rest to
+ // non-reduction dims
+ int remainingSgCount = workgroupSize;
+
+ SmallVector<int64_t> remainingDims;
+ // print debug info for consumerPreferredLayout and cplSgLayout
+ llvm::errs() << "DEBUG: consumerPreferredLayout sgLayout = [";
+ auto cplSgLayoutFull = consumerPreferredLayout.getEffectiveSgLayoutAsInt();
+ for (size_t i = 0; i < cplSgLayoutFull.size(); i++) {
+ llvm::errs() << cplSgLayoutFull[i];
+ if (i < cplSgLayoutFull.size() - 1)
+ llvm::errs() << ", ";
+ }
+ // if cplSgLayout is not empty, try to align the sg layout first
+ int cplId = cplSgLayout.size() - 1;
+ llvm::errs() << "DEBUG: Starting first loop, cplId = " << cplId << "\n";
+ for (int i = srcShapeSize - 1; i >= 0; i--) {
+ llvm::errs() << "DEBUG: Loop 1, i = " << i << ", is_reduction_dim = "
+ << llvm::is_contained(reductionDims, i) << "\n";
+ // try to align with cplSgLayout first for non-reduction dims
+ if (!llvm::is_contained(reductionDims, i) && cplId >= 0) {
+ if (srcShape[i] % cplSgLayout[cplId] == 0) {
+ sgLayout[i] = cplSgLayout[cplId];
+ sgData[i] = srcShape[i] / sgLayout[i];
+ instData[i] = std::min(instData[i], sgData[i]);
+ remainingSgCount /= sgLayout[i];
+ llvm::errs() << "DEBUG: Set sgLayout[" << i << "] = " << sgLayout[i]
+ << ", sgData[" << i << "] = " << sgData[i]
+ << ", remainingSgCount = " << remainingSgCount << "\n";
+ cplId--;
+ continue;
+ }
+ }
+ remainingDims.push_back(i);
+ llvm::errs() << "DEBUG: Added i = " << i << " to remainingDims\n";
+ }
+
+ llvm::errs() << "DEBUG: Starting second loop\n";
+ for (int i = srcShapeSize - 1; i >= 0; i--) {
+ llvm::errs() << "DEBUG: Loop 2, i = " << i << ", is_remaining_dim = "
+ << llvm::is_contained(remainingDims, i) << "\n";
+ if (llvm::is_contained(remainingDims, i)) {
+
+ llvm::errs() << "DEBUG: Before Set sgLayout[" << i
+ << "] = " << sgLayout[i] << ", sgData[" << i
+ << "] = " << sgData[i]
+ << ", remainingSgCount = " << remainingSgCount << "\n";
+
+ sgLayout[i] = std::min((srcShape[i] / laneLayout[i]),
+ static_cast<int64_t>(remainingSgCount));
+ sgData[i] = srcShape[i] / sgLayout[i];
+ instData[i] = std::min(instData[i], sgData[i]);
+ remainingSgCount /= sgLayout[i];
+
+ llvm::errs() << "DEBUG: After Set sgLayout[" << i
+ << "] = " << sgLayout[i] << ", sgData[" << i
+ << "] = " << sgData[i]
+ << ", remainingSgCount = " << remainingSgCount << "\n";
+
+ if (remainingSgCount == 1) {
+ llvm::errs() << "DEBUG: Breaking from loop 2, remainingSgCount = 1\n";
+ break;
+ }
+ }
+ }
+ }
+ // Convert int64_t vectors to int32_t for DenseI32ArrayAttr
+ SmallVector<int32_t> sgLayout32(sgLayout.begin(), sgLayout.end());
+ SmallVector<int32_t> sgData32(sgData.begin(), sgData.end());
+ SmallVector<int32_t> instData32(instData.begin(), instData.end());
+ SmallVector<int32_t> laneLayout32(laneLayout.begin(), laneLayout.end());
+ SmallVector<int32_t> laneData32(laneData.begin(), laneData.end());
+ proposedSrcLayout = xegpu::LayoutAttr::get(
+ context, DenseI32ArrayAttr::get(context, sgLayout32),
+ DenseI32ArrayAttr::get(context, sgData32),
+ DenseI32ArrayAttr::get(context, instData32),
+ DenseI32ArrayAttr::get(context, laneLayout32),
+ DenseI32ArrayAttr::get(context, laneData32),
+ consumerPreferredLayout.getOrder());
+
+ // finally, create the slice layout for reduction source
+ xegpu::SliceAttr reductionSrcLayout =
+ xegpu::SliceAttr::get(context, proposedSrcLayout,
+ DenseI64ArrayAttr::get(context, reductionDims));
+
+ return reductionSrcLayout;
+}
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout.mlir b/mlir/test/Dialect/XeGPU/propagate-layout.mlir
index b88d8e1a78a26..cccc446f2e4c4 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout.mlir
@@ -483,7 +483,7 @@ func.func @if_multiple_uses(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.t
gpu.module @test {
// CHECK-LABEL: func.func @vector_outer_reduction(
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: vector<16x16xf32>, %[[ARG1:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>>) {
-// CHECK: %{{.*}} = vector.multi_reduction <add>, %[[ARG0]], %{{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} [0] : vector<16x16xf32> to vector<16xf32>
+// CHECK: %{{.*}} = vector.multi_reduction <add>, %[[ARG0]], %{{.*}} {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>} [0] : vector<16x16xf32> to vector<16xf32>
func.func @vector_outer_reduction(%arg0: vector<16x16xf32>, %arg1: !xegpu.tensor_desc<16xf32>) {
%cst = arith.constant dense<0.000000e+00> : vector<16xf32>
%0 = vector.multi_reduction <add>, %arg0, %cst [0] : vector<16x16xf32> to vector<16xf32>
@@ -495,7 +495,7 @@ func.func @vector_outer_reduction(%arg0: vector<16x16xf32>, %arg1: !xegpu.tensor
gpu.module @test {
// CHECK-LABEL: func.func @vector_inner_reduction(
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: vector<16x16xf32>, %[[ARG1:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>>) {
-// CHECK: %{{.*}} = vector.multi_reduction <add>, %[[ARG0]], %{{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} [1] : vector<16x16xf32> to vector<16xf32>
+// CHECK: %{{.*}} = vector.multi_reduction <add>, %[[ARG0]], %{{.*}} {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [1]>} [1] : vector<16x16xf32> to vector<16xf32>
func.func @vector_inner_reduction(%arg0: vector<16x16xf32>, %arg1: !xegpu.tensor_desc<16xf32>) {
%cst = arith.constant dense<0.000000e+00> : vector<16xf32>
%0 = vector.multi_reduction <add>, %arg0, %cst [1] : vector<16x16xf32> to vector<16xf32>
>From a347f16fdcf87e69eda55fcfd325fc0594541ae3 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Wed, 24 Dec 2025 06:26:56 +0000
Subject: [PATCH 03/35] add infer rule for bitcast and shapecast
---
.../mlir/Dialect/XeGPU/IR/XeGPUAttrs.td | 18 +-
mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp | 118 ++++++++
.../Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp | 286 ++++++++++--------
3 files changed, 300 insertions(+), 122 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index eae2cbe24a72c..473e485c1faee 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -243,7 +243,13 @@ def DistributeLayoutAttr: AttrInterface<"DistributeLayoutAttr"> {
(ins "int64_t": $dim,
"int64_t": $sgData,
"int64_t": $instData,
- "int64_t": $laneData)>,
+ "int64_t": $laneData)>,
+ InterfaceMethod<[{Derive a new layout by collapsing groups of dimensions. Each inner array in
+ `dimGroups` specifies a group of dimensions that are collapsed into a single
+ dimension in the derived layout.}],
+ "xegpu::DistributeLayoutAttr",
+ "collapseDims",
+ (ins "ArrayRef<ArrayRef<int64_t>>": $dimGroups)>,
InterfaceMethod<[{Generates instructions to compute multidimensional coordinates for dist units
assigned to a level identified by linearId. The shape parameter
represents the higher-level problem size. Each level may access
@@ -518,6 +524,11 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttr]> {
// specified values for the given dimension
DistributeLayoutAttr setDimData(int64_t dim, int64_t sgData, int64_t instData, int64_t laneData);
+ // Derive a new layout by collapsing groups of dimensions.
+ // Each inner array in `dimGroups` specifies a set of dimensions
+ // that are collapsed into a single dimension in the derived layout.
+ DistributeLayoutAttr collapseDims(ArrayRef<ArrayRef<int64_t>> dimGroups);
+
/// Delinearizes a linear ID into its multidimensional indices
/// based on the effective level of the layout.
FailureOr<SmallVector<Value>>
@@ -693,6 +704,11 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttr]> {
// specified values for the given dimension
DistributeLayoutAttr setDimData(int64_t dim, int64_t sgData, int64_t instData, int64_t laneData);
+ // Derive a new layout by collapsing groups of dimensions.
+ // Each inner array in `dimGroups` specifies a set of dimensions
+ // that are collapsed into a single dimension in the derived layout.
+ DistributeLayoutAttr collapseDims(ArrayRef<ArrayRef<int64_t>> dimGroups);
+
/// flatten a nested SliceAttr, e.g., for 2-level nested SliceAttr
/// #xegpu.slice<#xegpu.slice<#xegpu.layout<sg_layout = [4, 8, 12]>, dims = [0]>, dims = [0]>
/// it will coalese two slice operations and return a simplified SliceAttr
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 0ecfe3eac650c..613651ec7f964 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -512,6 +512,102 @@ DistributeLayoutAttr LayoutAttr::setDimData(int64_t dim, int64_t sgData,
getOrder());
}
+// Derive a new layout by collapsing groups of dimensions.
+// Each inner array in `dimGroups` specifies a set of dimensions
+// that are collapsed into a single dimension in the derived layout.
+DistributeLayoutAttr
+LayoutAttr::collapseDims(ArrayRef<ArrayRef<int64_t>> dimGroups) {
+
+ // Extract layout attributes as vectors
+ SmallVector<int64_t> sgLayout = getEffectiveSgLayoutAsInt();
+ SmallVector<int64_t> sgData = getEffectiveSgDataAsInt();
+ SmallVector<int64_t> instData = getEffectiveInstDataAsInt();
+ SmallVector<int64_t> laneLayout = getEffectiveLaneLayoutAsInt();
+ SmallVector<int64_t> laneData = getEffectiveLaneDataAsInt();
+
+ DenseI32ArrayAttr orderAttr = getOrder();
+ SmallVector<int64_t> order;
+ if (orderAttr && !orderAttr.empty()) {
+ order = llvm::to_vector(
+ llvm::map_range(orderAttr.asArrayRef(),
+ [](int32_t idx) { return static_cast<int64_t>(idx); }));
+ } else {
+ // Default order: [1, 0] for 2D (row-major), [2, 1, 0] for 3D, etc.
+ order =
+ llvm::to_vector(llvm::reverse(llvm::seq<int64_t>(0, sgLayout.size())));
+ }
+
+ SmallVector<int64_t> collapsedSgLayout;
+ SmallVector<int64_t> collapsedSgData;
+ SmallVector<int64_t> collapsedInstData;
+ SmallVector<int64_t> collapsedLaneLayout;
+ SmallVector<int64_t> collapsedLaneData;
+ SmallVector<int64_t> collapsedOrder;
+
+ for (const auto &group : dimGroups) {
+
+ // Collapse by multiplying values across dimension group
+ int64_t collapsedSg = 1, collapsedSgD = 1, collapsedInst = 1;
+ int64_t collapsedLaneL = 1, collapsedLaneD = 1;
+ int64_t collapsedOrderValue = -1;
+
+ for (int64_t dimIdx : group) {
+ collapsedSg *= sgLayout[dimIdx];
+ collapsedSgD *= sgData[dimIdx];
+ collapsedInst *= instData[dimIdx];
+ collapsedLaneL *= laneLayout[dimIdx];
+ collapsedLaneD *= laneData[dimIdx];
+ collapsedOrderValue = order[dimIdx]; // take the last one's order
+ }
+
+ collapsedSgLayout.push_back(collapsedSg);
+ collapsedSgData.push_back(collapsedSgD);
+ collapsedInstData.push_back(collapsedInst);
+ collapsedLaneLayout.push_back(collapsedLaneL);
+ collapsedLaneData.push_back(collapsedLaneD);
+ collapsedOrder.push_back(collapsedOrderValue);
+ }
+
+ // go through the values inside collapsedOrder, and re-map the order values to
+ // be in range of [0, N-1] where N is the number of dimensions in collapsed
+ // shape
+ int64_t orderSize = static_cast<int64_t>(collapsedOrder.size());
+ SmallVector<int64_t> remappedOrder(orderSize, -1);
+ for (int64_t i = 0; i < orderSize; ++i) {
+ int64_t originalOrderValue = collapsedOrder[i];
+ // count how many values in collapsedOrder are less than originalOrderValue
+ int64_t count = 0;
+ for (int64_t j = 0; j < orderSize; ++j) {
+ if (collapsedOrder[j] < originalOrderValue)
+ count++;
+ }
+ remappedOrder[i] = count;
+ }
+
+ // Create collapsed layout
+ SmallVector<int32_t> collapsedSgLayout32(collapsedSgLayout.begin(),
+ collapsedSgLayout.end());
+ SmallVector<int32_t> collapsedSgData32(collapsedSgData.begin(),
+ collapsedSgData.end());
+ SmallVector<int32_t> collapsedInstData32(collapsedInstData.begin(),
+ collapsedInstData.end());
+ SmallVector<int32_t> collapsedLaneLayout32(collapsedLaneLayout.begin(),
+ collapsedLaneLayout.end());
+ SmallVector<int32_t> collapsedLaneData32(collapsedLaneData.begin(),
+ collapsedLaneData.end());
+ SmallVector<int32_t> remappedOrder32(remappedOrder.begin(),
+ remappedOrder.end());
+
+ auto collapsedLayout = xegpu::LayoutAttr::get(
+ getContext(), DenseI32ArrayAttr::get(getContext(), collapsedSgLayout32),
+ DenseI32ArrayAttr::get(getContext(), collapsedSgData32),
+ DenseI32ArrayAttr::get(getContext(), collapsedInstData32),
+ DenseI32ArrayAttr::get(getContext(), collapsedLaneLayout32),
+ DenseI32ArrayAttr::get(getContext(), collapsedLaneData32),
+ DenseI32ArrayAttr::get(getContext(), remappedOrder32));
+ return collapsedLayout;
+}
+
//===----------------------------------------------------------------------===//
// XeGPU_SliceAttr
//===----------------------------------------------------------------------===//
@@ -723,6 +819,28 @@ DistributeLayoutAttr SliceAttr::setDimData(int64_t dim, int64_t sgData,
parent.setDimData(adjustDims[0], sgData, instData, laneData), getDims());
}
+// Derive a new layout by collapsing groups of dimensions.
+// Each inner array in `dimGroups` specifies a set of dimensions
+// that are collapsed into a single dimension in the derived layout.
+DistributeLayoutAttr
+SliceAttr::collapseDims(ArrayRef<ArrayRef<int64_t>> dimGroups) const {
+
+ // Map the sliced dims from parent space to collapsed space
+ ArrayRef<int64_t> sliceDims = getDims().asArrayRef();
+
+ // go through dimGroups and map each dim from sliced space to parent space
+ SmallVector<SmallVector<int64_t>> adjustedDimGroups;
+ for (const auto &group : dimGroups) {
+ SetVector<int64_t> mappedDims = mapDimsFromSlicedSpace(group, sliceDims);
+ adjustedDimGroups.push_back(mappedDims.getArrayRef());
+ }
+
+ auto collapsedParent = getParent().collapseDims(adjustedDimGroups);
+
+ return SliceAttr::get(getContext(), collapsedParent,
+ DenseI64ArrayAttr::get(getContext(), sliceDims));
+}
+
//===----------------------------------------------------------------------===//
// XeGPU_RangeAttr
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp
index 58e222812661f..acb5f07893ee2 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp
@@ -291,6 +291,7 @@ bool xegpu::recoverTemporaryLayouts(Operation *rootOp) {
// across region boundaries
// while preserving a single well-defined use for each definition at the
// region-op level.
+bool xegpu::recoverTemporaryLayouts_first(Operation *rootOp) {}
template <typename T, typename>
void xegpu::removeLayoutAttr(const T &operandOrResult) {
@@ -369,125 +370,168 @@ xegpu::inferReductionSourceLayout(xegpu::DistributeLayoutAttr resLayout,
return sliceLayout.getParent();
}
-// /// Infers the source layout attribute for a bitcast operation given the
-// /// result layout attribute, result element type bitwidth, and source element
-// /// type bitwidth.
-// xegpu::DistributeLayoutAttr
-// xegpu::inferBitCastSourceLayout(MLIRContext *context,
-// xegpu::DistributeLayoutAttr resLayout,
-// int resElemTyBitWidth, int
-// srcElemTyBitWidth){
-// // the result and source layout must be the same
-// // if resLayout is SliceAttr, we need to first get its root layout
-// xegpu::DistributeLayoutAttr resRootLayout = resLayout;
-// while (auto sliceLayout = dyn_cast<xegpu::SliceAttr>(resRootLayout)) {
-// resRootLayout = sliceLayout.getParent();
-// }
-// // change the laneData of resRootLayout according to the bitwidth ratio
-// xegpu::LayoutAttr resRootPlainLayout =
-// dyn_cast<xegpu::LayoutAttr>(resRootLayout); SmallVector<int64_t> laneData =
-// resRootPlainLayout.getEffectiveLaneDataAsInt();
-
-// if (srcElemTyBitWidth >= resElemTyBitWidth) {
-// int bitWidthRatio = srcElemTyBitWidth / resElemTyBitWidth;
-// laneData[laneData.size()-1] = laneData[laneData.size()-1] *
-// bitWidthRatio;
-// } else {
-// int bitWidthRatio = resElemTyBitWidth / srcElemTyBitWidth;
-// assert((laneData[laneData.size()-2] % bitWidthRatio) == 0 &&
-// "laneData not divisible by bitWidthRatio");
-// laneData[laneData.size()-1] = laneData[laneData.size()-1] /
-// bitWidthRatio;
-// }
-
-// // now reconstruct the source layout with updated laneData
-// // by updating the root layout and going throught the slice layers
-// SmallVector<int32_t> laneData32(laneData.begin(), laneData.end());
-// xegpu::LayoutAttr proposedSrcLayout = xegpu::LayoutAttr::get(
-// context,
-// resRootPlainLayout.getSgLayout(),
-// resRootPlainLayout.getSgData(),
-// resRootPlainLayout.getInstData(),
-// resRootPlainLayout.getLaneLayout(),
-// DenseI32ArrayAttr::get(context, laneData32),
-// resRootPlainLayout.getOrder());
-
-// // reconstruct slice layers if any
-// // First collect all slice layers from innermost to outermost
-// SmallVector<DenseI64ArrayAttr> sliceDims;
-// xegpu::DistributeLayoutAttr currentLayout = resLayout;
-// while (auto sliceLayout = dyn_cast<xegpu::SliceAttr>(currentLayout)) {
-// sliceDims.push_back(sliceLayout.getDims());
-// currentLayout = sliceLayout.getParent();
-// }
-
-// // Now rebuild from outermost to innermost (reverse order)
-// xegpu::DistributeLayoutAttr finalSrcLayout = proposedSrcLayout;
-// for (auto it = sliceDims.rbegin(); it != sliceDims.rend(); ++it) {
-// finalSrcLayout = xegpu::SliceAttr::get(context, finalSrcLayout, *it);
-// }
-// return finalSrcLayout;
-// }
-
-// /// Infers the source layout attribute for a shape cast operation given the
-// /// result layout attribute, result shape, and source shape.
-// xegpu::DistributeLayoutAttr xegpu::inferShapeCastSourceLayout(
-// MLIRContext *context, xegpu::DistributeLayoutAttr resLayout,
-// ArrayRef<int64_t> resShape, ArrayRef<int64_t> srcShape){
-
-// // there are two use cases:
-// // 1. expand dims of low-rank dimensions (e.g., 1D to 2D): to set up the
-// tensor before broadcast
-// // 2. split dim of a high-rank dimension (e.g., 1D to 2D): to setup tensor
-// for multi-stage reduction
-
-// SmallVector<int64_t> shapeCastDims;
-// auto returnLayout = resLayout;
-
-// int resRank = resShape.size();
-// int srcRank = srcShape.size();
-
-// if (srcRank < resRank) {
-// // Case 1: expand dims of low-rank dimensions (e.g., 1D to 2D)
-// int dimDiff = resRank - srcRank;
-// // adding the missing leading dims
-// for (int i = 0; i < dimDiff; i++)
-// shapeCastDims.push_back(i);
-
-// // create a slice layout for the source
-// returnLayout = xegpu::SliceAttr::get(
-// context, resLayout, DenseI64ArrayAttr::get(context, shapeCastDims));
-// } else if (srcRank > resRank) {
-// // Case 2: split dim of a high-rank dimension (e.g., 1D to 2D)
-// // find the split dims by comparing srcShape and resShape
-// int srcIdx = 0;
-// int resIdx = 0;
-// while (srcIdx < srcRank && resIdx < resRank) {
-// if (srcShape[srcIdx] == resShape[resIdx]) {
-// srcIdx++;
-// resIdx++;
-// } else if (srcShape[srcIdx] < resShape[resIdx]) {
-// shapeCastDims.push_back(srcIdx);
-// srcIdx++;
-// } else {
-// // this should not happen in valid shape cast
-// assert(false && "Invalid shape cast: source shape dimension smaller
-// than result shape dimension");
-// }
-// }
-// // handle remaining src dims
-// while (srcIdx < srcRank) {
-// shapeCastDims.push_back(srcIdx);
-// srcIdx++;
-// }
-
-// // create a slice layout for the source
-// returnLayout = xegpu::SliceAttr::get(
-// context, resLayout, DenseI64ArrayAttr::get(context, shapeCastDims));
-// }
-// return returnLayout;
-
-// }
+/// Infers the source layout attribute for a bitcast operation given the
+/// result layout attribute, result element type bitwidth, and source element
+/// type bitwidth.
+xegpu::DistributeLayoutAttr
+xegpu::inferBitCastSourceLayout(MLIRContext *context,
+ xegpu::DistributeLayoutAttr resLayout,
+ int resElemTyBitWidth, int srcElemTyBitWidth) {
+ // the result and source layout must be the same
+ // if resLayout is SliceAttr, we need to first get its root layout
+ xegpu::DistributeLayoutAttr resRootLayout = resLayout;
+ while (auto sliceLayout = dyn_cast<xegpu::SliceAttr>(resRootLayout)) {
+ resRootLayout = sliceLayout.getParent();
+ }
+ // change the laneData of resRootLayout according to the bitwidth ratio
+ xegpu::LayoutAttr resRootPlainLayout =
+ dyn_cast<xegpu::LayoutAttr>(resRootLayout);
+ SmallVector<int64_t> laneData =
+ resRootPlainLayout.getEffectiveLaneDataAsInt();
+
+ if (srcElemTyBitWidth >= resElemTyBitWidth) {
+ int bitWidthRatio = srcElemTyBitWidth / resElemTyBitWidth;
+ laneData[laneData.size() - 1] =
+ laneData[laneData.size() - 1] * bitWidthRatio;
+ } else {
+ int bitWidthRatio = resElemTyBitWidth / srcElemTyBitWidth;
+ assert((laneData[laneData.size() - 2] % bitWidthRatio) == 0 &&
+ "laneData not divisible by bitWidthRatio");
+ laneData[laneData.size() - 1] =
+ laneData[laneData.size() - 1] / bitWidthRatio;
+ }
+
+ // now reconstruct the source layout with updated laneData
+ // by updating the root layout and going throught the slice layers
+ SmallVector<int32_t> laneData32(laneData.begin(), laneData.end());
+ xegpu::LayoutAttr proposedSrcLayout = xegpu::LayoutAttr::get(
+ context, resRootPlainLayout.getSgLayout(), resRootPlainLayout.getSgData(),
+ resRootPlainLayout.getInstData(), resRootPlainLayout.getLaneLayout(),
+ DenseI32ArrayAttr::get(context, laneData32),
+ resRootPlainLayout.getOrder());
+
+ // reconstruct slice layers if any
+ // First collect all slice layers from innermost to outermost
+ SmallVector<DenseI64ArrayAttr> sliceDims;
+ xegpu::DistributeLayoutAttr currentLayout = resLayout;
+ while (auto sliceLayout = dyn_cast<xegpu::SliceAttr>(currentLayout)) {
+ sliceDims.push_back(sliceLayout.getDims());
+ currentLayout = sliceLayout.getParent();
+ }
+
+ // Now rebuild from outermost to innermost (reverse order)
+ xegpu::DistributeLayoutAttr finalSrcLayout = proposedSrcLayout;
+ for (auto it = sliceDims.rbegin(); it != sliceDims.rend(); ++it) {
+ finalSrcLayout = xegpu::SliceAttr::get(context, finalSrcLayout, *it);
+ }
+ return finalSrcLayout;
+}
+
+/// Infers the source layout attribute for a shape cast operation given the
+/// result layout attribute, result shape, and source shape.
+xegpu::DistributeLayoutAttr xegpu::inferShapeCastSourceLayout(
+ MLIRContext *context, xegpu::DistributeLayoutAttr resLayout,
+ ArrayRef<int64_t> resShape, ArrayRef<int64_t> srcShape) {
+
+ // there are two use cases:
+ // 1. expand dims of low-rank dimensions (e.g., 1D to 2D): to set up the
+ // tensor before broadcast
+ // 2. split dim of a high-rank dimension (e.g., 1D to 2D): to setup tensor
+ // for multi-stage reduction
+ // 3. combines all dims to a single dim and put in the innermost dim in 2d as
+ // [1, combinedData]. only used after workgroup distribution to save
+ // multidimension data to 1D slm buffer so no need to handle sg_layout and
+ // sg_data.
+
+ // Use case 1: Check if shapes only differ by expanding unit dimensions (like
+ // expand_dims)
+ SmallVector<int64_t> expandedUnitDims;
+ auto checkOnlyExpandUnitDims = [&](ArrayRef<int64_t> src,
+ ArrayRef<int64_t> dst) -> bool {
+ // All unit dimensions in dst that don't appear in src are the expanded
+ // unit dimensions
+ size_t srcIdx = 0;
+ for (size_t dstIdx = 0; dstIdx < dst.size(); ++dstIdx)
+ if (srcIdx < src.size() && src[srcIdx] == dst[dstIdx])
+ srcIdx++;
+ else if (dst[dstIdx] == 1)
+ expandedUnitDims.push_back(dstIdx);
+ else
+ return false;
+ return srcIdx == src.size();
+ };
+
+ if (checkOnlyExpandUnitDims(srcShape, resShape)) {
+ // create a slice layout for the source by removing the expanded unit dims
+ auto sliceDimsAttr =
+ DenseI64ArrayAttr::get(context, ArrayRef<int64_t>(expandedUnitDims));
+ auto srcLayout = xegpu::SliceAttr::get(context, resLayout, sliceDimsAttr);
+ return srcLayout;
+ }
+
+ // Maps each source dimension to the range of destination dimensions it splits
+ // into
+ SmallVector<SmallVector<size_t>> splitDimGroups;
+
+ auto checkSplitDims = [&](ArrayRef<int64_t> src,
+ ArrayRef<int64_t> dst) -> bool {
+ // each dim in src can be mapped to one or more dims in dst whose product
+ // equals to the src dim
+ splitDimGroups.clear();
+ size_t srcIdx = 0;
+ int64_t accumulatedSize = 1;
+ SmallVector<size_t> currentDstDims;
+
+ for (size_t dstIdx = 0; dstIdx < dst.size(); ++dstIdx) {
+ if (srcIdx >= src.size())
+ return false;
+ accumulatedSize *= dst[dstIdx];
+ currentDstDims.push_back(dstIdx);
+
+ if (accumulatedSize == src[srcIdx]) {
+ // Record the mapping: srcIdx -> currentDstDims
+ splitDimGroups.push_back(currentDstDims);
+ // move to next src dim
+ srcIdx++;
+ accumulatedSize = 1;
+ currentDstDims.clear();
+ } else if (accumulatedSize > src[srcIdx]) {
+ return false;
+ }
+ }
+ return srcIdx == src.size();
+ };
+
+ if (checkSplitDims(srcShape, resShape)) {
+ return resLayout.collapseDims(splitDimGroups);
+ }
+
+ auto checkCombineToInnerMostDim = [&](ArrayRef<int64_t> src,
+ ArrayRef<int64_t> dst) -> bool {
+ // only one non-unit dim in dst which is the innermost dim
+ assert((dst.size() == 2) && "dst shape must be 2D");
+ int64_t srcSize = std::accumulate(src.begin(), src.end(), 1LL,
+ std::multiplies<int64_t>());
+ return (dst[0] == 1) && (dst[1] == srcSize);
+ };
+
+ if (checkCombineToInnerMostDim(srcShape, resShape)) {
+ const int subgroupSize = 16; // assuming 16 lanes per subgroup
+ const int vectorSize = 8; // assuming 8 elements per vector lane
+ int srcShapeSize = srcShape.size();
+
+ SmallVector<int64_t> instData(srcShapeSize, 1);
+ instData[srcShapeSize - 1] = subgroupSize;
+ instData[srcShapeSize - 2] =
+ vectorSize; // assuming 8 elements per instruction as starting point
+
+ // construct a vector layout with lane_layout = [1, ..., 1, subgroupSize]
+ SmallVector<int64_t> laneLayout(srcShapeSize, 1);
+ laneLayout[srcShapeSize - 1] = subgroupSize;
+ // construct a vector layout with lane_data = [1, ..., 1]
+ SmallVector<int64_t> laneData(srcShapeSize, 1);
+ }
+}
xegpu::SliceAttr
xegpu::reductionLayoutSetupRule(ArrayRef<int64_t> srcShape,
@@ -530,6 +574,8 @@ xegpu::reductionLayoutSetupRule(ArrayRef<int64_t> srcShape,
// construct a vector layout with lane_layout = [1, ..., 1, subgroupSize]
SmallVector<int64_t> laneLayout(srcShapeSize, 1);
laneLayout[srcShapeSize - 1] = subgroupSize;
+ // construct a vector layout with lane_data = [1, ..., 1]
+ SmallVector<int64_t> laneData(srcShapeSize, 1);
llvm::errs() << "DEBUG: laneLayout = [";
for (size_t i = 0; i < laneLayout.size(); i++) {
llvm::errs() << laneLayout[i];
@@ -537,8 +583,6 @@ xegpu::reductionLayoutSetupRule(ArrayRef<int64_t> srcShape,
llvm::errs() << ", ";
}
llvm::errs() << "]\n";
- // construct a vector layout with lane_data = [1, ..., 1]
- SmallVector<int64_t> laneData(srcShapeSize, 1);
bool failToAlignSliceStruct = false;
if (sliceCPL && sliceCPL.getDims().asArrayRef().equals(reductionDims)) {
>From 7e7ae5d0f59bab3382776463d07f870a45b1c21a Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Wed, 24 Dec 2025 19:21:41 +0000
Subject: [PATCH 04/35] add bitcast set rule
---
.../mlir/Dialect/XeGPU/IR/XeGPUAttrs.td | 6 +-
.../Dialect/XeGPU/Utils/XeGPULayoutUtils.h | 24 ++-
mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp | 47 +++--
.../XeGPU/Transforms/XeGPUPropagateLayout.cpp | 5 +-
.../Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp | 168 ++++++++++++------
5 files changed, 159 insertions(+), 91 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index 473e485c1faee..0e13d063498b1 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -249,7 +249,7 @@ def DistributeLayoutAttr: AttrInterface<"DistributeLayoutAttr"> {
dimension in the derived layout.}],
"xegpu::DistributeLayoutAttr",
"collapseDims",
- (ins "ArrayRef<ArrayRef<int64_t>>": $dimGroups)>,
+ (ins "SmallVector<SmallVector<int64_t>>": $dimGroups)>,
InterfaceMethod<[{Generates instructions to compute multidimensional coordinates for dist units
assigned to a level identified by linearId. The shape parameter
represents the higher-level problem size. Each level may access
@@ -527,7 +527,7 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttr]> {
// Derive a new layout by collapsing groups of dimensions.
// Each inner array in `dimGroups` specifies a set of dimensions
// that are collapsed into a single dimension in the derived layout.
- DistributeLayoutAttr collapseDims(ArrayRef<ArrayRef<int64_t>> dimGroups);
+ DistributeLayoutAttr collapseDims(SmallVector<SmallVector<int64_t>> dimGroups) const;
/// Delinearizes a linear ID into its multidimensional indices
/// based on the effective level of the layout.
@@ -707,7 +707,7 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttr]> {
// Derive a new layout by collapsing groups of dimensions.
// Each inner array in `dimGroups` specifies a set of dimensions
// that are collapsed into a single dimension in the derived layout.
- DistributeLayoutAttr collapseDims(ArrayRef<ArrayRef<int64_t>> dimGroups);
+ DistributeLayoutAttr collapseDims(SmallVector<SmallVector<int64_t>> dimGroups) const;
/// flatten a nested SliceAttr, e.g., for 2-level nested SliceAttr
/// #xegpu.slice<#xegpu.slice<#xegpu.layout<sg_layout = [4, 8, 12]>, dims = [0]>, dims = [0]>
diff --git a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h
index 01cb43b73d5ca..fca57eb85517b 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h
@@ -122,14 +122,32 @@ DistributeLayoutAttr inferShapeCastSourceLayout(MLIRContext *context,
ArrayRef<int64_t> resShape,
ArrayRef<int64_t> srcShape);
-/// Sets the the layout attribute for result based on a preferred Layout
-/// propagated from consumer
-/// the ouput must be a slice attribute
+/// Sets up layout for reduction operations by creating a SliceAttr for the
+/// result.
+///
+/// This function first attempts to construct a source layout that, when sliced
+/// along reduction dimensions, produces a result layout compatible with the
+/// consumer's preferred layout. This minimizes data redistribution overhead.
+/// The SliceAttr for the result is then created based on the derived source
+/// layout and the specified reduction dimensions.
SliceAttr
reductionLayoutSetupRule(ArrayRef<int64_t> srcShape,
SmallVector<int64_t> reductionDims,
DistributeLayoutAttr consumerPreferredLayout);
+/// Setup the result layout attribute for a bitcast operation based on element
+/// type bitwidths. This ensures the source layout can always be derived from
+/// the result layout.
+///
+/// When casting from a narrower to a wider element type (srcElemTyBitWidth <
+/// resElemTyBitWidth), the result layout's innermost dimension data sizes
+/// (sg_data, inst_data, lane_data) are scaled up by the bitwidth ratio. This
+/// maintains the invariant that the source layout can be recovered by inverse
+/// scaling during layout inference.
+DistributeLayoutAttr bitCastLayoutSetupRule(DistributeLayoutAttr resLayout,
+ int resElemTyBitWidth,
+ int srcElemTyBitWidth);
+
} // namespace xegpu
} // namespace mlir
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 613651ec7f964..1b50e6dd33802 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -476,39 +476,31 @@ LayoutAttr::setUnitDimLayout(SetVector<int64_t> unitDims) const {
DistributeLayoutAttr LayoutAttr::setDimData(int64_t dim, int64_t sgData,
int64_t instData,
int64_t laneData) {
- auto sgDataOpt = getSgData();
- auto instDataOpt = getInstData();
- auto laneDataOpt = getLaneData();
- SmallVector<int32_t> sgDataVec;
- SmallVector<int32_t> instDataVec;
- SmallVector<int32_t> laneDataVec;
+ SmallVector<int64_t> sgDataVec = getEffectiveSgDataAsInt();
+ SmallVector<int64_t> instDataVec = getEffectiveInstDataAsInt();
+ SmallVector<int64_t> laneDataVec = getEffectiveLaneDataAsInt();
- if (sgDataOpt)
- sgDataVec = llvm::to_vector(sgDataOpt.asArrayRef());
-
- if (instDataOpt)
- instDataVec = llvm::to_vector(instDataOpt.asArrayRef());
-
- if (laneDataOpt)
- laneDataVec = llvm::to_vector(laneDataOpt.asArrayRef());
-
- if (dim < static_cast<int64_t>(sgDataVec.size()))
+ if (dim < static_cast<int64_t>(sgDataVec.size()) && sgData != -1)
sgDataVec[dim] = sgData;
- if (dim < static_cast<int64_t>(instDataVec.size()))
+ if (dim < static_cast<int64_t>(instDataVec.size()) && instData != -1)
instDataVec[dim] = instData;
- if (dim < static_cast<int64_t>(laneDataVec.size()))
+ if (dim < static_cast<int64_t>(laneDataVec.size()) && laneData != -1)
laneDataVec[dim] = laneData;
+ SmallVector<int32_t> sgDataVec32(sgDataVec.begin(), sgDataVec.end());
+ SmallVector<int32_t> instDataVec32(instDataVec.begin(), instDataVec.end());
+ SmallVector<int32_t> laneDataVec32(laneDataVec.begin(), laneDataVec.end());
+
return LayoutAttr::get(
getContext(), getSgLayout(),
sgDataVec.empty() ? DenseI32ArrayAttr()
- : DenseI32ArrayAttr::get(getContext(), sgDataVec),
+ : DenseI32ArrayAttr::get(getContext(), sgDataVec32),
instDataVec.empty() ? DenseI32ArrayAttr()
- : DenseI32ArrayAttr::get(getContext(), instDataVec),
+ : DenseI32ArrayAttr::get(getContext(), instDataVec32),
getLaneLayout(),
laneDataVec.empty() ? DenseI32ArrayAttr()
- : DenseI32ArrayAttr::get(getContext(), laneDataVec),
+ : DenseI32ArrayAttr::get(getContext(), laneDataVec32),
getOrder());
}
@@ -516,9 +508,8 @@ DistributeLayoutAttr LayoutAttr::setDimData(int64_t dim, int64_t sgData,
// Each inner array in `dimGroups` specifies a set of dimensions
// that are collapsed into a single dimension in the derived layout.
DistributeLayoutAttr
-LayoutAttr::collapseDims(ArrayRef<ArrayRef<int64_t>> dimGroups) {
+LayoutAttr::collapseDims(SmallVector<SmallVector<int64_t>> dimGroups) const {
- // Extract layout attributes as vectors
SmallVector<int64_t> sgLayout = getEffectiveSgLayoutAsInt();
SmallVector<int64_t> sgData = getEffectiveSgDataAsInt();
SmallVector<int64_t> instData = getEffectiveInstDataAsInt();
@@ -823,16 +814,18 @@ DistributeLayoutAttr SliceAttr::setDimData(int64_t dim, int64_t sgData,
// Each inner array in `dimGroups` specifies a set of dimensions
// that are collapsed into a single dimension in the derived layout.
DistributeLayoutAttr
-SliceAttr::collapseDims(ArrayRef<ArrayRef<int64_t>> dimGroups) const {
+SliceAttr::collapseDims(SmallVector<SmallVector<int64_t>> dimGroups) const {
// Map the sliced dims from parent space to collapsed space
- ArrayRef<int64_t> sliceDims = getDims().asArrayRef();
+ SmallVector<int64_t> sliceDims = llvm::to_vector(getDims().asArrayRef());
// go through dimGroups and map each dim from sliced space to parent space
SmallVector<SmallVector<int64_t>> adjustedDimGroups;
for (const auto &group : dimGroups) {
- SetVector<int64_t> mappedDims = mapDimsFromSlicedSpace(group, sliceDims);
- adjustedDimGroups.push_back(mappedDims.getArrayRef());
+ SetVector<int64_t> groupSet(group.begin(), group.end());
+ SetVector<int64_t> mappedDims = mapDimsFromSlicedSpace(groupSet, sliceDims);
+ adjustedDimGroups.push_back(
+ SmallVector<int64_t>(mappedDims.begin(), mappedDims.end()));
}
auto collapsedParent = getParent().collapseDims(adjustedDimGroups);
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 1be44480de01d..7c8d0329140da 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -603,7 +603,6 @@ void LayoutInfoPropagation::visitVectorMultiReductionOp(
if (!resultLayout.isAssigned())
return;
- VectorType resultTy = llvm::dyn_cast<VectorType>(reduction.getDestType());
VectorType sourceTy =
llvm::dyn_cast<VectorType>(reduction.getSourceVectorType());
SmallVector<int64_t> reductionDims(reduction.getReductionDims().begin(),
@@ -691,10 +690,10 @@ void LayoutInfoPropagation::visitVectorBroadCastOp(
auto resultLayoutAttr =
dyn_cast<xegpu::DistributeLayoutAttr>(resultLayout.get());
- xegpu::DistributeLayoutAttr resLayout = xegpu::inferBroadCastSourceLayout(
+ xegpu::DistributeLayoutAttr srcLayout = xegpu::inferBroadCastSourceLayout(
broadcast.getContext(), resultLayoutAttr, resShape, srcShape);
- propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(resLayout)));
+ propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayout)));
return;
}
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp
index acb5f07893ee2..4a2b3557f1679 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp
@@ -291,7 +291,7 @@ bool xegpu::recoverTemporaryLayouts(Operation *rootOp) {
// across region boundaries
// while preserving a single well-defined use for each definition at the
// region-op level.
-bool xegpu::recoverTemporaryLayouts_first(Operation *rootOp) {}
+// bool xegpu::recoverTemporaryLayouts_first(Operation *rootOp) {}
template <typename T, typename>
void xegpu::removeLayoutAttr(const T &operandOrResult) {
@@ -378,52 +378,38 @@ xegpu::inferBitCastSourceLayout(MLIRContext *context,
xegpu::DistributeLayoutAttr resLayout,
int resElemTyBitWidth, int srcElemTyBitWidth) {
// the result and source layout must be the same
- // if resLayout is SliceAttr, we need to first get its root layout
- xegpu::DistributeLayoutAttr resRootLayout = resLayout;
- while (auto sliceLayout = dyn_cast<xegpu::SliceAttr>(resRootLayout)) {
- resRootLayout = sliceLayout.getParent();
- }
- // change the laneData of resRootLayout according to the bitwidth ratio
- xegpu::LayoutAttr resRootPlainLayout =
- dyn_cast<xegpu::LayoutAttr>(resRootLayout);
- SmallVector<int64_t> laneData =
- resRootPlainLayout.getEffectiveLaneDataAsInt();
+ // only adjust the sg_data, inst_data, lane_data accordingly
+ // based on the bitwidth ratio between source and result element type
+
+ SmallVector<int64_t> sgData = resLayout.getEffectiveSgDataAsInt();
+ SmallVector<int64_t> instData = resLayout.getEffectiveInstDataAsInt();
+ SmallVector<int64_t> laneData = resLayout.getEffectiveLaneDataAsInt();
+ size_t dim = sgData.size() - 1;
+ int64_t sgDataValue, instDataValue, laneDataValue;
if (srcElemTyBitWidth >= resElemTyBitWidth) {
int bitWidthRatio = srcElemTyBitWidth / resElemTyBitWidth;
- laneData[laneData.size() - 1] =
- laneData[laneData.size() - 1] * bitWidthRatio;
+ sgDataValue = (dim < sgData.size()) ? sgData[dim] * bitWidthRatio : -1;
+ instDataValue =
+ (dim < instData.size()) ? instData[dim] * bitWidthRatio : -1;
+ laneDataValue =
+ (dim < laneData.size()) ? laneData[dim] * bitWidthRatio : -1;
} else {
int bitWidthRatio = resElemTyBitWidth / srcElemTyBitWidth;
- assert((laneData[laneData.size() - 2] % bitWidthRatio) == 0 &&
+ assert((laneData[dim] % bitWidthRatio) == 0 &&
"laneData not divisible by bitWidthRatio");
- laneData[laneData.size() - 1] =
- laneData[laneData.size() - 1] / bitWidthRatio;
+ sgDataValue = (dim < sgData.size()) ? sgData[dim] / bitWidthRatio : -1;
+ instDataValue =
+ (dim < instData.size()) ? instData[dim] / bitWidthRatio : -1;
+ laneDataValue =
+ (dim < laneData.size()) ? laneData[dim] / bitWidthRatio : -1;
}
- // now reconstruct the source layout with updated laneData
- // by updating the root layout and going throught the slice layers
- SmallVector<int32_t> laneData32(laneData.begin(), laneData.end());
- xegpu::LayoutAttr proposedSrcLayout = xegpu::LayoutAttr::get(
- context, resRootPlainLayout.getSgLayout(), resRootPlainLayout.getSgData(),
- resRootPlainLayout.getInstData(), resRootPlainLayout.getLaneLayout(),
- DenseI32ArrayAttr::get(context, laneData32),
- resRootPlainLayout.getOrder());
-
- // reconstruct slice layers if any
- // First collect all slice layers from innermost to outermost
- SmallVector<DenseI64ArrayAttr> sliceDims;
- xegpu::DistributeLayoutAttr currentLayout = resLayout;
- while (auto sliceLayout = dyn_cast<xegpu::SliceAttr>(currentLayout)) {
- sliceDims.push_back(sliceLayout.getDims());
- currentLayout = sliceLayout.getParent();
- }
+ // Now set only instData and laneData, preserving sgData
+ xegpu::DistributeLayoutAttr finalSrcLayout;
+ finalSrcLayout =
+ resLayout.setDimData(dim, sgDataValue, instDataValue, laneDataValue);
- // Now rebuild from outermost to innermost (reverse order)
- xegpu::DistributeLayoutAttr finalSrcLayout = proposedSrcLayout;
- for (auto it = sliceDims.rbegin(); it != sliceDims.rend(); ++it) {
- finalSrcLayout = xegpu::SliceAttr::get(context, finalSrcLayout, *it);
- }
return finalSrcLayout;
}
@@ -471,7 +457,7 @@ xegpu::DistributeLayoutAttr xegpu::inferShapeCastSourceLayout(
// Maps each source dimension to the range of destination dimensions it splits
// into
- SmallVector<SmallVector<size_t>> splitDimGroups;
+ SmallVector<SmallVector<int64_t>> splitDimGroups;
auto checkSplitDims = [&](ArrayRef<int64_t> src,
ArrayRef<int64_t> dst) -> bool {
@@ -480,7 +466,7 @@ xegpu::DistributeLayoutAttr xegpu::inferShapeCastSourceLayout(
splitDimGroups.clear();
size_t srcIdx = 0;
int64_t accumulatedSize = 1;
- SmallVector<size_t> currentDstDims;
+ SmallVector<int64_t> currentDstDims;
for (size_t dstIdx = 0; dstIdx < dst.size(); ++dstIdx) {
if (srcIdx >= src.size())
@@ -531,8 +517,30 @@ xegpu::DistributeLayoutAttr xegpu::inferShapeCastSourceLayout(
// construct a vector layout with lane_data = [1, ..., 1]
SmallVector<int64_t> laneData(srcShapeSize, 1);
}
+
+ // TODO: Complete implementation for other shape cast scenarios
+ return nullptr;
}
+/// Sets up layout for reduction operations by creating a SliceAttr for the
+/// result.
+///
+/// Algorithm Overview:
+/// This function attempts to construct a source layout that, when sliced along
+/// reduction dimensions, produces a result layout compatible with the
+/// consumer's preferred layout. This minimizes data redistribution overhead.
+/// The SliceAttr for the result is created based on the derived source layout
+/// and the specified reduction dimensions.
+///
+/// Strategy:
+/// 1. First, check if the consumer's preferred layout is already a SliceAttr
+/// with matching reduction dimensions. If so, use its parent layout directly
+/// and adjust the sg_data/inst_data acccording to source shape.
+/// 2. If step 1 fails, construct a new layout by distributing
+/// workgroup/subgroup
+/// resources across dimensions, prioritizing alignment with the consumer's
+/// sg_layout for non-reduction dimensions.
+
xegpu::SliceAttr
xegpu::reductionLayoutSetupRule(ArrayRef<int64_t> srcShape,
SmallVector<int64_t> reductionDims,
@@ -541,29 +549,40 @@ xegpu::reductionLayoutSetupRule(ArrayRef<int64_t> srcShape,
xegpu::SliceAttr sliceCPL =
dyn_cast<xegpu::SliceAttr>(consumerPreferredLayout);
- // try to align wiht customer's preferred layout so that the slice layout
- // structure is preserved, and thus avoid potential data movement acorss sg or
+ // Strategy 1: Try to preserve the consumer's slice layout structure
+ // If the consumer already expects a slice layout with the same reduction
+ // dims, we can directly use its parent layout as our source layout. This
+ // ensures perfect alignment and avoids any data movement across subgroups or
// lanes.
- const int workgroupSize = 16; // assuming 16 subgroups for now
- const int subgroupSize = 16; // assuming 16 lanes per subgroup
- const int vectorSize = 8; // assuming 8 elements per vector lane
+ // Hardware constraints (these should ideally be queried from device
+ // capabilities)
+ const int workgroupSize = 16; // Total number of subgroups in a workgroup
+ const int subgroupSize = 16; // Number of SIMD lanes per subgroup
+ const int vectorSize = 8; // Elements processed per vector instruction
int srcShapeSize = srcShape.size();
xegpu::DistributeLayoutAttr proposedSrcLayout;
auto context = consumerPreferredLayout.getContext();
- // if srcShapeSize is less than 2, we cannot proceed
+ // Reduction layout requires at least 2D tensors
if (srcShapeSize < 2)
return nullptr;
llvm::errs() << "DEBUG: Entering \n";
+ // Initialize layout components:
+ // - sgLayout[i]: Number of subgroups covering dimension i
+ // - sgData[i]: Data elements per subgroup in dimension i (srcShape[i] /
+ // sgLayout[i])
SmallVector<int64_t> sgLayout(srcShapeSize);
SmallVector<int64_t> sgData(srcShapeSize);
+ // Initialize instruction-level parallelism with SIMD-friendly defaults:
+ // - Last dimension gets subgroupSize (16) to match lane width
+ // - Second-to-last dimension gets vectorSize (8) as starting point
SmallVector<int64_t> instData(srcShapeSize, 1);
instData[srcShapeSize - 1] = subgroupSize;
instData[srcShapeSize - 2] =
- vectorSize; // assuming 8 elements per instruction as starting point
+ vectorSize; // This will be adjusted based on actual data distribution
llvm::errs() << "DEBUG: Initial instData = [";
for (size_t i = 0; i < instData.size(); i++) {
llvm::errs() << instData[i];
@@ -571,7 +590,10 @@ xegpu::reductionLayoutSetupRule(ArrayRef<int64_t> srcShape,
llvm::errs() << ", ";
}
llvm::errs() << "]\n";
- // construct a vector layout with lane_layout = [1, ..., 1, subgroupSize]
+ // Initialize lane-level distribution:
+ // - laneLayout[i]: How lanes are distributed across dimension i
+ // (innermost dimension gets all subgroupSize lanes)
+ // - laneData[i]: Data elements per lane in dimension i (starts at 1 per lane)
SmallVector<int64_t> laneLayout(srcShapeSize, 1);
laneLayout[srcShapeSize - 1] = subgroupSize;
// construct a vector layout with lane_data = [1, ..., 1]
@@ -582,13 +604,16 @@ xegpu::reductionLayoutSetupRule(ArrayRef<int64_t> srcShape,
if (i < laneLayout.size() - 1)
llvm::errs() << ", ";
}
- llvm::errs() << "]\n";
-
+ // Attempt Strategy 1: Align with consumer's slice structure
bool failToAlignSliceStruct = false;
if (sliceCPL && sliceCPL.getDims().asArrayRef().equals(reductionDims)) {
-
+ // The consumer is already expecting a slice along our reduction dimensions!
+ // Extract the parent layout (the layout before slicing) as our candidate.
xegpu::DistributeLayoutAttr parentCPL = sliceCPL.getParent();
+ // Verify that the parent layout can be adapted to our source shape:
+ // For each dimension, check if srcShape[i] is divisible by the parent's
+ // sg_layout[i]. If so, we can reuse the subgroup distribution pattern
// for each slice dim in source shape, if the dim size is differnt than the
// result shape, try to adjust the sg_data/inst_data accordingly.
SmallVector<int64_t> pcplSgLayout = parentCPL.getEffectiveSgLayoutAsInt();
@@ -664,7 +689,9 @@ xegpu::reductionLayoutSetupRule(ArrayRef<int64_t> srcShape,
for (int i = srcShapeSize - 1; i >= 0; i--) {
llvm::errs() << "DEBUG: Loop 1, i = " << i << ", is_reduction_dim = "
<< llvm::is_contained(reductionDims, i) << "\n";
- // try to align with cplSgLayout first for non-reduction dims
+
+ // For non-reduction dimensions, try to match consumer's sg_layout
+ // This ensures the result after reduction has the expected distribution
if (!llvm::is_contained(reductionDims, i) && cplId >= 0) {
if (srcShape[i] % cplSgLayout[cplId] == 0) {
sgLayout[i] = cplSgLayout[cplId];
@@ -678,10 +705,14 @@ xegpu::reductionLayoutSetupRule(ArrayRef<int64_t> srcShape,
continue;
}
}
+ // Dimension couldn't be aligned; defer to second pass
remainingDims.push_back(i);
llvm::errs() << "DEBUG: Added i = " << i << " to remainingDims\n";
}
+ // Second pass: Distribute remaining subgroups across unhandled dimensions
+ // This handles reduction dimensions and dimensions that didn't align with
+ // consumer
llvm::errs() << "DEBUG: Starting second loop\n";
for (int i = srcShapeSize - 1; i >= 0; i--) {
llvm::errs() << "DEBUG: Loop 2, i = " << i << ", is_remaining_dim = "
@@ -725,10 +756,37 @@ xegpu::reductionLayoutSetupRule(ArrayRef<int64_t> srcShape,
DenseI32ArrayAttr::get(context, laneData32),
consumerPreferredLayout.getOrder());
- // finally, create the slice layout for reduction source
- xegpu::SliceAttr reductionSrcLayout =
+ // finally, create the slice layout for reduction result
+ xegpu::SliceAttr reductionResLayout =
xegpu::SliceAttr::get(context, proposedSrcLayout,
DenseI64ArrayAttr::get(context, reductionDims));
- return reductionSrcLayout;
+ return reductionResLayout;
}
+
+xegpu::DistributeLayoutAttr
+xegpu::bitCastLayoutSetupRule(xegpu::DistributeLayoutAttr resLayout,
+ int resElemTyBitWidth, int srcElemTyBitWidth) {
+
+ SmallVector<int64_t> sgData = resLayout.getEffectiveSgDataAsInt();
+ SmallVector<int64_t> instData = resLayout.getEffectiveInstDataAsInt();
+ SmallVector<int64_t> laneData = resLayout.getEffectiveLaneDataAsInt();
+ size_t dim = sgData.size() - 1;
+ int64_t sgDataValue, instDataValue, laneDataValue;
+
+ if (srcElemTyBitWidth < resElemTyBitWidth) {
+ int bitWidthRatio = resElemTyBitWidth / srcElemTyBitWidth;
+ sgDataValue = (dim < sgData.size()) ? sgData[dim] * bitWidthRatio : -1;
+ instDataValue =
+ (dim < instData.size()) ? instData[dim] * bitWidthRatio : -1;
+ laneDataValue =
+ (dim < laneData.size()) ? laneData[dim] * bitWidthRatio : -1;
+ }
+
+ // Now set only instData and laneData, preserving sgData
+ xegpu::DistributeLayoutAttr finalResLayout;
+ finalResLayout =
+ resLayout.setDimData(dim, sgDataValue, instDataValue, laneDataValue);
+
+ return finalResLayout;
+}
\ No newline at end of file
>From 1e0542258a34ba8d073a4e5b103221cfcf1e1733 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Wed, 24 Dec 2025 22:16:43 +0000
Subject: [PATCH 05/35] add recover temporary layout implementation, not
compiled yet
---
.../mlir/Dialect/XeGPU/IR/XeGPUAttrs.td | 9 +-
.../Dialect/XeGPU/Utils/XeGPULayoutUtils.h | 4 +-
.../Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp | 234 +++++++++++++++---
3 files changed, 209 insertions(+), 38 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index 0e13d063498b1..14dbcc8d94630 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -237,7 +237,8 @@ def DistributeLayoutAttr: AttrInterface<"DistributeLayoutAttr"> {
"delinearizeId",
(ins "OpBuilder &": $builder, "Location":$loc, "Value":$linearId)>,
InterfaceMethod<[{Derive a new layout with sg_data, inst_data and lane_data set to the
- specified values for the given dimension}],
+ specified values for the given dimension. Passing -1 for any parameter
+ preserves its original value.}],
"xegpu::DistributeLayoutAttr",
"setDimData",
(ins "int64_t": $dim,
@@ -521,7 +522,8 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttr]> {
DistributeLayoutAttr setUnitDimLayout(SetVector<int64_t> unitDims) const;
// Derive a new layout with sg_data, inst_data and lane_data set to the
- // specified values for the given dimension
+ // specified values for the given dimension. Passing -1 for any parameter
+ // preserves its original value.
DistributeLayoutAttr setDimData(int64_t dim, int64_t sgData, int64_t instData, int64_t laneData);
// Derive a new layout by collapsing groups of dimensions.
@@ -701,7 +703,8 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttr]> {
DistributeLayoutAttr setUnitDimLayout(SetVector<int64_t> unitDims) const;
// Derive a new layout with sg_data, inst_data and lane_data set to the
- // specified values for the given dimension
+ // specified values for the given dimension. Passing -1 for any parameter
+ // preserves its original value.
DistributeLayoutAttr setDimData(int64_t dim, int64_t sgData, int64_t instData, int64_t laneData);
// Derive a new layout by collapsing groups of dimensions.
diff --git a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h
index fca57eb85517b..bff005e3701df 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h
@@ -96,7 +96,7 @@ void removeLayoutAttrs(Operation *op);
/// Infers the source layout attribute for a broadcast operation given the
/// result layout attribute, result shape, source shape, and broadcasted dims.
-DistributeLayoutAttr inferBroadCastSourceLayout(MLIRContext *context,
+DistributeLayoutAttr inferBroadcastSourceLayout(MLIRContext *context,
DistributeLayoutAttr resLayout,
ArrayRef<int64_t> resShape,
ArrayRef<int64_t> srcShape);
@@ -117,7 +117,7 @@ DistributeLayoutAttr inferBitCastSourceLayout(MLIRContext *context,
/// Infers the source layout attribute for a shape cast operation given the
/// result layout attribute, result shape, and source shape.
-DistributeLayoutAttr inferShapeCastSourceLayout(MLIRContext *context,
+DistributeLayoutAttr inferShapecastSourceLayout(MLIRContext *context,
DistributeLayoutAttr resLayout,
ArrayRef<int64_t> resShape,
ArrayRef<int64_t> srcShape);
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp
index 4a2b3557f1679..ccda369accdd6 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp
@@ -10,6 +10,7 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/XeVMDialect.h"
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
@@ -244,24 +245,25 @@ void xegpu::recoverTemporaryLayoutsDeprecated(Operation *op) {
/// Attach layout attributes to all vector-type operands of operations within
/// the given operation's region. Reports an error if any vector operand lacks
/// a layout attribute.
-bool xegpu::recoverTemporaryLayouts(Operation *rootOp) {
- auto result = rootOp->walk([&](Operation *op) {
- for (OpOperand &operand : op->getOpOperands()) {
- // Layouts are needed for vector type only.
- if (!isa<VectorType>(operand.get().getType()))
- continue;
- auto layout = xegpu::getDistributeLayoutAttr(operand.get());
- if (!layout) {
- op->emitError("Could not find layout attribute for operand ")
- << operand.getOperandNumber() << " of operation " << op->getName();
- return WalkResult::interrupt();
- }
- xegpu::setDistributeLayoutAttr(operand, layout);
- }
- return WalkResult::advance();
- });
- return !result.wasInterrupted();
-}
+// bool xegpu::recoverTemporaryLayouts(Operation *rootOp) {
+// auto result = rootOp->walk([&](Operation *op) {
+// for (OpOperand &operand : op->getOpOperands()) {
+// // Layouts are needed for vector type only.
+// if (!isa<VectorType>(operand.get().getType()))
+// continue;
+// auto layout = xegpu::getDistributeLayoutAttr(operand.get());
+// if (!layout) {
+// op->emitError("Could not find layout attribute for operand ")
+// << operand.getOperandNumber() << " of operation " <<
+// op->getName();
+// return WalkResult::interrupt();
+// }
+// xegpu::setDistributeLayoutAttr(operand, layout);
+// }
+// return WalkResult::advance();
+// });
+// return !result.wasInterrupted();
+// }
// Prerequisite for Layout Recovery
// It relies on the following invariant:
@@ -274,10 +276,11 @@ bool xegpu::recoverTemporaryLayouts(Operation *rootOp) {
// - Only the result of convert_layout is permitted to have no subsequent
// use.
-// The recover proceeds by scanning the operation in reverse topological orderas
-// follows: Across operations: layouts are propagated from uses to definitions.
-// Within an operation: layouts are propagated from definitions (result) to uses
-// (operands).
+// The recover proceeds by scanning the operation in reverse topological order
+// as follows:
+// For regular operations: First the result layouts are propagated from uses.
+// Then the result layouts are propagated to uses (operands).
+//
// For region operations (e.g., loops):
// - When backward propagation reaches a region op, it sets the layout of
// the region op’s results according to use points like regular ops.
@@ -291,7 +294,168 @@ bool xegpu::recoverTemporaryLayouts(Operation *rootOp) {
// across region boundaries
// while preserving a single well-defined use for each definition at the
// region-op level.
-// bool xegpu::recoverTemporaryLayouts_first(Operation *rootOp) {}
+
+// the inner function for recoverTemporaryLayouts is a recursive function
+// the input rootOp is the function operation, which is also a region op.
+// it recursivley process the region op in reverse topological order.
+bool xegpu::recoverTemporaryLayouts(Operation *rootOp) {
+ rootOp->walk([&](func::FuncOp func) {
+ walkRegionBackward(func.getBody(), [&](Operation *op) {
+ if (auto regionOp = dyn_cast<mlir::RegionBranchOpInterface>(op)) {
+ // hit the region op after visiting inside region
+ propagateRegionArgsToInits(regionOp);
+ } else if (auto yieldOp =
+ dyn_cast<mlir::RegionBranchTerminatorOpInterface>(op)) {
+ // yield op inside region op
+ propagateRegionResultsToYieldOperands(yieldOp);
+ } else {
+ // if the op is regular op, calling propagateResultsToRegularOperands
+ propagateResultsToRegularOperands(op);
+ }
+ });
+ });
+}
+
+static void walkRegionBackward(Region ®ion,
+ llvm::function_ref<void(Operation *)> visit) {
+ // blocks: back -> front
+ for (Block &block : llvm::reverse(region)) {
+ // ops: back -> front, early-inc so visit() may erase current op safely
+ for (Operation &op : llvm::reverse(block)) {
+ // make sure we first visit inside the region op (so yield op first)
+ // and then move to region op itself
+ for (Region &nested : llvm::reverse(op.getRegions()))
+ walkRegionBackward(nested, visit);
+
+ visit(&op);
+ }
+ }
+}
+
+// For regular operations: First the result layouts are propagated from uses.
+// Then the result layouts are propagated to uses (operands).
+static void propagateResultsToRegularOperands(Operation *op) {
+ OpResult result = op->getOpResults()[0];
+ auto resLayout = xegpu::getDistributeLayoutAttr(result);
+ assert(resLayout &&
+ "result layout must be defined before propagating to uses");
+
+ // if op is reduction op, call inferReductionSourceLayout
+ if (auto reduceOp = dyn_cast<vector::MultiDimReductionOp>(op)) {
+ SmallVector<int64_t> reduceDims =
+ llvm::to_vector(reduceOp.getReductionDims().getAsValueRange<int64_t>());
+ auto srcLayout = xegpu::inferReductionSourceLayout(resLayout, reduceDims);
+ // set the layout to the operand
+ xegpu::setTemporaryLayout(reduceOp.getSource(), srcLayout);
+ xegpu::setTemporaryLayout(reduceOp.getAcc(), resLayout);
+ return;
+ }
+
+ // if op is broadcast op, call inferBroadcastSourceLayout
+ if (auto broadcastOp = dyn_cast<vector::BroadcastOp>(op)) {
+ ArrayRef<int64_t> resShape =
+ llvm::cast<VectorType>(broadcastOp.getResult().getType()).getShape();
+ ArrayRef<int64_t> srcShape =
+ llvm::cast<VectorType>(broadcastOp.getSource().getType()).getShape();
+ auto srcLayout =
+ xegpu::inferBroadcastSourceLayout(resLayout, resShape, srcShape);
+ // set the layout to the operand
+ xegpu::setTemporaryLayout(broadcastOp.getSource(), srcLayout);
+ return;
+ }
+
+ // if op is bitcast op, call inferBitCastSourceLayout
+ if (auto bitcastOp = dyn_cast<vector::BitCastOp>(op)) {
+ int resElemTyBitWidth =
+ llvm::cast<VectorType>(bitcastOp.getResult().getType())
+ .getElementTypeBitWidth();
+ int srcElemTyBitWidth =
+ llvm::cast<VectorType>(bitcastOp.getSource().getType())
+ .getElementTypeBitWidth();
+ auto srcLayout = xegpu::inferBitCastSourceLayout(
+ op->getContext(), resLayout, resElemTyBitWidth, srcElemTyBitWidth);
+ // set the layout to the operand
+ xegpu::setTemporaryLayout(bitcastOp.getSource(), srcLayout);
+ return;
+ }
+
+ // if op is shape_cast op, call inferShapecastSourceLayout
+ if (auto shapeCastOp = dyn_cast<vector::ShapeCastOp>(op)) {
+ ArrayRef<int64_t> resShape =
+ llvm::cast<VectorType>(shapeCastOp.getResult().getType()).getShape();
+ ArrayRef<int64_t> srcShape =
+ llvm::cast<VectorType>(shapeCastOp.getSource().getType()).getShape();
+ auto srcLayout =
+ xegpu::inferShapecastSourceLayout(resLayout, resShape, srcShape);
+ // set the layout to the operand
+ xegpu::setTemporaryLayout(shapeCastOp.getSource(), srcLayout);
+ return;
+ }
+
+ // if op is a anchor op, no need to do anything
+ if (isa<xegpu::AnchorLayoutInterface>(op)) {
+ return;
+ }
+
+ // for other regular ops, propagate the result layout to all vector operands
+ for (OpOperand &opr : op->getOpOperands()) {
+ // Layouts are needed for vector type only.
+ if (!isa<VectorType>(opr.get().getType()))
+ continue;
+ xegpu::setTemporaryLayout(opr, resLayout);
+ }
+}
+
+static void propagateRegionResultsToYieldOperands(
+ mlir::RegionBranchTerminatorOpInterface yieldOp) {
+ llvm::SmallVector<mlir::RegionSuccessor> successors;
+ llvm::SmallVector<mlir::Attribute> operands(yieldOp->getNumOperands(),
+ nullptr);
+ yieldOp.getSuccessorRegions(operands, successors);
+
+ for (mlir::RegionSuccessor &successor : successors) {
+ // find out the successor which is the parent region of yieldOp
+ if (successor.getSuccessorRegion() != yieldOp->getParentRegion()) //????//
+ continue;
+ // propagate the layout from region result to yield operands
+ for (unsigned i = 0; i < successor.getSuccessorInputs().size(); ++i) {
+ Value regionResult = successor.getSuccessorInputs()[i]; // region argument
+ Value yieldOperand = yieldOp->getOperand(i); // yield operand
+
+ auto layout = xegpu::getDistributeLayoutAttr(regionResult);
+ assert(
+ layout &&
+ "region result layout must be defined before propagating to yield");
+ xegpu::setTemporaryLayout(yieldOperand, layout);
+ }
+ }
+}
+
+void propagateRegionArgsToInits(mlir::RegionBranchOpInterface *regionOp) {
+
+ // Get entry successors (regions that can be entered initially)
+ SmallVector<RegionSuccessor> successors;
+ regionOp.getEntrySuccessorRegions(/*operands=*/ArrayRef<Attribute>(),
+ successors);
+
+ // For each possible entry region, get the operands forwarded to it
+ for (RegionSuccessor &successor : successors) {
+ OperandRange initOperands = regionOp.getEntrySuccessorOperands(successor);
+ // initOperands are the initialization arguments for this successor
+ // iterate the region arguments
+ for (unsigned i = 0; i < successor.getSuccessorRegion()->getNumArguments();
+ ++i) {
+ Value regionArg =
+ successor.getSuccessorRegion()->getArgument(i); // region argument
+ Value initOperand = initOperands[i]; // init operand
+ auto layout = xegpu::getDistributeLayoutAttr(regionArg);
+ assert(
+ layout &&
+ "region argument layout must be defined before propagating to init");
+ xegpu::setTemporaryLayout(initOperand, layout);
+ }
+ }
+}
template <typename T, typename>
void xegpu::removeLayoutAttr(const T &operandOrResult) {
@@ -328,9 +492,10 @@ void xegpu::removeLayoutAttrs(Operation *op) {
/// Infers the source layout attribute for a broadcast operation given the
/// result layout attribute, result shape, source shape.
-xegpu::DistributeLayoutAttr xegpu::inferBroadCastSourceLayout(
- MLIRContext *context, xegpu::DistributeLayoutAttr resLayout,
- ArrayRef<int64_t> resShape, ArrayRef<int64_t> srcShape) {
+xegpu::DistributeLayoutAttr
+xegpu::inferBroadcastSourceLayout(xegpu::DistributeLayoutAttr resLayout,
+ ArrayRef<int64_t> resShape,
+ ArrayRef<int64_t> srcShape) {
SmallVector<int64_t> bcastDims;
auto returnLayout = resLayout;
@@ -345,7 +510,8 @@ xegpu::DistributeLayoutAttr xegpu::inferBroadCastSourceLayout(
// create a slice layout for the source
returnLayout = xegpu::SliceAttr::get(
- context, resLayout, DenseI64ArrayAttr::get(context, bcastDims));
+ resLayout.getContext(), resLayout,
+ DenseI64ArrayAttr::get(resLayout.getContext(), bcastDims));
}
return returnLayout;
}
@@ -415,9 +581,10 @@ xegpu::inferBitCastSourceLayout(MLIRContext *context,
/// Infers the source layout attribute for a shape cast operation given the
/// result layout attribute, result shape, and source shape.
-xegpu::DistributeLayoutAttr xegpu::inferShapeCastSourceLayout(
- MLIRContext *context, xegpu::DistributeLayoutAttr resLayout,
- ArrayRef<int64_t> resShape, ArrayRef<int64_t> srcShape) {
+xegpu::DistributeLayoutAttr
+xegpu::inferShapecastSourceLayout(xegpu::DistributeLayoutAttr resLayout,
+ ArrayRef<int64_t> resShape,
+ ArrayRef<int64_t> srcShape) {
// there are two use cases:
// 1. expand dims of low-rank dimensions (e.g., 1D to 2D): to set up the
@@ -449,9 +616,10 @@ xegpu::DistributeLayoutAttr xegpu::inferShapeCastSourceLayout(
if (checkOnlyExpandUnitDims(srcShape, resShape)) {
// create a slice layout for the source by removing the expanded unit dims
- auto sliceDimsAttr =
- DenseI64ArrayAttr::get(context, ArrayRef<int64_t>(expandedUnitDims));
- auto srcLayout = xegpu::SliceAttr::get(context, resLayout, sliceDimsAttr);
+ auto sliceDimsAttr = DenseI64ArrayAttr::get(
+ resLayout.getContext(), ArrayRef<int64_t>(expandedUnitDims));
+ auto srcLayout =
+ xegpu::SliceAttr::get(resLayout.getContext(), resLayout, sliceDimsAttr);
return srcLayout;
}
>From e73664ce582e072f3444cf50e71ca6d73a7bbe3c Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Wed, 24 Dec 2025 22:25:07 +0000
Subject: [PATCH 06/35] remove debug print
---
.../Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp | 150 +++++-------------
1 file changed, 38 insertions(+), 112 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp
index ccda369accdd6..a995249d87e95 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp
@@ -717,13 +717,7 @@ xegpu::reductionLayoutSetupRule(ArrayRef<int64_t> srcShape,
xegpu::SliceAttr sliceCPL =
dyn_cast<xegpu::SliceAttr>(consumerPreferredLayout);
- // Strategy 1: Try to preserve the consumer's slice layout structure
- // If the consumer already expects a slice layout with the same reduction
- // dims, we can directly use its parent layout as our source layout. This
- // ensures perfect alignment and avoids any data movement across subgroups or
- // lanes.
-
- // Hardware constraints (these should ideally be queried from device
+ // Hardware constraints (TODO: these should ideally be queried from device
// capabilities)
const int workgroupSize = 16; // Total number of subgroups in a workgroup
const int subgroupSize = 16; // Number of SIMD lanes per subgroup
@@ -735,44 +729,23 @@ xegpu::reductionLayoutSetupRule(ArrayRef<int64_t> srcShape,
if (srcShapeSize < 2)
return nullptr;
- llvm::errs() << "DEBUG: Entering \n";
-
- // Initialize layout components:
- // - sgLayout[i]: Number of subgroups covering dimension i
- // - sgData[i]: Data elements per subgroup in dimension i (srcShape[i] /
- // sgLayout[i])
SmallVector<int64_t> sgLayout(srcShapeSize);
SmallVector<int64_t> sgData(srcShapeSize);
- // Initialize instruction-level parallelism with SIMD-friendly defaults:
- // - Last dimension gets subgroupSize (16) to match lane width
- // - Second-to-last dimension gets vectorSize (8) as starting point
SmallVector<int64_t> instData(srcShapeSize, 1);
instData[srcShapeSize - 1] = subgroupSize;
instData[srcShapeSize - 2] =
vectorSize; // This will be adjusted based on actual data distribution
- llvm::errs() << "DEBUG: Initial instData = [";
- for (size_t i = 0; i < instData.size(); i++) {
- llvm::errs() << instData[i];
- if (i < instData.size() - 1)
- llvm::errs() << ", ";
- }
- llvm::errs() << "]\n";
- // Initialize lane-level distribution:
- // - laneLayout[i]: How lanes are distributed across dimension i
- // (innermost dimension gets all subgroupSize lanes)
- // - laneData[i]: Data elements per lane in dimension i (starts at 1 per lane)
+
SmallVector<int64_t> laneLayout(srcShapeSize, 1);
laneLayout[srcShapeSize - 1] = subgroupSize;
- // construct a vector layout with lane_data = [1, ..., 1]
SmallVector<int64_t> laneData(srcShapeSize, 1);
- llvm::errs() << "DEBUG: laneLayout = [";
- for (size_t i = 0; i < laneLayout.size(); i++) {
- llvm::errs() << laneLayout[i];
- if (i < laneLayout.size() - 1)
- llvm::errs() << ", ";
- }
- // Attempt Strategy 1: Align with consumer's slice structure
+
+ // Strategy 1: Try to preserve the consumer's slice layout structure
+ // If the consumer already expects a slice layout with the same reduction
+ // dims, we can directly use its parent layout as our source layout. This
+ // ensures perfect alignment and avoids any data movement across subgroups or
+ // lanes.
bool failToAlignSliceStruct = false;
if (sliceCPL && sliceCPL.getDims().asArrayRef().equals(reductionDims)) {
// The consumer is already expecting a slice along our reduction dimensions!
@@ -782,7 +755,7 @@ xegpu::reductionLayoutSetupRule(ArrayRef<int64_t> srcShape,
// Verify that the parent layout can be adapted to our source shape:
// For each dimension, check if srcShape[i] is divisible by the parent's
// sg_layout[i]. If so, we can reuse the subgroup distribution pattern
- // for each slice dim in source shape, if the dim size is differnt than the
+ // for each slice dim in source shape, if the dim size is different than the
// result shape, try to adjust the sg_data/inst_data accordingly.
SmallVector<int64_t> pcplSgLayout = parentCPL.getEffectiveSgLayoutAsInt();
SmallVector<int64_t> pcplLaneLayout =
@@ -794,16 +767,6 @@ xegpu::reductionLayoutSetupRule(ArrayRef<int64_t> srcShape,
proposedSrcLayout = parentCPL;
- llvm::errs() << "DEBUG: srcShapeSize = " << srcShapeSize << "\n";
- llvm::errs() << "DEBUG: parentCPL rank = " << parentCPL.getRank() << "\n";
- llvm::errs() << "DEBUG: srcShape = [";
- for (int i = 0; i < srcShapeSize; i++) {
- llvm::errs() << srcShape[i];
- if (i < srcShapeSize - 1)
- llvm::errs() << ", ";
- }
- llvm::errs() << "]\n";
-
if (pcplSgLayout.size() == static_cast<size_t>(srcShapeSize)) {
for (int i = 0; i < srcShapeSize; i++) {
if (srcShape[i] % pcplSgLayout[i] == 0) {
@@ -832,85 +795,49 @@ xegpu::reductionLayoutSetupRule(ArrayRef<int64_t> srcShape,
failToAlignSliceStruct = true;
}
+ // Strategy 2: Construct a new layout aligned with consumer's sg_layout for
+ // the result (non-reduction dims) then distribute remaining subgroups across
+ // other dimensions
if (failToAlignSliceStruct) {
-
- // try to align the sg layout
SmallVector<int64_t> cplSgLayout =
consumerPreferredLayout.getEffectiveSgLayoutAsInt();
- llvm::errs() << "DEBUG: cplSgLayout size = " << cplSgLayout.size() << "\n";
- // if sg layout doesn't cover all the sg ids, distribute rest to
- // non-reduction dims
int remainingSgCount = workgroupSize;
-
SmallVector<int64_t> remainingDims;
- // print debug info for consumerPreferredLayout and cplSgLayout
- llvm::errs() << "DEBUG: consumerPreferredLayout sgLayout = [";
- auto cplSgLayoutFull = consumerPreferredLayout.getEffectiveSgLayoutAsInt();
- for (size_t i = 0; i < cplSgLayoutFull.size(); i++) {
- llvm::errs() << cplSgLayoutFull[i];
- if (i < cplSgLayoutFull.size() - 1)
- llvm::errs() << ", ";
- }
- // if cplSgLayout is not empty, try to align the sg layout first
int cplId = cplSgLayout.size() - 1;
- llvm::errs() << "DEBUG: Starting first loop, cplId = " << cplId << "\n";
- for (int i = srcShapeSize - 1; i >= 0; i--) {
- llvm::errs() << "DEBUG: Loop 1, i = " << i << ", is_reduction_dim = "
- << llvm::is_contained(reductionDims, i) << "\n";
-
- // For non-reduction dimensions, try to match consumer's sg_layout
- // This ensures the result after reduction has the expected distribution
- if (!llvm::is_contained(reductionDims, i) && cplId >= 0) {
- if (srcShape[i] % cplSgLayout[cplId] == 0) {
- sgLayout[i] = cplSgLayout[cplId];
- sgData[i] = srcShape[i] / sgLayout[i];
- instData[i] = std::min(instData[i], sgData[i]);
- remainingSgCount /= sgLayout[i];
- llvm::errs() << "DEBUG: Set sgLayout[" << i << "] = " << sgLayout[i]
- << ", sgData[" << i << "] = " << sgData[i]
- << ", remainingSgCount = " << remainingSgCount << "\n";
- cplId--;
- continue;
- }
- }
- // Dimension couldn't be aligned; defer to second pass
- remainingDims.push_back(i);
- llvm::errs() << "DEBUG: Added i = " << i << " to remainingDims\n";
- }
- // Second pass: Distribute remaining subgroups across unhandled dimensions
- // This handles reduction dimensions and dimensions that didn't align with
- // consumer
- llvm::errs() << "DEBUG: Starting second loop\n";
- for (int i = srcShapeSize - 1; i >= 0; i--) {
- llvm::errs() << "DEBUG: Loop 2, i = " << i << ", is_remaining_dim = "
- << llvm::is_contained(remainingDims, i) << "\n";
- if (llvm::is_contained(remainingDims, i)) {
-
- llvm::errs() << "DEBUG: Before Set sgLayout[" << i
- << "] = " << sgLayout[i] << ", sgData[" << i
- << "] = " << sgData[i]
- << ", remainingSgCount = " << remainingSgCount << "\n";
-
- sgLayout[i] = std::min((srcShape[i] / laneLayout[i]),
- static_cast<int64_t>(remainingSgCount));
+ // For non-reduction dimensions, try to match consumer's sg_layout
+ // This ensures the result after reduction has the expected distribution
+ if (!llvm::is_contained(reductionDims, i) && cplId >= 0) {
+ if (srcShape[i] % cplSgLayout[cplId] == 0) {
+ sgLayout[i] = cplSgLayout[cplId];
sgData[i] = srcShape[i] / sgLayout[i];
instData[i] = std::min(instData[i], sgData[i]);
remainingSgCount /= sgLayout[i];
+ cplId--;
+ continue;
+ }
+ }
+ remainingDims.push_back(i);
+ }
- llvm::errs() << "DEBUG: After Set sgLayout[" << i
- << "] = " << sgLayout[i] << ", sgData[" << i
- << "] = " << sgData[i]
- << ", remainingSgCount = " << remainingSgCount << "\n";
+ // Second pass: Distribute remaining subgroups across unhandled dimensions
+ // This handles reduction dimensions and dimensions that didn't align with
+ // consumer
+ for (int i = srcShapeSize - 1; i >= 0; i--) {
- if (remainingSgCount == 1) {
- llvm::errs() << "DEBUG: Breaking from loop 2, remainingSgCount = 1\n";
- break;
- }
- }
+ sgLayout[i] = std::min((srcShape[i] / laneLayout[i]),
+ static_cast<int64_t>(remainingSgCount));
+ sgData[i] = srcShape[i] / sgLayout[i];
+ instData[i] = std::min(instData[i], sgData[i]);
+ remainingSgCount /= sgLayout[i];
+
+ if (remainingSgCount == 1) {
+ break;
+ }
+ }
}
}
- // Convert int64_t vectors to int32_t for DenseI32ArrayAttr
+
SmallVector<int32_t> sgLayout32(sgLayout.begin(), sgLayout.end());
SmallVector<int32_t> sgData32(sgData.begin(), sgData.end());
SmallVector<int32_t> instData32(instData.begin(), instData.end());
@@ -924,7 +851,6 @@ xegpu::reductionLayoutSetupRule(ArrayRef<int64_t> srcShape,
DenseI32ArrayAttr::get(context, laneData32),
consumerPreferredLayout.getOrder());
- // finally, create the slice layout for reduction result
xegpu::SliceAttr reductionResLayout =
xegpu::SliceAttr::get(context, proposedSrcLayout,
DenseI64ArrayAttr::get(context, reductionDims));
>From 5565f60ff77e6d87bf780fab11d370054a3c59f0 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Fri, 23 Jan 2026 19:29:08 +0000
Subject: [PATCH 07/35] improve reductionSetupRule
---
.../Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp | 138 +++++++++---------
1 file changed, 73 insertions(+), 65 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp
index 327a1901a1bea..fb416b4c81f15 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp
@@ -637,7 +637,7 @@ xegpu::inferShapecastSourceLayout(xegpu::DistributeLayoutAttr resLayout,
ArrayRef<int64_t> resShape,
ArrayRef<int64_t> srcShape) {
- // there are two use cases:
+ // there are three use cases:
// 1. expand dims of low-rank dimensions (e.g., 1D to 2D): to set up the
// tensor before broadcast
// 2. split dim of a high-rank dimension (e.g., 1D to 2D): to setup tensor
@@ -763,10 +763,10 @@ xegpu::inferShapecastSourceLayout(xegpu::DistributeLayoutAttr resLayout,
xegpu::SliceAttr
xegpu::reductionLayoutSetupRule(ArrayRef<int64_t> srcShape,
SmallVector<int64_t> reductionDims,
- DistributeLayoutAttr consumerPreferredLayout) {
+ DistributeLayoutAttr consumerLayout) {
- xegpu::SliceAttr sliceCPL =
- dyn_cast<xegpu::SliceAttr>(consumerPreferredLayout);
+ xegpu::SliceAttr consumerSliceLayout =
+ dyn_cast<xegpu::SliceAttr>(consumerLayout);
// Hardware constraints (TODO: these should ideally be queried from device
// capabilities)
@@ -775,7 +775,7 @@ xegpu::reductionLayoutSetupRule(ArrayRef<int64_t> srcShape,
const int vectorSize = 8; // Elements processed per vector instruction
int srcShapeSize = srcShape.size();
xegpu::DistributeLayoutAttr proposedSrcLayout;
- auto context = consumerPreferredLayout.getContext();
+ auto context = consumerLayout.getContext();
// Reduction layout requires at least 2D tensors
if (srcShapeSize < 2)
return nullptr;
@@ -797,58 +797,69 @@ xegpu::reductionLayoutSetupRule(ArrayRef<int64_t> srcShape,
// dims, we can directly use its parent layout as our source layout. This
// ensures best alignment and avoids any data movement across subgroups
// and lanes.
- bool failToAlignSliceStruct = false;
- if (sliceCPL && sliceCPL.getDims().asArrayRef().equals(reductionDims) &&
- sliceCPL.getParent().getRank() == srcShapeSize) {
- // The consumer is expecting a slice along the current reduction dimensions
- xegpu::DistributeLayoutAttr parentCPL = sliceCPL.getParent();
+ auto canPreserveSliceLayout =
+ [&](ArrayRef<int64_t> srcShape, SmallVector<int64_t> reductionDims,
+ DistributeLayoutAttr consumerLayout) -> bool {
// Verify that the consumer layout can be adapted to the source shape:
// For each dimension, check if srcShape[i] is divisible by the parent's
- // sg_layout[i]. If so, we can reuse the subgroup distribution pattern
+ // sg_layout[i] and lane_layout[i]. If so, sg_layout and lane_layout can be
+ // reused.
+
+ if (!consumerSliceLayout)
+ return false;
+ if (!consumerSliceLayout.getDims().asArrayRef().equals(reductionDims))
+ return false;
+ xegpu::DistributeLayoutAttr parentLayout = consumerSliceLayout.getParent();
+ if (!parentLayout.getRank() == srcShapeSize)
+ return false;
+
+ SmallVector<int64_t> parentSgLayout =
+ parentLayout.getEffectiveSgLayoutAsInt();
+ SmallVector<int64_t> parentLaneLayout =
+ parentLayout.getEffectiveLaneLayoutAsInt();
+
+ if (parentSgLayout.size() != static_cast<size_t>(srcShapeSize))
+ return false;
+ if (parentLaneLayout.size() != static_cast<size_t>(srcShapeSize))
+ return false;
+ for (int i = 0; i < srcShapeSize; i++) {
+ if (srcShape[i] % parentSgLayout[i] != 0)
+ return false;
+ if (instData[i] % parentLaneLayout[i] != 0)
+ return false;
+ }
+ return true;
+ };
+
+ if (canPreserveSliceLayout(srcShape, reductionDims, consumerLayout)) {
// for each slice dim in source shape, if the dim size is different than the
// result shape, try to adjust the sg_data/inst_data accordingly.
- SmallVector<int64_t> pcplSgLayout = parentCPL.getEffectiveSgLayoutAsInt();
- SmallVector<int64_t> pcplLaneLayout =
- parentCPL.getEffectiveLaneLayoutAsInt();
- SmallVector<int64_t> pcplLaneData = parentCPL.getEffectiveLaneDataAsInt();
-
- proposedSrcLayout = parentCPL;
-
- if (pcplSgLayout.size() == static_cast<size_t>(srcShapeSize)) {
- for (int i = 0; i < srcShapeSize; i++) {
- if (srcShape[i] % pcplSgLayout[i] == 0) {
- sgLayout[i] = pcplSgLayout[i];
- sgData[i] = srcShape[i] / sgLayout[i];
- instData[i] = std::min(instData[i], sgData[i]);
- } else {
- failToAlignSliceStruct = true;
- break;
- }
- }
+ SmallVector<int64_t> parentSgLayout =
+ consumerSliceLayout.getEffectiveSgLayoutAsInt();
+ SmallVector<int64_t> parentLaneLayout =
+ consumerSliceLayout.getEffectiveLaneLayoutAsInt();
+ SmallVector<int64_t> parentLaneData =
+ consumerSliceLayout.getEffectiveLaneDataAsInt();
+
+ for (int i = 0; i < srcShapeSize; i++) {
+ sgLayout[i] = parentSgLayout[i];
+ sgData[i] = srcShape[i] / sgLayout[i];
+ instData[i] = std::min(instData[i], sgData[i]);
+ laneLayout[i] = parentLaneLayout[i];
+ laneData[i] = parentLaneData[i];
}
- if (pcplLaneLayout.size() == static_cast<size_t>(srcShapeSize)) {
- for (int i = 0; i < srcShapeSize; i++) {
- if (instData[i] % pcplLaneLayout[i] == 0) {
- laneLayout[i] = pcplLaneLayout[i];
- laneData[i] = pcplLaneData[i];
- } else {
- failToAlignSliceStruct = true;
- break;
- }
- }
- }
} else {
- failToAlignSliceStruct = true;
- }
+ // Strategy 2: Construct a new layout aligned with consumer's sg_layout for
+ // the result (non-reduction dims) then distribute remaining subgroups
+ // across reduced dimensions
- // Strategy 2: Construct a new layout aligned with consumer's sg_layout for
- // the result (non-reduction dims) then distribute remaining subgroups across
- // other dimensions
- if (failToAlignSliceStruct) {
SmallVector<int64_t> cplSgLayout =
- consumerPreferredLayout.getEffectiveSgLayoutAsInt();
+ consumerLayout.getEffectiveSgLayoutAsInt();
+ SmallVector<int64_t> cplSgData = consumerLayout.getEffectiveSgDataAsInt();
+ SmallVector<int64_t> cplInstData =
+ consumerLayout.getEffectiveInstDataAsInt();
int remainingSgCount = workgroupSize;
SmallVector<int64_t> remainingDims;
int cplId = cplSgLayout.size() - 1;
@@ -856,31 +867,29 @@ xegpu::reductionLayoutSetupRule(ArrayRef<int64_t> srcShape,
// This ensures the result after reduction has the expected distribution
for (int i = srcShapeSize - 1; i >= 0; i--) {
if (!llvm::is_contained(reductionDims, i) && cplId >= 0) {
- if (srcShape[i] % cplSgLayout[cplId] == 0) {
- sgLayout[i] = cplSgLayout[cplId];
- sgData[i] = srcShape[i] / sgLayout[i];
- instData[i] = std::min(instData[i], sgData[i]);
- remainingSgCount /= sgLayout[i];
- cplId--;
- continue;
- }
+ assert((srcShape[i] % cplSgLayout[cplId] == 0) &&
+ "source shape not divisible by consumer sg_layout");
+ sgLayout[i] = cplSgLayout[cplId];
+ sgData[i] = srcShape[i] / sgLayout[i];
+ instData[i] = std::min(cplInstData[cplId], sgData[i]);
+ remainingSgCount /= sgLayout[i];
+ cplId--;
}
- remainingDims.push_back(i);
}
// Second pass: Distribute remaining subgroups across unhandled dimensions
// This handles reduction dimensions and dimensions that didn't align with
// consumer
for (int i = srcShapeSize - 1; i >= 0; i--) {
- sgLayout[i] = std::min((srcShape[i] / laneLayout[i]),
- static_cast<int64_t>(remainingSgCount));
- sgData[i] = srcShape[i] / sgLayout[i];
- instData[i] = std::min(instData[i], sgData[i]);
- remainingSgCount /= sgLayout[i];
- if (remainingSgCount == 1) {
- break;
+ if (llvm::is_contained(reductionDims, i)) {
+ sgLayout[i] = std::min((srcShape[i] / laneLayout[i]),
+ static_cast<int64_t>(remainingSgCount));
+ sgData[i] = srcShape[i] / sgLayout[i];
+ instData[i] = std::min(instData[i], sgData[i]);
+ remainingSgCount /= sgLayout[i];
}
}
+ assert(remainingSgCount == 1 && "not all subgroups have been distributed");
}
SmallVector<int32_t> sgLayout32(sgLayout.begin(), sgLayout.end());
@@ -893,8 +902,7 @@ xegpu::reductionLayoutSetupRule(ArrayRef<int64_t> srcShape,
DenseI32ArrayAttr::get(context, sgData32),
DenseI32ArrayAttr::get(context, instData32),
DenseI32ArrayAttr::get(context, laneLayout32),
- DenseI32ArrayAttr::get(context, laneData32),
- consumerPreferredLayout.getOrder());
+ DenseI32ArrayAttr::get(context, laneData32), consumerLayout.getOrder());
xegpu::SliceAttr reductionResLayout =
xegpu::SliceAttr::get(context, proposedSrcLayout,
DenseI64ArrayAttr::get(context, reductionDims));
>From 540d1e0af9a057e574c2969aff770478738ec1b1 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Sun, 25 Jan 2026 04:51:39 +0000
Subject: [PATCH 08/35] add layoutKind parameter for SetupResultLayout
functions
---
.../Dialect/XeGPU/Utils/XeGPULayoutUtils.h | 63 +--
.../mlir/Dialect/XeGPU/Utils/XeGPUUtils.h | 44 +-
.../XeGPU/Transforms/XeGPUBlocking.cpp | 1 +
.../Transforms/XeGPUPeepHoleOptimizer.cpp | 1 +
.../XeGPU/Transforms/XeGPUPropagateLayout.cpp | 184 +++----
.../Transforms/XeGPUSubgroupDistribute.cpp | 1 +
.../Dialect/XeGPU/Transforms/XeGPUUnroll.cpp | 1 +
.../Transforms/XeGPUWgToSgDistribute.cpp | 1 +
.../Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp | 473 +++++++-----------
9 files changed, 303 insertions(+), 466 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h
index 7ea949ceaaa94..206b3d85df30b 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h
@@ -11,8 +11,10 @@
#define MLIR_DIALECT_XEGPU_UTILS_XEGPULAYOUTUTILS_H_
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpDefinition.h"
+
namespace mlir {
class VectorType;
@@ -31,47 +33,6 @@ class TensorDescType;
namespace xegpu {
-/// Return the attribute name for the OpOperand to attach DistributeLayoutAttr
-std::string getTemporaryLayoutName(const OpOperand &operand);
-
-/// Return the attribute name for the OpResult to attach DistributeLayoutAttr
-std::string getTemporaryLayoutName(const OpResult result);
-
-/// Retrieves the DistributeLayoutAttr associated with a given Value. For
-/// TensorDescType values, the DistributeLayoutAttr is extracted from the
-/// TensorDescType itself. For other values, it is obtained from the attributes
-/// of the defining operation. Returns nullptr if no DistributeLayoutAttr is
-/// found.
-DistributeLayoutAttr getDistributeLayoutAttr(const Value value);
-
-/// Retrieves the DistributeLayoutAttr associated with a given OpOperand. It
-/// will first check the operand_layout_{id} of the owner operation. If not
-/// found, it will check the operand itself and its defining op.
-DistributeLayoutAttr getDistributeLayoutAttr(const OpOperand &opr);
-
-/// [to-be-deprecated] Sets the DistributeLayoutAttr for a given OpResult
-/// user should use setAnchorLayout instead
-void setDistributeLayoutAttr(const OpResult &Result,
- const DistributeLayoutAttr layout);
-
-/// [to-be-deprecated] Sets the DistributeLayoutAttr for a given OpOperand
-/// user should use setAnchorLayout instead
-void setDistributeLayoutAttr(const OpOperand &opr,
- const DistributeLayoutAttr layout);
-
-/// get and set distribute layout attribute for non-anchor operations
-/// (and offsets/masks of load/store ops before we get rid of their temp attrs)
-template <typename T,
- typename = std::enable_if_t<std::is_same_v<T, OpOperand> ||
- std::is_same_v<T, OpResult>>>
-DistributeLayoutAttr getTemporaryLayout(const T &operandOrResult);
-
-template <typename T,
- typename = std::enable_if_t<std::is_same_v<T, OpOperand> ||
- std::is_same_v<T, OpResult>>>
-void setTemporaryLayout(const T &operandOrResult,
- const DistributeLayoutAttr layout);
-
/// [to-be-deprecated] Set the DistributeLayoutAttr for each OpOperand and
/// OpResult of of the given operation. If the operation contains regions, it is
/// also applied recursively to the contained operations operation.
@@ -124,7 +85,7 @@ DistributeLayoutAttr inferBitCastSourceLayout(DistributeLayoutAttr resLayout,
/// Infers the source layout attribute for a shape cast operation given the
/// result layout attribute, result shape, and source shape.
-DistributeLayoutAttr inferShapecastSourceLayout(DistributeLayoutAttr resLayout,
+DistributeLayoutAttr inferShapeCastSourceLayout(DistributeLayoutAttr resLayout,
ArrayRef<int64_t> resShape,
ArrayRef<int64_t> srcShape);
@@ -136,10 +97,10 @@ DistributeLayoutAttr inferShapecastSourceLayout(DistributeLayoutAttr resLayout,
/// consumer's preferred layout. This minimizes data redistribution overhead.
/// The SliceAttr for the result is then created based on the derived source
/// layout and the specified reduction dimensions.
-SliceAttr
-reductionLayoutSetupRule(ArrayRef<int64_t> srcShape,
- SmallVector<int64_t> reductionDims,
- DistributeLayoutAttr consumerPreferredLayout);
+SliceAttr reductionSetupResultLayout(xegpu::LayoutKind layoutKind,
+ ArrayRef<int64_t> srcShape,
+ DistributeLayoutAttr consumerLayout,
+ SmallVector<int64_t> reductionDims);
/// Setup the result layout attribute for a bitcast operation based on element
/// type bitwidths. This ensures the source layout can always be derived from
@@ -147,12 +108,14 @@ reductionLayoutSetupRule(ArrayRef<int64_t> srcShape,
///
/// When casting from a narrower to a wider element type (srcElemTyBitWidth <
/// resElemTyBitWidth), the result layout's innermost dimension data sizes
-/// (sg_data, inst_data, lane_data) are scaled up by the bitwidth ratio. This
+/// (inst_data, lane_data) are scaled up by the bitwidth ratio. This
/// maintains the invariant that the source layout can be recovered by inverse
/// scaling during layout inference.
-DistributeLayoutAttr bitCastLayoutSetupRule(DistributeLayoutAttr resLayout,
- int resElemTyBitWidth,
- int srcElemTyBitWidth);
+DistributeLayoutAttr
+bitCastSetupResultLayout(xegpu::LayoutKind layoutKind,
+ ArrayRef<int64_t> srcShape,
+ DistributeLayoutAttr consumerLayout,
+ int resElemTyBitWidth, int srcElemTyBitWidth);
} // namespace xegpu
diff --git a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
index 3dbbe7e4c5dff..0f1ca8e38c873 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
@@ -9,7 +9,6 @@
#ifndef MLIR_DIALECT_XEGPU_UTILS_XEGPUUTILS_H_
#define MLIR_DIALECT_XEGPU_UTILS_XEGPUUTILS_H_
-#include "XeGPULayoutUtils.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpDefinition.h"
@@ -31,6 +30,8 @@ class TensorDescType;
namespace xegpu {
+enum class LayoutKind { Lane, InstData, Subgroup };
+
/// Flatten a set of ValueRange into a single SmallVector<Value>
SmallVector<Value> flattenValues(ArrayRef<ValueRange> values);
@@ -119,6 +120,47 @@ template <typename T>
int getLargestDivisor(T dim, ArrayRef<T> candidates,
ArrayRef<T> candidateMultiples = {});
+/// [to-be-deprecated] Sets the DistributeLayoutAttr for a given OpResult
+/// user should use setAnchorLayout instead
+void setDistributeLayoutAttr(const OpResult &Result,
+ const DistributeLayoutAttr layout);
+
+/// [to-be-deprecated] Sets the DistributeLayoutAttr for a given OpOperand
+/// user should use setAnchorLayout instead
+void setDistributeLayoutAttr(const OpOperand &opr,
+ const DistributeLayoutAttr layout);
+
+/// Retrieves the DistributeLayoutAttr associated with a given Value. For
+/// TensorDescType values, the DistributeLayoutAttr is extracted from the
+/// TensorDescType itself. For other values, it is obtained from the attributes
+/// of the defining operation. Returns nullptr if no DistributeLayoutAttr is
+/// found.
+DistributeLayoutAttr getDistributeLayoutAttr(const Value value);
+
+/// Retrieves the DistributeLayoutAttr associated with a given OpOperand. It
+/// will first check the operand_layout_{id} of the owner operation. If not
+/// found, it will check the operand itself and its defining op.
+DistributeLayoutAttr getDistributeLayoutAttr(const OpOperand &opr);
+
+/// Return the attribute name for the OpOperand to attach DistributeLayoutAttr
+std::string getTemporaryLayoutName(const OpOperand &operand);
+
+/// Return the attribute name for the OpResult to attach DistributeLayoutAttr
+std::string getTemporaryLayoutName(const OpResult result);
+
+/// get and set distribute layout attribute for non-anchor operations
+/// (and offsets/masks of load/store ops before we get rid of their temp attrs)
+template <typename T,
+ typename = std::enable_if_t<std::is_same_v<T, OpOperand> ||
+ std::is_same_v<T, OpResult>>>
+DistributeLayoutAttr getTemporaryLayout(const T &operandOrResult);
+
+template <typename T,
+ typename = std::enable_if_t<std::is_same_v<T, OpOperand> ||
+ std::is_same_v<T, OpResult>>>
+void setTemporaryLayout(const T &operandOrResult,
+ const DistributeLayoutAttr layout);
+
} // namespace xegpu
} // namespace mlir
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
index 931834ba16d9a..a50c955ea83b5 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
@@ -12,6 +12,7 @@
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
+#include "mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h"
#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
#include "mlir/Interfaces/LoopLikeInterface.h"
#include "mlir/Pass/PassManager.h"
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPeepHoleOptimizer.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPeepHoleOptimizer.cpp
index 6a3e533fb2df4..9a8925d357d25 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPeepHoleOptimizer.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPeepHoleOptimizer.cpp
@@ -16,6 +16,7 @@
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
+#include "mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h"
#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
#include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h"
#include "mlir/Dialect/XeGPU/uArch/uArchBase.h"
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 36fe26f15049f..8f49647b153a4 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -15,6 +15,7 @@
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
+#include "mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h"
#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
@@ -53,8 +54,6 @@ using namespace mlir::dataflow;
namespace {
-enum class LayoutKind { Lane, InstData, Subgroup };
-
//===----------------------------------------------------------------------===//
// LayoutInfo
//===----------------------------------------------------------------------===//
@@ -370,7 +369,7 @@ getSIMTLayoutInfoForDPASOperand(VectorType vectorTy, unsigned operandNum,
class LayoutInfoPropagation
: public SparseBackwardDataFlowAnalysis<LayoutInfoLattice> {
private:
- LayoutKind layoutKind;
+ xegpu::LayoutKind layoutKind;
void visitDpasOp(xegpu::DpasOp dpas, ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results);
@@ -426,7 +425,7 @@ class LayoutInfoPropagation
public:
LayoutInfoPropagation(DataFlowSolver &solver,
SymbolTableCollection &symbolTable,
- LayoutKind layoutKind)
+ xegpu::LayoutKind layoutKind)
: SparseBackwardDataFlowAnalysis(solver, symbolTable),
layoutKind(layoutKind) {}
using SparseBackwardDataFlowAnalysis::SparseBackwardDataFlowAnalysis;
@@ -520,12 +519,12 @@ bool LayoutInfoPropagation::hasParamsOfLayoutKind(
if (anchorLayout == nullptr) {
return false;
}
- if (layoutKind == LayoutKind::InstData) {
+ if (layoutKind == xegpu::LayoutKind::InstData) {
return !(anchorLayout.getEffectiveInstDataAsInt().empty());
- } else if (layoutKind == LayoutKind::Lane) {
+ } else if (layoutKind == xegpu::LayoutKind::Lane) {
return !(anchorLayout.getEffectiveLaneLayoutAsInt().empty() ||
anchorLayout.getEffectiveLaneDataAsInt().empty());
- } else if (layoutKind == LayoutKind::Subgroup) {
+ } else if (layoutKind == xegpu::LayoutKind::Subgroup) {
return !(anchorLayout.getEffectiveSgLayoutAsInt().empty() ||
anchorLayout.getEffectiveSgDataAsInt().empty());
}
@@ -573,7 +572,7 @@ void LayoutInfoPropagation::visitPrefetchNdOp(
instData = {instHeight, instWidth};
}
- if (layoutKind == LayoutKind::InstData)
+ if (layoutKind == xegpu::LayoutKind::InstData)
prefetchLayout =
LayoutInfo(xegpu::LayoutAttr::get(tdescTy.getContext(), instData));
else
@@ -592,8 +591,8 @@ void LayoutInfoPropagation::visitVectorMultiReductionOp(
ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results) {
// The layout of the result must be present.
- LayoutInfo resultLayout = results[0]->getValue();
- if (!resultLayout.isAssigned())
+ LayoutInfo resLayoutInfo = results[0]->getValue();
+ if (!resLayoutInfo.isAssigned())
return;
VectorType sourceTy =
@@ -614,54 +613,42 @@ void LayoutInfoPropagation::visitVectorMultiReductionOp(
<< dim << " ";
llvm::dbgs() << "]\n");
- auto resultLayoutAttr =
- dyn_cast<xegpu::DistributeLayoutAttr>(resultLayout.get());
-
- LLVM_DEBUG(DBGS() << "visitVectorMultiReductionOp: resultLayoutAttr = "
- << resultLayoutAttr << "\n");
+ auto consumerLayoutAttr =
+ dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
+ LLVM_DEBUG(DBGS() << "visitVectorMultiReductionOp: consumerLayoutAttr = "
+ << consumerLayoutAttr << "\n");
// An dominant layout is for the result and represents the layout requirements
// for the operation it is recorded to anchor layout or temporary layout it
// must be honored for current op and may conflict with the layout propagated
// from consumer op the conflict is resolved in later phase by converting the
// dominant layout to the source layout
- xegpu::DistributeLayoutAttr dominantLayout = xegpu::reductionLayoutSetupRule(
- srcShape, reductionDims, resultLayoutAttr);
+ auto requiredResLayoutAttr = xegpu::reductionSetupResultLayout(
+ layoutKind, srcShape, consumerLayoutAttr, reductionDims);
+ LLVM_DEBUG(DBGS() << "visitVectorMultiReductionOp: requiredResLayoutAttr = "
+ << requiredResLayoutAttr << "\n");
- LLVM_DEBUG(DBGS() << "visitVectorMultiReductionOp: dominantLayout = "
- << dominantLayout << "\n");
-
- if (layoutKind == LayoutKind::Lane) {
- // only lane layout/data is considered
- dominantLayout = dominantLayout.dropInstData();
- dominantLayout = dominantLayout.dropSgLayoutAndData();
- } else if (layoutKind == LayoutKind::InstData) {
- dominantLayout = dominantLayout.dropSgLayoutAndData();
- }
-
- // record the dominant layout to the reduction op
- xegpu::setTemporaryLayout(reduction->getResult(0), dominantLayout);
+ resLayoutInfo.set(requiredResLayoutAttr);
// derive the source layout from the dominant layout and reduction dims
auto srcLayoutAttr =
- xegpu::inferReductionSourceLayout(dominantLayout, reductionDims);
- // void set(const xegpu::DistributeLayoutAttr &layout) { storage = layout; }
+ xegpu::inferReductionSourceLayout(requiredResLayoutAttr, reductionDims);
LLVM_DEBUG(DBGS() << "visitVectorMultiReductionOp: srcLayoutAttr = "
<< srcLayoutAttr << "\n");
propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
// Accumulator should have the same layout as the result.
- propagateIfChanged(operands[1], operands[1]->meet(resultLayout));
+ propagateIfChanged(operands[1], operands[1]->meet(resLayoutInfo));
}
void LayoutInfoPropagation::visitVectorBroadCastOp(
vector::BroadcastOp broadcast, ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results) {
// The layout of the result must be present.
- LayoutInfo resultLayout = results[0]->getValue();
- if (!resultLayout.isAssigned())
+ LayoutInfo resLayoutInfo = results[0]->getValue();
+ if (!resLayoutInfo.isAssigned())
return;
// Only consider vector to vector broadcasts for now.
@@ -681,12 +668,12 @@ void LayoutInfoPropagation::visitVectorBroadCastOp(
"with unit-dim, mixed scenario is not supported!");
auto resultLayoutAttr =
- dyn_cast<xegpu::DistributeLayoutAttr>(resultLayout.get());
+ dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
- xegpu::DistributeLayoutAttr srcLayout =
+ xegpu::DistributeLayoutAttr srcLayoutAttr =
xegpu::inferBroadcastSourceLayout(resultLayoutAttr, resShape, srcShape);
- propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayout)));
+ propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
return;
}
@@ -694,23 +681,18 @@ void LayoutInfoPropagation::visitShapeCastOp(
vector::ShapeCastOp shapeCast, ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results) {
// The layout of the result must be present.
- LayoutInfo resultLayout = results[0]->getValue();
- if (!resultLayout.isAssigned())
- return;
- VectorType sourceTy = shapeCast.getSourceVectorType();
- VectorType resultTy = shapeCast.getResultVectorType();
- // Shape cast layout propagation only supports 1D -> 2D shape casts.
- // TODO: Support kD -> nD shape casts (k < n, n >= 2) where expanded dims are
- // unit dimensions and non-unit dims match.
- if (sourceTy.getRank() != 1 || resultTy.getRank() != 2) {
- shapeCast.emitWarning("Expecting shape cast to be 1D -> 2D.");
+ LayoutInfo resLayoutInfo = results[0]->getValue();
+ if (!resLayoutInfo.isAssigned())
return;
- }
- int64_t slicedDim = resultTy.getShape()[0] == 1 ? 0 : 1;
- xegpu::SliceAttr sliceLayout = xegpu::SliceAttr::get(
- shapeCast->getContext(), cast<xegpu::LayoutAttr>(resultLayout.get()),
- DenseI64ArrayAttr::get(shapeCast->getContext(), {slicedDim}));
- propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(sliceLayout)));
+ ArrayRef<int64_t> resShape = shapeCast.getResultVectorType().getShape();
+ ArrayRef<int64_t> srcShape = shapeCast.getSourceVectorType().getShape();
+ auto resultLayoutAttr =
+ dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
+
+ xegpu::DistributeLayoutAttr srcLayoutAttr =
+ xegpu::inferShapeCastSourceLayout(resultLayoutAttr, resShape, srcShape);
+
+ propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
}
/// Propagate the layout of the result tensor to the source tensor descriptor
@@ -779,7 +761,7 @@ void LayoutInfoPropagation::visitDpasOp(
SmallVector<int> instDataA = {maxALen, subgroupSize};
SmallVector<int> instDataB = {subgroupSize, maxBLen};
- if (layoutKind == LayoutKind::InstData) {
+ if (layoutKind == xegpu::LayoutKind::InstData) {
dpasALayout =
LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataA));
dpasBLayout =
@@ -793,7 +775,7 @@ void LayoutInfoPropagation::visitDpasOp(
if (operands.size() > 2) {
VectorType cTy = dpas.getAccType();
- if (layoutKind == LayoutKind::InstData) {
+ if (layoutKind == xegpu::LayoutKind::InstData) {
const unsigned dataCLen = bTy.getShape().back();
auto supportedCLen =
uArchInstruction->getSupportedN(bTy.getElementType());
@@ -863,7 +845,7 @@ void LayoutInfoPropagation::visitStoreNdOp(
instData = {instHeight, instWidth};
}
- if (layoutKind == LayoutKind::InstData)
+ if (layoutKind == xegpu::LayoutKind::InstData)
storeLayout =
LayoutInfo(xegpu::LayoutAttr::get(dataTy.getContext(), instData));
else
@@ -930,66 +912,29 @@ void LayoutInfoPropagation::visitVectorBitcastOp(
vector::BitCastOp bitcast, ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results) {
// Need the layout of bitcast result to propagate to the operands.
- LayoutInfo resultLayout = results[0]->getValue();
- if (!resultLayout.isAssigned())
+ LayoutInfo resLayoutInfo = results[0]->getValue();
+ if (!resLayoutInfo.isAssigned())
return;
int inElemTyBitWidth =
bitcast.getSourceVectorType().getElementType().getIntOrFloatBitWidth();
int outElemTyBitWidth =
bitcast.getResultVectorType().getElementType().getIntOrFloatBitWidth();
- // If the element bit widths are the same, then the layout does not change.
- if (inElemTyBitWidth == outElemTyBitWidth) {
- propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
- return;
- }
- // Check if the result layout is valid. i.e. result vector can be distributed.
- auto resultLaneLayout = resultLayout.getLaneLayout();
- auto resultLaneData = resultLayout.getLaneData();
- if (failed(xegpu::getDistributedVectorType(
- bitcast.getResultVectorType(),
- xegpu::LayoutAttr::get(bitcast->getContext(), resultLaneLayout,
- resultLaneData)))) {
- bitcast.emitWarning(
- "Result vector type can not be evenly distributed across lanes.");
- return;
- }
- int64_t rank = bitcast.getSourceVectorType().getRank();
- // Bitcast is a `narrowing` if the input element type bit width larger than
- // the output element type bit width. eg. f32 -> f16 is a narrowing bitcast.
- bool isNarrowing = inElemTyBitWidth > outElemTyBitWidth;
- int bitCastRatio = isNarrowing ? inElemTyBitWidth / outElemTyBitWidth
- : outElemTyBitWidth / inElemTyBitWidth;
- SmallVector<int> sourceLaneLayout =
- resultLayout.getLaneLayout(); // Lane layout does not change for bitcast.
- SmallVector<int> outData = resultLayout.getLaneData();
-
- // TODO: Currently we assume that bitcasts does not require cross lane
- // communication. So each lane must own the required number of elements to
- // perform the bitcast locally without cross-lane communication.
- int outInnerBitsPerLane = outData[rank - 1] * outElemTyBitWidth;
- if (outInnerBitsPerLane < inElemTyBitWidth) {
- bitcast.emitWarning(
- "Narrowing bitcast with cross lane communication is not supported.");
- return;
- }
- // Check if each lane owns a single element in all dimensions except the
- // innermost dimension.
- SmallVector<int> sourceLaneData(outData.begin(), outData.end() - 1);
- if (llvm::any_of(sourceLaneData, [](int64_t d) { return d != 1; })) {
- bitcast.emitWarning("Each lane must not own multiple elements in any "
- "dimension other than "
- "the innermost dimension.");
- return;
- }
- // Decide lane data based on whether the bitcast is narrowing or widening.
- int64_t innerMostLaneData = isNarrowing ? outData[rank - 1] / bitCastRatio
- : outData[rank - 1] * bitCastRatio;
- sourceLaneData.push_back(innerMostLaneData);
-
- propagateIfChanged(
- operands[0],
- operands[0]->meet(LayoutInfo(xegpu::LayoutAttr::get(
- bitcast->getContext(), sourceLaneLayout, sourceLaneData))));
+
+ ArrayRef<int64_t> srcShape = bitcast.getSourceVectorType().getShape();
+ auto consumerLayoutAttr =
+ dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
+
+ auto requiredResLayoutAttr =
+ bitCastSetupResultLayout(layoutKind, srcShape, consumerLayoutAttr,
+ outElemTyBitWidth, inElemTyBitWidth);
+
+ resLayoutInfo.set(requiredResLayoutAttr);
+
+ // derive the source layout from the dominant layout and reduction dims
+ auto srcLayoutAttr = xegpu::inferBitCastSourceLayout(
+ requiredResLayoutAttr, outElemTyBitWidth, inElemTyBitWidth);
+
+ propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
}
/// Propagate the layout of the result to the tensor descriptor, mask and offset
@@ -1023,7 +968,7 @@ void LayoutInfoPropagation::visitLoadGatherOp(
instData.push_back(chunkSize);
}
- if (layoutKind == LayoutKind::InstData)
+ if (layoutKind == xegpu::LayoutKind::InstData)
loadLayout =
LayoutInfo(xegpu::LayoutAttr::get(load.getContext(), instData));
else
@@ -1085,7 +1030,7 @@ void LayoutInfoPropagation::visitStoreScatterOp(
auto uArch = getUArch(getChipStr(storeScatter).value_or(""));
const int subgroupSize = uArch->getSubgroupSize();
- if (layoutKind == LayoutKind::InstData) {
+ if (layoutKind == xegpu::LayoutKind::InstData) {
SmallVector<int> instData{subgroupSize};
if (auto chunkSize = storeScatter.getChunkSize().value_or(0);
chunkSize > 1)
@@ -1135,7 +1080,8 @@ class RunLayoutInfoPropagation {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(RunLayoutInfoPropagation)
- RunLayoutInfoPropagation(Operation *op, LayoutKind layoutKind) : target(op) {
+ RunLayoutInfoPropagation(Operation *op, xegpu::LayoutKind layoutKind)
+ : target(op) {
SymbolTableCollection symbolTable;
loadBaselineAnalyses(solver);
solver.load<LayoutInfoPropagation>(symbolTable, layoutKind);
@@ -1373,13 +1319,13 @@ struct XeGPUPropagateLayoutPass final
} // namespace
void XeGPUPropagateLayoutPass::runOnOperation() {
- LayoutKind layoutKind;
+ xegpu::LayoutKind layoutKind;
if (this->layoutKind == "lane") {
- layoutKind = LayoutKind::Lane;
+ layoutKind = xegpu::LayoutKind::Lane;
} else if (this->layoutKind == "inst") {
- layoutKind = LayoutKind::InstData;
+ layoutKind = xegpu::LayoutKind::InstData;
} else if (this->layoutKind == "subgroup") {
- layoutKind = LayoutKind::Subgroup;
+ layoutKind = xegpu::LayoutKind::Subgroup;
} else {
getOperation()->emitError("Unsupported layout kind option: " +
this->layoutKind);
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 9113f00ac39f0..8898be5f13dab 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -14,6 +14,7 @@
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
+#include "mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h"
#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
#include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h"
#include "mlir/IR/AffineMap.h"
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index 8f4e2bb0451d8..6bafdc955c9d3 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -15,6 +15,7 @@
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
+#include "mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h"
#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/DebugLog.h"
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 45a002b63abd6..72c8ea3de9976 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -19,6 +19,7 @@
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
+#include "mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h"
#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
#include "mlir/Transforms/DialectConversion.h"
#include <optional>
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp
index fb416b4c81f15..c21e22c966f3d 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp
@@ -10,13 +10,13 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/XeVMDialect.h"
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
-#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/ValueRange.h"
@@ -278,235 +278,26 @@ xegpu::dropInstDataOnAttrs(ArrayRef<NamedAttribute> attrs) {
return out;
}
-/// Attach layout attributes to all vector-type operands of operations within
-/// the given operation's region. Reports an error if any vector operand lacks
-/// a layout attribute.
-// bool xegpu::recoverTemporaryLayouts(Operation *rootOp) {
-// auto result = rootOp->walk([&](Operation *op) {
-// for (OpOperand &operand : op->getOpOperands()) {
-// // Layouts are needed for vector type only.
-// if (!isa<VectorType>(operand.get().getType()))
-// continue;
-// auto layout = xegpu::getDistributeLayoutAttr(operand.get());
-// if (!layout) {
-// op->emitError("Could not find layout attribute for operand ")
-// << operand.getOperandNumber() << " of operation " <<
-// op->getName();
-// return WalkResult::interrupt();
-// }
-// xegpu::setDistributeLayoutAttr(operand, layout);
-// }
-// return WalkResult::advance();
-// });
-// return !result.wasInterrupted();
-// }
-
-// Prerequisite for Layout Recovery
-// It relies on the following invariant:
-// 1. there is no layout conflict between different uses of the same definition.
-// 2. each definition has a well-defined layout requirement at its use point.
-// - Every definition must have at least one use that appears after it in
-// topological order.
-// - If a definition has no such use (e.g., a loop result or region output),
-// an explicit convert_layout operation is inserted to create a use.
-// - Only the result of convert_layout is permitted to have no subsequent
-// use.
-
-// The recover proceeds by scanning the operation in reverse topological order
-// as follows:
-// For regular operations: First the result layouts are propagated from uses.
-// Then the result layouts are propagated to uses (operands).
-//
-// For region operations (e.g., loops):
-// - When backward propagation reaches a region op, it sets the layout of
-// the region op’s results according to use points like regular ops.
-// - Then, the result layouts (such as a loop output) are propagated to
-// thiers corresponding operands in the yield.
-// - When backward propagation reaches the first operation inside the
-// region, the pass examines the region op’s initialization list,
-// propagating from region arguments to the corresponding initialization
-// operands.
-// - This ensures that layout constraints are consistently propagated
-// across region boundaries
-// while preserving a single well-defined use for each definition at the
-// region-op level.
-
-// Forward declarations
-static void walkRegionBackward(Region ®ion,
- llvm::function_ref<void(Operation *)> visit);
-static void propagateResultsToRegularOperands(Operation *op);
-static void propagateRegionResultsToYieldOperands(
- mlir::RegionBranchTerminatorOpInterface yieldOp);
-static void propagateRegionArgsToInits(mlir::RegionBranchOpInterface *regionOp);
-
-// the inner function for recoverTemporaryLayouts is a recursive function
-// the input rootOp is the function operation, which is also a region op.
-// it recursivley process the region op in reverse topological order.
+// Attach layout attributes to all vector-type operands of operations within
+// the given operation's region. Reports an error if any vector operand lacks
+// a layout attribute.
bool xegpu::recoverTemporaryLayouts(Operation *rootOp) {
- rootOp->walk([&](func::FuncOp func) {
- walkRegionBackward(func.getBody(), [&](Operation *op) {
- if (auto regionOp = dyn_cast<mlir::RegionBranchOpInterface>(op)) {
- // hit the region op after visiting inside region
- propagateRegionArgsToInits(®ionOp);
- } else if (auto yieldOp =
- dyn_cast<mlir::RegionBranchTerminatorOpInterface>(op)) {
- // yield op inside region op
- propagateRegionResultsToYieldOperands(yieldOp);
- } else {
- // if the op is regular op, calling propagateResultsToRegularOperands
- propagateResultsToRegularOperands(op);
- }
- });
- });
-}
-
-static void walkRegionBackward(Region ®ion,
- llvm::function_ref<void(Operation *)> visit) {
- // blocks: back -> front
- for (Block &block : llvm::reverse(region)) {
- // ops: back -> front, early-inc so visit() may erase current op safely
- for (Operation &op : llvm::reverse(block)) {
- // make sure we first visit inside the region op (so yield op first)
- // and then move to region op itself
- for (Region &nested : llvm::reverse(op.getRegions()))
- walkRegionBackward(nested, visit);
-
- visit(&op);
- }
- }
-}
-
-// For regular operations: First the result layouts are propagated from uses.
-// Then the result layouts are propagated to uses (operands).
-static void propagateResultsToRegularOperands(Operation *op) {
- OpResult result = op->getOpResults()[0];
- auto resLayout = xegpu::getDistributeLayoutAttr(result);
- assert(resLayout &&
- "result layout must be defined before propagating to uses");
-
- // if op is reduction op, call inferReductionSourceLayout
- if (auto reduceOp = dyn_cast<vector::MultiDimReductionOp>(op)) {
- SmallVector<int64_t> reduceDims(reduceOp.getReductionDims().begin(),
- reduceOp.getReductionDims().end());
- auto srcLayout = xegpu::inferReductionSourceLayout(resLayout, reduceDims);
- // set the layout to the operand
- xegpu::setTemporaryLayout(reduceOp->getOpOperand(0), srcLayout);
- xegpu::setTemporaryLayout(reduceOp->getOpOperand(1), resLayout);
- return;
- }
-
- // if op is broadcast op, call inferBroadcastSourceLayout
- if (auto broadcastOp = dyn_cast<vector::BroadcastOp>(op)) {
- ArrayRef<int64_t> resShape =
- llvm::cast<VectorType>(broadcastOp.getResult().getType()).getShape();
- ArrayRef<int64_t> srcShape =
- llvm::cast<VectorType>(broadcastOp.getSource().getType()).getShape();
- auto srcLayout =
- xegpu::inferBroadcastSourceLayout(resLayout, resShape, srcShape);
- // set the layout to the operand
- xegpu::setTemporaryLayout(broadcastOp->getOpOperand(0), srcLayout);
- return;
- }
-
- // if op is bitcast op, call inferBitCastSourceLayout
- if (auto bitcastOp = dyn_cast<vector::BitCastOp>(op)) {
- int resElemTyBitWidth =
- llvm::cast<VectorType>(bitcastOp.getResult().getType())
- .getElementTypeBitWidth();
- int srcElemTyBitWidth =
- llvm::cast<VectorType>(bitcastOp.getSource().getType())
- .getElementTypeBitWidth();
- auto srcLayout = xegpu::inferBitCastSourceLayout(
- resLayout, resElemTyBitWidth, srcElemTyBitWidth);
- // set the layout to the operand
- xegpu::setTemporaryLayout(bitcastOp->getOpOperand(0), srcLayout);
- return;
- }
-
- // if op is shape_cast op, call inferShapecastSourceLayout
- if (auto shapeCastOp = dyn_cast<vector::ShapeCastOp>(op)) {
- ArrayRef<int64_t> resShape =
- llvm::cast<VectorType>(shapeCastOp.getResult().getType()).getShape();
- ArrayRef<int64_t> srcShape =
- llvm::cast<VectorType>(shapeCastOp.getSource().getType()).getShape();
- auto srcLayout =
- xegpu::inferShapecastSourceLayout(resLayout, resShape, srcShape);
- // set the layout to the operand
- xegpu::setTemporaryLayout(shapeCastOp->getOpOperand(0), srcLayout);
- return;
- }
-
- // if op is a anchor op, no need to do anything
- if (isa<xegpu::AnchorLayoutInterface>(op)) {
- return;
- }
-
- // for other regular ops, propagate the result layout to all vector operands
- for (OpOperand &opr : op->getOpOperands()) {
- // Layouts are needed for vector type only.
- if (!isa<VectorType>(opr.get().getType()))
- continue;
- xegpu::setTemporaryLayout(opr, resLayout);
- }
-}
-
-static void propagateRegionResultsToYieldOperands(
- mlir::RegionBranchTerminatorOpInterface yieldOp) {
- llvm::SmallVector<mlir::RegionSuccessor> successors;
- llvm::SmallVector<mlir::Attribute> operands(yieldOp->getNumOperands(),
- nullptr);
- yieldOp.getSuccessorRegions(operands, successors);
-
- for (mlir::RegionSuccessor &successor : successors) {
- // find out the successor which is the parent operation
- if (!successor.isParent())
- continue;
- // For parent successor, the region arguments of the current region
- // correspond to the results of the parent operation
- Operation *parentOp = yieldOp->getParentOp();
- for (unsigned i = 0; i < yieldOp->getNumOperands(); ++i) {
- Value parentResult = parentOp->getResult(i);
- auto layout = xegpu::getDistributeLayoutAttr(parentResult);
- assert(
- layout &&
- "region result layout must be defined before propagating to yield");
- xegpu::setTemporaryLayout(yieldOp->getOpOperand(i), layout);
- }
- }
-}
-
-void propagateRegionArgsToInits(mlir::RegionBranchOpInterface *regionOp) {
-
- // Get entry successors (regions that can be entered initially)
- SmallVector<RegionSuccessor> successors;
- regionOp->getEntrySuccessorRegions(/*operands=*/ArrayRef<Attribute>(),
- successors);
-
- // For each possible entry region, get the operands forwarded to it
- for (RegionSuccessor &successor : successors) {
- if (successor.isParent())
- continue;
- OperandRange initOperands = regionOp->getEntrySuccessorOperands(successor);
- // initOperands are the initialization arguments for this successor
- // iterate the region arguments
- Region *successorRegion = successor.getSuccessor();
- for (unsigned i = 0; i < successorRegion->getNumArguments(); ++i) {
- Value regionArg = successorRegion->getArgument(i);
- auto layout = xegpu::getDistributeLayoutAttr(regionArg);
- if (!layout)
+ auto result = rootOp->walk([&](Operation *op) {
+ for (OpOperand &operand : op->getOpOperands()) {
+ // Layouts are needed for vector type only.
+ if (!isa<VectorType>(operand.get().getType()))
continue;
- // The initOperands are a subset of the parent operation's operands
- // We need to find which operand index this corresponds to
- Value initOperand = initOperands[i];
- // Find the operand index in the parent operation
- for (OpOperand &operand : (*regionOp)->getOpOperands()) {
- if (operand.get() == initOperand) {
- xegpu::setTemporaryLayout(operand, layout);
- break;
- }
+ auto layout = xegpu::getDistributeLayoutAttr(operand.get());
+ if (!layout) {
+ op->emitError("Could not find layout attribute for operand ")
+ << operand.getOperandNumber() << " of operation " << op->getName();
+ return WalkResult::interrupt();
}
+ xegpu::setDistributeLayoutAttr(operand, layout);
}
- }
+ return WalkResult::advance();
+ });
+ return !result.wasInterrupted();
}
template <typename T, typename>
@@ -633,7 +424,7 @@ xegpu::inferBitCastSourceLayout(xegpu::DistributeLayoutAttr resLayout,
/// Infers the source layout attribute for a shape cast operation given the
/// result layout attribute, result shape, and source shape.
xegpu::DistributeLayoutAttr
-xegpu::inferShapecastSourceLayout(xegpu::DistributeLayoutAttr resLayout,
+xegpu::inferShapeCastSourceLayout(xegpu::DistributeLayoutAttr resLayout,
ArrayRef<int64_t> resShape,
ArrayRef<int64_t> srcShape) {
@@ -760,10 +551,9 @@ xegpu::inferShapecastSourceLayout(xegpu::DistributeLayoutAttr resLayout,
/// reduced result stays with the same subgroup distribution as expected by
/// the consumer.
-xegpu::SliceAttr
-xegpu::reductionLayoutSetupRule(ArrayRef<int64_t> srcShape,
- SmallVector<int64_t> reductionDims,
- DistributeLayoutAttr consumerLayout) {
+xegpu::SliceAttr xegpu::reductionSetupResultLayout(
+ xegpu::LayoutKind layoutKind, ArrayRef<int64_t> srcShape,
+ DistributeLayoutAttr consumerLayout, SmallVector<int64_t> reductionDims) {
xegpu::SliceAttr consumerSliceLayout =
dyn_cast<xegpu::SliceAttr>(consumerLayout);
@@ -782,15 +572,18 @@ xegpu::reductionLayoutSetupRule(ArrayRef<int64_t> srcShape,
SmallVector<int64_t> sgLayout(srcShapeSize);
SmallVector<int64_t> sgData(srcShapeSize);
+ SmallVector<int64_t> instData(srcShapeSize);
+ SmallVector<int64_t> laneLayout(srcShapeSize);
+ SmallVector<int64_t> laneData(srcShapeSize);
- SmallVector<int64_t> instData(srcShapeSize, 1);
- instData[srcShapeSize - 1] = subgroupSize;
- instData[srcShapeSize - 2] =
+ SmallVector<int64_t> defaultInstData(srcShapeSize, 1);
+ defaultInstData[srcShapeSize - 1] = subgroupSize;
+ defaultInstData[srcShapeSize - 2] =
vectorSize; // This will be adjusted based on actual data distribution
- SmallVector<int64_t> laneLayout(srcShapeSize, 1);
- laneLayout[srcShapeSize - 1] = subgroupSize;
- SmallVector<int64_t> laneData(srcShapeSize, 1);
+ SmallVector<int64_t> defaultLaneLayout(srcShapeSize, 1);
+ defaultLaneLayout[srcShapeSize - 1] = subgroupSize;
+ SmallVector<int64_t> defaultLaneData(srcShapeSize, 1);
// Strategy 1: Try to preserve the consumer's slice layout structure
// If the consumer already expects a slice layout with the same reduction
@@ -811,7 +604,7 @@ xegpu::reductionLayoutSetupRule(ArrayRef<int64_t> srcShape,
if (!consumerSliceLayout.getDims().asArrayRef().equals(reductionDims))
return false;
xegpu::DistributeLayoutAttr parentLayout = consumerSliceLayout.getParent();
- if (!parentLayout.getRank() == srcShapeSize)
+ if (parentLayout.getRank() != srcShapeSize)
return false;
SmallVector<int64_t> parentSgLayout =
@@ -842,54 +635,103 @@ xegpu::reductionLayoutSetupRule(ArrayRef<int64_t> srcShape,
SmallVector<int64_t> parentLaneData =
consumerSliceLayout.getEffectiveLaneDataAsInt();
- for (int i = 0; i < srcShapeSize; i++) {
- sgLayout[i] = parentSgLayout[i];
- sgData[i] = srcShape[i] / sgLayout[i];
- instData[i] = std::min(instData[i], sgData[i]);
- laneLayout[i] = parentLaneLayout[i];
- laneData[i] = parentLaneData[i];
+ switch (layoutKind) {
+ case xegpu::LayoutKind::Subgroup:
+ sgLayout = parentSgLayout;
+ for (int i = 0; i < srcShapeSize; i++)
+ if (llvm::is_contained(reductionDims, i))
+ sgData[i] = srcShape[i] / sgLayout[i];
+ else
+ sgData[i] = parentSgLayout[i];
+ break;
+ case xegpu::LayoutKind::InstData:
+ for (int i = 0; i < srcShapeSize; i++)
+ instData[i] = std::min(defaultInstData[i], srcShape[i]);
+ break;
+ case xegpu::LayoutKind::Lane:
+ laneLayout = parentLaneLayout;
+ for (int i = 0; i < srcShapeSize; i++) {
+ assert((srcShape[i] % laneLayout[i] == 0) &&
+ "source shape not divisible by lane layout");
+ laneData[i] = srcShape[i] / laneLayout[i];
+ }
+ break;
+ default:
+ llvm_unreachable("unsupported layout kind");
}
-
} else {
// Strategy 2: Construct a new layout aligned with consumer's sg_layout for
// the result (non-reduction dims) then distribute remaining subgroups
// across reduced dimensions
- SmallVector<int64_t> cplSgLayout =
+ SmallVector<int64_t> consumerSgLayout =
consumerLayout.getEffectiveSgLayoutAsInt();
- SmallVector<int64_t> cplSgData = consumerLayout.getEffectiveSgDataAsInt();
- SmallVector<int64_t> cplInstData =
- consumerLayout.getEffectiveInstDataAsInt();
+ SmallVector<int64_t> consumerLaneLayout =
+ consumerLayout.getEffectiveLaneLayoutAsInt();
int remainingSgCount = workgroupSize;
- SmallVector<int64_t> remainingDims;
- int cplId = cplSgLayout.size() - 1;
- // For non-reduction dimensions, try to match consumer's sg_layout
- // This ensures the result after reduction has the expected distribution
- for (int i = srcShapeSize - 1; i >= 0; i--) {
- if (!llvm::is_contained(reductionDims, i) && cplId >= 0) {
- assert((srcShape[i] % cplSgLayout[cplId] == 0) &&
- "source shape not divisible by consumer sg_layout");
- sgLayout[i] = cplSgLayout[cplId];
- sgData[i] = srcShape[i] / sgLayout[i];
- instData[i] = std::min(cplInstData[cplId], sgData[i]);
- remainingSgCount /= sgLayout[i];
- cplId--;
- }
- }
-
- // Second pass: Distribute remaining subgroups across unhandled dimensions
- // This handles reduction dimensions and dimensions that didn't align with
- // consumer
- for (int i = srcShapeSize - 1; i >= 0; i--) {
- if (llvm::is_contained(reductionDims, i)) {
- sgLayout[i] = std::min((srcShape[i] / laneLayout[i]),
- static_cast<int64_t>(remainingSgCount));
- sgData[i] = srcShape[i] / sgLayout[i];
- instData[i] = std::min(instData[i], sgData[i]);
- remainingSgCount /= sgLayout[i];
- }
+ int remainingLaneCount = subgroupSize;
+ int consumerSgId, consumerLaneId;
+
+ switch (layoutKind) {
+ case xegpu::LayoutKind::Subgroup:
+ consumerSgId = consumerSgLayout.size() - 1;
+ // For non-reduction dimensions, try to match consumer's sg_layout
+ // This ensures the result after reduction has the expected distribution
+ for (int i = srcShapeSize - 1; i >= 0; i--)
+ if (!llvm::is_contained(reductionDims, i) && consumerSgId >= 0) {
+ sgLayout[i] = consumerSgLayout[consumerSgId];
+ assert((srcShape[i] % sgLayout[i] == 0) &&
+ "source shape not divisible by consumer sg_layout");
+ sgData[i] = srcShape[i] / sgLayout[i];
+ remainingSgCount /= sgLayout[i];
+ consumerSgId--;
+ }
+ // Second pass: Distribute remaining subgroups across unhandled dimensions
+ // This handles reduction dimensions that don't necessarily align with
+ // consumer
+ for (int i = srcShapeSize - 1; i >= 0; i--)
+ if (llvm::is_contained(reductionDims, i)) {
+ sgLayout[i] = std::min((srcShape[i] / defaultLaneLayout[i]),
+ static_cast<int64_t>(remainingSgCount));
+ assert((srcShape[i] % sgLayout[i] == 0) &&
+ "source shape not divisible by consumer sg_layout");
+ sgData[i] = srcShape[i] / sgLayout[i];
+ remainingSgCount /= sgLayout[i];
+ }
+ assert(remainingSgCount == 1 &&
+ "not all subgroups have been distributed");
+ break;
+ case xegpu::LayoutKind::InstData:
+ for (int i = 0; i < srcShapeSize; i++)
+ instData[i] = std::min(defaultInstData[i], srcShape[i]);
+ break;
+ case xegpu::LayoutKind::Lane:
+ consumerLaneId = consumerLaneLayout.size() - 1;
+ // For non-reduction dimensions, try to match consumer's lane_layout
+ // This ensures the result after reduction has the expected distribution
+ for (int i = 0; i < srcShapeSize; i++)
+ if (!llvm::is_contained(reductionDims, i) && consumerLaneId >= 0) {
+ laneLayout[i] = consumerLaneLayout[consumerLaneId];
+ assert((srcShape[i] % laneLayout[i] == 0) &&
+ "source shape not divisible by consumer lane_layout");
+ laneData[i] = srcShape[i] / laneLayout[i];
+ remainingLaneCount /= laneLayout[i];
+ consumerLaneId--;
+ }
+ for (int i = 0; i < srcShapeSize; i++)
+ if (llvm::is_contained(reductionDims, i)) {
+ laneLayout[i] = std::min((srcShape[i] / defaultLaneLayout[i]),
+ static_cast<int64_t>(remainingLaneCount));
+ assert((srcShape[i] % laneLayout[i] == 0) &&
+ "source shape not divisible by consumer lane_layout");
+ laneData[i] = srcShape[i] / laneLayout[i];
+ remainingLaneCount /= laneLayout[i];
+ }
+ assert(remainingLaneCount == 1 && "not all lanes have been distributed");
+ break;
+ default:
+ llvm_unreachable("unsupported layout kind");
}
- assert(remainingSgCount == 1 && "not all subgroups have been distributed");
}
SmallVector<int32_t> sgLayout32(sgLayout.begin(), sgLayout.end());
@@ -897,37 +739,76 @@ xegpu::reductionLayoutSetupRule(ArrayRef<int64_t> srcShape,
SmallVector<int32_t> instData32(instData.begin(), instData.end());
SmallVector<int32_t> laneLayout32(laneLayout.begin(), laneLayout.end());
SmallVector<int32_t> laneData32(laneData.begin(), laneData.end());
+
proposedSrcLayout = xegpu::LayoutAttr::get(
context, DenseI32ArrayAttr::get(context, sgLayout32),
DenseI32ArrayAttr::get(context, sgData32),
DenseI32ArrayAttr::get(context, instData32),
DenseI32ArrayAttr::get(context, laneLayout32),
DenseI32ArrayAttr::get(context, laneData32), consumerLayout.getOrder());
- xegpu::SliceAttr reductionResLayout =
+
+ xegpu::SliceAttr resLayout =
xegpu::SliceAttr::get(context, proposedSrcLayout,
DenseI64ArrayAttr::get(context, reductionDims));
- return reductionResLayout;
+ return resLayout;
}
xegpu::DistributeLayoutAttr
-xegpu::bitCastLayoutSetupRule(xegpu::DistributeLayoutAttr resLayout,
- int resElemTyBitWidth, int srcElemTyBitWidth) {
- SmallVector<int64_t> sgData = resLayout.getEffectiveSgDataAsInt();
- SmallVector<int64_t> instData = resLayout.getEffectiveInstDataAsInt();
- SmallVector<int64_t> laneData = resLayout.getEffectiveLaneDataAsInt();
- size_t dim = sgData.size() - 1;
+xegpu::bitCastSetupResultLayout(xegpu::LayoutKind layoutKind,
+ ArrayRef<int64_t> srcShape,
+ DistributeLayoutAttr consumerLayout,
+ int resElemTyBitWidth, int srcElemTyBitWidth) {
+ SmallVector<int64_t> sgData = consumerLayout.getEffectiveSgDataAsInt();
+ SmallVector<int64_t> instData = consumerLayout.getEffectiveInstDataAsInt();
+ SmallVector<int64_t> laneData = consumerLayout.getEffectiveLaneDataAsInt();
+ size_t dim = srcShape.size() - 1;
int64_t sgDataValue, instDataValue, laneDataValue;
- if (srcElemTyBitWidth < resElemTyBitWidth) {
- int bitWidthRatio = resElemTyBitWidth / srcElemTyBitWidth;
- sgDataValue = (dim < sgData.size()) ? sgData[dim] * bitWidthRatio : -1;
- instDataValue =
- (dim < instData.size()) ? instData[dim] * bitWidthRatio : -1;
- laneDataValue =
- (dim < laneData.size()) ? laneData[dim] * bitWidthRatio : -1;
+ const int subgroupSize = 16; // assuming 16 lanes per subgroup
+
+ if (srcElemTyBitWidth > resElemTyBitWidth) {
+ // When casting to a smaller bitwidth, multiply the result layout
+ // accordingly to ensure it can be divided by the ratio back to the
+ // source layout.
+ int bitWidthRatio = srcElemTyBitWidth / resElemTyBitWidth;
+ int innermostDimLaneLayout = subgroupSize;
+ switch (layoutKind) {
+ case xegpu::LayoutKind::Subgroup:
+ assert(sgData.size() == srcShape.size() &&
+ "sgData must be available for all dimensions");
+ sgDataValue = sgData[dim];
+ break;
+ case xegpu::LayoutKind::InstData:
+ assert(instData.size() == srcShape.size() &&
+ "instData must be available for all dimensions");
+ instDataValue = instData[dim];
+ // adjust instDataValue so it still fits within an instruction after
+ // dividing by bitWidthRatio
+ while ((instDataValue <= srcShape[dim]) &&
+ (instDataValue % (innermostDimLaneLayout * bitWidthRatio) != 0))
+ instDataValue *= 2;
+ assert(srcShape[dim] % instDataValue == 0 &&
+ "srcShape, instData, and lanelayout for innermost must be 2^n !");
+ break;
+ case xegpu::LayoutKind::Lane:
+ assert(laneData.size() == srcShape.size() &&
+ "laneData must be available for all dimensions");
+ laneDataValue = laneData[dim];
+ while ((laneDataValue <= srcShape[dim]) &&
+ (laneDataValue % bitWidthRatio != 0))
+ laneDataValue *= 2;
+ break;
+ default:
+ llvm_unreachable("unsupported layout kind");
+ }
+ } else {
+ sgDataValue = sgData[dim];
+ instDataValue = instData[dim];
+ laneDataValue = laneData[dim];
}
+
// Now set only instData and laneData, preserving sgData
- xegpu::DistributeLayoutAttr finalResLayout;
- finalResLayout =
- resLayout.setDimData(dim, sgDataValue, instDataValue, laneDataValue);
- return finalResLayout;
+ xegpu::DistributeLayoutAttr resLayout;
+ resLayout =
+ consumerLayout.setDimData(dim, sgDataValue, instDataValue, laneDataValue);
+ return resLayout;
}
\ No newline at end of file
>From a0537d61e9c0aa5debc09008b3b34f630dbdb5dd Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Mon, 26 Jan 2026 19:50:50 +0000
Subject: [PATCH 09/35] reverse the slice attribute hint setting for
store_matrix generated from cross-sg reduction
---
.../Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp | 9 +--------
mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir | 8 ++++----
2 files changed, 5 insertions(+), 12 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 72c8ea3de9976..88dc670e0ed6a 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -1489,15 +1489,8 @@ struct WgToSgMultiDimReductionOp
SmallVector<OpFoldResult> storeOffsets2D = {rowOffsetStore, colOffset};
- auto storeMatrixLayout = xegpu::SliceAttr::get(
- rewriter.getContext(),
- xegpu::LayoutAttr::get(rewriter.getContext(), /*sg_layout =*/nullptr,
- /*sg_data =*/nullptr,
- /*inst_data =*/nullptr, /*lane_layout =*/nullptr,
- /*lane_data =*/nullptr, /*order =*/nullptr),
- dyn_cast<xegpu::SliceAttr>(layout).getDims());
xegpu::StoreMatrixOp::create(rewriter, loc, storeData, memDesc.getResult(),
- storeOffsets2D, /*layout=*/storeMatrixLayout);
+ storeOffsets2D, /*layout=*/nullptr);
gpu::BarrierOp::create(rewriter, loc);
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
index 1fc2328d09046..9cb96775b4ee4 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
@@ -674,7 +674,7 @@ gpu.module @test_distribution {
// CHECK-DAG: %[[MUL3:.*]] = arith.muli %[[AFFINE3]], %[[C1:.*]] : index
// CHECK-DAG: %[[ADD2:.*]] = arith.addi %[[ADD1]], %[[MUL3]] : index
// CHECK-DAG: %[[COL_OFFSET:.*]] = arith.muli %[[ADD2]], %[[C32:.*]] : index
- // CHECK-DAG: xegpu.store_matrix %[[SHAPE_CAST]], %[[MEM_DESC]][%[[ROW_OFFSET]], %[[COL_OFFSET]]] <{layout = #xegpu.slice<#xegpu.layout<>, dims = [1]>}>: vector<1x32xf32>, !xegpu.mem_desc<32x32xf32>, index, index
+ // CHECK-DAG: xegpu.store_matrix %[[SHAPE_CAST]], %[[MEM_DESC]][%[[ROW_OFFSET]], %[[COL_OFFSET]]] : vector<1x32xf32>, !xegpu.mem_desc<32x32xf32>, index, index
// CHECK-DAG: gpu.barrier
// CHECK-DAG: %[[LOAD_SLM:.*]] = xegpu.load_matrix %[[MEM_DESC]][%[[C0:.*]], %[[COL_OFFSET]]] : !xegpu.mem_desc<32x32xf32>, index, index -> vector<32x32xf32>
// CHECK-DAG: %[[CST_3:.*]] = arith.constant dense<0.000000e+00> : vector<32xf32>
@@ -717,7 +717,7 @@ gpu.module @test_distribution {
// CHECK-DAG: %[[MUL4:.*]] = arith.muli {{.*}}, %[[C1:.*]] : index
// CHECK-DAG: %[[ADD1:.*]] = arith.addi %[[C0:.*]], %[[MUL4]] : index
// CHECK-DAG: %[[COL_OFFSET:.*]] = arith.muli %[[ADD1]], %[[C32:.*]] : index
- // CHECK-DAG: xegpu.store_matrix %[[SHAPE_CAST]], %[[MEM_DESC]][%[[ROW_OFFSET]], %[[COL_OFFSET]]] <{layout = #xegpu.slice<#xegpu.layout<>, dims = [0]>}>: vector<1x32xf32>, !xegpu.mem_desc<8x128xf32>, index, index
+ // CHECK-DAG: xegpu.store_matrix %[[SHAPE_CAST]], %[[MEM_DESC]][%[[ROW_OFFSET]], %[[COL_OFFSET]]] : vector<1x32xf32>, !xegpu.mem_desc<8x128xf32>, index, index
// CHECK-DAG: gpu.barrier
// CHECK-DAG: %[[LOAD_SLM:.*]] = xegpu.load_matrix %[[MEM_DESC]][%[[C0:.*]], %[[COL_OFFSET]]] : !xegpu.mem_desc<8x128xf32>, index, index -> vector<8x32xf32>
// CHECK-DAG: %[[CST_CROSS_SG_1:.*]] = arith.constant dense<0.000000e+00> : vector<32xf32>
@@ -766,7 +766,7 @@ gpu.module @test_distribution {
// CHECK-DAG: %[[MUL4:.*]] = arith.muli {{.*}}, %[[C2:.*]] : index
// CHECK-DAG: %[[ADD3:.*]] = arith.addi %[[ADD2]], %[[MUL4]] : index
// CHECK-DAG: %[[COL_OFFSET:.*]] = arith.muli %[[ADD3]], %[[C1:.*]] : index
- // CHECK-DAG: xegpu.store_matrix %[[SHAPE_CAST]], %[[MEM_DESC]][%[[ROW_OFFSET]], %[[COL_OFFSET]]] <{layout = #xegpu.slice<#xegpu.layout<>, dims = [2, 3]>}>: vector<1x1xf32>, !xegpu.mem_desc<16x4xf32>, index, index
+ // CHECK-DAG: xegpu.store_matrix %[[SHAPE_CAST]], %[[MEM_DESC]][%[[ROW_OFFSET]], %[[COL_OFFSET]]] : vector<1x1xf32>, !xegpu.mem_desc<16x4xf32>, index, index
// CHECK-DAG: gpu.barrier
// CHECK-DAG: %[[LOAD_SLM:.*]] = xegpu.load_matrix %[[MEM_DESC]][%[[C0:.*]], %[[COL_OFFSET]]] : !xegpu.mem_desc<16x4xf32>, index, index -> vector<16x1xf32>
// CHECK-DAG: %[[CST_3:.*]] = arith.constant dense<0.000000e+00> : vector<1xf32>
@@ -810,7 +810,7 @@ gpu.module @test_distribution {
// CHECK-DAG: %[[MUL4:.*]] = arith.muli {{.*}}, %[[C2:.*]] : index
// CHECK-DAG: %[[ADD3:.*]] = arith.addi %[[ADD2]], %[[MUL4]] : index
// CHECK-DAG: %[[COL_OFFSET:.*]] = arith.muli %[[ADD3]], %[[C256:.*]] : index
- // CHECK-DAG: xegpu.store_matrix %[[SHAPE_CAST]], %[[MEM_DESC]][%[[ROW_OFFSET]], %[[COL_OFFSET]]] <{layout = #xegpu.slice<#xegpu.layout<>, dims = [2, 3]>}>: vector<1x256xf32>, !xegpu.mem_desc<16x1024xf32>, index, index
+ // CHECK-DAG: xegpu.store_matrix %[[SHAPE_CAST]], %[[MEM_DESC]][%[[ROW_OFFSET]], %[[COL_OFFSET]]] : vector<1x256xf32>, !xegpu.mem_desc<16x1024xf32>, index, index
// CHECK-DAG: gpu.barrier
// CHECK-DAG: %[[LOAD_SLM:.*]] = xegpu.load_matrix %[[MEM_DESC]][%[[C0:.*]], %[[COL_OFFSET]]] : !xegpu.mem_desc<16x1024xf32>, index, index -> vector<16x256xf32>
// CHECK-DAG: %[[CST_3:.*]] = arith.constant dense<0.000000e+00> : vector<256xf32>
>From cbdbf38a02d1562935427f87aa65f0e7f3b5f1f2 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Tue, 27 Jan 2026 00:01:46 +0000
Subject: [PATCH 10/35] add store_matrix support
---
.../Dialect/XeGPU/Utils/XeGPULayoutUtils.h | 15 ++-
.../XeGPU/Transforms/XeGPUPropagateLayout.cpp | 55 +++++++----
.../Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp | 92 +++++++++++++++----
mlir/test/Dialect/XeGPU/propagate-layout.mlir | 2 +-
4 files changed, 122 insertions(+), 42 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h
index 206b3d85df30b..39afa4d7c1b8c 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h
@@ -12,6 +12,7 @@
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
+#include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpDefinition.h"
@@ -100,7 +101,8 @@ DistributeLayoutAttr inferShapeCastSourceLayout(DistributeLayoutAttr resLayout,
SliceAttr reductionSetupResultLayout(xegpu::LayoutKind layoutKind,
ArrayRef<int64_t> srcShape,
DistributeLayoutAttr consumerLayout,
- SmallVector<int64_t> reductionDims);
+ SmallVector<int64_t> reductionDims,
+ const uArch::uArch *uArch);
/// Setup the result layout attribute for a bitcast operation based on element
/// type bitwidths. This ensures the source layout can always be derived from
@@ -112,10 +114,15 @@ SliceAttr reductionSetupResultLayout(xegpu::LayoutKind layoutKind,
/// maintains the invariant that the source layout can be recovered by inverse
/// scaling during layout inference.
DistributeLayoutAttr
-bitCastSetupResultLayout(xegpu::LayoutKind layoutKind,
- ArrayRef<int64_t> srcShape,
+bitCastSetupResultLayout(LayoutKind layoutKind, ArrayRef<int64_t> srcShape,
DistributeLayoutAttr consumerLayout,
- int resElemTyBitWidth, int srcElemTyBitWidth);
+ int resElemTyBitWidth, int srcElemTyBitWidth,
+ const uArch::uArch *uArch);
+
+// Setup the anchor layout attribute for a storeMatrix operation
+DistributeLayoutAttr storeMatrixSetupAnchorLayout(LayoutKind layoutKind,
+ VectorType vectorTy,
+ const uArch::uArch *uArch);
} // namespace xegpu
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 07cef91ff3291..e7aad5337d59a 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -649,6 +649,10 @@ void LayoutInfoPropagation::visitVectorMultiReductionOp(
if (!resLayoutInfo.isAssigned())
return;
+ // debug print resLayoutInfo
+ LLVM_DEBUG(DBGS() << "visitVectorMultiReductionOp: resLayoutInfo = ";
+ resLayoutInfo.print(llvm::dbgs()); llvm::dbgs() << "\n");
+
VectorType sourceTy =
llvm::dyn_cast<VectorType>(reduction.getSourceVectorType());
SmallVector<int64_t> reductionDims(reduction.getReductionDims().begin(),
@@ -672,18 +676,28 @@ void LayoutInfoPropagation::visitVectorMultiReductionOp(
LLVM_DEBUG(DBGS() << "visitVectorMultiReductionOp: consumerLayoutAttr = "
<< consumerLayoutAttr << "\n");
- // An dominant layout is for the result and represents the layout requirements
- // for the operation it is recorded to anchor layout or temporary layout it
- // must be honored for current op and may conflict with the layout propagated
- // from consumer op the conflict is resolved in later phase by converting the
- // dominant layout to the source layout
-
+ // The required result layout represents the layout requirements
+ // for the operation it is recorded to anchor layout or temporary layout.
+ // it must be honored for current op and may conflict with the layout
+ // propagated from consumer op, the conflict is resolved in later phase by
+ // converting the required result layout to the consumer layout
+ auto uArch = getUArch(xegpu::getChipStr(reduction).value_or(""));
auto requiredResLayoutAttr = xegpu::reductionSetupResultLayout(
- layoutKind, srcShape, consumerLayoutAttr, reductionDims);
+ layoutKind, srcShape, consumerLayoutAttr, reductionDims, uArch);
LLVM_DEBUG(DBGS() << "visitVectorMultiReductionOp: requiredResLayoutAttr = "
<< requiredResLayoutAttr << "\n");
- resLayoutInfo.set(requiredResLayoutAttr);
+ // resLayoutInfo.set(requiredResLayoutAttr);
+ xegpu::setTemporaryLayout(reduction->getResult(0), requiredResLayoutAttr);
+
+ // debug print resLayoutInfo
+ LLVM_DEBUG(
+ DBGS() << "visitVectorMultiReductionOp: after change resLayoutInfo = ";
+ resLayoutInfo.print(llvm::dbgs()); llvm::dbgs() << "\n");
+
+ LLVM_DEBUG(DBGS() << "visitVectorMultiReductionOp: after change "
+ "results[0]->getValue() = ";
+ results[0]->getValue().print(llvm::dbgs()); llvm::dbgs() << "\n");
// derive the source layout from the dominant layout and reduction dims
auto srcLayoutAttr =
@@ -694,7 +708,8 @@ void LayoutInfoPropagation::visitVectorMultiReductionOp(
propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
// Accumulator should have the same layout as the result.
- propagateIfChanged(operands[1], operands[1]->meet(resLayoutInfo));
+ propagateIfChanged(operands[1],
+ operands[1]->meet(LayoutInfo(requiredResLayoutAttr)));
}
void LayoutInfoPropagation::visitVectorBroadCastOp(
@@ -1102,12 +1117,12 @@ void LayoutInfoPropagation::visitVectorBitcastOp(
ArrayRef<int64_t> srcShape = bitcast.getSourceVectorType().getShape();
auto consumerLayoutAttr =
dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
-
+ auto uArch = getUArch(xegpu::getChipStr(bitcast).value_or(""));
auto requiredResLayoutAttr =
bitCastSetupResultLayout(layoutKind, srcShape, consumerLayoutAttr,
- outElemTyBitWidth, inElemTyBitWidth);
+ outElemTyBitWidth, inElemTyBitWidth, uArch);
- resLayoutInfo.set(requiredResLayoutAttr);
+ xegpu::setTemporaryLayout(bitcast->getResult(0), requiredResLayoutAttr);
// derive the source layout from the dominant layout and reduction dims
auto srcLayoutAttr = xegpu::inferBitCastSourceLayout(
@@ -1327,12 +1342,10 @@ void LayoutInfoPropagation::visitStoreMatrixOp(
VectorType payloadTy = llvm::cast<VectorType>(operand.getType());
assert(payloadTy.getRank() == 2 && "Expecting 2D vector for store matrix.");
auto uArch = getUArch(getChipStr(storeMatrix).value_or(""));
- SmallVector<int> instData = {1, uArch->getSubgroupSize()};
- if (layoutKind == xegpu::LayoutKind::InstData)
- layout = LayoutInfo(
- xegpu::LayoutAttr::get(storeMatrix.getContext(), instData));
- else
- layout = getSIMTLayoutInfoScatterIO(payloadTy, uArch);
+ auto requiredAnchorLayoutAttr =
+ storeMatrixSetupAnchorLayout(layoutKind, payloadTy, uArch);
+ storeMatrix.setLayoutAttr(requiredAnchorLayoutAttr);
+ layout = LayoutInfo(requiredAnchorLayoutAttr);
}
propagateIfChanged(operands[index], operands[index]->meet(layout));
@@ -1612,6 +1625,12 @@ void XeGPUPropagateLayoutPass::runOnOperation() {
LayoutInfo layout = analysis.getLayoutInfo(val);
if (!layout.isAssigned())
return {};
+ if (auto opResult = dyn_cast<OpResult>(val)) {
+ xegpu::DistributeLayoutAttr requiredResLayoutAttr =
+ xegpu::getTemporaryLayout(opResult);
+ if (requiredResLayoutAttr != nullptr)
+ return requiredResLayoutAttr;
+ }
xegpu::DistributeLayoutAttr layoutAttr =
cast<xegpu::DistributeLayoutAttr>(layout.get());
if (layout.isSliceLayout())
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp
index c21e22c966f3d..d42f0d30ce8c0 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp
@@ -553,16 +553,12 @@ xegpu::inferShapeCastSourceLayout(xegpu::DistributeLayoutAttr resLayout,
xegpu::SliceAttr xegpu::reductionSetupResultLayout(
xegpu::LayoutKind layoutKind, ArrayRef<int64_t> srcShape,
- DistributeLayoutAttr consumerLayout, SmallVector<int64_t> reductionDims) {
+ DistributeLayoutAttr consumerLayout, SmallVector<int64_t> reductionDims,
+ const xegpu::uArch::uArch *uArch) {
xegpu::SliceAttr consumerSliceLayout =
dyn_cast<xegpu::SliceAttr>(consumerLayout);
- // Hardware constraints (TODO: these should ideally be queried from device
- // capabilities)
- const int workgroupSize = 16; // Total number of subgroups in a workgroup
- const int subgroupSize = 16; // Number of SIMD lanes per subgroup
- const int vectorSize = 8; // Elements processed per vector instruction
int srcShapeSize = srcShape.size();
xegpu::DistributeLayoutAttr proposedSrcLayout;
auto context = consumerLayout.getContext();
@@ -576,6 +572,23 @@ xegpu::SliceAttr xegpu::reductionSetupResultLayout(
SmallVector<int64_t> laneLayout(srcShapeSize);
SmallVector<int64_t> laneData(srcShapeSize);
+ // recover workgroup and subgroup size from consumer layout
+ DistributeLayoutAttr origPlainLayout;
+ if (consumerSliceLayout) {
+ origPlainLayout = consumerSliceLayout.flatten().getParent();
+ } else {
+ origPlainLayout = consumerLayout;
+ }
+
+ const int workgroupSize =
+ std::accumulate(origPlainLayout.getEffectiveSgLayoutAsInt().begin(),
+ origPlainLayout.getEffectiveSgLayoutAsInt().end(), 1,
+ std::multiplies<int64_t>());
+
+ const int subgroupSize = uArch->getSubgroupSize();
+
+ const int vectorSize = 16; // vector size from SPRIV vector restriction
+
SmallVector<int64_t> defaultInstData(srcShapeSize, 1);
defaultInstData[srcShapeSize - 1] = subgroupSize;
defaultInstData[srcShapeSize - 2] =
@@ -702,8 +715,11 @@ xegpu::SliceAttr xegpu::reductionSetupResultLayout(
"not all subgroups have been distributed");
break;
case xegpu::LayoutKind::InstData:
- for (int i = 0; i < srcShapeSize; i++)
+ for (int i = 0; i < srcShapeSize; i++) {
instData[i] = std::min(defaultInstData[i], srcShape[i]);
+ llvm::dbgs() << "MultiReductionOp: Strategy 2: instData [" << i
+ << "] = " << instData[i] << "\n";
+ }
break;
case xegpu::LayoutKind::Lane:
consumerLaneId = consumerLaneLayout.size() - 1;
@@ -740,12 +756,24 @@ xegpu::SliceAttr xegpu::reductionSetupResultLayout(
SmallVector<int32_t> laneLayout32(laneLayout.begin(), laneLayout.end());
SmallVector<int32_t> laneData32(laneData.begin(), laneData.end());
- proposedSrcLayout = xegpu::LayoutAttr::get(
- context, DenseI32ArrayAttr::get(context, sgLayout32),
- DenseI32ArrayAttr::get(context, sgData32),
- DenseI32ArrayAttr::get(context, instData32),
- DenseI32ArrayAttr::get(context, laneLayout32),
- DenseI32ArrayAttr::get(context, laneData32), consumerLayout.getOrder());
+ switch (layoutKind) {
+ case xegpu::LayoutKind::Subgroup:
+ proposedSrcLayout = xegpu::LayoutAttr::get(
+ context, DenseI32ArrayAttr::get(context, sgLayout32),
+ DenseI32ArrayAttr::get(context, sgData32), consumerLayout.getOrder());
+ break;
+ case xegpu::LayoutKind::InstData:
+ proposedSrcLayout = xegpu::LayoutAttr::get(
+ context, DenseI32ArrayAttr::get(context, instData32));
+ break;
+ case xegpu::LayoutKind::Lane:
+ proposedSrcLayout = xegpu::LayoutAttr::get(
+ context, DenseI32ArrayAttr::get(context, laneLayout32),
+ DenseI32ArrayAttr::get(context, laneData32), consumerLayout.getOrder());
+ break;
+ default:
+ llvm_unreachable("unsupported layout kind");
+ }
xegpu::SliceAttr resLayout =
xegpu::SliceAttr::get(context, proposedSrcLayout,
@@ -753,17 +781,17 @@ xegpu::SliceAttr xegpu::reductionSetupResultLayout(
return resLayout;
}
-xegpu::DistributeLayoutAttr
-xegpu::bitCastSetupResultLayout(xegpu::LayoutKind layoutKind,
- ArrayRef<int64_t> srcShape,
- DistributeLayoutAttr consumerLayout,
- int resElemTyBitWidth, int srcElemTyBitWidth) {
+xegpu::DistributeLayoutAttr xegpu::bitCastSetupResultLayout(
+ xegpu::LayoutKind layoutKind, ArrayRef<int64_t> srcShape,
+ DistributeLayoutAttr consumerLayout, int resElemTyBitWidth,
+ int srcElemTyBitWidth, const xegpu::uArch::uArch *uArch) {
SmallVector<int64_t> sgData = consumerLayout.getEffectiveSgDataAsInt();
SmallVector<int64_t> instData = consumerLayout.getEffectiveInstDataAsInt();
SmallVector<int64_t> laneData = consumerLayout.getEffectiveLaneDataAsInt();
size_t dim = srcShape.size() - 1;
int64_t sgDataValue, instDataValue, laneDataValue;
- const int subgroupSize = 16; // assuming 16 lanes per subgroup
+
+ const int subgroupSize = uArch->getSubgroupSize();
if (srcElemTyBitWidth > resElemTyBitWidth) {
// When casting to a smaller bitwidth, multiply the result layout
@@ -811,4 +839,30 @@ xegpu::bitCastSetupResultLayout(xegpu::LayoutKind layoutKind,
resLayout =
consumerLayout.setDimData(dim, sgDataValue, instDataValue, laneDataValue);
return resLayout;
+}
+
+xegpu::DistributeLayoutAttr
+xegpu::storeMatrixSetupAnchorLayout(xegpu::LayoutKind layoutKind,
+ VectorType vectorTy,
+ const xegpu::uArch::uArch *uArch) {
+
+ xegpu::DistributeLayoutAttr requiredLayout;
+ SmallVector<int> instData = {1, uArch->getSubgroupSize()};
+ switch (layoutKind) {
+ case xegpu::LayoutKind::Subgroup:
+ assert(false &&
+ "subgroup layout assignment not supported yet for storeMatrix.");
+ break;
+ case xegpu::LayoutKind::InstData:
+ requiredLayout = xegpu::LayoutAttr::get(vectorTy.getContext(), instData);
+ break;
+ case xegpu::LayoutKind::Lane:
+ requiredLayout = xegpu::LayoutAttr::get(
+ vectorTy.getContext(), {1, uArch->getSubgroupSize()}, {1, 1});
+
+ break;
+ default:
+ llvm_unreachable("unsupported layout kind");
+ }
+ return requiredLayout;
}
\ No newline at end of file
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout.mlir b/mlir/test/Dialect/XeGPU/propagate-layout.mlir
index 9f309917a7d60..acfd2e34c805c 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout.mlir
@@ -702,7 +702,7 @@ func.func @vector_broadcast_scalar_to_vector(%arg0: !xegpu.tensor_desc<16x16xf16
// -----
gpu.module @test {
// CHECK-LABEL: func.func @store_matrix(
-// CHECK: %[[CST:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>} dense<0.000000e+00> : vector<16x16xf16>
+// CHECK: %[[CST:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} dense<0.000000e+00> : vector<16x16xf16>
// CHECK-NEXT: xegpu.store_matrix %[[CST]], %arg0[8, 8] : vector<16x16xf16>, !xegpu.mem_desc<16x64xf16>
func.func @store_matrix(%arg0: !xegpu.mem_desc<16x64xf16>) {
>From 6c4bdb442c9fd8ff2b180692556bec097ded1893 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Tue, 27 Jan 2026 04:33:00 +0000
Subject: [PATCH 11/35] add more storescatter layout utilits
---
.../Dialect/XeGPU/Utils/XeGPULayoutUtils.h | 20 +-
.../XeGPU/Transforms/XeGPUPropagateLayout.cpp | 186 +++++++++---------
.../Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp | 98 ++++++++-
3 files changed, 197 insertions(+), 107 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h
index 39afa4d7c1b8c..fe23510845ae3 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h
@@ -99,7 +99,7 @@ DistributeLayoutAttr inferShapeCastSourceLayout(DistributeLayoutAttr resLayout,
/// The SliceAttr for the result is then created based on the derived source
/// layout and the specified reduction dimensions.
SliceAttr reductionSetupResultLayout(xegpu::LayoutKind layoutKind,
- ArrayRef<int64_t> srcShape,
+ VectorType srcVectorTy,
DistributeLayoutAttr consumerLayout,
SmallVector<int64_t> reductionDims,
const uArch::uArch *uArch);
@@ -113,17 +113,25 @@ SliceAttr reductionSetupResultLayout(xegpu::LayoutKind layoutKind,
/// (inst_data, lane_data) are scaled up by the bitwidth ratio. This
/// maintains the invariant that the source layout can be recovered by inverse
/// scaling during layout inference.
-DistributeLayoutAttr
-bitCastSetupResultLayout(LayoutKind layoutKind, ArrayRef<int64_t> srcShape,
- DistributeLayoutAttr consumerLayout,
- int resElemTyBitWidth, int srcElemTyBitWidth,
- const uArch::uArch *uArch);
+DistributeLayoutAttr bitCastSetupResultLayout(
+ LayoutKind layoutKind, VectorType srcVectorTy, VectorType resVectorTy,
+ DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch);
// Setup the anchor layout attribute for a storeMatrix operation
DistributeLayoutAttr storeMatrixSetupAnchorLayout(LayoutKind layoutKind,
VectorType vectorTy,
const uArch::uArch *uArch);
+xegpu::DistributeLayoutAttr
+xegpu::loadMatrixSetupAnchorLayout(LayoutKind layoutKind, VectorType vectorTy,
+ const uArch::uArch *uArch);
+
+xegpu::DistributeLayoutAttr
+xegpu::loadGatherSetupAnchorLayout(LayoutKind layoutKind, VectorType vectorTy,
+ const uArch::uArch *uArch);
+xegpu::DistributeLayoutAttr
+xegpu::storeScatterSetupAnchorLayout(LayoutKind layoutKind, VectorType vectorTy,
+ const uArch::uArch *uArch);
} // namespace xegpu
} // namespace mlir
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index e7aad5337d59a..6f356f490726b 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -419,6 +419,10 @@ class LayoutInfoPropagation
ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results);
+ void visitLoadMatrixOp(xegpu::LoadMatrixOp load,
+ ArrayRef<LayoutInfoLattice *> operands,
+ ArrayRef<const LayoutInfoLattice *> results);
+
void visitStoreMatrixOp(xegpu::StoreMatrixOp store,
ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results);
@@ -496,6 +500,9 @@ LogicalResult LayoutInfoPropagation::visitOperation(
.Case<vector::ShapeCastOp>([&](auto shapeCastOp) {
visitShapeCastOp(shapeCastOp, operands, results);
})
+ .Case<xegpu::LoadMatrixOp>([&](auto loadMatrixOp) {
+ visitLoadMatrixOp(loadMatrixOp, operands, results);
+ })
.Case<xegpu::StoreMatrixOp>([&](auto storeMatrixOp) {
visitStoreMatrixOp(storeMatrixOp, operands, results);
})
@@ -658,11 +665,9 @@ void LayoutInfoPropagation::visitVectorMultiReductionOp(
SmallVector<int64_t> reductionDims(reduction.getReductionDims().begin(),
reduction.getReductionDims().end());
- auto srcShape = sourceTy.getShape();
-
LLVM_DEBUG(DBGS() << "visitVectorMultiReductionOp: srcShape = [";
for (auto dim
- : srcShape) llvm::dbgs()
+ : sourceTy.getShape()) llvm::dbgs()
<< dim << " ";
llvm::dbgs() << "]\n");
LLVM_DEBUG(DBGS() << "visitVectorMultiReductionOp: reductionDims = [";
@@ -683,7 +688,7 @@ void LayoutInfoPropagation::visitVectorMultiReductionOp(
// converting the required result layout to the consumer layout
auto uArch = getUArch(xegpu::getChipStr(reduction).value_or(""));
auto requiredResLayoutAttr = xegpu::reductionSetupResultLayout(
- layoutKind, srcShape, consumerLayoutAttr, reductionDims, uArch);
+ layoutKind, sourceTy, consumerLayoutAttr, reductionDims, uArch);
LLVM_DEBUG(DBGS() << "visitVectorMultiReductionOp: requiredResLayoutAttr = "
<< requiredResLayoutAttr << "\n");
@@ -1109,21 +1114,21 @@ void LayoutInfoPropagation::visitVectorBitcastOp(
LayoutInfo resLayoutInfo = results[0]->getValue();
if (!resLayoutInfo.isAssigned())
return;
- int inElemTyBitWidth =
- bitcast.getSourceVectorType().getElementType().getIntOrFloatBitWidth();
- int outElemTyBitWidth =
- bitcast.getResultVectorType().getElementType().getIntOrFloatBitWidth();
- ArrayRef<int64_t> srcShape = bitcast.getSourceVectorType().getShape();
+ auto srcVecType = bitcast.getSourceVectorType();
+ auto resVecType = bitcast.getResultVectorType();
+
auto consumerLayoutAttr =
dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
auto uArch = getUArch(xegpu::getChipStr(bitcast).value_or(""));
- auto requiredResLayoutAttr =
- bitCastSetupResultLayout(layoutKind, srcShape, consumerLayoutAttr,
- outElemTyBitWidth, inElemTyBitWidth, uArch);
+ auto requiredResLayoutAttr = bitCastSetupResultLayout(
+ layoutKind, srcVecType, resVecType, consumerLayoutAttr, uArch);
xegpu::setTemporaryLayout(bitcast->getResult(0), requiredResLayoutAttr);
+ int inElemTyBitWidth = srcVecType.getElementType().getIntOrFloatBitWidth();
+ int outElemTyBitWidth = resVecType.getElementType().getIntOrFloatBitWidth();
+
// derive the source layout from the dominant layout and reduction dims
auto srcLayoutAttr = xegpu::inferBitCastSourceLayout(
requiredResLayoutAttr, outElemTyBitWidth, inElemTyBitWidth);
@@ -1250,107 +1255,96 @@ void LayoutInfoPropagation::visitStoreScatterOp(
xegpu::StoreScatterOp storeScatter, ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results) {
- LayoutInfo payloadLayout;
- LayoutInfo maskLayout;
+ LayoutInfo srcLayoutInfo;
+ LayoutInfo maskLayoutInfo;
+ xegpu::DistributeLayoutAttr requiredAnchorLayoutAttr;
xegpu::DistributeLayoutAttr anchorLayout = storeScatter.getLayoutAttr();
auto uArch = getUArch(getChipStr(storeScatter).value_or(""));
- const int subgroupSize = uArch->getSubgroupSize();
+ VectorType srcVecTy = storeScatter.getValueType();
+ VectorType maskTy =
+ llvm::dyn_cast<VectorType>(storeScatter.getMask().getType());
if (hasParamsOfLayoutKind(anchorLayout)) {
- payloadLayout = LayoutInfo(anchorLayout);
- maskLayout = payloadLayout;
+ requiredAnchorLayoutAttr = LayoutInfo(anchorLayout);
} else {
- // Currently, for 2D StoreScatterOp we expect that the height dimension of
- // the tensor descriptor is equal to the subgroup size. This is ensured by
- // the op verifier.
- VectorType payloadTy = storeScatter.getValueType();
- if (!payloadTy) {
+
+ if (!srcVecTy) {
storeScatter.emitWarning("Not propagating, non-vector payload supplied.");
return;
}
-
- if (layoutKind == xegpu::LayoutKind::InstData) {
- const auto *uArchInstruction =
- dyn_cast<xegpu::uArch::StoreScatterInstruction>(uArch->getInstruction(
- xegpu::uArch::InstructionKind::StoreScatter));
- const int subgroupSize = uArch->getSubgroupSize();
- SmallVector<int> instDataUarch{subgroupSize};
- if (payloadTy.getRank() != 1) {
- if (payloadTy.getRank() != 2) {
- storeScatter.emitWarning("Expected 2D payload for StoreScatterOp.");
- return;
- }
- instDataUarch.push_back(
- (std::min(static_cast<int>(payloadTy.getShape().back()),
- uArchInstruction->getMaxLaneLoadStoreSize())));
- }
- payloadLayout = LayoutInfo(
- xegpu::LayoutAttr::get(storeScatter.getContext(), instDataUarch));
- } else {
- auto payloadShape = payloadTy.getShape();
- if (payloadShape.size() > 1)
- assert(payloadShape[0] == subgroupSize &&
- "Expected the first dimension of 2D tensor descriptor to be "
- "equal to "
- "subgroup size.");
- payloadLayout = getSIMTLayoutInfoScatterIO(payloadTy, uArch);
- }
-
- storeScatter.setLayoutAttr(
- dyn_cast<xegpu::DistributeLayoutAttr>(payloadLayout.get()));
+ assert(srcVecTy.getRank() == 2 && "Expecting 2D vector for store matrix.");
+ auto uArch = getUArch(getChipStr(storeScatter).value_or(""));
+ requiredAnchorLayoutAttr =
+ xegpu::storeScatterSetupAnchorLayout(layoutKind, srcVecTy, uArch);
+ storeScatter.setLayoutAttr(requiredAnchorLayoutAttr);
}
- // If no user-defined anchor or we deal with a chunked op, set the default
- // mask layout.
- // Rank 1 data : Keep the mask layout aligned with data.
- // Rank >1 data: Enforce the default xegpu 1D layout for mask.
- if (!hasParamsOfLayoutKind(anchorLayout) ||
- storeScatter.getValueType().getRank() > 1) {
- if (layoutKind == xegpu::LayoutKind::InstData)
- maskLayout = LayoutInfo(
- xegpu::LayoutAttr::get(storeScatter->getContext(), {subgroupSize}));
- else if (layoutKind == xegpu::LayoutKind::Lane)
- maskLayout =
- getDefaultSIMTLayoutInfo(storeScatter->getContext(), 1, subgroupSize);
+ srcLayoutInfo = LayoutInfo(requiredAnchorLayoutAttr);
+ maskLayoutInfo = srcLayoutInfo;
+ if (maskTy.getRank() < srcVecTy.getRank()) {
+ assert((maskTy.getRank() == (srcVecTy.getRank() - 1)) &&
+ "Expecting mask vector only 1 dimension less than value vector.");
+ maskLayoutInfo = LayoutInfo(requiredAnchorLayoutAttr.dropDim(-1));
+
+ // Propagate the payload operand layout
+ propagateIfChanged(operands[0], operands[0]->meet(srcLayoutInfo));
+ // Propagate the destination (if tdesc) operand layout
+ if (isa<xegpu::TensorDescType>(storeScatter.getDestType()))
+ propagateIfChanged(operands[1], operands[1]->meet(srcLayoutInfo));
+ // Propagate the new layout to the mask and optional offset operand.
+ propagateIfChanged(operands[2], operands[2]->meet(maskLayoutInfo));
+ if (storeScatter.getOffsets())
+ propagateIfChanged(operands[3], operands[3]->meet(maskLayoutInfo));
}
- // Propagate the payload operand layout
- propagateIfChanged(operands[0], operands[0]->meet(payloadLayout));
- // Propagate the destination (if tdesc) operand layout
- if (isa<xegpu::TensorDescType>(storeScatter.getDestType()))
- propagateIfChanged(operands[1], operands[1]->meet(payloadLayout));
- // Propagate the new layout to the mask and optional offset operand.
- propagateIfChanged(operands[2], operands[2]->meet(maskLayout));
- if (storeScatter.getOffsets())
- propagateIfChanged(operands[3], operands[3]->meet(maskLayout));
-}
+ void LayoutInfoPropagation::visitLoadMatrixOp(
+ xegpu::LoadMatrixOp loadMatrixOp, ArrayRef<LayoutInfoLattice *> operands,
+ ArrayRef<const LayoutInfoLattice *> results) {
+ Value resVal = loadMatrixOp.getRes();
+ unsigned index =
+ std::distance(loadMatrixOp.operand_begin(),
+ llvm::find(loadMatrixOp.getOperands(), operand));
+
+ xegpu::DistributeLayoutAttr anchorLayout = loadMatrixOp.getLayoutAttr();
+ LayoutInfo layout;
+ if (hasParamsOfLayoutKind(anchorLayout)) {
+ layout = LayoutInfo(anchorLayout);
+ } else {
+ VectorType resVecTy = llvm::cast<VectorType>(resVal.getType());
+ assert(resVecTy.getRank() == 2 &&
+ "Expecting 2D vector for store matrix.");
+ auto uArch = getUArch(getChipStr(loadMatrixOp).value_or(""));
+ auto requiredAnchorLayoutAttr =
+ xegpu::loadMatrixSetupAnchorLayout(layoutKind, resVecTy, uArch);
+ loadMatrixOp.setLayoutAttr(requiredAnchorLayoutAttr);
+ layout = LayoutInfo(requiredAnchorLayoutAttr);
+ }
+ }
// Store matrix is a flavor of scattered store for 2D shapes.
-void LayoutInfoPropagation::visitStoreMatrixOp(
- xegpu::StoreMatrixOp storeMatrix, ArrayRef<LayoutInfoLattice *> operands,
- ArrayRef<const LayoutInfoLattice *> results) {
- Value operand = storeMatrix.getData();
- unsigned index =
- std::distance(storeMatrix.operand_begin(),
- llvm::find(storeMatrix->getOperands(), operand));
+ void LayoutInfoPropagation::visitStoreMatrixOp(
+ xegpu::StoreMatrixOp storeMatrix, ArrayRef<LayoutInfoLattice *> operands,
+ ArrayRef<const LayoutInfoLattice *> results) {
+
+ xegpu::DistributeLayoutAttr anchorLayout = storeMatrix.getLayoutAttr();
+ LayoutInfo layout;
+ if (hasParamsOfLayoutKind(anchorLayout)) {
+ layout = LayoutInfo(anchorLayout);
+ } else {
+ VectorType srcVecTy =
+ llvm::cast<VectorType>(storeMatrix.getData().getType());
+ assert(srcVecTy.getRank() == 2 &&
+ "Expecting 2D vector for store matrix.");
+ auto uArch = getUArch(getChipStr(storeMatrix).value_or(""));
+ auto requiredAnchorLayoutAttr =
+ xegpu::storeMatrixSetupAnchorLayout(layoutKind, srcVecTy, uArch);
+ storeMatrix.setLayoutAttr(requiredAnchorLayoutAttr);
+ layout = LayoutInfo(requiredAnchorLayoutAttr);
+ }
- xegpu::DistributeLayoutAttr anchorLayout = storeMatrix.getLayoutAttr();
- LayoutInfo layout;
- if (hasParamsOfLayoutKind(anchorLayout)) {
- layout = LayoutInfo(anchorLayout);
- } else {
- VectorType payloadTy = llvm::cast<VectorType>(operand.getType());
- assert(payloadTy.getRank() == 2 && "Expecting 2D vector for store matrix.");
- auto uArch = getUArch(getChipStr(storeMatrix).value_or(""));
- auto requiredAnchorLayoutAttr =
- storeMatrixSetupAnchorLayout(layoutKind, payloadTy, uArch);
- storeMatrix.setLayoutAttr(requiredAnchorLayoutAttr);
- layout = LayoutInfo(requiredAnchorLayoutAttr);
+ propagateIfChanged(operands[0], operands[0]->meet(layout));
}
- propagateIfChanged(operands[index], operands[index]->meet(layout));
-}
-
namespace {
//===----------------------------------------------------------------------===//
// RunLayoutInfoPropagation
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp
index d42f0d30ce8c0..fb806d5859c97 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp
@@ -552,10 +552,11 @@ xegpu::inferShapeCastSourceLayout(xegpu::DistributeLayoutAttr resLayout,
/// the consumer.
xegpu::SliceAttr xegpu::reductionSetupResultLayout(
- xegpu::LayoutKind layoutKind, ArrayRef<int64_t> srcShape,
+ xegpu::LayoutKind layoutKind, VectorType srcVecTy,
DistributeLayoutAttr consumerLayout, SmallVector<int64_t> reductionDims,
const xegpu::uArch::uArch *uArch) {
+ auto srcShape = srcVecTy.getShape();
xegpu::SliceAttr consumerSliceLayout =
dyn_cast<xegpu::SliceAttr>(consumerLayout);
@@ -782,9 +783,13 @@ xegpu::SliceAttr xegpu::reductionSetupResultLayout(
}
xegpu::DistributeLayoutAttr xegpu::bitCastSetupResultLayout(
- xegpu::LayoutKind layoutKind, ArrayRef<int64_t> srcShape,
- DistributeLayoutAttr consumerLayout, int resElemTyBitWidth,
- int srcElemTyBitWidth, const xegpu::uArch::uArch *uArch) {
+ xegpu::LayoutKind layoutKind, VectorType srcVecTy, VectorType resVecTy,
+ DistributeLayoutAttr consumerLayout, const xegpu::uArch::uArch *uArch) {
+
+ int srcElemTyBitWidth = srcVecTy.getElementType().getIntOrFloatBitWidth();
+ int resElemTyBitWidth = resVecTy.getElementType().getIntOrFloatBitWidth();
+
+ ArrayRef<int64_t> srcShape = srcVecTy.getShape();
SmallVector<int64_t> sgData = consumerLayout.getEffectiveSgDataAsInt();
SmallVector<int64_t> instData = consumerLayout.getEffectiveInstDataAsInt();
SmallVector<int64_t> laneData = consumerLayout.getEffectiveLaneDataAsInt();
@@ -850,7 +855,7 @@ xegpu::storeMatrixSetupAnchorLayout(xegpu::LayoutKind layoutKind,
SmallVector<int> instData = {1, uArch->getSubgroupSize()};
switch (layoutKind) {
case xegpu::LayoutKind::Subgroup:
- assert(false &&
+ assert(true &&
"subgroup layout assignment not supported yet for storeMatrix.");
break;
case xegpu::LayoutKind::InstData:
@@ -865,4 +870,87 @@ xegpu::storeMatrixSetupAnchorLayout(xegpu::LayoutKind layoutKind,
llvm_unreachable("unsupported layout kind");
}
return requiredLayout;
+}
+
+xegpu::DistributeLayoutAttr
+xegpu::loadMatrixSetupAnchorLayout(xegpu::LayoutKind layoutKind,
+ VectorType vectorTy,
+ const xegpu::uArch::uArch *uArch) {
+ xegpu::DistributeLayoutAttr requiredLayout;
+ SmallVector<int> instData = {1, uArch->getSubgroupSize()};
+ switch (layoutKind) {
+ case xegpu::LayoutKind::Subgroup:
+ assert(true &&
+ "subgroup layout assignment not supported yet for loadMatrix.");
+ break;
+ case xegpu::LayoutKind::InstData:
+ requiredLayout = xegpu::LayoutAttr::get(vectorTy.getContext(), instData);
+ break;
+ case xegpu::LayoutKind::Lane:
+ requiredLayout = xegpu::LayoutAttr::get(
+ vectorTy.getContext(), {1, uArch->getSubgroupSize()}, {1, 1});
+ break;
+ default:
+ llvm_unreachable("unsupported layout kind");
+ }
+ return requiredLayout;
+}
+
+static xegpu::DistributeLayoutAttr
+getDefaultSIMTLaneLayoutAttr(mlir::MLIRContext *ctx, unsigned rank,
+ int subgroupSize) {
+ assert((rank == 1 || rank == 2) && "Expected 1D or 2D vector.");
+ if (rank == 1) {
+ return xegpu::LayoutAttr::get(ctx, {subgroupSize}, {1});
+ }
+ return xegpu::LayoutAttr::get(ctx, {1, subgroupSize}, {1, 1});
+}
+
+xegpu::DistributeLayoutAttr
+xegpu::loadGatherSetupAnchorLayout(LayoutKind layoutKind, VectorType vectorTy,
+ const uArch::uArch *uArch) {}
+
+xegpu::DistributeLayoutAttr
+xegpu::storeScatterSetupAnchorLayout(LayoutKind layoutKind, VectorType srcVecTy,
+ const uArch::uArch *uArch) {
+
+ xegpu::DistributeLayoutAttr requiredLayout;
+ const int subgroupSize = uArch->getSubgroupSize();
+ const int spirVectorSize = 16; // vector size from SPRIV vector restriction
+
+ auto srcShape = srcVecTy.getShape();
+ int srcShapeSize = srcVecTy.getShape().size();
+
+ SmallVector<int64_t> instData(srcShapeSize);
+ SmallVector<int64_t> laneLayout(srcShapeSize);
+ SmallVector<int64_t> laneData(srcShapeSize);
+
+ const auto *uArchInstruction =
+ dyn_cast<xegpu::uArch::StoreScatterInstruction>(
+ uArch->getInstruction(xegpu::uArch::InstructionKind::StoreScatter));
+
+ switch (layoutKind) {
+ case xegpu::LayoutKind::Subgroup:
+ assert(false &&
+ "subgroup layout assignment not supported yet for loadMatrix.");
+ break;
+ case xegpu::LayoutKind::InstData:
+ assert((srcVecTy.getRank() > 2) && "StoreScatterOp can access 2D tensor "
+ "tile at maximum at subgroup level.");
+ if (srcVecTy.getRank() == 1)
+ instData[0] = subgroupSize;
+ else {
+ instData[0] = std::min(srcShape[0], static_cast<int64_t>(spirVectorSize));
+ instData[1] = subgroupSize;
+ }
+ requiredLayout = xegpu::LayoutAttr::get(srcVecTy.getContext(), instData);
+ break;
+ case xegpu::LayoutKind::Lane:
+ requiredLayout = getDefaultSIMTLaneLayoutAttr(
+ srcVecTy.getContext(), srcVecTy.getRank(), subgroupSize);
+ break;
+ default:
+ llvm_unreachable("unsupported layout kind");
+ }
+ return requiredLayout;
}
\ No newline at end of file
>From 3a4b60fa9b94e8b7fe2e586f6d311f0357db6390 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Tue, 27 Jan 2026 07:03:20 +0000
Subject: [PATCH 12/35] code polish
---
.../Dialect/XeGPU/Utils/XeGPULayoutUtils.h | 17 +--
.../XeGPU/Transforms/XeGPUPropagateLayout.cpp | 122 ++++++++--------
.../Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp | 136 ++++++++++++++----
3 files changed, 182 insertions(+), 93 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h
index fe23510845ae3..c083619292d0c 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h
@@ -117,21 +117,22 @@ DistributeLayoutAttr bitCastSetupResultLayout(
LayoutKind layoutKind, VectorType srcVectorTy, VectorType resVectorTy,
DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch);
-// Setup the anchor layout attribute for a storeMatrix operation
-DistributeLayoutAttr storeMatrixSetupAnchorLayout(LayoutKind layoutKind,
- VectorType vectorTy,
- const uArch::uArch *uArch);
xegpu::DistributeLayoutAttr
xegpu::loadMatrixSetupAnchorLayout(LayoutKind layoutKind, VectorType vectorTy,
+ xegpu::DistributeLayoutAttr consumerLayout,
const uArch::uArch *uArch);
-xegpu::DistributeLayoutAttr
-xegpu::loadGatherSetupAnchorLayout(LayoutKind layoutKind, VectorType vectorTy,
- const uArch::uArch *uArch);
+DistributeLayoutAttr storeMatrixSetupAnchorLayout(LayoutKind layoutKind,
+ VectorType vectorTy,
+ const uArch::uArch *uArch);
+
+xegpu::DistributeLayoutAttr xegpu::loadGatherSetupAnchorLayout(
+ LayoutKind layoutKind, VectorType vectorTy, int chunkSize,
+ DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch);
xegpu::DistributeLayoutAttr
xegpu::storeScatterSetupAnchorLayout(LayoutKind layoutKind, VectorType vectorTy,
- const uArch::uArch *uArch);
+ int chunkSize, const uArch::uArch *uArch);
} // namespace xegpu
} // namespace mlir
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 6f356f490726b..0236ba3437b6b 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -427,6 +427,14 @@ class LayoutInfoPropagation
ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results);
+ void visitLoadGatherOp(xegpu::LoadMatrixOp load,
+ ArrayRef<LayoutInfoLattice *> operands,
+ ArrayRef<const LayoutInfoLattice *> results);
+
+ void visitStoreScatterOp(xegpu::StoreMatrixOp store,
+ ArrayRef<LayoutInfoLattice *> operands,
+ ArrayRef<const LayoutInfoLattice *> results);
+
bool hasParamsOfLayoutKind(xegpu::DistributeLayoutAttr anchorLayout);
public:
@@ -1263,19 +1271,17 @@ void LayoutInfoPropagation::visitStoreScatterOp(
VectorType srcVecTy = storeScatter.getValueType();
VectorType maskTy =
llvm::dyn_cast<VectorType>(storeScatter.getMask().getType());
+ int chunkSize = storeScatter.getChunkSize().value_or(1);
if (hasParamsOfLayoutKind(anchorLayout)) {
requiredAnchorLayoutAttr = LayoutInfo(anchorLayout);
} else {
-
if (!srcVecTy) {
storeScatter.emitWarning("Not propagating, non-vector payload supplied.");
return;
}
- assert(srcVecTy.getRank() == 2 && "Expecting 2D vector for store matrix.");
- auto uArch = getUArch(getChipStr(storeScatter).value_or(""));
- requiredAnchorLayoutAttr =
- xegpu::storeScatterSetupAnchorLayout(layoutKind, srcVecTy, uArch);
+ requiredAnchorLayoutAttr = xegpu::storeScatterSetupAnchorLayout(
+ layoutKind, srcVecTy, chunkSize, uArch);
storeScatter.setLayoutAttr(requiredAnchorLayoutAttr);
}
@@ -1284,67 +1290,67 @@ void LayoutInfoPropagation::visitStoreScatterOp(
if (maskTy.getRank() < srcVecTy.getRank()) {
assert((maskTy.getRank() == (srcVecTy.getRank() - 1)) &&
"Expecting mask vector only 1 dimension less than value vector.");
- maskLayoutInfo = LayoutInfo(requiredAnchorLayoutAttr.dropDim(-1));
-
- // Propagate the payload operand layout
- propagateIfChanged(operands[0], operands[0]->meet(srcLayoutInfo));
- // Propagate the destination (if tdesc) operand layout
- if (isa<xegpu::TensorDescType>(storeScatter.getDestType()))
- propagateIfChanged(operands[1], operands[1]->meet(srcLayoutInfo));
- // Propagate the new layout to the mask and optional offset operand.
- propagateIfChanged(operands[2], operands[2]->meet(maskLayoutInfo));
- if (storeScatter.getOffsets())
- propagateIfChanged(operands[3], operands[3]->meet(maskLayoutInfo));
+ // ToDO: Infer the proper mask layout based on the value layout.
+ maskLayoutInfo = LayoutInfo(requiredAnchorLayoutAttr);
}
+ // Propagate the payload operand layout
+ propagateIfChanged(operands[0], operands[0]->meet(srcLayoutInfo));
+ // Propagate the destination (if tdesc) operand layout
+ if (isa<xegpu::TensorDescType>(storeScatter.getDestType()))
+ propagateIfChanged(operands[1], operands[1]->meet(srcLayoutInfo));
+ // Propagate the new layout to the mask and optional offset operand.
+ propagateIfChanged(operands[2], operands[2]->meet(maskLayoutInfo));
+ if (storeScatter.getOffsets())
+ propagateIfChanged(operands[3], operands[3]->meet(maskLayoutInfo));
+}
- void LayoutInfoPropagation::visitLoadMatrixOp(
- xegpu::LoadMatrixOp loadMatrixOp, ArrayRef<LayoutInfoLattice *> operands,
- ArrayRef<const LayoutInfoLattice *> results) {
- Value resVal = loadMatrixOp.getRes();
- unsigned index =
- std::distance(loadMatrixOp.operand_begin(),
- llvm::find(loadMatrixOp.getOperands(), operand));
-
- xegpu::DistributeLayoutAttr anchorLayout = loadMatrixOp.getLayoutAttr();
- LayoutInfo layout;
- if (hasParamsOfLayoutKind(anchorLayout)) {
- layout = LayoutInfo(anchorLayout);
- } else {
- VectorType resVecTy = llvm::cast<VectorType>(resVal.getType());
- assert(resVecTy.getRank() == 2 &&
- "Expecting 2D vector for store matrix.");
- auto uArch = getUArch(getChipStr(loadMatrixOp).value_or(""));
- auto requiredAnchorLayoutAttr =
- xegpu::loadMatrixSetupAnchorLayout(layoutKind, resVecTy, uArch);
- loadMatrixOp.setLayoutAttr(requiredAnchorLayoutAttr);
- layout = LayoutInfo(requiredAnchorLayoutAttr);
- }
+void LayoutInfoPropagation::visitLoadMatrixOp(
+ xegpu::LoadMatrixOp loadMatrixOp, ArrayRef<LayoutInfoLattice *> operands,
+ ArrayRef<const LayoutInfoLattice *> results) {
+
+ LayoutInfo resLayoutInfo = results[0]->getValue();
+ if (!resLayoutInfo.isAssigned())
+ return;
+ auto consumerLayoutAttr =
+ dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
+
+ xegpu::DistributeLayoutAttr anchorLayout = loadMatrixOp.getLayoutAttr();
+
+ // only need to set anchor layout, no need to porpagate to memdesc and offset
+ if (!hasParamsOfLayoutKind(anchorLayout)) {
+ VectorType resVecTy =
+ llvm::cast<VectorType>(loadMatrixOp.getRes().getType());
+ assert(resVecTy.getRank() == 2 && "Expecting 2D vector for store matrix.");
+ auto uArch = getUArch(getChipStr(loadMatrixOp).value_or(""));
+ auto requiredAnchorLayoutAttr = xegpu::loadMatrixSetupAnchorLayout(
+ layoutKind, resVecTy, consumerLayoutAttr, uArch);
+ loadMatrixOp.setLayoutAttr(requiredAnchorLayoutAttr);
}
+}
// Store matrix is a flavor of scattered store for 2D shapes.
- void LayoutInfoPropagation::visitStoreMatrixOp(
- xegpu::StoreMatrixOp storeMatrix, ArrayRef<LayoutInfoLattice *> operands,
- ArrayRef<const LayoutInfoLattice *> results) {
-
- xegpu::DistributeLayoutAttr anchorLayout = storeMatrix.getLayoutAttr();
- LayoutInfo layout;
- if (hasParamsOfLayoutKind(anchorLayout)) {
- layout = LayoutInfo(anchorLayout);
- } else {
- VectorType srcVecTy =
- llvm::cast<VectorType>(storeMatrix.getData().getType());
- assert(srcVecTy.getRank() == 2 &&
- "Expecting 2D vector for store matrix.");
- auto uArch = getUArch(getChipStr(storeMatrix).value_or(""));
- auto requiredAnchorLayoutAttr =
- xegpu::storeMatrixSetupAnchorLayout(layoutKind, srcVecTy, uArch);
- storeMatrix.setLayoutAttr(requiredAnchorLayoutAttr);
- layout = LayoutInfo(requiredAnchorLayoutAttr);
- }
+void LayoutInfoPropagation::visitStoreMatrixOp(
+ xegpu::StoreMatrixOp storeMatrix, ArrayRef<LayoutInfoLattice *> operands,
+ ArrayRef<const LayoutInfoLattice *> results) {
- propagateIfChanged(operands[0], operands[0]->meet(layout));
+ xegpu::DistributeLayoutAttr anchorLayout = storeMatrix.getLayoutAttr();
+ LayoutInfo layout;
+ if (hasParamsOfLayoutKind(anchorLayout)) {
+ layout = LayoutInfo(anchorLayout);
+ } else {
+ VectorType srcVecTy =
+ llvm::cast<VectorType>(storeMatrix.getData().getType());
+ assert(srcVecTy.getRank() == 2 && "Expecting 2D vector for store matrix.");
+ auto uArch = getUArch(getChipStr(storeMatrix).value_or(""));
+ auto requiredAnchorLayoutAttr =
+ xegpu::storeMatrixSetupAnchorLayout(layoutKind, srcVecTy, uArch);
+ storeMatrix.setLayoutAttr(requiredAnchorLayoutAttr);
+ layout = LayoutInfo(requiredAnchorLayoutAttr);
}
+ propagateIfChanged(operands[0], operands[0]->meet(layout));
+}
+
namespace {
//===----------------------------------------------------------------------===//
// RunLayoutInfoPropagation
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp
index fb806d5859c97..3968f0f2db0af 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp
@@ -875,20 +875,33 @@ xegpu::storeMatrixSetupAnchorLayout(xegpu::LayoutKind layoutKind,
xegpu::DistributeLayoutAttr
xegpu::loadMatrixSetupAnchorLayout(xegpu::LayoutKind layoutKind,
VectorType vectorTy,
+ xegpu::DistributeLayoutAttr consumerLayout,
const xegpu::uArch::uArch *uArch) {
xegpu::DistributeLayoutAttr requiredLayout;
SmallVector<int> instData = {1, uArch->getSubgroupSize()};
+
+ SmallVector<int64_t> consumerSgLayout =
+ consumerLayout.getEffectiveSgLayoutAsInt();
+ SmallVector<int64_t> consumerSgData =
+ consumerLayout.getEffectiveSgDataAsInt();
+ SmallVector<int32_t> sgLayout32(consumerSgLayout.begin(),
+ consumerSgLayout.end());
+ SmallVector<int32_t> sgData32(consumerSgData.begin(), consumerSgData.end());
+
+ auto context = vectorTy.getContext();
+
switch (layoutKind) {
case xegpu::LayoutKind::Subgroup:
- assert(true &&
- "subgroup layout assignment not supported yet for loadMatrix.");
+ requiredLayout = xegpu::LayoutAttr::get(
+ context, DenseI32ArrayAttr::get(context, sgLayout32),
+ DenseI32ArrayAttr::get(context, sgData32), consumerLayout.getOrder());
break;
case xegpu::LayoutKind::InstData:
- requiredLayout = xegpu::LayoutAttr::get(vectorTy.getContext(), instData);
+ requiredLayout = xegpu::LayoutAttr::get(context, instData);
break;
case xegpu::LayoutKind::Lane:
- requiredLayout = xegpu::LayoutAttr::get(
- vectorTy.getContext(), {1, uArch->getSubgroupSize()}, {1, 1});
+ requiredLayout =
+ xegpu::LayoutAttr::get(context, {1, uArch->getSubgroupSize()}, {1, 1});
break;
default:
llvm_unreachable("unsupported layout kind");
@@ -897,37 +910,99 @@ xegpu::loadMatrixSetupAnchorLayout(xegpu::LayoutKind layoutKind,
}
static xegpu::DistributeLayoutAttr
-getDefaultSIMTLaneLayoutAttr(mlir::MLIRContext *ctx, unsigned rank,
- int subgroupSize) {
+getDefaultLaneLayoutAttr(mlir::MLIRContext *ctx, unsigned rank,
+ const xegpu::uArch::uArch *uArch) {
assert((rank == 1 || rank == 2) && "Expected 1D or 2D vector.");
if (rank == 1) {
- return xegpu::LayoutAttr::get(ctx, {subgroupSize}, {1});
+ return xegpu::LayoutAttr::get(ctx, {uArch->getSubgroupSize()}, {1});
}
- return xegpu::LayoutAttr::get(ctx, {1, subgroupSize}, {1, 1});
+ return xegpu::LayoutAttr::get(ctx, {1, uArch->getSubgroupSize()}, {1, 1});
}
-xegpu::DistributeLayoutAttr
-xegpu::loadGatherSetupAnchorLayout(LayoutKind layoutKind, VectorType vectorTy,
- const uArch::uArch *uArch) {}
+xegpu::DistributeLayoutAttr xegpu::loadGatherSetupAnchorLayout(
+ LayoutKind layoutKind, VectorType resVecTy, int chunkSize,
+ DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch) {
+
+ xegpu::DistributeLayoutAttr requiredLayout;
+ const int subgroupSize = uArch->getSubgroupSize();
+ const int spirVectorSize = 16; // vector size from SPRIV vector restriction
+
+ auto resShape = resVecTy.getShape();
+ int resShapeSize = resShape.size();
+ SmallVector<int64_t> instData(subgroupSize);
+ auto context = resVecTy.getContext();
+
+ const auto *uArchInstruction =
+ dyn_cast<xegpu::uArch::StoreScatterInstruction>(
+ uArch->getInstruction(xegpu::uArch::InstructionKind::StoreScatter));
+
+ SmallVector<int64_t> consumerSgLayout =
+ consumerLayout.getEffectiveSgLayoutAsInt();
+ SmallVector<int64_t> consumerSgData =
+ consumerLayout.getEffectiveSgDataAsInt();
+ SmallVector<int32_t> sgLayout32(consumerSgLayout.begin(),
+ consumerSgLayout.end());
+ SmallVector<int32_t> sgData32(consumerSgData.begin(), consumerSgData.end());
+
+ SmallVector<int64_t> consumerInstData =
+ consumerLayout.getEffectiveInstDataAsInt();
+ SmallVector<int32_t> instData32;
+
+ switch (layoutKind) {
+ case xegpu::LayoutKind::Subgroup:
+ requiredLayout = xegpu::LayoutAttr::get(
+ context, DenseI32ArrayAttr::get(context, sgLayout32),
+ DenseI32ArrayAttr::get(context, sgData32), consumerLayout.getOrder());
+ break;
+ case xegpu::LayoutKind::InstData:
+ if (resVecTy.getRank() == 1) {
+ instData[0] = subgroupSize;
+ } else {
+ assert((resVecTy.getRank() > 2) && "StoreScatterOp can access 2D tensor "
+ "tile at maximum at subgroup level.");
+ if (chunkSize == 1) {
+ instData[0] =
+ std::min(resShape[0], static_cast<int64_t>(spirVectorSize));
+ instData[0] = std::min(instData[0], consumerInstData[0]);
+ instData[1] = subgroupSize;
+ } else {
+ instData[0] = subgroupSize;
+ instData[1] = std::min(
+ resShape[1],
+ static_cast<int64_t>(uArchInstruction->getMaxLaneLoadStoreSize()));
+ instData[1] = std::min(instData[1], consumerInstData[1]);
+ }
+ }
+ instData32 = SmallVector<int32_t>(instData.begin(), instData.end());
+ requiredLayout = xegpu::LayoutAttr::get(
+ context, DenseI32ArrayAttr::get(context, instData32));
+ break;
+ case xegpu::LayoutKind::Lane:
+ requiredLayout =
+ getDefaultLaneLayoutAttr(context, resVecTy.getRank(), uArch);
+ break;
+ default:
+ llvm_unreachable("unsupported layout kind");
+ }
+ return requiredLayout;
+}
xegpu::DistributeLayoutAttr
xegpu::storeScatterSetupAnchorLayout(LayoutKind layoutKind, VectorType srcVecTy,
- const uArch::uArch *uArch) {
+ int chunkSize, const uArch::uArch *uArch) {
xegpu::DistributeLayoutAttr requiredLayout;
const int subgroupSize = uArch->getSubgroupSize();
const int spirVectorSize = 16; // vector size from SPRIV vector restriction
auto srcShape = srcVecTy.getShape();
- int srcShapeSize = srcVecTy.getShape().size();
-
- SmallVector<int64_t> instData(srcShapeSize);
- SmallVector<int64_t> laneLayout(srcShapeSize);
- SmallVector<int64_t> laneData(srcShapeSize);
+ int srcShapeSize = srcShape.size();
+ SmallVector<int64_t> instData(subgroupSize);
const auto *uArchInstruction =
dyn_cast<xegpu::uArch::StoreScatterInstruction>(
uArch->getInstruction(xegpu::uArch::InstructionKind::StoreScatter));
+ auto context = srcVecTy.getContext();
switch (layoutKind) {
case xegpu::LayoutKind::Subgroup:
@@ -935,19 +1010,26 @@ xegpu::storeScatterSetupAnchorLayout(LayoutKind layoutKind, VectorType srcVecTy,
"subgroup layout assignment not supported yet for loadMatrix.");
break;
case xegpu::LayoutKind::InstData:
- assert((srcVecTy.getRank() > 2) && "StoreScatterOp can access 2D tensor "
- "tile at maximum at subgroup level.");
- if (srcVecTy.getRank() == 1)
+ if (srcVecTy.getRank() == 1) {
instData[0] = subgroupSize;
- else {
- instData[0] = std::min(srcShape[0], static_cast<int64_t>(spirVectorSize));
- instData[1] = subgroupSize;
+ } else {
+ assert((srcVecTy.getRank() > 2) && "StoreScatterOp can access 2D tensor "
+ "tile at maximum at subgroup level.");
+ if (chunkSize == 1) {
+ instData[0] =
+ std::min(srcShape[0], static_cast<int64_t>(spirVectorSize));
+ instData[1] = subgroupSize;
+ } else {
+ instData[0] = subgroupSize;
+ instData[1] = std::min(
+ srcShape[1],
+ static_cast<int64_t>(uArchInstruction->getMaxLaneLoadStoreSize()));
+ }
}
- requiredLayout = xegpu::LayoutAttr::get(srcVecTy.getContext(), instData);
break;
case xegpu::LayoutKind::Lane:
- requiredLayout = getDefaultSIMTLaneLayoutAttr(
- srcVecTy.getContext(), srcVecTy.getRank(), subgroupSize);
+ requiredLayout =
+ getDefaultLaneLayoutAttr(context, srcVecTy.getRank(), uArch);
break;
default:
llvm_unreachable("unsupported layout kind");
>From ef5a279afb8394e0bdfd9c46de48e4fb4324ba67 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Tue, 27 Jan 2026 20:08:28 +0000
Subject: [PATCH 13/35] add scatter IO mask handling
---
.../Dialect/XeGPU/Utils/XeGPULayoutUtils.h | 26 ++--
mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp | 52 +++----
.../XeGPU/Transforms/XeGPUPropagateLayout.cpp | 135 +++++++-----------
.../Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp | 130 ++++++++++++-----
4 files changed, 184 insertions(+), 159 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h
index c083619292d0c..bba2b1e06129a 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h
@@ -90,6 +90,12 @@ DistributeLayoutAttr inferShapeCastSourceLayout(DistributeLayoutAttr resLayout,
ArrayRef<int64_t> resShape,
ArrayRef<int64_t> srcShape);
+/// Infers the source layout attribute for mask operand of scatter IO operation
+/// given the result layout attribute, value shape, and mask shape.
+DistributeLayoutAttr inferScatterIOMaskLayout(DistributeLayoutAttr resLayout,
+ ArrayRef<int64_t> valShape,
+ ArrayRef<int64_t> maskShape);
+
/// Sets up layout for reduction operations by creating a SliceAttr for the
/// result.
///
@@ -98,11 +104,11 @@ DistributeLayoutAttr inferShapeCastSourceLayout(DistributeLayoutAttr resLayout,
/// consumer's preferred layout. This minimizes data redistribution overhead.
/// The SliceAttr for the result is then created based on the derived source
/// layout and the specified reduction dimensions.
-SliceAttr reductionSetupResultLayout(xegpu::LayoutKind layoutKind,
- VectorType srcVectorTy,
- DistributeLayoutAttr consumerLayout,
- SmallVector<int64_t> reductionDims,
- const uArch::uArch *uArch);
+SliceAttr setupMultiReductionResultLayout(xegpu::LayoutKind layoutKind,
+ VectorType srcVectorTy,
+ DistributeLayoutAttr consumerLayout,
+ SmallVector<int64_t> reductionDims,
+ const uArch::uArch *uArch);
/// Setup the result layout attribute for a bitcast operation based on element
/// type bitwidths. This ensures the source layout can always be derived from
@@ -113,25 +119,25 @@ SliceAttr reductionSetupResultLayout(xegpu::LayoutKind layoutKind,
/// (inst_data, lane_data) are scaled up by the bitwidth ratio. This
/// maintains the invariant that the source layout can be recovered by inverse
/// scaling during layout inference.
-DistributeLayoutAttr bitCastSetupResultLayout(
+DistributeLayoutAttr setupBitCastResultLayout(
LayoutKind layoutKind, VectorType srcVectorTy, VectorType resVectorTy,
DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch);
xegpu::DistributeLayoutAttr
-xegpu::loadMatrixSetupAnchorLayout(LayoutKind layoutKind, VectorType vectorTy,
+xegpu::setupLoadMatrixAnchorLayout(LayoutKind layoutKind, VectorType vectorTy,
xegpu::DistributeLayoutAttr consumerLayout,
const uArch::uArch *uArch);
-DistributeLayoutAttr storeMatrixSetupAnchorLayout(LayoutKind layoutKind,
+DistributeLayoutAttr setupStoreMatrixAnchorLayout(LayoutKind layoutKind,
VectorType vectorTy,
const uArch::uArch *uArch);
-xegpu::DistributeLayoutAttr xegpu::loadGatherSetupAnchorLayout(
+xegpu::DistributeLayoutAttr xegpu::setupLoadGatherAnchorLayout(
LayoutKind layoutKind, VectorType vectorTy, int chunkSize,
DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch);
xegpu::DistributeLayoutAttr
-xegpu::storeScatterSetupAnchorLayout(LayoutKind layoutKind, VectorType vectorTy,
+xegpu::setupStoreScatterAnchorLayout(LayoutKind layoutKind, VectorType vectorTy,
int chunkSize, const uArch::uArch *uArch);
} // namespace xegpu
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 5bf57376cbbaf..78d1ad50dd1dd 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -513,15 +513,11 @@ LayoutAttr::collapseDims(SmallVector<SmallVector<int64_t>> dimGroups) const {
SmallVector<int64_t> laneData = getEffectiveLaneDataAsInt();
DenseI32ArrayAttr orderAttr = getOrder();
- SmallVector<int64_t> order;
+ SmallVector<int64_t> orderVec;
if (orderAttr && !orderAttr.empty()) {
- order = llvm::to_vector(
+ orderVec = llvm::to_vector(
llvm::map_range(orderAttr.asArrayRef(),
[](int32_t idx) { return static_cast<int64_t>(idx); }));
- } else {
- // Default order: [1, 0] for 2D (row-major), [2, 1, 0] for 3D, etc.
- order =
- llvm::to_vector(llvm::reverse(llvm::seq<int64_t>(0, sgLayout.size())));
}
SmallVector<int64_t> collapsedSgLayout;
@@ -555,22 +551,6 @@ LayoutAttr::collapseDims(SmallVector<SmallVector<int64_t>> dimGroups) const {
collapsedOrder.push_back(collapsedOrderValue);
}
- // go through the values inside collapsedOrder, and re-map the order values to
- // be in range of [0, N-1] where N is the number of dimensions in collapsed
- // shape
- int64_t orderSize = static_cast<int64_t>(collapsedOrder.size());
- SmallVector<int64_t> remappedOrder(orderSize, -1);
- for (int64_t i = 0; i < orderSize; ++i) {
- int64_t originalOrderValue = collapsedOrder[i];
- // count how many values in collapsedOrder are less than originalOrderValue
- int64_t count = 0;
- for (int64_t j = 0; j < orderSize; ++j) {
- if (collapsedOrder[j] < originalOrderValue)
- count++;
- }
- remappedOrder[i] = count;
- }
-
// Create collapsed layout
SmallVector<int32_t> collapsedSgLayout32(collapsedSgLayout.begin(),
collapsedSgLayout.end());
@@ -582,8 +562,28 @@ LayoutAttr::collapseDims(SmallVector<SmallVector<int64_t>> dimGroups) const {
collapsedLaneLayout.end());
SmallVector<int32_t> collapsedLaneData32(collapsedLaneData.begin(),
collapsedLaneData.end());
- SmallVector<int32_t> remappedOrder32(remappedOrder.begin(),
- remappedOrder.end());
+
+ // go through the values inside collapsedOrder, and re-map the order values to
+ // be in range of [0, N-1] where N is the number of dimensions in collapsed
+ // shape
+ SmallVector<int32_t> remappedOrder32;
+ if (!orderVec.empty()) {
+ int64_t orderSize = static_cast<int64_t>(collapsedOrder.size());
+ SmallVector<int64_t> remappedOrder(orderSize, -1);
+ for (int64_t i = 0; i < orderSize; ++i) {
+ int64_t originalOrderValue = collapsedOrder[i];
+ // count how many values in collapsedOrder are less than
+ // originalOrderValue
+ int64_t count = 0;
+ for (int64_t j = 0; j < orderSize; ++j) {
+ if (collapsedOrder[j] < originalOrderValue)
+ count++;
+ }
+ remappedOrder[i] = count;
+ }
+ remappedOrder32 =
+ SmallVector<int32_t>(remappedOrder.begin(), remappedOrder.end());
+ }
auto collapsedLayout = xegpu::LayoutAttr::get(
getContext(), DenseI32ArrayAttr::get(getContext(), collapsedSgLayout32),
@@ -591,7 +591,9 @@ LayoutAttr::collapseDims(SmallVector<SmallVector<int64_t>> dimGroups) const {
DenseI32ArrayAttr::get(getContext(), collapsedInstData32),
DenseI32ArrayAttr::get(getContext(), collapsedLaneLayout32),
DenseI32ArrayAttr::get(getContext(), collapsedLaneData32),
- DenseI32ArrayAttr::get(getContext(), remappedOrder32));
+ remappedOrder32.empty()
+ ? nullptr
+ : DenseI32ArrayAttr::get(getContext(), remappedOrder32));
return collapsedLayout;
}
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 0236ba3437b6b..e6fa8bfd12474 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -1150,88 +1150,43 @@ void LayoutInfoPropagation::visitLoadGatherOp(
xegpu::LoadGatherOp load, ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results) {
- LayoutInfo loadLayout;
- LayoutInfo maskLayout;
+ xegpu::DistributeLayoutAttr requiredAnchorLayoutAttr;
+ xegpu::DistributeLayoutAttr anchorLayoutAttr = load.getLayoutAttr();
auto uArch = getUArch(getChipStr(load).value_or(""));
- const int subgroupSize = uArch->getSubgroupSize();
- xegpu::DistributeLayoutAttr anchorLayout = load.getLayoutAttr();
- if (hasParamsOfLayoutKind(anchorLayout)) {
- loadLayout = LayoutInfo(anchorLayout);
- maskLayout = loadLayout;
- } else {
- LayoutInfo valueLayout = results[0]->getValue();
- // Need the layout of the value to propagate to the tensor descriptor.
- if (!valueLayout.isAssigned())
- return;
+ auto subgroupSize = uArch->getSubgroupSize();
+ VectorType resVecTy = load.getValueType();
+ VectorType maskTy = llvm::dyn_cast<VectorType>(load.getMask().getType());
+ int chunkSize = load.getChunkSize().value_or(1);
- auto resAttr = dyn_cast<xegpu::DistributeLayoutAttr>(valueLayout.get());
- auto instDataIncoming = resAttr.getEffectiveInstDataAsInt();
- if (auto sliceAttr = dyn_cast<xegpu::SliceAttr>(resAttr))
- instDataIncoming = SmallVector<int64_t>(
- cast<xegpu::LayoutAttr>(sliceAttr.flatten().getParent())
- .getInstData()
- .asArrayRef());
-
- VectorType payloadTy = load.getValueType();
- if (!payloadTy) {
+ if (hasParamsOfLayoutKind(anchorLayoutAttr)) {
+ requiredAnchorLayoutAttr = anchorLayoutAttr;
+ } else {
+ if (!resVecTy) {
load.emitWarning("Not propagating, non-vector payload supplied.");
return;
}
- const auto *uArchInstruction =
- dyn_cast<xegpu::uArch::LoadGatherInstruction>(
- uArch->getInstruction(xegpu::uArch::InstructionKind::LoadGather));
-
- // Check if value inst_data complies with uArch
- if (layoutKind == xegpu::LayoutKind::InstData) {
- // Each lane loads either one element
- SmallVector<int> instDataUarch{subgroupSize};
- // Or multiple elements as 2D with lane's elements in the inner dimension
- if (payloadTy.getRank() != 1) {
- if (payloadTy.getRank() != 2) {
- load.emitWarning("Expected 2D payload for LoadGatherOp.");
- return;
- }
- instDataUarch.push_back(
- (std::min(static_cast<int>(payloadTy.getShape().back()),
- uArchInstruction->getMaxLaneLoadStoreSize())));
- }
- // If inst data does not match, enforce the uArch-based one
- if (!llvm::equal(instDataIncoming, instDataUarch)) {
- xegpu::LayoutAttr sourceAttr = dyn_cast<xegpu::LayoutAttr>(resAttr);
- if (auto sliceAttr = dyn_cast<xegpu::SliceAttr>(resAttr)) {
- sourceAttr = cast<xegpu::LayoutAttr>(sliceAttr.flatten().getParent());
- }
- assert(sourceAttr);
- xegpu::DistributeLayoutAttr updatedLayoutAttr = xegpu::LayoutAttr::get(
- load.getContext(), sourceAttr.getSgLayout(), sourceAttr.getSgData(),
- DenseI32ArrayAttr::get(load.getContext(), instDataUarch),
- sourceAttr.getLaneLayout(), sourceAttr.getLaneData(),
- sourceAttr.getOrder());
-
- if (auto sliceAttr = dyn_cast<xegpu::SliceAttr>(resAttr))
- updatedLayoutAttr = xegpu::SliceAttr::get(
- load.getContext(), updatedLayoutAttr, sliceAttr.getDims());
- valueLayout = LayoutInfo(updatedLayoutAttr);
- }
- }
- loadLayout = valueLayout;
- load.setLayoutAttr(dyn_cast<xegpu::DistributeLayoutAttr>(loadLayout.get()));
+ requiredAnchorLayoutAttr = xegpu::setupLoadGatherAnchorLayout(
+ layoutKind, resVecTy, chunkSize, uArch);
+ load.setLayoutAttr(requiredAnchorLayoutAttr);
}
- // If no user-defined anchor or we deal with a chunked op, set the default
- // mask layout.
- // Rank 1 data : Keep the mask layout aligned with data.
- // Rank >1 data: Enforce the default xegpu 1D layout for mask.
- if (!hasParamsOfLayoutKind(anchorLayout) ||
- load.getValueType().getRank() > 1) {
+ auto maskLayoutAttr = anchorLayoutAttr;
+ // Special handling mask layout for chunked ops: Enforce the default xegpu 1D
+ // layout for mask.
+ if (chunkSize > 1) {
if (layoutKind == xegpu::LayoutKind::InstData)
- maskLayout = LayoutInfo(
- xegpu::LayoutAttr::get(load->getContext(), {subgroupSize}));
+ maskLayoutAttr =
+ xegpu::LayoutAttr::get(load->getContext(), {subgroupSize});
else if (layoutKind == xegpu::LayoutKind::Lane)
- maskLayout =
- getDefaultSIMTLayoutInfo(load->getContext(), 1, subgroupSize);
+ maskLayoutAttr =
+ xegpu::LayoutAttr::get(load->getContext(), {subgroupSize}, {1});
+ else
+ assert(false &&
+ "chunked StoreScatterOp should not be used at workgroup level");
}
+ LayoutInfo maskLayoutInfo = LayoutInfo(maskLayoutAttr);
+
// Propagate the new layout to the tensor descriptor operand.
if (isa<xegpu::TensorDescType>(load.getSourceType()))
propagateIfChanged(operands[0], operands[0]->meet(loadLayout));
@@ -1263,36 +1218,45 @@ void LayoutInfoPropagation::visitStoreScatterOp(
xegpu::StoreScatterOp storeScatter, ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results) {
- LayoutInfo srcLayoutInfo;
- LayoutInfo maskLayoutInfo;
xegpu::DistributeLayoutAttr requiredAnchorLayoutAttr;
- xegpu::DistributeLayoutAttr anchorLayout = storeScatter.getLayoutAttr();
+ xegpu::DistributeLayoutAttr anchorLayoutAttr = storeScatter.getLayoutAttr();
auto uArch = getUArch(getChipStr(storeScatter).value_or(""));
+ auto subgroupSize = uArch->getSubgroupSize();
VectorType srcVecTy = storeScatter.getValueType();
VectorType maskTy =
llvm::dyn_cast<VectorType>(storeScatter.getMask().getType());
int chunkSize = storeScatter.getChunkSize().value_or(1);
- if (hasParamsOfLayoutKind(anchorLayout)) {
- requiredAnchorLayoutAttr = LayoutInfo(anchorLayout);
+ if (hasParamsOfLayoutKind(anchorLayoutAttr)) {
+ requiredAnchorLayoutAttr = anchorLayoutAttr;
} else {
if (!srcVecTy) {
storeScatter.emitWarning("Not propagating, non-vector payload supplied.");
return;
}
- requiredAnchorLayoutAttr = xegpu::storeScatterSetupAnchorLayout(
+ requiredAnchorLayoutAttr = xegpu::setupStoreScatterAnchorLayout(
layoutKind, srcVecTy, chunkSize, uArch);
storeScatter.setLayoutAttr(requiredAnchorLayoutAttr);
}
- srcLayoutInfo = LayoutInfo(requiredAnchorLayoutAttr);
- maskLayoutInfo = srcLayoutInfo;
- if (maskTy.getRank() < srcVecTy.getRank()) {
- assert((maskTy.getRank() == (srcVecTy.getRank() - 1)) &&
- "Expecting mask vector only 1 dimension less than value vector.");
- // ToDO: Infer the proper mask layout based on the value layout.
- maskLayoutInfo = LayoutInfo(requiredAnchorLayoutAttr);
+ LayoutInfo srcLayoutInfo = LayoutInfo(requiredAnchorLayoutAttr);
+ auto maskLayoutAttr = anchorLayoutAttr;
+ // Special handling mask layout for chunked ops: Enforce the default xegpu 1D
+ // layout for mask.
+ if (chunkSize > 1) {
+ if (layoutKind == xegpu::LayoutKind::InstData)
+ maskLayoutAttr =
+ xegpu::LayoutAttr::get(storeScatter->getContext(), {subgroupSize});
+ else if (layoutKind == xegpu::LayoutKind::Lane)
+ maskLayoutAttr = xegpu::LayoutAttr::get(storeScatter->getContext(),
+ {subgroupSize}, {1});
+ else
+ assert(false &&
+ "chunked StoreScatterOp should not be used at workgroup level");
}
+
+ LayoutInfo maskLayoutInfo = LayoutInfo(maskLayoutAttr);
+
// Propagate the payload operand layout
propagateIfChanged(operands[0], operands[0]->meet(srcLayoutInfo));
// Propagate the destination (if tdesc) operand layout
@@ -1316,7 +1280,8 @@ void LayoutInfoPropagation::visitLoadMatrixOp(
xegpu::DistributeLayoutAttr anchorLayout = loadMatrixOp.getLayoutAttr();
- // only need to set anchor layout, no need to porpagate to memdesc and offset
+ // only need to set anchor layout, no need to porpagate to memdesc and
+ // offset
if (!hasParamsOfLayoutKind(anchorLayout)) {
VectorType resVecTy =
llvm::cast<VectorType>(loadMatrixOp.getRes().getType());
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp
index 3968f0f2db0af..8d7d1b96b363e 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp
@@ -532,6 +532,61 @@ xegpu::inferShapeCastSourceLayout(xegpu::DistributeLayoutAttr resLayout,
return nullptr;
}
+xegpu::DistributeLayoutAttr
+inferScatterIOMaskLayout(xegpu::DistributeLayoutAttr resLayout,
+ ArrayRef<int64_t> valShape,
+ ArrayRef<int64_t> maskShape) {
+
+ xegpu::LayoutAttr resPlainLayout = dyn_cast<xegpu::LayoutAttr>(resLayout);
+ assert(resPlainLayout &&
+ "Expecting plain layout for scatter IO mask inference.");
+ xegpu::LayoutAttr maskLayout = resPlainLayout;
+
+ if (maskShape.size() < valShape.size()) {
+ int valShapeSize = valShape.size();
+ int maskShapeSize = maskShape.size();
+ assert((maskShapeSize == (valShapeSize - 1)) &&
+ "Expecting mask vector only 1 dimension less than value vector.");
+
+ maskLayout = resPlainLayout.setUnitDimData(valShapeSize - 1);
+ maskLayout = maskLayout.collapseDim(valShapeSize - 1);
+ // SmallVector<int> sgLayout(valShapeSize);
+ // sgLayout = resPlainLayout.getSgLayout();
+ // SmallVector<int> sgData(valShapeSize) = resPlainLayout.getSgData();
+ // SmallVector<int> instData(valShapeSize) = resPlainLayout.getSgLayout();
+ // SmallVector<int> laneLayout(valShapeSize) =
+ // resPlainLayout.getLaneLayout(); SmallVector<int> laneData(valShapeSize) =
+ // resPlainLayout.getSgLayout(); SmallVector<int> order(valShapeSize) =
+ // resPlainLayout.getOrder();
+
+ // // drop the innermost dimension for mask layout
+ // sgLayout.pop_back();
+ // sgData.pop_back();
+ // instData.pop_back();
+ // laneLayout.pop_back();
+ // laneData.pop_back();
+ // order.pop_back();
+
+ // SmallVector<int64_t> remappedOrder(maskShapeSize, -1);
+ // for (int64_t i = 0; i < maskShapeSize; ++i) {
+ // int64_t originalOrderValue = order[i];
+ // // count how many values in collapsed Order are less than
+ // // originalOrderValue
+ // int64_t count = 0;
+ // for (int64_t j = 0; j < maskShapeSize; ++j) {
+ // if (order[j] < originalOrderValue)
+ // count++;
+ // }
+ // remappedOrder[i] = count;
+ // }
+
+ // maskLayout =
+ // xegpu::LayoutAttr::get(resLayout.getContext(), sgLayout, sgData,
+ // instData, laneLayout, laneData, order);
+ }
+ return maskLayout;
+}
+
/// Sets up layout for reduction operations by creating a SliceAttr for the
/// result.
///
@@ -551,7 +606,7 @@ xegpu::inferShapeCastSourceLayout(xegpu::DistributeLayoutAttr resLayout,
/// reduced result stays with the same subgroup distribution as expected by
/// the consumer.
-xegpu::SliceAttr xegpu::reductionSetupResultLayout(
+xegpu::SliceAttr xegpu::setupMultiReductionResultLayout(
xegpu::LayoutKind layoutKind, VectorType srcVecTy,
DistributeLayoutAttr consumerLayout, SmallVector<int64_t> reductionDims,
const xegpu::uArch::uArch *uArch) {
@@ -782,7 +837,7 @@ xegpu::SliceAttr xegpu::reductionSetupResultLayout(
return resLayout;
}
-xegpu::DistributeLayoutAttr xegpu::bitCastSetupResultLayout(
+xegpu::DistributeLayoutAttr xegpu::setupBitCastResultLayout(
xegpu::LayoutKind layoutKind, VectorType srcVecTy, VectorType resVecTy,
DistributeLayoutAttr consumerLayout, const xegpu::uArch::uArch *uArch) {
@@ -847,7 +902,7 @@ xegpu::DistributeLayoutAttr xegpu::bitCastSetupResultLayout(
}
xegpu::DistributeLayoutAttr
-xegpu::storeMatrixSetupAnchorLayout(xegpu::LayoutKind layoutKind,
+xegpu::setupStoreMatrixAnchorLayout(xegpu::LayoutKind layoutKind,
VectorType vectorTy,
const xegpu::uArch::uArch *uArch) {
@@ -873,28 +928,18 @@ xegpu::storeMatrixSetupAnchorLayout(xegpu::LayoutKind layoutKind,
}
xegpu::DistributeLayoutAttr
-xegpu::loadMatrixSetupAnchorLayout(xegpu::LayoutKind layoutKind,
+xegpu::setupLoadMatrixAnchorLayout(xegpu::LayoutKind layoutKind,
VectorType vectorTy,
xegpu::DistributeLayoutAttr consumerLayout,
const xegpu::uArch::uArch *uArch) {
xegpu::DistributeLayoutAttr requiredLayout;
SmallVector<int> instData = {1, uArch->getSubgroupSize()};
- SmallVector<int64_t> consumerSgLayout =
- consumerLayout.getEffectiveSgLayoutAsInt();
- SmallVector<int64_t> consumerSgData =
- consumerLayout.getEffectiveSgDataAsInt();
- SmallVector<int32_t> sgLayout32(consumerSgLayout.begin(),
- consumerSgLayout.end());
- SmallVector<int32_t> sgData32(consumerSgData.begin(), consumerSgData.end());
-
auto context = vectorTy.getContext();
switch (layoutKind) {
case xegpu::LayoutKind::Subgroup:
- requiredLayout = xegpu::LayoutAttr::get(
- context, DenseI32ArrayAttr::get(context, sgLayout32),
- DenseI32ArrayAttr::get(context, sgData32), consumerLayout.getOrder());
+ requiredLayout = consumerLayout;
break;
case xegpu::LayoutKind::InstData:
requiredLayout = xegpu::LayoutAttr::get(context, instData);
@@ -919,7 +964,7 @@ getDefaultLaneLayoutAttr(mlir::MLIRContext *ctx, unsigned rank,
return xegpu::LayoutAttr::get(ctx, {1, uArch->getSubgroupSize()}, {1, 1});
}
-xegpu::DistributeLayoutAttr xegpu::loadGatherSetupAnchorLayout(
+xegpu::DistributeLayoutAttr xegpu::setupLoadGatherAnchorLayout(
LayoutKind layoutKind, VectorType resVecTy, int chunkSize,
DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch) {
@@ -936,23 +981,13 @@ xegpu::DistributeLayoutAttr xegpu::loadGatherSetupAnchorLayout(
dyn_cast<xegpu::uArch::StoreScatterInstruction>(
uArch->getInstruction(xegpu::uArch::InstructionKind::StoreScatter));
- SmallVector<int64_t> consumerSgLayout =
- consumerLayout.getEffectiveSgLayoutAsInt();
- SmallVector<int64_t> consumerSgData =
- consumerLayout.getEffectiveSgDataAsInt();
- SmallVector<int32_t> sgLayout32(consumerSgLayout.begin(),
- consumerSgLayout.end());
- SmallVector<int32_t> sgData32(consumerSgData.begin(), consumerSgData.end());
-
SmallVector<int64_t> consumerInstData =
consumerLayout.getEffectiveInstDataAsInt();
SmallVector<int32_t> instData32;
switch (layoutKind) {
case xegpu::LayoutKind::Subgroup:
- requiredLayout = xegpu::LayoutAttr::get(
- context, DenseI32ArrayAttr::get(context, sgLayout32),
- DenseI32ArrayAttr::get(context, sgData32), consumerLayout.getOrder());
+ requiredLayout = consumerLayout;
break;
case xegpu::LayoutKind::InstData:
if (resVecTy.getRank() == 1) {
@@ -961,9 +996,7 @@ xegpu::DistributeLayoutAttr xegpu::loadGatherSetupAnchorLayout(
assert((resVecTy.getRank() > 2) && "StoreScatterOp can access 2D tensor "
"tile at maximum at subgroup level.");
if (chunkSize == 1) {
- instData[0] =
- std::min(resShape[0], static_cast<int64_t>(spirVectorSize));
- instData[0] = std::min(instData[0], consumerInstData[0]);
+ instData[0] = 1;
instData[1] = subgroupSize;
} else {
instData[0] = subgroupSize;
@@ -978,8 +1011,18 @@ xegpu::DistributeLayoutAttr xegpu::loadGatherSetupAnchorLayout(
context, DenseI32ArrayAttr::get(context, instData32));
break;
case xegpu::LayoutKind::Lane:
- requiredLayout =
- getDefaultLaneLayoutAttr(context, resVecTy.getRank(), uArch);
+ if (chunkSize == 1)
+ requiredLayout =
+ getDefaultLaneLayoutAttr(context, resVecTy.getRank(), uArch);
+ else {
+ assert((resVecTy.getRank() <= 2) && "StoreScatterOp can access 2D tensor "
+ "tile at maximum at subgroup level.");
+ assert(resShape[1] <= static_cast<int64_t>(
+ uArchInstruction->getMaxLaneLoadStoreSize()) &&
+ "StoreScatterOp lane size exceeds max lane load/store size.");
+ requiredLayout = xegpu::LayoutAttr::get(
+ context, {subgroupSize, 1}, {1, static_cast<int>(resShape[1])});
+ }
break;
default:
llvm_unreachable("unsupported layout kind");
@@ -988,7 +1031,7 @@ xegpu::DistributeLayoutAttr xegpu::loadGatherSetupAnchorLayout(
}
xegpu::DistributeLayoutAttr
-xegpu::storeScatterSetupAnchorLayout(LayoutKind layoutKind, VectorType srcVecTy,
+xegpu::setupStoreScatterAnchorLayout(LayoutKind layoutKind, VectorType srcVecTy,
int chunkSize, const uArch::uArch *uArch) {
xegpu::DistributeLayoutAttr requiredLayout;
@@ -1013,11 +1056,10 @@ xegpu::storeScatterSetupAnchorLayout(LayoutKind layoutKind, VectorType srcVecTy,
if (srcVecTy.getRank() == 1) {
instData[0] = subgroupSize;
} else {
- assert((srcVecTy.getRank() > 2) && "StoreScatterOp can access 2D tensor "
- "tile at maximum at subgroup level.");
+ assert((srcVecTy.getRank() <= 2) && "StoreScatterOp can access 2D tensor "
+ "tile at maximum at subgroup level.");
if (chunkSize == 1) {
- instData[0] =
- std::min(srcShape[0], static_cast<int64_t>(spirVectorSize));
+ instData[0] = 1;
instData[1] = subgroupSize;
} else {
instData[0] = subgroupSize;
@@ -1028,8 +1070,18 @@ xegpu::storeScatterSetupAnchorLayout(LayoutKind layoutKind, VectorType srcVecTy,
}
break;
case xegpu::LayoutKind::Lane:
- requiredLayout =
- getDefaultLaneLayoutAttr(context, srcVecTy.getRank(), uArch);
+ if (chunkSize == 1)
+ requiredLayout =
+ getDefaultLaneLayoutAttr(context, srcVecTy.getRank(), uArch);
+ else {
+ assert((srcVecTy.getRank() > 2) && "StoreScatterOp can access 2D tensor "
+ "tile at maximum at subgroup level.");
+ assert(srcShape[1] <= static_cast<int64_t>(
+ uArchInstruction->getMaxLaneLoadStoreSize()) &&
+ "StoreScatterOp lane size exceeds max lane load/store size.");
+ requiredLayout = xegpu::LayoutAttr::get(
+ context, {subgroupSize, 1}, {1, static_cast<int>(srcShape[1])});
+ }
break;
default:
llvm_unreachable("unsupported layout kind");
>From a717336e6c639abf8b3636f300e119c9cefc4ea2 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Tue, 27 Jan 2026 22:48:06 +0000
Subject: [PATCH 14/35] setting correct inst layout for 1dreduction example
---
.../Dialect/XeGPU/Utils/XeGPULayoutUtils.h | 34 ++--
mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp | 3 +-
.../XeGPU/Transforms/XeGPUPropagateLayout.cpp | 126 +++++++++++++--
.../Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp | 145 +++++++++---------
4 files changed, 200 insertions(+), 108 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h
index bba2b1e06129a..95122104f22e5 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h
@@ -74,8 +74,8 @@ DistributeLayoutAttr inferBroadcastSourceLayout(DistributeLayoutAttr resLayout,
/// Infers the source layout attribute for a reduction operation given the
/// result layout attribute and reduced dims.
DistributeLayoutAttr
-inferReductionSourceLayout(DistributeLayoutAttr resLayout,
- SmallVector<int64_t> reduceDims);
+inferMultiReductionSourceLayout(DistributeLayoutAttr resLayout,
+ SmallVector<int64_t> reduceDims);
/// Infers the source layout attribute for a bitcast operation given the
/// result layout attribute, result element type bitwidth, and source element
@@ -90,12 +90,6 @@ DistributeLayoutAttr inferShapeCastSourceLayout(DistributeLayoutAttr resLayout,
ArrayRef<int64_t> resShape,
ArrayRef<int64_t> srcShape);
-/// Infers the source layout attribute for mask operand of scatter IO operation
-/// given the result layout attribute, value shape, and mask shape.
-DistributeLayoutAttr inferScatterIOMaskLayout(DistributeLayoutAttr resLayout,
- ArrayRef<int64_t> valShape,
- ArrayRef<int64_t> maskShape);
-
/// Sets up layout for reduction operations by creating a SliceAttr for the
/// result.
///
@@ -123,22 +117,24 @@ DistributeLayoutAttr setupBitCastResultLayout(
LayoutKind layoutKind, VectorType srcVectorTy, VectorType resVectorTy,
DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch);
-xegpu::DistributeLayoutAttr
-xegpu::setupLoadMatrixAnchorLayout(LayoutKind layoutKind, VectorType vectorTy,
- xegpu::DistributeLayoutAttr consumerLayout,
- const uArch::uArch *uArch);
+DistributeLayoutAttr
+setupLoadMatrixAnchorLayout(LayoutKind layoutKind, VectorType vectorTy,
+ DistributeLayoutAttr consumerLayout,
+ const uArch::uArch *uArch);
DistributeLayoutAttr setupStoreMatrixAnchorLayout(LayoutKind layoutKind,
VectorType vectorTy,
const uArch::uArch *uArch);
-xegpu::DistributeLayoutAttr xegpu::setupLoadGatherAnchorLayout(
- LayoutKind layoutKind, VectorType vectorTy, int chunkSize,
- DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch);
-
-xegpu::DistributeLayoutAttr
-xegpu::setupStoreScatterAnchorLayout(LayoutKind layoutKind, VectorType vectorTy,
- int chunkSize, const uArch::uArch *uArch);
+DistributeLayoutAttr
+setupLoadGatherAnchorLayout(LayoutKind layoutKind, VectorType vectorTy,
+ int chunkSize, DistributeLayoutAttr consumerLayout,
+ const uArch::uArch *uArch);
+
+DistributeLayoutAttr setupStoreScatterAnchorLayout(LayoutKind layoutKind,
+ VectorType vectorTy,
+ int chunkSize,
+ const uArch::uArch *uArch);
} // namespace xegpu
} // namespace mlir
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 78d1ad50dd1dd..c7b73dd59767e 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -540,7 +540,8 @@ LayoutAttr::collapseDims(SmallVector<SmallVector<int64_t>> dimGroups) const {
collapsedInst *= instData[dimIdx];
collapsedLaneL *= laneLayout[dimIdx];
collapsedLaneD *= laneData[dimIdx];
- collapsedOrderValue = order[dimIdx]; // take the last one's order
+ if (!orderVec.empty())
+ collapsedOrderValue = orderVec[dimIdx]; // take the last one's order
}
collapsedSgLayout.push_back(collapsedSg);
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index e6fa8bfd12474..44127d27de522 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -695,7 +695,7 @@ void LayoutInfoPropagation::visitVectorMultiReductionOp(
// propagated from consumer op, the conflict is resolved in later phase by
// converting the required result layout to the consumer layout
auto uArch = getUArch(xegpu::getChipStr(reduction).value_or(""));
- auto requiredResLayoutAttr = xegpu::reductionSetupResultLayout(
+ auto requiredResLayoutAttr = xegpu::setupMultiReductionResultLayout(
layoutKind, sourceTy, consumerLayoutAttr, reductionDims, uArch);
LLVM_DEBUG(DBGS() << "visitVectorMultiReductionOp: requiredResLayoutAttr = "
<< requiredResLayoutAttr << "\n");
@@ -713,8 +713,8 @@ void LayoutInfoPropagation::visitVectorMultiReductionOp(
results[0]->getValue().print(llvm::dbgs()); llvm::dbgs() << "\n");
// derive the source layout from the dominant layout and reduction dims
- auto srcLayoutAttr =
- xegpu::inferReductionSourceLayout(requiredResLayoutAttr, reductionDims);
+ auto srcLayoutAttr = xegpu::inferMultiReductionSourceLayout(
+ requiredResLayoutAttr, reductionDims);
LLVM_DEBUG(DBGS() << "visitVectorMultiReductionOp: srcLayoutAttr = "
<< srcLayoutAttr << "\n");
@@ -1129,7 +1129,7 @@ void LayoutInfoPropagation::visitVectorBitcastOp(
auto consumerLayoutAttr =
dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
auto uArch = getUArch(xegpu::getChipStr(bitcast).value_or(""));
- auto requiredResLayoutAttr = bitCastSetupResultLayout(
+ auto requiredResLayoutAttr = setupBitCastResultLayout(
layoutKind, srcVecType, resVecType, consumerLayoutAttr, uArch);
xegpu::setTemporaryLayout(bitcast->getResult(0), requiredResLayoutAttr);
@@ -1144,6 +1144,40 @@ void LayoutInfoPropagation::visitVectorBitcastOp(
propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
}
+// /// For vector::BitCastOp, the lane_data of the source layout is changed
+// based
+// /// on the bit width of the source and result types.
+// void LayoutInfoPropagation::visitVectorInsertStridedSliceOp(
+// vector::BitCastOp bitcast, ArrayRef<LayoutInfoLattice *> operands,
+// ArrayRef<const LayoutInfoLattice *> results) {
+// // Need the layout of bitcast result to propagate to the operands.
+// LayoutInfo resLayoutInfo = results[0]->getValue();
+// if (!resLayoutInfo.isAssigned())
+// return;
+
+// auto srcVecType = bitcast.getSourceVectorType();
+// auto resVecType = bitcast.getResultVectorType();
+
+// auto consumerLayoutAttr =
+// dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
+// auto uArch = getUArch(xegpu::getChipStr(bitcast).value_or(""));
+// auto requiredResLayoutAttr = setupInsertStridedSliceResultLayout(
+// layoutKind, srcVecType, resVecType, consumerLayoutAttr, uArch);
+
+// xegpu::setTemporaryLayout(bitcast->getResult(0), requiredResLayoutAttr);
+
+// int inElemTyBitWidth = srcVecType.getElementType().getIntOrFloatBitWidth();
+// int outElemTyBitWidth =
+// resVecType.getElementType().getIntOrFloatBitWidth();
+
+// // derive the source layout from the dominant layout and reduction dims
+// auto srcLayoutAttr = xegpu::inferBitCastSourceLayout(
+// requiredResLayoutAttr, outElemTyBitWidth, inElemTyBitWidth);
+
+// propagateIfChanged(operands[0],
+// operands[0]->meet(LayoutInfo(srcLayoutAttr)));
+// }
+
/// Propagate the layout of the result to the tensor descriptor, mask and offset
/// operands in LoadGatherOp.
void LayoutInfoPropagation::visitLoadGatherOp(
@@ -1155,9 +1189,14 @@ void LayoutInfoPropagation::visitLoadGatherOp(
auto uArch = getUArch(getChipStr(load).value_or(""));
auto subgroupSize = uArch->getSubgroupSize();
VectorType resVecTy = load.getValueType();
- VectorType maskTy = llvm::dyn_cast<VectorType>(load.getMask().getType());
int chunkSize = load.getChunkSize().value_or(1);
+ LayoutInfo resLayoutInfo = results[0]->getValue();
+ if (!resLayoutInfo.isAssigned())
+ return;
+ auto consumerLayoutAttr =
+ dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
+
if (hasParamsOfLayoutKind(anchorLayoutAttr)) {
requiredAnchorLayoutAttr = anchorLayoutAttr;
} else {
@@ -1166,11 +1205,11 @@ void LayoutInfoPropagation::visitLoadGatherOp(
return;
}
requiredAnchorLayoutAttr = xegpu::setupLoadGatherAnchorLayout(
- layoutKind, resVecTy, chunkSize, uArch);
+ layoutKind, resVecTy, chunkSize, consumerLayoutAttr, uArch);
load.setLayoutAttr(requiredAnchorLayoutAttr);
}
- auto maskLayoutAttr = anchorLayoutAttr;
+ auto maskLayoutAttr = requiredAnchorLayoutAttr;
// Special handling mask layout for chunked ops: Enforce the default xegpu 1D
// layout for mask.
if (chunkSize > 1) {
@@ -1186,14 +1225,15 @@ void LayoutInfoPropagation::visitLoadGatherOp(
}
LayoutInfo maskLayoutInfo = LayoutInfo(maskLayoutAttr);
+ auto loadLayoutInfo = LayoutInfo(requiredAnchorLayoutAttr);
// Propagate the new layout to the tensor descriptor operand.
if (isa<xegpu::TensorDescType>(load.getSourceType()))
- propagateIfChanged(operands[0], operands[0]->meet(loadLayout));
+ propagateIfChanged(operands[0], operands[0]->meet(loadLayoutInfo));
// Propagate the new layout to the mask and optional offset operand.
- propagateIfChanged(operands[1], operands[1]->meet(maskLayout));
+ propagateIfChanged(operands[1], operands[1]->meet(maskLayoutInfo));
if (load.getOffsets())
- propagateIfChanged(operands[2], operands[2]->meet(maskLayout));
+ propagateIfChanged(operands[2], operands[2]->meet(maskLayoutInfo));
}
/// Propagate the layout of the descriptor to the vector offset operand in
@@ -1218,6 +1258,8 @@ void LayoutInfoPropagation::visitStoreScatterOp(
xegpu::StoreScatterOp storeScatter, ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results) {
+ LLVM_DEBUG(DBGS() << "visitStoreScatterOp: Processing store scatter op\n");
+
xegpu::DistributeLayoutAttr requiredAnchorLayoutAttr;
xegpu::DistributeLayoutAttr anchorLayoutAttr = storeScatter.getLayoutAttr();
auto uArch = getUArch(getChipStr(storeScatter).value_or(""));
@@ -1227,9 +1269,19 @@ void LayoutInfoPropagation::visitStoreScatterOp(
llvm::dyn_cast<VectorType>(storeScatter.getMask().getType());
int chunkSize = storeScatter.getChunkSize().value_or(1);
+ LLVM_DEBUG(DBGS() << "visitStoreScatterOp: anchorLayoutAttr = "
+ << anchorLayoutAttr << "\n");
+ LLVM_DEBUG(DBGS() << "visitStoreScatterOp: subgroupSize = " << subgroupSize
+ << "\n");
+ LLVM_DEBUG(DBGS() << "visitStoreScatterOp: chunkSize = " << chunkSize
+ << "\n");
+ LLVM_DEBUG(DBGS() << "visitStoreScatterOp: srcVecTy = " << srcVecTy << "\n");
+
if (hasParamsOfLayoutKind(anchorLayoutAttr)) {
+ LLVM_DEBUG(DBGS() << "visitStoreScatterOp: Using existing anchor layout\n");
requiredAnchorLayoutAttr = anchorLayoutAttr;
} else {
+ LLVM_DEBUG(DBGS() << "visitStoreScatterOp: Setting up new anchor layout\n");
if (!srcVecTy) {
storeScatter.emitWarning("Not propagating, non-vector payload supplied.");
return;
@@ -1239,11 +1291,16 @@ void LayoutInfoPropagation::visitStoreScatterOp(
storeScatter.setLayoutAttr(requiredAnchorLayoutAttr);
}
+ LLVM_DEBUG(DBGS() << "visitStoreScatterOp: requiredAnchorLayoutAttr = "
+ << requiredAnchorLayoutAttr << "\n");
+
LayoutInfo srcLayoutInfo = LayoutInfo(requiredAnchorLayoutAttr);
- auto maskLayoutAttr = anchorLayoutAttr;
+ auto maskLayoutAttr = requiredAnchorLayoutAttr;
// Special handling mask layout for chunked ops: Enforce the default xegpu 1D
// layout for mask.
if (chunkSize > 1) {
+ LLVM_DEBUG(DBGS() << "visitStoreScatterOp: Setting mask layout for chunked "
+ "operation\n");
if (layoutKind == xegpu::LayoutKind::InstData)
maskLayoutAttr =
xegpu::LayoutAttr::get(storeScatter->getContext(), {subgroupSize});
@@ -1256,41 +1313,72 @@ void LayoutInfoPropagation::visitStoreScatterOp(
}
LayoutInfo maskLayoutInfo = LayoutInfo(maskLayoutAttr);
+ LLVM_DEBUG(DBGS() << "visitStoreScatterOp: maskLayoutAttr = "
+ << maskLayoutAttr << "\n");
// Propagate the payload operand layout
+ LLVM_DEBUG(DBGS() << "visitStoreScatterOp: Propagating payload layout\n");
propagateIfChanged(operands[0], operands[0]->meet(srcLayoutInfo));
// Propagate the destination (if tdesc) operand layout
- if (isa<xegpu::TensorDescType>(storeScatter.getDestType()))
+ if (isa<xegpu::TensorDescType>(storeScatter.getDestType())) {
+ LLVM_DEBUG(
+ DBGS() << "visitStoreScatterOp: Propagating destination layout\n");
propagateIfChanged(operands[1], operands[1]->meet(srcLayoutInfo));
+ }
// Propagate the new layout to the mask and optional offset operand.
+ LLVM_DEBUG(DBGS() << "visitStoreScatterOp: Propagating mask layout\n");
propagateIfChanged(operands[2], operands[2]->meet(maskLayoutInfo));
- if (storeScatter.getOffsets())
+ if (storeScatter.getOffsets()) {
+ LLVM_DEBUG(DBGS() << "visitStoreScatterOp: Propagating offset layout\n");
propagateIfChanged(operands[3], operands[3]->meet(maskLayoutInfo));
+ }
+ LLVM_DEBUG(DBGS() << "visitStoreScatterOp: Done\n");
}
void LayoutInfoPropagation::visitLoadMatrixOp(
xegpu::LoadMatrixOp loadMatrixOp, ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results) {
+ LLVM_DEBUG(DBGS() << "visitLoadMatrixOp: Processing load matrix op\n");
+
LayoutInfo resLayoutInfo = results[0]->getValue();
- if (!resLayoutInfo.isAssigned())
+ if (!resLayoutInfo.isAssigned()) {
+ LLVM_DEBUG(DBGS() << "visitLoadMatrixOp: Result layout not assigned\n");
return;
+ }
+
+ LLVM_DEBUG(DBGS() << "visitLoadMatrixOp: resLayoutInfo = ";
+ resLayoutInfo.print(llvm::dbgs()); llvm::dbgs() << "\n");
+
auto consumerLayoutAttr =
dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
+ LLVM_DEBUG(DBGS() << "visitLoadMatrixOp: consumerLayoutAttr = "
+ << consumerLayoutAttr << "\n");
+
xegpu::DistributeLayoutAttr anchorLayout = loadMatrixOp.getLayoutAttr();
+ LLVM_DEBUG(DBGS() << "visitLoadMatrixOp: anchorLayout = " << anchorLayout
+ << "\n");
+
// only need to set anchor layout, no need to porpagate to memdesc and
// offset
if (!hasParamsOfLayoutKind(anchorLayout)) {
+ LLVM_DEBUG(DBGS() << "visitLoadMatrixOp: Setting up new anchor layout\n");
VectorType resVecTy =
llvm::cast<VectorType>(loadMatrixOp.getRes().getType());
assert(resVecTy.getRank() == 2 && "Expecting 2D vector for store matrix.");
+ LLVM_DEBUG(DBGS() << "visitLoadMatrixOp: resVecTy = " << resVecTy << "\n");
auto uArch = getUArch(getChipStr(loadMatrixOp).value_or(""));
- auto requiredAnchorLayoutAttr = xegpu::loadMatrixSetupAnchorLayout(
+ auto requiredAnchorLayoutAttr = xegpu::setupLoadMatrixAnchorLayout(
layoutKind, resVecTy, consumerLayoutAttr, uArch);
+ LLVM_DEBUG(DBGS() << "visitLoadMatrixOp: requiredAnchorLayoutAttr = "
+ << requiredAnchorLayoutAttr << "\n");
loadMatrixOp.setLayoutAttr(requiredAnchorLayoutAttr);
+ } else {
+ LLVM_DEBUG(DBGS() << "visitLoadMatrixOp: Using existing anchor layout\n");
}
+ LLVM_DEBUG(DBGS() << "visitLoadMatrixOp: Done\n");
}
// Store matrix is a flavor of scattered store for 2D shapes.
@@ -1308,7 +1396,7 @@ void LayoutInfoPropagation::visitStoreMatrixOp(
assert(srcVecTy.getRank() == 2 && "Expecting 2D vector for store matrix.");
auto uArch = getUArch(getChipStr(storeMatrix).value_or(""));
auto requiredAnchorLayoutAttr =
- xegpu::storeMatrixSetupAnchorLayout(layoutKind, srcVecTy, uArch);
+ xegpu::setupStoreMatrixAnchorLayout(layoutKind, srcVecTy, uArch);
storeMatrix.setLayoutAttr(requiredAnchorLayoutAttr);
layout = LayoutInfo(requiredAnchorLayoutAttr);
}
@@ -1591,6 +1679,12 @@ void XeGPUPropagateLayoutPass::runOnOperation() {
if (!layout.isAssigned())
return {};
if (auto opResult = dyn_cast<OpResult>(val)) {
+
+ Operation *defOp = opResult.getDefiningOp();
+ if (auto anchorOp = dyn_cast<xegpu::AnchorLayoutInterface>(defOp)) {
+ return anchorOp.getAnchorLayout();
+ }
+
xegpu::DistributeLayoutAttr requiredResLayoutAttr =
xegpu::getTemporaryLayout(opResult);
if (requiredResLayoutAttr != nullptr)
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp
index 8d7d1b96b363e..50f2caaaba1ea 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp
@@ -362,8 +362,8 @@ xegpu::inferBroadcastSourceLayout(xegpu::DistributeLayoutAttr resLayout,
/// Infers the source layout attribute for a reduction operation given the
/// result layout attribute and reduced dims.
xegpu::DistributeLayoutAttr
-xegpu::inferReductionSourceLayout(xegpu::DistributeLayoutAttr resLayout,
- SmallVector<int64_t> reduceDims) {
+xegpu::inferMultiReductionSourceLayout(xegpu::DistributeLayoutAttr resLayout,
+ SmallVector<int64_t> reduceDims) {
// assert the resLayout must be slice layout
assert(isa<xegpu::SliceAttr>(resLayout) &&
@@ -532,61 +532,6 @@ xegpu::inferShapeCastSourceLayout(xegpu::DistributeLayoutAttr resLayout,
return nullptr;
}
-xegpu::DistributeLayoutAttr
-inferScatterIOMaskLayout(xegpu::DistributeLayoutAttr resLayout,
- ArrayRef<int64_t> valShape,
- ArrayRef<int64_t> maskShape) {
-
- xegpu::LayoutAttr resPlainLayout = dyn_cast<xegpu::LayoutAttr>(resLayout);
- assert(resPlainLayout &&
- "Expecting plain layout for scatter IO mask inference.");
- xegpu::LayoutAttr maskLayout = resPlainLayout;
-
- if (maskShape.size() < valShape.size()) {
- int valShapeSize = valShape.size();
- int maskShapeSize = maskShape.size();
- assert((maskShapeSize == (valShapeSize - 1)) &&
- "Expecting mask vector only 1 dimension less than value vector.");
-
- maskLayout = resPlainLayout.setUnitDimData(valShapeSize - 1);
- maskLayout = maskLayout.collapseDim(valShapeSize - 1);
- // SmallVector<int> sgLayout(valShapeSize);
- // sgLayout = resPlainLayout.getSgLayout();
- // SmallVector<int> sgData(valShapeSize) = resPlainLayout.getSgData();
- // SmallVector<int> instData(valShapeSize) = resPlainLayout.getSgLayout();
- // SmallVector<int> laneLayout(valShapeSize) =
- // resPlainLayout.getLaneLayout(); SmallVector<int> laneData(valShapeSize) =
- // resPlainLayout.getSgLayout(); SmallVector<int> order(valShapeSize) =
- // resPlainLayout.getOrder();
-
- // // drop the innermost dimension for mask layout
- // sgLayout.pop_back();
- // sgData.pop_back();
- // instData.pop_back();
- // laneLayout.pop_back();
- // laneData.pop_back();
- // order.pop_back();
-
- // SmallVector<int64_t> remappedOrder(maskShapeSize, -1);
- // for (int64_t i = 0; i < maskShapeSize; ++i) {
- // int64_t originalOrderValue = order[i];
- // // count how many values in collapsed Order are less than
- // // originalOrderValue
- // int64_t count = 0;
- // for (int64_t j = 0; j < maskShapeSize; ++j) {
- // if (order[j] < originalOrderValue)
- // count++;
- // }
- // remappedOrder[i] = count;
- // }
-
- // maskLayout =
- // xegpu::LayoutAttr::get(resLayout.getContext(), sgLayout, sgData,
- // instData, laneLayout, laneData, order);
- }
- return maskLayout;
-}
-
/// Sets up layout for reduction operations by creating a SliceAttr for the
/// result.
///
@@ -968,15 +913,28 @@ xegpu::DistributeLayoutAttr xegpu::setupLoadGatherAnchorLayout(
LayoutKind layoutKind, VectorType resVecTy, int chunkSize,
DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch) {
+ llvm::dbgs() << "setupLoadGatherAnchorLayout: layoutKind="
+ << static_cast<int>(layoutKind) << ", chunkSize=" << chunkSize
+ << "\n";
+
xegpu::DistributeLayoutAttr requiredLayout;
const int subgroupSize = uArch->getSubgroupSize();
const int spirVectorSize = 16; // vector size from SPRIV vector restriction
auto resShape = resVecTy.getShape();
int resShapeSize = resShape.size();
- SmallVector<int64_t> instData(subgroupSize);
+ SmallVector<int> instData(resShapeSize);
auto context = resVecTy.getContext();
+ llvm::dbgs() << "setupLoadGatherAnchorLayout: resVecTy.getRank()="
+ << resVecTy.getRank() << ", resShape=[";
+ for (size_t i = 0; i < resShape.size(); ++i) {
+ if (i > 0)
+ llvm::dbgs() << ", ";
+ llvm::dbgs() << resShape[i];
+ }
+ llvm::dbgs() << "]\n";
+
const auto *uArchInstruction =
dyn_cast<xegpu::uArch::StoreScatterInstruction>(
uArch->getInstruction(xegpu::uArch::InstructionKind::StoreScatter));
@@ -985,32 +943,48 @@ xegpu::DistributeLayoutAttr xegpu::setupLoadGatherAnchorLayout(
consumerLayout.getEffectiveInstDataAsInt();
SmallVector<int32_t> instData32;
+ llvm::dbgs() << "setupLoadGatherAnchorLayout: consumerInstData=[";
+ for (size_t i = 0; i < consumerInstData.size(); ++i) {
+ if (i > 0)
+ llvm::dbgs() << ", ";
+ llvm::dbgs() << consumerInstData[i];
+ }
+ llvm::dbgs() << "]\n";
+
switch (layoutKind) {
case xegpu::LayoutKind::Subgroup:
+ llvm::dbgs() << "setupLoadGatherAnchorLayout: LayoutKind::Subgroup\n";
requiredLayout = consumerLayout;
break;
case xegpu::LayoutKind::InstData:
+ llvm::dbgs() << "setupLoadGatherAnchorLayout: LayoutKind::InstData\n";
if (resVecTy.getRank() == 1) {
instData[0] = subgroupSize;
+ llvm::dbgs() << "setupLoadGatherAnchorLayout: 1D case, instData[0]="
+ << instData[0] << "\n";
} else {
- assert((resVecTy.getRank() > 2) && "StoreScatterOp can access 2D tensor "
- "tile at maximum at subgroup level.");
+ assert((resVecTy.getRank() == 2) && "StoreScatterOp can access 2D tensor "
+ "tile at maximum at subgroup level.");
if (chunkSize == 1) {
instData[0] = 1;
instData[1] = subgroupSize;
+ llvm::dbgs() << "setupLoadGatherAnchorLayout: chunkSize==1, instData=["
+ << instData[0] << ", " << instData[1] << "]\n";
} else {
instData[0] = subgroupSize;
- instData[1] = std::min(
- resShape[1],
- static_cast<int64_t>(uArchInstruction->getMaxLaneLoadStoreSize()));
- instData[1] = std::min(instData[1], consumerInstData[1]);
+ instData[1] = std::min(static_cast<int>(resShape[1]),
+ uArchInstruction->getMaxLaneLoadStoreSize());
+ instData[1] =
+ std::min(instData[1], static_cast<int>(consumerInstData[1]));
+ llvm::dbgs() << "setupLoadGatherAnchorLayout: chunkSize>1, instData=["
+ << instData[0] << ", " << instData[1] << "]\n";
}
}
- instData32 = SmallVector<int32_t>(instData.begin(), instData.end());
requiredLayout = xegpu::LayoutAttr::get(
- context, DenseI32ArrayAttr::get(context, instData32));
+ context, DenseI32ArrayAttr::get(context, instData));
break;
case xegpu::LayoutKind::Lane:
+ llvm::dbgs() << "setupLoadGatherAnchorLayout: LayoutKind::Lane\n";
if (chunkSize == 1)
requiredLayout =
getDefaultLaneLayoutAttr(context, resVecTy.getRank(), uArch);
@@ -1027,6 +1001,7 @@ xegpu::DistributeLayoutAttr xegpu::setupLoadGatherAnchorLayout(
default:
llvm_unreachable("unsupported layout kind");
}
+ llvm::dbgs() << "setupLoadGatherAnchorLayout: returning requiredLayout\n";
return requiredLayout;
}
@@ -1034,13 +1009,26 @@ xegpu::DistributeLayoutAttr
xegpu::setupStoreScatterAnchorLayout(LayoutKind layoutKind, VectorType srcVecTy,
int chunkSize, const uArch::uArch *uArch) {
+ llvm::dbgs() << "setupStoreScatterAnchorLayout: layoutKind="
+ << static_cast<int>(layoutKind) << ", chunkSize=" << chunkSize
+ << "\n";
+
xegpu::DistributeLayoutAttr requiredLayout;
const int subgroupSize = uArch->getSubgroupSize();
const int spirVectorSize = 16; // vector size from SPRIV vector restriction
auto srcShape = srcVecTy.getShape();
int srcShapeSize = srcShape.size();
- SmallVector<int64_t> instData(subgroupSize);
+ SmallVector<int> instData(srcShapeSize);
+
+ llvm::dbgs() << "setupStoreScatterAnchorLayout: srcVecTy.getRank()="
+ << srcVecTy.getRank() << ", srcShape=[";
+ for (size_t i = 0; i < srcShape.size(); ++i) {
+ if (i > 0)
+ llvm::dbgs() << ", ";
+ llvm::dbgs() << srcShape[i];
+ }
+ llvm::dbgs() << "]\n";
const auto *uArchInstruction =
dyn_cast<xegpu::uArch::StoreScatterInstruction>(
@@ -1049,27 +1037,39 @@ xegpu::setupStoreScatterAnchorLayout(LayoutKind layoutKind, VectorType srcVecTy,
switch (layoutKind) {
case xegpu::LayoutKind::Subgroup:
- assert(false &&
- "subgroup layout assignment not supported yet for loadMatrix.");
+ llvm::dbgs() << "setupStoreScatterAnchorLayout: LayoutKind::Subgroup\n";
+ assert(
+ false &&
+ "subgroup layout assignment not supported yet for store scatter op.");
break;
case xegpu::LayoutKind::InstData:
+ llvm::dbgs() << "setupStoreScatterAnchorLayout: LayoutKind::InstData\n";
if (srcVecTy.getRank() == 1) {
instData[0] = subgroupSize;
+ llvm::dbgs() << "setupStoreScatterAnchorLayout: 1D case, instData[0]="
+ << instData[0] << "\n";
} else {
assert((srcVecTy.getRank() <= 2) && "StoreScatterOp can access 2D tensor "
"tile at maximum at subgroup level.");
if (chunkSize == 1) {
instData[0] = 1;
instData[1] = subgroupSize;
+ llvm::dbgs()
+ << "setupStoreScatterAnchorLayout: chunkSize==1, instData=["
+ << instData[0] << ", " << instData[1] << "]\n";
} else {
instData[0] = subgroupSize;
- instData[1] = std::min(
- srcShape[1],
- static_cast<int64_t>(uArchInstruction->getMaxLaneLoadStoreSize()));
+ instData[1] = std::min(static_cast<int>(srcShape[1]),
+ uArchInstruction->getMaxLaneLoadStoreSize());
+ llvm::dbgs() << "setupStoreScatterAnchorLayout: chunkSize>1, instData=["
+ << instData[0] << ", " << instData[1] << "]\n";
}
}
+ requiredLayout = xegpu::LayoutAttr::get(
+ context, DenseI32ArrayAttr::get(context, instData));
break;
case xegpu::LayoutKind::Lane:
+ llvm::dbgs() << "setupStoreScatterAnchorLayout: LayoutKind::Lane\n";
if (chunkSize == 1)
requiredLayout =
getDefaultLaneLayoutAttr(context, srcVecTy.getRank(), uArch);
@@ -1086,5 +1086,6 @@ xegpu::setupStoreScatterAnchorLayout(LayoutKind layoutKind, VectorType srcVecTy,
default:
llvm_unreachable("unsupported layout kind");
}
+ llvm::dbgs() << "setupStoreScatterAnchorLayout: returning requiredLayout\n";
return requiredLayout;
}
\ No newline at end of file
>From c7f1c8603d45d68e26b82b219a9f38aa5c60f06a Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Wed, 28 Jan 2026 20:01:57 +0000
Subject: [PATCH 15/35] adding insert_strided_slice and improve shapecast rules
---
.../Dialect/XeGPU/Utils/XeGPULayoutUtils.h | 18 +-
.../XeGPU/Transforms/XeGPUBlocking.cpp | 4 +-
.../XeGPU/Transforms/XeGPUPropagateLayout.cpp | 52 +++++-
.../Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp | 166 ++++++++++++++----
4 files changed, 194 insertions(+), 46 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h
index 95122104f22e5..110496bb34fb3 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h
@@ -90,14 +90,19 @@ DistributeLayoutAttr inferShapeCastSourceLayout(DistributeLayoutAttr resLayout,
ArrayRef<int64_t> resShape,
ArrayRef<int64_t> srcShape);
+DistributeLayoutAttr
+inferInsertStridedSliceSourceLayout(DistributeLayoutAttr resLayout,
+ ArrayRef<int64_t> resShape,
+ ArrayRef<int64_t> srcShape);
+
/// Sets up layout for reduction operations by creating a SliceAttr for the
/// result.
///
-/// This function first attempts to construct a source layout that, when sliced
-/// along reduction dimensions, produces a result layout compatible with the
-/// consumer's preferred layout. This minimizes data redistribution overhead.
-/// The SliceAttr for the result is then created based on the derived source
-/// layout and the specified reduction dimensions.
+/// This function first attempts to construct a source layout that, when
+/// sliced along reduction dimensions, produces a result layout compatible
+/// with the consumer's preferred layout. This minimizes data redistribution
+/// overhead. The SliceAttr for the result is then created based on the
+/// derived source layout and the specified reduction dimensions.
SliceAttr setupMultiReductionResultLayout(xegpu::LayoutKind layoutKind,
VectorType srcVectorTy,
DistributeLayoutAttr consumerLayout,
@@ -117,6 +122,9 @@ DistributeLayoutAttr setupBitCastResultLayout(
LayoutKind layoutKind, VectorType srcVectorTy, VectorType resVectorTy,
DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch);
+DistributeLayoutAttr setupInsertStridedSliceResultLayout(
+ LayoutKind layoutKind, VectorType resVectorTy, const uArch::uArch *uArch);
+
DistributeLayoutAttr
setupLoadMatrixAnchorLayout(LayoutKind layoutKind, VectorType vectorTy,
DistributeLayoutAttr consumerLayout,
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
index 7af7622375ef4..29cd74ad461dd 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
@@ -162,8 +162,8 @@ XeGPUBlockingPass::getTileShape(const T &operandOrResult) const {
ownerOp &&
(isa<xegpu::CreateNdDescOp, xegpu::DpasOp, xegpu::ConvertLayoutOp,
xegpu::LoadMatrixOp, xegpu::StoreMatrixOp, xegpu::AtomicRMWOp,
- xegpu::LoadNdOp, xegpu::StoreNdOp, xegpu::PrefetchNdOp,
- vector::TransposeOp, vector::ShapeCastOp,
+ xegpu::LoadGatherOp, xegpu::LoadNdOp, xegpu::StoreNdOp,
+ xegpu::PrefetchNdOp, vector::TransposeOp, vector::ShapeCastOp,
vector::MultiDimReductionOp, vector::BroadcastOp>(ownerOp));
if (!skipLeadingUnitDimRemoval) {
auto it = llvm::find_if(instData, [](auto val) { return val != 1; });
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 44127d27de522..5cb0e259c503b 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -418,6 +418,10 @@ class LayoutInfoPropagation
void visitShapeCastOp(vector::ShapeCastOp shapeCast,
ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results);
+ void
+ visitInsertStridedSliceOp(vector::InsertStridedSliceOp insertStridedSlice,
+ ArrayRef<LayoutInfoLattice *> operands,
+ ArrayRef<const LayoutInfoLattice *> results);
void visitLoadMatrixOp(xegpu::LoadMatrixOp load,
ArrayRef<LayoutInfoLattice *> operands,
@@ -508,6 +512,9 @@ LogicalResult LayoutInfoPropagation::visitOperation(
.Case<vector::ShapeCastOp>([&](auto shapeCastOp) {
visitShapeCastOp(shapeCastOp, operands, results);
})
+ .Case<vector::InsertStridedSliceOp>([&](auto insertStridedSliceOp) {
+ visitInsertStridedSliceOp(insertStridedSliceOp, operands, results);
+ })
.Case<xegpu::LoadMatrixOp>([&](auto loadMatrixOp) {
visitLoadMatrixOp(loadMatrixOp, operands, results);
})
@@ -795,7 +802,6 @@ void LayoutInfoPropagation::visitUpdateNdOffsetOp(
void LayoutInfoPropagation::visitDpasOp(
xegpu::DpasOp dpas, ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results) {
-
LayoutInfo dpasALayout;
LayoutInfo dpasBLayout;
LayoutInfo dpasCDLayout;
@@ -992,7 +998,6 @@ void LayoutInfoPropagation::visitDpasOp(
void LayoutInfoPropagation::visitStoreNdOp(
xegpu::StoreNdOp store, ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results) {
-
LayoutInfo storeLayout;
xegpu::DistributeLayoutAttr anchorLayout = store.getLayoutAttr();
if (hasParamsOfLayoutKind(anchorLayout)) {
@@ -1073,7 +1078,6 @@ void LayoutInfoPropagation::visitStoreNdOp(
void LayoutInfoPropagation::visitLoadNdOp(
xegpu::LoadNdOp load, ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results) {
-
LayoutInfo loadLayout;
xegpu::DistributeLayoutAttr anchorLayout = load.getLayoutAttr();
if (hasParamsOfLayoutKind(anchorLayout)) {
@@ -1144,6 +1148,37 @@ void LayoutInfoPropagation::visitVectorBitcastOp(
propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
}
+void LayoutInfoPropagation::visitInsertStridedSliceOp(
+ vector::InsertStridedSliceOp insertStridedSlice,
+ ArrayRef<LayoutInfoLattice *> operands,
+ ArrayRef<const LayoutInfoLattice *> results) {
+ // The layout of the result must be present.
+ LayoutInfo resLayoutInfo = results[0]->getValue();
+ if (!resLayoutInfo.isAssigned())
+ return;
+
+ auto srcVecType = insertStridedSlice.getSourceVectorType();
+ auto resVecType = insertStridedSlice.getDestVectorType();
+
+ auto consumerLayoutAttr =
+ dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
+ auto uArch = getUArch(xegpu::getChipStr(insertStridedSlice).value_or(""));
+
+ auto requiredResLayoutAttr =
+ xegpu::setupInsertStridedSliceResultLayout(layoutKind, resVecType, uArch);
+
+ xegpu::setTemporaryLayout(insertStridedSlice->getResult(0),
+ requiredResLayoutAttr);
+
+ auto srcLayoutAttr = xegpu::inferInsertStridedSliceSourceLayout(
+ requiredResLayoutAttr, resVecType.getShape(), srcVecType.getShape());
+
+ propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
+ propagateIfChanged(operands[1],
+ operands[1]->meet(LayoutInfo(requiredResLayoutAttr)));
+ return;
+}
+
// /// For vector::BitCastOp, the lane_data of the source layout is changed
// based
// /// on the bit width of the source and result types.
@@ -1183,7 +1218,6 @@ void LayoutInfoPropagation::visitVectorBitcastOp(
void LayoutInfoPropagation::visitLoadGatherOp(
xegpu::LoadGatherOp load, ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results) {
-
xegpu::DistributeLayoutAttr requiredAnchorLayoutAttr;
xegpu::DistributeLayoutAttr anchorLayoutAttr = load.getLayoutAttr();
auto uArch = getUArch(getChipStr(load).value_or(""));
@@ -1257,7 +1291,6 @@ void LayoutInfoPropagation::visitCreateDescOp(
void LayoutInfoPropagation::visitStoreScatterOp(
xegpu::StoreScatterOp storeScatter, ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results) {
-
LLVM_DEBUG(DBGS() << "visitStoreScatterOp: Processing store scatter op\n");
xegpu::DistributeLayoutAttr requiredAnchorLayoutAttr;
@@ -1338,7 +1371,6 @@ void LayoutInfoPropagation::visitStoreScatterOp(
void LayoutInfoPropagation::visitLoadMatrixOp(
xegpu::LoadMatrixOp loadMatrixOp, ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results) {
-
LLVM_DEBUG(DBGS() << "visitLoadMatrixOp: Processing load matrix op\n");
LayoutInfo resLayoutInfo = results[0]->getValue();
@@ -1385,7 +1417,6 @@ void LayoutInfoPropagation::visitLoadMatrixOp(
void LayoutInfoPropagation::visitStoreMatrixOp(
xegpu::StoreMatrixOp storeMatrix, ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results) {
-
xegpu::DistributeLayoutAttr anchorLayout = storeMatrix.getLayoutAttr();
LayoutInfo layout;
if (hasParamsOfLayoutKind(anchorLayout)) {
@@ -1681,7 +1712,14 @@ void XeGPUPropagateLayoutPass::runOnOperation() {
if (auto opResult = dyn_cast<OpResult>(val)) {
Operation *defOp = opResult.getDefiningOp();
+ LLVM_DEBUG(DBGS() << "Try op: " << *defOp << "\n");
if (auto anchorOp = dyn_cast<xegpu::AnchorLayoutInterface>(defOp)) {
+ // inject debug print here
+ LLVM_DEBUG(DBGS() << "AnchorLayoutInterface found for op: " << *defOp
+ << "\n");
+ // print the anchor layout
+ LLVM_DEBUG(DBGS() << "Anchor layout: " << anchorOp.getAnchorLayout()
+ << "\n");
return anchorOp.getAnchorLayout();
}
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp
index 50f2caaaba1ea..9cfcad834a924 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp
@@ -434,9 +434,9 @@ xegpu::inferShapeCastSourceLayout(xegpu::DistributeLayoutAttr resLayout,
// 2. split dim of a high-rank dimension (e.g., 1D to 2D): to setup tensor
// for multi-stage reduction
// 3. combines all dims to a single dim and put in the innermost dim in 2d as
- // [1, combinedData]. only used after workgroup distribution to save
- // multidimension data to 1D slm buffer so no need to handle sg_layout and
- // sg_data.
+ // [1, combinedData] or [combinedData]. Only used after workgroup
+ // distribution. Example like cross-sg reduction saves multidimension data to
+ // 1D slm buffer, shapecast inserted by cse/canonicalization passes.
// Use case 1: Check if shapes only differ by expanding unit dimensions (like
// expand_dims)
@@ -505,27 +505,47 @@ xegpu::inferShapeCastSourceLayout(xegpu::DistributeLayoutAttr resLayout,
auto checkCombineToInnerMostDim = [&](ArrayRef<int64_t> src,
ArrayRef<int64_t> dst) -> bool {
// only one non-unit dim in dst which is the innermost dim
- assert((dst.size() == 2) && "dst shape must be 2D");
+ if ((dst.size() != 2) && (dst.size() != 1))
+ return false;
int64_t srcSize = std::accumulate(src.begin(), src.end(), 1LL,
std::multiplies<int64_t>());
+ if (dst.size() == 1)
+ return (dst[0] == srcSize);
return (dst[0] == 1) && (dst[1] == srcSize);
};
if (checkCombineToInnerMostDim(srcShape, resShape)) {
const int subgroupSize = 16; // assuming 16 lanes per subgroup
- const int vectorSize = 8; // assuming 8 elements per vector lane
int srcShapeSize = srcShape.size();
+ auto context = resLayout.getContext();
+ auto resInstData = resLayout.getEffectiveInstDataAsInt();
+ auto resLaneLayout = resLayout.getEffectiveLaneLayoutAsInt();
+ auto resLaneData = resLayout.getEffectiveLaneDataAsInt();
+ if (resInstData.size() != 0) {
+ if (resInstData.size() == 2)
+ assert(resInstData[0] == 1 &&
+ "only innermost dim can have inst_data for combine-to-1d");
+ // construct source inst_data layout like [1, ..., 1, subgroupSize]
+ SmallVector<int> inferredInstData(srcShapeSize, 1);
+ inferredInstData[srcShapeSize - 1] = resInstData[resInstData.size() - 1];
+ return xegpu::LayoutAttr::get(context, inferredInstData);
+ }
- SmallVector<int64_t> instData(srcShapeSize, 1);
- instData[srcShapeSize - 1] = subgroupSize;
- instData[srcShapeSize - 2] =
- vectorSize; // assuming 8 elements per instruction as starting point
-
- // construct a vector layout with lane_layout = [1, ..., 1, subgroupSize]
- SmallVector<int64_t> laneLayout(srcShapeSize, 1);
- laneLayout[srcShapeSize - 1] = subgroupSize;
- // construct a vector layout with lane_data = [1, ..., 1]
- SmallVector<int64_t> laneData(srcShapeSize, 1);
+ if (resLaneLayout.size() != 0) {
+ if (resInstData.size() == 2)
+ assert(resInstData[0] == 1 &&
+ "only innermost dim can have inst_data for combine-to-1d");
+ // construct source lane_layout like [1, ..., 1, subgroupSize]
+ SmallVector<int> inferredLaneLayout(srcShapeSize, 1);
+ SmallVector<int> inferredLaneData(srcShapeSize, 1);
+ inferredLaneLayout[srcShapeSize - 1] =
+ resLaneLayout[resLaneLayout.size() - 1];
+
+ inferredLaneData[srcShapeSize - 1] =
+ resLaneLayout[resLaneLayout.size() - 1];
+ return xegpu::LayoutAttr::get(context, inferredLaneLayout,
+ inferredLaneData);
+ }
}
// TODO: Complete implementation for other shape cast scenarios
@@ -846,24 +866,58 @@ xegpu::DistributeLayoutAttr xegpu::setupBitCastResultLayout(
return resLayout;
}
+xegpu::DistributeLayoutAttr
+xegpu::setupLoadMatrixAnchorLayout(xegpu::LayoutKind layoutKind,
+ VectorType resVectorTy,
+ xegpu::DistributeLayoutAttr consumerLayout,
+ const xegpu::uArch::uArch *uArch) {
+ xegpu::DistributeLayoutAttr requiredLayout;
+ auto subgroupSize = uArch->getSubgroupSize();
+ SmallVector<int> defaultInstData = {1, subgroupSize};
+ SmallVector<int> defaultLaneLayout = {1, subgroupSize};
+ SmallVector<int> defaultLaneData = {1, 1};
+ auto context = resVectorTy.getContext();
+
+ switch (layoutKind) {
+ case xegpu::LayoutKind::Subgroup:
+ requiredLayout = consumerLayout;
+ break;
+ case xegpu::LayoutKind::InstData:
+ requiredLayout = xegpu::LayoutAttr::get(context, defaultInstData);
+ break;
+ case xegpu::LayoutKind::Lane:
+ requiredLayout =
+ xegpu::LayoutAttr::get(context, defaultLaneData, defaultLaneLayout);
+ break;
+ default:
+ llvm_unreachable("unsupported layout kind");
+ }
+ return requiredLayout;
+}
+
xegpu::DistributeLayoutAttr
xegpu::setupStoreMatrixAnchorLayout(xegpu::LayoutKind layoutKind,
- VectorType vectorTy,
+ VectorType srcVectorTy,
const xegpu::uArch::uArch *uArch) {
xegpu::DistributeLayoutAttr requiredLayout;
- SmallVector<int> instData = {1, uArch->getSubgroupSize()};
+ auto subgroupSize = uArch->getSubgroupSize();
+ SmallVector<int> defaultInstData = {1, subgroupSize};
+ SmallVector<int> defaultLaneLayout = {1, subgroupSize};
+ SmallVector<int> defaultLaneData = {1, 1};
+ auto context = srcVectorTy.getContext();
+
switch (layoutKind) {
case xegpu::LayoutKind::Subgroup:
assert(true &&
"subgroup layout assignment not supported yet for storeMatrix.");
break;
case xegpu::LayoutKind::InstData:
- requiredLayout = xegpu::LayoutAttr::get(vectorTy.getContext(), instData);
+ requiredLayout = xegpu::LayoutAttr::get(context, defaultInstData);
break;
case xegpu::LayoutKind::Lane:
- requiredLayout = xegpu::LayoutAttr::get(
- vectorTy.getContext(), {1, uArch->getSubgroupSize()}, {1, 1});
+ requiredLayout =
+ xegpu::LayoutAttr::get(context, defaultLaneData, defaultLaneLayout);
break;
default:
@@ -873,30 +927,78 @@ xegpu::setupStoreMatrixAnchorLayout(xegpu::LayoutKind layoutKind,
}
xegpu::DistributeLayoutAttr
-xegpu::setupLoadMatrixAnchorLayout(xegpu::LayoutKind layoutKind,
- VectorType vectorTy,
- xegpu::DistributeLayoutAttr consumerLayout,
- const xegpu::uArch::uArch *uArch) {
- xegpu::DistributeLayoutAttr requiredLayout;
- SmallVector<int> instData = {1, uArch->getSubgroupSize()};
+xegpu::setupInsertStridedSliceResultLayout(xegpu::LayoutKind layoutKind,
+ VectorType resVectorTy,
+ const xegpu::uArch::uArch *uArch) {
+
+ xegpu::DistributeLayoutAttr requiredResLayout;
+ auto subgroupSize = uArch->getSubgroupSize();
+ auto context = resVectorTy.getContext();
+ auto resShape = resVectorTy.getShape();
+ int resShapeSize = resShape.size();
+ SmallVector<int> defaultInstData(resShapeSize, 1);
+ SmallVector<int> defaultLaneLayout(resShapeSize, 1);
+ SmallVector<int> defaultLaneData(resShapeSize, 1);
- auto context = vectorTy.getContext();
+ defaultInstData[resShapeSize - 1] = subgroupSize;
+ defaultLaneLayout[resShapeSize - 1] = subgroupSize;
switch (layoutKind) {
case xegpu::LayoutKind::Subgroup:
- requiredLayout = consumerLayout;
+ assert(true &&
+ "subgroup layout assignment not supported for insertStridedSlice.");
break;
case xegpu::LayoutKind::InstData:
- requiredLayout = xegpu::LayoutAttr::get(context, instData);
+ requiredResLayout = xegpu::LayoutAttr::get(context, defaultInstData);
break;
case xegpu::LayoutKind::Lane:
- requiredLayout =
- xegpu::LayoutAttr::get(context, {1, uArch->getSubgroupSize()}, {1, 1});
+ requiredResLayout =
+ xegpu::LayoutAttr::get(context, defaultLaneLayout, defaultLaneData);
break;
default:
llvm_unreachable("unsupported layout kind");
}
- return requiredLayout;
+ return requiredResLayout;
+}
+
+xegpu::DistributeLayoutAttr xegpu::inferInsertStridedSliceSourceLayout(
+ xegpu::DistributeLayoutAttr resLayout, ArrayRef<int64_t> resShape,
+ ArrayRef<int64_t> srcShape) {
+
+ const int subgroupSize = 16; // assuming 16 lanes per subgroup
+ int srcShapeSize = srcShape.size();
+ int resShapeSize = resShape.size();
+ int dimDiff = resShapeSize - srcShapeSize;
+
+ // assert resLayout must be a plain layout
+ assert(isa<xegpu::LayoutAttr>(resLayout) &&
+ "insertStridedSlice result layout must be plain layout");
+ auto context = resLayout.getContext();
+ auto resInstData = resLayout.getEffectiveInstDataAsInt();
+ auto resLaneLayout = resLayout.getEffectiveLaneLayoutAsInt();
+ auto resLaneData = resLayout.getEffectiveLaneDataAsInt();
+
+ if (resInstData.size() != 0) {
+ SmallVector<int> inferredInstData(srcShapeSize);
+ // remove the initial dims in resInstData to match srcShapeSize
+ for (int i = 0; i < srcShapeSize; i++)
+ inferredInstData[i] = resInstData[i + dimDiff];
+ return xegpu::LayoutAttr::get(context, inferredInstData);
+ }
+
+ if (resLaneLayout.size() != 0) {
+ // construct source lane_layout like [1, ..., 1, subgroupSize]
+ SmallVector<int> inferredLaneLayout(srcShapeSize);
+ SmallVector<int> inferredLaneData(srcShapeSize);
+ // remove the initial dims in resInstData to match srcShapeSize
+ for (int i = 0; i < srcShapeSize; i++) {
+ inferredLaneLayout[i] = resLaneLayout[i + dimDiff];
+ inferredLaneData[i] = resLaneData[i + dimDiff];
+ }
+ return xegpu::LayoutAttr::get(context, inferredLaneLayout,
+ inferredLaneData);
+ }
+ return nullptr;
}
static xegpu::DistributeLayoutAttr
>From 1c50dd3514c61ff927279032b4120d199cea65f7 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Thu, 29 Jan 2026 00:14:27 +0000
Subject: [PATCH 16/35] add sg distribution for creatmemdesc, able to
distribute
---
.../XeGPU/Transforms/XeGPUBlocking.cpp | 4 +-
.../Transforms/XeGPUSubgroupDistribute.cpp | 73 ++++++++++++++++---
.../Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp | 4 +-
3 files changed, 66 insertions(+), 15 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
index 29cd74ad461dd..7af7622375ef4 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
@@ -162,8 +162,8 @@ XeGPUBlockingPass::getTileShape(const T &operandOrResult) const {
ownerOp &&
(isa<xegpu::CreateNdDescOp, xegpu::DpasOp, xegpu::ConvertLayoutOp,
xegpu::LoadMatrixOp, xegpu::StoreMatrixOp, xegpu::AtomicRMWOp,
- xegpu::LoadGatherOp, xegpu::LoadNdOp, xegpu::StoreNdOp,
- xegpu::PrefetchNdOp, vector::TransposeOp, vector::ShapeCastOp,
+ xegpu::LoadNdOp, xegpu::StoreNdOp, xegpu::PrefetchNdOp,
+ vector::TransposeOp, vector::ShapeCastOp,
vector::MultiDimReductionOp, vector::BroadcastOp>(ownerOp));
if (!skipLeadingUnitDimRemoval) {
auto it = llvm::find_if(instData, [](auto val) { return val != 1; });
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 8898be5f13dab..80fd5feea9f97 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -1619,14 +1619,16 @@ struct VectorShapeCastDistribution : public gpu::WarpDistributionPattern {
// must be a slice of higher rank layout.
int64_t sourceRank = shapeCastOp.getSourceVectorType().getRank();
int64_t resultRank = shapeCastOp.getResultVectorType().getRank();
- if (sourceRank < resultRank && !sourceLayout.isSliceOf(resultLayout))
- return rewriter.notifyMatchFailure(
- warpOp, "shape_cast is rank reducing but source layout is not a "
- "slice of result layout");
- if (sourceRank > resultRank && !resultLayout.isSliceOf(sourceLayout))
- return rewriter.notifyMatchFailure(
- warpOp, "shape_cast is rank increasing but result layout is not a "
- "slice of source layout");
+ // if (sourceRank < resultRank && !sourceLayout.isSliceOf(resultLayout)) {
+ // return rewriter.notifyMatchFailure(
+ // warpOp, "shape_cast is rank reducing but source layout is not a "
+ // "slice of result layout");
+ // }
+ // if (sourceRank > resultRank && !resultLayout.isSliceOf(sourceLayout)) {
+ // return rewriter.notifyMatchFailure(
+ // warpOp, "shape_cast is rank increasing but result layout is not a "
+ // "slice of source layout");
+ // }
FailureOr<VectorType> sourceDistTypeOrFailure =
getDistVecTypeBasedOnLaneLayout(sourceLayout,
@@ -1906,8 +1908,56 @@ struct MemrefExtractAlignedPointerAsIndexDistribution final
auto newExtractOp = memref::ExtractAlignedPointerAsIndexOp::create(
rewriter, newWarpOp.getLoc(), extractOp.getType(),
newWarpOp.getResult(newRetIndices[0]));
- Value distributedVal = newWarpOp.getResult(operandIdx);
- rewriter.replaceAllUsesWith(distributedVal, newExtractOp.getResult());
+ Value resultVal = newWarpOp.getResult(operandIdx);
+ rewriter.replaceAllUsesWith(resultVal, newExtractOp.getResult());
+ return success();
+ }
+};
+
+struct MemrefAllocaDistribution final : public gpu::WarpDistributionPattern {
+ using gpu::WarpDistributionPattern::WarpDistributionPattern;
+ LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
+ PatternRewriter &rewriter) const override {
+ OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<memref::AllocaOp>);
+ if (!operand)
+ return rewriter.notifyMatchFailure(
+ warpOp, "warp result is not a memref::Alloca op");
+ auto allocaOp = operand->get().getDefiningOp<memref::AllocaOp>();
+ unsigned operandIdx = operand->getOperandNumber();
+ SmallVector<size_t> newRetIndices;
+ gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+ rewriter, warpOp, ValueRange{}, TypeRange{}, newRetIndices);
+ rewriter.setInsertionPointAfter(newWarpOp);
+ auto newAllocaOp = memref::AllocaOp::create(rewriter, newWarpOp.getLoc(),
+ allocaOp.getType(), nullptr);
+ Value resultVal = newWarpOp.getResult(operandIdx);
+ rewriter.replaceAllUsesWith(resultVal, newAllocaOp.getResult());
+ return success();
+ }
+};
+
+struct CreateMemDescDistribution final : public gpu::WarpDistributionPattern {
+ using gpu::WarpDistributionPattern::WarpDistributionPattern;
+ LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
+ PatternRewriter &rewriter) const override {
+ OpOperand *operand =
+ getWarpResult(warpOp, llvm::IsaPred<xegpu::CreateMemDescOp>);
+ if (!operand)
+ return rewriter.notifyMatchFailure(
+ warpOp, "warp result is not a xegpu::CreateMemDesc op");
+ auto createMemDescOp =
+ operand->get().getDefiningOp<xegpu::CreateMemDescOp>();
+ unsigned operandIdx = operand->getOperandNumber();
+ SmallVector<size_t> newRetIndices;
+ gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+ rewriter, warpOp, createMemDescOp.getSource(),
+ TypeRange{createMemDescOp.getSource().getType()}, newRetIndices);
+ rewriter.setInsertionPointAfter(newWarpOp);
+ auto newCreateMemDescOp = xegpu::CreateMemDescOp::create(
+ rewriter, newWarpOp.getLoc(), createMemDescOp.getType(),
+ newWarpOp.getResult(newRetIndices[0]));
+ Value resultVal = newWarpOp.getResult(operandIdx);
+ rewriter.replaceAllUsesWith(resultVal, newCreateMemDescOp.getResult());
return success();
}
};
@@ -2035,7 +2085,8 @@ void xegpu::populateXeGPUSubgroupDistributePatterns(
LoadDistribution, StoreDistribution, VectorTransposeDistribution,
VectorBitcastDistribution, LoadMatrixDistribution,
StoreMatrixDistribution,
- MemrefExtractAlignedPointerAsIndexDistribution>(
+ MemrefExtractAlignedPointerAsIndexDistribution,
+ MemrefAllocaDistribution, CreateMemDescDistribution>(
patterns.getContext(),
/*pattern benefit=*/regularPatternBenefit);
// For following patterns, we need to override the regular vector distribution
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp
index 9cfcad834a924..225a3874fb2ae 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp
@@ -887,7 +887,7 @@ xegpu::setupLoadMatrixAnchorLayout(xegpu::LayoutKind layoutKind,
break;
case xegpu::LayoutKind::Lane:
requiredLayout =
- xegpu::LayoutAttr::get(context, defaultLaneData, defaultLaneLayout);
+ xegpu::LayoutAttr::get(context, defaultLaneLayout, defaultLaneData);
break;
default:
llvm_unreachable("unsupported layout kind");
@@ -917,7 +917,7 @@ xegpu::setupStoreMatrixAnchorLayout(xegpu::LayoutKind layoutKind,
break;
case xegpu::LayoutKind::Lane:
requiredLayout =
- xegpu::LayoutAttr::get(context, defaultLaneData, defaultLaneLayout);
+ xegpu::LayoutAttr::get(context, defaultLaneLayout, defaultLaneData);
break;
default:
>From a7038d260f86b170d4310cba38fa51403c3a0ee3 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Thu, 29 Jan 2026 01:39:05 +0000
Subject: [PATCH 17/35] xegpu to xevm type convert fix
---
mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index 8a06271eadd84..8efbb0702f0d3 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -1190,7 +1190,7 @@ struct ConvertXeGPUToXeVMPass
return {};
auto input = inputs.front();
if (auto vecTy = dyn_cast<VectorType>(input.getType())) {
- if (vecTy.getNumElements() == 1) {
+ if (vecTy.getRank() == 1 && vecTy.getNumElements() == 1) {
// If the vector has a single element, return the element type.
Value cast =
vector::ExtractOp::create(builder, loc, input, 0).getResult();
>From 53bc42bf29d4e0b2f3e54ce64305a5d0e7f4e13b Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Thu, 29 Jan 2026 21:48:40 +0000
Subject: [PATCH 18/35] fix bitcast, reduction propagation rules, passing lane
layout tests
---
.../XeGPU/Transforms/XeGPUPropagateLayout.cpp | 28 ++-
.../Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp | 188 ++++++++++++------
mlir/test/Dialect/XeGPU/propagate-layout.mlir | 46 ++---
3 files changed, 173 insertions(+), 89 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 5cb0e259c503b..4966c657c6eca 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -1714,24 +1714,40 @@ void XeGPUPropagateLayoutPass::runOnOperation() {
Operation *defOp = opResult.getDefiningOp();
LLVM_DEBUG(DBGS() << "Try op: " << *defOp << "\n");
if (auto anchorOp = dyn_cast<xegpu::AnchorLayoutInterface>(defOp)) {
- // inject debug print here
LLVM_DEBUG(DBGS() << "AnchorLayoutInterface found for op: " << *defOp
<< "\n");
- // print the anchor layout
- LLVM_DEBUG(DBGS() << "Anchor layout: " << anchorOp.getAnchorLayout()
+ auto anchorLayout = anchorOp.getAnchorLayout();
+ LLVM_DEBUG(DBGS() << "Anchor layout: " << anchorLayout << "\n");
+ LLVM_DEBUG(DBGS() << "Anchor layout is null: "
+ << (anchorLayout == nullptr ? "true" : "false")
<< "\n");
- return anchorOp.getAnchorLayout();
+ if (anchorLayout != nullptr) {
+ LLVM_DEBUG(DBGS()
+ << "Returning anchor layout: " << anchorLayout << "\n");
+ return anchorLayout;
+ }
+ LLVM_DEBUG(DBGS() << "Anchor layout is null, continuing...\n");
}
xegpu::DistributeLayoutAttr requiredResLayoutAttr =
xegpu::getTemporaryLayout(opResult);
- if (requiredResLayoutAttr != nullptr)
+ LLVM_DEBUG(DBGS() << "Temporary layout for value: " << val << " is "
+ << requiredResLayoutAttr << "\n");
+ if (requiredResLayoutAttr != nullptr) {
+ LLVM_DEBUG(DBGS() << "Returning temporary layout: "
+ << requiredResLayoutAttr << "\n");
return requiredResLayoutAttr;
+ }
}
xegpu::DistributeLayoutAttr layoutAttr =
cast<xegpu::DistributeLayoutAttr>(layout.get());
- if (layout.isSliceLayout())
+ LLVM_DEBUG(DBGS() << "Layout attr for value: " << val << " is "
+ << layoutAttr << "\n");
+ if (layout.isSliceLayout()) {
+ LLVM_DEBUG(DBGS() << "Returning slice layout: " << layoutAttr << "\n");
return cast<xegpu::SliceAttr>(layoutAttr);
+ }
+ LLVM_DEBUG(DBGS() << "Returning layout attr: " << layoutAttr << "\n");
return cast<xegpu::LayoutAttr>(layoutAttr);
};
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp
index 225a3874fb2ae..86c328616b92a 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp
@@ -389,28 +389,63 @@ xegpu::inferBitCastSourceLayout(xegpu::DistributeLayoutAttr resLayout,
// only adjust the sg_data, inst_data, lane_data accordingly
// based on the bitwidth ratio between source and result element type
+ llvm::dbgs() << "inferBitCastSourceLayout: resElemTyBitWidth="
+ << resElemTyBitWidth
+ << ", srcElemTyBitWidth=" << srcElemTyBitWidth << "\n";
+
SmallVector<int64_t> sgData = resLayout.getEffectiveSgDataAsInt();
SmallVector<int64_t> instData = resLayout.getEffectiveInstDataAsInt();
SmallVector<int64_t> laneData = resLayout.getEffectiveLaneDataAsInt();
- size_t dim = sgData.size() - 1;
- int64_t sgDataValue, instDataValue, laneDataValue;
-
- if (srcElemTyBitWidth >= resElemTyBitWidth) {
- int bitWidthRatio = srcElemTyBitWidth / resElemTyBitWidth;
- sgDataValue = (dim < sgData.size()) ? sgData[dim] * bitWidthRatio : -1;
- instDataValue =
- (dim < instData.size()) ? instData[dim] * bitWidthRatio : -1;
- laneDataValue =
- (dim < laneData.size()) ? laneData[dim] * bitWidthRatio : -1;
- } else {
+ size_t sgDataSize = sgData.size();
+ size_t instDataSize = instData.size();
+ size_t laneDataSize = laneData.size();
+ int64_t sgDataValue = -1;
+ int64_t instDataValue = -1;
+ int64_t laneDataValue = -1;
+ int64_t dim = resLayout.getRank() - 1;
+
+ llvm::dbgs() << "inferBitCastSourceLayout: dim=" << dim
+ << ", sgData.size()=" << sgData.size()
+ << ", instData.size()=" << instData.size()
+ << ", laneData.size()=" << laneData.size() << "\n";
+
+ if (srcElemTyBitWidth <= resElemTyBitWidth) {
int bitWidthRatio = resElemTyBitWidth / srcElemTyBitWidth;
- assert((laneData[dim] % bitWidthRatio) == 0 &&
- "laneData not divisible by bitWidthRatio");
- sgDataValue = (dim < sgData.size()) ? sgData[dim] / bitWidthRatio : -1;
- instDataValue =
- (dim < instData.size()) ? instData[dim] / bitWidthRatio : -1;
- laneDataValue =
- (dim < laneData.size()) ? laneData[dim] / bitWidthRatio : -1;
+ llvm::dbgs() << "inferBitCastSourceLayout: srcElemTyBitWidth >= "
+ "resElemTyBitWidth, bitWidthRatio="
+ << bitWidthRatio << "\n";
+ if (sgDataSize)
+ sgDataValue = sgData[sgDataSize - 1] * bitWidthRatio;
+ if (instDataSize)
+ instDataValue = instData[instDataSize - 1] * bitWidthRatio;
+ if (laneDataSize)
+ laneDataValue = laneData[laneDataSize - 1] * bitWidthRatio;
+ llvm::dbgs() << "inferBitCastSourceLayout: sgDataValue=" << sgDataValue
+ << ", instDataValue=" << instDataValue
+ << ", laneDataValue=" << laneDataValue << "\n";
+ } else {
+ int bitWidthRatio = srcElemTyBitWidth / resElemTyBitWidth;
+ llvm::dbgs() << "inferBitCastSourceLayout: srcElemTyBitWidth < "
+ "resElemTyBitWidth, bitWidthRatio="
+ << bitWidthRatio << "\n";
+ if (sgDataSize) {
+ assert((sgData[sgDataSize - 1] % bitWidthRatio) == 0 &&
+ "sgData not divisible by bitWidthRatio");
+ sgDataValue = sgData[sgDataSize - 1] / bitWidthRatio;
+ }
+ if (instDataSize) {
+ assert((instData[instDataSize - 1] % bitWidthRatio) == 0 &&
+ "instData not divisible by bitWidthRatio");
+ instDataValue = instData[instDataSize - 1] / bitWidthRatio;
+ }
+ if (laneDataSize) {
+ assert((laneData[laneDataSize - 1] % bitWidthRatio) == 0 &&
+ "laneData not divisible by bitWidthRatio");
+ laneDataValue = laneData[laneDataSize - 1] / bitWidthRatio;
+ }
+ llvm::dbgs() << "inferBitCastSourceLayout: sgDataValue=" << sgDataValue
+ << ", instDataValue=" << instDataValue
+ << ", laneDataValue=" << laneDataValue << "\n";
}
// Now set only instData and laneData, preserving sgData
@@ -418,6 +453,7 @@ xegpu::inferBitCastSourceLayout(xegpu::DistributeLayoutAttr resLayout,
finalSrcLayout =
resLayout.setDimData(dim, sgDataValue, instDataValue, laneDataValue);
+ llvm::dbgs() << "inferBitCastSourceLayout: returning finalSrcLayout\n";
return finalSrcLayout;
}
@@ -608,16 +644,18 @@ xegpu::SliceAttr xegpu::setupMultiReductionResultLayout(
const int subgroupSize = uArch->getSubgroupSize();
- const int vectorSize = 16; // vector size from SPRIV vector restriction
+ const int vectorSize = 1; // vector size from SPRIV vector restriction
SmallVector<int64_t> defaultInstData(srcShapeSize, 1);
- defaultInstData[srcShapeSize - 1] = subgroupSize;
- defaultInstData[srcShapeSize - 2] =
- vectorSize; // This will be adjusted based on actual data distribution
SmallVector<int64_t> defaultLaneLayout(srcShapeSize, 1);
- defaultLaneLayout[srcShapeSize - 1] = subgroupSize;
+
SmallVector<int64_t> defaultLaneData(srcShapeSize, 1);
+ defaultInstData[srcShapeSize - 2] = vectorSize;
+ defaultInstData[srcShapeSize - 1] = subgroupSize;
+ defaultLaneLayout[srcShapeSize - 1] = subgroupSize;
+ defaultLaneData[srcShapeSize - 2] = vectorSize;
+ defaultLaneData[srcShapeSize - 1] = 1;
// Strategy 1: Try to preserve the consumer's slice layout structure
// If the consumer already expects a slice layout with the same reduction
@@ -743,28 +781,34 @@ xegpu::SliceAttr xegpu::setupMultiReductionResultLayout(
}
break;
case xegpu::LayoutKind::Lane:
- consumerLaneId = consumerLaneLayout.size() - 1;
- // For non-reduction dimensions, try to match consumer's lane_layout
- // This ensures the result after reduction has the expected distribution
- for (int i = 0; i < srcShapeSize; i++)
- if (!llvm::is_contained(reductionDims, i) && consumerLaneId >= 0) {
- laneLayout[i] = consumerLaneLayout[consumerLaneId];
- assert((srcShape[i] % laneLayout[i] == 0) &&
- "source shape not divisible by consumer lane_layout");
- laneData[i] = srcShape[i] / laneLayout[i];
- remainingLaneCount /= laneLayout[i];
- consumerLaneId--;
- }
- for (int i = 0; i < srcShapeSize; i++)
- if (llvm::is_contained(reductionDims, i)) {
- laneLayout[i] = std::min((srcShape[i] / defaultLaneLayout[i]),
- static_cast<int64_t>(remainingLaneCount));
- assert((srcShape[i] % laneLayout[i] == 0) &&
- "source shape not divisible by consumer lane_layout");
- laneData[i] = srcShape[i] / laneLayout[i];
- remainingLaneCount /= laneLayout[i];
- }
- assert(remainingLaneCount == 1 && "not all lanes have been distributed");
+ laneLayout = defaultLaneLayout;
+ laneData = defaultLaneData;
+ // consumerLaneId = consumerLaneLayout.size() - 1;
+ // // For non-reduction dimensions, try to match consumer's lane_layout
+ // // This ensures the result after reduction has the expected
+ // distribution for (int i = 0; i < srcShapeSize; i++)
+ // if (!llvm::is_contained(reductionDims, i) && consumerLaneId >= 0) {
+ // laneLayout[i] = consumerLaneLayout[consumerLaneId];
+ // assert((srcShape[i] % laneLayout[i] == 0) &&
+ // "source shape not divisible by consumer lane_layout");
+ // laneData[i] = srcShape[i] / laneLayout[i];
+ // remainingLaneCount /= laneLayout[i];
+ // consumerLaneId--;
+ // }
+ // for (int i = 0; i < srcShapeSize; i++)
+ // if (llvm::is_contained(reductionDims, i)) {
+ // laneLayout[i] =
+ // std::min(srcShape[i],
+ // static_cast<int64_t>(remainingLaneCount));
+ // assert((srcShape[i] % laneLayout[i] == 0) &&
+ // "source shape not divisible by consumer lane_layout");
+ // laneData[i] = srcShape[i] / laneLayout[i];
+ // laneData[i] = std::min(laneData[i],
+ // static_cast<int64_t>(vectorSize)); remainingLaneCount /=
+ // laneLayout[i];
+ // }
+ // assert(remainingLaneCount == 1 && "not all lanes have been
+ // distributed");
break;
default:
llvm_unreachable("unsupported layout kind");
@@ -806,17 +850,39 @@ xegpu::DistributeLayoutAttr xegpu::setupBitCastResultLayout(
xegpu::LayoutKind layoutKind, VectorType srcVecTy, VectorType resVecTy,
DistributeLayoutAttr consumerLayout, const xegpu::uArch::uArch *uArch) {
+ // print out consumerLayout for debugging
+ llvm::dbgs() << "setupBitCastResultLayout: consumerLayout=";
+ consumerLayout.print(llvm::dbgs());
+ llvm::dbgs() << "\n";
+
+ llvm::dbgs() << "setupBitCastResultLayout: layoutKind="
+ << static_cast<int>(layoutKind) << "\n";
+
int srcElemTyBitWidth = srcVecTy.getElementType().getIntOrFloatBitWidth();
int resElemTyBitWidth = resVecTy.getElementType().getIntOrFloatBitWidth();
+ llvm::dbgs() << "setupBitCastResultLayout: srcElemTyBitWidth="
+ << srcElemTyBitWidth
+ << ", resElemTyBitWidth=" << resElemTyBitWidth << "\n";
+
ArrayRef<int64_t> srcShape = srcVecTy.getShape();
SmallVector<int64_t> sgData = consumerLayout.getEffectiveSgDataAsInt();
SmallVector<int64_t> instData = consumerLayout.getEffectiveInstDataAsInt();
SmallVector<int64_t> laneData = consumerLayout.getEffectiveLaneDataAsInt();
size_t dim = srcShape.size() - 1;
- int64_t sgDataValue, instDataValue, laneDataValue;
+ int64_t sgDataValue = -1;
+ int64_t instDataValue = -1;
+ int64_t laneDataValue = -1;
+
+ llvm::dbgs() << "setupBitCastResultLayout: srcShape.size()="
+ << srcShape.size() << ", dim=" << dim << "\n";
+ llvm::dbgs() << "setupBitCastResultLayout: sgData.size()=" << sgData.size()
+ << ", instData.size()=" << instData.size()
+ << ", laneData.size()=" << laneData.size() << "\n";
const int subgroupSize = uArch->getSubgroupSize();
+ llvm::dbgs() << "setupBitCastResultLayout: subgroupSize=" << subgroupSize
+ << "\n";
if (srcElemTyBitWidth > resElemTyBitWidth) {
// When casting to a smaller bitwidth, multiply the result layout
@@ -853,17 +919,19 @@ xegpu::DistributeLayoutAttr xegpu::setupBitCastResultLayout(
default:
llvm_unreachable("unsupported layout kind");
}
- } else {
- sgDataValue = sgData[dim];
- instDataValue = instData[dim];
- laneDataValue = laneData[dim];
+ // Now set only instData and laneData, preserving sgData
+ xegpu::DistributeLayoutAttr resLayout;
+ llvm::dbgs() << "setupBitCastResultLayout: Setting dimension data - dim="
+ << dim << ", sgDataValue=" << sgDataValue
+ << ", instDataValue=" << instDataValue
+ << ", laneDataValue=" << laneDataValue << "\n";
+ resLayout = consumerLayout.setDimData(dim, sgDataValue, instDataValue,
+ laneDataValue);
+ llvm::dbgs()
+ << "setupBitCastResultLayout: resLayout created successfully\n";
+ return resLayout;
}
-
- // Now set only instData and laneData, preserving sgData
- xegpu::DistributeLayoutAttr resLayout;
- resLayout =
- consumerLayout.setDimData(dim, sgDataValue, instDataValue, laneDataValue);
- return resLayout;
+ return consumerLayout;
}
xegpu::DistributeLayoutAttr
@@ -1096,8 +1164,12 @@ xegpu::DistributeLayoutAttr xegpu::setupLoadGatherAnchorLayout(
assert(resShape[1] <= static_cast<int64_t>(
uArchInstruction->getMaxLaneLoadStoreSize()) &&
"StoreScatterOp lane size exceeds max lane load/store size.");
+ llvm::dbgs() << "setupLoadGatherAnchorLayout: Creating LayoutAttr with "
+ << "laneLayout=[" << subgroupSize << ", 1], "
+ << "laneData=[1, " << resShape[1] << "]\n";
requiredLayout = xegpu::LayoutAttr::get(
context, {subgroupSize, 1}, {1, static_cast<int>(resShape[1])});
+ llvm::dbgs() << "setupLoadGatherAnchorLayout: Created requiredLayout\n";
}
break;
default:
@@ -1176,8 +1248,8 @@ xegpu::setupStoreScatterAnchorLayout(LayoutKind layoutKind, VectorType srcVecTy,
requiredLayout =
getDefaultLaneLayoutAttr(context, srcVecTy.getRank(), uArch);
else {
- assert((srcVecTy.getRank() > 2) && "StoreScatterOp can access 2D tensor "
- "tile at maximum at subgroup level.");
+ assert((srcVecTy.getRank() <= 2) && "StoreScatterOp can access 2D tensor "
+ "tile at maximum at subgroup level.");
assert(srcShape[1] <= static_cast<int64_t>(
uArchInstruction->getMaxLaneLoadStoreSize()) &&
"StoreScatterOp lane size exceeds max lane load/store size.");
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout.mlir b/mlir/test/Dialect/XeGPU/propagate-layout.mlir
index acfd2e34c805c..a8eccfab53c37 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout.mlir
@@ -104,21 +104,18 @@ func.func @extf_truncf(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor
gpu.module @test {
// CHECK-LABEL: func.func @load_gather_with_chunksize(
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<8x16xf16>, %[[ARG1:[0-9a-zA-Z]+]]: memref<256xf16>, %[[ARG2:[0-9a-zA-Z]+]]: memref<8x16xf32>) {
-// CHECK: %[[CST:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
+// CHECK: %[[OFFSET:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
// CHECK-SAME: dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
-// CHECK-NEXT: %[[CST0:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<true> : vector<16xi1>
-// CHECK-NEXT: %[[T2:.*]] = xegpu.create_tdesc %[[ARG1]], %[[CST]] : memref<256xf16>, vector<16xindex> ->
-// CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16 : i64>, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>>
-// CHECK-NEXT: %{{.*}} = xegpu.load %[[T2]], %[[CST0]] <{layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>}>
-// CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16 : i64>, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>>, vector<16xi1> -> vector<16x16xf16>
+// CHECK-NEXT: %[[MASK:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<true> : vector<16xi1>
+// CHECK-NEXT: %{{.*}} = xegpu.load %arg1[%[[OFFSET]]], %[[MASK]] <{chunk_size = 16 : i64, layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 16]>}> : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x16xf16>
func.func @load_gather_with_chunksize(%arg0: memref<8x16xf16>, %arg1: memref<256xf16>, %arg2: memref<8x16xf32>) {
%c0 = arith.constant 0 : index
%0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
%1 = xegpu.load_nd %0 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
- %cst = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
- %cst_0 = arith.constant dense<true> : vector<16xi1>
- %2 = xegpu.create_tdesc %arg1, %cst : memref<256xf16>, vector<16xindex> -> !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16 : i64>>
- %3 = xegpu.load %2, %cst_0 : !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16 : i64>>, vector<16xi1> -> vector<16x16xf16>
+ %offset = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
+ %mask = arith.constant dense<true> : vector<16xi1>
+ %3 = xegpu.load %arg1[%offset], %mask <{chunk_size=16}>
+ : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x16xf16>
%4 = vector.transpose %3, [1, 0] : vector<16x16xf16> to vector<16x16xf16>
%5 = xegpu.dpas %1, %4 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
%6 = xegpu.create_nd_tdesc %arg2[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
@@ -151,16 +148,15 @@ func.func @load_gather_1d(%arg0: memref<256xf32>, %arg1: !xegpu.tensor_desc<16xf
gpu.module @test {
// CHECK-LABEL: func.func @store_scatter_with_chunksize(
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<128xf32>) {
-// CHECK: %[[T0:.*]] = xegpu.create_tdesc %[[ARG0]], %{{.*}} : memref<128xf32>, vector<16xindex> ->
-// CHECK-SAME: !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<chunk_size = 8 : i64>, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>>
-// CHECK-NEXT: xegpu.store %{{.*}}, %[[T0]], %{{.*}} : vector<16x8xf32>, !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<chunk_size = 8 : i64>,
-// CHECK-SAME: #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>>, vector<16xi1>
+// CHECK-NEXT: %[[CST:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 8]>} dense<1.000000e+00> : vector<16x8xf32>
+// CHECK-NEXT: %[[CST_0:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<true> : vector<16xi1>
+// CHECK-NEXT: %[[CST_1:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
+// CHECK-NEXT: xegpu.store %[[CST]], %[[ARG0]][%[[CST_1]]], %[[CST_0]] <{chunk_size = 8 : i64, layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 8]>}> : vector<16x8xf32>, memref<128xf32>, vector<16xindex>, vector<16xi1>
func.func @store_scatter_with_chunksize(%arg0: memref<128xf32>) {
- %cst = arith.constant dense<1.000000e+00> : vector<16x8xf32>
- %cst_0 = arith.constant dense<true> : vector<16xi1>
- %cst_1 = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
- %0 = xegpu.create_tdesc %arg0, %cst_1 : memref<128xf32>, vector<16xindex> -> !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<chunk_size = 8 : i64>>
- xegpu.store %cst, %0, %cst_0 : vector<16x8xf32>, !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<chunk_size = 8 : i64>>, vector<16xi1>
+ %val = arith.constant dense<1.000000e+00> : vector<16x8xf32>
+ %mask = arith.constant dense<true> : vector<16xi1>
+ %offset = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
+ xegpu.store %val, %arg0[%offset], %mask <{chunk_size = 8}>: vector<16x8xf32>, memref<128xf32>, vector<16xindex>, vector<16xi1>
return
}
}
@@ -184,9 +180,9 @@ gpu.module @test {
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<256xf16>) {
// CHECK: %[[MASK:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<true> : vector<16xi1>
// CHECK: %[[OFFSETS:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<12> : vector<16xindex>
-// CHECK: %[[LOAD_VEC:.*]] = xegpu.load %[[ARG0]][%[[OFFSETS]]], %[[MASK]] <{chunk_size = 8 : i64, layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>}>
+// CHECK: %[[LOAD_VEC:.*]] = xegpu.load %[[ARG0]][%[[OFFSETS]]], %[[MASK]] <{chunk_size = 8 : i64, layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 8]>}>
// CHECK-SAME: memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
-// CHECK: xegpu.store %[[LOAD_VEC]], %[[ARG0]][%[[OFFSETS]]], %[[MASK]] <{chunk_size = 8 : i64, layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>}> : vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
+// CHECK: xegpu.store %[[LOAD_VEC]], %[[ARG0]][%[[OFFSETS]]], %[[MASK]] <{chunk_size = 8 : i64, layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 8]>}> : vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
func.func @scatter_ops_chunksize(%src: memref<256xf16>) {
%1 = arith.constant dense<1>: vector<16xi1>
%offset = arith.constant dense<12> : vector<16xindex>
@@ -320,8 +316,9 @@ func.func @vector_bitcast_i16_to_i32(%arg0: memref<8x32xi16>, %arg1: memref<8x16
// -----
gpu.module @test {
// CHECK-LABEL: func.func @vector_bitcast_require_cross_lane_shuffle(
-// CHECK: %[[LOAD:.*]] = xegpu.load_nd %{{.*}} : !xegpu.tensor_desc<8x16xi32> -> vector<8x16xi32>
-// CHECK: %{{.*}} = vector.bitcast %[[LOAD]] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+// CHECK: %[[LOAD:.*]] = xegpu.load_nd %{{.*}} <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}>
+// CHECK-SAME: !xegpu.tensor_desc<8x16xi32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+// CHECK: %{{.*}} = vector.bitcast %[[LOAD]] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]>}
// CHECK-SAME: vector<8x16xi32> to vector<8x32xi16>
func.func @vector_bitcast_require_cross_lane_shuffle(%arg0: memref<8x16xi32>, %arg1: memref<8x32xi16>) {
%c0 = arith.constant 0 : index
@@ -703,8 +700,7 @@ func.func @vector_broadcast_scalar_to_vector(%arg0: !xegpu.tensor_desc<16x16xf16
gpu.module @test {
// CHECK-LABEL: func.func @store_matrix(
// CHECK: %[[CST:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} dense<0.000000e+00> : vector<16x16xf16>
-// CHECK-NEXT: xegpu.store_matrix %[[CST]], %arg0[8, 8] : vector<16x16xf16>, !xegpu.mem_desc<16x64xf16>
-
+// CHECK-NEXT: xegpu.store_matrix %[[CST]], %arg0[8, 8] <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}>
func.func @store_matrix(%arg0: !xegpu.mem_desc<16x64xf16>) {
%cst = arith.constant dense<0.0000> : vector<16x16xf16>
xegpu.store_matrix %cst, %arg0[8, 8]: vector<16x16xf16>, !xegpu.mem_desc<16x64xf16>
>From 84dc02a8d420e4a323eb39a98f86568eeb03a3d6 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Thu, 29 Jan 2026 23:16:59 +0000
Subject: [PATCH 19/35] passing all tests
---
.../XeGPU/Transforms/XeGPUSubgroupDistribute.cpp | 10 ----------
.../XeGPU/propagate-layout-inst-data.mlir | 2 +-
.../Dialect/XeGPU/subgroup-distribute-unit.mlir | 16 ++++++++--------
3 files changed, 9 insertions(+), 19 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 80fd5feea9f97..4de0905a9b7e2 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -1619,16 +1619,6 @@ struct VectorShapeCastDistribution : public gpu::WarpDistributionPattern {
// must be a slice of higher rank layout.
int64_t sourceRank = shapeCastOp.getSourceVectorType().getRank();
int64_t resultRank = shapeCastOp.getResultVectorType().getRank();
- // if (sourceRank < resultRank && !sourceLayout.isSliceOf(resultLayout)) {
- // return rewriter.notifyMatchFailure(
- // warpOp, "shape_cast is rank reducing but source layout is not a "
- // "slice of result layout");
- // }
- // if (sourceRank > resultRank && !resultLayout.isSliceOf(sourceLayout)) {
- // return rewriter.notifyMatchFailure(
- // warpOp, "shape_cast is rank increasing but result layout is not a "
- // "slice of source layout");
- // }
FailureOr<VectorType> sourceDistTypeOrFailure =
getDistVecTypeBasedOnLaneLayout(sourceLayout,
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir b/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
index 5aad0f592abed..3bc80e9b1596d 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
@@ -217,7 +217,7 @@ gpu.module @test {
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<1024xf32>) {
// CHECK: %{{.*}} = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [16]>} dense<true> : vector<16xi1>
// CHECK: %{{.*}} = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [16]>} dense<12> : vector<16xindex>
-// CHECK: %[[LOADED:.*]] = xegpu.load %[[ARG0]][%{{.*}}], %{{.*}} <{layout = #xegpu.slice<#xegpu.layout<inst_data = [16, 16]>, dims = [0]>}> :
+// CHECK: %[[LOADED:.*]] = xegpu.load %[[ARG0]][%{{.*}}], %{{.*}} <{layout = #xegpu.layout<inst_data = [16]>}> :
// CHECK-SAME: memref<1024xf32>, vector<16xindex>, vector<16xi1> -> vector<16xf32>
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[LOADED]] {layout_result_0 = #xegpu.layout<inst_data = [16, 16]>} : vector<16xf32> to vector<16x16xf32>
// CHECK: xegpu.store %[[BCAST]], %[[ARG0]][%{{.*}}], %{{.*}} <{chunk_size = 16 : i64, layout = #xegpu.layout<inst_data = [16, 16]>}> :
diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir
index b136c89925682..3a978f68ad9c0 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir
@@ -579,16 +579,15 @@ gpu.func @vector_shapecast_rank_reducing(%laneid: index) {
}
-// NOTE: Layouts are still valid, but distribution still requires a slice layout for the operand.
-//
-// CHECK-LABEL: gpu.func @vector_shapecast_unsupported
-// CHECK: %[[W:.*]] = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<1x1xf32>) {
-// CHECK: %[[T1:.*]] = vector.shape_cast %{{.*}} : vector<16xf32> to vector<1x16xf32>
-// CHECK: gpu.yield %[[T1]] : vector<1x16xf32>
+// CHECK-LABEL: gpu.func @vector_shapecast_rank_increasing_without_slicing_layout
+// CHECK: %[[W:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<1x1xf32>, vector<1xf32>) {
+// CHECK: %[[T1:.*]] = vector.shape_cast %{{.*}} {layout_operand_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>, layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16xf32> to vector<1x16xf32>
+// CHECK: gpu.yield %[[T1]], %{{.*}} : vector<1x16xf32>, vector<16xf32>
// CHECK: }
-// CHECK: "some_user_op"(%[[W]]) : (vector<1x1xf32>) -> ()
+// CHECK: %{{.*}} = vector.shape_cast %[[W]]#1 : vector<1xf32> to vector<1x1xf32>
// CHECK: gpu.return
-gpu.func @vector_shapecast_unsupported(%laneid: index) {
+gpu.module @xevm_module{
+gpu.func @vector_shapecast_rank_increasing_without_slicing_layout(%laneid: index) {
%r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<1x1xf32>) {
%cst = "some_op"()
{layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]> }
@@ -604,6 +603,7 @@ gpu.func @vector_shapecast_unsupported(%laneid: index) {
"some_user_op"(%r) : (vector<1x1xf32>) -> ()
gpu.return
}
+}
// CHECK-LABEL: gpu.func @vector_extract_strided_slice_distributed_dim_fully_extracted
>From ed1e3a02459f09f25476b13928c73c5ac8bb760a Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Thu, 29 Jan 2026 23:58:38 +0000
Subject: [PATCH 20/35] mov xegpulayoututils to transforms directory
---
.../XeGPULayoutImpls.h} | 0
.../mlir/Dialect/XeGPU/Utils/XeGPUUtils.h | 20 +-
.../Dialect/XeGPU/Transforms/CMakeLists.txt | 1 +
.../XeGPU/Transforms/XeGPUBlocking.cpp | 2 +-
.../XeGPULayoutImpls.cpp} | 202 +-----------------
.../Transforms/XeGPUPeepHoleOptimizer.cpp | 2 +-
.../XeGPU/Transforms/XeGPUPropagateLayout.cpp | 2 +-
.../Transforms/XeGPUSubgroupDistribute.cpp | 2 +-
.../Dialect/XeGPU/Transforms/XeGPUUnroll.cpp | 2 +-
.../Transforms/XeGPUWgToSgDistribute.cpp | 2 +-
mlir/lib/Dialect/XeGPU/Utils/CMakeLists.txt | 1 -
mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp | 200 +++++++++++++++++
12 files changed, 218 insertions(+), 218 deletions(-)
rename mlir/include/mlir/Dialect/XeGPU/{Utils/XeGPULayoutUtils.h => Transforms/XeGPULayoutImpls.h} (100%)
rename mlir/lib/Dialect/XeGPU/{Utils/XeGPULayoutUtils.cpp => Transforms/XeGPULayoutImpls.cpp} (86%)
diff --git a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h b/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpls.h
similarity index 100%
rename from mlir/include/mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h
rename to mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpls.h
diff --git a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
index 0f1ca8e38c873..572c174bccbf1 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
@@ -120,16 +120,6 @@ template <typename T>
int getLargestDivisor(T dim, ArrayRef<T> candidates,
ArrayRef<T> candidateMultiples = {});
-/// [to-be-deprecated] Sets the DistributeLayoutAttr for a given OpResult
-/// user should use setAnchorLayout instead
-void setDistributeLayoutAttr(const OpResult &Result,
- const DistributeLayoutAttr layout);
-
-/// [to-be-deprecated] Sets the DistributeLayoutAttr for a given OpOperand
-/// user should use setAnchorLayout instead
-void setDistributeLayoutAttr(const OpOperand &opr,
- const DistributeLayoutAttr layout);
-
/// Retrieves the DistributeLayoutAttr associated with a given Value. For
/// TensorDescType values, the DistributeLayoutAttr is extracted from the
/// TensorDescType itself. For other values, it is obtained from the attributes
@@ -142,6 +132,16 @@ DistributeLayoutAttr getDistributeLayoutAttr(const Value value);
/// found, it will check the operand itself and its defining op.
DistributeLayoutAttr getDistributeLayoutAttr(const OpOperand &opr);
+/// [to-be-deprecated] Sets the DistributeLayoutAttr for a given OpResult
+/// user should use setAnchorLayout instead
+void setDistributeLayoutAttr(const OpResult &Result,
+ const DistributeLayoutAttr layout);
+
+/// [to-be-deprecated] Sets the DistributeLayoutAttr for a given OpOperand
+/// user should use setAnchorLayout instead
+void setDistributeLayoutAttr(const OpOperand &opr,
+ const DistributeLayoutAttr layout);
+
/// Return the attribute name for the OpOperand to attach DistributeLayoutAttr
std::string getTemporaryLayoutName(const OpOperand &operand);
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
index 15d31eadcb6df..3492bf78bf785 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
@@ -7,6 +7,7 @@ add_mlir_dialect_library(MLIRXeGPUTransforms
XeGPUPropagateLayout.cpp
XeGPUVectorLinearize.cpp
XeGPUPeepHoleOptimizer.cpp
+ XeGPULayoutImpls.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/XeGPU
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
index 7af7622375ef4..c62b8e237596e 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
@@ -12,7 +12,7 @@
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
-#include "mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h"
+#include "mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpls.h"
#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
#include "mlir/Interfaces/LoopLikeInterface.h"
#include "mlir/Pass/PassManager.h"
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpls.cpp
similarity index 86%
rename from mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp
rename to mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpls.cpp
index 86c328616b92a..1b0ec6f11b5f9 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpls.cpp
@@ -10,7 +10,7 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h"
+#include "mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpls.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/XeVMDialect.h"
@@ -28,206 +28,6 @@
using namespace mlir;
-std::string xegpu::getTemporaryLayoutName(const OpOperand &operand) {
- const StringRef prefix("layout_operand_");
- unsigned idx = const_cast<OpOperand &>(operand).getOperandNumber();
- return llvm::formatv("{0}{1}", prefix, idx).str();
-}
-
-std::string xegpu::getTemporaryLayoutName(const OpResult result) {
- const StringRef prefix = "layout_result_";
- return llvm::formatv("{0}{1}", prefix, result.getResultNumber()).str();
-}
-
-xegpu::DistributeLayoutAttr xegpu::getDistributeLayoutAttr(const Value value) {
- if (!value)
- return nullptr;
-
- if (auto tdescTy =
- dyn_cast_if_present<xegpu::TensorDescType>(value.getType()))
- return tdescTy.getLayoutAttr();
-
- if (auto result = dyn_cast<OpResult>(value)) {
- Operation *defOp = result.getDefiningOp();
- assert(defOp && "result must have a defining op");
-
- if (auto anchorOp = dyn_cast<xegpu::AnchorLayoutInterface>(defOp)) {
- auto layout = anchorOp.getAnchorLayout();
- return layout;
- }
-
- std::string layoutName = getTemporaryLayoutName(result);
- if (defOp->hasAttr(layoutName)) {
- auto layout =
- defOp->getAttrOfType<xegpu::DistributeLayoutAttr>(layoutName);
- return layout;
- }
- }
-
- if (auto arg = dyn_cast<BlockArgument>(value)) {
- auto *parentOp = arg.getOwner()->getParentOp();
- if (auto loop = dyn_cast<LoopLikeOpInterface>(parentOp)) {
- OpOperand *tiedInit = loop.getTiedLoopInit(arg);
- if (tiedInit)
- return getDistributeLayoutAttr(tiedInit->get());
- }
- }
-
- return nullptr;
-}
-xegpu::DistributeLayoutAttr
-xegpu::getDistributeLayoutAttr(const OpOperand &opr) {
- Operation *op = opr.getOwner();
- unsigned idx = const_cast<OpOperand &>(opr).getOperandNumber();
-
- if (auto anchorOp = dyn_cast<xegpu::AnchorLayoutInterface>(op)) {
- if (auto dpasOp = dyn_cast<xegpu::DpasOp>(op)) {
- if (idx == 0) {
- return dpasOp.getLayoutAAttr();
- } else if (idx == 1) {
- return dpasOp.getLayoutBAttr();
- } else if (idx == 2) {
- return dpasOp.getLayoutCdAttr();
- }
- }
- if (auto convertOp = dyn_cast<xegpu::ConvertLayoutOp>(op)) {
- return convertOp.getInputLayoutAttr();
- }
- auto layout = anchorOp.getAnchorLayout();
-
- if (idx == 0)
- return layout;
-
- // For store operations (StoreScatterOp, StoreNdOp, StoreMatrixOp),
- // the layout is valid for the first two operands: value and memref/tdesc.
- // For other operations, the layout applies to the first operand only.
- if (isa<xegpu::StoreScatterOp, xegpu::StoreNdOp, xegpu::StoreMatrixOp>(
- op) &&
- (idx < 2))
- return layout;
- }
-
- std::string layoutName = xegpu::getTemporaryLayoutName(opr);
- if (op->hasAttr(layoutName)) {
- auto layout = op->getAttrOfType<xegpu::DistributeLayoutAttr>(layoutName);
- return layout;
- }
-
- auto layout = getDistributeLayoutAttr(opr.get());
- return layout;
-}
-
-// TODO-LayoutRefactor: Remove this function after replacing use
-// with setTemporaryLayout or setAnchorLayout
-void xegpu::setDistributeLayoutAttr(
- const mlir::OpResult &result,
- const mlir::xegpu::DistributeLayoutAttr layout) {
- Operation *owner = result.getOwner();
-
- if (auto anchorOp = dyn_cast<xegpu::AnchorLayoutInterface>(owner)) {
- if (anchorOp.getAnchorLayout() == layout)
- return;
- anchorOp.setAnchorLayout(layout);
- return;
- }
-
- std::string name = xegpu::getTemporaryLayoutName(result);
- if (owner->hasAttrOfType<DistributeLayoutAttr>(name)) {
- return;
- }
- if (layout) {
- owner->setAttr(name, layout);
- }
-}
-
-// TODO-LayoutRefactor: Remove this function after replacing use
-// with setTemporaryLayout or setAnchorLayout
-void xegpu::setDistributeLayoutAttr(const OpOperand &operand,
- const DistributeLayoutAttr layout) {
- Operation *owner = operand.getOwner();
- unsigned idx = const_cast<OpOperand &>(operand).getOperandNumber();
-
- if (!layout) {
- return;
- }
- if (auto anchorOp = dyn_cast<xegpu::AnchorLayoutInterface>(owner)) {
- if (auto dpasOp = dyn_cast<xegpu::DpasOp>(owner)) {
- if (idx == 0) {
- return dpasOp.setLayoutAAttr(layout);
- } else if (idx == 1) {
- return dpasOp.setLayoutBAttr(layout);
- } else if (idx == 2) {
- return dpasOp.setLayoutCdAttr(layout);
- }
- }
- if (auto convertOp = dyn_cast<xegpu::ConvertLayoutOp>(owner)) {
- return convertOp.setInputLayoutAttr(layout);
- }
-
- // For store operations (StoreScatterOp, StoreNdOp, StoreMatrixOp),
- // the layout is valid for the first two operands: value and memref/tdesc.
- // For other operations, the layout applies to the first operand only.
- if (isa<xegpu::StoreScatterOp, xegpu::StoreNdOp, xegpu::StoreMatrixOp>(
- owner)) {
- if (idx < 2) {
- anchorOp.setAnchorLayout(layout);
- }
- } else {
- if (idx == 0) {
- anchorOp.setAnchorLayout(layout);
- }
- }
- }
-
- std::string name = xegpu::getTemporaryLayoutName(operand);
- if (owner->hasAttrOfType<DistributeLayoutAttr>(name)) {
- return;
- }
- if (layout) {
- owner->setAttr(name, layout);
- }
-}
-
-template <typename T, typename>
-xegpu::DistributeLayoutAttr
-xegpu::getTemporaryLayout(const T &operandOrResult) {
- Operation *op = operandOrResult.getOwner();
-
- std::string layoutName = xegpu::getTemporaryLayoutName(operandOrResult);
- if (op->hasAttr(layoutName)) {
- auto layout = op->getAttrOfType<xegpu::DistributeLayoutAttr>(layoutName);
- return layout;
- }
-
- return nullptr;
-}
-
-template xegpu::DistributeLayoutAttr
-xegpu::getTemporaryLayout<mlir::OpResult>(const OpResult &result);
-template xegpu::DistributeLayoutAttr
-xegpu::getTemporaryLayout<mlir::OpOperand>(const OpOperand &operand);
-
-template <typename T, typename>
-void xegpu::setTemporaryLayout(const T &operandOrResult,
- const xegpu::DistributeLayoutAttr layout) {
- Operation *owner = operandOrResult.getOwner();
- std::string name = xegpu::getTemporaryLayoutName(operandOrResult);
- if (owner->hasAttrOfType<xegpu::DistributeLayoutAttr>(name)) {
- return;
- }
- if (layout) {
- owner->setAttr(name, layout);
- }
-}
-
-template void xegpu::setTemporaryLayout<mlir::OpResult>(
- const mlir::OpResult &result,
- const mlir::xegpu::DistributeLayoutAttr layout);
-
-template void xegpu::setTemporaryLayout<mlir::OpOperand>(
- const mlir::OpOperand &operand,
- const mlir::xegpu::DistributeLayoutAttr layout);
-
void xegpu::recoverTemporaryLayoutsDeprecated(Operation *op) {
op->walk([&](Operation *nestOp) {
for (OpOperand &opr : nestOp->getOpOperands()) {
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPeepHoleOptimizer.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPeepHoleOptimizer.cpp
index 9a8925d357d25..c1d658b0d73e7 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPeepHoleOptimizer.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPeepHoleOptimizer.cpp
@@ -16,7 +16,7 @@
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
-#include "mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h"
+#include "mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpls.h"
#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
#include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h"
#include "mlir/Dialect/XeGPU/uArch/uArchBase.h"
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 4966c657c6eca..c148ba459527e 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -15,7 +15,7 @@
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
-#include "mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h"
+#include "mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpls.h"
#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 4de0905a9b7e2..121810215913c 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -14,7 +14,7 @@
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
-#include "mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h"
+#include "mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpls.h"
#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
#include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h"
#include "mlir/IR/AffineMap.h"
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index 6bafdc955c9d3..725d72de94043 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -15,7 +15,7 @@
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
-#include "mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h"
+#include "mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpls.h"
#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/DebugLog.h"
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 88dc670e0ed6a..c0487cd709c3d 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -19,7 +19,7 @@
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
-#include "mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h"
+#include "mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpls.h"
#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
#include "mlir/Transforms/DialectConversion.h"
#include <optional>
diff --git a/mlir/lib/Dialect/XeGPU/Utils/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/Utils/CMakeLists.txt
index bde8324aab5fb..d9bf4a1461c27 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/CMakeLists.txt
+++ b/mlir/lib/Dialect/XeGPU/Utils/CMakeLists.txt
@@ -1,6 +1,5 @@
add_mlir_dialect_library(MLIRXeGPUUtils
XeGPUUtils.cpp
- XeGPULayoutUtils.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/XeGPU/Utils
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
index 181b7e9673fef..cc6ac76ec89fd 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
@@ -392,3 +392,203 @@ template int xegpu::getLargestDivisor<int>(int dim, ArrayRef<int> candidates,
template int
xegpu::getLargestDivisor<unsigned>(unsigned dim, ArrayRef<unsigned> candidates,
ArrayRef<unsigned> candidateMultiples);
+
+std::string xegpu::getTemporaryLayoutName(const OpOperand &operand) {
+ const StringRef prefix("layout_operand_");
+ unsigned idx = const_cast<OpOperand &>(operand).getOperandNumber();
+ return llvm::formatv("{0}{1}", prefix, idx).str();
+}
+
+std::string xegpu::getTemporaryLayoutName(const OpResult result) {
+ const StringRef prefix = "layout_result_";
+ return llvm::formatv("{0}{1}", prefix, result.getResultNumber()).str();
+}
+
+xegpu::DistributeLayoutAttr xegpu::getDistributeLayoutAttr(const Value value) {
+ if (!value)
+ return nullptr;
+
+ if (auto tdescTy =
+ dyn_cast_if_present<xegpu::TensorDescType>(value.getType()))
+ return tdescTy.getLayoutAttr();
+
+ if (auto result = dyn_cast<OpResult>(value)) {
+ Operation *defOp = result.getDefiningOp();
+ assert(defOp && "result must have a defining op");
+
+ if (auto anchorOp = dyn_cast<xegpu::AnchorLayoutInterface>(defOp)) {
+ auto layout = anchorOp.getAnchorLayout();
+ return layout;
+ }
+
+ std::string layoutName = getTemporaryLayoutName(result);
+ if (defOp->hasAttr(layoutName)) {
+ auto layout =
+ defOp->getAttrOfType<xegpu::DistributeLayoutAttr>(layoutName);
+ return layout;
+ }
+ }
+
+ if (auto arg = dyn_cast<BlockArgument>(value)) {
+ auto *parentOp = arg.getOwner()->getParentOp();
+ if (auto loop = dyn_cast<LoopLikeOpInterface>(parentOp)) {
+ OpOperand *tiedInit = loop.getTiedLoopInit(arg);
+ if (tiedInit)
+ return getDistributeLayoutAttr(tiedInit->get());
+ }
+ }
+
+ return nullptr;
+}
+xegpu::DistributeLayoutAttr
+xegpu::getDistributeLayoutAttr(const OpOperand &opr) {
+ Operation *op = opr.getOwner();
+ unsigned idx = const_cast<OpOperand &>(opr).getOperandNumber();
+
+ if (auto anchorOp = dyn_cast<xegpu::AnchorLayoutInterface>(op)) {
+ if (auto dpasOp = dyn_cast<xegpu::DpasOp>(op)) {
+ if (idx == 0) {
+ return dpasOp.getLayoutAAttr();
+ } else if (idx == 1) {
+ return dpasOp.getLayoutBAttr();
+ } else if (idx == 2) {
+ return dpasOp.getLayoutCdAttr();
+ }
+ }
+ if (auto convertOp = dyn_cast<xegpu::ConvertLayoutOp>(op)) {
+ return convertOp.getInputLayoutAttr();
+ }
+ auto layout = anchorOp.getAnchorLayout();
+
+ if (idx == 0)
+ return layout;
+
+ // For store operations (StoreScatterOp, StoreNdOp, StoreMatrixOp),
+ // the layout is valid for the first two operands: value and memref/tdesc.
+ // For other operations, the layout applies to the first operand only.
+ if (isa<xegpu::StoreScatterOp, xegpu::StoreNdOp, xegpu::StoreMatrixOp>(
+ op) &&
+ (idx < 2))
+ return layout;
+ }
+
+ std::string layoutName = xegpu::getTemporaryLayoutName(opr);
+ if (op->hasAttr(layoutName)) {
+ auto layout = op->getAttrOfType<xegpu::DistributeLayoutAttr>(layoutName);
+ return layout;
+ }
+
+ auto layout = getDistributeLayoutAttr(opr.get());
+ return layout;
+}
+
+// TODO-LayoutRefactor: Remove this function after replacing use
+// with setTemporaryLayout or setAnchorLayout
+void xegpu::setDistributeLayoutAttr(
+ const mlir::OpResult &result,
+ const mlir::xegpu::DistributeLayoutAttr layout) {
+ Operation *owner = result.getOwner();
+
+ if (auto anchorOp = dyn_cast<xegpu::AnchorLayoutInterface>(owner)) {
+ if (anchorOp.getAnchorLayout() == layout)
+ return;
+ anchorOp.setAnchorLayout(layout);
+ return;
+ }
+
+ std::string name = xegpu::getTemporaryLayoutName(result);
+ if (owner->hasAttrOfType<DistributeLayoutAttr>(name)) {
+ return;
+ }
+ if (layout) {
+ owner->setAttr(name, layout);
+ }
+}
+
+// TODO-LayoutRefactor: Remove this function after replacing use
+// with setTemporaryLayout or setAnchorLayout
+void xegpu::setDistributeLayoutAttr(const OpOperand &operand,
+ const DistributeLayoutAttr layout) {
+ Operation *owner = operand.getOwner();
+ unsigned idx = const_cast<OpOperand &>(operand).getOperandNumber();
+
+ if (!layout) {
+ return;
+ }
+ if (auto anchorOp = dyn_cast<xegpu::AnchorLayoutInterface>(owner)) {
+ if (auto dpasOp = dyn_cast<xegpu::DpasOp>(owner)) {
+ if (idx == 0) {
+ return dpasOp.setLayoutAAttr(layout);
+ } else if (idx == 1) {
+ return dpasOp.setLayoutBAttr(layout);
+ } else if (idx == 2) {
+ return dpasOp.setLayoutCdAttr(layout);
+ }
+ }
+ if (auto convertOp = dyn_cast<xegpu::ConvertLayoutOp>(owner)) {
+ return convertOp.setInputLayoutAttr(layout);
+ }
+
+ // For store operations (StoreScatterOp, StoreNdOp, StoreMatrixOp),
+ // the layout is valid for the first two operands: value and memref/tdesc.
+ // For other operations, the layout applies to the first operand only.
+ if (isa<xegpu::StoreScatterOp, xegpu::StoreNdOp, xegpu::StoreMatrixOp>(
+ owner)) {
+ if (idx < 2) {
+ anchorOp.setAnchorLayout(layout);
+ }
+ } else {
+ if (idx == 0) {
+ anchorOp.setAnchorLayout(layout);
+ }
+ }
+ }
+
+ std::string name = xegpu::getTemporaryLayoutName(operand);
+ if (owner->hasAttrOfType<DistributeLayoutAttr>(name)) {
+ return;
+ }
+ if (layout) {
+ owner->setAttr(name, layout);
+ }
+}
+
+template <typename T, typename>
+xegpu::DistributeLayoutAttr
+xegpu::getTemporaryLayout(const T &operandOrResult) {
+ Operation *op = operandOrResult.getOwner();
+
+ std::string layoutName = xegpu::getTemporaryLayoutName(operandOrResult);
+ if (op->hasAttr(layoutName)) {
+ auto layout = op->getAttrOfType<xegpu::DistributeLayoutAttr>(layoutName);
+ return layout;
+ }
+
+ return nullptr;
+}
+
+template xegpu::DistributeLayoutAttr
+xegpu::getTemporaryLayout<mlir::OpResult>(const OpResult &result);
+template xegpu::DistributeLayoutAttr
+xegpu::getTemporaryLayout<mlir::OpOperand>(const OpOperand &operand);
+
+template <typename T, typename>
+void xegpu::setTemporaryLayout(const T &operandOrResult,
+ const xegpu::DistributeLayoutAttr layout) {
+ Operation *owner = operandOrResult.getOwner();
+ std::string name = xegpu::getTemporaryLayoutName(operandOrResult);
+ if (owner->hasAttrOfType<xegpu::DistributeLayoutAttr>(name)) {
+ return;
+ }
+ if (layout) {
+ owner->setAttr(name, layout);
+ }
+}
+
+template void xegpu::setTemporaryLayout<mlir::OpResult>(
+ const mlir::OpResult &result,
+ const mlir::xegpu::DistributeLayoutAttr layout);
+
+template void xegpu::setTemporaryLayout<mlir::OpOperand>(
+ const mlir::OpOperand &operand,
+ const mlir::xegpu::DistributeLayoutAttr layout);
>From 72d1c59de6f30ed9bc1b30df12d85d12e365cfd0 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Fri, 30 Jan 2026 04:10:10 +0000
Subject: [PATCH 21/35] clean up
---
.../Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp | 2 +-
.../XeGPU/Transforms/XeGPULayoutImpls.cpp | 350 +++---------------
.../XeGPU/Transforms/XeGPUPropagateLayout.cpp | 119 +-----
3 files changed, 68 insertions(+), 403 deletions(-)
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index 8efbb0702f0d3..8a06271eadd84 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -1190,7 +1190,7 @@ struct ConvertXeGPUToXeVMPass
return {};
auto input = inputs.front();
if (auto vecTy = dyn_cast<VectorType>(input.getType())) {
- if (vecTy.getRank() == 1 && vecTy.getNumElements() == 1) {
+ if (vecTy.getNumElements() == 1) {
// If the vector has a single element, return the element type.
Value cast =
vector::ExtractOp::create(builder, loc, input, 0).getResult();
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpls.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpls.cpp
index 1b0ec6f11b5f9..45c6563ff966a 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpls.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpls.cpp
@@ -189,10 +189,6 @@ xegpu::inferBitCastSourceLayout(xegpu::DistributeLayoutAttr resLayout,
// only adjust the sg_data, inst_data, lane_data accordingly
// based on the bitwidth ratio between source and result element type
- llvm::dbgs() << "inferBitCastSourceLayout: resElemTyBitWidth="
- << resElemTyBitWidth
- << ", srcElemTyBitWidth=" << srcElemTyBitWidth << "\n";
-
SmallVector<int64_t> sgData = resLayout.getEffectiveSgDataAsInt();
SmallVector<int64_t> instData = resLayout.getEffectiveInstDataAsInt();
SmallVector<int64_t> laneData = resLayout.getEffectiveLaneDataAsInt();
@@ -204,30 +200,16 @@ xegpu::inferBitCastSourceLayout(xegpu::DistributeLayoutAttr resLayout,
int64_t laneDataValue = -1;
int64_t dim = resLayout.getRank() - 1;
- llvm::dbgs() << "inferBitCastSourceLayout: dim=" << dim
- << ", sgData.size()=" << sgData.size()
- << ", instData.size()=" << instData.size()
- << ", laneData.size()=" << laneData.size() << "\n";
-
if (srcElemTyBitWidth <= resElemTyBitWidth) {
int bitWidthRatio = resElemTyBitWidth / srcElemTyBitWidth;
- llvm::dbgs() << "inferBitCastSourceLayout: srcElemTyBitWidth >= "
- "resElemTyBitWidth, bitWidthRatio="
- << bitWidthRatio << "\n";
if (sgDataSize)
sgDataValue = sgData[sgDataSize - 1] * bitWidthRatio;
if (instDataSize)
instDataValue = instData[instDataSize - 1] * bitWidthRatio;
if (laneDataSize)
laneDataValue = laneData[laneDataSize - 1] * bitWidthRatio;
- llvm::dbgs() << "inferBitCastSourceLayout: sgDataValue=" << sgDataValue
- << ", instDataValue=" << instDataValue
- << ", laneDataValue=" << laneDataValue << "\n";
} else {
int bitWidthRatio = srcElemTyBitWidth / resElemTyBitWidth;
- llvm::dbgs() << "inferBitCastSourceLayout: srcElemTyBitWidth < "
- "resElemTyBitWidth, bitWidthRatio="
- << bitWidthRatio << "\n";
if (sgDataSize) {
assert((sgData[sgDataSize - 1] % bitWidthRatio) == 0 &&
"sgData not divisible by bitWidthRatio");
@@ -243,9 +225,6 @@ xegpu::inferBitCastSourceLayout(xegpu::DistributeLayoutAttr resLayout,
"laneData not divisible by bitWidthRatio");
laneDataValue = laneData[laneDataSize - 1] / bitWidthRatio;
}
- llvm::dbgs() << "inferBitCastSourceLayout: sgDataValue=" << sgDataValue
- << ", instDataValue=" << instDataValue
- << ", laneDataValue=" << laneDataValue << "\n";
}
// Now set only instData and laneData, preserving sgData
@@ -253,7 +232,6 @@ xegpu::inferBitCastSourceLayout(xegpu::DistributeLayoutAttr resLayout,
finalSrcLayout =
resLayout.setDimData(dim, sgDataValue, instDataValue, laneDataValue);
- llvm::dbgs() << "inferBitCastSourceLayout: returning finalSrcLayout\n";
return finalSrcLayout;
}
@@ -351,40 +329,32 @@ xegpu::inferShapeCastSourceLayout(xegpu::DistributeLayoutAttr resLayout,
};
if (checkCombineToInnerMostDim(srcShape, resShape)) {
- const int subgroupSize = 16; // assuming 16 lanes per subgroup
int srcShapeSize = srcShape.size();
+ int resShapeSize = resShape.size();
auto context = resLayout.getContext();
auto resInstData = resLayout.getEffectiveInstDataAsInt();
auto resLaneLayout = resLayout.getEffectiveLaneLayoutAsInt();
auto resLaneData = resLayout.getEffectiveLaneDataAsInt();
+ if (resShapeSize == 2)
+ assert(resInstData[0] == 1 &&
+ "only innermost dim can have data combined from sources");
+
if (resInstData.size() != 0) {
- if (resInstData.size() == 2)
- assert(resInstData[0] == 1 &&
- "only innermost dim can have inst_data for combine-to-1d");
- // construct source inst_data layout like [1, ..., 1, subgroupSize]
SmallVector<int> inferredInstData(srcShapeSize, 1);
- inferredInstData[srcShapeSize - 1] = resInstData[resInstData.size() - 1];
+ inferredInstData[srcShapeSize - 1] = resInstData[resShapeSize - 1];
return xegpu::LayoutAttr::get(context, inferredInstData);
}
if (resLaneLayout.size() != 0) {
- if (resInstData.size() == 2)
- assert(resInstData[0] == 1 &&
- "only innermost dim can have inst_data for combine-to-1d");
- // construct source lane_layout like [1, ..., 1, subgroupSize]
SmallVector<int> inferredLaneLayout(srcShapeSize, 1);
SmallVector<int> inferredLaneData(srcShapeSize, 1);
- inferredLaneLayout[srcShapeSize - 1] =
- resLaneLayout[resLaneLayout.size() - 1];
-
- inferredLaneData[srcShapeSize - 1] =
- resLaneLayout[resLaneLayout.size() - 1];
+ inferredLaneLayout[srcShapeSize - 1] = resLaneLayout[resShapeSize - 1];
+ inferredLaneData[srcShapeSize - 1] = resLaneData[resShapeSize - 1];
return xegpu::LayoutAttr::get(context, inferredLaneLayout,
inferredLaneData);
}
}
-
- // TODO: Complete implementation for other shape cast scenarios
+ assert("running into unsupported shape cast scenarios");
return nullptr;
}
@@ -394,18 +364,12 @@ xegpu::inferShapeCastSourceLayout(xegpu::DistributeLayoutAttr resLayout,
/// Algorithm Overview:
/// This function attempts to construct a source layout that, when sliced along
/// reduction dimensions, produces a result layout compatible with the
-/// consumer's preferred layout. This minimizes data redistribution overhead
+/// consumer layout. This minimizes data redistribution overhead
/// between the reduction operation and its consumer.
///
-/// Strategy:
-/// 1. First, check if the consumer's preferred layout is already a SliceAttr
-/// with matching reduction dimensions. If so, use its parent layout directly
-/// and adjust the sg_data/inst_data acccording to source shape.
-/// 2. If step 1 fails, construct a new layout by distributing
-/// workgroup/subgroup resources across dimensions. It will try to align
-/// with the consumer's sg_layout for non-reduction dimensions, so that the
-/// reduced result stays with the same subgroup distribution as expected by
-/// the consumer.
+/// It first tries to align the source layout's subgroup layout and data with
+/// the consumer's layout on non-reduction dimensions. Then, it distributes
+/// remaining subgroups across reduction dimensions.
xegpu::SliceAttr xegpu::setupMultiReductionResultLayout(
xegpu::LayoutKind layoutKind, VectorType srcVecTy,
@@ -457,162 +421,52 @@ xegpu::SliceAttr xegpu::setupMultiReductionResultLayout(
defaultLaneData[srcShapeSize - 2] = vectorSize;
defaultLaneData[srcShapeSize - 1] = 1;
- // Strategy 1: Try to preserve the consumer's slice layout structure
- // If the consumer already expects a slice layout with the same reduction
- // dims, we can directly use its parent layout as our source layout. This
- // ensures best alignment and avoids any data movement across subgroups
- // and lanes.
-
- auto canPreserveSliceLayout =
- [&](ArrayRef<int64_t> srcShape, SmallVector<int64_t> reductionDims,
- DistributeLayoutAttr consumerLayout) -> bool {
- // Verify that the consumer layout can be adapted to the source shape:
- // For each dimension, check if srcShape[i] is divisible by the parent's
- // sg_layout[i] and lane_layout[i]. If so, sg_layout and lane_layout can be
- // reused.
-
- if (!consumerSliceLayout)
- return false;
- if (!consumerSliceLayout.getDims().asArrayRef().equals(reductionDims))
- return false;
- xegpu::DistributeLayoutAttr parentLayout = consumerSliceLayout.getParent();
- if (parentLayout.getRank() != srcShapeSize)
- return false;
-
- SmallVector<int64_t> parentSgLayout =
- parentLayout.getEffectiveSgLayoutAsInt();
- SmallVector<int64_t> parentLaneLayout =
- parentLayout.getEffectiveLaneLayoutAsInt();
-
- if (parentSgLayout.size() != static_cast<size_t>(srcShapeSize))
- return false;
- if (parentLaneLayout.size() != static_cast<size_t>(srcShapeSize))
- return false;
- for (int i = 0; i < srcShapeSize; i++) {
- if (srcShape[i] % parentSgLayout[i] != 0)
- return false;
- if (instData[i] % parentLaneLayout[i] != 0)
- return false;
- }
- return true;
- };
-
- if (canPreserveSliceLayout(srcShape, reductionDims, consumerLayout)) {
- // for each slice dim in source shape, if the dim size is different than the
- // result shape, try to adjust the sg_data/inst_data accordingly.
- SmallVector<int64_t> parentSgLayout =
- consumerSliceLayout.getEffectiveSgLayoutAsInt();
- SmallVector<int64_t> parentLaneLayout =
- consumerSliceLayout.getEffectiveLaneLayoutAsInt();
- SmallVector<int64_t> parentLaneData =
- consumerSliceLayout.getEffectiveLaneDataAsInt();
+ SmallVector<int64_t> consumerSgLayout =
+ consumerLayout.getEffectiveSgLayoutAsInt();
+ SmallVector<int64_t> consumerLaneLayout =
+ consumerLayout.getEffectiveLaneLayoutAsInt();
+ int remainingSgCount = workgroupSize;
+ int consumerSgId;
- switch (layoutKind) {
- case xegpu::LayoutKind::Subgroup:
- sgLayout = parentSgLayout;
- for (int i = 0; i < srcShapeSize; i++)
- if (llvm::is_contained(reductionDims, i))
- sgData[i] = srcShape[i] / sgLayout[i];
- else
- sgData[i] = parentSgLayout[i];
- break;
- case xegpu::LayoutKind::InstData:
- for (int i = 0; i < srcShapeSize; i++)
- instData[i] = std::min(defaultInstData[i], srcShape[i]);
- break;
- case xegpu::LayoutKind::Lane:
- laneLayout = parentLaneLayout;
- for (int i = 0; i < srcShapeSize; i++) {
- assert((srcShape[i] % laneLayout[i] == 0) &&
- "source shape not divisible by lane layout");
- laneData[i] = srcShape[i] / laneLayout[i];
+ switch (layoutKind) {
+ case xegpu::LayoutKind::Subgroup:
+ consumerSgId = consumerSgLayout.size() - 1;
+ // For non-reduction dimensions, try to match consumer's sg_layout
+ // This ensures the result after reduction has the expected distribution
+ for (int i = srcShapeSize - 1; i >= 0; i--)
+ if (!llvm::is_contained(reductionDims, i) && consumerSgId >= 0) {
+ sgLayout[i] = consumerSgLayout[consumerSgId];
+ assert((srcShape[i] % sgLayout[i] == 0) &&
+ "source shape not divisible by consumer sg_layout");
+ sgData[i] = srcShape[i] / sgLayout[i];
+ remainingSgCount /= sgLayout[i];
+ consumerSgId--;
}
- break;
- default:
- llvm_unreachable("unsupported layout kind");
- }
- } else {
- // Strategy 2: Construct a new layout aligned with consumer's sg_layout for
- // the result (non-reduction dims) then distribute remaining subgroups
- // across reduced dimensions
-
- SmallVector<int64_t> consumerSgLayout =
- consumerLayout.getEffectiveSgLayoutAsInt();
- SmallVector<int64_t> consumerLaneLayout =
- consumerLayout.getEffectiveLaneLayoutAsInt();
- int remainingSgCount = workgroupSize;
- int remainingLaneCount = subgroupSize;
- int consumerSgId, consumerLaneId;
-
- switch (layoutKind) {
- case xegpu::LayoutKind::Subgroup:
- consumerSgId = consumerSgLayout.size() - 1;
- // For non-reduction dimensions, try to match consumer's sg_layout
- // This ensures the result after reduction has the expected distribution
- for (int i = srcShapeSize - 1; i >= 0; i--)
- if (!llvm::is_contained(reductionDims, i) && consumerSgId >= 0) {
- sgLayout[i] = consumerSgLayout[consumerSgId];
- assert((srcShape[i] % sgLayout[i] == 0) &&
- "source shape not divisible by consumer sg_layout");
- sgData[i] = srcShape[i] / sgLayout[i];
- remainingSgCount /= sgLayout[i];
- consumerSgId--;
- }
- // Second pass: Distribute remaining subgroups across unhandled dimensions
- // This handles reduction dimensions that don't necessarily align with
- // consumer
- for (int i = srcShapeSize - 1; i >= 0; i--)
- if (llvm::is_contained(reductionDims, i)) {
- sgLayout[i] = std::min((srcShape[i] / defaultLaneLayout[i]),
- static_cast<int64_t>(remainingSgCount));
- assert((srcShape[i] % sgLayout[i] == 0) &&
- "source shape not divisible by consumer sg_layout");
- sgData[i] = srcShape[i] / sgLayout[i];
- remainingSgCount /= sgLayout[i];
- }
- assert(remainingSgCount == 1 &&
- "not all subgroups have been distributed");
- break;
- case xegpu::LayoutKind::InstData:
- for (int i = 0; i < srcShapeSize; i++) {
- instData[i] = std::min(defaultInstData[i], srcShape[i]);
- llvm::dbgs() << "MultiReductionOp: Strategy 2: instData [" << i
- << "] = " << instData[i] << "\n";
+ // Second pass: Distribute remaining subgroups across unhandled dimensions
+ // This handles reduction dimensions that don't necessarily align with
+ // consumer
+ for (int i = srcShapeSize - 1; i >= 0; i--)
+ if (llvm::is_contained(reductionDims, i)) {
+ sgLayout[i] = std::min((srcShape[i] / defaultLaneLayout[i]),
+ static_cast<int64_t>(remainingSgCount));
+ assert((srcShape[i] % sgLayout[i] == 0) &&
+ "source shape not divisible by consumer sg_layout");
+ sgData[i] = srcShape[i] / sgLayout[i];
+ remainingSgCount /= sgLayout[i];
}
- break;
- case xegpu::LayoutKind::Lane:
- laneLayout = defaultLaneLayout;
- laneData = defaultLaneData;
- // consumerLaneId = consumerLaneLayout.size() - 1;
- // // For non-reduction dimensions, try to match consumer's lane_layout
- // // This ensures the result after reduction has the expected
- // distribution for (int i = 0; i < srcShapeSize; i++)
- // if (!llvm::is_contained(reductionDims, i) && consumerLaneId >= 0) {
- // laneLayout[i] = consumerLaneLayout[consumerLaneId];
- // assert((srcShape[i] % laneLayout[i] == 0) &&
- // "source shape not divisible by consumer lane_layout");
- // laneData[i] = srcShape[i] / laneLayout[i];
- // remainingLaneCount /= laneLayout[i];
- // consumerLaneId--;
- // }
- // for (int i = 0; i < srcShapeSize; i++)
- // if (llvm::is_contained(reductionDims, i)) {
- // laneLayout[i] =
- // std::min(srcShape[i],
- // static_cast<int64_t>(remainingLaneCount));
- // assert((srcShape[i] % laneLayout[i] == 0) &&
- // "source shape not divisible by consumer lane_layout");
- // laneData[i] = srcShape[i] / laneLayout[i];
- // laneData[i] = std::min(laneData[i],
- // static_cast<int64_t>(vectorSize)); remainingLaneCount /=
- // laneLayout[i];
- // }
- // assert(remainingLaneCount == 1 && "not all lanes have been
- // distributed");
- break;
- default:
- llvm_unreachable("unsupported layout kind");
+ assert(remainingSgCount == 1 && "not all subgroups have been distributed");
+ break;
+ case xegpu::LayoutKind::InstData:
+ for (int i = 0; i < srcShapeSize; i++) {
+ instData[i] = std::min(defaultInstData[i], srcShape[i]);
}
+ break;
+ case xegpu::LayoutKind::Lane:
+ laneLayout = defaultLaneLayout;
+ laneData = defaultLaneData;
+ break;
+ default:
+ llvm_unreachable("unsupported layout kind");
}
SmallVector<int32_t> sgLayout32(sgLayout.begin(), sgLayout.end());
@@ -650,21 +504,9 @@ xegpu::DistributeLayoutAttr xegpu::setupBitCastResultLayout(
xegpu::LayoutKind layoutKind, VectorType srcVecTy, VectorType resVecTy,
DistributeLayoutAttr consumerLayout, const xegpu::uArch::uArch *uArch) {
- // print out consumerLayout for debugging
- llvm::dbgs() << "setupBitCastResultLayout: consumerLayout=";
- consumerLayout.print(llvm::dbgs());
- llvm::dbgs() << "\n";
-
- llvm::dbgs() << "setupBitCastResultLayout: layoutKind="
- << static_cast<int>(layoutKind) << "\n";
-
int srcElemTyBitWidth = srcVecTy.getElementType().getIntOrFloatBitWidth();
int resElemTyBitWidth = resVecTy.getElementType().getIntOrFloatBitWidth();
- llvm::dbgs() << "setupBitCastResultLayout: srcElemTyBitWidth="
- << srcElemTyBitWidth
- << ", resElemTyBitWidth=" << resElemTyBitWidth << "\n";
-
ArrayRef<int64_t> srcShape = srcVecTy.getShape();
SmallVector<int64_t> sgData = consumerLayout.getEffectiveSgDataAsInt();
SmallVector<int64_t> instData = consumerLayout.getEffectiveInstDataAsInt();
@@ -674,15 +516,7 @@ xegpu::DistributeLayoutAttr xegpu::setupBitCastResultLayout(
int64_t instDataValue = -1;
int64_t laneDataValue = -1;
- llvm::dbgs() << "setupBitCastResultLayout: srcShape.size()="
- << srcShape.size() << ", dim=" << dim << "\n";
- llvm::dbgs() << "setupBitCastResultLayout: sgData.size()=" << sgData.size()
- << ", instData.size()=" << instData.size()
- << ", laneData.size()=" << laneData.size() << "\n";
-
const int subgroupSize = uArch->getSubgroupSize();
- llvm::dbgs() << "setupBitCastResultLayout: subgroupSize=" << subgroupSize
- << "\n";
if (srcElemTyBitWidth > resElemTyBitWidth) {
// When casting to a smaller bitwidth, multiply the result layout
@@ -705,7 +539,7 @@ xegpu::DistributeLayoutAttr xegpu::setupBitCastResultLayout(
while ((instDataValue <= srcShape[dim]) &&
(instDataValue % (innermostDimLaneLayout * bitWidthRatio) != 0))
instDataValue *= 2;
- assert(srcShape[dim] % instDataValue == 0 &&
+ assert((srcShape[dim] % instDataValue) == 0 &&
"srcShape, instData, and lanelayout for innermost must be 2^n !");
break;
case xegpu::LayoutKind::Lane:
@@ -721,14 +555,8 @@ xegpu::DistributeLayoutAttr xegpu::setupBitCastResultLayout(
}
// Now set only instData and laneData, preserving sgData
xegpu::DistributeLayoutAttr resLayout;
- llvm::dbgs() << "setupBitCastResultLayout: Setting dimension data - dim="
- << dim << ", sgDataValue=" << sgDataValue
- << ", instDataValue=" << instDataValue
- << ", laneDataValue=" << laneDataValue << "\n";
resLayout = consumerLayout.setDimData(dim, sgDataValue, instDataValue,
laneDataValue);
- llvm::dbgs()
- << "setupBitCastResultLayout: resLayout created successfully\n";
return resLayout;
}
return consumerLayout;
@@ -804,10 +632,10 @@ xegpu::setupInsertStridedSliceResultLayout(xegpu::LayoutKind layoutKind,
auto context = resVectorTy.getContext();
auto resShape = resVectorTy.getShape();
int resShapeSize = resShape.size();
+
SmallVector<int> defaultInstData(resShapeSize, 1);
SmallVector<int> defaultLaneLayout(resShapeSize, 1);
SmallVector<int> defaultLaneData(resShapeSize, 1);
-
defaultInstData[resShapeSize - 1] = subgroupSize;
defaultLaneLayout[resShapeSize - 1] = subgroupSize;
@@ -833,7 +661,6 @@ xegpu::DistributeLayoutAttr xegpu::inferInsertStridedSliceSourceLayout(
xegpu::DistributeLayoutAttr resLayout, ArrayRef<int64_t> resShape,
ArrayRef<int64_t> srcShape) {
- const int subgroupSize = 16; // assuming 16 lanes per subgroup
int srcShapeSize = srcShape.size();
int resShapeSize = resShape.size();
int dimDiff = resShapeSize - srcShapeSize;
@@ -883,28 +710,14 @@ xegpu::DistributeLayoutAttr xegpu::setupLoadGatherAnchorLayout(
LayoutKind layoutKind, VectorType resVecTy, int chunkSize,
DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch) {
- llvm::dbgs() << "setupLoadGatherAnchorLayout: layoutKind="
- << static_cast<int>(layoutKind) << ", chunkSize=" << chunkSize
- << "\n";
-
xegpu::DistributeLayoutAttr requiredLayout;
const int subgroupSize = uArch->getSubgroupSize();
- const int spirVectorSize = 16; // vector size from SPRIV vector restriction
auto resShape = resVecTy.getShape();
int resShapeSize = resShape.size();
SmallVector<int> instData(resShapeSize);
auto context = resVecTy.getContext();
- llvm::dbgs() << "setupLoadGatherAnchorLayout: resVecTy.getRank()="
- << resVecTy.getRank() << ", resShape=[";
- for (size_t i = 0; i < resShape.size(); ++i) {
- if (i > 0)
- llvm::dbgs() << ", ";
- llvm::dbgs() << resShape[i];
- }
- llvm::dbgs() << "]\n";
-
const auto *uArchInstruction =
dyn_cast<xegpu::uArch::StoreScatterInstruction>(
uArch->getInstruction(xegpu::uArch::InstructionKind::StoreScatter));
@@ -913,48 +726,31 @@ xegpu::DistributeLayoutAttr xegpu::setupLoadGatherAnchorLayout(
consumerLayout.getEffectiveInstDataAsInt();
SmallVector<int32_t> instData32;
- llvm::dbgs() << "setupLoadGatherAnchorLayout: consumerInstData=[";
- for (size_t i = 0; i < consumerInstData.size(); ++i) {
- if (i > 0)
- llvm::dbgs() << ", ";
- llvm::dbgs() << consumerInstData[i];
- }
- llvm::dbgs() << "]\n";
-
switch (layoutKind) {
case xegpu::LayoutKind::Subgroup:
- llvm::dbgs() << "setupLoadGatherAnchorLayout: LayoutKind::Subgroup\n";
requiredLayout = consumerLayout;
break;
case xegpu::LayoutKind::InstData:
- llvm::dbgs() << "setupLoadGatherAnchorLayout: LayoutKind::InstData\n";
if (resVecTy.getRank() == 1) {
instData[0] = subgroupSize;
- llvm::dbgs() << "setupLoadGatherAnchorLayout: 1D case, instData[0]="
- << instData[0] << "\n";
} else {
assert((resVecTy.getRank() == 2) && "StoreScatterOp can access 2D tensor "
"tile at maximum at subgroup level.");
if (chunkSize == 1) {
instData[0] = 1;
instData[1] = subgroupSize;
- llvm::dbgs() << "setupLoadGatherAnchorLayout: chunkSize==1, instData=["
- << instData[0] << ", " << instData[1] << "]\n";
} else {
instData[0] = subgroupSize;
instData[1] = std::min(static_cast<int>(resShape[1]),
uArchInstruction->getMaxLaneLoadStoreSize());
instData[1] =
std::min(instData[1], static_cast<int>(consumerInstData[1]));
- llvm::dbgs() << "setupLoadGatherAnchorLayout: chunkSize>1, instData=["
- << instData[0] << ", " << instData[1] << "]\n";
}
}
requiredLayout = xegpu::LayoutAttr::get(
context, DenseI32ArrayAttr::get(context, instData));
break;
case xegpu::LayoutKind::Lane:
- llvm::dbgs() << "setupLoadGatherAnchorLayout: LayoutKind::Lane\n";
if (chunkSize == 1)
requiredLayout =
getDefaultLaneLayoutAttr(context, resVecTy.getRank(), uArch);
@@ -964,18 +760,13 @@ xegpu::DistributeLayoutAttr xegpu::setupLoadGatherAnchorLayout(
assert(resShape[1] <= static_cast<int64_t>(
uArchInstruction->getMaxLaneLoadStoreSize()) &&
"StoreScatterOp lane size exceeds max lane load/store size.");
- llvm::dbgs() << "setupLoadGatherAnchorLayout: Creating LayoutAttr with "
- << "laneLayout=[" << subgroupSize << ", 1], "
- << "laneData=[1, " << resShape[1] << "]\n";
requiredLayout = xegpu::LayoutAttr::get(
context, {subgroupSize, 1}, {1, static_cast<int>(resShape[1])});
- llvm::dbgs() << "setupLoadGatherAnchorLayout: Created requiredLayout\n";
}
break;
default:
llvm_unreachable("unsupported layout kind");
}
- llvm::dbgs() << "setupLoadGatherAnchorLayout: returning requiredLayout\n";
return requiredLayout;
}
@@ -983,27 +774,13 @@ xegpu::DistributeLayoutAttr
xegpu::setupStoreScatterAnchorLayout(LayoutKind layoutKind, VectorType srcVecTy,
int chunkSize, const uArch::uArch *uArch) {
- llvm::dbgs() << "setupStoreScatterAnchorLayout: layoutKind="
- << static_cast<int>(layoutKind) << ", chunkSize=" << chunkSize
- << "\n";
-
xegpu::DistributeLayoutAttr requiredLayout;
const int subgroupSize = uArch->getSubgroupSize();
- const int spirVectorSize = 16; // vector size from SPRIV vector restriction
auto srcShape = srcVecTy.getShape();
int srcShapeSize = srcShape.size();
SmallVector<int> instData(srcShapeSize);
- llvm::dbgs() << "setupStoreScatterAnchorLayout: srcVecTy.getRank()="
- << srcVecTy.getRank() << ", srcShape=[";
- for (size_t i = 0; i < srcShape.size(); ++i) {
- if (i > 0)
- llvm::dbgs() << ", ";
- llvm::dbgs() << srcShape[i];
- }
- llvm::dbgs() << "]\n";
-
const auto *uArchInstruction =
dyn_cast<xegpu::uArch::StoreScatterInstruction>(
uArch->getInstruction(xegpu::uArch::InstructionKind::StoreScatter));
@@ -1011,39 +788,29 @@ xegpu::setupStoreScatterAnchorLayout(LayoutKind layoutKind, VectorType srcVecTy,
switch (layoutKind) {
case xegpu::LayoutKind::Subgroup:
- llvm::dbgs() << "setupStoreScatterAnchorLayout: LayoutKind::Subgroup\n";
assert(
- false &&
+ true &&
"subgroup layout assignment not supported yet for store scatter op.");
break;
case xegpu::LayoutKind::InstData:
- llvm::dbgs() << "setupStoreScatterAnchorLayout: LayoutKind::InstData\n";
if (srcVecTy.getRank() == 1) {
instData[0] = subgroupSize;
- llvm::dbgs() << "setupStoreScatterAnchorLayout: 1D case, instData[0]="
- << instData[0] << "\n";
} else {
assert((srcVecTy.getRank() <= 2) && "StoreScatterOp can access 2D tensor "
"tile at maximum at subgroup level.");
if (chunkSize == 1) {
instData[0] = 1;
instData[1] = subgroupSize;
- llvm::dbgs()
- << "setupStoreScatterAnchorLayout: chunkSize==1, instData=["
- << instData[0] << ", " << instData[1] << "]\n";
} else {
instData[0] = subgroupSize;
instData[1] = std::min(static_cast<int>(srcShape[1]),
uArchInstruction->getMaxLaneLoadStoreSize());
- llvm::dbgs() << "setupStoreScatterAnchorLayout: chunkSize>1, instData=["
- << instData[0] << ", " << instData[1] << "]\n";
}
}
requiredLayout = xegpu::LayoutAttr::get(
context, DenseI32ArrayAttr::get(context, instData));
break;
case xegpu::LayoutKind::Lane:
- llvm::dbgs() << "setupStoreScatterAnchorLayout: LayoutKind::Lane\n";
if (chunkSize == 1)
requiredLayout =
getDefaultLaneLayoutAttr(context, srcVecTy.getRank(), uArch);
@@ -1060,6 +827,5 @@ xegpu::setupStoreScatterAnchorLayout(LayoutKind layoutKind, VectorType srcVecTy,
default:
llvm_unreachable("unsupported layout kind");
}
- llvm::dbgs() << "setupStoreScatterAnchorLayout: returning requiredLayout\n";
return requiredLayout;
}
\ No newline at end of file
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index c148ba459527e..6385995317f38 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -671,61 +671,29 @@ void LayoutInfoPropagation::visitVectorMultiReductionOp(
if (!resLayoutInfo.isAssigned())
return;
- // debug print resLayoutInfo
- LLVM_DEBUG(DBGS() << "visitVectorMultiReductionOp: resLayoutInfo = ";
- resLayoutInfo.print(llvm::dbgs()); llvm::dbgs() << "\n");
-
VectorType sourceTy =
llvm::dyn_cast<VectorType>(reduction.getSourceVectorType());
SmallVector<int64_t> reductionDims(reduction.getReductionDims().begin(),
reduction.getReductionDims().end());
- LLVM_DEBUG(DBGS() << "visitVectorMultiReductionOp: srcShape = [";
- for (auto dim
- : sourceTy.getShape()) llvm::dbgs()
- << dim << " ";
- llvm::dbgs() << "]\n");
- LLVM_DEBUG(DBGS() << "visitVectorMultiReductionOp: reductionDims = [";
- for (auto dim
- : reductionDims) llvm::dbgs()
- << dim << " ";
- llvm::dbgs() << "]\n");
-
+ auto uArch = getUArch(xegpu::getChipStr(reduction).value_or(""));
auto consumerLayoutAttr =
dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
- LLVM_DEBUG(DBGS() << "visitVectorMultiReductionOp: consumerLayoutAttr = "
- << consumerLayoutAttr << "\n");
- // The required result layout represents the layout requirements
- // for the operation it is recorded to anchor layout or temporary layout.
+ // The result layout represents the layout requirements of the operation.
+ // it is recorded to anchor layout or temporary layout.
// it must be honored for current op and may conflict with the layout
// propagated from consumer op, the conflict is resolved in later phase by
// converting the required result layout to the consumer layout
- auto uArch = getUArch(xegpu::getChipStr(reduction).value_or(""));
auto requiredResLayoutAttr = xegpu::setupMultiReductionResultLayout(
layoutKind, sourceTy, consumerLayoutAttr, reductionDims, uArch);
- LLVM_DEBUG(DBGS() << "visitVectorMultiReductionOp: requiredResLayoutAttr = "
- << requiredResLayoutAttr << "\n");
-
// resLayoutInfo.set(requiredResLayoutAttr);
xegpu::setTemporaryLayout(reduction->getResult(0), requiredResLayoutAttr);
- // debug print resLayoutInfo
- LLVM_DEBUG(
- DBGS() << "visitVectorMultiReductionOp: after change resLayoutInfo = ";
- resLayoutInfo.print(llvm::dbgs()); llvm::dbgs() << "\n");
-
- LLVM_DEBUG(DBGS() << "visitVectorMultiReductionOp: after change "
- "results[0]->getValue() = ";
- results[0]->getValue().print(llvm::dbgs()); llvm::dbgs() << "\n");
-
// derive the source layout from the dominant layout and reduction dims
auto srcLayoutAttr = xegpu::inferMultiReductionSourceLayout(
requiredResLayoutAttr, reductionDims);
- LLVM_DEBUG(DBGS() << "visitVectorMultiReductionOp: srcLayoutAttr = "
- << srcLayoutAttr << "\n");
-
propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
// Accumulator should have the same layout as the result.
propagateIfChanged(operands[1],
@@ -1291,7 +1259,6 @@ void LayoutInfoPropagation::visitCreateDescOp(
void LayoutInfoPropagation::visitStoreScatterOp(
xegpu::StoreScatterOp storeScatter, ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results) {
- LLVM_DEBUG(DBGS() << "visitStoreScatterOp: Processing store scatter op\n");
xegpu::DistributeLayoutAttr requiredAnchorLayoutAttr;
xegpu::DistributeLayoutAttr anchorLayoutAttr = storeScatter.getLayoutAttr();
@@ -1302,19 +1269,9 @@ void LayoutInfoPropagation::visitStoreScatterOp(
llvm::dyn_cast<VectorType>(storeScatter.getMask().getType());
int chunkSize = storeScatter.getChunkSize().value_or(1);
- LLVM_DEBUG(DBGS() << "visitStoreScatterOp: anchorLayoutAttr = "
- << anchorLayoutAttr << "\n");
- LLVM_DEBUG(DBGS() << "visitStoreScatterOp: subgroupSize = " << subgroupSize
- << "\n");
- LLVM_DEBUG(DBGS() << "visitStoreScatterOp: chunkSize = " << chunkSize
- << "\n");
- LLVM_DEBUG(DBGS() << "visitStoreScatterOp: srcVecTy = " << srcVecTy << "\n");
-
if (hasParamsOfLayoutKind(anchorLayoutAttr)) {
- LLVM_DEBUG(DBGS() << "visitStoreScatterOp: Using existing anchor layout\n");
requiredAnchorLayoutAttr = anchorLayoutAttr;
} else {
- LLVM_DEBUG(DBGS() << "visitStoreScatterOp: Setting up new anchor layout\n");
if (!srcVecTy) {
storeScatter.emitWarning("Not propagating, non-vector payload supplied.");
return;
@@ -1324,16 +1281,11 @@ void LayoutInfoPropagation::visitStoreScatterOp(
storeScatter.setLayoutAttr(requiredAnchorLayoutAttr);
}
- LLVM_DEBUG(DBGS() << "visitStoreScatterOp: requiredAnchorLayoutAttr = "
- << requiredAnchorLayoutAttr << "\n");
-
LayoutInfo srcLayoutInfo = LayoutInfo(requiredAnchorLayoutAttr);
auto maskLayoutAttr = requiredAnchorLayoutAttr;
// Special handling mask layout for chunked ops: Enforce the default xegpu 1D
// layout for mask.
if (chunkSize > 1) {
- LLVM_DEBUG(DBGS() << "visitStoreScatterOp: Setting mask layout for chunked "
- "operation\n");
if (layoutKind == xegpu::LayoutKind::InstData)
maskLayoutAttr =
xegpu::LayoutAttr::get(storeScatter->getContext(), {subgroupSize});
@@ -1346,71 +1298,39 @@ void LayoutInfoPropagation::visitStoreScatterOp(
}
LayoutInfo maskLayoutInfo = LayoutInfo(maskLayoutAttr);
- LLVM_DEBUG(DBGS() << "visitStoreScatterOp: maskLayoutAttr = "
- << maskLayoutAttr << "\n");
// Propagate the payload operand layout
- LLVM_DEBUG(DBGS() << "visitStoreScatterOp: Propagating payload layout\n");
propagateIfChanged(operands[0], operands[0]->meet(srcLayoutInfo));
// Propagate the destination (if tdesc) operand layout
- if (isa<xegpu::TensorDescType>(storeScatter.getDestType())) {
- LLVM_DEBUG(
- DBGS() << "visitStoreScatterOp: Propagating destination layout\n");
+ if (isa<xegpu::TensorDescType>(storeScatter.getDestType()))
propagateIfChanged(operands[1], operands[1]->meet(srcLayoutInfo));
- }
// Propagate the new layout to the mask and optional offset operand.
- LLVM_DEBUG(DBGS() << "visitStoreScatterOp: Propagating mask layout\n");
propagateIfChanged(operands[2], operands[2]->meet(maskLayoutInfo));
- if (storeScatter.getOffsets()) {
- LLVM_DEBUG(DBGS() << "visitStoreScatterOp: Propagating offset layout\n");
+ if (storeScatter.getOffsets())
propagateIfChanged(operands[3], operands[3]->meet(maskLayoutInfo));
- }
- LLVM_DEBUG(DBGS() << "visitStoreScatterOp: Done\n");
}
void LayoutInfoPropagation::visitLoadMatrixOp(
xegpu::LoadMatrixOp loadMatrixOp, ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results) {
- LLVM_DEBUG(DBGS() << "visitLoadMatrixOp: Processing load matrix op\n");
LayoutInfo resLayoutInfo = results[0]->getValue();
- if (!resLayoutInfo.isAssigned()) {
- LLVM_DEBUG(DBGS() << "visitLoadMatrixOp: Result layout not assigned\n");
- return;
- }
-
- LLVM_DEBUG(DBGS() << "visitLoadMatrixOp: resLayoutInfo = ";
- resLayoutInfo.print(llvm::dbgs()); llvm::dbgs() << "\n");
-
auto consumerLayoutAttr =
dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
- LLVM_DEBUG(DBGS() << "visitLoadMatrixOp: consumerLayoutAttr = "
- << consumerLayoutAttr << "\n");
-
xegpu::DistributeLayoutAttr anchorLayout = loadMatrixOp.getLayoutAttr();
- LLVM_DEBUG(DBGS() << "visitLoadMatrixOp: anchorLayout = " << anchorLayout
- << "\n");
-
// only need to set anchor layout, no need to porpagate to memdesc and
// offset
if (!hasParamsOfLayoutKind(anchorLayout)) {
- LLVM_DEBUG(DBGS() << "visitLoadMatrixOp: Setting up new anchor layout\n");
VectorType resVecTy =
llvm::cast<VectorType>(loadMatrixOp.getRes().getType());
assert(resVecTy.getRank() == 2 && "Expecting 2D vector for store matrix.");
- LLVM_DEBUG(DBGS() << "visitLoadMatrixOp: resVecTy = " << resVecTy << "\n");
auto uArch = getUArch(getChipStr(loadMatrixOp).value_or(""));
auto requiredAnchorLayoutAttr = xegpu::setupLoadMatrixAnchorLayout(
layoutKind, resVecTy, consumerLayoutAttr, uArch);
- LLVM_DEBUG(DBGS() << "visitLoadMatrixOp: requiredAnchorLayoutAttr = "
- << requiredAnchorLayoutAttr << "\n");
loadMatrixOp.setLayoutAttr(requiredAnchorLayoutAttr);
- } else {
- LLVM_DEBUG(DBGS() << "visitLoadMatrixOp: Using existing anchor layout\n");
}
- LLVM_DEBUG(DBGS() << "visitLoadMatrixOp: Done\n");
}
// Store matrix is a flavor of scattered store for 2D shapes.
@@ -1712,42 +1632,21 @@ void XeGPUPropagateLayoutPass::runOnOperation() {
if (auto opResult = dyn_cast<OpResult>(val)) {
Operation *defOp = opResult.getDefiningOp();
- LLVM_DEBUG(DBGS() << "Try op: " << *defOp << "\n");
if (auto anchorOp = dyn_cast<xegpu::AnchorLayoutInterface>(defOp)) {
- LLVM_DEBUG(DBGS() << "AnchorLayoutInterface found for op: " << *defOp
- << "\n");
auto anchorLayout = anchorOp.getAnchorLayout();
- LLVM_DEBUG(DBGS() << "Anchor layout: " << anchorLayout << "\n");
- LLVM_DEBUG(DBGS() << "Anchor layout is null: "
- << (anchorLayout == nullptr ? "true" : "false")
- << "\n");
- if (anchorLayout != nullptr) {
- LLVM_DEBUG(DBGS()
- << "Returning anchor layout: " << anchorLayout << "\n");
+ if (anchorLayout != nullptr)
return anchorLayout;
- }
- LLVM_DEBUG(DBGS() << "Anchor layout is null, continuing...\n");
}
-
xegpu::DistributeLayoutAttr requiredResLayoutAttr =
xegpu::getTemporaryLayout(opResult);
- LLVM_DEBUG(DBGS() << "Temporary layout for value: " << val << " is "
- << requiredResLayoutAttr << "\n");
- if (requiredResLayoutAttr != nullptr) {
- LLVM_DEBUG(DBGS() << "Returning temporary layout: "
- << requiredResLayoutAttr << "\n");
+ if (requiredResLayoutAttr != nullptr)
return requiredResLayoutAttr;
- }
}
xegpu::DistributeLayoutAttr layoutAttr =
cast<xegpu::DistributeLayoutAttr>(layout.get());
- LLVM_DEBUG(DBGS() << "Layout attr for value: " << val << " is "
- << layoutAttr << "\n");
- if (layout.isSliceLayout()) {
- LLVM_DEBUG(DBGS() << "Returning slice layout: " << layoutAttr << "\n");
+ if (layout.isSliceLayout())
return cast<xegpu::SliceAttr>(layoutAttr);
- }
- LLVM_DEBUG(DBGS() << "Returning layout attr: " << layoutAttr << "\n");
+
return cast<xegpu::LayoutAttr>(layoutAttr);
};
>From 4a945718a1a602c3269cbf6724e188a63d7298f8 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Fri, 30 Jan 2026 04:20:42 +0000
Subject: [PATCH 22/35] fix creatememdesc and memalloca distribution issue
---
.../XeGPU/Transforms/XeGPUSubgroupDistribute.cpp | 15 ++++++++++++---
1 file changed, 12 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 121810215913c..cc278c6813fc9 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -1908,7 +1908,12 @@ struct MemrefAllocaDistribution final : public gpu::WarpDistributionPattern {
using gpu::WarpDistributionPattern::WarpDistributionPattern;
LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
- OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<memref::AllocaOp>);
+ OpOperand *operand = getWarpResult(warpOp, [&](Operation *op) {
+ // Check if the yield operand that was produced by the *last* scattered
+ // load op to avoid creating multiple copies due to multiple users.
+ return llvm::IsaPred<memref::AllocaOp>(op) &&
+ warpOp.getTerminator()->getPrevNode() == op;
+ });
if (!operand)
return rewriter.notifyMatchFailure(
warpOp, "warp result is not a memref::Alloca op");
@@ -1930,8 +1935,12 @@ struct CreateMemDescDistribution final : public gpu::WarpDistributionPattern {
using gpu::WarpDistributionPattern::WarpDistributionPattern;
LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
- OpOperand *operand =
- getWarpResult(warpOp, llvm::IsaPred<xegpu::CreateMemDescOp>);
+ OpOperand *operand = getWarpResult(warpOp, [&](Operation *op) {
+ // Check if the yield operand that was produced by the *last* scattered
+ // load op to avoid creating multiple copies due to multiple users.
+ return llvm::IsaPred<xegpu::CreateMemDescOp>(op) &&
+ warpOp.getTerminator()->getPrevNode() == op;
+ });
if (!operand)
return rewriter.notifyMatchFailure(
warpOp, "warp result is not a xegpu::CreateMemDesc op");
>From 20c4fbdd75663b9297b7cb621f2dc7886063e178 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Fri, 30 Jan 2026 06:34:24 +0000
Subject: [PATCH 23/35] add comments & polish
---
.../XeGPU/Transforms/XeGPULayoutImpls.h | 3 +
.../XeGPU/Transforms/XeGPULayoutImpls.cpp | 245 ++++++++----------
.../XeGPU/Transforms/XeGPUPropagateLayout.cpp | 36 ---
.../Transforms/XeGPUSubgroupDistribute.cpp | 5 -
4 files changed, 116 insertions(+), 173 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpls.h b/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpls.h
index 110496bb34fb3..e5f2b1be93cc2 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpls.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpls.h
@@ -90,6 +90,9 @@ DistributeLayoutAttr inferShapeCastSourceLayout(DistributeLayoutAttr resLayout,
ArrayRef<int64_t> resShape,
ArrayRef<int64_t> srcShape);
+/// Infers the source layout attribute for an insert strided slice operation
+/// given the result layout attribute, result shape, and source shape. Removes
+/// leading dimensions from the result layout to match the source shape size.
DistributeLayoutAttr
inferInsertStridedSliceSourceLayout(DistributeLayoutAttr resLayout,
ArrayRef<int64_t> resShape,
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpls.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpls.cpp
index 45c6563ff966a..8509f709c6fec 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpls.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpls.cpp
@@ -235,6 +235,48 @@ xegpu::inferBitCastSourceLayout(xegpu::DistributeLayoutAttr resLayout,
return finalSrcLayout;
}
+/// Infers the source layout attribute for an insert strided slice operation
+/// given the result layout attribute, result shape, and source shape. Removes
+/// leading dimensions from the result layout to match the source shape size.
+xegpu::DistributeLayoutAttr xegpu::inferInsertStridedSliceSourceLayout(
+ xegpu::DistributeLayoutAttr resLayout, ArrayRef<int64_t> resShape,
+ ArrayRef<int64_t> srcShape) {
+
+ int srcShapeSize = srcShape.size();
+ int resShapeSize = resShape.size();
+ int dimDiff = resShapeSize - srcShapeSize;
+
+ // assert resLayout must be a plain layout
+ assert(isa<xegpu::LayoutAttr>(resLayout) &&
+ "insertStridedSlice result layout must be plain layout");
+ auto context = resLayout.getContext();
+ auto resInstData = resLayout.getEffectiveInstDataAsInt();
+ auto resLaneLayout = resLayout.getEffectiveLaneLayoutAsInt();
+ auto resLaneData = resLayout.getEffectiveLaneDataAsInt();
+
+ if (resInstData.size() != 0) {
+ SmallVector<int> inferredInstData(srcShapeSize);
+ // remove the initial dims in resInstData to match srcShapeSize
+ for (int i = 0; i < srcShapeSize; i++)
+ inferredInstData[i] = resInstData[i + dimDiff];
+ return xegpu::LayoutAttr::get(context, inferredInstData);
+ }
+
+ if (resLaneLayout.size() != 0) {
+ // construct source lane_layout like [1, ..., 1, subgroupSize]
+ SmallVector<int> inferredLaneLayout(srcShapeSize);
+ SmallVector<int> inferredLaneData(srcShapeSize);
+ // remove the initial dims in resInstData to match srcShapeSize
+ for (int i = 0; i < srcShapeSize; i++) {
+ inferredLaneLayout[i] = resLaneLayout[i + dimDiff];
+ inferredLaneData[i] = resLaneData[i + dimDiff];
+ }
+ return xegpu::LayoutAttr::get(context, inferredLaneLayout,
+ inferredLaneData);
+ }
+ return nullptr;
+}
+
/// Infers the source layout attribute for a shape cast operation given the
/// result layout attribute, result shape, and source shape.
xegpu::DistributeLayoutAttr
@@ -335,12 +377,12 @@ xegpu::inferShapeCastSourceLayout(xegpu::DistributeLayoutAttr resLayout,
auto resInstData = resLayout.getEffectiveInstDataAsInt();
auto resLaneLayout = resLayout.getEffectiveLaneLayoutAsInt();
auto resLaneData = resLayout.getEffectiveLaneDataAsInt();
- if (resShapeSize == 2)
- assert(resInstData[0] == 1 &&
- "only innermost dim can have data combined from sources");
+ // get the layout info from the innermost dim of result layout
if (resInstData.size() != 0) {
SmallVector<int> inferredInstData(srcShapeSize, 1);
+ assert((resShapeSize == 1 || resInstData[0] == 1) &&
+ "only innermost dim can have data and instData layout");
inferredInstData[srcShapeSize - 1] = resInstData[resShapeSize - 1];
return xegpu::LayoutAttr::get(context, inferredInstData);
}
@@ -348,6 +390,8 @@ xegpu::inferShapeCastSourceLayout(xegpu::DistributeLayoutAttr resLayout,
if (resLaneLayout.size() != 0) {
SmallVector<int> inferredLaneLayout(srcShapeSize, 1);
SmallVector<int> inferredLaneData(srcShapeSize, 1);
+ assert((resShapeSize == 1 || resLaneLayout[0] == 1) &&
+ "only innermost dim can have data and lane layout");
inferredLaneLayout[srcShapeSize - 1] = resLaneLayout[resShapeSize - 1];
inferredLaneData[srcShapeSize - 1] = resLaneData[resShapeSize - 1];
return xegpu::LayoutAttr::get(context, inferredLaneLayout,
@@ -370,6 +414,10 @@ xegpu::inferShapeCastSourceLayout(xegpu::DistributeLayoutAttr resLayout,
/// It first tries to align the source layout's subgroup layout and data with
/// the consumer's layout on non-reduction dimensions. Then, it distributes
/// remaining subgroups across reduction dimensions.
+/// For InstData, it requries {1, ..., min(maxReduceVectorSize,
+/// srcshape),subgroupSize} For lane layout, it requires {1, ..., 1,
+/// subgroupSize} For lane data, it requires {1, ..., min(maxReduceVectorSize,
+/// srcshape), 1}
xegpu::SliceAttr xegpu::setupMultiReductionResultLayout(
xegpu::LayoutKind layoutKind, VectorType srcVecTy,
@@ -377,127 +425,99 @@ xegpu::SliceAttr xegpu::setupMultiReductionResultLayout(
const xegpu::uArch::uArch *uArch) {
auto srcShape = srcVecTy.getShape();
- xegpu::SliceAttr consumerSliceLayout =
- dyn_cast<xegpu::SliceAttr>(consumerLayout);
-
- int srcShapeSize = srcShape.size();
- xegpu::DistributeLayoutAttr proposedSrcLayout;
+ int srcRank = srcShape.size();
auto context = consumerLayout.getContext();
+
// Reduction layout requires at least 2D tensors
- if (srcShapeSize < 2)
+ if (srcRank < 2)
return nullptr;
- SmallVector<int64_t> sgLayout(srcShapeSize);
- SmallVector<int64_t> sgData(srcShapeSize);
- SmallVector<int64_t> instData(srcShapeSize);
- SmallVector<int64_t> laneLayout(srcShapeSize);
- SmallVector<int64_t> laneData(srcShapeSize);
-
- // recover workgroup and subgroup size from consumer layout
- DistributeLayoutAttr origPlainLayout;
- if (consumerSliceLayout) {
- origPlainLayout = consumerSliceLayout.flatten().getParent();
- } else {
- origPlainLayout = consumerLayout;
- }
+ // Helper lambda to convert int64 vectors to int32 DenseArrayAttr
+ auto toInt32Attr = [&](ArrayRef<int64_t> vec) {
+ SmallVector<int32_t> vec32(vec.begin(), vec.end());
+ return DenseI32ArrayAttr::get(context, vec32);
+ };
- const int workgroupSize =
- std::accumulate(origPlainLayout.getEffectiveSgLayoutAsInt().begin(),
- origPlainLayout.getEffectiveSgLayoutAsInt().end(), 1,
- std::multiplies<int64_t>());
+ // Extract original plain layout for workgroup/subgroup size recovery
+ xegpu::SliceAttr consumerSliceLayout =
+ dyn_cast<xegpu::SliceAttr>(consumerLayout);
+ DistributeLayoutAttr plainLayout =
+ consumerSliceLayout ? consumerSliceLayout.flatten().getParent()
+ : consumerLayout;
+ auto sgLayoutVec = plainLayout.getEffectiveSgLayoutAsInt();
+ const int workgroupSize = std::accumulate(
+ sgLayoutVec.begin(), sgLayoutVec.end(), 1, std::multiplies<int64_t>());
const int subgroupSize = uArch->getSubgroupSize();
+ int64_t maxReduceVectorSize = 1; // could extend to spirv vector Size
- const int vectorSize = 1; // vector size from SPRIV vector restriction
-
- SmallVector<int64_t> defaultInstData(srcShapeSize, 1);
-
- SmallVector<int64_t> defaultLaneLayout(srcShapeSize, 1);
-
- SmallVector<int64_t> defaultLaneData(srcShapeSize, 1);
- defaultInstData[srcShapeSize - 2] = vectorSize;
- defaultInstData[srcShapeSize - 1] = subgroupSize;
- defaultLaneLayout[srcShapeSize - 1] = subgroupSize;
- defaultLaneData[srcShapeSize - 2] = vectorSize;
- defaultLaneData[srcShapeSize - 1] = 1;
-
- SmallVector<int64_t> consumerSgLayout =
- consumerLayout.getEffectiveSgLayoutAsInt();
- SmallVector<int64_t> consumerLaneLayout =
- consumerLayout.getEffectiveLaneLayoutAsInt();
- int remainingSgCount = workgroupSize;
- int consumerSgId;
+ xegpu::DistributeLayoutAttr srcLayout;
switch (layoutKind) {
- case xegpu::LayoutKind::Subgroup:
- consumerSgId = consumerSgLayout.size() - 1;
- // For non-reduction dimensions, try to match consumer's sg_layout
- // This ensures the result after reduction has the expected distribution
- for (int i = srcShapeSize - 1; i >= 0; i--)
- if (!llvm::is_contained(reductionDims, i) && consumerSgId >= 0) {
- sgLayout[i] = consumerSgLayout[consumerSgId];
+ case xegpu::LayoutKind::Subgroup: {
+ SmallVector<int64_t> sgLayout(srcRank), sgData(srcRank);
+ SmallVector<int64_t> consumerSgLayout =
+ consumerLayout.getEffectiveSgLayoutAsInt();
+ int remainingSgCount = workgroupSize;
+ int consumerIdx = consumerSgLayout.size() - 1;
+
+ // First pass: Match consumer's layout on non-reduction dimensions
+ for (int i = srcRank - 1; i >= 0; i--) {
+ if (!llvm::is_contained(reductionDims, i) && consumerIdx >= 0) {
+ sgLayout[i] = consumerSgLayout[consumerIdx];
assert((srcShape[i] % sgLayout[i] == 0) &&
"source shape not divisible by consumer sg_layout");
sgData[i] = srcShape[i] / sgLayout[i];
remainingSgCount /= sgLayout[i];
- consumerSgId--;
+ consumerIdx--;
}
- // Second pass: Distribute remaining subgroups across unhandled dimensions
- // This handles reduction dimensions that don't necessarily align with
- // consumer
- for (int i = srcShapeSize - 1; i >= 0; i--)
+ }
+
+ // Second pass: Distribute remaining subgroups across reduction dimensions
+ for (int i = srcRank - 1; i >= 0; i--) {
if (llvm::is_contained(reductionDims, i)) {
- sgLayout[i] = std::min((srcShape[i] / defaultLaneLayout[i]),
+ sgLayout[i] = std::min(srcShape[i] / subgroupSize,
static_cast<int64_t>(remainingSgCount));
assert((srcShape[i] % sgLayout[i] == 0) &&
- "source shape not divisible by consumer sg_layout");
+ "source shape not divisible by sg_layout");
sgData[i] = srcShape[i] / sgLayout[i];
remainingSgCount /= sgLayout[i];
}
- assert(remainingSgCount == 1 && "not all subgroups have been distributed");
- break;
- case xegpu::LayoutKind::InstData:
- for (int i = 0; i < srcShapeSize; i++) {
- instData[i] = std::min(defaultInstData[i], srcShape[i]);
}
+
+ assert(remainingSgCount == 1 && "not all subgroups distributed");
+ srcLayout =
+ xegpu::LayoutAttr::get(context, toInt32Attr(sgLayout),
+ toInt32Attr(sgData), consumerLayout.getOrder());
break;
- case xegpu::LayoutKind::Lane:
- laneLayout = defaultLaneLayout;
- laneData = defaultLaneData;
- break;
- default:
- llvm_unreachable("unsupported layout kind");
}
- SmallVector<int32_t> sgLayout32(sgLayout.begin(), sgLayout.end());
- SmallVector<int32_t> sgData32(sgData.begin(), sgData.end());
- SmallVector<int32_t> instData32(instData.begin(), instData.end());
- SmallVector<int32_t> laneLayout32(laneLayout.begin(), laneLayout.end());
- SmallVector<int32_t> laneData32(laneData.begin(), laneData.end());
-
- switch (layoutKind) {
- case xegpu::LayoutKind::Subgroup:
- proposedSrcLayout = xegpu::LayoutAttr::get(
- context, DenseI32ArrayAttr::get(context, sgLayout32),
- DenseI32ArrayAttr::get(context, sgData32), consumerLayout.getOrder());
+ case xegpu::LayoutKind::InstData: {
+ SmallVector<int64_t> instData(srcRank, 1);
+ instData[srcRank - 2] =
+ std::min(maxReduceVectorSize, srcShape[srcRank - 2]);
+ instData[srcRank - 1] = subgroupSize;
+ srcLayout = xegpu::LayoutAttr::get(context, toInt32Attr(instData));
break;
- case xegpu::LayoutKind::InstData:
- proposedSrcLayout = xegpu::LayoutAttr::get(
- context, DenseI32ArrayAttr::get(context, instData32));
- break;
- case xegpu::LayoutKind::Lane:
- proposedSrcLayout = xegpu::LayoutAttr::get(
- context, DenseI32ArrayAttr::get(context, laneLayout32),
- DenseI32ArrayAttr::get(context, laneData32), consumerLayout.getOrder());
+ }
+
+ case xegpu::LayoutKind::Lane: {
+ SmallVector<int64_t> laneLayout(srcRank, 1), laneData(srcRank, 1);
+ laneLayout[srcRank - 1] = subgroupSize;
+ laneData[srcRank - 2] =
+ std::min(maxReduceVectorSize, srcShape[srcRank - 2]);
+ srcLayout = xegpu::LayoutAttr::get(context, toInt32Attr(laneLayout),
+ toInt32Attr(laneData),
+ consumerLayout.getOrder());
break;
+ }
+
default:
llvm_unreachable("unsupported layout kind");
}
- xegpu::SliceAttr resLayout =
- xegpu::SliceAttr::get(context, proposedSrcLayout,
- DenseI64ArrayAttr::get(context, reductionDims));
- return resLayout;
+ return xegpu::SliceAttr::get(context, srcLayout,
+ DenseI64ArrayAttr::get(context, reductionDims));
}
xegpu::DistributeLayoutAttr xegpu::setupBitCastResultLayout(
@@ -657,45 +677,6 @@ xegpu::setupInsertStridedSliceResultLayout(xegpu::LayoutKind layoutKind,
return requiredResLayout;
}
-xegpu::DistributeLayoutAttr xegpu::inferInsertStridedSliceSourceLayout(
- xegpu::DistributeLayoutAttr resLayout, ArrayRef<int64_t> resShape,
- ArrayRef<int64_t> srcShape) {
-
- int srcShapeSize = srcShape.size();
- int resShapeSize = resShape.size();
- int dimDiff = resShapeSize - srcShapeSize;
-
- // assert resLayout must be a plain layout
- assert(isa<xegpu::LayoutAttr>(resLayout) &&
- "insertStridedSlice result layout must be plain layout");
- auto context = resLayout.getContext();
- auto resInstData = resLayout.getEffectiveInstDataAsInt();
- auto resLaneLayout = resLayout.getEffectiveLaneLayoutAsInt();
- auto resLaneData = resLayout.getEffectiveLaneDataAsInt();
-
- if (resInstData.size() != 0) {
- SmallVector<int> inferredInstData(srcShapeSize);
- // remove the initial dims in resInstData to match srcShapeSize
- for (int i = 0; i < srcShapeSize; i++)
- inferredInstData[i] = resInstData[i + dimDiff];
- return xegpu::LayoutAttr::get(context, inferredInstData);
- }
-
- if (resLaneLayout.size() != 0) {
- // construct source lane_layout like [1, ..., 1, subgroupSize]
- SmallVector<int> inferredLaneLayout(srcShapeSize);
- SmallVector<int> inferredLaneData(srcShapeSize);
- // remove the initial dims in resInstData to match srcShapeSize
- for (int i = 0; i < srcShapeSize; i++) {
- inferredLaneLayout[i] = resLaneLayout[i + dimDiff];
- inferredLaneData[i] = resLaneData[i + dimDiff];
- }
- return xegpu::LayoutAttr::get(context, inferredLaneLayout,
- inferredLaneData);
- }
- return nullptr;
-}
-
static xegpu::DistributeLayoutAttr
getDefaultLaneLayoutAttr(mlir::MLIRContext *ctx, unsigned rank,
const xegpu::uArch::uArch *uArch) {
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 6385995317f38..951c6e092daf2 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -1147,40 +1147,6 @@ void LayoutInfoPropagation::visitInsertStridedSliceOp(
return;
}
-// /// For vector::BitCastOp, the lane_data of the source layout is changed
-// based
-// /// on the bit width of the source and result types.
-// void LayoutInfoPropagation::visitVectorInsertStridedSliceOp(
-// vector::BitCastOp bitcast, ArrayRef<LayoutInfoLattice *> operands,
-// ArrayRef<const LayoutInfoLattice *> results) {
-// // Need the layout of bitcast result to propagate to the operands.
-// LayoutInfo resLayoutInfo = results[0]->getValue();
-// if (!resLayoutInfo.isAssigned())
-// return;
-
-// auto srcVecType = bitcast.getSourceVectorType();
-// auto resVecType = bitcast.getResultVectorType();
-
-// auto consumerLayoutAttr =
-// dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
-// auto uArch = getUArch(xegpu::getChipStr(bitcast).value_or(""));
-// auto requiredResLayoutAttr = setupInsertStridedSliceResultLayout(
-// layoutKind, srcVecType, resVecType, consumerLayoutAttr, uArch);
-
-// xegpu::setTemporaryLayout(bitcast->getResult(0), requiredResLayoutAttr);
-
-// int inElemTyBitWidth = srcVecType.getElementType().getIntOrFloatBitWidth();
-// int outElemTyBitWidth =
-// resVecType.getElementType().getIntOrFloatBitWidth();
-
-// // derive the source layout from the dominant layout and reduction dims
-// auto srcLayoutAttr = xegpu::inferBitCastSourceLayout(
-// requiredResLayoutAttr, outElemTyBitWidth, inElemTyBitWidth);
-
-// propagateIfChanged(operands[0],
-// operands[0]->meet(LayoutInfo(srcLayoutAttr)));
-// }
-
/// Propagate the layout of the result to the tensor descriptor, mask and offset
/// operands in LoadGatherOp.
void LayoutInfoPropagation::visitLoadGatherOp(
@@ -1265,8 +1231,6 @@ void LayoutInfoPropagation::visitStoreScatterOp(
auto uArch = getUArch(getChipStr(storeScatter).value_or(""));
auto subgroupSize = uArch->getSubgroupSize();
VectorType srcVecTy = storeScatter.getValueType();
- VectorType maskTy =
- llvm::dyn_cast<VectorType>(storeScatter.getMask().getType());
int chunkSize = storeScatter.getChunkSize().value_or(1);
if (hasParamsOfLayoutKind(anchorLayoutAttr)) {
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index cc278c6813fc9..7d06ff0ff0fb3 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -1615,11 +1615,6 @@ struct VectorShapeCastDistribution : public gpu::WarpDistributionPattern {
warpOp,
"the source or result of shape_cast op lacks distribution layout");
- // For rank reducing or increasing shape_cast ops, the lower rank layout
- // must be a slice of higher rank layout.
- int64_t sourceRank = shapeCastOp.getSourceVectorType().getRank();
- int64_t resultRank = shapeCastOp.getResultVectorType().getRank();
-
FailureOr<VectorType> sourceDistTypeOrFailure =
getDistVecTypeBasedOnLaneLayout(sourceLayout,
shapeCastOp.getSourceVectorType());
>From 97e9211a376ecb634a41b7e506a7a8cce8a60865 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Fri, 30 Jan 2026 18:02:56 +0000
Subject: [PATCH 24/35] improve InsertStridedSlice result layout setup
---
.../XeGPU/Transforms/XeGPULayoutImpls.h | 3 +-
.../XeGPU/Transforms/XeGPULayoutImpls.cpp | 69 +++++++++++++------
.../XeGPU/Transforms/XeGPUPropagateLayout.cpp | 4 +-
3 files changed, 52 insertions(+), 24 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpls.h b/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpls.h
index e5f2b1be93cc2..27325987c64e7 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpls.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpls.h
@@ -126,7 +126,8 @@ DistributeLayoutAttr setupBitCastResultLayout(
DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch);
DistributeLayoutAttr setupInsertStridedSliceResultLayout(
- LayoutKind layoutKind, VectorType resVectorTy, const uArch::uArch *uArch);
+ LayoutKind layoutKind, VectorType resVectorTy,
+ DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch);
DistributeLayoutAttr
setupLoadMatrixAnchorLayout(LayoutKind layoutKind, VectorType vectorTy,
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpls.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpls.cpp
index 8509f709c6fec..1165d9d2cefab 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpls.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpls.cpp
@@ -408,16 +408,17 @@ xegpu::inferShapeCastSourceLayout(xegpu::DistributeLayoutAttr resLayout,
/// Algorithm Overview:
/// This function attempts to construct a source layout that, when sliced along
/// reduction dimensions, produces a result layout compatible with the
-/// consumer layout. This minimizes data redistribution overhead
-/// between the reduction operation and its consumer.
+/// consumer layout.
///
-/// It first tries to align the source layout's subgroup layout and data with
-/// the consumer's layout on non-reduction dimensions. Then, it distributes
-/// remaining subgroups across reduction dimensions.
-/// For InstData, it requries {1, ..., min(maxReduceVectorSize,
-/// srcshape),subgroupSize} For lane layout, it requires {1, ..., 1,
-/// subgroupSize} For lane data, it requires {1, ..., min(maxReduceVectorSize,
-/// srcshape), 1}
+/// For subgroup layouts, it first tries to align the source layout's subgroup
+/// layout and data with the consumer's layout on non-reduction dimensions.
+/// Then, it distributes remaining subgroups across reduction dimensions. This
+/// avoid subgroup data redistribution overhead between the reduced result and
+/// its consumer.
+///
+/// InstData requries {1, ..., min(maxReduceVectorSize, srcShape),subgroupSize}
+/// Lane Layout requires {1, ..., 1, subgroupSize}
+/// Lane data requires {1, ..., min(maxReduceVectorSize, srcShape), 1}
xegpu::SliceAttr xegpu::setupMultiReductionResultLayout(
xegpu::LayoutKind layoutKind, VectorType srcVecTy,
@@ -520,6 +521,11 @@ xegpu::SliceAttr xegpu::setupMultiReductionResultLayout(
DenseI64ArrayAttr::get(context, reductionDims));
}
+/// Sets up the result layout for a bitcast operation.
+/// When casting to a smaller bitwidth, adjusts the layout dimensions (sgData,
+/// instData, or laneData) by multiplying by the bitwidth ratio to ensure the
+/// result layout can be correctly divided back to the source layout during
+/// inference.
xegpu::DistributeLayoutAttr xegpu::setupBitCastResultLayout(
xegpu::LayoutKind layoutKind, VectorType srcVecTy, VectorType resVecTy,
DistributeLayoutAttr consumerLayout, const xegpu::uArch::uArch *uArch) {
@@ -642,22 +648,37 @@ xegpu::setupStoreMatrixAnchorLayout(xegpu::LayoutKind layoutKind,
return requiredLayout;
}
-xegpu::DistributeLayoutAttr
-xegpu::setupInsertStridedSliceResultLayout(xegpu::LayoutKind layoutKind,
- VectorType resVectorTy,
- const xegpu::uArch::uArch *uArch) {
+/// Sets up the result layout for an insert strided slice operation.
+/// Creates a default layout based on the specified layout kind (InstData or
+/// Lane).
+/// Subgroup layout is currently not supported for this operation.
+/// InstData layout requires {1, .., subgroupSize} by default.
+/// Lane layout requires {1, ..., subgroupSize} with lane data {1, ..., 1}.
+/// The instData and laneData is adjusted to contain packed data, by checking if
+/// the consumerLayout's innermost dimension.
+xegpu::DistributeLayoutAttr xegpu::setupInsertStridedSliceResultLayout(
+ xegpu::LayoutKind layoutKind, VectorType resVectorTy,
+ xegpu::DistributeLayoutAttr consumerLayout,
+ const xegpu::uArch::uArch *uArch) {
xegpu::DistributeLayoutAttr requiredResLayout;
auto subgroupSize = uArch->getSubgroupSize();
auto context = resVectorTy.getContext();
auto resShape = resVectorTy.getShape();
int resShapeSize = resShape.size();
+ SmallVector<int64_t> consumerInstData =
+ consumerLayout.getEffectiveInstDataAsInt();
+ SmallVector<int64_t> consumerLaneData =
+ consumerLayout.getEffectiveLaneDataAsInt();
+
+ SmallVector<int> instData(resShapeSize, 1);
+ SmallVector<int> laneLayout(resShapeSize, 1);
+ SmallVector<int> laneData(resShapeSize, 1);
- SmallVector<int> defaultInstData(resShapeSize, 1);
- SmallVector<int> defaultLaneLayout(resShapeSize, 1);
- SmallVector<int> defaultLaneData(resShapeSize, 1);
- defaultInstData[resShapeSize - 1] = subgroupSize;
- defaultLaneLayout[resShapeSize - 1] = subgroupSize;
+ const unsigned packingSize{uArch->getGeneralPackedFormatBitSize()};
+ unsigned bitwidth = resVectorTy.getElementType().getIntOrFloatBitWidth();
+ int packingFactor = bitwidth < packingSize ? packingSize / bitwidth : 1;
+ int packedDataSize = subgroupSize * packingFactor;
switch (layoutKind) {
case xegpu::LayoutKind::Subgroup:
@@ -665,11 +686,17 @@ xegpu::setupInsertStridedSliceResultLayout(xegpu::LayoutKind layoutKind,
"subgroup layout assignment not supported for insertStridedSlice.");
break;
case xegpu::LayoutKind::InstData:
- requiredResLayout = xegpu::LayoutAttr::get(context, defaultInstData);
+ instData[resShapeSize - 1] = subgroupSize;
+ if (consumerInstData[resShapeSize - 1] == packedDataSize)
+ instData[resShapeSize - 1] = packedDataSize;
+ requiredResLayout = xegpu::LayoutAttr::get(context, instData);
break;
case xegpu::LayoutKind::Lane:
- requiredResLayout =
- xegpu::LayoutAttr::get(context, defaultLaneLayout, defaultLaneData);
+ laneLayout[resShapeSize - 1] = subgroupSize;
+ laneData[resShapeSize - 1] = 1;
+ if (consumerLaneData[resShapeSize - 1] == packingFactor)
+ laneData[resShapeSize - 1] = packingFactor;
+ requiredResLayout = xegpu::LayoutAttr::get(context, laneLayout, laneData);
break;
default:
llvm_unreachable("unsupported layout kind");
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 951c6e092daf2..15547b33765fd 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -1132,8 +1132,8 @@ void LayoutInfoPropagation::visitInsertStridedSliceOp(
dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
auto uArch = getUArch(xegpu::getChipStr(insertStridedSlice).value_or(""));
- auto requiredResLayoutAttr =
- xegpu::setupInsertStridedSliceResultLayout(layoutKind, resVecType, uArch);
+ auto requiredResLayoutAttr = xegpu::setupInsertStridedSliceResultLayout(
+ layoutKind, resVecType, consumerLayoutAttr, uArch);
xegpu::setTemporaryLayout(insertStridedSlice->getResult(0),
requiredResLayoutAttr);
>From 44bd04fc50637023ac44fa8a1762f7b4f388811c Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Fri, 30 Jan 2026 19:08:22 +0000
Subject: [PATCH 25/35] polish loadgather layout setup
---
.../XeGPU/Transforms/XeGPULayoutImpls.cpp | 89 ++++++++++---------
mlir/test/Dialect/XeGPU/propagate-layout.mlir | 2 +-
2 files changed, 48 insertions(+), 43 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpls.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpls.cpp
index 1165d9d2cefab..664edbf8c0003 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpls.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpls.cpp
@@ -714,6 +714,17 @@ getDefaultLaneLayoutAttr(mlir::MLIRContext *ctx, unsigned rank,
return xegpu::LayoutAttr::get(ctx, {1, uArch->getSubgroupSize()}, {1, 1});
}
+/// Sets up the result layout for a load gather operation.
+/// For Subgroup layout, uses the consumer layout directly.
+/// non-chunked loads:
+/// InstData = {1, ..., min(consumer, maxLaneLoadStoreSize *
+/// subgroupSize)}
+/// LaneLayout = {1, ..., subgroupSize}
+/// lane_data = {1, ..., min(consumer, maxLaneLoadStoreSize)}
+/// chunked loads:
+/// InstData = {subgroupSize, min(consumer, maxLaneLoadStoreSize)}
+/// LaneLayout = {subgroupSize, 1}
+/// lane_data={1,min(consumer, maxLaneLoadStoreSize)}
xegpu::DistributeLayoutAttr xegpu::setupLoadGatherAnchorLayout(
LayoutKind layoutKind, VectorType resVecTy, int chunkSize,
DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch) {
@@ -721,9 +732,7 @@ xegpu::DistributeLayoutAttr xegpu::setupLoadGatherAnchorLayout(
xegpu::DistributeLayoutAttr requiredLayout;
const int subgroupSize = uArch->getSubgroupSize();
- auto resShape = resVecTy.getShape();
- int resShapeSize = resShape.size();
- SmallVector<int> instData(resShapeSize);
+ int resShapeSize = resVecTy.getShape().size();
auto context = resVecTy.getContext();
const auto *uArchInstruction =
@@ -732,48 +741,44 @@ xegpu::DistributeLayoutAttr xegpu::setupLoadGatherAnchorLayout(
SmallVector<int64_t> consumerInstData =
consumerLayout.getEffectiveInstDataAsInt();
- SmallVector<int32_t> instData32;
+ SmallVector<int64_t> consumerLaneData =
+ consumerLayout.getEffectiveLaneDataAsInt();
- switch (layoutKind) {
- case xegpu::LayoutKind::Subgroup:
- requiredLayout = consumerLayout;
- break;
- case xegpu::LayoutKind::InstData:
- if (resVecTy.getRank() == 1) {
- instData[0] = subgroupSize;
- } else {
- assert((resVecTy.getRank() == 2) && "StoreScatterOp can access 2D tensor "
- "tile at maximum at subgroup level.");
- if (chunkSize == 1) {
- instData[0] = 1;
- instData[1] = subgroupSize;
- } else {
- instData[0] = subgroupSize;
- instData[1] = std::min(static_cast<int>(resShape[1]),
- uArchInstruction->getMaxLaneLoadStoreSize());
- instData[1] =
- std::min(instData[1], static_cast<int>(consumerInstData[1]));
- }
+ SmallVector<int> instData(resShapeSize, 1);
+ SmallVector<int> laneLayout(resShapeSize, 1);
+ SmallVector<int> laneData(resShapeSize, 1);
+
+ if (layoutKind == xegpu::LayoutKind::Subgroup) {
+ return consumerLayout;
+ }
+
+ if (chunkSize == 1) {
+ if (layoutKind == xegpu::LayoutKind::InstData) {
+ instData[resShapeSize - 1] =
+ std::min(static_cast<int>(consumerInstData[resShapeSize - 1]),
+ uArchInstruction->getMaxLaneLoadStoreSize() * subgroupSize);
+ requiredLayout = xegpu::LayoutAttr::get(context, instData);
+ } else if (layoutKind == xegpu::LayoutKind::Lane) {
+ laneLayout[resShapeSize - 1] = subgroupSize;
+ laneData[resShapeSize - 1] =
+ std::min(static_cast<int>(consumerLaneData[resShapeSize - 1]),
+ uArchInstruction->getMaxLaneLoadStoreSize());
+ requiredLayout = xegpu::LayoutAttr::get(context, laneLayout, laneData);
}
- requiredLayout = xegpu::LayoutAttr::get(
- context, DenseI32ArrayAttr::get(context, instData));
- break;
- case xegpu::LayoutKind::Lane:
- if (chunkSize == 1)
- requiredLayout =
- getDefaultLaneLayoutAttr(context, resVecTy.getRank(), uArch);
- else {
- assert((resVecTy.getRank() <= 2) && "StoreScatterOp can access 2D tensor "
- "tile at maximum at subgroup level.");
- assert(resShape[1] <= static_cast<int64_t>(
- uArchInstruction->getMaxLaneLoadStoreSize()) &&
- "StoreScatterOp lane size exceeds max lane load/store size.");
- requiredLayout = xegpu::LayoutAttr::get(
- context, {subgroupSize, 1}, {1, static_cast<int>(resShape[1])});
+ } else {
+ assert(resVecTy.getRank() == 2 &&
+ "Chunked Store must access 2D tensor tile.");
+ if (layoutKind == xegpu::LayoutKind::InstData) {
+ instData[0] = subgroupSize;
+ instData[1] = std::min(static_cast<int>(consumerInstData[1]),
+ uArchInstruction->getMaxLaneLoadStoreSize());
+ requiredLayout = xegpu::LayoutAttr::get(context, instData);
+ } else if (layoutKind == xegpu::LayoutKind::Lane) {
+ laneLayout[0] = subgroupSize;
+ laneData[1] = std::min(static_cast<int>(consumerLaneData[1]),
+ uArchInstruction->getMaxLaneLoadStoreSize());
+ requiredLayout = xegpu::LayoutAttr::get(context, laneLayout, laneData);
}
- break;
- default:
- llvm_unreachable("unsupported layout kind");
}
return requiredLayout;
}
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout.mlir b/mlir/test/Dialect/XeGPU/propagate-layout.mlir
index a8eccfab53c37..85b9b82179e57 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout.mlir
@@ -107,7 +107,7 @@ gpu.module @test {
// CHECK: %[[OFFSET:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
// CHECK-SAME: dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
// CHECK-NEXT: %[[MASK:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<true> : vector<16xi1>
-// CHECK-NEXT: %{{.*}} = xegpu.load %arg1[%[[OFFSET]]], %[[MASK]] <{chunk_size = 16 : i64, layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 16]>}> : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x16xf16>
+// CHECK-NEXT: %{{.*}} = xegpu.load %arg1[%[[OFFSET]]], %[[MASK]] <{chunk_size = 16 : i64, layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>}> : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x16xf16>
func.func @load_gather_with_chunksize(%arg0: memref<8x16xf16>, %arg1: memref<256xf16>, %arg2: memref<8x16xf32>) {
%c0 = arith.constant 0 : index
%0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
>From 6feb7c824ed3fec6b598de5fd3522bd4e36be01e Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Fri, 30 Jan 2026 23:33:15 +0000
Subject: [PATCH 26/35] polish load store layout setup
---
.../mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h | 30 +-
.../mlir/Dialect/XeGPU/uArch/uArchBase.h | 8 +-
.../XeGPU/Transforms/XeGPULayoutImpls.cpp | 296 +++++++++---------
.../XeGPU/Transforms/XeGPUPropagateLayout.cpp | 21 --
4 files changed, 175 insertions(+), 180 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h b/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
index 29e75b57f4a5f..0f138b9defb5c 100644
--- a/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
+++ b/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
@@ -237,6 +237,28 @@ struct LoadGatherInstruction : public Instruction {
int32_t getMaxLaneLoadStoreSize() const { return 16; }
};
+struct StoreMatrixInstruction : public Instruction {
+ StoreMatrixInstruction()
+ : Instruction(InstructionKind::StoreMatrix, InstructionScope::Lane) {}
+ static bool classof(const Instruction *B) {
+ return B->getInstructionKind() == InstructionKind::StoreMatrix;
+ }
+
+ // SPIRV restricts vector size
+ int32_t getMaxLaneLoadStoreSize() const { return 16; }
+};
+
+struct LoadMatrixInstruction : public Instruction {
+ LoadMatrixInstruction()
+ : Instruction(InstructionKind::LoadMatrix, InstructionScope::Lane) {}
+ static bool classof(const Instruction *B) {
+ return B->getInstructionKind() == InstructionKind::LoadMatrix;
+ }
+
+ // SPIRV restricts vector size
+ int32_t getMaxLaneLoadStoreSize() const { return 16; }
+};
+
//===----------------------------------------------------------------------===//
// uArch instances
//===----------------------------------------------------------------------===//
@@ -249,9 +271,11 @@ struct PVCuArch final : public Xe2Plus {
static const Subgroup2DBlockPrefetchInstruction prefetchNdInst;
static const StoreScatterInstruction storeScatterInst;
static const LoadGatherInstruction loadGatherInst;
- static const Instruction *arr[] = {&dpasInst, &loadNdInst,
- &storeNdInst, &prefetchNdInst,
- &storeScatterInst, &loadGatherInst};
+ static const StoreMatrixInstruction storeMatrixInst;
+ static const LoadMatrixInstruction loadMatrixInst;
+ static const Instruction *arr[] = {
+ &dpasInst, &loadNdInst, &storeNdInst, &prefetchNdInst,
+ &storeScatterInst, &loadGatherInst, &storeMatrixInst, &loadMatrixInst};
return arr;
}
diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
index db1984b2edb1d..75a483baa116d 100644
--- a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
+++ b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
@@ -40,7 +40,9 @@ enum class InstructionKind {
Subgroup2DBlockLoad, // Subgroup-level 2D block load instruction
Subgroup2DBlockPrefetch, // Subgroup-level 2D block prefetch instruction
StoreScatter, // Lane-level store (scalar, vector)
- LoadGather // Lane-level load (scalar, vector)
+ LoadGather, // Lane-level load (scalar, vector)
+ StoreMatrix, // Lane-level matrix store to slm
+ LoadMatrix // Lane-level matrix load to slm
// @TODO: Add more instructions as needed
};
@@ -71,6 +73,10 @@ struct Instruction {
return "store";
case InstructionKind::LoadGather:
return "load";
+ case InstructionKind::StoreMatrix:
+ return "store_matrix";
+ case InstructionKind::LoadMatrix:
+ return "load_matrix";
}
llvm_unreachable("Unknown InstructionKind");
}
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpls.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpls.cpp
index 664edbf8c0003..a9428be4cadeb 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpls.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpls.cpp
@@ -588,66 +588,6 @@ xegpu::DistributeLayoutAttr xegpu::setupBitCastResultLayout(
return consumerLayout;
}
-xegpu::DistributeLayoutAttr
-xegpu::setupLoadMatrixAnchorLayout(xegpu::LayoutKind layoutKind,
- VectorType resVectorTy,
- xegpu::DistributeLayoutAttr consumerLayout,
- const xegpu::uArch::uArch *uArch) {
- xegpu::DistributeLayoutAttr requiredLayout;
- auto subgroupSize = uArch->getSubgroupSize();
- SmallVector<int> defaultInstData = {1, subgroupSize};
- SmallVector<int> defaultLaneLayout = {1, subgroupSize};
- SmallVector<int> defaultLaneData = {1, 1};
- auto context = resVectorTy.getContext();
-
- switch (layoutKind) {
- case xegpu::LayoutKind::Subgroup:
- requiredLayout = consumerLayout;
- break;
- case xegpu::LayoutKind::InstData:
- requiredLayout = xegpu::LayoutAttr::get(context, defaultInstData);
- break;
- case xegpu::LayoutKind::Lane:
- requiredLayout =
- xegpu::LayoutAttr::get(context, defaultLaneLayout, defaultLaneData);
- break;
- default:
- llvm_unreachable("unsupported layout kind");
- }
- return requiredLayout;
-}
-
-xegpu::DistributeLayoutAttr
-xegpu::setupStoreMatrixAnchorLayout(xegpu::LayoutKind layoutKind,
- VectorType srcVectorTy,
- const xegpu::uArch::uArch *uArch) {
-
- xegpu::DistributeLayoutAttr requiredLayout;
- auto subgroupSize = uArch->getSubgroupSize();
- SmallVector<int> defaultInstData = {1, subgroupSize};
- SmallVector<int> defaultLaneLayout = {1, subgroupSize};
- SmallVector<int> defaultLaneData = {1, 1};
- auto context = srcVectorTy.getContext();
-
- switch (layoutKind) {
- case xegpu::LayoutKind::Subgroup:
- assert(true &&
- "subgroup layout assignment not supported yet for storeMatrix.");
- break;
- case xegpu::LayoutKind::InstData:
- requiredLayout = xegpu::LayoutAttr::get(context, defaultInstData);
- break;
- case xegpu::LayoutKind::Lane:
- requiredLayout =
- xegpu::LayoutAttr::get(context, defaultLaneLayout, defaultLaneData);
-
- break;
- default:
- llvm_unreachable("unsupported layout kind");
- }
- return requiredLayout;
-}
-
/// Sets up the result layout for an insert strided slice operation.
/// Creates a default layout based on the specified layout kind (InstData or
/// Lane).
@@ -704,141 +644,187 @@ xegpu::DistributeLayoutAttr xegpu::setupInsertStridedSliceResultLayout(
return requiredResLayout;
}
-static xegpu::DistributeLayoutAttr
-getDefaultLaneLayoutAttr(mlir::MLIRContext *ctx, unsigned rank,
- const xegpu::uArch::uArch *uArch) {
- assert((rank == 1 || rank == 2) && "Expected 1D or 2D vector.");
- if (rank == 1) {
- return xegpu::LayoutAttr::get(ctx, {uArch->getSubgroupSize()}, {1});
- }
- return xegpu::LayoutAttr::get(ctx, {1, uArch->getSubgroupSize()}, {1, 1});
-}
-
-/// Sets up the result layout for a load gather operation.
+/// Sets up the anchor layout for load gather and load matrix operation.
+/// load matrix lowers to load gather and 1d block load. All of them share the
+/// same layout setup logic.
/// For Subgroup layout, uses the consumer layout directly.
/// non-chunked loads:
-/// InstData = {1, ..., min(consumer, maxLaneLoadStoreSize *
-/// subgroupSize)}
+/// InstData = {1, ..., min(consumer, maxLaneLoadStoreSize * subgroupSize)}
/// LaneLayout = {1, ..., subgroupSize}
/// lane_data = {1, ..., min(consumer, maxLaneLoadStoreSize)}
/// chunked loads:
/// InstData = {subgroupSize, min(consumer, maxLaneLoadStoreSize)}
/// LaneLayout = {subgroupSize, 1}
/// lane_data={1,min(consumer, maxLaneLoadStoreSize)}
+static xegpu::DistributeLayoutAttr setupGenericLoadAnchorLayout(
+ xegpu::LayoutKind layoutKind, mlir::MLIRContext *context,
+ xegpu::DistributeLayoutAttr consumerLayout, bool isChunkedLoad,
+ int maxChunkSize, int valShapeSize, int subgroupSize) {
+ SmallVector<int64_t> consumerInstData =
+ consumerLayout.getEffectiveInstDataAsInt();
+ SmallVector<int64_t> consumerLaneData =
+ consumerLayout.getEffectiveLaneDataAsInt();
+
+ SmallVector<int> instData(valShapeSize, 1);
+ SmallVector<int> laneLayout(valShapeSize, 1);
+ SmallVector<int> laneData(valShapeSize, 1);
+
+ if (layoutKind == xegpu::LayoutKind::Subgroup) {
+ return consumerLayout;
+ }
+
+ if (!isChunkedLoad) {
+ if (layoutKind == xegpu::LayoutKind::InstData) {
+ instData[valShapeSize - 1] =
+ std::min(static_cast<int>(consumerInstData[valShapeSize - 1]),
+ maxChunkSize * subgroupSize);
+ return xegpu::LayoutAttr::get(context, instData);
+ } else if (layoutKind == xegpu::LayoutKind::Lane) {
+ laneLayout[valShapeSize - 1] = subgroupSize;
+ laneData[valShapeSize - 1] = std::min(
+ static_cast<int>(consumerLaneData[valShapeSize - 1]), maxChunkSize);
+ return xegpu::LayoutAttr::get(context, laneLayout, laneData);
+ }
+ } else {
+ assert(valShapeSize == 2 && "Chunked Store must access 2D tensor tile.");
+ if (layoutKind == xegpu::LayoutKind::InstData) {
+ instData[0] = subgroupSize;
+ instData[1] =
+ std::min(static_cast<int>(consumerInstData[1]), maxChunkSize);
+ return xegpu::LayoutAttr::get(context, instData);
+ } else if (layoutKind == xegpu::LayoutKind::Lane) {
+ laneLayout[0] = subgroupSize;
+ laneData[1] =
+ std::min(static_cast<int>(consumerLaneData[1]), maxChunkSize);
+ return xegpu::LayoutAttr::get(context, laneLayout, laneData);
+ }
+ }
+ return nullptr;
+}
+
+/// Sets up the anchor layout for a load gather operation.
xegpu::DistributeLayoutAttr xegpu::setupLoadGatherAnchorLayout(
- LayoutKind layoutKind, VectorType resVecTy, int chunkSize,
- DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch) {
+ xegpu::LayoutKind layoutKind, VectorType resVecTy, int chunkSize,
+ xegpu::DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch) {
- xegpu::DistributeLayoutAttr requiredLayout;
const int subgroupSize = uArch->getSubgroupSize();
+ int resShapeSize = resVecTy.getShape().size();
+ auto context = resVecTy.getContext();
+
+ const auto *uArchInstruction = dyn_cast<xegpu::uArch::LoadGatherInstruction>(
+ uArch->getInstruction(xegpu::uArch::InstructionKind::LoadGather));
+ int maxChunkSize = uArchInstruction->getMaxLaneLoadStoreSize();
+
+ return setupGenericLoadAnchorLayout(layoutKind, context, consumerLayout,
+ (chunkSize > 1), maxChunkSize,
+ resShapeSize, subgroupSize);
+}
+
+/// Sets up the anchor layout for load matrix operation.
+/// TODO: enhance load matrix to indicate lowering to chunked load or not.
+xegpu::DistributeLayoutAttr
+xegpu::setupLoadMatrixAnchorLayout(xegpu::LayoutKind layoutKind,
+ VectorType resVecTy,
+ xegpu::DistributeLayoutAttr consumerLayout,
+ const xegpu::uArch::uArch *uArch) {
+ const int subgroupSize = uArch->getSubgroupSize();
int resShapeSize = resVecTy.getShape().size();
auto context = resVecTy.getContext();
- const auto *uArchInstruction =
- dyn_cast<xegpu::uArch::StoreScatterInstruction>(
- uArch->getInstruction(xegpu::uArch::InstructionKind::StoreScatter));
+ const auto *uArchInstruction = dyn_cast<xegpu::uArch::LoadMatrixInstruction>(
+ uArch->getInstruction(xegpu::uArch::InstructionKind::LoadMatrix));
+ int maxChunkSize = uArchInstruction->getMaxLaneLoadStoreSize();
- SmallVector<int64_t> consumerInstData =
- consumerLayout.getEffectiveInstDataAsInt();
- SmallVector<int64_t> consumerLaneData =
- consumerLayout.getEffectiveLaneDataAsInt();
+ return setupGenericLoadAnchorLayout(layoutKind, context, consumerLayout,
+ false, maxChunkSize, resShapeSize,
+ subgroupSize);
+}
- SmallVector<int> instData(resShapeSize, 1);
- SmallVector<int> laneLayout(resShapeSize, 1);
- SmallVector<int> laneData(resShapeSize, 1);
+/// Sets up the anchor layout for store scatter and store matrix operation.
+/// store matrix lowers to store scatter and 1d block store. All of them share
+/// the same layout setup logic. For Subgroup layout, not support yet.
+/// non-chunked stores:
+/// InstData = {1, ..., subgroupSize}
+/// LaneLayout = {1, ..., subgroupSize}
+/// lane_data = {1, ..., 1}
+/// chunked stores:
+/// InstData = {subgroupSize, min(srcVec, maxLaneLoadStoreSize)}
+/// LaneLayout = {subgroupSize, 1}
+/// lane_data={1,min(srcVec, maxLaneLoadStoreSize)}
+static xegpu::DistributeLayoutAttr
+setupGenericStoreAnchorLayout(xegpu::LayoutKind layoutKind,
+ mlir::MLIRContext *context, bool isChunkedStore,
+ int maxChunkSize, ArrayRef<int64_t> srcShape,
+ int subgroupSize) {
+
+ int srcShapeSize = srcShape.size();
+ SmallVector<int> instData(srcShapeSize, 1);
+ SmallVector<int> laneLayout(srcShapeSize, 1);
+ SmallVector<int> laneData(srcShapeSize, 1);
if (layoutKind == xegpu::LayoutKind::Subgroup) {
- return consumerLayout;
+ assert(true &&
+ "subgroup layout assignment not supported for storeScatter.");
+ return nullptr;
}
- if (chunkSize == 1) {
+ if (!isChunkedStore) {
if (layoutKind == xegpu::LayoutKind::InstData) {
- instData[resShapeSize - 1] =
- std::min(static_cast<int>(consumerInstData[resShapeSize - 1]),
- uArchInstruction->getMaxLaneLoadStoreSize() * subgroupSize);
- requiredLayout = xegpu::LayoutAttr::get(context, instData);
+ instData[srcShapeSize - 1] = subgroupSize;
+ return xegpu::LayoutAttr::get(context, instData);
} else if (layoutKind == xegpu::LayoutKind::Lane) {
- laneLayout[resShapeSize - 1] = subgroupSize;
- laneData[resShapeSize - 1] =
- std::min(static_cast<int>(consumerLaneData[resShapeSize - 1]),
- uArchInstruction->getMaxLaneLoadStoreSize());
- requiredLayout = xegpu::LayoutAttr::get(context, laneLayout, laneData);
+ laneLayout[srcShapeSize - 1] = subgroupSize;
+ return xegpu::LayoutAttr::get(context, laneLayout, laneData);
}
} else {
- assert(resVecTy.getRank() == 2 &&
- "Chunked Store must access 2D tensor tile.");
+ assert(srcShapeSize == 2 && "Chunked Store must access 2D tensor tile.");
if (layoutKind == xegpu::LayoutKind::InstData) {
instData[0] = subgroupSize;
- instData[1] = std::min(static_cast<int>(consumerInstData[1]),
- uArchInstruction->getMaxLaneLoadStoreSize());
- requiredLayout = xegpu::LayoutAttr::get(context, instData);
+ instData[1] = std::min(static_cast<int>(srcShape[1]), maxChunkSize);
+ return xegpu::LayoutAttr::get(context, instData);
} else if (layoutKind == xegpu::LayoutKind::Lane) {
laneLayout[0] = subgroupSize;
- laneData[1] = std::min(static_cast<int>(consumerLaneData[1]),
- uArchInstruction->getMaxLaneLoadStoreSize());
- requiredLayout = xegpu::LayoutAttr::get(context, laneLayout, laneData);
+ laneData[1] = std::min(static_cast<int>(srcShape[1]), maxChunkSize);
+ return xegpu::LayoutAttr::get(context, laneLayout, laneData);
}
}
- return requiredLayout;
+ return nullptr;
}
+/// Sets up the anchor layout for a store scatter operation.
xegpu::DistributeLayoutAttr
-xegpu::setupStoreScatterAnchorLayout(LayoutKind layoutKind, VectorType srcVecTy,
- int chunkSize, const uArch::uArch *uArch) {
+xegpu::setupStoreScatterAnchorLayout(xegpu::LayoutKind layoutKind,
+ VectorType srcVecTy, int chunkSize,
+ const uArch::uArch *uArch) {
- xegpu::DistributeLayoutAttr requiredLayout;
const int subgroupSize = uArch->getSubgroupSize();
-
- auto srcShape = srcVecTy.getShape();
- int srcShapeSize = srcShape.size();
- SmallVector<int> instData(srcShapeSize);
+ ArrayRef<int64_t> srcShape = srcVecTy.getShape();
+ auto context = srcVecTy.getContext();
const auto *uArchInstruction =
dyn_cast<xegpu::uArch::StoreScatterInstruction>(
uArch->getInstruction(xegpu::uArch::InstructionKind::StoreScatter));
+ int maxChunkSize = uArchInstruction->getMaxLaneLoadStoreSize();
+
+ return setupGenericStoreAnchorLayout(layoutKind, context, (chunkSize > 1),
+ maxChunkSize, srcShape, subgroupSize);
+}
+
+/// Sets up the anchor layout for a store matrix operation.
+xegpu::DistributeLayoutAttr
+xegpu::setupStoreMatrixAnchorLayout(xegpu::LayoutKind layoutKind,
+ VectorType srcVecTy,
+ const xegpu::uArch::uArch *uArch) {
+
+ const int subgroupSize = uArch->getSubgroupSize();
+ ArrayRef<int64_t> srcShape = srcVecTy.getShape();
auto context = srcVecTy.getContext();
- switch (layoutKind) {
- case xegpu::LayoutKind::Subgroup:
- assert(
- true &&
- "subgroup layout assignment not supported yet for store scatter op.");
- break;
- case xegpu::LayoutKind::InstData:
- if (srcVecTy.getRank() == 1) {
- instData[0] = subgroupSize;
- } else {
- assert((srcVecTy.getRank() <= 2) && "StoreScatterOp can access 2D tensor "
- "tile at maximum at subgroup level.");
- if (chunkSize == 1) {
- instData[0] = 1;
- instData[1] = subgroupSize;
- } else {
- instData[0] = subgroupSize;
- instData[1] = std::min(static_cast<int>(srcShape[1]),
- uArchInstruction->getMaxLaneLoadStoreSize());
- }
- }
- requiredLayout = xegpu::LayoutAttr::get(
- context, DenseI32ArrayAttr::get(context, instData));
- break;
- case xegpu::LayoutKind::Lane:
- if (chunkSize == 1)
- requiredLayout =
- getDefaultLaneLayoutAttr(context, srcVecTy.getRank(), uArch);
- else {
- assert((srcVecTy.getRank() <= 2) && "StoreScatterOp can access 2D tensor "
- "tile at maximum at subgroup level.");
- assert(srcShape[1] <= static_cast<int64_t>(
- uArchInstruction->getMaxLaneLoadStoreSize()) &&
- "StoreScatterOp lane size exceeds max lane load/store size.");
- requiredLayout = xegpu::LayoutAttr::get(
- context, {subgroupSize, 1}, {1, static_cast<int>(srcShape[1])});
- }
- break;
- default:
- llvm_unreachable("unsupported layout kind");
- }
- return requiredLayout;
+ const auto *uArchInstruction = dyn_cast<xegpu::uArch::StoreMatrixInstruction>(
+ uArch->getInstruction(xegpu::uArch::InstructionKind::StoreMatrix));
+ int maxChunkSize = uArchInstruction->getMaxLaneLoadStoreSize();
+
+ return setupGenericStoreAnchorLayout(layoutKind, context, false, maxChunkSize,
+ srcShape, subgroupSize);
}
\ No newline at end of file
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 15547b33765fd..a445df492bf1e 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -308,27 +308,6 @@ static LayoutInfo getSIMTLayoutInfoBlockIO(Ty ty,
ty.getContext(), {1, uArch->getSubgroupSize()}, {1, packingFactor}));
}
-/// Helper to get the default layout for a vector type.
-static LayoutInfo getSIMTLayoutInfoScatterIO(VectorType vectorTy,
- const xegpu::uArch::uArch *uArch) {
- // Expecting a 1D or 2D vector.
- assert((vectorTy.getRank() == 1 || vectorTy.getRank() == 2) &&
- "Expected 1D or 2D vector.");
- // Expecting int or float element type.
- assert(vectorTy.getElementType().isIntOrFloat() &&
- "Expected int or float element type.");
- // If the rank is 1, then return default layout for 1D vector.
- const unsigned packingSize{uArch->getGeneralPackedFormatBitSize()};
- if (vectorTy.getRank() == 1)
- return getDefaultSIMTLayoutInfo(vectorTy.getContext(), 1, uArch);
- // Packing factor is determined by the element type bitwidth.
- unsigned bitwidth = vectorTy.getElementType().getIntOrFloatBitWidth();
- int packingFactor = bitwidth < packingSize ? packingSize / bitwidth : 1;
- return LayoutInfo(xegpu::LayoutAttr::get(vectorTy.getContext(),
- {uArch->getSubgroupSize(), 1},
- {1, packingFactor}));
-}
-
/// Helper Function to get the expected layouts for DPAS operands. `lane_data`
/// is set according to the following criteria:
/// * For A operand, the data must be packed in minimum
>From e1f27e3b33097214f60d913dbd8d4c3d23a5d0db Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Sat, 31 Jan 2026 01:15:52 +0000
Subject: [PATCH 27/35] add test for create_memdesc and alloca
---
.../XeGPU/subgroup-distribute-unit.mlir | 29 +++++++++++++++++++
1 file changed, 29 insertions(+)
diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir
index 3a978f68ad9c0..4a6c81d1309c0 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir
@@ -483,7 +483,36 @@ gpu.func @memref_extract_aligned_pointer_as_index(%arg0 : memref<256x256xf16>, %
gpu.return
}
+// CHECK-LABEL: gpu.func @memref_alloca(
+// CHECK-NEXT: %[[ALLOCA:.*]] = memref.alloca() : memref<2048xi8, 3>
+// CHECK-NEXT: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ALLOCA]] : memref<2048xi8, 3> -> index
+// CHECK-NEXT: %[[CAST:.*]] = arith.index_cast %[[INTPTR]] : index to i64
+gpu.func @memref_alloca(%laneid: index) {
+ %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (memref<2048xi8, 3>) {
+ %alloca = memref.alloca() : memref<2048xi8, 3>
+ gpu.yield %alloca : memref<2048xi8, 3>
+ }
+ %ptr = memref.extract_aligned_pointer_as_index %r : memref<2048xi8, 3> -> index
+ %ptr_i64 = arith.index_cast %ptr : index to i64
+ "some_user_op"(%ptr_i64) : (i64) -> ()
+ gpu.return
+}
+// CHECK-LABEL: gpu.func @create_memdesc(
+// CHECK: %[[W:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (!xegpu.mem_desc<4x128xf32>, memref<2048xi8, 3>) {
+// CHECK: gpu.yield %{{.*}}, %{{.*}} : !xegpu.mem_desc<4x128xf32>, memref<2048xi8, 3>
+// CHECK-NEXT: }
+// CHECK-NEXT: %[[MDesc:.*]] = xegpu.create_mem_desc %[[W]]#1 : memref<2048xi8, 3> -> !xegpu.mem_desc<4x128xf32>
+gpu.func @create_memdesc(%laneid: index, %arg0 : memref<2048xi8, 3>) {
+ %c0 = arith.constant 0 : index
+ %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (!xegpu.mem_desc<4x128xf32>) {
+ %mdesc = xegpu.create_mem_desc %arg0 : memref<2048xi8, 3> -> !xegpu.mem_desc<4x128xf32>
+ gpu.yield %mdesc : !xegpu.mem_desc<4x128xf32>
+ }
+ %25 = xegpu.load_matrix %r[%c0, %c0]: !xegpu.mem_desc<4x128xf32>, index, index -> vector<1x16xf32>
+ "some_user_op"(%25) : (vector<1x16xf32>) -> ()
+ gpu.return
+}
// CHECK-LABEL: gpu.func @vector_transpose(
// CHECK: %[[W:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<2x1xf32>, vector<1x2xf32>) {
>From bea8f8994ba1f31dc4b90bde1607daf271af85ed Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Mon, 2 Feb 2026 19:26:43 +0000
Subject: [PATCH 28/35] remove alloca and create_memdesc pattern
---
.../Transforms/XeGPUSubgroupDistribute.cpp | 60 +------------------
1 file changed, 1 insertion(+), 59 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index fe0042fc3827f..75875c60f1012 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -1896,63 +1896,6 @@ struct MemrefExtractAlignedPointerAsIndexDistribution final
}
};
-struct MemrefAllocaDistribution final : public gpu::WarpDistributionPattern {
- using gpu::WarpDistributionPattern::WarpDistributionPattern;
- LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
- PatternRewriter &rewriter) const override {
- OpOperand *operand = getWarpResult(warpOp, [&](Operation *op) {
- // Check if the yield operand that was produced by the *last* scattered
- // load op to avoid creating multiple copies due to multiple users.
- return llvm::IsaPred<memref::AllocaOp>(op) &&
- warpOp.getTerminator()->getPrevNode() == op;
- });
- if (!operand)
- return rewriter.notifyMatchFailure(
- warpOp, "warp result is not a memref::Alloca op");
- auto allocaOp = operand->get().getDefiningOp<memref::AllocaOp>();
- unsigned operandIdx = operand->getOperandNumber();
- SmallVector<size_t> newRetIndices;
- gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
- rewriter, warpOp, ValueRange{}, TypeRange{}, newRetIndices);
- rewriter.setInsertionPointAfter(newWarpOp);
- auto newAllocaOp = memref::AllocaOp::create(rewriter, newWarpOp.getLoc(),
- allocaOp.getType(), nullptr);
- Value resultVal = newWarpOp.getResult(operandIdx);
- rewriter.replaceAllUsesWith(resultVal, newAllocaOp.getResult());
- return success();
- }
-};
-
-struct CreateMemDescDistribution final : public gpu::WarpDistributionPattern {
- using gpu::WarpDistributionPattern::WarpDistributionPattern;
- LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
- PatternRewriter &rewriter) const override {
- OpOperand *operand = getWarpResult(warpOp, [&](Operation *op) {
- // Check if the yield operand that was produced by the *last* scattered
- // load op to avoid creating multiple copies due to multiple users.
- return llvm::IsaPred<xegpu::CreateMemDescOp>(op) &&
- warpOp.getTerminator()->getPrevNode() == op;
- });
- if (!operand)
- return rewriter.notifyMatchFailure(
- warpOp, "warp result is not a xegpu::CreateMemDesc op");
- auto createMemDescOp =
- operand->get().getDefiningOp<xegpu::CreateMemDescOp>();
- unsigned operandIdx = operand->getOperandNumber();
- SmallVector<size_t> newRetIndices;
- gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
- rewriter, warpOp, createMemDescOp.getSource(),
- TypeRange{createMemDescOp.getSource().getType()}, newRetIndices);
- rewriter.setInsertionPointAfter(newWarpOp);
- auto newCreateMemDescOp = xegpu::CreateMemDescOp::create(
- rewriter, newWarpOp.getLoc(), createMemDescOp.getType(),
- newWarpOp.getResult(newRetIndices[0]));
- Value resultVal = newWarpOp.getResult(operandIdx);
- rewriter.replaceAllUsesWith(resultVal, newCreateMemDescOp.getResult());
- return success();
- }
-};
-
/// Distribute a vector::BitCastOp feeding into yield op of an enclosing
/// `gpu.warp_execute_on_lane_0` region. Bitcast only impacts the innermost
/// diemension of the source/result vectors. Equivalent vector::BitCastOp is
@@ -2076,8 +2019,7 @@ void xegpu::populateXeGPUSubgroupDistributePatterns(
LoadDistribution, StoreDistribution, VectorTransposeDistribution,
VectorBitcastDistribution, LoadMatrixDistribution,
StoreMatrixDistribution,
- MemrefExtractAlignedPointerAsIndexDistribution,
- MemrefAllocaDistribution, CreateMemDescDistribution>(
+ MemrefExtractAlignedPointerAsIndexDistribution>(
patterns.getContext(),
/*pattern benefit=*/PatternHierarchy::Regular);
// For following patterns, we need to override the regular vector distribution
>From 9b96365ada303d1819100847aeb5902ca23007b8 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Mon, 2 Feb 2026 23:08:57 +0000
Subject: [PATCH 29/35] remove switch default
---
.../XeGPU/Transforms/XeGPULayoutImpls.h | 31 ++++++----
.../XeGPU/Transforms/XeGPULayoutImpls.cpp | 58 +++++++------------
2 files changed, 39 insertions(+), 50 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpls.h b/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpls.h
index a4aebb3eb7cac..b7b7cf9dafd38 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpls.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpls.h
@@ -1,5 +1,4 @@
-//===- XeGPULayoutUtils.h - Layout Utilities --------------------------*- C++
-//-*-===//
+//===- XeGPULayoutImpls.h - Layout utility functions ------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -7,8 +6,8 @@
//
//===----------------------------------------------------------------------===//
-#ifndef MLIR_DIALECT_XEGPU_UTILS_XEGPULAYOUTUTILS_H_
-#define MLIR_DIALECT_XEGPU_UTILS_XEGPULAYOUTUTILS_H_
+#ifndef MLIR_DIALECT_XEGPU_UTILS_XEGPULAYOUTIMPLS_H_
+#define MLIR_DIALECT_XEGPU_UTILS_XEGPULAYOUTIMPLS_H_
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
@@ -132,28 +131,36 @@ DistributeLayoutAttr setupBitCastResultLayout(
LayoutKind layoutKind, VectorType srcVectorTy, VectorType resVectorTy,
DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch);
+/// Sets up the result layout for an insert strided slice operation.
+/// Creates a result layout based on the specified layout kind (InstData or
+/// Lane).
DistributeLayoutAttr setupInsertStridedSliceResultLayout(
LayoutKind layoutKind, VectorType resVectorTy,
DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch);
+/// Sets up the anchor layout for a load gather operation.
DistributeLayoutAttr
-setupLoadMatrixAnchorLayout(LayoutKind layoutKind, VectorType vectorTy,
- DistributeLayoutAttr consumerLayout,
+setupLoadGatherAnchorLayout(LayoutKind layoutKind, VectorType vectorTy,
+ int chunkSize, DistributeLayoutAttr consumerLayout,
const uArch::uArch *uArch);
-DistributeLayoutAttr setupStoreMatrixAnchorLayout(LayoutKind layoutKind,
- VectorType vectorTy,
- const uArch::uArch *uArch);
-
+/// Sets up the anchor layout for load matrix operation.
DistributeLayoutAttr
-setupLoadGatherAnchorLayout(LayoutKind layoutKind, VectorType vectorTy,
- int chunkSize, DistributeLayoutAttr consumerLayout,
+setupLoadMatrixAnchorLayout(LayoutKind layoutKind, VectorType vectorTy,
+ DistributeLayoutAttr consumerLayout,
const uArch::uArch *uArch);
+/// Sets up the anchor layout for a store scatter operation.
DistributeLayoutAttr setupStoreScatterAnchorLayout(LayoutKind layoutKind,
VectorType vectorTy,
int chunkSize,
const uArch::uArch *uArch);
+
+/// Sets up the anchor layout for a store matrix operation.
+DistributeLayoutAttr setupStoreMatrixAnchorLayout(LayoutKind layoutKind,
+ VectorType vectorTy,
+ const uArch::uArch *uArch);
+
} // namespace xegpu
} // namespace mlir
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpls.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpls.cpp
index a9428be4cadeb..4732da6db6f7a 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpls.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpls.cpp
@@ -1,4 +1,5 @@
-//===---- XeGPUUtils.cpp - MLIR Utilities for XeGPUOps ------------------===//
+//===---- XeGPULayoutImpls.cpp - MLIR Utilities for XeGPUOps
+//------------------===//
//
// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -6,7 +7,8 @@
//
//===----------------------------------------------------------------------===//
//
-// This file implements utility methods for working with the XeGPU dialect.
+// This file implements layout utility functions for XeGPU dialect
+// transformation.
//
//===----------------------------------------------------------------------===//
@@ -454,8 +456,7 @@ xegpu::SliceAttr xegpu::setupMultiReductionResultLayout(
xegpu::DistributeLayoutAttr srcLayout;
- switch (layoutKind) {
- case xegpu::LayoutKind::Subgroup: {
+ if (layoutKind == xegpu::LayoutKind::Subgroup) {
SmallVector<int64_t> sgLayout(srcRank), sgData(srcRank);
SmallVector<int64_t> consumerSgLayout =
consumerLayout.getEffectiveSgLayoutAsInt();
@@ -490,19 +491,17 @@ xegpu::SliceAttr xegpu::setupMultiReductionResultLayout(
srcLayout =
xegpu::LayoutAttr::get(context, toInt32Attr(sgLayout),
toInt32Attr(sgData), consumerLayout.getOrder());
- break;
- }
- case xegpu::LayoutKind::InstData: {
+ } else if (layoutKind == xegpu::LayoutKind::InstData) {
+
SmallVector<int64_t> instData(srcRank, 1);
instData[srcRank - 2] =
std::min(maxReduceVectorSize, srcShape[srcRank - 2]);
instData[srcRank - 1] = subgroupSize;
srcLayout = xegpu::LayoutAttr::get(context, toInt32Attr(instData));
- break;
- }
- case xegpu::LayoutKind::Lane: {
+ } else if (layoutKind == xegpu::LayoutKind::Lane) {
+
SmallVector<int64_t> laneLayout(srcRank, 1), laneData(srcRank, 1);
laneLayout[srcRank - 1] = subgroupSize;
laneData[srcRank - 2] =
@@ -510,11 +509,6 @@ xegpu::SliceAttr xegpu::setupMultiReductionResultLayout(
srcLayout = xegpu::LayoutAttr::get(context, toInt32Attr(laneLayout),
toInt32Attr(laneData),
consumerLayout.getOrder());
- break;
- }
-
- default:
- llvm_unreachable("unsupported layout kind");
}
return xegpu::SliceAttr::get(context, srcLayout,
@@ -550,13 +544,11 @@ xegpu::DistributeLayoutAttr xegpu::setupBitCastResultLayout(
// source layout.
int bitWidthRatio = srcElemTyBitWidth / resElemTyBitWidth;
int innermostDimLaneLayout = subgroupSize;
- switch (layoutKind) {
- case xegpu::LayoutKind::Subgroup:
+ if (layoutKind == xegpu::LayoutKind::Subgroup) {
assert(sgData.size() == srcShape.size() &&
"sgData must be available for all dimensions");
sgDataValue = sgData[dim];
- break;
- case xegpu::LayoutKind::InstData:
+ } else if (layoutKind == xegpu::LayoutKind::InstData) {
assert(instData.size() == srcShape.size() &&
"instData must be available for all dimensions");
instDataValue = instData[dim];
@@ -567,17 +559,13 @@ xegpu::DistributeLayoutAttr xegpu::setupBitCastResultLayout(
instDataValue *= 2;
assert((srcShape[dim] % instDataValue) == 0 &&
"srcShape, instData, and lanelayout for innermost must be 2^n !");
- break;
- case xegpu::LayoutKind::Lane:
+ } else if (layoutKind == xegpu::LayoutKind::Lane) {
assert(laneData.size() == srcShape.size() &&
"laneData must be available for all dimensions");
laneDataValue = laneData[dim];
while ((laneDataValue <= srcShape[dim]) &&
(laneDataValue % bitWidthRatio != 0))
laneDataValue *= 2;
- break;
- default:
- llvm_unreachable("unsupported layout kind");
}
// Now set only instData and laneData, preserving sgData
xegpu::DistributeLayoutAttr resLayout;
@@ -589,13 +577,13 @@ xegpu::DistributeLayoutAttr xegpu::setupBitCastResultLayout(
}
/// Sets up the result layout for an insert strided slice operation.
-/// Creates a default layout based on the specified layout kind (InstData or
+/// Creates a result layout based on the specified layout kind (InstData or
/// Lane).
/// Subgroup layout is currently not supported for this operation.
-/// InstData layout requires {1, .., subgroupSize} by default.
-/// Lane layout requires {1, ..., subgroupSize} with lane data {1, ..., 1}.
-/// The instData and laneData is adjusted to contain packed data, by checking if
-/// the consumerLayout's innermost dimension.
+/// InstData layout is first set to be {1, .., subgroupSize}.
+/// Lane layout is first set to be {1, ..., subgroupSize} with lane data {1,
+/// ..., 1}. The instData and laneData is then adjusted to contain packed data,
+/// by checking if the consumerLayout's innermost dimension.
xegpu::DistributeLayoutAttr xegpu::setupInsertStridedSliceResultLayout(
xegpu::LayoutKind layoutKind, VectorType resVectorTy,
xegpu::DistributeLayoutAttr consumerLayout,
@@ -620,26 +608,20 @@ xegpu::DistributeLayoutAttr xegpu::setupInsertStridedSliceResultLayout(
int packingFactor = bitwidth < packingSize ? packingSize / bitwidth : 1;
int packedDataSize = subgroupSize * packingFactor;
- switch (layoutKind) {
- case xegpu::LayoutKind::Subgroup:
+ if (layoutKind == xegpu::LayoutKind::Subgroup) {
assert(true &&
"subgroup layout assignment not supported for insertStridedSlice.");
- break;
- case xegpu::LayoutKind::InstData:
+ } else if (layoutKind == xegpu::LayoutKind::InstData) {
instData[resShapeSize - 1] = subgroupSize;
if (consumerInstData[resShapeSize - 1] == packedDataSize)
instData[resShapeSize - 1] = packedDataSize;
requiredResLayout = xegpu::LayoutAttr::get(context, instData);
- break;
- case xegpu::LayoutKind::Lane:
+ } else if (layoutKind == xegpu::LayoutKind::Lane) {
laneLayout[resShapeSize - 1] = subgroupSize;
laneData[resShapeSize - 1] = 1;
if (consumerLaneData[resShapeSize - 1] == packingFactor)
laneData[resShapeSize - 1] = packingFactor;
requiredResLayout = xegpu::LayoutAttr::get(context, laneLayout, laneData);
- break;
- default:
- llvm_unreachable("unsupported layout kind");
}
return requiredResLayout;
}
>From f92c9629a4923a4c5d16decab39830e3e0f6df53 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Tue, 3 Feb 2026 00:03:28 +0000
Subject: [PATCH 30/35] adress feedback
---
.../XeGPU/Transforms/XeGPULayoutImpls.h | 2 +-
.../XeGPU/Transforms/XeGPULayoutImpls.cpp | 46 +++++++++----------
2 files changed, 22 insertions(+), 26 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpls.h b/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpls.h
index b7b7cf9dafd38..078a2d1b2ad0d 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpls.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpls.h
@@ -72,7 +72,7 @@ dropSgLayoutAndDataOnAttrs(ArrayRef<NamedAttribute> attrs);
SmallVector<NamedAttribute> dropInstDataOnAttrs(ArrayRef<NamedAttribute> attrs);
/// Infers the source layout attribute for a broadcast operation given the
-/// result layout attribute, result shape, source shape, and broadcasted dims.
+/// result layout attribute, result shape, and source shape.
DistributeLayoutAttr inferBroadcastSourceLayout(DistributeLayoutAttr resLayout,
ArrayRef<int64_t> resShape,
ArrayRef<int64_t> srcShape);
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpls.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpls.cpp
index 4732da6db6f7a..213e3c8a8293c 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpls.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpls.cpp
@@ -120,18 +120,14 @@ xegpu::removeLayoutAttr<mlir::OpOperand>(const mlir::OpOperand &operand);
void xegpu::removeLayoutAttrs(Operation *op) {
op->walk([&](Operation *nestOp) {
- for (OpOperand &opr : nestOp->getOpOperands())
- removeLayoutAttr(opr);
- for (OpResult result : nestOp->getOpResults())
- removeLayoutAttr(result);
- if (op->hasAttrOfType<DistributeLayoutAttr>("layout"))
- op->removeAttr("layout");
- if (op->hasAttrOfType<DistributeLayoutAttr>("layout_a"))
- op->removeAttr("layout_a");
- if (op->hasAttrOfType<DistributeLayoutAttr>("layout_b"))
- op->removeAttr("layout_b");
- if (op->hasAttrOfType<DistributeLayoutAttr>("layout_cd"))
- op->removeAttr("layout_cd");
+ // Remove all attributes of DistributeLayoutAttr type
+ SmallVector<StringAttr> attrsToRemove;
+ for (auto namedAttr : nestOp->getAttrs()) {
+ if (isa<DistributeLayoutAttr>(namedAttr.getValue()))
+ attrsToRemove.push_back(namedAttr.getName());
+ }
+ for (auto attrName : attrsToRemove)
+ nestOp->removeAttr(attrName);
});
}
@@ -448,15 +444,15 @@ xegpu::SliceAttr xegpu::setupMultiReductionResultLayout(
consumerSliceLayout ? consumerSliceLayout.flatten().getParent()
: consumerLayout;
- auto sgLayoutVec = plainLayout.getEffectiveSgLayoutAsInt();
- const int workgroupSize = std::accumulate(
- sgLayoutVec.begin(), sgLayoutVec.end(), 1, std::multiplies<int64_t>());
const int subgroupSize = uArch->getSubgroupSize();
int64_t maxReduceVectorSize = 1; // could extend to spirv vector Size
xegpu::DistributeLayoutAttr srcLayout;
if (layoutKind == xegpu::LayoutKind::Subgroup) {
+ auto sgLayoutVec = plainLayout.getEffectiveSgLayoutAsInt();
+ const int workgroupSize = std::accumulate(
+ sgLayoutVec.begin(), sgLayoutVec.end(), 1, std::multiplies<int64_t>());
SmallVector<int64_t> sgLayout(srcRank), sgData(srcRank);
SmallVector<int64_t> consumerSgLayout =
consumerLayout.getEffectiveSgLayoutAsInt();
@@ -631,17 +627,21 @@ xegpu::DistributeLayoutAttr xegpu::setupInsertStridedSliceResultLayout(
/// same layout setup logic.
/// For Subgroup layout, uses the consumer layout directly.
/// non-chunked loads:
-/// InstData = {1, ..., min(consumer, maxLaneLoadStoreSize * subgroupSize)}
+/// InstData = {1, ..., min(consumer, maxLaneLoadSize * subgroupSize)}
/// LaneLayout = {1, ..., subgroupSize}
-/// lane_data = {1, ..., min(consumer, maxLaneLoadStoreSize)}
+/// lane_data = {1, ..., min(consumer, maxLaneLoadSize)}
/// chunked loads:
-/// InstData = {subgroupSize, min(consumer, maxLaneLoadStoreSize)}
+/// InstData = {subgroupSize, min(consumer, maxLaneLoadSize)}
/// LaneLayout = {subgroupSize, 1}
-/// lane_data={1,min(consumer, maxLaneLoadStoreSize)}
+/// lane_data={1,min(consumer, maxLaneLoadSize)}
static xegpu::DistributeLayoutAttr setupGenericLoadAnchorLayout(
xegpu::LayoutKind layoutKind, mlir::MLIRContext *context,
xegpu::DistributeLayoutAttr consumerLayout, bool isChunkedLoad,
int maxChunkSize, int valShapeSize, int subgroupSize) {
+
+ if (layoutKind == xegpu::LayoutKind::Subgroup)
+ return consumerLayout;
+
SmallVector<int64_t> consumerInstData =
consumerLayout.getEffectiveInstDataAsInt();
SmallVector<int64_t> consumerLaneData =
@@ -651,10 +651,6 @@ static xegpu::DistributeLayoutAttr setupGenericLoadAnchorLayout(
SmallVector<int> laneLayout(valShapeSize, 1);
SmallVector<int> laneData(valShapeSize, 1);
- if (layoutKind == xegpu::LayoutKind::Subgroup) {
- return consumerLayout;
- }
-
if (!isChunkedLoad) {
if (layoutKind == xegpu::LayoutKind::InstData) {
instData[valShapeSize - 1] =
@@ -731,9 +727,9 @@ xegpu::setupLoadMatrixAnchorLayout(xegpu::LayoutKind layoutKind,
/// LaneLayout = {1, ..., subgroupSize}
/// lane_data = {1, ..., 1}
/// chunked stores:
-/// InstData = {subgroupSize, min(srcVec, maxLaneLoadStoreSize)}
+/// InstData = {subgroupSize, min(srcVec, maxLaneStoreSize)}
/// LaneLayout = {subgroupSize, 1}
-/// lane_data={1,min(srcVec, maxLaneLoadStoreSize)}
+/// lane_data={1,min(srcVec, maxLaneStoreSize)}
static xegpu::DistributeLayoutAttr
setupGenericStoreAnchorLayout(xegpu::LayoutKind layoutKind,
mlir::MLIRContext *context, bool isChunkedStore,
>From e9a38ecd5a1b25c74ceefeba319e6931c46a9136 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Tue, 3 Feb 2026 00:24:49 +0000
Subject: [PATCH 31/35] add load/store instruction interface
---
.../mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h | 44 +++++--------------
.../mlir/Dialect/XeGPU/uArch/uArchBase.h | 44 +++++++++++++++++++
.../XeGPU/Transforms/XeGPULayoutImpls.cpp | 14 +++---
3 files changed, 64 insertions(+), 38 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h b/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
index 0f138b9defb5c..92b47b9c7d448 100644
--- a/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
+++ b/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
@@ -215,48 +215,28 @@ struct SubgroupMatrixMultiplyAcc : public Instruction,
const unsigned packedFormatBitSizeB;
};
-struct StoreScatterInstruction : public Instruction {
- StoreScatterInstruction()
- : Instruction(InstructionKind::StoreScatter, InstructionScope::Lane) {}
- static bool classof(const Instruction *B) {
- return B->getInstructionKind() == InstructionKind::StoreScatter;
+struct LoadGatherInstruction : public LoadGatherInstructionInterface {
+ int32_t getMaxLaneLoadSize(int32_t bitWidth) const override {
+ return 16; // SPIRV restricts vector size
}
-
- // SPIRV restricts vector size
- int32_t getMaxLaneLoadStoreSize() const { return 16; }
};
-struct LoadGatherInstruction : public Instruction {
- LoadGatherInstruction()
- : Instruction(InstructionKind::LoadGather, InstructionScope::Lane) {}
- static bool classof(const Instruction *B) {
- return B->getInstructionKind() == InstructionKind::LoadGather;
+struct StoreScatterInstruction : public StoreScatterInstructionInterface {
+ int32_t getMaxLaneStoreSize(int32_t bitWidth) const override {
+ return 16; // SPIRV restricts vector size
}
-
- // SPIRV restricts vector size
- int32_t getMaxLaneLoadStoreSize() const { return 16; }
};
-struct StoreMatrixInstruction : public Instruction {
- StoreMatrixInstruction()
- : Instruction(InstructionKind::StoreMatrix, InstructionScope::Lane) {}
- static bool classof(const Instruction *B) {
- return B->getInstructionKind() == InstructionKind::StoreMatrix;
+struct LoadMatrixInstruction : public LoadMatrixInstructionInterface {
+ int32_t getMaxLaneLoadSize(int32_t bitWidth) const override {
+ return 16; // SPIRV restricts vector size
}
-
- // SPIRV restricts vector size
- int32_t getMaxLaneLoadStoreSize() const { return 16; }
};
-struct LoadMatrixInstruction : public Instruction {
- LoadMatrixInstruction()
- : Instruction(InstructionKind::LoadMatrix, InstructionScope::Lane) {}
- static bool classof(const Instruction *B) {
- return B->getInstructionKind() == InstructionKind::LoadMatrix;
+struct StoreMatrixInstruction : public StoreMatrixInstructionInterface {
+ int32_t getMaxLaneStoreSize(int32_t bitWidth) const override {
+ return 16; // SPIRV restricts vector size
}
-
- // SPIRV restricts vector size
- int32_t getMaxLaneLoadStoreSize() const { return 16; }
};
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
index 75a483baa116d..f927c75d3bbe5 100644
--- a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
+++ b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
@@ -256,6 +256,50 @@ struct MMAInstructionInterface {
virtual ~MMAInstructionInterface() = default;
};
+struct LoadGatherInstructionInterface : public Instruction {
+ LoadGatherInstructionInterface()
+ : Instruction(InstructionKind::LoadGather, InstructionScope::Lane) {}
+ static bool classof(const Instruction *B) {
+ return B->getInstructionKind() == InstructionKind::LoadGather;
+ }
+
+ virtual int32_t getMaxLaneLoadSize(int32_t bitWidth) const = 0;
+ virtual ~LoadGatherInstructionInterface() = default;
+};
+
+struct StoreScatterInstructionInterface : public Instruction {
+ StoreScatterInstructionInterface()
+ : Instruction(InstructionKind::StoreScatter, InstructionScope::Lane) {}
+ static bool classof(const Instruction *B) {
+ return B->getInstructionKind() == InstructionKind::StoreScatter;
+ }
+
+ virtual int32_t getMaxLaneStoreSize(int32_t bitWidth) const = 0;
+ virtual ~StoreScatterInstructionInterface() = default;
+};
+
+struct LoadMatrixInstructionInterface : public Instruction {
+ LoadMatrixInstructionInterface()
+ : Instruction(InstructionKind::LoadMatrix, InstructionScope::Lane) {}
+ static bool classof(const Instruction *B) {
+ return B->getInstructionKind() == InstructionKind::LoadMatrix;
+ }
+
+ virtual int32_t getMaxLaneLoadSize(int32_t bitWidth) const = 0;
+ virtual ~LoadMatrixInstructionInterface() = default;
+};
+
+struct StoreMatrixInstructionInterface : public Instruction {
+ StoreMatrixInstructionInterface()
+ : Instruction(InstructionKind::StoreMatrix, InstructionScope::Lane) {}
+ static bool classof(const Instruction *B) {
+ return B->getInstructionKind() == InstructionKind::StoreMatrix;
+ }
+
+ virtual int32_t getMaxLaneStoreSize(int32_t bitWidth) const = 0;
+ virtual ~StoreMatrixInstructionInterface() = default;
+};
+
} // namespace uArch
} // namespace xegpu
} // namespace mlir
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpls.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpls.cpp
index 213e3c8a8293c..c3845abbc81b9 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpls.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpls.cpp
@@ -688,10 +688,11 @@ xegpu::DistributeLayoutAttr xegpu::setupLoadGatherAnchorLayout(
const int subgroupSize = uArch->getSubgroupSize();
int resShapeSize = resVecTy.getShape().size();
auto context = resVecTy.getContext();
+ auto elemBitWidth = resVecTy.getElementType().getIntOrFloatBitWidth();
const auto *uArchInstruction = dyn_cast<xegpu::uArch::LoadGatherInstruction>(
uArch->getInstruction(xegpu::uArch::InstructionKind::LoadGather));
- int maxChunkSize = uArchInstruction->getMaxLaneLoadStoreSize();
+ int maxChunkSize = uArchInstruction->getMaxLaneLoadSize(elemBitWidth);
return setupGenericLoadAnchorLayout(layoutKind, context, consumerLayout,
(chunkSize > 1), maxChunkSize,
@@ -709,11 +710,11 @@ xegpu::setupLoadMatrixAnchorLayout(xegpu::LayoutKind layoutKind,
const int subgroupSize = uArch->getSubgroupSize();
int resShapeSize = resVecTy.getShape().size();
auto context = resVecTy.getContext();
+ auto elemBitWidth = resVecTy.getElementType().getIntOrFloatBitWidth();
const auto *uArchInstruction = dyn_cast<xegpu::uArch::LoadMatrixInstruction>(
uArch->getInstruction(xegpu::uArch::InstructionKind::LoadMatrix));
- int maxChunkSize = uArchInstruction->getMaxLaneLoadStoreSize();
-
+ int maxChunkSize = uArchInstruction->getMaxLaneLoadSize(elemBitWidth);
return setupGenericLoadAnchorLayout(layoutKind, context, consumerLayout,
false, maxChunkSize, resShapeSize,
subgroupSize);
@@ -779,12 +780,12 @@ xegpu::setupStoreScatterAnchorLayout(xegpu::LayoutKind layoutKind,
const int subgroupSize = uArch->getSubgroupSize();
ArrayRef<int64_t> srcShape = srcVecTy.getShape();
auto context = srcVecTy.getContext();
+ auto elemBitWidth = srcVecTy.getElementType().getIntOrFloatBitWidth();
const auto *uArchInstruction =
dyn_cast<xegpu::uArch::StoreScatterInstruction>(
uArch->getInstruction(xegpu::uArch::InstructionKind::StoreScatter));
- int maxChunkSize = uArchInstruction->getMaxLaneLoadStoreSize();
-
+ int maxChunkSize = uArchInstruction->getMaxLaneStoreSize(elemBitWidth);
return setupGenericStoreAnchorLayout(layoutKind, context, (chunkSize > 1),
maxChunkSize, srcShape, subgroupSize);
}
@@ -798,10 +799,11 @@ xegpu::setupStoreMatrixAnchorLayout(xegpu::LayoutKind layoutKind,
const int subgroupSize = uArch->getSubgroupSize();
ArrayRef<int64_t> srcShape = srcVecTy.getShape();
auto context = srcVecTy.getContext();
+ auto elemBitWidth = srcVecTy.getElementType().getIntOrFloatBitWidth();
const auto *uArchInstruction = dyn_cast<xegpu::uArch::StoreMatrixInstruction>(
uArch->getInstruction(xegpu::uArch::InstructionKind::StoreMatrix));
- int maxChunkSize = uArchInstruction->getMaxLaneLoadStoreSize();
+ int maxChunkSize = uArchInstruction->getMaxLaneStoreSize(elemBitWidth);
return setupGenericStoreAnchorLayout(layoutKind, context, false, maxChunkSize,
srcShape, subgroupSize);
>From e2fe8b892bf623d03d164df0f66094d960d6f51e Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Tue, 3 Feb 2026 18:30:07 +0000
Subject: [PATCH 32/35] address feedback: add check for collapseDims
---
.../mlir/Dialect/XeGPU/IR/XeGPUAttrs.td | 8 ++++----
mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp | 19 +++++++++++++++++--
.../XeGPU/Transforms/XeGPULayoutImpls.cpp | 3 +--
3 files changed, 22 insertions(+), 8 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index 62ce7623ea14f..20dc8dd367203 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -246,8 +246,8 @@ def DistributeLayoutAttr: AttrInterface<"DistributeLayoutAttr"> {
"int64_t": $instData,
"int64_t": $laneData)>,
InterfaceMethod<[{Derive a new layout by collapsing groups of dimensions. Each inner array in
- `dimGroups` specifies a group of dimensions that are collapsed into a single
- dimension in the derived layout.}],
+ `dimGroups` specifies a group of adjacent dimensions that are collapsed into
+ a single dimension in the derived layout.}],
"xegpu::DistributeLayoutAttr",
"collapseDims",
(ins "SmallVector<SmallVector<int64_t>>": $dimGroups)>,
@@ -527,7 +527,7 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttr]> {
DistributeLayoutAttr setDimData(int64_t dim, int64_t sgData, int64_t instData, int64_t laneData);
// Derive a new layout by collapsing groups of dimensions.
- // Each inner array in `dimGroups` specifies a set of dimensions
+ // Each inner array in `dimGroups` specifies a set of adjacent dimensions
// that are collapsed into a single dimension in the derived layout.
DistributeLayoutAttr collapseDims(SmallVector<SmallVector<int64_t>> dimGroups) const;
@@ -708,7 +708,7 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttr]> {
DistributeLayoutAttr setDimData(int64_t dim, int64_t sgData, int64_t instData, int64_t laneData);
// Derive a new layout by collapsing groups of dimensions.
- // Each inner array in `dimGroups` specifies a set of dimensions
+ // Each inner array in `dimGroups` specifies a set of adjacent dimensions
// that are collapsed into a single dimension in the derived layout.
DistributeLayoutAttr collapseDims(SmallVector<SmallVector<int64_t>> dimGroups) const;
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 7df4af66e7b01..b36149baadec9 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -502,7 +502,7 @@ DistributeLayoutAttr LayoutAttr::setDimData(int64_t dim, int64_t sgData,
}
// Derive a new layout by collapsing groups of dimensions.
-// Each inner array in `dimGroups` specifies a set of dimensions
+// Each inner array in `dimGroups` specifies a set of adjacent dimensions
// that are collapsed into a single dimension in the derived layout.
DistributeLayoutAttr
LayoutAttr::collapseDims(SmallVector<SmallVector<int64_t>> dimGroups) const {
@@ -527,6 +527,7 @@ LayoutAttr::collapseDims(SmallVector<SmallVector<int64_t>> dimGroups) const {
SmallVector<int64_t> collapsedLaneLayout;
SmallVector<int64_t> collapsedLaneData;
SmallVector<int64_t> collapsedOrder;
+ SetVector<int64_t> coveredDims;
for (const auto &group : dimGroups) {
@@ -534,8 +535,17 @@ LayoutAttr::collapseDims(SmallVector<SmallVector<int64_t>> dimGroups) const {
int64_t collapsedSg = 1, collapsedSgD = 1, collapsedInst = 1;
int64_t collapsedLaneL = 1, collapsedLaneD = 1;
int64_t collapsedOrderValue = -1;
-
+ int64_t dimBeforeCurrent = group.front() - 1;
for (int64_t dimIdx : group) {
+ // no two groups can cover the same dimension
+ if (!coveredDims.insert(dimIdx))
+ llvm::report_fatal_error(Twine("dimension ") + Twine(dimIdx) +
+ " is covered more than once");
+ // dims within group must be adjacent
+ if (dimBeforeCurrent != (dimIdx - 1))
+ llvm::report_fatal_error("dimensions being collapsed must be adjacent");
+ dimBeforeCurrent = dimIdx;
+
collapsedSg *= sgLayout[dimIdx];
collapsedSgD *= sgData[dimIdx];
collapsedInst *= instData[dimIdx];
@@ -553,6 +563,11 @@ LayoutAttr::collapseDims(SmallVector<SmallVector<int64_t>> dimGroups) const {
collapsedOrder.push_back(collapsedOrderValue);
}
+ // check covered all dimensions
+ if (coveredDims.size() != sgLayout.size())
+ llvm::report_fatal_error(
+ "not all dimensions are covered in collapseGroups");
+
// Create collapsed layout
SmallVector<int32_t> collapsedSgLayout32(collapsedSgLayout.begin(),
collapsedSgLayout.end());
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpls.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpls.cpp
index c3845abbc81b9..201075b900fdb 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpls.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpls.cpp
@@ -225,7 +225,6 @@ xegpu::inferBitCastSourceLayout(xegpu::DistributeLayoutAttr resLayout,
}
}
- // Now set only instData and laneData, preserving sgData
xegpu::DistributeLayoutAttr finalSrcLayout;
finalSrcLayout =
resLayout.setDimData(dim, sgDataValue, instDataValue, laneDataValue);
@@ -327,11 +326,11 @@ xegpu::inferShapeCastSourceLayout(xegpu::DistributeLayoutAttr resLayout,
ArrayRef<int64_t> dst) -> bool {
// each dim in src can be mapped to one or more dims in dst whose product
// equals to the src dim
- splitDimGroups.clear();
size_t srcIdx = 0;
int64_t accumulatedSize = 1;
SmallVector<int64_t> currentDstDims;
+ splitDimGroups.clear();
for (size_t dstIdx = 0; dstIdx < dst.size(); ++dstIdx) {
if (srcIdx >= src.size())
return false;
>From 3b0bf7a3782200b811fba713d0713993a025f03f Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Tue, 3 Feb 2026 19:42:29 +0000
Subject: [PATCH 33/35] address feedback: move Dim expansion checkers to
utility functions
---
.../mlir/Dialect/XeGPU/IR/XeGPUAttrs.td | 12 ++--
.../mlir/Dialect/XeGPU/Utils/XeGPUUtils.h | 9 +++
mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp | 32 +++++------
.../XeGPU/Transforms/XeGPULayoutImpls.cpp | 55 ++-----------------
.../Transforms/XeGPUSubgroupDistribute.cpp | 5 +-
.../Transforms/XeGPUWgToSgDistribute.cpp | 23 +-------
mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp | 55 +++++++++++++++++++
7 files changed, 97 insertions(+), 94 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index 20dc8dd367203..1251631955580 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -226,11 +226,11 @@ def DistributeLayoutAttr: AttrInterface<"DistributeLayoutAttr"> {
InterfaceMethod<"Derive a new layout with sg_data, inst_data and lane_data set to 1 for the specified unit dims",
"xegpu::DistributeLayoutAttr",
"setUnitDimData",
- /*args=*/(ins "const llvm::SetVector<int64_t>": $unitDims)>,
+ /*args=*/(ins "const SmallVector<int64_t>": $unitDims)>,
InterfaceMethod<"Derive a new layout with sg_lane and lane_layout set to 1 for the specified unit dims",
"xegpu::DistributeLayoutAttr",
"setUnitDimLayout",
- /*args=*/(ins "const llvm::SetVector<int64_t>": $unitDims)>,
+ /*args=*/(ins "const SmallVector<int64_t>": $unitDims)>,
InterfaceMethod<[{Delinearizes a linear ID into its multidimensional
indices based on the effective layout level.}],
"FailureOr<SmallVector<Value>>",
@@ -516,10 +516,10 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttr]> {
}
//set the layout for the sepcified unit dims: sg_data, inst_data and lane_data to 1
- DistributeLayoutAttr setUnitDimData(SetVector<int64_t> unitDims) const;
+ DistributeLayoutAttr setUnitDimData(SmallVector<int64_t> unitDims) const;
//set the layout for the sepcified unit dims: sg_lane and lane_layout to 1
- DistributeLayoutAttr setUnitDimLayout(SetVector<int64_t> unitDims) const;
+ DistributeLayoutAttr setUnitDimLayout(SmallVector<int64_t> unitDims) const;
// Derive a new layout with sg_data, inst_data and lane_data set to the
// specified values for the given dimension. Passing -1 for any parameter
@@ -697,10 +697,10 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttr]> {
}
//set the layout for the sepcified unit dims: sg_data, inst_data and lane_data to 1
- DistributeLayoutAttr setUnitDimData(SetVector<int64_t> unitDims) const;
+ DistributeLayoutAttr setUnitDimData(SmallVector<int64_t> unitDims) const;
//set the layout for the sepcified unit dims: sg_lane and lane_layout to 1
- DistributeLayoutAttr setUnitDimLayout(SetVector<int64_t> unitDims) const;
+ DistributeLayoutAttr setUnitDimLayout(SmallVector<int64_t> unitDims) const;
// Derive a new layout with sg_data, inst_data and lane_data set to the
// specified values for the given dimension. Passing -1 for any parameter
diff --git a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
index d1fee2126daee..4443f86d1e4e2 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
@@ -186,6 +186,15 @@ bool requirePacked(const LayoutAttr layout);
/// Helper function to check if the layout requires a transpose effect.
bool requireTranspose(const LayoutAttr layout, const uArch::uArch *uArch);
+// Check if dst shape is an expansion of src shape by inserting unit dimensions.
+bool matchUnitDimExpansion(ArrayRef<int64_t> src, ArrayRef<int64_t> dst,
+ SmallVector<int64_t> &expandedUnitDims);
+
+// Checks if dst shape is an expansion of src shape where each dimension in src
+// is split into one or more consecutive dimensions in dst
+bool matchSplitDimExpansion(ArrayRef<int64_t> src, ArrayRef<int64_t> dst,
+ SmallVector<SmallVector<int64_t>> &splitDimGroups);
+
} // namespace xegpu
} // namespace mlir
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index b36149baadec9..dcf128c94d20e 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -398,7 +398,7 @@ bool LayoutAttr::isEqualTo(const xegpu::DistributeLayoutAttr &other) {
// set the layout for unit dims: sg_data, inst_data and lane_data to 1
DistributeLayoutAttr
-LayoutAttr::setUnitDimData(SetVector<int64_t> unitDims) const {
+LayoutAttr::setUnitDimData(SmallVector<int64_t> unitDims) const {
auto sgDataOpt = getSgData();
auto instDataOpt = getInstData();
auto laneDataOpt = getLaneData();
@@ -439,7 +439,7 @@ LayoutAttr::setUnitDimData(SetVector<int64_t> unitDims) const {
// set the layout for the sepcified unit dims: sg_lane and lane_layout to 1
DistributeLayoutAttr
-LayoutAttr::setUnitDimLayout(SetVector<int64_t> unitDims) const {
+LayoutAttr::setUnitDimLayout(SmallVector<int64_t> unitDims) const {
auto sgLayoutOpt = getSgLayout();
auto laneLayoutOpt = getLaneLayout();
@@ -767,8 +767,8 @@ bool SliceAttr::isEqualTo(const xegpu::DistributeLayoutAttr &other) {
// shape is of rank 2, if we want to set unit dim [0] in sliced space, it maps
// to dim [0] in parent space; if we want to set unit dim [1] in sliced space,
// it maps to dim [2] in parent space.
-static SetVector<int64_t>
-mapSlicedDimsToParentSpace(const SetVector<int64_t> &dimsToMap,
+static SmallVector<int64_t>
+mapSlicedDimsToParentSpace(const SmallVector<int64_t> &dimsToMap,
ArrayRef<int64_t> sliceDims) {
// Rather than recovering the exact parent rank, we compute a safe upper bound
// so that dimsToMap can be adjusted safely. This upper bound is defined as
@@ -791,10 +791,10 @@ mapSlicedDimsToParentSpace(const SetVector<int64_t> &dimsToMap,
}
// Map unit dims from sliced space to parent space
- SetVector<int64_t> adjustUnitDims;
+ SmallVector<int64_t> adjustUnitDims;
for (auto dim : dimsToMap) {
int64_t mappedDim = remainingDims[dim];
- adjustUnitDims.insert(mappedDim);
+ adjustUnitDims.push_back(mappedDim);
}
return adjustUnitDims;
@@ -802,12 +802,12 @@ mapSlicedDimsToParentSpace(const SetVector<int64_t> &dimsToMap,
// set the layout for unit dims: sg_data, inst_data and lane_data to 1
DistributeLayoutAttr
-SliceAttr::setUnitDimData(SetVector<int64_t> unitDims) const {
+SliceAttr::setUnitDimData(SmallVector<int64_t> unitDims) const {
DistributeLayoutAttr parentLayout = getParent();
ArrayRef<int64_t> sliceDims = getDims().asArrayRef();
- SetVector<int64_t> adjustUnitDims =
+ SmallVector<int64_t> adjustUnitDims =
mapSlicedDimsToParentSpace(unitDims, sliceDims);
return SliceAttr::get(getContext(),
@@ -816,12 +816,12 @@ SliceAttr::setUnitDimData(SetVector<int64_t> unitDims) const {
// set the layout for the sepcified unit dims: sg_lane and lane_layout to 1
DistributeLayoutAttr
-SliceAttr::setUnitDimLayout(SetVector<int64_t> unitDims) const {
+SliceAttr::setUnitDimLayout(SmallVector<int64_t> unitDims) const {
DistributeLayoutAttr parentLayout = getParent();
ArrayRef<int64_t> sliceDims = getDims().asArrayRef();
- SetVector<int64_t> adjustUnitDims =
+ SmallVector<int64_t> adjustUnitDims =
mapSlicedDimsToParentSpace(unitDims, sliceDims);
return SliceAttr::get(
@@ -835,10 +835,10 @@ DistributeLayoutAttr SliceAttr::setDimData(int64_t dim, int64_t sgData,
ArrayRef<int64_t> sliceDims = getDims().asArrayRef();
auto parent = dyn_cast<LayoutAttr>(getParent());
- SetVector<int64_t> dimSet;
- dimSet.insert(dim);
- SetVector<int64_t> adjustDims = mapSlicedDimsToParentSpace(dimSet, sliceDims);
-
+ SmallVector<int64_t> dimSet;
+ dimSet.push_back(dim);
+ SmallVector<int64_t> adjustDims =
+ mapSlicedDimsToParentSpace(dimSet, sliceDims);
return SliceAttr::get(
getContext(),
parent.setDimData(adjustDims[0], sgData, instData, laneData), getDims());
@@ -856,8 +856,8 @@ SliceAttr::collapseDims(SmallVector<SmallVector<int64_t>> dimGroups) const {
// go through dimGroups and map each dim from sliced space to parent space
SmallVector<SmallVector<int64_t>> adjustedDimGroups;
for (const auto &group : dimGroups) {
- SetVector<int64_t> groupSet(group.begin(), group.end());
- SetVector<int64_t> mappedDims =
+ SmallVector<int64_t> groupSet(group.begin(), group.end());
+ SmallVector<int64_t> mappedDims =
mapSlicedDimsToParentSpace(groupSet, sliceDims);
adjustedDimGroups.push_back(
SmallVector<int64_t>(mappedDims.begin(), mappedDims.end()));
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpls.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpls.cpp
index 201075b900fdb..407ee36d16769 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpls.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpls.cpp
@@ -294,22 +294,8 @@ xegpu::inferShapeCastSourceLayout(xegpu::DistributeLayoutAttr resLayout,
// Use case 1: Check if shapes only differ by expanding unit dimensions (like
// expand_dims)
SmallVector<int64_t> expandedUnitDims;
- auto checkOnlyExpandUnitDims = [&](ArrayRef<int64_t> src,
- ArrayRef<int64_t> dst) -> bool {
- // All unit dimensions in dst that don't appear in src are the expanded
- // unit dimensions
- size_t srcIdx = 0;
- for (size_t dstIdx = 0; dstIdx < dst.size(); ++dstIdx)
- if (srcIdx < src.size() && src[srcIdx] == dst[dstIdx])
- srcIdx++;
- else if (dst[dstIdx] == 1)
- expandedUnitDims.push_back(dstIdx);
- else
- return false;
- return srcIdx == src.size();
- };
- if (checkOnlyExpandUnitDims(srcShape, resShape)) {
+ if (xegpu::matchUnitDimExpansion(srcShape, resShape, expandedUnitDims)) {
// create a slice layout for the source by removing the expanded unit dims
auto sliceDimsAttr = DenseI64ArrayAttr::get(
resLayout.getContext(), ArrayRef<int64_t>(expandedUnitDims));
@@ -321,42 +307,11 @@ xegpu::inferShapeCastSourceLayout(xegpu::DistributeLayoutAttr resLayout,
// Maps each source dimension to the range of destination dimensions it splits
// into
SmallVector<SmallVector<int64_t>> splitDimGroups;
-
- auto checkSplitDims = [&](ArrayRef<int64_t> src,
- ArrayRef<int64_t> dst) -> bool {
- // each dim in src can be mapped to one or more dims in dst whose product
- // equals to the src dim
- size_t srcIdx = 0;
- int64_t accumulatedSize = 1;
- SmallVector<int64_t> currentDstDims;
-
- splitDimGroups.clear();
- for (size_t dstIdx = 0; dstIdx < dst.size(); ++dstIdx) {
- if (srcIdx >= src.size())
- return false;
- accumulatedSize *= dst[dstIdx];
- currentDstDims.push_back(dstIdx);
-
- if (accumulatedSize == src[srcIdx]) {
- // Record the mapping: srcIdx -> currentDstDims
- splitDimGroups.push_back(currentDstDims);
- // move to next src dim
- srcIdx++;
- accumulatedSize = 1;
- currentDstDims.clear();
- } else if (accumulatedSize > src[srcIdx]) {
- return false;
- }
- }
- return srcIdx == src.size();
- };
-
- if (checkSplitDims(srcShape, resShape)) {
+ if (xegpu::matchSplitDimExpansion(srcShape, resShape, splitDimGroups))
return resLayout.collapseDims(splitDimGroups);
- }
- auto checkCombineToInnerMostDim = [&](ArrayRef<int64_t> src,
- ArrayRef<int64_t> dst) -> bool {
+ auto matchCollapseToInnermostDim = [&](ArrayRef<int64_t> src,
+ ArrayRef<int64_t> dst) -> bool {
// only one non-unit dim in dst which is the innermost dim
if ((dst.size() != 2) && (dst.size() != 1))
return false;
@@ -367,7 +322,7 @@ xegpu::inferShapeCastSourceLayout(xegpu::DistributeLayoutAttr resLayout,
return (dst[0] == 1) && (dst[1] == srcSize);
};
- if (checkCombineToInnerMostDim(srcShape, resShape)) {
+ if (matchCollapseToInnermostDim(srcShape, resShape)) {
int srcShapeSize = srcShape.size();
int resShapeSize = resShape.size();
auto context = resLayout.getContext();
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 75875c60f1012..4b762fb363fca 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -1533,8 +1533,9 @@ struct VectorBroadcastDistribution : public gpu::WarpDistributionPattern {
}
// case 2: source and result have same rank
if (rankDiff == 0) {
- SetVector<int64_t> broadcastUnitDims =
- broadcastOp.computeBroadcastedUnitDims();
+ auto broadcastUnitDimsSet = broadcastOp.computeBroadcastedUnitDims();
+ SmallVector<int64_t> broadcastUnitDims(broadcastUnitDimsSet.begin(),
+ broadcastUnitDimsSet.end());
bool isEqualTo = sourceLayout.isEqualTo(resultLayout);
if (!isEqualTo)
return rewriter.notifyMatchFailure(
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index c0487cd709c3d..a90c218bdaae3 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -1114,27 +1114,10 @@ struct WgToSgVectorShapeCastOp
return failure();
ArrayRef<int64_t> srcShape = srcType.getShape();
- llvm::SetVector<int64_t> expandedUnitDims;
-
- // Check if shapes only differ by expanding unit dimensions (like
- // expand_dims)
- auto checkOnlyExpandUnitDims = [&](ArrayRef<int64_t> src,
- ArrayRef<int64_t> dst) -> bool {
- // All unit dimensions in dst that don't appear in src are the expanded
- // unit dimensions
- size_t srcIdx = 0;
- for (size_t dstIdx = 0; dstIdx < dst.size(); ++dstIdx)
- if (srcIdx < src.size() && src[srcIdx] == dst[dstIdx])
- srcIdx++;
- else if (dst[dstIdx] == 1)
- expandedUnitDims.insert(dstIdx);
- else
- return false;
- return srcIdx == src.size();
- };
- xegpu::DistributeLayoutAttr layoutToDistribute = layout;
- if (checkOnlyExpandUnitDims(srcShape, wgShape)) {
+ xegpu::DistributeLayoutAttr layoutToDistribute = layout;
+ SmallVector<int64_t> expandedUnitDims;
+ if (xegpu::matchUnitDimExpansion(srcShape, wgShape, expandedUnitDims)) {
xegpu::DistributeLayoutAttr sourceLayout =
xegpu::getTemporaryLayout(op->getOpOperand(0));
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
index 259c5b3fa89c8..13d5a5fd54023 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
@@ -681,3 +681,58 @@ bool xegpu::requireTranspose(const xegpu::LayoutAttr layout,
return false;
return laneLayout[0] == uArch->getSubgroupSize() && laneLayout[1] == 1;
}
+
+// Check if dst shape is an expansion of src shape by inserting unit dimensions.
+// Returns true if all dimensions in src match corresponding dimensions in dst
+// (after skipping unit dimensions), and populates expandedUnitDims with the
+// indices of the unit dimensions in dst that were added (not present in src).
+// Example: src=[2,3], dst=[1,2,3,1] -> true, expandedUnitDims=[0,3]
+bool xegpu::matchUnitDimExpansion(ArrayRef<int64_t> src, ArrayRef<int64_t> dst,
+ SmallVector<int64_t> &expandedUnitDims) {
+ // All unit dimensions in dst that don't appear in src are the expanded
+ // unit dimensions
+ size_t srcIdx = 0;
+ for (size_t dstIdx = 0; dstIdx < dst.size(); ++dstIdx)
+ if (srcIdx < src.size() && src[srcIdx] == dst[dstIdx])
+ srcIdx++;
+ else if (dst[dstIdx] == 1)
+ expandedUnitDims.push_back(dstIdx);
+ else
+ return false;
+ return srcIdx == src.size();
+};
+
+// Checks if dst shape is an expansion of src shape where each dimension in src
+// is split into one or more consecutive dimensions in dst whose product equals
+// the original dimension. Populates splitDimGroups with groups of dst indices
+// that correspond to each src dimension. Example: src=[6,4], dst=[2,3,2,2] ->
+// true
+bool xegpu::matchSplitDimExpansion(
+ ArrayRef<int64_t> src, ArrayRef<int64_t> dst,
+ SmallVector<SmallVector<int64_t>> &splitDimGroups) {
+ // each dim in src can be mapped to one or more dims in dst whose product
+ // equals to the src dim
+ size_t srcIdx = 0;
+ int64_t accumulatedSize = 1;
+ SmallVector<int64_t> currentDstDims;
+
+ splitDimGroups.clear();
+ for (size_t dstIdx = 0; dstIdx < dst.size(); ++dstIdx) {
+ if (srcIdx >= src.size())
+ return false;
+ accumulatedSize *= dst[dstIdx];
+ currentDstDims.push_back(dstIdx);
+
+ if (accumulatedSize == src[srcIdx]) {
+ // Record the mapping: srcIdx -> currentDstDims
+ splitDimGroups.push_back(currentDstDims);
+ // move to next src dim
+ srcIdx++;
+ accumulatedSize = 1;
+ currentDstDims.clear();
+ } else if (accumulatedSize > src[srcIdx]) {
+ return false;
+ }
+ }
+ return srcIdx == src.size();
+};
>From ff8691babeb45bce083d1755c259f20f7b97f191 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Tue, 3 Feb 2026 23:52:27 +0000
Subject: [PATCH 34/35] adding comments and tests
---
.../XeGPU/Transforms/XeGPULayoutImpls.h | 2 +-
.../XeGPU/Transforms/XeGPULayoutImpls.cpp | 120 +++++++++++++++---
.../XeGPU/Transforms/XeGPUPropagateLayout.cpp | 2 +-
.../XeGPU/propagate-layout-inst-data.mlir | 38 ++++++
.../XeGPU/propagate-layout-subgroup.mlir | 18 +++
mlir/test/Dialect/XeGPU/propagate-layout.mlir | 39 ++++++
6 files changed, 202 insertions(+), 17 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpls.h b/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpls.h
index 078a2d1b2ad0d..758b783953732 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpls.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpls.h
@@ -135,7 +135,7 @@ DistributeLayoutAttr setupBitCastResultLayout(
/// Creates a result layout based on the specified layout kind (InstData or
/// Lane).
DistributeLayoutAttr setupInsertStridedSliceResultLayout(
- LayoutKind layoutKind, VectorType resVectorTy,
+ LayoutKind layoutKind, VectorType srcVectorTy, VectorType resVectorTy,
DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch);
/// Sets up the anchor layout for a load gather operation.
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpls.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpls.cpp
index 407ee36d16769..a008990a0a063 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpls.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpls.cpp
@@ -330,22 +330,42 @@ xegpu::inferShapeCastSourceLayout(xegpu::DistributeLayoutAttr resLayout,
auto resLaneLayout = resLayout.getEffectiveLaneLayoutAsInt();
auto resLaneData = resLayout.getEffectiveLaneDataAsInt();
- // get the layout info from the innermost dim of result layout
+ // Extract layout info from result's innermost dimension and apply to
+ // source's innermost dimension while setting all other dimensions to 1.
+ // The inferred layout is restricted by srcShape to ensure it fits within
+ // the source dimensions.
+ // Examples:
+ // srcShape=[8, 16, 32], resShape=[1, 4096], resInstData=[1, 16]
+ // -> inferredInstData=[1, 1, min(16, 32)]=[1, 1, 16]
+ // srcShape=[4, 8, 64], resShape=[2048], resLaneLayout=[16],
+ // resLaneData=[2]
+ // -> inferredLaneData=[1, 1, min(2, 64/16)]=[1, 1, 2]
+ // -> inferredLaneLayout=[1, 1, 16]
if (resInstData.size() != 0) {
+ // assert resInstData must be 1 for all but the innermost dim
+ for (int i = 0; i < resShapeSize - 1; i++) {
+ assert(resInstData[i] == 1 &&
+ "only innermost dim can have non-unit instData");
+ }
SmallVector<int> inferredInstData(srcShapeSize, 1);
- assert((resShapeSize == 1 || resInstData[0] == 1) &&
- "only innermost dim can have data and instData layout");
- inferredInstData[srcShapeSize - 1] = resInstData[resShapeSize - 1];
+ inferredInstData[srcShapeSize - 1] =
+ std::min(resLaneData[resShapeSize - 1], srcShape[srcShapeSize - 1]);
return xegpu::LayoutAttr::get(context, inferredInstData);
}
if (resLaneLayout.size() != 0) {
+ for (int i = 0; i < resShapeSize - 1; i++) {
+ assert(resLaneData[i] == 1 &&
+ "only innermost dim can have non-unit instData");
+ }
+ assert("srcShape[srcShapeSize - 1] >= resLaneLayout[resShapeSize - 1]" &&
+ "source innermost dim must be >= result lane layout");
SmallVector<int> inferredLaneLayout(srcShapeSize, 1);
SmallVector<int> inferredLaneData(srcShapeSize, 1);
- assert((resShapeSize == 1 || resLaneLayout[0] == 1) &&
- "only innermost dim can have data and lane layout");
inferredLaneLayout[srcShapeSize - 1] = resLaneLayout[resShapeSize - 1];
- inferredLaneData[srcShapeSize - 1] = resLaneData[resShapeSize - 1];
+ inferredLaneData[srcShapeSize - 1] = std::min(
+ resLaneData[resShapeSize - 1],
+ srcShape[srcShapeSize - 1] / inferredLaneLayout[srcShapeSize - 1]);
return xegpu::LayoutAttr::get(context, inferredLaneLayout,
inferredLaneData);
}
@@ -365,12 +385,36 @@ xegpu::inferShapeCastSourceLayout(xegpu::DistributeLayoutAttr resLayout,
/// For subgroup layouts, it first tries to align the source layout's subgroup
/// layout and data with the consumer's layout on non-reduction dimensions.
/// Then, it distributes remaining subgroups across reduction dimensions. This
-/// avoid subgroup data redistribution overhead between the reduced result and
+/// avoids subgroup data redistribution overhead between the reduced result and
/// its consumer.
///
/// InstData requries {1, ..., min(maxReduceVectorSize, srcShape),subgroupSize}
/// Lane Layout requires {1, ..., 1, subgroupSize}
/// Lane data requires {1, ..., min(maxReduceVectorSize, srcShape), 1}
+///
+/// Examples:
+/// 1. Subgroup layout - Row reduction on 2D tensor:
+/// srcShape=[32, 64], reductionDims=[1], resShape=[32], subgroupSize=16,
+/// workgroupSize=32
+/// Consumer Layout:
+/// #xegpu.slice<#xegpu.layout<sg_layout=[4, 8], sg_data=[8, 8]>, dims =
+/// [0]>} Result: srcLayout with sgLayout=[4, 8], sgData=[8, 8] (matches
+/// consumer on non-reduction dim, minimizing data redistribution on
+/// reduction dim)
+/// 2. Subgroup layout - Same example above but consumer has different layout:
+/// sgLayout=[32], sgData=[1]
+/// Result: srcLayout with sgLayout=[32,1], sgData=[1, 64]
+/// (distributes all subgroups on non reduction dim)
+///
+/// 2. InstData layout - Column reduction:
+/// srcShape=[32, 64], reductionDims=[0], subgroupSize=16
+/// Result: instData=[1, 16] (maxReduceVectorSize=1, subgroupSize on
+/// innermost)
+///
+/// 3. Lane layout - Multi-dimensional reduction:
+/// srcShape=[16, 32, 64], reductionDims=[1], subgroupSize=16
+/// Result: laneLayout=[1, 1, 16], laneData=[1, 1, 1]
+/// (subgroupSize on innermost dim, max vector size on reduction dim)
xegpu::SliceAttr xegpu::setupMultiReductionResultLayout(
xegpu::LayoutKind layoutKind, VectorType srcVecTy,
@@ -438,9 +482,10 @@ xegpu::SliceAttr xegpu::setupMultiReductionResultLayout(
}
assert(remainingSgCount == 1 && "not all subgroups distributed");
- srcLayout =
- xegpu::LayoutAttr::get(context, toInt32Attr(sgLayout),
- toInt32Attr(sgData), consumerLayout.getOrder());
+ srcLayout = xegpu::LayoutAttr::get(
+ context, toInt32Attr(sgLayout), toInt32Attr(sgData),
+ /*inst_data =*/nullptr, /*lane_layout =*/nullptr,
+ /*lane_data =*/nullptr, /*order =*/nullptr);
} else if (layoutKind == xegpu::LayoutKind::InstData) {
@@ -470,6 +515,23 @@ xegpu::SliceAttr xegpu::setupMultiReductionResultLayout(
/// instData, or laneData) by multiplying by the bitwidth ratio to ensure the
/// result layout can be correctly divided back to the source layout during
/// inference.
+///
+/// Examples:
+/// 1. Casting f32 -> f16 (32-bit to 16-bit, bitWidthRatio = 2):
+/// Consumer layout: instData=[1, 16], subgroupSize=16
+/// Source shape: [8, 32]
+/// Result layout: instData=[1, 32] (16 * 2)
+/// The innermost dimension is multiplied by 2 to maintain consistency.
+///
+/// 2. Casting f32 -> i8 (32-bit to 8-bit, bitWidthRatio = 4):
+/// Consumer instData=[1, 16], subgroupSize=16
+/// Source shape: [4, 128]
+/// adjust the instData from [1, 16] to [1, 16 * 4 = 64]
+///
+/// 3. Casting i8 -> i32 (8-bit to 32-bit, bitWidthRatio = 1/4):
+/// Consumer layout: laneLayout=[1, 16], laneData=[1, 4]
+/// No adjustment needed - returns consumer layout directly.
+///
xegpu::DistributeLayoutAttr xegpu::setupBitCastResultLayout(
xegpu::LayoutKind layoutKind, VectorType srcVecTy, VectorType resVecTy,
DistributeLayoutAttr consumerLayout, const xegpu::uArch::uArch *uArch) {
@@ -534,9 +596,31 @@ xegpu::DistributeLayoutAttr xegpu::setupBitCastResultLayout(
/// Lane layout is first set to be {1, ..., subgroupSize} with lane data {1,
/// ..., 1}. The instData and laneData is then adjusted to contain packed data,
/// by checking if the consumerLayout's innermost dimension.
+///
+/// Examples:
+/// 1. InstData layout without packing:
+/// resShape=[8, 32], subgroupSize=16, bitwidth=32
+/// packingFactor=1, packedDataSize=16
+/// consumerLayout: instData=[1, 16]
+/// Result: instData=[1, 16]
+///
+/// 2. InstData layout with packing:
+/// resShape=[8, 64], subgroupSize=16, bitwidth=8, packingFactor=4
+/// consumerLayout: instData=[1, 64]
+/// Result: instData=[1, 64] (adjusted for packed data)
+///
+/// 3. Lane layout without packing:
+/// resShape=[4, 64], subgroupSize=16, bitwidth=32
+/// consumerLayout: laneLayout=[1, 16], laneData=[1, 1]
+/// Result: laneLayout=[1, 16], laneData=[1, 1]
+///
+/// 4. Lane layout with packing:
+/// resShape=[4, 64], subgroupSize=16, bitwidth=16, packingFactor=2
+/// consumerLayout: laneLayout=[1, 16], laneData=[1, 2]
+/// Result: laneLayout=[1, 16], laneData=[1, 2] (adjusted for packed data)
xegpu::DistributeLayoutAttr xegpu::setupInsertStridedSliceResultLayout(
- xegpu::LayoutKind layoutKind, VectorType resVectorTy,
- xegpu::DistributeLayoutAttr consumerLayout,
+ xegpu::LayoutKind layoutKind, VectorType srcVectorTy,
+ VectorType resVectorTy, xegpu::DistributeLayoutAttr consumerLayout,
const xegpu::uArch::uArch *uArch) {
xegpu::DistributeLayoutAttr requiredResLayout;
@@ -544,6 +628,8 @@ xegpu::DistributeLayoutAttr xegpu::setupInsertStridedSliceResultLayout(
auto context = resVectorTy.getContext();
auto resShape = resVectorTy.getShape();
int resShapeSize = resShape.size();
+ auto srcShape = srcVectorTy.getShape();
+ int srcShapeSize = srcVectorTy.getShape().size();
SmallVector<int64_t> consumerInstData =
consumerLayout.getEffectiveInstDataAsInt();
SmallVector<int64_t> consumerLaneData =
@@ -562,14 +648,18 @@ xegpu::DistributeLayoutAttr xegpu::setupInsertStridedSliceResultLayout(
assert(true &&
"subgroup layout assignment not supported for insertStridedSlice.");
} else if (layoutKind == xegpu::LayoutKind::InstData) {
+ assert(srcShape[srcShapeSize - 1] >= subgroupSize &&
+ "source innermost dim must be >= subgroupSize");
instData[resShapeSize - 1] = subgroupSize;
- if (consumerInstData[resShapeSize - 1] == packedDataSize)
+ if (consumerInstData[resShapeSize - 1] == packedDataSize &&
+ srcShape[srcShapeSize - 1] >= packedDataSize)
instData[resShapeSize - 1] = packedDataSize;
requiredResLayout = xegpu::LayoutAttr::get(context, instData);
} else if (layoutKind == xegpu::LayoutKind::Lane) {
laneLayout[resShapeSize - 1] = subgroupSize;
laneData[resShapeSize - 1] = 1;
- if (consumerLaneData[resShapeSize - 1] == packingFactor)
+ if (consumerLaneData[resShapeSize - 1] == packingFactor &&
+ srcShape[srcShapeSize - 1] >= packedDataSize)
laneData[resShapeSize - 1] = packingFactor;
requiredResLayout = xegpu::LayoutAttr::get(context, laneLayout, laneData);
}
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index b723036449914..107192f162da4 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -1114,7 +1114,7 @@ void LayoutInfoPropagation::visitInsertStridedSliceOp(
auto uArch = getUArch(xegpu::getChipStr(insertStridedSlice).value_or(""));
auto requiredResLayoutAttr = xegpu::setupInsertStridedSliceResultLayout(
- layoutKind, resVecType, consumerLayoutAttr, uArch);
+ layoutKind, srcVecType, resVecType, consumerLayoutAttr, uArch);
xegpu::setTemporaryLayout(insertStridedSlice->getResult(0),
requiredResLayoutAttr);
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir b/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
index d53d8ea5bd643..595be3dfa29e5 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
@@ -234,3 +234,41 @@ func.func @scatter_ops_chunksize_slice(%src: memref<1024xf32>) {
return
}
}
+
+// -----
+gpu.module @test {
+// CHECK-LABEL: func.func @insert_strided_slice_inst_data_no_packing(
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<8x32xf32>) {
+// CHECK: %[[CST_SMALL:.*]] = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [1, 16]>} dense<1.000000e+00> : vector<4x16xf32>
+// CHECK: %[[CST_LARGE:.*]] = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [1, 16]>} dense<0.000000e+00> : vector<8x32xf32>
+// CHECK: %[[INSERT:.*]] = vector.insert_strided_slice %[[CST_SMALL]], %[[CST_LARGE]] {layout_result_0 = #xegpu.layout<inst_data = [1, 16]>, offsets = [0, 0], strides = [1, 1]} : vector<4x16xf32> into vector<8x32xf32>
+// CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG0]][{{.*}}] : memref<8x32xf32> -> !xegpu.tensor_desc<8x32xf32, #xegpu.layout<inst_data = [8, 16]>>
+// CHECK: xegpu.store_nd %[[INSERT]], %[[TDESC]] <{layout = #xegpu.layout<inst_data = [8, 16]>}> : vector<8x32xf32>, !xegpu.tensor_desc<8x32xf32, #xegpu.layout<inst_data = [8, 16]>>
+func.func @insert_strided_slice_inst_data_no_packing(%arg0: memref<8x32xf32>) {
+ %c0 = arith.constant 0 : index
+ %cst_small = arith.constant dense<1.0> : vector<4x16xf32>
+ %cst_large = arith.constant dense<0.0> : vector<8x32xf32>
+ %insert = vector.insert_strided_slice %cst_small, %cst_large {offsets = [0, 0], strides = [1, 1]} : vector<4x16xf32> into vector<8x32xf32>
+ %tdesc = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x32xf32> -> !xegpu.tensor_desc<8x32xf32>
+ xegpu.store_nd %insert, %tdesc : vector<8x32xf32>, !xegpu.tensor_desc<8x32xf32>
+ return
+}
+}
+
+// -----
+gpu.module @test {
+// CHECK-LABEL: func.func @insert_strided_slice_inst_data_with_packing(
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<8x64xi8>) {
+// CHECK: %[[CST_SMALL:.*]] = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [1, 64]>} dense<1> : vector<4x64xi8>
+// CHECK: %[[CST_LARGE:.*]] = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [1, 64]>} dense<0> : vector<8x64xi8>
+// CHECK: %[[INSERT:.*]] = vector.insert_strided_slice %[[CST_SMALL]], %[[CST_LARGE]] {layout_result_0 = #xegpu.layout<inst_data = [1, 64]>, offsets = [0, 0], strides = [1, 1]} : vector<4x64xi8> into vector<8x64xi8>
+func.func @insert_strided_slice_inst_data_with_packing(%arg0: memref<8x64xi8>) {
+ %c0 = arith.constant 0 : index
+ %cst_small = arith.constant dense<1> : vector<4x64xi8>
+ %cst_large = arith.constant dense<0> : vector<8x64xi8>
+ %insert = vector.insert_strided_slice %cst_small, %cst_large {offsets = [0, 0], strides = [1, 1]} : vector<4x64xi8> into vector<8x64xi8>
+ %tdesc = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x64xi8> -> !xegpu.tensor_desc<8x64xi8, #xegpu.layout<inst_data = [8, 64]>>
+ xegpu.store_nd %insert, %tdesc <{layout = #xegpu.layout<inst_data = [8, 64]>}>: vector<8x64xi8>, !xegpu.tensor_desc<8x64xi8, #xegpu.layout<inst_data = [8, 64]>>
+ return
+}
+}
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir b/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir
index 29e5b51627fb6..d8668236116c8 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir
@@ -123,3 +123,21 @@ gpu.module @test {
gpu.return
}
}
+
+// -----
+gpu.module @test {
+// CHECK-LABEL: vector_row_reduction_1
+// CHECK: %[[REDUCE:.*]] = vector.multi_reduction <add>, %{{.*}}, %{{.*}} {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [32, 1], sg_data = [1, 64]>, dims = [1]>}
+ gpu.func @vector_row_reduction_1(%src: memref<32x64xf32>, %dst: memref<32xf32>) kernel attributes
+ {known_block_size = array<i32: 1, 32, 1>} {
+ %cst = arith.constant dense<0.000000e+00> : vector<32xf32>
+ %tdesc_src = xegpu.create_nd_tdesc %src : memref<32x64xf32> -> !xegpu.tensor_desc<32x64xf32>
+ %load = xegpu.load_nd %tdesc_src : !xegpu.tensor_desc<32x64xf32> -> vector<32x64xf32>
+ %reduce = vector.multi_reduction <add>, %load, %cst [1] : vector<32x64xf32> to vector<32xf32>
+ %tdesc_dst = xegpu.create_nd_tdesc %dst : memref<32xf32> -> !xegpu.tensor_desc<32xf32, #xegpu.layout<sg_layout = [32], sg_data = [1]>>
+ xegpu.store_nd %reduce, %tdesc_dst <{layout = #xegpu.layout<sg_layout = [32], sg_data = [1]>}>
+ : vector<32xf32>, !xegpu.tensor_desc<32xf32, #xegpu.layout<sg_layout = [32], sg_data = [1]>>
+ gpu.return
+ }
+}
+
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout.mlir b/mlir/test/Dialect/XeGPU/propagate-layout.mlir
index 91e790ccdb4fe..90d580cd5901e 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout.mlir
@@ -707,3 +707,42 @@ func.func @store_matrix(%arg0: !xegpu.mem_desc<16x64xf16>) {
return
}
}
+
+// -----
+gpu.module @test {
+// CHECK-LABEL: func.func @insert_strided_slice_lane_layout_no_packing(
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<4x64xf32>) {
+// CHECK: %[[CST_SMALL:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} dense<1.000000e+00> : vector<2x32xf32>
+// CHECK: %[[CST_LARGE:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} dense<0.000000e+00> : vector<4x64xf32>
+// CHECK: %[[INSERT:.*]] = vector.insert_strided_slice %[[CST_SMALL]], %[[CST_LARGE]] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, offsets = [0, 0], strides = [1, 1]} : vector<2x32xf32> into vector<4x64xf32>
+// CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG0]][{{.*}}] : memref<4x64xf32> -> !xegpu.tensor_desc<4x64xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+// CHECK: xegpu.store_nd %[[INSERT]], %[[TDESC]] <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}> : vector<4x64xf32>, !xegpu.tensor_desc<4x64xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+func.func @insert_strided_slice_lane_layout_no_packing(%arg0: memref<4x64xf32>) {
+ %c0 = arith.constant 0 : index
+ %cst_small = arith.constant dense<1.0> : vector<2x32xf32>
+ %cst_large = arith.constant dense<0.0> : vector<4x64xf32>
+ %insert = vector.insert_strided_slice %cst_small, %cst_large {offsets = [0, 0], strides = [1, 1]} : vector<2x32xf32> into vector<4x64xf32>
+ %tdesc = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<4x64xf32> -> !xegpu.tensor_desc<4x64xf32>
+ xegpu.store_nd %insert, %tdesc : vector<4x64xf32>, !xegpu.tensor_desc<4x64xf32>
+ return
+}
+}
+
+// -----
+gpu.module @test {
+// CHECK-LABEL: func.func @insert_strided_slice_lane_layout_with_packing(
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<4x64xf16>) {
+// CHECK: %[[CST_SMALL:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]>} dense<1.000000e+00> : vector<2x32xf16>
+// CHECK: %[[CST_LARGE:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]>} dense<0.000000e+00> : vector<4x64xf16>
+// CHECK: %[[INSERT:.*]] = vector.insert_strided_slice %[[CST_SMALL]], %[[CST_LARGE]] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]>, offsets = [0, 0], strides = [1, 1]} : vector<2x32xf16> into vector<4x64xf16>
+func.func @insert_strided_slice_lane_layout_with_packing(%arg0: memref<4x64xf16>) {
+ %c0 = arith.constant 0 : index
+ %cst_small = arith.constant dense<1.0> : vector<2x32xf16>
+ %cst_large = arith.constant dense<0.0> : vector<4x64xf16>
+ %insert = vector.insert_strided_slice %cst_small, %cst_large {offsets = [0, 0], strides = [1, 1]} : vector<2x32xf16> into vector<4x64xf16>
+ %tdesc = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<4x64xf16> -> !xegpu.tensor_desc<4x64xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]>>
+ xegpu.store_nd %insert, %tdesc <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]>}>: vector<4x64xf16>, !xegpu.tensor_desc<4x64xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]>>
+ return
+}
+}
+
>From 4044aabb0c9e4959cc7a92e2a223a275df88b0db Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Wed, 4 Feb 2026 05:11:11 +0000
Subject: [PATCH 35/35] add reduction tests
---
.../XeGPU/Transforms/XeGPULayoutImpls.cpp | 4 ++--
.../XeGPU/propagate-layout-subgroup.mlir | 23 +++++++++++++++++++
2 files changed, 25 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpls.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpls.cpp
index a008990a0a063..cd940198ef0c6 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpls.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpls.cpp
@@ -472,8 +472,8 @@ xegpu::SliceAttr xegpu::setupMultiReductionResultLayout(
// Second pass: Distribute remaining subgroups across reduction dimensions
for (int i = srcRank - 1; i >= 0; i--) {
if (llvm::is_contained(reductionDims, i)) {
- sgLayout[i] = std::min(srcShape[i] / subgroupSize,
- static_cast<int64_t>(remainingSgCount));
+ sgLayout[i] =
+ std::min(srcShape[i], static_cast<int64_t>(remainingSgCount));
assert((srcShape[i] % sgLayout[i] == 0) &&
"source shape not divisible by sg_layout");
sgData[i] = srcShape[i] / sgLayout[i];
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir b/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir
index d8668236116c8..5e07e5336b382 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir
@@ -141,3 +141,26 @@ gpu.module @test {
}
}
+// -----
+gpu.module @test {
+// CHECK-LABEL: vector_row_reduction_2
+ gpu.func @vector_row_reduction_2(%src: memref<32x128xf32>, %dst: memref<32xf32>) kernel attributes
+ {known_block_size = array<i32: 1, 32, 1>} {
+ %cst = arith.constant dense<0.000000e+00> : vector<32xf32>
+ %cst1 = arith.constant dense<0.000000e+00> : vector<32x128xf32>
+ %tdesc_src = xegpu.create_nd_tdesc %src : memref<32x128xf32> -> !xegpu.tensor_desc<32x128xf32>
+ %load = xegpu.load_nd %tdesc_src : !xegpu.tensor_desc<32x128xf32> -> vector<32x128xf32>
+ %bcast1 = vector.broadcast %load: vector<32x128xf32> to vector<4x32x128xf32>
+
+ // CHECK: %[[BCAST1:.*]] = vector.broadcast %{{.*}} {layout_result_0 = #xegpu.layout<sg_layout = [1, 4, 8], sg_data = [4, 8, 16]>} : vector<32x128xf32> to vector<4x32x128xf32>
+ // CHECK: %[[BCAST:.*]] = vector.multi_reduction <add>, %[[BCAST1]], %{{.*}} {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [1, 4, 8], sg_data = [4, 8, 16]>, dims = [0]>} [0] : vector<4x32x128xf32> to vector<32x128xf32>
+ // CHECK: %[[REDUCE:.*]] = vector.multi_reduction <add>, %[[BCAST]], %{{.*}} {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [4, 8], sg_data = [8, 16]>, dims = [1]>} [1] : vector<32x128xf32> to vector<32xf32>
+
+ %bcast = vector.multi_reduction <add>, %bcast1, %cst1 [0]: vector<4x32x128xf32> to vector<32x128xf32>
+ %reduce = vector.multi_reduction <add>, %bcast, %cst [1] : vector<32x128xf32> to vector<32xf32>
+ %mask = arith.constant dense<1>: vector<32xi1>
+ %offset = vector.step : vector<32xindex>
+ xegpu.store %reduce, %dst[%offset], %mask {layout = #xegpu.slice<#xegpu.layout<sg_layout=[4, 8], sg_data=[8, 16]>, dims = [1]>} : vector<32xf32>, memref<32xf32>, vector<32xindex>, vector<32xi1>
+ gpu.return
+ }
+}
\ No newline at end of file
More information about the Mlir-commits
mailing list