[Mlir-commits] [mlir] 843f1fc - [mlir][scf] Add scf.for + tensor.cast canonicalization pattern

Nicolas Vasilache llvmlistbot at llvm.org
Fri Apr 16 09:55:07 PDT 2021


Author: Nicolas Vasilache
Date: 2021-04-16T16:50:21Z
New Revision: 843f1fc82598216a4be672ba51820b037dae106b

URL: https://github.com/llvm/llvm-project/commit/843f1fc82598216a4be672ba51820b037dae106b
DIFF: https://github.com/llvm/llvm-project/commit/843f1fc82598216a4be672ba51820b037dae106b.diff

LOG: [mlir][scf] Add scf.for + tensor.cast canonicalization pattern

Fold scf.for iter_arg/result pairs that go through incoming/ougoing
a tensor.cast op pair so as to pull the tensor.cast inside the scf.for:

```
  %0 = tensor.cast %t0 : tensor<32x1024xf32> to tensor<?x?xf32>
  %1 = scf.for %i = %c0 to %c1024 step %c32 iter_args(%iter_t0 = %0)
     -> (tensor<?x?xf32>) {
    %2 = call @do(%iter_t0) : (tensor<?x?xf32>) -> tensor<?x?xf32>
    scf.yield %2 : tensor<?x?xf32>
  }
  %2 = tensor.cast %1 : tensor<?x?xf32> to tensor<32x1024xf32>
  use_of(%2)
```

folds into:

```
  %0 = scf.for %arg2 = %c0 to %c1024 step %c32 iter_args(%arg3 = %arg0)
      -> (tensor<32x1024xf32>) {
    %2 = tensor.cast %arg3 : tensor<32x1024xf32> to tensor<?x?xf32>
    %3 = call @do(%2) : (tensor<?x?xf32>) -> tensor<?x?xf32>
    %4 = tensor.cast %3 : tensor<?x?xf32> to tensor<32x1024xf32>
    scf.yield %4 : tensor<32x1024xf32>
  }
  use_of(%0)
```

Differential Revision: https://reviews.llvm.org/D100661

Added: 
    

Modified: 
    mlir/lib/Dialect/SCF/SCF.cpp
    mlir/test/Dialect/SCF/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp
index efbc87273ca82..fa4fb9ffef338 100644
--- a/mlir/lib/Dialect/SCF/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/SCF.cpp
@@ -9,6 +9,7 @@
 #include "mlir/Dialect/SCF/SCF.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/IR/BlockAndValueMapping.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Support/MathExtras.h"
@@ -578,6 +579,140 @@ struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {
   }
 };
 
+/// Perform a replacement of one iter OpOperand of an scf.for to the
+/// `replacement` value which is expected to be the source of a tensor.cast.
+/// tensor.cast ops are inserted inside the block to account for the type cast.
+static ForOp replaceTensorCastForOpIterArg(PatternRewriter &rewriter,
+                                           OpOperand &operand,
+                                           Value replacement) {
+  Type oldType = operand.get().getType(), newType = replacement.getType();
+  assert(oldType.isa<RankedTensorType>() && newType.isa<RankedTensorType>() &&
+         "expected ranked tensor types");
+
+  // 1. Create new iter operands, exactly 1 is replaced.
+  ForOp forOp = cast<ForOp>(operand.getOwner());
+  assert(operand.getOperandNumber() >= forOp.getNumControlOperands() &&
+         "expected an iter OpOperand");
+  if (operand.get().getType() == replacement.getType())
+    return forOp;
+  SmallVector<Value> newIterOperands;
+  for (OpOperand &opOperand : forOp.getIterOpOperands()) {
+    if (opOperand.getOperandNumber() == operand.getOperandNumber()) {
+      newIterOperands.push_back(replacement);
+      continue;
+    }
+    newIterOperands.push_back(opOperand.get());
+  }
+
+  // 2. Create the new forOp shell.
+  scf::ForOp newForOp = rewriter.create<scf::ForOp>(
+      forOp.getLoc(), forOp.lowerBound(), forOp.upperBound(), forOp.step(),
+      newIterOperands);
+  Block &newBlock = newForOp.region().front();
+  SmallVector<Value, 4> newBlockTransferArgs(newBlock.getArguments().begin(),
+                                             newBlock.getArguments().end());
+
+  // 3. Inject an incoming cast op at the beginning of the block for the bbArg
+  // corresponding to the `replacement` value.
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(&newBlock, newBlock.begin());
+  BlockArgument newRegionIterArg = newForOp.getRegionIterArgForOpOperand(
+      newForOp->getOpOperand(operand.getOperandNumber()));
+  Value castIn = rewriter.create<tensor::CastOp>(newForOp.getLoc(), oldType,
+                                                 newRegionIterArg);
+  newBlockTransferArgs[newRegionIterArg.getArgNumber()] = castIn;
+
+  // 4. Steal the old block ops, mapping to the newBlockTransferArgs.
+  Block &oldBlock = forOp.region().front();
+  rewriter.mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
+
+  // 5. Inject an outgoing cast op at the end of the block and yield it instead.
+  auto clonedYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
+  rewriter.setInsertionPoint(clonedYieldOp);
+  unsigned yieldIdx =
+      newRegionIterArg.getArgNumber() - forOp.getNumInductionVars();
+  Value castOut = rewriter.create<tensor::CastOp>(
+      newForOp.getLoc(), newType, clonedYieldOp.getOperand(yieldIdx));
+  SmallVector<Value> newYieldOperands = clonedYieldOp.getOperands();
+  newYieldOperands[yieldIdx] = castOut;
+  rewriter.create<scf::YieldOp>(newForOp.getLoc(), newYieldOperands);
+  rewriter.eraseOp(clonedYieldOp);
+
+  // 6. Inject an outgoing cast op after the forOp.
+  rewriter.setInsertionPointAfter(newForOp);
+  SmallVector<Value> newResults = newForOp.getResults();
+  newResults[yieldIdx] = rewriter.create<tensor::CastOp>(
+      newForOp.getLoc(), oldType, newResults[yieldIdx]);
+
+  return newForOp;
+}
+
+/// Fold scf.for iter_arg/result pairs that go through incoming/ougoing
+/// a tensor.cast op pair so as to pull the tensor.cast inside the scf.for:
+///
+/// ```
+///   %0 = tensor.cast %t0 : tensor<32x1024xf32> to tensor<?x?xf32>
+///   %1 = scf.for %i = %c0 to %c1024 step %c32 iter_args(%iter_t0 = %0)
+///      -> (tensor<?x?xf32>) {
+///     %2 = call @do(%iter_t0) : (tensor<?x?xf32>) -> tensor<?x?xf32>
+///     scf.yield %2 : tensor<?x?xf32>
+///   }
+///   %2 = tensor.cast %1 : tensor<?x?xf32> to tensor<32x1024xf32>
+///   use_of(%2)
+/// ```
+///
+/// folds into:
+///
+/// ```
+///   %0 = scf.for %arg2 = %c0 to %c1024 step %c32 iter_args(%arg3 = %arg0)
+///       -> (tensor<32x1024xf32>) {
+///     %2 = tensor.cast %arg3 : tensor<32x1024xf32> to tensor<?x?xf32>
+///     %3 = call @do(%2) : (tensor<?x?xf32>) -> tensor<?x?xf32>
+///     %4 = tensor.cast %3 : tensor<?x?xf32> to tensor<32x1024xf32>
+///     scf.yield %4 : tensor<32x1024xf32>
+///   }
+///   use_of(%0)
+/// ```
+struct ForOpTensorCastFolder : public OpRewritePattern<ForOp> {
+  using OpRewritePattern<ForOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ForOp op,
+                                PatternRewriter &rewriter) const override {
+    for (auto it : llvm::zip(op.getIterOpOperands(), op.getResults())) {
+      OpOperand &iterOpOperand = std::get<0>(it);
+      auto incomingCast = iterOpOperand.get().getDefiningOp<tensor::CastOp>();
+      if (!incomingCast)
+        continue;
+      if (!std::get<1>(it).hasOneUse())
+        continue;
+      auto outgoingCastOp =
+          dyn_cast<tensor::CastOp>(*std::get<1>(it).user_begin());
+      if (!outgoingCastOp)
+        continue;
+
+      // Must be a tensor.cast op pair with matching types.
+      if (outgoingCastOp.getResult().getType() !=
+          incomingCast.source().getType())
+        continue;
+
+      // Create a new ForOp with that iter operand replaced.
+      auto newForOp = replaceTensorCastForOpIterArg(rewriter, iterOpOperand,
+                                                    incomingCast.source());
+
+      // Insert outgoing cast and use it to replace the corresponding result.
+      rewriter.setInsertionPointAfter(newForOp);
+      SmallVector<Value> replacements = newForOp.getResults();
+      unsigned returnIdx =
+          iterOpOperand.getOperandNumber() - op.getNumControlOperands();
+      replacements[returnIdx] = rewriter.create<tensor::CastOp>(
+          op.getLoc(), incomingCast.dest().getType(), replacements[returnIdx]);
+      rewriter.replaceOp(op, replacements);
+      return success();
+    }
+    return failure();
+  }
+};
+
 /// Canonicalize the iter_args of an scf::ForOp that involve a tensor_load and
 /// for which only the last loop iteration is actually visible outside of the
 /// loop. The canonicalization looks for a pattern such as:
@@ -706,7 +841,7 @@ struct LastTensorLoadCanonicalization : public OpRewritePattern<ForOp> {
 void ForOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                         MLIRContext *context) {
   results.add<ForOpIterArgsFolder, SimplifyTrivialLoops,
-              LastTensorLoadCanonicalization>(context);
+              LastTensorLoadCanonicalization, ForOpTensorCastFolder>(context);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index 3964f85ba3d2a..d0d9e9c9a847f 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -580,3 +580,33 @@ func @fold_away_iter_and_result_with_no_use(%arg0 : i32,
   // CHECK: return %[[FOR_RES]] : i32
   return %0#0 : i32
 }
+
+// -----
+
+func private @do(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32>
+
+// CHECK-LABEL: matmul_on_tensors
+//  CHECK-SAME:   %[[T0:[0-9a-z]*]]: tensor<32x1024xf32>
+//  CHECK-SAME:   %[[T1:[0-9a-z]*]]: tensor<1024x1024xf32>
+func @matmul_on_tensors(%t0: tensor<32x1024xf32>, %t1: tensor<1024x1024xf32>) -> tensor<1024x1024xf32> {
+  %c0 = constant 0 : index
+  %c32 = constant 32 : index
+  %c1024 = constant 1024 : index
+//   CHECK-NOT: tensor.cast
+//       CHECK: %[[FOR_RES:.*]] = scf.for {{.*}} iter_args(%[[ITER_T0:.*]] = %[[T0]]) -> (tensor<32x1024xf32>) {
+//       CHECK:   %[[CAST:.*]] = tensor.cast %[[ITER_T0]] : tensor<32x1024xf32> to tensor<?x?xf32>
+//       CHECK:   %[[DONE:.*]] = call @do(%[[CAST]]) : (tensor<?x?xf32>) -> tensor<?x?xf32>
+//       CHECK:   %[[UNCAST:.*]] = tensor.cast %[[DONE]] : tensor<?x?xf32> to tensor<32x1024xf32>
+//       CHECK:   scf.yield %[[UNCAST]] : tensor<32x1024xf32>
+  %0 = tensor.cast %t0 : tensor<32x1024xf32> to tensor<?x?xf32>
+  %1 = scf.for %i = %c0 to %c1024 step %c32 iter_args(%iter_t0 = %0) -> (tensor<?x?xf32>) {
+    %2 = call @do(%iter_t0) : (tensor<?x?xf32>) -> tensor<?x?xf32>
+    scf.yield %2 : tensor<?x?xf32>
+  }
+//   CHECK-NOT: tensor.cast
+//       CHECK: %[[RES:.*]] = subtensor_insert %[[FOR_RES]] into %[[T1]][0, 0] [32, 1024] [1, 1] : tensor<32x1024xf32> into tensor<1024x1024xf32>
+//       CHECK: return %[[RES]] : tensor<1024x1024xf32>
+  %2 = tensor.cast %1 : tensor<?x?xf32> to tensor<32x1024xf32>
+  %res = subtensor_insert %2 into %t1[0, 0] [32, 1024] [1, 1] : tensor<32x1024xf32> into tensor<1024x1024xf32>
+  return %res : tensor<1024x1024xf32>
+}


        


More information about the Mlir-commits mailing list