[Mlir-commits] [mlir] 40f56c8 - [mlir] [VectorOps] Replace zero-scalar + splat into direct zero vector constant

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon May 11 20:20:51 PDT 2020


Author: aartbik
Date: 2020-05-11T20:20:37-07:00
New Revision: 40f56c8cf189249465997dad8bce413b71ccbef0

URL: https://github.com/llvm/llvm-project/commit/40f56c8cf189249465997dad8bce413b71ccbef0
DIFF: https://github.com/llvm/llvm-project/commit/40f56c8cf189249465997dad8bce413b71ccbef0.diff

LOG: [mlir] [VectorOps] Replace zero-scalar + splat into direct zero vector constant

Summary:
The scalar zero + splat yields more intermediate code than the direct
dense zero constant, and ultimately is lowered to exactly the same
LLVM IR operations, so no point wasting the intermediate code.

Reviewers: nicolasvasilache, andydavis1, reidtatge

Reviewed By: nicolasvasilache

Subscribers: mehdi_amini, rriddle, jpienaar, shauheen, antiagainst, nicolasvasilache, arpith-jacob, mgester, lucyrfox, liufengdb, stephenneuendorffer, Joonsoo, grosul1, frgossen, Kayjukh, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D79758

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/VectorTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index bc309f0c89c0..6e3e681ad815 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -929,10 +929,10 @@ class ExtractSlicesOpLowering
 /// One:
 ///   %x = vector.insert_slices %0
 /// is replaced by:
-///   %r0 = vector.splat 0
-//    %t1 = vector.tuple_get %0, 0
+///   %r0 = zero-result
+///   %t1 = vector.tuple_get %0, 0
 ///   %r1 = vector.insert_strided_slice %r0, %t1
-//    %t2 = vector.tuple_get %0, 1
+///   %t2 = vector.tuple_get %0, 1
 ///   %r2 = vector.insert_strided_slice %r1, %t2
 ///   ..
 ///   %x  = ..
@@ -953,10 +953,8 @@ class InsertSlicesOpLowering : public OpRewritePattern<vector::InsertSlicesOp> {
     op.getStrides(strides); // all-ones at the moment
 
     // Prepare result.
-    auto elemType = vectorType.getElementType();
-    Value zero = rewriter.create<ConstantOp>(loc, elemType,
-                                             rewriter.getZeroAttr(elemType));
-    Value result = rewriter.create<SplatOp>(loc, vectorType, zero);
+    Value result = rewriter.create<ConstantOp>(
+        loc, vectorType, rewriter.getZeroAttr(vectorType));
 
     // For each element in the tuple, extract the proper strided slice.
     TupleType tupleType = op.getSourceTupleType();
@@ -1015,9 +1013,8 @@ class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
           VectorType::get(dstType.getShape().drop_front(), eltType);
       Value bcst =
           rewriter.create<vector::BroadcastOp>(loc, resType, op.source());
-      Value zero = rewriter.create<ConstantOp>(loc, eltType,
-                                               rewriter.getZeroAttr(eltType));
-      Value result = rewriter.create<SplatOp>(loc, dstType, zero);
+      Value result = rewriter.create<ConstantOp>(loc, dstType,
+                                                 rewriter.getZeroAttr(dstType));
       for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d)
         result = rewriter.create<vector::InsertOp>(loc, bcst, result, d);
       rewriter.replaceOp(op, result);
@@ -1064,9 +1061,8 @@ class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
     //   %x = [%a,%b,%c,%d]
     VectorType resType =
         VectorType::get(dstType.getShape().drop_front(), eltType);
-    Value zero = rewriter.create<ConstantOp>(loc, eltType,
-                                             rewriter.getZeroAttr(eltType));
-    Value result = rewriter.create<SplatOp>(loc, dstType, zero);
+    Value result = rewriter.create<ConstantOp>(loc, dstType,
+                                               rewriter.getZeroAttr(dstType));
     if (m == 0) {
       // Stetch at start.
       Value ext = rewriter.create<vector::ExtractOp>(loc, op.source(), 0);
@@ -1104,7 +1100,6 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
     auto loc = op.getLoc();
 
     VectorType resType = op.getResultType();
-    Type eltType = resType.getElementType();
 
     // Set up convenience transposition table.
     SmallVector<int64_t, 4> transp;
@@ -1112,9 +1107,8 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
       transp.push_back(attr.cast<IntegerAttr>().getInt());
 
     // Generate fully unrolled extract/insert ops.
-    Value zero = rewriter.create<ConstantOp>(loc, eltType,
-                                             rewriter.getZeroAttr(eltType));
-    Value result = rewriter.create<SplatOp>(loc, resType, zero);
+    Value result = rewriter.create<ConstantOp>(loc, resType,
+                                               rewriter.getZeroAttr(resType));
     SmallVector<int64_t, 4> lhs(transp.size(), 0);
     SmallVector<int64_t, 4> rhs(transp.size(), 0);
     rewriter.replaceOp(op, expandIndices(loc, resType, 0, transp, lhs, rhs,
@@ -1173,9 +1167,8 @@ class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
     Type eltType = resType.getElementType();
     Value acc = (op.acc().empty()) ? nullptr : op.acc()[0];
 
-    Value zero = rewriter.create<ConstantOp>(loc, eltType,
-                                             rewriter.getZeroAttr(eltType));
-    Value result = rewriter.create<SplatOp>(loc, resType, zero);
+    Value result = rewriter.create<ConstantOp>(loc, resType,
+                                               rewriter.getZeroAttr(resType));
     for (int64_t d = 0, e = resType.getDimSize(0); d < e; ++d) {
       auto pos = rewriter.getI64ArrayAttr(d);
       Value x = rewriter.create<vector::ExtractOp>(loc, eltType, op.lhs(), pos);
@@ -1346,7 +1339,8 @@ class ContractionOpLowering : public OpRewritePattern<vector::ContractionOp> {
         rewriter.getArrayAttr(adjustIter(op.iterator_types(), iterIndex));
     // Unroll into a series of lower dimensional vector.contract ops.
     Location loc = op.getLoc();
-    Value result = zeroVector(loc, resType, rewriter);
+    Value result = rewriter.create<ConstantOp>(loc, resType,
+                                               rewriter.getZeroAttr(resType));
     for (int64_t d = 0; d < dimSize; ++d) {
       auto lhs = reshapeLoad(loc, op.lhs(), lhsType, lhsIndex, d, rewriter);
       auto rhs = reshapeLoad(loc, op.rhs(), rhsType, rhsIndex, d, rewriter);
@@ -1381,7 +1375,8 @@ class ContractionOpLowering : public OpRewritePattern<vector::ContractionOp> {
     // Base case.
     if (lhsType.getRank() == 1) {
       assert(rhsType.getRank() == 1 && "corrupt contraction");
-      Value zero = zeroVector(loc, lhsType, rewriter);
+      Value zero = rewriter.create<ConstantOp>(loc, lhsType,
+                                               rewriter.getZeroAttr(lhsType));
       Value fma = rewriter.create<vector::FMAOp>(loc, op.lhs(), op.rhs(), zero);
       StringAttr kind = rewriter.getStringAttr("add");
       return rewriter.create<vector::ReductionOp>(loc, resType, kind, fma,
@@ -1409,15 +1404,6 @@ class ContractionOpLowering : public OpRewritePattern<vector::ContractionOp> {
     return result;
   }
 
-  // Helper method to construct a zero vector.
-  static Value zeroVector(Location loc, VectorType vType,
-                          PatternRewriter &rewriter) {
-    Type eltType = vType.getElementType();
-    Value zero = rewriter.create<ConstantOp>(loc, eltType,
-                                             rewriter.getZeroAttr(eltType));
-    return rewriter.create<SplatOp>(loc, vType, zero);
-  }
-
   // Helper to find an index in an affine map.
   static Optional<int64_t> getResultIndex(AffineMap map, int64_t index) {
     for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
@@ -1493,7 +1479,8 @@ class ContractionOpLowering : public OpRewritePattern<vector::ContractionOp> {
     // Unroll leading dimensions.
     VectorType vType = lowType.cast<VectorType>();
     VectorType resType = adjustType(type, index).cast<VectorType>();
-    Value result = zeroVector(loc, resType, rewriter);
+    Value result = rewriter.create<ConstantOp>(loc, resType,
+                                               rewriter.getZeroAttr(resType));
     for (int64_t d = 0, e = resType.getDimSize(0); d < e; d++) {
       auto posAttr = rewriter.getI64ArrayAttr(d);
       Value ext = rewriter.create<vector::ExtractOp>(loc, vType, val, posAttr);
@@ -1555,10 +1542,8 @@ class ShapeCastOp2DDownCastRewritePattern
       return failure();
 
     auto loc = op.getLoc();
-    auto elemType = sourceVectorType.getElementType();
-    Value zero = rewriter.create<ConstantOp>(loc, elemType,
-                                             rewriter.getZeroAttr(elemType));
-    Value desc = rewriter.create<SplatOp>(loc, resultVectorType, zero);
+    Value desc = rewriter.create<ConstantOp>(
+        loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
     unsigned mostMinorVectorSize = sourceVectorType.getShape()[1];
     for (int64_t i = 0, e = sourceVectorType.getShape().front(); i != e; ++i) {
       Value vec = rewriter.create<vector::ExtractOp>(loc, op.source(), i);
@@ -1589,10 +1574,8 @@ class ShapeCastOp2DUpCastRewritePattern
       return failure();
 
     auto loc = op.getLoc();
-    auto elemType = sourceVectorType.getElementType();
-    Value zero = rewriter.create<ConstantOp>(loc, elemType,
-                                             rewriter.getZeroAttr(elemType));
-    Value desc = rewriter.create<SplatOp>(loc, resultVectorType, zero);
+    Value desc = rewriter.create<ConstantOp>(
+        loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
     unsigned mostMinorVectorSize = resultVectorType.getShape()[1];
     for (int64_t i = 0, e = resultVectorType.getShape().front(); i != e; ++i) {
       Value vec = rewriter.create<vector::ExtractStridedSliceOp>(


        


More information about the Mlir-commits mailing list