[Mlir-commits] [mlir] [mlir][tensor] Make useful Tensor utilities public (PR #126802)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Feb 11 13:50:57 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-tensor
Author: Andrzej WarzyĆski (banach-space)
<details>
<summary>Changes</summary>
1. Extract the main logic from `foldTensorCastPrecondition` into a dedicated
helper hook: `hasFoldableTensorCastOperand`. This allows for reusing the
corresponding checks.
2. Rename `getNewOperands` to `getUpdatedOperandsAfterCastOpFolding` for better
clarity and documentation of its functionality.
3. These updated hooks will be reused in:
* https://github.com/llvm/llvm-project/pull/123902.
This PR makes them public.
**Note:** Moving these hooks to `Tensor/Utils` is not feasible because
`MLIRTensorUtils` depends on `MLIRTensorDialect` (CMake targets). If these
hooks were moved to `Utils`, it would create a dependency of `MLIRTensorDialect`
on `MLIRTensorUtils`, leading to a circular dependency.
---
Full diff: https://github.com/llvm/llvm-project/pull/126802.diff
2 Files Affected:
- (modified) mlir/include/mlir/Dialect/Tensor/IR/Tensor.h (+12)
- (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+36-31)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
index 1bd0f6553fc8d..b3ec796a72337 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
+++ b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
@@ -116,6 +116,18 @@ bool canFoldIntoConsumerOp(CastOp castOp);
/// this method provides a check that it is worth doing the canonicalization.
bool canFoldIntoProducerOp(CastOp castOp);
+/// Return true if any of the operands of `op` is a CastOp that can be folded
+/// into its consumer, i.e. `op`. This is effectively a convenience wrapper for
+/// `canFoldIntoProducerOp`.
+bool hasFoldableTensorCastOperand(Operation *op);
+
+/// Assuming that `op` contains at least one operand that is a foldable CastOp
+/// (i.e. `hasFoldableTensorCastOperand` returns true), calculate the updated
+/// operands.
+SmallVector<Value>
+getUpdatedOperandsAfterCastOpFolding(DestinationStyleOpInterface op,
+ SmallVector<Type> &newResTy);
+
/// Performs folding of any operand of `op` if it comes from a tensor::CastOp
/// that can be folded.
LogicalResult foldTensorCast(Operation *op);
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index fda6246334e15..03c2f3843f262 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -354,6 +354,35 @@ bool mlir::tensor::canFoldIntoProducerOp(CastOp castOp) {
castOp.getType());
}
+bool mlir::tensor::hasFoldableTensorCastOperand(Operation *op) {
+ return llvm::any_of(op->getOpOperands(), [&](OpOperand &opOperand) {
+ if (llvm::isa<BlockArgument>(opOperand.get()))
+ return false;
+ auto castOp = opOperand.get().getDefiningOp<tensor::CastOp>();
+ return castOp && canFoldIntoConsumerOp(castOp);
+ });
+}
+
+SmallVector<Value> mlir::tensor::getUpdatedOperandsAfterCastOpFolding(
+ DestinationStyleOpInterface op, SmallVector<Type> &newResTy) {
+ SmallVector<Value> newOperands;
+ newOperands.reserve(op->getNumOperands());
+
+ assert(hasFoldableTensorCastOperand(op) && "No foldable CastOp operands!");
+
+ // Assumes that the result has dpsInits followed by nonDpsInits.
+ int64_t dpsInitIdx = 0;
+ for (OpOperand &opOperand : op->getOpOperands()) {
+ auto tensorCastOp = opOperand.get().getDefiningOp<tensor::CastOp>();
+ bool fold = canFoldIntoConsumerOp(tensorCastOp);
+ newOperands.push_back(fold ? tensorCastOp.getOperand() : opOperand.get());
+ if (op.isDpsInit(&opOperand) &&
+ !llvm::isa<MemRefType>(newOperands.back().getType()))
+ newResTy[dpsInitIdx++] = newOperands.back().getType();
+ }
+ return newOperands;
+}
+
/// Performs folding of any operand of `op` if it comes from a tensor::CastOp
/// that can be folded.
LogicalResult mlir::tensor::foldTensorCast(Operation *op) {
@@ -4777,34 +4806,7 @@ bool foldTensorCastPrecondition(DestinationStyleOpInterface op) {
isa<LoopLikeOpInterface>(op.getOperation()))
return false;
- // If no operand comes from a tensor::CastOp and can be folded then fail.
- bool hasTensorCastOperand =
- llvm::any_of(op->getOpOperands(), [&](OpOperand &opOperand) {
- if (llvm::isa<BlockArgument>(opOperand.get()))
- return false;
- auto castOp = opOperand.get().getDefiningOp<tensor::CastOp>();
- return castOp && canFoldIntoConsumerOp(castOp);
- });
-
- return hasTensorCastOperand;
-}
-
-static SmallVector<Value> getNewOperands(DestinationStyleOpInterface op,
- SmallVector<Type> &newResTy) {
- SmallVector<Value> newOperands;
- newOperands.reserve(op->getNumOperands());
-
- // Assumes that the result has dpsInits followed by nonDpsInits.
- int64_t dpsInitIdx = 0;
- for (OpOperand &opOperand : op->getOpOperands()) {
- auto tensorCastOp = opOperand.get().getDefiningOp<tensor::CastOp>();
- bool fold = canFoldIntoConsumerOp(tensorCastOp);
- newOperands.push_back(fold ? tensorCastOp.getOperand() : opOperand.get());
- if (op.isDpsInit(&opOperand) &&
- !llvm::isa<MemRefType>(newOperands.back().getType()))
- newResTy[dpsInitIdx++] = newOperands.back().getType();
- }
- return newOperands;
+ return hasFoldableTensorCastOperand(op);
}
// Given the (potentially) updated packed type, `newPackedTy`, generates an
@@ -4868,7 +4870,8 @@ struct FoldTensorCastPackOp : public OpRewritePattern<PackOp> {
return failure();
SmallVector<Type> newResultTypes(op->getResultTypes());
- SmallVector<Value> newOperands = getNewOperands(op, newResultTypes);
+ SmallVector<Value> newOperands =
+ getUpdatedOperandsAfterCastOpFolding(op, newResultTypes);
// Get the updated mixed-tile-sizes attribute.
SmallVector<OpFoldResult> newMixedTileSizes =
@@ -4920,7 +4923,8 @@ struct FoldTensorCastUnPackOp : public OpRewritePattern<UnPackOp> {
return failure();
SmallVector<Type> newResultTypes(op->getResultTypes());
- SmallVector<Value> newOperands = getNewOperands(op, newResultTypes);
+ SmallVector<Value> newOperands =
+ getUpdatedOperandsAfterCastOpFolding(op, newResultTypes);
Value sourceTensor = newOperands[0];
// Get the updated mixed-tile-sizes attribute.
@@ -4980,7 +4984,8 @@ struct FoldTensorCastProducerOp
return failure();
SmallVector<Type> newResultTypes(op->getResultTypes());
- SmallVector<Value> newOperands = getNewOperands(op, newResultTypes);
+ SmallVector<Value> newOperands =
+ getUpdatedOperandsAfterCastOpFolding(op, newResultTypes);
// Clone op
auto newOp = clone(rewriter, op, newResultTypes, newOperands);
``````````
</details>
https://github.com/llvm/llvm-project/pull/126802
More information about the Mlir-commits
mailing list