[Mlir-commits] [mlir] [MLIR][XeGPU] Use context-aware type converter in WgToSgDistribute and Blocking pass (PR #194685)

Jianhui Li llvmlistbot at llvm.org
Fri May 22 12:56:29 PDT 2026


================
@@ -986,3 +842,155 @@ bool xegpu::matchSplitDimExpansion(
   }
   return srcIdx == src.size();
 }
+
+//===----------------------------------------------------------------------===//
+// Context-aware type conversion utilities
+//===----------------------------------------------------------------------===//
+
+void xegpu::addSCFStructuralMaterializations(TypeConverter &converter) {
+  auto materializeCast = [](OpBuilder &builder, Type type, ValueRange inputs,
+                            Location loc) -> Value {
+    return UnrealizedConversionCastOp::create(builder, loc, type, inputs)
+        .getResult(0);
+  };
+  // Source materialization: N:1 (N converted values -> 1 original value).
+  converter.addSourceMaterialization(materializeCast);
+  // Target materialization: 1:1 (single value type conversion).
+  converter.addTargetMaterialization(materializeCast);
+}
+
+void xegpu::addContextAwareVectorTypeConversion(
+    TypeConverter &converter, Operation *topLevelOp,
+    SubShapeAndCountFn getSubShapeAndCount) {
+  // Pre-compute 1:N type mappings for scf.while block arguments only.
+  // During scf.while structural conversion, blocks are detached from their
+  // parent region before convertBlockSignature is called. Block::getParent()
+  // crashes on detached blocks (LLVM ilist assertion), so we cannot look up
+  // layout attributes at that point. Other SCF ops (scf.for, scf.if) keep
+  // blocks attached during conversion.
+  auto whileArgTypeMap = std::make_shared<DenseMap<Value, SmallVector<Type>>>();
+  auto recordBlockArgTypes = [&](Value init, BlockArgument arg) {
+    auto vecTy = dyn_cast<VectorType>(init.getType());
+    if (!vecTy)
+      return;
+    auto layout = xegpu::getDistributeLayoutAttr(init);
+    if (!layout)
+      return;
+    auto [subShape, count] = getSubShapeAndCount(vecTy, layout);
+    if (count <= 0)
+      return;
+    auto newTy = VectorType::get(subShape, vecTy.getElementType());
+    SmallVector<Type> types(count, newTy);
+    (*whileArgTypeMap)[arg] = std::move(types);
+  };
+  topLevelOp->walk([&](scf::WhileOp whileOp) {
+    // "before" region block arguments correspond to the `inits` operands.
+    for (auto [init, arg] :
+         llvm::zip(whileOp.getInits(), whileOp.getBeforeArguments()))
+      recordBlockArgTypes(init, arg);
+    // "after" region block arguments correspond to the operands of the
+    // embedded `scf.condition` op (not the `inits`). In general the two
+    // type lists may differ.
+    scf::ConditionOp condOp = whileOp.getConditionOp();
+    for (auto [condArg, arg] :
+         llvm::zip(condOp.getArgs(), whileOp.getAfterArguments()))
+      recordBlockArgTypes(condArg, arg);
+  });
+
+  // Context-aware 1:N conversion for VectorType. For scf.while block
+  // arguments, uses the pre-computed map. For all other Values, retrieves
+  // the layout directly via getDistributeLayoutAttr.
+  converter.addConversion(
+      [whileArgTypeMap, getSubShapeAndCount](
+          Value v,
+          SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> {
+        if (!isa<VectorType>(v.getType()))
+          return std::nullopt;
+
+        // Check pre-computed map first (for scf.while block args).
+        if (isa<BlockArgument>(v)) {
----------------
Jianhui-Li wrote:

Yes. The layout recovery process guarantee this. 

https://github.com/llvm/llvm-project/pull/194685


More information about the Mlir-commits mailing list