[Mlir-commits] [mlir] 9f6ba4b - [mlir][vector] Extend transfer_write to read propagation

Thomas Raoux llvmlistbot at llvm.org
Fri Jul 22 10:11:26 PDT 2022


Author: Thomas Raoux
Date: 2022-07-22T17:11:06Z
New Revision: 9f6ba4be26858a8c581b212efebe993e3c22e8d7

URL: https://github.com/llvm/llvm-project/commit/9f6ba4be26858a8c581b212efebe993e3c22e8d7
DIFF: https://github.com/llvm/llvm-project/commit/9f6ba4be26858a8c581b212efebe993e3c22e8d7.diff

LOG: [mlir][vector] Extend transfer_write to read propagation

Folding of transfer_write into transfer_read is already supported but
this requires the read and write to have the same permuation map.
After linalg vectorization it is common to have different ppermuation
map for write followed by read even though the cases could be
propagated.
This canonicalization handle cases where the permuation maps are
different but the data read and written match and replace the transfer
ops with broadcast and permuation

Differential Revision: https://reviews.llvm.org/D130135

Added: 
    

Modified: 
    mlir/include/mlir/Interfaces/VectorInterfaces.td
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/test/Dialect/Vector/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Interfaces/VectorInterfaces.td b/mlir/include/mlir/Interfaces/VectorInterfaces.td
index 063ddb31ca363..832528bfa78ad 100644
--- a/mlir/include/mlir/Interfaces/VectorInterfaces.td
+++ b/mlir/include/mlir/Interfaces/VectorInterfaces.td
@@ -243,6 +243,36 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> {
         fun(resultIdx, indicesIdx);
       }]
     >,
+    InterfaceMethod<
+      /*desc=*/[{
+      Return an upper-bound shape accessed by the transfer op within the
+      tensor/memref operand.
+      For example:
+      ```
+        vector.transfer %w0[%i, %j] {
+          permutation_map = affine_map<(d0, d1) -> (d1, d0, 0)>} :
+          tensor<?x?xf32>, vector<4x2x6xf32>
+      ```
+      returns a shape [2, 4].
+      }],
+      /*retTy=*/"SmallVector<int64_t>",
+      /*methodName=*/"getTransferChunkAccessed",
+      /*args=*/(ins),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        SmallVector<int64_t> dimSizes($_op.getPermutationMap().getNumDims(), 0);
+        for (auto vecDims : llvm::zip($_op.getPermutationMap().getResults(),
+                                      $_op.getVectorType().getShape())) {
+          AffineExpr dim = std::get<0>(vecDims);
+          int64_t size = std::get<1>(vecDims);
+          // Skip broadcast.
+          if (dim.isa<AffineConstantExpr>())
+            continue;
+          dimSizes[dim.cast<AffineDimExpr>().getPosition()] = size;
+        }
+        return dimSizes;
+      }]
+    >,
   ];
 }
 

diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index c474922b0444e..764ecc4eb53cb 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -3284,11 +3284,91 @@ struct FoldExtractSliceIntoTransferRead
     return success();
   }
 };
+
+/// Store to load forwarding for transfer operations with permuation maps.
+/// Even if the permutation maps are 
diff erent we can still propagate the store
+/// into the load if the size of the dimensions read and written match. Then we
+/// can replace the transfer_read + transfer_write by vector.broadcast and
+/// vector.transpose.
+/// Example:
+/// ```
+/// %w0 = vector.transfer_write %v0, %arg0[%c0, %c0, %c0]
+///  {in_bounds = [true, true],
+///   permutation_map = affine_map<(d0, d1, d2) -> (d2, d1)>} :
+///   vector<4x1xf32>, tensor<4x4x4xf32>
+///  %r = vector.transfer_read %w0[%c0, %c0, %c0], %cf0
+///   {in_bounds = [true, true, true, true],
+///   permutation_map = affine_map<(d0, d1, d2) -> (d1, 0, d2, 0)>} :
+///   tensor<4x4x4xf32>, vector<1x100x4x5xf32>
+/// ```
+/// To:
+/// ```
+/// %0 = vector.broadcast %arg1 : vector<4x1xf32> to vector<100x5x4x1xf32>
+/// %r = vector.transpose %0, [3, 0, 2, 1] :
+///   vector<100x5x4x1xf32> to vector<1x100x4x5xf32>
+/// ```
+struct TransferReadAfterWriteToBroadcast
+    : public OpRewritePattern<TransferReadOp> {
+  using OpRewritePattern<TransferReadOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(TransferReadOp readOp,
+                                PatternRewriter &rewriter) const override {
+    if (readOp.hasOutOfBoundsDim() ||
+        !readOp.getShapedType().isa<RankedTensorType>())
+      return failure();
+    auto defWrite = readOp.getSource().getDefiningOp<vector::TransferWriteOp>();
+    if (!defWrite)
+      return failure();
+
+    SmallVector<int64_t> readDims = readOp.getTransferChunkAccessed();
+    Value vec;
+    if (readOp.getIndices() == defWrite.getIndices() &&
+        readOp.getMask() == defWrite.getMask()) {
+      SmallVector<int64_t> writeDims = defWrite.getTransferChunkAccessed();
+      // TODO: If the writeDim is a superset of the read dims we could do an
+      // extract_strided_slice.
+      if (writeDims == readDims)
+        vec = defWrite.getVector();
+    }
+    // TODO: loop through the chain of transfer_write if we can prove that they
+    // don't overlap with the transfer_read. This requires improving
+    // `isDisjointTransferIndices` helper.
+    if (!vec)
+      return failure();
+    SmallVector<unsigned> permutation;
+    AffineMap readMap = compressUnusedDims(readOp.getPermutationMap());
+    AffineMap writeMap = compressUnusedDims(defWrite.getPermutationMap());
+    AffineMap map = readMap.compose(writeMap);
+    if (map.getNumResults() == 0)
+      return failure();
+    // Calculate the permuation to apply to go from the vector stored to the
+    // vector read.
+    if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation))
+      return failure();
+
+    Location loc = readOp.getLoc();
+    // Calculate the broadcast shape by applying the reverse permuation to the
+    // final shape we want.
+    ArrayRef<int64_t> destShape = readOp.getVectorType().getShape();
+    SmallVector<int64_t> broadcastShape(destShape.size());
+    for (const auto &pos : llvm::enumerate(permutation))
+      broadcastShape[pos.value()] = destShape[pos.index()];
+    VectorType broadcastedType = VectorType::get(
+        broadcastShape, defWrite.getVectorType().getElementType());
+    vec = rewriter.create<vector::BroadcastOp>(loc, broadcastedType, vec);
+    SmallVector<int64_t> transposePerm(permutation.begin(), permutation.end());
+    rewriter.replaceOpWithNewOp<vector::TransposeOp>(readOp, vec,
+                                                     transposePerm);
+    return success();
+  }
+};
 } // namespace
 
 void TransferReadOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                                  MLIRContext *context) {
-  results.add<FoldExtractSliceIntoTransferRead>(context);
+  results
+      .add<FoldExtractSliceIntoTransferRead, TransferReadAfterWriteToBroadcast>(
+          context);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 54025a626f002..be86fe91c17c1 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1058,6 +1058,45 @@ func.func @store_to_load_negative_tensor(%arg0 : tensor<4x4xf32>,
 
 // -----
 
+// CHECK-LABEL: func @store_to_load_tensor_broadcast
+//  CHECK-SAME: (%[[ARG:.*]]: tensor<4x4xf32>, %[[V0:.*]]: vector<4x2xf32>)
+//       CHECK:   %[[B:.*]] = vector.broadcast %[[V0]] : vector<4x2xf32> to vector<6x4x2xf32>
+//       CHECK:   %[[T:.*]] = vector.transpose %[[B]], [1, 2, 0] : vector<6x4x2xf32> to vector<4x2x6xf32>
+//       CHECK:   return %[[T]] : vector<4x2x6xf32>
+func.func @store_to_load_tensor_broadcast(%arg0 : tensor<4x4xf32>,
+  %v0 : vector<4x2xf32>) -> vector<4x2x6xf32> {
+  %c0 = arith.constant 0 : index
+  %cf0 = arith.constant 0.0 : f32
+  %w0 = vector.transfer_write %v0, %arg0[%c0, %c0] {in_bounds = [true, true]} :
+    vector<4x2xf32>, tensor<4x4xf32>
+  %0 = vector.transfer_read %w0[%c0, %c0], %cf0 {in_bounds = [true, true, true],
+  permutation_map = affine_map<(d0, d1) -> (d0, d1, 0)>} :
+    tensor<4x4xf32>, vector<4x2x6xf32>
+  return %0 : vector<4x2x6xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @store_to_load_tensor_perm_broadcast
+//  CHECK-SAME: (%[[ARG:.*]]: tensor<4x4x4xf32>, %[[V0:.*]]: vector<4x1xf32>)
+//       CHECK:   %[[B:.*]] = vector.broadcast %[[V0]] : vector<4x1xf32> to vector<100x5x4x1xf32>
+//       CHECK:   %[[T:.*]] = vector.transpose %[[B]], [3, 0, 2, 1] : vector<100x5x4x1xf32> to vector<1x100x4x5xf32>
+//       CHECK:   return %[[T]] : vector<1x100x4x5xf32>
+func.func @store_to_load_tensor_perm_broadcast(%arg0 : tensor<4x4x4xf32>,
+  %v0 : vector<4x1xf32>) -> vector<1x100x4x5xf32> {
+  %c0 = arith.constant 0 : index
+  %cf0 = arith.constant 0.0 : f32
+  %w0 = vector.transfer_write %v0, %arg0[%c0, %c0, %c0] {in_bounds = [true, true],
+  permutation_map = affine_map<(d0, d1, d2) -> (d2, d1)>} :
+    vector<4x1xf32>, tensor<4x4x4xf32>
+  %0 = vector.transfer_read %w0[%c0, %c0, %c0], %cf0 {in_bounds = [true, true, true, true],
+  permutation_map = affine_map<(d0, d1, d2) -> (d1, 0, d2, 0)>} :
+    tensor<4x4x4xf32>, vector<1x100x4x5xf32>
+  return %0 : vector<1x100x4x5xf32>
+}
+
+// -----
+
 
 // CHECK-LABEL: func @dead_store_tensor
 //   CHECK-DAG:      %[[C0:.*]] = arith.constant 0 : index


        


More information about the Mlir-commits mailing list