[Mlir-commits] [mlir] ffdf0a3 - [mlir][vector] Fix bug in vector-transfer-full-partial-split

Matthias Springer llvmlistbot at llvm.org
Mon Sep 27 02:19:17 PDT 2021


Author: Matthias Springer
Date: 2021-09-27T18:12:17+09:00
New Revision: ffdf0a370db0c92cb6e8b0b0dbed3b18a0bfb897

URL: https://github.com/llvm/llvm-project/commit/ffdf0a370db0c92cb6e8b0b0dbed3b18a0bfb897
DIFF: https://github.com/llvm/llvm-project/commit/ffdf0a370db0c92cb6e8b0b0dbed3b18a0bfb897.diff

LOG: [mlir][vector] Fix bug in vector-transfer-full-partial-split

When splitting with linalg.copy, cannot write into the destination alloc directly. Instead, write into a subview of the alloc.

Differential Revision: https://reviews.llvm.org/D110512

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/VectorTransforms.cpp
    mlir/test/Dialect/Vector/vector-transfer-full-partial-split.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index f3ad31c042f6..7d6e6d2ba53d 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -1835,9 +1835,8 @@ static MemRefType getCastCompatibleMemRefType(MemRefType aT, MemRefType bT) {
 /// Operates under a scoped context to build the intersection between the
 /// view `xferOp.source()` @ `xferOp.indices()` and the view `alloc`.
 // TODO: view intersection/union/
diff erences should be a proper std op.
-static Value createSubViewIntersection(OpBuilder &b,
-                                       VectorTransferOpInterface xferOp,
-                                       Value alloc) {
+static std::pair<Value, Value> createSubViewIntersection(
+    OpBuilder &b, VectorTransferOpInterface xferOp, Value alloc) {
   ImplicitLocOpBuilder lb(xferOp.getLoc(), b);
   int64_t memrefRank = xferOp.getShapedType().getRank();
   // TODO: relax this precondition, will require rank-reducing subviews.
@@ -1864,11 +1863,15 @@ static Value createSubViewIntersection(OpBuilder &b,
     sizes.push_back(affineMin);
   });
 
-  SmallVector<OpFoldResult, 4> indices = llvm::to_vector<4>(llvm::map_range(
+  SmallVector<OpFoldResult> srcIndices = llvm::to_vector<4>(llvm::map_range(
       xferOp.indices(), [](Value idx) -> OpFoldResult { return idx; }));
-  return lb.create<memref::SubViewOp>(
-      isaWrite ? alloc : xferOp.source(), indices, sizes,
-      SmallVector<OpFoldResult>(memrefRank, OpBuilder(xferOp).getIndexAttr(1)));
+  SmallVector<OpFoldResult> destIndices(memrefRank, b.getIndexAttr(0));
+  SmallVector<OpFoldResult> strides(memrefRank, b.getIndexAttr(1));
+  auto copySrc = lb.create<memref::SubViewOp>(
+      isaWrite ? alloc : xferOp.source(), srcIndices, sizes, strides);
+  auto copyDest = lb.create<memref::SubViewOp>(
+      isaWrite ? xferOp.source() : alloc, destIndices, sizes, strides);
+  return std::make_pair(copySrc, copyDest);
 }
 
 /// Given an `xferOp` for which:
@@ -1877,14 +1880,15 @@ static Value createSubViewIntersection(OpBuilder &b,
 /// Produce IR resembling:
 /// ```
 ///    %1:3 = scf.if (%inBounds) {
-///      memref.cast %A: memref<A...> to compatibleMemRefType
+///      %view = memref.cast %A: memref<A...> to compatibleMemRefType
 ///      scf.yield %view, ... : compatibleMemRefType, index, index
 ///    } else {
 ///      %2 = linalg.fill(%pad, %alloc)
 ///      %3 = subview %view [...][...][...]
-///      linalg.copy(%3, %alloc)
-///      memref.cast %alloc: memref<B...> to compatibleMemRefType
-///      scf.yield %4, ... : compatibleMemRefType, index, index
+///      %4 = subview %alloc [0, 0] [...] [...]
+///      linalg.copy(%3, %4)
+///      %5 = memref.cast %alloc: memref<B...> to compatibleMemRefType
+///      scf.yield %5, ... : compatibleMemRefType, index, index
 ///   }
 /// ```
 /// Return the produced scf::IfOp.
@@ -1910,9 +1914,9 @@ createFullPartialLinalgCopy(OpBuilder &b, vector::TransferReadOp xferOp,
         b.create<linalg::FillOp>(loc, xferOp.padding(), alloc);
         // Take partial subview of memref which guarantees no dimension
         // overflows.
-        Value memRefSubView = createSubViewIntersection(
+        std::pair<Value, Value> copyArgs = createSubViewIntersection(
             b, cast<VectorTransferOpInterface>(xferOp.getOperation()), alloc);
-        b.create<linalg::CopyOp>(loc, memRefSubView, alloc);
+        b.create<linalg::CopyOp>(loc, copyArgs.first, copyArgs.second);
         Value casted =
             b.create<memref::CastOp>(loc, alloc, compatibleMemRefType);
         scf::ValueVector viewAndIndices{casted};
@@ -2030,7 +2034,8 @@ getLocationToWriteFullVec(OpBuilder &b, vector::TransferWriteOp xferOp,
 ///    %notInBounds = xor %inBounds, %true
 ///    scf.if (%notInBounds) {
 ///      %3 = subview %alloc [...][...][...]
-///      linalg.copy(%3, %view)
+///      %4 = subview %view [0, 0][...][...]
+///      linalg.copy(%3, %4)
 ///   }
 /// ```
 static void createFullPartialLinalgCopy(OpBuilder &b,
@@ -2040,9 +2045,9 @@ static void createFullPartialLinalgCopy(OpBuilder &b,
   auto notInBounds =
       lb.create<XOrOp>(inBoundsCond, lb.create<ConstantIntOp>(true, 1));
   lb.create<scf::IfOp>(notInBounds, [&](OpBuilder &b, Location loc) {
-    Value memRefSubView = createSubViewIntersection(
+    std::pair<Value, Value> copyArgs = createSubViewIntersection(
         b, cast<VectorTransferOpInterface>(xferOp.getOperation()), alloc);
-    b.create<linalg::CopyOp>(loc, memRefSubView, xferOp.source());
+    b.create<linalg::CopyOp>(loc, copyArgs.first, copyArgs.second);
     b.create<scf::YieldOp>(loc, ValueRange{});
   });
 }

diff  --git a/mlir/test/Dialect/Vector/vector-transfer-full-partial-split.mlir b/mlir/test/Dialect/Vector/vector-transfer-full-partial-split.mlir
index 77a64476ca58..b98bcab92657 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-full-partial-split.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-full-partial-split.mlir
@@ -81,7 +81,8 @@ func @split_vector_transfer_read_2d(%A: memref<?x8xf32>, %i: index, %j: index) -
   //      LINALG:   %[[sv1:.*]] = affine.min #[[$bounds_map_8]](%[[c8]], %[[j]], %[[c8]])
   //      LINALG:   %[[sv:.*]] = memref.subview %[[A]][%[[i]], %[[j]]] [%[[sv0]], %[[sv1]]] [1, 1]
   // LINALG-SAME:     memref<?x8xf32> to memref<?x?xf32, #[[$map_2d_stride_8x1]]>
-  //      LINALG:   linalg.copy(%[[sv]], %[[alloc]]) : memref<?x?xf32, #[[$map_2d_stride_8x1]]>, memref<4x8xf32>
+  //      LINALG:   %[[alloc_view:.*]] = memref.subview %[[alloc]][0, 0] [%[[sv0]], %[[sv1]]] [1, 1]
+  //      LINALG:   linalg.copy(%[[sv]], %[[alloc_view]]) : memref<?x?xf32, #[[$map_2d_stride_8x1]]>, memref<?x?xf32, #{{.*}}>
   //      LINALG:   %[[yielded:.*]] = memref.cast %[[alloc]] :
   // LINALG-SAME:     memref<4x8xf32> to memref<?x8xf32>
   //      LINALG:   scf.yield %[[yielded]], %[[c0]], %[[c0]] :
@@ -172,7 +173,8 @@ func @split_vector_transfer_read_strided_2d(
   //      LINALG:   %[[sv1:.*]] = affine.min #[[$bounds_map_8]](%[[c8]], %[[j]], %[[c8]])
   //      LINALG:   %[[sv:.*]] = memref.subview %[[A]][%[[i]], %[[j]]] [%[[sv0]], %[[sv1]]] [1, 1]
   // LINALG-SAME:     memref<7x8xf32, #[[$map_2d_stride_1]]> to memref<?x?xf32, #[[$map_2d_stride_1]]>
-  //      LINALG:   linalg.copy(%[[sv]], %[[alloc]]) : memref<?x?xf32, #[[$map_2d_stride_1]]>, memref<4x8xf32>
+  //      LINALG:   %[[alloc_view:.*]] = memref.subview %[[alloc]][0, 0] [%[[sv0]], %[[sv1]]] [1, 1]
+  //      LINALG:   linalg.copy(%[[sv]], %[[alloc_view]]) : memref<?x?xf32, #[[$map_2d_stride_1]]>, memref<?x?xf32, #{{.*}}>
   //      LINALG:   %[[yielded:.*]] = memref.cast %[[alloc]] :
   // LINALG-SAME:     memref<4x8xf32> to memref<?x8xf32, #[[$map_2d_stride_1]]>
   //      LINALG:   scf.yield %[[yielded]], %[[c0]], %[[c0]] :
@@ -276,8 +278,9 @@ func @split_vector_transfer_write_2d(%V: vector<4x8xf32>, %A: memref<?x8xf32>, %
 // LINALG:             %[[VAL_22:.*]] = memref.subview %[[TEMP]]
 // LINALG-SAME:            [%[[I]], %[[J]]] [%[[VAL_20]], %[[VAL_21]]]
 // LINALG-SAME:            [1, 1] : memref<4x8xf32> to memref<?x?xf32, #[[MAP4]]>
-// LINALG:             linalg.copy(%[[VAL_22]], %[[DEST]])
-// LINALG-SAME:            : memref<?x?xf32, #[[MAP4]]>, memref<?x8xf32>
+// LINALG:             %[[DEST_VIEW:.*]] = memref.subview %[[DEST]][0, 0] [%[[VAL_20]], %[[VAL_21]]] [1, 1]
+// LINALG:             linalg.copy(%[[VAL_22]], %[[DEST_VIEW]])
+// LINALG-SAME:            : memref<?x?xf32, #[[MAP4]]>, memref<?x?xf32, #{{.*}}>
 // LINALG:           }
 // LINALG:           return
 // LINALG:         }
@@ -384,8 +387,9 @@ func @split_vector_transfer_write_strided_2d(
 // LINALG:             %[[VAL_22:.*]] = memref.subview %[[TEMP]]
 // LINALG-SAME:            [%[[I]], %[[J]]] [%[[VAL_20]], %[[VAL_21]]]
 // LINALG-SAME:            [1, 1] : memref<4x8xf32> to memref<?x?xf32, #[[MAP5]]>
-// LINALG:             linalg.copy(%[[VAL_22]], %[[DEST]])
-// LINALG-SAME:            : memref<?x?xf32, #[[MAP5]]>, memref<7x8xf32, #[[MAP0]]>
+// LINALG:             %[[DEST_VIEW:.*]] = memref.subview %[[DEST]][0, 0] [%[[VAL_20]], %[[VAL_21]]] [1, 1]
+// LINALG:             linalg.copy(%[[VAL_22]], %[[DEST_VIEW]])
+// LINALG-SAME:            : memref<?x?xf32, #[[MAP5]]>, memref<?x?xf32, #[[MAP0]]>
 // LINALG:           }
 // LINALG:           return
 // LINALG:         }


        


More information about the Mlir-commits mailing list