[Mlir-commits] [mlir] 589eac6 - [mlir] Add canonicalizations for op -> tensor.cast folding.
Mahesh Ravishankar
llvmlistbot at llvm.org
Tue Mar 8 10:27:16 PST 2022
Author: Mahesh Ravishankar
Date: 2022-03-08T18:26:55Z
New Revision: 589eac6524d6ba080a51107757c1f356c365d047
URL: https://github.com/llvm/llvm-project/commit/589eac6524d6ba080a51107757c1f356c365d047
DIFF: https://github.com/llvm/llvm-project/commit/589eac6524d6ba080a51107757c1f356c365d047.diff
LOG: [mlir] Add canonicalizations for op -> tensor.cast folding.
A `tensor.cast` consumer can be folded with its producer. This is
beneficial only if the result of the tensor cast is more static than
the source. This patch adds a utility function to check that this is
the case, and adds a couple of canonicalizations patterns that fold an
operation with `tensor.cast` conusmers.
Reviewed By: gysit
Differential Revision: https://reviews.llvm.org/D120950
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
mlir/test/Dialect/Linalg/canonicalize.mlir
mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index 0f896df15119a..be6f8c323cbe7 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -85,6 +85,9 @@ def Linalg_InitTensorOp : Linalg_Op<"init_tensor",
});
}
+ // Return both static and dynamic sizes as a list of `OpFoldResult`.
+ SmallVector<OpFoldResult> getMixedSizes();
+
// Return the Value of the dynamic size of the tensor at dimension
// `idx`. Asserts that the shape is dynamic at that `idx.
Value getDynamicSize(unsigned idx) {
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
index 7b0d62e502e58..e623b8cb03116 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
+++ b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
@@ -85,6 +85,28 @@ bool preservesStaticInformation(Type source, Type target);
/// ```
bool canFoldIntoConsumerOp(CastOp castOp);
+/// Determines whether the tensor::CastOp casts to a more static version of the
+/// source tensor. This is useful to fold into a producing op and implement
+/// canonicaliation patterns with the `tensor.cast` op as the root, but producer
+/// being from
diff erent dialects. Returns true when all conditions are met:
+/// 1. source and result and ranked tensors with same element type and rank.
+/// 2. the result type has more static information than the source.
+///
+/// Example:
+/// ```mlir
+/// %1 = producer ... : tensor<?x?xf32>
+/// %2 = tensor.cast %1 : tensor<?x?xf32> to tensor<8x16xf32>
+/// ```
+///
+/// can be canonicalized to :
+///
+/// ```mlir
+/// %2 = producer ... : tensor<8x16xf32>
+/// ```
+/// Not all ops might be canonicalizable this way, but for those that can be,
+/// this method provides a check that it is worth doing the canonicalization.
+bool canFoldIntoProducerOp(CastOp castOp);
+
/// Performs folding of any operand of `op` if it comes from a tensor::CastOp
/// that can be folded.
LogicalResult foldTensorCast(Operation *op);
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 53ff45a531049..010695172518c 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1072,6 +1072,21 @@ Type InitTensorOp::inferResultType(ArrayRef<int64_t> staticSizes,
return RankedTensorType::get(staticSizes, elementType, encoding);
}
+SmallVector<OpFoldResult> InitTensorOp::getMixedSizes() {
+ SmallVector<OpFoldResult> mixedSizes;
+ mixedSizes.reserve(getType().getRank());
+ unsigned dynamicValIndex = 0;
+ for (Attribute attr : static_sizes()) {
+ auto intAttr = attr.cast<IntegerAttr>();
+ if (!ShapedType::isDynamic(intAttr.getInt())) {
+ mixedSizes.push_back(intAttr);
+ continue;
+ }
+ mixedSizes.push_back(sizes()[dynamicValIndex++]);
+ }
+ return mixedSizes;
+}
+
namespace {
/// Change the type of the result of a `linalg.init_tensor` by making the result
/// type statically sized along dimension that in the original operation where
@@ -1193,11 +1208,86 @@ struct FoldInitTensorWithDimOp : public OpRewritePattern<tensor::DimOp> {
return success();
}
};
+
+/// Canonicalize
+///
+/// ```mlir
+/// %0 = linalg.init_tensor [%d0, %d1] : tensor<?x?xf32>
+/// %1 = tensor.cast %0 : tensor<?x?xf32> to tensor<4x?xf32>
+/// ```
+///
+/// into
+///
+/// ```mlir
+/// %0 = linalg.init_tensor [4, %d1] : tensor<4x?xf32>
+/// ```
+///
+/// This assumes the input program is correct in terms of its shape. So it
+/// is safe to assume that `%d0` is in fact 4. If that was not the case, the
+/// input program is wrong to begin with, so its undefined behavior anyway (i.e.
+/// this optimization can still triggering without violating program semantics).
+struct FoldInitTensorWithTensorCastOp
+ : public OpRewritePattern<tensor::CastOp> {
+ using OpRewritePattern<tensor::CastOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tensor::CastOp castOp,
+ PatternRewriter &rewriter) const override {
+ if (!canFoldIntoProducerOp(castOp))
+ return failure();
+ auto producer = castOp.source().getDefiningOp<InitTensorOp>();
+ if (!producer)
+ return failure();
+
+ auto resultType = castOp->getResult(0).getType().cast<RankedTensorType>();
+ ArrayRef<int64_t> resultShape = resultType.getShape();
+ SmallVector<OpFoldResult> currMixedSizes = producer.getMixedSizes();
+ SmallVector<OpFoldResult> newMixedSizes;
+ newMixedSizes.reserve(currMixedSizes.size());
+ assert(resultShape.size() == currMixedSizes.size() &&
+ "mismatch in result shape and sizes of init_tensor op");
+ for (auto it : llvm::zip(resultShape, currMixedSizes)) {
+ int64_t newDim = std::get<0>(it);
+ OpFoldResult currDim = std::get<1>(it);
+ // Case 1: The init tensor dim is static. Check that the tensor cast
+ // result dim matches.
+ if (auto attr = currDim.dyn_cast<Attribute>()) {
+ if (ShapedType::isDynamic(newDim) ||
+ newDim != attr.cast<IntegerAttr>().getInt()) {
+ // Something is off, the cast result shape cannot be more dynamic than
+ // the init tensor result shape (enforced by `canFoldIntoProducer`).
+ // Abort for now.
+ return rewriter.notifyMatchFailure(
+ producer, "mismatch in static value of shape of init "
+ "tensor result and cast result");
+ }
+ newMixedSizes.push_back(attr);
+ continue;
+ }
+
+ // Case 2 : The tensor cast shape is static, but init tensor result shape
+ // is dynamic.
+ if (!ShapedType::isDynamic(newDim)) {
+ newMixedSizes.push_back(rewriter.getIndexAttr(newDim));
+ continue;
+ }
+
+ // Case 3 : The tensor cast shape is dynamic and init tensor result shape
+ // is dynamic. Use the dynamic value from the init tensor op.
+ newMixedSizes.push_back(currDim);
+ }
+
+ rewriter.replaceOpWithNewOp<InitTensorOp>(castOp, newMixedSizes,
+ resultType.getElementType());
+ return success();
+ }
+};
+
} // namespace
void InitTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<FoldInitTensorWithDimOp, FoldInitTensorWithExtractSliceOp,
+ results.add<FoldInitTensorWithTensorCastOp, FoldInitTensorWithDimOp,
+ FoldInitTensorWithExtractSliceOp,
FoldInitTensorWithTensorReshapeOp<tensor::ExpandShapeOp>,
FoldInitTensorWithTensorReshapeOp<tensor::CollapseShapeOp>,
ReplaceStaticShapeDims>(context);
@@ -1608,7 +1698,7 @@ struct EraseDeadLinalgOp : public OpInterfaceRewritePattern<LinalgOp> {
}
};
-struct FoldTensorCastOp : public OpInterfaceRewritePattern<LinalgOp> {
+struct FoldTensorCastProducerOp : public OpInterfaceRewritePattern<LinalgOp> {
using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern;
LogicalResult matchAndRewrite(LinalgOp op,
@@ -1664,6 +1754,63 @@ struct FoldTensorCastOp : public OpInterfaceRewritePattern<LinalgOp> {
}
};
+/// Fold LinalgOps with `tensor.cast` consumer if the `tensor.cast` has
+/// result that is more static than the linalg op.
+struct FoldTensorCastConsumerOp : public OpRewritePattern<tensor::CastOp> {
+ using OpRewritePattern<tensor::CastOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tensor::CastOp castOp,
+ PatternRewriter &rewriter) const override {
+ if (!tensor::canFoldIntoProducerOp(castOp))
+ return failure();
+ auto linalgOp = castOp.source().getDefiningOp<LinalgOp>();
+ if (!linalgOp)
+ return failure();
+
+ OpBuilder::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPoint(linalgOp);
+
+ Location loc = linalgOp.getLoc();
+ OpResult resultValue = castOp.source().cast<OpResult>();
+ unsigned resultNumber = resultValue.getResultNumber();
+ auto resultType = castOp->getResult(0).getType().cast<RankedTensorType>();
+ // Replace the `outs` for the result with a `tensor.cast`. This cast is now
+ // going from a more dynamic shape to a less dynamic shape. If the producer
+ // for this cast, i.e. producer of the out operand, is also an operation
+ // that folds with tensor.cast consumer (like this pattern), the cast will
+ // continue to propagate as far up the stack as it can go.
+ OpOperand *outOperand = linalgOp.getOutputOperand(resultNumber);
+ Value newOperand =
+ rewriter.create<tensor::CastOp>(loc, resultType, outOperand->get());
+ SmallVector<Value> newOperands = linalgOp.getInputOperands();
+ SmallVector<Value> outputOperands = linalgOp.getOutputOperands();
+ outputOperands[resultNumber] = newOperand;
+ newOperands.append(outputOperands.begin(), outputOperands.end());
+
+ SmallVector<Type> resultTypes(linalgOp->result_type_begin(),
+ linalgOp->result_type_end());
+ resultTypes[resultNumber] = resultType;
+ Operation *newOp = linalgOp.clone(rewriter, loc, resultTypes, newOperands);
+
+ if (!resultValue.hasOneUse()) {
+ SmallVector<Value> results(newOp->result_begin(), newOp->result_end());
+ // Create a tensor.cast operation back to the original type.
+ Value castBack = rewriter.create<tensor::CastOp>(
+ loc, resultValue.getType(), newOp->getResult(resultNumber));
+ results[resultNumber] = castBack;
+ // Replace all uses except the use in the cast op that is matched by the
+ // pattern. Note that this cast is from a more static shape to a more
+ // dynamic shape. These are expected to be pulled into their consumers.
+ rewriter.replaceOpWithIf(linalgOp, results,
+ [&castOp](OpOperand &use) -> bool {
+ return use.getOwner() != castOp.getOperation();
+ });
+ }
+ rewriter.replaceOp(castOp, newOp->getResult(resultNumber));
+ return success();
+ }
+};
+
} // namespace
#define LINALGOP_FOLDERS(XXX) \
@@ -1684,7 +1831,8 @@ LINALGOP_FOLDERS(GenericOp)
void LinalgDialect::getCanonicalizationPatterns(
RewritePatternSet &results) const {
- results.add<EraseDeadLinalgOp, FoldTensorCastOp>(getContext());
+ results.add<EraseDeadLinalgOp, FoldTensorCastConsumerOp,
+ FoldTensorCastProducerOp>(getContext());
}
Operation *LinalgDialect::materializeConstant(OpBuilder &builder,
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 0c5d182c82d3e..e93d619bce74e 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -98,6 +98,33 @@ bool mlir::tensor::canFoldIntoConsumerOp(CastOp castOp) {
castOp.source().getType());
}
+/// Determines whether the tensor::CastOp casts to a more static version of the
+/// source tensor. This is useful to fold into a producing op and implement
+/// canonicaliation patterns with the `tensor.cast` op as the root, but producer
+/// being from
diff erent dialects. Returns true when all conditions are met:
+/// 1. source and result and ranked tensors with same element type and rank.
+/// 2. the result type has more static information than the source.
+///
+/// Example:
+/// ```mlir
+/// %1 = producer ... : tensor<?x?xf32>
+/// %2 = tensor.cast %1 : tensor<?x?xf32> to tensor<8x16xf32>
+/// ```
+///
+/// can be canonicalized to :
+///
+/// ```mlir
+/// %2 = producer ... : tensor<8x16xf32>
+/// ```
+/// Not all ops might be canonicalizable this way, but for those that can be,
+/// this method provides a check that it is worth doing the canonicalization.
+bool mlir::tensor::canFoldIntoProducerOp(CastOp castOp) {
+ if (!castOp)
+ return false;
+ return preservesStaticInformation(castOp.source().getType(),
+ castOp.getType());
+}
+
/// Performs folding of any operand of `op` if it comes from a tensor::CastOp
/// that can be folded.
LogicalResult mlir::tensor::foldTensorCast(Operation *op) {
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index a2be638fd8078..b24a3e78e32b1 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -244,6 +244,17 @@ func @fold_init_tensor_with_slice
// -----
+func @fold_init_tensor_with_cast(%arg0 : index) -> tensor<1x12xf32> {
+ %0 = linalg.init_tensor [%arg0, 12] : tensor<?x12xf32>
+ %1 = tensor.cast %0 : tensor<?x12xf32> to tensor<1x12xf32>
+ return %1 : tensor<1x12xf32>
+}
+// CHECK: func @fold_init_tensor_with_cast(%[[ARG0:.+]]: index)
+// CHECK: %[[T0:.+]] = linalg.init_tensor [1, 12] : tensor<1x12xf32>
+// CHECK: return %[[T0]] : tensor<1x12xf32>
+
+// -----
+
#accesses = [
affine_map<(i, j) -> (i, j)>
]
@@ -747,3 +758,23 @@ func @multi_insert_pad_into_fill_mismatch(%input: tensor<7x123x124xf32>, %a: ten
%2 = tensor.insert_slice %pad into %1 [0, 0, 256] [8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32>
return %2: tensor<8x384x384xf32>
}
+
+// -----
+
+func @fold_linalgop_with_cast_consumer(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
+ %arg2 : tensor<?x?xf32>) -> (tensor<4x8xf32>, tensor<?x?xf32>) {
+ %0 = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
+ %1 = tensor.cast %0 : tensor<?x?xf32> to tensor<4x8xf32>
+ return %1, %0 : tensor<4x8xf32>, tensor<?x?xf32>
+}
+// CHECK: func @fold_linalgop_with_cast_consumer(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<?x?xf32>)
+// CHECK: %[[OUT_CAST:.+]] = tensor.cast %[[ARG2]] : tensor<?x?xf32> to tensor<4x8xf32>
+// CHECK: %[[MATMUL:.+]] = linalg.matmul
+// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] :
+// CHECK-SAME: outs(%[[OUT_CAST]] :
+// CHECK: %[[RESULT_CAST:.+]] = tensor.cast %[[MATMUL]]
+// CHECK: return %[[MATMUL]], %[[RESULT_CAST]]
diff --git a/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir b/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir
index ac94261a153f0..716e38a32e03c 100644
--- a/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir
+++ b/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir
@@ -47,7 +47,7 @@ func @matmul_tensors(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tens
// CHECK: scf.for %[[I:[0-9a-z]*]]
// CHECK: %[[sizeA0:.*]] = affine.min #[[BOUND2_MAP]](%[[I]])[%[[dA0]]]
// CHECK: %[[stA:.*]] = tensor.extract_slice %[[A]][%[[I]], 0] [%[[sizeA0]], %[[dA1]]] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
-// CHECK-NEXT: scf.for %[[J:[0-9a-z]*]]
+// CHECK: scf.for %[[J:[0-9a-z]*]]
// CHECK-NEXT: scf.for %[[K:[0-9a-z]*]] {{.*}} iter_args(%[[RES:[0-9a-z]*]]
// CHECK-DAG: %[[stB1:.*]] = tensor.extract_slice %[[B]][%[[K]], %[[J]]] [4, 3] [1, 1] : tensor<?x?xf32> to tensor<4x3xf32>
// CHECK-DAG: %[[stF:.*]] = tensor.extract_slice %[[RES]][%[[I]], %[[J]]] [2, 3] [1, 1] : tensor<?x?xf32> to tensor<2x3xf32>
@@ -56,9 +56,9 @@ func @matmul_tensors(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tens
// CHECK: %[[sizeB1:.*]] = affine.min #[[BOUND4_MAP]](%[[K]])[%[[dB1]]]
// CHECK: %[[stB2:.*]] = tensor.extract_slice %[[B]][0, %[[K]]] [%[[dB0]], %[[sizeB1]]] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
// CHECK: %[[stC:.*]] = tensor.extract_slice %[[C]][%[[I]], %[[K]]] [%[[sizeA0]], %[[sizeB1]]] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
-// CHECK: %[[stD:.*]] = linalg.matmul ins(%[[stA]], %[[stB2]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[stC]] : tensor<?x?xf32>) -> tensor<?x?xf32>
-// CHECK: %[[CAST:.*]] = tensor.cast %[[stD]] : tensor<?x?xf32> to tensor<2x4xf32>
-// CHECK-NEXT: %[[stG:.*]] = linalg.matmul ins(%[[CAST]], %[[stB1]] : tensor<2x4xf32>, tensor<4x3xf32>) outs(%[[stF]] : tensor<2x3xf32>) -> tensor<2x3xf32>
+// CHECK-DAG: %[[castC:.+]] = tensor.cast %[[stC]] : tensor<?x?xf32> to tensor<2x4xf32>
+// CHECK: %[[stD:.*]] = linalg.matmul ins(%[[stA]], %[[stB2]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[castC]] : tensor<2x4xf32>) -> tensor<2x4xf32>
+// CHECK-NEXT: %[[stG:.*]] = linalg.matmul ins(%[[stD]], %[[stB1]] : tensor<2x4xf32>, tensor<4x3xf32>) outs(%[[stF]] : tensor<2x3xf32>) -> tensor<2x3xf32>
// CHECK-NEXT: tensor.insert_slice %[[stG]] into %[[RES]][%[[I]], %[[J]]]
// -----
More information about the Mlir-commits
mailing list