[Mlir-commits] [mlir] [MLIR][XeGPU] Recover temporary layout from Anchor Layout (PR #191947)
Jianhui Li
llvmlistbot at llvm.org
Thu Apr 16 11:24:13 PDT 2026
https://github.com/Jianhui-Li updated https://github.com/llvm/llvm-project/pull/191947
>From 3a6c2fe41fa7953ca42e94e5663231b33052ce00 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Thu, 2 Apr 2026 22:42:50 +0000
Subject: [PATCH 01/19] initial implementation
---
.../XeGPU/Transforms/XeGPULayoutImpl.h | 5 +-
.../XeGPU/Transforms/XeGPULayoutImpl.cpp | 473 ++++++++++++++++--
mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp | 77 ++-
3 files changed, 499 insertions(+), 56 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h b/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h
index 9cf9a8705209b..5f46eab7b74c7 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h
@@ -183,10 +183,13 @@ setupDpasLayout(LayoutKind layoutKind, VectorType aTy, VectorType bTy,
VectorType cdTy, DistributeLayoutAttr consumerLayout, int numSg,
const uArch::uArch *uArch);
+DistributeLayoutAttr
+inferSourceLayoutFromResult(OpOperand &operand, DistributeLayoutAttr resLayout);
+
/// Gets the expected layout for a given consumer operand. This will check if
/// the owning operation of the consumer operand is one of the special layout
/// users and determine the expected layout accordingly.
-xegpu::DistributeLayoutAttr getConsumerLayoutAt(OpOperand &operand);
+DistributeLayoutAttr getConsumerLayoutAt(OpOperand &operand);
} // namespace xegpu
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
index 55cd6ec04970c..06cd0eaa0059e 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
@@ -18,16 +18,22 @@
#include "mlir/Dialect/LLVMIR/XeVMDialect.h"
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/ValueRange.h"
+#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/LoopLikeInterface.h"
#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/Support/Debug.h"
#include "llvm/Support/FormatVariadic.h"
#include <cstdint>
#include <numeric>
+#define DEBUG_TYPE "xegpu-layout-recovery"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
+
using namespace mlir;
void xegpu::recoverTemporaryLayoutsDeprecated(Operation *op) {
@@ -80,32 +86,330 @@ xegpu::dropInstDataOnAttrs(ArrayRef<NamedAttribute> attrs) {
return out;
}
-// Attach layout attributes to all vector-type operands of operations within
-// the given operation's region. Reports an error if any vector operand lacks
-// a layout attribute.
-bool xegpu::recoverTemporaryLayouts(Operation *rootOp) {
- auto result = rootOp->walk([&](Operation *op) {
- for (OpOperand &operand : op->getOpOperands()) {
- // Layouts are needed for vector type only.
- if (!isa<VectorType>(operand.get().getType()))
- continue;
- // Skip block arguments since they don't have defining ops to attach
- // layout attributes to.
- if (isa<BlockArgument>(operand.get()))
+// Prerequisite for Layout Recovery
+// It relies on the following invariant:
+// 1. there is no layout conflict between different uses of the same definition.
+// 2. each definition has a well-defined layout requirement at its use point.
+// - Every definition must have at least one use that appears after it in
+// topological order.
+// - If a definition has no such use (e.g., a loop result or region output),
+// an explicit convert_layout operation is inserted to create a use.
+// - Only the result of convert_layout is permitted to have no subsequent
+// use.
+
+// The recovery proceeds by scanning the operation in reverse topological order
+// as follows:
+// For regular operations: First the result layouts are propagated from uses.
+// Then the result layouts are propagated to operands.
+//
+// For region operations (e.g., loops):
+// - When backward propagation reaches a region op, it sets the layout of
+// the region op’s results according to use points like regular ops.
+// - Then, the result layouts (such as a loop output) are propagated to
+// their corresponding operands in the yield.
+// - When backward propagation reaches the first operation inside the
+// region, the pass examines the region op’s initialization list,
+// propagating from region arguments to the corresponding initialization
+// operands.
+// - This ensures that layouts are consistently propagated
+// across region boundaries while preserving a single well-defined use for
+// each definition at the region-op level.
+
+// the inner function for recoverTemporaryLayouts is a recursive function
+// the input rootOp is the function operation, which is also a region op.
+// it recursivley process the region op in reverse topological order.
+
+static void walkRegionBackward(Region ®ion,
+ llvm::function_ref<void(Operation *)> visit) {
+ // blocks: back -> front
+ for (Block &block : llvm::reverse(region)) {
+ // ops: back -> front, early-inc so visit() may erase current op safely
+ for (Operation &op : llvm::reverse(block)) {
+ // make sure we first visit inside the region op (so yield op first)
+ // and then move to region op itself
+ for (Region &nested : llvm::reverse(op.getRegions()))
+ walkRegionBackward(nested, visit);
+
+ visit(&op);
+ }
+ }
+}
+
+static xegpu::DistributeLayoutAttr getLayoutFromUsePoints(Value result) {
+ xegpu::DistributeLayoutAttr layout = nullptr;
+ for (OpOperand &use : result.getUses()) {
+ if (auto tmpLayout = xegpu::getDistributeLayoutAttr(use)) {
+ // debug print the use and op, and the tmpLayout
+ LLVM_DEBUG({
+ DBGS() << " use: " << use.getOwner()->getName() << use.getOwner();
+ llvm::dbgs() << ", tmpLayout=" << tmpLayout << "\n";
+ });
+ // under debug mode, we want to check all the use points to make sure
+ // there is no conflict, so we do not break here. In release mode, we can
+ // break at the first use
+#ifndef NDEBUG
+ assert(!layout || layout == tmpLayout);
+ layout = tmpLayout;
+#else
+ layout = tmpLayout;
+ break;
+#endif
+ }
+ }
+ return layout;
+}
+
+// For regular operations: First the result layouts are propagated from uses.
+// Then the result layouts are propagated to uses (operands).
+static void propagateResultsToRegularOperands(Operation *op) {
+ LLVM_DEBUG(DBGS() << "propagateResultsToRegularOperands: " << op->getName()
+ << " (" << op->getNumOperands() << " operands, "
+ << op->getNumResults() << " results)\n");
+
+ if (op->getNumResults() == 0) {
+ LLVM_DEBUG(DBGS() << " skipping (no results)\n");
+ return;
+ }
+
+ Value result = op->getResult(0);
+ xegpu::DistributeLayoutAttr resLayout =
+ getLayoutFromUsePoints(op->getResult(0));
+ Type resultType = result.getType();
+
+ // recover layout for tensor Descriptor type, which is a special case since
+ // its layout is not stored as an attribute but encoded in the type itself.
+ // For vector type, we attach the layout as an attribute to op.
+ if (auto tensorDescTy = dyn_cast<xegpu::TensorDescType>(resultType)) {
+ auto typeWithLayout = xegpu::TensorDescType::get(
+ tensorDescTy.getContext(), tensorDescTy.getShape(),
+ tensorDescTy.getElementType(), tensorDescTy.getEncoding(), resLayout);
+ result.setType(typeWithLayout);
+ }
+
+ for (OpOperand &opr : op->getOpOperands()) {
+ // Layouts are needed for vector type only.
+ xegpu::DistributeLayoutAttr operandLayout =
+ xegpu::inferSourceLayoutFromResult(opr, resLayout);
+ if (!isa<VectorType>(opr.get().getType())) {
+ LLVM_DEBUG(DBGS() << " operand #" << opr.getOperandNumber()
+ << ": skipped (non-vector type: " << opr.get().getType()
+ << ")\n");
+ continue;
+ }
+
+ xegpu::setTemporaryLayout(opr, operandLayout);
+ // debug print op
+ LLVM_DEBUG(DBGS() << "after propagateResultsToRegularOperands op: "
+ << op->getName() << op << " operand #"
+ << opr.getOperandNumber()
+ << ": type=" << opr.get().getType());
+ llvm::dbgs() << ", temp Layout=" << xegpu::getTemporaryLayout(opr);
+ llvm::dbgs() << "\n";
+ }
+}
+
+static void propagateRegionResultsToYieldOperands(
+ mlir::RegionBranchTerminatorOpInterface yieldOp) {
+ LLVM_DEBUG(DBGS() << "propagateRegionResultsToYieldOperands: "
+ << yieldOp->getName() << " (" << yieldOp->getNumOperands()
+ << " operands), parent="
+ << yieldOp->getParentOp()->getName() << "\n");
+
+ if (func::FuncOp func = dyn_cast<func::FuncOp>(yieldOp->getParentOp())) {
+ LLVM_DEBUG(DBGS() << " skipping (parent is FuncOp)\n");
+ return;
+ }
+ llvm::SmallVector<mlir::RegionSuccessor> successors;
+ llvm::SmallVector<mlir::Attribute> operands(yieldOp->getNumOperands(),
+ nullptr);
+ yieldOp.getSuccessorRegions(operands, successors);
+
+ auto regionBranchOp = cast<RegionBranchOpInterface>(yieldOp->getParentOp());
+
+ LLVM_DEBUG(DBGS() << " found " << successors.size() << " successors\n");
+ for (mlir::RegionSuccessor &successor : successors) {
+ // debug print out successorr
+ LLVM_DEBUG({
+ DBGS() << " successor: ";
+ if (successor.isParent()) {
+ DBGS() << "(parent operation)";
+ } else {
+ DBGS() << "region with " << successor.getSuccessor()->getNumArguments()
+ << " arguments";
+ }
+ DBGS() << "\n";
+ });
+ // find out the successor which is the parent region of yieldOp
+ // if (successor.getSuccessor() != yieldOp->getParentRegion()) {
+ // LLVM_DEBUG(DBGS() << " skipping successor (not parent region)\n");
+ // continue;
+ // }
+ if (!successor.isParent())
+ continue;
+ // propagate the layout from region result to yield operands
+ ValueRange successorInputs = regionBranchOp.getSuccessorInputs(successor);
+ LLVM_DEBUG(DBGS() << " propagating " << successorInputs.size()
+ << " region results to yield operands\n");
+ for (unsigned i = 0; i < successorInputs.size(); ++i) {
+ Value regionResult = successorInputs[i];
+
+ // debug print regionResult
+ LLVM_DEBUG({
+ DBGS() << " before propagateRegionResultsToYieldOperands, Region IR:";
+ DBGS() << " region result #" << i
+ << ": type=" << regionResult.getType();
+ llvm::dbgs() << regionResult;
+ llvm::dbgs() << "\n";
+ });
+ // find all the use of region result, and propagate the layout to the
+ // corresponding yield operand for all use of region result, get its
+ // layout from temporary operand layout if any of these use have it
+ xegpu::DistributeLayoutAttr layout = getLayoutFromUsePoints(regionResult);
+
+ // auto layout = xegpu::getDistributeLayoutAttr(regionResult);
+ if (layout == nullptr) {
+ LLVM_DEBUG(DBGS() << " region result #" << i
+ << ": skipped (no layout)\n");
continue;
- auto layout = xegpu::getDistributeLayoutAttr(operand.get());
- if (!layout) {
- op->emitWarning("Could not find layout attribute for operand ")
- << operand.getOperandNumber() << " of operation " << op->getName();
+ }
+ assert(
+ layout &&
+ "region result layout must be defined before propagating to yield");
+
+ if (auto opResult = dyn_cast<OpResult>(regionResult))
+ xegpu::setTemporaryLayout(opResult, layout);
+ xegpu::setTemporaryLayout(yieldOp->getOpOperand(i), layout);
+
+ LLVM_DEBUG({
+ DBGS() << " after propagateRegionResultsToYieldOperands, Region IR:";
+ regionResult.print(llvm::dbgs());
+ if (Operation *defOp = regionResult.getDefiningOp())
+ defOp->print(llvm::dbgs(), OpPrintingFlags().printGenericOpForm());
+ llvm::dbgs() << "\n";
+ });
+ }
+ }
+}
+
+static void propagateRegionArgsToInits(mlir::RegionBranchOpInterface regionOp) {
+ LLVM_DEBUG(DBGS() << "propagateRegionArgsToInits: " << regionOp->getName()
+ << " (" << regionOp->getNumOperands() << " operands, "
+ << regionOp->getNumRegions() << " regions)\n");
+ DBGS() << " before propagateRegionArgsToInits, Region IR:";
+ regionOp.print(llvm::dbgs());
+ DBGS() << " complex debug Region IR:";
+ regionOp.print(llvm::dbgs(), OpPrintingFlags().printGenericOpForm());
+ // Get entry successors (regions that can be entered initially)
+ SmallVector<RegionSuccessor> successors;
+ regionOp.getEntrySuccessorRegions(/*operands=*/ArrayRef<Attribute>(),
+ successors);
+
+ LLVM_DEBUG(DBGS() << " found " << successors.size()
+ << " entry successors\n");
+ // For each possible entry region, get the operands forwarded to it
+ for (RegionSuccessor &successor : successors) {
+ OperandRange initOperands = regionOp.getEntrySuccessorOperands(successor);
+ unsigned beginIdx = initOperands.getBeginOperandIndex();
+ unsigned numArgs = successor.getSuccessor()->getNumArguments();
+ LLVM_DEBUG(DBGS() << " successor region: " << numArgs
+ << " args, initOperands beginIdx=" << beginIdx
+ << ", count=" << initOperands.size() << "\n");
+ // initOperands are the initialization arguments for this successor
+ // iterate the region arguments
+ for (unsigned i = 0; i < numArgs; ++i) {
+ Value regionArg =
+ successor.getSuccessor()->getArgument(i); // region argument
+ auto layout = xegpu::getDistributeLayoutAttr(regionArg);
+ if (layout == nullptr) {
+ LLVM_DEBUG(DBGS() << " region argument #" << i
+ << ": skipped (no layout)\n");
continue;
}
- xegpu::setTemporaryLayout(operand, layout);
+ assert(
+ layout &&
+ "region argument layout must be defined before propagating to init");
+ LLVM_DEBUG(DBGS() << " regionArg #" << i << ": type="
+ << regionArg.getType() << ", layout=" << layout
+ << " -> init operand #" << (beginIdx + i) << "\n");
+ xegpu::setTemporaryLayout(regionOp->getOpOperand(beginIdx + i), layout);
}
- return WalkResult::advance();
+ }
+ DBGS() << " after propagateRegionArgsToInits, Region IR:";
+ regionOp.print(llvm::dbgs());
+}
+
+bool xegpu::recoverTemporaryLayouts(Operation *rootOp) {
+ LLVM_DEBUG(DBGS() << "=== recoverTemporaryLayouts START ===\n");
+
+ auto processFunc = [&](Region &body, StringRef funcName) {
+ LLVM_DEBUG(DBGS() << "Processing func: " << funcName << "\n");
+ walkRegionBackward(body, [&](Operation *op) {
+ LLVM_DEBUG(DBGS() << "Visiting op: " << op->getName());
+ if (op->getNumResults() > 0) {
+ LLVM_DEBUG(llvm::dbgs() << " [results: " << op->getNumResults());
+ for (OpResult res : op->getResults()) {
+ auto layout = xegpu::getDistributeLayoutAttr(res);
+ LLVM_DEBUG(llvm::dbgs() << " r#" << res.getResultNumber() << "="
+ << (layout ? layout : nullptr));
+ }
+ LLVM_DEBUG(llvm::dbgs() << "]");
+ }
+ LLVM_DEBUG(llvm::dbgs() << "\n");
+ if (auto regionOp = dyn_cast<mlir::RegionBranchOpInterface>(op)) {
+ // hit the region op after visiting inside region
+ LLVM_DEBUG(DBGS() << " -> dispatching as RegionBranchOp\n");
+ propagateRegionArgsToInits(regionOp);
+ } else if (auto yieldOp =
+ dyn_cast<mlir::RegionBranchTerminatorOpInterface>(op)) {
+ // yield op inside region op
+ LLVM_DEBUG(DBGS() << " -> dispatching as YieldOp\n");
+ propagateRegionResultsToYieldOperands(yieldOp);
+ } else if (!dyn_cast<xegpu::AnchorLayoutInterface>(op)) {
+ // if the op is regular op, calling propagateResultsToRegularOperands
+ LLVM_DEBUG(DBGS() << " -> dispatching as regular op\n");
+ propagateResultsToRegularOperands(op);
+ }
+ });
+ };
+
+ rootOp->walk([&](func::FuncOp func) {
+ processFunc(func.getBody(), func.getSymName());
});
- return !result.wasInterrupted();
+ rootOp->walk([&](gpu::GPUFuncOp func) {
+ processFunc(func.getBody(), func.getName());
+ });
+
+ LLVM_DEBUG(DBGS() << "=== recoverTemporaryLayouts END ===\n");
+ return true;
}
+// // Attach layout attributes to all vector-type operands of operations within
+// // the given operation's region. Reports an error if any vector operand lacks
+// // a layout attribute.
+// bool xegpu::recoverTemporaryLayouts(Operation *rootOp) {
+// auto result = rootOp->walk([&](Operation *op) {
+// for (OpOperand &operand : op->getOpOperands()) {
+// // Layouts are needed for vector type only.
+// if (!isa<VectorType>(operand.get().getType()))
+// continue;
+// // Skip block arguments since they don't have defining ops to attach
+// // layout attributes to.
+// if (isa<BlockArgument>(operand.get()))
+// continue;
+// auto layout = xegpu::getDistributeLayoutAttr(operand.get());
+// if (!layout) {
+// op->emitWarning("Could not find layout attribute for operand ")
+// << operand.getOperandNumber() << " of operation " <<
+// op->getName();
+// xegpu::setTemporaryLayout(operand, layout);
+// continue;
+// }
+// }
+// return WalkResult::advance();
+// });
+// return !result.wasInterrupted();
+// }
+
template <typename T, typename>
void xegpu::removeLayoutAttr(const T &operandOrResult) {
Operation *owner = operandOrResult.getOwner();
@@ -1108,99 +1412,178 @@ xegpu::setupDpasLayout(xegpu::LayoutKind layoutKind, VectorType aTy,
return std::nullopt;
}
-xegpu::DistributeLayoutAttr xegpu::getConsumerLayoutAt(OpOperand &operand) {
+xegpu::DistributeLayoutAttr
+xegpu::inferSourceLayoutFromResult(OpOperand &operand,
+ xegpu::DistributeLayoutAttr resLayout) {
Operation *op = operand.getOwner();
unsigned idx = operand.getOperandNumber();
- xegpu::DistributeLayoutAttr resLayout;
- if (op->getNumResults() == 1)
- resLayout = xegpu::getDistributeLayoutAttr(op->getResult(0));
// For vector::BroadcastOp, infer the source layout from the result layout.
if (auto broadcast = dyn_cast<vector::BroadcastOp>(op)) {
- if (!resLayout)
+ LLVM_DEBUG(DBGS() << " -> BroadcastOp\n");
+ if (!resLayout) {
+ LLVM_DEBUG(DBGS() << " no resLayout, returning null\n");
return xegpu::DistributeLayoutAttr();
+ }
auto srcTy = dyn_cast<VectorType>(broadcast.getSourceType());
- if (!srcTy)
+ if (!srcTy) {
+ LLVM_DEBUG(DBGS() << " source is not VectorType, returning null\n");
return xegpu::DistributeLayoutAttr();
- return xegpu::inferBroadcastSourceLayout(
+ }
+ auto inferred = xegpu::inferBroadcastSourceLayout(
resLayout, broadcast.getResultVectorType().getShape(),
srcTy.getShape());
+ LLVM_DEBUG(DBGS() << " inferred=" << inferred << "\n");
+ return inferred;
}
// For vector::MultiDimReductionOp, infer source layout from result layout
// using reduction dims. Acc operand is expected to have the same layout as
// the result.
if (auto reduction = dyn_cast<vector::MultiDimReductionOp>(op)) {
- if (!resLayout)
+ LLVM_DEBUG(DBGS() << " -> MultiDimReductionOp, operand idx=" << idx
+ << "\n");
+ if (!resLayout) {
+ LLVM_DEBUG(DBGS() << " no resLayout, returning null\n");
return xegpu::DistributeLayoutAttr();
+ }
if (idx == 0) {
SmallVector<int64_t> reductionDims(reduction.getReductionDims());
- return xegpu::inferMultiReductionSourceLayout(resLayout, reductionDims);
+ LLVM_DEBUG({
+ DBGS() << " reductionDims=[";
+ llvm::interleaveComma(reductionDims, llvm::dbgs());
+ llvm::dbgs() << "]\n";
+ });
+ auto inferred =
+ xegpu::inferMultiReductionSourceLayout(resLayout, reductionDims);
+ LLVM_DEBUG(DBGS() << " inferred source layout=" << inferred << "\n");
+ return inferred;
}
- if (idx == 1)
+ if (idx == 1) {
+ LLVM_DEBUG(DBGS() << " acc operand, using resLayout\n");
return resLayout;
+ }
}
if (auto reduction = dyn_cast<vector::ReductionOp>(op)) {
- if (!resLayout)
+ LLVM_DEBUG(DBGS() << " -> ReductionOp\n");
+ if (!resLayout) {
+ LLVM_DEBUG(DBGS() << " no resLayout, returning null\n");
return xegpu::DistributeLayoutAttr();
- return xegpu::inferReductionSourceLayout(resLayout);
+ }
+ auto inferred = xegpu::inferReductionSourceLayout(resLayout);
+ LLVM_DEBUG(DBGS() << " inferred=" << inferred << "\n");
+ return inferred;
}
// For vector::BitCastOp, infer source layout from result layout using
// element type bitwidths.
if (auto bitcast = dyn_cast<vector::BitCastOp>(op)) {
- if (!resLayout)
+ LLVM_DEBUG(DBGS() << " -> BitCastOp\n");
+ if (!resLayout) {
+ LLVM_DEBUG(DBGS() << " no resLayout, returning null\n");
return xegpu::DistributeLayoutAttr();
+ }
int resElemBitWidth =
bitcast.getResultVectorType().getElementType().getIntOrFloatBitWidth();
int srcElemBitWidth =
bitcast.getSourceVectorType().getElementType().getIntOrFloatBitWidth();
- return xegpu::inferBitCastSourceLayout(resLayout, resElemBitWidth,
- srcElemBitWidth);
+ LLVM_DEBUG(DBGS() << " resBitWidth=" << resElemBitWidth
+ << ", srcBitWidth=" << srcElemBitWidth << "\n");
+ auto inferred = xegpu::inferBitCastSourceLayout(resLayout, resElemBitWidth,
+ srcElemBitWidth);
+ LLVM_DEBUG(DBGS() << " inferred=" << inferred << "\n");
+ return inferred;
}
// For vector::ShapeCastOp, infer source layout from result layout using
// shapes.
if (auto shapeCast = dyn_cast<vector::ShapeCastOp>(op)) {
- if (!resLayout)
+ LLVM_DEBUG({
+ DBGS() << " -> ShapeCastOp: resShape=[";
+ llvm::interleaveComma(shapeCast.getResultVectorType().getShape(),
+ llvm::dbgs());
+ llvm::dbgs() << "], srcShape=[";
+ llvm::interleaveComma(shapeCast.getSourceVectorType().getShape(),
+ llvm::dbgs());
+ llvm::dbgs() << "]\n";
+ });
+ if (!resLayout) {
+ LLVM_DEBUG(DBGS() << " no resLayout, returning null\n");
return xegpu::DistributeLayoutAttr();
- return xegpu::inferShapeCastSourceLayout(
+ }
+ auto inferred = xegpu::inferShapeCastSourceLayout(
resLayout, shapeCast.getResultVectorType().getShape(),
shapeCast.getSourceVectorType().getShape());
+ LLVM_DEBUG(DBGS() << " inferred=" << inferred << "\n");
+ return inferred;
}
// For vector::InsertStridedSliceOp, infer source layout from result layout.
// Dest vector must have the same layout as the result.
if (auto insertSlice = dyn_cast<vector::InsertStridedSliceOp>(op)) {
- if (!resLayout)
+ LLVM_DEBUG(DBGS() << " -> InsertStridedSliceOp, operand idx=" << idx
+ << "\n");
+ if (!resLayout) {
+ LLVM_DEBUG(DBGS() << " no resLayout, returning null\n");
return xegpu::DistributeLayoutAttr();
- if (idx == 0)
- return xegpu::inferInsertStridedSliceSourceLayout(
+ }
+ if (idx == 0) {
+ auto inferred = xegpu::inferInsertStridedSliceSourceLayout(
resLayout, insertSlice.getDestVectorType().getShape(),
insertSlice.getSourceVectorType().getShape());
- if (idx == 1)
+ LLVM_DEBUG(DBGS() << " inferred source layout=" << inferred << "\n");
+ return inferred;
+ }
+ if (idx == 1) {
+ LLVM_DEBUG(DBGS() << " dest operand, using resLayout\n");
return resLayout;
+ }
}
// For vector::TransposeOp, infer source layout from result layout using
// permutation.
if (auto transpose = dyn_cast<vector::TransposeOp>(op)) {
- if (!resLayout)
+ LLVM_DEBUG({
+ DBGS() << " -> TransposeOp, perm=[";
+ llvm::interleaveComma(transpose.getPermutation(), llvm::dbgs());
+ llvm::dbgs() << "]\n";
+ });
+ if (!resLayout) {
+ LLVM_DEBUG(DBGS() << " no resLayout, returning null\n");
return xegpu::DistributeLayoutAttr();
- return xegpu::inferTransposeSourceLayout(resLayout,
- transpose.getPermutation());
+ }
+ auto inferred = xegpu::inferTransposeSourceLayout(
+ resLayout, transpose.getPermutation());
+ LLVM_DEBUG(DBGS() << " inferred=" << inferred << "\n");
+ return inferred;
}
// For elementwise operations, all operands must have the same layout as the
// result.
if (OpTrait::hasElementwiseMappableTraits(op) && op->getNumResults() == 1) {
+ LLVM_DEBUG(DBGS() << " -> elementwise op, using resLayout="
+ << (resLayout ? resLayout : nullptr) << "\n");
if (!resLayout)
return xegpu::DistributeLayoutAttr();
return resLayout;
}
- // TODO: Handle more cases as needed here.
+ return xegpu::DistributeLayoutAttr();
+}
+
+xegpu::DistributeLayoutAttr xegpu::getConsumerLayoutAt(OpOperand &operand) {
+ Operation *op = operand.getOwner();
+ xegpu::DistributeLayoutAttr resLayout;
+ if (op->getNumResults() == 1)
+ resLayout = xegpu::getDistributeLayoutAttr(op->getResult(0));
+ auto inferredOperandLayout = inferSourceLayoutFromResult(operand, resLayout);
+ if (inferredOperandLayout)
+ return inferredOperandLayout;
// By default, assume no layout conflict and return the current layout of
// the operand.
- return xegpu::getDistributeLayoutAttr(operand.get());
+ auto fallback = xegpu::getDistributeLayoutAttr(operand);
+ LLVM_DEBUG(DBGS() << " -> fallback (unhandled op " << op->getName()
+ << "), returning operand layout="
+ << (fallback ? fallback : nullptr) << "\n");
+ return fallback;
}
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
index 243581b4ce522..a762458105e47 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
@@ -23,10 +23,14 @@
#include "mlir/Interfaces/LoopLikeInterface.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/Support/Casting.h"
+#include "llvm/Support/Debug.h"
#include "llvm/Support/FormatVariadic.h"
#include <cstdint>
#include <numeric>
+#define DEBUG_TYPE "xegpu-utils"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
+
using namespace mlir;
/// convert ArrayRef<ValueRange> into SmallVector<Value>
@@ -145,19 +149,31 @@ std::string xegpu::getTemporaryLayoutName(const OpResult result) {
}
xegpu::DistributeLayoutAttr xegpu::getDistributeLayoutAttr(const Value value) {
- if (!value)
+ LLVM_DEBUG(DBGS() << "getDistributeLayoutAttr(Value): type="
+ << value.getType() << "\n");
+ if (!value) {
+ LLVM_DEBUG(DBGS() << " -> null value, returning nullptr\n");
return nullptr;
+ }
if (auto tdescTy =
- dyn_cast_if_present<xegpu::TensorDescType>(value.getType()))
- return tdescTy.getLayoutAttr();
+ dyn_cast_if_present<xegpu::TensorDescType>(value.getType())) {
+ auto layout = tdescTy.getLayoutAttr();
+ LLVM_DEBUG(DBGS() << " -> TensorDescType, layout="
+ << (layout ? layout : nullptr) << "\n");
+ return layout;
+ }
if (auto result = dyn_cast<OpResult>(value)) {
Operation *defOp = result.getDefiningOp();
assert(defOp && "result must have a defining op");
+ LLVM_DEBUG(DBGS() << " OpResult #" << result.getResultNumber() << " from "
+ << defOp->getName() << "\n");
if (auto anchorOp = dyn_cast<xegpu::AnchorLayoutInterface>(defOp)) {
auto layout = anchorOp.getAnchorLayout();
+ LLVM_DEBUG(DBGS() << " -> AnchorLayoutInterface, layout="
+ << (layout ? layout : nullptr) << "\n");
return layout;
}
@@ -165,59 +181,100 @@ xegpu::DistributeLayoutAttr xegpu::getDistributeLayoutAttr(const Value value) {
if (defOp->hasAttr(layoutName)) {
auto layout =
defOp->getAttrOfType<xegpu::DistributeLayoutAttr>(layoutName);
+ LLVM_DEBUG(DBGS() << " -> temporary attr '" << layoutName
+ << "', layout=" << layout << "\n");
return layout;
}
+ LLVM_DEBUG(DBGS() << " -> OpResult: no layout found (checked '"
+ << layoutName << "')\n");
}
if (auto arg = dyn_cast<BlockArgument>(value)) {
auto *parentOp = arg.getOwner()->getParentOp();
+ LLVM_DEBUG(DBGS() << " BlockArgument #" << arg.getArgNumber() << " of "
+ << (parentOp ? parentOp->getName().getStringRef()
+ : StringRef("(null)"))
+ << "\n");
if (auto loop = dyn_cast_if_present<LoopLikeOpInterface>(parentOp)) {
OpOperand *tiedInit = loop.getTiedLoopInit(arg);
- if (tiedInit)
+ if (tiedInit) {
+ LLVM_DEBUG(DBGS() << " -> LoopLikeOp, recursing into tiedInit "
+ << "operand #" << tiedInit->getOperandNumber()
+ << "\n");
return getDistributeLayoutAttr(tiedInit->get());
+ }
+ LLVM_DEBUG(DBGS() << " -> LoopLikeOp, no tiedInit\n");
}
}
+ LLVM_DEBUG(DBGS() << " -> returning nullptr\n");
return nullptr;
}
xegpu::DistributeLayoutAttr
xegpu::getDistributeLayoutAttr(const OpOperand &opr) {
Operation *op = opr.getOwner();
unsigned idx = const_cast<OpOperand &>(opr).getOperandNumber();
+ LLVM_DEBUG(DBGS() << "getDistributeLayoutAttr(OpOperand): operand #" << idx
+ << " of " << op->getName()
+ << ", type=" << opr.get().getType() << "\n");
if (auto anchorOp = dyn_cast<xegpu::AnchorLayoutInterface>(op)) {
if (auto dpasOp = dyn_cast<xegpu::DpasOp>(op)) {
if (idx == 0) {
- return dpasOp.getLayoutAAttr();
+ auto layout = dpasOp.getLayoutAAttr();
+ LLVM_DEBUG(DBGS() << " -> DpasOp layoutA="
+ << (layout ? layout : nullptr) << "\n");
+ return layout;
} else if (idx == 1) {
- return dpasOp.getLayoutBAttr();
+ auto layout = dpasOp.getLayoutBAttr();
+ LLVM_DEBUG(DBGS() << " -> DpasOp layoutB="
+ << (layout ? layout : nullptr) << "\n");
+ return layout;
} else if (idx == 2) {
- return dpasOp.getLayoutCdAttr();
+ auto layout = dpasOp.getLayoutCdAttr();
+ LLVM_DEBUG(DBGS() << " -> DpasOp layoutCd="
+ << (layout ? layout : nullptr) << "\n");
+ return layout;
}
}
if (auto convertOp = dyn_cast<xegpu::ConvertLayoutOp>(op)) {
- return convertOp.getInputLayoutAttr();
+ auto layout = convertOp.getInputLayoutAttr();
+ LLVM_DEBUG(DBGS() << " -> ConvertLayoutOp inputLayout="
+ << (layout ? layout : nullptr) << "\n");
+ return layout;
}
auto layout = anchorOp.getAnchorLayout();
- if (idx == 0)
+ if (idx == 0) {
+ LLVM_DEBUG(DBGS() << " -> AnchorLayoutInterface idx=0, layout="
+ << (layout ? layout : nullptr) << "\n");
return layout;
+ }
// For store operations (StoreScatterOp, StoreNdOp, StoreMatrixOp),
// the layout is valid for the first two operands: value and memref/tdesc.
// For other operations, the layout applies to the first operand only.
if (isa<xegpu::StoreScatterOp, xegpu::StoreNdOp, xegpu::StoreMatrixOp>(
op) &&
- (idx < 2))
+ (idx < 2)) {
+ LLVM_DEBUG(DBGS() << " -> Store op idx=" << idx
+ << ", layout=" << (layout ? layout : nullptr) << "\n");
return layout;
+ }
+ LLVM_DEBUG(DBGS() << " -> AnchorLayoutInterface idx=" << idx
+ << " not covered, falling through\n");
}
std::string layoutName = xegpu::getTemporaryLayoutName(opr);
if (op->hasAttr(layoutName)) {
auto layout = op->getAttrOfType<xegpu::DistributeLayoutAttr>(layoutName);
+ LLVM_DEBUG(DBGS() << " -> temporary attr '" << layoutName
+ << "', layout=" << layout << "\n");
return layout;
}
+ LLVM_DEBUG(DBGS() << " -> returning nullptr (checked '" << layoutName
+ << "')\n");
return nullptr;
}
>From f77e110d9dc81257b2deeb9cccb20e10bea3739b Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Fri, 3 Apr 2026 05:31:42 +0000
Subject: [PATCH 02/19] pass while
---
.../XeGPU/Transforms/XeGPULayoutImpl.cpp | 253 +++++++-----------
.../XeGPU/Transforms/XeGPUPropagateLayout.cpp | 2 +-
2 files changed, 103 insertions(+), 152 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
index 06cd0eaa0059e..47148870eeaae 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
@@ -147,13 +147,8 @@ static xegpu::DistributeLayoutAttr getLayoutFromUsePoints(Value result) {
// under debug mode, we want to check all the use points to make sure
// there is no conflict, so we do not break here. In release mode, we can
// break at the first use
-#ifndef NDEBUG
- assert(!layout || layout == tmpLayout);
- layout = tmpLayout;
-#else
- layout = tmpLayout;
- break;
-#endif
+ if (!layout)
+ layout = tmpLayout;
}
}
return layout;
@@ -215,127 +210,118 @@ static void propagateRegionResultsToYieldOperands(
<< " operands), parent="
<< yieldOp->getParentOp()->getName() << "\n");
- if (func::FuncOp func = dyn_cast<func::FuncOp>(yieldOp->getParentOp())) {
+ if (isa<func::FuncOp>(yieldOp->getParentOp())) {
LLVM_DEBUG(DBGS() << " skipping (parent is FuncOp)\n");
return;
}
- llvm::SmallVector<mlir::RegionSuccessor> successors;
- llvm::SmallVector<mlir::Attribute> operands(yieldOp->getNumOperands(),
- nullptr);
- yieldOp.getSuccessorRegions(operands, successors);
- auto regionBranchOp = cast<RegionBranchOpInterface>(yieldOp->getParentOp());
+ auto regionBranchOp =
+ dyn_cast<RegionBranchOpInterface>(yieldOp->getParentOp());
+ if (!regionBranchOp) {
+ LLVM_DEBUG(DBGS() << " skipping (parent is not RegionBranchOp)\n");
+ return;
+ }
- LLVM_DEBUG(DBGS() << " found " << successors.size() << " successors\n");
- for (mlir::RegionSuccessor &successor : successors) {
- // debug print out successorr
- LLVM_DEBUG({
- DBGS() << " successor: ";
- if (successor.isParent()) {
- DBGS() << "(parent operation)";
- } else {
- DBGS() << "region with " << successor.getSuccessor()->getNumArguments()
- << " arguments";
- }
- DBGS() << "\n";
- });
- // find out the successor which is the parent region of yieldOp
- // if (successor.getSuccessor() != yieldOp->getParentRegion()) {
- // LLVM_DEBUG(DBGS() << " skipping successor (not parent region)\n");
- // continue;
- // }
- if (!successor.isParent())
- continue;
- // propagate the layout from region result to yield operands
- ValueRange successorInputs = regionBranchOp.getSuccessorInputs(successor);
- LLVM_DEBUG(DBGS() << " propagating " << successorInputs.size()
- << " region results to yield operands\n");
- for (unsigned i = 0; i < successorInputs.size(); ++i) {
- Value regionResult = successorInputs[i];
-
- // debug print regionResult
- LLVM_DEBUG({
- DBGS() << " before propagateRegionResultsToYieldOperands, Region IR:";
- DBGS() << " region result #" << i
- << ": type=" << regionResult.getType();
- llvm::dbgs() << regionResult;
- llvm::dbgs() << "\n";
- });
- // find all the use of region result, and propagate the layout to the
- // corresponding yield operand for all use of region result, get its
- // layout from temporary operand layout if any of these use have it
- xegpu::DistributeLayoutAttr layout = getLayoutFromUsePoints(regionResult);
-
- // auto layout = xegpu::getDistributeLayoutAttr(regionResult);
- if (layout == nullptr) {
- LLVM_DEBUG(DBGS() << " region result #" << i
- << ": skipped (no layout)\n");
- continue;
- }
- assert(
- layout &&
- "region result layout must be defined before propagating to yield");
+ // Gather layouts for each result of the parent region op from external
+ // use points.
+ unsigned numResults = regionBranchOp->getNumResults();
+ LLVM_DEBUG(DBGS() << " parent op has " << numResults << " results\n");
+
+ SmallVector<xegpu::DistributeLayoutAttr> resultLayouts(numResults, nullptr);
+ for (unsigned i = 0; i < numResults; ++i) {
+ OpResult result = regionBranchOp->getResult(i);
+ resultLayouts[i] = getLayoutFromUsePoints(result);
+ if (resultLayouts[i]) {
+ LLVM_DEBUG(DBGS() << " result #" << i << ": type=" << result.getType()
+ << ", layout=" << resultLayouts[i] << "\n");
+ xegpu::setTemporaryLayout(result, resultLayouts[i]);
+ } else {
+ LLVM_DEBUG(DBGS() << " result #" << i
+ << ": skipped (no layout from use points)\n");
+ }
+ }
- if (auto opResult = dyn_cast<OpResult>(regionResult))
- xegpu::setTemporaryLayout(opResult, layout);
- xegpu::setTemporaryLayout(yieldOp->getOpOperand(i), layout);
+ // Use getSuccessorOperands to find which operands of the terminator
+ // flow to a successor. This handles index offsets automatically (e.g.,
+ // scf.condition's predicate at operand #0 is excluded).
+ // Pick the first successor to determine the operand range.
+ SmallVector<RegionSuccessor> successors;
+ SmallVector<Attribute> operandAttrs(yieldOp->getNumOperands(), nullptr);
+ yieldOp.getSuccessorRegions(operandAttrs, successors);
+ assert(!successors.empty() && "terminator must have at least one successor");
- LLVM_DEBUG({
- DBGS() << " after propagateRegionResultsToYieldOperands, Region IR:";
- regionResult.print(llvm::dbgs());
- if (Operation *defOp = regionResult.getDefiningOp())
- defOp->print(llvm::dbgs(), OpPrintingFlags().printGenericOpForm());
- llvm::dbgs() << "\n";
- });
- }
+ OperandRange succOps = yieldOp.getSuccessorOperands(successors.front());
+ unsigned beginIdx = succOps.getBeginOperandIndex();
+ unsigned count = std::min(static_cast<unsigned>(succOps.size()), numResults);
+
+ LLVM_DEBUG(DBGS() << " " << count << " successor operands starting at index "
+ << beginIdx << "\n");
+
+ for (unsigned i = 0; i < count; ++i) {
+ if (!resultLayouts[i])
+ continue;
+ LLVM_DEBUG(DBGS() << " -> setting layout on operand #" << (beginIdx + i)
+ << "\n");
+ xegpu::setTemporaryLayout(yieldOp->getOpOperand(beginIdx + i),
+ resultLayouts[i]);
}
+
+ LLVM_DEBUG({
+ DBGS() << " after propagateRegionResultsToYieldOperands:\n";
+ yieldOp->print(llvm::dbgs(), OpPrintingFlags().printGenericOpForm());
+ llvm::dbgs() << "\n";
+ });
}
static void propagateRegionArgsToInits(mlir::RegionBranchOpInterface regionOp) {
LLVM_DEBUG(DBGS() << "propagateRegionArgsToInits: " << regionOp->getName()
<< " (" << regionOp->getNumOperands() << " operands, "
<< regionOp->getNumRegions() << " regions)\n");
- DBGS() << " before propagateRegionArgsToInits, Region IR:";
- regionOp.print(llvm::dbgs());
- DBGS() << " complex debug Region IR:";
- regionOp.print(llvm::dbgs(), OpPrintingFlags().printGenericOpForm());
- // Get entry successors (regions that can be entered initially)
- SmallVector<RegionSuccessor> successors;
- regionOp.getEntrySuccessorRegions(/*operands=*/ArrayRef<Attribute>(),
- successors);
-
- LLVM_DEBUG(DBGS() << " found " << successors.size()
- << " entry successors\n");
- // For each possible entry region, get the operands forwarded to it
- for (RegionSuccessor &successor : successors) {
- OperandRange initOperands = regionOp.getEntrySuccessorOperands(successor);
- unsigned beginIdx = initOperands.getBeginOperandIndex();
- unsigned numArgs = successor.getSuccessor()->getNumArguments();
- LLVM_DEBUG(DBGS() << " successor region: " << numArgs
- << " args, initOperands beginIdx=" << beginIdx
- << ", count=" << initOperands.size() << "\n");
- // initOperands are the initialization arguments for this successor
- // iterate the region arguments
- for (unsigned i = 0; i < numArgs; ++i) {
- Value regionArg =
- successor.getSuccessor()->getArgument(i); // region argument
- auto layout = xegpu::getDistributeLayoutAttr(regionArg);
- if (layout == nullptr) {
- LLVM_DEBUG(DBGS() << " region argument #" << i
- << ": skipped (no layout)\n");
+ LLVM_DEBUG({
+ DBGS() << " before propagateRegionArgsToInits, Region IR:\n";
+ regionOp.print(llvm::dbgs(), OpPrintingFlags().printGenericOpForm());
+ llvm::dbgs() << "\n";
+ });
+
+ // Iterate all regions of the region op. For each block argument that has a
+ // layout (determined from its use points), trace back to find the
+ // corresponding init operand of the regionOp and set the layout on it.
+ // This works generically for scf.for, scf.while, and other
+ // RegionBranchOpInterface ops.
+ for (Region ®ion : regionOp->getRegions()) {
+ RegionSuccessor regionSuccessor(®ion);
+ for (auto [argIdx, regionArg] : llvm::enumerate(region.getArguments())) {
+ auto layout = getLayoutFromUsePoints(regionArg);
+ if (!layout) {
+ LLVM_DEBUG(DBGS() << " region #" << region.getRegionNumber()
+ << " arg #" << argIdx << ": skipped (no layout)\n");
continue;
}
- assert(
- layout &&
- "region argument layout must be defined before propagating to init");
- LLVM_DEBUG(DBGS() << " regionArg #" << i << ": type="
- << regionArg.getType() << ", layout=" << layout
- << " -> init operand #" << (beginIdx + i) << "\n");
- xegpu::setTemporaryLayout(regionOp->getOpOperand(beginIdx + i), layout);
+ LLVM_DEBUG(DBGS() << " region #" << region.getRegionNumber() << " arg #"
+ << argIdx << ": type=" << regionArg.getType()
+ << ", layout=" << layout << "\n");
+
+ // Find all predecessor values that flow into this block argument.
+ SmallVector<Value> predValues;
+ regionOp.getPredecessorValues(regionSuccessor, argIdx, predValues);
+ for (Value predVal : predValues) {
+ // Match predecessor value to an operand of the regionOp.
+ for (OpOperand &operand : regionOp->getOpOperands()) {
+ if (operand.get() == predVal) {
+ LLVM_DEBUG(DBGS() << " -> setting layout on init operand #"
+ << operand.getOperandNumber() << "\n");
+ xegpu::setTemporaryLayout(operand, layout);
+ }
+ }
+ }
}
}
- DBGS() << " after propagateRegionArgsToInits, Region IR:";
- regionOp.print(llvm::dbgs());
+
+ LLVM_DEBUG({
+ DBGS() << " after propagateRegionArgsToInits, Region IR:\n";
+ regionOp.print(llvm::dbgs(), OpPrintingFlags().printGenericOpForm());
+ llvm::dbgs() << "\n";
+ });
}
bool xegpu::recoverTemporaryLayouts(Operation *rootOp) {
@@ -345,16 +331,6 @@ bool xegpu::recoverTemporaryLayouts(Operation *rootOp) {
LLVM_DEBUG(DBGS() << "Processing func: " << funcName << "\n");
walkRegionBackward(body, [&](Operation *op) {
LLVM_DEBUG(DBGS() << "Visiting op: " << op->getName());
- if (op->getNumResults() > 0) {
- LLVM_DEBUG(llvm::dbgs() << " [results: " << op->getNumResults());
- for (OpResult res : op->getResults()) {
- auto layout = xegpu::getDistributeLayoutAttr(res);
- LLVM_DEBUG(llvm::dbgs() << " r#" << res.getResultNumber() << "="
- << (layout ? layout : nullptr));
- }
- LLVM_DEBUG(llvm::dbgs() << "]");
- }
- LLVM_DEBUG(llvm::dbgs() << "\n");
if (auto regionOp = dyn_cast<mlir::RegionBranchOpInterface>(op)) {
// hit the region op after visiting inside region
LLVM_DEBUG(DBGS() << " -> dispatching as RegionBranchOp\n");
@@ -1415,16 +1391,16 @@ xegpu::setupDpasLayout(xegpu::LayoutKind layoutKind, VectorType aTy,
xegpu::DistributeLayoutAttr
xegpu::inferSourceLayoutFromResult(OpOperand &operand,
xegpu::DistributeLayoutAttr resLayout) {
+ if (!resLayout) {
+ LLVM_DEBUG(DBGS() << "no resLayout, returning null\n");
+ return xegpu::DistributeLayoutAttr();
+ }
Operation *op = operand.getOwner();
unsigned idx = operand.getOperandNumber();
// For vector::BroadcastOp, infer the source layout from the result layout.
if (auto broadcast = dyn_cast<vector::BroadcastOp>(op)) {
LLVM_DEBUG(DBGS() << " -> BroadcastOp\n");
- if (!resLayout) {
- LLVM_DEBUG(DBGS() << " no resLayout, returning null\n");
- return xegpu::DistributeLayoutAttr();
- }
auto srcTy = dyn_cast<VectorType>(broadcast.getSourceType());
if (!srcTy) {
LLVM_DEBUG(DBGS() << " source is not VectorType, returning null\n");
@@ -1443,10 +1419,6 @@ xegpu::inferSourceLayoutFromResult(OpOperand &operand,
if (auto reduction = dyn_cast<vector::MultiDimReductionOp>(op)) {
LLVM_DEBUG(DBGS() << " -> MultiDimReductionOp, operand idx=" << idx
<< "\n");
- if (!resLayout) {
- LLVM_DEBUG(DBGS() << " no resLayout, returning null\n");
- return xegpu::DistributeLayoutAttr();
- }
if (idx == 0) {
SmallVector<int64_t> reductionDims(reduction.getReductionDims());
LLVM_DEBUG({
@@ -1467,10 +1439,6 @@ xegpu::inferSourceLayoutFromResult(OpOperand &operand,
if (auto reduction = dyn_cast<vector::ReductionOp>(op)) {
LLVM_DEBUG(DBGS() << " -> ReductionOp\n");
- if (!resLayout) {
- LLVM_DEBUG(DBGS() << " no resLayout, returning null\n");
- return xegpu::DistributeLayoutAttr();
- }
auto inferred = xegpu::inferReductionSourceLayout(resLayout);
LLVM_DEBUG(DBGS() << " inferred=" << inferred << "\n");
return inferred;
@@ -1480,10 +1448,6 @@ xegpu::inferSourceLayoutFromResult(OpOperand &operand,
// element type bitwidths.
if (auto bitcast = dyn_cast<vector::BitCastOp>(op)) {
LLVM_DEBUG(DBGS() << " -> BitCastOp\n");
- if (!resLayout) {
- LLVM_DEBUG(DBGS() << " no resLayout, returning null\n");
- return xegpu::DistributeLayoutAttr();
- }
int resElemBitWidth =
bitcast.getResultVectorType().getElementType().getIntOrFloatBitWidth();
int srcElemBitWidth =
@@ -1508,10 +1472,6 @@ xegpu::inferSourceLayoutFromResult(OpOperand &operand,
llvm::dbgs());
llvm::dbgs() << "]\n";
});
- if (!resLayout) {
- LLVM_DEBUG(DBGS() << " no resLayout, returning null\n");
- return xegpu::DistributeLayoutAttr();
- }
auto inferred = xegpu::inferShapeCastSourceLayout(
resLayout, shapeCast.getResultVectorType().getShape(),
shapeCast.getSourceVectorType().getShape());
@@ -1524,10 +1484,6 @@ xegpu::inferSourceLayoutFromResult(OpOperand &operand,
if (auto insertSlice = dyn_cast<vector::InsertStridedSliceOp>(op)) {
LLVM_DEBUG(DBGS() << " -> InsertStridedSliceOp, operand idx=" << idx
<< "\n");
- if (!resLayout) {
- LLVM_DEBUG(DBGS() << " no resLayout, returning null\n");
- return xegpu::DistributeLayoutAttr();
- }
if (idx == 0) {
auto inferred = xegpu::inferInsertStridedSliceSourceLayout(
resLayout, insertSlice.getDestVectorType().getShape(),
@@ -1549,10 +1505,6 @@ xegpu::inferSourceLayoutFromResult(OpOperand &operand,
llvm::interleaveComma(transpose.getPermutation(), llvm::dbgs());
llvm::dbgs() << "]\n";
});
- if (!resLayout) {
- LLVM_DEBUG(DBGS() << " no resLayout, returning null\n");
- return xegpu::DistributeLayoutAttr();
- }
auto inferred = xegpu::inferTransposeSourceLayout(
resLayout, transpose.getPermutation());
LLVM_DEBUG(DBGS() << " inferred=" << inferred << "\n");
@@ -1564,8 +1516,7 @@ xegpu::inferSourceLayoutFromResult(OpOperand &operand,
if (OpTrait::hasElementwiseMappableTraits(op) && op->getNumResults() == 1) {
LLVM_DEBUG(DBGS() << " -> elementwise op, using resLayout="
<< (resLayout ? resLayout : nullptr) << "\n");
- if (!resLayout)
- return xegpu::DistributeLayoutAttr();
+
return resLayout;
}
return xegpu::DistributeLayoutAttr();
@@ -1581,7 +1532,7 @@ xegpu::DistributeLayoutAttr xegpu::getConsumerLayoutAt(OpOperand &operand) {
return inferredOperandLayout;
// By default, assume no layout conflict and return the current layout of
// the operand.
- auto fallback = xegpu::getDistributeLayoutAttr(operand);
+ auto fallback = xegpu::getDistributeLayoutAttr(operand.get());
LLVM_DEBUG(DBGS() << " -> fallback (unhandled op " << op->getName()
<< "), returning operand layout="
<< (fallback ? fallback : nullptr) << "\n");
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 4c30dacae8850..f0ff771f4cbc4 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -1338,7 +1338,7 @@ LogicalResult ResolveLayoutConflicts::run() {
// as anchor op for the reduction op's layout.
if (isa<vector::MultiDimReductionOp>(op) || isa<vector::ReductionOp>(op)) {
for (OpResult result : op->getResults()) {
- if (result.getType().isIntOrFloat()) {
+ if (result.getType().isIntOrFloat() || result.use_empty()) {
auto res = assignResultLayout(result);
if (failed(res)) {
DBGS() << "Failed to resolve vector consumer for multi-reduction "
>From 27cc56acf41eb3380d3195fca5a9215b4414a413 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Fri, 3 Apr 2026 18:00:50 +0000
Subject: [PATCH 03/19] adding support for DistributeLayoutAttr in TensorDesc
instead of just LayoutAttr
---
.../mlir/Dialect/XeGPU/IR/XeGPUTypes.td | 6 +++---
.../mlir/Dialect/XeGPU/Utils/XeGPUUtils.h | 5 +++--
.../XeGPU/Transforms/XeGPUBlocking.cpp | 13 ++++++------
.../XeGPU/Transforms/XeGPULayoutImpl.cpp | 21 +++++++++++++------
.../Transforms/XeGPUPeepHoleOptimizer.cpp | 11 +++++++---
.../Transforms/XeGPUSubgroupDistribute.cpp | 9 ++++----
.../Transforms/XeGPUWgToSgDistribute.cpp | 2 +-
mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp | 4 ++--
.../lib/Dialect/XeGPU/TestXeGPUTransforms.cpp | 13 ++++++------
9 files changed, 49 insertions(+), 35 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
index 7e142b20c0894..b13f5a9f2c9d9 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
@@ -82,7 +82,7 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
static-dim-list ::= decimal-literal `x` decimal-literal
attr-list = (, encoding-attr)? (, layout-attr)?
enconding-attr = (, memory_space = value)? (, arr_len = value)? (, boundary_check = value)? (, scattered = value)?
- layout-attr = (, layout `<`sg_layout = value, sg_data = value, inst_data = value, lane_layout = value, lane_data = value, order = value`>`)?
+ layout-attr = DistributeLayoutAttr
```
Examples:
@@ -158,8 +158,8 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
return llvm::dyn_cast_if_present<T>(getEncoding());
}
- LayoutAttr getLayoutAttr() const {
- return llvm::dyn_cast_if_present<LayoutAttr>(getLayout());
+ DistributeLayoutAttr getLayoutAttr() const {
+ return llvm::dyn_cast_if_present<DistributeLayoutAttr>(getLayout());
}
xegpu::MemorySpace getMemorySpace() const {
diff --git a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
index 0aa2cd45088f3..1b594f17e15ec 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
@@ -219,10 +219,11 @@ void setTemporaryLayout(const T &operandOrResult,
/// Helper function to check if the layout is packed. Layout is packed if it is
/// 2D and lane_data[0] != 1 (data packed from col dimension).
/// TODO: Move to target info.
-bool requirePacked(const LayoutAttr layout);
+bool requirePacked(const DistributeLayoutAttr layout);
/// Helper function to check if the layout requires a transpose effect.
-bool requireTranspose(const LayoutAttr layout, const uArch::uArch *uArch);
+bool requireTranspose(const DistributeLayoutAttr layout,
+ const uArch::uArch *uArch);
// Check if dst shape is an expansion of src shape by inserting unit dimensions.
bool matchUnitDimExpansion(ArrayRef<int64_t> src, ArrayRef<int64_t> dst,
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
index 1ee0bc6ad9507..ef6a494b76638 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
@@ -270,12 +270,11 @@ void XeGPUBlockingPass::runOnOperation() {
}
auto getTileShapeAndCount = [](llvm::ArrayRef<int64_t> shape,
- xegpu::LayoutAttr layout) {
+ xegpu::DistributeLayoutAttr layout) {
int count = 1;
SmallVector<int64_t> tileShape(shape);
- if (layout && layout.getInstData()) {
- DenseI32ArrayAttr instData = layout.getInstData();
- tileShape = llvm::to_vector_of<int64_t>(instData.asArrayRef());
+ if (layout && !layout.getEffectiveInstDataAsInt().empty()) {
+ tileShape = layout.getEffectiveInstDataAsInt();
count = computeProduct(shape) / computeProduct(tileShape);
}
return std::make_pair(tileShape, count);
@@ -308,7 +307,7 @@ void XeGPUBlockingPass::runOnOperation() {
Type elemTy = type.getElementType();
ArrayRef<int64_t> shape = type.getShape();
- xegpu::LayoutAttr layout = type.getLayoutAttr();
+ xegpu::DistributeLayoutAttr layout = type.getLayoutAttr();
if (layout && layout.isForWorkgroup())
return failure();
@@ -348,9 +347,9 @@ void XeGPUBlockingPass::runOnOperation() {
if (chunkSize > 1) {
int64_t blockedChunkSize = chunkSize;
- auto instData = tdescTy.getLayoutAttr().getInstData();
+ auto instData = tdescTy.getLayoutAttr().getEffectiveInstDataAsInt();
if (!instData.empty())
- blockedChunkSize = instData.asArrayRef().back();
+ blockedChunkSize = instData.back();
// To create a new attribute with a different chunk_size:
auto newEncoding = xegpu::ScatterTensorDescAttr::get(
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
index 47148870eeaae..535239e869af1 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
@@ -141,7 +141,8 @@ static xegpu::DistributeLayoutAttr getLayoutFromUsePoints(Value result) {
if (auto tmpLayout = xegpu::getDistributeLayoutAttr(use)) {
// debug print the use and op, and the tmpLayout
LLVM_DEBUG({
- DBGS() << " use: " << use.getOwner()->getName() << use.getOwner();
+ DBGS() << "getLayoutFromUsePoints use: " << use.getOwner()->getName()
+ << use.getOwner();
llvm::dbgs() << ", tmpLayout=" << tmpLayout << "\n";
});
// under debug mode, we want to check all the use points to make sure
@@ -175,10 +176,16 @@ static void propagateResultsToRegularOperands(Operation *op) {
// its layout is not stored as an attribute but encoded in the type itself.
// For vector type, we attach the layout as an attribute to op.
if (auto tensorDescTy = dyn_cast<xegpu::TensorDescType>(resultType)) {
- auto typeWithLayout = xegpu::TensorDescType::get(
- tensorDescTy.getContext(), tensorDescTy.getShape(),
- tensorDescTy.getElementType(), tensorDescTy.getEncoding(), resLayout);
- result.setType(typeWithLayout);
+ auto layout = tensorDescTy.getLayoutAttr();
+ // TODO: remove the layout check. The tensorDescType's layout is treated as
+ // temporary layout, which needs to be set by layout recovery.
+ // allow it now to pass some legacy test case
+ if (!layout) {
+ auto typeWithLayout = xegpu::TensorDescType::get(
+ tensorDescTy.getContext(), tensorDescTy.getShape(),
+ tensorDescTy.getElementType(), tensorDescTy.getEncoding(), resLayout);
+ result.setType(typeWithLayout);
+ }
}
for (OpOperand &opr : op->getOpOperands()) {
@@ -226,6 +233,8 @@ static void propagateRegionResultsToYieldOperands(
// use points.
unsigned numResults = regionBranchOp->getNumResults();
LLVM_DEBUG(DBGS() << " parent op has " << numResults << " results\n");
+ if (numResults == 0)
+ return;
SmallVector<xegpu::DistributeLayoutAttr> resultLayouts(numResults, nullptr);
for (unsigned i = 0; i < numResults; ++i) {
@@ -303,7 +312,7 @@ static void propagateRegionArgsToInits(mlir::RegionBranchOpInterface regionOp) {
// Find all predecessor values that flow into this block argument.
SmallVector<Value> predValues;
- regionOp.getPredecessorValues(regionSuccessor, argIdx, predValues);
+ regionOp.getPredecessorValues(regionSuccessor, argIdx - 1, predValues);
for (Value predVal : predValues) {
// Match predecessor value to an operand of the regionOp.
for (OpOperand &operand : regionOp->getOpOperands()) {
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPeepHoleOptimizer.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPeepHoleOptimizer.cpp
index 0ece695aed512..9288ba9a0cb56 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPeepHoleOptimizer.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPeepHoleOptimizer.cpp
@@ -145,10 +145,15 @@ static xegpu::TensorDescType tryOptimize(xegpu::TensorDescType tdescType,
return tdescType;
SmallVector<int64_t> supportedShape = {supportedHeight, supportedWidth};
+ auto ctx = tdescType.getContext();
+ auto origLayout = tdescType.getLayoutAttr();
+ SmallVector<int32_t> laneLayoutI32(
+ origLayout.getEffectiveLaneLayoutAsInt().begin(),
+ origLayout.getEffectiveLaneLayoutAsInt().end());
xegpu::LayoutAttr newLayout = xegpu::LayoutAttr::get(
- tdescType.getContext(), tdescType.getLayoutAttr().getLaneLayout(),
- DenseI32ArrayAttr::get(tdescType.getContext(), {1, 1}),
- tdescType.getLayoutAttr().getOrder());
+ ctx, /*lane_layout=*/DenseI32ArrayAttr::get(ctx, laneLayoutI32),
+ /*lane_data=*/DenseI32ArrayAttr::get(ctx, {1, 1}),
+ /*order=*/origLayout.getOrder());
// Array length can not be larger than 1 for transpose case.
return xegpu::TensorDescType::get(supportedShape, newElemTy, arrayLen,
tdescType.getBoundaryCheck(),
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index ecdf253d68182..d8ce24ddd5cb0 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -256,7 +256,7 @@ struct CreateNdDescDistribution final : public gpu::WarpDistributionPattern {
auto descOp = operand->get().getDefiningOp<xegpu::CreateNdDescOp>();
unsigned operandIdx = operand->getOperandNumber();
- xegpu::LayoutAttr layout = descOp.getType().getLayoutAttr();
+ xegpu::DistributeLayoutAttr layout = descOp.getType().getLayoutAttr();
if (!layout)
return rewriter.notifyMatchFailure(
descOp, "the tensor descriptor lacks layout attribute");
@@ -342,7 +342,7 @@ struct StoreNdDistribution final : public gpu::WarpDistributionPattern {
SmallVector<Type> offsetTypes = llvm::map_to_vector(
offsetsAsValues, [](Value v) { return v.getType(); });
xegpu::TensorDescType tensorDescTy = storeOp.getTensorDescType();
- xegpu::LayoutAttr layout = tensorDescTy.getLayoutAttr();
+ xegpu::DistributeLayoutAttr layout = tensorDescTy.getLayoutAttr();
if (!layout)
return rewriter.notifyMatchFailure(
storeOp, "the source tensor descriptor lacks layout attribute");
@@ -474,7 +474,7 @@ struct LoadNdDistribution final : public gpu::WarpDistributionPattern {
offsetsAsValues, [](Value v) { return v.getType(); });
xegpu::TensorDescType tensorDescTy = loadOp.getTensorDescType();
- xegpu::LayoutAttr layout = tensorDescTy.getLayoutAttr();
+ xegpu::DistributeLayoutAttr layout = tensorDescTy.getLayoutAttr();
if (!layout)
return rewriter.notifyMatchFailure(
loadOp, "the source tensor descriptor lacks layout attribute");
@@ -709,7 +709,8 @@ struct PrefetchNdDistribution final : public gpu::WarpDistributionPattern {
SmallVector<Type> offsetTypes = llvm::map_to_vector(
offsetsAsValues, [](Value v) { return v.getType(); });
- xegpu::LayoutAttr layout = prefetchOp.getTensorDescType().getLayoutAttr();
+ xegpu::DistributeLayoutAttr layout =
+ prefetchOp.getTensorDescType().getLayoutAttr();
if (!layout)
return rewriter.notifyMatchFailure(
prefetchOp, "the source tensor descriptor lacks layout attribute");
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 0aead9172858f..e47224bbe755c 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -1647,7 +1647,7 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
converter.addConversion(
[&](xegpu::TensorDescType type,
SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> {
- xegpu::LayoutAttr layout = type.getLayoutAttr();
+ xegpu::DistributeLayoutAttr layout = type.getLayoutAttr();
// Only convert WG-level tensor descs. SG-level or layout-less types
// are already legal and should pass through unchanged.
if (!layout || !layout.isForWorkgroup())
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
index a762458105e47..55cf47e38dfd0 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
@@ -936,7 +936,7 @@ template int
xegpu::getLargestDivisor<unsigned>(unsigned dim, ArrayRef<unsigned> candidates,
ArrayRef<unsigned> candidateMultiples);
-bool xegpu::requirePacked(const xegpu::LayoutAttr layout) {
+bool xegpu::requirePacked(const xegpu::DistributeLayoutAttr layout) {
if (!layout)
return false;
auto laneData = layout.getEffectiveLaneDataAsInt();
@@ -945,7 +945,7 @@ bool xegpu::requirePacked(const xegpu::LayoutAttr layout) {
return laneData[0] != 1;
}
-bool xegpu::requireTranspose(const xegpu::LayoutAttr layout,
+bool xegpu::requireTranspose(const xegpu::DistributeLayoutAttr layout,
const xegpu::uArch::uArch *uArch) {
// Return false for unsupported targets.
// TODO: Add more support or move to target info.
diff --git a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
index 0d10ab7c74da6..4760016bdcea4 100644
--- a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
+++ b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
@@ -106,10 +106,9 @@ struct TestXeGPUUnrollingPatterns
}
if (auto layout = tdescTy.getLayoutAttr()) {
- auto inst_data = layout.getInstData();
- if (inst_data && layout.isForSubgroup())
- return SmallVector<int64_t>(inst_data.asArrayRef().begin(),
- inst_data.asArrayRef().end());
+ auto inst_data = layout.getEffectiveInstDataAsInt();
+ if (!inst_data.empty() && layout.isForSubgroup())
+ return SmallVector<int64_t>(inst_data.begin(), inst_data.end());
}
}
@@ -138,9 +137,9 @@ struct TestXeGPUUnrollingPatterns
if (chunkSize > 1) {
int64_t blockedChunkSize = chunkSize;
- auto instData = layout.getInstData();
+ auto instData = layout.getEffectiveInstDataAsInt();
if (!instData.empty())
- blockedChunkSize = instData.asArrayRef().back();
+ blockedChunkSize = instData.back();
// To create a new attribute with a different chunk_size:
auto newEncoding = xegpu::ScatterTensorDescAttr::get(
@@ -150,7 +149,7 @@ struct TestXeGPUUnrollingPatterns
}
}
if (layout) {
- if (layout.getLaneLayout() == nullptr)
+ if (layout.getEffectiveLaneLayoutAsInt().empty())
layout = xegpu::LayoutAttr();
else
layout = layout.dropInstData();
>From 0690c6cc01e121b137c1056e62d22ff207a82777 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Fri, 3 Apr 2026 20:02:43 +0000
Subject: [PATCH 04/19] fix bugs
---
mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp | 2 +-
.../XeGPU/Transforms/XeGPULayoutImpl.cpp | 6 +++
.../Transforms/XeGPUPeepHoleOptimizer.cpp | 19 ++++++++--
.../Transforms/XeGPUSubgroupDistribute.cpp | 38 ++++++++++++++++++-
.../XeGPU/sg-to-wi-experimental-unit.mlir | 19 +---------
mlir/test/Dialect/XeGPU/xegpu-blocking.mlir | 4 +-
6 files changed, 64 insertions(+), 24 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 950371e17255f..64c56b5adf5d7 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -1318,7 +1318,7 @@ mlir::Type TensorDescType::parse(AsmParser &parser) {
mlir::Attribute attr;
ParseResult res = parser.parseAttribute(attr);
if (mlir::succeeded(res)) {
- if (mlir::isa<LayoutAttr>(attr)) {
+ if (mlir::isa<DistributeLayoutAttr>(attr)) {
layout = attr;
continue;
}
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
index 535239e869af1..33c9086566d3c 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
@@ -365,6 +365,12 @@ bool xegpu::recoverTemporaryLayouts(Operation *rootOp) {
});
LLVM_DEBUG(DBGS() << "=== recoverTemporaryLayouts END ===\n");
+ // print the root op after
+ LLVM_DEBUG({
+ DBGS() << "After recoverTemporaryLayouts, IR:\n";
+ rootOp->print(llvm::dbgs(), OpPrintingFlags().printGenericOpForm());
+ llvm::dbgs() << "\n";
+ });
return true;
}
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPeepHoleOptimizer.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPeepHoleOptimizer.cpp
index 9288ba9a0cb56..c43eaba5b3ee6 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPeepHoleOptimizer.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPeepHoleOptimizer.cpp
@@ -28,6 +28,7 @@
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/Debug.h"
#include <optional>
namespace mlir {
@@ -147,13 +148,25 @@ static xegpu::TensorDescType tryOptimize(xegpu::TensorDescType tdescType,
SmallVector<int64_t> supportedShape = {supportedHeight, supportedWidth};
auto ctx = tdescType.getContext();
auto origLayout = tdescType.getLayoutAttr();
- SmallVector<int32_t> laneLayoutI32(
- origLayout.getEffectiveLaneLayoutAsInt().begin(),
- origLayout.getEffectiveLaneLayoutAsInt().end());
+ auto laneLayoutI64 = origLayout.getEffectiveLaneLayoutAsInt();
+ SmallVector<int32_t> laneLayoutI32(laneLayoutI64.begin(),
+ laneLayoutI64.end());
+ LLVM_DEBUG({
+ DBGS() << "tryOptimize: origLayout=" << origLayout << "\n";
+ DBGS() << " laneLayoutI32=[";
+ llvm::interleaveComma(laneLayoutI32, llvm::dbgs());
+ llvm::dbgs() << "], laneData=[1, 1]";
+ if (origLayout.getOrder())
+ llvm::dbgs() << ", order=" << origLayout.getOrder();
+ llvm::dbgs() << "\n";
+ DBGS() << " supportedShape=[" << supportedHeight << ", " << supportedWidth
+ << "], newElemTy=" << newElemTy << ", arrayLen=" << arrayLen << "\n";
+ });
xegpu::LayoutAttr newLayout = xegpu::LayoutAttr::get(
ctx, /*lane_layout=*/DenseI32ArrayAttr::get(ctx, laneLayoutI32),
/*lane_data=*/DenseI32ArrayAttr::get(ctx, {1, 1}),
/*order=*/origLayout.getOrder());
+ LLVM_DEBUG(DBGS() << " newLayout=" << newLayout << "\n");
// Array length can not be larger than 1 for transpose case.
return xegpu::TensorDescType::get(supportedShape, newElemTy, arrayLen,
tdescType.getBoundaryCheck(),
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index d8ce24ddd5cb0..27cf788933f18 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -800,10 +800,17 @@ struct StoreDistribution final : public gpu::WarpDistributionPattern {
using gpu::WarpDistributionPattern::WarpDistributionPattern;
LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
+ LLVM_DEBUG(DBGS() << "StoreDistribution: attempting to match\n");
Operation *lastNode = warpOp.getTerminator()->getPrevNode();
auto storeScatterOp = dyn_cast_or_null<xegpu::StoreScatterOp>(lastNode);
- if (!storeScatterOp)
+ if (!storeScatterOp) {
+ LLVM_DEBUG(
+ DBGS()
+ << "StoreDistribution: last node is not StoreScatterOp, skipping\n");
return failure();
+ }
+ LLVM_DEBUG(DBGS() << "StoreDistribution: matched StoreScatterOp: "
+ << *storeScatterOp << "\n");
auto offsets = storeScatterOp.getOffsets();
if (!offsets || !isa<VectorType>(offsets.getType()))
return rewriter.notifyMatchFailure(
@@ -811,10 +818,15 @@ struct StoreDistribution final : public gpu::WarpDistributionPattern {
VectorType offsetsTy = cast<VectorType>(offsets.getType());
VectorType maskTy = cast<VectorType>(storeScatterOp.getMask().getType());
VectorType storeVecTy = cast<VectorType>(storeScatterOp.getValueType());
+ LLVM_DEBUG(DBGS() << "StoreDistribution: offsetsTy=" << offsetsTy
+ << ", maskTy=" << maskTy << ", storeVecTy=" << storeVecTy
+ << "\n");
// Add handling for leading unit dimensions support
int chunkSize = storeScatterOp.getChunkSize().value_or(1);
int effectiveVecRank = (chunkSize == 1) ? 1 : 2;
+ LLVM_DEBUG(DBGS() << "StoreDistribution: chunkSize=" << chunkSize
+ << ", effectiveVecRank=" << effectiveVecRank << "\n");
// Check that all leading dimensions are unit dimensions
for (int i = 0; i < storeVecTy.getRank() - effectiveVecRank; i++) {
@@ -831,6 +843,24 @@ struct StoreDistribution final : public gpu::WarpDistributionPattern {
xegpu::getTemporaryLayout(storeScatterOp->getOpOperand(2));
auto layoutMask =
xegpu::getTemporaryLayout(storeScatterOp->getOpOperand(3));
+ LLVM_DEBUG({
+ DBGS() << "StoreDistribution: layoutPayload=";
+ if (layoutPayload)
+ DBGS() << layoutPayload;
+ else
+ DBGS() << "(null)";
+ DBGS() << ", layoutOffsets=";
+ if (layoutOffsets)
+ DBGS() << layoutOffsets;
+ else
+ DBGS() << "(null)";
+ DBGS() << ", layoutMask=";
+ if (layoutMask)
+ DBGS() << layoutMask;
+ else
+ DBGS() << "(null)";
+ DBGS() << "\n";
+ });
FailureOr<VectorType> distStoreVecByWarpOpOrFailure =
getDistVecTypeBasedOnLaneLayout(layoutPayload, storeVecTy);
@@ -849,6 +879,9 @@ struct StoreDistribution final : public gpu::WarpDistributionPattern {
VectorType distPayloadTy = distStoreVecByWarpOpOrFailure.value();
VectorType distOffsetsTy = distOffsetsByWarpOpOrFailure.value();
VectorType distMaskTy = distMaskByWarpOpOrFailure.value();
+ LLVM_DEBUG(DBGS() << "StoreDistribution: distPayloadTy=" << distPayloadTy
+ << ", distOffsetsTy=" << distOffsetsTy
+ << ", distMaskTy=" << distMaskTy << "\n");
SmallVector<size_t> newRetIndices;
SmallVector<Value> operands = storeScatterOp->getOperands();
@@ -885,7 +918,10 @@ struct StoreDistribution final : public gpu::WarpDistributionPattern {
rewriter, newWarpOp.getLoc(), TypeRange{}, newStoreScatterOpOperands,
storeScatterOp->getAttrs());
xegpu::removeLayoutAttrs(newOp);
+ LLVM_DEBUG(DBGS() << "StoreDistribution: created new op: " << newOp
+ << "\n");
rewriter.eraseOp(storeScatterOp);
+ LLVM_DEBUG(DBGS() << "StoreDistribution: done\n");
return success();
}
};
diff --git a/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir b/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
index 842c2375dd31d..0d1bfd5480aa2 100644
--- a/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
+++ b/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
@@ -473,22 +473,6 @@ gpu.func @vector_multi_reduction_dim0_distributed_dim1_reduction(%laneid: index)
gpu.return
}
-// CHECK-LABEL: gpu.func @vector_transpose
-// CHECK: %[[SRC:.*]] = "some_op"()
-// CHECK: %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[SRC]] : vector<16x2xf32> to vector<1x2xf32>
-// CHECK-NEXT: %[[T:.*]] = vector.transpose %[[CAST]], [1, 0] : vector<1x2xf32> to vector<2x1xf32>
-// CHECK-NEXT: gpu.return
-gpu.func @vector_transpose() {
- %cst = "some_op"()
- {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1], order = [0, 1]>}
- : () -> (vector<16x2xf32>)
- %transpose = vector.transpose %cst, [1, 0]
- {
- layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
- }
- : vector<16x2xf32> to vector<2x16xf32>
- gpu.return
-}
// CHECK-LABEL: gpu.func @vector_bitcast
// CHECK: %[[SRC:.*]] = "some_op"()
@@ -1092,7 +1076,8 @@ gpu.module @xevm_module {
gpu.func @vector_broadcast_2d_to_2d_noop(%laneid: index) {
%0 = "some_op"() {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : () -> vector<16x1xf16>
%1 = vector.broadcast %0 {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16x1xf16> to vector<16x16xf16>
- "some_use"(%1) : (vector<16x16xf16>) -> ()
+ %2 = xegpu.convert_layout %1 <{input_layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, target_layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}> : vector<16x16xf16>
+ "some_use"(%2) : (vector<16x16xf16>) -> ()
gpu.return
}
}
diff --git a/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir b/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir
index 9ca424374335f..61b8046bd04e5 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir
@@ -257,7 +257,7 @@ gpu.module @test_kernel {
// -----
#l = #xegpu.layout<inst_data = [16, 16]>
-#r = #xegpu.layout<inst_data = [16]>
+#r = #xegpu.slice<#xegpu.layout<inst_data = [16, 16]>, dims = [0]>
gpu.module @test_kernel {
gpu.func @reduce_dim_0(%a: memref<16x512xf32>, %b: memref<512xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} {
%acc = arith.constant {layout_result_0 = #r} dense<0.0> : vector<64xf32>
@@ -277,7 +277,7 @@ gpu.module @test_kernel {
// -----
#l = #xegpu.layout<inst_data = [16, 16]>
-#r = #xegpu.layout<inst_data = [16]>
+#r = #xegpu.slice<#xegpu.layout<inst_data = [16, 16]>, dims = [1]>
gpu.module @test_kernel {
gpu.func @reduce_dim_1(%a: memref<512x32xf32>, %b: memref<512xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} {
%c1 = arith.constant 1 : index
>From ac36ceaccbd9bff10bf933ffef9b0b0d1e557cdc Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Fri, 3 Apr 2026 20:24:35 +0000
Subject: [PATCH 05/19] separate recover temporary layout out to another PR
---
.../XeGPU/Transforms/XeGPULayoutImpl.h | 5 +-
.../XeGPU/Transforms/XeGPULayoutImpl.cpp | 457 +++---------------
.../XeGPU/Transforms/XeGPUPropagateLayout.cpp | 2 +-
.../XeGPU/sg-to-wi-experimental-unit.mlir | 19 +-
4 files changed, 73 insertions(+), 410 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h b/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h
index 5f46eab7b74c7..9cf9a8705209b 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h
@@ -183,13 +183,10 @@ setupDpasLayout(LayoutKind layoutKind, VectorType aTy, VectorType bTy,
VectorType cdTy, DistributeLayoutAttr consumerLayout, int numSg,
const uArch::uArch *uArch);
-DistributeLayoutAttr
-inferSourceLayoutFromResult(OpOperand &operand, DistributeLayoutAttr resLayout);
-
/// Gets the expected layout for a given consumer operand. This will check if
/// the owning operation of the consumer operand is one of the special layout
/// users and determine the expected layout accordingly.
-DistributeLayoutAttr getConsumerLayoutAt(OpOperand &operand);
+xegpu::DistributeLayoutAttr getConsumerLayoutAt(OpOperand &operand);
} // namespace xegpu
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
index 33c9086566d3c..55cd6ec04970c 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
@@ -18,22 +18,16 @@
#include "mlir/Dialect/LLVMIR/XeVMDialect.h"
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
-#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/ValueRange.h"
-#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/LoopLikeInterface.h"
#include "mlir/Transforms/DialectConversion.h"
-#include "llvm/Support/Debug.h"
#include "llvm/Support/FormatVariadic.h"
#include <cstdint>
#include <numeric>
-#define DEBUG_TYPE "xegpu-layout-recovery"
-#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
-
using namespace mlir;
void xegpu::recoverTemporaryLayoutsDeprecated(Operation *op) {
@@ -86,321 +80,32 @@ xegpu::dropInstDataOnAttrs(ArrayRef<NamedAttribute> attrs) {
return out;
}
-// Prerequisite for Layout Recovery
-// It relies on the following invariant:
-// 1. there is no layout conflict between different uses of the same definition.
-// 2. each definition has a well-defined layout requirement at its use point.
-// - Every definition must have at least one use that appears after it in
-// topological order.
-// - If a definition has no such use (e.g., a loop result or region output),
-// an explicit convert_layout operation is inserted to create a use.
-// - Only the result of convert_layout is permitted to have no subsequent
-// use.
-
-// The recovery proceeds by scanning the operation in reverse topological order
-// as follows:
-// For regular operations: First the result layouts are propagated from uses.
-// Then the result layouts are propagated to operands.
-//
-// For region operations (e.g., loops):
-// - When backward propagation reaches a region op, it sets the layout of
-// the region op’s results according to use points like regular ops.
-// - Then, the result layouts (such as a loop output) are propagated to
-// their corresponding operands in the yield.
-// - When backward propagation reaches the first operation inside the
-// region, the pass examines the region op’s initialization list,
-// propagating from region arguments to the corresponding initialization
-// operands.
-// - This ensures that layouts are consistently propagated
-// across region boundaries while preserving a single well-defined use for
-// each definition at the region-op level.
-
-// the inner function for recoverTemporaryLayouts is a recursive function
-// the input rootOp is the function operation, which is also a region op.
-// it recursivley process the region op in reverse topological order.
-
-static void walkRegionBackward(Region ®ion,
- llvm::function_ref<void(Operation *)> visit) {
- // blocks: back -> front
- for (Block &block : llvm::reverse(region)) {
- // ops: back -> front, early-inc so visit() may erase current op safely
- for (Operation &op : llvm::reverse(block)) {
- // make sure we first visit inside the region op (so yield op first)
- // and then move to region op itself
- for (Region &nested : llvm::reverse(op.getRegions()))
- walkRegionBackward(nested, visit);
-
- visit(&op);
- }
- }
-}
-
-static xegpu::DistributeLayoutAttr getLayoutFromUsePoints(Value result) {
- xegpu::DistributeLayoutAttr layout = nullptr;
- for (OpOperand &use : result.getUses()) {
- if (auto tmpLayout = xegpu::getDistributeLayoutAttr(use)) {
- // debug print the use and op, and the tmpLayout
- LLVM_DEBUG({
- DBGS() << "getLayoutFromUsePoints use: " << use.getOwner()->getName()
- << use.getOwner();
- llvm::dbgs() << ", tmpLayout=" << tmpLayout << "\n";
- });
- // under debug mode, we want to check all the use points to make sure
- // there is no conflict, so we do not break here. In release mode, we can
- // break at the first use
- if (!layout)
- layout = tmpLayout;
- }
- }
- return layout;
-}
-
-// For regular operations: First the result layouts are propagated from uses.
-// Then the result layouts are propagated to uses (operands).
-static void propagateResultsToRegularOperands(Operation *op) {
- LLVM_DEBUG(DBGS() << "propagateResultsToRegularOperands: " << op->getName()
- << " (" << op->getNumOperands() << " operands, "
- << op->getNumResults() << " results)\n");
-
- if (op->getNumResults() == 0) {
- LLVM_DEBUG(DBGS() << " skipping (no results)\n");
- return;
- }
-
- Value result = op->getResult(0);
- xegpu::DistributeLayoutAttr resLayout =
- getLayoutFromUsePoints(op->getResult(0));
- Type resultType = result.getType();
-
- // recover layout for tensor Descriptor type, which is a special case since
- // its layout is not stored as an attribute but encoded in the type itself.
- // For vector type, we attach the layout as an attribute to op.
- if (auto tensorDescTy = dyn_cast<xegpu::TensorDescType>(resultType)) {
- auto layout = tensorDescTy.getLayoutAttr();
- // TODO: remove the layout check. The tensorDescType's layout is treated as
- // temporary layout, which needs to be set by layout recovery.
- // allow it now to pass some legacy test case
- if (!layout) {
- auto typeWithLayout = xegpu::TensorDescType::get(
- tensorDescTy.getContext(), tensorDescTy.getShape(),
- tensorDescTy.getElementType(), tensorDescTy.getEncoding(), resLayout);
- result.setType(typeWithLayout);
- }
- }
-
- for (OpOperand &opr : op->getOpOperands()) {
- // Layouts are needed for vector type only.
- xegpu::DistributeLayoutAttr operandLayout =
- xegpu::inferSourceLayoutFromResult(opr, resLayout);
- if (!isa<VectorType>(opr.get().getType())) {
- LLVM_DEBUG(DBGS() << " operand #" << opr.getOperandNumber()
- << ": skipped (non-vector type: " << opr.get().getType()
- << ")\n");
- continue;
- }
-
- xegpu::setTemporaryLayout(opr, operandLayout);
- // debug print op
- LLVM_DEBUG(DBGS() << "after propagateResultsToRegularOperands op: "
- << op->getName() << op << " operand #"
- << opr.getOperandNumber()
- << ": type=" << opr.get().getType());
- llvm::dbgs() << ", temp Layout=" << xegpu::getTemporaryLayout(opr);
- llvm::dbgs() << "\n";
- }
-}
-
-static void propagateRegionResultsToYieldOperands(
- mlir::RegionBranchTerminatorOpInterface yieldOp) {
- LLVM_DEBUG(DBGS() << "propagateRegionResultsToYieldOperands: "
- << yieldOp->getName() << " (" << yieldOp->getNumOperands()
- << " operands), parent="
- << yieldOp->getParentOp()->getName() << "\n");
-
- if (isa<func::FuncOp>(yieldOp->getParentOp())) {
- LLVM_DEBUG(DBGS() << " skipping (parent is FuncOp)\n");
- return;
- }
-
- auto regionBranchOp =
- dyn_cast<RegionBranchOpInterface>(yieldOp->getParentOp());
- if (!regionBranchOp) {
- LLVM_DEBUG(DBGS() << " skipping (parent is not RegionBranchOp)\n");
- return;
- }
-
- // Gather layouts for each result of the parent region op from external
- // use points.
- unsigned numResults = regionBranchOp->getNumResults();
- LLVM_DEBUG(DBGS() << " parent op has " << numResults << " results\n");
- if (numResults == 0)
- return;
-
- SmallVector<xegpu::DistributeLayoutAttr> resultLayouts(numResults, nullptr);
- for (unsigned i = 0; i < numResults; ++i) {
- OpResult result = regionBranchOp->getResult(i);
- resultLayouts[i] = getLayoutFromUsePoints(result);
- if (resultLayouts[i]) {
- LLVM_DEBUG(DBGS() << " result #" << i << ": type=" << result.getType()
- << ", layout=" << resultLayouts[i] << "\n");
- xegpu::setTemporaryLayout(result, resultLayouts[i]);
- } else {
- LLVM_DEBUG(DBGS() << " result #" << i
- << ": skipped (no layout from use points)\n");
- }
- }
-
- // Use getSuccessorOperands to find which operands of the terminator
- // flow to a successor. This handles index offsets automatically (e.g.,
- // scf.condition's predicate at operand #0 is excluded).
- // Pick the first successor to determine the operand range.
- SmallVector<RegionSuccessor> successors;
- SmallVector<Attribute> operandAttrs(yieldOp->getNumOperands(), nullptr);
- yieldOp.getSuccessorRegions(operandAttrs, successors);
- assert(!successors.empty() && "terminator must have at least one successor");
-
- OperandRange succOps = yieldOp.getSuccessorOperands(successors.front());
- unsigned beginIdx = succOps.getBeginOperandIndex();
- unsigned count = std::min(static_cast<unsigned>(succOps.size()), numResults);
-
- LLVM_DEBUG(DBGS() << " " << count << " successor operands starting at index "
- << beginIdx << "\n");
-
- for (unsigned i = 0; i < count; ++i) {
- if (!resultLayouts[i])
- continue;
- LLVM_DEBUG(DBGS() << " -> setting layout on operand #" << (beginIdx + i)
- << "\n");
- xegpu::setTemporaryLayout(yieldOp->getOpOperand(beginIdx + i),
- resultLayouts[i]);
- }
-
- LLVM_DEBUG({
- DBGS() << " after propagateRegionResultsToYieldOperands:\n";
- yieldOp->print(llvm::dbgs(), OpPrintingFlags().printGenericOpForm());
- llvm::dbgs() << "\n";
- });
-}
-
-static void propagateRegionArgsToInits(mlir::RegionBranchOpInterface regionOp) {
- LLVM_DEBUG(DBGS() << "propagateRegionArgsToInits: " << regionOp->getName()
- << " (" << regionOp->getNumOperands() << " operands, "
- << regionOp->getNumRegions() << " regions)\n");
- LLVM_DEBUG({
- DBGS() << " before propagateRegionArgsToInits, Region IR:\n";
- regionOp.print(llvm::dbgs(), OpPrintingFlags().printGenericOpForm());
- llvm::dbgs() << "\n";
- });
-
- // Iterate all regions of the region op. For each block argument that has a
- // layout (determined from its use points), trace back to find the
- // corresponding init operand of the regionOp and set the layout on it.
- // This works generically for scf.for, scf.while, and other
- // RegionBranchOpInterface ops.
- for (Region ®ion : regionOp->getRegions()) {
- RegionSuccessor regionSuccessor(®ion);
- for (auto [argIdx, regionArg] : llvm::enumerate(region.getArguments())) {
- auto layout = getLayoutFromUsePoints(regionArg);
+// Attach layout attributes to all vector-type operands of operations within
+// the given operation's region. Reports an error if any vector operand lacks
+// a layout attribute.
+bool xegpu::recoverTemporaryLayouts(Operation *rootOp) {
+ auto result = rootOp->walk([&](Operation *op) {
+ for (OpOperand &operand : op->getOpOperands()) {
+ // Layouts are needed for vector type only.
+ if (!isa<VectorType>(operand.get().getType()))
+ continue;
+ // Skip block arguments since they don't have defining ops to attach
+ // layout attributes to.
+ if (isa<BlockArgument>(operand.get()))
+ continue;
+ auto layout = xegpu::getDistributeLayoutAttr(operand.get());
if (!layout) {
- LLVM_DEBUG(DBGS() << " region #" << region.getRegionNumber()
- << " arg #" << argIdx << ": skipped (no layout)\n");
+ op->emitWarning("Could not find layout attribute for operand ")
+ << operand.getOperandNumber() << " of operation " << op->getName();
continue;
}
- LLVM_DEBUG(DBGS() << " region #" << region.getRegionNumber() << " arg #"
- << argIdx << ": type=" << regionArg.getType()
- << ", layout=" << layout << "\n");
-
- // Find all predecessor values that flow into this block argument.
- SmallVector<Value> predValues;
- regionOp.getPredecessorValues(regionSuccessor, argIdx - 1, predValues);
- for (Value predVal : predValues) {
- // Match predecessor value to an operand of the regionOp.
- for (OpOperand &operand : regionOp->getOpOperands()) {
- if (operand.get() == predVal) {
- LLVM_DEBUG(DBGS() << " -> setting layout on init operand #"
- << operand.getOperandNumber() << "\n");
- xegpu::setTemporaryLayout(operand, layout);
- }
- }
- }
+ xegpu::setTemporaryLayout(operand, layout);
}
- }
-
- LLVM_DEBUG({
- DBGS() << " after propagateRegionArgsToInits, Region IR:\n";
- regionOp.print(llvm::dbgs(), OpPrintingFlags().printGenericOpForm());
- llvm::dbgs() << "\n";
- });
-}
-
-bool xegpu::recoverTemporaryLayouts(Operation *rootOp) {
- LLVM_DEBUG(DBGS() << "=== recoverTemporaryLayouts START ===\n");
-
- auto processFunc = [&](Region &body, StringRef funcName) {
- LLVM_DEBUG(DBGS() << "Processing func: " << funcName << "\n");
- walkRegionBackward(body, [&](Operation *op) {
- LLVM_DEBUG(DBGS() << "Visiting op: " << op->getName());
- if (auto regionOp = dyn_cast<mlir::RegionBranchOpInterface>(op)) {
- // hit the region op after visiting inside region
- LLVM_DEBUG(DBGS() << " -> dispatching as RegionBranchOp\n");
- propagateRegionArgsToInits(regionOp);
- } else if (auto yieldOp =
- dyn_cast<mlir::RegionBranchTerminatorOpInterface>(op)) {
- // yield op inside region op
- LLVM_DEBUG(DBGS() << " -> dispatching as YieldOp\n");
- propagateRegionResultsToYieldOperands(yieldOp);
- } else if (!dyn_cast<xegpu::AnchorLayoutInterface>(op)) {
- // if the op is regular op, calling propagateResultsToRegularOperands
- LLVM_DEBUG(DBGS() << " -> dispatching as regular op\n");
- propagateResultsToRegularOperands(op);
- }
- });
- };
-
- rootOp->walk([&](func::FuncOp func) {
- processFunc(func.getBody(), func.getSymName());
- });
- rootOp->walk([&](gpu::GPUFuncOp func) {
- processFunc(func.getBody(), func.getName());
- });
-
- LLVM_DEBUG(DBGS() << "=== recoverTemporaryLayouts END ===\n");
- // print the root op after
- LLVM_DEBUG({
- DBGS() << "After recoverTemporaryLayouts, IR:\n";
- rootOp->print(llvm::dbgs(), OpPrintingFlags().printGenericOpForm());
- llvm::dbgs() << "\n";
+ return WalkResult::advance();
});
- return true;
+ return !result.wasInterrupted();
}
-// // Attach layout attributes to all vector-type operands of operations within
-// // the given operation's region. Reports an error if any vector operand lacks
-// // a layout attribute.
-// bool xegpu::recoverTemporaryLayouts(Operation *rootOp) {
-// auto result = rootOp->walk([&](Operation *op) {
-// for (OpOperand &operand : op->getOpOperands()) {
-// // Layouts are needed for vector type only.
-// if (!isa<VectorType>(operand.get().getType()))
-// continue;
-// // Skip block arguments since they don't have defining ops to attach
-// // layout attributes to.
-// if (isa<BlockArgument>(operand.get()))
-// continue;
-// auto layout = xegpu::getDistributeLayoutAttr(operand.get());
-// if (!layout) {
-// op->emitWarning("Could not find layout attribute for operand ")
-// << operand.getOperandNumber() << " of operation " <<
-// op->getName();
-// xegpu::setTemporaryLayout(operand, layout);
-// continue;
-// }
-// }
-// return WalkResult::advance();
-// });
-// return !result.wasInterrupted();
-// }
-
template <typename T, typename>
void xegpu::removeLayoutAttr(const T &operandOrResult) {
Operation *owner = operandOrResult.getOwner();
@@ -1403,153 +1108,99 @@ xegpu::setupDpasLayout(xegpu::LayoutKind layoutKind, VectorType aTy,
return std::nullopt;
}
-xegpu::DistributeLayoutAttr
-xegpu::inferSourceLayoutFromResult(OpOperand &operand,
- xegpu::DistributeLayoutAttr resLayout) {
- if (!resLayout) {
- LLVM_DEBUG(DBGS() << "no resLayout, returning null\n");
- return xegpu::DistributeLayoutAttr();
- }
+xegpu::DistributeLayoutAttr xegpu::getConsumerLayoutAt(OpOperand &operand) {
Operation *op = operand.getOwner();
unsigned idx = operand.getOperandNumber();
+ xegpu::DistributeLayoutAttr resLayout;
+ if (op->getNumResults() == 1)
+ resLayout = xegpu::getDistributeLayoutAttr(op->getResult(0));
// For vector::BroadcastOp, infer the source layout from the result layout.
if (auto broadcast = dyn_cast<vector::BroadcastOp>(op)) {
- LLVM_DEBUG(DBGS() << " -> BroadcastOp\n");
+ if (!resLayout)
+ return xegpu::DistributeLayoutAttr();
auto srcTy = dyn_cast<VectorType>(broadcast.getSourceType());
- if (!srcTy) {
- LLVM_DEBUG(DBGS() << " source is not VectorType, returning null\n");
+ if (!srcTy)
return xegpu::DistributeLayoutAttr();
- }
- auto inferred = xegpu::inferBroadcastSourceLayout(
+ return xegpu::inferBroadcastSourceLayout(
resLayout, broadcast.getResultVectorType().getShape(),
srcTy.getShape());
- LLVM_DEBUG(DBGS() << " inferred=" << inferred << "\n");
- return inferred;
}
// For vector::MultiDimReductionOp, infer source layout from result layout
// using reduction dims. Acc operand is expected to have the same layout as
// the result.
if (auto reduction = dyn_cast<vector::MultiDimReductionOp>(op)) {
- LLVM_DEBUG(DBGS() << " -> MultiDimReductionOp, operand idx=" << idx
- << "\n");
+ if (!resLayout)
+ return xegpu::DistributeLayoutAttr();
if (idx == 0) {
SmallVector<int64_t> reductionDims(reduction.getReductionDims());
- LLVM_DEBUG({
- DBGS() << " reductionDims=[";
- llvm::interleaveComma(reductionDims, llvm::dbgs());
- llvm::dbgs() << "]\n";
- });
- auto inferred =
- xegpu::inferMultiReductionSourceLayout(resLayout, reductionDims);
- LLVM_DEBUG(DBGS() << " inferred source layout=" << inferred << "\n");
- return inferred;
+ return xegpu::inferMultiReductionSourceLayout(resLayout, reductionDims);
}
- if (idx == 1) {
- LLVM_DEBUG(DBGS() << " acc operand, using resLayout\n");
+ if (idx == 1)
return resLayout;
- }
}
if (auto reduction = dyn_cast<vector::ReductionOp>(op)) {
- LLVM_DEBUG(DBGS() << " -> ReductionOp\n");
- auto inferred = xegpu::inferReductionSourceLayout(resLayout);
- LLVM_DEBUG(DBGS() << " inferred=" << inferred << "\n");
- return inferred;
+ if (!resLayout)
+ return xegpu::DistributeLayoutAttr();
+ return xegpu::inferReductionSourceLayout(resLayout);
}
// For vector::BitCastOp, infer source layout from result layout using
// element type bitwidths.
if (auto bitcast = dyn_cast<vector::BitCastOp>(op)) {
- LLVM_DEBUG(DBGS() << " -> BitCastOp\n");
+ if (!resLayout)
+ return xegpu::DistributeLayoutAttr();
int resElemBitWidth =
bitcast.getResultVectorType().getElementType().getIntOrFloatBitWidth();
int srcElemBitWidth =
bitcast.getSourceVectorType().getElementType().getIntOrFloatBitWidth();
- LLVM_DEBUG(DBGS() << " resBitWidth=" << resElemBitWidth
- << ", srcBitWidth=" << srcElemBitWidth << "\n");
- auto inferred = xegpu::inferBitCastSourceLayout(resLayout, resElemBitWidth,
- srcElemBitWidth);
- LLVM_DEBUG(DBGS() << " inferred=" << inferred << "\n");
- return inferred;
+ return xegpu::inferBitCastSourceLayout(resLayout, resElemBitWidth,
+ srcElemBitWidth);
}
// For vector::ShapeCastOp, infer source layout from result layout using
// shapes.
if (auto shapeCast = dyn_cast<vector::ShapeCastOp>(op)) {
- LLVM_DEBUG({
- DBGS() << " -> ShapeCastOp: resShape=[";
- llvm::interleaveComma(shapeCast.getResultVectorType().getShape(),
- llvm::dbgs());
- llvm::dbgs() << "], srcShape=[";
- llvm::interleaveComma(shapeCast.getSourceVectorType().getShape(),
- llvm::dbgs());
- llvm::dbgs() << "]\n";
- });
- auto inferred = xegpu::inferShapeCastSourceLayout(
+ if (!resLayout)
+ return xegpu::DistributeLayoutAttr();
+ return xegpu::inferShapeCastSourceLayout(
resLayout, shapeCast.getResultVectorType().getShape(),
shapeCast.getSourceVectorType().getShape());
- LLVM_DEBUG(DBGS() << " inferred=" << inferred << "\n");
- return inferred;
}
// For vector::InsertStridedSliceOp, infer source layout from result layout.
// Dest vector must have the same layout as the result.
if (auto insertSlice = dyn_cast<vector::InsertStridedSliceOp>(op)) {
- LLVM_DEBUG(DBGS() << " -> InsertStridedSliceOp, operand idx=" << idx
- << "\n");
- if (idx == 0) {
- auto inferred = xegpu::inferInsertStridedSliceSourceLayout(
+ if (!resLayout)
+ return xegpu::DistributeLayoutAttr();
+ if (idx == 0)
+ return xegpu::inferInsertStridedSliceSourceLayout(
resLayout, insertSlice.getDestVectorType().getShape(),
insertSlice.getSourceVectorType().getShape());
- LLVM_DEBUG(DBGS() << " inferred source layout=" << inferred << "\n");
- return inferred;
- }
- if (idx == 1) {
- LLVM_DEBUG(DBGS() << " dest operand, using resLayout\n");
+ if (idx == 1)
return resLayout;
- }
}
// For vector::TransposeOp, infer source layout from result layout using
// permutation.
if (auto transpose = dyn_cast<vector::TransposeOp>(op)) {
- LLVM_DEBUG({
- DBGS() << " -> TransposeOp, perm=[";
- llvm::interleaveComma(transpose.getPermutation(), llvm::dbgs());
- llvm::dbgs() << "]\n";
- });
- auto inferred = xegpu::inferTransposeSourceLayout(
- resLayout, transpose.getPermutation());
- LLVM_DEBUG(DBGS() << " inferred=" << inferred << "\n");
- return inferred;
+ if (!resLayout)
+ return xegpu::DistributeLayoutAttr();
+ return xegpu::inferTransposeSourceLayout(resLayout,
+ transpose.getPermutation());
}
// For elementwise operations, all operands must have the same layout as the
// result.
if (OpTrait::hasElementwiseMappableTraits(op) && op->getNumResults() == 1) {
- LLVM_DEBUG(DBGS() << " -> elementwise op, using resLayout="
- << (resLayout ? resLayout : nullptr) << "\n");
-
+ if (!resLayout)
+ return xegpu::DistributeLayoutAttr();
return resLayout;
}
- return xegpu::DistributeLayoutAttr();
-}
-
-xegpu::DistributeLayoutAttr xegpu::getConsumerLayoutAt(OpOperand &operand) {
- Operation *op = operand.getOwner();
- xegpu::DistributeLayoutAttr resLayout;
- if (op->getNumResults() == 1)
- resLayout = xegpu::getDistributeLayoutAttr(op->getResult(0));
- auto inferredOperandLayout = inferSourceLayoutFromResult(operand, resLayout);
- if (inferredOperandLayout)
- return inferredOperandLayout;
+ // TODO: Handle more cases as needed here.
// By default, assume no layout conflict and return the current layout of
// the operand.
- auto fallback = xegpu::getDistributeLayoutAttr(operand.get());
- LLVM_DEBUG(DBGS() << " -> fallback (unhandled op " << op->getName()
- << "), returning operand layout="
- << (fallback ? fallback : nullptr) << "\n");
- return fallback;
+ return xegpu::getDistributeLayoutAttr(operand.get());
}
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index f0ff771f4cbc4..4c30dacae8850 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -1338,7 +1338,7 @@ LogicalResult ResolveLayoutConflicts::run() {
// as anchor op for the reduction op's layout.
if (isa<vector::MultiDimReductionOp>(op) || isa<vector::ReductionOp>(op)) {
for (OpResult result : op->getResults()) {
- if (result.getType().isIntOrFloat() || result.use_empty()) {
+ if (result.getType().isIntOrFloat()) {
auto res = assignResultLayout(result);
if (failed(res)) {
DBGS() << "Failed to resolve vector consumer for multi-reduction "
diff --git a/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir b/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
index 0d1bfd5480aa2..842c2375dd31d 100644
--- a/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
+++ b/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
@@ -473,6 +473,22 @@ gpu.func @vector_multi_reduction_dim0_distributed_dim1_reduction(%laneid: index)
gpu.return
}
+// CHECK-LABEL: gpu.func @vector_transpose
+// CHECK: %[[SRC:.*]] = "some_op"()
+// CHECK: %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[SRC]] : vector<16x2xf32> to vector<1x2xf32>
+// CHECK-NEXT: %[[T:.*]] = vector.transpose %[[CAST]], [1, 0] : vector<1x2xf32> to vector<2x1xf32>
+// CHECK-NEXT: gpu.return
+gpu.func @vector_transpose() {
+ %cst = "some_op"()
+ {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1], order = [0, 1]>}
+ : () -> (vector<16x2xf32>)
+ %transpose = vector.transpose %cst, [1, 0]
+ {
+ layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+ }
+ : vector<16x2xf32> to vector<2x16xf32>
+ gpu.return
+}
// CHECK-LABEL: gpu.func @vector_bitcast
// CHECK: %[[SRC:.*]] = "some_op"()
@@ -1076,8 +1092,7 @@ gpu.module @xevm_module {
gpu.func @vector_broadcast_2d_to_2d_noop(%laneid: index) {
%0 = "some_op"() {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : () -> vector<16x1xf16>
%1 = vector.broadcast %0 {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16x1xf16> to vector<16x16xf16>
- %2 = xegpu.convert_layout %1 <{input_layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, target_layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}> : vector<16x16xf16>
- "some_use"(%2) : (vector<16x16xf16>) -> ()
+ "some_use"(%1) : (vector<16x16xf16>) -> ()
gpu.return
}
}
>From 1328c5ff1981598a4ed9ff102f1ac17360cbd6c4 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Fri, 3 Apr 2026 20:37:55 +0000
Subject: [PATCH 06/19] cleanup
---
.../Transforms/XeGPUPeepHoleOptimizer.cpp | 15 +---
.../Transforms/XeGPUSubgroupDistribute.cpp | 38 +---------
mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp | 74 +++----------------
3 files changed, 12 insertions(+), 115 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPeepHoleOptimizer.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPeepHoleOptimizer.cpp
index c43eaba5b3ee6..c488bca363da6 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPeepHoleOptimizer.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPeepHoleOptimizer.cpp
@@ -28,7 +28,6 @@
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
-#include "llvm/Support/Debug.h"
#include <optional>
namespace mlir {
@@ -151,22 +150,12 @@ static xegpu::TensorDescType tryOptimize(xegpu::TensorDescType tdescType,
auto laneLayoutI64 = origLayout.getEffectiveLaneLayoutAsInt();
SmallVector<int32_t> laneLayoutI32(laneLayoutI64.begin(),
laneLayoutI64.end());
- LLVM_DEBUG({
- DBGS() << "tryOptimize: origLayout=" << origLayout << "\n";
- DBGS() << " laneLayoutI32=[";
- llvm::interleaveComma(laneLayoutI32, llvm::dbgs());
- llvm::dbgs() << "], laneData=[1, 1]";
- if (origLayout.getOrder())
- llvm::dbgs() << ", order=" << origLayout.getOrder();
- llvm::dbgs() << "\n";
- DBGS() << " supportedShape=[" << supportedHeight << ", " << supportedWidth
- << "], newElemTy=" << newElemTy << ", arrayLen=" << arrayLen << "\n";
- });
+
xegpu::LayoutAttr newLayout = xegpu::LayoutAttr::get(
ctx, /*lane_layout=*/DenseI32ArrayAttr::get(ctx, laneLayoutI32),
/*lane_data=*/DenseI32ArrayAttr::get(ctx, {1, 1}),
/*order=*/origLayout.getOrder());
- LLVM_DEBUG(DBGS() << " newLayout=" << newLayout << "\n");
+
// Array length can not be larger than 1 for transpose case.
return xegpu::TensorDescType::get(supportedShape, newElemTy, arrayLen,
tdescType.getBoundaryCheck(),
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 27cf788933f18..d8ce24ddd5cb0 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -800,17 +800,10 @@ struct StoreDistribution final : public gpu::WarpDistributionPattern {
using gpu::WarpDistributionPattern::WarpDistributionPattern;
LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
- LLVM_DEBUG(DBGS() << "StoreDistribution: attempting to match\n");
Operation *lastNode = warpOp.getTerminator()->getPrevNode();
auto storeScatterOp = dyn_cast_or_null<xegpu::StoreScatterOp>(lastNode);
- if (!storeScatterOp) {
- LLVM_DEBUG(
- DBGS()
- << "StoreDistribution: last node is not StoreScatterOp, skipping\n");
+ if (!storeScatterOp)
return failure();
- }
- LLVM_DEBUG(DBGS() << "StoreDistribution: matched StoreScatterOp: "
- << *storeScatterOp << "\n");
auto offsets = storeScatterOp.getOffsets();
if (!offsets || !isa<VectorType>(offsets.getType()))
return rewriter.notifyMatchFailure(
@@ -818,15 +811,10 @@ struct StoreDistribution final : public gpu::WarpDistributionPattern {
VectorType offsetsTy = cast<VectorType>(offsets.getType());
VectorType maskTy = cast<VectorType>(storeScatterOp.getMask().getType());
VectorType storeVecTy = cast<VectorType>(storeScatterOp.getValueType());
- LLVM_DEBUG(DBGS() << "StoreDistribution: offsetsTy=" << offsetsTy
- << ", maskTy=" << maskTy << ", storeVecTy=" << storeVecTy
- << "\n");
// Add handling for leading unit dimensions support
int chunkSize = storeScatterOp.getChunkSize().value_or(1);
int effectiveVecRank = (chunkSize == 1) ? 1 : 2;
- LLVM_DEBUG(DBGS() << "StoreDistribution: chunkSize=" << chunkSize
- << ", effectiveVecRank=" << effectiveVecRank << "\n");
// Check that all leading dimensions are unit dimensions
for (int i = 0; i < storeVecTy.getRank() - effectiveVecRank; i++) {
@@ -843,24 +831,6 @@ struct StoreDistribution final : public gpu::WarpDistributionPattern {
xegpu::getTemporaryLayout(storeScatterOp->getOpOperand(2));
auto layoutMask =
xegpu::getTemporaryLayout(storeScatterOp->getOpOperand(3));
- LLVM_DEBUG({
- DBGS() << "StoreDistribution: layoutPayload=";
- if (layoutPayload)
- DBGS() << layoutPayload;
- else
- DBGS() << "(null)";
- DBGS() << ", layoutOffsets=";
- if (layoutOffsets)
- DBGS() << layoutOffsets;
- else
- DBGS() << "(null)";
- DBGS() << ", layoutMask=";
- if (layoutMask)
- DBGS() << layoutMask;
- else
- DBGS() << "(null)";
- DBGS() << "\n";
- });
FailureOr<VectorType> distStoreVecByWarpOpOrFailure =
getDistVecTypeBasedOnLaneLayout(layoutPayload, storeVecTy);
@@ -879,9 +849,6 @@ struct StoreDistribution final : public gpu::WarpDistributionPattern {
VectorType distPayloadTy = distStoreVecByWarpOpOrFailure.value();
VectorType distOffsetsTy = distOffsetsByWarpOpOrFailure.value();
VectorType distMaskTy = distMaskByWarpOpOrFailure.value();
- LLVM_DEBUG(DBGS() << "StoreDistribution: distPayloadTy=" << distPayloadTy
- << ", distOffsetsTy=" << distOffsetsTy
- << ", distMaskTy=" << distMaskTy << "\n");
SmallVector<size_t> newRetIndices;
SmallVector<Value> operands = storeScatterOp->getOperands();
@@ -918,10 +885,7 @@ struct StoreDistribution final : public gpu::WarpDistributionPattern {
rewriter, newWarpOp.getLoc(), TypeRange{}, newStoreScatterOpOperands,
storeScatterOp->getAttrs());
xegpu::removeLayoutAttrs(newOp);
- LLVM_DEBUG(DBGS() << "StoreDistribution: created new op: " << newOp
- << "\n");
rewriter.eraseOp(storeScatterOp);
- LLVM_DEBUG(DBGS() << "StoreDistribution: done\n");
return success();
}
};
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
index 55cf47e38dfd0..bcac517937754 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
@@ -23,14 +23,10 @@
#include "mlir/Interfaces/LoopLikeInterface.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/Support/Casting.h"
-#include "llvm/Support/Debug.h"
#include "llvm/Support/FormatVariadic.h"
#include <cstdint>
#include <numeric>
-#define DEBUG_TYPE "xegpu-utils"
-#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
-
using namespace mlir;
/// convert ArrayRef<ValueRange> into SmallVector<Value>
@@ -149,31 +145,19 @@ std::string xegpu::getTemporaryLayoutName(const OpResult result) {
}
xegpu::DistributeLayoutAttr xegpu::getDistributeLayoutAttr(const Value value) {
- LLVM_DEBUG(DBGS() << "getDistributeLayoutAttr(Value): type="
- << value.getType() << "\n");
- if (!value) {
- LLVM_DEBUG(DBGS() << " -> null value, returning nullptr\n");
+ if (!value)
return nullptr;
- }
if (auto tdescTy =
- dyn_cast_if_present<xegpu::TensorDescType>(value.getType())) {
- auto layout = tdescTy.getLayoutAttr();
- LLVM_DEBUG(DBGS() << " -> TensorDescType, layout="
- << (layout ? layout : nullptr) << "\n");
- return layout;
- }
+ dyn_cast_if_present<xegpu::TensorDescType>(value.getType()))
+ return tdescTy.getLayoutAttr();
if (auto result = dyn_cast<OpResult>(value)) {
Operation *defOp = result.getDefiningOp();
assert(defOp && "result must have a defining op");
- LLVM_DEBUG(DBGS() << " OpResult #" << result.getResultNumber() << " from "
- << defOp->getName() << "\n");
if (auto anchorOp = dyn_cast<xegpu::AnchorLayoutInterface>(defOp)) {
auto layout = anchorOp.getAnchorLayout();
- LLVM_DEBUG(DBGS() << " -> AnchorLayoutInterface, layout="
- << (layout ? layout : nullptr) << "\n");
return layout;
}
@@ -181,100 +165,60 @@ xegpu::DistributeLayoutAttr xegpu::getDistributeLayoutAttr(const Value value) {
if (defOp->hasAttr(layoutName)) {
auto layout =
defOp->getAttrOfType<xegpu::DistributeLayoutAttr>(layoutName);
- LLVM_DEBUG(DBGS() << " -> temporary attr '" << layoutName
- << "', layout=" << layout << "\n");
return layout;
}
- LLVM_DEBUG(DBGS() << " -> OpResult: no layout found (checked '"
- << layoutName << "')\n");
}
if (auto arg = dyn_cast<BlockArgument>(value)) {
auto *parentOp = arg.getOwner()->getParentOp();
- LLVM_DEBUG(DBGS() << " BlockArgument #" << arg.getArgNumber() << " of "
- << (parentOp ? parentOp->getName().getStringRef()
- : StringRef("(null)"))
- << "\n");
if (auto loop = dyn_cast_if_present<LoopLikeOpInterface>(parentOp)) {
OpOperand *tiedInit = loop.getTiedLoopInit(arg);
if (tiedInit) {
- LLVM_DEBUG(DBGS() << " -> LoopLikeOp, recursing into tiedInit "
- << "operand #" << tiedInit->getOperandNumber()
- << "\n");
return getDistributeLayoutAttr(tiedInit->get());
}
- LLVM_DEBUG(DBGS() << " -> LoopLikeOp, no tiedInit\n");
}
}
- LLVM_DEBUG(DBGS() << " -> returning nullptr\n");
return nullptr;
}
xegpu::DistributeLayoutAttr
xegpu::getDistributeLayoutAttr(const OpOperand &opr) {
Operation *op = opr.getOwner();
unsigned idx = const_cast<OpOperand &>(opr).getOperandNumber();
- LLVM_DEBUG(DBGS() << "getDistributeLayoutAttr(OpOperand): operand #" << idx
- << " of " << op->getName()
- << ", type=" << opr.get().getType() << "\n");
if (auto anchorOp = dyn_cast<xegpu::AnchorLayoutInterface>(op)) {
if (auto dpasOp = dyn_cast<xegpu::DpasOp>(op)) {
if (idx == 0) {
- auto layout = dpasOp.getLayoutAAttr();
- LLVM_DEBUG(DBGS() << " -> DpasOp layoutA="
- << (layout ? layout : nullptr) << "\n");
- return layout;
+ return dpasOp.getLayoutAAttr();
} else if (idx == 1) {
- auto layout = dpasOp.getLayoutBAttr();
- LLVM_DEBUG(DBGS() << " -> DpasOp layoutB="
- << (layout ? layout : nullptr) << "\n");
- return layout;
+ return dpasOp.getLayoutBAttr();
} else if (idx == 2) {
- auto layout = dpasOp.getLayoutCdAttr();
- LLVM_DEBUG(DBGS() << " -> DpasOp layoutCd="
- << (layout ? layout : nullptr) << "\n");
- return layout;
+ return dpasOp.getLayoutCdAttr();
}
}
if (auto convertOp = dyn_cast<xegpu::ConvertLayoutOp>(op)) {
- auto layout = convertOp.getInputLayoutAttr();
- LLVM_DEBUG(DBGS() << " -> ConvertLayoutOp inputLayout="
- << (layout ? layout : nullptr) << "\n");
- return layout;
+ return convertOp.getInputLayoutAttr();
}
auto layout = anchorOp.getAnchorLayout();
- if (idx == 0) {
- LLVM_DEBUG(DBGS() << " -> AnchorLayoutInterface idx=0, layout="
- << (layout ? layout : nullptr) << "\n");
+ if (idx == 0)
return layout;
- }
// For store operations (StoreScatterOp, StoreNdOp, StoreMatrixOp),
// the layout is valid for the first two operands: value and memref/tdesc.
// For other operations, the layout applies to the first operand only.
if (isa<xegpu::StoreScatterOp, xegpu::StoreNdOp, xegpu::StoreMatrixOp>(
op) &&
- (idx < 2)) {
- LLVM_DEBUG(DBGS() << " -> Store op idx=" << idx
- << ", layout=" << (layout ? layout : nullptr) << "\n");
+ (idx < 2))
return layout;
- }
- LLVM_DEBUG(DBGS() << " -> AnchorLayoutInterface idx=" << idx
- << " not covered, falling through\n");
}
std::string layoutName = xegpu::getTemporaryLayoutName(opr);
if (op->hasAttr(layoutName)) {
auto layout = op->getAttrOfType<xegpu::DistributeLayoutAttr>(layoutName);
- LLVM_DEBUG(DBGS() << " -> temporary attr '" << layoutName
- << "', layout=" << layout << "\n");
return layout;
}
- LLVM_DEBUG(DBGS() << " -> returning nullptr (checked '" << layoutName
- << "')\n");
return nullptr;
}
>From 2617a0258e5435f141292f85c5633a8574bdaebd Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Fri, 3 Apr 2026 20:41:17 +0000
Subject: [PATCH 07/19] cleanup
---
mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp | 3 +--
1 file changed, 1 insertion(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
index bcac517937754..f0508a30621f2 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
@@ -173,9 +173,8 @@ xegpu::DistributeLayoutAttr xegpu::getDistributeLayoutAttr(const Value value) {
auto *parentOp = arg.getOwner()->getParentOp();
if (auto loop = dyn_cast_if_present<LoopLikeOpInterface>(parentOp)) {
OpOperand *tiedInit = loop.getTiedLoopInit(arg);
- if (tiedInit) {
+ if (tiedInit)
return getDistributeLayoutAttr(tiedInit->get());
- }
}
}
>From 182711fa0be85a820dfd3d3e769f27071762014d Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Fri, 3 Apr 2026 20:51:43 +0000
Subject: [PATCH 08/19] change needed for recoverTemporaryLayout, only sg
distribution tests fails
---
.../XeGPU/Transforms/XeGPULayoutImpl.h | 5 +-
.../XeGPU/Transforms/XeGPULayoutImpl.cpp | 457 +++++++++++++++---
.../XeGPU/Transforms/XeGPUPropagateLayout.cpp | 2 +-
.../XeGPU/sg-to-wi-experimental-unit.mlir | 3 +-
4 files changed, 410 insertions(+), 57 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h b/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h
index 9cf9a8705209b..5f46eab7b74c7 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h
@@ -183,10 +183,13 @@ setupDpasLayout(LayoutKind layoutKind, VectorType aTy, VectorType bTy,
VectorType cdTy, DistributeLayoutAttr consumerLayout, int numSg,
const uArch::uArch *uArch);
+DistributeLayoutAttr
+inferSourceLayoutFromResult(OpOperand &operand, DistributeLayoutAttr resLayout);
+
/// Gets the expected layout for a given consumer operand. This will check if
/// the owning operation of the consumer operand is one of the special layout
/// users and determine the expected layout accordingly.
-xegpu::DistributeLayoutAttr getConsumerLayoutAt(OpOperand &operand);
+DistributeLayoutAttr getConsumerLayoutAt(OpOperand &operand);
} // namespace xegpu
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
index 55cd6ec04970c..33c9086566d3c 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
@@ -18,16 +18,22 @@
#include "mlir/Dialect/LLVMIR/XeVMDialect.h"
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/ValueRange.h"
+#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/LoopLikeInterface.h"
#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/Support/Debug.h"
#include "llvm/Support/FormatVariadic.h"
#include <cstdint>
#include <numeric>
+#define DEBUG_TYPE "xegpu-layout-recovery"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
+
using namespace mlir;
void xegpu::recoverTemporaryLayoutsDeprecated(Operation *op) {
@@ -80,32 +86,321 @@ xegpu::dropInstDataOnAttrs(ArrayRef<NamedAttribute> attrs) {
return out;
}
-// Attach layout attributes to all vector-type operands of operations within
-// the given operation's region. Reports an error if any vector operand lacks
-// a layout attribute.
-bool xegpu::recoverTemporaryLayouts(Operation *rootOp) {
- auto result = rootOp->walk([&](Operation *op) {
- for (OpOperand &operand : op->getOpOperands()) {
- // Layouts are needed for vector type only.
- if (!isa<VectorType>(operand.get().getType()))
- continue;
- // Skip block arguments since they don't have defining ops to attach
- // layout attributes to.
- if (isa<BlockArgument>(operand.get()))
- continue;
- auto layout = xegpu::getDistributeLayoutAttr(operand.get());
+// Prerequisite for Layout Recovery
+// It relies on the following invariant:
+// 1. there is no layout conflict between different uses of the same definition.
+// 2. each definition has a well-defined layout requirement at its use point.
+// - Every definition must have at least one use that appears after it in
+// topological order.
+// - If a definition has no such use (e.g., a loop result or region output),
+// an explicit convert_layout operation is inserted to create a use.
+// - Only the result of convert_layout is permitted to have no subsequent
+// use.
+
+// The recovery proceeds by scanning the operation in reverse topological order
+// as follows:
+// For regular operations: First the result layouts are propagated from uses.
+// Then the result layouts are propagated to operands.
+//
+// For region operations (e.g., loops):
+// - When backward propagation reaches a region op, it sets the layout of
+// the region op’s results according to use points like regular ops.
+// - Then, the result layouts (such as a loop output) are propagated to
+// their corresponding operands in the yield.
+// - When backward propagation reaches the first operation inside the
+// region, the pass examines the region op’s initialization list,
+// propagating from region arguments to the corresponding initialization
+// operands.
+// - This ensures that layouts are consistently propagated
+// across region boundaries while preserving a single well-defined use for
+// each definition at the region-op level.
+
+// the inner function for recoverTemporaryLayouts is a recursive function
+// the input rootOp is the function operation, which is also a region op.
+// it recursivley process the region op in reverse topological order.
+
+static void walkRegionBackward(Region ®ion,
+ llvm::function_ref<void(Operation *)> visit) {
+ // blocks: back -> front
+ for (Block &block : llvm::reverse(region)) {
+ // ops: back -> front, early-inc so visit() may erase current op safely
+ for (Operation &op : llvm::reverse(block)) {
+ // make sure we first visit inside the region op (so yield op first)
+ // and then move to region op itself
+ for (Region &nested : llvm::reverse(op.getRegions()))
+ walkRegionBackward(nested, visit);
+
+ visit(&op);
+ }
+ }
+}
+
+static xegpu::DistributeLayoutAttr getLayoutFromUsePoints(Value result) {
+ xegpu::DistributeLayoutAttr layout = nullptr;
+ for (OpOperand &use : result.getUses()) {
+ if (auto tmpLayout = xegpu::getDistributeLayoutAttr(use)) {
+ // debug print the use and op, and the tmpLayout
+ LLVM_DEBUG({
+ DBGS() << "getLayoutFromUsePoints use: " << use.getOwner()->getName()
+ << use.getOwner();
+ llvm::dbgs() << ", tmpLayout=" << tmpLayout << "\n";
+ });
+ // under debug mode, we want to check all the use points to make sure
+ // there is no conflict, so we do not break here. In release mode, we can
+ // break at the first use
+ if (!layout)
+ layout = tmpLayout;
+ }
+ }
+ return layout;
+}
+
+// For regular operations: First the result layouts are propagated from uses.
+// Then the result layouts are propagated to uses (operands).
+static void propagateResultsToRegularOperands(Operation *op) {
+ LLVM_DEBUG(DBGS() << "propagateResultsToRegularOperands: " << op->getName()
+ << " (" << op->getNumOperands() << " operands, "
+ << op->getNumResults() << " results)\n");
+
+ if (op->getNumResults() == 0) {
+ LLVM_DEBUG(DBGS() << " skipping (no results)\n");
+ return;
+ }
+
+ Value result = op->getResult(0);
+ xegpu::DistributeLayoutAttr resLayout =
+ getLayoutFromUsePoints(op->getResult(0));
+ Type resultType = result.getType();
+
+ // recover layout for tensor Descriptor type, which is a special case since
+ // its layout is not stored as an attribute but encoded in the type itself.
+ // For vector type, we attach the layout as an attribute to op.
+ if (auto tensorDescTy = dyn_cast<xegpu::TensorDescType>(resultType)) {
+ auto layout = tensorDescTy.getLayoutAttr();
+ // TODO: remove the layout check. The tensorDescType's layout is treated as
+ // temporary layout, which needs to be set by layout recovery.
+ // allow it now to pass some legacy test case
+ if (!layout) {
+ auto typeWithLayout = xegpu::TensorDescType::get(
+ tensorDescTy.getContext(), tensorDescTy.getShape(),
+ tensorDescTy.getElementType(), tensorDescTy.getEncoding(), resLayout);
+ result.setType(typeWithLayout);
+ }
+ }
+
+ for (OpOperand &opr : op->getOpOperands()) {
+ // Layouts are needed for vector type only.
+ xegpu::DistributeLayoutAttr operandLayout =
+ xegpu::inferSourceLayoutFromResult(opr, resLayout);
+ if (!isa<VectorType>(opr.get().getType())) {
+ LLVM_DEBUG(DBGS() << " operand #" << opr.getOperandNumber()
+ << ": skipped (non-vector type: " << opr.get().getType()
+ << ")\n");
+ continue;
+ }
+
+ xegpu::setTemporaryLayout(opr, operandLayout);
+ // debug print op
+ LLVM_DEBUG(DBGS() << "after propagateResultsToRegularOperands op: "
+ << op->getName() << op << " operand #"
+ << opr.getOperandNumber()
+ << ": type=" << opr.get().getType());
+ llvm::dbgs() << ", temp Layout=" << xegpu::getTemporaryLayout(opr);
+ llvm::dbgs() << "\n";
+ }
+}
+
+static void propagateRegionResultsToYieldOperands(
+ mlir::RegionBranchTerminatorOpInterface yieldOp) {
+ LLVM_DEBUG(DBGS() << "propagateRegionResultsToYieldOperands: "
+ << yieldOp->getName() << " (" << yieldOp->getNumOperands()
+ << " operands), parent="
+ << yieldOp->getParentOp()->getName() << "\n");
+
+ if (isa<func::FuncOp>(yieldOp->getParentOp())) {
+ LLVM_DEBUG(DBGS() << " skipping (parent is FuncOp)\n");
+ return;
+ }
+
+ auto regionBranchOp =
+ dyn_cast<RegionBranchOpInterface>(yieldOp->getParentOp());
+ if (!regionBranchOp) {
+ LLVM_DEBUG(DBGS() << " skipping (parent is not RegionBranchOp)\n");
+ return;
+ }
+
+ // Gather layouts for each result of the parent region op from external
+ // use points.
+ unsigned numResults = regionBranchOp->getNumResults();
+ LLVM_DEBUG(DBGS() << " parent op has " << numResults << " results\n");
+ if (numResults == 0)
+ return;
+
+ SmallVector<xegpu::DistributeLayoutAttr> resultLayouts(numResults, nullptr);
+ for (unsigned i = 0; i < numResults; ++i) {
+ OpResult result = regionBranchOp->getResult(i);
+ resultLayouts[i] = getLayoutFromUsePoints(result);
+ if (resultLayouts[i]) {
+ LLVM_DEBUG(DBGS() << " result #" << i << ": type=" << result.getType()
+ << ", layout=" << resultLayouts[i] << "\n");
+ xegpu::setTemporaryLayout(result, resultLayouts[i]);
+ } else {
+ LLVM_DEBUG(DBGS() << " result #" << i
+ << ": skipped (no layout from use points)\n");
+ }
+ }
+
+ // Use getSuccessorOperands to find which operands of the terminator
+ // flow to a successor. This handles index offsets automatically (e.g.,
+ // scf.condition's predicate at operand #0 is excluded).
+ // Pick the first successor to determine the operand range.
+ SmallVector<RegionSuccessor> successors;
+ SmallVector<Attribute> operandAttrs(yieldOp->getNumOperands(), nullptr);
+ yieldOp.getSuccessorRegions(operandAttrs, successors);
+ assert(!successors.empty() && "terminator must have at least one successor");
+
+ OperandRange succOps = yieldOp.getSuccessorOperands(successors.front());
+ unsigned beginIdx = succOps.getBeginOperandIndex();
+ unsigned count = std::min(static_cast<unsigned>(succOps.size()), numResults);
+
+ LLVM_DEBUG(DBGS() << " " << count << " successor operands starting at index "
+ << beginIdx << "\n");
+
+ for (unsigned i = 0; i < count; ++i) {
+ if (!resultLayouts[i])
+ continue;
+ LLVM_DEBUG(DBGS() << " -> setting layout on operand #" << (beginIdx + i)
+ << "\n");
+ xegpu::setTemporaryLayout(yieldOp->getOpOperand(beginIdx + i),
+ resultLayouts[i]);
+ }
+
+ LLVM_DEBUG({
+ DBGS() << " after propagateRegionResultsToYieldOperands:\n";
+ yieldOp->print(llvm::dbgs(), OpPrintingFlags().printGenericOpForm());
+ llvm::dbgs() << "\n";
+ });
+}
+
+static void propagateRegionArgsToInits(mlir::RegionBranchOpInterface regionOp) {
+ LLVM_DEBUG(DBGS() << "propagateRegionArgsToInits: " << regionOp->getName()
+ << " (" << regionOp->getNumOperands() << " operands, "
+ << regionOp->getNumRegions() << " regions)\n");
+ LLVM_DEBUG({
+ DBGS() << " before propagateRegionArgsToInits, Region IR:\n";
+ regionOp.print(llvm::dbgs(), OpPrintingFlags().printGenericOpForm());
+ llvm::dbgs() << "\n";
+ });
+
+ // Iterate all regions of the region op. For each block argument that has a
+ // layout (determined from its use points), trace back to find the
+ // corresponding init operand of the regionOp and set the layout on it.
+ // This works generically for scf.for, scf.while, and other
+ // RegionBranchOpInterface ops.
+ for (Region ®ion : regionOp->getRegions()) {
+ RegionSuccessor regionSuccessor(®ion);
+ for (auto [argIdx, regionArg] : llvm::enumerate(region.getArguments())) {
+ auto layout = getLayoutFromUsePoints(regionArg);
if (!layout) {
- op->emitWarning("Could not find layout attribute for operand ")
- << operand.getOperandNumber() << " of operation " << op->getName();
+ LLVM_DEBUG(DBGS() << " region #" << region.getRegionNumber()
+ << " arg #" << argIdx << ": skipped (no layout)\n");
continue;
}
- xegpu::setTemporaryLayout(operand, layout);
+ LLVM_DEBUG(DBGS() << " region #" << region.getRegionNumber() << " arg #"
+ << argIdx << ": type=" << regionArg.getType()
+ << ", layout=" << layout << "\n");
+
+ // Find all predecessor values that flow into this block argument.
+ SmallVector<Value> predValues;
+ regionOp.getPredecessorValues(regionSuccessor, argIdx - 1, predValues);
+ for (Value predVal : predValues) {
+ // Match predecessor value to an operand of the regionOp.
+ for (OpOperand &operand : regionOp->getOpOperands()) {
+ if (operand.get() == predVal) {
+ LLVM_DEBUG(DBGS() << " -> setting layout on init operand #"
+ << operand.getOperandNumber() << "\n");
+ xegpu::setTemporaryLayout(operand, layout);
+ }
+ }
+ }
}
- return WalkResult::advance();
+ }
+
+ LLVM_DEBUG({
+ DBGS() << " after propagateRegionArgsToInits, Region IR:\n";
+ regionOp.print(llvm::dbgs(), OpPrintingFlags().printGenericOpForm());
+ llvm::dbgs() << "\n";
+ });
+}
+
+bool xegpu::recoverTemporaryLayouts(Operation *rootOp) {
+ LLVM_DEBUG(DBGS() << "=== recoverTemporaryLayouts START ===\n");
+
+ auto processFunc = [&](Region &body, StringRef funcName) {
+ LLVM_DEBUG(DBGS() << "Processing func: " << funcName << "\n");
+ walkRegionBackward(body, [&](Operation *op) {
+ LLVM_DEBUG(DBGS() << "Visiting op: " << op->getName());
+ if (auto regionOp = dyn_cast<mlir::RegionBranchOpInterface>(op)) {
+ // hit the region op after visiting inside region
+ LLVM_DEBUG(DBGS() << " -> dispatching as RegionBranchOp\n");
+ propagateRegionArgsToInits(regionOp);
+ } else if (auto yieldOp =
+ dyn_cast<mlir::RegionBranchTerminatorOpInterface>(op)) {
+ // yield op inside region op
+ LLVM_DEBUG(DBGS() << " -> dispatching as YieldOp\n");
+ propagateRegionResultsToYieldOperands(yieldOp);
+ } else if (!dyn_cast<xegpu::AnchorLayoutInterface>(op)) {
+ // if the op is regular op, calling propagateResultsToRegularOperands
+ LLVM_DEBUG(DBGS() << " -> dispatching as regular op\n");
+ propagateResultsToRegularOperands(op);
+ }
+ });
+ };
+
+ rootOp->walk([&](func::FuncOp func) {
+ processFunc(func.getBody(), func.getSymName());
+ });
+ rootOp->walk([&](gpu::GPUFuncOp func) {
+ processFunc(func.getBody(), func.getName());
+ });
+
+ LLVM_DEBUG(DBGS() << "=== recoverTemporaryLayouts END ===\n");
+ // print the root op after
+ LLVM_DEBUG({
+ DBGS() << "After recoverTemporaryLayouts, IR:\n";
+ rootOp->print(llvm::dbgs(), OpPrintingFlags().printGenericOpForm());
+ llvm::dbgs() << "\n";
});
- return !result.wasInterrupted();
+ return true;
}
+// // Attach layout attributes to all vector-type operands of operations within
+// // the given operation's region. Reports an error if any vector operand lacks
+// // a layout attribute.
+// bool xegpu::recoverTemporaryLayouts(Operation *rootOp) {
+// auto result = rootOp->walk([&](Operation *op) {
+// for (OpOperand &operand : op->getOpOperands()) {
+// // Layouts are needed for vector type only.
+// if (!isa<VectorType>(operand.get().getType()))
+// continue;
+// // Skip block arguments since they don't have defining ops to attach
+// // layout attributes to.
+// if (isa<BlockArgument>(operand.get()))
+// continue;
+// auto layout = xegpu::getDistributeLayoutAttr(operand.get());
+// if (!layout) {
+// op->emitWarning("Could not find layout attribute for operand ")
+// << operand.getOperandNumber() << " of operation " <<
+// op->getName();
+// xegpu::setTemporaryLayout(operand, layout);
+// continue;
+// }
+// }
+// return WalkResult::advance();
+// });
+// return !result.wasInterrupted();
+// }
+
template <typename T, typename>
void xegpu::removeLayoutAttr(const T &operandOrResult) {
Operation *owner = operandOrResult.getOwner();
@@ -1108,99 +1403,153 @@ xegpu::setupDpasLayout(xegpu::LayoutKind layoutKind, VectorType aTy,
return std::nullopt;
}
-xegpu::DistributeLayoutAttr xegpu::getConsumerLayoutAt(OpOperand &operand) {
+xegpu::DistributeLayoutAttr
+xegpu::inferSourceLayoutFromResult(OpOperand &operand,
+ xegpu::DistributeLayoutAttr resLayout) {
+ if (!resLayout) {
+ LLVM_DEBUG(DBGS() << "no resLayout, returning null\n");
+ return xegpu::DistributeLayoutAttr();
+ }
Operation *op = operand.getOwner();
unsigned idx = operand.getOperandNumber();
- xegpu::DistributeLayoutAttr resLayout;
- if (op->getNumResults() == 1)
- resLayout = xegpu::getDistributeLayoutAttr(op->getResult(0));
// For vector::BroadcastOp, infer the source layout from the result layout.
if (auto broadcast = dyn_cast<vector::BroadcastOp>(op)) {
- if (!resLayout)
- return xegpu::DistributeLayoutAttr();
+ LLVM_DEBUG(DBGS() << " -> BroadcastOp\n");
auto srcTy = dyn_cast<VectorType>(broadcast.getSourceType());
- if (!srcTy)
+ if (!srcTy) {
+ LLVM_DEBUG(DBGS() << " source is not VectorType, returning null\n");
return xegpu::DistributeLayoutAttr();
- return xegpu::inferBroadcastSourceLayout(
+ }
+ auto inferred = xegpu::inferBroadcastSourceLayout(
resLayout, broadcast.getResultVectorType().getShape(),
srcTy.getShape());
+ LLVM_DEBUG(DBGS() << " inferred=" << inferred << "\n");
+ return inferred;
}
// For vector::MultiDimReductionOp, infer source layout from result layout
// using reduction dims. Acc operand is expected to have the same layout as
// the result.
if (auto reduction = dyn_cast<vector::MultiDimReductionOp>(op)) {
- if (!resLayout)
- return xegpu::DistributeLayoutAttr();
+ LLVM_DEBUG(DBGS() << " -> MultiDimReductionOp, operand idx=" << idx
+ << "\n");
if (idx == 0) {
SmallVector<int64_t> reductionDims(reduction.getReductionDims());
- return xegpu::inferMultiReductionSourceLayout(resLayout, reductionDims);
+ LLVM_DEBUG({
+ DBGS() << " reductionDims=[";
+ llvm::interleaveComma(reductionDims, llvm::dbgs());
+ llvm::dbgs() << "]\n";
+ });
+ auto inferred =
+ xegpu::inferMultiReductionSourceLayout(resLayout, reductionDims);
+ LLVM_DEBUG(DBGS() << " inferred source layout=" << inferred << "\n");
+ return inferred;
}
- if (idx == 1)
+ if (idx == 1) {
+ LLVM_DEBUG(DBGS() << " acc operand, using resLayout\n");
return resLayout;
+ }
}
if (auto reduction = dyn_cast<vector::ReductionOp>(op)) {
- if (!resLayout)
- return xegpu::DistributeLayoutAttr();
- return xegpu::inferReductionSourceLayout(resLayout);
+ LLVM_DEBUG(DBGS() << " -> ReductionOp\n");
+ auto inferred = xegpu::inferReductionSourceLayout(resLayout);
+ LLVM_DEBUG(DBGS() << " inferred=" << inferred << "\n");
+ return inferred;
}
// For vector::BitCastOp, infer source layout from result layout using
// element type bitwidths.
if (auto bitcast = dyn_cast<vector::BitCastOp>(op)) {
- if (!resLayout)
- return xegpu::DistributeLayoutAttr();
+ LLVM_DEBUG(DBGS() << " -> BitCastOp\n");
int resElemBitWidth =
bitcast.getResultVectorType().getElementType().getIntOrFloatBitWidth();
int srcElemBitWidth =
bitcast.getSourceVectorType().getElementType().getIntOrFloatBitWidth();
- return xegpu::inferBitCastSourceLayout(resLayout, resElemBitWidth,
- srcElemBitWidth);
+ LLVM_DEBUG(DBGS() << " resBitWidth=" << resElemBitWidth
+ << ", srcBitWidth=" << srcElemBitWidth << "\n");
+ auto inferred = xegpu::inferBitCastSourceLayout(resLayout, resElemBitWidth,
+ srcElemBitWidth);
+ LLVM_DEBUG(DBGS() << " inferred=" << inferred << "\n");
+ return inferred;
}
// For vector::ShapeCastOp, infer source layout from result layout using
// shapes.
if (auto shapeCast = dyn_cast<vector::ShapeCastOp>(op)) {
- if (!resLayout)
- return xegpu::DistributeLayoutAttr();
- return xegpu::inferShapeCastSourceLayout(
+ LLVM_DEBUG({
+ DBGS() << " -> ShapeCastOp: resShape=[";
+ llvm::interleaveComma(shapeCast.getResultVectorType().getShape(),
+ llvm::dbgs());
+ llvm::dbgs() << "], srcShape=[";
+ llvm::interleaveComma(shapeCast.getSourceVectorType().getShape(),
+ llvm::dbgs());
+ llvm::dbgs() << "]\n";
+ });
+ auto inferred = xegpu::inferShapeCastSourceLayout(
resLayout, shapeCast.getResultVectorType().getShape(),
shapeCast.getSourceVectorType().getShape());
+ LLVM_DEBUG(DBGS() << " inferred=" << inferred << "\n");
+ return inferred;
}
// For vector::InsertStridedSliceOp, infer source layout from result layout.
// Dest vector must have the same layout as the result.
if (auto insertSlice = dyn_cast<vector::InsertStridedSliceOp>(op)) {
- if (!resLayout)
- return xegpu::DistributeLayoutAttr();
- if (idx == 0)
- return xegpu::inferInsertStridedSliceSourceLayout(
+ LLVM_DEBUG(DBGS() << " -> InsertStridedSliceOp, operand idx=" << idx
+ << "\n");
+ if (idx == 0) {
+ auto inferred = xegpu::inferInsertStridedSliceSourceLayout(
resLayout, insertSlice.getDestVectorType().getShape(),
insertSlice.getSourceVectorType().getShape());
- if (idx == 1)
+ LLVM_DEBUG(DBGS() << " inferred source layout=" << inferred << "\n");
+ return inferred;
+ }
+ if (idx == 1) {
+ LLVM_DEBUG(DBGS() << " dest operand, using resLayout\n");
return resLayout;
+ }
}
// For vector::TransposeOp, infer source layout from result layout using
// permutation.
if (auto transpose = dyn_cast<vector::TransposeOp>(op)) {
- if (!resLayout)
- return xegpu::DistributeLayoutAttr();
- return xegpu::inferTransposeSourceLayout(resLayout,
- transpose.getPermutation());
+ LLVM_DEBUG({
+ DBGS() << " -> TransposeOp, perm=[";
+ llvm::interleaveComma(transpose.getPermutation(), llvm::dbgs());
+ llvm::dbgs() << "]\n";
+ });
+ auto inferred = xegpu::inferTransposeSourceLayout(
+ resLayout, transpose.getPermutation());
+ LLVM_DEBUG(DBGS() << " inferred=" << inferred << "\n");
+ return inferred;
}
// For elementwise operations, all operands must have the same layout as the
// result.
if (OpTrait::hasElementwiseMappableTraits(op) && op->getNumResults() == 1) {
- if (!resLayout)
- return xegpu::DistributeLayoutAttr();
+ LLVM_DEBUG(DBGS() << " -> elementwise op, using resLayout="
+ << (resLayout ? resLayout : nullptr) << "\n");
+
return resLayout;
}
- // TODO: Handle more cases as needed here.
+ return xegpu::DistributeLayoutAttr();
+}
+
+xegpu::DistributeLayoutAttr xegpu::getConsumerLayoutAt(OpOperand &operand) {
+ Operation *op = operand.getOwner();
+ xegpu::DistributeLayoutAttr resLayout;
+ if (op->getNumResults() == 1)
+ resLayout = xegpu::getDistributeLayoutAttr(op->getResult(0));
+ auto inferredOperandLayout = inferSourceLayoutFromResult(operand, resLayout);
+ if (inferredOperandLayout)
+ return inferredOperandLayout;
// By default, assume no layout conflict and return the current layout of
// the operand.
- return xegpu::getDistributeLayoutAttr(operand.get());
+ auto fallback = xegpu::getDistributeLayoutAttr(operand.get());
+ LLVM_DEBUG(DBGS() << " -> fallback (unhandled op " << op->getName()
+ << "), returning operand layout="
+ << (fallback ? fallback : nullptr) << "\n");
+ return fallback;
}
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 4c30dacae8850..f0ff771f4cbc4 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -1338,7 +1338,7 @@ LogicalResult ResolveLayoutConflicts::run() {
// as anchor op for the reduction op's layout.
if (isa<vector::MultiDimReductionOp>(op) || isa<vector::ReductionOp>(op)) {
for (OpResult result : op->getResults()) {
- if (result.getType().isIntOrFloat()) {
+ if (result.getType().isIntOrFloat() || result.use_empty()) {
auto res = assignResultLayout(result);
if (failed(res)) {
DBGS() << "Failed to resolve vector consumer for multi-reduction "
diff --git a/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir b/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
index 842c2375dd31d..ae0347d507e9c 100644
--- a/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
+++ b/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
@@ -1092,7 +1092,8 @@ gpu.module @xevm_module {
gpu.func @vector_broadcast_2d_to_2d_noop(%laneid: index) {
%0 = "some_op"() {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : () -> vector<16x1xf16>
%1 = vector.broadcast %0 {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16x1xf16> to vector<16x16xf16>
- "some_use"(%1) : (vector<16x16xf16>) -> ()
+ %2 = xegpu.convert_layout %1 <{input_layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, target_layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}> : vector<16x16xf16>
+ "some_use"(%2) : (vector<16x16xf16>) -> ()
gpu.return
}
}
>From 6d4955dd6cc75162045db403e7822c07b5c4976e Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Wed, 8 Apr 2026 05:29:22 +0000
Subject: [PATCH 09/19] almost there - only sg-to-wi-experimental-unit.mlir
fails
---
.../XeGPU/Transforms/XeGPULayoutImpl.h | 6 ++
mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp | 2 +-
.../XeGPU/Transforms/XeGPULayoutImpl.cpp | 18 ++++-
.../XeGPU/Transforms/XeGPUPropagateLayout.cpp | 70 +++++++++++--------
.../XeGPUSgToWiDistributeExperimental.cpp | 10 +++
.../Transforms/XeGPUSubgroupDistribute.cpp | 50 +++++++++++--
.../Transforms/XeGPUWgToSgDistribute.cpp | 3 -
mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp | 25 +++++--
.../XeGPU/sg-to-wi-experimental-unit.mlir | 5 ++
.../Dialect/XeGPU/sg-to-wi-experimental.mlir | 30 +++-----
.../XeGPU/subgroup-distribute-unit.mlir | 69 +++++-------------
.../Dialect/XeGPU/subgroup-distribute.mlir | 19 ++---
.../XeGPU/xegpu-wg-to-sg-unify-ops.mlir | 2 +-
13 files changed, 178 insertions(+), 131 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h b/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h
index 5f46eab7b74c7..c36201c2f0d9e 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h
@@ -112,6 +112,12 @@ inferInsertStridedSliceSourceLayout(DistributeLayoutAttr resLayout,
ArrayRef<int64_t> resShape,
ArrayRef<int64_t> srcShape);
+/// Infers the layout attribute for mask and offset operand for Chunked load
+/// and store, given the anchor layout attribute for the value being load/store.
+DistributeLayoutAttr
+inferMaskOffsetLayoutForScatterIO(DistributeLayoutAttr payloadLayout,
+ int chunkSize);
+
/// Sets up layout for Multi-Reduction operations by creating a SliceAttr for
/// the result.
///
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 64c56b5adf5d7..eaa43c02946d8 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -585,7 +585,7 @@ DistributeLayoutAttr LayoutAttr::dropDims(SmallVector<int64_t> dimGroup) {
int64_t offset = llvm::count_if(dimGroup, [&](int64_t s) { return s < d; });
newOrder.push_back(d - offset);
}
- if (sgLayout.empty() && laneLayout.empty())
+ if ((sgLayout.empty() && laneLayout.empty()) || newOrder.size() == 1)
newOrder.clear();
auto toAttr = [&](ArrayRef<int64_t> v) -> DenseI32ArrayAttr {
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
index 29af3d3f1d95a..d5c5e05549745 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
@@ -167,9 +167,8 @@ static void propagateResultsToRegularOperands(Operation *op) {
return;
}
- Value result = op->getResult(0);
- xegpu::DistributeLayoutAttr resLayout =
- getLayoutFromUsePoints(op->getResult(0));
+ OpResult result = op->getResult(0);
+ xegpu::DistributeLayoutAttr resLayout = getLayoutFromUsePoints(result);
Type resultType = result.getType();
// recover layout for tensor Descriptor type, which is a special case since
@@ -188,6 +187,8 @@ static void propagateResultsToRegularOperands(Operation *op) {
}
}
+ xegpu::setTemporaryLayout(result, resLayout);
+
for (OpOperand &opr : op->getOpOperands()) {
// Layouts are needed for vector type only.
xegpu::DistributeLayoutAttr operandLayout =
@@ -676,6 +677,17 @@ xegpu::inferShapeCastSourceLayout(xegpu::DistributeLayoutAttr resLayout,
return nullptr;
}
+/// Infers the layout attribute for mask and offset operand for Chunked load
+/// and store, given the anchor layout attribute for the value being load/store.
+xegpu::DistributeLayoutAttr xegpu::inferMaskOffsetLayoutForScatterIO(
+ xegpu::DistributeLayoutAttr payloadLayout, int chunkSize) {
+ auto rank = payloadLayout.getRank();
+ if (chunkSize > 1)
+ return payloadLayout.dropDims(
+ llvm::to_vector(llvm::seq<int64_t>(rank - 1, rank)));
+ return payloadLayout;
+}
+
/// Sets up layout for reduction operations by creating a SliceAttr for the
/// result.
///
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index f0ff771f4cbc4..8abb1fc4bae52 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -1027,7 +1027,7 @@ void LayoutInfoPropagation::visitLoadGatherOp(
const uArch *uArch = getUArch(getChipStr(load).value_or(""));
if (!uArch)
return;
- auto subgroupSize = uArch->getSubgroupSize();
+ // auto subgroupSize = uArch->getSubgroupSize();
VectorType resVecTy = load.getValueType();
int chunkSize = load.getChunkSize().value_or(1);
@@ -1049,20 +1049,24 @@ void LayoutInfoPropagation::visitLoadGatherOp(
load.setLayoutAttr(requiredAnchorLayoutAttr);
}
- auto maskLayoutAttr = requiredAnchorLayoutAttr;
- // Special handling mask layout for chunked ops: Enforce the default xegpu 1D
- // layout for mask.
- if (chunkSize > 1) {
- if (layoutKind == xegpu::LayoutKind::InstData)
- maskLayoutAttr =
- xegpu::LayoutAttr::get(load->getContext(), {subgroupSize});
- else if (layoutKind == xegpu::LayoutKind::Lane)
- maskLayoutAttr =
- xegpu::LayoutAttr::get(load->getContext(), {subgroupSize}, {1});
- else
- assert(false &&
- "chunked StoreScatterOp should not be used at workgroup level");
- }
+ assert((chunkSize <= 1) || (layoutKind != xegpu::LayoutKind::Subgroup));
+ auto maskLayoutAttr = xegpu::inferMaskOffsetLayoutForScatterIO(
+ requiredAnchorLayoutAttr, chunkSize);
+
+ // // Special handling mask layout for chunked ops: Enforce the default xegpu
+ // 1D
+ // // layout for mask.
+ // if (chunkSize > 1) {
+ // if (layoutKind == xegpu::LayoutKind::InstData)
+ // maskLayoutAttr =
+ // xegpu::LayoutAttr::get(load->getContext(), {subgroupSize});
+ // else if (layoutKind == xegpu::LayoutKind::Lane)
+ // maskLayoutAttr =
+ // xegpu::LayoutAttr::get(load->getContext(), {subgroupSize}, {1});
+ // else
+ // assert(false &&
+ // "chunked StoreScatterOp should not be used at workgroup level");
+ // }
LayoutInfo maskLayoutInfo = LayoutInfo(maskLayoutAttr);
auto loadLayoutInfo = LayoutInfo(requiredAnchorLayoutAttr);
@@ -1105,7 +1109,7 @@ void LayoutInfoPropagation::visitStoreScatterOp(
const uArch *uArch = getUArch(getChipStr(storeScatter).value_or(""));
if (!uArch)
return;
- auto subgroupSize = uArch->getSubgroupSize();
+ // auto subgroupSize = uArch->getSubgroupSize();
VectorType srcVecTy = storeScatter.getValueType();
int chunkSize = storeScatter.getChunkSize().value_or(1);
@@ -1122,23 +1126,27 @@ void LayoutInfoPropagation::visitStoreScatterOp(
}
LayoutInfo srcLayoutInfo = LayoutInfo(requiredAnchorLayoutAttr);
- auto maskLayoutAttr = requiredAnchorLayoutAttr;
- // Special handling mask layout for chunked ops: Enforce the default xegpu 1D
- // layout for mask.
- if (chunkSize > 1) {
- if (layoutKind == xegpu::LayoutKind::InstData)
- maskLayoutAttr =
- xegpu::LayoutAttr::get(storeScatter->getContext(), {subgroupSize});
- else if (layoutKind == xegpu::LayoutKind::Lane)
- maskLayoutAttr = xegpu::LayoutAttr::get(storeScatter->getContext(),
- {subgroupSize}, {1});
- else
- assert(false &&
- "chunked StoreScatterOp should not be used at workgroup level");
- }
-
+ assert((chunkSize <= 1) || (layoutKind != xegpu::LayoutKind::Subgroup));
+ auto maskLayoutAttr = xegpu::inferMaskOffsetLayoutForScatterIO(
+ requiredAnchorLayoutAttr, chunkSize);
LayoutInfo maskLayoutInfo = LayoutInfo(maskLayoutAttr);
+ // auto maskLayoutAttr = requiredAnchorLayoutAttr;
+ // // Special handling mask layout for chunked ops: Enforce the default xegpu
+ // 1D
+ // // layout for mask.
+ // if (chunkSize > 1) {
+ // if (layoutKind == xegpu::LayoutKind::InstData)
+ // maskLayoutAttr =
+ // xegpu::LayoutAttr::get(storeScatter->getContext(), {subgroupSize});
+ // else if (layoutKind == xegpu::LayoutKind::Lane)
+ // maskLayoutAttr = xegpu::LayoutAttr::get(storeScatter->getContext(),
+ // {subgroupSize}, {1});
+ // else
+ // assert(false &&
+ // "chunked StoreScatterOp should not be used at workgroup level");
+ // }
+
// Propagate the payload operand layout
propagateIfChanged(operands[0], operands[0]->meet(srcLayoutInfo));
// Propagate the destination (if tdesc) operand layout
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
index e3227c7f5b149..c029ee1d8ae0d 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
@@ -1610,6 +1610,16 @@ void XeGPUSgToWiDistributeExperimentalPass::runOnOperation() {
}
});
}
+ // Remove layout attributes from SCF ops
+ getOperation()->walk([](Operation *op) {
+ SmallVector<StringAttr> attrsToRemove;
+ for (auto namedAttr : op->getDiscardableAttrs()) {
+ if (isa<xegpu::DistributeLayoutAttr>(namedAttr.getValue()))
+ attrsToRemove.push_back(namedAttr.getName());
+ }
+ for (auto attrName : attrsToRemove)
+ op->removeDiscardableAttr(attrName);
+ });
}
void xegpu::populateXeGPUSgToWiDistributeTypeConversions(
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index d8ce24ddd5cb0..012d7aefafb06 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -825,12 +825,10 @@ struct StoreDistribution final : public gpu::WarpDistributionPattern {
}
}
- auto layoutPayload =
- xegpu::getTemporaryLayout(storeScatterOp->getOpOperand(0));
+ auto layoutPayload = storeScatterOp.getLayoutAttr();
auto layoutOffsets =
- xegpu::getTemporaryLayout(storeScatterOp->getOpOperand(2));
- auto layoutMask =
- xegpu::getTemporaryLayout(storeScatterOp->getOpOperand(3));
+ xegpu::inferMaskOffsetLayoutForScatterIO(layoutPayload, chunkSize);
+ auto layoutMask = layoutOffsets;
FailureOr<VectorType> distStoreVecByWarpOpOrFailure =
getDistVecTypeBasedOnLaneLayout(layoutPayload, storeVecTy);
@@ -1132,9 +1130,36 @@ struct LoadDistribution final : public gpu::WarpDistributionPattern {
}
}
+ auto layoutPayload = loadGatherOp.getLayoutAttr();
auto layoutOffsets =
- xegpu::getTemporaryLayout(loadGatherOp->getOpOperand(1));
- auto layoutMask = xegpu::getTemporaryLayout(loadGatherOp->getOpOperand(2));
+ xegpu::inferMaskOffsetLayoutForScatterIO(layoutPayload, chunkSize);
+ auto layoutMask = layoutOffsets;
+
+ // print the layouts for debug
+ LLVM_DEBUG({
+ llvm::dbgs() << "In LoadDistribution pattern:\n";
+ llvm::dbgs() << "Payload layout: ";
+ if (layoutPayload)
+ llvm::dbgs() << layoutPayload;
+ else
+ llvm::dbgs() << "none";
+ llvm::dbgs() << "\nOffsets layout: ";
+ if (layoutOffsets)
+ llvm::dbgs() << layoutOffsets;
+ else
+ llvm::dbgs() << "none";
+ llvm::dbgs() << "\nMask layout: ";
+ if (layoutMask)
+ llvm::dbgs() << layoutMask;
+ else
+ llvm::dbgs() << "none";
+ llvm::dbgs() << "\n";
+ });
+
+ // auto layoutOffsets =
+ // xegpu::getTemporaryLayout(loadGatherOp->getOpOperand(1));
+ // auto layoutMask =
+ // xegpu::getTemporaryLayout(loadGatherOp->getOpOperand(2));
FailureOr<VectorType> distOffsetsByWarpOpOrFailure =
getDistVecTypeBasedOnLaneLayout(layoutOffsets, offsetsTy);
@@ -2281,4 +2306,15 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
op->erase();
return WalkResult::advance();
});
+
+ // Remove layout attributes from SCF ops
+ getOperation()->walk([](Operation *op) {
+ SmallVector<StringAttr> attrsToRemove;
+ for (auto namedAttr : op->getDiscardableAttrs()) {
+ if (isa<xegpu::DistributeLayoutAttr>(namedAttr.getValue()))
+ attrsToRemove.push_back(namedAttr.getName());
+ }
+ for (auto attrName : attrsToRemove)
+ op->removeDiscardableAttr(attrName);
+ });
}
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index a095c19d66c15..02f88828f667f 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -1786,9 +1786,6 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
// Remove layout attributes from SCF ops
getOperation()->walk([](Operation *op) {
- if (!isa<RegionBranchOpInterface, RegionBranchTerminatorOpInterface>(op))
- return;
-
SmallVector<StringAttr> attrsToRemove;
for (auto namedAttr : op->getDiscardableAttrs()) {
if (isa<xegpu::DistributeLayoutAttr>(namedAttr.getValue()))
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
index f0508a30621f2..0dab06d206ceb 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
@@ -16,6 +16,7 @@
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h"
#include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Operation.h"
@@ -203,13 +204,27 @@ xegpu::getDistributeLayoutAttr(const OpOperand &opr) {
if (idx == 0)
return layout;
- // For store operations (StoreScatterOp, StoreNdOp, StoreMatrixOp),
+ // For StoreNdOp and StoreMatrixOp,
// the layout is valid for the first two operands: value and memref/tdesc.
- // For other operations, the layout applies to the first operand only.
- if (isa<xegpu::StoreScatterOp, xegpu::StoreNdOp, xegpu::StoreMatrixOp>(
- op) &&
- (idx < 2))
+ if (isa<xegpu::StoreNdOp, xegpu::StoreMatrixOp>(op) && (idx < 2))
+ return layout;
+
+ if (isa<xegpu::StoreScatterOp>(op)) {
+ xegpu::StoreScatterOp store(op);
+ int chunkSize = store.getChunkSize().value_or(1);
+ if (layout && idx >= 2 && chunkSize > 1)
+ return layout.dropDims(llvm::to_vector(
+ llvm::seq<int64_t>(layout.getRank() - 1, layout.getRank())));
+ return layout;
+ }
+ if (isa<xegpu::LoadGatherOp>(op)) {
+ xegpu::LoadGatherOp load(op);
+ int chunkSize = load.getChunkSize().value_or(1);
+ if (layout && idx >= 1 && chunkSize > 1)
+ return layout.dropDims(llvm::to_vector(
+ llvm::seq<int64_t>(layout.getRank() - 1, layout.getRank())));
return layout;
+ }
}
std::string layoutName = xegpu::getTemporaryLayoutName(opr);
diff --git a/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir b/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
index c7b5542f91cbd..d53d0463db4cb 100644
--- a/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
+++ b/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
@@ -487,6 +487,11 @@ gpu.func @vector_transpose() {
layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
}
: vector<16x2xf32> to vector<2x16xf32>
+ %transpose2 = xegpu.convert_layout %transpose
+ <{
+ input_layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+ target_layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+ }> : vector<2x16xf32>
gpu.return
}
diff --git a/mlir/test/Dialect/XeGPU/sg-to-wi-experimental.mlir b/mlir/test/Dialect/XeGPU/sg-to-wi-experimental.mlir
index 9febd79c7adc3..723c70a09931e 100644
--- a/mlir/test/Dialect/XeGPU/sg-to-wi-experimental.mlir
+++ b/mlir/test/Dialect/XeGPU/sg-to-wi-experimental.mlir
@@ -235,28 +235,24 @@ gpu.func @gemm_with_postop(%arg0: memref<1024x1024xbf16>, %arg1: memref<1024x102
gpu.module @xevm_module{
gpu.func @load_dpas_postop_store(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<8x16xf32>) {
%c0 = arith.constant 0 : index
- %cst = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} dense<0.0> : vector<8x16xf32>
+ %cst = arith.constant dense<0.0> : vector<8x16xf32>
%0 = xegpu.create_nd_tdesc %arg0 : memref<8x16xf16>
-> !xegpu.tensor_desc<8x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
%1 = xegpu.load_nd %0[%c0, %c0]
- {layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
- layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} :
+ {layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} :
!xegpu.tensor_desc<8x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<8x16xf16>
%2 = xegpu.create_nd_tdesc %arg1: memref<16x16xf16>
-> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>>
%3 = xegpu.load_nd %2[%c0, %c0]
- {layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>,
- layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>}
+ {layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>}
: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>>
-> vector<16x16xf16>
%4 = xegpu.dpas %1, %3, %cst
{layout_a = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
layout_b = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>,
- layout_cd = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
- layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+ layout_cd = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
: vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
%5 = math.exp %4
- {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
: vector<8x16xf32>
%6 = xegpu.create_nd_tdesc %arg2 : memref<8x16xf32> ->
!xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
@@ -279,7 +275,7 @@ gpu.module @xevm_module{
// CHECK-NEXT: scf.yield %[[LD_CAST]] : vector<1x8xf16>
// CHECK-NEXT: } else {
// CHECK-NEXT: scf.yield %[[CST]] : vector<1x8xf16>
-// CHECK-NEXT: } {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>}
+// CHECK-NEXT: }
// CHECK-NEXT: %[[IF_CAST:.*]] = vector.shape_cast %[[IF]] : vector<1x8xf16> to vector<8xf16>
// CHECK-NEXT: xegpu.store %[[IF_CAST]], %{{.*}}[%[[OFFSET]]], %[[MASK]] <{chunk_size = 8 : i64}>
// CHECK-SAME: vector<8xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
@@ -289,16 +285,13 @@ gpu.module @xevm_module{
%offset = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<12> : vector<16xindex>
%loaded = scf.if %pred -> (vector<16x8xf16>) {
%3 = xegpu.load %src[%offset], %1 <{chunk_size=8}> {
- layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>,
- layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>
+ layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>
} : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
scf.yield %3 : vector<16x8xf16>
} else {
- %3 = arith.constant {
- layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>
- } dense<12.> : vector<16x8xf16>
+ %3 = arith.constant dense<12.> : vector<16x8xf16>
scf.yield %3 : vector<16x8xf16>
- } { layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]> }
+ }
xegpu.store %loaded, %src[%offset], %1 <{chunk_size=8}> {layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>} : vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
gpu.return
}
@@ -318,12 +311,11 @@ gpu.module @xevm_module{
gpu.module @xevm_module{
gpu.func @scatter_ops_scf_non_yield(%src: memref<256xf16>) {
%pred = llvm.mlir.poison : i1
- %1 = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<1>: vector<16xi1>
- %offset = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<12> : vector<16xindex>
+ %1 = arith.constant dense<1>: vector<16xi1>
+ %offset = arith.constant dense<12> : vector<16xindex>
scf.if %pred {
%3 = xegpu.load %src[%offset], %1 <{chunk_size=8}> {
- layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>,
- layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>
+ layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>
} : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
xegpu.store %3, %src[%offset], %1 <{chunk_size=8}> {layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>} : vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
}
diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir
index 18ebc09caa5aa..27c5bd497b948 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir
@@ -469,10 +469,9 @@ gpu.func @vector_multi_reduction_3d_trivial_reduction(%laneid: index) {
gpu.return
}
-
// CHECK-LABEL: gpu.func @scatter_ops_chunksize({{.*}}) {
-// CHECK: %[[OFFSETS:.*]] = arith.constant {{.*}} dense<12> : vector<16xindex>
-// CHECK: %[[MASKS:.*]] = arith.constant {{.*}} dense<true> : vector<16xi1>
+// CHECK: %[[OFFSETS:.*]] = arith.constant dense<12> : vector<16xindex>
+// CHECK: %[[MASKS:.*]] = arith.constant dense<true> : vector<16xi1>
// CHECK: %[[W:.*]]:4 = gpu.warp_execute_on_lane_0(%{{.*}})[16]
// CHECK-SAME: -> (vector<1x8xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>) {
// CHECK: gpu.yield %{{.*}}, %{{.*}}, %[[OFFSETS]], %[[MASKS]] :
@@ -484,34 +483,19 @@ gpu.func @vector_multi_reduction_3d_trivial_reduction(%laneid: index) {
// CHECK-SAME: : vector<8xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
gpu.func @scatter_ops_chunksize(%laneid: index, %src: memref<256xf16>) {
gpu.warp_execute_on_lane_0(%laneid)[16] {
- %1 = arith.constant
- {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
- dense<1>: vector<16xi1>
- %offset = arith.constant
- {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
- dense<12> : vector<16xindex>
- %3 = xegpu.load %src[%offset], %1 <{chunk_size=8}>
- {
- layout_operand_1 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
- layout_operand_2 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
- layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>
- }
+ %1 = arith.constant dense<1>: vector<16xi1>
+ %offset = arith.constant dense<12> : vector<16xindex>
+ %3 = xegpu.load %src[%offset], %1 <{chunk_size=8, layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>}>
: memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
- xegpu.store %3, %src[%offset], %1 <{chunk_size=8}>
- {
- layout_operand_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>,
- layout_operand_2 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
- layout_operand_3 = #xegpu.layout<lane_layout = [16], lane_data = [1]>
- }
+ xegpu.store %3, %src[%offset], %1 <{chunk_size=8, layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>}>
: vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
}
gpu.return
}
-
// CHECK-LABEL: gpu.func @scatter_ops({{.*}}) {
-// CHECK: %[[OFFSETS:.*]] = arith.constant {{.*}} dense<12> : vector<16xindex>
-// CHECK: %[[MASKS:.*]] = arith.constant {{.*}} dense<true> : vector<16xi1>
+// CHECK: %[[OFFSETS:.*]] = arith.constant dense<12> : vector<16xindex>
+// CHECK: %[[MASKS:.*]] = arith.constant dense<true> : vector<16xi1>
// CHECK: %[[W:.*]]:4 = gpu.warp_execute_on_lane_0(%{{.*}})[16]
// CHECK-SAME: -> (vector<1xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>) {
// CHECK: gpu.yield %{{.*}}, %{{.*}}, %[[OFFSETS]], %[[MASKS]]
@@ -523,23 +507,15 @@ gpu.func @scatter_ops_chunksize(%laneid: index, %src: memref<256xf16>) {
// CHECK-SAME: : vector<1xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
gpu.func @scatter_ops(%src: memref<256xf16>, %laneid: index) {
gpu.warp_execute_on_lane_0(%laneid)[16] {
- %1 = arith.constant
- {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
- dense<1> : vector<16xi1>
- %offset = arith.constant
- {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
- dense<12> : vector<16xindex>
+ %1 = arith.constant dense<1> : vector<16xi1>
+ %offset = arith.constant dense<12> : vector<16xindex>
%3 = xegpu.load %src[%offset], %1
{
- layout_operand_1 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
- layout_operand_2 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
- layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>
+ layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>
} : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16xf16>
xegpu.store %3, %src[%offset], %1
{
- layout_operand_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
- layout_operand_2 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
- layout_operand_3 = #xegpu.layout<lane_layout = [16], lane_data = [1]>
+ layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>
}
: vector<16xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
}
@@ -547,8 +523,8 @@ gpu.func @scatter_ops(%src: memref<256xf16>, %laneid: index) {
}
// CHECK-LABEL: gpu.func @scatter_ops_with_leading_dims({{.*}}) {
-// CHECK: %[[OFFSETS:.*]] = arith.constant {{.*}} dense<12> : vector<1x1x16xindex>
-// CHECK: %[[MASKS:.*]] = arith.constant {{.*}} dense<true> : vector<1x1x16xi1>
+// CHECK: %[[OFFSETS:.*]] = arith.constant dense<12> : vector<1x1x16xindex>
+// CHECK: %[[MASKS:.*]] = arith.constant dense<true> : vector<1x1x16xi1>
// CHECK: %[[W:.*]]:4 = gpu.warp_execute_on_lane_0(%{{.*}})[16]
// CHECK-SAME: -> (vector<1x1x1xf16>, memref<256xf16>, vector<1x1x1xindex>, vector<1x1x1xi1>) {
// CHECK: gpu.yield %{{.*}}, %{{.*}}, %[[OFFSETS]], %[[MASKS]]
@@ -563,23 +539,12 @@ gpu.func @scatter_ops(%src: memref<256xf16>, %laneid: index) {
gpu.func @scatter_ops_with_leading_dims(%src: memref<256xf16>, %laneid: index) {
gpu.warp_execute_on_lane_0(%laneid)[16] {
%1 = arith.constant
- {layout_result_0 = #xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>}
dense<1> : vector<1x1x16xi1>
%offset = arith.constant
- {layout_result_0 = #xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>}
dense<12> : vector<1x1x16xindex>
- %3 = xegpu.load %src[%offset], %1
- {
- layout_operand_1 = #xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>,
- layout_operand_2 = #xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>,
- layout_result_0 = #xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>
- } : memref<256xf16>, vector<1x1x16xindex>, vector<1x1x16xi1> -> vector<1x1x16xf16>
- xegpu.store %3, %src[%offset], %1
- {
- layout_operand_0 = #xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>,
- layout_operand_2 = #xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>,
- layout_operand_3 = #xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>
- }
+ %3 = xegpu.load %src[%offset], %1 {layout = #xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>}
+ : memref<256xf16>, vector<1x1x16xindex>, vector<1x1x16xi1> -> vector<1x1x16xf16>
+ xegpu.store %3, %src[%offset], %1 { layout = #xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>}
: vector<1x1x16xf16>, memref<256xf16>, vector<1x1x16xindex>, vector<1x1x16xi1>
}
gpu.return
diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
index c3cdc79d9f70e..285669cae7174 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
@@ -10,7 +10,7 @@
// CHECK: %[[T1:.*]] = xegpu.load_nd %[[T0]][%{{.*}}] <{packed}> : !xegpu.tensor_desc<16x16xf16> -> vector<16xf16>
// CHECK-DAG: %[[T4:.*]] = xegpu.dpas %[[T3]], %[[T1]] : vector<8xf16>, vector<16xf16> -> vector<8xf32>
// CHECK: %[[T5:.*]] = vector.shape_cast %[[T4]] : vector<8xf32> to vector<8x1xf32>
-// CHECK: %[[T6:.*]] = math.exp %[[T5]] {{{.*}}} : vector<8x1xf32>
+// CHECK: %[[T6:.*]] = math.exp %[[T5]] : vector<8x1xf32>
// CHECK-DAG: %[[T8:.*]] = vector.shape_cast %[[T6]] : vector<8x1xf32> to vector<8xf32>
// CHECK-DAG: %[[T7:.*]] = xegpu.create_nd_tdesc %[[ARG2]] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
// CHECK: xegpu.store_nd %[[T8]], %[[T7]][{{.*}}] : vector<8xf32>, !xegpu.tensor_desc<8x16xf32>
@@ -423,16 +423,17 @@ gpu.module @xevm_test {
// CHECK: %[[VEC_RED:.*]] = vector.broadcast %[[VEC_RED_3]] : f32 to vector<1xf32>
// CHECK: xegpu.store %[[VEC_RED]], %arg1[%[[CST]]], %[[CST_0]] : vector<1xf32>, memref<256xf32>, vector<1xindex>, vector<1xi1>
gpu.func @vector_reduce_2d(%arg0: memref<4x16xf32>, %arg1: memref<256xf32>) {
- %cst = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0, 1]>} 1.000000e+00 : f32
+ %cst = arith.constant 1.000000e+00 : f32
%0 = xegpu.create_nd_tdesc %arg0 : memref<4x16xf32> -> !xegpu.tensor_desc<4x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
%1 = xegpu.load_nd %0[0, 0] <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}> : !xegpu.tensor_desc<4x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<4x16xf32>
- %2 = vector.broadcast %cst {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>} : f32 to vector<16xf32>
- %3 = vector.multi_reduction <add>, %1, %2 {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>} [0] : vector<4x16xf32> to vector<16xf32>
- %4 = vector.reduction <add>, %3 {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0, 1]>} : vector<16xf32> into f32
- %5 = vector.broadcast %4 {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>} : f32 to vector<16xf32>
- %cst_0 = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<0> : vector<16xindex>
- %cst_1 = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<true> : vector<16xi1>
- xegpu.store %5, %arg1[%cst_0], %cst_1 <{layout = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>}> : vector<16xf32>, memref<256xf32>, vector<16xindex>, vector<16xi1>
+ %2 = vector.broadcast %cst : f32 to vector<16xf32>
+ %3 = vector.multi_reduction <add>, %1, %2 [0] : vector<4x16xf32> to vector<16xf32>
+ %4 = vector.reduction <add>, %3 : vector<16xf32> into f32
+ %40 = xegpu.convert_layout %4 <{input_layout = #xegpu.slice<#xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>, dims = [0]>, target_layout = #xegpu.slice<#xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>, dims = [0]>}>: f32
+ %5 = vector.broadcast %40 : f32 to vector<16xf32>
+ %cst_0 = arith.constant dense<0> : vector<16xindex>
+ %cst_1 = arith.constant dense<true> : vector<16xi1>
+ xegpu.store %5, %arg1[%cst_0], %cst_1 <{layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>}> : vector<16xf32>, memref<256xf32>, vector<16xindex>, vector<16xi1>
gpu.return
}
}
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
index bbdffa0986962..3bc43b780ade2 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
@@ -405,7 +405,7 @@ gpu.module @test_distribution {
// CHECK-LABEL: gpu.func @vector_reduce_scalar_cross_sg
// CHECK-SAME: (%[[ARG0:.*]]: memref<32x32xf32>)
- // CHECK-DAG: %[[CST:.*]] = arith.constant {{.*}} 0.000000e+00 : f32
+ // CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
// CHECK-DAG: %[[LOAD:.*]] = xegpu.load_nd %{{.*}} : !xegpu.tensor_desc<8x8xf32> -> vector<8x8xf32>
// CHECK-DAG: %[[CST_ACC:.*]] = arith.constant 0.000000e+00 : f32
// CHECK-DAG: %[[LOCAL:.*]] = vector.multi_reduction <add>, %[[LOAD]], %[[CST_ACC]] [0, 1] : vector<8x8xf32> to f32
>From 01602b9a946cdded77cc3d09b2033ecfe495e615 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Wed, 8 Apr 2026 20:51:46 +0000
Subject: [PATCH 10/19] fix tests
---
mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 22 +--
.../XeGPU/Transforms/XeGPULayoutImpl.cpp | 12 +-
.../XeGPU/Transforms/XeGPUPropagateLayout.cpp | 51 ++++-
.../XeGPU/sg-to-wi-experimental-unit.mlir | 175 ++++++++++++++++++
4 files changed, 237 insertions(+), 23 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 5697097a4c999..d6981052a7a5c 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -1101,17 +1101,17 @@ LogicalResult ConvertLayoutOp::verify() {
return emitOpError("expected input layout and target layout be WgLayout or "
"SgLayout at the same time.");
- Type srcType = getSource().getType();
- if (llvm::isa<VectorType>(srcType)) {
- auto shape = llvm::cast<VectorType>(srcType).getShape();
- if (!XeGPUDialect::isEvenlyDistributable(shape, srcLayout))
- return emitOpError(
- "invalid input layout, data cannot be evenly distributed.");
-
- if (!XeGPUDialect::isEvenlyDistributable(shape, resLayout))
- return emitOpError(
- "invalid target layout, data cannot be evenly distributed.");
- }
+ // Type srcType = getSource().getType();
+ // if (llvm::isa<VectorType>(srcType)) {
+ // auto shape = llvm::cast<VectorType>(srcType).getShape();
+ // if (!XeGPUDialect::isEvenlyDistributable(shape, srcLayout))
+ // return emitOpError(
+ // "invalid input layout, data cannot be evenly distributed.");
+
+ // if (!XeGPUDialect::isEvenlyDistributable(shape, resLayout))
+ // return emitOpError(
+ // "invalid target layout, data cannot be evenly distributed.");
+ // }
return mlir::success();
}
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
index d5c5e05549745..47c60eaf7d4e0 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
@@ -1541,12 +1541,14 @@ xegpu::inferSourceLayoutFromResult(OpOperand &operand,
return inferred;
}
- // For elementwise operations, all operands must have the same layout as the
- // result.
- if (OpTrait::hasElementwiseMappableTraits(op) && op->getNumResults() == 1) {
- LLVM_DEBUG(DBGS() << " -> elementwise op, using resLayout="
+ if (isa<VectorType>(operand.get().getType()) &&
+ !dyn_cast<xegpu::AnchorLayoutInterface>(op)) {
+ // For elementwise operations, all operands must have the same layout as the
+ // result.
+ // if (OpTrait::hasElementwiseMappableTraits(op) && op->getNumResults() ==
+ // 1) {
+ LLVM_DEBUG(DBGS() << " -> other vector or tensorDesc ops using resLayout="
<< (resLayout ? resLayout : nullptr) << "\n");
-
return resLayout;
}
return xegpu::DistributeLayoutAttr();
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 8abb1fc4bae52..6aeabacf4b30f 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -1538,37 +1538,63 @@ static LogicalResult
updateControlFlowOps(mlir::OpBuilder &builder,
mlir::RegionBranchTerminatorOpInterface terminator,
GetLayoutFnTy getLayoutOfValue) {
+ LLVM_DEBUG(DBGS() << "updateControlFlowOps: processing terminator: "
+ << *terminator << "\n");
// Only process if the terminator is inside a region branch op.
auto branchOp = dyn_cast<RegionBranchOpInterface>(terminator->getParentOp());
- if (!branchOp)
+ if (!branchOp) {
+ LLVM_DEBUG(
+ DBGS() << " parent is not a RegionBranchOpInterface, skipping\n");
return success();
+ }
+ LLVM_DEBUG(DBGS() << " parent branch op: " << *branchOp << "\n");
RegionBranchSuccessorMapping mapping;
branchOp.getSuccessorOperandInputMapping(mapping,
RegionBranchPoint(terminator));
+ LLVM_DEBUG(DBGS() << " successor mapping has " << mapping.size()
+ << " entries\n");
for (const auto &[successorOperand, successorInputs] : mapping) {
+ LLVM_DEBUG(DBGS() << " processing successor operand: "
+ << successorOperand->get()
+ << " (type: " << successorOperand->get().getType()
+ << "), num successor inputs: " << successorInputs.size()
+ << "\n");
for (Value successorInput : successorInputs) {
Type inputType = successorInput.getType();
+ LLVM_DEBUG(DBGS() << " successor input: " << successorInput
+ << ", type: " << inputType << "\n");
// We only need to operate on tensor descriptor or vector types.
- if (!isa<xegpu::TensorDescType, VectorType>(inputType))
+ if (!isa<xegpu::TensorDescType, VectorType>(inputType)) {
+ LLVM_DEBUG(
+ DBGS() << " skipping: not a TensorDescType or VectorType\n");
continue;
+ }
xegpu::DistributeLayoutAttr successorInputLayout =
getLayoutOfValue(successorInput);
xegpu::DistributeLayoutAttr successorOperandLayout =
getLayoutOfValue(successorOperand->get());
+ LLVM_DEBUG(DBGS() << " successor input layout: ");
+ LLVM_DEBUG(if (successorInputLayout) llvm::dbgs() << successorInputLayout;
+ else llvm::dbgs() << "<<NULL>>"; llvm::dbgs() << "\n");
+ LLVM_DEBUG(DBGS() << " successor operand layout: ");
+ LLVM_DEBUG(if (successorOperandLayout) llvm::dbgs()
+ << successorOperandLayout;
+ else llvm::dbgs() << "<<NULL>>"; llvm::dbgs() << "\n");
+
// If either of the layouts is not assigned, we cannot proceed.
if (!successorOperandLayout) {
- LLVM_DEBUG(DBGS() << "No layout assigned for forwarded operand in "
- "branch terminator: "
+ LLVM_DEBUG(DBGS() << " FAILURE: No layout assigned for forwarded "
+ "operand in branch terminator: "
<< successorOperand->get() << "\n");
return failure();
}
// We expect the layouts to match.
if (successorInputLayout &&
successorInputLayout != successorOperandLayout) {
- LLVM_DEBUG(DBGS() << "Conflicting layouts for region argument and "
- "operand forwarded as the argument: "
+ LLVM_DEBUG(DBGS() << " FAILURE: Conflicting layouts for region "
+ "argument and operand forwarded as the argument: "
<< successorInputLayout << " vs "
<< successorOperandLayout << "\n");
return failure();
@@ -1578,15 +1604,26 @@ updateControlFlowOps(mlir::OpBuilder &builder,
auto newTdescTy = xegpu::TensorDescType::get(
tdescTy.getContext(), tdescTy.getShape(), tdescTy.getElementType(),
tdescTy.getEncoding(), successorOperandLayout);
+ LLVM_DEBUG(DBGS() << " updating tensor desc type: " << tdescTy
+ << " -> " << newTdescTy << "\n");
successorInput.setType(newTdescTy);
continue;
}
// If the type is a vector type and this region argument is an OpResult,
// set the layout attribute on the OpResult.
- if (auto result = dyn_cast<OpResult>(successorInput))
+ if (auto result = dyn_cast<OpResult>(successorInput)) {
+ LLVM_DEBUG(DBGS() << " setting layout on OpResult #"
+ << result.getResultNumber() << " of "
+ << *result.getOwner() << " to "
+ << successorOperandLayout << "\n");
xegpu::setDistributeLayoutAttr(result, successorOperandLayout);
+ } else {
+ LLVM_DEBUG(DBGS() << " successor input is a BlockArgument, "
+ "not setting layout attribute\n");
+ }
}
}
+ LLVM_DEBUG(DBGS() << " updateControlFlowOps: success\n");
return success();
}
diff --git a/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir b/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
index d53d0463db4cb..056e5e7e34a63 100644
--- a/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
+++ b/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
@@ -136,6 +136,11 @@ gpu.func @elementwise() {
%3 = arith.addf %0, %2
{layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
: vector<16x16xf32>
+ %cl3 = xegpu.convert_layout %3
+ <{
+ input_layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+ target_layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+ }> : vector<16x16xf32>
gpu.return
}
@@ -145,6 +150,11 @@ gpu.func @elementwise() {
gpu.func @arith_constant() {
%0 = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
dense<1.0> : vector<16x16xf32>
+ %cl0 = xegpu.convert_layout %0
+ <{
+ input_layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+ target_layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+ }> : vector<16x16xf32>
gpu.return
}
@@ -363,6 +373,11 @@ gpu.func @vector_multi_reduction_dim1_distributed_dim1_reduction(%laneid: index)
layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [1]>
}
[1] : vector<2x16xf32> to vector<2xf32>
+ %cl1 = xegpu.convert_layout %1
+ <{
+ input_layout = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [1]>,
+ target_layout = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [1]>
+ }> : vector<2xf32>
gpu.return
}
@@ -428,6 +443,11 @@ gpu.func @vector_multi_reduction_dim0_distributed_dim0_reduction(%laneid: index)
layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>, dims = [0]>
}
[0] : vector<16x2xf32> to vector<2xf32>
+ %cl1 = xegpu.convert_layout %1
+ <{
+ input_layout = #xegpu.slice<#xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>, dims = [0]>,
+ target_layout = #xegpu.slice<#xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>, dims = [0]>
+ }> : vector<2xf32>
gpu.return
}
@@ -449,6 +469,11 @@ gpu.func @vector_multi_reduction_dim1_distributed_dim0_reduction(%laneid: index)
layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>
}
[0] : vector<4x16xf32> to vector<16xf32>
+ %cl1 = xegpu.convert_layout %1
+ <{
+ input_layout = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>,
+ target_layout = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>
+ }> : vector<16xf32>
gpu.return
}
@@ -470,6 +495,11 @@ gpu.func @vector_multi_reduction_dim0_distributed_dim1_reduction(%laneid: index)
layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>, dims = [1]>
}
[1] : vector<16x12xf32> to vector<16xf32>
+ %cl1 = xegpu.convert_layout %1
+ <{
+ input_layout = #xegpu.slice<#xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>, dims = [1]>,
+ target_layout = #xegpu.slice<#xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>, dims = [1]>
+ }> : vector<16xf32>
gpu.return
}
@@ -524,6 +554,11 @@ gpu.func @create_mask_1d(%m0: index) {
%mask = vector.create_mask %m0
{layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
: vector<16xi1>
+ %mask_cl = xegpu.convert_layout %mask
+ <{
+ input_layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
+ target_layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>
+ }> : vector<16xi1>
gpu.return
}
@@ -539,6 +574,11 @@ gpu.func @constant_mask_1d() {
%mask = vector.constant_mask [4]
{layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
: vector<16xi1>
+ %mask_cl = xegpu.convert_layout %mask
+ <{
+ input_layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
+ target_layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>
+ }> : vector<16xi1>
gpu.return
}
@@ -560,6 +600,11 @@ gpu.func @create_mask_2d(%m0: index, %m1: index) {
%mask = vector.create_mask %m0, %m1
{layout_result_0 = #xegpu.layout<lane_layout = [8, 2], lane_data = [1, 1]>}
: vector<8x4xi1>
+ %mask_cl = xegpu.convert_layout %mask
+ <{
+ input_layout = #xegpu.layout<lane_layout = [8, 2], lane_data = [1, 1]>,
+ target_layout = #xegpu.layout<lane_layout = [8, 2], lane_data = [1, 1]>
+ }> : vector<8x4xi1>
gpu.return
}
@@ -582,6 +627,11 @@ gpu.func @constant_mask_2d() {
%mask = vector.constant_mask [2, 3]
{layout_result_0 = #xegpu.layout<lane_layout = [8, 2], lane_data = [1, 1]>}
: vector<8x4xi1>
+ %mask_cl = xegpu.convert_layout %mask
+ <{
+ input_layout = #xegpu.layout<lane_layout = [8, 2], lane_data = [1, 1]>,
+ target_layout = #xegpu.layout<lane_layout = [8, 2], lane_data = [1, 1]>
+ }> : vector<8x4xi1>
gpu.return
}
@@ -603,6 +653,11 @@ gpu.func @vector_multi_reduction_3d_leading_unit_dim_lane_local() {
layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>, dims = [1]>
}
[1] : vector<1x16x32xf32> to vector<1x32xf32>
+ %cl1 = xegpu.convert_layout %1
+ <{
+ input_layout = #xegpu.slice<#xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>, dims = [1]>,
+ target_layout = #xegpu.slice<#xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>, dims = [1]>
+ }> : vector<1x32xf32>
gpu.return
}
@@ -636,6 +691,11 @@ gpu.func @vector_multi_reduction_3d_leading_unit_dim_cross_lane() {
layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16, 1], lane_data = [1, 1, 1]>, dims = [1]>
}
[1] : vector<1x16x2xf32> to vector<1x2xf32>
+ %cl1 = xegpu.convert_layout %1
+ <{
+ input_layout = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16, 1], lane_data = [1, 1, 1]>, dims = [1]>,
+ target_layout = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16, 1], lane_data = [1, 1, 1]>, dims = [1]>
+ }> : vector<1x2xf32>
gpu.return
}
@@ -648,6 +708,11 @@ gpu.func @vector_extract_from_2d() {
%0 = vector.extract %src[0]
{layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
: vector<16xf32> from vector<4x16xf32>
+ %cl0 = xegpu.convert_layout %0
+ <{
+ input_layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
+ target_layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>
+ }> : vector<16xf32>
gpu.return
}
@@ -660,6 +725,11 @@ gpu.func @vector_extract_from_2d_offset2() {
%0 = vector.extract %src[2]
{layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
: vector<16xf32> from vector<8x16xf32>
+ %cl0 = xegpu.convert_layout %0
+ <{
+ input_layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
+ target_layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>
+ }> : vector<16xf32>
gpu.return
}
@@ -675,6 +745,11 @@ gpu.func @vector_insert_into_2d() {
%0 = vector.insert %val, %dst[0]
{layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
: vector<16xf32> into vector<4x16xf32>
+ %cl0 = xegpu.convert_layout %0
+ <{
+ input_layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+ target_layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+ }> : vector<4x16xf32>
gpu.return
}
@@ -690,6 +765,11 @@ gpu.func @vector_insert_into_2d_offset2() {
%0 = vector.insert %val, %dst[2]
{layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
: vector<16xf32> into vector<8x16xf32>
+ %cl0 = xegpu.convert_layout %0
+ <{
+ input_layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+ target_layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+ }> : vector<8x16xf32>
gpu.return
}
@@ -703,6 +783,11 @@ gpu.func @vector_extract_strided_slice_distributed_dim_fully_extracted() {
layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
}
: vector<24x16xf32> to vector<8x16xf32>
+ %cl1 = xegpu.convert_layout %1
+ <{
+ input_layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+ target_layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+ }> : vector<8x16xf32>
gpu.return
}
@@ -716,6 +801,11 @@ gpu.func @vector_extract_strided_slice_inner_distributed() {
layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
}
: vector<24x64xf32> to vector<8x16xf32>
+ %cl1 = xegpu.convert_layout %1
+ <{
+ input_layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+ target_layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+ }> : vector<8x16xf32>
gpu.return
}
@@ -729,6 +819,11 @@ gpu.func @vector_extract_strided_slice_outer_distributed() {
layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>
}
: vector<32x16xf32> to vector<16x16xf32>
+ %cl1 = xegpu.convert_layout %1
+ <{
+ input_layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>,
+ target_layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>
+ }> : vector<16x16xf32>
gpu.return
}
@@ -742,6 +837,11 @@ gpu.func @vector_extract_strided_slice_1d() {
layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>
}
: vector<64xf32> to vector<32xf32>
+ %cl1 = xegpu.convert_layout %1
+ <{
+ input_layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
+ target_layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>
+ }> : vector<32xf32>
gpu.return
}
@@ -755,6 +855,11 @@ gpu.func @vector_extract_strided_slice_partial_offsets() {
layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
}
: vector<24x16xf32> to vector<8x16xf32>
+ %cl1 = xegpu.convert_layout %1
+ <{
+ input_layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+ target_layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+ }> : vector<8x16xf32>
gpu.return
}
@@ -771,6 +876,11 @@ gpu.func @vector_insert_strided_slice_distributed_dim_fully_inserted() {
layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
}
: vector<16x16xf32> into vector<64x16xf32>
+ %cl2 = xegpu.convert_layout %2
+ <{
+ input_layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+ target_layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+ }> : vector<64x16xf32>
gpu.return
}
@@ -787,6 +897,11 @@ gpu.func @vector_insert_strided_slice_inner_distributed() {
layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
}
: vector<16x16xf32> into vector<64x32xf32>
+ %cl2 = xegpu.convert_layout %2
+ <{
+ input_layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+ target_layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+ }> : vector<64x32xf32>
gpu.return
}
@@ -803,6 +918,11 @@ gpu.func @vector_insert_strided_slice_outer_distributed() {
layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>
}
: vector<16x16xf32> into vector<48x32xf32>
+ %cl2 = xegpu.convert_layout %2
+ <{
+ input_layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>,
+ target_layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>
+ }> : vector<48x32xf32>
gpu.return
}
@@ -819,6 +939,11 @@ gpu.func @vector_insert_strided_slice_1d() {
layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>
}
: vector<16xf32> into vector<48xf32>
+ %cl2 = xegpu.convert_layout %2
+ <{
+ input_layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
+ target_layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>
+ }> : vector<48xf32>
gpu.return
}
@@ -835,6 +960,11 @@ gpu.func @vector_insert_strided_slice_different_ranks() {
layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
}
: vector<16xf32> into vector<64x16xf32>
+ %cl2 = xegpu.convert_layout %2
+ <{
+ input_layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+ target_layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+ }> : vector<64x16xf32>
gpu.return
}
@@ -953,6 +1083,11 @@ gpu.func @elementwise_wrap_around_dim() {
: () -> vector<16x1xf16>
%1 = arith.negf %0 {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
: vector<16x1xf16>
+ %cl1 = xegpu.convert_layout %1
+ <{
+ input_layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+ target_layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+ }> : vector<16x1xf16>
gpu.return
}
}
@@ -967,6 +1102,11 @@ gpu.module @xevm_module {
// CHECK: %[[VEC:.*]] = vector.from_elements %[[REM2]] : vector<1xindex>
gpu.func @vector_step_slice() {
%0 = vector.step {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 1, 1, 16], lane_data = [1, 1, 1, 1]>, dims = [0, 1, 2]>} : vector<16xindex>
+ %cl0 = xegpu.convert_layout %0
+ <{
+ input_layout = #xegpu.slice<#xegpu.layout<lane_layout = [1, 1, 1, 16], lane_data = [1, 1, 1, 1]>, dims = [0, 1, 2]>,
+ target_layout = #xegpu.slice<#xegpu.layout<lane_layout = [1, 1, 1, 16], lane_data = [1, 1, 1, 1]>, dims = [0, 1, 2]>
+ }> : vector<16xindex>
gpu.return
}
}
@@ -977,6 +1117,11 @@ gpu.module @xevm_module {
// CHECK: %[[VEC:.*]] = vector.from_elements %{{.*}} : vector<1xindex>
gpu.func @vector_step_slice_unit() {
%0 = vector.step {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 1, 1, 16], lane_data = [1, 1, 1, 1]>, dims = [0, 1, 3]>} : vector<1xindex>
+ %cl0 = xegpu.convert_layout %0
+ <{
+ input_layout = #xegpu.slice<#xegpu.layout<lane_layout = [1, 1, 1, 16], lane_data = [1, 1, 1, 1]>, dims = [0, 1, 3]>,
+ target_layout = #xegpu.slice<#xegpu.layout<lane_layout = [1, 1, 1, 16], lane_data = [1, 1, 1, 1]>, dims = [0, 1, 3]>
+ }> : vector<1xindex>
gpu.return
}
}
@@ -994,6 +1139,11 @@ gpu.module @xevm_module {
// CHECK: %[[VEC:.*]] = vector.from_elements %[[V0]], %[[V1]], %[[V2]], %[[V3]] : vector<4xindex>
gpu.func @vector_step_slice_multi_dist() {
%0 = vector.step {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [2, 4, 2], lane_data = [1, 2, 1]>, dims = [0, 2]>} : vector<16xindex>
+ %cl0 = xegpu.convert_layout %0
+ <{
+ input_layout = #xegpu.slice<#xegpu.layout<lane_layout = [2, 4, 2], lane_data = [1, 2, 1]>, dims = [0, 2]>,
+ target_layout = #xegpu.slice<#xegpu.layout<lane_layout = [2, 4, 2], lane_data = [1, 2, 1]>, dims = [0, 2]>
+ }> : vector<16xindex>
gpu.return
}
}
@@ -1011,6 +1161,11 @@ gpu.func @vector_shapecast_rank_increasing() {
layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
}
: vector<16xf32> to vector<1x16xf32>
+ %cast_cl = xegpu.convert_layout %cast
+ <{
+ input_layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+ target_layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+ }> : vector<1x16xf32>
gpu.return
}
}
@@ -1028,6 +1183,11 @@ gpu.func @vector_shapecast_rank_reducing() {
layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>
}
: vector<1x16xf32> to vector<16xf32>
+ %cast_cl = xegpu.convert_layout %cast
+ <{
+ input_layout = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>,
+ target_layout = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>
+ }> : vector<16xf32>
gpu.return
}
}
@@ -1045,6 +1205,11 @@ gpu.func @vector_shapecast_rank_increasing_without_slicing_layout() {
layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
}
: vector<16xf32> to vector<1x16xf32>
+ %cast_cl = xegpu.convert_layout %cast
+ <{
+ input_layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+ target_layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+ }> : vector<1x16xf32>
gpu.return
}
}
@@ -1071,6 +1236,11 @@ gpu.module @xevm_module {
gpu.func @constant_wrap_around_dim() {
%0 = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
dense<1.0> : vector<16x1xf16>
+ %cl0 = xegpu.convert_layout %0
+ <{
+ input_layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+ target_layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+ }> : vector<16x1xf16>
gpu.return
}
}
@@ -1160,6 +1330,11 @@ gpu.func @vector_multi_reduction_1d_to_scalar() {
layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [16], lane_data = [1]>, dims = [0]>
}
[0] : vector<32xf32> to f32
+ %cl1 = xegpu.convert_layout %1
+ <{
+ input_layout = #xegpu.slice<#xegpu.layout<lane_layout = [16], lane_data = [1]>, dims = [0]>,
+ target_layout = #xegpu.slice<#xegpu.layout<lane_layout = [16], lane_data = [1]>, dims = [0]>
+ }> : f32
gpu.return
}
}
>From df3106ac7ba17724f662c838bb3a8878624d1f11 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Thu, 9 Apr 2026 17:54:43 +0000
Subject: [PATCH 11/19] fix issues
---
.../include/mlir/Dialect/XeGPU/IR/XeGPUOps.td | 2 +-
.../XeGPU/Transforms/XeGPUPropagateLayout.cpp | 87 ++++++++++++++++++-
2 files changed, 84 insertions(+), 5 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index e001419257d8f..72584f40681f4 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -1494,7 +1494,7 @@ def XeGPU_FenceOp: XeGPU_Op<"fence", []> {
let extraClassDeclaration = extraBaseClassDeclaration;
}
-def XeGPU_ConvertLayoutOp: XeGPU_Op<"convert_layout", [Pure, AllTypesMatch<["source", "result"]>, AnchorLayoutInterface]> {
+def XeGPU_ConvertLayoutOp: XeGPU_Op<"convert_layout", [AllTypesMatch<["source", "result"]>, MemoryEffects<[MemRead, MemWrite]>, AnchorLayoutInterface]> {
let summary = "Convert the layout of the input operand";
let description = [{
`convert_layout` redistribute data across subgroups and/or lanes from the `input_layout` to
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 6aeabacf4b30f..c53e865729eb2 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -396,6 +396,10 @@ class LayoutInfoPropagation
ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results);
+ void visitConvertLayoutOp(xegpu::ConvertLayoutOp convertLayout,
+ ArrayRef<LayoutInfoLattice *> operands,
+ ArrayRef<const LayoutInfoLattice *> results);
+
bool hasParamsOfLayoutKind(xegpu::DistributeLayoutAttr anchorLayout);
public:
@@ -483,6 +487,9 @@ LogicalResult LayoutInfoPropagation::visitOperation(
.Case([&](xegpu::StoreMatrixOp storeMatrixOp) {
visitStoreMatrixOp(storeMatrixOp, operands, results);
})
+ .Case([&](xegpu::ConvertLayoutOp convertLayoutOp) {
+ visitConvertLayoutOp(convertLayoutOp, operands, results);
+ })
// All other ops.
.Default([&](Operation *op) {
for (const LayoutInfoLattice *resultInfo : results) {
@@ -936,6 +943,17 @@ void LayoutInfoPropagation::visitLoadNdOp(
propagateIfChanged(operands[0], operands[0]->meet(loadLayout));
}
+/// Propagate the layout of the value to the tensor descriptor operand in
+/// ConvertLayoutOp.
+void LayoutInfoPropagation::visitConvertLayoutOp(
+ xegpu::ConvertLayoutOp convert, ArrayRef<LayoutInfoLattice *> operands,
+ ArrayRef<const LayoutInfoLattice *> results) {
+ xegpu::DistributeLayoutAttr anchorLayout = convert.getTargetLayoutAttr();
+ LayoutInfo convertLayout(anchorLayout);
+ // Propagate the new layout to the tensor descriptor operand.
+ propagateIfChanged(operands[0], operands[0]->meet(convertLayout));
+}
+
/// For vector::TransposeOp, the layout of the result is transposed and
/// propagated to the operand.
void LayoutInfoPropagation::visitTransposeOp(
@@ -1346,7 +1364,7 @@ LogicalResult ResolveLayoutConflicts::run() {
// as anchor op for the reduction op's layout.
if (isa<vector::MultiDimReductionOp>(op) || isa<vector::ReductionOp>(op)) {
for (OpResult result : op->getResults()) {
- if (result.getType().isIntOrFloat() || result.use_empty()) {
+ if (result.getType().isIntOrFloat()) {
auto res = assignResultLayout(result);
if (failed(res)) {
DBGS() << "Failed to resolve vector consumer for multi-reduction "
@@ -1356,6 +1374,22 @@ LogicalResult ResolveLayoutConflicts::run() {
}
}
}
+ if (isa<RegionBranchOpInterface>(op)) {
+ auto regionBranchOp = dyn_cast<RegionBranchOpInterface>(op);
+ unsigned numResults = regionBranchOp->getNumResults();
+ for (unsigned i = 0; i < numResults; ++i) {
+ OpResult result = regionBranchOp->getResult(i);
+ if (result.use_empty()) {
+ auto res = assignResultLayout(result);
+ if (failed(res)) {
+ DBGS() << "Failed to resolve vector consumer for loop/switch "
+ "result with no use: "
+ << *op << "\n";
+ return WalkResult::interrupt();
+ }
+ }
+ }
+ }
for (OpOperand &operand : op->getOpOperands()) {
// Handle conflicts in tensor descriptor operands.
Type operandType = operand.get().getType();
@@ -1369,6 +1403,8 @@ LogicalResult ResolveLayoutConflicts::run() {
}
}
// Handle conflicts in vector operands.
+ LLVM_DEBUG(DBGS() << "Handling vector operand #" << operand.getOperandNumber()
+ << ": " << operand.get() << " in operation: " << *op << "\n");
if (isa<VectorType>(operandType)) {
auto res = resolveVectorConsumer(operand);
if (failed(res)) {
@@ -1507,7 +1543,7 @@ static LogicalResult updateOp(mlir::OpBuilder &builder, mlir::Operation *op,
}
// If the result is a vector type, add a temporary layout attribute to the
// op.
- xegpu::setDistributeLayoutAttr(result, layout);
+ // xegpu::setDistributeLayoutAttr(result, layout);
}
return success();
}
@@ -1537,7 +1573,8 @@ static LogicalResult updateOp(mlir::OpBuilder &builder, mlir::Operation *op,
static LogicalResult
updateControlFlowOps(mlir::OpBuilder &builder,
mlir::RegionBranchTerminatorOpInterface terminator,
- GetLayoutFnTy getLayoutOfValue) {
+ GetLayoutFnTy getLayoutOfValue,
+ xegpu::LayoutKind layoutKind) {
LLVM_DEBUG(DBGS() << "updateControlFlowOps: processing terminator: "
<< *terminator << "\n");
// Only process if the terminator is inside a region branch op.
@@ -1570,6 +1607,12 @@ updateControlFlowOps(mlir::OpBuilder &builder,
DBGS() << " skipping: not a TensorDescType or VectorType\n");
continue;
}
+
+ // debug print successorInput and successorOperand
+ LLVM_DEBUG(DBGS() << " successor input: " << successorInput << "\n");
+ LLVM_DEBUG(DBGS() << " successor operand: " << successorOperand->get()
+ << "\n");
+
xegpu::DistributeLayoutAttr successorInputLayout =
getLayoutOfValue(successorInput);
xegpu::DistributeLayoutAttr successorOperandLayout =
@@ -1601,6 +1644,15 @@ updateControlFlowOps(mlir::OpBuilder &builder,
}
// Get tensor descriptor type with the layout.
if (auto tdescTy = dyn_cast<xegpu::TensorDescType>(inputType)) {
+ // if (successorInputLayout != successorOperandLayout) {
+ // LLVM_DEBUG(DBGS()
+ // << " FAILURE: Conflicting layouts for region "
+ // "argument and operand forwarded as the argument: "
+ // << successorInputLayout << " vs " <<
+ // successorOperandLayout
+ // << "\n");
+ // return failure();
+ // }
auto newTdescTy = xegpu::TensorDescType::get(
tdescTy.getContext(), tdescTy.getShape(), tdescTy.getElementType(),
tdescTy.getEncoding(), successorOperandLayout);
@@ -1609,6 +1661,22 @@ updateControlFlowOps(mlir::OpBuilder &builder,
successorInput.setType(newTdescTy);
continue;
}
+
+ // if (auto vectorTy = dyn_cast<VectorType>(inputType)) {
+ // SmallVector<int64_t> vectorShape(vectorTy.getShape().begin(),
+ // vectorTy.getShape().end());
+ // if (!successorInputLayout.isCompatibleWith(successorOperandLayout,
+ // vectorShape, layoutKind))
+ // {
+ // LLVM_DEBUG(DBGS()
+ // << " FAILURE: Conflicting layouts for region "
+ // "argument and operand forwarded as the argument: "
+ // << successorInputLayout << " vs " <<
+ // successorOperandLayout
+ // << "\n");
+ // return failure();
+ // }
+ // }
// If the type is a vector type and this region argument is an OpResult,
// set the layout attribute on the OpResult.
if (auto result = dyn_cast<OpResult>(successorInput)) {
@@ -1720,7 +1788,7 @@ LogicalResult xegpu::propagateLayouts(OpBuilder &builder, Operation *target,
TypeSwitch<Operation *>(&op)
.Case([&](mlir::RegionBranchTerminatorOpInterface branchTermOp) {
r = updateControlFlowOps(builder, branchTermOp,
- getXeGPULayoutForValue);
+ getXeGPULayoutForValue, layoutKind);
})
.Case([&](mlir::FunctionOpInterface funcOp) {
r = updateFunctionOpInterface(builder, funcOp,
@@ -1748,6 +1816,17 @@ LogicalResult xegpu::resolveLayoutConflicts(Operation *target) {
}
void XeGPUPropagateLayoutPass::runOnOperation() {
+ // Remove layout attributes from SCF ops
+ getOperation()->walk([](Operation *op) {
+ SmallVector<StringAttr> attrsToRemove;
+ for (auto namedAttr : op->getDiscardableAttrs()) {
+ if (isa<xegpu::DistributeLayoutAttr>(namedAttr.getValue()))
+ attrsToRemove.push_back(namedAttr.getName());
+ }
+ for (auto attrName : attrsToRemove)
+ op->removeDiscardableAttr(attrName);
+ });
+
xegpu::LayoutKind layoutKind;
if (this->layoutKind == "lane") {
layoutKind = xegpu::LayoutKind::Lane;
>From ad40b1f79e13b9def43a29ca53e8096308f894d9 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Thu, 9 Apr 2026 20:56:45 +0000
Subject: [PATCH 12/19] fix convert_layout pattern
---
.../XeGPU/Transforms/XeGPUPropagateLayout.cpp | 12 +++++++++---
1 file changed, 9 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index c53e865729eb2..84ecaf34606cd 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -948,7 +948,7 @@ void LayoutInfoPropagation::visitLoadNdOp(
void LayoutInfoPropagation::visitConvertLayoutOp(
xegpu::ConvertLayoutOp convert, ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results) {
- xegpu::DistributeLayoutAttr anchorLayout = convert.getTargetLayoutAttr();
+ xegpu::DistributeLayoutAttr anchorLayout = convert.getInputLayoutAttr();
LayoutInfo convertLayout(anchorLayout);
// Propagate the new layout to the tensor descriptor operand.
propagateIfChanged(operands[0], operands[0]->meet(convertLayout));
@@ -1543,7 +1543,7 @@ static LogicalResult updateOp(mlir::OpBuilder &builder, mlir::Operation *op,
}
// If the result is a vector type, add a temporary layout attribute to the
// op.
- // xegpu::setDistributeLayoutAttr(result, layout);
+ xegpu::setDistributeLayoutAttr(result, layout);
}
return success();
}
@@ -1780,8 +1780,10 @@ LogicalResult xegpu::propagateLayouts(OpBuilder &builder, Operation *target,
return cast<xegpu::LayoutAttr>(layoutAttr);
};
-
+ // dump the op before update
+ llvm::dbgs() << "Before layout propagation and conflict resolution:\n";
Operation *op = target;
+ op->dump();
auto walkResult = op->walk([&](mlir::Block *block) -> WalkResult {
for (mlir::Operation &op : llvm::reverse(block->getOperations())) {
LogicalResult r = success();
@@ -1804,6 +1806,10 @@ LogicalResult xegpu::propagateLayouts(OpBuilder &builder, Operation *target,
}
return WalkResult::advance();
});
+ // dump the op after update, print some headline
+ llvm::dbgs() << "After layout propagation and conflict resolution:\n";
+ op->dump();
+
if (walkResult.wasInterrupted())
return failure();
>From e2f581b14bb8a3f4a5277febe05c1e906a9325bf Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Fri, 10 Apr 2026 06:12:47 +0000
Subject: [PATCH 13/19] fix, all test passed now
---
.../XeGPU/Transforms/XeGPUPropagateLayout.cpp | 57 ++++++++++++-------
1 file changed, 37 insertions(+), 20 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 84ecaf34606cd..4f1680648fcdf 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -1374,22 +1374,22 @@ LogicalResult ResolveLayoutConflicts::run() {
}
}
}
- if (isa<RegionBranchOpInterface>(op)) {
- auto regionBranchOp = dyn_cast<RegionBranchOpInterface>(op);
- unsigned numResults = regionBranchOp->getNumResults();
- for (unsigned i = 0; i < numResults; ++i) {
- OpResult result = regionBranchOp->getResult(i);
- if (result.use_empty()) {
- auto res = assignResultLayout(result);
- if (failed(res)) {
- DBGS() << "Failed to resolve vector consumer for loop/switch "
- "result with no use: "
- << *op << "\n";
- return WalkResult::interrupt();
- }
- }
- }
- }
+ // if (isa<RegionBranchOpInterface>(op)) {
+ // auto regionBranchOp = dyn_cast<RegionBranchOpInterface>(op);
+ // unsigned numResults = regionBranchOp->getNumResults();
+ // for (unsigned i = 0; i < numResults; ++i) {
+ // OpResult result = regionBranchOp->getResult(i);
+ // if (result.use_empty()) {
+ // auto res = assignResultLayout(result);
+ // if (failed(res)) {
+ // DBGS() << "Failed to resolve vector consumer for loop/switch "
+ // "result with no use: "
+ // << *op << "\n";
+ // return WalkResult::interrupt();
+ // }
+ // }
+ // }
+ // }
for (OpOperand &operand : op->getOpOperands()) {
// Handle conflicts in tensor descriptor operands.
Type operandType = operand.get().getType();
@@ -1403,8 +1403,9 @@ LogicalResult ResolveLayoutConflicts::run() {
}
}
// Handle conflicts in vector operands.
- LLVM_DEBUG(DBGS() << "Handling vector operand #" << operand.getOperandNumber()
- << ": " << operand.get() << " in operation: " << *op << "\n");
+ LLVM_DEBUG(DBGS() << "Handling vector operand #"
+ << operand.getOperandNumber() << ": " << operand.get()
+ << " in operation: " << *op << "\n");
if (isa<VectorType>(operandType)) {
auto res = resolveVectorConsumer(operand);
if (failed(res)) {
@@ -1758,8 +1759,7 @@ LogicalResult xegpu::propagateLayouts(OpBuilder &builder, Operation *target,
// Helper to convert LayoutInfo to xegpu::LayoutAttr.
auto getXeGPULayoutForValue = [&](Value val) -> xegpu::DistributeLayoutAttr {
LayoutInfo layout = analysis.getLayoutInfo(val);
- if (!layout.isAssigned())
- return {};
+
if (auto opResult = dyn_cast<OpResult>(val)) {
Operation *defOp = opResult.getDefiningOp();
@@ -1773,6 +1773,8 @@ LogicalResult xegpu::propagateLayouts(OpBuilder &builder, Operation *target,
if (requiredResLayoutAttr != nullptr)
return requiredResLayoutAttr;
}
+ if (!layout.isAssigned())
+ return {};
xegpu::DistributeLayoutAttr layoutAttr =
cast<xegpu::DistributeLayoutAttr>(layout.get());
if (layout.isSliceLayout())
@@ -1787,6 +1789,21 @@ LogicalResult xegpu::propagateLayouts(OpBuilder &builder, Operation *target,
auto walkResult = op->walk([&](mlir::Block *block) -> WalkResult {
for (mlir::Operation &op : llvm::reverse(block->getOperations())) {
LogicalResult r = success();
+
+ // for (OpOperand &operand : op.getOpOperands()) {
+ // Type operandType = operand.get().getType();
+ // // We only need to operate on tensor descriptor or vector types.
+ // if (!isa<xegpu::TensorDescType, VectorType>(operandType))
+ // continue;
+ // xegpu::DistributeLayoutAttr consumerLayout =
+ // getXeGPULayoutForValue(operand.get());
+ // // dump the layout info for debugging
+ // // dump operation and operand info for debugging
+ // DBGS() << "Processing operation: " << op << "\n";
+ // llvm::outs() << "operand #" << operand.getOperandNumber() << " = "
+ // << operand.get() << "\n";
+ // DBGS() << "Layout: " << consumerLayout << "\n";
+ // }
TypeSwitch<Operation *>(&op)
.Case([&](mlir::RegionBranchTerminatorOpInterface branchTermOp) {
r = updateControlFlowOps(builder, branchTermOp,
>From cb4aeb2b3b0ec20ef180147ff610d03ca93ec550 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Mon, 13 Apr 2026 20:39:57 +0000
Subject: [PATCH 14/19] remove separated PR
---
.../XeGPU/Transforms/XeGPULayoutImpl.h | 10 +-
mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp | 2 +-
.../XeGPU/Transforms/XeGPULayoutImpl.cpp | 40 +-
.../XeGPU/Transforms/XeGPUPropagateLayout.cpp | 233 +--
.../XeGPUSgToWiDistributeExperimental.cp | 1744 +++++++++++++++++
.../XeGPUSgToWiDistributeExperimental.cpp | 1 -
.../Transforms/XeGPUSubgroupDistribute.cpp | 41 +-
.../Transforms/XeGPUWgToSgDistribute.cpp | 1 -
mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp | 1 -
9 files changed, 1799 insertions(+), 274 deletions(-)
create mode 100644 mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cp
diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h b/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h
index c36201c2f0d9e..651b001a9d931 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h
@@ -112,11 +112,10 @@ inferInsertStridedSliceSourceLayout(DistributeLayoutAttr resLayout,
ArrayRef<int64_t> resShape,
ArrayRef<int64_t> srcShape);
-/// Infers the layout attribute for mask and offset operand for Chunked load
-/// and store, given the anchor layout attribute for the value being load/store.
+/// Infers the source layout attribute for an operand using result layout
+/// attribute
DistributeLayoutAttr
-inferMaskOffsetLayoutForScatterIO(DistributeLayoutAttr payloadLayout,
- int chunkSize);
+inferSourceLayoutFromResult(OpOperand &operand, DistributeLayoutAttr resLayout);
/// Sets up layout for Multi-Reduction operations by creating a SliceAttr for
/// the result.
@@ -189,9 +188,6 @@ setupDpasLayout(LayoutKind layoutKind, VectorType aTy, VectorType bTy,
VectorType cdTy, DistributeLayoutAttr consumerLayout, int numSg,
const uArch::uArch *uArch);
-DistributeLayoutAttr
-inferSourceLayoutFromResult(OpOperand &operand, DistributeLayoutAttr resLayout);
-
/// Gets the expected layout for a given consumer operand. This will check if
/// the owning operation of the consumer operand is one of the special layout
/// users and determine the expected layout accordingly.
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index eaa43c02946d8..64c56b5adf5d7 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -585,7 +585,7 @@ DistributeLayoutAttr LayoutAttr::dropDims(SmallVector<int64_t> dimGroup) {
int64_t offset = llvm::count_if(dimGroup, [&](int64_t s) { return s < d; });
newOrder.push_back(d - offset);
}
- if ((sgLayout.empty() && laneLayout.empty()) || newOrder.size() == 1)
+ if (sgLayout.empty() && laneLayout.empty())
newOrder.clear();
auto toAttr = [&](ArrayRef<int64_t> v) -> DenseI32ArrayAttr {
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
index 47c60eaf7d4e0..65fab3e60166d 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
@@ -26,11 +26,11 @@
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/LoopLikeInterface.h"
#include "mlir/Transforms/DialectConversion.h"
-#include "llvm/Support/Debug.h"
#include "llvm/Support/FormatVariadic.h"
#include <cstdint>
#include <numeric>
+#include "llvm/Support/Debug.h"
#define DEBUG_TYPE "xegpu-layout-recovery"
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
@@ -375,33 +375,6 @@ bool xegpu::recoverTemporaryLayouts(Operation *rootOp) {
return true;
}
-// // Attach layout attributes to all vector-type operands of operations within
-// // the given operation's region. Reports an error if any vector operand lacks
-// // a layout attribute.
-// bool xegpu::recoverTemporaryLayouts(Operation *rootOp) {
-// auto result = rootOp->walk([&](Operation *op) {
-// for (OpOperand &operand : op->getOpOperands()) {
-// // Layouts are needed for vector type only.
-// if (!isa<VectorType>(operand.get().getType()))
-// continue;
-// // Skip block arguments since they don't have defining ops to attach
-// // layout attributes to.
-// if (isa<BlockArgument>(operand.get()))
-// continue;
-// auto layout = xegpu::getDistributeLayoutAttr(operand.get());
-// if (!layout) {
-// op->emitWarning("Could not find layout attribute for operand ")
-// << operand.getOperandNumber() << " of operation " <<
-// op->getName();
-// xegpu::setTemporaryLayout(operand, layout);
-// continue;
-// }
-// }
-// return WalkResult::advance();
-// });
-// return !result.wasInterrupted();
-// }
-
template <typename T, typename>
void xegpu::removeLayoutAttr(const T &operandOrResult) {
Operation *owner = operandOrResult.getOwner();
@@ -677,17 +650,6 @@ xegpu::inferShapeCastSourceLayout(xegpu::DistributeLayoutAttr resLayout,
return nullptr;
}
-/// Infers the layout attribute for mask and offset operand for Chunked load
-/// and store, given the anchor layout attribute for the value being load/store.
-xegpu::DistributeLayoutAttr xegpu::inferMaskOffsetLayoutForScatterIO(
- xegpu::DistributeLayoutAttr payloadLayout, int chunkSize) {
- auto rank = payloadLayout.getRank();
- if (chunkSize > 1)
- return payloadLayout.dropDims(
- llvm::to_vector(llvm::seq<int64_t>(rank - 1, rank)));
- return payloadLayout;
-}
-
/// Sets up layout for reduction operations by creating a SliceAttr for the
/// result.
///
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 4f1680648fcdf..4c30dacae8850 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -396,10 +396,6 @@ class LayoutInfoPropagation
ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results);
- void visitConvertLayoutOp(xegpu::ConvertLayoutOp convertLayout,
- ArrayRef<LayoutInfoLattice *> operands,
- ArrayRef<const LayoutInfoLattice *> results);
-
bool hasParamsOfLayoutKind(xegpu::DistributeLayoutAttr anchorLayout);
public:
@@ -487,9 +483,6 @@ LogicalResult LayoutInfoPropagation::visitOperation(
.Case([&](xegpu::StoreMatrixOp storeMatrixOp) {
visitStoreMatrixOp(storeMatrixOp, operands, results);
})
- .Case([&](xegpu::ConvertLayoutOp convertLayoutOp) {
- visitConvertLayoutOp(convertLayoutOp, operands, results);
- })
// All other ops.
.Default([&](Operation *op) {
for (const LayoutInfoLattice *resultInfo : results) {
@@ -943,17 +936,6 @@ void LayoutInfoPropagation::visitLoadNdOp(
propagateIfChanged(operands[0], operands[0]->meet(loadLayout));
}
-/// Propagate the layout of the value to the tensor descriptor operand in
-/// ConvertLayoutOp.
-void LayoutInfoPropagation::visitConvertLayoutOp(
- xegpu::ConvertLayoutOp convert, ArrayRef<LayoutInfoLattice *> operands,
- ArrayRef<const LayoutInfoLattice *> results) {
- xegpu::DistributeLayoutAttr anchorLayout = convert.getInputLayoutAttr();
- LayoutInfo convertLayout(anchorLayout);
- // Propagate the new layout to the tensor descriptor operand.
- propagateIfChanged(operands[0], operands[0]->meet(convertLayout));
-}
-
/// For vector::TransposeOp, the layout of the result is transposed and
/// propagated to the operand.
void LayoutInfoPropagation::visitTransposeOp(
@@ -1045,7 +1027,7 @@ void LayoutInfoPropagation::visitLoadGatherOp(
const uArch *uArch = getUArch(getChipStr(load).value_or(""));
if (!uArch)
return;
- // auto subgroupSize = uArch->getSubgroupSize();
+ auto subgroupSize = uArch->getSubgroupSize();
VectorType resVecTy = load.getValueType();
int chunkSize = load.getChunkSize().value_or(1);
@@ -1067,24 +1049,20 @@ void LayoutInfoPropagation::visitLoadGatherOp(
load.setLayoutAttr(requiredAnchorLayoutAttr);
}
- assert((chunkSize <= 1) || (layoutKind != xegpu::LayoutKind::Subgroup));
- auto maskLayoutAttr = xegpu::inferMaskOffsetLayoutForScatterIO(
- requiredAnchorLayoutAttr, chunkSize);
-
- // // Special handling mask layout for chunked ops: Enforce the default xegpu
- // 1D
- // // layout for mask.
- // if (chunkSize > 1) {
- // if (layoutKind == xegpu::LayoutKind::InstData)
- // maskLayoutAttr =
- // xegpu::LayoutAttr::get(load->getContext(), {subgroupSize});
- // else if (layoutKind == xegpu::LayoutKind::Lane)
- // maskLayoutAttr =
- // xegpu::LayoutAttr::get(load->getContext(), {subgroupSize}, {1});
- // else
- // assert(false &&
- // "chunked StoreScatterOp should not be used at workgroup level");
- // }
+ auto maskLayoutAttr = requiredAnchorLayoutAttr;
+ // Special handling mask layout for chunked ops: Enforce the default xegpu 1D
+ // layout for mask.
+ if (chunkSize > 1) {
+ if (layoutKind == xegpu::LayoutKind::InstData)
+ maskLayoutAttr =
+ xegpu::LayoutAttr::get(load->getContext(), {subgroupSize});
+ else if (layoutKind == xegpu::LayoutKind::Lane)
+ maskLayoutAttr =
+ xegpu::LayoutAttr::get(load->getContext(), {subgroupSize}, {1});
+ else
+ assert(false &&
+ "chunked StoreScatterOp should not be used at workgroup level");
+ }
LayoutInfo maskLayoutInfo = LayoutInfo(maskLayoutAttr);
auto loadLayoutInfo = LayoutInfo(requiredAnchorLayoutAttr);
@@ -1127,7 +1105,7 @@ void LayoutInfoPropagation::visitStoreScatterOp(
const uArch *uArch = getUArch(getChipStr(storeScatter).value_or(""));
if (!uArch)
return;
- // auto subgroupSize = uArch->getSubgroupSize();
+ auto subgroupSize = uArch->getSubgroupSize();
VectorType srcVecTy = storeScatter.getValueType();
int chunkSize = storeScatter.getChunkSize().value_or(1);
@@ -1144,26 +1122,22 @@ void LayoutInfoPropagation::visitStoreScatterOp(
}
LayoutInfo srcLayoutInfo = LayoutInfo(requiredAnchorLayoutAttr);
- assert((chunkSize <= 1) || (layoutKind != xegpu::LayoutKind::Subgroup));
- auto maskLayoutAttr = xegpu::inferMaskOffsetLayoutForScatterIO(
- requiredAnchorLayoutAttr, chunkSize);
- LayoutInfo maskLayoutInfo = LayoutInfo(maskLayoutAttr);
+ auto maskLayoutAttr = requiredAnchorLayoutAttr;
+ // Special handling mask layout for chunked ops: Enforce the default xegpu 1D
+ // layout for mask.
+ if (chunkSize > 1) {
+ if (layoutKind == xegpu::LayoutKind::InstData)
+ maskLayoutAttr =
+ xegpu::LayoutAttr::get(storeScatter->getContext(), {subgroupSize});
+ else if (layoutKind == xegpu::LayoutKind::Lane)
+ maskLayoutAttr = xegpu::LayoutAttr::get(storeScatter->getContext(),
+ {subgroupSize}, {1});
+ else
+ assert(false &&
+ "chunked StoreScatterOp should not be used at workgroup level");
+ }
- // auto maskLayoutAttr = requiredAnchorLayoutAttr;
- // // Special handling mask layout for chunked ops: Enforce the default xegpu
- // 1D
- // // layout for mask.
- // if (chunkSize > 1) {
- // if (layoutKind == xegpu::LayoutKind::InstData)
- // maskLayoutAttr =
- // xegpu::LayoutAttr::get(storeScatter->getContext(), {subgroupSize});
- // else if (layoutKind == xegpu::LayoutKind::Lane)
- // maskLayoutAttr = xegpu::LayoutAttr::get(storeScatter->getContext(),
- // {subgroupSize}, {1});
- // else
- // assert(false &&
- // "chunked StoreScatterOp should not be used at workgroup level");
- // }
+ LayoutInfo maskLayoutInfo = LayoutInfo(maskLayoutAttr);
// Propagate the payload operand layout
propagateIfChanged(operands[0], operands[0]->meet(srcLayoutInfo));
@@ -1374,22 +1348,6 @@ LogicalResult ResolveLayoutConflicts::run() {
}
}
}
- // if (isa<RegionBranchOpInterface>(op)) {
- // auto regionBranchOp = dyn_cast<RegionBranchOpInterface>(op);
- // unsigned numResults = regionBranchOp->getNumResults();
- // for (unsigned i = 0; i < numResults; ++i) {
- // OpResult result = regionBranchOp->getResult(i);
- // if (result.use_empty()) {
- // auto res = assignResultLayout(result);
- // if (failed(res)) {
- // DBGS() << "Failed to resolve vector consumer for loop/switch "
- // "result with no use: "
- // << *op << "\n";
- // return WalkResult::interrupt();
- // }
- // }
- // }
- // }
for (OpOperand &operand : op->getOpOperands()) {
// Handle conflicts in tensor descriptor operands.
Type operandType = operand.get().getType();
@@ -1403,9 +1361,6 @@ LogicalResult ResolveLayoutConflicts::run() {
}
}
// Handle conflicts in vector operands.
- LLVM_DEBUG(DBGS() << "Handling vector operand #"
- << operand.getOperandNumber() << ": " << operand.get()
- << " in operation: " << *op << "\n");
if (isa<VectorType>(operandType)) {
auto res = resolveVectorConsumer(operand);
if (failed(res)) {
@@ -1574,125 +1529,56 @@ static LogicalResult updateOp(mlir::OpBuilder &builder, mlir::Operation *op,
static LogicalResult
updateControlFlowOps(mlir::OpBuilder &builder,
mlir::RegionBranchTerminatorOpInterface terminator,
- GetLayoutFnTy getLayoutOfValue,
- xegpu::LayoutKind layoutKind) {
- LLVM_DEBUG(DBGS() << "updateControlFlowOps: processing terminator: "
- << *terminator << "\n");
+ GetLayoutFnTy getLayoutOfValue) {
// Only process if the terminator is inside a region branch op.
auto branchOp = dyn_cast<RegionBranchOpInterface>(terminator->getParentOp());
- if (!branchOp) {
- LLVM_DEBUG(
- DBGS() << " parent is not a RegionBranchOpInterface, skipping\n");
+ if (!branchOp)
return success();
- }
- LLVM_DEBUG(DBGS() << " parent branch op: " << *branchOp << "\n");
RegionBranchSuccessorMapping mapping;
branchOp.getSuccessorOperandInputMapping(mapping,
RegionBranchPoint(terminator));
- LLVM_DEBUG(DBGS() << " successor mapping has " << mapping.size()
- << " entries\n");
for (const auto &[successorOperand, successorInputs] : mapping) {
- LLVM_DEBUG(DBGS() << " processing successor operand: "
- << successorOperand->get()
- << " (type: " << successorOperand->get().getType()
- << "), num successor inputs: " << successorInputs.size()
- << "\n");
for (Value successorInput : successorInputs) {
Type inputType = successorInput.getType();
- LLVM_DEBUG(DBGS() << " successor input: " << successorInput
- << ", type: " << inputType << "\n");
// We only need to operate on tensor descriptor or vector types.
- if (!isa<xegpu::TensorDescType, VectorType>(inputType)) {
- LLVM_DEBUG(
- DBGS() << " skipping: not a TensorDescType or VectorType\n");
+ if (!isa<xegpu::TensorDescType, VectorType>(inputType))
continue;
- }
-
- // debug print successorInput and successorOperand
- LLVM_DEBUG(DBGS() << " successor input: " << successorInput << "\n");
- LLVM_DEBUG(DBGS() << " successor operand: " << successorOperand->get()
- << "\n");
-
xegpu::DistributeLayoutAttr successorInputLayout =
getLayoutOfValue(successorInput);
xegpu::DistributeLayoutAttr successorOperandLayout =
getLayoutOfValue(successorOperand->get());
- LLVM_DEBUG(DBGS() << " successor input layout: ");
- LLVM_DEBUG(if (successorInputLayout) llvm::dbgs() << successorInputLayout;
- else llvm::dbgs() << "<<NULL>>"; llvm::dbgs() << "\n");
- LLVM_DEBUG(DBGS() << " successor operand layout: ");
- LLVM_DEBUG(if (successorOperandLayout) llvm::dbgs()
- << successorOperandLayout;
- else llvm::dbgs() << "<<NULL>>"; llvm::dbgs() << "\n");
-
// If either of the layouts is not assigned, we cannot proceed.
if (!successorOperandLayout) {
- LLVM_DEBUG(DBGS() << " FAILURE: No layout assigned for forwarded "
- "operand in branch terminator: "
+ LLVM_DEBUG(DBGS() << "No layout assigned for forwarded operand in "
+ "branch terminator: "
<< successorOperand->get() << "\n");
return failure();
}
// We expect the layouts to match.
if (successorInputLayout &&
successorInputLayout != successorOperandLayout) {
- LLVM_DEBUG(DBGS() << " FAILURE: Conflicting layouts for region "
- "argument and operand forwarded as the argument: "
+ LLVM_DEBUG(DBGS() << "Conflicting layouts for region argument and "
+ "operand forwarded as the argument: "
<< successorInputLayout << " vs "
<< successorOperandLayout << "\n");
return failure();
}
// Get tensor descriptor type with the layout.
if (auto tdescTy = dyn_cast<xegpu::TensorDescType>(inputType)) {
- // if (successorInputLayout != successorOperandLayout) {
- // LLVM_DEBUG(DBGS()
- // << " FAILURE: Conflicting layouts for region "
- // "argument and operand forwarded as the argument: "
- // << successorInputLayout << " vs " <<
- // successorOperandLayout
- // << "\n");
- // return failure();
- // }
auto newTdescTy = xegpu::TensorDescType::get(
tdescTy.getContext(), tdescTy.getShape(), tdescTy.getElementType(),
tdescTy.getEncoding(), successorOperandLayout);
- LLVM_DEBUG(DBGS() << " updating tensor desc type: " << tdescTy
- << " -> " << newTdescTy << "\n");
successorInput.setType(newTdescTy);
continue;
}
-
- // if (auto vectorTy = dyn_cast<VectorType>(inputType)) {
- // SmallVector<int64_t> vectorShape(vectorTy.getShape().begin(),
- // vectorTy.getShape().end());
- // if (!successorInputLayout.isCompatibleWith(successorOperandLayout,
- // vectorShape, layoutKind))
- // {
- // LLVM_DEBUG(DBGS()
- // << " FAILURE: Conflicting layouts for region "
- // "argument and operand forwarded as the argument: "
- // << successorInputLayout << " vs " <<
- // successorOperandLayout
- // << "\n");
- // return failure();
- // }
- // }
// If the type is a vector type and this region argument is an OpResult,
// set the layout attribute on the OpResult.
- if (auto result = dyn_cast<OpResult>(successorInput)) {
- LLVM_DEBUG(DBGS() << " setting layout on OpResult #"
- << result.getResultNumber() << " of "
- << *result.getOwner() << " to "
- << successorOperandLayout << "\n");
+ if (auto result = dyn_cast<OpResult>(successorInput))
xegpu::setDistributeLayoutAttr(result, successorOperandLayout);
- } else {
- LLVM_DEBUG(DBGS() << " successor input is a BlockArgument, "
- "not setting layout attribute\n");
- }
}
}
- LLVM_DEBUG(DBGS() << " updateControlFlowOps: success\n");
return success();
}
@@ -1759,7 +1645,8 @@ LogicalResult xegpu::propagateLayouts(OpBuilder &builder, Operation *target,
// Helper to convert LayoutInfo to xegpu::LayoutAttr.
auto getXeGPULayoutForValue = [&](Value val) -> xegpu::DistributeLayoutAttr {
LayoutInfo layout = analysis.getLayoutInfo(val);
-
+ if (!layout.isAssigned())
+ return {};
if (auto opResult = dyn_cast<OpResult>(val)) {
Operation *defOp = opResult.getDefiningOp();
@@ -1773,8 +1660,6 @@ LogicalResult xegpu::propagateLayouts(OpBuilder &builder, Operation *target,
if (requiredResLayoutAttr != nullptr)
return requiredResLayoutAttr;
}
- if (!layout.isAssigned())
- return {};
xegpu::DistributeLayoutAttr layoutAttr =
cast<xegpu::DistributeLayoutAttr>(layout.get());
if (layout.isSliceLayout())
@@ -1782,32 +1667,15 @@ LogicalResult xegpu::propagateLayouts(OpBuilder &builder, Operation *target,
return cast<xegpu::LayoutAttr>(layoutAttr);
};
- // dump the op before update
- llvm::dbgs() << "Before layout propagation and conflict resolution:\n";
+
Operation *op = target;
- op->dump();
auto walkResult = op->walk([&](mlir::Block *block) -> WalkResult {
for (mlir::Operation &op : llvm::reverse(block->getOperations())) {
LogicalResult r = success();
-
- // for (OpOperand &operand : op.getOpOperands()) {
- // Type operandType = operand.get().getType();
- // // We only need to operate on tensor descriptor or vector types.
- // if (!isa<xegpu::TensorDescType, VectorType>(operandType))
- // continue;
- // xegpu::DistributeLayoutAttr consumerLayout =
- // getXeGPULayoutForValue(operand.get());
- // // dump the layout info for debugging
- // // dump operation and operand info for debugging
- // DBGS() << "Processing operation: " << op << "\n";
- // llvm::outs() << "operand #" << operand.getOperandNumber() << " = "
- // << operand.get() << "\n";
- // DBGS() << "Layout: " << consumerLayout << "\n";
- // }
TypeSwitch<Operation *>(&op)
.Case([&](mlir::RegionBranchTerminatorOpInterface branchTermOp) {
r = updateControlFlowOps(builder, branchTermOp,
- getXeGPULayoutForValue, layoutKind);
+ getXeGPULayoutForValue);
})
.Case([&](mlir::FunctionOpInterface funcOp) {
r = updateFunctionOpInterface(builder, funcOp,
@@ -1823,10 +1691,6 @@ LogicalResult xegpu::propagateLayouts(OpBuilder &builder, Operation *target,
}
return WalkResult::advance();
});
- // dump the op after update, print some headline
- llvm::dbgs() << "After layout propagation and conflict resolution:\n";
- op->dump();
-
if (walkResult.wasInterrupted())
return failure();
@@ -1839,17 +1703,6 @@ LogicalResult xegpu::resolveLayoutConflicts(Operation *target) {
}
void XeGPUPropagateLayoutPass::runOnOperation() {
- // Remove layout attributes from SCF ops
- getOperation()->walk([](Operation *op) {
- SmallVector<StringAttr> attrsToRemove;
- for (auto namedAttr : op->getDiscardableAttrs()) {
- if (isa<xegpu::DistributeLayoutAttr>(namedAttr.getValue()))
- attrsToRemove.push_back(namedAttr.getName());
- }
- for (auto attrName : attrsToRemove)
- op->removeDiscardableAttr(attrName);
- });
-
xegpu::LayoutKind layoutKind;
if (this->layoutKind == "lane") {
layoutKind = xegpu::LayoutKind::Lane;
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cp
new file mode 100644
index 0000000000000..e3227c7f5b149
--- /dev/null
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cp
@@ -0,0 +1,1744 @@
+//===- XeGPUSgToWiDistributeExperimental.cpp - XeGPU SG to WI Pass --------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/Index/IR/IndexDialect.h"
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/Transforms/Patterns.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
+#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
+#include "mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h"
+#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
+#include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/Value.h"
+#include "mlir/IR/ValueRange.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/SetVector.h"
+#include "llvm/Support/LogicalResult.h"
+#include "llvm/Support/raw_ostream.h"
+#include <optional>
+
+namespace mlir {
+namespace xegpu {
+#define GEN_PASS_DEF_XEGPUSGTOWIDISTRIBUTEEXPERIMENTAL
+#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
+} // namespace xegpu
+} // namespace mlir
+
+using namespace mlir;
+
+#define DEBUG_TYPE "xegpu-sg-to-wi-distribute-experimental"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
+
+namespace {
+
+/// Casts the given vector value `v` to the expected vector type `expectedTy`.
+static Value castValueTo(ConversionPatternRewriter &rewriter,
+ TypedValue<VectorType> v, VectorType expectedTy) {
+ // If the type matches, simply return the value itself.
+ if (v.getType() == expectedTy)
+ return v;
+ // If only shape differs, use shape cast.
+ if (isa<VectorType>(v.getType()) &&
+ v.getType().getNumElements() == expectedTy.getNumElements())
+ return vector::ShapeCastOp::create(rewriter, v.getLoc(), expectedTy, v);
+
+ // Else create an unrealized cast.
+ auto newOp = UnrealizedConversionCastOp::create(rewriter, v.getLoc(),
+ expectedTy, ValueRange{v});
+ return newOp.getResult(0);
+}
+
+/// A vector::MultiDimReductionOp at subgroup level in expected form if, it has
+/// exactly 1 reduction dimension, it had valid result layout attribute, and
+/// result type can be distributed to lanes using the layout.
+static bool isValidSubgroupMultiReductionOp(vector::MultiDimReductionOp op) {
+ auto resLayout = xegpu::getTemporaryLayout(op->getOpResult(0));
+ // If no layout, not valid.
+ if (!resLayout || !resLayout.isForSubgroup())
+ return false;
+ // Scalar result (e.g., vector<32xf32> to f32) is valid.
+ if (op.getType().isIntOrFloat())
+ return op.getReductionDims().size() == 1;
+ VectorType resTy = dyn_cast<VectorType>(op.getType());
+ if (!resTy)
+ return false;
+ // Compute the distributed result vector type based on the layout.
+ FailureOr<VectorType> resDistTypeOrFailure =
+ getDistVecTypeBasedOnLaneLayout(resLayout, resTy);
+ if (failed(resDistTypeOrFailure))
+ return false;
+ return op.getReductionDims().size() == 1;
+}
+
+/// A vector::MultiDimReductionOp is doing lane-local reduction if each workitem
+/// is doing its own local reduction. In this case the result layout ensures
+/// that result vector is distributed to lanes, i.e. the result vector type is
+/// different from the distributed result vector type.
+static bool isReductionLaneLocal(vector::MultiDimReductionOp op) {
+ // Must be valid MultiDimReductionOp.
+ assert(isValidSubgroupMultiReductionOp(op) && "Expecting a valid subgroup "
+ "MultiDimReductionOp");
+ auto resLayout = xegpu::getTemporaryLayout(op->getOpResult(0));
+ VectorType resTy = dyn_cast<VectorType>(op.getType());
+ auto resDistTypeOrFailure = getDistVecTypeBasedOnLaneLayout(resLayout, resTy);
+ return resTy != resDistTypeOrFailure.value();
+}
+
+/// Given a vector type and its distributed vector type, return the list of
+/// dimensions that are distributed.
+static SmallVector<int64_t> getDistributedDims(VectorType originalType,
+ VectorType distributedType) {
+ assert(originalType.getRank() == distributedType.getRank() &&
+ "original and distributed vector types must have the same rank");
+ SmallVector<int64_t> distributedDims;
+ for (int64_t i = 0; i < originalType.getRank(); ++i) {
+ if (distributedType.getDimSize(i) != originalType.getDimSize(i))
+ distributedDims.push_back(i);
+ }
+ return distributedDims;
+}
+
+/// Distributes a subgroup-level CreateNdDesc op to workitem-level CreateNdDesc
+/// op. This simply drops the layout attribute from the tensor descriptor type.
+struct SgToWiCreateNdDesc : public OpConversionPattern<xegpu::CreateNdDescOp> {
+ using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(xegpu::CreateNdDescOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ xegpu::TensorDescType resultType = op.getType();
+ // If no layout, nothing to do.
+ if (!resultType.getLayout())
+ return failure();
+
+ auto newOp = xegpu::CreateNdDescOp::create(
+ rewriter, op.getLoc(), resultType.dropLayouts(), op.getOperands(),
+ op->getAttrs());
+ rewriter.replaceOp(op, newOp.getResult());
+ return success();
+ }
+};
+
+/// Distributes a subgroup-level LoadNd op to workitem-level LoadNd op. Output
+/// of workitem-level LoadNd op is 1D. ShapeCast is added to restore the
+/// original rank.
+struct SgToWiLoadNd : public OpConversionPattern<xegpu::LoadNdOp> {
+ using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(xegpu::LoadNdOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ xegpu::DistributeLayoutAttr layout = op.getAnchorLayout();
+ // If no layout, nothing to do.
+ if (!layout)
+ return failure();
+ // Check if the layout attached to the tensor descriptor is same as the
+ // anchor layout. Otherwise, this is a conflict.
+ if (op.getTensorDescType().getLayout() != layout)
+ return rewriter.notifyMatchFailure(
+ op, "conflicting layout attributes on tensor descriptor and anchor");
+ auto uArch = getUArch(xegpu::getChipStr(op).value_or(""));
+ if (!uArch)
+ return rewriter.notifyMatchFailure(
+ op, "xegpu::LoadNdOp require target attribute attached to "
+ "determine transpose "
+ "requirement");
+ auto supportedWiResultTyOrFailure =
+ xegpu::getDistributedVectorType(op.getTensorDescType());
+ auto expectedWiResultTyOrFailure =
+ xegpu::getDistVecTypeBasedOnLaneLayout(layout, op.getType());
+ if (failed(supportedWiResultTyOrFailure))
+ return rewriter.notifyMatchFailure(
+ op, "unable to compute the workitem vector type for LoadNdOp");
+ if (failed(expectedWiResultTyOrFailure))
+ return rewriter.notifyMatchFailure(
+ op,
+ "unable to compute expected workitem vector type from lane layout");
+ auto newOp = xegpu::LoadNdOp::create(
+ rewriter, op.getLoc(), supportedWiResultTyOrFailure.value(),
+ adaptor.getTensorDesc(), op.getMixedOffsets(), op.getPackedAttr(),
+ op.getTransposeAttr(), op.getL1HintAttr(), op.getL2HintAttr(),
+ op.getL3HintAttr(), /**layout**/ nullptr);
+ // Set the packed attribute if the layout requires it.
+ newOp.setPacked(xegpu::requirePacked(cast<xegpu::LayoutAttr>(layout)));
+ // Set the transpose attribute if the layout requires it.
+ if (xegpu::requireTranspose(cast<xegpu::LayoutAttr>(layout), uArch))
+ newOp.setTranspose(DenseI64ArrayAttr::get(rewriter.getContext(), {1, 0}));
+ rewriter.replaceOp(op, castValueTo(rewriter, newOp.getResult(),
+ expectedWiResultTyOrFailure.value()));
+ return success();
+ }
+};
+
+/// Distributes a subgroup-level StoreNd op to workitem-level StoreNd op. Stored
+/// value in workitem-level StoreNd op is 1D. ShapeCast is added to cast the
+/// incoming value to 1D.
+struct SgToWiStoreNd : public OpConversionPattern<xegpu::StoreNdOp> {
+ using OpConversionPattern<xegpu::StoreNdOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(xegpu::StoreNdOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ xegpu::DistributeLayoutAttr layout = op.getAnchorLayout();
+ // If no layout, nothing to do.
+ if (!layout)
+ return failure();
+ // Check if the layout attached to the tensor descriptor and value layout is
+ // same as the anchor layout. Otherwise, this is a conflict.
+ if (op.getTensorDescType().getLayout() != layout)
+ return rewriter.notifyMatchFailure(
+ op, "conflicting layout attributes on tensor descriptor and anchor");
+ auto valueLayout = xegpu::getDistributeLayoutAttr(op->getOpOperand(0));
+ if (valueLayout != layout)
+ return rewriter.notifyMatchFailure(
+ op, "conflicting layout attributes on value and anchor");
+ auto supportedWiValueTyOrFailure =
+ xegpu::getDistributedVectorType(op.getTensorDescType());
+ if (failed(supportedWiValueTyOrFailure))
+ return rewriter.notifyMatchFailure(
+ op,
+ "unable to compute wi vector type for StoreNdOp value from tensor "
+ "descriptor");
+
+ xegpu::StoreNdOp::create(
+ rewriter, op.getLoc(),
+ castValueTo(rewriter, cast<TypedValue<VectorType>>(adaptor.getValue()),
+ supportedWiValueTyOrFailure.value()),
+ adaptor.getTensorDesc(), op.getMixedOffsets(), op.getL1HintAttr(),
+ op.getL2HintAttr(), op.getL3HintAttr(), /**layout**/ nullptr);
+ rewriter.eraseOp(op);
+ return success();
+ }
+};
+
+/// Distributes a subgroup-level Dpas op to workitem-level Dpas op. All inpputs
+/// and output of workitem-level Dpas op are 1D. Necessary casts are added to
+/// convert the inputs and output to/from 1D.
+struct SgToWiDpas : public OpConversionPattern<xegpu::DpasOp> {
+ using OpConversionPattern<xegpu::DpasOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(xegpu::DpasOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ // llvm::errs() << "DpasOpPattern matchAndRewrite called\n";
+ // Check if the op has A, B and CD layouts attached.
+ auto layoutA = cast<xegpu::LayoutAttr>(op.getLayoutAAttr());
+ auto layoutB = cast<xegpu::LayoutAttr>(op.getLayoutBAttr());
+ auto layoutCd = cast<xegpu::LayoutAttr>(op.getLayoutCdAttr());
+ if (!layoutA || !layoutB || !layoutCd)
+ return failure();
+ // llvm::errs() << "tryning to calculate wi types for dpas op\n";
+ auto wiResultTyOrFailure =
+ xegpu::getDistributedVectorType(op.getType(), layoutCd);
+ auto wiATypeOrFailure =
+ xegpu::getDistributedVectorType(op.getLhs().getType(), layoutA);
+ auto wiBTypeOrFailure =
+ xegpu::getDistributedVectorType(op.getRhs().getType(), layoutB);
+ auto expectedWiResultTyOrFailure =
+ xegpu::getDistVecTypeBasedOnLaneLayout(layoutCd, op.getType());
+ if (failed(wiResultTyOrFailure) || failed(wiATypeOrFailure) ||
+ failed(wiBTypeOrFailure))
+ return rewriter.notifyMatchFailure(
+ op, "failed to calculate supported workitem vector types for DpasOp "
+ "from layouts");
+ if (failed(expectedWiResultTyOrFailure))
+ return rewriter.notifyMatchFailure(
+ op, "unable to compute expected workitem vector type for DpasOp from "
+ "lane layout");
+ auto newOp = xegpu::DpasOp::create(
+ rewriter, op->getLoc(), wiResultTyOrFailure.value(),
+ castValueTo(rewriter, cast<TypedValue<VectorType>>(adaptor.getLhs()),
+ wiATypeOrFailure.value()),
+ castValueTo(rewriter, cast<TypedValue<VectorType>>(adaptor.getRhs()),
+ wiBTypeOrFailure.value()),
+ castValueTo(rewriter, cast<TypedValue<VectorType>>(adaptor.getAcc()),
+ wiResultTyOrFailure.value()),
+ /** layoutA**/ nullptr,
+ /** layoutB**/ nullptr, /** layoutCd**/ nullptr);
+ // Explicitly set the new types to enable correct type materializations.
+ rewriter.replaceOp(op, castValueTo(rewriter, newOp.getResult(),
+ expectedWiResultTyOrFailure.value()));
+ return success();
+ }
+};
+
+/// Distributes elementwise ops to workitem-level elementwise ops. This
+/// currently handles elementwise ops with single result only.
+struct SgToWiElementWise : public ConversionPattern {
+ SgToWiElementWise(TypeConverter &typeConverter, MLIRContext *ctx)
+ : ConversionPattern(MatchAnyOpTypeTag(), /*benefit=*/1, ctx) {}
+
+ LogicalResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ // Only match ops with elementwise trait and single result.
+ if (!OpTrait::hasElementwiseMappableTraits(op) || op->getNumResults() != 1)
+ return failure();
+
+ auto resultType = dyn_cast<VectorType>(op->getResult(0).getType());
+ if (!resultType)
+ return rewriter.notifyMatchFailure(
+ op, "operation result is not a vector type");
+
+ xegpu::DistributeLayoutAttr layout =
+ xegpu::getTemporaryLayout(llvm::cast<OpResult>(op->getResult(0)));
+ if (!layout || !layout.isForSubgroup())
+ return rewriter.notifyMatchFailure(
+ op, "operation result does not have subgroup distribute layout");
+
+ auto wiShapeOrFailure =
+ xegpu::getDistVecTypeBasedOnLaneLayout(layout, resultType);
+
+ if (failed(wiShapeOrFailure))
+ return rewriter.notifyMatchFailure(
+ op, "unable to compute workitem vector type from the layout");
+
+ VectorType newResultType = wiShapeOrFailure.value();
+ OperationState state(op->getLoc(), op->getName());
+ state.addOperands(operands);
+ state.addTypes(newResultType);
+ // Copy all attributes except for DistributeLayoutAttr.
+ for (auto attr : op->getAttrs()) {
+ if (!isa<xegpu::DistributeLayoutAttr>(attr.getValue()))
+ state.addAttribute(attr.getName(), attr.getValue());
+ }
+ Operation *newOp = rewriter.create(state);
+
+ rewriter.replaceOp(op, newOp->getResult(0));
+ return success();
+ }
+};
+
+/// Distributes a subgroup-level arith ConstantOp to workitem-level arith
+/// ConstantOp.
+struct SgToWiArithConstant : public OpConversionPattern<arith::ConstantOp> {
+ using OpConversionPattern<arith::ConstantOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto resultType = dyn_cast<VectorType>(op.getType());
+ if (!resultType)
+ return failure();
+
+ // Only handle dense vector constants
+ auto dense = dyn_cast<SplatElementsAttr>(op.getValue());
+ if (!dense)
+ return rewriter.notifyMatchFailure(
+ op, "only dense splat vector constants are supported");
+
+ xegpu::DistributeLayoutAttr layout =
+ xegpu::getTemporaryLayout(llvm::cast<OpResult>(op.getResult()));
+ if (!layout || !layout.isForSubgroup())
+ return rewriter.notifyMatchFailure(
+ op, "operation result does not have subgroup distribute layout");
+
+ auto wiShapeOrFailure =
+ xegpu::getDistVecTypeBasedOnLaneLayout(layout, resultType);
+
+ if (failed(wiShapeOrFailure))
+ return rewriter.notifyMatchFailure(
+ op, "unable to compute workitem vector type from the layout");
+
+ VectorType newResultType = wiShapeOrFailure.value();
+ auto sclarValue = dense.getSplatValue<Attribute>();
+ auto newDenseAttr = DenseElementsAttr::get(newResultType, sclarValue);
+
+ auto newOp = arith::ConstantOp::create(rewriter, op.getLoc(), newResultType,
+ newDenseAttr);
+ rewriter.replaceOp(op, newOp.getResult());
+ return success();
+ }
+};
+
+/// Distributes a subgroup-level PrefetchNd op to workitem-level PrefetchNd op.
+struct SgToWiPrefetchNd : public OpConversionPattern<xegpu::PrefetchNdOp> {
+ using OpConversionPattern<xegpu::PrefetchNdOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(xegpu::PrefetchNdOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ xegpu::DistributeLayoutAttr layout = op.getAnchorLayout();
+ // If no layout, nothing to do.
+ if (!layout)
+ return failure();
+
+ xegpu::PrefetchNdOp::create(rewriter, op.getLoc(), adaptor.getTensorDesc(),
+ op.getMixedOffsets(), op.getL1HintAttr(),
+ op.getL2HintAttr(), op.getL3HintAttr(),
+ /**layout**/ nullptr);
+ rewriter.eraseOp(op);
+ return success();
+ }
+};
+
+/// Distributes a subgroup-level LoadGather (xegpu.load) op to workitem-level.
+///
+/// Example 1 (1D, no chunk size):
+/// layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>
+/// %mask = producer_op : vector<16xi1>
+/// %offset = producer_op : vector<16xindex>
+/// %0 = xegpu.load %src[%offset], %mask : memref<256xf16>,
+/// vector<16xindex>, vector<16xi1> -> vector<16xf16>
+/// Distributed to:
+/// %mask = producer_op : vector<1xi1>
+/// %offset = producer_op : vector<1xindex>
+/// %0 = xegpu.load %src[%offset], %mask : memref<256xf16>,
+/// vector<1xindex>, vector<1xi1> -> vector<1xf16>
+///
+/// Example 2 (2D with chunk size, same mask & offset):
+/// layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>
+/// %0 = xegpu.load %src[%offset], %mask <{chunk_size=8}> :
+/// memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
+/// Distributed to:
+/// %0 = xegpu.load %src[%offset], %mask <{chunk_size=8}> :
+/// memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<8xf16>
+///
+/// Example 3 (3D with leading unit dims):
+/// layout = #xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>
+/// %mask = producer_op : vector<1x1x16xi1>
+/// %offset = producer_op : vector<1x1x16xindex>
+/// %0 = xegpu.load %src[%offset], %mask : memref<256xf16>,
+/// vector<1x1x16xindex>, vector<1x1x16xi1> -> vector<1x1x16xf16>
+/// Distributed to:
+/// %mask = producer_op : vector<1x1x1xi1>
+/// %offset = producer_op : vector<1x1x1xindex>
+/// %0 = xegpu.load %src[%offset], %mask : memref<256xf16>,
+/// vector<1xindex>, vector<1xi1> -> vector<1xf16>
+struct SgToWiLoadGather : public OpConversionPattern<xegpu::LoadGatherOp> {
+ using OpConversionPattern<xegpu::LoadGatherOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(xegpu::LoadGatherOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ xegpu::DistributeLayoutAttr layout = op.getAnchorLayout();
+ if (!layout)
+ return failure();
+
+ VectorType origResultTy = op.getValueType();
+ if (!origResultTy)
+ return failure();
+
+ // Check that leading dimensions are unit.
+ int chunkSize = op.getChunkSize().value_or(1);
+ int effectiveVecRank = (chunkSize == 1) ? 1 : 2;
+ ArrayRef<int64_t> shape = origResultTy.getShape();
+ if (llvm::any_of(
+ shape.take_front(origResultTy.getRank() - effectiveVecRank),
+ [](int64_t d) { return d != 1; }))
+ return rewriter.notifyMatchFailure(
+ op, "Only unit dimensions allowed for the leading "
+ "dimensions of the load vector!");
+
+ auto distResultTyOrFailure =
+ xegpu::getDistVecTypeBasedOnLaneLayout(layout, origResultTy);
+ if (failed(distResultTyOrFailure))
+ return rewriter.notifyMatchFailure(
+ op,
+ "unable to compute expected workitem vector type from lane layout");
+
+ VectorType distResultTy = distResultTyOrFailure.value();
+ VectorType distResultTy1D = VectorType::get({distResultTy.getNumElements()},
+ distResultTy.getElementType());
+
+ // Flatten offsets and mask to 1D to match the 1D result type.
+ Value distOffsets = adaptor.getOffsets();
+ auto distOffsetsTy = cast<VectorType>(distOffsets.getType());
+ VectorType offsetsTy1D = VectorType::get({distOffsetsTy.getNumElements()},
+ distOffsetsTy.getElementType());
+ distOffsets = castValueTo(
+ rewriter, cast<TypedValue<VectorType>>(distOffsets), offsetsTy1D);
+
+ Value distMask = adaptor.getMask();
+ auto distMaskTy = cast<VectorType>(distMask.getType());
+ VectorType maskTy1D = VectorType::get({distMaskTy.getNumElements()},
+ distMaskTy.getElementType());
+ distMask =
+ castValueTo(rewriter, cast<TypedValue<VectorType>>(distMask), maskTy1D);
+
+ Value distSource = adaptor.getSource();
+ auto newOp = xegpu::LoadGatherOp::create(
+ rewriter, op.getLoc(), distResultTy1D, distSource, distOffsets,
+ distMask, op.getChunkSizeAttr(), op.getL1HintAttr(), op.getL2HintAttr(),
+ op.getL3HintAttr(), /*layout=*/nullptr);
+
+ Value result = newOp->getResult(0);
+ if (distResultTy1D != distResultTy)
+ result = castValueTo(rewriter, cast<TypedValue<VectorType>>(result),
+ distResultTy);
+ rewriter.replaceOp(op, result);
+ return success();
+ }
+};
+
+/// This pattern distributes a subgroup-level vector.reduction op to
+/// workitem-level. This require shuffling the data across the workitems (using
+/// gpu::ShuffleOp) and reducing in stages until all workitems have the final
+/// result.
+struct SgToWiVectorReduction : public OpConversionPattern<vector::ReductionOp> {
+ using OpConversionPattern<vector::ReductionOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::ReductionOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto layout = xegpu::getDistributeLayoutAttr(op.getVector());
+
+ // If no layout, nothing to do.
+ if (!layout || !layout.isForSubgroup())
+ return failure();
+
+ VectorType srcVecType = op.getSourceVectorType();
+ // Only rank 1 vectors supported.
+ if (srcVecType.getRank() != 1)
+ return rewriter.notifyMatchFailure(
+ op, "Only rank 1 reductions can be distributed.");
+ // Lane layout must have the same rank as the vector.
+ if (layout.getRank() != srcVecType.getRank())
+ return rewriter.notifyMatchFailure(
+ op, "Layout rank does not match vector rank.");
+
+ // Get the subgroup size from the layout.
+ int64_t sgSize = layout.getEffectiveLaneLayoutAsInt()[0];
+ const uArch *uArch = getUArch(xegpu::getChipStr(op).value_or(""));
+ if (!uArch)
+ return rewriter.notifyMatchFailure(
+ op, "xegpu::ReductionOp require target attribute attached to "
+ "determine subgroup size");
+
+ // Only subgroup-sized vectors supported.
+ if (sgSize != uArch->getSubgroupSize() ||
+ srcVecType.getShape()[0] % sgSize != 0)
+ return rewriter.notifyMatchFailure(op,
+ "Invalid layout or reduction vector "
+ "dimension must match subgroup size.");
+
+ if (!op.getType().isIntOrFloat())
+ return rewriter.notifyMatchFailure(
+ op, "Reduction distribution currently only supports floats and "
+ "integer types.");
+
+ // Get the distributed vector (per work-item portion).
+ Value laneValVec = adaptor.getVector();
+
+ // Distribute and reduce across work-items in the subgroup.
+ Value fullReduce = xegpu::subgroupReduction(
+ op.getLoc(), rewriter, laneValVec, op.getKind(), sgSize);
+
+ // If there's an accumulator, combine it with the reduced value.
+ if (adaptor.getAcc())
+ fullReduce = vector::makeArithReduction(
+ rewriter, op.getLoc(), op.getKind(), fullReduce, adaptor.getAcc());
+
+ rewriter.replaceOp(op, fullReduce);
+ return success();
+ }
+};
+
+/// This pattern distributes a subgroup-level vector.multi_reduction op to
+/// workitem-level only if the reduction is lane-local. This means that
+/// reduction dimension is not distributed to lanes and each lane does its own
+/// local reduction.
+struct SgToWiMultiDimReduction
+ : public OpConversionPattern<vector::MultiDimReductionOp> {
+ using OpConversionPattern<vector::MultiDimReductionOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::MultiDimReductionOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Value result;
+ ArrayRef<int64_t> reductionDims = op.getReductionDims();
+ assert(reductionDims.size() == 1 &&
+ "Expecting single reduction dimension for subgroup multi "
+ "reduction op");
+ // For rank > 2, ensure leading dimensions are unit.
+ VectorType sourceType = op.getSourceVectorType();
+ int64_t rank = sourceType.getRank();
+ if (rank > 2) {
+ ArrayRef<int64_t> shape = sourceType.getShape();
+ if (llvm::any_of(shape.take_front(rank - 2),
+ [](int64_t d) { return d != 1; }))
+ return rewriter.notifyMatchFailure(
+ op, "only unit leading dimensions are supported for "
+ "multi_reduction with rank > 2");
+ }
+ // Handle scalar result: full reduction of a distributed vector to a
+ // scalar. First do a local vector reduction, then cross-lane shuffles.
+ if (op.getType().isIntOrFloat()) {
+ auto reductionDim = reductionDims[0];
+ VectorType origSourceType = op.getSourceVectorType();
+ int64_t reductionDimSize = origSourceType.getShape()[reductionDim];
+ // Local reduction to scalar, then cross-lane butterfly shuffles.
+ result =
+ xegpu::subgroupReduction(op.getLoc(), rewriter, adaptor.getSource(),
+ op.getKind(), reductionDimSize);
+ // Combine with accumulator if present.
+ if (adaptor.getAcc())
+ result = vector::makeArithReduction(rewriter, op.getLoc(), op.getKind(),
+ result, adaptor.getAcc());
+ } else if (isReductionLaneLocal(op)) {
+ auto resLayout = xegpu::getTemporaryLayout(op->getOpResult(0));
+ VectorType resVecTy = dyn_cast<VectorType>(op.getType());
+ auto resDistVecTyOrFailure =
+ getDistVecTypeBasedOnLaneLayout(resLayout, resVecTy);
+ // For lane local reduction, simply create a new MultiDimReductionOp using
+ // adaptor operands and the new result type.
+ result = vector::MultiDimReductionOp::create(
+ rewriter, op.getLoc(), resDistVecTyOrFailure.value(), op.getKind(),
+ adaptor.getSource(), adaptor.getAcc(), op.getReductionDims());
+ } else {
+ auto reductionDim = reductionDims[0];
+ VectorType sourceType = op.getSourceVectorType();
+ int64_t reductionDimSize = sourceType.getShape()[reductionDim];
+ result = xegpu::lowerCrossLaneReductionToShuffles(
+ cast<TypedValue<VectorType>>(adaptor.getSource()),
+ cast<TypedValue<VectorType>>(adaptor.getAcc()), op.getKind(),
+ reductionDim, reductionDimSize, op.getLoc(), rewriter);
+ }
+ rewriter.replaceOp(op, result);
+ return success();
+ }
+};
+
+/// Helper to compute distributed coordinates for matrix ops.
+/// When not using subgroup_block_io, each workitem computes its own
+/// coordinates based on the layout and lane ID.
+static SmallVector<Value> computeDistributedCoordsForMatrixOp(
+ ConversionPatternRewriter &rewriter, Location loc,
+ xegpu::DistributeLayoutAttr layout, ArrayRef<int64_t> payloadShape,
+ ValueRange origOffsets) {
+ Value laneId = gpu::LaneIdOp::create(rewriter, loc, rewriter.getIndexType(),
+ /*upperBound=*/mlir::IntegerAttr());
+ auto maybeCoords =
+ layout.computeDistributedCoords(rewriter, loc, laneId, payloadShape);
+ if (failed(maybeCoords))
+ return {};
+ assert(maybeCoords.value().size() == 1 &&
+ "Expected one set of distributed offsets");
+ SmallVector<OpFoldResult> ofrVec = xegpu::addWithRightAligned(
+ rewriter, loc, getAsOpFoldResult(maybeCoords.value()[0]),
+ getAsOpFoldResult(origOffsets));
+ return llvm::map_to_vector(ofrVec, llvm::CastTo<Value>);
+}
+
+/// This pattern distributes a subgroup-level LoadMatrix op to workitem-level.
+struct SgToWiLoadMatrix : public OpConversionPattern<xegpu::LoadMatrixOp> {
+ using OpConversionPattern<xegpu::LoadMatrixOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(xegpu::LoadMatrixOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto layout = op.getLayoutAttr();
+ // If no layout, nothing to do.
+ if (!layout)
+ return failure();
+
+ VectorType sgPayloadTy = dyn_cast<VectorType>(op.getResult().getType());
+ if (!sgPayloadTy)
+ return rewriter.notifyMatchFailure(
+ op, "the matrix op payload must be a vector type");
+
+ auto loc = op.getLoc();
+ auto offsets = op.getMixedOffsets();
+ if (offsets.empty())
+ return rewriter.notifyMatchFailure(op, "the load op must have offsets");
+
+ FailureOr<VectorType> distPayloadTyOrFailure =
+ getDistVecTypeBasedOnLaneLayout(layout, sgPayloadTy);
+ if (failed(distPayloadTyOrFailure))
+ return rewriter.notifyMatchFailure(
+ op, "Failed to distribute matrix op payload based on layout.");
+
+ SmallVector<Value> offsetsAsValues =
+ vector::getAsValues(rewriter, loc, offsets);
+
+ SmallVector<Value> newCoords = offsetsAsValues;
+ if (!op.getSubgroupBlockIoAttr()) {
+ newCoords = computeDistributedCoordsForMatrixOp(
+ rewriter, loc, layout, sgPayloadTy.getShape(), offsetsAsValues);
+ if (newCoords.empty())
+ return rewriter.notifyMatchFailure(
+ op, "Failed to compute distributed coordinates.");
+ }
+
+ SmallVector<int64_t> newConstOffsets(op.getConstOffsets().size(),
+ ShapedType::kDynamic);
+ DenseI64ArrayAttr newConstOffsetsAttr =
+ rewriter.getDenseI64ArrayAttr(newConstOffsets);
+
+ auto newOp = xegpu::LoadMatrixOp::create(
+ rewriter, loc, *distPayloadTyOrFailure, adaptor.getMemDesc(),
+ ValueRange(newCoords), newConstOffsetsAttr, op.getSubgroupBlockIoAttr(),
+ xegpu::DistributeLayoutAttr{});
+ rewriter.replaceOp(op, newOp.getResult());
+ return success();
+ }
+};
+
+/// Distributes a subgroup-level vector.transpose op to workitem-level.
+struct SgToWiVectorTranspose : public OpConversionPattern<vector::TransposeOp> {
+ using OpConversionPattern<vector::TransposeOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::TransposeOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ xegpu::DistributeLayoutAttr sourceLayout =
+ xegpu::getTemporaryLayout(op->getOpOperand(0));
+ xegpu::DistributeLayoutAttr resultLayout =
+ xegpu::getTemporaryLayout(op->getOpResult(0));
+ if (!sourceLayout || !resultLayout)
+ return rewriter.notifyMatchFailure(
+ op, "the source or result vector of the transpose op lacks layout "
+ "attribute");
+ ArrayRef<int64_t> perm = op.getPermutation();
+ // Result layout must be a transpose of source layout.
+ if (!resultLayout.isTransposeOf(sourceLayout, perm,
+ xegpu::LayoutKind::Lane))
+ return rewriter.notifyMatchFailure(
+ op, "the source or result vector layouts must be transposes of "
+ "each other");
+ FailureOr<VectorType> distributedResultTypeOrFailure =
+ getDistVecTypeBasedOnLaneLayout(resultLayout, op.getResultVectorType());
+ if (failed(distributedResultTypeOrFailure))
+ return rewriter.notifyMatchFailure(
+ op, "Failed to distribute the result vector type in "
+ "vector::Transpose op");
+ auto newOp = vector::TransposeOp::create(rewriter, op.getLoc(),
+ adaptor.getVector(), perm);
+ rewriter.replaceOp(op, castValueTo(rewriter, newOp.getResult(),
+ distributedResultTypeOrFailure.value()));
+ return success();
+ }
+};
+
+/// Distributes a subgroup-level vector.bitcast op to workitem-level.
+/// Bitcast only impacts the innermost dimension of the source/result vectors.
+struct SgToWiVectorBitcast : public OpConversionPattern<vector::BitCastOp> {
+ using OpConversionPattern<vector::BitCastOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::BitCastOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ xegpu::DistributeLayoutAttr resultLayout =
+ xegpu::getTemporaryLayout(op->getOpResult(0));
+ if (!resultLayout)
+ return rewriter.notifyMatchFailure(
+ op, "result vector of the bitcast op lacks layout attribute");
+ FailureOr<VectorType> distributedResultTypeOrFailure =
+ getDistVecTypeBasedOnLaneLayout(resultLayout, op.getResultVectorType());
+ if (failed(distributedResultTypeOrFailure))
+ return rewriter.notifyMatchFailure(
+ op, "Failed to distribute the result vector type in "
+ "vector::BitCast op");
+ auto newOp = vector::BitCastOp::create(
+ rewriter, op.getLoc(), distributedResultTypeOrFailure.value(),
+ adaptor.getSource());
+ rewriter.replaceOp(op, newOp.getResult());
+ return success();
+ }
+};
+
+/// Distributes a subgroup-level vector.create_mask or vector.constant_mask op
+/// to workitem-level. Uses `computeDistributedCoords()` to obtain the
+/// coordinates each workitem owns, then compares each coordinate against the
+/// original mask bounds using `arith.cmpi slt`. The per-element boolean
+/// results are assembled into the distributed mask vector.
+///
+/// For multi-dimensional masks, the element is in-bounds when ALL dimensions
+/// satisfy `coord[i] < bound[i]`.
+///
+/// Example (1D):
+/// layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>
+/// %mask = vector.create_mask %m0 : vector<16xi1>
+/// For lane k, computeDistributedCoords gives coord = [k], so:
+/// %in_bounds = arith.cmpi slt, %coord, %m0 → i1
+/// %mask = vector.broadcast %in_bounds : i1 to vector<1xi1>
+///
+/// Example (2D):
+/// layout = #xegpu.layout<lane_layout = [8, 2], lane_data = [1, 1]>
+/// %mask = vector.create_mask %m0, %m1 : vector<8x4xi1>
+/// Each WI owns a 1x2 slice. computeDistributedCoords returns 2 coords:
+/// [[r0, c0], [r0, c1]]
+/// For each coord: in_bounds = (r < m0) && (c < m1)
+/// %mask = vector.from_elements %bit0, %bit1 : vector<1x2xi1>
+template <typename OpType,
+ typename = std::enable_if_t<llvm::is_one_of<
+ OpType, vector::CreateMaskOp, vector::ConstantMaskOp>::value>>
+struct SgToWiCreateMask : public OpConversionPattern<OpType> {
+ using OpConversionPattern<OpType>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ xegpu::DistributeLayoutAttr layout =
+ xegpu::getTemporaryLayout(op->getOpResult(0));
+ if (!layout || !layout.isForSubgroup())
+ return rewriter.notifyMatchFailure(
+ op, "operation result does not have subgroup distribute layout");
+
+ VectorType origType = op.getType();
+ FailureOr<VectorType> distTypeOrFailure =
+ getDistVecTypeBasedOnLaneLayout(layout, origType);
+ if (failed(distTypeOrFailure))
+ return rewriter.notifyMatchFailure(
+ op, "unable to compute workitem vector type from the layout");
+
+ VectorType distType = distTypeOrFailure.value();
+ Location loc = op.getLoc();
+
+ // Materialize the original mask bounds as Values.
+ SmallVector<Value> origBounds;
+ if constexpr (std::is_same_v<OpType, vector::CreateMaskOp>) {
+ origBounds.append(op.getOperands().begin(), op.getOperands().end());
+ } else {
+ auto dimSizes = op.getMaskDimSizesAttr().asArrayRef();
+ for (auto dimSize : dimSizes)
+ origBounds.push_back(
+ arith::ConstantIndexOp::create(rewriter, loc, dimSize).getResult());
+ }
+
+ ArrayRef<int64_t> origShape = origType.getShape();
+
+ // Use computeDistributedCoords to get the coordinates each WI owns.
+ Value laneId = gpu::LaneIdOp::create(rewriter, loc, rewriter.getIndexType(),
+ /*upperBound=*/mlir::IntegerAttr());
+ auto maybeCoordsVec =
+ layout.computeDistributedCoords(rewriter, loc, laneId, origShape);
+ if (failed(maybeCoordsVec))
+ return rewriter.notifyMatchFailure(
+ op, "failed to compute distributed coordinates from layout");
+
+ SmallVector<SmallVector<Value>> coordsVec = maybeCoordsVec.value();
+ int64_t numElements = distType.getNumElements();
+ assert(static_cast<int64_t>(coordsVec.size()) == numElements &&
+ "number of coordinate sets must match number of distributed "
+ "elements");
+
+ // For each element, compare all coordinates against bounds.
+ Value trueVal =
+ arith::ConstantIntOp::create(rewriter, loc, /*value=*/1, /*width=*/1);
+ SmallVector<Value> maskBits;
+ for (auto &coords : coordsVec) {
+ Value inBounds = trueVal;
+ for (size_t i = 0; i < coords.size(); ++i) {
+ Value cmp = arith::CmpIOp::create(
+ rewriter, loc, arith::CmpIPredicate::slt, coords[i], origBounds[i]);
+ inBounds = arith::AndIOp::create(rewriter, loc, inBounds, cmp);
+ }
+ maskBits.push_back(inBounds);
+ }
+
+ // Build the distributed mask vector.
+ Value result;
+ if (numElements == 1) {
+ result =
+ vector::BroadcastOp::create(rewriter, loc, distType, maskBits[0]);
+ } else {
+ result =
+ vector::FromElementsOp::create(rewriter, loc, distType, maskBits);
+ }
+ rewriter.replaceOp(op, result);
+ return success();
+ }
+};
+
+/// This pattern distributes a subgroup-level StoreMatrix op to workitem-level.
+struct SgToWiStoreMatrix : public OpConversionPattern<xegpu::StoreMatrixOp> {
+ using OpConversionPattern<xegpu::StoreMatrixOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(xegpu::StoreMatrixOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto layout = op.getLayoutAttr();
+ // If no layout, nothing to do.
+ if (!layout)
+ return failure();
+
+ VectorType sgPayloadTy = dyn_cast<VectorType>(op.getData().getType());
+ if (!sgPayloadTy)
+ return rewriter.notifyMatchFailure(
+ op, "the matrix op payload must be a vector type");
+
+ auto loc = op.getLoc();
+ auto offsets = op.getMixedOffsets();
+ if (offsets.empty())
+ return rewriter.notifyMatchFailure(op, "the store op must have offsets");
+
+ FailureOr<VectorType> distPayloadTyOrFailure =
+ getDistVecTypeBasedOnLaneLayout(layout, sgPayloadTy);
+ if (failed(distPayloadTyOrFailure))
+ return rewriter.notifyMatchFailure(
+ op, "Failed to distribute matrix op payload based on layout.");
+
+ SmallVector<Value> offsetsAsValues =
+ vector::getAsValues(rewriter, loc, offsets);
+
+ SmallVector<Value> newCoords = offsetsAsValues;
+ if (!op.getSubgroupBlockIoAttr()) {
+ newCoords = computeDistributedCoordsForMatrixOp(
+ rewriter, loc, layout, sgPayloadTy.getShape(), offsetsAsValues);
+ if (newCoords.empty())
+ return rewriter.notifyMatchFailure(
+ op, "Failed to compute distributed coordinates.");
+ }
+
+ SmallVector<int64_t> newConstOffsets(op.getConstOffsets().size(),
+ ShapedType::kDynamic);
+ DenseI64ArrayAttr newConstOffsetsAttr =
+ rewriter.getDenseI64ArrayAttr(newConstOffsets);
+
+ xegpu::StoreMatrixOp::create(
+ rewriter, loc, TypeRange{},
+ castValueTo(rewriter, cast<TypedValue<VectorType>>(adaptor.getData()),
+ distPayloadTyOrFailure.value()),
+ adaptor.getMemDesc(), ValueRange(newCoords), newConstOffsetsAttr,
+ op.getSubgroupBlockIoAttr(), xegpu::DistributeLayoutAttr{});
+ rewriter.eraseOp(op);
+ return success();
+ }
+};
+
+/// Distributes a subgroup-level StoreScatter (xegpu.store) op to
+/// workitem-level.
+///
+/// Example 1 (1D, no chunk size):
+/// layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>
+/// %mask = producer_op : vector<16xi1>
+/// %offset = producer_op : vector<16xindex>
+/// xegpu.store %payload, %src[%offset], %mask : vector<16xf16>,
+/// memref<256xf16>, vector<16xindex>, vector<16xi1>
+/// Distributed to:
+/// %mask = producer_op : vector<1xi1>
+/// %offset = producer_op : vector<1xindex>
+/// xegpu.store %payload, %src[%offset], %mask : vector<1xf16>,
+/// memref<256xf16>, vector<1xindex>, vector<1xi1>
+///
+/// Example 2 (2D with chunk size, same mask & offset):
+/// layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>
+/// xegpu.store %payload, %src[%offset], %mask <{chunk_size=8}> :
+/// vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
+/// Distributed to:
+/// xegpu.store %payload, %src[%offset], %mask <{chunk_size=8}> :
+/// vector<8xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
+///
+/// Example 3 (3D with leading unit dims):
+/// layout = #xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>
+/// %mask = producer_op : vector<1x1x16xi1>
+/// %offset = producer_op : vector<1x1x16xindex>
+/// xegpu.store %payload, %src[%offset], %mask : vector<1x1x16xf16>,
+/// memref<256xf16>, vector<1x1x16xindex>, vector<1x1x16xi1>
+/// Distributed to:
+/// %mask = producer_op : vector<1x1x1xi1>
+/// %offset = producer_op : vector<1x1x1xindex>
+/// xegpu.store %payload, %src[%offset], %mask : vector<1xf16>,
+/// memref<256xf16>, vector<1xindex>, vector<1xi1>
+struct SgToWiStoreScatter : public OpConversionPattern<xegpu::StoreScatterOp> {
+ using OpConversionPattern<xegpu::StoreScatterOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(xegpu::StoreScatterOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ xegpu::DistributeLayoutAttr layout = op.getAnchorLayout();
+ if (!layout)
+ return failure();
+
+ VectorType origValueTy = op.getValueType();
+ if (!origValueTy)
+ return failure();
+
+ // Check that all leading dimensions are unit dimensions.
+ int chunkSize = op.getChunkSize().value_or(1);
+ int effectiveVecRank = (chunkSize == 1) ? 1 : 2;
+ ArrayRef<int64_t> shape = origValueTy.getShape();
+ if (llvm::any_of(shape.take_front(origValueTy.getRank() - effectiveVecRank),
+ [](int64_t d) { return d != 1; }))
+ return rewriter.notifyMatchFailure(
+ op, "Only unit dimensions allowed for the leading "
+ "dimensions of the store vector!");
+
+ auto distValueTyOrFailure =
+ xegpu::getDistVecTypeBasedOnLaneLayout(layout, origValueTy);
+ if (failed(distValueTyOrFailure))
+ return rewriter.notifyMatchFailure(
+ op,
+ "unable to compute expected workitem vector type from lane layout");
+
+ VectorType distValueTy = distValueTyOrFailure.value();
+ VectorType distValueTy1D = VectorType::get({distValueTy.getNumElements()},
+ distValueTy.getElementType());
+
+ Value distValue = adaptor.getValue();
+ if (distValue.getType() != distValueTy1D)
+ distValue = castValueTo(rewriter, cast<TypedValue<VectorType>>(distValue),
+ distValueTy1D);
+
+ // Flatten offsets and mask to 1D to match the 1D value type.
+ Value distOffsets = adaptor.getOffsets();
+ auto distOffsetsTy = cast<VectorType>(distOffsets.getType());
+ VectorType offsetsTy1D = VectorType::get({distOffsetsTy.getNumElements()},
+ distOffsetsTy.getElementType());
+ distOffsets = castValueTo(
+ rewriter, cast<TypedValue<VectorType>>(distOffsets), offsetsTy1D);
+
+ Value distMask = adaptor.getMask();
+ auto distMaskTy = cast<VectorType>(distMask.getType());
+ VectorType maskTy1D = VectorType::get({distMaskTy.getNumElements()},
+ distMaskTy.getElementType());
+ distMask =
+ castValueTo(rewriter, cast<TypedValue<VectorType>>(distMask), maskTy1D);
+
+ Value distDest = adaptor.getDest();
+ xegpu::StoreScatterOp::create(rewriter, op.getLoc(), distValue, distDest,
+ distOffsets, distMask, op.getChunkSizeAttr(),
+ op.getL1HintAttr(), op.getL2HintAttr(),
+ op.getL3HintAttr(), /*layout=*/nullptr);
+ rewriter.eraseOp(op);
+ return success();
+ }
+};
+
+/// Distribute a vector::StepOp to workitem-level.
+/// The layout must have exactly 1 effective lane dimension.
+/// We completely resolve the vector::StepOp by computing the lane_data-sized
+/// subranges.
+struct SgToWiVectorStep : public OpConversionPattern<vector::StepOp> {
+ using OpConversionPattern<vector::StepOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::StepOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ xegpu::DistributeLayoutAttr resultLayout =
+ xegpu::getTemporaryLayout(op->getResult(0));
+ if (!resultLayout || !resultLayout.isForSubgroup())
+ return rewriter.notifyMatchFailure(
+ op, "the result vector of the step op lacks subgroup layout");
+
+ auto loc = op.getLoc();
+ auto stepResultVecTy = op.getResult().getType();
+ auto wiShapeOrFailure =
+ xegpu::getDistVecTypeBasedOnLaneLayout(resultLayout, stepResultVecTy);
+ if (failed(wiShapeOrFailure))
+ return rewriter.notifyMatchFailure(
+ op, "unable to compute workitem vector type from the layout");
+ VectorType newVecTy = wiShapeOrFailure.value();
+
+ Value laneId = gpu::LaneIdOp::create(rewriter, loc, rewriter.getIndexType(),
+ /*upperBound=*/mlir::IntegerAttr());
+ auto laneDataBlockCoords = resultLayout.computeDistributedCoords(
+ rewriter, loc, laneId, stepResultVecTy.getShape());
+ if (failed(laneDataBlockCoords))
+ return rewriter.notifyMatchFailure(
+ op, "failed to compute lane data block coordinates");
+
+ auto laneDataBlockCoordsVec = laneDataBlockCoords.value();
+ auto laneDataBlockLength = resultLayout.getEffectiveLaneDataAsInt()[0];
+ assert(static_cast<int64_t>(laneDataBlockCoordsVec.size()) ==
+ newVecTy.getNumElements() / laneDataBlockLength);
+ SmallVector<Value> stepVals;
+ // For each lane_data block, reconstruct its sub-range
+ // from the range of SG-level vector.step.Example: vector.step
+ // {slice<layout<lane_layout=[2,4,2], lane_data=[1,2,1]>, dims=[0,2]>} :
+ // vector<16xindex>
+ // Each logical lane holds 4 elements as 2 blocks of 2 elements each.
+ // The blocks are round-robin distributed, so logical lane id 0
+ // holds values [0,1, 8,9].
+ for (auto &laneDataBlockCoords : laneDataBlockCoordsVec) {
+ auto laneDataBlockStartCoord = laneDataBlockCoords[0];
+ stepVals.push_back(laneDataBlockStartCoord);
+ for (int i = 1; i < laneDataBlockLength; ++i) {
+ auto offset = arith::ConstantIndexOp::create(rewriter, loc, i);
+ stepVals.push_back(arith::AddIOp::create(
+ rewriter, loc, laneDataBlockStartCoord, offset));
+ }
+ }
+ assert(static_cast<int64_t>(stepVals.size()) == newVecTy.getNumElements() &&
+ "Expecting the number of step values to match the number of "
+ "elements in the vector");
+ auto stepOpVal =
+ vector::FromElementsOp::create(rewriter, loc, newVecTy, stepVals);
+ rewriter.replaceOp(op, stepOpVal);
+ return success();
+ }
+};
+
+/// Distributes a subgroup-level vector.extract op to workitem-level. Only
+/// handles sub-vector extraction (result is VectorType, not scalar).
+struct SgToWiVectorExtract : public OpConversionPattern<vector::ExtractOp> {
+ using OpConversionPattern<vector::ExtractOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::ExtractOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ // Only handle vector results (not scalar extraction).
+ auto resultType = dyn_cast<VectorType>(op.getType());
+ if (!resultType)
+ return rewriter.notifyMatchFailure(op, "scalar extract not supported");
+
+ xegpu::DistributeLayoutAttr layout =
+ xegpu::getTemporaryLayout(op->getOpResult(0));
+ if (!layout || !layout.isForSubgroup())
+ return failure();
+
+ // This implementation assumes distribution only happens on the innermost
+ // dimension. Verify that lane_layout[0...n-2] are all unit.
+ auto laneLayout = layout.getEffectiveLaneLayoutAsInt();
+ if (llvm::any_of(ArrayRef<int64_t>(laneLayout).drop_back(1),
+ [](int64_t v) { return v != 1; }))
+ return rewriter.notifyMatchFailure(
+ op, "only innermost dimension distribution is supported for "
+ "vector.extract");
+
+ auto newOp = vector::ExtractOp::create(
+ rewriter, op.getLoc(), adaptor.getSource(), op.getMixedPosition());
+ rewriter.replaceOp(op, newOp.getResult());
+ return success();
+ }
+};
+
+/// This pattern distributes a subgroup-level ShapeCast op to workitem-level.
+struct SgToWiVectorShapeCast : public OpConversionPattern<vector::ShapeCastOp> {
+ using OpConversionPattern<vector::ShapeCastOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::ShapeCastOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ xegpu::DistributeLayoutAttr resultLayout =
+ xegpu::getTemporaryLayout(op->getOpResult(0));
+ if (!resultLayout || !resultLayout.isForSubgroup())
+ return rewriter.notifyMatchFailure(
+ op, "the result vector of the shape_cast op lacks subgroup layout");
+
+ auto resultDistTypeOrFailure = xegpu::getDistVecTypeBasedOnLaneLayout(
+ resultLayout, op.getResultVectorType());
+ if (failed(resultDistTypeOrFailure))
+ return rewriter.notifyMatchFailure(
+ op, "failed to get distributed vector type for result");
+
+ Value source = adaptor.getSource();
+ auto newShapeCast = vector::ShapeCastOp::create(
+ rewriter, op.getLoc(), resultDistTypeOrFailure.value(), source);
+ rewriter.replaceOp(op, newShapeCast);
+ return success();
+ }
+};
+
+/// Distributes a subgroup-level vector.extract_strided_slice op to
+/// workitem-level. If the result is distributed, the offsets and sizes are
+/// adjusted to match the distributed types.
+struct SgToWiVectorExtractStridedSlice
+ : public OpConversionPattern<vector::ExtractStridedSliceOp> {
+ using OpConversionPattern<vector::ExtractStridedSliceOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::ExtractStridedSliceOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ xegpu::DistributeLayoutAttr resultLayout =
+ xegpu::getTemporaryLayout(op->getOpResult(0));
+ if (!resultLayout || !resultLayout.isForSubgroup())
+ return failure();
+
+ VectorType resultType = op.getType();
+ auto distResultTyOrFailure =
+ xegpu::getDistVecTypeBasedOnLaneLayout(resultLayout, resultType);
+ if (failed(distResultTyOrFailure))
+ return rewriter.notifyMatchFailure(
+ op, "unable to compute distributed vector type from lane layout");
+ VectorType distResultTy = *distResultTyOrFailure;
+
+ SmallVector<int64_t> distributedDims =
+ getDistributedDims(resultType, distResultTy);
+
+ // Collect updated sizes, offsets, strides. Pad to full source rank.
+ int64_t sourceRank = op.getSourceVectorType().getRank();
+ SmallVector<Attribute> updatedSizes =
+ llvm::map_to_vector(op.getSizes(), [](Attribute attr) { return attr; });
+ SmallVector<Attribute> updatedOffsets = llvm::map_to_vector(
+ op.getOffsets(), [](Attribute attr) { return attr; });
+ SmallVector<Attribute> updatedStrides = llvm::map_to_vector(
+ op.getStrides(), [](Attribute attr) { return attr; });
+ for (int64_t i = op.getSizes().size(); i < sourceRank; ++i) {
+ updatedSizes.push_back(
+ rewriter.getI64IntegerAttr(op.getSourceVectorType().getDimSize(i)));
+ updatedOffsets.push_back(rewriter.getI64IntegerAttr(0));
+ updatedStrides.push_back(rewriter.getI64IntegerAttr(1));
+ }
+
+ // If the result is distributed, adjust offsets and sizes in the
+ // distributed dimension.
+ if (!distributedDims.empty()) {
+ if (distributedDims.size() != 1)
+ return rewriter.notifyMatchFailure(
+ op, "only single dimension distribution is supported");
+ int64_t distDim = distributedDims[0];
+ const uArch *uArch = getUArch(xegpu::getChipStr(op).value_or(""));
+ if (!uArch)
+ return rewriter.notifyMatchFailure(
+ op, "target attribute required to determine subgroup size");
+ int subgroupSize = uArch->getSubgroupSize();
+ auto sourceLayout = xegpu::getTemporaryLayout(op->getOpOperand(0));
+ if (!sourceLayout || sourceLayout.getEffectiveLaneLayoutAsInt().empty())
+ return rewriter.notifyMatchFailure(
+ op, "source of extract_strided_slice lacks distribution layout");
+ int sourceDistrDimSize = op.getSourceVectorType().getShape()[distDim];
+ if (sourceDistrDimSize % subgroupSize != 0)
+ return rewriter.notifyMatchFailure(
+ op, "source size along distributed dim is not a multiple of "
+ "subgroup size");
+ auto sourceLaneData = sourceLayout.getEffectiveLaneDataAsInt();
+ // Only check lane_data for the distributed dimension. Non-distributed
+ // dimensions may have non-unit lane_data (e.g., packed layouts).
+ if (distDim < static_cast<int64_t>(sourceLaneData.size()) &&
+ sourceLaneData[distDim] != 1)
+ return rewriter.notifyMatchFailure(
+ op, "expecting unit lane data along the distributed dimension");
+ int64_t distrDimOffset =
+ cast<IntegerAttr>(updatedOffsets[distDim]).getInt();
+ if (distrDimOffset % subgroupSize != 0)
+ return rewriter.notifyMatchFailure(
+ op, "offset along distributed dim is not a multiple of "
+ "subgroup size");
+ // Adjust sizes and offsets for the distributed dimension.
+ updatedSizes[distDim] =
+ rewriter.getI64IntegerAttr(distResultTy.getDimSize(distDim));
+ updatedOffsets[distDim] =
+ rewriter.getI64IntegerAttr(distrDimOffset / subgroupSize);
+ }
+
+ auto newOp = vector::ExtractStridedSliceOp::create(
+ rewriter, op.getLoc(), distResultTy, adaptor.getSource(),
+ ArrayAttr::get(rewriter.getContext(), updatedOffsets),
+ ArrayAttr::get(rewriter.getContext(), updatedSizes),
+ ArrayAttr::get(rewriter.getContext(), updatedStrides));
+ rewriter.replaceOp(op, newOp.getResult());
+ return success();
+ }
+};
+
+/// This pattern distributes a subgroup-level `vector.broadcast` op to
+/// workitem-level. The pattern supports three cases:
+///
+/// 1) Broadcast a low-rank vector to high-rank vector: The low-rank input
+/// vector must have a slice layout of the result. If the distributed source
+/// and target vector types are identical, this lowers to a no-op; otherwise,
+/// it remains a broadcast but operates on distributed vectors.
+///
+/// 2) Broadcast a same-rank vector with identical layouts for source and
+/// target: The source vector must have unit dimensions, and lane_data must
+/// be unit size for those unit dims. This always lowers to a no-op.
+///
+/// 3) Broadcast a scalar with no layout: This always lowers to a broadcast
+/// from scalar to distributed result type.
+///
+/// Example 1 (low-rank to high-rank broadcast):
+/// ```
+/// %0 = "some_op"() {layout_result_0 =
+/// #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+/// dims = [0]>} : () -> vector<16xf16>
+/// %1 = vector.broadcast %0 {layout_result_0 =
+/// #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+/// : vector<16xf16> to vector<16x16xf16>
+/// ```
+/// is distributed to:
+/// ```
+/// %0 = "some_op"() : () -> vector<1xf16>
+/// %1 = vector.broadcast %0 : vector<1xf16> to vector<16x1xf16>
+/// ```
+///
+/// Example 2 (same-rank broadcast, no-op):
+/// ```
+/// %0 = "some_op"() {layout_result_0 =
+/// #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+/// : () -> vector<16x1xf16>
+/// %1 = vector.broadcast %0 {layout_result_0 =
+/// #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+/// : vector<16x1xf16> to vector<16x16xf16>
+/// ```
+/// is distributed to (no-op, source already matches distributed result type):
+/// ```
+/// %0 = "some_op"() : () -> vector<16x1xf16>
+/// // broadcast is eliminated, %0 is used directly
+/// ```
+///
+/// Example 3 (scalar to vector broadcast):
+/// ```
+/// %0 = "some_op"() : () -> f16
+/// %1 = vector.broadcast %0 {layout_result_0 =
+/// #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+/// : f16 to vector<16x16xf16>
+/// ```
+/// is distributed to:
+/// ```
+/// %0 = "some_op"() : f16
+/// %1 = vector.broadcast %0 : f16 to vector<16x1xf16>
+/// ```
+struct SgToWiBroadcast : public OpConversionPattern<vector::BroadcastOp> {
+ using OpConversionPattern<vector::BroadcastOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::BroadcastOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ xegpu::DistributeLayoutAttr resultLayout =
+ xegpu::getTemporaryLayout(cast<OpResult>(op.getResult()));
+ if (!resultLayout || !resultLayout.isForSubgroup())
+ return rewriter.notifyMatchFailure(
+ op, "result does not have subgroup distribute layout");
+
+ VectorType destType = op.getResultVectorType();
+ VectorType sourceType = dyn_cast<VectorType>(op.getSourceType());
+
+ xegpu::DistributeLayoutAttr sourceLayout =
+ xegpu::getTemporaryLayout(op->getOpOperand(0));
+
+ if (sourceType) {
+ int64_t rankDiff = destType.getRank() - sourceType.getRank();
+ if (rankDiff > 0) {
+ // Case 1: Low-rank to high-rank broadcast.
+ if (!sourceLayout || !sourceLayout.isSliceOf(resultLayout))
+ op.emitWarning(
+ "broadcast source layout must be a slice of result layout");
+ } else if (rankDiff == 0) {
+ // Case 2: Same-rank broadcast.
+ auto broadcastUnitDimsSet = op.computeBroadcastedUnitDims();
+ SmallVector<int64_t> broadcastUnitDims(broadcastUnitDimsSet.begin(),
+ broadcastUnitDimsSet.end());
+ assert(sourceLayout.isEqualTo(
+ sourceLayout.setUnitDimData(broadcastUnitDims)) &&
+ "The sg_data for unit dimensions should be set as 1");
+ sourceLayout = sourceLayout.setUnitDimLayout(broadcastUnitDims);
+ }
+ } else {
+ // Case 3: Scalar to vector broadcast.
+ if (sourceLayout)
+ return rewriter.notifyMatchFailure(
+ op, "broadcast from scalar must not have a layout attribute");
+ }
+
+ auto destDistType =
+ xegpu::getDistVecTypeBasedOnLaneLayout(resultLayout, destType);
+ if (failed(destDistType))
+ return rewriter.notifyMatchFailure(
+ op, "failed to distribute the result vector type");
+
+ Value source = adaptor.getSource();
+ // If the adapted source already matches the dest dist type, it's a no-op.
+ if (source.getType() == destDistType.value()) {
+ rewriter.replaceOp(op, source);
+ return success();
+ }
+
+ auto newOp = vector::BroadcastOp::create(rewriter, op.getLoc(),
+ destDistType.value(), source);
+ rewriter.replaceOp(op, newOp);
+ return success();
+ }
+};
+
+/// Distributes a subgroup-level vector.insert_strided_slice op to
+/// workitem-level. If the dest is distributed, the offsets are adjusted to
+/// match the distributed types.
+struct SgToWiVectorInsertStridedSlice
+ : public OpConversionPattern<vector::InsertStridedSliceOp> {
+ using OpConversionPattern<vector::InsertStridedSliceOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::InsertStridedSliceOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ xegpu::DistributeLayoutAttr resultLayout =
+ xegpu::getTemporaryLayout(op->getOpResult(0));
+ if (!resultLayout || !resultLayout.isForSubgroup())
+ return failure();
+
+ VectorType destType = op.getDestVectorType();
+ auto distDestTyOrFailure =
+ xegpu::getDistVecTypeBasedOnLaneLayout(resultLayout, destType);
+ if (failed(distDestTyOrFailure))
+ return rewriter.notifyMatchFailure(
+ op, "unable to compute distributed vector type from lane layout");
+ VectorType distDestTy = *distDestTyOrFailure;
+
+ SmallVector<int64_t> destDistributedDims =
+ getDistributedDims(destType, distDestTy);
+
+ SmallVector<Attribute> updatedOffsets = llvm::map_to_vector(
+ op.getOffsets(), [](Attribute attr) { return attr; });
+
+ if (!destDistributedDims.empty()) {
+ if (destDistributedDims.size() != 1)
+ return rewriter.notifyMatchFailure(
+ op, "only single dimension distribution is supported");
+ int64_t destDistDim = destDistributedDims[0];
+
+ const uArch *uArch = getUArch(xegpu::getChipStr(op).value_or(""));
+ if (!uArch)
+ return rewriter.notifyMatchFailure(
+ op, "target attribute required to determine subgroup size");
+ int subgroupSize = uArch->getSubgroupSize();
+
+ VectorType srcType = op.getSourceVectorType();
+ // The distributed dim must be in the last k (source rank) dims of dest.
+ int64_t sourceDistDim =
+ destDistDim - (destType.getRank() - srcType.getRank());
+ if (sourceDistDim < 0)
+ return rewriter.notifyMatchFailure(
+ op, "distributed dimension must be in the last k dims of dest");
+
+ auto destLayout = xegpu::getTemporaryLayout(op->getOpOperand(1));
+ auto sourceLayout = xegpu::getTemporaryLayout(op->getOpOperand(0));
+ if (!destLayout || !sourceLayout ||
+ destLayout.getEffectiveLaneLayoutAsInt().empty() ||
+ sourceLayout.getEffectiveLaneLayoutAsInt().empty())
+ return rewriter.notifyMatchFailure(
+ op, "source or dest of insert_strided_slice lacks distribution "
+ "layout");
+
+ auto destLaneData = destLayout.getEffectiveLaneDataAsInt();
+ auto sourceLaneData = sourceLayout.getEffectiveLaneDataAsInt();
+ // Only check lane_data for the distributed dimension. Non-distributed
+ // dimensions may have non-unit lane_data (e.g., packed layouts).
+ if ((destDistDim < static_cast<int64_t>(destLaneData.size()) &&
+ destLaneData[destDistDim] != 1) ||
+ (sourceDistDim < static_cast<int64_t>(sourceLaneData.size()) &&
+ sourceLaneData[sourceDistDim] != 1))
+ return rewriter.notifyMatchFailure(
+ op, "expecting unit lane data along the distributed dimension");
+
+ int64_t srcDistrDimSize = srcType.getDimSize(sourceDistDim);
+ if (srcDistrDimSize % subgroupSize != 0)
+ return rewriter.notifyMatchFailure(
+ op, "source distributed dim size is not a multiple of "
+ "subgroup size");
+
+ int64_t destDistrDimOffset =
+ cast<IntegerAttr>(op.getOffsets()[destDistDim]).getInt();
+ if (destDistrDimOffset % subgroupSize != 0)
+ return rewriter.notifyMatchFailure(
+ op, "offset along distributed dim is not a multiple of "
+ "subgroup size");
+ // Adjust offset for the distributed dimension.
+ updatedOffsets[destDistDim] =
+ rewriter.getI64IntegerAttr(destDistrDimOffset / subgroupSize);
+ }
+
+ auto newOp = vector::InsertStridedSliceOp::create(
+ rewriter, op.getLoc(), distDestTy, adaptor.getValueToStore(),
+ adaptor.getDest(),
+ ArrayAttr::get(rewriter.getContext(), updatedOffsets), op.getStrides());
+ rewriter.replaceOp(op, newOp.getResult());
+ return success();
+ }
+};
+
+/// Distributes a subgroup-level vector.insert op to workitem-level. Only
+/// handles sub-vector insertion (value to store is VectorType, not scalar).
+struct SgToWiVectorInsert : public OpConversionPattern<vector::InsertOp> {
+ using OpConversionPattern<vector::InsertOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::InsertOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ // Only handle vector value-to-store (not scalar insertion).
+ auto valueType = dyn_cast<VectorType>(op.getValueToStoreType());
+ if (!valueType)
+ return rewriter.notifyMatchFailure(op, "scalar insert not supported");
+
+ xegpu::DistributeLayoutAttr layout =
+ xegpu::getTemporaryLayout(op->getOpResult(0));
+ if (!layout || !layout.isForSubgroup())
+ return failure();
+
+ // verify that the outer k dimensions (for offsets)
+ // don't have non-unit lane_layout.
+ auto laneLayout = layout.getEffectiveLaneLayoutAsInt();
+ if (llvm::any_of(ArrayRef<int64_t>(laneLayout).drop_back(1),
+ [](int64_t v) { return v != 1; }))
+ return rewriter.notifyMatchFailure(
+ op, "only innermost dimension distribution is supported for "
+ "vector.insert");
+
+ auto newOp = vector::InsertOp::create(
+ rewriter, op.getLoc(), adaptor.getValueToStore(), adaptor.getDest(),
+ op.getMixedPosition());
+ rewriter.replaceOp(op, newOp.getResult());
+ return success();
+ }
+};
+
+/// Folds a subgroup-level ConvertLayout op with compatible lane layouts.
+struct SgToWiConvertLayout
+ : public OpConversionPattern<xegpu::ConvertLayoutOp> {
+ using OpConversionPattern<xegpu::ConvertLayoutOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(xegpu::ConvertLayoutOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto inputLayout = op.getInputLayoutAttr();
+ auto targetLayout = op.getTargetLayoutAttr();
+ Type valType = op.getResult().getType();
+
+ if (valType.isIntOrFloat()) {
+ rewriter.replaceOp(op, op.getSource());
+ return success();
+ }
+
+ auto resShape = cast<VectorType>(valType).getShape();
+ SmallVector<int64_t> resShapeVec(resShape.begin(), resShape.end());
+ if (!inputLayout.isCompatibleWith(targetLayout, resShapeVec,
+ xegpu::LayoutKind::Lane)) {
+ return rewriter.notifyMatchFailure(
+ op, "lowering incompatible convert_layout not yet supported");
+ }
+
+ rewriter.replaceOp(op, adaptor.getSource());
+ return success();
+ }
+};
+
+struct XeGPUSgToWiDistributeExperimentalPass
+ : public xegpu::impl::XeGPUSgToWiDistributeExperimentalBase<
+ XeGPUSgToWiDistributeExperimentalPass> {
+ void runOnOperation() override;
+};
+
+} // namespace
+
+void XeGPUSgToWiDistributeExperimentalPass::runOnOperation() {
+
+ // Recover temporary operand layouts for usage in patterns.
+ Operation *root = getOperation();
+ if (!xegpu::recoverTemporaryLayouts(root)) {
+ signalPassFailure();
+ return;
+ }
+
+ // Collect existing UnrealizedConversionCastOps. These must be preserved.
+ llvm::SmallSetVector<UnrealizedConversionCastOp, 8> existingCasts;
+ root->walk(
+ [&](UnrealizedConversionCastOp castOp) { existingCasts.insert(castOp); });
+ // Perform a structural type conversion to convert structural ops to have WI
+ // types. This will insert UnrealizedConversionCastOps to make the IR
+ // valid.
+ auto materializeCast = [&](mlir::OpBuilder &builder, mlir::Type type,
+ mlir::ValueRange inputs,
+ mlir::Location loc) -> mlir::Value {
+ UnrealizedConversionCastOp castOp =
+ UnrealizedConversionCastOp::create(builder, loc, type, inputs);
+ return castOp.getResult(0);
+ };
+ {
+ ConversionTarget target(getContext());
+ TypeConverter typeConverter;
+ RewritePatternSet patterns(&getContext());
+ typeConverter.addSourceMaterialization(materializeCast);
+ typeConverter.addTargetMaterialization(materializeCast);
+ xegpu::populateXeGPUSgToWiDistributeTypeConversions(typeConverter);
+ scf::populateSCFStructuralTypeConversionsAndLegality(typeConverter,
+ patterns, target);
+ xegpu::populateXeGPUSgToWiDistributeTypeConversionAndLegality(
+ typeConverter, patterns, target);
+ target.addLegalOp<UnrealizedConversionCastOp>();
+ (void)applyPartialConversion(root, target, std::move(patterns));
+ }
+ // Structural type conversion can generate some redundant
+ // UnrealizedConversionCastOps to materialize the SG type from type converted
+ // WI type. These are redundant at this point and can be eliminated by
+ // inserting shape casts instead.
+ // Example:
+ // %1 = UnrealizedConversionCastOp %0 : vector<16x1xf32> to vector<16x16xf32>
+ // %2 = UnrealizedConversionCastOp %1 : vector<16x16xf32> to vector<16xf32>
+ // This can be replaced with:
+ // %2 = vector.shape_cast %0 : vector<16x1xf32> to vector<16xf32>
+ OpBuilder builder(root);
+ root->walk([&](UnrealizedConversionCastOp op) {
+ // If this op existed before, nothing to do.
+ if (existingCasts.contains(op))
+ return;
+ // number of inputs and outputs must be 1.
+ if (op.getNumOperands() != 1 || op.getNumResults() != 1)
+ return;
+ // Both input and output types must be vector types.
+ auto singleInput = op.getInputs()[0];
+ auto inputTy = dyn_cast<VectorType>(singleInput.getType());
+ auto outputTy = dyn_cast<VectorType>(op.getResult(0).getType());
+ if (!inputTy || !outputTy)
+ return;
+
+ // Check if the defining op of the input is also an
+ // UnrealizedConversionCastOp and it has a single user (which is this
+ // op).
+ auto definingOp = singleInput.getDefiningOp<UnrealizedConversionCastOp>();
+ if (!definingOp || !definingOp->hasOneUse())
+ return;
+ auto inputOfDefiningOp = definingOp.getInputs()[0];
+ // If the input of the defining op and output type are both vector types
+ // have same number of elements, insert a shape cast.
+ auto inputOfDefiningOpTy =
+ dyn_cast<VectorType>(inputOfDefiningOp.getType());
+ if (inputOfDefiningOpTy &&
+ inputOfDefiningOpTy.getNumElements() == outputTy.getNumElements()) {
+ builder.setInsertionPoint(op);
+ auto shapeCast = vector::ShapeCastOp::create(builder, op.getLoc(),
+ outputTy, inputOfDefiningOp);
+ op.replaceAllUsesWith(ValueRange{shapeCast.getResult()});
+ return;
+ }
+ });
+ // At this point, we will have some dead UnrealizedConversionCastOps. Just
+ // erase them.
+ bool changed = true;
+ while (changed) {
+ changed = false;
+ root->walk([&](UnrealizedConversionCastOp op) {
+ // Skip existing casts.
+ if (existingCasts.contains(op))
+ return;
+ if (op.use_empty()) {
+ op.erase();
+ changed = true;
+ }
+ });
+ }
+}
+
+void xegpu::populateXeGPUSgToWiDistributeTypeConversions(
+ TypeConverter &typeConverter) {
+ // Any type other than TensorDescType and VectorType are legal as is.
+ typeConverter.addConversion([](Type type) -> std::optional<Type> {
+ if (!isa<TensorDescType, VectorType>(type))
+ return type;
+ return std::nullopt;
+ });
+ // For TensorDescType, drop the layout attribute if any.
+ typeConverter.addConversion([](TensorDescType type) -> Type {
+ if (type.getLayoutAttr()) {
+ return type.dropLayouts();
+ }
+ return type;
+ });
+ // For VectorType, check if there is a distribute layout attribute on the
+ // value. If so, convert to the distributed vector type based on the layout.
+ typeConverter.addConversion([](Value v) -> std::optional<Type> {
+ auto type = v.getType();
+ // If value is not vector type, nothing to do.
+ if (!isa<VectorType>(type))
+ return std::nullopt;
+ auto layout = xegpu::getDistributeLayoutAttr(v);
+ if (!layout || !layout.isForSubgroup())
+ return type;
+ // Vector type is distributed based on lane layout.
+ auto newTyOrFailure =
+ getDistVecTypeBasedOnLaneLayout(layout, cast<VectorType>(type));
+ if (failed(newTyOrFailure))
+ return type;
+ return *newTyOrFailure;
+ });
+}
+
+void xegpu::populateXeGPUSgToWiDistributeTypeConversionAndLegality(
+ TypeConverter &typeConverter, RewritePatternSet &patterns,
+ ConversionTarget &target) {
+ populateXeGPUSgToWiDistributeTypeConversions(typeConverter);
+ // CreateNdDescOp is legal only if its result type has no layout attribute.
+ target.addDynamicallyLegalOp<xegpu::CreateNdDescOp>(
+ [&](xegpu::CreateNdDescOp op) { return !op.getType().getLayoutAttr(); });
+ // Any anchor XeGPU op is legal only if it has no anchor layout.
+ target.addDynamicallyLegalDialect<xegpu::XeGPUDialect>([](Operation *op) {
+ auto anchorOp = dyn_cast<AnchorLayoutInterface>(op);
+ if (!anchorOp)
+ return true;
+ return !anchorOp.getAnchorLayout();
+ });
+ // Arith constants are legal only if they have no temporary layout attribute.
+ target.addDynamicallyLegalOp<arith::ConstantOp>(
+ [=](arith::ConstantOp op) -> bool {
+ // If the result type is not a vector, it's legal.
+ if (!isa<VectorType>(op.getResult().getType()))
+ return true;
+ return !xegpu::getTemporaryLayout(dyn_cast<OpResult>(op.getResult()));
+ });
+ // In math and arith dialects, only handle elementwise ops with a single
+ // result and with a result layout attribute.
+ target.addDynamicallyLegalDialect<math::MathDialect, arith::ArithDialect>(
+ [=](Operation *op) -> std::optional<bool> {
+ // Only handle elementwise mappable ops
+ if (!OpTrait::hasElementwiseMappableTraits(op))
+ return true;
+ // Only handle ops with single vector result
+ if (op->getNumResults() != 1)
+ return true;
+
+ VectorType resultType =
+ dyn_cast<VectorType>(op->getResult(0).getType());
+ if (!resultType)
+ return true;
+
+ // Check if all operands are vectors of the same shape
+ for (Value operand : op->getOperands()) {
+ VectorType operandType = dyn_cast<VectorType>(operand.getType());
+ if (!operandType || operandType.getShape() != resultType.getShape()) {
+ return true;
+ }
+ }
+ return !xegpu::getTemporaryLayout(dyn_cast<OpResult>(op->getResult(0)));
+ });
+ // vector::ReductionOp is legal only if its source has no distribute layout
+ // attribute.
+ target.addDynamicallyLegalOp<vector::ReductionOp>(
+ [=](vector::ReductionOp op) -> bool {
+ auto layout = xegpu::getDistributeLayoutAttr(op.getVector());
+ return !layout;
+ });
+ // vector::MultiDimReductionOp op legality.
+ target.addDynamicallyLegalOp<vector::MultiDimReductionOp>(
+ [=](vector::MultiDimReductionOp op) -> bool {
+ return !isValidSubgroupMultiReductionOp(op);
+ });
+ target.addDynamicallyLegalOp<vector::CreateMaskOp, vector::ConstantMaskOp,
+ vector::TransposeOp, vector::BitCastOp,
+ vector::ShapeCastOp, vector::StepOp,
+ vector::BroadcastOp>([=](Operation *op) -> bool {
+ return !xegpu::getTemporaryLayout(op->getOpResult(0));
+ });
+ target.addDynamicallyLegalOp<vector::ExtractOp>(
+ [=](vector::ExtractOp op) -> bool {
+ if (!isa<VectorType>(op.getType()))
+ return true;
+ return !xegpu::getTemporaryLayout(op->getOpResult(0));
+ });
+ target.addDynamicallyLegalOp<vector::InsertOp>(
+ [=](vector::InsertOp op) -> bool {
+ return !xegpu::getTemporaryLayout(op->getOpResult(0));
+ });
+ target.addDynamicallyLegalOp<vector::ExtractStridedSliceOp>(
+ [=](vector::ExtractStridedSliceOp op) -> bool {
+ return !xegpu::getTemporaryLayout(op->getOpResult(0));
+ });
+ target.addDynamicallyLegalOp<vector::InsertStridedSliceOp>(
+ [=](vector::InsertStridedSliceOp op) -> bool {
+ return !xegpu::getTemporaryLayout(op->getOpResult(0));
+ });
+ target.markUnknownOpDynamicallyLegal([](Operation *op) { return true; });
+ patterns.add<SgToWiCreateNdDesc, SgToWiLoadNd, SgToWiStoreNd, SgToWiDpas,
+ SgToWiElementWise, SgToWiArithConstant, SgToWiPrefetchNd,
+ SgToWiLoadGather, SgToWiStoreScatter, SgToWiVectorReduction,
+ SgToWiMultiDimReduction, SgToWiVectorExtract, SgToWiVectorInsert,
+ SgToWiVectorExtractStridedSlice, SgToWiVectorInsertStridedSlice,
+ SgToWiLoadMatrix, SgToWiStoreMatrix, SgToWiConvertLayout,
+ SgToWiVectorTranspose, SgToWiVectorBitcast, SgToWiVectorStep,
+ SgToWiVectorShapeCast, SgToWiBroadcast,
+ SgToWiCreateMask<vector::CreateMaskOp>,
+ SgToWiCreateMask<vector::ConstantMaskOp>>(typeConverter,
+ patterns.getContext());
+}
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
index c029ee1d8ae0d..6ef0a63926c8d 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
@@ -1610,7 +1610,6 @@ void XeGPUSgToWiDistributeExperimentalPass::runOnOperation() {
}
});
}
- // Remove layout attributes from SCF ops
getOperation()->walk([](Operation *op) {
SmallVector<StringAttr> attrsToRemove;
for (auto namedAttr : op->getDiscardableAttrs()) {
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 012d7aefafb06..d16d0cfc6c587 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -825,10 +825,12 @@ struct StoreDistribution final : public gpu::WarpDistributionPattern {
}
}
- auto layoutPayload = storeScatterOp.getLayoutAttr();
+ auto layoutPayload =
+ xegpu::getTemporaryLayout(storeScatterOp->getOpOperand(0));
auto layoutOffsets =
- xegpu::inferMaskOffsetLayoutForScatterIO(layoutPayload, chunkSize);
- auto layoutMask = layoutOffsets;
+ xegpu::getTemporaryLayout(storeScatterOp->getOpOperand(2));
+ auto layoutMask =
+ xegpu::getTemporaryLayout(storeScatterOp->getOpOperand(3));
FailureOr<VectorType> distStoreVecByWarpOpOrFailure =
getDistVecTypeBasedOnLaneLayout(layoutPayload, storeVecTy);
@@ -1130,36 +1132,9 @@ struct LoadDistribution final : public gpu::WarpDistributionPattern {
}
}
- auto layoutPayload = loadGatherOp.getLayoutAttr();
auto layoutOffsets =
- xegpu::inferMaskOffsetLayoutForScatterIO(layoutPayload, chunkSize);
- auto layoutMask = layoutOffsets;
-
- // print the layouts for debug
- LLVM_DEBUG({
- llvm::dbgs() << "In LoadDistribution pattern:\n";
- llvm::dbgs() << "Payload layout: ";
- if (layoutPayload)
- llvm::dbgs() << layoutPayload;
- else
- llvm::dbgs() << "none";
- llvm::dbgs() << "\nOffsets layout: ";
- if (layoutOffsets)
- llvm::dbgs() << layoutOffsets;
- else
- llvm::dbgs() << "none";
- llvm::dbgs() << "\nMask layout: ";
- if (layoutMask)
- llvm::dbgs() << layoutMask;
- else
- llvm::dbgs() << "none";
- llvm::dbgs() << "\n";
- });
-
- // auto layoutOffsets =
- // xegpu::getTemporaryLayout(loadGatherOp->getOpOperand(1));
- // auto layoutMask =
- // xegpu::getTemporaryLayout(loadGatherOp->getOpOperand(2));
+ xegpu::getTemporaryLayout(loadGatherOp->getOpOperand(1));
+ auto layoutMask = xegpu::getTemporaryLayout(loadGatherOp->getOpOperand(2));
FailureOr<VectorType> distOffsetsByWarpOpOrFailure =
getDistVecTypeBasedOnLaneLayout(layoutOffsets, offsetsTy);
@@ -2306,8 +2281,6 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
op->erase();
return WalkResult::advance();
});
-
- // Remove layout attributes from SCF ops
getOperation()->walk([](Operation *op) {
SmallVector<StringAttr> attrsToRemove;
for (auto namedAttr : op->getDiscardableAttrs()) {
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 02f88828f667f..dabdcb61f0500 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -1784,7 +1784,6 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
applyPartialConversion(getOperation(), target, std::move(patterns))))
return signalPassFailure();
- // Remove layout attributes from SCF ops
getOperation()->walk([](Operation *op) {
SmallVector<StringAttr> attrsToRemove;
for (auto namedAttr : op->getDiscardableAttrs()) {
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
index 0dab06d206ceb..a6f621de0dd82 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
@@ -16,7 +16,6 @@
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
-#include "mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h"
#include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Operation.h"
>From 79764d19d6f664e3b46a5cabec803be021e64dac Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Mon, 13 Apr 2026 21:11:24 +0000
Subject: [PATCH 15/19] cleanup
---
.../XeGPU/Transforms/XeGPULayoutImpl.cpp | 200 ++----------------
1 file changed, 22 insertions(+), 178 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
index c326e6c917df8..8f319bd161798 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
@@ -30,10 +30,6 @@
#include <cstdint>
#include <numeric>
-#include "llvm/Support/Debug.h"
-#define DEBUG_TYPE "xegpu-layout-recovery"
-#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
-
using namespace mlir;
void xegpu::recoverTemporaryLayoutsDeprecated(Operation *op) {
@@ -139,15 +135,6 @@ static xegpu::DistributeLayoutAttr getLayoutFromUsePoints(Value result) {
xegpu::DistributeLayoutAttr layout = nullptr;
for (OpOperand &use : result.getUses()) {
if (auto tmpLayout = xegpu::getDistributeLayoutAttr(use)) {
- // debug print the use and op, and the tmpLayout
- LLVM_DEBUG({
- DBGS() << "getLayoutFromUsePoints use: " << use.getOwner()->getName()
- << use.getOwner();
- llvm::dbgs() << ", tmpLayout=" << tmpLayout << "\n";
- });
- // under debug mode, we want to check all the use points to make sure
- // there is no conflict, so we do not break here. In release mode, we can
- // break at the first use
if (!layout)
layout = tmpLayout;
}
@@ -158,14 +145,8 @@ static xegpu::DistributeLayoutAttr getLayoutFromUsePoints(Value result) {
// For regular operations: First the result layouts are propagated from uses.
// Then the result layouts are propagated to uses (operands).
static void propagateResultsToRegularOperands(Operation *op) {
- LLVM_DEBUG(DBGS() << "propagateResultsToRegularOperands: " << op->getName()
- << " (" << op->getNumOperands() << " operands, "
- << op->getNumResults() << " results)\n");
-
- if (op->getNumResults() == 0) {
- LLVM_DEBUG(DBGS() << " skipping (no results)\n");
+ if (op->getNumResults() == 0)
return;
- }
OpResult result = op->getResult(0);
xegpu::DistributeLayoutAttr resLayout = getLayoutFromUsePoints(result);
@@ -193,47 +174,26 @@ static void propagateResultsToRegularOperands(Operation *op) {
// Layouts are needed for vector type only.
xegpu::DistributeLayoutAttr operandLayout =
xegpu::inferSourceLayoutFromResult(opr, resLayout);
- if (!isa<VectorType>(opr.get().getType())) {
- LLVM_DEBUG(DBGS() << " operand #" << opr.getOperandNumber()
- << ": skipped (non-vector type: " << opr.get().getType()
- << ")\n");
+ if (!isa<VectorType>(opr.get().getType()))
continue;
- }
xegpu::setTemporaryLayout(opr, operandLayout);
- // debug print op
- LLVM_DEBUG(DBGS() << "after propagateResultsToRegularOperands op: "
- << op->getName() << op << " operand #"
- << opr.getOperandNumber()
- << ": type=" << opr.get().getType());
- llvm::dbgs() << ", temp Layout=" << xegpu::getTemporaryLayout(opr);
- llvm::dbgs() << "\n";
}
}
static void propagateRegionResultsToYieldOperands(
mlir::RegionBranchTerminatorOpInterface yieldOp) {
- LLVM_DEBUG(DBGS() << "propagateRegionResultsToYieldOperands: "
- << yieldOp->getName() << " (" << yieldOp->getNumOperands()
- << " operands), parent="
- << yieldOp->getParentOp()->getName() << "\n");
-
- if (isa<func::FuncOp>(yieldOp->getParentOp())) {
- LLVM_DEBUG(DBGS() << " skipping (parent is FuncOp)\n");
+ if (isa<func::FuncOp>(yieldOp->getParentOp()))
return;
- }
auto regionBranchOp =
dyn_cast<RegionBranchOpInterface>(yieldOp->getParentOp());
- if (!regionBranchOp) {
- LLVM_DEBUG(DBGS() << " skipping (parent is not RegionBranchOp)\n");
+ if (!regionBranchOp)
return;
- }
// Gather layouts for each result of the parent region op from external
// use points.
unsigned numResults = regionBranchOp->getNumResults();
- LLVM_DEBUG(DBGS() << " parent op has " << numResults << " results\n");
if (numResults == 0)
return;
@@ -241,14 +201,8 @@ static void propagateRegionResultsToYieldOperands(
for (unsigned i = 0; i < numResults; ++i) {
OpResult result = regionBranchOp->getResult(i);
resultLayouts[i] = getLayoutFromUsePoints(result);
- if (resultLayouts[i]) {
- LLVM_DEBUG(DBGS() << " result #" << i << ": type=" << result.getType()
- << ", layout=" << resultLayouts[i] << "\n");
+ if (resultLayouts[i])
xegpu::setTemporaryLayout(result, resultLayouts[i]);
- } else {
- LLVM_DEBUG(DBGS() << " result #" << i
- << ": skipped (no layout from use points)\n");
- }
}
// Use getSuccessorOperands to find which operands of the terminator
@@ -264,35 +218,15 @@ static void propagateRegionResultsToYieldOperands(
unsigned beginIdx = succOps.getBeginOperandIndex();
unsigned count = std::min(static_cast<unsigned>(succOps.size()), numResults);
- LLVM_DEBUG(DBGS() << " " << count << " successor operands starting at index "
- << beginIdx << "\n");
-
for (unsigned i = 0; i < count; ++i) {
if (!resultLayouts[i])
continue;
- LLVM_DEBUG(DBGS() << " -> setting layout on operand #" << (beginIdx + i)
- << "\n");
xegpu::setTemporaryLayout(yieldOp->getOpOperand(beginIdx + i),
resultLayouts[i]);
}
-
- LLVM_DEBUG({
- DBGS() << " after propagateRegionResultsToYieldOperands:\n";
- yieldOp->print(llvm::dbgs(), OpPrintingFlags().printGenericOpForm());
- llvm::dbgs() << "\n";
- });
}
static void propagateRegionArgsToInits(mlir::RegionBranchOpInterface regionOp) {
- LLVM_DEBUG(DBGS() << "propagateRegionArgsToInits: " << regionOp->getName()
- << " (" << regionOp->getNumOperands() << " operands, "
- << regionOp->getNumRegions() << " regions)\n");
- LLVM_DEBUG({
- DBGS() << " before propagateRegionArgsToInits, Region IR:\n";
- regionOp.print(llvm::dbgs(), OpPrintingFlags().printGenericOpForm());
- llvm::dbgs() << "\n";
- });
-
// Iterate all regions of the region op. For each block argument that has a
// layout (determined from its use points), trace back to find the
// corresponding init operand of the regionOp and set the layout on it.
@@ -302,14 +236,8 @@ static void propagateRegionArgsToInits(mlir::RegionBranchOpInterface regionOp) {
RegionSuccessor regionSuccessor(®ion);
for (auto [argIdx, regionArg] : llvm::enumerate(region.getArguments())) {
auto layout = getLayoutFromUsePoints(regionArg);
- if (!layout) {
- LLVM_DEBUG(DBGS() << " region #" << region.getRegionNumber()
- << " arg #" << argIdx << ": skipped (no layout)\n");
+ if (!layout)
continue;
- }
- LLVM_DEBUG(DBGS() << " region #" << region.getRegionNumber() << " arg #"
- << argIdx << ": type=" << regionArg.getType()
- << ", layout=" << layout << "\n");
// Find all predecessor values that flow into this block argument.
SmallVector<Value> predValues;
@@ -317,42 +245,23 @@ static void propagateRegionArgsToInits(mlir::RegionBranchOpInterface regionOp) {
for (Value predVal : predValues) {
// Match predecessor value to an operand of the regionOp.
for (OpOperand &operand : regionOp->getOpOperands()) {
- if (operand.get() == predVal) {
- LLVM_DEBUG(DBGS() << " -> setting layout on init operand #"
- << operand.getOperandNumber() << "\n");
+ if (operand.get() == predVal)
xegpu::setTemporaryLayout(operand, layout);
- }
}
}
}
}
-
- LLVM_DEBUG({
- DBGS() << " after propagateRegionArgsToInits, Region IR:\n";
- regionOp.print(llvm::dbgs(), OpPrintingFlags().printGenericOpForm());
- llvm::dbgs() << "\n";
- });
}
bool xegpu::recoverTemporaryLayouts(Operation *rootOp) {
- LLVM_DEBUG(DBGS() << "=== recoverTemporaryLayouts START ===\n");
-
auto processFunc = [&](Region &body, StringRef funcName) {
- LLVM_DEBUG(DBGS() << "Processing func: " << funcName << "\n");
walkRegionBackward(body, [&](Operation *op) {
- LLVM_DEBUG(DBGS() << "Visiting op: " << op->getName());
if (auto regionOp = dyn_cast<mlir::RegionBranchOpInterface>(op)) {
- // hit the region op after visiting inside region
- LLVM_DEBUG(DBGS() << " -> dispatching as RegionBranchOp\n");
propagateRegionArgsToInits(regionOp);
} else if (auto yieldOp =
dyn_cast<mlir::RegionBranchTerminatorOpInterface>(op)) {
- // yield op inside region op
- LLVM_DEBUG(DBGS() << " -> dispatching as YieldOp\n");
propagateRegionResultsToYieldOperands(yieldOp);
} else if (!dyn_cast<xegpu::AnchorLayoutInterface>(op)) {
- // if the op is regular op, calling propagateResultsToRegularOperands
- LLVM_DEBUG(DBGS() << " -> dispatching as regular op\n");
propagateResultsToRegularOperands(op);
}
});
@@ -365,13 +274,6 @@ bool xegpu::recoverTemporaryLayouts(Operation *rootOp) {
processFunc(func.getBody(), func.getName());
});
- LLVM_DEBUG(DBGS() << "=== recoverTemporaryLayouts END ===\n");
- // print the root op after
- LLVM_DEBUG({
- DBGS() << "After recoverTemporaryLayouts, IR:\n";
- rootOp->print(llvm::dbgs(), OpPrintingFlags().printGenericOpForm());
- llvm::dbgs() << "\n";
- });
return true;
}
@@ -1394,124 +1296,72 @@ xegpu::setupDpasLayout(xegpu::LayoutKind layoutKind, VectorType aTy,
xegpu::DistributeLayoutAttr
xegpu::inferSourceLayoutFromResult(OpOperand &operand,
xegpu::DistributeLayoutAttr resLayout) {
- if (!resLayout) {
- LLVM_DEBUG(DBGS() << "no resLayout, returning null\n");
+ if (!resLayout)
return xegpu::DistributeLayoutAttr();
- }
Operation *op = operand.getOwner();
unsigned idx = operand.getOperandNumber();
// For vector::BroadcastOp, infer the source layout from the result layout.
if (auto broadcast = dyn_cast<vector::BroadcastOp>(op)) {
- LLVM_DEBUG(DBGS() << " -> BroadcastOp\n");
auto srcTy = dyn_cast<VectorType>(broadcast.getSourceType());
- if (!srcTy) {
- LLVM_DEBUG(DBGS() << " source is not VectorType, returning null\n");
+ if (!srcTy)
return xegpu::DistributeLayoutAttr();
- }
- auto inferred = xegpu::inferBroadcastSourceLayout(
+ return xegpu::inferBroadcastSourceLayout(
resLayout, broadcast.getResultVectorType().getShape(),
srcTy.getShape());
- LLVM_DEBUG(DBGS() << " inferred=" << inferred << "\n");
- return inferred;
}
// For vector::MultiDimReductionOp, infer source layout from result layout
// using reduction dims. Acc operand is expected to have the same layout as
// the result.
if (auto reduction = dyn_cast<vector::MultiDimReductionOp>(op)) {
- LLVM_DEBUG(DBGS() << " -> MultiDimReductionOp, operand idx=" << idx
- << "\n");
if (idx == 0) {
SmallVector<int64_t> reductionDims(reduction.getReductionDims());
- LLVM_DEBUG({
- DBGS() << " reductionDims=[";
- llvm::interleaveComma(reductionDims, llvm::dbgs());
- llvm::dbgs() << "]\n";
- });
- auto inferred =
- xegpu::inferMultiReductionSourceLayout(resLayout, reductionDims);
- LLVM_DEBUG(DBGS() << " inferred source layout=" << inferred << "\n");
- return inferred;
+ return xegpu::inferMultiReductionSourceLayout(resLayout, reductionDims);
}
- if (idx == 1) {
- LLVM_DEBUG(DBGS() << " acc operand, using resLayout\n");
+ if (idx == 1)
return resLayout;
- }
}
- if (auto reduction = dyn_cast<vector::ReductionOp>(op)) {
- LLVM_DEBUG(DBGS() << " -> ReductionOp\n");
- auto inferred = xegpu::inferReductionSourceLayout(resLayout);
- LLVM_DEBUG(DBGS() << " inferred=" << inferred << "\n");
- return inferred;
- }
+ if (auto reduction = dyn_cast<vector::ReductionOp>(op))
+ return xegpu::inferReductionSourceLayout(resLayout);
// For vector::BitCastOp, infer source layout from result layout using
// element type bitwidths.
if (auto bitcast = dyn_cast<vector::BitCastOp>(op)) {
- LLVM_DEBUG(DBGS() << " -> BitCastOp\n");
int resElemBitWidth =
bitcast.getResultVectorType().getElementType().getIntOrFloatBitWidth();
int srcElemBitWidth =
bitcast.getSourceVectorType().getElementType().getIntOrFloatBitWidth();
- LLVM_DEBUG(DBGS() << " resBitWidth=" << resElemBitWidth
- << ", srcBitWidth=" << srcElemBitWidth << "\n");
- auto inferred = xegpu::inferBitCastSourceLayout(resLayout, resElemBitWidth,
- srcElemBitWidth);
- LLVM_DEBUG(DBGS() << " inferred=" << inferred << "\n");
- return inferred;
+ return xegpu::inferBitCastSourceLayout(resLayout, resElemBitWidth,
+ srcElemBitWidth);
}
// For vector::ShapeCastOp, infer source layout from result layout using
// shapes.
if (auto shapeCast = dyn_cast<vector::ShapeCastOp>(op)) {
- LLVM_DEBUG({
- DBGS() << " -> ShapeCastOp: resShape=[";
- llvm::interleaveComma(shapeCast.getResultVectorType().getShape(),
- llvm::dbgs());
- llvm::dbgs() << "], srcShape=[";
- llvm::interleaveComma(shapeCast.getSourceVectorType().getShape(),
- llvm::dbgs());
- llvm::dbgs() << "]\n";
- });
- auto inferred = xegpu::inferShapeCastSourceLayout(
+ return xegpu::inferShapeCastSourceLayout(
resLayout, shapeCast.getResultVectorType().getShape(),
shapeCast.getSourceVectorType().getShape());
- LLVM_DEBUG(DBGS() << " inferred=" << inferred << "\n");
- return inferred;
}
// For vector::InsertStridedSliceOp, infer source layout from result layout.
// Dest vector must have the same layout as the result.
if (auto insertSlice = dyn_cast<vector::InsertStridedSliceOp>(op)) {
- LLVM_DEBUG(DBGS() << " -> InsertStridedSliceOp, operand idx=" << idx
- << "\n");
if (idx == 0) {
- auto inferred = xegpu::inferInsertStridedSliceSourceLayout(
+ return xegpu::inferInsertStridedSliceSourceLayout(
resLayout, insertSlice.getDestVectorType().getShape(),
insertSlice.getSourceVectorType().getShape());
- LLVM_DEBUG(DBGS() << " inferred source layout=" << inferred << "\n");
- return inferred;
}
- if (idx == 1) {
- LLVM_DEBUG(DBGS() << " dest operand, using resLayout\n");
+ if (idx == 1)
return resLayout;
- }
}
// For vector::TransposeOp, infer source layout from result layout using
// permutation.
if (auto transpose = dyn_cast<vector::TransposeOp>(op)) {
- LLVM_DEBUG({
- DBGS() << " -> TransposeOp, perm=[";
- llvm::interleaveComma(transpose.getPermutation(), llvm::dbgs());
- llvm::dbgs() << "]\n";
- });
- auto inferred = xegpu::inferTransposeSourceLayout(
- resLayout, transpose.getPermutation());
- LLVM_DEBUG(DBGS() << " inferred=" << inferred << "\n");
- return inferred;
+ return xegpu::inferTransposeSourceLayout(resLayout,
+ transpose.getPermutation());
}
if (isa<VectorType>(operand.get().getType()) &&
@@ -1520,8 +1370,6 @@ xegpu::inferSourceLayoutFromResult(OpOperand &operand,
// result.
// if (OpTrait::hasElementwiseMappableTraits(op) && op->getNumResults() ==
// 1) {
- LLVM_DEBUG(DBGS() << " -> other vector or tensorDesc ops using resLayout="
- << (resLayout ? resLayout : nullptr) << "\n");
return resLayout;
}
return xegpu::DistributeLayoutAttr();
@@ -1537,9 +1385,5 @@ xegpu::DistributeLayoutAttr xegpu::getConsumerLayoutAt(OpOperand &operand) {
return inferredOperandLayout;
// By default, assume no layout conflict and return the current layout of
// the operand.
- auto fallback = xegpu::getDistributeLayoutAttr(operand.get());
- LLVM_DEBUG(DBGS() << " -> fallback (unhandled op " << op->getName()
- << "), returning operand layout="
- << (fallback ? fallback : nullptr) << "\n");
- return fallback;
+ return xegpu::getDistributeLayoutAttr(operand.get());
}
>From eb6048ed52046e0aa13d566e34d49ec22e1a13ac Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Mon, 13 Apr 2026 21:14:53 +0000
Subject: [PATCH 16/19] clean up
---
.../XeGPUSgToWiDistributeExperimental.cp | 1744 -----------------
1 file changed, 1744 deletions(-)
delete mode 100644 mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cp
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cp
deleted file mode 100644
index e3227c7f5b149..0000000000000
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cp
+++ /dev/null
@@ -1,1744 +0,0 @@
-//===- XeGPUSgToWiDistributeExperimental.cpp - XeGPU SG to WI Pass --------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-#include "mlir/Dialect/GPU/IR/GPUDialect.h"
-#include "mlir/Dialect/Index/IR/IndexDialect.h"
-#include "mlir/Dialect/Math/IR/Math.h"
-#include "mlir/Dialect/MemRef/IR/MemRef.h"
-#include "mlir/Dialect/SCF/Transforms/Patterns.h"
-#include "mlir/Dialect/Vector/IR/VectorOps.h"
-#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
-#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
-#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
-#include "mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h"
-#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
-#include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/BuiltinAttributes.h"
-#include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/BuiltinTypes.h"
-#include "mlir/IR/MLIRContext.h"
-#include "mlir/IR/Operation.h"
-#include "mlir/IR/Value.h"
-#include "mlir/IR/ValueRange.h"
-#include "mlir/Transforms/DialectConversion.h"
-#include "llvm/ADT/SetVector.h"
-#include "llvm/Support/LogicalResult.h"
-#include "llvm/Support/raw_ostream.h"
-#include <optional>
-
-namespace mlir {
-namespace xegpu {
-#define GEN_PASS_DEF_XEGPUSGTOWIDISTRIBUTEEXPERIMENTAL
-#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
-} // namespace xegpu
-} // namespace mlir
-
-using namespace mlir;
-
-#define DEBUG_TYPE "xegpu-sg-to-wi-distribute-experimental"
-#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
-
-namespace {
-
-/// Casts the given vector value `v` to the expected vector type `expectedTy`.
-static Value castValueTo(ConversionPatternRewriter &rewriter,
- TypedValue<VectorType> v, VectorType expectedTy) {
- // If the type matches, simply return the value itself.
- if (v.getType() == expectedTy)
- return v;
- // If only shape differs, use shape cast.
- if (isa<VectorType>(v.getType()) &&
- v.getType().getNumElements() == expectedTy.getNumElements())
- return vector::ShapeCastOp::create(rewriter, v.getLoc(), expectedTy, v);
-
- // Else create an unrealized cast.
- auto newOp = UnrealizedConversionCastOp::create(rewriter, v.getLoc(),
- expectedTy, ValueRange{v});
- return newOp.getResult(0);
-}
-
-/// A vector::MultiDimReductionOp at subgroup level in expected form if, it has
-/// exactly 1 reduction dimension, it had valid result layout attribute, and
-/// result type can be distributed to lanes using the layout.
-static bool isValidSubgroupMultiReductionOp(vector::MultiDimReductionOp op) {
- auto resLayout = xegpu::getTemporaryLayout(op->getOpResult(0));
- // If no layout, not valid.
- if (!resLayout || !resLayout.isForSubgroup())
- return false;
- // Scalar result (e.g., vector<32xf32> to f32) is valid.
- if (op.getType().isIntOrFloat())
- return op.getReductionDims().size() == 1;
- VectorType resTy = dyn_cast<VectorType>(op.getType());
- if (!resTy)
- return false;
- // Compute the distributed result vector type based on the layout.
- FailureOr<VectorType> resDistTypeOrFailure =
- getDistVecTypeBasedOnLaneLayout(resLayout, resTy);
- if (failed(resDistTypeOrFailure))
- return false;
- return op.getReductionDims().size() == 1;
-}
-
-/// A vector::MultiDimReductionOp is doing lane-local reduction if each workitem
-/// is doing its own local reduction. In this case the result layout ensures
-/// that result vector is distributed to lanes, i.e. the result vector type is
-/// different from the distributed result vector type.
-static bool isReductionLaneLocal(vector::MultiDimReductionOp op) {
- // Must be valid MultiDimReductionOp.
- assert(isValidSubgroupMultiReductionOp(op) && "Expecting a valid subgroup "
- "MultiDimReductionOp");
- auto resLayout = xegpu::getTemporaryLayout(op->getOpResult(0));
- VectorType resTy = dyn_cast<VectorType>(op.getType());
- auto resDistTypeOrFailure = getDistVecTypeBasedOnLaneLayout(resLayout, resTy);
- return resTy != resDistTypeOrFailure.value();
-}
-
-/// Given a vector type and its distributed vector type, return the list of
-/// dimensions that are distributed.
-static SmallVector<int64_t> getDistributedDims(VectorType originalType,
- VectorType distributedType) {
- assert(originalType.getRank() == distributedType.getRank() &&
- "original and distributed vector types must have the same rank");
- SmallVector<int64_t> distributedDims;
- for (int64_t i = 0; i < originalType.getRank(); ++i) {
- if (distributedType.getDimSize(i) != originalType.getDimSize(i))
- distributedDims.push_back(i);
- }
- return distributedDims;
-}
-
-/// Distributes a subgroup-level CreateNdDesc op to workitem-level CreateNdDesc
-/// op. This simply drops the layout attribute from the tensor descriptor type.
-struct SgToWiCreateNdDesc : public OpConversionPattern<xegpu::CreateNdDescOp> {
- using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(xegpu::CreateNdDescOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- xegpu::TensorDescType resultType = op.getType();
- // If no layout, nothing to do.
- if (!resultType.getLayout())
- return failure();
-
- auto newOp = xegpu::CreateNdDescOp::create(
- rewriter, op.getLoc(), resultType.dropLayouts(), op.getOperands(),
- op->getAttrs());
- rewriter.replaceOp(op, newOp.getResult());
- return success();
- }
-};
-
-/// Distributes a subgroup-level LoadNd op to workitem-level LoadNd op. Output
-/// of workitem-level LoadNd op is 1D. ShapeCast is added to restore the
-/// original rank.
-struct SgToWiLoadNd : public OpConversionPattern<xegpu::LoadNdOp> {
- using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(xegpu::LoadNdOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- xegpu::DistributeLayoutAttr layout = op.getAnchorLayout();
- // If no layout, nothing to do.
- if (!layout)
- return failure();
- // Check if the layout attached to the tensor descriptor is same as the
- // anchor layout. Otherwise, this is a conflict.
- if (op.getTensorDescType().getLayout() != layout)
- return rewriter.notifyMatchFailure(
- op, "conflicting layout attributes on tensor descriptor and anchor");
- auto uArch = getUArch(xegpu::getChipStr(op).value_or(""));
- if (!uArch)
- return rewriter.notifyMatchFailure(
- op, "xegpu::LoadNdOp require target attribute attached to "
- "determine transpose "
- "requirement");
- auto supportedWiResultTyOrFailure =
- xegpu::getDistributedVectorType(op.getTensorDescType());
- auto expectedWiResultTyOrFailure =
- xegpu::getDistVecTypeBasedOnLaneLayout(layout, op.getType());
- if (failed(supportedWiResultTyOrFailure))
- return rewriter.notifyMatchFailure(
- op, "unable to compute the workitem vector type for LoadNdOp");
- if (failed(expectedWiResultTyOrFailure))
- return rewriter.notifyMatchFailure(
- op,
- "unable to compute expected workitem vector type from lane layout");
- auto newOp = xegpu::LoadNdOp::create(
- rewriter, op.getLoc(), supportedWiResultTyOrFailure.value(),
- adaptor.getTensorDesc(), op.getMixedOffsets(), op.getPackedAttr(),
- op.getTransposeAttr(), op.getL1HintAttr(), op.getL2HintAttr(),
- op.getL3HintAttr(), /**layout**/ nullptr);
- // Set the packed attribute if the layout requires it.
- newOp.setPacked(xegpu::requirePacked(cast<xegpu::LayoutAttr>(layout)));
- // Set the transpose attribute if the layout requires it.
- if (xegpu::requireTranspose(cast<xegpu::LayoutAttr>(layout), uArch))
- newOp.setTranspose(DenseI64ArrayAttr::get(rewriter.getContext(), {1, 0}));
- rewriter.replaceOp(op, castValueTo(rewriter, newOp.getResult(),
- expectedWiResultTyOrFailure.value()));
- return success();
- }
-};
-
-/// Distributes a subgroup-level StoreNd op to workitem-level StoreNd op. Stored
-/// value in workitem-level StoreNd op is 1D. ShapeCast is added to cast the
-/// incoming value to 1D.
-struct SgToWiStoreNd : public OpConversionPattern<xegpu::StoreNdOp> {
- using OpConversionPattern<xegpu::StoreNdOp>::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(xegpu::StoreNdOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- xegpu::DistributeLayoutAttr layout = op.getAnchorLayout();
- // If no layout, nothing to do.
- if (!layout)
- return failure();
- // Check if the layout attached to the tensor descriptor and value layout is
- // same as the anchor layout. Otherwise, this is a conflict.
- if (op.getTensorDescType().getLayout() != layout)
- return rewriter.notifyMatchFailure(
- op, "conflicting layout attributes on tensor descriptor and anchor");
- auto valueLayout = xegpu::getDistributeLayoutAttr(op->getOpOperand(0));
- if (valueLayout != layout)
- return rewriter.notifyMatchFailure(
- op, "conflicting layout attributes on value and anchor");
- auto supportedWiValueTyOrFailure =
- xegpu::getDistributedVectorType(op.getTensorDescType());
- if (failed(supportedWiValueTyOrFailure))
- return rewriter.notifyMatchFailure(
- op,
- "unable to compute wi vector type for StoreNdOp value from tensor "
- "descriptor");
-
- xegpu::StoreNdOp::create(
- rewriter, op.getLoc(),
- castValueTo(rewriter, cast<TypedValue<VectorType>>(adaptor.getValue()),
- supportedWiValueTyOrFailure.value()),
- adaptor.getTensorDesc(), op.getMixedOffsets(), op.getL1HintAttr(),
- op.getL2HintAttr(), op.getL3HintAttr(), /**layout**/ nullptr);
- rewriter.eraseOp(op);
- return success();
- }
-};
-
-/// Distributes a subgroup-level Dpas op to workitem-level Dpas op. All inpputs
-/// and output of workitem-level Dpas op are 1D. Necessary casts are added to
-/// convert the inputs and output to/from 1D.
-struct SgToWiDpas : public OpConversionPattern<xegpu::DpasOp> {
- using OpConversionPattern<xegpu::DpasOp>::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(xegpu::DpasOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- // llvm::errs() << "DpasOpPattern matchAndRewrite called\n";
- // Check if the op has A, B and CD layouts attached.
- auto layoutA = cast<xegpu::LayoutAttr>(op.getLayoutAAttr());
- auto layoutB = cast<xegpu::LayoutAttr>(op.getLayoutBAttr());
- auto layoutCd = cast<xegpu::LayoutAttr>(op.getLayoutCdAttr());
- if (!layoutA || !layoutB || !layoutCd)
- return failure();
- // llvm::errs() << "tryning to calculate wi types for dpas op\n";
- auto wiResultTyOrFailure =
- xegpu::getDistributedVectorType(op.getType(), layoutCd);
- auto wiATypeOrFailure =
- xegpu::getDistributedVectorType(op.getLhs().getType(), layoutA);
- auto wiBTypeOrFailure =
- xegpu::getDistributedVectorType(op.getRhs().getType(), layoutB);
- auto expectedWiResultTyOrFailure =
- xegpu::getDistVecTypeBasedOnLaneLayout(layoutCd, op.getType());
- if (failed(wiResultTyOrFailure) || failed(wiATypeOrFailure) ||
- failed(wiBTypeOrFailure))
- return rewriter.notifyMatchFailure(
- op, "failed to calculate supported workitem vector types for DpasOp "
- "from layouts");
- if (failed(expectedWiResultTyOrFailure))
- return rewriter.notifyMatchFailure(
- op, "unable to compute expected workitem vector type for DpasOp from "
- "lane layout");
- auto newOp = xegpu::DpasOp::create(
- rewriter, op->getLoc(), wiResultTyOrFailure.value(),
- castValueTo(rewriter, cast<TypedValue<VectorType>>(adaptor.getLhs()),
- wiATypeOrFailure.value()),
- castValueTo(rewriter, cast<TypedValue<VectorType>>(adaptor.getRhs()),
- wiBTypeOrFailure.value()),
- castValueTo(rewriter, cast<TypedValue<VectorType>>(adaptor.getAcc()),
- wiResultTyOrFailure.value()),
- /** layoutA**/ nullptr,
- /** layoutB**/ nullptr, /** layoutCd**/ nullptr);
- // Explicitly set the new types to enable correct type materializations.
- rewriter.replaceOp(op, castValueTo(rewriter, newOp.getResult(),
- expectedWiResultTyOrFailure.value()));
- return success();
- }
-};
-
-/// Distributes elementwise ops to workitem-level elementwise ops. This
-/// currently handles elementwise ops with single result only.
-struct SgToWiElementWise : public ConversionPattern {
- SgToWiElementWise(TypeConverter &typeConverter, MLIRContext *ctx)
- : ConversionPattern(MatchAnyOpTypeTag(), /*benefit=*/1, ctx) {}
-
- LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const override {
- // Only match ops with elementwise trait and single result.
- if (!OpTrait::hasElementwiseMappableTraits(op) || op->getNumResults() != 1)
- return failure();
-
- auto resultType = dyn_cast<VectorType>(op->getResult(0).getType());
- if (!resultType)
- return rewriter.notifyMatchFailure(
- op, "operation result is not a vector type");
-
- xegpu::DistributeLayoutAttr layout =
- xegpu::getTemporaryLayout(llvm::cast<OpResult>(op->getResult(0)));
- if (!layout || !layout.isForSubgroup())
- return rewriter.notifyMatchFailure(
- op, "operation result does not have subgroup distribute layout");
-
- auto wiShapeOrFailure =
- xegpu::getDistVecTypeBasedOnLaneLayout(layout, resultType);
-
- if (failed(wiShapeOrFailure))
- return rewriter.notifyMatchFailure(
- op, "unable to compute workitem vector type from the layout");
-
- VectorType newResultType = wiShapeOrFailure.value();
- OperationState state(op->getLoc(), op->getName());
- state.addOperands(operands);
- state.addTypes(newResultType);
- // Copy all attributes except for DistributeLayoutAttr.
- for (auto attr : op->getAttrs()) {
- if (!isa<xegpu::DistributeLayoutAttr>(attr.getValue()))
- state.addAttribute(attr.getName(), attr.getValue());
- }
- Operation *newOp = rewriter.create(state);
-
- rewriter.replaceOp(op, newOp->getResult(0));
- return success();
- }
-};
-
-/// Distributes a subgroup-level arith ConstantOp to workitem-level arith
-/// ConstantOp.
-struct SgToWiArithConstant : public OpConversionPattern<arith::ConstantOp> {
- using OpConversionPattern<arith::ConstantOp>::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- auto resultType = dyn_cast<VectorType>(op.getType());
- if (!resultType)
- return failure();
-
- // Only handle dense vector constants
- auto dense = dyn_cast<SplatElementsAttr>(op.getValue());
- if (!dense)
- return rewriter.notifyMatchFailure(
- op, "only dense splat vector constants are supported");
-
- xegpu::DistributeLayoutAttr layout =
- xegpu::getTemporaryLayout(llvm::cast<OpResult>(op.getResult()));
- if (!layout || !layout.isForSubgroup())
- return rewriter.notifyMatchFailure(
- op, "operation result does not have subgroup distribute layout");
-
- auto wiShapeOrFailure =
- xegpu::getDistVecTypeBasedOnLaneLayout(layout, resultType);
-
- if (failed(wiShapeOrFailure))
- return rewriter.notifyMatchFailure(
- op, "unable to compute workitem vector type from the layout");
-
- VectorType newResultType = wiShapeOrFailure.value();
- auto sclarValue = dense.getSplatValue<Attribute>();
- auto newDenseAttr = DenseElementsAttr::get(newResultType, sclarValue);
-
- auto newOp = arith::ConstantOp::create(rewriter, op.getLoc(), newResultType,
- newDenseAttr);
- rewriter.replaceOp(op, newOp.getResult());
- return success();
- }
-};
-
-/// Distributes a subgroup-level PrefetchNd op to workitem-level PrefetchNd op.
-struct SgToWiPrefetchNd : public OpConversionPattern<xegpu::PrefetchNdOp> {
- using OpConversionPattern<xegpu::PrefetchNdOp>::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(xegpu::PrefetchNdOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- xegpu::DistributeLayoutAttr layout = op.getAnchorLayout();
- // If no layout, nothing to do.
- if (!layout)
- return failure();
-
- xegpu::PrefetchNdOp::create(rewriter, op.getLoc(), adaptor.getTensorDesc(),
- op.getMixedOffsets(), op.getL1HintAttr(),
- op.getL2HintAttr(), op.getL3HintAttr(),
- /**layout**/ nullptr);
- rewriter.eraseOp(op);
- return success();
- }
-};
-
-/// Distributes a subgroup-level LoadGather (xegpu.load) op to workitem-level.
-///
-/// Example 1 (1D, no chunk size):
-/// layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>
-/// %mask = producer_op : vector<16xi1>
-/// %offset = producer_op : vector<16xindex>
-/// %0 = xegpu.load %src[%offset], %mask : memref<256xf16>,
-/// vector<16xindex>, vector<16xi1> -> vector<16xf16>
-/// Distributed to:
-/// %mask = producer_op : vector<1xi1>
-/// %offset = producer_op : vector<1xindex>
-/// %0 = xegpu.load %src[%offset], %mask : memref<256xf16>,
-/// vector<1xindex>, vector<1xi1> -> vector<1xf16>
-///
-/// Example 2 (2D with chunk size, same mask & offset):
-/// layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>
-/// %0 = xegpu.load %src[%offset], %mask <{chunk_size=8}> :
-/// memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
-/// Distributed to:
-/// %0 = xegpu.load %src[%offset], %mask <{chunk_size=8}> :
-/// memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<8xf16>
-///
-/// Example 3 (3D with leading unit dims):
-/// layout = #xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>
-/// %mask = producer_op : vector<1x1x16xi1>
-/// %offset = producer_op : vector<1x1x16xindex>
-/// %0 = xegpu.load %src[%offset], %mask : memref<256xf16>,
-/// vector<1x1x16xindex>, vector<1x1x16xi1> -> vector<1x1x16xf16>
-/// Distributed to:
-/// %mask = producer_op : vector<1x1x1xi1>
-/// %offset = producer_op : vector<1x1x1xindex>
-/// %0 = xegpu.load %src[%offset], %mask : memref<256xf16>,
-/// vector<1xindex>, vector<1xi1> -> vector<1xf16>
-struct SgToWiLoadGather : public OpConversionPattern<xegpu::LoadGatherOp> {
- using OpConversionPattern<xegpu::LoadGatherOp>::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(xegpu::LoadGatherOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- xegpu::DistributeLayoutAttr layout = op.getAnchorLayout();
- if (!layout)
- return failure();
-
- VectorType origResultTy = op.getValueType();
- if (!origResultTy)
- return failure();
-
- // Check that leading dimensions are unit.
- int chunkSize = op.getChunkSize().value_or(1);
- int effectiveVecRank = (chunkSize == 1) ? 1 : 2;
- ArrayRef<int64_t> shape = origResultTy.getShape();
- if (llvm::any_of(
- shape.take_front(origResultTy.getRank() - effectiveVecRank),
- [](int64_t d) { return d != 1; }))
- return rewriter.notifyMatchFailure(
- op, "Only unit dimensions allowed for the leading "
- "dimensions of the load vector!");
-
- auto distResultTyOrFailure =
- xegpu::getDistVecTypeBasedOnLaneLayout(layout, origResultTy);
- if (failed(distResultTyOrFailure))
- return rewriter.notifyMatchFailure(
- op,
- "unable to compute expected workitem vector type from lane layout");
-
- VectorType distResultTy = distResultTyOrFailure.value();
- VectorType distResultTy1D = VectorType::get({distResultTy.getNumElements()},
- distResultTy.getElementType());
-
- // Flatten offsets and mask to 1D to match the 1D result type.
- Value distOffsets = adaptor.getOffsets();
- auto distOffsetsTy = cast<VectorType>(distOffsets.getType());
- VectorType offsetsTy1D = VectorType::get({distOffsetsTy.getNumElements()},
- distOffsetsTy.getElementType());
- distOffsets = castValueTo(
- rewriter, cast<TypedValue<VectorType>>(distOffsets), offsetsTy1D);
-
- Value distMask = adaptor.getMask();
- auto distMaskTy = cast<VectorType>(distMask.getType());
- VectorType maskTy1D = VectorType::get({distMaskTy.getNumElements()},
- distMaskTy.getElementType());
- distMask =
- castValueTo(rewriter, cast<TypedValue<VectorType>>(distMask), maskTy1D);
-
- Value distSource = adaptor.getSource();
- auto newOp = xegpu::LoadGatherOp::create(
- rewriter, op.getLoc(), distResultTy1D, distSource, distOffsets,
- distMask, op.getChunkSizeAttr(), op.getL1HintAttr(), op.getL2HintAttr(),
- op.getL3HintAttr(), /*layout=*/nullptr);
-
- Value result = newOp->getResult(0);
- if (distResultTy1D != distResultTy)
- result = castValueTo(rewriter, cast<TypedValue<VectorType>>(result),
- distResultTy);
- rewriter.replaceOp(op, result);
- return success();
- }
-};
-
-/// This pattern distributes a subgroup-level vector.reduction op to
-/// workitem-level. This require shuffling the data across the workitems (using
-/// gpu::ShuffleOp) and reducing in stages until all workitems have the final
-/// result.
-struct SgToWiVectorReduction : public OpConversionPattern<vector::ReductionOp> {
- using OpConversionPattern<vector::ReductionOp>::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(vector::ReductionOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- auto layout = xegpu::getDistributeLayoutAttr(op.getVector());
-
- // If no layout, nothing to do.
- if (!layout || !layout.isForSubgroup())
- return failure();
-
- VectorType srcVecType = op.getSourceVectorType();
- // Only rank 1 vectors supported.
- if (srcVecType.getRank() != 1)
- return rewriter.notifyMatchFailure(
- op, "Only rank 1 reductions can be distributed.");
- // Lane layout must have the same rank as the vector.
- if (layout.getRank() != srcVecType.getRank())
- return rewriter.notifyMatchFailure(
- op, "Layout rank does not match vector rank.");
-
- // Get the subgroup size from the layout.
- int64_t sgSize = layout.getEffectiveLaneLayoutAsInt()[0];
- const uArch *uArch = getUArch(xegpu::getChipStr(op).value_or(""));
- if (!uArch)
- return rewriter.notifyMatchFailure(
- op, "xegpu::ReductionOp require target attribute attached to "
- "determine subgroup size");
-
- // Only subgroup-sized vectors supported.
- if (sgSize != uArch->getSubgroupSize() ||
- srcVecType.getShape()[0] % sgSize != 0)
- return rewriter.notifyMatchFailure(op,
- "Invalid layout or reduction vector "
- "dimension must match subgroup size.");
-
- if (!op.getType().isIntOrFloat())
- return rewriter.notifyMatchFailure(
- op, "Reduction distribution currently only supports floats and "
- "integer types.");
-
- // Get the distributed vector (per work-item portion).
- Value laneValVec = adaptor.getVector();
-
- // Distribute and reduce across work-items in the subgroup.
- Value fullReduce = xegpu::subgroupReduction(
- op.getLoc(), rewriter, laneValVec, op.getKind(), sgSize);
-
- // If there's an accumulator, combine it with the reduced value.
- if (adaptor.getAcc())
- fullReduce = vector::makeArithReduction(
- rewriter, op.getLoc(), op.getKind(), fullReduce, adaptor.getAcc());
-
- rewriter.replaceOp(op, fullReduce);
- return success();
- }
-};
-
-/// This pattern distributes a subgroup-level vector.multi_reduction op to
-/// workitem-level only if the reduction is lane-local. This means that
-/// reduction dimension is not distributed to lanes and each lane does its own
-/// local reduction.
-struct SgToWiMultiDimReduction
- : public OpConversionPattern<vector::MultiDimReductionOp> {
- using OpConversionPattern<vector::MultiDimReductionOp>::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(vector::MultiDimReductionOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- Value result;
- ArrayRef<int64_t> reductionDims = op.getReductionDims();
- assert(reductionDims.size() == 1 &&
- "Expecting single reduction dimension for subgroup multi "
- "reduction op");
- // For rank > 2, ensure leading dimensions are unit.
- VectorType sourceType = op.getSourceVectorType();
- int64_t rank = sourceType.getRank();
- if (rank > 2) {
- ArrayRef<int64_t> shape = sourceType.getShape();
- if (llvm::any_of(shape.take_front(rank - 2),
- [](int64_t d) { return d != 1; }))
- return rewriter.notifyMatchFailure(
- op, "only unit leading dimensions are supported for "
- "multi_reduction with rank > 2");
- }
- // Handle scalar result: full reduction of a distributed vector to a
- // scalar. First do a local vector reduction, then cross-lane shuffles.
- if (op.getType().isIntOrFloat()) {
- auto reductionDim = reductionDims[0];
- VectorType origSourceType = op.getSourceVectorType();
- int64_t reductionDimSize = origSourceType.getShape()[reductionDim];
- // Local reduction to scalar, then cross-lane butterfly shuffles.
- result =
- xegpu::subgroupReduction(op.getLoc(), rewriter, adaptor.getSource(),
- op.getKind(), reductionDimSize);
- // Combine with accumulator if present.
- if (adaptor.getAcc())
- result = vector::makeArithReduction(rewriter, op.getLoc(), op.getKind(),
- result, adaptor.getAcc());
- } else if (isReductionLaneLocal(op)) {
- auto resLayout = xegpu::getTemporaryLayout(op->getOpResult(0));
- VectorType resVecTy = dyn_cast<VectorType>(op.getType());
- auto resDistVecTyOrFailure =
- getDistVecTypeBasedOnLaneLayout(resLayout, resVecTy);
- // For lane local reduction, simply create a new MultiDimReductionOp using
- // adaptor operands and the new result type.
- result = vector::MultiDimReductionOp::create(
- rewriter, op.getLoc(), resDistVecTyOrFailure.value(), op.getKind(),
- adaptor.getSource(), adaptor.getAcc(), op.getReductionDims());
- } else {
- auto reductionDim = reductionDims[0];
- VectorType sourceType = op.getSourceVectorType();
- int64_t reductionDimSize = sourceType.getShape()[reductionDim];
- result = xegpu::lowerCrossLaneReductionToShuffles(
- cast<TypedValue<VectorType>>(adaptor.getSource()),
- cast<TypedValue<VectorType>>(adaptor.getAcc()), op.getKind(),
- reductionDim, reductionDimSize, op.getLoc(), rewriter);
- }
- rewriter.replaceOp(op, result);
- return success();
- }
-};
-
-/// Helper to compute distributed coordinates for matrix ops.
-/// When not using subgroup_block_io, each workitem computes its own
-/// coordinates based on the layout and lane ID.
-static SmallVector<Value> computeDistributedCoordsForMatrixOp(
- ConversionPatternRewriter &rewriter, Location loc,
- xegpu::DistributeLayoutAttr layout, ArrayRef<int64_t> payloadShape,
- ValueRange origOffsets) {
- Value laneId = gpu::LaneIdOp::create(rewriter, loc, rewriter.getIndexType(),
- /*upperBound=*/mlir::IntegerAttr());
- auto maybeCoords =
- layout.computeDistributedCoords(rewriter, loc, laneId, payloadShape);
- if (failed(maybeCoords))
- return {};
- assert(maybeCoords.value().size() == 1 &&
- "Expected one set of distributed offsets");
- SmallVector<OpFoldResult> ofrVec = xegpu::addWithRightAligned(
- rewriter, loc, getAsOpFoldResult(maybeCoords.value()[0]),
- getAsOpFoldResult(origOffsets));
- return llvm::map_to_vector(ofrVec, llvm::CastTo<Value>);
-}
-
-/// This pattern distributes a subgroup-level LoadMatrix op to workitem-level.
-struct SgToWiLoadMatrix : public OpConversionPattern<xegpu::LoadMatrixOp> {
- using OpConversionPattern<xegpu::LoadMatrixOp>::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(xegpu::LoadMatrixOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- auto layout = op.getLayoutAttr();
- // If no layout, nothing to do.
- if (!layout)
- return failure();
-
- VectorType sgPayloadTy = dyn_cast<VectorType>(op.getResult().getType());
- if (!sgPayloadTy)
- return rewriter.notifyMatchFailure(
- op, "the matrix op payload must be a vector type");
-
- auto loc = op.getLoc();
- auto offsets = op.getMixedOffsets();
- if (offsets.empty())
- return rewriter.notifyMatchFailure(op, "the load op must have offsets");
-
- FailureOr<VectorType> distPayloadTyOrFailure =
- getDistVecTypeBasedOnLaneLayout(layout, sgPayloadTy);
- if (failed(distPayloadTyOrFailure))
- return rewriter.notifyMatchFailure(
- op, "Failed to distribute matrix op payload based on layout.");
-
- SmallVector<Value> offsetsAsValues =
- vector::getAsValues(rewriter, loc, offsets);
-
- SmallVector<Value> newCoords = offsetsAsValues;
- if (!op.getSubgroupBlockIoAttr()) {
- newCoords = computeDistributedCoordsForMatrixOp(
- rewriter, loc, layout, sgPayloadTy.getShape(), offsetsAsValues);
- if (newCoords.empty())
- return rewriter.notifyMatchFailure(
- op, "Failed to compute distributed coordinates.");
- }
-
- SmallVector<int64_t> newConstOffsets(op.getConstOffsets().size(),
- ShapedType::kDynamic);
- DenseI64ArrayAttr newConstOffsetsAttr =
- rewriter.getDenseI64ArrayAttr(newConstOffsets);
-
- auto newOp = xegpu::LoadMatrixOp::create(
- rewriter, loc, *distPayloadTyOrFailure, adaptor.getMemDesc(),
- ValueRange(newCoords), newConstOffsetsAttr, op.getSubgroupBlockIoAttr(),
- xegpu::DistributeLayoutAttr{});
- rewriter.replaceOp(op, newOp.getResult());
- return success();
- }
-};
-
-/// Distributes a subgroup-level vector.transpose op to workitem-level.
-struct SgToWiVectorTranspose : public OpConversionPattern<vector::TransposeOp> {
- using OpConversionPattern<vector::TransposeOp>::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(vector::TransposeOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- xegpu::DistributeLayoutAttr sourceLayout =
- xegpu::getTemporaryLayout(op->getOpOperand(0));
- xegpu::DistributeLayoutAttr resultLayout =
- xegpu::getTemporaryLayout(op->getOpResult(0));
- if (!sourceLayout || !resultLayout)
- return rewriter.notifyMatchFailure(
- op, "the source or result vector of the transpose op lacks layout "
- "attribute");
- ArrayRef<int64_t> perm = op.getPermutation();
- // Result layout must be a transpose of source layout.
- if (!resultLayout.isTransposeOf(sourceLayout, perm,
- xegpu::LayoutKind::Lane))
- return rewriter.notifyMatchFailure(
- op, "the source or result vector layouts must be transposes of "
- "each other");
- FailureOr<VectorType> distributedResultTypeOrFailure =
- getDistVecTypeBasedOnLaneLayout(resultLayout, op.getResultVectorType());
- if (failed(distributedResultTypeOrFailure))
- return rewriter.notifyMatchFailure(
- op, "Failed to distribute the result vector type in "
- "vector::Transpose op");
- auto newOp = vector::TransposeOp::create(rewriter, op.getLoc(),
- adaptor.getVector(), perm);
- rewriter.replaceOp(op, castValueTo(rewriter, newOp.getResult(),
- distributedResultTypeOrFailure.value()));
- return success();
- }
-};
-
-/// Distributes a subgroup-level vector.bitcast op to workitem-level.
-/// Bitcast only impacts the innermost dimension of the source/result vectors.
-struct SgToWiVectorBitcast : public OpConversionPattern<vector::BitCastOp> {
- using OpConversionPattern<vector::BitCastOp>::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(vector::BitCastOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- xegpu::DistributeLayoutAttr resultLayout =
- xegpu::getTemporaryLayout(op->getOpResult(0));
- if (!resultLayout)
- return rewriter.notifyMatchFailure(
- op, "result vector of the bitcast op lacks layout attribute");
- FailureOr<VectorType> distributedResultTypeOrFailure =
- getDistVecTypeBasedOnLaneLayout(resultLayout, op.getResultVectorType());
- if (failed(distributedResultTypeOrFailure))
- return rewriter.notifyMatchFailure(
- op, "Failed to distribute the result vector type in "
- "vector::BitCast op");
- auto newOp = vector::BitCastOp::create(
- rewriter, op.getLoc(), distributedResultTypeOrFailure.value(),
- adaptor.getSource());
- rewriter.replaceOp(op, newOp.getResult());
- return success();
- }
-};
-
-/// Distributes a subgroup-level vector.create_mask or vector.constant_mask op
-/// to workitem-level. Uses `computeDistributedCoords()` to obtain the
-/// coordinates each workitem owns, then compares each coordinate against the
-/// original mask bounds using `arith.cmpi slt`. The per-element boolean
-/// results are assembled into the distributed mask vector.
-///
-/// For multi-dimensional masks, the element is in-bounds when ALL dimensions
-/// satisfy `coord[i] < bound[i]`.
-///
-/// Example (1D):
-/// layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>
-/// %mask = vector.create_mask %m0 : vector<16xi1>
-/// For lane k, computeDistributedCoords gives coord = [k], so:
-/// %in_bounds = arith.cmpi slt, %coord, %m0 → i1
-/// %mask = vector.broadcast %in_bounds : i1 to vector<1xi1>
-///
-/// Example (2D):
-/// layout = #xegpu.layout<lane_layout = [8, 2], lane_data = [1, 1]>
-/// %mask = vector.create_mask %m0, %m1 : vector<8x4xi1>
-/// Each WI owns a 1x2 slice. computeDistributedCoords returns 2 coords:
-/// [[r0, c0], [r0, c1]]
-/// For each coord: in_bounds = (r < m0) && (c < m1)
-/// %mask = vector.from_elements %bit0, %bit1 : vector<1x2xi1>
-template <typename OpType,
- typename = std::enable_if_t<llvm::is_one_of<
- OpType, vector::CreateMaskOp, vector::ConstantMaskOp>::value>>
-struct SgToWiCreateMask : public OpConversionPattern<OpType> {
- using OpConversionPattern<OpType>::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- xegpu::DistributeLayoutAttr layout =
- xegpu::getTemporaryLayout(op->getOpResult(0));
- if (!layout || !layout.isForSubgroup())
- return rewriter.notifyMatchFailure(
- op, "operation result does not have subgroup distribute layout");
-
- VectorType origType = op.getType();
- FailureOr<VectorType> distTypeOrFailure =
- getDistVecTypeBasedOnLaneLayout(layout, origType);
- if (failed(distTypeOrFailure))
- return rewriter.notifyMatchFailure(
- op, "unable to compute workitem vector type from the layout");
-
- VectorType distType = distTypeOrFailure.value();
- Location loc = op.getLoc();
-
- // Materialize the original mask bounds as Values.
- SmallVector<Value> origBounds;
- if constexpr (std::is_same_v<OpType, vector::CreateMaskOp>) {
- origBounds.append(op.getOperands().begin(), op.getOperands().end());
- } else {
- auto dimSizes = op.getMaskDimSizesAttr().asArrayRef();
- for (auto dimSize : dimSizes)
- origBounds.push_back(
- arith::ConstantIndexOp::create(rewriter, loc, dimSize).getResult());
- }
-
- ArrayRef<int64_t> origShape = origType.getShape();
-
- // Use computeDistributedCoords to get the coordinates each WI owns.
- Value laneId = gpu::LaneIdOp::create(rewriter, loc, rewriter.getIndexType(),
- /*upperBound=*/mlir::IntegerAttr());
- auto maybeCoordsVec =
- layout.computeDistributedCoords(rewriter, loc, laneId, origShape);
- if (failed(maybeCoordsVec))
- return rewriter.notifyMatchFailure(
- op, "failed to compute distributed coordinates from layout");
-
- SmallVector<SmallVector<Value>> coordsVec = maybeCoordsVec.value();
- int64_t numElements = distType.getNumElements();
- assert(static_cast<int64_t>(coordsVec.size()) == numElements &&
- "number of coordinate sets must match number of distributed "
- "elements");
-
- // For each element, compare all coordinates against bounds.
- Value trueVal =
- arith::ConstantIntOp::create(rewriter, loc, /*value=*/1, /*width=*/1);
- SmallVector<Value> maskBits;
- for (auto &coords : coordsVec) {
- Value inBounds = trueVal;
- for (size_t i = 0; i < coords.size(); ++i) {
- Value cmp = arith::CmpIOp::create(
- rewriter, loc, arith::CmpIPredicate::slt, coords[i], origBounds[i]);
- inBounds = arith::AndIOp::create(rewriter, loc, inBounds, cmp);
- }
- maskBits.push_back(inBounds);
- }
-
- // Build the distributed mask vector.
- Value result;
- if (numElements == 1) {
- result =
- vector::BroadcastOp::create(rewriter, loc, distType, maskBits[0]);
- } else {
- result =
- vector::FromElementsOp::create(rewriter, loc, distType, maskBits);
- }
- rewriter.replaceOp(op, result);
- return success();
- }
-};
-
-/// This pattern distributes a subgroup-level StoreMatrix op to workitem-level.
-struct SgToWiStoreMatrix : public OpConversionPattern<xegpu::StoreMatrixOp> {
- using OpConversionPattern<xegpu::StoreMatrixOp>::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(xegpu::StoreMatrixOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- auto layout = op.getLayoutAttr();
- // If no layout, nothing to do.
- if (!layout)
- return failure();
-
- VectorType sgPayloadTy = dyn_cast<VectorType>(op.getData().getType());
- if (!sgPayloadTy)
- return rewriter.notifyMatchFailure(
- op, "the matrix op payload must be a vector type");
-
- auto loc = op.getLoc();
- auto offsets = op.getMixedOffsets();
- if (offsets.empty())
- return rewriter.notifyMatchFailure(op, "the store op must have offsets");
-
- FailureOr<VectorType> distPayloadTyOrFailure =
- getDistVecTypeBasedOnLaneLayout(layout, sgPayloadTy);
- if (failed(distPayloadTyOrFailure))
- return rewriter.notifyMatchFailure(
- op, "Failed to distribute matrix op payload based on layout.");
-
- SmallVector<Value> offsetsAsValues =
- vector::getAsValues(rewriter, loc, offsets);
-
- SmallVector<Value> newCoords = offsetsAsValues;
- if (!op.getSubgroupBlockIoAttr()) {
- newCoords = computeDistributedCoordsForMatrixOp(
- rewriter, loc, layout, sgPayloadTy.getShape(), offsetsAsValues);
- if (newCoords.empty())
- return rewriter.notifyMatchFailure(
- op, "Failed to compute distributed coordinates.");
- }
-
- SmallVector<int64_t> newConstOffsets(op.getConstOffsets().size(),
- ShapedType::kDynamic);
- DenseI64ArrayAttr newConstOffsetsAttr =
- rewriter.getDenseI64ArrayAttr(newConstOffsets);
-
- xegpu::StoreMatrixOp::create(
- rewriter, loc, TypeRange{},
- castValueTo(rewriter, cast<TypedValue<VectorType>>(adaptor.getData()),
- distPayloadTyOrFailure.value()),
- adaptor.getMemDesc(), ValueRange(newCoords), newConstOffsetsAttr,
- op.getSubgroupBlockIoAttr(), xegpu::DistributeLayoutAttr{});
- rewriter.eraseOp(op);
- return success();
- }
-};
-
-/// Distributes a subgroup-level StoreScatter (xegpu.store) op to
-/// workitem-level.
-///
-/// Example 1 (1D, no chunk size):
-/// layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>
-/// %mask = producer_op : vector<16xi1>
-/// %offset = producer_op : vector<16xindex>
-/// xegpu.store %payload, %src[%offset], %mask : vector<16xf16>,
-/// memref<256xf16>, vector<16xindex>, vector<16xi1>
-/// Distributed to:
-/// %mask = producer_op : vector<1xi1>
-/// %offset = producer_op : vector<1xindex>
-/// xegpu.store %payload, %src[%offset], %mask : vector<1xf16>,
-/// memref<256xf16>, vector<1xindex>, vector<1xi1>
-///
-/// Example 2 (2D with chunk size, same mask & offset):
-/// layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>
-/// xegpu.store %payload, %src[%offset], %mask <{chunk_size=8}> :
-/// vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
-/// Distributed to:
-/// xegpu.store %payload, %src[%offset], %mask <{chunk_size=8}> :
-/// vector<8xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
-///
-/// Example 3 (3D with leading unit dims):
-/// layout = #xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>
-/// %mask = producer_op : vector<1x1x16xi1>
-/// %offset = producer_op : vector<1x1x16xindex>
-/// xegpu.store %payload, %src[%offset], %mask : vector<1x1x16xf16>,
-/// memref<256xf16>, vector<1x1x16xindex>, vector<1x1x16xi1>
-/// Distributed to:
-/// %mask = producer_op : vector<1x1x1xi1>
-/// %offset = producer_op : vector<1x1x1xindex>
-/// xegpu.store %payload, %src[%offset], %mask : vector<1xf16>,
-/// memref<256xf16>, vector<1xindex>, vector<1xi1>
-struct SgToWiStoreScatter : public OpConversionPattern<xegpu::StoreScatterOp> {
- using OpConversionPattern<xegpu::StoreScatterOp>::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(xegpu::StoreScatterOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- xegpu::DistributeLayoutAttr layout = op.getAnchorLayout();
- if (!layout)
- return failure();
-
- VectorType origValueTy = op.getValueType();
- if (!origValueTy)
- return failure();
-
- // Check that all leading dimensions are unit dimensions.
- int chunkSize = op.getChunkSize().value_or(1);
- int effectiveVecRank = (chunkSize == 1) ? 1 : 2;
- ArrayRef<int64_t> shape = origValueTy.getShape();
- if (llvm::any_of(shape.take_front(origValueTy.getRank() - effectiveVecRank),
- [](int64_t d) { return d != 1; }))
- return rewriter.notifyMatchFailure(
- op, "Only unit dimensions allowed for the leading "
- "dimensions of the store vector!");
-
- auto distValueTyOrFailure =
- xegpu::getDistVecTypeBasedOnLaneLayout(layout, origValueTy);
- if (failed(distValueTyOrFailure))
- return rewriter.notifyMatchFailure(
- op,
- "unable to compute expected workitem vector type from lane layout");
-
- VectorType distValueTy = distValueTyOrFailure.value();
- VectorType distValueTy1D = VectorType::get({distValueTy.getNumElements()},
- distValueTy.getElementType());
-
- Value distValue = adaptor.getValue();
- if (distValue.getType() != distValueTy1D)
- distValue = castValueTo(rewriter, cast<TypedValue<VectorType>>(distValue),
- distValueTy1D);
-
- // Flatten offsets and mask to 1D to match the 1D value type.
- Value distOffsets = adaptor.getOffsets();
- auto distOffsetsTy = cast<VectorType>(distOffsets.getType());
- VectorType offsetsTy1D = VectorType::get({distOffsetsTy.getNumElements()},
- distOffsetsTy.getElementType());
- distOffsets = castValueTo(
- rewriter, cast<TypedValue<VectorType>>(distOffsets), offsetsTy1D);
-
- Value distMask = adaptor.getMask();
- auto distMaskTy = cast<VectorType>(distMask.getType());
- VectorType maskTy1D = VectorType::get({distMaskTy.getNumElements()},
- distMaskTy.getElementType());
- distMask =
- castValueTo(rewriter, cast<TypedValue<VectorType>>(distMask), maskTy1D);
-
- Value distDest = adaptor.getDest();
- xegpu::StoreScatterOp::create(rewriter, op.getLoc(), distValue, distDest,
- distOffsets, distMask, op.getChunkSizeAttr(),
- op.getL1HintAttr(), op.getL2HintAttr(),
- op.getL3HintAttr(), /*layout=*/nullptr);
- rewriter.eraseOp(op);
- return success();
- }
-};
-
-/// Distribute a vector::StepOp to workitem-level.
-/// The layout must have exactly 1 effective lane dimension.
-/// We completely resolve the vector::StepOp by computing the lane_data-sized
-/// subranges.
-struct SgToWiVectorStep : public OpConversionPattern<vector::StepOp> {
- using OpConversionPattern<vector::StepOp>::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(vector::StepOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- xegpu::DistributeLayoutAttr resultLayout =
- xegpu::getTemporaryLayout(op->getResult(0));
- if (!resultLayout || !resultLayout.isForSubgroup())
- return rewriter.notifyMatchFailure(
- op, "the result vector of the step op lacks subgroup layout");
-
- auto loc = op.getLoc();
- auto stepResultVecTy = op.getResult().getType();
- auto wiShapeOrFailure =
- xegpu::getDistVecTypeBasedOnLaneLayout(resultLayout, stepResultVecTy);
- if (failed(wiShapeOrFailure))
- return rewriter.notifyMatchFailure(
- op, "unable to compute workitem vector type from the layout");
- VectorType newVecTy = wiShapeOrFailure.value();
-
- Value laneId = gpu::LaneIdOp::create(rewriter, loc, rewriter.getIndexType(),
- /*upperBound=*/mlir::IntegerAttr());
- auto laneDataBlockCoords = resultLayout.computeDistributedCoords(
- rewriter, loc, laneId, stepResultVecTy.getShape());
- if (failed(laneDataBlockCoords))
- return rewriter.notifyMatchFailure(
- op, "failed to compute lane data block coordinates");
-
- auto laneDataBlockCoordsVec = laneDataBlockCoords.value();
- auto laneDataBlockLength = resultLayout.getEffectiveLaneDataAsInt()[0];
- assert(static_cast<int64_t>(laneDataBlockCoordsVec.size()) ==
- newVecTy.getNumElements() / laneDataBlockLength);
- SmallVector<Value> stepVals;
- // For each lane_data block, reconstruct its sub-range
- // from the range of SG-level vector.step.Example: vector.step
- // {slice<layout<lane_layout=[2,4,2], lane_data=[1,2,1]>, dims=[0,2]>} :
- // vector<16xindex>
- // Each logical lane holds 4 elements as 2 blocks of 2 elements each.
- // The blocks are round-robin distributed, so logical lane id 0
- // holds values [0,1, 8,9].
- for (auto &laneDataBlockCoords : laneDataBlockCoordsVec) {
- auto laneDataBlockStartCoord = laneDataBlockCoords[0];
- stepVals.push_back(laneDataBlockStartCoord);
- for (int i = 1; i < laneDataBlockLength; ++i) {
- auto offset = arith::ConstantIndexOp::create(rewriter, loc, i);
- stepVals.push_back(arith::AddIOp::create(
- rewriter, loc, laneDataBlockStartCoord, offset));
- }
- }
- assert(static_cast<int64_t>(stepVals.size()) == newVecTy.getNumElements() &&
- "Expecting the number of step values to match the number of "
- "elements in the vector");
- auto stepOpVal =
- vector::FromElementsOp::create(rewriter, loc, newVecTy, stepVals);
- rewriter.replaceOp(op, stepOpVal);
- return success();
- }
-};
-
-/// Distributes a subgroup-level vector.extract op to workitem-level. Only
-/// handles sub-vector extraction (result is VectorType, not scalar).
-struct SgToWiVectorExtract : public OpConversionPattern<vector::ExtractOp> {
- using OpConversionPattern<vector::ExtractOp>::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(vector::ExtractOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- // Only handle vector results (not scalar extraction).
- auto resultType = dyn_cast<VectorType>(op.getType());
- if (!resultType)
- return rewriter.notifyMatchFailure(op, "scalar extract not supported");
-
- xegpu::DistributeLayoutAttr layout =
- xegpu::getTemporaryLayout(op->getOpResult(0));
- if (!layout || !layout.isForSubgroup())
- return failure();
-
- // This implementation assumes distribution only happens on the innermost
- // dimension. Verify that lane_layout[0...n-2] are all unit.
- auto laneLayout = layout.getEffectiveLaneLayoutAsInt();
- if (llvm::any_of(ArrayRef<int64_t>(laneLayout).drop_back(1),
- [](int64_t v) { return v != 1; }))
- return rewriter.notifyMatchFailure(
- op, "only innermost dimension distribution is supported for "
- "vector.extract");
-
- auto newOp = vector::ExtractOp::create(
- rewriter, op.getLoc(), adaptor.getSource(), op.getMixedPosition());
- rewriter.replaceOp(op, newOp.getResult());
- return success();
- }
-};
-
-/// This pattern distributes a subgroup-level ShapeCast op to workitem-level.
-struct SgToWiVectorShapeCast : public OpConversionPattern<vector::ShapeCastOp> {
- using OpConversionPattern<vector::ShapeCastOp>::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(vector::ShapeCastOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- xegpu::DistributeLayoutAttr resultLayout =
- xegpu::getTemporaryLayout(op->getOpResult(0));
- if (!resultLayout || !resultLayout.isForSubgroup())
- return rewriter.notifyMatchFailure(
- op, "the result vector of the shape_cast op lacks subgroup layout");
-
- auto resultDistTypeOrFailure = xegpu::getDistVecTypeBasedOnLaneLayout(
- resultLayout, op.getResultVectorType());
- if (failed(resultDistTypeOrFailure))
- return rewriter.notifyMatchFailure(
- op, "failed to get distributed vector type for result");
-
- Value source = adaptor.getSource();
- auto newShapeCast = vector::ShapeCastOp::create(
- rewriter, op.getLoc(), resultDistTypeOrFailure.value(), source);
- rewriter.replaceOp(op, newShapeCast);
- return success();
- }
-};
-
-/// Distributes a subgroup-level vector.extract_strided_slice op to
-/// workitem-level. If the result is distributed, the offsets and sizes are
-/// adjusted to match the distributed types.
-struct SgToWiVectorExtractStridedSlice
- : public OpConversionPattern<vector::ExtractStridedSliceOp> {
- using OpConversionPattern<vector::ExtractStridedSliceOp>::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(vector::ExtractStridedSliceOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- xegpu::DistributeLayoutAttr resultLayout =
- xegpu::getTemporaryLayout(op->getOpResult(0));
- if (!resultLayout || !resultLayout.isForSubgroup())
- return failure();
-
- VectorType resultType = op.getType();
- auto distResultTyOrFailure =
- xegpu::getDistVecTypeBasedOnLaneLayout(resultLayout, resultType);
- if (failed(distResultTyOrFailure))
- return rewriter.notifyMatchFailure(
- op, "unable to compute distributed vector type from lane layout");
- VectorType distResultTy = *distResultTyOrFailure;
-
- SmallVector<int64_t> distributedDims =
- getDistributedDims(resultType, distResultTy);
-
- // Collect updated sizes, offsets, strides. Pad to full source rank.
- int64_t sourceRank = op.getSourceVectorType().getRank();
- SmallVector<Attribute> updatedSizes =
- llvm::map_to_vector(op.getSizes(), [](Attribute attr) { return attr; });
- SmallVector<Attribute> updatedOffsets = llvm::map_to_vector(
- op.getOffsets(), [](Attribute attr) { return attr; });
- SmallVector<Attribute> updatedStrides = llvm::map_to_vector(
- op.getStrides(), [](Attribute attr) { return attr; });
- for (int64_t i = op.getSizes().size(); i < sourceRank; ++i) {
- updatedSizes.push_back(
- rewriter.getI64IntegerAttr(op.getSourceVectorType().getDimSize(i)));
- updatedOffsets.push_back(rewriter.getI64IntegerAttr(0));
- updatedStrides.push_back(rewriter.getI64IntegerAttr(1));
- }
-
- // If the result is distributed, adjust offsets and sizes in the
- // distributed dimension.
- if (!distributedDims.empty()) {
- if (distributedDims.size() != 1)
- return rewriter.notifyMatchFailure(
- op, "only single dimension distribution is supported");
- int64_t distDim = distributedDims[0];
- const uArch *uArch = getUArch(xegpu::getChipStr(op).value_or(""));
- if (!uArch)
- return rewriter.notifyMatchFailure(
- op, "target attribute required to determine subgroup size");
- int subgroupSize = uArch->getSubgroupSize();
- auto sourceLayout = xegpu::getTemporaryLayout(op->getOpOperand(0));
- if (!sourceLayout || sourceLayout.getEffectiveLaneLayoutAsInt().empty())
- return rewriter.notifyMatchFailure(
- op, "source of extract_strided_slice lacks distribution layout");
- int sourceDistrDimSize = op.getSourceVectorType().getShape()[distDim];
- if (sourceDistrDimSize % subgroupSize != 0)
- return rewriter.notifyMatchFailure(
- op, "source size along distributed dim is not a multiple of "
- "subgroup size");
- auto sourceLaneData = sourceLayout.getEffectiveLaneDataAsInt();
- // Only check lane_data for the distributed dimension. Non-distributed
- // dimensions may have non-unit lane_data (e.g., packed layouts).
- if (distDim < static_cast<int64_t>(sourceLaneData.size()) &&
- sourceLaneData[distDim] != 1)
- return rewriter.notifyMatchFailure(
- op, "expecting unit lane data along the distributed dimension");
- int64_t distrDimOffset =
- cast<IntegerAttr>(updatedOffsets[distDim]).getInt();
- if (distrDimOffset % subgroupSize != 0)
- return rewriter.notifyMatchFailure(
- op, "offset along distributed dim is not a multiple of "
- "subgroup size");
- // Adjust sizes and offsets for the distributed dimension.
- updatedSizes[distDim] =
- rewriter.getI64IntegerAttr(distResultTy.getDimSize(distDim));
- updatedOffsets[distDim] =
- rewriter.getI64IntegerAttr(distrDimOffset / subgroupSize);
- }
-
- auto newOp = vector::ExtractStridedSliceOp::create(
- rewriter, op.getLoc(), distResultTy, adaptor.getSource(),
- ArrayAttr::get(rewriter.getContext(), updatedOffsets),
- ArrayAttr::get(rewriter.getContext(), updatedSizes),
- ArrayAttr::get(rewriter.getContext(), updatedStrides));
- rewriter.replaceOp(op, newOp.getResult());
- return success();
- }
-};
-
-/// This pattern distributes a subgroup-level `vector.broadcast` op to
-/// workitem-level. The pattern supports three cases:
-///
-/// 1) Broadcast a low-rank vector to high-rank vector: The low-rank input
-/// vector must have a slice layout of the result. If the distributed source
-/// and target vector types are identical, this lowers to a no-op; otherwise,
-/// it remains a broadcast but operates on distributed vectors.
-///
-/// 2) Broadcast a same-rank vector with identical layouts for source and
-/// target: The source vector must have unit dimensions, and lane_data must
-/// be unit size for those unit dims. This always lowers to a no-op.
-///
-/// 3) Broadcast a scalar with no layout: This always lowers to a broadcast
-/// from scalar to distributed result type.
-///
-/// Example 1 (low-rank to high-rank broadcast):
-/// ```
-/// %0 = "some_op"() {layout_result_0 =
-/// #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
-/// dims = [0]>} : () -> vector<16xf16>
-/// %1 = vector.broadcast %0 {layout_result_0 =
-/// #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
-/// : vector<16xf16> to vector<16x16xf16>
-/// ```
-/// is distributed to:
-/// ```
-/// %0 = "some_op"() : () -> vector<1xf16>
-/// %1 = vector.broadcast %0 : vector<1xf16> to vector<16x1xf16>
-/// ```
-///
-/// Example 2 (same-rank broadcast, no-op):
-/// ```
-/// %0 = "some_op"() {layout_result_0 =
-/// #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
-/// : () -> vector<16x1xf16>
-/// %1 = vector.broadcast %0 {layout_result_0 =
-/// #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
-/// : vector<16x1xf16> to vector<16x16xf16>
-/// ```
-/// is distributed to (no-op, source already matches distributed result type):
-/// ```
-/// %0 = "some_op"() : () -> vector<16x1xf16>
-/// // broadcast is eliminated, %0 is used directly
-/// ```
-///
-/// Example 3 (scalar to vector broadcast):
-/// ```
-/// %0 = "some_op"() : () -> f16
-/// %1 = vector.broadcast %0 {layout_result_0 =
-/// #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
-/// : f16 to vector<16x16xf16>
-/// ```
-/// is distributed to:
-/// ```
-/// %0 = "some_op"() : f16
-/// %1 = vector.broadcast %0 : f16 to vector<16x1xf16>
-/// ```
-struct SgToWiBroadcast : public OpConversionPattern<vector::BroadcastOp> {
- using OpConversionPattern<vector::BroadcastOp>::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(vector::BroadcastOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- xegpu::DistributeLayoutAttr resultLayout =
- xegpu::getTemporaryLayout(cast<OpResult>(op.getResult()));
- if (!resultLayout || !resultLayout.isForSubgroup())
- return rewriter.notifyMatchFailure(
- op, "result does not have subgroup distribute layout");
-
- VectorType destType = op.getResultVectorType();
- VectorType sourceType = dyn_cast<VectorType>(op.getSourceType());
-
- xegpu::DistributeLayoutAttr sourceLayout =
- xegpu::getTemporaryLayout(op->getOpOperand(0));
-
- if (sourceType) {
- int64_t rankDiff = destType.getRank() - sourceType.getRank();
- if (rankDiff > 0) {
- // Case 1: Low-rank to high-rank broadcast.
- if (!sourceLayout || !sourceLayout.isSliceOf(resultLayout))
- op.emitWarning(
- "broadcast source layout must be a slice of result layout");
- } else if (rankDiff == 0) {
- // Case 2: Same-rank broadcast.
- auto broadcastUnitDimsSet = op.computeBroadcastedUnitDims();
- SmallVector<int64_t> broadcastUnitDims(broadcastUnitDimsSet.begin(),
- broadcastUnitDimsSet.end());
- assert(sourceLayout.isEqualTo(
- sourceLayout.setUnitDimData(broadcastUnitDims)) &&
- "The sg_data for unit dimensions should be set as 1");
- sourceLayout = sourceLayout.setUnitDimLayout(broadcastUnitDims);
- }
- } else {
- // Case 3: Scalar to vector broadcast.
- if (sourceLayout)
- return rewriter.notifyMatchFailure(
- op, "broadcast from scalar must not have a layout attribute");
- }
-
- auto destDistType =
- xegpu::getDistVecTypeBasedOnLaneLayout(resultLayout, destType);
- if (failed(destDistType))
- return rewriter.notifyMatchFailure(
- op, "failed to distribute the result vector type");
-
- Value source = adaptor.getSource();
- // If the adapted source already matches the dest dist type, it's a no-op.
- if (source.getType() == destDistType.value()) {
- rewriter.replaceOp(op, source);
- return success();
- }
-
- auto newOp = vector::BroadcastOp::create(rewriter, op.getLoc(),
- destDistType.value(), source);
- rewriter.replaceOp(op, newOp);
- return success();
- }
-};
-
-/// Distributes a subgroup-level vector.insert_strided_slice op to
-/// workitem-level. If the dest is distributed, the offsets are adjusted to
-/// match the distributed types.
-struct SgToWiVectorInsertStridedSlice
- : public OpConversionPattern<vector::InsertStridedSliceOp> {
- using OpConversionPattern<vector::InsertStridedSliceOp>::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(vector::InsertStridedSliceOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- xegpu::DistributeLayoutAttr resultLayout =
- xegpu::getTemporaryLayout(op->getOpResult(0));
- if (!resultLayout || !resultLayout.isForSubgroup())
- return failure();
-
- VectorType destType = op.getDestVectorType();
- auto distDestTyOrFailure =
- xegpu::getDistVecTypeBasedOnLaneLayout(resultLayout, destType);
- if (failed(distDestTyOrFailure))
- return rewriter.notifyMatchFailure(
- op, "unable to compute distributed vector type from lane layout");
- VectorType distDestTy = *distDestTyOrFailure;
-
- SmallVector<int64_t> destDistributedDims =
- getDistributedDims(destType, distDestTy);
-
- SmallVector<Attribute> updatedOffsets = llvm::map_to_vector(
- op.getOffsets(), [](Attribute attr) { return attr; });
-
- if (!destDistributedDims.empty()) {
- if (destDistributedDims.size() != 1)
- return rewriter.notifyMatchFailure(
- op, "only single dimension distribution is supported");
- int64_t destDistDim = destDistributedDims[0];
-
- const uArch *uArch = getUArch(xegpu::getChipStr(op).value_or(""));
- if (!uArch)
- return rewriter.notifyMatchFailure(
- op, "target attribute required to determine subgroup size");
- int subgroupSize = uArch->getSubgroupSize();
-
- VectorType srcType = op.getSourceVectorType();
- // The distributed dim must be in the last k (source rank) dims of dest.
- int64_t sourceDistDim =
- destDistDim - (destType.getRank() - srcType.getRank());
- if (sourceDistDim < 0)
- return rewriter.notifyMatchFailure(
- op, "distributed dimension must be in the last k dims of dest");
-
- auto destLayout = xegpu::getTemporaryLayout(op->getOpOperand(1));
- auto sourceLayout = xegpu::getTemporaryLayout(op->getOpOperand(0));
- if (!destLayout || !sourceLayout ||
- destLayout.getEffectiveLaneLayoutAsInt().empty() ||
- sourceLayout.getEffectiveLaneLayoutAsInt().empty())
- return rewriter.notifyMatchFailure(
- op, "source or dest of insert_strided_slice lacks distribution "
- "layout");
-
- auto destLaneData = destLayout.getEffectiveLaneDataAsInt();
- auto sourceLaneData = sourceLayout.getEffectiveLaneDataAsInt();
- // Only check lane_data for the distributed dimension. Non-distributed
- // dimensions may have non-unit lane_data (e.g., packed layouts).
- if ((destDistDim < static_cast<int64_t>(destLaneData.size()) &&
- destLaneData[destDistDim] != 1) ||
- (sourceDistDim < static_cast<int64_t>(sourceLaneData.size()) &&
- sourceLaneData[sourceDistDim] != 1))
- return rewriter.notifyMatchFailure(
- op, "expecting unit lane data along the distributed dimension");
-
- int64_t srcDistrDimSize = srcType.getDimSize(sourceDistDim);
- if (srcDistrDimSize % subgroupSize != 0)
- return rewriter.notifyMatchFailure(
- op, "source distributed dim size is not a multiple of "
- "subgroup size");
-
- int64_t destDistrDimOffset =
- cast<IntegerAttr>(op.getOffsets()[destDistDim]).getInt();
- if (destDistrDimOffset % subgroupSize != 0)
- return rewriter.notifyMatchFailure(
- op, "offset along distributed dim is not a multiple of "
- "subgroup size");
- // Adjust offset for the distributed dimension.
- updatedOffsets[destDistDim] =
- rewriter.getI64IntegerAttr(destDistrDimOffset / subgroupSize);
- }
-
- auto newOp = vector::InsertStridedSliceOp::create(
- rewriter, op.getLoc(), distDestTy, adaptor.getValueToStore(),
- adaptor.getDest(),
- ArrayAttr::get(rewriter.getContext(), updatedOffsets), op.getStrides());
- rewriter.replaceOp(op, newOp.getResult());
- return success();
- }
-};
-
-/// Distributes a subgroup-level vector.insert op to workitem-level. Only
-/// handles sub-vector insertion (value to store is VectorType, not scalar).
-struct SgToWiVectorInsert : public OpConversionPattern<vector::InsertOp> {
- using OpConversionPattern<vector::InsertOp>::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(vector::InsertOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- // Only handle vector value-to-store (not scalar insertion).
- auto valueType = dyn_cast<VectorType>(op.getValueToStoreType());
- if (!valueType)
- return rewriter.notifyMatchFailure(op, "scalar insert not supported");
-
- xegpu::DistributeLayoutAttr layout =
- xegpu::getTemporaryLayout(op->getOpResult(0));
- if (!layout || !layout.isForSubgroup())
- return failure();
-
- // verify that the outer k dimensions (for offsets)
- // don't have non-unit lane_layout.
- auto laneLayout = layout.getEffectiveLaneLayoutAsInt();
- if (llvm::any_of(ArrayRef<int64_t>(laneLayout).drop_back(1),
- [](int64_t v) { return v != 1; }))
- return rewriter.notifyMatchFailure(
- op, "only innermost dimension distribution is supported for "
- "vector.insert");
-
- auto newOp = vector::InsertOp::create(
- rewriter, op.getLoc(), adaptor.getValueToStore(), adaptor.getDest(),
- op.getMixedPosition());
- rewriter.replaceOp(op, newOp.getResult());
- return success();
- }
-};
-
-/// Folds a subgroup-level ConvertLayout op with compatible lane layouts.
-struct SgToWiConvertLayout
- : public OpConversionPattern<xegpu::ConvertLayoutOp> {
- using OpConversionPattern<xegpu::ConvertLayoutOp>::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(xegpu::ConvertLayoutOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- auto inputLayout = op.getInputLayoutAttr();
- auto targetLayout = op.getTargetLayoutAttr();
- Type valType = op.getResult().getType();
-
- if (valType.isIntOrFloat()) {
- rewriter.replaceOp(op, op.getSource());
- return success();
- }
-
- auto resShape = cast<VectorType>(valType).getShape();
- SmallVector<int64_t> resShapeVec(resShape.begin(), resShape.end());
- if (!inputLayout.isCompatibleWith(targetLayout, resShapeVec,
- xegpu::LayoutKind::Lane)) {
- return rewriter.notifyMatchFailure(
- op, "lowering incompatible convert_layout not yet supported");
- }
-
- rewriter.replaceOp(op, adaptor.getSource());
- return success();
- }
-};
-
-struct XeGPUSgToWiDistributeExperimentalPass
- : public xegpu::impl::XeGPUSgToWiDistributeExperimentalBase<
- XeGPUSgToWiDistributeExperimentalPass> {
- void runOnOperation() override;
-};
-
-} // namespace
-
-void XeGPUSgToWiDistributeExperimentalPass::runOnOperation() {
-
- // Recover temporary operand layouts for usage in patterns.
- Operation *root = getOperation();
- if (!xegpu::recoverTemporaryLayouts(root)) {
- signalPassFailure();
- return;
- }
-
- // Collect existing UnrealizedConversionCastOps. These must be preserved.
- llvm::SmallSetVector<UnrealizedConversionCastOp, 8> existingCasts;
- root->walk(
- [&](UnrealizedConversionCastOp castOp) { existingCasts.insert(castOp); });
- // Perform a structural type conversion to convert structural ops to have WI
- // types. This will insert UnrealizedConversionCastOps to make the IR
- // valid.
- auto materializeCast = [&](mlir::OpBuilder &builder, mlir::Type type,
- mlir::ValueRange inputs,
- mlir::Location loc) -> mlir::Value {
- UnrealizedConversionCastOp castOp =
- UnrealizedConversionCastOp::create(builder, loc, type, inputs);
- return castOp.getResult(0);
- };
- {
- ConversionTarget target(getContext());
- TypeConverter typeConverter;
- RewritePatternSet patterns(&getContext());
- typeConverter.addSourceMaterialization(materializeCast);
- typeConverter.addTargetMaterialization(materializeCast);
- xegpu::populateXeGPUSgToWiDistributeTypeConversions(typeConverter);
- scf::populateSCFStructuralTypeConversionsAndLegality(typeConverter,
- patterns, target);
- xegpu::populateXeGPUSgToWiDistributeTypeConversionAndLegality(
- typeConverter, patterns, target);
- target.addLegalOp<UnrealizedConversionCastOp>();
- (void)applyPartialConversion(root, target, std::move(patterns));
- }
- // Structural type conversion can generate some redundant
- // UnrealizedConversionCastOps to materialize the SG type from type converted
- // WI type. These are redundant at this point and can be eliminated by
- // inserting shape casts instead.
- // Example:
- // %1 = UnrealizedConversionCastOp %0 : vector<16x1xf32> to vector<16x16xf32>
- // %2 = UnrealizedConversionCastOp %1 : vector<16x16xf32> to vector<16xf32>
- // This can be replaced with:
- // %2 = vector.shape_cast %0 : vector<16x1xf32> to vector<16xf32>
- OpBuilder builder(root);
- root->walk([&](UnrealizedConversionCastOp op) {
- // If this op existed before, nothing to do.
- if (existingCasts.contains(op))
- return;
- // number of inputs and outputs must be 1.
- if (op.getNumOperands() != 1 || op.getNumResults() != 1)
- return;
- // Both input and output types must be vector types.
- auto singleInput = op.getInputs()[0];
- auto inputTy = dyn_cast<VectorType>(singleInput.getType());
- auto outputTy = dyn_cast<VectorType>(op.getResult(0).getType());
- if (!inputTy || !outputTy)
- return;
-
- // Check if the defining op of the input is also an
- // UnrealizedConversionCastOp and it has a single user (which is this
- // op).
- auto definingOp = singleInput.getDefiningOp<UnrealizedConversionCastOp>();
- if (!definingOp || !definingOp->hasOneUse())
- return;
- auto inputOfDefiningOp = definingOp.getInputs()[0];
- // If the input of the defining op and output type are both vector types
- // have same number of elements, insert a shape cast.
- auto inputOfDefiningOpTy =
- dyn_cast<VectorType>(inputOfDefiningOp.getType());
- if (inputOfDefiningOpTy &&
- inputOfDefiningOpTy.getNumElements() == outputTy.getNumElements()) {
- builder.setInsertionPoint(op);
- auto shapeCast = vector::ShapeCastOp::create(builder, op.getLoc(),
- outputTy, inputOfDefiningOp);
- op.replaceAllUsesWith(ValueRange{shapeCast.getResult()});
- return;
- }
- });
- // At this point, we will have some dead UnrealizedConversionCastOps. Just
- // erase them.
- bool changed = true;
- while (changed) {
- changed = false;
- root->walk([&](UnrealizedConversionCastOp op) {
- // Skip existing casts.
- if (existingCasts.contains(op))
- return;
- if (op.use_empty()) {
- op.erase();
- changed = true;
- }
- });
- }
-}
-
-void xegpu::populateXeGPUSgToWiDistributeTypeConversions(
- TypeConverter &typeConverter) {
- // Any type other than TensorDescType and VectorType are legal as is.
- typeConverter.addConversion([](Type type) -> std::optional<Type> {
- if (!isa<TensorDescType, VectorType>(type))
- return type;
- return std::nullopt;
- });
- // For TensorDescType, drop the layout attribute if any.
- typeConverter.addConversion([](TensorDescType type) -> Type {
- if (type.getLayoutAttr()) {
- return type.dropLayouts();
- }
- return type;
- });
- // For VectorType, check if there is a distribute layout attribute on the
- // value. If so, convert to the distributed vector type based on the layout.
- typeConverter.addConversion([](Value v) -> std::optional<Type> {
- auto type = v.getType();
- // If value is not vector type, nothing to do.
- if (!isa<VectorType>(type))
- return std::nullopt;
- auto layout = xegpu::getDistributeLayoutAttr(v);
- if (!layout || !layout.isForSubgroup())
- return type;
- // Vector type is distributed based on lane layout.
- auto newTyOrFailure =
- getDistVecTypeBasedOnLaneLayout(layout, cast<VectorType>(type));
- if (failed(newTyOrFailure))
- return type;
- return *newTyOrFailure;
- });
-}
-
-void xegpu::populateXeGPUSgToWiDistributeTypeConversionAndLegality(
- TypeConverter &typeConverter, RewritePatternSet &patterns,
- ConversionTarget &target) {
- populateXeGPUSgToWiDistributeTypeConversions(typeConverter);
- // CreateNdDescOp is legal only if its result type has no layout attribute.
- target.addDynamicallyLegalOp<xegpu::CreateNdDescOp>(
- [&](xegpu::CreateNdDescOp op) { return !op.getType().getLayoutAttr(); });
- // Any anchor XeGPU op is legal only if it has no anchor layout.
- target.addDynamicallyLegalDialect<xegpu::XeGPUDialect>([](Operation *op) {
- auto anchorOp = dyn_cast<AnchorLayoutInterface>(op);
- if (!anchorOp)
- return true;
- return !anchorOp.getAnchorLayout();
- });
- // Arith constants are legal only if they have no temporary layout attribute.
- target.addDynamicallyLegalOp<arith::ConstantOp>(
- [=](arith::ConstantOp op) -> bool {
- // If the result type is not a vector, it's legal.
- if (!isa<VectorType>(op.getResult().getType()))
- return true;
- return !xegpu::getTemporaryLayout(dyn_cast<OpResult>(op.getResult()));
- });
- // In math and arith dialects, only handle elementwise ops with a single
- // result and with a result layout attribute.
- target.addDynamicallyLegalDialect<math::MathDialect, arith::ArithDialect>(
- [=](Operation *op) -> std::optional<bool> {
- // Only handle elementwise mappable ops
- if (!OpTrait::hasElementwiseMappableTraits(op))
- return true;
- // Only handle ops with single vector result
- if (op->getNumResults() != 1)
- return true;
-
- VectorType resultType =
- dyn_cast<VectorType>(op->getResult(0).getType());
- if (!resultType)
- return true;
-
- // Check if all operands are vectors of the same shape
- for (Value operand : op->getOperands()) {
- VectorType operandType = dyn_cast<VectorType>(operand.getType());
- if (!operandType || operandType.getShape() != resultType.getShape()) {
- return true;
- }
- }
- return !xegpu::getTemporaryLayout(dyn_cast<OpResult>(op->getResult(0)));
- });
- // vector::ReductionOp is legal only if its source has no distribute layout
- // attribute.
- target.addDynamicallyLegalOp<vector::ReductionOp>(
- [=](vector::ReductionOp op) -> bool {
- auto layout = xegpu::getDistributeLayoutAttr(op.getVector());
- return !layout;
- });
- // vector::MultiDimReductionOp op legality.
- target.addDynamicallyLegalOp<vector::MultiDimReductionOp>(
- [=](vector::MultiDimReductionOp op) -> bool {
- return !isValidSubgroupMultiReductionOp(op);
- });
- target.addDynamicallyLegalOp<vector::CreateMaskOp, vector::ConstantMaskOp,
- vector::TransposeOp, vector::BitCastOp,
- vector::ShapeCastOp, vector::StepOp,
- vector::BroadcastOp>([=](Operation *op) -> bool {
- return !xegpu::getTemporaryLayout(op->getOpResult(0));
- });
- target.addDynamicallyLegalOp<vector::ExtractOp>(
- [=](vector::ExtractOp op) -> bool {
- if (!isa<VectorType>(op.getType()))
- return true;
- return !xegpu::getTemporaryLayout(op->getOpResult(0));
- });
- target.addDynamicallyLegalOp<vector::InsertOp>(
- [=](vector::InsertOp op) -> bool {
- return !xegpu::getTemporaryLayout(op->getOpResult(0));
- });
- target.addDynamicallyLegalOp<vector::ExtractStridedSliceOp>(
- [=](vector::ExtractStridedSliceOp op) -> bool {
- return !xegpu::getTemporaryLayout(op->getOpResult(0));
- });
- target.addDynamicallyLegalOp<vector::InsertStridedSliceOp>(
- [=](vector::InsertStridedSliceOp op) -> bool {
- return !xegpu::getTemporaryLayout(op->getOpResult(0));
- });
- target.markUnknownOpDynamicallyLegal([](Operation *op) { return true; });
- patterns.add<SgToWiCreateNdDesc, SgToWiLoadNd, SgToWiStoreNd, SgToWiDpas,
- SgToWiElementWise, SgToWiArithConstant, SgToWiPrefetchNd,
- SgToWiLoadGather, SgToWiStoreScatter, SgToWiVectorReduction,
- SgToWiMultiDimReduction, SgToWiVectorExtract, SgToWiVectorInsert,
- SgToWiVectorExtractStridedSlice, SgToWiVectorInsertStridedSlice,
- SgToWiLoadMatrix, SgToWiStoreMatrix, SgToWiConvertLayout,
- SgToWiVectorTranspose, SgToWiVectorBitcast, SgToWiVectorStep,
- SgToWiVectorShapeCast, SgToWiBroadcast,
- SgToWiCreateMask<vector::CreateMaskOp>,
- SgToWiCreateMask<vector::ConstantMaskOp>>(typeConverter,
- patterns.getContext());
-}
>From 6845ef13d5ccfe86004bbe8c00f5e5b9446e9526 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Mon, 13 Apr 2026 23:16:51 +0000
Subject: [PATCH 17/19] remove isEvenlyDistributable
---
mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp | 76 -------------------
mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 11 ---
.../Transforms/XeGPUWgToSgDistribute.cpp | 3 -
3 files changed, 90 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index eaa43c02946d8..cba29b1a926d0 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -121,74 +121,6 @@ static SmallVector<SmallVector<int64_t>> genStaticCoordinates(
return coordinates;
}
-// Checks if the given shape can be evenly distributed based on the layout
-// and data factors provided by the LayoutAttr.
-bool XeGPUDialect::isEvenlyDistributable(llvm::ArrayRef<int64_t> shape,
- xegpu::DistributeLayoutAttr attr) {
- assert(attr && "Layout attribute is missing.");
-
- // Checks whether the given shape can be evenly distributed using the
- // specified layout and data attributes. If successful, it returns the work
- // size for each compute unit; otherwise, it returns `std::nullopt`. The work
- // size per compute unit is calculated as follows:
- // - If `data` is null: newShape[i] = shape[i] / layout[i]
- // - If `data` is not null: newShape[i] = data[i]
- // When round-robin distribution (`rr`) is enabled, `shape[i]` can be
- // smaller than `layout[i] * data[i]`, allowing multiple compute units to
- // share the data.
- auto tryDistribute = [&](llvm::ArrayRef<int64_t> shape,
- SmallVector<int64_t> layout,
- SmallVector<int64_t> data,
- bool rr = true) -> optional<SmallVector<int64_t>> {
- llvm::SmallVector<int64_t> newShape(shape);
- if (layout.size()) {
- if (layout.size() != shape.size())
- return std::nullopt;
- auto ratio = computeShapeRatio(shape, layout);
- if (ratio.has_value()) {
- newShape = ratio.value();
- } else if (!rr || !computeShapeRatio(layout, shape).has_value()) {
- return std::nullopt;
- }
- // Round-robin case: continue with original newShape
- }
-
- if (data.size()) {
- if (data.size() != shape.size())
- return std::nullopt;
- auto ratio = computeShapeRatio(newShape, data);
- if (!ratio.has_value() && rr)
- ratio = computeShapeRatio(data, newShape);
- if (!ratio.has_value())
- return std::nullopt;
-
- // if data is not null, we always return it for next phase.
- newShape = data;
- }
- return newShape;
- };
-
- // check the sgLayout and sgData
- auto maybeSgShape = tryDistribute(shape, attr.getEffectiveSgLayoutAsInt(),
- attr.getEffectiveSgDataAsInt());
- if (!maybeSgShape)
- return false;
- auto sgShape = maybeSgShape.value();
-
- // check InstData, it neither have layout nor need round-robin
- auto maybeInstShape =
- tryDistribute(sgShape, {}, attr.getEffectiveInstDataAsInt(), false);
- if (!maybeInstShape)
- return false;
- auto instShape = maybeInstShape.value();
-
- // check LaneLayout and LaneData
- auto maybeLaneShape =
- tryDistribute(instShape, attr.getEffectiveLaneLayoutAsInt(),
- attr.getEffectiveLaneDataAsInt());
- return maybeLaneShape.has_value();
-}
-
//===----------------------------------------------------------------------===//
// XeGPU_BlockTensorDescAttr
//===----------------------------------------------------------------------===//
@@ -1448,14 +1380,6 @@ TensorDescType::verify(llvm::function_ref<InFlightDiagnostic()> emitError,
<< "expected last dim of lane_data to be a multiple of: "
<< chunkAlignmentFactor;
}
-
- if (!XeGPUDialect::isEvenlyDistributable(shape, layoutAttr)) {
- std::string shapeStr;
- llvm::raw_string_ostream stream(shapeStr);
- llvm::interleaveComma(shape, stream);
- return emitError() << "cannot distribute [" << shapeStr << "] using "
- << layoutAttr;
- }
}
return success();
}
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index d6981052a7a5c..d3fdbb81a1cbd 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -1101,17 +1101,6 @@ LogicalResult ConvertLayoutOp::verify() {
return emitOpError("expected input layout and target layout be WgLayout or "
"SgLayout at the same time.");
- // Type srcType = getSource().getType();
- // if (llvm::isa<VectorType>(srcType)) {
- // auto shape = llvm::cast<VectorType>(srcType).getShape();
- // if (!XeGPUDialect::isEvenlyDistributable(shape, srcLayout))
- // return emitOpError(
- // "invalid input layout, data cannot be evenly distributed.");
-
- // if (!XeGPUDialect::isEvenlyDistributable(shape, resLayout))
- // return emitOpError(
- // "invalid target layout, data cannot be evenly distributed.");
- // }
return mlir::success();
}
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index dabdcb61f0500..57188fae5e8c8 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -498,9 +498,6 @@ struct WgToSgVectorBroadcastOp
VectorType newResultType =
VectorType::get(sgShape, resultType.getElementType());
- if (!xegpu::XeGPUDialect::isEvenlyDistributable(wgShape, layout))
- return failure();
-
SmallVector<Value> newBroadcastOps;
auto distSource = adaptor.getOperands().front();
int numDistributions = count / distSource.size();
>From 0c7917e246f5f78f38b1872ad97e1cab4fea3e47 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Mon, 13 Apr 2026 23:17:41 +0000
Subject: [PATCH 18/19] remove isEvenlyDistributable
---
mlir/include/mlir/Dialect/XeGPU/IR/XeGPUDialect.td | 4 ----
1 file changed, 4 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUDialect.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUDialect.td
index c173b93face98..84fd8f9e0060c 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUDialect.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUDialect.td
@@ -38,10 +38,6 @@ def XeGPU_Dialect : Dialect {
let useDefaultAttributePrinterParser = true;
let extraClassDeclaration = [{
- /// Checks if the given shape can be evenly distributed based on the layout
- /// and data factors provided by the LayoutAttr.
- static bool isEvenlyDistributable(llvm::ArrayRef<int64_t> shape, xegpu::DistributeLayoutAttr attr);
-
/// drops/slices the shape in the specified dims, and return the rest. e.g.,
/// for shape = [32, 64, 8], dims = [0, 2], it will return [64]
template<typename T, typename U>
>From b688a2f15c140d2538cd8c40c97ae1a875500bbb Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Thu, 16 Apr 2026 18:23:40 +0000
Subject: [PATCH 19/19] refactor adding removeTemporaryLayoutAttrs to utility
---
.../mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h | 5 +++++
.../Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp | 12 ++++++++++++
.../XeGPU/Transforms/XeGPUPeepHoleOptimizer.cpp | 11 +----------
.../XeGPU/Transforms/XeGPUPropagateLayout.cpp | 13 +++----------
.../XeGPUSgToWiDistributeExperimental.cpp | 11 ++---------
.../XeGPU/Transforms/XeGPUSubgroupDistribute.cpp | 11 ++---------
.../XeGPU/Transforms/XeGPUWgToSgDistribute.cpp | 10 +---------
7 files changed, 26 insertions(+), 47 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h b/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h
index a7ca51bc4fdf7..83eb939cf1bec 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h
@@ -61,6 +61,11 @@ void removeLayoutAttr(const T &operandOrResult);
/// applied recursively to the contained operations
void removeLayoutAttrs(Operation *op);
+/// Removes the temporary layout attributes for each OpOperand and OpResult of
+/// the given operation. Recursive for contained operations if the given
+/// operation contains regions.
+void removeTemporaryLayoutAttrs(Operation *op);
+
/// Updates the NamedAttribute sequence by dropping sg-layout and
/// sg-data information from any DistributeLayoutAttr found.
SmallVector<NamedAttribute>
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
index 8f319bd161798..0dd9348bc06c2 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
@@ -306,6 +306,18 @@ void xegpu::removeLayoutAttrs(Operation *op) {
});
}
+void xegpu::removeTemporaryLayoutAttrs(Operation *op) {
+ op->walk([&](Operation *nestOp) {
+ SmallVector<StringAttr> attrsToRemove;
+ for (auto namedAttr : nestOp->getDiscardableAttrs()) {
+ if (isa<xegpu::DistributeLayoutAttr>(namedAttr.getValue()))
+ attrsToRemove.push_back(namedAttr.getName());
+ }
+ for (auto attrName : attrsToRemove)
+ nestOp->removeDiscardableAttr(attrName);
+ });
+}
+
/// Infers the source layout attribute for a broadcast operation given the
/// result layout attribute, result shape, source shape.
xegpu::DistributeLayoutAttr
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPeepHoleOptimizer.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPeepHoleOptimizer.cpp
index 3496756e8a6d3..8ade936724480 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPeepHoleOptimizer.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPeepHoleOptimizer.cpp
@@ -598,16 +598,7 @@ struct XeGPUPeepHoleOptimizerPass final
RewritePatternSet emptyPatterns(ctx);
(void)applyPatternsGreedily(getOperation(), std::move(emptyPatterns));
- // Remove the temporary layout after all patterns are applied.
- getOperation()->walk([](Operation *op) {
- SmallVector<StringAttr> attrsToRemove;
- for (auto namedAttr : op->getDiscardableAttrs()) {
- if (isa<xegpu::DistributeLayoutAttr>(namedAttr.getValue()))
- attrsToRemove.push_back(namedAttr.getName());
- }
- for (auto attrName : attrsToRemove)
- op->removeDiscardableAttr(attrName);
- });
+ xegpu::removeTemporaryLayoutAttrs(getOperation());
}
};
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index ff9ff4937c293..630a314b1ce40 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -1694,16 +1694,9 @@ LogicalResult xegpu::resolveLayoutConflicts(Operation *target) {
}
void XeGPUPropagateLayoutPass::runOnOperation() {
- // Clean up temporary layout attributes
- getOperation()->walk([](Operation *op) {
- SmallVector<StringAttr> attrsToRemove;
- for (auto namedAttr : op->getDiscardableAttrs()) {
- if (isa<xegpu::DistributeLayoutAttr>(namedAttr.getValue()))
- attrsToRemove.push_back(namedAttr.getName());
- }
- for (auto attrName : attrsToRemove)
- op->removeDiscardableAttr(attrName);
- });
+
+ xegpu::removeTemporaryLayoutAttrs(getOperation());
+
xegpu::LayoutKind layoutKind;
if (this->layoutKind == "lane") {
layoutKind = xegpu::LayoutKind::Lane;
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
index 9cb9d34401216..c153db431c035 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
@@ -1609,15 +1609,8 @@ void XeGPUSgToWiDistributeExperimentalPass::runOnOperation() {
}
});
}
- getOperation()->walk([](Operation *op) {
- SmallVector<StringAttr> attrsToRemove;
- for (auto namedAttr : op->getDiscardableAttrs()) {
- if (isa<xegpu::DistributeLayoutAttr>(namedAttr.getValue()))
- attrsToRemove.push_back(namedAttr.getName());
- }
- for (auto attrName : attrsToRemove)
- op->removeDiscardableAttr(attrName);
- });
+
+ xegpu::removeTemporaryLayoutAttrs(getOperation());
}
void xegpu::populateXeGPUSgToWiDistributeTypeConversions(
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index c6c48515fcf0c..1dab7d9808756 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -2280,13 +2280,6 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
op->erase();
return WalkResult::advance();
});
- getOperation()->walk([](Operation *op) {
- SmallVector<StringAttr> attrsToRemove;
- for (auto namedAttr : op->getDiscardableAttrs()) {
- if (isa<xegpu::DistributeLayoutAttr>(namedAttr.getValue()))
- attrsToRemove.push_back(namedAttr.getName());
- }
- for (auto attrName : attrsToRemove)
- op->removeDiscardableAttr(attrName);
- });
+
+ xegpu::removeTemporaryLayoutAttrs(getOperation());
}
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 57188fae5e8c8..3fadc593849bf 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -1781,13 +1781,5 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
applyPartialConversion(getOperation(), target, std::move(patterns))))
return signalPassFailure();
- getOperation()->walk([](Operation *op) {
- SmallVector<StringAttr> attrsToRemove;
- for (auto namedAttr : op->getDiscardableAttrs()) {
- if (isa<xegpu::DistributeLayoutAttr>(namedAttr.getValue()))
- attrsToRemove.push_back(namedAttr.getName());
- }
- for (auto attrName : attrsToRemove)
- op->removeDiscardableAttr(attrName);
- });
+ xegpu::removeTemporaryLayoutAttrs(getOperation());
}
More information about the Mlir-commits
mailing list