[Mlir-commits] [mlir] cd4c9d2 - [mlir][xegpu] Add initial support for layout conflict handling. (#173090)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Jan 28 11:50:47 PST 2026
Author: Charitha Saumya
Date: 2026-01-28T11:50:42-08:00
New Revision: cd4c9d200beeafaf6710e3226c11845794d5ffa9
URL: https://github.com/llvm/llvm-project/commit/cd4c9d200beeafaf6710e3226c11845794d5ffa9
DIFF: https://github.com/llvm/llvm-project/commit/cd4c9d200beeafaf6710e3226c11845794d5ffa9.diff
LOG: [mlir][xegpu] Add initial support for layout conflict handling. (#173090)
This PR adds initial support for layout conflict resolution in XeGPU.
Layout conflict occurs when some op's use point expects a different
layout than what the op can currently provide. This conflict needs to be
resolved by adding certain other xegpu ops.
Initially, We only focus conflict handling at tensor desc use points.
Added:
mlir/test/Dialect/XeGPU/resolve-layout-conflicts.mlir
Modified:
mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir
mlir/test/Dialect/XeGPU/propagate-layout.mlir
mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h b/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
index 5942f69b4a66d..9628d3064eabf 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
@@ -9,6 +9,7 @@
#ifndef MLIR_DIALECT_XEGPU_TRANSFORMS_TRANSFORMS_H
#define MLIR_DIALECT_XEGPU_TRANSFORMS_TRANSFORMS_H
+#include "mlir/IR/Builders.h"
#include "mlir/IR/Operation.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/LogicalResult.h"
@@ -91,6 +92,12 @@ void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns);
void populateXeGPUUnrollPatterns(RewritePatternSet &patterns,
const UnrollOptions &options);
+enum class LayoutKind { Lane, InstData, Subgroup };
+LogicalResult propagateLayouts(OpBuilder &builder, Operation *target,
+ LayoutKind layoutKind, bool printOnly = false);
+
+LogicalResult resolveLayoutConflicts(Operation *target);
+
} // namespace xegpu
} // namespace mlir
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 9a88310ccd3c9..96fdced39d9ab 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -15,7 +15,9 @@
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
+#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
+#include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
@@ -25,6 +27,7 @@
#include "mlir/IR/Visitors.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
+#include "mlir/Interfaces/LoopLikeInterface.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
@@ -36,8 +39,6 @@
#include "llvm/Support/LogicalResult.h"
#include "llvm/Support/raw_ostream.h"
-#include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h"
-
namespace mlir {
namespace xegpu {
#define GEN_PASS_DEF_XEGPUPROPAGATELAYOUT
@@ -53,8 +54,6 @@ using namespace mlir::dataflow;
namespace {
-enum class LayoutKind { Lane, InstData, Subgroup };
-
//===----------------------------------------------------------------------===//
// LayoutInfo
//===----------------------------------------------------------------------===//
@@ -368,7 +367,7 @@ getSIMTLayoutInfoForDPASOperand(VectorType vectorTy, unsigned operandNum,
class LayoutInfoPropagation
: public SparseBackwardDataFlowAnalysis<LayoutInfoLattice> {
private:
- LayoutKind layoutKind;
+ xegpu::LayoutKind layoutKind;
void visitDpasOp(xegpu::DpasOp dpas, ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results);
@@ -428,7 +427,7 @@ class LayoutInfoPropagation
public:
LayoutInfoPropagation(DataFlowSolver &solver,
SymbolTableCollection &symbolTable,
- LayoutKind layoutKind)
+ xegpu::LayoutKind layoutKind)
: SparseBackwardDataFlowAnalysis(solver, symbolTable),
layoutKind(layoutKind) {}
using SparseBackwardDataFlowAnalysis::SparseBackwardDataFlowAnalysis;
@@ -527,12 +526,12 @@ bool LayoutInfoPropagation::hasParamsOfLayoutKind(
if (anchorLayout == nullptr) {
return false;
}
- if (layoutKind == LayoutKind::InstData) {
+ if (layoutKind == xegpu::LayoutKind::InstData) {
return !(anchorLayout.getEffectiveInstDataAsInt().empty());
- } else if (layoutKind == LayoutKind::Lane) {
+ } else if (layoutKind == xegpu::LayoutKind::Lane) {
return !(anchorLayout.getEffectiveLaneLayoutAsInt().empty() ||
anchorLayout.getEffectiveLaneDataAsInt().empty());
- } else if (layoutKind == LayoutKind::Subgroup) {
+ } else if (layoutKind == xegpu::LayoutKind::Subgroup) {
return !(anchorLayout.getEffectiveSgLayoutAsInt().empty() ||
anchorLayout.getEffectiveSgDataAsInt().empty());
}
@@ -628,7 +627,7 @@ void LayoutInfoPropagation::visitPrefetchNdOp(
instData = {instHeight, instWidth};
}
- if (layoutKind == LayoutKind::InstData)
+ if (layoutKind == xegpu::LayoutKind::InstData)
prefetchLayout =
LayoutInfo(xegpu::LayoutAttr::get(tdescTy.getContext(), instData));
else
@@ -812,7 +811,7 @@ void LayoutInfoPropagation::visitDpasOp(
}
instDataCD = {maxALen, maxCLen};
}
- if (layoutKind == LayoutKind::InstData) {
+ if (layoutKind == xegpu::LayoutKind::InstData) {
dpasALayout =
LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataA));
dpasBLayout =
@@ -821,7 +820,7 @@ void LayoutInfoPropagation::visitDpasOp(
dpasCDLayout =
LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataCD));
}
- } else if (layoutKind == LayoutKind::Lane) {
+ } else if (layoutKind == xegpu::LayoutKind::Lane) {
dpasALayout = getSIMTLayoutInfoForDPASOperand(
aTy, 0, uArch, uArchInstruction->getPackedFormatBitSizeA());
dpasBLayout = getSIMTLayoutInfoForDPASOperand(
@@ -980,10 +979,10 @@ void LayoutInfoPropagation::visitStoreNdOp(
instData = {instHeight, instWidth};
}
- if (layoutKind == LayoutKind::InstData)
+ if (layoutKind == xegpu::LayoutKind::InstData)
storeLayout =
LayoutInfo(xegpu::LayoutAttr::get(dataTy.getContext(), instData));
- else if (layoutKind == LayoutKind::Lane)
+ else if (layoutKind == xegpu::LayoutKind::Lane)
storeLayout =
getSIMTLayoutInfoBlockIO(store.getValueType(), uArch,
uArchInstruction->getPackedFormatBitSize());
@@ -1173,7 +1172,7 @@ void LayoutInfoPropagation::visitLoadGatherOp(
uArch->getInstruction(xegpu::uArch::InstructionKind::LoadGather));
// Check if value inst_data complies with uArch
- if (layoutKind == LayoutKind::InstData) {
+ if (layoutKind == xegpu::LayoutKind::InstData) {
// Each lane loads either one element
SmallVector<int> instDataUarch{subgroupSize};
// Or multiple elements as 2D with lane's elements in the inner dimension
@@ -1215,10 +1214,10 @@ void LayoutInfoPropagation::visitLoadGatherOp(
// Rank >1 data: Enforce the default xegpu 1D layout for mask.
if (!hasParamsOfLayoutKind(anchorLayout) ||
load.getValueType().getRank() > 1) {
- if (layoutKind == LayoutKind::InstData)
+ if (layoutKind == xegpu::LayoutKind::InstData)
maskLayout = LayoutInfo(
xegpu::LayoutAttr::get(load->getContext(), {subgroupSize}));
- else if (layoutKind == LayoutKind::Lane)
+ else if (layoutKind == xegpu::LayoutKind::Lane)
maskLayout =
getDefaultSIMTLayoutInfo(load->getContext(), 1, subgroupSize);
}
@@ -1273,7 +1272,7 @@ void LayoutInfoPropagation::visitStoreScatterOp(
return;
}
- if (layoutKind == LayoutKind::InstData) {
+ if (layoutKind == xegpu::LayoutKind::InstData) {
const auto *uArchInstruction =
dyn_cast<xegpu::uArch::StoreScatterInstruction>(uArch->getInstruction(
xegpu::uArch::InstructionKind::StoreScatter));
@@ -1310,10 +1309,10 @@ void LayoutInfoPropagation::visitStoreScatterOp(
// Rank >1 data: Enforce the default xegpu 1D layout for mask.
if (!hasParamsOfLayoutKind(anchorLayout) ||
storeScatter.getValueType().getRank() > 1) {
- if (layoutKind == LayoutKind::InstData)
+ if (layoutKind == xegpu::LayoutKind::InstData)
maskLayout = LayoutInfo(
xegpu::LayoutAttr::get(storeScatter->getContext(), {subgroupSize}));
- else if (layoutKind == LayoutKind::Lane)
+ else if (layoutKind == xegpu::LayoutKind::Lane)
maskLayout =
getDefaultSIMTLayoutInfo(storeScatter->getContext(), 1, subgroupSize);
}
@@ -1347,7 +1346,7 @@ void LayoutInfoPropagation::visitStoreMatrixOp(
assert(payloadTy.getRank() == 2 && "Expecting 2D vector for store matrix.");
auto uArch = getUArch(getChipStr(storeMatrix).value_or(""));
SmallVector<int> instData = {1, uArch->getSubgroupSize()};
- if (layoutKind == LayoutKind::InstData)
+ if (layoutKind == xegpu::LayoutKind::InstData)
layout = LayoutInfo(
xegpu::LayoutAttr::get(storeMatrix.getContext(), instData));
else
@@ -1367,7 +1366,8 @@ class RunLayoutInfoPropagation {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(RunLayoutInfoPropagation)
- RunLayoutInfoPropagation(Operation *op, LayoutKind layoutKind) : target(op) {
+ RunLayoutInfoPropagation(Operation *op, xegpu::LayoutKind layoutKind)
+ : target(op) {
SymbolTableCollection symbolTable;
loadBaselineAnalyses(solver);
solver.load<LayoutInfoPropagation>(symbolTable, layoutKind);
@@ -1441,6 +1441,121 @@ void RunLayoutInfoPropagation::printAnalysisResult(llvm::raw_ostream &os) {
printFunctionResult(funcOp);
}
+namespace {
+
+//===----------------------------------------------------------------------===//
+// ResolveLayoutConflicts
+//===----------------------------------------------------------------------===//
+struct ResolveLayoutConflicts {
+ ResolveLayoutConflicts(Operation *parentOp)
+ : parentOp(parentOp), builder(parentOp->getContext()) {}
+ LogicalResult run();
+
+private:
+ Operation *parentOp;
+ OpBuilder builder;
+ LogicalResult resolveTensorDescConsumer(OpOperand &operand);
+ LogicalResult resolveVectorConsumer(OpOperand &operand);
+};
+
+} // namespace
+
+LogicalResult ResolveLayoutConflicts::run() {
+ // Scan all operations in the parent op and resolve layout conflicts at
+ // tensor descriptor and vector use points.
+ auto r = parentOp->walk([&](Operation *op) -> WalkResult {
+ for (OpOperand &operand : op->getOpOperands()) {
+ // Handle conflicts in tensor descriptor operands.
+ Type operandType = operand.get().getType();
+ if (isa<xegpu::AnchorLayoutInterface>(op) &&
+ isa<xegpu::TensorDescType>(operandType)) {
+ auto res = resolveTensorDescConsumer(operand);
+ return succeeded(res) ? WalkResult::advance() : WalkResult::interrupt();
+ }
+ // Handle conflicts in vector operands.
+ if (isa<VectorType>(operandType)) {
+ auto res = resolveVectorConsumer(operand);
+ return succeeded(res) ? WalkResult::advance() : WalkResult::interrupt();
+ }
+ }
+ return WalkResult::advance();
+ });
+
+ return r.wasInterrupted() ? failure() : success();
+}
+
+/// Helper to get the defining CreateNdDescOp of a tensor descriptor value. This
+/// function tries to find the defining CreateNdDescOp recursively accross
+/// control-flow boundaries.
+static xegpu::CreateNdDescOp getDefiningCreateNdDescOp(Value tdescValue) {
+ // Try to get the defining CreateNdDescOp of the tensor descriptor.
+ auto definingOp = tdescValue.getDefiningOp<xegpu::CreateNdDescOp>();
+ if (definingOp)
+ return definingOp;
+ // If tdescValue is an argument, try to get the tied init value from the
+ // parent loop-like op.
+ if (auto arg = dyn_cast<BlockArgument>(tdescValue)) {
+ auto *parentOp = arg.getOwner()->getParentOp();
+ if (auto loop = dyn_cast<LoopLikeOpInterface>(parentOp)) {
+ OpOperand *tiedInit = loop.getTiedLoopInit(arg);
+ if (tiedInit)
+ return getDefiningCreateNdDescOp(tiedInit->get());
+ }
+ }
+ // If not found, return null.
+ return nullptr;
+}
+
+LogicalResult
+ResolveLayoutConflicts::resolveVectorConsumer(OpOperand &operand) {
+ // TODO: Implement vector consumer layout conflict resolution. Requires layout
+ // utilities.
+ return success();
+}
+
+LogicalResult
+ResolveLayoutConflicts::resolveTensorDescConsumer(OpOperand &operand) {
+ Operation *consumerOp = operand.getOwner();
+ Value tdescValue = operand.get();
+ auto anchorOp = dyn_cast<xegpu::AnchorLayoutInterface>(consumerOp);
+ auto currTDescType = dyn_cast<xegpu::TensorDescType>(tdescValue.getType());
+ assert(anchorOp && currTDescType &&
+ "Expected anchor layout op and tensor descriptor consumer.");
+ // TODO: Scattered tensor desc is not supported for now.
+ if (currTDescType.isScattered()) {
+ DBGS() << "Scattered tensor descriptor not supported: " << tdescValue
+ << "\n";
+ return failure();
+ }
+ Attribute currLayout = currTDescType.getLayout();
+ Attribute expectedLayout = anchorOp.getAnchorLayout();
+ // A conflict exists in tensor descriptor operand if tensor descriptor's
+ // layout is
diff erent from the anchor layout expected by the consumer.
+ if (expectedLayout && currLayout && expectedLayout != currLayout) {
+ // Try to get the defining CreateNdDescOp of the tensor descriptor.
+ auto conflictingCreateNdOp = getDefiningCreateNdDescOp(tdescValue);
+ if (!conflictingCreateNdOp) {
+ DBGS() << "Unable to find defining CreateNdDescOp for tensor descriptor: "
+ << tdescValue << "\n";
+ return failure();
+ }
+ // Duplicate the CreateNdDescOp with the expected layout.
+ builder.setInsertionPointAfter(conflictingCreateNdOp);
+ auto newTensorDescType = xegpu::TensorDescType::get(
+ conflictingCreateNdOp.getContext(), currTDescType.getShape(),
+ currTDescType.getElementType(), currTDescType.getEncoding(),
+ expectedLayout);
+ xegpu::CreateNdDescOp newOp = xegpu::CreateNdDescOp::create(
+ builder, consumerOp->getLoc(), newTensorDescType,
+ conflictingCreateNdOp->getOperands(),
+ conflictingCreateNdOp->getAttrs());
+ // Replace the tensor descriptor operand in the consumer op with the new
+ // tensor descriptor.
+ consumerOp->replaceUsesOfWith(tdescValue, newOp.getResult());
+ }
+ return success();
+}
+
using GetLayoutFnTy = function_ref<xegpu::DistributeLayoutAttr(Value)>;
/// Update an operation with the layout of its results. If the result type is
/// a vector type, a temporary layout attribute is added to the operation. If
@@ -1604,26 +1719,14 @@ struct XeGPUPropagateLayoutPass final
} // namespace
-void XeGPUPropagateLayoutPass::runOnOperation() {
- LayoutKind layoutKind;
- if (this->layoutKind == "lane") {
- layoutKind = LayoutKind::Lane;
- } else if (this->layoutKind == "inst") {
- layoutKind = LayoutKind::InstData;
- } else if (this->layoutKind == "subgroup") {
- layoutKind = LayoutKind::Subgroup;
- } else {
- getOperation()->emitError("Unsupported layout kind option: " +
- this->layoutKind);
- signalPassFailure();
- return;
- }
- RunLayoutInfoPropagation analysis(getOperation(), layoutKind);
+LogicalResult xegpu::propagateLayouts(OpBuilder &builder, Operation *target,
+ LayoutKind layoutKind, bool printOnly) {
+ RunLayoutInfoPropagation analysis(target, layoutKind);
// Print the analysis result and exit. (for debugging purposes)
if (printOnly) {
auto &os = llvm::outs();
analysis.printAnalysisResult(os);
- return;
+ return success();
}
// Helper to convert LayoutInfo to xegpu::LayoutAttr.
auto getXeGPULayoutForValue = [&](Value val) -> xegpu::DistributeLayoutAttr {
@@ -1637,8 +1740,7 @@ void XeGPUPropagateLayoutPass::runOnOperation() {
return cast<xegpu::LayoutAttr>(layoutAttr);
};
- mlir::OpBuilder builder(&getContext());
- Operation *op = getOperation();
+ Operation *op = target;
auto walkResult = op->walk([&](mlir::Block *block) -> WalkResult {
for (mlir::Operation &op : llvm::reverse(block->getOperations())) {
LogicalResult r = success();
@@ -1661,7 +1763,39 @@ void XeGPUPropagateLayoutPass::runOnOperation() {
}
return WalkResult::advance();
});
- if (walkResult.wasInterrupted()) {
+ if (walkResult.wasInterrupted())
+ return failure();
+
+ return success();
+}
+
+LogicalResult xegpu::resolveLayoutConflicts(Operation *target) {
+ ResolveLayoutConflicts resolver(target);
+ return resolver.run();
+}
+
+void XeGPUPropagateLayoutPass::runOnOperation() {
+ xegpu::LayoutKind layoutKind;
+ if (this->layoutKind == "lane") {
+ layoutKind = xegpu::LayoutKind::Lane;
+ } else if (this->layoutKind == "inst") {
+ layoutKind = xegpu::LayoutKind::InstData;
+ } else if (this->layoutKind == "subgroup") {
+ layoutKind = xegpu::LayoutKind::Subgroup;
+ } else {
+ getOperation()->emitError("Unsupported layout kind option: " +
+ this->layoutKind);
+ signalPassFailure();
+ return;
+ }
+ OpBuilder builder(&getContext());
+ if (failed(xegpu::propagateLayouts(builder, getOperation(), layoutKind,
+ this->printOnly))) {
+ signalPassFailure();
+ return;
+ }
+ // Resolve layout conflicts if any.
+ if (failed(xegpu::resolveLayoutConflicts(getOperation()))) {
signalPassFailure();
return;
}
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir b/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
index 5aad0f592abed..9de2881d05d0b 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -xevm-attach-target='chip=pvc' -xegpu-propagate-layout="layout-kind=inst" -split-input-file %s | FileCheck %s
+// RUN: mlir-opt -xevm-attach-target='chip=pvc' -test-xegpu-propagate-layouts="layout-kind=inst" -split-input-file %s | FileCheck %s
// CHECK-LABEL: func.func @load_store_no_array_len(
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir b/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir
index a0adb731605de..29e5b51627fb6 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -xevm-attach-target='chip=pvc' -xegpu-propagate-layout="layout-kind=subgroup" -split-input-file %s | FileCheck %s
+// RUN: mlir-opt -xevm-attach-target='chip=pvc' -test-xegpu-propagate-layouts="layout-kind=subgroup" -split-input-file %s | FileCheck %s
gpu.module @test {
// CHECK-LABEL: store_nd
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout.mlir b/mlir/test/Dialect/XeGPU/propagate-layout.mlir
index bf6c5d992a47f..f4859fe324b19 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -xevm-attach-target='chip=pvc' -xegpu-propagate-layout="layout-kind=lane" -split-input-file %s | FileCheck %s
+// RUN: mlir-opt -xevm-attach-target='chip=pvc' -test-xegpu-propagate-layouts="layout-kind=lane" -split-input-file %s | FileCheck %s
gpu.module @test {
// CHECK-LABEL: func.func @dpas_f16(
diff --git a/mlir/test/Dialect/XeGPU/resolve-layout-conflicts.mlir b/mlir/test/Dialect/XeGPU/resolve-layout-conflicts.mlir
new file mode 100644
index 0000000000000..d1dbe8bcff509
--- /dev/null
+++ b/mlir/test/Dialect/XeGPU/resolve-layout-conflicts.mlir
@@ -0,0 +1,81 @@
+// RUN: mlir-opt --test-xegpu-resolve-layout-conflicts -split-input-file %s | FileCheck %s
+
+#load_lo = #xegpu.layout<inst_data = [8, 16]>
+#prefetch_lo = #xegpu.layout<inst_data = [16, 16]>
+#load_lo1 = #xegpu.layout<inst_data = [32, 16]>
+gpu.module @test {
+
+// CHECK-LABEL: func.func @load_nd_with_conflicting_tensor_desc
+// CHECK: %{{.*}} = xegpu.create_nd_tdesc %{{.*}} : memref<64x64xf16>
+// CHECK-SAME: -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<inst_data = [16, 16]>>
+// CHECK-NEXT: %[[T1:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<64x64xf16>
+// CHECK-SAME: -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<inst_data = [8, 16]>>
+// CHECK-NEXT: %{{.*}} = xegpu.load_nd %[[T1]][%{{.*}}, %{{.*}}] <{layout = #xegpu.layout<inst_data = [8, 16]>}> :
+// CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<inst_data = [8, 16]>> -> vector<16x16xf16>
+func.func @load_nd_with_conflicting_tensor_desc(%arg0: memref<64x64xf16>) -> vector<16x16xf16> {
+ %c0 = arith.constant 0 : index
+ %0 = xegpu.create_nd_tdesc %arg0 : memref<64x64xf16>
+ -> !xegpu.tensor_desc<16x16xf16, #prefetch_lo>
+ %1 = xegpu.load_nd %0 [%c0, %c0] {layout = #load_lo} : !xegpu.tensor_desc<16x16xf16, #prefetch_lo>
+ -> vector<16x16xf16>
+ xegpu.prefetch_nd %0 [%c0, %c0] {layout = #prefetch_lo} : !xegpu.tensor_desc<16x16xf16, #prefetch_lo>
+ return %1 : vector<16x16xf16>
+}
+
+// CHECK-LABEL: func.func @multiple_tensor_desc_conflicts
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-NEXT: %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<64x64xf16>
+// CHECK-SAME: -> !xegpu.tensor_desc<32x16xf16, #xegpu.layout<inst_data = [8, 16]>>
+// CHECK-NEXT: %[[T1:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<64x64xf16>
+// CHECK-SAME: -> !xegpu.tensor_desc<32x16xf16, #xegpu.layout<inst_data = [16, 16]>>
+// CHECK-NEXT: %[[T2:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<64x64xf16>
+// CHECK-SAME: -> !xegpu.tensor_desc<32x16xf16, #xegpu.layout<inst_data = [32, 16]>>
+// CHECK-NEXT: %{{.*}} = xegpu.load_nd %[[T0]][%[[C0]], %[[C0]]] <{layout = #xegpu.layout<inst_data = [8, 16]>}> :
+// CHECK-SAME: !xegpu.tensor_desc<32x16xf16, #xegpu.layout<inst_data = [8, 16]>> -> vector<32x16xf16>
+// CHECK-NEXT: %{{.*}} = xegpu.load_nd %[[T2]][%[[C0]], %[[C0]]] <{layout = #xegpu.layout<inst_data = [32, 16]>}> :
+// CHECK-SAME: !xegpu.tensor_desc<32x16xf16, #xegpu.layout<inst_data = [32, 16]>> -> vector<32x16xf16>
+// CHECK-NEXT: xegpu.prefetch_nd %[[T1]][%[[C0]], %[[C0]]] <{layout = #xegpu.layout<inst_data = [16, 16]>}> :
+// CHECK-SAME: !xegpu.tensor_desc<32x16xf16, #xegpu.layout<inst_data = [16, 16]>>
+func.func @multiple_tensor_desc_conflicts(%arg0: memref<64x64xf16>) -> vector<32x16xf16> {
+ %c0 = arith.constant 0 : index
+ %tdesc1 = xegpu.create_nd_tdesc %arg0 : memref<64x64xf16>
+ -> !xegpu.tensor_desc<32x16xf16, #load_lo>
+ %load1 = xegpu.load_nd %tdesc1 [%c0, %c0] {layout = #load_lo} : !xegpu.tensor_desc<32x16xf16, #load_lo>
+ -> vector<32x16xf16>
+ %load2 = xegpu.load_nd %tdesc1 [%c0, %c0] {layout = #load_lo1} : !xegpu.tensor_desc<32x16xf16, #load_lo>
+ -> vector<32x16xf16>
+ xegpu.prefetch_nd %tdesc1 [%c0, %c0] {layout = #prefetch_lo} : !xegpu.tensor_desc<32x16xf16, #load_lo>
+ %result = arith.addf %load1, %load2 : vector<32x16xf16>
+ return %result : vector<32x16xf16>
+}
+
+// CHECK-LABEL: func.func @load_nd_with_conflicting_tensor_desc_in_loop
+// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<64x64xf16>
+// CHECK-SAME: -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<inst_data = [16, 16]>>
+// CHECK-NEXT: %[[T1:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<64x64xf16>
+// CHECK-SAME: -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<inst_data = [8, 16]>>
+// CHECK-NEXT: %{{.*}}:2 = scf.for %{{.*}} = %{{.*}} iter_args(%{{.*}} = %{{.*}}, %{{.*}} = %[[T0]])
+// CHECK-SAME: -> (vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16, #xegpu.layout<inst_data = [16, 16]>>) {
+// CHECK-NEXT: %{{.*}} = xegpu.load_nd %[[T1]][%{{.*}}] <{layout = #xegpu.layout<inst_data = [8, 16]>}> :
+// CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<inst_data = [8, 16]>> -> vector<16x16xf16>
+// CHECK: scf.yield %{{.*}}, %{{.*}} : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16, #xegpu.layout<inst_data = [16, 16]>>
+// CHECK: xegpu.prefetch_nd %[[T0]][%{{.*}}] <{layout = #xegpu.layout<inst_data = [16, 16]>}> :
+// CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<inst_data = [16, 16]>>
+// CHECK-NEXT: return %{{.*}}#0 : vector<16x16xf16>
+func.func @load_nd_with_conflicting_tensor_desc_in_loop(%arg0: memref<64x64xf16>) -> vector<16x16xf16> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c4 = arith.constant 4 : index
+ %cst = arith.constant dense<0.0> : vector<16x16xf16>
+ %0 = xegpu.create_nd_tdesc %arg0 : memref<64x64xf16>
+ -> !xegpu.tensor_desc<16x16xf16, #prefetch_lo>
+ %1:2 = scf.for %i = %c0 to %c4 step %c1 iter_args(%acc = %cst, %tdesc = %0) -> (vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16, #prefetch_lo>) {
+ %2 = xegpu.load_nd %tdesc [%c0, %c0] {layout = #load_lo} : !xegpu.tensor_desc<16x16xf16, #prefetch_lo>
+ -> vector<16x16xf16>
+ %3 = arith.addf %acc, %2 : vector<16x16xf16>
+ scf.yield %3, %tdesc : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16, #prefetch_lo>
+ }
+ xegpu.prefetch_nd %0 [%c0, %c0] {layout = #prefetch_lo} : !xegpu.tensor_desc<16x16xf16, #prefetch_lo>
+ return %1#0 : vector<16x16xf16>
+}
+}
diff --git a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
index 1a1520dfa975d..c8a6a6d7b8eb8 100644
--- a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
+++ b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
@@ -277,6 +277,80 @@ struct TestXeGPUMoveFuncBodyToWarpOp
}
};
+struct TestXeGPUPropagateLayouts
+ : public PassWrapper<TestXeGPUPropagateLayouts,
+ OperationPass<gpu::GPUModuleOp>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestXeGPUPropagateLayouts)
+
+ StringRef getArgument() const final { return "test-xegpu-propagate-layouts"; }
+
+ StringRef getDescription() const final {
+ return "Test the implementation of XeGPU propagate layouts.";
+ }
+
+ void getDependentDialects(::mlir::DialectRegistry ®istry) const override {
+ registry.insert<xegpu::XeGPUDialect>();
+ registry.insert<gpu::GPUDialect>();
+ }
+
+ TestXeGPUPropagateLayouts() = default;
+ TestXeGPUPropagateLayouts(const TestXeGPUPropagateLayouts &pass)
+ : PassWrapper(pass) {}
+
+ Option<std::string> layoutKind{
+ *this, "layout-kind",
+ llvm::cl::desc("Propagate `subgroup` / `inst` / `lane` level of xegpu "
+ "layouts."),
+ llvm::cl::init("lane")};
+
+ void runOnOperation() override {
+ OpBuilder builder(getOperation());
+ LayoutKind kind;
+ if (layoutKind == "subgroup")
+ kind = LayoutKind::Subgroup;
+ else if (layoutKind == "inst")
+ kind = LayoutKind::InstData;
+ else if (layoutKind == "lane")
+ kind = LayoutKind::Lane;
+ else {
+ signalPassFailure();
+ return;
+ }
+ if (failed(xegpu::propagateLayouts(builder, getOperation(), kind))) {
+ signalPassFailure();
+ }
+ }
+};
+
+struct TestXeGPUResolveLayoutConflicts
+ : public PassWrapper<TestXeGPUResolveLayoutConflicts,
+ OperationPass<gpu::GPUModuleOp>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestXeGPUResolveLayoutConflicts)
+
+ StringRef getArgument() const final {
+ return "test-xegpu-resolve-layout-conflicts";
+ }
+
+ StringRef getDescription() const final {
+ return "Test the implementation of XeGPU layout conflict resolution.";
+ }
+
+ void getDependentDialects(::mlir::DialectRegistry ®istry) const override {
+ registry.insert<xegpu::XeGPUDialect>();
+ registry.insert<gpu::GPUDialect>();
+ }
+
+ TestXeGPUResolveLayoutConflicts() = default;
+ TestXeGPUResolveLayoutConflicts(const TestXeGPUResolveLayoutConflicts &pass) =
+ default;
+
+ void runOnOperation() override {
+ if (failed(xegpu::resolveLayoutConflicts(getOperation()))) {
+ signalPassFailure();
+ }
+ }
+};
+
struct TestXeGPULayoutInterface
: public PassWrapper<TestXeGPULayoutInterface,
OperationPass<gpu::GPUModuleOp>> {
@@ -342,6 +416,8 @@ void registerTestXeGPULowerings() {
PassRegistration<TestXeGPULayoutInterface>();
PassRegistration<TestXeGPUSGDistribute>();
PassRegistration<TestXeGPUMoveFuncBodyToWarpOp>();
+ PassRegistration<TestXeGPUPropagateLayouts>();
+ PassRegistration<TestXeGPUResolveLayoutConflicts>();
}
} // namespace test
} // namespace mlir
More information about the Mlir-commits
mailing list