[Mlir-commits] [mlir] 843f1fc - [mlir][scf] Add scf.for + tensor.cast canonicalization pattern
Nicolas Vasilache
llvmlistbot at llvm.org
Fri Apr 16 09:55:07 PDT 2021
Author: Nicolas Vasilache
Date: 2021-04-16T16:50:21Z
New Revision: 843f1fc82598216a4be672ba51820b037dae106b
URL: https://github.com/llvm/llvm-project/commit/843f1fc82598216a4be672ba51820b037dae106b
DIFF: https://github.com/llvm/llvm-project/commit/843f1fc82598216a4be672ba51820b037dae106b.diff
LOG: [mlir][scf] Add scf.for + tensor.cast canonicalization pattern
Fold scf.for iter_arg/result pairs that go through incoming/ougoing
a tensor.cast op pair so as to pull the tensor.cast inside the scf.for:
```
%0 = tensor.cast %t0 : tensor<32x1024xf32> to tensor<?x?xf32>
%1 = scf.for %i = %c0 to %c1024 step %c32 iter_args(%iter_t0 = %0)
-> (tensor<?x?xf32>) {
%2 = call @do(%iter_t0) : (tensor<?x?xf32>) -> tensor<?x?xf32>
scf.yield %2 : tensor<?x?xf32>
}
%2 = tensor.cast %1 : tensor<?x?xf32> to tensor<32x1024xf32>
use_of(%2)
```
folds into:
```
%0 = scf.for %arg2 = %c0 to %c1024 step %c32 iter_args(%arg3 = %arg0)
-> (tensor<32x1024xf32>) {
%2 = tensor.cast %arg3 : tensor<32x1024xf32> to tensor<?x?xf32>
%3 = call @do(%2) : (tensor<?x?xf32>) -> tensor<?x?xf32>
%4 = tensor.cast %3 : tensor<?x?xf32> to tensor<32x1024xf32>
scf.yield %4 : tensor<32x1024xf32>
}
use_of(%0)
```
Differential Revision: https://reviews.llvm.org/D100661
Added:
Modified:
mlir/lib/Dialect/SCF/SCF.cpp
mlir/test/Dialect/SCF/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp
index efbc87273ca82..fa4fb9ffef338 100644
--- a/mlir/lib/Dialect/SCF/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/SCF.cpp
@@ -9,6 +9,7 @@
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/MathExtras.h"
@@ -578,6 +579,140 @@ struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {
}
};
+/// Perform a replacement of one iter OpOperand of an scf.for to the
+/// `replacement` value which is expected to be the source of a tensor.cast.
+/// tensor.cast ops are inserted inside the block to account for the type cast.
+static ForOp replaceTensorCastForOpIterArg(PatternRewriter &rewriter,
+ OpOperand &operand,
+ Value replacement) {
+ Type oldType = operand.get().getType(), newType = replacement.getType();
+ assert(oldType.isa<RankedTensorType>() && newType.isa<RankedTensorType>() &&
+ "expected ranked tensor types");
+
+ // 1. Create new iter operands, exactly 1 is replaced.
+ ForOp forOp = cast<ForOp>(operand.getOwner());
+ assert(operand.getOperandNumber() >= forOp.getNumControlOperands() &&
+ "expected an iter OpOperand");
+ if (operand.get().getType() == replacement.getType())
+ return forOp;
+ SmallVector<Value> newIterOperands;
+ for (OpOperand &opOperand : forOp.getIterOpOperands()) {
+ if (opOperand.getOperandNumber() == operand.getOperandNumber()) {
+ newIterOperands.push_back(replacement);
+ continue;
+ }
+ newIterOperands.push_back(opOperand.get());
+ }
+
+ // 2. Create the new forOp shell.
+ scf::ForOp newForOp = rewriter.create<scf::ForOp>(
+ forOp.getLoc(), forOp.lowerBound(), forOp.upperBound(), forOp.step(),
+ newIterOperands);
+ Block &newBlock = newForOp.region().front();
+ SmallVector<Value, 4> newBlockTransferArgs(newBlock.getArguments().begin(),
+ newBlock.getArguments().end());
+
+ // 3. Inject an incoming cast op at the beginning of the block for the bbArg
+ // corresponding to the `replacement` value.
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(&newBlock, newBlock.begin());
+ BlockArgument newRegionIterArg = newForOp.getRegionIterArgForOpOperand(
+ newForOp->getOpOperand(operand.getOperandNumber()));
+ Value castIn = rewriter.create<tensor::CastOp>(newForOp.getLoc(), oldType,
+ newRegionIterArg);
+ newBlockTransferArgs[newRegionIterArg.getArgNumber()] = castIn;
+
+ // 4. Steal the old block ops, mapping to the newBlockTransferArgs.
+ Block &oldBlock = forOp.region().front();
+ rewriter.mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
+
+ // 5. Inject an outgoing cast op at the end of the block and yield it instead.
+ auto clonedYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
+ rewriter.setInsertionPoint(clonedYieldOp);
+ unsigned yieldIdx =
+ newRegionIterArg.getArgNumber() - forOp.getNumInductionVars();
+ Value castOut = rewriter.create<tensor::CastOp>(
+ newForOp.getLoc(), newType, clonedYieldOp.getOperand(yieldIdx));
+ SmallVector<Value> newYieldOperands = clonedYieldOp.getOperands();
+ newYieldOperands[yieldIdx] = castOut;
+ rewriter.create<scf::YieldOp>(newForOp.getLoc(), newYieldOperands);
+ rewriter.eraseOp(clonedYieldOp);
+
+ // 6. Inject an outgoing cast op after the forOp.
+ rewriter.setInsertionPointAfter(newForOp);
+ SmallVector<Value> newResults = newForOp.getResults();
+ newResults[yieldIdx] = rewriter.create<tensor::CastOp>(
+ newForOp.getLoc(), oldType, newResults[yieldIdx]);
+
+ return newForOp;
+}
+
+/// Fold scf.for iter_arg/result pairs that go through incoming/ougoing
+/// a tensor.cast op pair so as to pull the tensor.cast inside the scf.for:
+///
+/// ```
+/// %0 = tensor.cast %t0 : tensor<32x1024xf32> to tensor<?x?xf32>
+/// %1 = scf.for %i = %c0 to %c1024 step %c32 iter_args(%iter_t0 = %0)
+/// -> (tensor<?x?xf32>) {
+/// %2 = call @do(%iter_t0) : (tensor<?x?xf32>) -> tensor<?x?xf32>
+/// scf.yield %2 : tensor<?x?xf32>
+/// }
+/// %2 = tensor.cast %1 : tensor<?x?xf32> to tensor<32x1024xf32>
+/// use_of(%2)
+/// ```
+///
+/// folds into:
+///
+/// ```
+/// %0 = scf.for %arg2 = %c0 to %c1024 step %c32 iter_args(%arg3 = %arg0)
+/// -> (tensor<32x1024xf32>) {
+/// %2 = tensor.cast %arg3 : tensor<32x1024xf32> to tensor<?x?xf32>
+/// %3 = call @do(%2) : (tensor<?x?xf32>) -> tensor<?x?xf32>
+/// %4 = tensor.cast %3 : tensor<?x?xf32> to tensor<32x1024xf32>
+/// scf.yield %4 : tensor<32x1024xf32>
+/// }
+/// use_of(%0)
+/// ```
+struct ForOpTensorCastFolder : public OpRewritePattern<ForOp> {
+ using OpRewritePattern<ForOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ForOp op,
+ PatternRewriter &rewriter) const override {
+ for (auto it : llvm::zip(op.getIterOpOperands(), op.getResults())) {
+ OpOperand &iterOpOperand = std::get<0>(it);
+ auto incomingCast = iterOpOperand.get().getDefiningOp<tensor::CastOp>();
+ if (!incomingCast)
+ continue;
+ if (!std::get<1>(it).hasOneUse())
+ continue;
+ auto outgoingCastOp =
+ dyn_cast<tensor::CastOp>(*std::get<1>(it).user_begin());
+ if (!outgoingCastOp)
+ continue;
+
+ // Must be a tensor.cast op pair with matching types.
+ if (outgoingCastOp.getResult().getType() !=
+ incomingCast.source().getType())
+ continue;
+
+ // Create a new ForOp with that iter operand replaced.
+ auto newForOp = replaceTensorCastForOpIterArg(rewriter, iterOpOperand,
+ incomingCast.source());
+
+ // Insert outgoing cast and use it to replace the corresponding result.
+ rewriter.setInsertionPointAfter(newForOp);
+ SmallVector<Value> replacements = newForOp.getResults();
+ unsigned returnIdx =
+ iterOpOperand.getOperandNumber() - op.getNumControlOperands();
+ replacements[returnIdx] = rewriter.create<tensor::CastOp>(
+ op.getLoc(), incomingCast.dest().getType(), replacements[returnIdx]);
+ rewriter.replaceOp(op, replacements);
+ return success();
+ }
+ return failure();
+ }
+};
+
/// Canonicalize the iter_args of an scf::ForOp that involve a tensor_load and
/// for which only the last loop iteration is actually visible outside of the
/// loop. The canonicalization looks for a pattern such as:
@@ -706,7 +841,7 @@ struct LastTensorLoadCanonicalization : public OpRewritePattern<ForOp> {
void ForOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<ForOpIterArgsFolder, SimplifyTrivialLoops,
- LastTensorLoadCanonicalization>(context);
+ LastTensorLoadCanonicalization, ForOpTensorCastFolder>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index 3964f85ba3d2a..d0d9e9c9a847f 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -580,3 +580,33 @@ func @fold_away_iter_and_result_with_no_use(%arg0 : i32,
// CHECK: return %[[FOR_RES]] : i32
return %0#0 : i32
}
+
+// -----
+
+func private @do(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32>
+
+// CHECK-LABEL: matmul_on_tensors
+// CHECK-SAME: %[[T0:[0-9a-z]*]]: tensor<32x1024xf32>
+// CHECK-SAME: %[[T1:[0-9a-z]*]]: tensor<1024x1024xf32>
+func @matmul_on_tensors(%t0: tensor<32x1024xf32>, %t1: tensor<1024x1024xf32>) -> tensor<1024x1024xf32> {
+ %c0 = constant 0 : index
+ %c32 = constant 32 : index
+ %c1024 = constant 1024 : index
+// CHECK-NOT: tensor.cast
+// CHECK: %[[FOR_RES:.*]] = scf.for {{.*}} iter_args(%[[ITER_T0:.*]] = %[[T0]]) -> (tensor<32x1024xf32>) {
+// CHECK: %[[CAST:.*]] = tensor.cast %[[ITER_T0]] : tensor<32x1024xf32> to tensor<?x?xf32>
+// CHECK: %[[DONE:.*]] = call @do(%[[CAST]]) : (tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK: %[[UNCAST:.*]] = tensor.cast %[[DONE]] : tensor<?x?xf32> to tensor<32x1024xf32>
+// CHECK: scf.yield %[[UNCAST]] : tensor<32x1024xf32>
+ %0 = tensor.cast %t0 : tensor<32x1024xf32> to tensor<?x?xf32>
+ %1 = scf.for %i = %c0 to %c1024 step %c32 iter_args(%iter_t0 = %0) -> (tensor<?x?xf32>) {
+ %2 = call @do(%iter_t0) : (tensor<?x?xf32>) -> tensor<?x?xf32>
+ scf.yield %2 : tensor<?x?xf32>
+ }
+// CHECK-NOT: tensor.cast
+// CHECK: %[[RES:.*]] = subtensor_insert %[[FOR_RES]] into %[[T1]][0, 0] [32, 1024] [1, 1] : tensor<32x1024xf32> into tensor<1024x1024xf32>
+// CHECK: return %[[RES]] : tensor<1024x1024xf32>
+ %2 = tensor.cast %1 : tensor<?x?xf32> to tensor<32x1024xf32>
+ %res = subtensor_insert %2 into %t1[0, 0] [32, 1024] [1, 1] : tensor<32x1024xf32> into tensor<1024x1024xf32>
+ return %res : tensor<1024x1024xf32>
+}
More information about the Mlir-commits
mailing list