[Mlir-commits] [mlir] [mlir][xegpu] Refine layout assignment in XeGPU SIMT distribution. (PR #142687)

Charitha Saumya llvmlistbot at llvm.org
Fri Jun 13 11:08:39 PDT 2025


https://github.com/charithaintc updated https://github.com/llvm/llvm-project/pull/142687

>From ff1012e2208ef866a0313289d4bf6e130d1a0eaf Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Tue, 27 May 2025 23:40:57 +0000
Subject: [PATCH 01/32] add bug fix

---
 .../Vector/Transforms/VectorDistribute.cpp    | 42 ++++++++++++++-----
 1 file changed, 32 insertions(+), 10 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 045c192787f10..1649fb5f91b42 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -15,10 +15,13 @@
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
 #include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/Value.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "mlir/Transforms/RegionUtils.h"
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SetVector.h"
 #include "llvm/Support/FormatVariadic.h"
+#include "llvm/Support/raw_ostream.h"
 #include <utility>
 
 using namespace mlir;
@@ -1554,22 +1557,36 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
     llvm::SmallSetVector<Value, 32> escapingValues;
     SmallVector<Type> inputTypes;
     SmallVector<Type> distTypes;
+    auto collectEscapingValues = [&](Value value) {
+      if (!escapingValues.insert(value))
+        return;
+      Type distType = value.getType();
+      if (auto vecType = dyn_cast<VectorType>(distType)) {
+        AffineMap map = distributionMapFn(value);
+        distType = getDistributedType(vecType, map, warpOp.getWarpSize());
+      }
+      inputTypes.push_back(value.getType());
+      distTypes.push_back(distType);
+    };
+
     mlir::visitUsedValuesDefinedAbove(
         forOp.getBodyRegion(), [&](OpOperand *operand) {
           Operation *parent = operand->get().getParentRegion()->getParentOp();
           if (warpOp->isAncestor(parent)) {
-            if (!escapingValues.insert(operand->get()))
-              return;
-            Type distType = operand->get().getType();
-            if (auto vecType = dyn_cast<VectorType>(distType)) {
-              AffineMap map = distributionMapFn(operand->get());
-              distType = getDistributedType(vecType, map, warpOp.getWarpSize());
-            }
-            inputTypes.push_back(operand->get().getType());
-            distTypes.push_back(distType);
+            collectEscapingValues(operand->get());
           }
         });
 
+    // Any forOp result that is not already yielded by the warpOp
+    // region is also considered escaping.
+    for (OpResult forResult : forOp.getResults()) {
+      // Check if this forResult is already yielded by the yield op.
+      if (llvm::is_contained(yield->getOperands(), forResult)) {
+        continue;
+      }
+      collectEscapingValues(forResult);
+    }
+
     if (llvm::is_contained(distTypes, Type{}))
       return failure();
 
@@ -1609,7 +1626,12 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
                                     forOp.getResultTypes().end());
     llvm::SmallDenseMap<Value, int64_t> argIndexMapping;
     for (auto [i, retIdx] : llvm::enumerate(newRetIndices)) {
-      warpInput.push_back(newWarpOp.getResult(retIdx));
+      auto newWarpResult = newWarpOp.getResult(retIdx);
+      // Unused forOp results yielded by the warpOp region are already included
+      // in the new ForOp.
+      if (llvm::is_contained(newOperands, newWarpResult))
+        continue;
+      warpInput.push_back(newWarpResult);
       argIndexMapping[escapingValues[i]] = warpInputType.size();
       warpInputType.push_back(inputTypes[i]);
     }

>From c6eb53fefded7152c2d627c4094b66f616bc53ed Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Wed, 28 May 2025 20:22:47 +0000
Subject: [PATCH 02/32] add test

---
 .../Vector/vector-warp-distribute.mlir        | 36 +++++++++++++++++++
 1 file changed, 36 insertions(+)

diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index 38771f2593449..6c7ac7a5196a7 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -584,6 +584,42 @@ func.func @warp_scf_for_multiple_yield(%arg0: index, %arg1: memref<?xf32>, %arg2
   return
 }
 
+// -----
+// CHECK-PROP-LABEL: func.func @warp_scf_for_unused_yield(
+//       CHECK-PROP: %[[W0:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<4xf32>, vector<4xf32>) {
+//       CHECK-PROP: %[[INI0:.*]] = "some_def"() : () -> vector<128xf32>
+//       CHECK-PROP: %[[INI1:.*]] = "some_def"() : () -> vector<128xf32>
+//       CHECK-PROP: gpu.yield %[[INI0]], %[[INI1]] : vector<128xf32>, vector<128xf32>
+//       CHECK-PROP: }
+//       CHECK-PROP: %[[F:.*]]:2 = scf.for %{{.*}} iter_args(%{{.*}} = %[[W0]]#0, %{{.*}} = %[[W0]]#1) -> (vector<4xf32>, vector<4xf32>) {
+//       CHECK-PROP: %[[W1:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[32] args(%{{.*}} : vector<4xf32>, vector<4xf32>) -> (vector<4xf32>, vector<4xf32>) {
+//       CHECK-PROP: %[[ACC0:.*]] = "some_def"(%{{.*}}) : (vector<128xf32>, index) -> vector<128xf32>
+//       CHECK-PROP: %[[ACC1:.*]] = "some_def"(%{{.*}}) : (index, vector<128xf32>, vector<128xf32>) -> vector<128xf32>
+//       CHECK-PROP: gpu.yield %[[ACC1]], %[[ACC0]] : vector<128xf32>, vector<128xf32>
+//       CHECK-PROP: }
+//       CHECK-PROP: scf.yield %[[W1]]#0, %[[W1]]#1 : vector<4xf32>, vector<4xf32>
+//       CHECK-PROP: }
+//       CHECK-PROP: "some_use"(%[[F]]#0) : (vector<4xf32>) -> ()
+func.func @warp_scf_for_unused_yield(%arg0: index) {
+  %c128 = arith.constant 128 : index
+  %c1 = arith.constant 1 : index
+  %c0 = arith.constant 0 : index
+  %0 = gpu.warp_execute_on_lane_0(%arg0)[32] -> (vector<4xf32>) {
+    %ini = "some_def"() : () -> (vector<128xf32>)
+    %ini1 = "some_def"() : () -> (vector<128xf32>)
+    %3:2 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %ini, %arg5 = %ini1) -> (vector<128xf32>, vector<128xf32>) {
+      %add = arith.addi %arg3, %c1 : index
+      %1  = "some_def"(%arg5, %add) : (vector<128xf32>, index) -> (vector<128xf32>)
+      %acc = "some_def"(%add, %arg4, %1) : (index, vector<128xf32>, vector<128xf32>) -> (vector<128xf32>)
+      scf.yield %acc, %1 : vector<128xf32>, vector<128xf32>
+    }
+    gpu.yield %3#0 : vector<128xf32>
+  }
+  "some_use"(%0) : (vector<4xf32>) -> ()
+  return
+}
+
+
 // -----
 
 // CHECK-PROP-LABEL: func @vector_reduction(

>From 3bdb5961d48bf70b63560820375d24e0682dbff8 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Wed, 28 May 2025 20:26:01 +0000
Subject: [PATCH 03/32] add comments

---
 mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 1649fb5f91b42..94435588459e6 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -1578,7 +1578,8 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
         });
 
     // Any forOp result that is not already yielded by the warpOp
-    // region is also considered escaping.
+    // region is also considered escaping and must be returned by the
+    // original warpOp.
     for (OpResult forResult : forOp.getResults()) {
       // Check if this forResult is already yielded by the yield op.
       if (llvm::is_contained(yield->getOperands(), forResult)) {

>From fe3ab99da99bfe47dd257a458d01ddd4e24df63e Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Wed, 28 May 2025 21:32:04 +0000
Subject: [PATCH 04/32] remove unsused headers

---
 mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp | 3 ---
 1 file changed, 3 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 94435588459e6..bd833ddb773f7 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -15,13 +15,10 @@
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
 #include "mlir/IR/AffineExpr.h"
-#include "mlir/IR/Value.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "mlir/Transforms/RegionUtils.h"
-#include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SetVector.h"
 #include "llvm/Support/FormatVariadic.h"
-#include "llvm/Support/raw_ostream.h"
 #include <utility>
 
 using namespace mlir;

>From f91b64c88ef893a9a7d620cd76345c21a4a46d33 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Mon, 2 Jun 2025 18:24:58 +0000
Subject: [PATCH 05/32] save work

---
 .../Transforms/XeGPUSubgroupDistribute.cpp    | 218 +++++++++++++-----
 1 file changed, 164 insertions(+), 54 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 992700524146a..d178c2c33245e 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -12,6 +12,8 @@
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/Dialect/GPU/Utils/DistributionUtils.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SCF/Transforms/Patterns.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
 #include "mlir/Dialect/XeGPU/IR/XeGPU.h"
@@ -30,6 +32,7 @@
 #include "mlir/IR/Value.h"
 #include "mlir/IR/Visitors.h"
 #include "mlir/Interfaces/FunctionInterfaces.h"
+#include "mlir/Transforms/DialectConversion.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "mlir/Transforms/InliningUtils.h"
 #include "llvm/ADT/ArrayRef.h"
@@ -38,6 +41,7 @@
 #include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Support/FormatVariadic.h"
 #include "llvm/Support/InterleavedRange.h"
+#include "llvm/Support/LogicalResult.h"
 #include "llvm/Support/raw_ostream.h"
 
 namespace mlir {
@@ -701,7 +705,47 @@ namespace {
 //===----------------------------------------------------------------------===//
 // LayoutAttrAssignment
 //===----------------------------------------------------------------------===//
+template <typename OpTy>
+class UpdateTensorDescType : public OpConversionPattern<OpTy> {
+public:
+  UpdateTensorDescType(MLIRContext *context,
+                       function_ref<xegpu::LayoutAttr(Value)> getLayoutOfValue,
+                       TypeConverter &typeConverter, PatternBenefit benefit = 1)
+      : OpConversionPattern<OpTy>(typeConverter, context, benefit),
+        getLayoutOfValue(getLayoutOfValue) {}
+  using OpConversionPattern<OpTy>::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    // Op must have single result.
+    if (op->getNumResults() != 1)
+      return failure();
+    Type resultType = op->getResult(0).getType();
+    // Result type must be a tensor descriptor type.
+    if (!isa<xegpu::TensorDescType>(resultType)) {
+      LLVM_DEBUG(DBGS() << "Result type is not a tensor descriptor type: "
+                        << resultType << "\n");
+      return failure();
+    }
+    auto assignedLayout = getLayoutOfValue(op.getResult());
+    if (!assignedLayout) {
+      LLVM_DEBUG(DBGS() << "No layout assigned for " << *op << "\n");
+      return failure();
+    }
+    // Get the original tensor descriptor type.
+    auto origTensorDescTy = dyn_cast<xegpu::TensorDescType>(resultType);
+    auto newTensorDescTy = xegpu::TensorDescType::get(
+        origTensorDescTy.getContext(), origTensorDescTy.getShape(),
+        origTensorDescTy.getElementType(), origTensorDescTy.getEncoding(),
+        assignedLayout);
+    rewriter.replaceOpWithNewOp<OpTy>(op, newTensorDescTy,
+                                      adaptor.getOperands(), op->getAttrs());
+    return success();
+  }
 
+private:
+  function_ref<xegpu::LayoutAttr(Value)> getLayoutOfValue;
+};
 /// This class is responsible for assigning the layout attributes to the ops and
 /// their users based on the layout propagation analysis result.
 class LayoutAttrAssignment {
@@ -739,15 +783,19 @@ void LayoutAttrAssignment::assignToUsers(Value v, xegpu::LayoutAttr layout) {
 
 /// Convert the layout assigned to a value to xegpu::LayoutAttr.
 xegpu::LayoutAttr LayoutAttrAssignment::getLayoutAttrForValue(Value v) {
+  llvm::errs() << "getLayoutAttrForValue: " << v << "\n";
   LayoutInfo layout = getAnalysisResult(v);
-  if (!layout.isAssigned())
+  if (!layout.isAssigned()) {
+    llvm::errs() << "No layout assigned for value\n";
     return {};
+  }
   SmallVector<int, 2> laneLayout, laneData;
   for (auto [layout, data] : llvm::zip_equal(layout.getLayoutAsArrayRef(),
                                              layout.getDataAsArrayRef())) {
     laneLayout.push_back(static_cast<int>(layout));
     laneData.push_back(static_cast<int>(data));
   }
+  llvm::errs() << "return layout\n";
   return xegpu::LayoutAttr::get(v.getContext(), laneLayout, laneData);
 }
 
@@ -820,14 +868,23 @@ LogicalResult LayoutAttrAssignment::assign(Operation *op) {
 
 /// Walk the IR and attach xegpu::LayoutAttr to all ops and their users.
 LogicalResult LayoutAttrAssignment::run() {
-  auto walkResult = top->walk([&](Operation *op) {
-    if (failed(assign(op)))
-      return WalkResult::interrupt();
-    return WalkResult::advance();
-  });
-
-  if (walkResult.wasInterrupted())
-    return failure();
+  // auto walkResult = top->walk([&](Operation *op) {
+  //   if (failed(assign(op)))
+  //     return WalkResult::interrupt();
+  //   return WalkResult::advance();
+  // });
+
+  // if (walkResult.wasInterrupted())
+  //   return failure();
+  // apply the UpdateTensorDescType pattern to all ops
+  // RewritePatternSet patterns(top->getContext());
+  // patterns.add<UpdateTensorDescType>(
+  //     top->getContext(), [&](Value v) -> xegpu::LayoutAttr {
+  //       llvm::errs() << "invoking callback for value\n";
+  //       return getLayoutAttrForValue(v);
+  //     });
+  // if (failed(applyPatternsGreedily(top, std::move(patterns))))
+  //   return failure();
 
   return resolveConflicts();
 }
@@ -1597,56 +1654,109 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
     analyis.printAnalysisResult(os);
     return;
   }
-  auto getPropagatedLayout = [&](Value val) {
-    return analyis.getLayoutInfo(val);
+  // auto getPropagatedLayout = [&](Value val) {
+  //   return analyis.getLayoutInfo(val);
+  // };
+  auto getXeGpuLayoutForValue = [&](Value val) -> xegpu::LayoutAttr {
+    LayoutInfo layout = analyis.getLayoutInfo(val);
+    if (!layout.isAssigned()) {
+      llvm::errs() << "No layout assigned for value\n";
+      return {};
+    }
+    SmallVector<int, 2> laneLayout, laneData;
+    for (auto [layout, data] : llvm::zip_equal(layout.getLayoutAsArrayRef(),
+                                               layout.getDataAsArrayRef())) {
+      laneLayout.push_back(static_cast<int>(layout));
+      laneData.push_back(static_cast<int>(data));
+    }
+    return xegpu::LayoutAttr::get(val.getContext(), laneLayout, laneData);
+  };
+
+  ConversionTarget target(getContext());
+  target.addDynamicallyLegalOp<xegpu::CreateNdDescOp, xegpu::UpdateNdOffsetOp>(
+      [&](Operation *op) {
+        return llvm::all_of(op->getResults(), [&](Value val) {
+          if (auto descType = dyn_cast<xegpu::TensorDescType>(val.getType())) {
+            return descType.getLayoutAttr() != nullptr;
+          }
+          return true; // Non-tensor descriptor types are always legal.
+        });
+      });
+  target.addLegalOp<UnrealizedConversionCastOp>();
+  TypeConverter typeConverter;
+  typeConverter.addConversion([](Type type) { return type; });
+  // // typeConverter.addConversion([](xegpu::TensorDescType type) {
+  // //   return xegpu::TensorDescType::get(
+  // //       type.getContext(), type.getShape(), type.getElementType(),
+  // //       type.getEncoding(),
+  // //       xegpu::LayoutAttr::get(type.getContext(), {1, 1}, {1, 1}));
+  // // });
+  auto addUnrealizedCast = [](OpBuilder &builder, Type type, ValueRange inputs,
+                              Location loc) -> Value {
+    auto cast = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
+    return cast.getResult(0);
   };
 
+  typeConverter.addSourceMaterialization(addUnrealizedCast);
+  typeConverter.addTargetMaterialization(addUnrealizedCast);
+
+  RewritePatternSet patterns(&getContext());
+  patterns.add<UpdateTensorDescType<xegpu::CreateNdDescOp>,
+               UpdateTensorDescType<xegpu::UpdateNdOffsetOp>>(
+      &getContext(), getXeGpuLayoutForValue, typeConverter);
+  if (failed(
+          applyPartialConversion(getOperation(), target, std::move(patterns))))
+    signalPassFailure();
+
   // Assign xegpu::LayoutAttr to all ops and their users based on the layout
   // propagation analysis result.
-  LayoutAttrAssignment layoutAssignment(getOperation(), getPropagatedLayout);
-  if (failed(layoutAssignment.run())) {
-    signalPassFailure();
-    return;
-  }
+  // LayoutAttrAssignment layoutAssignment(getOperation(), getPropagatedLayout);
+  // if (failed(layoutAssignment.run())) {
+  //   signalPassFailure();
+  //   return;
+  // }
 
   // Move all operations of a GPU function inside gpu.warp_execute_on_lane_0
   // operation.
-  {
-    RewritePatternSet patterns(&getContext());
-    patterns.add<MoveFuncBodyToWarpExecuteOnLane0>(&getContext());
-
-    if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
-      signalPassFailure();
-      return;
-    }
-    // At this point, we have moved the entire function body inside the warpOp.
-    // Now move any scalar uniform code outside of the warpOp (like GPU index
-    // ops, scalar constants, etc.). This will simplify the later lowering and
-    // avoid custom patterns for these ops.
-    getOperation()->walk([&](Operation *op) {
-      if (auto warpOp = dyn_cast<gpu::WarpExecuteOnLane0Op>(op)) {
-        vector::moveScalarUniformCode(warpOp);
-      }
-    });
-  }
-  // Finally, do the SIMD to SIMT distribution.
-  RewritePatternSet patterns(&getContext());
-  xegpu::populateXeGPUSubgroupDistributePatterns(patterns);
-  // TODO: distributionFn and shuffleFn are not used at this point.
-  auto distributionFn = [](Value val) {
-    VectorType vecType = dyn_cast<VectorType>(val.getType());
-    int64_t vecRank = vecType ? vecType.getRank() : 0;
-    OpBuilder builder(val.getContext());
-    if (vecRank == 0)
-      return AffineMap::get(val.getContext());
-    return AffineMap::getMultiDimIdentityMap(vecRank, val.getContext());
-  };
-  auto shuffleFn = [](Location loc, OpBuilder &builder, Value val, Value srcIdx,
-                      int64_t warpSz) { return Value(); };
-  vector::populatePropagateWarpVectorDistributionPatterns(
-      patterns, distributionFn, shuffleFn);
-  if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
-    signalPassFailure();
-    return;
-  }
+  // {
+  //   RewritePatternSet patterns(&getContext());
+  //   patterns.add<MoveFuncBodyToWarpExecuteOnLane0>(&getContext());
+
+  //   if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
+  //     signalPassFailure();
+  //     return;
+  //   }
+  //   // At this point, we have moved the entire function body inside the
+  //   warpOp.
+  //   // Now move any scalar uniform code outside of the warpOp (like GPU index
+  //   // ops, scalar constants, etc.). This will simplify the later lowering
+  //   and
+  //   // avoid custom patterns for these ops.
+  //   getOperation()->walk([&](Operation *op) {
+  //     if (auto warpOp = dyn_cast<gpu::WarpExecuteOnLane0Op>(op)) {
+  //       vector::moveScalarUniformCode(warpOp);
+  //     }
+  //   });
+  // }
+  // // Finally, do the SIMD to SIMT distribution.
+  // RewritePatternSet patterns(&getContext());
+  // xegpu::populateXeGPUSubgroupDistributePatterns(patterns);
+  // // TODO: distributionFn and shuffleFn are not used at this point.
+  // auto distributionFn = [](Value val) {
+  //   VectorType vecType = dyn_cast<VectorType>(val.getType());
+  //   int64_t vecRank = vecType ? vecType.getRank() : 0;
+  //   OpBuilder builder(val.getContext());
+  //   if (vecRank == 0)
+  //     return AffineMap::get(val.getContext());
+  //   return AffineMap::getMultiDimIdentityMap(vecRank, val.getContext());
+  // };
+  // auto shuffleFn = [](Location loc, OpBuilder &builder, Value val, Value
+  // srcIdx,
+  //                     int64_t warpSz) { return Value(); };
+  // vector::populatePropagateWarpVectorDistributionPatterns(
+  //     patterns, distributionFn, shuffleFn);
+  // if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
+  //   signalPassFailure();
+  //   return;
+  // }
 }

>From 5cacace6c3f56f3d84b2a63003c2f3d9947b195a Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Mon, 2 Jun 2025 22:57:27 +0000
Subject: [PATCH 06/32] initial version

---
 .../Transforms/XeGPUSubgroupDistribute.cpp    | 487 ++++++++++--------
 1 file changed, 267 insertions(+), 220 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index d178c2c33245e..aa982ae779d1e 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -32,6 +32,7 @@
 #include "mlir/IR/Value.h"
 #include "mlir/IR/Visitors.h"
 #include "mlir/Interfaces/FunctionInterfaces.h"
+#include "mlir/Support/LLVM.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "mlir/Transforms/InliningUtils.h"
@@ -700,203 +701,264 @@ void RunLayoutInfoPropagation::printAnalysisResult(llvm::raw_ostream &os) {
   }
 }
 
-namespace {
+// namespace {
 
 //===----------------------------------------------------------------------===//
 // LayoutAttrAssignment
 //===----------------------------------------------------------------------===//
-template <typename OpTy>
-class UpdateTensorDescType : public OpConversionPattern<OpTy> {
-public:
-  UpdateTensorDescType(MLIRContext *context,
-                       function_ref<xegpu::LayoutAttr(Value)> getLayoutOfValue,
-                       TypeConverter &typeConverter, PatternBenefit benefit = 1)
-      : OpConversionPattern<OpTy>(typeConverter, context, benefit),
-        getLayoutOfValue(getLayoutOfValue) {}
-  using OpConversionPattern<OpTy>::OpConversionPattern;
-  LogicalResult
-  matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    // Op must have single result.
-    if (op->getNumResults() != 1)
-      return failure();
-    Type resultType = op->getResult(0).getType();
-    // Result type must be a tensor descriptor type.
-    if (!isa<xegpu::TensorDescType>(resultType)) {
-      LLVM_DEBUG(DBGS() << "Result type is not a tensor descriptor type: "
-                        << resultType << "\n");
-      return failure();
+// template <typename OpTy>
+// class UpdateTensorDescType : public OpConversionPattern<OpTy> {
+// public:
+//   UpdateTensorDescType(MLIRContext *context,
+//                        function_ref<xegpu::LayoutAttr(Value)>
+//                        getLayoutOfValue, TypeConverter &typeConverter,
+//                        PatternBenefit benefit = 1)
+//       : OpConversionPattern<OpTy>(typeConverter, context, benefit),
+//         getLayoutOfValue(getLayoutOfValue) {}
+//   using OpConversionPattern<OpTy>::OpConversionPattern;
+//   LogicalResult
+//   matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
+//                   ConversionPatternRewriter &rewriter) const override {
+//     // Op must have single result.
+//     if (op->getNumResults() != 1)
+//       return failure();
+//     Type resultType = op->getResult(0).getType();
+//     // Result type must be a tensor descriptor type.
+//     if (!isa<xegpu::TensorDescType>(resultType)) {
+//       LLVM_DEBUG(DBGS() << "Result type is not a tensor descriptor type: "
+//                         << resultType << "\n");
+//       return failure();
+//     }
+//     auto assignedLayout = getLayoutOfValue(op.getResult());
+//     if (!assignedLayout) {
+//       LLVM_DEBUG(DBGS() << "No layout assigned for " << *op << "\n");
+//       return failure();
+//     }
+//     // Get the original tensor descriptor type.
+//     auto origTensorDescTy = dyn_cast<xegpu::TensorDescType>(resultType);
+//     auto newTensorDescTy = xegpu::TensorDescType::get(
+//         origTensorDescTy.getContext(), origTensorDescTy.getShape(),
+//         origTensorDescTy.getElementType(), origTensorDescTy.getEncoding(),
+//         assignedLayout);
+//     rewriter.replaceOpWithNewOp<OpTy>(op, newTensorDescTy,
+//                                       adaptor.getOperands(), op->getAttrs());
+//     return success();
+//   }
+
+// private:
+//   function_ref<xegpu::LayoutAttr(Value)> getLayoutOfValue;
+// };
+// /// This class is responsible for assigning the layout attributes to the ops
+// and
+// /// their users based on the layout propagation analysis result.
+// class LayoutAttrAssignment {
+// public:
+//   LayoutAttrAssignment(Operation *top,
+//                        function_ref<LayoutInfo(Value)> getLayout)
+//       : getAnalysisResult(getLayout), top(top) {}
+
+//   LogicalResult run();
+
+// private:
+//   LogicalResult assign(Operation *op);
+//   void assignToUsers(Value v, xegpu::LayoutAttr layout);
+//   xegpu::LayoutAttr getLayoutAttrForValue(Value v);
+//   LogicalResult resolveConflicts();
+//   // Callable to get the layout of a value based on the layout propagation
+//   // analysis.
+//   function_ref<LayoutInfo(Value)> getAnalysisResult;
+//   Operation *top;
+// };
+
+// } // namespace
+
+// /// Helper to assign the layout attribute to the users of the value.
+// void LayoutAttrAssignment::assignToUsers(Value v, xegpu::LayoutAttr layout) {
+//   for (OpOperand &user : v.getUses()) {
+//     Operation *owner = user.getOwner();
+//     unsigned operandNumber = user.getOperandNumber();
+//     // Use a generic name for ease of querying the layout attribute later.
+//     std::string attrName =
+//         operandLayoutNamePrefix + std::to_string(operandNumber);
+//     owner->setAttr(attrName, layout);
+//   }
+// }
+
+// /// Convert the layout assigned to a value to xegpu::LayoutAttr.
+// xegpu::LayoutAttr LayoutAttrAssignment::getLayoutAttrForValue(Value v) {
+//   llvm::errs() << "getLayoutAttrForValue: " << v << "\n";
+//   LayoutInfo layout = getAnalysisResult(v);
+//   if (!layout.isAssigned()) {
+//     llvm::errs() << "No layout assigned for value\n";
+//     return {};
+//   }
+//   SmallVector<int, 2> laneLayout, laneData;
+//   for (auto [layout, data] : llvm::zip_equal(layout.getLayoutAsArrayRef(),
+//                                              layout.getDataAsArrayRef())) {
+//     laneLayout.push_back(static_cast<int>(layout));
+//     laneData.push_back(static_cast<int>(data));
+//   }
+//   llvm::errs() << "return layout\n";
+//   return xegpu::LayoutAttr::get(v.getContext(), laneLayout, laneData);
+// }
+
+// /// Assign xegpu::LayoutAttr to the op and its users. The layout is assigned
+// /// based on the layout propagation analysis result.
+// LogicalResult LayoutAttrAssignment::assign(Operation *op) {
+//   // For function ops, propagate the function argument layout to the users.
+//   if (auto func = dyn_cast<FunctionOpInterface>(op)) {
+//     for (BlockArgument arg : func.getArguments()) {
+//       xegpu::LayoutAttr layoutInfo = getLayoutAttrForValue(arg);
+//       if (layoutInfo) {
+//         assignToUsers(arg, layoutInfo);
+//       }
+//     }
+//     return success();
+//   }
+//   // If no results, move on.
+//   if (op->getNumResults() == 0)
+//     return success();
+//   // If all the results are scalars, move on.
+//   if (llvm::all_of(op->getResultTypes(),
+//                    [](Type t) { return t.isIntOrIndexOrFloat(); }))
+//     return success();
+//   // If the op has more than one result and at least one result is a tensor
+//   // descriptor, exit. This case is not supported yet.
+//   // TODO: Support this case.
+//   if (op->getNumResults() > 1 && llvm::any_of(op->getResultTypes(), [](Type
+//   t) {
+//         return isa<xegpu::TensorDescType>(t);
+//       })) {
+//     LLVM_DEBUG(
+//         DBGS() << op->getName()
+//                << " op has more than one result and at least one is a tensor
+//                "
+//                   "descriptor. This case is not handled.\n");
+//     return failure();
+//   }
+//   // If the result is a tensor descriptor, attach the layout to the tensor
+//   // descriptor itself.
+//   if (auto tensorDescTy =
+//           dyn_cast<xegpu::TensorDescType>(op->getResultTypes()[0])) {
+//     xegpu::LayoutAttr layoutInfo = getLayoutAttrForValue(op->getResult(0));
+//     if (!layoutInfo) {
+//       LLVM_DEBUG(DBGS() << "No layout for result of " << *op << "\n");
+//       return failure();
+//     }
+
+//     // Clone the op, attach the layout to the result tensor descriptor, and
+//     // remove the original op.
+//     OpBuilder builder(op);
+//     Operation *newOp = builder.clone(*op);
+//     auto newTensorDescTy = xegpu::TensorDescType::get(
+//         tensorDescTy.getContext(), tensorDescTy.getShape(),
+//         tensorDescTy.getElementType(), tensorDescTy.getEncoding(),
+//         layoutInfo);
+//     newOp->getResult(0).setType(newTensorDescTy);
+//     op->replaceAllUsesWith(newOp->getResults());
+//     op->erase();
+//     return success();
+//   }
+//   // Otherwise simply attach the layout to the op itself.
+//   for (auto [i, r] : llvm::enumerate(op->getResults())) {
+//     xegpu::LayoutAttr layoutInfo = getLayoutAttrForValue(r);
+//     if (layoutInfo) {
+//       std::string attrName = resultLayoutNamePrefix + std::to_string(i);
+//       op->setAttr(attrName, layoutInfo);
+//       // Attach the layout attribute to the users of the result.
+//       assignToUsers(r, layoutInfo);
+//     }
+//   }
+//   return success();
+// }
+
+// /// Walk the IR and attach xegpu::LayoutAttr to all ops and their users.
+// LogicalResult LayoutAttrAssignment::run() {
+//   // auto walkResult = top->walk([&](Operation *op) {
+//   //   if (failed(assign(op)))
+//   //     return WalkResult::interrupt();
+//   //   return WalkResult::advance();
+//   // });
+
+//   // if (walkResult.wasInterrupted())
+//   //   return failure();
+//   // apply the UpdateTensorDescType pattern to all ops
+//   // RewritePatternSet patterns(top->getContext());
+//   // patterns.add<UpdateTensorDescType>(
+//   //     top->getContext(), [&](Value v) -> xegpu::LayoutAttr {
+//   //       llvm::errs() << "invoking callback for value\n";
+//   //       return getLayoutAttrForValue(v);
+//   //     });
+//   // if (failed(applyPatternsGreedily(top, std::move(patterns))))
+//   //   return failure();
+
+//   return resolveConflicts();
+// }
+
+// /// TODO: Implement the layout conflict resolution. This must ensure mainly
+// two
+// /// things:
+// /// 1) Is a given layout supported by the op? (need to query the target
+// ///    HW info). Otherwise can we achieve this layout using a layout
+// conversion?
+// /// 2) Do all the operands have the required layout? If not, can it
+// ///    be resolved using a layout conversion?
+// LogicalResult LayoutAttrAssignment::resolveConflicts() { return success(); }
+using GetLayoutCallbackFnTy = function_ref<xegpu::LayoutAttr(Value)>;
+static void handleBranchTerminatorOpInterface(
+    mlir::OpBuilder &builder,
+    mlir::RegionBranchTerminatorOpInterface terminator,
+    GetLayoutCallbackFnTy getLayoutOfValue) {}
+static void handleBranchOpInterface(mlir::OpBuilder &builder,
+                                    mlir::RegionBranchOpInterface branch,
+                                    GetLayoutCallbackFnTy getLayoutOfValue) {}
+static void updateBlockTypes(mlir::OpBuilder &builder, mlir::Block &block,
+                             GetLayoutCallbackFnTy getLayoutOfValue) {}
+static void updateOp(mlir::OpBuilder &builder, mlir::Operation *op,
+                     GetLayoutCallbackFnTy getLayoutOfValue) {
+
+  auto updateValue = [&](Value v, unsigned vIndex,
+                         const std::string &layoutAttrName) {
+    // Layouts are needed only for vector and tensor descriptor types.
+    if (!isa<VectorType, xegpu::TensorDescType>(v.getType()))
+      return;
+    xegpu::LayoutAttr layout = getLayoutOfValue(v);
+    if (!layout) {
+      // TODO : handle error.
+      LLVM_DEBUG(DBGS() << "Expecting layout for value: " << v
+                        << " but got none.\n");
+      return;
     }
-    auto assignedLayout = getLayoutOfValue(op.getResult());
-    if (!assignedLayout) {
-      LLVM_DEBUG(DBGS() << "No layout assigned for " << *op << "\n");
-      return failure();
+    auto tensorDescTy = dyn_cast<xegpu::TensorDescType>(v.getType());
+
+    if (tensorDescTy) {
+      auto newTensorDescTy = xegpu::TensorDescType::get(
+          tensorDescTy.getContext(), tensorDescTy.getShape(),
+          tensorDescTy.getElementType(), tensorDescTy.getEncoding(), layout);
+      v.setType(newTensorDescTy);
+      return;
     }
-    // Get the original tensor descriptor type.
-    auto origTensorDescTy = dyn_cast<xegpu::TensorDescType>(resultType);
-    auto newTensorDescTy = xegpu::TensorDescType::get(
-        origTensorDescTy.getContext(), origTensorDescTy.getShape(),
-        origTensorDescTy.getElementType(), origTensorDescTy.getEncoding(),
-        assignedLayout);
-    rewriter.replaceOpWithNewOp<OpTy>(op, newTensorDescTy,
-                                      adaptor.getOperands(), op->getAttrs());
-    return success();
-  }
-
-private:
-  function_ref<xegpu::LayoutAttr(Value)> getLayoutOfValue;
-};
-/// This class is responsible for assigning the layout attributes to the ops and
-/// their users based on the layout propagation analysis result.
-class LayoutAttrAssignment {
-public:
-  LayoutAttrAssignment(Operation *top,
-                       function_ref<LayoutInfo(Value)> getLayout)
-      : getAnalysisResult(getLayout), top(top) {}
-
-  LogicalResult run();
-
-private:
-  LogicalResult assign(Operation *op);
-  void assignToUsers(Value v, xegpu::LayoutAttr layout);
-  xegpu::LayoutAttr getLayoutAttrForValue(Value v);
-  LogicalResult resolveConflicts();
-  // Callable to get the layout of a value based on the layout propagation
-  // analysis.
-  function_ref<LayoutInfo(Value)> getAnalysisResult;
-  Operation *top;
-};
-
-} // namespace
-
-/// Helper to assign the layout attribute to the users of the value.
-void LayoutAttrAssignment::assignToUsers(Value v, xegpu::LayoutAttr layout) {
-  for (OpOperand &user : v.getUses()) {
-    Operation *owner = user.getOwner();
-    unsigned operandNumber = user.getOperandNumber();
-    // Use a generic name for ease of querying the layout attribute later.
-    std::string attrName =
-        operandLayoutNamePrefix + std::to_string(operandNumber);
-    owner->setAttr(attrName, layout);
-  }
-}
-
-/// Convert the layout assigned to a value to xegpu::LayoutAttr.
-xegpu::LayoutAttr LayoutAttrAssignment::getLayoutAttrForValue(Value v) {
-  llvm::errs() << "getLayoutAttrForValue: " << v << "\n";
-  LayoutInfo layout = getAnalysisResult(v);
-  if (!layout.isAssigned()) {
-    llvm::errs() << "No layout assigned for value\n";
-    return {};
-  }
-  SmallVector<int, 2> laneLayout, laneData;
-  for (auto [layout, data] : llvm::zip_equal(layout.getLayoutAsArrayRef(),
-                                             layout.getDataAsArrayRef())) {
-    laneLayout.push_back(static_cast<int>(layout));
-    laneData.push_back(static_cast<int>(data));
-  }
-  llvm::errs() << "return layout\n";
-  return xegpu::LayoutAttr::get(v.getContext(), laneLayout, laneData);
-}
+    // If type is vector, add a temporary layout attribute to the op.
+    op->setAttr(layoutAttrName, layout);
+  };
 
-/// Assign xegpu::LayoutAttr to the op and its users. The layout is assigned
-/// based on the layout propagation analysis result.
-LogicalResult LayoutAttrAssignment::assign(Operation *op) {
-  // For function ops, propagate the function argument layout to the users.
-  if (auto func = dyn_cast<FunctionOpInterface>(op)) {
-    for (BlockArgument arg : func.getArguments()) {
-      xegpu::LayoutAttr layoutInfo = getLayoutAttrForValue(arg);
-      if (layoutInfo) {
-        assignToUsers(arg, layoutInfo);
-      }
-    }
-    return success();
-  }
-  // If no results, move on.
-  if (op->getNumResults() == 0)
-    return success();
-  // If all the results are scalars, move on.
-  if (llvm::all_of(op->getResultTypes(),
-                   [](Type t) { return t.isIntOrIndexOrFloat(); }))
-    return success();
-  // If the op has more than one result and at least one result is a tensor
-  // descriptor, exit. This case is not supported yet.
-  // TODO: Support this case.
-  if (op->getNumResults() > 1 && llvm::any_of(op->getResultTypes(), [](Type t) {
-        return isa<xegpu::TensorDescType>(t);
-      })) {
-    LLVM_DEBUG(
-        DBGS() << op->getName()
-               << " op has more than one result and at least one is a tensor "
-                  "descriptor. This case is not handled.\n");
-    return failure();
+  // Iterate over all the operands.
+  for (OpOperand &operand : op->getOpOperands()) {
+    unsigned operandIndex = operand.getOperandNumber();
+    std::string operandLayoutName =
+        operandLayoutNamePrefix + std::to_string(operandIndex);
+    updateValue(operand.get(), operandIndex, operandLayoutName);
   }
-  // If the result is a tensor descriptor, attach the layout to the tensor
-  // descriptor itself.
-  if (auto tensorDescTy =
-          dyn_cast<xegpu::TensorDescType>(op->getResultTypes()[0])) {
-    xegpu::LayoutAttr layoutInfo = getLayoutAttrForValue(op->getResult(0));
-    if (!layoutInfo) {
-      LLVM_DEBUG(DBGS() << "No layout for result of " << *op << "\n");
-      return failure();
-    }
 
-    // Clone the op, attach the layout to the result tensor descriptor, and
-    // remove the original op.
-    OpBuilder builder(op);
-    Operation *newOp = builder.clone(*op);
-    auto newTensorDescTy = xegpu::TensorDescType::get(
-        tensorDescTy.getContext(), tensorDescTy.getShape(),
-        tensorDescTy.getElementType(), tensorDescTy.getEncoding(), layoutInfo);
-    newOp->getResult(0).setType(newTensorDescTy);
-    op->replaceAllUsesWith(newOp->getResults());
-    op->erase();
-    return success();
+  // Iterate over all the results.
+  for (OpResult result : op->getResults()) {
+    unsigned resultIndex = result.getResultNumber();
+    std::string resultLayoutName =
+        resultLayoutNamePrefix + std::to_string(resultIndex);
+    updateValue(result, resultIndex, resultLayoutName);
   }
-  // Otherwise simply attach the layout to the op itself.
-  for (auto [i, r] : llvm::enumerate(op->getResults())) {
-    xegpu::LayoutAttr layoutInfo = getLayoutAttrForValue(r);
-    if (layoutInfo) {
-      std::string attrName = resultLayoutNamePrefix + std::to_string(i);
-      op->setAttr(attrName, layoutInfo);
-      // Attach the layout attribute to the users of the result.
-      assignToUsers(r, layoutInfo);
-    }
-  }
-  return success();
 }
 
-/// Walk the IR and attach xegpu::LayoutAttr to all ops and their users.
-LogicalResult LayoutAttrAssignment::run() {
-  // auto walkResult = top->walk([&](Operation *op) {
-  //   if (failed(assign(op)))
-  //     return WalkResult::interrupt();
-  //   return WalkResult::advance();
-  // });
-
-  // if (walkResult.wasInterrupted())
-  //   return failure();
-  // apply the UpdateTensorDescType pattern to all ops
-  // RewritePatternSet patterns(top->getContext());
-  // patterns.add<UpdateTensorDescType>(
-  //     top->getContext(), [&](Value v) -> xegpu::LayoutAttr {
-  //       llvm::errs() << "invoking callback for value\n";
-  //       return getLayoutAttrForValue(v);
-  //     });
-  // if (failed(applyPatternsGreedily(top, std::move(patterns))))
-  //   return failure();
-
-  return resolveConflicts();
-}
-
-/// TODO: Implement the layout conflict resolution. This must ensure mainly two
-/// things:
-/// 1) Is a given layout supported by the op? (need to query the target
-///    HW info). Otherwise can we achieve this layout using a layout conversion?
-/// 2) Do all the operands have the required layout? If not, can it
-///    be resolved using a layout conversion?
-LogicalResult LayoutAttrAssignment::resolveConflicts() { return success(); }
-
 namespace {
 
 //===----------------------------------------------------------------------===//
@@ -1657,10 +1719,10 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
   // auto getPropagatedLayout = [&](Value val) {
   //   return analyis.getLayoutInfo(val);
   // };
-  auto getXeGpuLayoutForValue = [&](Value val) -> xegpu::LayoutAttr {
+  auto getXeGPULayoutForValue = [&](Value val) -> xegpu::LayoutAttr {
     LayoutInfo layout = analyis.getLayoutInfo(val);
     if (!layout.isAssigned()) {
-      llvm::errs() << "No layout assigned for value\n";
+      llvm::errs() << "No layout assigned for value" << val << "\n";
       return {};
     }
     SmallVector<int, 2> laneLayout, laneData;
@@ -1672,41 +1734,26 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
     return xegpu::LayoutAttr::get(val.getContext(), laneLayout, laneData);
   };
 
-  ConversionTarget target(getContext());
-  target.addDynamicallyLegalOp<xegpu::CreateNdDescOp, xegpu::UpdateNdOffsetOp>(
-      [&](Operation *op) {
-        return llvm::all_of(op->getResults(), [&](Value val) {
-          if (auto descType = dyn_cast<xegpu::TensorDescType>(val.getType())) {
-            return descType.getLayoutAttr() != nullptr;
-          }
-          return true; // Non-tensor descriptor types are always legal.
-        });
-      });
-  target.addLegalOp<UnrealizedConversionCastOp>();
-  TypeConverter typeConverter;
-  typeConverter.addConversion([](Type type) { return type; });
-  // // typeConverter.addConversion([](xegpu::TensorDescType type) {
-  // //   return xegpu::TensorDescType::get(
-  // //       type.getContext(), type.getShape(), type.getElementType(),
-  // //       type.getEncoding(),
-  // //       xegpu::LayoutAttr::get(type.getContext(), {1, 1}, {1, 1}));
-  // // });
-  auto addUnrealizedCast = [](OpBuilder &builder, Type type, ValueRange inputs,
-                              Location loc) -> Value {
-    auto cast = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
-    return cast.getResult(0);
-  };
+  mlir::OpBuilder builder(&getContext());
+  Operation *op = getOperation();
+  op->walk([&](mlir::Block *block) {
+    for (mlir::Operation &op : llvm::reverse(block->getOperations())) {
+      if (auto terminator =
+              mlir::dyn_cast<mlir::RegionBranchTerminatorOpInterface>(op)) {
+        handleBranchTerminatorOpInterface(builder, terminator,
+                                          getXeGPULayoutForValue);
+        continue;
+      }
 
-  typeConverter.addSourceMaterialization(addUnrealizedCast);
-  typeConverter.addTargetMaterialization(addUnrealizedCast);
+      if (auto iface = mlir::dyn_cast<mlir::RegionBranchOpInterface>(op)) {
+        handleBranchOpInterface(builder, iface, getXeGPULayoutForValue);
+        continue;
+      }
+      updateOp(builder, &op, getXeGPULayoutForValue);
+    }
 
-  RewritePatternSet patterns(&getContext());
-  patterns.add<UpdateTensorDescType<xegpu::CreateNdDescOp>,
-               UpdateTensorDescType<xegpu::UpdateNdOffsetOp>>(
-      &getContext(), getXeGpuLayoutForValue, typeConverter);
-  if (failed(
-          applyPartialConversion(getOperation(), target, std::move(patterns))))
-    signalPassFailure();
+    updateBlockTypes(builder, *block, getXeGPULayoutForValue);
+  });
 
   // Assign xegpu::LayoutAttr to all ops and their users based on the layout
   // propagation analysis result.

>From 7d54194f0c726db4461015de87abf9ad380bbfa3 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Tue, 3 Jun 2025 19:50:30 +0000
Subject: [PATCH 07/32] working version

---
 .../Transforms/XeGPUSubgroupDistribute.cpp    | 159 +++++++++++++-----
 1 file changed, 120 insertions(+), 39 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index aa982ae779d1e..6b3ff8312e365 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -40,6 +40,7 @@
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/Debug.h"
 #include "llvm/Support/FormatVariadic.h"
 #include "llvm/Support/InterleavedRange.h"
 #include "llvm/Support/LogicalResult.h"
@@ -905,59 +906,140 @@ void RunLayoutInfoPropagation::printAnalysisResult(llvm::raw_ostream &os) {
 // ///    be resolved using a layout conversion?
 // LogicalResult LayoutAttrAssignment::resolveConflicts() { return success(); }
 using GetLayoutCallbackFnTy = function_ref<xegpu::LayoutAttr(Value)>;
+static void updateOp(mlir::OpBuilder &builder, mlir::Operation *op,
+                     GetLayoutCallbackFnTy getLayoutOfValue) {
+
+  // Iterate over all the results.
+  for (OpResult result : op->getResults()) {
+    Type resultType = result.getType();
+    // Layouts are needed only for vector and tensor descriptor types.
+    if (!isa<VectorType, xegpu::TensorDescType>(resultType))
+      continue;
+    // If the result has any users, we expect it to have a layout.
+    xegpu::LayoutAttr layout = getLayoutOfValue(result);
+    if (!layout && result.getNumUses() > 0) {
+      LLVM_DEBUG(DBGS() << "Expecting layout for result: " << result
+                        << " but got none.\n");
+      continue;
+    }
+    if (auto tensorDescTy = dyn_cast<xegpu::TensorDescType>(resultType)) {
+      // TODO: Handle error.
+      auto typeWithLayout = xegpu::TensorDescType::get(
+          tensorDescTy.getContext(), tensorDescTy.getShape(),
+          tensorDescTy.getElementType(), tensorDescTy.getEncoding(), layout);
+      result.setType(typeWithLayout);
+      continue;
+    }
+    // If the result is a vector type, add a temporary layout attribute to the
+    // op.
+    std::string resultLayoutName =
+        resultLayoutNamePrefix + std::to_string(result.getResultNumber());
+    op->setAttr(resultLayoutName, layout);
+    // Update all users of the result with the layout.
+    for (OpOperand &user : result.getUses()) {
+      Operation *owner = user.getOwner();
+      unsigned operandNumber = user.getOperandNumber();
+      // Add temorary layout attribute at the user op.
+      std::string attrName =
+          operandLayoutNamePrefix + std::to_string(operandNumber);
+      owner->setAttr(attrName, layout);
+    }
+  }
+}
 static void handleBranchTerminatorOpInterface(
     mlir::OpBuilder &builder,
     mlir::RegionBranchTerminatorOpInterface terminator,
     GetLayoutCallbackFnTy getLayoutOfValue) {}
 static void handleBranchOpInterface(mlir::OpBuilder &builder,
                                     mlir::RegionBranchOpInterface branch,
-                                    GetLayoutCallbackFnTy getLayoutOfValue) {}
-static void updateBlockTypes(mlir::OpBuilder &builder, mlir::Block &block,
-                             GetLayoutCallbackFnTy getLayoutOfValue) {}
-static void updateOp(mlir::OpBuilder &builder, mlir::Operation *op,
-                     GetLayoutCallbackFnTy getLayoutOfValue) {
+                                    GetLayoutCallbackFnTy getLayoutOfValue) {
+  mlir::Operation *op = branch.getOperation();
+  llvm::SmallVector<mlir::RegionSuccessor> successors;
+  llvm::SmallVector<mlir::Attribute> operands(op->getNumOperands(), nullptr);
+  branch.getEntrySuccessorRegions(operands, successors);
+  DenseMap<Value, xegpu::LayoutAttr> resultToLayouts;
+  mlir::ValueRange results = op->getResults();
+
+  for (mlir::RegionSuccessor &successor : successors) {
+    if (successor.isParent())
+      continue;
 
-  auto updateValue = [&](Value v, unsigned vIndex,
-                         const std::string &layoutAttrName) {
-    // Layouts are needed only for vector and tensor descriptor types.
-    if (!isa<VectorType, xegpu::TensorDescType>(v.getType()))
-      return;
-    xegpu::LayoutAttr layout = getLayoutOfValue(v);
+    mlir::OperandRange initArgs = branch.getEntrySuccessorOperands(successor);
+    mlir::ValueRange blockArgs = successor.getSuccessorInputs();
+    unsigned index = 0;
+
+    for (auto [initArg, blockArg, result] :
+         llvm::zip(initArgs, blockArgs, results)) {
+      Type inputType = blockArg.getType();
+      if (!isa<xegpu::TensorDescType>(inputType))
+        continue;
+      xegpu::LayoutAttr blockArgLayout = getLayoutOfValue(blockArg);
+      xegpu::LayoutAttr initArgLayout = getLayoutOfValue(initArg);
+
+      if (!blockArgLayout || !initArgLayout) {
+        LLVM_DEBUG(DBGS() << "No layout assigned for block arg: " << blockArg
+                          << " or init arg: " << initArg << "\n");
+        continue;
+      }
+
+      // TOOD: We expect these two to match. Data flow analysis will ensure
+      // this.
+      assert(blockArgLayout == initArgLayout &&
+             "Expexing block arg and init arg to have the same layout.");
+      // Get tensor descriptor type with the layout.
+      auto tdescTy = dyn_cast<xegpu::TensorDescType>(inputType);
+      auto newTdescTy = xegpu::TensorDescType::get(
+          tdescTy.getContext(), tdescTy.getShape(), tdescTy.getElementType(),
+          tdescTy.getEncoding(), blockArgLayout);
+      blockArg.setType(newTdescTy);
+      // Store the layout for the result.
+      if (resultToLayouts.count(result) != 0 &&
+          resultToLayouts[result] != blockArgLayout) {
+        LLVM_DEBUG(DBGS() << "Conflicting layouts for result: " << result
+                          << " - " << resultToLayouts[result] << " vs "
+                          << blockArgLayout << "\n");
+      } else {
+        resultToLayouts[result] = blockArgLayout;
+      }
+    }
+  }
+  for (auto [i, r] : llvm::enumerate(op->getResults())) {
+    Type resultType = r.getType();
+    if (!isa<xegpu::TensorDescType, VectorType>(resultType))
+      continue;
+    xegpu::LayoutAttr layout = getLayoutOfValue(r);
+    if (!layout)
+      layout = resultToLayouts[r];
     if (!layout) {
-      // TODO : handle error.
-      LLVM_DEBUG(DBGS() << "Expecting layout for value: " << v
-                        << " but got none.\n");
-      return;
+      LLVM_DEBUG(DBGS() << "No layout assigned for vector/tensor desc result: "
+                        << r << "\n");
+      continue;
     }
-    auto tensorDescTy = dyn_cast<xegpu::TensorDescType>(v.getType());
-
-    if (tensorDescTy) {
-      auto newTensorDescTy = xegpu::TensorDescType::get(
+    if (auto tensorDescTy = dyn_cast<xegpu::TensorDescType>(resultType)) {
+      auto newTdescTy = xegpu::TensorDescType::get(
           tensorDescTy.getContext(), tensorDescTy.getShape(),
           tensorDescTy.getElementType(), tensorDescTy.getEncoding(), layout);
-      v.setType(newTensorDescTy);
-      return;
+      r.setType(newTdescTy);
+      continue;
     }
-    // If type is vector, add a temporary layout attribute to the op.
-    op->setAttr(layoutAttrName, layout);
-  };
-
-  // Iterate over all the operands.
-  for (OpOperand &operand : op->getOpOperands()) {
-    unsigned operandIndex = operand.getOperandNumber();
-    std::string operandLayoutName =
-        operandLayoutNamePrefix + std::to_string(operandIndex);
-    updateValue(operand.get(), operandIndex, operandLayoutName);
-  }
-
-  // Iterate over all the results.
-  for (OpResult result : op->getResults()) {
-    unsigned resultIndex = result.getResultNumber();
+    // If the result is a vector type, add a temporary layout attribute to the
+    // op.
     std::string resultLayoutName =
-        resultLayoutNamePrefix + std::to_string(resultIndex);
-    updateValue(result, resultIndex, resultLayoutName);
+        resultLayoutNamePrefix + std::to_string(r.getResultNumber());
+    op->setAttr(resultLayoutName, layout);
+    // Update all users of the result with the layout.
+    for (OpOperand &user : r.getUses()) {
+      Operation *owner = user.getOwner();
+      unsigned operandNumber = user.getOperandNumber();
+      // Add temporary layout attribute at the user op.
+      std::string attrName =
+          operandLayoutNamePrefix + std::to_string(operandNumber);
+      owner->setAttr(attrName, layout);
+    }
   }
 }
+static void updateBlockTypes(mlir::OpBuilder &builder, mlir::Block &block,
+                             GetLayoutCallbackFnTy getLayoutOfValue) {}
 
 namespace {
 
@@ -1722,7 +1804,6 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
   auto getXeGPULayoutForValue = [&](Value val) -> xegpu::LayoutAttr {
     LayoutInfo layout = analyis.getLayoutInfo(val);
     if (!layout.isAssigned()) {
-      llvm::errs() << "No layout assigned for value" << val << "\n";
       return {};
     }
     SmallVector<int, 2> laneLayout, laneData;

>From b289399e44bf56e91149cbfc37a729c14949c4d2 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Tue, 3 Jun 2025 21:36:11 +0000
Subject: [PATCH 08/32] working expect for unreal cast

---
 .../Transforms/XeGPUSubgroupDistribute.cpp    | 97 +++++++++----------
 1 file changed, 46 insertions(+), 51 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 6b3ff8312e365..dfb7b0668d2be 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -1291,11 +1291,14 @@ struct CreateNdDescDistribution final : public gpu::WarpDistributionPattern {
     xegpu::TensorDescType distributedTensorDescTy =
         descOp.getType().dropLayouts(); // Distributed tensor descriptor type
                                         // does not contain layout info.
-    auto newDescOp = rewriter.create<xegpu::CreateNdDescOp>(
+    Value newDescOp = rewriter.create<xegpu::CreateNdDescOp>(
         newWarpOp.getLoc(), distributedTensorDescTy, newDescOperands,
         descOp->getAttrs());
 
     Value distributedVal = newWarpOp.getResult(operandIdx);
+    // Resolve the distributed type to the expected type.
+    newDescOp =
+        resolveDistributedTy(newDescOp, distributedVal.getType(), rewriter);
     rewriter.replaceAllUsesWith(distributedVal, newDescOp);
     return success();
   }
@@ -1697,10 +1700,13 @@ struct UpdateNdOffsetDistribution final : public gpu::WarpDistributionPattern {
       }
     }
     // Create a new update op outside the warp op.
-    auto newUpdateOp = rewriter.create<xegpu::UpdateNdOffsetOp>(
+    Value newUpdateOp = rewriter.create<xegpu::UpdateNdOffsetOp>(
         newWarpOp.getLoc(), newTensorDescTy, newUpdateOperands,
         removeTemporaryLayoutAttributes(updateOp->getAttrs()));
     Value distributedVal = newWarpOp.getResult(operandIdx);
+    // Resolve the distributed type with the original type.
+    newUpdateOp =
+        resolveDistributedTy(newUpdateOp, distributedVal.getType(), rewriter);
     rewriter.replaceAllUsesWith(distributedVal, newUpdateOp);
     return success();
   }
@@ -1836,55 +1842,44 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
     updateBlockTypes(builder, *block, getXeGPULayoutForValue);
   });
 
-  // Assign xegpu::LayoutAttr to all ops and their users based on the layout
-  // propagation analysis result.
-  // LayoutAttrAssignment layoutAssignment(getOperation(), getPropagatedLayout);
-  // if (failed(layoutAssignment.run())) {
-  //   signalPassFailure();
-  //   return;
-  // }
-
   // Move all operations of a GPU function inside gpu.warp_execute_on_lane_0
   // operation.
-  // {
-  //   RewritePatternSet patterns(&getContext());
-  //   patterns.add<MoveFuncBodyToWarpExecuteOnLane0>(&getContext());
-
-  //   if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
-  //     signalPassFailure();
-  //     return;
-  //   }
-  //   // At this point, we have moved the entire function body inside the
-  //   warpOp.
-  //   // Now move any scalar uniform code outside of the warpOp (like GPU index
-  //   // ops, scalar constants, etc.). This will simplify the later lowering
-  //   and
-  //   // avoid custom patterns for these ops.
-  //   getOperation()->walk([&](Operation *op) {
-  //     if (auto warpOp = dyn_cast<gpu::WarpExecuteOnLane0Op>(op)) {
-  //       vector::moveScalarUniformCode(warpOp);
-  //     }
-  //   });
-  // }
-  // // Finally, do the SIMD to SIMT distribution.
-  // RewritePatternSet patterns(&getContext());
-  // xegpu::populateXeGPUSubgroupDistributePatterns(patterns);
-  // // TODO: distributionFn and shuffleFn are not used at this point.
-  // auto distributionFn = [](Value val) {
-  //   VectorType vecType = dyn_cast<VectorType>(val.getType());
-  //   int64_t vecRank = vecType ? vecType.getRank() : 0;
-  //   OpBuilder builder(val.getContext());
-  //   if (vecRank == 0)
-  //     return AffineMap::get(val.getContext());
-  //   return AffineMap::getMultiDimIdentityMap(vecRank, val.getContext());
-  // };
-  // auto shuffleFn = [](Location loc, OpBuilder &builder, Value val, Value
-  // srcIdx,
-  //                     int64_t warpSz) { return Value(); };
-  // vector::populatePropagateWarpVectorDistributionPatterns(
-  //     patterns, distributionFn, shuffleFn);
-  // if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
-  //   signalPassFailure();
-  //   return;
-  // }
+  {
+    RewritePatternSet patterns(&getContext());
+    patterns.add<MoveFuncBodyToWarpExecuteOnLane0>(&getContext());
+
+    if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
+      signalPassFailure();
+      return;
+    }
+    // At this point, we have moved the entire function body inside the
+    // warpOp. Now move any scalar uniform code outside of the warpOp (like GPU
+    // index ops, scalar constants, etc.). This will simplify the later lowering
+    // and avoid custom patterns for these ops.
+    getOperation()->walk([&](Operation *op) {
+      if (auto warpOp = dyn_cast<gpu::WarpExecuteOnLane0Op>(op)) {
+        vector::moveScalarUniformCode(warpOp);
+      }
+    });
+  }
+  // Finally, do the SIMD to SIMT distribution.
+  RewritePatternSet patterns(&getContext());
+  xegpu::populateXeGPUSubgroupDistributePatterns(patterns);
+  // TODO: distributionFn and shuffleFn are not used at this point.
+  auto distributionFn = [](Value val) {
+    VectorType vecType = dyn_cast<VectorType>(val.getType());
+    int64_t vecRank = vecType ? vecType.getRank() : 0;
+    OpBuilder builder(val.getContext());
+    if (vecRank == 0)
+      return AffineMap::get(val.getContext());
+    return AffineMap::getMultiDimIdentityMap(vecRank, val.getContext());
+  };
+  auto shuffleFn = [](Location loc, OpBuilder &builder, Value val, Value srcIdx,
+                      int64_t warpSz) { return Value(); };
+  vector::populatePropagateWarpVectorDistributionPatterns(
+      patterns, distributionFn, shuffleFn);
+  if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
+    signalPassFailure();
+    return;
+  }
 }

>From 4318343ead59cda8f70741ca45e9255a6ce66bba Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Tue, 3 Jun 2025 22:57:01 +0000
Subject: [PATCH 09/32] some fixes

---
 .../Transforms/XeGPUSubgroupDistribute.cpp    | 70 ++++++++++++++++---
 1 file changed, 60 insertions(+), 10 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index dfb7b0668d2be..56ec1eaa118e5 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -68,8 +68,14 @@ constexpr unsigned packedSizeInBitsForDefault =
     16; // Minimum packing size per register for DPAS A.
 constexpr unsigned packedSizeInBitsForDpasB =
     32; // Minimum packing size per register for DPAS B.
-static const char *const operandLayoutNamePrefix = "layout_operand_";
-static const char *const resultLayoutNamePrefix = "layout_result_";
+static const char *const operandLayoutNamePrefix =
+    "layout_operand_"; // Attribute name for identifying operand layouts.
+static const char *const resultLayoutNamePrefix =
+    "layout_result_"; // Attribute name for identifying result layouts.
+static const char *const resolveSIMTTypeMismatch =
+    "resolve_simt_type_mismatch"; // Attribute name for identifying
+                                  // UnrelizedConversionCastOp added to resolve
+                                  // SIMT type mismatches.
 
 namespace {
 
@@ -946,11 +952,11 @@ static void updateOp(mlir::OpBuilder &builder, mlir::Operation *op,
     }
   }
 }
-static void handleBranchTerminatorOpInterface(
+static void updateBranchTerminatorOpInterface(
     mlir::OpBuilder &builder,
     mlir::RegionBranchTerminatorOpInterface terminator,
     GetLayoutCallbackFnTy getLayoutOfValue) {}
-static void handleBranchOpInterface(mlir::OpBuilder &builder,
+static void updateBranchOpInterface(mlir::OpBuilder &builder,
                                     mlir::RegionBranchOpInterface branch,
                                     GetLayoutCallbackFnTy getLayoutOfValue) {
   mlir::Operation *op = branch.getOperation();
@@ -966,7 +972,6 @@ static void handleBranchOpInterface(mlir::OpBuilder &builder,
 
     mlir::OperandRange initArgs = branch.getEntrySuccessorOperands(successor);
     mlir::ValueRange blockArgs = successor.getSuccessorInputs();
-    unsigned index = 0;
 
     for (auto [initArg, blockArg, result] :
          llvm::zip(initArgs, blockArgs, results)) {
@@ -1117,6 +1122,7 @@ static Value resolveDistributedTy(Value orig, T expected,
   if (isa<xegpu::TensorDescType>(orig.getType())) {
     auto castOp = rewriter.create<UnrealizedConversionCastOp>(orig.getLoc(),
                                                               expected, orig);
+    castOp->setAttr(resolveSIMTTypeMismatch, rewriter.getUnitAttr());
     return castOp.getResult(0);
   }
   llvm_unreachable("Unsupported type for reconciliation");
@@ -1804,9 +1810,7 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
     analyis.printAnalysisResult(os);
     return;
   }
-  // auto getPropagatedLayout = [&](Value val) {
-  //   return analyis.getLayoutInfo(val);
-  // };
+
   auto getXeGPULayoutForValue = [&](Value val) -> xegpu::LayoutAttr {
     LayoutInfo layout = analyis.getLayoutInfo(val);
     if (!layout.isAssigned()) {
@@ -1827,13 +1831,13 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
     for (mlir::Operation &op : llvm::reverse(block->getOperations())) {
       if (auto terminator =
               mlir::dyn_cast<mlir::RegionBranchTerminatorOpInterface>(op)) {
-        handleBranchTerminatorOpInterface(builder, terminator,
+        updateBranchTerminatorOpInterface(builder, terminator,
                                           getXeGPULayoutForValue);
         continue;
       }
 
       if (auto iface = mlir::dyn_cast<mlir::RegionBranchOpInterface>(op)) {
-        handleBranchOpInterface(builder, iface, getXeGPULayoutForValue);
+        updateBranchOpInterface(builder, iface, getXeGPULayoutForValue);
         continue;
       }
       updateOp(builder, &op, getXeGPULayoutForValue);
@@ -1882,4 +1886,50 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
     signalPassFailure();
     return;
   }
+
+  // Clean up UnrealizedConversionCastOps that were inserted due to tensor desc
+  // type mismatches created by using upstream distribution patterns (scf.for)
+  getOperation()->walk([&](mlir::UnrealizedConversionCastOp op) {
+    // We are only interested in UnrealizedConversionCastOps there were added
+    // for resolving SIMT type mismatches.
+    if (!op->getAttr(resolveSIMTTypeMismatch))
+      return WalkResult::skip();
+
+    Value input = op.getOperand(0);
+    Value output = op.getResult(0);
+
+    // Both input and output must have tensor descriptor types.
+    xegpu::TensorDescType inputDescType =
+        mlir::dyn_cast<xegpu::TensorDescType>(input.getType());
+    xegpu::TensorDescType outputDescType =
+        mlir::dyn_cast<xegpu::TensorDescType>(output.getType());
+    assert(inputDescType && outputDescType &&
+           "Unrealized conversion cast must have tensor descriptor types");
+
+    // tensor_desc<shape, layout> -> tensor_desc<shape> Type of conversions.
+    // This occurs iside scf.for body to resolve the block argument type to SIMT
+    // type.
+    if (inputDescType.getLayout()) {
+      auto argument = mlir::dyn_cast<mlir::BlockArgument>(input);
+      if (argument) {
+        argument.setType(output.getType());
+        output.replaceAllUsesWith(argument);
+        if (auto loopOp = mlir::dyn_cast<mlir::LoopLikeOpInterface>(
+                argument.getOwner()->getParentOp())) {
+          auto result = loopOp.getTiedLoopResult(argument);
+          result.setType(output.getType());
+        }
+      }
+    }
+
+    // tensor_desc<shape> -> tensor_desc<shape, layout> Type of
+    // conversions. This occurs at the yield op of scf.for body to go back from
+    // SIMT type to original type.
+    if (outputDescType.getLayout())
+      output.replaceAllUsesWith(input);
+
+    if (op->use_empty())
+      op->erase();
+    return WalkResult::advance();
+  });
 }

>From 20a641545534132b59c934d7bc31b6c088134605 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Wed, 4 Jun 2025 00:01:15 +0000
Subject: [PATCH 10/32] branch terminator iface

---
 .../Transforms/XeGPUSubgroupDistribute.cpp    | 332 ++++++++++--------
 1 file changed, 193 insertions(+), 139 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 56ec1eaa118e5..27d912b87c6dc 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -955,7 +955,54 @@ static void updateOp(mlir::OpBuilder &builder, mlir::Operation *op,
 static void updateBranchTerminatorOpInterface(
     mlir::OpBuilder &builder,
     mlir::RegionBranchTerminatorOpInterface terminator,
-    GetLayoutCallbackFnTy getLayoutOfValue) {}
+    GetLayoutCallbackFnTy getLayoutOfValue) {
+  if (!mlir::isa<mlir::RegionBranchOpInterface>(terminator->getParentOp()))
+    return;
+
+  llvm::SmallVector<mlir::RegionSuccessor> successors;
+  llvm::SmallVector<mlir::Attribute> operands(terminator->getNumOperands(),
+                                              nullptr);
+  terminator.getSuccessorRegions(operands, successors);
+
+  for (mlir::RegionSuccessor &successor : successors) {
+    if (!successor.isParent())
+      continue;
+
+    mlir::OperandRange operands = terminator.getSuccessorOperands(successor);
+    mlir::ValueRange inputs = successor.getSuccessorInputs();
+    for (auto [operand, input] : llvm::zip(operands, inputs)) {
+      // print arg and inp
+      // llvm::errs() << "arg: " << operand << ", inp: " << input << "\n";
+      Type inputType = input.getType();
+      if (!isa<xegpu::TensorDescType>(inputType))
+        continue;
+      xegpu::LayoutAttr inputLayout = getLayoutOfValue(input);
+      xegpu::LayoutAttr operandLayout = getLayoutOfValue(operand);
+
+      if (!operandLayout) {
+        LLVM_DEBUG(DBGS() << "Expecting layout for region successor operand : "
+                          << operand << " but got none.\n");
+        continue;
+      }
+
+      if (inputLayout && inputLayout != operandLayout) {
+        LLVM_DEBUG(
+            DBGS()
+            << "Conflicting layouts for region successor operand and input: "
+            << inputLayout << " vs " << operandLayout << "\n");
+        continue;
+      }
+      llvm::errs() << "Setting layout for input to "
+                   << ": " << operandLayout << "\n";
+      // Get tensor descriptor type with the layout.
+      auto tdescTy = dyn_cast<xegpu::TensorDescType>(inputType);
+      auto newTdescTy = xegpu::TensorDescType::get(
+          tdescTy.getContext(), tdescTy.getShape(), tdescTy.getElementType(),
+          tdescTy.getEncoding(), operandLayout);
+      input.setType(newTdescTy);
+    }
+  }
+}
 static void updateBranchOpInterface(mlir::OpBuilder &builder,
                                     mlir::RegionBranchOpInterface branch,
                                     GetLayoutCallbackFnTy getLayoutOfValue) {
@@ -970,20 +1017,19 @@ static void updateBranchOpInterface(mlir::OpBuilder &builder,
     if (successor.isParent())
       continue;
 
-    mlir::OperandRange initArgs = branch.getEntrySuccessorOperands(successor);
-    mlir::ValueRange blockArgs = successor.getSuccessorInputs();
+    mlir::OperandRange operands = branch.getEntrySuccessorOperands(successor);
+    mlir::ValueRange inputs = successor.getSuccessorInputs();
 
-    for (auto [initArg, blockArg, result] :
-         llvm::zip(initArgs, blockArgs, results)) {
-      Type inputType = blockArg.getType();
+    for (auto [operand, input, result] : llvm::zip(operands, inputs, results)) {
+      Type inputType = input.getType();
       if (!isa<xegpu::TensorDescType>(inputType))
         continue;
-      xegpu::LayoutAttr blockArgLayout = getLayoutOfValue(blockArg);
-      xegpu::LayoutAttr initArgLayout = getLayoutOfValue(initArg);
+      xegpu::LayoutAttr blockArgLayout = getLayoutOfValue(input);
+      xegpu::LayoutAttr initArgLayout = getLayoutOfValue(operand);
 
       if (!blockArgLayout || !initArgLayout) {
-        LLVM_DEBUG(DBGS() << "No layout assigned for block arg: " << blockArg
-                          << " or init arg: " << initArg << "\n");
+        LLVM_DEBUG(DBGS() << "No layout assigned for block arg: " << input
+                          << " or init arg: " << operand << "\n");
         continue;
       }
 
@@ -996,52 +1042,54 @@ static void updateBranchOpInterface(mlir::OpBuilder &builder,
       auto newTdescTy = xegpu::TensorDescType::get(
           tdescTy.getContext(), tdescTy.getShape(), tdescTy.getElementType(),
           tdescTy.getEncoding(), blockArgLayout);
-      blockArg.setType(newTdescTy);
+      input.setType(newTdescTy);
       // Store the layout for the result.
-      if (resultToLayouts.count(result) != 0 &&
-          resultToLayouts[result] != blockArgLayout) {
-        LLVM_DEBUG(DBGS() << "Conflicting layouts for result: " << result
-                          << " - " << resultToLayouts[result] << " vs "
-                          << blockArgLayout << "\n");
-      } else {
-        resultToLayouts[result] = blockArgLayout;
-      }
-    }
-  }
-  for (auto [i, r] : llvm::enumerate(op->getResults())) {
-    Type resultType = r.getType();
-    if (!isa<xegpu::TensorDescType, VectorType>(resultType))
-      continue;
-    xegpu::LayoutAttr layout = getLayoutOfValue(r);
-    if (!layout)
-      layout = resultToLayouts[r];
-    if (!layout) {
-      LLVM_DEBUG(DBGS() << "No layout assigned for vector/tensor desc result: "
-                        << r << "\n");
-      continue;
-    }
-    if (auto tensorDescTy = dyn_cast<xegpu::TensorDescType>(resultType)) {
-      auto newTdescTy = xegpu::TensorDescType::get(
-          tensorDescTy.getContext(), tensorDescTy.getShape(),
-          tensorDescTy.getElementType(), tensorDescTy.getEncoding(), layout);
-      r.setType(newTdescTy);
-      continue;
-    }
-    // If the result is a vector type, add a temporary layout attribute to the
-    // op.
-    std::string resultLayoutName =
-        resultLayoutNamePrefix + std::to_string(r.getResultNumber());
-    op->setAttr(resultLayoutName, layout);
-    // Update all users of the result with the layout.
-    for (OpOperand &user : r.getUses()) {
-      Operation *owner = user.getOwner();
-      unsigned operandNumber = user.getOperandNumber();
-      // Add temporary layout attribute at the user op.
-      std::string attrName =
-          operandLayoutNamePrefix + std::to_string(operandNumber);
-      owner->setAttr(attrName, layout);
+      // if (resultToLayouts.count(result) != 0 &&
+      //     resultToLayouts[result] != blockArgLayout) {
+      //   LLVM_DEBUG(DBGS() << "Conflicting layouts for result: " << result
+      //                     << " - " << resultToLayouts[result] << " vs "
+      //                     << blockArgLayout << "\n");
+      // } else {
+      //   resultToLayouts[result] = blockArgLayout;
+      // }
     }
   }
+  // for (auto [i, r] : llvm::enumerate(op->getResults())) {
+  //   Type resultType = r.getType();
+  //   if (!isa<xegpu::TensorDescType, VectorType>(resultType))
+  //     continue;
+  //   xegpu::LayoutAttr layout = getLayoutOfValue(r);
+  //   if (!layout)
+  //     layout = resultToLayouts[r];
+  //   if (!layout) {
+  //     LLVM_DEBUG(DBGS() << "No layout assigned for vector/tensor desc result:
+  //     "
+  //                       << r << "\n");
+  //     continue;
+  //   }
+  //   if (auto tensorDescTy = dyn_cast<xegpu::TensorDescType>(resultType)) {
+  //     auto newTdescTy = xegpu::TensorDescType::get(
+  //         tensorDescTy.getContext(), tensorDescTy.getShape(),
+  //         tensorDescTy.getElementType(), tensorDescTy.getEncoding(), layout);
+  //     r.setType(newTdescTy);
+  //     continue;
+  //   }
+  //   // If the result is a vector type, add a temporary layout attribute to
+  //   the
+  //   // op.
+  //   std::string resultLayoutName =
+  //       resultLayoutNamePrefix + std::to_string(r.getResultNumber());
+  //   op->setAttr(resultLayoutName, layout);
+  //   // Update all users of the result with the layout.
+  //   for (OpOperand &user : r.getUses()) {
+  //     Operation *owner = user.getOwner();
+  //     unsigned operandNumber = user.getOperandNumber();
+  //     // Add temporary layout attribute at the user op.
+  //     std::string attrName =
+  //         operandLayoutNamePrefix + std::to_string(operandNumber);
+  //     owner->setAttr(attrName, layout);
+  //   }
+  // }
 }
 static void updateBlockTypes(mlir::OpBuilder &builder, mlir::Block &block,
                              GetLayoutCallbackFnTy getLayoutOfValue) {}
@@ -1846,90 +1894,96 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
     updateBlockTypes(builder, *block, getXeGPULayoutForValue);
   });
 
-  // Move all operations of a GPU function inside gpu.warp_execute_on_lane_0
-  // operation.
-  {
-    RewritePatternSet patterns(&getContext());
-    patterns.add<MoveFuncBodyToWarpExecuteOnLane0>(&getContext());
-
-    if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
-      signalPassFailure();
-      return;
-    }
-    // At this point, we have moved the entire function body inside the
-    // warpOp. Now move any scalar uniform code outside of the warpOp (like GPU
-    // index ops, scalar constants, etc.). This will simplify the later lowering
-    // and avoid custom patterns for these ops.
-    getOperation()->walk([&](Operation *op) {
-      if (auto warpOp = dyn_cast<gpu::WarpExecuteOnLane0Op>(op)) {
-        vector::moveScalarUniformCode(warpOp);
-      }
-    });
-  }
-  // Finally, do the SIMD to SIMT distribution.
-  RewritePatternSet patterns(&getContext());
-  xegpu::populateXeGPUSubgroupDistributePatterns(patterns);
-  // TODO: distributionFn and shuffleFn are not used at this point.
-  auto distributionFn = [](Value val) {
-    VectorType vecType = dyn_cast<VectorType>(val.getType());
-    int64_t vecRank = vecType ? vecType.getRank() : 0;
-    OpBuilder builder(val.getContext());
-    if (vecRank == 0)
-      return AffineMap::get(val.getContext());
-    return AffineMap::getMultiDimIdentityMap(vecRank, val.getContext());
-  };
-  auto shuffleFn = [](Location loc, OpBuilder &builder, Value val, Value srcIdx,
-                      int64_t warpSz) { return Value(); };
-  vector::populatePropagateWarpVectorDistributionPatterns(
-      patterns, distributionFn, shuffleFn);
-  if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
-    signalPassFailure();
-    return;
-  }
-
-  // Clean up UnrealizedConversionCastOps that were inserted due to tensor desc
-  // type mismatches created by using upstream distribution patterns (scf.for)
-  getOperation()->walk([&](mlir::UnrealizedConversionCastOp op) {
-    // We are only interested in UnrealizedConversionCastOps there were added
-    // for resolving SIMT type mismatches.
-    if (!op->getAttr(resolveSIMTTypeMismatch))
-      return WalkResult::skip();
-
-    Value input = op.getOperand(0);
-    Value output = op.getResult(0);
-
-    // Both input and output must have tensor descriptor types.
-    xegpu::TensorDescType inputDescType =
-        mlir::dyn_cast<xegpu::TensorDescType>(input.getType());
-    xegpu::TensorDescType outputDescType =
-        mlir::dyn_cast<xegpu::TensorDescType>(output.getType());
-    assert(inputDescType && outputDescType &&
-           "Unrealized conversion cast must have tensor descriptor types");
-
-    // tensor_desc<shape, layout> -> tensor_desc<shape> Type of conversions.
-    // This occurs iside scf.for body to resolve the block argument type to SIMT
-    // type.
-    if (inputDescType.getLayout()) {
-      auto argument = mlir::dyn_cast<mlir::BlockArgument>(input);
-      if (argument) {
-        argument.setType(output.getType());
-        output.replaceAllUsesWith(argument);
-        if (auto loopOp = mlir::dyn_cast<mlir::LoopLikeOpInterface>(
-                argument.getOwner()->getParentOp())) {
-          auto result = loopOp.getTiedLoopResult(argument);
-          result.setType(output.getType());
-        }
-      }
-    }
-
-    // tensor_desc<shape> -> tensor_desc<shape, layout> Type of
-    // conversions. This occurs at the yield op of scf.for body to go back from
-    // SIMT type to original type.
-    if (outputDescType.getLayout())
-      output.replaceAllUsesWith(input);
-
-    if (op->use_empty())
-      op->erase();
-    return WalkResult::advance();
-  });
+  // // Move all operations of a GPU function inside gpu.warp_execute_on_lane_0
+  // // operation.
+  // {
+  //   RewritePatternSet patterns(&getContext());
+  //   patterns.add<MoveFuncBodyToWarpExecuteOnLane0>(&getContext());
+
+  //   if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
+  //     signalPassFailure();
+  //     return;
+  //   }
+  //   // At this point, we have moved the entire function body inside the
+  //   // warpOp. Now move any scalar uniform code outside of the warpOp (like
+  //   GPU
+  //   // index ops, scalar constants, etc.). This will simplify the later
+  //   lowering
+  //   // and avoid custom patterns for these ops.
+  //   getOperation()->walk([&](Operation *op) {
+  //     if (auto warpOp = dyn_cast<gpu::WarpExecuteOnLane0Op>(op)) {
+  //       vector::moveScalarUniformCode(warpOp);
+  //     }
+  //   });
+  // }
+  // // Finally, do the SIMD to SIMT distribution.
+  // RewritePatternSet patterns(&getContext());
+  // xegpu::populateXeGPUSubgroupDistributePatterns(patterns);
+  // // TODO: distributionFn and shuffleFn are not used at this point.
+  // auto distributionFn = [](Value val) {
+  //   VectorType vecType = dyn_cast<VectorType>(val.getType());
+  //   int64_t vecRank = vecType ? vecType.getRank() : 0;
+  //   OpBuilder builder(val.getContext());
+  //   if (vecRank == 0)
+  //     return AffineMap::get(val.getContext());
+  //   return AffineMap::getMultiDimIdentityMap(vecRank, val.getContext());
+  // };
+  // auto shuffleFn = [](Location loc, OpBuilder &builder, Value val, Value
+  // srcIdx,
+  //                     int64_t warpSz) { return Value(); };
+  // vector::populatePropagateWarpVectorDistributionPatterns(
+  //     patterns, distributionFn, shuffleFn);
+  // if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
+  //   signalPassFailure();
+  //   return;
+  // }
+
+  // // Clean up UnrealizedConversionCastOps that were inserted due to tensor
+  // desc
+  // // type mismatches created by using upstream distribution patterns
+  // (scf.for) getOperation()->walk([&](mlir::UnrealizedConversionCastOp op) {
+  //   // We are only interested in UnrealizedConversionCastOps there were added
+  //   // for resolving SIMT type mismatches.
+  //   if (!op->getAttr(resolveSIMTTypeMismatch))
+  //     return WalkResult::skip();
+
+  //   Value input = op.getOperand(0);
+  //   Value output = op.getResult(0);
+
+  //   // Both input and output must have tensor descriptor types.
+  //   xegpu::TensorDescType inputDescType =
+  //       mlir::dyn_cast<xegpu::TensorDescType>(input.getType());
+  //   xegpu::TensorDescType outputDescType =
+  //       mlir::dyn_cast<xegpu::TensorDescType>(output.getType());
+  //   assert(inputDescType && outputDescType &&
+  //          "Unrealized conversion cast must have tensor descriptor types");
+
+  //   // tensor_desc<shape, layout> -> tensor_desc<shape> Type of conversions.
+  //   // This occurs iside scf.for body to resolve the block argument type to
+  //   SIMT
+  //   // type.
+  //   if (inputDescType.getLayout()) {
+  //     auto argument = mlir::dyn_cast<mlir::BlockArgument>(input);
+  //     if (argument) {
+  //       argument.setType(output.getType());
+  //       output.replaceAllUsesWith(argument);
+  //       if (auto loopOp = mlir::dyn_cast<mlir::LoopLikeOpInterface>(
+  //               argument.getOwner()->getParentOp())) {
+  //         auto result = loopOp.getTiedLoopResult(argument);
+  //         result.setType(output.getType());
+  //       }
+  //     }
+  //   }
+
+  //   // tensor_desc<shape> -> tensor_desc<shape, layout> Type of
+  //   // conversions. This occurs at the yield op of scf.for body to go back
+  //   from
+  //   // SIMT type to original type.
+  //   if (outputDescType.getLayout())
+  //     output.replaceAllUsesWith(input);
+
+  //   if (op->use_empty())
+  //     op->erase();
+  //   return WalkResult::advance();
+  // });
 }

>From 7bd0be22d02e14f2ca4c5530b8a14e6b18781803 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Wed, 4 Jun 2025 15:17:27 +0000
Subject: [PATCH 11/32] save work

---
 .../Transforms/XeGPUSubgroupDistribute.cpp    | 24 +++++++++----------
 1 file changed, 12 insertions(+), 12 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 27d912b87c6dc..b997af37a072b 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -938,18 +938,18 @@ static void updateOp(mlir::OpBuilder &builder, mlir::Operation *op,
     }
     // If the result is a vector type, add a temporary layout attribute to the
     // op.
-    std::string resultLayoutName =
-        resultLayoutNamePrefix + std::to_string(result.getResultNumber());
-    op->setAttr(resultLayoutName, layout);
-    // Update all users of the result with the layout.
-    for (OpOperand &user : result.getUses()) {
-      Operation *owner = user.getOwner();
-      unsigned operandNumber = user.getOperandNumber();
-      // Add temorary layout attribute at the user op.
-      std::string attrName =
-          operandLayoutNamePrefix + std::to_string(operandNumber);
-      owner->setAttr(attrName, layout);
-    }
+    // std::string resultLayoutName =
+    //     resultLayoutNamePrefix + std::to_string(result.getResultNumber());
+    // op->setAttr(resultLayoutName, layout);
+    // // Update all users of the result with the layout.
+    // for (OpOperand &user : result.getUses()) {
+    //   Operation *owner = user.getOwner();
+    //   unsigned operandNumber = user.getOperandNumber();
+    //   // Add temorary layout attribute at the user op.
+    //   std::string attrName =
+    //       operandLayoutNamePrefix + std::to_string(operandNumber);
+    //   owner->setAttr(attrName, layout);
+    // }
   }
 }
 static void updateBranchTerminatorOpInterface(

>From 00dc2b67a925ac79d9dc6bee5bf4a167217304eb Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Wed, 4 Jun 2025 22:51:37 +0000
Subject: [PATCH 12/32] working

---
 .../Transforms/XeGPUSubgroupDistribute.cpp    | 295 +++++++++---------
 .../Dialect/XeGPU/subgroup-distribution.mlir  |  98 +++---
 2 files changed, 195 insertions(+), 198 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index b997af37a072b..a17c8d8a4f3f3 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -938,18 +938,18 @@ static void updateOp(mlir::OpBuilder &builder, mlir::Operation *op,
     }
     // If the result is a vector type, add a temporary layout attribute to the
     // op.
-    // std::string resultLayoutName =
-    //     resultLayoutNamePrefix + std::to_string(result.getResultNumber());
-    // op->setAttr(resultLayoutName, layout);
-    // // Update all users of the result with the layout.
-    // for (OpOperand &user : result.getUses()) {
-    //   Operation *owner = user.getOwner();
-    //   unsigned operandNumber = user.getOperandNumber();
-    //   // Add temorary layout attribute at the user op.
-    //   std::string attrName =
-    //       operandLayoutNamePrefix + std::to_string(operandNumber);
-    //   owner->setAttr(attrName, layout);
-    // }
+    std::string resultLayoutName =
+        resultLayoutNamePrefix + std::to_string(result.getResultNumber());
+    op->setAttr(resultLayoutName, layout);
+    // Update all users of the result with the layout.
+    for (OpOperand &user : result.getUses()) {
+      Operation *owner = user.getOwner();
+      unsigned operandNumber = user.getOperandNumber();
+      // Add temorary layout attribute at the user op.
+      std::string attrName =
+          operandLayoutNamePrefix + std::to_string(operandNumber);
+      owner->setAttr(attrName, layout);
+    }
   }
 }
 static void updateBranchTerminatorOpInterface(
@@ -992,8 +992,6 @@ static void updateBranchTerminatorOpInterface(
             << inputLayout << " vs " << operandLayout << "\n");
         continue;
       }
-      llvm::errs() << "Setting layout for input to "
-                   << ": " << operandLayout << "\n";
       // Get tensor descriptor type with the layout.
       auto tdescTy = dyn_cast<xegpu::TensorDescType>(inputType);
       auto newTdescTy = xegpu::TensorDescType::get(
@@ -1044,55 +1042,51 @@ static void updateBranchOpInterface(mlir::OpBuilder &builder,
           tdescTy.getEncoding(), blockArgLayout);
       input.setType(newTdescTy);
       // Store the layout for the result.
-      // if (resultToLayouts.count(result) != 0 &&
-      //     resultToLayouts[result] != blockArgLayout) {
-      //   LLVM_DEBUG(DBGS() << "Conflicting layouts for result: " << result
-      //                     << " - " << resultToLayouts[result] << " vs "
-      //                     << blockArgLayout << "\n");
-      // } else {
-      //   resultToLayouts[result] = blockArgLayout;
-      // }
+      if (resultToLayouts.count(result) != 0 &&
+          resultToLayouts[result] != blockArgLayout) {
+        LLVM_DEBUG(DBGS() << "Conflicting layouts for result: " << result
+                          << " - " << resultToLayouts[result] << " vs "
+                          << blockArgLayout << "\n");
+      } else {
+        resultToLayouts[result] = blockArgLayout;
+      }
+    }
+  }
+  for (auto [i, r] : llvm::enumerate(op->getResults())) {
+    Type resultType = r.getType();
+    if (!isa<xegpu::TensorDescType, VectorType>(resultType))
+      continue;
+    xegpu::LayoutAttr layout = getLayoutOfValue(r);
+    if (!layout)
+      layout = resultToLayouts[r];
+    if (!layout) {
+      LLVM_DEBUG(DBGS() << "No layout assigned for vector/tensor desc result:"
+                        << r << "\n");
+      continue;
+    }
+    if (auto tensorDescTy = dyn_cast<xegpu::TensorDescType>(resultType)) {
+      auto newTdescTy = xegpu::TensorDescType::get(
+          tensorDescTy.getContext(), tensorDescTy.getShape(),
+          tensorDescTy.getElementType(), tensorDescTy.getEncoding(), layout);
+      r.setType(newTdescTy);
+      continue;
+    }
+    // If the result is a vector type, add a temporary layout attribute to
+    // the op.
+    std::string resultLayoutName =
+        resultLayoutNamePrefix + std::to_string(r.getResultNumber());
+    op->setAttr(resultLayoutName, layout);
+    // Update all users of the result with the layout.
+    for (OpOperand &user : r.getUses()) {
+      Operation *owner = user.getOwner();
+      unsigned operandNumber = user.getOperandNumber();
+      // Add temporary layout attribute at the user op.
+      std::string attrName =
+          operandLayoutNamePrefix + std::to_string(operandNumber);
+      owner->setAttr(attrName, layout);
     }
   }
-  // for (auto [i, r] : llvm::enumerate(op->getResults())) {
-  //   Type resultType = r.getType();
-  //   if (!isa<xegpu::TensorDescType, VectorType>(resultType))
-  //     continue;
-  //   xegpu::LayoutAttr layout = getLayoutOfValue(r);
-  //   if (!layout)
-  //     layout = resultToLayouts[r];
-  //   if (!layout) {
-  //     LLVM_DEBUG(DBGS() << "No layout assigned for vector/tensor desc result:
-  //     "
-  //                       << r << "\n");
-  //     continue;
-  //   }
-  //   if (auto tensorDescTy = dyn_cast<xegpu::TensorDescType>(resultType)) {
-  //     auto newTdescTy = xegpu::TensorDescType::get(
-  //         tensorDescTy.getContext(), tensorDescTy.getShape(),
-  //         tensorDescTy.getElementType(), tensorDescTy.getEncoding(), layout);
-  //     r.setType(newTdescTy);
-  //     continue;
-  //   }
-  //   // If the result is a vector type, add a temporary layout attribute to
-  //   the
-  //   // op.
-  //   std::string resultLayoutName =
-  //       resultLayoutNamePrefix + std::to_string(r.getResultNumber());
-  //   op->setAttr(resultLayoutName, layout);
-  //   // Update all users of the result with the layout.
-  //   for (OpOperand &user : r.getUses()) {
-  //     Operation *owner = user.getOwner();
-  //     unsigned operandNumber = user.getOperandNumber();
-  //     // Add temporary layout attribute at the user op.
-  //     std::string attrName =
-  //         operandLayoutNamePrefix + std::to_string(operandNumber);
-  //     owner->setAttr(attrName, layout);
-  //   }
-  // }
 }
-static void updateBlockTypes(mlir::OpBuilder &builder, mlir::Block &block,
-                             GetLayoutCallbackFnTy getLayoutOfValue) {}
 
 namespace {
 
@@ -1890,100 +1884,93 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
       }
       updateOp(builder, &op, getXeGPULayoutForValue);
     }
-
-    updateBlockTypes(builder, *block, getXeGPULayoutForValue);
   });
 
-  // // Move all operations of a GPU function inside gpu.warp_execute_on_lane_0
-  // // operation.
-  // {
-  //   RewritePatternSet patterns(&getContext());
-  //   patterns.add<MoveFuncBodyToWarpExecuteOnLane0>(&getContext());
-
-  //   if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
-  //     signalPassFailure();
-  //     return;
-  //   }
-  //   // At this point, we have moved the entire function body inside the
-  //   // warpOp. Now move any scalar uniform code outside of the warpOp (like
-  //   GPU
-  //   // index ops, scalar constants, etc.). This will simplify the later
-  //   lowering
-  //   // and avoid custom patterns for these ops.
-  //   getOperation()->walk([&](Operation *op) {
-  //     if (auto warpOp = dyn_cast<gpu::WarpExecuteOnLane0Op>(op)) {
-  //       vector::moveScalarUniformCode(warpOp);
-  //     }
-  //   });
-  // }
-  // // Finally, do the SIMD to SIMT distribution.
-  // RewritePatternSet patterns(&getContext());
-  // xegpu::populateXeGPUSubgroupDistributePatterns(patterns);
-  // // TODO: distributionFn and shuffleFn are not used at this point.
-  // auto distributionFn = [](Value val) {
-  //   VectorType vecType = dyn_cast<VectorType>(val.getType());
-  //   int64_t vecRank = vecType ? vecType.getRank() : 0;
-  //   OpBuilder builder(val.getContext());
-  //   if (vecRank == 0)
-  //     return AffineMap::get(val.getContext());
-  //   return AffineMap::getMultiDimIdentityMap(vecRank, val.getContext());
-  // };
-  // auto shuffleFn = [](Location loc, OpBuilder &builder, Value val, Value
-  // srcIdx,
-  //                     int64_t warpSz) { return Value(); };
-  // vector::populatePropagateWarpVectorDistributionPatterns(
-  //     patterns, distributionFn, shuffleFn);
-  // if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
-  //   signalPassFailure();
-  //   return;
-  // }
-
-  // // Clean up UnrealizedConversionCastOps that were inserted due to tensor
-  // desc
-  // // type mismatches created by using upstream distribution patterns
-  // (scf.for) getOperation()->walk([&](mlir::UnrealizedConversionCastOp op) {
-  //   // We are only interested in UnrealizedConversionCastOps there were added
-  //   // for resolving SIMT type mismatches.
-  //   if (!op->getAttr(resolveSIMTTypeMismatch))
-  //     return WalkResult::skip();
-
-  //   Value input = op.getOperand(0);
-  //   Value output = op.getResult(0);
-
-  //   // Both input and output must have tensor descriptor types.
-  //   xegpu::TensorDescType inputDescType =
-  //       mlir::dyn_cast<xegpu::TensorDescType>(input.getType());
-  //   xegpu::TensorDescType outputDescType =
-  //       mlir::dyn_cast<xegpu::TensorDescType>(output.getType());
-  //   assert(inputDescType && outputDescType &&
-  //          "Unrealized conversion cast must have tensor descriptor types");
-
-  //   // tensor_desc<shape, layout> -> tensor_desc<shape> Type of conversions.
-  //   // This occurs iside scf.for body to resolve the block argument type to
-  //   SIMT
-  //   // type.
-  //   if (inputDescType.getLayout()) {
-  //     auto argument = mlir::dyn_cast<mlir::BlockArgument>(input);
-  //     if (argument) {
-  //       argument.setType(output.getType());
-  //       output.replaceAllUsesWith(argument);
-  //       if (auto loopOp = mlir::dyn_cast<mlir::LoopLikeOpInterface>(
-  //               argument.getOwner()->getParentOp())) {
-  //         auto result = loopOp.getTiedLoopResult(argument);
-  //         result.setType(output.getType());
-  //       }
-  //     }
-  //   }
-
-  //   // tensor_desc<shape> -> tensor_desc<shape, layout> Type of
-  //   // conversions. This occurs at the yield op of scf.for body to go back
-  //   from
-  //   // SIMT type to original type.
-  //   if (outputDescType.getLayout())
-  //     output.replaceAllUsesWith(input);
-
-  //   if (op->use_empty())
-  //     op->erase();
-  //   return WalkResult::advance();
-  // });
+  // Move all operations of a GPU function inside gpu.warp_execute_on_lane_0
+  // operation.
+  {
+    RewritePatternSet patterns(&getContext());
+    patterns.add<MoveFuncBodyToWarpExecuteOnLane0>(&getContext());
+
+    if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
+      signalPassFailure();
+      return;
+    }
+    // At this point, we have moved the entire function body inside the
+    // warpOp. Now move any scalar uniform code outside of the warpOp (like
+    // GPU index ops, scalar constants, etc.). This will simplify the
+    // later lowering and avoid custom patterns for these ops.
+    getOperation()->walk([&](Operation *op) {
+      if (auto warpOp = dyn_cast<gpu::WarpExecuteOnLane0Op>(op)) {
+        vector::moveScalarUniformCode(warpOp);
+      }
+    });
+  }
+  // Finally, do the SIMD to SIMT distribution.
+  RewritePatternSet patterns(&getContext());
+  xegpu::populateXeGPUSubgroupDistributePatterns(patterns);
+  // TODO: distributionFn and shuffleFn are not used at this point.
+  auto distributionFn = [](Value val) {
+    VectorType vecType = dyn_cast<VectorType>(val.getType());
+    int64_t vecRank = vecType ? vecType.getRank() : 0;
+    OpBuilder builder(val.getContext());
+    if (vecRank == 0)
+      return AffineMap::get(val.getContext());
+    return AffineMap::getMultiDimIdentityMap(vecRank, val.getContext());
+  };
+  auto shuffleFn = [](Location loc, OpBuilder &builder, Value val, Value srcIdx,
+                      int64_t warpSz) { return Value(); };
+  vector::populatePropagateWarpVectorDistributionPatterns(
+      patterns, distributionFn, shuffleFn);
+  if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
+    signalPassFailure();
+    return;
+  }
+
+  // Clean up UnrealizedConversionCastOps that were inserted due to tensor
+  // desc type mismatches created by using upstream distribution patterns
+  // (scf.for)
+  getOperation()->walk([&](mlir::UnrealizedConversionCastOp op) {
+    // We are only interested in UnrealizedConversionCastOps there were added
+    // for resolving SIMT type mismatches.
+    if (!op->getAttr(resolveSIMTTypeMismatch))
+      return WalkResult::skip();
+
+    Value input = op.getOperand(0);
+    Value output = op.getResult(0);
+
+    // Both input and output must have tensor descriptor types.
+    xegpu::TensorDescType inputDescType =
+        mlir::dyn_cast<xegpu::TensorDescType>(input.getType());
+    xegpu::TensorDescType outputDescType =
+        mlir::dyn_cast<xegpu::TensorDescType>(output.getType());
+    assert(inputDescType && outputDescType &&
+           "Unrealized conversion cast must have tensor descriptor types");
+
+    // tensor_desc<shape, layout> -> tensor_desc<shape> Type of conversions.
+    // This occurs iside scf.for body to resolve the block argument type to
+    // SIMT type.
+    if (inputDescType.getLayout()) {
+      auto argument = mlir::dyn_cast<mlir::BlockArgument>(input);
+      if (argument) {
+        argument.setType(output.getType());
+        output.replaceAllUsesWith(argument);
+        if (auto loopOp = mlir::dyn_cast<mlir::LoopLikeOpInterface>(
+                argument.getOwner()->getParentOp())) {
+          auto result = loopOp.getTiedLoopResult(argument);
+          result.setType(output.getType());
+        }
+      }
+    }
+
+    // tensor_desc<shape> -> tensor_desc<shape, layout> Type of
+    // conversions. This occurs at the yield op of scf.for body to go back
+    // from SIMT type to original type.
+    if (outputDescType.getLayout())
+      output.replaceAllUsesWith(input);
+
+    if (op->use_empty())
+      op->erase();
+    return WalkResult::advance();
+  });
 }
diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribution.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribution.mlir
index e5606c5642505..b5f6bda26d830 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-distribution.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-distribution.mlir
@@ -93,49 +93,54 @@ gpu.func @load_nd_array_length(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf16
 }
 
 // -----
-// CHECK-LABEL: gpu.func @dpas
-// CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: vector<8x16xf16>, %[[ARG1:[0-9a-zA-Z]+]]: vector<16x16xf16>, %[[ARG2:[0-9a-zA-Z]+]]: vector<8x16xf32>, %[[ARG3:[0-9a-zA-Z]+]]: memref<8x16xf32>) {
-// CHECK: %[[T1:.*]]:3 = gpu.warp_execute_on_lane_0(%{{.*}})[16] args(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]]
-// CHECK-SAME: vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32>, memref<8x16xf32>) -> (vector<8x1xf16>, vector<16x1xf16>, vector<8x1xf32>) {
-// CHECK: ^bb0(%[[ARG4:[0-9a-zA-Z]+]]: vector<8x16xf16>, %[[ARG5:[0-9a-zA-Z]+]]: vector<16x16xf16>, %[[ARG6:[0-9a-zA-Z]+]]: vector<8x16xf32>, %[[ARG7:[0-9a-zA-Z]+]]: memref<8x16xf32>):
-// CHECK:  gpu.yield %[[ARG4]], %[[ARG5]], %[[ARG6]] : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32>
-// CHECK: }
-// CHECK-DAG: %[[T2:.*]] = vector.shape_cast %[[T1]]#0 : vector<8x1xf16> to vector<8xf16>
-// CHECK-DAG: %[[T3:.*]] = vector.shape_cast %[[T1]]#1 : vector<16x1xf16> to vector<16xf16>
-// CHECK-DAG: %[[T4:.*]] = vector.shape_cast %[[T1]]#2 : vector<8x1xf32> to vector<8xf32>
-// CHECK: %[[T5:.*]] = xegpu.dpas %[[T2]], %[[T3]], %[[T4]] : vector<8xf16>, vector<16xf16>, vector<8xf32> -> vector<8xf32>
-// CHECK: %[[T6:.*]] = xegpu.create_nd_tdesc %[[ARG3]][%{{.*}}] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
-// CHECK: xegpu.store_nd %[[T5]], %[[T6]]  : vector<8xf32>, !xegpu.tensor_desc<8x16xf32>
+// CHECK-LABEL: gpu.func @load_dpas_store
+// CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: memref<8x16xf16>, %[[ARG1:[0-9a-zA-Z]+]]: memref<16x16xf16>, %[[ARG2:[0-9a-zA-Z]+]]: memref<8x16xf32>) {
+// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG1]][%{{.*}}] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
+// CHECK: %[[T1:.*]] = xegpu.load_nd %[[T0]] <{packed}> : !xegpu.tensor_desc<16x16xf16> -> vector<16xf16>
+// CHECK: %[[T2:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%{{.*}}] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
+// CHECK: %[[T3:.*]] = xegpu.load_nd %[[T2]]  : !xegpu.tensor_desc<8x16xf16> -> vector<8xf16>
+// CHECK-DAG: %[[T4:.*]] = xegpu.dpas %[[T3]], %[[T1]] : vector<8xf16>, vector<16xf16> -> vector<8xf32>
+// CHECK-DAG: %[[T5:.*]] = xegpu.create_nd_tdesc %[[ARG2]][%{{.*}}] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
+// CHECK: xegpu.store_nd %[[T4]], %[[T5]]  : vector<8xf32>, !xegpu.tensor_desc<8x16xf32>
 gpu.module @test {
-gpu.func @dpas(%arg0: vector<8x16xf16>, %arg1: vector<16x16xf16>, %arg3: vector<8x16xf32>, %arg2: memref<8x16xf32>){
+gpu.func @load_dpas_store(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg3: memref<8x16xf32>){
   %c0 = arith.constant 0 : index
-  %0 = xegpu.dpas %arg0, %arg1, %arg3 : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
-  %3 = xegpu.create_nd_tdesc %arg2[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
-  xegpu.store_nd %0, %3 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
+  %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
+  %1 = xegpu.load_nd %0 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+  %2 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
+  %3 = xegpu.load_nd %2 : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+  %4 = xegpu.dpas %1, %3 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+  %5 = xegpu.create_nd_tdesc %arg3[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
+  xegpu.store_nd %4, %5 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
   gpu.return
 }
 }
 
+
 // -----
-// CHECK-LABEL: gpu.func @load_dpas_store
+// CHECK-LABEL: gpu.func @load_dpas_postop_store
 // CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: memref<8x16xf16>, %[[ARG1:[0-9a-zA-Z]+]]: memref<16x16xf16>, %[[ARG2:[0-9a-zA-Z]+]]: memref<8x16xf32>) {
 // CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG1]][%{{.*}}] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
 // CHECK: %[[T1:.*]] = xegpu.load_nd %[[T0]] <{packed}> : !xegpu.tensor_desc<16x16xf16> -> vector<16xf16>
 // CHECK: %[[T2:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%{{.*}}] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
 // CHECK: %[[T3:.*]] = xegpu.load_nd %[[T2]]  : !xegpu.tensor_desc<8x16xf16> -> vector<8xf16>
 // CHECK-DAG: %[[T4:.*]] = xegpu.dpas %[[T3]], %[[T1]] : vector<8xf16>, vector<16xf16> -> vector<8xf32>
-// CHECK-DAG: %[[T5:.*]] = xegpu.create_nd_tdesc %[[ARG2]][%{{.*}}] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
-// CHECK: xegpu.store_nd %[[T4]], %[[T5]]  : vector<8xf32>, !xegpu.tensor_desc<8x16xf32>
+// CHECK: %[[T5:.*]] = vector.shape_cast %[[T4]] : vector<8xf32> to vector<8x1xf32>
+// CHECK: %[[T6:.*]] = math.exp %[[T5]] {{{.*}}} : vector<8x1xf32>
+// CHECK-DAG: %[[T8:.*]] = vector.shape_cast %[[T6]] : vector<8x1xf32> to vector<8xf32>
+// CHECK-DAG: %[[T7:.*]] = xegpu.create_nd_tdesc %[[ARG2]][{{.*}}] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
+// CHECK: xegpu.store_nd %[[T8]], %[[T7]]  : vector<8xf32>, !xegpu.tensor_desc<8x16xf32>
 gpu.module @test {
-gpu.func @load_dpas_store(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg3: memref<8x16xf32>){
+gpu.func @load_dpas_postop_store(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg3: memref<8x16xf32>){
   %c0 = arith.constant 0 : index
   %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
   %1 = xegpu.load_nd %0 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
   %2 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
   %3 = xegpu.load_nd %2 : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
   %4 = xegpu.dpas %1, %3 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
-  %5 = xegpu.create_nd_tdesc %arg3[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
-  xegpu.store_nd %4, %5 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
+  %5 = math.exp %4 : vector<8x16xf32>
+  %6 = xegpu.create_nd_tdesc %arg3[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
+  xegpu.store_nd %5, %6 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
   gpu.return
 }
 }
@@ -169,20 +174,22 @@ gpu.func @create_nd_tdesc_non_memref(%arg0: ui64, %arg1: ui64,
 // CHECK: %[[Y_COORD:.*]] = arith.muli %[[BLOCK_ID_Y]], %c16 : index
 // CHECK: %[[X_COORD:.*]] = arith.muli %[[BLOCK_ID_X]], %c8 : index
 // CHECK: %[[T2:.*]] = xegpu.create_nd_tdesc %[[ARG2]][%[[X_COORD]], %[[Y_COORD]]] : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32>
-// CHECK: %[[T3:.*]] = xegpu.load_nd %[[T2]] : !xegpu.tensor_desc<8x16xf32> -> vector<8xf32>
-// CHECK: %[[T4:.*]] = vector.shape_cast %[[T3]] : vector<8xf32> to vector<8x1xf32>
-// CHECK: %[[T5:.*]] = scf.for %[[K:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG4:.*]] = %[[T4]]) -> (vector<8x1xf32>) {
-// CHECK: %[[T10:.*]] = xegpu.create_nd_tdesc %[[ARG1]][%[[K]], %[[Y_COORD]]] : memref<1024x1024xbf16> -> !xegpu.tensor_desc<16x16xbf16>
-// CHECK: %[[T11:.*]] = xegpu.load_nd %[[T10]] <{packed}> : !xegpu.tensor_desc<16x16xbf16> -> vector<16xbf16>
-// CHECK: %[[T12:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%[[X_COORD]], %[[K]]] : memref<1024x1024xbf16> -> !xegpu.tensor_desc<8x16xbf16>
-// CHECK: %[[T13:.*]] = xegpu.load_nd %[[T12]] : !xegpu.tensor_desc<8x16xbf16> -> vector<8xbf16>
-// CHECK: %[[T14:.*]] = vector.shape_cast %[[ARG4]] : vector<8x1xf32> to vector<8xf32>
-// CHECK: %[[T15:.*]] = xegpu.dpas %[[T13]], %[[T11]], %[[T14]] : vector<8xbf16>, vector<16xbf16>, vector<8xf32> -> vector<8xf32>
-// CHECK: %[[T16:.*]] = vector.shape_cast %[[T15]] : vector<8xf32> to vector<8x1xf32>
-// CHECK: scf.yield %[[T16]] : vector<8x1xf32>
-// CHECK: }
-// CHECK: %[[T9:.*]] = vector.shape_cast %[[T5]] : vector<8x1xf32> to vector<8xf32>
-// CHECK: xegpu.store_nd %[[T9]], %[[T2]] : vector<8xf32>, !xegpu.tensor_desc<8x16xf32>
+// CHECK: %[[T3:.*]] = xegpu.load_nd %[[T2]]  : !xegpu.tensor_desc<8x16xf32> -> vector<8xf32>
+// CHECK-DAG: %[[C_INIT:.*]] = vector.shape_cast %[[T3]] : vector<8xf32> to vector<8x1xf32>
+// CHECK-DAG: %[[B_TILE:.*]] = xegpu.create_nd_tdesc %[[ARG1]][%{{.*}}, %[[Y_COORD]]] : memref<1024x1024xbf16> -> !xegpu.tensor_desc<16x16xbf16>
+// CHECK-DAG: %[[A_TILE:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%[[X_COORD]], %{{.*}}] : memref<1024x1024xbf16> -> !xegpu.tensor_desc<8x16xbf16>
+// CHECK: %[[T7:.*]]:3 = scf.for {{.*}} iter_args(%[[C_VAL:.*]] = %[[C_INIT]], %[[A_ARG:.*]] = %[[A_TILE]], %[[B_ARG:.*]] = %[[B_TILE]]) -> (vector<8x1xf32>, !xegpu.tensor_desc<8x16xbf16>, !xegpu.tensor_desc<16x16xbf16>) {
+// CHECK-DAG: %[[B_NEXT:.*]] = xegpu.update_nd_offset %[[B_ARG]], [{{.*}}] : !xegpu.tensor_desc<16x16xbf16>
+// CHECK-DAG: %[[A_NEXT:.*]] = xegpu.update_nd_offset %[[A_ARG]], [{{.*}}] : !xegpu.tensor_desc<8x16xbf16>
+// CHECK-DAG: %[[B:.*]] = xegpu.load_nd %[[B_ARG]] <{packed}> : !xegpu.tensor_desc<16x16xbf16> -> vector<16xbf16>
+// CHECK-DAG: %[[A:.*]] = xegpu.load_nd %[[A_ARG]]  : !xegpu.tensor_desc<8x16xbf16> -> vector<8xbf16>
+// CHECK-DAG: %[[C:.*]] = vector.shape_cast %[[C_VAL]] : vector<8x1xf32> to vector<8xf32>
+// CHECK-NEXT: %[[T8:.*]] = xegpu.dpas %[[A]], %[[B]], %[[C]] : vector<8xbf16>, vector<16xbf16>, vector<8xf32> -> vector<8xf32>
+// CHECK-NEXT: %[[C_OUT:.*]] = vector.shape_cast %[[T8]] : vector<8xf32> to vector<8x1xf32>
+// CHECK-NEXT: scf.yield %[[C_OUT]], %[[A_NEXT]], %[[B_NEXT]] : vector<8x1xf32>, !xegpu.tensor_desc<8x16xbf16>, !xegpu.tensor_desc<16x16xbf16>
+// CHECK-NEXT:}
+// CHECK-NEXT: %[[C_FINAL:.*]] = vector.shape_cast %[[T7]]#0 : vector<8x1xf32> to vector<8xf32>
+// CHECK-NEXT: xegpu.store_nd %[[C_FINAL]], %[[T2]]  : vector<8xf32>, !xegpu.tensor_desc<8x16xf32>
 gpu.module @test {
 gpu.func @gemm_loop(%arg0: memref<1024x1024xbf16>, %arg1: memref<1024x1024xbf16>, %arg2: memref<1024x1024xf32>){
   %c0 = arith.constant 0 : index
@@ -195,15 +202,18 @@ gpu.func @gemm_loop(%arg0: memref<1024x1024xbf16>, %arg1: memref<1024x1024xbf16>
   %3 = arith.muli %1, %c16 : index
   %4 = xegpu.create_nd_tdesc %arg2[%2, %3] : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32>
   %5 = xegpu.load_nd %4 : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32>
-  %6 = scf.for %arg3 = %c0 to %c1024 step %c16 iter_args(%arg4 = %5) -> (vector<8x16xf32>) {
-    %7 = xegpu.create_nd_tdesc %arg0[%2, %arg3] : memref<1024x1024xbf16> -> !xegpu.tensor_desc<8x16xbf16>
-    %8 = xegpu.create_nd_tdesc %arg1[%arg3, %3] : memref<1024x1024xbf16> -> !xegpu.tensor_desc<16x16xbf16>
-    %9 = xegpu.load_nd %7 : !xegpu.tensor_desc<8x16xbf16> -> vector<8x16xbf16>
-    %10 = xegpu.load_nd %8 : !xegpu.tensor_desc<16x16xbf16> -> vector<16x16xbf16>
+  %7 = xegpu.create_nd_tdesc %arg0[%2, %c0] : memref<1024x1024xbf16> -> !xegpu.tensor_desc<8x16xbf16>
+  %8 = xegpu.create_nd_tdesc %arg1[%c0, %3] : memref<1024x1024xbf16> -> !xegpu.tensor_desc<16x16xbf16>
+  %6:3 = scf.for %arg3 = %c0 to %c1024 step %c16 iter_args(%arg4 = %5, %arg5 = %7, %arg6 = %8) -> (vector<8x16xf32>, !xegpu.tensor_desc<8x16xbf16>, !xegpu.tensor_desc<16x16xbf16>) {
+    %9 = xegpu.load_nd %arg5 : !xegpu.tensor_desc<8x16xbf16> -> vector<8x16xbf16>
+    %10 = xegpu.load_nd %arg6 : !xegpu.tensor_desc<16x16xbf16> -> vector<16x16xbf16>
+    %12 = xegpu.update_nd_offset %arg5, [%c0, %c16] : !xegpu.tensor_desc<8x16xbf16>
+    %13 = xegpu.update_nd_offset %arg6, [%c16, %c0] : !xegpu.tensor_desc<16x16xbf16>
     %11 = xegpu.dpas %9, %10, %arg4 : vector<8x16xbf16>, vector<16x16xbf16>, vector<8x16xf32> -> vector<8x16xf32>
-    scf.yield %11 : vector<8x16xf32>
+    scf.yield %11, %12, %13 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xbf16>, !xegpu.tensor_desc<16x16xbf16>
   }
-  xegpu.store_nd %6, %4 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
+  %12 = xegpu.create_nd_tdesc %arg2[%2, %3] : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32>
+  xegpu.store_nd %6#0, %12 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
   gpu.return
 }
 }

>From 35620ec131462b97239409a984d792455289a32e Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Thu, 5 Jun 2025 15:39:43 +0000
Subject: [PATCH 13/32] move out layout prop

---
 .../mlir/Dialect/XeGPU/Transforms/Passes.td   |   12 +
 .../Dialect/XeGPU/Transforms/CMakeLists.txt   |    1 +
 .../XeGPU/Transforms/XeGPULayoutPropagate.cpp |  920 ++++++++++++++
 .../Transforms/XeGPUSubgroupDistribute.cpp    | 1052 -----------------
 4 files changed, 933 insertions(+), 1052 deletions(-)
 create mode 100644 mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutPropagate.cpp

diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
index 6f585f9ceb29b..08e02f295a851 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
@@ -33,6 +33,18 @@ def XeGPUSubgroupDistribute : Pass<"xegpu-subgroup-distribute"> {
       "Print the result of the subgroup map propagation analysis and exit.">];
 }
 
+def XeGPULayoutPropagate : Pass<"xegpu-layout-propagate"> {
+  let summary = "Propagate XeGPU layout information";
+  let description = [{
+    This pass propagates the XeGPU layout information accross ops. Starting
+    from a set of anchor operations (e.g. `dpas`, `store_nd`), this will
+    propagate the layouts required for operands and results to the producers or
+    consumers.
+  }];
+  let dependentDialects = ["memref::MemRefDialect", "xegpu::XeGPUDialect",
+                           "vector::VectorDialect"];
+}
+
 def XeGPUWgToSgDistribute : Pass<"xegpu-wg-to-sg-distribute"> {
   let summary = "Transform WorkGroup level XeGPU code to SubGroup level";
   let description = [{
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
index 7d9b5584b0b2b..a72be9cd60b9c 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
@@ -3,6 +3,7 @@ add_mlir_dialect_library(MLIRXeGPUTransforms
   XeGPUSubgroupDistribute.cpp
   XeGPUUnroll.cpp
   XeGPUWgToSgDistribute.cpp
+  XeGPULayoutPropagate.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/XeGPU
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutPropagate.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutPropagate.cpp
new file mode 100644
index 0000000000000..f308d338b511a
--- /dev/null
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutPropagate.cpp
@@ -0,0 +1,920 @@
+//===- XeGPULayoutPropagate.cpp - XeGPU Layout Propagation ------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
+#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
+#include "mlir/Analysis/DataFlow/SparseAnalysis.h"
+#include "mlir/Analysis/DataFlowFramework.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/Value.h"
+#include "mlir/IR/Visitors.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
+#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/InterleavedRange.h"
+#include "llvm/Support/raw_ostream.h"
+
+namespace mlir {
+namespace xegpu {
+#define GEN_PASS_DEF_XEGPULAYOUTPROPAGATE
+#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
+} // namespace xegpu
+} // namespace mlir
+
+#define DEBUG_TYPE "xegpu-layout-propagate"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
+
+using namespace mlir;
+using namespace mlir::dataflow;
+
+/// HW dependent constants.
+/// TODO: These constants should be queried from the target information.
+constexpr unsigned subgroupSize = 16; // How many lanes in a subgroup.
+/// If DPAS A or B operands have low precision element types they must be packed
+/// according to the following sizes.
+constexpr unsigned packedSizeInBitsForDefault =
+    16; // Minimum packing size per register for DPAS A.
+constexpr unsigned packedSizeInBitsForDpasB =
+    32; // Minimum packing size per register for DPAS B.
+static const char *const operandLayoutNamePrefix =
+    "layout_operand_"; // Attribute name for identifying operand layouts.
+static const char *const resultLayoutNamePrefix =
+    "layout_result_"; // Attribute name for identifying result layouts.
+static const char *const resolveSIMTTypeMismatch =
+    "resolve_simt_type_mismatch"; // Attribute name for identifying
+                                  // UnrelizedConversionCastOp added to resolve
+                                  // SIMT type mismatches.
+
+namespace {
+
+//===----------------------------------------------------------------------===//
+// Layout
+//===----------------------------------------------------------------------===//
+
+/// Helper class to store the ND layout of lanes within a subgroup and data
+/// owned by each lane.
+struct Layout {
+  SmallVector<int64_t, 3> layout;
+  Layout() = default;
+  Layout(std::initializer_list<int64_t> list) : layout(list) {}
+  void print(llvm::raw_ostream &os) const;
+  size_t size() const { return layout.size(); }
+  int64_t operator[](size_t idx) const;
+};
+
+void Layout::print(llvm::raw_ostream &os) const {
+  os << llvm::interleaved_array(layout);
+}
+
+int64_t Layout::operator[](size_t idx) const {
+  assert(idx < layout.size() && "Index out of bounds.");
+  return layout[idx];
+}
+
+/// LaneLayout represents the logical layout of lanes within a subgroup when it
+/// accesses some value. LaneData represents the logical layout of data owned by
+/// each work item.
+using LaneLayout = Layout;
+using LaneData = Layout;
+
+//===----------------------------------------------------------------------===//
+// LayoutInfo
+//===----------------------------------------------------------------------===//
+
+/// Helper class for tracking the analysis state of an mlir value. For layout
+/// propagation, the analysis state is simply the lane_layout and lane_data of
+/// each value. Purpose of this analysis to propagate some unique layout for
+/// each value in the program starting from a set of anchor operations (like
+/// DPAS, StoreNd, etc.).
+///
+/// Given this, LayoutInfo  satisifies the following properties:
+///  1) A LayoutInfo value can be in one of two states - `assigned` or `not
+///  assigned`.
+///  2) Two LayoutInfo values are equal if they are both assigned or
+///  both not assigned. The concrete value of assigned state does not matter.
+///  3) The meet operator works as follows:
+///     - If current state is assigned, return the current state. (already
+///     a unique layout is assigned. don't change it)
+///     - Otherwise, return the other state.
+
+struct LayoutInfo {
+private:
+  LaneLayout laneLayout;
+  LaneData laneData;
+
+public:
+  LayoutInfo() = default;
+  LayoutInfo(const LaneLayout &layout, const LaneData &data)
+      : laneLayout(layout), laneData(data) {}
+
+  // Two lattice values are equal if they have `some` layout. The actual
+  // content of the layout does not matter.
+  bool operator==(const LayoutInfo &other) const {
+    return this->isAssigned() == other.isAssigned();
+  }
+
+  static LayoutInfo meet(const LayoutInfo &lhs, const LayoutInfo &rhs);
+
+  static LayoutInfo join(const LayoutInfo &lhs, const LayoutInfo &rhs);
+
+  void print(raw_ostream &os) const;
+
+  bool isAssigned() const {
+    return laneLayout.size() > 0 && laneData.size() > 0;
+  }
+
+  LayoutInfo getTransposedLayout(ArrayRef<int64_t> permutation) const;
+
+  const LaneLayout &getLayout() const { return laneLayout; }
+  const LaneData &getData() const { return laneData; }
+  ArrayRef<int64_t> getLayoutAsArrayRef() const { return laneLayout.layout; }
+  ArrayRef<int64_t> getDataAsArrayRef() const { return laneData.layout; }
+};
+
+void LayoutInfo::print(raw_ostream &os) const {
+  if (isAssigned()) {
+    os << "lane_layout: ";
+    laneLayout.print(os);
+    os << ", lane_data: ";
+    laneData.print(os);
+  } else {
+    os << "Not assigned.";
+  }
+}
+
+LayoutInfo LayoutInfo::meet(const LayoutInfo &lhs, const LayoutInfo &rhs) {
+  if (!lhs.isAssigned())
+    return rhs;
+  return lhs;
+}
+
+/// Since this is a backward analysis, join method is not used.
+LayoutInfo LayoutInfo::join(const LayoutInfo &lhs, const LayoutInfo &rhs) {
+  llvm_unreachable("Join should not be triggered by layout propagation.");
+}
+
+/// Get the transposed layout according to the given permutation.
+LayoutInfo
+LayoutInfo::getTransposedLayout(ArrayRef<int64_t> permutation) const {
+  if (!isAssigned())
+    return {};
+  LaneLayout newLayout;
+  LaneData newData;
+  for (int64_t idx : permutation) {
+    newLayout.layout.push_back(laneLayout.layout[idx]);
+    newData.layout.push_back(laneData.layout[idx]);
+  }
+  return LayoutInfo(newLayout, newData);
+}
+
+//===----------------------------------------------------------------------===//
+// LayoutInfoLattice
+//===----------------------------------------------------------------------===//
+
+/// Lattice holding the LayoutInfo for each value.
+struct LayoutInfoLattice : public Lattice<LayoutInfo> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LayoutInfoLattice)
+  using Lattice::Lattice;
+};
+
+/// Helper Functions to get default layouts. A `default layout` is a layout that
+/// is assigned to a value when the layout is not fixed by some anchor operation
+/// (like DPAS).
+
+/// Helper Function to get the default layout for uniform values like constants.
+/// For 1D vector, lane_layout is [subgroupSize] and lane_data is [1].
+/// For 2D vector, lane_layout is [1, subgroupSize] and lane_data is [1, 1].
+static LayoutInfo getDefaultLayoutInfo(unsigned rank) {
+  assert((rank == 1 || rank == 2) && "Expected 1D or 2D vector.");
+  if (rank == 1)
+    return LayoutInfo(LaneLayout({subgroupSize}), LaneData({1}));
+  return LayoutInfo(LaneLayout({1, subgroupSize}), LaneData({1, 1}));
+}
+
+/// Helper to get the default layout for a vector type.
+static LayoutInfo getDefaultLayoutInfo(VectorType vectorTy) {
+  // Expecting a 1D or 2D vector.
+  assert((vectorTy.getRank() == 1 || vectorTy.getRank() == 2) &&
+         "Expected 1D or 2D vector.");
+  // Expecting int or float element type.
+  assert(vectorTy.getElementType().isIntOrFloat() &&
+         "Expected int or float element type.");
+  // If the rank is 1, then return default layout for 1D vector.
+  if (vectorTy.getRank() == 1)
+    return getDefaultLayoutInfo(1);
+  // Packing factor is determined by the element type bitwidth.
+  int packingFactor = 1;
+  unsigned bitwidth = vectorTy.getElementType().getIntOrFloatBitWidth();
+  if (bitwidth < packedSizeInBitsForDefault)
+    packingFactor = packedSizeInBitsForDefault / bitwidth;
+  return LayoutInfo(LaneLayout({1, subgroupSize}),
+                    LaneData({1, packingFactor}));
+}
+
+/// Helper Function to get the expected layouts for DPAS operands. `lane_data`
+/// is set according to the following criteria:
+/// * For A operand, the data must be packed in minimum
+/// `packedSizeInBitsForDefault`
+/// * For B operand, the data must be packed in minimum
+/// `packedSizeInBitsForDpasB`
+static LayoutInfo getLayoutInfoForDPASOperand(VectorType vectorTy,
+                                              unsigned operandNum) {
+  Type elementTy = vectorTy.getElementType();
+  assert(elementTy.isIntOrFloat() &&
+         "Expected int or float type in DPAS operands");
+  LaneLayout layout({1, subgroupSize});
+  // For B operand, data must be packed in minimum `packedDpasBSizeInBits` and
+  // must have the VNNI format.
+  if (operandNum == 1 &&
+      elementTy.getIntOrFloatBitWidth() < packedSizeInBitsForDpasB) {
+    LaneData data(
+        {packedSizeInBitsForDpasB / elementTy.getIntOrFloatBitWidth(), 1});
+    return LayoutInfo(layout, data);
+  }
+  // Otherwise, return the default layout for the vector type.
+  return getDefaultLayoutInfo(vectorTy);
+}
+
+//===----------------------------------------------------------------------===//
+// LayoutInfoPropagation
+//===----------------------------------------------------------------------===//
+
+/// Backward data flow analysis to propagate the lane_layout and lane_data of
+/// each value in the program. Currently, the layouts for operands DPAS,
+/// StoreNd, and StoreScatter are fixed (known before propagation). Purpose of
+/// this analysis is to propagate those known layouts to all their producers and
+/// (other) consumers.
+class LayoutInfoPropagation
+    : public SparseBackwardDataFlowAnalysis<LayoutInfoLattice> {
+private:
+  void visitDpasOp(xegpu::DpasOp dpas, ArrayRef<LayoutInfoLattice *> operands,
+                   ArrayRef<const LayoutInfoLattice *> results);
+
+  void visitStoreNdOp(xegpu::StoreNdOp store,
+                      ArrayRef<LayoutInfoLattice *> operands,
+                      ArrayRef<const LayoutInfoLattice *> results);
+
+  void visitStoreScatterOp(xegpu::StoreScatterOp storeScatter,
+                           ArrayRef<LayoutInfoLattice *> operands,
+                           ArrayRef<const LayoutInfoLattice *> results);
+
+  void visitLoadNdOp(xegpu::LoadNdOp load,
+                     ArrayRef<LayoutInfoLattice *> operands,
+                     ArrayRef<const LayoutInfoLattice *> results);
+
+  void visitLoadGatherOp(xegpu::LoadGatherOp load,
+                         ArrayRef<LayoutInfoLattice *> operands,
+                         ArrayRef<const LayoutInfoLattice *> results);
+
+  void visitTransposeOp(vector::TransposeOp transpose,
+                        ArrayRef<LayoutInfoLattice *> operands,
+                        ArrayRef<const LayoutInfoLattice *> results);
+
+  void visitVectorBitcastOp(vector::BitCastOp bitcast,
+                            ArrayRef<LayoutInfoLattice *> operands,
+                            ArrayRef<const LayoutInfoLattice *> results);
+
+  void visitCreateDescOp(xegpu::CreateDescOp createDesc,
+                         ArrayRef<LayoutInfoLattice *> operands,
+                         ArrayRef<const LayoutInfoLattice *> results);
+
+  void visitUpdateNdOffsetOp(xegpu::UpdateNdOffsetOp updateNdOffset,
+                             ArrayRef<LayoutInfoLattice *> operands,
+                             ArrayRef<const LayoutInfoLattice *> results);
+
+  void visitPrefetchNdOp(xegpu::PrefetchNdOp prefetch,
+                         ArrayRef<LayoutInfoLattice *> operands,
+                         ArrayRef<const LayoutInfoLattice *> results);
+
+  void visitVectorMultiReductionOp(vector::MultiDimReductionOp reduction,
+                                   ArrayRef<LayoutInfoLattice *> operands,
+                                   ArrayRef<const LayoutInfoLattice *> results);
+
+public:
+  LayoutInfoPropagation(DataFlowSolver &solver,
+                        SymbolTableCollection &symbolTable)
+      : SparseBackwardDataFlowAnalysis(solver, symbolTable) {}
+  using SparseBackwardDataFlowAnalysis::SparseBackwardDataFlowAnalysis;
+
+  LogicalResult
+  visitOperation(Operation *op, ArrayRef<LayoutInfoLattice *> operands,
+                 ArrayRef<const LayoutInfoLattice *> results) override;
+
+  void visitBranchOperand(OpOperand &operand) override {};
+
+  void visitCallOperand(OpOperand &operand) override {};
+
+  void visitExternalCall(CallOpInterface call,
+                         ArrayRef<LayoutInfoLattice *> operands,
+                         ArrayRef<const LayoutInfoLattice *> results) override {
+  };
+
+  void setToExitState(LayoutInfoLattice *lattice) override {
+    (void)lattice->meet(LayoutInfo());
+  }
+};
+} // namespace
+
+LogicalResult LayoutInfoPropagation::visitOperation(
+    Operation *op, ArrayRef<LayoutInfoLattice *> operands,
+    ArrayRef<const LayoutInfoLattice *> results) {
+  TypeSwitch<Operation *>(op)
+      .Case<xegpu::DpasOp>(
+          [&](auto dpasOp) { visitDpasOp(dpasOp, operands, results); })
+      .Case<xegpu::StoreNdOp>(
+          [&](auto storeNdOp) { visitStoreNdOp(storeNdOp, operands, results); })
+      .Case<xegpu::StoreScatterOp>([&](auto storeScatterOp) {
+        visitStoreScatterOp(storeScatterOp, operands, results);
+      })
+      .Case<xegpu::LoadNdOp>(
+          [&](auto loadNdOp) { visitLoadNdOp(loadNdOp, operands, results); })
+      .Case<xegpu::LoadGatherOp>([&](auto loadGatherOp) {
+        visitLoadGatherOp(loadGatherOp, operands, results);
+      })
+      .Case<xegpu::CreateDescOp>([&](auto createDescOp) {
+        visitCreateDescOp(createDescOp, operands, results);
+      })
+      .Case<xegpu::UpdateNdOffsetOp>([&](auto updateNdOffsetOp) {
+        visitUpdateNdOffsetOp(updateNdOffsetOp, operands, results);
+      })
+      .Case<xegpu::PrefetchNdOp>([&](auto prefetchNdOp) {
+        visitPrefetchNdOp(prefetchNdOp, operands, results);
+      })
+      // No need to propagate the layout to operands in CreateNdDescOp because
+      // they are scalars (offsets, sizes, etc.).
+      .Case<xegpu::CreateNdDescOp>([&](auto createNdDescOp) {})
+      .Case<vector::TransposeOp>([&](auto transposeOp) {
+        visitTransposeOp(transposeOp, operands, results);
+      })
+      .Case<vector::BitCastOp>([&](auto bitcastOp) {
+        visitVectorBitcastOp(bitcastOp, operands, results);
+      })
+      .Case<vector::MultiDimReductionOp>([&](auto reductionOp) {
+        visitVectorMultiReductionOp(reductionOp, operands, results);
+      })
+      // All other ops.
+      .Default([&](Operation *op) {
+        for (const LayoutInfoLattice *r : results) {
+          for (LayoutInfoLattice *operand : operands) {
+            // Propagate the layout of the result to the operand.
+            if (r->getValue().isAssigned())
+              meet(operand, *r);
+          }
+        }
+      });
+  // Add a dependency from each result to program point after the operation.
+  for (const LayoutInfoLattice *r : results) {
+    addDependency(const_cast<LayoutInfoLattice *>(r), getProgramPointAfter(op));
+  }
+  return success();
+}
+
+void LayoutInfoPropagation::visitPrefetchNdOp(
+    xegpu::PrefetchNdOp prefetch, ArrayRef<LayoutInfoLattice *> operands,
+    ArrayRef<const LayoutInfoLattice *> results) {
+  // Here we assign the default layout to the tensor descriptor operand of
+  // prefetch.
+  auto tdescTy = prefetch.getTensorDescType();
+  auto prefetchLayout = getDefaultLayoutInfo(
+      VectorType::get(tdescTy.getShape(), tdescTy.getElementType()));
+  // Propagate the layout to the source tensor descriptor.
+  propagateIfChanged(operands[0], operands[0]->meet(prefetchLayout));
+}
+
+void LayoutInfoPropagation::visitVectorMultiReductionOp(
+    vector::MultiDimReductionOp reduction,
+    ArrayRef<LayoutInfoLattice *> operands,
+    ArrayRef<const LayoutInfoLattice *> results) {
+  // The layout of the result must be present.
+  LayoutInfo resultLayout = results[0]->getValue();
+  if (!resultLayout.isAssigned())
+    return;
+  // We only consider 2D -> 1D reductions at this point.
+  assert(resultLayout.getLayout().size() == 1 &&
+         "Expected 1D layout for reduction result.");
+  // Given that the result is 1D, the layout of the operand should be 2D with
+  // default layout.
+  LayoutInfo operandLayout = getDefaultLayoutInfo(2);
+  propagateIfChanged(operands[0], operands[0]->meet(operandLayout));
+  // Accumulator should have the same layout as the result.
+  propagateIfChanged(operands[1], operands[1]->meet(resultLayout));
+}
+
+/// Propagate the layout of the result tensor to the source tensor descriptor in
+/// UpdateNdOffsetOp.
+void LayoutInfoPropagation::visitUpdateNdOffsetOp(
+    xegpu::UpdateNdOffsetOp updateNdOffset,
+    ArrayRef<LayoutInfoLattice *> operands,
+    ArrayRef<const LayoutInfoLattice *> results) {
+  // The layout of the result must be present.
+  LayoutInfo resultLayout = results[0]->getValue();
+  if (!resultLayout.isAssigned())
+    return;
+  // Propagate the layout to the source operand.
+  propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
+}
+
+/// Set the layouts for DPAS A, B, and C operands.
+void LayoutInfoPropagation::visitDpasOp(
+    xegpu::DpasOp dpas, ArrayRef<LayoutInfoLattice *> operands,
+    ArrayRef<const LayoutInfoLattice *> results) {
+  VectorType aTy = dpas.getLhsType();
+  VectorType bTy = dpas.getRhsType();
+  propagateIfChanged(operands[0],
+                     operands[0]->meet(getLayoutInfoForDPASOperand(aTy, 0)));
+  propagateIfChanged(operands[1],
+                     operands[1]->meet(getLayoutInfoForDPASOperand(bTy, 1)));
+  if (operands.size() > 2) {
+    VectorType cTy = dpas.getAccType();
+    propagateIfChanged(operands[2],
+                       operands[2]->meet(getLayoutInfoForDPASOperand(cTy, 2)));
+  }
+}
+
+/// Set the layout for the value and tensor descriptor operands in StoreNdOp.
+void LayoutInfoPropagation::visitStoreNdOp(
+    xegpu::StoreNdOp store, ArrayRef<LayoutInfoLattice *> operands,
+    ArrayRef<const LayoutInfoLattice *> results) {
+  LayoutInfo storeLayout = getDefaultLayoutInfo(store.getValueType());
+  // Both operands should have the same layout
+  for (LayoutInfoLattice *operand : operands) {
+    propagateIfChanged(operand, operand->meet(storeLayout));
+  }
+}
+
+/// Propagate the layout of the value to the tensor descriptor operand in
+/// LoadNdOp.
+void LayoutInfoPropagation::visitLoadNdOp(
+    xegpu::LoadNdOp load, ArrayRef<LayoutInfoLattice *> operands,
+    ArrayRef<const LayoutInfoLattice *> results) {
+  LayoutInfo valueLayout = results[0]->getValue();
+  // Need the layout of the value to propagate to the tensor descriptor.
+  if (!valueLayout.isAssigned())
+    return;
+  LayoutInfo tensorDescLayout = valueLayout;
+  // LoadNdOp has the transpose effect. However, at the stage of this analysis
+  // this effect is not expected and should be abstracted away. Emit a warning.
+  if (auto transpose = load.getTranspose()) {
+    load.emitWarning("Transpose effect is not expected for LoadNdOp at "
+                     "LayoutInfoPropagation stage.");
+    tensorDescLayout = valueLayout.getTransposedLayout(transpose.value());
+  }
+  // Propagate the new layout to the tensor descriptor operand.
+  propagateIfChanged(operands[0], operands[0]->meet(tensorDescLayout));
+}
+
+/// For vector::TransposeOp, the layout of the result is transposed and
+/// propagated to the operand.
+void LayoutInfoPropagation::visitTransposeOp(
+    vector::TransposeOp transpose, ArrayRef<LayoutInfoLattice *> operands,
+    ArrayRef<const LayoutInfoLattice *> results) {
+  // Need the layout of transpose result to propagate to the operands.
+  LayoutInfo resultLayout = results[0]->getValue();
+  if (!resultLayout.isAssigned())
+    return;
+  LayoutInfo newLayout =
+      resultLayout.getTransposedLayout(transpose.getPermutation());
+  // Propagate the new layout to the vector operand.
+  propagateIfChanged(operands[0], operands[0]->meet(newLayout));
+}
+
+/// For vector::BitCastOp, the lane_data of the source layout is changed based
+/// on the bit width of the source and result types.
+void LayoutInfoPropagation::visitVectorBitcastOp(
+    vector::BitCastOp bitcast, ArrayRef<LayoutInfoLattice *> operands,
+    ArrayRef<const LayoutInfoLattice *> results) {
+  // Need the layout of bitcast result to propagate to the operands.
+  LayoutInfo resultLayout = results[0]->getValue();
+  if (!resultLayout.isAssigned())
+    return;
+  int inElemTyBitWidth =
+      bitcast.getSourceVectorType().getElementType().getIntOrFloatBitWidth();
+  int outElemTyBitWidth =
+      bitcast.getResultVectorType().getElementType().getIntOrFloatBitWidth();
+
+  // LaneLayout does not change.
+  const LaneLayout &newLaneLayout = resultLayout.getLayout();
+  const LaneData &currData = resultLayout.getData();
+  LaneData newLaneData;
+  // It's a widening bitcast
+  if (inElemTyBitWidth < outElemTyBitWidth) {
+    int ratio = outElemTyBitWidth / inElemTyBitWidth;
+    newLaneData = resultLayout.getData()[0] == 1
+                      ? LaneData({1, currData[1] * ratio})
+                      : LaneData({currData[0] * ratio, 1});
+  } else {
+    // It's a narrowing bitcast
+    int ratio = inElemTyBitWidth / outElemTyBitWidth;
+    newLaneData = resultLayout.getData()[0] == 1
+                      ? LaneData({1, currData[1] / ratio})
+                      : LaneData({currData[0] / ratio, 1});
+  }
+
+  propagateIfChanged(operands[0],
+                     operands[0]->meet(LayoutInfo(newLaneLayout, newLaneData)));
+}
+
+/// Propagate the layout of the result to the tensor descriptor and mask
+/// operands in LoadGatherOp.
+void LayoutInfoPropagation::visitLoadGatherOp(
+    xegpu::LoadGatherOp load, ArrayRef<LayoutInfoLattice *> operands,
+    ArrayRef<const LayoutInfoLattice *> results) {
+  LayoutInfo valueLayout = results[0]->getValue();
+  // Need the layout of the value to propagate to the tensor descriptor.
+  if (!valueLayout.isAssigned())
+    return;
+
+  LayoutInfo tensorDescLayout = valueLayout;
+  if (load.getTranspose()) {
+    // LoadGatherOp has the transpose effect. However, at the stage of this
+    // analyis this effect is not expected and should be abstracted away. Emit
+    // a warning.
+    load.emitWarning("Transpose effect is not expected for LoadGatherOp at "
+                     "LayoutInfoPropagation stage.");
+    tensorDescLayout = valueLayout.getTransposedLayout({1, 0});
+  }
+  // Mask operand should have 1D default layout.
+  LayoutInfo maskLayout = getDefaultLayoutInfo(1);
+  // Propagate the new layout to the tensor descriptor operand.
+  propagateIfChanged(operands[0], operands[0]->meet(tensorDescLayout));
+  // Propagate the new layout to the mask operand.
+  propagateIfChanged(operands[1], operands[1]->meet(maskLayout));
+}
+
+/// Propagate the layout of the descriptor to the vector offset operand in
+/// CreateDescOp.
+void LayoutInfoPropagation::visitCreateDescOp(
+    xegpu::CreateDescOp createDesc, ArrayRef<LayoutInfoLattice *> operands,
+    ArrayRef<const LayoutInfoLattice *> results) {
+  LayoutInfo descLayout = results[0]->getValue();
+  // Need the layout of the descriptor to propagate to the operands.
+  if (!descLayout.isAssigned())
+    return;
+  // For offset operand propagate 1D default layout.
+  LayoutInfo layout = getDefaultLayoutInfo(1);
+  propagateIfChanged(operands[1], operands[1]->meet(layout));
+}
+
+/// Set the layout for the value, tensor descriptor, and mask operands in the
+/// StoreScatterOp.
+void LayoutInfoPropagation::visitStoreScatterOp(
+    xegpu::StoreScatterOp storeScatter, ArrayRef<LayoutInfoLattice *> operands,
+    ArrayRef<const LayoutInfoLattice *> results) {
+  // Currently, for 2D StoreScatterOp we expect that the height dimension of
+  // the tensor descriptor is equal to the subgroup size. This is ensured by
+  // the op verifier.
+  ArrayRef<int64_t> tdescShape = storeScatter.getTensorDescType().getShape();
+  if (tdescShape.size() > 1)
+    assert(
+        tdescShape[0] == subgroupSize &&
+        "Expected the first dimension of 2D tensor descriptor to be equal to "
+        "subgroup size.");
+
+  LayoutInfo valueLayout = getDefaultLayoutInfo(storeScatter.getValueType());
+  LayoutInfo storeScatterLayout = valueLayout;
+  if (storeScatter.getTranspose()) {
+    // StoreScatteOp allows transpose effect. However, at the stage of this
+    // analyis this effect is not expected and should be abstracted away. Emit
+    // a warning.
+    storeScatter.emitWarning("Transpose effect is not expected for "
+                             "StoreScatterOp at LayoutInfoPropagation stage.");
+    storeScatterLayout = valueLayout.getTransposedLayout({1, 0});
+  }
+  // Propagate the value layout.
+  propagateIfChanged(operands[0], operands[0]->meet(valueLayout));
+  // Propagate the tensor descriptor layout.
+  propagateIfChanged(operands[1], operands[1]->meet(storeScatterLayout));
+  // Use default 1D layout for mask operand.
+  LayoutInfo maskLayout = getDefaultLayoutInfo(1);
+  propagateIfChanged(operands[2], operands[2]->meet(maskLayout));
+}
+
+namespace {
+
+//===----------------------------------------------------------------------===//
+// RunLayoutInfoPropagation
+//===----------------------------------------------------------------------===//
+
+/// Driver class for running the LayoutInfoPropagation analysis.
+class RunLayoutInfoPropagation {
+public:
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(RunLayoutInfoPropagation)
+
+  RunLayoutInfoPropagation(Operation *op) : target(op) {
+    SymbolTableCollection symbolTable;
+    solver.load<DeadCodeAnalysis>();
+    solver.load<SparseConstantPropagation>();
+    solver.load<LayoutInfoPropagation>(symbolTable);
+    (void)solver.initializeAndRun(op);
+  }
+
+  LayoutInfo getLayoutInfo(Value val);
+
+  void printAnalysisResult(llvm::raw_ostream &os);
+
+private:
+  DataFlowSolver solver;
+  const Operation *target;
+};
+} // namespace
+
+LayoutInfo RunLayoutInfoPropagation::getLayoutInfo(Value val) {
+  auto *state = solver.lookupState<LayoutInfoLattice>(val);
+  if (!state)
+    return {};
+  return state->getValue();
+}
+
+// Print the analysis result for debugging purposes.
+[[maybe_unused]] void
+RunLayoutInfoPropagation::printAnalysisResult(llvm::raw_ostream &os) {
+  auto printFunctionResult = [&](FunctionOpInterface funcOp) {
+    os << "function: " << funcOp.getName() << ":\n";
+    // Function arguments
+    for (BlockArgument arg : funcOp.getArguments()) {
+      LayoutInfo layout = getLayoutInfo(arg);
+      os << "argument: " << arg << "\n";
+      os << "layout  : ";
+      layout.print(os);
+      os << "\n";
+    }
+    // Function ops
+    funcOp.walk([&](Operation *op) {
+      // Skip ops that do not have results
+      if (op->getResults().empty())
+        return;
+      os << "op    : ";
+      // For control-flow ops, print the op name only.
+      if (isa<BranchOpInterface>(op) || isa<RegionBranchOpInterface>(op))
+        os << op->getName();
+      else
+        op->print(os);
+      os << "\n";
+      // Print the layout for each result.
+      for (auto [i, r] : llvm::enumerate(op->getResults())) {
+        LayoutInfo layout = getLayoutInfo(r);
+        os << "layout for result #" << i << ": ";
+        layout.print(os);
+        os << "\n";
+      }
+    });
+  };
+
+  SmallVector<FunctionOpInterface> funcOps;
+  if (auto modOp = dyn_cast<ModuleOp>(target)) {
+    for (auto funcOp : modOp.getOps<FunctionOpInterface>()) {
+      funcOps.push_back(funcOp);
+    }
+    // Collect all GpuFuncOps in the module.
+    for (auto gpuModOp : modOp.getOps<gpu::GPUModuleOp>()) {
+      for (auto gpuFuncOp : gpuModOp.getOps<FunctionOpInterface>()) {
+        funcOps.push_back(gpuFuncOp);
+      }
+    }
+  }
+  // Print the analysis result for each function.
+  for (FunctionOpInterface funcOp : funcOps) {
+    printFunctionResult(funcOp);
+  }
+}
+
+using GetLayoutCallbackFnTy = function_ref<xegpu::LayoutAttr(Value)>;
+static void updateOp(mlir::OpBuilder &builder, mlir::Operation *op,
+                     GetLayoutCallbackFnTy getLayoutOfValue) {
+
+  // Iterate over all the results.
+  for (OpResult result : op->getResults()) {
+    Type resultType = result.getType();
+    // Layouts are needed only for vector and tensor descriptor types.
+    if (!isa<VectorType, xegpu::TensorDescType>(resultType))
+      continue;
+    // If the result has any users, we expect it to have a layout.
+    xegpu::LayoutAttr layout = getLayoutOfValue(result);
+    if (!layout && result.getNumUses() > 0) {
+      LLVM_DEBUG(DBGS() << "Expecting layout for result: " << result
+                        << " but got none.\n");
+      continue;
+    }
+    if (auto tensorDescTy = dyn_cast<xegpu::TensorDescType>(resultType)) {
+      // TODO: Handle error.
+      auto typeWithLayout = xegpu::TensorDescType::get(
+          tensorDescTy.getContext(), tensorDescTy.getShape(),
+          tensorDescTy.getElementType(), tensorDescTy.getEncoding(), layout);
+      result.setType(typeWithLayout);
+      continue;
+    }
+    // If the result is a vector type, add a temporary layout attribute to the
+    // op.
+    std::string resultLayoutName =
+        resultLayoutNamePrefix + std::to_string(result.getResultNumber());
+    op->setAttr(resultLayoutName, layout);
+    // Update all users of the result with the layout.
+    for (OpOperand &user : result.getUses()) {
+      Operation *owner = user.getOwner();
+      unsigned operandNumber = user.getOperandNumber();
+      // Add temorary layout attribute at the user op.
+      std::string attrName =
+          operandLayoutNamePrefix + std::to_string(operandNumber);
+      owner->setAttr(attrName, layout);
+    }
+  }
+}
+static void updateBranchTerminatorOpInterface(
+    mlir::OpBuilder &builder,
+    mlir::RegionBranchTerminatorOpInterface terminator,
+    GetLayoutCallbackFnTy getLayoutOfValue) {
+  if (!mlir::isa<mlir::RegionBranchOpInterface>(terminator->getParentOp()))
+    return;
+
+  llvm::SmallVector<mlir::RegionSuccessor> successors;
+  llvm::SmallVector<mlir::Attribute> operands(terminator->getNumOperands(),
+                                              nullptr);
+  terminator.getSuccessorRegions(operands, successors);
+
+  for (mlir::RegionSuccessor &successor : successors) {
+    if (!successor.isParent())
+      continue;
+
+    mlir::OperandRange operands = terminator.getSuccessorOperands(successor);
+    mlir::ValueRange inputs = successor.getSuccessorInputs();
+    for (auto [operand, input] : llvm::zip(operands, inputs)) {
+      // print arg and inp
+      // llvm::errs() << "arg: " << operand << ", inp: " << input << "\n";
+      Type inputType = input.getType();
+      if (!isa<xegpu::TensorDescType>(inputType))
+        continue;
+      xegpu::LayoutAttr inputLayout = getLayoutOfValue(input);
+      xegpu::LayoutAttr operandLayout = getLayoutOfValue(operand);
+
+      if (!operandLayout) {
+        LLVM_DEBUG(DBGS() << "Expecting layout for region successor operand : "
+                          << operand << " but got none.\n");
+        continue;
+      }
+
+      if (inputLayout && inputLayout != operandLayout) {
+        LLVM_DEBUG(
+            DBGS()
+            << "Conflicting layouts for region successor operand and input: "
+            << inputLayout << " vs " << operandLayout << "\n");
+        continue;
+      }
+      // Get tensor descriptor type with the layout.
+      auto tdescTy = dyn_cast<xegpu::TensorDescType>(inputType);
+      auto newTdescTy = xegpu::TensorDescType::get(
+          tdescTy.getContext(), tdescTy.getShape(), tdescTy.getElementType(),
+          tdescTy.getEncoding(), operandLayout);
+      input.setType(newTdescTy);
+    }
+  }
+}
+static void updateBranchOpInterface(mlir::OpBuilder &builder,
+                                    mlir::RegionBranchOpInterface branch,
+                                    GetLayoutCallbackFnTy getLayoutOfValue) {
+  mlir::Operation *op = branch.getOperation();
+  llvm::SmallVector<mlir::RegionSuccessor> successors;
+  llvm::SmallVector<mlir::Attribute> operands(op->getNumOperands(), nullptr);
+  branch.getEntrySuccessorRegions(operands, successors);
+  DenseMap<Value, xegpu::LayoutAttr> resultToLayouts;
+  mlir::ValueRange results = op->getResults();
+
+  for (mlir::RegionSuccessor &successor : successors) {
+    if (successor.isParent())
+      continue;
+
+    mlir::OperandRange operands = branch.getEntrySuccessorOperands(successor);
+    mlir::ValueRange inputs = successor.getSuccessorInputs();
+
+    for (auto [operand, input, result] : llvm::zip(operands, inputs, results)) {
+      Type inputType = input.getType();
+      if (!isa<xegpu::TensorDescType>(inputType))
+        continue;
+      xegpu::LayoutAttr blockArgLayout = getLayoutOfValue(input);
+      xegpu::LayoutAttr initArgLayout = getLayoutOfValue(operand);
+
+      if (!blockArgLayout || !initArgLayout) {
+        LLVM_DEBUG(DBGS() << "No layout assigned for block arg: " << input
+                          << " or init arg: " << operand << "\n");
+        continue;
+      }
+
+      // TOOD: We expect these two to match. Data flow analysis will ensure
+      // this.
+      assert(blockArgLayout == initArgLayout &&
+             "Expexing block arg and init arg to have the same layout.");
+      // Get tensor descriptor type with the layout.
+      auto tdescTy = dyn_cast<xegpu::TensorDescType>(inputType);
+      auto newTdescTy = xegpu::TensorDescType::get(
+          tdescTy.getContext(), tdescTy.getShape(), tdescTy.getElementType(),
+          tdescTy.getEncoding(), blockArgLayout);
+      input.setType(newTdescTy);
+      // Store the layout for the result.
+      if (resultToLayouts.count(result) != 0 &&
+          resultToLayouts[result] != blockArgLayout) {
+        LLVM_DEBUG(DBGS() << "Conflicting layouts for result: " << result
+                          << " - " << resultToLayouts[result] << " vs "
+                          << blockArgLayout << "\n");
+      } else {
+        resultToLayouts[result] = blockArgLayout;
+      }
+    }
+  }
+  for (auto [i, r] : llvm::enumerate(op->getResults())) {
+    Type resultType = r.getType();
+    if (!isa<xegpu::TensorDescType, VectorType>(resultType))
+      continue;
+    xegpu::LayoutAttr layout = getLayoutOfValue(r);
+    if (!layout)
+      layout = resultToLayouts[r];
+    if (!layout) {
+      LLVM_DEBUG(DBGS() << "No layout assigned for vector/tensor desc result:"
+                        << r << "\n");
+      continue;
+    }
+    if (auto tensorDescTy = dyn_cast<xegpu::TensorDescType>(resultType)) {
+      auto newTdescTy = xegpu::TensorDescType::get(
+          tensorDescTy.getContext(), tensorDescTy.getShape(),
+          tensorDescTy.getElementType(), tensorDescTy.getEncoding(), layout);
+      r.setType(newTdescTy);
+      continue;
+    }
+    // If the result is a vector type, add a temporary layout attribute to
+    // the op.
+    std::string resultLayoutName =
+        resultLayoutNamePrefix + std::to_string(r.getResultNumber());
+    op->setAttr(resultLayoutName, layout);
+    // Update all users of the result with the layout.
+    for (OpOperand &user : r.getUses()) {
+      Operation *owner = user.getOwner();
+      unsigned operandNumber = user.getOperandNumber();
+      // Add temporary layout attribute at the user op.
+      std::string attrName =
+          operandLayoutNamePrefix + std::to_string(operandNumber);
+      owner->setAttr(attrName, layout);
+    }
+  }
+}
+
+namespace {
+
+struct XeGPULayoutPropagatePass final
+    : public xegpu::impl::XeGPULayoutPropagateBase<XeGPULayoutPropagatePass> {
+  void runOnOperation() override;
+};
+
+} // namespace
+
+void XeGPULayoutPropagatePass::runOnOperation() {
+  auto &analyis = getAnalysis<RunLayoutInfoPropagation>();
+
+  auto getXeGPULayoutForValue = [&](Value val) -> xegpu::LayoutAttr {
+    LayoutInfo layout = analyis.getLayoutInfo(val);
+    if (!layout.isAssigned()) {
+      return {};
+    }
+    SmallVector<int, 2> laneLayout, laneData;
+    for (auto [layout, data] : llvm::zip_equal(layout.getLayoutAsArrayRef(),
+                                               layout.getDataAsArrayRef())) {
+      laneLayout.push_back(static_cast<int>(layout));
+      laneData.push_back(static_cast<int>(data));
+    }
+    return xegpu::LayoutAttr::get(val.getContext(), laneLayout, laneData);
+  };
+
+  mlir::OpBuilder builder(&getContext());
+  Operation *op = getOperation();
+  op->walk([&](mlir::Block *block) {
+    for (mlir::Operation &op : llvm::reverse(block->getOperations())) {
+      if (auto terminator =
+              mlir::dyn_cast<mlir::RegionBranchTerminatorOpInterface>(op)) {
+        updateBranchTerminatorOpInterface(builder, terminator,
+                                          getXeGPULayoutForValue);
+        continue;
+      }
+
+      if (auto iface = mlir::dyn_cast<mlir::RegionBranchOpInterface>(op)) {
+        updateBranchOpInterface(builder, iface, getXeGPULayoutForValue);
+        continue;
+      }
+      updateOp(builder, &op, getXeGPULayoutForValue);
+    }
+  });
+}
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index a17c8d8a4f3f3..2df8701ed3b31 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -57,7 +57,6 @@ namespace xegpu {
 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
 
 using namespace mlir;
-using namespace mlir::dataflow;
 
 /// HW dependent constants.
 /// TODO: These constants should be queried from the target information.
@@ -79,1017 +78,6 @@ static const char *const resolveSIMTTypeMismatch =
 
 namespace {
 
-//===----------------------------------------------------------------------===//
-// Layout
-//===----------------------------------------------------------------------===//
-
-/// Helper class to store the ND layout of lanes within a subgroup and data
-/// owned by each lane.
-struct Layout {
-  SmallVector<int64_t, 3> layout;
-  Layout() = default;
-  Layout(std::initializer_list<int64_t> list) : layout(list) {}
-  void print(llvm::raw_ostream &os) const;
-  size_t size() const { return layout.size(); }
-  int64_t operator[](size_t idx) const;
-};
-
-void Layout::print(llvm::raw_ostream &os) const {
-  os << llvm::interleaved_array(layout);
-}
-
-int64_t Layout::operator[](size_t idx) const {
-  assert(idx < layout.size() && "Index out of bounds.");
-  return layout[idx];
-}
-
-/// LaneLayout represents the logical layout of lanes within a subgroup when it
-/// accesses some value. LaneData represents the logical layout of data owned by
-/// each work item.
-using LaneLayout = Layout;
-using LaneData = Layout;
-
-//===----------------------------------------------------------------------===//
-// LayoutInfo
-//===----------------------------------------------------------------------===//
-
-/// Helper class for tracking the analysis state of an mlir value. For layout
-/// propagation, the analysis state is simply the lane_layout and lane_data of
-/// each value. Purpose of this analysis to propagate some unique layout for
-/// each value in the program starting from a set of anchor operations (like
-/// DPAS, StoreNd, etc.).
-///
-/// Given this, LayoutInfo  satisifies the following properties:
-///  1) A LayoutInfo value can be in one of two states - `assigned` or `not
-///  assigned`.
-///  2) Two LayoutInfo values are equal if they are both assigned or
-///  both not assigned. The concrete value of assigned state does not matter.
-///  3) The meet operator works as follows:
-///     - If current state is assigned, return the current state. (already
-///     a unique layout is assigned. don't change it)
-///     - Otherwise, return the other state.
-
-struct LayoutInfo {
-private:
-  LaneLayout laneLayout;
-  LaneData laneData;
-
-public:
-  LayoutInfo() = default;
-  LayoutInfo(const LaneLayout &layout, const LaneData &data)
-      : laneLayout(layout), laneData(data) {}
-
-  // Two lattice values are equal if they have `some` layout. The actual
-  // content of the layout does not matter.
-  bool operator==(const LayoutInfo &other) const {
-    return this->isAssigned() == other.isAssigned();
-  }
-
-  static LayoutInfo meet(const LayoutInfo &lhs, const LayoutInfo &rhs);
-
-  static LayoutInfo join(const LayoutInfo &lhs, const LayoutInfo &rhs);
-
-  void print(raw_ostream &os) const;
-
-  bool isAssigned() const {
-    return laneLayout.size() > 0 && laneData.size() > 0;
-  }
-
-  LayoutInfo getTransposedLayout(ArrayRef<int64_t> permutation) const;
-
-  const LaneLayout &getLayout() const { return laneLayout; }
-  const LaneData &getData() const { return laneData; }
-  ArrayRef<int64_t> getLayoutAsArrayRef() const { return laneLayout.layout; }
-  ArrayRef<int64_t> getDataAsArrayRef() const { return laneData.layout; }
-};
-
-void LayoutInfo::print(raw_ostream &os) const {
-  if (isAssigned()) {
-    os << "lane_layout: ";
-    laneLayout.print(os);
-    os << ", lane_data: ";
-    laneData.print(os);
-  } else {
-    os << "Not assigned.";
-  }
-}
-
-LayoutInfo LayoutInfo::meet(const LayoutInfo &lhs, const LayoutInfo &rhs) {
-  if (!lhs.isAssigned())
-    return rhs;
-  return lhs;
-}
-
-/// Since this is a backward analysis, join method is not used.
-LayoutInfo LayoutInfo::join(const LayoutInfo &lhs, const LayoutInfo &rhs) {
-  llvm_unreachable("Join should not be triggered by layout propagation.");
-}
-
-/// Get the transposed layout according to the given permutation.
-LayoutInfo
-LayoutInfo::getTransposedLayout(ArrayRef<int64_t> permutation) const {
-  if (!isAssigned())
-    return {};
-  LaneLayout newLayout;
-  LaneData newData;
-  for (int64_t idx : permutation) {
-    newLayout.layout.push_back(laneLayout.layout[idx]);
-    newData.layout.push_back(laneData.layout[idx]);
-  }
-  return LayoutInfo(newLayout, newData);
-}
-
-//===----------------------------------------------------------------------===//
-// LayoutInfoLattice
-//===----------------------------------------------------------------------===//
-
-/// Lattice holding the LayoutInfo for each value.
-struct LayoutInfoLattice : public Lattice<LayoutInfo> {
-  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LayoutInfoLattice)
-  using Lattice::Lattice;
-};
-
-/// Helper Functions to get default layouts. A `default layout` is a layout that
-/// is assigned to a value when the layout is not fixed by some anchor operation
-/// (like DPAS).
-
-/// Helper Function to get the default layout for uniform values like constants.
-/// For 1D vector, lane_layout is [subgroupSize] and lane_data is [1].
-/// For 2D vector, lane_layout is [1, subgroupSize] and lane_data is [1, 1].
-static LayoutInfo getDefaultLayoutInfo(unsigned rank) {
-  assert((rank == 1 || rank == 2) && "Expected 1D or 2D vector.");
-  if (rank == 1)
-    return LayoutInfo(LaneLayout({subgroupSize}), LaneData({1}));
-  return LayoutInfo(LaneLayout({1, subgroupSize}), LaneData({1, 1}));
-}
-
-/// Helper to get the default layout for a vector type.
-static LayoutInfo getDefaultLayoutInfo(VectorType vectorTy) {
-  // Expecting a 1D or 2D vector.
-  assert((vectorTy.getRank() == 1 || vectorTy.getRank() == 2) &&
-         "Expected 1D or 2D vector.");
-  // Expecting int or float element type.
-  assert(vectorTy.getElementType().isIntOrFloat() &&
-         "Expected int or float element type.");
-  // If the rank is 1, then return default layout for 1D vector.
-  if (vectorTy.getRank() == 1)
-    return getDefaultLayoutInfo(1);
-  // Packing factor is determined by the element type bitwidth.
-  int packingFactor = 1;
-  unsigned bitwidth = vectorTy.getElementType().getIntOrFloatBitWidth();
-  if (bitwidth < packedSizeInBitsForDefault)
-    packingFactor = packedSizeInBitsForDefault / bitwidth;
-  return LayoutInfo(LaneLayout({1, subgroupSize}),
-                    LaneData({1, packingFactor}));
-}
-
-/// Helper Function to get the expected layouts for DPAS operands. `lane_data`
-/// is set according to the following criteria:
-/// * For A operand, the data must be packed in minimum
-/// `packedSizeInBitsForDefault`
-/// * For B operand, the data must be packed in minimum
-/// `packedSizeInBitsForDpasB`
-static LayoutInfo getLayoutInfoForDPASOperand(VectorType vectorTy,
-                                              unsigned operandNum) {
-  Type elementTy = vectorTy.getElementType();
-  assert(elementTy.isIntOrFloat() &&
-         "Expected int or float type in DPAS operands");
-  LaneLayout layout({1, subgroupSize});
-  // For B operand, data must be packed in minimum `packedDpasBSizeInBits` and
-  // must have the VNNI format.
-  if (operandNum == 1 &&
-      elementTy.getIntOrFloatBitWidth() < packedSizeInBitsForDpasB) {
-    LaneData data(
-        {packedSizeInBitsForDpasB / elementTy.getIntOrFloatBitWidth(), 1});
-    return LayoutInfo(layout, data);
-  }
-  // Otherwise, return the default layout for the vector type.
-  return getDefaultLayoutInfo(vectorTy);
-}
-
-//===----------------------------------------------------------------------===//
-// LayoutInfoPropagation
-//===----------------------------------------------------------------------===//
-
-/// Backward data flow analysis to propagate the lane_layout and lane_data of
-/// each value in the program. Currently, the layouts for operands DPAS,
-/// StoreNd, and StoreScatter are fixed (known before propagation). Purpose of
-/// this analysis is to propagate those known layouts to all their producers and
-/// (other) consumers.
-class LayoutInfoPropagation
-    : public SparseBackwardDataFlowAnalysis<LayoutInfoLattice> {
-private:
-  void visitDpasOp(xegpu::DpasOp dpas, ArrayRef<LayoutInfoLattice *> operands,
-                   ArrayRef<const LayoutInfoLattice *> results);
-
-  void visitStoreNdOp(xegpu::StoreNdOp store,
-                      ArrayRef<LayoutInfoLattice *> operands,
-                      ArrayRef<const LayoutInfoLattice *> results);
-
-  void visitStoreScatterOp(xegpu::StoreScatterOp storeScatter,
-                           ArrayRef<LayoutInfoLattice *> operands,
-                           ArrayRef<const LayoutInfoLattice *> results);
-
-  void visitLoadNdOp(xegpu::LoadNdOp load,
-                     ArrayRef<LayoutInfoLattice *> operands,
-                     ArrayRef<const LayoutInfoLattice *> results);
-
-  void visitLoadGatherOp(xegpu::LoadGatherOp load,
-                         ArrayRef<LayoutInfoLattice *> operands,
-                         ArrayRef<const LayoutInfoLattice *> results);
-
-  void visitTransposeOp(vector::TransposeOp transpose,
-                        ArrayRef<LayoutInfoLattice *> operands,
-                        ArrayRef<const LayoutInfoLattice *> results);
-
-  void visitVectorBitcastOp(vector::BitCastOp bitcast,
-                            ArrayRef<LayoutInfoLattice *> operands,
-                            ArrayRef<const LayoutInfoLattice *> results);
-
-  void visitCreateDescOp(xegpu::CreateDescOp createDesc,
-                         ArrayRef<LayoutInfoLattice *> operands,
-                         ArrayRef<const LayoutInfoLattice *> results);
-
-  void visitUpdateNdOffsetOp(xegpu::UpdateNdOffsetOp updateNdOffset,
-                             ArrayRef<LayoutInfoLattice *> operands,
-                             ArrayRef<const LayoutInfoLattice *> results);
-
-  void visitPrefetchNdOp(xegpu::PrefetchNdOp prefetch,
-                         ArrayRef<LayoutInfoLattice *> operands,
-                         ArrayRef<const LayoutInfoLattice *> results);
-
-  void visitVectorMultiReductionOp(vector::MultiDimReductionOp reduction,
-                                   ArrayRef<LayoutInfoLattice *> operands,
-                                   ArrayRef<const LayoutInfoLattice *> results);
-
-public:
-  LayoutInfoPropagation(DataFlowSolver &solver,
-                        SymbolTableCollection &symbolTable)
-      : SparseBackwardDataFlowAnalysis(solver, symbolTable) {}
-  using SparseBackwardDataFlowAnalysis::SparseBackwardDataFlowAnalysis;
-
-  LogicalResult
-  visitOperation(Operation *op, ArrayRef<LayoutInfoLattice *> operands,
-                 ArrayRef<const LayoutInfoLattice *> results) override;
-
-  void visitBranchOperand(OpOperand &operand) override {};
-
-  void visitCallOperand(OpOperand &operand) override {};
-
-  void visitExternalCall(CallOpInterface call,
-                         ArrayRef<LayoutInfoLattice *> operands,
-                         ArrayRef<const LayoutInfoLattice *> results) override {
-  };
-
-  void setToExitState(LayoutInfoLattice *lattice) override {
-    (void)lattice->meet(LayoutInfo());
-  }
-};
-} // namespace
-
-LogicalResult LayoutInfoPropagation::visitOperation(
-    Operation *op, ArrayRef<LayoutInfoLattice *> operands,
-    ArrayRef<const LayoutInfoLattice *> results) {
-  TypeSwitch<Operation *>(op)
-      .Case<xegpu::DpasOp>(
-          [&](auto dpasOp) { visitDpasOp(dpasOp, operands, results); })
-      .Case<xegpu::StoreNdOp>(
-          [&](auto storeNdOp) { visitStoreNdOp(storeNdOp, operands, results); })
-      .Case<xegpu::StoreScatterOp>([&](auto storeScatterOp) {
-        visitStoreScatterOp(storeScatterOp, operands, results);
-      })
-      .Case<xegpu::LoadNdOp>(
-          [&](auto loadNdOp) { visitLoadNdOp(loadNdOp, operands, results); })
-      .Case<xegpu::LoadGatherOp>([&](auto loadGatherOp) {
-        visitLoadGatherOp(loadGatherOp, operands, results);
-      })
-      .Case<xegpu::CreateDescOp>([&](auto createDescOp) {
-        visitCreateDescOp(createDescOp, operands, results);
-      })
-      .Case<xegpu::UpdateNdOffsetOp>([&](auto updateNdOffsetOp) {
-        visitUpdateNdOffsetOp(updateNdOffsetOp, operands, results);
-      })
-      .Case<xegpu::PrefetchNdOp>([&](auto prefetchNdOp) {
-        visitPrefetchNdOp(prefetchNdOp, operands, results);
-      })
-      // No need to propagate the layout to operands in CreateNdDescOp because
-      // they are scalars (offsets, sizes, etc.).
-      .Case<xegpu::CreateNdDescOp>([&](auto createNdDescOp) {})
-      .Case<vector::TransposeOp>([&](auto transposeOp) {
-        visitTransposeOp(transposeOp, operands, results);
-      })
-      .Case<vector::BitCastOp>([&](auto bitcastOp) {
-        visitVectorBitcastOp(bitcastOp, operands, results);
-      })
-      .Case<vector::MultiDimReductionOp>([&](auto reductionOp) {
-        visitVectorMultiReductionOp(reductionOp, operands, results);
-      })
-      // All other ops.
-      .Default([&](Operation *op) {
-        for (const LayoutInfoLattice *r : results) {
-          for (LayoutInfoLattice *operand : operands) {
-            // Propagate the layout of the result to the operand.
-            if (r->getValue().isAssigned())
-              meet(operand, *r);
-          }
-        }
-      });
-  // Add a dependency from each result to program point after the operation.
-  for (const LayoutInfoLattice *r : results) {
-    addDependency(const_cast<LayoutInfoLattice *>(r), getProgramPointAfter(op));
-  }
-  return success();
-}
-
-void LayoutInfoPropagation::visitPrefetchNdOp(
-    xegpu::PrefetchNdOp prefetch, ArrayRef<LayoutInfoLattice *> operands,
-    ArrayRef<const LayoutInfoLattice *> results) {
-  // Here we assign the default layout to the tensor descriptor operand of
-  // prefetch.
-  auto tdescTy = prefetch.getTensorDescType();
-  auto prefetchLayout = getDefaultLayoutInfo(
-      VectorType::get(tdescTy.getShape(), tdescTy.getElementType()));
-  // Propagate the layout to the source tensor descriptor.
-  propagateIfChanged(operands[0], operands[0]->meet(prefetchLayout));
-}
-
-void LayoutInfoPropagation::visitVectorMultiReductionOp(
-    vector::MultiDimReductionOp reduction,
-    ArrayRef<LayoutInfoLattice *> operands,
-    ArrayRef<const LayoutInfoLattice *> results) {
-  // The layout of the result must be present.
-  LayoutInfo resultLayout = results[0]->getValue();
-  if (!resultLayout.isAssigned())
-    return;
-  // We only consider 2D -> 1D reductions at this point.
-  assert(resultLayout.getLayout().size() == 1 &&
-         "Expected 1D layout for reduction result.");
-  // Given that the result is 1D, the layout of the operand should be 2D with
-  // default layout.
-  LayoutInfo operandLayout = getDefaultLayoutInfo(2);
-  propagateIfChanged(operands[0], operands[0]->meet(operandLayout));
-  // Accumulator should have the same layout as the result.
-  propagateIfChanged(operands[1], operands[1]->meet(resultLayout));
-}
-
-/// Propagate the layout of the result tensor to the source tensor descriptor in
-/// UpdateNdOffsetOp.
-void LayoutInfoPropagation::visitUpdateNdOffsetOp(
-    xegpu::UpdateNdOffsetOp updateNdOffset,
-    ArrayRef<LayoutInfoLattice *> operands,
-    ArrayRef<const LayoutInfoLattice *> results) {
-  // The layout of the result must be present.
-  LayoutInfo resultLayout = results[0]->getValue();
-  if (!resultLayout.isAssigned())
-    return;
-  // Propagate the layout to the source operand.
-  propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
-}
-
-/// Set the layouts for DPAS A, B, and C operands.
-void LayoutInfoPropagation::visitDpasOp(
-    xegpu::DpasOp dpas, ArrayRef<LayoutInfoLattice *> operands,
-    ArrayRef<const LayoutInfoLattice *> results) {
-  VectorType aTy = dpas.getLhsType();
-  VectorType bTy = dpas.getRhsType();
-  propagateIfChanged(operands[0],
-                     operands[0]->meet(getLayoutInfoForDPASOperand(aTy, 0)));
-  propagateIfChanged(operands[1],
-                     operands[1]->meet(getLayoutInfoForDPASOperand(bTy, 1)));
-  if (operands.size() > 2) {
-    VectorType cTy = dpas.getAccType();
-    propagateIfChanged(operands[2],
-                       operands[2]->meet(getLayoutInfoForDPASOperand(cTy, 2)));
-  }
-}
-
-/// Set the layout for the value and tensor descriptor operands in StoreNdOp.
-void LayoutInfoPropagation::visitStoreNdOp(
-    xegpu::StoreNdOp store, ArrayRef<LayoutInfoLattice *> operands,
-    ArrayRef<const LayoutInfoLattice *> results) {
-  LayoutInfo storeLayout = getDefaultLayoutInfo(store.getValueType());
-  // Both operands should have the same layout
-  for (LayoutInfoLattice *operand : operands) {
-    propagateIfChanged(operand, operand->meet(storeLayout));
-  }
-}
-
-/// Propagate the layout of the value to the tensor descriptor operand in
-/// LoadNdOp.
-void LayoutInfoPropagation::visitLoadNdOp(
-    xegpu::LoadNdOp load, ArrayRef<LayoutInfoLattice *> operands,
-    ArrayRef<const LayoutInfoLattice *> results) {
-  LayoutInfo valueLayout = results[0]->getValue();
-  // Need the layout of the value to propagate to the tensor descriptor.
-  if (!valueLayout.isAssigned())
-    return;
-  LayoutInfo tensorDescLayout = valueLayout;
-  // LoadNdOp has the transpose effect. However, at the stage of this analysis
-  // this effect is not expected and should be abstracted away. Emit a warning.
-  if (auto transpose = load.getTranspose()) {
-    load.emitWarning("Transpose effect is not expected for LoadNdOp at "
-                     "LayoutInfoPropagation stage.");
-    tensorDescLayout = valueLayout.getTransposedLayout(transpose.value());
-  }
-  // Propagate the new layout to the tensor descriptor operand.
-  propagateIfChanged(operands[0], operands[0]->meet(tensorDescLayout));
-}
-
-/// For vector::TransposeOp, the layout of the result is transposed and
-/// propagated to the operand.
-void LayoutInfoPropagation::visitTransposeOp(
-    vector::TransposeOp transpose, ArrayRef<LayoutInfoLattice *> operands,
-    ArrayRef<const LayoutInfoLattice *> results) {
-  // Need the layout of transpose result to propagate to the operands.
-  LayoutInfo resultLayout = results[0]->getValue();
-  if (!resultLayout.isAssigned())
-    return;
-  LayoutInfo newLayout =
-      resultLayout.getTransposedLayout(transpose.getPermutation());
-  // Propagate the new layout to the vector operand.
-  propagateIfChanged(operands[0], operands[0]->meet(newLayout));
-}
-
-/// For vector::BitCastOp, the lane_data of the source layout is changed based
-/// on the bit width of the source and result types.
-void LayoutInfoPropagation::visitVectorBitcastOp(
-    vector::BitCastOp bitcast, ArrayRef<LayoutInfoLattice *> operands,
-    ArrayRef<const LayoutInfoLattice *> results) {
-  // Need the layout of bitcast result to propagate to the operands.
-  LayoutInfo resultLayout = results[0]->getValue();
-  if (!resultLayout.isAssigned())
-    return;
-  int inElemTyBitWidth =
-      bitcast.getSourceVectorType().getElementType().getIntOrFloatBitWidth();
-  int outElemTyBitWidth =
-      bitcast.getResultVectorType().getElementType().getIntOrFloatBitWidth();
-
-  // LaneLayout does not change.
-  const LaneLayout &newLaneLayout = resultLayout.getLayout();
-  const LaneData &currData = resultLayout.getData();
-  LaneData newLaneData;
-  // It's a widening bitcast
-  if (inElemTyBitWidth < outElemTyBitWidth) {
-    int ratio = outElemTyBitWidth / inElemTyBitWidth;
-    newLaneData = resultLayout.getData()[0] == 1
-                      ? LaneData({1, currData[1] * ratio})
-                      : LaneData({currData[0] * ratio, 1});
-  } else {
-    // It's a narrowing bitcast
-    int ratio = inElemTyBitWidth / outElemTyBitWidth;
-    newLaneData = resultLayout.getData()[0] == 1
-                      ? LaneData({1, currData[1] / ratio})
-                      : LaneData({currData[0] / ratio, 1});
-  }
-
-  propagateIfChanged(operands[0],
-                     operands[0]->meet(LayoutInfo(newLaneLayout, newLaneData)));
-}
-
-/// Propagate the layout of the result to the tensor descriptor and mask
-/// operands in LoadGatherOp.
-void LayoutInfoPropagation::visitLoadGatherOp(
-    xegpu::LoadGatherOp load, ArrayRef<LayoutInfoLattice *> operands,
-    ArrayRef<const LayoutInfoLattice *> results) {
-  LayoutInfo valueLayout = results[0]->getValue();
-  // Need the layout of the value to propagate to the tensor descriptor.
-  if (!valueLayout.isAssigned())
-    return;
-
-  LayoutInfo tensorDescLayout = valueLayout;
-  if (load.getTranspose()) {
-    // LoadGatherOp has the transpose effect. However, at the stage of this
-    // analyis this effect is not expected and should be abstracted away. Emit
-    // a warning.
-    load.emitWarning("Transpose effect is not expected for LoadGatherOp at "
-                     "LayoutInfoPropagation stage.");
-    tensorDescLayout = valueLayout.getTransposedLayout({1, 0});
-  }
-  // Mask operand should have 1D default layout.
-  LayoutInfo maskLayout = getDefaultLayoutInfo(1);
-  // Propagate the new layout to the tensor descriptor operand.
-  propagateIfChanged(operands[0], operands[0]->meet(tensorDescLayout));
-  // Propagate the new layout to the mask operand.
-  propagateIfChanged(operands[1], operands[1]->meet(maskLayout));
-}
-
-/// Propagate the layout of the descriptor to the vector offset operand in
-/// CreateDescOp.
-void LayoutInfoPropagation::visitCreateDescOp(
-    xegpu::CreateDescOp createDesc, ArrayRef<LayoutInfoLattice *> operands,
-    ArrayRef<const LayoutInfoLattice *> results) {
-  LayoutInfo descLayout = results[0]->getValue();
-  // Need the layout of the descriptor to propagate to the operands.
-  if (!descLayout.isAssigned())
-    return;
-  // For offset operand propagate 1D default layout.
-  LayoutInfo layout = getDefaultLayoutInfo(1);
-  propagateIfChanged(operands[1], operands[1]->meet(layout));
-}
-
-/// Set the layout for the value, tensor descriptor, and mask operands in the
-/// StoreScatterOp.
-void LayoutInfoPropagation::visitStoreScatterOp(
-    xegpu::StoreScatterOp storeScatter, ArrayRef<LayoutInfoLattice *> operands,
-    ArrayRef<const LayoutInfoLattice *> results) {
-  // Currently, for 2D StoreScatterOp we expect that the height dimension of
-  // the tensor descriptor is equal to the subgroup size. This is ensured by
-  // the op verifier.
-  ArrayRef<int64_t> tdescShape = storeScatter.getTensorDescType().getShape();
-  if (tdescShape.size() > 1)
-    assert(
-        tdescShape[0] == subgroupSize &&
-        "Expected the first dimension of 2D tensor descriptor to be equal to "
-        "subgroup size.");
-
-  LayoutInfo valueLayout = getDefaultLayoutInfo(storeScatter.getValueType());
-  LayoutInfo storeScatterLayout = valueLayout;
-  if (storeScatter.getTranspose()) {
-    // StoreScatteOp allows transpose effect. However, at the stage of this
-    // analyis this effect is not expected and should be abstracted away. Emit
-    // a warning.
-    storeScatter.emitWarning("Transpose effect is not expected for "
-                             "StoreScatterOp at LayoutInfoPropagation stage.");
-    storeScatterLayout = valueLayout.getTransposedLayout({1, 0});
-  }
-  // Propagate the value layout.
-  propagateIfChanged(operands[0], operands[0]->meet(valueLayout));
-  // Propagate the tensor descriptor layout.
-  propagateIfChanged(operands[1], operands[1]->meet(storeScatterLayout));
-  // Use default 1D layout for mask operand.
-  LayoutInfo maskLayout = getDefaultLayoutInfo(1);
-  propagateIfChanged(operands[2], operands[2]->meet(maskLayout));
-}
-
-namespace {
-
-//===----------------------------------------------------------------------===//
-// RunLayoutInfoPropagation
-//===----------------------------------------------------------------------===//
-
-/// Driver class for running the LayoutInfoPropagation analysis.
-class RunLayoutInfoPropagation {
-public:
-  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(RunLayoutInfoPropagation)
-
-  RunLayoutInfoPropagation(Operation *op) : target(op) {
-    SymbolTableCollection symbolTable;
-    solver.load<DeadCodeAnalysis>();
-    solver.load<SparseConstantPropagation>();
-    solver.load<LayoutInfoPropagation>(symbolTable);
-    (void)solver.initializeAndRun(op);
-  }
-
-  LayoutInfo getLayoutInfo(Value val);
-
-  void printAnalysisResult(llvm::raw_ostream &os);
-
-private:
-  DataFlowSolver solver;
-  const Operation *target;
-};
-} // namespace
-
-LayoutInfo RunLayoutInfoPropagation::getLayoutInfo(Value val) {
-  auto *state = solver.lookupState<LayoutInfoLattice>(val);
-  if (!state)
-    return {};
-  return state->getValue();
-}
-
-void RunLayoutInfoPropagation::printAnalysisResult(llvm::raw_ostream &os) {
-  auto printFunctionResult = [&](FunctionOpInterface funcOp) {
-    os << "function: " << funcOp.getName() << ":\n";
-    // Function arguments
-    for (BlockArgument arg : funcOp.getArguments()) {
-      LayoutInfo layout = getLayoutInfo(arg);
-      os << "argument: " << arg << "\n";
-      os << "layout  : ";
-      layout.print(os);
-      os << "\n";
-    }
-    // Function ops
-    funcOp.walk([&](Operation *op) {
-      // Skip ops that do not have results
-      if (op->getResults().empty())
-        return;
-      os << "op    : ";
-      // For control-flow ops, print the op name only.
-      if (isa<BranchOpInterface>(op) || isa<RegionBranchOpInterface>(op))
-        os << op->getName();
-      else
-        op->print(os);
-      os << "\n";
-      // Print the layout for each result.
-      for (auto [i, r] : llvm::enumerate(op->getResults())) {
-        LayoutInfo layout = getLayoutInfo(r);
-        os << "layout for result #" << i << ": ";
-        layout.print(os);
-        os << "\n";
-      }
-    });
-  };
-
-  SmallVector<FunctionOpInterface> funcOps;
-  if (auto modOp = dyn_cast<ModuleOp>(target)) {
-    for (auto funcOp : modOp.getOps<FunctionOpInterface>()) {
-      funcOps.push_back(funcOp);
-    }
-    // Collect all GpuFuncOps in the module.
-    for (auto gpuModOp : modOp.getOps<gpu::GPUModuleOp>()) {
-      for (auto gpuFuncOp : gpuModOp.getOps<FunctionOpInterface>()) {
-        funcOps.push_back(gpuFuncOp);
-      }
-    }
-  }
-  // Print the analysis result for each function.
-  for (FunctionOpInterface funcOp : funcOps) {
-    printFunctionResult(funcOp);
-  }
-}
-
-// namespace {
-
-//===----------------------------------------------------------------------===//
-// LayoutAttrAssignment
-//===----------------------------------------------------------------------===//
-// template <typename OpTy>
-// class UpdateTensorDescType : public OpConversionPattern<OpTy> {
-// public:
-//   UpdateTensorDescType(MLIRContext *context,
-//                        function_ref<xegpu::LayoutAttr(Value)>
-//                        getLayoutOfValue, TypeConverter &typeConverter,
-//                        PatternBenefit benefit = 1)
-//       : OpConversionPattern<OpTy>(typeConverter, context, benefit),
-//         getLayoutOfValue(getLayoutOfValue) {}
-//   using OpConversionPattern<OpTy>::OpConversionPattern;
-//   LogicalResult
-//   matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
-//                   ConversionPatternRewriter &rewriter) const override {
-//     // Op must have single result.
-//     if (op->getNumResults() != 1)
-//       return failure();
-//     Type resultType = op->getResult(0).getType();
-//     // Result type must be a tensor descriptor type.
-//     if (!isa<xegpu::TensorDescType>(resultType)) {
-//       LLVM_DEBUG(DBGS() << "Result type is not a tensor descriptor type: "
-//                         << resultType << "\n");
-//       return failure();
-//     }
-//     auto assignedLayout = getLayoutOfValue(op.getResult());
-//     if (!assignedLayout) {
-//       LLVM_DEBUG(DBGS() << "No layout assigned for " << *op << "\n");
-//       return failure();
-//     }
-//     // Get the original tensor descriptor type.
-//     auto origTensorDescTy = dyn_cast<xegpu::TensorDescType>(resultType);
-//     auto newTensorDescTy = xegpu::TensorDescType::get(
-//         origTensorDescTy.getContext(), origTensorDescTy.getShape(),
-//         origTensorDescTy.getElementType(), origTensorDescTy.getEncoding(),
-//         assignedLayout);
-//     rewriter.replaceOpWithNewOp<OpTy>(op, newTensorDescTy,
-//                                       adaptor.getOperands(), op->getAttrs());
-//     return success();
-//   }
-
-// private:
-//   function_ref<xegpu::LayoutAttr(Value)> getLayoutOfValue;
-// };
-// /// This class is responsible for assigning the layout attributes to the ops
-// and
-// /// their users based on the layout propagation analysis result.
-// class LayoutAttrAssignment {
-// public:
-//   LayoutAttrAssignment(Operation *top,
-//                        function_ref<LayoutInfo(Value)> getLayout)
-//       : getAnalysisResult(getLayout), top(top) {}
-
-//   LogicalResult run();
-
-// private:
-//   LogicalResult assign(Operation *op);
-//   void assignToUsers(Value v, xegpu::LayoutAttr layout);
-//   xegpu::LayoutAttr getLayoutAttrForValue(Value v);
-//   LogicalResult resolveConflicts();
-//   // Callable to get the layout of a value based on the layout propagation
-//   // analysis.
-//   function_ref<LayoutInfo(Value)> getAnalysisResult;
-//   Operation *top;
-// };
-
-// } // namespace
-
-// /// Helper to assign the layout attribute to the users of the value.
-// void LayoutAttrAssignment::assignToUsers(Value v, xegpu::LayoutAttr layout) {
-//   for (OpOperand &user : v.getUses()) {
-//     Operation *owner = user.getOwner();
-//     unsigned operandNumber = user.getOperandNumber();
-//     // Use a generic name for ease of querying the layout attribute later.
-//     std::string attrName =
-//         operandLayoutNamePrefix + std::to_string(operandNumber);
-//     owner->setAttr(attrName, layout);
-//   }
-// }
-
-// /// Convert the layout assigned to a value to xegpu::LayoutAttr.
-// xegpu::LayoutAttr LayoutAttrAssignment::getLayoutAttrForValue(Value v) {
-//   llvm::errs() << "getLayoutAttrForValue: " << v << "\n";
-//   LayoutInfo layout = getAnalysisResult(v);
-//   if (!layout.isAssigned()) {
-//     llvm::errs() << "No layout assigned for value\n";
-//     return {};
-//   }
-//   SmallVector<int, 2> laneLayout, laneData;
-//   for (auto [layout, data] : llvm::zip_equal(layout.getLayoutAsArrayRef(),
-//                                              layout.getDataAsArrayRef())) {
-//     laneLayout.push_back(static_cast<int>(layout));
-//     laneData.push_back(static_cast<int>(data));
-//   }
-//   llvm::errs() << "return layout\n";
-//   return xegpu::LayoutAttr::get(v.getContext(), laneLayout, laneData);
-// }
-
-// /// Assign xegpu::LayoutAttr to the op and its users. The layout is assigned
-// /// based on the layout propagation analysis result.
-// LogicalResult LayoutAttrAssignment::assign(Operation *op) {
-//   // For function ops, propagate the function argument layout to the users.
-//   if (auto func = dyn_cast<FunctionOpInterface>(op)) {
-//     for (BlockArgument arg : func.getArguments()) {
-//       xegpu::LayoutAttr layoutInfo = getLayoutAttrForValue(arg);
-//       if (layoutInfo) {
-//         assignToUsers(arg, layoutInfo);
-//       }
-//     }
-//     return success();
-//   }
-//   // If no results, move on.
-//   if (op->getNumResults() == 0)
-//     return success();
-//   // If all the results are scalars, move on.
-//   if (llvm::all_of(op->getResultTypes(),
-//                    [](Type t) { return t.isIntOrIndexOrFloat(); }))
-//     return success();
-//   // If the op has more than one result and at least one result is a tensor
-//   // descriptor, exit. This case is not supported yet.
-//   // TODO: Support this case.
-//   if (op->getNumResults() > 1 && llvm::any_of(op->getResultTypes(), [](Type
-//   t) {
-//         return isa<xegpu::TensorDescType>(t);
-//       })) {
-//     LLVM_DEBUG(
-//         DBGS() << op->getName()
-//                << " op has more than one result and at least one is a tensor
-//                "
-//                   "descriptor. This case is not handled.\n");
-//     return failure();
-//   }
-//   // If the result is a tensor descriptor, attach the layout to the tensor
-//   // descriptor itself.
-//   if (auto tensorDescTy =
-//           dyn_cast<xegpu::TensorDescType>(op->getResultTypes()[0])) {
-//     xegpu::LayoutAttr layoutInfo = getLayoutAttrForValue(op->getResult(0));
-//     if (!layoutInfo) {
-//       LLVM_DEBUG(DBGS() << "No layout for result of " << *op << "\n");
-//       return failure();
-//     }
-
-//     // Clone the op, attach the layout to the result tensor descriptor, and
-//     // remove the original op.
-//     OpBuilder builder(op);
-//     Operation *newOp = builder.clone(*op);
-//     auto newTensorDescTy = xegpu::TensorDescType::get(
-//         tensorDescTy.getContext(), tensorDescTy.getShape(),
-//         tensorDescTy.getElementType(), tensorDescTy.getEncoding(),
-//         layoutInfo);
-//     newOp->getResult(0).setType(newTensorDescTy);
-//     op->replaceAllUsesWith(newOp->getResults());
-//     op->erase();
-//     return success();
-//   }
-//   // Otherwise simply attach the layout to the op itself.
-//   for (auto [i, r] : llvm::enumerate(op->getResults())) {
-//     xegpu::LayoutAttr layoutInfo = getLayoutAttrForValue(r);
-//     if (layoutInfo) {
-//       std::string attrName = resultLayoutNamePrefix + std::to_string(i);
-//       op->setAttr(attrName, layoutInfo);
-//       // Attach the layout attribute to the users of the result.
-//       assignToUsers(r, layoutInfo);
-//     }
-//   }
-//   return success();
-// }
-
-// /// Walk the IR and attach xegpu::LayoutAttr to all ops and their users.
-// LogicalResult LayoutAttrAssignment::run() {
-//   // auto walkResult = top->walk([&](Operation *op) {
-//   //   if (failed(assign(op)))
-//   //     return WalkResult::interrupt();
-//   //   return WalkResult::advance();
-//   // });
-
-//   // if (walkResult.wasInterrupted())
-//   //   return failure();
-//   // apply the UpdateTensorDescType pattern to all ops
-//   // RewritePatternSet patterns(top->getContext());
-//   // patterns.add<UpdateTensorDescType>(
-//   //     top->getContext(), [&](Value v) -> xegpu::LayoutAttr {
-//   //       llvm::errs() << "invoking callback for value\n";
-//   //       return getLayoutAttrForValue(v);
-//   //     });
-//   // if (failed(applyPatternsGreedily(top, std::move(patterns))))
-//   //   return failure();
-
-//   return resolveConflicts();
-// }
-
-// /// TODO: Implement the layout conflict resolution. This must ensure mainly
-// two
-// /// things:
-// /// 1) Is a given layout supported by the op? (need to query the target
-// ///    HW info). Otherwise can we achieve this layout using a layout
-// conversion?
-// /// 2) Do all the operands have the required layout? If not, can it
-// ///    be resolved using a layout conversion?
-// LogicalResult LayoutAttrAssignment::resolveConflicts() { return success(); }
-using GetLayoutCallbackFnTy = function_ref<xegpu::LayoutAttr(Value)>;
-static void updateOp(mlir::OpBuilder &builder, mlir::Operation *op,
-                     GetLayoutCallbackFnTy getLayoutOfValue) {
-
-  // Iterate over all the results.
-  for (OpResult result : op->getResults()) {
-    Type resultType = result.getType();
-    // Layouts are needed only for vector and tensor descriptor types.
-    if (!isa<VectorType, xegpu::TensorDescType>(resultType))
-      continue;
-    // If the result has any users, we expect it to have a layout.
-    xegpu::LayoutAttr layout = getLayoutOfValue(result);
-    if (!layout && result.getNumUses() > 0) {
-      LLVM_DEBUG(DBGS() << "Expecting layout for result: " << result
-                        << " but got none.\n");
-      continue;
-    }
-    if (auto tensorDescTy = dyn_cast<xegpu::TensorDescType>(resultType)) {
-      // TODO: Handle error.
-      auto typeWithLayout = xegpu::TensorDescType::get(
-          tensorDescTy.getContext(), tensorDescTy.getShape(),
-          tensorDescTy.getElementType(), tensorDescTy.getEncoding(), layout);
-      result.setType(typeWithLayout);
-      continue;
-    }
-    // If the result is a vector type, add a temporary layout attribute to the
-    // op.
-    std::string resultLayoutName =
-        resultLayoutNamePrefix + std::to_string(result.getResultNumber());
-    op->setAttr(resultLayoutName, layout);
-    // Update all users of the result with the layout.
-    for (OpOperand &user : result.getUses()) {
-      Operation *owner = user.getOwner();
-      unsigned operandNumber = user.getOperandNumber();
-      // Add temorary layout attribute at the user op.
-      std::string attrName =
-          operandLayoutNamePrefix + std::to_string(operandNumber);
-      owner->setAttr(attrName, layout);
-    }
-  }
-}
-static void updateBranchTerminatorOpInterface(
-    mlir::OpBuilder &builder,
-    mlir::RegionBranchTerminatorOpInterface terminator,
-    GetLayoutCallbackFnTy getLayoutOfValue) {
-  if (!mlir::isa<mlir::RegionBranchOpInterface>(terminator->getParentOp()))
-    return;
-
-  llvm::SmallVector<mlir::RegionSuccessor> successors;
-  llvm::SmallVector<mlir::Attribute> operands(terminator->getNumOperands(),
-                                              nullptr);
-  terminator.getSuccessorRegions(operands, successors);
-
-  for (mlir::RegionSuccessor &successor : successors) {
-    if (!successor.isParent())
-      continue;
-
-    mlir::OperandRange operands = terminator.getSuccessorOperands(successor);
-    mlir::ValueRange inputs = successor.getSuccessorInputs();
-    for (auto [operand, input] : llvm::zip(operands, inputs)) {
-      // print arg and inp
-      // llvm::errs() << "arg: " << operand << ", inp: " << input << "\n";
-      Type inputType = input.getType();
-      if (!isa<xegpu::TensorDescType>(inputType))
-        continue;
-      xegpu::LayoutAttr inputLayout = getLayoutOfValue(input);
-      xegpu::LayoutAttr operandLayout = getLayoutOfValue(operand);
-
-      if (!operandLayout) {
-        LLVM_DEBUG(DBGS() << "Expecting layout for region successor operand : "
-                          << operand << " but got none.\n");
-        continue;
-      }
-
-      if (inputLayout && inputLayout != operandLayout) {
-        LLVM_DEBUG(
-            DBGS()
-            << "Conflicting layouts for region successor operand and input: "
-            << inputLayout << " vs " << operandLayout << "\n");
-        continue;
-      }
-      // Get tensor descriptor type with the layout.
-      auto tdescTy = dyn_cast<xegpu::TensorDescType>(inputType);
-      auto newTdescTy = xegpu::TensorDescType::get(
-          tdescTy.getContext(), tdescTy.getShape(), tdescTy.getElementType(),
-          tdescTy.getEncoding(), operandLayout);
-      input.setType(newTdescTy);
-    }
-  }
-}
-static void updateBranchOpInterface(mlir::OpBuilder &builder,
-                                    mlir::RegionBranchOpInterface branch,
-                                    GetLayoutCallbackFnTy getLayoutOfValue) {
-  mlir::Operation *op = branch.getOperation();
-  llvm::SmallVector<mlir::RegionSuccessor> successors;
-  llvm::SmallVector<mlir::Attribute> operands(op->getNumOperands(), nullptr);
-  branch.getEntrySuccessorRegions(operands, successors);
-  DenseMap<Value, xegpu::LayoutAttr> resultToLayouts;
-  mlir::ValueRange results = op->getResults();
-
-  for (mlir::RegionSuccessor &successor : successors) {
-    if (successor.isParent())
-      continue;
-
-    mlir::OperandRange operands = branch.getEntrySuccessorOperands(successor);
-    mlir::ValueRange inputs = successor.getSuccessorInputs();
-
-    for (auto [operand, input, result] : llvm::zip(operands, inputs, results)) {
-      Type inputType = input.getType();
-      if (!isa<xegpu::TensorDescType>(inputType))
-        continue;
-      xegpu::LayoutAttr blockArgLayout = getLayoutOfValue(input);
-      xegpu::LayoutAttr initArgLayout = getLayoutOfValue(operand);
-
-      if (!blockArgLayout || !initArgLayout) {
-        LLVM_DEBUG(DBGS() << "No layout assigned for block arg: " << input
-                          << " or init arg: " << operand << "\n");
-        continue;
-      }
-
-      // TOOD: We expect these two to match. Data flow analysis will ensure
-      // this.
-      assert(blockArgLayout == initArgLayout &&
-             "Expexing block arg and init arg to have the same layout.");
-      // Get tensor descriptor type with the layout.
-      auto tdescTy = dyn_cast<xegpu::TensorDescType>(inputType);
-      auto newTdescTy = xegpu::TensorDescType::get(
-          tdescTy.getContext(), tdescTy.getShape(), tdescTy.getElementType(),
-          tdescTy.getEncoding(), blockArgLayout);
-      input.setType(newTdescTy);
-      // Store the layout for the result.
-      if (resultToLayouts.count(result) != 0 &&
-          resultToLayouts[result] != blockArgLayout) {
-        LLVM_DEBUG(DBGS() << "Conflicting layouts for result: " << result
-                          << " - " << resultToLayouts[result] << " vs "
-                          << blockArgLayout << "\n");
-      } else {
-        resultToLayouts[result] = blockArgLayout;
-      }
-    }
-  }
-  for (auto [i, r] : llvm::enumerate(op->getResults())) {
-    Type resultType = r.getType();
-    if (!isa<xegpu::TensorDescType, VectorType>(resultType))
-      continue;
-    xegpu::LayoutAttr layout = getLayoutOfValue(r);
-    if (!layout)
-      layout = resultToLayouts[r];
-    if (!layout) {
-      LLVM_DEBUG(DBGS() << "No layout assigned for vector/tensor desc result:"
-                        << r << "\n");
-      continue;
-    }
-    if (auto tensorDescTy = dyn_cast<xegpu::TensorDescType>(resultType)) {
-      auto newTdescTy = xegpu::TensorDescType::get(
-          tensorDescTy.getContext(), tensorDescTy.getShape(),
-          tensorDescTy.getElementType(), tensorDescTy.getEncoding(), layout);
-      r.setType(newTdescTy);
-      continue;
-    }
-    // If the result is a vector type, add a temporary layout attribute to
-    // the op.
-    std::string resultLayoutName =
-        resultLayoutNamePrefix + std::to_string(r.getResultNumber());
-    op->setAttr(resultLayoutName, layout);
-    // Update all users of the result with the layout.
-    for (OpOperand &user : r.getUses()) {
-      Operation *owner = user.getOwner();
-      unsigned operandNumber = user.getOperandNumber();
-      // Add temporary layout attribute at the user op.
-      std::string attrName =
-          operandLayoutNamePrefix + std::to_string(operandNumber);
-      owner->setAttr(attrName, layout);
-    }
-  }
-}
-
-namespace {
-
 //===----------------------------------------------------------------------===//
 // SIMT Distribution Patterns
 //===----------------------------------------------------------------------===//
@@ -1845,46 +833,6 @@ void xegpu::populateXeGPUSubgroupDistributePatterns(
 }
 
 void XeGPUSubgroupDistributePass::runOnOperation() {
-  auto &analyis = getAnalysis<RunLayoutInfoPropagation>();
-  // Print the analysis result and exit. (for testing purposes)
-  if (printOnly) {
-    auto &os = llvm::outs();
-    analyis.printAnalysisResult(os);
-    return;
-  }
-
-  auto getXeGPULayoutForValue = [&](Value val) -> xegpu::LayoutAttr {
-    LayoutInfo layout = analyis.getLayoutInfo(val);
-    if (!layout.isAssigned()) {
-      return {};
-    }
-    SmallVector<int, 2> laneLayout, laneData;
-    for (auto [layout, data] : llvm::zip_equal(layout.getLayoutAsArrayRef(),
-                                               layout.getDataAsArrayRef())) {
-      laneLayout.push_back(static_cast<int>(layout));
-      laneData.push_back(static_cast<int>(data));
-    }
-    return xegpu::LayoutAttr::get(val.getContext(), laneLayout, laneData);
-  };
-
-  mlir::OpBuilder builder(&getContext());
-  Operation *op = getOperation();
-  op->walk([&](mlir::Block *block) {
-    for (mlir::Operation &op : llvm::reverse(block->getOperations())) {
-      if (auto terminator =
-              mlir::dyn_cast<mlir::RegionBranchTerminatorOpInterface>(op)) {
-        updateBranchTerminatorOpInterface(builder, terminator,
-                                          getXeGPULayoutForValue);
-        continue;
-      }
-
-      if (auto iface = mlir::dyn_cast<mlir::RegionBranchOpInterface>(op)) {
-        updateBranchOpInterface(builder, iface, getXeGPULayoutForValue);
-        continue;
-      }
-      updateOp(builder, &op, getXeGPULayoutForValue);
-    }
-  });
 
   // Move all operations of a GPU function inside gpu.warp_execute_on_lane_0
   // operation.

>From 92c23f189b06d0dd5df702774e5788fd53c1d67b Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Thu, 5 Jun 2025 16:40:23 +0000
Subject: [PATCH 14/32] fix test

---
 .../Dialect/XeGPU/subgroup-distribution.mlir  | 252 +++++++++---------
 1 file changed, 125 insertions(+), 127 deletions(-)

diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribution.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribution.mlir
index b5f6bda26d830..0f236d4e8b9dc 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-distribution.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-distribution.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -xegpu-subgroup-distribute -cse -split-input-file %s | FileCheck %s
+// RUN: mlir-opt -xegpu-subgroup-distribute -canonicalize -cse -split-input-file %s | FileCheck %s
 
 // CHECK-LABEL: gpu.func @store_nd_1d
 // CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: memref<16xf32>) {
@@ -7,13 +7,13 @@
 // CHECK: xegpu.store_nd %[[CST]], %[[T0]]  : vector<1xf32>, !xegpu.tensor_desc<16xf32>
 // CHECK: gpu.return
 gpu.module @test {
-gpu.func @store_nd_1d(%arg0: memref<16xf32>){
-  %c0 = arith.constant 0 : index
-  %1 = arith.constant dense<1.000000e+00> : vector<16xf32>
-  %0 = xegpu.create_nd_tdesc %arg0[%c0] : memref<16xf32> -> !xegpu.tensor_desc<16xf32>
-  xegpu.store_nd %1, %0 : vector<16xf32>, !xegpu.tensor_desc<16xf32>
-  gpu.return
-}
+  gpu.func @store_nd_1d(%arg0: memref<16xf32>) {
+    %c0 = arith.constant 0 : index
+    %cst = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<1.000000e+00> : vector<16xf32>
+    %0 = xegpu.create_nd_tdesc %arg0[%c0] : memref<16xf32> -> !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
+    xegpu.store_nd %cst, %0  {layout_operand_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
+    gpu.return
+  }
 }
 
 // -----
@@ -23,13 +23,13 @@ gpu.func @store_nd_1d(%arg0: memref<16xf32>){
 // CHECK-DAG: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%{{.*}}] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
 // CHECK: xegpu.store_nd %[[CST]], %[[T0]]  : vector<16xf16>, !xegpu.tensor_desc<16x16xf16>
 gpu.module @test {
-gpu.func @store_nd_2d(%arg0: memref<16x16xf16>){
-  %c0 = arith.constant 0 : index
-  %1 = arith.constant dense<1.000000e+00> : vector<16x16xf16>
-  %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
-  xegpu.store_nd %1, %0 : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16>
-  gpu.return
-}
+  gpu.func @store_nd_2d(%arg0: memref<16x16xf16>) {
+    %c0 = arith.constant 0 : index
+    %cst = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} dense<1.000000e+00> : vector<16x16xf16>
+    %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    xegpu.store_nd %cst, %0  {layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    gpu.return
+  }
 }
 
 
@@ -42,14 +42,14 @@ gpu.func @store_nd_2d(%arg0: memref<16x16xf16>){
 // CHECK-DAG: %[[T2:.*]] = xegpu.create_nd_tdesc %[[ARG1]][%{{.*}}] : memref<16xf32> -> !xegpu.tensor_desc<16xf32>
 // CHECK: xegpu.store_nd %[[T1]], %[[T2]]  : vector<1xf32>, !xegpu.tensor_desc<16xf32>
 gpu.module @test {
-gpu.func @load_nd_1d(%arg0: memref<16xf32>, %arg1: memref<16xf32>){
-  %c0 = arith.constant 0 : index
-  %0 = xegpu.create_nd_tdesc %arg0[%c0] : memref<16xf32> -> !xegpu.tensor_desc<16xf32>
-  %1 = xegpu.load_nd %0 : !xegpu.tensor_desc<16xf32> -> vector<16xf32>
-  %2 = xegpu.create_nd_tdesc %arg1[%c0] : memref<16xf32> -> !xegpu.tensor_desc<16xf32>
-  xegpu.store_nd %1, %2 : vector<16xf32>, !xegpu.tensor_desc<16xf32>
-  gpu.return
-}
+  gpu.func @load_nd_1d(%arg0: memref<16xf32>, %arg1: memref<16xf32>) {
+    %c0 = arith.constant 0 : index
+    %0 = xegpu.create_nd_tdesc %arg0[%c0] : memref<16xf32> -> !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
+    %1 = xegpu.load_nd %0  {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} : !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>> -> vector<16xf32>
+    %2 = xegpu.create_nd_tdesc %arg1[%c0] : memref<16xf32> -> !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
+    xegpu.store_nd %1, %2  {layout_operand_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
+    gpu.return
+  }
 }
 
 // -----
@@ -60,14 +60,14 @@ gpu.func @load_nd_1d(%arg0: memref<16xf32>, %arg1: memref<16xf32>){
 // CHECK-DAG: %[[T2:.*]] = xegpu.create_nd_tdesc %[[ARG1]][%{{.*}}] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
 // CHECK: xegpu.store_nd %[[T1]], %[[T2]]  : vector<16xf16>, !xegpu.tensor_desc<16x16xf16>
 gpu.module @test {
-gpu.func @load_nd_2d(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf16>){
-  %c0 = arith.constant 0 : index
-  %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
-  %1 = xegpu.load_nd %0 : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
-  %2 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
-  xegpu.store_nd %1, %2 : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16>
-  gpu.return
-}
+  gpu.func @load_nd_2d(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf16>) {
+    %c0 = arith.constant 0 : index
+    %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    %1 = xegpu.load_nd %0  {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<16x16xf16>
+    %2 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    xegpu.store_nd %1, %2  {layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    gpu.return
+  }
 }
 
 // -----
@@ -81,15 +81,15 @@ gpu.func @load_nd_2d(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf16>){
 // CHECK-DAG: %[[T5:.*]] = vector.shape_cast %[[T3]] : vector<16x1xf16> to vector<16xf16>
 // CHECK: xegpu.store_nd %[[T5]], %[[T4]]  : vector<16xf16>, !xegpu.tensor_desc<16x16xf16>
 gpu.module @test {
-gpu.func @load_nd_array_length(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf16>){
-  %c0 = arith.constant 0 : index
-  %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2 : i64>>
-  %1 = xegpu.load_nd %0 : !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2 : i64>> -> vector<2x16x16xf16>
-  %2 = vector.extract %1[%c0] : vector<16x16xf16> from vector<2x16x16xf16>
-  %3 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
-  xegpu.store_nd %2, %3 : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16>
-  gpu.return
-}
+  gpu.func @load_nd_array_length(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf16>) {
+    %c0 = arith.constant 0 : index
+    %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2 : i64>, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    %1 = xegpu.load_nd %0  {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2 : i64>, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<2x16x16xf16>
+    %2 = vector.extract %1[%c0] {layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16x16xf16> from vector<2x16x16xf16>
+    %3 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    xegpu.store_nd %2, %3  {layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    gpu.return
+  }
 }
 
 // -----
@@ -103,17 +103,17 @@ gpu.func @load_nd_array_length(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf16
 // CHECK-DAG: %[[T5:.*]] = xegpu.create_nd_tdesc %[[ARG2]][%{{.*}}] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
 // CHECK: xegpu.store_nd %[[T4]], %[[T5]]  : vector<8xf32>, !xegpu.tensor_desc<8x16xf32>
 gpu.module @test {
-gpu.func @load_dpas_store(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg3: memref<8x16xf32>){
-  %c0 = arith.constant 0 : index
-  %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
-  %1 = xegpu.load_nd %0 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
-  %2 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
-  %3 = xegpu.load_nd %2 : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
-  %4 = xegpu.dpas %1, %3 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
-  %5 = xegpu.create_nd_tdesc %arg3[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
-  xegpu.store_nd %4, %5 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
-  gpu.return
-}
+  gpu.func @load_dpas_store(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<8x16xf32>) {
+    %c0 = arith.constant 0 : index
+    %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    %1 = xegpu.load_nd %0  {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : !xegpu.tensor_desc<8x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<8x16xf16>
+    %2 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>>
+    %3 = xegpu.load_nd %2  {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>> -> vector<16x16xf16>
+    %4 = xegpu.dpas %1, %3 {layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, layout_operand_1 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>, layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+    %5 = xegpu.create_nd_tdesc %arg2[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    xegpu.store_nd %4, %5  {layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    gpu.return
+  }
 }
 
 
@@ -131,22 +131,21 @@ gpu.func @load_dpas_store(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %ar
 // CHECK-DAG: %[[T7:.*]] = xegpu.create_nd_tdesc %[[ARG2]][{{.*}}] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
 // CHECK: xegpu.store_nd %[[T8]], %[[T7]]  : vector<8xf32>, !xegpu.tensor_desc<8x16xf32>
 gpu.module @test {
-gpu.func @load_dpas_postop_store(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg3: memref<8x16xf32>){
-  %c0 = arith.constant 0 : index
-  %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
-  %1 = xegpu.load_nd %0 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
-  %2 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
-  %3 = xegpu.load_nd %2 : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
-  %4 = xegpu.dpas %1, %3 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
-  %5 = math.exp %4 : vector<8x16xf32>
-  %6 = xegpu.create_nd_tdesc %arg3[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
-  xegpu.store_nd %5, %6 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
-  gpu.return
-}
+  gpu.func @load_dpas_postop_store(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<8x16xf32>) {
+    %c0 = arith.constant 0 : index
+    %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    %1 = xegpu.load_nd %0  {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : !xegpu.tensor_desc<8x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<8x16xf16>
+    %2 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>>
+    %3 = xegpu.load_nd %2  {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>> -> vector<16x16xf16>
+    %4 = xegpu.dpas %1, %3 {layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, layout_operand_1 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>, layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+    %5 = math.exp %4 {layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<8x16xf32>
+    %6 = xegpu.create_nd_tdesc %arg2[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    xegpu.store_nd %5, %6  {layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    gpu.return
+  }
 }
 
 // -----
-gpu.module @test {
 // CHECK-LABEL: gpu.func @create_nd_tdesc_non_memref
 // CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: ui64, %[[ARG1:[0-9a-zA-Z]+]]: ui64, %[[ARG2:[0-9a-zA-Z]+]]: index,
 // CHECK-SAME: %[[ARG3:[0-9a-zA-Z]+]]: index, %[[ARG4:[0-9a-zA-Z]+]]: index,
@@ -155,15 +154,15 @@ gpu.module @test {
 // CHECK: %[[T1:.*]] = xegpu.load_nd %[[T0]]  : !xegpu.tensor_desc<16x16xf16> -> vector<16xf16>
 // CHECK: %[[T2:.*]] = xegpu.create_nd_tdesc %[[ARG1]][{{.*}}], [%[[ARG2]], %[[ARG3]]], [%[[ARG4]], %[[ARG5]]] : ui64 -> !xegpu.tensor_desc<16x16xf16>
 // CHECK: xegpu.store_nd %[[T1]], %[[T2]]  : vector<16xf16>, !xegpu.tensor_desc<16x16xf16>
-gpu.func @create_nd_tdesc_non_memref(%arg0: ui64, %arg1: ui64,
-  %arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: index, %arg7: index) {
-  %c0 = arith.constant 0 : index
-  %0 = xegpu.create_nd_tdesc %arg0 [%c0, %c0], [%arg2, %arg3], [%arg4, %arg5] : ui64 -> !xegpu.tensor_desc<16x16xf16>
-  %1 = xegpu.load_nd %0 : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
-  %2 = xegpu.create_nd_tdesc %arg1[%c0, %c0], [%arg2, %arg3], [%arg4, %arg5] : ui64 -> !xegpu.tensor_desc<16x16xf16>
-  xegpu.store_nd %1, %2 : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16>
-  gpu.return
-}
+gpu.module @test {
+  gpu.func @create_nd_tdesc_non_memref(%arg0: ui64, %arg1: ui64, %arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: index, %arg7: index) {
+    %c0 = arith.constant 0 : index
+    %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0], [%arg2, %arg3], [%arg4, %arg5] : ui64 -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    %1 = xegpu.load_nd %0  {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<16x16xf16>
+    %2 = xegpu.create_nd_tdesc %arg1[%c0, %c0], [%arg2, %arg3], [%arg4, %arg5] : ui64 -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    xegpu.store_nd %1, %2  {layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    gpu.return
+  }
 }
 
 // -----
@@ -191,31 +190,30 @@ gpu.func @create_nd_tdesc_non_memref(%arg0: ui64, %arg1: ui64,
 // CHECK-NEXT: %[[C_FINAL:.*]] = vector.shape_cast %[[T7]]#0 : vector<8x1xf32> to vector<8xf32>
 // CHECK-NEXT: xegpu.store_nd %[[C_FINAL]], %[[T2]]  : vector<8xf32>, !xegpu.tensor_desc<8x16xf32>
 gpu.module @test {
-gpu.func @gemm_loop(%arg0: memref<1024x1024xbf16>, %arg1: memref<1024x1024xbf16>, %arg2: memref<1024x1024xf32>){
-  %c0 = arith.constant 0 : index
-  %c16 = arith.constant 16 : index
-  %c8 = arith.constant 8 : index
-  %c1024 = arith.constant 1024 : index
-  %0 = gpu.block_id x
-  %1 = gpu.block_id y
-  %2 = arith.muli %0, %c8 : index
-  %3 = arith.muli %1, %c16 : index
-  %4 = xegpu.create_nd_tdesc %arg2[%2, %3] : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32>
-  %5 = xegpu.load_nd %4 : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32>
-  %7 = xegpu.create_nd_tdesc %arg0[%2, %c0] : memref<1024x1024xbf16> -> !xegpu.tensor_desc<8x16xbf16>
-  %8 = xegpu.create_nd_tdesc %arg1[%c0, %3] : memref<1024x1024xbf16> -> !xegpu.tensor_desc<16x16xbf16>
-  %6:3 = scf.for %arg3 = %c0 to %c1024 step %c16 iter_args(%arg4 = %5, %arg5 = %7, %arg6 = %8) -> (vector<8x16xf32>, !xegpu.tensor_desc<8x16xbf16>, !xegpu.tensor_desc<16x16xbf16>) {
-    %9 = xegpu.load_nd %arg5 : !xegpu.tensor_desc<8x16xbf16> -> vector<8x16xbf16>
-    %10 = xegpu.load_nd %arg6 : !xegpu.tensor_desc<16x16xbf16> -> vector<16x16xbf16>
-    %12 = xegpu.update_nd_offset %arg5, [%c0, %c16] : !xegpu.tensor_desc<8x16xbf16>
-    %13 = xegpu.update_nd_offset %arg6, [%c16, %c0] : !xegpu.tensor_desc<16x16xbf16>
-    %11 = xegpu.dpas %9, %10, %arg4 : vector<8x16xbf16>, vector<16x16xbf16>, vector<8x16xf32> -> vector<8x16xf32>
-    scf.yield %11, %12, %13 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xbf16>, !xegpu.tensor_desc<16x16xbf16>
+  gpu.func @gemm_loop(%arg0: memref<1024x1024xbf16>, %arg1: memref<1024x1024xbf16>, %arg2: memref<1024x1024xf32>) {
+    %c0 = arith.constant 0 : index
+    %c16 = arith.constant 16 : index
+    %c8 = arith.constant 8 : index
+    %c1024 = arith.constant 1024 : index
+    %block_id_x = gpu.block_id  x
+    %block_id_y = gpu.block_id  y
+    %0 = arith.muli %block_id_x, %c8 : index
+    %1 = arith.muli %block_id_y, %c16 : index
+    %2 = xegpu.create_nd_tdesc %arg2[%0, %1] : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    %3 = xegpu.load_nd %2  {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<8x16xf32>
+    %4 = xegpu.create_nd_tdesc %arg0[%0, %c0] : memref<1024x1024xbf16> -> !xegpu.tensor_desc<8x16xbf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    %5 = xegpu.create_nd_tdesc %arg1[%c0, %1] : memref<1024x1024xbf16> -> !xegpu.tensor_desc<16x16xbf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>>
+    %6:3 = scf.for %arg3 = %c0 to %c1024 step %c16 iter_args(%arg4 = %3, %arg5 = %4, %arg6 = %5) -> (vector<8x16xf32>, !xegpu.tensor_desc<8x16xbf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>, !xegpu.tensor_desc<16x16xbf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>>) {
+      %8 = xegpu.load_nd %arg5  {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : !xegpu.tensor_desc<8x16xbf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<8x16xbf16>
+      %9 = xegpu.load_nd %arg6  {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} : !xegpu.tensor_desc<16x16xbf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>> -> vector<16x16xbf16>
+      %10 = xegpu.update_nd_offset %arg5, [%c0, %c16] : !xegpu.tensor_desc<8x16xbf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+      %11 = xegpu.update_nd_offset %arg6, [%c16, %c0] : !xegpu.tensor_desc<16x16xbf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>>
+      %12 = xegpu.dpas %8, %9, %arg4 {layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, layout_operand_1 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>, layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<8x16xbf16>, vector<16x16xbf16>, vector<8x16xf32> -> vector<8x16xf32>
+      scf.yield {layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} %12, %10, %11 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xbf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>, !xegpu.tensor_desc<16x16xbf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>>
+    } {layout_operand_3 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+    xegpu.store_nd %6#0, %2  {layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    gpu.return
   }
-  %12 = xegpu.create_nd_tdesc %arg2[%2, %3] : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32>
-  xegpu.store_nd %6#0, %12 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
-  gpu.return
-}
 }
 
 // -----
@@ -226,15 +224,15 @@ gpu.func @gemm_loop(%arg0: memref<1024x1024xbf16>, %arg1: memref<1024x1024xbf16>
 // CHECK: %[[T1:.*]] = xegpu.update_nd_offset %[[T0]], [%c32] : !xegpu.tensor_desc<16xf32>
 // CHECK: xegpu.store_nd %[[CST]], %[[T1]]  : vector<1xf32>, !xegpu.tensor_desc<16xf32>
 gpu.module @test {
-gpu.func @update_nd_offset_1d(%arg0: memref<256xf32>){
-  %c0 = arith.constant 0 : index
-  %c32 = arith.constant 32 : index
-  %1 = arith.constant dense<1.000000e+00> : vector<16xf32>
-  %0 = xegpu.create_nd_tdesc %arg0[%c0] : memref<256xf32> -> !xegpu.tensor_desc<16xf32>
-  %2 = xegpu.update_nd_offset %0, [%c32] : !xegpu.tensor_desc<16xf32>
-  xegpu.store_nd %1, %2 : vector<16xf32>, !xegpu.tensor_desc<16xf32>
-  gpu.return
-}
+  gpu.func @update_nd_offset_1d(%arg0: memref<256xf32>) {
+    %c0 = arith.constant 0 : index
+    %c32 = arith.constant 32 : index
+    %cst = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<1.000000e+00> : vector<16xf32>
+    %0 = xegpu.create_nd_tdesc %arg0[%c0] : memref<256xf32> -> !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
+    %1 = xegpu.update_nd_offset %0, [%c32] : !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
+    xegpu.store_nd %cst, %1  {layout_operand_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
+    gpu.return
+  }
 }
 
 // -----
@@ -245,15 +243,15 @@ gpu.func @update_nd_offset_1d(%arg0: memref<256xf32>){
 // CHECK: %[[T1:.*]] = xegpu.update_nd_offset %[[T0]], [%c32, %c32] : !xegpu.tensor_desc<16x16xf32>
 // CHECK: xegpu.store_nd %[[CST]], %[[T1]]  : vector<16xf32>, !xegpu.tensor_desc<16x16xf32>
 gpu.module @test {
-gpu.func @update_nd_offset_2d(%arg0: memref<256x256xf32>){
-  %c0 = arith.constant 0 : index
-  %c32 = arith.constant 32 : index
-  %1 = arith.constant dense<1.000000e+00> : vector<16x16xf32>
-  %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<256x256xf32> -> !xegpu.tensor_desc<16x16xf32>
-  %2 = xegpu.update_nd_offset %0, [%c32, %c32] : !xegpu.tensor_desc<16x16xf32>
-  xegpu.store_nd %1, %2 : vector<16x16xf32>, !xegpu.tensor_desc<16x16xf32>
-  gpu.return
-}
+  gpu.func @update_nd_offset_2d(%arg0: memref<256x256xf32>) {
+    %c0 = arith.constant 0 : index
+    %c32 = arith.constant 32 : index
+    %cst = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} dense<1.000000e+00> : vector<16x16xf32>
+    %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<256x256xf32> -> !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    %1 = xegpu.update_nd_offset %0, [%c32, %c32] : !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    xegpu.store_nd %cst, %1  {layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16x16xf32>, !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    gpu.return
+  }
 }
 
 // -----
@@ -262,12 +260,12 @@ gpu.func @update_nd_offset_2d(%arg0: memref<256x256xf32>){
 // CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%{{.*}}] : memref<256x256xf16> -> !xegpu.tensor_desc<16x16xf16>
 // CHECK: xegpu.prefetch_nd %[[T0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<16x16xf16>
 gpu.module @test {
-gpu.func @prefetch_2d(%arg0: memref<256x256xf16>){
-  %c0 = arith.constant 0 : index
-  %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<256x256xf16> -> !xegpu.tensor_desc<16x16xf16>
-  xegpu.prefetch_nd %0 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>: !xegpu.tensor_desc<16x16xf16>
-  gpu.return
-}
+  gpu.func @prefetch_2d(%arg0: memref<256x256xf16>) {
+    %c0 = arith.constant 0 : index
+    %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<256x256xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    xegpu.prefetch_nd %0 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    gpu.return
+  }
 }
 
 // -----
@@ -276,10 +274,10 @@ gpu.func @prefetch_2d(%arg0: memref<256x256xf16>){
 // CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%{{.*}}] : memref<256xf16> -> !xegpu.tensor_desc<16xf16>
 // CHECK: xegpu.prefetch_nd %[[T0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<16xf16>
 gpu.module @test {
-gpu.func @prefetch_1d(%arg0: memref<256xf16>){
-  %c0 = arith.constant 0 : index
-  %0 = xegpu.create_nd_tdesc %arg0[%c0] : memref<256xf16> -> !xegpu.tensor_desc<16xf16>
-  xegpu.prefetch_nd %0 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>: !xegpu.tensor_desc<16xf16>
-  gpu.return
-}
+  gpu.func @prefetch_1d(%arg0: memref<256xf16>) {
+    %c0 = arith.constant 0 : index
+    %0 = xegpu.create_nd_tdesc %arg0[%c0] : memref<256xf16> -> !xegpu.tensor_desc<16xf16, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
+    xegpu.prefetch_nd %0 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<16xf16, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
+    gpu.return
+  }
 }

>From 7b69082fa2fd3d54ac164ebeae43ed464ab30d6a Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Thu, 5 Jun 2025 16:46:24 +0000
Subject: [PATCH 15/32] fix names

---
 .../{subgroup-map-propagation.mlir => layout-propagate.mlir}      | 0
 .../{subgroup-distribution.mlir => subgroup-distribute.mlir}      | 0
 2 files changed, 0 insertions(+), 0 deletions(-)
 rename mlir/test/Dialect/XeGPU/{subgroup-map-propagation.mlir => layout-propagate.mlir} (100%)
 rename mlir/test/Dialect/XeGPU/{subgroup-distribution.mlir => subgroup-distribute.mlir} (100%)

diff --git a/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir b/mlir/test/Dialect/XeGPU/layout-propagate.mlir
similarity index 100%
rename from mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
rename to mlir/test/Dialect/XeGPU/layout-propagate.mlir
diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribution.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
similarity index 100%
rename from mlir/test/Dialect/XeGPU/subgroup-distribution.mlir
rename to mlir/test/Dialect/XeGPU/subgroup-distribute.mlir

>From 56696165ff7886a802d1334f0826e50373d47b2b Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Thu, 5 Jun 2025 17:42:15 +0000
Subject: [PATCH 16/32] func op iface support

---
 .../XeGPU/Transforms/XeGPULayoutPropagate.cpp | 53 +++++++++++++++++--
 1 file changed, 49 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutPropagate.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutPropagate.cpp
index f308d338b511a..d876110fe2692 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutPropagate.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutPropagate.cpp
@@ -873,6 +873,46 @@ static void updateBranchOpInterface(mlir::OpBuilder &builder,
   }
 }
 
+static void updateFunctionOpInterface(mlir::OpBuilder &builder,
+                                      mlir::FunctionOpInterface funcOp,
+                                      GetLayoutCallbackFnTy getLayoutOfValue) {
+  SmallVector<Type> newArgTypes;
+  // Update the function arguments.
+  for (BlockArgument arg : funcOp.getArguments()) {
+    Type argType = arg.getType();
+    newArgTypes.push_back(argType);
+    if (!isa<VectorType, xegpu::TensorDescType>(argType))
+      continue;
+    xegpu::LayoutAttr layout = getLayoutOfValue(arg);
+    if (!layout) {
+      LLVM_DEBUG(DBGS() << "Expecting layout for function argument: " << arg
+                        << " but got none.\n");
+      continue;
+    }
+    if (auto tensorDescTy = dyn_cast<xegpu::TensorDescType>(argType)) {
+      auto newTdescTy = xegpu::TensorDescType::get(
+          tensorDescTy.getContext(), tensorDescTy.getShape(),
+          tensorDescTy.getElementType(), tensorDescTy.getEncoding(), layout);
+      arg.setType(newTdescTy);
+      newArgTypes.back() = newTdescTy;
+      continue;
+    }
+    // If the argument is a vector type, update all the users of the argument
+    // with the layout.
+    for (OpOperand &user : arg.getUses()) {
+      Operation *owner = user.getOwner();
+      unsigned operandNumber = user.getOperandNumber();
+      std::string attrName =
+          operandLayoutNamePrefix + std::to_string(operandNumber);
+      owner->setAttr(attrName, layout);
+    }
+  }
+  // Update the function type with the new argument types.
+  // NOTE: We assume that function results are not expected to have layouts.
+  funcOp.setType(FunctionType::get(funcOp.getContext(), newArgTypes,
+                                   funcOp.getResultTypes()));
+}
+
 namespace {
 
 struct XeGPULayoutPropagatePass final
@@ -903,15 +943,20 @@ void XeGPULayoutPropagatePass::runOnOperation() {
   Operation *op = getOperation();
   op->walk([&](mlir::Block *block) {
     for (mlir::Operation &op : llvm::reverse(block->getOperations())) {
-      if (auto terminator =
+      if (auto branchTermOp =
               mlir::dyn_cast<mlir::RegionBranchTerminatorOpInterface>(op)) {
-        updateBranchTerminatorOpInterface(builder, terminator,
+        updateBranchTerminatorOpInterface(builder, branchTermOp,
                                           getXeGPULayoutForValue);
         continue;
       }
 
-      if (auto iface = mlir::dyn_cast<mlir::RegionBranchOpInterface>(op)) {
-        updateBranchOpInterface(builder, iface, getXeGPULayoutForValue);
+      if (auto regionBrOp = mlir::dyn_cast<mlir::RegionBranchOpInterface>(op)) {
+        updateBranchOpInterface(builder, regionBrOp, getXeGPULayoutForValue);
+        continue;
+      }
+
+      if (auto funcOp = mlir::dyn_cast<mlir::FunctionOpInterface>(op)) {
+        updateFunctionOpInterface(builder, funcOp, getXeGPULayoutForValue);
         continue;
       }
       updateOp(builder, &op, getXeGPULayoutForValue);

>From 71902aa6c8eb28ee13c7b802951ae5a5c1195ef7 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Thu, 5 Jun 2025 20:00:53 +0000
Subject: [PATCH 17/32] fix test

---
 mlir/test/Dialect/XeGPU/layout-propagate.mlir | 511 +++++-------------
 1 file changed, 134 insertions(+), 377 deletions(-)

diff --git a/mlir/test/Dialect/XeGPU/layout-propagate.mlir b/mlir/test/Dialect/XeGPU/layout-propagate.mlir
index c7c82fc8dbb3c..f698b997e8cb7 100644
--- a/mlir/test/Dialect/XeGPU/layout-propagate.mlir
+++ b/mlir/test/Dialect/XeGPU/layout-propagate.mlir
@@ -1,29 +1,16 @@
-// RUN: mlir-opt -xegpu-subgroup-distribute='print-analysis-only=true' -split-input-file %s | FileCheck %s
+// RUN: mlir-opt -xegpu-layout-propagate -split-input-file %s | FileCheck %s
 
-// CHECK: function: test_dpas_f16:
-// CHECK-NEXT: argument: <block argument> of type 'memref<8x16xf16>' at index: 0
-// CHECK-NEXT: layout  : Not assigned.
-// CHECK-NEXT: argument: <block argument> of type 'memref<16x16xf16>' at index: 1
-// CHECK-NEXT: layout  : Not assigned.
-// CHECK-NEXT: argument: <block argument> of type 'memref<8x16xf32>' at index: 2
-// CHECK-NEXT: layout  : Not assigned.
-// CHECK-NEXT: op    : %{{.*}} = arith.constant 0 : index
-// CHECK-NEXT: layout for result #0: Not assigned.
-// CHECK-NEXT: op    : %{{.*}} = arith.constant dense<0.000000e+00> : vector<8x16xf32>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 1]
-// CHECK-NEXT: op    : %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 1]
-// CHECK-NEXT: op    : %[[T1:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [2, 1]
-// CHECK-NEXT: op    : %[[T2:.*]] = xegpu.load_nd %[[T0]]  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 1]
-// CHECK-NEXT: op    : %[[T3:.*]] = xegpu.load_nd %[[T1]]  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [2, 1]
-// CHECK-NEXT: op    : %[[T4:.*]] = xegpu.dpas %[[T2]], %[[T3]], %{{.*}} : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 1]
-// CHECK-NEXT: op    : %{{.*}} = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 1]
-func.func @test_dpas_f16(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<8x16xf32>) {
+// CHECK-LABEL: func.func @dpas_f16(
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<8x16xf16>, %[[ARG1:[0-9a-zA-Z]+]]: memref<16x16xf16>, %[[ARG2:[0-9a-zA-Z]+]]: memref<8x16xf32>) {
+// CHECK: %[[CST:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} dense<0.000000e+00> : vector<8x16xf32>
+// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]][{{.*}}] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+// CHECK: %[[T1:.*]] = xegpu.create_nd_tdesc %[[ARG1]][{{.*}}] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>>
+// CHECK: %[[T2:.*]] = xegpu.load_nd %[[T0]]  {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : !xegpu.tensor_desc<8x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<8x16xf16>
+// CHECK: %[[T3:.*]] = xegpu.load_nd %[[T1]]  {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>> -> vector<16x16xf16>
+// CHECK: %[[T4:.*]] = xegpu.dpas %[[T2]], %[[T3]], %[[CST]] {layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, layout_operand_1 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>, layout_operand_2 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
+// CHECK: %[[T5:.*]] = xegpu.create_nd_tdesc %[[ARG2]][{{.*}}] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+// CHECK: xegpu.store_nd %[[T4]], %[[T5]]  {layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+func.func @dpas_f16(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<8x16xf32>) {
   %c0 = arith.constant 0 : index
   %cst = arith.constant dense<0.000000e+00> : vector<8x16xf32>
   %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
@@ -36,22 +23,11 @@ func.func @test_dpas_f16(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg
   return
 }
 
-
 // -----
-// CHECK: function: test_dpas_i8:
-// CHECK-NEXT: argument: <block argument> of type 'vector<8x32xi8>' at index: 0
-// CHECK-NEXT: layout  : lane_layout: [1, 16], lane_data: [1, 2]
-// CHECK-NEXT: argument: <block argument> of type 'vector<32x16xi8>' at index: 1
-// CHECK-NEXT: layout  : lane_layout: [1, 16], lane_data: [4, 1]
-// CHECK-NEXT: argument: <block argument> of type 'memref<8x16xi32>' at index: 2
-// CHECK-NEXT: layout  : Not assigned.
-// CHECK-NEXT: op    : %{{.*}} = arith.constant 0 : index
-// CHECK-NEXT: layout for result #0: Not assigned.
-// CHECK-NEXT: op    : %[[T0:.*]] = xegpu.dpas %{{.*}} : vector<8x32xi8>, vector<32x16xi8> -> vector<8x16xi32>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 1]
-// CHECK-NEXT: op    : %[[T1:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xi32> -> !xegpu.tensor_desc<8x16xi32>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 1]
-func.func @test_dpas_i8(%arg0: vector<8x32xi8>, %arg1: vector<32x16xi8>, %arg2: memref<8x16xi32>) {
+// CHECK-LABEL: func.func @dpas_i8(
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: vector<8x32xi8>, %[[ARG1:[0-9a-zA-Z]+]]: vector<32x16xi8>, %[[ARG2:[0-9a-zA-Z]+]]: memref<8x16xi32>) {
+// CHECK: %[[T0:.*]] = xegpu.dpas %[[ARG0]], %[[ARG1]] {layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]>, layout_operand_1 = #xegpu.layout<lane_layout = [1, 16], lane_data = [4, 1]>, layout_result_0 = #xegpu.layout<lane_layout = [1, 16],
+func.func @dpas_i8(%arg0: vector<8x32xi8>, %arg1: vector<32x16xi8>, %arg2: memref<8x16xi32>) {
   %c0 = arith.constant 0 : index
   %0 = xegpu.dpas %arg0, %arg1 : vector<8x32xi8>, vector<32x16xi8> -> vector<8x16xi32>
   %1 = xegpu.create_nd_tdesc %arg2[%c0, %c0] : memref<8x16xi32> -> !xegpu.tensor_desc<8x16xi32>
@@ -60,30 +36,10 @@ func.func @test_dpas_i8(%arg0: vector<8x32xi8>, %arg1: vector<32x16xi8>, %arg2:
 }
 
 // -----
-// CHECK: function: test_load_with_transpose_effect:
-// CHECK-NEXT: argument: <block argument> of type 'memref<8x16xf16>' at index: 0
-// CHECK-NEXT: layout  : Not assigned.
-// CHECK-NEXT: argument: <block argument> of type 'memref<16x16xf16>' at index: 1
-// CHECK-NEXT: layout  : Not assigned.
-// CHECK-NEXT: argument: <block argument> of type 'memref<8x16xf32>' at index: 2
-// CHECK-NEXT: layout  : Not assigned.
-// CHECK-NEXT: op    : %{{.*}} = arith.constant 0 : index
-// CHECK-NEXT: layout for result #0: Not assigned.
-// CHECK-NEXT: op    : %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<8x16xf32>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 1]
-// CHECK-NEXT: op    : %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 1]
-// CHECK-NEXT: op    : %[[T1:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
-// CHECK-NEXT: layout for result #0: lane_layout: [16, 1], lane_data: [1, 2]
-// CHECK-NEXT: op    : %[[T2:.*]] = xegpu.load_nd %[[T0]]  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 1]
-// CHECK-NEXT: op    : %[[T3:.*]] = xegpu.load_nd %[[T1]] <{transpose = array<i64: 1, 0>}> : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [2, 1]
-// CHECK-NEXT: op    : %[[T4:.*]] = xegpu.dpas %[[T2]], %[[T3]], %[[CST]] : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 1]
-// CHECK-NEXT: op    : %[[T5:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 1]
-func.func @test_load_with_transpose_effect(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<8x16xf32>) {
+// CHECK-LABEL: func.func @load_with_transpose_effect(
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<8x16xf16>, %[[ARG0:[0-9a-zA-Z]+]]: memref<16x16xf16>, %[[ARG0:[0-9a-zA-Z]+]]: memref<8x16xf32>) {
+// CHECK: %{{.*}} = xegpu.load_nd %{{.*}} <{transpose = array<i64: 1, 0>}> {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>> -> vector<16x16xf16>
+func.func @load_with_transpose_effect(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<8x16xf32>) {
   %c0 = arith.constant 0 : index
   %cst = arith.constant dense<0.000000e+00> : vector<8x16xf32>
   %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
@@ -97,32 +53,10 @@ func.func @test_load_with_transpose_effect(%arg0: memref<8x16xf16>, %arg1: memre
 }
 
 // -----
-// CHECK: function: test_vector_transpose:
-// CHECK-NEXT: argument: <block argument> of type 'memref<8x16xf16>' at index: 0
-// CHECK-NEXT: layout  : Not assigned.
-// CHECK-NEXT: argument: <block argument> of type 'memref<16x16xf16>' at index: 1
-// CHECK-NEXT: layout  : Not assigned.
-// CHECK-NEXT: argument: <block argument> of type 'memref<8x16xf32>' at index: 2
-// CHECK-NEXT: layout  : Not assigned.
-// CHECK-NEXT: op    : %{{.*}} = arith.constant 0 : index
-// CHECK-NEXT: layout for result #0: Not assigned.
-// CHECK-NEXT: op    : %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<8x16xf32>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 1]
-// CHECK-NEXT: op    : %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 1]
-// CHECK-NEXT: op    : %[[T1:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
-// CHECK-NEXT: layout for result #0: lane_layout: [16, 1], lane_data: [1, 2]
-// CHECK-NEXT: op    : %[[T2:.*]] = xegpu.load_nd %[[T0]]  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 1]
-// CHECK-NEXT: op    : %[[T3:.*]] = xegpu.load_nd %[[T1]]  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
-// CHECK-NEXT: layout for result #0: lane_layout: [16, 1], lane_data: [1, 2]
-// CHECK-NEXT: op    : %[[T4:.*]] = vector.transpose %[[T3]], [1, 0] : vector<16x16xf16> to vector<16x16xf16>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [2, 1]
-// CHECK-NEXT: op    : %[[T5:.*]] = xegpu.dpas %[[T2]], %[[T4]], %[[CST]] : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 1]
-// CHECK-NEXT: op    : %[[T6:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 1]
-func.func @test_vector_transpose(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<8x16xf32>) {
+// CHECK-LABEL: func.func @vector_transpose(
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<8x16xf16>, %[[ARG1:[0-9a-zA-Z]+]]: memref<16x16xf16>, %[[ARG2:[0-9a-zA-Z]+]]: memref<8x16xf32>) {
+// CHECK: %{{.*}} = vector.transpose %{{.*}}, [1, 0] {layout_operand_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>, layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} : vector<16x16xf16> to vector<16x16xf16>
+func.func @vector_transpose(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<8x16xf32>) {
   %c0 = arith.constant 0 : index
   %cst = arith.constant dense<0.000000e+00> : vector<8x16xf32>
   %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
@@ -137,22 +71,11 @@ func.func @test_vector_transpose(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf1
 }
 
 // -----
-// CHECK: function: test_extf_truncf:
-// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<8x16xf16>' at index: 0
-// CHECK-NEXT: layout  : lane_layout: [1, 16], lane_data: [1, 1]
-// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<16x16xf16>' at index: 1
-// CHECK-NEXT: layout  : lane_layout: [1, 16], lane_data: [2, 1]
-// CHECK-NEXT: op    : %[[T0:.*]] = xegpu.load_nd %{{.*}}  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 1]
-// CHECK-NEXT: op    : %[[T1:.*]] = xegpu.load_nd %{{.*}}  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [2, 1]
-// CHECK-NEXT: op    : %[[T2:.*]] = arith.extf %[[T1]] : vector<16x16xf16> to vector<16x16xf32>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [2, 1]
-// CHECK-NEXT: op    : %[[T3:.*]] = arith.truncf %[[T2]] : vector<16x16xf32> to vector<16x16xf16>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [2, 1]
-// CHECK-NEXT: op    : %[[T4:.*]] = xegpu.dpas %[[T0]], %[[T3]] : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
-// CHECK-NEXT: layout for result #0: Not assigned.
-func.func @test_extf_truncf(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>) -> vector<8x16xf32> {
+// CHECK-LABEL: func.func @extf_truncf(
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<8x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>, %[[ARG1:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>>) -> vector<8x16xf32> {
+// CHECK: %[[T2:.*]] = arith.extf %{{.*}} {layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>, layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} : vector<16x16xf16> to vector<16x16xf32>
+// CHECK-NEXT: %{{.*}} = arith.truncf %[[T2]] {layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>, layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} : vector<16x16xf32> to vector<16x16xf16>
+func.func @extf_truncf(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>) -> vector<8x16xf32> {
   %0 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
   %1 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
   %2 = arith.extf %1 : vector<16x16xf16> to vector<16x16xf32>
@@ -162,32 +85,13 @@ func.func @test_extf_truncf(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.t
 }
 
 // -----
-// CHECK: function: test_load_gather_with_transpose_effect:
-// CHECK-NEXT: argument: <block argument> of type 'memref<8x16xf16>' at index: 0
-// CHECK-NEXT: layout  : Not assigned.
-// CHECK-NEXT: argument: <block argument> of type 'memref<256xf16>' at index: 1
-// CHECK-NEXT: layout  : Not assigned.
-// CHECK-NEXT: argument: <block argument> of type 'memref<8x16xf32>' at index: 2
-// CHECK-NEXT: layout  : Not assigned.
-// CHECK-NEXT: op    : %{{.*}} = arith.constant 0 : index
-// CHECK-NEXT: layout for result #0: Not assigned.
-// CHECK-NEXT: op    : %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 1]
-// CHECK-NEXT: op    : %[[T1:.*]] = xegpu.load_nd %[[T0]]  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 1]
-// CHECK-NEXT: op    : %[[CST:.*]] = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
-// CHECK-NEXT: layout for result #0: lane_layout: [16], lane_data: [1]
-// CHECK-NEXT: op    : %[[CST0:.*]] = arith.constant dense<true> : vector<16xi1>
-// CHECK-NEXT: layout for result #0: lane_layout: [16], lane_data: [1]
-// CHECK-NEXT: op    : %[[T2:.*]] = xegpu.create_tdesc %{{.*}}, %[[CST]] : memref<256xf16>, vector<16xindex> -> !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16 : i64>>
-// CHECK-NEXT: layout for result #0: lane_layout: [16, 1], lane_data: [1, 2]
-// CHECK-NEXT: op    : %[[T3:.*]] = xegpu.load %[[T2]], %[[CST0]] <{transpose}> : !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16 : i64>>, vector<16xi1> -> vector<16x16xf16>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [2, 1]
-// CHECK-NEXT: op    : %[[T4:.*]] = xegpu.dpas %[[T1]], %[[T3]] : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 1]
-// CHECK-NEXT: op    : %[[T5:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 1]
-func.func @test_load_gather_with_transpose_effect(%arg0: memref<8x16xf16>, %arg1: memref<256xf16>, %arg2: memref<8x16xf32>) {
+// CHECK-LABEL: func.func @load_gather_with_transpose_effect(
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<8x16xf16>, %[[ARG1:[0-9a-zA-Z]+]]: memref<256xf16>, %[[ARG2:[0-9a-zA-Z]+]]: memref<8x16xf32>) {
+// CHECK: %[[CST:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
+// CHECK-NEXT: %[[CST0:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<true> : vector<16xi1>
+// CHECK-NEXT: %[[T2:.*]] = xegpu.create_tdesc %[[ARG1]], %[[CST]] {layout_operand_1 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} : memref<256xf16>, vector<16xindex> -> !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16 : i64>, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>>
+// CHECK-NEXT: %{{.*}} = xegpu.load %[[T2]], %[[CST0]] <{transpose}> {layout_operand_1 = #xegpu.layout<lane_layout = [16], lane_data = [1]>, layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} : !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16 : i64>, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>>, vector<16xi1> -> vector<16x16xf16>
+func.func @load_gather_with_transpose_effect(%arg0: memref<8x16xf16>, %arg1: memref<256xf16>, %arg2: memref<8x16xf32>) {
   %c0 = arith.constant 0 : index
   %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
   %1 = xegpu.load_nd %0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
@@ -202,20 +106,13 @@ func.func @test_load_gather_with_transpose_effect(%arg0: memref<8x16xf16>, %arg1
 }
 
 // -----
-// CHECK: function: test_load_gather_1d:
-// CHECK: argument: <block argument> of type 'memref<256xf32>' at index: 0
-// CHECK-NEXT: layout  : Not assigned.
-// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<16xf32>' at index: 1
-// CHECK-NEXT: layout  : lane_layout: [16], lane_data: [1]
-// CHECK-NEXT: op    : %[[CST:.*]] = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
-// CHECK-NEXT: layout for result #0: lane_layout: [16], lane_data: [1]
-// CHECK-NEXT: op    : %[[CST0:.*]] = arith.constant dense<true> : vector<16xi1>
-// CHECK-NEXT: layout for result #0: lane_layout: [16], lane_data: [1]
-// CHECK-NEXT: op    : %[[T0:.*]] = xegpu.create_tdesc %{{.*}}, %[[CST]] : memref<256xf32>, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
-// CHECK-NEXT: layout for result #0: lane_layout: [16], lane_data: [1]
-// CHECK-NEXT: op    : %[[T1]] = xegpu.load %[[T0]], %[[CST0]]  : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> -> vector<16xf32>
-// CHECK-NEXT: layout for result #0: lane_layout: [16], lane_data: [1]
-func.func @test_load_gather_1d(%arg0: memref<256xf32>, %arg1: !xegpu.tensor_desc<16xf32>) {
+// CHECK-LABEL: func.func @load_gather_1d(
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<256xf32>, %[[ARG1:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>>) {
+// CHECK: %[[CST:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
+// CHECK-NEXT: %[[CST0:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<true> : vector<16xi1>
+// CHECK-NEXT: %[[T0:.*]] = xegpu.create_tdesc %[[ARG0]], %[[CST]] {layout_operand_1 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} : memref<256xf32>, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
+// CHECK-NEXT: %{{.*}} = xegpu.load %[[T0]], %[[CST0]]  {layout_operand_1 = #xegpu.layout<lane_layout = [16], lane_data = [1]>, layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<lane_layout = [16], lane_data = [1]>>, vector<16xi1> -> vector<16xf32>
+func.func @load_gather_1d(%arg0: memref<256xf32>, %arg1: !xegpu.tensor_desc<16xf32>) {
   %cst = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
   %cst_0 = arith.constant dense<true> : vector<16xi1>
   %0 = xegpu.create_tdesc %arg0, %cst : memref<256xf32>, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
@@ -225,18 +122,11 @@ func.func @test_load_gather_1d(%arg0: memref<256xf32>, %arg1: !xegpu.tensor_desc
 }
 
 // -----
-// CHECK: function: test_store_scatter_with_transpose_effect:
-// CHECK-NEXT: argument: <block argument> of type 'memref<128xf32>' at index: 0
-// CHECK-NEXT: layout  : Not assigned.
-// CHECK-NEXT: op    : %[[CST:.*]] = arith.constant dense<1.000000e+00> : vector<8x16xf32>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 1]
-// CHECK-NEXT: op    : %[[CST0:.*]] = arith.constant dense<true> : vector<16xi1>
-// CHECK-NEXT: layout for result #0: lane_layout: [16], lane_data: [1]
-// CHECK-NEXT: op    : %[[CST1:.*]] = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
-// CHECK-NEXT: layout for result #0: lane_layout: [16], lane_data: [1]
-// CHECK-NEXT: op    : %[[T0:.*]] = xegpu.create_tdesc %{{.*}}, %[[CST1]] : memref<128xf32>, vector<16xindex> -> !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<chunk_size = 8 : i64>>
-// CHECK-NEXT: layout for result #0: lane_layout: [16, 1], lane_data: [1, 1]
-func.func @test_store_scatter_with_transpose_effect(%arg0: memref<128xf32>) {
+// CHECK-LABEL: func.func @store_scatter_with_transpose_effect(
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<128xf32>) {
+// CHECK: %[[T0:.*]] = xegpu.create_tdesc %[[ARG0]], %{{.*}} {layout_operand_1 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} : memref<128xf32>, vector<16xindex> -> !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<chunk_size = 8 : i64>, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>>
+// CHECK-NEXT: xegpu.store %{{.*}}, %[[T0]], %{{.*}} <{transpose}> {layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, layout_operand_2 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} : vector<8x16xf32>, !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<chunk_size = 8 : i64>, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>>, vector<16xi1>
+func.func @store_scatter_with_transpose_effect(%arg0: memref<128xf32>) {
   %cst = arith.constant dense<1.000000e+00> : vector<8x16xf32>
   %cst_0 = arith.constant dense<true> : vector<16xi1>
   %cst_1 = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
@@ -246,18 +136,10 @@ func.func @test_store_scatter_with_transpose_effect(%arg0: memref<128xf32>) {
 }
 
 // -----
-// CHECK: function: test_store_scatter_1d:
-// CHECK-NEXT: argument: <block argument> of type 'vector<16xf32>' at index: 0
-// CHECK-NEXT: layout  : lane_layout: [16], lane_data: [1]
-// CHECK-NEXT: argument: <block argument> of type 'memref<256xf32>' at index: 1
-// CHECK-NEXT: layout  : Not assigned.
-// CHECK-NEXT: op    : %[[CST:.*]] = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
-// CHECK-NEXT: layout for result #0: lane_layout: [16], lane_data: [1]
-// CHECK-NEXT: op    : %[[CST1:.*]] = arith.constant dense<true> : vector<16xi1>
-// CHECK-NEXT: layout for result #0: lane_layout: [16], lane_data: [1]
-// CHECK-NEXT: op    : %[[T0:.*]] = xegpu.create_tdesc %{{.*}}, %[[CST]] : memref<256xf32>, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
-// CHECK-NEXT: layout for result #0: lane_layout: [16], lane_data: [1]
-func.func @test_store_scatter_1d(%arg0: vector<16xf32>, %arg1: memref<256xf32>) {
+// CHECK-LABEL: func.func @store_scatter_1d(
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: vector<16xf32>, %[[ARG1:[0-9a-zA-Z]+]]: memref<256xf32>) {
+// CHECK: xegpu.store %[[ARG0]], %{{.*}}, %{{.*}}  {layout_operand_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>, layout_operand_2 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<lane_layout = [16], lane_data = [1]>>, vector<16xi1>
+func.func @store_scatter_1d(%arg0: vector<16xf32>, %arg1: memref<256xf32>) {
   %cst = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
   %cst_0 = arith.constant dense<true> : vector<16xi1>
   %0 = xegpu.create_tdesc %arg1, %cst : memref<256xf32>, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
@@ -266,30 +148,10 @@ func.func @test_store_scatter_1d(%arg0: vector<16xf32>, %arg1: memref<256xf32>)
 }
 
 // -----
-// CHECK: function: test_vector_bitcast_i16_to_i8:
-// CHECK-NEXT: argument: <block argument> of type 'memref<8x16xi16>' at index: 0
-// CHECK-NEXT: layout  : Not assigned.
-// CHECK-NEXT: argument: <block argument> of type 'memref<32x16xi8>' at index: 1
-// CHECK-NEXT: layout  : Not assigned.
-// CHECK-NEXT: argument: <block argument> of type 'memref<8x16xi32>' at index: 2
-// CHECK-NEXT: layout  : Not assigned.
-// CHECK-NEXT: op    : %{{.*}} = arith.constant 0 : index
-// CHECK-NEXT: layout for result #0: Not assigned.
-// CHECK-NEXT: op    : %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xi16> -> !xegpu.tensor_desc<8x16xi16>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 1]
-// CHECK-NEXT: op    : %[[T1:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<32x16xi8> -> !xegpu.tensor_desc<32x16xi8>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [4, 1]
-// CHECK-NEXT: op    : %[[T2:.*]] = xegpu.load_nd %[[T0]]  : !xegpu.tensor_desc<8x16xi16> -> vector<8x16xi16>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 1]
-// CHECK-NEXT: op    : %[[T3:.*]] = xegpu.load_nd %[[T1]]  : !xegpu.tensor_desc<32x16xi8> -> vector<32x16xi8>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [4, 1]
-// CHECK-NEXT: op    : %[[T4:.*]] = vector.bitcast %[[T2]] : vector<8x16xi16> to vector<8x32xi8>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 2]
-// CHECK-NEXT: op    : %[[T5:.*]] = xegpu.dpas %[[T4]], %[[T3]] : vector<8x32xi8>, vector<32x16xi8> -> vector<8x16xi32>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 1]
-// CHECK-NEXT: op    : %[[T6:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xi32> -> !xegpu.tensor_desc<8x16xi32>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 1]
-func.func @test_vector_bitcast_i16_to_i8(%arg0: memref<8x16xi16>, %arg1: memref<32x16xi8>, %arg2: memref<8x16xi32>) {
+// CHECK-LABEL: func.func @vector_bitcast_i16_to_i8(
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<8x16xi16>, %[[ARG1:[0-9a-zA-Z]+]]: memref<32x16xi8>, %[[ARG2:[0-9a-zA-Z]+]]: memref<8x16xi32>) {
+// CHECK: %{{.*}} = vector.bitcast %{{.*}} {layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]>} : vector<8x16xi16> to vector<8x32xi8>
+func.func @vector_bitcast_i16_to_i8(%arg0: memref<8x16xi16>, %arg1: memref<32x16xi8>, %arg2: memref<8x16xi32>) {
   %c0 = arith.constant 0 : index
   %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x16xi16> -> !xegpu.tensor_desc<8x16xi16>
   %1 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<32x16xi8> -> !xegpu.tensor_desc<32x16xi8>
@@ -303,32 +165,11 @@ func.func @test_vector_bitcast_i16_to_i8(%arg0: memref<8x16xi16>, %arg1: memref<
 }
 
 // -----
-// CHECK: function: test_vector_bitcast_i8_to_f16:
-// CHECK-NEXT: argument: <block argument> of type 'memref<8x32xi8>' at index: 0
-// CHECK-NEXT: layout  : Not assigned.
-// CHECK-NEXT: argument: <block argument> of type 'memref<16x32xi8>' at index: 1
-// CHECK-NEXT: layout  : Not assigned.
-// CHECK-NEXT: argument: <block argument> of type 'memref<8x16xf32>' at index: 2
-// CHECK-NEXT: layout  : Not assigned.
-// CHECK-NEXT: op    : %{{.*}} = arith.constant 0 : index
-// CHECK-NEXT: layout for result #0: Not assigned.
-// CHECK-NEXT: op    : %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x32xi8> -> !xegpu.tensor_desc<8x32xi8>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 2]
-// CHECK-NEXT: op    : %[[T1:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<16x32xi8> -> !xegpu.tensor_desc<16x32xi8>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [4, 1]
-// CHECK-NEXT: op    : %[[T2:.*]] = xegpu.load_nd %[[T0]]  : !xegpu.tensor_desc<8x32xi8> -> vector<8x32xi8>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 2]
-// CHECK-NEXT: op    : %[[T3:.*]] = xegpu.load_nd %[[T1]]  : !xegpu.tensor_desc<16x32xi8> -> vector<16x32xi8>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [4, 1]
-// CHECK-NEXT: op    : %[[T4:.*]] = vector.bitcast %[[T2]] : vector<8x32xi8> to vector<8x16xf16>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 1]
-// CHECK-NEXT: op    : %[[T5:.*]] = vector.bitcast %[[T3]] : vector<16x32xi8> to vector<16x16xf16>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [2, 1]
-// CHECK-NEXT: op    : %[[T6:.*]] = xegpu.dpas %[[T4]], %[[T5]] : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 1]
-// CHECK-NEXT: op    : %[[T7:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 1]
-func.func @test_vector_bitcast_i8_to_f16(%arg0: memref<8x32xi8>, %arg1: memref<16x32xi8>, %arg2: memref<8x16xf32>) {
+// CHECK-LABEL: func.func @vector_bitcast_i8_to_f16(
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<8x32xi8>, %[[ARG1:[0-9a-zA-Z]+]]: memref<16x32xi8>, %[[ARG2:[0-9a-zA-Z]+]]: memref<8x16xf32>) {
+// CHECK: %{{.*}} = vector.bitcast %{{.*}} {layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]>, layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<8x32xi8> to vector<8x16xf16>
+// CHECK: %{{.*}} = vector.bitcast %{{.*}} {layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [4, 1]>, layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} : vector<16x32xi8> to vector<16x16xf16>
+func.func @vector_bitcast_i8_to_f16(%arg0: memref<8x32xi8>, %arg1: memref<16x32xi8>, %arg2: memref<8x16xf32>) {
   %c0 = arith.constant 0 : index
   %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x32xi8> -> !xegpu.tensor_desc<8x32xi8>
   %1 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<16x32xi8> -> !xegpu.tensor_desc<16x32xi8>
@@ -343,24 +184,12 @@ func.func @test_vector_bitcast_i8_to_f16(%arg0: memref<8x32xi8>, %arg1: memref<1
 }
 
 // -----
-// CHECK: function: test_binary_op_one_use:
-// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<8x16xf16>' at index: 0
-// CHECK-NEXT: layout  : lane_layout: [1, 16], lane_data: [1, 1]
-// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<16x16xf16>' at index: 1
-// CHECK-NEXT: layout  : lane_layout: [1, 16], lane_data: [2, 1]
-// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<8x16xf32>' at index: 2
-// CHECK-NEXT: layout  : lane_layout: [1, 16], lane_data: [1, 1]
-// CHECK-NEXT: op    : %[[T0:.*]] = xegpu.load_nd %{{.*}}  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 1]
-// CHECK-NEXT: op    : %[[T1:.*]] = xegpu.load_nd %{{.*}}  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [2, 1]
-// CHECK-NEXT: op    : %[[T2:.*]] = xegpu.load_nd %{{.*}}  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [2, 1]
-// CHECK-NEXT: op    : %[[T3:.*]] = arith.addf %[[T1]], %[[T2]] : vector<16x16xf16>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [2, 1]
-// CHECK-NEXT: op    : %[[T4:.*]] = xegpu.dpas %[[T0]], %[[T3]] : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 1]
-func.func @test_binary_op_one_use(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: !xegpu.tensor_desc<8x16xf32>) {
+// CHECK-LABEL: func.func @binary_op_one_use(
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<8x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>, %[[ARG1:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>>, %[[ARG2:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>) {
+// CHECK: %[[T1:.*]] = xegpu.load_nd %[[ARG1]]  {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>> -> vector<16x16xf16>
+// CHECK-NEXT: %[[T2:.*]] = xegpu.load_nd %[[ARG1]]  {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>> -> vector<16x16xf16>
+// CHECK-NEXT: %{{.*}} = arith.addf %[[T1]], %[[T2]] {layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>, layout_operand_1 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>, layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} : vector<16x16xf16>
+func.func @binary_op_one_use(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: !xegpu.tensor_desc<8x16xf32>) {
   %0 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
   %1 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
   %2 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
@@ -371,26 +200,13 @@ func.func @test_binary_op_one_use(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !x
 }
 
 // -----
-// CHECK: function: test_binary_op_multiple_uses:
-// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<8x16xf16>' at index: 0
-// CHECK-NEXT: layout  : lane_layout: [1, 16], lane_data: [1, 1]
-// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<16x16xf16>' at index: 1
-// CHECK-NEXT: layout  : lane_layout: [1, 16], lane_data: [1, 1]
-// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<8x16xf32>' at index: 2
-// CHECK-NEXT: layout  : lane_layout: [1, 16], lane_data: [1, 1]
-// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<16x16xf16>' at index: 3
-// CHECK-NEXT: layout  : lane_layout: [1, 16], lane_data: [1, 1]
-// CHECK-NEXT: op    : %[[T0:.*]] = xegpu.load_nd %{{.*}}  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 1]
-// CHECK-NEXT: op    : %[[T1:.*]] = xegpu.load_nd %{{.*}}  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 1]
-// CHECK-NEXT: op    : %[[CST:.*]] = arith.constant dense<1.000000e+00> : vector<16x16xf16>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 1]
-// CHECK-NEXT: op    : %[[T2:.*]] = arith.addf %[[T1]], %[[CST]] : vector<16x16xf16>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 1]
-// CHECK-NEXT: op    : %[[T3:.*]] = xegpu.dpas %[[T0]], %[[T2]] : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 1]
-func.func @test_binary_op_multiple_uses(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: !xegpu.tensor_desc<8x16xf32>, %arg3: !xegpu.tensor_desc<16x16xf16>) {
+// CHECK-LABEL: func.func @binary_op_multiple_uses(
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<8x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>, %[[ARG1:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>, %[[ARG2:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>, %[[ARG3:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>) {
+// CHECK: %[[T2:.*]] = arith.addf %{{.*}}, %{{.*}} {layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, layout_operand_1 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16x16xf16>
+// CHECK: %[[T3:.*]] = xegpu.dpas %{{.*}}, %[[T2]] {layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, layout_operand_1 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+// CHECK-NEXT: xegpu.store_nd %[[T3]], %[[ARG2]]  {layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+// CHECK-NEXT: xegpu.store_nd %[[T2]], %[[ARG3]]  {layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+func.func @binary_op_multiple_uses(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: !xegpu.tensor_desc<8x16xf32>, %arg3: !xegpu.tensor_desc<16x16xf16>) {
   %0 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
   %1 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
   %cst = arith.constant dense<1.000000e+00> : vector<16x16xf16>
@@ -402,42 +218,22 @@ func.func @test_binary_op_multiple_uses(%arg0: !xegpu.tensor_desc<8x16xf16>, %ar
 }
 
 // -----
-// CHECK: function: test_for_op:
-// CHECK-NEXT: argument: <block argument> of type 'memref<8x128xf16>' at index: 0
-// CHECK-NEXT: layout  : Not assigned.
-// CHECK-NEXT: argument: <block argument> of type 'memref<128x16xf16>' at index: 1
-// CHECK-NEXT: layout  : Not assigned.
-// CHECK-NEXT: argument: <block argument> of type 'memref<8x16xf32>' at index: 2
-// CHECK-NEXT: layout  : Not assigned.
-// CHECK-NEXT: op    : %{{.*}} = arith.constant 0 : index
-// CHECK-NEXT: layout for result #0: Not assigned.
-// CHECK-NEXT: op    : %{{.*}} = arith.constant 128 : index
-// CHECK-NEXT: layout for result #0: Not assigned.
-// CHECK-NEXT: op    : %{{.*}} = arith.constant 16 : index
-// CHECK-NEXT: layout for result #0: Not assigned.
-// CHECK-NEXT: op    : %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x128xf16> -> !xegpu.tensor_desc<8x16xf16>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 1]
-// CHECK-NEXT: op    : %[[T1:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<128x16xf16> -> !xegpu.tensor_desc<16x16xf16>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [2, 1]
-// CHECK-NEXT: op    : %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<8x16xf32>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 1]
-// CHECK-NEXT: op    : %[[T4:.*]] = xegpu.load_nd %{{.*}}  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 1]
-// CHECK-NEXT: op    : %[[T5:.*]] = xegpu.load_nd %{{.*}}  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [2, 1]
-// CHECK-NEXT: op    : %[[T6:.*]] = xegpu.dpas %[[T4]], %[[T5]], %{{.*}} : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 1]
-// CHECK-NEXT: op    : %[[T7:.*]] = xegpu.update_nd_offset %{{.*}} : !xegpu.tensor_desc<8x16xf16>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 1]
-// CHECK-NEXT: op    : %[[T8:.*]] = xegpu.update_nd_offset %{{.*}} : !xegpu.tensor_desc<16x16xf16>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [2, 1]
-// CHECK-NEXT: op    : scf.for
-// CHECK-NEXT: layout for result #0: Not assigned.
-// CHECK-NEXT: layout for result #1: Not assigned.
-// CHECK-NEXT: layout for result #2: lane_layout: [1, 16], lane_data: [1, 1]
-// CHECK-NEXT: op    : %[[T3:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 1]
-func.func @test_for_op(%arg0: memref<8x128xf16>, %arg1: memref<128x16xf16>, %arg2: memref<8x16xf32>) {
+// CHECK-LABEL: func.func @for_op(
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<8x128xf16>, %[[ARG1:[0-9a-zA-Z]+]]: memref<128x16xf16>, %[[ARG2:[0-9a-zA-Z]+]]: memref<8x16xf32>) {
+// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%{{.*}}] : memref<8x128xf16> -> !xegpu.tensor_desc<8x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+// CHECK-NEXT: %[[T1:.*]] = xegpu.create_nd_tdesc %[[ARG1]][%{{.*}}] : memref<128x16xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>>
+// CHECK-NEXT: %[[CST:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} dense<0.000000e+00> : vector<8x16xf32>
+// CHECK-NEXT: %[[T2:.*]]:3 = scf.for %{{.*}} iter_args(%[[ARG4:.*]] = %[[T0]], %[[ARG5:.*]] = %[[T1]], %[[ARG6:.*]] = %[[CST]]) -> (!xegpu.tensor_desc<8x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>, !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>>, vector<8x16xf32>) {
+// CHECK-NEXT:   %[[T4:.*]] = xegpu.load_nd %[[ARG4]]  {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : !xegpu.tensor_desc<8x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<8x16xf16>
+// CHECK-NEXT:   %[[T5:.*]] = xegpu.load_nd %[[ARG5]]  {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>> -> vector<16x16xf16>
+// CHECK-NEXT:   %[[T6:.*]] = xegpu.dpas %[[T4]], %[[T5]], %[[ARG6]] {layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, layout_operand_1 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>, layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
+// CHECK-NEXT:   %[[T7:.*]] = xegpu.update_nd_offset %[[ARG4]], [{{.*}}] : !xegpu.tensor_desc<8x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+// CHECK-NEXT:   %[[T8:.*]] = xegpu.update_nd_offset %[[ARG5]], [{{.*}}] : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>>
+// CHECK-NEXT:   scf.yield {layout_operand_2 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} %[[T7]], %[[T8]], %[[T6]] : !xegpu.tensor_desc<8x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>, !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>>, vector<8x16xf32>
+// CHECK-NEXT: } {layout_operand_5 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, layout_result_2 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+// CHECK-NEXT: %[[T3:.*]] = xegpu.create_nd_tdesc %[[ARG2]][{{.*}}] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+// CHECK-NEXT: xegpu.store_nd %[[T2]]#2, %[[T3]]  {layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+func.func @for_op(%arg0: memref<8x128xf16>, %arg1: memref<128x16xf16>, %arg2: memref<8x16xf32>) {
   %c0 = arith.constant 0 : index
   %c128 = arith.constant 128 : index
   %c16 = arith.constant 16 : index
@@ -458,26 +254,16 @@ func.func @test_for_op(%arg0: memref<8x128xf16>, %arg1: memref<128x16xf16>, %arg
 }
 
 // -----
-// CHECK: function: test_if_single_use:
-// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<8x16xf16>' at index: 0
-// CHECK-NEXT: layout  : lane_layout: [1, 16], lane_data: [1, 1]
-// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<16x16xf16>' at index: 1
-// CHECK-NEXT: layout  : lane_layout: [1, 16], lane_data: [2, 1]
-// CHECK-NEXT: argument: <block argument> of type 'i1' at index: 2
-// CHECK-NEXT: layout  : Not assigned.
-// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<8x16xf32>' at index: 3
-// CHECK-NEXT: layout  : lane_layout: [1, 16], lane_data: [1, 1]
-// CHECK-NEXT: op    : %[[T0:.*]] = xegpu.load_nd %{{.*}}  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 1]
-// CHECK-NEXT: op    : %[[T3:.*]] = xegpu.load_nd %{{.*}}  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [2, 1]
-// CHECK-NEXT: op    : %[[T4:.*]] = xegpu.load_nd %{{.*}}  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [2, 1]
-// CHECK-NEXT: op    : scf.if
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [2, 1]
-// CHECK-NEXT: op    : %[[T2:.*]] = xegpu.dpas %[[T0]], %{{.*}} : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 1]
-func.func @test_if_single_use(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: i1, %arg3: !xegpu.tensor_desc<8x16xf32>) {
+// CHECK-LABEL: func.func @if_single_use(
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<8x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>, %[[ARG1:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>>, %[[ARG2:[0-9a-zA-Z]+]]: i1, %[[ARG3:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>) {
+// CHECK:  %{{.*}} = scf.if %[[ARG2]] -> (vector<16x16xf16>) {
+// CHECK-NEXT:    %[[T3:.*]] = xegpu.load_nd %[[ARG1]]  {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>> -> vector<16x16xf16>
+// CHECK-NEXT:    scf.yield {layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} %[[T3]] : vector<16x16xf16>
+// CHECK-NEXT:  } else {
+// CHECK-NEXT:    %[[T4:.*]] = xegpu.load_nd %[[ARG1]]  {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>> -> vector<16x16xf16>
+// CHECK-NEXT:    scf.yield {layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} %[[T4]] : vector<16x16xf16>
+// CHECK-NEXT:  } {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>}
+func.func @if_single_use(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: i1, %arg3: !xegpu.tensor_desc<8x16xf32>) {
   %0 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
   %1 = scf.if %arg2 -> (vector<16x16xf16>) {
     %3 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
@@ -492,28 +278,16 @@ func.func @test_if_single_use(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu
 }
 
 // -----
-// CHECK: function: test_if_multiple_uses:
-// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<8x16xf16>' at index: 0
-// CHECK-NEXT: layout  : lane_layout: [1, 16], lane_data: [1, 1]
-// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<16x16xf16>' at index: 1
-// CHECK-NEXT: layout  : lane_layout: [1, 16], lane_data: [1, 1]
-// CHECK-NEXT: argument: <block argument> of type 'i1' at index: 2
-// CHECK-NEXT: layout  : Not assigned.
-// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<8x16xf32>' at index: 3
-// CHECK-NEXT: layout  : lane_layout: [1, 16], lane_data: [1, 1]
-// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<16x16xf16>' at index: 4
-// CHECK-NEXT: layout  : lane_layout: [1, 16], lane_data: [1, 1]
-// CHECK-NEXT: op    : %[[T0:.*]] = xegpu.load_nd %{{.*}}  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 1]
-// CHECK-NEXT: op    : %[[T3:.*]] = xegpu.load_nd %{{.*}}  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 1]
-// CHECK-NEXT: op    : %[[T4:.*]] = xegpu.load_nd %{{.*}}  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 1]
-// CHECK-NEXT: op    : scf.if
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 1]
-// CHECK-NEXT: op    : %[[T2:.*]] = xegpu.dpas %[[T0]], %{{.*}} : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 1]
-func.func @test_if_multiple_uses(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: i1, %arg3: !xegpu.tensor_desc<8x16xf32>, %arg4: !xegpu.tensor_desc<16x16xf16>) {
+// CHECK-LABEL: func.func @if_multiple_uses(
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<8x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>, %[[ARG1:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>, %[[ARG2:[0-9a-zA-Z]+]]: i1, %[[ARG3:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>, %[[ARG4:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>) {
+// CHECK: %[[T1:.*]] = scf.if %[[ARG2]] -> (vector<16x16xf16>) {
+// CHECK-NEXT:       %[[T3:.*]] = xegpu.load_nd %[[ARG1]]  {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<16x16xf16>
+// CHECK-NEXT:       scf.yield {layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} %[[T3]] : vector<16x16xf16>
+// CHECK-NEXT:     } else {
+// CHECK-NEXT:       %[[T4:.*]] = xegpu.load_nd %[[ARG1]]  {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<16x16xf16>
+// CHECK-NEXT:       scf.yield {layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} %[[T4]] : vector<16x16xf16>
+// CHECK-NEXT:     } {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+func.func @if_multiple_uses(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: i1, %arg3: !xegpu.tensor_desc<8x16xf32>, %arg4: !xegpu.tensor_desc<16x16xf16>) {
   %0 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
   %1 = scf.if %arg2 -> (vector<16x16xf16>) {
     %3 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
@@ -529,16 +303,10 @@ func.func @test_if_multiple_uses(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xe
 }
 
 // -----
-// CHECK: function: test_vector_outer_reduction:
-// CHECK-NEXT: argument: <block argument> of type 'vector<16x16xf32>' at index: 0
-// CHECK-NEXT: layout  : lane_layout: [1, 16], lane_data: [1, 1]
-// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<16xf32>' at index: 1
-// CHECK-NEXT: layout  : lane_layout: [16], lane_data: [1]
-// CHECK-NEXT: op    : %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<16xf32>
-// CHECK-NEXT: layout for result #0: lane_layout: [16], lane_data: [1]
-// CHECK-NEXT: op    : %[[T0:.*]] = vector.multi_reduction <add>, %{{.*}}, %[[CST]] [0] : vector<16x16xf32> to vector<16xf32>
-// CHECK-NEXT: layout for result #0: lane_layout: [16], lane_data: [1]
-func.func @test_vector_outer_reduction(%arg0: vector<16x16xf32>, %arg1: !xegpu.tensor_desc<16xf32>) {
+// CHECK-LABEL: func.func @vector_outer_reduction(
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: vector<16x16xf32>, %[[ARG1:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>>) {
+// CHECK: %{{.*}} = vector.multi_reduction <add>, %[[ARG0]], %{{.*}} {layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, layout_operand_1 = #xegpu.layout<lane_layout = [16], lane_data = [1]>, layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} [0] : vector<16x16xf32> to vector<16xf32>
+func.func @vector_outer_reduction(%arg0: vector<16x16xf32>, %arg1: !xegpu.tensor_desc<16xf32>) {
   %cst = arith.constant dense<0.000000e+00> : vector<16xf32>
   %0 = vector.multi_reduction <add>, %arg0, %cst [0] : vector<16x16xf32> to vector<16xf32>
   xegpu.store_nd %0, %arg1  : vector<16xf32>, !xegpu.tensor_desc<16xf32>
@@ -546,16 +314,10 @@ func.func @test_vector_outer_reduction(%arg0: vector<16x16xf32>, %arg1: !xegpu.t
 }
 
 // -----
-// CHECK: function: test_vector_inner_reduction:
-// CHECK-NEXT: argument: <block argument> of type 'vector<16x16xf32>' at index: 0
-// CHECK-NEXT: layout  : lane_layout: [1, 16], lane_data: [1, 1]
-// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<16xf32>' at index: 1
-// CHECK-NEXT: layout  : lane_layout: [16], lane_data: [1]
-// CHECK-NEXT: op    : %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<16xf32>
-// CHECK-NEXT: layout for result #0: lane_layout: [16], lane_data: [1]
-// CHECK-NEXT: op    : %[[T0:.*]] = vector.multi_reduction <add>, %{{.*}}, %[[CST]] [1] : vector<16x16xf32> to vector<16xf32>
-// CHECK-NEXT: layout for result #0: lane_layout: [16], lane_data: [1]
-func.func @test_vector_inner_reduction(%arg0: vector<16x16xf32>, %arg1: !xegpu.tensor_desc<16xf32>) {
+// CHECK-LABEL: func.func @vector_inner_reduction(
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: vector<16x16xf32>, %[[ARG1:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>>) {
+// CHECK: %{{.*}} = vector.multi_reduction <add>, %[[ARG0]], %{{.*}} {layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, layout_operand_1 = #xegpu.layout<lane_layout = [16], lane_data = [1]>, layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} [1] : vector<16x16xf32> to vector<16xf32>
+func.func @vector_inner_reduction(%arg0: vector<16x16xf32>, %arg1: !xegpu.tensor_desc<16xf32>) {
   %cst = arith.constant dense<0.000000e+00> : vector<16xf32>
   %0 = vector.multi_reduction <add>, %arg0, %cst [1] : vector<16x16xf32> to vector<16xf32>
   xegpu.store_nd %0, %arg1  : vector<16xf32>, !xegpu.tensor_desc<16xf32>
@@ -563,13 +325,10 @@ func.func @test_vector_inner_reduction(%arg0: vector<16x16xf32>, %arg1: !xegpu.t
 }
 
 // -----
-// CHECK: function: update_nd_offset_1d:
-// CHECK: op    : %[[CST:.*]] = arith.constant dense<1.000000e+00> : vector<16xf32>
-// CHECK-NEXT: layout for result #0: lane_layout: [16], lane_data: [1]
-// CHECK-NEXT: op    : %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}}[%{{.*}}] : memref<256xf32> -> !xegpu.tensor_desc<16xf32>
-// CHECK-NEXT: layout for result #0: lane_layout: [16], lane_data: [1]
-// CHECK-NEXT: op    : %[[T1:.*]] = xegpu.update_nd_offset %[[T0]], [%{{.*}}] : !xegpu.tensor_desc<16xf32>
-// CHECK-NEXT: layout for result #0: lane_layout: [16], lane_data: [1]
+// CHECK-LABEL: func.func @update_nd_offset_1d(
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<256xf32>) {
+// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%{{.*}}] : memref<256xf32> -> !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
+// CHECK-NEXT: %[[T1:.*]] = xegpu.update_nd_offset %[[T0]], [%{{.*}}] : !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
 func.func @update_nd_offset_1d(%arg0: memref<256xf32>){
   %c0 = arith.constant 0 : index
   %c32 = arith.constant 32 : index
@@ -581,13 +340,10 @@ func.func @update_nd_offset_1d(%arg0: memref<256xf32>){
 }
 
 // -----
-// CHECK: function: update_nd_offset_2d:
-// CHECK: op    : %[[CST:.*]] = arith.constant dense<1.000000e+00> : vector<16x16xf32>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 1]
-// CHECK-NEXT: op    : %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}}[%{{.*}}] : memref<256x256xf32> -> !xegpu.tensor_desc<16x16xf32>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 1]
-// CHECK-NEXT: op    : %[[T1:.*]] = xegpu.update_nd_offset %[[T0]], [%{{.*}}] : !xegpu.tensor_desc<16x16xf32>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 1]
+// CHECK-LABEL: func.func @update_nd_offset_2d(
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<256x256xf32>) {
+// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%{{.*}}, %{{.*}}] : memref<256x256xf32> -> !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+// CHECK-NEXT: %[[T1:.*]] = xegpu.update_nd_offset %[[T0]], [%{{.*}}, %{{.*}}] : !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
 func.func @update_nd_offset_2d(%arg0: memref<256x256xf32>){
   %c0 = arith.constant 0 : index
   %c32 = arith.constant 32 : index
@@ -599,10 +355,10 @@ func.func @update_nd_offset_2d(%arg0: memref<256x256xf32>){
 }
 
 // -----
-// CHECK: function: prefetch_2d:
-// CHECK: layout for result #0: Not assigned.
-// CHECK-NEXT: op    : %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}}[%{{.*}}] : memref<256x256xf16> -> !xegpu.tensor_desc<16x16xf16>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 1]
+// CHECK-LABEL: func.func @prefetch_2d(
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<256x256xf16>) {
+// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%{{.*}}, %{{.*}}] : memref<256x256xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+// CHECK-NEXT: xegpu.prefetch_nd %[[T0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
 func.func @prefetch_2d(%arg0: memref<256x256xf16>){
   %c0 = arith.constant 0 : index
   %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<256x256xf16> -> !xegpu.tensor_desc<16x16xf16>
@@ -611,9 +367,10 @@ func.func @prefetch_2d(%arg0: memref<256x256xf16>){
 }
 
 // -----
-// CHECK: function: prefetch_1d:
-// CHECK: op    : %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}}[%{{.*}}] : memref<256xf16> -> !xegpu.tensor_desc<16xf16>
-// CHECK-NEXT: layout for result #0: lane_layout: [16], lane_data: [1]
+// CHECK-LABEL: func.func @prefetch_1d(
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<256xf16>) {
+// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%{{.*}}] : memref<256xf16> -> !xegpu.tensor_desc<16xf16, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
+// CHECK-NEXT: xegpu.prefetch_nd %[[T0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<16xf16, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
 func.func @prefetch_1d(%arg0: memref<256xf16>){
   %c0 = arith.constant 0 : index
   %0 = xegpu.create_nd_tdesc %arg0[%c0] : memref<256xf16> -> !xegpu.tensor_desc<16xf16>

>From 341daff6dd9f95fcd6a73240f6edb108a8e50b77 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Thu, 5 Jun 2025 20:15:04 +0000
Subject: [PATCH 18/32] fix test

---
 .../Dialect/XeGPU/subgroup-distribute.mlir    | 84 +++++++++----------
 1 file changed, 40 insertions(+), 44 deletions(-)

diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
index 0f236d4e8b9dc..3bfabac55faf3 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
@@ -168,52 +168,48 @@ gpu.module @test {
 // -----
 // CHECK-LABEL: gpu.func @gemm_loop
 // CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: memref<1024x1024xbf16>, %[[ARG1:[0-9a-zA-Z]+]]: memref<1024x1024xbf16>, %[[ARG2:[0-9a-zA-Z]+]]: memref<1024x1024xf32>) {
-// CHECK: %[[BLOCK_ID_X:.*]] = gpu.block_id x
-// CHECK: %[[BLOCK_ID_Y:.*]] = gpu.block_id y
-// CHECK: %[[Y_COORD:.*]] = arith.muli %[[BLOCK_ID_Y]], %c16 : index
-// CHECK: %[[X_COORD:.*]] = arith.muli %[[BLOCK_ID_X]], %c8 : index
+// CHECK-DAG: %[[BLOCK_ID_X:.*]] = gpu.block_id x
+// CHECK-DAG: %[[BLOCK_ID_Y:.*]] = gpu.block_id y
+// CHECK-DAG: %[[Y_COORD:.*]] = arith.muli %[[BLOCK_ID_Y]], %c16 : index
+// CHECK-DAG: %[[X_COORD:.*]] = arith.muli %[[BLOCK_ID_X]], %c8 : index
 // CHECK: %[[T2:.*]] = xegpu.create_nd_tdesc %[[ARG2]][%[[X_COORD]], %[[Y_COORD]]] : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32>
-// CHECK: %[[T3:.*]] = xegpu.load_nd %[[T2]]  : !xegpu.tensor_desc<8x16xf32> -> vector<8xf32>
-// CHECK-DAG: %[[C_INIT:.*]] = vector.shape_cast %[[T3]] : vector<8xf32> to vector<8x1xf32>
-// CHECK-DAG: %[[B_TILE:.*]] = xegpu.create_nd_tdesc %[[ARG1]][%{{.*}}, %[[Y_COORD]]] : memref<1024x1024xbf16> -> !xegpu.tensor_desc<16x16xbf16>
-// CHECK-DAG: %[[A_TILE:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%[[X_COORD]], %{{.*}}] : memref<1024x1024xbf16> -> !xegpu.tensor_desc<8x16xbf16>
-// CHECK: %[[T7:.*]]:3 = scf.for {{.*}} iter_args(%[[C_VAL:.*]] = %[[C_INIT]], %[[A_ARG:.*]] = %[[A_TILE]], %[[B_ARG:.*]] = %[[B_TILE]]) -> (vector<8x1xf32>, !xegpu.tensor_desc<8x16xbf16>, !xegpu.tensor_desc<16x16xbf16>) {
-// CHECK-DAG: %[[B_NEXT:.*]] = xegpu.update_nd_offset %[[B_ARG]], [{{.*}}] : !xegpu.tensor_desc<16x16xbf16>
-// CHECK-DAG: %[[A_NEXT:.*]] = xegpu.update_nd_offset %[[A_ARG]], [{{.*}}] : !xegpu.tensor_desc<8x16xbf16>
-// CHECK-DAG: %[[B:.*]] = xegpu.load_nd %[[B_ARG]] <{packed}> : !xegpu.tensor_desc<16x16xbf16> -> vector<16xbf16>
-// CHECK-DAG: %[[A:.*]] = xegpu.load_nd %[[A_ARG]]  : !xegpu.tensor_desc<8x16xbf16> -> vector<8xbf16>
-// CHECK-DAG: %[[C:.*]] = vector.shape_cast %[[C_VAL]] : vector<8x1xf32> to vector<8xf32>
-// CHECK-NEXT: %[[T8:.*]] = xegpu.dpas %[[A]], %[[B]], %[[C]] : vector<8xbf16>, vector<16xbf16>, vector<8xf32> -> vector<8xf32>
-// CHECK-NEXT: %[[C_OUT:.*]] = vector.shape_cast %[[T8]] : vector<8xf32> to vector<8x1xf32>
-// CHECK-NEXT: scf.yield %[[C_OUT]], %[[A_NEXT]], %[[B_NEXT]] : vector<8x1xf32>, !xegpu.tensor_desc<8x16xbf16>, !xegpu.tensor_desc<16x16xbf16>
-// CHECK-NEXT:}
-// CHECK-NEXT: %[[C_FINAL:.*]] = vector.shape_cast %[[T7]]#0 : vector<8x1xf32> to vector<8xf32>
-// CHECK-NEXT: xegpu.store_nd %[[C_FINAL]], %[[T2]]  : vector<8xf32>, !xegpu.tensor_desc<8x16xf32>
+// CHECK-NEXT: %[[T3:.*]] = xegpu.load_nd %[[T2]] : !xegpu.tensor_desc<8x16xf32> -> vector<8xf32>
+// CHECK-NEXT: %[[T4:.*]] = vector.shape_cast %[[T3]] : vector<8xf32> to vector<8x1xf32>
+// CHECK: %[[T5:.*]] = scf.for %[[K:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG4:.*]] = %[[T4]]) -> (vector<8x1xf32>) {
+// CHECK-DAG: %[[T10:.*]] = xegpu.create_nd_tdesc %[[ARG1]][%[[K]], %[[Y_COORD]]] : memref<1024x1024xbf16> -> !xegpu.tensor_desc<16x16xbf16>
+// CHECK-DAG: %[[T11:.*]] = xegpu.load_nd %[[T10]] <{packed}> : !xegpu.tensor_desc<16x16xbf16> -> vector<16xbf16>
+// CHECK-DAG: %[[T12:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%[[X_COORD]], %[[K]]] : memref<1024x1024xbf16> -> !xegpu.tensor_desc<8x16xbf16>
+// CHECK-DAG: %[[T13:.*]] = xegpu.load_nd %[[T12]] : !xegpu.tensor_desc<8x16xbf16> -> vector<8xbf16>
+// CHECK-DAG: %[[T14:.*]] = vector.shape_cast %[[ARG4]] : vector<8x1xf32> to vector<8xf32>
+// CHECK-NEXT: %[[T15:.*]] = xegpu.dpas %[[T13]], %[[T11]], %[[T14]] : vector<8xbf16>, vector<16xbf16>, vector<8xf32> -> vector<8xf32>
+// CHECK-NEXT: %[[T16:.*]] = vector.shape_cast %[[T15]] : vector<8xf32> to vector<8x1xf32>
+// CHECK-NEXT: scf.yield %[[T16]] : vector<8x1xf32>
+// CHECK-NEXT: }
+// CHECK-NEXT: %[[T9:.*]] = vector.shape_cast %[[T5]] : vector<8x1xf32> to vector<8xf32>
+// CHECK-NEXT: xegpu.store_nd %[[T9]], %[[T2]] : vector<8xf32>, !xegpu.tensor_desc<8x16xf32>
 gpu.module @test {
-  gpu.func @gemm_loop(%arg0: memref<1024x1024xbf16>, %arg1: memref<1024x1024xbf16>, %arg2: memref<1024x1024xf32>) {
-    %c0 = arith.constant 0 : index
-    %c16 = arith.constant 16 : index
-    %c8 = arith.constant 8 : index
-    %c1024 = arith.constant 1024 : index
-    %block_id_x = gpu.block_id  x
-    %block_id_y = gpu.block_id  y
-    %0 = arith.muli %block_id_x, %c8 : index
-    %1 = arith.muli %block_id_y, %c16 : index
-    %2 = xegpu.create_nd_tdesc %arg2[%0, %1] : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
-    %3 = xegpu.load_nd %2  {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<8x16xf32>
-    %4 = xegpu.create_nd_tdesc %arg0[%0, %c0] : memref<1024x1024xbf16> -> !xegpu.tensor_desc<8x16xbf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
-    %5 = xegpu.create_nd_tdesc %arg1[%c0, %1] : memref<1024x1024xbf16> -> !xegpu.tensor_desc<16x16xbf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>>
-    %6:3 = scf.for %arg3 = %c0 to %c1024 step %c16 iter_args(%arg4 = %3, %arg5 = %4, %arg6 = %5) -> (vector<8x16xf32>, !xegpu.tensor_desc<8x16xbf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>, !xegpu.tensor_desc<16x16xbf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>>) {
-      %8 = xegpu.load_nd %arg5  {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : !xegpu.tensor_desc<8x16xbf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<8x16xbf16>
-      %9 = xegpu.load_nd %arg6  {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} : !xegpu.tensor_desc<16x16xbf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>> -> vector<16x16xbf16>
-      %10 = xegpu.update_nd_offset %arg5, [%c0, %c16] : !xegpu.tensor_desc<8x16xbf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
-      %11 = xegpu.update_nd_offset %arg6, [%c16, %c0] : !xegpu.tensor_desc<16x16xbf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>>
-      %12 = xegpu.dpas %8, %9, %arg4 {layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, layout_operand_1 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>, layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<8x16xbf16>, vector<16x16xbf16>, vector<8x16xf32> -> vector<8x16xf32>
-      scf.yield {layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} %12, %10, %11 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xbf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>, !xegpu.tensor_desc<16x16xbf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>>
-    } {layout_operand_3 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
-    xegpu.store_nd %6#0, %2  {layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
-    gpu.return
-  }
+gpu.func @gemm_loop(%arg0: memref<1024x1024xbf16>, %arg1: memref<1024x1024xbf16>, %arg2: memref<1024x1024xf32>){
+  %c0 = arith.constant 0 : index
+  %c16 = arith.constant 16 : index
+  %c8 = arith.constant 8 : index
+  %c1024 = arith.constant 1024 : index
+  %block_id_x = gpu.block_id  x
+  %block_id_y = gpu.block_id  y
+  %0 = arith.muli %block_id_x, %c8 : index
+  %1 = arith.muli %block_id_y, %c16 : index
+  %2 = xegpu.create_nd_tdesc %arg2[%0, %1] : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+  %3 = xegpu.load_nd %2  {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<8x16xf32>
+  %4 = scf.for %arg3 = %c0 to %c1024 step %c16 iter_args(%arg4 = %3) -> (vector<8x16xf32>) {
+    %5 = xegpu.create_nd_tdesc %arg0[%0, %arg3] : memref<1024x1024xbf16> -> !xegpu.tensor_desc<8x16xbf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    %6 = xegpu.create_nd_tdesc %arg1[%arg3, %1] : memref<1024x1024xbf16> -> !xegpu.tensor_desc<16x16xbf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>>
+    %7 = xegpu.load_nd %5  {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : !xegpu.tensor_desc<8x16xbf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<8x16xbf16>
+    %8 = xegpu.load_nd %6  {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} : !xegpu.tensor_desc<16x16xbf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>> -> vector<16x16xbf16>
+    %9 = xegpu.dpas %7, %8, %arg4 {layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, layout_operand_1 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>, layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<8x16xbf16>, vector<16x16xbf16>, vector<8x16xf32> -> vector<8x16xf32>
+    scf.yield {layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} %9 : vector<8x16xf32>
+  } {layout_operand_3 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+  xegpu.store_nd %4, %2  {layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+  gpu.return
+}
 }
 
 // -----

>From fdacb63e51af6de3a0deedddef30a10870d5d66b Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Thu, 5 Jun 2025 20:18:17 +0000
Subject: [PATCH 19/32] revert merge

---
 .../Vector/Transforms/VectorDistribute.cpp    | 40 +++++--------------
 .../Vector/vector-warp-distribute.mlir        | 36 -----------------
 2 files changed, 10 insertions(+), 66 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index bd833ddb773f7..045c192787f10 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -1554,37 +1554,22 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
     llvm::SmallSetVector<Value, 32> escapingValues;
     SmallVector<Type> inputTypes;
     SmallVector<Type> distTypes;
-    auto collectEscapingValues = [&](Value value) {
-      if (!escapingValues.insert(value))
-        return;
-      Type distType = value.getType();
-      if (auto vecType = dyn_cast<VectorType>(distType)) {
-        AffineMap map = distributionMapFn(value);
-        distType = getDistributedType(vecType, map, warpOp.getWarpSize());
-      }
-      inputTypes.push_back(value.getType());
-      distTypes.push_back(distType);
-    };
-
     mlir::visitUsedValuesDefinedAbove(
         forOp.getBodyRegion(), [&](OpOperand *operand) {
           Operation *parent = operand->get().getParentRegion()->getParentOp();
           if (warpOp->isAncestor(parent)) {
-            collectEscapingValues(operand->get());
+            if (!escapingValues.insert(operand->get()))
+              return;
+            Type distType = operand->get().getType();
+            if (auto vecType = dyn_cast<VectorType>(distType)) {
+              AffineMap map = distributionMapFn(operand->get());
+              distType = getDistributedType(vecType, map, warpOp.getWarpSize());
+            }
+            inputTypes.push_back(operand->get().getType());
+            distTypes.push_back(distType);
           }
         });
 
-    // Any forOp result that is not already yielded by the warpOp
-    // region is also considered escaping and must be returned by the
-    // original warpOp.
-    for (OpResult forResult : forOp.getResults()) {
-      // Check if this forResult is already yielded by the yield op.
-      if (llvm::is_contained(yield->getOperands(), forResult)) {
-        continue;
-      }
-      collectEscapingValues(forResult);
-    }
-
     if (llvm::is_contained(distTypes, Type{}))
       return failure();
 
@@ -1624,12 +1609,7 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
                                     forOp.getResultTypes().end());
     llvm::SmallDenseMap<Value, int64_t> argIndexMapping;
     for (auto [i, retIdx] : llvm::enumerate(newRetIndices)) {
-      auto newWarpResult = newWarpOp.getResult(retIdx);
-      // Unused forOp results yielded by the warpOp region are already included
-      // in the new ForOp.
-      if (llvm::is_contained(newOperands, newWarpResult))
-        continue;
-      warpInput.push_back(newWarpResult);
+      warpInput.push_back(newWarpOp.getResult(retIdx));
       argIndexMapping[escapingValues[i]] = warpInputType.size();
       warpInputType.push_back(inputTypes[i]);
     }
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index 6c7ac7a5196a7..38771f2593449 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -584,42 +584,6 @@ func.func @warp_scf_for_multiple_yield(%arg0: index, %arg1: memref<?xf32>, %arg2
   return
 }
 
-// -----
-// CHECK-PROP-LABEL: func.func @warp_scf_for_unused_yield(
-//       CHECK-PROP: %[[W0:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<4xf32>, vector<4xf32>) {
-//       CHECK-PROP: %[[INI0:.*]] = "some_def"() : () -> vector<128xf32>
-//       CHECK-PROP: %[[INI1:.*]] = "some_def"() : () -> vector<128xf32>
-//       CHECK-PROP: gpu.yield %[[INI0]], %[[INI1]] : vector<128xf32>, vector<128xf32>
-//       CHECK-PROP: }
-//       CHECK-PROP: %[[F:.*]]:2 = scf.for %{{.*}} iter_args(%{{.*}} = %[[W0]]#0, %{{.*}} = %[[W0]]#1) -> (vector<4xf32>, vector<4xf32>) {
-//       CHECK-PROP: %[[W1:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[32] args(%{{.*}} : vector<4xf32>, vector<4xf32>) -> (vector<4xf32>, vector<4xf32>) {
-//       CHECK-PROP: %[[ACC0:.*]] = "some_def"(%{{.*}}) : (vector<128xf32>, index) -> vector<128xf32>
-//       CHECK-PROP: %[[ACC1:.*]] = "some_def"(%{{.*}}) : (index, vector<128xf32>, vector<128xf32>) -> vector<128xf32>
-//       CHECK-PROP: gpu.yield %[[ACC1]], %[[ACC0]] : vector<128xf32>, vector<128xf32>
-//       CHECK-PROP: }
-//       CHECK-PROP: scf.yield %[[W1]]#0, %[[W1]]#1 : vector<4xf32>, vector<4xf32>
-//       CHECK-PROP: }
-//       CHECK-PROP: "some_use"(%[[F]]#0) : (vector<4xf32>) -> ()
-func.func @warp_scf_for_unused_yield(%arg0: index) {
-  %c128 = arith.constant 128 : index
-  %c1 = arith.constant 1 : index
-  %c0 = arith.constant 0 : index
-  %0 = gpu.warp_execute_on_lane_0(%arg0)[32] -> (vector<4xf32>) {
-    %ini = "some_def"() : () -> (vector<128xf32>)
-    %ini1 = "some_def"() : () -> (vector<128xf32>)
-    %3:2 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %ini, %arg5 = %ini1) -> (vector<128xf32>, vector<128xf32>) {
-      %add = arith.addi %arg3, %c1 : index
-      %1  = "some_def"(%arg5, %add) : (vector<128xf32>, index) -> (vector<128xf32>)
-      %acc = "some_def"(%add, %arg4, %1) : (index, vector<128xf32>, vector<128xf32>) -> (vector<128xf32>)
-      scf.yield %acc, %1 : vector<128xf32>, vector<128xf32>
-    }
-    gpu.yield %3#0 : vector<128xf32>
-  }
-  "some_use"(%0) : (vector<4xf32>) -> ()
-  return
-}
-
-
 // -----
 
 // CHECK-PROP-LABEL: func @vector_reduction(

>From 57acc9e1f06bedea779ddb3e0097948f353f3ede Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Thu, 5 Jun 2025 20:20:11 +0000
Subject: [PATCH 20/32] add comment

---
 mlir/test/Dialect/XeGPU/subgroup-distribute.mlir | 1 +
 1 file changed, 1 insertion(+)

diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
index 3bfabac55faf3..7362c175a70a4 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
@@ -166,6 +166,7 @@ gpu.module @test {
 }
 
 // -----
+// TODO: gemm does not use update_nd_offset because of an issue in vector distribution. PR141853 tracks this issue.
 // CHECK-LABEL: gpu.func @gemm_loop
 // CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: memref<1024x1024xbf16>, %[[ARG1:[0-9a-zA-Z]+]]: memref<1024x1024xbf16>, %[[ARG2:[0-9a-zA-Z]+]]: memref<1024x1024xf32>) {
 // CHECK-DAG: %[[BLOCK_ID_X:.*]] = gpu.block_id x

>From a99ee751d4112c152017805449ce2c623d906adb Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Thu, 5 Jun 2025 21:14:25 +0000
Subject: [PATCH 21/32] refactor

---
 .../mlir/Dialect/XeGPU/Utils/XeGPUUtils.h     | 14 ++++++++
 .../XeGPU/Transforms/XeGPULayoutPropagate.cpp | 35 ++++++++-----------
 .../Transforms/XeGPUSubgroupDistribute.cpp    | 14 ++------
 3 files changed, 31 insertions(+), 32 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
index f9327d63869c0..23f44dcb8725d 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
@@ -24,6 +24,20 @@ class LayoutAttr;
 class TensorDescType;
 } // namespace xegpu
 
+namespace xegpu {
+/// HW dependent constants.
+/// TODO: These constants should be queried from the target information.
+namespace targetinfo {
+constexpr unsigned subgroupSize = 16; // How many lanes in a subgroup.
+/// If DPAS A or B operands have low precision element types they must be packed
+/// according to the following sizes.
+constexpr unsigned packedSizeInBitsForDefault =
+    16; // Minimum packing size per register for DPAS A.
+constexpr unsigned packedSizeInBitsForDpasB =
+    32; // Minimum packing size per register for DPAS B.
+} // namespace targetinfo
+} // namespace xegpu
+
 namespace xegpu {
 
 /// If tensor descriptor has a layout attribute it is used in SIMT mode.
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutPropagate.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutPropagate.cpp
index ce2b1454fb6a0..fb69498dacb54 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutPropagate.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutPropagate.cpp
@@ -46,16 +46,6 @@ namespace xegpu {
 using namespace mlir;
 using namespace mlir::dataflow;
 
-/// HW dependent constants.
-/// TODO: These constants should be queried from the target information.
-constexpr unsigned subgroupSize = 16; // How many lanes in a subgroup.
-/// If DPAS A or B operands have low precision element types they must be packed
-/// according to the following sizes.
-constexpr unsigned packedSizeInBitsForDefault =
-    16; // Minimum packing size per register for DPAS A.
-constexpr unsigned packedSizeInBitsForDpasB =
-    32; // Minimum packing size per register for DPAS B.
-
 namespace {
 
 //===----------------------------------------------------------------------===//
@@ -198,8 +188,10 @@ struct LayoutInfoLattice : public Lattice<LayoutInfo> {
 static LayoutInfo getDefaultLayoutInfo(unsigned rank) {
   assert((rank == 1 || rank == 2) && "Expected 1D or 2D vector.");
   if (rank == 1)
-    return LayoutInfo(LaneLayout({subgroupSize}), LaneData({1}));
-  return LayoutInfo(LaneLayout({1, subgroupSize}), LaneData({1, 1}));
+    return LayoutInfo(LaneLayout({xegpu::targetinfo::subgroupSize}),
+                      LaneData({1}));
+  return LayoutInfo(LaneLayout({1, xegpu::targetinfo::subgroupSize}),
+                    LaneData({1, 1}));
 }
 
 /// Helper to get the default layout for a vector type.
@@ -216,9 +208,9 @@ static LayoutInfo getDefaultLayoutInfo(VectorType vectorTy) {
   // Packing factor is determined by the element type bitwidth.
   int packingFactor = 1;
   unsigned bitwidth = vectorTy.getElementType().getIntOrFloatBitWidth();
-  if (bitwidth < packedSizeInBitsForDefault)
-    packingFactor = packedSizeInBitsForDefault / bitwidth;
-  return LayoutInfo(LaneLayout({1, subgroupSize}),
+  if (bitwidth < xegpu::targetinfo::packedSizeInBitsForDefault)
+    packingFactor = xegpu::targetinfo::packedSizeInBitsForDefault / bitwidth;
+  return LayoutInfo(LaneLayout({1, xegpu::targetinfo::subgroupSize}),
                     LaneData({1, packingFactor}));
 }
 
@@ -233,13 +225,14 @@ static LayoutInfo getLayoutInfoForDPASOperand(VectorType vectorTy,
   Type elementTy = vectorTy.getElementType();
   assert(elementTy.isIntOrFloat() &&
          "Expected int or float type in DPAS operands");
-  LaneLayout layout({1, subgroupSize});
+  LaneLayout layout({1, xegpu::targetinfo::subgroupSize});
   // For B operand, data must be packed in minimum `packedDpasBSizeInBits` and
   // must have the VNNI format.
-  if (operandNum == 1 &&
-      elementTy.getIntOrFloatBitWidth() < packedSizeInBitsForDpasB) {
-    LaneData data(
-        {packedSizeInBitsForDpasB / elementTy.getIntOrFloatBitWidth(), 1});
+  if (operandNum == 1 && elementTy.getIntOrFloatBitWidth() <
+                             xegpu::targetinfo::packedSizeInBitsForDpasB) {
+    LaneData data({xegpu::targetinfo::packedSizeInBitsForDpasB /
+                       elementTy.getIntOrFloatBitWidth(),
+                   1});
     return LayoutInfo(layout, data);
   }
   // Otherwise, return the default layout for the vector type.
@@ -577,7 +570,7 @@ void LayoutInfoPropagation::visitStoreScatterOp(
   ArrayRef<int64_t> tdescShape = storeScatter.getTensorDescType().getShape();
   if (tdescShape.size() > 1)
     assert(
-        tdescShape[0] == subgroupSize &&
+        tdescShape[0] == xegpu::targetinfo::subgroupSize &&
         "Expected the first dimension of 2D tensor descriptor to be equal to "
         "subgroup size.");
 
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 9ddf3abe667e2..73da16cb2e3fb 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -58,15 +58,6 @@ namespace xegpu {
 
 using namespace mlir;
 
-/// HW dependent constants.
-/// TODO: These constants should be queried from the target information.
-constexpr unsigned subgroupSize = 16; // How many lanes in a subgroup.
-/// If DPAS A or B operands have low precision element types they must be packed
-/// according to the following sizes.
-constexpr unsigned packedSizeInBitsForDefault =
-    16; // Minimum packing size per register for DPAS A.
-constexpr unsigned packedSizeInBitsForDpasB =
-    32; // Minimum packing size per register for DPAS B.
 static const char *const resolveSIMTTypeMismatch =
     "resolve_simt_type_mismatch"; // Attribute name for identifying
                                   // UnrelizedConversionCastOp added to resolve
@@ -228,8 +219,9 @@ struct MoveFuncBodyToWarpExecuteOnLane0
         /** upperBound = **/ mlir::IntegerAttr());
     ArrayRef<Type> gpuFuncResultType = gpuFuncOp.getFunctionType().getResults();
     auto warpOp = rewriter.create<gpu::WarpExecuteOnLane0Op>(
-        laneId.getLoc(), gpuFuncResultType, laneId, subgroupSize,
-        newGpuFunc.getArguments(), newGpuFunc.getArgumentTypes());
+        laneId.getLoc(), gpuFuncResultType, laneId,
+        xegpu::targetinfo::subgroupSize, newGpuFunc.getArguments(),
+        newGpuFunc.getArgumentTypes());
     Block &warpBodyBlock = warpOp.getBodyRegion().front();
     // Replace the ReturnOp of the original gpu function with a YieldOp.
     auto origRetunOp =

>From 739aad7a7743c96b7935622806de50e09ffa85bd Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Thu, 5 Jun 2025 21:26:48 +0000
Subject: [PATCH 22/32] refactor

---
 mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td         | 4 ----
 .../lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp | 5 -----
 2 files changed, 9 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
index ee25eee688095..29f936e81974e 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
@@ -27,10 +27,6 @@ def XeGPUSubgroupDistribute : Pass<"xegpu-subgroup-distribute"> {
   }];
   let dependentDialects = ["memref::MemRefDialect", "xegpu::XeGPUDialect",
                            "vector::VectorDialect"];
-  let options = [Option<
-      "printOnly", "print-analysis-only", "bool",
-      /*default=*/"false",
-      "Print the result of the subgroup map propagation analysis and exit.">];
 }
 
 def XeGPULayoutPropagate : Pass<"xegpu-layout-propagate"> {
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 73da16cb2e3fb..221c309e18a4b 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -800,11 +800,6 @@ namespace {
 struct XeGPUSubgroupDistributePass final
     : public xegpu::impl::XeGPUSubgroupDistributeBase<
           XeGPUSubgroupDistributePass> {
-  XeGPUSubgroupDistributePass() = default;
-  XeGPUSubgroupDistributePass(const XeGPUSubgroupDistributePass &other) =
-      default;
-  XeGPUSubgroupDistributePass(xegpu::XeGPUSubgroupDistributeOptions options)
-      : XeGPUSubgroupDistributeBase(options) {}
   void runOnOperation() override;
 };
 } // namespace

>From 76b7333a088d8a58c5f1aa2b7d2b3740962332cc Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Thu, 5 Jun 2025 22:32:06 +0000
Subject: [PATCH 23/32] refactor

---
 .../XeGPU/Transforms/XeGPULayoutPropagate.cpp | 100 ++++++++++--------
 .../Transforms/XeGPUSubgroupDistribute.cpp    |   1 -
 2 files changed, 56 insertions(+), 45 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutPropagate.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutPropagate.cpp
index fb69498dacb54..5ee034570ad0c 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutPropagate.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutPropagate.cpp
@@ -23,6 +23,7 @@
 #include "mlir/IR/Operation.h"
 #include "mlir/IR/Value.h"
 #include "mlir/IR/Visitors.h"
+#include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Interfaces/FunctionInterfaces.h"
 #include "mlir/Support/LLVM.h"
 #include "llvm/ADT/ArrayRef.h"
@@ -683,6 +684,22 @@ RunLayoutInfoPropagation::printAnalysisResult(llvm::raw_ostream &os) {
 }
 
 using GetLayoutCallbackFnTy = function_ref<xegpu::LayoutAttr(Value)>;
+/// Helper to update the users of a value with a given layout.
+static void updateUsers(Value v, xegpu::LayoutAttr layout) {
+  // Update all users of the value with the layout.
+  for (OpOperand &user : v.getUses()) {
+    Operation *owner = user.getOwner();
+    // Add temporary layout attribute at the user op.
+    std::string attrName = xegpu::getLayoutName(user);
+    owner->setAttr(attrName, layout);
+  }
+}
+
+/// Update an operation with the layout of its results. If the result type is a
+/// vector type, a temporary layout attribute is added to the operation. If the
+/// result type is a tensor descriptor type, the type is updated with the layout
+/// attribute. The users of the result are also updated with the layout
+/// attribute.
 static void updateOp(mlir::OpBuilder &builder, mlir::Operation *op,
                      GetLayoutCallbackFnTy getLayoutOfValue) {
 
@@ -712,14 +729,12 @@ static void updateOp(mlir::OpBuilder &builder, mlir::Operation *op,
     std::string resultLayoutName = xegpu::getLayoutName(result);
     op->setAttr(resultLayoutName, layout);
     // Update all users of the result with the layout.
-    for (OpOperand &user : result.getUses()) {
-      Operation *owner = user.getOwner();
-      // Add temorary layout attribute at the user op.
-      std::string attrName = xegpu::getLayoutName(user);
-      owner->setAttr(attrName, layout);
-    }
+    updateUsers(result, layout);
   }
 }
+
+/// Update the types of successor regions of a branch terminator op (scf.yield)
+/// with assigned layouts.
 static void updateBranchTerminatorOpInterface(
     mlir::OpBuilder &builder,
     mlir::RegionBranchTerminatorOpInterface terminator,
@@ -769,6 +784,10 @@ static void updateBranchTerminatorOpInterface(
     }
   }
 }
+
+/// Some operations contain multiple regions (like scf.for) each of which have
+/// block arguments. This function updates the block arguments types of such
+/// regions with the assigned layouts.
 static void updateBranchOpInterface(mlir::OpBuilder &builder,
                                     mlir::RegionBranchOpInterface branch,
                                     GetLayoutCallbackFnTy getLayoutOfValue) {
@@ -790,33 +809,32 @@ static void updateBranchOpInterface(mlir::OpBuilder &builder,
       Type inputType = input.getType();
       if (!isa<xegpu::TensorDescType>(inputType))
         continue;
-      xegpu::LayoutAttr blockArgLayout = getLayoutOfValue(input);
-      xegpu::LayoutAttr initArgLayout = getLayoutOfValue(operand);
+      xegpu::LayoutAttr inputLayout = getLayoutOfValue(input);
+      xegpu::LayoutAttr operandLayout = getLayoutOfValue(operand);
 
-      if (!blockArgLayout || !initArgLayout) {
+      if (!inputLayout || !operandLayout) {
         LLVM_DEBUG(DBGS() << "No layout assigned for block arg: " << input
                           << " or init arg: " << operand << "\n");
         continue;
       }
 
-      // TOOD: We expect these two to match. Data flow analysis will ensure
-      // this.
-      assert(blockArgLayout == initArgLayout &&
+      // TODO: We expect these two to match.
+      assert(inputLayout == operandLayout &&
              "Expexing block arg and init arg to have the same layout.");
       // Get tensor descriptor type with the layout.
       auto tdescTy = dyn_cast<xegpu::TensorDescType>(inputType);
       auto newTdescTy = xegpu::TensorDescType::get(
           tdescTy.getContext(), tdescTy.getShape(), tdescTy.getElementType(),
-          tdescTy.getEncoding(), blockArgLayout);
+          tdescTy.getEncoding(), inputLayout);
       input.setType(newTdescTy);
       // Store the layout for the result.
       if (resultToLayouts.count(result) != 0 &&
-          resultToLayouts[result] != blockArgLayout) {
+          resultToLayouts[result] != inputLayout) {
         LLVM_DEBUG(DBGS() << "Conflicting layouts for result: " << result
                           << " - " << resultToLayouts[result] << " vs "
-                          << blockArgLayout << "\n");
+                          << inputLayout << "\n");
       } else {
-        resultToLayouts[result] = blockArgLayout;
+        resultToLayouts[result] = inputLayout;
       }
     }
   }
@@ -844,15 +862,11 @@ static void updateBranchOpInterface(mlir::OpBuilder &builder,
     std::string resultLayoutName = xegpu::getLayoutName(r);
     op->setAttr(resultLayoutName, layout);
     // Update all users of the result with the layout.
-    for (OpOperand &user : r.getUses()) {
-      Operation *owner = user.getOwner();
-      // Add temporary layout attribute at the user op.
-      std::string attrName = xegpu::getLayoutName(user);
-      owner->setAttr(attrName, layout);
-    }
+    updateUsers(r, layout);
   }
 }
 
+/// Update the function arguments and results with the layouts.
 static void updateFunctionOpInterface(mlir::OpBuilder &builder,
                                       mlir::FunctionOpInterface funcOp,
                                       GetLayoutCallbackFnTy getLayoutOfValue) {
@@ -879,11 +893,7 @@ static void updateFunctionOpInterface(mlir::OpBuilder &builder,
     }
     // If the argument is a vector type, update all the users of the argument
     // with the layout.
-    for (OpOperand &user : arg.getUses()) {
-      Operation *owner = user.getOwner();
-      std::string attrName = xegpu::getLayoutName(user);
-      owner->setAttr(attrName, layout);
-    }
+    updateUsers(arg, layout);
   }
   // Update the function type with the new argument types.
   // NOTE: We assume that function results are not expected to have layouts.
@@ -902,7 +912,7 @@ struct XeGPULayoutPropagatePass final
 
 void XeGPULayoutPropagatePass::runOnOperation() {
   auto &analyis = getAnalysis<RunLayoutInfoPropagation>();
-
+  // Helper to convert LayoutInfo to xegpu::LayoutAttr.
   auto getXeGPULayoutForValue = [&](Value val) -> xegpu::LayoutAttr {
     LayoutInfo layout = analyis.getLayoutInfo(val);
     if (!layout.isAssigned()) {
@@ -921,23 +931,25 @@ void XeGPULayoutPropagatePass::runOnOperation() {
   Operation *op = getOperation();
   op->walk([&](mlir::Block *block) {
     for (mlir::Operation &op : llvm::reverse(block->getOperations())) {
-      if (auto branchTermOp =
-              mlir::dyn_cast<mlir::RegionBranchTerminatorOpInterface>(op)) {
-        updateBranchTerminatorOpInterface(builder, branchTermOp,
+      TypeSwitch<Operation *>(&op)
+          .Case<mlir::RegionBranchTerminatorOpInterface>(
+              [&](mlir::RegionBranchTerminatorOpInterface branchTermOp) {
+                updateBranchTerminatorOpInterface(builder, branchTermOp,
+                                                  getXeGPULayoutForValue);
+              })
+          .Case<mlir::RegionBranchOpInterface>(
+              [&](mlir::RegionBranchOpInterface regionBrOp) {
+                updateBranchOpInterface(builder, regionBrOp,
+                                        getXeGPULayoutForValue);
+              })
+          .Case<mlir::FunctionOpInterface>(
+              [&](mlir::FunctionOpInterface funcOp) {
+                updateFunctionOpInterface(builder, funcOp,
                                           getXeGPULayoutForValue);
-        continue;
-      }
-
-      if (auto regionBrOp = mlir::dyn_cast<mlir::RegionBranchOpInterface>(op)) {
-        updateBranchOpInterface(builder, regionBrOp, getXeGPULayoutForValue);
-        continue;
-      }
-
-      if (auto funcOp = mlir::dyn_cast<mlir::FunctionOpInterface>(op)) {
-        updateFunctionOpInterface(builder, funcOp, getXeGPULayoutForValue);
-        continue;
-      }
-      updateOp(builder, &op, getXeGPULayoutForValue);
+              })
+          .Default([&](Operation *op) {
+            updateOp(builder, op, getXeGPULayoutForValue);
+          });
     }
   });
 }
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 221c309e18a4b..eb8192417f843 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -812,7 +812,6 @@ void xegpu::populateXeGPUSubgroupDistributePatterns(
 }
 
 void XeGPUSubgroupDistributePass::runOnOperation() {
-
   // Move all operations of a GPU function inside gpu.warp_execute_on_lane_0
   // operation.
   {

>From cbcfd61b7c9c0e2d165ef319f57a978350ca6ddf Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Mon, 9 Jun 2025 23:32:50 +0000
Subject: [PATCH 24/32] address comments

---
 .../mlir/Dialect/XeGPU/Transforms/Passes.td   |  7 ++-
 .../XeGPU/Transforms/XeGPULayoutPropagate.cpp | 59 +++++++++++--------
 2 files changed, 37 insertions(+), 29 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
index 29f936e81974e..bf95dae69518d 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
@@ -30,12 +30,13 @@ def XeGPUSubgroupDistribute : Pass<"xegpu-subgroup-distribute"> {
 }
 
 def XeGPULayoutPropagate : Pass<"xegpu-layout-propagate"> {
-  let summary = "Propagate XeGPU layout information";
+  let summary = "Propagate and assign XeGPU layout information";
   let description = [{
     This pass propagates the XeGPU layout information accross ops. Starting
     from a set of anchor operations (e.g. `dpas`, `store_nd`), this will
-    propagate the layouts required for operands and results to the producers or
-    consumers.
+    propagate the layouts required for their operands to the producers. With
+    this propagated layout information, pass will then update the XeGPU tensor
+    descriptor type with the layout information.
   }];
   let dependentDialects = ["memref::MemRefDialect", "xegpu::XeGPUDialect",
                            "vector::VectorDialect"];
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutPropagate.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutPropagate.cpp
index 5ee034570ad0c..1f6ba5f1a6064 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutPropagate.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutPropagate.cpp
@@ -30,6 +30,7 @@
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/Casting.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/InterleavedRange.h"
 #include "llvm/Support/raw_ostream.h"
@@ -103,6 +104,7 @@ struct LayoutInfo {
 private:
   LaneLayout laneLayout;
   LaneData laneData;
+  xegpu::LayoutAttr layoutAttr;
 
 public:
   LayoutInfo() = default;
@@ -186,7 +188,7 @@ struct LayoutInfoLattice : public Lattice<LayoutInfo> {
 /// Helper Function to get the default layout for uniform values like constants.
 /// For 1D vector, lane_layout is [subgroupSize] and lane_data is [1].
 /// For 2D vector, lane_layout is [1, subgroupSize] and lane_data is [1, 1].
-static LayoutInfo getDefaultLayoutInfo(unsigned rank) {
+static LayoutInfo getDefaultSIMTLayoutInfo(unsigned rank) {
   assert((rank == 1 || rank == 2) && "Expected 1D or 2D vector.");
   if (rank == 1)
     return LayoutInfo(LaneLayout({xegpu::targetinfo::subgroupSize}),
@@ -196,7 +198,7 @@ static LayoutInfo getDefaultLayoutInfo(unsigned rank) {
 }
 
 /// Helper to get the default layout for a vector type.
-static LayoutInfo getDefaultLayoutInfo(VectorType vectorTy) {
+static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy) {
   // Expecting a 1D or 2D vector.
   assert((vectorTy.getRank() == 1 || vectorTy.getRank() == 2) &&
          "Expected 1D or 2D vector.");
@@ -205,7 +207,7 @@ static LayoutInfo getDefaultLayoutInfo(VectorType vectorTy) {
          "Expected int or float element type.");
   // If the rank is 1, then return default layout for 1D vector.
   if (vectorTy.getRank() == 1)
-    return getDefaultLayoutInfo(1);
+    return getDefaultSIMTLayoutInfo(1);
   // Packing factor is determined by the element type bitwidth.
   int packingFactor = 1;
   unsigned bitwidth = vectorTy.getElementType().getIntOrFloatBitWidth();
@@ -221,8 +223,8 @@ static LayoutInfo getDefaultLayoutInfo(VectorType vectorTy) {
 /// `packedSizeInBitsForDefault`
 /// * For B operand, the data must be packed in minimum
 /// `packedSizeInBitsForDpasB`
-static LayoutInfo getLayoutInfoForDPASOperand(VectorType vectorTy,
-                                              unsigned operandNum) {
+static LayoutInfo getSIMTLayoutInfoForDPASOperand(VectorType vectorTy,
+                                                  unsigned operandNum) {
   Type elementTy = vectorTy.getElementType();
   assert(elementTy.isIntOrFloat() &&
          "Expected int or float type in DPAS operands");
@@ -237,7 +239,7 @@ static LayoutInfo getLayoutInfoForDPASOperand(VectorType vectorTy,
     return LayoutInfo(layout, data);
   }
   // Otherwise, return the default layout for the vector type.
-  return getDefaultLayoutInfo(vectorTy);
+  return getDefaultSIMTLayoutInfo(vectorTy);
 }
 
 //===----------------------------------------------------------------------===//
@@ -360,17 +362,18 @@ LogicalResult LayoutInfoPropagation::visitOperation(
       // All other ops.
       .Default([&](Operation *op) {
         for (const LayoutInfoLattice *r : results) {
-          for (LayoutInfoLattice *operand : operands) {
-            // Propagate the layout of the result to the operand.
-            if (r->getValue().isAssigned())
+          if (r->getValue().isAssigned()) {
+            for (LayoutInfoLattice *operand : operands) {
+              // Propagate the layout of the result to the operand.
               meet(operand, *r);
+            }
           }
         }
       });
   // Add a dependency from each result to program point after the operation.
-  for (const LayoutInfoLattice *r : results) {
+  for (const LayoutInfoLattice *r : results)
     addDependency(const_cast<LayoutInfoLattice *>(r), getProgramPointAfter(op));
-  }
+
   return success();
 }
 
@@ -380,7 +383,7 @@ void LayoutInfoPropagation::visitPrefetchNdOp(
   // Here we assign the default layout to the tensor descriptor operand of
   // prefetch.
   auto tdescTy = prefetch.getTensorDescType();
-  auto prefetchLayout = getDefaultLayoutInfo(
+  auto prefetchLayout = getDefaultSIMTLayoutInfo(
       VectorType::get(tdescTy.getShape(), tdescTy.getElementType()));
   // Propagate the layout to the source tensor descriptor.
   propagateIfChanged(operands[0], operands[0]->meet(prefetchLayout));
@@ -395,11 +398,13 @@ void LayoutInfoPropagation::visitVectorMultiReductionOp(
   if (!resultLayout.isAssigned())
     return;
   // We only consider 2D -> 1D reductions at this point.
-  assert(resultLayout.getLayout().size() == 1 &&
-         "Expected 1D layout for reduction result.");
+  if (resultLayout.getLayout().size() != 1) {
+    reduction.emitWarning("Expected 1D layout for reduction result. ");
+    return;
+  }
   // Given that the result is 1D, the layout of the operand should be 2D with
   // default layout.
-  LayoutInfo operandLayout = getDefaultLayoutInfo(2);
+  LayoutInfo operandLayout = getDefaultSIMTLayoutInfo(2);
   propagateIfChanged(operands[0], operands[0]->meet(operandLayout));
   // Accumulator should have the same layout as the result.
   propagateIfChanged(operands[1], operands[1]->meet(resultLayout));
@@ -425,14 +430,15 @@ void LayoutInfoPropagation::visitDpasOp(
     ArrayRef<const LayoutInfoLattice *> results) {
   VectorType aTy = dpas.getLhsType();
   VectorType bTy = dpas.getRhsType();
-  propagateIfChanged(operands[0],
-                     operands[0]->meet(getLayoutInfoForDPASOperand(aTy, 0)));
-  propagateIfChanged(operands[1],
-                     operands[1]->meet(getLayoutInfoForDPASOperand(bTy, 1)));
+  propagateIfChanged(
+      operands[0], operands[0]->meet(getSIMTLayoutInfoForDPASOperand(aTy, 0)));
+  propagateIfChanged(
+      operands[1], operands[1]->meet(getSIMTLayoutInfoForDPASOperand(bTy, 1)));
   if (operands.size() > 2) {
     VectorType cTy = dpas.getAccType();
-    propagateIfChanged(operands[2],
-                       operands[2]->meet(getLayoutInfoForDPASOperand(cTy, 2)));
+    propagateIfChanged(
+        operands[2],
+        operands[2]->meet(getSIMTLayoutInfoForDPASOperand(cTy, 2)));
   }
 }
 
@@ -440,7 +446,7 @@ void LayoutInfoPropagation::visitDpasOp(
 void LayoutInfoPropagation::visitStoreNdOp(
     xegpu::StoreNdOp store, ArrayRef<LayoutInfoLattice *> operands,
     ArrayRef<const LayoutInfoLattice *> results) {
-  LayoutInfo storeLayout = getDefaultLayoutInfo(store.getValueType());
+  LayoutInfo storeLayout = getDefaultSIMTLayoutInfo(store.getValueType());
   // Both operands should have the same layout
   for (LayoutInfoLattice *operand : operands) {
     propagateIfChanged(operand, operand->meet(storeLayout));
@@ -539,7 +545,7 @@ void LayoutInfoPropagation::visitLoadGatherOp(
     tensorDescLayout = valueLayout.getTransposedLayout({1, 0});
   }
   // Mask operand should have 1D default layout.
-  LayoutInfo maskLayout = getDefaultLayoutInfo(1);
+  LayoutInfo maskLayout = getDefaultSIMTLayoutInfo(1);
   // Propagate the new layout to the tensor descriptor operand.
   propagateIfChanged(operands[0], operands[0]->meet(tensorDescLayout));
   // Propagate the new layout to the mask operand.
@@ -556,7 +562,7 @@ void LayoutInfoPropagation::visitCreateDescOp(
   if (!descLayout.isAssigned())
     return;
   // For offset operand propagate 1D default layout.
-  LayoutInfo layout = getDefaultLayoutInfo(1);
+  LayoutInfo layout = getDefaultSIMTLayoutInfo(1);
   propagateIfChanged(operands[1], operands[1]->meet(layout));
 }
 
@@ -575,7 +581,8 @@ void LayoutInfoPropagation::visitStoreScatterOp(
         "Expected the first dimension of 2D tensor descriptor to be equal to "
         "subgroup size.");
 
-  LayoutInfo valueLayout = getDefaultLayoutInfo(storeScatter.getValueType());
+  LayoutInfo valueLayout =
+      getDefaultSIMTLayoutInfo(storeScatter.getValueType());
   LayoutInfo storeScatterLayout = valueLayout;
   if (storeScatter.getTranspose()) {
     // StoreScatteOp allows transpose effect. However, at the stage of this
@@ -590,7 +597,7 @@ void LayoutInfoPropagation::visitStoreScatterOp(
   // Propagate the tensor descriptor layout.
   propagateIfChanged(operands[1], operands[1]->meet(storeScatterLayout));
   // Use default 1D layout for mask operand.
-  LayoutInfo maskLayout = getDefaultLayoutInfo(1);
+  LayoutInfo maskLayout = getDefaultSIMTLayoutInfo(1);
   propagateIfChanged(operands[2], operands[2]->meet(maskLayout));
 }
 

>From 0f796970a0424881f0d8bcc5e260a8462ca81f1c Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Tue, 10 Jun 2025 20:50:06 +0000
Subject: [PATCH 25/32] fix bitcast

---
 .../XeGPU/Transforms/XeGPULayoutPropagate.cpp | 25 ++++---------
 mlir/test/Dialect/XeGPU/layout-propagate.mlir | 35 +++++--------------
 2 files changed, 16 insertions(+), 44 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutPropagate.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutPropagate.cpp
index 1f6ba5f1a6064..c8462140e8788 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutPropagate.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutPropagate.cpp
@@ -503,26 +503,15 @@ void LayoutInfoPropagation::visitVectorBitcastOp(
   int outElemTyBitWidth =
       bitcast.getResultVectorType().getElementType().getIntOrFloatBitWidth();
 
-  // LaneLayout does not change.
-  const LaneLayout &newLaneLayout = resultLayout.getLayout();
-  const LaneData &currData = resultLayout.getData();
-  LaneData newLaneData;
-  // It's a widening bitcast
-  if (inElemTyBitWidth < outElemTyBitWidth) {
-    int ratio = outElemTyBitWidth / inElemTyBitWidth;
-    newLaneData = resultLayout.getData()[0] == 1
-                      ? LaneData({1, currData[1] * ratio})
-                      : LaneData({currData[0] * ratio, 1});
-  } else {
-    // It's a narrowing bitcast
-    int ratio = inElemTyBitWidth / outElemTyBitWidth;
-    newLaneData = resultLayout.getData()[0] == 1
-                      ? LaneData({1, currData[1] / ratio})
-                      : LaneData({currData[0] / ratio, 1});
+  // NOTE: We do not expect widening or narrowing bitcasts at this stage. Emit a
+  // warning and return.
+  if (inElemTyBitWidth != outElemTyBitWidth) {
+    bitcast.emitWarning("Widening or narrowing bitcasts are not expected at "
+                        "layout propagation stage.");
+    return;
   }
 
-  propagateIfChanged(operands[0],
-                     operands[0]->meet(LayoutInfo(newLaneLayout, newLaneData)));
+  propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
 }
 
 /// Propagate the layout of the result to the tensor descriptor and mask
diff --git a/mlir/test/Dialect/XeGPU/layout-propagate.mlir b/mlir/test/Dialect/XeGPU/layout-propagate.mlir
index f698b997e8cb7..b8f5546dd8b6b 100644
--- a/mlir/test/Dialect/XeGPU/layout-propagate.mlir
+++ b/mlir/test/Dialect/XeGPU/layout-propagate.mlir
@@ -148,35 +148,18 @@ func.func @store_scatter_1d(%arg0: vector<16xf32>, %arg1: memref<256xf32>) {
 }
 
 // -----
-// CHECK-LABEL: func.func @vector_bitcast_i16_to_i8(
-// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<8x16xi16>, %[[ARG1:[0-9a-zA-Z]+]]: memref<32x16xi8>, %[[ARG2:[0-9a-zA-Z]+]]: memref<8x16xi32>) {
-// CHECK: %{{.*}} = vector.bitcast %{{.*}} {layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]>} : vector<8x16xi16> to vector<8x32xi8>
-func.func @vector_bitcast_i16_to_i8(%arg0: memref<8x16xi16>, %arg1: memref<32x16xi8>, %arg2: memref<8x16xi32>) {
+// CHECK-LABEL: func.func @vector_bitcast_i16_to_f16(
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<8x16xi16>, %[[ARG1:[0-9a-zA-Z]+]]: memref<16x16xi16>, %[[ARG2:[0-9a-zA-Z]+]]: memref<8x16xf32>) {
+// CHECK: %{{.*}} = vector.bitcast %{{.*}} {layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<8x16xi16> to vector<8x16xf16>
+// CHECK: %{{.*}} = vector.bitcast %{{.*}} {layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>, layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} : vector<16x16xi16> to vector<16x16xf16>
+func.func @vector_bitcast_i16_to_f16(%arg0: memref<8x16xi16>, %arg1: memref<16x16xi16>, %arg2: memref<8x16xf32>) {
   %c0 = arith.constant 0 : index
   %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x16xi16> -> !xegpu.tensor_desc<8x16xi16>
-  %1 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<32x16xi8> -> !xegpu.tensor_desc<32x16xi8>
+  %1 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<16x16xi16> -> !xegpu.tensor_desc<16x16xi16>
   %2 = xegpu.load_nd %0  : !xegpu.tensor_desc<8x16xi16> -> vector<8x16xi16>
-  %3 = xegpu.load_nd %1  : !xegpu.tensor_desc<32x16xi8> -> vector<32x16xi8>
-  %4 = vector.bitcast %2 : vector<8x16xi16> to vector<8x32xi8>
-  %5 = xegpu.dpas %4, %3 : vector<8x32xi8>, vector<32x16xi8> -> vector<8x16xi32>
-  %6 = xegpu.create_nd_tdesc %arg2[%c0, %c0] : memref<8x16xi32> -> !xegpu.tensor_desc<8x16xi32>
-  xegpu.store_nd %5, %6  : vector<8x16xi32>, !xegpu.tensor_desc<8x16xi32>
-  return
-}
-
-// -----
-// CHECK-LABEL: func.func @vector_bitcast_i8_to_f16(
-// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<8x32xi8>, %[[ARG1:[0-9a-zA-Z]+]]: memref<16x32xi8>, %[[ARG2:[0-9a-zA-Z]+]]: memref<8x16xf32>) {
-// CHECK: %{{.*}} = vector.bitcast %{{.*}} {layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]>, layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<8x32xi8> to vector<8x16xf16>
-// CHECK: %{{.*}} = vector.bitcast %{{.*}} {layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [4, 1]>, layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} : vector<16x32xi8> to vector<16x16xf16>
-func.func @vector_bitcast_i8_to_f16(%arg0: memref<8x32xi8>, %arg1: memref<16x32xi8>, %arg2: memref<8x16xf32>) {
-  %c0 = arith.constant 0 : index
-  %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x32xi8> -> !xegpu.tensor_desc<8x32xi8>
-  %1 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<16x32xi8> -> !xegpu.tensor_desc<16x32xi8>
-  %2 = xegpu.load_nd %0  : !xegpu.tensor_desc<8x32xi8> -> vector<8x32xi8>
-  %3 = xegpu.load_nd %1  : !xegpu.tensor_desc<16x32xi8> -> vector<16x32xi8>
-  %4 = vector.bitcast %2 : vector<8x32xi8> to vector<8x16xf16>
-  %5 = vector.bitcast %3 : vector<16x32xi8> to vector<16x16xf16>
+  %3 = xegpu.load_nd %1  : !xegpu.tensor_desc<16x16xi16> -> vector<16x16xi16>
+  %4 = vector.bitcast %2 : vector<8x16xi16> to vector<8x16xf16>
+  %5 = vector.bitcast %3 : vector<16x16xi16> to vector<16x16xf16>
   %6 = xegpu.dpas %4, %5 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
   %7 = xegpu.create_nd_tdesc %arg2[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
   xegpu.store_nd %6, %7  : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>

>From 74bf971a69bff63e47cf555e685418762f069dc4 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Wed, 11 Jun 2025 20:24:24 +0000
Subject: [PATCH 26/32] address comments

---
 .../XeGPU/Transforms/XeGPULayoutPropagate.cpp     | 15 ++++++++-------
 1 file changed, 8 insertions(+), 7 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutPropagate.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutPropagate.cpp
index c8462140e8788..ede190ca4ad44 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutPropagate.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutPropagate.cpp
@@ -398,8 +398,9 @@ void LayoutInfoPropagation::visitVectorMultiReductionOp(
   if (!resultLayout.isAssigned())
     return;
   // We only consider 2D -> 1D reductions at this point.
-  if (resultLayout.getLayout().size() != 1) {
-    reduction.emitWarning("Expected 1D layout for reduction result. ");
+  VectorType resultTy = llvm::dyn_cast<VectorType>(reduction.getDestType());
+  if (!resultTy || resultTy.getRank() != 1) {
+    reduction.emitWarning("Expecting output type to be 1D vector.");
     return;
   }
   // Given that the result is 1D, the layout of the operand should be 2D with
@@ -679,7 +680,7 @@ RunLayoutInfoPropagation::printAnalysisResult(llvm::raw_ostream &os) {
   }
 }
 
-using GetLayoutCallbackFnTy = function_ref<xegpu::LayoutAttr(Value)>;
+using GetLayoutFnTy = function_ref<xegpu::LayoutAttr(Value)>;
 /// Helper to update the users of a value with a given layout.
 static void updateUsers(Value v, xegpu::LayoutAttr layout) {
   // Update all users of the value with the layout.
@@ -697,7 +698,7 @@ static void updateUsers(Value v, xegpu::LayoutAttr layout) {
 /// attribute. The users of the result are also updated with the layout
 /// attribute.
 static void updateOp(mlir::OpBuilder &builder, mlir::Operation *op,
-                     GetLayoutCallbackFnTy getLayoutOfValue) {
+                     GetLayoutFnTy getLayoutOfValue) {
 
   // Iterate over all the results.
   for (OpResult result : op->getResults()) {
@@ -734,7 +735,7 @@ static void updateOp(mlir::OpBuilder &builder, mlir::Operation *op,
 static void updateBranchTerminatorOpInterface(
     mlir::OpBuilder &builder,
     mlir::RegionBranchTerminatorOpInterface terminator,
-    GetLayoutCallbackFnTy getLayoutOfValue) {
+    GetLayoutFnTy getLayoutOfValue) {
   if (!mlir::isa<mlir::RegionBranchOpInterface>(terminator->getParentOp()))
     return;
 
@@ -786,7 +787,7 @@ static void updateBranchTerminatorOpInterface(
 /// regions with the assigned layouts.
 static void updateBranchOpInterface(mlir::OpBuilder &builder,
                                     mlir::RegionBranchOpInterface branch,
-                                    GetLayoutCallbackFnTy getLayoutOfValue) {
+                                    GetLayoutFnTy getLayoutOfValue) {
   mlir::Operation *op = branch.getOperation();
   llvm::SmallVector<mlir::RegionSuccessor> successors;
   llvm::SmallVector<mlir::Attribute> operands(op->getNumOperands(), nullptr);
@@ -865,7 +866,7 @@ static void updateBranchOpInterface(mlir::OpBuilder &builder,
 /// Update the function arguments and results with the layouts.
 static void updateFunctionOpInterface(mlir::OpBuilder &builder,
                                       mlir::FunctionOpInterface funcOp,
-                                      GetLayoutCallbackFnTy getLayoutOfValue) {
+                                      GetLayoutFnTy getLayoutOfValue) {
   SmallVector<Type> newArgTypes;
   // Update the function arguments.
   for (BlockArgument arg : funcOp.getArguments()) {

>From d6969bc8a52bcd906f471e5e6f792bfe7db792be Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Wed, 11 Jun 2025 20:53:13 +0000
Subject: [PATCH 27/32] address comments

---
 mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutPropagate.cpp | 3 ---
 1 file changed, 3 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutPropagate.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutPropagate.cpp
index ede190ca4ad44..64e2271d9423b 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutPropagate.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutPropagate.cpp
@@ -370,9 +370,6 @@ LogicalResult LayoutInfoPropagation::visitOperation(
           }
         }
       });
-  // Add a dependency from each result to program point after the operation.
-  for (const LayoutInfoLattice *r : results)
-    addDependency(const_cast<LayoutInfoLattice *>(r), getProgramPointAfter(op));
 
   return success();
 }

>From d5e4c6c55b94ccedf46c1447dc75499025a6e38e Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Thu, 12 Jun 2025 00:27:59 +0000
Subject: [PATCH 28/32] address comments

---
 .../XeGPU/Transforms/XeGPULayoutPropagate.cpp | 35 ++++-------
 mlir/test/Dialect/XeGPU/layout-propagate.mlir | 60 +++++++++----------
 2 files changed, 41 insertions(+), 54 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutPropagate.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutPropagate.cpp
index 64e2271d9423b..8c5a0163d1a43 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutPropagate.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutPropagate.cpp
@@ -62,18 +62,12 @@ struct Layout {
   Layout(std::initializer_list<int64_t> list) : layout(list) {}
   void print(llvm::raw_ostream &os) const;
   size_t size() const { return layout.size(); }
-  int64_t operator[](size_t idx) const;
 };
 
 void Layout::print(llvm::raw_ostream &os) const {
   os << llvm::interleaved_array(layout);
 }
 
-int64_t Layout::operator[](size_t idx) const {
-  assert(idx < layout.size() && "Index out of bounds.");
-  return layout[idx];
-}
-
 /// LaneLayout represents the logical layout of lanes within a subgroup when it
 /// accesses some value. LaneData represents the logical layout of data owned by
 /// each work item.
@@ -679,15 +673,15 @@ RunLayoutInfoPropagation::printAnalysisResult(llvm::raw_ostream &os) {
 
 using GetLayoutFnTy = function_ref<xegpu::LayoutAttr(Value)>;
 /// Helper to update the users of a value with a given layout.
-static void updateUsers(Value v, xegpu::LayoutAttr layout) {
-  // Update all users of the value with the layout.
-  for (OpOperand &user : v.getUses()) {
-    Operation *owner = user.getOwner();
-    // Add temporary layout attribute at the user op.
-    std::string attrName = xegpu::getLayoutName(user);
-    owner->setAttr(attrName, layout);
-  }
-}
+// static void updateUsers(Value v, xegpu::LayoutAttr layout) {
+//   // Update all users of the value with the layout.
+//   for (OpOperand &user : v.getUses()) {
+//     Operation *owner = user.getOwner();
+//     // Add temporary layout attribute at the user op.
+//     std::string attrName = xegpu::getLayoutName(user);
+//     owner->setAttr(attrName, layout);
+//   }
+// }
 
 /// Update an operation with the layout of its results. If the result type is a
 /// vector type, a temporary layout attribute is added to the operation. If the
@@ -721,9 +715,7 @@ static void updateOp(mlir::OpBuilder &builder, mlir::Operation *op,
     // If the result is a vector type, add a temporary layout attribute to the
     // op.
     std::string resultLayoutName = xegpu::getLayoutName(result);
-    op->setAttr(resultLayoutName, layout);
-    // Update all users of the result with the layout.
-    updateUsers(result, layout);
+    xegpu::setLayoutAttr(result, layout);
   }
 }
 
@@ -854,9 +846,7 @@ static void updateBranchOpInterface(mlir::OpBuilder &builder,
     // If the result is a vector type, add a temporary layout attribute to
     // the op.
     std::string resultLayoutName = xegpu::getLayoutName(r);
-    op->setAttr(resultLayoutName, layout);
-    // Update all users of the result with the layout.
-    updateUsers(r, layout);
+    xegpu::setLayoutAttr(r, layout);
   }
 }
 
@@ -885,9 +875,6 @@ static void updateFunctionOpInterface(mlir::OpBuilder &builder,
       newArgTypes.back() = newTdescTy;
       continue;
     }
-    // If the argument is a vector type, update all the users of the argument
-    // with the layout.
-    updateUsers(arg, layout);
   }
   // Update the function type with the new argument types.
   // NOTE: We assume that function results are not expected to have layouts.
diff --git a/mlir/test/Dialect/XeGPU/layout-propagate.mlir b/mlir/test/Dialect/XeGPU/layout-propagate.mlir
index b8f5546dd8b6b..e0534fe29d377 100644
--- a/mlir/test/Dialect/XeGPU/layout-propagate.mlir
+++ b/mlir/test/Dialect/XeGPU/layout-propagate.mlir
@@ -7,9 +7,9 @@
 // CHECK: %[[T1:.*]] = xegpu.create_nd_tdesc %[[ARG1]][{{.*}}] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>>
 // CHECK: %[[T2:.*]] = xegpu.load_nd %[[T0]]  {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : !xegpu.tensor_desc<8x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<8x16xf16>
 // CHECK: %[[T3:.*]] = xegpu.load_nd %[[T1]]  {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>> -> vector<16x16xf16>
-// CHECK: %[[T4:.*]] = xegpu.dpas %[[T2]], %[[T3]], %[[CST]] {layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, layout_operand_1 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>, layout_operand_2 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
+// CHECK: %[[T4:.*]] = xegpu.dpas %[[T2]], %[[T3]], %[[CST]] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
 // CHECK: %[[T5:.*]] = xegpu.create_nd_tdesc %[[ARG2]][{{.*}}] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
-// CHECK: xegpu.store_nd %[[T4]], %[[T5]]  {layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+// CHECK: xegpu.store_nd %[[T4]], %[[T5]] : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
 func.func @dpas_f16(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<8x16xf32>) {
   %c0 = arith.constant 0 : index
   %cst = arith.constant dense<0.000000e+00> : vector<8x16xf32>
@@ -26,7 +26,7 @@ func.func @dpas_f16(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: me
 // -----
 // CHECK-LABEL: func.func @dpas_i8(
 // CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: vector<8x32xi8>, %[[ARG1:[0-9a-zA-Z]+]]: vector<32x16xi8>, %[[ARG2:[0-9a-zA-Z]+]]: memref<8x16xi32>) {
-// CHECK: %[[T0:.*]] = xegpu.dpas %[[ARG0]], %[[ARG1]] {layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]>, layout_operand_1 = #xegpu.layout<lane_layout = [1, 16], lane_data = [4, 1]>, layout_result_0 = #xegpu.layout<lane_layout = [1, 16],
+// CHECK: %[[T0:.*]] = xegpu.dpas %[[ARG0]], %[[ARG1]] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16],
 func.func @dpas_i8(%arg0: vector<8x32xi8>, %arg1: vector<32x16xi8>, %arg2: memref<8x16xi32>) {
   %c0 = arith.constant 0 : index
   %0 = xegpu.dpas %arg0, %arg1 : vector<8x32xi8>, vector<32x16xi8> -> vector<8x16xi32>
@@ -55,7 +55,7 @@ func.func @load_with_transpose_effect(%arg0: memref<8x16xf16>, %arg1: memref<16x
 // -----
 // CHECK-LABEL: func.func @vector_transpose(
 // CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<8x16xf16>, %[[ARG1:[0-9a-zA-Z]+]]: memref<16x16xf16>, %[[ARG2:[0-9a-zA-Z]+]]: memref<8x16xf32>) {
-// CHECK: %{{.*}} = vector.transpose %{{.*}}, [1, 0] {layout_operand_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>, layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} : vector<16x16xf16> to vector<16x16xf16>
+// CHECK: %{{.*}} = vector.transpose %{{.*}}, [1, 0] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} : vector<16x16xf16> to vector<16x16xf16>
 func.func @vector_transpose(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<8x16xf32>) {
   %c0 = arith.constant 0 : index
   %cst = arith.constant dense<0.000000e+00> : vector<8x16xf32>
@@ -73,8 +73,8 @@ func.func @vector_transpose(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %
 // -----
 // CHECK-LABEL: func.func @extf_truncf(
 // CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<8x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>, %[[ARG1:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>>) -> vector<8x16xf32> {
-// CHECK: %[[T2:.*]] = arith.extf %{{.*}} {layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>, layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} : vector<16x16xf16> to vector<16x16xf32>
-// CHECK-NEXT: %{{.*}} = arith.truncf %[[T2]] {layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>, layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} : vector<16x16xf32> to vector<16x16xf16>
+// CHECK: %[[T2:.*]] = arith.extf %{{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} : vector<16x16xf16> to vector<16x16xf32>
+// CHECK-NEXT: %{{.*}} = arith.truncf %[[T2]] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} : vector<16x16xf32> to vector<16x16xf16>
 func.func @extf_truncf(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>) -> vector<8x16xf32> {
   %0 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
   %1 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
@@ -89,8 +89,8 @@ func.func @extf_truncf(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor
 // CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<8x16xf16>, %[[ARG1:[0-9a-zA-Z]+]]: memref<256xf16>, %[[ARG2:[0-9a-zA-Z]+]]: memref<8x16xf32>) {
 // CHECK: %[[CST:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
 // CHECK-NEXT: %[[CST0:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<true> : vector<16xi1>
-// CHECK-NEXT: %[[T2:.*]] = xegpu.create_tdesc %[[ARG1]], %[[CST]] {layout_operand_1 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} : memref<256xf16>, vector<16xindex> -> !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16 : i64>, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>>
-// CHECK-NEXT: %{{.*}} = xegpu.load %[[T2]], %[[CST0]] <{transpose}> {layout_operand_1 = #xegpu.layout<lane_layout = [16], lane_data = [1]>, layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} : !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16 : i64>, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>>, vector<16xi1> -> vector<16x16xf16>
+// CHECK-NEXT: %[[T2:.*]] = xegpu.create_tdesc %[[ARG1]], %[[CST]] : memref<256xf16>, vector<16xindex> -> !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16 : i64>, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>>
+// CHECK-NEXT: %{{.*}} = xegpu.load %[[T2]], %[[CST0]] <{transpose}> {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} : !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16 : i64>, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>>, vector<16xi1> -> vector<16x16xf16>
 func.func @load_gather_with_transpose_effect(%arg0: memref<8x16xf16>, %arg1: memref<256xf16>, %arg2: memref<8x16xf32>) {
   %c0 = arith.constant 0 : index
   %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
@@ -110,8 +110,8 @@ func.func @load_gather_with_transpose_effect(%arg0: memref<8x16xf16>, %arg1: mem
 // CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<256xf32>, %[[ARG1:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>>) {
 // CHECK: %[[CST:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
 // CHECK-NEXT: %[[CST0:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<true> : vector<16xi1>
-// CHECK-NEXT: %[[T0:.*]] = xegpu.create_tdesc %[[ARG0]], %[[CST]] {layout_operand_1 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} : memref<256xf32>, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
-// CHECK-NEXT: %{{.*}} = xegpu.load %[[T0]], %[[CST0]]  {layout_operand_1 = #xegpu.layout<lane_layout = [16], lane_data = [1]>, layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<lane_layout = [16], lane_data = [1]>>, vector<16xi1> -> vector<16xf32>
+// CHECK-NEXT: %[[T0:.*]] = xegpu.create_tdesc %[[ARG0]], %[[CST]] : memref<256xf32>, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
+// CHECK-NEXT: %{{.*}} = xegpu.load %[[T0]], %[[CST0]]  {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<lane_layout = [16], lane_data = [1]>>, vector<16xi1> -> vector<16xf32>
 func.func @load_gather_1d(%arg0: memref<256xf32>, %arg1: !xegpu.tensor_desc<16xf32>) {
   %cst = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
   %cst_0 = arith.constant dense<true> : vector<16xi1>
@@ -124,8 +124,8 @@ func.func @load_gather_1d(%arg0: memref<256xf32>, %arg1: !xegpu.tensor_desc<16xf
 // -----
 // CHECK-LABEL: func.func @store_scatter_with_transpose_effect(
 // CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<128xf32>) {
-// CHECK: %[[T0:.*]] = xegpu.create_tdesc %[[ARG0]], %{{.*}} {layout_operand_1 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} : memref<128xf32>, vector<16xindex> -> !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<chunk_size = 8 : i64>, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>>
-// CHECK-NEXT: xegpu.store %{{.*}}, %[[T0]], %{{.*}} <{transpose}> {layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, layout_operand_2 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} : vector<8x16xf32>, !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<chunk_size = 8 : i64>, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>>, vector<16xi1>
+// CHECK: %[[T0:.*]] = xegpu.create_tdesc %[[ARG0]], %{{.*}} : memref<128xf32>, vector<16xindex> -> !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<chunk_size = 8 : i64>, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>>
+// CHECK-NEXT: xegpu.store %{{.*}}, %[[T0]], %{{.*}} <{transpose}> : vector<8x16xf32>, !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<chunk_size = 8 : i64>, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>>, vector<16xi1>
 func.func @store_scatter_with_transpose_effect(%arg0: memref<128xf32>) {
   %cst = arith.constant dense<1.000000e+00> : vector<8x16xf32>
   %cst_0 = arith.constant dense<true> : vector<16xi1>
@@ -138,7 +138,7 @@ func.func @store_scatter_with_transpose_effect(%arg0: memref<128xf32>) {
 // -----
 // CHECK-LABEL: func.func @store_scatter_1d(
 // CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: vector<16xf32>, %[[ARG1:[0-9a-zA-Z]+]]: memref<256xf32>) {
-// CHECK: xegpu.store %[[ARG0]], %{{.*}}, %{{.*}}  {layout_operand_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>, layout_operand_2 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<lane_layout = [16], lane_data = [1]>>, vector<16xi1>
+// CHECK: xegpu.store %[[ARG0]], %{{.*}}, %{{.*}}  : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<lane_layout = [16], lane_data = [1]>>, vector<16xi1>
 func.func @store_scatter_1d(%arg0: vector<16xf32>, %arg1: memref<256xf32>) {
   %cst = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
   %cst_0 = arith.constant dense<true> : vector<16xi1>
@@ -150,8 +150,8 @@ func.func @store_scatter_1d(%arg0: vector<16xf32>, %arg1: memref<256xf32>) {
 // -----
 // CHECK-LABEL: func.func @vector_bitcast_i16_to_f16(
 // CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<8x16xi16>, %[[ARG1:[0-9a-zA-Z]+]]: memref<16x16xi16>, %[[ARG2:[0-9a-zA-Z]+]]: memref<8x16xf32>) {
-// CHECK: %{{.*}} = vector.bitcast %{{.*}} {layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<8x16xi16> to vector<8x16xf16>
-// CHECK: %{{.*}} = vector.bitcast %{{.*}} {layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>, layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} : vector<16x16xi16> to vector<16x16xf16>
+// CHECK: %{{.*}} = vector.bitcast %{{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<8x16xi16> to vector<8x16xf16>
+// CHECK: %{{.*}} = vector.bitcast %{{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} : vector<16x16xi16> to vector<16x16xf16>
 func.func @vector_bitcast_i16_to_f16(%arg0: memref<8x16xi16>, %arg1: memref<16x16xi16>, %arg2: memref<8x16xf32>) {
   %c0 = arith.constant 0 : index
   %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x16xi16> -> !xegpu.tensor_desc<8x16xi16>
@@ -171,7 +171,7 @@ func.func @vector_bitcast_i16_to_f16(%arg0: memref<8x16xi16>, %arg1: memref<16x1
 // CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<8x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>, %[[ARG1:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>>, %[[ARG2:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>) {
 // CHECK: %[[T1:.*]] = xegpu.load_nd %[[ARG1]]  {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>> -> vector<16x16xf16>
 // CHECK-NEXT: %[[T2:.*]] = xegpu.load_nd %[[ARG1]]  {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>> -> vector<16x16xf16>
-// CHECK-NEXT: %{{.*}} = arith.addf %[[T1]], %[[T2]] {layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>, layout_operand_1 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>, layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} : vector<16x16xf16>
+// CHECK-NEXT: %{{.*}} = arith.addf %[[T1]], %[[T2]] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} : vector<16x16xf16>
 func.func @binary_op_one_use(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: !xegpu.tensor_desc<8x16xf32>) {
   %0 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
   %1 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
@@ -185,10 +185,10 @@ func.func @binary_op_one_use(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.
 // -----
 // CHECK-LABEL: func.func @binary_op_multiple_uses(
 // CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<8x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>, %[[ARG1:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>, %[[ARG2:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>, %[[ARG3:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>) {
-// CHECK: %[[T2:.*]] = arith.addf %{{.*}}, %{{.*}} {layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, layout_operand_1 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16x16xf16>
-// CHECK: %[[T3:.*]] = xegpu.dpas %{{.*}}, %[[T2]] {layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, layout_operand_1 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
-// CHECK-NEXT: xegpu.store_nd %[[T3]], %[[ARG2]]  {layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
-// CHECK-NEXT: xegpu.store_nd %[[T2]], %[[ARG3]]  {layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+// CHECK: %[[T2:.*]] = arith.addf %{{.*}}, %{{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16x16xf16>
+// CHECK: %[[T3:.*]] = xegpu.dpas %{{.*}}, %[[T2]] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+// CHECK-NEXT: xegpu.store_nd %[[T3]], %[[ARG2]]  : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+// CHECK-NEXT: xegpu.store_nd %[[T2]], %[[ARG3]]  : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
 func.func @binary_op_multiple_uses(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: !xegpu.tensor_desc<8x16xf32>, %arg3: !xegpu.tensor_desc<16x16xf16>) {
   %0 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
   %1 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
@@ -209,13 +209,13 @@ func.func @binary_op_multiple_uses(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !
 // CHECK-NEXT: %[[T2:.*]]:3 = scf.for %{{.*}} iter_args(%[[ARG4:.*]] = %[[T0]], %[[ARG5:.*]] = %[[T1]], %[[ARG6:.*]] = %[[CST]]) -> (!xegpu.tensor_desc<8x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>, !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>>, vector<8x16xf32>) {
 // CHECK-NEXT:   %[[T4:.*]] = xegpu.load_nd %[[ARG4]]  {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : !xegpu.tensor_desc<8x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<8x16xf16>
 // CHECK-NEXT:   %[[T5:.*]] = xegpu.load_nd %[[ARG5]]  {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>> -> vector<16x16xf16>
-// CHECK-NEXT:   %[[T6:.*]] = xegpu.dpas %[[T4]], %[[T5]], %[[ARG6]] {layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, layout_operand_1 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>, layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
+// CHECK-NEXT:   %[[T6:.*]] = xegpu.dpas %[[T4]], %[[T5]], %[[ARG6]] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
 // CHECK-NEXT:   %[[T7:.*]] = xegpu.update_nd_offset %[[ARG4]], [{{.*}}] : !xegpu.tensor_desc<8x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
 // CHECK-NEXT:   %[[T8:.*]] = xegpu.update_nd_offset %[[ARG5]], [{{.*}}] : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>>
-// CHECK-NEXT:   scf.yield {layout_operand_2 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} %[[T7]], %[[T8]], %[[T6]] : !xegpu.tensor_desc<8x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>, !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>>, vector<8x16xf32>
-// CHECK-NEXT: } {layout_operand_5 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, layout_result_2 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+// CHECK-NEXT:   scf.yield %[[T7]], %[[T8]], %[[T6]] : !xegpu.tensor_desc<8x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>, !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>>, vector<8x16xf32>
+// CHECK-NEXT: } {layout_result_2 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
 // CHECK-NEXT: %[[T3:.*]] = xegpu.create_nd_tdesc %[[ARG2]][{{.*}}] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
-// CHECK-NEXT: xegpu.store_nd %[[T2]]#2, %[[T3]]  {layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+// CHECK-NEXT: xegpu.store_nd %[[T2]]#2, %[[T3]] : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
 func.func @for_op(%arg0: memref<8x128xf16>, %arg1: memref<128x16xf16>, %arg2: memref<8x16xf32>) {
   %c0 = arith.constant 0 : index
   %c128 = arith.constant 128 : index
@@ -241,10 +241,10 @@ func.func @for_op(%arg0: memref<8x128xf16>, %arg1: memref<128x16xf16>, %arg2: me
 // CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<8x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>, %[[ARG1:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>>, %[[ARG2:[0-9a-zA-Z]+]]: i1, %[[ARG3:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>) {
 // CHECK:  %{{.*}} = scf.if %[[ARG2]] -> (vector<16x16xf16>) {
 // CHECK-NEXT:    %[[T3:.*]] = xegpu.load_nd %[[ARG1]]  {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>> -> vector<16x16xf16>
-// CHECK-NEXT:    scf.yield {layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} %[[T3]] : vector<16x16xf16>
+// CHECK-NEXT:    scf.yield %[[T3]] : vector<16x16xf16>
 // CHECK-NEXT:  } else {
 // CHECK-NEXT:    %[[T4:.*]] = xegpu.load_nd %[[ARG1]]  {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>> -> vector<16x16xf16>
-// CHECK-NEXT:    scf.yield {layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} %[[T4]] : vector<16x16xf16>
+// CHECK-NEXT:    scf.yield %[[T4]] : vector<16x16xf16>
 // CHECK-NEXT:  } {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>}
 func.func @if_single_use(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: i1, %arg3: !xegpu.tensor_desc<8x16xf32>) {
   %0 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
@@ -265,10 +265,10 @@ func.func @if_single_use(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tens
 // CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<8x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>, %[[ARG1:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>, %[[ARG2:[0-9a-zA-Z]+]]: i1, %[[ARG3:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>, %[[ARG4:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>) {
 // CHECK: %[[T1:.*]] = scf.if %[[ARG2]] -> (vector<16x16xf16>) {
 // CHECK-NEXT:       %[[T3:.*]] = xegpu.load_nd %[[ARG1]]  {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<16x16xf16>
-// CHECK-NEXT:       scf.yield {layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} %[[T3]] : vector<16x16xf16>
+// CHECK-NEXT:       scf.yield %[[T3]] : vector<16x16xf16>
 // CHECK-NEXT:     } else {
 // CHECK-NEXT:       %[[T4:.*]] = xegpu.load_nd %[[ARG1]]  {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<16x16xf16>
-// CHECK-NEXT:       scf.yield {layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} %[[T4]] : vector<16x16xf16>
+// CHECK-NEXT:       scf.yield %[[T4]] : vector<16x16xf16>
 // CHECK-NEXT:     } {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
 func.func @if_multiple_uses(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: i1, %arg3: !xegpu.tensor_desc<8x16xf32>, %arg4: !xegpu.tensor_desc<16x16xf16>) {
   %0 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
@@ -288,7 +288,7 @@ func.func @if_multiple_uses(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.t
 // -----
 // CHECK-LABEL: func.func @vector_outer_reduction(
 // CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: vector<16x16xf32>, %[[ARG1:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>>) {
-// CHECK: %{{.*}} = vector.multi_reduction <add>, %[[ARG0]], %{{.*}} {layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, layout_operand_1 = #xegpu.layout<lane_layout = [16], lane_data = [1]>, layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} [0] : vector<16x16xf32> to vector<16xf32>
+// CHECK: %{{.*}} = vector.multi_reduction <add>, %[[ARG0]], %{{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} [0] : vector<16x16xf32> to vector<16xf32>
 func.func @vector_outer_reduction(%arg0: vector<16x16xf32>, %arg1: !xegpu.tensor_desc<16xf32>) {
   %cst = arith.constant dense<0.000000e+00> : vector<16xf32>
   %0 = vector.multi_reduction <add>, %arg0, %cst [0] : vector<16x16xf32> to vector<16xf32>
@@ -299,7 +299,7 @@ func.func @vector_outer_reduction(%arg0: vector<16x16xf32>, %arg1: !xegpu.tensor
 // -----
 // CHECK-LABEL: func.func @vector_inner_reduction(
 // CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: vector<16x16xf32>, %[[ARG1:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>>) {
-// CHECK: %{{.*}} = vector.multi_reduction <add>, %[[ARG0]], %{{.*}} {layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, layout_operand_1 = #xegpu.layout<lane_layout = [16], lane_data = [1]>, layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} [1] : vector<16x16xf32> to vector<16xf32>
+// CHECK: %{{.*}} = vector.multi_reduction <add>, %[[ARG0]], %{{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} [1] : vector<16x16xf32> to vector<16xf32>
 func.func @vector_inner_reduction(%arg0: vector<16x16xf32>, %arg1: !xegpu.tensor_desc<16xf32>) {
   %cst = arith.constant dense<0.000000e+00> : vector<16xf32>
   %0 = vector.multi_reduction <add>, %arg0, %cst [1] : vector<16x16xf32> to vector<16xf32>

>From 94da37e54ea094474301250d628d25104e4ff096 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Thu, 12 Jun 2025 18:28:02 +0000
Subject: [PATCH 29/32] address comments

---
 .../XeGPU/Transforms/XeGPULayoutPropagate.cpp | 11 ------
 .../Transforms/XeGPUSubgroupDistribute.cpp    | 21 +++++++++++
 .../Dialect/XeGPU/subgroup-distribute.mlir    | 36 +++++++++----------
 3 files changed, 39 insertions(+), 29 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutPropagate.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutPropagate.cpp
index 8c5a0163d1a43..a26b2e83580da 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutPropagate.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutPropagate.cpp
@@ -672,17 +672,6 @@ RunLayoutInfoPropagation::printAnalysisResult(llvm::raw_ostream &os) {
 }
 
 using GetLayoutFnTy = function_ref<xegpu::LayoutAttr(Value)>;
-/// Helper to update the users of a value with a given layout.
-// static void updateUsers(Value v, xegpu::LayoutAttr layout) {
-//   // Update all users of the value with the layout.
-//   for (OpOperand &user : v.getUses()) {
-//     Operation *owner = user.getOwner();
-//     // Add temporary layout attribute at the user op.
-//     std::string attrName = xegpu::getLayoutName(user);
-//     owner->setAttr(attrName, layout);
-//   }
-// }
-
 /// Update an operation with the layout of its results. If the result type is a
 /// vector type, a temporary layout attribute is added to the operation. If the
 /// result type is a tensor descriptor type, the type is updated with the layout
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index eb8192417f843..747e01f329c03 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -812,6 +812,27 @@ void xegpu::populateXeGPUSubgroupDistributePatterns(
 }
 
 void XeGPUSubgroupDistributePass::runOnOperation() {
+  // Attach layout to operands.
+  Operation *op = getOperation();
+  op->walk([&](Operation *op) {
+    for (OpOperand &operand : op->getOpOperands()) {
+      // Layouts are needed for vector type only.
+      if (!isa<VectorType>(operand.get().getType()))
+        continue;
+      // If the operand already has a layout, skip it.
+      if (xegpu::getLayoutAttr(operand))
+        continue;
+
+      xegpu::LayoutAttr layout = xegpu::getLayoutAttr(operand);
+      if (!layout) {
+        op->emitError("Could not find layout attribute for operand ")
+            << operand.getOperandNumber() << " of operation " << op->getName();
+        signalPassFailure();
+        return;
+      }
+      xegpu::setLayoutAttr(operand, layout);
+    }
+  });
   // Move all operations of a GPU function inside gpu.warp_execute_on_lane_0
   // operation.
   {
diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
index 7362c175a70a4..fef03560dddd7 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
@@ -11,7 +11,7 @@ gpu.module @test {
     %c0 = arith.constant 0 : index
     %cst = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<1.000000e+00> : vector<16xf32>
     %0 = xegpu.create_nd_tdesc %arg0[%c0] : memref<16xf32> -> !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
-    xegpu.store_nd %cst, %0  {layout_operand_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
+    xegpu.store_nd %cst, %0  : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
     gpu.return
   }
 }
@@ -27,7 +27,7 @@ gpu.module @test {
     %c0 = arith.constant 0 : index
     %cst = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} dense<1.000000e+00> : vector<16x16xf16>
     %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
-    xegpu.store_nd %cst, %0  {layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    xegpu.store_nd %cst, %0 : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
     gpu.return
   }
 }
@@ -47,7 +47,7 @@ gpu.module @test {
     %0 = xegpu.create_nd_tdesc %arg0[%c0] : memref<16xf32> -> !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
     %1 = xegpu.load_nd %0  {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} : !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>> -> vector<16xf32>
     %2 = xegpu.create_nd_tdesc %arg1[%c0] : memref<16xf32> -> !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
-    xegpu.store_nd %1, %2  {layout_operand_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
+    xegpu.store_nd %1, %2 : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
     gpu.return
   }
 }
@@ -65,7 +65,7 @@ gpu.module @test {
     %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
     %1 = xegpu.load_nd %0  {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<16x16xf16>
     %2 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
-    xegpu.store_nd %1, %2  {layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    xegpu.store_nd %1, %2 : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
     gpu.return
   }
 }
@@ -85,9 +85,9 @@ gpu.module @test {
     %c0 = arith.constant 0 : index
     %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2 : i64>, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
     %1 = xegpu.load_nd %0  {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2 : i64>, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<2x16x16xf16>
-    %2 = vector.extract %1[%c0] {layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16x16xf16> from vector<2x16x16xf16>
+    %2 = vector.extract %1[%c0] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16x16xf16> from vector<2x16x16xf16>
     %3 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
-    xegpu.store_nd %2, %3  {layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    xegpu.store_nd %2, %3 : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
     gpu.return
   }
 }
@@ -109,9 +109,9 @@ gpu.module @test {
     %1 = xegpu.load_nd %0  {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : !xegpu.tensor_desc<8x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<8x16xf16>
     %2 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>>
     %3 = xegpu.load_nd %2  {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>> -> vector<16x16xf16>
-    %4 = xegpu.dpas %1, %3 {layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, layout_operand_1 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>, layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+    %4 = xegpu.dpas %1, %3 {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
     %5 = xegpu.create_nd_tdesc %arg2[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
-    xegpu.store_nd %4, %5  {layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    xegpu.store_nd %4, %5  : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
     gpu.return
   }
 }
@@ -137,10 +137,10 @@ gpu.module @test {
     %1 = xegpu.load_nd %0  {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : !xegpu.tensor_desc<8x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<8x16xf16>
     %2 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>>
     %3 = xegpu.load_nd %2  {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>> -> vector<16x16xf16>
-    %4 = xegpu.dpas %1, %3 {layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, layout_operand_1 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>, layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
-    %5 = math.exp %4 {layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<8x16xf32>
+    %4 = xegpu.dpas %1, %3 {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+    %5 = math.exp %4 {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<8x16xf32>
     %6 = xegpu.create_nd_tdesc %arg2[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
-    xegpu.store_nd %5, %6  {layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    xegpu.store_nd %5, %6 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
     gpu.return
   }
 }
@@ -160,7 +160,7 @@ gpu.module @test {
     %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0], [%arg2, %arg3], [%arg4, %arg5] : ui64 -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
     %1 = xegpu.load_nd %0  {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<16x16xf16>
     %2 = xegpu.create_nd_tdesc %arg1[%c0, %c0], [%arg2, %arg3], [%arg4, %arg5] : ui64 -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
-    xegpu.store_nd %1, %2  {layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    xegpu.store_nd %1, %2 : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
     gpu.return
   }
 }
@@ -205,10 +205,10 @@ gpu.func @gemm_loop(%arg0: memref<1024x1024xbf16>, %arg1: memref<1024x1024xbf16>
     %6 = xegpu.create_nd_tdesc %arg1[%arg3, %1] : memref<1024x1024xbf16> -> !xegpu.tensor_desc<16x16xbf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>>
     %7 = xegpu.load_nd %5  {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : !xegpu.tensor_desc<8x16xbf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<8x16xbf16>
     %8 = xegpu.load_nd %6  {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} : !xegpu.tensor_desc<16x16xbf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>> -> vector<16x16xbf16>
-    %9 = xegpu.dpas %7, %8, %arg4 {layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, layout_operand_1 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>, layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<8x16xbf16>, vector<16x16xbf16>, vector<8x16xf32> -> vector<8x16xf32>
-    scf.yield {layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} %9 : vector<8x16xf32>
-  } {layout_operand_3 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
-  xegpu.store_nd %4, %2  {layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    %9 = xegpu.dpas %7, %8, %arg4 {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<8x16xbf16>, vector<16x16xbf16>, vector<8x16xf32> -> vector<8x16xf32>
+    scf.yield %9 : vector<8x16xf32>
+  } {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+  xegpu.store_nd %4, %2 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
   gpu.return
 }
 }
@@ -227,7 +227,7 @@ gpu.module @test {
     %cst = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<1.000000e+00> : vector<16xf32>
     %0 = xegpu.create_nd_tdesc %arg0[%c0] : memref<256xf32> -> !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
     %1 = xegpu.update_nd_offset %0, [%c32] : !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
-    xegpu.store_nd %cst, %1  {layout_operand_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
+    xegpu.store_nd %cst, %1  : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
     gpu.return
   }
 }
@@ -246,7 +246,7 @@ gpu.module @test {
     %cst = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} dense<1.000000e+00> : vector<16x16xf32>
     %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<256x256xf32> -> !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
     %1 = xegpu.update_nd_offset %0, [%c32, %c32] : !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
-    xegpu.store_nd %cst, %1  {layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16x16xf32>, !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    xegpu.store_nd %cst, %1  : vector<16x16xf32>, !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
     gpu.return
   }
 }

>From 76671e2538bfacce83ad2f594ead5b19eb0de1c4 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Thu, 12 Jun 2025 19:07:54 +0000
Subject: [PATCH 30/32] address comments

---
 .../Transforms/XeGPUSubgroupDistribute.cpp     | 18 +++++++++---------
 1 file changed, 9 insertions(+), 9 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 747e01f329c03..869f99c206c96 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -812,16 +812,16 @@ void xegpu::populateXeGPUSubgroupDistributePatterns(
 }
 
 void XeGPUSubgroupDistributePass::runOnOperation() {
-  // Attach layout to operands.
+  // Step 1: Attach layout to op operands.
+  // TODO: Following assumptions are made:
+  // 1) It is assumed that there are no layout conflicts.
+  // 2) Any existing layout attributes attached to the operands are ignored.
   Operation *op = getOperation();
   op->walk([&](Operation *op) {
     for (OpOperand &operand : op->getOpOperands()) {
       // Layouts are needed for vector type only.
       if (!isa<VectorType>(operand.get().getType()))
         continue;
-      // If the operand already has a layout, skip it.
-      if (xegpu::getLayoutAttr(operand))
-        continue;
 
       xegpu::LayoutAttr layout = xegpu::getLayoutAttr(operand);
       if (!layout) {
@@ -833,8 +833,8 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
       xegpu::setLayoutAttr(operand, layout);
     }
   });
-  // Move all operations of a GPU function inside gpu.warp_execute_on_lane_0
-  // operation.
+  // Step 2: Move all operations of a GPU function inside
+  // gpu.warp_execute_on_lane_0 operation.
   {
     RewritePatternSet patterns(&getContext());
     patterns.add<MoveFuncBodyToWarpExecuteOnLane0>(&getContext());
@@ -853,7 +853,7 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
       }
     });
   }
-  // Apply subgroup to workitem distribution patterns.
+  // Step 3: Finally, Apply subgroup to workitem distribution patterns.
   RewritePatternSet patterns(&getContext());
   xegpu::populateXeGPUSubgroupDistributePatterns(patterns);
   // TODO: distributionFn and shuffleFn are not used at this point.
@@ -874,8 +874,8 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
     return;
   }
 
-  // Clean up UnrealizedConversionCastOps that were inserted due to tensor
-  // desc type mismatches created by using upstream distribution patterns
+  // Step 4: Clean up UnrealizedConversionCastOps that were inserted due to
+  // tensor desc type mismatches created by using upstream distribution patterns
   // (scf.for)
   getOperation()->walk([&](mlir::UnrealizedConversionCastOp op) {
     // We are only interested in UnrealizedConversionCastOps there were added

>From 32f8c799b523c1906a3334893d33587bbbd72866 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Thu, 12 Jun 2025 19:48:25 +0000
Subject: [PATCH 31/32] address comments

---
 .../XeGPU/Transforms/XeGPULayoutPropagate.cpp | 32 ++++++++++---------
 .../Transforms/XeGPUSubgroupDistribute.cpp    | 10 +++---
 2 files changed, 22 insertions(+), 20 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutPropagate.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutPropagate.cpp
index a26b2e83580da..0376d1c8c4ff4 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutPropagate.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutPropagate.cpp
@@ -19,6 +19,7 @@
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Operation.h"
 #include "mlir/IR/Value.h"
@@ -341,9 +342,6 @@ LogicalResult LayoutInfoPropagation::visitOperation(
       .Case<xegpu::PrefetchNdOp>([&](auto prefetchNdOp) {
         visitPrefetchNdOp(prefetchNdOp, operands, results);
       })
-      // No need to propagate the layout to operands in CreateNdDescOp because
-      // they are scalars (offsets, sizes, etc.).
-      .Case<xegpu::CreateNdDescOp>([&](auto createNdDescOp) {})
       .Case<vector::TransposeOp>([&](auto transposeOp) {
         visitTransposeOp(transposeOp, operands, results);
       })
@@ -355,12 +353,18 @@ LogicalResult LayoutInfoPropagation::visitOperation(
       })
       // All other ops.
       .Default([&](Operation *op) {
-        for (const LayoutInfoLattice *r : results) {
-          if (r->getValue().isAssigned()) {
-            for (LayoutInfoLattice *operand : operands) {
-              // Propagate the layout of the result to the operand.
-              meet(operand, *r);
-            }
+        for (const LayoutInfoLattice *resultInfo : results) {
+          if (!resultInfo->getValue().isAssigned())
+            continue;
+          for (auto [operandInfo, operand] :
+               llvm::zip(operands, op->getOpOperands())) {
+            // If the operand type is not a vector or tensor descriptor, skip
+            // it.
+            if (!isa<xegpu::TensorDescType, VectorType>(
+                    operand.get().getType()))
+              continue;
+            // Propagate the result layout to the operand.
+            meet(operandInfo, *resultInfo);
           }
         }
       });
@@ -456,7 +460,8 @@ void LayoutInfoPropagation::visitLoadNdOp(
     return;
   LayoutInfo tensorDescLayout = valueLayout;
   // LoadNdOp has the transpose effect. However, at the stage of this analysis
-  // this effect is not expected and should be abstracted away. Emit a warning.
+  // this effect is not expected and should be abstracted away. Emit a
+  // warning.
   if (auto transpose = load.getTranspose()) {
     load.emitWarning("Transpose effect is not expected for LoadNdOp at "
                      "LayoutInfoPropagation stage.");
@@ -495,8 +500,8 @@ void LayoutInfoPropagation::visitVectorBitcastOp(
   int outElemTyBitWidth =
       bitcast.getResultVectorType().getElementType().getIntOrFloatBitWidth();
 
-  // NOTE: We do not expect widening or narrowing bitcasts at this stage. Emit a
-  // warning and return.
+  // NOTE: We do not expect widening or narrowing bitcasts at this stage. Emit
+  // a warning and return.
   if (inElemTyBitWidth != outElemTyBitWidth) {
     bitcast.emitWarning("Widening or narrowing bitcasts are not expected at "
                         "layout propagation stage.");
@@ -583,7 +588,6 @@ void LayoutInfoPropagation::visitStoreScatterOp(
 }
 
 namespace {
-
 //===----------------------------------------------------------------------===//
 // RunLayoutInfoPropagation
 //===----------------------------------------------------------------------===//
@@ -679,7 +683,6 @@ using GetLayoutFnTy = function_ref<xegpu::LayoutAttr(Value)>;
 /// attribute.
 static void updateOp(mlir::OpBuilder &builder, mlir::Operation *op,
                      GetLayoutFnTy getLayoutOfValue) {
-
   // Iterate over all the results.
   for (OpResult result : op->getResults()) {
     Type resultType = result.getType();
@@ -872,7 +875,6 @@ static void updateFunctionOpInterface(mlir::OpBuilder &builder,
 }
 
 namespace {
-
 struct XeGPULayoutPropagatePass final
     : public xegpu::impl::XeGPULayoutPropagateBase<XeGPULayoutPropagatePass> {
   void runOnOperation() override;
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 869f99c206c96..8b818b21ca436 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -812,7 +812,7 @@ void xegpu::populateXeGPUSubgroupDistributePatterns(
 }
 
 void XeGPUSubgroupDistributePass::runOnOperation() {
-  // Step 1: Attach layout to op operands.
+  // Step 1: Attach layouts to op operands.
   // TODO: Following assumptions are made:
   // 1) It is assumed that there are no layout conflicts.
   // 2) Any existing layout attributes attached to the operands are ignored.
@@ -853,7 +853,7 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
       }
     });
   }
-  // Step 3: Finally, Apply subgroup to workitem distribution patterns.
+  // Step 3: Apply subgroup to workitem distribution patterns.
   RewritePatternSet patterns(&getContext());
   xegpu::populateXeGPUSubgroupDistributePatterns(patterns);
   // TODO: distributionFn and shuffleFn are not used at this point.
@@ -874,9 +874,9 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
     return;
   }
 
-  // Step 4: Clean up UnrealizedConversionCastOps that were inserted due to
-  // tensor desc type mismatches created by using upstream distribution patterns
-  // (scf.for)
+  // Step 4: Finllay, clean up UnrealizedConversionCastOps that were inserted
+  // due to tensor desc type mismatches created by using upstream distribution
+  // patterns (scf.for)
   getOperation()->walk([&](mlir::UnrealizedConversionCastOp op) {
     // We are only interested in UnrealizedConversionCastOps there were added
     // for resolving SIMT type mismatches.

>From 9cefe6fab894b903f22647c0e4f981bd1dcc8d24 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Fri, 13 Jun 2025 18:08:04 +0000
Subject: [PATCH 32/32] address comments

---
 .../XeGPU/Transforms/XeGPULayoutPropagate.cpp | 55 ++++++++++---------
 .../Transforms/XeGPUSubgroupDistribute.cpp    |  7 +--
 .../Dialect/XeGPU/subgroup-distribute.mlir    |  6 +-
 3 files changed, 34 insertions(+), 34 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutPropagate.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutPropagate.cpp
index 0376d1c8c4ff4..c36b2897e7903 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutPropagate.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutPropagate.cpp
@@ -444,9 +444,8 @@ void LayoutInfoPropagation::visitStoreNdOp(
     ArrayRef<const LayoutInfoLattice *> results) {
   LayoutInfo storeLayout = getDefaultSIMTLayoutInfo(store.getValueType());
   // Both operands should have the same layout
-  for (LayoutInfoLattice *operand : operands) {
+  for (LayoutInfoLattice *operand : operands)
     propagateIfChanged(operand, operand->meet(storeLayout));
-  }
 }
 
 /// Propagate the layout of the value to the tensor descriptor operand in
@@ -659,20 +658,18 @@ RunLayoutInfoPropagation::printAnalysisResult(llvm::raw_ostream &os) {
 
   SmallVector<FunctionOpInterface> funcOps;
   if (auto modOp = dyn_cast<ModuleOp>(target)) {
-    for (auto funcOp : modOp.getOps<FunctionOpInterface>()) {
+    for (auto funcOp : modOp.getOps<FunctionOpInterface>())
       funcOps.push_back(funcOp);
-    }
+
     // Collect all GpuFuncOps in the module.
     for (auto gpuModOp : modOp.getOps<gpu::GPUModuleOp>()) {
-      for (auto gpuFuncOp : gpuModOp.getOps<FunctionOpInterface>()) {
+      for (auto gpuFuncOp : gpuModOp.getOps<FunctionOpInterface>())
         funcOps.push_back(gpuFuncOp);
-      }
     }
   }
   // Print the analysis result for each function.
-  for (FunctionOpInterface funcOp : funcOps) {
+  for (FunctionOpInterface funcOp : funcOps)
     printFunctionResult(funcOp);
-  }
 }
 
 using GetLayoutFnTy = function_ref<xegpu::LayoutAttr(Value)>;
@@ -706,7 +703,6 @@ static void updateOp(mlir::OpBuilder &builder, mlir::Operation *op,
     }
     // If the result is a vector type, add a temporary layout attribute to the
     // op.
-    std::string resultLayoutName = xegpu::getLayoutName(result);
     xegpu::setLayoutAttr(result, layout);
   }
 }
@@ -717,6 +713,7 @@ static void updateBranchTerminatorOpInterface(
     mlir::OpBuilder &builder,
     mlir::RegionBranchTerminatorOpInterface terminator,
     GetLayoutFnTy getLayoutOfValue) {
+  // Only process if the terminator is inside a region branch op.
   if (!mlir::isa<mlir::RegionBranchOpInterface>(terminator->getParentOp()))
     return;
 
@@ -729,9 +726,10 @@ static void updateBranchTerminatorOpInterface(
     if (!successor.isParent())
       continue;
 
-    mlir::OperandRange operands = terminator.getSuccessorOperands(successor);
-    mlir::ValueRange inputs = successor.getSuccessorInputs();
-    for (auto [operand, input] : llvm::zip(operands, inputs)) {
+    mlir::OperandRange forwardedOperands =
+        terminator.getSuccessorOperands(successor);
+    mlir::ValueRange regionArgs = successor.getSuccessorInputs();
+    for (auto [operand, input] : llvm::zip(forwardedOperands, regionArgs)) {
       // print arg and inp
       // llvm::errs() << "arg: " << operand << ", inp: " << input << "\n";
       Type inputType = input.getType();
@@ -773,38 +771,43 @@ static void updateBranchOpInterface(mlir::OpBuilder &builder,
   llvm::SmallVector<mlir::RegionSuccessor> successors;
   llvm::SmallVector<mlir::Attribute> operands(op->getNumOperands(), nullptr);
   branch.getEntrySuccessorRegions(operands, successors);
-  DenseMap<Value, xegpu::LayoutAttr> resultToLayouts;
+  DenseMap<Value, xegpu::LayoutAttr>
+      resultToLayouts; // This map keeps track of layouts of any unused results
+                       // of the branch op.
   mlir::ValueRange results = op->getResults();
 
   for (mlir::RegionSuccessor &successor : successors) {
+    // Only interested in successor regions that are contained within the op.
     if (successor.isParent())
       continue;
 
-    mlir::OperandRange operands = branch.getEntrySuccessorOperands(successor);
-    mlir::ValueRange inputs = successor.getSuccessorInputs();
+    mlir::OperandRange forwardedOperands =
+        branch.getEntrySuccessorOperands(successor);
+    mlir::ValueRange regionArgs = successor.getSuccessorInputs();
 
-    for (auto [operand, input, result] : llvm::zip(operands, inputs, results)) {
-      Type inputType = input.getType();
+    for (auto [forwardedOperand, regionArg, result] :
+         llvm::zip(forwardedOperands, regionArgs, results)) {
+      Type inputType = regionArg.getType();
       if (!isa<xegpu::TensorDescType>(inputType))
         continue;
-      xegpu::LayoutAttr inputLayout = getLayoutOfValue(input);
-      xegpu::LayoutAttr operandLayout = getLayoutOfValue(operand);
+      xegpu::LayoutAttr inputLayout = getLayoutOfValue(regionArg);
+      xegpu::LayoutAttr operandLayout = getLayoutOfValue(forwardedOperand);
 
       if (!inputLayout || !operandLayout) {
-        LLVM_DEBUG(DBGS() << "No layout assigned for block arg: " << input
-                          << " or init arg: " << operand << "\n");
+        LLVM_DEBUG(DBGS() << "No layout assigned for block arg: " << regionArg
+                          << " or init arg: " << forwardedOperand << "\n");
         continue;
       }
 
       // TODO: We expect these two to match.
       assert(inputLayout == operandLayout &&
-             "Expexing block arg and init arg to have the same layout.");
+             "Expecting block arg and init arg to have the same layout.");
       // Get tensor descriptor type with the layout.
       auto tdescTy = dyn_cast<xegpu::TensorDescType>(inputType);
       auto newTdescTy = xegpu::TensorDescType::get(
           tdescTy.getContext(), tdescTy.getShape(), tdescTy.getElementType(),
           tdescTy.getEncoding(), inputLayout);
-      input.setType(newTdescTy);
+      regionArg.setType(newTdescTy);
       // Store the layout for the result.
       if (resultToLayouts.count(result) != 0 &&
           resultToLayouts[result] != inputLayout) {
@@ -837,7 +840,6 @@ static void updateBranchOpInterface(mlir::OpBuilder &builder,
     }
     // If the result is a vector type, add a temporary layout attribute to
     // the op.
-    std::string resultLayoutName = xegpu::getLayoutName(r);
     xegpu::setLayoutAttr(r, layout);
   }
 }
@@ -865,7 +867,6 @@ static void updateFunctionOpInterface(mlir::OpBuilder &builder,
           tensorDescTy.getElementType(), tensorDescTy.getEncoding(), layout);
       arg.setType(newTdescTy);
       newArgTypes.back() = newTdescTy;
-      continue;
     }
   }
   // Update the function type with the new argument types.
@@ -887,9 +888,9 @@ void XeGPULayoutPropagatePass::runOnOperation() {
   // Helper to convert LayoutInfo to xegpu::LayoutAttr.
   auto getXeGPULayoutForValue = [&](Value val) -> xegpu::LayoutAttr {
     LayoutInfo layout = analyis.getLayoutInfo(val);
-    if (!layout.isAssigned()) {
+    if (!layout.isAssigned())
       return {};
-    }
+
     SmallVector<int, 2> laneLayout, laneData;
     for (auto [layout, data] : llvm::zip_equal(layout.getLayoutAsArrayRef(),
                                                layout.getDataAsArrayRef())) {
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 8b818b21ca436..dc3dc70e325a3 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -97,9 +97,9 @@ getDistVecTypeBasedOnLaneLayout(xegpu::LayoutAttr layout,
   // dimensions are not distributed.
   unsigned distributionStart = originalType.getRank() - laneLayout.size();
   for (auto [i, dim] : llvm::enumerate(originalType.getShape())) {
-    if (i < distributionStart) {
+    if (i < distributionStart)
       continue;
-    }
+
     // Check if the dimension can be distributed evenly.
     if (dim % laneLayout[i - distributionStart] != 0)
       return failure();
@@ -848,9 +848,8 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
     // GPU index ops, scalar constants, etc.). This will simplify the
     // later lowering and avoid custom patterns for these ops.
     getOperation()->walk([&](Operation *op) {
-      if (auto warpOp = dyn_cast<gpu::WarpExecuteOnLane0Op>(op)) {
+      if (auto warpOp = dyn_cast<gpu::WarpExecuteOnLane0Op>(op))
         vector::moveScalarUniformCode(warpOp);
-      }
     });
   }
   // Step 3: Apply subgroup to workitem distribution patterns.
diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
index fef03560dddd7..a59633b0cbd9a 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
@@ -166,8 +166,8 @@ gpu.module @test {
 }
 
 // -----
-// TODO: gemm does not use update_nd_offset because of an issue in vector distribution. PR141853 tracks this issue.
-// CHECK-LABEL: gpu.func @gemm_loop
+// TODO: gemm does not use update_nd_offset because of an issue in scf-for distribution.
+// CHECK-LABEL: gpu.func @gemm
 // CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: memref<1024x1024xbf16>, %[[ARG1:[0-9a-zA-Z]+]]: memref<1024x1024xbf16>, %[[ARG2:[0-9a-zA-Z]+]]: memref<1024x1024xf32>) {
 // CHECK-DAG: %[[BLOCK_ID_X:.*]] = gpu.block_id x
 // CHECK-DAG: %[[BLOCK_ID_Y:.*]] = gpu.block_id y
@@ -189,7 +189,7 @@ gpu.module @test {
 // CHECK-NEXT: %[[T9:.*]] = vector.shape_cast %[[T5]] : vector<8x1xf32> to vector<8xf32>
 // CHECK-NEXT: xegpu.store_nd %[[T9]], %[[T2]] : vector<8xf32>, !xegpu.tensor_desc<8x16xf32>
 gpu.module @test {
-gpu.func @gemm_loop(%arg0: memref<1024x1024xbf16>, %arg1: memref<1024x1024xbf16>, %arg2: memref<1024x1024xf32>){
+gpu.func @gemm(%arg0: memref<1024x1024xbf16>, %arg1: memref<1024x1024xbf16>, %arg2: memref<1024x1024xf32>){
   %c0 = arith.constant 0 : index
   %c16 = arith.constant 16 : index
   %c8 = arith.constant 8 : index



More information about the Mlir-commits mailing list