[Mlir-commits] [mlir] d7d6443 - [mlir][vector] Avoid creating duplicate output in warpOp
Thomas Raoux
llvmlistbot at llvm.org
Mon Jul 11 08:42:08 PDT 2022
Author: Thomas Raoux
Date: 2022-07-11T15:37:50Z
New Revision: d7d6443d501839ef806f9dc872900451d7b41927
URL: https://github.com/llvm/llvm-project/commit/d7d6443d501839ef806f9dc872900451d7b41927
DIFF: https://github.com/llvm/llvm-project/commit/d7d6443d501839ef806f9dc872900451d7b41927.diff
LOG: [mlir][vector] Avoid creating duplicate output in warpOp
Prevent creating multiple output for the same Value when distributing
operations out of WarpExecuteOnLane0Op. This avoid creating combinatory
explosion of outputs.
Differential Revision: https://reviews.llvm.org/D129465
Added:
Modified:
mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
mlir/test/Dialect/Vector/vector-warp-distribute.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 2b9635835d7b1..57fa863320906 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -14,7 +14,7 @@
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/Transforms/SideEffectUtils.h"
-
+#include "llvm/ADT/SetVector.h"
#include <utility>
using namespace mlir;
@@ -165,19 +165,34 @@ static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndReplaceReturns(
}
/// Helper to create a new WarpExecuteOnLane0Op region with extra outputs.
+/// `indices` return the index of each new output.
static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndAppendReturns(
RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
- ValueRange newYieldedValues, TypeRange newReturnTypes) {
+ ValueRange newYieldedValues, TypeRange newReturnTypes,
+ llvm::SmallVector<size_t> &indices) {
SmallVector<Type> types(warpOp.getResultTypes().begin(),
warpOp.getResultTypes().end());
- types.append(newReturnTypes.begin(), newReturnTypes.end());
auto yield = cast<vector::YieldOp>(
warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
- SmallVector<Value> yieldValues(yield.getOperands().begin(),
- yield.getOperands().end());
- yieldValues.append(newYieldedValues.begin(), newYieldedValues.end());
+ llvm::SmallSetVector<Value, 32> yieldValues(yield.getOperands().begin(),
+ yield.getOperands().end());
+ for (auto newRet : llvm::zip(newYieldedValues, newReturnTypes)) {
+ if (yieldValues.insert(std::get<0>(newRet))) {
+ types.push_back(std::get<1>(newRet));
+ indices.push_back(yieldValues.size() - 1);
+ } else {
+ // If the value already exit the region don't create a new output.
+ for (auto &yieldOperand : llvm::enumerate(yieldValues.getArrayRef())) {
+ if (yieldOperand.value() == std::get<0>(newRet)) {
+ indices.push_back(yieldOperand.index());
+ break;
+ }
+ }
+ }
+ }
+ yieldValues.insert(newYieldedValues.begin(), newYieldedValues.end());
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
- rewriter, warpOp, yieldValues, types);
+ rewriter, warpOp, yieldValues.getArrayRef(), types);
rewriter.replaceOp(warpOp,
newWarpOp.getResults().take_front(warpOp.getNumResults()));
return newWarpOp;
@@ -273,14 +288,15 @@ static vector::TransferWriteOp cloneWriteOp(RewriterBase &rewriter,
assert(writeOp->getParentOp() == warpOp &&
"write must be nested immediately under warp");
OpBuilder::InsertionGuard g(rewriter);
+ SmallVector<size_t> newRetIndices;
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
rewriter, warpOp, ValueRange{{writeOp.getVector()}},
- TypeRange{targetType});
+ TypeRange{targetType}, newRetIndices);
rewriter.setInsertionPointAfter(newWarpOp);
auto newWriteOp =
cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation()));
rewriter.eraseOp(writeOp);
- newWriteOp.getVectorMutable().assign(newWarpOp.getResults().back());
+ newWriteOp.getVectorMutable().assign(newWarpOp.getResult(newRetIndices[0]));
return newWriteOp;
}
@@ -387,8 +403,9 @@ struct WarpOpTransferWrite : public OpRewritePattern<vector::TransferWriteOp> {
SmallVector<Value> yieldValues = {writeOp.getVector()};
SmallVector<Type> retTypes = {vecType};
+ SmallVector<size_t> newRetIndices;
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
- rewriter, warpOp, yieldValues, retTypes);
+ rewriter, warpOp, yieldValues, retTypes, newRetIndices);
rewriter.setInsertionPointAfter(newWarpOp);
// Create a second warp op that contains only writeOp.
@@ -398,8 +415,7 @@ struct WarpOpTransferWrite : public OpRewritePattern<vector::TransferWriteOp> {
rewriter.setInsertionPointToStart(&body);
auto newWriteOp =
cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation()));
- newWriteOp.getVectorMutable().assign(
- newWarpOp.getResult(newWarpOp.getNumResults() - 1));
+ newWriteOp.getVectorMutable().assign(newWarpOp.getResult(newRetIndices[0]));
rewriter.eraseOp(writeOp);
rewriter.create<vector::YieldOp>(newWarpOp.getLoc());
return success();
@@ -489,14 +505,14 @@ struct WarpOpElementwise : public OpRewritePattern<WarpExecuteOnLane0Op> {
retTypes.push_back(targetType);
yieldValues.push_back(operand.get());
}
- unsigned numResults = warpOp.getNumResults();
+ SmallVector<size_t> newRetIndices;
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
- rewriter, warpOp, yieldValues, retTypes);
+ rewriter, warpOp, yieldValues, retTypes, newRetIndices);
rewriter.setInsertionPointAfter(newWarpOp);
SmallVector<Value> newOperands(elementWise->getOperands().begin(),
elementWise->getOperands().end());
for (unsigned i : llvm::seq(unsigned(0), elementWise->getNumOperands())) {
- newOperands[i] = newWarpOp.getResult(i + numResults);
+ newOperands[i] = newWarpOp.getResult(newRetIndices[i]);
}
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPointAfter(newWarpOp);
@@ -653,12 +669,13 @@ struct WarpOpBroadcast : public OpRewritePattern<WarpExecuteOnLane0Op> {
Location loc = broadcastOp.getLoc();
auto destVecType =
warpOp->getResultTypes()[operandNumber].cast<VectorType>();
+ SmallVector<size_t> newRetIndices;
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
rewriter, warpOp, {broadcastOp.getSource()},
- {broadcastOp.getSource().getType()});
+ {broadcastOp.getSource().getType()}, newRetIndices);
rewriter.setInsertionPointAfter(newWarpOp);
Value broadcasted = rewriter.create<vector::BroadcastOp>(
- loc, destVecType, newWarpOp->getResults().back());
+ loc, destVecType, newWarpOp->getResult(newRetIndices[0]));
newWarpOp->getResult(operandNumber).replaceAllUsesWith(broadcasted);
return success();
}
@@ -814,12 +831,12 @@ struct WarpOpReduction : public OpRewritePattern<WarpExecuteOnLane0Op> {
SmallVector<Value> yieldValues = {reductionOp.getVector()};
SmallVector<Type> retTypes = {
VectorType::get({numElements}, reductionOp.getType())};
- unsigned numResults = warpOp.getNumResults();
+ SmallVector<size_t> newRetIndices;
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
- rewriter, warpOp, yieldValues, retTypes);
+ rewriter, warpOp, yieldValues, retTypes, newRetIndices);
rewriter.setInsertionPointAfter(newWarpOp);
- Value laneValVec = newWarpOp.getResult(numResults);
+ Value laneValVec = newWarpOp.getResult(newRetIndices[0]);
// First reduce on a single thread.
Value perLaneReduction = rewriter.create<vector::ReductionOp>(
reductionOp.getLoc(), reductionOp.getKind(), laneValVec);
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index 82f6299634578..4a04f988be979 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -545,3 +545,20 @@ func.func @vector_reduction_large(%laneid: index) -> (f32) {
}
return %r : f32
}
+
+// -----
+
+// CHECK-PROP-LABEL: func @warp_duplicate_yield(
+func.func @warp_duplicate_yield(%laneid: index) -> (vector<1xf32>, vector<1xf32>) {
+ // CHECK-PROP: %{{.*}}:2 = vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<1xf32>, vector<1xf32>)
+ %r:2 = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xf32>, vector<1xf32>) {
+ %2 = "some_def"() : () -> (vector<32xf32>)
+ %3 = "some_def"() : () -> (vector<32xf32>)
+ %4 = arith.addf %2, %3 : vector<32xf32>
+ %5 = arith.addf %2, %2 : vector<32xf32>
+// CHECK-PROP-NOT: arith.addf
+// CHECK-PROP: vector.yield %{{.*}}, %{{.*}} : vector<32xf32>, vector<32xf32>
+ vector.yield %4, %5 : vector<32xf32>, vector<32xf32>
+ }
+ return %r#0, %r#1 : vector<1xf32>, vector<1xf32>
+}
More information about the Mlir-commits
mailing list