[Mlir-commits] [mlir] 64f7fb5 - [mlir] Support masked N-D vector transfer ops in ProgressiveVectorToSCF.

Matthias Springer llvmlistbot at llvm.org
Fri Apr 23 02:24:13 PDT 2021


Author: Matthias Springer
Date: 2021-04-23T18:23:51+09:00
New Revision: 64f7fb5dfca14bead0e4b12142da2135f950034f

URL: https://github.com/llvm/llvm-project/commit/64f7fb5dfca14bead0e4b12142da2135f950034f
DIFF: https://github.com/llvm/llvm-project/commit/64f7fb5dfca14bead0e4b12142da2135f950034f.diff

LOG: [mlir] Support masked N-D vector transfer ops in ProgressiveVectorToSCF.

Mask vectors are handled similar to data vectors in N-D TransferWriteOp. They are copied into a temporary memory buffer, which can be indexed into with non-constant values.

Differential Revision: https://reviews.llvm.org/D101136

Added: 
    

Modified: 
    mlir/lib/Conversion/VectorToSCF/ProgressiveVectorToSCF.cpp
    mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-2d.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/VectorToSCF/ProgressiveVectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/ProgressiveVectorToSCF.cpp
index d42bd67082e8..08aca49c7af4 100644
--- a/mlir/lib/Conversion/VectorToSCF/ProgressiveVectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/ProgressiveVectorToSCF.cpp
@@ -56,16 +56,34 @@ static MemRefType unpackOneDim(MemRefType type) {
                                          vectorType.getElementType()));
 }
 
-// TODO: Parallelism and threadlocal considerations.
-static Value setAllocAtFunctionEntry(MemRefType type, Operation *op) {
+/// Helper data structure for data and mask buffers.
+struct BufferAllocs {
+  Value dataBuffer;
+  Value maskBuffer;
+};
+
+/// Allocate temporary buffers for data (vector) and mask (if present).
+/// TODO: Parallelism and threadlocal considerations.
+template <typename OpTy>
+static BufferAllocs allocBuffers(OpTy xferOp) {
   auto &b = ScopedContext::getBuilderRef();
   OpBuilder::InsertionGuard guard(b);
   Operation *scope =
-      op->getParentWithTrait<OpTrait::AutomaticAllocationScope>();
+      xferOp->template getParentWithTrait<OpTrait::AutomaticAllocationScope>();
   assert(scope && "Expected op to be inside automatic allocation scope");
   b.setInsertionPointToStart(&scope->getRegion(0).front());
-  Value res = memref_alloca(type);
-  return res;
+
+  BufferAllocs result;
+  auto bufferType = MemRefType::get({}, xferOp.getVectorType());
+  result.dataBuffer = memref_alloca(bufferType).value;
+
+  if (xferOp.mask()) {
+    auto maskType = MemRefType::get({}, xferOp.mask().getType());
+    result.maskBuffer = memref_alloca(maskType).value;
+    memref_store(xferOp.mask(), result.maskBuffer);
+  }
+
+  return result;
 }
 
 /// Given a vector transfer op, calculate which dimension of the `source`
@@ -238,6 +256,16 @@ static ArrayAttr dropFirstElem(OpBuilder &builder, ArrayAttr attr) {
   return ArrayAttr::get(builder.getContext(), attr.getValue().drop_front());
 }
 
+/// Given a transfer op, find the memref from which the mask is loaded. This
+/// is similar to Strategy<TransferWriteOp>::getBuffer.
+template <typename OpTy>
+static Value getMaskBuffer(OpTy xferOp) {
+  assert(xferOp.mask() && "Expected that transfer op has mask");
+  auto loadOp = xferOp.mask().template getDefiningOp<memref::LoadOp>();
+  assert(loadOp && "Expected transfer op mask produced by LoadOp");
+  return loadOp.getMemRef();
+}
+
 /// Codegen strategy, depending on the operation.
 template <typename OpTy>
 struct Strategy;
@@ -266,9 +294,9 @@ struct Strategy<TransferReadOp> {
     return getStoreOp(xferOp).getMemRef();
   }
 
-  /// Retrieve the indices of the current StoreOp.
-  static void getStoreIndices(TransferReadOp xferOp,
-                             SmallVector<Value, 8> &indices) {
+  /// Retrieve the indices of the current StoreOp that stores into the buffer.
+  static void getBufferIndices(TransferReadOp xferOp,
+                               SmallVector<Value, 8> &indices) {
     auto storeOp = getStoreOp(xferOp);
     auto prevIndices = memref::StoreOpAdaptor(storeOp).indices();
     indices.append(prevIndices.begin(), prevIndices.end());
@@ -300,10 +328,11 @@ struct Strategy<TransferReadOp> {
   ///
   /// Note: The loop and type cast are generated in TransferOpConversion.
   ///       The original TransferReadOp and store op are deleted in `cleanup`.
-  static void rewriteOp(OpBuilder &builder, TransferReadOp xferOp,
-                        Value buffer, Value iv) {
+  /// Note: The `mask` operand is set in TransferOpConversion.
+  static TransferReadOp rewriteOp(OpBuilder &builder, TransferReadOp xferOp,
+                                  Value buffer, Value iv) {
     SmallVector<Value, 8> storeIndices;
-    getStoreIndices(xferOp, storeIndices);
+    getBufferIndices(xferOp, storeIndices);
     storeIndices.push_back(iv);
 
     SmallVector<Value, 8> xferIndices;
@@ -321,6 +350,7 @@ struct Strategy<TransferReadOp> {
         newXfer.getDefiningOp()->setAttr(kPassLabel, builder.getUnitAttr());
 
     memref_store(newXfer, buffer, storeIndices);
+    return newXfer.getDefiningOp<TransferReadOp>();
   }
 
   /// Handle out-of-bounds accesses on the to-be-unpacked dimension: Write
@@ -329,7 +359,7 @@ struct Strategy<TransferReadOp> {
       OpBuilder &/*builder*/, TransferReadOp xferOp, Value buffer,
       Value iv) {
     SmallVector<Value, 8> storeIndices;
-    getStoreIndices(xferOp, storeIndices);
+    getBufferIndices(xferOp, storeIndices);
     storeIndices.push_back(iv);
 
     auto bufferType = buffer.getType().dyn_cast<ShapedType>();
@@ -361,9 +391,9 @@ struct Strategy<TransferWriteOp> {
     return loadOp.getMemRef();
   }
 
-  /// Retrieve the indices of the current LoadOp.
-  static void getLoadIndices(TransferWriteOp xferOp,
-                             SmallVector<Value, 8> &indices) {
+  /// Retrieve the indices of the current LoadOp that loads from the buffer.
+  static void getBufferIndices(TransferWriteOp xferOp,
+                               SmallVector<Value, 8> &indices) {
     auto loadOp = xferOp.vector().getDefiningOp<memref::LoadOp>();
     auto prevIndices = memref::LoadOpAdaptor(loadOp).indices();
     indices.append(prevIndices.begin(), prevIndices.end());
@@ -378,10 +408,10 @@ struct Strategy<TransferWriteOp> {
   ///    to memory.
   ///
   /// Note: For more details, see comments on Strategy<TransferReadOp>.
-  static void rewriteOp(OpBuilder &builder, TransferWriteOp xferOp,
-                        Value buffer, Value iv) {
+  static TransferWriteOp rewriteOp(OpBuilder &builder, TransferWriteOp xferOp,
+                                   Value buffer, Value iv) {
     SmallVector<Value, 8> loadIndices;
-    getLoadIndices(xferOp, loadIndices);
+    getBufferIndices(xferOp, loadIndices);
     loadIndices.push_back(iv);
 
     SmallVector<Value, 8> xferIndices;
@@ -397,6 +427,8 @@ struct Strategy<TransferWriteOp> {
 
     if (vecType.getRank() > kTargetRank)
         newXfer.op->setAttr(kPassLabel, builder.getUnitAttr());
+
+    return newXfer;
   }
 
   /// Handle out-of-bounds accesses on the to-be-unpacked dimension.
@@ -416,8 +448,6 @@ LogicalResult checkPrepareXferOp(OpTy xferOp) {
       return failure();
   if (xferOp.getVectorType().getRank() <= kTargetRank)
       return failure();
-  if (xferOp.mask())
-      return failure();
   return success();
 }
 
@@ -442,6 +472,8 @@ LogicalResult checkPrepareXferOp(OpTy xferOp) {
 /// memref.store %1, %0[] : memref<vector<5x4xf32>>
 /// %vec = memref.load %0[] : memref<vector<5x4xf32>>
 /// ```
+///
+/// Note: A second temporary buffer may be allocated for the `mask` operand.
 struct PrepareTransferReadConversion
     : public OpRewritePattern<TransferReadOp> {
   using OpRewritePattern<TransferReadOp>::OpRewritePattern;
@@ -452,12 +484,16 @@ struct PrepareTransferReadConversion
       return failure();
 
     ScopedContext scope(rewriter, xferOp.getLoc());
-    auto allocType = MemRefType::get({}, xferOp.getVectorType());
-    auto buffer = setAllocAtFunctionEntry(allocType, xferOp);
+    auto buffers = allocBuffers(xferOp);
     auto *newXfer = rewriter.clone(*xferOp.getOperation());
     newXfer->setAttr(kPassLabel, rewriter.getUnitAttr());
-    memref_store(newXfer->getResult(0), buffer);
-    rewriter.replaceOpWithNewOp<memref::LoadOp>(xferOp, buffer);
+    if (xferOp.mask()) {
+      auto loadedMask = memref_load(buffers.maskBuffer);
+      dyn_cast<TransferReadOp>(newXfer).maskMutable().assign(loadedMask);
+    }
+
+    memref_store(newXfer->getResult(0), buffers.dataBuffer);
+    rewriter.replaceOpWithNewOp<memref::LoadOp>(xferOp, buffers.dataBuffer);
 
     return success();
   }
@@ -484,6 +520,8 @@ struct PrepareTransferReadConversion
 /// vector.transfer_write %1, %A[%a, %b, %c] { __vector_to_scf_lowering__ }
 ///     : vector<5x4xf32>, memref<?x?x?xf32>
 /// ```
+///
+/// Note: A second temporary buffer may be allocated for the `mask` operand.
 struct PrepareTransferWriteConversion
     : public OpRewritePattern<TransferWriteOp> {
   using OpRewritePattern<TransferWriteOp>::OpRewritePattern;
@@ -494,16 +532,20 @@ struct PrepareTransferWriteConversion
       return failure();
 
     ScopedContext scope(rewriter, xferOp.getLoc());
-    auto allocType = MemRefType::get({}, xferOp.getVectorType());
-    auto buffer = setAllocAtFunctionEntry(allocType, xferOp);
-    memref_store(xferOp.vector(), buffer);
-    auto loadedVec = memref_load(buffer);
-
+    auto buffers = allocBuffers(xferOp);
+    memref_store(xferOp.vector(), buffers.dataBuffer);
+    auto loadedVec = memref_load(buffers.dataBuffer);
     rewriter.updateRootInPlace(xferOp, [&]() {
       xferOp.vectorMutable().assign(loadedVec);
       xferOp->setAttr(kPassLabel, rewriter.getUnitAttr());
     });
 
+    if (xferOp.mask()) {
+      auto loadedMask = memref_load(buffers.maskBuffer);
+      rewriter.updateRootInPlace(
+          xferOp, [&]() { xferOp.maskMutable().assign(loadedMask); });
+    }
+
     return success();
   }
 };
@@ -535,16 +577,28 @@ struct TransferOpConversion : public OpRewritePattern<OpTy> {
         return failure();
 
     ScopedContext scope(rewriter, xferOp.getLoc());
-    // How the buffer can be found depends on OpTy.
-    auto buffer = Strategy<OpTy>::getBuffer(xferOp);
-    auto bufferType = buffer.getType().template dyn_cast<MemRefType>();
-    auto castedType = unpackOneDim(bufferType);
-    auto casted = vector_type_cast(castedType, buffer);
+
+    // Find and cast data buffer. How the buffer can be found depends on OpTy.
+    auto dataBuffer = Strategy<OpTy>::getBuffer(xferOp);
+    auto dataBufferType = dataBuffer.getType().template dyn_cast<MemRefType>();
+    auto castedDataType = unpackOneDim(dataBufferType);
+    auto castedDataBuffer = vector_type_cast(castedDataType, dataBuffer);
+
+    // If the xferOp has a mask: Find and cast mask buffer.
+    Value castedMaskBuffer;
+    if (xferOp.mask()) {
+      auto maskBuffer = getMaskBuffer(xferOp);
+      auto maskBufferType =
+          maskBuffer.getType().template dyn_cast<MemRefType>();
+      auto castedMaskType = unpackOneDim(maskBufferType);
+      castedMaskBuffer = vector_type_cast(castedMaskType, maskBuffer);
+    }
 
     // Loop bounds and step.
     auto lb = std_constant_index(0).value;
     auto ub = std_constant_index(
-        castedType.getDimSize(castedType.getRank() - 1)).value;
+                  castedDataType.getDimSize(castedDataType.getRank() - 1))
+                  .value;
     auto step = std_constant_index(1).value;
 
     // Generate for loop.
@@ -555,11 +609,31 @@ struct TransferOpConversion : public OpRewritePattern<OpTy> {
       ScopedContext scope(b, loc);
       generateInBoundsCheck(
           xferOp, iv, b, unpackedDim(xferOp),
-          /*inBoundsCase=*/[&](OpBuilder &b, Location /*loc*/) {
-        Strategy<OpTy>::rewriteOp(b, xferOp, casted, iv);
-      }, /*outOfBoundsCase=*/[&](OpBuilder &b, Location /*loc*/) {
-        Strategy<OpTy>::handleOutOfBoundsDim(b, xferOp, casted, iv);
-      });
+          /*inBoundsCase=*/
+          [&](OpBuilder &b, Location /*loc*/) {
+            // Create new transfer op.
+            OpTy newXfer =
+                Strategy<OpTy>::rewriteOp(b, xferOp, castedDataBuffer, iv);
+
+            // If old transfer op has a mask: Set mask on new transfer op.
+            if (xferOp.mask()) {
+              OpBuilder::InsertionGuard guard(b);
+              b.setInsertionPoint(newXfer); // Insert load before newXfer.
+
+              SmallVector<Value, 8> loadIndices;
+              Strategy<OpTy>::getBufferIndices(xferOp, loadIndices);
+              loadIndices.push_back(iv);
+
+              auto mask = memref_load(castedMaskBuffer, loadIndices);
+              rewriter.updateRootInPlace(
+                  newXfer, [&]() { newXfer.maskMutable().assign(mask); });
+            }
+          },
+          /*outOfBoundsCase=*/
+          [&](OpBuilder &b, Location /*loc*/) {
+            Strategy<OpTy>::handleOutOfBoundsDim(b, xferOp, castedDataBuffer,
+                                                 iv);
+          });
       b.create<scf::YieldOp>(loc);
     });
 

diff  --git a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-2d.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-2d.mlir
index cbe0aa52a437..f4eef8b98b76 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-2d.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-2d.mlir
@@ -1,8 +1,3 @@
-// RUN: mlir-opt %s -convert-vector-to-scf -lower-affine -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \
-// RUN: mlir-cpu-runner -e entry -entry-point-result=void  \
-// RUN:   -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
-// RUN: FileCheck %s
-
 // RUN: mlir-opt %s -test-progressive-convert-vector-to-scf -lower-affine -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \
 // RUN: mlir-cpu-runner -e entry -entry-point-result=void  \
 // RUN:   -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
@@ -17,6 +12,19 @@ func @transfer_read_2d(%A : memref<?x?xf32>, %base1: index, %base2: index) {
   return
 }
 
+func @transfer_read_2d_mask(%A : memref<?x?xf32>, %base1: index, %base2: index) {
+  %fm42 = constant -42.0: f32
+  %mask = constant dense<[[1, 0, 1, 0, 1, 1, 1, 0, 1],
+                          [0, 0, 1, 1, 1, 1, 1, 0, 1],
+                          [1, 1, 1, 1, 1, 1, 1, 0, 1],
+                          [0, 0, 1, 0, 1, 1, 1, 0, 1]]> : vector<4x9xi1>
+  %f = vector.transfer_read %A[%base1, %base2], %fm42, %mask
+      {permutation_map = affine_map<(d0, d1) -> (d0, d1)>} :
+    memref<?x?xf32>, vector<4x9xf32>
+  vector.print %f: vector<4x9xf32>
+  return
+}
+
 func @transfer_read_2d_transposed(
     %A : memref<?x?xf32>, %base1: index, %base2: index) {
   %fm42 = constant -42.0: f32
@@ -80,7 +88,10 @@ func @entry() {
   call @transfer_write_2d(%A, %c3, %c1) : (memref<?x?xf32>, index, index) -> ()
   // Read shifted by 0 and pad with -42:
   call @transfer_read_2d(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()
-  // Same as above, but transposed
+  // Same as above, but apply a mask
+  call @transfer_read_2d_mask(%A, %c0, %c0)
+      : (memref<?x?xf32>, index, index) -> ()
+  // Same as above, but without mask and transposed
   call @transfer_read_2d_transposed(%A, %c0, %c0)
       : (memref<?x?xf32>, index, index) -> ()
   // Second vector dimension is a broadcast
@@ -92,5 +103,6 @@ func @entry() {
 // CHECK: ( ( 12, 13, -42, -42, -42, -42, -42, -42, -42 ), ( 22, 23, -42, -42, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ) )
 // CHECK: ( ( 12, 22, -42, -42, -42, -42, -42, -42, -42 ), ( 13, 23, -42, -42, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ) )
 // CHECK: ( ( 0, 1, 2, 3, -42, -42, -42, -42, -42 ), ( 10, 11, 12, 13, -42, -42, -42, -42, -42 ), ( 20, 21, 22, 23, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ) )
+// CHECK: ( ( 0, -42, 2, -42, -42, -42, -42, -42, -42 ), ( -42, -42, 12, 13, -42, -42, -42, -42, -42 ), ( 20, 21, 22, 23, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ) )
 // CHECK: ( ( 0, 10, 20, -42, -42, -42, -42, -42, -42 ), ( 1, 11, 21, -42, -42, -42, -42, -42, -42 ), ( 2, 12, 22, -42, -42, -42, -42, -42, -42 ), ( 3, 13, 23, -42, -42, -42, -42, -42, -42 ) )
 // CHECK: ( ( 12, 12, 12, 12, 12, 12, 12, 12, 12 ), ( 13, 13, 13, 13, 13, 13, 13, 13, 13 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ) )


        


More information about the Mlir-commits mailing list