[Mlir-commits] [mlir] [MLIR][XeGPU] Use context-aware type converter in WgToSgDistribute and Blocking pass (PR #194685)
Jianhui Li
llvmlistbot at llvm.org
Fri May 22 12:56:29 PDT 2026
================
@@ -986,3 +842,155 @@ bool xegpu::matchSplitDimExpansion(
}
return srcIdx == src.size();
}
+
+//===----------------------------------------------------------------------===//
+// Context-aware type conversion utilities
+//===----------------------------------------------------------------------===//
+
+void xegpu::addSCFStructuralMaterializations(TypeConverter &converter) {
+ auto materializeCast = [](OpBuilder &builder, Type type, ValueRange inputs,
+ Location loc) -> Value {
+ return UnrealizedConversionCastOp::create(builder, loc, type, inputs)
+ .getResult(0);
+ };
+ // Source materialization: N:1 (N converted values -> 1 original value).
+ converter.addSourceMaterialization(materializeCast);
+ // Target materialization: 1:1 (single value type conversion).
+ converter.addTargetMaterialization(materializeCast);
+}
+
+void xegpu::addContextAwareVectorTypeConversion(
+ TypeConverter &converter, Operation *topLevelOp,
+ SubShapeAndCountFn getSubShapeAndCount) {
+ // Pre-compute 1:N type mappings for scf.while block arguments only.
+ // During scf.while structural conversion, blocks are detached from their
+ // parent region before convertBlockSignature is called. Block::getParent()
+ // crashes on detached blocks (LLVM ilist assertion), so we cannot look up
+ // layout attributes at that point. Other SCF ops (scf.for, scf.if) keep
+ // blocks attached during conversion.
+ auto whileArgTypeMap = std::make_shared<DenseMap<Value, SmallVector<Type>>>();
+ auto recordBlockArgTypes = [&](Value init, BlockArgument arg) {
+ auto vecTy = dyn_cast<VectorType>(init.getType());
+ if (!vecTy)
+ return;
+ auto layout = xegpu::getDistributeLayoutAttr(init);
+ if (!layout)
+ return;
+ auto [subShape, count] = getSubShapeAndCount(vecTy, layout);
+ if (count <= 0)
+ return;
+ auto newTy = VectorType::get(subShape, vecTy.getElementType());
+ SmallVector<Type> types(count, newTy);
+ (*whileArgTypeMap)[arg] = std::move(types);
+ };
+ topLevelOp->walk([&](scf::WhileOp whileOp) {
+ // "before" region block arguments correspond to the `inits` operands.
+ for (auto [init, arg] :
+ llvm::zip(whileOp.getInits(), whileOp.getBeforeArguments()))
+ recordBlockArgTypes(init, arg);
+ // "after" region block arguments correspond to the operands of the
+ // embedded `scf.condition` op (not the `inits`). In general the two
+ // type lists may differ.
+ scf::ConditionOp condOp = whileOp.getConditionOp();
+ for (auto [condArg, arg] :
+ llvm::zip(condOp.getArgs(), whileOp.getAfterArguments()))
+ recordBlockArgTypes(condArg, arg);
+ });
+
+ // Context-aware 1:N conversion for VectorType. For scf.while block
+ // arguments, uses the pre-computed map. For all other Values, retrieves
+ // the layout directly via getDistributeLayoutAttr.
+ converter.addConversion(
+ [whileArgTypeMap, getSubShapeAndCount](
+ Value v,
+ SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> {
+ if (!isa<VectorType>(v.getType()))
+ return std::nullopt;
+
+ // Check pre-computed map first (for scf.while block args).
+ if (isa<BlockArgument>(v)) {
----------------
Jianhui-Li wrote:
Yes. The layout recovery process guarantee this.
https://github.com/llvm/llvm-project/pull/194685
More information about the Mlir-commits
mailing list