[Mlir-commits] [mlir] [MLIR][XeGPU] Xegpu distribution patterns for load_nd, store_nd, and create_nd_tdesc. (PR #112945)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Oct 18 10:52:04 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-vector
@llvm/pr-subscribers-mlir-core
Author: Petr Kurapov (kurapov-peter)
<details>
<summary>Changes</summary>
This PR introduces distribution patterns for a portion of xegpu dialect similarly to the vector dialect, as well as moving some of the common functionality to the vector utilities.
Xegpu ops rewrite patterns distribute the vector and xegpu tensor descriptor types when sinked through yield op of a `vector.warp_execute_on_lane_0` according to the `xegpu.sg_map` attribute. The validation of distributed types in the `warp_execute_on_lane_0` was hence relaxed to allow `ShapedType` return values to have distributed shapes.
---
Patch is 43.46 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/112945.diff
15 Files Affected:
- (modified) mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h (+20)
- (modified) mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h (+1)
- (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+6-6)
- (modified) mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp (+1-79)
- (modified) mlir/lib/Dialect/Vector/Utils/CMakeLists.txt (+1)
- (added) mlir/lib/Dialect/Vector/Utils/VectorDistributeUtils.cpp (+96)
- (modified) mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt (+5)
- (added) mlir/lib/Dialect/XeGPU/Transforms/XeGPUDistribute.cpp (+393)
- (modified) mlir/test/Dialect/Vector/invalid.mlir (+3-3)
- (added) mlir/test/Dialect/XeGPU/xegpu-distribute.mlir (+81)
- (modified) mlir/test/lib/Dialect/CMakeLists.txt (+1)
- (added) mlir/test/lib/Dialect/XeGPU/CMakeLists.txt (+17)
- (added) mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp (+58)
- (modified) mlir/tools/mlir-opt/CMakeLists.txt (+1)
- (modified) mlir/tools/mlir-opt/mlir-opt.cpp (+2)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
index 5f32aca88a2734..6bd924307376dc 100644
--- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
+++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
@@ -20,6 +20,8 @@
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/TypeSwitch.h"
+#include <utility>
+
namespace mlir {
// Forward declarations.
@@ -324,6 +326,24 @@ namespace matcher {
bool operatesOnSuperVectorsOf(Operation &op, VectorType subVectorType);
} // namespace matcher
+
+/// Return a value yielded by `warpOp` which statifies the filter lamdba
+/// condition and is not dead.
+OpOperand *getWarpResult(vector::WarpExecuteOnLane0Op warpOp,
+ const std::function<bool(Operation *)> &fn);
+
+/// Helper to create a new WarpExecuteOnLane0Op with different signature.
+vector::WarpExecuteOnLane0Op moveRegionToNewWarpOpAndReplaceReturns(
+ RewriterBase &rewriter, vector::WarpExecuteOnLane0Op warpOp,
+ ValueRange newYieldedValues, TypeRange newReturnTypes);
+
+/// Helper to create a new WarpExecuteOnLane0Op region with extra outputs.
+/// `indices` return the index of each new output.
+vector::WarpExecuteOnLane0Op moveRegionToNewWarpOpAndAppendReturns(
+ RewriterBase &rewriter, vector::WarpExecuteOnLane0Op warpOp,
+ ValueRange newYieldedValues, TypeRange newReturnTypes,
+ llvm::SmallVector<size_t> &indices);
+
} // namespace mlir
#endif // MLIR_DIALECT_VECTOR_UTILS_VECTORUTILS_H_
diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h b/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
index 63ea26df069372..fe5198d1ac6dba 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
@@ -16,6 +16,7 @@ namespace xegpu {
/// Appends patterns for folding aliasing ops into XeGPU ops into `patterns`.
void populateXeGPUFoldAliasOpsPatterns(RewritePatternSet &patterns);
+void populateXeGPUDistributePatterns(RewritePatternSet &patterns);
} // namespace xegpu
} // namespace mlir
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index a2abe1619454f2..51d3691fd107ae 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -6558,14 +6558,14 @@ static LogicalResult verifyDistributedType(Type expanded, Type distributed,
// If the types matches there is no distribution.
if (expanded == distributed)
return success();
- auto expandedVecType = llvm::dyn_cast<VectorType>(expanded);
- auto distributedVecType = llvm::dyn_cast<VectorType>(distributed);
+ auto expandedVecType = llvm::dyn_cast<ShapedType>(expanded);
+ auto distributedVecType = llvm::dyn_cast<ShapedType>(distributed);
if (!expandedVecType || !distributedVecType)
- return op->emitOpError("expected vector type for distributed operands.");
+ return op->emitOpError("expected shaped type for distributed operands.");
if (expandedVecType.getRank() != distributedVecType.getRank() ||
expandedVecType.getElementType() != distributedVecType.getElementType())
return op->emitOpError(
- "expected distributed vectors to have same rank and element type.");
+ "expected distributed types to have same rank and element type.");
SmallVector<int64_t> scales(expandedVecType.getRank(), 1);
for (int64_t i = 0, e = expandedVecType.getRank(); i < e; i++) {
@@ -6575,8 +6575,8 @@ static LogicalResult verifyDistributedType(Type expanded, Type distributed,
continue;
if (eDim % dDim != 0)
return op->emitOpError()
- << "expected expanded vector dimension #" << i << " (" << eDim
- << ") to be a multipler of the distributed vector dimension ("
+ << "expected expanded type dimension #" << i << " (" << eDim
+ << ") to be a multipler of the distributed type dimension ("
<< dDim << ")";
scales[i] = eDim / dDim;
}
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 2289fd1ff1364e..c80c3179b5e025 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -12,6 +12,7 @@
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Transforms/RegionUtils.h"
@@ -160,68 +161,6 @@ struct DistributedLoadStoreHelper {
} // namespace
-/// Helper to create a new WarpExecuteOnLane0Op with different signature.
-static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndReplaceReturns(
- RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
- ValueRange newYieldedValues, TypeRange newReturnTypes) {
- // Create a new op before the existing one, with the extra operands.
- OpBuilder::InsertionGuard g(rewriter);
- rewriter.setInsertionPoint(warpOp);
- auto newWarpOp = rewriter.create<WarpExecuteOnLane0Op>(
- warpOp.getLoc(), newReturnTypes, warpOp.getLaneid(), warpOp.getWarpSize(),
- warpOp.getArgs(), warpOp.getBody()->getArgumentTypes());
-
- Region &opBody = warpOp.getBodyRegion();
- Region &newOpBody = newWarpOp.getBodyRegion();
- Block &newOpFirstBlock = newOpBody.front();
- rewriter.inlineRegionBefore(opBody, newOpBody, newOpBody.begin());
- rewriter.eraseBlock(&newOpFirstBlock);
- assert(newWarpOp.getWarpRegion().hasOneBlock() &&
- "expected WarpOp with single block");
-
- auto yield =
- cast<vector::YieldOp>(newOpBody.getBlocks().begin()->getTerminator());
-
- rewriter.modifyOpInPlace(
- yield, [&]() { yield.getOperandsMutable().assign(newYieldedValues); });
- return newWarpOp;
-}
-
-/// Helper to create a new WarpExecuteOnLane0Op region with extra outputs.
-/// `indices` return the index of each new output.
-static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndAppendReturns(
- RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
- ValueRange newYieldedValues, TypeRange newReturnTypes,
- llvm::SmallVector<size_t> &indices) {
- SmallVector<Type> types(warpOp.getResultTypes().begin(),
- warpOp.getResultTypes().end());
- auto yield = cast<vector::YieldOp>(
- warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
- llvm::SmallSetVector<Value, 32> yieldValues(yield.getOperands().begin(),
- yield.getOperands().end());
- for (auto newRet : llvm::zip(newYieldedValues, newReturnTypes)) {
- if (yieldValues.insert(std::get<0>(newRet))) {
- types.push_back(std::get<1>(newRet));
- indices.push_back(yieldValues.size() - 1);
- } else {
- // If the value already exit the region don't create a new output.
- for (auto [idx, yieldOperand] :
- llvm::enumerate(yieldValues.getArrayRef())) {
- if (yieldOperand == std::get<0>(newRet)) {
- indices.push_back(idx);
- break;
- }
- }
- }
- }
- yieldValues.insert(newYieldedValues.begin(), newYieldedValues.end());
- WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
- rewriter, warpOp, yieldValues.getArrayRef(), types);
- rewriter.replaceOp(warpOp,
- newWarpOp.getResults().take_front(warpOp.getNumResults()));
- return newWarpOp;
-}
-
/// Helper to know if an op can be hoisted out of the region.
static bool canBeHoisted(Operation *op,
function_ref<bool(Value)> definedOutside) {
@@ -229,23 +168,6 @@ static bool canBeHoisted(Operation *op,
isMemoryEffectFree(op) && op->getNumRegions() == 0;
}
-/// Return a value yielded by `warpOp` which statifies the filter lamdba
-/// condition and is not dead.
-static OpOperand *getWarpResult(WarpExecuteOnLane0Op warpOp,
- const std::function<bool(Operation *)> &fn) {
- auto yield = cast<vector::YieldOp>(
- warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
- for (OpOperand &yieldOperand : yield->getOpOperands()) {
- Value yieldValues = yieldOperand.get();
- Operation *definedOp = yieldValues.getDefiningOp();
- if (definedOp && fn(definedOp)) {
- if (!warpOp.getResult(yieldOperand.getOperandNumber()).use_empty())
- return &yieldOperand;
- }
- }
- return {};
-}
-
// Clones `op` into a new operation that takes `operands` and returns
// `resultTypes`.
static Operation *cloneOpWithOperandsAndTypes(RewriterBase &rewriter,
diff --git a/mlir/lib/Dialect/Vector/Utils/CMakeLists.txt b/mlir/lib/Dialect/Vector/Utils/CMakeLists.txt
index fa3971695d4bf2..9db0c172fec5ce 100644
--- a/mlir/lib/Dialect/Vector/Utils/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/Utils/CMakeLists.txt
@@ -1,5 +1,6 @@
add_mlir_dialect_library(MLIRVectorUtils
VectorUtils.cpp
+ VectorDistributeUtils.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Vector/Utils
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorDistributeUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorDistributeUtils.cpp
new file mode 100644
index 00000000000000..f41581c6d47f2a
--- /dev/null
+++ b/mlir/lib/Dialect/Vector/Utils/VectorDistributeUtils.cpp
@@ -0,0 +1,96 @@
+//===- VectorDistributeUtils.cpp - MLIR Utilities VectorOps distribution -===//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements utility methods for working with the Vector dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
+
+using namespace mlir;
+
+/// Return a value yielded by `warpOp` which statifies the filter lamdba
+/// condition and is not dead.
+mlir::OpOperand *
+mlir::getWarpResult(vector::WarpExecuteOnLane0Op warpOp,
+ const std::function<bool(Operation *)> &fn) {
+ auto yield = cast<vector::YieldOp>(
+ warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+ for (mlir::OpOperand &yieldOperand : yield->getOpOperands()) {
+ Value yieldValues = yieldOperand.get();
+ Operation *definedOp = yieldValues.getDefiningOp();
+ if (definedOp && fn(definedOp)) {
+ if (!warpOp.getResult(yieldOperand.getOperandNumber()).use_empty())
+ return &yieldOperand;
+ }
+ }
+ return {};
+}
+
+/// Helper to create a new WarpExecuteOnLane0Op with different signature.
+vector::WarpExecuteOnLane0Op mlir::moveRegionToNewWarpOpAndReplaceReturns(
+ RewriterBase &rewriter, vector::WarpExecuteOnLane0Op warpOp,
+ ValueRange newYieldedValues, TypeRange newReturnTypes) {
+ // Create a new op before the existing one, with the extra operands.
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(warpOp);
+ auto newWarpOp = rewriter.create<vector::WarpExecuteOnLane0Op>(
+ warpOp.getLoc(), newReturnTypes, warpOp.getLaneid(), warpOp.getWarpSize(),
+ warpOp.getArgs(), warpOp.getBody()->getArgumentTypes());
+
+ Region &opBody = warpOp.getBodyRegion();
+ Region &newOpBody = newWarpOp.getBodyRegion();
+ Block &newOpFirstBlock = newOpBody.front();
+ rewriter.inlineRegionBefore(opBody, newOpBody, newOpBody.begin());
+ rewriter.eraseBlock(&newOpFirstBlock);
+ assert(newWarpOp.getWarpRegion().hasOneBlock() &&
+ "expected WarpOp with single block");
+
+ auto yield =
+ cast<vector::YieldOp>(newOpBody.getBlocks().begin()->getTerminator());
+
+ rewriter.modifyOpInPlace(
+ yield, [&]() { yield.getOperandsMutable().assign(newYieldedValues); });
+ return newWarpOp;
+}
+
+/// Helper to create a new WarpExecuteOnLane0Op region with extra outputs.
+/// `indices` return the index of each new output.
+vector::WarpExecuteOnLane0Op mlir::moveRegionToNewWarpOpAndAppendReturns(
+ RewriterBase &rewriter, vector::WarpExecuteOnLane0Op warpOp,
+ ValueRange newYieldedValues, TypeRange newReturnTypes,
+ llvm::SmallVector<size_t> &indices) {
+ SmallVector<Type> types(warpOp.getResultTypes().begin(),
+ warpOp.getResultTypes().end());
+ auto yield = cast<vector::YieldOp>(
+ warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+ llvm::SmallSetVector<Value, 32> yieldValues(yield.getOperands().begin(),
+ yield.getOperands().end());
+ for (auto newRet : llvm::zip(newYieldedValues, newReturnTypes)) {
+ if (yieldValues.insert(std::get<0>(newRet))) {
+ types.push_back(std::get<1>(newRet));
+ indices.push_back(yieldValues.size() - 1);
+ } else {
+ // If the value already exit the region don't create a new output.
+ for (auto [idx, yieldOperand] :
+ llvm::enumerate(yieldValues.getArrayRef())) {
+ if (yieldOperand == std::get<0>(newRet)) {
+ indices.push_back(idx);
+ break;
+ }
+ }
+ }
+ }
+ yieldValues.insert(newYieldedValues.begin(), newYieldedValues.end());
+ vector::WarpExecuteOnLane0Op newWarpOp =
+ moveRegionToNewWarpOpAndReplaceReturns(rewriter, warpOp,
+ yieldValues.getArrayRef(), types);
+ rewriter.replaceOp(warpOp,
+ newWarpOp.getResults().take_front(warpOp.getNumResults()));
+ return newWarpOp;
+}
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
index 7fb64d3b97b87d..148ff46ba41b72 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
@@ -1,5 +1,6 @@
add_mlir_dialect_library(MLIRXeGPUTransforms
XeGPUFoldAliasOps.cpp
+ XeGPUDistribute.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/XeGPU
@@ -12,6 +13,10 @@ add_mlir_dialect_library(MLIRXeGPUTransforms
MLIRIR
MLIRMemRefDialect
MLIRXeGPUDialect
+ MLIRVectorDialect
+ MLIRVectorUtils
+ MLIRArithDialect
+ MLIRFuncDialect
MLIRPass
MLIRTransforms
)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUDistribute.cpp
new file mode 100644
index 00000000000000..78a010ff1c941b
--- /dev/null
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUDistribute.cpp
@@ -0,0 +1,393 @@
+//===- XeGPUDistribute.cpp - XeGPU ditribute ops to work items --*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
+#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "xegpu-distribute"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
+
+using namespace mlir;
+
+namespace {
+bool divisible(APInt lhs, APInt rhs) { return !lhs.urem(rhs); }
+
+/// Clone a create_nd_tdesc feeding into vector.yield op for the enclosing
+/// `vector.warp_execute_on_lane_0` and put it after the warp op.
+/// The warp op will still contain the original op that will not be used by the
+/// yield op (and should be cleaned up later with dce). The yield op will bypass
+/// the create_nd_tdesc's arguments.
+/// The rewrite will create a subview of the size used by a single work item and
+/// appropriate offset. The distributed create_nd_tdesc points into the subview
+/// without offset. The tensor descriptor types is distributed according to
+/// sg_map attribute.
+///
+/// Example:
+///
+/// ```
+/// #sg_map_8 = #xegpu.sg_map<wi_layout = [1, 8], wi_data = [1, 1]>
+/// %r = vector.warp_execute_on_lane_0(%laneid) ->
+/// (!xegpu.tensor_desc<4x8xf32>) {
+/// ...
+/// %td = xegpu.create_nd_tdesc %arg0[0, 0]
+/// : memref<4x8xf32> -> !xegpu.tensor_desc<4x8xf32>
+/// vector.yield %td
+/// }
+/// ```
+/// To
+/// ```
+/// %r:2 = vector.warp_execute_on_lane_0(%laneid) -> () {
+/// ...
+/// %dead = xegpu.create_nd_tdesc %arg0[0, 0]
+/// : memref<4x8xf32> -> !xegpu.tensor_desc<4x8xf32>
+/// vector.yield %arg0, %dead
+/// }
+/// %view = memref.subview %r#0[0, %laneid] [4, 1] [1, 1]
+/// : memref<4x8xf32> to memref<4x1xf32>
+/// %td = xegpu.create_nd_tdesc %view[0, 0]: memref<4x1xf32>
+/// -> !xegpu.tensor_desc<4x1xf32>
+///
+/// ```
+struct WarpOpTensorDescOp final
+ : public OpRewritePattern<vector::WarpExecuteOnLane0Op> {
+ using OpRewritePattern<vector::WarpExecuteOnLane0Op>::OpRewritePattern;
+ LogicalResult matchAndRewrite(vector::WarpExecuteOnLane0Op warpOp,
+ PatternRewriter &rewriter) const override;
+};
+
+/// Sink a store_nd feeding into vector.yield op for the enclosing
+/// `vector.warp_execute_on_lane_0`. In case arguments for the store are passed
+/// through the warp op interface they would be propagated as returned values.
+/// Both the stored vector type and tensor descriptor types are distributed
+/// according to sg_map attribute.
+///
+/// Example:
+///
+/// ```
+/// #sg_map_8 = #xegpu.sg_map<wi_layout = [1, 8], wi_data = [1, 1]>
+/// vector.warp_execute_on_lane_0(%laneid) -> () {
+/// ...
+/// xegpu.store_nd %arg0, %arg1: vector<4x8xf32>,
+/// !xegpu.tensor_desc<4x8xf32>
+/// vector.yield
+/// }
+/// ```
+/// To
+/// ```
+/// %r = vector.warp_execute_on_lane_0(%laneid) -> () {
+/// ...
+/// vector.yield
+/// }
+/// xegpu.store_nd %arg0, %arg1: vector<4x1xf32>, !xegpu.tensor_desc<4x1xf32>
+///
+/// ```
+struct WarpOpStoreNd final
+ : public OpRewritePattern<vector::WarpExecuteOnLane0Op> {
+ using OpRewritePattern<vector::WarpExecuteOnLane0Op>::OpRewritePattern;
+ LogicalResult matchAndRewrite(vector::WarpExecuteOnLane0Op warpOp,
+ PatternRewriter &rewriter) const override;
+};
+
+/// Clone a load_nd feeding into vector.yield op for the enclosing
+/// `vector.warp_execute_on_lane_0` and put it after the warp op.
+/// The warp op will still contain the original op that will not be used by the
+/// yield op (and should be cleaned up later with dce). The yield op will bypass
+/// the load's arguments.
+/// Both the loaded vector type and tensor descriptor types are distributed
+/// according to sg_map attribute.
+///
+/// Example:
+///
+/// ```
+/// #sg_map_8 = #xegpu.sg_map<wi_layout = [1, 8], wi_data = [1, 1]>
+/// %r = vector.warp_execute_on_lane_0(%laneid) ->
+/// (!xegpu.tensor_desc<4x8xf32>) {
+/// ...
+/// %ld = xegpu.load_nd %arg0, %arg1: !xegpu.tensor_desc<4x8xf32>,
+/// vector<4x8xf32> vector.yield %ld
+/// }
+/// ```
+/// To
+/// ```
+/// %r:2 = vector.warp_execute_on_lane_0(%laneid) -> () {
+/// ...
+/// %dead = xegpu.load_nd %arg0, %arg1:
+/// !xegpu.tensor_desc<4x8xf32>, vector<4x8xf32>
+/// vector.yield %arg0, %arg1
+/// }
+/// xegpu.store_nd %r#0, %r#1: vector<4x1xf32>, !xegpu.tensor_desc<4x1xf32>
+///
+/// ```
+struct WarpOpLoadNd final
+ : public OpRewritePattern<vector::WarpExecuteOnLane0Op> {
+ using OpRewritePattern<vector::WarpExecuteOnLane0Op>::OpRewritePattern;
+ LogicalResult matchAndRewrite(vector::WarpExecuteOnLane0Op warpOp,
+ PatternRewriter &rewriter) const override;
+};
+
+FailureOr<VectorType> getDistributedVectorType(VectorType originalT,
+ ...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/112945
More information about the Mlir-commits
mailing list