[Mlir-commits] [mlir] bbf1d80 - [mlir][tensor] Fix transpose permutation in tensor.pack generalization pattern
Quinn Dawkins
llvmlistbot at llvm.org
Wed Feb 22 11:53:56 PST 2023
Author: Quinn Dawkins
Date: 2023-02-22T14:49:49-05:00
New Revision: bbf1d80d67db5076e4cd02caa754c0688239fe76
URL: https://github.com/llvm/llvm-project/commit/bbf1d80d67db5076e4cd02caa754c0688239fe76
DIFF: https://github.com/llvm/llvm-project/commit/bbf1d80d67db5076e4cd02caa754c0688239fe76.diff
LOG: [mlir][tensor] Fix transpose permutation in tensor.pack generalization pattern
The generalization pattern for tensor.pack was inverting the
innerDimsPos permutation when normalizing. Thus, the transpose op
produced by the generalization would be incorrect.
Differential Revision: https://reviews.llvm.org/D144425
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 3b2cd0d29d95..96c29e464aea 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -578,6 +578,13 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
// 2. Transpose the tile to match the inner tile order.
SmallVector<int64_t> perm =
getPackUnpackNormalizedInnerPerm(srcRank, packOp.getInnerDimsPos());
+ // The permutation is inverted when normalizing so invert back to match the
+ // ordering in the pack op.
+ perm = invertPermutationVector(perm);
+
+ LLVM_DEBUG(DBGS() << "Pack permutation: " << packOp << "\n";
+ llvm::interleaveComma(perm, DBGS() << "perm: "); DBGSNL(););
+
SmallVector<int64_t> transpShape = readShape;
applyPermutationToVector<int64_t>(transpShape, perm);
diff --git a/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir b/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir
index 3feb1657aad0..8e9b77ed6f67 100644
--- a/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir
@@ -58,3 +58,21 @@ func.func @simple_NC_to_CNnc(%arg0: tensor<32x8xf32>, %arg1: tensor<1x1x32x8xf32
// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %[[DEST]]
// CHECK-SAME: [0, 0, 0, 0] [1, 1, 32, 8] [1, 1, 1, 1]
// CHECK: return %[[INSERT]]
+
+// -----
+
+func.func @simple_CHW_to_CHWhwc(%arg0: tensor<3x5x7xf32>, %arg1: tensor<1x1x1x5x7x3xf32>) -> tensor<1x1x1x5x7x3xf32> {
+ %0 = tensor.pack %arg0 inner_dims_pos = [1, 2, 0] inner_tiles = [5, 7, 3] into %arg1 : tensor<3x5x7xf32> -> tensor<1x1x1x5x7x3xf32>
+ return %0 : tensor<1x1x1x5x7x3xf32>
+}
+// CHECK-LABEL: func.func @simple_CHW_to_CHWhwc
+// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<5x7x3xf32>
+// CHECK: %[[TRANSP:.+]] = linalg.transpose
+// CHECK-SAME: ins(%[[SRC]] : tensor<3x5x7xf32>)
+// CHECK-SAME: outs(%[[EMPTY]] : tensor<5x7x3xf32>)
+// CHECK-SAME: permutation = [1, 2, 0]
+// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %[[DEST]]
+// CHECK-SAME: [0, 0, 0, 0, 0, 0] [1, 1, 1, 5, 7, 3] [1, 1, 1, 1, 1, 1]
+// CHECK: return %[[INSERT]]
More information about the Mlir-commits
mailing list