[Mlir-commits] [mlir] [MLIR][Tensor] Fix DPS op canonicalizer with `tensor.cast`` (PR #91382)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue May 7 12:16:46 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-core

Author: Christopher Bate (christopherbate)

<details>
<summary>Changes</summary>

Attempts to address a bug pointed out in https://github.com/llvm/llvm-project/issues/91265
by moving the FoldTensorCastProducerOp canonicalizer definition upward into
the MLIRDialectUtils library. Since the MLIRDialectUtils can't depend on any
dialect, the canonicalizer had to change slightly, and a templated version
is introduced.

Then, we need to add this canonicalization routine where it was used before,
except for places where it is incorrect as pointed out in the bug.
Based on cursory inspection of the TableGen definitions, only
`bufferization.materialize_in_destination` should *not* have the
canonicalizer, but existing tests passed if the canonicalizer as
only added for `tensor.pack|unpack|extract_slice` and the LinalgOp interface.


---
Full diff: https://github.com/llvm/llvm-project/pull/91382.diff


10 Files Affected:

- (modified) mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td (+1) 
- (modified) mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h (+72) 
- (modified) mlir/include/mlir/IR/BuiltinTypeInterfaces.td (+19) 
- (modified) mlir/include/mlir/IR/BuiltinTypes.td (+27) 
- (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+20-2) 
- (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+28-4) 
- (modified) mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp (+53) 
- (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+8-2) 
- (modified) mlir/lib/IR/BuiltinTypeInterfaces.cpp (+24) 
- (modified) mlir/test/Dialect/Bufferization/canonicalize.mlir (+16) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index a403e89a39f98..228317b4adb6c 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -818,6 +818,7 @@ def Tensor_InsertOp : Tensor_Op<"insert", [
     MutableOperandRange getDpsInitsMutable() { return getDestMutable(); }
   }];
 
+  let hasCanonicalizer = 1;
   let hasFolder = 1;
   let hasVerifier = 1;
 }
diff --git a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
index 929a2a7d39649..a1f4ddc1221fc 100644
--- a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
@@ -20,7 +20,9 @@
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/Location.h"
+#include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeRange.h"
+#include "mlir/Interfaces/DestinationStyleOpInterface.h"
 #include "mlir/Support/LLVM.h"
 
 // Pull in all enum type definitions and utility function declarations.
@@ -158,6 +160,76 @@ Operation *cloneWithoutRegions(OpBuilder &b, Operation *op,
 SmallVector<NamedAttribute>
 getPrunedAttributeList(Operation *op, ArrayRef<StringRef> elidedAttrs);
 
+/// Folds cast-like operations into a consuming DestinationStyleOpInterface op
+/// if `isPreservingCast` is true. If the cast appears on a 'DPS-init operand',
+/// then the tied result type is updated as well to the type of the cast source,
+/// and a new cast must be inserted on the new op's result. `createCast` is used
+/// to build such required cast ops.
+///
+/// ### Example
+/// If the `isPreservingCast` returns true if the cast is a "generalizing"
+/// `tensor.cast`, then this function would be have as follows:
+///
+/// ```mlir
+/// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
+/// %2 = dps_op %1 ... : tensor<?x?xf32> ...
+/// ```
+///
+/// folds into:
+///
+/// ```mlir
+/// %2 = dps_op %0 ... : tensor<8x16xf32> ...
+/// ```
+LogicalResult foldCastProducers(
+    RewriterBase &rewriter, DestinationStyleOpInterface consumerOp,
+    llvm::function_ref<bool(Operation *)> isPreservingCast,
+    llvm::function_ref<Value(RewriterBase &rewriter, Type originalType,
+                             Value replacement)>
+        createCast);
+
+/// Folds `tensor.cast` ops into a consuming DestinationStyleOpInterface op
+/// if the casts make their operands less static. See also isPreservingCast
+/// above.
+template <typename CastOpType>
+LogicalResult foldCastProducers(DestinationStyleOpInterface op,
+                                RewriterBase &rewriter) {
+  return foldCastProducers(
+      rewriter, op,
+      [](Operation *castOp) -> bool {
+        auto concreteCast = dyn_cast<CastOpType>(castOp);
+        if (!concreteCast)
+          return false;
+        RankedTensorType resultType =
+            dyn_cast<RankedTensorType>(concreteCast.getType());
+        RankedTensorType sourceType =
+            dyn_cast<RankedTensorType>(concreteCast->getOperand(0).getType());
+        if (!resultType || !sourceType)
+          return false;
+        return resultType.isGeneralizationOf(sourceType);
+      },
+      [](RewriterBase &rewriter, Type resultType, Value operand) -> Value {
+        return rewriter.create<CastOpType>(operand.getLoc(), resultType,
+                                           operand);
+      });
+}
+
+/// A generic pattern for an Operation type that implements
+/// DestinationStyleOpInterface, allowing for absorbing cast-like operations
+/// that are producers of operands.
+template <typename OpType, typename CastOpType>
+struct FoldTensorCastIntoConsumerPattern : public OpRewritePattern<OpType> {
+  using OpRewritePattern<OpType>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(OpType op,
+                                PatternRewriter &rewriter) const override {
+    DestinationStyleOpInterface dpsOp =
+        llvm::dyn_cast<DestinationStyleOpInterface>(op.getOperation());
+    if (!dpsOp)
+      return failure();
+    return foldCastProducers<CastOpType>(dpsOp, rewriter);
+  }
+};
+
 } // namespace mlir
 
 #endif // MLIR_DIALECT_UTILS_STRUCTUREDOPSUTILS_H
diff --git a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
index db38e2e1bce22..cbb20f42f9fb9 100644
--- a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
+++ b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
@@ -116,6 +116,15 @@ def ShapedTypeInterface : TypeInterface<"ShapedType"> {
     auto clone(::llvm::ArrayRef<int64_t> shape) {
       return cloneWith(shape, getElementType());
     }
+
+    /// Return whether the target shape is a refinement of the source shape.
+    static bool isShapeRefinementOf(
+      ArrayRef<int64_t> source, ArrayRef<int64_t> target);
+
+    /// Return whether the target shape is a generalization of the source
+    /// shape.
+    static bool isShapeGeneralizationOf(
+      ArrayRef<int64_t> source, ArrayRef<int64_t> target);
   }];
 
   let extraSharedClassDeclaration = [{
@@ -185,6 +194,16 @@ def ShapedTypeInterface : TypeInterface<"ShapedType"> {
       return llvm::count_if($_type.getShape().take_front(index),
                             ::mlir::ShapedType::isDynamic);
     }
+
+    bool isRefinementOf(ShapedType source) {
+      return $_type.getElementType() == source.getElementType() &&
+        ShapedType::isShapeRefinementOf(source.getShape(), $_type.getShape());
+    }
+
+    bool isGeneralizationOf(ShapedType source) {
+      return $_type.getElementType() == source.getElementType() &&
+        ShapedType::isShapeGeneralizationOf(source.getShape(), $_type.getShape());
+    }
   }];
 }
 
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index 4cade83dd3c32..cf9489711aa17 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -837,6 +837,33 @@ def Builtin_RankedTensor : Builtin_Type<"RankedTensor", "tensor", [
     using ShapedType::Trait<RankedTensorType>::getDimSize;
     using ShapedType::Trait<RankedTensorType>::getDynamicDimIndex;
 
+    /// Return whether this type is a refinement of `source` with
+    /// respect to only the shape, meaning they ave the same element type
+    /// and the shape of this type is the same as source, except
+    /// zero or more dynamic extents from source have been replaced with
+    /// static extents.
+    /// This method is conservative with respect to the encoding. If the
+    /// encodings are not the same, then false is returned.
+    bool isRefinementOf(RankedTensorType source) {
+
+      return getEncoding() == source.getEncoding() &&
+        ShapedType::Trait<RankedTensorType>::isRefinementOf(
+          llvm::cast<ShapedType>(source));
+    }
+
+    /// Return whether this type is a generalization of `source` with
+    /// respect to only the shape, meaning they have the same element
+    /// type and the shape of this type is the same as source, except
+    /// zero or more static extents have been replaced with unknown
+    /// extents.
+    /// This method is conservative with respect to the encoding. If the
+    /// encodings are not the same, then false is returned.
+    bool isGeneralizationOf(RankedTensorType source) {
+      return getEncoding() == source.getEncoding() &&
+        ShapedType::Trait<RankedTensorType>::isGeneralizationOf(
+          llvm::cast<ShapedType>(source));
+    }
+
     /// This is a builder type that keeps local references to arguments.
     /// Arguments that are passed into the builder must outlive the builder.
     class Builder;
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index e5f83331baf81..549fd37744d71 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2674,10 +2674,28 @@ FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &b) {
 // LinalgDialect
 //===----------------------------------------------------------------------===//
 
+namespace {
+struct LinalgAbsorbTensorCastProducersPattern
+    : public OpInterfaceRewritePattern<linalg::LinalgOp> {
+  using OpInterfaceRewritePattern<linalg::LinalgOp>::OpInterfaceRewritePattern;
+
+  LogicalResult matchAndRewrite(LinalgOp op,
+                                PatternRewriter &rewriter) const override {
+    DestinationStyleOpInterface dpsOp =
+        llvm::dyn_cast<DestinationStyleOpInterface>(op.getOperation());
+    if (!dpsOp)
+      return failure();
+    return foldCastProducers<tensor::CastOp>(dpsOp, rewriter);
+  }
+};
+} // namespace
+
 void LinalgDialect::getCanonicalizationPatterns(
     RewritePatternSet &results) const {
-  results.add<EraseDeadLinalgOp, FoldTensorCastConsumerOp,
-              InferStaticShapeOfOperands>(getContext());
+  results
+      .add<EraseDeadLinalgOp, FoldTensorCastConsumerOp,
+           InferStaticShapeOfOperands, LinalgAbsorbTensorCastProducersPattern>(
+          getContext());
 }
 
 Operation *LinalgDialect::materializeConstant(OpBuilder &builder,
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 4c65045084dc5..38f8a2a288020 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -14,6 +14,7 @@
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinAttributeInterfaces.h"
 #include "mlir/IR/BuiltinTypeInterfaces.h"
@@ -1369,6 +1370,11 @@ OpFoldResult InsertOp::fold(FoldAdaptor adaptor) {
   return {};
 }
 
+void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
+                                           MLIRContext *context) {
+  results.add<FoldTensorCastIntoConsumerPattern<InsertOp, CastOp>>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // GenerateOp
 //===----------------------------------------------------------------------===//
@@ -2413,7 +2419,9 @@ void ExtractSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
   results.add<
       OpWithOffsetSizesAndStridesConstantArgumentFolder<
           ExtractSliceOp, SliceReturnTypeCanonicalizer, SliceCanonicalizer>,
-      ExtractSliceOpCastFolder>(context);
+      ExtractSliceOpCastFolder,
+      FoldTensorCastIntoConsumerPattern<tensor::ExtractSliceOp,
+                                        tensor::CastOp>>(context);
 }
 
 //
@@ -4154,6 +4162,15 @@ static bool inferStaticShape(PackOp packOp, SmallVectorImpl<int64_t> &srcShape,
 }
 
 LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
+  // pack(cast(x)) -> pack(x)
+  if (packOp.getSource().getDefiningOp<tensor::CastOp>() ||
+      packOp.getDest().getDefiningOp<tensor::CastOp>()) {
+    if (succeeded(foldCastProducers<CastOp>(
+            cast<DestinationStyleOpInterface>(packOp.getOperation()),
+            rewriter)))
+      return success();
+  }
+
   // Fold an unpack(pack(x)) to x.
   if (auto unPackOp = packOp.getSource().getDefiningOp<UnPackOp>()) {
     if (unPackOp.getSourceType() != packOp.getDestType())
@@ -4388,6 +4405,15 @@ static bool inferStaticShape(UnPackOp op, SmallVectorImpl<int64_t> &srcShape,
 
 LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
                                      PatternRewriter &rewriter) {
+  // pack(cast(x)) -> pack(x)
+  if (unPackOp.getSource().getDefiningOp<tensor::CastOp>() ||
+      unPackOp.getDest().getDefiningOp<tensor::CastOp>()) {
+    if (succeeded(foldCastProducers<CastOp>(
+            cast<DestinationStyleOpInterface>(unPackOp.getOperation()),
+            rewriter)))
+      return success();
+  }
+
   /// pack(unpack(x)) -> x
   if (PackOp packOp = unPackOp.getSource().getDefiningOp<tensor::PackOp>()) {
     if (packOp.getDestType() != unPackOp.getSourceType())
@@ -4533,9 +4559,7 @@ struct FoldTensorCastProducerOp
 //===----------------------------------------------------------------------===//
 
 void TensorDialect::getCanonicalizationPatterns(
-    RewritePatternSet &results) const {
-  results.add<FoldTensorCastProducerOp>(getContext());
-}
+    RewritePatternSet &results) const {}
 
 //===----------------------------------------------------------------------===//
 // TableGen'd op method definitions
diff --git a/mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp b/mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp
index adde8a66d8354..131a67cede5aa 100644
--- a/mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp
+++ b/mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp
@@ -228,3 +228,56 @@ mlir::getPrunedAttributeList(Operation *op, ArrayRef<StringRef> elidedAttrs) {
   }
   return attrs;
 }
+
+LogicalResult mlir::foldCastProducers(
+    RewriterBase &rewriter, DestinationStyleOpInterface op,
+    llvm::function_ref<bool(Operation *)> isPreservingCast,
+    llvm::function_ref<Value(RewriterBase &rewriter, Type originalType,
+                             Value replacement)>
+        createCast) {
+
+  auto canFoldIntoConsumerOp = [&isPreservingCast](Operation *castOp) {
+    return castOp && isPreservingCast(castOp);
+  };
+
+  // If no operand comes from a tensor::CastOp and can be folded then fail.
+  bool hasTensorCastOperand =
+      llvm::any_of(op->getOpOperands(), [&](OpOperand &opOperand) {
+        if (llvm::isa<BlockArgument>(opOperand.get()))
+          return false;
+        Operation *castOp = opOperand.get().getDefiningOp();
+        return castOp && canFoldIntoConsumerOp(castOp);
+      });
+  if (!hasTensorCastOperand)
+    return failure();
+
+  SmallVector<Type, 4> newResultTypes;
+  newResultTypes.reserve(op->getNumResults());
+  SmallVector<Value, 4> newOperands;
+  newOperands.reserve(op->getNumOperands());
+  for (OpOperand &opOperand : op->getOpOperands()) {
+    Operation *tensorCastOp = opOperand.get().getDefiningOp();
+    bool fold = canFoldIntoConsumerOp(tensorCastOp);
+    newOperands.push_back(fold ? tensorCastOp->getOperand(0) : opOperand.get());
+    if (op.isDpsInit(&opOperand) &&
+        !llvm::isa<MemRefType>(newOperands.back().getType()))
+      newResultTypes.push_back(newOperands.back().getType());
+  }
+
+  // Clone op.
+  Operation *newOp = clone(rewriter, op, newResultTypes, newOperands);
+  SmallVector<Value, 4> replacements;
+  replacements.reserve(newOp->getNumResults());
+  for (auto [oldResult, newResult] :
+       llvm::zip(op->getResults(), newOp->getResults())) {
+    if (newResult.getType() != oldResult.getType()) {
+      Value resultCast = createCast(rewriter, oldResult.getType(), newResult);
+      replacements.push_back(resultCast);
+    } else {
+      replacements.push_back(newResult);
+    }
+  }
+  rewriter.replaceOp(op, replacements);
+
+  return success();
+}
\ No newline at end of file
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index d10a31941db4f..2b18e17d8d250 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -4184,7 +4184,10 @@ struct TransferReadAfterWriteToBroadcast
 
 void TransferReadOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                                  MLIRContext *context) {
-  results.add<TransferReadAfterWriteToBroadcast>(context);
+  results
+      .add<TransferReadAfterWriteToBroadcast,
+           FoldTensorCastIntoConsumerPattern<TransferReadOp, tensor::CastOp>>(
+          context);
 }
 
 //===----------------------------------------------------------------------===//
@@ -4636,7 +4639,10 @@ struct SwapExtractSliceOfTransferWrite
 
 void TransferWriteOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                                   MLIRContext *context) {
-  results.add<FoldWaw, SwapExtractSliceOfTransferWrite>(context);
+  results
+      .add<FoldWaw, SwapExtractSliceOfTransferWrite,
+           FoldTensorCastIntoConsumerPattern<TransferWriteOp, tensor::CastOp>>(
+          context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/IR/BuiltinTypeInterfaces.cpp b/mlir/lib/IR/BuiltinTypeInterfaces.cpp
index ab9e65b5edfed..50a8330d01dc4 100644
--- a/mlir/lib/IR/BuiltinTypeInterfaces.cpp
+++ b/mlir/lib/IR/BuiltinTypeInterfaces.cpp
@@ -33,3 +33,27 @@ int64_t ShapedType::getNumElements(ArrayRef<int64_t> shape) {
   }
   return num;
 }
+
+bool ShapedType::isShapeRefinementOf(ArrayRef<int64_t> source,
+                                     ArrayRef<int64_t> target) {
+  if (source.size() != target.size())
+    return false;
+  for (auto [srcDim, tgtDim] : llvm::zip_equal(source, target)) {
+    // If the source dimension is dynamic, then the target dimension can be
+    // dynamic or static.
+    if (isDynamic(srcDim))
+      continue;
+    // Static source dim and dynamic result dim -> not a refinement.
+    if (isDynamic(tgtDim))
+      return false;
+    // Static source dim != static result dim -> not a refinement.
+    if (srcDim != tgtDim)
+      return false;
+  }
+  return true;
+}
+
+bool ShapedType::isShapeGeneralizationOf(ArrayRef<int64_t> source,
+                                         ArrayRef<int64_t> target) {
+  return isShapeRefinementOf(target, source);
+}
diff --git a/mlir/test/Dialect/Bufferization/canonicalize.mlir b/mlir/test/Dialect/Bufferization/canonicalize.mlir
index b6c0a0e25efe0..bae1ff805d939 100644
--- a/mlir/test/Dialect/Bufferization/canonicalize.mlir
+++ b/mlir/test/Dialect/Bufferization/canonicalize.mlir
@@ -388,3 +388,19 @@ func.func @negative_input() -> tensor<?x?x?xf16> {
   %11 = bufferization.alloc_tensor(%c10, %idx-3, %idx27) : tensor<?x?x?xf16>
   return %11 : tensor<?x?x?xf16>
 }
+
+// -----
+
+func.func @materialize_in_destination_tensor_cast(%arg0: tensor<4xf32>, %arg1: index) -> tensor<?xf32> {
+  %0 = bufferization.alloc_tensor(%arg1) : tensor<?xf32>
+  %1 = tensor.cast %arg0 : tensor<4xf32> to tensor<?xf32>
+  %2 = bufferization.materialize_in_destination %1 in %0 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
+  return %2 : tensor<?xf32>
+}
+
+// Check that a `tensor.cast` producer is not absorbed.
+
+// CHECK-LABEL: func.func @materialize_in_destination_tensor_cast
+//       CHECK:   tensor.cast
+//       CHECK:   bufferization.materialize_in_destination
+//  CHECK-SAME:    : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>

``````````

</details>


https://github.com/llvm/llvm-project/pull/91382


More information about the Mlir-commits mailing list