[Mlir-commits] [mlir] 44f7356 - [MLIR][Tensor] Add canonicalization for UnpackOp

Lorenzo Chelini llvmlistbot at llvm.org
Thu Dec 1 06:17:55 PST 2022


Author: Lorenzo Chelini
Date: 2022-12-01T15:17:50+01:00
New Revision: 44f73560057609783575d8034cca3f40ae3960eb

URL: https://github.com/llvm/llvm-project/commit/44f73560057609783575d8034cca3f40ae3960eb
DIFF: https://github.com/llvm/llvm-project/commit/44f73560057609783575d8034cca3f40ae3960eb.diff

LOG: [MLIR][Tensor] Add canonicalization for UnpackOp

pack(unpack(x)) -> x

Reviewed By: hanchung

Differential Revision: https://reviews.llvm.org/D138917

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
    mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
    mlir/test/Dialect/Tensor/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index c9c5fb187fe33..c20b92a01ab97 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -1824,6 +1824,7 @@ def Tensor_UnPackOp : Tensor_RelayoutOp<"unpack"> {
   }];
 
   let extraClassDeclaration = commonExtraClassDeclaration;
+  let hasCanonicalizeMethod = 1;
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 23bfb232f47fa..9b7cad34c5d4f 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -3419,6 +3419,20 @@ Speculation::Speculatability UnPackOp::getSpeculatability() {
   return Speculation::Speculatable;
 }
 
+/// pack(unpack(x)) -> x
+LogicalResult UnPackOp::canonicalize(UnPackOp unpackOp,
+                                     PatternRewriter &rewriter) {
+  PackOp packOp = unpackOp.getSource().getDefiningOp<tensor::PackOp>();
+  if (!packOp || packOp.getDestType() != unpackOp.getSourceType())
+    return failure();
+  if (packOp.getInnerDimsPos() != unpackOp.getInnerDimsPos())
+    return failure();
+  if (packOp.getOuterDimsPerm() != unpackOp.getOuterDimsPerm())
+    return failure();
+  rewriter.replaceOp(unpackOp, packOp.getSource());
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // TableGen'd op method definitions
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 92e329d916a20..04e7207434b0f 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -1678,3 +1678,62 @@ func.func @collapse_expand_fold_to_cast(%t: tensor<?xf32>) -> (tensor<?xf32>)
   %1 = tensor.collapse_shape %0 [[0, 1]] : tensor<1x?xf32> into tensor<?xf32>
   return %1 : tensor<?xf32>
 }
+
+// -----
+
+// Chain: NC -> NCnc -> NCnc -> NC
+// CHECK: func.func @unpack_pack(
+// CHECK-SAME: %[[T:.+]]: tensor<128x128xf32>)
+// CHECK: return %[[T]] : tensor<128x128xf32>
+func.func @unpack_pack(%t: tensor<128x128xf32>) -> tensor<128x128xf32> {
+  %tensor_empty = tensor.empty() : tensor<16x16x8x8xf32>
+  %packed = tensor.pack %t inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %tensor_empty : tensor<128x128xf32> -> tensor<16x16x8x8xf32>
+  %tensor_empty1 = tensor.empty() : tensor<128x128xf32>
+  %unpacked = tensor.unpack %packed inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %tensor_empty1 : tensor<16x16x8x8xf32> -> tensor<128x128xf32>
+  return %unpacked : tensor<128x128xf32>
+}
+
+// -----
+
+// Chain: NC -> NCcn -> NCnc -> NC
+// CHECK: func.func @unpack_pack(
+// CHECK-SAME: %[[T:.+]]: tensor<128x128xf32>)
+// CHECK-NOT: return %[[T]] : tensor<128x128xf32>
+func.func @unpack_pack(%t: tensor<128x128xf32>) -> tensor<128x128xf32> {
+  %tensor_empty = tensor.empty() : tensor<16x16x8x8xf32>
+  %packed = tensor.pack %t inner_dims_pos = [1, 0] inner_tiles = [8, 8] into %tensor_empty : tensor<128x128xf32> -> tensor<16x16x8x8xf32>
+  %tensor_empty1 = tensor.empty() : tensor<128x128xf32>
+  %unpacked = tensor.unpack %packed inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %tensor_empty1 : tensor<16x16x8x8xf32> -> tensor
+<128x128xf32>
+  return %unpacked : tensor<128x128xf32>
+}
+
+// -----
+
+// Chain: NC -> CNcn -> NCnc -> NC
+// CHECK: func.func @unpack_pack(
+// CHECK-SAME: %[[T:.+]]: tensor<128x128xf32>)
+// CHECK-NOT: return %[[T]] : tensor<128x128xf32>
+func.func @unpack_pack(%t: tensor<128x128xf32>) -> tensor<128x128xf32> {
+  %tensor_empty = tensor.empty() : tensor<16x16x8x8xf32>
+  %packed = tensor.pack %t outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [8, 8] into %tensor_empty : tensor<128x128xf32> -> tensor<16x16x8x8xf32>
+  %tensor_empty1 = tensor.empty() : tensor<128x128xf32>
+  %unpacked = tensor.unpack %packed inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %tensor_empty1 : tensor<16x16x8x8xf32> -> tensor
+<128x128xf32>
+  return %unpacked : tensor<128x128xf32>
+}
+
+// -----
+
+// Chain: NC -> NCnc -> NCnc -> NC
+// CHECK: func.func @unpack_pack(
+// CHECK-SAME: %[[T:.+]]: tensor<128x128xf32>, 
+// CHECK: return %[[T]] : tensor<128x128xf32>
+func.func @unpack_pack(%t: tensor<128x128xf32>, %tile1: index, %tile2: index) -> tensor<128x128xf32> {
+  %tensor_empty = tensor.empty(%tile1, %tile2) : tensor<16x16x?x?xf32>
+  %packed = tensor.pack %t inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty : tensor<128x128xf32> -> tensor<16x16x?x?xf32>
+  %tensor_empty1 = tensor.empty() : tensor<128x128xf32>
+  %unpacked = tensor.unpack %packed inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty1 : tensor<16x16x?x?xf32> -> tensor
+<128x128xf32>
+  return %unpacked : tensor<128x128xf32>
+}


        


More information about the Mlir-commits mailing list