[Mlir-commits] [mlir] 9b77be5 - [mlir] Unrolled progressive-vector-to-scf.

Matthias Springer llvmlistbot at llvm.org
Wed May 12 21:09:09 PDT 2021


Author: Matthias Springer
Date: 2021-05-13T13:08:48+09:00
New Revision: 9b77be5583d2da03f2ccd7319d33a2daedf8b1b3

URL: https://github.com/llvm/llvm-project/commit/9b77be5583d2da03f2ccd7319d33a2daedf8b1b3
DIFF: https://github.com/llvm/llvm-project/commit/9b77be5583d2da03f2ccd7319d33a2daedf8b1b3.diff

LOG: [mlir] Unrolled progressive-vector-to-scf.

Instead of an SCF for loop, these pattern generate fully unrolled loops with no temporary buffer allocations.

Differential Revision: https://reviews.llvm.org/D101981

Added: 
    mlir/test/Conversion/VectorToSCF/unrolled-vector-to-loops.mlir

Modified: 
    mlir/include/mlir/Conversion/VectorToSCF/ProgressiveVectorToSCF.h
    mlir/lib/Conversion/VectorToSCF/ProgressiveVectorToSCF.cpp
    mlir/lib/Interfaces/VectorInterfaces.cpp
    mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-1d.mlir
    mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-2d.mlir
    mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-3d.mlir
    mlir/test/lib/Transforms/TestVectorTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/VectorToSCF/ProgressiveVectorToSCF.h b/mlir/include/mlir/Conversion/VectorToSCF/ProgressiveVectorToSCF.h
index 5f765b18d339..b69ec01a0d5c 100644
--- a/mlir/include/mlir/Conversion/VectorToSCF/ProgressiveVectorToSCF.h
+++ b/mlir/include/mlir/Conversion/VectorToSCF/ProgressiveVectorToSCF.h
@@ -47,12 +47,24 @@ class RewritePatternSet;
 /// When applying the pattern a second time, the existing alloca() operation
 /// is reused and only a second vector.type_cast is added.
 
+struct ProgressiveVectorTransferToSCFOptions {
+  bool unroll = false;
+  ProgressiveVectorTransferToSCFOptions &setUnroll(bool u) {
+    unroll = u;
+    return *this;
+  }
+};
+
 /// Collect a set of patterns to convert from the Vector dialect to SCF + std.
 void populateProgressiveVectorToSCFConversionPatterns(
-    RewritePatternSet &patterns);
+    RewritePatternSet &patterns,
+    const ProgressiveVectorTransferToSCFOptions &options =
+        ProgressiveVectorTransferToSCFOptions());
 
 /// Create a pass to convert a subset of vector ops to SCF.
-std::unique_ptr<Pass> createProgressiveConvertVectorToSCFPass();
+std::unique_ptr<Pass> createProgressiveConvertVectorToSCFPass(
+    const ProgressiveVectorTransferToSCFOptions &options =
+        ProgressiveVectorTransferToSCFOptions());
 
 } // namespace mlir
 

diff  --git a/mlir/lib/Conversion/VectorToSCF/ProgressiveVectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/ProgressiveVectorToSCF.cpp
index 3b46df8fcc1b..7a016bd88547 100644
--- a/mlir/lib/Conversion/VectorToSCF/ProgressiveVectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/ProgressiveVectorToSCF.cpp
@@ -262,6 +262,14 @@ static ArrayAttr dropFirstElem(OpBuilder &builder, ArrayAttr attr) {
   return ArrayAttr::get(builder.getContext(), attr.getValue().drop_front());
 }
 
+/// Add the pass label to a vector transfer op if its rank is not the target
+/// rank.
+template <typename OpTy>
+static void maybeApplyPassLabel(OpBuilder &builder, OpTy newXferOp) {
+  if (newXferOp.getVectorType().getRank() > kTargetRank)
+    newXferOp->setAttr(kPassLabel, builder.getUnitAttr());
+}
+
 /// Given a transfer op, find the memref from which the mask is loaded. This
 /// is similar to Strategy<TransferWriteOp>::getBuffer.
 template <typename OpTy>
@@ -352,8 +360,8 @@ struct Strategy<TransferReadOp> {
         AffineMapAttr::get(unpackedPermutationMap(xferOp, builder)),
         xferOp.padding(), Value(), inBoundsAttr).value;
 
-    if (vecType.getRank() > kTargetRank)
-        newXfer.getDefiningOp()->setAttr(kPassLabel, builder.getUnitAttr());
+    maybeApplyPassLabel(builder,
+                        dyn_cast<TransferReadOp>(newXfer.getDefiningOp()));
 
     memref_store(newXfer, buffer, storeIndices);
     return newXfer.getDefiningOp<TransferReadOp>();
@@ -424,15 +432,13 @@ struct Strategy<TransferWriteOp> {
     getXferIndices(xferOp, iv, xferIndices);
 
     auto vec = memref_load(buffer, loadIndices);
-    auto vecType = vec.value.getType().dyn_cast<VectorType>();
     auto inBoundsAttr = dropFirstElem(builder, xferOp.in_boundsAttr());
     auto newXfer = vector_transfer_write(
         Type(), vec, xferOp.source(), xferIndices,
         AffineMapAttr::get(unpackedPermutationMap(xferOp, builder)),
         Value(), inBoundsAttr);
 
-    if (vecType.getRank() > kTargetRank)
-        newXfer.op->setAttr(kPassLabel, builder.getUnitAttr());
+    maybeApplyPassLabel(builder, newXfer.op);
 
     return newXfer;
   }
@@ -663,6 +669,264 @@ struct TransferOpConversion : public OpRewritePattern<OpTy> {
   }
 };
 
+/// If the original transfer op has a mask, compute the mask of the new transfer
+/// op (for the current iteration `i`) and assign it.
+template <typename OpTy>
+static void maybeAssignMask(OpBuilder &builder, OpTy xferOp, OpTy newXferOp,
+                            int64_t i) {
+  if (!xferOp.mask())
+    return;
+
+  if (xferOp.isBroadcastDim(0)) {
+    // To-be-unpacked dimension is a broadcast, which does not have a
+    // corresponding mask dimension. Mask attribute remains unchanged.
+    newXferOp.maskMutable().assign(xferOp.mask());
+    return;
+  }
+
+  if (xferOp.getMaskType().getRank() > 1) {
+    // Unpack one dimension of the mask.
+    OpBuilder::InsertionGuard guard(builder);
+    builder.setInsertionPoint(newXferOp); // Insert load before newXfer.
+
+    llvm::SmallVector<int64_t, 1> indices({i});
+    auto newMask = vector_extract(xferOp.mask(), indices).value;
+    newXferOp.maskMutable().assign(newMask);
+  }
+
+  // If we end up here: The mask of the old transfer op is 1D and the unpacked
+  // dim is not a broadcast, so no mask is needed on the new transfer op.
+  // `generateInBoundsCheck` will have evaluated the mask already.
+}
+
+/// Progressive lowering of vector TransferReadOp with unrolling: Unpack one
+/// dimension. This is similar to TransferOpConversion<TransferReadOp>, but no
+/// memref buffer is allocated and the SCF loop is fully unrolled.
+///
+/// ```
+/// E.g.:
+/// ```
+/// %vec = vector.transfer_read %A[%a, %b, %c], %padding
+///     : memref<?x?x?xf32>, vector<5x4xf32>
+/// ```
+/// is rewritten to IR such as (simplified):
+/// ```
+/// %v_init = splat %padding : vector<5x4xf32>
+/// %tmp0 = vector.transfer_read %A[%a, %b, %c], %padding
+///     : memref<?x?x?xf32>, vector<4xf32>
+/// %v0 = vector.insert %tmp0, %v_init[0] : vector<4xf32> into vector<5x4xf32>
+/// %tmp1 = vector.transfer_read %A[%a, %b + 1, %c], %padding
+///     : memref<?x?x?xf32>, vector<4xf32>
+/// %v1 = vector.insert %tmp1, %v0[1] : vector<4xf32> into vector<5x4xf32>
+/// ...
+/// %tmp4 = vector.transfer_read %A[%a, %b + 4, %c], %padding
+///     : memref<?x?x?xf32>, vector<4xf32>
+/// %vec = vector.insert %tmp1, %v3[4] : vector<4xf32> into vector<5x4xf32>
+/// ```
+///
+/// Note: A pass label is attached to new TransferReadOps, so that subsequent
+/// applications of this pattern do not create an additional %v_init vector.
+struct UnrollTransferReadConversion : public OpRewritePattern<TransferReadOp> {
+  using OpRewritePattern<TransferReadOp>::OpRewritePattern;
+
+  /// Find the result vector %v_init or create a new vector if this the first
+  /// application of the pattern.
+  Value getResultVector(TransferReadOp xferOp,
+                        PatternRewriter &rewriter) const {
+    if (xferOp->hasAttr(kPassLabel)) {
+      return getInsertOp(xferOp).dest();
+    }
+    return std_splat(xferOp.getVectorType(), xferOp.padding()).value;
+  }
+
+  /// Assuming that this not the first application of the pattern, return the
+  /// vector.insert op in which the result of this transfer op is used.
+  vector::InsertOp getInsertOp(TransferReadOp xferOp) const {
+    Operation *xferOpUser = *xferOp->getUsers().begin();
+    return dyn_cast<vector::InsertOp>(xferOpUser);
+  }
+
+  /// Assuming that this not the first application of the pattern, return the
+  /// indices of the vector.insert op in which the result of this transfer op
+  /// is used.
+  void getInsertionIndices(TransferReadOp xferOp,
+                           SmallVector<int64_t, 8> &indices) const {
+    if (xferOp->hasAttr(kPassLabel)) {
+      llvm::for_each(getInsertOp(xferOp).position(), [&](Attribute attr) {
+        indices.push_back(attr.dyn_cast<IntegerAttr>().getInt());
+      });
+    }
+  }
+
+  /// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds
+  /// accesses, and broadcasts and transposes in permutation maps.
+  LogicalResult matchAndRewrite(TransferReadOp xferOp,
+                                PatternRewriter &rewriter) const override {
+    if (xferOp.getVectorType().getRank() <= kTargetRank)
+      return failure();
+
+    ScopedContext scope(rewriter, xferOp.getLoc());
+    auto vec = getResultVector(xferOp, rewriter);
+    auto vecType = vec.getType().dyn_cast<VectorType>();
+    auto xferVecType = xferOp.getVectorType();
+    auto newXferVecType = VectorType::get(xferVecType.getShape().drop_front(),
+                                          xferVecType.getElementType());
+    int64_t dimSize = xferVecType.getShape()[0];
+
+    // Generate fully unrolled loop of transfer ops.
+    for (int64_t i = 0; i < dimSize; ++i) {
+      Value iv = std_constant_index(i);
+
+      vec = generateInBoundsCheck(
+          xferOp, iv, rewriter, unpackedDim(xferOp), TypeRange(vecType),
+          /*inBoundsCase=*/
+          [&](OpBuilder &b, Location loc) {
+            ScopedContext scope(b, loc);
+
+            // Indices for the new transfer op.
+            SmallVector<Value, 8> xferIndices;
+            getXferIndices(xferOp, iv, xferIndices);
+
+            // Indices for the new vector.insert op.
+            SmallVector<int64_t, 8> insertionIndices;
+            getInsertionIndices(xferOp, insertionIndices);
+            insertionIndices.push_back(i);
+
+            auto inBoundsAttr = dropFirstElem(b, xferOp.in_boundsAttr());
+            auto newXferOpVal =
+                vector_transfer_read(
+                    newXferVecType, xferOp.source(), xferIndices,
+                    AffineMapAttr::get(unpackedPermutationMap(xferOp, b)),
+                    xferOp.padding(), Value(), inBoundsAttr)
+                    .value;
+            auto newXferOp =
+                dyn_cast<TransferReadOp>(newXferOpVal.getDefiningOp());
+
+            maybeAssignMask(b, xferOp, newXferOp, i);
+            maybeApplyPassLabel(b, newXferOp);
+
+            return vector_insert(newXferOp, vec, insertionIndices).value;
+          },
+          /*outOfBoundsCase=*/
+          [&](OpBuilder &b, Location loc) {
+            // Loop through original (unmodified) vector.
+            return vec;
+          });
+    }
+
+    if (xferOp->hasAttr(kPassLabel)) {
+      rewriter.replaceOp(getInsertOp(xferOp), vec);
+      rewriter.eraseOp(xferOp);
+    } else {
+      rewriter.replaceOp(xferOp, vec);
+    }
+
+    return success();
+  }
+};
+
+/// Progressive lowering of vector TransferWriteOp with unrolling: Unpack one
+/// dimension. This is similar to TransferOpConversion<TransferWriteOp>, but no
+/// memref buffer is allocated and the SCF loop is fully unrolled.
+///
+/// ```
+/// E.g.:
+/// ```
+/// vector.transfer_write %vec, %A[%a, %b, %c]
+///     : vector<5x4xf32>, memref<?x?x?xf32>
+/// ```
+/// is rewritten to IR such as (simplified):
+/// ```
+/// %v0 = vector.extract %vec[0] : vector<5x4xf32>
+/// vector.transfer_write %v0, %A[%a, %b, %c] : vector<4xf32>, memref<...>
+/// %v1 = vector.extract %vec[1] : vector<5x4xf32>
+/// vector.transfer_write %v1, %A[%a, %b + 1, %c] : vector<4xf32>, memref<...>
+/// ...
+/// %v4 = vector.extract %vec[4] : vector<5x4xf32>
+/// vector.transfer_write %v4, %A[%a, %b + 4, %c] : vector<4xf32>, memref<...>
+/// ```
+///
+/// Note: A pass label is attached to new TransferWriteOps, so that subsequent
+/// applications of this pattern can read the indices of previously generated
+/// vector.extract ops.
+struct UnrollTransferWriteConversion
+    : public OpRewritePattern<TransferWriteOp> {
+  using OpRewritePattern<TransferWriteOp>::OpRewritePattern;
+
+  /// If this is not the first application of the pattern, find the original
+  /// vector %vec that is written by this transfer op. Otherwise, return the
+  /// vector of this transfer op.
+  Value getDataVector(TransferWriteOp xferOp) const {
+    if (xferOp->hasAttr(kPassLabel))
+      return getExtractOp(xferOp).vector();
+    return xferOp.vector();
+  }
+
+  /// Assuming that this is not the first application of the pattern, find the
+  /// vector.extract op whose result is written by this transfer op.
+  vector::ExtractOp getExtractOp(TransferWriteOp xferOp) const {
+    return dyn_cast<vector::ExtractOp>(xferOp.vector().getDefiningOp());
+  }
+
+  void getExtractionIndices(TransferWriteOp xferOp,
+                            SmallVector<int64_t, 8> &indices) const {
+    if (xferOp->hasAttr(kPassLabel)) {
+      llvm::for_each(getExtractOp(xferOp).position(), [&](Attribute attr) {
+        indices.push_back(attr.dyn_cast<IntegerAttr>().getInt());
+      });
+    }
+  }
+
+  /// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds
+  /// accesses, and broadcasts and transposes in permutation maps.
+  LogicalResult matchAndRewrite(TransferWriteOp xferOp,
+                                PatternRewriter &rewriter) const override {
+    if (xferOp.getVectorType().getRank() <= kTargetRank)
+      return failure();
+
+    ScopedContext scope(rewriter, xferOp.getLoc());
+    auto vec = getDataVector(xferOp);
+    auto xferVecType = xferOp.getVectorType();
+    int64_t dimSize = xferVecType.getShape()[0];
+
+    // Generate fully unrolled loop of transfer ops.
+    for (int64_t i = 0; i < dimSize; ++i) {
+      Value iv = std_constant_index(i);
+
+      generateInBoundsCheck(
+          xferOp, iv, rewriter, unpackedDim(xferOp),
+          /*inBoundsCase=*/[&](OpBuilder &b, Location loc) {
+            ScopedContext scope(b, loc);
+
+            // Indices for the new transfer op.
+            SmallVector<Value, 8> xferIndices;
+            getXferIndices(xferOp, iv, xferIndices);
+
+            // Indices for the new vector.extract op.
+            SmallVector<int64_t, 8> extractionIndices;
+            getExtractionIndices(xferOp, extractionIndices);
+            extractionIndices.push_back(i);
+
+            auto extracted = vector_extract(vec, extractionIndices).value;
+            auto inBoundsAttr = dropFirstElem(b, xferOp.in_boundsAttr());
+
+            auto newXferOp =
+                vector_transfer_write(
+                    Type(), extracted, xferOp.source(), xferIndices,
+                    AffineMapAttr::get(unpackedPermutationMap(xferOp, b)),
+                    Value(), inBoundsAttr)
+                    .op;
+
+            maybeAssignMask(b, xferOp, newXferOp, i);
+            maybeApplyPassLabel(b, newXferOp);
+          });
+    }
+
+    rewriter.eraseOp(xferOp);
+    return success();
+  }
+};
+
 /// Compute the indices into the memref for the LoadOp/StoreOp generated as
 /// part of TransferOp1dConversion. Return the memref dimension on which
 /// the transfer is operating. A return value of None indicates a broadcast.
@@ -819,11 +1083,16 @@ struct TransferOp1dConversion : public OpRewritePattern<OpTy> {
 namespace mlir {
 
 void populateProgressiveVectorToSCFConversionPatterns(
-    RewritePatternSet &patterns) {
-  patterns.add<PrepareTransferReadConversion,
-               PrepareTransferWriteConversion,
-               TransferOpConversion<TransferReadOp>,
-               TransferOpConversion<TransferWriteOp>>(patterns.getContext());
+    RewritePatternSet &patterns,
+    const ProgressiveVectorTransferToSCFOptions &options) {
+  if (options.unroll) {
+    patterns.add<UnrollTransferReadConversion, UnrollTransferWriteConversion>(
+        patterns.getContext());
+  } else {
+    patterns.add<PrepareTransferReadConversion, PrepareTransferWriteConversion,
+                 TransferOpConversion<TransferReadOp>,
+                 TransferOpConversion<TransferWriteOp>>(patterns.getContext());
+  }
 
   if (kTargetRank == 1) {
     patterns.add<TransferOp1dConversion<TransferReadOp>,
@@ -834,16 +1103,22 @@ void populateProgressiveVectorToSCFConversionPatterns(
 
 struct ConvertProgressiveVectorToSCFPass
     : public ConvertVectorToSCFBase<ConvertProgressiveVectorToSCFPass> {
+  ConvertProgressiveVectorToSCFPass(
+      const ProgressiveVectorTransferToSCFOptions &opt)
+      : options(opt) {}
+
   void runOnFunction() override {
     RewritePatternSet patterns(getFunction().getContext());
-    populateProgressiveVectorToSCFConversionPatterns(patterns);
+    populateProgressiveVectorToSCFConversionPatterns(patterns, options);
     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
   }
+
+  ProgressiveVectorTransferToSCFOptions options;
 };
 
 }  // namespace mlir
 
-std::unique_ptr<Pass>
-mlir::createProgressiveConvertVectorToSCFPass() {
-  return std::make_unique<ConvertProgressiveVectorToSCFPass>();
+std::unique_ptr<Pass> mlir::createProgressiveConvertVectorToSCFPass(
+    const ProgressiveVectorTransferToSCFOptions &options) {
+  return std::make_unique<ConvertProgressiveVectorToSCFPass>(options);
 }

diff  --git a/mlir/lib/Interfaces/VectorInterfaces.cpp b/mlir/lib/Interfaces/VectorInterfaces.cpp
index 36dfd4ff87a5..625ffa985239 100644
--- a/mlir/lib/Interfaces/VectorInterfaces.cpp
+++ b/mlir/lib/Interfaces/VectorInterfaces.cpp
@@ -16,7 +16,7 @@ VectorType mlir::vector::detail::transferMaskType(VectorType vecType,
   SmallVector<int64_t, 8> shape;
   for (int64_t i = 0; i < vecType.getRank(); ++i) {
     // Only result dims have a corresponding dim in the mask.
-    if (auto expr = map.getResult(i).template isa<AffineDimExpr>()) {
+    if (map.getResult(i).template isa<AffineDimExpr>()) {
       shape.push_back(vecType.getDimSize(i));
     }
   }

diff  --git a/mlir/test/Conversion/VectorToSCF/unrolled-vector-to-loops.mlir b/mlir/test/Conversion/VectorToSCF/unrolled-vector-to-loops.mlir
new file mode 100644
index 000000000000..f90d20a518a6
--- /dev/null
+++ b/mlir/test/Conversion/VectorToSCF/unrolled-vector-to-loops.mlir
@@ -0,0 +1,86 @@
+// RUN: mlir-opt %s -test-unrolled-progressive-convert-vector-to-scf -split-input-file -allow-unregistered-dialect | FileCheck %s
+
+// CHECK-LABEL: func @transfer_read_inbounds
+func @transfer_read_inbounds(%A : memref<?x?x?xf32>) -> (vector<2x3x4xf32>) {
+  %f0 = constant 0.0: f32
+  %c0 = constant 0: index
+
+  // CHECK:      vector.transfer_read {{.*}} : memref<?x?x?xf32>, vector<4xf32>
+  // CHECK-NEXT: vector.insert {{.*}} [0, 0] : vector<4xf32> into vector<2x3x4xf32>
+  // CHECK-NEXT: vector.transfer_read {{.*}} : memref<?x?x?xf32>, vector<4xf32>
+  // CHECK-NEXT: vector.insert {{.*}} [0, 1] : vector<4xf32> into vector<2x3x4xf32>
+  // CHECK-NEXT: vector.transfer_read {{.*}} : memref<?x?x?xf32>, vector<4xf32>
+  // CHECK-NEXT: vector.insert {{.*}} [0, 2] : vector<4xf32> into vector<2x3x4xf32>
+  // CHECK-NEXT: vector.transfer_read {{.*}} : memref<?x?x?xf32>, vector<4xf32>
+  // CHECK-NEXT: vector.insert {{.*}} [1, 0] : vector<4xf32> into vector<2x3x4xf32>
+  // CHECK-NEXT: vector.transfer_read {{.*}} : memref<?x?x?xf32>, vector<4xf32>
+  // CHECK-NEXT: vector.insert {{.*}} [1, 1] : vector<4xf32> into vector<2x3x4xf32>
+  // CHECK-NEXT: vector.transfer_read {{.*}} : memref<?x?x?xf32>, vector<4xf32>
+  // CHECK-NEXT: vector.insert {{.*}} [1, 2] : vector<4xf32> into vector<2x3x4xf32>
+  // CHECK-NOT: scf.if
+  // CHECK-NOT: scf.for
+  %vec = vector.transfer_read %A[%c0, %c0, %c0], %f0 {in_bounds = [true, true, true]} : memref<?x?x?xf32>, vector<2x3x4xf32>
+  return %vec : vector<2x3x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @transfer_read_out_of_bounds
+func @transfer_read_out_of_bounds(%A : memref<?x?x?xf32>) -> (vector<2x3x4xf32>) {
+  %f0 = constant 0.0: f32
+  %c0 = constant 0: index
+
+  // CHECK: scf.if
+  // CHECK: scf.if
+  // CHECK: vector.transfer_read {{.*}} : memref<?x?x?xf32>, vector<4xf32>
+  // CHECK: vector.insert {{.*}} [0, 0] : vector<4xf32> into vector<2x3x4xf32>
+  // CHECK: scf.if
+  // CHECK: vector.transfer_read {{.*}} : memref<?x?x?xf32>, vector<4xf32>
+  // CHECK: vector.insert {{.*}} [0, 1] : vector<4xf32> into vector<2x3x4xf32>
+  // CHECK: scf.if
+  // CHECK: vector.transfer_read {{.*}} : memref<?x?x?xf32>, vector<4xf32>
+  // CHECK: vector.insert {{.*}} [0, 2] : vector<4xf32> into vector<2x3x4xf32>
+  // CHECK: scf.if
+  // CHECK: scf.if
+  // CHECK: vector.transfer_read {{.*}} : memref<?x?x?xf32>, vector<4xf32>
+  // CHECK: vector.insert {{.*}} [1, 0] : vector<4xf32> into vector<2x3x4xf32>
+  // CHECK: scf.if
+  // CHECK: vector.transfer_read {{.*}} : memref<?x?x?xf32>, vector<4xf32>
+  // CHECK: vector.insert {{.*}} [1, 1] : vector<4xf32> into vector<2x3x4xf32>
+  // CHECK: scf.if
+  // CHECK: vector.transfer_read {{.*}} : memref<?x?x?xf32>, vector<4xf32>
+  // CHECK: vector.insert {{.*}} [1, 2] : vector<4xf32> into vector<2x3x4xf32>
+  // CHECK-NOT: scf.for
+  %vec = vector.transfer_read %A[%c0, %c0, %c0], %f0 : memref<?x?x?xf32>, vector<2x3x4xf32>
+  return %vec : vector<2x3x4xf32>
+}
+
+// -----
+
+func @transfer_read_mask(%A : memref<?x?x?xf32>, %mask : vector<2x3x4xi1>) -> (vector<2x3x4xf32>) {
+  %f0 = constant 0.0: f32
+  %c0 = constant 0: index
+
+  // CHECK:      vector.extract %{{.*}}[0, 0] : vector<2x3x4xi1>
+  // CHECK-NEXT: vector.transfer_read {{.*}} : memref<?x?x?xf32>, vector<4xf32>
+  // CHECK-NEXT: vector.insert {{.*}} [0, 0] : vector<4xf32> into vector<2x3x4xf32>
+  // CHECK-NEXT: vector.extract %{{.*}}[0, 1] : vector<2x3x4xi1>
+  // CHECK-NEXT: vector.transfer_read {{.*}} : memref<?x?x?xf32>, vector<4xf32>
+  // CHECK-NEXT: vector.insert {{.*}} [0, 1] : vector<4xf32> into vector<2x3x4xf32>
+  // CHECK-NEXT: vector.extract %{{.*}}[0, 2] : vector<2x3x4xi1>
+  // CHECK-NEXT: vector.transfer_read {{.*}} : memref<?x?x?xf32>, vector<4xf32>
+  // CHECK-NEXT: vector.insert {{.*}} [0, 2] : vector<4xf32> into vector<2x3x4xf32>
+  // CHECK-NEXT: vector.extract %{{.*}}[1, 0] : vector<2x3x4xi1>
+  // CHECK-NEXT: vector.transfer_read {{.*}} : memref<?x?x?xf32>, vector<4xf32>
+  // CHECK-NEXT: vector.insert {{.*}} [1, 0] : vector<4xf32> into vector<2x3x4xf32>
+  // CHECK-NEXT: vector.extract %{{.*}}[1, 1] : vector<2x3x4xi1>
+  // CHECK-NEXT: vector.transfer_read {{.*}} : memref<?x?x?xf32>, vector<4xf32>
+  // CHECK-NEXT: vector.insert {{.*}} [1, 1] : vector<4xf32> into vector<2x3x4xf32>
+  // CHECK-NEXT: vector.extract %{{.*}}[1, 2] : vector<2x3x4xi1>
+  // CHECK-NEXT: vector.transfer_read {{.*}} : memref<?x?x?xf32>, vector<4xf32>
+  // CHECK-NEXT: vector.insert {{.*}} [1, 2] : vector<4xf32> into vector<2x3x4xf32>
+  // CHECK-NOT: scf.if
+  // CHECK-NOT: scf.for
+  %vec = vector.transfer_read %A[%c0, %c0, %c0], %f0, %mask {in_bounds = [true, true, true]}: memref<?x?x?xf32>, vector<2x3x4xf32>
+  return %vec : vector<2x3x4xf32>
+}

diff  --git a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-1d.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-1d.mlir
index 4521b293ddb3..d7dc9d6f1e59 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-1d.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-1d.mlir
@@ -1,5 +1,10 @@
 // RUN: mlir-opt %s -test-progressive-convert-vector-to-scf -lower-affine -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \
-// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
+// RUN: mlir-cpu-runner -e entry -entry-point-result=void  \
+// RUN:   -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
+// RUN: FileCheck %s
+
+// RUN: mlir-opt %s -test-unrolled-progressive-convert-vector-to-scf -lower-affine -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \
+// RUN: mlir-cpu-runner -e entry -entry-point-result=void  \
 // RUN:   -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
 // RUN: FileCheck %s
 

diff  --git a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-2d.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-2d.mlir
index c391e3d39a2c..1fc11fab8528 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-2d.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-2d.mlir
@@ -3,6 +3,11 @@
 // RUN:   -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
 // RUN: FileCheck %s
 
+// RUN: mlir-opt %s -test-unrolled-progressive-convert-vector-to-scf -lower-affine -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \
+// RUN: mlir-cpu-runner -e entry -entry-point-result=void  \
+// RUN:   -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
+// RUN: FileCheck %s
+
 memref.global "private" @gv : memref<3x4xf32> = dense<[[0. , 1. , 2. , 3. ],
                                                        [10., 11., 12., 13.],
                                                        [20., 21., 22., 23.]]>

diff  --git a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-3d.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-3d.mlir
index ff64dbbc8e4c..902f4f50d223 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-3d.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-3d.mlir
@@ -3,6 +3,11 @@
 // RUN:   -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
 // RUN: FileCheck %s
 
+// RUN: mlir-opt %s -test-unrolled-progressive-convert-vector-to-scf -lower-affine -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \
+// RUN: mlir-cpu-runner -e entry -entry-point-result=void  \
+// RUN:   -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
+// RUN: FileCheck %s
+
 func @transfer_read_3d(%A : memref<?x?x?x?xf32>,
                        %o: index, %a: index, %b: index, %c: index) {
   %fm42 = constant -42.0: f32

diff  --git a/mlir/test/lib/Transforms/TestVectorTransforms.cpp b/mlir/test/lib/Transforms/TestVectorTransforms.cpp
index 0e62d2c09be7..d1ac5e1b994f 100644
--- a/mlir/test/lib/Transforms/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestVectorTransforms.cpp
@@ -390,16 +390,20 @@ struct TestVectorMultiReductionLoweringPatterns
   }
 };
 
+template <bool Unroll>
 struct TestProgressiveVectorToSCFLoweringPatterns
-    : public PassWrapper<TestProgressiveVectorToSCFLoweringPatterns,
+    : public PassWrapper<TestProgressiveVectorToSCFLoweringPatterns<Unroll>,
                          FunctionPass> {
   void getDependentDialects(DialectRegistry &registry) const override {
     registry.insert<memref::MemRefDialect, scf::SCFDialect, AffineDialect>();
   }
   void runOnFunction() override {
-    RewritePatternSet patterns(&getContext());
-    populateProgressiveVectorToSCFConversionPatterns(patterns);
-    (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
+    RewritePatternSet patterns(&this->getContext());
+    ProgressiveVectorTransferToSCFOptions options;
+    options.unroll = Unroll;
+    populateProgressiveVectorToSCFConversionPatterns(patterns, options);
+    (void)applyPatternsAndFoldGreedily(this->getFunction(),
+                                       std::move(patterns));
   }
 };
 
@@ -450,9 +454,18 @@ void registerTestVectorConversions() {
       "test-vector-transfer-lowering-patterns",
       "Test conversion patterns to lower transfer ops to other vector ops");
 
-  PassRegistration<TestProgressiveVectorToSCFLoweringPatterns> transferOpToSCF(
-      "test-progressive-convert-vector-to-scf",
-      "Test conversion patterns to progressively lower transfer ops to SCF");
+  PassRegistration<TestProgressiveVectorToSCFLoweringPatterns<
+      /*Unroll=*/false>>
+      transferOpToSCF("test-progressive-convert-vector-to-scf",
+                      "Test conversion patterns to progressively lower "
+                      "transfer ops to SCF");
+
+  PassRegistration<TestProgressiveVectorToSCFLoweringPatterns<
+      /*Unroll=*/true>>
+      transferOpToSCFUnrolled(
+          "test-unrolled-progressive-convert-vector-to-scf",
+          "Test conversion patterns to progressively lower transfer ops to SCF"
+          "(unrolled variant)");
 
   PassRegistration<TestVectorMultiReductionLoweringPatterns>
       multiDimReductionOpLoweringPass(


        


More information about the Mlir-commits mailing list