[Mlir-commits] [mlir] Revert "[mlir][xegpu] Add XeGPU subgroup map propagation analysis for XeGPU SIMT distribution." (PR #131364)

Charitha Saumya llvmlistbot at llvm.org
Fri Mar 14 10:32:24 PDT 2025


https://github.com/charithaintc created https://github.com/llvm/llvm-project/pull/131364

Reverts llvm/llvm-project#130240

>From 07f03ef6ab2a35ad904a62a85c45c4f7644d7e75 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <136391709+charithaintc at users.noreply.github.com>
Date: Fri, 14 Mar 2025 10:32:05 -0700
Subject: [PATCH] =?UTF-8?q?Revert=20"[mlir][xegpu]=20Add=20XeGPU=20subgrou?=
 =?UTF-8?q?p=20map=20propagation=20analysis=20for=20XeGPU=20S=E2=80=A6"?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

This reverts commit 5eb557774df637c9e581bd3008cfc6d156a61902.
---
 .../mlir/Dialect/XeGPU/Transforms/Passes.td   |  15 -
 .../Dialect/XeGPU/Transforms/CMakeLists.txt   |   1 -
 .../Transforms/XeGPUSubgroupDistribute.cpp    | 662 ------------------
 .../XeGPU/subgroup-map-propagation.mlir       | 563 ---------------
 4 files changed, 1241 deletions(-)
 delete mode 100644 mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
 delete mode 100644 mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir

diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
index 3e81f2d0ed786..1ecd6ce95322b 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
@@ -23,19 +23,4 @@ def XeGPUFoldAliasOps : Pass<"xegpu-fold-alias-ops"> {
   ];
 }
 
-def XeGPUSubgroupDistribute : Pass<"xegpu-subgroup-distribute"> {
-  let summary = "Distribute XeGPU ops to work items";
-  let description = [{
-    The pass distributes subgroup level (SIMD) XeGPU ops to work items.
-  }];
-  let dependentDialects = [
-      "memref::MemRefDialect", "xegpu::XeGPUDialect", "vector::VectorDialect"
-  ];
-  let options = [
-    Option<"printOnly", "print-analysis-only", "bool",
-         /*default=*/"false",
-         "Print the result of the subgroup map propagation analysis and exit.">
-  ];
-}
-
 #endif // MLIR_DIALECT_XEGPU_TRANSFORMS_PASSES_TD
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
index 124e904edb543..7fb64d3b97b87 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
@@ -1,6 +1,5 @@
 add_mlir_dialect_library(MLIRXeGPUTransforms
   XeGPUFoldAliasOps.cpp
-  XeGPUSubgroupDistribute.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/XeGPU
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
deleted file mode 100644
index 86e07697f437c..0000000000000
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ /dev/null
@@ -1,662 +0,0 @@
-//===- XeGPUSubgroupDistribute.cpp - XeGPU Subgroup Distribute Pass -------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
-#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
-#include "mlir/Analysis/DataFlow/SparseAnalysis.h"
-#include "mlir/Analysis/DataFlowFramework.h"
-#include "mlir/Dialect/GPU/IR/GPUDialect.h"
-#include "mlir/Dialect/MemRef/IR/MemRef.h"
-#include "mlir/Dialect/Vector/IR/VectorOps.h"
-#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
-#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
-#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/Interfaces/FunctionInterfaces.h"
-#include "llvm/ADT/TypeSwitch.h"
-#include "llvm/Support/raw_ostream.h"
-
-namespace mlir {
-namespace xegpu {
-#define GEN_PASS_DEF_XEGPUSUBGROUPDISTRIBUTE
-#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
-} // namespace xegpu
-} // namespace mlir
-
-using namespace mlir;
-using namespace mlir::dataflow;
-
-/// HW dependent constants.
-/// TODO: These constants should be queried from the target information.
-constexpr unsigned subgroupSize = 16; // How many work items in a subgroup.
-/// If DPAS A or B operands have low precision element types they must be packed
-/// according to the following sizes.
-constexpr unsigned packedSizeInBitsForDefault =
-    16; // Minimum packing size per register for DPAS A.
-constexpr unsigned packedSizeInBitsForDpasB =
-    32; // Minimum packing size per register for DPAS B.
-
-namespace {
-
-///===----------------------------------------------------------------------===///
-/// Layout
-///===----------------------------------------------------------------------===///
-
-/// Helper class to store the ND layout of work items within a subgroup and data
-/// owned by each work item.
-struct Layout {
-  SmallVector<int64_t, 3> layout;
-  Layout() = default;
-  Layout(const Layout &other) = default;
-  Layout(std::initializer_list<int64_t> list) : layout(list) {}
-  void print(llvm::raw_ostream &os) const;
-  size_t size() const { return layout.size(); }
-  int64_t operator[](size_t idx) const;
-};
-
-void Layout::print(llvm::raw_ostream &os) const {
-  os << "[";
-  llvm::interleaveComma(layout, os);
-  os << "]";
-}
-
-int64_t Layout::operator[](size_t idx) const {
-  assert(idx < layout.size() && "Index out of bounds.");
-  return layout[idx];
-}
-
-/// WiLayout represents the layout of work items within a subgroup when it
-/// accesses some value. WiData represents the layout of data owned by each work
-/// item.
-using WiLayout = Layout;
-using WiData = Layout;
-
-///===----------------------------------------------------------------------===///
-/// SGMap
-///===----------------------------------------------------------------------===///
-
-/// Helper class for tracking the analysis state of a value. For SGPropagation,
-/// the analysis state is simply the wi_layout and wi_data of each value.
-/// Purpose of this analysis to propagate some unique layout for each value in
-/// the program starting from some known values (like DPAS, StoreNd, etc.).
-///
-/// Given this, SGMap satisifies the following properties:
-///  1) SGMap is a lattice with two states - assigned and not assigned.
-///  2) Two SGMap values are equal if they are both assigned or both not
-///  assigned. The concrete value of assigned state does not matter.
-///  3) The meet operator works as follows:
-///     - If current state is assigned, return the current state. (already
-///     a unique layout is assigned. don't change it)
-///     - Otherwise, return the other state.
-
-struct SGMap {
-private:
-  WiLayout wiLayout;
-  WiData wiData;
-
-public:
-  SGMap() = default;
-  SGMap(const SGMap &other) = default;
-  SGMap(const WiLayout &layout, const WiData &data)
-      : wiLayout(layout), wiData(data) {}
-
-  /// Two lattice values are equal if they have `some` layout. The actual
-  /// content of the layout does not matter.
-  bool operator==(const SGMap &other) const {
-    return this->isAssigned() == other.isAssigned();
-  }
-
-  static SGMap meet(const SGMap &lhs, const SGMap &rhs);
-
-  static SGMap join(const SGMap &lhs, const SGMap &rhs);
-
-  void print(raw_ostream &os) const;
-
-  bool isAssigned() const { return wiLayout.size() > 0 && wiData.size() > 0; }
-
-  SGMap getTransposedLayout(ArrayRef<int64_t> permutation) const;
-
-  const WiLayout &getLayout() const { return wiLayout; }
-  const WiData &getData() const { return wiData; }
-};
-
-void SGMap::print(raw_ostream &os) const {
-  if (isAssigned()) {
-    os << "wi_layout: ";
-    wiLayout.print(os);
-    os << ", wi_data: ";
-    wiData.print(os);
-  } else
-    os << "Not assigned.";
-}
-
-SGMap SGMap::meet(const SGMap &lhs, const SGMap &rhs) {
-  if (!lhs.isAssigned())
-    return rhs;
-  return lhs;
-}
-
-/// Since this is a backward analysis, join method is not used.
-SGMap SGMap::join(const SGMap &lhs, const SGMap &rhs) {
-  llvm_unreachable("Join should not be triggered by SGMapPropagation.");
-}
-
-/// Get the transposed layout according to the given permutation.
-SGMap SGMap::getTransposedLayout(ArrayRef<int64_t> permutation) const {
-  if (!isAssigned())
-    return {};
-  WiLayout newLayout;
-  WiData newData;
-  for (auto idx : permutation) {
-    newLayout.layout.push_back(wiLayout.layout[idx]);
-    newData.layout.push_back(wiData.layout[idx]);
-  }
-  return SGMap(newLayout, newData);
-}
-
-///===----------------------------------------------------------------------===///
-/// SGMapLattice
-///===----------------------------------------------------------------------===///
-
-/// Lattice holding the SGMap for each value.
-struct SGMapLattice : public Lattice<SGMap> {
-  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SGMapLattice)
-  using Lattice::Lattice;
-};
-
-/// Helper Functions to get default layouts. A `default layout` is a layout that
-/// is assigned to a value when the layout is not fixed by some anchor operation
-/// (like DPAS). This is the natural layout work items are arranged in a
-/// subgroup.
-
-/// Helper Function to get the default layout for uniform values like constants.
-/// For 1D vector, wi_layout is [subgroupSize] and wi_data is [1].
-/// For 2D vector, wi_layout is [1, subgroupSize] and wi_data is [1, 1].
-static SGMap getDefaultSgMap(unsigned rank) {
-  assert((rank == 1 || rank == 2) && "Expected 1D or 2D vector.");
-  if (rank == 1)
-    return SGMap(WiLayout({subgroupSize}), WiData({1}));
-  return SGMap(WiLayout({1, subgroupSize}), WiData({1, 1}));
-}
-
-/// Helper to get the default layout for a vector type.
-static SGMap getDefaultSgMap(VectorType vectorTy) {
-  /// Expecting a 1D or 2D vector.
-  assert((vectorTy.getRank() == 1 || vectorTy.getRank() == 2) &&
-         "Expected 1D or 2D vector.");
-  /// Expecting int or float element type.
-  assert(vectorTy.getElementType().isIntOrFloat() &&
-         "Expected int or float element type.");
-  /// If the rank is 1, then return default layout for 1D vector.
-  if (vectorTy.getRank() == 1)
-    return getDefaultSgMap(1);
-  /// Packing factor is determined by the element type bitwidth.
-  int packingFactor = 1;
-  auto bitwidth = vectorTy.getElementType().getIntOrFloatBitWidth();
-  if (bitwidth < packedSizeInBitsForDefault)
-    packingFactor = packedSizeInBitsForDefault / bitwidth;
-  return SGMap(WiLayout({1, subgroupSize}), WiData({1, packingFactor}));
-}
-
-/// Helper Function to get the expected layouts for DPAS operands. `wi_data` is
-/// set according to the following criteria:
-/// * For A operand, the data must be packed in minimum
-/// `packedSizeInBitsForDefault`
-/// * For B operand, the data must be packed in minimum
-/// `packedSizeInBitsForDpasB`
-static SGMap getSGMapForDPASOperand(VectorType vectorTy, unsigned operandNum) {
-  auto elementTy = vectorTy.getElementType();
-  assert(elementTy.isIntOrFloat() &&
-         "Expected int or float type in DPAS operands");
-  WiLayout layout({1, subgroupSize});
-  /// For B operand, data must be packed in minimum `packedDpasBSizeInBits` and
-  /// must have the VNNI format.
-  if (operandNum == 1 &&
-      elementTy.getIntOrFloatBitWidth() < packedSizeInBitsForDpasB) {
-    WiData data(
-        {packedSizeInBitsForDpasB / elementTy.getIntOrFloatBitWidth(), 1});
-    return SGMap(layout, data);
-  }
-  /// Otherwise, return the default layout for the vector type.
-  return getDefaultSgMap(vectorTy);
-}
-
-///===----------------------------------------------------------------------===///
-/// SGMapPropagation
-///===----------------------------------------------------------------------===///
-
-/// Backward data flow analysis to propagate the wi_layout and wi_data of each
-/// value in the program. Currently, the layouts for operands DPAS, StoreNd, and
-/// StoreScatter are fixed (known before propagation). Purpose of this analysis
-/// is to propagate those known layouts to all their producers and (other)
-/// consumers.
-class SGMapPropagation : public SparseBackwardDataFlowAnalysis<SGMapLattice> {
-private:
-  void visitDpasOp(xegpu::DpasOp dpas, ArrayRef<SGMapLattice *> operands,
-                   ArrayRef<const SGMapLattice *> results);
-
-  void visitStoreNdOp(xegpu::StoreNdOp store, ArrayRef<SGMapLattice *> operands,
-                      ArrayRef<const SGMapLattice *> results);
-
-  void visitStoreScatterOp(xegpu::StoreScatterOp storeScatter,
-                           ArrayRef<SGMapLattice *> operands,
-                           ArrayRef<const SGMapLattice *> results);
-
-  void visitLoadNdOp(xegpu::LoadNdOp load, ArrayRef<SGMapLattice *> operands,
-                     ArrayRef<const SGMapLattice *> results);
-
-  void visitLoadGatherOp(xegpu::LoadGatherOp load,
-                         ArrayRef<SGMapLattice *> operands,
-                         ArrayRef<const SGMapLattice *> results);
-
-  void visitTransposeOp(vector::TransposeOp transpose,
-                        ArrayRef<SGMapLattice *> operands,
-                        ArrayRef<const SGMapLattice *> results);
-
-  void visitVectorBitcastOp(vector::BitCastOp bitcast,
-                            ArrayRef<SGMapLattice *> operands,
-                            ArrayRef<const SGMapLattice *> results);
-
-  void visitCreateDescOp(xegpu::CreateDescOp createDesc,
-                         ArrayRef<SGMapLattice *> operands,
-                         ArrayRef<const SGMapLattice *> results);
-
-  void visitUpdateNdOffsetOp(xegpu::UpdateNdOffsetOp updateNdOffset,
-                             ArrayRef<SGMapLattice *> operands,
-                             ArrayRef<const SGMapLattice *> results);
-
-  void visitVectorMultiReductionOp(vector::MultiDimReductionOp reduction,
-                                   ArrayRef<SGMapLattice *> operands,
-                                   ArrayRef<const SGMapLattice *> results);
-
-public:
-  SGMapPropagation(DataFlowSolver &solver, SymbolTableCollection &symbolTable)
-      : SparseBackwardDataFlowAnalysis(solver, symbolTable) {}
-  using SparseBackwardDataFlowAnalysis::SparseBackwardDataFlowAnalysis;
-
-  LogicalResult visitOperation(Operation *op, ArrayRef<SGMapLattice *> operands,
-                               ArrayRef<const SGMapLattice *> results) override;
-
-  void visitBranchOperand(OpOperand &operand) override {};
-
-  void visitCallOperand(OpOperand &operand) override {};
-
-  void visitExternalCall(CallOpInterface call,
-                         ArrayRef<SGMapLattice *> operands,
-                         ArrayRef<const SGMapLattice *> results) override {};
-
-  void setToExitState(SGMapLattice *lattice) override {
-    (void)lattice->meet(SGMap());
-  }
-};
-} // namespace
-
-LogicalResult
-SGMapPropagation::visitOperation(Operation *op,
-                                 ArrayRef<SGMapLattice *> operands,
-                                 ArrayRef<const SGMapLattice *> results) {
-  TypeSwitch<Operation *>(op)
-      .Case<xegpu::DpasOp>(
-          [&](auto dpasOp) { visitDpasOp(dpasOp, operands, results); })
-      .Case<xegpu::StoreNdOp>(
-          [&](auto storeNdOp) { visitStoreNdOp(storeNdOp, operands, results); })
-      .Case<xegpu::StoreScatterOp>([&](auto storeScatterOp) {
-        visitStoreScatterOp(storeScatterOp, operands, results);
-      })
-      .Case<xegpu::LoadNdOp>(
-          [&](auto loadNdOp) { visitLoadNdOp(loadNdOp, operands, results); })
-      .Case<xegpu::LoadGatherOp>([&](auto loadGatherOp) {
-        visitLoadGatherOp(loadGatherOp, operands, results);
-      })
-      .Case<xegpu::CreateDescOp>([&](auto createDescOp) {
-        visitCreateDescOp(createDescOp, operands, results);
-      })
-      .Case<xegpu::UpdateNdOffsetOp>([&](auto updateNdOffsetOp) {
-        visitUpdateNdOffsetOp(updateNdOffsetOp, operands, results);
-      })
-      /// No need to propagate the layout to operands in CreateNdDescOp because
-      /// they are scalars (offsets, sizes, etc.).
-      .Case<xegpu::CreateNdDescOp>([&](auto createNdDescOp) {})
-      .Case<vector::TransposeOp>([&](auto transposeOp) {
-        visitTransposeOp(transposeOp, operands, results);
-      })
-      .Case<vector::BitCastOp>([&](auto bitcastOp) {
-        visitVectorBitcastOp(bitcastOp, operands, results);
-      })
-      .Case<vector::MultiDimReductionOp>([&](auto reductionOp) {
-        visitVectorMultiReductionOp(reductionOp, operands, results);
-      })
-      /// All other ops.
-      .Default([&](Operation *op) {
-        for (const SGMapLattice *r : results) {
-          for (SGMapLattice *operand : operands) {
-            /// Propagate the layout of the result to the operand.
-            if (r->getValue().isAssigned())
-              meet(operand, *r);
-          }
-        }
-      });
-  /// Add a dependency from each result to program point after the operation.
-  for (const SGMapLattice *r : results) {
-    addDependency(const_cast<SGMapLattice *>(r), getProgramPointAfter(op));
-  }
-  return success();
-}
-
-void SGMapPropagation::visitVectorMultiReductionOp(
-    vector::MultiDimReductionOp reduction, ArrayRef<SGMapLattice *> operands,
-    ArrayRef<const SGMapLattice *> results) {
-  /// The layout of the result must be present.
-  auto resultLayout = results[0]->getValue();
-  if (!resultLayout.isAssigned())
-    return;
-  /// We only consider 2D -> 1D reductions at this point.
-  assert(resultLayout.getLayout().size() == 1 &&
-         "Expected 1D layout for reduction result.");
-  /// Given that the result is 1D, the layout of the operand should be 2D with
-  /// default layout.
-  auto operandLayout = getDefaultSgMap(2);
-  propagateIfChanged(operands[0], operands[0]->meet(operandLayout));
-  /// Accumulator should have the same layout as the result.
-  propagateIfChanged(operands[1], operands[1]->meet(resultLayout));
-}
-
-/// Propagate the layout of the result tensor to the source tensor descriptor in
-/// UpdateNdOffsetOp.
-void SGMapPropagation::visitUpdateNdOffsetOp(
-    xegpu::UpdateNdOffsetOp updateNdOffset, ArrayRef<SGMapLattice *> operands,
-    ArrayRef<const SGMapLattice *> results) {
-  /// The layout of the result must be present.
-  auto resultLayout = results[0]->getValue();
-  if (!resultLayout.isAssigned())
-    return;
-  /// Propagate the layout to the source operand.
-  propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
-}
-
-/// Set the layouts for DPAS A, B, and C operands.
-void SGMapPropagation::visitDpasOp(xegpu::DpasOp dpas,
-                                   ArrayRef<SGMapLattice *> operands,
-                                   ArrayRef<const SGMapLattice *> results) {
-  auto aTy = dpas.getLhsType();
-  auto bTy = dpas.getRhsType();
-  propagateIfChanged(operands[0],
-                     operands[0]->meet(getSGMapForDPASOperand(aTy, 0)));
-  propagateIfChanged(operands[1],
-                     operands[1]->meet(getSGMapForDPASOperand(bTy, 1)));
-  if (operands.size() > 2) {
-    auto cTy = dpas.getAccType();
-    propagateIfChanged(operands[2],
-                       operands[2]->meet(getSGMapForDPASOperand(cTy, 2)));
-  }
-};
-
-/// Set the layout for the value and tensor descriptor operands in StoreNdOp.
-void SGMapPropagation::visitStoreNdOp(xegpu::StoreNdOp store,
-                                      ArrayRef<SGMapLattice *> operands,
-                                      ArrayRef<const SGMapLattice *> results) {
-  auto storeLayout = getDefaultSgMap(store.getValueType());
-  /// Both operands should have the same layout
-  for (SGMapLattice *operand : operands) {
-    propagateIfChanged(operand, operand->meet(storeLayout));
-  }
-}
-
-/// Propagate the layout of the value to the tensor descriptor operand in
-/// LoadNdOp.
-void SGMapPropagation::visitLoadNdOp(xegpu::LoadNdOp load,
-                                     ArrayRef<SGMapLattice *> operands,
-                                     ArrayRef<const SGMapLattice *> results) {
-  auto valueLayout = results[0]->getValue();
-  /// Need the layout of the value to propagate to the tensor descriptor.
-  if (!valueLayout.isAssigned())
-    return;
-  SGMap tensorDescLayout = valueLayout;
-  /// LoadNdOp has the transpose effect. However, at the stage of this analysis
-  /// this effect is not expected and should be abstracted away. Emit a warning.
-  if (auto transpose = load.getTranspose()) {
-    load.emitWarning("Transpose effect is not expected for LoadNdOp at "
-                     "SGMapPropagation stage.");
-    tensorDescLayout = valueLayout.getTransposedLayout(transpose.value());
-  }
-  /// Propagate the new layout to the tensor descriptor operand.
-  propagateIfChanged(operands[0], operands[0]->meet(tensorDescLayout));
-}
-
-/// For vector::TransposeOp, the layout of the result is transposed and
-/// propagated to the operand.
-void SGMapPropagation::visitTransposeOp(
-    vector::TransposeOp transpose, ArrayRef<SGMapLattice *> operands,
-    ArrayRef<const SGMapLattice *> results) {
-  /// Need the layout of transpose result to propagate to the operands.
-  auto resultLayout = results[0]->getValue();
-  if (!resultLayout.isAssigned())
-    return;
-  auto newLayout = resultLayout.getTransposedLayout(transpose.getPermutation());
-  /// Propagate the new layout to the vector operand.
-  propagateIfChanged(operands[0], operands[0]->meet(newLayout));
-}
-
-/// For vector::BitCastOp, the wi_data of the source layout is changed based on
-/// the bit width of the source and result types.
-void SGMapPropagation::visitVectorBitcastOp(
-    vector::BitCastOp bitcast, ArrayRef<SGMapLattice *> operands,
-    ArrayRef<const SGMapLattice *> results) {
-  /// Need the layout of bitcast result to propagate to the operands.
-  auto resultLayout = results[0]->getValue();
-  if (!resultLayout.isAssigned())
-    return;
-  auto inElemTyBitWidth =
-      bitcast.getSourceVectorType().getElementType().getIntOrFloatBitWidth();
-  auto outElemTyBitWidth =
-      bitcast.getResultVectorType().getElementType().getIntOrFloatBitWidth();
-
-  /// WiLayout does not change.
-  const WiLayout &newWiLayout = resultLayout.getLayout();
-  const WiData &currData = resultLayout.getData();
-  WiData newWiData;
-  /// It's a widening bitcast
-  if (inElemTyBitWidth < outElemTyBitWidth) {
-    auto ratio = outElemTyBitWidth / inElemTyBitWidth;
-    newWiData = resultLayout.getData()[0] == 1
-                    ? WiData({1, currData[1] * ratio})
-                    : WiData({currData[0] * ratio, 1});
-  } else {
-    /// It's a narrowing bitcast
-    auto ratio = inElemTyBitWidth / outElemTyBitWidth;
-    newWiData = resultLayout.getData()[0] == 1
-                    ? WiData({1, currData[1] / ratio})
-                    : WiData({currData[0] / ratio, 1});
-  }
-
-  propagateIfChanged(operands[0],
-                     operands[0]->meet(SGMap(newWiLayout, newWiData)));
-}
-
-/// Propagate the layout of the result to the tensor descriptor and mask
-/// operands in LoadGatherOp.
-void SGMapPropagation::visitLoadGatherOp(
-    xegpu::LoadGatherOp load, ArrayRef<SGMapLattice *> operands,
-    ArrayRef<const SGMapLattice *> results) {
-  auto valueLayout = results[0]->getValue();
-  /// Need the layout of the value to propagate to the tensor descriptor.
-  if (!valueLayout.isAssigned())
-    return;
-
-  SGMap tensorDescLayout = valueLayout;
-  if (load.getTranspose()) {
-    /// LoadGatherOp has the transpose effect. However, at the stage of this
-    /// analyis this effect is not expected and should be abstracted away. Emit
-    /// a warning.
-    load.emitWarning("Transpose effect is not expected for LoadGatherOp at "
-                     "SGMapPropagation stage.");
-    tensorDescLayout = valueLayout.getTransposedLayout({1, 0});
-  }
-  /// Mask operand should have 1D default layout.
-  auto maskLayout = getDefaultSgMap(1);
-  /// Propagate the new layout to the tensor descriptor operand.
-  propagateIfChanged(operands[0], operands[0]->meet(tensorDescLayout));
-  /// Propagate the new layout to the mask operand.
-  propagateIfChanged(operands[1], operands[1]->meet(maskLayout));
-}
-
-/// Propagate the layout of the descriptor to the vector offset operand in
-/// CreateDescOp.
-void SGMapPropagation::visitCreateDescOp(
-    xegpu::CreateDescOp createDesc, ArrayRef<SGMapLattice *> operands,
-    ArrayRef<const SGMapLattice *> results) {
-  auto descLayout = results[0]->getValue();
-  /// Need the layout of the descriptor to propagate to the operands.
-  if (!descLayout.isAssigned())
-    return;
-  /// For offset operand propagate 1D default layout.
-  SGMap layout = getDefaultSgMap(1);
-  propagateIfChanged(operands[1], operands[1]->meet(layout));
-}
-
-/// Set the layout for the value, tensor descriptor, and mask operands in the
-/// StoreScatterOp.
-void SGMapPropagation::visitStoreScatterOp(
-    xegpu::StoreScatterOp storeScatter, ArrayRef<SGMapLattice *> operands,
-    ArrayRef<const SGMapLattice *> results) {
-  /// Currently, for 2D StoreScatterOp we expect that the height dimension of
-  /// the tensor descriptor is evenly divisible by the subgroup size.
-  /// TODO: Add support for other 2D shapes.
-  auto tdescShape = storeScatter.getTensorDescType().getShape();
-  if (tdescShape.size() > 1 && tdescShape[0] % subgroupSize != 0) {
-    storeScatter.emitError("Height dimension of the tensor descriptor should "
-                           "be evenly divisible by the subgroup size.");
-    return;
-  }
-  auto valueLayout = getDefaultSgMap(storeScatter.getValueType());
-  SGMap storeScatterLayout = valueLayout;
-  if (storeScatter.getTranspose()) {
-    /// StoreScatteOp allows transpose effect. However, at the stage of this
-    /// analyis this effect is not expected and should be abstracted away. Emit
-    /// a warning.
-    storeScatter.emitWarning("Transpose effect is not expected for "
-                             "StoreScatterOp at SGMapPropagation stage.");
-    storeScatterLayout = valueLayout.getTransposedLayout({1, 0});
-  }
-  /// Propagate the value layout.
-  propagateIfChanged(operands[0], operands[0]->meet(valueLayout));
-  /// Propagate the tensor descriptor layout.
-  propagateIfChanged(operands[1], operands[1]->meet(storeScatterLayout));
-  /// Use default 1D layout for mask operand.
-  auto maskLayout = getDefaultSgMap(1);
-  propagateIfChanged(operands[2], operands[2]->meet(maskLayout));
-}
-
-namespace {
-
-///===----------------------------------------------------------------------===///
-/// RunSGMapPropagation
-///===----------------------------------------------------------------------===///
-
-/// Driver class for running the SGMapPropagation analysis.
-class RunSGMapPropagation {
-public:
-  RunSGMapPropagation(Operation *op) : target(op) {
-    SymbolTableCollection symbolTable;
-    solver.load<DeadCodeAnalysis>();
-    solver.load<SparseConstantPropagation>();
-    solver.load<SGMapPropagation>(symbolTable);
-    (void)solver.initializeAndRun(op);
-  }
-
-  SGMap getSGMap(Value val);
-
-  void printAnalysisResult(llvm::raw_ostream &os);
-
-private:
-  DataFlowSolver solver;
-  const Operation *target;
-};
-} // namespace
-
-SGMap RunSGMapPropagation::getSGMap(Value val) {
-  auto *state = solver.lookupState<SGMapLattice>(val);
-  if (!state)
-    return {};
-  return state->getValue();
-}
-
-void RunSGMapPropagation::printAnalysisResult(llvm::raw_ostream &os) {
-  auto printFunctionResult = [&](FunctionOpInterface funcOp) {
-    os << "function: " << funcOp.getName() << ":\n";
-    // Function arguments
-    for (auto arg : funcOp.getArguments()) {
-      auto layout = getSGMap(arg);
-      os << "argument: " << arg << "\n";
-      os << "sg_map  : ";
-      layout.print(os);
-      os << "\n";
-    }
-    // Function ops
-    funcOp.walk([&](Operation *op) {
-      // Skip ops that do not have results
-      if (op->getResults().empty())
-        return;
-      os << "op    : ";
-      /// For control-flow ops, print the op name only.
-      if (isa<BranchOpInterface>(op) || isa<RegionBranchOpInterface>(op))
-        os << op->getName();
-      else
-        op->print(os);
-      os << "\n";
-      /// Print the sg_map for each result.
-      for (auto [i, r] : llvm::enumerate(op->getResults())) {
-        auto layout = getSGMap(r);
-        os << "sg_map for result #" << i << ": ";
-        layout.print(os);
-        os << "\n";
-      }
-    });
-  };
-
-  SmallVector<FunctionOpInterface> funcOps;
-  if (auto modOp = dyn_cast<ModuleOp>(target)) {
-    for (auto funcOp : modOp.getOps<FunctionOpInterface>()) {
-      funcOps.push_back(funcOp);
-    }
-    /// Collect all GpuFuncOps in the module.
-    for (auto gpuModOp : modOp.getOps<gpu::GPUModuleOp>()) {
-      for (auto gpuFuncOp : gpuModOp.getOps<FunctionOpInterface>()) {
-        funcOps.push_back(gpuFuncOp);
-      }
-    }
-  }
-  /// Print the analysis result for each function.
-  for (auto funcOp : funcOps) {
-    printFunctionResult(funcOp);
-  }
-}
-
-namespace {
-struct XeGPUSubgroupDistributePass final
-    : public xegpu::impl::XeGPUSubgroupDistributeBase<
-          XeGPUSubgroupDistributePass> {
-  XeGPUSubgroupDistributePass() = default;
-  XeGPUSubgroupDistributePass(const XeGPUSubgroupDistributePass &other) =
-      default;
-  XeGPUSubgroupDistributePass(xegpu::XeGPUSubgroupDistributeOptions options)
-      : XeGPUSubgroupDistributeBase(options) {}
-  void runOnOperation() override;
-};
-} // namespace
-
-void XeGPUSubgroupDistributePass::runOnOperation() {
-  Operation *op = getOperation();
-  RunSGMapPropagation solver(op);
-
-  // Print the analysis result and exit.
-  if (printOnly) {
-    auto &os = llvm::outs();
-    solver.printAnalysisResult(os);
-    return;
-  }
-}
diff --git a/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir b/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
deleted file mode 100644
index 1ae4348af33e6..0000000000000
--- a/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
+++ /dev/null
@@ -1,563 +0,0 @@
-// RUN: mlir-opt -xegpu-subgroup-distribute='print-analysis-only=true' -split-input-file %s | FileCheck %s
-
-// CHECK: function: test_dpas_f16:
-// CHECK-NEXT: argument: <block argument> of type 'memref<8x16xf16>' at index: 0
-// CHECK-NEXT: sg_map  : Not assigned.
-// CHECK-NEXT: argument: <block argument> of type 'memref<16x16xf16>' at index: 1
-// CHECK-NEXT: sg_map  : Not assigned.
-// CHECK-NEXT: argument: <block argument> of type 'memref<8x16xf32>' at index: 2
-// CHECK-NEXT: sg_map  : Not assigned.
-// CHECK-NEXT: op    : %{{.*}} = arith.constant 0 : index
-// CHECK-NEXT: sg_map for result #0: Not assigned.
-// CHECK-NEXT: op    : %{{.*}} = arith.constant dense<0.000000e+00> : vector<8x16xf32>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
-// CHECK-NEXT: op    : %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
-// CHECK-NEXT: op    : %[[T1:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1]
-// CHECK-NEXT: op    : %[[T2:.*]] = xegpu.load_nd %[[T0]]  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
-// CHECK-NEXT: op    : %[[T3:.*]] = xegpu.load_nd %[[T1]]  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1]
-// CHECK-NEXT: op    : %[[T4:.*]] = xegpu.dpas %[[T2]], %[[T3]], %{{.*}} : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
-// CHECK-NEXT: op    : %{{.*}} = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
-func.func @test_dpas_f16(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<8x16xf32>) {
-  %c0 = arith.constant 0 : index
-  %cst = arith.constant dense<0.000000e+00> : vector<8x16xf32>
-  %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
-  %1 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
-  %2 = xegpu.load_nd %0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
-  %3 = xegpu.load_nd %1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
-  %4 = xegpu.dpas %2, %3, %cst : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
-  %5 = xegpu.create_nd_tdesc %arg2[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
-  xegpu.store_nd %4, %5  : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
-  return
-}
-
-
-// -----
-// CHECK: function: test_dpas_i8:
-// CHECK-NEXT: argument: <block argument> of type 'vector<8x32xi8>' at index: 0
-// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 2]
-// CHECK-NEXT: argument: <block argument> of type 'vector<32x16xi8>' at index: 1
-// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [4, 1]
-// CHECK-NEXT: argument: <block argument> of type 'memref<8x16xi32>' at index: 2
-// CHECK-NEXT: sg_map  : Not assigned.
-// CHECK-NEXT: op    : %{{.*}} = arith.constant 0 : index
-// CHECK-NEXT: sg_map for result #0: Not assigned.
-// CHECK-NEXT: op    : %[[T0:.*]] = xegpu.dpas %{{.*}} : vector<8x32xi8>, vector<32x16xi8> -> vector<8x16xi32>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
-// CHECK-NEXT: op    : %[[T1:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xi32> -> !xegpu.tensor_desc<8x16xi32>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
-func.func @test_dpas_i8(%arg0: vector<8x32xi8>, %arg1: vector<32x16xi8>, %arg2: memref<8x16xi32>) {
-  %c0 = arith.constant 0 : index
-  %0 = xegpu.dpas %arg0, %arg1 : vector<8x32xi8>, vector<32x16xi8> -> vector<8x16xi32>
-  %1 = xegpu.create_nd_tdesc %arg2[%c0, %c0] : memref<8x16xi32> -> !xegpu.tensor_desc<8x16xi32>
-  xegpu.store_nd %0, %1  : vector<8x16xi32>, !xegpu.tensor_desc<8x16xi32>
-  return
-}
-
-// -----
-// CHECK: function: test_load_with_transpose_effect:
-// CHECK-NEXT: argument: <block argument> of type 'memref<8x16xf16>' at index: 0
-// CHECK-NEXT: sg_map  : Not assigned.
-// CHECK-NEXT: argument: <block argument> of type 'memref<16x16xf16>' at index: 1
-// CHECK-NEXT: sg_map  : Not assigned.
-// CHECK-NEXT: argument: <block argument> of type 'memref<8x16xf32>' at index: 2
-// CHECK-NEXT: sg_map  : Not assigned.
-// CHECK-NEXT: op    : %{{.*}} = arith.constant 0 : index
-// CHECK-NEXT: sg_map for result #0: Not assigned.
-// CHECK-NEXT: op    : %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<8x16xf32>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
-// CHECK-NEXT: op    : %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
-// CHECK-NEXT: op    : %[[T1:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [16, 1], wi_data: [1, 2]
-// CHECK-NEXT: op    : %[[T2:.*]] = xegpu.load_nd %[[T0]]  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
-// CHECK-NEXT: op    : %[[T3:.*]] = xegpu.load_nd %[[T1]] <{transpose = array<i64: 1, 0>}> : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1]
-// CHECK-NEXT: op    : %[[T4:.*]] = xegpu.dpas %[[T2]], %[[T3]], %[[CST]] : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
-// CHECK-NEXT: op    : %[[T5:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
-func.func @test_load_with_transpose_effect(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<8x16xf32>) {
-  %c0 = arith.constant 0 : index
-  %cst = arith.constant dense<0.000000e+00> : vector<8x16xf32>
-  %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
-  %1 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
-  %2 = xegpu.load_nd %0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
-  %3 = xegpu.load_nd %1 <{transpose = array<i64: 1, 0>}> : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
-  %4 = xegpu.dpas %2, %3, %cst : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
-  %5 = xegpu.create_nd_tdesc %arg2[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
-  xegpu.store_nd %4, %5  : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
-  return
-}
-
-// -----
-// CHECK: function: test_vector_transpose:
-// CHECK-NEXT: argument: <block argument> of type 'memref<8x16xf16>' at index: 0
-// CHECK-NEXT: sg_map  : Not assigned.
-// CHECK-NEXT: argument: <block argument> of type 'memref<16x16xf16>' at index: 1
-// CHECK-NEXT: sg_map  : Not assigned.
-// CHECK-NEXT: argument: <block argument> of type 'memref<8x16xf32>' at index: 2
-// CHECK-NEXT: sg_map  : Not assigned.
-// CHECK-NEXT: op    : %{{.*}} = arith.constant 0 : index
-// CHECK-NEXT: sg_map for result #0: Not assigned.
-// CHECK-NEXT: op    : %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<8x16xf32>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
-// CHECK-NEXT: op    : %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
-// CHECK-NEXT: op    : %[[T1:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [16, 1], wi_data: [1, 2]
-// CHECK-NEXT: op    : %[[T2:.*]] = xegpu.load_nd %[[T0]]  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
-// CHECK-NEXT: op    : %[[T3:.*]] = xegpu.load_nd %[[T1]]  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [16, 1], wi_data: [1, 2]
-// CHECK-NEXT: op    : %[[T4:.*]] = vector.transpose %[[T3]], [1, 0] : vector<16x16xf16> to vector<16x16xf16>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1]
-// CHECK-NEXT: op    : %[[T5:.*]] = xegpu.dpas %[[T2]], %[[T4]], %[[CST]] : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
-// CHECK-NEXT: op    : %[[T6:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
-func.func @test_vector_transpose(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<8x16xf32>) {
-  %c0 = arith.constant 0 : index
-  %cst = arith.constant dense<0.000000e+00> : vector<8x16xf32>
-  %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
-  %1 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
-  %2 = xegpu.load_nd %0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
-  %3 = xegpu.load_nd %1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
-  %4 = vector.transpose %3, [1, 0] : vector<16x16xf16> to vector<16x16xf16>
-  %5 = xegpu.dpas %2, %4, %cst : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
-  %6 = xegpu.create_nd_tdesc %arg2[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
-  xegpu.store_nd %5, %6  : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
-  return
-}
-
-// -----
-// CHECK: function: test_extf_truncf:
-// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<8x16xf16>' at index: 0
-// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
-// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<16x16xf16>' at index: 1
-// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [2, 1]
-// CHECK-NEXT: op    : %[[T0:.*]] = xegpu.load_nd %{{.*}}  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
-// CHECK-NEXT: op    : %[[T1:.*]] = xegpu.load_nd %{{.*}}  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1]
-// CHECK-NEXT: op    : %[[T2:.*]] = arith.extf %[[T1]] : vector<16x16xf16> to vector<16x16xf32>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1]
-// CHECK-NEXT: op    : %[[T3:.*]] = arith.truncf %[[T2]] : vector<16x16xf32> to vector<16x16xf16>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1]
-// CHECK-NEXT: op    : %[[T4:.*]] = xegpu.dpas %[[T0]], %[[T3]] : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
-// CHECK-NEXT: sg_map for result #0: Not assigned.
-func.func @test_extf_truncf(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>) -> vector<8x16xf32> {
-  %0 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
-  %1 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
-  %2 = arith.extf %1 : vector<16x16xf16> to vector<16x16xf32>
-  %3 = arith.truncf %2 : vector<16x16xf32> to vector<16x16xf16>
-  %4 = xegpu.dpas %0, %3 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
-  return %4 : vector<8x16xf32>
-}
-
-// -----
-// CHECK: function: test_load_gather_with_transpose_effect:
-// CHECK-NEXT: argument: <block argument> of type 'memref<8x16xf16>' at index: 0
-// CHECK-NEXT: sg_map  : Not assigned.
-// CHECK-NEXT: argument: <block argument> of type 'memref<256xf16>' at index: 1
-// CHECK-NEXT: sg_map  : Not assigned.
-// CHECK-NEXT: argument: <block argument> of type 'memref<8x16xf32>' at index: 2
-// CHECK-NEXT: sg_map  : Not assigned.
-// CHECK-NEXT: op    : %{{.*}} = arith.constant 0 : index
-// CHECK-NEXT: sg_map for result #0: Not assigned.
-// CHECK-NEXT: op    : %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
-// CHECK-NEXT: op    : %[[T1:.*]] = xegpu.load_nd %[[T0]]  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
-// CHECK-NEXT: op    : %[[CST:.*]] = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
-// CHECK-NEXT: op    : %[[CST0:.*]] = arith.constant dense<true> : vector<16xi1>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
-// CHECK-NEXT: op    : %[[T2:.*]] = xegpu.create_tdesc %{{.*}}, %[[CST]] : memref<256xf16>, vector<16xindex> -> !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16 : i64>>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [16, 1], wi_data: [1, 2]
-// CHECK-NEXT: op    : %[[T3:.*]] = xegpu.load %[[T2]], %[[CST0]] <{transpose}> : !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16 : i64>>, vector<16xi1> -> vector<16x16xf16>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1]
-// CHECK-NEXT: op    : %[[T4:.*]] = xegpu.dpas %[[T1]], %[[T3]] : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
-// CHECK-NEXT: op    : %[[T5:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
-func.func @test_load_gather_with_transpose_effect(%arg0: memref<8x16xf16>, %arg1: memref<256xf16>, %arg2: memref<8x16xf32>) {
-  %c0 = arith.constant 0 : index
-  %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
-  %1 = xegpu.load_nd %0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
-  %cst = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
-  %cst_0 = arith.constant dense<true> : vector<16xi1>
-  %2 = xegpu.create_tdesc %arg1, %cst : memref<256xf16>, vector<16xindex> -> !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16 : i64>>
-  %3 = xegpu.load %2, %cst_0 <{transpose}> : !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16 : i64>>, vector<16xi1> -> vector<16x16xf16>
-  %4 = xegpu.dpas %1, %3 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
-  %5 = xegpu.create_nd_tdesc %arg2[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
-  xegpu.store_nd %4, %5  : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
-  return
-}
-
-// -----
-// CHECK: function: test_load_gather_1d:
-// CHECK: argument: <block argument> of type 'memref<256xf32>' at index: 0
-// CHECK-NEXT: sg_map  : Not assigned.
-// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<16xf32>' at index: 1
-// CHECK-NEXT: sg_map  : wi_layout: [16], wi_data: [1]
-// CHECK-NEXT: op    : %[[CST:.*]] = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
-// CHECK-NEXT: op    : %[[CST0:.*]] = arith.constant dense<true> : vector<16xi1>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
-// CHECK-NEXT: op    : %[[T0:.*]] = xegpu.create_tdesc %{{.*}}, %[[CST]] : memref<256xf32>, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
-// CHECK-NEXT: op    : %[[T1]] = xegpu.load %[[T0]], %[[CST0]]  : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> -> vector<16xf32>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
-func.func @test_load_gather_1d(%arg0: memref<256xf32>, %arg1: !xegpu.tensor_desc<16xf32>) {
-  %cst = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
-  %cst_0 = arith.constant dense<true> : vector<16xi1>
-  %0 = xegpu.create_tdesc %arg0, %cst : memref<256xf32>, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
-  %1 = xegpu.load %0, %cst_0  : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> -> vector<16xf32>
-  xegpu.store_nd %1, %arg1  : vector<16xf32>, !xegpu.tensor_desc<16xf32>
-  return
-}
-
-// -----
-// CHECK: function: test_store_scatter_with_transpose_effect:
-// CHECK-NEXT: argument: <block argument> of type 'memref<128xf32>' at index: 0
-// CHECK-NEXT: sg_map  : Not assigned.
-// CHECK-NEXT: op    : %[[CST:.*]] = arith.constant dense<1.000000e+00> : vector<8x16xf32>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
-// CHECK-NEXT: op    : %[[CST0:.*]] = arith.constant dense<true> : vector<16xi1>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
-// CHECK-NEXT: op    : %[[CST1:.*]] = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
-// CHECK-NEXT: op    : %[[T0:.*]] = xegpu.create_tdesc %{{.*}}, %[[CST1]] : memref<128xf32>, vector<16xindex> -> !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<chunk_size = 8 : i64>>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [16, 1], wi_data: [1, 1]
-func.func @test_store_scatter_with_transpose_effect(%arg0: memref<128xf32>) {
-  %cst = arith.constant dense<1.000000e+00> : vector<8x16xf32>
-  %cst_0 = arith.constant dense<true> : vector<16xi1>
-  %cst_1 = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
-  %0 = xegpu.create_tdesc %arg0, %cst_1 : memref<128xf32>, vector<16xindex> -> !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<chunk_size = 8 : i64>>
-  xegpu.store %cst, %0, %cst_0 <{transpose}> : vector<8x16xf32>, !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<chunk_size = 8 : i64>>, vector<16xi1>
-  return
-}
-
-// -----
-// CHECK: function: test_store_scatter_1d:
-// CHECK-NEXT: argument: <block argument> of type 'vector<16xf32>' at index: 0
-// CHECK-NEXT: sg_map  : wi_layout: [16], wi_data: [1]
-// CHECK-NEXT: argument: <block argument> of type 'memref<256xf32>' at index: 1
-// CHECK-NEXT: sg_map  : Not assigned.
-// CHECK-NEXT: op    : %[[CST:.*]] = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
-// CHECK-NEXT: op    : %[[CST1:.*]] = arith.constant dense<true> : vector<16xi1>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
-// CHECK-NEXT: op    : %[[T0:.*]] = xegpu.create_tdesc %{{.*}}, %[[CST]] : memref<256xf32>, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
-func.func @test_store_scatter_1d(%arg0: vector<16xf32>, %arg1: memref<256xf32>) {
-  %cst = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
-  %cst_0 = arith.constant dense<true> : vector<16xi1>
-  %0 = xegpu.create_tdesc %arg1, %cst : memref<256xf32>, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
-  xegpu.store %arg0, %0, %cst_0  : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1>
-  return
-}
-
-// -----
-// CHECK: function: test_vector_bitcast_i16_to_i8:
-// CHECK-NEXT: argument: <block argument> of type 'memref<8x16xi16>' at index: 0
-// CHECK-NEXT: sg_map  : Not assigned.
-// CHECK-NEXT: argument: <block argument> of type 'memref<32x16xi8>' at index: 1
-// CHECK-NEXT: sg_map  : Not assigned.
-// CHECK-NEXT: argument: <block argument> of type 'memref<8x16xi32>' at index: 2
-// CHECK-NEXT: sg_map  : Not assigned.
-// CHECK-NEXT: op    : %{{.*}} = arith.constant 0 : index
-// CHECK-NEXT: sg_map for result #0: Not assigned.
-// CHECK-NEXT: op    : %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xi16> -> !xegpu.tensor_desc<8x16xi16>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
-// CHECK-NEXT: op    : %[[T1:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<32x16xi8> -> !xegpu.tensor_desc<32x16xi8>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [4, 1]
-// CHECK-NEXT: op    : %[[T2:.*]] = xegpu.load_nd %[[T0]]  : !xegpu.tensor_desc<8x16xi16> -> vector<8x16xi16>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
-// CHECK-NEXT: op    : %[[T3:.*]] = xegpu.load_nd %[[T1]]  : !xegpu.tensor_desc<32x16xi8> -> vector<32x16xi8>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [4, 1]
-// CHECK-NEXT: op    : %[[T4:.*]] = vector.bitcast %[[T2]] : vector<8x16xi16> to vector<8x32xi8>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 2]
-// CHECK-NEXT: op    : %[[T5:.*]] = xegpu.dpas %[[T4]], %[[T3]] : vector<8x32xi8>, vector<32x16xi8> -> vector<8x16xi32>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
-// CHECK-NEXT: op    : %[[T6:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xi32> -> !xegpu.tensor_desc<8x16xi32>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
-func.func @test_vector_bitcast_i16_to_i8(%arg0: memref<8x16xi16>, %arg1: memref<32x16xi8>, %arg2: memref<8x16xi32>) {
-  %c0 = arith.constant 0 : index
-  %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x16xi16> -> !xegpu.tensor_desc<8x16xi16>
-  %1 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<32x16xi8> -> !xegpu.tensor_desc<32x16xi8>
-  %2 = xegpu.load_nd %0  : !xegpu.tensor_desc<8x16xi16> -> vector<8x16xi16>
-  %3 = xegpu.load_nd %1  : !xegpu.tensor_desc<32x16xi8> -> vector<32x16xi8>
-  %4 = vector.bitcast %2 : vector<8x16xi16> to vector<8x32xi8>
-  %5 = xegpu.dpas %4, %3 : vector<8x32xi8>, vector<32x16xi8> -> vector<8x16xi32>
-  %6 = xegpu.create_nd_tdesc %arg2[%c0, %c0] : memref<8x16xi32> -> !xegpu.tensor_desc<8x16xi32>
-  xegpu.store_nd %5, %6  : vector<8x16xi32>, !xegpu.tensor_desc<8x16xi32>
-  return
-}
-
-// -----
-// CHECK: function: test_vector_bitcast_i8_to_f16:
-// CHECK-NEXT: argument: <block argument> of type 'memref<8x32xi8>' at index: 0
-// CHECK-NEXT: sg_map  : Not assigned.
-// CHECK-NEXT: argument: <block argument> of type 'memref<16x32xi8>' at index: 1
-// CHECK-NEXT: sg_map  : Not assigned.
-// CHECK-NEXT: argument: <block argument> of type 'memref<8x16xf32>' at index: 2
-// CHECK-NEXT: sg_map  : Not assigned.
-// CHECK-NEXT: op    : %{{.*}} = arith.constant 0 : index
-// CHECK-NEXT: sg_map for result #0: Not assigned.
-// CHECK-NEXT: op    : %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x32xi8> -> !xegpu.tensor_desc<8x32xi8>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 2]
-// CHECK-NEXT: op    : %[[T1:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<16x32xi8> -> !xegpu.tensor_desc<16x32xi8>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [4, 1]
-// CHECK-NEXT: op    : %[[T2:.*]] = xegpu.load_nd %[[T0]]  : !xegpu.tensor_desc<8x32xi8> -> vector<8x32xi8>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 2]
-// CHECK-NEXT: op    : %[[T3:.*]] = xegpu.load_nd %[[T1]]  : !xegpu.tensor_desc<16x32xi8> -> vector<16x32xi8>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [4, 1]
-// CHECK-NEXT: op    : %[[T4:.*]] = vector.bitcast %[[T2]] : vector<8x32xi8> to vector<8x16xf16>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
-// CHECK-NEXT: op    : %[[T5:.*]] = vector.bitcast %[[T3]] : vector<16x32xi8> to vector<16x16xf16>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1]
-// CHECK-NEXT: op    : %[[T6:.*]] = xegpu.dpas %[[T4]], %[[T5]] : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
-// CHECK-NEXT: op    : %[[T7:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
-func.func @test_vector_bitcast_i8_to_f16(%arg0: memref<8x32xi8>, %arg1: memref<16x32xi8>, %arg2: memref<8x16xf32>) {
-  %c0 = arith.constant 0 : index
-  %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x32xi8> -> !xegpu.tensor_desc<8x32xi8>
-  %1 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<16x32xi8> -> !xegpu.tensor_desc<16x32xi8>
-  %2 = xegpu.load_nd %0  : !xegpu.tensor_desc<8x32xi8> -> vector<8x32xi8>
-  %3 = xegpu.load_nd %1  : !xegpu.tensor_desc<16x32xi8> -> vector<16x32xi8>
-  %4 = vector.bitcast %2 : vector<8x32xi8> to vector<8x16xf16>
-  %5 = vector.bitcast %3 : vector<16x32xi8> to vector<16x16xf16>
-  %6 = xegpu.dpas %4, %5 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
-  %7 = xegpu.create_nd_tdesc %arg2[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
-  xegpu.store_nd %6, %7  : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
-  return
-}
-
-// -----
-// CHECK: function: test_binary_op_one_use:
-// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<8x16xf16>' at index: 0
-// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
-// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<16x16xf16>' at index: 1
-// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [2, 1]
-// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<8x16xf32>' at index: 2
-// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
-// CHECK-NEXT: op    : %[[T0:.*]] = xegpu.load_nd %{{.*}}  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
-// CHECK-NEXT: op    : %[[T1:.*]] = xegpu.load_nd %{{.*}}  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1]
-// CHECK-NEXT: op    : %[[T2:.*]] = xegpu.load_nd %{{.*}}  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1]
-// CHECK-NEXT: op    : %[[T3:.*]] = arith.addf %[[T1]], %[[T2]] : vector<16x16xf16>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1]
-// CHECK-NEXT: op    : %[[T4:.*]] = xegpu.dpas %[[T0]], %[[T3]] : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
-func.func @test_binary_op_one_use(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: !xegpu.tensor_desc<8x16xf32>) {
-  %0 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
-  %1 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
-  %2 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
-  %3 = arith.addf %1, %2 : vector<16x16xf16>
-  %4 = xegpu.dpas %0, %3 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
-  xegpu.store_nd %4, %arg2  : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
-  return
-}
-
-// -----
-// CHECK: function: test_binary_op_multiple_uses:
-// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<8x16xf16>' at index: 0
-// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
-// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<16x16xf16>' at index: 1
-// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
-// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<8x16xf32>' at index: 2
-// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
-// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<16x16xf16>' at index: 3
-// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
-// CHECK-NEXT: op    : %[[T0:.*]] = xegpu.load_nd %{{.*}}  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
-// CHECK-NEXT: op    : %[[T1:.*]] = xegpu.load_nd %{{.*}}  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
-// CHECK-NEXT: op    : %[[CST:.*]] = arith.constant dense<1.000000e+00> : vector<16x16xf16>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
-// CHECK-NEXT: op    : %[[T2:.*]] = arith.addf %[[T1]], %[[CST]] : vector<16x16xf16>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
-// CHECK-NEXT: op    : %[[T3:.*]] = xegpu.dpas %[[T0]], %[[T2]] : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
-func.func @test_binary_op_multiple_uses(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: !xegpu.tensor_desc<8x16xf32>, %arg3: !xegpu.tensor_desc<16x16xf16>) {
-  %0 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
-  %1 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
-  %cst = arith.constant dense<1.000000e+00> : vector<16x16xf16>
-  %2 = arith.addf %1, %cst : vector<16x16xf16>
-  %3 = xegpu.dpas %0, %2 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
-  xegpu.store_nd %3, %arg2  : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
-  xegpu.store_nd %2, %arg3  : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16>
-  return
-}
-
-// -----
-// CHECK: function: test_for_op:
-// CHECK-NEXT: argument: <block argument> of type 'memref<8x128xf16>' at index: 0
-// CHECK-NEXT: sg_map  : Not assigned.
-// CHECK-NEXT: argument: <block argument> of type 'memref<128x16xf16>' at index: 1
-// CHECK-NEXT: sg_map  : Not assigned.
-// CHECK-NEXT: argument: <block argument> of type 'memref<8x16xf32>' at index: 2
-// CHECK-NEXT: sg_map  : Not assigned.
-// CHECK-NEXT: op    : %{{.*}} = arith.constant 0 : index
-// CHECK-NEXT: sg_map for result #0: Not assigned.
-// CHECK-NEXT: op    : %{{.*}} = arith.constant 128 : index
-// CHECK-NEXT: sg_map for result #0: Not assigned.
-// CHECK-NEXT: op    : %{{.*}} = arith.constant 16 : index
-// CHECK-NEXT: sg_map for result #0: Not assigned.
-// CHECK-NEXT: op    : %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x128xf16> -> !xegpu.tensor_desc<8x16xf16>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
-// CHECK-NEXT: op    : %[[T1:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<128x16xf16> -> !xegpu.tensor_desc<16x16xf16>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1]
-// CHECK-NEXT: op    : %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<8x16xf32>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
-// CHECK-NEXT: op    : %[[T4:.*]] = xegpu.load_nd %{{.*}}  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
-// CHECK-NEXT: op    : %[[T5:.*]] = xegpu.load_nd %{{.*}}  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1]
-// CHECK-NEXT: op    : %[[T6:.*]] = xegpu.dpas %[[T4]], %[[T5]], %{{.*}} : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
-// CHECK-NEXT: op    : %[[T7:.*]] = xegpu.update_nd_offset %{{.*}} : !xegpu.tensor_desc<8x16xf16>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
-// CHECK-NEXT: op    : %[[T8:.*]] = xegpu.update_nd_offset %{{.*}} : !xegpu.tensor_desc<16x16xf16>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1]
-// CHECK-NEXT: op    : scf.for
-// CHECK-NEXT: sg_map for result #0: Not assigned.
-// CHECK-NEXT: sg_map for result #1: Not assigned.
-// CHECK-NEXT: sg_map for result #2: wi_layout: [1, 16], wi_data: [1, 1]
-// CHECK-NEXT: op    : %[[T3:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
-func.func @test_for_op(%arg0: memref<8x128xf16>, %arg1: memref<128x16xf16>, %arg2: memref<8x16xf32>) {
-  %c0 = arith.constant 0 : index
-  %c128 = arith.constant 128 : index
-  %c16 = arith.constant 16 : index
-  %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x128xf16> -> !xegpu.tensor_desc<8x16xf16>
-  %1 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<128x16xf16> -> !xegpu.tensor_desc<16x16xf16>
-  %cst = arith.constant dense<0.000000e+00> : vector<8x16xf32>
-  %2:3 = scf.for %arg3 = %c0 to %c128 step %c16 iter_args(%arg4 = %0, %arg5 = %1, %arg6 = %cst) -> (!xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<16x16xf16>, vector<8x16xf32>) {
-    %4 = xegpu.load_nd %arg4  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
-    %5 = xegpu.load_nd %arg5  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
-    %6 = xegpu.dpas %4, %5, %arg6 : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
-    %7 = xegpu.update_nd_offset %arg4, [%c0, %c16] : !xegpu.tensor_desc<8x16xf16>
-    %8 = xegpu.update_nd_offset %arg5, [%c16, %c0] : !xegpu.tensor_desc<16x16xf16>
-    scf.yield %7, %8, %6 : !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<16x16xf16>, vector<8x16xf32>
-  }
-  %3 = xegpu.create_nd_tdesc %arg2[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
-  xegpu.store_nd %2#2, %3  : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
-  return
-}
-
-// -----
-// CHECK: function: test_if_single_use:
-// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<8x16xf16>' at index: 0
-// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
-// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<16x16xf16>' at index: 1
-// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [2, 1]
-// CHECK-NEXT: argument: <block argument> of type 'i1' at index: 2
-// CHECK-NEXT: sg_map  : Not assigned.
-// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<8x16xf32>' at index: 3
-// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
-// CHECK-NEXT: op    : %[[T0:.*]] = xegpu.load_nd %{{.*}}  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
-// CHECK-NEXT: op    : %[[T3:.*]] = xegpu.load_nd %{{.*}}  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1]
-// CHECK-NEXT: op    : %[[T4:.*]] = xegpu.load_nd %{{.*}}  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1]
-// CHECK-NEXT: op    : scf.if
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1]
-// CHECK-NEXT: op    : %[[T2:.*]] = xegpu.dpas %[[T0]], %{{.*}} : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
-func.func @test_if_single_use(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: i1, %arg3: !xegpu.tensor_desc<8x16xf32>) {
-  %0 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
-  %1 = scf.if %arg2 -> (vector<16x16xf16>) {
-    %3 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
-    scf.yield %3 : vector<16x16xf16>
-  } else {
-    %3 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
-    scf.yield %3 : vector<16x16xf16>
-  }
-  %2 = xegpu.dpas %0, %1 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
-  xegpu.store_nd %2, %arg3  : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
-  return
-}
-
-// -----
-// CHECK: function: test_if_multiple_uses:
-// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<8x16xf16>' at index: 0
-// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
-// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<16x16xf16>' at index: 1
-// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
-// CHECK-NEXT: argument: <block argument> of type 'i1' at index: 2
-// CHECK-NEXT: sg_map  : Not assigned.
-// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<8x16xf32>' at index: 3
-// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
-// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<16x16xf16>' at index: 4
-// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
-// CHECK-NEXT: op    : %[[T0:.*]] = xegpu.load_nd %{{.*}}  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
-// CHECK-NEXT: op    : %[[T3:.*]] = xegpu.load_nd %{{.*}}  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
-// CHECK-NEXT: op    : %[[T4:.*]] = xegpu.load_nd %{{.*}}  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
-// CHECK-NEXT: op    : scf.if
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
-// CHECK-NEXT: op    : %[[T2:.*]] = xegpu.dpas %[[T0]], %{{.*}} : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
-func.func @test_if_multiple_uses(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: i1, %arg3: !xegpu.tensor_desc<8x16xf32>, %arg4: !xegpu.tensor_desc<16x16xf16>) {
-  %0 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
-  %1 = scf.if %arg2 -> (vector<16x16xf16>) {
-    %3 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
-    scf.yield %3 : vector<16x16xf16>
-  } else {
-    %3 = xegpu.load_nd %arg1  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
-    scf.yield %3 : vector<16x16xf16>
-  }
-  %2 = xegpu.dpas %0, %1 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
-  xegpu.store_nd %2, %arg3  : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
-  xegpu.store_nd %1, %arg4  : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16>
-  return
-}
-
-// -----
-// CHECK: function: test_vector_outer_reduction:
-// CHECK-NEXT: argument: <block argument> of type 'vector<16x16xf32>' at index: 0
-// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
-// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<16xf32>' at index: 1
-// CHECK-NEXT: sg_map  : wi_layout: [16], wi_data: [1]
-// CHECK-NEXT: op    : %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<16xf32>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
-// CHECK-NEXT: op    : %[[T0:.*]] = vector.multi_reduction <add>, %{{.*}}, %[[CST]] [0] : vector<16x16xf32> to vector<16xf32>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
-func.func @test_vector_outer_reduction(%arg0: vector<16x16xf32>, %arg1: !xegpu.tensor_desc<16xf32>) {
-  %cst = arith.constant dense<0.000000e+00> : vector<16xf32>
-  %0 = vector.multi_reduction <add>, %arg0, %cst [0] : vector<16x16xf32> to vector<16xf32>
-  xegpu.store_nd %0, %arg1  : vector<16xf32>, !xegpu.tensor_desc<16xf32>
-  return
-}
-
-// -----
-// CHECK: function: test_vector_inner_reduction:
-// CHECK-NEXT: argument: <block argument> of type 'vector<16x16xf32>' at index: 0
-// CHECK-NEXT: sg_map  : wi_layout: [1, 16], wi_data: [1, 1]
-// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<16xf32>' at index: 1
-// CHECK-NEXT: sg_map  : wi_layout: [16], wi_data: [1]
-// CHECK-NEXT: op    : %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<16xf32>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
-// CHECK-NEXT: op    : %[[T0:.*]] = vector.multi_reduction <add>, %{{.*}}, %[[CST]] [1] : vector<16x16xf32> to vector<16xf32>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
-func.func @test_vector_inner_reduction(%arg0: vector<16x16xf32>, %arg1: !xegpu.tensor_desc<16xf32>) {
-  %cst = arith.constant dense<0.000000e+00> : vector<16xf32>
-  %0 = vector.multi_reduction <add>, %arg0, %cst [1] : vector<16x16xf32> to vector<16xf32>
-  xegpu.store_nd %0, %arg1  : vector<16xf32>, !xegpu.tensor_desc<16xf32>
-  return
-}



More information about the Mlir-commits mailing list