[Mlir-commits] [mlir] [MLIR][XeGPU][Draft] Xegpu distribution patterns for load_nd, store_nd, and create_nd_tdesc. (PR #119783)

Charitha Saumya llvmlistbot at llvm.org
Thu Dec 12 14:45:17 PST 2024


https://github.com/charithaintc updated https://github.com/llvm/llvm-project/pull/119783

>From dab68412dc7b133baec0b2ec17428f3378c65cc8 Mon Sep 17 00:00:00 2001
From: Petr Kurapov <petr.a.kurapov at intel.com>
Date: Mon, 9 Dec 2024 20:25:50 +0000
Subject: [PATCH 1/7] [MLIR] Create GPU utils library & move distribution utils

---
 .../mlir/Conversion/GPUCommon/GPUCommonPass.h |   2 +-
 .../mlir/Dialect/GPU/Transforms/Passes.h      |   2 +-
 .../Dialect/GPU/Utils/DistributionUtils.h     |  57 +++++++
 .../{Transforms/Utils.h => Utils/GPUUtils.h}  |   0
 mlir/lib/Dialect/GPU/CMakeLists.txt           |   3 +-
 .../GPU/Transforms/AsyncRegionRewriter.cpp    |   2 +-
 .../GPU/Transforms/KernelOutlining.cpp        |   2 +-
 .../GPU/Transforms/SubgroupReduceLowering.cpp |   2 +-
 mlir/lib/Dialect/GPU/Utils/CMakeLists.txt     |  14 ++
 .../Dialect/GPU/Utils/DistributionUtils.cpp   | 149 ++++++++++++++++++
 .../GPU/{Transforms => Utils}/Utils.cpp       |   2 +-
 .../Dialect/Vector/Transforms/CMakeLists.txt  |   1 +
 .../Vector/Transforms/VectorDistribute.cpp    | 139 +---------------
 13 files changed, 230 insertions(+), 145 deletions(-)
 create mode 100644 mlir/include/mlir/Dialect/GPU/Utils/DistributionUtils.h
 rename mlir/include/mlir/Dialect/GPU/{Transforms/Utils.h => Utils/GPUUtils.h} (100%)
 create mode 100644 mlir/lib/Dialect/GPU/Utils/CMakeLists.txt
 create mode 100644 mlir/lib/Dialect/GPU/Utils/DistributionUtils.cpp
 rename mlir/lib/Dialect/GPU/{Transforms => Utils}/Utils.cpp (96%)

diff --git a/mlir/include/mlir/Conversion/GPUCommon/GPUCommonPass.h b/mlir/include/mlir/Conversion/GPUCommon/GPUCommonPass.h
index 5f40315a849094..094360e75ab617 100644
--- a/mlir/include/mlir/Conversion/GPUCommon/GPUCommonPass.h
+++ b/mlir/include/mlir/Conversion/GPUCommon/GPUCommonPass.h
@@ -8,7 +8,7 @@
 #ifndef MLIR_CONVERSION_GPUCOMMON_GPUCOMMONPASS_H_
 #define MLIR_CONVERSION_GPUCOMMON_GPUCOMMONPASS_H_
 
-#include "mlir/Dialect/GPU/Transforms/Utils.h"
+#include "mlir/Dialect/GPU/Utils/GPUUtils.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/Types.h"
diff --git a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
index 8eb711962583da..eb51d477e23f86 100644
--- a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
@@ -13,7 +13,7 @@
 #ifndef MLIR_DIALECT_GPU_TRANSFORMS_PASSES_H_
 #define MLIR_DIALECT_GPU_TRANSFORMS_PASSES_H_
 
-#include "Utils.h"
+#include "mlir/Dialect/GPU/Utils/GPUUtils.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Pass/Pass.h"
diff --git a/mlir/include/mlir/Dialect/GPU/Utils/DistributionUtils.h b/mlir/include/mlir/Dialect/GPU/Utils/DistributionUtils.h
new file mode 100644
index 00000000000000..6efd2326971982
--- /dev/null
+++ b/mlir/include/mlir/Dialect/GPU/Utils/DistributionUtils.h
@@ -0,0 +1,57 @@
+//===- VectorDistributionUtils.h - Distribution Utilities -------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_GPU_TRANSFORMS_DISTRIBUTIONUTILS_H_
+#define MLIR_DIALECT_GPU_TRANSFORMS_DISTRIBITIONUTILS_H_
+
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/IR/PatternMatch.h"
+
+#include <utility>
+
+namespace mlir {
+namespace gpu {
+/// Return a value yielded by `warpOp` which statifies the filter lamdba
+/// condition and is not dead.
+OpOperand *getWarpResult(WarpExecuteOnLane0Op warpOp,
+                         const std::function<bool(Operation *)> &fn);
+
+/// Helper to create a new WarpExecuteOnLane0Op with different signature.
+WarpExecuteOnLane0Op moveRegionToNewWarpOpAndReplaceReturns(
+    RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
+    ValueRange newYieldedValues, TypeRange newReturnTypes);
+
+/// Helper to create a new WarpExecuteOnLane0Op region with extra outputs.
+/// `indices` return the index of each new output.
+WarpExecuteOnLane0Op moveRegionToNewWarpOpAndAppendReturns(
+    RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
+    ValueRange newYieldedValues, TypeRange newReturnTypes,
+    llvm::SmallVector<size_t> &indices);
+
+/// Helper to know if an op can be hoisted out of the region.
+bool canBeHoisted(Operation *op, function_ref<bool(Value)> definedOutside);
+
+/// Return a value yielded by `warpOp` which statifies the filter lamdba
+/// condition and is not dead.
+OpOperand *getWarpResult(WarpExecuteOnLane0Op warpOp,
+                         const std::function<bool(Operation *)> &fn);
+
+/// Delinearize the given `laneId` into multiple dimensions, where each
+/// dimension's size is determined by `originalShape` and `distributedShape`
+/// together. This function expects the total numbers of threads needed for
+/// distribution is equal to `warpSize`. Returns true and updates
+/// `delinearizedIds` if so.
+bool delinearizeLaneId(OpBuilder &builder, Location loc,
+                       ArrayRef<int64_t> originalShape,
+                       ArrayRef<int64_t> distributedShape, int64_t warpSize,
+                       Value laneId, SmallVectorImpl<Value> &delinearizedIds);
+
+} // namespace gpu
+} // namespace mlir
+
+#endif // MLIR_DIALECT_GPU_TRANSFORMS_DISTRIBUTIONUTILS_H_
diff --git a/mlir/include/mlir/Dialect/GPU/Transforms/Utils.h b/mlir/include/mlir/Dialect/GPU/Utils/GPUUtils.h
similarity index 100%
rename from mlir/include/mlir/Dialect/GPU/Transforms/Utils.h
rename to mlir/include/mlir/Dialect/GPU/Utils/GPUUtils.h
diff --git a/mlir/lib/Dialect/GPU/CMakeLists.txt b/mlir/lib/Dialect/GPU/CMakeLists.txt
index a59645480aba21..1026e9b509332a 100644
--- a/mlir/lib/Dialect/GPU/CMakeLists.txt
+++ b/mlir/lib/Dialect/GPU/CMakeLists.txt
@@ -40,7 +40,6 @@ add_mlir_dialect_library(MLIRGPUTransforms
   Transforms/ShuffleRewriter.cpp
   Transforms/SPIRVAttachTarget.cpp
   Transforms/SubgroupReduceLowering.cpp
-  Transforms/Utils.cpp
   
   OBJECT
 
@@ -59,6 +58,7 @@ add_mlir_dialect_library(MLIRGPUTransforms
   MLIRDataLayoutInterfaces
   MLIRExecutionEngineUtils
   MLIRGPUDialect
+  MLIRGPUUtils
   MLIRIR
   MLIRIndexDialect
   MLIRLLVMDialect
@@ -76,3 +76,4 @@ add_mlir_dialect_library(MLIRGPUTransforms
 
 add_subdirectory(TransformOps)
 add_subdirectory(Pipelines)
+add_subdirectory(Utils)
diff --git a/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp b/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp
index b2fa3a99c53fc3..41a5e39e55064e 100644
--- a/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp
@@ -16,7 +16,7 @@
 #include "mlir/Dialect/Async/IR/Async.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
-#include "mlir/Dialect/GPU/Transforms/Utils.h"
+#include "mlir/Dialect/GPU/Utils/GPUUtils.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/IRMapping.h"
 #include "mlir/IR/PatternMatch.h"
diff --git a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
index ba0c80c50211e3..a6a36848b5635d 100644
--- a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
@@ -18,7 +18,7 @@
 #include "mlir/Dialect/DLTI/DLTI.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
-#include "mlir/Dialect/GPU/Transforms/Utils.h"
+#include "mlir/Dialect/GPU/Utils/GPUUtils.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinAttributes.h"
diff --git a/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp b/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp
index 185f824351a230..43eff3eddcc491 100644
--- a/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp
@@ -13,7 +13,7 @@
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/Dialect/GPU/Transforms/Passes.h"
-#include "mlir/Dialect/GPU/Transforms/Utils.h"
+#include "mlir/Dialect/GPU/Utils/GPUUtils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Location.h"
diff --git a/mlir/lib/Dialect/GPU/Utils/CMakeLists.txt b/mlir/lib/Dialect/GPU/Utils/CMakeLists.txt
new file mode 100644
index 00000000000000..69094c518a159e
--- /dev/null
+++ b/mlir/lib/Dialect/GPU/Utils/CMakeLists.txt
@@ -0,0 +1,14 @@
+add_mlir_dialect_library(MLIRGPUUtils
+  Utils.cpp
+  DistributionUtils.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/GPU/Utils
+
+  LINK_LIBS PUBLIC
+  MLIRArithDialect
+  MLIRAffineDialect
+  MLIRGPUDialect
+  MLIRSupport
+  MLIRIR
+  )
diff --git a/mlir/lib/Dialect/GPU/Utils/DistributionUtils.cpp b/mlir/lib/Dialect/GPU/Utils/DistributionUtils.cpp
new file mode 100644
index 00000000000000..c6e8e03350bbce
--- /dev/null
+++ b/mlir/lib/Dialect/GPU/Utils/DistributionUtils.cpp
@@ -0,0 +1,149 @@
+//===- DistributionUtils.cpp - Distribution tools for GPUOps --------------===//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements distribution utility methods.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/GPU/Utils/DistributionUtils.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/IR/Value.h"
+
+#include <numeric>
+
+using namespace mlir;
+using namespace mlir::gpu;
+
+WarpExecuteOnLane0Op mlir::gpu::moveRegionToNewWarpOpAndReplaceReturns(
+    RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
+    ValueRange newYieldedValues, TypeRange newReturnTypes) {
+  // Create a new op before the existing one, with the extra operands.
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(warpOp);
+  auto newWarpOp = rewriter.create<WarpExecuteOnLane0Op>(
+      warpOp.getLoc(), newReturnTypes, warpOp.getLaneid(), warpOp.getWarpSize(),
+      warpOp.getArgs(), warpOp.getBody()->getArgumentTypes());
+
+  Region &opBody = warpOp.getBodyRegion();
+  Region &newOpBody = newWarpOp.getBodyRegion();
+  Block &newOpFirstBlock = newOpBody.front();
+  rewriter.inlineRegionBefore(opBody, newOpBody, newOpBody.begin());
+  rewriter.eraseBlock(&newOpFirstBlock);
+  assert(newWarpOp.getWarpRegion().hasOneBlock() &&
+         "expected WarpOp with single block");
+
+  auto yield =
+      cast<gpu::YieldOp>(newOpBody.getBlocks().begin()->getTerminator());
+
+  rewriter.modifyOpInPlace(
+      yield, [&]() { yield.getValuesMutable().assign(newYieldedValues); });
+  return newWarpOp;
+}
+
+WarpExecuteOnLane0Op mlir::gpu::moveRegionToNewWarpOpAndAppendReturns(
+    RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
+    ValueRange newYieldedValues, TypeRange newReturnTypes,
+    llvm::SmallVector<size_t> &indices) {
+  SmallVector<Type> types(warpOp.getResultTypes().begin(),
+                          warpOp.getResultTypes().end());
+  auto yield = cast<gpu::YieldOp>(
+      warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+  llvm::SmallSetVector<Value, 32> yieldValues(yield.getOperands().begin(),
+                                              yield.getOperands().end());
+  for (auto newRet : llvm::zip(newYieldedValues, newReturnTypes)) {
+    if (yieldValues.insert(std::get<0>(newRet))) {
+      types.push_back(std::get<1>(newRet));
+      indices.push_back(yieldValues.size() - 1);
+    } else {
+      // If the value already exit the region don't create a new output.
+      for (auto [idx, yieldOperand] :
+           llvm::enumerate(yieldValues.getArrayRef())) {
+        if (yieldOperand == std::get<0>(newRet)) {
+          indices.push_back(idx);
+          break;
+        }
+      }
+    }
+  }
+  yieldValues.insert(newYieldedValues.begin(), newYieldedValues.end());
+  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
+      rewriter, warpOp, yieldValues.getArrayRef(), types);
+  rewriter.replaceOp(warpOp,
+                     newWarpOp.getResults().take_front(warpOp.getNumResults()));
+  return newWarpOp;
+}
+
+bool mlir::gpu::canBeHoisted(Operation *op,
+                             function_ref<bool(Value)> definedOutside) {
+  return llvm::all_of(op->getOperands(), definedOutside) &&
+         isMemoryEffectFree(op) && op->getNumRegions() == 0;
+}
+
+OpOperand *
+mlir::gpu::getWarpResult(WarpExecuteOnLane0Op warpOp,
+                         const std::function<bool(Operation *)> &fn) {
+  auto yield = cast<gpu::YieldOp>(
+      warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+  for (OpOperand &yieldOperand : yield->getOpOperands()) {
+    Value yieldValues = yieldOperand.get();
+    Operation *definedOp = yieldValues.getDefiningOp();
+    if (definedOp && fn(definedOp)) {
+      if (!warpOp.getResult(yieldOperand.getOperandNumber()).use_empty())
+        return &yieldOperand;
+    }
+  }
+  return {};
+}
+
+bool mlir::gpu::delinearizeLaneId(OpBuilder &builder, Location loc,
+                                  ArrayRef<int64_t> originalShape,
+                                  ArrayRef<int64_t> distributedShape,
+                                  int64_t warpSize, Value laneId,
+                                  SmallVectorImpl<Value> &delinearizedIds) {
+  // If the original shape and the distributed shape is the same, we don't
+  // distribute at all--every thread is handling the whole. For such case, we
+  // should not rely on lane IDs later. So just return an empty lane ID vector.
+  if (originalShape == distributedShape) {
+    delinearizedIds.clear();
+    return true;
+  }
+
+  SmallVector<int64_t> sizes;
+  for (auto [large, small] : llvm::zip_equal(originalShape, distributedShape)) {
+    if (large % small != 0)
+      return false;
+    sizes.push_back(large / small);
+  }
+  if (std::accumulate(sizes.begin(), sizes.end(), 1,
+                      std::multiplies<int64_t>()) != warpSize)
+    return false;
+
+  AffineExpr s0, s1;
+  bindSymbols(builder.getContext(), s0, s1);
+
+  int64_t usedThreads = 1;
+
+  Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
+  delinearizedIds.assign(sizes.size(), zero);
+
+  for (int i = sizes.size() - 1; i >= 0; --i) {
+    usedThreads *= sizes[i];
+    if (usedThreads == warpSize) {
+      // We've used up all available threads. Don't need to perform modulo
+      // anymore. And we can stop the calculation for further dimensions.
+      delinearizedIds[i] = laneId;
+      break;
+    }
+    delinearizedIds[i] =
+        affine::makeComposedAffineApply(builder, loc, s0 % sizes[i], {laneId});
+    laneId = affine::makeComposedAffineApply(
+        builder, loc, s0.floorDiv(usedThreads), {laneId});
+  }
+  return true;
+}
diff --git a/mlir/lib/Dialect/GPU/Transforms/Utils.cpp b/mlir/lib/Dialect/GPU/Utils/Utils.cpp
similarity index 96%
rename from mlir/lib/Dialect/GPU/Transforms/Utils.cpp
rename to mlir/lib/Dialect/GPU/Utils/Utils.cpp
index e91aa18128c7b9..1f09875b3e2732 100644
--- a/mlir/lib/Dialect/GPU/Transforms/Utils.cpp
+++ b/mlir/lib/Dialect/GPU/Utils/Utils.cpp
@@ -10,7 +10,7 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "mlir/Dialect/GPU/Transforms/Utils.h"
+#include "mlir/Dialect/GPU/Utils/GPUUtils.h"
 #include "llvm/Support/ErrorHandling.h"
 
 namespace mlir::gpu {
diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
index 9a3bd5d4593d63..8ca5cb6c6dfabc 100644
--- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
@@ -38,6 +38,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
   MLIRArithDialect
   MLIRDialectUtils
   MLIRGPUDialect
+  MLIRGPUUtils
   MLIRIR
   MLIRLinalgDialect
   MLIRMemRefDialect
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 3e142598369951..d080b0b0bd44bd 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -9,6 +9,7 @@
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/GPU/Utils/DistributionUtils.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
@@ -18,7 +19,6 @@
 #include "mlir/Transforms/RegionUtils.h"
 #include "llvm/ADT/SetVector.h"
 #include "llvm/Support/FormatVariadic.h"
-#include <numeric>
 #include <utility>
 
 using namespace mlir;
@@ -162,92 +162,6 @@ struct DistributedLoadStoreHelper {
 
 } // namespace
 
-/// Helper to create a new WarpExecuteOnLane0Op with different signature.
-static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndReplaceReturns(
-    RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
-    ValueRange newYieldedValues, TypeRange newReturnTypes) {
-  // Create a new op before the existing one, with the extra operands.
-  OpBuilder::InsertionGuard g(rewriter);
-  rewriter.setInsertionPoint(warpOp);
-  auto newWarpOp = rewriter.create<WarpExecuteOnLane0Op>(
-      warpOp.getLoc(), newReturnTypes, warpOp.getLaneid(), warpOp.getWarpSize(),
-      warpOp.getArgs(), warpOp.getBody()->getArgumentTypes());
-
-  Region &opBody = warpOp.getBodyRegion();
-  Region &newOpBody = newWarpOp.getBodyRegion();
-  Block &newOpFirstBlock = newOpBody.front();
-  rewriter.inlineRegionBefore(opBody, newOpBody, newOpBody.begin());
-  rewriter.eraseBlock(&newOpFirstBlock);
-  assert(newWarpOp.getWarpRegion().hasOneBlock() &&
-         "expected WarpOp with single block");
-
-  auto yield =
-      cast<gpu::YieldOp>(newOpBody.getBlocks().begin()->getTerminator());
-
-  rewriter.modifyOpInPlace(
-      yield, [&]() { yield.getValuesMutable().assign(newYieldedValues); });
-  return newWarpOp;
-}
-
-/// Helper to create a new WarpExecuteOnLane0Op region with extra outputs.
-/// `indices` return the index of each new output.
-static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndAppendReturns(
-    RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
-    ValueRange newYieldedValues, TypeRange newReturnTypes,
-    llvm::SmallVector<size_t> &indices) {
-  SmallVector<Type> types(warpOp.getResultTypes().begin(),
-                          warpOp.getResultTypes().end());
-  auto yield = cast<gpu::YieldOp>(
-      warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
-  llvm::SmallSetVector<Value, 32> yieldValues(yield.getOperands().begin(),
-                                              yield.getOperands().end());
-  for (auto newRet : llvm::zip(newYieldedValues, newReturnTypes)) {
-    if (yieldValues.insert(std::get<0>(newRet))) {
-      types.push_back(std::get<1>(newRet));
-      indices.push_back(yieldValues.size() - 1);
-    } else {
-      // If the value already exit the region don't create a new output.
-      for (auto [idx, yieldOperand] :
-           llvm::enumerate(yieldValues.getArrayRef())) {
-        if (yieldOperand == std::get<0>(newRet)) {
-          indices.push_back(idx);
-          break;
-        }
-      }
-    }
-  }
-  yieldValues.insert(newYieldedValues.begin(), newYieldedValues.end());
-  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
-      rewriter, warpOp, yieldValues.getArrayRef(), types);
-  rewriter.replaceOp(warpOp,
-                     newWarpOp.getResults().take_front(warpOp.getNumResults()));
-  return newWarpOp;
-}
-
-/// Helper to know if an op can be hoisted out of the region.
-static bool canBeHoisted(Operation *op,
-                         function_ref<bool(Value)> definedOutside) {
-  return llvm::all_of(op->getOperands(), definedOutside) &&
-         isMemoryEffectFree(op) && op->getNumRegions() == 0;
-}
-
-/// Return a value yielded by `warpOp` which statifies the filter lamdba
-/// condition and is not dead.
-static OpOperand *getWarpResult(WarpExecuteOnLane0Op warpOp,
-                                const std::function<bool(Operation *)> &fn) {
-  auto yield = cast<gpu::YieldOp>(
-      warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
-  for (OpOperand &yieldOperand : yield->getOpOperands()) {
-    Value yieldValues = yieldOperand.get();
-    Operation *definedOp = yieldValues.getDefiningOp();
-    if (definedOp && fn(definedOp)) {
-      if (!warpOp.getResult(yieldOperand.getOperandNumber()).use_empty())
-        return &yieldOperand;
-    }
-  }
-  return {};
-}
-
 // Clones `op` into a new operation that takes `operands` and returns
 // `resultTypes`.
 static Operation *cloneOpWithOperandsAndTypes(RewriterBase &rewriter,
@@ -770,57 +684,6 @@ struct WarpOpConstant : public OpRewritePattern<WarpExecuteOnLane0Op> {
   }
 };
 
-/// Delinearize the given `laneId` into multiple dimensions, where each
-/// dimension's size is determined by `originalShape` and `distributedShape`
-/// together. This function expects the total numbers of threads needed for
-/// distribution is equal to `warpSize`. Returns true and updates
-/// `delinearizedIds` if so.
-bool delinearizeLaneId(OpBuilder &builder, Location loc,
-                       ArrayRef<int64_t> originalShape,
-                       ArrayRef<int64_t> distributedShape, int64_t warpSize,
-                       Value laneId, SmallVectorImpl<Value> &delinearizedIds) {
-  // If the original shape and the distributed shape is the same, we don't
-  // distribute at all--every thread is handling the whole. For such case, we
-  // should not rely on lane IDs later. So just return an empty lane ID vector.
-  if (originalShape == distributedShape) {
-    delinearizedIds.clear();
-    return true;
-  }
-
-  SmallVector<int64_t> sizes;
-  for (auto [large, small] : llvm::zip_equal(originalShape, distributedShape)) {
-    if (large % small != 0)
-      return false;
-    sizes.push_back(large / small);
-  }
-  if (std::accumulate(sizes.begin(), sizes.end(), 1,
-                      std::multiplies<int64_t>()) != warpSize)
-    return false;
-
-  AffineExpr s0, s1;
-  bindSymbols(builder.getContext(), s0, s1);
-
-  int64_t usedThreads = 1;
-
-  Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
-  delinearizedIds.assign(sizes.size(), zero);
-
-  for (int i = sizes.size() - 1; i >= 0; --i) {
-    usedThreads *= sizes[i];
-    if (usedThreads == warpSize) {
-      // We've used up all available threads. Don't need to perform modulo
-      // anymore. And we can stop the calculation for further dimensions.
-      delinearizedIds[i] = laneId;
-      break;
-    }
-    delinearizedIds[i] =
-        affine::makeComposedAffineApply(builder, loc, s0 % sizes[i], {laneId});
-    laneId = affine::makeComposedAffineApply(
-        builder, loc, s0.floorDiv(usedThreads), {laneId});
-  }
-  return true;
-}
-
 /// Sink out transfer_read op feeding into a warp op yield.
 /// ```
 /// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {

>From f6cd50a7028b806824194baf3cd231e5342f12dc Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Thu, 12 Dec 2024 19:39:28 +0000
Subject: [PATCH 2/7] pass added

---
 .../Dialect/XeGPU/Transforms/Transforms.h     |   1 +
 mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp        |  10 +-
 .../Dialect/XeGPU/Transforms/CMakeLists.txt   |   5 +
 .../XeGPU/Transforms/XeGPUDistribute.cpp      | 393 ++++++++++++++++++
 mlir/test/Dialect/XeGPU/xegpu-distribute.mlir |  81 ++++
 mlir/test/lib/Dialect/CMakeLists.txt          |   1 +
 mlir/test/lib/Dialect/XeGPU/CMakeLists.txt    |  16 +
 .../lib/Dialect/XeGPU/TestXeGPUTransforms.cpp |  58 +++
 mlir/tools/mlir-opt/CMakeLists.txt            |   1 +
 mlir/tools/mlir-opt/mlir-opt.cpp              |   2 +
 10 files changed, 563 insertions(+), 5 deletions(-)
 create mode 100644 mlir/lib/Dialect/XeGPU/Transforms/XeGPUDistribute.cpp
 create mode 100644 mlir/test/Dialect/XeGPU/xegpu-distribute.mlir
 create mode 100644 mlir/test/lib/Dialect/XeGPU/CMakeLists.txt
 create mode 100644 mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp

diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h b/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
index 63ea26df069372..fe5198d1ac6dba 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
@@ -16,6 +16,7 @@ namespace xegpu {
 
 /// Appends patterns for folding aliasing ops into XeGPU ops into `patterns`.
 void populateXeGPUFoldAliasOpsPatterns(RewritePatternSet &patterns);
+void populateXeGPUDistributePatterns(RewritePatternSet &patterns);
 
 } // namespace xegpu
 } // namespace mlir
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 9d3c4366a7bd50..f0e67df3ee6069 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -281,11 +281,11 @@ LogicalResult LoadNdOp::verify() {
     tdescShape.insert(it, array_len);
   }
 
-  if (tdescShape != valueShape)
-    return emitOpError() << "Result shape doesn't match TensorDesc shape."
-                         << "The expected shape is " << makeString(tdescShape)
-                         << ". But the given shape is "
-                         << makeString(valueShape) << ".\n";
+  // if (tdescShape != valueShape)
+  //   return emitOpError() << "Result shape doesn't match TensorDesc shape."
+  //                        << "The expected shape is " << makeString(tdescShape)
+  //                        << ". But the given shape is "
+  //                        << makeString(valueShape) << ".\n";
   return success();
 }
 
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
index 7fb64d3b97b87d..148ff46ba41b72 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
@@ -1,5 +1,6 @@
 add_mlir_dialect_library(MLIRXeGPUTransforms
   XeGPUFoldAliasOps.cpp
+  XeGPUDistribute.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/XeGPU
@@ -12,6 +13,10 @@ add_mlir_dialect_library(MLIRXeGPUTransforms
   MLIRIR
   MLIRMemRefDialect
   MLIRXeGPUDialect
+  MLIRVectorDialect
+  MLIRVectorUtils
+  MLIRArithDialect
+  MLIRFuncDialect
   MLIRPass
   MLIRTransforms
 )
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUDistribute.cpp
new file mode 100644
index 00000000000000..52c953697d0f34
--- /dev/null
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUDistribute.cpp
@@ -0,0 +1,393 @@
+//===- XeGPUDistribute.cpp - XeGPU ditribute ops to work items --*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
+#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "xegpu-distribute"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
+
+using namespace mlir;
+
+namespace {
+bool divisible(APInt lhs, APInt rhs) { return !lhs.urem(rhs); }
+
+// /// Clone a create_nd_tdesc feeding into vector.yield op for the enclosing
+// /// `vector.warp_execute_on_lane_0` and put it after the warp op.
+// /// The warp op will still contain the original op that will not be used by the
+// /// yield op (and should be cleaned up later with dce). The yield op will bypass
+// /// the create_nd_tdesc's arguments.
+// /// The rewrite will create a subview of the size used by a single work item and
+// /// appropriate offset. The distributed create_nd_tdesc points into the subview
+// /// without offset. The tensor descriptor types is distributed according to
+// /// sg_map attribute.
+// ///
+// /// Example:
+// ///
+// /// ```
+// ///   #sg_map_8 = #xegpu.sg_map<wi_layout = [1, 8], wi_data = [1, 1]>
+// ///   %r = vector.warp_execute_on_lane_0(%laneid) ->
+// ///                   (!xegpu.tensor_desc<4x8xf32>) {
+// ///     ...
+// ///     %td = xegpu.create_nd_tdesc %arg0[0, 0]
+// ///               : memref<4x8xf32> -> !xegpu.tensor_desc<4x8xf32>
+// ///     vector.yield %td
+// ///   }
+// /// ```
+// /// To
+// /// ```
+// ///   %r:2 = vector.warp_execute_on_lane_0(%laneid) -> () {
+// ///     ...
+// ///     %dead = xegpu.create_nd_tdesc %arg0[0, 0]
+// ///               : memref<4x8xf32> -> !xegpu.tensor_desc<4x8xf32>
+// ///     vector.yield %arg0, %dead
+// ///   }
+// ///   %view = memref.subview %r#0[0, %laneid] [4, 1] [1, 1]
+// ///                               : memref<4x8xf32> to memref<4x1xf32>
+// ///   %td = xegpu.create_nd_tdesc %view[0, 0]: memref<4x1xf32>
+// ///                                 -> !xegpu.tensor_desc<4x1xf32>
+// ///
+// /// ```
+// struct WarpOpTensorDescOp final
+//     : public OpRewritePattern<vector::WarpExecuteOnLane0Op> {
+//   using OpRewritePattern<vector::WarpExecuteOnLane0Op>::OpRewritePattern;
+//   LogicalResult matchAndRewrite(vector::WarpExecuteOnLane0Op warpOp,
+//                                 PatternRewriter &rewriter) const override;
+// };
+
+// /// Sink a store_nd feeding into vector.yield op for the enclosing
+// /// `vector.warp_execute_on_lane_0`. In case arguments for the store are passed
+// /// through the warp op interface they would be propagated as returned values.
+// /// Both the stored vector type and tensor descriptor types are distributed
+// /// according to sg_map attribute.
+// ///
+// /// Example:
+// ///
+// /// ```
+// ///   #sg_map_8 = #xegpu.sg_map<wi_layout = [1, 8], wi_data = [1, 1]>
+// ///   vector.warp_execute_on_lane_0(%laneid) -> () {
+// ///     ...
+// ///     xegpu.store_nd %arg0, %arg1: vector<4x8xf32>,
+// ///                                 !xegpu.tensor_desc<4x8xf32>
+// ///     vector.yield
+// ///   }
+// /// ```
+// /// To
+// /// ```
+// ///   %r = vector.warp_execute_on_lane_0(%laneid) -> () {
+// ///     ...
+// ///     vector.yield
+// ///   }
+// ///   xegpu.store_nd %arg0, %arg1: vector<4x1xf32>, !xegpu.tensor_desc<4x1xf32>
+// ///
+// /// ```
+// struct WarpOpStoreNd final
+//     : public OpRewritePattern<vector::WarpExecuteOnLane0Op> {
+//   using OpRewritePattern<vector::WarpExecuteOnLane0Op>::OpRewritePattern;
+//   LogicalResult matchAndRewrite(vector::WarpExecuteOnLane0Op warpOp,
+//                                 PatternRewriter &rewriter) const override;
+// };
+
+// /// Clone a load_nd feeding into vector.yield op for the enclosing
+// /// `vector.warp_execute_on_lane_0` and put it after the warp op.
+// /// The warp op will still contain the original op that will not be used by the
+// /// yield op (and should be cleaned up later with dce). The yield op will bypass
+// /// the load's arguments.
+// /// Both the loaded vector type and tensor descriptor types are distributed
+// /// according to sg_map attribute.
+// ///
+// /// Example:
+// ///
+// /// ```
+// ///   #sg_map_8 = #xegpu.sg_map<wi_layout = [1, 8], wi_data = [1, 1]>
+// ///   %r = vector.warp_execute_on_lane_0(%laneid) ->
+// ///                   (!xegpu.tensor_desc<4x8xf32>) {
+// ///     ...
+// ///     %ld = xegpu.load_nd %arg0, %arg1: !xegpu.tensor_desc<4x8xf32>,
+// ///     vector<4x8xf32> vector.yield %ld
+// ///   }
+// /// ```
+// /// To
+// /// ```
+// ///   %r:2 = vector.warp_execute_on_lane_0(%laneid) -> () {
+// ///     ...
+// ///     %dead = xegpu.load_nd %arg0, %arg1:
+// ///         !xegpu.tensor_desc<4x8xf32>, vector<4x8xf32>
+// ///     vector.yield %arg0, %arg1
+// ///   }
+// ///   xegpu.store_nd %r#0, %r#1: vector<4x1xf32>, !xegpu.tensor_desc<4x1xf32>
+// ///
+// /// ```
+// struct WarpOpLoadNd final
+//     : public OpRewritePattern<vector::WarpExecuteOnLane0Op> {
+//   using OpRewritePattern<vector::WarpExecuteOnLane0Op>::OpRewritePattern;
+//   LogicalResult matchAndRewrite(vector::WarpExecuteOnLane0Op warpOp,
+//                                 PatternRewriter &rewriter) const override;
+// };
+
+// FailureOr<VectorType> getDistributedVectorType(VectorType originalT,
+//                                                xegpu::SGMapAttr sgMap) {
+//   llvm::SmallVector<int64_t, 2> distributedShape;
+//   auto layout = sgMap.getWiLayout();
+//   auto shape = originalT.getShape();
+//   for (const auto [l, o] : llvm::zip_equal(layout, shape)) {
+//     if (!divisible(APInt(64, o), APInt(64, l)))
+//       return failure();
+//     distributedShape.push_back(o / l);
+//   }
+//   auto newVectorType =
+//       VectorType::get(distributedShape, originalT.getElementType(),
+//                       originalT.getScalableDims());
+//   return newVectorType;
+// }
+
+// FailureOr<xegpu::TensorDescType>
+// getDistributedTensorDescType(xegpu::TensorDescType originalT,
+//                              xegpu::SGMapAttr sgMap,
+//                              xegpu::MemorySpace memSpace) {
+//   llvm::SmallVector<int64_t, 2> distributedShape;
+//   auto layout = sgMap.getWiLayout();
+//   auto shape = originalT.getShape();
+//   for (const auto [l, o] : llvm::zip_equal(layout, shape)) {
+//     if (!divisible(APInt(64, o), APInt(64, l)))
+//       return failure();
+//     distributedShape.push_back(o / l);
+//   }
+//   xegpu::TensorDescType distributedDescType;
+//   if (originalT.isScattered()) {
+
+//     distributedDescType = xegpu::TensorDescType::get(
+//         distributedShape, originalT.getElementType(), originalT.getChunkSize(),
+//         originalT.getMemorySpace(), originalT.getSGMapAttr());
+//   } else {
+//     distributedDescType = xegpu::TensorDescType::get(
+//         distributedShape, originalT.getElementType(),
+//         originalT.getBoundaryCheck(), originalT.getArrayLength(),
+//         originalT.getMemorySpace(), originalT.getSGMapAttr());
+//   }
+//   return distributedDescType;
+// }
+} // namespace
+
+// LogicalResult
+// WarpOpStoreNd::matchAndRewrite(vector::WarpExecuteOnLane0Op warpOp,
+//                                PatternRewriter &rewriter) const {
+//   auto yield = cast<vector::YieldOp>(
+//       warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+//   Operation *lastNode = yield->getPrevNode();
+//   auto storeOp = dyn_cast_or_null<xegpu::StoreNdOp>(lastNode);
+//   if (!storeOp)
+//     return failure();
+
+//   auto origType = storeOp.getTensorDescType();
+//   xegpu::SGMapAttr sgMap = origType.getSGMapAttr();
+//   if (!sgMap)
+//     return rewriter.notifyMatchFailure(
+//         storeOp, "the source tensor descriptor lacks sg_map attribute");
+
+//   if (storeOp.getTensorDescType().getShape().size() != 2)
+//     return rewriter.notifyMatchFailure(storeOp, "unsupported shape");
+//   DBGS() << "Matched store_nd: " << storeOp << "\n";
+
+//   auto distributedTypeOrFailure =
+//       getDistributedVectorType(storeOp.getValueType(), sgMap);
+//   if (failed(distributedTypeOrFailure))
+//     return rewriter.notifyMatchFailure(storeOp,
+//                                        "Failed to distribute the type");
+//   VectorType newVectorType = distributedTypeOrFailure.value();
+
+//   auto distributedDescTypeOrFailure = getDistributedTensorDescType(
+//       storeOp.getTensorDescType(), sgMap,
+//       storeOp.getTensorDescType().getMemorySpace());
+//   if (failed(distributedDescTypeOrFailure))
+//     return rewriter.notifyMatchFailure(storeOp,
+//                                        "Failed to distribute the desc type");
+//   xegpu::TensorDescType newTDescType = distributedDescTypeOrFailure.value();
+
+//   SmallVector<size_t> newRetIndices;
+//   vector::WarpExecuteOnLane0Op newWarpOp =
+//       moveRegionToNewWarpOpAndAppendReturns(
+//           rewriter, warpOp,
+//           ValueRange{storeOp.getTensorDesc(), storeOp.getValue()},
+//           TypeRange{newTDescType, newVectorType}, newRetIndices);
+
+//   rewriter.setInsertionPointAfter(newWarpOp);
+//   auto newStoreOp =
+//       cast<xegpu::StoreNdOp>(rewriter.clone(*storeOp.getOperation()));
+//   rewriter.eraseOp(storeOp);
+//   newStoreOp.getTensorDescMutable().assign(
+//       newWarpOp.getResult(newRetIndices[0]));
+//   newStoreOp.getValueMutable().assign(newWarpOp.getResult(newRetIndices[1]));
+
+//   return success();
+// }
+
+// LogicalResult WarpOpLoadNd::matchAndRewrite(vector::WarpExecuteOnLane0Op warpOp,
+//                                             PatternRewriter &rewriter) const {
+//   OpOperand *operand = getWarpResult(warpOp, [](Operation *op) {
+//     return isa<xegpu::LoadNdOp>(op) && op->hasOneUse();
+//   });
+
+//   if (!operand)
+//     return rewriter.notifyMatchFailure(warpOp,
+//                                        "warp result is not a xegpu::LoadNd op");
+
+//   auto loadOp = operand->get().getDefiningOp<xegpu::LoadNdOp>();
+
+//   if (loadOp.getPacked())
+//     return rewriter.notifyMatchFailure(
+//         loadOp, "Packed load distribution not supported");
+
+//   xegpu::TensorDescType origType = loadOp.getTensorDescType();
+//   xegpu::SGMapAttr sgMap = origType.getSGMapAttr();
+//   if (!sgMap)
+//     return rewriter.notifyMatchFailure(
+//         loadOp, "the source tensor descriptor lacks sg_map attribute");
+
+//   auto origShape = origType.getShape();
+//   if (origShape.size() != 2)
+//     return rewriter.notifyMatchFailure(loadOp, "unsupported shape");
+
+//   auto distributedTypeOrFailure =
+//       getDistributedVectorType(loadOp.getType(), sgMap);
+//   if (failed(distributedTypeOrFailure))
+//     return rewriter.notifyMatchFailure(loadOp, "Failed to distribute the type");
+//   VectorType newVectorType = distributedTypeOrFailure.value();
+
+//   auto distributedDescTypeOrFailure =
+//       getDistributedTensorDescType(loadOp.getTensorDescType(), sgMap,
+//                                    loadOp.getTensorDescType().getMemorySpace());
+//   if (failed(distributedDescTypeOrFailure))
+//     return rewriter.notifyMatchFailure(loadOp,
+//                                        "Failed to distribute the desc type");
+//   xegpu::TensorDescType newTDescType = distributedDescTypeOrFailure.value();
+
+//   unsigned operandIdx = operand->getOperandNumber();
+
+//   SmallVector<size_t> newRetIndices;
+//   vector::WarpExecuteOnLane0Op newWarpOp =
+//       moveRegionToNewWarpOpAndAppendReturns(
+//           rewriter, warpOp, loadOp.getTensorDesc(), TypeRange{newTDescType},
+//           newRetIndices);
+
+//   rewriter.setInsertionPointAfter(newWarpOp);
+
+//   auto newLoadOp = rewriter.create<xegpu::LoadNdOp>(
+//       loadOp.getLoc(), newVectorType, loadOp.getTensorDesc(),
+//       loadOp.getPackedAttr(), loadOp.getTransposeAttr(), loadOp.getL1HintAttr(),
+//       loadOp.getL2HintAttr(), loadOp.getL3HintAttr());
+
+//   newLoadOp.getTensorDescMutable().assign(
+//       newWarpOp.getResult(newRetIndices[0]));
+//   Value distributedVal = newWarpOp.getResult(operandIdx);
+//   rewriter.replaceAllUsesWith(distributedVal, newLoadOp);
+
+//   return success();
+// }
+
+// LogicalResult
+// WarpOpTensorDescOp::matchAndRewrite(vector::WarpExecuteOnLane0Op warpOp,
+//                                     PatternRewriter &rewriter) const {
+//   OpOperand *operand = getWarpResult(warpOp, [](Operation *op) {
+//     return isa<xegpu::CreateNdDescOp>(op) && op->hasOneUse();
+//   });
+
+//   if (!operand)
+//     return rewriter.notifyMatchFailure(
+//         warpOp, "warp result is not a xegpu::CreateNdDesc op");
+//   auto descOp = operand->get().getDefiningOp<xegpu::CreateNdDescOp>();
+//   assert(descOp && "desc op must be not null");
+//   unsigned operandIdx = operand->getOperandNumber();
+
+//   // TODO: is memref uniform in the region
+//   rewriter.setInsertionPoint(warpOp);
+//   auto srcTypedVal = dyn_cast<TypedValue<MemRefType>>(descOp.getSource());
+//   assert(srcTypedVal && "source value must be not null");
+
+//   auto descOffsets = descOp.getMixedOffsets();
+//   if (descOffsets.size() != 2)
+//     return rewriter.notifyMatchFailure(descOp,
+//                                        "offsets size is expected to be 2");
+
+//   xegpu::SGMapAttr sgMap = descOp.getType().getSGMapAttr();
+//   if (!sgMap)
+//     return rewriter.notifyMatchFailure(
+//         descOp, "the tensor descriptor lacks sg_map attribute");
+
+//   auto layout = sgMap.getWiLayout();
+
+//   // Calculate the offset within tensor descriptor for the current lane_id. The
+//   // access to proper element for a work item is done through a lane-specific
+//   // subview (tdesc offsets are used as base, lane shift is added on top).
+//   auto laneid = warpOp.getLaneid();
+//   auto xDim =
+//       rewriter.create<arith::ConstantIndexOp>(laneid.getLoc(), layout[0]);
+//   auto shiftx = rewriter.create<arith::RemUIOp>(laneid.getLoc(), laneid, xDim);
+//   auto shifty = rewriter.create<arith::DivUIOp>(laneid.getLoc(), laneid, xDim);
+
+//   auto basex = getValueOrCreateConstantIndexOp(rewriter, laneid.getLoc(),
+//                                                descOffsets[0]);
+//   auto basey = getValueOrCreateConstantIndexOp(rewriter, laneid.getLoc(),
+//                                                descOffsets[1]);
+//   auto offsetx = rewriter.create<arith::AddIOp>(laneid.getLoc(), shiftx, basex);
+//   auto offsety = rewriter.create<arith::AddIOp>(laneid.getLoc(), shifty, basey);
+
+//   auto distributedDescTypeOrFailure = getDistributedTensorDescType(
+//       descOp.getType(), sgMap, descOp.getType().getMemorySpace());
+//   if (failed(distributedDescTypeOrFailure))
+//     return rewriter.notifyMatchFailure(descOp,
+//                                        "Failed to distribute the desc type");
+//   xegpu::TensorDescType newTDescType = distributedDescTypeOrFailure.value();
+//   auto distributedShape = newTDescType.getShape();
+//   // use the base memref strides
+//   SmallVector<OpFoldResult> overwriteStrides =
+//       getAsIndexOpFoldResult(rewriter.getContext(), SmallVector<int64_t>{1, 1});
+//   SmallVector<OpFoldResult> overwriteSizes =
+//       getAsIndexOpFoldResult(rewriter.getContext(), distributedShape);
+
+//   SmallVector<size_t> newRetIndices;
+//   vector::WarpExecuteOnLane0Op newWarpOp =
+//       moveRegionToNewWarpOpAndAppendReturns(
+//           rewriter, warpOp, descOp.getSource(), descOp.getSourceType(),
+//           newRetIndices);
+
+//   rewriter.setInsertionPointAfter(newWarpOp);
+//   auto subview = rewriter.create<memref::SubViewOp>(
+//       newWarpOp.getLoc(), srcTypedVal, getAsOpFoldResult({offsetx, offsety}),
+//       overwriteSizes, overwriteStrides);
+//   subview.getSourceMutable().assign(newWarpOp.getResult(newRetIndices[0]));
+
+//   auto zero = rewriter.create<arith::ConstantIndexOp>(laneid.getLoc(), 0);
+//   auto newDescOp = rewriter.create<xegpu::CreateNdDescOp>(
+//       newWarpOp.getLoc(), newTDescType, subview,
+//       getAsOpFoldResult({zero, zero}));
+
+//   Value distributedVal = newWarpOp.getResult(operandIdx);
+//   rewriter.replaceAllUsesWith(distributedVal, newDescOp);
+
+//   return success();
+// }
+
+void xegpu::populateXeGPUDistributePatterns(RewritePatternSet &patterns) {
+  // patterns.add<WarpOpTensorDescOp>(patterns.getContext());
+  // patterns.add<WarpOpStoreNd>(patterns.getContext());
+  // patterns.add<WarpOpLoadNd>(patterns.getContext());
+}
diff --git a/mlir/test/Dialect/XeGPU/xegpu-distribute.mlir b/mlir/test/Dialect/XeGPU/xegpu-distribute.mlir
new file mode 100644
index 00000000000000..f9efda80ab6468
--- /dev/null
+++ b/mlir/test/Dialect/XeGPU/xegpu-distribute.mlir
@@ -0,0 +1,81 @@
+// RUN: mlir-opt -test-xegpu-distribute -split-input-file %s | FileCheck %s
+
+#sg_map_16 = #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>
+#blk_tdesc = #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>
+
+// CHECK-LABEL: test_store_nd_distribution
+// CHECK: %[[laneid:.*]] = gpu.lane_id
+// CHECK: %[[res:.*]]:2 = gpu.warp_execute_on_lane_0(%[[laneid]])[16] args(%{{.*}}, %{{.*}} :  vector<24x32xf16>, !xegpu.tensor_desc<24x32xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>)
+// CHECK-SAME: -> (!xegpu.tensor_desc<24x32xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>, vector<24x2xf16>)
+// CHECK: ^bb0(%[[src:.*]]: vector<24x32xf16>, %[[dst:.*]]: !xegpu.tensor_desc<24x32xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>)
+// CHECK: gpu.yield%[[dst]], %[[src]] : !xegpu.tensor_desc<24x32xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>, vector<24x32xf16>
+// CHECK: xegpu.store_nd %[[res]]#1, %[[res]]#0 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}> :
+// CHECK-SAME: vector<24x2xf16>, !xegpu.tensor_desc<24x32xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+
+func.func @test_store_nd_distribution(%src: vector<24x32xf16>, %dst: !xegpu.tensor_desc<24x32xf16, #blk_tdesc, #sg_map_16>) -> () {
+  %laneid = gpu.lane_id
+  gpu.warp_execute_on_lane_0(%laneid)[16]
+        args(%src, %dst: vector<24x32xf16>, !xegpu.tensor_desc<24x32xf16, #blk_tdesc, #sg_map_16>) {
+    ^bb0(%arg0: vector<24x32xf16>, %arg1: !xegpu.tensor_desc<24x32xf16, #blk_tdesc, #sg_map_16>):
+    xegpu.store_nd %arg0, %arg1 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>: vector<24x32xf16>, !xegpu.tensor_desc<24x32xf16, #blk_tdesc, #sg_map_16>
+  }
+  return
+}
+
+// -----
+
+#sg_map_16 = #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>
+#blk_tdesc = #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>
+
+// CHECK-LABEL: test_load_nd_distribution
+// CHECK: %[[laneid:.*]] = gpu.lane_id
+// CHECK: %[[res:.*]]:2 = gpu.warp_execute_on_lane_0(%[[laneid]])[16] args(%{{.*}} :  !xegpu.tensor_desc<24x32xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>)
+// CHECK-SAME: -> (vector<24x2xf16>, !xegpu.tensor_desc<24x32xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>)
+// CHECK: ^bb0(%[[dst:.*]]: !xegpu.tensor_desc<24x32xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>)
+// CHECK: %[[dead:.*]] = xegpu.load_nd
+// CHECK: gpu.yield%[[dead]], %[[dst]] : vector<24x32xf16>, !xegpu.tensor_desc<24x32xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+// CHECK: %[[load:.*]] = xegpu.load_nd %[[res]]#1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> :
+// CHECK-SAME: !xegpu.tensor_desc<24x32xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>> -> vector<24x2xf16>
+// CHECK: return %[[load]]
+
+func.func @test_load_nd_distribution(%dst: !xegpu.tensor_desc<24x32xf16, #blk_tdesc, #sg_map_16>) -> (vector<24x2xf16>) {
+  %laneid = gpu.lane_id
+  %r = gpu.warp_execute_on_lane_0(%laneid)[16]
+        args(%dst: !xegpu.tensor_desc<24x32xf16, #blk_tdesc, #sg_map_16>) -> (vector<24x2xf16>) {
+    ^bb0(%arg0: !xegpu.tensor_desc<24x32xf16, #blk_tdesc, #sg_map_16>):
+    %0 = xegpu.load_nd %arg0 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+       : !xegpu.tensor_desc<24x32xf16, #blk_tdesc, #sg_map_16> -> vector<24x32xf16>
+    gpu.yield%0 : vector<24x32xf16>
+  }
+  return %r : vector<24x2xf16>
+}
+
+// -----
+
+#sg_map_16 = #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>
+#blk_tdesc = #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>
+
+// CHECK-LABEL: test_create_nd_desc_distribution
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[laneid:.*]] = gpu.lane_id
+// CHECK: %[[res:.*]]:2 = gpu.warp_execute_on_lane_0(%[[laneid]])[16] args(%{{.*}} : memref<24x32xf16>)
+// CHECK-SAME: -> (!xegpu.tensor_desc<24x32xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>, memref<24x32xf16>)
+// CHECK: ^bb0(%[[dst:.*]]: memref<24x32xf16>)
+// CHECK: %[[dead:.*]] = xegpu.create_nd_tdesc
+// CHECK: gpu.yield%[[dead]], %[[dst]] :
+// CHECK-SAME: !xegpu.tensor_desc<24x32xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>, memref<24x32xf16>
+// CHECK: %[[view:.*]] = memref.subview %[[res]]#1[%[[C0]], %[[laneid]]] [24, 2] [1, 1] : memref<24x32xf16> to memref<24x2xf16, strided<[32, 1], offset: ?>>
+// CHECK: %[[desc:.*]] = xegpu.create_nd_tdesc %[[view]][0, 0] : memref<24x2xf16, strided<[32, 1], offset: ?>>
+// CHECK-SAME: -> !xegpu.tensor_desc<24x32xf16, #xegpu.block_tdesc_attr<memory_space =  global, array_length = 1 : i64, boundary_check = true>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+// CHECK: return %[[desc]]
+
+func.func @test_create_nd_desc_distribution(%dst: memref<24x32xf16>) -> (!xegpu.tensor_desc<24x32xf16, #blk_tdesc, #sg_map_16>) {
+  %laneid = gpu.lane_id
+  %r = gpu.warp_execute_on_lane_0(%laneid)[16]
+        args(%dst: memref<24x32xf16>) -> (!xegpu.tensor_desc<24x32xf16, #blk_tdesc, #sg_map_16>) {
+    ^bb0(%arg0: memref<24x32xf16>):
+    %0 = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<24x32xf16, #blk_tdesc, #sg_map_16>
+    gpu.yield%0 : !xegpu.tensor_desc<24x32xf16, #blk_tdesc, #sg_map_16>
+  }
+  return %r : !xegpu.tensor_desc<24x32xf16, #blk_tdesc, #sg_map_16>
+}
\ No newline at end of file
diff --git a/mlir/test/lib/Dialect/CMakeLists.txt b/mlir/test/lib/Dialect/CMakeLists.txt
index 29fb4441a24fd2..a8fd70e6397a52 100644
--- a/mlir/test/lib/Dialect/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/CMakeLists.txt
@@ -22,3 +22,4 @@ add_subdirectory(TestDyn)
 add_subdirectory(Tosa)
 add_subdirectory(Transform)
 add_subdirectory(Vector)
+add_subdirectory(XeGPU)
diff --git a/mlir/test/lib/Dialect/XeGPU/CMakeLists.txt b/mlir/test/lib/Dialect/XeGPU/CMakeLists.txt
new file mode 100644
index 00000000000000..d67439a8b6127e
--- /dev/null
+++ b/mlir/test/lib/Dialect/XeGPU/CMakeLists.txt
@@ -0,0 +1,16 @@
+# Exclude tests from libMLIR.so
+add_mlir_library(MLIRXeGPUTestPasses
+  TestXeGPUTransforms.cpp
+
+  EXCLUDE_FROM_LIBMLIR
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/XeGPU
+
+  LINK_LIBS PUBLIC
+  MLIRIR
+  MLIRPass
+  MLIRXeGPUTransforms
+  MLIRXeGPUDialect
+  MLIRSupport
+  )
\ No newline at end of file
diff --git a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
new file mode 100644
index 00000000000000..a5060c8eeb916e
--- /dev/null
+++ b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
@@ -0,0 +1,58 @@
+//===- TestXeGPUTransforms.cpp - Test XeGPU transforms and lowerings ------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+using namespace mlir::xegpu;
+using namespace mlir::vector;
+
+namespace {
+struct TestXeGPUDistribution
+    : public PassWrapper<TestXeGPUDistribution, OperationPass<func::FuncOp>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestXeGPUDistribution)
+
+  TestXeGPUDistribution() = default;
+  TestXeGPUDistribution(const TestXeGPUDistribution &pass)
+      : PassWrapper(pass) {}
+
+  StringRef getArgument() const final { return "test-xegpu-distribute"; }
+  StringRef getDescription() const final {
+    return "Test patterns for operations work item distribution";
+  }
+
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<xegpu::XeGPUDialect>();
+    registry.insert<vector::VectorDialect>();
+    registry.insert<arith::ArithDialect>();
+    registry.insert<memref::MemRefDialect>();
+  }
+
+  void runOnOperation() override {
+    RewritePatternSet patterns(&getContext());
+    populateXeGPUDistributePatterns(patterns);
+    (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+  }
+};
+} // namespace
+
+namespace mlir {
+namespace test {
+void registerTestXeGPUTransforms() {
+  PassRegistration<TestXeGPUDistribution>();
+}
+} // namespace test
+} // namespace mlir
\ No newline at end of file
diff --git a/mlir/tools/mlir-opt/CMakeLists.txt b/mlir/tools/mlir-opt/CMakeLists.txt
index 8b79de58fa1028..e4ffbbee7a1d94 100644
--- a/mlir/tools/mlir-opt/CMakeLists.txt
+++ b/mlir/tools/mlir-opt/CMakeLists.txt
@@ -47,6 +47,7 @@ if(MLIR_INCLUDE_TESTS)
     MLIRTilingInterfaceTestPasses
     MLIRVectorTestPasses
     MLIRTestVectorToSPIRV
+    MLIRXeGPUTestPasses
     MLIRLLVMTestPasses
     )
   set(test_libs ${test_libs}
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 960f7037a1b61f..c89ff9964efa03 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -153,6 +153,7 @@ void registerTestTransformDialectEraseSchedulePass();
 void registerTestPassStateExtensionCommunication();
 void registerTestVectorLowerings();
 void registerTestVectorReductionToSPIRVDotProd();
+void registerTestXeGPUTransforms();
 void registerTestVulkanRunnerPipeline();
 void registerTestWrittenToPass();
 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH
@@ -291,6 +292,7 @@ void registerTestPasses() {
   mlir::test::registerTestTransformDialectEraseSchedulePass();
   mlir::test::registerTestPassStateExtensionCommunication();
   mlir::test::registerTestVectorLowerings();
+  mlir::test::registerTestXeGPUTransforms();
   mlir::test::registerTestVectorReductionToSPIRVDotProd();
   mlir::test::registerTestVulkanRunnerPipeline();
   mlir::test::registerTestWrittenToPass();

>From 1c0692085ae2518768132e53e6436bf443e5e0a4 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Thu, 12 Dec 2024 19:47:16 +0000
Subject: [PATCH 3/7] fix

---
 .../XeGPU/Transforms/XeGPUDistribute.cpp      | 231 +++++++++---------
 1 file changed, 116 insertions(+), 115 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUDistribute.cpp
index 52c953697d0f34..5c45a293a8d468 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUDistribute.cpp
@@ -12,6 +12,7 @@
 #include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/GPU/Utils/DistributionUtils.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
@@ -66,12 +67,12 @@ bool divisible(APInt lhs, APInt rhs) { return !lhs.urem(rhs); }
 // ///                                 -> !xegpu.tensor_desc<4x1xf32>
 // ///
 // /// ```
-// struct WarpOpTensorDescOp final
-//     : public OpRewritePattern<vector::WarpExecuteOnLane0Op> {
-//   using OpRewritePattern<vector::WarpExecuteOnLane0Op>::OpRewritePattern;
-//   LogicalResult matchAndRewrite(vector::WarpExecuteOnLane0Op warpOp,
-//                                 PatternRewriter &rewriter) const override;
-// };
+struct WarpOpTensorDescOp final
+    : public OpRewritePattern<gpu::WarpExecuteOnLane0Op> {
+  using OpRewritePattern<gpu::WarpExecuteOnLane0Op>::OpRewritePattern;
+  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
+                                PatternRewriter &rewriter) const override;
+};
 
 // /// Sink a store_nd feeding into vector.yield op for the enclosing
 // /// `vector.warp_execute_on_lane_0`. In case arguments for the store are passed
@@ -159,32 +160,32 @@ bool divisible(APInt lhs, APInt rhs) { return !lhs.urem(rhs); }
 //   return newVectorType;
 // }
 
-// FailureOr<xegpu::TensorDescType>
-// getDistributedTensorDescType(xegpu::TensorDescType originalT,
-//                              xegpu::SGMapAttr sgMap,
-//                              xegpu::MemorySpace memSpace) {
-//   llvm::SmallVector<int64_t, 2> distributedShape;
-//   auto layout = sgMap.getWiLayout();
-//   auto shape = originalT.getShape();
-//   for (const auto [l, o] : llvm::zip_equal(layout, shape)) {
-//     if (!divisible(APInt(64, o), APInt(64, l)))
-//       return failure();
-//     distributedShape.push_back(o / l);
-//   }
-//   xegpu::TensorDescType distributedDescType;
-//   if (originalT.isScattered()) {
-
-//     distributedDescType = xegpu::TensorDescType::get(
-//         distributedShape, originalT.getElementType(), originalT.getChunkSize(),
-//         originalT.getMemorySpace(), originalT.getSGMapAttr());
-//   } else {
-//     distributedDescType = xegpu::TensorDescType::get(
-//         distributedShape, originalT.getElementType(),
-//         originalT.getBoundaryCheck(), originalT.getArrayLength(),
-//         originalT.getMemorySpace(), originalT.getSGMapAttr());
-//   }
-//   return distributedDescType;
-// }
+FailureOr<xegpu::TensorDescType>
+getDistributedTensorDescType(xegpu::TensorDescType originalT,
+                             xegpu::SGMapAttr sgMap,
+                             xegpu::MemorySpace memSpace) {
+  llvm::SmallVector<int64_t, 2> distributedShape;
+  auto layout = sgMap.getWiLayout();
+  auto shape = originalT.getShape();
+  for (const auto [l, o] : llvm::zip_equal(layout, shape)) {
+    if (!divisible(APInt(64, o), APInt(64, l)))
+      return failure();
+    distributedShape.push_back(o / l);
+  }
+  xegpu::TensorDescType distributedDescType;
+  if (originalT.isScattered()) {
+
+    distributedDescType = xegpu::TensorDescType::get(
+        distributedShape, originalT.getElementType(), originalT.getChunkSize(),
+        originalT.getMemorySpace(), originalT.getSGMapAttr());
+  } else {
+    distributedDescType = xegpu::TensorDescType::get(
+        distributedShape, originalT.getElementType(),
+        originalT.getBoundaryCheck(), originalT.getArrayLength(),
+        originalT.getMemorySpace(), originalT.getSGMapAttr());
+  }
+  return distributedDescType;
+}
 } // namespace
 
 // LogicalResult
@@ -303,91 +304,91 @@ bool divisible(APInt lhs, APInt rhs) { return !lhs.urem(rhs); }
 //   return success();
 // }
 
-// LogicalResult
-// WarpOpTensorDescOp::matchAndRewrite(vector::WarpExecuteOnLane0Op warpOp,
-//                                     PatternRewriter &rewriter) const {
-//   OpOperand *operand = getWarpResult(warpOp, [](Operation *op) {
-//     return isa<xegpu::CreateNdDescOp>(op) && op->hasOneUse();
-//   });
-
-//   if (!operand)
-//     return rewriter.notifyMatchFailure(
-//         warpOp, "warp result is not a xegpu::CreateNdDesc op");
-//   auto descOp = operand->get().getDefiningOp<xegpu::CreateNdDescOp>();
-//   assert(descOp && "desc op must be not null");
-//   unsigned operandIdx = operand->getOperandNumber();
-
-//   // TODO: is memref uniform in the region
-//   rewriter.setInsertionPoint(warpOp);
-//   auto srcTypedVal = dyn_cast<TypedValue<MemRefType>>(descOp.getSource());
-//   assert(srcTypedVal && "source value must be not null");
-
-//   auto descOffsets = descOp.getMixedOffsets();
-//   if (descOffsets.size() != 2)
-//     return rewriter.notifyMatchFailure(descOp,
-//                                        "offsets size is expected to be 2");
-
-//   xegpu::SGMapAttr sgMap = descOp.getType().getSGMapAttr();
-//   if (!sgMap)
-//     return rewriter.notifyMatchFailure(
-//         descOp, "the tensor descriptor lacks sg_map attribute");
-
-//   auto layout = sgMap.getWiLayout();
-
-//   // Calculate the offset within tensor descriptor for the current lane_id. The
-//   // access to proper element for a work item is done through a lane-specific
-//   // subview (tdesc offsets are used as base, lane shift is added on top).
-//   auto laneid = warpOp.getLaneid();
-//   auto xDim =
-//       rewriter.create<arith::ConstantIndexOp>(laneid.getLoc(), layout[0]);
-//   auto shiftx = rewriter.create<arith::RemUIOp>(laneid.getLoc(), laneid, xDim);
-//   auto shifty = rewriter.create<arith::DivUIOp>(laneid.getLoc(), laneid, xDim);
-
-//   auto basex = getValueOrCreateConstantIndexOp(rewriter, laneid.getLoc(),
-//                                                descOffsets[0]);
-//   auto basey = getValueOrCreateConstantIndexOp(rewriter, laneid.getLoc(),
-//                                                descOffsets[1]);
-//   auto offsetx = rewriter.create<arith::AddIOp>(laneid.getLoc(), shiftx, basex);
-//   auto offsety = rewriter.create<arith::AddIOp>(laneid.getLoc(), shifty, basey);
-
-//   auto distributedDescTypeOrFailure = getDistributedTensorDescType(
-//       descOp.getType(), sgMap, descOp.getType().getMemorySpace());
-//   if (failed(distributedDescTypeOrFailure))
-//     return rewriter.notifyMatchFailure(descOp,
-//                                        "Failed to distribute the desc type");
-//   xegpu::TensorDescType newTDescType = distributedDescTypeOrFailure.value();
-//   auto distributedShape = newTDescType.getShape();
-//   // use the base memref strides
-//   SmallVector<OpFoldResult> overwriteStrides =
-//       getAsIndexOpFoldResult(rewriter.getContext(), SmallVector<int64_t>{1, 1});
-//   SmallVector<OpFoldResult> overwriteSizes =
-//       getAsIndexOpFoldResult(rewriter.getContext(), distributedShape);
-
-//   SmallVector<size_t> newRetIndices;
-//   vector::WarpExecuteOnLane0Op newWarpOp =
-//       moveRegionToNewWarpOpAndAppendReturns(
-//           rewriter, warpOp, descOp.getSource(), descOp.getSourceType(),
-//           newRetIndices);
-
-//   rewriter.setInsertionPointAfter(newWarpOp);
-//   auto subview = rewriter.create<memref::SubViewOp>(
-//       newWarpOp.getLoc(), srcTypedVal, getAsOpFoldResult({offsetx, offsety}),
-//       overwriteSizes, overwriteStrides);
-//   subview.getSourceMutable().assign(newWarpOp.getResult(newRetIndices[0]));
-
-//   auto zero = rewriter.create<arith::ConstantIndexOp>(laneid.getLoc(), 0);
-//   auto newDescOp = rewriter.create<xegpu::CreateNdDescOp>(
-//       newWarpOp.getLoc(), newTDescType, subview,
-//       getAsOpFoldResult({zero, zero}));
-
-//   Value distributedVal = newWarpOp.getResult(operandIdx);
-//   rewriter.replaceAllUsesWith(distributedVal, newDescOp);
-
-//   return success();
-// }
+LogicalResult
+WarpOpTensorDescOp::matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
+                                    PatternRewriter &rewriter) const {
+  OpOperand *operand = getWarpResult(warpOp, [](Operation *op) {
+    return isa<xegpu::CreateNdDescOp>(op) && op->hasOneUse();
+  });
+
+  if (!operand)
+    return rewriter.notifyMatchFailure(
+        warpOp, "warp result is not a xegpu::CreateNdDesc op");
+  auto descOp = operand->get().getDefiningOp<xegpu::CreateNdDescOp>();
+  assert(descOp && "desc op must be not null");
+  unsigned operandIdx = operand->getOperandNumber();
+
+  // TODO: is memref uniform in the region
+  rewriter.setInsertionPoint(warpOp);
+  auto srcTypedVal = dyn_cast<TypedValue<MemRefType>>(descOp.getSource());
+  assert(srcTypedVal && "source value must be not null");
+
+  auto descOffsets = descOp.getMixedOffsets();
+  if (descOffsets.size() != 2)
+    return rewriter.notifyMatchFailure(descOp,
+                                       "offsets size is expected to be 2");
+
+  xegpu::SGMapAttr sgMap = descOp.getType().getSGMapAttr();
+  if (!sgMap)
+    return rewriter.notifyMatchFailure(
+        descOp, "the tensor descriptor lacks sg_map attribute");
+
+  auto layout = sgMap.getWiLayout();
+
+  // Calculate the offset within tensor descriptor for the current lane_id. The
+  // access to proper element for a work item is done through a lane-specific
+  // subview (tdesc offsets are used as base, lane shift is added on top).
+  auto laneid = warpOp.getLaneid();
+  auto xDim =
+      rewriter.create<arith::ConstantIndexOp>(laneid.getLoc(), layout[0]);
+  auto shiftx = rewriter.create<arith::RemUIOp>(laneid.getLoc(), laneid, xDim);
+  auto shifty = rewriter.create<arith::DivUIOp>(laneid.getLoc(), laneid, xDim);
+
+  auto basex = getValueOrCreateConstantIndexOp(rewriter, laneid.getLoc(),
+                                               descOffsets[0]);
+  auto basey = getValueOrCreateConstantIndexOp(rewriter, laneid.getLoc(),
+                                               descOffsets[1]);
+  auto offsetx = rewriter.create<arith::AddIOp>(laneid.getLoc(), shiftx, basex);
+  auto offsety = rewriter.create<arith::AddIOp>(laneid.getLoc(), shifty, basey);
+
+  auto distributedDescTypeOrFailure = getDistributedTensorDescType(
+      descOp.getType(), sgMap, descOp.getType().getMemorySpace());
+  if (failed(distributedDescTypeOrFailure))
+    return rewriter.notifyMatchFailure(descOp,
+                                       "Failed to distribute the desc type");
+  xegpu::TensorDescType newTDescType = distributedDescTypeOrFailure.value();
+  auto distributedShape = newTDescType.getShape();
+  // use the base memref strides
+  SmallVector<OpFoldResult> overwriteStrides =
+      getAsIndexOpFoldResult(rewriter.getContext(), SmallVector<int64_t>{1, 1});
+  SmallVector<OpFoldResult> overwriteSizes =
+      getAsIndexOpFoldResult(rewriter.getContext(), distributedShape);
+
+  SmallVector<size_t> newRetIndices;
+  gpu::WarpExecuteOnLane0Op newWarpOp =
+      moveRegionToNewWarpOpAndAppendReturns(
+          rewriter, warpOp, descOp.getSource(), descOp.getSourceType(),
+          newRetIndices);
+
+  rewriter.setInsertionPointAfter(newWarpOp);
+  auto subview = rewriter.create<memref::SubViewOp>(
+      newWarpOp.getLoc(), srcTypedVal, getAsOpFoldResult({offsetx, offsety}),
+      overwriteSizes, overwriteStrides);
+  subview.getSourceMutable().assign(newWarpOp.getResult(newRetIndices[0]));
+
+  auto zero = rewriter.create<arith::ConstantIndexOp>(laneid.getLoc(), 0);
+  auto newDescOp = rewriter.create<xegpu::CreateNdDescOp>(
+      newWarpOp.getLoc(), newTDescType, subview,
+      getAsOpFoldResult({zero, zero}));
+
+  Value distributedVal = newWarpOp.getResult(operandIdx);
+  rewriter.replaceAllUsesWith(distributedVal, newDescOp);
+
+  return success();
+}
 
 void xegpu::populateXeGPUDistributePatterns(RewritePatternSet &patterns) {
-  // patterns.add<WarpOpTensorDescOp>(patterns.getContext());
+  patterns.add<WarpOpTensorDescOp>(patterns.getContext());
   // patterns.add<WarpOpStoreNd>(patterns.getContext());
   // patterns.add<WarpOpLoadNd>(patterns.getContext());
 }

>From 9888c849327aa0caef34325c0a8333851503714d Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Thu, 12 Dec 2024 21:15:39 +0000
Subject: [PATCH 4/7] fix

---
 mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td        | 3 +--
 mlir/lib/Dialect/XeGPU/Transforms/XeGPUDistribute.cpp | 6 +++++-
 2 files changed, 6 insertions(+), 3 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index 5910aa3f7f2dae..f3ffbd0f5a027d 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -327,8 +327,7 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [AllElementTypesMatch<["value", "Tensor
   let hasVerifier = 1;
 }
 
-def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [AllShapesMatch<["value", "TensorDesc"]>,
-                                       AllElementTypesMatch<["value", "TensorDesc"]>]> {
+def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [AllElementTypesMatch<["value", "TensorDesc"]>]> {
   let summary = "stores a n-D block register region back to memory, currently only supports 2D";
 
   let description = [{
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUDistribute.cpp
index 5c45a293a8d468..ac02be78280218 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUDistribute.cpp
@@ -170,7 +170,11 @@ getDistributedTensorDescType(xegpu::TensorDescType originalT,
   for (const auto [l, o] : llvm::zip_equal(layout, shape)) {
     if (!divisible(APInt(64, o), APInt(64, l)))
       return failure();
-    distributedShape.push_back(o / l);
+    // Tensor descriptor is distributed only for the scattered case.
+    if (originalT.isScattered())
+      distributedShape.push_back(o / l);
+    else
+      distributedShape.push_back(o);
   }
   xegpu::TensorDescType distributedDescType;
   if (originalT.isScattered()) {

>From 491625d81d80cb0092ec5704f0de23c039e9c681 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Thu, 12 Dec 2024 21:34:10 +0000
Subject: [PATCH 5/7] fix

---
 .../XeGPU/Transforms/XeGPUDistribute.cpp      | 319 +++++++++---------
 1 file changed, 161 insertions(+), 158 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUDistribute.cpp
index ac02be78280218..95610a9af7ac85 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUDistribute.cpp
@@ -33,11 +33,15 @@ bool divisible(APInt lhs, APInt rhs) { return !lhs.urem(rhs); }
 
 // /// Clone a create_nd_tdesc feeding into vector.yield op for the enclosing
 // /// `vector.warp_execute_on_lane_0` and put it after the warp op.
-// /// The warp op will still contain the original op that will not be used by the
-// /// yield op (and should be cleaned up later with dce). The yield op will bypass
+// /// The warp op will still contain the original op that will not be used by
+// the
+// /// yield op (and should be cleaned up later with dce). The yield op will
+// bypass
 // /// the create_nd_tdesc's arguments.
-// /// The rewrite will create a subview of the size used by a single work item and
-// /// appropriate offset. The distributed create_nd_tdesc points into the subview
+// /// The rewrite will create a subview of the size used by a single work item
+// and
+// /// appropriate offset. The distributed create_nd_tdesc points into the
+// subview
 // /// without offset. The tensor descriptor types is distributed according to
 // /// sg_map attribute.
 // ///
@@ -75,8 +79,10 @@ struct WarpOpTensorDescOp final
 };
 
 // /// Sink a store_nd feeding into vector.yield op for the enclosing
-// /// `vector.warp_execute_on_lane_0`. In case arguments for the store are passed
-// /// through the warp op interface they would be propagated as returned values.
+// /// `vector.warp_execute_on_lane_0`. In case arguments for the store are
+// passed
+// /// through the warp op interface they would be propagated as returned
+// values.
 // /// Both the stored vector type and tensor descriptor types are distributed
 // /// according to sg_map attribute.
 // ///
@@ -97,20 +103,23 @@ struct WarpOpTensorDescOp final
 // ///     ...
 // ///     vector.yield
 // ///   }
-// ///   xegpu.store_nd %arg0, %arg1: vector<4x1xf32>, !xegpu.tensor_desc<4x1xf32>
+// ///   xegpu.store_nd %arg0, %arg1: vector<4x1xf32>,
+// !xegpu.tensor_desc<4x1xf32>
 // ///
 // /// ```
-// struct WarpOpStoreNd final
-//     : public OpRewritePattern<vector::WarpExecuteOnLane0Op> {
-//   using OpRewritePattern<vector::WarpExecuteOnLane0Op>::OpRewritePattern;
-//   LogicalResult matchAndRewrite(vector::WarpExecuteOnLane0Op warpOp,
-//                                 PatternRewriter &rewriter) const override;
-// };
+struct WarpOpStoreNd final
+    : public OpRewritePattern<gpu::WarpExecuteOnLane0Op> {
+  using OpRewritePattern<gpu::WarpExecuteOnLane0Op>::OpRewritePattern;
+  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
+                                PatternRewriter &rewriter) const override;
+};
 
 // /// Clone a load_nd feeding into vector.yield op for the enclosing
 // /// `vector.warp_execute_on_lane_0` and put it after the warp op.
-// /// The warp op will still contain the original op that will not be used by the
-// /// yield op (and should be cleaned up later with dce). The yield op will bypass
+// /// The warp op will still contain the original op that will not be used by
+// the
+// /// yield op (and should be cleaned up later with dce). The yield op will
+// bypass
 // /// the load's arguments.
 // /// Both the loaded vector type and tensor descriptor types are distributed
 // /// according to sg_map attribute.
@@ -137,28 +146,27 @@ struct WarpOpTensorDescOp final
 // ///   xegpu.store_nd %r#0, %r#1: vector<4x1xf32>, !xegpu.tensor_desc<4x1xf32>
 // ///
 // /// ```
-// struct WarpOpLoadNd final
-//     : public OpRewritePattern<vector::WarpExecuteOnLane0Op> {
-//   using OpRewritePattern<vector::WarpExecuteOnLane0Op>::OpRewritePattern;
-//   LogicalResult matchAndRewrite(vector::WarpExecuteOnLane0Op warpOp,
-//                                 PatternRewriter &rewriter) const override;
-// };
-
-// FailureOr<VectorType> getDistributedVectorType(VectorType originalT,
-//                                                xegpu::SGMapAttr sgMap) {
-//   llvm::SmallVector<int64_t, 2> distributedShape;
-//   auto layout = sgMap.getWiLayout();
-//   auto shape = originalT.getShape();
-//   for (const auto [l, o] : llvm::zip_equal(layout, shape)) {
-//     if (!divisible(APInt(64, o), APInt(64, l)))
-//       return failure();
-//     distributedShape.push_back(o / l);
-//   }
-//   auto newVectorType =
-//       VectorType::get(distributedShape, originalT.getElementType(),
-//                       originalT.getScalableDims());
-//   return newVectorType;
-// }
+struct WarpOpLoadNd final : public OpRewritePattern<gpu::WarpExecuteOnLane0Op> {
+  using OpRewritePattern<gpu::WarpExecuteOnLane0Op>::OpRewritePattern;
+  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
+                                PatternRewriter &rewriter) const override;
+};
+
+FailureOr<VectorType> getDistributedVectorType(VectorType originalT,
+                                               xegpu::SGMapAttr sgMap) {
+  llvm::SmallVector<int64_t, 2> distributedShape;
+  auto layout = sgMap.getWiLayout();
+  auto shape = originalT.getShape();
+  for (const auto [l, o] : llvm::zip_equal(layout, shape)) {
+    if (!divisible(APInt(64, o), APInt(64, l)))
+      return failure();
+    distributedShape.push_back(o / l);
+  }
+  auto newVectorType =
+      VectorType::get(distributedShape, originalT.getElementType(),
+                      originalT.getScalableDims());
+  return newVectorType;
+}
 
 FailureOr<xegpu::TensorDescType>
 getDistributedTensorDescType(xegpu::TensorDescType originalT,
@@ -192,121 +200,117 @@ getDistributedTensorDescType(xegpu::TensorDescType originalT,
 }
 } // namespace
 
-// LogicalResult
-// WarpOpStoreNd::matchAndRewrite(vector::WarpExecuteOnLane0Op warpOp,
-//                                PatternRewriter &rewriter) const {
-//   auto yield = cast<vector::YieldOp>(
-//       warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
-//   Operation *lastNode = yield->getPrevNode();
-//   auto storeOp = dyn_cast_or_null<xegpu::StoreNdOp>(lastNode);
-//   if (!storeOp)
-//     return failure();
-
-//   auto origType = storeOp.getTensorDescType();
-//   xegpu::SGMapAttr sgMap = origType.getSGMapAttr();
-//   if (!sgMap)
-//     return rewriter.notifyMatchFailure(
-//         storeOp, "the source tensor descriptor lacks sg_map attribute");
-
-//   if (storeOp.getTensorDescType().getShape().size() != 2)
-//     return rewriter.notifyMatchFailure(storeOp, "unsupported shape");
-//   DBGS() << "Matched store_nd: " << storeOp << "\n";
-
-//   auto distributedTypeOrFailure =
-//       getDistributedVectorType(storeOp.getValueType(), sgMap);
-//   if (failed(distributedTypeOrFailure))
-//     return rewriter.notifyMatchFailure(storeOp,
-//                                        "Failed to distribute the type");
-//   VectorType newVectorType = distributedTypeOrFailure.value();
-
-//   auto distributedDescTypeOrFailure = getDistributedTensorDescType(
-//       storeOp.getTensorDescType(), sgMap,
-//       storeOp.getTensorDescType().getMemorySpace());
-//   if (failed(distributedDescTypeOrFailure))
-//     return rewriter.notifyMatchFailure(storeOp,
-//                                        "Failed to distribute the desc type");
-//   xegpu::TensorDescType newTDescType = distributedDescTypeOrFailure.value();
-
-//   SmallVector<size_t> newRetIndices;
-//   vector::WarpExecuteOnLane0Op newWarpOp =
-//       moveRegionToNewWarpOpAndAppendReturns(
-//           rewriter, warpOp,
-//           ValueRange{storeOp.getTensorDesc(), storeOp.getValue()},
-//           TypeRange{newTDescType, newVectorType}, newRetIndices);
-
-//   rewriter.setInsertionPointAfter(newWarpOp);
-//   auto newStoreOp =
-//       cast<xegpu::StoreNdOp>(rewriter.clone(*storeOp.getOperation()));
-//   rewriter.eraseOp(storeOp);
-//   newStoreOp.getTensorDescMutable().assign(
-//       newWarpOp.getResult(newRetIndices[0]));
-//   newStoreOp.getValueMutable().assign(newWarpOp.getResult(newRetIndices[1]));
-
-//   return success();
-// }
-
-// LogicalResult WarpOpLoadNd::matchAndRewrite(vector::WarpExecuteOnLane0Op warpOp,
-//                                             PatternRewriter &rewriter) const {
-//   OpOperand *operand = getWarpResult(warpOp, [](Operation *op) {
-//     return isa<xegpu::LoadNdOp>(op) && op->hasOneUse();
-//   });
-
-//   if (!operand)
-//     return rewriter.notifyMatchFailure(warpOp,
-//                                        "warp result is not a xegpu::LoadNd op");
-
-//   auto loadOp = operand->get().getDefiningOp<xegpu::LoadNdOp>();
-
-//   if (loadOp.getPacked())
-//     return rewriter.notifyMatchFailure(
-//         loadOp, "Packed load distribution not supported");
-
-//   xegpu::TensorDescType origType = loadOp.getTensorDescType();
-//   xegpu::SGMapAttr sgMap = origType.getSGMapAttr();
-//   if (!sgMap)
-//     return rewriter.notifyMatchFailure(
-//         loadOp, "the source tensor descriptor lacks sg_map attribute");
-
-//   auto origShape = origType.getShape();
-//   if (origShape.size() != 2)
-//     return rewriter.notifyMatchFailure(loadOp, "unsupported shape");
-
-//   auto distributedTypeOrFailure =
-//       getDistributedVectorType(loadOp.getType(), sgMap);
-//   if (failed(distributedTypeOrFailure))
-//     return rewriter.notifyMatchFailure(loadOp, "Failed to distribute the type");
-//   VectorType newVectorType = distributedTypeOrFailure.value();
-
-//   auto distributedDescTypeOrFailure =
-//       getDistributedTensorDescType(loadOp.getTensorDescType(), sgMap,
-//                                    loadOp.getTensorDescType().getMemorySpace());
-//   if (failed(distributedDescTypeOrFailure))
-//     return rewriter.notifyMatchFailure(loadOp,
-//                                        "Failed to distribute the desc type");
-//   xegpu::TensorDescType newTDescType = distributedDescTypeOrFailure.value();
-
-//   unsigned operandIdx = operand->getOperandNumber();
-
-//   SmallVector<size_t> newRetIndices;
-//   vector::WarpExecuteOnLane0Op newWarpOp =
-//       moveRegionToNewWarpOpAndAppendReturns(
-//           rewriter, warpOp, loadOp.getTensorDesc(), TypeRange{newTDescType},
-//           newRetIndices);
-
-//   rewriter.setInsertionPointAfter(newWarpOp);
-
-//   auto newLoadOp = rewriter.create<xegpu::LoadNdOp>(
-//       loadOp.getLoc(), newVectorType, loadOp.getTensorDesc(),
-//       loadOp.getPackedAttr(), loadOp.getTransposeAttr(), loadOp.getL1HintAttr(),
-//       loadOp.getL2HintAttr(), loadOp.getL3HintAttr());
-
-//   newLoadOp.getTensorDescMutable().assign(
-//       newWarpOp.getResult(newRetIndices[0]));
-//   Value distributedVal = newWarpOp.getResult(operandIdx);
-//   rewriter.replaceAllUsesWith(distributedVal, newLoadOp);
-
-//   return success();
-// }
+LogicalResult WarpOpStoreNd::matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
+                                             PatternRewriter &rewriter) const {
+  auto yield = cast<gpu::YieldOp>(
+      warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+  Operation *lastNode = yield->getPrevNode();
+  auto storeOp = dyn_cast_or_null<xegpu::StoreNdOp>(lastNode);
+  if (!storeOp)
+    return failure();
+
+  auto origType = storeOp.getTensorDescType();
+  xegpu::SGMapAttr sgMap = origType.getSGMapAttr();
+  if (!sgMap)
+    return rewriter.notifyMatchFailure(
+        storeOp, "the source tensor descriptor lacks sg_map attribute");
+
+  if (storeOp.getTensorDescType().getShape().size() != 2)
+    return rewriter.notifyMatchFailure(storeOp, "unsupported shape");
+  DBGS() << "Matched store_nd: " << storeOp << "\n";
+
+  auto distributedTypeOrFailure =
+      getDistributedVectorType(storeOp.getValueType(), sgMap);
+  if (failed(distributedTypeOrFailure))
+    return rewriter.notifyMatchFailure(storeOp,
+                                       "Failed to distribute the type");
+  VectorType newVectorType = distributedTypeOrFailure.value();
+
+  auto distributedDescTypeOrFailure = getDistributedTensorDescType(
+      storeOp.getTensorDescType(), sgMap,
+      storeOp.getTensorDescType().getMemorySpace());
+  if (failed(distributedDescTypeOrFailure))
+    return rewriter.notifyMatchFailure(storeOp,
+                                       "Failed to distribute the desc type");
+  xegpu::TensorDescType newTDescType = distributedDescTypeOrFailure.value();
+
+  SmallVector<size_t> newRetIndices;
+  gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+      rewriter, warpOp, ValueRange{storeOp.getTensorDesc(), storeOp.getValue()},
+      TypeRange{newTDescType, newVectorType}, newRetIndices);
+
+  rewriter.setInsertionPointAfter(newWarpOp);
+  auto newStoreOp =
+      cast<xegpu::StoreNdOp>(rewriter.clone(*storeOp.getOperation()));
+  rewriter.eraseOp(storeOp);
+  newStoreOp.getTensorDescMutable().assign(
+      newWarpOp.getResult(newRetIndices[0]));
+  newStoreOp.getValueMutable().assign(newWarpOp.getResult(newRetIndices[1]));
+
+  return success();
+}
+
+LogicalResult WarpOpLoadNd::matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
+                                            PatternRewriter &rewriter) const {
+  OpOperand *operand = getWarpResult(warpOp, [](Operation *op) {
+    return isa<xegpu::LoadNdOp>(op) && op->hasOneUse();
+  });
+
+  if (!operand)
+    return rewriter.notifyMatchFailure(warpOp,
+                                       "warp result is not a xegpu::LoadNd op");
+
+  auto loadOp = operand->get().getDefiningOp<xegpu::LoadNdOp>();
+
+  if (loadOp.getPacked())
+    return rewriter.notifyMatchFailure(
+        loadOp, "Packed load distribution not supported");
+
+  xegpu::TensorDescType origType = loadOp.getTensorDescType();
+  xegpu::SGMapAttr sgMap = origType.getSGMapAttr();
+  if (!sgMap)
+    return rewriter.notifyMatchFailure(
+        loadOp, "the source tensor descriptor lacks sg_map attribute");
+
+  auto origShape = origType.getShape();
+  if (origShape.size() != 2)
+    return rewriter.notifyMatchFailure(loadOp, "unsupported shape");
+
+  auto distributedTypeOrFailure =
+      getDistributedVectorType(loadOp.getType(), sgMap);
+  if (failed(distributedTypeOrFailure))
+    return rewriter.notifyMatchFailure(loadOp, "Failed to distribute the type");
+  VectorType newVectorType = distributedTypeOrFailure.value();
+
+  auto distributedDescTypeOrFailure =
+      getDistributedTensorDescType(loadOp.getTensorDescType(), sgMap,
+                                   loadOp.getTensorDescType().getMemorySpace());
+  if (failed(distributedDescTypeOrFailure))
+    return rewriter.notifyMatchFailure(loadOp,
+                                       "Failed to distribute the desc type");
+  xegpu::TensorDescType newTDescType = distributedDescTypeOrFailure.value();
+
+  unsigned operandIdx = operand->getOperandNumber();
+
+  SmallVector<size_t> newRetIndices;
+  gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+      rewriter, warpOp, loadOp.getTensorDesc(), TypeRange{newTDescType},
+      newRetIndices);
+
+  rewriter.setInsertionPointAfter(newWarpOp);
+
+  auto newLoadOp = rewriter.create<xegpu::LoadNdOp>(
+      loadOp.getLoc(), newVectorType, loadOp.getTensorDesc(),
+      loadOp.getPackedAttr(), loadOp.getTransposeAttr(), loadOp.getL1HintAttr(),
+      loadOp.getL2HintAttr(), loadOp.getL3HintAttr());
+
+  newLoadOp.getTensorDescMutable().assign(
+      newWarpOp.getResult(newRetIndices[0]));
+  Value distributedVal = newWarpOp.getResult(operandIdx);
+  rewriter.replaceAllUsesWith(distributedVal, newLoadOp);
+
+  return success();
+}
 
 LogicalResult
 WarpOpTensorDescOp::matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
@@ -369,10 +373,9 @@ WarpOpTensorDescOp::matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
       getAsIndexOpFoldResult(rewriter.getContext(), distributedShape);
 
   SmallVector<size_t> newRetIndices;
-  gpu::WarpExecuteOnLane0Op newWarpOp =
-      moveRegionToNewWarpOpAndAppendReturns(
-          rewriter, warpOp, descOp.getSource(), descOp.getSourceType(),
-          newRetIndices);
+  gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+      rewriter, warpOp, descOp.getSource(), descOp.getSourceType(),
+      newRetIndices);
 
   rewriter.setInsertionPointAfter(newWarpOp);
   auto subview = rewriter.create<memref::SubViewOp>(
@@ -393,6 +396,6 @@ WarpOpTensorDescOp::matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
 
 void xegpu::populateXeGPUDistributePatterns(RewritePatternSet &patterns) {
   patterns.add<WarpOpTensorDescOp>(patterns.getContext());
-  // patterns.add<WarpOpStoreNd>(patterns.getContext());
-  // patterns.add<WarpOpLoadNd>(patterns.getContext());
+  patterns.add<WarpOpStoreNd>(patterns.getContext());
+  patterns.add<WarpOpLoadNd>(patterns.getContext());
 }

>From 07f9f9f77953c4278912afd12910f2264f3e36e3 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Thu, 12 Dec 2024 22:37:28 +0000
Subject: [PATCH 6/7] fix

---
 .../XeGPU/Transforms/XeGPUDistribute.cpp      | 29 ++----------------
 mlir/test/Dialect/XeGPU/xegpu-distribute.mlir | 30 ++++++++-----------
 2 files changed, 16 insertions(+), 43 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUDistribute.cpp
index 95610a9af7ac85..297cff8d0144d5 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUDistribute.cpp
@@ -341,24 +341,6 @@ WarpOpTensorDescOp::matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
     return rewriter.notifyMatchFailure(
         descOp, "the tensor descriptor lacks sg_map attribute");
 
-  auto layout = sgMap.getWiLayout();
-
-  // Calculate the offset within tensor descriptor for the current lane_id. The
-  // access to proper element for a work item is done through a lane-specific
-  // subview (tdesc offsets are used as base, lane shift is added on top).
-  auto laneid = warpOp.getLaneid();
-  auto xDim =
-      rewriter.create<arith::ConstantIndexOp>(laneid.getLoc(), layout[0]);
-  auto shiftx = rewriter.create<arith::RemUIOp>(laneid.getLoc(), laneid, xDim);
-  auto shifty = rewriter.create<arith::DivUIOp>(laneid.getLoc(), laneid, xDim);
-
-  auto basex = getValueOrCreateConstantIndexOp(rewriter, laneid.getLoc(),
-                                               descOffsets[0]);
-  auto basey = getValueOrCreateConstantIndexOp(rewriter, laneid.getLoc(),
-                                               descOffsets[1]);
-  auto offsetx = rewriter.create<arith::AddIOp>(laneid.getLoc(), shiftx, basex);
-  auto offsety = rewriter.create<arith::AddIOp>(laneid.getLoc(), shifty, basey);
-
   auto distributedDescTypeOrFailure = getDistributedTensorDescType(
       descOp.getType(), sgMap, descOp.getType().getMemorySpace());
   if (failed(distributedDescTypeOrFailure))
@@ -378,15 +360,10 @@ WarpOpTensorDescOp::matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
       newRetIndices);
 
   rewriter.setInsertionPointAfter(newWarpOp);
-  auto subview = rewriter.create<memref::SubViewOp>(
-      newWarpOp.getLoc(), srcTypedVal, getAsOpFoldResult({offsetx, offsety}),
-      overwriteSizes, overwriteStrides);
-  subview.getSourceMutable().assign(newWarpOp.getResult(newRetIndices[0]));
-
-  auto zero = rewriter.create<arith::ConstantIndexOp>(laneid.getLoc(), 0);
   auto newDescOp = rewriter.create<xegpu::CreateNdDescOp>(
-      newWarpOp.getLoc(), newTDescType, subview,
-      getAsOpFoldResult({zero, zero}));
+      newWarpOp.getLoc(), newTDescType,
+      dyn_cast<TypedValue<MemRefType>>(newWarpOp.getResult(newRetIndices[0])),
+      descOffsets);
 
   Value distributedVal = newWarpOp.getResult(operandIdx);
   rewriter.replaceAllUsesWith(distributedVal, newDescOp);
diff --git a/mlir/test/Dialect/XeGPU/xegpu-distribute.mlir b/mlir/test/Dialect/XeGPU/xegpu-distribute.mlir
index f9efda80ab6468..755f4922bfaa23 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-distribute.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-distribute.mlir
@@ -8,9 +8,9 @@
 // CHECK: %[[res:.*]]:2 = gpu.warp_execute_on_lane_0(%[[laneid]])[16] args(%{{.*}}, %{{.*}} :  vector<24x32xf16>, !xegpu.tensor_desc<24x32xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>)
 // CHECK-SAME: -> (!xegpu.tensor_desc<24x32xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>, vector<24x2xf16>)
 // CHECK: ^bb0(%[[src:.*]]: vector<24x32xf16>, %[[dst:.*]]: !xegpu.tensor_desc<24x32xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>)
-// CHECK: gpu.yield%[[dst]], %[[src]] : !xegpu.tensor_desc<24x32xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>, vector<24x32xf16>
+// CHECK: gpu.yield %[[dst]], %[[src]] : !xegpu.tensor_desc<24x32xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>, vector<24x32xf16>
 // CHECK: xegpu.store_nd %[[res]]#1, %[[res]]#0 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}> :
-// CHECK-SAME: vector<24x2xf16>, !xegpu.tensor_desc<24x32xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+// CHECK: vector<24x2xf16>, !xegpu.tensor_desc<24x32xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
 
 func.func @test_store_nd_distribution(%src: vector<24x32xf16>, %dst: !xegpu.tensor_desc<24x32xf16, #blk_tdesc, #sg_map_16>) -> () {
   %laneid = gpu.lane_id
@@ -23,7 +23,6 @@ func.func @test_store_nd_distribution(%src: vector<24x32xf16>, %dst: !xegpu.tens
 }
 
 // -----
-
 #sg_map_16 = #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>
 #blk_tdesc = #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>
 
@@ -33,7 +32,7 @@ func.func @test_store_nd_distribution(%src: vector<24x32xf16>, %dst: !xegpu.tens
 // CHECK-SAME: -> (vector<24x2xf16>, !xegpu.tensor_desc<24x32xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>)
 // CHECK: ^bb0(%[[dst:.*]]: !xegpu.tensor_desc<24x32xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>)
 // CHECK: %[[dead:.*]] = xegpu.load_nd
-// CHECK: gpu.yield%[[dead]], %[[dst]] : vector<24x32xf16>, !xegpu.tensor_desc<24x32xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+// CHECK: gpu.yield %[[dead]], %[[dst]] : vector<24x32xf16>, !xegpu.tensor_desc<24x32xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
 // CHECK: %[[load:.*]] = xegpu.load_nd %[[res]]#1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> :
 // CHECK-SAME: !xegpu.tensor_desc<24x32xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>> -> vector<24x2xf16>
 // CHECK: return %[[load]]
@@ -56,26 +55,23 @@ func.func @test_load_nd_distribution(%dst: !xegpu.tensor_desc<24x32xf16, #blk_td
 #blk_tdesc = #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>
 
 // CHECK-LABEL: test_create_nd_desc_distribution
-// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C12:.*]] = arith.constant 12 : index
 // CHECK: %[[laneid:.*]] = gpu.lane_id
 // CHECK: %[[res:.*]]:2 = gpu.warp_execute_on_lane_0(%[[laneid]])[16] args(%{{.*}} : memref<24x32xf16>)
-// CHECK-SAME: -> (!xegpu.tensor_desc<24x32xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>, memref<24x32xf16>)
+// CHECK-SAME: -> (!xegpu.tensor_desc<12x32xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>, memref<24x32xf16>)
 // CHECK: ^bb0(%[[dst:.*]]: memref<24x32xf16>)
 // CHECK: %[[dead:.*]] = xegpu.create_nd_tdesc
-// CHECK: gpu.yield%[[dead]], %[[dst]] :
-// CHECK-SAME: !xegpu.tensor_desc<24x32xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>, memref<24x32xf16>
-// CHECK: %[[view:.*]] = memref.subview %[[res]]#1[%[[C0]], %[[laneid]]] [24, 2] [1, 1] : memref<24x32xf16> to memref<24x2xf16, strided<[32, 1], offset: ?>>
-// CHECK: %[[desc:.*]] = xegpu.create_nd_tdesc %[[view]][0, 0] : memref<24x2xf16, strided<[32, 1], offset: ?>>
-// CHECK-SAME: -> !xegpu.tensor_desc<24x32xf16, #xegpu.block_tdesc_attr<memory_space =  global, array_length = 1 : i64, boundary_check = true>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
-// CHECK: return %[[desc]]
+// CHECK: gpu.yield %[[dead]], %[[dst]] :
+
 
-func.func @test_create_nd_desc_distribution(%dst: memref<24x32xf16>) -> (!xegpu.tensor_desc<24x32xf16, #blk_tdesc, #sg_map_16>) {
+func.func @test_create_nd_desc_distribution(%dst: memref<24x32xf16>) -> (!xegpu.tensor_desc<12x32xf16, #blk_tdesc, #sg_map_16>) {
   %laneid = gpu.lane_id
+  %c12 = arith.constant 12 : index
   %r = gpu.warp_execute_on_lane_0(%laneid)[16]
-        args(%dst: memref<24x32xf16>) -> (!xegpu.tensor_desc<24x32xf16, #blk_tdesc, #sg_map_16>) {
+        args(%dst: memref<24x32xf16>) -> (!xegpu.tensor_desc<12x32xf16, #blk_tdesc, #sg_map_16>) {
     ^bb0(%arg0: memref<24x32xf16>):
-    %0 = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<24x32xf16, #blk_tdesc, #sg_map_16>
-    gpu.yield%0 : !xegpu.tensor_desc<24x32xf16, #blk_tdesc, #sg_map_16>
+    %0 = xegpu.create_nd_tdesc %arg0[%c12, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<12x32xf16, #blk_tdesc, #sg_map_16>
+    gpu.yield%0 : !xegpu.tensor_desc<12x32xf16, #blk_tdesc, #sg_map_16>
   }
-  return %r : !xegpu.tensor_desc<24x32xf16, #blk_tdesc, #sg_map_16>
+  return %r : !xegpu.tensor_desc<12x32xf16, #blk_tdesc, #sg_map_16>
 }
\ No newline at end of file

>From b842f331ffab253aaf8924b5a6574d35c4fdcf9d Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Thu, 12 Dec 2024 22:45:02 +0000
Subject: [PATCH 7/7] fix

---
 .../XeGPU/Transforms/XeGPUDistribute.cpp      | 200 +++++++++---------
 1 file changed, 99 insertions(+), 101 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUDistribute.cpp
index 297cff8d0144d5..e3f48919c39266 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUDistribute.cpp
@@ -31,46 +31,46 @@ using namespace mlir;
 namespace {
 bool divisible(APInt lhs, APInt rhs) { return !lhs.urem(rhs); }
 
-// /// Clone a create_nd_tdesc feeding into vector.yield op for the enclosing
-// /// `vector.warp_execute_on_lane_0` and put it after the warp op.
-// /// The warp op will still contain the original op that will not be used by
-// the
-// /// yield op (and should be cleaned up later with dce). The yield op will
-// bypass
-// /// the create_nd_tdesc's arguments.
-// /// The rewrite will create a subview of the size used by a single work item
-// and
-// /// appropriate offset. The distributed create_nd_tdesc points into the
-// subview
-// /// without offset. The tensor descriptor types is distributed according to
-// /// sg_map attribute.
-// ///
-// /// Example:
-// ///
-// /// ```
-// ///   #sg_map_8 = #xegpu.sg_map<wi_layout = [1, 8], wi_data = [1, 1]>
-// ///   %r = vector.warp_execute_on_lane_0(%laneid) ->
-// ///                   (!xegpu.tensor_desc<4x8xf32>) {
-// ///     ...
-// ///     %td = xegpu.create_nd_tdesc %arg0[0, 0]
-// ///               : memref<4x8xf32> -> !xegpu.tensor_desc<4x8xf32>
-// ///     vector.yield %td
-// ///   }
-// /// ```
-// /// To
-// /// ```
-// ///   %r:2 = vector.warp_execute_on_lane_0(%laneid) -> () {
-// ///     ...
-// ///     %dead = xegpu.create_nd_tdesc %arg0[0, 0]
-// ///               : memref<4x8xf32> -> !xegpu.tensor_desc<4x8xf32>
-// ///     vector.yield %arg0, %dead
-// ///   }
-// ///   %view = memref.subview %r#0[0, %laneid] [4, 1] [1, 1]
-// ///                               : memref<4x8xf32> to memref<4x1xf32>
-// ///   %td = xegpu.create_nd_tdesc %view[0, 0]: memref<4x1xf32>
-// ///                                 -> !xegpu.tensor_desc<4x1xf32>
-// ///
-// /// ```
+/// Clone a create_nd_tdesc feeding into vector.yield op for the enclosing
+/// `vector.warp_execute_on_lane_0` and put it after the warp op.
+/// The warp op will still contain the original op that will not be used by
+/// the
+/// yield op (and should be cleaned up later with dce). The yield op will
+/// bypass
+/// the create_nd_tdesc's arguments.
+/// The rewrite will create a subview of the size used by a single work item
+/// and
+/// appropriate offset. The distributed create_nd_tdesc points into the
+/// subview
+/// without offset. The tensor descriptor types is distributed according to
+/// sg_map attribute.
+///
+/// Example:
+///
+/// ```
+///   #sg_map_8 = #xegpu.sg_map<wi_layout = [1, 8], wi_data = [1, 1]>
+///   %r = vector.warp_execute_on_lane_0(%laneid) ->
+///                   (!xegpu.tensor_desc<4x8xf32>) {
+///     ...
+///     %td = xegpu.create_nd_tdesc %arg0[0, 0]
+///               : memref<4x8xf32> -> !xegpu.tensor_desc<4x8xf32>
+///     vector.yield %td
+///   }
+/// ```
+/// To
+/// ```
+///   %r:2 = vector.warp_execute_on_lane_0(%laneid) -> () {
+///     ...
+///     %dead = xegpu.create_nd_tdesc %arg0[0, 0]
+///               : memref<4x8xf32> -> !xegpu.tensor_desc<4x8xf32>
+///     vector.yield %arg0, %dead
+///   }
+///   %view = memref.subview %r#0[0, %laneid] [4, 1] [1, 1]
+///                               : memref<4x8xf32> to memref<4x1xf32>
+///   %td = xegpu.create_nd_tdesc %view[0, 0]: memref<4x1xf32>
+///                                 -> !xegpu.tensor_desc<4x1xf32>
+///
+/// ```
 struct WarpOpTensorDescOp final
     : public OpRewritePattern<gpu::WarpExecuteOnLane0Op> {
   using OpRewritePattern<gpu::WarpExecuteOnLane0Op>::OpRewritePattern;
@@ -78,35 +78,35 @@ struct WarpOpTensorDescOp final
                                 PatternRewriter &rewriter) const override;
 };
 
-// /// Sink a store_nd feeding into vector.yield op for the enclosing
-// /// `vector.warp_execute_on_lane_0`. In case arguments for the store are
-// passed
-// /// through the warp op interface they would be propagated as returned
-// values.
-// /// Both the stored vector type and tensor descriptor types are distributed
-// /// according to sg_map attribute.
-// ///
-// /// Example:
-// ///
-// /// ```
-// ///   #sg_map_8 = #xegpu.sg_map<wi_layout = [1, 8], wi_data = [1, 1]>
-// ///   vector.warp_execute_on_lane_0(%laneid) -> () {
-// ///     ...
-// ///     xegpu.store_nd %arg0, %arg1: vector<4x8xf32>,
-// ///                                 !xegpu.tensor_desc<4x8xf32>
-// ///     vector.yield
-// ///   }
-// /// ```
-// /// To
-// /// ```
-// ///   %r = vector.warp_execute_on_lane_0(%laneid) -> () {
-// ///     ...
-// ///     vector.yield
-// ///   }
-// ///   xegpu.store_nd %arg0, %arg1: vector<4x1xf32>,
-// !xegpu.tensor_desc<4x1xf32>
-// ///
-// /// ```
+/// Sink a store_nd feeding into vector.yield op for the enclosing
+/// `vector.warp_execute_on_lane_0`. In case arguments for the store are
+/// passed
+/// through the warp op interface they would be propagated as returned
+/// values.
+/// Both the stored vector type and tensor descriptor types are distributed
+/// according to sg_map attribute.
+///
+/// Example:
+///
+/// ```
+///   #sg_map_8 = #xegpu.sg_map<wi_layout = [1, 8], wi_data = [1, 1]>
+///   vector.warp_execute_on_lane_0(%laneid) -> () {
+///     ...
+///     xegpu.store_nd %arg0, %arg1: vector<4x8xf32>,
+///                                 !xegpu.tensor_desc<4x8xf32>
+///     vector.yield
+///   }
+/// ```
+/// To
+/// ```
+///   %r = vector.warp_execute_on_lane_0(%laneid) -> () {
+///     ...
+///     vector.yield
+///   }
+///   xegpu.store_nd %arg0, %arg1: vector<4x1xf32>,
+///   !xegpu.tensor_desc<4x1xf32>
+///
+/// ```
 struct WarpOpStoreNd final
     : public OpRewritePattern<gpu::WarpExecuteOnLane0Op> {
   using OpRewritePattern<gpu::WarpExecuteOnLane0Op>::OpRewritePattern;
@@ -114,38 +114,36 @@ struct WarpOpStoreNd final
                                 PatternRewriter &rewriter) const override;
 };
 
-// /// Clone a load_nd feeding into vector.yield op for the enclosing
-// /// `vector.warp_execute_on_lane_0` and put it after the warp op.
-// /// The warp op will still contain the original op that will not be used by
-// the
-// /// yield op (and should be cleaned up later with dce). The yield op will
-// bypass
-// /// the load's arguments.
-// /// Both the loaded vector type and tensor descriptor types are distributed
-// /// according to sg_map attribute.
-// ///
-// /// Example:
-// ///
-// /// ```
-// ///   #sg_map_8 = #xegpu.sg_map<wi_layout = [1, 8], wi_data = [1, 1]>
-// ///   %r = vector.warp_execute_on_lane_0(%laneid) ->
-// ///                   (!xegpu.tensor_desc<4x8xf32>) {
-// ///     ...
-// ///     %ld = xegpu.load_nd %arg0, %arg1: !xegpu.tensor_desc<4x8xf32>,
-// ///     vector<4x8xf32> vector.yield %ld
-// ///   }
-// /// ```
-// /// To
-// /// ```
-// ///   %r:2 = vector.warp_execute_on_lane_0(%laneid) -> () {
-// ///     ...
-// ///     %dead = xegpu.load_nd %arg0, %arg1:
-// ///         !xegpu.tensor_desc<4x8xf32>, vector<4x8xf32>
-// ///     vector.yield %arg0, %arg1
-// ///   }
-// ///   xegpu.store_nd %r#0, %r#1: vector<4x1xf32>, !xegpu.tensor_desc<4x1xf32>
-// ///
-// /// ```
+/// Clone a load_nd feeding into vector.yield op for the enclosing
+/// `vector.warp_execute_on_lane_0` and put it after the warp op.
+/// The warp op will still contain the original op that will not be used by
+/// the yield op (and should be cleaned up later with dce). The yield op will
+/// bypass the load's arguments. Both the loaded vector type and tensor
+/// descriptor types are distributed according to sg_map attribute.
+///
+/// Example:
+///
+/// ```
+///   #sg_map_8 = #xegpu.sg_map<wi_layout = [1, 8], wi_data = [1, 1]>
+///   %r = vector.warp_execute_on_lane_0(%laneid) ->
+///                   (!xegpu.tensor_desc<4x8xf32>) {
+///     ...
+///     %ld = xegpu.load_nd %arg0, %arg1: !xegpu.tensor_desc<4x8xf32>,
+///     vector<4x8xf32> vector.yield %ld
+///   }
+/// ```
+/// To
+/// ```
+///   %r:2 = vector.warp_execute_on_lane_0(%laneid) -> () {
+///     ...
+///     %dead = xegpu.load_nd %arg0, %arg1:
+///         !xegpu.tensor_desc<4x8xf32>, vector<4x8xf32>
+///     vector.yield %arg0, %arg1
+///   }
+///   xegpu.store_nd %r#0, %r#1: vector<4x1xf32>,
+///   !xegpu.tensor_desc<4x1xf32>
+///
+/// ```
 struct WarpOpLoadNd final : public OpRewritePattern<gpu::WarpExecuteOnLane0Op> {
   using OpRewritePattern<gpu::WarpExecuteOnLane0Op>::OpRewritePattern;
   LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,



More information about the Mlir-commits mailing list