[Mlir-commits] [mlir] 0af2680 - [mlir][vector] Add pattern to distribute splat constant
Thomas Raoux
llvmlistbot at llvm.org
Mon Jul 11 08:52:20 PDT 2022
Author: Thomas Raoux
Date: 2022-07-11T15:50:26Z
New Revision: 0af268059636647798b00bd85dc4faecf537ce52
URL: https://github.com/llvm/llvm-project/commit/0af268059636647798b00bd85dc4faecf537ce52
DIFF: https://github.com/llvm/llvm-project/commit/0af268059636647798b00bd85dc4faecf537ce52.diff
LOG: [mlir][vector] Add pattern to distribute splat constant
Distribute splat constant out of WarpExecuteOnLane0Op region.
Differential Revision: https://reviews.llvm.org/D129467
Added:
Modified:
mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
mlir/test/Dialect/Vector/vector-warp-distribute.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 57fa863320906..1fb7a215a285f 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -524,6 +524,44 @@ struct WarpOpElementwise : public OpRewritePattern<WarpExecuteOnLane0Op> {
}
};
+/// Sink out splat constant op feeding into a warp op yield.
+/// ```
+/// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
+/// ...
+/// %cst = arith.constant dense<2.0> : vector<32xf32>
+/// vector.yield %cst : vector<32xf32>
+/// }
+/// ```
+/// To
+/// ```
+/// vector.warp_execute_on_lane_0(%arg0 {
+/// ...
+/// }
+/// %0 = arith.constant dense<2.0> : vector<1xf32>
+struct WarpOpConstant : public OpRewritePattern<WarpExecuteOnLane0Op> {
+ using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
+ LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
+ PatternRewriter &rewriter) const override {
+ OpOperand *yieldOperand = getWarpResult(
+ warpOp, [](Operation *op) { return isa<arith::ConstantOp>(op); });
+ if (!yieldOperand)
+ return failure();
+ auto constantOp = yieldOperand->get().getDefiningOp<arith::ConstantOp>();
+ auto dense = constantOp.getValue().dyn_cast<SplatElementsAttr>();
+ if (!dense)
+ return failure();
+ unsigned operandIndex = yieldOperand->getOperandNumber();
+ Attribute scalarAttr = dense.getSplatValue<Attribute>();
+ Attribute newAttr = DenseElementsAttr::get(
+ warpOp.getResult(operandIndex).getType(), scalarAttr);
+ Location loc = warpOp.getLoc();
+ rewriter.setInsertionPointAfter(warpOp);
+ Value distConstant = rewriter.create<arith::ConstantOp>(loc, newAttr);
+ warpOp.getResult(operandIndex).replaceAllUsesWith(distConstant);
+ return success();
+ }
+};
+
/// Sink out transfer_read op feeding into a warp op yield.
/// ```
/// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
@@ -868,8 +906,8 @@ void mlir::vector::populateDistributeTransferWriteOpPatterns(
void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
RewritePatternSet &patterns) {
patterns.add<WarpOpElementwise, WarpOpTransferRead, WarpOpDeadResult,
- WarpOpBroadcast, WarpOpForwardOperand, WarpOpScfForOp>(
- patterns.getContext());
+ WarpOpBroadcast, WarpOpForwardOperand, WarpOpScfForOp,
+ WarpOpConstant>(patterns.getContext());
}
void mlir::vector::populateDistributeReduction(
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index 4a04f988be979..55a8490049d8d 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -562,3 +562,16 @@ func.func @warp_duplicate_yield(%laneid: index) -> (vector<1xf32>, vector<1xf32>
}
return %r#0, %r#1 : vector<1xf32>, vector<1xf32>
}
+
+// -----
+
+// CHECK-PROP-LABEL: func @warp_constant(
+// CHECK-PROP: %[[C:.*]] = arith.constant dense<2.000000e+00> : vector<1xf32>
+// CHECK-PROP: return %[[C]] : vector<1xf32>
+func.func @warp_constant(%laneid: index) -> (vector<1xf32>) {
+ %r = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xf32>) {
+ %cst = arith.constant dense<2.0> : vector<32xf32>
+ vector.yield %cst : vector<32xf32>
+ }
+ return %r : vector<1xf32>
+}
More information about the Mlir-commits
mailing list