[Mlir-commits] [mlir] bf068e1 - [mlir] Do not use pass labels in unrolled ProgressiveVectorToSCF

Matthias Springer llvmlistbot at llvm.org
Thu May 13 06:16:35 PDT 2021


Author: Matthias Springer
Date: 2021-05-13T22:01:08+09:00
New Revision: bf068e1077a44fcb52fdf2aeb8f03f80517b64ab

URL: https://github.com/llvm/llvm-project/commit/bf068e1077a44fcb52fdf2aeb8f03f80517b64ab
DIFF: https://github.com/llvm/llvm-project/commit/bf068e1077a44fcb52fdf2aeb8f03f80517b64ab.diff

LOG: [mlir] Do not use pass labels in unrolled ProgressiveVectorToSCF

Do not rely on pass labels to detect if the pattern was already applied in the past (which allows for more some extra optimizations to avoid extra InsertOps and ExtractOps). Instead, check if these optimizations can be applied on-the-fly.

This also fixes a bug, where vector.insert and vector.extract ops sometimes disappeared in the middle of the pass because they get folded away, but the next application of the pattern expected them to be there.

Differential Revision: https://reviews.llvm.org/D102206

Added: 
    

Modified: 
    mlir/lib/Conversion/VectorToSCF/ProgressiveVectorToSCF.cpp
    mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-3d.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/VectorToSCF/ProgressiveVectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/ProgressiveVectorToSCF.cpp
index 7b976cc3c2a5..c62939243186 100644
--- a/mlir/lib/Conversion/VectorToSCF/ProgressiveVectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/ProgressiveVectorToSCF.cpp
@@ -724,35 +724,39 @@ static void maybeAssignMask(OpBuilder &builder, OpTy xferOp, OpTy newXferOp,
 /// %vec = vector.insert %tmp1, %v3[4] : vector<4xf32> into vector<5x4xf32>
 /// ```
 ///
-/// Note: A pass label is attached to new TransferReadOps, so that subsequent
-/// applications of this pattern do not create an additional %v_init vector.
+/// Note: As an optimization, if the result of the original TransferReadOp
+/// was directly inserted into another vector, no new %v_init vector is created.
+/// Instead, the new TransferReadOp results are inserted into that vector.
 struct UnrollTransferReadConversion : public OpRewritePattern<TransferReadOp> {
   using OpRewritePattern<TransferReadOp>::OpRewritePattern;
 
-  /// Find the result vector %v_init or create a new vector if this the first
-  /// application of the pattern.
+  /// Return the vector into which the newly created TransferReadOp results
+  /// are inserted.
   Value getResultVector(TransferReadOp xferOp,
                         PatternRewriter &rewriter) const {
-    if (xferOp->hasAttr(kPassLabel)) {
-      return getInsertOp(xferOp).dest();
-    }
+    if (auto insertOp = getInsertOp(xferOp))
+      return insertOp.dest();
     return std_splat(xferOp.getVectorType(), xferOp.padding()).value;
   }
 
-  /// Assuming that this not the first application of the pattern, return the
-  /// vector.insert op in which the result of this transfer op is used.
+  /// If the result of the TransferReadOp has exactly one user, which is a
+  /// vector::InsertOp, return that operation.
   vector::InsertOp getInsertOp(TransferReadOp xferOp) const {
-    Operation *xferOpUser = *xferOp->getUsers().begin();
-    return dyn_cast<vector::InsertOp>(xferOpUser);
+    if (xferOp->hasOneUse()) {
+      Operation *xferOpUser = *xferOp->getUsers().begin();
+      if (auto insertOp = dyn_cast<vector::InsertOp>(xferOpUser))
+        return insertOp;
+    }
+
+    return vector::InsertOp();
   }
 
-  /// Assuming that this not the first application of the pattern, return the
-  /// indices of the vector.insert op in which the result of this transfer op
-  /// is used.
+  /// If the result of the TransferReadOp has exactly one user, which is a
+  /// vector::InsertOp, return that operation's indices.
   void getInsertionIndices(TransferReadOp xferOp,
                            SmallVector<int64_t, 8> &indices) const {
-    if (xferOp->hasAttr(kPassLabel)) {
-      llvm::for_each(getInsertOp(xferOp).position(), [&](Attribute attr) {
+    if (auto insertOp = getInsertOp(xferOp)) {
+      llvm::for_each(insertOp.position(), [&](Attribute attr) {
         indices.push_back(attr.dyn_cast<IntegerAttr>().getInt());
       });
     }
@@ -766,6 +770,7 @@ struct UnrollTransferReadConversion : public OpRewritePattern<TransferReadOp> {
       return failure();
 
     ScopedContext scope(rewriter, xferOp.getLoc());
+    auto insertOp = getInsertOp(xferOp);
     auto vec = getResultVector(xferOp, rewriter);
     auto vecType = vec.getType().dyn_cast<VectorType>();
     auto xferVecType = xferOp.getVectorType();
@@ -803,7 +808,6 @@ struct UnrollTransferReadConversion : public OpRewritePattern<TransferReadOp> {
                 dyn_cast<TransferReadOp>(newXferOpVal.getDefiningOp());
 
             maybeAssignMask(b, xferOp, newXferOp, i);
-            maybeApplyPassLabel(b, newXferOp);
 
             return vector_insert(newXferOp, vec, insertionIndices).value;
           },
@@ -814,8 +818,9 @@ struct UnrollTransferReadConversion : public OpRewritePattern<TransferReadOp> {
           });
     }
 
-    if (xferOp->hasAttr(kPassLabel)) {
-      rewriter.replaceOp(getInsertOp(xferOp), vec);
+    if (insertOp) {
+      // Rewrite single user of the old TransferReadOp, which was an InsertOp.
+      rewriter.replaceOp(insertOp, vec);
       rewriter.eraseOp(xferOp);
     } else {
       rewriter.replaceOp(xferOp, vec);
@@ -846,32 +851,33 @@ struct UnrollTransferReadConversion : public OpRewritePattern<TransferReadOp> {
 /// vector.transfer_write %v4, %A[%a, %b + 4, %c] : vector<4xf32>, memref<...>
 /// ```
 ///
-/// Note: A pass label is attached to new TransferWriteOps, so that subsequent
-/// applications of this pattern can read the indices of previously generated
-/// vector.extract ops.
+/// Note: As an optimization, if the vector of the original TransferWriteOp
+/// was directly extracted from another vector via an ExtractOp `a`, extract
+/// the vectors for the newly generated TransferWriteOps from `a`'s input. By
+/// doing so, `a` may become dead, and the number of ExtractOps generated during
+/// recursive application of this pattern will be minimal.
 struct UnrollTransferWriteConversion
     : public OpRewritePattern<TransferWriteOp> {
   using OpRewritePattern<TransferWriteOp>::OpRewritePattern;
 
-  /// If this is not the first application of the pattern, find the original
-  /// vector %vec that is written by this transfer op. Otherwise, return the
-  /// vector of this transfer op.
+  /// Return the vector from which newly generated ExtracOps will extract.
   Value getDataVector(TransferWriteOp xferOp) const {
-    if (xferOp->hasAttr(kPassLabel))
-      return getExtractOp(xferOp).vector();
+    if (auto extractOp = getExtractOp(xferOp))
+      return extractOp.vector();
     return xferOp.vector();
   }
 
-  /// Assuming that this is not the first application of the pattern, find the
-  /// vector.extract op whose result is written by this transfer op.
+  /// If the input of the given TransferWriteOp is an ExtractOp, return it.
   vector::ExtractOp getExtractOp(TransferWriteOp xferOp) const {
     return dyn_cast<vector::ExtractOp>(xferOp.vector().getDefiningOp());
   }
 
+  /// If the input of the given TransferWriteOp is an ExtractOp, return its
+  /// indices.
   void getExtractionIndices(TransferWriteOp xferOp,
                             SmallVector<int64_t, 8> &indices) const {
-    if (xferOp->hasAttr(kPassLabel)) {
-      llvm::for_each(getExtractOp(xferOp).position(), [&](Attribute attr) {
+    if (auto extractOp = getExtractOp(xferOp)) {
+      llvm::for_each(extractOp.position(), [&](Attribute attr) {
         indices.push_back(attr.dyn_cast<IntegerAttr>().getInt());
       });
     }
@@ -918,7 +924,6 @@ struct UnrollTransferWriteConversion
                     .op;
 
             maybeAssignMask(b, xferOp, newXferOp, i);
-            maybeApplyPassLabel(b, newXferOp);
           });
     }
 

diff  --git a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-3d.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-3d.mlir
index 902f4f50d223..6de89a6cd6ac 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-3d.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-3d.mlir
@@ -4,7 +4,7 @@
 // RUN: FileCheck %s
 
 // RUN: mlir-opt %s -test-unrolled-progressive-convert-vector-to-scf -lower-affine -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \
-// RUN: mlir-cpu-runner -e entry -entry-point-result=void  \
+// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
 // RUN:   -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
 // RUN: FileCheck %s
 
@@ -17,6 +17,17 @@ func @transfer_read_3d(%A : memref<?x?x?x?xf32>,
   return
 }
 
+func @transfer_read_3d_and_extract(%A : memref<?x?x?x?xf32>,
+                                   %o: index, %a: index, %b: index, %c: index) {
+  %fm42 = constant -42.0: f32
+  %f = vector.transfer_read %A[%o, %a, %b, %c], %fm42
+      {in_bounds = [true, true, true]}
+      : memref<?x?x?x?xf32>, vector<2x5x3xf32>
+  %sub = vector.extract %f[0] : vector<2x5x3xf32>
+  vector.print %sub: vector<5x3xf32>
+  return
+}
+
 func @transfer_read_3d_broadcast(%A : memref<?x?x?x?xf32>,
                                  %o: index, %a: index, %b: index, %c: index) {
   %fm42 = constant -42.0: f32
@@ -94,26 +105,31 @@ func @entry() {
       : (memref<?x?x?x?xf32>, index, index, index, index) -> ()
   // CHECK: ( ( ( 0, 0, -42 ), ( 2, 3, -42 ), ( 4, 6, -42 ), ( 6, 9, -42 ), ( -42, -42, -42 ) ), ( ( 20, 30, -42 ), ( 22, 33, -42 ), ( 24, 36, -42 ), ( 26, 39, -42 ), ( -42, -42, -42 ) ) )
 
-  // 2. Write 3D vector to 4D memref.
+  // 2. Read 3D vector from 4D memref and extract subvector from result.
+  call @transfer_read_3d_and_extract(%A, %c0, %c0, %c0, %c0)
+      : (memref<?x?x?x?xf32>, index, index, index, index) -> ()
+  // CHECK: ( ( 0, 0, 2 ), ( 2, 3, 4 ), ( 4, 6, 6 ), ( 6, 9, 20 ), ( 20, 30, 22 ) )
+
+  // 3. Write 3D vector to 4D memref.
   call @transfer_write_3d(%A, %c0, %c0, %c1, %c1)
       : (memref<?x?x?x?xf32>, index, index, index, index) -> ()
 
-  // 3. Read memref to verify step 2.
+  // 4. Read memref to verify step 2.
   call @transfer_read_3d(%A, %c0, %c0, %c0, %c0)
       : (memref<?x?x?x?xf32>, index, index, index, index) -> ()
   // CHECK: ( ( ( 0, 0, -42 ), ( 2, -1, -42 ), ( 4, -1, -42 ), ( 6, -1, -42 ), ( -42, -42, -42 ) ), ( ( 20, 30, -42 ), ( 22, -1, -42 ), ( 24, -1, -42 ), ( 26, -1, -42 ), ( -42, -42, -42 ) ) )
 
-  // 4. Read 3D vector from 4D memref and transpose vector.
+  // 5. Read 3D vector from 4D memref and transpose vector.
   call @transfer_read_3d_transposed(%A, %c0, %c0, %c0, %c0)
       : (memref<?x?x?x?xf32>, index, index, index, index) -> ()
   // CHECK: ( ( ( 0, 20, 40 ), ( 0, 20, 40 ), ( 0, 20, 40 ), ( 0, 20, 40 ), ( 0, 20, 40 ) ), ( ( 0, 30, 60 ), ( 0, 30, 60 ), ( 0, 30, 60 ), ( 0, 30, 60 ), ( 0, 30, 60 ) ), ( ( -42, -42, -42 ), ( -42, -42, -42 ), ( -42, -42, -42 ), ( -42, -42, -42 ), ( -42, -42, -42 ) ) )
 
-  // 5. Read 1D vector from 4D memref and broadcast vector to 3D.
+  // 6. Read 1D vector from 4D memref and broadcast vector to 3D.
   call @transfer_read_3d_broadcast(%A, %c0, %c0, %c0, %c0)
       : (memref<?x?x?x?xf32>, index, index, index, index) -> ()
   // CHECK: ( ( ( 0, 0, -42 ), ( 0, 0, -42 ), ( 0, 0, -42 ), ( 0, 0, -42 ), ( 0, 0, -42 ) ), ( ( 20, 30, -42 ), ( 20, 30, -42 ), ( 20, 30, -42 ), ( 20, 30, -42 ), ( 20, 30, -42 ) ) )
 
-  // 6. Read 1D vector from 4D memref with mask and broadcast vector to 3D.
+  // 7. Read 1D vector from 4D memref with mask and broadcast vector to 3D.
   call @transfer_read_3d_mask_broadcast(%A, %c0, %c0, %c0, %c0)
       : (memref<?x?x?x?xf32>, index, index, index, index) -> ()
   // CHECK: ( ( ( -42, -42, -42 ), ( -42, -42, -42 ), ( -42, -42, -42 ), ( -42, -42, -42 ), ( -42, -42, -42 ) ), ( ( 20, 20, 20 ), ( 20, 20, 20 ), ( 20, 20, 20 ), ( 20, 20, 20 ), ( 20, 20, 20 ) ) )


        


More information about the Mlir-commits mailing list