[Mlir-commits] [mlir] Reapply [mlir][xegpu] Add XeGPU subgroup map propagation analysis for XeGPU SIMT distribution. (PR #131380)

Charitha Saumya llvmlistbot at llvm.org
Fri Mar 14 12:18:33 PDT 2025


https://github.com/charithaintc created https://github.com/llvm/llvm-project/pull/131380

Originally introduced in #130240 and reverted in #131364 

Reproduced the issue locally in Linux by doing a shared lib build. Fixes including adding the missing LINK_LIBS.

**Original commit message:**

This PR adds the SG map propagation step of the XeGPU SIMT distribution. SG map propagation is a sparse backward dataflow analysis that propagate the sg_map backward starting from the operands of certain operations (DPAS, store etc.).

This is the first step of XeGPU subgroup distribution. This analysis result is used to attach layout information to each XeGPU SIMD subgroup op. The lowering patterns in XeGPUSubgroupDistribute will consume these layout info to distribute SIMD ops into SIMT ops that work on work-item level data fragments.

Summary of Lowering XeGPU SIMD -> SIMT
Subgroup map propagation (This PR)
Attach sg_map to each op in move all ops inside gpu.warp_execute_on_lane0 region.
Distribute each op using sg_map
Additional legalization steps to align more with Xe HW.

>From ff972397d8c1f8acf51781fd26d42ae9ddd26db2 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Fri, 28 Feb 2025 23:57:37 +0000
Subject: [PATCH 01/28] save work

---
 .../mlir/Dialect/XeGPU/Transforms/Passes.td   |  10 +
 .../Dialect/XeGPU/Transforms/Transforms.h     |   2 +
 .../Dialect/XeGPU/Transforms/CMakeLists.txt   |   1 +
 .../Transforms/XeGPUSubgroupDistribute.cpp    | 249 ++++++++++++++++++
 4 files changed, 262 insertions(+)
 create mode 100644 mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp

diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
index 1ecd6ce95322b..cb9d403566645 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
@@ -23,4 +23,14 @@ def XeGPUFoldAliasOps : Pass<"xegpu-fold-alias-ops"> {
   ];
 }
 
+def XeGPUSubgroupDistribute : Pass<"xegpu-subgroup-distribute"> {
+  let summary = "Distribute XeGPU ops to work items";
+  let description = [{
+    The pass distributes subgroup level (SIMD) XeGPU ops to work items.
+  }];
+  let dependentDialects = [
+      "memref::MemRefDialect", "xegpu::XeGPUDialect", "vector::VectorDialect"
+  ];
+}
+
 #endif // MLIR_DIALECT_XEGPU_TRANSFORMS_PASSES_TD
diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h b/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
index 63ea26df06937..86b95721df60c 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
@@ -16,6 +16,8 @@ namespace xegpu {
 
 /// Appends patterns for folding aliasing ops into XeGPU ops into `patterns`.
 void populateXeGPUFoldAliasOpsPatterns(RewritePatternSet &patterns);
+/// Appends patterns for distributing XeGPU ops to work items into `patterns`.
+void populateXeGPUSubgroupDistributePatterns(RewritePatternSet &patterns);
 
 } // namespace xegpu
 } // namespace mlir
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
index 7fb64d3b97b87..124e904edb543 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
@@ -1,5 +1,6 @@
 add_mlir_dialect_library(MLIRXeGPUTransforms
   XeGPUFoldAliasOps.cpp
+  XeGPUSubgroupDistribute.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/XeGPU
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
new file mode 100644
index 0000000000000..99995d92e24b6
--- /dev/null
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -0,0 +1,249 @@
+//===- XeGPUSubgroupDistribute.cpp - XeGPU Subgroup Distribute Pass -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
+#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
+#include "mlir/Analysis/DataFlow/SparseAnalysis.h"
+#include "mlir/Analysis/DataFlowFramework.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
+#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/Support/raw_ostream.h"
+#include <algorithm>
+
+namespace mlir {
+namespace xegpu {
+#define GEN_PASS_DEF_XEGPUSUBGROUPDISTRIBUTE
+#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
+} // namespace xegpu
+} // namespace mlir
+
+#define DEBUG_TYPE "xegpu-subgroup-distribute"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
+
+using namespace mlir;
+using namespace mlir::dataflow;
+
+constexpr unsigned subgroupSize = 16;
+constexpr unsigned packedASizeInBits = 16;
+constexpr unsigned packedBSizeInBits = 32;
+
+namespace {
+struct Layout2D {
+  SmallVector<int64_t, 2> layout;
+  Layout2D() = default;
+  Layout2D(int64_t x, int64_t y) { layout.insert(layout.end(), {x, y}); }
+  bool operator==(const Layout2D &rhs) const {
+    return this->layout == rhs.layout;
+  }
+  bool operator<(const Layout2D &rhs) const {
+    return this->layout < rhs.layout;
+  }
+  void print(llvm::raw_ostream &os) const {
+    os << "{";
+    llvm::interleave(
+        layout, os, [&](int64_t a) { os << a; }, ", ");
+    os << "}";
+  }
+};
+
+using WiLayout = Layout2D;
+using WiData = Layout2D;
+
+struct SGMapInfo {
+  Layout2D wiLayout;
+  Layout2D wiData;
+  SGMapInfo() = default;
+  SGMapInfo(const Layout2D &layout, const Layout2D &data, unsigned bitWidth)
+      : wiLayout(layout), wiData(data) {}
+  bool operator==(const SGMapInfo &rhs) const {
+    return this->wiLayout == rhs.wiLayout && this->wiData == rhs.wiData;
+  }
+  bool operator<(const SGMapInfo &rhs) const {
+    return this->wiLayout < rhs.wiLayout || this->wiData < rhs.wiData;
+  }
+  void print(llvm::raw_ostream &os) const {
+    os << "{";
+    os << "layout: ";
+    wiLayout.print(os);
+    os << ", ";
+    os << "data: ";
+    wiData.print(os);
+    os << "}";
+  }
+};
+
+struct SGMapLatticeValue {
+private:
+  std::set<SGMapInfo> layouts;
+
+public:
+  SGMapLatticeValue() = default;
+  SGMapLatticeValue(const SGMapLatticeValue &other) = default;
+  SGMapLatticeValue(const WiLayout &layout, const WiData &data) {
+    layouts.insert(SGMapInfo(layout, data, 16));
+  }
+
+  bool operator==(const SGMapLatticeValue &other) const {
+    return this->layouts == other.layouts;
+  }
+
+  /// This function depends on a partial ordering of the lattice values.
+  static SGMapLatticeValue meet(const SGMapLatticeValue &lhs,
+                                const SGMapLatticeValue &rhs) {
+    SGMapLatticeValue res = lhs;
+    (void)res.addLayouts(rhs.layouts);
+    return res;
+  }
+
+  static SGMapLatticeValue join(const SGMapLatticeValue &lhs,
+                                const SGMapLatticeValue &rhs) {
+    // Should not be triggered by this analysis, but required by `Lattice<T>`
+    llvm_unreachable("Join should not be triggered by this test");
+  }
+
+  ChangeResult addLayouts(const std::set<SGMapInfo> &layouts) {
+    int sizeBefore = this->layouts.size();
+    this->layouts.insert(layouts.begin(), layouts.end());
+    int sizeAfter = this->layouts.size();
+    return sizeBefore == sizeAfter ? ChangeResult::NoChange
+                                   : ChangeResult::Change;
+  }
+
+  void print(raw_ostream &os) const {
+    os << "[";
+    llvm::interleave(
+        layouts, os, [&](const SGMapInfo &a) { a.print(os); }, ", ");
+    os << "]";
+  }
+
+  void clear() { layouts.clear(); }
+
+  std::set<SGMapInfo> getLayouts() const { return layouts; }
+};
+
+struct SGMap : public Lattice<SGMapLatticeValue> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SGMap)
+  using Lattice::Lattice;
+};
+
+static SGMapLatticeValue getSGMapForDPASOperand(Type operandTy,
+                                                unsigned operandNum) {
+  int packingFactorForB = packedBSizeInBits / operandTy.getIntOrFloatBitWidth();
+  int packingFactorForA =
+      operandTy.getIntOrFloatBitWidth() < packedBSizeInBits
+          ? packedASizeInBits / operandTy.getIntOrFloatBitWidth()
+          : 1;
+  return SGMapLatticeValue(WiLayout(1, subgroupSize),
+                           WiData(operandNum == 1 ? packingFactorForB : 1,
+                                  operandNum == 0 ? packingFactorForA : 1));
+}
+
+class SGMapPropagation : public SparseBackwardDataFlowAnalysis<SGMap> {
+public:
+  SGMapPropagation(DataFlowSolver &solver, SymbolTableCollection &symbolTable)
+      : SparseBackwardDataFlowAnalysis(solver, symbolTable) {}
+  using SparseBackwardDataFlowAnalysis::SparseBackwardDataFlowAnalysis;
+
+  LogicalResult visitOperation(Operation *op, ArrayRef<SGMap *> operands,
+                               ArrayRef<const SGMap *> results) override;
+
+  void visitBranchOperand(OpOperand &operand) override{};
+
+  void visitCallOperand(OpOperand &operand) override{};
+
+  void visitExternalCall(CallOpInterface call, ArrayRef<SGMap *> operands,
+                         ArrayRef<const SGMap *> results) override{};
+
+  void setToExitState(SGMap *lattice) override {
+    (void)lattice->meet(SGMapLatticeValue());
+  }
+};
+} // namespace
+
+LogicalResult
+SGMapPropagation::visitOperation(Operation *op, ArrayRef<SGMap *> operands,
+                                 ArrayRef<const SGMap *> results) {
+  /// Handle dpas
+  if (auto dpas = dyn_cast<xegpu::DpasOp>(op)) {
+    auto aTy = dpas.getLhsType().getElementType();
+    auto bTy = dpas.getRhsType().getElementType();
+    propagateIfChanged(operands[0],
+                       operands[0]->meet(getSGMapForDPASOperand(aTy, 0)));
+    propagateIfChanged(operands[1],
+                       operands[1]->meet(getSGMapForDPASOperand(bTy, 1)));
+    if (operands.size() > 2) {
+      auto cTy = dpas.getAccType().getElementType();
+      propagateIfChanged(operands[2],
+                         operands[2]->meet(getSGMapForDPASOperand(cTy, 2)));
+    }
+    return success();
+  }
+  for (const SGMap *r : results) {
+    // For each operand assume a default layout.
+    for (SGMap *operand : operands) {
+      meet(operand, *r);
+    }
+    addDependency(const_cast<SGMap *>(r), getProgramPointAfter(op));
+  }
+  return success();
+}
+
+void xegpu::populateXeGPUSubgroupDistributePatterns(
+    RewritePatternSet &patterns) {}
+
+namespace {
+
+class RunSGMapPropagation {
+public:
+  RunSGMapPropagation(Operation *op) {
+    SymbolTableCollection symbolTable;
+
+    solver.load<DeadCodeAnalysis>();
+    solver.load<SparseConstantPropagation>();
+    solver.load<SGMapPropagation>(symbolTable);
+    (void)solver.initializeAndRun(op);
+  }
+
+  SGMapLatticeValue getSGMap(Value val) {
+    auto *state = solver.lookupState<SGMap>(val);
+    if (!state)
+      return {};
+    return state->getValue();
+  }
+
+private:
+  DataFlowSolver solver;
+};
+
+struct XeGPUSubgroupDistributePass final
+    : public xegpu::impl::XeGPUSubgroupDistributeBase<
+          XeGPUSubgroupDistributePass> {
+  void runOnOperation() override;
+};
+
+} // namespace
+
+void XeGPUSubgroupDistributePass::runOnOperation() {
+  Operation *op = getOperation();
+
+  RunSGMapPropagation solver(op);
+
+  // Print analysis results
+  auto &os = llvm::outs();
+  op->walk([&](Operation *op) {
+    if (op->getResults().empty())
+      return;
+    auto layouts = solver.getSGMap(op->getResult(0));
+    os << "SGMap for " << op->getName() << ": ";
+    layouts.print(os);
+    os << "\n";
+  });
+}

>From d0c1144513b0c6c077c0b545889d8e0d90cf394d Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Sat, 1 Mar 2025 00:04:34 +0000
Subject: [PATCH 02/28] save work

---
 .../Transforms/XeGPUSubgroupDistribute.cpp    | 33 ++++++++++++++-----
 1 file changed, 25 insertions(+), 8 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 99995d92e24b6..d2f43b3448f95 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -9,11 +9,13 @@
 #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
 #include "mlir/Analysis/DataFlow/SparseAnalysis.h"
 #include "mlir/Analysis/DataFlowFramework.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/XeGPU/IR/XeGPU.h"
 #include "mlir/Dialect/XeGPU/Transforms/Passes.h"
 #include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
 #include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/Interfaces/DataLayoutInterfaces.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "llvm/Support/raw_ostream.h"
 #include <algorithm>
@@ -238,12 +240,27 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
 
   // Print analysis results
   auto &os = llvm::outs();
-  op->walk([&](Operation *op) {
-    if (op->getResults().empty())
-      return;
-    auto layouts = solver.getSGMap(op->getResult(0));
-    os << "SGMap for " << op->getName() << ": ";
-    layouts.print(os);
-    os << "\n";
-  });
+  // check if op is a function
+  // llvm::errs() << op->getName() << "\n";
+  if (auto modOp = dyn_cast<ModuleOp>(op)) {
+    for (auto funcOp : modOp.getOps<func::FuncOp>()) {
+      os << "SGMap for " << funcOp.getName() << ":\n";
+      // Function args
+      for (auto arg : funcOp.getArguments()) {
+        auto layouts = solver.getSGMap(arg);
+        os << "SGMap for " << arg << ": ";
+        layouts.print(os);
+        os << "\n";
+      }
+      // Function ops
+      funcOp.walk([&](Operation *op) {
+        if (op->getResults().empty())
+          return;
+        auto layouts = solver.getSGMap(op->getResult(0));
+        os << "SGMap for " << op->getName() << ": ";
+        layouts.print(os);
+        os << "\n";
+      });
+    }
+  }
 }

>From 0b8752c921b17a99ab58f48cebc7564cd1387256 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Tue, 4 Mar 2025 18:14:40 +0000
Subject: [PATCH 03/28] save work

---
 .../Transforms/XeGPUSubgroupDistribute.cpp    | 361 +++++++++++-------
 .../XeGPU/subgroup-map-propagation.mlir       | 186 +++++++++
 2 files changed, 402 insertions(+), 145 deletions(-)
 create mode 100644 mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index d2f43b3448f95..c80ba2ca5f3d1 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -11,12 +11,15 @@
 #include "mlir/Analysis/DataFlowFramework.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/XeGPU/IR/XeGPU.h"
 #include "mlir/Dialect/XeGPU/Transforms/Passes.h"
 #include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/Interfaces/DataLayoutInterfaces.h"
+#include "mlir/Support/LLVM.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/Support/LogicalResult.h"
 #include "llvm/Support/raw_ostream.h"
 #include <algorithm>
 
@@ -38,217 +41,264 @@ constexpr unsigned packedASizeInBits = 16;
 constexpr unsigned packedBSizeInBits = 32;
 
 namespace {
-struct Layout2D {
-  SmallVector<int64_t, 2> layout;
-  Layout2D() = default;
-  Layout2D(int64_t x, int64_t y) { layout.insert(layout.end(), {x, y}); }
-  bool operator==(const Layout2D &rhs) const {
-    return this->layout == rhs.layout;
-  }
-  bool operator<(const Layout2D &rhs) const {
-    return this->layout < rhs.layout;
-  }
-  void print(llvm::raw_ostream &os) const {
-    os << "{";
-    llvm::interleave(
-        layout, os, [&](int64_t a) { os << a; }, ", ");
-    os << "}";
-  }
+struct Layout {
+  SmallVector<int64_t, 3> layout;
+  Layout() = default;
+  Layout(const Layout &other) = default;
+  Layout(std::initializer_list<int64_t> list) : layout(list) {}
+  void print(llvm::raw_ostream &os) const;
+  size_t size() const { return layout.size(); }
 };
 
-using WiLayout = Layout2D;
-using WiData = Layout2D;
-
-struct SGMapInfo {
-  Layout2D wiLayout;
-  Layout2D wiData;
-  SGMapInfo() = default;
-  SGMapInfo(const Layout2D &layout, const Layout2D &data, unsigned bitWidth)
-      : wiLayout(layout), wiData(data) {}
-  bool operator==(const SGMapInfo &rhs) const {
-    return this->wiLayout == rhs.wiLayout && this->wiData == rhs.wiData;
-  }
-  bool operator<(const SGMapInfo &rhs) const {
-    return this->wiLayout < rhs.wiLayout || this->wiData < rhs.wiData;
-  }
-  void print(llvm::raw_ostream &os) const {
-    os << "{";
-    os << "layout: ";
-    wiLayout.print(os);
-    os << ", ";
-    os << "data: ";
-    wiData.print(os);
-    os << "}";
-  }
-};
+void Layout::print(llvm::raw_ostream &os) const {
+  os << "[";
+  llvm::interleaveComma(layout, os);
+  os << "]";
+}
+
+using WiLayout = Layout;
+using WiData = Layout;
 
-struct SGMapLatticeValue {
+struct SGMap {
 private:
-  std::set<SGMapInfo> layouts;
+  WiLayout layout;
+  WiData data;
 
 public:
-  SGMapLatticeValue() = default;
-  SGMapLatticeValue(const SGMapLatticeValue &other) = default;
-  SGMapLatticeValue(const WiLayout &layout, const WiData &data) {
-    layouts.insert(SGMapInfo(layout, data, 16));
+  SGMap() = default;
+  SGMap(const SGMap &other) = default;
+  SGMap(const WiLayout &layout, const WiData &data)
+      : layout(layout), data(data) {}
+
+  // Two lattice values are equal if they have `some` layout. The actual
+  // content of the layout does not matter.
+  bool operator==(const SGMap &other) const {
+    return this->isAssigned() == other.isAssigned();
   }
 
-  bool operator==(const SGMapLatticeValue &other) const {
-    return this->layouts == other.layouts;
-  }
+  static SGMap meet(const SGMap &lhs, const SGMap &rhs);
 
-  /// This function depends on a partial ordering of the lattice values.
-  static SGMapLatticeValue meet(const SGMapLatticeValue &lhs,
-                                const SGMapLatticeValue &rhs) {
-    SGMapLatticeValue res = lhs;
-    (void)res.addLayouts(rhs.layouts);
-    return res;
-  }
+  static SGMap join(const SGMap &lhs, const SGMap &rhs);
 
-  static SGMapLatticeValue join(const SGMapLatticeValue &lhs,
-                                const SGMapLatticeValue &rhs) {
-    // Should not be triggered by this analysis, but required by `Lattice<T>`
-    llvm_unreachable("Join should not be triggered by this test");
-  }
+  void print(raw_ostream &os) const;
 
-  ChangeResult addLayouts(const std::set<SGMapInfo> &layouts) {
-    int sizeBefore = this->layouts.size();
-    this->layouts.insert(layouts.begin(), layouts.end());
-    int sizeAfter = this->layouts.size();
-    return sizeBefore == sizeAfter ? ChangeResult::NoChange
-                                   : ChangeResult::Change;
-  }
+  bool isAssigned() const { return layout.size() > 0 && data.size() > 0; }
 
-  void print(raw_ostream &os) const {
-    os << "[";
-    llvm::interleave(
-        layouts, os, [&](const SGMapInfo &a) { a.print(os); }, ", ");
-    os << "]";
-  }
+  SGMap getTransposedLayout(ArrayRef<int64_t> permutation) const;
+};
+
+void SGMap::print(raw_ostream &os) const {
+  if (isAssigned()) {
+    os << "Layout: ";
+    layout.print(os);
+    os << ", Data: ";
+    data.print(os);
+  } else
+    os << "Not initialized";
+}
 
-  void clear() { layouts.clear(); }
+SGMap SGMap::meet(const SGMap &lhs, const SGMap &rhs) {
+  if (!lhs.isAssigned())
+    return rhs;
+  return lhs;
+}
 
-  std::set<SGMapInfo> getLayouts() const { return layouts; }
-};
+SGMap SGMap::join(const SGMap &lhs, const SGMap &rhs) {
+  // Should not be triggered by this analysis, but required by `Lattice<T>`
+  llvm_unreachable("Join should not be triggered by this test");
+}
 
-struct SGMap : public Lattice<SGMapLatticeValue> {
-  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SGMap)
+SGMap SGMap::getTransposedLayout(ArrayRef<int64_t> permutation) const {
+  if (!isAssigned())
+    return {};
+  WiLayout newLayout;
+  WiData newData;
+  for (auto idx : permutation) {
+    newLayout.layout.push_back(layout.layout[idx]);
+    newData.layout.push_back(data.layout[idx]);
+  }
+  return SGMap(newLayout, data);
+}
+
+struct SGMapLattice : public Lattice<SGMap> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SGMapLattice)
   using Lattice::Lattice;
 };
 
-static SGMapLatticeValue getSGMapForDPASOperand(Type operandTy,
-                                                unsigned operandNum) {
+/// Helper Functions
+///
+
+static SGMap getSGMapForDPASOperand(Type operandTy, unsigned operandNum) {
   int packingFactorForB = packedBSizeInBits / operandTy.getIntOrFloatBitWidth();
   int packingFactorForA =
       operandTy.getIntOrFloatBitWidth() < packedBSizeInBits
           ? packedASizeInBits / operandTy.getIntOrFloatBitWidth()
           : 1;
-  return SGMapLatticeValue(WiLayout(1, subgroupSize),
-                           WiData(operandNum == 1 ? packingFactorForB : 1,
-                                  operandNum == 0 ? packingFactorForA : 1));
+  return SGMap(WiLayout({1, subgroupSize}),
+               WiData({operandNum == 1 ? packingFactorForB : 1,
+                       operandNum == 0 ? packingFactorForA : 1}));
+}
+
+static SGMap getDefaultSgMap(Type ty) {
+  int packingFactor = 1;
+  if (ty.getIntOrFloatBitWidth() < packedASizeInBits)
+    packingFactor = packedBSizeInBits / ty.getIntOrFloatBitWidth();
+  return SGMap(WiLayout({1, subgroupSize}), WiData({1, packingFactor}));
 }
 
-class SGMapPropagation : public SparseBackwardDataFlowAnalysis<SGMap> {
+class SGMapPropagation : public SparseBackwardDataFlowAnalysis<SGMapLattice> {
+private:
+  void visitDpasOp(xegpu::DpasOp dpas, ArrayRef<SGMapLattice *> operands,
+                   ArrayRef<const SGMapLattice *> results);
+
+  void visitStoreNdOp(xegpu::StoreNdOp store, ArrayRef<SGMapLattice *> operands,
+                      ArrayRef<const SGMapLattice *> results);
+
+  void visitLoadNdOp(xegpu::LoadNdOp load, ArrayRef<SGMapLattice *> operands,
+                     ArrayRef<const SGMapLattice *> results);
+
+  void visitTransposeOp(vector::TransposeOp transpose,
+                        ArrayRef<SGMapLattice *> operands,
+                        ArrayRef<const SGMapLattice *> results);
+
 public:
   SGMapPropagation(DataFlowSolver &solver, SymbolTableCollection &symbolTable)
       : SparseBackwardDataFlowAnalysis(solver, symbolTable) {}
   using SparseBackwardDataFlowAnalysis::SparseBackwardDataFlowAnalysis;
 
-  LogicalResult visitOperation(Operation *op, ArrayRef<SGMap *> operands,
-                               ArrayRef<const SGMap *> results) override;
+  LogicalResult visitOperation(Operation *op, ArrayRef<SGMapLattice *> operands,
+                               ArrayRef<const SGMapLattice *> results) override;
 
   void visitBranchOperand(OpOperand &operand) override{};
 
   void visitCallOperand(OpOperand &operand) override{};
 
-  void visitExternalCall(CallOpInterface call, ArrayRef<SGMap *> operands,
-                         ArrayRef<const SGMap *> results) override{};
+  void visitExternalCall(CallOpInterface call,
+                         ArrayRef<SGMapLattice *> operands,
+                         ArrayRef<const SGMapLattice *> results) override{};
 
-  void setToExitState(SGMap *lattice) override {
-    (void)lattice->meet(SGMapLatticeValue());
+  void setToExitState(SGMapLattice *lattice) override {
+    (void)lattice->meet(SGMap());
   }
 };
 } // namespace
 
 LogicalResult
-SGMapPropagation::visitOperation(Operation *op, ArrayRef<SGMap *> operands,
-                                 ArrayRef<const SGMap *> results) {
-  /// Handle dpas
-  if (auto dpas = dyn_cast<xegpu::DpasOp>(op)) {
-    auto aTy = dpas.getLhsType().getElementType();
-    auto bTy = dpas.getRhsType().getElementType();
-    propagateIfChanged(operands[0],
-                       operands[0]->meet(getSGMapForDPASOperand(aTy, 0)));
-    propagateIfChanged(operands[1],
-                       operands[1]->meet(getSGMapForDPASOperand(bTy, 1)));
-    if (operands.size() > 2) {
-      auto cTy = dpas.getAccType().getElementType();
-      propagateIfChanged(operands[2],
-                         operands[2]->meet(getSGMapForDPASOperand(cTy, 2)));
-    }
-    return success();
-  }
-  for (const SGMap *r : results) {
-    // For each operand assume a default layout.
-    for (SGMap *operand : operands) {
-      meet(operand, *r);
+SGMapPropagation::visitOperation(Operation *op,
+                                 ArrayRef<SGMapLattice *> operands,
+                                 ArrayRef<const SGMapLattice *> results) {
+  if (auto dpas = dyn_cast<xegpu::DpasOp>(op))
+    visitDpasOp(dpas, operands, results);
+  else if (auto store = dyn_cast<xegpu::StoreNdOp>(op))
+    visitStoreNdOp(store, operands, results);
+  else if (auto load = dyn_cast<xegpu::LoadNdOp>(op))
+    visitLoadNdOp(load, operands, results);
+  else if (auto transpose = dyn_cast<vector::TransposeOp>(op))
+    visitTransposeOp(transpose, operands, results);
+  /// All other ops
+  else {
+    for (const SGMapLattice *r : results) {
+      for (SGMapLattice *operand : operands) {
+        if (r->getValue().isAssigned())
+          meet(operand, *r);
+      }
+      addDependency(const_cast<SGMapLattice *>(r), getProgramPointAfter(op));
     }
-    addDependency(const_cast<SGMap *>(r), getProgramPointAfter(op));
   }
   return success();
 }
 
-void xegpu::populateXeGPUSubgroupDistributePatterns(
-    RewritePatternSet &patterns) {}
+void SGMapPropagation::visitDpasOp(xegpu::DpasOp dpas,
+                                   ArrayRef<SGMapLattice *> operands,
+                                   ArrayRef<const SGMapLattice *> results) {
+  auto aTy = dpas.getLhsType().getElementType();
+  auto bTy = dpas.getRhsType().getElementType();
+  propagateIfChanged(operands[0],
+                     operands[0]->meet(getSGMapForDPASOperand(aTy, 0)));
+  propagateIfChanged(operands[1],
+                     operands[1]->meet(getSGMapForDPASOperand(bTy, 1)));
+  if (operands.size() > 2) {
+    auto cTy = dpas.getAccType().getElementType();
+    propagateIfChanged(operands[2],
+                       operands[2]->meet(getSGMapForDPASOperand(cTy, 2)));
+  }
+};
+
+void SGMapPropagation::visitStoreNdOp(xegpu::StoreNdOp store,
+                                      ArrayRef<SGMapLattice *> operands,
+                                      ArrayRef<const SGMapLattice *> results) {
+  auto storeLayout =
+      getDefaultSgMap(store.getTensorDescType().getElementType());
+  /// Both operands should have the same layout
+  for (SGMapLattice *operand : operands) {
+    propagateIfChanged(operand, operand->meet(storeLayout));
+  }
+}
+
+void SGMapPropagation::visitLoadNdOp(xegpu::LoadNdOp load,
+                                     ArrayRef<SGMapLattice *> operands,
+                                     ArrayRef<const SGMapLattice *> results) {
+  auto valueLayout = results[0]->getValue();
+  /// Need the layout of the value to propagate to the tensor descriptor.
+  if (!valueLayout.isAssigned())
+    return;
+  SGMap tensorDescLayout = valueLayout;
+  if (auto transpose = load.getTranspose())
+    tensorDescLayout = valueLayout.getTransposedLayout(transpose.value());
+  /// Propagate the new layout to the tensor descriptor operand.
+  propagateIfChanged(operands[0], operands[0]->meet(tensorDescLayout));
+}
+
+void SGMapPropagation::visitTransposeOp(
+    vector::TransposeOp transpose, ArrayRef<SGMapLattice *> operands,
+    ArrayRef<const SGMapLattice *> results) {
+  /// Need the layout of transpose result to propagate to the operands.
+  auto operandLayout = results[0]->getValue();
+  if (!operandLayout.isAssigned())
+    return;
+  auto newLayout =
+      operandLayout.getTransposedLayout(transpose.getPermutation());
+  /// Propagate the new layout to the vector operand.
+  propagateIfChanged(operands[0], operands[0]->meet(newLayout));
+}
 
 namespace {
 
 class RunSGMapPropagation {
 public:
-  RunSGMapPropagation(Operation *op) {
+  RunSGMapPropagation(Operation *op) : target(op) {
     SymbolTableCollection symbolTable;
-
     solver.load<DeadCodeAnalysis>();
     solver.load<SparseConstantPropagation>();
     solver.load<SGMapPropagation>(symbolTable);
     (void)solver.initializeAndRun(op);
   }
 
-  SGMapLatticeValue getSGMap(Value val) {
-    auto *state = solver.lookupState<SGMap>(val);
-    if (!state)
-      return {};
-    return state->getValue();
-  }
+  SGMap getSGMap(Value val);
+
+  void printAnalysisResult(llvm::raw_ostream &os);
 
 private:
   DataFlowSolver solver;
+  const Operation *target;
 };
-
-struct XeGPUSubgroupDistributePass final
-    : public xegpu::impl::XeGPUSubgroupDistributeBase<
-          XeGPUSubgroupDistributePass> {
-  void runOnOperation() override;
-};
-
 } // namespace
 
-void XeGPUSubgroupDistributePass::runOnOperation() {
-  Operation *op = getOperation();
-
-  RunSGMapPropagation solver(op);
+SGMap RunSGMapPropagation::getSGMap(Value val) {
+  auto *state = solver.lookupState<SGMapLattice>(val);
+  if (!state)
+    return {};
+  return state->getValue();
+}
 
-  // Print analysis results
-  auto &os = llvm::outs();
-  // check if op is a function
-  // llvm::errs() << op->getName() << "\n";
-  if (auto modOp = dyn_cast<ModuleOp>(op)) {
+void RunSGMapPropagation::printAnalysisResult(llvm::raw_ostream &os) {
+  if (auto modOp = dyn_cast<ModuleOp>(target)) {
     for (auto funcOp : modOp.getOps<func::FuncOp>()) {
-      os << "SGMap for " << funcOp.getName() << ":\n";
+      os << "sg_map for " << funcOp.getName() << ":\n";
       // Function args
       for (auto arg : funcOp.getArguments()) {
-        auto layouts = solver.getSGMap(arg);
-        os << "SGMap for " << arg << ": ";
+        auto layouts = getSGMap(arg);
+        os << "sg_map for " << arg << ": ";
         layouts.print(os);
         os << "\n";
       }
@@ -256,11 +306,32 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
       funcOp.walk([&](Operation *op) {
         if (op->getResults().empty())
           return;
-        auto layouts = solver.getSGMap(op->getResult(0));
-        os << "SGMap for " << op->getName() << ": ";
+        auto layouts = getSGMap(op->getResult(0));
+        os << "sg_map for " << op->getName() << ": ";
         layouts.print(os);
         os << "\n";
       });
     }
   }
 }
+
+namespace {
+struct XeGPUSubgroupDistributePass final
+    : public xegpu::impl::XeGPUSubgroupDistributeBase<
+          XeGPUSubgroupDistributePass> {
+  void runOnOperation() override;
+};
+} // namespace
+
+void XeGPUSubgroupDistributePass::runOnOperation() {
+  Operation *op = getOperation();
+
+  RunSGMapPropagation solver(op);
+
+  // Print analysis results
+  auto &os = llvm::outs();
+  solver.printAnalysisResult(os);
+}
+
+void xegpu::populateXeGPUSubgroupDistributePatterns(
+    RewritePatternSet &patterns) {}
diff --git a/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir b/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
new file mode 100644
index 0000000000000..217d510973a4f
--- /dev/null
+++ b/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
@@ -0,0 +1,186 @@
+func.func @test(%a: memref<8x16xf16>, %b : memref<16x16xf16>, %c : memref<8x16xf32>) {
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant dense<0.0> : vector<8x16xf32>
+  %0 = xegpu.create_nd_tdesc %a[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
+  %1 = xegpu.create_nd_tdesc %b[%c0, %c0] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
+  %2 = xegpu.load_nd %0 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+  %3 = xegpu.load_nd %1 : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+  %4 = xegpu.dpas %2, %3, %cst : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
+  %5 = xegpu.create_nd_tdesc %c[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
+  xegpu.store_nd %4, %5 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
+  return
+}
+
+// -----
+func.func @test(%a: memref<8x16xf16>, %b : memref<16x16xf16>, %c : memref<8x16xf32>) {
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant dense<0.0> : vector<8x16xf32>
+  %0 = xegpu.create_nd_tdesc %a[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
+  %1 = xegpu.create_nd_tdesc %b[%c0, %c0] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
+  %2 = xegpu.load_nd %0 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+  %3 = xegpu.load_nd %1 {transpose = array<i64: 1, 0>} : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+  %4 = xegpu.dpas %2, %3, %cst : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
+  %5 = xegpu.create_nd_tdesc %c[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
+  xegpu.store_nd %4, %5 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
+  return
+}
+
+
+// -----
+func.func @test(%arg0: vector<8x16xf16>, %arg1: vector<16x16xf16>) -> vector<8x16xf32> {
+  %0 = xegpu.dpas %arg0, %arg1 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+  return %0 : vector<8x16xf32>
+}
+
+// -----
+func.func @test(%arg0: vector<8x32xi8>, %arg1: vector<32x32xi8>) -> vector<8x32xi32> {
+  %0 = xegpu.dpas %arg0, %arg1 : vector<8x32xi8>, vector<32x32xi8> -> vector<8x32xi32>
+  return %0 : vector<8x32xi32>
+}
+
+// -----
+func.func @test(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>) -> vector<8x16xf32> {
+  %0 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+  %1 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+  %2 = xegpu.dpas %0, %1 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+  return %2 : vector<8x16xf32>
+}
+
+// -----
+func.func @test(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>) -> (vector<8x16xf32>, vector<8x16xf32>) {
+  %0 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+  %1 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+  %2 = xegpu.dpas %0, %1 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+  %3 = xegpu.dpas %0, %1 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+  return %2, %3 : vector<8x16xf32>, vector<8x16xf32>
+}
+
+// -----
+func.func @test(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>) -> (vector<8x16xf32>, vector<16x16xf16>) {
+  %0 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+  %1 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+  %2 = xegpu.dpas %0, %1 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+  return %2, %1 : vector<8x16xf32>, vector<16x16xf16>
+}
+
+// -----
+func.func @test(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>) -> vector<8x16xf32> {
+  %0 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+  %1 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+  %2 = arith.extf %1 : vector<16x16xf16> to vector<16x16xf32>
+  %3 = arith.truncf %2 : vector<16x16xf32> to vector<16x16xf16>
+  %4 = xegpu.dpas %0, %3 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+  return %4 : vector<8x16xf32>
+}
+
+// -----
+func.func @test(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>) -> vector<8x16xf32> {
+  %0 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+  %1 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+  %2 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+  %3 = arith.addf %1, %2 : vector<16x16xf16>
+  %4 = xegpu.dpas %0, %3 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+  return %4 : vector<8x16xf32>
+}
+
+// -----
+func.func @test(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>) -> vector<8x16xf32> {
+  %0 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+  %1 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+  %cst = arith.constant dense<1.000000e+00> : vector<16x16xf16>
+  %2 = arith.addf %1, %cst : vector<16x16xf16>
+  %3 = xegpu.dpas %0, %2 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+  return %3 : vector<8x16xf32>
+}
+
+// -----
+func.func @test(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>) -> (vector<8x16xf32>, vector<16x16xf16>) {
+  %0 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+  %1 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+  %2 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+  %3 = arith.addf %1, %2 : vector<16x16xf16>
+  %4 = xegpu.dpas %0, %3 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+  return %4, %2 : vector<8x16xf32>, vector<16x16xf16>
+}
+
+// -----
+func.func @test(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: index) -> vector<8x16xf32> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %0 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+  %1 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+  %2 = xegpu.dpas %0, %1 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+  %3:2 = scf.for %arg3 = %c0 to %arg2 step %c1 iter_args(%arg4 = %1, %arg5 = %2) -> (vector<16x16xf16>, vector<8x16xf32>) {
+    %4 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+    %5 = arith.addf %arg4, %4 : vector<16x16xf16>
+    %6 = xegpu.dpas %0, %5 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+    scf.yield %5, %6 : vector<16x16xf16>, vector<8x16xf32>
+  }
+  return %3#1 : vector<8x16xf32>
+}
+
+// -----
+func.func @test(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: i1) -> vector<8x16xf32> {
+  %0 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+  %1 = scf.if %arg2 -> (vector<16x16xf16>) {
+    %3 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+    scf.yield %3 : vector<16x16xf16>
+  } else {
+    %3 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+    scf.yield %3 : vector<16x16xf16>
+  }
+  %2 = xegpu.dpas %0, %1 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+  return %2 : vector<8x16xf32>
+}
+
+// -----
+func.func @test(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: i1) -> (vector<8x16xf32>, vector<16x16xf16>) {
+  %0 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+  %1 = scf.if %arg2 -> (vector<16x16xf16>) {
+    %3 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+    scf.yield %3 : vector<16x16xf16>
+  } else {
+    %3 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+    scf.yield %3 : vector<16x16xf16>
+  }
+  %2 = xegpu.dpas %0, %1 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+  return %2, %1 : vector<8x16xf32>, vector<16x16xf16>
+}
+
+// -----
+func.func @test(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: vector<16x16xf16>, %arg3: i1) -> vector<8x16xf32> {
+  %0 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+  %1 = scf.if %arg3 -> (vector<16x16xf16>) {
+    %3 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+    scf.yield %3 : vector<16x16xf16>
+  } else {
+    scf.yield %arg2 : vector<16x16xf16>
+  }
+  %2 = xegpu.dpas %0, %1 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+  return %2 : vector<8x16xf32>
+}
+
+// -----
+func.func @test(%arg0: !xegpu.tensor_desc<8x16xi16>, %arg1: !xegpu.tensor_desc<16x16xi16>) -> vector<8x16xf32> {
+  %0 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<8x16xi16> -> vector<8x16xi16>
+  %1 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xi16> -> vector<16x16xi16>
+  %2 = arith.bitcast %0 : vector<8x16xi16> to vector<8x16xf16>
+  %3 = arith.bitcast %1 : vector<16x16xi16> to vector<16x16xf16>
+  %4 = xegpu.dpas %2, %3 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+  return %4 : vector<8x16xf32>
+}
+
+// -----
+func.func @test(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: vector<16x16xf16>, %arg3: i1) -> vector<8x16xf32> {
+  %0 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+  %1 = scf.if %arg3 -> (vector<8x16xf32>) {
+    %2 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+    %3 = arith.addf %2, %arg2 : vector<16x16xf16>
+    %4 = xegpu.dpas %0, %3 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+    scf.yield %4 : vector<8x16xf32>
+  } else {
+    %2 = xegpu.dpas %0, %arg2 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+    scf.yield %2 : vector<8x16xf32>
+  }
+  return %1 : vector<8x16xf32>
+}

>From 2890d5031eab8c07e8b1e3526b1caa13f2574e0d Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Tue, 4 Mar 2025 22:58:28 +0000
Subject: [PATCH 04/28] save work

---
 .../Transforms/XeGPUSubgroupDistribute.cpp    | 51 ++++++++++++-
 .../XeGPU/subgroup-map-propagation.mlir       | 75 +++++++++++--------
 2 files changed, 95 insertions(+), 31 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index c80ba2ca5f3d1..4f6091d7bebfb 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -48,6 +48,7 @@ struct Layout {
   Layout(std::initializer_list<int64_t> list) : layout(list) {}
   void print(llvm::raw_ostream &os) const;
   size_t size() const { return layout.size(); }
+  int64_t operator[](size_t idx) const { return layout[idx]; }
 };
 
 void Layout::print(llvm::raw_ostream &os) const {
@@ -85,6 +86,9 @@ struct SGMap {
   bool isAssigned() const { return layout.size() > 0 && data.size() > 0; }
 
   SGMap getTransposedLayout(ArrayRef<int64_t> permutation) const;
+
+  const WiLayout &getLayout() const { return layout; }
+  const WiData &getData() const { return data; }
 };
 
 void SGMap::print(raw_ostream &os) const {
@@ -160,6 +164,9 @@ class SGMapPropagation : public SparseBackwardDataFlowAnalysis<SGMapLattice> {
   void visitTransposeOp(vector::TransposeOp transpose,
                         ArrayRef<SGMapLattice *> operands,
                         ArrayRef<const SGMapLattice *> results);
+  void visitVectorBitcastOp(vector::BitCastOp bitcast,
+                            ArrayRef<SGMapLattice *> operands,
+                            ArrayRef<const SGMapLattice *> results);
 
 public:
   SGMapPropagation(DataFlowSolver &solver, SymbolTableCollection &symbolTable)
@@ -195,16 +202,23 @@ SGMapPropagation::visitOperation(Operation *op,
     visitLoadNdOp(load, operands, results);
   else if (auto transpose = dyn_cast<vector::TransposeOp>(op))
     visitTransposeOp(transpose, operands, results);
+  else if (auto bitcast = dyn_cast<vector::BitCastOp>(op))
+    visitVectorBitcastOp(bitcast, operands, results);
   /// All other ops
   else {
     for (const SGMapLattice *r : results) {
       for (SGMapLattice *operand : operands) {
+        /// Propagate the layout of the result to the operand.
         if (r->getValue().isAssigned())
           meet(operand, *r);
       }
-      addDependency(const_cast<SGMapLattice *>(r), getProgramPointAfter(op));
     }
   }
+  /// Add a dependency from each reult to program point after the operation.
+  /// NOTE: not sure if this is required, but all other passes do this.
+  for (const SGMapLattice *r : results) {
+    addDependency(const_cast<SGMapLattice *>(r), getProgramPointAfter(op));
+  }
   return success();
 }
 
@@ -262,6 +276,41 @@ void SGMapPropagation::visitTransposeOp(
   propagateIfChanged(operands[0], operands[0]->meet(newLayout));
 }
 
+void SGMapPropagation::visitVectorBitcastOp(
+    vector::BitCastOp bitcast, ArrayRef<SGMapLattice *> operands,
+    ArrayRef<const SGMapLattice *> results) {
+  /// Need the layout of bitcast result to propagate to the operands.
+  auto resultLayout = results[0]->getValue();
+  if (!resultLayout.isAssigned())
+    return;
+  auto inElemTyBitWidth =
+      bitcast.getSourceVectorType().getElementType().getIntOrFloatBitWidth();
+  auto outElemTyBitWidth =
+      bitcast.getResultVectorType().getElementType().getIntOrFloatBitWidth();
+
+  /// WiLayout does not change.
+  WiLayout newWiLayout = resultLayout.getLayout();
+  WiData newWiData;
+  /// It's a widening bitcast
+  if (inElemTyBitWidth < outElemTyBitWidth) {
+    auto ratio = outElemTyBitWidth / inElemTyBitWidth;
+    const auto &currData = resultLayout.getData();
+    newWiData = resultLayout.getData()[0] == 1
+                    ? WiData({1, currData[1] * ratio})
+                    : WiData({currData[0] * ratio, 1});
+  } else {
+    /// It's a narrowing bitcast
+    auto ratio = inElemTyBitWidth / outElemTyBitWidth;
+    const auto &currData = resultLayout.getData();
+    newWiData = resultLayout.getData()[0] == 1
+                    ? WiData({1, currData[1] / ratio})
+                    : WiData({currData[0] / ratio, 1});
+  }
+
+  propagateIfChanged(operands[0],
+                     operands[0]->meet(SGMap(newWiLayout, newWiData)));
+}
+
 namespace {
 
 class RunSGMapPropagation {
diff --git a/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir b/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
index 217d510973a4f..a806dfd9f0a6c 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
@@ -25,52 +25,67 @@ func.func @test(%a: memref<8x16xf16>, %b : memref<16x16xf16>, %c : memref<8x16xf
   return
 }
 
-
 // -----
-func.func @test(%arg0: vector<8x16xf16>, %arg1: vector<16x16xf16>) -> vector<8x16xf32> {
-  %0 = xegpu.dpas %arg0, %arg1 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
-  return %0 : vector<8x16xf32>
+func.func @test(%a: memref<8x16xf16>, %b : memref<16x16xf16>, %c : memref<8x16xf32>) {
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant dense<0.0> : vector<8x16xf32>
+  %0 = xegpu.create_nd_tdesc %a[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
+  %1 = xegpu.create_nd_tdesc %b[%c0, %c0] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
+  %2 = xegpu.load_nd %0 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+  %3 = xegpu.load_nd %1 : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+  %6 = vector.transpose %3, [1, 0] : vector<16x16xf16> to vector<16x16xf16>
+  %4 = xegpu.dpas %2, %6, %cst : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
+  %5 = xegpu.create_nd_tdesc %c[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
+  xegpu.store_nd %4, %5 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
+  return
 }
 
 // -----
-func.func @test(%arg0: vector<8x32xi8>, %arg1: vector<32x32xi8>) -> vector<8x32xi32> {
-  %0 = xegpu.dpas %arg0, %arg1 : vector<8x32xi8>, vector<32x32xi8> -> vector<8x32xi32>
-  return %0 : vector<8x32xi32>
+func.func @test(%arg0: vector<8x32xi8>, %arg1: vector<32x16xi8>, %c : memref<8x16xi32>) {
+  %c0 = arith.constant 0 : index
+  %0 = xegpu.dpas %arg0, %arg1 : vector<8x32xi8>, vector<32x16xi8> -> vector<8x16xi32>
+  %1 = xegpu.create_nd_tdesc %c [%c0, %c0] : memref<8x16xi32> -> !xegpu.tensor_desc<8x16xi32>
+  xegpu.store_nd %0, %1 : vector<8x16xi32>, !xegpu.tensor_desc<8x16xi32>
+  return
 }
 
 // -----
 func.func @test(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>) -> vector<8x16xf32> {
   %0 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
   %1 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
-  %2 = xegpu.dpas %0, %1 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
-  return %2 : vector<8x16xf32>
-}
-
-// -----
-func.func @test(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>) -> (vector<8x16xf32>, vector<8x16xf32>) {
-  %0 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
-  %1 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
-  %2 = xegpu.dpas %0, %1 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
-  %3 = xegpu.dpas %0, %1 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
-  return %2, %3 : vector<8x16xf32>, vector<8x16xf32>
+  %2 = arith.extf %1 : vector<16x16xf16> to vector<16x16xf32>
+  %3 = arith.truncf %2 : vector<16x16xf32> to vector<16x16xf16>
+  %4 = xegpu.dpas %0, %3 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+  return %4 : vector<8x16xf32>
 }
 
 // -----
-func.func @test(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>) -> (vector<8x16xf32>, vector<16x16xf16>) {
-  %0 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
-  %1 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
-  %2 = xegpu.dpas %0, %1 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
-  return %2, %1 : vector<8x16xf32>, vector<16x16xf16>
+func.func @test_bitcast(%arg0: memref<8x16xi16>, %arg1: memref<32x16xi8>, %c : memref<8x16xi32>) {
+  %c0 = arith.constant 0 : index
+  %2 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x16xi16> -> !xegpu.tensor_desc<8x16xi16>
+  %3 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<32x16xi8> -> !xegpu.tensor_desc<32x16xi8>
+  %4 = xegpu.load_nd %2 : !xegpu.tensor_desc<8x16xi16> -> vector<8x16xi16>
+  %5 = xegpu.load_nd %3 : !xegpu.tensor_desc<32x16xi8> -> vector<32x16xi8>
+  %6 = vector.bitcast %4 : vector<8x16xi16> to vector<8x32xi8>
+  %0 = xegpu.dpas %6, %5 : vector<8x32xi8>, vector<32x16xi8> -> vector<8x16xi32>
+  %1 = xegpu.create_nd_tdesc %c [%c0, %c0] : memref<8x16xi32> -> !xegpu.tensor_desc<8x16xi32>
+  xegpu.store_nd %0, %1 : vector<8x16xi32>, !xegpu.tensor_desc<8x16xi32>
+  return
 }
 
 // -----
-func.func @test(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>) -> vector<8x16xf32> {
-  %0 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
-  %1 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
-  %2 = arith.extf %1 : vector<16x16xf16> to vector<16x16xf32>
-  %3 = arith.truncf %2 : vector<16x16xf32> to vector<16x16xf16>
-  %4 = xegpu.dpas %0, %3 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
-  return %4 : vector<8x16xf32>
+func.func @test_bitcast(%arg0: memref<8x32xi8>, %arg1: memref<16x32xi8>, %c : memref<8x16xf32>) {
+  %c0 = arith.constant 0 : index
+  %2 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x32xi8> -> !xegpu.tensor_desc<8x32xi8>
+  %3 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<16x32xi8> -> !xegpu.tensor_desc<16x32xi8>
+  %4 = xegpu.load_nd %2 : !xegpu.tensor_desc<8x32xi8> -> vector<8x32xi8>
+  %5 = xegpu.load_nd %3 : !xegpu.tensor_desc<16x32xi8> -> vector<16x32xi8>
+  %6 = vector.bitcast %4 : vector<8x32xi8> to vector<8x16xf16>
+  %7 = vector.bitcast %5 : vector<16x32xi8> to vector<16x16xf16>
+  %0 = xegpu.dpas %6, %7 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+  %1 = xegpu.create_nd_tdesc %c [%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
+  xegpu.store_nd %0, %1 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
+  return
 }
 
 // -----

>From 38a388e013ff0f98fa1d24f11c6bf24b6baa649a Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Wed, 5 Mar 2025 20:40:20 +0000
Subject: [PATCH 05/28] save work

---
 .../Transforms/XeGPUSubgroupDistribute.cpp    | 76 ++++++++++++++++++-
 .../XeGPU/subgroup-map-propagation.mlir       | 15 ++++
 2 files changed, 87 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 4f6091d7bebfb..eaf87087c6fe2 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -161,13 +161,26 @@ class SGMapPropagation : public SparseBackwardDataFlowAnalysis<SGMapLattice> {
   void visitLoadNdOp(xegpu::LoadNdOp load, ArrayRef<SGMapLattice *> operands,
                      ArrayRef<const SGMapLattice *> results);
 
+  void visitLoadGatherOp(xegpu::LoadGatherOp load,
+                         ArrayRef<SGMapLattice *> operands,
+                         ArrayRef<const SGMapLattice *> results);
+
   void visitTransposeOp(vector::TransposeOp transpose,
                         ArrayRef<SGMapLattice *> operands,
                         ArrayRef<const SGMapLattice *> results);
+
   void visitVectorBitcastOp(vector::BitCastOp bitcast,
                             ArrayRef<SGMapLattice *> operands,
                             ArrayRef<const SGMapLattice *> results);
 
+  void visitCreateNdDescOp(xegpu::CreateNdDescOp createNdDesc,
+                           ArrayRef<SGMapLattice *> operands,
+                           ArrayRef<const SGMapLattice *> results);
+
+  void visitCreateDescOp(xegpu::CreateDescOp createDesc,
+                         ArrayRef<SGMapLattice *> operands,
+                         ArrayRef<const SGMapLattice *> results);
+
 public:
   SGMapPropagation(DataFlowSolver &solver, SymbolTableCollection &symbolTable)
       : SparseBackwardDataFlowAnalysis(solver, symbolTable) {}
@@ -204,6 +217,12 @@ SGMapPropagation::visitOperation(Operation *op,
     visitTransposeOp(transpose, operands, results);
   else if (auto bitcast = dyn_cast<vector::BitCastOp>(op))
     visitVectorBitcastOp(bitcast, operands, results);
+  else if (auto loadGather = dyn_cast<xegpu::LoadGatherOp>(op))
+    visitLoadGatherOp(loadGather, operands, results);
+  else if (auto createNdDesc = dyn_cast<xegpu::CreateNdDescOp>(op))
+    visitCreateNdDescOp(createNdDesc, operands, results);
+  else if (auto createDesc = dyn_cast<xegpu::CreateDescOp>(op))
+    visitCreateDescOp(createDesc, operands, results);
   /// All other ops
   else {
     for (const SGMapLattice *r : results) {
@@ -215,7 +234,8 @@ SGMapPropagation::visitOperation(Operation *op,
     }
   }
   /// Add a dependency from each reult to program point after the operation.
-  /// NOTE: not sure if this is required, but all other passes do this.
+  /// NOTE: not sure if this is required, but all other similar analysis do
+  /// this.
   for (const SGMapLattice *r : results) {
     addDependency(const_cast<SGMapLattice *>(r), getProgramPointAfter(op));
   }
@@ -289,19 +309,18 @@ void SGMapPropagation::visitVectorBitcastOp(
       bitcast.getResultVectorType().getElementType().getIntOrFloatBitWidth();
 
   /// WiLayout does not change.
-  WiLayout newWiLayout = resultLayout.getLayout();
+  const WiLayout &newWiLayout = resultLayout.getLayout();
+  const WiData &currData = resultLayout.getData();
   WiData newWiData;
   /// It's a widening bitcast
   if (inElemTyBitWidth < outElemTyBitWidth) {
     auto ratio = outElemTyBitWidth / inElemTyBitWidth;
-    const auto &currData = resultLayout.getData();
     newWiData = resultLayout.getData()[0] == 1
                     ? WiData({1, currData[1] * ratio})
                     : WiData({currData[0] * ratio, 1});
   } else {
     /// It's a narrowing bitcast
     auto ratio = inElemTyBitWidth / outElemTyBitWidth;
-    const auto &currData = resultLayout.getData();
     newWiData = resultLayout.getData()[0] == 1
                     ? WiData({1, currData[1] / ratio})
                     : WiData({currData[0] / ratio, 1});
@@ -311,6 +330,55 @@ void SGMapPropagation::visitVectorBitcastOp(
                      operands[0]->meet(SGMap(newWiLayout, newWiData)));
 }
 
+void SGMapPropagation::visitLoadGatherOp(
+    xegpu::LoadGatherOp load, ArrayRef<SGMapLattice *> operands,
+    ArrayRef<const SGMapLattice *> results) {
+  auto valueLayout = results[0]->getValue();
+  /// Need the layout of the value to propagate to the tensor descriptor.
+  if (!valueLayout.isAssigned())
+    return;
+  /// LoadGatherOp has the transpose effect, so propagate the transposed layout
+  /// to the tensor descriptor.
+  SGMap tensorDescLayout = valueLayout.getTransposedLayout({1, 0});
+  /// Mask operand should have the same layout as the value but with wi_data =
+  /// [1, 1]
+  SGMap maskLayout = SGMap(valueLayout.getLayout(), WiData({1, 1}));
+  /// Propagate the new layout to the tensor descriptor operand.
+  propagateIfChanged(operands[0], operands[0]->meet(tensorDescLayout));
+  /// Propagate the new layout to the mask operand.
+  propagateIfChanged(operands[1], operands[1]->meet(maskLayout));
+}
+
+void SGMapPropagation::visitCreateNdDescOp(
+    xegpu::CreateNdDescOp createNdDesc, ArrayRef<SGMapLattice *> operands,
+    ArrayRef<const SGMapLattice *> results) {
+  auto descLayout = results[0]->getValue();
+  /// Need the layout of the descriptor to propagate to the operands.
+  if (!descLayout.isAssigned())
+    return;
+  /// Propagate the layout to the source operand.
+  propagateIfChanged(operands[0], operands[0]->meet(descLayout));
+  /// For all other operands propagate the same layout with wi_data = [1, 1]
+  SGMap layout = SGMap(descLayout.getLayout(), WiData({1, 1}));
+  for (size_t i = 1; i < operands.size(); ++i) {
+    propagateIfChanged(operands[i], operands[i]->meet(layout));
+  }
+}
+
+void SGMapPropagation::visitCreateDescOp(
+    xegpu::CreateDescOp createDesc, ArrayRef<SGMapLattice *> operands,
+    ArrayRef<const SGMapLattice *> results) {
+  auto descLayout = results[0]->getValue();
+  /// Need the layout of the descriptor to propagate to the operands.
+  if (!descLayout.isAssigned())
+    return;
+  /// Propagate the layout to the source operand.
+  propagateIfChanged(operands[0], operands[0]->meet(descLayout));
+  /// For offset operand propagate the same layout with wi_data = [1, 1]
+  SGMap layout = SGMap(descLayout.getLayout(), WiData({1, 1}));
+  propagateIfChanged(operands[1], operands[1]->meet(layout));
+}
+
 namespace {
 
 class RunSGMapPropagation {
diff --git a/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir b/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
index a806dfd9f0a6c..dfe9c75c36ed6 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
@@ -59,6 +59,21 @@ func.func @test(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<1
   return %4 : vector<8x16xf32>
 }
 
+// -----
+func.func @test_load_gather(%a: memref<8x16xf16>, %b : memref<256xf16>, %c : memref<8x16xf32>) {
+  %c0 = arith.constant 0 : index
+  %6 = xegpu.create_nd_tdesc %a[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
+  %7 = xegpu.load_nd %6 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+  %0 = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
+  %1 = arith.constant dense<1>: vector<16xi1>
+  %2 = xegpu.create_tdesc %b, %0 : memref<256xf16>, vector<16xindex> -> !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16>>
+  %3 = xegpu.load %2, %1 <{transpose}> : !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16>>, vector<16xi1> -> vector<16x16xf16>
+  %4 = xegpu.dpas %7, %3 : vector<8x16xf16>, vector<16x16xf16>-> vector<8x16xf32>
+  %5 = xegpu.create_nd_tdesc %c[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
+  xegpu.store_nd %4, %5 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
+  return
+}
+
 // -----
 func.func @test_bitcast(%arg0: memref<8x16xi16>, %arg1: memref<32x16xi8>, %c : memref<8x16xi32>) {
   %c0 = arith.constant 0 : index

>From 8996f11f39931c9540572a630f9e726d810cd5fa Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Wed, 5 Mar 2025 22:07:34 +0000
Subject: [PATCH 06/28] save work

---
 .../Transforms/XeGPUSubgroupDistribute.cpp    | 56 ++++++++++++++++---
 .../XeGPU/subgroup-map-propagation.mlir       | 36 ++++++++----
 2 files changed, 72 insertions(+), 20 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index eaf87087c6fe2..871138b8dbc16 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -158,6 +158,10 @@ class SGMapPropagation : public SparseBackwardDataFlowAnalysis<SGMapLattice> {
   void visitStoreNdOp(xegpu::StoreNdOp store, ArrayRef<SGMapLattice *> operands,
                       ArrayRef<const SGMapLattice *> results);
 
+  void visitStoreScatterOp(xegpu::StoreScatterOp storeScatter,
+                           ArrayRef<SGMapLattice *> operands,
+                           ArrayRef<const SGMapLattice *> results);
+
   void visitLoadNdOp(xegpu::LoadNdOp load, ArrayRef<SGMapLattice *> operands,
                      ArrayRef<const SGMapLattice *> results);
 
@@ -223,6 +227,8 @@ SGMapPropagation::visitOperation(Operation *op,
     visitCreateNdDescOp(createNdDesc, operands, results);
   else if (auto createDesc = dyn_cast<xegpu::CreateDescOp>(op))
     visitCreateDescOp(createDesc, operands, results);
+  else if (auto storeScatter = dyn_cast<xegpu::StoreScatterOp>(op))
+    visitStoreScatterOp(storeScatter, operands, results);
   /// All other ops
   else {
     for (const SGMapLattice *r : results) {
@@ -379,6 +385,22 @@ void SGMapPropagation::visitCreateDescOp(
   propagateIfChanged(operands[1], operands[1]->meet(layout));
 }
 
+void SGMapPropagation::visitStoreScatterOp(
+    xegpu::StoreScatterOp storeScatter, ArrayRef<SGMapLattice *> operands,
+    ArrayRef<const SGMapLattice *> results) {
+  /// StoreScatterOp has the transpose effect. Value has the regular layout,
+  /// while the tensor descriptor has the transposed layout.
+  auto valueLayout =
+      getDefaultSgMap(storeScatter.getTensorDescType().getElementType());
+  auto storeScatterLayout = valueLayout.getTransposedLayout({1, 0});
+  /// Propagate the value layout.
+  propagateIfChanged(operands[0], operands[0]->meet(valueLayout));
+  /// Propagate the tensor descriptor layout.
+  propagateIfChanged(operands[1], operands[1]->meet(storeScatterLayout));
+  /// Propagate the mask layout.
+  propagateIfChanged(operands[2], operands[2]->meet(valueLayout));
+}
+
 namespace {
 
 class RunSGMapPropagation {
@@ -411,20 +433,25 @@ SGMap RunSGMapPropagation::getSGMap(Value val) {
 void RunSGMapPropagation::printAnalysisResult(llvm::raw_ostream &os) {
   if (auto modOp = dyn_cast<ModuleOp>(target)) {
     for (auto funcOp : modOp.getOps<func::FuncOp>()) {
-      os << "sg_map for " << funcOp.getName() << ":\n";
-      // Function args
+      os << "function: " << funcOp.getName() << ":\n";
+      // Function arguments
       for (auto arg : funcOp.getArguments()) {
-        auto layouts = getSGMap(arg);
-        os << "sg_map for " << arg << ": ";
-        layouts.print(os);
+        auto layout = getSGMap(arg);
+        os << "argument: " << arg << "\n";
+        os << "sg_map  : ";
+        layout.print(os);
         os << "\n";
       }
       // Function ops
       funcOp.walk([&](Operation *op) {
+        // Skip ops that do not have results
         if (op->getResults().empty())
           return;
         auto layouts = getSGMap(op->getResult(0));
-        os << "sg_map for " << op->getName() << ": ";
+        os << "op    : ";
+        op->print(os);
+        os << "\n";
+        os << "sg_map: ";
         layouts.print(os);
         os << "\n";
       });
@@ -436,7 +463,18 @@ namespace {
 struct XeGPUSubgroupDistributePass final
     : public xegpu::impl::XeGPUSubgroupDistributeBase<
           XeGPUSubgroupDistributePass> {
+  XeGPUSubgroupDistributePass() = default;
+  XeGPUSubgroupDistributePass(const XeGPUSubgroupDistributePass &other)
+      : xegpu::impl::XeGPUSubgroupDistributeBase<XeGPUSubgroupDistributePass>(
+            other) {
+    this->printOnly = other.printOnly;
+  }
   void runOnOperation() override;
+  /// Print sg map propagation analysis result and exit for testing purposes.
+  Option<bool> printOnly{
+      *this, "print-only", llvm::cl::init(false),
+      llvm::cl::desc(
+          "Print the result of the subgroup map propagation analysis")};
 };
 } // namespace
 
@@ -446,8 +484,10 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
   RunSGMapPropagation solver(op);
 
   // Print analysis results
-  auto &os = llvm::outs();
-  solver.printAnalysisResult(os);
+  if (printOnly) {
+    auto &os = llvm::outs();
+    solver.printAnalysisResult(os);
+  }
 }
 
 void xegpu::populateXeGPUSubgroupDistributePatterns(
diff --git a/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir b/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
index dfe9c75c36ed6..76a44fe16440b 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
@@ -1,4 +1,4 @@
-func.func @test(%a: memref<8x16xf16>, %b : memref<16x16xf16>, %c : memref<8x16xf32>) {
+func.func @test_dpas_1(%a: memref<8x16xf16>, %b : memref<16x16xf16>, %c : memref<8x16xf32>) {
   %c0 = arith.constant 0 : index
   %cst = arith.constant dense<0.0> : vector<8x16xf32>
   %0 = xegpu.create_nd_tdesc %a[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
@@ -12,7 +12,16 @@ func.func @test(%a: memref<8x16xf16>, %b : memref<16x16xf16>, %c : memref<8x16xf
 }
 
 // -----
-func.func @test(%a: memref<8x16xf16>, %b : memref<16x16xf16>, %c : memref<8x16xf32>) {
+func.func @test_dpas_2(%arg0: vector<8x32xi8>, %arg1: vector<32x16xi8>, %c : memref<8x16xi32>) {
+  %c0 = arith.constant 0 : index
+  %0 = xegpu.dpas %arg0, %arg1 : vector<8x32xi8>, vector<32x16xi8> -> vector<8x16xi32>
+  %1 = xegpu.create_nd_tdesc %c [%c0, %c0] : memref<8x16xi32> -> !xegpu.tensor_desc<8x16xi32>
+  xegpu.store_nd %0, %1 : vector<8x16xi32>, !xegpu.tensor_desc<8x16xi32>
+  return
+}
+
+// -----
+func.func @test_transpose_1(%a: memref<8x16xf16>, %b : memref<16x16xf16>, %c : memref<8x16xf32>) {
   %c0 = arith.constant 0 : index
   %cst = arith.constant dense<0.0> : vector<8x16xf32>
   %0 = xegpu.create_nd_tdesc %a[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
@@ -26,7 +35,7 @@ func.func @test(%a: memref<8x16xf16>, %b : memref<16x16xf16>, %c : memref<8x16xf
 }
 
 // -----
-func.func @test(%a: memref<8x16xf16>, %b : memref<16x16xf16>, %c : memref<8x16xf32>) {
+func.func @test_transpose_2(%a: memref<8x16xf16>, %b : memref<16x16xf16>, %c : memref<8x16xf32>) {
   %c0 = arith.constant 0 : index
   %cst = arith.constant dense<0.0> : vector<8x16xf32>
   %0 = xegpu.create_nd_tdesc %a[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
@@ -40,17 +49,10 @@ func.func @test(%a: memref<8x16xf16>, %b : memref<16x16xf16>, %c : memref<8x16xf
   return
 }
 
-// -----
-func.func @test(%arg0: vector<8x32xi8>, %arg1: vector<32x16xi8>, %c : memref<8x16xi32>) {
-  %c0 = arith.constant 0 : index
-  %0 = xegpu.dpas %arg0, %arg1 : vector<8x32xi8>, vector<32x16xi8> -> vector<8x16xi32>
-  %1 = xegpu.create_nd_tdesc %c [%c0, %c0] : memref<8x16xi32> -> !xegpu.tensor_desc<8x16xi32>
-  xegpu.store_nd %0, %1 : vector<8x16xi32>, !xegpu.tensor_desc<8x16xi32>
-  return
-}
+
 
 // -----
-func.func @test(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>) -> vector<8x16xf32> {
+func.func @test_extf_truncf(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>) -> vector<8x16xf32> {
   %0 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
   %1 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
   %2 = arith.extf %1 : vector<16x16xf16> to vector<16x16xf32>
@@ -74,6 +76,16 @@ func.func @test_load_gather(%a: memref<8x16xf16>, %b : memref<256xf16>, %c : mem
   return
 }
 
+// -----
+func.func @test_store_scatter(%c: memref<128xf32>){
+  %cst = arith.constant dense<1.000000e+00> : vector<8x16xf32>
+  %1 = arith.constant dense<1> : vector<16xi1>
+  %2 = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
+  %3 = xegpu.create_tdesc %c, %2 : memref<128xf32>, vector<16xindex> -> !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<chunk_size = 8>>
+  xegpu.store %cst, %3, %1 <{transpose}> : vector<8x16xf32>, !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<chunk_size = 8>>, vector<16xi1>
+  return
+}
+
 // -----
 func.func @test_bitcast(%arg0: memref<8x16xi16>, %arg1: memref<32x16xi8>, %c : memref<8x16xi32>) {
   %c0 = arith.constant 0 : index

>From 47888451211e237cb1f0c2b443552921a73805cc Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Thu, 6 Mar 2025 05:20:50 +0000
Subject: [PATCH 07/28] save work

---
 .../Transforms/XeGPUSubgroupDistribute.cpp    | 23 +++--
 .../XeGPU/subgroup-map-propagation.mlir       | 97 ++++++++++---------
 2 files changed, 64 insertions(+), 56 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 871138b8dbc16..4182f2176f04b 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -93,12 +93,12 @@ struct SGMap {
 
 void SGMap::print(raw_ostream &os) const {
   if (isAssigned()) {
-    os << "Layout: ";
+    os << "wi_layout: ";
     layout.print(os);
-    os << ", Data: ";
+    os << ", wi_data: ";
     data.print(os);
   } else
-    os << "Not initialized";
+    os << "Not assigned.";
 }
 
 SGMap SGMap::meet(const SGMap &lhs, const SGMap &rhs) {
@@ -447,13 +447,20 @@ void RunSGMapPropagation::printAnalysisResult(llvm::raw_ostream &os) {
         // Skip ops that do not have results
         if (op->getResults().empty())
           return;
-        auto layouts = getSGMap(op->getResult(0));
         os << "op    : ";
-        op->print(os);
-        os << "\n";
-        os << "sg_map: ";
-        layouts.print(os);
+        /// For control-flow ops, print the op name only.
+        if (isa<BranchOpInterface>(op) || isa<RegionBranchOpInterface>(op))
+          os << op->getName();
+        else
+          op->print(os);
         os << "\n";
+        /// Print the sg_map for each result.
+        for (auto [i, r] : llvm::enumerate(op->getResults())) {
+          auto layout = getSGMap(r);
+          os << "sg_map for result #" << i << ": ";
+          layout.print(os);
+          os << "\n";
+        }
       });
     }
   }
diff --git a/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir b/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
index 76a44fe16440b..81c4244091e87 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
@@ -1,4 +1,4 @@
-func.func @test_dpas_1(%a: memref<8x16xf16>, %b : memref<16x16xf16>, %c : memref<8x16xf32>) {
+func.func @test_dpas_op_1(%a: memref<8x16xf16>, %b : memref<16x16xf16>, %c : memref<8x16xf32>) {
   %c0 = arith.constant 0 : index
   %cst = arith.constant dense<0.0> : vector<8x16xf32>
   %0 = xegpu.create_nd_tdesc %a[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
@@ -12,7 +12,7 @@ func.func @test_dpas_1(%a: memref<8x16xf16>, %b : memref<16x16xf16>, %c : memref
 }
 
 // -----
-func.func @test_dpas_2(%arg0: vector<8x32xi8>, %arg1: vector<32x16xi8>, %c : memref<8x16xi32>) {
+func.func @test_dpas_op_2(%arg0: vector<8x32xi8>, %arg1: vector<32x16xi8>, %c : memref<8x16xi32>) {
   %c0 = arith.constant 0 : index
   %0 = xegpu.dpas %arg0, %arg1 : vector<8x32xi8>, vector<32x16xi8> -> vector<8x16xi32>
   %1 = xegpu.create_nd_tdesc %c [%c0, %c0] : memref<8x16xi32> -> !xegpu.tensor_desc<8x16xi32>
@@ -21,7 +21,7 @@ func.func @test_dpas_2(%arg0: vector<8x32xi8>, %arg1: vector<32x16xi8>, %c : mem
 }
 
 // -----
-func.func @test_transpose_1(%a: memref<8x16xf16>, %b : memref<16x16xf16>, %c : memref<8x16xf32>) {
+func.func @test_transpose_op_1(%a: memref<8x16xf16>, %b : memref<16x16xf16>, %c : memref<8x16xf32>) {
   %c0 = arith.constant 0 : index
   %cst = arith.constant dense<0.0> : vector<8x16xf32>
   %0 = xegpu.create_nd_tdesc %a[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
@@ -35,7 +35,7 @@ func.func @test_transpose_1(%a: memref<8x16xf16>, %b : memref<16x16xf16>, %c : m
 }
 
 // -----
-func.func @test_transpose_2(%a: memref<8x16xf16>, %b : memref<16x16xf16>, %c : memref<8x16xf32>) {
+func.func @test_transpose_op_2(%a: memref<8x16xf16>, %b : memref<16x16xf16>, %c : memref<8x16xf32>) {
   %c0 = arith.constant 0 : index
   %cst = arith.constant dense<0.0> : vector<8x16xf32>
   %0 = xegpu.create_nd_tdesc %a[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
@@ -52,7 +52,7 @@ func.func @test_transpose_2(%a: memref<8x16xf16>, %b : memref<16x16xf16>, %c : m
 
 
 // -----
-func.func @test_extf_truncf(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>) -> vector<8x16xf32> {
+func.func @test_extf_truncf_op(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>) -> vector<8x16xf32> {
   %0 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
   %1 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
   %2 = arith.extf %1 : vector<16x16xf16> to vector<16x16xf32>
@@ -62,7 +62,7 @@ func.func @test_extf_truncf(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.t
 }
 
 // -----
-func.func @test_load_gather(%a: memref<8x16xf16>, %b : memref<256xf16>, %c : memref<8x16xf32>) {
+func.func @test_load_gather_op(%a: memref<8x16xf16>, %b : memref<256xf16>, %c : memref<8x16xf32>) {
   %c0 = arith.constant 0 : index
   %6 = xegpu.create_nd_tdesc %a[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
   %7 = xegpu.load_nd %6 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
@@ -77,7 +77,7 @@ func.func @test_load_gather(%a: memref<8x16xf16>, %b : memref<256xf16>, %c : mem
 }
 
 // -----
-func.func @test_store_scatter(%c: memref<128xf32>){
+func.func @test_store_scatter_op(%c: memref<128xf32>){
   %cst = arith.constant dense<1.000000e+00> : vector<8x16xf32>
   %1 = arith.constant dense<1> : vector<16xi1>
   %2 = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
@@ -87,7 +87,7 @@ func.func @test_store_scatter(%c: memref<128xf32>){
 }
 
 // -----
-func.func @test_bitcast(%arg0: memref<8x16xi16>, %arg1: memref<32x16xi8>, %c : memref<8x16xi32>) {
+func.func @test_bitcast_op_1(%arg0: memref<8x16xi16>, %arg1: memref<32x16xi8>, %c : memref<8x16xi32>) {
   %c0 = arith.constant 0 : index
   %2 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x16xi16> -> !xegpu.tensor_desc<8x16xi16>
   %3 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<32x16xi8> -> !xegpu.tensor_desc<32x16xi8>
@@ -101,7 +101,7 @@ func.func @test_bitcast(%arg0: memref<8x16xi16>, %arg1: memref<32x16xi8>, %c : m
 }
 
 // -----
-func.func @test_bitcast(%arg0: memref<8x32xi8>, %arg1: memref<16x32xi8>, %c : memref<8x16xf32>) {
+func.func @test_bitcast_op_2(%arg0: memref<8x32xi8>, %arg1: memref<16x32xi8>, %c : memref<8x16xf32>) {
   %c0 = arith.constant 0 : index
   %2 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x32xi8> -> !xegpu.tensor_desc<8x32xi8>
   %3 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<16x32xi8> -> !xegpu.tensor_desc<16x32xi8>
@@ -116,53 +116,51 @@ func.func @test_bitcast(%arg0: memref<8x32xi8>, %arg1: memref<16x32xi8>, %c : me
 }
 
 // -----
-func.func @test(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>) -> vector<8x16xf32> {
+func.func @test_binary_op_1(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: !xegpu.tensor_desc<8x16xf32>) {
   %0 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
   %1 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
   %2 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
   %3 = arith.addf %1, %2 : vector<16x16xf16>
   %4 = xegpu.dpas %0, %3 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
-  return %4 : vector<8x16xf32>
+  xegpu.store_nd %4, %arg2 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
+  return
 }
 
 // -----
-func.func @test(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>) -> vector<8x16xf32> {
+func.func @test_binary_op_2(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: !xegpu.tensor_desc<8x16xf32>, %arg3: !xegpu.tensor_desc<16x16xf16>) {
   %0 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
   %1 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
   %cst = arith.constant dense<1.000000e+00> : vector<16x16xf16>
   %2 = arith.addf %1, %cst : vector<16x16xf16>
   %3 = xegpu.dpas %0, %2 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
-  return %3 : vector<8x16xf32>
-}
-
-// -----
-func.func @test(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>) -> (vector<8x16xf32>, vector<16x16xf16>) {
-  %0 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
-  %1 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
-  %2 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
-  %3 = arith.addf %1, %2 : vector<16x16xf16>
-  %4 = xegpu.dpas %0, %3 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
-  return %4, %2 : vector<8x16xf32>, vector<16x16xf16>
+  xegpu.store_nd %3, %arg2 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
+  xegpu.store_nd %2, %arg3 : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16>
+  return
 }
 
 // -----
-func.func @test(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: index) -> vector<8x16xf32> {
+func.func @test_for_op(%a: memref<8x128xf16>, %b : memref<128x16xf16>, %c : memref<8x16xf32>) {
   %c0 = arith.constant 0 : index
-  %c1 = arith.constant 1 : index
-  %0 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
-  %1 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
-  %2 = xegpu.dpas %0, %1 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
-  %3:2 = scf.for %arg3 = %c0 to %arg2 step %c1 iter_args(%arg4 = %1, %arg5 = %2) -> (vector<16x16xf16>, vector<8x16xf32>) {
-    %4 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
-    %5 = arith.addf %arg4, %4 : vector<16x16xf16>
-    %6 = xegpu.dpas %0, %5 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
-    scf.yield %5, %6 : vector<16x16xf16>, vector<8x16xf32>
+  %c128 = arith.constant 128 : index
+  %c16 = arith.constant 16 : index
+  %0 = xegpu.create_nd_tdesc %a[%c0, %c0] : memref<8x128xf16> -> !xegpu.tensor_desc<8x16xf16>
+  %1 = xegpu.create_nd_tdesc %b[%c0, %c0] : memref<128x16xf16> -> !xegpu.tensor_desc<16x16xf16>
+  %2 = arith.constant dense<0.0> : vector<8x16xf32>
+  %3:3 = scf.for %k = %c0 to %c128 step %c16 iter_args(%arg0 = %0, %arg1 = %1, %arg2 = %2) -> (!xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<16x16xf16>, vector<8x16xf32>) {
+    %4 = xegpu.load_nd %arg0 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+    %5 = xegpu.load_nd %arg1 : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+    %6 = xegpu.dpas %4, %5, %arg2 : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
+    %7 = xegpu.update_nd_offset %arg0, [%c0, %c16] : !xegpu.tensor_desc<8x16xf16>
+    %8 = xegpu.update_nd_offset %arg1, [%c16, %c0] : !xegpu.tensor_desc<16x16xf16>
+    scf.yield %7, %8, %6 : !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<16x16xf16>, vector<8x16xf32>
   }
-  return %3#1 : vector<8x16xf32>
+  %9 = xegpu.create_nd_tdesc %c[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
+  xegpu.store_nd %3#2, %9 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
+  return
 }
 
 // -----
-func.func @test(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: i1) -> vector<8x16xf32> {
+func.func @test_if_op_1(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: i1, %arg3: !xegpu.tensor_desc<8x16xf32>){
   %0 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
   %1 = scf.if %arg2 -> (vector<16x16xf16>) {
     %3 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
@@ -172,11 +170,12 @@ func.func @test(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<1
     scf.yield %3 : vector<16x16xf16>
   }
   %2 = xegpu.dpas %0, %1 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
-  return %2 : vector<8x16xf32>
+  xegpu.store_nd %2, %arg3 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
+  return
 }
 
 // -----
-func.func @test(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: i1) -> (vector<8x16xf32>, vector<16x16xf16>) {
+func.func @test_if_op_2(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: i1, %arg3: !xegpu.tensor_desc<8x16xf32>, %arg4: !xegpu.tensor_desc<16x16xf16>){
   %0 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
   %1 = scf.if %arg2 -> (vector<16x16xf16>) {
     %3 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
@@ -186,7 +185,9 @@ func.func @test(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<1
     scf.yield %3 : vector<16x16xf16>
   }
   %2 = xegpu.dpas %0, %1 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
-  return %2, %1 : vector<8x16xf32>, vector<16x16xf16>
+  xegpu.store_nd %2, %arg3 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
+  xegpu.store_nd %1, %arg4 : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16>
+  return
 }
 
 // -----
@@ -202,16 +203,6 @@ func.func @test(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<1
   return %2 : vector<8x16xf32>
 }
 
-// -----
-func.func @test(%arg0: !xegpu.tensor_desc<8x16xi16>, %arg1: !xegpu.tensor_desc<16x16xi16>) -> vector<8x16xf32> {
-  %0 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<8x16xi16> -> vector<8x16xi16>
-  %1 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xi16> -> vector<16x16xi16>
-  %2 = arith.bitcast %0 : vector<8x16xi16> to vector<8x16xf16>
-  %3 = arith.bitcast %1 : vector<16x16xi16> to vector<16x16xf16>
-  %4 = xegpu.dpas %2, %3 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
-  return %4 : vector<8x16xf32>
-}
-
 // -----
 func.func @test(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: vector<16x16xf16>, %arg3: i1) -> vector<8x16xf32> {
   %0 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
@@ -226,3 +217,13 @@ func.func @test(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<1
   }
   return %1 : vector<8x16xf32>
 }
+
+// -----
+func.func @test(%arg0: !xegpu.tensor_desc<8x16xi16>, %arg1: !xegpu.tensor_desc<16x16xi16>) -> vector<8x16xf32> {
+  %0 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<8x16xi16> -> vector<8x16xi16>
+  %1 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xi16> -> vector<16x16xi16>
+  %2 = arith.bitcast %0 : vector<8x16xi16> to vector<8x16xf16>
+  %3 = arith.bitcast %1 : vector<16x16xi16> to vector<16x16xf16>
+  %4 = xegpu.dpas %2, %3 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+  return %4 : vector<8x16xf32>
+}

>From c7b60d156c6a91e10fe84146f56ae2efa8a344b5 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Thu, 6 Mar 2025 22:29:27 +0000
Subject: [PATCH 08/28] save work

---
 .../Transforms/XeGPUSubgroupDistribute.cpp    | 169 +++++++++++-------
 .../XeGPU/subgroup-map-propagation.mlir       |  61 ++++---
 2 files changed, 142 insertions(+), 88 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 4182f2176f04b..13030d5c74ea0 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -15,13 +15,7 @@
 #include "mlir/Dialect/XeGPU/IR/XeGPU.h"
 #include "mlir/Dialect/XeGPU/Transforms/Passes.h"
 #include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
-#include "mlir/IR/BuiltinAttributes.h"
-#include "mlir/Interfaces/DataLayoutInterfaces.h"
-#include "mlir/Support/LLVM.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#include "llvm/Support/LogicalResult.h"
 #include "llvm/Support/raw_ostream.h"
-#include <algorithm>
 
 namespace mlir {
 namespace xegpu {
@@ -41,6 +35,13 @@ constexpr unsigned packedASizeInBits = 16;
 constexpr unsigned packedBSizeInBits = 32;
 
 namespace {
+
+///===----------------------------------------------------------------------===///
+/// Layout
+///===----------------------------------------------------------------------===///
+
+/// Helper class to store the ND layout of work items within a subgroup and data
+/// owned by each work item.
 struct Layout {
   SmallVector<int64_t, 3> layout;
   Layout() = default;
@@ -57,9 +58,30 @@ void Layout::print(llvm::raw_ostream &os) const {
   os << "]";
 }
 
+/// WiLayout represents the layout of work items within a subgroup when it
+/// accesses some value. WiData represents the layout of data owned by each work
+/// item.
 using WiLayout = Layout;
 using WiData = Layout;
 
+///===----------------------------------------------------------------------===///
+/// SGMap
+///===----------------------------------------------------------------------===///
+
+/// Helper class for tracking the analysis state of a value. For SGPropagation,
+/// the analysis state is simply the wi_layout and wi_data of each value.
+/// Purpose of this analysis to propagate some unique layout for each value in
+/// the program starting from some known values (like DPAS, StoreNd, etc.).
+///
+/// Given this, SGMap satisifies the following properties:
+///  1) SGMap is a lattice with two states - assigned and not assigned.
+///  2) Two SGMap values are equal if they are both assigned or both not
+///  assigned. The concrete value of assigned state does not matter.
+///  3) The meet operator works as follows:
+///     - If current state is assigned, return the current state. (already
+///     a unique layout is assigned. don't change it)
+///     - Otherwise, return the other state.
+
 struct SGMap {
 private:
   WiLayout layout;
@@ -71,8 +93,8 @@ struct SGMap {
   SGMap(const WiLayout &layout, const WiData &data)
       : layout(layout), data(data) {}
 
-  // Two lattice values are equal if they have `some` layout. The actual
-  // content of the layout does not matter.
+  /// Two lattice values are equal if they have `some` layout. The actual
+  /// content of the layout does not matter.
   bool operator==(const SGMap &other) const {
     return this->isAssigned() == other.isAssigned();
   }
@@ -107,9 +129,9 @@ SGMap SGMap::meet(const SGMap &lhs, const SGMap &rhs) {
   return lhs;
 }
 
+/// Since this is a backward analysis, join method is not used.
 SGMap SGMap::join(const SGMap &lhs, const SGMap &rhs) {
-  // Should not be triggered by this analysis, but required by `Lattice<T>`
-  llvm_unreachable("Join should not be triggered by this test");
+  llvm_unreachable("Join should not be triggered by SGMapPropagation.");
 }
 
 SGMap SGMap::getTransposedLayout(ArrayRef<int64_t> permutation) const {
@@ -124,14 +146,17 @@ SGMap SGMap::getTransposedLayout(ArrayRef<int64_t> permutation) const {
   return SGMap(newLayout, data);
 }
 
+///===----------------------------------------------------------------------===///
+/// SGMapLattice
+///===----------------------------------------------------------------------===///
+
+/// Lattice holding the SGMap for each value.
 struct SGMapLattice : public Lattice<SGMap> {
   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SGMapLattice)
   using Lattice::Lattice;
 };
 
-/// Helper Functions
-///
-
+/// Helper Function to get the expected layouts for DPAS operands.
 static SGMap getSGMapForDPASOperand(Type operandTy, unsigned operandNum) {
   int packingFactorForB = packedBSizeInBits / operandTy.getIntOrFloatBitWidth();
   int packingFactorForA =
@@ -143,6 +168,11 @@ static SGMap getSGMapForDPASOperand(Type operandTy, unsigned operandNum) {
                        operandNum == 0 ? packingFactorForA : 1}));
 }
 
+/// Helper Function to get the default layout for a given type. Usually this is,
+/// wi_layout = [1, subgroupSize] and wi_data = [1, 1].
+/// However, the minimum granularity of data access per work item is 16-bits.
+/// So, if the bitwidth of the type is less than 16, we need to pack the data to
+/// 16-bits.
 static SGMap getDefaultSgMap(Type ty) {
   int packingFactor = 1;
   if (ty.getIntOrFloatBitWidth() < packedASizeInBits)
@@ -150,6 +180,15 @@ static SGMap getDefaultSgMap(Type ty) {
   return SGMap(WiLayout({1, subgroupSize}), WiData({1, packingFactor}));
 }
 
+///===----------------------------------------------------------------------===///
+/// SGMapPropagation
+///===----------------------------------------------------------------------===///
+
+/// Backward data flow analysis to propagate the wi_layout and wi_data of each
+/// value in the program. Currently, the layouts for operands DPAS, StoreNd, and
+/// StoreScatter are fixed (known before propagation). Purpose of this analysis
+/// is to propagate those known layouts to all their producers and (other)
+/// consumers.
 class SGMapPropagation : public SparseBackwardDataFlowAnalysis<SGMapLattice> {
 private:
   void visitDpasOp(xegpu::DpasOp dpas, ArrayRef<SGMapLattice *> operands,
@@ -177,14 +216,6 @@ class SGMapPropagation : public SparseBackwardDataFlowAnalysis<SGMapLattice> {
                             ArrayRef<SGMapLattice *> operands,
                             ArrayRef<const SGMapLattice *> results);
 
-  void visitCreateNdDescOp(xegpu::CreateNdDescOp createNdDesc,
-                           ArrayRef<SGMapLattice *> operands,
-                           ArrayRef<const SGMapLattice *> results);
-
-  void visitCreateDescOp(xegpu::CreateDescOp createDesc,
-                         ArrayRef<SGMapLattice *> operands,
-                         ArrayRef<const SGMapLattice *> results);
-
 public:
   SGMapPropagation(DataFlowSolver &solver, SymbolTableCollection &symbolTable)
       : SparseBackwardDataFlowAnalysis(solver, symbolTable) {}
@@ -223,10 +254,6 @@ SGMapPropagation::visitOperation(Operation *op,
     visitVectorBitcastOp(bitcast, operands, results);
   else if (auto loadGather = dyn_cast<xegpu::LoadGatherOp>(op))
     visitLoadGatherOp(loadGather, operands, results);
-  else if (auto createNdDesc = dyn_cast<xegpu::CreateNdDescOp>(op))
-    visitCreateNdDescOp(createNdDesc, operands, results);
-  else if (auto createDesc = dyn_cast<xegpu::CreateDescOp>(op))
-    visitCreateDescOp(createDesc, operands, results);
   else if (auto storeScatter = dyn_cast<xegpu::StoreScatterOp>(op))
     visitStoreScatterOp(storeScatter, operands, results);
   /// All other ops
@@ -248,6 +275,7 @@ SGMapPropagation::visitOperation(Operation *op,
   return success();
 }
 
+/// Set the layouts for DPAS A, B, and C operands.
 void SGMapPropagation::visitDpasOp(xegpu::DpasOp dpas,
                                    ArrayRef<SGMapLattice *> operands,
                                    ArrayRef<const SGMapLattice *> results) {
@@ -264,6 +292,7 @@ void SGMapPropagation::visitDpasOp(xegpu::DpasOp dpas,
   }
 };
 
+/// Set the layout for the value and tensor descriptor operands in StoreNdOp.
 void SGMapPropagation::visitStoreNdOp(xegpu::StoreNdOp store,
                                       ArrayRef<SGMapLattice *> operands,
                                       ArrayRef<const SGMapLattice *> results) {
@@ -275,6 +304,8 @@ void SGMapPropagation::visitStoreNdOp(xegpu::StoreNdOp store,
   }
 }
 
+/// Propagate the layout of the value to the tensor descriptor operand in
+/// LoadNdOp.
 void SGMapPropagation::visitLoadNdOp(xegpu::LoadNdOp load,
                                      ArrayRef<SGMapLattice *> operands,
                                      ArrayRef<const SGMapLattice *> results) {
@@ -283,12 +314,20 @@ void SGMapPropagation::visitLoadNdOp(xegpu::LoadNdOp load,
   if (!valueLayout.isAssigned())
     return;
   SGMap tensorDescLayout = valueLayout;
-  if (auto transpose = load.getTranspose())
+  /// LoadNdOp has the transpose effect. However, at the stage of this analyis
+  /// this effect is not expected and should be abstracted away. Emit a warning.
+  /// TODO: Handle this case properly when `order` is introduced in the sg_map.
+  if (auto transpose = load.getTranspose()) {
+    emitWarning(load.getLoc())
+        << "Transpose effect is not expected for LoadNdOp";
     tensorDescLayout = valueLayout.getTransposedLayout(transpose.value());
+  }
   /// Propagate the new layout to the tensor descriptor operand.
   propagateIfChanged(operands[0], operands[0]->meet(tensorDescLayout));
 }
 
+/// For vector::TransposeOp, the layout of the result is transposed and
+/// propagated to the operand.
 void SGMapPropagation::visitTransposeOp(
     vector::TransposeOp transpose, ArrayRef<SGMapLattice *> operands,
     ArrayRef<const SGMapLattice *> results) {
@@ -302,6 +341,8 @@ void SGMapPropagation::visitTransposeOp(
   propagateIfChanged(operands[0], operands[0]->meet(newLayout));
 }
 
+/// For vector::BitCastOp, the wi_data of the source layout is changed based on
+/// the bit width of the source and result types.
 void SGMapPropagation::visitVectorBitcastOp(
     vector::BitCastOp bitcast, ArrayRef<SGMapLattice *> operands,
     ArrayRef<const SGMapLattice *> results) {
@@ -336,6 +377,8 @@ void SGMapPropagation::visitVectorBitcastOp(
                      operands[0]->meet(SGMap(newWiLayout, newWiData)));
 }
 
+/// Propagate the layout of the result to the tensor descriptor and mask
+/// operands in LoadGatherOp.
 void SGMapPropagation::visitLoadGatherOp(
     xegpu::LoadGatherOp load, ArrayRef<SGMapLattice *> operands,
     ArrayRef<const SGMapLattice *> results) {
@@ -343,56 +386,47 @@ void SGMapPropagation::visitLoadGatherOp(
   /// Need the layout of the value to propagate to the tensor descriptor.
   if (!valueLayout.isAssigned())
     return;
-  /// LoadGatherOp has the transpose effect, so propagate the transposed layout
-  /// to the tensor descriptor.
-  SGMap tensorDescLayout = valueLayout.getTransposedLayout({1, 0});
+
+  SGMap tensorDescLayout;
+  if (load.getTranspose()) {
+    /// LoadGatherOp has the transpose effect. However, at the stage of this
+    /// analyis this effect is not expected and should be abstracted away. Emit
+    /// a warning.
+    /// TODO: Handle this case properly when `order` is introduced in the
+    /// sg_map.
+    emitWarning(load.getLoc())
+        << "Transpose effect is not expected for LoadGatherOp";
+    tensorDescLayout = valueLayout.getTransposedLayout({1, 0});
+  } else
+    tensorDescLayout = valueLayout;
   /// Mask operand should have the same layout as the value but with wi_data =
   /// [1, 1]
-  SGMap maskLayout = SGMap(valueLayout.getLayout(), WiData({1, 1}));
+  SGMap maskLayout = getDefaultSgMap(load.getTensorDescType().getElementType());
   /// Propagate the new layout to the tensor descriptor operand.
   propagateIfChanged(operands[0], operands[0]->meet(tensorDescLayout));
   /// Propagate the new layout to the mask operand.
   propagateIfChanged(operands[1], operands[1]->meet(maskLayout));
 }
 
-void SGMapPropagation::visitCreateNdDescOp(
-    xegpu::CreateNdDescOp createNdDesc, ArrayRef<SGMapLattice *> operands,
-    ArrayRef<const SGMapLattice *> results) {
-  auto descLayout = results[0]->getValue();
-  /// Need the layout of the descriptor to propagate to the operands.
-  if (!descLayout.isAssigned())
-    return;
-  /// Propagate the layout to the source operand.
-  propagateIfChanged(operands[0], operands[0]->meet(descLayout));
-  /// For all other operands propagate the same layout with wi_data = [1, 1]
-  SGMap layout = SGMap(descLayout.getLayout(), WiData({1, 1}));
-  for (size_t i = 1; i < operands.size(); ++i) {
-    propagateIfChanged(operands[i], operands[i]->meet(layout));
-  }
-}
-
-void SGMapPropagation::visitCreateDescOp(
-    xegpu::CreateDescOp createDesc, ArrayRef<SGMapLattice *> operands,
-    ArrayRef<const SGMapLattice *> results) {
-  auto descLayout = results[0]->getValue();
-  /// Need the layout of the descriptor to propagate to the operands.
-  if (!descLayout.isAssigned())
-    return;
-  /// Propagate the layout to the source operand.
-  propagateIfChanged(operands[0], operands[0]->meet(descLayout));
-  /// For offset operand propagate the same layout with wi_data = [1, 1]
-  SGMap layout = SGMap(descLayout.getLayout(), WiData({1, 1}));
-  propagateIfChanged(operands[1], operands[1]->meet(layout));
-}
-
+/// Set the layout for the value, tensor descriptor, and mask operands in
+/// StoreScatterOp.
 void SGMapPropagation::visitStoreScatterOp(
     xegpu::StoreScatterOp storeScatter, ArrayRef<SGMapLattice *> operands,
     ArrayRef<const SGMapLattice *> results) {
-  /// StoreScatterOp has the transpose effect. Value has the regular layout,
-  /// while the tensor descriptor has the transposed layout.
   auto valueLayout =
       getDefaultSgMap(storeScatter.getTensorDescType().getElementType());
-  auto storeScatterLayout = valueLayout.getTransposedLayout({1, 0});
+  SGMap storeScatterLayout;
+  if (storeScatter.getTranspose()) {
+    /// StoreScatteOp allows transpose effect. However, at the stage of this
+    /// analyis this effect is not expected and should be abstracted away. Emit
+    /// a warning.
+    /// TODO: Handle this case properly when `order` is introduced in the
+    /// sg_map.
+    emitWarning(storeScatter.getLoc())
+        << "Transpose effect is not expected for StoreScatterOp";
+    storeScatterLayout = valueLayout.getTransposedLayout({1, 0});
+  } else
+    storeScatterLayout = valueLayout;
   /// Propagate the value layout.
   propagateIfChanged(operands[0], operands[0]->meet(valueLayout));
   /// Propagate the tensor descriptor layout.
@@ -403,6 +437,11 @@ void SGMapPropagation::visitStoreScatterOp(
 
 namespace {
 
+///===----------------------------------------------------------------------===///
+/// RunSGMapPropagation
+///===----------------------------------------------------------------------===///
+
+/// Driver class for running the SGMapPropagation analysis.
 class RunSGMapPropagation {
 public:
   RunSGMapPropagation(Operation *op) : target(op) {
@@ -487,13 +526,13 @@ struct XeGPUSubgroupDistributePass final
 
 void XeGPUSubgroupDistributePass::runOnOperation() {
   Operation *op = getOperation();
-
   RunSGMapPropagation solver(op);
 
-  // Print analysis results
+  // Print the analysis result and exit.
   if (printOnly) {
     auto &os = llvm::outs();
     solver.printAnalysisResult(os);
+    return;
   }
 }
 
diff --git a/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir b/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
index 81c4244091e87..871e6e2f9139e 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
@@ -62,7 +62,7 @@ func.func @test_extf_truncf_op(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegp
 }
 
 // -----
-func.func @test_load_gather_op(%a: memref<8x16xf16>, %b : memref<256xf16>, %c : memref<8x16xf32>) {
+func.func @test_load_gather_op_1(%a: memref<8x16xf16>, %b : memref<256xf16>, %c : memref<8x16xf32>) {
   %c0 = arith.constant 0 : index
   %6 = xegpu.create_nd_tdesc %a[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
   %7 = xegpu.load_nd %6 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
@@ -77,7 +77,17 @@ func.func @test_load_gather_op(%a: memref<8x16xf16>, %b : memref<256xf16>, %c :
 }
 
 // -----
-func.func @test_store_scatter_op(%c: memref<128xf32>){
+func.func @test_load_gather_op_2(%arg0: memref<256xf32>, %arg1: !xegpu.tensor_desc<16xf32>){
+  %0 = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
+  %1 = arith.constant dense<1>: vector<16xi1>
+  %2 = xegpu.create_tdesc %arg0, %0 : memref<256xf32>, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<chunk_size = 1>>
+  %3 = xegpu.load %2, %1 : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<chunk_size = 1>>, vector<16xi1> -> vector<16xf32>
+  xegpu.store_nd %3, %arg1 : vector<16xf32>, !xegpu.tensor_desc<16xf32>
+  return
+}
+
+// -----
+func.func @test_store_scatter_op_1(%c: memref<128xf32>){
   %cst = arith.constant dense<1.000000e+00> : vector<8x16xf32>
   %1 = arith.constant dense<1> : vector<16xi1>
   %2 = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
@@ -86,6 +96,15 @@ func.func @test_store_scatter_op(%c: memref<128xf32>){
   return
 }
 
+// -----
+func.func @test_store_scatter_op_2(%arg0: vector<16xf32>, %arg1: memref<256xf32>){
+  %0 = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
+  %1 = arith.constant dense<1>: vector<16xi1>
+  %2 = xegpu.create_tdesc %arg1, %0 : memref<256xf32>, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<chunk_size = 1>>
+  xegpu.store %arg0, %2, %1 : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<chunk_size = 1>>, vector<16xi1>
+  return
+}
+
 // -----
 func.func @test_bitcast_op_1(%arg0: memref<8x16xi16>, %arg1: memref<32x16xi8>, %c : memref<8x16xi32>) {
   %c0 = arith.constant 0 : index
@@ -191,7 +210,7 @@ func.func @test_if_op_2(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tenso
 }
 
 // -----
-func.func @test(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: vector<16x16xf16>, %arg3: i1) -> vector<8x16xf32> {
+func.func @test_if_op_3(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: vector<16x16xf16>, %arg3: i1, %arg4: !xegpu.tensor_desc<8x16xf32>){
   %0 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
   %1 = scf.if %arg3 -> (vector<16x16xf16>) {
     %3 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
@@ -200,30 +219,26 @@ func.func @test(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<1
     scf.yield %arg2 : vector<16x16xf16>
   }
   %2 = xegpu.dpas %0, %1 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
-  return %2 : vector<8x16xf32>
+  xegpu.store_nd %2, %arg4 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
+  return
 }
 
 // -----
-func.func @test(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: vector<16x16xf16>, %arg3: i1) -> vector<8x16xf32> {
-  %0 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
-  %1 = scf.if %arg3 -> (vector<8x16xf32>) {
-    %2 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
-    %3 = arith.addf %2, %arg2 : vector<16x16xf16>
-    %4 = xegpu.dpas %0, %3 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
-    scf.yield %4 : vector<8x16xf32>
-  } else {
-    %2 = xegpu.dpas %0, %arg2 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
-    scf.yield %2 : vector<8x16xf32>
-  }
-  return %1 : vector<8x16xf32>
+
+func.func @test_vector_reduction_1(%arg0: vector<16x16xf32>, %arg1: !xegpu.tensor_desc<16xf32>) {
+  %0 = arith.constant dense<0.000000e+00> : vector<16xf32>
+  %1 = vector.multi_reduction <add>, %arg0, %0 [0] : vector<16x16xf32> to vector<16xf32>
+  xegpu.store_nd %1, %arg1 : vector<16xf32>, !xegpu.tensor_desc<16xf32>
+  return
 }
 
+
+
 // -----
-func.func @test(%arg0: !xegpu.tensor_desc<8x16xi16>, %arg1: !xegpu.tensor_desc<16x16xi16>) -> vector<8x16xf32> {
-  %0 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<8x16xi16> -> vector<8x16xi16>
-  %1 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xi16> -> vector<16x16xi16>
-  %2 = arith.bitcast %0 : vector<8x16xi16> to vector<8x16xf16>
-  %3 = arith.bitcast %1 : vector<16x16xi16> to vector<16x16xf16>
-  %4 = xegpu.dpas %2, %3 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
-  return %4 : vector<8x16xf32>
+
+func.func @test_vector_reduction_2(%arg0: vector<16x16xf32>, %arg1: !xegpu.tensor_desc<16xf32>) {
+  %0 = arith.constant dense<0.000000e+00> : vector<16xf32>
+  %1 = vector.multi_reduction <add>, %arg0, %0 [1] : vector<16x16xf32> to vector<16xf32>
+  xegpu.store_nd %1, %arg1 : vector<16xf32>, !xegpu.tensor_desc<16xf32>
+  return
 }

>From 462b3c09d15254f9d6dbdbe4c11aef9761e56aea Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Fri, 7 Mar 2025 04:00:35 +0000
Subject: [PATCH 09/28] save work

---
 .../Transforms/XeGPUSubgroupDistribute.cpp    |  77 +++-
 .../XeGPU/subgroup-map-propagation.mlir       | 341 ++++++++++++------
 2 files changed, 296 insertions(+), 122 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 13030d5c74ea0..de65f285d75b4 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -134,6 +134,7 @@ SGMap SGMap::join(const SGMap &lhs, const SGMap &rhs) {
   llvm_unreachable("Join should not be triggered by SGMapPropagation.");
 }
 
+/// Get the transposed layout according to the given permutation.
 SGMap SGMap::getTransposedLayout(ArrayRef<int64_t> permutation) const {
   if (!isAssigned())
     return {};
@@ -180,6 +181,11 @@ static SGMap getDefaultSgMap(Type ty) {
   return SGMap(WiLayout({1, subgroupSize}), WiData({1, packingFactor}));
 }
 
+/// Helper Function to get the default layout representing constants.
+static SGMap getDefaultSgMap() {
+  return SGMap(WiLayout({1, subgroupSize}), WiData({1, 1}));
+}
+
 ///===----------------------------------------------------------------------===///
 /// SGMapPropagation
 ///===----------------------------------------------------------------------===///
@@ -216,6 +222,14 @@ class SGMapPropagation : public SparseBackwardDataFlowAnalysis<SGMapLattice> {
                             ArrayRef<SGMapLattice *> operands,
                             ArrayRef<const SGMapLattice *> results);
 
+  void visitCreateNdDescOp(xegpu::CreateNdDescOp createNdDesc,
+                           ArrayRef<SGMapLattice *> operands,
+                           ArrayRef<const SGMapLattice *> results);
+
+  void visitCreateDescOp(xegpu::CreateDescOp createDesc,
+                         ArrayRef<SGMapLattice *> operands,
+                         ArrayRef<const SGMapLattice *> results);
+
 public:
   SGMapPropagation(DataFlowSolver &solver, SymbolTableCollection &symbolTable)
       : SparseBackwardDataFlowAnalysis(solver, symbolTable) {}
@@ -254,6 +268,10 @@ SGMapPropagation::visitOperation(Operation *op,
     visitVectorBitcastOp(bitcast, operands, results);
   else if (auto loadGather = dyn_cast<xegpu::LoadGatherOp>(op))
     visitLoadGatherOp(loadGather, operands, results);
+  else if (auto createNdDesc = dyn_cast<xegpu::CreateNdDescOp>(op))
+    visitCreateNdDescOp(createNdDesc, operands, results);
+  else if (auto createDesc = dyn_cast<xegpu::CreateDescOp>(op))
+    visitCreateDescOp(createDesc, operands, results);
   else if (auto storeScatter = dyn_cast<xegpu::StoreScatterOp>(op))
     visitStoreScatterOp(storeScatter, operands, results);
   /// All other ops
@@ -318,8 +336,7 @@ void SGMapPropagation::visitLoadNdOp(xegpu::LoadNdOp load,
   /// this effect is not expected and should be abstracted away. Emit a warning.
   /// TODO: Handle this case properly when `order` is introduced in the sg_map.
   if (auto transpose = load.getTranspose()) {
-    emitWarning(load.getLoc())
-        << "Transpose effect is not expected for LoadNdOp";
+    load.emitWarning("Transpose effect is not expected for LoadNdOp");
     tensorDescLayout = valueLayout.getTransposedLayout(transpose.value());
   }
   /// Propagate the new layout to the tensor descriptor operand.
@@ -394,21 +411,53 @@ void SGMapPropagation::visitLoadGatherOp(
     /// a warning.
     /// TODO: Handle this case properly when `order` is introduced in the
     /// sg_map.
-    emitWarning(load.getLoc())
-        << "Transpose effect is not expected for LoadGatherOp";
+    load.emitWarning("Transpose effect is not expected for LoadGatherOp");
     tensorDescLayout = valueLayout.getTransposedLayout({1, 0});
   } else
     tensorDescLayout = valueLayout;
   /// Mask operand should have the same layout as the value but with wi_data =
   /// [1, 1]
-  SGMap maskLayout = getDefaultSgMap(load.getTensorDescType().getElementType());
+  SGMap maskLayout = SGMap(valueLayout.getLayout(), WiData({1, 1}));
   /// Propagate the new layout to the tensor descriptor operand.
   propagateIfChanged(operands[0], operands[0]->meet(tensorDescLayout));
   /// Propagate the new layout to the mask operand.
   propagateIfChanged(operands[1], operands[1]->meet(maskLayout));
 }
 
-/// Set the layout for the value, tensor descriptor, and mask operands in
+/// Propagate the layout of the descriptor to the operands in CreateNdDescOp.
+void SGMapPropagation::visitCreateNdDescOp(
+    xegpu::CreateNdDescOp createNdDesc, ArrayRef<SGMapLattice *> operands,
+    ArrayRef<const SGMapLattice *> results) {
+  auto descLayout = results[0]->getValue();
+  /// Need the layout of the descriptor to propagate to the operands.
+  if (!descLayout.isAssigned())
+    return;
+  /// Propagate the layout to the source operand.
+  propagateIfChanged(operands[0], operands[0]->meet(descLayout));
+  /// For all other operands propagate the descriptor layout.
+  SGMap layout = getDefaultSgMap();
+  for (size_t i = 1; i < operands.size(); ++i) {
+    propagateIfChanged(operands[i], operands[i]->meet(layout));
+  }
+}
+
+/// Propagate the layout of the descriptor to the source and offset operands in
+/// CreateDescOp.
+void SGMapPropagation::visitCreateDescOp(
+    xegpu::CreateDescOp createDesc, ArrayRef<SGMapLattice *> operands,
+    ArrayRef<const SGMapLattice *> results) {
+  auto descLayout = results[0]->getValue();
+  /// Need the layout of the descriptor to propagate to the operands.
+  if (!descLayout.isAssigned())
+    return;
+  /// Propagate the layout to the source operand.
+  propagateIfChanged(operands[0], operands[0]->meet(descLayout));
+  /// For offset operand propagate the default layout.
+  SGMap layout = getDefaultSgMap();
+  propagateIfChanged(operands[1], operands[1]->meet(layout));
+}
+
+/// Set the layout for the value, tensor descriptor, and mask operands in the
 /// StoreScatterOp.
 void SGMapPropagation::visitStoreScatterOp(
     xegpu::StoreScatterOp storeScatter, ArrayRef<SGMapLattice *> operands,
@@ -422,8 +471,8 @@ void SGMapPropagation::visitStoreScatterOp(
     /// a warning.
     /// TODO: Handle this case properly when `order` is introduced in the
     /// sg_map.
-    emitWarning(storeScatter.getLoc())
-        << "Transpose effect is not expected for StoreScatterOp";
+    storeScatter.emitWarning(
+        "Transpose effect is not expected for StoreScatterOp");
     storeScatterLayout = valueLayout.getTransposedLayout({1, 0});
   } else
     storeScatterLayout = valueLayout;
@@ -431,8 +480,9 @@ void SGMapPropagation::visitStoreScatterOp(
   propagateIfChanged(operands[0], operands[0]->meet(valueLayout));
   /// Propagate the tensor descriptor layout.
   propagateIfChanged(operands[1], operands[1]->meet(storeScatterLayout));
-  /// Propagate the mask layout.
-  propagateIfChanged(operands[2], operands[2]->meet(valueLayout));
+  /// Use default layout for mask operand.
+  auto maskLayout = getDefaultSgMap();
+  propagateIfChanged(operands[2], operands[2]->meet(maskLayout));
 }
 
 namespace {
@@ -517,10 +567,9 @@ struct XeGPUSubgroupDistributePass final
   }
   void runOnOperation() override;
   /// Print sg map propagation analysis result and exit for testing purposes.
-  Option<bool> printOnly{
-      *this, "print-only", llvm::cl::init(false),
-      llvm::cl::desc(
-          "Print the result of the subgroup map propagation analysis")};
+  Option<bool> printOnly{*this, "print-analysis-only", llvm::cl::init(false),
+                         llvm::cl::desc("Print the result of the subgroup map "
+                                        "propagation analysis and exit.")};
 };
 } // namespace
 
diff --git a/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir b/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
index 871e6e2f9139e..c75931cd6a37b 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
@@ -1,57 +1,149 @@
-func.func @test_dpas_op_1(%a: memref<8x16xf16>, %b : memref<16x16xf16>, %c : memref<8x16xf32>) {
+// RUN: mlir-opt -xegpu-subgroup-distribute='print-analysis-only=true' -split-input-file %s | FileCheck %s
+
+// CHECK: function: test_dpas_op_1:
+// CHECK: op    : %{{.*}} = arith.constant 0 : index
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %{{.*}} = arith.constant dense<0.000000e+00> : vector<8x16xf32>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[T1:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1]
+// CHECK-NEXT: op    : %[[T2:.*]] = xegpu.load_nd %[[T0]]  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[T3:.*]] = xegpu.load_nd %[[T1]]  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1]
+// CHECK-NEXT: op    : %[[T4:.*]] = xegpu.dpas %[[T2]], %[[T3]], %{{.*}} : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %{{.*}} = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+func.func @test_dpas_op_1(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<8x16xf32>) {
   %c0 = arith.constant 0 : index
-  %cst = arith.constant dense<0.0> : vector<8x16xf32>
-  %0 = xegpu.create_nd_tdesc %a[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
-  %1 = xegpu.create_nd_tdesc %b[%c0, %c0] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
-  %2 = xegpu.load_nd %0 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
-  %3 = xegpu.load_nd %1 : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+  %cst = arith.constant dense<0.000000e+00> : vector<8x16xf32>
+  %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
+  %1 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
+  %2 = xegpu.load_nd %0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+  %3 = xegpu.load_nd %1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
   %4 = xegpu.dpas %2, %3, %cst : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
-  %5 = xegpu.create_nd_tdesc %c[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
-  xegpu.store_nd %4, %5 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
+  %5 = xegpu.create_nd_tdesc %arg2[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
+  xegpu.store_nd %4, %5  : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
   return
 }
 
+
 // -----
-func.func @test_dpas_op_2(%arg0: vector<8x32xi8>, %arg1: vector<32x16xi8>, %c : memref<8x16xi32>) {
+// CHECK: function: test_dpas_op_2:
+// CHECK-NEXT: argument: <block argument> of type 'vector<8x32xi8>' at index: 0
+// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 2]
+// CHECK-NEXT: argument: <block argument> of type 'vector<32x16xi8>' at index: 1
+// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [4, 1]
+// CHECK-NEXT: argument: <block argument> of type 'memref<8x16xi32>' at index: 2
+// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %{{.*}} = arith.constant 0 : index
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[T0:.*]] = xegpu.dpas %{{.*}} : vector<8x32xi8>, vector<32x16xi8> -> vector<8x16xi32>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[T1:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xi32> -> !xegpu.tensor_desc<8x16xi32>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+func.func @test_dpas_op_2(%arg0: vector<8x32xi8>, %arg1: vector<32x16xi8>, %arg2: memref<8x16xi32>) {
   %c0 = arith.constant 0 : index
   %0 = xegpu.dpas %arg0, %arg1 : vector<8x32xi8>, vector<32x16xi8> -> vector<8x16xi32>
-  %1 = xegpu.create_nd_tdesc %c [%c0, %c0] : memref<8x16xi32> -> !xegpu.tensor_desc<8x16xi32>
-  xegpu.store_nd %0, %1 : vector<8x16xi32>, !xegpu.tensor_desc<8x16xi32>
+  %1 = xegpu.create_nd_tdesc %arg2[%c0, %c0] : memref<8x16xi32> -> !xegpu.tensor_desc<8x16xi32>
+  xegpu.store_nd %0, %1  : vector<8x16xi32>, !xegpu.tensor_desc<8x16xi32>
   return
 }
 
 // -----
-func.func @test_transpose_op_1(%a: memref<8x16xf16>, %b : memref<16x16xf16>, %c : memref<8x16xf32>) {
+// CHECK: function: test_transpose_op_1:
+// CHECK-NEXT: argument: <block argument> of type 'memref<8x16xf16>' at index: 0
+// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: argument: <block argument> of type 'memref<16x16xf16>' at index: 1
+// CHECK-NEXT: sg_map  : wi_layout: [16, 1], wi_data: [2, 1]
+// CHECK-NEXT: argument: <block argument> of type 'memref<8x16xf32>' at index: 2
+// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %{{.*}} = arith.constant 0 : index
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<8x16xf32>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[T1:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [16, 1], wi_data: [2, 1]
+// CHECK-NEXT: op    : %[[T2:.*]] = xegpu.load_nd %[[T0]]  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[T3:.*]] = xegpu.load_nd %[[T1]] <{transpose = array<i64: 1, 0>}> : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1]
+// CHECK-NEXT: op    : %[[T4:.*]] = xegpu.dpas %[[T2]], %[[T3]], %[[CST]] : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[T5:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+func.func @test_transpose_op_1(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<8x16xf32>) {
   %c0 = arith.constant 0 : index
-  %cst = arith.constant dense<0.0> : vector<8x16xf32>
-  %0 = xegpu.create_nd_tdesc %a[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
-  %1 = xegpu.create_nd_tdesc %b[%c0, %c0] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
-  %2 = xegpu.load_nd %0 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
-  %3 = xegpu.load_nd %1 {transpose = array<i64: 1, 0>} : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+  %cst = arith.constant dense<0.000000e+00> : vector<8x16xf32>
+  %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
+  %1 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
+  %2 = xegpu.load_nd %0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+  %3 = xegpu.load_nd %1 <{transpose = array<i64: 1, 0>}> : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
   %4 = xegpu.dpas %2, %3, %cst : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
-  %5 = xegpu.create_nd_tdesc %c[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
-  xegpu.store_nd %4, %5 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
+  %5 = xegpu.create_nd_tdesc %arg2[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
+  xegpu.store_nd %4, %5  : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
   return
 }
 
 // -----
-func.func @test_transpose_op_2(%a: memref<8x16xf16>, %b : memref<16x16xf16>, %c : memref<8x16xf32>) {
+// CHECK: argument: <block argument> of type 'memref<8x16xf16>' at index: 0
+// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: argument: <block argument> of type 'memref<16x16xf16>' at index: 1
+// CHECK-NEXT: sg_map  : wi_layout: [16, 1], wi_data: [2, 1]
+// CHECK-NEXT: argument: <block argument> of type 'memref<8x16xf32>' at index: 2
+// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %{{.*}} = arith.constant 0 : index
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<8x16xf32>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[T1:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [16, 1], wi_data: [2, 1]
+// CHECK-NEXT: op    : %[[T2:.*]] = xegpu.load_nd %[[T0]]  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[T3:.*]] = xegpu.load_nd %[[T1]]  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [16, 1], wi_data: [2, 1]
+// CHECK-NEXT: op    : %[[T4:.*]] = vector.transpose %[[T3]], [1, 0] : vector<16x16xf16> to vector<16x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1]
+// CHECK-NEXT: op    : %[[T5:.*]] = xegpu.dpas %[[T2]], %[[T4]], %[[CST]] : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[T6:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+func.func @test_transpose_op_2(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<8x16xf32>) {
   %c0 = arith.constant 0 : index
-  %cst = arith.constant dense<0.0> : vector<8x16xf32>
-  %0 = xegpu.create_nd_tdesc %a[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
-  %1 = xegpu.create_nd_tdesc %b[%c0, %c0] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
-  %2 = xegpu.load_nd %0 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
-  %3 = xegpu.load_nd %1 : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
-  %6 = vector.transpose %3, [1, 0] : vector<16x16xf16> to vector<16x16xf16>
-  %4 = xegpu.dpas %2, %6, %cst : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
-  %5 = xegpu.create_nd_tdesc %c[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
-  xegpu.store_nd %4, %5 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
+  %cst = arith.constant dense<0.000000e+00> : vector<8x16xf32>
+  %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
+  %1 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
+  %2 = xegpu.load_nd %0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+  %3 = xegpu.load_nd %1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+  %4 = vector.transpose %3, [1, 0] : vector<16x16xf16> to vector<16x16xf16>
+  %5 = xegpu.dpas %2, %4, %cst : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
+  %6 = xegpu.create_nd_tdesc %arg2[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
+  xegpu.store_nd %5, %6  : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
   return
 }
 
-
-
 // -----
+// CHECK: argument: <block argument> of type '!xegpu.tensor_desc<8x16xf16>' at index: 0
+// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<16x16xf16>' at index: 1
+// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [2, 1]
+// CHECK-NEXT: op    : %[[T0:.*]] = xegpu.load_nd %{{.*}}  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[T1:.*]] = xegpu.load_nd %{{.*}}  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1]
+// CHECK-NEXT: op    : %[[T2:.*]] = arith.extf %[[T1]] : vector<16x16xf16> to vector<16x16xf32>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1]
+// CHECK-NEXT: op    : %[[T3:.*]] = arith.truncf %[[T2]] : vector<16x16xf32> to vector<16x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1]
+// CHECK-NEXT: op    : %[[T4:.*]] = xegpu.dpas %[[T0]], %[[T3]] : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+// CHECK-NEXT: sg_map for result #0: Not assigned.
 func.func @test_extf_truncf_op(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>) -> vector<8x16xf32> {
   %0 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
   %1 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
@@ -62,75 +154,112 @@ func.func @test_extf_truncf_op(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegp
 }
 
 // -----
-func.func @test_load_gather_op_1(%a: memref<8x16xf16>, %b : memref<256xf16>, %c : memref<8x16xf32>) {
+// CHECK: function: test_load_gather_op_1:
+// CHECK-NEXT: argument: <block argument> of type 'memref<8x16xf16>' at index: 0
+// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: argument: <block argument> of type 'memref<256xf16>' at index: 1
+// CHECK-NEXT: sg_map  : wi_layout: [16, 1], wi_data: [2, 1]
+// CHECK-NEXT: argument: <block argument> of type 'memref<8x16xf32>' at index: 2
+// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %{{.*}} = arith.constant 0 : index
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[T1:.*]] = xegpu.load_nd %[[T0]]  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[CST:.*]] = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[CST0:.*]] = arith.constant dense<true> : vector<16xi1>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[T2:.*]] = xegpu.create_tdesc %{{.*}}, %[[CST]] : memref<256xf16>, vector<16xindex> -> !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16 : i64>>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [16, 1], wi_data: [2, 1]
+// CHECK-NEXT: op    : %[[T3:.*]] = xegpu.load %[[T2]], %[[CST0]] <{transpose}> : !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16 : i64>>, vector<16xi1> -> vector<16x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1]
+// CHECK-NEXT: op    : %[[T4:.*]] = xegpu.dpas %[[T1]], %[[T3]] : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[T5:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+func.func @test_load_gather_op_1(%arg0: memref<8x16xf16>, %arg1: memref<256xf16>, %arg2: memref<8x16xf32>) {
   %c0 = arith.constant 0 : index
-  %6 = xegpu.create_nd_tdesc %a[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
-  %7 = xegpu.load_nd %6 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
-  %0 = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
-  %1 = arith.constant dense<1>: vector<16xi1>
-  %2 = xegpu.create_tdesc %b, %0 : memref<256xf16>, vector<16xindex> -> !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16>>
-  %3 = xegpu.load %2, %1 <{transpose}> : !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16>>, vector<16xi1> -> vector<16x16xf16>
-  %4 = xegpu.dpas %7, %3 : vector<8x16xf16>, vector<16x16xf16>-> vector<8x16xf32>
-  %5 = xegpu.create_nd_tdesc %c[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
-  xegpu.store_nd %4, %5 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
+  %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
+  %1 = xegpu.load_nd %0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+  %cst = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
+  %cst_0 = arith.constant dense<true> : vector<16xi1>
+  %2 = xegpu.create_tdesc %arg1, %cst : memref<256xf16>, vector<16xindex> -> !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16 : i64>>
+  %3 = xegpu.load %2, %cst_0 <{transpose}> : !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16 : i64>>, vector<16xi1> -> vector<16x16xf16>
+  %4 = xegpu.dpas %1, %3 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+  %5 = xegpu.create_nd_tdesc %arg2[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
+  xegpu.store_nd %4, %5  : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
   return
 }
 
 // -----
-func.func @test_load_gather_op_2(%arg0: memref<256xf32>, %arg1: !xegpu.tensor_desc<16xf32>){
-  %0 = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
-  %1 = arith.constant dense<1>: vector<16xi1>
-  %2 = xegpu.create_tdesc %arg0, %0 : memref<256xf32>, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<chunk_size = 1>>
-  %3 = xegpu.load %2, %1 : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<chunk_size = 1>>, vector<16xi1> -> vector<16xf32>
-  xegpu.store_nd %3, %arg1 : vector<16xf32>, !xegpu.tensor_desc<16xf32>
+// CHECK: argument: <block argument> of type 'memref<256xf32>' at index: 0
+// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<16xf32>' at index: 1
+// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[CST:.*]] = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[CST0:.*]] = arith.constant dense<true> : vector<16xi1>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[T0:.*]] = xegpu.create_tdesc %{{.*}}, %[[CST]] : memref<256xf32>, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[T1]] = xegpu.load %[[T0]], %[[CST0]]  : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> -> vector<16xf32>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+func.func @test_load_gather_op_2(%arg0: memref<256xf32>, %arg1: !xegpu.tensor_desc<16xf32>) {
+  %cst = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
+  %cst_0 = arith.constant dense<true> : vector<16xi1>
+  %0 = xegpu.create_tdesc %arg0, %cst : memref<256xf32>, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
+  %1 = xegpu.load %0, %cst_0  : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> -> vector<16xf32>
+  xegpu.store_nd %1, %arg1  : vector<16xf32>, !xegpu.tensor_desc<16xf32>
   return
 }
 
 // -----
-func.func @test_store_scatter_op_1(%c: memref<128xf32>){
+func.func @test_store_scatter_op_1(%arg0: memref<128xf32>) {
   %cst = arith.constant dense<1.000000e+00> : vector<8x16xf32>
-  %1 = arith.constant dense<1> : vector<16xi1>
-  %2 = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
-  %3 = xegpu.create_tdesc %c, %2 : memref<128xf32>, vector<16xindex> -> !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<chunk_size = 8>>
-  xegpu.store %cst, %3, %1 <{transpose}> : vector<8x16xf32>, !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<chunk_size = 8>>, vector<16xi1>
+  %cst_0 = arith.constant dense<true> : vector<16xi1>
+  %cst_1 = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
+  %0 = xegpu.create_tdesc %arg0, %cst_1 : memref<128xf32>, vector<16xindex> -> !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<chunk_size = 8 : i64>>
+  xegpu.store %cst, %0, %cst_0 <{transpose}> : vector<8x16xf32>, !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<chunk_size = 8 : i64>>, vector<16xi1>
   return
 }
 
 // -----
-func.func @test_store_scatter_op_2(%arg0: vector<16xf32>, %arg1: memref<256xf32>){
-  %0 = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
-  %1 = arith.constant dense<1>: vector<16xi1>
-  %2 = xegpu.create_tdesc %arg1, %0 : memref<256xf32>, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<chunk_size = 1>>
-  xegpu.store %arg0, %2, %1 : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<chunk_size = 1>>, vector<16xi1>
+func.func @test_store_scatter_op_2(%arg0: vector<16xf32>, %arg1: memref<256xf32>) {
+  %cst = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
+  %cst_0 = arith.constant dense<true> : vector<16xi1>
+  %0 = xegpu.create_tdesc %arg1, %cst : memref<256xf32>, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
+  xegpu.store %arg0, %0, %cst_0  : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1>
   return
 }
 
 // -----
-func.func @test_bitcast_op_1(%arg0: memref<8x16xi16>, %arg1: memref<32x16xi8>, %c : memref<8x16xi32>) {
+func.func @test_bitcast_op_1(%arg0: memref<8x16xi16>, %arg1: memref<32x16xi8>, %arg2: memref<8x16xi32>) {
   %c0 = arith.constant 0 : index
-  %2 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x16xi16> -> !xegpu.tensor_desc<8x16xi16>
-  %3 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<32x16xi8> -> !xegpu.tensor_desc<32x16xi8>
-  %4 = xegpu.load_nd %2 : !xegpu.tensor_desc<8x16xi16> -> vector<8x16xi16>
-  %5 = xegpu.load_nd %3 : !xegpu.tensor_desc<32x16xi8> -> vector<32x16xi8>
-  %6 = vector.bitcast %4 : vector<8x16xi16> to vector<8x32xi8>
-  %0 = xegpu.dpas %6, %5 : vector<8x32xi8>, vector<32x16xi8> -> vector<8x16xi32>
-  %1 = xegpu.create_nd_tdesc %c [%c0, %c0] : memref<8x16xi32> -> !xegpu.tensor_desc<8x16xi32>
-  xegpu.store_nd %0, %1 : vector<8x16xi32>, !xegpu.tensor_desc<8x16xi32>
+  %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x16xi16> -> !xegpu.tensor_desc<8x16xi16>
+  %1 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<32x16xi8> -> !xegpu.tensor_desc<32x16xi8>
+  %2 = xegpu.load_nd %0  : !xegpu.tensor_desc<8x16xi16> -> vector<8x16xi16>
+  %3 = xegpu.load_nd %1  : !xegpu.tensor_desc<32x16xi8> -> vector<32x16xi8>
+  %4 = vector.bitcast %2 : vector<8x16xi16> to vector<8x32xi8>
+  %5 = xegpu.dpas %4, %3 : vector<8x32xi8>, vector<32x16xi8> -> vector<8x16xi32>
+  %6 = xegpu.create_nd_tdesc %arg2[%c0, %c0] : memref<8x16xi32> -> !xegpu.tensor_desc<8x16xi32>
+  xegpu.store_nd %5, %6  : vector<8x16xi32>, !xegpu.tensor_desc<8x16xi32>
   return
 }
 
 // -----
-func.func @test_bitcast_op_2(%arg0: memref<8x32xi8>, %arg1: memref<16x32xi8>, %c : memref<8x16xf32>) {
+func.func @test_bitcast_op_2(%arg0: memref<8x32xi8>, %arg1: memref<16x32xi8>, %arg2: memref<8x16xf32>) {
   %c0 = arith.constant 0 : index
-  %2 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x32xi8> -> !xegpu.tensor_desc<8x32xi8>
-  %3 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<16x32xi8> -> !xegpu.tensor_desc<16x32xi8>
-  %4 = xegpu.load_nd %2 : !xegpu.tensor_desc<8x32xi8> -> vector<8x32xi8>
-  %5 = xegpu.load_nd %3 : !xegpu.tensor_desc<16x32xi8> -> vector<16x32xi8>
-  %6 = vector.bitcast %4 : vector<8x32xi8> to vector<8x16xf16>
-  %7 = vector.bitcast %5 : vector<16x32xi8> to vector<16x16xf16>
-  %0 = xegpu.dpas %6, %7 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
-  %1 = xegpu.create_nd_tdesc %c [%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
-  xegpu.store_nd %0, %1 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
+  %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x32xi8> -> !xegpu.tensor_desc<8x32xi8>
+  %1 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<16x32xi8> -> !xegpu.tensor_desc<16x32xi8>
+  %2 = xegpu.load_nd %0  : !xegpu.tensor_desc<8x32xi8> -> vector<8x32xi8>
+  %3 = xegpu.load_nd %1  : !xegpu.tensor_desc<16x32xi8> -> vector<16x32xi8>
+  %4 = vector.bitcast %2 : vector<8x32xi8> to vector<8x16xf16>
+  %5 = vector.bitcast %3 : vector<16x32xi8> to vector<16x16xf16>
+  %6 = xegpu.dpas %4, %5 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+  %7 = xegpu.create_nd_tdesc %arg2[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
+  xegpu.store_nd %6, %7  : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
   return
 }
 
@@ -141,7 +270,7 @@ func.func @test_binary_op_1(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.t
   %2 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
   %3 = arith.addf %1, %2 : vector<16x16xf16>
   %4 = xegpu.dpas %0, %3 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
-  xegpu.store_nd %4, %arg2 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
+  xegpu.store_nd %4, %arg2  : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
   return
 }
 
@@ -152,34 +281,34 @@ func.func @test_binary_op_2(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.t
   %cst = arith.constant dense<1.000000e+00> : vector<16x16xf16>
   %2 = arith.addf %1, %cst : vector<16x16xf16>
   %3 = xegpu.dpas %0, %2 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
-  xegpu.store_nd %3, %arg2 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
-  xegpu.store_nd %2, %arg3 : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16>
+  xegpu.store_nd %3, %arg2  : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
+  xegpu.store_nd %2, %arg3  : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16>
   return
 }
 
 // -----
-func.func @test_for_op(%a: memref<8x128xf16>, %b : memref<128x16xf16>, %c : memref<8x16xf32>) {
+func.func @test_for_op(%arg0: memref<8x128xf16>, %arg1: memref<128x16xf16>, %arg2: memref<8x16xf32>) {
   %c0 = arith.constant 0 : index
   %c128 = arith.constant 128 : index
   %c16 = arith.constant 16 : index
-  %0 = xegpu.create_nd_tdesc %a[%c0, %c0] : memref<8x128xf16> -> !xegpu.tensor_desc<8x16xf16>
-  %1 = xegpu.create_nd_tdesc %b[%c0, %c0] : memref<128x16xf16> -> !xegpu.tensor_desc<16x16xf16>
-  %2 = arith.constant dense<0.0> : vector<8x16xf32>
-  %3:3 = scf.for %k = %c0 to %c128 step %c16 iter_args(%arg0 = %0, %arg1 = %1, %arg2 = %2) -> (!xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<16x16xf16>, vector<8x16xf32>) {
-    %4 = xegpu.load_nd %arg0 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
-    %5 = xegpu.load_nd %arg1 : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
-    %6 = xegpu.dpas %4, %5, %arg2 : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
-    %7 = xegpu.update_nd_offset %arg0, [%c0, %c16] : !xegpu.tensor_desc<8x16xf16>
-    %8 = xegpu.update_nd_offset %arg1, [%c16, %c0] : !xegpu.tensor_desc<16x16xf16>
+  %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x128xf16> -> !xegpu.tensor_desc<8x16xf16>
+  %1 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<128x16xf16> -> !xegpu.tensor_desc<16x16xf16>
+  %cst = arith.constant dense<0.000000e+00> : vector<8x16xf32>
+  %2:3 = scf.for %arg3 = %c0 to %c128 step %c16 iter_args(%arg4 = %0, %arg5 = %1, %arg6 = %cst) -> (!xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<16x16xf16>, vector<8x16xf32>) {
+    %4 = xegpu.load_nd %arg4  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+    %5 = xegpu.load_nd %arg5  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+    %6 = xegpu.dpas %4, %5, %arg6 : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
+    %7 = xegpu.update_nd_offset %arg4, [%c0, %c16] : !xegpu.tensor_desc<8x16xf16>
+    %8 = xegpu.update_nd_offset %arg5, [%c16, %c0] : !xegpu.tensor_desc<16x16xf16>
     scf.yield %7, %8, %6 : !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<16x16xf16>, vector<8x16xf32>
   }
-  %9 = xegpu.create_nd_tdesc %c[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
-  xegpu.store_nd %3#2, %9 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
+  %3 = xegpu.create_nd_tdesc %arg2[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
+  xegpu.store_nd %2#2, %3  : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
   return
 }
 
 // -----
-func.func @test_if_op_1(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: i1, %arg3: !xegpu.tensor_desc<8x16xf32>){
+func.func @test_if_op_1(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: i1, %arg3: !xegpu.tensor_desc<8x16xf32>) {
   %0 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
   %1 = scf.if %arg2 -> (vector<16x16xf16>) {
     %3 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
@@ -189,12 +318,12 @@ func.func @test_if_op_1(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tenso
     scf.yield %3 : vector<16x16xf16>
   }
   %2 = xegpu.dpas %0, %1 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
-  xegpu.store_nd %2, %arg3 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
+  xegpu.store_nd %2, %arg3  : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
   return
 }
 
 // -----
-func.func @test_if_op_2(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: i1, %arg3: !xegpu.tensor_desc<8x16xf32>, %arg4: !xegpu.tensor_desc<16x16xf16>){
+func.func @test_if_op_2(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: i1, %arg3: !xegpu.tensor_desc<8x16xf32>, %arg4: !xegpu.tensor_desc<16x16xf16>) {
   %0 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
   %1 = scf.if %arg2 -> (vector<16x16xf16>) {
     %3 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
@@ -204,13 +333,13 @@ func.func @test_if_op_2(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tenso
     scf.yield %3 : vector<16x16xf16>
   }
   %2 = xegpu.dpas %0, %1 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
-  xegpu.store_nd %2, %arg3 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
-  xegpu.store_nd %1, %arg4 : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16>
+  xegpu.store_nd %2, %arg3  : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
+  xegpu.store_nd %1, %arg4  : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16>
   return
 }
 
 // -----
-func.func @test_if_op_3(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: vector<16x16xf16>, %arg3: i1, %arg4: !xegpu.tensor_desc<8x16xf32>){
+func.func @test_if_op_3(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: vector<16x16xf16>, %arg3: i1, %arg4: !xegpu.tensor_desc<8x16xf32>) {
   %0 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
   %1 = scf.if %arg3 -> (vector<16x16xf16>) {
     %3 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
@@ -219,26 +348,22 @@ func.func @test_if_op_3(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tenso
     scf.yield %arg2 : vector<16x16xf16>
   }
   %2 = xegpu.dpas %0, %1 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
-  xegpu.store_nd %2, %arg4 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
+  xegpu.store_nd %2, %arg4  : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
   return
 }
 
 // -----
-
 func.func @test_vector_reduction_1(%arg0: vector<16x16xf32>, %arg1: !xegpu.tensor_desc<16xf32>) {
-  %0 = arith.constant dense<0.000000e+00> : vector<16xf32>
-  %1 = vector.multi_reduction <add>, %arg0, %0 [0] : vector<16x16xf32> to vector<16xf32>
-  xegpu.store_nd %1, %arg1 : vector<16xf32>, !xegpu.tensor_desc<16xf32>
+  %cst = arith.constant dense<0.000000e+00> : vector<16xf32>
+  %0 = vector.multi_reduction <add>, %arg0, %cst [0] : vector<16x16xf32> to vector<16xf32>
+  xegpu.store_nd %0, %arg1  : vector<16xf32>, !xegpu.tensor_desc<16xf32>
   return
 }
 
-
-
 // -----
-
 func.func @test_vector_reduction_2(%arg0: vector<16x16xf32>, %arg1: !xegpu.tensor_desc<16xf32>) {
-  %0 = arith.constant dense<0.000000e+00> : vector<16xf32>
-  %1 = vector.multi_reduction <add>, %arg0, %0 [1] : vector<16x16xf32> to vector<16xf32>
-  xegpu.store_nd %1, %arg1 : vector<16xf32>, !xegpu.tensor_desc<16xf32>
+  %cst = arith.constant dense<0.000000e+00> : vector<16xf32>
+  %0 = vector.multi_reduction <add>, %arg0, %cst [1] : vector<16x16xf32> to vector<16xf32>
+  xegpu.store_nd %0, %arg1  : vector<16xf32>, !xegpu.tensor_desc<16xf32>
   return
 }

>From 27f011fbedc582bb85f3555bf38b00247906a6bf Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Fri, 7 Mar 2025 04:35:52 +0000
Subject: [PATCH 10/28] save work

---
 .../XeGPU/subgroup-map-propagation.mlir       | 138 +++++++++++++++++-
 1 file changed, 136 insertions(+), 2 deletions(-)

diff --git a/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir b/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
index c75931cd6a37b..d9d1ac80169ea 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
@@ -216,6 +216,16 @@ func.func @test_load_gather_op_2(%arg0: memref<256xf32>, %arg1: !xegpu.tensor_de
 }
 
 // -----
+// CHECK: argument: <block argument> of type 'memref<128xf32>' at index: 0
+// CHECK-NEXT: sg_map  : wi_layout: [16, 1], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[CST:.*]] = arith.constant dense<1.000000e+00> : vector<8x16xf32>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[CST0:.*]] = arith.constant dense<true> : vector<16xi1>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[CST1:.*]] = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[T0:.*]] = xegpu.create_tdesc %{{.*}}, %[[CST1]] : memref<128xf32>, vector<16xindex> -> !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<chunk_size = 8 : i64>>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [16, 1], wi_data: [1, 1]
 func.func @test_store_scatter_op_1(%arg0: memref<128xf32>) {
   %cst = arith.constant dense<1.000000e+00> : vector<8x16xf32>
   %cst_0 = arith.constant dense<true> : vector<16xi1>
@@ -226,6 +236,16 @@ func.func @test_store_scatter_op_1(%arg0: memref<128xf32>) {
 }
 
 // -----
+// CHECK: argument: <block argument> of type 'vector<16xf32>' at index: 0
+// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: argument: <block argument> of type 'memref<256xf32>' at index: 1
+// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[CST:.*]] = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[CST1:.*]] = arith.constant dense<true> : vector<16xi1>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[T0:.*]] = xegpu.create_tdesc %{{.*}}, %[[CST]] : memref<256xf32>, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
 func.func @test_store_scatter_op_2(%arg0: vector<16xf32>, %arg1: memref<256xf32>) {
   %cst = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
   %cst_0 = arith.constant dense<true> : vector<16xi1>
@@ -235,7 +255,29 @@ func.func @test_store_scatter_op_2(%arg0: vector<16xf32>, %arg1: memref<256xf32>
 }
 
 // -----
-func.func @test_bitcast_op_1(%arg0: memref<8x16xi16>, %arg1: memref<32x16xi8>, %arg2: memref<8x16xi32>) {
+// CHECK: argument: <block argument> of type 'memref<8x16xi16>' at index: 0
+// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: argument: <block argument> of type 'memref<32x16xi8>' at index: 1
+// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [4, 1]
+// CHECK-NEXT: argument: <block argument> of type 'memref<8x16xi32>' at index: 2
+// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %{{.*}} = arith.constant 0 : index
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xi16> -> !xegpu.tensor_desc<8x16xi16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[T1:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<32x16xi8> -> !xegpu.tensor_desc<32x16xi8>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [4, 1]
+// CHECK-NEXT: op    : %[[T2:.*]] = xegpu.load_nd %[[T0]]  : !xegpu.tensor_desc<8x16xi16> -> vector<8x16xi16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[T3:.*]] = xegpu.load_nd %[[T1]]  : !xegpu.tensor_desc<32x16xi8> -> vector<32x16xi8>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [4, 1]
+// CHECK-NEXT: op    : %[[T4:.*]] = vector.bitcast %[[T2]] : vector<8x16xi16> to vector<8x32xi8>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 2]
+// CHECK-NEXT: op    : %[[T5:.*]] = xegpu.dpas %[[T4]], %[[T3]] : vector<8x32xi8>, vector<32x16xi8> -> vector<8x16xi32>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[T6:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xi32> -> !xegpu.tensor_desc<8x16xi32>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+func.func @test_vector_bitcast_op_1(%arg0: memref<8x16xi16>, %arg1: memref<32x16xi8>, %arg2: memref<8x16xi32>) {
   %c0 = arith.constant 0 : index
   %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x16xi16> -> !xegpu.tensor_desc<8x16xi16>
   %1 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<32x16xi8> -> !xegpu.tensor_desc<32x16xi8>
@@ -249,7 +291,31 @@ func.func @test_bitcast_op_1(%arg0: memref<8x16xi16>, %arg1: memref<32x16xi8>, %
 }
 
 // -----
-func.func @test_bitcast_op_2(%arg0: memref<8x32xi8>, %arg1: memref<16x32xi8>, %arg2: memref<8x16xf32>) {
+// CHECK: argument: <block argument> of type 'memref<8x32xi8>' at index: 0
+// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 2]
+// CHECK-NEXT: argument: <block argument> of type 'memref<16x32xi8>' at index: 1
+// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [4, 1]
+// CHECK-NEXT: argument: <block argument> of type 'memref<8x16xf32>' at index: 2
+// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %{{.*}} = arith.constant 0 : index
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x32xi8> -> !xegpu.tensor_desc<8x32xi8>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 2]
+// CHECK-NEXT: op    : %[[T1:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<16x32xi8> -> !xegpu.tensor_desc<16x32xi8>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [4, 1]
+// CHECK-NEXT: op    : %[[T2:.*]] = xegpu.load_nd %[[T0]]  : !xegpu.tensor_desc<8x32xi8> -> vector<8x32xi8>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 2]
+// CHECK-NEXT: op    : %[[T3:.*]] = xegpu.load_nd %[[T1]]  : !xegpu.tensor_desc<16x32xi8> -> vector<16x32xi8>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [4, 1]
+// CHECK-NEXT: op    : %[[T4:.*]] = vector.bitcast %[[T2]] : vector<8x32xi8> to vector<8x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[T5:.*]] = vector.bitcast %[[T3]] : vector<16x32xi8> to vector<16x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1]
+// CHECK-NEXT: op    : %[[T6:.*]] = xegpu.dpas %[[T4]], %[[T5]] : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[T7:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+func.func @test_vector_bitcast_op_2(%arg0: memref<8x32xi8>, %arg1: memref<16x32xi8>, %arg2: memref<8x16xf32>) {
   %c0 = arith.constant 0 : index
   %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x32xi8> -> !xegpu.tensor_desc<8x32xi8>
   %1 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<16x32xi8> -> !xegpu.tensor_desc<16x32xi8>
@@ -264,6 +330,22 @@ func.func @test_bitcast_op_2(%arg0: memref<8x32xi8>, %arg1: memref<16x32xi8>, %a
 }
 
 // -----
+// CHECK: argument: <block argument> of type '!xegpu.tensor_desc<8x16xf16>' at index: 0
+// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<16x16xf16>' at index: 1
+// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [2, 1]
+// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<8x16xf32>' at index: 2
+// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[T0:.*]] = xegpu.load_nd %{{.*}}  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[T1:.*]] = xegpu.load_nd %{{.*}}  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1]
+// CHECK-NEXT: op    : %[[T2:.*]] = xegpu.load_nd %{{.*}}  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1]
+// CHECK-NEXT: op    : %[[T3:.*]] = arith.addf %[[T1]], %[[T2]] : vector<16x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1]
+// CHECK-NEXT: op    : %[[T4:.*]] = xegpu.dpas %[[T0]], %[[T3]] : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
 func.func @test_binary_op_1(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: !xegpu.tensor_desc<8x16xf32>) {
   %0 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
   %1 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
@@ -275,6 +357,24 @@ func.func @test_binary_op_1(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.t
 }
 
 // -----
+// CHECK: argument: <block argument> of type '!xegpu.tensor_desc<8x16xf16>' at index: 0
+// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<16x16xf16>' at index: 1
+// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<8x16xf32>' at index: 2
+// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<16x16xf16>' at index: 3
+// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[T0:.*]] = xegpu.load_nd %{{.*}}  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[T1:.*]] = xegpu.load_nd %{{.*}}  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[CST:.*]] = arith.constant dense<1.000000e+00> : vector<16x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[T2:.*]] = arith.addf %[[T1]], %[[CST]] : vector<16x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[T3:.*]] = xegpu.dpas %[[T0]], %[[T2]] : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
 func.func @test_binary_op_2(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: !xegpu.tensor_desc<8x16xf32>, %arg3: !xegpu.tensor_desc<16x16xf16>) {
   %0 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
   %1 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
@@ -287,6 +387,40 @@ func.func @test_binary_op_2(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.t
 }
 
 // -----
+// CHECK: argument: <block argument> of type 'memref<8x128xf16>' at index: 0
+// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: argument: <block argument> of type 'memref<128x16xf16>' at index: 1
+// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [2, 1]
+// CHECK-NEXT: argument: <block argument> of type 'memref<8x16xf32>' at index: 2
+// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %{{.*}} = arith.constant 0 : index
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %{{.*}} = arith.constant 128 : index
+// CHECK-NEXT: sg_map for result #0: Not assigned.
+// CHECK-NEXT: op    : %{{.*}} = arith.constant 16 : index
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x128xf16> -> !xegpu.tensor_desc<8x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[T1:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<128x16xf16> -> !xegpu.tensor_desc<16x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1]
+// CHECK-NEXT: op    : %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<8x16xf32>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[T4:.*]] = xegpu.load_nd %{{.*}}  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[T5:.*]] = xegpu.load_nd %{{.*}}  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1]
+// CHECK-NEXT: op    : %[[T6:.*]] = xegpu.dpas %[[T4]], %[[T5]], %{{.*}} : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[T7:.*]] = xegpu.update_nd_offset %{{.*}} : !xegpu.tensor_desc<8x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[T8:.*]] = xegpu.update_nd_offset %{{.8}} : !xegpu.tensor_desc<16x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1]
+// CHECK-NEXT: op    : scf.for
+// CHECK-NEXT: sg_map for result #0: Not assigned.
+// CHECK-NEXT: sg_map for result #1: Not assigned.
+// CHECK-NEXT: sg_map for result #2: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[T3:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
 func.func @test_for_op(%arg0: memref<8x128xf16>, %arg1: memref<128x16xf16>, %arg2: memref<8x16xf32>) {
   %c0 = arith.constant 0 : index
   %c128 = arith.constant 128 : index

>From 898b6195b8676a0e3db141cc5e6077c22d0c1ec3 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Fri, 7 Mar 2025 04:45:17 +0000
Subject: [PATCH 11/28] save work

---
 .../XeGPU/subgroup-map-propagation.mlir       | 78 ++++++++++++++++++-
 1 file changed, 77 insertions(+), 1 deletion(-)

diff --git a/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir b/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
index d9d1ac80169ea..83a17dd9d1a7e 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
@@ -413,7 +413,7 @@ func.func @test_binary_op_2(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.t
 // CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
 // CHECK-NEXT: op    : %[[T7:.*]] = xegpu.update_nd_offset %{{.*}} : !xegpu.tensor_desc<8x16xf16>
 // CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
-// CHECK-NEXT: op    : %[[T8:.*]] = xegpu.update_nd_offset %{{.8}} : !xegpu.tensor_desc<16x16xf16>
+// CHECK-NEXT: op    : %[[T8:.*]] = xegpu.update_nd_offset %{{.*}} : !xegpu.tensor_desc<16x16xf16>
 // CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1]
 // CHECK-NEXT: op    : scf.for
 // CHECK-NEXT: sg_map for result #0: Not assigned.
@@ -442,6 +442,25 @@ func.func @test_for_op(%arg0: memref<8x128xf16>, %arg1: memref<128x16xf16>, %arg
 }
 
 // -----
+// CHECK: function: test_if_op_1:
+// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<8x16xf16>' at index: 0
+// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<16x16xf16>' at index: 1
+// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [2, 1]
+// CHECK-NEXT: argument: <block argument> of type 'i1' at index: 2
+// CHECK-NEXT: sg_map  : Not assigned.
+// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<8x16xf32>' at index: 3
+// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[T0:.*]] = xegpu.load_nd %{{.*}}  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[T3:.*]] = xegpu.load_nd %{{.*}}  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1]
+// CHECK-NEXT: op    : %[[T4:.*]] = xegpu.load_nd %{{.*}}  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1]
+// CHECK-NEXT: op    : scf.if
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1]
+// CHECK-NEXT: op    : %[[T2:.*]] = xegpu.dpas %[[T0]], %{{.*}} : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
 func.func @test_if_op_1(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: i1, %arg3: !xegpu.tensor_desc<8x16xf32>) {
   %0 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
   %1 = scf.if %arg2 -> (vector<16x16xf16>) {
@@ -457,6 +476,27 @@ func.func @test_if_op_1(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tenso
 }
 
 // -----
+// CHECK: function: test_if_op_2:
+// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<8x16xf16>' at index: 0
+// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<16x16xf16>' at index: 1
+// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: argument: <block argument> of type 'i1' at index: 2
+// CHECK-NEXT: sg_map  : Not assigned.
+// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<8x16xf32>' at index: 3
+// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<16x16xf16>' at index: 4
+// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[T0:.*]] = xegpu.load_nd %{{.*}}  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[T3:.*]] = xegpu.load_nd %{{.*}}  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[T4:.*]] = xegpu.load_nd %{{.*}}  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : scf.if
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[T2:.*]] = xegpu.dpas %[[T0]], %{{.*}} : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
 func.func @test_if_op_2(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: i1, %arg3: !xegpu.tensor_desc<8x16xf32>, %arg4: !xegpu.tensor_desc<16x16xf16>) {
   %0 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
   %1 = scf.if %arg2 -> (vector<16x16xf16>) {
@@ -473,6 +513,25 @@ func.func @test_if_op_2(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tenso
 }
 
 // -----
+// CHECK: function: test_if_op_3:
+// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<8x16xf16>' at index: 0
+// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<16x16xf16>' at index: 1
+// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [2, 1]
+// CHECK-NEXT: argument: <block argument> of type 'vector<16x16xf16>' at index: 2
+// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [2, 1]
+// CHECK-NEXT: argument: <block argument> of type 'i1' at index: 3
+// CHECK-NEXT: sg_map  : Not assigned.
+// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<8x16xf32>' at index: 4
+// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[T0:.*]] = xegpu.load_nd %{{.*}}  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[T3:.*]] = xegpu.load_nd %{{.*}}  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1]
+// CHECK-NEXT: op    : scf.if
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1]
+// CHECK-NEXT: op    : %[[T2:.*]] = xegpu.dpas %[[T0]], %{{.*}} : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
 func.func @test_if_op_3(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: vector<16x16xf16>, %arg3: i1, %arg4: !xegpu.tensor_desc<8x16xf32>) {
   %0 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
   %1 = scf.if %arg3 -> (vector<16x16xf16>) {
@@ -487,6 +546,15 @@ func.func @test_if_op_3(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tenso
 }
 
 // -----
+// CHECK: function: test_vector_reduction_1:
+// CHECK-NEXT: argument: <block argument> of type 'vector<16x16xf32>' at index: 0
+// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<16xf32>' at index: 1
+// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<16xf32>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[T0:.*]] = vector.multi_reduction <add>, %{{.*}}, %[[CST]] [0] : vector<16x16xf32> to vector<16xf32>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
 func.func @test_vector_reduction_1(%arg0: vector<16x16xf32>, %arg1: !xegpu.tensor_desc<16xf32>) {
   %cst = arith.constant dense<0.000000e+00> : vector<16xf32>
   %0 = vector.multi_reduction <add>, %arg0, %cst [0] : vector<16x16xf32> to vector<16xf32>
@@ -495,6 +563,14 @@ func.func @test_vector_reduction_1(%arg0: vector<16x16xf32>, %arg1: !xegpu.tenso
 }
 
 // -----
+// CHECK: argument: <block argument> of type 'vector<16x16xf32>' at index: 0
+// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<16xf32>' at index: 1
+// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<16xf32>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[T0:.*]] = vector.multi_reduction <add>, %{{.*}}, %[[CST]] [1] : vector<16x16xf32> to vector<16xf32>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
 func.func @test_vector_reduction_2(%arg0: vector<16x16xf32>, %arg1: !xegpu.tensor_desc<16xf32>) {
   %cst = arith.constant dense<0.000000e+00> : vector<16xf32>
   %0 = vector.multi_reduction <add>, %arg0, %cst [1] : vector<16x16xf32> to vector<16xf32>

>From 83404020a7459ca2c99cfb8e15027c2f31968213 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Fri, 7 Mar 2025 19:55:40 +0000
Subject: [PATCH 12/28] save work

---
 .../Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp    | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index de65f285d75b4..aba0586f30cfd 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -238,13 +238,13 @@ class SGMapPropagation : public SparseBackwardDataFlowAnalysis<SGMapLattice> {
   LogicalResult visitOperation(Operation *op, ArrayRef<SGMapLattice *> operands,
                                ArrayRef<const SGMapLattice *> results) override;
 
-  void visitBranchOperand(OpOperand &operand) override{};
+  void visitBranchOperand(OpOperand &operand) override {};
 
-  void visitCallOperand(OpOperand &operand) override{};
+  void visitCallOperand(OpOperand &operand) override {};
 
   void visitExternalCall(CallOpInterface call,
                          ArrayRef<SGMapLattice *> operands,
-                         ArrayRef<const SGMapLattice *> results) override{};
+                         ArrayRef<const SGMapLattice *> results) override {};
 
   void setToExitState(SGMapLattice *lattice) override {
     (void)lattice->meet(SGMap());

>From 549e8781d8899f7056a19548a3c6848639be8587 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Fri, 7 Mar 2025 20:07:39 +0000
Subject: [PATCH 13/28] save work

---
 mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index aba0586f30cfd..d4618bf0e24d6 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -144,7 +144,7 @@ SGMap SGMap::getTransposedLayout(ArrayRef<int64_t> permutation) const {
     newLayout.layout.push_back(layout.layout[idx]);
     newData.layout.push_back(data.layout[idx]);
   }
-  return SGMap(newLayout, data);
+  return SGMap(newLayout, newData);
 }
 
 ///===----------------------------------------------------------------------===///

>From abca365b5bf0851b8a31885395ce576df15b1bf3 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Fri, 7 Mar 2025 20:12:56 +0000
Subject: [PATCH 14/28] save work

---
 .../Dialect/XeGPU/subgroup-map-propagation.mlir    | 14 +++++++-------
 1 file changed, 7 insertions(+), 7 deletions(-)

diff --git a/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir b/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
index 83a17dd9d1a7e..92e37a20bf555 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
@@ -58,7 +58,7 @@ func.func @test_dpas_op_2(%arg0: vector<8x32xi8>, %arg1: vector<32x16xi8>, %arg2
 // CHECK-NEXT: argument: <block argument> of type 'memref<8x16xf16>' at index: 0
 // CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
 // CHECK-NEXT: argument: <block argument> of type 'memref<16x16xf16>' at index: 1
-// CHECK-NEXT: sg_map  : wi_layout: [16, 1], wi_data: [2, 1]
+// CHECK-NEXT: sg_map  : wi_layout: [16, 1], wi_data: [1, 2]
 // CHECK-NEXT: argument: <block argument> of type 'memref<8x16xf32>' at index: 2
 // CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
 // CHECK-NEXT: op    : %{{.*}} = arith.constant 0 : index
@@ -68,7 +68,7 @@ func.func @test_dpas_op_2(%arg0: vector<8x32xi8>, %arg1: vector<32x16xi8>, %arg2
 // CHECK-NEXT: op    : %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
 // CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
 // CHECK-NEXT: op    : %[[T1:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [16, 1], wi_data: [2, 1]
+// CHECK-NEXT: sg_map for result #0: wi_layout: [16, 1], wi_data: [1, 2]
 // CHECK-NEXT: op    : %[[T2:.*]] = xegpu.load_nd %[[T0]]  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
 // CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
 // CHECK-NEXT: op    : %[[T3:.*]] = xegpu.load_nd %[[T1]] <{transpose = array<i64: 1, 0>}> : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
@@ -94,7 +94,7 @@ func.func @test_transpose_op_1(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>
 // CHECK: argument: <block argument> of type 'memref<8x16xf16>' at index: 0
 // CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
 // CHECK-NEXT: argument: <block argument> of type 'memref<16x16xf16>' at index: 1
-// CHECK-NEXT: sg_map  : wi_layout: [16, 1], wi_data: [2, 1]
+// CHECK-NEXT: sg_map  : wi_layout: [16, 1], wi_data: [1, 2]
 // CHECK-NEXT: argument: <block argument> of type 'memref<8x16xf32>' at index: 2
 // CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
 // CHECK-NEXT: op    : %{{.*}} = arith.constant 0 : index
@@ -104,11 +104,11 @@ func.func @test_transpose_op_1(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>
 // CHECK-NEXT: op    : %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
 // CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
 // CHECK-NEXT: op    : %[[T1:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [16, 1], wi_data: [2, 1]
+// CHECK-NEXT: sg_map for result #0: wi_layout: [16, 1], wi_data: [1, 2]
 // CHECK-NEXT: op    : %[[T2:.*]] = xegpu.load_nd %[[T0]]  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
 // CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
 // CHECK-NEXT: op    : %[[T3:.*]] = xegpu.load_nd %[[T1]]  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [16, 1], wi_data: [2, 1]
+// CHECK-NEXT: sg_map for result #0: wi_layout: [16, 1], wi_data: [1, 2]
 // CHECK-NEXT: op    : %[[T4:.*]] = vector.transpose %[[T3]], [1, 0] : vector<16x16xf16> to vector<16x16xf16>
 // CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1]
 // CHECK-NEXT: op    : %[[T5:.*]] = xegpu.dpas %[[T2]], %[[T4]], %[[CST]] : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
@@ -158,7 +158,7 @@ func.func @test_extf_truncf_op(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegp
 // CHECK-NEXT: argument: <block argument> of type 'memref<8x16xf16>' at index: 0
 // CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
 // CHECK-NEXT: argument: <block argument> of type 'memref<256xf16>' at index: 1
-// CHECK-NEXT: sg_map  : wi_layout: [16, 1], wi_data: [2, 1]
+// CHECK-NEXT: sg_map  : wi_layout: [16, 1], wi_data: [1, 2]
 // CHECK-NEXT: argument: <block argument> of type 'memref<8x16xf32>' at index: 2
 // CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
 // CHECK-NEXT: op    : %{{.*}} = arith.constant 0 : index
@@ -172,7 +172,7 @@ func.func @test_extf_truncf_op(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegp
 // CHECK-NEXT: op    : %[[CST0:.*]] = arith.constant dense<true> : vector<16xi1>
 // CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
 // CHECK-NEXT: op    : %[[T2:.*]] = xegpu.create_tdesc %{{.*}}, %[[CST]] : memref<256xf16>, vector<16xindex> -> !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16 : i64>>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [16, 1], wi_data: [2, 1]
+// CHECK-NEXT: sg_map for result #0: wi_layout: [16, 1], wi_data: [1, 2]
 // CHECK-NEXT: op    : %[[T3:.*]] = xegpu.load %[[T2]], %[[CST0]] <{transpose}> : !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16 : i64>>, vector<16xi1> -> vector<16x16xf16>
 // CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1]
 // CHECK-NEXT: op    : %[[T4:.*]] = xegpu.dpas %[[T1]], %[[T3]] : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>

>From 5bd5db50d8bdd6c3802e270741b9e0e1d744bd34 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Fri, 7 Mar 2025 23:14:23 +0000
Subject: [PATCH 15/28] save work

---
 .../Transforms/XeGPUSubgroupDistribute.cpp    | 113 ++++++++++++++----
 .../XeGPU/subgroup-map-propagation.mlir       |  60 +++++-----
 2 files changed, 117 insertions(+), 56 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index d4618bf0e24d6..7136993d0e968 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -15,6 +15,7 @@
 #include "mlir/Dialect/XeGPU/IR/XeGPU.h"
 #include "mlir/Dialect/XeGPU/Transforms/Passes.h"
 #include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
+#include "mlir/IR/Builders.h"
 #include "llvm/Support/raw_ostream.h"
 
 namespace mlir {
@@ -157,6 +158,8 @@ struct SGMapLattice : public Lattice<SGMap> {
   using Lattice::Lattice;
 };
 
+/// Helper Functions
+
 /// Helper Function to get the expected layouts for DPAS operands.
 static SGMap getSGMapForDPASOperand(Type operandTy, unsigned operandNum) {
   int packingFactorForB = packedBSizeInBits / operandTy.getIntOrFloatBitWidth();
@@ -174,16 +177,33 @@ static SGMap getSGMapForDPASOperand(Type operandTy, unsigned operandNum) {
 /// However, the minimum granularity of data access per work item is 16-bits.
 /// So, if the bitwidth of the type is less than 16, we need to pack the data to
 /// 16-bits.
-static SGMap getDefaultSgMap(Type ty) {
-  int packingFactor = 1;
-  if (ty.getIntOrFloatBitWidth() < packedASizeInBits)
-    packingFactor = packedBSizeInBits / ty.getIntOrFloatBitWidth();
-  return SGMap(WiLayout({1, subgroupSize}), WiData({1, packingFactor}));
+// static SGMap getDefaultSgMap(Type ty, unsigned rank) {
+//   int packingFactor = 1;
+//   if (ty.getIntOrFloatBitWidth() < packedASizeInBits)
+//     packingFactor = packedBSizeInBits / ty.getIntOrFloatBitWidth();
+//   return SGMap(WiLayout({1, subgroupSize}), WiData({1, packingFactor}));
+// }
+
+/// Helper Function to get the default layout for uniform values like constants.
+static SGMap getDefaultSgMap(unsigned rank) {
+  assert((rank == 1 || rank == 2) && "Expected 0D or 1D vector.");
+  if (rank == 1)
+    return SGMap(WiLayout({subgroupSize}), WiData({1}));
+  return SGMap(WiLayout({1, subgroupSize}), WiData({1, 1}));
 }
 
-/// Helper Function to get the default layout representing constants.
-static SGMap getDefaultSgMap() {
-  return SGMap(WiLayout({1, subgroupSize}), WiData({1, 1}));
+static SGMap getDefaultSgMap(VectorType vectorTy) {
+  /// Expecting a 1D or 2D vector.
+  assert((vectorTy.getRank() == 1 || vectorTy.getRank() == 2) &&
+         "Expected 1D or 2D vector.");
+  /// If the rank is 1, then return default layout for 1D vector.
+  if (vectorTy.getRank() == 1)
+    return getDefaultSgMap(1);
+  int packingFactor = 1;
+  auto bitwidth = vectorTy.getElementType().getIntOrFloatBitWidth();
+  if (bitwidth < packedASizeInBits)
+    packingFactor = packedBSizeInBits / bitwidth;
+  return SGMap(WiLayout({1, subgroupSize}), WiData({1, packingFactor}));
 }
 
 ///===----------------------------------------------------------------------===///
@@ -230,6 +250,14 @@ class SGMapPropagation : public SparseBackwardDataFlowAnalysis<SGMapLattice> {
                          ArrayRef<SGMapLattice *> operands,
                          ArrayRef<const SGMapLattice *> results);
 
+  void visitUpdateNdOffsetOp(xegpu::UpdateNdOffsetOp updateNdOffset,
+                             ArrayRef<SGMapLattice *> operands,
+                             ArrayRef<const SGMapLattice *> results);
+
+  void visitVectorMultiReductionOp(vector::MultiDimReductionOp reduction,
+                                   ArrayRef<SGMapLattice *> operands,
+                                   ArrayRef<const SGMapLattice *> results);
+
 public:
   SGMapPropagation(DataFlowSolver &solver, SymbolTableCollection &symbolTable)
       : SparseBackwardDataFlowAnalysis(solver, symbolTable) {}
@@ -274,6 +302,10 @@ SGMapPropagation::visitOperation(Operation *op,
     visitCreateDescOp(createDesc, operands, results);
   else if (auto storeScatter = dyn_cast<xegpu::StoreScatterOp>(op))
     visitStoreScatterOp(storeScatter, operands, results);
+  else if (auto updateNdOffset = dyn_cast<xegpu::UpdateNdOffsetOp>(op))
+    visitUpdateNdOffsetOp(updateNdOffset, operands, results);
+  else if (auto reduction = dyn_cast<vector::MultiDimReductionOp>(op))
+    visitVectorMultiReductionOp(reduction, operands, results);
   /// All other ops
   else {
     for (const SGMapLattice *r : results) {
@@ -293,6 +325,43 @@ SGMapPropagation::visitOperation(Operation *op,
   return success();
 }
 
+void SGMapPropagation::visitVectorMultiReductionOp(
+    vector::MultiDimReductionOp reduction, ArrayRef<SGMapLattice *> operands,
+    ArrayRef<const SGMapLattice *> results) {
+  /// The layout of the result must be present.
+  auto resultLayout = results[0]->getValue();
+  if (!resultLayout.isAssigned())
+    return;
+  /// We only consider 2D -> 1D reductions at this point.
+  assert(resultLayout.getLayout().size() == 1 &&
+         "Expected 1D layout for reduction result.");
+  /// Given that the result is 1D, the layout of the operand should be 2D with
+  /// default layout.
+  auto operandLayout = getDefaultSgMap(2);
+  operandLayout.print(llvm::outs());
+  propagateIfChanged(operands[0], operands[0]->meet(operandLayout));
+  /// Accumulator should have the same layout as the result.
+  propagateIfChanged(operands[1], operands[1]->meet(resultLayout));
+}
+
+/// Propagate the layout of the result tensor to the source tensor descriptor in
+/// UpdateNdOffsetOp.
+void SGMapPropagation::visitUpdateNdOffsetOp(
+    xegpu::UpdateNdOffsetOp updateNdOffset, ArrayRef<SGMapLattice *> operands,
+    ArrayRef<const SGMapLattice *> results) {
+  /// The layout of the result must be present.
+  auto resultLayout = results[0]->getValue();
+  if (!resultLayout.isAssigned())
+    return;
+  /// Propagate the layout to the source operand.
+  propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
+  /// For all other operands use 1D default layout.
+  SGMap layout = getDefaultSgMap(1);
+  for (size_t i = 1; i < operands.size(); ++i) {
+    propagateIfChanged(operands[i], operands[i]->meet(layout));
+  }
+}
+
 /// Set the layouts for DPAS A, B, and C operands.
 void SGMapPropagation::visitDpasOp(xegpu::DpasOp dpas,
                                    ArrayRef<SGMapLattice *> operands,
@@ -314,8 +383,7 @@ void SGMapPropagation::visitDpasOp(xegpu::DpasOp dpas,
 void SGMapPropagation::visitStoreNdOp(xegpu::StoreNdOp store,
                                       ArrayRef<SGMapLattice *> operands,
                                       ArrayRef<const SGMapLattice *> results) {
-  auto storeLayout =
-      getDefaultSgMap(store.getTensorDescType().getElementType());
+  auto storeLayout = getDefaultSgMap(store.getValueType());
   /// Both operands should have the same layout
   for (SGMapLattice *operand : operands) {
     propagateIfChanged(operand, operand->meet(storeLayout));
@@ -334,7 +402,6 @@ void SGMapPropagation::visitLoadNdOp(xegpu::LoadNdOp load,
   SGMap tensorDescLayout = valueLayout;
   /// LoadNdOp has the transpose effect. However, at the stage of this analyis
   /// this effect is not expected and should be abstracted away. Emit a warning.
-  /// TODO: Handle this case properly when `order` is introduced in the sg_map.
   if (auto transpose = load.getTranspose()) {
     load.emitWarning("Transpose effect is not expected for LoadNdOp");
     tensorDescLayout = valueLayout.getTransposedLayout(transpose.value());
@@ -409,15 +476,12 @@ void SGMapPropagation::visitLoadGatherOp(
     /// LoadGatherOp has the transpose effect. However, at the stage of this
     /// analyis this effect is not expected and should be abstracted away. Emit
     /// a warning.
-    /// TODO: Handle this case properly when `order` is introduced in the
-    /// sg_map.
     load.emitWarning("Transpose effect is not expected for LoadGatherOp");
     tensorDescLayout = valueLayout.getTransposedLayout({1, 0});
   } else
     tensorDescLayout = valueLayout;
-  /// Mask operand should have the same layout as the value but with wi_data =
-  /// [1, 1]
-  SGMap maskLayout = SGMap(valueLayout.getLayout(), WiData({1, 1}));
+  /// Mask operand should have 1D default layout.
+  auto maskLayout = getDefaultSgMap(1);
   /// Propagate the new layout to the tensor descriptor operand.
   propagateIfChanged(operands[0], operands[0]->meet(tensorDescLayout));
   /// Propagate the new layout to the mask operand.
@@ -434,8 +498,8 @@ void SGMapPropagation::visitCreateNdDescOp(
     return;
   /// Propagate the layout to the source operand.
   propagateIfChanged(operands[0], operands[0]->meet(descLayout));
-  /// For all other operands propagate the descriptor layout.
-  SGMap layout = getDefaultSgMap();
+  /// For all other operands use 1D default layout.
+  SGMap layout = getDefaultSgMap(1);
   for (size_t i = 1; i < operands.size(); ++i) {
     propagateIfChanged(operands[i], operands[i]->meet(layout));
   }
@@ -452,8 +516,8 @@ void SGMapPropagation::visitCreateDescOp(
     return;
   /// Propagate the layout to the source operand.
   propagateIfChanged(operands[0], operands[0]->meet(descLayout));
-  /// For offset operand propagate the default layout.
-  SGMap layout = getDefaultSgMap();
+  /// For offset operand propagate 1D default layout.
+  SGMap layout = getDefaultSgMap(1);
   propagateIfChanged(operands[1], operands[1]->meet(layout));
 }
 
@@ -462,15 +526,12 @@ void SGMapPropagation::visitCreateDescOp(
 void SGMapPropagation::visitStoreScatterOp(
     xegpu::StoreScatterOp storeScatter, ArrayRef<SGMapLattice *> operands,
     ArrayRef<const SGMapLattice *> results) {
-  auto valueLayout =
-      getDefaultSgMap(storeScatter.getTensorDescType().getElementType());
+  auto valueLayout = getDefaultSgMap(storeScatter.getValueType());
   SGMap storeScatterLayout;
   if (storeScatter.getTranspose()) {
     /// StoreScatteOp allows transpose effect. However, at the stage of this
     /// analyis this effect is not expected and should be abstracted away. Emit
     /// a warning.
-    /// TODO: Handle this case properly when `order` is introduced in the
-    /// sg_map.
     storeScatter.emitWarning(
         "Transpose effect is not expected for StoreScatterOp");
     storeScatterLayout = valueLayout.getTransposedLayout({1, 0});
@@ -480,8 +541,8 @@ void SGMapPropagation::visitStoreScatterOp(
   propagateIfChanged(operands[0], operands[0]->meet(valueLayout));
   /// Propagate the tensor descriptor layout.
   propagateIfChanged(operands[1], operands[1]->meet(storeScatterLayout));
-  /// Use default layout for mask operand.
-  auto maskLayout = getDefaultSgMap();
+  /// Use default 1D layout for mask operand.
+  auto maskLayout = getDefaultSgMap(1);
   propagateIfChanged(operands[2], operands[2]->meet(maskLayout));
 }
 
diff --git a/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir b/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
index 92e37a20bf555..7d8bd20391463 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
@@ -2,7 +2,7 @@
 
 // CHECK: function: test_dpas_op_1:
 // CHECK: op    : %{{.*}} = arith.constant 0 : index
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
 // CHECK-NEXT: op    : %{{.*}} = arith.constant dense<0.000000e+00> : vector<8x16xf32>
 // CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
 // CHECK-NEXT: op    : %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
@@ -40,7 +40,7 @@ func.func @test_dpas_op_1(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %ar
 // CHECK-NEXT: argument: <block argument> of type 'memref<8x16xi32>' at index: 2
 // CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
 // CHECK-NEXT: op    : %{{.*}} = arith.constant 0 : index
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
 // CHECK-NEXT: op    : %[[T0:.*]] = xegpu.dpas %{{.*}} : vector<8x32xi8>, vector<32x16xi8> -> vector<8x16xi32>
 // CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
 // CHECK-NEXT: op    : %[[T1:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xi32> -> !xegpu.tensor_desc<8x16xi32>
@@ -62,7 +62,7 @@ func.func @test_dpas_op_2(%arg0: vector<8x32xi8>, %arg1: vector<32x16xi8>, %arg2
 // CHECK-NEXT: argument: <block argument> of type 'memref<8x16xf32>' at index: 2
 // CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
 // CHECK-NEXT: op    : %{{.*}} = arith.constant 0 : index
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
 // CHECK-NEXT: op    : %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<8x16xf32>
 // CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
 // CHECK-NEXT: op    : %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
@@ -98,7 +98,7 @@ func.func @test_transpose_op_1(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>
 // CHECK-NEXT: argument: <block argument> of type 'memref<8x16xf32>' at index: 2
 // CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
 // CHECK-NEXT: op    : %{{.*}} = arith.constant 0 : index
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
 // CHECK-NEXT: op    : %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<8x16xf32>
 // CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
 // CHECK-NEXT: op    : %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
@@ -162,15 +162,15 @@ func.func @test_extf_truncf_op(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegp
 // CHECK-NEXT: argument: <block argument> of type 'memref<8x16xf32>' at index: 2
 // CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
 // CHECK-NEXT: op    : %{{.*}} = arith.constant 0 : index
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
 // CHECK-NEXT: op    : %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
 // CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
 // CHECK-NEXT: op    : %[[T1:.*]] = xegpu.load_nd %[[T0]]  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
 // CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
 // CHECK-NEXT: op    : %[[CST:.*]] = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
 // CHECK-NEXT: op    : %[[CST0:.*]] = arith.constant dense<true> : vector<16xi1>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
 // CHECK-NEXT: op    : %[[T2:.*]] = xegpu.create_tdesc %{{.*}}, %[[CST]] : memref<256xf16>, vector<16xindex> -> !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16 : i64>>
 // CHECK-NEXT: sg_map for result #0: wi_layout: [16, 1], wi_data: [1, 2]
 // CHECK-NEXT: op    : %[[T3:.*]] = xegpu.load %[[T2]], %[[CST0]] <{transpose}> : !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16 : i64>>, vector<16xi1> -> vector<16x16xf16>
@@ -195,17 +195,17 @@ func.func @test_load_gather_op_1(%arg0: memref<8x16xf16>, %arg1: memref<256xf16>
 
 // -----
 // CHECK: argument: <block argument> of type 'memref<256xf32>' at index: 0
-// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map  : wi_layout: [16], wi_data: [1]
 // CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<16xf32>' at index: 1
-// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map  : wi_layout: [16], wi_data: [1]
 // CHECK-NEXT: op    : %[[CST:.*]] = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
 // CHECK-NEXT: op    : %[[CST0:.*]] = arith.constant dense<true> : vector<16xi1>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
 // CHECK-NEXT: op    : %[[T0:.*]] = xegpu.create_tdesc %{{.*}}, %[[CST]] : memref<256xf32>, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
 // CHECK-NEXT: op    : %[[T1]] = xegpu.load %[[T0]], %[[CST0]]  : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> -> vector<16xf32>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
 func.func @test_load_gather_op_2(%arg0: memref<256xf32>, %arg1: !xegpu.tensor_desc<16xf32>) {
   %cst = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
   %cst_0 = arith.constant dense<true> : vector<16xi1>
@@ -221,9 +221,9 @@ func.func @test_load_gather_op_2(%arg0: memref<256xf32>, %arg1: !xegpu.tensor_de
 // CHECK-NEXT: op    : %[[CST:.*]] = arith.constant dense<1.000000e+00> : vector<8x16xf32>
 // CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
 // CHECK-NEXT: op    : %[[CST0:.*]] = arith.constant dense<true> : vector<16xi1>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
 // CHECK-NEXT: op    : %[[CST1:.*]] = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
 // CHECK-NEXT: op    : %[[T0:.*]] = xegpu.create_tdesc %{{.*}}, %[[CST1]] : memref<128xf32>, vector<16xindex> -> !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<chunk_size = 8 : i64>>
 // CHECK-NEXT: sg_map for result #0: wi_layout: [16, 1], wi_data: [1, 1]
 func.func @test_store_scatter_op_1(%arg0: memref<128xf32>) {
@@ -237,15 +237,15 @@ func.func @test_store_scatter_op_1(%arg0: memref<128xf32>) {
 
 // -----
 // CHECK: argument: <block argument> of type 'vector<16xf32>' at index: 0
-// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map  : wi_layout: [16], wi_data: [1]
 // CHECK-NEXT: argument: <block argument> of type 'memref<256xf32>' at index: 1
-// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map  : wi_layout: [16], wi_data: [1]
 // CHECK-NEXT: op    : %[[CST:.*]] = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
 // CHECK-NEXT: op    : %[[CST1:.*]] = arith.constant dense<true> : vector<16xi1>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
 // CHECK-NEXT: op    : %[[T0:.*]] = xegpu.create_tdesc %{{.*}}, %[[CST]] : memref<256xf32>, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
 func.func @test_store_scatter_op_2(%arg0: vector<16xf32>, %arg1: memref<256xf32>) {
   %cst = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
   %cst_0 = arith.constant dense<true> : vector<16xi1>
@@ -262,7 +262,7 @@ func.func @test_store_scatter_op_2(%arg0: vector<16xf32>, %arg1: memref<256xf32>
 // CHECK-NEXT: argument: <block argument> of type 'memref<8x16xi32>' at index: 2
 // CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
 // CHECK-NEXT: op    : %{{.*}} = arith.constant 0 : index
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
 // CHECK-NEXT: op    : %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xi16> -> !xegpu.tensor_desc<8x16xi16>
 // CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
 // CHECK-NEXT: op    : %[[T1:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<32x16xi8> -> !xegpu.tensor_desc<32x16xi8>
@@ -298,7 +298,7 @@ func.func @test_vector_bitcast_op_1(%arg0: memref<8x16xi16>, %arg1: memref<32x16
 // CHECK-NEXT: argument: <block argument> of type 'memref<8x16xf32>' at index: 2
 // CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
 // CHECK-NEXT: op    : %{{.*}} = arith.constant 0 : index
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
 // CHECK-NEXT: op    : %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x32xi8> -> !xegpu.tensor_desc<8x32xi8>
 // CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 2]
 // CHECK-NEXT: op    : %[[T1:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<16x32xi8> -> !xegpu.tensor_desc<16x32xi8>
@@ -394,11 +394,11 @@ func.func @test_binary_op_2(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.t
 // CHECK-NEXT: argument: <block argument> of type 'memref<8x16xf32>' at index: 2
 // CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
 // CHECK-NEXT: op    : %{{.*}} = arith.constant 0 : index
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
 // CHECK-NEXT: op    : %{{.*}} = arith.constant 128 : index
 // CHECK-NEXT: sg_map for result #0: Not assigned.
 // CHECK-NEXT: op    : %{{.*}} = arith.constant 16 : index
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
 // CHECK-NEXT: op    : %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x128xf16> -> !xegpu.tensor_desc<8x16xf16>
 // CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
 // CHECK-NEXT: op    : %[[T1:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<128x16xf16> -> !xegpu.tensor_desc<16x16xf16>
@@ -550,11 +550,11 @@ func.func @test_if_op_3(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tenso
 // CHECK-NEXT: argument: <block argument> of type 'vector<16x16xf32>' at index: 0
 // CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
 // CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<16xf32>' at index: 1
-// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map  : wi_layout: [16], wi_data: [1]
 // CHECK-NEXT: op    : %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<16xf32>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
 // CHECK-NEXT: op    : %[[T0:.*]] = vector.multi_reduction <add>, %{{.*}}, %[[CST]] [0] : vector<16x16xf32> to vector<16xf32>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
 func.func @test_vector_reduction_1(%arg0: vector<16x16xf32>, %arg1: !xegpu.tensor_desc<16xf32>) {
   %cst = arith.constant dense<0.000000e+00> : vector<16xf32>
   %0 = vector.multi_reduction <add>, %arg0, %cst [0] : vector<16x16xf32> to vector<16xf32>
@@ -566,11 +566,11 @@ func.func @test_vector_reduction_1(%arg0: vector<16x16xf32>, %arg1: !xegpu.tenso
 // CHECK: argument: <block argument> of type 'vector<16x16xf32>' at index: 0
 // CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
 // CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<16xf32>' at index: 1
-// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map  : wi_layout: [16], wi_data: [1]
 // CHECK-NEXT: op    : %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<16xf32>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
 // CHECK-NEXT: op    : %[[T0:.*]] = vector.multi_reduction <add>, %{{.*}}, %[[CST]] [1] : vector<16x16xf32> to vector<16xf32>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
 func.func @test_vector_reduction_2(%arg0: vector<16x16xf32>, %arg1: !xegpu.tensor_desc<16xf32>) {
   %cst = arith.constant dense<0.000000e+00> : vector<16xf32>
   %0 = vector.multi_reduction <add>, %arg0, %cst [1] : vector<16x16xf32> to vector<16xf32>

>From 5820c15c393a5b29e5aa8533627de51e480b6275 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Tue, 11 Mar 2025 18:54:34 +0000
Subject: [PATCH 16/28] save work

---
 .../Transforms/XeGPUSubgroupDistribute.cpp    | 56 +++++--------
 .../XeGPU/subgroup-map-propagation.mlir       | 78 ++++++++++---------
 2 files changed, 63 insertions(+), 71 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 7136993d0e968..14a9f5954edcf 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -16,6 +16,7 @@
 #include "mlir/Dialect/XeGPU/Transforms/Passes.h"
 #include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
 #include "mlir/IR/Builders.h"
+#include "mlir/Support/LLVM.h"
 #include "llvm/Support/raw_ostream.h"
 
 namespace mlir {
@@ -242,10 +243,6 @@ class SGMapPropagation : public SparseBackwardDataFlowAnalysis<SGMapLattice> {
                             ArrayRef<SGMapLattice *> operands,
                             ArrayRef<const SGMapLattice *> results);
 
-  void visitCreateNdDescOp(xegpu::CreateNdDescOp createNdDesc,
-                           ArrayRef<SGMapLattice *> operands,
-                           ArrayRef<const SGMapLattice *> results);
-
   void visitCreateDescOp(xegpu::CreateDescOp createDesc,
                          ArrayRef<SGMapLattice *> operands,
                          ArrayRef<const SGMapLattice *> results);
@@ -296,8 +293,6 @@ SGMapPropagation::visitOperation(Operation *op,
     visitVectorBitcastOp(bitcast, operands, results);
   else if (auto loadGather = dyn_cast<xegpu::LoadGatherOp>(op))
     visitLoadGatherOp(loadGather, operands, results);
-  else if (auto createNdDesc = dyn_cast<xegpu::CreateNdDescOp>(op))
-    visitCreateNdDescOp(createNdDesc, operands, results);
   else if (auto createDesc = dyn_cast<xegpu::CreateDescOp>(op))
     visitCreateDescOp(createDesc, operands, results);
   else if (auto storeScatter = dyn_cast<xegpu::StoreScatterOp>(op))
@@ -306,6 +301,10 @@ SGMapPropagation::visitOperation(Operation *op,
     visitUpdateNdOffsetOp(updateNdOffset, operands, results);
   else if (auto reduction = dyn_cast<vector::MultiDimReductionOp>(op))
     visitVectorMultiReductionOp(reduction, operands, results);
+  /// No need to propagate the layout to operands in CreateNdDescOp because they
+  /// are scalars (offsets, sizes, etc.).
+  else if (auto createNdDesc = dyn_cast<xegpu::CreateNdDescOp>(op))
+    return success();
   /// All other ops
   else {
     for (const SGMapLattice *r : results) {
@@ -355,11 +354,6 @@ void SGMapPropagation::visitUpdateNdOffsetOp(
     return;
   /// Propagate the layout to the source operand.
   propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
-  /// For all other operands use 1D default layout.
-  SGMap layout = getDefaultSgMap(1);
-  for (size_t i = 1; i < operands.size(); ++i) {
-    propagateIfChanged(operands[i], operands[i]->meet(layout));
-  }
 }
 
 /// Set the layouts for DPAS A, B, and C operands.
@@ -403,7 +397,8 @@ void SGMapPropagation::visitLoadNdOp(xegpu::LoadNdOp load,
   /// LoadNdOp has the transpose effect. However, at the stage of this analyis
   /// this effect is not expected and should be abstracted away. Emit a warning.
   if (auto transpose = load.getTranspose()) {
-    load.emitWarning("Transpose effect is not expected for LoadNdOp");
+    load.emitWarning("Transpose effect is not expected for LoadNdOp at "
+                     "SGMapPropagation stage.");
     tensorDescLayout = valueLayout.getTransposedLayout(transpose.value());
   }
   /// Propagate the new layout to the tensor descriptor operand.
@@ -476,7 +471,8 @@ void SGMapPropagation::visitLoadGatherOp(
     /// LoadGatherOp has the transpose effect. However, at the stage of this
     /// analyis this effect is not expected and should be abstracted away. Emit
     /// a warning.
-    load.emitWarning("Transpose effect is not expected for LoadGatherOp");
+    load.emitWarning("Transpose effect is not expected for LoadGatherOp at "
+                     "SGMapPropagation stage.");
     tensorDescLayout = valueLayout.getTransposedLayout({1, 0});
   } else
     tensorDescLayout = valueLayout;
@@ -488,24 +484,7 @@ void SGMapPropagation::visitLoadGatherOp(
   propagateIfChanged(operands[1], operands[1]->meet(maskLayout));
 }
 
-/// Propagate the layout of the descriptor to the operands in CreateNdDescOp.
-void SGMapPropagation::visitCreateNdDescOp(
-    xegpu::CreateNdDescOp createNdDesc, ArrayRef<SGMapLattice *> operands,
-    ArrayRef<const SGMapLattice *> results) {
-  auto descLayout = results[0]->getValue();
-  /// Need the layout of the descriptor to propagate to the operands.
-  if (!descLayout.isAssigned())
-    return;
-  /// Propagate the layout to the source operand.
-  propagateIfChanged(operands[0], operands[0]->meet(descLayout));
-  /// For all other operands use 1D default layout.
-  SGMap layout = getDefaultSgMap(1);
-  for (size_t i = 1; i < operands.size(); ++i) {
-    propagateIfChanged(operands[i], operands[i]->meet(layout));
-  }
-}
-
-/// Propagate the layout of the descriptor to the source and offset operands in
+/// Propagate the layout of the descriptor to the vector offset operand in
 /// CreateDescOp.
 void SGMapPropagation::visitCreateDescOp(
     xegpu::CreateDescOp createDesc, ArrayRef<SGMapLattice *> operands,
@@ -514,8 +493,6 @@ void SGMapPropagation::visitCreateDescOp(
   /// Need the layout of the descriptor to propagate to the operands.
   if (!descLayout.isAssigned())
     return;
-  /// Propagate the layout to the source operand.
-  propagateIfChanged(operands[0], operands[0]->meet(descLayout));
   /// For offset operand propagate 1D default layout.
   SGMap layout = getDefaultSgMap(1);
   propagateIfChanged(operands[1], operands[1]->meet(layout));
@@ -526,14 +503,23 @@ void SGMapPropagation::visitCreateDescOp(
 void SGMapPropagation::visitStoreScatterOp(
     xegpu::StoreScatterOp storeScatter, ArrayRef<SGMapLattice *> operands,
     ArrayRef<const SGMapLattice *> results) {
+  /// Currently, for 2D StoreScatterOp we expect that the height dimension of
+  /// the tensor descriptor is evenly divisible by the subgroup size.
+  /// TODO: Add support for other 2D shapes.
+  auto tdescShape = storeScatter.getTensorDescType().getShape();
+  if (tdescShape.size() > 1 && tdescShape[0] % subgroupSize != 0) {
+    storeScatter.emitError("Height dimension of the tensor descriptor should "
+                           "be evenly divisible by the subgroup size.");
+    return;
+  }
   auto valueLayout = getDefaultSgMap(storeScatter.getValueType());
   SGMap storeScatterLayout;
   if (storeScatter.getTranspose()) {
     /// StoreScatteOp allows transpose effect. However, at the stage of this
     /// analyis this effect is not expected and should be abstracted away. Emit
     /// a warning.
-    storeScatter.emitWarning(
-        "Transpose effect is not expected for StoreScatterOp");
+    storeScatter.emitWarning("Transpose effect is not expected for "
+                             "StoreScatterOp at SGMapPropagation stage.");
     storeScatterLayout = valueLayout.getTransposedLayout({1, 0});
   } else
     storeScatterLayout = valueLayout;
diff --git a/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir b/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
index 7d8bd20391463..1f16c0f738af9 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
@@ -1,8 +1,14 @@
 // RUN: mlir-opt -xegpu-subgroup-distribute='print-analysis-only=true' -split-input-file %s | FileCheck %s
 
-// CHECK: function: test_dpas_op_1:
-// CHECK: op    : %{{.*}} = arith.constant 0 : index
-// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
+// CHECK: function: test_dpas_f16:
+// CHECK-NEXT: argument: <block argument> of type 'memref<8x16xf16>' at index: 0
+// CHECK-NEXT: sg_map  : Not assigned.
+// CHECK-NEXT: argument: <block argument> of type 'memref<16x16xf16>' at index: 1
+// CHECK-NEXT: sg_map  : Not assigned.
+// CHECK-NEXT: argument: <block argument> of type 'memref<8x16xf32>' at index: 2
+// CHECK-NEXT: sg_map  : Not assigned.
+// CHECK-NEXT: op    : %{{.*}} = arith.constant 0 : index
+// CHECK-NEXT: sg_map for result #0: Not assigned.
 // CHECK-NEXT: op    : %{{.*}} = arith.constant dense<0.000000e+00> : vector<8x16xf32>
 // CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
 // CHECK-NEXT: op    : %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
@@ -17,7 +23,7 @@
 // CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
 // CHECK-NEXT: op    : %{{.*}} = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
 // CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
-func.func @test_dpas_op_1(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<8x16xf32>) {
+func.func @test_dpas_f16(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<8x16xf32>) {
   %c0 = arith.constant 0 : index
   %cst = arith.constant dense<0.000000e+00> : vector<8x16xf32>
   %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
@@ -32,20 +38,20 @@ func.func @test_dpas_op_1(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %ar
 
 
 // -----
-// CHECK: function: test_dpas_op_2:
+// CHECK: function: test_dpas_i8:
 // CHECK-NEXT: argument: <block argument> of type 'vector<8x32xi8>' at index: 0
 // CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 2]
 // CHECK-NEXT: argument: <block argument> of type 'vector<32x16xi8>' at index: 1
 // CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [4, 1]
 // CHECK-NEXT: argument: <block argument> of type 'memref<8x16xi32>' at index: 2
-// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map  : Not assigned.
 // CHECK-NEXT: op    : %{{.*}} = arith.constant 0 : index
-// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
+// CHECK-NEXT: sg_map for result #0: Not assigned.
 // CHECK-NEXT: op    : %[[T0:.*]] = xegpu.dpas %{{.*}} : vector<8x32xi8>, vector<32x16xi8> -> vector<8x16xi32>
 // CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
 // CHECK-NEXT: op    : %[[T1:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xi32> -> !xegpu.tensor_desc<8x16xi32>
 // CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
-func.func @test_dpas_op_2(%arg0: vector<8x32xi8>, %arg1: vector<32x16xi8>, %arg2: memref<8x16xi32>) {
+func.func @test_dpas_i8(%arg0: vector<8x32xi8>, %arg1: vector<32x16xi8>, %arg2: memref<8x16xi32>) {
   %c0 = arith.constant 0 : index
   %0 = xegpu.dpas %arg0, %arg1 : vector<8x32xi8>, vector<32x16xi8> -> vector<8x16xi32>
   %1 = xegpu.create_nd_tdesc %arg2[%c0, %c0] : memref<8x16xi32> -> !xegpu.tensor_desc<8x16xi32>
@@ -56,13 +62,13 @@ func.func @test_dpas_op_2(%arg0: vector<8x32xi8>, %arg1: vector<32x16xi8>, %arg2
 // -----
 // CHECK: function: test_transpose_op_1:
 // CHECK-NEXT: argument: <block argument> of type 'memref<8x16xf16>' at index: 0
-// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map  : Not assigned.
 // CHECK-NEXT: argument: <block argument> of type 'memref<16x16xf16>' at index: 1
-// CHECK-NEXT: sg_map  : wi_layout: [16, 1], wi_data: [1, 2]
+// CHECK-NEXT: sg_map  : Not assigned.
 // CHECK-NEXT: argument: <block argument> of type 'memref<8x16xf32>' at index: 2
-// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map  : Not assigned.
 // CHECK-NEXT: op    : %{{.*}} = arith.constant 0 : index
-// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
+// CHECK-NEXT: sg_map for result #0: Not assigned.
 // CHECK-NEXT: op    : %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<8x16xf32>
 // CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
 // CHECK-NEXT: op    : %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
@@ -92,13 +98,13 @@ func.func @test_transpose_op_1(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>
 
 // -----
 // CHECK: argument: <block argument> of type 'memref<8x16xf16>' at index: 0
-// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map  : Not assigned.
 // CHECK-NEXT: argument: <block argument> of type 'memref<16x16xf16>' at index: 1
-// CHECK-NEXT: sg_map  : wi_layout: [16, 1], wi_data: [1, 2]
+// CHECK-NEXT: sg_map  : Not assigned.
 // CHECK-NEXT: argument: <block argument> of type 'memref<8x16xf32>' at index: 2
-// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map  : Not assigned.
 // CHECK-NEXT: op    : %{{.*}} = arith.constant 0 : index
-// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
+// CHECK-NEXT: sg_map for result #0: Not assigned.
 // CHECK-NEXT: op    : %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<8x16xf32>
 // CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
 // CHECK-NEXT: op    : %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
@@ -156,13 +162,13 @@ func.func @test_extf_truncf_op(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegp
 // -----
 // CHECK: function: test_load_gather_op_1:
 // CHECK-NEXT: argument: <block argument> of type 'memref<8x16xf16>' at index: 0
-// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map  : Not assigned.
 // CHECK-NEXT: argument: <block argument> of type 'memref<256xf16>' at index: 1
-// CHECK-NEXT: sg_map  : wi_layout: [16, 1], wi_data: [1, 2]
+// CHECK-NEXT: sg_map  : Not assigned.
 // CHECK-NEXT: argument: <block argument> of type 'memref<8x16xf32>' at index: 2
-// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map  : Not assigned.
 // CHECK-NEXT: op    : %{{.*}} = arith.constant 0 : index
-// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
+// CHECK-NEXT: sg_map for result #0: Not assigned.
 // CHECK-NEXT: op    : %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
 // CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
 // CHECK-NEXT: op    : %[[T1:.*]] = xegpu.load_nd %[[T0]]  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
@@ -195,7 +201,7 @@ func.func @test_load_gather_op_1(%arg0: memref<8x16xf16>, %arg1: memref<256xf16>
 
 // -----
 // CHECK: argument: <block argument> of type 'memref<256xf32>' at index: 0
-// CHECK-NEXT: sg_map  : wi_layout: [16], wi_data: [1]
+// CHECK-NEXT: sg_map  : Not assigned.
 // CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<16xf32>' at index: 1
 // CHECK-NEXT: sg_map  : wi_layout: [16], wi_data: [1]
 // CHECK-NEXT: op    : %[[CST:.*]] = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
@@ -217,7 +223,7 @@ func.func @test_load_gather_op_2(%arg0: memref<256xf32>, %arg1: !xegpu.tensor_de
 
 // -----
 // CHECK: argument: <block argument> of type 'memref<128xf32>' at index: 0
-// CHECK-NEXT: sg_map  : wi_layout: [16, 1], wi_data: [1, 1]
+// CHECK-NEXT: sg_map  : Not assigned.
 // CHECK-NEXT: op    : %[[CST:.*]] = arith.constant dense<1.000000e+00> : vector<8x16xf32>
 // CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
 // CHECK-NEXT: op    : %[[CST0:.*]] = arith.constant dense<true> : vector<16xi1>
@@ -239,7 +245,7 @@ func.func @test_store_scatter_op_1(%arg0: memref<128xf32>) {
 // CHECK: argument: <block argument> of type 'vector<16xf32>' at index: 0
 // CHECK-NEXT: sg_map  : wi_layout: [16], wi_data: [1]
 // CHECK-NEXT: argument: <block argument> of type 'memref<256xf32>' at index: 1
-// CHECK-NEXT: sg_map  : wi_layout: [16], wi_data: [1]
+// CHECK-NEXT: sg_map  : Not assigned.
 // CHECK-NEXT: op    : %[[CST:.*]] = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
 // CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
 // CHECK-NEXT: op    : %[[CST1:.*]] = arith.constant dense<true> : vector<16xi1>
@@ -256,13 +262,13 @@ func.func @test_store_scatter_op_2(%arg0: vector<16xf32>, %arg1: memref<256xf32>
 
 // -----
 // CHECK: argument: <block argument> of type 'memref<8x16xi16>' at index: 0
-// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map  : Not assigned.
 // CHECK-NEXT: argument: <block argument> of type 'memref<32x16xi8>' at index: 1
-// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [4, 1]
+// CHECK-NEXT: sg_map  : Not assigned.
 // CHECK-NEXT: argument: <block argument> of type 'memref<8x16xi32>' at index: 2
-// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map  : Not assigned.
 // CHECK-NEXT: op    : %{{.*}} = arith.constant 0 : index
-// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
+// CHECK-NEXT: sg_map for result #0: Not assigned.
 // CHECK-NEXT: op    : %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xi16> -> !xegpu.tensor_desc<8x16xi16>
 // CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
 // CHECK-NEXT: op    : %[[T1:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<32x16xi8> -> !xegpu.tensor_desc<32x16xi8>
@@ -292,13 +298,13 @@ func.func @test_vector_bitcast_op_1(%arg0: memref<8x16xi16>, %arg1: memref<32x16
 
 // -----
 // CHECK: argument: <block argument> of type 'memref<8x32xi8>' at index: 0
-// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 2]
+// CHECK-NEXT: sg_map  : Not assigned.
 // CHECK-NEXT: argument: <block argument> of type 'memref<16x32xi8>' at index: 1
-// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [4, 1]
+// CHECK-NEXT: sg_map  : Not assigned.
 // CHECK-NEXT: argument: <block argument> of type 'memref<8x16xf32>' at index: 2
-// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map  : Not assigned.
 // CHECK-NEXT: op    : %{{.*}} = arith.constant 0 : index
-// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
+// CHECK-NEXT: sg_map for result #0: Not assigned.
 // CHECK-NEXT: op    : %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x32xi8> -> !xegpu.tensor_desc<8x32xi8>
 // CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 2]
 // CHECK-NEXT: op    : %[[T1:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<16x32xi8> -> !xegpu.tensor_desc<16x32xi8>
@@ -388,17 +394,17 @@ func.func @test_binary_op_2(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.t
 
 // -----
 // CHECK: argument: <block argument> of type 'memref<8x128xf16>' at index: 0
-// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map  : Not assigned.
 // CHECK-NEXT: argument: <block argument> of type 'memref<128x16xf16>' at index: 1
-// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [2, 1]
+// CHECK-NEXT: sg_map  : Not assigned.
 // CHECK-NEXT: argument: <block argument> of type 'memref<8x16xf32>' at index: 2
-// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map  : Not assigned.
 // CHECK-NEXT: op    : %{{.*}} = arith.constant 0 : index
-// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
+// CHECK-NEXT: sg_map for result #0: Not assigned.
 // CHECK-NEXT: op    : %{{.*}} = arith.constant 128 : index
 // CHECK-NEXT: sg_map for result #0: Not assigned.
 // CHECK-NEXT: op    : %{{.*}} = arith.constant 16 : index
-// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
+// CHECK-NEXT: sg_map for result #0: Not assigned.
 // CHECK-NEXT: op    : %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x128xf16> -> !xegpu.tensor_desc<8x16xf16>
 // CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
 // CHECK-NEXT: op    : %[[T1:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<128x16xf16> -> !xegpu.tensor_desc<16x16xf16>

>From f7b7bdfc5fb2fbf64eb4db05c0918e912ce30825 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Tue, 11 Mar 2025 19:08:55 +0000
Subject: [PATCH 17/28] save work

---
 .../XeGPU/subgroup-map-propagation.mlir       | 104 +++++++-----------
 1 file changed, 41 insertions(+), 63 deletions(-)

diff --git a/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir b/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
index 1f16c0f738af9..1ae4348af33e6 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
@@ -60,7 +60,7 @@ func.func @test_dpas_i8(%arg0: vector<8x32xi8>, %arg1: vector<32x16xi8>, %arg2:
 }
 
 // -----
-// CHECK: function: test_transpose_op_1:
+// CHECK: function: test_load_with_transpose_effect:
 // CHECK-NEXT: argument: <block argument> of type 'memref<8x16xf16>' at index: 0
 // CHECK-NEXT: sg_map  : Not assigned.
 // CHECK-NEXT: argument: <block argument> of type 'memref<16x16xf16>' at index: 1
@@ -83,7 +83,7 @@ func.func @test_dpas_i8(%arg0: vector<8x32xi8>, %arg1: vector<32x16xi8>, %arg2:
 // CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
 // CHECK-NEXT: op    : %[[T5:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
 // CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
-func.func @test_transpose_op_1(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<8x16xf32>) {
+func.func @test_load_with_transpose_effect(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<8x16xf32>) {
   %c0 = arith.constant 0 : index
   %cst = arith.constant dense<0.000000e+00> : vector<8x16xf32>
   %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
@@ -97,7 +97,8 @@ func.func @test_transpose_op_1(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>
 }
 
 // -----
-// CHECK: argument: <block argument> of type 'memref<8x16xf16>' at index: 0
+// CHECK: function: test_vector_transpose:
+// CHECK-NEXT: argument: <block argument> of type 'memref<8x16xf16>' at index: 0
 // CHECK-NEXT: sg_map  : Not assigned.
 // CHECK-NEXT: argument: <block argument> of type 'memref<16x16xf16>' at index: 1
 // CHECK-NEXT: sg_map  : Not assigned.
@@ -121,7 +122,7 @@ func.func @test_transpose_op_1(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>
 // CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
 // CHECK-NEXT: op    : %[[T6:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
 // CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
-func.func @test_transpose_op_2(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<8x16xf32>) {
+func.func @test_vector_transpose(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<8x16xf32>) {
   %c0 = arith.constant 0 : index
   %cst = arith.constant dense<0.000000e+00> : vector<8x16xf32>
   %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
@@ -136,7 +137,8 @@ func.func @test_transpose_op_2(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>
 }
 
 // -----
-// CHECK: argument: <block argument> of type '!xegpu.tensor_desc<8x16xf16>' at index: 0
+// CHECK: function: test_extf_truncf:
+// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<8x16xf16>' at index: 0
 // CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
 // CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<16x16xf16>' at index: 1
 // CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [2, 1]
@@ -150,7 +152,7 @@ func.func @test_transpose_op_2(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>
 // CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1]
 // CHECK-NEXT: op    : %[[T4:.*]] = xegpu.dpas %[[T0]], %[[T3]] : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
 // CHECK-NEXT: sg_map for result #0: Not assigned.
-func.func @test_extf_truncf_op(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>) -> vector<8x16xf32> {
+func.func @test_extf_truncf(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>) -> vector<8x16xf32> {
   %0 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
   %1 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
   %2 = arith.extf %1 : vector<16x16xf16> to vector<16x16xf32>
@@ -160,7 +162,7 @@ func.func @test_extf_truncf_op(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegp
 }
 
 // -----
-// CHECK: function: test_load_gather_op_1:
+// CHECK: function: test_load_gather_with_transpose_effect:
 // CHECK-NEXT: argument: <block argument> of type 'memref<8x16xf16>' at index: 0
 // CHECK-NEXT: sg_map  : Not assigned.
 // CHECK-NEXT: argument: <block argument> of type 'memref<256xf16>' at index: 1
@@ -185,7 +187,7 @@ func.func @test_extf_truncf_op(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegp
 // CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
 // CHECK-NEXT: op    : %[[T5:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
 // CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
-func.func @test_load_gather_op_1(%arg0: memref<8x16xf16>, %arg1: memref<256xf16>, %arg2: memref<8x16xf32>) {
+func.func @test_load_gather_with_transpose_effect(%arg0: memref<8x16xf16>, %arg1: memref<256xf16>, %arg2: memref<8x16xf32>) {
   %c0 = arith.constant 0 : index
   %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
   %1 = xegpu.load_nd %0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
@@ -200,6 +202,7 @@ func.func @test_load_gather_op_1(%arg0: memref<8x16xf16>, %arg1: memref<256xf16>
 }
 
 // -----
+// CHECK: function: test_load_gather_1d:
 // CHECK: argument: <block argument> of type 'memref<256xf32>' at index: 0
 // CHECK-NEXT: sg_map  : Not assigned.
 // CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<16xf32>' at index: 1
@@ -212,7 +215,7 @@ func.func @test_load_gather_op_1(%arg0: memref<8x16xf16>, %arg1: memref<256xf16>
 // CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
 // CHECK-NEXT: op    : %[[T1]] = xegpu.load %[[T0]], %[[CST0]]  : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> -> vector<16xf32>
 // CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
-func.func @test_load_gather_op_2(%arg0: memref<256xf32>, %arg1: !xegpu.tensor_desc<16xf32>) {
+func.func @test_load_gather_1d(%arg0: memref<256xf32>, %arg1: !xegpu.tensor_desc<16xf32>) {
   %cst = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
   %cst_0 = arith.constant dense<true> : vector<16xi1>
   %0 = xegpu.create_tdesc %arg0, %cst : memref<256xf32>, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
@@ -222,7 +225,8 @@ func.func @test_load_gather_op_2(%arg0: memref<256xf32>, %arg1: !xegpu.tensor_de
 }
 
 // -----
-// CHECK: argument: <block argument> of type 'memref<128xf32>' at index: 0
+// CHECK: function: test_store_scatter_with_transpose_effect:
+// CHECK-NEXT: argument: <block argument> of type 'memref<128xf32>' at index: 0
 // CHECK-NEXT: sg_map  : Not assigned.
 // CHECK-NEXT: op    : %[[CST:.*]] = arith.constant dense<1.000000e+00> : vector<8x16xf32>
 // CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
@@ -232,7 +236,7 @@ func.func @test_load_gather_op_2(%arg0: memref<256xf32>, %arg1: !xegpu.tensor_de
 // CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
 // CHECK-NEXT: op    : %[[T0:.*]] = xegpu.create_tdesc %{{.*}}, %[[CST1]] : memref<128xf32>, vector<16xindex> -> !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<chunk_size = 8 : i64>>
 // CHECK-NEXT: sg_map for result #0: wi_layout: [16, 1], wi_data: [1, 1]
-func.func @test_store_scatter_op_1(%arg0: memref<128xf32>) {
+func.func @test_store_scatter_with_transpose_effect(%arg0: memref<128xf32>) {
   %cst = arith.constant dense<1.000000e+00> : vector<8x16xf32>
   %cst_0 = arith.constant dense<true> : vector<16xi1>
   %cst_1 = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
@@ -242,7 +246,8 @@ func.func @test_store_scatter_op_1(%arg0: memref<128xf32>) {
 }
 
 // -----
-// CHECK: argument: <block argument> of type 'vector<16xf32>' at index: 0
+// CHECK: function: test_store_scatter_1d:
+// CHECK-NEXT: argument: <block argument> of type 'vector<16xf32>' at index: 0
 // CHECK-NEXT: sg_map  : wi_layout: [16], wi_data: [1]
 // CHECK-NEXT: argument: <block argument> of type 'memref<256xf32>' at index: 1
 // CHECK-NEXT: sg_map  : Not assigned.
@@ -252,7 +257,7 @@ func.func @test_store_scatter_op_1(%arg0: memref<128xf32>) {
 // CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
 // CHECK-NEXT: op    : %[[T0:.*]] = xegpu.create_tdesc %{{.*}}, %[[CST]] : memref<256xf32>, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
 // CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
-func.func @test_store_scatter_op_2(%arg0: vector<16xf32>, %arg1: memref<256xf32>) {
+func.func @test_store_scatter_1d(%arg0: vector<16xf32>, %arg1: memref<256xf32>) {
   %cst = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
   %cst_0 = arith.constant dense<true> : vector<16xi1>
   %0 = xegpu.create_tdesc %arg1, %cst : memref<256xf32>, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
@@ -261,7 +266,8 @@ func.func @test_store_scatter_op_2(%arg0: vector<16xf32>, %arg1: memref<256xf32>
 }
 
 // -----
-// CHECK: argument: <block argument> of type 'memref<8x16xi16>' at index: 0
+// CHECK: function: test_vector_bitcast_i16_to_i8:
+// CHECK-NEXT: argument: <block argument> of type 'memref<8x16xi16>' at index: 0
 // CHECK-NEXT: sg_map  : Not assigned.
 // CHECK-NEXT: argument: <block argument> of type 'memref<32x16xi8>' at index: 1
 // CHECK-NEXT: sg_map  : Not assigned.
@@ -283,7 +289,7 @@ func.func @test_store_scatter_op_2(%arg0: vector<16xf32>, %arg1: memref<256xf32>
 // CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
 // CHECK-NEXT: op    : %[[T6:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xi32> -> !xegpu.tensor_desc<8x16xi32>
 // CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
-func.func @test_vector_bitcast_op_1(%arg0: memref<8x16xi16>, %arg1: memref<32x16xi8>, %arg2: memref<8x16xi32>) {
+func.func @test_vector_bitcast_i16_to_i8(%arg0: memref<8x16xi16>, %arg1: memref<32x16xi8>, %arg2: memref<8x16xi32>) {
   %c0 = arith.constant 0 : index
   %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x16xi16> -> !xegpu.tensor_desc<8x16xi16>
   %1 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<32x16xi8> -> !xegpu.tensor_desc<32x16xi8>
@@ -297,7 +303,8 @@ func.func @test_vector_bitcast_op_1(%arg0: memref<8x16xi16>, %arg1: memref<32x16
 }
 
 // -----
-// CHECK: argument: <block argument> of type 'memref<8x32xi8>' at index: 0
+// CHECK: function: test_vector_bitcast_i8_to_f16:
+// CHECK-NEXT: argument: <block argument> of type 'memref<8x32xi8>' at index: 0
 // CHECK-NEXT: sg_map  : Not assigned.
 // CHECK-NEXT: argument: <block argument> of type 'memref<16x32xi8>' at index: 1
 // CHECK-NEXT: sg_map  : Not assigned.
@@ -321,7 +328,7 @@ func.func @test_vector_bitcast_op_1(%arg0: memref<8x16xi16>, %arg1: memref<32x16
 // CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
 // CHECK-NEXT: op    : %[[T7:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
 // CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
-func.func @test_vector_bitcast_op_2(%arg0: memref<8x32xi8>, %arg1: memref<16x32xi8>, %arg2: memref<8x16xf32>) {
+func.func @test_vector_bitcast_i8_to_f16(%arg0: memref<8x32xi8>, %arg1: memref<16x32xi8>, %arg2: memref<8x16xf32>) {
   %c0 = arith.constant 0 : index
   %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x32xi8> -> !xegpu.tensor_desc<8x32xi8>
   %1 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<16x32xi8> -> !xegpu.tensor_desc<16x32xi8>
@@ -336,7 +343,8 @@ func.func @test_vector_bitcast_op_2(%arg0: memref<8x32xi8>, %arg1: memref<16x32x
 }
 
 // -----
-// CHECK: argument: <block argument> of type '!xegpu.tensor_desc<8x16xf16>' at index: 0
+// CHECK: function: test_binary_op_one_use:
+// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<8x16xf16>' at index: 0
 // CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
 // CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<16x16xf16>' at index: 1
 // CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [2, 1]
@@ -352,7 +360,7 @@ func.func @test_vector_bitcast_op_2(%arg0: memref<8x32xi8>, %arg1: memref<16x32x
 // CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1]
 // CHECK-NEXT: op    : %[[T4:.*]] = xegpu.dpas %[[T0]], %[[T3]] : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
 // CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
-func.func @test_binary_op_1(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: !xegpu.tensor_desc<8x16xf32>) {
+func.func @test_binary_op_one_use(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: !xegpu.tensor_desc<8x16xf32>) {
   %0 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
   %1 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
   %2 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
@@ -363,7 +371,8 @@ func.func @test_binary_op_1(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.t
 }
 
 // -----
-// CHECK: argument: <block argument> of type '!xegpu.tensor_desc<8x16xf16>' at index: 0
+// CHECK: function: test_binary_op_multiple_uses:
+// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<8x16xf16>' at index: 0
 // CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
 // CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<16x16xf16>' at index: 1
 // CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
@@ -381,7 +390,7 @@ func.func @test_binary_op_1(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.t
 // CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
 // CHECK-NEXT: op    : %[[T3:.*]] = xegpu.dpas %[[T0]], %[[T2]] : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
 // CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
-func.func @test_binary_op_2(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: !xegpu.tensor_desc<8x16xf32>, %arg3: !xegpu.tensor_desc<16x16xf16>) {
+func.func @test_binary_op_multiple_uses(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: !xegpu.tensor_desc<8x16xf32>, %arg3: !xegpu.tensor_desc<16x16xf16>) {
   %0 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
   %1 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
   %cst = arith.constant dense<1.000000e+00> : vector<16x16xf16>
@@ -393,7 +402,8 @@ func.func @test_binary_op_2(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.t
 }
 
 // -----
-// CHECK: argument: <block argument> of type 'memref<8x128xf16>' at index: 0
+// CHECK: function: test_for_op:
+// CHECK-NEXT: argument: <block argument> of type 'memref<8x128xf16>' at index: 0
 // CHECK-NEXT: sg_map  : Not assigned.
 // CHECK-NEXT: argument: <block argument> of type 'memref<128x16xf16>' at index: 1
 // CHECK-NEXT: sg_map  : Not assigned.
@@ -448,7 +458,7 @@ func.func @test_for_op(%arg0: memref<8x128xf16>, %arg1: memref<128x16xf16>, %arg
 }
 
 // -----
-// CHECK: function: test_if_op_1:
+// CHECK: function: test_if_single_use:
 // CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<8x16xf16>' at index: 0
 // CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
 // CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<16x16xf16>' at index: 1
@@ -467,7 +477,7 @@ func.func @test_for_op(%arg0: memref<8x128xf16>, %arg1: memref<128x16xf16>, %arg
 // CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1]
 // CHECK-NEXT: op    : %[[T2:.*]] = xegpu.dpas %[[T0]], %{{.*}} : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
 // CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
-func.func @test_if_op_1(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: i1, %arg3: !xegpu.tensor_desc<8x16xf32>) {
+func.func @test_if_single_use(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: i1, %arg3: !xegpu.tensor_desc<8x16xf32>) {
   %0 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
   %1 = scf.if %arg2 -> (vector<16x16xf16>) {
     %3 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
@@ -482,7 +492,7 @@ func.func @test_if_op_1(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tenso
 }
 
 // -----
-// CHECK: function: test_if_op_2:
+// CHECK: function: test_if_multiple_uses:
 // CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<8x16xf16>' at index: 0
 // CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
 // CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<16x16xf16>' at index: 1
@@ -503,7 +513,7 @@ func.func @test_if_op_1(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tenso
 // CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
 // CHECK-NEXT: op    : %[[T2:.*]] = xegpu.dpas %[[T0]], %{{.*}} : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
 // CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
-func.func @test_if_op_2(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: i1, %arg3: !xegpu.tensor_desc<8x16xf32>, %arg4: !xegpu.tensor_desc<16x16xf16>) {
+func.func @test_if_multiple_uses(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: i1, %arg3: !xegpu.tensor_desc<8x16xf32>, %arg4: !xegpu.tensor_desc<16x16xf16>) {
   %0 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
   %1 = scf.if %arg2 -> (vector<16x16xf16>) {
     %3 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
@@ -519,40 +529,7 @@ func.func @test_if_op_2(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tenso
 }
 
 // -----
-// CHECK: function: test_if_op_3:
-// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<8x16xf16>' at index: 0
-// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
-// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<16x16xf16>' at index: 1
-// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [2, 1]
-// CHECK-NEXT: argument: <block argument> of type 'vector<16x16xf16>' at index: 2
-// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [2, 1]
-// CHECK-NEXT: argument: <block argument> of type 'i1' at index: 3
-// CHECK-NEXT: sg_map  : Not assigned.
-// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<8x16xf32>' at index: 4
-// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
-// CHECK-NEXT: op    : %[[T0:.*]] = xegpu.load_nd %{{.*}}  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
-// CHECK-NEXT: op    : %[[T3:.*]] = xegpu.load_nd %{{.*}}  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1]
-// CHECK-NEXT: op    : scf.if
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1]
-// CHECK-NEXT: op    : %[[T2:.*]] = xegpu.dpas %[[T0]], %{{.*}} : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
-func.func @test_if_op_3(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: vector<16x16xf16>, %arg3: i1, %arg4: !xegpu.tensor_desc<8x16xf32>) {
-  %0 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
-  %1 = scf.if %arg3 -> (vector<16x16xf16>) {
-    %3 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
-    scf.yield %3 : vector<16x16xf16>
-  } else {
-    scf.yield %arg2 : vector<16x16xf16>
-  }
-  %2 = xegpu.dpas %0, %1 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
-  xegpu.store_nd %2, %arg4  : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
-  return
-}
-
-// -----
-// CHECK: function: test_vector_reduction_1:
+// CHECK: function: test_vector_outer_reduction:
 // CHECK-NEXT: argument: <block argument> of type 'vector<16x16xf32>' at index: 0
 // CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
 // CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<16xf32>' at index: 1
@@ -561,7 +538,7 @@ func.func @test_if_op_3(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tenso
 // CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
 // CHECK-NEXT: op    : %[[T0:.*]] = vector.multi_reduction <add>, %{{.*}}, %[[CST]] [0] : vector<16x16xf32> to vector<16xf32>
 // CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
-func.func @test_vector_reduction_1(%arg0: vector<16x16xf32>, %arg1: !xegpu.tensor_desc<16xf32>) {
+func.func @test_vector_outer_reduction(%arg0: vector<16x16xf32>, %arg1: !xegpu.tensor_desc<16xf32>) {
   %cst = arith.constant dense<0.000000e+00> : vector<16xf32>
   %0 = vector.multi_reduction <add>, %arg0, %cst [0] : vector<16x16xf32> to vector<16xf32>
   xegpu.store_nd %0, %arg1  : vector<16xf32>, !xegpu.tensor_desc<16xf32>
@@ -569,7 +546,8 @@ func.func @test_vector_reduction_1(%arg0: vector<16x16xf32>, %arg1: !xegpu.tenso
 }
 
 // -----
-// CHECK: argument: <block argument> of type 'vector<16x16xf32>' at index: 0
+// CHECK: function: test_vector_inner_reduction:
+// CHECK-NEXT: argument: <block argument> of type 'vector<16x16xf32>' at index: 0
 // CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
 // CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<16xf32>' at index: 1
 // CHECK-NEXT: sg_map  : wi_layout: [16], wi_data: [1]
@@ -577,7 +555,7 @@ func.func @test_vector_reduction_1(%arg0: vector<16x16xf32>, %arg1: !xegpu.tenso
 // CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
 // CHECK-NEXT: op    : %[[T0:.*]] = vector.multi_reduction <add>, %{{.*}}, %[[CST]] [1] : vector<16x16xf32> to vector<16xf32>
 // CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
-func.func @test_vector_reduction_2(%arg0: vector<16x16xf32>, %arg1: !xegpu.tensor_desc<16xf32>) {
+func.func @test_vector_inner_reduction(%arg0: vector<16x16xf32>, %arg1: !xegpu.tensor_desc<16xf32>) {
   %cst = arith.constant dense<0.000000e+00> : vector<16xf32>
   %0 = vector.multi_reduction <add>, %arg0, %cst [1] : vector<16x16xf32> to vector<16xf32>
   xegpu.store_nd %0, %arg1  : vector<16xf32>, !xegpu.tensor_desc<16xf32>

>From f64581a7dfdf9dc15e4d8e5d4305e89acae692fd Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Tue, 11 Mar 2025 20:34:22 +0000
Subject: [PATCH 18/28] save work

---
 .../Transforms/XeGPUSubgroupDistribute.cpp    | 106 ++++++++++--------
 1 file changed, 62 insertions(+), 44 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 14a9f5954edcf..d5e4f6532955e 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -32,9 +32,15 @@ namespace xegpu {
 using namespace mlir;
 using namespace mlir::dataflow;
 
-constexpr unsigned subgroupSize = 16;
-constexpr unsigned packedASizeInBits = 16;
-constexpr unsigned packedBSizeInBits = 32;
+/// HW dependent constants.
+/// TODO: These constants should be queried from the uArch interface.
+constexpr unsigned subgroupSize = 16; // How many work items in a subgroup.
+/// If DPAS A or B operands have low precision element types they must be packed
+/// according to the following sizes.
+constexpr unsigned packedSizeInBitsForDefault =
+    16; // Minimum packing size per register for DPAS A.
+constexpr unsigned packedSizeInBitsForDpasB =
+    32; // Minimum packing size per register for DPAS B.
 
 namespace {
 
@@ -51,7 +57,7 @@ struct Layout {
   Layout(std::initializer_list<int64_t> list) : layout(list) {}
   void print(llvm::raw_ostream &os) const;
   size_t size() const { return layout.size(); }
-  int64_t operator[](size_t idx) const { return layout[idx]; }
+  int64_t operator[](size_t idx) const;
 };
 
 void Layout::print(llvm::raw_ostream &os) const {
@@ -60,6 +66,11 @@ void Layout::print(llvm::raw_ostream &os) const {
   os << "]";
 }
 
+int64_t Layout::operator[](size_t idx) const {
+  assert(idx < layout.size() && "Index out of bounds.");
+  return layout[idx];
+}
+
 /// WiLayout represents the layout of work items within a subgroup when it
 /// accesses some value. WiData represents the layout of data owned by each work
 /// item.
@@ -86,14 +97,14 @@ using WiData = Layout;
 
 struct SGMap {
 private:
-  WiLayout layout;
-  WiData data;
+  WiLayout wiLayout;
+  WiData wiData;
 
 public:
   SGMap() = default;
   SGMap(const SGMap &other) = default;
   SGMap(const WiLayout &layout, const WiData &data)
-      : layout(layout), data(data) {}
+      : wiLayout(layout), wiData(data) {}
 
   /// Two lattice values are equal if they have `some` layout. The actual
   /// content of the layout does not matter.
@@ -107,20 +118,20 @@ struct SGMap {
 
   void print(raw_ostream &os) const;
 
-  bool isAssigned() const { return layout.size() > 0 && data.size() > 0; }
+  bool isAssigned() const { return wiLayout.size() > 0 && wiData.size() > 0; }
 
   SGMap getTransposedLayout(ArrayRef<int64_t> permutation) const;
 
-  const WiLayout &getLayout() const { return layout; }
-  const WiData &getData() const { return data; }
+  const WiLayout &getLayout() const { return wiLayout; }
+  const WiData &getData() const { return wiData; }
 };
 
 void SGMap::print(raw_ostream &os) const {
   if (isAssigned()) {
     os << "wi_layout: ";
-    layout.print(os);
+    wiLayout.print(os);
     os << ", wi_data: ";
-    data.print(os);
+    wiData.print(os);
   } else
     os << "Not assigned.";
 }
@@ -143,8 +154,8 @@ SGMap SGMap::getTransposedLayout(ArrayRef<int64_t> permutation) const {
   WiLayout newLayout;
   WiData newData;
   for (auto idx : permutation) {
-    newLayout.layout.push_back(layout.layout[idx]);
-    newData.layout.push_back(data.layout[idx]);
+    newLayout.layout.push_back(wiLayout.layout[idx]);
+    newData.layout.push_back(wiData.layout[idx]);
   }
   return SGMap(newLayout, newData);
 }
@@ -159,33 +170,14 @@ struct SGMapLattice : public Lattice<SGMap> {
   using Lattice::Lattice;
 };
 
-/// Helper Functions
-
-/// Helper Function to get the expected layouts for DPAS operands.
-static SGMap getSGMapForDPASOperand(Type operandTy, unsigned operandNum) {
-  int packingFactorForB = packedBSizeInBits / operandTy.getIntOrFloatBitWidth();
-  int packingFactorForA =
-      operandTy.getIntOrFloatBitWidth() < packedBSizeInBits
-          ? packedASizeInBits / operandTy.getIntOrFloatBitWidth()
-          : 1;
-  return SGMap(WiLayout({1, subgroupSize}),
-               WiData({operandNum == 1 ? packingFactorForB : 1,
-                       operandNum == 0 ? packingFactorForA : 1}));
-}
-
-/// Helper Function to get the default layout for a given type. Usually this is,
-/// wi_layout = [1, subgroupSize] and wi_data = [1, 1].
-/// However, the minimum granularity of data access per work item is 16-bits.
-/// So, if the bitwidth of the type is less than 16, we need to pack the data to
-/// 16-bits.
-// static SGMap getDefaultSgMap(Type ty, unsigned rank) {
-//   int packingFactor = 1;
-//   if (ty.getIntOrFloatBitWidth() < packedASizeInBits)
-//     packingFactor = packedBSizeInBits / ty.getIntOrFloatBitWidth();
-//   return SGMap(WiLayout({1, subgroupSize}), WiData({1, packingFactor}));
-// }
+/// Helper Functions to get default layouts. A `default layout` is a layout that
+/// is assigned to a value when the layout is not fixed by some anchor operation
+/// (like DPAS). This is the natural layout work items are arranged in a
+/// subgroup.
 
 /// Helper Function to get the default layout for uniform values like constants.
+/// For 1D vector, wi_layout is [subgroupSize] and wi_data is [1].
+/// For 2D vector, wi_layout is [1, subgroupSize] and wi_data is [1, 1].
 static SGMap getDefaultSgMap(unsigned rank) {
   assert((rank == 1 || rank == 2) && "Expected 0D or 1D vector.");
   if (rank == 1)
@@ -193,20 +185,46 @@ static SGMap getDefaultSgMap(unsigned rank) {
   return SGMap(WiLayout({1, subgroupSize}), WiData({1, 1}));
 }
 
+/// Helper to get the default layout for a vector type.
 static SGMap getDefaultSgMap(VectorType vectorTy) {
   /// Expecting a 1D or 2D vector.
   assert((vectorTy.getRank() == 1 || vectorTy.getRank() == 2) &&
          "Expected 1D or 2D vector.");
+  /// Expecting int or float element type.
+  assert(vectorTy.getElementType().isIntOrFloat() &&
+         "Expected int or float element type.");
   /// If the rank is 1, then return default layout for 1D vector.
   if (vectorTy.getRank() == 1)
     return getDefaultSgMap(1);
+  /// Packing factor is determined by the element type bitwidth.
   int packingFactor = 1;
   auto bitwidth = vectorTy.getElementType().getIntOrFloatBitWidth();
-  if (bitwidth < packedASizeInBits)
-    packingFactor = packedBSizeInBits / bitwidth;
+  if (bitwidth < packedSizeInBitsForDefault)
+    packingFactor = packedSizeInBitsForDefault / bitwidth;
   return SGMap(WiLayout({1, subgroupSize}), WiData({1, packingFactor}));
 }
 
+/// Helper Function to get the expected layouts for DPAS operands. `wi_data` is
+/// set according to the following criteria:
+/// * For A operand, the data must be packed in minimum `packedDpasASizeInBits`
+/// * For B operand, the data must be packed in minimum `packedDpasBSizeInBits`
+static SGMap getSGMapForDPASOperand(VectorType vectorTy, unsigned operandNum) {
+  auto elementTy = vectorTy.getElementType();
+  assert(elementTy.isIntOrFloat() &&
+         "Expected int or float type in DPAS operands");
+  WiLayout layout({1, subgroupSize});
+  /// For B operand, data must be packed in minimum `packedDpasBSizeInBits` and
+  /// must have the VNNI format.
+  if (operandNum == 1 &&
+      elementTy.getIntOrFloatBitWidth() < packedSizeInBitsForDpasB) {
+    WiData data(
+        {packedSizeInBitsForDpasB / elementTy.getIntOrFloatBitWidth(), 1});
+    return SGMap(layout, data);
+  }
+  /// Otherwise, return the default layout for the vector type.
+  return getDefaultSgMap(vectorTy);
+}
+
 ///===----------------------------------------------------------------------===///
 /// SGMapPropagation
 ///===----------------------------------------------------------------------===///
@@ -360,14 +378,14 @@ void SGMapPropagation::visitUpdateNdOffsetOp(
 void SGMapPropagation::visitDpasOp(xegpu::DpasOp dpas,
                                    ArrayRef<SGMapLattice *> operands,
                                    ArrayRef<const SGMapLattice *> results) {
-  auto aTy = dpas.getLhsType().getElementType();
-  auto bTy = dpas.getRhsType().getElementType();
+  auto aTy = dpas.getLhsType();
+  auto bTy = dpas.getRhsType();
   propagateIfChanged(operands[0],
                      operands[0]->meet(getSGMapForDPASOperand(aTy, 0)));
   propagateIfChanged(operands[1],
                      operands[1]->meet(getSGMapForDPASOperand(bTy, 1)));
   if (operands.size() > 2) {
-    auto cTy = dpas.getAccType().getElementType();
+    auto cTy = dpas.getAccType();
     propagateIfChanged(operands[2],
                        operands[2]->meet(getSGMapForDPASOperand(cTy, 2)));
   }

>From 562b5c7f7fc16dddc95ebb3ea914ab9314233e3c Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Tue, 11 Mar 2025 21:34:18 +0000
Subject: [PATCH 19/28] save work

---
 .../mlir/Dialect/XeGPU/Transforms/Passes.td        |  5 +++++
 .../XeGPU/Transforms/XeGPUSubgroupDistribute.cpp   | 14 ++++----------
 2 files changed, 9 insertions(+), 10 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
index cb9d403566645..3e81f2d0ed786 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
@@ -31,6 +31,11 @@ def XeGPUSubgroupDistribute : Pass<"xegpu-subgroup-distribute"> {
   let dependentDialects = [
       "memref::MemRefDialect", "xegpu::XeGPUDialect", "vector::VectorDialect"
   ];
+  let options = [
+    Option<"printOnly", "print-analysis-only", "bool",
+         /*default=*/"false",
+         "Print the result of the subgroup map propagation analysis and exit.">
+  ];
 }
 
 #endif // MLIR_DIALECT_XEGPU_TRANSFORMS_PASSES_TD
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index d5e4f6532955e..6213a6b7d1a6e 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -355,7 +355,6 @@ void SGMapPropagation::visitVectorMultiReductionOp(
   /// Given that the result is 1D, the layout of the operand should be 2D with
   /// default layout.
   auto operandLayout = getDefaultSgMap(2);
-  operandLayout.print(llvm::outs());
   propagateIfChanged(operands[0], operands[0]->meet(operandLayout));
   /// Accumulator should have the same layout as the result.
   propagateIfChanged(operands[1], operands[1]->meet(resultLayout));
@@ -625,16 +624,11 @@ struct XeGPUSubgroupDistributePass final
     : public xegpu::impl::XeGPUSubgroupDistributeBase<
           XeGPUSubgroupDistributePass> {
   XeGPUSubgroupDistributePass() = default;
-  XeGPUSubgroupDistributePass(const XeGPUSubgroupDistributePass &other)
-      : xegpu::impl::XeGPUSubgroupDistributeBase<XeGPUSubgroupDistributePass>(
-            other) {
-    this->printOnly = other.printOnly;
-  }
+  XeGPUSubgroupDistributePass(const XeGPUSubgroupDistributePass &other) =
+      default;
+  XeGPUSubgroupDistributePass(xegpu::XeGPUSubgroupDistributeOptions options)
+      : XeGPUSubgroupDistributeBase(options) {}
   void runOnOperation() override;
-  /// Print sg map propagation analysis result and exit for testing purposes.
-  Option<bool> printOnly{*this, "print-analysis-only", llvm::cl::init(false),
-                         llvm::cl::desc("Print the result of the subgroup map "
-                                        "propagation analysis and exit.")};
 };
 } // namespace
 

>From d8b74b34b3b9c781ee039ceb3adb6b05559d0049 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Tue, 11 Mar 2025 21:47:39 +0000
Subject: [PATCH 20/28] save work

---
 .../Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp | 9 ++++-----
 1 file changed, 4 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 6213a6b7d1a6e..d6f39887af8e7 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -179,7 +179,7 @@ struct SGMapLattice : public Lattice<SGMap> {
 /// For 1D vector, wi_layout is [subgroupSize] and wi_data is [1].
 /// For 2D vector, wi_layout is [1, subgroupSize] and wi_data is [1, 1].
 static SGMap getDefaultSgMap(unsigned rank) {
-  assert((rank == 1 || rank == 2) && "Expected 0D or 1D vector.");
+  assert((rank == 1 || rank == 2) && "Expected 1D or 2D vector.");
   if (rank == 1)
     return SGMap(WiLayout({subgroupSize}), WiData({1}));
   return SGMap(WiLayout({1, subgroupSize}), WiData({1, 1}));
@@ -428,11 +428,10 @@ void SGMapPropagation::visitTransposeOp(
     vector::TransposeOp transpose, ArrayRef<SGMapLattice *> operands,
     ArrayRef<const SGMapLattice *> results) {
   /// Need the layout of transpose result to propagate to the operands.
-  auto operandLayout = results[0]->getValue();
-  if (!operandLayout.isAssigned())
+  auto resultLayout = results[0]->getValue();
+  if (!resultLayout.isAssigned())
     return;
-  auto newLayout =
-      operandLayout.getTransposedLayout(transpose.getPermutation());
+  auto newLayout = resultLayout.getTransposedLayout(transpose.getPermutation());
   /// Propagate the new layout to the vector operand.
   propagateIfChanged(operands[0], operands[0]->meet(newLayout));
 }

>From 7196c9f5e5267fe7eaa9fc60092f4fe2544c303f Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Wed, 12 Mar 2025 19:42:08 +0000
Subject: [PATCH 21/28] save work

---
 .../Transforms/XeGPUSubgroupDistribute.cpp    | 93 ++++++++++---------
 1 file changed, 50 insertions(+), 43 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index d6f39887af8e7..f8a27176ae66c 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -17,6 +17,7 @@
 #include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/Support/LLVM.h"
+#include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Support/raw_ostream.h"
 
 namespace mlir {
@@ -26,9 +27,6 @@ namespace xegpu {
 } // namespace xegpu
 } // namespace mlir
 
-#define DEBUG_TYPE "xegpu-subgroup-distribute"
-#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
-
 using namespace mlir;
 using namespace mlir::dataflow;
 
@@ -206,8 +204,10 @@ static SGMap getDefaultSgMap(VectorType vectorTy) {
 
 /// Helper Function to get the expected layouts for DPAS operands. `wi_data` is
 /// set according to the following criteria:
-/// * For A operand, the data must be packed in minimum `packedDpasASizeInBits`
-/// * For B operand, the data must be packed in minimum `packedDpasBSizeInBits`
+/// * For A operand, the data must be packed in minimum
+/// `packedSizeInBitsForDefault`
+/// * For B operand, the data must be packed in minimum
+/// `packedSizeInBitsForDpasB`
 static SGMap getSGMapForDPASOperand(VectorType vectorTy, unsigned operandNum) {
   auto elementTy = vectorTy.getElementType();
   assert(elementTy.isIntOrFloat() &&
@@ -299,40 +299,48 @@ LogicalResult
 SGMapPropagation::visitOperation(Operation *op,
                                  ArrayRef<SGMapLattice *> operands,
                                  ArrayRef<const SGMapLattice *> results) {
-  if (auto dpas = dyn_cast<xegpu::DpasOp>(op))
-    visitDpasOp(dpas, operands, results);
-  else if (auto store = dyn_cast<xegpu::StoreNdOp>(op))
-    visitStoreNdOp(store, operands, results);
-  else if (auto load = dyn_cast<xegpu::LoadNdOp>(op))
-    visitLoadNdOp(load, operands, results);
-  else if (auto transpose = dyn_cast<vector::TransposeOp>(op))
-    visitTransposeOp(transpose, operands, results);
-  else if (auto bitcast = dyn_cast<vector::BitCastOp>(op))
-    visitVectorBitcastOp(bitcast, operands, results);
-  else if (auto loadGather = dyn_cast<xegpu::LoadGatherOp>(op))
-    visitLoadGatherOp(loadGather, operands, results);
-  else if (auto createDesc = dyn_cast<xegpu::CreateDescOp>(op))
-    visitCreateDescOp(createDesc, operands, results);
-  else if (auto storeScatter = dyn_cast<xegpu::StoreScatterOp>(op))
-    visitStoreScatterOp(storeScatter, operands, results);
-  else if (auto updateNdOffset = dyn_cast<xegpu::UpdateNdOffsetOp>(op))
-    visitUpdateNdOffsetOp(updateNdOffset, operands, results);
-  else if (auto reduction = dyn_cast<vector::MultiDimReductionOp>(op))
-    visitVectorMultiReductionOp(reduction, operands, results);
-  /// No need to propagate the layout to operands in CreateNdDescOp because they
-  /// are scalars (offsets, sizes, etc.).
-  else if (auto createNdDesc = dyn_cast<xegpu::CreateNdDescOp>(op))
-    return success();
-  /// All other ops
-  else {
-    for (const SGMapLattice *r : results) {
-      for (SGMapLattice *operand : operands) {
-        /// Propagate the layout of the result to the operand.
-        if (r->getValue().isAssigned())
-          meet(operand, *r);
-      }
-    }
-  }
+  TypeSwitch<Operation *>(op)
+      .Case<xegpu::DpasOp>(
+          [&](auto dpasOp) { visitDpasOp(dpasOp, operands, results); })
+      .Case<xegpu::StoreNdOp>(
+          [&](auto storeNdOp) { visitStoreNdOp(storeNdOp, operands, results); })
+      .Case<xegpu::StoreScatterOp>([&](auto storeScatterOp) {
+        visitStoreScatterOp(storeScatterOp, operands, results);
+      })
+      .Case<xegpu::LoadNdOp>(
+          [&](auto loadNdOp) { visitLoadNdOp(loadNdOp, operands, results); })
+      .Case<xegpu::LoadGatherOp>([&](auto loadGatherOp) {
+        visitLoadGatherOp(loadGatherOp, operands, results);
+      })
+      .Case<xegpu::CreateDescOp>([&](auto createDescOp) {
+        visitCreateDescOp(createDescOp, operands, results);
+      })
+      .Case<xegpu::UpdateNdOffsetOp>([&](auto updateNdOffsetOp) {
+        visitUpdateNdOffsetOp(updateNdOffsetOp, operands, results);
+      })
+      /// No need to propagate the layout to operands in CreateNdDescOp because
+      /// they are scalars (offsets, sizes, etc.).
+      .Case<xegpu::CreateNdDescOp>(
+          [&](auto createNdDescOp) { return success(); })
+      .Case<vector::TransposeOp>([&](auto transposeOp) {
+        visitTransposeOp(transposeOp, operands, results);
+      })
+      .Case<vector::BitCastOp>([&](auto bitcastOp) {
+        visitVectorBitcastOp(bitcastOp, operands, results);
+      })
+      .Case<vector::MultiDimReductionOp>([&](auto reductionOp) {
+        visitVectorMultiReductionOp(reductionOp, operands, results);
+      })
+      /// All other ops.
+      .Default([&](Operation *op) {
+        for (const SGMapLattice *r : results) {
+          for (SGMapLattice *operand : operands) {
+            /// Propagate the layout of the result to the operand.
+            if (r->getValue().isAssigned())
+              meet(operand, *r);
+          }
+        }
+      });
   /// Add a dependency from each reult to program point after the operation.
   /// NOTE: not sure if this is required, but all other similar analysis do
   /// this.
@@ -411,7 +419,7 @@ void SGMapPropagation::visitLoadNdOp(xegpu::LoadNdOp load,
   if (!valueLayout.isAssigned())
     return;
   SGMap tensorDescLayout = valueLayout;
-  /// LoadNdOp has the transpose effect. However, at the stage of this analyis
+  /// LoadNdOp has the transpose effect. However, at the stage of this analysis
   /// this effect is not expected and should be abstracted away. Emit a warning.
   if (auto transpose = load.getTranspose()) {
     load.emitWarning("Transpose effect is not expected for LoadNdOp at "
@@ -529,7 +537,7 @@ void SGMapPropagation::visitStoreScatterOp(
     return;
   }
   auto valueLayout = getDefaultSgMap(storeScatter.getValueType());
-  SGMap storeScatterLayout;
+  SGMap storeScatterLayout = valueLayout;
   if (storeScatter.getTranspose()) {
     /// StoreScatteOp allows transpose effect. However, at the stage of this
     /// analyis this effect is not expected and should be abstracted away. Emit
@@ -537,8 +545,7 @@ void SGMapPropagation::visitStoreScatterOp(
     storeScatter.emitWarning("Transpose effect is not expected for "
                              "StoreScatterOp at SGMapPropagation stage.");
     storeScatterLayout = valueLayout.getTransposedLayout({1, 0});
-  } else
-    storeScatterLayout = valueLayout;
+  }
   /// Propagate the value layout.
   propagateIfChanged(operands[0], operands[0]->meet(valueLayout));
   /// Propagate the tensor descriptor layout.

>From 4c9e641af62f8be956afdb2deacd548f25f1704c Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Wed, 12 Mar 2025 20:02:34 +0000
Subject: [PATCH 22/28] save work

---
 mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h       | 2 --
 mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp | 3 ---
 2 files changed, 5 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h b/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
index 86b95721df60c..63ea26df06937 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
@@ -16,8 +16,6 @@ namespace xegpu {
 
 /// Appends patterns for folding aliasing ops into XeGPU ops into `patterns`.
 void populateXeGPUFoldAliasOpsPatterns(RewritePatternSet &patterns);
-/// Appends patterns for distributing XeGPU ops to work items into `patterns`.
-void populateXeGPUSubgroupDistributePatterns(RewritePatternSet &patterns);
 
 } // namespace xegpu
 } // namespace mlir
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index f8a27176ae66c..13688e9cba62b 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -649,6 +649,3 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
     return;
   }
 }
-
-void xegpu::populateXeGPUSubgroupDistributePatterns(
-    RewritePatternSet &patterns) {}

>From 861f9075bfa18410c4d6e39cc822f85bb6fad13c Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Fri, 14 Mar 2025 00:17:07 +0000
Subject: [PATCH 23/28] save work

---
 .../Transforms/XeGPUSubgroupDistribute.cpp    | 78 +++++++++++--------
 1 file changed, 47 insertions(+), 31 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 13688e9cba62b..1fdd37f031385 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -10,12 +10,14 @@
 #include "mlir/Analysis/DataFlow/SparseAnalysis.h"
 #include "mlir/Analysis/DataFlowFramework.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/XeGPU/IR/XeGPU.h"
 #include "mlir/Dialect/XeGPU/Transforms/Passes.h"
 #include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
 #include "mlir/IR/Builders.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
 #include "mlir/Support/LLVM.h"
 #include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Support/raw_ostream.h"
@@ -490,7 +492,7 @@ void SGMapPropagation::visitLoadGatherOp(
   if (!valueLayout.isAssigned())
     return;
 
-  SGMap tensorDescLayout;
+  SGMap tensorDescLayout = valueLayout;
   if (load.getTranspose()) {
     /// LoadGatherOp has the transpose effect. However, at the stage of this
     /// analyis this effect is not expected and should be abstracted away. Emit
@@ -498,8 +500,7 @@ void SGMapPropagation::visitLoadGatherOp(
     load.emitWarning("Transpose effect is not expected for LoadGatherOp at "
                      "SGMapPropagation stage.");
     tensorDescLayout = valueLayout.getTransposedLayout({1, 0});
-  } else
-    tensorDescLayout = valueLayout;
+  }
   /// Mask operand should have 1D default layout.
   auto maskLayout = getDefaultSgMap(1);
   /// Propagate the new layout to the tensor descriptor operand.
@@ -590,38 +591,53 @@ SGMap RunSGMapPropagation::getSGMap(Value val) {
 }
 
 void RunSGMapPropagation::printAnalysisResult(llvm::raw_ostream &os) {
-  if (auto modOp = dyn_cast<ModuleOp>(target)) {
-    for (auto funcOp : modOp.getOps<func::FuncOp>()) {
-      os << "function: " << funcOp.getName() << ":\n";
-      // Function arguments
-      for (auto arg : funcOp.getArguments()) {
-        auto layout = getSGMap(arg);
-        os << "argument: " << arg << "\n";
-        os << "sg_map  : ";
+  auto printFunctionResult = [&](FunctionOpInterface funcOp) {
+    os << "function: " << funcOp.getName() << ":\n";
+    // Function arguments
+    for (auto arg : funcOp.getArguments()) {
+      auto layout = getSGMap(arg);
+      os << "argument: " << arg << "\n";
+      os << "sg_map  : ";
+      layout.print(os);
+      os << "\n";
+    }
+    // Function ops
+    funcOp.walk([&](Operation *op) {
+      // Skip ops that do not have results
+      if (op->getResults().empty())
+        return;
+      os << "op    : ";
+      /// For control-flow ops, print the op name only.
+      if (isa<BranchOpInterface>(op) || isa<RegionBranchOpInterface>(op))
+        os << op->getName();
+      else
+        op->print(os);
+      os << "\n";
+      /// Print the sg_map for each result.
+      for (auto [i, r] : llvm::enumerate(op->getResults())) {
+        auto layout = getSGMap(r);
+        os << "sg_map for result #" << i << ": ";
         layout.print(os);
         os << "\n";
       }
-      // Function ops
-      funcOp.walk([&](Operation *op) {
-        // Skip ops that do not have results
-        if (op->getResults().empty())
-          return;
-        os << "op    : ";
-        /// For control-flow ops, print the op name only.
-        if (isa<BranchOpInterface>(op) || isa<RegionBranchOpInterface>(op))
-          os << op->getName();
-        else
-          op->print(os);
-        os << "\n";
-        /// Print the sg_map for each result.
-        for (auto [i, r] : llvm::enumerate(op->getResults())) {
-          auto layout = getSGMap(r);
-          os << "sg_map for result #" << i << ": ";
-          layout.print(os);
-          os << "\n";
-        }
-      });
+    });
+  };
+
+  SmallVector<FunctionOpInterface> funcOps;
+  if (auto modOp = dyn_cast<ModuleOp>(target)) {
+    for (auto funcOp : modOp.getOps<FunctionOpInterface>()) {
+      funcOps.push_back(funcOp);
     }
+    /// Collect all GpuFuncOps in the module.
+    for (auto gpuModOp : modOp.getOps<gpu::GPUModuleOp>()) {
+      for (auto gpuFuncOp : gpuModOp.getOps<FunctionOpInterface>()) {
+        funcOps.push_back(gpuFuncOp);
+      }
+    }
+  }
+  /// Print the analysis result for each function.
+  for (auto funcOp : funcOps) {
+    printFunctionResult(funcOp);
   }
 }
 

>From ac337a3659df75a581367d0f139afa0c85b75df8 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Fri, 14 Mar 2025 01:53:08 +0000
Subject: [PATCH 24/28] save work

---
 mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp | 2 --
 1 file changed, 2 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 1fdd37f031385..124e65b57521d 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -9,7 +9,6 @@
 #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
 #include "mlir/Analysis/DataFlow/SparseAnalysis.h"
 #include "mlir/Analysis/DataFlowFramework.h"
-#include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
@@ -18,7 +17,6 @@
 #include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/Interfaces/FunctionInterfaces.h"
-#include "mlir/Support/LLVM.h"
 #include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Support/raw_ostream.h"
 

>From 7ac804c6e99c2cf54ade6618a5b09d958b1fc1fb Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Fri, 14 Mar 2025 16:38:07 +0000
Subject: [PATCH 25/28] save work

---
 mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp | 4 +---
 1 file changed, 1 insertion(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 124e65b57521d..2de68695fac2b 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -31,7 +31,7 @@ using namespace mlir;
 using namespace mlir::dataflow;
 
 /// HW dependent constants.
-/// TODO: These constants should be queried from the uArch interface.
+/// TODO: These constants should be queried from the target information.
 constexpr unsigned subgroupSize = 16; // How many work items in a subgroup.
 /// If DPAS A or B operands have low precision element types they must be packed
 /// according to the following sizes.
@@ -342,8 +342,6 @@ SGMapPropagation::visitOperation(Operation *op,
         }
       });
   /// Add a dependency from each reult to program point after the operation.
-  /// NOTE: not sure if this is required, but all other similar analysis do
-  /// this.
   for (const SGMapLattice *r : results) {
     addDependency(const_cast<SGMapLattice *>(r), getProgramPointAfter(op));
   }

>From 0ce71f87ce139e3380da4beef6a8a75503c30f8a Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Fri, 14 Mar 2025 16:38:45 +0000
Subject: [PATCH 26/28] save work

---
 mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 2de68695fac2b..74fcb6b8c7613 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -341,7 +341,7 @@ SGMapPropagation::visitOperation(Operation *op,
           }
         }
       });
-  /// Add a dependency from each reult to program point after the operation.
+  /// Add a dependency from each result to program point after the operation.
   for (const SGMapLattice *r : results) {
     addDependency(const_cast<SGMapLattice *>(r), getProgramPointAfter(op));
   }

>From 4949a1f286348fd023fcc40da3b9a7c80986e5a9 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Fri, 14 Mar 2025 16:46:44 +0000
Subject: [PATCH 27/28] save work

---
 mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp | 3 +--
 1 file changed, 1 insertion(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 74fcb6b8c7613..86e07697f437c 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -320,8 +320,7 @@ SGMapPropagation::visitOperation(Operation *op,
       })
       /// No need to propagate the layout to operands in CreateNdDescOp because
       /// they are scalars (offsets, sizes, etc.).
-      .Case<xegpu::CreateNdDescOp>(
-          [&](auto createNdDescOp) { return success(); })
+      .Case<xegpu::CreateNdDescOp>([&](auto createNdDescOp) {})
       .Case<vector::TransposeOp>([&](auto transposeOp) {
         visitTransposeOp(transposeOp, operands, results);
       })

>From c12692477e768c7160616c93614ec4094f721264 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Fri, 14 Mar 2025 19:13:29 +0000
Subject: [PATCH 28/28] save work

---
 mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt              | 1 +
 mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp | 2 +-
 2 files changed, 2 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
index 124e904edb543..9f041aae511df 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
@@ -15,4 +15,5 @@ add_mlir_dialect_library(MLIRXeGPUTransforms
   MLIRXeGPUDialect
   MLIRPass
   MLIRTransforms
+  MLIRGPUDialect
 )
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 86e07697f437c..55263b15523de 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -393,7 +393,7 @@ void SGMapPropagation::visitDpasOp(xegpu::DpasOp dpas,
     propagateIfChanged(operands[2],
                        operands[2]->meet(getSGMapForDPASOperand(cTy, 2)));
   }
-};
+}
 
 /// Set the layout for the value and tensor descriptor operands in StoreNdOp.
 void SGMapPropagation::visitStoreNdOp(xegpu::StoreNdOp store,



More information about the Mlir-commits mailing list