[Mlir-commits] [mlir] 9a5092b - [MLIR][Tensor] Add canonicalization patterns for `tensor.pack`

Lorenzo Chelini llvmlistbot at llvm.org
Wed Jan 11 23:46:52 PST 2023


Author: Lorenzo Chelini
Date: 2023-01-12T08:46:45+01:00
New Revision: 9a5092b358ce52dd7c7dc8927529dac22523b6a2

URL: https://github.com/llvm/llvm-project/commit/9a5092b358ce52dd7c7dc8927529dac22523b6a2
DIFF: https://github.com/llvm/llvm-project/commit/9a5092b358ce52dd7c7dc8927529dac22523b6a2.diff

LOG: [MLIR][Tensor] Add canonicalization patterns for `tensor.pack`

- Fold an unpack(pack(x)) to x.

- Rewrite a `tensor.pack` to an `tensor.expand_shape` if only one
  dimension is packed.

Reviewed By: tyb0807, hanchung, mravishankar

Differential Revision: https://reviews.llvm.org/D141123

Added: 
    mlir/test/Dialect/Tensor/simplify-tensor-pack.mlir

Modified: 
    mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
    mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
    mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
    mlir/test/Dialect/Tensor/canonicalize.mlir
    mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
index 77be053ce0e17..b6e6024719728 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
+++ b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
@@ -142,10 +142,10 @@ FailureOr<Value> getOrCreateDestination(OpBuilder &b, Location loc,
 LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op,
                                       SmallVector<Value> &result);
 
-/// Function to control the folding of constant and extract slice
+/// Function to control the folding of constant and extract slice.
 using ControlConstantExtractSliceFusionFn = std::function<bool(ExtractSliceOp)>;
 
-/// Patterns to fold the extract slice op with its constant operand
+/// Patterns to fold the extract slice op with its constant operand.
 void populateFoldConstantExtractSlicePatterns(
     RewritePatternSet &patterns,
     const ControlConstantExtractSliceFusionFn &controlFn =
@@ -155,6 +155,9 @@ void populateFoldConstantExtractSlicePatterns(
           return false;
         });
 
+/// Patterns to simplify tensor.pack.
+void populateSimplifyTensorPack(RewritePatternSet &patterns);
+
 } // namespace tensor
 } // namespace mlir
 

diff  --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 0107c37904e59..3535146e3c05c 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -1790,6 +1790,8 @@ def Tensor_PackOp : Tensor_RelayoutOp<"pack", [
         Value source, ArrayRef<OpFoldResult> innerTileSizes,
         ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm);
   }];
+
+  let hasCanonicalizeMethod = 1;
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index cd962456ba424..5746b6eaffb54 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -3017,6 +3017,44 @@ OpFoldResult SplatOp::fold(ArrayRef<Attribute> operands) {
 // PackOp/UnPackOp Common
 //===----------------------------------------------------------------------===//
 
+namespace {
+
+/// Packing one-dimensional tensor can be expressed as an expand shape op.
+struct SimplifyPackToExandShape : public OpRewritePattern<PackOp> {
+  using OpRewritePattern<PackOp>::OpRewritePattern;
+
+  Value insertExpand(RewriterBase &rewriter, Location loc, Value operand,
+                     Type newOperandType, ArrayAttr reassociation) const {
+    if (operand.getType() == newOperandType)
+      return operand;
+    return rewriter.create<tensor::ExpandShapeOp>(loc, newOperandType, operand,
+                                                  reassociation);
+  }
+
+  LogicalResult matchAndRewrite(PackOp packOp,
+                                PatternRewriter &rewriter) const override {
+    RankedTensorType sourceType = packOp.getSourceType();
+    RankedTensorType destType = packOp.getDestType();
+    if (sourceType.getRank() != 1 || packOp.getPaddingValue())
+      return failure();
+    auto reassociation =
+        getReassociationIndicesForReshape(sourceType, destType);
+    if (!reassociation)
+      return failure();
+    Value expanded = insertExpand(
+        rewriter, packOp.getLoc(), packOp.getSource(), destType,
+        getReassociationIndicesAttribute(rewriter, *reassociation));
+    rewriter.replaceOp(packOp, expanded);
+    return success();
+  }
+};
+
+} // namespace
+
+void mlir::tensor::populateSimplifyTensorPack(RewritePatternSet &patterns) {
+  patterns.add<SimplifyPackToExandShape>(patterns.getContext());
+}
+
 template <typename OpTy>
 static LogicalResult
 reifyResultShapesImpl(OpTy op, OpBuilder &builder,
@@ -3376,6 +3414,41 @@ Speculation::Speculatability PackOp::getSpeculatability() {
   return Speculation::Speculatable;
 }
 
+// Return true if `inner_dims_pos` and `outer_dims_perm` target the same
+// dimensions for pack and unpack.
+static bool hasSameInnerOuterAttribute(PackOp packOp, UnPackOp unPackOp) {
+  if (packOp.getInnerDimsPos() != unPackOp.getInnerDimsPos())
+    return false;
+  return packOp.getOuterDimsPerm() == unPackOp.getOuterDimsPerm();
+}
+
+// Return true if pack and unpack have the same tiles.
+// Same SSA values or same integer constants.
+static bool haveSameTiles(PackOp packOp, UnPackOp unPackOp) {
+  auto packTiles = packOp.getMixedTiles();
+  auto unPackTiles = unPackOp.getMixedTiles();
+  if (packTiles.size() != unPackTiles.size())
+    return false;
+  for (size_t i = 0, e = packTiles.size(); i < e; i++) {
+    if (!isEqualConstantIntOrValue(packTiles[i], unPackTiles[i]))
+      return false;
+  }
+  return true;
+}
+
+/// Fold an unpack(pack(x)) to x.
+LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
+  UnPackOp unPackOp = packOp.getSource().getDefiningOp<UnPackOp>();
+  if (!unPackOp || unPackOp.getSourceType() != packOp.getDestType())
+    return failure();
+  if (packOp.getPaddingValue() ||
+      !hasSameInnerOuterAttribute(packOp, unPackOp) ||
+      !haveSameTiles(packOp, unPackOp))
+    return failure();
+  rewriter.replaceOp(packOp, unPackOp.getSource());
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // UnPackOp
 //===----------------------------------------------------------------------===//
@@ -3433,16 +3506,16 @@ void UnPackOp::build(OpBuilder &builder, OperationState &state, Value source,
 }
 
 /// pack(unpack(x)) -> x
-LogicalResult UnPackOp::canonicalize(UnPackOp unpackOp,
+LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
                                      PatternRewriter &rewriter) {
-  PackOp packOp = unpackOp.getSource().getDefiningOp<tensor::PackOp>();
-  if (!packOp || packOp.getDestType() != unpackOp.getSourceType())
-    return failure();
-  if (packOp.getInnerDimsPos() != unpackOp.getInnerDimsPos())
+  PackOp packOp = unPackOp.getSource().getDefiningOp<tensor::PackOp>();
+  if (!packOp || packOp.getDestType() != unPackOp.getSourceType())
     return failure();
-  if (packOp.getOuterDimsPerm() != unpackOp.getOuterDimsPerm())
+  if (packOp.getPaddingValue() ||
+      !hasSameInnerOuterAttribute(packOp, unPackOp) ||
+      !haveSameTiles(packOp, unPackOp))
     return failure();
-  rewriter.replaceOp(unpackOp, packOp.getSource());
+  rewriter.replaceOp(unPackOp, packOp.getSource());
   return success();
 }
 

diff  --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 6267c269ab0b7..f4706fc439b9e 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -1704,3 +1704,73 @@ func.func @unpack_pack(%t: tensor<128x128xf32>, %tile1: index, %tile2: index) ->
 <128x128xf32>
   return %unpacked : tensor<128x128xf32>
 }
+
+// -----
+
+// Chain NCnc -> NC -> NC -> NCnc
+// CHECK: func.func @pack_unpack(
+// CHECK-SAME: %[[T:.+]]: tensor<16x16x?x?xf32>,
+// CHECK: return %[[T]] : tensor<16x16x?x?xf32>
+func.func @pack_unpack(%t: tensor<16x16x?x?xf32>, %tile1: index, %tile2: index) -> tensor<16x16x?x?xf32> {
+  %tensor_empty = tensor.empty() : tensor<128x128xf32>
+  %unpacked = tensor.unpack %t inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty : tensor<16x16x?x?xf32> -> tensor<128x128xf32>
+  %tensor_empty1 = tensor.empty(%tile1, %tile2) : tensor<16x16x?x?xf32>
+  %packed = tensor.pack %unpacked inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty1 : tensor<128x128xf32> -> tensor<16x16x?x?xf32>
+  return %packed : tensor<16x16x?x?xf32>
+}
+
+// -----
+
+// Chain NCnc -> NC -> NC -> NCnc
+// CHECK: func.func @pack_unpack(
+// CHECK-SAME: %[[T:.+]]: tensor<16x16x8x8xf32>
+// CHECK: return %[[T]] : tensor<16x16x8x8xf32>
+func.func @pack_unpack(%t: tensor<16x16x8x8xf32>) -> tensor<16x16x8x8xf32> {
+  %tensor_empty = tensor.empty() : tensor<128x128xf32>
+  %unpacked = tensor.unpack %t inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %tensor_empty : tensor<16x16x8x8xf32> -> tensor<128x128xf32>
+  %tensor_empty1 = tensor.empty() : tensor<16x16x8x8xf32>
+  %packed = tensor.pack %unpacked inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %tensor_empty1 : tensor<128x128xf32> -> tensor<16x16x8x8xf32>
+  return %packed : tensor<16x16x8x8xf32>
+}
+
+// -----
+
+// CHECK: func.func @pack_unpack_same_tiles(
+// CHECK-SAME:  %[[T:.+]]: tensor<?x?x?x?xf32>,
+// CHECK: return %[[T]] : tensor<?x?x?x?xf32>
+func.func @pack_unpack_same_tiles(%t: tensor<?x?x?x?xf32>, %dim1: index, %dim2: index, %dim3: index, %dim4: index, %dim5: index, %dim6: index,
+                       %tile1: index, %tile2: index) -> tensor<?x?x?x?xf32> {
+  %tensor_empty = tensor.empty(%dim1, %dim2) : tensor<?x?xf32>
+  %unpacked = tensor.unpack %t inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty : tensor<?x?x?x?xf32> -> tensor<?x?xf32>
+  %tensor_empty1 = tensor.empty(%dim3, %dim4, %dim5, %dim6) : tensor<?x?x?x?xf32>
+  %packed = tensor.pack %unpacked inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty1 : tensor<?x?xf32> -> tensor<?x?x?x?xf32>
+  return %packed : tensor<?x?x?x?xf32>
+}
+
+// -----
+
+// CHECK: func.func @pack_unpack_
diff erent_tiles(
+// CHECK-SAME:  %[[T:.+]]: tensor<?x?x?x?xf32>,
+// CHECK-NOT: return %[[T]] : tensor<?x?x?x?xf32>
+func.func @pack_unpack_
diff erent_tiles(%t: tensor<?x?x?x?xf32>, %dim1: index, %dim2: index, %dim3: index, %dim4: index, %dim5: index, %dim6: index,
+                       %tile1: index, %tile2: index) -> tensor<?x?x?x?xf32> {
+  %tensor_empty = tensor.empty(%dim1, %dim2) : tensor<?x?xf32>
+  %unpacked = tensor.unpack %t inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty : tensor<?x?x?x?xf32> -> tensor<?x?xf32>
+  %tensor_empty1 = tensor.empty(%dim3, %dim4, %dim5, %dim6) : tensor<?x?x?x?xf32>
+  %packed = tensor.pack %unpacked inner_dims_pos = [0, 1] inner_tiles = [%tile2, %tile1] into %tensor_empty1 : tensor<?x?xf32> -> tensor<?x?x?x?xf32>
+  return %packed : tensor<?x?x?x?xf32>
+}
+
+// -----
+
+// CHECK: func.func @pack_unpack_dynamic_with_padding(
+// CHECK-SAME:  %[[T:.+]]: tensor<?x?x?x?xf32>,
+// CHECK-NOT: return %[[T]] : tensor<?x?x?x?xf32>
+func.func @pack_unpack_dynamic_with_padding(%t: tensor<?x?x?x?xf32>, %dim1: index, %dim2: index, %dim3: index, %dim4: index, %dim5: index, %dim6: index,
+                       %tile1: index, %tile2: index, %pad: f32) -> tensor<?x?x?x?xf32> {
+  %tensor_empty = tensor.empty(%dim1, %dim2) : tensor<?x?xf32>
+  %unpacked = tensor.unpack %t inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty : tensor<?x?x?x?xf32> -> tensor<?x?xf32>
+  %tensor_empty1 = tensor.empty(%dim3, %dim4, %dim5, %dim6) : tensor<?x?x?x?xf32>
+  %packed = tensor.pack %unpacked padding_value(%pad: f32) inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty1 : tensor<?x?xf32> -> tensor<?x?x?x?xf32>
+  return %packed : tensor<?x?x?x?xf32>
+}

diff  --git a/mlir/test/Dialect/Tensor/simplify-tensor-pack.mlir b/mlir/test/Dialect/Tensor/simplify-tensor-pack.mlir
new file mode 100644
index 0000000000000..75eb33ed033b9
--- /dev/null
+++ b/mlir/test/Dialect/Tensor/simplify-tensor-pack.mlir
@@ -0,0 +1,24 @@
+// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns="test-simplify-pack-patterns" %s | FileCheck %s
+
+// CHECK: func.func @single_dim_packing(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<256xf32>)
+// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1]] : tensor<256xf32> into tensor<8x32xf32>
+// CHECK: return %[[EXPANDED]] : tensor<8x32xf32>
+func.func @single_dim_packing(%arg0: tensor<256xf32>) -> tensor<8x32xf32> {
+  %empty = tensor.empty() : tensor<8x32xf32>
+  %0 = tensor.pack %arg0 inner_dims_pos = [0] inner_tiles = [32] into %empty : tensor<256xf32> -> tensor<8x32xf32>
+  return %0 : tensor<8x32xf32>
+}
+
+// -----
+
+// CHECK: func.func @single_dim_packing_with_padding(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<255xf32>)
+// CHECK-NOT: tensor.expand_shape
+// CHECK: tensor.pack
+func.func @single_dim_packing_with_padding(%arg0: tensor<255xf32>) -> tensor<8x32xf32> {
+  %empty = tensor.empty() : tensor<8x32xf32>
+  %cst = arith.constant 0.000000e+00 : f32
+  %0 = tensor.pack %arg0 padding_value(%cst : f32) inner_dims_pos = [0] inner_tiles = [32] into %empty : tensor<255xf32> -> tensor<8x32xf32>
+  return %0 : tensor<8x32xf32>
+}

diff  --git a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
index a87547035e75b..25039a5c41f92 100644
--- a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
@@ -85,6 +85,11 @@ struct TestTensorTransforms
           "Use the scf.foreach_thread operation when generating loop nests for "
           "the extract_slice of collapse_shape pattern"),
       llvm::cl::init(false)};
+
+  Option<bool> testSimplifyPackPatterns{
+      *this, "test-simplify-pack-patterns",
+      llvm::cl::desc("Test patterns to simplify tensor.pack"),
+      llvm::cl::init(false)};
 };
 } // namespace
 
@@ -134,6 +139,12 @@ static void applyFoldConsecutiveInsertExtractSlicePatterns(Operation *rootOp) {
   (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns));
 }
 
+static void applySimplifyPackPatterns(Operation *rootOp) {
+  RewritePatternSet patterns(rootOp->getContext());
+  tensor::populateSimplifyTensorPack(patterns);
+  (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns));
+}
+
 namespace {
 /// Base pattern to rewrite  a `tensor.collapse_shape -> tensor.extract_slice`.
 /// The `tensor.extract_slice` is replaced by a loop or gather operation that
@@ -277,6 +288,8 @@ applyRewriteExtractFromCollapseShapePatterns(Operation *rootOp,
 
 void TestTensorTransforms::runOnOperation() {
   Operation *rootOp = getOperation();
+  if (testSimplifyPackPatterns)
+    applySimplifyPackPatterns(rootOp);
   if (testSplitPaddingPatterns)
     applySplitPaddingPatterns(rootOp);
   if (testFoldConstantExtractSlice)


        


More information about the Mlir-commits mailing list