[Mlir-commits] [mlir] 7006daa - [MLIR][Vector] Update ShapeCastOp folder to use producer-consumer value forwarding.

Andy Davis llvmlistbot at llvm.org
Wed Apr 8 08:55:45 PDT 2020


Author: Andy Davis
Date: 2020-04-08T08:55:37-07:00
New Revision: 7006daa548c25960dbb5a50e9b9987d4dd01798b

URL: https://github.com/llvm/llvm-project/commit/7006daa548c25960dbb5a50e9b9987d4dd01798b
DIFF: https://github.com/llvm/llvm-project/commit/7006daa548c25960dbb5a50e9b9987d4dd01798b.diff

LOG: [MLIR][Vector] Update ShapeCastOp folder to use producer-consumer value forwarding.

Summary:
Update ShapeCastOp folder to use producer-consumer value forwarding.
Support is added for tracking sub-vectors through trivial shape cast operations,
where the sub-vector shape is preserved across shape cast operations and only
leading ones are added or removed.
Support is preserved for cancelling shape cast operations.
One unit test is added and two are updated.

Reviewers: aartbik, nicolasvasilache

Reviewed By: aartbik, nicolasvasilache

Subscribers: frgossen, mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, arpith-jacob, mgester, lucyrfox, liufengdb, Joonsoo, grosul1, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D77253

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/VectorTransforms.cpp
    mlir/test/Dialect/Vector/vector-transforms.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index dbb0bf443700..7a197ef14334 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -676,10 +676,10 @@ struct ShapeCastOpDecomposer : public OpRewritePattern<vector::ShapeCastOp> {
 
 /// Returns the producer Value of the same type as 'consumerValue', by tracking
 /// the tuple index and offsets of the consumer vector value through the
-/// chain of operations (TupleGetOp, InsertSlicesOp, ExtractSlicesOp, TupleOp)
-/// from consumer to producer. Each operation in the chain is structured, and
-/// so the tuple index and offsets can be mapped from result to input, while
-/// visiting each operation in the chain.
+/// chain of operations (TupleGetOp, InsertSlicesOp, ExtractSlicesOp, TupleOp,
+/// and ShapeCastOp) from consumer to producer. Each operation in the chain is
+/// structured, and so the tuple index and offsets can be mapped from result to
+/// input, while visiting each operation in the chain.
 /// Returns nullptr on failure.
 static Value getProducerValue(Value consumerValue) {
   auto consumerVectorType = consumerValue.getType().cast<VectorType>();
@@ -760,8 +760,57 @@ static Value getProducerValue(Value consumerValue) {
       // Update 'tupleIndex' and next defining 'op' to visit.
       tupleIndex = -1;
       op = value.getDefiningOp();
+    } else if (auto shapeCastOp = dyn_cast<vector::ShapeCastOp>(op)) {
+      if (shapeCastOp.source().getType().isa<TupleType>())
+        return nullptr;
+      assert(tupleIndex == -1);
+      auto sourceVectorType = shapeCastOp.getSourceVectorType();
+      auto sourceVectorShape = sourceVectorType.getShape();
+      unsigned sourceVectorRank = sourceVectorType.getRank();
+      auto resultVectorType = shapeCastOp.getResultVectorType();
+      auto resultVectorShape = resultVectorType.getShape();
+      unsigned resultVectorRank = resultVectorType.getRank();
+
+      int i = sourceVectorRank - 1;
+      int j = resultVectorRank - 1;
+
+      // Check that source/result vector shape prefixes match while
+      // updating 'newOffsets'.
+      bool canShapeCastFold = true;
+      SmallVector<int64_t, 4> newOffsets(sourceVectorRank, 0);
+
+      auto apply = [&](int64_t sourceSize, int64_t resultSize) {
+        canShapeCastFold = sourceSize == resultSize;
+        newOffsets[i--] = offsets[j--];
+      };
+      functional::zipApply(apply, llvm::reverse(sourceVectorShape),
+                           llvm::reverse(resultVectorShape));
+      if (!canShapeCastFold)
+        return nullptr;
+
+      // Check that remaining prefix of source/result vector shapes are all 1s.
+      // Currently we only support producer/consumer tracking through trivial
+      // shape cast ops. Examples:
+      //   %1 = vector.shape_cast %0 : vector<1x1x2x4xf32> to vector<2x4xf32>
+      //   %3 = vector.shape_cast %2 : vector<16x8xf32> to vector<1x16x8xf32>
+      assert(i == -1 || j == -1);
+      if (i >= 0 &&
+          !std::all_of(sourceVectorShape.begin(), sourceVectorShape.begin() + i,
+                       [](int64_t v) { return v == 1; }))
+        return nullptr;
+      if (j >= 0 &&
+          !std::all_of(resultVectorShape.begin(), resultVectorShape.begin() + j,
+                       [](int64_t v) { return v == 1; }))
+        return nullptr;
+
+      offsets.swap(newOffsets);
+      op = shapeCastOp.source().getDefiningOp();
     } else {
-      break;
+      // Check if 'op' produces a Value with the same type as 'consumerValue'.
+      if (op->getNumResults() == 1 &&
+          op->getResult(0).getType() == consumerVectorType)
+        return op->getResult(0);
+      return nullptr;
     }
   }
   return nullptr;
@@ -788,6 +837,12 @@ struct ShapeCastOpFolder : public OpRewritePattern<vector::ShapeCastOp> {
 
   LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
                                 PatternRewriter &rewriter) const override {
+    // Check if we can replace 'shapeCastOp' result with its producer.
+    if (auto producer = getProducerValue(shapeCastOp.getResult())) {
+      rewriter.replaceOp(shapeCastOp, producer);
+      return success();
+    }
+
     // Check if 'shapeCastOp' has vector source/result type.
     auto sourceVectorType =
         shapeCastOp.source().getType().dyn_cast_or_null<VectorType>();

diff  --git a/mlir/test/Dialect/Vector/vector-transforms.mlir b/mlir/test/Dialect/Vector/vector-transforms.mlir
index 082afbaff0b0..2e4e9033fb81 100644
--- a/mlir/test/Dialect/Vector/vector-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-transforms.mlir
@@ -341,16 +341,21 @@ func @tuple_get_producer_consumer(
   %2 = vector.extract_slices %1, [4, 8], [1, 1]
     : vector<4x16xf32> into tuple<vector<4x8xf32>, vector<4x8xf32>>
   // %arg7 == %2 at tupleIndex = 1, offsets = [2, 4]
-  %3 = vector.tuple_get %2, 1 : tuple<vector<4x8xf32>, vector<4x8xf32>>
-  // %arg7 == %3 at tupleIndex = -1, offsets = [2, 4]
-  %4 = vector.extract_slices %3, [2, 4], [1, 1]
+  %3 = vector.shape_cast %2 : tuple<vector<4x8xf32>, vector<4x8xf32>> to
+                              tuple<vector<1x1x4x8xf32>, vector<1x1x4x8xf32>>
+  // %arg7 = %3 at tupleIndex = 1, offsets = [0, 0, 2, 4]
+  %4 = vector.tuple_get %3, 1 : tuple<vector<1x1x4x8xf32>, vector<1x1x4x8xf32>>
+  // %arg7 == %4 at tupleIndex = -1, offsets = [0, 0, 2, 4]
+  %5 = vector.shape_cast %4 : vector<1x1x4x8xf32> to vector<4x8xf32>
+  // %arg7 == %5 at tupleIndex = -1, offsets = [2, 4]
+  %6 = vector.extract_slices %5, [2, 4], [1, 1]
     : vector<4x8xf32> into
       tuple<vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>>
-  // %arg7 == %4 at tupleIndex = 3, offsets = [0, 0]
-  %5 = vector.tuple_get %4, 3
+  // %arg7 == %6 at tupleIndex = 3, offsets = [0, 0]
+  %7 = vector.tuple_get %6, 3
     : tuple<vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>>
-  // %arg7 == %5
-  return %5 : vector<2x4xf32>
+  // %arg7 == %7
+  return %7 : vector<2x4xf32>
 }
 
 // CHECK-LABEL: func @tuple_get_producer_consumer_swizzle
@@ -381,25 +386,40 @@ func @tuple_get_producer_consumer_swizzle(
   %2 = vector.extract_slices %1, [4, 8], [1, 1]
     : vector<4x16xf32> into tuple<vector<4x8xf32>, vector<4x8xf32>>
   // %arg7 == %2 at tupleIndex = 1, offsets = [2, 4]
+  %3= vector.shape_cast %2 : tuple<vector<4x8xf32>, vector<4x8xf32>> to
+                             tuple<vector<1x1x4x8xf32>, vector<1x1x4x8xf32>>
+  // %arg7 = %3 at tupleIndex = 1, offsets = [0, 0, 2, 4]
 
   // Extract tuple elements.
-  %3 = vector.tuple_get %2, 0 : tuple<vector<4x8xf32>, vector<4x8xf32>>
-  %4 = vector.tuple_get %2, 1 : tuple<vector<4x8xf32>, vector<4x8xf32>>
-  // %arg7 == %4 at tupleIndex = -1, offsets = [2, 4]
+  %4 = vector.tuple_get %3, 0 : tuple<vector<1x1x4x8xf32>, vector<1x1x4x8xf32>>
+  %5 = vector.tuple_get %3, 1 : tuple<vector<1x1x4x8xf32>, vector<1x1x4x8xf32>>
+  // %arg7 == %5 at tupleIndex = -1, offsets = [0, 0, 2, 4]
 
   // Swizzle tuple elements.
-  %5 = vector.tuple %4, %3 : vector<4x8xf32>, vector<4x8xf32>
-  // %arg7 == %5 at tupleIndex = 0, offsets = [2, 4]
-  %6 = vector.tuple_get %5, 0 : tuple<vector<4x8xf32>, vector<4x8xf32>>
-  // %arg7 == %6 at tupleIndex = -1, offsets = [2, 4]
-  %7 = vector.extract_slices %6, [2, 4], [1, 1]
+  %6 = vector.tuple %5, %4 : vector<1x1x4x8xf32>, vector<1x1x4x8xf32>
+  // %arg7 == %6 at tupleIndex = 0, offsets = [0, 0, 2, 4]
+  %7 = vector.shape_cast %6 : tuple<vector<1x1x4x8xf32>, vector<1x1x4x8xf32>> to
+                              tuple<vector<4x8xf32>, vector<4x8xf32>>
+  // %arg7 = %7 at tupleIndex = 0, offsets = [2, 4]
+  %8 = vector.tuple_get %7, 0 : tuple<vector<4x8xf32>, vector<4x8xf32>>
+  // %arg7 == %8 at tupleIndex = -1, offsets = [2, 4]
+  %9 = vector.extract_slices %8, [2, 4], [1, 1]
     : vector<4x8xf32> into
       tuple<vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>>
-  // %arg7 == %7 at tupleIndex = 3, offsets = [0, 0]
-  %8 = vector.tuple_get %7, 3
+  // %arg7 == %9 at tupleIndex = 3, offsets = [0, 0]
+  %10 = vector.tuple_get %9, 3
     : tuple<vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>>
-  // %arg7 == %8
-  return %8 : vector<2x4xf32>
+  // %arg7 == %10
+  return %10 : vector<2x4xf32>
+}
+
+// CHECK-LABEL: func @cancelling_shape_cast_ops
+//  CHECK-SAME: %[[A0:.*0]]: vector<2x4xf32>
+//       CHECK: return %[[A0]] : vector<2x4xf32>
+func @cancelling_shape_cast_ops(%arg0 : vector<2x4xf32>) -> vector<2x4xf32> {
+  %0 = vector.shape_cast %arg0 : vector<2x4xf32> to vector<8xf32>
+  %1 = vector.shape_cast %0 : vector<8xf32> to vector<2x4xf32>
+  return %1 : vector<2x4xf32>
 }
 
 // CHECK-LABEL: func @vector_transfers_vector_element_type


        


More information about the Mlir-commits mailing list