[Mlir-commits] [mlir] bbf1d80 - [mlir][tensor] Fix transpose permutation in tensor.pack generalization pattern

Quinn Dawkins llvmlistbot at llvm.org
Wed Feb 22 11:53:56 PST 2023


Author: Quinn Dawkins
Date: 2023-02-22T14:49:49-05:00
New Revision: bbf1d80d67db5076e4cd02caa754c0688239fe76

URL: https://github.com/llvm/llvm-project/commit/bbf1d80d67db5076e4cd02caa754c0688239fe76
DIFF: https://github.com/llvm/llvm-project/commit/bbf1d80d67db5076e4cd02caa754c0688239fe76.diff

LOG: [mlir][tensor] Fix transpose permutation in tensor.pack generalization pattern

The generalization pattern for tensor.pack was inverting the
innerDimsPos permutation when normalizing. Thus, the transpose op
produced by the generalization would be incorrect.

Differential Revision: https://reviews.llvm.org/D144425

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
    mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 3b2cd0d29d95..96c29e464aea 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -578,6 +578,13 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
   // 2. Transpose the tile to match the inner tile order.
   SmallVector<int64_t> perm =
       getPackUnpackNormalizedInnerPerm(srcRank, packOp.getInnerDimsPos());
+  // The permutation is inverted when normalizing so invert back to match the
+  // ordering in the pack op.
+  perm = invertPermutationVector(perm);
+
+  LLVM_DEBUG(DBGS() << "Pack permutation: " << packOp << "\n";
+             llvm::interleaveComma(perm, DBGS() << "perm: "); DBGSNL(););
+
   SmallVector<int64_t> transpShape = readShape;
   applyPermutationToVector<int64_t>(transpShape, perm);
 

diff  --git a/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir b/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir
index 3feb1657aad0..8e9b77ed6f67 100644
--- a/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir
@@ -58,3 +58,21 @@ func.func @simple_NC_to_CNnc(%arg0: tensor<32x8xf32>, %arg1: tensor<1x1x32x8xf32
 // CHECK:         %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %[[DEST]]
 // CHECK-SAME:      [0, 0, 0, 0] [1, 1, 32, 8] [1, 1, 1, 1]
 // CHECK:         return %[[INSERT]]
+
+// -----
+
+func.func @simple_CHW_to_CHWhwc(%arg0: tensor<3x5x7xf32>, %arg1: tensor<1x1x1x5x7x3xf32>) -> tensor<1x1x1x5x7x3xf32> {
+  %0 = tensor.pack %arg0 inner_dims_pos = [1, 2, 0] inner_tiles = [5, 7, 3] into %arg1 : tensor<3x5x7xf32> -> tensor<1x1x1x5x7x3xf32>
+  return %0 : tensor<1x1x1x5x7x3xf32>
+}
+// CHECK-LABEL: func.func @simple_CHW_to_CHWhwc
+// CHECK-SAME:    %[[SRC:[a-zA-Z0-9]+]]
+// CHECK-SAME:    %[[DEST:[a-zA-Z0-9]+]]
+// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<5x7x3xf32>
+// CHECK:         %[[TRANSP:.+]] =  linalg.transpose
+// CHECK-SAME:      ins(%[[SRC]] : tensor<3x5x7xf32>)
+// CHECK-SAME:      outs(%[[EMPTY]] : tensor<5x7x3xf32>)
+// CHECK-SAME:      permutation = [1, 2, 0]
+// CHECK:         %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %[[DEST]]
+// CHECK-SAME:      [0, 0, 0, 0, 0, 0] [1, 1, 1, 5, 7, 3] [1, 1, 1, 1, 1, 1]
+// CHECK:         return %[[INSERT]]


        


More information about the Mlir-commits mailing list