[Mlir-commits] [mlir] [MLIR][XeGPU] Use context-aware type converter in WgToSgDistribute and Blocking pass (PR #194685)

Jianhui Li llvmlistbot at llvm.org
Fri May 22 12:56:29 PDT 2026


================
@@ -1580,78 +1543,22 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
     return;
   }
 
-  // Track existing UnrealizedConversionCastOps
-  SmallVector<Operation *> existingCastOps;
-  getOperation()->walk([&](UnrealizedConversionCastOp castOp) {
-    existingCastOps.push_back(castOp.getOperation());
-  });
+  // Collect existing UnrealizedConversionCastOps. These must be preserved.
+  llvm::SmallSetVector<UnrealizedConversionCastOp, 8> existingCasts;
+  getOperation()->walk(
+      [&](UnrealizedConversionCastOp castOp) { existingCasts.insert(castOp); });
 
-  {
-    // Step 1: Apply SCFStructuralTypeConversions to SCF operations with
-    // VectorType operands. This first converts such operands to
-    // RankedTensorType, propagates the layout attribute into the encoding
-    // attribute, and finally converts the RankedTensorType to VectorType based
-    // on the encoding.
-
-    TypeConverter converter;
-    converter.addConversion([&](Type type) -> Type { return type; });
-    converter.addConversion(
-        [&](RankedTensorType type,
-            SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> {
-          // Only convert RankedTensorTypes that carry an XeGPU layout encoding.
-          // Plain tensors (e.g. tensor<?xi32>) have no XeGPU encoding and must
-          // not be converted: VectorType does not support dynamic dimensions.
-          auto encoding = dyn_cast_if_present<xegpu::DistributeLayoutAttr>(
-              type.getEncoding());
-          if (!encoding)
-            return std::nullopt;
-
-          Type elemTy = type.getElementType();
-          ArrayRef<int64_t> shape = type.getShape();
-
-          int count;
-          SmallVector<int64_t> subShape;
-          std::tie(subShape, count) = getSgShapeAndCount(shape, encoding);
-
-          auto newTy = VectorType::get(subShape, elemTy);
-          result.append(count, newTy);
-          return success();
-        });
-
-    xegpu::doSCFStructuralTypeConversionWithTensorType(getOperation(),
-                                                       converter);
-  }
-
-  // Step 2: Perform workgroup to subgroup distribution for TensorDesc values,
-  // as well as XeGPU, Arith, and Vector operations.
+  // Perform workgroup to subgroup distribution for TensorDesc and Vector
+  // values, as well as XeGPU, Arith, and Vector operations. Uses a
+  // context-aware type converter that inspects Values to retrieve the
+  // distribute layout attribute for 1:N type conversion.
   MLIRContext *ctx = &getContext();
   RewritePatternSet patterns(ctx);
   ConversionTarget target(*ctx);
   TypeConverter converter;
-  converter.addConversion([&](Type type) -> Type { return type; });
-  converter.addConversion(
-      [&](xegpu::TensorDescType type,
-          SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> {
-        xegpu::DistributeLayoutAttr layout = type.getLayoutAttr();
-        // Only convert WG-level tensor descs. SG-level or layout-less types
-        // are already legal and should pass through unchanged.
-        if (!layout || !layout.isForWorkgroup())
-          return std::nullopt;
-
-        Type elemTy = type.getElementType();
-        ArrayRef<int64_t> shape = type.getShape();
-
-        int count;
-        SmallVector<int64_t> subShape;
-        std::tie(subShape, count) = getSgShapeAndCount(shape, layout);
-
-        layout = layout.dropSgLayoutAndData();
-
-        auto newTy = xegpu::TensorDescType::get(
-            type.getContext(), subShape, elemTy, type.getEncoding(), layout);
-        result.append(count, newTy);
-        return success();
-      });
+  xegpu::addSCFStructuralMaterializations(converter);
----------------
Jianhui-Li wrote:

Why this function, addSCFStructuralMaterializations(), is not inside populateXeGPUWgToSgDistributeTypeConversions(), to be side by side with addContextAwareVectorTypeConversion(), as both are realted to SCF handling?  Or I miss something?

https://github.com/llvm/llvm-project/pull/194685


More information about the Mlir-commits mailing list