[Mlir-commits] [mlir] f3cc854 - [mlir][Vector] Add folding of vector transfers from/into tensor producing ops.

Nicolas Vasilache llvmlistbot at llvm.org
Thu Mar 4 06:21:55 PST 2021


Author: Nicolas Vasilache
Date: 2021-03-04T14:17:42Z
New Revision: f3cc8543647cfcfd3ea383e6738bc58a258d6f74

URL: https://github.com/llvm/llvm-project/commit/f3cc8543647cfcfd3ea383e6738bc58a258d6f74
DIFF: https://github.com/llvm/llvm-project/commit/f3cc8543647cfcfd3ea383e6738bc58a258d6f74.diff

LOG: [mlir][Vector] Add folding of vector transfers from/into tensor producing ops.

Add a folder to rewrite a sequence such as:

```
   %t1 = ...
   %v = vector.transfer_read %t0[%c0...], {masked = [false...]} :
     tensor<static_sizesxf32>, vector<static_sizesxf32>
  %t2 = vector.transfer_write %v, %t1[%c0...] {masked = [false...]} :
     vector<static_sizesxf32>, tensor<static_sizesxf32>
```

into:

```
   %t0
```

The producer of t1 may or may not be DCE'd depending on whether it is a
block argument or has side effects.

Differential revision: https://reviews.llvm.org/D97934

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/VectorOps.cpp
    mlir/test/Dialect/Vector/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index b08a696281faf..ebe3250a09afb 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -2593,8 +2593,70 @@ static LogicalResult verify(TransferWriteOp op) {
                               [&op](Twine t) { return op.emitOpError(t); });
 }
 
-LogicalResult TransferWriteOp::fold(ArrayRef<Attribute>,
-                                    SmallVectorImpl<OpFoldResult> &) {
+/// Fold:
+/// ```
+///    %t1 = ...
+///    %v = vector.transfer_read %t0[%c0...], {masked = [false...]} :
+///      tensor<static_sizesxf32>, vector<static_sizesxf32>
+///    %t2 = vector.transfer_write %v, %t1[%c0...] {masked = [false...]} :
+///      vector<static_sizesxf32>, tensor<static_sizesxf32>
+/// ```
+///
+/// into:
+///
+/// ```
+///    %t0
+/// ```
+///
+/// The producer of t1 may or may not be DCE'd depending on whether it is a
+/// block argument or has side effects.
+static LogicalResult foldReadInitWrite(TransferWriteOp write,
+                                       ArrayRef<Attribute>,
+                                       SmallVectorImpl<OpFoldResult> &results) {
+  auto rankedTensorType = write.source().getType().dyn_cast<RankedTensorType>();
+  // If not operating on tensors, bail.
+  if (!rankedTensorType)
+    return failure();
+  // If no read, bail.
+  auto read = write.vector().getDefiningOp<vector::TransferReadOp>();
+  if (!read)
+    return failure();
+  // For now, only accept minor identity. Future: composition is minor identity.
+  if (!read.permutation_map().isMinorIdentity() ||
+      !write.permutation_map().isMinorIdentity())
+    return failure();
+  // Bail on mismatching ranks.
+  if (read.getTransferRank() != write.getTransferRank())
+    return failure();
+  // Bail on masked.
+  if (read.hasMaskedDim() || write.hasMaskedDim())
+    return failure();
+  // Tensor types must be the same.
+  if (read.source().getType() != rankedTensorType)
+    return failure();
+  // Vector types must be the same.
+  if (read.getVectorType() != write.getVectorType())
+    return failure();
+  // Vector and Tensor shapes must match.
+  if (read.getVectorType().getShape() != rankedTensorType.getShape())
+    return failure();
+  // If any index is nonzero.
+  auto isNotConstantZero = [](Value v) {
+    auto cstOp = v.getDefiningOp<ConstantIndexOp>();
+    return !cstOp || cstOp.getValue() != 0;
+  };
+  if (llvm::any_of(read.indices(), isNotConstantZero) ||
+      llvm::any_of(write.indices(), isNotConstantZero))
+    return failure();
+  // Success.
+  results.push_back(read.source());
+  return success();
+}
+
+LogicalResult TransferWriteOp::fold(ArrayRef<Attribute> operands,
+                                    SmallVectorImpl<OpFoldResult> &results) {
+  if (succeeded(foldReadInitWrite(*this, operands, results)))
+    return success();
   if (succeeded(foldTransferMaskAttribute(*this)))
     return success();
   return foldMemRefCast(*this);

diff  --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 45f5d7d2b9018..17da655f72cb5 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -pass-pipeline='func(canonicalize)' -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -pass-pipeline='func(canonicalize)' -split-input-file -allow-unregistered-dialect | FileCheck %s
 
 // -----
 
@@ -770,3 +770,32 @@ func @contractions(%a: vector<2x3xf32>, %b: vector<3x4xf32>, %c: vector<2x4xf32>
   return %1, %i8_1: vector<2x4xf32>, vector<2x4xi8>
 }
 
+// -----
+
+// CHECK-LABEL: func @transfer_folding_1
+//  CHECK-SAME:   %[[T0:[0-9a-zA-Z]+]]: tensor<2x3x4xf32>
+//  CHECK-SAME:   %[[T1:[0-9a-zA-Z]+]]: tensor<2x3x4xf32>
+func @transfer_folding_1(%t0: tensor<2x3x4xf32>, %t1: tensor<2x3x4xf32>)
+  -> (tensor<2x3x4xf32>, tensor<2x3x4xf32>, tensor<2x3x4xf32>)
+{
+  %c0 = constant 0 : index
+  %pad = constant 0.0 : f32
+  %v = vector.transfer_read %t0[%c0, %c0, %c0], %pad {masked = [false, false, false]} :
+    tensor<2x3x4xf32>, vector<2x3x4xf32>
+
+  %r0 = vector.transfer_write %v, %t1[%c0, %c0, %c0] {masked = [false, false, false]} :
+    vector<2x3x4xf32>, tensor<2x3x4xf32>
+
+  %t2 = "test.constant"() { value = dense<6.0> : tensor<2x3x4xf32>} : () -> (tensor<2x3x4xf32>)
+  %r1 = vector.transfer_write %v, %t2[%c0, %c0, %c0] {masked = [false, false, false]} :
+    vector<2x3x4xf32>, tensor<2x3x4xf32>
+
+
+  // CHECK-NEXT: some_op_that_may_have_side_effects
+  %t3 = "some_op_that_may_have_side_effects"() : () -> (tensor<2x3x4xf32>)
+  %r2 = vector.transfer_write %v, %t0[%c0, %c0, %c0] {masked = [false, false, false]} :
+    vector<2x3x4xf32>, tensor<2x3x4xf32>
+
+  // CHECK-NEXT: return %[[T0]], %[[T0]], %[[T0]]
+  return %r0, %r1, %r2: tensor<2x3x4xf32>, tensor<2x3x4xf32>, tensor<2x3x4xf32>
+}


        


More information about the Mlir-commits mailing list