[Mlir-commits] [mlir] [mlir][xegpu] Refine layout assignment in XeGPU SIMT distribution. (PR #142687)

Charitha Saumya llvmlistbot at llvm.org
Tue Jun 3 17:01:52 PDT 2025


https://github.com/charithaintc updated https://github.com/llvm/llvm-project/pull/142687

>From ff1012e2208ef866a0313289d4bf6e130d1a0eaf Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Tue, 27 May 2025 23:40:57 +0000
Subject: [PATCH 01/10] add bug fix

---
 .../Vector/Transforms/VectorDistribute.cpp    | 42 ++++++++++++++-----
 1 file changed, 32 insertions(+), 10 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 045c192787f10..1649fb5f91b42 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -15,10 +15,13 @@
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
 #include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/Value.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "mlir/Transforms/RegionUtils.h"
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SetVector.h"
 #include "llvm/Support/FormatVariadic.h"
+#include "llvm/Support/raw_ostream.h"
 #include <utility>
 
 using namespace mlir;
@@ -1554,22 +1557,36 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
     llvm::SmallSetVector<Value, 32> escapingValues;
     SmallVector<Type> inputTypes;
     SmallVector<Type> distTypes;
+    auto collectEscapingValues = [&](Value value) {
+      if (!escapingValues.insert(value))
+        return;
+      Type distType = value.getType();
+      if (auto vecType = dyn_cast<VectorType>(distType)) {
+        AffineMap map = distributionMapFn(value);
+        distType = getDistributedType(vecType, map, warpOp.getWarpSize());
+      }
+      inputTypes.push_back(value.getType());
+      distTypes.push_back(distType);
+    };
+
     mlir::visitUsedValuesDefinedAbove(
         forOp.getBodyRegion(), [&](OpOperand *operand) {
           Operation *parent = operand->get().getParentRegion()->getParentOp();
           if (warpOp->isAncestor(parent)) {
-            if (!escapingValues.insert(operand->get()))
-              return;
-            Type distType = operand->get().getType();
-            if (auto vecType = dyn_cast<VectorType>(distType)) {
-              AffineMap map = distributionMapFn(operand->get());
-              distType = getDistributedType(vecType, map, warpOp.getWarpSize());
-            }
-            inputTypes.push_back(operand->get().getType());
-            distTypes.push_back(distType);
+            collectEscapingValues(operand->get());
           }
         });
 
+    // Any forOp result that is not already yielded by the warpOp
+    // region is also considered escaping.
+    for (OpResult forResult : forOp.getResults()) {
+      // Check if this forResult is already yielded by the yield op.
+      if (llvm::is_contained(yield->getOperands(), forResult)) {
+        continue;
+      }
+      collectEscapingValues(forResult);
+    }
+
     if (llvm::is_contained(distTypes, Type{}))
       return failure();
 
@@ -1609,7 +1626,12 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
                                     forOp.getResultTypes().end());
     llvm::SmallDenseMap<Value, int64_t> argIndexMapping;
     for (auto [i, retIdx] : llvm::enumerate(newRetIndices)) {
-      warpInput.push_back(newWarpOp.getResult(retIdx));
+      auto newWarpResult = newWarpOp.getResult(retIdx);
+      // Unused forOp results yielded by the warpOp region are already included
+      // in the new ForOp.
+      if (llvm::is_contained(newOperands, newWarpResult))
+        continue;
+      warpInput.push_back(newWarpResult);
       argIndexMapping[escapingValues[i]] = warpInputType.size();
       warpInputType.push_back(inputTypes[i]);
     }

>From c6eb53fefded7152c2d627c4094b66f616bc53ed Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Wed, 28 May 2025 20:22:47 +0000
Subject: [PATCH 02/10] add test

---
 .../Vector/vector-warp-distribute.mlir        | 36 +++++++++++++++++++
 1 file changed, 36 insertions(+)

diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index 38771f2593449..6c7ac7a5196a7 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -584,6 +584,42 @@ func.func @warp_scf_for_multiple_yield(%arg0: index, %arg1: memref<?xf32>, %arg2
   return
 }
 
+// -----
+// CHECK-PROP-LABEL: func.func @warp_scf_for_unused_yield(
+//       CHECK-PROP: %[[W0:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<4xf32>, vector<4xf32>) {
+//       CHECK-PROP: %[[INI0:.*]] = "some_def"() : () -> vector<128xf32>
+//       CHECK-PROP: %[[INI1:.*]] = "some_def"() : () -> vector<128xf32>
+//       CHECK-PROP: gpu.yield %[[INI0]], %[[INI1]] : vector<128xf32>, vector<128xf32>
+//       CHECK-PROP: }
+//       CHECK-PROP: %[[F:.*]]:2 = scf.for %{{.*}} iter_args(%{{.*}} = %[[W0]]#0, %{{.*}} = %[[W0]]#1) -> (vector<4xf32>, vector<4xf32>) {
+//       CHECK-PROP: %[[W1:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[32] args(%{{.*}} : vector<4xf32>, vector<4xf32>) -> (vector<4xf32>, vector<4xf32>) {
+//       CHECK-PROP: %[[ACC0:.*]] = "some_def"(%{{.*}}) : (vector<128xf32>, index) -> vector<128xf32>
+//       CHECK-PROP: %[[ACC1:.*]] = "some_def"(%{{.*}}) : (index, vector<128xf32>, vector<128xf32>) -> vector<128xf32>
+//       CHECK-PROP: gpu.yield %[[ACC1]], %[[ACC0]] : vector<128xf32>, vector<128xf32>
+//       CHECK-PROP: }
+//       CHECK-PROP: scf.yield %[[W1]]#0, %[[W1]]#1 : vector<4xf32>, vector<4xf32>
+//       CHECK-PROP: }
+//       CHECK-PROP: "some_use"(%[[F]]#0) : (vector<4xf32>) -> ()
+func.func @warp_scf_for_unused_yield(%arg0: index) {
+  %c128 = arith.constant 128 : index
+  %c1 = arith.constant 1 : index
+  %c0 = arith.constant 0 : index
+  %0 = gpu.warp_execute_on_lane_0(%arg0)[32] -> (vector<4xf32>) {
+    %ini = "some_def"() : () -> (vector<128xf32>)
+    %ini1 = "some_def"() : () -> (vector<128xf32>)
+    %3:2 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %ini, %arg5 = %ini1) -> (vector<128xf32>, vector<128xf32>) {
+      %add = arith.addi %arg3, %c1 : index
+      %1  = "some_def"(%arg5, %add) : (vector<128xf32>, index) -> (vector<128xf32>)
+      %acc = "some_def"(%add, %arg4, %1) : (index, vector<128xf32>, vector<128xf32>) -> (vector<128xf32>)
+      scf.yield %acc, %1 : vector<128xf32>, vector<128xf32>
+    }
+    gpu.yield %3#0 : vector<128xf32>
+  }
+  "some_use"(%0) : (vector<4xf32>) -> ()
+  return
+}
+
+
 // -----
 
 // CHECK-PROP-LABEL: func @vector_reduction(

>From 3bdb5961d48bf70b63560820375d24e0682dbff8 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Wed, 28 May 2025 20:26:01 +0000
Subject: [PATCH 03/10] add comments

---
 mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 1649fb5f91b42..94435588459e6 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -1578,7 +1578,8 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
         });
 
     // Any forOp result that is not already yielded by the warpOp
-    // region is also considered escaping.
+    // region is also considered escaping and must be returned by the
+    // original warpOp.
     for (OpResult forResult : forOp.getResults()) {
       // Check if this forResult is already yielded by the yield op.
       if (llvm::is_contained(yield->getOperands(), forResult)) {

>From fe3ab99da99bfe47dd257a458d01ddd4e24df63e Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Wed, 28 May 2025 21:32:04 +0000
Subject: [PATCH 04/10] remove unsused headers

---
 mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp | 3 ---
 1 file changed, 3 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 94435588459e6..bd833ddb773f7 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -15,13 +15,10 @@
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
 #include "mlir/IR/AffineExpr.h"
-#include "mlir/IR/Value.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "mlir/Transforms/RegionUtils.h"
-#include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SetVector.h"
 #include "llvm/Support/FormatVariadic.h"
-#include "llvm/Support/raw_ostream.h"
 #include <utility>
 
 using namespace mlir;

>From f91b64c88ef893a9a7d620cd76345c21a4a46d33 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Mon, 2 Jun 2025 18:24:58 +0000
Subject: [PATCH 05/10] save work

---
 .../Transforms/XeGPUSubgroupDistribute.cpp    | 218 +++++++++++++-----
 1 file changed, 164 insertions(+), 54 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 992700524146a..d178c2c33245e 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -12,6 +12,8 @@
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/Dialect/GPU/Utils/DistributionUtils.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SCF/Transforms/Patterns.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
 #include "mlir/Dialect/XeGPU/IR/XeGPU.h"
@@ -30,6 +32,7 @@
 #include "mlir/IR/Value.h"
 #include "mlir/IR/Visitors.h"
 #include "mlir/Interfaces/FunctionInterfaces.h"
+#include "mlir/Transforms/DialectConversion.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "mlir/Transforms/InliningUtils.h"
 #include "llvm/ADT/ArrayRef.h"
@@ -38,6 +41,7 @@
 #include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Support/FormatVariadic.h"
 #include "llvm/Support/InterleavedRange.h"
+#include "llvm/Support/LogicalResult.h"
 #include "llvm/Support/raw_ostream.h"
 
 namespace mlir {
@@ -701,7 +705,47 @@ namespace {
 //===----------------------------------------------------------------------===//
 // LayoutAttrAssignment
 //===----------------------------------------------------------------------===//
+template <typename OpTy>
+class UpdateTensorDescType : public OpConversionPattern<OpTy> {
+public:
+  UpdateTensorDescType(MLIRContext *context,
+                       function_ref<xegpu::LayoutAttr(Value)> getLayoutOfValue,
+                       TypeConverter &typeConverter, PatternBenefit benefit = 1)
+      : OpConversionPattern<OpTy>(typeConverter, context, benefit),
+        getLayoutOfValue(getLayoutOfValue) {}
+  using OpConversionPattern<OpTy>::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    // Op must have single result.
+    if (op->getNumResults() != 1)
+      return failure();
+    Type resultType = op->getResult(0).getType();
+    // Result type must be a tensor descriptor type.
+    if (!isa<xegpu::TensorDescType>(resultType)) {
+      LLVM_DEBUG(DBGS() << "Result type is not a tensor descriptor type: "
+                        << resultType << "\n");
+      return failure();
+    }
+    auto assignedLayout = getLayoutOfValue(op.getResult());
+    if (!assignedLayout) {
+      LLVM_DEBUG(DBGS() << "No layout assigned for " << *op << "\n");
+      return failure();
+    }
+    // Get the original tensor descriptor type.
+    auto origTensorDescTy = dyn_cast<xegpu::TensorDescType>(resultType);
+    auto newTensorDescTy = xegpu::TensorDescType::get(
+        origTensorDescTy.getContext(), origTensorDescTy.getShape(),
+        origTensorDescTy.getElementType(), origTensorDescTy.getEncoding(),
+        assignedLayout);
+    rewriter.replaceOpWithNewOp<OpTy>(op, newTensorDescTy,
+                                      adaptor.getOperands(), op->getAttrs());
+    return success();
+  }
 
+private:
+  function_ref<xegpu::LayoutAttr(Value)> getLayoutOfValue;
+};
 /// This class is responsible for assigning the layout attributes to the ops and
 /// their users based on the layout propagation analysis result.
 class LayoutAttrAssignment {
@@ -739,15 +783,19 @@ void LayoutAttrAssignment::assignToUsers(Value v, xegpu::LayoutAttr layout) {
 
 /// Convert the layout assigned to a value to xegpu::LayoutAttr.
 xegpu::LayoutAttr LayoutAttrAssignment::getLayoutAttrForValue(Value v) {
+  llvm::errs() << "getLayoutAttrForValue: " << v << "\n";
   LayoutInfo layout = getAnalysisResult(v);
-  if (!layout.isAssigned())
+  if (!layout.isAssigned()) {
+    llvm::errs() << "No layout assigned for value\n";
     return {};
+  }
   SmallVector<int, 2> laneLayout, laneData;
   for (auto [layout, data] : llvm::zip_equal(layout.getLayoutAsArrayRef(),
                                              layout.getDataAsArrayRef())) {
     laneLayout.push_back(static_cast<int>(layout));
     laneData.push_back(static_cast<int>(data));
   }
+  llvm::errs() << "return layout\n";
   return xegpu::LayoutAttr::get(v.getContext(), laneLayout, laneData);
 }
 
@@ -820,14 +868,23 @@ LogicalResult LayoutAttrAssignment::assign(Operation *op) {
 
 /// Walk the IR and attach xegpu::LayoutAttr to all ops and their users.
 LogicalResult LayoutAttrAssignment::run() {
-  auto walkResult = top->walk([&](Operation *op) {
-    if (failed(assign(op)))
-      return WalkResult::interrupt();
-    return WalkResult::advance();
-  });
-
-  if (walkResult.wasInterrupted())
-    return failure();
+  // auto walkResult = top->walk([&](Operation *op) {
+  //   if (failed(assign(op)))
+  //     return WalkResult::interrupt();
+  //   return WalkResult::advance();
+  // });
+
+  // if (walkResult.wasInterrupted())
+  //   return failure();
+  // apply the UpdateTensorDescType pattern to all ops
+  // RewritePatternSet patterns(top->getContext());
+  // patterns.add<UpdateTensorDescType>(
+  //     top->getContext(), [&](Value v) -> xegpu::LayoutAttr {
+  //       llvm::errs() << "invoking callback for value\n";
+  //       return getLayoutAttrForValue(v);
+  //     });
+  // if (failed(applyPatternsGreedily(top, std::move(patterns))))
+  //   return failure();
 
   return resolveConflicts();
 }
@@ -1597,56 +1654,109 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
     analyis.printAnalysisResult(os);
     return;
   }
-  auto getPropagatedLayout = [&](Value val) {
-    return analyis.getLayoutInfo(val);
+  // auto getPropagatedLayout = [&](Value val) {
+  //   return analyis.getLayoutInfo(val);
+  // };
+  auto getXeGpuLayoutForValue = [&](Value val) -> xegpu::LayoutAttr {
+    LayoutInfo layout = analyis.getLayoutInfo(val);
+    if (!layout.isAssigned()) {
+      llvm::errs() << "No layout assigned for value\n";
+      return {};
+    }
+    SmallVector<int, 2> laneLayout, laneData;
+    for (auto [layout, data] : llvm::zip_equal(layout.getLayoutAsArrayRef(),
+                                               layout.getDataAsArrayRef())) {
+      laneLayout.push_back(static_cast<int>(layout));
+      laneData.push_back(static_cast<int>(data));
+    }
+    return xegpu::LayoutAttr::get(val.getContext(), laneLayout, laneData);
+  };
+
+  ConversionTarget target(getContext());
+  target.addDynamicallyLegalOp<xegpu::CreateNdDescOp, xegpu::UpdateNdOffsetOp>(
+      [&](Operation *op) {
+        return llvm::all_of(op->getResults(), [&](Value val) {
+          if (auto descType = dyn_cast<xegpu::TensorDescType>(val.getType())) {
+            return descType.getLayoutAttr() != nullptr;
+          }
+          return true; // Non-tensor descriptor types are always legal.
+        });
+      });
+  target.addLegalOp<UnrealizedConversionCastOp>();
+  TypeConverter typeConverter;
+  typeConverter.addConversion([](Type type) { return type; });
+  // // typeConverter.addConversion([](xegpu::TensorDescType type) {
+  // //   return xegpu::TensorDescType::get(
+  // //       type.getContext(), type.getShape(), type.getElementType(),
+  // //       type.getEncoding(),
+  // //       xegpu::LayoutAttr::get(type.getContext(), {1, 1}, {1, 1}));
+  // // });
+  auto addUnrealizedCast = [](OpBuilder &builder, Type type, ValueRange inputs,
+                              Location loc) -> Value {
+    auto cast = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
+    return cast.getResult(0);
   };
 
+  typeConverter.addSourceMaterialization(addUnrealizedCast);
+  typeConverter.addTargetMaterialization(addUnrealizedCast);
+
+  RewritePatternSet patterns(&getContext());
+  patterns.add<UpdateTensorDescType<xegpu::CreateNdDescOp>,
+               UpdateTensorDescType<xegpu::UpdateNdOffsetOp>>(
+      &getContext(), getXeGpuLayoutForValue, typeConverter);
+  if (failed(
+          applyPartialConversion(getOperation(), target, std::move(patterns))))
+    signalPassFailure();
+
   // Assign xegpu::LayoutAttr to all ops and their users based on the layout
   // propagation analysis result.
-  LayoutAttrAssignment layoutAssignment(getOperation(), getPropagatedLayout);
-  if (failed(layoutAssignment.run())) {
-    signalPassFailure();
-    return;
-  }
+  // LayoutAttrAssignment layoutAssignment(getOperation(), getPropagatedLayout);
+  // if (failed(layoutAssignment.run())) {
+  //   signalPassFailure();
+  //   return;
+  // }
 
   // Move all operations of a GPU function inside gpu.warp_execute_on_lane_0
   // operation.
-  {
-    RewritePatternSet patterns(&getContext());
-    patterns.add<MoveFuncBodyToWarpExecuteOnLane0>(&getContext());
-
-    if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
-      signalPassFailure();
-      return;
-    }
-    // At this point, we have moved the entire function body inside the warpOp.
-    // Now move any scalar uniform code outside of the warpOp (like GPU index
-    // ops, scalar constants, etc.). This will simplify the later lowering and
-    // avoid custom patterns for these ops.
-    getOperation()->walk([&](Operation *op) {
-      if (auto warpOp = dyn_cast<gpu::WarpExecuteOnLane0Op>(op)) {
-        vector::moveScalarUniformCode(warpOp);
-      }
-    });
-  }
-  // Finally, do the SIMD to SIMT distribution.
-  RewritePatternSet patterns(&getContext());
-  xegpu::populateXeGPUSubgroupDistributePatterns(patterns);
-  // TODO: distributionFn and shuffleFn are not used at this point.
-  auto distributionFn = [](Value val) {
-    VectorType vecType = dyn_cast<VectorType>(val.getType());
-    int64_t vecRank = vecType ? vecType.getRank() : 0;
-    OpBuilder builder(val.getContext());
-    if (vecRank == 0)
-      return AffineMap::get(val.getContext());
-    return AffineMap::getMultiDimIdentityMap(vecRank, val.getContext());
-  };
-  auto shuffleFn = [](Location loc, OpBuilder &builder, Value val, Value srcIdx,
-                      int64_t warpSz) { return Value(); };
-  vector::populatePropagateWarpVectorDistributionPatterns(
-      patterns, distributionFn, shuffleFn);
-  if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
-    signalPassFailure();
-    return;
-  }
+  // {
+  //   RewritePatternSet patterns(&getContext());
+  //   patterns.add<MoveFuncBodyToWarpExecuteOnLane0>(&getContext());
+
+  //   if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
+  //     signalPassFailure();
+  //     return;
+  //   }
+  //   // At this point, we have moved the entire function body inside the
+  //   warpOp.
+  //   // Now move any scalar uniform code outside of the warpOp (like GPU index
+  //   // ops, scalar constants, etc.). This will simplify the later lowering
+  //   and
+  //   // avoid custom patterns for these ops.
+  //   getOperation()->walk([&](Operation *op) {
+  //     if (auto warpOp = dyn_cast<gpu::WarpExecuteOnLane0Op>(op)) {
+  //       vector::moveScalarUniformCode(warpOp);
+  //     }
+  //   });
+  // }
+  // // Finally, do the SIMD to SIMT distribution.
+  // RewritePatternSet patterns(&getContext());
+  // xegpu::populateXeGPUSubgroupDistributePatterns(patterns);
+  // // TODO: distributionFn and shuffleFn are not used at this point.
+  // auto distributionFn = [](Value val) {
+  //   VectorType vecType = dyn_cast<VectorType>(val.getType());
+  //   int64_t vecRank = vecType ? vecType.getRank() : 0;
+  //   OpBuilder builder(val.getContext());
+  //   if (vecRank == 0)
+  //     return AffineMap::get(val.getContext());
+  //   return AffineMap::getMultiDimIdentityMap(vecRank, val.getContext());
+  // };
+  // auto shuffleFn = [](Location loc, OpBuilder &builder, Value val, Value
+  // srcIdx,
+  //                     int64_t warpSz) { return Value(); };
+  // vector::populatePropagateWarpVectorDistributionPatterns(
+  //     patterns, distributionFn, shuffleFn);
+  // if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
+  //   signalPassFailure();
+  //   return;
+  // }
 }

>From 5cacace6c3f56f3d84b2a63003c2f3d9947b195a Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Mon, 2 Jun 2025 22:57:27 +0000
Subject: [PATCH 06/10] initial version

---
 .../Transforms/XeGPUSubgroupDistribute.cpp    | 487 ++++++++++--------
 1 file changed, 267 insertions(+), 220 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index d178c2c33245e..aa982ae779d1e 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -32,6 +32,7 @@
 #include "mlir/IR/Value.h"
 #include "mlir/IR/Visitors.h"
 #include "mlir/Interfaces/FunctionInterfaces.h"
+#include "mlir/Support/LLVM.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "mlir/Transforms/InliningUtils.h"
@@ -700,203 +701,264 @@ void RunLayoutInfoPropagation::printAnalysisResult(llvm::raw_ostream &os) {
   }
 }
 
-namespace {
+// namespace {
 
 //===----------------------------------------------------------------------===//
 // LayoutAttrAssignment
 //===----------------------------------------------------------------------===//
-template <typename OpTy>
-class UpdateTensorDescType : public OpConversionPattern<OpTy> {
-public:
-  UpdateTensorDescType(MLIRContext *context,
-                       function_ref<xegpu::LayoutAttr(Value)> getLayoutOfValue,
-                       TypeConverter &typeConverter, PatternBenefit benefit = 1)
-      : OpConversionPattern<OpTy>(typeConverter, context, benefit),
-        getLayoutOfValue(getLayoutOfValue) {}
-  using OpConversionPattern<OpTy>::OpConversionPattern;
-  LogicalResult
-  matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    // Op must have single result.
-    if (op->getNumResults() != 1)
-      return failure();
-    Type resultType = op->getResult(0).getType();
-    // Result type must be a tensor descriptor type.
-    if (!isa<xegpu::TensorDescType>(resultType)) {
-      LLVM_DEBUG(DBGS() << "Result type is not a tensor descriptor type: "
-                        << resultType << "\n");
-      return failure();
+// template <typename OpTy>
+// class UpdateTensorDescType : public OpConversionPattern<OpTy> {
+// public:
+//   UpdateTensorDescType(MLIRContext *context,
+//                        function_ref<xegpu::LayoutAttr(Value)>
+//                        getLayoutOfValue, TypeConverter &typeConverter,
+//                        PatternBenefit benefit = 1)
+//       : OpConversionPattern<OpTy>(typeConverter, context, benefit),
+//         getLayoutOfValue(getLayoutOfValue) {}
+//   using OpConversionPattern<OpTy>::OpConversionPattern;
+//   LogicalResult
+//   matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
+//                   ConversionPatternRewriter &rewriter) const override {
+//     // Op must have single result.
+//     if (op->getNumResults() != 1)
+//       return failure();
+//     Type resultType = op->getResult(0).getType();
+//     // Result type must be a tensor descriptor type.
+//     if (!isa<xegpu::TensorDescType>(resultType)) {
+//       LLVM_DEBUG(DBGS() << "Result type is not a tensor descriptor type: "
+//                         << resultType << "\n");
+//       return failure();
+//     }
+//     auto assignedLayout = getLayoutOfValue(op.getResult());
+//     if (!assignedLayout) {
+//       LLVM_DEBUG(DBGS() << "No layout assigned for " << *op << "\n");
+//       return failure();
+//     }
+//     // Get the original tensor descriptor type.
+//     auto origTensorDescTy = dyn_cast<xegpu::TensorDescType>(resultType);
+//     auto newTensorDescTy = xegpu::TensorDescType::get(
+//         origTensorDescTy.getContext(), origTensorDescTy.getShape(),
+//         origTensorDescTy.getElementType(), origTensorDescTy.getEncoding(),
+//         assignedLayout);
+//     rewriter.replaceOpWithNewOp<OpTy>(op, newTensorDescTy,
+//                                       adaptor.getOperands(), op->getAttrs());
+//     return success();
+//   }
+
+// private:
+//   function_ref<xegpu::LayoutAttr(Value)> getLayoutOfValue;
+// };
+// /// This class is responsible for assigning the layout attributes to the ops
+// and
+// /// their users based on the layout propagation analysis result.
+// class LayoutAttrAssignment {
+// public:
+//   LayoutAttrAssignment(Operation *top,
+//                        function_ref<LayoutInfo(Value)> getLayout)
+//       : getAnalysisResult(getLayout), top(top) {}
+
+//   LogicalResult run();
+
+// private:
+//   LogicalResult assign(Operation *op);
+//   void assignToUsers(Value v, xegpu::LayoutAttr layout);
+//   xegpu::LayoutAttr getLayoutAttrForValue(Value v);
+//   LogicalResult resolveConflicts();
+//   // Callable to get the layout of a value based on the layout propagation
+//   // analysis.
+//   function_ref<LayoutInfo(Value)> getAnalysisResult;
+//   Operation *top;
+// };
+
+// } // namespace
+
+// /// Helper to assign the layout attribute to the users of the value.
+// void LayoutAttrAssignment::assignToUsers(Value v, xegpu::LayoutAttr layout) {
+//   for (OpOperand &user : v.getUses()) {
+//     Operation *owner = user.getOwner();
+//     unsigned operandNumber = user.getOperandNumber();
+//     // Use a generic name for ease of querying the layout attribute later.
+//     std::string attrName =
+//         operandLayoutNamePrefix + std::to_string(operandNumber);
+//     owner->setAttr(attrName, layout);
+//   }
+// }
+
+// /// Convert the layout assigned to a value to xegpu::LayoutAttr.
+// xegpu::LayoutAttr LayoutAttrAssignment::getLayoutAttrForValue(Value v) {
+//   llvm::errs() << "getLayoutAttrForValue: " << v << "\n";
+//   LayoutInfo layout = getAnalysisResult(v);
+//   if (!layout.isAssigned()) {
+//     llvm::errs() << "No layout assigned for value\n";
+//     return {};
+//   }
+//   SmallVector<int, 2> laneLayout, laneData;
+//   for (auto [layout, data] : llvm::zip_equal(layout.getLayoutAsArrayRef(),
+//                                              layout.getDataAsArrayRef())) {
+//     laneLayout.push_back(static_cast<int>(layout));
+//     laneData.push_back(static_cast<int>(data));
+//   }
+//   llvm::errs() << "return layout\n";
+//   return xegpu::LayoutAttr::get(v.getContext(), laneLayout, laneData);
+// }
+
+// /// Assign xegpu::LayoutAttr to the op and its users. The layout is assigned
+// /// based on the layout propagation analysis result.
+// LogicalResult LayoutAttrAssignment::assign(Operation *op) {
+//   // For function ops, propagate the function argument layout to the users.
+//   if (auto func = dyn_cast<FunctionOpInterface>(op)) {
+//     for (BlockArgument arg : func.getArguments()) {
+//       xegpu::LayoutAttr layoutInfo = getLayoutAttrForValue(arg);
+//       if (layoutInfo) {
+//         assignToUsers(arg, layoutInfo);
+//       }
+//     }
+//     return success();
+//   }
+//   // If no results, move on.
+//   if (op->getNumResults() == 0)
+//     return success();
+//   // If all the results are scalars, move on.
+//   if (llvm::all_of(op->getResultTypes(),
+//                    [](Type t) { return t.isIntOrIndexOrFloat(); }))
+//     return success();
+//   // If the op has more than one result and at least one result is a tensor
+//   // descriptor, exit. This case is not supported yet.
+//   // TODO: Support this case.
+//   if (op->getNumResults() > 1 && llvm::any_of(op->getResultTypes(), [](Type
+//   t) {
+//         return isa<xegpu::TensorDescType>(t);
+//       })) {
+//     LLVM_DEBUG(
+//         DBGS() << op->getName()
+//                << " op has more than one result and at least one is a tensor
+//                "
+//                   "descriptor. This case is not handled.\n");
+//     return failure();
+//   }
+//   // If the result is a tensor descriptor, attach the layout to the tensor
+//   // descriptor itself.
+//   if (auto tensorDescTy =
+//           dyn_cast<xegpu::TensorDescType>(op->getResultTypes()[0])) {
+//     xegpu::LayoutAttr layoutInfo = getLayoutAttrForValue(op->getResult(0));
+//     if (!layoutInfo) {
+//       LLVM_DEBUG(DBGS() << "No layout for result of " << *op << "\n");
+//       return failure();
+//     }
+
+//     // Clone the op, attach the layout to the result tensor descriptor, and
+//     // remove the original op.
+//     OpBuilder builder(op);
+//     Operation *newOp = builder.clone(*op);
+//     auto newTensorDescTy = xegpu::TensorDescType::get(
+//         tensorDescTy.getContext(), tensorDescTy.getShape(),
+//         tensorDescTy.getElementType(), tensorDescTy.getEncoding(),
+//         layoutInfo);
+//     newOp->getResult(0).setType(newTensorDescTy);
+//     op->replaceAllUsesWith(newOp->getResults());
+//     op->erase();
+//     return success();
+//   }
+//   // Otherwise simply attach the layout to the op itself.
+//   for (auto [i, r] : llvm::enumerate(op->getResults())) {
+//     xegpu::LayoutAttr layoutInfo = getLayoutAttrForValue(r);
+//     if (layoutInfo) {
+//       std::string attrName = resultLayoutNamePrefix + std::to_string(i);
+//       op->setAttr(attrName, layoutInfo);
+//       // Attach the layout attribute to the users of the result.
+//       assignToUsers(r, layoutInfo);
+//     }
+//   }
+//   return success();
+// }
+
+// /// Walk the IR and attach xegpu::LayoutAttr to all ops and their users.
+// LogicalResult LayoutAttrAssignment::run() {
+//   // auto walkResult = top->walk([&](Operation *op) {
+//   //   if (failed(assign(op)))
+//   //     return WalkResult::interrupt();
+//   //   return WalkResult::advance();
+//   // });
+
+//   // if (walkResult.wasInterrupted())
+//   //   return failure();
+//   // apply the UpdateTensorDescType pattern to all ops
+//   // RewritePatternSet patterns(top->getContext());
+//   // patterns.add<UpdateTensorDescType>(
+//   //     top->getContext(), [&](Value v) -> xegpu::LayoutAttr {
+//   //       llvm::errs() << "invoking callback for value\n";
+//   //       return getLayoutAttrForValue(v);
+//   //     });
+//   // if (failed(applyPatternsGreedily(top, std::move(patterns))))
+//   //   return failure();
+
+//   return resolveConflicts();
+// }
+
+// /// TODO: Implement the layout conflict resolution. This must ensure mainly
+// two
+// /// things:
+// /// 1) Is a given layout supported by the op? (need to query the target
+// ///    HW info). Otherwise can we achieve this layout using a layout
+// conversion?
+// /// 2) Do all the operands have the required layout? If not, can it
+// ///    be resolved using a layout conversion?
+// LogicalResult LayoutAttrAssignment::resolveConflicts() { return success(); }
+using GetLayoutCallbackFnTy = function_ref<xegpu::LayoutAttr(Value)>;
+static void handleBranchTerminatorOpInterface(
+    mlir::OpBuilder &builder,
+    mlir::RegionBranchTerminatorOpInterface terminator,
+    GetLayoutCallbackFnTy getLayoutOfValue) {}
+static void handleBranchOpInterface(mlir::OpBuilder &builder,
+                                    mlir::RegionBranchOpInterface branch,
+                                    GetLayoutCallbackFnTy getLayoutOfValue) {}
+static void updateBlockTypes(mlir::OpBuilder &builder, mlir::Block &block,
+                             GetLayoutCallbackFnTy getLayoutOfValue) {}
+static void updateOp(mlir::OpBuilder &builder, mlir::Operation *op,
+                     GetLayoutCallbackFnTy getLayoutOfValue) {
+
+  auto updateValue = [&](Value v, unsigned vIndex,
+                         const std::string &layoutAttrName) {
+    // Layouts are needed only for vector and tensor descriptor types.
+    if (!isa<VectorType, xegpu::TensorDescType>(v.getType()))
+      return;
+    xegpu::LayoutAttr layout = getLayoutOfValue(v);
+    if (!layout) {
+      // TODO : handle error.
+      LLVM_DEBUG(DBGS() << "Expecting layout for value: " << v
+                        << " but got none.\n");
+      return;
     }
-    auto assignedLayout = getLayoutOfValue(op.getResult());
-    if (!assignedLayout) {
-      LLVM_DEBUG(DBGS() << "No layout assigned for " << *op << "\n");
-      return failure();
+    auto tensorDescTy = dyn_cast<xegpu::TensorDescType>(v.getType());
+
+    if (tensorDescTy) {
+      auto newTensorDescTy = xegpu::TensorDescType::get(
+          tensorDescTy.getContext(), tensorDescTy.getShape(),
+          tensorDescTy.getElementType(), tensorDescTy.getEncoding(), layout);
+      v.setType(newTensorDescTy);
+      return;
     }
-    // Get the original tensor descriptor type.
-    auto origTensorDescTy = dyn_cast<xegpu::TensorDescType>(resultType);
-    auto newTensorDescTy = xegpu::TensorDescType::get(
-        origTensorDescTy.getContext(), origTensorDescTy.getShape(),
-        origTensorDescTy.getElementType(), origTensorDescTy.getEncoding(),
-        assignedLayout);
-    rewriter.replaceOpWithNewOp<OpTy>(op, newTensorDescTy,
-                                      adaptor.getOperands(), op->getAttrs());
-    return success();
-  }
-
-private:
-  function_ref<xegpu::LayoutAttr(Value)> getLayoutOfValue;
-};
-/// This class is responsible for assigning the layout attributes to the ops and
-/// their users based on the layout propagation analysis result.
-class LayoutAttrAssignment {
-public:
-  LayoutAttrAssignment(Operation *top,
-                       function_ref<LayoutInfo(Value)> getLayout)
-      : getAnalysisResult(getLayout), top(top) {}
-
-  LogicalResult run();
-
-private:
-  LogicalResult assign(Operation *op);
-  void assignToUsers(Value v, xegpu::LayoutAttr layout);
-  xegpu::LayoutAttr getLayoutAttrForValue(Value v);
-  LogicalResult resolveConflicts();
-  // Callable to get the layout of a value based on the layout propagation
-  // analysis.
-  function_ref<LayoutInfo(Value)> getAnalysisResult;
-  Operation *top;
-};
-
-} // namespace
-
-/// Helper to assign the layout attribute to the users of the value.
-void LayoutAttrAssignment::assignToUsers(Value v, xegpu::LayoutAttr layout) {
-  for (OpOperand &user : v.getUses()) {
-    Operation *owner = user.getOwner();
-    unsigned operandNumber = user.getOperandNumber();
-    // Use a generic name for ease of querying the layout attribute later.
-    std::string attrName =
-        operandLayoutNamePrefix + std::to_string(operandNumber);
-    owner->setAttr(attrName, layout);
-  }
-}
-
-/// Convert the layout assigned to a value to xegpu::LayoutAttr.
-xegpu::LayoutAttr LayoutAttrAssignment::getLayoutAttrForValue(Value v) {
-  llvm::errs() << "getLayoutAttrForValue: " << v << "\n";
-  LayoutInfo layout = getAnalysisResult(v);
-  if (!layout.isAssigned()) {
-    llvm::errs() << "No layout assigned for value\n";
-    return {};
-  }
-  SmallVector<int, 2> laneLayout, laneData;
-  for (auto [layout, data] : llvm::zip_equal(layout.getLayoutAsArrayRef(),
-                                             layout.getDataAsArrayRef())) {
-    laneLayout.push_back(static_cast<int>(layout));
-    laneData.push_back(static_cast<int>(data));
-  }
-  llvm::errs() << "return layout\n";
-  return xegpu::LayoutAttr::get(v.getContext(), laneLayout, laneData);
-}
+    // If type is vector, add a temporary layout attribute to the op.
+    op->setAttr(layoutAttrName, layout);
+  };
 
-/// Assign xegpu::LayoutAttr to the op and its users. The layout is assigned
-/// based on the layout propagation analysis result.
-LogicalResult LayoutAttrAssignment::assign(Operation *op) {
-  // For function ops, propagate the function argument layout to the users.
-  if (auto func = dyn_cast<FunctionOpInterface>(op)) {
-    for (BlockArgument arg : func.getArguments()) {
-      xegpu::LayoutAttr layoutInfo = getLayoutAttrForValue(arg);
-      if (layoutInfo) {
-        assignToUsers(arg, layoutInfo);
-      }
-    }
-    return success();
-  }
-  // If no results, move on.
-  if (op->getNumResults() == 0)
-    return success();
-  // If all the results are scalars, move on.
-  if (llvm::all_of(op->getResultTypes(),
-                   [](Type t) { return t.isIntOrIndexOrFloat(); }))
-    return success();
-  // If the op has more than one result and at least one result is a tensor
-  // descriptor, exit. This case is not supported yet.
-  // TODO: Support this case.
-  if (op->getNumResults() > 1 && llvm::any_of(op->getResultTypes(), [](Type t) {
-        return isa<xegpu::TensorDescType>(t);
-      })) {
-    LLVM_DEBUG(
-        DBGS() << op->getName()
-               << " op has more than one result and at least one is a tensor "
-                  "descriptor. This case is not handled.\n");
-    return failure();
+  // Iterate over all the operands.
+  for (OpOperand &operand : op->getOpOperands()) {
+    unsigned operandIndex = operand.getOperandNumber();
+    std::string operandLayoutName =
+        operandLayoutNamePrefix + std::to_string(operandIndex);
+    updateValue(operand.get(), operandIndex, operandLayoutName);
   }
-  // If the result is a tensor descriptor, attach the layout to the tensor
-  // descriptor itself.
-  if (auto tensorDescTy =
-          dyn_cast<xegpu::TensorDescType>(op->getResultTypes()[0])) {
-    xegpu::LayoutAttr layoutInfo = getLayoutAttrForValue(op->getResult(0));
-    if (!layoutInfo) {
-      LLVM_DEBUG(DBGS() << "No layout for result of " << *op << "\n");
-      return failure();
-    }
 
-    // Clone the op, attach the layout to the result tensor descriptor, and
-    // remove the original op.
-    OpBuilder builder(op);
-    Operation *newOp = builder.clone(*op);
-    auto newTensorDescTy = xegpu::TensorDescType::get(
-        tensorDescTy.getContext(), tensorDescTy.getShape(),
-        tensorDescTy.getElementType(), tensorDescTy.getEncoding(), layoutInfo);
-    newOp->getResult(0).setType(newTensorDescTy);
-    op->replaceAllUsesWith(newOp->getResults());
-    op->erase();
-    return success();
+  // Iterate over all the results.
+  for (OpResult result : op->getResults()) {
+    unsigned resultIndex = result.getResultNumber();
+    std::string resultLayoutName =
+        resultLayoutNamePrefix + std::to_string(resultIndex);
+    updateValue(result, resultIndex, resultLayoutName);
   }
-  // Otherwise simply attach the layout to the op itself.
-  for (auto [i, r] : llvm::enumerate(op->getResults())) {
-    xegpu::LayoutAttr layoutInfo = getLayoutAttrForValue(r);
-    if (layoutInfo) {
-      std::string attrName = resultLayoutNamePrefix + std::to_string(i);
-      op->setAttr(attrName, layoutInfo);
-      // Attach the layout attribute to the users of the result.
-      assignToUsers(r, layoutInfo);
-    }
-  }
-  return success();
 }
 
-/// Walk the IR and attach xegpu::LayoutAttr to all ops and their users.
-LogicalResult LayoutAttrAssignment::run() {
-  // auto walkResult = top->walk([&](Operation *op) {
-  //   if (failed(assign(op)))
-  //     return WalkResult::interrupt();
-  //   return WalkResult::advance();
-  // });
-
-  // if (walkResult.wasInterrupted())
-  //   return failure();
-  // apply the UpdateTensorDescType pattern to all ops
-  // RewritePatternSet patterns(top->getContext());
-  // patterns.add<UpdateTensorDescType>(
-  //     top->getContext(), [&](Value v) -> xegpu::LayoutAttr {
-  //       llvm::errs() << "invoking callback for value\n";
-  //       return getLayoutAttrForValue(v);
-  //     });
-  // if (failed(applyPatternsGreedily(top, std::move(patterns))))
-  //   return failure();
-
-  return resolveConflicts();
-}
-
-/// TODO: Implement the layout conflict resolution. This must ensure mainly two
-/// things:
-/// 1) Is a given layout supported by the op? (need to query the target
-///    HW info). Otherwise can we achieve this layout using a layout conversion?
-/// 2) Do all the operands have the required layout? If not, can it
-///    be resolved using a layout conversion?
-LogicalResult LayoutAttrAssignment::resolveConflicts() { return success(); }
-
 namespace {
 
 //===----------------------------------------------------------------------===//
@@ -1657,10 +1719,10 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
   // auto getPropagatedLayout = [&](Value val) {
   //   return analyis.getLayoutInfo(val);
   // };
-  auto getXeGpuLayoutForValue = [&](Value val) -> xegpu::LayoutAttr {
+  auto getXeGPULayoutForValue = [&](Value val) -> xegpu::LayoutAttr {
     LayoutInfo layout = analyis.getLayoutInfo(val);
     if (!layout.isAssigned()) {
-      llvm::errs() << "No layout assigned for value\n";
+      llvm::errs() << "No layout assigned for value" << val << "\n";
       return {};
     }
     SmallVector<int, 2> laneLayout, laneData;
@@ -1672,41 +1734,26 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
     return xegpu::LayoutAttr::get(val.getContext(), laneLayout, laneData);
   };
 
-  ConversionTarget target(getContext());
-  target.addDynamicallyLegalOp<xegpu::CreateNdDescOp, xegpu::UpdateNdOffsetOp>(
-      [&](Operation *op) {
-        return llvm::all_of(op->getResults(), [&](Value val) {
-          if (auto descType = dyn_cast<xegpu::TensorDescType>(val.getType())) {
-            return descType.getLayoutAttr() != nullptr;
-          }
-          return true; // Non-tensor descriptor types are always legal.
-        });
-      });
-  target.addLegalOp<UnrealizedConversionCastOp>();
-  TypeConverter typeConverter;
-  typeConverter.addConversion([](Type type) { return type; });
-  // // typeConverter.addConversion([](xegpu::TensorDescType type) {
-  // //   return xegpu::TensorDescType::get(
-  // //       type.getContext(), type.getShape(), type.getElementType(),
-  // //       type.getEncoding(),
-  // //       xegpu::LayoutAttr::get(type.getContext(), {1, 1}, {1, 1}));
-  // // });
-  auto addUnrealizedCast = [](OpBuilder &builder, Type type, ValueRange inputs,
-                              Location loc) -> Value {
-    auto cast = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
-    return cast.getResult(0);
-  };
+  mlir::OpBuilder builder(&getContext());
+  Operation *op = getOperation();
+  op->walk([&](mlir::Block *block) {
+    for (mlir::Operation &op : llvm::reverse(block->getOperations())) {
+      if (auto terminator =
+              mlir::dyn_cast<mlir::RegionBranchTerminatorOpInterface>(op)) {
+        handleBranchTerminatorOpInterface(builder, terminator,
+                                          getXeGPULayoutForValue);
+        continue;
+      }
 
-  typeConverter.addSourceMaterialization(addUnrealizedCast);
-  typeConverter.addTargetMaterialization(addUnrealizedCast);
+      if (auto iface = mlir::dyn_cast<mlir::RegionBranchOpInterface>(op)) {
+        handleBranchOpInterface(builder, iface, getXeGPULayoutForValue);
+        continue;
+      }
+      updateOp(builder, &op, getXeGPULayoutForValue);
+    }
 
-  RewritePatternSet patterns(&getContext());
-  patterns.add<UpdateTensorDescType<xegpu::CreateNdDescOp>,
-               UpdateTensorDescType<xegpu::UpdateNdOffsetOp>>(
-      &getContext(), getXeGpuLayoutForValue, typeConverter);
-  if (failed(
-          applyPartialConversion(getOperation(), target, std::move(patterns))))
-    signalPassFailure();
+    updateBlockTypes(builder, *block, getXeGPULayoutForValue);
+  });
 
   // Assign xegpu::LayoutAttr to all ops and their users based on the layout
   // propagation analysis result.

>From 7d54194f0c726db4461015de87abf9ad380bbfa3 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Tue, 3 Jun 2025 19:50:30 +0000
Subject: [PATCH 07/10] working version

---
 .../Transforms/XeGPUSubgroupDistribute.cpp    | 159 +++++++++++++-----
 1 file changed, 120 insertions(+), 39 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index aa982ae779d1e..6b3ff8312e365 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -40,6 +40,7 @@
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/Debug.h"
 #include "llvm/Support/FormatVariadic.h"
 #include "llvm/Support/InterleavedRange.h"
 #include "llvm/Support/LogicalResult.h"
@@ -905,59 +906,140 @@ void RunLayoutInfoPropagation::printAnalysisResult(llvm::raw_ostream &os) {
 // ///    be resolved using a layout conversion?
 // LogicalResult LayoutAttrAssignment::resolveConflicts() { return success(); }
 using GetLayoutCallbackFnTy = function_ref<xegpu::LayoutAttr(Value)>;
+static void updateOp(mlir::OpBuilder &builder, mlir::Operation *op,
+                     GetLayoutCallbackFnTy getLayoutOfValue) {
+
+  // Iterate over all the results.
+  for (OpResult result : op->getResults()) {
+    Type resultType = result.getType();
+    // Layouts are needed only for vector and tensor descriptor types.
+    if (!isa<VectorType, xegpu::TensorDescType>(resultType))
+      continue;
+    // If the result has any users, we expect it to have a layout.
+    xegpu::LayoutAttr layout = getLayoutOfValue(result);
+    if (!layout && result.getNumUses() > 0) {
+      LLVM_DEBUG(DBGS() << "Expecting layout for result: " << result
+                        << " but got none.\n");
+      continue;
+    }
+    if (auto tensorDescTy = dyn_cast<xegpu::TensorDescType>(resultType)) {
+      // TODO: Handle error.
+      auto typeWithLayout = xegpu::TensorDescType::get(
+          tensorDescTy.getContext(), tensorDescTy.getShape(),
+          tensorDescTy.getElementType(), tensorDescTy.getEncoding(), layout);
+      result.setType(typeWithLayout);
+      continue;
+    }
+    // If the result is a vector type, add a temporary layout attribute to the
+    // op.
+    std::string resultLayoutName =
+        resultLayoutNamePrefix + std::to_string(result.getResultNumber());
+    op->setAttr(resultLayoutName, layout);
+    // Update all users of the result with the layout.
+    for (OpOperand &user : result.getUses()) {
+      Operation *owner = user.getOwner();
+      unsigned operandNumber = user.getOperandNumber();
+      // Add temorary layout attribute at the user op.
+      std::string attrName =
+          operandLayoutNamePrefix + std::to_string(operandNumber);
+      owner->setAttr(attrName, layout);
+    }
+  }
+}
 static void handleBranchTerminatorOpInterface(
     mlir::OpBuilder &builder,
     mlir::RegionBranchTerminatorOpInterface terminator,
     GetLayoutCallbackFnTy getLayoutOfValue) {}
 static void handleBranchOpInterface(mlir::OpBuilder &builder,
                                     mlir::RegionBranchOpInterface branch,
-                                    GetLayoutCallbackFnTy getLayoutOfValue) {}
-static void updateBlockTypes(mlir::OpBuilder &builder, mlir::Block &block,
-                             GetLayoutCallbackFnTy getLayoutOfValue) {}
-static void updateOp(mlir::OpBuilder &builder, mlir::Operation *op,
-                     GetLayoutCallbackFnTy getLayoutOfValue) {
+                                    GetLayoutCallbackFnTy getLayoutOfValue) {
+  mlir::Operation *op = branch.getOperation();
+  llvm::SmallVector<mlir::RegionSuccessor> successors;
+  llvm::SmallVector<mlir::Attribute> operands(op->getNumOperands(), nullptr);
+  branch.getEntrySuccessorRegions(operands, successors);
+  DenseMap<Value, xegpu::LayoutAttr> resultToLayouts;
+  mlir::ValueRange results = op->getResults();
+
+  for (mlir::RegionSuccessor &successor : successors) {
+    if (successor.isParent())
+      continue;
 
-  auto updateValue = [&](Value v, unsigned vIndex,
-                         const std::string &layoutAttrName) {
-    // Layouts are needed only for vector and tensor descriptor types.
-    if (!isa<VectorType, xegpu::TensorDescType>(v.getType()))
-      return;
-    xegpu::LayoutAttr layout = getLayoutOfValue(v);
+    mlir::OperandRange initArgs = branch.getEntrySuccessorOperands(successor);
+    mlir::ValueRange blockArgs = successor.getSuccessorInputs();
+    unsigned index = 0;
+
+    for (auto [initArg, blockArg, result] :
+         llvm::zip(initArgs, blockArgs, results)) {
+      Type inputType = blockArg.getType();
+      if (!isa<xegpu::TensorDescType>(inputType))
+        continue;
+      xegpu::LayoutAttr blockArgLayout = getLayoutOfValue(blockArg);
+      xegpu::LayoutAttr initArgLayout = getLayoutOfValue(initArg);
+
+      if (!blockArgLayout || !initArgLayout) {
+        LLVM_DEBUG(DBGS() << "No layout assigned for block arg: " << blockArg
+                          << " or init arg: " << initArg << "\n");
+        continue;
+      }
+
+      // TOOD: We expect these two to match. Data flow analysis will ensure
+      // this.
+      assert(blockArgLayout == initArgLayout &&
+             "Expexing block arg and init arg to have the same layout.");
+      // Get tensor descriptor type with the layout.
+      auto tdescTy = dyn_cast<xegpu::TensorDescType>(inputType);
+      auto newTdescTy = xegpu::TensorDescType::get(
+          tdescTy.getContext(), tdescTy.getShape(), tdescTy.getElementType(),
+          tdescTy.getEncoding(), blockArgLayout);
+      blockArg.setType(newTdescTy);
+      // Store the layout for the result.
+      if (resultToLayouts.count(result) != 0 &&
+          resultToLayouts[result] != blockArgLayout) {
+        LLVM_DEBUG(DBGS() << "Conflicting layouts for result: " << result
+                          << " - " << resultToLayouts[result] << " vs "
+                          << blockArgLayout << "\n");
+      } else {
+        resultToLayouts[result] = blockArgLayout;
+      }
+    }
+  }
+  for (auto [i, r] : llvm::enumerate(op->getResults())) {
+    Type resultType = r.getType();
+    if (!isa<xegpu::TensorDescType, VectorType>(resultType))
+      continue;
+    xegpu::LayoutAttr layout = getLayoutOfValue(r);
+    if (!layout)
+      layout = resultToLayouts[r];
     if (!layout) {
-      // TODO : handle error.
-      LLVM_DEBUG(DBGS() << "Expecting layout for value: " << v
-                        << " but got none.\n");
-      return;
+      LLVM_DEBUG(DBGS() << "No layout assigned for vector/tensor desc result: "
+                        << r << "\n");
+      continue;
     }
-    auto tensorDescTy = dyn_cast<xegpu::TensorDescType>(v.getType());
-
-    if (tensorDescTy) {
-      auto newTensorDescTy = xegpu::TensorDescType::get(
+    if (auto tensorDescTy = dyn_cast<xegpu::TensorDescType>(resultType)) {
+      auto newTdescTy = xegpu::TensorDescType::get(
           tensorDescTy.getContext(), tensorDescTy.getShape(),
           tensorDescTy.getElementType(), tensorDescTy.getEncoding(), layout);
-      v.setType(newTensorDescTy);
-      return;
+      r.setType(newTdescTy);
+      continue;
     }
-    // If type is vector, add a temporary layout attribute to the op.
-    op->setAttr(layoutAttrName, layout);
-  };
-
-  // Iterate over all the operands.
-  for (OpOperand &operand : op->getOpOperands()) {
-    unsigned operandIndex = operand.getOperandNumber();
-    std::string operandLayoutName =
-        operandLayoutNamePrefix + std::to_string(operandIndex);
-    updateValue(operand.get(), operandIndex, operandLayoutName);
-  }
-
-  // Iterate over all the results.
-  for (OpResult result : op->getResults()) {
-    unsigned resultIndex = result.getResultNumber();
+    // If the result is a vector type, add a temporary layout attribute to the
+    // op.
     std::string resultLayoutName =
-        resultLayoutNamePrefix + std::to_string(resultIndex);
-    updateValue(result, resultIndex, resultLayoutName);
+        resultLayoutNamePrefix + std::to_string(r.getResultNumber());
+    op->setAttr(resultLayoutName, layout);
+    // Update all users of the result with the layout.
+    for (OpOperand &user : r.getUses()) {
+      Operation *owner = user.getOwner();
+      unsigned operandNumber = user.getOperandNumber();
+      // Add temporary layout attribute at the user op.
+      std::string attrName =
+          operandLayoutNamePrefix + std::to_string(operandNumber);
+      owner->setAttr(attrName, layout);
+    }
   }
 }
+static void updateBlockTypes(mlir::OpBuilder &builder, mlir::Block &block,
+                             GetLayoutCallbackFnTy getLayoutOfValue) {}
 
 namespace {
 
@@ -1722,7 +1804,6 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
   auto getXeGPULayoutForValue = [&](Value val) -> xegpu::LayoutAttr {
     LayoutInfo layout = analyis.getLayoutInfo(val);
     if (!layout.isAssigned()) {
-      llvm::errs() << "No layout assigned for value" << val << "\n";
       return {};
     }
     SmallVector<int, 2> laneLayout, laneData;

>From b289399e44bf56e91149cbfc37a729c14949c4d2 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Tue, 3 Jun 2025 21:36:11 +0000
Subject: [PATCH 08/10] working expect for unreal cast

---
 .../Transforms/XeGPUSubgroupDistribute.cpp    | 97 +++++++++----------
 1 file changed, 46 insertions(+), 51 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 6b3ff8312e365..dfb7b0668d2be 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -1291,11 +1291,14 @@ struct CreateNdDescDistribution final : public gpu::WarpDistributionPattern {
     xegpu::TensorDescType distributedTensorDescTy =
         descOp.getType().dropLayouts(); // Distributed tensor descriptor type
                                         // does not contain layout info.
-    auto newDescOp = rewriter.create<xegpu::CreateNdDescOp>(
+    Value newDescOp = rewriter.create<xegpu::CreateNdDescOp>(
         newWarpOp.getLoc(), distributedTensorDescTy, newDescOperands,
         descOp->getAttrs());
 
     Value distributedVal = newWarpOp.getResult(operandIdx);
+    // Resolve the distributed type to the expected type.
+    newDescOp =
+        resolveDistributedTy(newDescOp, distributedVal.getType(), rewriter);
     rewriter.replaceAllUsesWith(distributedVal, newDescOp);
     return success();
   }
@@ -1697,10 +1700,13 @@ struct UpdateNdOffsetDistribution final : public gpu::WarpDistributionPattern {
       }
     }
     // Create a new update op outside the warp op.
-    auto newUpdateOp = rewriter.create<xegpu::UpdateNdOffsetOp>(
+    Value newUpdateOp = rewriter.create<xegpu::UpdateNdOffsetOp>(
         newWarpOp.getLoc(), newTensorDescTy, newUpdateOperands,
         removeTemporaryLayoutAttributes(updateOp->getAttrs()));
     Value distributedVal = newWarpOp.getResult(operandIdx);
+    // Resolve the distributed type with the original type.
+    newUpdateOp =
+        resolveDistributedTy(newUpdateOp, distributedVal.getType(), rewriter);
     rewriter.replaceAllUsesWith(distributedVal, newUpdateOp);
     return success();
   }
@@ -1836,55 +1842,44 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
     updateBlockTypes(builder, *block, getXeGPULayoutForValue);
   });
 
-  // Assign xegpu::LayoutAttr to all ops and their users based on the layout
-  // propagation analysis result.
-  // LayoutAttrAssignment layoutAssignment(getOperation(), getPropagatedLayout);
-  // if (failed(layoutAssignment.run())) {
-  //   signalPassFailure();
-  //   return;
-  // }
-
   // Move all operations of a GPU function inside gpu.warp_execute_on_lane_0
   // operation.
-  // {
-  //   RewritePatternSet patterns(&getContext());
-  //   patterns.add<MoveFuncBodyToWarpExecuteOnLane0>(&getContext());
-
-  //   if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
-  //     signalPassFailure();
-  //     return;
-  //   }
-  //   // At this point, we have moved the entire function body inside the
-  //   warpOp.
-  //   // Now move any scalar uniform code outside of the warpOp (like GPU index
-  //   // ops, scalar constants, etc.). This will simplify the later lowering
-  //   and
-  //   // avoid custom patterns for these ops.
-  //   getOperation()->walk([&](Operation *op) {
-  //     if (auto warpOp = dyn_cast<gpu::WarpExecuteOnLane0Op>(op)) {
-  //       vector::moveScalarUniformCode(warpOp);
-  //     }
-  //   });
-  // }
-  // // Finally, do the SIMD to SIMT distribution.
-  // RewritePatternSet patterns(&getContext());
-  // xegpu::populateXeGPUSubgroupDistributePatterns(patterns);
-  // // TODO: distributionFn and shuffleFn are not used at this point.
-  // auto distributionFn = [](Value val) {
-  //   VectorType vecType = dyn_cast<VectorType>(val.getType());
-  //   int64_t vecRank = vecType ? vecType.getRank() : 0;
-  //   OpBuilder builder(val.getContext());
-  //   if (vecRank == 0)
-  //     return AffineMap::get(val.getContext());
-  //   return AffineMap::getMultiDimIdentityMap(vecRank, val.getContext());
-  // };
-  // auto shuffleFn = [](Location loc, OpBuilder &builder, Value val, Value
-  // srcIdx,
-  //                     int64_t warpSz) { return Value(); };
-  // vector::populatePropagateWarpVectorDistributionPatterns(
-  //     patterns, distributionFn, shuffleFn);
-  // if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
-  //   signalPassFailure();
-  //   return;
-  // }
+  {
+    RewritePatternSet patterns(&getContext());
+    patterns.add<MoveFuncBodyToWarpExecuteOnLane0>(&getContext());
+
+    if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
+      signalPassFailure();
+      return;
+    }
+    // At this point, we have moved the entire function body inside the
+    // warpOp. Now move any scalar uniform code outside of the warpOp (like GPU
+    // index ops, scalar constants, etc.). This will simplify the later lowering
+    // and avoid custom patterns for these ops.
+    getOperation()->walk([&](Operation *op) {
+      if (auto warpOp = dyn_cast<gpu::WarpExecuteOnLane0Op>(op)) {
+        vector::moveScalarUniformCode(warpOp);
+      }
+    });
+  }
+  // Finally, do the SIMD to SIMT distribution.
+  RewritePatternSet patterns(&getContext());
+  xegpu::populateXeGPUSubgroupDistributePatterns(patterns);
+  // TODO: distributionFn and shuffleFn are not used at this point.
+  auto distributionFn = [](Value val) {
+    VectorType vecType = dyn_cast<VectorType>(val.getType());
+    int64_t vecRank = vecType ? vecType.getRank() : 0;
+    OpBuilder builder(val.getContext());
+    if (vecRank == 0)
+      return AffineMap::get(val.getContext());
+    return AffineMap::getMultiDimIdentityMap(vecRank, val.getContext());
+  };
+  auto shuffleFn = [](Location loc, OpBuilder &builder, Value val, Value srcIdx,
+                      int64_t warpSz) { return Value(); };
+  vector::populatePropagateWarpVectorDistributionPatterns(
+      patterns, distributionFn, shuffleFn);
+  if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
+    signalPassFailure();
+    return;
+  }
 }

>From 4318343ead59cda8f70741ca45e9255a6ce66bba Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Tue, 3 Jun 2025 22:57:01 +0000
Subject: [PATCH 09/10] some fixes

---
 .../Transforms/XeGPUSubgroupDistribute.cpp    | 70 ++++++++++++++++---
 1 file changed, 60 insertions(+), 10 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index dfb7b0668d2be..56ec1eaa118e5 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -68,8 +68,14 @@ constexpr unsigned packedSizeInBitsForDefault =
     16; // Minimum packing size per register for DPAS A.
 constexpr unsigned packedSizeInBitsForDpasB =
     32; // Minimum packing size per register for DPAS B.
-static const char *const operandLayoutNamePrefix = "layout_operand_";
-static const char *const resultLayoutNamePrefix = "layout_result_";
+static const char *const operandLayoutNamePrefix =
+    "layout_operand_"; // Attribute name for identifying operand layouts.
+static const char *const resultLayoutNamePrefix =
+    "layout_result_"; // Attribute name for identifying result layouts.
+static const char *const resolveSIMTTypeMismatch =
+    "resolve_simt_type_mismatch"; // Attribute name for identifying
+                                  // UnrelizedConversionCastOp added to resolve
+                                  // SIMT type mismatches.
 
 namespace {
 
@@ -946,11 +952,11 @@ static void updateOp(mlir::OpBuilder &builder, mlir::Operation *op,
     }
   }
 }
-static void handleBranchTerminatorOpInterface(
+static void updateBranchTerminatorOpInterface(
     mlir::OpBuilder &builder,
     mlir::RegionBranchTerminatorOpInterface terminator,
     GetLayoutCallbackFnTy getLayoutOfValue) {}
-static void handleBranchOpInterface(mlir::OpBuilder &builder,
+static void updateBranchOpInterface(mlir::OpBuilder &builder,
                                     mlir::RegionBranchOpInterface branch,
                                     GetLayoutCallbackFnTy getLayoutOfValue) {
   mlir::Operation *op = branch.getOperation();
@@ -966,7 +972,6 @@ static void handleBranchOpInterface(mlir::OpBuilder &builder,
 
     mlir::OperandRange initArgs = branch.getEntrySuccessorOperands(successor);
     mlir::ValueRange blockArgs = successor.getSuccessorInputs();
-    unsigned index = 0;
 
     for (auto [initArg, blockArg, result] :
          llvm::zip(initArgs, blockArgs, results)) {
@@ -1117,6 +1122,7 @@ static Value resolveDistributedTy(Value orig, T expected,
   if (isa<xegpu::TensorDescType>(orig.getType())) {
     auto castOp = rewriter.create<UnrealizedConversionCastOp>(orig.getLoc(),
                                                               expected, orig);
+    castOp->setAttr(resolveSIMTTypeMismatch, rewriter.getUnitAttr());
     return castOp.getResult(0);
   }
   llvm_unreachable("Unsupported type for reconciliation");
@@ -1804,9 +1810,7 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
     analyis.printAnalysisResult(os);
     return;
   }
-  // auto getPropagatedLayout = [&](Value val) {
-  //   return analyis.getLayoutInfo(val);
-  // };
+
   auto getXeGPULayoutForValue = [&](Value val) -> xegpu::LayoutAttr {
     LayoutInfo layout = analyis.getLayoutInfo(val);
     if (!layout.isAssigned()) {
@@ -1827,13 +1831,13 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
     for (mlir::Operation &op : llvm::reverse(block->getOperations())) {
       if (auto terminator =
               mlir::dyn_cast<mlir::RegionBranchTerminatorOpInterface>(op)) {
-        handleBranchTerminatorOpInterface(builder, terminator,
+        updateBranchTerminatorOpInterface(builder, terminator,
                                           getXeGPULayoutForValue);
         continue;
       }
 
       if (auto iface = mlir::dyn_cast<mlir::RegionBranchOpInterface>(op)) {
-        handleBranchOpInterface(builder, iface, getXeGPULayoutForValue);
+        updateBranchOpInterface(builder, iface, getXeGPULayoutForValue);
         continue;
       }
       updateOp(builder, &op, getXeGPULayoutForValue);
@@ -1882,4 +1886,50 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
     signalPassFailure();
     return;
   }
+
+  // Clean up UnrealizedConversionCastOps that were inserted due to tensor desc
+  // type mismatches created by using upstream distribution patterns (scf.for)
+  getOperation()->walk([&](mlir::UnrealizedConversionCastOp op) {
+    // We are only interested in UnrealizedConversionCastOps there were added
+    // for resolving SIMT type mismatches.
+    if (!op->getAttr(resolveSIMTTypeMismatch))
+      return WalkResult::skip();
+
+    Value input = op.getOperand(0);
+    Value output = op.getResult(0);
+
+    // Both input and output must have tensor descriptor types.
+    xegpu::TensorDescType inputDescType =
+        mlir::dyn_cast<xegpu::TensorDescType>(input.getType());
+    xegpu::TensorDescType outputDescType =
+        mlir::dyn_cast<xegpu::TensorDescType>(output.getType());
+    assert(inputDescType && outputDescType &&
+           "Unrealized conversion cast must have tensor descriptor types");
+
+    // tensor_desc<shape, layout> -> tensor_desc<shape> Type of conversions.
+    // This occurs iside scf.for body to resolve the block argument type to SIMT
+    // type.
+    if (inputDescType.getLayout()) {
+      auto argument = mlir::dyn_cast<mlir::BlockArgument>(input);
+      if (argument) {
+        argument.setType(output.getType());
+        output.replaceAllUsesWith(argument);
+        if (auto loopOp = mlir::dyn_cast<mlir::LoopLikeOpInterface>(
+                argument.getOwner()->getParentOp())) {
+          auto result = loopOp.getTiedLoopResult(argument);
+          result.setType(output.getType());
+        }
+      }
+    }
+
+    // tensor_desc<shape> -> tensor_desc<shape, layout> Type of
+    // conversions. This occurs at the yield op of scf.for body to go back from
+    // SIMT type to original type.
+    if (outputDescType.getLayout())
+      output.replaceAllUsesWith(input);
+
+    if (op->use_empty())
+      op->erase();
+    return WalkResult::advance();
+  });
 }

>From 20a641545534132b59c934d7bc31b6c088134605 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Wed, 4 Jun 2025 00:01:15 +0000
Subject: [PATCH 10/10] branch terminator iface

---
 .../Transforms/XeGPUSubgroupDistribute.cpp    | 332 ++++++++++--------
 1 file changed, 193 insertions(+), 139 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 56ec1eaa118e5..27d912b87c6dc 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -955,7 +955,54 @@ static void updateOp(mlir::OpBuilder &builder, mlir::Operation *op,
 static void updateBranchTerminatorOpInterface(
     mlir::OpBuilder &builder,
     mlir::RegionBranchTerminatorOpInterface terminator,
-    GetLayoutCallbackFnTy getLayoutOfValue) {}
+    GetLayoutCallbackFnTy getLayoutOfValue) {
+  if (!mlir::isa<mlir::RegionBranchOpInterface>(terminator->getParentOp()))
+    return;
+
+  llvm::SmallVector<mlir::RegionSuccessor> successors;
+  llvm::SmallVector<mlir::Attribute> operands(terminator->getNumOperands(),
+                                              nullptr);
+  terminator.getSuccessorRegions(operands, successors);
+
+  for (mlir::RegionSuccessor &successor : successors) {
+    if (!successor.isParent())
+      continue;
+
+    mlir::OperandRange operands = terminator.getSuccessorOperands(successor);
+    mlir::ValueRange inputs = successor.getSuccessorInputs();
+    for (auto [operand, input] : llvm::zip(operands, inputs)) {
+      // print arg and inp
+      // llvm::errs() << "arg: " << operand << ", inp: " << input << "\n";
+      Type inputType = input.getType();
+      if (!isa<xegpu::TensorDescType>(inputType))
+        continue;
+      xegpu::LayoutAttr inputLayout = getLayoutOfValue(input);
+      xegpu::LayoutAttr operandLayout = getLayoutOfValue(operand);
+
+      if (!operandLayout) {
+        LLVM_DEBUG(DBGS() << "Expecting layout for region successor operand : "
+                          << operand << " but got none.\n");
+        continue;
+      }
+
+      if (inputLayout && inputLayout != operandLayout) {
+        LLVM_DEBUG(
+            DBGS()
+            << "Conflicting layouts for region successor operand and input: "
+            << inputLayout << " vs " << operandLayout << "\n");
+        continue;
+      }
+      llvm::errs() << "Setting layout for input to "
+                   << ": " << operandLayout << "\n";
+      // Get tensor descriptor type with the layout.
+      auto tdescTy = dyn_cast<xegpu::TensorDescType>(inputType);
+      auto newTdescTy = xegpu::TensorDescType::get(
+          tdescTy.getContext(), tdescTy.getShape(), tdescTy.getElementType(),
+          tdescTy.getEncoding(), operandLayout);
+      input.setType(newTdescTy);
+    }
+  }
+}
 static void updateBranchOpInterface(mlir::OpBuilder &builder,
                                     mlir::RegionBranchOpInterface branch,
                                     GetLayoutCallbackFnTy getLayoutOfValue) {
@@ -970,20 +1017,19 @@ static void updateBranchOpInterface(mlir::OpBuilder &builder,
     if (successor.isParent())
       continue;
 
-    mlir::OperandRange initArgs = branch.getEntrySuccessorOperands(successor);
-    mlir::ValueRange blockArgs = successor.getSuccessorInputs();
+    mlir::OperandRange operands = branch.getEntrySuccessorOperands(successor);
+    mlir::ValueRange inputs = successor.getSuccessorInputs();
 
-    for (auto [initArg, blockArg, result] :
-         llvm::zip(initArgs, blockArgs, results)) {
-      Type inputType = blockArg.getType();
+    for (auto [operand, input, result] : llvm::zip(operands, inputs, results)) {
+      Type inputType = input.getType();
       if (!isa<xegpu::TensorDescType>(inputType))
         continue;
-      xegpu::LayoutAttr blockArgLayout = getLayoutOfValue(blockArg);
-      xegpu::LayoutAttr initArgLayout = getLayoutOfValue(initArg);
+      xegpu::LayoutAttr blockArgLayout = getLayoutOfValue(input);
+      xegpu::LayoutAttr initArgLayout = getLayoutOfValue(operand);
 
       if (!blockArgLayout || !initArgLayout) {
-        LLVM_DEBUG(DBGS() << "No layout assigned for block arg: " << blockArg
-                          << " or init arg: " << initArg << "\n");
+        LLVM_DEBUG(DBGS() << "No layout assigned for block arg: " << input
+                          << " or init arg: " << operand << "\n");
         continue;
       }
 
@@ -996,52 +1042,54 @@ static void updateBranchOpInterface(mlir::OpBuilder &builder,
       auto newTdescTy = xegpu::TensorDescType::get(
           tdescTy.getContext(), tdescTy.getShape(), tdescTy.getElementType(),
           tdescTy.getEncoding(), blockArgLayout);
-      blockArg.setType(newTdescTy);
+      input.setType(newTdescTy);
       // Store the layout for the result.
-      if (resultToLayouts.count(result) != 0 &&
-          resultToLayouts[result] != blockArgLayout) {
-        LLVM_DEBUG(DBGS() << "Conflicting layouts for result: " << result
-                          << " - " << resultToLayouts[result] << " vs "
-                          << blockArgLayout << "\n");
-      } else {
-        resultToLayouts[result] = blockArgLayout;
-      }
-    }
-  }
-  for (auto [i, r] : llvm::enumerate(op->getResults())) {
-    Type resultType = r.getType();
-    if (!isa<xegpu::TensorDescType, VectorType>(resultType))
-      continue;
-    xegpu::LayoutAttr layout = getLayoutOfValue(r);
-    if (!layout)
-      layout = resultToLayouts[r];
-    if (!layout) {
-      LLVM_DEBUG(DBGS() << "No layout assigned for vector/tensor desc result: "
-                        << r << "\n");
-      continue;
-    }
-    if (auto tensorDescTy = dyn_cast<xegpu::TensorDescType>(resultType)) {
-      auto newTdescTy = xegpu::TensorDescType::get(
-          tensorDescTy.getContext(), tensorDescTy.getShape(),
-          tensorDescTy.getElementType(), tensorDescTy.getEncoding(), layout);
-      r.setType(newTdescTy);
-      continue;
-    }
-    // If the result is a vector type, add a temporary layout attribute to the
-    // op.
-    std::string resultLayoutName =
-        resultLayoutNamePrefix + std::to_string(r.getResultNumber());
-    op->setAttr(resultLayoutName, layout);
-    // Update all users of the result with the layout.
-    for (OpOperand &user : r.getUses()) {
-      Operation *owner = user.getOwner();
-      unsigned operandNumber = user.getOperandNumber();
-      // Add temporary layout attribute at the user op.
-      std::string attrName =
-          operandLayoutNamePrefix + std::to_string(operandNumber);
-      owner->setAttr(attrName, layout);
+      // if (resultToLayouts.count(result) != 0 &&
+      //     resultToLayouts[result] != blockArgLayout) {
+      //   LLVM_DEBUG(DBGS() << "Conflicting layouts for result: " << result
+      //                     << " - " << resultToLayouts[result] << " vs "
+      //                     << blockArgLayout << "\n");
+      // } else {
+      //   resultToLayouts[result] = blockArgLayout;
+      // }
     }
   }
+  // for (auto [i, r] : llvm::enumerate(op->getResults())) {
+  //   Type resultType = r.getType();
+  //   if (!isa<xegpu::TensorDescType, VectorType>(resultType))
+  //     continue;
+  //   xegpu::LayoutAttr layout = getLayoutOfValue(r);
+  //   if (!layout)
+  //     layout = resultToLayouts[r];
+  //   if (!layout) {
+  //     LLVM_DEBUG(DBGS() << "No layout assigned for vector/tensor desc result:
+  //     "
+  //                       << r << "\n");
+  //     continue;
+  //   }
+  //   if (auto tensorDescTy = dyn_cast<xegpu::TensorDescType>(resultType)) {
+  //     auto newTdescTy = xegpu::TensorDescType::get(
+  //         tensorDescTy.getContext(), tensorDescTy.getShape(),
+  //         tensorDescTy.getElementType(), tensorDescTy.getEncoding(), layout);
+  //     r.setType(newTdescTy);
+  //     continue;
+  //   }
+  //   // If the result is a vector type, add a temporary layout attribute to
+  //   the
+  //   // op.
+  //   std::string resultLayoutName =
+  //       resultLayoutNamePrefix + std::to_string(r.getResultNumber());
+  //   op->setAttr(resultLayoutName, layout);
+  //   // Update all users of the result with the layout.
+  //   for (OpOperand &user : r.getUses()) {
+  //     Operation *owner = user.getOwner();
+  //     unsigned operandNumber = user.getOperandNumber();
+  //     // Add temporary layout attribute at the user op.
+  //     std::string attrName =
+  //         operandLayoutNamePrefix + std::to_string(operandNumber);
+  //     owner->setAttr(attrName, layout);
+  //   }
+  // }
 }
 static void updateBlockTypes(mlir::OpBuilder &builder, mlir::Block &block,
                              GetLayoutCallbackFnTy getLayoutOfValue) {}
@@ -1846,90 +1894,96 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
     updateBlockTypes(builder, *block, getXeGPULayoutForValue);
   });
 
-  // Move all operations of a GPU function inside gpu.warp_execute_on_lane_0
-  // operation.
-  {
-    RewritePatternSet patterns(&getContext());
-    patterns.add<MoveFuncBodyToWarpExecuteOnLane0>(&getContext());
-
-    if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
-      signalPassFailure();
-      return;
-    }
-    // At this point, we have moved the entire function body inside the
-    // warpOp. Now move any scalar uniform code outside of the warpOp (like GPU
-    // index ops, scalar constants, etc.). This will simplify the later lowering
-    // and avoid custom patterns for these ops.
-    getOperation()->walk([&](Operation *op) {
-      if (auto warpOp = dyn_cast<gpu::WarpExecuteOnLane0Op>(op)) {
-        vector::moveScalarUniformCode(warpOp);
-      }
-    });
-  }
-  // Finally, do the SIMD to SIMT distribution.
-  RewritePatternSet patterns(&getContext());
-  xegpu::populateXeGPUSubgroupDistributePatterns(patterns);
-  // TODO: distributionFn and shuffleFn are not used at this point.
-  auto distributionFn = [](Value val) {
-    VectorType vecType = dyn_cast<VectorType>(val.getType());
-    int64_t vecRank = vecType ? vecType.getRank() : 0;
-    OpBuilder builder(val.getContext());
-    if (vecRank == 0)
-      return AffineMap::get(val.getContext());
-    return AffineMap::getMultiDimIdentityMap(vecRank, val.getContext());
-  };
-  auto shuffleFn = [](Location loc, OpBuilder &builder, Value val, Value srcIdx,
-                      int64_t warpSz) { return Value(); };
-  vector::populatePropagateWarpVectorDistributionPatterns(
-      patterns, distributionFn, shuffleFn);
-  if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
-    signalPassFailure();
-    return;
-  }
-
-  // Clean up UnrealizedConversionCastOps that were inserted due to tensor desc
-  // type mismatches created by using upstream distribution patterns (scf.for)
-  getOperation()->walk([&](mlir::UnrealizedConversionCastOp op) {
-    // We are only interested in UnrealizedConversionCastOps there were added
-    // for resolving SIMT type mismatches.
-    if (!op->getAttr(resolveSIMTTypeMismatch))
-      return WalkResult::skip();
-
-    Value input = op.getOperand(0);
-    Value output = op.getResult(0);
-
-    // Both input and output must have tensor descriptor types.
-    xegpu::TensorDescType inputDescType =
-        mlir::dyn_cast<xegpu::TensorDescType>(input.getType());
-    xegpu::TensorDescType outputDescType =
-        mlir::dyn_cast<xegpu::TensorDescType>(output.getType());
-    assert(inputDescType && outputDescType &&
-           "Unrealized conversion cast must have tensor descriptor types");
-
-    // tensor_desc<shape, layout> -> tensor_desc<shape> Type of conversions.
-    // This occurs iside scf.for body to resolve the block argument type to SIMT
-    // type.
-    if (inputDescType.getLayout()) {
-      auto argument = mlir::dyn_cast<mlir::BlockArgument>(input);
-      if (argument) {
-        argument.setType(output.getType());
-        output.replaceAllUsesWith(argument);
-        if (auto loopOp = mlir::dyn_cast<mlir::LoopLikeOpInterface>(
-                argument.getOwner()->getParentOp())) {
-          auto result = loopOp.getTiedLoopResult(argument);
-          result.setType(output.getType());
-        }
-      }
-    }
-
-    // tensor_desc<shape> -> tensor_desc<shape, layout> Type of
-    // conversions. This occurs at the yield op of scf.for body to go back from
-    // SIMT type to original type.
-    if (outputDescType.getLayout())
-      output.replaceAllUsesWith(input);
-
-    if (op->use_empty())
-      op->erase();
-    return WalkResult::advance();
-  });
+  // // Move all operations of a GPU function inside gpu.warp_execute_on_lane_0
+  // // operation.
+  // {
+  //   RewritePatternSet patterns(&getContext());
+  //   patterns.add<MoveFuncBodyToWarpExecuteOnLane0>(&getContext());
+
+  //   if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
+  //     signalPassFailure();
+  //     return;
+  //   }
+  //   // At this point, we have moved the entire function body inside the
+  //   // warpOp. Now move any scalar uniform code outside of the warpOp (like
+  //   GPU
+  //   // index ops, scalar constants, etc.). This will simplify the later
+  //   lowering
+  //   // and avoid custom patterns for these ops.
+  //   getOperation()->walk([&](Operation *op) {
+  //     if (auto warpOp = dyn_cast<gpu::WarpExecuteOnLane0Op>(op)) {
+  //       vector::moveScalarUniformCode(warpOp);
+  //     }
+  //   });
+  // }
+  // // Finally, do the SIMD to SIMT distribution.
+  // RewritePatternSet patterns(&getContext());
+  // xegpu::populateXeGPUSubgroupDistributePatterns(patterns);
+  // // TODO: distributionFn and shuffleFn are not used at this point.
+  // auto distributionFn = [](Value val) {
+  //   VectorType vecType = dyn_cast<VectorType>(val.getType());
+  //   int64_t vecRank = vecType ? vecType.getRank() : 0;
+  //   OpBuilder builder(val.getContext());
+  //   if (vecRank == 0)
+  //     return AffineMap::get(val.getContext());
+  //   return AffineMap::getMultiDimIdentityMap(vecRank, val.getContext());
+  // };
+  // auto shuffleFn = [](Location loc, OpBuilder &builder, Value val, Value
+  // srcIdx,
+  //                     int64_t warpSz) { return Value(); };
+  // vector::populatePropagateWarpVectorDistributionPatterns(
+  //     patterns, distributionFn, shuffleFn);
+  // if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
+  //   signalPassFailure();
+  //   return;
+  // }
+
+  // // Clean up UnrealizedConversionCastOps that were inserted due to tensor
+  // desc
+  // // type mismatches created by using upstream distribution patterns
+  // (scf.for) getOperation()->walk([&](mlir::UnrealizedConversionCastOp op) {
+  //   // We are only interested in UnrealizedConversionCastOps there were added
+  //   // for resolving SIMT type mismatches.
+  //   if (!op->getAttr(resolveSIMTTypeMismatch))
+  //     return WalkResult::skip();
+
+  //   Value input = op.getOperand(0);
+  //   Value output = op.getResult(0);
+
+  //   // Both input and output must have tensor descriptor types.
+  //   xegpu::TensorDescType inputDescType =
+  //       mlir::dyn_cast<xegpu::TensorDescType>(input.getType());
+  //   xegpu::TensorDescType outputDescType =
+  //       mlir::dyn_cast<xegpu::TensorDescType>(output.getType());
+  //   assert(inputDescType && outputDescType &&
+  //          "Unrealized conversion cast must have tensor descriptor types");
+
+  //   // tensor_desc<shape, layout> -> tensor_desc<shape> Type of conversions.
+  //   // This occurs iside scf.for body to resolve the block argument type to
+  //   SIMT
+  //   // type.
+  //   if (inputDescType.getLayout()) {
+  //     auto argument = mlir::dyn_cast<mlir::BlockArgument>(input);
+  //     if (argument) {
+  //       argument.setType(output.getType());
+  //       output.replaceAllUsesWith(argument);
+  //       if (auto loopOp = mlir::dyn_cast<mlir::LoopLikeOpInterface>(
+  //               argument.getOwner()->getParentOp())) {
+  //         auto result = loopOp.getTiedLoopResult(argument);
+  //         result.setType(output.getType());
+  //       }
+  //     }
+  //   }
+
+  //   // tensor_desc<shape> -> tensor_desc<shape, layout> Type of
+  //   // conversions. This occurs at the yield op of scf.for body to go back
+  //   from
+  //   // SIMT type to original type.
+  //   if (outputDescType.getLayout())
+  //     output.replaceAllUsesWith(input);
+
+  //   if (op->use_empty())
+  //     op->erase();
+  //   return WalkResult::advance();
+  // });
 }



More information about the Mlir-commits mailing list