[Mlir-commits] [mlir] [MLIR][XeGPU] Xegpu distribution patterns for load_nd, store_nd, and create_nd_tdesc. (PR #112945)

Petr Kurapov llvmlistbot at llvm.org
Tue Oct 29 02:23:22 PDT 2024


https://github.com/kurapov-peter updated https://github.com/llvm/llvm-project/pull/112945

>From d292605dbd568e946efeac0260e7d16d416ac95c Mon Sep 17 00:00:00 2001
From: Petr Kurapov <petr.a.kurapov at intel.com>
Date: Thu, 17 Oct 2024 17:29:00 +0000
Subject: [PATCH 1/4] [MLIR][Vector] Allow any shaped typed to be distributed
 for vector.warp_execute_on_lane_0's return values

---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 12 ++++++------
 mlir/test/Dialect/Vector/invalid.mlir    |  6 +++---
 2 files changed, 9 insertions(+), 9 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index a2abe1619454f2..51d3691fd107ae 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -6558,14 +6558,14 @@ static LogicalResult verifyDistributedType(Type expanded, Type distributed,
   // If the types matches there is no distribution.
   if (expanded == distributed)
     return success();
-  auto expandedVecType = llvm::dyn_cast<VectorType>(expanded);
-  auto distributedVecType = llvm::dyn_cast<VectorType>(distributed);
+  auto expandedVecType = llvm::dyn_cast<ShapedType>(expanded);
+  auto distributedVecType = llvm::dyn_cast<ShapedType>(distributed);
   if (!expandedVecType || !distributedVecType)
-    return op->emitOpError("expected vector type for distributed operands.");
+    return op->emitOpError("expected shaped type for distributed operands.");
   if (expandedVecType.getRank() != distributedVecType.getRank() ||
       expandedVecType.getElementType() != distributedVecType.getElementType())
     return op->emitOpError(
-        "expected distributed vectors to have same rank and element type.");
+        "expected distributed types to have same rank and element type.");
 
   SmallVector<int64_t> scales(expandedVecType.getRank(), 1);
   for (int64_t i = 0, e = expandedVecType.getRank(); i < e; i++) {
@@ -6575,8 +6575,8 @@ static LogicalResult verifyDistributedType(Type expanded, Type distributed,
       continue;
     if (eDim % dDim != 0)
       return op->emitOpError()
-             << "expected expanded vector dimension #" << i << " (" << eDim
-             << ") to be a multipler of the distributed vector dimension ("
+             << "expected expanded type dimension #" << i << " (" << eDim
+             << ") to be a multipler of the distributed type dimension ("
              << dDim << ")";
     scales[i] = eDim / dDim;
   }
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 36d04bb77e3b96..69346574177929 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1665,7 +1665,7 @@ func.func @warp_2_distributed_dims(%laneid: index) {
 // -----
 
 func.func @warp_2_distributed_dims(%laneid: index) {
-  // expected-error at +1 {{expected expanded vector dimension #1 (8) to be a multipler of the distributed vector dimension (3)}}
+  // expected-error at +1 {{expected expanded type dimension #1 (8) to be a multipler of the distributed type dimension (3)}}
   %2 = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<1x3xi32>) {
     %0 = arith.constant dense<2>: vector<4x8xi32>
     vector.yield %0 : vector<4x8xi32>
@@ -1676,7 +1676,7 @@ func.func @warp_2_distributed_dims(%laneid: index) {
 // -----
 
 func.func @warp_mismatch_rank(%laneid: index) {
-  // expected-error at +1 {{'vector.warp_execute_on_lane_0' op expected distributed vectors to have same rank and element type.}}
+  // expected-error at +1 {{'vector.warp_execute_on_lane_0' op expected distributed types to have same rank and element type.}}
   %2 = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<4x4xi32>) {
     %0 = arith.constant dense<2>: vector<128xi32>
     vector.yield %0 : vector<128xi32>
@@ -1687,7 +1687,7 @@ func.func @warp_mismatch_rank(%laneid: index) {
 // -----
 
 func.func @warp_mismatch_rank(%laneid: index) {
-  // expected-error at +1 {{'vector.warp_execute_on_lane_0' op expected vector type for distributed operands.}}
+  // expected-error at +1 {{'vector.warp_execute_on_lane_0' op expected shaped type for distributed operands.}}
   %2 = vector.warp_execute_on_lane_0(%laneid)[32] -> (i32) {
     %0 = arith.constant dense<2>: vector<128xi32>
     vector.yield %0 : vector<128xi32>

>From 7627cf747d98c33ce24e585f4b695962d681139d Mon Sep 17 00:00:00 2001
From: Petr Kurapov <petr.a.kurapov at intel.com>
Date: Fri, 18 Oct 2024 15:50:31 +0000
Subject: [PATCH 2/4] [MLIR][Vector] Move helper functions to vector
 distribution utils

---
 .../mlir/Dialect/Vector/Utils/VectorUtils.h   | 20 ++++
 .../Vector/Transforms/VectorDistribute.cpp    | 80 +---------------
 mlir/lib/Dialect/Vector/Utils/CMakeLists.txt  |  1 +
 .../Vector/Utils/VectorDistributeUtils.cpp    | 96 +++++++++++++++++++
 4 files changed, 118 insertions(+), 79 deletions(-)
 create mode 100644 mlir/lib/Dialect/Vector/Utils/VectorDistributeUtils.cpp

diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
index 5f32aca88a2734..6bd924307376dc 100644
--- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
+++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
@@ -20,6 +20,8 @@
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/TypeSwitch.h"
 
+#include <utility>
+
 namespace mlir {
 
 // Forward declarations.
@@ -324,6 +326,24 @@ namespace matcher {
 bool operatesOnSuperVectorsOf(Operation &op, VectorType subVectorType);
 
 } // namespace matcher
+
+/// Return a value yielded by `warpOp` which statifies the filter lamdba
+/// condition and is not dead.
+OpOperand *getWarpResult(vector::WarpExecuteOnLane0Op warpOp,
+                         const std::function<bool(Operation *)> &fn);
+
+/// Helper to create a new WarpExecuteOnLane0Op with different signature.
+vector::WarpExecuteOnLane0Op moveRegionToNewWarpOpAndReplaceReturns(
+    RewriterBase &rewriter, vector::WarpExecuteOnLane0Op warpOp,
+    ValueRange newYieldedValues, TypeRange newReturnTypes);
+
+/// Helper to create a new WarpExecuteOnLane0Op region with extra outputs.
+/// `indices` return the index of each new output.
+vector::WarpExecuteOnLane0Op moveRegionToNewWarpOpAndAppendReturns(
+    RewriterBase &rewriter, vector::WarpExecuteOnLane0Op warpOp,
+    ValueRange newYieldedValues, TypeRange newReturnTypes,
+    llvm::SmallVector<size_t> &indices);
+
 } // namespace mlir
 
 #endif // MLIR_DIALECT_VECTOR_UTILS_VECTORUTILS_H_
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 2289fd1ff1364e..c80c3179b5e025 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -12,6 +12,7 @@
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
 #include "mlir/IR/AffineExpr.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "mlir/Transforms/RegionUtils.h"
@@ -160,68 +161,6 @@ struct DistributedLoadStoreHelper {
 
 } // namespace
 
-/// Helper to create a new WarpExecuteOnLane0Op with different signature.
-static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndReplaceReturns(
-    RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
-    ValueRange newYieldedValues, TypeRange newReturnTypes) {
-  // Create a new op before the existing one, with the extra operands.
-  OpBuilder::InsertionGuard g(rewriter);
-  rewriter.setInsertionPoint(warpOp);
-  auto newWarpOp = rewriter.create<WarpExecuteOnLane0Op>(
-      warpOp.getLoc(), newReturnTypes, warpOp.getLaneid(), warpOp.getWarpSize(),
-      warpOp.getArgs(), warpOp.getBody()->getArgumentTypes());
-
-  Region &opBody = warpOp.getBodyRegion();
-  Region &newOpBody = newWarpOp.getBodyRegion();
-  Block &newOpFirstBlock = newOpBody.front();
-  rewriter.inlineRegionBefore(opBody, newOpBody, newOpBody.begin());
-  rewriter.eraseBlock(&newOpFirstBlock);
-  assert(newWarpOp.getWarpRegion().hasOneBlock() &&
-         "expected WarpOp with single block");
-
-  auto yield =
-      cast<vector::YieldOp>(newOpBody.getBlocks().begin()->getTerminator());
-
-  rewriter.modifyOpInPlace(
-      yield, [&]() { yield.getOperandsMutable().assign(newYieldedValues); });
-  return newWarpOp;
-}
-
-/// Helper to create a new WarpExecuteOnLane0Op region with extra outputs.
-/// `indices` return the index of each new output.
-static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndAppendReturns(
-    RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
-    ValueRange newYieldedValues, TypeRange newReturnTypes,
-    llvm::SmallVector<size_t> &indices) {
-  SmallVector<Type> types(warpOp.getResultTypes().begin(),
-                          warpOp.getResultTypes().end());
-  auto yield = cast<vector::YieldOp>(
-      warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
-  llvm::SmallSetVector<Value, 32> yieldValues(yield.getOperands().begin(),
-                                              yield.getOperands().end());
-  for (auto newRet : llvm::zip(newYieldedValues, newReturnTypes)) {
-    if (yieldValues.insert(std::get<0>(newRet))) {
-      types.push_back(std::get<1>(newRet));
-      indices.push_back(yieldValues.size() - 1);
-    } else {
-      // If the value already exit the region don't create a new output.
-      for (auto [idx, yieldOperand] :
-           llvm::enumerate(yieldValues.getArrayRef())) {
-        if (yieldOperand == std::get<0>(newRet)) {
-          indices.push_back(idx);
-          break;
-        }
-      }
-    }
-  }
-  yieldValues.insert(newYieldedValues.begin(), newYieldedValues.end());
-  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
-      rewriter, warpOp, yieldValues.getArrayRef(), types);
-  rewriter.replaceOp(warpOp,
-                     newWarpOp.getResults().take_front(warpOp.getNumResults()));
-  return newWarpOp;
-}
-
 /// Helper to know if an op can be hoisted out of the region.
 static bool canBeHoisted(Operation *op,
                          function_ref<bool(Value)> definedOutside) {
@@ -229,23 +168,6 @@ static bool canBeHoisted(Operation *op,
          isMemoryEffectFree(op) && op->getNumRegions() == 0;
 }
 
-/// Return a value yielded by `warpOp` which statifies the filter lamdba
-/// condition and is not dead.
-static OpOperand *getWarpResult(WarpExecuteOnLane0Op warpOp,
-                                const std::function<bool(Operation *)> &fn) {
-  auto yield = cast<vector::YieldOp>(
-      warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
-  for (OpOperand &yieldOperand : yield->getOpOperands()) {
-    Value yieldValues = yieldOperand.get();
-    Operation *definedOp = yieldValues.getDefiningOp();
-    if (definedOp && fn(definedOp)) {
-      if (!warpOp.getResult(yieldOperand.getOperandNumber()).use_empty())
-        return &yieldOperand;
-    }
-  }
-  return {};
-}
-
 // Clones `op` into a new operation that takes `operands` and returns
 // `resultTypes`.
 static Operation *cloneOpWithOperandsAndTypes(RewriterBase &rewriter,
diff --git a/mlir/lib/Dialect/Vector/Utils/CMakeLists.txt b/mlir/lib/Dialect/Vector/Utils/CMakeLists.txt
index fa3971695d4bf2..9db0c172fec5ce 100644
--- a/mlir/lib/Dialect/Vector/Utils/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/Utils/CMakeLists.txt
@@ -1,5 +1,6 @@
 add_mlir_dialect_library(MLIRVectorUtils
   VectorUtils.cpp
+  VectorDistributeUtils.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Vector/Utils
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorDistributeUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorDistributeUtils.cpp
new file mode 100644
index 00000000000000..f41581c6d47f2a
--- /dev/null
+++ b/mlir/lib/Dialect/Vector/Utils/VectorDistributeUtils.cpp
@@ -0,0 +1,96 @@
+//===- VectorDistributeUtils.cpp - MLIR Utilities VectorOps distribution  -===//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements utility methods for working with the Vector dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
+
+using namespace mlir;
+
+/// Return a value yielded by `warpOp` which statifies the filter lamdba
+/// condition and is not dead.
+mlir::OpOperand *
+mlir::getWarpResult(vector::WarpExecuteOnLane0Op warpOp,
+                    const std::function<bool(Operation *)> &fn) {
+  auto yield = cast<vector::YieldOp>(
+      warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+  for (mlir::OpOperand &yieldOperand : yield->getOpOperands()) {
+    Value yieldValues = yieldOperand.get();
+    Operation *definedOp = yieldValues.getDefiningOp();
+    if (definedOp && fn(definedOp)) {
+      if (!warpOp.getResult(yieldOperand.getOperandNumber()).use_empty())
+        return &yieldOperand;
+    }
+  }
+  return {};
+}
+
+/// Helper to create a new WarpExecuteOnLane0Op with different signature.
+vector::WarpExecuteOnLane0Op mlir::moveRegionToNewWarpOpAndReplaceReturns(
+    RewriterBase &rewriter, vector::WarpExecuteOnLane0Op warpOp,
+    ValueRange newYieldedValues, TypeRange newReturnTypes) {
+  // Create a new op before the existing one, with the extra operands.
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(warpOp);
+  auto newWarpOp = rewriter.create<vector::WarpExecuteOnLane0Op>(
+      warpOp.getLoc(), newReturnTypes, warpOp.getLaneid(), warpOp.getWarpSize(),
+      warpOp.getArgs(), warpOp.getBody()->getArgumentTypes());
+
+  Region &opBody = warpOp.getBodyRegion();
+  Region &newOpBody = newWarpOp.getBodyRegion();
+  Block &newOpFirstBlock = newOpBody.front();
+  rewriter.inlineRegionBefore(opBody, newOpBody, newOpBody.begin());
+  rewriter.eraseBlock(&newOpFirstBlock);
+  assert(newWarpOp.getWarpRegion().hasOneBlock() &&
+         "expected WarpOp with single block");
+
+  auto yield =
+      cast<vector::YieldOp>(newOpBody.getBlocks().begin()->getTerminator());
+
+  rewriter.modifyOpInPlace(
+      yield, [&]() { yield.getOperandsMutable().assign(newYieldedValues); });
+  return newWarpOp;
+}
+
+/// Helper to create a new WarpExecuteOnLane0Op region with extra outputs.
+/// `indices` return the index of each new output.
+vector::WarpExecuteOnLane0Op mlir::moveRegionToNewWarpOpAndAppendReturns(
+    RewriterBase &rewriter, vector::WarpExecuteOnLane0Op warpOp,
+    ValueRange newYieldedValues, TypeRange newReturnTypes,
+    llvm::SmallVector<size_t> &indices) {
+  SmallVector<Type> types(warpOp.getResultTypes().begin(),
+                          warpOp.getResultTypes().end());
+  auto yield = cast<vector::YieldOp>(
+      warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+  llvm::SmallSetVector<Value, 32> yieldValues(yield.getOperands().begin(),
+                                              yield.getOperands().end());
+  for (auto newRet : llvm::zip(newYieldedValues, newReturnTypes)) {
+    if (yieldValues.insert(std::get<0>(newRet))) {
+      types.push_back(std::get<1>(newRet));
+      indices.push_back(yieldValues.size() - 1);
+    } else {
+      // If the value already exit the region don't create a new output.
+      for (auto [idx, yieldOperand] :
+           llvm::enumerate(yieldValues.getArrayRef())) {
+        if (yieldOperand == std::get<0>(newRet)) {
+          indices.push_back(idx);
+          break;
+        }
+      }
+    }
+  }
+  yieldValues.insert(newYieldedValues.begin(), newYieldedValues.end());
+  vector::WarpExecuteOnLane0Op newWarpOp =
+      moveRegionToNewWarpOpAndReplaceReturns(rewriter, warpOp,
+                                             yieldValues.getArrayRef(), types);
+  rewriter.replaceOp(warpOp,
+                     newWarpOp.getResults().take_front(warpOp.getNumResults()));
+  return newWarpOp;
+}

>From de3ae899d02f6f151cdbd9badcd47c6f1df39861 Mon Sep 17 00:00:00 2001
From: Petr Kurapov <petr.a.kurapov at intel.com>
Date: Fri, 18 Oct 2024 16:18:32 +0000
Subject: [PATCH 3/4] [MLIR][XeGPU] Add distribution patterns for load_nd,
 store_nd, and create_nd_tdesc

---
 .../Dialect/XeGPU/Transforms/Transforms.h     |   1 +
 .../Dialect/XeGPU/Transforms/CMakeLists.txt   |   5 +
 .../XeGPU/Transforms/XeGPUDistribute.cpp      | 393 ++++++++++++++++++
 mlir/test/Dialect/XeGPU/xegpu-distribute.mlir |  81 ++++
 mlir/test/lib/Dialect/CMakeLists.txt          |   1 +
 mlir/test/lib/Dialect/XeGPU/CMakeLists.txt    |  17 +
 .../lib/Dialect/XeGPU/TestXeGPUTransforms.cpp |  58 +++
 mlir/tools/mlir-opt/CMakeLists.txt            |   1 +
 mlir/tools/mlir-opt/mlir-opt.cpp              |   2 +
 9 files changed, 559 insertions(+)
 create mode 100644 mlir/lib/Dialect/XeGPU/Transforms/XeGPUDistribute.cpp
 create mode 100644 mlir/test/Dialect/XeGPU/xegpu-distribute.mlir
 create mode 100644 mlir/test/lib/Dialect/XeGPU/CMakeLists.txt
 create mode 100644 mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp

diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h b/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
index 63ea26df069372..fe5198d1ac6dba 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
@@ -16,6 +16,7 @@ namespace xegpu {
 
 /// Appends patterns for folding aliasing ops into XeGPU ops into `patterns`.
 void populateXeGPUFoldAliasOpsPatterns(RewritePatternSet &patterns);
+void populateXeGPUDistributePatterns(RewritePatternSet &patterns);
 
 } // namespace xegpu
 } // namespace mlir
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
index 7fb64d3b97b87d..148ff46ba41b72 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
@@ -1,5 +1,6 @@
 add_mlir_dialect_library(MLIRXeGPUTransforms
   XeGPUFoldAliasOps.cpp
+  XeGPUDistribute.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/XeGPU
@@ -12,6 +13,10 @@ add_mlir_dialect_library(MLIRXeGPUTransforms
   MLIRIR
   MLIRMemRefDialect
   MLIRXeGPUDialect
+  MLIRVectorDialect
+  MLIRVectorUtils
+  MLIRArithDialect
+  MLIRFuncDialect
   MLIRPass
   MLIRTransforms
 )
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUDistribute.cpp
new file mode 100644
index 00000000000000..78a010ff1c941b
--- /dev/null
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUDistribute.cpp
@@ -0,0 +1,393 @@
+//===- XeGPUDistribute.cpp - XeGPU ditribute ops to work items --*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
+#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "xegpu-distribute"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
+
+using namespace mlir;
+
+namespace {
+bool divisible(APInt lhs, APInt rhs) { return !lhs.urem(rhs); }
+
+/// Clone a create_nd_tdesc feeding into vector.yield op for the enclosing
+/// `vector.warp_execute_on_lane_0` and put it after the warp op.
+/// The warp op will still contain the original op that will not be used by the
+/// yield op (and should be cleaned up later with dce). The yield op will bypass
+/// the create_nd_tdesc's arguments.
+/// The rewrite will create a subview of the size used by a single work item and
+/// appropriate offset. The distributed create_nd_tdesc points into the subview
+/// without offset. The tensor descriptor types is distributed according to
+/// sg_map attribute.
+///
+/// Example:
+///
+/// ```
+///   #sg_map_8 = #xegpu.sg_map<wi_layout = [1, 8], wi_data = [1, 1]>
+///   %r = vector.warp_execute_on_lane_0(%laneid) ->
+///                   (!xegpu.tensor_desc<4x8xf32>) {
+///     ...
+///     %td = xegpu.create_nd_tdesc %arg0[0, 0]
+///               : memref<4x8xf32> -> !xegpu.tensor_desc<4x8xf32>
+///     vector.yield %td
+///   }
+/// ```
+/// To
+/// ```
+///   %r:2 = vector.warp_execute_on_lane_0(%laneid) -> () {
+///     ...
+///     %dead = xegpu.create_nd_tdesc %arg0[0, 0]
+///               : memref<4x8xf32> -> !xegpu.tensor_desc<4x8xf32>
+///     vector.yield %arg0, %dead
+///   }
+///   %view = memref.subview %r#0[0, %laneid] [4, 1] [1, 1]
+///                               : memref<4x8xf32> to memref<4x1xf32>
+///   %td = xegpu.create_nd_tdesc %view[0, 0]: memref<4x1xf32>
+///                                 -> !xegpu.tensor_desc<4x1xf32>
+///
+/// ```
+struct WarpOpTensorDescOp final
+    : public OpRewritePattern<vector::WarpExecuteOnLane0Op> {
+  using OpRewritePattern<vector::WarpExecuteOnLane0Op>::OpRewritePattern;
+  LogicalResult matchAndRewrite(vector::WarpExecuteOnLane0Op warpOp,
+                                PatternRewriter &rewriter) const override;
+};
+
+/// Sink a store_nd feeding into vector.yield op for the enclosing
+/// `vector.warp_execute_on_lane_0`. In case arguments for the store are passed
+/// through the warp op interface they would be propagated as returned values.
+/// Both the stored vector type and tensor descriptor types are distributed
+/// according to sg_map attribute.
+///
+/// Example:
+///
+/// ```
+///   #sg_map_8 = #xegpu.sg_map<wi_layout = [1, 8], wi_data = [1, 1]>
+///   vector.warp_execute_on_lane_0(%laneid) -> () {
+///     ...
+///     xegpu.store_nd %arg0, %arg1: vector<4x8xf32>,
+///                                 !xegpu.tensor_desc<4x8xf32>
+///     vector.yield
+///   }
+/// ```
+/// To
+/// ```
+///   %r = vector.warp_execute_on_lane_0(%laneid) -> () {
+///     ...
+///     vector.yield
+///   }
+///   xegpu.store_nd %arg0, %arg1: vector<4x1xf32>, !xegpu.tensor_desc<4x1xf32>
+///
+/// ```
+struct WarpOpStoreNd final
+    : public OpRewritePattern<vector::WarpExecuteOnLane0Op> {
+  using OpRewritePattern<vector::WarpExecuteOnLane0Op>::OpRewritePattern;
+  LogicalResult matchAndRewrite(vector::WarpExecuteOnLane0Op warpOp,
+                                PatternRewriter &rewriter) const override;
+};
+
+/// Clone a load_nd feeding into vector.yield op for the enclosing
+/// `vector.warp_execute_on_lane_0` and put it after the warp op.
+/// The warp op will still contain the original op that will not be used by the
+/// yield op (and should be cleaned up later with dce). The yield op will bypass
+/// the load's arguments.
+/// Both the loaded vector type and tensor descriptor types are distributed
+/// according to sg_map attribute.
+///
+/// Example:
+///
+/// ```
+///   #sg_map_8 = #xegpu.sg_map<wi_layout = [1, 8], wi_data = [1, 1]>
+///   %r = vector.warp_execute_on_lane_0(%laneid) ->
+///                   (!xegpu.tensor_desc<4x8xf32>) {
+///     ...
+///     %ld = xegpu.load_nd %arg0, %arg1: !xegpu.tensor_desc<4x8xf32>,
+///     vector<4x8xf32> vector.yield %ld
+///   }
+/// ```
+/// To
+/// ```
+///   %r:2 = vector.warp_execute_on_lane_0(%laneid) -> () {
+///     ...
+///     %dead = xegpu.load_nd %arg0, %arg1:
+///         !xegpu.tensor_desc<4x8xf32>, vector<4x8xf32>
+///     vector.yield %arg0, %arg1
+///   }
+///   xegpu.store_nd %r#0, %r#1: vector<4x1xf32>, !xegpu.tensor_desc<4x1xf32>
+///
+/// ```
+struct WarpOpLoadNd final
+    : public OpRewritePattern<vector::WarpExecuteOnLane0Op> {
+  using OpRewritePattern<vector::WarpExecuteOnLane0Op>::OpRewritePattern;
+  LogicalResult matchAndRewrite(vector::WarpExecuteOnLane0Op warpOp,
+                                PatternRewriter &rewriter) const override;
+};
+
+FailureOr<VectorType> getDistributedVectorType(VectorType originalT,
+                                               xegpu::SGMapAttr sgMap) {
+  llvm::SmallVector<int64_t, 2> distributedShape;
+  auto layout = sgMap.getWiLayout();
+  auto shape = originalT.getShape();
+  for (const auto [l, o] : llvm::zip_equal(layout, shape)) {
+    if (!divisible(APInt(64, o), APInt(64, l)))
+      return failure();
+    distributedShape.push_back(o / l);
+  }
+  auto newVectorType =
+      VectorType::get(distributedShape, originalT.getElementType(),
+                      originalT.getScalableDims());
+  return newVectorType;
+}
+
+FailureOr<xegpu::TensorDescType>
+getDistributedTensorDescType(xegpu::TensorDescType originalT,
+                             xegpu::SGMapAttr sgMap,
+                             xegpu::MemorySpace memSpace) {
+  llvm::SmallVector<int64_t, 2> distributedShape;
+  auto layout = sgMap.getWiLayout();
+  auto shape = originalT.getShape();
+  for (const auto [l, o] : llvm::zip_equal(layout, shape)) {
+    if (!divisible(APInt(64, o), APInt(64, l)))
+      return failure();
+    distributedShape.push_back(o / l);
+  }
+  xegpu::TensorDescType distributedDescType;
+  if (originalT.isScattered()) {
+
+    distributedDescType = xegpu::TensorDescType::get(
+        distributedShape, originalT.getElementType(), originalT.getChunkSize(),
+        originalT.getMemorySpace(), originalT.getSGMapAttr());
+  } else {
+    distributedDescType = xegpu::TensorDescType::get(
+        distributedShape, originalT.getElementType(),
+        originalT.getBoundaryCheck(), originalT.getArrayLength(),
+        originalT.getMemorySpace(), originalT.getSGMapAttr());
+  }
+  return distributedDescType;
+}
+} // namespace
+
+LogicalResult
+WarpOpStoreNd::matchAndRewrite(vector::WarpExecuteOnLane0Op warpOp,
+                               PatternRewriter &rewriter) const {
+  auto yield = cast<vector::YieldOp>(
+      warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+  Operation *lastNode = yield->getPrevNode();
+  auto storeOp = dyn_cast_or_null<xegpu::StoreNdOp>(lastNode);
+  if (!storeOp)
+    return failure();
+
+  auto origType = storeOp.getTensorDescType();
+  xegpu::SGMapAttr sgMap = origType.getSGMapAttr();
+  if (!sgMap)
+    return rewriter.notifyMatchFailure(
+        storeOp, "the source tensor descriptor lacks sg_map attribute");
+
+  if (storeOp.getTensorDescType().getShape().size() != 2)
+    return rewriter.notifyMatchFailure(storeOp, "unsupported shape");
+  DBGS() << "Matched store_nd: " << storeOp << "\n";
+
+  auto distributedTypeOrFailure =
+      getDistributedVectorType(storeOp.getValueType(), sgMap);
+  if (failed(distributedTypeOrFailure))
+    return rewriter.notifyMatchFailure(storeOp,
+                                       "Failed to distribute the type");
+  VectorType newVectorType = distributedTypeOrFailure.value();
+
+  auto distributedDescTypeOrFailure = getDistributedTensorDescType(
+      storeOp.getTensorDescType(), sgMap,
+      storeOp.getTensorDescType().getMemorySpace());
+  if (failed(distributedDescTypeOrFailure))
+    return rewriter.notifyMatchFailure(storeOp,
+                                       "Failed to distribute the desc type");
+  xegpu::TensorDescType newTDescType = distributedDescTypeOrFailure.value();
+
+  SmallVector<size_t> newRetIndices;
+  vector::WarpExecuteOnLane0Op newWarpOp =
+      moveRegionToNewWarpOpAndAppendReturns(
+          rewriter, warpOp,
+          ValueRange{storeOp.getTensorDesc(), storeOp.getValue()},
+          TypeRange{newTDescType, newVectorType}, newRetIndices);
+
+  rewriter.setInsertionPointAfter(newWarpOp);
+  auto newStoreOp =
+      cast<xegpu::StoreNdOp>(rewriter.clone(*storeOp.getOperation()));
+  rewriter.eraseOp(storeOp);
+  newStoreOp.getTensorDescMutable().assign(
+      newWarpOp.getResult(newRetIndices[0]));
+  newStoreOp.getValueMutable().assign(newWarpOp.getResult(newRetIndices[1]));
+
+  return success();
+}
+
+LogicalResult WarpOpLoadNd::matchAndRewrite(vector::WarpExecuteOnLane0Op warpOp,
+                                            PatternRewriter &rewriter) const {
+  OpOperand *operand = getWarpResult(warpOp, [](Operation *op) {
+    return isa<xegpu::LoadNdOp>(op) && op->hasOneUse();
+  });
+
+  if (!operand)
+    return rewriter.notifyMatchFailure(warpOp,
+                                       "warp result is not a xegpu::LoadNd op");
+
+  auto loadOp = operand->get().getDefiningOp<xegpu::LoadNdOp>();
+
+  if (loadOp.getPacked())
+    return rewriter.notifyMatchFailure(
+        loadOp, "Packed load distribution not supported");
+
+  xegpu::TensorDescType origType = loadOp.getTensorDescType();
+  xegpu::SGMapAttr sgMap = origType.getSGMapAttr();
+  if (!sgMap)
+    return rewriter.notifyMatchFailure(
+        loadOp, "the source tensor descriptor lacks sg_map attribute");
+
+  auto origShape = origType.getShape();
+  if (origShape.size() != 2)
+    return rewriter.notifyMatchFailure(loadOp, "unsupported shape");
+
+  auto distributedTypeOrFailure =
+      getDistributedVectorType(loadOp.getType(), sgMap);
+  if (failed(distributedTypeOrFailure))
+    return rewriter.notifyMatchFailure(loadOp, "Failed to distribute the type");
+  VectorType newVectorType = distributedTypeOrFailure.value();
+
+  auto distributedDescTypeOrFailure =
+      getDistributedTensorDescType(loadOp.getTensorDescType(), sgMap,
+                                   loadOp.getTensorDescType().getMemorySpace());
+  if (failed(distributedDescTypeOrFailure))
+    return rewriter.notifyMatchFailure(loadOp,
+                                       "Failed to distribute the desc type");
+  xegpu::TensorDescType newTDescType = distributedDescTypeOrFailure.value();
+
+  unsigned operandIdx = operand->getOperandNumber();
+
+  SmallVector<size_t> newRetIndices;
+  vector::WarpExecuteOnLane0Op newWarpOp =
+      moveRegionToNewWarpOpAndAppendReturns(
+          rewriter, warpOp, loadOp.getTensorDesc(), TypeRange{newTDescType},
+          newRetIndices);
+
+  rewriter.setInsertionPointAfter(newWarpOp);
+
+  auto newLoadOp = rewriter.create<xegpu::LoadNdOp>(
+      loadOp.getLoc(), newVectorType, loadOp.getTensorDesc(),
+      loadOp.getPackedAttr(), loadOp.getTransposeAttr(), loadOp.getL1HintAttr(),
+      loadOp.getL2HintAttr(), loadOp.getL3HintAttr());
+
+  newLoadOp.getTensorDescMutable().assign(
+      newWarpOp.getResult(newRetIndices[0]));
+  Value distributedVal = newWarpOp.getResult(operandIdx);
+  rewriter.replaceAllUsesWith(distributedVal, newLoadOp);
+
+  return success();
+}
+
+LogicalResult
+WarpOpTensorDescOp::matchAndRewrite(vector::WarpExecuteOnLane0Op warpOp,
+                                    PatternRewriter &rewriter) const {
+  OpOperand *operand = getWarpResult(warpOp, [](Operation *op) {
+    return isa<xegpu::CreateNdDescOp>(op) && op->hasOneUse();
+  });
+
+  if (!operand)
+    return rewriter.notifyMatchFailure(
+        warpOp, "warp result is not a xegpu::CreateNdDesc op");
+  auto descOp = operand->get().getDefiningOp<xegpu::CreateNdDescOp>();
+  assert(descOp && "desc op must be not null");
+  unsigned operandIdx = operand->getOperandNumber();
+
+  // TODO: is memref uniform in the region
+  rewriter.setInsertionPoint(warpOp);
+  auto srcTypedVal = dyn_cast<TypedValue<MemRefType>>(descOp.getSource());
+  assert(srcTypedVal && "source value must be not null");
+
+  auto descOffsets = descOp.getMixedOffsets();
+  if (descOffsets.size() != 2)
+    return rewriter.notifyMatchFailure(descOp,
+                                       "offsets size is expected to be 2");
+
+  xegpu::SGMapAttr sgMap = descOp.getType().getSGMapAttr();
+  if (!sgMap)
+    return rewriter.notifyMatchFailure(
+        descOp, "the tensor descriptor lacks sg_map attribute");
+
+  auto layout = sgMap.getWiLayout();
+
+  // Calculate the offset within tensor descriptor for the current lane_id. The
+  // access to proper element for a work item is done through a lane-specific
+  // subview (tdesc offsets are used as base, lane shift is added on top).
+  auto laneid = warpOp.getLaneid();
+  auto xDim =
+      rewriter.create<arith::ConstantIndexOp>(laneid.getLoc(), layout[0]);
+  auto shiftx = rewriter.create<arith::RemUIOp>(laneid.getLoc(), laneid, xDim);
+  auto shifty = rewriter.create<arith::DivUIOp>(laneid.getLoc(), laneid, xDim);
+
+  auto basex = getValueOrCreateConstantIndexOp(rewriter, laneid.getLoc(),
+                                               descOffsets[0]);
+  auto basey = getValueOrCreateConstantIndexOp(rewriter, laneid.getLoc(),
+                                               descOffsets[1]);
+  auto offsetx = rewriter.create<arith::AddIOp>(laneid.getLoc(), shiftx, basex);
+  auto offsety = rewriter.create<arith::AddIOp>(laneid.getLoc(), shifty, basey);
+
+  auto distributedDescTypeOrFailure = getDistributedTensorDescType(
+      descOp.getType(), sgMap, descOp.getType().getMemorySpace());
+  if (failed(distributedDescTypeOrFailure))
+    return rewriter.notifyMatchFailure(descOp,
+                                       "Failed to distribute the desc type");
+  xegpu::TensorDescType newTDescType = distributedDescTypeOrFailure.value();
+  auto distributedShape = newTDescType.getShape();
+  // use the base memref strides
+  SmallVector<OpFoldResult> overwriteStrides =
+      getAsIndexOpFoldResult(rewriter.getContext(), SmallVector<int64_t>{1, 1});
+  SmallVector<OpFoldResult> overwriteSizes =
+      getAsIndexOpFoldResult(rewriter.getContext(), distributedShape);
+
+  SmallVector<size_t> newRetIndices;
+  vector::WarpExecuteOnLane0Op newWarpOp =
+      moveRegionToNewWarpOpAndAppendReturns(
+          rewriter, warpOp, descOp.getSource(), descOp.getSourceType(),
+          newRetIndices);
+
+  rewriter.setInsertionPointAfter(newWarpOp);
+  auto subview = rewriter.create<memref::SubViewOp>(
+      newWarpOp.getLoc(), srcTypedVal, getAsOpFoldResult({offsetx, offsety}),
+      overwriteSizes, overwriteStrides);
+  subview.getSourceMutable().assign(newWarpOp.getResult(newRetIndices[0]));
+
+  auto zero = rewriter.create<arith::ConstantIndexOp>(laneid.getLoc(), 0);
+  auto newDescOp = rewriter.create<xegpu::CreateNdDescOp>(
+      newWarpOp.getLoc(), newTDescType, subview,
+      getAsOpFoldResult({zero, zero}));
+
+  Value distributedVal = newWarpOp.getResult(operandIdx);
+  rewriter.replaceAllUsesWith(distributedVal, newDescOp);
+
+  return success();
+}
+
+void xegpu::populateXeGPUDistributePatterns(RewritePatternSet &patterns) {
+  patterns.add<WarpOpTensorDescOp>(patterns.getContext());
+  patterns.add<WarpOpStoreNd>(patterns.getContext());
+  patterns.add<WarpOpLoadNd>(patterns.getContext());
+}
diff --git a/mlir/test/Dialect/XeGPU/xegpu-distribute.mlir b/mlir/test/Dialect/XeGPU/xegpu-distribute.mlir
new file mode 100644
index 00000000000000..ec01fc82688156
--- /dev/null
+++ b/mlir/test/Dialect/XeGPU/xegpu-distribute.mlir
@@ -0,0 +1,81 @@
+// RUN: mlir-opt -test-xegpu-distribute -split-input-file %s | FileCheck %s
+
+#sg_map_16 = #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>
+#blk_tdesc = #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>
+
+// CHECK-LABEL: test_store_nd_distribution
+// CHECK: %[[laneid:.*]] = gpu.lane_id
+// CHECK: %[[res:.*]]:2 = vector.warp_execute_on_lane_0(%[[laneid]])[16] args(%{{.*}}, %{{.*}} :  vector<24x32xf16>, !xegpu.tensor_desc<24x32xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>)
+// CHECK-SAME: -> (!xegpu.tensor_desc<24x2xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>, vector<24x2xf16>)
+// CHECK: ^bb0(%[[src:.*]]: vector<24x32xf16>, %[[dst:.*]]: !xegpu.tensor_desc<24x32xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>)
+// CHECK: vector.yield %[[dst]], %[[src]] : !xegpu.tensor_desc<24x32xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>, vector<24x32xf16>
+// CHECK: xegpu.store_nd %[[res]]#1, %[[res]]#0 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}> :
+// CHECK-SAME: vector<24x2xf16>, !xegpu.tensor_desc<24x2xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+
+func.func @test_store_nd_distribution(%src: vector<24x32xf16>, %dst: !xegpu.tensor_desc<24x32xf16, #blk_tdesc, #sg_map_16>) -> () {
+  %laneid = gpu.lane_id
+  vector.warp_execute_on_lane_0(%laneid)[16]
+        args(%src, %dst: vector<24x32xf16>, !xegpu.tensor_desc<24x32xf16, #blk_tdesc, #sg_map_16>) {
+    ^bb0(%arg0: vector<24x32xf16>, %arg1: !xegpu.tensor_desc<24x32xf16, #blk_tdesc, #sg_map_16>):
+    xegpu.store_nd %arg0, %arg1 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>: vector<24x32xf16>, !xegpu.tensor_desc<24x32xf16, #blk_tdesc, #sg_map_16>
+  }
+  return
+}
+
+// -----
+
+#sg_map_16 = #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>
+#blk_tdesc = #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>
+
+// CHECK-LABEL: test_load_nd_distribution
+// CHECK: %[[laneid:.*]] = gpu.lane_id
+// CHECK: %[[res:.*]]:2 = vector.warp_execute_on_lane_0(%[[laneid]])[16] args(%{{.*}} :  !xegpu.tensor_desc<24x32xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>)
+// CHECK-SAME: -> (vector<24x2xf16>, !xegpu.tensor_desc<24x2xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>)
+// CHECK: ^bb0(%[[dst:.*]]: !xegpu.tensor_desc<24x32xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>)
+// CHECK: %[[dead:.*]] = xegpu.load_nd
+// CHECK: vector.yield %[[dead]], %[[dst]] : vector<24x32xf16>, !xegpu.tensor_desc<24x32xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+// CHECK: %[[load:.*]] = xegpu.load_nd %[[res]]#1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> :
+// CHECK-SAME: !xegpu.tensor_desc<24x2xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>> -> vector<24x2xf16>
+// CHECK: return %[[load]]
+
+func.func @test_load_nd_distribution(%dst: !xegpu.tensor_desc<24x32xf16, #blk_tdesc, #sg_map_16>) -> (vector<24x2xf16>) {
+  %laneid = gpu.lane_id
+  %r = vector.warp_execute_on_lane_0(%laneid)[16]
+        args(%dst: !xegpu.tensor_desc<24x32xf16, #blk_tdesc, #sg_map_16>) -> (vector<24x2xf16>) {
+    ^bb0(%arg0: !xegpu.tensor_desc<24x32xf16, #blk_tdesc, #sg_map_16>):
+    %0 = xegpu.load_nd %arg0 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+       : !xegpu.tensor_desc<24x32xf16, #blk_tdesc, #sg_map_16> -> vector<24x32xf16>
+    vector.yield %0 : vector<24x32xf16>
+  }
+  return %r : vector<24x2xf16>
+}
+
+// -----
+
+#sg_map_16 = #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>
+#blk_tdesc = #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>
+
+// CHECK-LABEL: test_create_nd_desc_distribution
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[laneid:.*]] = gpu.lane_id
+// CHECK: %[[res:.*]]:2 = vector.warp_execute_on_lane_0(%[[laneid]])[16] args(%{{.*}} : memref<24x32xf16>)
+// CHECK-SAME: -> (!xegpu.tensor_desc<24x2xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>, memref<24x32xf16>)
+// CHECK: ^bb0(%[[dst:.*]]: memref<24x32xf16>)
+// CHECK: %[[dead:.*]] = xegpu.create_nd_tdesc
+// CHECK: vector.yield %[[dead]], %[[dst]] :
+// CHECK-SAME: !xegpu.tensor_desc<24x32xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>, memref<24x32xf16>
+// CHECK: %[[view:.*]] = memref.subview %[[res]]#1[%[[C0]], %[[laneid]]] [24, 2] [1, 1] : memref<24x32xf16> to memref<24x2xf16, strided<[32, 1], offset: ?>>
+// CHECK: %[[desc:.*]] = xegpu.create_nd_tdesc %[[view]][0, 0] : memref<24x2xf16, strided<[32, 1], offset: ?>>
+// CHECK-SAME: -> !xegpu.tensor_desc<24x2xf16, #xegpu.block_tdesc_attr<memory_space =  global, array_length = 1 : i64, boundary_check = true>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+// CHECK: return %[[desc]]
+
+func.func @test_create_nd_desc_distribution(%dst: memref<24x32xf16>) -> (!xegpu.tensor_desc<24x2xf16, #blk_tdesc, #sg_map_16>) {
+  %laneid = gpu.lane_id
+  %r = vector.warp_execute_on_lane_0(%laneid)[16]
+        args(%dst: memref<24x32xf16>) -> (!xegpu.tensor_desc<24x2xf16, #blk_tdesc, #sg_map_16>) {
+    ^bb0(%arg0: memref<24x32xf16>):
+    %0 = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<24x32xf16, #blk_tdesc, #sg_map_16>
+    vector.yield %0 : !xegpu.tensor_desc<24x32xf16, #blk_tdesc, #sg_map_16>
+  }
+  return %r : !xegpu.tensor_desc<24x2xf16, #blk_tdesc, #sg_map_16>
+}
diff --git a/mlir/test/lib/Dialect/CMakeLists.txt b/mlir/test/lib/Dialect/CMakeLists.txt
index 29fb4441a24fd2..a8fd70e6397a52 100644
--- a/mlir/test/lib/Dialect/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/CMakeLists.txt
@@ -22,3 +22,4 @@ add_subdirectory(TestDyn)
 add_subdirectory(Tosa)
 add_subdirectory(Transform)
 add_subdirectory(Vector)
+add_subdirectory(XeGPU)
diff --git a/mlir/test/lib/Dialect/XeGPU/CMakeLists.txt b/mlir/test/lib/Dialect/XeGPU/CMakeLists.txt
new file mode 100644
index 00000000000000..c8fe0db5f6213a
--- /dev/null
+++ b/mlir/test/lib/Dialect/XeGPU/CMakeLists.txt
@@ -0,0 +1,17 @@
+# Exclude tests from libMLIR.so
+add_mlir_library(MLIRXeGPUTestPasses
+  TestXeGPUTransforms.cpp
+
+  EXCLUDE_FROM_LIBMLIR
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/XeGPU
+
+  LINK_LIBS PUBLIC
+  MLIRIR
+  MLIRPass
+  MLIRXeGPUTransforms
+  MLIRXeGPUDialect
+  MLIRSupport
+  )
+
diff --git a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
new file mode 100644
index 00000000000000..eda68b83748139
--- /dev/null
+++ b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
@@ -0,0 +1,58 @@
+//===- TestXeGPUTransforms.cpp - Test XeGPU transforms and lowerings ------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+using namespace mlir::xegpu;
+using namespace mlir::vector;
+
+namespace {
+struct TestXeGPUDistribution
+    : public PassWrapper<TestXeGPUDistribution, OperationPass<func::FuncOp>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestXeGPUDistribution)
+
+  TestXeGPUDistribution() = default;
+  TestXeGPUDistribution(const TestXeGPUDistribution &pass)
+      : PassWrapper(pass) {}
+
+  StringRef getArgument() const final { return "test-xegpu-distribute"; }
+  StringRef getDescription() const final {
+    return "Test patterns for operations work item distribution";
+  }
+
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<xegpu::XeGPUDialect>();
+    registry.insert<vector::VectorDialect>();
+    registry.insert<arith::ArithDialect>();
+    registry.insert<memref::MemRefDialect>();
+  }
+
+  void runOnOperation() override {
+    RewritePatternSet patterns(&getContext());
+    populateXeGPUDistributePatterns(patterns);
+    (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+  }
+};
+} // namespace
+
+namespace mlir {
+namespace test {
+void registerTestXeGPUTransforms() {
+  PassRegistration<TestXeGPUDistribution>();
+}
+} // namespace test
+} // namespace mlir
diff --git a/mlir/tools/mlir-opt/CMakeLists.txt b/mlir/tools/mlir-opt/CMakeLists.txt
index 8b79de58fa1028..e4ffbbee7a1d94 100644
--- a/mlir/tools/mlir-opt/CMakeLists.txt
+++ b/mlir/tools/mlir-opt/CMakeLists.txt
@@ -47,6 +47,7 @@ if(MLIR_INCLUDE_TESTS)
     MLIRTilingInterfaceTestPasses
     MLIRVectorTestPasses
     MLIRTestVectorToSPIRV
+    MLIRXeGPUTestPasses
     MLIRLLVMTestPasses
     )
   set(test_libs ${test_libs}
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 36b142484bb04a..b53e9513b0598d 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -151,6 +151,7 @@ void registerTestTransformDialectEraseSchedulePass();
 void registerTestPassStateExtensionCommunication();
 void registerTestVectorLowerings();
 void registerTestVectorReductionToSPIRVDotProd();
+void registerTestXeGPUTransforms();
 void registerTestWrittenToPass();
 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH
 void registerTestDialectConversionPasses();
@@ -286,6 +287,7 @@ void registerTestPasses() {
   mlir::test::registerTestTransformDialectEraseSchedulePass();
   mlir::test::registerTestPassStateExtensionCommunication();
   mlir::test::registerTestVectorLowerings();
+  mlir::test::registerTestXeGPUTransforms();
   mlir::test::registerTestVectorReductionToSPIRVDotProd();
   mlir::test::registerTestWrittenToPass();
 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH

>From f907be64e1bb4af07df070c6e3d49e989ec6c6db Mon Sep 17 00:00:00 2001
From: Petr Kurapov <petr.a.kurapov at intel.com>
Date: Tue, 29 Oct 2024 09:23:01 +0000
Subject: [PATCH 4/4] Remove duplicate comments

---
 mlir/lib/Dialect/Vector/Utils/VectorDistributeUtils.cpp | 5 -----
 1 file changed, 5 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Utils/VectorDistributeUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorDistributeUtils.cpp
index f41581c6d47f2a..91dac58bacf66c 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorDistributeUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorDistributeUtils.cpp
@@ -14,8 +14,6 @@
 
 using namespace mlir;
 
-/// Return a value yielded by `warpOp` which statifies the filter lamdba
-/// condition and is not dead.
 mlir::OpOperand *
 mlir::getWarpResult(vector::WarpExecuteOnLane0Op warpOp,
                     const std::function<bool(Operation *)> &fn) {
@@ -32,7 +30,6 @@ mlir::getWarpResult(vector::WarpExecuteOnLane0Op warpOp,
   return {};
 }
 
-/// Helper to create a new WarpExecuteOnLane0Op with different signature.
 vector::WarpExecuteOnLane0Op mlir::moveRegionToNewWarpOpAndReplaceReturns(
     RewriterBase &rewriter, vector::WarpExecuteOnLane0Op warpOp,
     ValueRange newYieldedValues, TypeRange newReturnTypes) {
@@ -59,8 +56,6 @@ vector::WarpExecuteOnLane0Op mlir::moveRegionToNewWarpOpAndReplaceReturns(
   return newWarpOp;
 }
 
-/// Helper to create a new WarpExecuteOnLane0Op region with extra outputs.
-/// `indices` return the index of each new output.
 vector::WarpExecuteOnLane0Op mlir::moveRegionToNewWarpOpAndAppendReturns(
     RewriterBase &rewriter, vector::WarpExecuteOnLane0Op warpOp,
     ValueRange newYieldedValues, TypeRange newReturnTypes,



More information about the Mlir-commits mailing list