[Mlir-commits] [mlir] 88a4899 - Support VectorTransfer splitting on writes also.

Tres Popp llvmlistbot at llvm.org
Tue May 11 01:33:48 PDT 2021


Author: Tres Popp
Date: 2021-05-11T10:33:27+02:00
New Revision: 88a48999d249a5478d813596d1cfac6ba82126dc

URL: https://github.com/llvm/llvm-project/commit/88a48999d249a5478d813596d1cfac6ba82126dc
DIFF: https://github.com/llvm/llvm-project/commit/88a48999d249a5478d813596d1cfac6ba82126dc.diff

LOG: Support VectorTransfer splitting on writes also.

VectorTransfer split previously only split read xfer ops. This adds
the same logic to write ops. The resulting code involves 2
conditionals for write ops while read ops only needed 1, but the created
ops are built upon the same patterns, so pattern matching/expectations
are all consistent other than in regards to the if/else ops.

Differential Revision: https://reviews.llvm.org/D102157

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/VectorTransforms.cpp
    mlir/test/Dialect/Vector/vector-transfer-full-partial-split.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index 8c03cffe9418..30b3c4a32c20 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -2379,6 +2379,7 @@ static Value createScopedSubViewIntersection(VectorTransferOpInterface xferOp,
       xferOp.indices().take_front(xferOp.getLeadingShapedRank());
   SmallVector<OpFoldResult, 4> sizes;
   sizes.append(leadingIndices.begin(), leadingIndices.end());
+  auto isaWrite = isa<vector::TransferWriteOp>(xferOp);
   xferOp.zipResultAndIndexing([&](int64_t resultIdx, int64_t indicesIdx) {
     using MapList = ArrayRef<ArrayRef<AffineExpr>>;
     Value dimMemRef = memref_dim(xferOp.source(), indicesIdx);
@@ -2397,7 +2398,7 @@ static Value createScopedSubViewIntersection(VectorTransferOpInterface xferOp,
   SmallVector<OpFoldResult, 4> indices = llvm::to_vector<4>(llvm::map_range(
       xferOp.indices(), [](Value idx) -> OpFoldResult { return idx; }));
   return memref_sub_view(
-      xferOp.source(), indices, sizes,
+      isaWrite ? alloc : xferOp.source(), indices, sizes,
       SmallVector<OpFoldResult>(memrefRank, OpBuilder(xferOp).getIndexAttr(1)));
 }
 
@@ -2509,14 +2510,119 @@ static scf::IfOp createScopedFullPartialVectorTransferRead(
   return fullPartialIfOp;
 }
 
+/// Given an `xferOp` for which:
+///   1. `inBoundsCond` and a `compatibleMemRefType` have been computed.
+///   2. a memref of single vector `alloc` has been allocated.
+/// Produce IR resembling:
+/// ```
+///    %1:3 = scf.if (%inBounds) {
+///      memref.cast %A: memref<A...> to compatibleMemRefType
+///      scf.yield %view, ... : compatibleMemRefType, index, index
+///    } else {
+///      %3 = vector.type_cast %extra_alloc :
+///        memref<...> to memref<vector<...>>
+///      %4 = memref.cast %alloc: memref<B...> to compatibleMemRefType
+///      scf.yield %4, ... : compatibleMemRefType, index, index
+///   }
+/// ```
+static ValueRange getLocationToWriteFullVec(vector::TransferWriteOp xferOp,
+                                            TypeRange returnTypes,
+                                            Value inBoundsCond,
+                                            MemRefType compatibleMemRefType,
+                                            Value alloc) {
+  using namespace edsc;
+  using namespace edsc::intrinsics;
+  Value zero = std_constant_index(0);
+  Value memref = xferOp.source();
+  return conditionBuilder(
+      returnTypes, inBoundsCond,
+      [&]() -> scf::ValueVector {
+        Value res = memref;
+        if (compatibleMemRefType != xferOp.getShapedType())
+          res = memref_cast(memref, compatibleMemRefType);
+        scf::ValueVector viewAndIndices{res};
+        viewAndIndices.insert(viewAndIndices.end(), xferOp.indices().begin(),
+                              xferOp.indices().end());
+        return viewAndIndices;
+      },
+      [&]() -> scf::ValueVector {
+        Value casted = memref_cast(alloc, compatibleMemRefType);
+        scf::ValueVector viewAndIndices{casted};
+        viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(),
+                              zero);
+        return viewAndIndices;
+      });
+}
+
+/// Given an `xferOp` for which:
+///   1. `inBoundsCond` has been computed.
+///   2. a memref of single vector `alloc` has been allocated.
+///   3. it originally wrote to %view
+/// Produce IR resembling:
+/// ```
+///    %notInBounds = xor %inBounds, %true
+///    scf.if (%notInBounds) {
+///      %3 = subview %alloc [...][...][...]
+///      linalg.copy(%3, %view)
+///   }
+/// ```
+static void createScopedFullPartialLinalgCopy(vector::TransferWriteOp xferOp,
+                                              Value inBoundsCond, Value alloc) {
+  using namespace edsc;
+  using namespace edsc::intrinsics;
+  auto &b = ScopedContext::getBuilderRef();
+  auto notInBounds = b.create<XOrOp>(
+      xferOp->getLoc(), inBoundsCond,
+      b.create<::mlir::ConstantIntOp>(xferOp.getLoc(), true, 1));
+
+  conditionBuilder(notInBounds, [&]() {
+    Value memRefSubView = createScopedSubViewIntersection(
+        cast<VectorTransferOpInterface>(xferOp.getOperation()), alloc);
+    linalg_copy(memRefSubView, xferOp.source());
+  });
+}
+
+/// Given an `xferOp` for which:
+///   1. `inBoundsCond` has been computed.
+///   2. a memref of single vector `alloc` has been allocated.
+///   3. it originally wrote to %view
+/// Produce IR resembling:
+/// ```
+///    %notInBounds = xor %inBounds, %true
+///    scf.if (%notInBounds) {
+///      %2 = load %alloc : memref<vector<...>>
+///      vector.transfer_write %2, %view[...] : memref<A...>, vector<...>
+///   }
+/// ```
+static void
+createScopedFullPartialVectorTransferWrite(vector::TransferWriteOp xferOp,
+                                           Value inBoundsCond, Value alloc) {
+  using namespace edsc;
+  using namespace edsc::intrinsics;
+  auto &b = ScopedContext::getBuilderRef();
+  auto notInBounds = b.create<XOrOp>(
+      xferOp->getLoc(), inBoundsCond,
+      b.create<::mlir::ConstantIntOp>(xferOp.getLoc(), true, 1));
+  conditionBuilder(notInBounds, [&]() {
+    BlockAndValueMapping mapping;
+
+    Value load = memref_load(vector_type_cast(
+        MemRefType::get({}, xferOp.vector().getType()), alloc));
+
+    mapping.map(xferOp.vector(), load);
+    b.clone(*xferOp.getOperation(), mapping);
+  });
+}
+
 /// Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds
 /// masking) fastpath and a slowpath.
+///
+/// For vector.transfer_read:
 /// If `ifOp` is not null and the result is `success, the `ifOp` points to the
 /// newly created conditional upon function return.
 /// To accomodate for the fact that the original vector.transfer indexing may be
 /// arbitrary and the slow path indexes @[0...0] in the temporary buffer, the
 /// scf.if op returns a view and values of type index.
-/// At this time, only vector.transfer_read case is implemented.
 ///
 /// Example (a 2-D vector.transfer_read):
 /// ```
@@ -2537,6 +2643,32 @@ static scf::IfOp createScopedFullPartialVectorTransferRead(
 /// ```
 /// where `alloc` is a top of the function alloca'ed buffer of one vector.
 ///
+/// For vector.transfer_write:
+/// There are 2 conditional blocks. First a block to decide which memref and
+/// indices to use for an unmasked, inbounds write. Then a conditional block to
+/// further copy a partial buffer into the final result in the slow path case.
+///
+/// Example (a 2-D vector.transfer_write):
+/// ```
+///    vector.transfer_write %arg, %0[...], %pad : memref<A...>, vector<...>
+/// ```
+/// is transformed into:
+/// ```
+///    %1:3 = scf.if (%inBounds) {
+///      memref.cast %A: memref<A...> to compatibleMemRefType
+///      scf.yield %view : compatibleMemRefType, index, index
+///    } else {
+///      memref.cast %alloc: memref<B...> to compatibleMemRefType
+///      scf.yield %4 : compatibleMemRefType, index, index
+///     }
+///    %0 = vector.transfer_write %arg, %1#0[%1#1, %1#2] {in_bounds = [true ...
+///                                                                    true]}
+///    scf.if (%notInBounds) {
+///      // slowpath: not in-bounds vector.transfer or linalg.copy.
+///    }
+/// ```
+/// where `alloc` is a top of the function alloca'ed buffer of one vector.
+///
 /// Preconditions:
 ///  1. `xferOp.permutation_map()` must be a minor identity map
 ///  2. the rank of the `xferOp.source()` and the rank of the `xferOp.vector()`
@@ -2554,27 +2686,29 @@ LogicalResult mlir::vector::splitFullAndPartialTransfer(
   SmallVector<bool, 4> bools(xferOp.getTransferRank(), true);
   auto inBoundsAttr = b.getBoolArrayAttr(bools);
   if (options.vectorTransferSplit == VectorTransferSplit::ForceInBounds) {
-    xferOp->setAttr(vector::TransferReadOp::getInBoundsAttrName(),
-                    inBoundsAttr);
+    xferOp->setAttr(xferOp.getInBoundsAttrName(), inBoundsAttr);
     return success();
   }
 
-  assert(succeeded(splitFullAndPartialTransferPrecondition(xferOp)) &&
-         "Expected splitFullAndPartialTransferPrecondition to hold");
-  auto xferReadOp = dyn_cast<vector::TransferReadOp>(xferOp.getOperation());
+  // Assert preconditions. Additionally, keep the variables in an inner scope to
+  // ensure they aren't used in the wrong scopes further down.
+  {
+    assert(succeeded(splitFullAndPartialTransferPrecondition(xferOp)) &&
+           "Expected splitFullAndPartialTransferPrecondition to hold");
 
-  // TODO: add support for write case.
-  if (!xferReadOp)
-    return failure();
+    auto xferReadOp = dyn_cast<vector::TransferReadOp>(xferOp.getOperation());
+    auto xferWriteOp = dyn_cast<vector::TransferWriteOp>(xferOp.getOperation());
 
-  if (xferReadOp.mask())
-    return failure();
+    if (!(xferReadOp || xferWriteOp))
+      return failure();
+    if (xferWriteOp && xferWriteOp.mask())
+      return failure();
+    if (xferReadOp && xferReadOp.mask())
+      return failure();
+  }
 
   OpBuilder::InsertionGuard guard(b);
-  if (Operation *sourceOp = xferOp.source().getDefiningOp())
-    b.setInsertionPointAfter(sourceOp);
-  else
-    b.setInsertionPoint(xferOp);
+  b.setInsertionPoint(xferOp);
   ScopedContext scope(b, xferOp.getLoc());
   Value inBoundsCond = createScopedInBoundsCond(
       cast<VectorTransferOpInterface>(xferOp.getOperation()));
@@ -2596,26 +2730,57 @@ LogicalResult mlir::vector::splitFullAndPartialTransfer(
   MemRefType compatibleMemRefType =
       getCastCompatibleMemRefType(xferOp.getShapedType().cast<MemRefType>(),
                                   alloc.getType().cast<MemRefType>());
-
-  // Read case: full fill + partial copy -> in-bounds vector.xfer_read.
   SmallVector<Type, 4> returnTypes(1 + xferOp.getTransferRank(),
                                    b.getIndexType());
   returnTypes[0] = compatibleMemRefType;
-  scf::IfOp fullPartialIfOp =
-      options.vectorTransferSplit == VectorTransferSplit::VectorTransfer
-          ? createScopedFullPartialVectorTransferRead(
-                xferReadOp, returnTypes, inBoundsCond, compatibleMemRefType,
-                alloc)
-          : createScopedFullPartialLinalgCopy(xferReadOp, returnTypes,
-                                              inBoundsCond,
-                                              compatibleMemRefType, alloc);
-  if (ifOp)
-    *ifOp = fullPartialIfOp;
-
-  // Set existing read op to in-bounds, it always reads from a full buffer.
-  for (unsigned i = 0, e = returnTypes.size(); i != e; ++i)
-    xferReadOp.setOperand(i, fullPartialIfOp.getResult(i));
-  xferOp->setAttr(vector::TransferReadOp::getInBoundsAttrName(), inBoundsAttr);
+
+  if (auto xferReadOp =
+          dyn_cast<vector::TransferReadOp>(xferOp.getOperation())) {
+    // Read case: full fill + partial copy -> in-bounds vector.xfer_read.
+    scf::IfOp fullPartialIfOp =
+        options.vectorTransferSplit == VectorTransferSplit::VectorTransfer
+            ? createScopedFullPartialVectorTransferRead(
+                  xferReadOp, returnTypes, inBoundsCond, compatibleMemRefType,
+                  alloc)
+            : createScopedFullPartialLinalgCopy(xferReadOp, returnTypes,
+                                                inBoundsCond,
+                                                compatibleMemRefType, alloc);
+    if (ifOp)
+      *ifOp = fullPartialIfOp;
+
+    // Set existing read op to in-bounds, it always reads from a full buffer.
+    for (unsigned i = 0, e = returnTypes.size(); i != e; ++i)
+      xferReadOp.setOperand(i, fullPartialIfOp.getResult(i));
+
+    xferOp->setAttr(xferOp.getInBoundsAttrName(), inBoundsAttr);
+
+    return success();
+  }
+
+  auto xferWriteOp = cast<vector::TransferWriteOp>(xferOp.getOperation());
+
+  // Decide which location to write the entire vector to.
+  auto memrefAndIndices = getLocationToWriteFullVec(
+      xferWriteOp, returnTypes, inBoundsCond, compatibleMemRefType, alloc);
+
+  // Do an in bounds write to either the output or the extra allocated buffer.
+  // The operation is cloned to prevent deleting information needed for the
+  // later IR creation.
+  BlockAndValueMapping mapping;
+  mapping.map(xferWriteOp.source(), memrefAndIndices.front());
+  mapping.map(xferWriteOp.indices(), memrefAndIndices.drop_front());
+  auto *clone = b.clone(*xferWriteOp, mapping);
+  clone->setAttr(xferWriteOp.getInBoundsAttrName(), inBoundsAttr);
+
+  // Create a potential copy from the allocated buffer to the final output in
+  // the slow path case.
+  if (options.vectorTransferSplit == VectorTransferSplit::VectorTransfer)
+    createScopedFullPartialVectorTransferWrite(xferWriteOp, inBoundsCond,
+                                               alloc);
+  else
+    createScopedFullPartialLinalgCopy(xferWriteOp, inBoundsCond, alloc);
+
+  xferOp->erase();
 
   return success();
 }

diff  --git a/mlir/test/Dialect/Vector/vector-transfer-full-partial-split.mlir b/mlir/test/Dialect/Vector/vector-transfer-full-partial-split.mlir
index 0c83d3e57909..a311e43e1567 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-full-partial-split.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-full-partial-split.mlir
@@ -1,5 +1,5 @@
-// RUN: mlir-opt %s -test-vector-transfer-full-partial-split | FileCheck %s
-// RUN: mlir-opt %s -test-vector-transfer-full-partial-split=use-linalg-copy | FileCheck %s --check-prefix=LINALG
+// RUN: mlir-opt %s -test-vector-transfer-full-partial-split -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -test-vector-transfer-full-partial-split=use-linalg-copy -split-input-file | FileCheck %s --check-prefix=LINALG
 
 // CHECK-DAG: #[[$map_p4:.*]] = affine_map<()[s0] -> (s0 + 4)>
 // CHECK-DAG: #[[$map_p8:.*]] = affine_map<()[s0] -> (s0 + 8)>
@@ -186,3 +186,206 @@ func @split_vector_transfer_read_strided_2d(
   // CHECK: return %[[res]] : vector<4x8xf32>
   return %1 : vector<4x8xf32>
 }
+
+// -----
+
+func @split_vector_transfer_write_2d(%V: vector<4x8xf32>, %A: memref<?x8xf32>, %i: index, %j: index) {
+  vector.transfer_write %V, %A[%i, %j] :
+    vector<4x8xf32>, memref<?x8xf32>
+  return
+}
+
+// CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0] -> (s0 + 4)>
+// CHECK-DAG: #[[MAP1:.*]] = affine_map<()[s0] -> (s0 + 8)>
+// CHECK:     func @split_vector_transfer_write_2d(
+// CHECK-SAME:                                         %[[VEC:.*]]: vector<4x8xf32>,
+// CHECK-SAME:                                         %[[DEST:.*]]: memref<?x8xf32>,
+// CHECK-SAME:                                         %[[I:.*]]: index,
+// CHECK-SAME:                                         %[[J:.*]]: index) {
+// CHECK-DAG:       %[[C8:.*]] = constant 8 : index
+// CHECK-DAG:       %[[C0:.*]] = constant 0 : index
+// CHECK-DAG:       %[[CT:.*]] = constant true
+// CHECK:           %[[TEMP:.*]] = memref.alloca() {alignment = 32 : i64} : memref<4x8xf32>
+// CHECK:           %[[VAL_8:.*]] = affine.apply #[[MAP0]]()[%[[I]]]
+// CHECK:           %[[DIM0:.*]] = memref.dim %[[DEST]], %[[C0]] : memref<?x8xf32>
+// CHECK:           %[[DIM0_IN:.*]] = cmpi sle, %[[VAL_8]], %[[DIM0]] : index
+// CHECK:           %[[DIM1:.*]] = affine.apply #[[MAP1]]()[%[[J]]]
+// CHECK:           %[[DIM1_IN:.*]] = cmpi sle, %[[DIM1]], %[[C8]] : index
+// CHECK:           %[[IN_BOUNDS:.*]] = and %[[DIM0_IN]], %[[DIM1_IN]] : i1
+// CHECK:           %[[IN_BOUND_DEST:.*]]:3 = scf.if %[[IN_BOUNDS]] ->
+// CHECK-SAME:          (memref<?x8xf32>, index, index) {
+// CHECK:             scf.yield %[[DEST]], %[[I]], %[[J]] : memref<?x8xf32>, index, index
+// CHECK:           } else {
+// CHECK:             %[[VAL_15:.*]] = memref.cast %[[TEMP]]
+// CHECK-SAME:            : memref<4x8xf32> to memref<?x8xf32>
+// CHECK:             scf.yield %[[VAL_15]], %[[C0]], %[[C0]]
+// CHECK-SAME:            : memref<?x8xf32>, index, index
+// CHECK:           }
+// CHECK:           vector.transfer_write %[[VEC]],
+// CHECK-SAME:           %[[IN_BOUND_DEST:.*]]#0[%[[IN_BOUND_DEST]]#1, %[[IN_BOUND_DEST]]#2]
+// CHECK-SAME:           {in_bounds = [true, true]} : vector<4x8xf32>, memref<?x8xf32>
+// CHECK:           %[[OUT_BOUNDS:.*]] = xor %[[IN_BOUNDS]], %[[CT]] : i1
+// CHECK:           scf.if %[[OUT_BOUNDS]] {
+// CHECK:             %[[CASTED:.*]] = vector.type_cast %[[TEMP]]
+// CHECK-SAME:            : memref<4x8xf32> to memref<vector<4x8xf32>>
+// CHECK:             %[[RESULT_COPY:.*]] = memref.load %[[CASTED]][]
+// CHECK-SAME:            : memref<vector<4x8xf32>>
+// CHECK:             vector.transfer_write %[[RESULT_COPY]],
+// CHECK-SAME:            %[[DEST]][%[[I]], %[[J]]]
+// CHECK-SAME:            : vector<4x8xf32>, memref<?x8xf32>
+// CHECK:           }
+// CHECK:           return
+// CHECK:         }
+
+// LINALG-DAG: #[[MAP0:.*]] = affine_map<()[s0] -> (s0 + 4)>
+// LINALG-DAG: #[[MAP1:.*]] = affine_map<()[s0] -> (s0 + 8)>
+// LINALG-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2) -> (d0 - d1, 4)>
+// LINALG-DAG: #[[MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0 - d1, 8)>
+// LINALG-DAG: #[[MAP4:.*]] = affine_map<(d0, d1)[s0] -> (d0 * 8 + s0 + d1)>
+// LINALG:     func @split_vector_transfer_write_2d(
+// LINALG-SAME:                                         %[[VEC:.*]]: vector<4x8xf32>,
+// LINALG-SAME:                                         %[[DEST:.*]]: memref<?x8xf32>,
+// LINALG-SAME:                                         %[[I:.*]]: index,
+// LINALG-SAME:                                         %[[J:.*]]: index) {
+// LINALG-DAG:       %[[CT:.*]] = constant true
+// LINALG-DAG:       %[[C0:.*]] = constant 0 : index
+// LINALG-DAG:       %[[C4:.*]] = constant 4 : index
+// LINALG-DAG:       %[[C8:.*]] = constant 8 : index
+// LINALG:           %[[TEMP:.*]] = memref.alloca() {alignment = 32 : i64} : memref<4x8xf32>
+// LINALG:           %[[IDX0:.*]] = affine.apply #[[MAP0]]()[%[[I]]]
+// LINALG:           %[[DIM0:.*]] = memref.dim %[[DEST]], %[[C0]] : memref<?x8xf32>
+// LINALG:           %[[DIM0_IN:.*]] = cmpi sle, %[[IDX0]], %[[DIM0]] : index
+// LINALG:           %[[DIM1:.*]] = affine.apply #[[MAP1]]()[%[[J]]]
+// LINALG:           %[[DIM1_IN:.*]] = cmpi sle, %[[DIM1]], %[[C8]] : index
+// LINALG:           %[[IN_BOUNDS:.*]] = and %[[DIM0_IN]], %[[DIM1_IN]] : i1
+// LINALG:           %[[IN_BOUND_DEST:.*]]:3 = scf.if %[[IN_BOUNDS]]
+// LINALG-SAME:          -> (memref<?x8xf32>, index, index) {
+// LINALG:             scf.yield %[[DEST]], %[[I]], %[[J]] : memref<?x8xf32>, index, index
+// LINALG:           } else {
+// LINALG:             %[[VAL_16:.*]] = memref.cast %[[TEMP]] : memref<4x8xf32> to memref<?x8xf32>
+// LINALG:             scf.yield %[[VAL_16]], %[[C0]], %[[C0]] : memref<?x8xf32>, index, index
+// LINALG:           }
+// LINALG:           vector.transfer_write %[[VEC]],
+// LINALG-SAME:          %[[IN_BOUND_DEST:.*]]#0[%[[IN_BOUND_DEST]]#1, %[[IN_BOUND_DEST]]#2]
+// LINALG-SAME:          {in_bounds = [true, true]} : vector<4x8xf32>, memref<?x8xf32>
+// LINALG:           %[[OUT_BOUNDS:.*]] = xor %[[IN_BOUNDS]], %[[CT]] : i1
+// LINALG:           scf.if %[[OUT_BOUNDS]] {
+// LINALG:             %[[VAL_19:.*]] = memref.dim %[[DEST]], %[[C0]] : memref<?x8xf32>
+// LINALG-DAG:         %[[VAL_20:.*]] = affine.min #[[MAP2]](%[[VAL_19]], %[[I]], %[[C4]])
+// LINALG-DAG:         %[[VAL_21:.*]] = affine.min #[[MAP3]](%[[C8]], %[[J]], %[[C8]])
+// LINALG:             %[[VAL_22:.*]] = memref.subview %[[TEMP]]
+// LINALG-SAME:            [%[[I]], %[[J]]] [%[[VAL_20]], %[[VAL_21]]]
+// LINALG-SAME:            [1, 1] : memref<4x8xf32> to memref<?x?xf32, #[[MAP4]]>
+// LINALG:             linalg.copy(%[[VAL_22]], %[[DEST]])
+// LINALG-SAME:            : memref<?x?xf32, #[[MAP4]]>, memref<?x8xf32>
+// LINALG:           }
+// LINALG:           return
+// LINALG:         }
+
+// -----
+
+func @split_vector_transfer_write_strided_2d(
+    %V: vector<4x8xf32>, %A: memref<7x8xf32, offset:?, strides:[?, 1]>,
+    %i: index, %j: index) {
+  vector.transfer_write %V, %A[%i, %j] :
+    vector<4x8xf32>, memref<7x8xf32, offset:?, strides:[?, 1]>
+  return
+}
+
+// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
+// CHECK-DAG: #[[MAP1:.*]] = affine_map<()[s0] -> (s0 + 4)>
+// CHECK-DAG: #[[MAP2:.*]] = affine_map<()[s0] -> (s0 + 8)>
+// CHECK:   func @split_vector_transfer_write_strided_2d(
+// CHECK-SAME:                                                 %[[VEC:.*]]: vector<4x8xf32>,
+// CHECK-SAME:                                                 %[[DEST:.*]]: memref<7x8xf32, #[[MAP0]]>,
+// CHECK-SAME:                                                 %[[I:.*]]: index,
+// CHECK-SAME:                                                 %[[J:.*]]: index) {
+// CHECK-DAG:       %[[C7:.*]] = constant 7 : index
+// CHECK-DAG:       %[[C8:.*]] = constant 8 : index
+// CHECK-DAG:       %[[C0:.*]] = constant 0 : index
+// CHECK-DAG:       %[[CT:.*]] = constant true
+// CHECK:           %[[TEMP:.*]] = memref.alloca() {alignment = 32 : i64} : memref<4x8xf32>
+// CHECK:           %[[DIM0:.*]] = affine.apply #[[MAP1]]()[%[[I]]]
+// CHECK:           %[[DIM0_IN:.*]] = cmpi sle, %[[DIM0]], %[[C7]] : index
+// CHECK:           %[[DIM1:.*]] = affine.apply #[[MAP2]]()[%[[J]]]
+// CHECK:           %[[DIM1_IN:.*]] = cmpi sle, %[[DIM1]], %[[C8]] : index
+// CHECK:           %[[IN_BOUNDS:.*]] = and %[[DIM0_IN]], %[[DIM1_IN]] : i1
+// CHECK:           %[[IN_BOUND_DEST:.*]]:3 = scf.if %[[IN_BOUNDS]]
+// CHECK-SAME:          -> (memref<?x8xf32, #[[MAP0]]>, index, index) {
+// CHECK:             %[[VAL_15:.*]] = memref.cast %[[DEST]]
+// CHECK-SAME:            : memref<7x8xf32, #[[MAP0]]> to memref<?x8xf32, #[[MAP0]]>
+// CHECK:             scf.yield %[[VAL_15]], %[[I]], %[[J]]
+// CHECK-SAME:            : memref<?x8xf32, #[[MAP0]]>, index, index
+// CHECK:           } else {
+// CHECK:             %[[VAL_16:.*]] = memref.cast %[[TEMP]]
+// CHECK-SAME:            : memref<4x8xf32> to memref<?x8xf32, #[[MAP0]]>
+// CHECK:             scf.yield %[[VAL_16]], %[[C0]], %[[C0]]
+// CHECK-SAME:            : memref<?x8xf32, #[[MAP0]]>, index, index
+// CHECK:           }
+// CHECK:           vector.transfer_write %[[VEC]],
+// CHECK-SAME:          %[[IN_BOUND_DEST:.*]]#0
+// CHECK-SAME:          [%[[IN_BOUND_DEST]]#1, %[[IN_BOUND_DEST]]#2]
+// CHECK-SAME:          {in_bounds = [true, true]} : vector<4x8xf32>, memref<?x8xf32, #[[MAP0]]>
+// CHECK:           %[[OUT_BOUNDS:.*]] = xor %[[IN_BOUNDS]], %[[CT]] : i1
+// CHECK:           scf.if %[[OUT_BOUNDS]] {
+// CHECK:             %[[VAL_19:.*]] = vector.type_cast %[[TEMP]]
+// CHECK-SAME:            : memref<4x8xf32> to memref<vector<4x8xf32>>
+// CHECK:             %[[VAL_20:.*]] = memref.load %[[VAL_19]][]
+// CHECK-SAME:            : memref<vector<4x8xf32>>
+// CHECK:             vector.transfer_write %[[VAL_20]], %[[DEST]][%[[I]], %[[J]]]
+// CHECK-SAME:            : vector<4x8xf32>, memref<7x8xf32, #[[MAP0]]>
+// CHECK:           }
+// CHECK:           return
+// CHECK:         }
+
+// LINALG-DAG: #[[MAP0:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
+// LINALG-DAG: #[[MAP1:.*]] = affine_map<()[s0] -> (s0 + 4)>
+// LINALG-DAG: #[[MAP2:.*]] = affine_map<()[s0] -> (s0 + 8)>
+// LINALG-DAG: #[[MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0 - d1, 4)>
+// LINALG-DAG: #[[MAP4:.*]] = affine_map<(d0, d1, d2) -> (d0 - d1, 8)>
+// LINALG-DAG: #[[MAP5:.*]] = affine_map<(d0, d1)[s0] -> (d0 * 8 + s0 + d1)>
+// LINALG:   func @split_vector_transfer_write_strided_2d(
+// LINALG-SAME:                                                 %[[VEC:.*]]: vector<4x8xf32>,
+// LINALG-SAME:                                                 %[[DEST:.*]]: memref<7x8xf32, #[[MAP0]]>,
+// LINALG-SAME:                                                 %[[I:.*]]: index,
+// LINALG-SAME:                                                 %[[J:.*]]: index) {
+// LINALG-DAG:       %[[C0:.*]] = constant 0 : index
+// LINALG-DAG:       %[[CT:.*]] = constant true
+// LINALG-DAG:       %[[C7:.*]] = constant 7 : index
+// LINALG-DAG:       %[[C4:.*]] = constant 4 : index
+// LINALG-DAG:       %[[C8:.*]] = constant 8 : index
+// LINALG:           %[[TEMP:.*]] = memref.alloca() {alignment = 32 : i64} : memref<4x8xf32>
+// LINALG:           %[[DIM0:.*]] = affine.apply #[[MAP1]]()[%[[I]]]
+// LINALG:           %[[DIM0_IN:.*]] = cmpi sle, %[[DIM0]], %[[C7]] : index
+// LINALG:           %[[DIM1:.*]] = affine.apply #[[MAP2]]()[%[[J]]]
+// LINALG:           %[[DIM1_IN:.*]] = cmpi sle, %[[DIM1]], %[[C8]] : index
+// LINALG:           %[[IN_BOUNDS:.*]] = and %[[DIM0_IN]], %[[DIM1_IN]] : i1
+// LINALG:           %[[IN_BOUND_DEST:.*]]:3 = scf.if %[[IN_BOUNDS]]
+// LINALG-SAME:          -> (memref<?x8xf32, #[[MAP0]]>, index, index) {
+// LINALG:             %[[VAL_16:.*]] = memref.cast %[[DEST]]
+// LINALG-SAME:            : memref<7x8xf32, #[[MAP0]]> to memref<?x8xf32, #[[MAP0]]>
+// LINALG:             scf.yield %[[VAL_16]], %[[I]], %[[J]]
+// LINALG-SAME:            : memref<?x8xf32, #[[MAP0]]>, index, index
+// LINALG:           } else {
+// LINALG:             %[[VAL_17:.*]] = memref.cast %[[TEMP]]
+// LINALG-SAME:            : memref<4x8xf32> to memref<?x8xf32, #[[MAP0]]>
+// LINALG:             scf.yield %[[VAL_17]], %[[C0]], %[[C0]]
+// LINALG-SAME:            : memref<?x8xf32, #[[MAP0]]>, index, index
+// LINALG:           }
+// LINALG:           vector.transfer_write %[[VEC]],
+// LINALG-SAME:          %[[IN_BOUND_DEST:.*]]#0
+// LINALG-SAME:          [%[[IN_BOUND_DEST]]#1, %[[IN_BOUND_DEST]]#2]
+// LINALG-SAME:          {in_bounds = [true, true]}
+// LINALG-SAME:          : vector<4x8xf32>, memref<?x8xf32, #[[MAP0]]>
+// LINALG:           %[[OUT_BOUNDS:.*]] = xor %[[IN_BOUNDS]], %[[CT]] : i1
+// LINALG:           scf.if %[[OUT_BOUNDS]] {
+// LINALG-DAG:         %[[VAL_20:.*]] = affine.min #[[MAP3]](%[[C7]], %[[I]], %[[C4]])
+// LINALG-DAG:         %[[VAL_21:.*]] = affine.min #[[MAP4]](%[[C8]], %[[J]], %[[C8]])
+// LINALG:             %[[VAL_22:.*]] = memref.subview %[[TEMP]]
+// LINALG-SAME:            [%[[I]], %[[J]]] [%[[VAL_20]], %[[VAL_21]]]
+// LINALG-SAME:            [1, 1] : memref<4x8xf32> to memref<?x?xf32, #[[MAP5]]>
+// LINALG:             linalg.copy(%[[VAL_22]], %[[DEST]])
+// LINALG-SAME:            : memref<?x?xf32, #[[MAP5]]>, memref<7x8xf32, #[[MAP0]]>
+// LINALG:           }
+// LINALG:           return
+// LINALG:         }


        


More information about the Mlir-commits mailing list