[Mlir-commits] [mlir] adc6228 - [mlir][xegpu] Refine layout assignment in XeGPU SIMT distribution. (#142687)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jun 20 10:43:23 PDT 2025


Author: Charitha Saumya
Date: 2025-06-20T10:43:19-07:00
New Revision: adc6228ea07eba401481e218c3e0536a4aa6b8ec

URL: https://github.com/llvm/llvm-project/commit/adc6228ea07eba401481e218c3e0536a4aa6b8ec
DIFF: https://github.com/llvm/llvm-project/commit/adc6228ea07eba401481e218c3e0536a4aa6b8ec.diff

LOG: [mlir][xegpu] Refine layout assignment in XeGPU SIMT distribution. (#142687)

Changes:
* Decouple layout propagation from subgroup distribution and move it to
an independent pass.
* Refine layout assignment to handle control-flow ops correctly (scf.for, scf.while).
* Refine test cases.

Added: 
    mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
    mlir/test/Dialect/XeGPU/propagate-layout.mlir
    mlir/test/Dialect/XeGPU/subgroup-distribute.mlir

Modified: 
    mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
    mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
    mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
    mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp

Removed: 
    mlir/test/Dialect/XeGPU/subgroup-distribution.mlir
    mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir


################################################################################
diff  --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
index 8bdf19ac0e47d..3a88dae041dd1 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
@@ -27,10 +27,23 @@ def XeGPUSubgroupDistribute : Pass<"xegpu-subgroup-distribute"> {
   }];
   let dependentDialects = ["memref::MemRefDialect", "xegpu::XeGPUDialect",
                            "vector::VectorDialect"];
+}
+
+def XeGPUPropagateLayout : Pass<"xegpu-propagate-layout"> {
+  let summary = "Propagate and assign XeGPU layout information";
+  let description = [{
+    This pass propagates the XeGPU layout information accross ops. Starting
+    from a set of anchor operations (e.g. `dpas`, `store_nd`), this will
+    propagate the layouts required for their operands to the producers. With
+    this propagated layout information, pass will then update op result type
+    with the layout information.
+  }];
+  let dependentDialects = ["memref::MemRefDialect", "xegpu::XeGPUDialect",
+                           "vector::VectorDialect"];
   let options = [Option<
-      "printOnly", "print-analysis-only", "bool",
-      /*default=*/"false",
-      "Print the result of the subgroup map propagation analysis and exit.">];
+    "printOnly", "print-analysis-only", "bool",
+    /*default=*/"false",
+    "Print the result of layout propagation analysis and exit.">];
 }
 
 def XeGPUWgToSgDistribute : Pass<"xegpu-wg-to-sg-distribute"> {

diff  --git a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
index 6fea10185402a..772cf73649646 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
@@ -24,6 +24,20 @@ class LayoutAttr;
 class TensorDescType;
 } // namespace xegpu
 
+namespace xegpu {
+/// HW dependent constants.
+/// TODO: These constants should be queried from the target information.
+namespace targetinfo {
+constexpr unsigned subgroupSize = 16; // How many lanes in a subgroup.
+/// If DPAS A or B operands have low precision element types they must be packed
+/// according to the following sizes.
+constexpr unsigned packedSizeInBitsForDefault =
+    16; // Minimum packing size per register for DPAS A.
+constexpr unsigned packedSizeInBitsForDpasB =
+    32; // Minimum packing size per register for DPAS B.
+} // namespace targetinfo
+} // namespace xegpu
+
 namespace xegpu {
 
 /// Flatten a set of ValueRange into a single SmallVector<Value>

diff  --git a/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
index af0d7f6bd9070..9c178d1d85642 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
@@ -4,6 +4,7 @@ add_mlir_dialect_library(MLIRXeGPUTransforms
   XeGPUSubgroupDistribute.cpp
   XeGPUUnroll.cpp
   XeGPUWgToSgDistribute.cpp
+  XeGPUPropagateLayout.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/XeGPU

diff  --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
new file mode 100644
index 0000000000000..cc22d2bbd8c39
--- /dev/null
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -0,0 +1,889 @@
+//===- XeGPUPropagateLayout.cpp - XeGPU Layout Propagation ------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
+#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
+#include "mlir/Analysis/DataFlow/SparseAnalysis.h"
+#include "mlir/Analysis/DataFlow/Utils.h"
+#include "mlir/Analysis/DataFlowFramework.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
+#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/Value.h"
+#include "mlir/IR/Visitors.h"
+#include "mlir/Interfaces/ControlFlowInterfaces.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
+#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/Casting.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/InterleavedRange.h"
+#include "llvm/Support/LogicalResult.h"
+#include "llvm/Support/raw_ostream.h"
+
+namespace mlir {
+namespace xegpu {
+#define GEN_PASS_DEF_XEGPUPROPAGATELAYOUT
+#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
+} // namespace xegpu
+} // namespace mlir
+
+#define DEBUG_TYPE "xegpu-propagate-layout"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
+
+using namespace mlir;
+using namespace mlir::dataflow;
+
+namespace {
+
+//===----------------------------------------------------------------------===//
+// Layout
+//===----------------------------------------------------------------------===//
+
+/// Helper class to store the ND layout of lanes within a subgroup and data
+/// owned by each lane.
+struct Layout {
+  SmallVector<int64_t, 3> layout;
+  Layout() = default;
+  Layout(std::initializer_list<int64_t> list) : layout(list) {}
+  void print(llvm::raw_ostream &os) const;
+  size_t size() const { return layout.size(); }
+};
+
+void Layout::print(llvm::raw_ostream &os) const {
+  os << llvm::interleaved_array(layout);
+}
+
+/// LaneLayout represents the logical layout of lanes within a subgroup when it
+/// accesses some value. LaneData represents the logical layout of data owned by
+/// each work item.
+using LaneLayout = Layout;
+using LaneData = Layout;
+
+//===----------------------------------------------------------------------===//
+// LayoutInfo
+//===----------------------------------------------------------------------===//
+
+/// Helper class for tracking the analysis state of an mlir value. For layout
+/// propagation, the analysis state is simply the lane_layout and lane_data of
+/// each value. Purpose of this analysis to propagate some unique layout for
+/// each value in the program starting from a set of anchor operations (like
+/// DPAS, StoreNd, etc.).
+///
+/// Given this, LayoutInfo  satisifies the following properties:
+///  1) A LayoutInfo value can be in one of two states - `assigned` or `not
+///  assigned`.
+///  2) Two LayoutInfo values are equal if they are both assigned or
+///  both not assigned. The concrete value of assigned state does not matter.
+///  3) The meet operator works as follows:
+///     - If current state is assigned, return the current state. (already
+///     a unique layout is assigned. don't change it)
+///     - Otherwise, return the other state.
+
+struct LayoutInfo {
+private:
+  LaneLayout laneLayout;
+  LaneData laneData;
+  xegpu::LayoutAttr layoutAttr;
+
+public:
+  LayoutInfo() = default;
+  LayoutInfo(const LaneLayout &layout, const LaneData &data)
+      : laneLayout(layout), laneData(data) {}
+
+  // Two lattice values are equal if they have `some` layout. The actual
+  // content of the layout does not matter.
+  bool operator==(const LayoutInfo &other) const {
+    return this->isAssigned() == other.isAssigned();
+  }
+
+  static LayoutInfo meet(const LayoutInfo &lhs, const LayoutInfo &rhs);
+
+  static LayoutInfo join(const LayoutInfo &lhs, const LayoutInfo &rhs);
+
+  void print(raw_ostream &os) const;
+
+  bool isAssigned() const {
+    return laneLayout.size() > 0 && laneData.size() > 0;
+  }
+
+  LayoutInfo getTransposedLayout(ArrayRef<int64_t> permutation) const;
+
+  const LaneLayout &getLayout() const { return laneLayout; }
+  const LaneData &getData() const { return laneData; }
+  ArrayRef<int64_t> getLayoutAsArrayRef() const { return laneLayout.layout; }
+  ArrayRef<int64_t> getDataAsArrayRef() const { return laneData.layout; }
+};
+
+void LayoutInfo::print(raw_ostream &os) const {
+  if (isAssigned()) {
+    os << "lane_layout: ";
+    laneLayout.print(os);
+    os << ", lane_data: ";
+    laneData.print(os);
+  } else {
+    os << "Not assigned.";
+  }
+}
+
+LayoutInfo LayoutInfo::meet(const LayoutInfo &lhs, const LayoutInfo &rhs) {
+  if (!lhs.isAssigned())
+    return rhs;
+  return lhs;
+}
+
+/// Since this is a backward analysis, join method is not used.
+LayoutInfo LayoutInfo::join(const LayoutInfo &lhs, const LayoutInfo &rhs) {
+  llvm_unreachable("Join should not be triggered by layout propagation.");
+}
+
+/// Get the transposed layout according to the given permutation.
+LayoutInfo
+LayoutInfo::getTransposedLayout(ArrayRef<int64_t> permutation) const {
+  if (!isAssigned())
+    return {};
+  LaneLayout newLayout;
+  LaneData newData;
+  for (int64_t idx : permutation) {
+    newLayout.layout.push_back(laneLayout.layout[idx]);
+    newData.layout.push_back(laneData.layout[idx]);
+  }
+  return LayoutInfo(newLayout, newData);
+}
+
+//===----------------------------------------------------------------------===//
+// LayoutInfoLattice
+//===----------------------------------------------------------------------===//
+
+/// Lattice holding the LayoutInfo for each value.
+struct LayoutInfoLattice : public Lattice<LayoutInfo> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LayoutInfoLattice)
+  using Lattice::Lattice;
+};
+
+/// Helper Functions to get default layouts. A `default layout` is a layout that
+/// is assigned to a value when the layout is not fixed by some anchor operation
+/// (like DPAS).
+
+/// Helper Function to get the default layout for uniform values like constants.
+/// For 1D vector, lane_layout is [subgroupSize] and lane_data is [1].
+/// For 2D vector, lane_layout is [1, subgroupSize] and lane_data is [1, 1].
+static LayoutInfo getDefaultSIMTLayoutInfo(unsigned rank) {
+  assert((rank == 1 || rank == 2) && "Expected 1D or 2D vector.");
+  if (rank == 1)
+    return LayoutInfo(LaneLayout({xegpu::targetinfo::subgroupSize}),
+                      LaneData({1}));
+  return LayoutInfo(LaneLayout({1, xegpu::targetinfo::subgroupSize}),
+                    LaneData({1, 1}));
+}
+
+/// Helper to get the default layout for a vector type.
+static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy) {
+  // Expecting a 1D or 2D vector.
+  assert((vectorTy.getRank() == 1 || vectorTy.getRank() == 2) &&
+         "Expected 1D or 2D vector.");
+  // Expecting int or float element type.
+  assert(vectorTy.getElementType().isIntOrFloat() &&
+         "Expected int or float element type.");
+  // If the rank is 1, then return default layout for 1D vector.
+  if (vectorTy.getRank() == 1)
+    return getDefaultSIMTLayoutInfo(1);
+  // Packing factor is determined by the element type bitwidth.
+  int packingFactor = 1;
+  unsigned bitwidth = vectorTy.getElementType().getIntOrFloatBitWidth();
+  if (bitwidth < xegpu::targetinfo::packedSizeInBitsForDefault)
+    packingFactor = xegpu::targetinfo::packedSizeInBitsForDefault / bitwidth;
+  return LayoutInfo(LaneLayout({1, xegpu::targetinfo::subgroupSize}),
+                    LaneData({1, packingFactor}));
+}
+
+/// Helper Function to get the expected layouts for DPAS operands. `lane_data`
+/// is set according to the following criteria:
+/// * For A operand, the data must be packed in minimum
+/// `packedSizeInBitsForDefault`
+/// * For B operand, the data must be packed in minimum
+/// `packedSizeInBitsForDpasB`
+static LayoutInfo getSIMTLayoutInfoForDPASOperand(VectorType vectorTy,
+                                                  unsigned operandNum) {
+  Type elementTy = vectorTy.getElementType();
+  assert(elementTy.isIntOrFloat() &&
+         "Expected int or float type in DPAS operands");
+  LaneLayout layout({1, xegpu::targetinfo::subgroupSize});
+  // For B operand, data must be packed in minimum `packedDpasBSizeInBits` and
+  // must have the VNNI format.
+  if (operandNum == 1 && elementTy.getIntOrFloatBitWidth() <
+                             xegpu::targetinfo::packedSizeInBitsForDpasB) {
+    LaneData data({xegpu::targetinfo::packedSizeInBitsForDpasB /
+                       elementTy.getIntOrFloatBitWidth(),
+                   1});
+    return LayoutInfo(layout, data);
+  }
+  // Otherwise, return the default layout for the vector type.
+  return getDefaultSIMTLayoutInfo(vectorTy);
+}
+
+//===----------------------------------------------------------------------===//
+// LayoutInfoPropagation
+//===----------------------------------------------------------------------===//
+
+/// Backward data flow analysis to propagate the lane_layout and lane_data of
+/// each value in the program. Currently, the layouts for operands DPAS,
+/// StoreNd, and StoreScatter are fixed (known before propagation). Purpose of
+/// this analysis is to propagate those known layouts to all their producers and
+/// (other) consumers.
+class LayoutInfoPropagation
+    : public SparseBackwardDataFlowAnalysis<LayoutInfoLattice> {
+private:
+  void visitDpasOp(xegpu::DpasOp dpas, ArrayRef<LayoutInfoLattice *> operands,
+                   ArrayRef<const LayoutInfoLattice *> results);
+
+  void visitStoreNdOp(xegpu::StoreNdOp store,
+                      ArrayRef<LayoutInfoLattice *> operands,
+                      ArrayRef<const LayoutInfoLattice *> results);
+
+  void visitStoreScatterOp(xegpu::StoreScatterOp storeScatter,
+                           ArrayRef<LayoutInfoLattice *> operands,
+                           ArrayRef<const LayoutInfoLattice *> results);
+
+  void visitLoadNdOp(xegpu::LoadNdOp load,
+                     ArrayRef<LayoutInfoLattice *> operands,
+                     ArrayRef<const LayoutInfoLattice *> results);
+
+  void visitLoadGatherOp(xegpu::LoadGatherOp load,
+                         ArrayRef<LayoutInfoLattice *> operands,
+                         ArrayRef<const LayoutInfoLattice *> results);
+
+  void visitTransposeOp(vector::TransposeOp transpose,
+                        ArrayRef<LayoutInfoLattice *> operands,
+                        ArrayRef<const LayoutInfoLattice *> results);
+
+  void visitVectorBitcastOp(vector::BitCastOp bitcast,
+                            ArrayRef<LayoutInfoLattice *> operands,
+                            ArrayRef<const LayoutInfoLattice *> results);
+
+  void visitCreateDescOp(xegpu::CreateDescOp createDesc,
+                         ArrayRef<LayoutInfoLattice *> operands,
+                         ArrayRef<const LayoutInfoLattice *> results);
+
+  void visitUpdateNdOffsetOp(xegpu::UpdateNdOffsetOp updateNdOffset,
+                             ArrayRef<LayoutInfoLattice *> operands,
+                             ArrayRef<const LayoutInfoLattice *> results);
+
+  void visitPrefetchNdOp(xegpu::PrefetchNdOp prefetch,
+                         ArrayRef<LayoutInfoLattice *> operands,
+                         ArrayRef<const LayoutInfoLattice *> results);
+
+  void visitVectorMultiReductionOp(vector::MultiDimReductionOp reduction,
+                                   ArrayRef<LayoutInfoLattice *> operands,
+                                   ArrayRef<const LayoutInfoLattice *> results);
+
+public:
+  LayoutInfoPropagation(DataFlowSolver &solver,
+                        SymbolTableCollection &symbolTable)
+      : SparseBackwardDataFlowAnalysis(solver, symbolTable) {}
+  using SparseBackwardDataFlowAnalysis::SparseBackwardDataFlowAnalysis;
+
+  LogicalResult
+  visitOperation(Operation *op, ArrayRef<LayoutInfoLattice *> operands,
+                 ArrayRef<const LayoutInfoLattice *> results) override;
+
+  void visitBranchOperand(OpOperand &operand) override {};
+
+  void visitCallOperand(OpOperand &operand) override {};
+
+  void visitExternalCall(CallOpInterface call,
+                         ArrayRef<LayoutInfoLattice *> operands,
+                         ArrayRef<const LayoutInfoLattice *> results) override {
+  };
+
+  void setToExitState(LayoutInfoLattice *lattice) override {
+    (void)lattice->meet(LayoutInfo());
+  }
+};
+} // namespace
+
+LogicalResult LayoutInfoPropagation::visitOperation(
+    Operation *op, ArrayRef<LayoutInfoLattice *> operands,
+    ArrayRef<const LayoutInfoLattice *> results) {
+  TypeSwitch<Operation *>(op)
+      .Case<xegpu::DpasOp>(
+          [&](auto dpasOp) { visitDpasOp(dpasOp, operands, results); })
+      .Case<xegpu::StoreNdOp>(
+          [&](auto storeNdOp) { visitStoreNdOp(storeNdOp, operands, results); })
+      .Case<xegpu::StoreScatterOp>([&](auto storeScatterOp) {
+        visitStoreScatterOp(storeScatterOp, operands, results);
+      })
+      .Case<xegpu::LoadNdOp>(
+          [&](auto loadNdOp) { visitLoadNdOp(loadNdOp, operands, results); })
+      .Case<xegpu::LoadGatherOp>([&](auto loadGatherOp) {
+        visitLoadGatherOp(loadGatherOp, operands, results);
+      })
+      .Case<xegpu::CreateDescOp>([&](auto createDescOp) {
+        visitCreateDescOp(createDescOp, operands, results);
+      })
+      .Case<xegpu::UpdateNdOffsetOp>([&](auto updateNdOffsetOp) {
+        visitUpdateNdOffsetOp(updateNdOffsetOp, operands, results);
+      })
+      .Case<xegpu::PrefetchNdOp>([&](auto prefetchNdOp) {
+        visitPrefetchNdOp(prefetchNdOp, operands, results);
+      })
+      .Case<vector::TransposeOp>([&](auto transposeOp) {
+        visitTransposeOp(transposeOp, operands, results);
+      })
+      .Case<vector::BitCastOp>([&](auto bitcastOp) {
+        visitVectorBitcastOp(bitcastOp, operands, results);
+      })
+      .Case<vector::MultiDimReductionOp>([&](auto reductionOp) {
+        visitVectorMultiReductionOp(reductionOp, operands, results);
+      })
+      // All other ops.
+      .Default([&](Operation *op) {
+        for (const LayoutInfoLattice *resultInfo : results) {
+          if (!resultInfo->getValue().isAssigned())
+            continue;
+          for (auto [operandInfo, operand] :
+               llvm::zip(operands, op->getOpOperands())) {
+            // If the operand type is not a vector or tensor descriptor, skip
+            // it.
+            if (!isa<xegpu::TensorDescType, VectorType>(
+                    operand.get().getType()))
+              continue;
+            // Propagate the result layout to the operand.
+            meet(operandInfo, *resultInfo);
+          }
+        }
+      });
+
+  return success();
+}
+
+void LayoutInfoPropagation::visitPrefetchNdOp(
+    xegpu::PrefetchNdOp prefetch, ArrayRef<LayoutInfoLattice *> operands,
+    ArrayRef<const LayoutInfoLattice *> results) {
+  // Here we assign the default layout to the tensor descriptor operand of
+  // prefetch.
+  auto tdescTy = prefetch.getTensorDescType();
+  auto prefetchLayout = getDefaultSIMTLayoutInfo(
+      VectorType::get(tdescTy.getShape(), tdescTy.getElementType()));
+  // Propagate the layout to the source tensor descriptor.
+  propagateIfChanged(operands[0], operands[0]->meet(prefetchLayout));
+}
+
+void LayoutInfoPropagation::visitVectorMultiReductionOp(
+    vector::MultiDimReductionOp reduction,
+    ArrayRef<LayoutInfoLattice *> operands,
+    ArrayRef<const LayoutInfoLattice *> results) {
+  // The layout of the result must be present.
+  LayoutInfo resultLayout = results[0]->getValue();
+  if (!resultLayout.isAssigned())
+    return;
+  // We only consider 2D -> 1D reductions at this point.
+  VectorType resultTy = llvm::dyn_cast<VectorType>(reduction.getDestType());
+  if (!resultTy || resultTy.getRank() != 1) {
+    reduction.emitWarning("Expecting output type to be 1D vector.");
+    return;
+  }
+  // Given that the result is 1D, the layout of the operand should be 2D with
+  // default layout.
+  LayoutInfo operandLayout = getDefaultSIMTLayoutInfo(2);
+  propagateIfChanged(operands[0], operands[0]->meet(operandLayout));
+  // Accumulator should have the same layout as the result.
+  propagateIfChanged(operands[1], operands[1]->meet(resultLayout));
+}
+
+/// Propagate the layout of the result tensor to the source tensor descriptor in
+/// UpdateNdOffsetOp.
+void LayoutInfoPropagation::visitUpdateNdOffsetOp(
+    xegpu::UpdateNdOffsetOp updateNdOffset,
+    ArrayRef<LayoutInfoLattice *> operands,
+    ArrayRef<const LayoutInfoLattice *> results) {
+  // The layout of the result must be present.
+  LayoutInfo resultLayout = results[0]->getValue();
+  if (!resultLayout.isAssigned())
+    return;
+  // Propagate the layout to the source operand.
+  propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
+}
+
+/// Set the layouts for DPAS A, B, and C operands.
+void LayoutInfoPropagation::visitDpasOp(
+    xegpu::DpasOp dpas, ArrayRef<LayoutInfoLattice *> operands,
+    ArrayRef<const LayoutInfoLattice *> results) {
+  VectorType aTy = dpas.getLhsType();
+  VectorType bTy = dpas.getRhsType();
+  propagateIfChanged(
+      operands[0], operands[0]->meet(getSIMTLayoutInfoForDPASOperand(aTy, 0)));
+  propagateIfChanged(
+      operands[1], operands[1]->meet(getSIMTLayoutInfoForDPASOperand(bTy, 1)));
+  if (operands.size() > 2) {
+    VectorType cTy = dpas.getAccType();
+    propagateIfChanged(
+        operands[2],
+        operands[2]->meet(getSIMTLayoutInfoForDPASOperand(cTy, 2)));
+  }
+}
+
+/// Set the layout for the value and tensor descriptor operands in StoreNdOp.
+void LayoutInfoPropagation::visitStoreNdOp(
+    xegpu::StoreNdOp store, ArrayRef<LayoutInfoLattice *> operands,
+    ArrayRef<const LayoutInfoLattice *> results) {
+  LayoutInfo storeLayout = getDefaultSIMTLayoutInfo(store.getValueType());
+  // Both operands should have the same layout
+  for (LayoutInfoLattice *operand : operands)
+    propagateIfChanged(operand, operand->meet(storeLayout));
+}
+
+/// Propagate the layout of the value to the tensor descriptor operand in
+/// LoadNdOp.
+void LayoutInfoPropagation::visitLoadNdOp(
+    xegpu::LoadNdOp load, ArrayRef<LayoutInfoLattice *> operands,
+    ArrayRef<const LayoutInfoLattice *> results) {
+  LayoutInfo valueLayout = results[0]->getValue();
+  // Need the layout of the value to propagate to the tensor descriptor.
+  if (!valueLayout.isAssigned())
+    return;
+  LayoutInfo tensorDescLayout = valueLayout;
+  // LoadNdOp has the transpose effect. However, at the stage of this analysis
+  // this effect is not expected and should be abstracted away. Emit a
+  // warning.
+  if (auto transpose = load.getTranspose()) {
+    load.emitWarning("Transpose effect is not expected for LoadNdOp at "
+                     "LayoutInfoPropagation stage.");
+    tensorDescLayout = valueLayout.getTransposedLayout(transpose.value());
+  }
+  // Propagate the new layout to the tensor descriptor operand.
+  propagateIfChanged(operands[0], operands[0]->meet(tensorDescLayout));
+}
+
+/// For vector::TransposeOp, the layout of the result is transposed and
+/// propagated to the operand.
+void LayoutInfoPropagation::visitTransposeOp(
+    vector::TransposeOp transpose, ArrayRef<LayoutInfoLattice *> operands,
+    ArrayRef<const LayoutInfoLattice *> results) {
+  // Need the layout of transpose result to propagate to the operands.
+  LayoutInfo resultLayout = results[0]->getValue();
+  if (!resultLayout.isAssigned())
+    return;
+  LayoutInfo newLayout =
+      resultLayout.getTransposedLayout(transpose.getPermutation());
+  // Propagate the new layout to the vector operand.
+  propagateIfChanged(operands[0], operands[0]->meet(newLayout));
+}
+
+/// For vector::BitCastOp, the lane_data of the source layout is changed based
+/// on the bit width of the source and result types.
+void LayoutInfoPropagation::visitVectorBitcastOp(
+    vector::BitCastOp bitcast, ArrayRef<LayoutInfoLattice *> operands,
+    ArrayRef<const LayoutInfoLattice *> results) {
+  // Need the layout of bitcast result to propagate to the operands.
+  LayoutInfo resultLayout = results[0]->getValue();
+  if (!resultLayout.isAssigned())
+    return;
+  int inElemTyBitWidth =
+      bitcast.getSourceVectorType().getElementType().getIntOrFloatBitWidth();
+  int outElemTyBitWidth =
+      bitcast.getResultVectorType().getElementType().getIntOrFloatBitWidth();
+
+  // NOTE: We do not expect widening or narrowing bitcasts at this stage. Emit
+  // a warning and return.
+  if (inElemTyBitWidth != outElemTyBitWidth) {
+    bitcast.emitWarning("Widening or narrowing bitcasts are not expected at "
+                        "layout propagation stage.");
+    return;
+  }
+
+  propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
+}
+
+/// Propagate the layout of the result to the tensor descriptor and mask
+/// operands in LoadGatherOp.
+void LayoutInfoPropagation::visitLoadGatherOp(
+    xegpu::LoadGatherOp load, ArrayRef<LayoutInfoLattice *> operands,
+    ArrayRef<const LayoutInfoLattice *> results) {
+  LayoutInfo valueLayout = results[0]->getValue();
+  // Need the layout of the value to propagate to the tensor descriptor.
+  if (!valueLayout.isAssigned())
+    return;
+
+  LayoutInfo tensorDescLayout = valueLayout;
+  if (load.getTranspose()) {
+    // LoadGatherOp has the transpose effect. However, at the stage of this
+    // analyis this effect is not expected and should be abstracted away. Emit
+    // a warning.
+    load.emitWarning("Transpose effect is not expected for LoadGatherOp at "
+                     "LayoutInfoPropagation stage.");
+    tensorDescLayout = valueLayout.getTransposedLayout({1, 0});
+  }
+  // Mask operand should have 1D default layout.
+  LayoutInfo maskLayout = getDefaultSIMTLayoutInfo(1);
+  // Propagate the new layout to the tensor descriptor operand.
+  propagateIfChanged(operands[0], operands[0]->meet(tensorDescLayout));
+  // Propagate the new layout to the mask operand.
+  propagateIfChanged(operands[1], operands[1]->meet(maskLayout));
+}
+
+/// Propagate the layout of the descriptor to the vector offset operand in
+/// CreateDescOp.
+void LayoutInfoPropagation::visitCreateDescOp(
+    xegpu::CreateDescOp createDesc, ArrayRef<LayoutInfoLattice *> operands,
+    ArrayRef<const LayoutInfoLattice *> results) {
+  LayoutInfo descLayout = results[0]->getValue();
+  // Need the layout of the descriptor to propagate to the operands.
+  if (!descLayout.isAssigned())
+    return;
+  // For offset operand propagate 1D default layout.
+  LayoutInfo layout = getDefaultSIMTLayoutInfo(1);
+  propagateIfChanged(operands[1], operands[1]->meet(layout));
+}
+
+/// Set the layout for the value, tensor descriptor, and mask operands in the
+/// StoreScatterOp.
+void LayoutInfoPropagation::visitStoreScatterOp(
+    xegpu::StoreScatterOp storeScatter, ArrayRef<LayoutInfoLattice *> operands,
+    ArrayRef<const LayoutInfoLattice *> results) {
+  // Currently, for 2D StoreScatterOp we expect that the height dimension of
+  // the tensor descriptor is equal to the subgroup size. This is ensured by
+  // the op verifier.
+  ArrayRef<int64_t> tdescShape = storeScatter.getTensorDescType().getShape();
+  if (tdescShape.size() > 1)
+    assert(
+        tdescShape[0] == xegpu::targetinfo::subgroupSize &&
+        "Expected the first dimension of 2D tensor descriptor to be equal to "
+        "subgroup size.");
+
+  LayoutInfo valueLayout =
+      getDefaultSIMTLayoutInfo(storeScatter.getValueType());
+  LayoutInfo storeScatterLayout = valueLayout;
+  if (storeScatter.getTranspose()) {
+    // StoreScatteOp allows transpose effect. However, at the stage of this
+    // analyis this effect is not expected and should be abstracted away. Emit
+    // a warning.
+    storeScatter.emitWarning("Transpose effect is not expected for "
+                             "StoreScatterOp at LayoutInfoPropagation stage.");
+    storeScatterLayout = valueLayout.getTransposedLayout({1, 0});
+  }
+  // Propagate the value layout.
+  propagateIfChanged(operands[0], operands[0]->meet(valueLayout));
+  // Propagate the tensor descriptor layout.
+  propagateIfChanged(operands[1], operands[1]->meet(storeScatterLayout));
+  // Use default 1D layout for mask operand.
+  LayoutInfo maskLayout = getDefaultSIMTLayoutInfo(1);
+  propagateIfChanged(operands[2], operands[2]->meet(maskLayout));
+}
+
+namespace {
+//===----------------------------------------------------------------------===//
+// RunLayoutInfoPropagation
+//===----------------------------------------------------------------------===//
+
+/// Driver class for running the LayoutInfoPropagation analysis.
+class RunLayoutInfoPropagation {
+public:
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(RunLayoutInfoPropagation)
+
+  RunLayoutInfoPropagation(Operation *op) : target(op) {
+    SymbolTableCollection symbolTable;
+    loadBaselineAnalyses(solver);
+    solver.load<LayoutInfoPropagation>(symbolTable);
+    (void)solver.initializeAndRun(op);
+  }
+
+  LayoutInfo getLayoutInfo(Value val);
+
+  void printAnalysisResult(llvm::raw_ostream &os);
+
+private:
+  DataFlowSolver solver;
+  const Operation *target;
+};
+} // namespace
+
+LayoutInfo RunLayoutInfoPropagation::getLayoutInfo(Value val) {
+  auto *state = solver.lookupState<LayoutInfoLattice>(val);
+  if (!state)
+    return {};
+  return state->getValue();
+}
+
+// Print the analysis result for debugging purposes.
+void RunLayoutInfoPropagation::printAnalysisResult(llvm::raw_ostream &os) {
+  auto printFunctionResult = [&](FunctionOpInterface funcOp) {
+    os << "function: " << funcOp.getName() << ":\n";
+    // Function arguments
+    for (BlockArgument arg : funcOp.getArguments()) {
+      LayoutInfo layout = getLayoutInfo(arg);
+      os << "argument: " << arg << "\n";
+      os << "layout  : ";
+      layout.print(os);
+      os << "\n";
+    }
+    // Function ops
+    funcOp.walk([&](Operation *op) {
+      // Skip ops that do not have results
+      if (op->getResults().empty())
+        return;
+      os << "op    : ";
+      // For control-flow ops, print the op name only.
+      if (isa<BranchOpInterface>(op) || isa<RegionBranchOpInterface>(op))
+        os << op->getName();
+      else
+        op->print(os);
+      os << "\n";
+      // Print the layout for each result.
+      for (auto [i, r] : llvm::enumerate(op->getResults())) {
+        LayoutInfo layout = getLayoutInfo(r);
+        os << "layout for result #" << i << ": ";
+        layout.print(os);
+        os << "\n";
+      }
+    });
+  };
+
+  SmallVector<FunctionOpInterface> funcOps;
+  if (auto modOp = dyn_cast<ModuleOp>(target)) {
+    for (auto funcOp : modOp.getOps<FunctionOpInterface>())
+      funcOps.push_back(funcOp);
+
+    // Collect all GpuFuncOps in the module.
+    for (auto gpuModOp : modOp.getOps<gpu::GPUModuleOp>()) {
+      for (auto gpuFuncOp : gpuModOp.getOps<FunctionOpInterface>())
+        funcOps.push_back(gpuFuncOp);
+    }
+  }
+  // Print the analysis result for each function.
+  for (FunctionOpInterface funcOp : funcOps)
+    printFunctionResult(funcOp);
+}
+
+using GetLayoutFnTy = function_ref<xegpu::LayoutAttr(Value)>;
+/// Update an operation with the layout of its results. If the result type is a
+/// vector type, a temporary layout attribute is added to the operation. If the
+/// result type is a tensor descriptor type, the type is updated with the layout
+/// attribute. The users of the result are also updated with the layout
+/// attribute.
+static LogicalResult updateOp(mlir::OpBuilder &builder, mlir::Operation *op,
+                              GetLayoutFnTy getLayoutOfValue) {
+  // Region ops (like scf.for) are already handled by the updateControlFlowOps.
+  if (mlir::isa<mlir::RegionBranchOpInterface>(op))
+    return success();
+
+  // Iterate over all the results.
+  for (OpResult result : op->getResults()) {
+    Type resultType = result.getType();
+    // Layouts are needed only for vector and tensor descriptor types.
+    if (!isa<VectorType, xegpu::TensorDescType>(resultType))
+      continue;
+    // If the result has no layout but has users, emit a warning and continue.
+    xegpu::LayoutAttr layout = getLayoutOfValue(result);
+    if (!layout && result.getNumUses() > 0) {
+      op->emitWarning("op has users but no layout assigned for its result");
+      continue;
+    }
+    // If the result is a tensor descriptor type, update the tensor desc type
+    // with layout.
+    if (auto tensorDescTy = dyn_cast<xegpu::TensorDescType>(resultType)) {
+      auto typeWithLayout = xegpu::TensorDescType::get(
+          tensorDescTy.getContext(), tensorDescTy.getShape(),
+          tensorDescTy.getElementType(), tensorDescTy.getEncoding(), layout);
+      result.setType(typeWithLayout);
+      continue;
+    }
+    // If the result is a vector type, add a temporary layout attribute to the
+    // op.
+    xegpu::setLayoutAttr(result, layout);
+  }
+  return success();
+}
+
+/// Region ops like scf.for need special handling because they have blocks
+/// inside. If the blocks have tensor descriptor type as block arguments, thier
+/// types must be updated. Also region op can have results that may not have any
+/// users (e.g. A and B tiles). They are not assigned a layout by layout
+/// analysis because they have no users. However inside the region op
+/// corresponding block arguments for these results do have layouts. Therefore,
+/// in this case we still need to update the result types with the layout
+/// attribute. This function function updates the internal block arguments and
+/// the result types of the region op with the assigned layouts.
+/// clang-format off
+/// Example: scf.for ... iter_args(...) -> (out types) {
+///   ^bb0(block types):
+///     ...
+///   scf.yield ... : (yield types)
+/// }
+/// clang-format on
+/// In this example, at scf.yield, control-flow can transfer to two successor
+/// regions. One is the ^bb0 (for loop body) and the other is the scf.for op
+/// itself (yield the results). So we update both the block arguments of the
+/// successor region (i.e. block types) and the result types of the scf.for op
+/// (i.e. out types). Note that yield types are updated by respective producers
+/// inside bb0.
+static LogicalResult
+updateControlFlowOps(mlir::OpBuilder &builder,
+                     mlir::RegionBranchTerminatorOpInterface terminator,
+                     GetLayoutFnTy getLayoutOfValue) {
+  // Only process if the terminator is inside a region branch op.
+  if (!mlir::isa<mlir::RegionBranchOpInterface>(terminator->getParentOp()))
+    return success();
+
+  llvm::SmallVector<mlir::RegionSuccessor> successors;
+  llvm::SmallVector<mlir::Attribute> operands(terminator->getNumOperands(),
+                                              nullptr);
+  terminator.getSuccessorRegions(operands, successors);
+
+  for (mlir::RegionSuccessor &successor : successors) {
+    mlir::OperandRange successorOperands =
+        terminator.getSuccessorOperands(successor);
+    mlir::ValueRange successorInputs = successor.getSuccessorInputs();
+    for (auto [successorOperand, successorInput] :
+         llvm::zip(successorOperands, successorInputs)) {
+      Type inputType = successorInput.getType();
+      // We only need to operate on tensor descriptor or vector types.
+      if (!isa<xegpu::TensorDescType, VectorType>(inputType))
+        continue;
+      xegpu::LayoutAttr successorInputLayout = getLayoutOfValue(successorInput);
+      xegpu::LayoutAttr successorOperandLayout =
+          getLayoutOfValue(successorOperand);
+
+      // If either of the layouts is not assigned, we cannot proceed.
+      if (!successorOperandLayout) {
+        LLVM_DEBUG(
+            DBGS()
+            << "No layout assigned for forwarded operand in branch terminator: "
+            << successorOperand << "\n");
+        return failure();
+      }
+      // We expect the layouts to match.
+      if (successorInputLayout &&
+          successorInputLayout != successorOperandLayout) {
+        LLVM_DEBUG(DBGS() << "Conflicting layouts for region argument and "
+                             "operand forwarded as the argument: "
+                          << successorInputLayout << " vs "
+                          << successorOperandLayout << "\n");
+        return failure();
+      }
+      // Get tensor descriptor type with the layout.
+      if (auto tdescTy = dyn_cast<xegpu::TensorDescType>(inputType)) {
+        auto newTdescTy = xegpu::TensorDescType::get(
+            tdescTy.getContext(), tdescTy.getShape(), tdescTy.getElementType(),
+            tdescTy.getEncoding(), successorOperandLayout);
+        successorInput.setType(newTdescTy);
+        continue;
+      }
+      // If the type is a vector type and this region argument is an OpResult,
+      // set the layout attribute on the OpResult.
+      if (auto result = dyn_cast<OpResult>(successorInput))
+        xegpu::setLayoutAttr(result, successorOperandLayout);
+    }
+  }
+  return success();
+}
+
+/// Update the function arguments and results with the layouts.
+static LogicalResult updateFunctionOpInterface(mlir::OpBuilder &builder,
+                                               mlir::FunctionOpInterface funcOp,
+                                               GetLayoutFnTy getLayoutOfValue) {
+  SmallVector<Type> newArgTypes;
+  // Update the function arguments.
+  for (BlockArgument arg : funcOp.getArguments()) {
+    Type argType = arg.getType();
+    newArgTypes.push_back(argType);
+    if (!isa<VectorType, xegpu::TensorDescType>(argType))
+      continue;
+    xegpu::LayoutAttr layout = getLayoutOfValue(arg);
+    if (!layout) {
+      LLVM_DEBUG(DBGS() << "Expecting layout for function argument: " << arg
+                        << " but got none.\n");
+      return failure();
+    }
+    if (auto tensorDescTy = dyn_cast<xegpu::TensorDescType>(argType)) {
+      auto newTdescTy = xegpu::TensorDescType::get(
+          tensorDescTy.getContext(), tensorDescTy.getShape(),
+          tensorDescTy.getElementType(), tensorDescTy.getEncoding(), layout);
+      arg.setType(newTdescTy);
+      newArgTypes.back() = newTdescTy;
+    }
+  }
+  // Update the function type with the new argument types.
+  // NOTE: We assume that function results are not expected to have layouts.
+  funcOp.setType(FunctionType::get(funcOp.getContext(), newArgTypes,
+                                   funcOp.getResultTypes()));
+  return success();
+}
+
+namespace {
+struct XeGPUPropagateLayoutPass final
+    : public xegpu::impl::XeGPUPropagateLayoutBase<XeGPUPropagateLayoutPass> {
+  XeGPUPropagateLayoutPass() = default;
+  XeGPUPropagateLayoutPass(const XeGPUPropagateLayoutPass &other) = default;
+  XeGPUPropagateLayoutPass(xegpu::XeGPUPropagateLayoutOptions options)
+      : XeGPUPropagateLayoutBase(options) {}
+  void runOnOperation() override;
+};
+
+} // namespace
+
+void XeGPUPropagateLayoutPass::runOnOperation() {
+  auto &analysis = getAnalysis<RunLayoutInfoPropagation>();
+  // Print the analysis result and exit. (for debugging purposes)
+  if (printOnly) {
+    auto &os = llvm::outs();
+    analysis.printAnalysisResult(os);
+    return;
+  }
+  // Helper to convert LayoutInfo to xegpu::LayoutAttr.
+  auto getXeGPULayoutForValue = [&](Value val) -> xegpu::LayoutAttr {
+    LayoutInfo layout = analysis.getLayoutInfo(val);
+    if (!layout.isAssigned())
+      return {};
+    return xegpu::LayoutAttr::get(
+        val.getContext(), llvm::to_vector_of<int>(layout.getLayoutAsArrayRef()),
+        llvm::to_vector_of<int>(layout.getDataAsArrayRef()));
+  };
+
+  mlir::OpBuilder builder(&getContext());
+  Operation *op = getOperation();
+  auto walkResult = op->walk([&](mlir::Block *block) -> WalkResult {
+    for (mlir::Operation &op : llvm::reverse(block->getOperations())) {
+      LogicalResult r = success();
+      TypeSwitch<Operation *>(&op)
+          .Case<mlir::RegionBranchTerminatorOpInterface>(
+              [&](mlir::RegionBranchTerminatorOpInterface branchTermOp) {
+                r = updateControlFlowOps(builder, branchTermOp,
+                                         getXeGPULayoutForValue);
+              })
+          .Case<mlir::FunctionOpInterface>(
+              [&](mlir::FunctionOpInterface funcOp) {
+                r = updateFunctionOpInterface(builder, funcOp,
+                                              getXeGPULayoutForValue);
+              })
+          .Default([&](Operation *op) {
+            r = updateOp(builder, op, getXeGPULayoutForValue);
+          });
+      if (failed(r)) {
+        op.emitError("Failed to update operation with the layout.");
+        return WalkResult::interrupt();
+      }
+    }
+    return WalkResult::advance();
+  });
+  if (walkResult.wasInterrupted()) {
+    signalPassFailure();
+    return;
+  }
+}

diff  --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 66d21dbdaf064..dabcae0bfe4b1 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -5,9 +5,6 @@
 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 //
 //===----------------------------------------------------------------------===//
-#include "mlir/Analysis/DataFlow/SparseAnalysis.h"
-#include "mlir/Analysis/DataFlow/Utils.h"
-#include "mlir/Analysis/DataFlowFramework.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/Dialect/GPU/Utils/DistributionUtils.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
@@ -29,15 +26,13 @@
 #include "mlir/IR/Value.h"
 #include "mlir/IR/Visitors.h"
 #include "mlir/Interfaces/FunctionInterfaces.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Transforms/DialectConversion.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "mlir/Transforms/InliningUtils.h"
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallVector.h"
-#include "llvm/ADT/TypeSwitch.h"
-#include "llvm/Support/FormatVariadic.h"
-#include "llvm/Support/InterleavedRange.h"
-#include "llvm/Support/raw_ostream.h"
 
 namespace mlir {
 namespace xegpu {
@@ -50,788 +45,11 @@ namespace xegpu {
 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
 
 using namespace mlir;
-using namespace mlir::dataflow;
 
-/// HW dependent constants.
-/// TODO: These constants should be queried from the target information.
-constexpr unsigned subgroupSize = 16; // How many lanes in a subgroup.
-/// If DPAS A or B operands have low precision element types they must be packed
-/// according to the following sizes.
-constexpr unsigned packedSizeInBitsForDefault =
-    16; // Minimum packing size per register for DPAS A.
-constexpr unsigned packedSizeInBitsForDpasB =
-    32; // Minimum packing size per register for DPAS B.
-
-namespace {
-
-//===----------------------------------------------------------------------===//
-// Layout
-//===----------------------------------------------------------------------===//
-
-/// Helper class to store the ND layout of lanes within a subgroup and data
-/// owned by each lane.
-struct Layout {
-  SmallVector<int64_t, 3> layout;
-  Layout() = default;
-  Layout(std::initializer_list<int64_t> list) : layout(list) {}
-  void print(llvm::raw_ostream &os) const;
-  size_t size() const { return layout.size(); }
-  int64_t operator[](size_t idx) const;
-};
-
-void Layout::print(llvm::raw_ostream &os) const {
-  os << llvm::interleaved_array(layout);
-}
-
-int64_t Layout::operator[](size_t idx) const {
-  assert(idx < layout.size() && "Index out of bounds.");
-  return layout[idx];
-}
-
-/// LaneLayout represents the logical layout of lanes within a subgroup when it
-/// accesses some value. LaneData represents the logical layout of data owned by
-/// each work item.
-using LaneLayout = Layout;
-using LaneData = Layout;
-
-//===----------------------------------------------------------------------===//
-// LayoutInfo
-//===----------------------------------------------------------------------===//
-
-/// Helper class for tracking the analysis state of an mlir value. For layout
-/// propagation, the analysis state is simply the lane_layout and lane_data of
-/// each value. Purpose of this analysis to propagate some unique layout for
-/// each value in the program starting from a set of anchor operations (like
-/// DPAS, StoreNd, etc.).
-///
-/// Given this, LayoutInfo  satisifies the following properties:
-///  1) A LayoutInfo value can be in one of two states - `assigned` or `not
-///  assigned`.
-///  2) Two LayoutInfo values are equal if they are both assigned or
-///  both not assigned. The concrete value of assigned state does not matter.
-///  3) The meet operator works as follows:
-///     - If current state is assigned, return the current state. (already
-///     a unique layout is assigned. don't change it)
-///     - Otherwise, return the other state.
-
-struct LayoutInfo {
-private:
-  LaneLayout laneLayout;
-  LaneData laneData;
-
-public:
-  LayoutInfo() = default;
-  LayoutInfo(const LaneLayout &layout, const LaneData &data)
-      : laneLayout(layout), laneData(data) {}
-
-  // Two lattice values are equal if they have `some` layout. The actual
-  // content of the layout does not matter.
-  bool operator==(const LayoutInfo &other) const {
-    return this->isAssigned() == other.isAssigned();
-  }
-
-  static LayoutInfo meet(const LayoutInfo &lhs, const LayoutInfo &rhs);
-
-  static LayoutInfo join(const LayoutInfo &lhs, const LayoutInfo &rhs);
-
-  void print(raw_ostream &os) const;
-
-  bool isAssigned() const {
-    return laneLayout.size() > 0 && laneData.size() > 0;
-  }
-
-  LayoutInfo getTransposedLayout(ArrayRef<int64_t> permutation) const;
-
-  const LaneLayout &getLayout() const { return laneLayout; }
-  const LaneData &getData() const { return laneData; }
-  ArrayRef<int64_t> getLayoutAsArrayRef() const { return laneLayout.layout; }
-  ArrayRef<int64_t> getDataAsArrayRef() const { return laneData.layout; }
-};
-
-void LayoutInfo::print(raw_ostream &os) const {
-  if (isAssigned()) {
-    os << "lane_layout: ";
-    laneLayout.print(os);
-    os << ", lane_data: ";
-    laneData.print(os);
-  } else {
-    os << "Not assigned.";
-  }
-}
-
-LayoutInfo LayoutInfo::meet(const LayoutInfo &lhs, const LayoutInfo &rhs) {
-  if (!lhs.isAssigned())
-    return rhs;
-  return lhs;
-}
-
-/// Since this is a backward analysis, join method is not used.
-LayoutInfo LayoutInfo::join(const LayoutInfo &lhs, const LayoutInfo &rhs) {
-  llvm_unreachable("Join should not be triggered by layout propagation.");
-}
-
-/// Get the transposed layout according to the given permutation.
-LayoutInfo
-LayoutInfo::getTransposedLayout(ArrayRef<int64_t> permutation) const {
-  if (!isAssigned())
-    return {};
-  LaneLayout newLayout;
-  LaneData newData;
-  for (int64_t idx : permutation) {
-    newLayout.layout.push_back(laneLayout.layout[idx]);
-    newData.layout.push_back(laneData.layout[idx]);
-  }
-  return LayoutInfo(newLayout, newData);
-}
-
-//===----------------------------------------------------------------------===//
-// LayoutInfoLattice
-//===----------------------------------------------------------------------===//
-
-/// Lattice holding the LayoutInfo for each value.
-struct LayoutInfoLattice : public Lattice<LayoutInfo> {
-  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LayoutInfoLattice)
-  using Lattice::Lattice;
-};
-
-/// Helper Functions to get default layouts. A `default layout` is a layout that
-/// is assigned to a value when the layout is not fixed by some anchor operation
-/// (like DPAS).
-
-/// Helper Function to get the default layout for uniform values like constants.
-/// For 1D vector, lane_layout is [subgroupSize] and lane_data is [1].
-/// For 2D vector, lane_layout is [1, subgroupSize] and lane_data is [1, 1].
-static LayoutInfo getDefaultLayoutInfo(unsigned rank) {
-  assert((rank == 1 || rank == 2) && "Expected 1D or 2D vector.");
-  if (rank == 1)
-    return LayoutInfo(LaneLayout({subgroupSize}), LaneData({1}));
-  return LayoutInfo(LaneLayout({1, subgroupSize}), LaneData({1, 1}));
-}
-
-/// Helper to get the default layout for a vector type.
-static LayoutInfo getDefaultLayoutInfo(VectorType vectorTy) {
-  // Expecting a 1D or 2D vector.
-  assert((vectorTy.getRank() == 1 || vectorTy.getRank() == 2) &&
-         "Expected 1D or 2D vector.");
-  // Expecting int or float element type.
-  assert(vectorTy.getElementType().isIntOrFloat() &&
-         "Expected int or float element type.");
-  // If the rank is 1, then return default layout for 1D vector.
-  if (vectorTy.getRank() == 1)
-    return getDefaultLayoutInfo(1);
-  // Packing factor is determined by the element type bitwidth.
-  int packingFactor = 1;
-  unsigned bitwidth = vectorTy.getElementType().getIntOrFloatBitWidth();
-  if (bitwidth < packedSizeInBitsForDefault)
-    packingFactor = packedSizeInBitsForDefault / bitwidth;
-  return LayoutInfo(LaneLayout({1, subgroupSize}),
-                    LaneData({1, packingFactor}));
-}
-
-/// Helper Function to get the expected layouts for DPAS operands. `lane_data`
-/// is set according to the following criteria:
-/// * For A operand, the data must be packed in minimum
-/// `packedSizeInBitsForDefault`
-/// * For B operand, the data must be packed in minimum
-/// `packedSizeInBitsForDpasB`
-static LayoutInfo getLayoutInfoForDPASOperand(VectorType vectorTy,
-                                              unsigned operandNum) {
-  Type elementTy = vectorTy.getElementType();
-  assert(elementTy.isIntOrFloat() &&
-         "Expected int or float type in DPAS operands");
-  LaneLayout layout({1, subgroupSize});
-  // For B operand, data must be packed in minimum `packedDpasBSizeInBits` and
-  // must have the VNNI format.
-  if (operandNum == 1 &&
-      elementTy.getIntOrFloatBitWidth() < packedSizeInBitsForDpasB) {
-    LaneData data(
-        {packedSizeInBitsForDpasB / elementTy.getIntOrFloatBitWidth(), 1});
-    return LayoutInfo(layout, data);
-  }
-  // Otherwise, return the default layout for the vector type.
-  return getDefaultLayoutInfo(vectorTy);
-}
-
-//===----------------------------------------------------------------------===//
-// LayoutInfoPropagation
-//===----------------------------------------------------------------------===//
-
-/// Backward data flow analysis to propagate the lane_layout and lane_data of
-/// each value in the program. Currently, the layouts for operands DPAS,
-/// StoreNd, and StoreScatter are fixed (known before propagation). Purpose of
-/// this analysis is to propagate those known layouts to all their producers and
-/// (other) consumers.
-class LayoutInfoPropagation
-    : public SparseBackwardDataFlowAnalysis<LayoutInfoLattice> {
-private:
-  void visitDpasOp(xegpu::DpasOp dpas, ArrayRef<LayoutInfoLattice *> operands,
-                   ArrayRef<const LayoutInfoLattice *> results);
-
-  void visitStoreNdOp(xegpu::StoreNdOp store,
-                      ArrayRef<LayoutInfoLattice *> operands,
-                      ArrayRef<const LayoutInfoLattice *> results);
-
-  void visitStoreScatterOp(xegpu::StoreScatterOp storeScatter,
-                           ArrayRef<LayoutInfoLattice *> operands,
-                           ArrayRef<const LayoutInfoLattice *> results);
-
-  void visitLoadNdOp(xegpu::LoadNdOp load,
-                     ArrayRef<LayoutInfoLattice *> operands,
-                     ArrayRef<const LayoutInfoLattice *> results);
-
-  void visitLoadGatherOp(xegpu::LoadGatherOp load,
-                         ArrayRef<LayoutInfoLattice *> operands,
-                         ArrayRef<const LayoutInfoLattice *> results);
-
-  void visitTransposeOp(vector::TransposeOp transpose,
-                        ArrayRef<LayoutInfoLattice *> operands,
-                        ArrayRef<const LayoutInfoLattice *> results);
-
-  void visitVectorBitcastOp(vector::BitCastOp bitcast,
-                            ArrayRef<LayoutInfoLattice *> operands,
-                            ArrayRef<const LayoutInfoLattice *> results);
-
-  void visitCreateDescOp(xegpu::CreateDescOp createDesc,
-                         ArrayRef<LayoutInfoLattice *> operands,
-                         ArrayRef<const LayoutInfoLattice *> results);
-
-  void visitUpdateNdOffsetOp(xegpu::UpdateNdOffsetOp updateNdOffset,
-                             ArrayRef<LayoutInfoLattice *> operands,
-                             ArrayRef<const LayoutInfoLattice *> results);
-
-  void visitPrefetchNdOp(xegpu::PrefetchNdOp prefetch,
-                         ArrayRef<LayoutInfoLattice *> operands,
-                         ArrayRef<const LayoutInfoLattice *> results);
-
-  void visitVectorMultiReductionOp(vector::MultiDimReductionOp reduction,
-                                   ArrayRef<LayoutInfoLattice *> operands,
-                                   ArrayRef<const LayoutInfoLattice *> results);
-
-public:
-  LayoutInfoPropagation(DataFlowSolver &solver,
-                        SymbolTableCollection &symbolTable)
-      : SparseBackwardDataFlowAnalysis(solver, symbolTable) {}
-  using SparseBackwardDataFlowAnalysis::SparseBackwardDataFlowAnalysis;
-
-  LogicalResult
-  visitOperation(Operation *op, ArrayRef<LayoutInfoLattice *> operands,
-                 ArrayRef<const LayoutInfoLattice *> results) override;
-
-  void visitBranchOperand(OpOperand &operand) override {};
-
-  void visitCallOperand(OpOperand &operand) override {};
-
-  void visitExternalCall(CallOpInterface call,
-                         ArrayRef<LayoutInfoLattice *> operands,
-                         ArrayRef<const LayoutInfoLattice *> results) override {
-  };
-
-  void setToExitState(LayoutInfoLattice *lattice) override {
-    (void)lattice->meet(LayoutInfo());
-  }
-};
-} // namespace
-
-LogicalResult LayoutInfoPropagation::visitOperation(
-    Operation *op, ArrayRef<LayoutInfoLattice *> operands,
-    ArrayRef<const LayoutInfoLattice *> results) {
-  TypeSwitch<Operation *>(op)
-      .Case<xegpu::DpasOp>(
-          [&](auto dpasOp) { visitDpasOp(dpasOp, operands, results); })
-      .Case<xegpu::StoreNdOp>(
-          [&](auto storeNdOp) { visitStoreNdOp(storeNdOp, operands, results); })
-      .Case<xegpu::StoreScatterOp>([&](auto storeScatterOp) {
-        visitStoreScatterOp(storeScatterOp, operands, results);
-      })
-      .Case<xegpu::LoadNdOp>(
-          [&](auto loadNdOp) { visitLoadNdOp(loadNdOp, operands, results); })
-      .Case<xegpu::LoadGatherOp>([&](auto loadGatherOp) {
-        visitLoadGatherOp(loadGatherOp, operands, results);
-      })
-      .Case<xegpu::CreateDescOp>([&](auto createDescOp) {
-        visitCreateDescOp(createDescOp, operands, results);
-      })
-      .Case<xegpu::UpdateNdOffsetOp>([&](auto updateNdOffsetOp) {
-        visitUpdateNdOffsetOp(updateNdOffsetOp, operands, results);
-      })
-      .Case<xegpu::PrefetchNdOp>([&](auto prefetchNdOp) {
-        visitPrefetchNdOp(prefetchNdOp, operands, results);
-      })
-      // No need to propagate the layout to operands in CreateNdDescOp because
-      // they are scalars (offsets, sizes, etc.).
-      .Case<xegpu::CreateNdDescOp>([&](auto createNdDescOp) {})
-      .Case<vector::TransposeOp>([&](auto transposeOp) {
-        visitTransposeOp(transposeOp, operands, results);
-      })
-      .Case<vector::BitCastOp>([&](auto bitcastOp) {
-        visitVectorBitcastOp(bitcastOp, operands, results);
-      })
-      .Case<vector::MultiDimReductionOp>([&](auto reductionOp) {
-        visitVectorMultiReductionOp(reductionOp, operands, results);
-      })
-      // All other ops.
-      .Default([&](Operation *op) {
-        for (const LayoutInfoLattice *r : results) {
-          for (LayoutInfoLattice *operand : operands) {
-            // Propagate the layout of the result to the operand.
-            if (r->getValue().isAssigned())
-              meet(operand, *r);
-          }
-        }
-      });
-  // Add a dependency from each result to program point after the operation.
-  for (const LayoutInfoLattice *r : results) {
-    addDependency(const_cast<LayoutInfoLattice *>(r), getProgramPointAfter(op));
-  }
-  return success();
-}
-
-void LayoutInfoPropagation::visitPrefetchNdOp(
-    xegpu::PrefetchNdOp prefetch, ArrayRef<LayoutInfoLattice *> operands,
-    ArrayRef<const LayoutInfoLattice *> results) {
-  // Here we assign the default layout to the tensor descriptor operand of
-  // prefetch.
-  auto tdescTy = prefetch.getTensorDescType();
-  auto prefetchLayout = getDefaultLayoutInfo(
-      VectorType::get(tdescTy.getShape(), tdescTy.getElementType()));
-  // Propagate the layout to the source tensor descriptor.
-  propagateIfChanged(operands[0], operands[0]->meet(prefetchLayout));
-}
-
-void LayoutInfoPropagation::visitVectorMultiReductionOp(
-    vector::MultiDimReductionOp reduction,
-    ArrayRef<LayoutInfoLattice *> operands,
-    ArrayRef<const LayoutInfoLattice *> results) {
-  // The layout of the result must be present.
-  LayoutInfo resultLayout = results[0]->getValue();
-  if (!resultLayout.isAssigned())
-    return;
-  // We only consider 2D -> 1D reductions at this point.
-  assert(resultLayout.getLayout().size() == 1 &&
-         "Expected 1D layout for reduction result.");
-  // Given that the result is 1D, the layout of the operand should be 2D with
-  // default layout.
-  LayoutInfo operandLayout = getDefaultLayoutInfo(2);
-  propagateIfChanged(operands[0], operands[0]->meet(operandLayout));
-  // Accumulator should have the same layout as the result.
-  propagateIfChanged(operands[1], operands[1]->meet(resultLayout));
-}
-
-/// Propagate the layout of the result tensor to the source tensor descriptor in
-/// UpdateNdOffsetOp.
-void LayoutInfoPropagation::visitUpdateNdOffsetOp(
-    xegpu::UpdateNdOffsetOp updateNdOffset,
-    ArrayRef<LayoutInfoLattice *> operands,
-    ArrayRef<const LayoutInfoLattice *> results) {
-  // The layout of the result must be present.
-  LayoutInfo resultLayout = results[0]->getValue();
-  if (!resultLayout.isAssigned())
-    return;
-  // Propagate the layout to the source operand.
-  propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
-}
-
-/// Set the layouts for DPAS A, B, and C operands.
-void LayoutInfoPropagation::visitDpasOp(
-    xegpu::DpasOp dpas, ArrayRef<LayoutInfoLattice *> operands,
-    ArrayRef<const LayoutInfoLattice *> results) {
-  VectorType aTy = dpas.getLhsType();
-  VectorType bTy = dpas.getRhsType();
-  propagateIfChanged(operands[0],
-                     operands[0]->meet(getLayoutInfoForDPASOperand(aTy, 0)));
-  propagateIfChanged(operands[1],
-                     operands[1]->meet(getLayoutInfoForDPASOperand(bTy, 1)));
-  if (operands.size() > 2) {
-    VectorType cTy = dpas.getAccType();
-    propagateIfChanged(operands[2],
-                       operands[2]->meet(getLayoutInfoForDPASOperand(cTy, 2)));
-  }
-}
-
-/// Set the layout for the value and tensor descriptor operands in StoreNdOp.
-void LayoutInfoPropagation::visitStoreNdOp(
-    xegpu::StoreNdOp store, ArrayRef<LayoutInfoLattice *> operands,
-    ArrayRef<const LayoutInfoLattice *> results) {
-  LayoutInfo storeLayout = getDefaultLayoutInfo(store.getValueType());
-  // Both operands should have the same layout
-  for (LayoutInfoLattice *operand : operands) {
-    propagateIfChanged(operand, operand->meet(storeLayout));
-  }
-}
-
-/// Propagate the layout of the value to the tensor descriptor operand in
-/// LoadNdOp.
-void LayoutInfoPropagation::visitLoadNdOp(
-    xegpu::LoadNdOp load, ArrayRef<LayoutInfoLattice *> operands,
-    ArrayRef<const LayoutInfoLattice *> results) {
-  LayoutInfo valueLayout = results[0]->getValue();
-  // Need the layout of the value to propagate to the tensor descriptor.
-  if (!valueLayout.isAssigned())
-    return;
-  LayoutInfo tensorDescLayout = valueLayout;
-  // LoadNdOp has the transpose effect. However, at the stage of this analysis
-  // this effect is not expected and should be abstracted away. Emit a warning.
-  if (auto transpose = load.getTranspose()) {
-    load.emitWarning("Transpose effect is not expected for LoadNdOp at "
-                     "LayoutInfoPropagation stage.");
-    tensorDescLayout = valueLayout.getTransposedLayout(transpose.value());
-  }
-  // Propagate the new layout to the tensor descriptor operand.
-  propagateIfChanged(operands[0], operands[0]->meet(tensorDescLayout));
-}
-
-/// For vector::TransposeOp, the layout of the result is transposed and
-/// propagated to the operand.
-void LayoutInfoPropagation::visitTransposeOp(
-    vector::TransposeOp transpose, ArrayRef<LayoutInfoLattice *> operands,
-    ArrayRef<const LayoutInfoLattice *> results) {
-  // Need the layout of transpose result to propagate to the operands.
-  LayoutInfo resultLayout = results[0]->getValue();
-  if (!resultLayout.isAssigned())
-    return;
-  LayoutInfo newLayout =
-      resultLayout.getTransposedLayout(transpose.getPermutation());
-  // Propagate the new layout to the vector operand.
-  propagateIfChanged(operands[0], operands[0]->meet(newLayout));
-}
-
-/// For vector::BitCastOp, the lane_data of the source layout is changed based
-/// on the bit width of the source and result types.
-void LayoutInfoPropagation::visitVectorBitcastOp(
-    vector::BitCastOp bitcast, ArrayRef<LayoutInfoLattice *> operands,
-    ArrayRef<const LayoutInfoLattice *> results) {
-  // Need the layout of bitcast result to propagate to the operands.
-  LayoutInfo resultLayout = results[0]->getValue();
-  if (!resultLayout.isAssigned())
-    return;
-  int inElemTyBitWidth =
-      bitcast.getSourceVectorType().getElementType().getIntOrFloatBitWidth();
-  int outElemTyBitWidth =
-      bitcast.getResultVectorType().getElementType().getIntOrFloatBitWidth();
-
-  // LaneLayout does not change.
-  const LaneLayout &newLaneLayout = resultLayout.getLayout();
-  const LaneData &currData = resultLayout.getData();
-  LaneData newLaneData;
-  // It's a widening bitcast
-  if (inElemTyBitWidth < outElemTyBitWidth) {
-    int ratio = outElemTyBitWidth / inElemTyBitWidth;
-    newLaneData = resultLayout.getData()[0] == 1
-                      ? LaneData({1, currData[1] * ratio})
-                      : LaneData({currData[0] * ratio, 1});
-  } else {
-    // It's a narrowing bitcast
-    int ratio = inElemTyBitWidth / outElemTyBitWidth;
-    newLaneData = resultLayout.getData()[0] == 1
-                      ? LaneData({1, currData[1] / ratio})
-                      : LaneData({currData[0] / ratio, 1});
-  }
-
-  propagateIfChanged(operands[0],
-                     operands[0]->meet(LayoutInfo(newLaneLayout, newLaneData)));
-}
-
-/// Propagate the layout of the result to the tensor descriptor and mask
-/// operands in LoadGatherOp.
-void LayoutInfoPropagation::visitLoadGatherOp(
-    xegpu::LoadGatherOp load, ArrayRef<LayoutInfoLattice *> operands,
-    ArrayRef<const LayoutInfoLattice *> results) {
-  LayoutInfo valueLayout = results[0]->getValue();
-  // Need the layout of the value to propagate to the tensor descriptor.
-  if (!valueLayout.isAssigned())
-    return;
-
-  LayoutInfo tensorDescLayout = valueLayout;
-  if (load.getTranspose()) {
-    // LoadGatherOp has the transpose effect. However, at the stage of this
-    // analyis this effect is not expected and should be abstracted away. Emit
-    // a warning.
-    load.emitWarning("Transpose effect is not expected for LoadGatherOp at "
-                     "LayoutInfoPropagation stage.");
-    tensorDescLayout = valueLayout.getTransposedLayout({1, 0});
-  }
-  // Mask operand should have 1D default layout.
-  LayoutInfo maskLayout = getDefaultLayoutInfo(1);
-  // Propagate the new layout to the tensor descriptor operand.
-  propagateIfChanged(operands[0], operands[0]->meet(tensorDescLayout));
-  // Propagate the new layout to the mask operand.
-  propagateIfChanged(operands[1], operands[1]->meet(maskLayout));
-}
-
-/// Propagate the layout of the descriptor to the vector offset operand in
-/// CreateDescOp.
-void LayoutInfoPropagation::visitCreateDescOp(
-    xegpu::CreateDescOp createDesc, ArrayRef<LayoutInfoLattice *> operands,
-    ArrayRef<const LayoutInfoLattice *> results) {
-  LayoutInfo descLayout = results[0]->getValue();
-  // Need the layout of the descriptor to propagate to the operands.
-  if (!descLayout.isAssigned())
-    return;
-  // For offset operand propagate 1D default layout.
-  LayoutInfo layout = getDefaultLayoutInfo(1);
-  propagateIfChanged(operands[1], operands[1]->meet(layout));
-}
-
-/// Set the layout for the value, tensor descriptor, and mask operands in the
-/// StoreScatterOp.
-void LayoutInfoPropagation::visitStoreScatterOp(
-    xegpu::StoreScatterOp storeScatter, ArrayRef<LayoutInfoLattice *> operands,
-    ArrayRef<const LayoutInfoLattice *> results) {
-  // Currently, for 2D StoreScatterOp we expect that the height dimension of
-  // the tensor descriptor is equal to the subgroup size. This is ensured by
-  // the op verifier.
-  ArrayRef<int64_t> tdescShape = storeScatter.getTensorDescType().getShape();
-  if (tdescShape.size() > 1)
-    assert(
-        tdescShape[0] == subgroupSize &&
-        "Expected the first dimension of 2D tensor descriptor to be equal to "
-        "subgroup size.");
-
-  LayoutInfo valueLayout = getDefaultLayoutInfo(storeScatter.getValueType());
-  LayoutInfo storeScatterLayout = valueLayout;
-  if (storeScatter.getTranspose()) {
-    // StoreScatteOp allows transpose effect. However, at the stage of this
-    // analyis this effect is not expected and should be abstracted away. Emit
-    // a warning.
-    storeScatter.emitWarning("Transpose effect is not expected for "
-                             "StoreScatterOp at LayoutInfoPropagation stage.");
-    storeScatterLayout = valueLayout.getTransposedLayout({1, 0});
-  }
-  // Propagate the value layout.
-  propagateIfChanged(operands[0], operands[0]->meet(valueLayout));
-  // Propagate the tensor descriptor layout.
-  propagateIfChanged(operands[1], operands[1]->meet(storeScatterLayout));
-  // Use default 1D layout for mask operand.
-  LayoutInfo maskLayout = getDefaultLayoutInfo(1);
-  propagateIfChanged(operands[2], operands[2]->meet(maskLayout));
-}
-
-namespace {
-
-//===----------------------------------------------------------------------===//
-// RunLayoutInfoPropagation
-//===----------------------------------------------------------------------===//
-
-/// Driver class for running the LayoutInfoPropagation analysis.
-class RunLayoutInfoPropagation {
-public:
-  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(RunLayoutInfoPropagation)
-
-  RunLayoutInfoPropagation(Operation *op) : target(op) {
-    SymbolTableCollection symbolTable;
-    loadBaselineAnalyses(solver);
-    solver.load<LayoutInfoPropagation>(symbolTable);
-    (void)solver.initializeAndRun(op);
-  }
-
-  LayoutInfo getLayoutInfo(Value val);
-
-  void printAnalysisResult(llvm::raw_ostream &os);
-
-private:
-  DataFlowSolver solver;
-  const Operation *target;
-};
-} // namespace
-
-LayoutInfo RunLayoutInfoPropagation::getLayoutInfo(Value val) {
-  auto *state = solver.lookupState<LayoutInfoLattice>(val);
-  if (!state)
-    return {};
-  return state->getValue();
-}
-
-void RunLayoutInfoPropagation::printAnalysisResult(llvm::raw_ostream &os) {
-  auto printFunctionResult = [&](FunctionOpInterface funcOp) {
-    os << "function: " << funcOp.getName() << ":\n";
-    // Function arguments
-    for (BlockArgument arg : funcOp.getArguments()) {
-      LayoutInfo layout = getLayoutInfo(arg);
-      os << "argument: " << arg << "\n";
-      os << "layout  : ";
-      layout.print(os);
-      os << "\n";
-    }
-    // Function ops
-    funcOp.walk([&](Operation *op) {
-      // Skip ops that do not have results
-      if (op->getResults().empty())
-        return;
-      os << "op    : ";
-      // For control-flow ops, print the op name only.
-      if (isa<BranchOpInterface>(op) || isa<RegionBranchOpInterface>(op))
-        os << op->getName();
-      else
-        op->print(os);
-      os << "\n";
-      // Print the layout for each result.
-      for (auto [i, r] : llvm::enumerate(op->getResults())) {
-        LayoutInfo layout = getLayoutInfo(r);
-        os << "layout for result #" << i << ": ";
-        layout.print(os);
-        os << "\n";
-      }
-    });
-  };
-
-  SmallVector<FunctionOpInterface> funcOps;
-  if (auto modOp = dyn_cast<ModuleOp>(target)) {
-    for (auto funcOp : modOp.getOps<FunctionOpInterface>()) {
-      funcOps.push_back(funcOp);
-    }
-    // Collect all GpuFuncOps in the module.
-    for (auto gpuModOp : modOp.getOps<gpu::GPUModuleOp>()) {
-      for (auto gpuFuncOp : gpuModOp.getOps<FunctionOpInterface>()) {
-        funcOps.push_back(gpuFuncOp);
-      }
-    }
-  }
-  // Print the analysis result for each function.
-  for (FunctionOpInterface funcOp : funcOps) {
-    printFunctionResult(funcOp);
-  }
-}
-
-namespace {
-
-//===----------------------------------------------------------------------===//
-// LayoutAttrAssignment
-//===----------------------------------------------------------------------===//
-
-/// This class is responsible for assigning the layout attributes to the ops and
-/// their users based on the layout propagation analysis result.
-class LayoutAttrAssignment {
-public:
-  LayoutAttrAssignment(Operation *top,
-                       function_ref<LayoutInfo(Value)> getLayout)
-      : getAnalysisResult(getLayout), top(top) {}
-
-  LogicalResult run();
-
-private:
-  LogicalResult assign(Operation *op);
-  void assignToUsers(Value v, xegpu::LayoutAttr layout);
-  xegpu::LayoutAttr getLayoutAttrForValue(Value v);
-  LogicalResult resolveConflicts();
-  // Callable to get the layout of a value based on the layout propagation
-  // analysis.
-  function_ref<LayoutInfo(Value)> getAnalysisResult;
-  Operation *top;
-};
-
-} // namespace
-
-/// Helper to assign the layout attribute to the users of the value.
-void LayoutAttrAssignment::assignToUsers(Value v, xegpu::LayoutAttr layout) {
-  for (OpOperand &user : v.getUses()) {
-    Operation *owner = user.getOwner();
-    std::string attrName = xegpu::getLayoutName(user);
-    owner->setAttr(attrName, layout);
-  }
-}
-
-/// Convert the layout assigned to a value to xegpu::LayoutAttr.
-xegpu::LayoutAttr LayoutAttrAssignment::getLayoutAttrForValue(Value v) {
-  LayoutInfo layout = getAnalysisResult(v);
-  if (!layout.isAssigned())
-    return {};
-  SmallVector<int, 2> laneLayout, laneData;
-  for (auto [layout, data] : llvm::zip_equal(layout.getLayoutAsArrayRef(),
-                                             layout.getDataAsArrayRef())) {
-    laneLayout.push_back(static_cast<int>(layout));
-    laneData.push_back(static_cast<int>(data));
-  }
-  return xegpu::LayoutAttr::get(v.getContext(), laneLayout, laneData);
-}
-
-/// Assign xegpu::LayoutAttr to the op and its users. The layout is assigned
-/// based on the layout propagation analysis result.
-LogicalResult LayoutAttrAssignment::assign(Operation *op) {
-  // For function ops, propagate the function argument layout to the users.
-  if (auto func = dyn_cast<FunctionOpInterface>(op)) {
-    for (BlockArgument arg : func.getArguments()) {
-      xegpu::LayoutAttr layoutInfo = getLayoutAttrForValue(arg);
-      if (layoutInfo) {
-        assignToUsers(arg, layoutInfo);
-      }
-    }
-    return success();
-  }
-  // If no results, move on.
-  if (op->getNumResults() == 0)
-    return success();
-  // If all the results are scalars, move on.
-  if (llvm::all_of(op->getResultTypes(),
-                   [](Type t) { return t.isIntOrIndexOrFloat(); }))
-    return success();
-  // If the op has more than one result and at least one result is a tensor
-  // descriptor, exit. This case is not supported yet.
-  // TODO: Support this case.
-  if (op->getNumResults() > 1 && llvm::any_of(op->getResultTypes(), [](Type t) {
-        return isa<xegpu::TensorDescType>(t);
-      })) {
-    LLVM_DEBUG(
-        DBGS() << op->getName()
-               << " op has more than one result and at least one is a tensor "
-                  "descriptor. This case is not handled.\n");
-    return failure();
-  }
-  // If the result is a tensor descriptor, attach the layout to the tensor
-  // descriptor itself.
-  if (auto tensorDescTy =
-          dyn_cast<xegpu::TensorDescType>(op->getResultTypes()[0])) {
-    xegpu::LayoutAttr layoutInfo = getLayoutAttrForValue(op->getResult(0));
-    if (!layoutInfo) {
-      LLVM_DEBUG(DBGS() << "No layout for result of " << *op << "\n");
-      return failure();
-    }
-
-    // Clone the op, attach the layout to the result tensor descriptor, and
-    // remove the original op.
-    OpBuilder builder(op);
-    Operation *newOp = builder.clone(*op);
-    auto newTensorDescTy = xegpu::TensorDescType::get(
-        tensorDescTy.getContext(), tensorDescTy.getShape(),
-        tensorDescTy.getElementType(), tensorDescTy.getEncoding(), layoutInfo);
-    newOp->getResult(0).setType(newTensorDescTy);
-    op->replaceAllUsesWith(newOp->getResults());
-    op->erase();
-    return success();
-  }
-  // Otherwise simply attach the layout to the op itself.
-  for (auto r : op->getOpResults()) {
-    xegpu::LayoutAttr layoutInfo = getLayoutAttrForValue(r);
-    if (layoutInfo) {
-      std::string attrName = xegpu::getLayoutName(r);
-      op->setAttr(attrName, layoutInfo);
-      // Attach the layout attribute to the users of the result.
-      assignToUsers(r, layoutInfo);
-    }
-  }
-  return success();
-}
-
-/// Walk the IR and attach xegpu::LayoutAttr to all ops and their users.
-LogicalResult LayoutAttrAssignment::run() {
-  auto walkResult = top->walk([&](Operation *op) {
-    if (failed(assign(op)))
-      return WalkResult::interrupt();
-    return WalkResult::advance();
-  });
-
-  if (walkResult.wasInterrupted())
-    return failure();
-
-  return resolveConflicts();
-}
-
-/// TODO: Implement the layout conflict resolution. This must ensure mainly two
-/// things:
-/// 1) Is a given layout supported by the op? (need to query the target
-///    HW info). Otherwise can we achieve this layout using a layout conversion?
-/// 2) Do all the operands have the required layout? If not, can it
-///    be resolved using a layout conversion?
-LogicalResult LayoutAttrAssignment::resolveConflicts() { return success(); }
+static const char *const resolveSIMTTypeMismatch =
+    "resolve_simt_type_mismatch"; // Attribute name for identifying
+                                  // UnrelizedConversionCastOp added to resolve
+                                  // SIMT type mismatches.
 
 namespace {
 
@@ -867,9 +85,9 @@ getDistVecTypeBasedOnLaneLayout(xegpu::LayoutAttr layout,
   // dimensions are not distributed.
   unsigned distributionStart = originalType.getRank() - laneLayout.size();
   for (auto [i, dim] : llvm::enumerate(originalType.getShape())) {
-    if (i < distributionStart) {
+    if (i < distributionStart)
       continue;
-    }
+
     // Check if the dimension can be distributed evenly.
     if (dim % laneLayout[i - distributionStart] != 0)
       return failure();
@@ -909,6 +127,7 @@ static Value resolveDistributedTy(Value orig, T expected,
   if (isa<xegpu::TensorDescType>(orig.getType())) {
     auto castOp = rewriter.create<UnrealizedConversionCastOp>(orig.getLoc(),
                                                               expected, orig);
+    castOp->setAttr(resolveSIMTTypeMismatch, rewriter.getUnitAttr());
     return castOp.getResult(0);
   }
   llvm_unreachable("Unsupported type for reconciliation");
@@ -988,8 +207,9 @@ struct MoveFuncBodyToWarpExecuteOnLane0
         /** upperBound = **/ mlir::IntegerAttr());
     ArrayRef<Type> gpuFuncResultType = gpuFuncOp.getFunctionType().getResults();
     auto warpOp = rewriter.create<gpu::WarpExecuteOnLane0Op>(
-        laneId.getLoc(), gpuFuncResultType, laneId, subgroupSize,
-        newGpuFunc.getArguments(), newGpuFunc.getArgumentTypes());
+        laneId.getLoc(), gpuFuncResultType, laneId,
+        xegpu::targetinfo::subgroupSize, newGpuFunc.getArguments(),
+        newGpuFunc.getArgumentTypes());
     Block &warpBodyBlock = warpOp.getBodyRegion().front();
     // Replace the ReturnOp of the original gpu function with a YieldOp.
     auto origRetunOp =
@@ -1080,11 +300,14 @@ struct CreateNdDescDistribution final : public gpu::WarpDistributionPattern {
     xegpu::TensorDescType distributedTensorDescTy =
         descOp.getType().dropLayouts(); // Distributed tensor descriptor type
                                         // does not contain layout info.
-    auto newDescOp = rewriter.create<xegpu::CreateNdDescOp>(
+    Value newDescOp = rewriter.create<xegpu::CreateNdDescOp>(
         newWarpOp.getLoc(), distributedTensorDescTy, newDescOperands,
         descOp->getAttrs());
 
     Value distributedVal = newWarpOp.getResult(operandIdx);
+    // Resolve the distributed type to the expected type.
+    newDescOp =
+        resolveDistributedTy(newDescOp, distributedVal.getType(), rewriter);
     rewriter.replaceAllUsesWith(distributedVal, newDescOp);
     return success();
   }
@@ -1485,10 +708,13 @@ struct UpdateNdOffsetDistribution final : public gpu::WarpDistributionPattern {
       }
     }
     // Create a new update op outside the warp op.
-    auto newUpdateOp = rewriter.create<xegpu::UpdateNdOffsetOp>(
+    Value newUpdateOp = rewriter.create<xegpu::UpdateNdOffsetOp>(
         newWarpOp.getLoc(), newTensorDescTy, newUpdateOperands,
         removeTemporaryLayoutAttributes(updateOp->getAttrs()));
     Value distributedVal = newWarpOp.getResult(operandIdx);
+    // Resolve the distributed type with the original type.
+    newUpdateOp =
+        resolveDistributedTy(newUpdateOp, distributedVal.getType(), rewriter);
     rewriter.replaceAllUsesWith(distributedVal, newUpdateOp);
     return success();
   }
@@ -1562,11 +788,6 @@ namespace {
 struct XeGPUSubgroupDistributePass final
     : public xegpu::impl::XeGPUSubgroupDistributeBase<
           XeGPUSubgroupDistributePass> {
-  XeGPUSubgroupDistributePass() = default;
-  XeGPUSubgroupDistributePass(const XeGPUSubgroupDistributePass &other) =
-      default;
-  XeGPUSubgroupDistributePass(xegpu::XeGPUSubgroupDistributeOptions options)
-      : XeGPUSubgroupDistributeBase(options) {}
   void runOnOperation() override;
 };
 } // namespace
@@ -1579,27 +800,29 @@ void xegpu::populateXeGPUSubgroupDistributePatterns(
 }
 
 void XeGPUSubgroupDistributePass::runOnOperation() {
-  auto &analyis = getAnalysis<RunLayoutInfoPropagation>();
-  // Print the analysis result and exit. (for testing purposes)
-  if (printOnly) {
-    auto &os = llvm::outs();
-    analyis.printAnalysisResult(os);
-    return;
-  }
-  auto getPropagatedLayout = [&](Value val) {
-    return analyis.getLayoutInfo(val);
-  };
-
-  // Assign xegpu::LayoutAttr to all ops and their users based on the layout
-  // propagation analysis result.
-  LayoutAttrAssignment layoutAssignment(getOperation(), getPropagatedLayout);
-  if (failed(layoutAssignment.run())) {
-    signalPassFailure();
-    return;
-  }
-
-  // Move all operations of a GPU function inside gpu.warp_execute_on_lane_0
-  // operation.
+  // Step 1: Attach layouts to op operands.
+  // TODO: Following assumptions are made:
+  // 1) It is assumed that there are no layout conflicts.
+  // 2) Any existing layout attributes attached to the operands are ignored.
+  Operation *op = getOperation();
+  op->walk([&](Operation *op) {
+    for (OpOperand &operand : op->getOpOperands()) {
+      // Layouts are needed for vector type only.
+      if (!isa<VectorType>(operand.get().getType()))
+        continue;
+
+      xegpu::LayoutAttr layout = xegpu::getLayoutAttr(operand);
+      if (!layout) {
+        op->emitError("Could not find layout attribute for operand ")
+            << operand.getOperandNumber() << " of operation " << op->getName();
+        signalPassFailure();
+        return;
+      }
+      xegpu::setLayoutAttr(operand, layout);
+    }
+  });
+  // Step 2: Move all operations of a GPU function inside
+  // gpu.warp_execute_on_lane_0 operation.
   {
     RewritePatternSet patterns(&getContext());
     patterns.add<MoveFuncBodyToWarpExecuteOnLane0>(&getContext());
@@ -1608,17 +831,16 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
       signalPassFailure();
       return;
     }
-    // At this point, we have moved the entire function body inside the warpOp.
-    // Now move any scalar uniform code outside of the warpOp (like GPU index
-    // ops, scalar constants, etc.). This will simplify the later lowering and
-    // avoid custom patterns for these ops.
+    // At this point, we have moved the entire function body inside the
+    // warpOp. Now move any scalar uniform code outside of the warpOp (like
+    // GPU index ops, scalar constants, etc.). This will simplify the
+    // later lowering and avoid custom patterns for these ops.
     getOperation()->walk([&](Operation *op) {
-      if (auto warpOp = dyn_cast<gpu::WarpExecuteOnLane0Op>(op)) {
+      if (auto warpOp = dyn_cast<gpu::WarpExecuteOnLane0Op>(op))
         vector::moveScalarUniformCode(warpOp);
-      }
     });
   }
-  // Finally, do the SIMD to SIMT distribution.
+  // Step 3: Apply subgroup to workitem distribution patterns.
   RewritePatternSet patterns(&getContext());
   xegpu::populateXeGPUSubgroupDistributePatterns(patterns);
   // TODO: distributionFn and shuffleFn are not used at this point.
@@ -1638,4 +860,51 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
     signalPassFailure();
     return;
   }
+
+  // Step 4: Finllay, clean up UnrealizedConversionCastOps that were inserted
+  // due to tensor desc type mismatches created by using upstream distribution
+  // patterns (scf.for)
+  getOperation()->walk([&](mlir::UnrealizedConversionCastOp op) {
+    // We are only interested in UnrealizedConversionCastOps there were added
+    // for resolving SIMT type mismatches.
+    if (!op->getAttr(resolveSIMTTypeMismatch))
+      return WalkResult::skip();
+
+    Value input = op.getOperand(0);
+    Value output = op.getResult(0);
+
+    // Both input and output must have tensor descriptor types.
+    xegpu::TensorDescType inputDescType =
+        mlir::dyn_cast<xegpu::TensorDescType>(input.getType());
+    xegpu::TensorDescType outputDescType =
+        mlir::dyn_cast<xegpu::TensorDescType>(output.getType());
+    assert(inputDescType && outputDescType &&
+           "Unrealized conversion cast must have tensor descriptor types");
+
+    // tensor_desc<shape, layout> -> tensor_desc<shape> Type of conversions.
+    // This occurs iside scf.for body to resolve the block argument type to
+    // SIMT type.
+    if (inputDescType.getLayout()) {
+      auto argument = mlir::dyn_cast<mlir::BlockArgument>(input);
+      if (argument) {
+        argument.setType(output.getType());
+        output.replaceAllUsesWith(argument);
+        if (auto loopOp = mlir::dyn_cast<mlir::LoopLikeOpInterface>(
+                argument.getOwner()->getParentOp())) {
+          auto result = loopOp.getTiedLoopResult(argument);
+          result.setType(output.getType());
+        }
+      }
+    }
+
+    // tensor_desc<shape> -> tensor_desc<shape, layout> Type of
+    // conversions. This occurs at the yield op of scf.for body to go back
+    // from SIMT type to original type.
+    if (outputDescType.getLayout())
+      output.replaceAllUsesWith(input);
+
+    if (op->use_empty())
+      op->erase();
+    return WalkResult::advance();
+  });
 }

diff  --git a/mlir/test/Dialect/XeGPU/propagate-layout.mlir b/mlir/test/Dialect/XeGPU/propagate-layout.mlir
new file mode 100644
index 0000000000000..429081079de1e
--- /dev/null
+++ b/mlir/test/Dialect/XeGPU/propagate-layout.mlir
@@ -0,0 +1,430 @@
+// RUN: mlir-opt -xegpu-propagate-layout -split-input-file %s | FileCheck %s
+
+// CHECK-LABEL: func.func @dpas_f16(
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<8x16xf16>, %[[ARG1:[0-9a-zA-Z]+]]: memref<16x16xf16>, %[[ARG2:[0-9a-zA-Z]+]]: memref<8x16xf32>) {
+// CHECK: %[[CST:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} dense<0.000000e+00> : vector<8x16xf32>
+// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]][{{.*}}] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+// CHECK: %[[T1:.*]] = xegpu.create_nd_tdesc %[[ARG1]][{{.*}}] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>>
+// CHECK: %[[T2:.*]] = xegpu.load_nd %[[T0]]  {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} :
+// CHECK-SAME: !xegpu.tensor_desc<8x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<8x16xf16>
+// CHECK: %[[T3:.*]] = xegpu.load_nd %[[T1]]  {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} :
+// CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>> -> vector<16x16xf16>
+// CHECK: %[[T4:.*]] = xegpu.dpas %[[T2]], %[[T3]], %[[CST]] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} :
+// CHECK-SAME: vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
+// CHECK: %[[T5:.*]] = xegpu.create_nd_tdesc %[[ARG2]][{{.*}}] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+// CHECK: xegpu.store_nd %[[T4]], %[[T5]] : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+func.func @dpas_f16(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<8x16xf32>) {
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant dense<0.000000e+00> : vector<8x16xf32>
+  %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
+  %1 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
+  %2 = xegpu.load_nd %0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+  %3 = xegpu.load_nd %1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+  %4 = xegpu.dpas %2, %3, %cst : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
+  %5 = xegpu.create_nd_tdesc %arg2[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
+  xegpu.store_nd %4, %5  : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
+  return
+}
+
+// -----
+// CHECK-LABEL: func.func @dpas_i8(
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: vector<8x32xi8>, %[[ARG1:[0-9a-zA-Z]+]]: vector<32x16xi8>, %[[ARG2:[0-9a-zA-Z]+]]: memref<8x16xi32>) {
+// CHECK: %[[T0:.*]] = xegpu.dpas %[[ARG0]], %[[ARG1]] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16],
+func.func @dpas_i8(%arg0: vector<8x32xi8>, %arg1: vector<32x16xi8>, %arg2: memref<8x16xi32>) {
+  %c0 = arith.constant 0 : index
+  %0 = xegpu.dpas %arg0, %arg1 : vector<8x32xi8>, vector<32x16xi8> -> vector<8x16xi32>
+  %1 = xegpu.create_nd_tdesc %arg2[%c0, %c0] : memref<8x16xi32> -> !xegpu.tensor_desc<8x16xi32>
+  xegpu.store_nd %0, %1  : vector<8x16xi32>, !xegpu.tensor_desc<8x16xi32>
+  return
+}
+
+// -----
+// CHECK-LABEL: func.func @load_with_transpose_effect(
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<8x16xf16>, %[[ARG0:[0-9a-zA-Z]+]]: memref<16x16xf16>, %[[ARG0:[0-9a-zA-Z]+]]: memref<8x16xf32>) {
+// CHECK: %{{.*}} = xegpu.load_nd %{{.*}} <{transpose = array<i64: 1, 0>}> {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} :
+// CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>> -> vector<16x16xf16>
+func.func @load_with_transpose_effect(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<8x16xf32>) {
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant dense<0.000000e+00> : vector<8x16xf32>
+  %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
+  %1 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
+  %2 = xegpu.load_nd %0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+  %3 = xegpu.load_nd %1 <{transpose = array<i64: 1, 0>}> : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+  %4 = xegpu.dpas %2, %3, %cst : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
+  %5 = xegpu.create_nd_tdesc %arg2[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
+  xegpu.store_nd %4, %5  : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
+  return
+}
+
+// -----
+// CHECK-LABEL: func.func @vector_transpose(
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<8x16xf16>, %[[ARG1:[0-9a-zA-Z]+]]: memref<16x16xf16>, %[[ARG2:[0-9a-zA-Z]+]]: memref<8x16xf32>) {
+// CHECK: %{{.*}} = vector.transpose %{{.*}}, [1, 0] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} : vector<16x16xf16> to vector<16x16xf16>
+func.func @vector_transpose(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<8x16xf32>) {
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant dense<0.000000e+00> : vector<8x16xf32>
+  %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
+  %1 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
+  %2 = xegpu.load_nd %0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+  %3 = xegpu.load_nd %1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+  %4 = vector.transpose %3, [1, 0] : vector<16x16xf16> to vector<16x16xf16>
+  %5 = xegpu.dpas %2, %4, %cst : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
+  %6 = xegpu.create_nd_tdesc %arg2[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
+  xegpu.store_nd %5, %6  : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
+  return
+}
+
+// -----
+// CHECK-LABEL: func.func @extf_truncf(
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<8x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>, %[[ARG1:[0-9a-zA-Z]+]]:
+// CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>>) -> vector<8x16xf32> {
+// CHECK: %[[T2:.*]] = arith.extf %{{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} : vector<16x16xf16> to vector<16x16xf32>
+// CHECK-NEXT: %{{.*}} = arith.truncf %[[T2]] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} : vector<16x16xf32> to vector<16x16xf16>
+func.func @extf_truncf(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>) -> vector<8x16xf32> {
+  %0 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+  %1 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+  %2 = arith.extf %1 : vector<16x16xf16> to vector<16x16xf32>
+  %3 = arith.truncf %2 : vector<16x16xf32> to vector<16x16xf16>
+  %4 = xegpu.dpas %0, %3 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+  return %4 : vector<8x16xf32>
+}
+
+// -----
+// CHECK-LABEL: func.func @load_gather_with_transpose_effect(
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<8x16xf16>, %[[ARG1:[0-9a-zA-Z]+]]: memref<256xf16>, %[[ARG2:[0-9a-zA-Z]+]]: memref<8x16xf32>) {
+// CHECK: %[[CST:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
+// CHECK-SAME:  dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
+// CHECK-NEXT: %[[CST0:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<true> : vector<16xi1>
+// CHECK-NEXT: %[[T2:.*]] = xegpu.create_tdesc %[[ARG1]], %[[CST]] : memref<256xf16>, vector<16xindex> ->
+// CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16 : i64>, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>>
+// CHECK-NEXT: %{{.*}} = xegpu.load %[[T2]], %[[CST0]] <{transpose}> {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} :
+// CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16 : i64>, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>>, vector<16xi1> -> vector<16x16xf16>
+func.func @load_gather_with_transpose_effect(%arg0: memref<8x16xf16>, %arg1: memref<256xf16>, %arg2: memref<8x16xf32>) {
+  %c0 = arith.constant 0 : index
+  %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
+  %1 = xegpu.load_nd %0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+  %cst = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
+  %cst_0 = arith.constant dense<true> : vector<16xi1>
+  %2 = xegpu.create_tdesc %arg1, %cst : memref<256xf16>, vector<16xindex> -> !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16 : i64>>
+  %3 = xegpu.load %2, %cst_0 <{transpose}> : !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16 : i64>>, vector<16xi1> -> vector<16x16xf16>
+  %4 = xegpu.dpas %1, %3 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+  %5 = xegpu.create_nd_tdesc %arg2[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
+  xegpu.store_nd %4, %5  : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
+  return
+}
+
+// -----
+// CHECK-LABEL: func.func @load_gather_1d(
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<256xf32>, %[[ARG1:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>>) {
+// CHECK: %[[CST:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
+// CHECK-SAME: dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
+// CHECK-NEXT: %[[CST0:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<true> : vector<16xi1>
+// CHECK-NEXT: %[[T0:.*]] = xegpu.create_tdesc %[[ARG0]], %[[CST]] : memref<256xf32>, vector<16xindex> ->
+// CHECK-SAME: !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
+// CHECK-NEXT: %{{.*}} = xegpu.load %[[T0]], %[[CST0]]  {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} :
+// CHECK-SAME: !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<lane_layout = [16], lane_data = [1]>>, vector<16xi1> -> vector<16xf32>
+func.func @load_gather_1d(%arg0: memref<256xf32>, %arg1: !xegpu.tensor_desc<16xf32>) {
+  %cst = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
+  %cst_0 = arith.constant dense<true> : vector<16xi1>
+  %0 = xegpu.create_tdesc %arg0, %cst : memref<256xf32>, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
+  %1 = xegpu.load %0, %cst_0  : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> -> vector<16xf32>
+  xegpu.store_nd %1, %arg1  : vector<16xf32>, !xegpu.tensor_desc<16xf32>
+  return
+}
+
+// -----
+// CHECK-LABEL: func.func @store_scatter_with_transpose_effect(
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<128xf32>) {
+// CHECK: %[[T0:.*]] = xegpu.create_tdesc %[[ARG0]], %{{.*}} : memref<128xf32>, vector<16xindex> ->
+// CHECK-SAME: !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<chunk_size = 8 : i64>, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>>
+// CHECK-NEXT: xegpu.store %{{.*}}, %[[T0]], %{{.*}} <{transpose}> : vector<8x16xf32>, !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<chunk_size = 8 : i64>,
+// CHECK-SAME: #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>>, vector<16xi1>
+func.func @store_scatter_with_transpose_effect(%arg0: memref<128xf32>) {
+  %cst = arith.constant dense<1.000000e+00> : vector<8x16xf32>
+  %cst_0 = arith.constant dense<true> : vector<16xi1>
+  %cst_1 = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
+  %0 = xegpu.create_tdesc %arg0, %cst_1 : memref<128xf32>, vector<16xindex> -> !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<chunk_size = 8 : i64>>
+  xegpu.store %cst, %0, %cst_0 <{transpose}> : vector<8x16xf32>, !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<chunk_size = 8 : i64>>, vector<16xi1>
+  return
+}
+
+// -----
+// CHECK-LABEL: func.func @store_scatter_1d(
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: vector<16xf32>, %[[ARG1:[0-9a-zA-Z]+]]: memref<256xf32>) {
+// CHECK: xegpu.store %[[ARG0]], %{{.*}}, %{{.*}}  : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>,
+// CHECK-SAME: #xegpu.layout<lane_layout = [16], lane_data = [1]>>, vector<16xi1>
+func.func @store_scatter_1d(%arg0: vector<16xf32>, %arg1: memref<256xf32>) {
+  %cst = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
+  %cst_0 = arith.constant dense<true> : vector<16xi1>
+  %0 = xegpu.create_tdesc %arg1, %cst : memref<256xf32>, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
+  xegpu.store %arg0, %0, %cst_0  : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1>
+  return
+}
+
+// -----
+// CHECK-LABEL: func.func @vector_bitcast_i16_to_f16(
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<8x16xi16>, %[[ARG1:[0-9a-zA-Z]+]]: memref<16x16xi16>, %[[ARG2:[0-9a-zA-Z]+]]: memref<8x16xf32>) {
+// CHECK: %{{.*}} = vector.bitcast %{{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<8x16xi16> to vector<8x16xf16>
+// CHECK: %{{.*}} = vector.bitcast %{{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} : vector<16x16xi16> to vector<16x16xf16>
+func.func @vector_bitcast_i16_to_f16(%arg0: memref<8x16xi16>, %arg1: memref<16x16xi16>, %arg2: memref<8x16xf32>) {
+  %c0 = arith.constant 0 : index
+  %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x16xi16> -> !xegpu.tensor_desc<8x16xi16>
+  %1 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<16x16xi16> -> !xegpu.tensor_desc<16x16xi16>
+  %2 = xegpu.load_nd %0  : !xegpu.tensor_desc<8x16xi16> -> vector<8x16xi16>
+  %3 = xegpu.load_nd %1  : !xegpu.tensor_desc<16x16xi16> -> vector<16x16xi16>
+  %4 = vector.bitcast %2 : vector<8x16xi16> to vector<8x16xf16>
+  %5 = vector.bitcast %3 : vector<16x16xi16> to vector<16x16xf16>
+  %6 = xegpu.dpas %4, %5 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+  %7 = xegpu.create_nd_tdesc %arg2[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
+  xegpu.store_nd %6, %7  : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
+  return
+}
+
+// -----
+// CHECK-LABEL: func.func @binary_op_one_use(
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<8x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>,
+// CHECK-SAME: %[[ARG1:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>>,
+// CHECK-SAME: %[[ARG2:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>) {
+// CHECK: %[[T1:.*]] = xegpu.load_nd %[[ARG1]]  {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} :
+// CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>> -> vector<16x16xf16>
+// CHECK-NEXT: %[[T2:.*]] = xegpu.load_nd %[[ARG1]]  {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} :
+// CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>> -> vector<16x16xf16>
+// CHECK-NEXT: %{{.*}} = arith.addf %[[T1]], %[[T2]] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} : vector<16x16xf16>
+func.func @binary_op_one_use(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: !xegpu.tensor_desc<8x16xf32>) {
+  %0 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+  %1 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+  %2 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+  %3 = arith.addf %1, %2 : vector<16x16xf16>
+  %4 = xegpu.dpas %0, %3 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+  xegpu.store_nd %4, %arg2  : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
+  return
+}
+
+// -----
+// CHECK-LABEL: func.func @binary_op_multiple_uses(
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<8x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>,
+// CHECK-SAME: %[[ARG1:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>,
+// CHECK-SAME: %[[ARG2:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>,
+// CHECK-SAME: %[[ARG3:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>) {
+// CHECK: %[[T2:.*]] = arith.addf %{{.*}}, %{{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16x16xf16>
+// CHECK: %[[T3:.*]] = xegpu.dpas %{{.*}}, %[[T2]] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+// CHECK-NEXT: xegpu.store_nd %[[T3]], %[[ARG2]]  : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+// CHECK-NEXT: xegpu.store_nd %[[T2]], %[[ARG3]]  : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+func.func @binary_op_multiple_uses(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: !xegpu.tensor_desc<8x16xf32>, %arg3: !xegpu.tensor_desc<16x16xf16>) {
+  %0 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+  %1 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+  %cst = arith.constant dense<1.000000e+00> : vector<16x16xf16>
+  %2 = arith.addf %1, %cst : vector<16x16xf16>
+  %3 = xegpu.dpas %0, %2 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+  xegpu.store_nd %3, %arg2  : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
+  xegpu.store_nd %2, %arg3  : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16>
+  return
+}
+
+// -----
+// CHECK-LABEL: func.func @for_op(
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<8x128xf16>, %[[ARG1:[0-9a-zA-Z]+]]: memref<128x16xf16>, %[[ARG2:[0-9a-zA-Z]+]]: memref<8x16xf32>) {
+// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%{{.*}}] : memref<8x128xf16> -> !xegpu.tensor_desc<8x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+// CHECK-NEXT: %[[T1:.*]] = xegpu.create_nd_tdesc %[[ARG1]][%{{.*}}] : memref<128x16xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>>
+// CHECK-NEXT: %[[CST:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} dense<0.000000e+00> : vector<8x16xf32>
+// CHECK-NEXT: %[[T2:.*]]:3 = scf.for %{{.*}} iter_args(%[[ARG4:.*]] = %[[T0]], %[[ARG5:.*]] = %[[T1]], %[[ARG6:.*]] = %[[CST]]) ->
+// CHECK-SAME: (!xegpu.tensor_desc<8x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>, !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>>, vector<8x16xf32>) {
+// CHECK-NEXT:   %[[T4:.*]] = xegpu.load_nd %[[ARG4]]  {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} :
+// CHECK-SAME: !xegpu.tensor_desc<8x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<8x16xf16>
+// CHECK-NEXT:   %[[T5:.*]] = xegpu.load_nd %[[ARG5]]  {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} :
+// CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>> -> vector<16x16xf16>
+// CHECK-NEXT:   %[[T6:.*]] = xegpu.dpas %[[T4]], %[[T5]], %[[ARG6]] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} :
+// CHECK-SAME: vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
+// CHECK-NEXT:   %[[T7:.*]] = xegpu.update_nd_offset %[[ARG4]], [{{.*}}] : !xegpu.tensor_desc<8x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+// CHECK-NEXT:   %[[T8:.*]] = xegpu.update_nd_offset %[[ARG5]], [{{.*}}] : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>>
+// CHECK-NEXT:   scf.yield %[[T7]], %[[T8]], %[[T6]] : !xegpu.tensor_desc<8x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>,
+// CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>>, vector<8x16xf32>
+// CHECK-NEXT: } {layout_result_2 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+// CHECK-NEXT: %[[T3:.*]] = xegpu.create_nd_tdesc %[[ARG2]][{{.*}}] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+// CHECK-NEXT: xegpu.store_nd %[[T2]]#2, %[[T3]] : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+func.func @for_op(%arg0: memref<8x128xf16>, %arg1: memref<128x16xf16>, %arg2: memref<8x16xf32>) {
+  %c0 = arith.constant 0 : index
+  %c128 = arith.constant 128 : index
+  %c16 = arith.constant 16 : index
+  %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x128xf16> -> !xegpu.tensor_desc<8x16xf16>
+  %1 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<128x16xf16> -> !xegpu.tensor_desc<16x16xf16>
+  %cst = arith.constant dense<0.000000e+00> : vector<8x16xf32>
+  %2:3 = scf.for %arg3 = %c0 to %c128 step %c16 iter_args(%arg4 = %0, %arg5 = %1, %arg6 = %cst) -> (!xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<16x16xf16>, vector<8x16xf32>) {
+    %4 = xegpu.load_nd %arg4  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+    %5 = xegpu.load_nd %arg5  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+    %6 = xegpu.dpas %4, %5, %arg6 : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
+    %7 = xegpu.update_nd_offset %arg4, [%c0, %c16] : !xegpu.tensor_desc<8x16xf16>
+    %8 = xegpu.update_nd_offset %arg5, [%c16, %c0] : !xegpu.tensor_desc<16x16xf16>
+    scf.yield %7, %8, %6 : !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<16x16xf16>, vector<8x16xf32>
+  }
+  %3 = xegpu.create_nd_tdesc %arg2[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
+  xegpu.store_nd %2#2, %3  : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
+  return
+}
+
+// -----
+// CHECK-LABEL: func.func @if_single_use(
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<8x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>,
+// CHECK-SAME: %[[ARG1:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>>,
+// CHECK-SAME: %[[ARG2:[0-9a-zA-Z]+]]: i1, %[[ARG3:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>) {
+// CHECK:  %{{.*}} = scf.if %[[ARG2]] -> (vector<16x16xf16>) {
+// CHECK-NEXT:    %[[T3:.*]] = xegpu.load_nd %[[ARG1]]  {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} :
+// CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>> -> vector<16x16xf16>
+// CHECK-NEXT:    scf.yield %[[T3]] : vector<16x16xf16>
+// CHECK-NEXT:  } else {
+// CHECK-NEXT:    %[[T4:.*]] = xegpu.load_nd %[[ARG1]]  {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} :
+// CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>> -> vector<16x16xf16>
+// CHECK-NEXT:    scf.yield %[[T4]] : vector<16x16xf16>
+// CHECK-NEXT:  } {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>}
+func.func @if_single_use(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: i1, %arg3: !xegpu.tensor_desc<8x16xf32>) {
+  %0 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+  %1 = scf.if %arg2 -> (vector<16x16xf16>) {
+    %3 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+    scf.yield %3 : vector<16x16xf16>
+  } else {
+    %3 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+    scf.yield %3 : vector<16x16xf16>
+  }
+  %2 = xegpu.dpas %0, %1 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+  xegpu.store_nd %2, %arg3  : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
+  return
+}
+
+// -----
+// CHECK-LABEL: func.func @if_multiple_uses(
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<8x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>,
+// CHECK-SAME: %[[ARG1:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>,
+// CHECK-SAME: %[[ARG2:[0-9a-zA-Z]+]]: i1, %[[ARG3:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>,
+// CHECK-SAME: %[[ARG4:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>) {
+// CHECK: %[[T1:.*]] = scf.if %[[ARG2]] -> (vector<16x16xf16>) {
+// CHECK-NEXT:       %[[T3:.*]] = xegpu.load_nd %[[ARG1]]  {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} :
+// CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<16x16xf16>
+// CHECK-NEXT:       scf.yield %[[T3]] : vector<16x16xf16>
+// CHECK-NEXT:     } else {
+// CHECK-NEXT:       %[[T4:.*]] = xegpu.load_nd %[[ARG1]]  {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} :
+// CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<16x16xf16>
+// CHECK-NEXT:       scf.yield %[[T4]] : vector<16x16xf16>
+// CHECK-NEXT:     } {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+func.func @if_multiple_uses(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: i1, %arg3: !xegpu.tensor_desc<8x16xf32>, %arg4: !xegpu.tensor_desc<16x16xf16>) {
+  %0 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+  %1 = scf.if %arg2 -> (vector<16x16xf16>) {
+    %3 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+    scf.yield %3 : vector<16x16xf16>
+  } else {
+    %3 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+    scf.yield %3 : vector<16x16xf16>
+  }
+  %2 = xegpu.dpas %0, %1 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+  xegpu.store_nd %2, %arg3  : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
+  xegpu.store_nd %1, %arg4  : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16>
+  return
+}
+
+// -----
+// CHECK-LABEL: func.func @vector_outer_reduction(
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: vector<16x16xf32>, %[[ARG1:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>>) {
+// CHECK: %{{.*}} = vector.multi_reduction <add>, %[[ARG0]], %{{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} [0] : vector<16x16xf32> to vector<16xf32>
+func.func @vector_outer_reduction(%arg0: vector<16x16xf32>, %arg1: !xegpu.tensor_desc<16xf32>) {
+  %cst = arith.constant dense<0.000000e+00> : vector<16xf32>
+  %0 = vector.multi_reduction <add>, %arg0, %cst [0] : vector<16x16xf32> to vector<16xf32>
+  xegpu.store_nd %0, %arg1  : vector<16xf32>, !xegpu.tensor_desc<16xf32>
+  return
+}
+
+// -----
+// CHECK-LABEL: func.func @vector_inner_reduction(
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: vector<16x16xf32>, %[[ARG1:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>>) {
+// CHECK: %{{.*}} = vector.multi_reduction <add>, %[[ARG0]], %{{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} [1] : vector<16x16xf32> to vector<16xf32>
+func.func @vector_inner_reduction(%arg0: vector<16x16xf32>, %arg1: !xegpu.tensor_desc<16xf32>) {
+  %cst = arith.constant dense<0.000000e+00> : vector<16xf32>
+  %0 = vector.multi_reduction <add>, %arg0, %cst [1] : vector<16x16xf32> to vector<16xf32>
+  xegpu.store_nd %0, %arg1  : vector<16xf32>, !xegpu.tensor_desc<16xf32>
+  return
+}
+
+// -----
+// CHECK-LABEL: func.func @update_nd_offset_1d(
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<256xf32>) {
+// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%{{.*}}] : memref<256xf32> -> !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
+// CHECK-NEXT: %[[T1:.*]] = xegpu.update_nd_offset %[[T0]], [%{{.*}}] : !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
+func.func @update_nd_offset_1d(%arg0: memref<256xf32>){
+  %c0 = arith.constant 0 : index
+  %c32 = arith.constant 32 : index
+  %1 = arith.constant dense<1.000000e+00> : vector<16xf32>
+  %0 = xegpu.create_nd_tdesc %arg0[%c0] : memref<256xf32> -> !xegpu.tensor_desc<16xf32>
+  %2 = xegpu.update_nd_offset %0, [%c32] : !xegpu.tensor_desc<16xf32>
+  xegpu.store_nd %1, %2 : vector<16xf32>, !xegpu.tensor_desc<16xf32>
+  return
+}
+
+// -----
+// CHECK-LABEL: func.func @update_nd_offset_2d(
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<256x256xf32>) {
+// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%{{.*}}, %{{.*}}] : memref<256x256xf32> -> !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+// CHECK-NEXT: %[[T1:.*]] = xegpu.update_nd_offset %[[T0]], [%{{.*}}, %{{.*}}] : !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+func.func @update_nd_offset_2d(%arg0: memref<256x256xf32>){
+  %c0 = arith.constant 0 : index
+  %c32 = arith.constant 32 : index
+  %1 = arith.constant dense<1.000000e+00> : vector<16x16xf32>
+  %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<256x256xf32> -> !xegpu.tensor_desc<16x16xf32>
+  %2 = xegpu.update_nd_offset %0, [%c32, %c32] : !xegpu.tensor_desc<16x16xf32>
+  xegpu.store_nd %1, %2 : vector<16x16xf32>, !xegpu.tensor_desc<16x16xf32>
+  return
+}
+
+// -----
+// CHECK-LABEL: func.func @prefetch_2d(
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<256x256xf16>) {
+// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%{{.*}}, %{{.*}}] : memref<256x256xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+// CHECK-NEXT: xegpu.prefetch_nd %[[T0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+func.func @prefetch_2d(%arg0: memref<256x256xf16>){
+  %c0 = arith.constant 0 : index
+  %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<256x256xf16> -> !xegpu.tensor_desc<16x16xf16>
+  xegpu.prefetch_nd %0 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>: !xegpu.tensor_desc<16x16xf16>
+  return
+}
+
+// -----
+// CHECK-LABEL: func.func @prefetch_1d(
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<256xf16>) {
+// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%{{.*}}] : memref<256xf16> -> !xegpu.tensor_desc<16xf16, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
+// CHECK-NEXT: xegpu.prefetch_nd %[[T0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<16xf16, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
+func.func @prefetch_1d(%arg0: memref<256xf16>){
+  %c0 = arith.constant 0 : index
+  %0 = xegpu.create_nd_tdesc %arg0[%c0] : memref<256xf16> -> !xegpu.tensor_desc<16xf16>
+  xegpu.prefetch_nd %0 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>: !xegpu.tensor_desc<16xf16>
+  return
+}
+
+// -----
+// CHECK-LABEL: func.func @test_scf_while_and_condition(
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<256xf32>, %[[ARG1:[0-9a-zA-Z]+]]: memref<256xf32>) {
+// CHECK: %{{.*}}:3 = scf.while ({{.*}}) : (vector<16xf32>, i32, !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>>)
+// CHECK-SAME: -> (vector<16xf32>, i32, !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>>) {
+// CHECK:       scf.condition(%{{.*}}) {{.*}} : vector<16xf32>, i32, !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
+// CHECK-NEXT: } do {
+// CHECK-NEXT: ^bb0(%{{.*}}: vector<16xf32>, %{{.*}}: i32, %{{.*}}: !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>>):
+// CHECK:     scf.yield {{.*}} : vector<16xf32>, i32, !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
+// CHECK-NEXT: } attributes {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
+func.func @test_scf_while_and_condition(%arg0: memref<256xf32>, %arg1: memref<256xf32>) {
+  %c0 = arith.constant 0 : i32
+  %c16 = arith.constant 16 : i32
+  %c256 = arith.constant 256 : i32
+  %0 = xegpu.create_nd_tdesc %arg0[0] : memref<256xf32> -> !xegpu.tensor_desc<16xf32>
+  %1 = xegpu.load_nd %0  : !xegpu.tensor_desc<16xf32> -> vector<16xf32>
+  %2 = xegpu.create_nd_tdesc %arg1[0] : memref<256xf32> -> !xegpu.tensor_desc<16xf32>
+
+  %3:3 = scf.while (%arg2 = %1, %arg3 = %c0, %arg4 = %0) : (vector<16xf32>, i32, !xegpu.tensor_desc<16xf32>)
+    -> (vector<16xf32>, i32, !xegpu.tensor_desc<16xf32>) {
+    %4 = arith.cmpi slt, %arg3, %c256 : i32
+    scf.condition(%4) %arg2, %arg3, %arg4 : vector<16xf32>, i32, !xegpu.tensor_desc<16xf32>
+  } do {
+  ^bb0(%arg2: vector<16xf32>, %arg3: i32, %arg4: !xegpu.tensor_desc<16xf32>):
+    xegpu.store_nd %arg2, %2  : vector<16xf32>, !xegpu.tensor_desc<16xf32>
+    %4 = arith.addi %arg3, %c16 : i32
+    %5 = xegpu.update_nd_offset %arg4, [16] : !xegpu.tensor_desc<16xf32>
+    %6 = xegpu.load_nd %5  : !xegpu.tensor_desc<16xf32> -> vector<16xf32>
+    scf.yield %6, %4, %5 : vector<16xf32>, i32, !xegpu.tensor_desc<16xf32>
+  }
+  return
+}

diff  --git a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
new file mode 100644
index 0000000000000..a59633b0cbd9a
--- /dev/null
+++ b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
@@ -0,0 +1,280 @@
+// RUN: mlir-opt -xegpu-subgroup-distribute -canonicalize -cse -split-input-file %s | FileCheck %s
+
+// CHECK-LABEL: gpu.func @store_nd_1d
+// CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: memref<16xf32>) {
+// CHECK-DAG: %[[CST:.*]] = arith.constant dense<1.000000e+00> : vector<1xf32>
+// CHECK-DAG: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%{{.*}}] : memref<16xf32> -> !xegpu.tensor_desc<16xf32>
+// CHECK: xegpu.store_nd %[[CST]], %[[T0]]  : vector<1xf32>, !xegpu.tensor_desc<16xf32>
+// CHECK: gpu.return
+gpu.module @test {
+  gpu.func @store_nd_1d(%arg0: memref<16xf32>) {
+    %c0 = arith.constant 0 : index
+    %cst = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<1.000000e+00> : vector<16xf32>
+    %0 = xegpu.create_nd_tdesc %arg0[%c0] : memref<16xf32> -> !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
+    xegpu.store_nd %cst, %0  : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
+    gpu.return
+  }
+}
+
+// -----
+// CHECK-LABEL: gpu.func @store_nd_2d
+// CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: memref<16x16xf16>) {
+// CHECK-DAG: %[[CST:.*]] = arith.constant dense<1.000000e+00> : vector<16xf16>
+// CHECK-DAG: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%{{.*}}] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
+// CHECK: xegpu.store_nd %[[CST]], %[[T0]]  : vector<16xf16>, !xegpu.tensor_desc<16x16xf16>
+gpu.module @test {
+  gpu.func @store_nd_2d(%arg0: memref<16x16xf16>) {
+    %c0 = arith.constant 0 : index
+    %cst = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} dense<1.000000e+00> : vector<16x16xf16>
+    %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    xegpu.store_nd %cst, %0 : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    gpu.return
+  }
+}
+
+
+
+// -----
+// CHECK-LABEL: gpu.func @load_nd_1d
+// CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: memref<16xf32>, %[[ARG1:[0-9a-zA-Z]+]]: memref<16xf32>) {
+// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%{{.*}}] : memref<16xf32> -> !xegpu.tensor_desc<16xf32>
+// CHECK-DAG: %[[T1:.*]] = xegpu.load_nd %[[T0]]  : !xegpu.tensor_desc<16xf32> -> vector<1xf32>
+// CHECK-DAG: %[[T2:.*]] = xegpu.create_nd_tdesc %[[ARG1]][%{{.*}}] : memref<16xf32> -> !xegpu.tensor_desc<16xf32>
+// CHECK: xegpu.store_nd %[[T1]], %[[T2]]  : vector<1xf32>, !xegpu.tensor_desc<16xf32>
+gpu.module @test {
+  gpu.func @load_nd_1d(%arg0: memref<16xf32>, %arg1: memref<16xf32>) {
+    %c0 = arith.constant 0 : index
+    %0 = xegpu.create_nd_tdesc %arg0[%c0] : memref<16xf32> -> !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
+    %1 = xegpu.load_nd %0  {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} : !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>> -> vector<16xf32>
+    %2 = xegpu.create_nd_tdesc %arg1[%c0] : memref<16xf32> -> !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
+    xegpu.store_nd %1, %2 : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
+    gpu.return
+  }
+}
+
+// -----
+// CHECK-LABEL: gpu.func @load_nd_2d
+// CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: memref<16x16xf16>, %[[ARG1:[0-9a-zA-Z]+]]: memref<16x16xf16>) {
+// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%{{.*}}] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
+// CHECK-DAG: %[[T1:.*]] = xegpu.load_nd %[[T0]]  : !xegpu.tensor_desc<16x16xf16> -> vector<16xf16>
+// CHECK-DAG: %[[T2:.*]] = xegpu.create_nd_tdesc %[[ARG1]][%{{.*}}] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
+// CHECK: xegpu.store_nd %[[T1]], %[[T2]]  : vector<16xf16>, !xegpu.tensor_desc<16x16xf16>
+gpu.module @test {
+  gpu.func @load_nd_2d(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf16>) {
+    %c0 = arith.constant 0 : index
+    %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    %1 = xegpu.load_nd %0  {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<16x16xf16>
+    %2 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    xegpu.store_nd %1, %2 : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    gpu.return
+  }
+}
+
+// -----
+// CHECK-LABEL: gpu.func @load_nd_array_length
+// CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: memref<16x16xf16>, %[[ARG1:[0-9a-zA-Z]+]]: memref<16x16xf16>) {
+// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%{{.*}}] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2 : i64>>
+// CHECK: %[[T1:.*]] = xegpu.load_nd %[[T0]]  : !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2 : i64>> -> vector<32xf16>
+// CHECK: %[[T2:.*]] = vector.shape_cast %[[T1]] : vector<32xf16> to vector<2x16x1xf16>
+// CHECK: %[[T3:.*]] = vector.extract %[[T2]][0] : vector<16x1xf16> from vector<2x16x1xf16>
+// CHECK-DAG: %[[T4:.*]] = xegpu.create_nd_tdesc %[[ARG1]][%{{.*}}] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
+// CHECK-DAG: %[[T5:.*]] = vector.shape_cast %[[T3]] : vector<16x1xf16> to vector<16xf16>
+// CHECK: xegpu.store_nd %[[T5]], %[[T4]]  : vector<16xf16>, !xegpu.tensor_desc<16x16xf16>
+gpu.module @test {
+  gpu.func @load_nd_array_length(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf16>) {
+    %c0 = arith.constant 0 : index
+    %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2 : i64>, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    %1 = xegpu.load_nd %0  {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2 : i64>, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<2x16x16xf16>
+    %2 = vector.extract %1[%c0] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16x16xf16> from vector<2x16x16xf16>
+    %3 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    xegpu.store_nd %2, %3 : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    gpu.return
+  }
+}
+
+// -----
+// CHECK-LABEL: gpu.func @load_dpas_store
+// CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: memref<8x16xf16>, %[[ARG1:[0-9a-zA-Z]+]]: memref<16x16xf16>, %[[ARG2:[0-9a-zA-Z]+]]: memref<8x16xf32>) {
+// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG1]][%{{.*}}] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
+// CHECK: %[[T1:.*]] = xegpu.load_nd %[[T0]] <{packed}> : !xegpu.tensor_desc<16x16xf16> -> vector<16xf16>
+// CHECK: %[[T2:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%{{.*}}] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
+// CHECK: %[[T3:.*]] = xegpu.load_nd %[[T2]]  : !xegpu.tensor_desc<8x16xf16> -> vector<8xf16>
+// CHECK-DAG: %[[T4:.*]] = xegpu.dpas %[[T3]], %[[T1]] : vector<8xf16>, vector<16xf16> -> vector<8xf32>
+// CHECK-DAG: %[[T5:.*]] = xegpu.create_nd_tdesc %[[ARG2]][%{{.*}}] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
+// CHECK: xegpu.store_nd %[[T4]], %[[T5]]  : vector<8xf32>, !xegpu.tensor_desc<8x16xf32>
+gpu.module @test {
+  gpu.func @load_dpas_store(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<8x16xf32>) {
+    %c0 = arith.constant 0 : index
+    %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    %1 = xegpu.load_nd %0  {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : !xegpu.tensor_desc<8x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<8x16xf16>
+    %2 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>>
+    %3 = xegpu.load_nd %2  {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>> -> vector<16x16xf16>
+    %4 = xegpu.dpas %1, %3 {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+    %5 = xegpu.create_nd_tdesc %arg2[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    xegpu.store_nd %4, %5  : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    gpu.return
+  }
+}
+
+
+// -----
+// CHECK-LABEL: gpu.func @load_dpas_postop_store
+// CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: memref<8x16xf16>, %[[ARG1:[0-9a-zA-Z]+]]: memref<16x16xf16>, %[[ARG2:[0-9a-zA-Z]+]]: memref<8x16xf32>) {
+// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG1]][%{{.*}}] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
+// CHECK: %[[T1:.*]] = xegpu.load_nd %[[T0]] <{packed}> : !xegpu.tensor_desc<16x16xf16> -> vector<16xf16>
+// CHECK: %[[T2:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%{{.*}}] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
+// CHECK: %[[T3:.*]] = xegpu.load_nd %[[T2]]  : !xegpu.tensor_desc<8x16xf16> -> vector<8xf16>
+// CHECK-DAG: %[[T4:.*]] = xegpu.dpas %[[T3]], %[[T1]] : vector<8xf16>, vector<16xf16> -> vector<8xf32>
+// CHECK: %[[T5:.*]] = vector.shape_cast %[[T4]] : vector<8xf32> to vector<8x1xf32>
+// CHECK: %[[T6:.*]] = math.exp %[[T5]] {{{.*}}} : vector<8x1xf32>
+// CHECK-DAG: %[[T8:.*]] = vector.shape_cast %[[T6]] : vector<8x1xf32> to vector<8xf32>
+// CHECK-DAG: %[[T7:.*]] = xegpu.create_nd_tdesc %[[ARG2]][{{.*}}] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
+// CHECK: xegpu.store_nd %[[T8]], %[[T7]]  : vector<8xf32>, !xegpu.tensor_desc<8x16xf32>
+gpu.module @test {
+  gpu.func @load_dpas_postop_store(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<8x16xf32>) {
+    %c0 = arith.constant 0 : index
+    %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    %1 = xegpu.load_nd %0  {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : !xegpu.tensor_desc<8x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<8x16xf16>
+    %2 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>>
+    %3 = xegpu.load_nd %2  {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>> -> vector<16x16xf16>
+    %4 = xegpu.dpas %1, %3 {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+    %5 = math.exp %4 {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<8x16xf32>
+    %6 = xegpu.create_nd_tdesc %arg2[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    xegpu.store_nd %5, %6 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    gpu.return
+  }
+}
+
+// -----
+// CHECK-LABEL: gpu.func @create_nd_tdesc_non_memref
+// CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: ui64, %[[ARG1:[0-9a-zA-Z]+]]: ui64, %[[ARG2:[0-9a-zA-Z]+]]: index,
+// CHECK-SAME: %[[ARG3:[0-9a-zA-Z]+]]: index, %[[ARG4:[0-9a-zA-Z]+]]: index,
+// CHECK-SAME: %[[ARG5:[0-9a-zA-Z]+]]: index, %[[ARG6:[0-9a-zA-Z]+]]: index, %[[ARG7:[0-9a-zA-Z]+]]: index) {
+// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]][{{.*}}], [%[[ARG2]], %[[ARG3]]], [%[[ARG4]], %[[ARG5]]] : ui64 -> !xegpu.tensor_desc<16x16xf16>
+// CHECK: %[[T1:.*]] = xegpu.load_nd %[[T0]]  : !xegpu.tensor_desc<16x16xf16> -> vector<16xf16>
+// CHECK: %[[T2:.*]] = xegpu.create_nd_tdesc %[[ARG1]][{{.*}}], [%[[ARG2]], %[[ARG3]]], [%[[ARG4]], %[[ARG5]]] : ui64 -> !xegpu.tensor_desc<16x16xf16>
+// CHECK: xegpu.store_nd %[[T1]], %[[T2]]  : vector<16xf16>, !xegpu.tensor_desc<16x16xf16>
+gpu.module @test {
+  gpu.func @create_nd_tdesc_non_memref(%arg0: ui64, %arg1: ui64, %arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: index, %arg7: index) {
+    %c0 = arith.constant 0 : index
+    %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0], [%arg2, %arg3], [%arg4, %arg5] : ui64 -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    %1 = xegpu.load_nd %0  {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<16x16xf16>
+    %2 = xegpu.create_nd_tdesc %arg1[%c0, %c0], [%arg2, %arg3], [%arg4, %arg5] : ui64 -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    xegpu.store_nd %1, %2 : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    gpu.return
+  }
+}
+
+// -----
+// TODO: gemm does not use update_nd_offset because of an issue in scf-for distribution.
+// CHECK-LABEL: gpu.func @gemm
+// CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: memref<1024x1024xbf16>, %[[ARG1:[0-9a-zA-Z]+]]: memref<1024x1024xbf16>, %[[ARG2:[0-9a-zA-Z]+]]: memref<1024x1024xf32>) {
+// CHECK-DAG: %[[BLOCK_ID_X:.*]] = gpu.block_id x
+// CHECK-DAG: %[[BLOCK_ID_Y:.*]] = gpu.block_id y
+// CHECK-DAG: %[[Y_COORD:.*]] = arith.muli %[[BLOCK_ID_Y]], %c16 : index
+// CHECK-DAG: %[[X_COORD:.*]] = arith.muli %[[BLOCK_ID_X]], %c8 : index
+// CHECK: %[[T2:.*]] = xegpu.create_nd_tdesc %[[ARG2]][%[[X_COORD]], %[[Y_COORD]]] : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32>
+// CHECK-NEXT: %[[T3:.*]] = xegpu.load_nd %[[T2]] : !xegpu.tensor_desc<8x16xf32> -> vector<8xf32>
+// CHECK-NEXT: %[[T4:.*]] = vector.shape_cast %[[T3]] : vector<8xf32> to vector<8x1xf32>
+// CHECK: %[[T5:.*]] = scf.for %[[K:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG4:.*]] = %[[T4]]) -> (vector<8x1xf32>) {
+// CHECK-DAG: %[[T10:.*]] = xegpu.create_nd_tdesc %[[ARG1]][%[[K]], %[[Y_COORD]]] : memref<1024x1024xbf16> -> !xegpu.tensor_desc<16x16xbf16>
+// CHECK-DAG: %[[T11:.*]] = xegpu.load_nd %[[T10]] <{packed}> : !xegpu.tensor_desc<16x16xbf16> -> vector<16xbf16>
+// CHECK-DAG: %[[T12:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%[[X_COORD]], %[[K]]] : memref<1024x1024xbf16> -> !xegpu.tensor_desc<8x16xbf16>
+// CHECK-DAG: %[[T13:.*]] = xegpu.load_nd %[[T12]] : !xegpu.tensor_desc<8x16xbf16> -> vector<8xbf16>
+// CHECK-DAG: %[[T14:.*]] = vector.shape_cast %[[ARG4]] : vector<8x1xf32> to vector<8xf32>
+// CHECK-NEXT: %[[T15:.*]] = xegpu.dpas %[[T13]], %[[T11]], %[[T14]] : vector<8xbf16>, vector<16xbf16>, vector<8xf32> -> vector<8xf32>
+// CHECK-NEXT: %[[T16:.*]] = vector.shape_cast %[[T15]] : vector<8xf32> to vector<8x1xf32>
+// CHECK-NEXT: scf.yield %[[T16]] : vector<8x1xf32>
+// CHECK-NEXT: }
+// CHECK-NEXT: %[[T9:.*]] = vector.shape_cast %[[T5]] : vector<8x1xf32> to vector<8xf32>
+// CHECK-NEXT: xegpu.store_nd %[[T9]], %[[T2]] : vector<8xf32>, !xegpu.tensor_desc<8x16xf32>
+gpu.module @test {
+gpu.func @gemm(%arg0: memref<1024x1024xbf16>, %arg1: memref<1024x1024xbf16>, %arg2: memref<1024x1024xf32>){
+  %c0 = arith.constant 0 : index
+  %c16 = arith.constant 16 : index
+  %c8 = arith.constant 8 : index
+  %c1024 = arith.constant 1024 : index
+  %block_id_x = gpu.block_id  x
+  %block_id_y = gpu.block_id  y
+  %0 = arith.muli %block_id_x, %c8 : index
+  %1 = arith.muli %block_id_y, %c16 : index
+  %2 = xegpu.create_nd_tdesc %arg2[%0, %1] : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+  %3 = xegpu.load_nd %2  {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<8x16xf32>
+  %4 = scf.for %arg3 = %c0 to %c1024 step %c16 iter_args(%arg4 = %3) -> (vector<8x16xf32>) {
+    %5 = xegpu.create_nd_tdesc %arg0[%0, %arg3] : memref<1024x1024xbf16> -> !xegpu.tensor_desc<8x16xbf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    %6 = xegpu.create_nd_tdesc %arg1[%arg3, %1] : memref<1024x1024xbf16> -> !xegpu.tensor_desc<16x16xbf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>>
+    %7 = xegpu.load_nd %5  {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : !xegpu.tensor_desc<8x16xbf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<8x16xbf16>
+    %8 = xegpu.load_nd %6  {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} : !xegpu.tensor_desc<16x16xbf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>> -> vector<16x16xbf16>
+    %9 = xegpu.dpas %7, %8, %arg4 {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<8x16xbf16>, vector<16x16xbf16>, vector<8x16xf32> -> vector<8x16xf32>
+    scf.yield %9 : vector<8x16xf32>
+  } {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+  xegpu.store_nd %4, %2 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+  gpu.return
+}
+}
+
+// -----
+// CHECK-LABEL: gpu.func @update_nd_offset_1d(
+// CHECK: %[[ARG0:[0-9a-zA-Z]+]]: memref<256xf32>) {
+// CHECK: %[[CST:.*]] = arith.constant dense<1.000000e+00> : vector<1xf32>
+// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%{{.*}}] : memref<256xf32> -> !xegpu.tensor_desc<16xf32>
+// CHECK: %[[T1:.*]] = xegpu.update_nd_offset %[[T0]], [%c32] : !xegpu.tensor_desc<16xf32>
+// CHECK: xegpu.store_nd %[[CST]], %[[T1]]  : vector<1xf32>, !xegpu.tensor_desc<16xf32>
+gpu.module @test {
+  gpu.func @update_nd_offset_1d(%arg0: memref<256xf32>) {
+    %c0 = arith.constant 0 : index
+    %c32 = arith.constant 32 : index
+    %cst = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<1.000000e+00> : vector<16xf32>
+    %0 = xegpu.create_nd_tdesc %arg0[%c0] : memref<256xf32> -> !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
+    %1 = xegpu.update_nd_offset %0, [%c32] : !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
+    xegpu.store_nd %cst, %1  : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
+    gpu.return
+  }
+}
+
+// -----
+// CHECK-LABEL: gpu.func @update_nd_offset_2d
+// CHECK: %[[ARG0:[0-9a-zA-Z]+]]: memref<256x256xf32>) {
+// CHECK: %[[CST:.*]] = arith.constant dense<1.000000e+00> : vector<16xf32>
+// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%{{.*}}] : memref<256x256xf32> -> !xegpu.tensor_desc<16x16xf32>
+// CHECK: %[[T1:.*]] = xegpu.update_nd_offset %[[T0]], [%c32, %c32] : !xegpu.tensor_desc<16x16xf32>
+// CHECK: xegpu.store_nd %[[CST]], %[[T1]]  : vector<16xf32>, !xegpu.tensor_desc<16x16xf32>
+gpu.module @test {
+  gpu.func @update_nd_offset_2d(%arg0: memref<256x256xf32>) {
+    %c0 = arith.constant 0 : index
+    %c32 = arith.constant 32 : index
+    %cst = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} dense<1.000000e+00> : vector<16x16xf32>
+    %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<256x256xf32> -> !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    %1 = xegpu.update_nd_offset %0, [%c32, %c32] : !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    xegpu.store_nd %cst, %1  : vector<16x16xf32>, !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    gpu.return
+  }
+}
+
+// -----
+// CHECK-LABEL: gpu.func @prefetch_2d
+// CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: memref<256x256xf16>) {
+// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%{{.*}}] : memref<256x256xf16> -> !xegpu.tensor_desc<16x16xf16>
+// CHECK: xegpu.prefetch_nd %[[T0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<16x16xf16>
+gpu.module @test {
+  gpu.func @prefetch_2d(%arg0: memref<256x256xf16>) {
+    %c0 = arith.constant 0 : index
+    %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<256x256xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    xegpu.prefetch_nd %0 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    gpu.return
+  }
+}
+
+// -----
+// CHECK-LABEL: gpu.func @prefetch_1d
+// CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: memref<256xf16>) {
+// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%{{.*}}] : memref<256xf16> -> !xegpu.tensor_desc<16xf16>
+// CHECK: xegpu.prefetch_nd %[[T0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<16xf16>
+gpu.module @test {
+  gpu.func @prefetch_1d(%arg0: memref<256xf16>) {
+    %c0 = arith.constant 0 : index
+    %0 = xegpu.create_nd_tdesc %arg0[%c0] : memref<256xf16> -> !xegpu.tensor_desc<16xf16, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
+    xegpu.prefetch_nd %0 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<16xf16, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
+    gpu.return
+  }
+}

diff  --git a/mlir/test/Dialect/XeGPU/subgroup-distribution.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribution.mlir
deleted file mode 100644
index e5606c5642505..0000000000000
--- a/mlir/test/Dialect/XeGPU/subgroup-distribution.mlir
+++ /dev/null
@@ -1,275 +0,0 @@
-// RUN: mlir-opt -xegpu-subgroup-distribute -cse -split-input-file %s | FileCheck %s
-
-// CHECK-LABEL: gpu.func @store_nd_1d
-// CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: memref<16xf32>) {
-// CHECK-DAG: %[[CST:.*]] = arith.constant dense<1.000000e+00> : vector<1xf32>
-// CHECK-DAG: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%{{.*}}] : memref<16xf32> -> !xegpu.tensor_desc<16xf32>
-// CHECK: xegpu.store_nd %[[CST]], %[[T0]]  : vector<1xf32>, !xegpu.tensor_desc<16xf32>
-// CHECK: gpu.return
-gpu.module @test {
-gpu.func @store_nd_1d(%arg0: memref<16xf32>){
-  %c0 = arith.constant 0 : index
-  %1 = arith.constant dense<1.000000e+00> : vector<16xf32>
-  %0 = xegpu.create_nd_tdesc %arg0[%c0] : memref<16xf32> -> !xegpu.tensor_desc<16xf32>
-  xegpu.store_nd %1, %0 : vector<16xf32>, !xegpu.tensor_desc<16xf32>
-  gpu.return
-}
-}
-
-// -----
-// CHECK-LABEL: gpu.func @store_nd_2d
-// CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: memref<16x16xf16>) {
-// CHECK-DAG: %[[CST:.*]] = arith.constant dense<1.000000e+00> : vector<16xf16>
-// CHECK-DAG: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%{{.*}}] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
-// CHECK: xegpu.store_nd %[[CST]], %[[T0]]  : vector<16xf16>, !xegpu.tensor_desc<16x16xf16>
-gpu.module @test {
-gpu.func @store_nd_2d(%arg0: memref<16x16xf16>){
-  %c0 = arith.constant 0 : index
-  %1 = arith.constant dense<1.000000e+00> : vector<16x16xf16>
-  %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
-  xegpu.store_nd %1, %0 : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16>
-  gpu.return
-}
-}
-
-
-
-// -----
-// CHECK-LABEL: gpu.func @load_nd_1d
-// CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: memref<16xf32>, %[[ARG1:[0-9a-zA-Z]+]]: memref<16xf32>) {
-// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%{{.*}}] : memref<16xf32> -> !xegpu.tensor_desc<16xf32>
-// CHECK-DAG: %[[T1:.*]] = xegpu.load_nd %[[T0]]  : !xegpu.tensor_desc<16xf32> -> vector<1xf32>
-// CHECK-DAG: %[[T2:.*]] = xegpu.create_nd_tdesc %[[ARG1]][%{{.*}}] : memref<16xf32> -> !xegpu.tensor_desc<16xf32>
-// CHECK: xegpu.store_nd %[[T1]], %[[T2]]  : vector<1xf32>, !xegpu.tensor_desc<16xf32>
-gpu.module @test {
-gpu.func @load_nd_1d(%arg0: memref<16xf32>, %arg1: memref<16xf32>){
-  %c0 = arith.constant 0 : index
-  %0 = xegpu.create_nd_tdesc %arg0[%c0] : memref<16xf32> -> !xegpu.tensor_desc<16xf32>
-  %1 = xegpu.load_nd %0 : !xegpu.tensor_desc<16xf32> -> vector<16xf32>
-  %2 = xegpu.create_nd_tdesc %arg1[%c0] : memref<16xf32> -> !xegpu.tensor_desc<16xf32>
-  xegpu.store_nd %1, %2 : vector<16xf32>, !xegpu.tensor_desc<16xf32>
-  gpu.return
-}
-}
-
-// -----
-// CHECK-LABEL: gpu.func @load_nd_2d
-// CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: memref<16x16xf16>, %[[ARG1:[0-9a-zA-Z]+]]: memref<16x16xf16>) {
-// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%{{.*}}] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
-// CHECK-DAG: %[[T1:.*]] = xegpu.load_nd %[[T0]]  : !xegpu.tensor_desc<16x16xf16> -> vector<16xf16>
-// CHECK-DAG: %[[T2:.*]] = xegpu.create_nd_tdesc %[[ARG1]][%{{.*}}] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
-// CHECK: xegpu.store_nd %[[T1]], %[[T2]]  : vector<16xf16>, !xegpu.tensor_desc<16x16xf16>
-gpu.module @test {
-gpu.func @load_nd_2d(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf16>){
-  %c0 = arith.constant 0 : index
-  %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
-  %1 = xegpu.load_nd %0 : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
-  %2 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
-  xegpu.store_nd %1, %2 : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16>
-  gpu.return
-}
-}
-
-// -----
-// CHECK-LABEL: gpu.func @load_nd_array_length
-// CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: memref<16x16xf16>, %[[ARG1:[0-9a-zA-Z]+]]: memref<16x16xf16>) {
-// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%{{.*}}] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2 : i64>>
-// CHECK: %[[T1:.*]] = xegpu.load_nd %[[T0]]  : !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2 : i64>> -> vector<32xf16>
-// CHECK: %[[T2:.*]] = vector.shape_cast %[[T1]] : vector<32xf16> to vector<2x16x1xf16>
-// CHECK: %[[T3:.*]] = vector.extract %[[T2]][0] : vector<16x1xf16> from vector<2x16x1xf16>
-// CHECK-DAG: %[[T4:.*]] = xegpu.create_nd_tdesc %[[ARG1]][%{{.*}}] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
-// CHECK-DAG: %[[T5:.*]] = vector.shape_cast %[[T3]] : vector<16x1xf16> to vector<16xf16>
-// CHECK: xegpu.store_nd %[[T5]], %[[T4]]  : vector<16xf16>, !xegpu.tensor_desc<16x16xf16>
-gpu.module @test {
-gpu.func @load_nd_array_length(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf16>){
-  %c0 = arith.constant 0 : index
-  %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2 : i64>>
-  %1 = xegpu.load_nd %0 : !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2 : i64>> -> vector<2x16x16xf16>
-  %2 = vector.extract %1[%c0] : vector<16x16xf16> from vector<2x16x16xf16>
-  %3 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
-  xegpu.store_nd %2, %3 : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16>
-  gpu.return
-}
-}
-
-// -----
-// CHECK-LABEL: gpu.func @dpas
-// CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: vector<8x16xf16>, %[[ARG1:[0-9a-zA-Z]+]]: vector<16x16xf16>, %[[ARG2:[0-9a-zA-Z]+]]: vector<8x16xf32>, %[[ARG3:[0-9a-zA-Z]+]]: memref<8x16xf32>) {
-// CHECK: %[[T1:.*]]:3 = gpu.warp_execute_on_lane_0(%{{.*}})[16] args(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]]
-// CHECK-SAME: vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32>, memref<8x16xf32>) -> (vector<8x1xf16>, vector<16x1xf16>, vector<8x1xf32>) {
-// CHECK: ^bb0(%[[ARG4:[0-9a-zA-Z]+]]: vector<8x16xf16>, %[[ARG5:[0-9a-zA-Z]+]]: vector<16x16xf16>, %[[ARG6:[0-9a-zA-Z]+]]: vector<8x16xf32>, %[[ARG7:[0-9a-zA-Z]+]]: memref<8x16xf32>):
-// CHECK:  gpu.yield %[[ARG4]], %[[ARG5]], %[[ARG6]] : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32>
-// CHECK: }
-// CHECK-DAG: %[[T2:.*]] = vector.shape_cast %[[T1]]#0 : vector<8x1xf16> to vector<8xf16>
-// CHECK-DAG: %[[T3:.*]] = vector.shape_cast %[[T1]]#1 : vector<16x1xf16> to vector<16xf16>
-// CHECK-DAG: %[[T4:.*]] = vector.shape_cast %[[T1]]#2 : vector<8x1xf32> to vector<8xf32>
-// CHECK: %[[T5:.*]] = xegpu.dpas %[[T2]], %[[T3]], %[[T4]] : vector<8xf16>, vector<16xf16>, vector<8xf32> -> vector<8xf32>
-// CHECK: %[[T6:.*]] = xegpu.create_nd_tdesc %[[ARG3]][%{{.*}}] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
-// CHECK: xegpu.store_nd %[[T5]], %[[T6]]  : vector<8xf32>, !xegpu.tensor_desc<8x16xf32>
-gpu.module @test {
-gpu.func @dpas(%arg0: vector<8x16xf16>, %arg1: vector<16x16xf16>, %arg3: vector<8x16xf32>, %arg2: memref<8x16xf32>){
-  %c0 = arith.constant 0 : index
-  %0 = xegpu.dpas %arg0, %arg1, %arg3 : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
-  %3 = xegpu.create_nd_tdesc %arg2[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
-  xegpu.store_nd %0, %3 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
-  gpu.return
-}
-}
-
-// -----
-// CHECK-LABEL: gpu.func @load_dpas_store
-// CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: memref<8x16xf16>, %[[ARG1:[0-9a-zA-Z]+]]: memref<16x16xf16>, %[[ARG2:[0-9a-zA-Z]+]]: memref<8x16xf32>) {
-// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG1]][%{{.*}}] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
-// CHECK: %[[T1:.*]] = xegpu.load_nd %[[T0]] <{packed}> : !xegpu.tensor_desc<16x16xf16> -> vector<16xf16>
-// CHECK: %[[T2:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%{{.*}}] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
-// CHECK: %[[T3:.*]] = xegpu.load_nd %[[T2]]  : !xegpu.tensor_desc<8x16xf16> -> vector<8xf16>
-// CHECK-DAG: %[[T4:.*]] = xegpu.dpas %[[T3]], %[[T1]] : vector<8xf16>, vector<16xf16> -> vector<8xf32>
-// CHECK-DAG: %[[T5:.*]] = xegpu.create_nd_tdesc %[[ARG2]][%{{.*}}] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
-// CHECK: xegpu.store_nd %[[T4]], %[[T5]]  : vector<8xf32>, !xegpu.tensor_desc<8x16xf32>
-gpu.module @test {
-gpu.func @load_dpas_store(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg3: memref<8x16xf32>){
-  %c0 = arith.constant 0 : index
-  %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
-  %1 = xegpu.load_nd %0 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
-  %2 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
-  %3 = xegpu.load_nd %2 : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
-  %4 = xegpu.dpas %1, %3 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
-  %5 = xegpu.create_nd_tdesc %arg3[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
-  xegpu.store_nd %4, %5 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
-  gpu.return
-}
-}
-
-// -----
-gpu.module @test {
-// CHECK-LABEL: gpu.func @create_nd_tdesc_non_memref
-// CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: ui64, %[[ARG1:[0-9a-zA-Z]+]]: ui64, %[[ARG2:[0-9a-zA-Z]+]]: index,
-// CHECK-SAME: %[[ARG3:[0-9a-zA-Z]+]]: index, %[[ARG4:[0-9a-zA-Z]+]]: index,
-// CHECK-SAME: %[[ARG5:[0-9a-zA-Z]+]]: index, %[[ARG6:[0-9a-zA-Z]+]]: index, %[[ARG7:[0-9a-zA-Z]+]]: index) {
-// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]][{{.*}}], [%[[ARG2]], %[[ARG3]]], [%[[ARG4]], %[[ARG5]]] : ui64 -> !xegpu.tensor_desc<16x16xf16>
-// CHECK: %[[T1:.*]] = xegpu.load_nd %[[T0]]  : !xegpu.tensor_desc<16x16xf16> -> vector<16xf16>
-// CHECK: %[[T2:.*]] = xegpu.create_nd_tdesc %[[ARG1]][{{.*}}], [%[[ARG2]], %[[ARG3]]], [%[[ARG4]], %[[ARG5]]] : ui64 -> !xegpu.tensor_desc<16x16xf16>
-// CHECK: xegpu.store_nd %[[T1]], %[[T2]]  : vector<16xf16>, !xegpu.tensor_desc<16x16xf16>
-gpu.func @create_nd_tdesc_non_memref(%arg0: ui64, %arg1: ui64,
-  %arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: index, %arg7: index) {
-  %c0 = arith.constant 0 : index
-  %0 = xegpu.create_nd_tdesc %arg0 [%c0, %c0], [%arg2, %arg3], [%arg4, %arg5] : ui64 -> !xegpu.tensor_desc<16x16xf16>
-  %1 = xegpu.load_nd %0 : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
-  %2 = xegpu.create_nd_tdesc %arg1[%c0, %c0], [%arg2, %arg3], [%arg4, %arg5] : ui64 -> !xegpu.tensor_desc<16x16xf16>
-  xegpu.store_nd %1, %2 : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16>
-  gpu.return
-}
-}
-
-// -----
-// CHECK-LABEL: gpu.func @gemm_loop
-// CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: memref<1024x1024xbf16>, %[[ARG1:[0-9a-zA-Z]+]]: memref<1024x1024xbf16>, %[[ARG2:[0-9a-zA-Z]+]]: memref<1024x1024xf32>) {
-// CHECK: %[[BLOCK_ID_X:.*]] = gpu.block_id x
-// CHECK: %[[BLOCK_ID_Y:.*]] = gpu.block_id y
-// CHECK: %[[Y_COORD:.*]] = arith.muli %[[BLOCK_ID_Y]], %c16 : index
-// CHECK: %[[X_COORD:.*]] = arith.muli %[[BLOCK_ID_X]], %c8 : index
-// CHECK: %[[T2:.*]] = xegpu.create_nd_tdesc %[[ARG2]][%[[X_COORD]], %[[Y_COORD]]] : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32>
-// CHECK: %[[T3:.*]] = xegpu.load_nd %[[T2]] : !xegpu.tensor_desc<8x16xf32> -> vector<8xf32>
-// CHECK: %[[T4:.*]] = vector.shape_cast %[[T3]] : vector<8xf32> to vector<8x1xf32>
-// CHECK: %[[T5:.*]] = scf.for %[[K:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG4:.*]] = %[[T4]]) -> (vector<8x1xf32>) {
-// CHECK: %[[T10:.*]] = xegpu.create_nd_tdesc %[[ARG1]][%[[K]], %[[Y_COORD]]] : memref<1024x1024xbf16> -> !xegpu.tensor_desc<16x16xbf16>
-// CHECK: %[[T11:.*]] = xegpu.load_nd %[[T10]] <{packed}> : !xegpu.tensor_desc<16x16xbf16> -> vector<16xbf16>
-// CHECK: %[[T12:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%[[X_COORD]], %[[K]]] : memref<1024x1024xbf16> -> !xegpu.tensor_desc<8x16xbf16>
-// CHECK: %[[T13:.*]] = xegpu.load_nd %[[T12]] : !xegpu.tensor_desc<8x16xbf16> -> vector<8xbf16>
-// CHECK: %[[T14:.*]] = vector.shape_cast %[[ARG4]] : vector<8x1xf32> to vector<8xf32>
-// CHECK: %[[T15:.*]] = xegpu.dpas %[[T13]], %[[T11]], %[[T14]] : vector<8xbf16>, vector<16xbf16>, vector<8xf32> -> vector<8xf32>
-// CHECK: %[[T16:.*]] = vector.shape_cast %[[T15]] : vector<8xf32> to vector<8x1xf32>
-// CHECK: scf.yield %[[T16]] : vector<8x1xf32>
-// CHECK: }
-// CHECK: %[[T9:.*]] = vector.shape_cast %[[T5]] : vector<8x1xf32> to vector<8xf32>
-// CHECK: xegpu.store_nd %[[T9]], %[[T2]] : vector<8xf32>, !xegpu.tensor_desc<8x16xf32>
-gpu.module @test {
-gpu.func @gemm_loop(%arg0: memref<1024x1024xbf16>, %arg1: memref<1024x1024xbf16>, %arg2: memref<1024x1024xf32>){
-  %c0 = arith.constant 0 : index
-  %c16 = arith.constant 16 : index
-  %c8 = arith.constant 8 : index
-  %c1024 = arith.constant 1024 : index
-  %0 = gpu.block_id x
-  %1 = gpu.block_id y
-  %2 = arith.muli %0, %c8 : index
-  %3 = arith.muli %1, %c16 : index
-  %4 = xegpu.create_nd_tdesc %arg2[%2, %3] : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32>
-  %5 = xegpu.load_nd %4 : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32>
-  %6 = scf.for %arg3 = %c0 to %c1024 step %c16 iter_args(%arg4 = %5) -> (vector<8x16xf32>) {
-    %7 = xegpu.create_nd_tdesc %arg0[%2, %arg3] : memref<1024x1024xbf16> -> !xegpu.tensor_desc<8x16xbf16>
-    %8 = xegpu.create_nd_tdesc %arg1[%arg3, %3] : memref<1024x1024xbf16> -> !xegpu.tensor_desc<16x16xbf16>
-    %9 = xegpu.load_nd %7 : !xegpu.tensor_desc<8x16xbf16> -> vector<8x16xbf16>
-    %10 = xegpu.load_nd %8 : !xegpu.tensor_desc<16x16xbf16> -> vector<16x16xbf16>
-    %11 = xegpu.dpas %9, %10, %arg4 : vector<8x16xbf16>, vector<16x16xbf16>, vector<8x16xf32> -> vector<8x16xf32>
-    scf.yield %11 : vector<8x16xf32>
-  }
-  xegpu.store_nd %6, %4 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
-  gpu.return
-}
-}
-
-// -----
-// CHECK-LABEL: gpu.func @update_nd_offset_1d(
-// CHECK: %[[ARG0:[0-9a-zA-Z]+]]: memref<256xf32>) {
-// CHECK: %[[CST:.*]] = arith.constant dense<1.000000e+00> : vector<1xf32>
-// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%{{.*}}] : memref<256xf32> -> !xegpu.tensor_desc<16xf32>
-// CHECK: %[[T1:.*]] = xegpu.update_nd_offset %[[T0]], [%c32] : !xegpu.tensor_desc<16xf32>
-// CHECK: xegpu.store_nd %[[CST]], %[[T1]]  : vector<1xf32>, !xegpu.tensor_desc<16xf32>
-gpu.module @test {
-gpu.func @update_nd_offset_1d(%arg0: memref<256xf32>){
-  %c0 = arith.constant 0 : index
-  %c32 = arith.constant 32 : index
-  %1 = arith.constant dense<1.000000e+00> : vector<16xf32>
-  %0 = xegpu.create_nd_tdesc %arg0[%c0] : memref<256xf32> -> !xegpu.tensor_desc<16xf32>
-  %2 = xegpu.update_nd_offset %0, [%c32] : !xegpu.tensor_desc<16xf32>
-  xegpu.store_nd %1, %2 : vector<16xf32>, !xegpu.tensor_desc<16xf32>
-  gpu.return
-}
-}
-
-// -----
-// CHECK-LABEL: gpu.func @update_nd_offset_2d
-// CHECK: %[[ARG0:[0-9a-zA-Z]+]]: memref<256x256xf32>) {
-// CHECK: %[[CST:.*]] = arith.constant dense<1.000000e+00> : vector<16xf32>
-// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%{{.*}}] : memref<256x256xf32> -> !xegpu.tensor_desc<16x16xf32>
-// CHECK: %[[T1:.*]] = xegpu.update_nd_offset %[[T0]], [%c32, %c32] : !xegpu.tensor_desc<16x16xf32>
-// CHECK: xegpu.store_nd %[[CST]], %[[T1]]  : vector<16xf32>, !xegpu.tensor_desc<16x16xf32>
-gpu.module @test {
-gpu.func @update_nd_offset_2d(%arg0: memref<256x256xf32>){
-  %c0 = arith.constant 0 : index
-  %c32 = arith.constant 32 : index
-  %1 = arith.constant dense<1.000000e+00> : vector<16x16xf32>
-  %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<256x256xf32> -> !xegpu.tensor_desc<16x16xf32>
-  %2 = xegpu.update_nd_offset %0, [%c32, %c32] : !xegpu.tensor_desc<16x16xf32>
-  xegpu.store_nd %1, %2 : vector<16x16xf32>, !xegpu.tensor_desc<16x16xf32>
-  gpu.return
-}
-}
-
-// -----
-// CHECK-LABEL: gpu.func @prefetch_2d
-// CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: memref<256x256xf16>) {
-// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%{{.*}}] : memref<256x256xf16> -> !xegpu.tensor_desc<16x16xf16>
-// CHECK: xegpu.prefetch_nd %[[T0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<16x16xf16>
-gpu.module @test {
-gpu.func @prefetch_2d(%arg0: memref<256x256xf16>){
-  %c0 = arith.constant 0 : index
-  %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<256x256xf16> -> !xegpu.tensor_desc<16x16xf16>
-  xegpu.prefetch_nd %0 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>: !xegpu.tensor_desc<16x16xf16>
-  gpu.return
-}
-}
-
-// -----
-// CHECK-LABEL: gpu.func @prefetch_1d
-// CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: memref<256xf16>) {
-// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%{{.*}}] : memref<256xf16> -> !xegpu.tensor_desc<16xf16>
-// CHECK: xegpu.prefetch_nd %[[T0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<16xf16>
-gpu.module @test {
-gpu.func @prefetch_1d(%arg0: memref<256xf16>){
-  %c0 = arith.constant 0 : index
-  %0 = xegpu.create_nd_tdesc %arg0[%c0] : memref<256xf16> -> !xegpu.tensor_desc<16xf16>
-  xegpu.prefetch_nd %0 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>: !xegpu.tensor_desc<16xf16>
-  gpu.return
-}
-}

diff  --git a/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir b/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
deleted file mode 100644
index 35ac39d074c70..0000000000000
--- a/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
+++ /dev/null
@@ -1,622 +0,0 @@
-// RUN: mlir-opt -xegpu-subgroup-distribute='print-analysis-only=true' -split-input-file %s | FileCheck %s
-
-// CHECK: function: dpas_f16:
-// CHECK-NEXT: argument: <block argument> of type 'memref<8x16xf16>' at index: 0
-// CHECK-NEXT: layout  : Not assigned.
-// CHECK-NEXT: argument: <block argument> of type 'memref<16x16xf16>' at index: 1
-// CHECK-NEXT: layout  : Not assigned.
-// CHECK-NEXT: argument: <block argument> of type 'memref<8x16xf32>' at index: 2
-// CHECK-NEXT: layout  : Not assigned.
-// CHECK-NEXT: op    : %{{.*}} = arith.constant 0 : index
-// CHECK-NEXT: layout for result #0: Not assigned.
-// CHECK-NEXT: op    : %{{.*}} = arith.constant dense<0.000000e+00> : vector<8x16xf32>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 1]
-// CHECK-NEXT: op    : %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 1]
-// CHECK-NEXT: op    : %[[T1:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [2, 1]
-// CHECK-NEXT: op    : %[[T2:.*]] = xegpu.load_nd %[[T0]]  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 1]
-// CHECK-NEXT: op    : %[[T3:.*]] = xegpu.load_nd %[[T1]]  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [2, 1]
-// CHECK-NEXT: op    : %[[T4:.*]] = xegpu.dpas %[[T2]], %[[T3]], %{{.*}} : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 1]
-// CHECK-NEXT: op    : %{{.*}} = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 1]
-func.func @dpas_f16(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<8x16xf32>) {
-  %c0 = arith.constant 0 : index
-  %cst = arith.constant dense<0.000000e+00> : vector<8x16xf32>
-  %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
-  %1 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
-  %2 = xegpu.load_nd %0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
-  %3 = xegpu.load_nd %1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
-  %4 = xegpu.dpas %2, %3, %cst : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
-  %5 = xegpu.create_nd_tdesc %arg2[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
-  xegpu.store_nd %4, %5  : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
-  return
-}
-
-
-// -----
-// CHECK: function: dpas_i8:
-// CHECK-NEXT: argument: <block argument> of type 'vector<8x32xi8>' at index: 0
-// CHECK-NEXT: layout  : lane_layout: [1, 16], lane_data: [1, 2]
-// CHECK-NEXT: argument: <block argument> of type 'vector<32x16xi8>' at index: 1
-// CHECK-NEXT: layout  : lane_layout: [1, 16], lane_data: [4, 1]
-// CHECK-NEXT: argument: <block argument> of type 'memref<8x16xi32>' at index: 2
-// CHECK-NEXT: layout  : Not assigned.
-// CHECK-NEXT: op    : %{{.*}} = arith.constant 0 : index
-// CHECK-NEXT: layout for result #0: Not assigned.
-// CHECK-NEXT: op    : %[[T0:.*]] = xegpu.dpas %{{.*}} : vector<8x32xi8>, vector<32x16xi8> -> vector<8x16xi32>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 1]
-// CHECK-NEXT: op    : %[[T1:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xi32> -> !xegpu.tensor_desc<8x16xi32>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 1]
-func.func @dpas_i8(%arg0: vector<8x32xi8>, %arg1: vector<32x16xi8>, %arg2: memref<8x16xi32>) {
-  %c0 = arith.constant 0 : index
-  %0 = xegpu.dpas %arg0, %arg1 : vector<8x32xi8>, vector<32x16xi8> -> vector<8x16xi32>
-  %1 = xegpu.create_nd_tdesc %arg2[%c0, %c0] : memref<8x16xi32> -> !xegpu.tensor_desc<8x16xi32>
-  xegpu.store_nd %0, %1  : vector<8x16xi32>, !xegpu.tensor_desc<8x16xi32>
-  return
-}
-
-// -----
-// CHECK: function: load_with_transpose_effect:
-// CHECK-NEXT: argument: <block argument> of type 'memref<8x16xf16>' at index: 0
-// CHECK-NEXT: layout  : Not assigned.
-// CHECK-NEXT: argument: <block argument> of type 'memref<16x16xf16>' at index: 1
-// CHECK-NEXT: layout  : Not assigned.
-// CHECK-NEXT: argument: <block argument> of type 'memref<8x16xf32>' at index: 2
-// CHECK-NEXT: layout  : Not assigned.
-// CHECK-NEXT: op    : %{{.*}} = arith.constant 0 : index
-// CHECK-NEXT: layout for result #0: Not assigned.
-// CHECK-NEXT: op    : %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<8x16xf32>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 1]
-// CHECK-NEXT: op    : %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 1]
-// CHECK-NEXT: op    : %[[T1:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
-// CHECK-NEXT: layout for result #0: lane_layout: [16, 1], lane_data: [1, 2]
-// CHECK-NEXT: op    : %[[T2:.*]] = xegpu.load_nd %[[T0]]  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 1]
-// CHECK-NEXT: op    : %[[T3:.*]] = xegpu.load_nd %[[T1]] <{transpose = array<i64: 1, 0>}> : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [2, 1]
-// CHECK-NEXT: op    : %[[T4:.*]] = xegpu.dpas %[[T2]], %[[T3]], %[[CST]] : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 1]
-// CHECK-NEXT: op    : %[[T5:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 1]
-func.func @load_with_transpose_effect(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<8x16xf32>) {
-  %c0 = arith.constant 0 : index
-  %cst = arith.constant dense<0.000000e+00> : vector<8x16xf32>
-  %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
-  %1 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
-  %2 = xegpu.load_nd %0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
-  %3 = xegpu.load_nd %1 <{transpose = array<i64: 1, 0>}> : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
-  %4 = xegpu.dpas %2, %3, %cst : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
-  %5 = xegpu.create_nd_tdesc %arg2[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
-  xegpu.store_nd %4, %5  : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
-  return
-}
-
-// -----
-// CHECK: function: vector_transpose:
-// CHECK-NEXT: argument: <block argument> of type 'memref<8x16xf16>' at index: 0
-// CHECK-NEXT: layout  : Not assigned.
-// CHECK-NEXT: argument: <block argument> of type 'memref<16x16xf16>' at index: 1
-// CHECK-NEXT: layout  : Not assigned.
-// CHECK-NEXT: argument: <block argument> of type 'memref<8x16xf32>' at index: 2
-// CHECK-NEXT: layout  : Not assigned.
-// CHECK-NEXT: op    : %{{.*}} = arith.constant 0 : index
-// CHECK-NEXT: layout for result #0: Not assigned.
-// CHECK-NEXT: op    : %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<8x16xf32>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 1]
-// CHECK-NEXT: op    : %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 1]
-// CHECK-NEXT: op    : %[[T1:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
-// CHECK-NEXT: layout for result #0: lane_layout: [16, 1], lane_data: [1, 2]
-// CHECK-NEXT: op    : %[[T2:.*]] = xegpu.load_nd %[[T0]]  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 1]
-// CHECK-NEXT: op    : %[[T3:.*]] = xegpu.load_nd %[[T1]]  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
-// CHECK-NEXT: layout for result #0: lane_layout: [16, 1], lane_data: [1, 2]
-// CHECK-NEXT: op    : %[[T4:.*]] = vector.transpose %[[T3]], [1, 0] : vector<16x16xf16> to vector<16x16xf16>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [2, 1]
-// CHECK-NEXT: op    : %[[T5:.*]] = xegpu.dpas %[[T2]], %[[T4]], %[[CST]] : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 1]
-// CHECK-NEXT: op    : %[[T6:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 1]
-func.func @vector_transpose(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<8x16xf32>) {
-  %c0 = arith.constant 0 : index
-  %cst = arith.constant dense<0.000000e+00> : vector<8x16xf32>
-  %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
-  %1 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
-  %2 = xegpu.load_nd %0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
-  %3 = xegpu.load_nd %1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
-  %4 = vector.transpose %3, [1, 0] : vector<16x16xf16> to vector<16x16xf16>
-  %5 = xegpu.dpas %2, %4, %cst : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
-  %6 = xegpu.create_nd_tdesc %arg2[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
-  xegpu.store_nd %5, %6  : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
-  return
-}
-
-// -----
-// CHECK: function: extf_truncf:
-// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<8x16xf16>' at index: 0
-// CHECK-NEXT: layout  : lane_layout: [1, 16], lane_data: [1, 1]
-// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<16x16xf16>' at index: 1
-// CHECK-NEXT: layout  : lane_layout: [1, 16], lane_data: [2, 1]
-// CHECK-NEXT: op    : %[[T0:.*]] = xegpu.load_nd %{{.*}}  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 1]
-// CHECK-NEXT: op    : %[[T1:.*]] = xegpu.load_nd %{{.*}}  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [2, 1]
-// CHECK-NEXT: op    : %[[T2:.*]] = arith.extf %[[T1]] : vector<16x16xf16> to vector<16x16xf32>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [2, 1]
-// CHECK-NEXT: op    : %[[T3:.*]] = arith.truncf %[[T2]] : vector<16x16xf32> to vector<16x16xf16>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [2, 1]
-// CHECK-NEXT: op    : %[[T4:.*]] = xegpu.dpas %[[T0]], %[[T3]] : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
-// CHECK-NEXT: layout for result #0: Not assigned.
-func.func @extf_truncf(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>) -> vector<8x16xf32> {
-  %0 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
-  %1 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
-  %2 = arith.extf %1 : vector<16x16xf16> to vector<16x16xf32>
-  %3 = arith.truncf %2 : vector<16x16xf32> to vector<16x16xf16>
-  %4 = xegpu.dpas %0, %3 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
-  return %4 : vector<8x16xf32>
-}
-
-// -----
-// CHECK: function: load_gather_with_transpose_effect:
-// CHECK-NEXT: argument: <block argument> of type 'memref<8x16xf16>' at index: 0
-// CHECK-NEXT: layout  : Not assigned.
-// CHECK-NEXT: argument: <block argument> of type 'memref<256xf16>' at index: 1
-// CHECK-NEXT: layout  : Not assigned.
-// CHECK-NEXT: argument: <block argument> of type 'memref<8x16xf32>' at index: 2
-// CHECK-NEXT: layout  : Not assigned.
-// CHECK-NEXT: op    : %{{.*}} = arith.constant 0 : index
-// CHECK-NEXT: layout for result #0: Not assigned.
-// CHECK-NEXT: op    : %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 1]
-// CHECK-NEXT: op    : %[[T1:.*]] = xegpu.load_nd %[[T0]]  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 1]
-// CHECK-NEXT: op    : %[[CST:.*]] = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
-// CHECK-NEXT: layout for result #0: lane_layout: [16], lane_data: [1]
-// CHECK-NEXT: op    : %[[CST0:.*]] = arith.constant dense<true> : vector<16xi1>
-// CHECK-NEXT: layout for result #0: lane_layout: [16], lane_data: [1]
-// CHECK-NEXT: op    : %[[T2:.*]] = xegpu.create_tdesc %{{.*}}, %[[CST]] : memref<256xf16>, vector<16xindex> -> !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16 : i64>>
-// CHECK-NEXT: layout for result #0: lane_layout: [16, 1], lane_data: [1, 2]
-// CHECK-NEXT: op    : %[[T3:.*]] = xegpu.load %[[T2]], %[[CST0]] <{transpose}> : !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16 : i64>>, vector<16xi1> -> vector<16x16xf16>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [2, 1]
-// CHECK-NEXT: op    : %[[T4:.*]] = xegpu.dpas %[[T1]], %[[T3]] : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 1]
-// CHECK-NEXT: op    : %[[T5:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 1]
-func.func @load_gather_with_transpose_effect(%arg0: memref<8x16xf16>, %arg1: memref<256xf16>, %arg2: memref<8x16xf32>) {
-  %c0 = arith.constant 0 : index
-  %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
-  %1 = xegpu.load_nd %0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
-  %cst = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
-  %cst_0 = arith.constant dense<true> : vector<16xi1>
-  %2 = xegpu.create_tdesc %arg1, %cst : memref<256xf16>, vector<16xindex> -> !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16 : i64>>
-  %3 = xegpu.load %2, %cst_0 <{transpose}> : !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16 : i64>>, vector<16xi1> -> vector<16x16xf16>
-  %4 = xegpu.dpas %1, %3 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
-  %5 = xegpu.create_nd_tdesc %arg2[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
-  xegpu.store_nd %4, %5  : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
-  return
-}
-
-// -----
-// CHECK: function: load_gather_1d:
-// CHECK: argument: <block argument> of type 'memref<256xf32>' at index: 0
-// CHECK-NEXT: layout  : Not assigned.
-// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<16xf32>' at index: 1
-// CHECK-NEXT: layout  : lane_layout: [16], lane_data: [1]
-// CHECK-NEXT: op    : %[[CST:.*]] = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
-// CHECK-NEXT: layout for result #0: lane_layout: [16], lane_data: [1]
-// CHECK-NEXT: op    : %[[CST0:.*]] = arith.constant dense<true> : vector<16xi1>
-// CHECK-NEXT: layout for result #0: lane_layout: [16], lane_data: [1]
-// CHECK-NEXT: op    : %[[T0:.*]] = xegpu.create_tdesc %{{.*}}, %[[CST]] : memref<256xf32>, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
-// CHECK-NEXT: layout for result #0: lane_layout: [16], lane_data: [1]
-// CHECK-NEXT: op    : %[[T1]] = xegpu.load %[[T0]], %[[CST0]]  : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> -> vector<16xf32>
-// CHECK-NEXT: layout for result #0: lane_layout: [16], lane_data: [1]
-func.func @load_gather_1d(%arg0: memref<256xf32>, %arg1: !xegpu.tensor_desc<16xf32>) {
-  %cst = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
-  %cst_0 = arith.constant dense<true> : vector<16xi1>
-  %0 = xegpu.create_tdesc %arg0, %cst : memref<256xf32>, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
-  %1 = xegpu.load %0, %cst_0  : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> -> vector<16xf32>
-  xegpu.store_nd %1, %arg1  : vector<16xf32>, !xegpu.tensor_desc<16xf32>
-  return
-}
-
-// -----
-// CHECK: function: store_scatter_with_transpose_effect:
-// CHECK-NEXT: argument: <block argument> of type 'memref<128xf32>' at index: 0
-// CHECK-NEXT: layout  : Not assigned.
-// CHECK-NEXT: op    : %[[CST:.*]] = arith.constant dense<1.000000e+00> : vector<8x16xf32>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 1]
-// CHECK-NEXT: op    : %[[CST0:.*]] = arith.constant dense<true> : vector<16xi1>
-// CHECK-NEXT: layout for result #0: lane_layout: [16], lane_data: [1]
-// CHECK-NEXT: op    : %[[CST1:.*]] = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
-// CHECK-NEXT: layout for result #0: lane_layout: [16], lane_data: [1]
-// CHECK-NEXT: op    : %[[T0:.*]] = xegpu.create_tdesc %{{.*}}, %[[CST1]] : memref<128xf32>, vector<16xindex> -> !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<chunk_size = 8 : i64>>
-// CHECK-NEXT: layout for result #0: lane_layout: [16, 1], lane_data: [1, 1]
-func.func @store_scatter_with_transpose_effect(%arg0: memref<128xf32>) {
-  %cst = arith.constant dense<1.000000e+00> : vector<8x16xf32>
-  %cst_0 = arith.constant dense<true> : vector<16xi1>
-  %cst_1 = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
-  %0 = xegpu.create_tdesc %arg0, %cst_1 : memref<128xf32>, vector<16xindex> -> !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<chunk_size = 8 : i64>>
-  xegpu.store %cst, %0, %cst_0 <{transpose}> : vector<8x16xf32>, !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<chunk_size = 8 : i64>>, vector<16xi1>
-  return
-}
-
-// -----
-// CHECK: function: store_scatter_1d:
-// CHECK-NEXT: argument: <block argument> of type 'vector<16xf32>' at index: 0
-// CHECK-NEXT: layout  : lane_layout: [16], lane_data: [1]
-// CHECK-NEXT: argument: <block argument> of type 'memref<256xf32>' at index: 1
-// CHECK-NEXT: layout  : Not assigned.
-// CHECK-NEXT: op    : %[[CST:.*]] = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
-// CHECK-NEXT: layout for result #0: lane_layout: [16], lane_data: [1]
-// CHECK-NEXT: op    : %[[CST1:.*]] = arith.constant dense<true> : vector<16xi1>
-// CHECK-NEXT: layout for result #0: lane_layout: [16], lane_data: [1]
-// CHECK-NEXT: op    : %[[T0:.*]] = xegpu.create_tdesc %{{.*}}, %[[CST]] : memref<256xf32>, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
-// CHECK-NEXT: layout for result #0: lane_layout: [16], lane_data: [1]
-func.func @store_scatter_1d(%arg0: vector<16xf32>, %arg1: memref<256xf32>) {
-  %cst = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
-  %cst_0 = arith.constant dense<true> : vector<16xi1>
-  %0 = xegpu.create_tdesc %arg1, %cst : memref<256xf32>, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
-  xegpu.store %arg0, %0, %cst_0  : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1>
-  return
-}
-
-// -----
-// CHECK: function: vector_bitcast_i16_to_i8:
-// CHECK-NEXT: argument: <block argument> of type 'memref<8x16xi16>' at index: 0
-// CHECK-NEXT: layout  : Not assigned.
-// CHECK-NEXT: argument: <block argument> of type 'memref<32x16xi8>' at index: 1
-// CHECK-NEXT: layout  : Not assigned.
-// CHECK-NEXT: argument: <block argument> of type 'memref<8x16xi32>' at index: 2
-// CHECK-NEXT: layout  : Not assigned.
-// CHECK-NEXT: op    : %{{.*}} = arith.constant 0 : index
-// CHECK-NEXT: layout for result #0: Not assigned.
-// CHECK-NEXT: op    : %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xi16> -> !xegpu.tensor_desc<8x16xi16>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 1]
-// CHECK-NEXT: op    : %[[T1:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<32x16xi8> -> !xegpu.tensor_desc<32x16xi8>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [4, 1]
-// CHECK-NEXT: op    : %[[T2:.*]] = xegpu.load_nd %[[T0]]  : !xegpu.tensor_desc<8x16xi16> -> vector<8x16xi16>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 1]
-// CHECK-NEXT: op    : %[[T3:.*]] = xegpu.load_nd %[[T1]]  : !xegpu.tensor_desc<32x16xi8> -> vector<32x16xi8>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [4, 1]
-// CHECK-NEXT: op    : %[[T4:.*]] = vector.bitcast %[[T2]] : vector<8x16xi16> to vector<8x32xi8>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 2]
-// CHECK-NEXT: op    : %[[T5:.*]] = xegpu.dpas %[[T4]], %[[T3]] : vector<8x32xi8>, vector<32x16xi8> -> vector<8x16xi32>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 1]
-// CHECK-NEXT: op    : %[[T6:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xi32> -> !xegpu.tensor_desc<8x16xi32>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 1]
-func.func @vector_bitcast_i16_to_i8(%arg0: memref<8x16xi16>, %arg1: memref<32x16xi8>, %arg2: memref<8x16xi32>) {
-  %c0 = arith.constant 0 : index
-  %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x16xi16> -> !xegpu.tensor_desc<8x16xi16>
-  %1 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<32x16xi8> -> !xegpu.tensor_desc<32x16xi8>
-  %2 = xegpu.load_nd %0  : !xegpu.tensor_desc<8x16xi16> -> vector<8x16xi16>
-  %3 = xegpu.load_nd %1  : !xegpu.tensor_desc<32x16xi8> -> vector<32x16xi8>
-  %4 = vector.bitcast %2 : vector<8x16xi16> to vector<8x32xi8>
-  %5 = xegpu.dpas %4, %3 : vector<8x32xi8>, vector<32x16xi8> -> vector<8x16xi32>
-  %6 = xegpu.create_nd_tdesc %arg2[%c0, %c0] : memref<8x16xi32> -> !xegpu.tensor_desc<8x16xi32>
-  xegpu.store_nd %5, %6  : vector<8x16xi32>, !xegpu.tensor_desc<8x16xi32>
-  return
-}
-
-// -----
-// CHECK: function: vector_bitcast_i8_to_f16:
-// CHECK-NEXT: argument: <block argument> of type 'memref<8x32xi8>' at index: 0
-// CHECK-NEXT: layout  : Not assigned.
-// CHECK-NEXT: argument: <block argument> of type 'memref<16x32xi8>' at index: 1
-// CHECK-NEXT: layout  : Not assigned.
-// CHECK-NEXT: argument: <block argument> of type 'memref<8x16xf32>' at index: 2
-// CHECK-NEXT: layout  : Not assigned.
-// CHECK-NEXT: op    : %{{.*}} = arith.constant 0 : index
-// CHECK-NEXT: layout for result #0: Not assigned.
-// CHECK-NEXT: op    : %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x32xi8> -> !xegpu.tensor_desc<8x32xi8>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 2]
-// CHECK-NEXT: op    : %[[T1:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<16x32xi8> -> !xegpu.tensor_desc<16x32xi8>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [4, 1]
-// CHECK-NEXT: op    : %[[T2:.*]] = xegpu.load_nd %[[T0]]  : !xegpu.tensor_desc<8x32xi8> -> vector<8x32xi8>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 2]
-// CHECK-NEXT: op    : %[[T3:.*]] = xegpu.load_nd %[[T1]]  : !xegpu.tensor_desc<16x32xi8> -> vector<16x32xi8>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [4, 1]
-// CHECK-NEXT: op    : %[[T4:.*]] = vector.bitcast %[[T2]] : vector<8x32xi8> to vector<8x16xf16>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 1]
-// CHECK-NEXT: op    : %[[T5:.*]] = vector.bitcast %[[T3]] : vector<16x32xi8> to vector<16x16xf16>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [2, 1]
-// CHECK-NEXT: op    : %[[T6:.*]] = xegpu.dpas %[[T4]], %[[T5]] : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 1]
-// CHECK-NEXT: op    : %[[T7:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 1]
-func.func @vector_bitcast_i8_to_f16(%arg0: memref<8x32xi8>, %arg1: memref<16x32xi8>, %arg2: memref<8x16xf32>) {
-  %c0 = arith.constant 0 : index
-  %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x32xi8> -> !xegpu.tensor_desc<8x32xi8>
-  %1 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<16x32xi8> -> !xegpu.tensor_desc<16x32xi8>
-  %2 = xegpu.load_nd %0  : !xegpu.tensor_desc<8x32xi8> -> vector<8x32xi8>
-  %3 = xegpu.load_nd %1  : !xegpu.tensor_desc<16x32xi8> -> vector<16x32xi8>
-  %4 = vector.bitcast %2 : vector<8x32xi8> to vector<8x16xf16>
-  %5 = vector.bitcast %3 : vector<16x32xi8> to vector<16x16xf16>
-  %6 = xegpu.dpas %4, %5 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
-  %7 = xegpu.create_nd_tdesc %arg2[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
-  xegpu.store_nd %6, %7  : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
-  return
-}
-
-// -----
-// CHECK: function: binary_op_one_use:
-// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<8x16xf16>' at index: 0
-// CHECK-NEXT: layout  : lane_layout: [1, 16], lane_data: [1, 1]
-// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<16x16xf16>' at index: 1
-// CHECK-NEXT: layout  : lane_layout: [1, 16], lane_data: [2, 1]
-// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<8x16xf32>' at index: 2
-// CHECK-NEXT: layout  : lane_layout: [1, 16], lane_data: [1, 1]
-// CHECK-NEXT: op    : %[[T0:.*]] = xegpu.load_nd %{{.*}}  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 1]
-// CHECK-NEXT: op    : %[[T1:.*]] = xegpu.load_nd %{{.*}}  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [2, 1]
-// CHECK-NEXT: op    : %[[T2:.*]] = xegpu.load_nd %{{.*}}  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [2, 1]
-// CHECK-NEXT: op    : %[[T3:.*]] = arith.addf %[[T1]], %[[T2]] : vector<16x16xf16>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [2, 1]
-// CHECK-NEXT: op    : %[[T4:.*]] = xegpu.dpas %[[T0]], %[[T3]] : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 1]
-func.func @binary_op_one_use(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: !xegpu.tensor_desc<8x16xf32>) {
-  %0 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
-  %1 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
-  %2 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
-  %3 = arith.addf %1, %2 : vector<16x16xf16>
-  %4 = xegpu.dpas %0, %3 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
-  xegpu.store_nd %4, %arg2  : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
-  return
-}
-
-// -----
-// CHECK: function: binary_op_multiple_uses:
-// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<8x16xf16>' at index: 0
-// CHECK-NEXT: layout  : lane_layout: [1, 16], lane_data: [1, 1]
-// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<16x16xf16>' at index: 1
-// CHECK-NEXT: layout  : lane_layout: [1, 16], lane_data: [1, 1]
-// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<8x16xf32>' at index: 2
-// CHECK-NEXT: layout  : lane_layout: [1, 16], lane_data: [1, 1]
-// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<16x16xf16>' at index: 3
-// CHECK-NEXT: layout  : lane_layout: [1, 16], lane_data: [1, 1]
-// CHECK-NEXT: op    : %[[T0:.*]] = xegpu.load_nd %{{.*}}  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 1]
-// CHECK-NEXT: op    : %[[T1:.*]] = xegpu.load_nd %{{.*}}  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 1]
-// CHECK-NEXT: op    : %[[CST:.*]] = arith.constant dense<1.000000e+00> : vector<16x16xf16>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 1]
-// CHECK-NEXT: op    : %[[T2:.*]] = arith.addf %[[T1]], %[[CST]] : vector<16x16xf16>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 1]
-// CHECK-NEXT: op    : %[[T3:.*]] = xegpu.dpas %[[T0]], %[[T2]] : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 1]
-func.func @binary_op_multiple_uses(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: !xegpu.tensor_desc<8x16xf32>, %arg3: !xegpu.tensor_desc<16x16xf16>) {
-  %0 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
-  %1 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
-  %cst = arith.constant dense<1.000000e+00> : vector<16x16xf16>
-  %2 = arith.addf %1, %cst : vector<16x16xf16>
-  %3 = xegpu.dpas %0, %2 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
-  xegpu.store_nd %3, %arg2  : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
-  xegpu.store_nd %2, %arg3  : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16>
-  return
-}
-
-// -----
-// CHECK: function: for_op:
-// CHECK-NEXT: argument: <block argument> of type 'memref<8x128xf16>' at index: 0
-// CHECK-NEXT: layout  : Not assigned.
-// CHECK-NEXT: argument: <block argument> of type 'memref<128x16xf16>' at index: 1
-// CHECK-NEXT: layout  : Not assigned.
-// CHECK-NEXT: argument: <block argument> of type 'memref<8x16xf32>' at index: 2
-// CHECK-NEXT: layout  : Not assigned.
-// CHECK-NEXT: op    : %{{.*}} = arith.constant 0 : index
-// CHECK-NEXT: layout for result #0: Not assigned.
-// CHECK-NEXT: op    : %{{.*}} = arith.constant 128 : index
-// CHECK-NEXT: layout for result #0: Not assigned.
-// CHECK-NEXT: op    : %{{.*}} = arith.constant 16 : index
-// CHECK-NEXT: layout for result #0: Not assigned.
-// CHECK-NEXT: op    : %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x128xf16> -> !xegpu.tensor_desc<8x16xf16>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 1]
-// CHECK-NEXT: op    : %[[T1:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<128x16xf16> -> !xegpu.tensor_desc<16x16xf16>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [2, 1]
-// CHECK-NEXT: op    : %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<8x16xf32>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 1]
-// CHECK-NEXT: op    : %[[T4:.*]] = xegpu.load_nd %{{.*}}  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 1]
-// CHECK-NEXT: op    : %[[T5:.*]] = xegpu.load_nd %{{.*}}  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [2, 1]
-// CHECK-NEXT: op    : %[[T6:.*]] = xegpu.dpas %[[T4]], %[[T5]], %{{.*}} : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 1]
-// CHECK-NEXT: op    : %[[T7:.*]] = xegpu.update_nd_offset %{{.*}} : !xegpu.tensor_desc<8x16xf16>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 1]
-// CHECK-NEXT: op    : %[[T8:.*]] = xegpu.update_nd_offset %{{.*}} : !xegpu.tensor_desc<16x16xf16>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [2, 1]
-// CHECK-NEXT: op    : scf.for
-// CHECK-NEXT: layout for result #0: Not assigned.
-// CHECK-NEXT: layout for result #1: Not assigned.
-// CHECK-NEXT: layout for result #2: lane_layout: [1, 16], lane_data: [1, 1]
-// CHECK-NEXT: op    : %[[T3:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 1]
-func.func @for_op(%arg0: memref<8x128xf16>, %arg1: memref<128x16xf16>, %arg2: memref<8x16xf32>) {
-  %c0 = arith.constant 0 : index
-  %c128 = arith.constant 128 : index
-  %c16 = arith.constant 16 : index
-  %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x128xf16> -> !xegpu.tensor_desc<8x16xf16>
-  %1 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<128x16xf16> -> !xegpu.tensor_desc<16x16xf16>
-  %cst = arith.constant dense<0.000000e+00> : vector<8x16xf32>
-  %2:3 = scf.for %arg3 = %c0 to %c128 step %c16 iter_args(%arg4 = %0, %arg5 = %1, %arg6 = %cst) -> (!xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<16x16xf16>, vector<8x16xf32>) {
-    %4 = xegpu.load_nd %arg4  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
-    %5 = xegpu.load_nd %arg5  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
-    %6 = xegpu.dpas %4, %5, %arg6 : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
-    %7 = xegpu.update_nd_offset %arg4, [%c0, %c16] : !xegpu.tensor_desc<8x16xf16>
-    %8 = xegpu.update_nd_offset %arg5, [%c16, %c0] : !xegpu.tensor_desc<16x16xf16>
-    scf.yield %7, %8, %6 : !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<16x16xf16>, vector<8x16xf32>
-  }
-  %3 = xegpu.create_nd_tdesc %arg2[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
-  xegpu.store_nd %2#2, %3  : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
-  return
-}
-
-// -----
-// CHECK: function: if_single_use:
-// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<8x16xf16>' at index: 0
-// CHECK-NEXT: layout  : lane_layout: [1, 16], lane_data: [1, 1]
-// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<16x16xf16>' at index: 1
-// CHECK-NEXT: layout  : lane_layout: [1, 16], lane_data: [2, 1]
-// CHECK-NEXT: argument: <block argument> of type 'i1' at index: 2
-// CHECK-NEXT: layout  : Not assigned.
-// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<8x16xf32>' at index: 3
-// CHECK-NEXT: layout  : lane_layout: [1, 16], lane_data: [1, 1]
-// CHECK-NEXT: op    : %[[T0:.*]] = xegpu.load_nd %{{.*}}  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 1]
-// CHECK-NEXT: op    : %[[T3:.*]] = xegpu.load_nd %{{.*}}  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [2, 1]
-// CHECK-NEXT: op    : %[[T4:.*]] = xegpu.load_nd %{{.*}}  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [2, 1]
-// CHECK-NEXT: op    : scf.if
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [2, 1]
-// CHECK-NEXT: op    : %[[T2:.*]] = xegpu.dpas %[[T0]], %{{.*}} : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 1]
-func.func @if_single_use(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: i1, %arg3: !xegpu.tensor_desc<8x16xf32>) {
-  %0 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
-  %1 = scf.if %arg2 -> (vector<16x16xf16>) {
-    %3 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
-    scf.yield %3 : vector<16x16xf16>
-  } else {
-    %3 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
-    scf.yield %3 : vector<16x16xf16>
-  }
-  %2 = xegpu.dpas %0, %1 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
-  xegpu.store_nd %2, %arg3  : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
-  return
-}
-
-// -----
-// CHECK: function: if_multiple_uses:
-// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<8x16xf16>' at index: 0
-// CHECK-NEXT: layout  : lane_layout: [1, 16], lane_data: [1, 1]
-// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<16x16xf16>' at index: 1
-// CHECK-NEXT: layout  : lane_layout: [1, 16], lane_data: [1, 1]
-// CHECK-NEXT: argument: <block argument> of type 'i1' at index: 2
-// CHECK-NEXT: layout  : Not assigned.
-// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<8x16xf32>' at index: 3
-// CHECK-NEXT: layout  : lane_layout: [1, 16], lane_data: [1, 1]
-// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<16x16xf16>' at index: 4
-// CHECK-NEXT: layout  : lane_layout: [1, 16], lane_data: [1, 1]
-// CHECK-NEXT: op    : %[[T0:.*]] = xegpu.load_nd %{{.*}}  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 1]
-// CHECK-NEXT: op    : %[[T3:.*]] = xegpu.load_nd %{{.*}}  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 1]
-// CHECK-NEXT: op    : %[[T4:.*]] = xegpu.load_nd %{{.*}}  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 1]
-// CHECK-NEXT: op    : scf.if
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 1]
-// CHECK-NEXT: op    : %[[T2:.*]] = xegpu.dpas %[[T0]], %{{.*}} : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 1]
-func.func @if_multiple_uses(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: i1, %arg3: !xegpu.tensor_desc<8x16xf32>, %arg4: !xegpu.tensor_desc<16x16xf16>) {
-  %0 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
-  %1 = scf.if %arg2 -> (vector<16x16xf16>) {
-    %3 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
-    scf.yield %3 : vector<16x16xf16>
-  } else {
-    %3 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
-    scf.yield %3 : vector<16x16xf16>
-  }
-  %2 = xegpu.dpas %0, %1 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
-  xegpu.store_nd %2, %arg3  : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
-  xegpu.store_nd %1, %arg4  : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16>
-  return
-}
-
-// -----
-// CHECK: function: vector_outer_reduction:
-// CHECK-NEXT: argument: <block argument> of type 'vector<16x16xf32>' at index: 0
-// CHECK-NEXT: layout  : lane_layout: [1, 16], lane_data: [1, 1]
-// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<16xf32>' at index: 1
-// CHECK-NEXT: layout  : lane_layout: [16], lane_data: [1]
-// CHECK-NEXT: op    : %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<16xf32>
-// CHECK-NEXT: layout for result #0: lane_layout: [16], lane_data: [1]
-// CHECK-NEXT: op    : %[[T0:.*]] = vector.multi_reduction <add>, %{{.*}}, %[[CST]] [0] : vector<16x16xf32> to vector<16xf32>
-// CHECK-NEXT: layout for result #0: lane_layout: [16], lane_data: [1]
-func.func @vector_outer_reduction(%arg0: vector<16x16xf32>, %arg1: !xegpu.tensor_desc<16xf32>) {
-  %cst = arith.constant dense<0.000000e+00> : vector<16xf32>
-  %0 = vector.multi_reduction <add>, %arg0, %cst [0] : vector<16x16xf32> to vector<16xf32>
-  xegpu.store_nd %0, %arg1  : vector<16xf32>, !xegpu.tensor_desc<16xf32>
-  return
-}
-
-// -----
-// CHECK: function: vector_inner_reduction:
-// CHECK-NEXT: argument: <block argument> of type 'vector<16x16xf32>' at index: 0
-// CHECK-NEXT: layout  : lane_layout: [1, 16], lane_data: [1, 1]
-// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<16xf32>' at index: 1
-// CHECK-NEXT: layout  : lane_layout: [16], lane_data: [1]
-// CHECK-NEXT: op    : %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<16xf32>
-// CHECK-NEXT: layout for result #0: lane_layout: [16], lane_data: [1]
-// CHECK-NEXT: op    : %[[T0:.*]] = vector.multi_reduction <add>, %{{.*}}, %[[CST]] [1] : vector<16x16xf32> to vector<16xf32>
-// CHECK-NEXT: layout for result #0: lane_layout: [16], lane_data: [1]
-func.func @vector_inner_reduction(%arg0: vector<16x16xf32>, %arg1: !xegpu.tensor_desc<16xf32>) {
-  %cst = arith.constant dense<0.000000e+00> : vector<16xf32>
-  %0 = vector.multi_reduction <add>, %arg0, %cst [1] : vector<16x16xf32> to vector<16xf32>
-  xegpu.store_nd %0, %arg1  : vector<16xf32>, !xegpu.tensor_desc<16xf32>
-  return
-}
-
-// -----
-// CHECK: function: update_nd_offset_1d:
-// CHECK: op    : %[[CST:.*]] = arith.constant dense<1.000000e+00> : vector<16xf32>
-// CHECK-NEXT: layout for result #0: lane_layout: [16], lane_data: [1]
-// CHECK-NEXT: op    : %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}}[%{{.*}}] : memref<256xf32> -> !xegpu.tensor_desc<16xf32>
-// CHECK-NEXT: layout for result #0: lane_layout: [16], lane_data: [1]
-// CHECK-NEXT: op    : %[[T1:.*]] = xegpu.update_nd_offset %[[T0]], [%{{.*}}] : !xegpu.tensor_desc<16xf32>
-// CHECK-NEXT: layout for result #0: lane_layout: [16], lane_data: [1]
-func.func @update_nd_offset_1d(%arg0: memref<256xf32>){
-  %c0 = arith.constant 0 : index
-  %c32 = arith.constant 32 : index
-  %1 = arith.constant dense<1.000000e+00> : vector<16xf32>
-  %0 = xegpu.create_nd_tdesc %arg0[%c0] : memref<256xf32> -> !xegpu.tensor_desc<16xf32>
-  %2 = xegpu.update_nd_offset %0, [%c32] : !xegpu.tensor_desc<16xf32>
-  xegpu.store_nd %1, %2 : vector<16xf32>, !xegpu.tensor_desc<16xf32>
-  return
-}
-
-// -----
-// CHECK: function: update_nd_offset_2d:
-// CHECK: op    : %[[CST:.*]] = arith.constant dense<1.000000e+00> : vector<16x16xf32>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 1]
-// CHECK-NEXT: op    : %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}}[%{{.*}}] : memref<256x256xf32> -> !xegpu.tensor_desc<16x16xf32>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 1]
-// CHECK-NEXT: op    : %[[T1:.*]] = xegpu.update_nd_offset %[[T0]], [%{{.*}}] : !xegpu.tensor_desc<16x16xf32>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 1]
-func.func @update_nd_offset_2d(%arg0: memref<256x256xf32>){
-  %c0 = arith.constant 0 : index
-  %c32 = arith.constant 32 : index
-  %1 = arith.constant dense<1.000000e+00> : vector<16x16xf32>
-  %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<256x256xf32> -> !xegpu.tensor_desc<16x16xf32>
-  %2 = xegpu.update_nd_offset %0, [%c32, %c32] : !xegpu.tensor_desc<16x16xf32>
-  xegpu.store_nd %1, %2 : vector<16x16xf32>, !xegpu.tensor_desc<16x16xf32>
-  return
-}
-
-// -----
-// CHECK: function: prefetch_2d:
-// CHECK: layout for result #0: Not assigned.
-// CHECK-NEXT: op    : %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}}[%{{.*}}] : memref<256x256xf16> -> !xegpu.tensor_desc<16x16xf16>
-// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 1]
-func.func @prefetch_2d(%arg0: memref<256x256xf16>){
-  %c0 = arith.constant 0 : index
-  %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<256x256xf16> -> !xegpu.tensor_desc<16x16xf16>
-  xegpu.prefetch_nd %0 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>: !xegpu.tensor_desc<16x16xf16>
-  return
-}
-
-// -----
-// CHECK: function: prefetch_1d:
-// CHECK: op    : %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}}[%{{.*}}] : memref<256xf16> -> !xegpu.tensor_desc<16xf16>
-// CHECK-NEXT: layout for result #0: lane_layout: [16], lane_data: [1]
-func.func @prefetch_1d(%arg0: memref<256xf16>){
-  %c0 = arith.constant 0 : index
-  %0 = xegpu.create_nd_tdesc %arg0[%c0] : memref<256xf16> -> !xegpu.tensor_desc<16xf16>
-  xegpu.prefetch_nd %0 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>: !xegpu.tensor_desc<16xf16>
-  return
-}


        


More information about the Mlir-commits mailing list