[Mlir-commits] [mlir] [MLIR][Tensor] Fix DPS op canonicalizer with `tensor.cast`` (PR #91382)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed May 8 21:04:05 PDT 2024


================
@@ -158,6 +160,76 @@ Operation *cloneWithoutRegions(OpBuilder &b, Operation *op,
 SmallVector<NamedAttribute>
 getPrunedAttributeList(Operation *op, ArrayRef<StringRef> elidedAttrs);
 
+/// Folds cast-like operations into a consuming DestinationStyleOpInterface op
+/// if `isPreservingCast` is true. If the cast appears on a 'DPS-init operand',
+/// then the tied result type is updated as well to the type of the cast source,
+/// and a new cast must be inserted on the new op's result. `createCast` is used
+/// to build such required cast ops.
+///
+/// ### Example
+/// If the `isPreservingCast` returns true if the cast is a "generalizing"
+/// `tensor.cast`, then this function would be have as follows:
+///
+/// ```mlir
+/// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
+/// %2 = dps_op %1 ... : tensor<?x?xf32> ...
+/// ```
+///
+/// folds into:
+///
+/// ```mlir
+/// %2 = dps_op %0 ... : tensor<8x16xf32> ...
+/// ```
+LogicalResult foldCastProducers(
+    RewriterBase &rewriter, DestinationStyleOpInterface consumerOp,
+    llvm::function_ref<bool(Operation *)> isPreservingCast,
+    llvm::function_ref<Value(RewriterBase &rewriter, Type originalType,
+                             Value replacement)>
+        createCast);
+
+/// Folds `tensor.cast` ops into a consuming DestinationStyleOpInterface op
+/// if the casts make their operands less static. See also isPreservingCast
+/// above.
+template <typename CastOpType>
+LogicalResult foldCastProducers(DestinationStyleOpInterface op,
+                                RewriterBase &rewriter) {
+  return foldCastProducers(
+      rewriter, op,
+      [](Operation *castOp) -> bool {
+        auto concreteCast = dyn_cast<CastOpType>(castOp);
+        if (!concreteCast)
+          return false;
+        RankedTensorType resultType =
+            dyn_cast<RankedTensorType>(concreteCast.getType());
+        RankedTensorType sourceType =
+            dyn_cast<RankedTensorType>(concreteCast->getOperand(0).getType());
+        if (!resultType || !sourceType)
+          return false;
+        return resultType.isGeneralizationOf(sourceType);
+      },
+      [](RewriterBase &rewriter, Type resultType, Value operand) -> Value {
+        return rewriter.create<CastOpType>(operand.getLoc(), resultType,
+                                           operand);
+      });
+}
+
+/// A generic pattern for an Operation type that implements
+/// DestinationStyleOpInterface, allowing for absorbing cast-like operations
+/// that are producers of operands.
+template <typename OpType, typename CastOpType>
+struct FoldTensorCastIntoConsumerPattern : public OpRewritePattern<OpType> {
----------------
MaheshRavishankar wrote:

I think we try to not have template patterns be part of header files this way. I see why you have this, but is there a way to avoid doing this?

https://github.com/llvm/llvm-project/pull/91382


More information about the Mlir-commits mailing list