[Mlir-commits] [mlir] 35b65be - [mlir][Vector] Add transformation + pattern to split vector.transfer_read into full and partial copies.

Nicolas Vasilache llvmlistbot at llvm.org
Mon Aug 3 01:55:43 PDT 2020


Author: Nicolas Vasilache
Date: 2020-08-03T04:53:43-04:00
New Revision: 35b65be041127db9fe23d3128a004c888893cbae

URL: https://github.com/llvm/llvm-project/commit/35b65be041127db9fe23d3128a004c888893cbae
DIFF: https://github.com/llvm/llvm-project/commit/35b65be041127db9fe23d3128a004c888893cbae.diff

LOG: [mlir][Vector] Add transformation + pattern to split vector.transfer_read into full and partial copies.

This revision adds a transformation and a pattern that rewrites a "maybe masked" `vector.transfer_read %view[...], %pad `into a pattern resembling:

```
   %1:3 = scf.if (%inBounds) {
      scf.yield %view : memref<A...>, index, index
    } else {
      %2 = vector.transfer_read %view[...], %pad : memref<A...>, vector<...>
      %3 = vector.type_cast %extra_alloc : memref<...> to
      memref<vector<...>> store %2, %3[] : memref<vector<...>> %4 =
      memref_cast %extra_alloc: memref<B...> to memref<A...> scf.yield %4 :
      memref<A...>, index, index
   }
   %res= vector.transfer_read %1#0[%1#1, %1#2] {masked = [false ... false]}
```
where `extra_alloc` is a top of the function alloca'ed buffer of one vector.

This rewrite makes it possible to realize the "always full tile" abstraction where vector.transfer_read operations are guaranteed to read from a padded full buffer.
The extra work only occurs on the boundary tiles.

Differential Revision: https://reviews.llvm.org/D84631

Added: 
    mlir/test/Dialect/Vector/vector-transfer-full-partial-split.mlir

Modified: 
    mlir/include/mlir/Dialect/Vector/VectorTransforms.h
    mlir/include/mlir/Interfaces/VectorInterfaces.td
    mlir/lib/Dialect/Vector/VectorTransforms.cpp
    mlir/test/lib/Transforms/TestVectorTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
index 0d18c5aa782d..835ad18a79ad 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
@@ -17,6 +17,11 @@
 namespace mlir {
 class MLIRContext;
 class OwningRewritePatternList;
+class VectorTransferOpInterface;
+
+namespace scf {
+class IfOp;
+} // namespace scf
 
 /// Collect a set of patterns to convert from the Vector dialect to itself.
 /// Should be merged with populateVectorToSCFLoweringPattern.
@@ -104,6 +109,65 @@ struct UnrollVectorPattern : public OpRewritePattern<OpTy> {
   FilterConstraintType filter;
 };
 
+/// Split a vector.transfer operation into an unmasked fastpath vector.transfer
+/// and a slowpath masked vector.transfer. If `ifOp` is not null and the result
+/// is `success, the `ifOp` points to the newly created conditional upon
+/// function return. To accomodate for the fact that the original
+/// vector.transfer indexing may be arbitrary and the slow path indexes @[0...0]
+/// in the temporary buffer, the scf.if op returns a view and values of type
+/// index. At this time, only vector.transfer_read is implemented.
+///
+/// Example (a 2-D vector.transfer_read):
+/// ```
+///    %1 = vector.transfer_read %0[...], %pad : memref<A...>, vector<...>
+/// ```
+/// is transformed into:
+/// ```
+///    %1:3 = scf.if (%inBounds) {
+///       scf.yield %0 : memref<A...>, index, index
+///     } else {
+///       %2 = vector.transfer_read %0[...], %pad : memref<A...>, vector<...>
+///       %3 = vector.type_cast %extra_alloc : memref<...> to
+///       memref<vector<...>> store %2, %3[] : memref<vector<...>> %4 =
+///       memref_cast %extra_alloc: memref<B...> to memref<A...> scf.yield %4 :
+///       memref<A...>, index, index
+//     }
+///    %0 = vector.transfer_read %1#0[%1#1, %1#2] {masked = [false ... false]}
+/// ```
+/// where `extra_alloc` is a top of the function alloca'ed buffer of one vector.
+///
+/// Preconditions:
+///  1. `xferOp.permutation_map()` must be a minor identity map
+///  2. the rank of the `xferOp.memref()` and the rank of the `xferOp.vector()`
+///  must be equal. This will be relaxed in the future but requires
+///  rank-reducing subviews.
+LogicalResult
+splitFullAndPartialTransferPrecondition(VectorTransferOpInterface xferOp);
+LogicalResult splitFullAndPartialTransfer(OpBuilder &b,
+                                          VectorTransferOpInterface xferOp,
+                                          scf::IfOp *ifOp = nullptr);
+
+/// Apply `splitFullAndPartialTransfer` selectively via a pattern. This pattern
+/// may take an extra filter to perform selection at a finer granularity.
+struct VectorTransferFullPartialRewriter : public RewritePattern {
+  using FilterConstraintType =
+      std::function<LogicalResult(VectorTransferOpInterface op)>;
+
+  explicit VectorTransferFullPartialRewriter(
+      MLIRContext *context,
+      FilterConstraintType filter =
+          [](VectorTransferOpInterface op) { return success(); },
+      PatternBenefit benefit = 1)
+      : RewritePattern(benefit, MatchAnyOpTypeTag()), filter(filter) {}
+
+  /// Performs the rewrite.
+  LogicalResult matchAndRewrite(Operation *op,
+                                PatternRewriter &rewriter) const override;
+
+private:
+  FilterConstraintType filter;
+};
+
 } // namespace vector
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Interfaces/VectorInterfaces.td b/mlir/include/mlir/Interfaces/VectorInterfaces.td
index aefbb7d47117..218715318a86 100644
--- a/mlir/include/mlir/Interfaces/VectorInterfaces.td
+++ b/mlir/include/mlir/Interfaces/VectorInterfaces.td
@@ -160,6 +160,19 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> {
       /*defaultImplementation=*/
         "return $_op.getMemRefType().getRank() - $_op.getTransferRank();"
     >,
+    InterfaceMethod<
+      /*desc=*/[{ Returns true if at least one of the dimensions is masked.}],
+      /*retTy=*/"bool",
+      /*methodName=*/"hasMaskedDim",
+      /*args=*/(ins),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        for (unsigned idx = 0, e = $_op.getTransferRank(); idx < e; ++idx)
+          if ($_op.isMaskedDim(idx))
+            return true;
+        return false;
+      }]
+    >,
     InterfaceMethod<
       /*desc=*/[{
       Helper function to account for the fact that `permutationMap` results and

diff  --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index 197b1c62274b..573b822503f3 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -12,9 +12,13 @@
 
 #include <type_traits>
 
+#include "mlir/Dialect/Affine/EDSC/Intrinsics.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/SCF/EDSC/Intrinsics.h"
+#include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
+#include "mlir/Dialect/Vector/EDSC/Intrinsics.h"
 #include "mlir/Dialect/Vector/VectorOps.h"
 #include "mlir/Dialect/Vector/VectorTransforms.h"
 #include "mlir/Dialect/Vector/VectorUtils.h"
@@ -1985,6 +1989,236 @@ Value ContractionOpLowering::lowerReduction(vector::ContractionOp op,
 
 } // namespace mlir
 
+static Optional<int64_t> extractConstantIndex(Value v) {
+  if (auto cstOp = v.getDefiningOp<ConstantIndexOp>())
+    return cstOp.getValue();
+  if (auto affineApplyOp = v.getDefiningOp<AffineApplyOp>())
+    if (affineApplyOp.getAffineMap().isSingleConstant())
+      return affineApplyOp.getAffineMap().getSingleConstantResult();
+  return None;
+}
+
+// Missing foldings of scf.if make it necessary to perform poor man's folding
+// eagerly, especially in the case of unrolling. In the future, this should go
+// away once scf.if folds properly.
+static Value createScopedFoldedSLE(Value v, Value ub) {
+  using namespace edsc::op;
+  auto maybeCstV = extractConstantIndex(v);
+  auto maybeCstUb = extractConstantIndex(ub);
+  if (maybeCstV && maybeCstUb && *maybeCstV < *maybeCstUb)
+    return Value();
+  return sle(v, ub);
+}
+
+// Operates under a scoped context to build the condition to ensure that a
+// particular VectorTransferOpInterface is unmasked.
+static Value createScopedInBoundsCond(VectorTransferOpInterface xferOp) {
+  assert(xferOp.permutation_map().isMinorIdentity() &&
+         "Expected minor identity map");
+  Value inBoundsCond;
+  xferOp.zipResultAndIndexing([&](int64_t resultIdx, int64_t indicesIdx) {
+    // Zip over the resulting vector shape and memref indices.
+    // If the dimension is known to be unmasked, it does not participate in the
+    // construction of `inBoundsCond`.
+    if (!xferOp.isMaskedDim(resultIdx))
+      return;
+    int64_t vectorSize = xferOp.getVectorType().getDimSize(resultIdx);
+    using namespace edsc::op;
+    using namespace edsc::intrinsics;
+    // Fold or create the check that `index + vector_size` <= `memref_size`.
+    Value sum = xferOp.indices()[indicesIdx] + std_constant_index(vectorSize);
+    Value cond =
+        createScopedFoldedSLE(sum, std_dim(xferOp.memref(), indicesIdx));
+    if (!cond)
+      return;
+    // Conjunction over all dims for which we are in-bounds.
+    inBoundsCond = inBoundsCond ? inBoundsCond && cond : cond;
+  });
+  return inBoundsCond;
+}
+
+LogicalResult mlir::vector::splitFullAndPartialTransferPrecondition(
+    VectorTransferOpInterface xferOp) {
+  // TODO: expand support to these 2 cases.
+  if (!xferOp.permutation_map().isMinorIdentity())
+    return failure();
+  // TODO: relax this precondition. This will require rank-reducing subviews.
+  if (xferOp.getMemRefType().getRank() != xferOp.getTransferRank())
+    return failure();
+  // Must have some masked dimension to be a candidate for splitting.
+  if (!xferOp.hasMaskedDim())
+    return failure();
+  // Don't split transfer operations under IfOp, this avoids applying the
+  // pattern recursively.
+  // TODO: improve the condition to make it more applicable.
+  if (xferOp.getParentOfType<scf::IfOp>())
+    return failure();
+  return success();
+}
+
+MemRefType getCastCompatibleMemRefType(MemRefType aT, MemRefType bT) {
+  if (MemRefCastOp::areCastCompatible(aT, bT))
+    return aT;
+  if (aT.getRank() != bT.getRank())
+    return MemRefType();
+  int64_t aOffset, bOffset;
+  SmallVector<int64_t, 4> aStrides, bStrides;
+  if (failed(getStridesAndOffset(aT, aStrides, aOffset)) ||
+      failed(getStridesAndOffset(bT, bStrides, bOffset)) ||
+      aStrides.size() != bStrides.size())
+    return MemRefType();
+
+  ArrayRef<int64_t> aShape = aT.getShape(), bShape = bT.getShape();
+  int64_t resOffset;
+  SmallVector<int64_t, 4> resShape(aT.getRank(), 0),
+      resStrides(bT.getRank(), 0);
+  for (int64_t idx = 0, e = aT.getRank(); idx < e; ++idx) {
+    resShape[idx] =
+        (aShape[idx] == bShape[idx]) ? aShape[idx] : MemRefType::kDynamicSize;
+    resStrides[idx] = (aStrides[idx] == bStrides[idx])
+                          ? aStrides[idx]
+                          : MemRefType::kDynamicStrideOrOffset;
+  }
+  resOffset =
+      (aOffset == bOffset) ? aOffset : MemRefType::kDynamicStrideOrOffset;
+  return MemRefType::get(
+      resShape, aT.getElementType(),
+      makeStridedLinearLayoutMap(resStrides, resOffset, aT.getContext()));
+}
+
+/// Split a vector.transfer operation into an unmasked fastpath vector.transfer
+/// and a slowpath masked vector.transfer. If `ifOp` is not null and the result
+/// is `success, the `ifOp` points to the newly created conditional upon
+/// function return. To accomodate for the fact that the original
+/// vector.transfer indexing may be arbitrary and the slow path indexes @[0...0]
+/// in the temporary buffer, the scf.if op returns a view and values of type
+/// index. At this time, only vector.transfer_read is implemented.
+///
+/// Example (a 2-D vector.transfer_read):
+/// ```
+///    %1 = vector.transfer_read %0[...], %pad : memref<A...>, vector<...>
+/// ```
+/// is transformed into:
+/// ```
+///    %1:3 = scf.if (%inBounds) {
+///       scf.yield %0 : memref<A...>, index, index
+///     } else {
+///       %2 = vector.transfer_read %0[...], %pad : memref<A...>, vector<...>
+///       %3 = vector.type_cast %extra_alloc : memref<...> to
+///       memref<vector<...>> store %2, %3[] : memref<vector<...>> %4 =
+///       memref_cast %extra_alloc: memref<B...> to memref<A...> scf.yield %4 :
+///       memref<A...>, index, index
+//     }
+///    %0 = vector.transfer_read %1#0[%1#1, %1#2] {masked = [false ... false]}
+/// ```
+/// where `extra_alloc` is a top of the function alloca'ed buffer of one vector.
+///
+/// Preconditions:
+///  1. `xferOp.permutation_map()` must be a minor identity map
+///  2. the rank of the `xferOp.memref()` and the rank of the `xferOp.vector()`
+///  must be equal. This will be relaxed in the future but requires
+///  rank-reducing subviews.
+LogicalResult mlir::vector::splitFullAndPartialTransfer(
+    OpBuilder &b, VectorTransferOpInterface xferOp, scf::IfOp *ifOp) {
+  using namespace edsc;
+  using namespace edsc::intrinsics;
+
+  assert(succeeded(splitFullAndPartialTransferPrecondition(xferOp)) &&
+         "Expected splitFullAndPartialTransferPrecondition to hold");
+  auto xferReadOp = dyn_cast<vector::TransferReadOp>(xferOp.getOperation());
+
+  // TODO: add support for write case.
+  if (!xferReadOp)
+    return failure();
+
+  OpBuilder::InsertionGuard guard(b);
+  if (xferOp.memref().getDefiningOp())
+    b.setInsertionPointAfter(xferOp.memref().getDefiningOp());
+  else
+    b.setInsertionPoint(xferOp);
+  ScopedContext scope(b, xferOp.getLoc());
+  Value inBoundsCond = createScopedInBoundsCond(
+      cast<VectorTransferOpInterface>(xferOp.getOperation()));
+  if (!inBoundsCond)
+    return failure();
+
+  // Top of the function `alloc` for transient storage.
+  Value alloc;
+  {
+    FuncOp funcOp = xferOp.getParentOfType<FuncOp>();
+    OpBuilder::InsertionGuard guard(b);
+    b.setInsertionPointToStart(&funcOp.getRegion().front());
+    auto shape = xferOp.getVectorType().getShape();
+    Type elementType = xferOp.getVectorType().getElementType();
+    alloc = std_alloca(MemRefType::get(shape, elementType), ValueRange{},
+                       b.getI64IntegerAttr(32));
+  }
+
+  Value memref = xferOp.memref();
+  SmallVector<bool, 4> bools(xferOp.getTransferRank(), false);
+  auto unmaskedAttr = b.getBoolArrayAttr(bools);
+
+  MemRefType compatibleMemRefType = getCastCompatibleMemRefType(
+      xferOp.getMemRefType(), alloc.getType().cast<MemRefType>());
+
+  // Read case: full fill + partial copy -> unmasked vector.xfer_read.
+  Value zero = std_constant_index(0);
+  SmallVector<Type, 4> returnTypes(1 + xferOp.getTransferRank(),
+                                   b.getIndexType());
+  returnTypes[0] = compatibleMemRefType;
+  scf::IfOp fullPartialIfOp;
+  conditionBuilder(
+      returnTypes, inBoundsCond,
+      [&]() -> scf::ValueVector {
+        Value res = memref;
+        if (compatibleMemRefType != xferOp.getMemRefType())
+          res = std_memref_cast(memref, compatibleMemRefType);
+        scf::ValueVector viewAndIndices{res};
+        viewAndIndices.insert(viewAndIndices.end(), xferOp.indices().begin(),
+                              xferOp.indices().end());
+        return viewAndIndices;
+      },
+      [&]() -> scf::ValueVector {
+        Operation *newXfer =
+            ScopedContext::getBuilderRef().clone(*xferOp.getOperation());
+        Value vector = cast<VectorTransferOpInterface>(newXfer).vector();
+        std_store(vector, vector_type_cast(
+                              MemRefType::get({}, vector.getType()), alloc));
+
+        Value casted = std_memref_cast(alloc, compatibleMemRefType);
+        scf::ValueVector viewAndIndices{casted};
+        viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(),
+                              zero);
+
+        return viewAndIndices;
+      },
+      &fullPartialIfOp);
+  if (ifOp)
+    *ifOp = fullPartialIfOp;
+
+  // Unmask the existing read op, it always reads from a full buffer.
+  for (unsigned i = 0, e = returnTypes.size(); i != e; ++i)
+    xferReadOp.setOperand(i, fullPartialIfOp.getResult(i));
+  xferOp.setAttr(vector::TransferReadOp::getMaskedAttrName(), unmaskedAttr);
+
+  return success();
+}
+
+LogicalResult mlir::vector::VectorTransferFullPartialRewriter::matchAndRewrite(
+    Operation *op, PatternRewriter &rewriter) const {
+  auto xferOp = dyn_cast<VectorTransferOpInterface>(op);
+  if (!xferOp || failed(splitFullAndPartialTransferPrecondition(xferOp)) ||
+      failed(filter(xferOp)))
+    return failure();
+  rewriter.startRootUpdate(xferOp);
+  if (succeeded(splitFullAndPartialTransfer(rewriter, xferOp))) {
+    rewriter.finalizeRootUpdate(xferOp);
+    return success();
+  }
+  rewriter.cancelRootUpdate(xferOp);
+  return failure();
+}
+
 // TODO: Add pattern to rewrite ExtractSlices(ConstantMaskOp).
 // TODO: Add this as DRR pattern.
 void mlir::vector::populateVectorToVectorTransformationPatterns(

diff  --git a/mlir/test/Dialect/Vector/vector-transfer-full-partial-split.mlir b/mlir/test/Dialect/Vector/vector-transfer-full-partial-split.mlir
new file mode 100644
index 000000000000..ef76247ee9d4
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-transfer-full-partial-split.mlir
@@ -0,0 +1,102 @@
+// RUN: mlir-opt %s -test-vector-transfer-full-partial-split | FileCheck %s
+
+// CHECK-DAG: #[[$map_p4:.*]] = affine_map<()[s0] -> (s0 + 4)>
+// CHECK-DAG: #[[$map_p8:.*]] = affine_map<()[s0] -> (s0 + 8)>
+// CHECK-DAG: #[[$map_2d_stride_1:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
+
+// CHECK-LABEL: split_vector_transfer_read_2d(
+//  CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: memref
+//  CHECK-SAME: %[[i:[a-zA-Z0-9]*]]: index
+//  CHECK-SAME: %[[j:[a-zA-Z0-9]*]]: index
+func @split_vector_transfer_read_2d(%A: memref<?x8xf32>, %i: index, %j: index) -> vector<4x8xf32> {
+  %c0 = constant 0 : index
+  %f0 = constant 0.0 : f32
+
+  //  CHECK-DAG: %[[c0:.*]] = constant 0 : index
+  //  CHECK-DAG: %[[c8:.*]] = constant 8 : index
+  //  CHECK-DAG: %[[cst:.*]] = constant 0.000000e+00 : f32
+  // alloca for boundary full tile
+  //      CHECK: %[[alloc:.*]] = alloca() {alignment = 32 : i64} : memref<4x8xf32>
+  // %i + 4 <= dim(%A, 0)
+  //      CHECK: %[[idx0:.*]] = affine.apply #[[$map_p4]]()[%[[i]]]
+  //      CHECK: %[[d0:.*]] = dim %[[A]], %[[c0]] : memref<?x8xf32>
+  //      CHECK: %[[cmp0:.*]] = cmpi "sle", %[[idx0]], %[[d0]] : index
+  // %j + 8 <= dim(%A, 1)
+  //      CHECK: %[[idx1:.*]] = affine.apply #[[$map_p8]]()[%[[j]]]
+  //      CHECK: %[[cmp1:.*]] = cmpi "sle", %[[idx1]], %[[c8]] : index
+  // are both conds true
+  //      CHECK: %[[cond:.*]] = and %[[cmp0]], %[[cmp1]] : i1
+  //      CHECK: %[[ifres:.*]]:3 = scf.if %[[cond]] -> (memref<?x8xf32>, index, index) {
+  //               inBounds, just yield %A
+  //      CHECK:   scf.yield %[[A]], %[[i]], %[[j]] : memref<?x8xf32>, index, index
+  //      CHECK: } else {
+  //               slow path, fill tmp alloc and yield a memref_casted version of it
+  //      CHECK:   %[[slow:.*]] = vector.transfer_read %[[A]][%[[i]], %[[j]]], %cst :
+  // CHECK-SAME:     memref<?x8xf32>, vector<4x8xf32>
+  //      CHECK:   %[[cast_alloc:.*]] = vector.type_cast %[[alloc]] :
+  // CHECK-SAME:     memref<4x8xf32> to memref<vector<4x8xf32>>
+  //      CHECK:   store %[[slow]], %[[cast_alloc]][] : memref<vector<4x8xf32>>
+  //      CHECK:   %[[yielded:.*]] = memref_cast %[[alloc]] :
+  // CHECK-SAME:     memref<4x8xf32> to memref<?x8xf32>
+  //      CHECK:   scf.yield %[[yielded]], %[[c0]], %[[c0]] :
+  // CHECK-SAME:     memref<?x8xf32>, index, index
+  //      CHECK: }
+  //      CHECK: %[[res:.*]] = vector.transfer_read %[[ifres]]#0[%[[ifres]]#1, %[[ifres]]#2], %[[cst]]
+  // CHECK_SAME:   {masked = [false, false]} : memref<?x8xf32>, vector<4x8xf32>
+  %1 = vector.transfer_read %A[%i, %j], %f0 : memref<?x8xf32>, vector<4x8xf32>
+
+  // CHECK: return %[[res]] : vector<4x8xf32>
+  return %1: vector<4x8xf32>
+}
+
+// CHECK-LABEL: split_vector_transfer_read_strided_2d(
+//  CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: memref
+//  CHECK-SAME: %[[i:[a-zA-Z0-9]*]]: index
+//  CHECK-SAME: %[[j:[a-zA-Z0-9]*]]: index
+func @split_vector_transfer_read_strided_2d(
+    %A: memref<7x8xf32, offset:?, strides:[?, 1]>,
+    %i: index, %j: index) -> vector<4x8xf32> {
+  %c0 = constant 0 : index
+  %f0 = constant 0.0 : f32
+
+  //  CHECK-DAG: %[[c0:.*]] = constant 0 : index
+  //  CHECK-DAG: %[[c7:.*]] = constant 7 : index
+  //  CHECK-DAG: %[[c8:.*]] = constant 8 : index
+  //  CHECK-DAG: %[[cst:.*]] = constant 0.000000e+00 : f32
+  // alloca for boundary full tile
+  //      CHECK: %[[alloc:.*]] = alloca() {alignment = 32 : i64} : memref<4x8xf32>
+  // %i + 4 <= dim(%A, 0)
+  //      CHECK: %[[idx0:.*]] = affine.apply #[[$map_p4]]()[%[[i]]]
+  //      CHECK: %[[cmp0:.*]] = cmpi "sle", %[[idx0]], %[[c7]] : index
+  // %j + 8 <= dim(%A, 1)
+  //      CHECK: %[[idx1:.*]] = affine.apply #[[$map_p8]]()[%[[j]]]
+  //      CHECK: %[[cmp1:.*]] = cmpi "sle", %[[idx1]], %[[c8]] : index
+  // are both conds true
+  //      CHECK: %[[cond:.*]] = and %[[cmp0]], %[[cmp1]] : i1
+  //      CHECK: %[[ifres:.*]]:3 = scf.if %[[cond]] -> (memref<?x8xf32, #[[$map_2d_stride_1]]>, index, index) {
+  //               inBounds but not cast-compatible: yield a memref_casted form of %A
+  //      CHECK:   %[[casted:.*]] = memref_cast %arg0 :
+  // CHECK-SAME:     memref<7x8xf32, #[[$map_2d_stride_1]]> to memref<?x8xf32, #[[$map_2d_stride_1]]>
+  //      CHECK:   scf.yield %[[casted]], %[[i]], %[[j]] :
+  // CHECK-SAME:     memref<?x8xf32, #[[$map_2d_stride_1]]>, index, index
+  //      CHECK: } else {
+  //               slow path, fill tmp alloc and yield a memref_casted version of it
+  //      CHECK:   %[[slow:.*]] = vector.transfer_read %[[A]][%[[i]], %[[j]]], %cst :
+  // CHECK-SAME:     memref<7x8xf32, #[[$map_2d_stride_1]]>, vector<4x8xf32>
+  //      CHECK:   %[[cast_alloc:.*]] = vector.type_cast %[[alloc]] :
+  // CHECK-SAME:     memref<4x8xf32> to memref<vector<4x8xf32>>
+  //      CHECK:   store %[[slow]], %[[cast_alloc]][] :
+  // CHECK-SAME:     memref<vector<4x8xf32>>
+  //      CHECK:   %[[yielded:.*]] = memref_cast %[[alloc]] :
+  // CHECK-SAME:     memref<4x8xf32> to memref<?x8xf32, #[[$map_2d_stride_1]]>
+  //      CHECK:   scf.yield %[[yielded]], %[[c0]], %[[c0]] :
+  // CHECK-SAME:     memref<?x8xf32, #[[$map_2d_stride_1]]>, index, index
+  //      CHECK: }
+  //      CHECK: %[[res:.*]] = vector.transfer_read {{.*}} {masked = [false, false]} :
+  // CHECK-SAME:   memref<?x8xf32, #[[$map_2d_stride_1]]>, vector<4x8xf32>
+  %1 = vector.transfer_read %A[%i, %j], %f0 :
+    memref<7x8xf32, offset:?, strides:[?, 1]>, vector<4x8xf32>
+
+  // CHECK: return %[[res]] : vector<4x8xf32>
+  return %1 : vector<4x8xf32>
+}

diff  --git a/mlir/test/lib/Transforms/TestVectorTransforms.cpp b/mlir/test/lib/Transforms/TestVectorTransforms.cpp
index 2058706dcbdd..0bba74e76385 100644
--- a/mlir/test/lib/Transforms/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestVectorTransforms.cpp
@@ -122,6 +122,17 @@ struct TestVectorUnrollingPatterns
   }
 };
 
+struct TestVectorTransferFullPartialSplitPatterns
+    : public PassWrapper<TestVectorTransferFullPartialSplitPatterns,
+                         FunctionPass> {
+  void runOnFunction() override {
+    MLIRContext *ctx = &getContext();
+    OwningRewritePatternList patterns;
+    patterns.insert<VectorTransferFullPartialRewriter>(ctx);
+    applyPatternsAndFoldGreedily(getFunction(), patterns);
+  }
+};
+
 } // end anonymous namespace
 
 namespace mlir {
@@ -141,5 +152,10 @@ void registerTestVectorConversions() {
   PassRegistration<TestVectorUnrollingPatterns> contractionUnrollingPass(
       "test-vector-unrolling-patterns",
       "Test conversion patterns to unroll contract ops in the vector dialect");
+
+  PassRegistration<TestVectorTransferFullPartialSplitPatterns>
+      vectorTransformFullPartialPass("test-vector-transfer-full-partial-split",
+                                     "Test conversion patterns to split "
+                                     "transfer ops via scf.if + linalg ops");
 }
 } // namespace mlir


        


More information about the Mlir-commits mailing list