[Mlir-commits] [mlir] bebf5d5 - [mlir][Linalg] Add support for vector.transfer ops to comprehensive bufferization (2/n).

Nicolas Vasilache llvmlistbot at llvm.org
Thu May 13 15:36:11 PDT 2021


Author: Nicolas Vasilache
Date: 2021-05-13T22:26:28Z
New Revision: bebf5d56bff75cd5b74b58cbdcb965885a82916f

URL: https://github.com/llvm/llvm-project/commit/bebf5d56bff75cd5b74b58cbdcb965885a82916f
DIFF: https://github.com/llvm/llvm-project/commit/bebf5d56bff75cd5b74b58cbdcb965885a82916f.diff

LOG: [mlir][Linalg] Add support for vector.transfer ops to comprehensive bufferization (2/n).

Differential revision: https://reviews.llvm.org/D102395

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
    mlir/test/Dialect/Linalg/comprehensive-func-bufferize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
index 2534eeeb7dcd..acac96ea88d7 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
@@ -82,8 +82,8 @@
 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
 #include "mlir/Dialect/Linalg/Passes.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Vector/VectorOps.h"
 #include "mlir/IR/Operation.h"
-#include "mlir/Interfaces/LoopLikeInterface.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/BufferUtils.h"
 
@@ -128,16 +128,25 @@ OpResult getMatchingOpResult(LinalgOp linalgOp, OpOperand &opOperand) {
   return linalgOp->getResult(outputOperandIndex - numOutputBuffers);
 }
 
+OpResult getMatchingOpResult(VectorTransferOpInterface op,
+                             OpOperand &opOperand) {
+  if (opOperand.get() != op.source() ||
+      !op.source().getType().isa<TensorType>())
+    return OpResult();
+  return op->getResult(0);
+}
+
 /// Determine which results may be reused inplace by the bufferization
 /// patterns of `bufferizeFuncOpInternals`.
 /// The inplace analysis uses this information along with interfering read
 /// analysis to determine which op results reuse the same buffer as some
 /// operand.
 OpResult getMatchingOpResult(OpOperand &opOperand) {
-  OpResult res =
-      llvm::TypeSwitch<Operation *, OpResult>(opOperand.getOwner())
-          .Case([&](LinalgOp op) { return getMatchingOpResult(op, opOperand); })
-          .Default([&](Operation *op) { return OpResult(); });
+  OpResult res = llvm::TypeSwitch<Operation *, OpResult>(opOperand.getOwner())
+                     .Case<LinalgOp, VectorTransferOpInterface>([&](auto op) {
+                       return getMatchingOpResult(op, opOperand);
+                     })
+                     .Default([&](Operation *op) { return OpResult(); });
   return res;
 }
 
@@ -708,6 +717,54 @@ static LogicalResult convertReturnOp(OpBuilder &b, ReturnOp returnOp,
   return success();
 }
 
+static LogicalResult convertTransferOp(OpBuilder &b,
+                                       VectorTransferOpInterface op,
+                                       BlockAndValueMapping &bvm) {
+  // Take a guard before anything else.
+  OpBuilder::InsertionGuard g(b);
+  b.setInsertionPoint(op);
+  Location loc = op.getLoc();
+
+  if (op.getShapedType().isa<MemRefType>())
+    return failure();
+
+  LLVM_DEBUG(DBGS() << "convert: " << *op << "\n");
+
+  /// transfer_read from buffer
+  if (auto readOp = dyn_cast<vector::TransferReadOp>(op.getOperation())) {
+    readOp.sourceMutable().assign(lookup(bvm, op.source()));
+    return success();
+  }
+
+  auto inPlace = getInPlace(op->getResult(0));
+  auto writeOp = cast<vector::TransferWriteOp>(op.getOperation());
+
+  // If transfer_write is not inPlace, allocate a new buffer.
+  Value newInputBuffer;
+  if (inPlace != InPlaceSpec::True) {
+    newInputBuffer =
+        createNewAllocDeallocPairForShapedValue(b, loc, writeOp.result());
+    b.setInsertionPointAfter(newInputBuffer.getDefiningOp());
+    map(bvm, writeOp.result(), newInputBuffer);
+  } else {
+    // InPlace write will result in memref.tensor_load(x) which must
+    // canonicalize away with one of it uses.
+    newInputBuffer = lookup(bvm, writeOp.source());
+  }
+
+  // Create a new transfer_write on buffer that doesn't have a return value.
+  // Leave the previous transfer_write to dead code as it still has uses at
+  // this point.
+  b.create<vector::TransferWriteOp>(
+      loc, writeOp.vector(), newInputBuffer, writeOp.indices(),
+      writeOp.permutation_map(),
+      writeOp.in_bounds() ? *writeOp.in_bounds() : ArrayAttr());
+
+  map(bvm, op->getResult(0), newInputBuffer);
+
+  return success();
+}
+
 static void inPlaceAnalysisFuncOpInternals(FuncOp funcOp,
                                            const DominanceInfo &domInfo) {
   assert(funcOp && funcOp->getNumRegions() > 0 && !funcOp.body().empty() &&
@@ -733,6 +790,9 @@ static LogicalResult bufferizeFuncOpInternals(
             .Case([&](memref::DimOp op) { return convertDimOp(b, op, bvm); })
             .Case([&](LinalgOp op) { return convertAnyLinalgOp(b, op, bvm); })
             .Case([&](ReturnOp op) { return convertReturnOp(b, op, bvm); })
+            .Case([&](VectorTransferOpInterface op) {
+              return convertTransferOp(b, op, bvm);
+            })
             .Default([&](Operation *op) {
               auto isaTensor = [](Type t) { return t.isa<TensorType>(); };
               if (llvm::any_of(op->getOperandTypes(), isaTensor) ||

diff  --git a/mlir/test/Dialect/Linalg/comprehensive-func-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-func-bufferize.mlir
index 69c4e3fe5919..a5636329fd3a 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-func-bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-func-bufferize.mlir
@@ -81,3 +81,39 @@ func @not_inplace(%A : tensor<?x?xf32> {linalg.inplaceable = true}) -> tensor<?x
     -> tensor<?x?xf32>
   return %r: tensor<?x?xf32>
 }
+// -----
+
+// CHECK-LABEL: func @vec_inplace
+func @vec_inplace(%A : tensor<?xf32> {linalg.inplaceable = true}, %vec : vector<4xf32>)
+    -> tensor<?xf32>
+{
+  %c0 = constant 0 : index
+  // CHECK-NOT: alloc
+  %r = vector.transfer_write %vec, %A[%c0] : vector<4xf32>, tensor<?xf32>
+  return %r: tensor<?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @vec_not_inplace
+//  CHECK-SAME:   %[[A:[a-zA-Z0-9]*]]: tensor<?xf32> {linalg.inplaceable = true}
+func @vec_not_inplace(%A : tensor<?xf32> {linalg.inplaceable = true}, %vec : vector<4xf32>)
+    -> (tensor<?xf32>, tensor<?xf32>)
+{
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+
+  //       CHECK: %[[BUFFER_CAST:.*]] = memref.buffer_cast %[[A]] : memref<?xf32, #[[$map_2d_dyn]]>
+
+  /// Cross-op multiple uses of %A, the first vector.transfer which has interfering reads must alloc.
+  //      CHECK: %[[ALLOC:.*]] = memref.alloc
+  // CHECK-NEXT: vector.transfer_write {{.*}}, %[[ALLOC]]
+  %r0 = vector.transfer_write %vec, %A[%c0] : vector<4xf32>, tensor<?xf32>
+
+  /// The second vector.transfer has no interfering reads and can reuse the buffer.
+  //  CHECK-NOT: alloc
+  // CHECK-NEXT: vector.transfer_write {{.*}}, %[[BUFFER_CAST]]
+  %r1 = vector.transfer_write %vec, %A[%c1] : vector<4xf32>, tensor<?xf32>
+  return %r0, %r1: tensor<?xf32>, tensor<?xf32>
+}
+


        


More information about the Mlir-commits mailing list