[Mlir-commits] [mlir] [mlir][tensor] Make useful Tensor utilities public (PR #126802)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Feb 11 13:50:57 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-tensor

Author: Andrzej WarzyƄski (banach-space)

<details>
<summary>Changes</summary>

1. Extract the main logic from `foldTensorCastPrecondition` into a dedicated
   helper hook: `hasFoldableTensorCastOperand`. This allows for reusing the
   corresponding checks.

2. Rename `getNewOperands` to `getUpdatedOperandsAfterCastOpFolding` for better
   clarity and documentation of its functionality.

3. These updated hooks will be reused in:
   * https://github.com/llvm/llvm-project/pull/123902.
   This PR makes them public.

**Note:** Moving these hooks to `Tensor/Utils` is not feasible because
`MLIRTensorUtils` depends on `MLIRTensorDialect` (CMake targets). If these
hooks were moved to `Utils`, it would create a dependency of `MLIRTensorDialect`
on `MLIRTensorUtils`, leading to a circular dependency.


---
Full diff: https://github.com/llvm/llvm-project/pull/126802.diff


2 Files Affected:

- (modified) mlir/include/mlir/Dialect/Tensor/IR/Tensor.h (+12) 
- (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+36-31) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
index 1bd0f6553fc8d..b3ec796a72337 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
+++ b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
@@ -116,6 +116,18 @@ bool canFoldIntoConsumerOp(CastOp castOp);
 /// this method provides a check that it is worth doing the canonicalization.
 bool canFoldIntoProducerOp(CastOp castOp);
 
+/// Return true if any of the operands of `op` is a CastOp that can be folded
+/// into its consumer, i.e. `op`. This is effectively a convenience wrapper for
+/// `canFoldIntoProducerOp`.
+bool hasFoldableTensorCastOperand(Operation *op);
+
+/// Assuming that `op` contains at least one operand that is a foldable CastOp
+/// (i.e. `hasFoldableTensorCastOperand` returns true), calculate the updated
+/// operands.
+SmallVector<Value>
+getUpdatedOperandsAfterCastOpFolding(DestinationStyleOpInterface op,
+                                     SmallVector<Type> &newResTy);
+
 /// Performs folding of any operand of `op` if it comes from a tensor::CastOp
 /// that can be folded.
 LogicalResult foldTensorCast(Operation *op);
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index fda6246334e15..03c2f3843f262 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -354,6 +354,35 @@ bool mlir::tensor::canFoldIntoProducerOp(CastOp castOp) {
                                     castOp.getType());
 }
 
+bool mlir::tensor::hasFoldableTensorCastOperand(Operation *op) {
+  return llvm::any_of(op->getOpOperands(), [&](OpOperand &opOperand) {
+    if (llvm::isa<BlockArgument>(opOperand.get()))
+      return false;
+    auto castOp = opOperand.get().getDefiningOp<tensor::CastOp>();
+    return castOp && canFoldIntoConsumerOp(castOp);
+  });
+}
+
+SmallVector<Value> mlir::tensor::getUpdatedOperandsAfterCastOpFolding(
+    DestinationStyleOpInterface op, SmallVector<Type> &newResTy) {
+  SmallVector<Value> newOperands;
+  newOperands.reserve(op->getNumOperands());
+
+  assert(hasFoldableTensorCastOperand(op) && "No foldable CastOp operands!");
+
+  // Assumes that the result has dpsInits followed by nonDpsInits.
+  int64_t dpsInitIdx = 0;
+  for (OpOperand &opOperand : op->getOpOperands()) {
+    auto tensorCastOp = opOperand.get().getDefiningOp<tensor::CastOp>();
+    bool fold = canFoldIntoConsumerOp(tensorCastOp);
+    newOperands.push_back(fold ? tensorCastOp.getOperand() : opOperand.get());
+    if (op.isDpsInit(&opOperand) &&
+        !llvm::isa<MemRefType>(newOperands.back().getType()))
+      newResTy[dpsInitIdx++] = newOperands.back().getType();
+  }
+  return newOperands;
+}
+
 /// Performs folding of any operand of `op` if it comes from a tensor::CastOp
 /// that can be folded.
 LogicalResult mlir::tensor::foldTensorCast(Operation *op) {
@@ -4777,34 +4806,7 @@ bool foldTensorCastPrecondition(DestinationStyleOpInterface op) {
       isa<LoopLikeOpInterface>(op.getOperation()))
     return false;
 
-  // If no operand comes from a tensor::CastOp and can be folded then fail.
-  bool hasTensorCastOperand =
-      llvm::any_of(op->getOpOperands(), [&](OpOperand &opOperand) {
-        if (llvm::isa<BlockArgument>(opOperand.get()))
-          return false;
-        auto castOp = opOperand.get().getDefiningOp<tensor::CastOp>();
-        return castOp && canFoldIntoConsumerOp(castOp);
-      });
-
-  return hasTensorCastOperand;
-}
-
-static SmallVector<Value> getNewOperands(DestinationStyleOpInterface op,
-                                         SmallVector<Type> &newResTy) {
-  SmallVector<Value> newOperands;
-  newOperands.reserve(op->getNumOperands());
-
-  // Assumes that the result has dpsInits followed by nonDpsInits.
-  int64_t dpsInitIdx = 0;
-  for (OpOperand &opOperand : op->getOpOperands()) {
-    auto tensorCastOp = opOperand.get().getDefiningOp<tensor::CastOp>();
-    bool fold = canFoldIntoConsumerOp(tensorCastOp);
-    newOperands.push_back(fold ? tensorCastOp.getOperand() : opOperand.get());
-    if (op.isDpsInit(&opOperand) &&
-        !llvm::isa<MemRefType>(newOperands.back().getType()))
-      newResTy[dpsInitIdx++] = newOperands.back().getType();
-  }
-  return newOperands;
+  return hasFoldableTensorCastOperand(op);
 }
 
 // Given the (potentially) updated packed type, `newPackedTy`, generates an
@@ -4868,7 +4870,8 @@ struct FoldTensorCastPackOp : public OpRewritePattern<PackOp> {
       return failure();
 
     SmallVector<Type> newResultTypes(op->getResultTypes());
-    SmallVector<Value> newOperands = getNewOperands(op, newResultTypes);
+    SmallVector<Value> newOperands =
+        getUpdatedOperandsAfterCastOpFolding(op, newResultTypes);
 
     // Get the updated mixed-tile-sizes attribute.
     SmallVector<OpFoldResult> newMixedTileSizes =
@@ -4920,7 +4923,8 @@ struct FoldTensorCastUnPackOp : public OpRewritePattern<UnPackOp> {
       return failure();
 
     SmallVector<Type> newResultTypes(op->getResultTypes());
-    SmallVector<Value> newOperands = getNewOperands(op, newResultTypes);
+    SmallVector<Value> newOperands =
+        getUpdatedOperandsAfterCastOpFolding(op, newResultTypes);
     Value sourceTensor = newOperands[0];
 
     // Get the updated mixed-tile-sizes attribute.
@@ -4980,7 +4984,8 @@ struct FoldTensorCastProducerOp
       return failure();
 
     SmallVector<Type> newResultTypes(op->getResultTypes());
-    SmallVector<Value> newOperands = getNewOperands(op, newResultTypes);
+    SmallVector<Value> newOperands =
+        getUpdatedOperandsAfterCastOpFolding(op, newResultTypes);
 
     // Clone op
     auto newOp = clone(rewriter, op, newResultTypes, newOperands);

``````````

</details>


https://github.com/llvm/llvm-project/pull/126802


More information about the Mlir-commits mailing list