[Mlir-commits] [mlir] 346b9d1 - [mlir][Linalg] Canonicalize TensorCastOp away when it feeds a LinalgOp.
Nicolas Vasilache
llvmlistbot at llvm.org
Mon Oct 5 07:50:18 PDT 2020
Author: Nicolas Vasilache
Date: 2020-10-05T14:48:21Z
New Revision: 346b9d17720a0ccd920cd02b81811a4d2ddc67d6
URL: https://github.com/llvm/llvm-project/commit/346b9d17720a0ccd920cd02b81811a4d2ddc67d6
DIFF: https://github.com/llvm/llvm-project/commit/346b9d17720a0ccd920cd02b81811a4d2ddc67d6.diff
LOG: [mlir][Linalg] Canonicalize TensorCastOp away when it feeds a LinalgOp.
This canonicalization is the counterpart of MemRefCastOp -> LinalgOp but on tensors.
This is needed to properly canonicalize post linalg tiling on tensors.
Differential Revision: https://reviews.llvm.org/D88729
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/lib/Dialect/StandardOps/IR/Ops.cpp
mlir/test/Dialect/Linalg/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
index f51f7b913027..44c6b77ee404 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
@@ -404,6 +404,19 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
return getInitTensors()[i];
}]
>,
+ InterfaceMethod<
+ /*desc=*/[{
+ Return the number of inputs, output buffers and init tensors operands.
+ }],
+ /*retTy=*/"unsigned",
+ /*methodName=*/"getNumShapedOperands",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ auto range = this->getOperation()->getOperands();
+ return getNumInputsAndOutputBuffers() + $_op.getNumInitTensors();
+ }]
+ >,
InterfaceMethod<
/*desc=*/[{
Return the range over inputs, output buffers and init tensors.
@@ -414,7 +427,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
/*methodBody=*/"",
/*defaultImplementation=*/[{
auto range = this->getOperation()->getOperands();
- return {range.begin(), range.begin() + getNumInputsAndOutputs()};
+ return {range.begin(), range.begin() + getNumShapedOperands()};
}]
>,
InterfaceMethod<
@@ -621,6 +634,27 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
}]
>
];
+
+ let extraClassDeclaration = [{
+ /// Returns all the operands past the inputs, output_buffers and
+ /// init_tensors operands. Asserts that these operands are value types to
+ /// allow transformations like tiling to just use the values when cloning
+ /// `linalgOp`.
+ SmallVector<Value, 4> getAssumedNonShapedOperands() {
+ unsigned numShapedOperands = getNumInputsAndOutputs();
+ unsigned nExtraOperands =
+ getOperation()->getNumOperands() - numShapedOperands;
+ SmallVector<Value, 4> res;
+ res.reserve(nExtraOperands);
+ for (unsigned i = 0; i < nExtraOperands; ++i) {
+ res.push_back(getOperation()->getOperand(numShapedOperands + i));
+ assert((res.back().getType().isSignlessIntOrIndexOrFloat()
+ || res.back().getType().isa<VectorType>()) &&
+ "expected scalar or vector type");
+ }
+ return res;
+ }
+ }];
}
#endif // LINALG_IR_STRUCTURED_OPS_INTERFACE
diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
index fbe735e31cff..409f54384aca 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
@@ -350,6 +350,31 @@ ParseResult parseDimAndSymbolList(OpAsmParser &parser,
/// ```
bool canFoldIntoConsumerOp(MemRefCastOp castOp);
+/// Counterpart of `canFoldIntoConsumerOp(MemRefCastOp castOp)` for tensors.
+/// Determines whether TensorCastOp casts to a more dynamic version of the
+/// source tensor. This is useful to fold a tensor_cast into a consuming op and
+/// implement canonicalization patterns for ops in
diff erent dialects that may
+/// consume the results of tensor_cast operations. Such foldable tensor_cast
+/// operations are typically inserted as `subtensor` ops and are canonicalized,
+/// to preserve the type compatibility of their uses.
+///
+/// Returns true when all conditions are met:
+/// 1. source and result are ranked tensors with same element type and rank.
+/// 2. the tensor type has more static information than the result
+///
+/// Example:
+/// ```mlir
+/// %1 = tensor_cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
+/// %2 = consumer %1 ... : tensor<?x?xf32> ...
+/// ```
+///
+/// folds into:
+///
+/// ```mlir
+/// %2 = consumer %0 ... : tensor<8x16xf32> ...
+/// ```
+bool canFoldIntoConsumerOp(TensorCastOp castOp);
+
/// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer
/// comparison predicates.
bool applyCmpPredicate(CmpIPredicate predicate, const APInt &lhs,
diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index 69c979ae9e38..ab7b599dffba 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -3334,7 +3334,7 @@ def TensorCastOp : CastOp<"tensor_cast"> {
```
}];
- let arguments = (ins AnyTensor);
+ let arguments = (ins AnyTensor:$source);
let results = (outs AnyTensor);
let extraClassDeclaration = [{
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index e9cdb3391f4a..26aa75955e3c 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -25,6 +25,7 @@
#include "mlir/IR/StandardTypes.h"
#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/MathExtras.h"
@@ -1498,12 +1499,65 @@ struct EraseDeadLinalgOp : public RewritePattern {
return failure();
}
};
+
+struct FoldTensorCastOp : public RewritePattern {
+ FoldTensorCastOp(PatternBenefit benefit = 1)
+ : RewritePattern(benefit, MatchAnyOpTypeTag()) {}
+
+ LogicalResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const override {
+ auto linalgOp = dyn_cast<LinalgOp>(op);
+ if (!linalgOp)
+ return failure();
+
+ // If no operand comes from a TensorCastOp and can be folded then fail.
+ bool hasTensorCastOperand =
+ llvm::any_of(linalgOp.getShapedOperands(), [&](Value v) {
+ if (v.isa<BlockArgument>())
+ return false;
+ auto castOp = v.getDefiningOp<TensorCastOp>();
+ return castOp && canFoldIntoConsumerOp(castOp);
+ });
+ if (!hasTensorCastOperand)
+ return failure();
+
+ SmallVector<Type, 4> newResultTypes;
+ newResultTypes.reserve(op->getNumResults());
+ SmallVector<Value, 4> newOperands;
+ newOperands.reserve(op->getNumOperands());
+ // Inputs may fold.
+ for (Value v : linalgOp.getInputs()) {
+ auto tensorCastOp = v.getDefiningOp<TensorCastOp>();
+ newOperands.push_back(
+ canFoldIntoConsumerOp(tensorCastOp) ? tensorCastOp.source() : v);
+ }
+ // Output buffers are memrefs, they don't fold.
+ newOperands.append(linalgOp.getOutputBuffers().begin(),
+ linalgOp.getOutputBuffers().end());
+ // Init tensors may fold, in which case the resultType must also change.
+ for (Value v : linalgOp.getInitTensors()) {
+ auto tensorCastOp = v.getDefiningOp<TensorCastOp>();
+ bool fold = canFoldIntoConsumerOp(tensorCastOp);
+ newOperands.push_back(fold ? tensorCastOp.getOperand() : v);
+ newResultTypes.push_back(newOperands.back().getType());
+ }
+ auto extraOperands = linalgOp.getAssumedNonShapedOperands();
+ newOperands.append(extraOperands.begin(), extraOperands.end());
+ // Clone op.
+ Operation *newOp =
+ linalgOp.clone(rewriter, op->getLoc(), newResultTypes, newOperands);
+ rewriter.replaceOp(op, newOp->getResults());
+
+ return success();
+ }
+};
} // namespace
#define CANONICALIZERS_AND_FOLDERS(XXX) \
void XXX::getCanonicalizationPatterns(OwningRewritePatternList &results, \
MLIRContext *context) { \
results.insert<EraseDeadLinalgOp>(); \
+ results.insert<FoldTensorCastOp>(); \
} \
\
LogicalResult XXX::fold(ArrayRef<Attribute>, \
diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index a4d739135aea..f2823c564cce 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -3157,6 +3157,60 @@ bool mlir::canFoldIntoConsumerOp(MemRefCastOp castOp) {
return true;
}
+/// Counterpart of `canFoldIntoConsumerOp(MemRefCastOp castOp)` for tensors.
+/// Determines whether TensorCastOp casts to a more dynamic version of the
+/// source tensor. This is useful to fold a tensor_cast into a consuming op and
+/// implement canonicalization patterns for ops in
diff erent dialects that may
+/// consume the results of tensor_cast operations. Such foldable tensor_cast
+/// operations are typically inserted as `subtensor` ops and are canonicalized,
+/// to preserve the type compatibility of their uses.
+///
+/// Returns true when all conditions are met:
+/// 1. source and result are ranked tensors with same element type and rank.
+/// 2. the tensor type has more static information than the result
+///
+/// Example:
+/// ```mlir
+/// %1 = tensor_cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
+/// %2 = consumer %1 ... : tensor<?x?xf32> ...
+/// ```
+///
+/// folds into:
+///
+/// ```mlir
+/// %2 = consumer %0 ... : tensor<8x16xf32> ...
+/// ```
+bool mlir::canFoldIntoConsumerOp(TensorCastOp castOp) {
+ if (!castOp)
+ return false;
+
+ RankedTensorType sourceType =
+ castOp.source().getType().dyn_cast<RankedTensorType>();
+ RankedTensorType resultType = castOp.getType().dyn_cast<RankedTensorType>();
+
+ // Requires RankedTensorType.
+ if (!sourceType || !resultType)
+ return false;
+
+ // Requires same elemental type.
+ if (sourceType.getElementType() != resultType.getElementType())
+ return false;
+
+ // Requires same rank.
+ if (sourceType.getRank() != resultType.getRank())
+ return false;
+
+ // If cast is towards more static sizes along any dimension, don't fold.
+ for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) {
+ auto ss = std::get<0>(it), st = std::get<1>(it);
+ if (ss != st)
+ if (ShapedType::isDynamic(ss) && !ShapedType::isDynamic(st))
+ return false;
+ }
+
+ return true;
+}
+
namespace {
/// Pattern to rewrite a subview op with MemRefCast arguments.
/// This essentially pushes memref_cast past its consuming subview when
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 5e0890fa4bb5..cf86a97f4fcd 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -259,3 +259,23 @@ func @reshape_splat_constant_float64() -> tensor<2x4x2xf64>
// CHECK: %[[CST:.*]] = constant dense<{{.*}}> : tensor<2x4x2xf64>
// CHECK-NOT: linalg.tensor_reshape
// CHECK: return %[[CST]]
+
+// -----
+
+// CHECK-LABEL: func @tensor_cast(
+func @tensor_cast(%a : tensor<3x4xf32>, %b : tensor<4x?xf32>, %c : tensor<3x?xf32>)
+ -> tensor<3x?xf32>
+{
+ %ta = tensor_cast %a : tensor<3x4xf32> to tensor<?x?xf32>
+ %tb = tensor_cast %b : tensor<4x?xf32> to tensor<?x?xf32>
+ %tc = tensor_cast %c : tensor<3x?xf32> to tensor<?x?xf32>
+
+ // CHECK: linalg.matmul ins({{.*}}tensor<3x4xf32>, tensor<4x?xf32>)
+ // CHECK-SAME: init({{.*}}tensor<3x?xf32>) -> tensor<3x?xf32>
+ %0 = linalg.matmul ins(%ta, %tb: tensor<?x?xf32>, tensor<?x?xf32>)
+ init(%tc: tensor<?x?xf32>) -> tensor<?x?xf32>
+
+ %1 = tensor_cast %0 : tensor<?x?xf32> to tensor<3x?xf32>
+
+ return %1: tensor<3x?xf32>
+}
More information about the Mlir-commits
mailing list