[Mlir-commits] [mlir] d2f2ef8 - [MLIR][Linalg] Respect DPS in `lower_unpack`
Lorenzo Chelini
llvmlistbot at llvm.org
Tue Aug 22 00:38:53 PDT 2023
Author: Lorenzo Chelini
Date: 2023-08-22T09:38:48+02:00
New Revision: d2f2ef84e82f2fa4cc47fabbb2d7a0ab011a654d
URL: https://github.com/llvm/llvm-project/commit/d2f2ef84e82f2fa4cc47fabbb2d7a0ab011a654d
DIFF: https://github.com/llvm/llvm-project/commit/d2f2ef84e82f2fa4cc47fabbb2d7a0ab011a654d.diff
LOG: [MLIR][Linalg] Respect DPS in `lower_unpack`
`tensor.unpack` implements the DPS (Destination Passing Style) interface
and expects the result to be "stored" in the `outs` operand, but this is
not the case with the current decomposition as the final operation is a
`tensor.extract_slice` that does not implement DPS. Add a `linalg.copy`
to fix the problem.
Reviewed By: springerm
Differential Revision: https://reviews.llvm.org/D158393
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
mlir/test/Dialect/Linalg/transform-lower-pack.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index e39f8470e9c7a5..12f0bed76031af 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -454,7 +454,7 @@ FailureOr<LowerUnPackOpResult> linalg::lowerUnPack(RewriterBase &rewriter,
loc, collapsedType, transposeOp->getResult(0),
packingMetadata.reassociations);
- // 6. ExtractSlice
+ // 6. ExtractSlice.
int64_t destRank = destTensorType.getRank();
auto extractSliceOp = rewriter.create<tensor::ExtractSliceOp>(
loc, destTensorType, reshapeOp->getResult(0),
@@ -462,8 +462,12 @@ FailureOr<LowerUnPackOpResult> linalg::lowerUnPack(RewriterBase &rewriter,
tensor::getMixedSizes(rewriter, loc, unPackOp->getResult(0)),
SmallVector<OpFoldResult>(destRank, one));
- // 7. Replace unPackOp by extractSliceOp.
- rewriter.replaceOp(unPackOp, extractSliceOp->getResults());
+ // 7. Inject a copy to preserve DPS.
+ auto copyOp = rewriter.create<linalg::CopyOp>(
+ loc, extractSliceOp->getResult(0), unPackOp.getDest());
+
+ // 8. Replace unPackOp by extractSliceOp.
+ rewriter.replaceOp(unPackOp, copyOp->getResults());
return LowerUnPackOpResult{emptyOp, transposeOp, reshapeOp, extractSliceOp};
}
diff --git a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
index 374ea994ed496b..c11d301140039a 100644
--- a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
+++ b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
@@ -123,16 +123,18 @@ transform.sequence failures(propagate) {
// CHECK-LABEL: func.func @unpack(
func.func @unpack(%arg0: tensor<17x2x16x16x32x8xf32>, %arg1: tensor<129x47x16x16xf32>) -> tensor<129x47x16x16xf32> {
%cst_0 = arith.constant 0.0 : f32
-
- // CHECK: tensor.empty() : tensor<17x8x2x32x16x16xf32>
- // CHECK: linalg.transpose
- // CHECK-SAME: ins(%{{.*}} : tensor<17x2x16x16x32x8xf32>)
- // CHECK-SAME: outs(%{{.*}} : tensor<17x8x2x32x16x16xf32>)
+ // CHECK-SAME: %[[ARG0:.*]]: tensor<17x2x16x16x32x8xf32>, %[[ARG1:.*]]: tensor<129x47x16x16xf32>
+ // CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<17x8x2x32x16x16xf32>
+ // CHECK: %[[TRAN:.*]] = linalg.transpose
+ // CHECK-SAME: ins(%[[ARG0]] : tensor<17x2x16x16x32x8xf32>)
+ // CHECK-SAME: outs(%[[EMPTY]] : tensor<17x8x2x32x16x16xf32>)
// CHECK-SAME: permutation = [0, 5, 1, 4, 2, 3]
- // CHECK: tensor.collapse_shape {{.*}}[0, 1], [2, 3], [4], [5]]
+ // CHECK: %[[CLP:.*]] = tensor.collapse_shape %[[TRAN]] {{\[}}[0, 1], [2, 3], [4], [5]]
// CHECK-SAME: : tensor<17x8x2x32x16x16xf32> into tensor<136x64x16x16xf32>
- // CHECK: tensor.extract_slice %{{.*}}[0, 0, 0, 0] [129, 47, 16, 16] [1, 1, 1, 1]
+ // CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[CLP]][0, 0, 0, 0] [129, 47, 16, 16] [1, 1, 1, 1]
// CHECK-SAME: : tensor<136x64x16x16xf32> to tensor<129x47x16x16xf32>
+ // CHECK: linalg.copy ins(%[[SLICE]] : tensor<129x47x16x16xf32>)
+ // CHECK-SAME: outs(%[[ARG1]] : tensor<129x47x16x16xf32>)
%pack = tensor.unpack %arg0 inner_dims_pos = [1, 0] inner_tiles = [32, 8] into %arg1
: tensor<17x2x16x16x32x8xf32> -> tensor<129x47x16x16xf32>
return %pack : tensor<129x47x16x16xf32>
More information about the Mlir-commits
mailing list