[Mlir-commits] [mlir] [MLIR][Tensor] Fix DPS op canonicalizer with `tensor.cast`` (PR #91382)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed May 8 21:04:05 PDT 2024
================
@@ -158,6 +160,76 @@ Operation *cloneWithoutRegions(OpBuilder &b, Operation *op,
SmallVector<NamedAttribute>
getPrunedAttributeList(Operation *op, ArrayRef<StringRef> elidedAttrs);
+/// Folds cast-like operations into a consuming DestinationStyleOpInterface op
+/// if `isPreservingCast` is true. If the cast appears on a 'DPS-init operand',
+/// then the tied result type is updated as well to the type of the cast source,
+/// and a new cast must be inserted on the new op's result. `createCast` is used
+/// to build such required cast ops.
+///
+/// ### Example
+/// If the `isPreservingCast` returns true if the cast is a "generalizing"
+/// `tensor.cast`, then this function would be have as follows:
+///
+/// ```mlir
+/// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
+/// %2 = dps_op %1 ... : tensor<?x?xf32> ...
+/// ```
+///
+/// folds into:
+///
+/// ```mlir
+/// %2 = dps_op %0 ... : tensor<8x16xf32> ...
+/// ```
+LogicalResult foldCastProducers(
+ RewriterBase &rewriter, DestinationStyleOpInterface consumerOp,
+ llvm::function_ref<bool(Operation *)> isPreservingCast,
+ llvm::function_ref<Value(RewriterBase &rewriter, Type originalType,
+ Value replacement)>
+ createCast);
+
+/// Folds `tensor.cast` ops into a consuming DestinationStyleOpInterface op
+/// if the casts make their operands less static. See also isPreservingCast
+/// above.
+template <typename CastOpType>
+LogicalResult foldCastProducers(DestinationStyleOpInterface op,
+ RewriterBase &rewriter) {
+ return foldCastProducers(
+ rewriter, op,
+ [](Operation *castOp) -> bool {
+ auto concreteCast = dyn_cast<CastOpType>(castOp);
+ if (!concreteCast)
+ return false;
+ RankedTensorType resultType =
+ dyn_cast<RankedTensorType>(concreteCast.getType());
+ RankedTensorType sourceType =
+ dyn_cast<RankedTensorType>(concreteCast->getOperand(0).getType());
+ if (!resultType || !sourceType)
+ return false;
+ return resultType.isGeneralizationOf(sourceType);
+ },
+ [](RewriterBase &rewriter, Type resultType, Value operand) -> Value {
+ return rewriter.create<CastOpType>(operand.getLoc(), resultType,
+ operand);
+ });
+}
+
+/// A generic pattern for an Operation type that implements
+/// DestinationStyleOpInterface, allowing for absorbing cast-like operations
+/// that are producers of operands.
+template <typename OpType, typename CastOpType>
+struct FoldTensorCastIntoConsumerPattern : public OpRewritePattern<OpType> {
----------------
MaheshRavishankar wrote:
I think we try to not have template patterns be part of header files this way. I see why you have this, but is there a way to avoid doing this?
https://github.com/llvm/llvm-project/pull/91382
More information about the Mlir-commits
mailing list