[Mlir-commits] [mlir] 193cefd - [mlir][tensor] Adapt FoldTensorCastProducerOp pattern on DPS interface.
Hanhan Wang
llvmlistbot at llvm.org
Tue Dec 6 12:13:44 PST 2022
Author: Hanhan Wang
Date: 2022-12-06T12:13:37-08:00
New Revision: 193cefd1b1f25eec3ea9e9ebd3faaa2e16caabc0
URL: https://github.com/llvm/llvm-project/commit/193cefd1b1f25eec3ea9e9ebd3faaa2e16caabc0
DIFF: https://github.com/llvm/llvm-project/commit/193cefd1b1f25eec3ea9e9ebd3faaa2e16caabc0.diff
LOG: [mlir][tensor] Adapt FoldTensorCastProducerOp pattern on DPS interface.
This revision adapts the pattern in LinAlg to work on DPS interface, and
adds it to canonicalization patterns of tensor dialect. The
InsertSliceOp is skipped in the pattern because it has its own logic
about folding tensor.cast ops.
Reviewed By: pifon2a
Differential Revision: https://reviews.llvm.org/D139375
Added:
Modified:
mlir/include/mlir/Dialect/Tensor/IR/TensorBase.td
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
mlir/test/Dialect/Tensor/tiling.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorBase.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorBase.td
index 1c380bdc1717f..fe49f8db9810d 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorBase.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorBase.td
@@ -45,6 +45,7 @@ def Tensor_Dialect : Dialect {
}];
+ let hasCanonicalizer = 1;
let hasConstantMaterializer = 1;
let dependentDialects = [
"AffineDialect",
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index a6f42c9577ef7..98b1406d98482 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1745,61 +1745,6 @@ struct EraseDeadLinalgOp : public OpInterfaceRewritePattern<LinalgOp> {
}
};
-struct FoldTensorCastProducerOp : public OpInterfaceRewritePattern<LinalgOp> {
- using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern;
-
- LogicalResult matchAndRewrite(LinalgOp op,
- PatternRewriter &rewriter) const override {
- // If no operand comes from a tensor::CastOp and can be folded then fail.
- bool hasTensorCastOperand =
- llvm::any_of(op->getOpOperands(), [&](OpOperand &opOperand) {
- if (opOperand.get().isa<BlockArgument>())
- return false;
- auto castOp = opOperand.get().getDefiningOp<tensor::CastOp>();
- return castOp && canFoldIntoConsumerOp(castOp);
- });
- if (!hasTensorCastOperand)
- return failure();
-
- SmallVector<Type, 4> newResultTypes;
- newResultTypes.reserve(op->getNumResults());
- SmallVector<Value, 4> newOperands;
- newOperands.reserve(op->getNumOperands());
- // Inputs may fold.
- for (auto *input : op.getDpsInputOperands()) {
- auto tensorCastOp = input->get().getDefiningOp<tensor::CastOp>();
- newOperands.push_back(canFoldIntoConsumerOp(tensorCastOp)
- ? tensorCastOp.getSource()
- : input->get());
- }
- // Init tensors may fold, in which case the resultType must also change.
- for (auto *output : op.getDpsInitOperands()) {
- auto tensorCastOp = output->get().getDefiningOp<tensor::CastOp>();
- bool fold = canFoldIntoConsumerOp(tensorCastOp);
- newOperands.push_back(fold ? tensorCastOp.getOperand() : output->get());
- if (!newOperands.back().getType().isa<MemRefType>())
- newResultTypes.push_back(newOperands.back().getType());
- }
- // Clone op.
- Operation *newOp = clone(rewriter, op, newResultTypes, newOperands);
- SmallVector<Value, 4> replacements;
- replacements.reserve(newOp->getNumResults());
- for (auto result : llvm::zip(op->getResults(), newOp->getResults())) {
- Value oldResult = std::get<0>(result);
- Value newResult = std::get<1>(result);
- if (newResult.getType() != oldResult.getType()) {
- replacements.push_back(rewriter.create<tensor::CastOp>(
- op->getLoc(), oldResult.getType(), newResult));
- } else {
- replacements.push_back(newResult);
- }
- }
- rewriter.replaceOp(op, replacements);
-
- return success();
- }
-};
-
/// Fold LinalgOps with `tensor.cast` consumer if the `tensor.cast` has
/// result that is more static than the linalg op.
struct FoldTensorCastConsumerOp : public OpRewritePattern<tensor::CastOp> {
@@ -2023,8 +1968,7 @@ struct InferStaticShapeOfOperands : public OpInterfaceRewritePattern<LinalgOp> {
void LinalgDialect::getCanonicalizationPatterns(
RewritePatternSet &results) const {
results.add<EraseDeadLinalgOp, FoldTensorCastConsumerOp,
- FoldTensorCastProducerOp, InferStaticShapeOfOperands>(
- getContext());
+ InferStaticShapeOfOperands>(getContext());
}
Operation *LinalgDialect::materializeConstant(OpBuilder &builder,
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 4e95243c20ed2..0b24149a3e98a 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -3433,6 +3433,89 @@ LogicalResult UnPackOp::canonicalize(UnPackOp unpackOp,
return success();
}
+//===----------------------------------------------------------------------===//
+// Common Canonicalizers and Folders.
+//===----------------------------------------------------------------------===//
+
+/// Folds a tensor.cast op into a consuming DestinationStyleOpInterface op if
+/// the `tensor.cast` has source that is more static than the consuming op.
+///
+/// Example:
+/// ```mlir
+/// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
+/// %2 = consumer %1 ... : tensor<?x?xf32> ...
+/// ```
+///
+/// folds into:
+///
+/// ```mlir
+/// %2 = consumer %0 ... : tensor<8x16xf32> ...
+/// ```
+/// TODO: Move the pattern to a proper place, so all other DestinationStyleOp
+/// can add the pattern to their canonicalizers.
+struct FoldTensorCastProducerOp
+ : public OpInterfaceRewritePattern<DestinationStyleOpInterface> {
+ using OpInterfaceRewritePattern<
+ DestinationStyleOpInterface>::OpInterfaceRewritePattern;
+
+ LogicalResult matchAndRewrite(DestinationStyleOpInterface op,
+ PatternRewriter &rewriter) const override {
+ // InsertSliceOp has its own logic about folding tensor.cast ops.
+ if (isa<InsertSliceOp>(op.getOperation()))
+ return failure();
+
+ // If no operand comes from a tensor::CastOp and can be folded then fail.
+ bool hasTensorCastOperand =
+ llvm::any_of(op->getOpOperands(), [&](OpOperand &opOperand) {
+ if (opOperand.get().isa<BlockArgument>())
+ return false;
+ auto castOp = opOperand.get().getDefiningOp<tensor::CastOp>();
+ return castOp && canFoldIntoConsumerOp(castOp);
+ });
+ if (!hasTensorCastOperand)
+ return failure();
+
+ SmallVector<Type, 4> newResultTypes;
+ newResultTypes.reserve(op->getNumResults());
+ SmallVector<Value, 4> newOperands;
+ newOperands.reserve(op->getNumOperands());
+ for (OpOperand &opOperand : op->getOpOperands()) {
+ auto tensorCastOp = opOperand.get().getDefiningOp<tensor::CastOp>();
+ bool fold = canFoldIntoConsumerOp(tensorCastOp);
+ newOperands.push_back(fold ? tensorCastOp.getOperand() : opOperand.get());
+ if (op.isDpsInit(&opOperand) &&
+ !newOperands.back().getType().isa<MemRefType>())
+ newResultTypes.push_back(newOperands.back().getType());
+ }
+
+ // Clone op.
+ Operation *newOp = clone(rewriter, op, newResultTypes, newOperands);
+ SmallVector<Value, 4> replacements;
+ replacements.reserve(newOp->getNumResults());
+ for (auto [oldResult, newResult] :
+ llvm::zip(op->getResults(), newOp->getResults())) {
+ if (newResult.getType() != oldResult.getType()) {
+ replacements.push_back(rewriter.create<tensor::CastOp>(
+ op->getLoc(), oldResult.getType(), newResult));
+ } else {
+ replacements.push_back(newResult);
+ }
+ }
+ rewriter.replaceOp(op, replacements);
+
+ return success();
+ }
+};
+
+//===----------------------------------------------------------------------===//
+// TensorDialect
+//===----------------------------------------------------------------------===//
+
+void TensorDialect::getCanonicalizationPatterns(
+ RewritePatternSet &results) const {
+ results.add<FoldTensorCastProducerOp>(getContext());
+}
+
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Tensor/tiling.mlir b/mlir/test/Dialect/Tensor/tiling.mlir
index 612367916bb4c..1f87101131b44 100644
--- a/mlir/test/Dialect/Tensor/tiling.mlir
+++ b/mlir/test/Dialect/Tensor/tiling.mlir
@@ -18,11 +18,9 @@
// CHECK-DAG: %[[IN_C_SZ:.*]] = affine.min #[[MAP2]]
// CHECK: %[[SUB_IN:.*]] = tensor.extract_slice %[[IN]][%[[IN_N]], %[[IN_C]]] [%[[IN_N_SZ]], %[[IN_C_SZ]]] [1, 1] : tensor<128x256xf32> to tensor<?x?xf32>
// CHECK: %[[SUB_OUT:.*]] = tensor.extract_slice %[[ITER1]][%[[N]], %[[C]], 0, 0] [2, 4, 32, 32] [1, 1, 1, 1] : tensor<4x8x32x32xf32> to tensor<2x4x32x32xf32>
-// CHECK: %[[CAST_OUT:.*]] = tensor.cast %[[SUB_OUT]]
// CHECK: %[[SUB_RES:.*]] = tensor.pack
-// CHECK-SAME: %[[SUB_IN]] inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %[[CAST_OUT]]
-// CHECK: %[[CAST_RES:.*]] = tensor.cast %[[SUB_RES]]
-// CHECK: %[[INSERT:.*]] = tensor.insert_slice %[[CAST_RES]] into %[[ITER1]]
+// CHECK-SAME: %[[SUB_IN]] inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %[[SUB_OUT]]
+// CHECK: %[[INSERT:.*]] = tensor.insert_slice %[[SUB_RES]] into %[[ITER1]]
// CHECK: scf.yield %[[INSERT]] : tensor<4x8x32x32xf32>
// CHECK: }
// CHECK: scf.yield %[[RES1:.*]] : tensor<4x8x32x32xf32>
@@ -55,12 +53,10 @@ transform.sequence failures(propagate) {
// CHECK-DAG: %[[IN_C_SZ:.+]] = affine.min #[[MAP1]](%[[C]])
// CHECK: %[[INPUT_SLICE:.+]] = tensor.extract_slice %[[IN]]
// CHECK-SAME: [0, %[[IN_C]]] [128, %[[IN_C_SZ]]]
-// CHECK: %[[CAST_IN:.+]] = tensor.cast %[[INPUT_SLICE]]
// CHECK: %[[OUTPUT_SLICE:.+]] = tensor.extract_slice %{{.+}}[%[[C]], 0, 0, 0] [2, 4, 32, 8]
-// CHECK: %[[CAST_OUT:.+]] = tensor.cast %[[OUTPUT_SLICE]]
// CHECK: tensor.pack
-// CHECK-SAME: %[[CAST_IN]] outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 8]
-// CHECK-SAME: into %[[CAST_OUT]]
+// CHECK-SAME: %[[INPUT_SLICE]] outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 8]
+// CHECK-SAME: into %[[OUTPUT_SLICE]]
func.func @KC_to_CKkc(%arg0: tensor<128x256xf32>, %arg1: tensor<32x4x32x8xf32>) -> tensor<32x4x32x8xf32> {
%0 = tensor.pack %arg0 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 8] into %arg1 : tensor<128x256xf32> -> tensor<32x4x32x8xf32>
return %0 : tensor<32x4x32x8xf32>
@@ -87,14 +83,11 @@ transform.sequence failures(propagate) {
// CHECK-DAG: %[[IN_J:.*]] = affine.apply #[[MAP0]](%[[J]])
// CHECK-DAG: %[[IN_J_SZ:.*]] = affine.min #[[MAP1]](%[[J]])
// CHECK: %[[SUB_IN:.*]] = tensor.extract_slice %[[IN]][0, %[[IN_J]]] [13, %[[IN_J_SZ]]] [1, 1]
-// CHECK: %[[CAST_IN:.*]] = tensor.cast %[[SUB_IN]]
// CHECK: %[[SUB_OUT:.*]] = tensor.extract_slice %[[ITER1]][0, %[[J]], 0, 0] [2, 4, 8, 2] [1, 1, 1, 1]
-// CHECK: %[[CAST_OUT:.*]] = tensor.cast %[[SUB_OUT]]
// CHECK: %[[SUB_RES:.*]] = tensor.pack
-// CHECK-SAME: %[[CAST_IN]] padding_value(%[[PAD]] : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 2]
-// CHECK-SAME: into %[[CAST_OUT]]
-// CHECK: %[[CAST_RES:.*]] = tensor.cast %[[SUB_RES]]
-// CHECK: %[[INSERT:.*]] = tensor.insert_slice %[[CAST_RES]] into %[[ITER1]]
+// CHECK-SAME: %[[SUB_IN]] padding_value(%[[PAD]] : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 2]
+// CHECK-SAME: into %[[SUB_OUT]]
+// CHECK: %[[INSERT:.*]] = tensor.insert_slice %[[SUB_RES]] into %[[ITER1]]
// CHECK: scf.yield %[[INSERT]] : tensor<2x8x8x2xf32>
// CHECK: }
// CHECK: return %[[RES0:.*]] : tensor<2x8x8x2xf32>
More information about the Mlir-commits
mailing list