[mlir] [llvm] [mlir][tensor] Fold linalg transpose with tensor pack (PR #74206)

via llvm-commits llvm-commits at lists.llvm.org
Sat Dec 2 12:13:33 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-tensor

Author: Prathamesh Tagore (meshtag)

<details>
<summary>Changes</summary>

Partial fix to https://github.com/openxla/iree/issues/15367

---
Full diff: https://github.com/llvm/llvm-project/pull/74206.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp (+78-1) 
- (modified) mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir (+50) 


``````````diff
diff --git a/mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp
index 9eac3e5c7ef91..47d85a6f4f9a5 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp
@@ -6,6 +6,7 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Tensor/Transforms/Transforms.h"
 #include "mlir/IR/PatternMatch.h"
@@ -81,10 +82,86 @@ struct FoldUnpackWithExtractSliceOp : public OpRewritePattern<ExtractSliceOp> {
     return success();
   }
 };
+
+/// Fold 'pack' -> 'transpose' into 'pack' since 'pack' already has transpose
+/// semantics.
+struct FoldProducerPackWithConsumerLinalgTransposeOp
+    : public OpRewritePattern<linalg::TransposeOp> {
+  using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
+                                PatternRewriter &rewriter) const override {
+    auto transposeInputTensor = transposeOp.getOperand(0);
+    auto packOp = transposeInputTensor.getDefiningOp<PackOp>();
+
+    if (!packOp)
+      return failure();
+
+    auto packOuterDimsPerm = packOp.getOuterDimsPerm();
+    auto transposePerm = transposeOp.getPermutation();
+    llvm::SmallVector<int64_t> newPackOuterDimsPermVec;
+
+    for (unsigned int i = 0; i < packOuterDimsPerm.size(); ++i)
+      newPackOuterDimsPermVec.push_back(packOuterDimsPerm[transposePerm[i]]);
+
+    // Create a new empty output tensor.
+    Type elementType = packOp.getDestType().getElementType();
+    auto transposeOpResultType = transposeOp.getResult().getType()[0];
+    auto rankedTensorType = transposeOpResultType.dyn_cast<RankedTensorType>();
+    Value output = rewriter.create<EmptyOp>(
+        transposeOp.getLoc(), rankedTensorType.getShape(), elementType);
+
+    rewriter.replaceOpWithNewOp<PackOp>(
+        transposeOp, packOp.getSource(), output, packOp.getInnerDimsPos(),
+        packOp.getMixedTiles(), std::nullopt,
+        static_cast<llvm::ArrayRef<int64_t>>(newPackOuterDimsPermVec));
+
+    return success();
+  }
+};
+
+/// Fold 'transpose' -> 'pack' into 'pack' since 'pack' already has transpose
+/// semantics.
+struct FoldConsumerPackWithProducerLinalgTransposeOp
+    : public OpRewritePattern<PackOp> {
+  using OpRewritePattern<PackOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(PackOp packOp,
+                                PatternRewriter &rewriter) const override {
+    auto packInputTensor = packOp.getOperand(0);
+    auto transposeOp = packInputTensor.getDefiningOp<linalg::TransposeOp>();
+
+    if (!transposeOp)
+      return failure();
+
+    auto packOuterDimsPerm = packOp.getOuterDimsPerm();
+    auto transposePerm = transposeOp.getPermutation();
+    llvm::SmallVector<int64_t> newPackOuterDimsPermVec;
+
+    for (unsigned int i = 0; i < packOuterDimsPerm.size(); ++i)
+      newPackOuterDimsPermVec.push_back(transposePerm[packOuterDimsPerm[i]]);
+
+    // Create a new empty output tensor.
+    Type elementType = packOp.getDestType().getElementType();
+    auto packOpResultType = packOp.getResult().getType();
+    auto rankedTensorType = packOpResultType.dyn_cast<RankedTensorType>();
+    Value output = rewriter.create<EmptyOp>(
+        packOp.getLoc(), rankedTensorType.getShape(), elementType);
+
+    rewriter.replaceOpWithNewOp<PackOp>(
+        packOp, transposeOp.getOperand(0), output, packOp.getInnerDimsPos(),
+        packOp.getMixedTiles(), std::nullopt,
+        static_cast<llvm::ArrayRef<int64_t>>(newPackOuterDimsPermVec));
+
+    return success();
+  }
+};
 } // namespace
 
 void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns) {
-  patterns.insert<FoldUnpackWithExtractSliceOp, FoldPadWithPackOp>(
+  patterns.insert<FoldUnpackWithExtractSliceOp, FoldPadWithPackOp,
+                  FoldProducerPackWithConsumerLinalgTransposeOp,
+                  FoldConsumerPackWithProducerLinalgTransposeOp>(
       patterns.getContext());
 }
 
diff --git a/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir b/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
index 5c75789665742..0b00c7fa7feb9 100644
--- a/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
+++ b/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
@@ -114,3 +114,53 @@ func.func @pad_pack_different_padding_value(%src: tensor<16641x16xf32>) -> tenso
 // CHECK-LABEL: func.func @pad_pack_different_padding_value
 // CHECK:         tensor.pad
 // CHECK:         tensor.pack
+
+func.func @linalg_transpose_tensor_pack_fold(%arg0: tensor<56x57x1x64xf32>) -> tensor<1x2x56x57x32xf32> {
+  %0 = tensor.empty() : tensor<1x56x57x64xf32>
+  %transposed = linalg.transpose
+    ins(%arg0 : tensor<56x57x1x64xf32>)
+    outs(%0 : tensor<1x56x57x64xf32>)
+    permutation = [2, 0, 1, 3]
+
+  %1 = tensor.empty() : tensor<1x2x56x57x32xf32>
+  %pack = tensor.pack %transposed
+    outer_dims_perm = [0, 3, 1, 2]
+    inner_dims_pos = [3]
+    inner_tiles = [32]
+    into %1 : tensor<1x56x57x64xf32> -> tensor<1x2x56x57x32xf32>
+  return %pack : tensor<1x2x56x57x32xf32>
+}
+//      CHECK: func @linalg_transpose_tensor_pack_fold(
+// CHECK-SAME:     %[[ARG0:.+]]: tensor<56x57x1x64xf32>)
+//      CHECK:   %[[INIT:.+]] = tensor.empty() : tensor<1x2x56x57x32xf32>
+//      CHECK:   %[[PACK:.+]] = tensor.pack %[[ARG0]]
+// CHECK-SAME:      outer_dims_perm = [2, 3, 0, 1]
+// CHECK-SAME:      inner_dims_pos = [3] inner_tiles = [32] 
+// CHECK-SAME:       into %[[INIT]]
+//      CHECK:   return %[[PACK]]
+
+// -----
+
+func.func @tensor_pack_linalg_transpose_fold(%arg0: tensor<56x57x1x64xf32>) -> tensor<1x2x56x57x32xf32> {
+  %0 = tensor.empty() : tensor<56x57x1x2x32xf32>
+  %pack = tensor.pack %arg0
+    outer_dims_perm = [0, 1, 2, 3]
+    inner_dims_pos = [3]
+    inner_tiles = [32]
+    into %0 : tensor<56x57x1x64xf32> -> tensor<56x57x1x2x32xf32>
+
+  %1 = tensor.empty() : tensor<1x2x56x57x32xf32>
+  %transposed = linalg.transpose
+    ins(%pack : tensor<56x57x1x2x32xf32>)
+    outs(%1 : tensor<1x2x56x57x32xf32>)
+    permutation = [2, 3, 0, 1, 4]
+  return %transposed : tensor<1x2x56x57x32xf32>
+}
+//      CHECK: func @tensor_pack_linalg_transpose_fold(
+// CHECK-SAME:     %[[ARG0:.+]]: tensor<56x57x1x64xf32>)
+//      CHECK:   %[[INIT:.+]] = tensor.empty() : tensor<1x2x56x57x32xf32>
+//      CHECK:   %[[PACK:.+]] = tensor.pack %[[ARG0]]
+// CHECK-SAME:      outer_dims_perm = [2, 3, 0, 1]
+// CHECK-SAME:      inner_dims_pos = [3] inner_tiles = [32] 
+// CHECK-SAME:       into %[[INIT]]
+//      CHECK:   return %[[PACK]]

``````````

</details>


https://github.com/llvm/llvm-project/pull/74206


More information about the llvm-commits mailing list