[Mlir-commits] [mlir] 40f56c8 - [mlir] [VectorOps] Replace zero-scalar + splat into direct zero vector constant
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon May 11 20:20:51 PDT 2020
Author: aartbik
Date: 2020-05-11T20:20:37-07:00
New Revision: 40f56c8cf189249465997dad8bce413b71ccbef0
URL: https://github.com/llvm/llvm-project/commit/40f56c8cf189249465997dad8bce413b71ccbef0
DIFF: https://github.com/llvm/llvm-project/commit/40f56c8cf189249465997dad8bce413b71ccbef0.diff
LOG: [mlir] [VectorOps] Replace zero-scalar + splat into direct zero vector constant
Summary:
The scalar zero + splat yields more intermediate code than the direct
dense zero constant, and ultimately is lowered to exactly the same
LLVM IR operations, so no point wasting the intermediate code.
Reviewers: nicolasvasilache, andydavis1, reidtatge
Reviewed By: nicolasvasilache
Subscribers: mehdi_amini, rriddle, jpienaar, shauheen, antiagainst, nicolasvasilache, arpith-jacob, mgester, lucyrfox, liufengdb, stephenneuendorffer, Joonsoo, grosul1, frgossen, Kayjukh, llvm-commits
Tags: #llvm
Differential Revision: https://reviews.llvm.org/D79758
Added:
Modified:
mlir/lib/Dialect/Vector/VectorTransforms.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index bc309f0c89c0..6e3e681ad815 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -929,10 +929,10 @@ class ExtractSlicesOpLowering
/// One:
/// %x = vector.insert_slices %0
/// is replaced by:
-/// %r0 = vector.splat 0
-// %t1 = vector.tuple_get %0, 0
+/// %r0 = zero-result
+/// %t1 = vector.tuple_get %0, 0
/// %r1 = vector.insert_strided_slice %r0, %t1
-// %t2 = vector.tuple_get %0, 1
+/// %t2 = vector.tuple_get %0, 1
/// %r2 = vector.insert_strided_slice %r1, %t2
/// ..
/// %x = ..
@@ -953,10 +953,8 @@ class InsertSlicesOpLowering : public OpRewritePattern<vector::InsertSlicesOp> {
op.getStrides(strides); // all-ones at the moment
// Prepare result.
- auto elemType = vectorType.getElementType();
- Value zero = rewriter.create<ConstantOp>(loc, elemType,
- rewriter.getZeroAttr(elemType));
- Value result = rewriter.create<SplatOp>(loc, vectorType, zero);
+ Value result = rewriter.create<ConstantOp>(
+ loc, vectorType, rewriter.getZeroAttr(vectorType));
// For each element in the tuple, extract the proper strided slice.
TupleType tupleType = op.getSourceTupleType();
@@ -1015,9 +1013,8 @@ class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
VectorType::get(dstType.getShape().drop_front(), eltType);
Value bcst =
rewriter.create<vector::BroadcastOp>(loc, resType, op.source());
- Value zero = rewriter.create<ConstantOp>(loc, eltType,
- rewriter.getZeroAttr(eltType));
- Value result = rewriter.create<SplatOp>(loc, dstType, zero);
+ Value result = rewriter.create<ConstantOp>(loc, dstType,
+ rewriter.getZeroAttr(dstType));
for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d)
result = rewriter.create<vector::InsertOp>(loc, bcst, result, d);
rewriter.replaceOp(op, result);
@@ -1064,9 +1061,8 @@ class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
// %x = [%a,%b,%c,%d]
VectorType resType =
VectorType::get(dstType.getShape().drop_front(), eltType);
- Value zero = rewriter.create<ConstantOp>(loc, eltType,
- rewriter.getZeroAttr(eltType));
- Value result = rewriter.create<SplatOp>(loc, dstType, zero);
+ Value result = rewriter.create<ConstantOp>(loc, dstType,
+ rewriter.getZeroAttr(dstType));
if (m == 0) {
// Stetch at start.
Value ext = rewriter.create<vector::ExtractOp>(loc, op.source(), 0);
@@ -1104,7 +1100,6 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
auto loc = op.getLoc();
VectorType resType = op.getResultType();
- Type eltType = resType.getElementType();
// Set up convenience transposition table.
SmallVector<int64_t, 4> transp;
@@ -1112,9 +1107,8 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
transp.push_back(attr.cast<IntegerAttr>().getInt());
// Generate fully unrolled extract/insert ops.
- Value zero = rewriter.create<ConstantOp>(loc, eltType,
- rewriter.getZeroAttr(eltType));
- Value result = rewriter.create<SplatOp>(loc, resType, zero);
+ Value result = rewriter.create<ConstantOp>(loc, resType,
+ rewriter.getZeroAttr(resType));
SmallVector<int64_t, 4> lhs(transp.size(), 0);
SmallVector<int64_t, 4> rhs(transp.size(), 0);
rewriter.replaceOp(op, expandIndices(loc, resType, 0, transp, lhs, rhs,
@@ -1173,9 +1167,8 @@ class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
Type eltType = resType.getElementType();
Value acc = (op.acc().empty()) ? nullptr : op.acc()[0];
- Value zero = rewriter.create<ConstantOp>(loc, eltType,
- rewriter.getZeroAttr(eltType));
- Value result = rewriter.create<SplatOp>(loc, resType, zero);
+ Value result = rewriter.create<ConstantOp>(loc, resType,
+ rewriter.getZeroAttr(resType));
for (int64_t d = 0, e = resType.getDimSize(0); d < e; ++d) {
auto pos = rewriter.getI64ArrayAttr(d);
Value x = rewriter.create<vector::ExtractOp>(loc, eltType, op.lhs(), pos);
@@ -1346,7 +1339,8 @@ class ContractionOpLowering : public OpRewritePattern<vector::ContractionOp> {
rewriter.getArrayAttr(adjustIter(op.iterator_types(), iterIndex));
// Unroll into a series of lower dimensional vector.contract ops.
Location loc = op.getLoc();
- Value result = zeroVector(loc, resType, rewriter);
+ Value result = rewriter.create<ConstantOp>(loc, resType,
+ rewriter.getZeroAttr(resType));
for (int64_t d = 0; d < dimSize; ++d) {
auto lhs = reshapeLoad(loc, op.lhs(), lhsType, lhsIndex, d, rewriter);
auto rhs = reshapeLoad(loc, op.rhs(), rhsType, rhsIndex, d, rewriter);
@@ -1381,7 +1375,8 @@ class ContractionOpLowering : public OpRewritePattern<vector::ContractionOp> {
// Base case.
if (lhsType.getRank() == 1) {
assert(rhsType.getRank() == 1 && "corrupt contraction");
- Value zero = zeroVector(loc, lhsType, rewriter);
+ Value zero = rewriter.create<ConstantOp>(loc, lhsType,
+ rewriter.getZeroAttr(lhsType));
Value fma = rewriter.create<vector::FMAOp>(loc, op.lhs(), op.rhs(), zero);
StringAttr kind = rewriter.getStringAttr("add");
return rewriter.create<vector::ReductionOp>(loc, resType, kind, fma,
@@ -1409,15 +1404,6 @@ class ContractionOpLowering : public OpRewritePattern<vector::ContractionOp> {
return result;
}
- // Helper method to construct a zero vector.
- static Value zeroVector(Location loc, VectorType vType,
- PatternRewriter &rewriter) {
- Type eltType = vType.getElementType();
- Value zero = rewriter.create<ConstantOp>(loc, eltType,
- rewriter.getZeroAttr(eltType));
- return rewriter.create<SplatOp>(loc, vType, zero);
- }
-
// Helper to find an index in an affine map.
static Optional<int64_t> getResultIndex(AffineMap map, int64_t index) {
for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
@@ -1493,7 +1479,8 @@ class ContractionOpLowering : public OpRewritePattern<vector::ContractionOp> {
// Unroll leading dimensions.
VectorType vType = lowType.cast<VectorType>();
VectorType resType = adjustType(type, index).cast<VectorType>();
- Value result = zeroVector(loc, resType, rewriter);
+ Value result = rewriter.create<ConstantOp>(loc, resType,
+ rewriter.getZeroAttr(resType));
for (int64_t d = 0, e = resType.getDimSize(0); d < e; d++) {
auto posAttr = rewriter.getI64ArrayAttr(d);
Value ext = rewriter.create<vector::ExtractOp>(loc, vType, val, posAttr);
@@ -1555,10 +1542,8 @@ class ShapeCastOp2DDownCastRewritePattern
return failure();
auto loc = op.getLoc();
- auto elemType = sourceVectorType.getElementType();
- Value zero = rewriter.create<ConstantOp>(loc, elemType,
- rewriter.getZeroAttr(elemType));
- Value desc = rewriter.create<SplatOp>(loc, resultVectorType, zero);
+ Value desc = rewriter.create<ConstantOp>(
+ loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
unsigned mostMinorVectorSize = sourceVectorType.getShape()[1];
for (int64_t i = 0, e = sourceVectorType.getShape().front(); i != e; ++i) {
Value vec = rewriter.create<vector::ExtractOp>(loc, op.source(), i);
@@ -1589,10 +1574,8 @@ class ShapeCastOp2DUpCastRewritePattern
return failure();
auto loc = op.getLoc();
- auto elemType = sourceVectorType.getElementType();
- Value zero = rewriter.create<ConstantOp>(loc, elemType,
- rewriter.getZeroAttr(elemType));
- Value desc = rewriter.create<SplatOp>(loc, resultVectorType, zero);
+ Value desc = rewriter.create<ConstantOp>(
+ loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
unsigned mostMinorVectorSize = resultVectorType.getShape()[1];
for (int64_t i = 0, e = resultVectorType.getShape().front(); i != e; ++i) {
Value vec = rewriter.create<vector::ExtractStridedSliceOp>(
More information about the Mlir-commits
mailing list