[Mlir-commits] [mlir] [mlir][Tensor] Fold destination-style ops into `tensor.unpack` operation. (PR #71468)
Han-Chung Wang
llvmlistbot at llvm.org
Mon Nov 6 16:51:33 PST 2023
================
@@ -3925,15 +3925,25 @@ UnPackOp UnPackOp::createTransposedClone(OpBuilder &b, Location loc,
/// pack(unpack(x)) -> x
LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
PatternRewriter &rewriter) {
- PackOp packOp = unPackOp.getSource().getDefiningOp<tensor::PackOp>();
- if (!packOp || packOp.getDestType() != unPackOp.getSourceType())
- return failure();
- if (packOp.getPaddingValue() ||
- !hasSameInnerOuterAttribute(packOp, unPackOp) ||
- !haveSameTiles(packOp, unPackOp))
- return failure();
- rewriter.replaceOp(unPackOp, packOp.getSource());
- return success();
+ if (PackOp packOp = unPackOp.getSource().getDefiningOp<tensor::PackOp>()) {
+ if (packOp.getDestType() != unPackOp.getSourceType())
+ return failure();
+ if (packOp.getPaddingValue() ||
+ !hasSameInnerOuterAttribute(packOp, unPackOp) ||
+ !haveSameTiles(packOp, unPackOp))
+ return failure();
+ rewriter.replaceOp(unPackOp, packOp.getSource());
+ return success();
+ }
+ if (DestinationStyleOpInterface dstStyleOp =
+ unPackOp.getDest().getDefiningOp<DestinationStyleOpInterface>()) {
+ OpResult destValue = unPackOp.getDest().cast<OpResult>();
+ Value newDest = dstStyleOp.getDpsInits()[destValue.getResultNumber()];
----------------
hanhanW wrote:
[optional] we can use `getDpsInitOperand` method so it does not return a full list.
```suggestion
Value newDest = dstStyleOp.getDpsInitOperand(destValue.getResultNumber())->get();
```
https://github.com/llvm/llvm-project/pull/71468
More information about the Mlir-commits
mailing list