[Mlir-commits] [mlir] [mlir][Tensor] Fold destination-style ops into `tensor.unpack` operation. (PR #71468)

Han-Chung Wang llvmlistbot at llvm.org
Mon Nov 6 16:51:33 PST 2023


================
@@ -3925,15 +3925,25 @@ UnPackOp UnPackOp::createTransposedClone(OpBuilder &b, Location loc,
 /// pack(unpack(x)) -> x
 LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
                                      PatternRewriter &rewriter) {
-  PackOp packOp = unPackOp.getSource().getDefiningOp<tensor::PackOp>();
-  if (!packOp || packOp.getDestType() != unPackOp.getSourceType())
-    return failure();
-  if (packOp.getPaddingValue() ||
-      !hasSameInnerOuterAttribute(packOp, unPackOp) ||
-      !haveSameTiles(packOp, unPackOp))
-    return failure();
-  rewriter.replaceOp(unPackOp, packOp.getSource());
-  return success();
+  if (PackOp packOp = unPackOp.getSource().getDefiningOp<tensor::PackOp>()) {
+    if (packOp.getDestType() != unPackOp.getSourceType())
+      return failure();
+    if (packOp.getPaddingValue() ||
+        !hasSameInnerOuterAttribute(packOp, unPackOp) ||
+        !haveSameTiles(packOp, unPackOp))
+      return failure();
+    rewriter.replaceOp(unPackOp, packOp.getSource());
+    return success();
+  }
+  if (DestinationStyleOpInterface dstStyleOp =
+          unPackOp.getDest().getDefiningOp<DestinationStyleOpInterface>()) {
+    OpResult destValue = unPackOp.getDest().cast<OpResult>();
+    Value newDest = dstStyleOp.getDpsInits()[destValue.getResultNumber()];
----------------
hanhanW wrote:

[optional] we can use `getDpsInitOperand` method so it does not return a full list.

```suggestion
    Value newDest = dstStyleOp.getDpsInitOperand(destValue.getResultNumber())->get();
```



https://github.com/llvm/llvm-project/pull/71468


More information about the Mlir-commits mailing list