[Mlir-commits] [mlir] 44f7356 - [MLIR][Tensor] Add canonicalization for UnpackOp
Lorenzo Chelini
llvmlistbot at llvm.org
Thu Dec 1 06:17:55 PST 2022
Author: Lorenzo Chelini
Date: 2022-12-01T15:17:50+01:00
New Revision: 44f73560057609783575d8034cca3f40ae3960eb
URL: https://github.com/llvm/llvm-project/commit/44f73560057609783575d8034cca3f40ae3960eb
DIFF: https://github.com/llvm/llvm-project/commit/44f73560057609783575d8034cca3f40ae3960eb.diff
LOG: [MLIR][Tensor] Add canonicalization for UnpackOp
pack(unpack(x)) -> x
Reviewed By: hanchung
Differential Revision: https://reviews.llvm.org/D138917
Added:
Modified:
mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
mlir/test/Dialect/Tensor/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index c9c5fb187fe33..c20b92a01ab97 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -1824,6 +1824,7 @@ def Tensor_UnPackOp : Tensor_RelayoutOp<"unpack"> {
}];
let extraClassDeclaration = commonExtraClassDeclaration;
+ let hasCanonicalizeMethod = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 23bfb232f47fa..9b7cad34c5d4f 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -3419,6 +3419,20 @@ Speculation::Speculatability UnPackOp::getSpeculatability() {
return Speculation::Speculatable;
}
+/// pack(unpack(x)) -> x
+LogicalResult UnPackOp::canonicalize(UnPackOp unpackOp,
+ PatternRewriter &rewriter) {
+ PackOp packOp = unpackOp.getSource().getDefiningOp<tensor::PackOp>();
+ if (!packOp || packOp.getDestType() != unpackOp.getSourceType())
+ return failure();
+ if (packOp.getInnerDimsPos() != unpackOp.getInnerDimsPos())
+ return failure();
+ if (packOp.getOuterDimsPerm() != unpackOp.getOuterDimsPerm())
+ return failure();
+ rewriter.replaceOp(unpackOp, packOp.getSource());
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 92e329d916a20..04e7207434b0f 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -1678,3 +1678,62 @@ func.func @collapse_expand_fold_to_cast(%t: tensor<?xf32>) -> (tensor<?xf32>)
%1 = tensor.collapse_shape %0 [[0, 1]] : tensor<1x?xf32> into tensor<?xf32>
return %1 : tensor<?xf32>
}
+
+// -----
+
+// Chain: NC -> NCnc -> NCnc -> NC
+// CHECK: func.func @unpack_pack(
+// CHECK-SAME: %[[T:.+]]: tensor<128x128xf32>)
+// CHECK: return %[[T]] : tensor<128x128xf32>
+func.func @unpack_pack(%t: tensor<128x128xf32>) -> tensor<128x128xf32> {
+ %tensor_empty = tensor.empty() : tensor<16x16x8x8xf32>
+ %packed = tensor.pack %t inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %tensor_empty : tensor<128x128xf32> -> tensor<16x16x8x8xf32>
+ %tensor_empty1 = tensor.empty() : tensor<128x128xf32>
+ %unpacked = tensor.unpack %packed inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %tensor_empty1 : tensor<16x16x8x8xf32> -> tensor<128x128xf32>
+ return %unpacked : tensor<128x128xf32>
+}
+
+// -----
+
+// Chain: NC -> NCcn -> NCnc -> NC
+// CHECK: func.func @unpack_pack(
+// CHECK-SAME: %[[T:.+]]: tensor<128x128xf32>)
+// CHECK-NOT: return %[[T]] : tensor<128x128xf32>
+func.func @unpack_pack(%t: tensor<128x128xf32>) -> tensor<128x128xf32> {
+ %tensor_empty = tensor.empty() : tensor<16x16x8x8xf32>
+ %packed = tensor.pack %t inner_dims_pos = [1, 0] inner_tiles = [8, 8] into %tensor_empty : tensor<128x128xf32> -> tensor<16x16x8x8xf32>
+ %tensor_empty1 = tensor.empty() : tensor<128x128xf32>
+ %unpacked = tensor.unpack %packed inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %tensor_empty1 : tensor<16x16x8x8xf32> -> tensor
+<128x128xf32>
+ return %unpacked : tensor<128x128xf32>
+}
+
+// -----
+
+// Chain: NC -> CNcn -> NCnc -> NC
+// CHECK: func.func @unpack_pack(
+// CHECK-SAME: %[[T:.+]]: tensor<128x128xf32>)
+// CHECK-NOT: return %[[T]] : tensor<128x128xf32>
+func.func @unpack_pack(%t: tensor<128x128xf32>) -> tensor<128x128xf32> {
+ %tensor_empty = tensor.empty() : tensor<16x16x8x8xf32>
+ %packed = tensor.pack %t outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [8, 8] into %tensor_empty : tensor<128x128xf32> -> tensor<16x16x8x8xf32>
+ %tensor_empty1 = tensor.empty() : tensor<128x128xf32>
+ %unpacked = tensor.unpack %packed inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %tensor_empty1 : tensor<16x16x8x8xf32> -> tensor
+<128x128xf32>
+ return %unpacked : tensor<128x128xf32>
+}
+
+// -----
+
+// Chain: NC -> NCnc -> NCnc -> NC
+// CHECK: func.func @unpack_pack(
+// CHECK-SAME: %[[T:.+]]: tensor<128x128xf32>,
+// CHECK: return %[[T]] : tensor<128x128xf32>
+func.func @unpack_pack(%t: tensor<128x128xf32>, %tile1: index, %tile2: index) -> tensor<128x128xf32> {
+ %tensor_empty = tensor.empty(%tile1, %tile2) : tensor<16x16x?x?xf32>
+ %packed = tensor.pack %t inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty : tensor<128x128xf32> -> tensor<16x16x?x?xf32>
+ %tensor_empty1 = tensor.empty() : tensor<128x128xf32>
+ %unpacked = tensor.unpack %packed inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty1 : tensor<16x16x?x?xf32> -> tensor
+<128x128xf32>
+ return %unpacked : tensor<128x128xf32>
+}
More information about the Mlir-commits
mailing list