[Mlir-commits] [mlir] [mlir][linalg] Lower unpack - capture handle to created copy op (PR #183744)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Feb 27 06:35:26 PST 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-linalg
Author: Adam Siemieniuk (adam-smnk)
<details>
<summary>Changes</summary>
Adds missing copy op created to unpack lowering results. Corresponding transform op is also updated with the new result value.
---
Full diff: https://github.com/llvm/llvm-project/pull/183744.diff
7 Files Affected:
- (modified) mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td (+4-3)
- (modified) mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h (+2-1)
- (modified) mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp (+1)
- (modified) mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp (+4-2)
- (modified) mlir/test/Dialect/Linalg/transform-lower-pack.mlir (+18-9)
- (modified) mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir (+4-2)
- (modified) mlir/test/Integration/Dialect/Linalg/CPU/pack-unpack-mmt4d.mlir (+2-1)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 70d424bae9285..caec229207ea6 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -717,7 +717,7 @@ def LowerUnPackOp : Op<Transform_Dialect, "structured.lower_unpack", [
ReportTrackingListenerFailuresOpTrait]> {
let description = [{
Lower a linalg.unpack into empty + linalg.transpose + tensor.collapse_shape +
- tensor.extract_slice.
+ tensor.extract_slice + linalg.copy.
#### Return modes
@@ -725,7 +725,7 @@ def LowerUnPackOp : Op<Transform_Dialect, "structured.lower_unpack", [
operation produces a silenceable failure if the rewrite fails for any
reason. If all the operations referred to by the `target` are rewritten,
the transform succeeds. Return handles to the newly produced empty,
- transpose, collapse_shape and extract_slice ops.
+ transpose, collapse_shape, extract_slice and copy ops.
}];
let arguments = (ins Transform_ConcreteOpType<"linalg.unpack">:$target,
@@ -733,7 +733,8 @@ def LowerUnPackOp : Op<Transform_Dialect, "structured.lower_unpack", [
let results = (outs Transform_ConcreteOpType<"tensor.empty">:$empty_op,
Transform_ConcreteOpType<"linalg.transpose">:$transpose_op,
Transform_ConcreteOpType<"tensor.collapse_shape">:$collapse_shape_op,
- Transform_ConcreteOpType<"tensor.extract_slice">:$extract_slice_op);
+ Transform_ConcreteOpType<"tensor.extract_slice">:$extract_slice_op,
+ Transform_ConcreteOpType<"linalg.copy">:$copy_op);
let assemblyFormat = [{
$target attr-dict `:` functional-type(operands, results)
}];
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index d1f313098a2c1..fb9cede670801 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1360,9 +1360,10 @@ struct LowerUnPackOpResult {
linalg::TransposeOp transposeOp;
tensor::CollapseShapeOp collapseShapeOp;
tensor::ExtractSliceOp extractSliceOp;
+ linalg::CopyOp copyOp;
};
-/// Rewrite pack as empty + transpose + reshape + extract_slice.
+/// Rewrite pack as empty + transpose + reshape + extract_slice + copy.
FailureOr<LowerUnPackOpResult>
lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp,
bool lowerUnpadLikeWithExtractSlice = true);
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index e945a15476b3a..309a4d989465d 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -1556,6 +1556,7 @@ DiagnosedSilenceableFailure transform::LowerUnPackOp::applyToOne(
transformResults.push_back(res->transposeOp);
transformResults.push_back(res->collapseShapeOp);
transformResults.push_back(res->extractSliceOp);
+ transformResults.push_back(res->copyOp);
return DiagnosedSilenceableFailure::success();
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index eb3eb48a7fe34..2b4986aeac14f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -382,7 +382,8 @@ linalg::lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp,
rewriter.replaceOp(unPackOp, extractSliceOp->getResults());
return LowerUnPackOpResult{/*emptyOp=*/nullptr, /*transposeOp=*/nullptr,
- /*reshapeOp=*/nullptr, extractSliceOp};
+ /*reshapeOp=*/nullptr, extractSliceOp,
+ /*copyOp=*/nullptr};
}
// 1. Compute the permutation vector to shuffle packed shape into the shape
@@ -444,7 +445,8 @@ linalg::lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp,
// 7. Replace unPackOp by copyOp.
rewriter.replaceOp(unPackOp, copyOp->getResults());
- return LowerUnPackOpResult{emptyOp, transposeOp, reshapeOp, extractSliceOp};
+ return LowerUnPackOpResult{emptyOp, transposeOp, reshapeOp, extractSliceOp,
+ copyOp};
}
SmallVector<int64_t>
diff --git a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
index 9e7681d1a1b7d..b6fe67a9ae1f3 100644
--- a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
+++ b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
@@ -185,7 +185,8 @@ module attributes {transform.with_named_sequence} {
-> (!transform.op<"tensor.empty">,
!transform.op<"linalg.transpose">,
!transform.op<"tensor.collapse_shape">,
- !transform.op<"tensor.extract_slice">)
+ !transform.op<"tensor.extract_slice">,
+ !transform.op<"linalg.copy">)
transform.yield
}
}
@@ -220,7 +221,8 @@ module attributes {transform.with_named_sequence} {
-> (!transform.op<"tensor.empty">,
!transform.op<"linalg.transpose">,
!transform.op<"tensor.collapse_shape">,
- !transform.op<"tensor.extract_slice">)
+ !transform.op<"tensor.extract_slice">,
+ !transform.op<"linalg.copy">)
transform.yield
}
}
@@ -254,7 +256,8 @@ module attributes {transform.with_named_sequence} {
-> (!transform.op<"tensor.empty">,
!transform.op<"linalg.transpose">,
!transform.op<"tensor.collapse_shape">,
- !transform.op<"tensor.extract_slice">)
+ !transform.op<"tensor.extract_slice">,
+ !transform.op<"linalg.copy">)
transform.yield
}
}
@@ -286,7 +289,8 @@ module attributes {transform.with_named_sequence} {
-> (!transform.op<"tensor.empty">,
!transform.op<"linalg.transpose">,
!transform.op<"tensor.collapse_shape">,
- !transform.op<"tensor.extract_slice">)
+ !transform.op<"tensor.extract_slice">,
+ !transform.op<"linalg.copy">)
transform.yield
}
}
@@ -554,7 +558,8 @@ module attributes {transform.with_named_sequence} {
-> (!transform.op<"tensor.empty">,
!transform.op<"linalg.transpose">,
!transform.op<"tensor.collapse_shape">,
- !transform.op<"tensor.extract_slice">)
+ !transform.op<"tensor.extract_slice">,
+ !transform.op<"linalg.copy">)
transform.yield
}
}
@@ -594,7 +599,8 @@ module attributes {transform.with_named_sequence} {
-> (!transform.op<"tensor.empty">,
!transform.op<"linalg.transpose">,
!transform.op<"tensor.collapse_shape">,
- !transform.op<"tensor.extract_slice">)
+ !transform.op<"tensor.extract_slice">,
+ !transform.op<"linalg.copy">)
transform.yield
}
}
@@ -637,7 +643,8 @@ module attributes {transform.with_named_sequence} {
-> (!transform.op<"tensor.empty">,
!transform.op<"linalg.transpose">,
!transform.op<"tensor.collapse_shape">,
- !transform.op<"tensor.extract_slice">)
+ !transform.op<"tensor.extract_slice">,
+ !transform.op<"linalg.copy">)
transform.yield
}
}
@@ -677,7 +684,8 @@ module attributes {transform.with_named_sequence} {
-> (!transform.op<"tensor.empty">,
!transform.op<"linalg.transpose">,
!transform.op<"tensor.collapse_shape">,
- !transform.op<"tensor.extract_slice">)
+ !transform.op<"tensor.extract_slice">,
+ !transform.op<"linalg.copy">)
transform.yield
}
}
@@ -711,7 +719,8 @@ module attributes {transform.with_named_sequence} {
-> (!transform.op<"tensor.empty">,
!transform.op<"linalg.transpose">,
!transform.op<"tensor.collapse_shape">,
- !transform.op<"tensor.extract_slice">)
+ !transform.op<"tensor.extract_slice">,
+ !transform.op<"linalg.copy">)
transform.yield
}
}
diff --git a/mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir b/mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir
index d72ab080f3c5c..dc4d2e434c0db 100644
--- a/mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir
+++ b/mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir
@@ -159,7 +159,8 @@ module {
-> (!transform.op<"tensor.empty">,
!transform.op<"linalg.transpose">,
!transform.op<"tensor.collapse_shape">,
- !transform.op<"tensor.extract_slice">)
+ !transform.op<"tensor.extract_slice">,
+ !transform.op<"linalg.copy">)
%root = transform.structured.match ops{["linalg.generic"]} in %arg1
: (!transform.any_op) -> !transform.any_op
@@ -220,7 +221,8 @@ module {
-> (!transform.op<"tensor.empty">,
!transform.op<"linalg.transpose">,
!transform.op<"tensor.collapse_shape">,
- !transform.op<"tensor.extract_slice">)
+ !transform.op<"tensor.extract_slice">,
+ !transform.op<"linalg.copy">)
%root = transform.structured.match ops{["linalg.generic"]} in %arg1
: (!transform.any_op) -> !transform.any_op
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/pack-unpack-mmt4d.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/pack-unpack-mmt4d.mlir
index a7bb039b04102..08dbe7c0ef345 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/pack-unpack-mmt4d.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/pack-unpack-mmt4d.mlir
@@ -159,7 +159,8 @@ module @transforms attributes { transform.with_named_sequence } {
-> (!transform.op<"tensor.empty">,
!transform.op<"linalg.transpose">,
!transform.op<"tensor.collapse_shape">,
- !transform.op<"tensor.extract_slice">)
+ !transform.op<"tensor.extract_slice">,
+ !transform.op<"linalg.copy">)
transform.yield
}
``````````
</details>
https://github.com/llvm/llvm-project/pull/183744
More information about the Mlir-commits
mailing list