[Mlir-commits] [mlir] [MLIR][Tensor] Fix DPS op canonicalizer with `tensor.cast`` (PR #91382)
Christopher Bate
llvmlistbot at llvm.org
Thu May 9 08:09:01 PDT 2024
================
@@ -158,6 +160,76 @@ Operation *cloneWithoutRegions(OpBuilder &b, Operation *op,
SmallVector<NamedAttribute>
getPrunedAttributeList(Operation *op, ArrayRef<StringRef> elidedAttrs);
+/// Folds cast-like operations into a consuming DestinationStyleOpInterface op
+/// if `isPreservingCast` is true. If the cast appears on a 'DPS-init operand',
+/// then the tied result type is updated as well to the type of the cast source,
+/// and a new cast must be inserted on the new op's result. `createCast` is used
+/// to build such required cast ops.
+///
+/// ### Example
+/// If the `isPreservingCast` returns true if the cast is a "generalizing"
+/// `tensor.cast`, then this function would be have as follows:
+///
+/// ```mlir
+/// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
+/// %2 = dps_op %1 ... : tensor<?x?xf32> ...
+/// ```
+///
+/// folds into:
+///
+/// ```mlir
+/// %2 = dps_op %0 ... : tensor<8x16xf32> ...
+/// ```
+LogicalResult foldCastProducers(
+ RewriterBase &rewriter, DestinationStyleOpInterface consumerOp,
+ llvm::function_ref<bool(Operation *)> isPreservingCast,
+ llvm::function_ref<Value(RewriterBase &rewriter, Type originalType,
+ Value replacement)>
+ createCast);
+
+/// Folds `tensor.cast` ops into a consuming DestinationStyleOpInterface op
+/// if the casts make their operands less static. See also isPreservingCast
+/// above.
+template <typename CastOpType>
+LogicalResult foldCastProducers(DestinationStyleOpInterface op,
+ RewriterBase &rewriter) {
+ return foldCastProducers(
+ rewriter, op,
+ [](Operation *castOp) -> bool {
+ auto concreteCast = dyn_cast<CastOpType>(castOp);
+ if (!concreteCast)
+ return false;
+ RankedTensorType resultType =
+ dyn_cast<RankedTensorType>(concreteCast.getType());
+ RankedTensorType sourceType =
+ dyn_cast<RankedTensorType>(concreteCast->getOperand(0).getType());
+ if (!resultType || !sourceType)
+ return false;
+ return resultType.isGeneralizationOf(sourceType);
+ },
+ [](RewriterBase &rewriter, Type resultType, Value operand) -> Value {
+ return rewriter.create<CastOpType>(operand.getLoc(), resultType,
+ operand);
+ });
+}
+
+/// A generic pattern for an Operation type that implements
+/// DestinationStyleOpInterface, allowing for absorbing cast-like operations
+/// that are producers of operands.
+template <typename OpType, typename CastOpType>
+struct FoldTensorCastIntoConsumerPattern : public OpRewritePattern<OpType> {
----------------
christopherbate wrote:
We can duplicate it in dialects where it is used. It's not much code, so that would seem fine. I do recall seeing templated patterns like this in ReshapeOpsUtils.h though.
https://github.com/llvm/llvm-project/pull/91382
More information about the Mlir-commits
mailing list