[Mlir-commits] [mlir] [MLIR][Vector][NFC] Move helper functions to vector distribution utils (PR #114208)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Oct 30 04:29:10 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-vector
Author: Petr Kurapov (kurapov-peter)
<details>
<summary>Changes</summary>
The first portion of https://github.com/llvm/llvm-project/pull/112945.
---
Full diff: https://github.com/llvm/llvm-project/pull/114208.diff
4 Files Affected:
- (added) mlir/include/mlir/Dialect/Vector/Utils/VectorDistributionUtils.h (+37)
- (modified) mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp (+1-79)
- (modified) mlir/lib/Dialect/Vector/Utils/CMakeLists.txt (+1)
- (added) mlir/lib/Dialect/Vector/Utils/VectorDistributionUtils.cpp (+93)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorDistributionUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorDistributionUtils.h
new file mode 100644
index 00000000000000..460f1f2a49d89c
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorDistributionUtils.h
@@ -0,0 +1,37 @@
+//===- VectorDistributionUtils.h - Distribution Utilities -------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_VECTOR_DISTRIBUTION_UTILS_VECTORUTILS_H_
+#define MLIR_DIALECT_VECTOR_DISTRIBITION_UTILS_VECTORUTILS_H_
+
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+
+#include <utility>
+
+namespace mlir {
+namespace vector {
+/// Return a value yielded by `warpOp` which statifies the filter lamdba
+/// condition and is not dead.
+OpOperand *getWarpResult(WarpExecuteOnLane0Op warpOp,
+ const std::function<bool(Operation *)> &fn);
+
+/// Helper to create a new WarpExecuteOnLane0Op with different signature.
+WarpExecuteOnLane0Op moveRegionToNewWarpOpAndReplaceReturns(
+ RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
+ ValueRange newYieldedValues, TypeRange newReturnTypes);
+
+/// Helper to create a new WarpExecuteOnLane0Op region with extra outputs.
+/// `indices` return the index of each new output.
+WarpExecuteOnLane0Op moveRegionToNewWarpOpAndAppendReturns(
+ RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
+ ValueRange newYieldedValues, TypeRange newReturnTypes,
+ llvm::SmallVector<size_t> &indices);
+} // namespace vector
+} // namespace mlir
+
+#endif // MLIR_DIALECT_VECTOR_DISTRIBITION_UTILS_VECTORUTILS_H_
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 2289fd1ff1364e..1745345a640dbc 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -12,6 +12,7 @@
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
+#include "mlir/Dialect/Vector/Utils/VectorDistributionUtils.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Transforms/RegionUtils.h"
@@ -160,68 +161,6 @@ struct DistributedLoadStoreHelper {
} // namespace
-/// Helper to create a new WarpExecuteOnLane0Op with different signature.
-static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndReplaceReturns(
- RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
- ValueRange newYieldedValues, TypeRange newReturnTypes) {
- // Create a new op before the existing one, with the extra operands.
- OpBuilder::InsertionGuard g(rewriter);
- rewriter.setInsertionPoint(warpOp);
- auto newWarpOp = rewriter.create<WarpExecuteOnLane0Op>(
- warpOp.getLoc(), newReturnTypes, warpOp.getLaneid(), warpOp.getWarpSize(),
- warpOp.getArgs(), warpOp.getBody()->getArgumentTypes());
-
- Region &opBody = warpOp.getBodyRegion();
- Region &newOpBody = newWarpOp.getBodyRegion();
- Block &newOpFirstBlock = newOpBody.front();
- rewriter.inlineRegionBefore(opBody, newOpBody, newOpBody.begin());
- rewriter.eraseBlock(&newOpFirstBlock);
- assert(newWarpOp.getWarpRegion().hasOneBlock() &&
- "expected WarpOp with single block");
-
- auto yield =
- cast<vector::YieldOp>(newOpBody.getBlocks().begin()->getTerminator());
-
- rewriter.modifyOpInPlace(
- yield, [&]() { yield.getOperandsMutable().assign(newYieldedValues); });
- return newWarpOp;
-}
-
-/// Helper to create a new WarpExecuteOnLane0Op region with extra outputs.
-/// `indices` return the index of each new output.
-static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndAppendReturns(
- RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
- ValueRange newYieldedValues, TypeRange newReturnTypes,
- llvm::SmallVector<size_t> &indices) {
- SmallVector<Type> types(warpOp.getResultTypes().begin(),
- warpOp.getResultTypes().end());
- auto yield = cast<vector::YieldOp>(
- warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
- llvm::SmallSetVector<Value, 32> yieldValues(yield.getOperands().begin(),
- yield.getOperands().end());
- for (auto newRet : llvm::zip(newYieldedValues, newReturnTypes)) {
- if (yieldValues.insert(std::get<0>(newRet))) {
- types.push_back(std::get<1>(newRet));
- indices.push_back(yieldValues.size() - 1);
- } else {
- // If the value already exit the region don't create a new output.
- for (auto [idx, yieldOperand] :
- llvm::enumerate(yieldValues.getArrayRef())) {
- if (yieldOperand == std::get<0>(newRet)) {
- indices.push_back(idx);
- break;
- }
- }
- }
- }
- yieldValues.insert(newYieldedValues.begin(), newYieldedValues.end());
- WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
- rewriter, warpOp, yieldValues.getArrayRef(), types);
- rewriter.replaceOp(warpOp,
- newWarpOp.getResults().take_front(warpOp.getNumResults()));
- return newWarpOp;
-}
-
/// Helper to know if an op can be hoisted out of the region.
static bool canBeHoisted(Operation *op,
function_ref<bool(Value)> definedOutside) {
@@ -229,23 +168,6 @@ static bool canBeHoisted(Operation *op,
isMemoryEffectFree(op) && op->getNumRegions() == 0;
}
-/// Return a value yielded by `warpOp` which statifies the filter lamdba
-/// condition and is not dead.
-static OpOperand *getWarpResult(WarpExecuteOnLane0Op warpOp,
- const std::function<bool(Operation *)> &fn) {
- auto yield = cast<vector::YieldOp>(
- warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
- for (OpOperand &yieldOperand : yield->getOpOperands()) {
- Value yieldValues = yieldOperand.get();
- Operation *definedOp = yieldValues.getDefiningOp();
- if (definedOp && fn(definedOp)) {
- if (!warpOp.getResult(yieldOperand.getOperandNumber()).use_empty())
- return &yieldOperand;
- }
- }
- return {};
-}
-
// Clones `op` into a new operation that takes `operands` and returns
// `resultTypes`.
static Operation *cloneOpWithOperandsAndTypes(RewriterBase &rewriter,
diff --git a/mlir/lib/Dialect/Vector/Utils/CMakeLists.txt b/mlir/lib/Dialect/Vector/Utils/CMakeLists.txt
index fa3971695d4bf2..41fd677c845ae2 100644
--- a/mlir/lib/Dialect/Vector/Utils/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/Utils/CMakeLists.txt
@@ -1,5 +1,6 @@
add_mlir_dialect_library(MLIRVectorUtils
VectorUtils.cpp
+ VectorDistributionUtils.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Vector/Utils
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorDistributionUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorDistributionUtils.cpp
new file mode 100644
index 00000000000000..19df2105de861f
--- /dev/null
+++ b/mlir/lib/Dialect/Vector/Utils/VectorDistributionUtils.cpp
@@ -0,0 +1,93 @@
+//===- VectorDistributionUtils.cpp - Distribution tools for VectorOps -----===//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements distribution utility methods.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Vector/Utils/VectorDistributionUtils.h"
+
+#include "mlir/IR/Value.h"
+
+using namespace mlir;
+using namespace mlir::vector;
+
+mlir::OpOperand *
+mlir::vector::getWarpResult(WarpExecuteOnLane0Op warpOp,
+ const std::function<bool(Operation *)> &fn) {
+ auto yield = cast<vector::YieldOp>(
+ warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+ for (mlir::OpOperand &yieldOperand : yield->getOpOperands()) {
+ Value yieldValues = yieldOperand.get();
+ Operation *definedOp = yieldValues.getDefiningOp();
+ if (definedOp && fn(definedOp)) {
+ if (!warpOp.getResult(yieldOperand.getOperandNumber()).use_empty())
+ return &yieldOperand;
+ }
+ }
+ return {};
+}
+
+WarpExecuteOnLane0Op mlir::vector::moveRegionToNewWarpOpAndReplaceReturns(
+ RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
+ ValueRange newYieldedValues, TypeRange newReturnTypes) {
+ // Create a new op before the existing one, with the extra operands.
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(warpOp);
+ auto newWarpOp = rewriter.create<WarpExecuteOnLane0Op>(
+ warpOp.getLoc(), newReturnTypes, warpOp.getLaneid(), warpOp.getWarpSize(),
+ warpOp.getArgs(), warpOp.getBody()->getArgumentTypes());
+
+ Region &opBody = warpOp.getBodyRegion();
+ Region &newOpBody = newWarpOp.getBodyRegion();
+ Block &newOpFirstBlock = newOpBody.front();
+ rewriter.inlineRegionBefore(opBody, newOpBody, newOpBody.begin());
+ rewriter.eraseBlock(&newOpFirstBlock);
+ assert(newWarpOp.getWarpRegion().hasOneBlock() &&
+ "expected WarpOp with single block");
+
+ auto yield =
+ cast<vector::YieldOp>(newOpBody.getBlocks().begin()->getTerminator());
+
+ rewriter.modifyOpInPlace(
+ yield, [&]() { yield.getOperandsMutable().assign(newYieldedValues); });
+ return newWarpOp;
+}
+
+WarpExecuteOnLane0Op mlir::vector::moveRegionToNewWarpOpAndAppendReturns(
+ RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
+ ValueRange newYieldedValues, TypeRange newReturnTypes,
+ llvm::SmallVector<size_t> &indices) {
+ SmallVector<Type> types(warpOp.getResultTypes().begin(),
+ warpOp.getResultTypes().end());
+ auto yield = cast<vector::YieldOp>(
+ warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+ llvm::SmallSetVector<Value, 32> yieldValues(yield.getOperands().begin(),
+ yield.getOperands().end());
+ for (auto newRet : llvm::zip(newYieldedValues, newReturnTypes)) {
+ if (yieldValues.insert(std::get<0>(newRet))) {
+ types.push_back(std::get<1>(newRet));
+ indices.push_back(yieldValues.size() - 1);
+ } else {
+ // If the value already exit the region don't create a new output.
+ for (auto [idx, yieldOperand] :
+ llvm::enumerate(yieldValues.getArrayRef())) {
+ if (yieldOperand == std::get<0>(newRet)) {
+ indices.push_back(idx);
+ break;
+ }
+ }
+ }
+ }
+ yieldValues.insert(newYieldedValues.begin(), newYieldedValues.end());
+ WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
+ rewriter, warpOp, yieldValues.getArrayRef(), types);
+ rewriter.replaceOp(warpOp,
+ newWarpOp.getResults().take_front(warpOp.getNumResults()));
+ return newWarpOp;
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/114208
More information about the Mlir-commits
mailing list