[Mlir-commits] [mlir] d2f2ef8 - [MLIR][Linalg] Respect DPS in `lower_unpack`

Lorenzo Chelini llvmlistbot at llvm.org
Tue Aug 22 00:38:53 PDT 2023


Author: Lorenzo Chelini
Date: 2023-08-22T09:38:48+02:00
New Revision: d2f2ef84e82f2fa4cc47fabbb2d7a0ab011a654d

URL: https://github.com/llvm/llvm-project/commit/d2f2ef84e82f2fa4cc47fabbb2d7a0ab011a654d
DIFF: https://github.com/llvm/llvm-project/commit/d2f2ef84e82f2fa4cc47fabbb2d7a0ab011a654d.diff

LOG: [MLIR][Linalg] Respect DPS in `lower_unpack`

`tensor.unpack` implements the DPS (Destination Passing Style) interface
and expects the result to be "stored" in the `outs` operand, but this is
not the case with the current decomposition as the final operation is a
`tensor.extract_slice` that does not implement DPS. Add a `linalg.copy`
to fix the problem.

Reviewed By: springerm

Differential Revision: https://reviews.llvm.org/D158393

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
    mlir/test/Dialect/Linalg/transform-lower-pack.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index e39f8470e9c7a5..12f0bed76031af 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -454,7 +454,7 @@ FailureOr<LowerUnPackOpResult> linalg::lowerUnPack(RewriterBase &rewriter,
       loc, collapsedType, transposeOp->getResult(0),
       packingMetadata.reassociations);
 
-  // 6. ExtractSlice
+  // 6. ExtractSlice.
   int64_t destRank = destTensorType.getRank();
   auto extractSliceOp = rewriter.create<tensor::ExtractSliceOp>(
       loc, destTensorType, reshapeOp->getResult(0),
@@ -462,8 +462,12 @@ FailureOr<LowerUnPackOpResult> linalg::lowerUnPack(RewriterBase &rewriter,
       tensor::getMixedSizes(rewriter, loc, unPackOp->getResult(0)),
       SmallVector<OpFoldResult>(destRank, one));
 
-  // 7. Replace unPackOp by extractSliceOp.
-  rewriter.replaceOp(unPackOp, extractSliceOp->getResults());
+  // 7. Inject a copy to preserve DPS.
+  auto copyOp = rewriter.create<linalg::CopyOp>(
+      loc, extractSliceOp->getResult(0), unPackOp.getDest());
+
+  // 8. Replace unPackOp by extractSliceOp.
+  rewriter.replaceOp(unPackOp, copyOp->getResults());
 
   return LowerUnPackOpResult{emptyOp, transposeOp, reshapeOp, extractSliceOp};
 }

diff  --git a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
index 374ea994ed496b..c11d301140039a 100644
--- a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
+++ b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
@@ -123,16 +123,18 @@ transform.sequence failures(propagate) {
 // CHECK-LABEL: func.func @unpack(
 func.func @unpack(%arg0: tensor<17x2x16x16x32x8xf32>, %arg1: tensor<129x47x16x16xf32>) -> tensor<129x47x16x16xf32> {
   %cst_0 = arith.constant 0.0 : f32
-
-  //      CHECK: tensor.empty() : tensor<17x8x2x32x16x16xf32>
-  //      CHECK: linalg.transpose
-  // CHECK-SAME:    ins(%{{.*}} : tensor<17x2x16x16x32x8xf32>)
-  // CHECK-SAME:   outs(%{{.*}} : tensor<17x8x2x32x16x16xf32>)
+  // CHECK-SAME: %[[ARG0:.*]]: tensor<17x2x16x16x32x8xf32>, %[[ARG1:.*]]: tensor<129x47x16x16xf32>
+  //      CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<17x8x2x32x16x16xf32>
+  //      CHECK: %[[TRAN:.*]] = linalg.transpose
+  // CHECK-SAME:    ins(%[[ARG0]] : tensor<17x2x16x16x32x8xf32>)
+  // CHECK-SAME:   outs(%[[EMPTY]] : tensor<17x8x2x32x16x16xf32>)
   // CHECK-SAME:   permutation = [0, 5, 1, 4, 2, 3]
-  //      CHECK: tensor.collapse_shape {{.*}}[0, 1], [2, 3], [4], [5]]
+  //      CHECK: %[[CLP:.*]] = tensor.collapse_shape %[[TRAN]] {{\[}}[0, 1], [2, 3], [4], [5]]
   // CHECK-SAME:   : tensor<17x8x2x32x16x16xf32> into tensor<136x64x16x16xf32>
-  //      CHECK: tensor.extract_slice %{{.*}}[0, 0, 0, 0] [129, 47, 16, 16] [1, 1, 1, 1]
+  //      CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[CLP]][0, 0, 0, 0] [129, 47, 16, 16] [1, 1, 1, 1]
   // CHECK-SAME:   : tensor<136x64x16x16xf32> to tensor<129x47x16x16xf32>
+  //      CHECK: linalg.copy ins(%[[SLICE]] : tensor<129x47x16x16xf32>) 
+  // CHECK-SAME:        outs(%[[ARG1]] : tensor<129x47x16x16xf32>)
   %pack = tensor.unpack %arg0 inner_dims_pos = [1, 0] inner_tiles = [32, 8] into %arg1
     : tensor<17x2x16x16x32x8xf32> -> tensor<129x47x16x16xf32>
   return %pack : tensor<129x47x16x16xf32>


        


More information about the Mlir-commits mailing list