[Mlir-commits] [mlir] 20df17f - [mlir][vector] Extend WarpExecutionOnLane0 pattern support to allow deduplicating identical yield values.
Nicolas Vasilache
llvmlistbot at llvm.org
Fri Sep 9 06:53:45 PDT 2022
Author: Nicolas Vasilache
Date: 2022-09-09T06:53:36-07:00
New Revision: 20df17fd2db16b4df20fe275bf9f062eaecf9745
URL: https://github.com/llvm/llvm-project/commit/20df17fd2db16b4df20fe275bf9f062eaecf9745
DIFF: https://github.com/llvm/llvm-project/commit/20df17fd2db16b4df20fe275bf9f062eaecf9745.diff
LOG: [mlir][vector] Extend WarpExecutionOnLane0 pattern support to allow deduplicating identical yield values.
Differential Revision: https://reviews.llvm.org/D133573
Added:
Modified:
mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
mlir/test/Dialect/Vector/vector-warp-distribute.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index f356248f18311..2c757be09a6a0 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -753,27 +753,50 @@ struct WarpOpDeadResult : public OpRewritePattern<WarpExecuteOnLane0Op> {
using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
- SmallVector<Type> resultTypes;
- SmallVector<Value> yieldValues;
+ SmallVector<Type> newResultTypes;
+ newResultTypes.reserve(warpOp->getNumResults());
+ SmallVector<Value> newYieldValues;
+ newYieldValues.reserve(warpOp->getNumResults());
+ DenseMap<Value, int64_t> dedupYieldOperandPositionMap;
+ DenseMap<OpResult, int64_t> dedupResultPositionMap;
auto yield = cast<vector::YieldOp>(
warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+
+ // Some values may be yielded multiple times and correspond to multiple
+ // results. Deduplicating occurs by taking each result with its matching
+ // yielded value, and:
+ // 1. recording the unique first position at which the value is yielded.
+ // 2. recording for the result, the first position at which the dedup'ed
+ // value is yielded.
+ // 3. skipping from the new result types / new yielded values any result
+ // that has no use or whose yielded value has already been seen.
for (OpResult result : warpOp.getResults()) {
- if (result.use_empty())
+ Value yieldOperand = yield.getOperand(result.getResultNumber());
+ auto it = dedupYieldOperandPositionMap.insert(
+ std::make_pair(yieldOperand, newResultTypes.size()));
+ dedupResultPositionMap.insert(std::make_pair(result, it.first->second));
+ if (result.use_empty() || !it.second)
continue;
- resultTypes.push_back(result.getType());
- yieldValues.push_back(yield.getOperand(result.getResultNumber()));
+ newResultTypes.push_back(result.getType());
+ newYieldValues.push_back(yieldOperand);
}
- if (yield.getNumOperands() == yieldValues.size())
+ // No modification, exit early.
+ if (yield.getNumOperands() == newYieldValues.size())
return failure();
+ // Move the body of the old warpOp to a new warpOp.
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
- rewriter, warpOp, yieldValues, resultTypes);
- unsigned resultIndex = 0;
+ rewriter, warpOp, newYieldValues, newResultTypes);
+ // Replace results of the old warpOp by the new, deduplicated results.
+ SmallVector<Value> newValues;
+ newValues.reserve(warpOp->getNumResults());
for (OpResult result : warpOp.getResults()) {
if (result.use_empty())
- continue;
- result.replaceAllUsesWith(newWarpOp.getResult(resultIndex++));
+ newValues.push_back(Value());
+ else
+ newValues.push_back(
+ newWarpOp.getResult(dedupResultPositionMap.lookup(result)));
}
- rewriter.eraseOp(warpOp);
+ rewriter.replaceOp(warpOp, newValues);
return success();
}
};
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index 09d41db642fb2..5a70ae8a4994a 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -669,3 +669,25 @@ func.func @dont_duplicate_read(
}
return %r : vector<1xf32>
}
+
+// -----
+
+// CHECK-PROP: func @dedup
+func.func @dedup(%laneid: index, %v0: vector<4xf32>, %v1: vector<4xf32>)
+ -> (vector<1xf32>, vector<1xf32>) {
+
+ // CHECK-PROP: %[[SINGLE_RES:.*]] = vector.warp_execute_on_lane_0{{.*}} -> (vector<1xf32>) {
+ %r:2 = vector.warp_execute_on_lane_0(%laneid)[32]
+ args(%v0, %v1 : vector<4xf32>, vector<4xf32>) -> (vector<1xf32>, vector<1xf32>) {
+ ^bb0(%arg0: vector<128xf32>, %arg1: vector<128xf32>):
+
+ // CHECK-PROP: %[[SINGLE_VAL:.*]] = "some_def"(%{{.*}}) : (vector<128xf32>) -> vector<32xf32>
+ %2 = "some_def"(%arg0) : (vector<128xf32>) -> vector<32xf32>
+
+ // CHECK-PROP: vector.yield %[[SINGLE_VAL]] : vector<32xf32>
+ vector.yield %2, %2 : vector<32xf32>, vector<32xf32>
+ }
+
+ // CHECK-PROP: return %[[SINGLE_RES]], %[[SINGLE_RES]] : vector<1xf32>, vector<1xf32>
+ return %r#0, %r#1 : vector<1xf32>, vector<1xf32>
+}
More information about the Mlir-commits
mailing list