[Mlir-commits] [mlir] [MLIR][XeGPU] Xegpu distribution patterns for load_nd, store_nd, and create_nd_tdesc. (PR #112945)
Petr Kurapov
llvmlistbot at llvm.org
Tue Oct 29 02:23:22 PDT 2024
https://github.com/kurapov-peter updated https://github.com/llvm/llvm-project/pull/112945
>From d292605dbd568e946efeac0260e7d16d416ac95c Mon Sep 17 00:00:00 2001
From: Petr Kurapov <petr.a.kurapov at intel.com>
Date: Thu, 17 Oct 2024 17:29:00 +0000
Subject: [PATCH 1/4] [MLIR][Vector] Allow any shaped typed to be distributed
for vector.warp_execute_on_lane_0's return values
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 12 ++++++------
mlir/test/Dialect/Vector/invalid.mlir | 6 +++---
2 files changed, 9 insertions(+), 9 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index a2abe1619454f2..51d3691fd107ae 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -6558,14 +6558,14 @@ static LogicalResult verifyDistributedType(Type expanded, Type distributed,
// If the types matches there is no distribution.
if (expanded == distributed)
return success();
- auto expandedVecType = llvm::dyn_cast<VectorType>(expanded);
- auto distributedVecType = llvm::dyn_cast<VectorType>(distributed);
+ auto expandedVecType = llvm::dyn_cast<ShapedType>(expanded);
+ auto distributedVecType = llvm::dyn_cast<ShapedType>(distributed);
if (!expandedVecType || !distributedVecType)
- return op->emitOpError("expected vector type for distributed operands.");
+ return op->emitOpError("expected shaped type for distributed operands.");
if (expandedVecType.getRank() != distributedVecType.getRank() ||
expandedVecType.getElementType() != distributedVecType.getElementType())
return op->emitOpError(
- "expected distributed vectors to have same rank and element type.");
+ "expected distributed types to have same rank and element type.");
SmallVector<int64_t> scales(expandedVecType.getRank(), 1);
for (int64_t i = 0, e = expandedVecType.getRank(); i < e; i++) {
@@ -6575,8 +6575,8 @@ static LogicalResult verifyDistributedType(Type expanded, Type distributed,
continue;
if (eDim % dDim != 0)
return op->emitOpError()
- << "expected expanded vector dimension #" << i << " (" << eDim
- << ") to be a multipler of the distributed vector dimension ("
+ << "expected expanded type dimension #" << i << " (" << eDim
+ << ") to be a multipler of the distributed type dimension ("
<< dDim << ")";
scales[i] = eDim / dDim;
}
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 36d04bb77e3b96..69346574177929 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1665,7 +1665,7 @@ func.func @warp_2_distributed_dims(%laneid: index) {
// -----
func.func @warp_2_distributed_dims(%laneid: index) {
- // expected-error at +1 {{expected expanded vector dimension #1 (8) to be a multipler of the distributed vector dimension (3)}}
+ // expected-error at +1 {{expected expanded type dimension #1 (8) to be a multipler of the distributed type dimension (3)}}
%2 = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<1x3xi32>) {
%0 = arith.constant dense<2>: vector<4x8xi32>
vector.yield %0 : vector<4x8xi32>
@@ -1676,7 +1676,7 @@ func.func @warp_2_distributed_dims(%laneid: index) {
// -----
func.func @warp_mismatch_rank(%laneid: index) {
- // expected-error at +1 {{'vector.warp_execute_on_lane_0' op expected distributed vectors to have same rank and element type.}}
+ // expected-error at +1 {{'vector.warp_execute_on_lane_0' op expected distributed types to have same rank and element type.}}
%2 = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<4x4xi32>) {
%0 = arith.constant dense<2>: vector<128xi32>
vector.yield %0 : vector<128xi32>
@@ -1687,7 +1687,7 @@ func.func @warp_mismatch_rank(%laneid: index) {
// -----
func.func @warp_mismatch_rank(%laneid: index) {
- // expected-error at +1 {{'vector.warp_execute_on_lane_0' op expected vector type for distributed operands.}}
+ // expected-error at +1 {{'vector.warp_execute_on_lane_0' op expected shaped type for distributed operands.}}
%2 = vector.warp_execute_on_lane_0(%laneid)[32] -> (i32) {
%0 = arith.constant dense<2>: vector<128xi32>
vector.yield %0 : vector<128xi32>
>From 7627cf747d98c33ce24e585f4b695962d681139d Mon Sep 17 00:00:00 2001
From: Petr Kurapov <petr.a.kurapov at intel.com>
Date: Fri, 18 Oct 2024 15:50:31 +0000
Subject: [PATCH 2/4] [MLIR][Vector] Move helper functions to vector
distribution utils
---
.../mlir/Dialect/Vector/Utils/VectorUtils.h | 20 ++++
.../Vector/Transforms/VectorDistribute.cpp | 80 +---------------
mlir/lib/Dialect/Vector/Utils/CMakeLists.txt | 1 +
.../Vector/Utils/VectorDistributeUtils.cpp | 96 +++++++++++++++++++
4 files changed, 118 insertions(+), 79 deletions(-)
create mode 100644 mlir/lib/Dialect/Vector/Utils/VectorDistributeUtils.cpp
diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
index 5f32aca88a2734..6bd924307376dc 100644
--- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
+++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
@@ -20,6 +20,8 @@
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/TypeSwitch.h"
+#include <utility>
+
namespace mlir {
// Forward declarations.
@@ -324,6 +326,24 @@ namespace matcher {
bool operatesOnSuperVectorsOf(Operation &op, VectorType subVectorType);
} // namespace matcher
+
+/// Return a value yielded by `warpOp` which statifies the filter lamdba
+/// condition and is not dead.
+OpOperand *getWarpResult(vector::WarpExecuteOnLane0Op warpOp,
+ const std::function<bool(Operation *)> &fn);
+
+/// Helper to create a new WarpExecuteOnLane0Op with different signature.
+vector::WarpExecuteOnLane0Op moveRegionToNewWarpOpAndReplaceReturns(
+ RewriterBase &rewriter, vector::WarpExecuteOnLane0Op warpOp,
+ ValueRange newYieldedValues, TypeRange newReturnTypes);
+
+/// Helper to create a new WarpExecuteOnLane0Op region with extra outputs.
+/// `indices` return the index of each new output.
+vector::WarpExecuteOnLane0Op moveRegionToNewWarpOpAndAppendReturns(
+ RewriterBase &rewriter, vector::WarpExecuteOnLane0Op warpOp,
+ ValueRange newYieldedValues, TypeRange newReturnTypes,
+ llvm::SmallVector<size_t> &indices);
+
} // namespace mlir
#endif // MLIR_DIALECT_VECTOR_UTILS_VECTORUTILS_H_
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 2289fd1ff1364e..c80c3179b5e025 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -12,6 +12,7 @@
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Transforms/RegionUtils.h"
@@ -160,68 +161,6 @@ struct DistributedLoadStoreHelper {
} // namespace
-/// Helper to create a new WarpExecuteOnLane0Op with different signature.
-static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndReplaceReturns(
- RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
- ValueRange newYieldedValues, TypeRange newReturnTypes) {
- // Create a new op before the existing one, with the extra operands.
- OpBuilder::InsertionGuard g(rewriter);
- rewriter.setInsertionPoint(warpOp);
- auto newWarpOp = rewriter.create<WarpExecuteOnLane0Op>(
- warpOp.getLoc(), newReturnTypes, warpOp.getLaneid(), warpOp.getWarpSize(),
- warpOp.getArgs(), warpOp.getBody()->getArgumentTypes());
-
- Region &opBody = warpOp.getBodyRegion();
- Region &newOpBody = newWarpOp.getBodyRegion();
- Block &newOpFirstBlock = newOpBody.front();
- rewriter.inlineRegionBefore(opBody, newOpBody, newOpBody.begin());
- rewriter.eraseBlock(&newOpFirstBlock);
- assert(newWarpOp.getWarpRegion().hasOneBlock() &&
- "expected WarpOp with single block");
-
- auto yield =
- cast<vector::YieldOp>(newOpBody.getBlocks().begin()->getTerminator());
-
- rewriter.modifyOpInPlace(
- yield, [&]() { yield.getOperandsMutable().assign(newYieldedValues); });
- return newWarpOp;
-}
-
-/// Helper to create a new WarpExecuteOnLane0Op region with extra outputs.
-/// `indices` return the index of each new output.
-static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndAppendReturns(
- RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
- ValueRange newYieldedValues, TypeRange newReturnTypes,
- llvm::SmallVector<size_t> &indices) {
- SmallVector<Type> types(warpOp.getResultTypes().begin(),
- warpOp.getResultTypes().end());
- auto yield = cast<vector::YieldOp>(
- warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
- llvm::SmallSetVector<Value, 32> yieldValues(yield.getOperands().begin(),
- yield.getOperands().end());
- for (auto newRet : llvm::zip(newYieldedValues, newReturnTypes)) {
- if (yieldValues.insert(std::get<0>(newRet))) {
- types.push_back(std::get<1>(newRet));
- indices.push_back(yieldValues.size() - 1);
- } else {
- // If the value already exit the region don't create a new output.
- for (auto [idx, yieldOperand] :
- llvm::enumerate(yieldValues.getArrayRef())) {
- if (yieldOperand == std::get<0>(newRet)) {
- indices.push_back(idx);
- break;
- }
- }
- }
- }
- yieldValues.insert(newYieldedValues.begin(), newYieldedValues.end());
- WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
- rewriter, warpOp, yieldValues.getArrayRef(), types);
- rewriter.replaceOp(warpOp,
- newWarpOp.getResults().take_front(warpOp.getNumResults()));
- return newWarpOp;
-}
-
/// Helper to know if an op can be hoisted out of the region.
static bool canBeHoisted(Operation *op,
function_ref<bool(Value)> definedOutside) {
@@ -229,23 +168,6 @@ static bool canBeHoisted(Operation *op,
isMemoryEffectFree(op) && op->getNumRegions() == 0;
}
-/// Return a value yielded by `warpOp` which statifies the filter lamdba
-/// condition and is not dead.
-static OpOperand *getWarpResult(WarpExecuteOnLane0Op warpOp,
- const std::function<bool(Operation *)> &fn) {
- auto yield = cast<vector::YieldOp>(
- warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
- for (OpOperand &yieldOperand : yield->getOpOperands()) {
- Value yieldValues = yieldOperand.get();
- Operation *definedOp = yieldValues.getDefiningOp();
- if (definedOp && fn(definedOp)) {
- if (!warpOp.getResult(yieldOperand.getOperandNumber()).use_empty())
- return &yieldOperand;
- }
- }
- return {};
-}
-
// Clones `op` into a new operation that takes `operands` and returns
// `resultTypes`.
static Operation *cloneOpWithOperandsAndTypes(RewriterBase &rewriter,
diff --git a/mlir/lib/Dialect/Vector/Utils/CMakeLists.txt b/mlir/lib/Dialect/Vector/Utils/CMakeLists.txt
index fa3971695d4bf2..9db0c172fec5ce 100644
--- a/mlir/lib/Dialect/Vector/Utils/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/Utils/CMakeLists.txt
@@ -1,5 +1,6 @@
add_mlir_dialect_library(MLIRVectorUtils
VectorUtils.cpp
+ VectorDistributeUtils.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Vector/Utils
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorDistributeUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorDistributeUtils.cpp
new file mode 100644
index 00000000000000..f41581c6d47f2a
--- /dev/null
+++ b/mlir/lib/Dialect/Vector/Utils/VectorDistributeUtils.cpp
@@ -0,0 +1,96 @@
+//===- VectorDistributeUtils.cpp - MLIR Utilities VectorOps distribution -===//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements utility methods for working with the Vector dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
+
+using namespace mlir;
+
+/// Return a value yielded by `warpOp` which statifies the filter lamdba
+/// condition and is not dead.
+mlir::OpOperand *
+mlir::getWarpResult(vector::WarpExecuteOnLane0Op warpOp,
+ const std::function<bool(Operation *)> &fn) {
+ auto yield = cast<vector::YieldOp>(
+ warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+ for (mlir::OpOperand &yieldOperand : yield->getOpOperands()) {
+ Value yieldValues = yieldOperand.get();
+ Operation *definedOp = yieldValues.getDefiningOp();
+ if (definedOp && fn(definedOp)) {
+ if (!warpOp.getResult(yieldOperand.getOperandNumber()).use_empty())
+ return &yieldOperand;
+ }
+ }
+ return {};
+}
+
+/// Helper to create a new WarpExecuteOnLane0Op with different signature.
+vector::WarpExecuteOnLane0Op mlir::moveRegionToNewWarpOpAndReplaceReturns(
+ RewriterBase &rewriter, vector::WarpExecuteOnLane0Op warpOp,
+ ValueRange newYieldedValues, TypeRange newReturnTypes) {
+ // Create a new op before the existing one, with the extra operands.
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(warpOp);
+ auto newWarpOp = rewriter.create<vector::WarpExecuteOnLane0Op>(
+ warpOp.getLoc(), newReturnTypes, warpOp.getLaneid(), warpOp.getWarpSize(),
+ warpOp.getArgs(), warpOp.getBody()->getArgumentTypes());
+
+ Region &opBody = warpOp.getBodyRegion();
+ Region &newOpBody = newWarpOp.getBodyRegion();
+ Block &newOpFirstBlock = newOpBody.front();
+ rewriter.inlineRegionBefore(opBody, newOpBody, newOpBody.begin());
+ rewriter.eraseBlock(&newOpFirstBlock);
+ assert(newWarpOp.getWarpRegion().hasOneBlock() &&
+ "expected WarpOp with single block");
+
+ auto yield =
+ cast<vector::YieldOp>(newOpBody.getBlocks().begin()->getTerminator());
+
+ rewriter.modifyOpInPlace(
+ yield, [&]() { yield.getOperandsMutable().assign(newYieldedValues); });
+ return newWarpOp;
+}
+
+/// Helper to create a new WarpExecuteOnLane0Op region with extra outputs.
+/// `indices` return the index of each new output.
+vector::WarpExecuteOnLane0Op mlir::moveRegionToNewWarpOpAndAppendReturns(
+ RewriterBase &rewriter, vector::WarpExecuteOnLane0Op warpOp,
+ ValueRange newYieldedValues, TypeRange newReturnTypes,
+ llvm::SmallVector<size_t> &indices) {
+ SmallVector<Type> types(warpOp.getResultTypes().begin(),
+ warpOp.getResultTypes().end());
+ auto yield = cast<vector::YieldOp>(
+ warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+ llvm::SmallSetVector<Value, 32> yieldValues(yield.getOperands().begin(),
+ yield.getOperands().end());
+ for (auto newRet : llvm::zip(newYieldedValues, newReturnTypes)) {
+ if (yieldValues.insert(std::get<0>(newRet))) {
+ types.push_back(std::get<1>(newRet));
+ indices.push_back(yieldValues.size() - 1);
+ } else {
+ // If the value already exit the region don't create a new output.
+ for (auto [idx, yieldOperand] :
+ llvm::enumerate(yieldValues.getArrayRef())) {
+ if (yieldOperand == std::get<0>(newRet)) {
+ indices.push_back(idx);
+ break;
+ }
+ }
+ }
+ }
+ yieldValues.insert(newYieldedValues.begin(), newYieldedValues.end());
+ vector::WarpExecuteOnLane0Op newWarpOp =
+ moveRegionToNewWarpOpAndReplaceReturns(rewriter, warpOp,
+ yieldValues.getArrayRef(), types);
+ rewriter.replaceOp(warpOp,
+ newWarpOp.getResults().take_front(warpOp.getNumResults()));
+ return newWarpOp;
+}
>From de3ae899d02f6f151cdbd9badcd47c6f1df39861 Mon Sep 17 00:00:00 2001
From: Petr Kurapov <petr.a.kurapov at intel.com>
Date: Fri, 18 Oct 2024 16:18:32 +0000
Subject: [PATCH 3/4] [MLIR][XeGPU] Add distribution patterns for load_nd,
store_nd, and create_nd_tdesc
---
.../Dialect/XeGPU/Transforms/Transforms.h | 1 +
.../Dialect/XeGPU/Transforms/CMakeLists.txt | 5 +
.../XeGPU/Transforms/XeGPUDistribute.cpp | 393 ++++++++++++++++++
mlir/test/Dialect/XeGPU/xegpu-distribute.mlir | 81 ++++
mlir/test/lib/Dialect/CMakeLists.txt | 1 +
mlir/test/lib/Dialect/XeGPU/CMakeLists.txt | 17 +
.../lib/Dialect/XeGPU/TestXeGPUTransforms.cpp | 58 +++
mlir/tools/mlir-opt/CMakeLists.txt | 1 +
mlir/tools/mlir-opt/mlir-opt.cpp | 2 +
9 files changed, 559 insertions(+)
create mode 100644 mlir/lib/Dialect/XeGPU/Transforms/XeGPUDistribute.cpp
create mode 100644 mlir/test/Dialect/XeGPU/xegpu-distribute.mlir
create mode 100644 mlir/test/lib/Dialect/XeGPU/CMakeLists.txt
create mode 100644 mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h b/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
index 63ea26df069372..fe5198d1ac6dba 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
@@ -16,6 +16,7 @@ namespace xegpu {
/// Appends patterns for folding aliasing ops into XeGPU ops into `patterns`.
void populateXeGPUFoldAliasOpsPatterns(RewritePatternSet &patterns);
+void populateXeGPUDistributePatterns(RewritePatternSet &patterns);
} // namespace xegpu
} // namespace mlir
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
index 7fb64d3b97b87d..148ff46ba41b72 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
@@ -1,5 +1,6 @@
add_mlir_dialect_library(MLIRXeGPUTransforms
XeGPUFoldAliasOps.cpp
+ XeGPUDistribute.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/XeGPU
@@ -12,6 +13,10 @@ add_mlir_dialect_library(MLIRXeGPUTransforms
MLIRIR
MLIRMemRefDialect
MLIRXeGPUDialect
+ MLIRVectorDialect
+ MLIRVectorUtils
+ MLIRArithDialect
+ MLIRFuncDialect
MLIRPass
MLIRTransforms
)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUDistribute.cpp
new file mode 100644
index 00000000000000..78a010ff1c941b
--- /dev/null
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUDistribute.cpp
@@ -0,0 +1,393 @@
+//===- XeGPUDistribute.cpp - XeGPU ditribute ops to work items --*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
+#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "xegpu-distribute"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
+
+using namespace mlir;
+
+namespace {
+bool divisible(APInt lhs, APInt rhs) { return !lhs.urem(rhs); }
+
+/// Clone a create_nd_tdesc feeding into vector.yield op for the enclosing
+/// `vector.warp_execute_on_lane_0` and put it after the warp op.
+/// The warp op will still contain the original op that will not be used by the
+/// yield op (and should be cleaned up later with dce). The yield op will bypass
+/// the create_nd_tdesc's arguments.
+/// The rewrite will create a subview of the size used by a single work item and
+/// appropriate offset. The distributed create_nd_tdesc points into the subview
+/// without offset. The tensor descriptor types is distributed according to
+/// sg_map attribute.
+///
+/// Example:
+///
+/// ```
+/// #sg_map_8 = #xegpu.sg_map<wi_layout = [1, 8], wi_data = [1, 1]>
+/// %r = vector.warp_execute_on_lane_0(%laneid) ->
+/// (!xegpu.tensor_desc<4x8xf32>) {
+/// ...
+/// %td = xegpu.create_nd_tdesc %arg0[0, 0]
+/// : memref<4x8xf32> -> !xegpu.tensor_desc<4x8xf32>
+/// vector.yield %td
+/// }
+/// ```
+/// To
+/// ```
+/// %r:2 = vector.warp_execute_on_lane_0(%laneid) -> () {
+/// ...
+/// %dead = xegpu.create_nd_tdesc %arg0[0, 0]
+/// : memref<4x8xf32> -> !xegpu.tensor_desc<4x8xf32>
+/// vector.yield %arg0, %dead
+/// }
+/// %view = memref.subview %r#0[0, %laneid] [4, 1] [1, 1]
+/// : memref<4x8xf32> to memref<4x1xf32>
+/// %td = xegpu.create_nd_tdesc %view[0, 0]: memref<4x1xf32>
+/// -> !xegpu.tensor_desc<4x1xf32>
+///
+/// ```
+struct WarpOpTensorDescOp final
+ : public OpRewritePattern<vector::WarpExecuteOnLane0Op> {
+ using OpRewritePattern<vector::WarpExecuteOnLane0Op>::OpRewritePattern;
+ LogicalResult matchAndRewrite(vector::WarpExecuteOnLane0Op warpOp,
+ PatternRewriter &rewriter) const override;
+};
+
+/// Sink a store_nd feeding into vector.yield op for the enclosing
+/// `vector.warp_execute_on_lane_0`. In case arguments for the store are passed
+/// through the warp op interface they would be propagated as returned values.
+/// Both the stored vector type and tensor descriptor types are distributed
+/// according to sg_map attribute.
+///
+/// Example:
+///
+/// ```
+/// #sg_map_8 = #xegpu.sg_map<wi_layout = [1, 8], wi_data = [1, 1]>
+/// vector.warp_execute_on_lane_0(%laneid) -> () {
+/// ...
+/// xegpu.store_nd %arg0, %arg1: vector<4x8xf32>,
+/// !xegpu.tensor_desc<4x8xf32>
+/// vector.yield
+/// }
+/// ```
+/// To
+/// ```
+/// %r = vector.warp_execute_on_lane_0(%laneid) -> () {
+/// ...
+/// vector.yield
+/// }
+/// xegpu.store_nd %arg0, %arg1: vector<4x1xf32>, !xegpu.tensor_desc<4x1xf32>
+///
+/// ```
+struct WarpOpStoreNd final
+ : public OpRewritePattern<vector::WarpExecuteOnLane0Op> {
+ using OpRewritePattern<vector::WarpExecuteOnLane0Op>::OpRewritePattern;
+ LogicalResult matchAndRewrite(vector::WarpExecuteOnLane0Op warpOp,
+ PatternRewriter &rewriter) const override;
+};
+
+/// Clone a load_nd feeding into vector.yield op for the enclosing
+/// `vector.warp_execute_on_lane_0` and put it after the warp op.
+/// The warp op will still contain the original op that will not be used by the
+/// yield op (and should be cleaned up later with dce). The yield op will bypass
+/// the load's arguments.
+/// Both the loaded vector type and tensor descriptor types are distributed
+/// according to sg_map attribute.
+///
+/// Example:
+///
+/// ```
+/// #sg_map_8 = #xegpu.sg_map<wi_layout = [1, 8], wi_data = [1, 1]>
+/// %r = vector.warp_execute_on_lane_0(%laneid) ->
+/// (!xegpu.tensor_desc<4x8xf32>) {
+/// ...
+/// %ld = xegpu.load_nd %arg0, %arg1: !xegpu.tensor_desc<4x8xf32>,
+/// vector<4x8xf32> vector.yield %ld
+/// }
+/// ```
+/// To
+/// ```
+/// %r:2 = vector.warp_execute_on_lane_0(%laneid) -> () {
+/// ...
+/// %dead = xegpu.load_nd %arg0, %arg1:
+/// !xegpu.tensor_desc<4x8xf32>, vector<4x8xf32>
+/// vector.yield %arg0, %arg1
+/// }
+/// xegpu.store_nd %r#0, %r#1: vector<4x1xf32>, !xegpu.tensor_desc<4x1xf32>
+///
+/// ```
+struct WarpOpLoadNd final
+ : public OpRewritePattern<vector::WarpExecuteOnLane0Op> {
+ using OpRewritePattern<vector::WarpExecuteOnLane0Op>::OpRewritePattern;
+ LogicalResult matchAndRewrite(vector::WarpExecuteOnLane0Op warpOp,
+ PatternRewriter &rewriter) const override;
+};
+
+FailureOr<VectorType> getDistributedVectorType(VectorType originalT,
+ xegpu::SGMapAttr sgMap) {
+ llvm::SmallVector<int64_t, 2> distributedShape;
+ auto layout = sgMap.getWiLayout();
+ auto shape = originalT.getShape();
+ for (const auto [l, o] : llvm::zip_equal(layout, shape)) {
+ if (!divisible(APInt(64, o), APInt(64, l)))
+ return failure();
+ distributedShape.push_back(o / l);
+ }
+ auto newVectorType =
+ VectorType::get(distributedShape, originalT.getElementType(),
+ originalT.getScalableDims());
+ return newVectorType;
+}
+
+FailureOr<xegpu::TensorDescType>
+getDistributedTensorDescType(xegpu::TensorDescType originalT,
+ xegpu::SGMapAttr sgMap,
+ xegpu::MemorySpace memSpace) {
+ llvm::SmallVector<int64_t, 2> distributedShape;
+ auto layout = sgMap.getWiLayout();
+ auto shape = originalT.getShape();
+ for (const auto [l, o] : llvm::zip_equal(layout, shape)) {
+ if (!divisible(APInt(64, o), APInt(64, l)))
+ return failure();
+ distributedShape.push_back(o / l);
+ }
+ xegpu::TensorDescType distributedDescType;
+ if (originalT.isScattered()) {
+
+ distributedDescType = xegpu::TensorDescType::get(
+ distributedShape, originalT.getElementType(), originalT.getChunkSize(),
+ originalT.getMemorySpace(), originalT.getSGMapAttr());
+ } else {
+ distributedDescType = xegpu::TensorDescType::get(
+ distributedShape, originalT.getElementType(),
+ originalT.getBoundaryCheck(), originalT.getArrayLength(),
+ originalT.getMemorySpace(), originalT.getSGMapAttr());
+ }
+ return distributedDescType;
+}
+} // namespace
+
+LogicalResult
+WarpOpStoreNd::matchAndRewrite(vector::WarpExecuteOnLane0Op warpOp,
+ PatternRewriter &rewriter) const {
+ auto yield = cast<vector::YieldOp>(
+ warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+ Operation *lastNode = yield->getPrevNode();
+ auto storeOp = dyn_cast_or_null<xegpu::StoreNdOp>(lastNode);
+ if (!storeOp)
+ return failure();
+
+ auto origType = storeOp.getTensorDescType();
+ xegpu::SGMapAttr sgMap = origType.getSGMapAttr();
+ if (!sgMap)
+ return rewriter.notifyMatchFailure(
+ storeOp, "the source tensor descriptor lacks sg_map attribute");
+
+ if (storeOp.getTensorDescType().getShape().size() != 2)
+ return rewriter.notifyMatchFailure(storeOp, "unsupported shape");
+ DBGS() << "Matched store_nd: " << storeOp << "\n";
+
+ auto distributedTypeOrFailure =
+ getDistributedVectorType(storeOp.getValueType(), sgMap);
+ if (failed(distributedTypeOrFailure))
+ return rewriter.notifyMatchFailure(storeOp,
+ "Failed to distribute the type");
+ VectorType newVectorType = distributedTypeOrFailure.value();
+
+ auto distributedDescTypeOrFailure = getDistributedTensorDescType(
+ storeOp.getTensorDescType(), sgMap,
+ storeOp.getTensorDescType().getMemorySpace());
+ if (failed(distributedDescTypeOrFailure))
+ return rewriter.notifyMatchFailure(storeOp,
+ "Failed to distribute the desc type");
+ xegpu::TensorDescType newTDescType = distributedDescTypeOrFailure.value();
+
+ SmallVector<size_t> newRetIndices;
+ vector::WarpExecuteOnLane0Op newWarpOp =
+ moveRegionToNewWarpOpAndAppendReturns(
+ rewriter, warpOp,
+ ValueRange{storeOp.getTensorDesc(), storeOp.getValue()},
+ TypeRange{newTDescType, newVectorType}, newRetIndices);
+
+ rewriter.setInsertionPointAfter(newWarpOp);
+ auto newStoreOp =
+ cast<xegpu::StoreNdOp>(rewriter.clone(*storeOp.getOperation()));
+ rewriter.eraseOp(storeOp);
+ newStoreOp.getTensorDescMutable().assign(
+ newWarpOp.getResult(newRetIndices[0]));
+ newStoreOp.getValueMutable().assign(newWarpOp.getResult(newRetIndices[1]));
+
+ return success();
+}
+
+LogicalResult WarpOpLoadNd::matchAndRewrite(vector::WarpExecuteOnLane0Op warpOp,
+ PatternRewriter &rewriter) const {
+ OpOperand *operand = getWarpResult(warpOp, [](Operation *op) {
+ return isa<xegpu::LoadNdOp>(op) && op->hasOneUse();
+ });
+
+ if (!operand)
+ return rewriter.notifyMatchFailure(warpOp,
+ "warp result is not a xegpu::LoadNd op");
+
+ auto loadOp = operand->get().getDefiningOp<xegpu::LoadNdOp>();
+
+ if (loadOp.getPacked())
+ return rewriter.notifyMatchFailure(
+ loadOp, "Packed load distribution not supported");
+
+ xegpu::TensorDescType origType = loadOp.getTensorDescType();
+ xegpu::SGMapAttr sgMap = origType.getSGMapAttr();
+ if (!sgMap)
+ return rewriter.notifyMatchFailure(
+ loadOp, "the source tensor descriptor lacks sg_map attribute");
+
+ auto origShape = origType.getShape();
+ if (origShape.size() != 2)
+ return rewriter.notifyMatchFailure(loadOp, "unsupported shape");
+
+ auto distributedTypeOrFailure =
+ getDistributedVectorType(loadOp.getType(), sgMap);
+ if (failed(distributedTypeOrFailure))
+ return rewriter.notifyMatchFailure(loadOp, "Failed to distribute the type");
+ VectorType newVectorType = distributedTypeOrFailure.value();
+
+ auto distributedDescTypeOrFailure =
+ getDistributedTensorDescType(loadOp.getTensorDescType(), sgMap,
+ loadOp.getTensorDescType().getMemorySpace());
+ if (failed(distributedDescTypeOrFailure))
+ return rewriter.notifyMatchFailure(loadOp,
+ "Failed to distribute the desc type");
+ xegpu::TensorDescType newTDescType = distributedDescTypeOrFailure.value();
+
+ unsigned operandIdx = operand->getOperandNumber();
+
+ SmallVector<size_t> newRetIndices;
+ vector::WarpExecuteOnLane0Op newWarpOp =
+ moveRegionToNewWarpOpAndAppendReturns(
+ rewriter, warpOp, loadOp.getTensorDesc(), TypeRange{newTDescType},
+ newRetIndices);
+
+ rewriter.setInsertionPointAfter(newWarpOp);
+
+ auto newLoadOp = rewriter.create<xegpu::LoadNdOp>(
+ loadOp.getLoc(), newVectorType, loadOp.getTensorDesc(),
+ loadOp.getPackedAttr(), loadOp.getTransposeAttr(), loadOp.getL1HintAttr(),
+ loadOp.getL2HintAttr(), loadOp.getL3HintAttr());
+
+ newLoadOp.getTensorDescMutable().assign(
+ newWarpOp.getResult(newRetIndices[0]));
+ Value distributedVal = newWarpOp.getResult(operandIdx);
+ rewriter.replaceAllUsesWith(distributedVal, newLoadOp);
+
+ return success();
+}
+
+LogicalResult
+WarpOpTensorDescOp::matchAndRewrite(vector::WarpExecuteOnLane0Op warpOp,
+ PatternRewriter &rewriter) const {
+ OpOperand *operand = getWarpResult(warpOp, [](Operation *op) {
+ return isa<xegpu::CreateNdDescOp>(op) && op->hasOneUse();
+ });
+
+ if (!operand)
+ return rewriter.notifyMatchFailure(
+ warpOp, "warp result is not a xegpu::CreateNdDesc op");
+ auto descOp = operand->get().getDefiningOp<xegpu::CreateNdDescOp>();
+ assert(descOp && "desc op must be not null");
+ unsigned operandIdx = operand->getOperandNumber();
+
+ // TODO: is memref uniform in the region
+ rewriter.setInsertionPoint(warpOp);
+ auto srcTypedVal = dyn_cast<TypedValue<MemRefType>>(descOp.getSource());
+ assert(srcTypedVal && "source value must be not null");
+
+ auto descOffsets = descOp.getMixedOffsets();
+ if (descOffsets.size() != 2)
+ return rewriter.notifyMatchFailure(descOp,
+ "offsets size is expected to be 2");
+
+ xegpu::SGMapAttr sgMap = descOp.getType().getSGMapAttr();
+ if (!sgMap)
+ return rewriter.notifyMatchFailure(
+ descOp, "the tensor descriptor lacks sg_map attribute");
+
+ auto layout = sgMap.getWiLayout();
+
+ // Calculate the offset within tensor descriptor for the current lane_id. The
+ // access to proper element for a work item is done through a lane-specific
+ // subview (tdesc offsets are used as base, lane shift is added on top).
+ auto laneid = warpOp.getLaneid();
+ auto xDim =
+ rewriter.create<arith::ConstantIndexOp>(laneid.getLoc(), layout[0]);
+ auto shiftx = rewriter.create<arith::RemUIOp>(laneid.getLoc(), laneid, xDim);
+ auto shifty = rewriter.create<arith::DivUIOp>(laneid.getLoc(), laneid, xDim);
+
+ auto basex = getValueOrCreateConstantIndexOp(rewriter, laneid.getLoc(),
+ descOffsets[0]);
+ auto basey = getValueOrCreateConstantIndexOp(rewriter, laneid.getLoc(),
+ descOffsets[1]);
+ auto offsetx = rewriter.create<arith::AddIOp>(laneid.getLoc(), shiftx, basex);
+ auto offsety = rewriter.create<arith::AddIOp>(laneid.getLoc(), shifty, basey);
+
+ auto distributedDescTypeOrFailure = getDistributedTensorDescType(
+ descOp.getType(), sgMap, descOp.getType().getMemorySpace());
+ if (failed(distributedDescTypeOrFailure))
+ return rewriter.notifyMatchFailure(descOp,
+ "Failed to distribute the desc type");
+ xegpu::TensorDescType newTDescType = distributedDescTypeOrFailure.value();
+ auto distributedShape = newTDescType.getShape();
+ // use the base memref strides
+ SmallVector<OpFoldResult> overwriteStrides =
+ getAsIndexOpFoldResult(rewriter.getContext(), SmallVector<int64_t>{1, 1});
+ SmallVector<OpFoldResult> overwriteSizes =
+ getAsIndexOpFoldResult(rewriter.getContext(), distributedShape);
+
+ SmallVector<size_t> newRetIndices;
+ vector::WarpExecuteOnLane0Op newWarpOp =
+ moveRegionToNewWarpOpAndAppendReturns(
+ rewriter, warpOp, descOp.getSource(), descOp.getSourceType(),
+ newRetIndices);
+
+ rewriter.setInsertionPointAfter(newWarpOp);
+ auto subview = rewriter.create<memref::SubViewOp>(
+ newWarpOp.getLoc(), srcTypedVal, getAsOpFoldResult({offsetx, offsety}),
+ overwriteSizes, overwriteStrides);
+ subview.getSourceMutable().assign(newWarpOp.getResult(newRetIndices[0]));
+
+ auto zero = rewriter.create<arith::ConstantIndexOp>(laneid.getLoc(), 0);
+ auto newDescOp = rewriter.create<xegpu::CreateNdDescOp>(
+ newWarpOp.getLoc(), newTDescType, subview,
+ getAsOpFoldResult({zero, zero}));
+
+ Value distributedVal = newWarpOp.getResult(operandIdx);
+ rewriter.replaceAllUsesWith(distributedVal, newDescOp);
+
+ return success();
+}
+
+void xegpu::populateXeGPUDistributePatterns(RewritePatternSet &patterns) {
+ patterns.add<WarpOpTensorDescOp>(patterns.getContext());
+ patterns.add<WarpOpStoreNd>(patterns.getContext());
+ patterns.add<WarpOpLoadNd>(patterns.getContext());
+}
diff --git a/mlir/test/Dialect/XeGPU/xegpu-distribute.mlir b/mlir/test/Dialect/XeGPU/xegpu-distribute.mlir
new file mode 100644
index 00000000000000..ec01fc82688156
--- /dev/null
+++ b/mlir/test/Dialect/XeGPU/xegpu-distribute.mlir
@@ -0,0 +1,81 @@
+// RUN: mlir-opt -test-xegpu-distribute -split-input-file %s | FileCheck %s
+
+#sg_map_16 = #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>
+#blk_tdesc = #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>
+
+// CHECK-LABEL: test_store_nd_distribution
+// CHECK: %[[laneid:.*]] = gpu.lane_id
+// CHECK: %[[res:.*]]:2 = vector.warp_execute_on_lane_0(%[[laneid]])[16] args(%{{.*}}, %{{.*}} : vector<24x32xf16>, !xegpu.tensor_desc<24x32xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>)
+// CHECK-SAME: -> (!xegpu.tensor_desc<24x2xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>, vector<24x2xf16>)
+// CHECK: ^bb0(%[[src:.*]]: vector<24x32xf16>, %[[dst:.*]]: !xegpu.tensor_desc<24x32xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>)
+// CHECK: vector.yield %[[dst]], %[[src]] : !xegpu.tensor_desc<24x32xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>, vector<24x32xf16>
+// CHECK: xegpu.store_nd %[[res]]#1, %[[res]]#0 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}> :
+// CHECK-SAME: vector<24x2xf16>, !xegpu.tensor_desc<24x2xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+
+func.func @test_store_nd_distribution(%src: vector<24x32xf16>, %dst: !xegpu.tensor_desc<24x32xf16, #blk_tdesc, #sg_map_16>) -> () {
+ %laneid = gpu.lane_id
+ vector.warp_execute_on_lane_0(%laneid)[16]
+ args(%src, %dst: vector<24x32xf16>, !xegpu.tensor_desc<24x32xf16, #blk_tdesc, #sg_map_16>) {
+ ^bb0(%arg0: vector<24x32xf16>, %arg1: !xegpu.tensor_desc<24x32xf16, #blk_tdesc, #sg_map_16>):
+ xegpu.store_nd %arg0, %arg1 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>: vector<24x32xf16>, !xegpu.tensor_desc<24x32xf16, #blk_tdesc, #sg_map_16>
+ }
+ return
+}
+
+// -----
+
+#sg_map_16 = #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>
+#blk_tdesc = #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>
+
+// CHECK-LABEL: test_load_nd_distribution
+// CHECK: %[[laneid:.*]] = gpu.lane_id
+// CHECK: %[[res:.*]]:2 = vector.warp_execute_on_lane_0(%[[laneid]])[16] args(%{{.*}} : !xegpu.tensor_desc<24x32xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>)
+// CHECK-SAME: -> (vector<24x2xf16>, !xegpu.tensor_desc<24x2xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>)
+// CHECK: ^bb0(%[[dst:.*]]: !xegpu.tensor_desc<24x32xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>)
+// CHECK: %[[dead:.*]] = xegpu.load_nd
+// CHECK: vector.yield %[[dead]], %[[dst]] : vector<24x32xf16>, !xegpu.tensor_desc<24x32xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+// CHECK: %[[load:.*]] = xegpu.load_nd %[[res]]#1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> :
+// CHECK-SAME: !xegpu.tensor_desc<24x2xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>> -> vector<24x2xf16>
+// CHECK: return %[[load]]
+
+func.func @test_load_nd_distribution(%dst: !xegpu.tensor_desc<24x32xf16, #blk_tdesc, #sg_map_16>) -> (vector<24x2xf16>) {
+ %laneid = gpu.lane_id
+ %r = vector.warp_execute_on_lane_0(%laneid)[16]
+ args(%dst: !xegpu.tensor_desc<24x32xf16, #blk_tdesc, #sg_map_16>) -> (vector<24x2xf16>) {
+ ^bb0(%arg0: !xegpu.tensor_desc<24x32xf16, #blk_tdesc, #sg_map_16>):
+ %0 = xegpu.load_nd %arg0 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+ : !xegpu.tensor_desc<24x32xf16, #blk_tdesc, #sg_map_16> -> vector<24x32xf16>
+ vector.yield %0 : vector<24x32xf16>
+ }
+ return %r : vector<24x2xf16>
+}
+
+// -----
+
+#sg_map_16 = #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>
+#blk_tdesc = #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>
+
+// CHECK-LABEL: test_create_nd_desc_distribution
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[laneid:.*]] = gpu.lane_id
+// CHECK: %[[res:.*]]:2 = vector.warp_execute_on_lane_0(%[[laneid]])[16] args(%{{.*}} : memref<24x32xf16>)
+// CHECK-SAME: -> (!xegpu.tensor_desc<24x2xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>, memref<24x32xf16>)
+// CHECK: ^bb0(%[[dst:.*]]: memref<24x32xf16>)
+// CHECK: %[[dead:.*]] = xegpu.create_nd_tdesc
+// CHECK: vector.yield %[[dead]], %[[dst]] :
+// CHECK-SAME: !xegpu.tensor_desc<24x32xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>, memref<24x32xf16>
+// CHECK: %[[view:.*]] = memref.subview %[[res]]#1[%[[C0]], %[[laneid]]] [24, 2] [1, 1] : memref<24x32xf16> to memref<24x2xf16, strided<[32, 1], offset: ?>>
+// CHECK: %[[desc:.*]] = xegpu.create_nd_tdesc %[[view]][0, 0] : memref<24x2xf16, strided<[32, 1], offset: ?>>
+// CHECK-SAME: -> !xegpu.tensor_desc<24x2xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+// CHECK: return %[[desc]]
+
+func.func @test_create_nd_desc_distribution(%dst: memref<24x32xf16>) -> (!xegpu.tensor_desc<24x2xf16, #blk_tdesc, #sg_map_16>) {
+ %laneid = gpu.lane_id
+ %r = vector.warp_execute_on_lane_0(%laneid)[16]
+ args(%dst: memref<24x32xf16>) -> (!xegpu.tensor_desc<24x2xf16, #blk_tdesc, #sg_map_16>) {
+ ^bb0(%arg0: memref<24x32xf16>):
+ %0 = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<24x32xf16, #blk_tdesc, #sg_map_16>
+ vector.yield %0 : !xegpu.tensor_desc<24x32xf16, #blk_tdesc, #sg_map_16>
+ }
+ return %r : !xegpu.tensor_desc<24x2xf16, #blk_tdesc, #sg_map_16>
+}
diff --git a/mlir/test/lib/Dialect/CMakeLists.txt b/mlir/test/lib/Dialect/CMakeLists.txt
index 29fb4441a24fd2..a8fd70e6397a52 100644
--- a/mlir/test/lib/Dialect/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/CMakeLists.txt
@@ -22,3 +22,4 @@ add_subdirectory(TestDyn)
add_subdirectory(Tosa)
add_subdirectory(Transform)
add_subdirectory(Vector)
+add_subdirectory(XeGPU)
diff --git a/mlir/test/lib/Dialect/XeGPU/CMakeLists.txt b/mlir/test/lib/Dialect/XeGPU/CMakeLists.txt
new file mode 100644
index 00000000000000..c8fe0db5f6213a
--- /dev/null
+++ b/mlir/test/lib/Dialect/XeGPU/CMakeLists.txt
@@ -0,0 +1,17 @@
+# Exclude tests from libMLIR.so
+add_mlir_library(MLIRXeGPUTestPasses
+ TestXeGPUTransforms.cpp
+
+ EXCLUDE_FROM_LIBMLIR
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/XeGPU
+
+ LINK_LIBS PUBLIC
+ MLIRIR
+ MLIRPass
+ MLIRXeGPUTransforms
+ MLIRXeGPUDialect
+ MLIRSupport
+ )
+
diff --git a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
new file mode 100644
index 00000000000000..eda68b83748139
--- /dev/null
+++ b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
@@ -0,0 +1,58 @@
+//===- TestXeGPUTransforms.cpp - Test XeGPU transforms and lowerings ------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+using namespace mlir::xegpu;
+using namespace mlir::vector;
+
+namespace {
+struct TestXeGPUDistribution
+ : public PassWrapper<TestXeGPUDistribution, OperationPass<func::FuncOp>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestXeGPUDistribution)
+
+ TestXeGPUDistribution() = default;
+ TestXeGPUDistribution(const TestXeGPUDistribution &pass)
+ : PassWrapper(pass) {}
+
+ StringRef getArgument() const final { return "test-xegpu-distribute"; }
+ StringRef getDescription() const final {
+ return "Test patterns for operations work item distribution";
+ }
+
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<xegpu::XeGPUDialect>();
+ registry.insert<vector::VectorDialect>();
+ registry.insert<arith::ArithDialect>();
+ registry.insert<memref::MemRefDialect>();
+ }
+
+ void runOnOperation() override {
+ RewritePatternSet patterns(&getContext());
+ populateXeGPUDistributePatterns(patterns);
+ (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+ }
+};
+} // namespace
+
+namespace mlir {
+namespace test {
+void registerTestXeGPUTransforms() {
+ PassRegistration<TestXeGPUDistribution>();
+}
+} // namespace test
+} // namespace mlir
diff --git a/mlir/tools/mlir-opt/CMakeLists.txt b/mlir/tools/mlir-opt/CMakeLists.txt
index 8b79de58fa1028..e4ffbbee7a1d94 100644
--- a/mlir/tools/mlir-opt/CMakeLists.txt
+++ b/mlir/tools/mlir-opt/CMakeLists.txt
@@ -47,6 +47,7 @@ if(MLIR_INCLUDE_TESTS)
MLIRTilingInterfaceTestPasses
MLIRVectorTestPasses
MLIRTestVectorToSPIRV
+ MLIRXeGPUTestPasses
MLIRLLVMTestPasses
)
set(test_libs ${test_libs}
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 36b142484bb04a..b53e9513b0598d 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -151,6 +151,7 @@ void registerTestTransformDialectEraseSchedulePass();
void registerTestPassStateExtensionCommunication();
void registerTestVectorLowerings();
void registerTestVectorReductionToSPIRVDotProd();
+void registerTestXeGPUTransforms();
void registerTestWrittenToPass();
#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
void registerTestDialectConversionPasses();
@@ -286,6 +287,7 @@ void registerTestPasses() {
mlir::test::registerTestTransformDialectEraseSchedulePass();
mlir::test::registerTestPassStateExtensionCommunication();
mlir::test::registerTestVectorLowerings();
+ mlir::test::registerTestXeGPUTransforms();
mlir::test::registerTestVectorReductionToSPIRVDotProd();
mlir::test::registerTestWrittenToPass();
#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
>From f907be64e1bb4af07df070c6e3d49e989ec6c6db Mon Sep 17 00:00:00 2001
From: Petr Kurapov <petr.a.kurapov at intel.com>
Date: Tue, 29 Oct 2024 09:23:01 +0000
Subject: [PATCH 4/4] Remove duplicate comments
---
mlir/lib/Dialect/Vector/Utils/VectorDistributeUtils.cpp | 5 -----
1 file changed, 5 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorDistributeUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorDistributeUtils.cpp
index f41581c6d47f2a..91dac58bacf66c 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorDistributeUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorDistributeUtils.cpp
@@ -14,8 +14,6 @@
using namespace mlir;
-/// Return a value yielded by `warpOp` which statifies the filter lamdba
-/// condition and is not dead.
mlir::OpOperand *
mlir::getWarpResult(vector::WarpExecuteOnLane0Op warpOp,
const std::function<bool(Operation *)> &fn) {
@@ -32,7 +30,6 @@ mlir::getWarpResult(vector::WarpExecuteOnLane0Op warpOp,
return {};
}
-/// Helper to create a new WarpExecuteOnLane0Op with different signature.
vector::WarpExecuteOnLane0Op mlir::moveRegionToNewWarpOpAndReplaceReturns(
RewriterBase &rewriter, vector::WarpExecuteOnLane0Op warpOp,
ValueRange newYieldedValues, TypeRange newReturnTypes) {
@@ -59,8 +56,6 @@ vector::WarpExecuteOnLane0Op mlir::moveRegionToNewWarpOpAndReplaceReturns(
return newWarpOp;
}
-/// Helper to create a new WarpExecuteOnLane0Op region with extra outputs.
-/// `indices` return the index of each new output.
vector::WarpExecuteOnLane0Op mlir::moveRegionToNewWarpOpAndAppendReturns(
RewriterBase &rewriter, vector::WarpExecuteOnLane0Op warpOp,
ValueRange newYieldedValues, TypeRange newReturnTypes,
More information about the Mlir-commits
mailing list