[Mlir-commits] [mlir] [MLIR][XeGPU] Xegpu distribution patterns for load_nd, store_nd, and create_nd_tdesc. (PR #112945)

Petr Kurapov llvmlistbot at llvm.org
Fri Oct 18 10:51:33 PDT 2024


https://github.com/kurapov-peter created https://github.com/llvm/llvm-project/pull/112945

This PR introduces distribution patterns for a portion of xegpu dialect similarly to the vector dialect, as well as moving some of the common functionality to the vector utilities.

Xegpu ops rewrite patterns distribute the vector and xegpu tensor descriptor types when sinked through yield op of a `vector.warp_execute_on_lane_0` according to the `xegpu.sg_map` attribute. The validation of distributed types in the `warp_execute_on_lane_0` was hence relaxed to allow `ShapedType` return values to have distributed shapes.

>From d292605dbd568e946efeac0260e7d16d416ac95c Mon Sep 17 00:00:00 2001
From: Petr Kurapov <petr.a.kurapov at intel.com>
Date: Thu, 17 Oct 2024 17:29:00 +0000
Subject: [PATCH 1/3] [MLIR][Vector] Allow any shaped typed to be distributed
 for vector.warp_execute_on_lane_0's return values

---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 12 ++++++------
 mlir/test/Dialect/Vector/invalid.mlir    |  6 +++---
 2 files changed, 9 insertions(+), 9 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index a2abe1619454f2..51d3691fd107ae 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -6558,14 +6558,14 @@ static LogicalResult verifyDistributedType(Type expanded, Type distributed,
   // If the types matches there is no distribution.
   if (expanded == distributed)
     return success();
-  auto expandedVecType = llvm::dyn_cast<VectorType>(expanded);
-  auto distributedVecType = llvm::dyn_cast<VectorType>(distributed);
+  auto expandedVecType = llvm::dyn_cast<ShapedType>(expanded);
+  auto distributedVecType = llvm::dyn_cast<ShapedType>(distributed);
   if (!expandedVecType || !distributedVecType)
-    return op->emitOpError("expected vector type for distributed operands.");
+    return op->emitOpError("expected shaped type for distributed operands.");
   if (expandedVecType.getRank() != distributedVecType.getRank() ||
       expandedVecType.getElementType() != distributedVecType.getElementType())
     return op->emitOpError(
-        "expected distributed vectors to have same rank and element type.");
+        "expected distributed types to have same rank and element type.");
 
   SmallVector<int64_t> scales(expandedVecType.getRank(), 1);
   for (int64_t i = 0, e = expandedVecType.getRank(); i < e; i++) {
@@ -6575,8 +6575,8 @@ static LogicalResult verifyDistributedType(Type expanded, Type distributed,
       continue;
     if (eDim % dDim != 0)
       return op->emitOpError()
-             << "expected expanded vector dimension #" << i << " (" << eDim
-             << ") to be a multipler of the distributed vector dimension ("
+             << "expected expanded type dimension #" << i << " (" << eDim
+             << ") to be a multipler of the distributed type dimension ("
              << dDim << ")";
     scales[i] = eDim / dDim;
   }
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 36d04bb77e3b96..69346574177929 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1665,7 +1665,7 @@ func.func @warp_2_distributed_dims(%laneid: index) {
 // -----
 
 func.func @warp_2_distributed_dims(%laneid: index) {
-  // expected-error at +1 {{expected expanded vector dimension #1 (8) to be a multipler of the distributed vector dimension (3)}}
+  // expected-error at +1 {{expected expanded type dimension #1 (8) to be a multipler of the distributed type dimension (3)}}
   %2 = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<1x3xi32>) {
     %0 = arith.constant dense<2>: vector<4x8xi32>
     vector.yield %0 : vector<4x8xi32>
@@ -1676,7 +1676,7 @@ func.func @warp_2_distributed_dims(%laneid: index) {
 // -----
 
 func.func @warp_mismatch_rank(%laneid: index) {
-  // expected-error at +1 {{'vector.warp_execute_on_lane_0' op expected distributed vectors to have same rank and element type.}}
+  // expected-error at +1 {{'vector.warp_execute_on_lane_0' op expected distributed types to have same rank and element type.}}
   %2 = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<4x4xi32>) {
     %0 = arith.constant dense<2>: vector<128xi32>
     vector.yield %0 : vector<128xi32>
@@ -1687,7 +1687,7 @@ func.func @warp_mismatch_rank(%laneid: index) {
 // -----
 
 func.func @warp_mismatch_rank(%laneid: index) {
-  // expected-error at +1 {{'vector.warp_execute_on_lane_0' op expected vector type for distributed operands.}}
+  // expected-error at +1 {{'vector.warp_execute_on_lane_0' op expected shaped type for distributed operands.}}
   %2 = vector.warp_execute_on_lane_0(%laneid)[32] -> (i32) {
     %0 = arith.constant dense<2>: vector<128xi32>
     vector.yield %0 : vector<128xi32>

>From 7627cf747d98c33ce24e585f4b695962d681139d Mon Sep 17 00:00:00 2001
From: Petr Kurapov <petr.a.kurapov at intel.com>
Date: Fri, 18 Oct 2024 15:50:31 +0000
Subject: [PATCH 2/3] [MLIR][Vector] Move helper functions to vector
 distribution utils

---
 .../mlir/Dialect/Vector/Utils/VectorUtils.h   | 20 ++++
 .../Vector/Transforms/VectorDistribute.cpp    | 80 +---------------
 mlir/lib/Dialect/Vector/Utils/CMakeLists.txt  |  1 +
 .../Vector/Utils/VectorDistributeUtils.cpp    | 96 +++++++++++++++++++
 4 files changed, 118 insertions(+), 79 deletions(-)
 create mode 100644 mlir/lib/Dialect/Vector/Utils/VectorDistributeUtils.cpp

diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
index 5f32aca88a2734..6bd924307376dc 100644
--- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
+++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
@@ -20,6 +20,8 @@
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/TypeSwitch.h"
 
+#include <utility>
+
 namespace mlir {
 
 // Forward declarations.
@@ -324,6 +326,24 @@ namespace matcher {
 bool operatesOnSuperVectorsOf(Operation &op, VectorType subVectorType);
 
 } // namespace matcher
+
+/// Return a value yielded by `warpOp` which statifies the filter lamdba
+/// condition and is not dead.
+OpOperand *getWarpResult(vector::WarpExecuteOnLane0Op warpOp,
+                         const std::function<bool(Operation *)> &fn);
+
+/// Helper to create a new WarpExecuteOnLane0Op with different signature.
+vector::WarpExecuteOnLane0Op moveRegionToNewWarpOpAndReplaceReturns(
+    RewriterBase &rewriter, vector::WarpExecuteOnLane0Op warpOp,
+    ValueRange newYieldedValues, TypeRange newReturnTypes);
+
+/// Helper to create a new WarpExecuteOnLane0Op region with extra outputs.
+/// `indices` return the index of each new output.
+vector::WarpExecuteOnLane0Op moveRegionToNewWarpOpAndAppendReturns(
+    RewriterBase &rewriter, vector::WarpExecuteOnLane0Op warpOp,
+    ValueRange newYieldedValues, TypeRange newReturnTypes,
+    llvm::SmallVector<size_t> &indices);
+
 } // namespace mlir
 
 #endif // MLIR_DIALECT_VECTOR_UTILS_VECTORUTILS_H_
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 2289fd1ff1364e..c80c3179b5e025 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -12,6 +12,7 @@
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
 #include "mlir/IR/AffineExpr.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "mlir/Transforms/RegionUtils.h"
@@ -160,68 +161,6 @@ struct DistributedLoadStoreHelper {
 
 } // namespace
 
-/// Helper to create a new WarpExecuteOnLane0Op with different signature.
-static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndReplaceReturns(
-    RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
-    ValueRange newYieldedValues, TypeRange newReturnTypes) {
-  // Create a new op before the existing one, with the extra operands.
-  OpBuilder::InsertionGuard g(rewriter);
-  rewriter.setInsertionPoint(warpOp);
-  auto newWarpOp = rewriter.create<WarpExecuteOnLane0Op>(
-      warpOp.getLoc(), newReturnTypes, warpOp.getLaneid(), warpOp.getWarpSize(),
-      warpOp.getArgs(), warpOp.getBody()->getArgumentTypes());
-
-  Region &opBody = warpOp.getBodyRegion();
-  Region &newOpBody = newWarpOp.getBodyRegion();
-  Block &newOpFirstBlock = newOpBody.front();
-  rewriter.inlineRegionBefore(opBody, newOpBody, newOpBody.begin());
-  rewriter.eraseBlock(&newOpFirstBlock);
-  assert(newWarpOp.getWarpRegion().hasOneBlock() &&
-         "expected WarpOp with single block");
-
-  auto yield =
-      cast<vector::YieldOp>(newOpBody.getBlocks().begin()->getTerminator());
-
-  rewriter.modifyOpInPlace(
-      yield, [&]() { yield.getOperandsMutable().assign(newYieldedValues); });
-  return newWarpOp;
-}
-
-/// Helper to create a new WarpExecuteOnLane0Op region with extra outputs.
-/// `indices` return the index of each new output.
-static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndAppendReturns(
-    RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
-    ValueRange newYieldedValues, TypeRange newReturnTypes,
-    llvm::SmallVector<size_t> &indices) {
-  SmallVector<Type> types(warpOp.getResultTypes().begin(),
-                          warpOp.getResultTypes().end());
-  auto yield = cast<vector::YieldOp>(
-      warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
-  llvm::SmallSetVector<Value, 32> yieldValues(yield.getOperands().begin(),
-                                              yield.getOperands().end());
-  for (auto newRet : llvm::zip(newYieldedValues, newReturnTypes)) {
-    if (yieldValues.insert(std::get<0>(newRet))) {
-      types.push_back(std::get<1>(newRet));
-      indices.push_back(yieldValues.size() - 1);
-    } else {
-      // If the value already exit the region don't create a new output.
-      for (auto [idx, yieldOperand] :
-           llvm::enumerate(yieldValues.getArrayRef())) {
-        if (yieldOperand == std::get<0>(newRet)) {
-          indices.push_back(idx);
-          break;
-        }
-      }
-    }
-  }
-  yieldValues.insert(newYieldedValues.begin(), newYieldedValues.end());
-  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
-      rewriter, warpOp, yieldValues.getArrayRef(), types);
-  rewriter.replaceOp(warpOp,
-                     newWarpOp.getResults().take_front(warpOp.getNumResults()));
-  return newWarpOp;
-}
-
 /// Helper to know if an op can be hoisted out of the region.
 static bool canBeHoisted(Operation *op,
                          function_ref<bool(Value)> definedOutside) {
@@ -229,23 +168,6 @@ static bool canBeHoisted(Operation *op,
          isMemoryEffectFree(op) && op->getNumRegions() == 0;
 }
 
-/// Return a value yielded by `warpOp` which statifies the filter lamdba
-/// condition and is not dead.
-static OpOperand *getWarpResult(WarpExecuteOnLane0Op warpOp,
-                                const std::function<bool(Operation *)> &fn) {
-  auto yield = cast<vector::YieldOp>(
-      warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
-  for (OpOperand &yieldOperand : yield->getOpOperands()) {
-    Value yieldValues = yieldOperand.get();
-    Operation *definedOp = yieldValues.getDefiningOp();
-    if (definedOp && fn(definedOp)) {
-      if (!warpOp.getResult(yieldOperand.getOperandNumber()).use_empty())
-        return &yieldOperand;
-    }
-  }
-  return {};
-}
-
 // Clones `op` into a new operation that takes `operands` and returns
 // `resultTypes`.
 static Operation *cloneOpWithOperandsAndTypes(RewriterBase &rewriter,
diff --git a/mlir/lib/Dialect/Vector/Utils/CMakeLists.txt b/mlir/lib/Dialect/Vector/Utils/CMakeLists.txt
index fa3971695d4bf2..9db0c172fec5ce 100644
--- a/mlir/lib/Dialect/Vector/Utils/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/Utils/CMakeLists.txt
@@ -1,5 +1,6 @@
 add_mlir_dialect_library(MLIRVectorUtils
   VectorUtils.cpp
+  VectorDistributeUtils.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Vector/Utils
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorDistributeUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorDistributeUtils.cpp
new file mode 100644
index 00000000000000..f41581c6d47f2a
--- /dev/null
+++ b/mlir/lib/Dialect/Vector/Utils/VectorDistributeUtils.cpp
@@ -0,0 +1,96 @@
+//===- VectorDistributeUtils.cpp - MLIR Utilities VectorOps distribution  -===//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements utility methods for working with the Vector dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
+
+using namespace mlir;
+
+/// Return a value yielded by `warpOp` which statifies the filter lamdba
+/// condition and is not dead.
+mlir::OpOperand *
+mlir::getWarpResult(vector::WarpExecuteOnLane0Op warpOp,
+                    const std::function<bool(Operation *)> &fn) {
+  auto yield = cast<vector::YieldOp>(
+      warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+  for (mlir::OpOperand &yieldOperand : yield->getOpOperands()) {
+    Value yieldValues = yieldOperand.get();
+    Operation *definedOp = yieldValues.getDefiningOp();
+    if (definedOp && fn(definedOp)) {
+      if (!warpOp.getResult(yieldOperand.getOperandNumber()).use_empty())
+        return &yieldOperand;
+    }
+  }
+  return {};
+}
+
+/// Helper to create a new WarpExecuteOnLane0Op with different signature.
+vector::WarpExecuteOnLane0Op mlir::moveRegionToNewWarpOpAndReplaceReturns(
+    RewriterBase &rewriter, vector::WarpExecuteOnLane0Op warpOp,
+    ValueRange newYieldedValues, TypeRange newReturnTypes) {
+  // Create a new op before the existing one, with the extra operands.
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(warpOp);
+  auto newWarpOp = rewriter.create<vector::WarpExecuteOnLane0Op>(
+      warpOp.getLoc(), newReturnTypes, warpOp.getLaneid(), warpOp.getWarpSize(),
+      warpOp.getArgs(), warpOp.getBody()->getArgumentTypes());
+
+  Region &opBody = warpOp.getBodyRegion();
+  Region &newOpBody = newWarpOp.getBodyRegion();
+  Block &newOpFirstBlock = newOpBody.front();
+  rewriter.inlineRegionBefore(opBody, newOpBody, newOpBody.begin());
+  rewriter.eraseBlock(&newOpFirstBlock);
+  assert(newWarpOp.getWarpRegion().hasOneBlock() &&
+         "expected WarpOp with single block");
+
+  auto yield =
+      cast<vector::YieldOp>(newOpBody.getBlocks().begin()->getTerminator());
+
+  rewriter.modifyOpInPlace(
+      yield, [&]() { yield.getOperandsMutable().assign(newYieldedValues); });
+  return newWarpOp;
+}
+
+/// Helper to create a new WarpExecuteOnLane0Op region with extra outputs.
+/// `indices` return the index of each new output.
+vector::WarpExecuteOnLane0Op mlir::moveRegionToNewWarpOpAndAppendReturns(
+    RewriterBase &rewriter, vector::WarpExecuteOnLane0Op warpOp,
+    ValueRange newYieldedValues, TypeRange newReturnTypes,
+    llvm::SmallVector<size_t> &indices) {
+  SmallVector<Type> types(warpOp.getResultTypes().begin(),
+                          warpOp.getResultTypes().end());
+  auto yield = cast<vector::YieldOp>(
+      warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+  llvm::SmallSetVector<Value, 32> yieldValues(yield.getOperands().begin(),
+                                              yield.getOperands().end());
+  for (auto newRet : llvm::zip(newYieldedValues, newReturnTypes)) {
+    if (yieldValues.insert(std::get<0>(newRet))) {
+      types.push_back(std::get<1>(newRet));
+      indices.push_back(yieldValues.size() - 1);
+    } else {
+      // If the value already exit the region don't create a new output.
+      for (auto [idx, yieldOperand] :
+           llvm::enumerate(yieldValues.getArrayRef())) {
+        if (yieldOperand == std::get<0>(newRet)) {
+          indices.push_back(idx);
+          break;
+        }
+      }
+    }
+  }
+  yieldValues.insert(newYieldedValues.begin(), newYieldedValues.end());
+  vector::WarpExecuteOnLane0Op newWarpOp =
+      moveRegionToNewWarpOpAndReplaceReturns(rewriter, warpOp,
+                                             yieldValues.getArrayRef(), types);
+  rewriter.replaceOp(warpOp,
+                     newWarpOp.getResults().take_front(warpOp.getNumResults()));
+  return newWarpOp;
+}

>From de3ae899d02f6f151cdbd9badcd47c6f1df39861 Mon Sep 17 00:00:00 2001
From: Petr Kurapov <petr.a.kurapov at intel.com>
Date: Fri, 18 Oct 2024 16:18:32 +0000
Subject: [PATCH 3/3] [MLIR][XeGPU] Add distribution patterns for load_nd,
 store_nd, and create_nd_tdesc

---
 .../Dialect/XeGPU/Transforms/Transforms.h     |   1 +
 .../Dialect/XeGPU/Transforms/CMakeLists.txt   |   5 +
 .../XeGPU/Transforms/XeGPUDistribute.cpp      | 393 ++++++++++++++++++
 mlir/test/Dialect/XeGPU/xegpu-distribute.mlir |  81 ++++
 mlir/test/lib/Dialect/CMakeLists.txt          |   1 +
 mlir/test/lib/Dialect/XeGPU/CMakeLists.txt    |  17 +
 .../lib/Dialect/XeGPU/TestXeGPUTransforms.cpp |  58 +++
 mlir/tools/mlir-opt/CMakeLists.txt            |   1 +
 mlir/tools/mlir-opt/mlir-opt.cpp              |   2 +
 9 files changed, 559 insertions(+)
 create mode 100644 mlir/lib/Dialect/XeGPU/Transforms/XeGPUDistribute.cpp
 create mode 100644 mlir/test/Dialect/XeGPU/xegpu-distribute.mlir
 create mode 100644 mlir/test/lib/Dialect/XeGPU/CMakeLists.txt
 create mode 100644 mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp

diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h b/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
index 63ea26df069372..fe5198d1ac6dba 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
@@ -16,6 +16,7 @@ namespace xegpu {
 
 /// Appends patterns for folding aliasing ops into XeGPU ops into `patterns`.
 void populateXeGPUFoldAliasOpsPatterns(RewritePatternSet &patterns);
+void populateXeGPUDistributePatterns(RewritePatternSet &patterns);
 
 } // namespace xegpu
 } // namespace mlir
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
index 7fb64d3b97b87d..148ff46ba41b72 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
@@ -1,5 +1,6 @@
 add_mlir_dialect_library(MLIRXeGPUTransforms
   XeGPUFoldAliasOps.cpp
+  XeGPUDistribute.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/XeGPU
@@ -12,6 +13,10 @@ add_mlir_dialect_library(MLIRXeGPUTransforms
   MLIRIR
   MLIRMemRefDialect
   MLIRXeGPUDialect
+  MLIRVectorDialect
+  MLIRVectorUtils
+  MLIRArithDialect
+  MLIRFuncDialect
   MLIRPass
   MLIRTransforms
 )
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUDistribute.cpp
new file mode 100644
index 00000000000000..78a010ff1c941b
--- /dev/null
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUDistribute.cpp
@@ -0,0 +1,393 @@
+//===- XeGPUDistribute.cpp - XeGPU ditribute ops to work items --*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
+#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "xegpu-distribute"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
+
+using namespace mlir;
+
+namespace {
+bool divisible(APInt lhs, APInt rhs) { return !lhs.urem(rhs); }
+
+/// Clone a create_nd_tdesc feeding into vector.yield op for the enclosing
+/// `vector.warp_execute_on_lane_0` and put it after the warp op.
+/// The warp op will still contain the original op that will not be used by the
+/// yield op (and should be cleaned up later with dce). The yield op will bypass
+/// the create_nd_tdesc's arguments.
+/// The rewrite will create a subview of the size used by a single work item and
+/// appropriate offset. The distributed create_nd_tdesc points into the subview
+/// without offset. The tensor descriptor types is distributed according to
+/// sg_map attribute.
+///
+/// Example:
+///
+/// ```
+///   #sg_map_8 = #xegpu.sg_map<wi_layout = [1, 8], wi_data = [1, 1]>
+///   %r = vector.warp_execute_on_lane_0(%laneid) ->
+///                   (!xegpu.tensor_desc<4x8xf32>) {
+///     ...
+///     %td = xegpu.create_nd_tdesc %arg0[0, 0]
+///               : memref<4x8xf32> -> !xegpu.tensor_desc<4x8xf32>
+///     vector.yield %td
+///   }
+/// ```
+/// To
+/// ```
+///   %r:2 = vector.warp_execute_on_lane_0(%laneid) -> () {
+///     ...
+///     %dead = xegpu.create_nd_tdesc %arg0[0, 0]
+///               : memref<4x8xf32> -> !xegpu.tensor_desc<4x8xf32>
+///     vector.yield %arg0, %dead
+///   }
+///   %view = memref.subview %r#0[0, %laneid] [4, 1] [1, 1]
+///                               : memref<4x8xf32> to memref<4x1xf32>
+///   %td = xegpu.create_nd_tdesc %view[0, 0]: memref<4x1xf32>
+///                                 -> !xegpu.tensor_desc<4x1xf32>
+///
+/// ```
+struct WarpOpTensorDescOp final
+    : public OpRewritePattern<vector::WarpExecuteOnLane0Op> {
+  using OpRewritePattern<vector::WarpExecuteOnLane0Op>::OpRewritePattern;
+  LogicalResult matchAndRewrite(vector::WarpExecuteOnLane0Op warpOp,
+                                PatternRewriter &rewriter) const override;
+};
+
+/// Sink a store_nd feeding into vector.yield op for the enclosing
+/// `vector.warp_execute_on_lane_0`. In case arguments for the store are passed
+/// through the warp op interface they would be propagated as returned values.
+/// Both the stored vector type and tensor descriptor types are distributed
+/// according to sg_map attribute.
+///
+/// Example:
+///
+/// ```
+///   #sg_map_8 = #xegpu.sg_map<wi_layout = [1, 8], wi_data = [1, 1]>
+///   vector.warp_execute_on_lane_0(%laneid) -> () {
+///     ...
+///     xegpu.store_nd %arg0, %arg1: vector<4x8xf32>,
+///                                 !xegpu.tensor_desc<4x8xf32>
+///     vector.yield
+///   }
+/// ```
+/// To
+/// ```
+///   %r = vector.warp_execute_on_lane_0(%laneid) -> () {
+///     ...
+///     vector.yield
+///   }
+///   xegpu.store_nd %arg0, %arg1: vector<4x1xf32>, !xegpu.tensor_desc<4x1xf32>
+///
+/// ```
+struct WarpOpStoreNd final
+    : public OpRewritePattern<vector::WarpExecuteOnLane0Op> {
+  using OpRewritePattern<vector::WarpExecuteOnLane0Op>::OpRewritePattern;
+  LogicalResult matchAndRewrite(vector::WarpExecuteOnLane0Op warpOp,
+                                PatternRewriter &rewriter) const override;
+};
+
+/// Clone a load_nd feeding into vector.yield op for the enclosing
+/// `vector.warp_execute_on_lane_0` and put it after the warp op.
+/// The warp op will still contain the original op that will not be used by the
+/// yield op (and should be cleaned up later with dce). The yield op will bypass
+/// the load's arguments.
+/// Both the loaded vector type and tensor descriptor types are distributed
+/// according to sg_map attribute.
+///
+/// Example:
+///
+/// ```
+///   #sg_map_8 = #xegpu.sg_map<wi_layout = [1, 8], wi_data = [1, 1]>
+///   %r = vector.warp_execute_on_lane_0(%laneid) ->
+///                   (!xegpu.tensor_desc<4x8xf32>) {
+///     ...
+///     %ld = xegpu.load_nd %arg0, %arg1: !xegpu.tensor_desc<4x8xf32>,
+///     vector<4x8xf32> vector.yield %ld
+///   }
+/// ```
+/// To
+/// ```
+///   %r:2 = vector.warp_execute_on_lane_0(%laneid) -> () {
+///     ...
+///     %dead = xegpu.load_nd %arg0, %arg1:
+///         !xegpu.tensor_desc<4x8xf32>, vector<4x8xf32>
+///     vector.yield %arg0, %arg1
+///   }
+///   xegpu.store_nd %r#0, %r#1: vector<4x1xf32>, !xegpu.tensor_desc<4x1xf32>
+///
+/// ```
+struct WarpOpLoadNd final
+    : public OpRewritePattern<vector::WarpExecuteOnLane0Op> {
+  using OpRewritePattern<vector::WarpExecuteOnLane0Op>::OpRewritePattern;
+  LogicalResult matchAndRewrite(vector::WarpExecuteOnLane0Op warpOp,
+                                PatternRewriter &rewriter) const override;
+};
+
+FailureOr<VectorType> getDistributedVectorType(VectorType originalT,
+                                               xegpu::SGMapAttr sgMap) {
+  llvm::SmallVector<int64_t, 2> distributedShape;
+  auto layout = sgMap.getWiLayout();
+  auto shape = originalT.getShape();
+  for (const auto [l, o] : llvm::zip_equal(layout, shape)) {
+    if (!divisible(APInt(64, o), APInt(64, l)))
+      return failure();
+    distributedShape.push_back(o / l);
+  }
+  auto newVectorType =
+      VectorType::get(distributedShape, originalT.getElementType(),
+                      originalT.getScalableDims());
+  return newVectorType;
+}
+
+FailureOr<xegpu::TensorDescType>
+getDistributedTensorDescType(xegpu::TensorDescType originalT,
+                             xegpu::SGMapAttr sgMap,
+                             xegpu::MemorySpace memSpace) {
+  llvm::SmallVector<int64_t, 2> distributedShape;
+  auto layout = sgMap.getWiLayout();
+  auto shape = originalT.getShape();
+  for (const auto [l, o] : llvm::zip_equal(layout, shape)) {
+    if (!divisible(APInt(64, o), APInt(64, l)))
+      return failure();
+    distributedShape.push_back(o / l);
+  }
+  xegpu::TensorDescType distributedDescType;
+  if (originalT.isScattered()) {
+
+    distributedDescType = xegpu::TensorDescType::get(
+        distributedShape, originalT.getElementType(), originalT.getChunkSize(),
+        originalT.getMemorySpace(), originalT.getSGMapAttr());
+  } else {
+    distributedDescType = xegpu::TensorDescType::get(
+        distributedShape, originalT.getElementType(),
+        originalT.getBoundaryCheck(), originalT.getArrayLength(),
+        originalT.getMemorySpace(), originalT.getSGMapAttr());
+  }
+  return distributedDescType;
+}
+} // namespace
+
+LogicalResult
+WarpOpStoreNd::matchAndRewrite(vector::WarpExecuteOnLane0Op warpOp,
+                               PatternRewriter &rewriter) const {
+  auto yield = cast<vector::YieldOp>(
+      warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+  Operation *lastNode = yield->getPrevNode();
+  auto storeOp = dyn_cast_or_null<xegpu::StoreNdOp>(lastNode);
+  if (!storeOp)
+    return failure();
+
+  auto origType = storeOp.getTensorDescType();
+  xegpu::SGMapAttr sgMap = origType.getSGMapAttr();
+  if (!sgMap)
+    return rewriter.notifyMatchFailure(
+        storeOp, "the source tensor descriptor lacks sg_map attribute");
+
+  if (storeOp.getTensorDescType().getShape().size() != 2)
+    return rewriter.notifyMatchFailure(storeOp, "unsupported shape");
+  DBGS() << "Matched store_nd: " << storeOp << "\n";
+
+  auto distributedTypeOrFailure =
+      getDistributedVectorType(storeOp.getValueType(), sgMap);
+  if (failed(distributedTypeOrFailure))
+    return rewriter.notifyMatchFailure(storeOp,
+                                       "Failed to distribute the type");
+  VectorType newVectorType = distributedTypeOrFailure.value();
+
+  auto distributedDescTypeOrFailure = getDistributedTensorDescType(
+      storeOp.getTensorDescType(), sgMap,
+      storeOp.getTensorDescType().getMemorySpace());
+  if (failed(distributedDescTypeOrFailure))
+    return rewriter.notifyMatchFailure(storeOp,
+                                       "Failed to distribute the desc type");
+  xegpu::TensorDescType newTDescType = distributedDescTypeOrFailure.value();
+
+  SmallVector<size_t> newRetIndices;
+  vector::WarpExecuteOnLane0Op newWarpOp =
+      moveRegionToNewWarpOpAndAppendReturns(
+          rewriter, warpOp,
+          ValueRange{storeOp.getTensorDesc(), storeOp.getValue()},
+          TypeRange{newTDescType, newVectorType}, newRetIndices);
+
+  rewriter.setInsertionPointAfter(newWarpOp);
+  auto newStoreOp =
+      cast<xegpu::StoreNdOp>(rewriter.clone(*storeOp.getOperation()));
+  rewriter.eraseOp(storeOp);
+  newStoreOp.getTensorDescMutable().assign(
+      newWarpOp.getResult(newRetIndices[0]));
+  newStoreOp.getValueMutable().assign(newWarpOp.getResult(newRetIndices[1]));
+
+  return success();
+}
+
+LogicalResult WarpOpLoadNd::matchAndRewrite(vector::WarpExecuteOnLane0Op warpOp,
+                                            PatternRewriter &rewriter) const {
+  OpOperand *operand = getWarpResult(warpOp, [](Operation *op) {
+    return isa<xegpu::LoadNdOp>(op) && op->hasOneUse();
+  });
+
+  if (!operand)
+    return rewriter.notifyMatchFailure(warpOp,
+                                       "warp result is not a xegpu::LoadNd op");
+
+  auto loadOp = operand->get().getDefiningOp<xegpu::LoadNdOp>();
+
+  if (loadOp.getPacked())
+    return rewriter.notifyMatchFailure(
+        loadOp, "Packed load distribution not supported");
+
+  xegpu::TensorDescType origType = loadOp.getTensorDescType();
+  xegpu::SGMapAttr sgMap = origType.getSGMapAttr();
+  if (!sgMap)
+    return rewriter.notifyMatchFailure(
+        loadOp, "the source tensor descriptor lacks sg_map attribute");
+
+  auto origShape = origType.getShape();
+  if (origShape.size() != 2)
+    return rewriter.notifyMatchFailure(loadOp, "unsupported shape");
+
+  auto distributedTypeOrFailure =
+      getDistributedVectorType(loadOp.getType(), sgMap);
+  if (failed(distributedTypeOrFailure))
+    return rewriter.notifyMatchFailure(loadOp, "Failed to distribute the type");
+  VectorType newVectorType = distributedTypeOrFailure.value();
+
+  auto distributedDescTypeOrFailure =
+      getDistributedTensorDescType(loadOp.getTensorDescType(), sgMap,
+                                   loadOp.getTensorDescType().getMemorySpace());
+  if (failed(distributedDescTypeOrFailure))
+    return rewriter.notifyMatchFailure(loadOp,
+                                       "Failed to distribute the desc type");
+  xegpu::TensorDescType newTDescType = distributedDescTypeOrFailure.value();
+
+  unsigned operandIdx = operand->getOperandNumber();
+
+  SmallVector<size_t> newRetIndices;
+  vector::WarpExecuteOnLane0Op newWarpOp =
+      moveRegionToNewWarpOpAndAppendReturns(
+          rewriter, warpOp, loadOp.getTensorDesc(), TypeRange{newTDescType},
+          newRetIndices);
+
+  rewriter.setInsertionPointAfter(newWarpOp);
+
+  auto newLoadOp = rewriter.create<xegpu::LoadNdOp>(
+      loadOp.getLoc(), newVectorType, loadOp.getTensorDesc(),
+      loadOp.getPackedAttr(), loadOp.getTransposeAttr(), loadOp.getL1HintAttr(),
+      loadOp.getL2HintAttr(), loadOp.getL3HintAttr());
+
+  newLoadOp.getTensorDescMutable().assign(
+      newWarpOp.getResult(newRetIndices[0]));
+  Value distributedVal = newWarpOp.getResult(operandIdx);
+  rewriter.replaceAllUsesWith(distributedVal, newLoadOp);
+
+  return success();
+}
+
+LogicalResult
+WarpOpTensorDescOp::matchAndRewrite(vector::WarpExecuteOnLane0Op warpOp,
+                                    PatternRewriter &rewriter) const {
+  OpOperand *operand = getWarpResult(warpOp, [](Operation *op) {
+    return isa<xegpu::CreateNdDescOp>(op) && op->hasOneUse();
+  });
+
+  if (!operand)
+    return rewriter.notifyMatchFailure(
+        warpOp, "warp result is not a xegpu::CreateNdDesc op");
+  auto descOp = operand->get().getDefiningOp<xegpu::CreateNdDescOp>();
+  assert(descOp && "desc op must be not null");
+  unsigned operandIdx = operand->getOperandNumber();
+
+  // TODO: is memref uniform in the region
+  rewriter.setInsertionPoint(warpOp);
+  auto srcTypedVal = dyn_cast<TypedValue<MemRefType>>(descOp.getSource());
+  assert(srcTypedVal && "source value must be not null");
+
+  auto descOffsets = descOp.getMixedOffsets();
+  if (descOffsets.size() != 2)
+    return rewriter.notifyMatchFailure(descOp,
+                                       "offsets size is expected to be 2");
+
+  xegpu::SGMapAttr sgMap = descOp.getType().getSGMapAttr();
+  if (!sgMap)
+    return rewriter.notifyMatchFailure(
+        descOp, "the tensor descriptor lacks sg_map attribute");
+
+  auto layout = sgMap.getWiLayout();
+
+  // Calculate the offset within tensor descriptor for the current lane_id. The
+  // access to proper element for a work item is done through a lane-specific
+  // subview (tdesc offsets are used as base, lane shift is added on top).
+  auto laneid = warpOp.getLaneid();
+  auto xDim =
+      rewriter.create<arith::ConstantIndexOp>(laneid.getLoc(), layout[0]);
+  auto shiftx = rewriter.create<arith::RemUIOp>(laneid.getLoc(), laneid, xDim);
+  auto shifty = rewriter.create<arith::DivUIOp>(laneid.getLoc(), laneid, xDim);
+
+  auto basex = getValueOrCreateConstantIndexOp(rewriter, laneid.getLoc(),
+                                               descOffsets[0]);
+  auto basey = getValueOrCreateConstantIndexOp(rewriter, laneid.getLoc(),
+                                               descOffsets[1]);
+  auto offsetx = rewriter.create<arith::AddIOp>(laneid.getLoc(), shiftx, basex);
+  auto offsety = rewriter.create<arith::AddIOp>(laneid.getLoc(), shifty, basey);
+
+  auto distributedDescTypeOrFailure = getDistributedTensorDescType(
+      descOp.getType(), sgMap, descOp.getType().getMemorySpace());
+  if (failed(distributedDescTypeOrFailure))
+    return rewriter.notifyMatchFailure(descOp,
+                                       "Failed to distribute the desc type");
+  xegpu::TensorDescType newTDescType = distributedDescTypeOrFailure.value();
+  auto distributedShape = newTDescType.getShape();
+  // use the base memref strides
+  SmallVector<OpFoldResult> overwriteStrides =
+      getAsIndexOpFoldResult(rewriter.getContext(), SmallVector<int64_t>{1, 1});
+  SmallVector<OpFoldResult> overwriteSizes =
+      getAsIndexOpFoldResult(rewriter.getContext(), distributedShape);
+
+  SmallVector<size_t> newRetIndices;
+  vector::WarpExecuteOnLane0Op newWarpOp =
+      moveRegionToNewWarpOpAndAppendReturns(
+          rewriter, warpOp, descOp.getSource(), descOp.getSourceType(),
+          newRetIndices);
+
+  rewriter.setInsertionPointAfter(newWarpOp);
+  auto subview = rewriter.create<memref::SubViewOp>(
+      newWarpOp.getLoc(), srcTypedVal, getAsOpFoldResult({offsetx, offsety}),
+      overwriteSizes, overwriteStrides);
+  subview.getSourceMutable().assign(newWarpOp.getResult(newRetIndices[0]));
+
+  auto zero = rewriter.create<arith::ConstantIndexOp>(laneid.getLoc(), 0);
+  auto newDescOp = rewriter.create<xegpu::CreateNdDescOp>(
+      newWarpOp.getLoc(), newTDescType, subview,
+      getAsOpFoldResult({zero, zero}));
+
+  Value distributedVal = newWarpOp.getResult(operandIdx);
+  rewriter.replaceAllUsesWith(distributedVal, newDescOp);
+
+  return success();
+}
+
+void xegpu::populateXeGPUDistributePatterns(RewritePatternSet &patterns) {
+  patterns.add<WarpOpTensorDescOp>(patterns.getContext());
+  patterns.add<WarpOpStoreNd>(patterns.getContext());
+  patterns.add<WarpOpLoadNd>(patterns.getContext());
+}
diff --git a/mlir/test/Dialect/XeGPU/xegpu-distribute.mlir b/mlir/test/Dialect/XeGPU/xegpu-distribute.mlir
new file mode 100644
index 00000000000000..ec01fc82688156
--- /dev/null
+++ b/mlir/test/Dialect/XeGPU/xegpu-distribute.mlir
@@ -0,0 +1,81 @@
+// RUN: mlir-opt -test-xegpu-distribute -split-input-file %s | FileCheck %s
+
+#sg_map_16 = #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>
+#blk_tdesc = #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>
+
+// CHECK-LABEL: test_store_nd_distribution
+// CHECK: %[[laneid:.*]] = gpu.lane_id
+// CHECK: %[[res:.*]]:2 = vector.warp_execute_on_lane_0(%[[laneid]])[16] args(%{{.*}}, %{{.*}} :  vector<24x32xf16>, !xegpu.tensor_desc<24x32xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>)
+// CHECK-SAME: -> (!xegpu.tensor_desc<24x2xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>, vector<24x2xf16>)
+// CHECK: ^bb0(%[[src:.*]]: vector<24x32xf16>, %[[dst:.*]]: !xegpu.tensor_desc<24x32xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>)
+// CHECK: vector.yield %[[dst]], %[[src]] : !xegpu.tensor_desc<24x32xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>, vector<24x32xf16>
+// CHECK: xegpu.store_nd %[[res]]#1, %[[res]]#0 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}> :
+// CHECK-SAME: vector<24x2xf16>, !xegpu.tensor_desc<24x2xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+
+func.func @test_store_nd_distribution(%src: vector<24x32xf16>, %dst: !xegpu.tensor_desc<24x32xf16, #blk_tdesc, #sg_map_16>) -> () {
+  %laneid = gpu.lane_id
+  vector.warp_execute_on_lane_0(%laneid)[16]
+        args(%src, %dst: vector<24x32xf16>, !xegpu.tensor_desc<24x32xf16, #blk_tdesc, #sg_map_16>) {
+    ^bb0(%arg0: vector<24x32xf16>, %arg1: !xegpu.tensor_desc<24x32xf16, #blk_tdesc, #sg_map_16>):
+    xegpu.store_nd %arg0, %arg1 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>: vector<24x32xf16>, !xegpu.tensor_desc<24x32xf16, #blk_tdesc, #sg_map_16>
+  }
+  return
+}
+
+// -----
+
+#sg_map_16 = #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>
+#blk_tdesc = #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>
+
+// CHECK-LABEL: test_load_nd_distribution
+// CHECK: %[[laneid:.*]] = gpu.lane_id
+// CHECK: %[[res:.*]]:2 = vector.warp_execute_on_lane_0(%[[laneid]])[16] args(%{{.*}} :  !xegpu.tensor_desc<24x32xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>)
+// CHECK-SAME: -> (vector<24x2xf16>, !xegpu.tensor_desc<24x2xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>)
+// CHECK: ^bb0(%[[dst:.*]]: !xegpu.tensor_desc<24x32xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>)
+// CHECK: %[[dead:.*]] = xegpu.load_nd
+// CHECK: vector.yield %[[dead]], %[[dst]] : vector<24x32xf16>, !xegpu.tensor_desc<24x32xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+// CHECK: %[[load:.*]] = xegpu.load_nd %[[res]]#1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> :
+// CHECK-SAME: !xegpu.tensor_desc<24x2xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>> -> vector<24x2xf16>
+// CHECK: return %[[load]]
+
+func.func @test_load_nd_distribution(%dst: !xegpu.tensor_desc<24x32xf16, #blk_tdesc, #sg_map_16>) -> (vector<24x2xf16>) {
+  %laneid = gpu.lane_id
+  %r = vector.warp_execute_on_lane_0(%laneid)[16]
+        args(%dst: !xegpu.tensor_desc<24x32xf16, #blk_tdesc, #sg_map_16>) -> (vector<24x2xf16>) {
+    ^bb0(%arg0: !xegpu.tensor_desc<24x32xf16, #blk_tdesc, #sg_map_16>):
+    %0 = xegpu.load_nd %arg0 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+       : !xegpu.tensor_desc<24x32xf16, #blk_tdesc, #sg_map_16> -> vector<24x32xf16>
+    vector.yield %0 : vector<24x32xf16>
+  }
+  return %r : vector<24x2xf16>
+}
+
+// -----
+
+#sg_map_16 = #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>
+#blk_tdesc = #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>
+
+// CHECK-LABEL: test_create_nd_desc_distribution
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[laneid:.*]] = gpu.lane_id
+// CHECK: %[[res:.*]]:2 = vector.warp_execute_on_lane_0(%[[laneid]])[16] args(%{{.*}} : memref<24x32xf16>)
+// CHECK-SAME: -> (!xegpu.tensor_desc<24x2xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>, memref<24x32xf16>)
+// CHECK: ^bb0(%[[dst:.*]]: memref<24x32xf16>)
+// CHECK: %[[dead:.*]] = xegpu.create_nd_tdesc
+// CHECK: vector.yield %[[dead]], %[[dst]] :
+// CHECK-SAME: !xegpu.tensor_desc<24x32xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>, memref<24x32xf16>
+// CHECK: %[[view:.*]] = memref.subview %[[res]]#1[%[[C0]], %[[laneid]]] [24, 2] [1, 1] : memref<24x32xf16> to memref<24x2xf16, strided<[32, 1], offset: ?>>
+// CHECK: %[[desc:.*]] = xegpu.create_nd_tdesc %[[view]][0, 0] : memref<24x2xf16, strided<[32, 1], offset: ?>>
+// CHECK-SAME: -> !xegpu.tensor_desc<24x2xf16, #xegpu.block_tdesc_attr<memory_space =  global, array_length = 1 : i64, boundary_check = true>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+// CHECK: return %[[desc]]
+
+func.func @test_create_nd_desc_distribution(%dst: memref<24x32xf16>) -> (!xegpu.tensor_desc<24x2xf16, #blk_tdesc, #sg_map_16>) {
+  %laneid = gpu.lane_id
+  %r = vector.warp_execute_on_lane_0(%laneid)[16]
+        args(%dst: memref<24x32xf16>) -> (!xegpu.tensor_desc<24x2xf16, #blk_tdesc, #sg_map_16>) {
+    ^bb0(%arg0: memref<24x32xf16>):
+    %0 = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<24x32xf16, #blk_tdesc, #sg_map_16>
+    vector.yield %0 : !xegpu.tensor_desc<24x32xf16, #blk_tdesc, #sg_map_16>
+  }
+  return %r : !xegpu.tensor_desc<24x2xf16, #blk_tdesc, #sg_map_16>
+}
diff --git a/mlir/test/lib/Dialect/CMakeLists.txt b/mlir/test/lib/Dialect/CMakeLists.txt
index 29fb4441a24fd2..a8fd70e6397a52 100644
--- a/mlir/test/lib/Dialect/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/CMakeLists.txt
@@ -22,3 +22,4 @@ add_subdirectory(TestDyn)
 add_subdirectory(Tosa)
 add_subdirectory(Transform)
 add_subdirectory(Vector)
+add_subdirectory(XeGPU)
diff --git a/mlir/test/lib/Dialect/XeGPU/CMakeLists.txt b/mlir/test/lib/Dialect/XeGPU/CMakeLists.txt
new file mode 100644
index 00000000000000..c8fe0db5f6213a
--- /dev/null
+++ b/mlir/test/lib/Dialect/XeGPU/CMakeLists.txt
@@ -0,0 +1,17 @@
+# Exclude tests from libMLIR.so
+add_mlir_library(MLIRXeGPUTestPasses
+  TestXeGPUTransforms.cpp
+
+  EXCLUDE_FROM_LIBMLIR
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/XeGPU
+
+  LINK_LIBS PUBLIC
+  MLIRIR
+  MLIRPass
+  MLIRXeGPUTransforms
+  MLIRXeGPUDialect
+  MLIRSupport
+  )
+
diff --git a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
new file mode 100644
index 00000000000000..eda68b83748139
--- /dev/null
+++ b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
@@ -0,0 +1,58 @@
+//===- TestXeGPUTransforms.cpp - Test XeGPU transforms and lowerings ------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+using namespace mlir::xegpu;
+using namespace mlir::vector;
+
+namespace {
+struct TestXeGPUDistribution
+    : public PassWrapper<TestXeGPUDistribution, OperationPass<func::FuncOp>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestXeGPUDistribution)
+
+  TestXeGPUDistribution() = default;
+  TestXeGPUDistribution(const TestXeGPUDistribution &pass)
+      : PassWrapper(pass) {}
+
+  StringRef getArgument() const final { return "test-xegpu-distribute"; }
+  StringRef getDescription() const final {
+    return "Test patterns for operations work item distribution";
+  }
+
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<xegpu::XeGPUDialect>();
+    registry.insert<vector::VectorDialect>();
+    registry.insert<arith::ArithDialect>();
+    registry.insert<memref::MemRefDialect>();
+  }
+
+  void runOnOperation() override {
+    RewritePatternSet patterns(&getContext());
+    populateXeGPUDistributePatterns(patterns);
+    (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+  }
+};
+} // namespace
+
+namespace mlir {
+namespace test {
+void registerTestXeGPUTransforms() {
+  PassRegistration<TestXeGPUDistribution>();
+}
+} // namespace test
+} // namespace mlir
diff --git a/mlir/tools/mlir-opt/CMakeLists.txt b/mlir/tools/mlir-opt/CMakeLists.txt
index 8b79de58fa1028..e4ffbbee7a1d94 100644
--- a/mlir/tools/mlir-opt/CMakeLists.txt
+++ b/mlir/tools/mlir-opt/CMakeLists.txt
@@ -47,6 +47,7 @@ if(MLIR_INCLUDE_TESTS)
     MLIRTilingInterfaceTestPasses
     MLIRVectorTestPasses
     MLIRTestVectorToSPIRV
+    MLIRXeGPUTestPasses
     MLIRLLVMTestPasses
     )
   set(test_libs ${test_libs}
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 36b142484bb04a..b53e9513b0598d 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -151,6 +151,7 @@ void registerTestTransformDialectEraseSchedulePass();
 void registerTestPassStateExtensionCommunication();
 void registerTestVectorLowerings();
 void registerTestVectorReductionToSPIRVDotProd();
+void registerTestXeGPUTransforms();
 void registerTestWrittenToPass();
 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH
 void registerTestDialectConversionPasses();
@@ -286,6 +287,7 @@ void registerTestPasses() {
   mlir::test::registerTestTransformDialectEraseSchedulePass();
   mlir::test::registerTestPassStateExtensionCommunication();
   mlir::test::registerTestVectorLowerings();
+  mlir::test::registerTestXeGPUTransforms();
   mlir::test::registerTestVectorReductionToSPIRVDotProd();
   mlir::test::registerTestWrittenToPass();
 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH



More information about the Mlir-commits mailing list