[Mlir-commits] [mlir] [mlir][xegpu] Add XeGPU subgroup map propagation analysis for XeGPU SIMT distribution. (PR #130240)

Charitha Saumya llvmlistbot at llvm.org
Tue Mar 11 12:09:27 PDT 2025


https://github.com/charithaintc updated https://github.com/llvm/llvm-project/pull/130240

>From ff972397d8c1f8acf51781fd26d42ae9ddd26db2 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Fri, 28 Feb 2025 23:57:37 +0000
Subject: [PATCH 01/17] save work

---
 .../mlir/Dialect/XeGPU/Transforms/Passes.td   |  10 +
 .../Dialect/XeGPU/Transforms/Transforms.h     |   2 +
 .../Dialect/XeGPU/Transforms/CMakeLists.txt   |   1 +
 .../Transforms/XeGPUSubgroupDistribute.cpp    | 249 ++++++++++++++++++
 4 files changed, 262 insertions(+)
 create mode 100644 mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp

diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
index 1ecd6ce95322b..cb9d403566645 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
@@ -23,4 +23,14 @@ def XeGPUFoldAliasOps : Pass<"xegpu-fold-alias-ops"> {
   ];
 }
 
+def XeGPUSubgroupDistribute : Pass<"xegpu-subgroup-distribute"> {
+  let summary = "Distribute XeGPU ops to work items";
+  let description = [{
+    The pass distributes subgroup level (SIMD) XeGPU ops to work items.
+  }];
+  let dependentDialects = [
+      "memref::MemRefDialect", "xegpu::XeGPUDialect", "vector::VectorDialect"
+  ];
+}
+
 #endif // MLIR_DIALECT_XEGPU_TRANSFORMS_PASSES_TD
diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h b/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
index 63ea26df06937..86b95721df60c 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
@@ -16,6 +16,8 @@ namespace xegpu {
 
 /// Appends patterns for folding aliasing ops into XeGPU ops into `patterns`.
 void populateXeGPUFoldAliasOpsPatterns(RewritePatternSet &patterns);
+/// Appends patterns for distributing XeGPU ops to work items into `patterns`.
+void populateXeGPUSubgroupDistributePatterns(RewritePatternSet &patterns);
 
 } // namespace xegpu
 } // namespace mlir
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
index 7fb64d3b97b87..124e904edb543 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
@@ -1,5 +1,6 @@
 add_mlir_dialect_library(MLIRXeGPUTransforms
   XeGPUFoldAliasOps.cpp
+  XeGPUSubgroupDistribute.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/XeGPU
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
new file mode 100644
index 0000000000000..99995d92e24b6
--- /dev/null
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -0,0 +1,249 @@
+//===- XeGPUSubgroupDistribute.cpp - XeGPU Subgroup Distribute Pass -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
+#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
+#include "mlir/Analysis/DataFlow/SparseAnalysis.h"
+#include "mlir/Analysis/DataFlowFramework.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
+#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/Support/raw_ostream.h"
+#include <algorithm>
+
+namespace mlir {
+namespace xegpu {
+#define GEN_PASS_DEF_XEGPUSUBGROUPDISTRIBUTE
+#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
+} // namespace xegpu
+} // namespace mlir
+
+#define DEBUG_TYPE "xegpu-subgroup-distribute"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
+
+using namespace mlir;
+using namespace mlir::dataflow;
+
+constexpr unsigned subgroupSize = 16;
+constexpr unsigned packedASizeInBits = 16;
+constexpr unsigned packedBSizeInBits = 32;
+
+namespace {
+struct Layout2D {
+  SmallVector<int64_t, 2> layout;
+  Layout2D() = default;
+  Layout2D(int64_t x, int64_t y) { layout.insert(layout.end(), {x, y}); }
+  bool operator==(const Layout2D &rhs) const {
+    return this->layout == rhs.layout;
+  }
+  bool operator<(const Layout2D &rhs) const {
+    return this->layout < rhs.layout;
+  }
+  void print(llvm::raw_ostream &os) const {
+    os << "{";
+    llvm::interleave(
+        layout, os, [&](int64_t a) { os << a; }, ", ");
+    os << "}";
+  }
+};
+
+using WiLayout = Layout2D;
+using WiData = Layout2D;
+
+struct SGMapInfo {
+  Layout2D wiLayout;
+  Layout2D wiData;
+  SGMapInfo() = default;
+  SGMapInfo(const Layout2D &layout, const Layout2D &data, unsigned bitWidth)
+      : wiLayout(layout), wiData(data) {}
+  bool operator==(const SGMapInfo &rhs) const {
+    return this->wiLayout == rhs.wiLayout && this->wiData == rhs.wiData;
+  }
+  bool operator<(const SGMapInfo &rhs) const {
+    return this->wiLayout < rhs.wiLayout || this->wiData < rhs.wiData;
+  }
+  void print(llvm::raw_ostream &os) const {
+    os << "{";
+    os << "layout: ";
+    wiLayout.print(os);
+    os << ", ";
+    os << "data: ";
+    wiData.print(os);
+    os << "}";
+  }
+};
+
+struct SGMapLatticeValue {
+private:
+  std::set<SGMapInfo> layouts;
+
+public:
+  SGMapLatticeValue() = default;
+  SGMapLatticeValue(const SGMapLatticeValue &other) = default;
+  SGMapLatticeValue(const WiLayout &layout, const WiData &data) {
+    layouts.insert(SGMapInfo(layout, data, 16));
+  }
+
+  bool operator==(const SGMapLatticeValue &other) const {
+    return this->layouts == other.layouts;
+  }
+
+  /// This function depends on a partial ordering of the lattice values.
+  static SGMapLatticeValue meet(const SGMapLatticeValue &lhs,
+                                const SGMapLatticeValue &rhs) {
+    SGMapLatticeValue res = lhs;
+    (void)res.addLayouts(rhs.layouts);
+    return res;
+  }
+
+  static SGMapLatticeValue join(const SGMapLatticeValue &lhs,
+                                const SGMapLatticeValue &rhs) {
+    // Should not be triggered by this analysis, but required by `Lattice<T>`
+    llvm_unreachable("Join should not be triggered by this test");
+  }
+
+  ChangeResult addLayouts(const std::set<SGMapInfo> &layouts) {
+    int sizeBefore = this->layouts.size();
+    this->layouts.insert(layouts.begin(), layouts.end());
+    int sizeAfter = this->layouts.size();
+    return sizeBefore == sizeAfter ? ChangeResult::NoChange
+                                   : ChangeResult::Change;
+  }
+
+  void print(raw_ostream &os) const {
+    os << "[";
+    llvm::interleave(
+        layouts, os, [&](const SGMapInfo &a) { a.print(os); }, ", ");
+    os << "]";
+  }
+
+  void clear() { layouts.clear(); }
+
+  std::set<SGMapInfo> getLayouts() const { return layouts; }
+};
+
+struct SGMap : public Lattice<SGMapLatticeValue> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SGMap)
+  using Lattice::Lattice;
+};
+
+static SGMapLatticeValue getSGMapForDPASOperand(Type operandTy,
+                                                unsigned operandNum) {
+  int packingFactorForB = packedBSizeInBits / operandTy.getIntOrFloatBitWidth();
+  int packingFactorForA =
+      operandTy.getIntOrFloatBitWidth() < packedBSizeInBits
+          ? packedASizeInBits / operandTy.getIntOrFloatBitWidth()
+          : 1;
+  return SGMapLatticeValue(WiLayout(1, subgroupSize),
+                           WiData(operandNum == 1 ? packingFactorForB : 1,
+                                  operandNum == 0 ? packingFactorForA : 1));
+}
+
+class SGMapPropagation : public SparseBackwardDataFlowAnalysis<SGMap> {
+public:
+  SGMapPropagation(DataFlowSolver &solver, SymbolTableCollection &symbolTable)
+      : SparseBackwardDataFlowAnalysis(solver, symbolTable) {}
+  using SparseBackwardDataFlowAnalysis::SparseBackwardDataFlowAnalysis;
+
+  LogicalResult visitOperation(Operation *op, ArrayRef<SGMap *> operands,
+                               ArrayRef<const SGMap *> results) override;
+
+  void visitBranchOperand(OpOperand &operand) override{};
+
+  void visitCallOperand(OpOperand &operand) override{};
+
+  void visitExternalCall(CallOpInterface call, ArrayRef<SGMap *> operands,
+                         ArrayRef<const SGMap *> results) override{};
+
+  void setToExitState(SGMap *lattice) override {
+    (void)lattice->meet(SGMapLatticeValue());
+  }
+};
+} // namespace
+
+LogicalResult
+SGMapPropagation::visitOperation(Operation *op, ArrayRef<SGMap *> operands,
+                                 ArrayRef<const SGMap *> results) {
+  /// Handle dpas
+  if (auto dpas = dyn_cast<xegpu::DpasOp>(op)) {
+    auto aTy = dpas.getLhsType().getElementType();
+    auto bTy = dpas.getRhsType().getElementType();
+    propagateIfChanged(operands[0],
+                       operands[0]->meet(getSGMapForDPASOperand(aTy, 0)));
+    propagateIfChanged(operands[1],
+                       operands[1]->meet(getSGMapForDPASOperand(bTy, 1)));
+    if (operands.size() > 2) {
+      auto cTy = dpas.getAccType().getElementType();
+      propagateIfChanged(operands[2],
+                         operands[2]->meet(getSGMapForDPASOperand(cTy, 2)));
+    }
+    return success();
+  }
+  for (const SGMap *r : results) {
+    // For each operand assume a default layout.
+    for (SGMap *operand : operands) {
+      meet(operand, *r);
+    }
+    addDependency(const_cast<SGMap *>(r), getProgramPointAfter(op));
+  }
+  return success();
+}
+
+void xegpu::populateXeGPUSubgroupDistributePatterns(
+    RewritePatternSet &patterns) {}
+
+namespace {
+
+class RunSGMapPropagation {
+public:
+  RunSGMapPropagation(Operation *op) {
+    SymbolTableCollection symbolTable;
+
+    solver.load<DeadCodeAnalysis>();
+    solver.load<SparseConstantPropagation>();
+    solver.load<SGMapPropagation>(symbolTable);
+    (void)solver.initializeAndRun(op);
+  }
+
+  SGMapLatticeValue getSGMap(Value val) {
+    auto *state = solver.lookupState<SGMap>(val);
+    if (!state)
+      return {};
+    return state->getValue();
+  }
+
+private:
+  DataFlowSolver solver;
+};
+
+struct XeGPUSubgroupDistributePass final
+    : public xegpu::impl::XeGPUSubgroupDistributeBase<
+          XeGPUSubgroupDistributePass> {
+  void runOnOperation() override;
+};
+
+} // namespace
+
+void XeGPUSubgroupDistributePass::runOnOperation() {
+  Operation *op = getOperation();
+
+  RunSGMapPropagation solver(op);
+
+  // Print analysis results
+  auto &os = llvm::outs();
+  op->walk([&](Operation *op) {
+    if (op->getResults().empty())
+      return;
+    auto layouts = solver.getSGMap(op->getResult(0));
+    os << "SGMap for " << op->getName() << ": ";
+    layouts.print(os);
+    os << "\n";
+  });
+}

>From d0c1144513b0c6c077c0b545889d8e0d90cf394d Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Sat, 1 Mar 2025 00:04:34 +0000
Subject: [PATCH 02/17] save work

---
 .../Transforms/XeGPUSubgroupDistribute.cpp    | 33 ++++++++++++++-----
 1 file changed, 25 insertions(+), 8 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 99995d92e24b6..d2f43b3448f95 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -9,11 +9,13 @@
 #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
 #include "mlir/Analysis/DataFlow/SparseAnalysis.h"
 #include "mlir/Analysis/DataFlowFramework.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/XeGPU/IR/XeGPU.h"
 #include "mlir/Dialect/XeGPU/Transforms/Passes.h"
 #include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
 #include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/Interfaces/DataLayoutInterfaces.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "llvm/Support/raw_ostream.h"
 #include <algorithm>
@@ -238,12 +240,27 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
 
   // Print analysis results
   auto &os = llvm::outs();
-  op->walk([&](Operation *op) {
-    if (op->getResults().empty())
-      return;
-    auto layouts = solver.getSGMap(op->getResult(0));
-    os << "SGMap for " << op->getName() << ": ";
-    layouts.print(os);
-    os << "\n";
-  });
+  // check if op is a function
+  // llvm::errs() << op->getName() << "\n";
+  if (auto modOp = dyn_cast<ModuleOp>(op)) {
+    for (auto funcOp : modOp.getOps<func::FuncOp>()) {
+      os << "SGMap for " << funcOp.getName() << ":\n";
+      // Function args
+      for (auto arg : funcOp.getArguments()) {
+        auto layouts = solver.getSGMap(arg);
+        os << "SGMap for " << arg << ": ";
+        layouts.print(os);
+        os << "\n";
+      }
+      // Function ops
+      funcOp.walk([&](Operation *op) {
+        if (op->getResults().empty())
+          return;
+        auto layouts = solver.getSGMap(op->getResult(0));
+        os << "SGMap for " << op->getName() << ": ";
+        layouts.print(os);
+        os << "\n";
+      });
+    }
+  }
 }

>From 0b8752c921b17a99ab58f48cebc7564cd1387256 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Tue, 4 Mar 2025 18:14:40 +0000
Subject: [PATCH 03/17] save work

---
 .../Transforms/XeGPUSubgroupDistribute.cpp    | 361 +++++++++++-------
 .../XeGPU/subgroup-map-propagation.mlir       | 186 +++++++++
 2 files changed, 402 insertions(+), 145 deletions(-)
 create mode 100644 mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index d2f43b3448f95..c80ba2ca5f3d1 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -11,12 +11,15 @@
 #include "mlir/Analysis/DataFlowFramework.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/XeGPU/IR/XeGPU.h"
 #include "mlir/Dialect/XeGPU/Transforms/Passes.h"
 #include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/Interfaces/DataLayoutInterfaces.h"
+#include "mlir/Support/LLVM.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/Support/LogicalResult.h"
 #include "llvm/Support/raw_ostream.h"
 #include <algorithm>
 
@@ -38,217 +41,264 @@ constexpr unsigned packedASizeInBits = 16;
 constexpr unsigned packedBSizeInBits = 32;
 
 namespace {
-struct Layout2D {
-  SmallVector<int64_t, 2> layout;
-  Layout2D() = default;
-  Layout2D(int64_t x, int64_t y) { layout.insert(layout.end(), {x, y}); }
-  bool operator==(const Layout2D &rhs) const {
-    return this->layout == rhs.layout;
-  }
-  bool operator<(const Layout2D &rhs) const {
-    return this->layout < rhs.layout;
-  }
-  void print(llvm::raw_ostream &os) const {
-    os << "{";
-    llvm::interleave(
-        layout, os, [&](int64_t a) { os << a; }, ", ");
-    os << "}";
-  }
+struct Layout {
+  SmallVector<int64_t, 3> layout;
+  Layout() = default;
+  Layout(const Layout &other) = default;
+  Layout(std::initializer_list<int64_t> list) : layout(list) {}
+  void print(llvm::raw_ostream &os) const;
+  size_t size() const { return layout.size(); }
 };
 
-using WiLayout = Layout2D;
-using WiData = Layout2D;
-
-struct SGMapInfo {
-  Layout2D wiLayout;
-  Layout2D wiData;
-  SGMapInfo() = default;
-  SGMapInfo(const Layout2D &layout, const Layout2D &data, unsigned bitWidth)
-      : wiLayout(layout), wiData(data) {}
-  bool operator==(const SGMapInfo &rhs) const {
-    return this->wiLayout == rhs.wiLayout && this->wiData == rhs.wiData;
-  }
-  bool operator<(const SGMapInfo &rhs) const {
-    return this->wiLayout < rhs.wiLayout || this->wiData < rhs.wiData;
-  }
-  void print(llvm::raw_ostream &os) const {
-    os << "{";
-    os << "layout: ";
-    wiLayout.print(os);
-    os << ", ";
-    os << "data: ";
-    wiData.print(os);
-    os << "}";
-  }
-};
+void Layout::print(llvm::raw_ostream &os) const {
+  os << "[";
+  llvm::interleaveComma(layout, os);
+  os << "]";
+}
+
+using WiLayout = Layout;
+using WiData = Layout;
 
-struct SGMapLatticeValue {
+struct SGMap {
 private:
-  std::set<SGMapInfo> layouts;
+  WiLayout layout;
+  WiData data;
 
 public:
-  SGMapLatticeValue() = default;
-  SGMapLatticeValue(const SGMapLatticeValue &other) = default;
-  SGMapLatticeValue(const WiLayout &layout, const WiData &data) {
-    layouts.insert(SGMapInfo(layout, data, 16));
+  SGMap() = default;
+  SGMap(const SGMap &other) = default;
+  SGMap(const WiLayout &layout, const WiData &data)
+      : layout(layout), data(data) {}
+
+  // Two lattice values are equal if they have `some` layout. The actual
+  // content of the layout does not matter.
+  bool operator==(const SGMap &other) const {
+    return this->isAssigned() == other.isAssigned();
   }
 
-  bool operator==(const SGMapLatticeValue &other) const {
-    return this->layouts == other.layouts;
-  }
+  static SGMap meet(const SGMap &lhs, const SGMap &rhs);
 
-  /// This function depends on a partial ordering of the lattice values.
-  static SGMapLatticeValue meet(const SGMapLatticeValue &lhs,
-                                const SGMapLatticeValue &rhs) {
-    SGMapLatticeValue res = lhs;
-    (void)res.addLayouts(rhs.layouts);
-    return res;
-  }
+  static SGMap join(const SGMap &lhs, const SGMap &rhs);
 
-  static SGMapLatticeValue join(const SGMapLatticeValue &lhs,
-                                const SGMapLatticeValue &rhs) {
-    // Should not be triggered by this analysis, but required by `Lattice<T>`
-    llvm_unreachable("Join should not be triggered by this test");
-  }
+  void print(raw_ostream &os) const;
 
-  ChangeResult addLayouts(const std::set<SGMapInfo> &layouts) {
-    int sizeBefore = this->layouts.size();
-    this->layouts.insert(layouts.begin(), layouts.end());
-    int sizeAfter = this->layouts.size();
-    return sizeBefore == sizeAfter ? ChangeResult::NoChange
-                                   : ChangeResult::Change;
-  }
+  bool isAssigned() const { return layout.size() > 0 && data.size() > 0; }
 
-  void print(raw_ostream &os) const {
-    os << "[";
-    llvm::interleave(
-        layouts, os, [&](const SGMapInfo &a) { a.print(os); }, ", ");
-    os << "]";
-  }
+  SGMap getTransposedLayout(ArrayRef<int64_t> permutation) const;
+};
+
+void SGMap::print(raw_ostream &os) const {
+  if (isAssigned()) {
+    os << "Layout: ";
+    layout.print(os);
+    os << ", Data: ";
+    data.print(os);
+  } else
+    os << "Not initialized";
+}
 
-  void clear() { layouts.clear(); }
+SGMap SGMap::meet(const SGMap &lhs, const SGMap &rhs) {
+  if (!lhs.isAssigned())
+    return rhs;
+  return lhs;
+}
 
-  std::set<SGMapInfo> getLayouts() const { return layouts; }
-};
+SGMap SGMap::join(const SGMap &lhs, const SGMap &rhs) {
+  // Should not be triggered by this analysis, but required by `Lattice<T>`
+  llvm_unreachable("Join should not be triggered by this test");
+}
 
-struct SGMap : public Lattice<SGMapLatticeValue> {
-  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SGMap)
+SGMap SGMap::getTransposedLayout(ArrayRef<int64_t> permutation) const {
+  if (!isAssigned())
+    return {};
+  WiLayout newLayout;
+  WiData newData;
+  for (auto idx : permutation) {
+    newLayout.layout.push_back(layout.layout[idx]);
+    newData.layout.push_back(data.layout[idx]);
+  }
+  return SGMap(newLayout, data);
+}
+
+struct SGMapLattice : public Lattice<SGMap> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SGMapLattice)
   using Lattice::Lattice;
 };
 
-static SGMapLatticeValue getSGMapForDPASOperand(Type operandTy,
-                                                unsigned operandNum) {
+/// Helper Functions
+///
+
+static SGMap getSGMapForDPASOperand(Type operandTy, unsigned operandNum) {
   int packingFactorForB = packedBSizeInBits / operandTy.getIntOrFloatBitWidth();
   int packingFactorForA =
       operandTy.getIntOrFloatBitWidth() < packedBSizeInBits
           ? packedASizeInBits / operandTy.getIntOrFloatBitWidth()
           : 1;
-  return SGMapLatticeValue(WiLayout(1, subgroupSize),
-                           WiData(operandNum == 1 ? packingFactorForB : 1,
-                                  operandNum == 0 ? packingFactorForA : 1));
+  return SGMap(WiLayout({1, subgroupSize}),
+               WiData({operandNum == 1 ? packingFactorForB : 1,
+                       operandNum == 0 ? packingFactorForA : 1}));
+}
+
+static SGMap getDefaultSgMap(Type ty) {
+  int packingFactor = 1;
+  if (ty.getIntOrFloatBitWidth() < packedASizeInBits)
+    packingFactor = packedBSizeInBits / ty.getIntOrFloatBitWidth();
+  return SGMap(WiLayout({1, subgroupSize}), WiData({1, packingFactor}));
 }
 
-class SGMapPropagation : public SparseBackwardDataFlowAnalysis<SGMap> {
+class SGMapPropagation : public SparseBackwardDataFlowAnalysis<SGMapLattice> {
+private:
+  void visitDpasOp(xegpu::DpasOp dpas, ArrayRef<SGMapLattice *> operands,
+                   ArrayRef<const SGMapLattice *> results);
+
+  void visitStoreNdOp(xegpu::StoreNdOp store, ArrayRef<SGMapLattice *> operands,
+                      ArrayRef<const SGMapLattice *> results);
+
+  void visitLoadNdOp(xegpu::LoadNdOp load, ArrayRef<SGMapLattice *> operands,
+                     ArrayRef<const SGMapLattice *> results);
+
+  void visitTransposeOp(vector::TransposeOp transpose,
+                        ArrayRef<SGMapLattice *> operands,
+                        ArrayRef<const SGMapLattice *> results);
+
 public:
   SGMapPropagation(DataFlowSolver &solver, SymbolTableCollection &symbolTable)
       : SparseBackwardDataFlowAnalysis(solver, symbolTable) {}
   using SparseBackwardDataFlowAnalysis::SparseBackwardDataFlowAnalysis;
 
-  LogicalResult visitOperation(Operation *op, ArrayRef<SGMap *> operands,
-                               ArrayRef<const SGMap *> results) override;
+  LogicalResult visitOperation(Operation *op, ArrayRef<SGMapLattice *> operands,
+                               ArrayRef<const SGMapLattice *> results) override;
 
   void visitBranchOperand(OpOperand &operand) override{};
 
   void visitCallOperand(OpOperand &operand) override{};
 
-  void visitExternalCall(CallOpInterface call, ArrayRef<SGMap *> operands,
-                         ArrayRef<const SGMap *> results) override{};
+  void visitExternalCall(CallOpInterface call,
+                         ArrayRef<SGMapLattice *> operands,
+                         ArrayRef<const SGMapLattice *> results) override{};
 
-  void setToExitState(SGMap *lattice) override {
-    (void)lattice->meet(SGMapLatticeValue());
+  void setToExitState(SGMapLattice *lattice) override {
+    (void)lattice->meet(SGMap());
   }
 };
 } // namespace
 
 LogicalResult
-SGMapPropagation::visitOperation(Operation *op, ArrayRef<SGMap *> operands,
-                                 ArrayRef<const SGMap *> results) {
-  /// Handle dpas
-  if (auto dpas = dyn_cast<xegpu::DpasOp>(op)) {
-    auto aTy = dpas.getLhsType().getElementType();
-    auto bTy = dpas.getRhsType().getElementType();
-    propagateIfChanged(operands[0],
-                       operands[0]->meet(getSGMapForDPASOperand(aTy, 0)));
-    propagateIfChanged(operands[1],
-                       operands[1]->meet(getSGMapForDPASOperand(bTy, 1)));
-    if (operands.size() > 2) {
-      auto cTy = dpas.getAccType().getElementType();
-      propagateIfChanged(operands[2],
-                         operands[2]->meet(getSGMapForDPASOperand(cTy, 2)));
-    }
-    return success();
-  }
-  for (const SGMap *r : results) {
-    // For each operand assume a default layout.
-    for (SGMap *operand : operands) {
-      meet(operand, *r);
+SGMapPropagation::visitOperation(Operation *op,
+                                 ArrayRef<SGMapLattice *> operands,
+                                 ArrayRef<const SGMapLattice *> results) {
+  if (auto dpas = dyn_cast<xegpu::DpasOp>(op))
+    visitDpasOp(dpas, operands, results);
+  else if (auto store = dyn_cast<xegpu::StoreNdOp>(op))
+    visitStoreNdOp(store, operands, results);
+  else if (auto load = dyn_cast<xegpu::LoadNdOp>(op))
+    visitLoadNdOp(load, operands, results);
+  else if (auto transpose = dyn_cast<vector::TransposeOp>(op))
+    visitTransposeOp(transpose, operands, results);
+  /// All other ops
+  else {
+    for (const SGMapLattice *r : results) {
+      for (SGMapLattice *operand : operands) {
+        if (r->getValue().isAssigned())
+          meet(operand, *r);
+      }
+      addDependency(const_cast<SGMapLattice *>(r), getProgramPointAfter(op));
     }
-    addDependency(const_cast<SGMap *>(r), getProgramPointAfter(op));
   }
   return success();
 }
 
-void xegpu::populateXeGPUSubgroupDistributePatterns(
-    RewritePatternSet &patterns) {}
+void SGMapPropagation::visitDpasOp(xegpu::DpasOp dpas,
+                                   ArrayRef<SGMapLattice *> operands,
+                                   ArrayRef<const SGMapLattice *> results) {
+  auto aTy = dpas.getLhsType().getElementType();
+  auto bTy = dpas.getRhsType().getElementType();
+  propagateIfChanged(operands[0],
+                     operands[0]->meet(getSGMapForDPASOperand(aTy, 0)));
+  propagateIfChanged(operands[1],
+                     operands[1]->meet(getSGMapForDPASOperand(bTy, 1)));
+  if (operands.size() > 2) {
+    auto cTy = dpas.getAccType().getElementType();
+    propagateIfChanged(operands[2],
+                       operands[2]->meet(getSGMapForDPASOperand(cTy, 2)));
+  }
+};
+
+void SGMapPropagation::visitStoreNdOp(xegpu::StoreNdOp store,
+                                      ArrayRef<SGMapLattice *> operands,
+                                      ArrayRef<const SGMapLattice *> results) {
+  auto storeLayout =
+      getDefaultSgMap(store.getTensorDescType().getElementType());
+  /// Both operands should have the same layout
+  for (SGMapLattice *operand : operands) {
+    propagateIfChanged(operand, operand->meet(storeLayout));
+  }
+}
+
+void SGMapPropagation::visitLoadNdOp(xegpu::LoadNdOp load,
+                                     ArrayRef<SGMapLattice *> operands,
+                                     ArrayRef<const SGMapLattice *> results) {
+  auto valueLayout = results[0]->getValue();
+  /// Need the layout of the value to propagate to the tensor descriptor.
+  if (!valueLayout.isAssigned())
+    return;
+  SGMap tensorDescLayout = valueLayout;
+  if (auto transpose = load.getTranspose())
+    tensorDescLayout = valueLayout.getTransposedLayout(transpose.value());
+  /// Propagate the new layout to the tensor descriptor operand.
+  propagateIfChanged(operands[0], operands[0]->meet(tensorDescLayout));
+}
+
+void SGMapPropagation::visitTransposeOp(
+    vector::TransposeOp transpose, ArrayRef<SGMapLattice *> operands,
+    ArrayRef<const SGMapLattice *> results) {
+  /// Need the layout of transpose result to propagate to the operands.
+  auto operandLayout = results[0]->getValue();
+  if (!operandLayout.isAssigned())
+    return;
+  auto newLayout =
+      operandLayout.getTransposedLayout(transpose.getPermutation());
+  /// Propagate the new layout to the vector operand.
+  propagateIfChanged(operands[0], operands[0]->meet(newLayout));
+}
 
 namespace {
 
 class RunSGMapPropagation {
 public:
-  RunSGMapPropagation(Operation *op) {
+  RunSGMapPropagation(Operation *op) : target(op) {
     SymbolTableCollection symbolTable;
-
     solver.load<DeadCodeAnalysis>();
     solver.load<SparseConstantPropagation>();
     solver.load<SGMapPropagation>(symbolTable);
     (void)solver.initializeAndRun(op);
   }
 
-  SGMapLatticeValue getSGMap(Value val) {
-    auto *state = solver.lookupState<SGMap>(val);
-    if (!state)
-      return {};
-    return state->getValue();
-  }
+  SGMap getSGMap(Value val);
+
+  void printAnalysisResult(llvm::raw_ostream &os);
 
 private:
   DataFlowSolver solver;
+  const Operation *target;
 };
-
-struct XeGPUSubgroupDistributePass final
-    : public xegpu::impl::XeGPUSubgroupDistributeBase<
-          XeGPUSubgroupDistributePass> {
-  void runOnOperation() override;
-};
-
 } // namespace
 
-void XeGPUSubgroupDistributePass::runOnOperation() {
-  Operation *op = getOperation();
-
-  RunSGMapPropagation solver(op);
+SGMap RunSGMapPropagation::getSGMap(Value val) {
+  auto *state = solver.lookupState<SGMapLattice>(val);
+  if (!state)
+    return {};
+  return state->getValue();
+}
 
-  // Print analysis results
-  auto &os = llvm::outs();
-  // check if op is a function
-  // llvm::errs() << op->getName() << "\n";
-  if (auto modOp = dyn_cast<ModuleOp>(op)) {
+void RunSGMapPropagation::printAnalysisResult(llvm::raw_ostream &os) {
+  if (auto modOp = dyn_cast<ModuleOp>(target)) {
     for (auto funcOp : modOp.getOps<func::FuncOp>()) {
-      os << "SGMap for " << funcOp.getName() << ":\n";
+      os << "sg_map for " << funcOp.getName() << ":\n";
       // Function args
       for (auto arg : funcOp.getArguments()) {
-        auto layouts = solver.getSGMap(arg);
-        os << "SGMap for " << arg << ": ";
+        auto layouts = getSGMap(arg);
+        os << "sg_map for " << arg << ": ";
         layouts.print(os);
         os << "\n";
       }
@@ -256,11 +306,32 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
       funcOp.walk([&](Operation *op) {
         if (op->getResults().empty())
           return;
-        auto layouts = solver.getSGMap(op->getResult(0));
-        os << "SGMap for " << op->getName() << ": ";
+        auto layouts = getSGMap(op->getResult(0));
+        os << "sg_map for " << op->getName() << ": ";
         layouts.print(os);
         os << "\n";
       });
     }
   }
 }
+
+namespace {
+struct XeGPUSubgroupDistributePass final
+    : public xegpu::impl::XeGPUSubgroupDistributeBase<
+          XeGPUSubgroupDistributePass> {
+  void runOnOperation() override;
+};
+} // namespace
+
+void XeGPUSubgroupDistributePass::runOnOperation() {
+  Operation *op = getOperation();
+
+  RunSGMapPropagation solver(op);
+
+  // Print analysis results
+  auto &os = llvm::outs();
+  solver.printAnalysisResult(os);
+}
+
+void xegpu::populateXeGPUSubgroupDistributePatterns(
+    RewritePatternSet &patterns) {}
diff --git a/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir b/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
new file mode 100644
index 0000000000000..217d510973a4f
--- /dev/null
+++ b/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
@@ -0,0 +1,186 @@
+func.func @test(%a: memref<8x16xf16>, %b : memref<16x16xf16>, %c : memref<8x16xf32>) {
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant dense<0.0> : vector<8x16xf32>
+  %0 = xegpu.create_nd_tdesc %a[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
+  %1 = xegpu.create_nd_tdesc %b[%c0, %c0] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
+  %2 = xegpu.load_nd %0 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+  %3 = xegpu.load_nd %1 : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+  %4 = xegpu.dpas %2, %3, %cst : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
+  %5 = xegpu.create_nd_tdesc %c[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
+  xegpu.store_nd %4, %5 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
+  return
+}
+
+// -----
+func.func @test(%a: memref<8x16xf16>, %b : memref<16x16xf16>, %c : memref<8x16xf32>) {
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant dense<0.0> : vector<8x16xf32>
+  %0 = xegpu.create_nd_tdesc %a[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
+  %1 = xegpu.create_nd_tdesc %b[%c0, %c0] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
+  %2 = xegpu.load_nd %0 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+  %3 = xegpu.load_nd %1 {transpose = array<i64: 1, 0>} : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+  %4 = xegpu.dpas %2, %3, %cst : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
+  %5 = xegpu.create_nd_tdesc %c[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
+  xegpu.store_nd %4, %5 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
+  return
+}
+
+
+// -----
+func.func @test(%arg0: vector<8x16xf16>, %arg1: vector<16x16xf16>) -> vector<8x16xf32> {
+  %0 = xegpu.dpas %arg0, %arg1 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+  return %0 : vector<8x16xf32>
+}
+
+// -----
+func.func @test(%arg0: vector<8x32xi8>, %arg1: vector<32x32xi8>) -> vector<8x32xi32> {
+  %0 = xegpu.dpas %arg0, %arg1 : vector<8x32xi8>, vector<32x32xi8> -> vector<8x32xi32>
+  return %0 : vector<8x32xi32>
+}
+
+// -----
+func.func @test(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>) -> vector<8x16xf32> {
+  %0 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+  %1 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+  %2 = xegpu.dpas %0, %1 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+  return %2 : vector<8x16xf32>
+}
+
+// -----
+func.func @test(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>) -> (vector<8x16xf32>, vector<8x16xf32>) {
+  %0 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+  %1 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+  %2 = xegpu.dpas %0, %1 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+  %3 = xegpu.dpas %0, %1 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+  return %2, %3 : vector<8x16xf32>, vector<8x16xf32>
+}
+
+// -----
+func.func @test(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>) -> (vector<8x16xf32>, vector<16x16xf16>) {
+  %0 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+  %1 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+  %2 = xegpu.dpas %0, %1 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+  return %2, %1 : vector<8x16xf32>, vector<16x16xf16>
+}
+
+// -----
+func.func @test(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>) -> vector<8x16xf32> {
+  %0 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+  %1 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+  %2 = arith.extf %1 : vector<16x16xf16> to vector<16x16xf32>
+  %3 = arith.truncf %2 : vector<16x16xf32> to vector<16x16xf16>
+  %4 = xegpu.dpas %0, %3 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+  return %4 : vector<8x16xf32>
+}
+
+// -----
+func.func @test(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>) -> vector<8x16xf32> {
+  %0 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+  %1 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+  %2 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+  %3 = arith.addf %1, %2 : vector<16x16xf16>
+  %4 = xegpu.dpas %0, %3 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+  return %4 : vector<8x16xf32>
+}
+
+// -----
+func.func @test(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>) -> vector<8x16xf32> {
+  %0 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+  %1 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+  %cst = arith.constant dense<1.000000e+00> : vector<16x16xf16>
+  %2 = arith.addf %1, %cst : vector<16x16xf16>
+  %3 = xegpu.dpas %0, %2 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+  return %3 : vector<8x16xf32>
+}
+
+// -----
+func.func @test(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>) -> (vector<8x16xf32>, vector<16x16xf16>) {
+  %0 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+  %1 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+  %2 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+  %3 = arith.addf %1, %2 : vector<16x16xf16>
+  %4 = xegpu.dpas %0, %3 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+  return %4, %2 : vector<8x16xf32>, vector<16x16xf16>
+}
+
+// -----
+func.func @test(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: index) -> vector<8x16xf32> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %0 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+  %1 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+  %2 = xegpu.dpas %0, %1 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+  %3:2 = scf.for %arg3 = %c0 to %arg2 step %c1 iter_args(%arg4 = %1, %arg5 = %2) -> (vector<16x16xf16>, vector<8x16xf32>) {
+    %4 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+    %5 = arith.addf %arg4, %4 : vector<16x16xf16>
+    %6 = xegpu.dpas %0, %5 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+    scf.yield %5, %6 : vector<16x16xf16>, vector<8x16xf32>
+  }
+  return %3#1 : vector<8x16xf32>
+}
+
+// -----
+func.func @test(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: i1) -> vector<8x16xf32> {
+  %0 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+  %1 = scf.if %arg2 -> (vector<16x16xf16>) {
+    %3 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+    scf.yield %3 : vector<16x16xf16>
+  } else {
+    %3 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+    scf.yield %3 : vector<16x16xf16>
+  }
+  %2 = xegpu.dpas %0, %1 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+  return %2 : vector<8x16xf32>
+}
+
+// -----
+func.func @test(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: i1) -> (vector<8x16xf32>, vector<16x16xf16>) {
+  %0 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+  %1 = scf.if %arg2 -> (vector<16x16xf16>) {
+    %3 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+    scf.yield %3 : vector<16x16xf16>
+  } else {
+    %3 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+    scf.yield %3 : vector<16x16xf16>
+  }
+  %2 = xegpu.dpas %0, %1 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+  return %2, %1 : vector<8x16xf32>, vector<16x16xf16>
+}
+
+// -----
+func.func @test(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: vector<16x16xf16>, %arg3: i1) -> vector<8x16xf32> {
+  %0 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+  %1 = scf.if %arg3 -> (vector<16x16xf16>) {
+    %3 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+    scf.yield %3 : vector<16x16xf16>
+  } else {
+    scf.yield %arg2 : vector<16x16xf16>
+  }
+  %2 = xegpu.dpas %0, %1 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+  return %2 : vector<8x16xf32>
+}
+
+// -----
+func.func @test(%arg0: !xegpu.tensor_desc<8x16xi16>, %arg1: !xegpu.tensor_desc<16x16xi16>) -> vector<8x16xf32> {
+  %0 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<8x16xi16> -> vector<8x16xi16>
+  %1 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xi16> -> vector<16x16xi16>
+  %2 = arith.bitcast %0 : vector<8x16xi16> to vector<8x16xf16>
+  %3 = arith.bitcast %1 : vector<16x16xi16> to vector<16x16xf16>
+  %4 = xegpu.dpas %2, %3 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+  return %4 : vector<8x16xf32>
+}
+
+// -----
+func.func @test(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: vector<16x16xf16>, %arg3: i1) -> vector<8x16xf32> {
+  %0 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+  %1 = scf.if %arg3 -> (vector<8x16xf32>) {
+    %2 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+    %3 = arith.addf %2, %arg2 : vector<16x16xf16>
+    %4 = xegpu.dpas %0, %3 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+    scf.yield %4 : vector<8x16xf32>
+  } else {
+    %2 = xegpu.dpas %0, %arg2 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+    scf.yield %2 : vector<8x16xf32>
+  }
+  return %1 : vector<8x16xf32>
+}

>From 2890d5031eab8c07e8b1e3526b1caa13f2574e0d Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Tue, 4 Mar 2025 22:58:28 +0000
Subject: [PATCH 04/17] save work

---
 .../Transforms/XeGPUSubgroupDistribute.cpp    | 51 ++++++++++++-
 .../XeGPU/subgroup-map-propagation.mlir       | 75 +++++++++++--------
 2 files changed, 95 insertions(+), 31 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index c80ba2ca5f3d1..4f6091d7bebfb 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -48,6 +48,7 @@ struct Layout {
   Layout(std::initializer_list<int64_t> list) : layout(list) {}
   void print(llvm::raw_ostream &os) const;
   size_t size() const { return layout.size(); }
+  int64_t operator[](size_t idx) const { return layout[idx]; }
 };
 
 void Layout::print(llvm::raw_ostream &os) const {
@@ -85,6 +86,9 @@ struct SGMap {
   bool isAssigned() const { return layout.size() > 0 && data.size() > 0; }
 
   SGMap getTransposedLayout(ArrayRef<int64_t> permutation) const;
+
+  const WiLayout &getLayout() const { return layout; }
+  const WiData &getData() const { return data; }
 };
 
 void SGMap::print(raw_ostream &os) const {
@@ -160,6 +164,9 @@ class SGMapPropagation : public SparseBackwardDataFlowAnalysis<SGMapLattice> {
   void visitTransposeOp(vector::TransposeOp transpose,
                         ArrayRef<SGMapLattice *> operands,
                         ArrayRef<const SGMapLattice *> results);
+  void visitVectorBitcastOp(vector::BitCastOp bitcast,
+                            ArrayRef<SGMapLattice *> operands,
+                            ArrayRef<const SGMapLattice *> results);
 
 public:
   SGMapPropagation(DataFlowSolver &solver, SymbolTableCollection &symbolTable)
@@ -195,16 +202,23 @@ SGMapPropagation::visitOperation(Operation *op,
     visitLoadNdOp(load, operands, results);
   else if (auto transpose = dyn_cast<vector::TransposeOp>(op))
     visitTransposeOp(transpose, operands, results);
+  else if (auto bitcast = dyn_cast<vector::BitCastOp>(op))
+    visitVectorBitcastOp(bitcast, operands, results);
   /// All other ops
   else {
     for (const SGMapLattice *r : results) {
       for (SGMapLattice *operand : operands) {
+        /// Propagate the layout of the result to the operand.
         if (r->getValue().isAssigned())
           meet(operand, *r);
       }
-      addDependency(const_cast<SGMapLattice *>(r), getProgramPointAfter(op));
     }
   }
+  /// Add a dependency from each reult to program point after the operation.
+  /// NOTE: not sure if this is required, but all other passes do this.
+  for (const SGMapLattice *r : results) {
+    addDependency(const_cast<SGMapLattice *>(r), getProgramPointAfter(op));
+  }
   return success();
 }
 
@@ -262,6 +276,41 @@ void SGMapPropagation::visitTransposeOp(
   propagateIfChanged(operands[0], operands[0]->meet(newLayout));
 }
 
+void SGMapPropagation::visitVectorBitcastOp(
+    vector::BitCastOp bitcast, ArrayRef<SGMapLattice *> operands,
+    ArrayRef<const SGMapLattice *> results) {
+  /// Need the layout of bitcast result to propagate to the operands.
+  auto resultLayout = results[0]->getValue();
+  if (!resultLayout.isAssigned())
+    return;
+  auto inElemTyBitWidth =
+      bitcast.getSourceVectorType().getElementType().getIntOrFloatBitWidth();
+  auto outElemTyBitWidth =
+      bitcast.getResultVectorType().getElementType().getIntOrFloatBitWidth();
+
+  /// WiLayout does not change.
+  WiLayout newWiLayout = resultLayout.getLayout();
+  WiData newWiData;
+  /// It's a widening bitcast
+  if (inElemTyBitWidth < outElemTyBitWidth) {
+    auto ratio = outElemTyBitWidth / inElemTyBitWidth;
+    const auto &currData = resultLayout.getData();
+    newWiData = resultLayout.getData()[0] == 1
+                    ? WiData({1, currData[1] * ratio})
+                    : WiData({currData[0] * ratio, 1});
+  } else {
+    /// It's a narrowing bitcast
+    auto ratio = inElemTyBitWidth / outElemTyBitWidth;
+    const auto &currData = resultLayout.getData();
+    newWiData = resultLayout.getData()[0] == 1
+                    ? WiData({1, currData[1] / ratio})
+                    : WiData({currData[0] / ratio, 1});
+  }
+
+  propagateIfChanged(operands[0],
+                     operands[0]->meet(SGMap(newWiLayout, newWiData)));
+}
+
 namespace {
 
 class RunSGMapPropagation {
diff --git a/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir b/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
index 217d510973a4f..a806dfd9f0a6c 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
@@ -25,52 +25,67 @@ func.func @test(%a: memref<8x16xf16>, %b : memref<16x16xf16>, %c : memref<8x16xf
   return
 }
 
-
 // -----
-func.func @test(%arg0: vector<8x16xf16>, %arg1: vector<16x16xf16>) -> vector<8x16xf32> {
-  %0 = xegpu.dpas %arg0, %arg1 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
-  return %0 : vector<8x16xf32>
+func.func @test(%a: memref<8x16xf16>, %b : memref<16x16xf16>, %c : memref<8x16xf32>) {
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant dense<0.0> : vector<8x16xf32>
+  %0 = xegpu.create_nd_tdesc %a[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
+  %1 = xegpu.create_nd_tdesc %b[%c0, %c0] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
+  %2 = xegpu.load_nd %0 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+  %3 = xegpu.load_nd %1 : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+  %6 = vector.transpose %3, [1, 0] : vector<16x16xf16> to vector<16x16xf16>
+  %4 = xegpu.dpas %2, %6, %cst : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
+  %5 = xegpu.create_nd_tdesc %c[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
+  xegpu.store_nd %4, %5 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
+  return
 }
 
 // -----
-func.func @test(%arg0: vector<8x32xi8>, %arg1: vector<32x32xi8>) -> vector<8x32xi32> {
-  %0 = xegpu.dpas %arg0, %arg1 : vector<8x32xi8>, vector<32x32xi8> -> vector<8x32xi32>
-  return %0 : vector<8x32xi32>
+func.func @test(%arg0: vector<8x32xi8>, %arg1: vector<32x16xi8>, %c : memref<8x16xi32>) {
+  %c0 = arith.constant 0 : index
+  %0 = xegpu.dpas %arg0, %arg1 : vector<8x32xi8>, vector<32x16xi8> -> vector<8x16xi32>
+  %1 = xegpu.create_nd_tdesc %c [%c0, %c0] : memref<8x16xi32> -> !xegpu.tensor_desc<8x16xi32>
+  xegpu.store_nd %0, %1 : vector<8x16xi32>, !xegpu.tensor_desc<8x16xi32>
+  return
 }
 
 // -----
 func.func @test(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>) -> vector<8x16xf32> {
   %0 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
   %1 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
-  %2 = xegpu.dpas %0, %1 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
-  return %2 : vector<8x16xf32>
-}
-
-// -----
-func.func @test(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>) -> (vector<8x16xf32>, vector<8x16xf32>) {
-  %0 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
-  %1 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
-  %2 = xegpu.dpas %0, %1 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
-  %3 = xegpu.dpas %0, %1 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
-  return %2, %3 : vector<8x16xf32>, vector<8x16xf32>
+  %2 = arith.extf %1 : vector<16x16xf16> to vector<16x16xf32>
+  %3 = arith.truncf %2 : vector<16x16xf32> to vector<16x16xf16>
+  %4 = xegpu.dpas %0, %3 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+  return %4 : vector<8x16xf32>
 }
 
 // -----
-func.func @test(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>) -> (vector<8x16xf32>, vector<16x16xf16>) {
-  %0 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
-  %1 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
-  %2 = xegpu.dpas %0, %1 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
-  return %2, %1 : vector<8x16xf32>, vector<16x16xf16>
+func.func @test_bitcast(%arg0: memref<8x16xi16>, %arg1: memref<32x16xi8>, %c : memref<8x16xi32>) {
+  %c0 = arith.constant 0 : index
+  %2 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x16xi16> -> !xegpu.tensor_desc<8x16xi16>
+  %3 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<32x16xi8> -> !xegpu.tensor_desc<32x16xi8>
+  %4 = xegpu.load_nd %2 : !xegpu.tensor_desc<8x16xi16> -> vector<8x16xi16>
+  %5 = xegpu.load_nd %3 : !xegpu.tensor_desc<32x16xi8> -> vector<32x16xi8>
+  %6 = vector.bitcast %4 : vector<8x16xi16> to vector<8x32xi8>
+  %0 = xegpu.dpas %6, %5 : vector<8x32xi8>, vector<32x16xi8> -> vector<8x16xi32>
+  %1 = xegpu.create_nd_tdesc %c [%c0, %c0] : memref<8x16xi32> -> !xegpu.tensor_desc<8x16xi32>
+  xegpu.store_nd %0, %1 : vector<8x16xi32>, !xegpu.tensor_desc<8x16xi32>
+  return
 }
 
 // -----
-func.func @test(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>) -> vector<8x16xf32> {
-  %0 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
-  %1 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
-  %2 = arith.extf %1 : vector<16x16xf16> to vector<16x16xf32>
-  %3 = arith.truncf %2 : vector<16x16xf32> to vector<16x16xf16>
-  %4 = xegpu.dpas %0, %3 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
-  return %4 : vector<8x16xf32>
+func.func @test_bitcast(%arg0: memref<8x32xi8>, %arg1: memref<16x32xi8>, %c : memref<8x16xf32>) {
+  %c0 = arith.constant 0 : index
+  %2 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x32xi8> -> !xegpu.tensor_desc<8x32xi8>
+  %3 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<16x32xi8> -> !xegpu.tensor_desc<16x32xi8>
+  %4 = xegpu.load_nd %2 : !xegpu.tensor_desc<8x32xi8> -> vector<8x32xi8>
+  %5 = xegpu.load_nd %3 : !xegpu.tensor_desc<16x32xi8> -> vector<16x32xi8>
+  %6 = vector.bitcast %4 : vector<8x32xi8> to vector<8x16xf16>
+  %7 = vector.bitcast %5 : vector<16x32xi8> to vector<16x16xf16>
+  %0 = xegpu.dpas %6, %7 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+  %1 = xegpu.create_nd_tdesc %c [%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
+  xegpu.store_nd %0, %1 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
+  return
 }
 
 // -----

>From 38a388e013ff0f98fa1d24f11c6bf24b6baa649a Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Wed, 5 Mar 2025 20:40:20 +0000
Subject: [PATCH 05/17] save work

---
 .../Transforms/XeGPUSubgroupDistribute.cpp    | 76 ++++++++++++++++++-
 .../XeGPU/subgroup-map-propagation.mlir       | 15 ++++
 2 files changed, 87 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 4f6091d7bebfb..eaf87087c6fe2 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -161,13 +161,26 @@ class SGMapPropagation : public SparseBackwardDataFlowAnalysis<SGMapLattice> {
   void visitLoadNdOp(xegpu::LoadNdOp load, ArrayRef<SGMapLattice *> operands,
                      ArrayRef<const SGMapLattice *> results);
 
+  void visitLoadGatherOp(xegpu::LoadGatherOp load,
+                         ArrayRef<SGMapLattice *> operands,
+                         ArrayRef<const SGMapLattice *> results);
+
   void visitTransposeOp(vector::TransposeOp transpose,
                         ArrayRef<SGMapLattice *> operands,
                         ArrayRef<const SGMapLattice *> results);
+
   void visitVectorBitcastOp(vector::BitCastOp bitcast,
                             ArrayRef<SGMapLattice *> operands,
                             ArrayRef<const SGMapLattice *> results);
 
+  void visitCreateNdDescOp(xegpu::CreateNdDescOp createNdDesc,
+                           ArrayRef<SGMapLattice *> operands,
+                           ArrayRef<const SGMapLattice *> results);
+
+  void visitCreateDescOp(xegpu::CreateDescOp createDesc,
+                         ArrayRef<SGMapLattice *> operands,
+                         ArrayRef<const SGMapLattice *> results);
+
 public:
   SGMapPropagation(DataFlowSolver &solver, SymbolTableCollection &symbolTable)
       : SparseBackwardDataFlowAnalysis(solver, symbolTable) {}
@@ -204,6 +217,12 @@ SGMapPropagation::visitOperation(Operation *op,
     visitTransposeOp(transpose, operands, results);
   else if (auto bitcast = dyn_cast<vector::BitCastOp>(op))
     visitVectorBitcastOp(bitcast, operands, results);
+  else if (auto loadGather = dyn_cast<xegpu::LoadGatherOp>(op))
+    visitLoadGatherOp(loadGather, operands, results);
+  else if (auto createNdDesc = dyn_cast<xegpu::CreateNdDescOp>(op))
+    visitCreateNdDescOp(createNdDesc, operands, results);
+  else if (auto createDesc = dyn_cast<xegpu::CreateDescOp>(op))
+    visitCreateDescOp(createDesc, operands, results);
   /// All other ops
   else {
     for (const SGMapLattice *r : results) {
@@ -215,7 +234,8 @@ SGMapPropagation::visitOperation(Operation *op,
     }
   }
   /// Add a dependency from each reult to program point after the operation.
-  /// NOTE: not sure if this is required, but all other passes do this.
+  /// NOTE: not sure if this is required, but all other similar analysis do
+  /// this.
   for (const SGMapLattice *r : results) {
     addDependency(const_cast<SGMapLattice *>(r), getProgramPointAfter(op));
   }
@@ -289,19 +309,18 @@ void SGMapPropagation::visitVectorBitcastOp(
       bitcast.getResultVectorType().getElementType().getIntOrFloatBitWidth();
 
   /// WiLayout does not change.
-  WiLayout newWiLayout = resultLayout.getLayout();
+  const WiLayout &newWiLayout = resultLayout.getLayout();
+  const WiData &currData = resultLayout.getData();
   WiData newWiData;
   /// It's a widening bitcast
   if (inElemTyBitWidth < outElemTyBitWidth) {
     auto ratio = outElemTyBitWidth / inElemTyBitWidth;
-    const auto &currData = resultLayout.getData();
     newWiData = resultLayout.getData()[0] == 1
                     ? WiData({1, currData[1] * ratio})
                     : WiData({currData[0] * ratio, 1});
   } else {
     /// It's a narrowing bitcast
     auto ratio = inElemTyBitWidth / outElemTyBitWidth;
-    const auto &currData = resultLayout.getData();
     newWiData = resultLayout.getData()[0] == 1
                     ? WiData({1, currData[1] / ratio})
                     : WiData({currData[0] / ratio, 1});
@@ -311,6 +330,55 @@ void SGMapPropagation::visitVectorBitcastOp(
                      operands[0]->meet(SGMap(newWiLayout, newWiData)));
 }
 
+void SGMapPropagation::visitLoadGatherOp(
+    xegpu::LoadGatherOp load, ArrayRef<SGMapLattice *> operands,
+    ArrayRef<const SGMapLattice *> results) {
+  auto valueLayout = results[0]->getValue();
+  /// Need the layout of the value to propagate to the tensor descriptor.
+  if (!valueLayout.isAssigned())
+    return;
+  /// LoadGatherOp has the transpose effect, so propagate the transposed layout
+  /// to the tensor descriptor.
+  SGMap tensorDescLayout = valueLayout.getTransposedLayout({1, 0});
+  /// Mask operand should have the same layout as the value but with wi_data =
+  /// [1, 1]
+  SGMap maskLayout = SGMap(valueLayout.getLayout(), WiData({1, 1}));
+  /// Propagate the new layout to the tensor descriptor operand.
+  propagateIfChanged(operands[0], operands[0]->meet(tensorDescLayout));
+  /// Propagate the new layout to the mask operand.
+  propagateIfChanged(operands[1], operands[1]->meet(maskLayout));
+}
+
+void SGMapPropagation::visitCreateNdDescOp(
+    xegpu::CreateNdDescOp createNdDesc, ArrayRef<SGMapLattice *> operands,
+    ArrayRef<const SGMapLattice *> results) {
+  auto descLayout = results[0]->getValue();
+  /// Need the layout of the descriptor to propagate to the operands.
+  if (!descLayout.isAssigned())
+    return;
+  /// Propagate the layout to the source operand.
+  propagateIfChanged(operands[0], operands[0]->meet(descLayout));
+  /// For all other operands propagate the same layout with wi_data = [1, 1]
+  SGMap layout = SGMap(descLayout.getLayout(), WiData({1, 1}));
+  for (size_t i = 1; i < operands.size(); ++i) {
+    propagateIfChanged(operands[i], operands[i]->meet(layout));
+  }
+}
+
+void SGMapPropagation::visitCreateDescOp(
+    xegpu::CreateDescOp createDesc, ArrayRef<SGMapLattice *> operands,
+    ArrayRef<const SGMapLattice *> results) {
+  auto descLayout = results[0]->getValue();
+  /// Need the layout of the descriptor to propagate to the operands.
+  if (!descLayout.isAssigned())
+    return;
+  /// Propagate the layout to the source operand.
+  propagateIfChanged(operands[0], operands[0]->meet(descLayout));
+  /// For offset operand propagate the same layout with wi_data = [1, 1]
+  SGMap layout = SGMap(descLayout.getLayout(), WiData({1, 1}));
+  propagateIfChanged(operands[1], operands[1]->meet(layout));
+}
+
 namespace {
 
 class RunSGMapPropagation {
diff --git a/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir b/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
index a806dfd9f0a6c..dfe9c75c36ed6 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
@@ -59,6 +59,21 @@ func.func @test(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<1
   return %4 : vector<8x16xf32>
 }
 
+// -----
+func.func @test_load_gather(%a: memref<8x16xf16>, %b : memref<256xf16>, %c : memref<8x16xf32>) {
+  %c0 = arith.constant 0 : index
+  %6 = xegpu.create_nd_tdesc %a[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
+  %7 = xegpu.load_nd %6 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+  %0 = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
+  %1 = arith.constant dense<1>: vector<16xi1>
+  %2 = xegpu.create_tdesc %b, %0 : memref<256xf16>, vector<16xindex> -> !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16>>
+  %3 = xegpu.load %2, %1 <{transpose}> : !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16>>, vector<16xi1> -> vector<16x16xf16>
+  %4 = xegpu.dpas %7, %3 : vector<8x16xf16>, vector<16x16xf16>-> vector<8x16xf32>
+  %5 = xegpu.create_nd_tdesc %c[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
+  xegpu.store_nd %4, %5 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
+  return
+}
+
 // -----
 func.func @test_bitcast(%arg0: memref<8x16xi16>, %arg1: memref<32x16xi8>, %c : memref<8x16xi32>) {
   %c0 = arith.constant 0 : index

>From 8996f11f39931c9540572a630f9e726d810cd5fa Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Wed, 5 Mar 2025 22:07:34 +0000
Subject: [PATCH 06/17] save work

---
 .../Transforms/XeGPUSubgroupDistribute.cpp    | 56 ++++++++++++++++---
 .../XeGPU/subgroup-map-propagation.mlir       | 36 ++++++++----
 2 files changed, 72 insertions(+), 20 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index eaf87087c6fe2..871138b8dbc16 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -158,6 +158,10 @@ class SGMapPropagation : public SparseBackwardDataFlowAnalysis<SGMapLattice> {
   void visitStoreNdOp(xegpu::StoreNdOp store, ArrayRef<SGMapLattice *> operands,
                       ArrayRef<const SGMapLattice *> results);
 
+  void visitStoreScatterOp(xegpu::StoreScatterOp storeScatter,
+                           ArrayRef<SGMapLattice *> operands,
+                           ArrayRef<const SGMapLattice *> results);
+
   void visitLoadNdOp(xegpu::LoadNdOp load, ArrayRef<SGMapLattice *> operands,
                      ArrayRef<const SGMapLattice *> results);
 
@@ -223,6 +227,8 @@ SGMapPropagation::visitOperation(Operation *op,
     visitCreateNdDescOp(createNdDesc, operands, results);
   else if (auto createDesc = dyn_cast<xegpu::CreateDescOp>(op))
     visitCreateDescOp(createDesc, operands, results);
+  else if (auto storeScatter = dyn_cast<xegpu::StoreScatterOp>(op))
+    visitStoreScatterOp(storeScatter, operands, results);
   /// All other ops
   else {
     for (const SGMapLattice *r : results) {
@@ -379,6 +385,22 @@ void SGMapPropagation::visitCreateDescOp(
   propagateIfChanged(operands[1], operands[1]->meet(layout));
 }
 
+void SGMapPropagation::visitStoreScatterOp(
+    xegpu::StoreScatterOp storeScatter, ArrayRef<SGMapLattice *> operands,
+    ArrayRef<const SGMapLattice *> results) {
+  /// StoreScatterOp has the transpose effect. Value has the regular layout,
+  /// while the tensor descriptor has the transposed layout.
+  auto valueLayout =
+      getDefaultSgMap(storeScatter.getTensorDescType().getElementType());
+  auto storeScatterLayout = valueLayout.getTransposedLayout({1, 0});
+  /// Propagate the value layout.
+  propagateIfChanged(operands[0], operands[0]->meet(valueLayout));
+  /// Propagate the tensor descriptor layout.
+  propagateIfChanged(operands[1], operands[1]->meet(storeScatterLayout));
+  /// Propagate the mask layout.
+  propagateIfChanged(operands[2], operands[2]->meet(valueLayout));
+}
+
 namespace {
 
 class RunSGMapPropagation {
@@ -411,20 +433,25 @@ SGMap RunSGMapPropagation::getSGMap(Value val) {
 void RunSGMapPropagation::printAnalysisResult(llvm::raw_ostream &os) {
   if (auto modOp = dyn_cast<ModuleOp>(target)) {
     for (auto funcOp : modOp.getOps<func::FuncOp>()) {
-      os << "sg_map for " << funcOp.getName() << ":\n";
-      // Function args
+      os << "function: " << funcOp.getName() << ":\n";
+      // Function arguments
       for (auto arg : funcOp.getArguments()) {
-        auto layouts = getSGMap(arg);
-        os << "sg_map for " << arg << ": ";
-        layouts.print(os);
+        auto layout = getSGMap(arg);
+        os << "argument: " << arg << "\n";
+        os << "sg_map  : ";
+        layout.print(os);
         os << "\n";
       }
       // Function ops
       funcOp.walk([&](Operation *op) {
+        // Skip ops that do not have results
         if (op->getResults().empty())
           return;
         auto layouts = getSGMap(op->getResult(0));
-        os << "sg_map for " << op->getName() << ": ";
+        os << "op    : ";
+        op->print(os);
+        os << "\n";
+        os << "sg_map: ";
         layouts.print(os);
         os << "\n";
       });
@@ -436,7 +463,18 @@ namespace {
 struct XeGPUSubgroupDistributePass final
     : public xegpu::impl::XeGPUSubgroupDistributeBase<
           XeGPUSubgroupDistributePass> {
+  XeGPUSubgroupDistributePass() = default;
+  XeGPUSubgroupDistributePass(const XeGPUSubgroupDistributePass &other)
+      : xegpu::impl::XeGPUSubgroupDistributeBase<XeGPUSubgroupDistributePass>(
+            other) {
+    this->printOnly = other.printOnly;
+  }
   void runOnOperation() override;
+  /// Print sg map propagation analysis result and exit for testing purposes.
+  Option<bool> printOnly{
+      *this, "print-only", llvm::cl::init(false),
+      llvm::cl::desc(
+          "Print the result of the subgroup map propagation analysis")};
 };
 } // namespace
 
@@ -446,8 +484,10 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
   RunSGMapPropagation solver(op);
 
   // Print analysis results
-  auto &os = llvm::outs();
-  solver.printAnalysisResult(os);
+  if (printOnly) {
+    auto &os = llvm::outs();
+    solver.printAnalysisResult(os);
+  }
 }
 
 void xegpu::populateXeGPUSubgroupDistributePatterns(
diff --git a/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir b/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
index dfe9c75c36ed6..76a44fe16440b 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
@@ -1,4 +1,4 @@
-func.func @test(%a: memref<8x16xf16>, %b : memref<16x16xf16>, %c : memref<8x16xf32>) {
+func.func @test_dpas_1(%a: memref<8x16xf16>, %b : memref<16x16xf16>, %c : memref<8x16xf32>) {
   %c0 = arith.constant 0 : index
   %cst = arith.constant dense<0.0> : vector<8x16xf32>
   %0 = xegpu.create_nd_tdesc %a[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
@@ -12,7 +12,16 @@ func.func @test(%a: memref<8x16xf16>, %b : memref<16x16xf16>, %c : memref<8x16xf
 }
 
 // -----
-func.func @test(%a: memref<8x16xf16>, %b : memref<16x16xf16>, %c : memref<8x16xf32>) {
+func.func @test_dpas_2(%arg0: vector<8x32xi8>, %arg1: vector<32x16xi8>, %c : memref<8x16xi32>) {
+  %c0 = arith.constant 0 : index
+  %0 = xegpu.dpas %arg0, %arg1 : vector<8x32xi8>, vector<32x16xi8> -> vector<8x16xi32>
+  %1 = xegpu.create_nd_tdesc %c [%c0, %c0] : memref<8x16xi32> -> !xegpu.tensor_desc<8x16xi32>
+  xegpu.store_nd %0, %1 : vector<8x16xi32>, !xegpu.tensor_desc<8x16xi32>
+  return
+}
+
+// -----
+func.func @test_transpose_1(%a: memref<8x16xf16>, %b : memref<16x16xf16>, %c : memref<8x16xf32>) {
   %c0 = arith.constant 0 : index
   %cst = arith.constant dense<0.0> : vector<8x16xf32>
   %0 = xegpu.create_nd_tdesc %a[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
@@ -26,7 +35,7 @@ func.func @test(%a: memref<8x16xf16>, %b : memref<16x16xf16>, %c : memref<8x16xf
 }
 
 // -----
-func.func @test(%a: memref<8x16xf16>, %b : memref<16x16xf16>, %c : memref<8x16xf32>) {
+func.func @test_transpose_2(%a: memref<8x16xf16>, %b : memref<16x16xf16>, %c : memref<8x16xf32>) {
   %c0 = arith.constant 0 : index
   %cst = arith.constant dense<0.0> : vector<8x16xf32>
   %0 = xegpu.create_nd_tdesc %a[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
@@ -40,17 +49,10 @@ func.func @test(%a: memref<8x16xf16>, %b : memref<16x16xf16>, %c : memref<8x16xf
   return
 }
 
-// -----
-func.func @test(%arg0: vector<8x32xi8>, %arg1: vector<32x16xi8>, %c : memref<8x16xi32>) {
-  %c0 = arith.constant 0 : index
-  %0 = xegpu.dpas %arg0, %arg1 : vector<8x32xi8>, vector<32x16xi8> -> vector<8x16xi32>
-  %1 = xegpu.create_nd_tdesc %c [%c0, %c0] : memref<8x16xi32> -> !xegpu.tensor_desc<8x16xi32>
-  xegpu.store_nd %0, %1 : vector<8x16xi32>, !xegpu.tensor_desc<8x16xi32>
-  return
-}
+
 
 // -----
-func.func @test(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>) -> vector<8x16xf32> {
+func.func @test_extf_truncf(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>) -> vector<8x16xf32> {
   %0 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
   %1 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
   %2 = arith.extf %1 : vector<16x16xf16> to vector<16x16xf32>
@@ -74,6 +76,16 @@ func.func @test_load_gather(%a: memref<8x16xf16>, %b : memref<256xf16>, %c : mem
   return
 }
 
+// -----
+func.func @test_store_scatter(%c: memref<128xf32>){
+  %cst = arith.constant dense<1.000000e+00> : vector<8x16xf32>
+  %1 = arith.constant dense<1> : vector<16xi1>
+  %2 = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
+  %3 = xegpu.create_tdesc %c, %2 : memref<128xf32>, vector<16xindex> -> !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<chunk_size = 8>>
+  xegpu.store %cst, %3, %1 <{transpose}> : vector<8x16xf32>, !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<chunk_size = 8>>, vector<16xi1>
+  return
+}
+
 // -----
 func.func @test_bitcast(%arg0: memref<8x16xi16>, %arg1: memref<32x16xi8>, %c : memref<8x16xi32>) {
   %c0 = arith.constant 0 : index

>From 47888451211e237cb1f0c2b443552921a73805cc Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Thu, 6 Mar 2025 05:20:50 +0000
Subject: [PATCH 07/17] save work

---
 .../Transforms/XeGPUSubgroupDistribute.cpp    | 23 +++--
 .../XeGPU/subgroup-map-propagation.mlir       | 97 ++++++++++---------
 2 files changed, 64 insertions(+), 56 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 871138b8dbc16..4182f2176f04b 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -93,12 +93,12 @@ struct SGMap {
 
 void SGMap::print(raw_ostream &os) const {
   if (isAssigned()) {
-    os << "Layout: ";
+    os << "wi_layout: ";
     layout.print(os);
-    os << ", Data: ";
+    os << ", wi_data: ";
     data.print(os);
   } else
-    os << "Not initialized";
+    os << "Not assigned.";
 }
 
 SGMap SGMap::meet(const SGMap &lhs, const SGMap &rhs) {
@@ -447,13 +447,20 @@ void RunSGMapPropagation::printAnalysisResult(llvm::raw_ostream &os) {
         // Skip ops that do not have results
         if (op->getResults().empty())
           return;
-        auto layouts = getSGMap(op->getResult(0));
         os << "op    : ";
-        op->print(os);
-        os << "\n";
-        os << "sg_map: ";
-        layouts.print(os);
+        /// For control-flow ops, print the op name only.
+        if (isa<BranchOpInterface>(op) || isa<RegionBranchOpInterface>(op))
+          os << op->getName();
+        else
+          op->print(os);
         os << "\n";
+        /// Print the sg_map for each result.
+        for (auto [i, r] : llvm::enumerate(op->getResults())) {
+          auto layout = getSGMap(r);
+          os << "sg_map for result #" << i << ": ";
+          layout.print(os);
+          os << "\n";
+        }
       });
     }
   }
diff --git a/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir b/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
index 76a44fe16440b..81c4244091e87 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
@@ -1,4 +1,4 @@
-func.func @test_dpas_1(%a: memref<8x16xf16>, %b : memref<16x16xf16>, %c : memref<8x16xf32>) {
+func.func @test_dpas_op_1(%a: memref<8x16xf16>, %b : memref<16x16xf16>, %c : memref<8x16xf32>) {
   %c0 = arith.constant 0 : index
   %cst = arith.constant dense<0.0> : vector<8x16xf32>
   %0 = xegpu.create_nd_tdesc %a[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
@@ -12,7 +12,7 @@ func.func @test_dpas_1(%a: memref<8x16xf16>, %b : memref<16x16xf16>, %c : memref
 }
 
 // -----
-func.func @test_dpas_2(%arg0: vector<8x32xi8>, %arg1: vector<32x16xi8>, %c : memref<8x16xi32>) {
+func.func @test_dpas_op_2(%arg0: vector<8x32xi8>, %arg1: vector<32x16xi8>, %c : memref<8x16xi32>) {
   %c0 = arith.constant 0 : index
   %0 = xegpu.dpas %arg0, %arg1 : vector<8x32xi8>, vector<32x16xi8> -> vector<8x16xi32>
   %1 = xegpu.create_nd_tdesc %c [%c0, %c0] : memref<8x16xi32> -> !xegpu.tensor_desc<8x16xi32>
@@ -21,7 +21,7 @@ func.func @test_dpas_2(%arg0: vector<8x32xi8>, %arg1: vector<32x16xi8>, %c : mem
 }
 
 // -----
-func.func @test_transpose_1(%a: memref<8x16xf16>, %b : memref<16x16xf16>, %c : memref<8x16xf32>) {
+func.func @test_transpose_op_1(%a: memref<8x16xf16>, %b : memref<16x16xf16>, %c : memref<8x16xf32>) {
   %c0 = arith.constant 0 : index
   %cst = arith.constant dense<0.0> : vector<8x16xf32>
   %0 = xegpu.create_nd_tdesc %a[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
@@ -35,7 +35,7 @@ func.func @test_transpose_1(%a: memref<8x16xf16>, %b : memref<16x16xf16>, %c : m
 }
 
 // -----
-func.func @test_transpose_2(%a: memref<8x16xf16>, %b : memref<16x16xf16>, %c : memref<8x16xf32>) {
+func.func @test_transpose_op_2(%a: memref<8x16xf16>, %b : memref<16x16xf16>, %c : memref<8x16xf32>) {
   %c0 = arith.constant 0 : index
   %cst = arith.constant dense<0.0> : vector<8x16xf32>
   %0 = xegpu.create_nd_tdesc %a[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
@@ -52,7 +52,7 @@ func.func @test_transpose_2(%a: memref<8x16xf16>, %b : memref<16x16xf16>, %c : m
 
 
 // -----
-func.func @test_extf_truncf(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>) -> vector<8x16xf32> {
+func.func @test_extf_truncf_op(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>) -> vector<8x16xf32> {
   %0 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
   %1 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
   %2 = arith.extf %1 : vector<16x16xf16> to vector<16x16xf32>
@@ -62,7 +62,7 @@ func.func @test_extf_truncf(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.t
 }
 
 // -----
-func.func @test_load_gather(%a: memref<8x16xf16>, %b : memref<256xf16>, %c : memref<8x16xf32>) {
+func.func @test_load_gather_op(%a: memref<8x16xf16>, %b : memref<256xf16>, %c : memref<8x16xf32>) {
   %c0 = arith.constant 0 : index
   %6 = xegpu.create_nd_tdesc %a[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
   %7 = xegpu.load_nd %6 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
@@ -77,7 +77,7 @@ func.func @test_load_gather(%a: memref<8x16xf16>, %b : memref<256xf16>, %c : mem
 }
 
 // -----
-func.func @test_store_scatter(%c: memref<128xf32>){
+func.func @test_store_scatter_op(%c: memref<128xf32>){
   %cst = arith.constant dense<1.000000e+00> : vector<8x16xf32>
   %1 = arith.constant dense<1> : vector<16xi1>
   %2 = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
@@ -87,7 +87,7 @@ func.func @test_store_scatter(%c: memref<128xf32>){
 }
 
 // -----
-func.func @test_bitcast(%arg0: memref<8x16xi16>, %arg1: memref<32x16xi8>, %c : memref<8x16xi32>) {
+func.func @test_bitcast_op_1(%arg0: memref<8x16xi16>, %arg1: memref<32x16xi8>, %c : memref<8x16xi32>) {
   %c0 = arith.constant 0 : index
   %2 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x16xi16> -> !xegpu.tensor_desc<8x16xi16>
   %3 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<32x16xi8> -> !xegpu.tensor_desc<32x16xi8>
@@ -101,7 +101,7 @@ func.func @test_bitcast(%arg0: memref<8x16xi16>, %arg1: memref<32x16xi8>, %c : m
 }
 
 // -----
-func.func @test_bitcast(%arg0: memref<8x32xi8>, %arg1: memref<16x32xi8>, %c : memref<8x16xf32>) {
+func.func @test_bitcast_op_2(%arg0: memref<8x32xi8>, %arg1: memref<16x32xi8>, %c : memref<8x16xf32>) {
   %c0 = arith.constant 0 : index
   %2 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x32xi8> -> !xegpu.tensor_desc<8x32xi8>
   %3 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<16x32xi8> -> !xegpu.tensor_desc<16x32xi8>
@@ -116,53 +116,51 @@ func.func @test_bitcast(%arg0: memref<8x32xi8>, %arg1: memref<16x32xi8>, %c : me
 }
 
 // -----
-func.func @test(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>) -> vector<8x16xf32> {
+func.func @test_binary_op_1(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: !xegpu.tensor_desc<8x16xf32>) {
   %0 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
   %1 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
   %2 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
   %3 = arith.addf %1, %2 : vector<16x16xf16>
   %4 = xegpu.dpas %0, %3 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
-  return %4 : vector<8x16xf32>
+  xegpu.store_nd %4, %arg2 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
+  return
 }
 
 // -----
-func.func @test(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>) -> vector<8x16xf32> {
+func.func @test_binary_op_2(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: !xegpu.tensor_desc<8x16xf32>, %arg3: !xegpu.tensor_desc<16x16xf16>) {
   %0 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
   %1 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
   %cst = arith.constant dense<1.000000e+00> : vector<16x16xf16>
   %2 = arith.addf %1, %cst : vector<16x16xf16>
   %3 = xegpu.dpas %0, %2 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
-  return %3 : vector<8x16xf32>
-}
-
-// -----
-func.func @test(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>) -> (vector<8x16xf32>, vector<16x16xf16>) {
-  %0 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
-  %1 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
-  %2 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
-  %3 = arith.addf %1, %2 : vector<16x16xf16>
-  %4 = xegpu.dpas %0, %3 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
-  return %4, %2 : vector<8x16xf32>, vector<16x16xf16>
+  xegpu.store_nd %3, %arg2 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
+  xegpu.store_nd %2, %arg3 : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16>
+  return
 }
 
 // -----
-func.func @test(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: index) -> vector<8x16xf32> {
+func.func @test_for_op(%a: memref<8x128xf16>, %b : memref<128x16xf16>, %c : memref<8x16xf32>) {
   %c0 = arith.constant 0 : index
-  %c1 = arith.constant 1 : index
-  %0 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
-  %1 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
-  %2 = xegpu.dpas %0, %1 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
-  %3:2 = scf.for %arg3 = %c0 to %arg2 step %c1 iter_args(%arg4 = %1, %arg5 = %2) -> (vector<16x16xf16>, vector<8x16xf32>) {
-    %4 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
-    %5 = arith.addf %arg4, %4 : vector<16x16xf16>
-    %6 = xegpu.dpas %0, %5 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
-    scf.yield %5, %6 : vector<16x16xf16>, vector<8x16xf32>
+  %c128 = arith.constant 128 : index
+  %c16 = arith.constant 16 : index
+  %0 = xegpu.create_nd_tdesc %a[%c0, %c0] : memref<8x128xf16> -> !xegpu.tensor_desc<8x16xf16>
+  %1 = xegpu.create_nd_tdesc %b[%c0, %c0] : memref<128x16xf16> -> !xegpu.tensor_desc<16x16xf16>
+  %2 = arith.constant dense<0.0> : vector<8x16xf32>
+  %3:3 = scf.for %k = %c0 to %c128 step %c16 iter_args(%arg0 = %0, %arg1 = %1, %arg2 = %2) -> (!xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<16x16xf16>, vector<8x16xf32>) {
+    %4 = xegpu.load_nd %arg0 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+    %5 = xegpu.load_nd %arg1 : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+    %6 = xegpu.dpas %4, %5, %arg2 : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
+    %7 = xegpu.update_nd_offset %arg0, [%c0, %c16] : !xegpu.tensor_desc<8x16xf16>
+    %8 = xegpu.update_nd_offset %arg1, [%c16, %c0] : !xegpu.tensor_desc<16x16xf16>
+    scf.yield %7, %8, %6 : !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<16x16xf16>, vector<8x16xf32>
   }
-  return %3#1 : vector<8x16xf32>
+  %9 = xegpu.create_nd_tdesc %c[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
+  xegpu.store_nd %3#2, %9 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
+  return
 }
 
 // -----
-func.func @test(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: i1) -> vector<8x16xf32> {
+func.func @test_if_op_1(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: i1, %arg3: !xegpu.tensor_desc<8x16xf32>){
   %0 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
   %1 = scf.if %arg2 -> (vector<16x16xf16>) {
     %3 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
@@ -172,11 +170,12 @@ func.func @test(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<1
     scf.yield %3 : vector<16x16xf16>
   }
   %2 = xegpu.dpas %0, %1 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
-  return %2 : vector<8x16xf32>
+  xegpu.store_nd %2, %arg3 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
+  return
 }
 
 // -----
-func.func @test(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: i1) -> (vector<8x16xf32>, vector<16x16xf16>) {
+func.func @test_if_op_2(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: i1, %arg3: !xegpu.tensor_desc<8x16xf32>, %arg4: !xegpu.tensor_desc<16x16xf16>){
   %0 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
   %1 = scf.if %arg2 -> (vector<16x16xf16>) {
     %3 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
@@ -186,7 +185,9 @@ func.func @test(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<1
     scf.yield %3 : vector<16x16xf16>
   }
   %2 = xegpu.dpas %0, %1 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
-  return %2, %1 : vector<8x16xf32>, vector<16x16xf16>
+  xegpu.store_nd %2, %arg3 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
+  xegpu.store_nd %1, %arg4 : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16>
+  return
 }
 
 // -----
@@ -202,16 +203,6 @@ func.func @test(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<1
   return %2 : vector<8x16xf32>
 }
 
-// -----
-func.func @test(%arg0: !xegpu.tensor_desc<8x16xi16>, %arg1: !xegpu.tensor_desc<16x16xi16>) -> vector<8x16xf32> {
-  %0 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<8x16xi16> -> vector<8x16xi16>
-  %1 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xi16> -> vector<16x16xi16>
-  %2 = arith.bitcast %0 : vector<8x16xi16> to vector<8x16xf16>
-  %3 = arith.bitcast %1 : vector<16x16xi16> to vector<16x16xf16>
-  %4 = xegpu.dpas %2, %3 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
-  return %4 : vector<8x16xf32>
-}
-
 // -----
 func.func @test(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: vector<16x16xf16>, %arg3: i1) -> vector<8x16xf32> {
   %0 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
@@ -226,3 +217,13 @@ func.func @test(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<1
   }
   return %1 : vector<8x16xf32>
 }
+
+// -----
+func.func @test(%arg0: !xegpu.tensor_desc<8x16xi16>, %arg1: !xegpu.tensor_desc<16x16xi16>) -> vector<8x16xf32> {
+  %0 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<8x16xi16> -> vector<8x16xi16>
+  %1 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xi16> -> vector<16x16xi16>
+  %2 = arith.bitcast %0 : vector<8x16xi16> to vector<8x16xf16>
+  %3 = arith.bitcast %1 : vector<16x16xi16> to vector<16x16xf16>
+  %4 = xegpu.dpas %2, %3 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+  return %4 : vector<8x16xf32>
+}

>From c7b60d156c6a91e10fe84146f56ae2efa8a344b5 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Thu, 6 Mar 2025 22:29:27 +0000
Subject: [PATCH 08/17] save work

---
 .../Transforms/XeGPUSubgroupDistribute.cpp    | 169 +++++++++++-------
 .../XeGPU/subgroup-map-propagation.mlir       |  61 ++++---
 2 files changed, 142 insertions(+), 88 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 4182f2176f04b..13030d5c74ea0 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -15,13 +15,7 @@
 #include "mlir/Dialect/XeGPU/IR/XeGPU.h"
 #include "mlir/Dialect/XeGPU/Transforms/Passes.h"
 #include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
-#include "mlir/IR/BuiltinAttributes.h"
-#include "mlir/Interfaces/DataLayoutInterfaces.h"
-#include "mlir/Support/LLVM.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#include "llvm/Support/LogicalResult.h"
 #include "llvm/Support/raw_ostream.h"
-#include <algorithm>
 
 namespace mlir {
 namespace xegpu {
@@ -41,6 +35,13 @@ constexpr unsigned packedASizeInBits = 16;
 constexpr unsigned packedBSizeInBits = 32;
 
 namespace {
+
+///===----------------------------------------------------------------------===///
+/// Layout
+///===----------------------------------------------------------------------===///
+
+/// Helper class to store the ND layout of work items within a subgroup and data
+/// owned by each work item.
 struct Layout {
   SmallVector<int64_t, 3> layout;
   Layout() = default;
@@ -57,9 +58,30 @@ void Layout::print(llvm::raw_ostream &os) const {
   os << "]";
 }
 
+/// WiLayout represents the layout of work items within a subgroup when it
+/// accesses some value. WiData represents the layout of data owned by each work
+/// item.
 using WiLayout = Layout;
 using WiData = Layout;
 
+///===----------------------------------------------------------------------===///
+/// SGMap
+///===----------------------------------------------------------------------===///
+
+/// Helper class for tracking the analysis state of a value. For SGPropagation,
+/// the analysis state is simply the wi_layout and wi_data of each value.
+/// Purpose of this analysis to propagate some unique layout for each value in
+/// the program starting from some known values (like DPAS, StoreNd, etc.).
+///
+/// Given this, SGMap satisifies the following properties:
+///  1) SGMap is a lattice with two states - assigned and not assigned.
+///  2) Two SGMap values are equal if they are both assigned or both not
+///  assigned. The concrete value of assigned state does not matter.
+///  3) The meet operator works as follows:
+///     - If current state is assigned, return the current state. (already
+///     a unique layout is assigned. don't change it)
+///     - Otherwise, return the other state.
+
 struct SGMap {
 private:
   WiLayout layout;
@@ -71,8 +93,8 @@ struct SGMap {
   SGMap(const WiLayout &layout, const WiData &data)
       : layout(layout), data(data) {}
 
-  // Two lattice values are equal if they have `some` layout. The actual
-  // content of the layout does not matter.
+  /// Two lattice values are equal if they have `some` layout. The actual
+  /// content of the layout does not matter.
   bool operator==(const SGMap &other) const {
     return this->isAssigned() == other.isAssigned();
   }
@@ -107,9 +129,9 @@ SGMap SGMap::meet(const SGMap &lhs, const SGMap &rhs) {
   return lhs;
 }
 
+/// Since this is a backward analysis, join method is not used.
 SGMap SGMap::join(const SGMap &lhs, const SGMap &rhs) {
-  // Should not be triggered by this analysis, but required by `Lattice<T>`
-  llvm_unreachable("Join should not be triggered by this test");
+  llvm_unreachable("Join should not be triggered by SGMapPropagation.");
 }
 
 SGMap SGMap::getTransposedLayout(ArrayRef<int64_t> permutation) const {
@@ -124,14 +146,17 @@ SGMap SGMap::getTransposedLayout(ArrayRef<int64_t> permutation) const {
   return SGMap(newLayout, data);
 }
 
+///===----------------------------------------------------------------------===///
+/// SGMapLattice
+///===----------------------------------------------------------------------===///
+
+/// Lattice holding the SGMap for each value.
 struct SGMapLattice : public Lattice<SGMap> {
   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SGMapLattice)
   using Lattice::Lattice;
 };
 
-/// Helper Functions
-///
-
+/// Helper Function to get the expected layouts for DPAS operands.
 static SGMap getSGMapForDPASOperand(Type operandTy, unsigned operandNum) {
   int packingFactorForB = packedBSizeInBits / operandTy.getIntOrFloatBitWidth();
   int packingFactorForA =
@@ -143,6 +168,11 @@ static SGMap getSGMapForDPASOperand(Type operandTy, unsigned operandNum) {
                        operandNum == 0 ? packingFactorForA : 1}));
 }
 
+/// Helper Function to get the default layout for a given type. Usually this is,
+/// wi_layout = [1, subgroupSize] and wi_data = [1, 1].
+/// However, the minimum granularity of data access per work item is 16-bits.
+/// So, if the bitwidth of the type is less than 16, we need to pack the data to
+/// 16-bits.
 static SGMap getDefaultSgMap(Type ty) {
   int packingFactor = 1;
   if (ty.getIntOrFloatBitWidth() < packedASizeInBits)
@@ -150,6 +180,15 @@ static SGMap getDefaultSgMap(Type ty) {
   return SGMap(WiLayout({1, subgroupSize}), WiData({1, packingFactor}));
 }
 
+///===----------------------------------------------------------------------===///
+/// SGMapPropagation
+///===----------------------------------------------------------------------===///
+
+/// Backward data flow analysis to propagate the wi_layout and wi_data of each
+/// value in the program. Currently, the layouts for operands DPAS, StoreNd, and
+/// StoreScatter are fixed (known before propagation). Purpose of this analysis
+/// is to propagate those known layouts to all their producers and (other)
+/// consumers.
 class SGMapPropagation : public SparseBackwardDataFlowAnalysis<SGMapLattice> {
 private:
   void visitDpasOp(xegpu::DpasOp dpas, ArrayRef<SGMapLattice *> operands,
@@ -177,14 +216,6 @@ class SGMapPropagation : public SparseBackwardDataFlowAnalysis<SGMapLattice> {
                             ArrayRef<SGMapLattice *> operands,
                             ArrayRef<const SGMapLattice *> results);
 
-  void visitCreateNdDescOp(xegpu::CreateNdDescOp createNdDesc,
-                           ArrayRef<SGMapLattice *> operands,
-                           ArrayRef<const SGMapLattice *> results);
-
-  void visitCreateDescOp(xegpu::CreateDescOp createDesc,
-                         ArrayRef<SGMapLattice *> operands,
-                         ArrayRef<const SGMapLattice *> results);
-
 public:
   SGMapPropagation(DataFlowSolver &solver, SymbolTableCollection &symbolTable)
       : SparseBackwardDataFlowAnalysis(solver, symbolTable) {}
@@ -223,10 +254,6 @@ SGMapPropagation::visitOperation(Operation *op,
     visitVectorBitcastOp(bitcast, operands, results);
   else if (auto loadGather = dyn_cast<xegpu::LoadGatherOp>(op))
     visitLoadGatherOp(loadGather, operands, results);
-  else if (auto createNdDesc = dyn_cast<xegpu::CreateNdDescOp>(op))
-    visitCreateNdDescOp(createNdDesc, operands, results);
-  else if (auto createDesc = dyn_cast<xegpu::CreateDescOp>(op))
-    visitCreateDescOp(createDesc, operands, results);
   else if (auto storeScatter = dyn_cast<xegpu::StoreScatterOp>(op))
     visitStoreScatterOp(storeScatter, operands, results);
   /// All other ops
@@ -248,6 +275,7 @@ SGMapPropagation::visitOperation(Operation *op,
   return success();
 }
 
+/// Set the layouts for DPAS A, B, and C operands.
 void SGMapPropagation::visitDpasOp(xegpu::DpasOp dpas,
                                    ArrayRef<SGMapLattice *> operands,
                                    ArrayRef<const SGMapLattice *> results) {
@@ -264,6 +292,7 @@ void SGMapPropagation::visitDpasOp(xegpu::DpasOp dpas,
   }
 };
 
+/// Set the layout for the value and tensor descriptor operands in StoreNdOp.
 void SGMapPropagation::visitStoreNdOp(xegpu::StoreNdOp store,
                                       ArrayRef<SGMapLattice *> operands,
                                       ArrayRef<const SGMapLattice *> results) {
@@ -275,6 +304,8 @@ void SGMapPropagation::visitStoreNdOp(xegpu::StoreNdOp store,
   }
 }
 
+/// Propagate the layout of the value to the tensor descriptor operand in
+/// LoadNdOp.
 void SGMapPropagation::visitLoadNdOp(xegpu::LoadNdOp load,
                                      ArrayRef<SGMapLattice *> operands,
                                      ArrayRef<const SGMapLattice *> results) {
@@ -283,12 +314,20 @@ void SGMapPropagation::visitLoadNdOp(xegpu::LoadNdOp load,
   if (!valueLayout.isAssigned())
     return;
   SGMap tensorDescLayout = valueLayout;
-  if (auto transpose = load.getTranspose())
+  /// LoadNdOp has the transpose effect. However, at the stage of this analyis
+  /// this effect is not expected and should be abstracted away. Emit a warning.
+  /// TODO: Handle this case properly when `order` is introduced in the sg_map.
+  if (auto transpose = load.getTranspose()) {
+    emitWarning(load.getLoc())
+        << "Transpose effect is not expected for LoadNdOp";
     tensorDescLayout = valueLayout.getTransposedLayout(transpose.value());
+  }
   /// Propagate the new layout to the tensor descriptor operand.
   propagateIfChanged(operands[0], operands[0]->meet(tensorDescLayout));
 }
 
+/// For vector::TransposeOp, the layout of the result is transposed and
+/// propagated to the operand.
 void SGMapPropagation::visitTransposeOp(
     vector::TransposeOp transpose, ArrayRef<SGMapLattice *> operands,
     ArrayRef<const SGMapLattice *> results) {
@@ -302,6 +341,8 @@ void SGMapPropagation::visitTransposeOp(
   propagateIfChanged(operands[0], operands[0]->meet(newLayout));
 }
 
+/// For vector::BitCastOp, the wi_data of the source layout is changed based on
+/// the bit width of the source and result types.
 void SGMapPropagation::visitVectorBitcastOp(
     vector::BitCastOp bitcast, ArrayRef<SGMapLattice *> operands,
     ArrayRef<const SGMapLattice *> results) {
@@ -336,6 +377,8 @@ void SGMapPropagation::visitVectorBitcastOp(
                      operands[0]->meet(SGMap(newWiLayout, newWiData)));
 }
 
+/// Propagate the layout of the result to the tensor descriptor and mask
+/// operands in LoadGatherOp.
 void SGMapPropagation::visitLoadGatherOp(
     xegpu::LoadGatherOp load, ArrayRef<SGMapLattice *> operands,
     ArrayRef<const SGMapLattice *> results) {
@@ -343,56 +386,47 @@ void SGMapPropagation::visitLoadGatherOp(
   /// Need the layout of the value to propagate to the tensor descriptor.
   if (!valueLayout.isAssigned())
     return;
-  /// LoadGatherOp has the transpose effect, so propagate the transposed layout
-  /// to the tensor descriptor.
-  SGMap tensorDescLayout = valueLayout.getTransposedLayout({1, 0});
+
+  SGMap tensorDescLayout;
+  if (load.getTranspose()) {
+    /// LoadGatherOp has the transpose effect. However, at the stage of this
+    /// analyis this effect is not expected and should be abstracted away. Emit
+    /// a warning.
+    /// TODO: Handle this case properly when `order` is introduced in the
+    /// sg_map.
+    emitWarning(load.getLoc())
+        << "Transpose effect is not expected for LoadGatherOp";
+    tensorDescLayout = valueLayout.getTransposedLayout({1, 0});
+  } else
+    tensorDescLayout = valueLayout;
   /// Mask operand should have the same layout as the value but with wi_data =
   /// [1, 1]
-  SGMap maskLayout = SGMap(valueLayout.getLayout(), WiData({1, 1}));
+  SGMap maskLayout = getDefaultSgMap(load.getTensorDescType().getElementType());
   /// Propagate the new layout to the tensor descriptor operand.
   propagateIfChanged(operands[0], operands[0]->meet(tensorDescLayout));
   /// Propagate the new layout to the mask operand.
   propagateIfChanged(operands[1], operands[1]->meet(maskLayout));
 }
 
-void SGMapPropagation::visitCreateNdDescOp(
-    xegpu::CreateNdDescOp createNdDesc, ArrayRef<SGMapLattice *> operands,
-    ArrayRef<const SGMapLattice *> results) {
-  auto descLayout = results[0]->getValue();
-  /// Need the layout of the descriptor to propagate to the operands.
-  if (!descLayout.isAssigned())
-    return;
-  /// Propagate the layout to the source operand.
-  propagateIfChanged(operands[0], operands[0]->meet(descLayout));
-  /// For all other operands propagate the same layout with wi_data = [1, 1]
-  SGMap layout = SGMap(descLayout.getLayout(), WiData({1, 1}));
-  for (size_t i = 1; i < operands.size(); ++i) {
-    propagateIfChanged(operands[i], operands[i]->meet(layout));
-  }
-}
-
-void SGMapPropagation::visitCreateDescOp(
-    xegpu::CreateDescOp createDesc, ArrayRef<SGMapLattice *> operands,
-    ArrayRef<const SGMapLattice *> results) {
-  auto descLayout = results[0]->getValue();
-  /// Need the layout of the descriptor to propagate to the operands.
-  if (!descLayout.isAssigned())
-    return;
-  /// Propagate the layout to the source operand.
-  propagateIfChanged(operands[0], operands[0]->meet(descLayout));
-  /// For offset operand propagate the same layout with wi_data = [1, 1]
-  SGMap layout = SGMap(descLayout.getLayout(), WiData({1, 1}));
-  propagateIfChanged(operands[1], operands[1]->meet(layout));
-}
-
+/// Set the layout for the value, tensor descriptor, and mask operands in
+/// StoreScatterOp.
 void SGMapPropagation::visitStoreScatterOp(
     xegpu::StoreScatterOp storeScatter, ArrayRef<SGMapLattice *> operands,
     ArrayRef<const SGMapLattice *> results) {
-  /// StoreScatterOp has the transpose effect. Value has the regular layout,
-  /// while the tensor descriptor has the transposed layout.
   auto valueLayout =
       getDefaultSgMap(storeScatter.getTensorDescType().getElementType());
-  auto storeScatterLayout = valueLayout.getTransposedLayout({1, 0});
+  SGMap storeScatterLayout;
+  if (storeScatter.getTranspose()) {
+    /// StoreScatteOp allows transpose effect. However, at the stage of this
+    /// analyis this effect is not expected and should be abstracted away. Emit
+    /// a warning.
+    /// TODO: Handle this case properly when `order` is introduced in the
+    /// sg_map.
+    emitWarning(storeScatter.getLoc())
+        << "Transpose effect is not expected for StoreScatterOp";
+    storeScatterLayout = valueLayout.getTransposedLayout({1, 0});
+  } else
+    storeScatterLayout = valueLayout;
   /// Propagate the value layout.
   propagateIfChanged(operands[0], operands[0]->meet(valueLayout));
   /// Propagate the tensor descriptor layout.
@@ -403,6 +437,11 @@ void SGMapPropagation::visitStoreScatterOp(
 
 namespace {
 
+///===----------------------------------------------------------------------===///
+/// RunSGMapPropagation
+///===----------------------------------------------------------------------===///
+
+/// Driver class for running the SGMapPropagation analysis.
 class RunSGMapPropagation {
 public:
   RunSGMapPropagation(Operation *op) : target(op) {
@@ -487,13 +526,13 @@ struct XeGPUSubgroupDistributePass final
 
 void XeGPUSubgroupDistributePass::runOnOperation() {
   Operation *op = getOperation();
-
   RunSGMapPropagation solver(op);
 
-  // Print analysis results
+  // Print the analysis result and exit.
   if (printOnly) {
     auto &os = llvm::outs();
     solver.printAnalysisResult(os);
+    return;
   }
 }
 
diff --git a/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir b/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
index 81c4244091e87..871e6e2f9139e 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
@@ -62,7 +62,7 @@ func.func @test_extf_truncf_op(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegp
 }
 
 // -----
-func.func @test_load_gather_op(%a: memref<8x16xf16>, %b : memref<256xf16>, %c : memref<8x16xf32>) {
+func.func @test_load_gather_op_1(%a: memref<8x16xf16>, %b : memref<256xf16>, %c : memref<8x16xf32>) {
   %c0 = arith.constant 0 : index
   %6 = xegpu.create_nd_tdesc %a[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
   %7 = xegpu.load_nd %6 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
@@ -77,7 +77,17 @@ func.func @test_load_gather_op(%a: memref<8x16xf16>, %b : memref<256xf16>, %c :
 }
 
 // -----
-func.func @test_store_scatter_op(%c: memref<128xf32>){
+func.func @test_load_gather_op_2(%arg0: memref<256xf32>, %arg1: !xegpu.tensor_desc<16xf32>){
+  %0 = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
+  %1 = arith.constant dense<1>: vector<16xi1>
+  %2 = xegpu.create_tdesc %arg0, %0 : memref<256xf32>, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<chunk_size = 1>>
+  %3 = xegpu.load %2, %1 : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<chunk_size = 1>>, vector<16xi1> -> vector<16xf32>
+  xegpu.store_nd %3, %arg1 : vector<16xf32>, !xegpu.tensor_desc<16xf32>
+  return
+}
+
+// -----
+func.func @test_store_scatter_op_1(%c: memref<128xf32>){
   %cst = arith.constant dense<1.000000e+00> : vector<8x16xf32>
   %1 = arith.constant dense<1> : vector<16xi1>
   %2 = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
@@ -86,6 +96,15 @@ func.func @test_store_scatter_op(%c: memref<128xf32>){
   return
 }
 
+// -----
+func.func @test_store_scatter_op_2(%arg0: vector<16xf32>, %arg1: memref<256xf32>){
+  %0 = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
+  %1 = arith.constant dense<1>: vector<16xi1>
+  %2 = xegpu.create_tdesc %arg1, %0 : memref<256xf32>, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<chunk_size = 1>>
+  xegpu.store %arg0, %2, %1 : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<chunk_size = 1>>, vector<16xi1>
+  return
+}
+
 // -----
 func.func @test_bitcast_op_1(%arg0: memref<8x16xi16>, %arg1: memref<32x16xi8>, %c : memref<8x16xi32>) {
   %c0 = arith.constant 0 : index
@@ -191,7 +210,7 @@ func.func @test_if_op_2(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tenso
 }
 
 // -----
-func.func @test(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: vector<16x16xf16>, %arg3: i1) -> vector<8x16xf32> {
+func.func @test_if_op_3(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: vector<16x16xf16>, %arg3: i1, %arg4: !xegpu.tensor_desc<8x16xf32>){
   %0 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
   %1 = scf.if %arg3 -> (vector<16x16xf16>) {
     %3 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
@@ -200,30 +219,26 @@ func.func @test(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<1
     scf.yield %arg2 : vector<16x16xf16>
   }
   %2 = xegpu.dpas %0, %1 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
-  return %2 : vector<8x16xf32>
+  xegpu.store_nd %2, %arg4 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
+  return
 }
 
 // -----
-func.func @test(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: vector<16x16xf16>, %arg3: i1) -> vector<8x16xf32> {
-  %0 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
-  %1 = scf.if %arg3 -> (vector<8x16xf32>) {
-    %2 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
-    %3 = arith.addf %2, %arg2 : vector<16x16xf16>
-    %4 = xegpu.dpas %0, %3 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
-    scf.yield %4 : vector<8x16xf32>
-  } else {
-    %2 = xegpu.dpas %0, %arg2 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
-    scf.yield %2 : vector<8x16xf32>
-  }
-  return %1 : vector<8x16xf32>
+
+func.func @test_vector_reduction_1(%arg0: vector<16x16xf32>, %arg1: !xegpu.tensor_desc<16xf32>) {
+  %0 = arith.constant dense<0.000000e+00> : vector<16xf32>
+  %1 = vector.multi_reduction <add>, %arg0, %0 [0] : vector<16x16xf32> to vector<16xf32>
+  xegpu.store_nd %1, %arg1 : vector<16xf32>, !xegpu.tensor_desc<16xf32>
+  return
 }
 
+
+
 // -----
-func.func @test(%arg0: !xegpu.tensor_desc<8x16xi16>, %arg1: !xegpu.tensor_desc<16x16xi16>) -> vector<8x16xf32> {
-  %0 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<8x16xi16> -> vector<8x16xi16>
-  %1 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xi16> -> vector<16x16xi16>
-  %2 = arith.bitcast %0 : vector<8x16xi16> to vector<8x16xf16>
-  %3 = arith.bitcast %1 : vector<16x16xi16> to vector<16x16xf16>
-  %4 = xegpu.dpas %2, %3 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
-  return %4 : vector<8x16xf32>
+
+func.func @test_vector_reduction_2(%arg0: vector<16x16xf32>, %arg1: !xegpu.tensor_desc<16xf32>) {
+  %0 = arith.constant dense<0.000000e+00> : vector<16xf32>
+  %1 = vector.multi_reduction <add>, %arg0, %0 [1] : vector<16x16xf32> to vector<16xf32>
+  xegpu.store_nd %1, %arg1 : vector<16xf32>, !xegpu.tensor_desc<16xf32>
+  return
 }

>From 462b3c09d15254f9d6dbdbe4c11aef9761e56aea Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Fri, 7 Mar 2025 04:00:35 +0000
Subject: [PATCH 09/17] save work

---
 .../Transforms/XeGPUSubgroupDistribute.cpp    |  77 +++-
 .../XeGPU/subgroup-map-propagation.mlir       | 341 ++++++++++++------
 2 files changed, 296 insertions(+), 122 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 13030d5c74ea0..de65f285d75b4 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -134,6 +134,7 @@ SGMap SGMap::join(const SGMap &lhs, const SGMap &rhs) {
   llvm_unreachable("Join should not be triggered by SGMapPropagation.");
 }
 
+/// Get the transposed layout according to the given permutation.
 SGMap SGMap::getTransposedLayout(ArrayRef<int64_t> permutation) const {
   if (!isAssigned())
     return {};
@@ -180,6 +181,11 @@ static SGMap getDefaultSgMap(Type ty) {
   return SGMap(WiLayout({1, subgroupSize}), WiData({1, packingFactor}));
 }
 
+/// Helper Function to get the default layout representing constants.
+static SGMap getDefaultSgMap() {
+  return SGMap(WiLayout({1, subgroupSize}), WiData({1, 1}));
+}
+
 ///===----------------------------------------------------------------------===///
 /// SGMapPropagation
 ///===----------------------------------------------------------------------===///
@@ -216,6 +222,14 @@ class SGMapPropagation : public SparseBackwardDataFlowAnalysis<SGMapLattice> {
                             ArrayRef<SGMapLattice *> operands,
                             ArrayRef<const SGMapLattice *> results);
 
+  void visitCreateNdDescOp(xegpu::CreateNdDescOp createNdDesc,
+                           ArrayRef<SGMapLattice *> operands,
+                           ArrayRef<const SGMapLattice *> results);
+
+  void visitCreateDescOp(xegpu::CreateDescOp createDesc,
+                         ArrayRef<SGMapLattice *> operands,
+                         ArrayRef<const SGMapLattice *> results);
+
 public:
   SGMapPropagation(DataFlowSolver &solver, SymbolTableCollection &symbolTable)
       : SparseBackwardDataFlowAnalysis(solver, symbolTable) {}
@@ -254,6 +268,10 @@ SGMapPropagation::visitOperation(Operation *op,
     visitVectorBitcastOp(bitcast, operands, results);
   else if (auto loadGather = dyn_cast<xegpu::LoadGatherOp>(op))
     visitLoadGatherOp(loadGather, operands, results);
+  else if (auto createNdDesc = dyn_cast<xegpu::CreateNdDescOp>(op))
+    visitCreateNdDescOp(createNdDesc, operands, results);
+  else if (auto createDesc = dyn_cast<xegpu::CreateDescOp>(op))
+    visitCreateDescOp(createDesc, operands, results);
   else if (auto storeScatter = dyn_cast<xegpu::StoreScatterOp>(op))
     visitStoreScatterOp(storeScatter, operands, results);
   /// All other ops
@@ -318,8 +336,7 @@ void SGMapPropagation::visitLoadNdOp(xegpu::LoadNdOp load,
   /// this effect is not expected and should be abstracted away. Emit a warning.
   /// TODO: Handle this case properly when `order` is introduced in the sg_map.
   if (auto transpose = load.getTranspose()) {
-    emitWarning(load.getLoc())
-        << "Transpose effect is not expected for LoadNdOp";
+    load.emitWarning("Transpose effect is not expected for LoadNdOp");
     tensorDescLayout = valueLayout.getTransposedLayout(transpose.value());
   }
   /// Propagate the new layout to the tensor descriptor operand.
@@ -394,21 +411,53 @@ void SGMapPropagation::visitLoadGatherOp(
     /// a warning.
     /// TODO: Handle this case properly when `order` is introduced in the
     /// sg_map.
-    emitWarning(load.getLoc())
-        << "Transpose effect is not expected for LoadGatherOp";
+    load.emitWarning("Transpose effect is not expected for LoadGatherOp");
     tensorDescLayout = valueLayout.getTransposedLayout({1, 0});
   } else
     tensorDescLayout = valueLayout;
   /// Mask operand should have the same layout as the value but with wi_data =
   /// [1, 1]
-  SGMap maskLayout = getDefaultSgMap(load.getTensorDescType().getElementType());
+  SGMap maskLayout = SGMap(valueLayout.getLayout(), WiData({1, 1}));
   /// Propagate the new layout to the tensor descriptor operand.
   propagateIfChanged(operands[0], operands[0]->meet(tensorDescLayout));
   /// Propagate the new layout to the mask operand.
   propagateIfChanged(operands[1], operands[1]->meet(maskLayout));
 }
 
-/// Set the layout for the value, tensor descriptor, and mask operands in
+/// Propagate the layout of the descriptor to the operands in CreateNdDescOp.
+void SGMapPropagation::visitCreateNdDescOp(
+    xegpu::CreateNdDescOp createNdDesc, ArrayRef<SGMapLattice *> operands,
+    ArrayRef<const SGMapLattice *> results) {
+  auto descLayout = results[0]->getValue();
+  /// Need the layout of the descriptor to propagate to the operands.
+  if (!descLayout.isAssigned())
+    return;
+  /// Propagate the layout to the source operand.
+  propagateIfChanged(operands[0], operands[0]->meet(descLayout));
+  /// For all other operands propagate the descriptor layout.
+  SGMap layout = getDefaultSgMap();
+  for (size_t i = 1; i < operands.size(); ++i) {
+    propagateIfChanged(operands[i], operands[i]->meet(layout));
+  }
+}
+
+/// Propagate the layout of the descriptor to the source and offset operands in
+/// CreateDescOp.
+void SGMapPropagation::visitCreateDescOp(
+    xegpu::CreateDescOp createDesc, ArrayRef<SGMapLattice *> operands,
+    ArrayRef<const SGMapLattice *> results) {
+  auto descLayout = results[0]->getValue();
+  /// Need the layout of the descriptor to propagate to the operands.
+  if (!descLayout.isAssigned())
+    return;
+  /// Propagate the layout to the source operand.
+  propagateIfChanged(operands[0], operands[0]->meet(descLayout));
+  /// For offset operand propagate the default layout.
+  SGMap layout = getDefaultSgMap();
+  propagateIfChanged(operands[1], operands[1]->meet(layout));
+}
+
+/// Set the layout for the value, tensor descriptor, and mask operands in the
 /// StoreScatterOp.
 void SGMapPropagation::visitStoreScatterOp(
     xegpu::StoreScatterOp storeScatter, ArrayRef<SGMapLattice *> operands,
@@ -422,8 +471,8 @@ void SGMapPropagation::visitStoreScatterOp(
     /// a warning.
     /// TODO: Handle this case properly when `order` is introduced in the
     /// sg_map.
-    emitWarning(storeScatter.getLoc())
-        << "Transpose effect is not expected for StoreScatterOp";
+    storeScatter.emitWarning(
+        "Transpose effect is not expected for StoreScatterOp");
     storeScatterLayout = valueLayout.getTransposedLayout({1, 0});
   } else
     storeScatterLayout = valueLayout;
@@ -431,8 +480,9 @@ void SGMapPropagation::visitStoreScatterOp(
   propagateIfChanged(operands[0], operands[0]->meet(valueLayout));
   /// Propagate the tensor descriptor layout.
   propagateIfChanged(operands[1], operands[1]->meet(storeScatterLayout));
-  /// Propagate the mask layout.
-  propagateIfChanged(operands[2], operands[2]->meet(valueLayout));
+  /// Use default layout for mask operand.
+  auto maskLayout = getDefaultSgMap();
+  propagateIfChanged(operands[2], operands[2]->meet(maskLayout));
 }
 
 namespace {
@@ -517,10 +567,9 @@ struct XeGPUSubgroupDistributePass final
   }
   void runOnOperation() override;
   /// Print sg map propagation analysis result and exit for testing purposes.
-  Option<bool> printOnly{
-      *this, "print-only", llvm::cl::init(false),
-      llvm::cl::desc(
-          "Print the result of the subgroup map propagation analysis")};
+  Option<bool> printOnly{*this, "print-analysis-only", llvm::cl::init(false),
+                         llvm::cl::desc("Print the result of the subgroup map "
+                                        "propagation analysis and exit.")};
 };
 } // namespace
 
diff --git a/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir b/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
index 871e6e2f9139e..c75931cd6a37b 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
@@ -1,57 +1,149 @@
-func.func @test_dpas_op_1(%a: memref<8x16xf16>, %b : memref<16x16xf16>, %c : memref<8x16xf32>) {
+// RUN: mlir-opt -xegpu-subgroup-distribute='print-analysis-only=true' -split-input-file %s | FileCheck %s
+
+// CHECK: function: test_dpas_op_1:
+// CHECK: op    : %{{.*}} = arith.constant 0 : index
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %{{.*}} = arith.constant dense<0.000000e+00> : vector<8x16xf32>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[T1:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1]
+// CHECK-NEXT: op    : %[[T2:.*]] = xegpu.load_nd %[[T0]]  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[T3:.*]] = xegpu.load_nd %[[T1]]  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1]
+// CHECK-NEXT: op    : %[[T4:.*]] = xegpu.dpas %[[T2]], %[[T3]], %{{.*}} : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %{{.*}} = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+func.func @test_dpas_op_1(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<8x16xf32>) {
   %c0 = arith.constant 0 : index
-  %cst = arith.constant dense<0.0> : vector<8x16xf32>
-  %0 = xegpu.create_nd_tdesc %a[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
-  %1 = xegpu.create_nd_tdesc %b[%c0, %c0] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
-  %2 = xegpu.load_nd %0 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
-  %3 = xegpu.load_nd %1 : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+  %cst = arith.constant dense<0.000000e+00> : vector<8x16xf32>
+  %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
+  %1 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
+  %2 = xegpu.load_nd %0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+  %3 = xegpu.load_nd %1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
   %4 = xegpu.dpas %2, %3, %cst : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
-  %5 = xegpu.create_nd_tdesc %c[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
-  xegpu.store_nd %4, %5 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
+  %5 = xegpu.create_nd_tdesc %arg2[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
+  xegpu.store_nd %4, %5  : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
   return
 }
 
+
 // -----
-func.func @test_dpas_op_2(%arg0: vector<8x32xi8>, %arg1: vector<32x16xi8>, %c : memref<8x16xi32>) {
+// CHECK: function: test_dpas_op_2:
+// CHECK-NEXT: argument: <block argument> of type 'vector<8x32xi8>' at index: 0
+// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 2]
+// CHECK-NEXT: argument: <block argument> of type 'vector<32x16xi8>' at index: 1
+// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [4, 1]
+// CHECK-NEXT: argument: <block argument> of type 'memref<8x16xi32>' at index: 2
+// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %{{.*}} = arith.constant 0 : index
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[T0:.*]] = xegpu.dpas %{{.*}} : vector<8x32xi8>, vector<32x16xi8> -> vector<8x16xi32>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[T1:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xi32> -> !xegpu.tensor_desc<8x16xi32>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+func.func @test_dpas_op_2(%arg0: vector<8x32xi8>, %arg1: vector<32x16xi8>, %arg2: memref<8x16xi32>) {
   %c0 = arith.constant 0 : index
   %0 = xegpu.dpas %arg0, %arg1 : vector<8x32xi8>, vector<32x16xi8> -> vector<8x16xi32>
-  %1 = xegpu.create_nd_tdesc %c [%c0, %c0] : memref<8x16xi32> -> !xegpu.tensor_desc<8x16xi32>
-  xegpu.store_nd %0, %1 : vector<8x16xi32>, !xegpu.tensor_desc<8x16xi32>
+  %1 = xegpu.create_nd_tdesc %arg2[%c0, %c0] : memref<8x16xi32> -> !xegpu.tensor_desc<8x16xi32>
+  xegpu.store_nd %0, %1  : vector<8x16xi32>, !xegpu.tensor_desc<8x16xi32>
   return
 }
 
 // -----
-func.func @test_transpose_op_1(%a: memref<8x16xf16>, %b : memref<16x16xf16>, %c : memref<8x16xf32>) {
+// CHECK: function: test_transpose_op_1:
+// CHECK-NEXT: argument: <block argument> of type 'memref<8x16xf16>' at index: 0
+// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: argument: <block argument> of type 'memref<16x16xf16>' at index: 1
+// CHECK-NEXT: sg_map  : wi_layout: [16, 1], wi_data: [2, 1]
+// CHECK-NEXT: argument: <block argument> of type 'memref<8x16xf32>' at index: 2
+// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %{{.*}} = arith.constant 0 : index
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<8x16xf32>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[T1:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [16, 1], wi_data: [2, 1]
+// CHECK-NEXT: op    : %[[T2:.*]] = xegpu.load_nd %[[T0]]  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[T3:.*]] = xegpu.load_nd %[[T1]] <{transpose = array<i64: 1, 0>}> : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1]
+// CHECK-NEXT: op    : %[[T4:.*]] = xegpu.dpas %[[T2]], %[[T3]], %[[CST]] : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[T5:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+func.func @test_transpose_op_1(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<8x16xf32>) {
   %c0 = arith.constant 0 : index
-  %cst = arith.constant dense<0.0> : vector<8x16xf32>
-  %0 = xegpu.create_nd_tdesc %a[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
-  %1 = xegpu.create_nd_tdesc %b[%c0, %c0] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
-  %2 = xegpu.load_nd %0 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
-  %3 = xegpu.load_nd %1 {transpose = array<i64: 1, 0>} : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+  %cst = arith.constant dense<0.000000e+00> : vector<8x16xf32>
+  %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
+  %1 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
+  %2 = xegpu.load_nd %0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+  %3 = xegpu.load_nd %1 <{transpose = array<i64: 1, 0>}> : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
   %4 = xegpu.dpas %2, %3, %cst : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
-  %5 = xegpu.create_nd_tdesc %c[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
-  xegpu.store_nd %4, %5 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
+  %5 = xegpu.create_nd_tdesc %arg2[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
+  xegpu.store_nd %4, %5  : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
   return
 }
 
 // -----
-func.func @test_transpose_op_2(%a: memref<8x16xf16>, %b : memref<16x16xf16>, %c : memref<8x16xf32>) {
+// CHECK: argument: <block argument> of type 'memref<8x16xf16>' at index: 0
+// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: argument: <block argument> of type 'memref<16x16xf16>' at index: 1
+// CHECK-NEXT: sg_map  : wi_layout: [16, 1], wi_data: [2, 1]
+// CHECK-NEXT: argument: <block argument> of type 'memref<8x16xf32>' at index: 2
+// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %{{.*}} = arith.constant 0 : index
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<8x16xf32>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[T1:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [16, 1], wi_data: [2, 1]
+// CHECK-NEXT: op    : %[[T2:.*]] = xegpu.load_nd %[[T0]]  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[T3:.*]] = xegpu.load_nd %[[T1]]  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [16, 1], wi_data: [2, 1]
+// CHECK-NEXT: op    : %[[T4:.*]] = vector.transpose %[[T3]], [1, 0] : vector<16x16xf16> to vector<16x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1]
+// CHECK-NEXT: op    : %[[T5:.*]] = xegpu.dpas %[[T2]], %[[T4]], %[[CST]] : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[T6:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+func.func @test_transpose_op_2(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<8x16xf32>) {
   %c0 = arith.constant 0 : index
-  %cst = arith.constant dense<0.0> : vector<8x16xf32>
-  %0 = xegpu.create_nd_tdesc %a[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
-  %1 = xegpu.create_nd_tdesc %b[%c0, %c0] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
-  %2 = xegpu.load_nd %0 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
-  %3 = xegpu.load_nd %1 : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
-  %6 = vector.transpose %3, [1, 0] : vector<16x16xf16> to vector<16x16xf16>
-  %4 = xegpu.dpas %2, %6, %cst : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
-  %5 = xegpu.create_nd_tdesc %c[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
-  xegpu.store_nd %4, %5 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
+  %cst = arith.constant dense<0.000000e+00> : vector<8x16xf32>
+  %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
+  %1 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
+  %2 = xegpu.load_nd %0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+  %3 = xegpu.load_nd %1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+  %4 = vector.transpose %3, [1, 0] : vector<16x16xf16> to vector<16x16xf16>
+  %5 = xegpu.dpas %2, %4, %cst : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
+  %6 = xegpu.create_nd_tdesc %arg2[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
+  xegpu.store_nd %5, %6  : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
   return
 }
 
-
-
 // -----
+// CHECK: argument: <block argument> of type '!xegpu.tensor_desc<8x16xf16>' at index: 0
+// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<16x16xf16>' at index: 1
+// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [2, 1]
+// CHECK-NEXT: op    : %[[T0:.*]] = xegpu.load_nd %{{.*}}  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[T1:.*]] = xegpu.load_nd %{{.*}}  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1]
+// CHECK-NEXT: op    : %[[T2:.*]] = arith.extf %[[T1]] : vector<16x16xf16> to vector<16x16xf32>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1]
+// CHECK-NEXT: op    : %[[T3:.*]] = arith.truncf %[[T2]] : vector<16x16xf32> to vector<16x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1]
+// CHECK-NEXT: op    : %[[T4:.*]] = xegpu.dpas %[[T0]], %[[T3]] : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+// CHECK-NEXT: sg_map for result #0: Not assigned.
 func.func @test_extf_truncf_op(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>) -> vector<8x16xf32> {
   %0 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
   %1 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
@@ -62,75 +154,112 @@ func.func @test_extf_truncf_op(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegp
 }
 
 // -----
-func.func @test_load_gather_op_1(%a: memref<8x16xf16>, %b : memref<256xf16>, %c : memref<8x16xf32>) {
+// CHECK: function: test_load_gather_op_1:
+// CHECK-NEXT: argument: <block argument> of type 'memref<8x16xf16>' at index: 0
+// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: argument: <block argument> of type 'memref<256xf16>' at index: 1
+// CHECK-NEXT: sg_map  : wi_layout: [16, 1], wi_data: [2, 1]
+// CHECK-NEXT: argument: <block argument> of type 'memref<8x16xf32>' at index: 2
+// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %{{.*}} = arith.constant 0 : index
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[T1:.*]] = xegpu.load_nd %[[T0]]  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[CST:.*]] = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[CST0:.*]] = arith.constant dense<true> : vector<16xi1>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[T2:.*]] = xegpu.create_tdesc %{{.*}}, %[[CST]] : memref<256xf16>, vector<16xindex> -> !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16 : i64>>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [16, 1], wi_data: [2, 1]
+// CHECK-NEXT: op    : %[[T3:.*]] = xegpu.load %[[T2]], %[[CST0]] <{transpose}> : !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16 : i64>>, vector<16xi1> -> vector<16x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1]
+// CHECK-NEXT: op    : %[[T4:.*]] = xegpu.dpas %[[T1]], %[[T3]] : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[T5:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+func.func @test_load_gather_op_1(%arg0: memref<8x16xf16>, %arg1: memref<256xf16>, %arg2: memref<8x16xf32>) {
   %c0 = arith.constant 0 : index
-  %6 = xegpu.create_nd_tdesc %a[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
-  %7 = xegpu.load_nd %6 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
-  %0 = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
-  %1 = arith.constant dense<1>: vector<16xi1>
-  %2 = xegpu.create_tdesc %b, %0 : memref<256xf16>, vector<16xindex> -> !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16>>
-  %3 = xegpu.load %2, %1 <{transpose}> : !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16>>, vector<16xi1> -> vector<16x16xf16>
-  %4 = xegpu.dpas %7, %3 : vector<8x16xf16>, vector<16x16xf16>-> vector<8x16xf32>
-  %5 = xegpu.create_nd_tdesc %c[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
-  xegpu.store_nd %4, %5 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
+  %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
+  %1 = xegpu.load_nd %0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+  %cst = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
+  %cst_0 = arith.constant dense<true> : vector<16xi1>
+  %2 = xegpu.create_tdesc %arg1, %cst : memref<256xf16>, vector<16xindex> -> !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16 : i64>>
+  %3 = xegpu.load %2, %cst_0 <{transpose}> : !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16 : i64>>, vector<16xi1> -> vector<16x16xf16>
+  %4 = xegpu.dpas %1, %3 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+  %5 = xegpu.create_nd_tdesc %arg2[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
+  xegpu.store_nd %4, %5  : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
   return
 }
 
 // -----
-func.func @test_load_gather_op_2(%arg0: memref<256xf32>, %arg1: !xegpu.tensor_desc<16xf32>){
-  %0 = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
-  %1 = arith.constant dense<1>: vector<16xi1>
-  %2 = xegpu.create_tdesc %arg0, %0 : memref<256xf32>, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<chunk_size = 1>>
-  %3 = xegpu.load %2, %1 : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<chunk_size = 1>>, vector<16xi1> -> vector<16xf32>
-  xegpu.store_nd %3, %arg1 : vector<16xf32>, !xegpu.tensor_desc<16xf32>
+// CHECK: argument: <block argument> of type 'memref<256xf32>' at index: 0
+// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<16xf32>' at index: 1
+// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[CST:.*]] = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[CST0:.*]] = arith.constant dense<true> : vector<16xi1>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[T0:.*]] = xegpu.create_tdesc %{{.*}}, %[[CST]] : memref<256xf32>, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[T1]] = xegpu.load %[[T0]], %[[CST0]]  : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> -> vector<16xf32>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+func.func @test_load_gather_op_2(%arg0: memref<256xf32>, %arg1: !xegpu.tensor_desc<16xf32>) {
+  %cst = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
+  %cst_0 = arith.constant dense<true> : vector<16xi1>
+  %0 = xegpu.create_tdesc %arg0, %cst : memref<256xf32>, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
+  %1 = xegpu.load %0, %cst_0  : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> -> vector<16xf32>
+  xegpu.store_nd %1, %arg1  : vector<16xf32>, !xegpu.tensor_desc<16xf32>
   return
 }
 
 // -----
-func.func @test_store_scatter_op_1(%c: memref<128xf32>){
+func.func @test_store_scatter_op_1(%arg0: memref<128xf32>) {
   %cst = arith.constant dense<1.000000e+00> : vector<8x16xf32>
-  %1 = arith.constant dense<1> : vector<16xi1>
-  %2 = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
-  %3 = xegpu.create_tdesc %c, %2 : memref<128xf32>, vector<16xindex> -> !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<chunk_size = 8>>
-  xegpu.store %cst, %3, %1 <{transpose}> : vector<8x16xf32>, !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<chunk_size = 8>>, vector<16xi1>
+  %cst_0 = arith.constant dense<true> : vector<16xi1>
+  %cst_1 = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
+  %0 = xegpu.create_tdesc %arg0, %cst_1 : memref<128xf32>, vector<16xindex> -> !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<chunk_size = 8 : i64>>
+  xegpu.store %cst, %0, %cst_0 <{transpose}> : vector<8x16xf32>, !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<chunk_size = 8 : i64>>, vector<16xi1>
   return
 }
 
 // -----
-func.func @test_store_scatter_op_2(%arg0: vector<16xf32>, %arg1: memref<256xf32>){
-  %0 = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
-  %1 = arith.constant dense<1>: vector<16xi1>
-  %2 = xegpu.create_tdesc %arg1, %0 : memref<256xf32>, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<chunk_size = 1>>
-  xegpu.store %arg0, %2, %1 : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<chunk_size = 1>>, vector<16xi1>
+func.func @test_store_scatter_op_2(%arg0: vector<16xf32>, %arg1: memref<256xf32>) {
+  %cst = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
+  %cst_0 = arith.constant dense<true> : vector<16xi1>
+  %0 = xegpu.create_tdesc %arg1, %cst : memref<256xf32>, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
+  xegpu.store %arg0, %0, %cst_0  : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1>
   return
 }
 
 // -----
-func.func @test_bitcast_op_1(%arg0: memref<8x16xi16>, %arg1: memref<32x16xi8>, %c : memref<8x16xi32>) {
+func.func @test_bitcast_op_1(%arg0: memref<8x16xi16>, %arg1: memref<32x16xi8>, %arg2: memref<8x16xi32>) {
   %c0 = arith.constant 0 : index
-  %2 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x16xi16> -> !xegpu.tensor_desc<8x16xi16>
-  %3 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<32x16xi8> -> !xegpu.tensor_desc<32x16xi8>
-  %4 = xegpu.load_nd %2 : !xegpu.tensor_desc<8x16xi16> -> vector<8x16xi16>
-  %5 = xegpu.load_nd %3 : !xegpu.tensor_desc<32x16xi8> -> vector<32x16xi8>
-  %6 = vector.bitcast %4 : vector<8x16xi16> to vector<8x32xi8>
-  %0 = xegpu.dpas %6, %5 : vector<8x32xi8>, vector<32x16xi8> -> vector<8x16xi32>
-  %1 = xegpu.create_nd_tdesc %c [%c0, %c0] : memref<8x16xi32> -> !xegpu.tensor_desc<8x16xi32>
-  xegpu.store_nd %0, %1 : vector<8x16xi32>, !xegpu.tensor_desc<8x16xi32>
+  %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x16xi16> -> !xegpu.tensor_desc<8x16xi16>
+  %1 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<32x16xi8> -> !xegpu.tensor_desc<32x16xi8>
+  %2 = xegpu.load_nd %0  : !xegpu.tensor_desc<8x16xi16> -> vector<8x16xi16>
+  %3 = xegpu.load_nd %1  : !xegpu.tensor_desc<32x16xi8> -> vector<32x16xi8>
+  %4 = vector.bitcast %2 : vector<8x16xi16> to vector<8x32xi8>
+  %5 = xegpu.dpas %4, %3 : vector<8x32xi8>, vector<32x16xi8> -> vector<8x16xi32>
+  %6 = xegpu.create_nd_tdesc %arg2[%c0, %c0] : memref<8x16xi32> -> !xegpu.tensor_desc<8x16xi32>
+  xegpu.store_nd %5, %6  : vector<8x16xi32>, !xegpu.tensor_desc<8x16xi32>
   return
 }
 
 // -----
-func.func @test_bitcast_op_2(%arg0: memref<8x32xi8>, %arg1: memref<16x32xi8>, %c : memref<8x16xf32>) {
+func.func @test_bitcast_op_2(%arg0: memref<8x32xi8>, %arg1: memref<16x32xi8>, %arg2: memref<8x16xf32>) {
   %c0 = arith.constant 0 : index
-  %2 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x32xi8> -> !xegpu.tensor_desc<8x32xi8>
-  %3 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<16x32xi8> -> !xegpu.tensor_desc<16x32xi8>
-  %4 = xegpu.load_nd %2 : !xegpu.tensor_desc<8x32xi8> -> vector<8x32xi8>
-  %5 = xegpu.load_nd %3 : !xegpu.tensor_desc<16x32xi8> -> vector<16x32xi8>
-  %6 = vector.bitcast %4 : vector<8x32xi8> to vector<8x16xf16>
-  %7 = vector.bitcast %5 : vector<16x32xi8> to vector<16x16xf16>
-  %0 = xegpu.dpas %6, %7 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
-  %1 = xegpu.create_nd_tdesc %c [%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
-  xegpu.store_nd %0, %1 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
+  %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x32xi8> -> !xegpu.tensor_desc<8x32xi8>
+  %1 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<16x32xi8> -> !xegpu.tensor_desc<16x32xi8>
+  %2 = xegpu.load_nd %0  : !xegpu.tensor_desc<8x32xi8> -> vector<8x32xi8>
+  %3 = xegpu.load_nd %1  : !xegpu.tensor_desc<16x32xi8> -> vector<16x32xi8>
+  %4 = vector.bitcast %2 : vector<8x32xi8> to vector<8x16xf16>
+  %5 = vector.bitcast %3 : vector<16x32xi8> to vector<16x16xf16>
+  %6 = xegpu.dpas %4, %5 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+  %7 = xegpu.create_nd_tdesc %arg2[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
+  xegpu.store_nd %6, %7  : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
   return
 }
 
@@ -141,7 +270,7 @@ func.func @test_binary_op_1(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.t
   %2 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
   %3 = arith.addf %1, %2 : vector<16x16xf16>
   %4 = xegpu.dpas %0, %3 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
-  xegpu.store_nd %4, %arg2 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
+  xegpu.store_nd %4, %arg2  : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
   return
 }
 
@@ -152,34 +281,34 @@ func.func @test_binary_op_2(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.t
   %cst = arith.constant dense<1.000000e+00> : vector<16x16xf16>
   %2 = arith.addf %1, %cst : vector<16x16xf16>
   %3 = xegpu.dpas %0, %2 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
-  xegpu.store_nd %3, %arg2 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
-  xegpu.store_nd %2, %arg3 : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16>
+  xegpu.store_nd %3, %arg2  : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
+  xegpu.store_nd %2, %arg3  : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16>
   return
 }
 
 // -----
-func.func @test_for_op(%a: memref<8x128xf16>, %b : memref<128x16xf16>, %c : memref<8x16xf32>) {
+func.func @test_for_op(%arg0: memref<8x128xf16>, %arg1: memref<128x16xf16>, %arg2: memref<8x16xf32>) {
   %c0 = arith.constant 0 : index
   %c128 = arith.constant 128 : index
   %c16 = arith.constant 16 : index
-  %0 = xegpu.create_nd_tdesc %a[%c0, %c0] : memref<8x128xf16> -> !xegpu.tensor_desc<8x16xf16>
-  %1 = xegpu.create_nd_tdesc %b[%c0, %c0] : memref<128x16xf16> -> !xegpu.tensor_desc<16x16xf16>
-  %2 = arith.constant dense<0.0> : vector<8x16xf32>
-  %3:3 = scf.for %k = %c0 to %c128 step %c16 iter_args(%arg0 = %0, %arg1 = %1, %arg2 = %2) -> (!xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<16x16xf16>, vector<8x16xf32>) {
-    %4 = xegpu.load_nd %arg0 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
-    %5 = xegpu.load_nd %arg1 : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
-    %6 = xegpu.dpas %4, %5, %arg2 : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
-    %7 = xegpu.update_nd_offset %arg0, [%c0, %c16] : !xegpu.tensor_desc<8x16xf16>
-    %8 = xegpu.update_nd_offset %arg1, [%c16, %c0] : !xegpu.tensor_desc<16x16xf16>
+  %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x128xf16> -> !xegpu.tensor_desc<8x16xf16>
+  %1 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<128x16xf16> -> !xegpu.tensor_desc<16x16xf16>
+  %cst = arith.constant dense<0.000000e+00> : vector<8x16xf32>
+  %2:3 = scf.for %arg3 = %c0 to %c128 step %c16 iter_args(%arg4 = %0, %arg5 = %1, %arg6 = %cst) -> (!xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<16x16xf16>, vector<8x16xf32>) {
+    %4 = xegpu.load_nd %arg4  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+    %5 = xegpu.load_nd %arg5  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+    %6 = xegpu.dpas %4, %5, %arg6 : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
+    %7 = xegpu.update_nd_offset %arg4, [%c0, %c16] : !xegpu.tensor_desc<8x16xf16>
+    %8 = xegpu.update_nd_offset %arg5, [%c16, %c0] : !xegpu.tensor_desc<16x16xf16>
     scf.yield %7, %8, %6 : !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<16x16xf16>, vector<8x16xf32>
   }
-  %9 = xegpu.create_nd_tdesc %c[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
-  xegpu.store_nd %3#2, %9 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
+  %3 = xegpu.create_nd_tdesc %arg2[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
+  xegpu.store_nd %2#2, %3  : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
   return
 }
 
 // -----
-func.func @test_if_op_1(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: i1, %arg3: !xegpu.tensor_desc<8x16xf32>){
+func.func @test_if_op_1(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: i1, %arg3: !xegpu.tensor_desc<8x16xf32>) {
   %0 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
   %1 = scf.if %arg2 -> (vector<16x16xf16>) {
     %3 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
@@ -189,12 +318,12 @@ func.func @test_if_op_1(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tenso
     scf.yield %3 : vector<16x16xf16>
   }
   %2 = xegpu.dpas %0, %1 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
-  xegpu.store_nd %2, %arg3 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
+  xegpu.store_nd %2, %arg3  : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
   return
 }
 
 // -----
-func.func @test_if_op_2(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: i1, %arg3: !xegpu.tensor_desc<8x16xf32>, %arg4: !xegpu.tensor_desc<16x16xf16>){
+func.func @test_if_op_2(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: i1, %arg3: !xegpu.tensor_desc<8x16xf32>, %arg4: !xegpu.tensor_desc<16x16xf16>) {
   %0 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
   %1 = scf.if %arg2 -> (vector<16x16xf16>) {
     %3 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
@@ -204,13 +333,13 @@ func.func @test_if_op_2(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tenso
     scf.yield %3 : vector<16x16xf16>
   }
   %2 = xegpu.dpas %0, %1 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
-  xegpu.store_nd %2, %arg3 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
-  xegpu.store_nd %1, %arg4 : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16>
+  xegpu.store_nd %2, %arg3  : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
+  xegpu.store_nd %1, %arg4  : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16>
   return
 }
 
 // -----
-func.func @test_if_op_3(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: vector<16x16xf16>, %arg3: i1, %arg4: !xegpu.tensor_desc<8x16xf32>){
+func.func @test_if_op_3(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: vector<16x16xf16>, %arg3: i1, %arg4: !xegpu.tensor_desc<8x16xf32>) {
   %0 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
   %1 = scf.if %arg3 -> (vector<16x16xf16>) {
     %3 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
@@ -219,26 +348,22 @@ func.func @test_if_op_3(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tenso
     scf.yield %arg2 : vector<16x16xf16>
   }
   %2 = xegpu.dpas %0, %1 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
-  xegpu.store_nd %2, %arg4 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
+  xegpu.store_nd %2, %arg4  : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
   return
 }
 
 // -----
-
 func.func @test_vector_reduction_1(%arg0: vector<16x16xf32>, %arg1: !xegpu.tensor_desc<16xf32>) {
-  %0 = arith.constant dense<0.000000e+00> : vector<16xf32>
-  %1 = vector.multi_reduction <add>, %arg0, %0 [0] : vector<16x16xf32> to vector<16xf32>
-  xegpu.store_nd %1, %arg1 : vector<16xf32>, !xegpu.tensor_desc<16xf32>
+  %cst = arith.constant dense<0.000000e+00> : vector<16xf32>
+  %0 = vector.multi_reduction <add>, %arg0, %cst [0] : vector<16x16xf32> to vector<16xf32>
+  xegpu.store_nd %0, %arg1  : vector<16xf32>, !xegpu.tensor_desc<16xf32>
   return
 }
 
-
-
 // -----
-
 func.func @test_vector_reduction_2(%arg0: vector<16x16xf32>, %arg1: !xegpu.tensor_desc<16xf32>) {
-  %0 = arith.constant dense<0.000000e+00> : vector<16xf32>
-  %1 = vector.multi_reduction <add>, %arg0, %0 [1] : vector<16x16xf32> to vector<16xf32>
-  xegpu.store_nd %1, %arg1 : vector<16xf32>, !xegpu.tensor_desc<16xf32>
+  %cst = arith.constant dense<0.000000e+00> : vector<16xf32>
+  %0 = vector.multi_reduction <add>, %arg0, %cst [1] : vector<16x16xf32> to vector<16xf32>
+  xegpu.store_nd %0, %arg1  : vector<16xf32>, !xegpu.tensor_desc<16xf32>
   return
 }

>From 27f011fbedc582bb85f3555bf38b00247906a6bf Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Fri, 7 Mar 2025 04:35:52 +0000
Subject: [PATCH 10/17] save work

---
 .../XeGPU/subgroup-map-propagation.mlir       | 138 +++++++++++++++++-
 1 file changed, 136 insertions(+), 2 deletions(-)

diff --git a/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir b/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
index c75931cd6a37b..d9d1ac80169ea 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
@@ -216,6 +216,16 @@ func.func @test_load_gather_op_2(%arg0: memref<256xf32>, %arg1: !xegpu.tensor_de
 }
 
 // -----
+// CHECK: argument: <block argument> of type 'memref<128xf32>' at index: 0
+// CHECK-NEXT: sg_map  : wi_layout: [16, 1], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[CST:.*]] = arith.constant dense<1.000000e+00> : vector<8x16xf32>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[CST0:.*]] = arith.constant dense<true> : vector<16xi1>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[CST1:.*]] = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[T0:.*]] = xegpu.create_tdesc %{{.*}}, %[[CST1]] : memref<128xf32>, vector<16xindex> -> !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<chunk_size = 8 : i64>>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [16, 1], wi_data: [1, 1]
 func.func @test_store_scatter_op_1(%arg0: memref<128xf32>) {
   %cst = arith.constant dense<1.000000e+00> : vector<8x16xf32>
   %cst_0 = arith.constant dense<true> : vector<16xi1>
@@ -226,6 +236,16 @@ func.func @test_store_scatter_op_1(%arg0: memref<128xf32>) {
 }
 
 // -----
+// CHECK: argument: <block argument> of type 'vector<16xf32>' at index: 0
+// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: argument: <block argument> of type 'memref<256xf32>' at index: 1
+// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[CST:.*]] = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[CST1:.*]] = arith.constant dense<true> : vector<16xi1>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[T0:.*]] = xegpu.create_tdesc %{{.*}}, %[[CST]] : memref<256xf32>, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
 func.func @test_store_scatter_op_2(%arg0: vector<16xf32>, %arg1: memref<256xf32>) {
   %cst = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
   %cst_0 = arith.constant dense<true> : vector<16xi1>
@@ -235,7 +255,29 @@ func.func @test_store_scatter_op_2(%arg0: vector<16xf32>, %arg1: memref<256xf32>
 }
 
 // -----
-func.func @test_bitcast_op_1(%arg0: memref<8x16xi16>, %arg1: memref<32x16xi8>, %arg2: memref<8x16xi32>) {
+// CHECK: argument: <block argument> of type 'memref<8x16xi16>' at index: 0
+// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: argument: <block argument> of type 'memref<32x16xi8>' at index: 1
+// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [4, 1]
+// CHECK-NEXT: argument: <block argument> of type 'memref<8x16xi32>' at index: 2
+// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %{{.*}} = arith.constant 0 : index
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xi16> -> !xegpu.tensor_desc<8x16xi16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[T1:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<32x16xi8> -> !xegpu.tensor_desc<32x16xi8>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [4, 1]
+// CHECK-NEXT: op    : %[[T2:.*]] = xegpu.load_nd %[[T0]]  : !xegpu.tensor_desc<8x16xi16> -> vector<8x16xi16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[T3:.*]] = xegpu.load_nd %[[T1]]  : !xegpu.tensor_desc<32x16xi8> -> vector<32x16xi8>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [4, 1]
+// CHECK-NEXT: op    : %[[T4:.*]] = vector.bitcast %[[T2]] : vector<8x16xi16> to vector<8x32xi8>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 2]
+// CHECK-NEXT: op    : %[[T5:.*]] = xegpu.dpas %[[T4]], %[[T3]] : vector<8x32xi8>, vector<32x16xi8> -> vector<8x16xi32>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[T6:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xi32> -> !xegpu.tensor_desc<8x16xi32>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+func.func @test_vector_bitcast_op_1(%arg0: memref<8x16xi16>, %arg1: memref<32x16xi8>, %arg2: memref<8x16xi32>) {
   %c0 = arith.constant 0 : index
   %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x16xi16> -> !xegpu.tensor_desc<8x16xi16>
   %1 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<32x16xi8> -> !xegpu.tensor_desc<32x16xi8>
@@ -249,7 +291,31 @@ func.func @test_bitcast_op_1(%arg0: memref<8x16xi16>, %arg1: memref<32x16xi8>, %
 }
 
 // -----
-func.func @test_bitcast_op_2(%arg0: memref<8x32xi8>, %arg1: memref<16x32xi8>, %arg2: memref<8x16xf32>) {
+// CHECK: argument: <block argument> of type 'memref<8x32xi8>' at index: 0
+// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 2]
+// CHECK-NEXT: argument: <block argument> of type 'memref<16x32xi8>' at index: 1
+// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [4, 1]
+// CHECK-NEXT: argument: <block argument> of type 'memref<8x16xf32>' at index: 2
+// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %{{.*}} = arith.constant 0 : index
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x32xi8> -> !xegpu.tensor_desc<8x32xi8>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 2]
+// CHECK-NEXT: op    : %[[T1:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<16x32xi8> -> !xegpu.tensor_desc<16x32xi8>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [4, 1]
+// CHECK-NEXT: op    : %[[T2:.*]] = xegpu.load_nd %[[T0]]  : !xegpu.tensor_desc<8x32xi8> -> vector<8x32xi8>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 2]
+// CHECK-NEXT: op    : %[[T3:.*]] = xegpu.load_nd %[[T1]]  : !xegpu.tensor_desc<16x32xi8> -> vector<16x32xi8>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [4, 1]
+// CHECK-NEXT: op    : %[[T4:.*]] = vector.bitcast %[[T2]] : vector<8x32xi8> to vector<8x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[T5:.*]] = vector.bitcast %[[T3]] : vector<16x32xi8> to vector<16x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1]
+// CHECK-NEXT: op    : %[[T6:.*]] = xegpu.dpas %[[T4]], %[[T5]] : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[T7:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+func.func @test_vector_bitcast_op_2(%arg0: memref<8x32xi8>, %arg1: memref<16x32xi8>, %arg2: memref<8x16xf32>) {
   %c0 = arith.constant 0 : index
   %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x32xi8> -> !xegpu.tensor_desc<8x32xi8>
   %1 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<16x32xi8> -> !xegpu.tensor_desc<16x32xi8>
@@ -264,6 +330,22 @@ func.func @test_bitcast_op_2(%arg0: memref<8x32xi8>, %arg1: memref<16x32xi8>, %a
 }
 
 // -----
+// CHECK: argument: <block argument> of type '!xegpu.tensor_desc<8x16xf16>' at index: 0
+// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<16x16xf16>' at index: 1
+// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [2, 1]
+// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<8x16xf32>' at index: 2
+// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[T0:.*]] = xegpu.load_nd %{{.*}}  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[T1:.*]] = xegpu.load_nd %{{.*}}  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1]
+// CHECK-NEXT: op    : %[[T2:.*]] = xegpu.load_nd %{{.*}}  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1]
+// CHECK-NEXT: op    : %[[T3:.*]] = arith.addf %[[T1]], %[[T2]] : vector<16x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1]
+// CHECK-NEXT: op    : %[[T4:.*]] = xegpu.dpas %[[T0]], %[[T3]] : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
 func.func @test_binary_op_1(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: !xegpu.tensor_desc<8x16xf32>) {
   %0 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
   %1 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
@@ -275,6 +357,24 @@ func.func @test_binary_op_1(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.t
 }
 
 // -----
+// CHECK: argument: <block argument> of type '!xegpu.tensor_desc<8x16xf16>' at index: 0
+// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<16x16xf16>' at index: 1
+// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<8x16xf32>' at index: 2
+// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<16x16xf16>' at index: 3
+// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[T0:.*]] = xegpu.load_nd %{{.*}}  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[T1:.*]] = xegpu.load_nd %{{.*}}  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[CST:.*]] = arith.constant dense<1.000000e+00> : vector<16x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[T2:.*]] = arith.addf %[[T1]], %[[CST]] : vector<16x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[T3:.*]] = xegpu.dpas %[[T0]], %[[T2]] : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
 func.func @test_binary_op_2(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: !xegpu.tensor_desc<8x16xf32>, %arg3: !xegpu.tensor_desc<16x16xf16>) {
   %0 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
   %1 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
@@ -287,6 +387,40 @@ func.func @test_binary_op_2(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.t
 }
 
 // -----
+// CHECK: argument: <block argument> of type 'memref<8x128xf16>' at index: 0
+// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: argument: <block argument> of type 'memref<128x16xf16>' at index: 1
+// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [2, 1]
+// CHECK-NEXT: argument: <block argument> of type 'memref<8x16xf32>' at index: 2
+// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %{{.*}} = arith.constant 0 : index
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %{{.*}} = arith.constant 128 : index
+// CHECK-NEXT: sg_map for result #0: Not assigned.
+// CHECK-NEXT: op    : %{{.*}} = arith.constant 16 : index
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x128xf16> -> !xegpu.tensor_desc<8x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[T1:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<128x16xf16> -> !xegpu.tensor_desc<16x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1]
+// CHECK-NEXT: op    : %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<8x16xf32>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[T4:.*]] = xegpu.load_nd %{{.*}}  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[T5:.*]] = xegpu.load_nd %{{.*}}  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1]
+// CHECK-NEXT: op    : %[[T6:.*]] = xegpu.dpas %[[T4]], %[[T5]], %{{.*}} : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[T7:.*]] = xegpu.update_nd_offset %{{.*}} : !xegpu.tensor_desc<8x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[T8:.*]] = xegpu.update_nd_offset %{{.8}} : !xegpu.tensor_desc<16x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1]
+// CHECK-NEXT: op    : scf.for
+// CHECK-NEXT: sg_map for result #0: Not assigned.
+// CHECK-NEXT: sg_map for result #1: Not assigned.
+// CHECK-NEXT: sg_map for result #2: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[T3:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
 func.func @test_for_op(%arg0: memref<8x128xf16>, %arg1: memref<128x16xf16>, %arg2: memref<8x16xf32>) {
   %c0 = arith.constant 0 : index
   %c128 = arith.constant 128 : index

>From 898b6195b8676a0e3db141cc5e6077c22d0c1ec3 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Fri, 7 Mar 2025 04:45:17 +0000
Subject: [PATCH 11/17] save work

---
 .../XeGPU/subgroup-map-propagation.mlir       | 78 ++++++++++++++++++-
 1 file changed, 77 insertions(+), 1 deletion(-)

diff --git a/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir b/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
index d9d1ac80169ea..83a17dd9d1a7e 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
@@ -413,7 +413,7 @@ func.func @test_binary_op_2(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.t
 // CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
 // CHECK-NEXT: op    : %[[T7:.*]] = xegpu.update_nd_offset %{{.*}} : !xegpu.tensor_desc<8x16xf16>
 // CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
-// CHECK-NEXT: op    : %[[T8:.*]] = xegpu.update_nd_offset %{{.8}} : !xegpu.tensor_desc<16x16xf16>
+// CHECK-NEXT: op    : %[[T8:.*]] = xegpu.update_nd_offset %{{.*}} : !xegpu.tensor_desc<16x16xf16>
 // CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1]
 // CHECK-NEXT: op    : scf.for
 // CHECK-NEXT: sg_map for result #0: Not assigned.
@@ -442,6 +442,25 @@ func.func @test_for_op(%arg0: memref<8x128xf16>, %arg1: memref<128x16xf16>, %arg
 }
 
 // -----
+// CHECK: function: test_if_op_1:
+// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<8x16xf16>' at index: 0
+// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<16x16xf16>' at index: 1
+// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [2, 1]
+// CHECK-NEXT: argument: <block argument> of type 'i1' at index: 2
+// CHECK-NEXT: sg_map  : Not assigned.
+// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<8x16xf32>' at index: 3
+// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[T0:.*]] = xegpu.load_nd %{{.*}}  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[T3:.*]] = xegpu.load_nd %{{.*}}  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1]
+// CHECK-NEXT: op    : %[[T4:.*]] = xegpu.load_nd %{{.*}}  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1]
+// CHECK-NEXT: op    : scf.if
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1]
+// CHECK-NEXT: op    : %[[T2:.*]] = xegpu.dpas %[[T0]], %{{.*}} : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
 func.func @test_if_op_1(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: i1, %arg3: !xegpu.tensor_desc<8x16xf32>) {
   %0 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
   %1 = scf.if %arg2 -> (vector<16x16xf16>) {
@@ -457,6 +476,27 @@ func.func @test_if_op_1(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tenso
 }
 
 // -----
+// CHECK: function: test_if_op_2:
+// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<8x16xf16>' at index: 0
+// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<16x16xf16>' at index: 1
+// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: argument: <block argument> of type 'i1' at index: 2
+// CHECK-NEXT: sg_map  : Not assigned.
+// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<8x16xf32>' at index: 3
+// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<16x16xf16>' at index: 4
+// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[T0:.*]] = xegpu.load_nd %{{.*}}  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[T3:.*]] = xegpu.load_nd %{{.*}}  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[T4:.*]] = xegpu.load_nd %{{.*}}  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : scf.if
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[T2:.*]] = xegpu.dpas %[[T0]], %{{.*}} : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
 func.func @test_if_op_2(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: i1, %arg3: !xegpu.tensor_desc<8x16xf32>, %arg4: !xegpu.tensor_desc<16x16xf16>) {
   %0 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
   %1 = scf.if %arg2 -> (vector<16x16xf16>) {
@@ -473,6 +513,25 @@ func.func @test_if_op_2(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tenso
 }
 
 // -----
+// CHECK: function: test_if_op_3:
+// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<8x16xf16>' at index: 0
+// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<16x16xf16>' at index: 1
+// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [2, 1]
+// CHECK-NEXT: argument: <block argument> of type 'vector<16x16xf16>' at index: 2
+// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [2, 1]
+// CHECK-NEXT: argument: <block argument> of type 'i1' at index: 3
+// CHECK-NEXT: sg_map  : Not assigned.
+// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<8x16xf32>' at index: 4
+// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[T0:.*]] = xegpu.load_nd %{{.*}}  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[T3:.*]] = xegpu.load_nd %{{.*}}  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1]
+// CHECK-NEXT: op    : scf.if
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1]
+// CHECK-NEXT: op    : %[[T2:.*]] = xegpu.dpas %[[T0]], %{{.*}} : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
 func.func @test_if_op_3(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: vector<16x16xf16>, %arg3: i1, %arg4: !xegpu.tensor_desc<8x16xf32>) {
   %0 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
   %1 = scf.if %arg3 -> (vector<16x16xf16>) {
@@ -487,6 +546,15 @@ func.func @test_if_op_3(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tenso
 }
 
 // -----
+// CHECK: function: test_vector_reduction_1:
+// CHECK-NEXT: argument: <block argument> of type 'vector<16x16xf32>' at index: 0
+// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<16xf32>' at index: 1
+// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<16xf32>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[T0:.*]] = vector.multi_reduction <add>, %{{.*}}, %[[CST]] [0] : vector<16x16xf32> to vector<16xf32>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
 func.func @test_vector_reduction_1(%arg0: vector<16x16xf32>, %arg1: !xegpu.tensor_desc<16xf32>) {
   %cst = arith.constant dense<0.000000e+00> : vector<16xf32>
   %0 = vector.multi_reduction <add>, %arg0, %cst [0] : vector<16x16xf32> to vector<16xf32>
@@ -495,6 +563,14 @@ func.func @test_vector_reduction_1(%arg0: vector<16x16xf32>, %arg1: !xegpu.tenso
 }
 
 // -----
+// CHECK: argument: <block argument> of type 'vector<16x16xf32>' at index: 0
+// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<16xf32>' at index: 1
+// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<16xf32>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op    : %[[T0:.*]] = vector.multi_reduction <add>, %{{.*}}, %[[CST]] [1] : vector<16x16xf32> to vector<16xf32>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
 func.func @test_vector_reduction_2(%arg0: vector<16x16xf32>, %arg1: !xegpu.tensor_desc<16xf32>) {
   %cst = arith.constant dense<0.000000e+00> : vector<16xf32>
   %0 = vector.multi_reduction <add>, %arg0, %cst [1] : vector<16x16xf32> to vector<16xf32>

>From 83404020a7459ca2c99cfb8e15027c2f31968213 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Fri, 7 Mar 2025 19:55:40 +0000
Subject: [PATCH 12/17] save work

---
 .../Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp    | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index de65f285d75b4..aba0586f30cfd 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -238,13 +238,13 @@ class SGMapPropagation : public SparseBackwardDataFlowAnalysis<SGMapLattice> {
   LogicalResult visitOperation(Operation *op, ArrayRef<SGMapLattice *> operands,
                                ArrayRef<const SGMapLattice *> results) override;
 
-  void visitBranchOperand(OpOperand &operand) override{};
+  void visitBranchOperand(OpOperand &operand) override {};
 
-  void visitCallOperand(OpOperand &operand) override{};
+  void visitCallOperand(OpOperand &operand) override {};
 
   void visitExternalCall(CallOpInterface call,
                          ArrayRef<SGMapLattice *> operands,
-                         ArrayRef<const SGMapLattice *> results) override{};
+                         ArrayRef<const SGMapLattice *> results) override {};
 
   void setToExitState(SGMapLattice *lattice) override {
     (void)lattice->meet(SGMap());

>From 549e8781d8899f7056a19548a3c6848639be8587 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Fri, 7 Mar 2025 20:07:39 +0000
Subject: [PATCH 13/17] save work

---
 mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index aba0586f30cfd..d4618bf0e24d6 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -144,7 +144,7 @@ SGMap SGMap::getTransposedLayout(ArrayRef<int64_t> permutation) const {
     newLayout.layout.push_back(layout.layout[idx]);
     newData.layout.push_back(data.layout[idx]);
   }
-  return SGMap(newLayout, data);
+  return SGMap(newLayout, newData);
 }
 
 ///===----------------------------------------------------------------------===///

>From abca365b5bf0851b8a31885395ce576df15b1bf3 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Fri, 7 Mar 2025 20:12:56 +0000
Subject: [PATCH 14/17] save work

---
 .../Dialect/XeGPU/subgroup-map-propagation.mlir    | 14 +++++++-------
 1 file changed, 7 insertions(+), 7 deletions(-)

diff --git a/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir b/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
index 83a17dd9d1a7e..92e37a20bf555 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
@@ -58,7 +58,7 @@ func.func @test_dpas_op_2(%arg0: vector<8x32xi8>, %arg1: vector<32x16xi8>, %arg2
 // CHECK-NEXT: argument: <block argument> of type 'memref<8x16xf16>' at index: 0
 // CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
 // CHECK-NEXT: argument: <block argument> of type 'memref<16x16xf16>' at index: 1
-// CHECK-NEXT: sg_map  : wi_layout: [16, 1], wi_data: [2, 1]
+// CHECK-NEXT: sg_map  : wi_layout: [16, 1], wi_data: [1, 2]
 // CHECK-NEXT: argument: <block argument> of type 'memref<8x16xf32>' at index: 2
 // CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
 // CHECK-NEXT: op    : %{{.*}} = arith.constant 0 : index
@@ -68,7 +68,7 @@ func.func @test_dpas_op_2(%arg0: vector<8x32xi8>, %arg1: vector<32x16xi8>, %arg2
 // CHECK-NEXT: op    : %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
 // CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
 // CHECK-NEXT: op    : %[[T1:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [16, 1], wi_data: [2, 1]
+// CHECK-NEXT: sg_map for result #0: wi_layout: [16, 1], wi_data: [1, 2]
 // CHECK-NEXT: op    : %[[T2:.*]] = xegpu.load_nd %[[T0]]  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
 // CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
 // CHECK-NEXT: op    : %[[T3:.*]] = xegpu.load_nd %[[T1]] <{transpose = array<i64: 1, 0>}> : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
@@ -94,7 +94,7 @@ func.func @test_transpose_op_1(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>
 // CHECK: argument: <block argument> of type 'memref<8x16xf16>' at index: 0
 // CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
 // CHECK-NEXT: argument: <block argument> of type 'memref<16x16xf16>' at index: 1
-// CHECK-NEXT: sg_map  : wi_layout: [16, 1], wi_data: [2, 1]
+// CHECK-NEXT: sg_map  : wi_layout: [16, 1], wi_data: [1, 2]
 // CHECK-NEXT: argument: <block argument> of type 'memref<8x16xf32>' at index: 2
 // CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
 // CHECK-NEXT: op    : %{{.*}} = arith.constant 0 : index
@@ -104,11 +104,11 @@ func.func @test_transpose_op_1(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>
 // CHECK-NEXT: op    : %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
 // CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
 // CHECK-NEXT: op    : %[[T1:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [16, 1], wi_data: [2, 1]
+// CHECK-NEXT: sg_map for result #0: wi_layout: [16, 1], wi_data: [1, 2]
 // CHECK-NEXT: op    : %[[T2:.*]] = xegpu.load_nd %[[T0]]  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
 // CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
 // CHECK-NEXT: op    : %[[T3:.*]] = xegpu.load_nd %[[T1]]  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [16, 1], wi_data: [2, 1]
+// CHECK-NEXT: sg_map for result #0: wi_layout: [16, 1], wi_data: [1, 2]
 // CHECK-NEXT: op    : %[[T4:.*]] = vector.transpose %[[T3]], [1, 0] : vector<16x16xf16> to vector<16x16xf16>
 // CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1]
 // CHECK-NEXT: op    : %[[T5:.*]] = xegpu.dpas %[[T2]], %[[T4]], %[[CST]] : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
@@ -158,7 +158,7 @@ func.func @test_extf_truncf_op(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegp
 // CHECK-NEXT: argument: <block argument> of type 'memref<8x16xf16>' at index: 0
 // CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
 // CHECK-NEXT: argument: <block argument> of type 'memref<256xf16>' at index: 1
-// CHECK-NEXT: sg_map  : wi_layout: [16, 1], wi_data: [2, 1]
+// CHECK-NEXT: sg_map  : wi_layout: [16, 1], wi_data: [1, 2]
 // CHECK-NEXT: argument: <block argument> of type 'memref<8x16xf32>' at index: 2
 // CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
 // CHECK-NEXT: op    : %{{.*}} = arith.constant 0 : index
@@ -172,7 +172,7 @@ func.func @test_extf_truncf_op(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegp
 // CHECK-NEXT: op    : %[[CST0:.*]] = arith.constant dense<true> : vector<16xi1>
 // CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
 // CHECK-NEXT: op    : %[[T2:.*]] = xegpu.create_tdesc %{{.*}}, %[[CST]] : memref<256xf16>, vector<16xindex> -> !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16 : i64>>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [16, 1], wi_data: [2, 1]
+// CHECK-NEXT: sg_map for result #0: wi_layout: [16, 1], wi_data: [1, 2]
 // CHECK-NEXT: op    : %[[T3:.*]] = xegpu.load %[[T2]], %[[CST0]] <{transpose}> : !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16 : i64>>, vector<16xi1> -> vector<16x16xf16>
 // CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1]
 // CHECK-NEXT: op    : %[[T4:.*]] = xegpu.dpas %[[T1]], %[[T3]] : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>

>From 5bd5db50d8bdd6c3802e270741b9e0e1d744bd34 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Fri, 7 Mar 2025 23:14:23 +0000
Subject: [PATCH 15/17] save work

---
 .../Transforms/XeGPUSubgroupDistribute.cpp    | 113 ++++++++++++++----
 .../XeGPU/subgroup-map-propagation.mlir       |  60 +++++-----
 2 files changed, 117 insertions(+), 56 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index d4618bf0e24d6..7136993d0e968 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -15,6 +15,7 @@
 #include "mlir/Dialect/XeGPU/IR/XeGPU.h"
 #include "mlir/Dialect/XeGPU/Transforms/Passes.h"
 #include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
+#include "mlir/IR/Builders.h"
 #include "llvm/Support/raw_ostream.h"
 
 namespace mlir {
@@ -157,6 +158,8 @@ struct SGMapLattice : public Lattice<SGMap> {
   using Lattice::Lattice;
 };
 
+/// Helper Functions
+
 /// Helper Function to get the expected layouts for DPAS operands.
 static SGMap getSGMapForDPASOperand(Type operandTy, unsigned operandNum) {
   int packingFactorForB = packedBSizeInBits / operandTy.getIntOrFloatBitWidth();
@@ -174,16 +177,33 @@ static SGMap getSGMapForDPASOperand(Type operandTy, unsigned operandNum) {
 /// However, the minimum granularity of data access per work item is 16-bits.
 /// So, if the bitwidth of the type is less than 16, we need to pack the data to
 /// 16-bits.
-static SGMap getDefaultSgMap(Type ty) {
-  int packingFactor = 1;
-  if (ty.getIntOrFloatBitWidth() < packedASizeInBits)
-    packingFactor = packedBSizeInBits / ty.getIntOrFloatBitWidth();
-  return SGMap(WiLayout({1, subgroupSize}), WiData({1, packingFactor}));
+// static SGMap getDefaultSgMap(Type ty, unsigned rank) {
+//   int packingFactor = 1;
+//   if (ty.getIntOrFloatBitWidth() < packedASizeInBits)
+//     packingFactor = packedBSizeInBits / ty.getIntOrFloatBitWidth();
+//   return SGMap(WiLayout({1, subgroupSize}), WiData({1, packingFactor}));
+// }
+
+/// Helper Function to get the default layout for uniform values like constants.
+static SGMap getDefaultSgMap(unsigned rank) {
+  assert((rank == 1 || rank == 2) && "Expected 0D or 1D vector.");
+  if (rank == 1)
+    return SGMap(WiLayout({subgroupSize}), WiData({1}));
+  return SGMap(WiLayout({1, subgroupSize}), WiData({1, 1}));
 }
 
-/// Helper Function to get the default layout representing constants.
-static SGMap getDefaultSgMap() {
-  return SGMap(WiLayout({1, subgroupSize}), WiData({1, 1}));
+static SGMap getDefaultSgMap(VectorType vectorTy) {
+  /// Expecting a 1D or 2D vector.
+  assert((vectorTy.getRank() == 1 || vectorTy.getRank() == 2) &&
+         "Expected 1D or 2D vector.");
+  /// If the rank is 1, then return default layout for 1D vector.
+  if (vectorTy.getRank() == 1)
+    return getDefaultSgMap(1);
+  int packingFactor = 1;
+  auto bitwidth = vectorTy.getElementType().getIntOrFloatBitWidth();
+  if (bitwidth < packedASizeInBits)
+    packingFactor = packedBSizeInBits / bitwidth;
+  return SGMap(WiLayout({1, subgroupSize}), WiData({1, packingFactor}));
 }
 
 ///===----------------------------------------------------------------------===///
@@ -230,6 +250,14 @@ class SGMapPropagation : public SparseBackwardDataFlowAnalysis<SGMapLattice> {
                          ArrayRef<SGMapLattice *> operands,
                          ArrayRef<const SGMapLattice *> results);
 
+  void visitUpdateNdOffsetOp(xegpu::UpdateNdOffsetOp updateNdOffset,
+                             ArrayRef<SGMapLattice *> operands,
+                             ArrayRef<const SGMapLattice *> results);
+
+  void visitVectorMultiReductionOp(vector::MultiDimReductionOp reduction,
+                                   ArrayRef<SGMapLattice *> operands,
+                                   ArrayRef<const SGMapLattice *> results);
+
 public:
   SGMapPropagation(DataFlowSolver &solver, SymbolTableCollection &symbolTable)
       : SparseBackwardDataFlowAnalysis(solver, symbolTable) {}
@@ -274,6 +302,10 @@ SGMapPropagation::visitOperation(Operation *op,
     visitCreateDescOp(createDesc, operands, results);
   else if (auto storeScatter = dyn_cast<xegpu::StoreScatterOp>(op))
     visitStoreScatterOp(storeScatter, operands, results);
+  else if (auto updateNdOffset = dyn_cast<xegpu::UpdateNdOffsetOp>(op))
+    visitUpdateNdOffsetOp(updateNdOffset, operands, results);
+  else if (auto reduction = dyn_cast<vector::MultiDimReductionOp>(op))
+    visitVectorMultiReductionOp(reduction, operands, results);
   /// All other ops
   else {
     for (const SGMapLattice *r : results) {
@@ -293,6 +325,43 @@ SGMapPropagation::visitOperation(Operation *op,
   return success();
 }
 
+void SGMapPropagation::visitVectorMultiReductionOp(
+    vector::MultiDimReductionOp reduction, ArrayRef<SGMapLattice *> operands,
+    ArrayRef<const SGMapLattice *> results) {
+  /// The layout of the result must be present.
+  auto resultLayout = results[0]->getValue();
+  if (!resultLayout.isAssigned())
+    return;
+  /// We only consider 2D -> 1D reductions at this point.
+  assert(resultLayout.getLayout().size() == 1 &&
+         "Expected 1D layout for reduction result.");
+  /// Given that the result is 1D, the layout of the operand should be 2D with
+  /// default layout.
+  auto operandLayout = getDefaultSgMap(2);
+  operandLayout.print(llvm::outs());
+  propagateIfChanged(operands[0], operands[0]->meet(operandLayout));
+  /// Accumulator should have the same layout as the result.
+  propagateIfChanged(operands[1], operands[1]->meet(resultLayout));
+}
+
+/// Propagate the layout of the result tensor to the source tensor descriptor in
+/// UpdateNdOffsetOp.
+void SGMapPropagation::visitUpdateNdOffsetOp(
+    xegpu::UpdateNdOffsetOp updateNdOffset, ArrayRef<SGMapLattice *> operands,
+    ArrayRef<const SGMapLattice *> results) {
+  /// The layout of the result must be present.
+  auto resultLayout = results[0]->getValue();
+  if (!resultLayout.isAssigned())
+    return;
+  /// Propagate the layout to the source operand.
+  propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
+  /// For all other operands use 1D default layout.
+  SGMap layout = getDefaultSgMap(1);
+  for (size_t i = 1; i < operands.size(); ++i) {
+    propagateIfChanged(operands[i], operands[i]->meet(layout));
+  }
+}
+
 /// Set the layouts for DPAS A, B, and C operands.
 void SGMapPropagation::visitDpasOp(xegpu::DpasOp dpas,
                                    ArrayRef<SGMapLattice *> operands,
@@ -314,8 +383,7 @@ void SGMapPropagation::visitDpasOp(xegpu::DpasOp dpas,
 void SGMapPropagation::visitStoreNdOp(xegpu::StoreNdOp store,
                                       ArrayRef<SGMapLattice *> operands,
                                       ArrayRef<const SGMapLattice *> results) {
-  auto storeLayout =
-      getDefaultSgMap(store.getTensorDescType().getElementType());
+  auto storeLayout = getDefaultSgMap(store.getValueType());
   /// Both operands should have the same layout
   for (SGMapLattice *operand : operands) {
     propagateIfChanged(operand, operand->meet(storeLayout));
@@ -334,7 +402,6 @@ void SGMapPropagation::visitLoadNdOp(xegpu::LoadNdOp load,
   SGMap tensorDescLayout = valueLayout;
   /// LoadNdOp has the transpose effect. However, at the stage of this analyis
   /// this effect is not expected and should be abstracted away. Emit a warning.
-  /// TODO: Handle this case properly when `order` is introduced in the sg_map.
   if (auto transpose = load.getTranspose()) {
     load.emitWarning("Transpose effect is not expected for LoadNdOp");
     tensorDescLayout = valueLayout.getTransposedLayout(transpose.value());
@@ -409,15 +476,12 @@ void SGMapPropagation::visitLoadGatherOp(
     /// LoadGatherOp has the transpose effect. However, at the stage of this
     /// analyis this effect is not expected and should be abstracted away. Emit
     /// a warning.
-    /// TODO: Handle this case properly when `order` is introduced in the
-    /// sg_map.
     load.emitWarning("Transpose effect is not expected for LoadGatherOp");
     tensorDescLayout = valueLayout.getTransposedLayout({1, 0});
   } else
     tensorDescLayout = valueLayout;
-  /// Mask operand should have the same layout as the value but with wi_data =
-  /// [1, 1]
-  SGMap maskLayout = SGMap(valueLayout.getLayout(), WiData({1, 1}));
+  /// Mask operand should have 1D default layout.
+  auto maskLayout = getDefaultSgMap(1);
   /// Propagate the new layout to the tensor descriptor operand.
   propagateIfChanged(operands[0], operands[0]->meet(tensorDescLayout));
   /// Propagate the new layout to the mask operand.
@@ -434,8 +498,8 @@ void SGMapPropagation::visitCreateNdDescOp(
     return;
   /// Propagate the layout to the source operand.
   propagateIfChanged(operands[0], operands[0]->meet(descLayout));
-  /// For all other operands propagate the descriptor layout.
-  SGMap layout = getDefaultSgMap();
+  /// For all other operands use 1D default layout.
+  SGMap layout = getDefaultSgMap(1);
   for (size_t i = 1; i < operands.size(); ++i) {
     propagateIfChanged(operands[i], operands[i]->meet(layout));
   }
@@ -452,8 +516,8 @@ void SGMapPropagation::visitCreateDescOp(
     return;
   /// Propagate the layout to the source operand.
   propagateIfChanged(operands[0], operands[0]->meet(descLayout));
-  /// For offset operand propagate the default layout.
-  SGMap layout = getDefaultSgMap();
+  /// For offset operand propagate 1D default layout.
+  SGMap layout = getDefaultSgMap(1);
   propagateIfChanged(operands[1], operands[1]->meet(layout));
 }
 
@@ -462,15 +526,12 @@ void SGMapPropagation::visitCreateDescOp(
 void SGMapPropagation::visitStoreScatterOp(
     xegpu::StoreScatterOp storeScatter, ArrayRef<SGMapLattice *> operands,
     ArrayRef<const SGMapLattice *> results) {
-  auto valueLayout =
-      getDefaultSgMap(storeScatter.getTensorDescType().getElementType());
+  auto valueLayout = getDefaultSgMap(storeScatter.getValueType());
   SGMap storeScatterLayout;
   if (storeScatter.getTranspose()) {
     /// StoreScatteOp allows transpose effect. However, at the stage of this
     /// analyis this effect is not expected and should be abstracted away. Emit
     /// a warning.
-    /// TODO: Handle this case properly when `order` is introduced in the
-    /// sg_map.
     storeScatter.emitWarning(
         "Transpose effect is not expected for StoreScatterOp");
     storeScatterLayout = valueLayout.getTransposedLayout({1, 0});
@@ -480,8 +541,8 @@ void SGMapPropagation::visitStoreScatterOp(
   propagateIfChanged(operands[0], operands[0]->meet(valueLayout));
   /// Propagate the tensor descriptor layout.
   propagateIfChanged(operands[1], operands[1]->meet(storeScatterLayout));
-  /// Use default layout for mask operand.
-  auto maskLayout = getDefaultSgMap();
+  /// Use default 1D layout for mask operand.
+  auto maskLayout = getDefaultSgMap(1);
   propagateIfChanged(operands[2], operands[2]->meet(maskLayout));
 }
 
diff --git a/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir b/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
index 92e37a20bf555..7d8bd20391463 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
@@ -2,7 +2,7 @@
 
 // CHECK: function: test_dpas_op_1:
 // CHECK: op    : %{{.*}} = arith.constant 0 : index
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
 // CHECK-NEXT: op    : %{{.*}} = arith.constant dense<0.000000e+00> : vector<8x16xf32>
 // CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
 // CHECK-NEXT: op    : %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
@@ -40,7 +40,7 @@ func.func @test_dpas_op_1(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %ar
 // CHECK-NEXT: argument: <block argument> of type 'memref<8x16xi32>' at index: 2
 // CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
 // CHECK-NEXT: op    : %{{.*}} = arith.constant 0 : index
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
 // CHECK-NEXT: op    : %[[T0:.*]] = xegpu.dpas %{{.*}} : vector<8x32xi8>, vector<32x16xi8> -> vector<8x16xi32>
 // CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
 // CHECK-NEXT: op    : %[[T1:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xi32> -> !xegpu.tensor_desc<8x16xi32>
@@ -62,7 +62,7 @@ func.func @test_dpas_op_2(%arg0: vector<8x32xi8>, %arg1: vector<32x16xi8>, %arg2
 // CHECK-NEXT: argument: <block argument> of type 'memref<8x16xf32>' at index: 2
 // CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
 // CHECK-NEXT: op    : %{{.*}} = arith.constant 0 : index
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
 // CHECK-NEXT: op    : %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<8x16xf32>
 // CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
 // CHECK-NEXT: op    : %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
@@ -98,7 +98,7 @@ func.func @test_transpose_op_1(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>
 // CHECK-NEXT: argument: <block argument> of type 'memref<8x16xf32>' at index: 2
 // CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
 // CHECK-NEXT: op    : %{{.*}} = arith.constant 0 : index
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
 // CHECK-NEXT: op    : %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<8x16xf32>
 // CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
 // CHECK-NEXT: op    : %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
@@ -162,15 +162,15 @@ func.func @test_extf_truncf_op(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegp
 // CHECK-NEXT: argument: <block argument> of type 'memref<8x16xf32>' at index: 2
 // CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
 // CHECK-NEXT: op    : %{{.*}} = arith.constant 0 : index
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
 // CHECK-NEXT: op    : %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
 // CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
 // CHECK-NEXT: op    : %[[T1:.*]] = xegpu.load_nd %[[T0]]  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
 // CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
 // CHECK-NEXT: op    : %[[CST:.*]] = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
 // CHECK-NEXT: op    : %[[CST0:.*]] = arith.constant dense<true> : vector<16xi1>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
 // CHECK-NEXT: op    : %[[T2:.*]] = xegpu.create_tdesc %{{.*}}, %[[CST]] : memref<256xf16>, vector<16xindex> -> !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16 : i64>>
 // CHECK-NEXT: sg_map for result #0: wi_layout: [16, 1], wi_data: [1, 2]
 // CHECK-NEXT: op    : %[[T3:.*]] = xegpu.load %[[T2]], %[[CST0]] <{transpose}> : !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16 : i64>>, vector<16xi1> -> vector<16x16xf16>
@@ -195,17 +195,17 @@ func.func @test_load_gather_op_1(%arg0: memref<8x16xf16>, %arg1: memref<256xf16>
 
 // -----
 // CHECK: argument: <block argument> of type 'memref<256xf32>' at index: 0
-// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map  : wi_layout: [16], wi_data: [1]
 // CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<16xf32>' at index: 1
-// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map  : wi_layout: [16], wi_data: [1]
 // CHECK-NEXT: op    : %[[CST:.*]] = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
 // CHECK-NEXT: op    : %[[CST0:.*]] = arith.constant dense<true> : vector<16xi1>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
 // CHECK-NEXT: op    : %[[T0:.*]] = xegpu.create_tdesc %{{.*}}, %[[CST]] : memref<256xf32>, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
 // CHECK-NEXT: op    : %[[T1]] = xegpu.load %[[T0]], %[[CST0]]  : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> -> vector<16xf32>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
 func.func @test_load_gather_op_2(%arg0: memref<256xf32>, %arg1: !xegpu.tensor_desc<16xf32>) {
   %cst = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
   %cst_0 = arith.constant dense<true> : vector<16xi1>
@@ -221,9 +221,9 @@ func.func @test_load_gather_op_2(%arg0: memref<256xf32>, %arg1: !xegpu.tensor_de
 // CHECK-NEXT: op    : %[[CST:.*]] = arith.constant dense<1.000000e+00> : vector<8x16xf32>
 // CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
 // CHECK-NEXT: op    : %[[CST0:.*]] = arith.constant dense<true> : vector<16xi1>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
 // CHECK-NEXT: op    : %[[CST1:.*]] = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
 // CHECK-NEXT: op    : %[[T0:.*]] = xegpu.create_tdesc %{{.*}}, %[[CST1]] : memref<128xf32>, vector<16xindex> -> !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<chunk_size = 8 : i64>>
 // CHECK-NEXT: sg_map for result #0: wi_layout: [16, 1], wi_data: [1, 1]
 func.func @test_store_scatter_op_1(%arg0: memref<128xf32>) {
@@ -237,15 +237,15 @@ func.func @test_store_scatter_op_1(%arg0: memref<128xf32>) {
 
 // -----
 // CHECK: argument: <block argument> of type 'vector<16xf32>' at index: 0
-// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map  : wi_layout: [16], wi_data: [1]
 // CHECK-NEXT: argument: <block argument> of type 'memref<256xf32>' at index: 1
-// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map  : wi_layout: [16], wi_data: [1]
 // CHECK-NEXT: op    : %[[CST:.*]] = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
 // CHECK-NEXT: op    : %[[CST1:.*]] = arith.constant dense<true> : vector<16xi1>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
 // CHECK-NEXT: op    : %[[T0:.*]] = xegpu.create_tdesc %{{.*}}, %[[CST]] : memref<256xf32>, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
 func.func @test_store_scatter_op_2(%arg0: vector<16xf32>, %arg1: memref<256xf32>) {
   %cst = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
   %cst_0 = arith.constant dense<true> : vector<16xi1>
@@ -262,7 +262,7 @@ func.func @test_store_scatter_op_2(%arg0: vector<16xf32>, %arg1: memref<256xf32>
 // CHECK-NEXT: argument: <block argument> of type 'memref<8x16xi32>' at index: 2
 // CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
 // CHECK-NEXT: op    : %{{.*}} = arith.constant 0 : index
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
 // CHECK-NEXT: op    : %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xi16> -> !xegpu.tensor_desc<8x16xi16>
 // CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
 // CHECK-NEXT: op    : %[[T1:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<32x16xi8> -> !xegpu.tensor_desc<32x16xi8>
@@ -298,7 +298,7 @@ func.func @test_vector_bitcast_op_1(%arg0: memref<8x16xi16>, %arg1: memref<32x16
 // CHECK-NEXT: argument: <block argument> of type 'memref<8x16xf32>' at index: 2
 // CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
 // CHECK-NEXT: op    : %{{.*}} = arith.constant 0 : index
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
 // CHECK-NEXT: op    : %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x32xi8> -> !xegpu.tensor_desc<8x32xi8>
 // CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 2]
 // CHECK-NEXT: op    : %[[T1:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<16x32xi8> -> !xegpu.tensor_desc<16x32xi8>
@@ -394,11 +394,11 @@ func.func @test_binary_op_2(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.t
 // CHECK-NEXT: argument: <block argument> of type 'memref<8x16xf32>' at index: 2
 // CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
 // CHECK-NEXT: op    : %{{.*}} = arith.constant 0 : index
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
 // CHECK-NEXT: op    : %{{.*}} = arith.constant 128 : index
 // CHECK-NEXT: sg_map for result #0: Not assigned.
 // CHECK-NEXT: op    : %{{.*}} = arith.constant 16 : index
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
 // CHECK-NEXT: op    : %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x128xf16> -> !xegpu.tensor_desc<8x16xf16>
 // CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
 // CHECK-NEXT: op    : %[[T1:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<128x16xf16> -> !xegpu.tensor_desc<16x16xf16>
@@ -550,11 +550,11 @@ func.func @test_if_op_3(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tenso
 // CHECK-NEXT: argument: <block argument> of type 'vector<16x16xf32>' at index: 0
 // CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
 // CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<16xf32>' at index: 1
-// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map  : wi_layout: [16], wi_data: [1]
 // CHECK-NEXT: op    : %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<16xf32>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
 // CHECK-NEXT: op    : %[[T0:.*]] = vector.multi_reduction <add>, %{{.*}}, %[[CST]] [0] : vector<16x16xf32> to vector<16xf32>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
 func.func @test_vector_reduction_1(%arg0: vector<16x16xf32>, %arg1: !xegpu.tensor_desc<16xf32>) {
   %cst = arith.constant dense<0.000000e+00> : vector<16xf32>
   %0 = vector.multi_reduction <add>, %arg0, %cst [0] : vector<16x16xf32> to vector<16xf32>
@@ -566,11 +566,11 @@ func.func @test_vector_reduction_1(%arg0: vector<16x16xf32>, %arg1: !xegpu.tenso
 // CHECK: argument: <block argument> of type 'vector<16x16xf32>' at index: 0
 // CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
 // CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<16xf32>' at index: 1
-// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map  : wi_layout: [16], wi_data: [1]
 // CHECK-NEXT: op    : %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<16xf32>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
 // CHECK-NEXT: op    : %[[T0:.*]] = vector.multi_reduction <add>, %{{.*}}, %[[CST]] [1] : vector<16x16xf32> to vector<16xf32>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
 func.func @test_vector_reduction_2(%arg0: vector<16x16xf32>, %arg1: !xegpu.tensor_desc<16xf32>) {
   %cst = arith.constant dense<0.000000e+00> : vector<16xf32>
   %0 = vector.multi_reduction <add>, %arg0, %cst [1] : vector<16x16xf32> to vector<16xf32>

>From 5820c15c393a5b29e5aa8533627de51e480b6275 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Tue, 11 Mar 2025 18:54:34 +0000
Subject: [PATCH 16/17] save work

---
 .../Transforms/XeGPUSubgroupDistribute.cpp    | 56 +++++--------
 .../XeGPU/subgroup-map-propagation.mlir       | 78 ++++++++++---------
 2 files changed, 63 insertions(+), 71 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 7136993d0e968..14a9f5954edcf 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -16,6 +16,7 @@
 #include "mlir/Dialect/XeGPU/Transforms/Passes.h"
 #include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
 #include "mlir/IR/Builders.h"
+#include "mlir/Support/LLVM.h"
 #include "llvm/Support/raw_ostream.h"
 
 namespace mlir {
@@ -242,10 +243,6 @@ class SGMapPropagation : public SparseBackwardDataFlowAnalysis<SGMapLattice> {
                             ArrayRef<SGMapLattice *> operands,
                             ArrayRef<const SGMapLattice *> results);
 
-  void visitCreateNdDescOp(xegpu::CreateNdDescOp createNdDesc,
-                           ArrayRef<SGMapLattice *> operands,
-                           ArrayRef<const SGMapLattice *> results);
-
   void visitCreateDescOp(xegpu::CreateDescOp createDesc,
                          ArrayRef<SGMapLattice *> operands,
                          ArrayRef<const SGMapLattice *> results);
@@ -296,8 +293,6 @@ SGMapPropagation::visitOperation(Operation *op,
     visitVectorBitcastOp(bitcast, operands, results);
   else if (auto loadGather = dyn_cast<xegpu::LoadGatherOp>(op))
     visitLoadGatherOp(loadGather, operands, results);
-  else if (auto createNdDesc = dyn_cast<xegpu::CreateNdDescOp>(op))
-    visitCreateNdDescOp(createNdDesc, operands, results);
   else if (auto createDesc = dyn_cast<xegpu::CreateDescOp>(op))
     visitCreateDescOp(createDesc, operands, results);
   else if (auto storeScatter = dyn_cast<xegpu::StoreScatterOp>(op))
@@ -306,6 +301,10 @@ SGMapPropagation::visitOperation(Operation *op,
     visitUpdateNdOffsetOp(updateNdOffset, operands, results);
   else if (auto reduction = dyn_cast<vector::MultiDimReductionOp>(op))
     visitVectorMultiReductionOp(reduction, operands, results);
+  /// No need to propagate the layout to operands in CreateNdDescOp because they
+  /// are scalars (offsets, sizes, etc.).
+  else if (auto createNdDesc = dyn_cast<xegpu::CreateNdDescOp>(op))
+    return success();
   /// All other ops
   else {
     for (const SGMapLattice *r : results) {
@@ -355,11 +354,6 @@ void SGMapPropagation::visitUpdateNdOffsetOp(
     return;
   /// Propagate the layout to the source operand.
   propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
-  /// For all other operands use 1D default layout.
-  SGMap layout = getDefaultSgMap(1);
-  for (size_t i = 1; i < operands.size(); ++i) {
-    propagateIfChanged(operands[i], operands[i]->meet(layout));
-  }
 }
 
 /// Set the layouts for DPAS A, B, and C operands.
@@ -403,7 +397,8 @@ void SGMapPropagation::visitLoadNdOp(xegpu::LoadNdOp load,
   /// LoadNdOp has the transpose effect. However, at the stage of this analyis
   /// this effect is not expected and should be abstracted away. Emit a warning.
   if (auto transpose = load.getTranspose()) {
-    load.emitWarning("Transpose effect is not expected for LoadNdOp");
+    load.emitWarning("Transpose effect is not expected for LoadNdOp at "
+                     "SGMapPropagation stage.");
     tensorDescLayout = valueLayout.getTransposedLayout(transpose.value());
   }
   /// Propagate the new layout to the tensor descriptor operand.
@@ -476,7 +471,8 @@ void SGMapPropagation::visitLoadGatherOp(
     /// LoadGatherOp has the transpose effect. However, at the stage of this
     /// analyis this effect is not expected and should be abstracted away. Emit
     /// a warning.
-    load.emitWarning("Transpose effect is not expected for LoadGatherOp");
+    load.emitWarning("Transpose effect is not expected for LoadGatherOp at "
+                     "SGMapPropagation stage.");
     tensorDescLayout = valueLayout.getTransposedLayout({1, 0});
   } else
     tensorDescLayout = valueLayout;
@@ -488,24 +484,7 @@ void SGMapPropagation::visitLoadGatherOp(
   propagateIfChanged(operands[1], operands[1]->meet(maskLayout));
 }
 
-/// Propagate the layout of the descriptor to the operands in CreateNdDescOp.
-void SGMapPropagation::visitCreateNdDescOp(
-    xegpu::CreateNdDescOp createNdDesc, ArrayRef<SGMapLattice *> operands,
-    ArrayRef<const SGMapLattice *> results) {
-  auto descLayout = results[0]->getValue();
-  /// Need the layout of the descriptor to propagate to the operands.
-  if (!descLayout.isAssigned())
-    return;
-  /// Propagate the layout to the source operand.
-  propagateIfChanged(operands[0], operands[0]->meet(descLayout));
-  /// For all other operands use 1D default layout.
-  SGMap layout = getDefaultSgMap(1);
-  for (size_t i = 1; i < operands.size(); ++i) {
-    propagateIfChanged(operands[i], operands[i]->meet(layout));
-  }
-}
-
-/// Propagate the layout of the descriptor to the source and offset operands in
+/// Propagate the layout of the descriptor to the vector offset operand in
 /// CreateDescOp.
 void SGMapPropagation::visitCreateDescOp(
     xegpu::CreateDescOp createDesc, ArrayRef<SGMapLattice *> operands,
@@ -514,8 +493,6 @@ void SGMapPropagation::visitCreateDescOp(
   /// Need the layout of the descriptor to propagate to the operands.
   if (!descLayout.isAssigned())
     return;
-  /// Propagate the layout to the source operand.
-  propagateIfChanged(operands[0], operands[0]->meet(descLayout));
   /// For offset operand propagate 1D default layout.
   SGMap layout = getDefaultSgMap(1);
   propagateIfChanged(operands[1], operands[1]->meet(layout));
@@ -526,14 +503,23 @@ void SGMapPropagation::visitCreateDescOp(
 void SGMapPropagation::visitStoreScatterOp(
     xegpu::StoreScatterOp storeScatter, ArrayRef<SGMapLattice *> operands,
     ArrayRef<const SGMapLattice *> results) {
+  /// Currently, for 2D StoreScatterOp we expect that the height dimension of
+  /// the tensor descriptor is evenly divisible by the subgroup size.
+  /// TODO: Add support for other 2D shapes.
+  auto tdescShape = storeScatter.getTensorDescType().getShape();
+  if (tdescShape.size() > 1 && tdescShape[0] % subgroupSize != 0) {
+    storeScatter.emitError("Height dimension of the tensor descriptor should "
+                           "be evenly divisible by the subgroup size.");
+    return;
+  }
   auto valueLayout = getDefaultSgMap(storeScatter.getValueType());
   SGMap storeScatterLayout;
   if (storeScatter.getTranspose()) {
     /// StoreScatteOp allows transpose effect. However, at the stage of this
     /// analyis this effect is not expected and should be abstracted away. Emit
     /// a warning.
-    storeScatter.emitWarning(
-        "Transpose effect is not expected for StoreScatterOp");
+    storeScatter.emitWarning("Transpose effect is not expected for "
+                             "StoreScatterOp at SGMapPropagation stage.");
     storeScatterLayout = valueLayout.getTransposedLayout({1, 0});
   } else
     storeScatterLayout = valueLayout;
diff --git a/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir b/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
index 7d8bd20391463..1f16c0f738af9 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
@@ -1,8 +1,14 @@
 // RUN: mlir-opt -xegpu-subgroup-distribute='print-analysis-only=true' -split-input-file %s | FileCheck %s
 
-// CHECK: function: test_dpas_op_1:
-// CHECK: op    : %{{.*}} = arith.constant 0 : index
-// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
+// CHECK: function: test_dpas_f16:
+// CHECK-NEXT: argument: <block argument> of type 'memref<8x16xf16>' at index: 0
+// CHECK-NEXT: sg_map  : Not assigned.
+// CHECK-NEXT: argument: <block argument> of type 'memref<16x16xf16>' at index: 1
+// CHECK-NEXT: sg_map  : Not assigned.
+// CHECK-NEXT: argument: <block argument> of type 'memref<8x16xf32>' at index: 2
+// CHECK-NEXT: sg_map  : Not assigned.
+// CHECK-NEXT: op    : %{{.*}} = arith.constant 0 : index
+// CHECK-NEXT: sg_map for result #0: Not assigned.
 // CHECK-NEXT: op    : %{{.*}} = arith.constant dense<0.000000e+00> : vector<8x16xf32>
 // CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
 // CHECK-NEXT: op    : %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
@@ -17,7 +23,7 @@
 // CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
 // CHECK-NEXT: op    : %{{.*}} = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
 // CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
-func.func @test_dpas_op_1(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<8x16xf32>) {
+func.func @test_dpas_f16(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<8x16xf32>) {
   %c0 = arith.constant 0 : index
   %cst = arith.constant dense<0.000000e+00> : vector<8x16xf32>
   %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
@@ -32,20 +38,20 @@ func.func @test_dpas_op_1(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %ar
 
 
 // -----
-// CHECK: function: test_dpas_op_2:
+// CHECK: function: test_dpas_i8:
 // CHECK-NEXT: argument: <block argument> of type 'vector<8x32xi8>' at index: 0
 // CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 2]
 // CHECK-NEXT: argument: <block argument> of type 'vector<32x16xi8>' at index: 1
 // CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [4, 1]
 // CHECK-NEXT: argument: <block argument> of type 'memref<8x16xi32>' at index: 2
-// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map  : Not assigned.
 // CHECK-NEXT: op    : %{{.*}} = arith.constant 0 : index
-// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
+// CHECK-NEXT: sg_map for result #0: Not assigned.
 // CHECK-NEXT: op    : %[[T0:.*]] = xegpu.dpas %{{.*}} : vector<8x32xi8>, vector<32x16xi8> -> vector<8x16xi32>
 // CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
 // CHECK-NEXT: op    : %[[T1:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xi32> -> !xegpu.tensor_desc<8x16xi32>
 // CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
-func.func @test_dpas_op_2(%arg0: vector<8x32xi8>, %arg1: vector<32x16xi8>, %arg2: memref<8x16xi32>) {
+func.func @test_dpas_i8(%arg0: vector<8x32xi8>, %arg1: vector<32x16xi8>, %arg2: memref<8x16xi32>) {
   %c0 = arith.constant 0 : index
   %0 = xegpu.dpas %arg0, %arg1 : vector<8x32xi8>, vector<32x16xi8> -> vector<8x16xi32>
   %1 = xegpu.create_nd_tdesc %arg2[%c0, %c0] : memref<8x16xi32> -> !xegpu.tensor_desc<8x16xi32>
@@ -56,13 +62,13 @@ func.func @test_dpas_op_2(%arg0: vector<8x32xi8>, %arg1: vector<32x16xi8>, %arg2
 // -----
 // CHECK: function: test_transpose_op_1:
 // CHECK-NEXT: argument: <block argument> of type 'memref<8x16xf16>' at index: 0
-// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map  : Not assigned.
 // CHECK-NEXT: argument: <block argument> of type 'memref<16x16xf16>' at index: 1
-// CHECK-NEXT: sg_map  : wi_layout: [16, 1], wi_data: [1, 2]
+// CHECK-NEXT: sg_map  : Not assigned.
 // CHECK-NEXT: argument: <block argument> of type 'memref<8x16xf32>' at index: 2
-// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map  : Not assigned.
 // CHECK-NEXT: op    : %{{.*}} = arith.constant 0 : index
-// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
+// CHECK-NEXT: sg_map for result #0: Not assigned.
 // CHECK-NEXT: op    : %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<8x16xf32>
 // CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
 // CHECK-NEXT: op    : %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
@@ -92,13 +98,13 @@ func.func @test_transpose_op_1(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>
 
 // -----
 // CHECK: argument: <block argument> of type 'memref<8x16xf16>' at index: 0
-// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map  : Not assigned.
 // CHECK-NEXT: argument: <block argument> of type 'memref<16x16xf16>' at index: 1
-// CHECK-NEXT: sg_map  : wi_layout: [16, 1], wi_data: [1, 2]
+// CHECK-NEXT: sg_map  : Not assigned.
 // CHECK-NEXT: argument: <block argument> of type 'memref<8x16xf32>' at index: 2
-// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map  : Not assigned.
 // CHECK-NEXT: op    : %{{.*}} = arith.constant 0 : index
-// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
+// CHECK-NEXT: sg_map for result #0: Not assigned.
 // CHECK-NEXT: op    : %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<8x16xf32>
 // CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
 // CHECK-NEXT: op    : %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
@@ -156,13 +162,13 @@ func.func @test_extf_truncf_op(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegp
 // -----
 // CHECK: function: test_load_gather_op_1:
 // CHECK-NEXT: argument: <block argument> of type 'memref<8x16xf16>' at index: 0
-// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map  : Not assigned.
 // CHECK-NEXT: argument: <block argument> of type 'memref<256xf16>' at index: 1
-// CHECK-NEXT: sg_map  : wi_layout: [16, 1], wi_data: [1, 2]
+// CHECK-NEXT: sg_map  : Not assigned.
 // CHECK-NEXT: argument: <block argument> of type 'memref<8x16xf32>' at index: 2
-// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map  : Not assigned.
 // CHECK-NEXT: op    : %{{.*}} = arith.constant 0 : index
-// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
+// CHECK-NEXT: sg_map for result #0: Not assigned.
 // CHECK-NEXT: op    : %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
 // CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
 // CHECK-NEXT: op    : %[[T1:.*]] = xegpu.load_nd %[[T0]]  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
@@ -195,7 +201,7 @@ func.func @test_load_gather_op_1(%arg0: memref<8x16xf16>, %arg1: memref<256xf16>
 
 // -----
 // CHECK: argument: <block argument> of type 'memref<256xf32>' at index: 0
-// CHECK-NEXT: sg_map  : wi_layout: [16], wi_data: [1]
+// CHECK-NEXT: sg_map  : Not assigned.
 // CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<16xf32>' at index: 1
 // CHECK-NEXT: sg_map  : wi_layout: [16], wi_data: [1]
 // CHECK-NEXT: op    : %[[CST:.*]] = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
@@ -217,7 +223,7 @@ func.func @test_load_gather_op_2(%arg0: memref<256xf32>, %arg1: !xegpu.tensor_de
 
 // -----
 // CHECK: argument: <block argument> of type 'memref<128xf32>' at index: 0
-// CHECK-NEXT: sg_map  : wi_layout: [16, 1], wi_data: [1, 1]
+// CHECK-NEXT: sg_map  : Not assigned.
 // CHECK-NEXT: op    : %[[CST:.*]] = arith.constant dense<1.000000e+00> : vector<8x16xf32>
 // CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
 // CHECK-NEXT: op    : %[[CST0:.*]] = arith.constant dense<true> : vector<16xi1>
@@ -239,7 +245,7 @@ func.func @test_store_scatter_op_1(%arg0: memref<128xf32>) {
 // CHECK: argument: <block argument> of type 'vector<16xf32>' at index: 0
 // CHECK-NEXT: sg_map  : wi_layout: [16], wi_data: [1]
 // CHECK-NEXT: argument: <block argument> of type 'memref<256xf32>' at index: 1
-// CHECK-NEXT: sg_map  : wi_layout: [16], wi_data: [1]
+// CHECK-NEXT: sg_map  : Not assigned.
 // CHECK-NEXT: op    : %[[CST:.*]] = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
 // CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
 // CHECK-NEXT: op    : %[[CST1:.*]] = arith.constant dense<true> : vector<16xi1>
@@ -256,13 +262,13 @@ func.func @test_store_scatter_op_2(%arg0: vector<16xf32>, %arg1: memref<256xf32>
 
 // -----
 // CHECK: argument: <block argument> of type 'memref<8x16xi16>' at index: 0
-// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map  : Not assigned.
 // CHECK-NEXT: argument: <block argument> of type 'memref<32x16xi8>' at index: 1
-// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [4, 1]
+// CHECK-NEXT: sg_map  : Not assigned.
 // CHECK-NEXT: argument: <block argument> of type 'memref<8x16xi32>' at index: 2
-// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map  : Not assigned.
 // CHECK-NEXT: op    : %{{.*}} = arith.constant 0 : index
-// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
+// CHECK-NEXT: sg_map for result #0: Not assigned.
 // CHECK-NEXT: op    : %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xi16> -> !xegpu.tensor_desc<8x16xi16>
 // CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
 // CHECK-NEXT: op    : %[[T1:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<32x16xi8> -> !xegpu.tensor_desc<32x16xi8>
@@ -292,13 +298,13 @@ func.func @test_vector_bitcast_op_1(%arg0: memref<8x16xi16>, %arg1: memref<32x16
 
 // -----
 // CHECK: argument: <block argument> of type 'memref<8x32xi8>' at index: 0
-// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 2]
+// CHECK-NEXT: sg_map  : Not assigned.
 // CHECK-NEXT: argument: <block argument> of type 'memref<16x32xi8>' at index: 1
-// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [4, 1]
+// CHECK-NEXT: sg_map  : Not assigned.
 // CHECK-NEXT: argument: <block argument> of type 'memref<8x16xf32>' at index: 2
-// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map  : Not assigned.
 // CHECK-NEXT: op    : %{{.*}} = arith.constant 0 : index
-// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
+// CHECK-NEXT: sg_map for result #0: Not assigned.
 // CHECK-NEXT: op    : %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x32xi8> -> !xegpu.tensor_desc<8x32xi8>
 // CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 2]
 // CHECK-NEXT: op    : %[[T1:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<16x32xi8> -> !xegpu.tensor_desc<16x32xi8>
@@ -388,17 +394,17 @@ func.func @test_binary_op_2(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.t
 
 // -----
 // CHECK: argument: <block argument> of type 'memref<8x128xf16>' at index: 0
-// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map  : Not assigned.
 // CHECK-NEXT: argument: <block argument> of type 'memref<128x16xf16>' at index: 1
-// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [2, 1]
+// CHECK-NEXT: sg_map  : Not assigned.
 // CHECK-NEXT: argument: <block argument> of type 'memref<8x16xf32>' at index: 2
-// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map  : Not assigned.
 // CHECK-NEXT: op    : %{{.*}} = arith.constant 0 : index
-// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
+// CHECK-NEXT: sg_map for result #0: Not assigned.
 // CHECK-NEXT: op    : %{{.*}} = arith.constant 128 : index
 // CHECK-NEXT: sg_map for result #0: Not assigned.
 // CHECK-NEXT: op    : %{{.*}} = arith.constant 16 : index
-// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
+// CHECK-NEXT: sg_map for result #0: Not assigned.
 // CHECK-NEXT: op    : %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x128xf16> -> !xegpu.tensor_desc<8x16xf16>
 // CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
 // CHECK-NEXT: op    : %[[T1:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<128x16xf16> -> !xegpu.tensor_desc<16x16xf16>

>From f7b7bdfc5fb2fbf64eb4db05c0918e912ce30825 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Tue, 11 Mar 2025 19:08:55 +0000
Subject: [PATCH 17/17] save work

---
 .../XeGPU/subgroup-map-propagation.mlir       | 104 +++++++-----------
 1 file changed, 41 insertions(+), 63 deletions(-)

diff --git a/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir b/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
index 1f16c0f738af9..1ae4348af33e6 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
@@ -60,7 +60,7 @@ func.func @test_dpas_i8(%arg0: vector<8x32xi8>, %arg1: vector<32x16xi8>, %arg2:
 }
 
 // -----
-// CHECK: function: test_transpose_op_1:
+// CHECK: function: test_load_with_transpose_effect:
 // CHECK-NEXT: argument: <block argument> of type 'memref<8x16xf16>' at index: 0
 // CHECK-NEXT: sg_map  : Not assigned.
 // CHECK-NEXT: argument: <block argument> of type 'memref<16x16xf16>' at index: 1
@@ -83,7 +83,7 @@ func.func @test_dpas_i8(%arg0: vector<8x32xi8>, %arg1: vector<32x16xi8>, %arg2:
 // CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
 // CHECK-NEXT: op    : %[[T5:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
 // CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
-func.func @test_transpose_op_1(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<8x16xf32>) {
+func.func @test_load_with_transpose_effect(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<8x16xf32>) {
   %c0 = arith.constant 0 : index
   %cst = arith.constant dense<0.000000e+00> : vector<8x16xf32>
   %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
@@ -97,7 +97,8 @@ func.func @test_transpose_op_1(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>
 }
 
 // -----
-// CHECK: argument: <block argument> of type 'memref<8x16xf16>' at index: 0
+// CHECK: function: test_vector_transpose:
+// CHECK-NEXT: argument: <block argument> of type 'memref<8x16xf16>' at index: 0
 // CHECK-NEXT: sg_map  : Not assigned.
 // CHECK-NEXT: argument: <block argument> of type 'memref<16x16xf16>' at index: 1
 // CHECK-NEXT: sg_map  : Not assigned.
@@ -121,7 +122,7 @@ func.func @test_transpose_op_1(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>
 // CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
 // CHECK-NEXT: op    : %[[T6:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
 // CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
-func.func @test_transpose_op_2(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<8x16xf32>) {
+func.func @test_vector_transpose(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<8x16xf32>) {
   %c0 = arith.constant 0 : index
   %cst = arith.constant dense<0.000000e+00> : vector<8x16xf32>
   %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
@@ -136,7 +137,8 @@ func.func @test_transpose_op_2(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>
 }
 
 // -----
-// CHECK: argument: <block argument> of type '!xegpu.tensor_desc<8x16xf16>' at index: 0
+// CHECK: function: test_extf_truncf:
+// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<8x16xf16>' at index: 0
 // CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
 // CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<16x16xf16>' at index: 1
 // CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [2, 1]
@@ -150,7 +152,7 @@ func.func @test_transpose_op_2(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>
 // CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1]
 // CHECK-NEXT: op    : %[[T4:.*]] = xegpu.dpas %[[T0]], %[[T3]] : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
 // CHECK-NEXT: sg_map for result #0: Not assigned.
-func.func @test_extf_truncf_op(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>) -> vector<8x16xf32> {
+func.func @test_extf_truncf(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>) -> vector<8x16xf32> {
   %0 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
   %1 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
   %2 = arith.extf %1 : vector<16x16xf16> to vector<16x16xf32>
@@ -160,7 +162,7 @@ func.func @test_extf_truncf_op(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegp
 }
 
 // -----
-// CHECK: function: test_load_gather_op_1:
+// CHECK: function: test_load_gather_with_transpose_effect:
 // CHECK-NEXT: argument: <block argument> of type 'memref<8x16xf16>' at index: 0
 // CHECK-NEXT: sg_map  : Not assigned.
 // CHECK-NEXT: argument: <block argument> of type 'memref<256xf16>' at index: 1
@@ -185,7 +187,7 @@ func.func @test_extf_truncf_op(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegp
 // CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
 // CHECK-NEXT: op    : %[[T5:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
 // CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
-func.func @test_load_gather_op_1(%arg0: memref<8x16xf16>, %arg1: memref<256xf16>, %arg2: memref<8x16xf32>) {
+func.func @test_load_gather_with_transpose_effect(%arg0: memref<8x16xf16>, %arg1: memref<256xf16>, %arg2: memref<8x16xf32>) {
   %c0 = arith.constant 0 : index
   %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
   %1 = xegpu.load_nd %0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
@@ -200,6 +202,7 @@ func.func @test_load_gather_op_1(%arg0: memref<8x16xf16>, %arg1: memref<256xf16>
 }
 
 // -----
+// CHECK: function: test_load_gather_1d:
 // CHECK: argument: <block argument> of type 'memref<256xf32>' at index: 0
 // CHECK-NEXT: sg_map  : Not assigned.
 // CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<16xf32>' at index: 1
@@ -212,7 +215,7 @@ func.func @test_load_gather_op_1(%arg0: memref<8x16xf16>, %arg1: memref<256xf16>
 // CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
 // CHECK-NEXT: op    : %[[T1]] = xegpu.load %[[T0]], %[[CST0]]  : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> -> vector<16xf32>
 // CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
-func.func @test_load_gather_op_2(%arg0: memref<256xf32>, %arg1: !xegpu.tensor_desc<16xf32>) {
+func.func @test_load_gather_1d(%arg0: memref<256xf32>, %arg1: !xegpu.tensor_desc<16xf32>) {
   %cst = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
   %cst_0 = arith.constant dense<true> : vector<16xi1>
   %0 = xegpu.create_tdesc %arg0, %cst : memref<256xf32>, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
@@ -222,7 +225,8 @@ func.func @test_load_gather_op_2(%arg0: memref<256xf32>, %arg1: !xegpu.tensor_de
 }
 
 // -----
-// CHECK: argument: <block argument> of type 'memref<128xf32>' at index: 0
+// CHECK: function: test_store_scatter_with_transpose_effect:
+// CHECK-NEXT: argument: <block argument> of type 'memref<128xf32>' at index: 0
 // CHECK-NEXT: sg_map  : Not assigned.
 // CHECK-NEXT: op    : %[[CST:.*]] = arith.constant dense<1.000000e+00> : vector<8x16xf32>
 // CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
@@ -232,7 +236,7 @@ func.func @test_load_gather_op_2(%arg0: memref<256xf32>, %arg1: !xegpu.tensor_de
 // CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
 // CHECK-NEXT: op    : %[[T0:.*]] = xegpu.create_tdesc %{{.*}}, %[[CST1]] : memref<128xf32>, vector<16xindex> -> !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<chunk_size = 8 : i64>>
 // CHECK-NEXT: sg_map for result #0: wi_layout: [16, 1], wi_data: [1, 1]
-func.func @test_store_scatter_op_1(%arg0: memref<128xf32>) {
+func.func @test_store_scatter_with_transpose_effect(%arg0: memref<128xf32>) {
   %cst = arith.constant dense<1.000000e+00> : vector<8x16xf32>
   %cst_0 = arith.constant dense<true> : vector<16xi1>
   %cst_1 = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
@@ -242,7 +246,8 @@ func.func @test_store_scatter_op_1(%arg0: memref<128xf32>) {
 }
 
 // -----
-// CHECK: argument: <block argument> of type 'vector<16xf32>' at index: 0
+// CHECK: function: test_store_scatter_1d:
+// CHECK-NEXT: argument: <block argument> of type 'vector<16xf32>' at index: 0
 // CHECK-NEXT: sg_map  : wi_layout: [16], wi_data: [1]
 // CHECK-NEXT: argument: <block argument> of type 'memref<256xf32>' at index: 1
 // CHECK-NEXT: sg_map  : Not assigned.
@@ -252,7 +257,7 @@ func.func @test_store_scatter_op_1(%arg0: memref<128xf32>) {
 // CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
 // CHECK-NEXT: op    : %[[T0:.*]] = xegpu.create_tdesc %{{.*}}, %[[CST]] : memref<256xf32>, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
 // CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
-func.func @test_store_scatter_op_2(%arg0: vector<16xf32>, %arg1: memref<256xf32>) {
+func.func @test_store_scatter_1d(%arg0: vector<16xf32>, %arg1: memref<256xf32>) {
   %cst = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
   %cst_0 = arith.constant dense<true> : vector<16xi1>
   %0 = xegpu.create_tdesc %arg1, %cst : memref<256xf32>, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
@@ -261,7 +266,8 @@ func.func @test_store_scatter_op_2(%arg0: vector<16xf32>, %arg1: memref<256xf32>
 }
 
 // -----
-// CHECK: argument: <block argument> of type 'memref<8x16xi16>' at index: 0
+// CHECK: function: test_vector_bitcast_i16_to_i8:
+// CHECK-NEXT: argument: <block argument> of type 'memref<8x16xi16>' at index: 0
 // CHECK-NEXT: sg_map  : Not assigned.
 // CHECK-NEXT: argument: <block argument> of type 'memref<32x16xi8>' at index: 1
 // CHECK-NEXT: sg_map  : Not assigned.
@@ -283,7 +289,7 @@ func.func @test_store_scatter_op_2(%arg0: vector<16xf32>, %arg1: memref<256xf32>
 // CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
 // CHECK-NEXT: op    : %[[T6:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xi32> -> !xegpu.tensor_desc<8x16xi32>
 // CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
-func.func @test_vector_bitcast_op_1(%arg0: memref<8x16xi16>, %arg1: memref<32x16xi8>, %arg2: memref<8x16xi32>) {
+func.func @test_vector_bitcast_i16_to_i8(%arg0: memref<8x16xi16>, %arg1: memref<32x16xi8>, %arg2: memref<8x16xi32>) {
   %c0 = arith.constant 0 : index
   %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x16xi16> -> !xegpu.tensor_desc<8x16xi16>
   %1 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<32x16xi8> -> !xegpu.tensor_desc<32x16xi8>
@@ -297,7 +303,8 @@ func.func @test_vector_bitcast_op_1(%arg0: memref<8x16xi16>, %arg1: memref<32x16
 }
 
 // -----
-// CHECK: argument: <block argument> of type 'memref<8x32xi8>' at index: 0
+// CHECK: function: test_vector_bitcast_i8_to_f16:
+// CHECK-NEXT: argument: <block argument> of type 'memref<8x32xi8>' at index: 0
 // CHECK-NEXT: sg_map  : Not assigned.
 // CHECK-NEXT: argument: <block argument> of type 'memref<16x32xi8>' at index: 1
 // CHECK-NEXT: sg_map  : Not assigned.
@@ -321,7 +328,7 @@ func.func @test_vector_bitcast_op_1(%arg0: memref<8x16xi16>, %arg1: memref<32x16
 // CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
 // CHECK-NEXT: op    : %[[T7:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
 // CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
-func.func @test_vector_bitcast_op_2(%arg0: memref<8x32xi8>, %arg1: memref<16x32xi8>, %arg2: memref<8x16xf32>) {
+func.func @test_vector_bitcast_i8_to_f16(%arg0: memref<8x32xi8>, %arg1: memref<16x32xi8>, %arg2: memref<8x16xf32>) {
   %c0 = arith.constant 0 : index
   %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x32xi8> -> !xegpu.tensor_desc<8x32xi8>
   %1 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<16x32xi8> -> !xegpu.tensor_desc<16x32xi8>
@@ -336,7 +343,8 @@ func.func @test_vector_bitcast_op_2(%arg0: memref<8x32xi8>, %arg1: memref<16x32x
 }
 
 // -----
-// CHECK: argument: <block argument> of type '!xegpu.tensor_desc<8x16xf16>' at index: 0
+// CHECK: function: test_binary_op_one_use:
+// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<8x16xf16>' at index: 0
 // CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
 // CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<16x16xf16>' at index: 1
 // CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [2, 1]
@@ -352,7 +360,7 @@ func.func @test_vector_bitcast_op_2(%arg0: memref<8x32xi8>, %arg1: memref<16x32x
 // CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1]
 // CHECK-NEXT: op    : %[[T4:.*]] = xegpu.dpas %[[T0]], %[[T3]] : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
 // CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
-func.func @test_binary_op_1(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: !xegpu.tensor_desc<8x16xf32>) {
+func.func @test_binary_op_one_use(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: !xegpu.tensor_desc<8x16xf32>) {
   %0 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
   %1 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
   %2 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
@@ -363,7 +371,8 @@ func.func @test_binary_op_1(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.t
 }
 
 // -----
-// CHECK: argument: <block argument> of type '!xegpu.tensor_desc<8x16xf16>' at index: 0
+// CHECK: function: test_binary_op_multiple_uses:
+// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<8x16xf16>' at index: 0
 // CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
 // CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<16x16xf16>' at index: 1
 // CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
@@ -381,7 +390,7 @@ func.func @test_binary_op_1(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.t
 // CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
 // CHECK-NEXT: op    : %[[T3:.*]] = xegpu.dpas %[[T0]], %[[T2]] : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
 // CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
-func.func @test_binary_op_2(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: !xegpu.tensor_desc<8x16xf32>, %arg3: !xegpu.tensor_desc<16x16xf16>) {
+func.func @test_binary_op_multiple_uses(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: !xegpu.tensor_desc<8x16xf32>, %arg3: !xegpu.tensor_desc<16x16xf16>) {
   %0 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
   %1 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
   %cst = arith.constant dense<1.000000e+00> : vector<16x16xf16>
@@ -393,7 +402,8 @@ func.func @test_binary_op_2(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.t
 }
 
 // -----
-// CHECK: argument: <block argument> of type 'memref<8x128xf16>' at index: 0
+// CHECK: function: test_for_op:
+// CHECK-NEXT: argument: <block argument> of type 'memref<8x128xf16>' at index: 0
 // CHECK-NEXT: sg_map  : Not assigned.
 // CHECK-NEXT: argument: <block argument> of type 'memref<128x16xf16>' at index: 1
 // CHECK-NEXT: sg_map  : Not assigned.
@@ -448,7 +458,7 @@ func.func @test_for_op(%arg0: memref<8x128xf16>, %arg1: memref<128x16xf16>, %arg
 }
 
 // -----
-// CHECK: function: test_if_op_1:
+// CHECK: function: test_if_single_use:
 // CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<8x16xf16>' at index: 0
 // CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
 // CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<16x16xf16>' at index: 1
@@ -467,7 +477,7 @@ func.func @test_for_op(%arg0: memref<8x128xf16>, %arg1: memref<128x16xf16>, %arg
 // CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1]
 // CHECK-NEXT: op    : %[[T2:.*]] = xegpu.dpas %[[T0]], %{{.*}} : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
 // CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
-func.func @test_if_op_1(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: i1, %arg3: !xegpu.tensor_desc<8x16xf32>) {
+func.func @test_if_single_use(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: i1, %arg3: !xegpu.tensor_desc<8x16xf32>) {
   %0 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
   %1 = scf.if %arg2 -> (vector<16x16xf16>) {
     %3 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
@@ -482,7 +492,7 @@ func.func @test_if_op_1(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tenso
 }
 
 // -----
-// CHECK: function: test_if_op_2:
+// CHECK: function: test_if_multiple_uses:
 // CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<8x16xf16>' at index: 0
 // CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
 // CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<16x16xf16>' at index: 1
@@ -503,7 +513,7 @@ func.func @test_if_op_1(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tenso
 // CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
 // CHECK-NEXT: op    : %[[T2:.*]] = xegpu.dpas %[[T0]], %{{.*}} : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
 // CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
-func.func @test_if_op_2(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: i1, %arg3: !xegpu.tensor_desc<8x16xf32>, %arg4: !xegpu.tensor_desc<16x16xf16>) {
+func.func @test_if_multiple_uses(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: i1, %arg3: !xegpu.tensor_desc<8x16xf32>, %arg4: !xegpu.tensor_desc<16x16xf16>) {
   %0 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
   %1 = scf.if %arg2 -> (vector<16x16xf16>) {
     %3 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
@@ -519,40 +529,7 @@ func.func @test_if_op_2(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tenso
 }
 
 // -----
-// CHECK: function: test_if_op_3:
-// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<8x16xf16>' at index: 0
-// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
-// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<16x16xf16>' at index: 1
-// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [2, 1]
-// CHECK-NEXT: argument: <block argument> of type 'vector<16x16xf16>' at index: 2
-// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [2, 1]
-// CHECK-NEXT: argument: <block argument> of type 'i1' at index: 3
-// CHECK-NEXT: sg_map  : Not assigned.
-// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<8x16xf32>' at index: 4
-// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
-// CHECK-NEXT: op    : %[[T0:.*]] = xegpu.load_nd %{{.*}}  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
-// CHECK-NEXT: op    : %[[T3:.*]] = xegpu.load_nd %{{.*}}  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1]
-// CHECK-NEXT: op    : scf.if
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1]
-// CHECK-NEXT: op    : %[[T2:.*]] = xegpu.dpas %[[T0]], %{{.*}} : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
-func.func @test_if_op_3(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: vector<16x16xf16>, %arg3: i1, %arg4: !xegpu.tensor_desc<8x16xf32>) {
-  %0 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
-  %1 = scf.if %arg3 -> (vector<16x16xf16>) {
-    %3 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
-    scf.yield %3 : vector<16x16xf16>
-  } else {
-    scf.yield %arg2 : vector<16x16xf16>
-  }
-  %2 = xegpu.dpas %0, %1 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
-  xegpu.store_nd %2, %arg4  : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
-  return
-}
-
-// -----
-// CHECK: function: test_vector_reduction_1:
+// CHECK: function: test_vector_outer_reduction:
 // CHECK-NEXT: argument: <block argument> of type 'vector<16x16xf32>' at index: 0
 // CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
 // CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<16xf32>' at index: 1
@@ -561,7 +538,7 @@ func.func @test_if_op_3(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tenso
 // CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
 // CHECK-NEXT: op    : %[[T0:.*]] = vector.multi_reduction <add>, %{{.*}}, %[[CST]] [0] : vector<16x16xf32> to vector<16xf32>
 // CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
-func.func @test_vector_reduction_1(%arg0: vector<16x16xf32>, %arg1: !xegpu.tensor_desc<16xf32>) {
+func.func @test_vector_outer_reduction(%arg0: vector<16x16xf32>, %arg1: !xegpu.tensor_desc<16xf32>) {
   %cst = arith.constant dense<0.000000e+00> : vector<16xf32>
   %0 = vector.multi_reduction <add>, %arg0, %cst [0] : vector<16x16xf32> to vector<16xf32>
   xegpu.store_nd %0, %arg1  : vector<16xf32>, !xegpu.tensor_desc<16xf32>
@@ -569,7 +546,8 @@ func.func @test_vector_reduction_1(%arg0: vector<16x16xf32>, %arg1: !xegpu.tenso
 }
 
 // -----
-// CHECK: argument: <block argument> of type 'vector<16x16xf32>' at index: 0
+// CHECK: function: test_vector_inner_reduction:
+// CHECK-NEXT: argument: <block argument> of type 'vector<16x16xf32>' at index: 0
 // CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
 // CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<16xf32>' at index: 1
 // CHECK-NEXT: sg_map  : wi_layout: [16], wi_data: [1]
@@ -577,7 +555,7 @@ func.func @test_vector_reduction_1(%arg0: vector<16x16xf32>, %arg1: !xegpu.tenso
 // CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
 // CHECK-NEXT: op    : %[[T0:.*]] = vector.multi_reduction <add>, %{{.*}}, %[[CST]] [1] : vector<16x16xf32> to vector<16xf32>
 // CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
-func.func @test_vector_reduction_2(%arg0: vector<16x16xf32>, %arg1: !xegpu.tensor_desc<16xf32>) {
+func.func @test_vector_inner_reduction(%arg0: vector<16x16xf32>, %arg1: !xegpu.tensor_desc<16xf32>) {
   %cst = arith.constant dense<0.000000e+00> : vector<16xf32>
   %0 = vector.multi_reduction <add>, %arg0, %cst [1] : vector<16x16xf32> to vector<16xf32>
   xegpu.store_nd %0, %arg1  : vector<16xf32>, !xegpu.tensor_desc<16xf32>



More information about the Mlir-commits mailing list