[Mlir-commits] [mlir] [mlir][xegpu] Add initial support for layout conflict handling. (PR #173090)
Charitha Saumya
llvmlistbot at llvm.org
Wed Jan 28 08:25:45 PST 2026
================
@@ -1439,6 +1441,121 @@ void RunLayoutInfoPropagation::printAnalysisResult(llvm::raw_ostream &os) {
printFunctionResult(funcOp);
}
+namespace {
+
+//===----------------------------------------------------------------------===//
+// ResolveLayoutConflicts
+//===----------------------------------------------------------------------===//
+struct ResolveLayoutConflicts {
+ ResolveLayoutConflicts(Operation *parentOp)
+ : parentOp(parentOp), builder(parentOp->getContext()) {}
+ LogicalResult run();
+
+private:
+ Operation *parentOp;
+ OpBuilder builder;
+ LogicalResult resolveTensorDescConsumer(OpOperand &operand);
+ LogicalResult resolveVectorConsumer(OpOperand &operand);
+};
+
+} // namespace
+
+LogicalResult ResolveLayoutConflicts::run() {
+ // Scan all operations in the parent operation and resolve layout conflicts at
+ // tensor descriptor and vector use points.
+ auto r = parentOp->walk([&](Operation *op) -> WalkResult {
+ for (OpOperand &operand : op->getOpOperands()) {
+ // Handle conflicts in tensor descriptor operands.
+ Type operandType = operand.get().getType();
+ if (isa<xegpu::AnchorLayoutInterface>(op) &&
+ isa<xegpu::TensorDescType>(operandType)) {
+ auto res = resolveTensorDescConsumer(operand);
+ return succeeded(res) ? WalkResult::advance() : WalkResult::interrupt();
+ }
+ // Handle conflicts in vector operands.
+ if (isa<VectorType>(operandType)) {
+ auto res = resolveVectorConsumer(operand);
+ return succeeded(res) ? WalkResult::advance() : WalkResult::interrupt();
+ }
+ }
+ return WalkResult::advance();
+ });
+
+ return r.wasInterrupted() ? failure() : success();
+}
+
+/// Helper to get the defining CreateNdDescOp of a tensor descriptor value. This
+/// function tries to find the defining CreateNdDescOp recursively accross
+/// control-flow boundaries.
+static xegpu::CreateNdDescOp getDefiningCreateNdDescOp(Value tdescValue) {
+ // Try to get the defining CreateNdDescOp of the tensor descriptor.
+ auto definingOp = tdescValue.getDefiningOp<xegpu::CreateNdDescOp>();
+ if (definingOp)
+ return definingOp;
+ // If tdescValue is an argument, try to get the tied init value from the
+ // parent loop-like op.
+ if (auto arg = dyn_cast<BlockArgument>(tdescValue)) {
+ auto *parentOp = arg.getOwner()->getParentOp();
+ if (auto loop = dyn_cast<LoopLikeOpInterface>(parentOp)) {
+ OpOperand *tiedInit = loop.getTiedLoopInit(arg);
+ if (tiedInit)
+ return getDefiningCreateNdDescOp(tiedInit->get());
+ }
+ }
+ // If not found, return null.
+ return nullptr;
+}
+
+LogicalResult
+ResolveLayoutConflicts::resolveVectorConsumer(OpOperand &operand) {
+ // TODO: Implement vector consumer layout conflict resolution. Requires layout
+ // utilities.
+ return success();
+}
+
+LogicalResult
+ResolveLayoutConflicts::resolveTensorDescConsumer(OpOperand &operand) {
+ Operation *consumerOp = operand.getOwner();
+ Value tdescValue = operand.get();
+ auto anchorOp = dyn_cast<xegpu::AnchorLayoutInterface>(consumerOp);
+ auto currTDescType = dyn_cast<xegpu::TensorDescType>(tdescValue.getType());
+ assert(anchorOp && currTDescType &&
+ "Expected anchor layout op and tensor descriptor consumer.");
+ // TODO: Scattered tensor desc is not supported for now.
+ if (currTDescType.isScattered()) {
+ DBGS() << "Scattered tensor descriptor not supported: " << tdescValue
+ << "\n";
+ return failure();
+ }
+ Attribute currLayout = currTDescType.getLayout();
+ Attribute expectedLayout = anchorOp.getAnchorLayout();
+ // A conflict exists in tensot descriptor operand if tensor descriptor's
----------------
charithaintc wrote:
fixed.
https://github.com/llvm/llvm-project/pull/173090
More information about the Mlir-commits
mailing list