[Mlir-commits] [mlir] 9a5092b - [MLIR][Tensor] Add canonicalization patterns for `tensor.pack`
Lorenzo Chelini
llvmlistbot at llvm.org
Wed Jan 11 23:46:52 PST 2023
Author: Lorenzo Chelini
Date: 2023-01-12T08:46:45+01:00
New Revision: 9a5092b358ce52dd7c7dc8927529dac22523b6a2
URL: https://github.com/llvm/llvm-project/commit/9a5092b358ce52dd7c7dc8927529dac22523b6a2
DIFF: https://github.com/llvm/llvm-project/commit/9a5092b358ce52dd7c7dc8927529dac22523b6a2.diff
LOG: [MLIR][Tensor] Add canonicalization patterns for `tensor.pack`
- Fold an unpack(pack(x)) to x.
- Rewrite a `tensor.pack` to an `tensor.expand_shape` if only one
dimension is packed.
Reviewed By: tyb0807, hanchung, mravishankar
Differential Revision: https://reviews.llvm.org/D141123
Added:
mlir/test/Dialect/Tensor/simplify-tensor-pack.mlir
Modified:
mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
mlir/test/Dialect/Tensor/canonicalize.mlir
mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
index 77be053ce0e17..b6e6024719728 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
+++ b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
@@ -142,10 +142,10 @@ FailureOr<Value> getOrCreateDestination(OpBuilder &b, Location loc,
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op,
SmallVector<Value> &result);
-/// Function to control the folding of constant and extract slice
+/// Function to control the folding of constant and extract slice.
using ControlConstantExtractSliceFusionFn = std::function<bool(ExtractSliceOp)>;
-/// Patterns to fold the extract slice op with its constant operand
+/// Patterns to fold the extract slice op with its constant operand.
void populateFoldConstantExtractSlicePatterns(
RewritePatternSet &patterns,
const ControlConstantExtractSliceFusionFn &controlFn =
@@ -155,6 +155,9 @@ void populateFoldConstantExtractSlicePatterns(
return false;
});
+/// Patterns to simplify tensor.pack.
+void populateSimplifyTensorPack(RewritePatternSet &patterns);
+
} // namespace tensor
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 0107c37904e59..3535146e3c05c 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -1790,6 +1790,8 @@ def Tensor_PackOp : Tensor_RelayoutOp<"pack", [
Value source, ArrayRef<OpFoldResult> innerTileSizes,
ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm);
}];
+
+ let hasCanonicalizeMethod = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index cd962456ba424..5746b6eaffb54 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -3017,6 +3017,44 @@ OpFoldResult SplatOp::fold(ArrayRef<Attribute> operands) {
// PackOp/UnPackOp Common
//===----------------------------------------------------------------------===//
+namespace {
+
+/// Packing one-dimensional tensor can be expressed as an expand shape op.
+struct SimplifyPackToExandShape : public OpRewritePattern<PackOp> {
+ using OpRewritePattern<PackOp>::OpRewritePattern;
+
+ Value insertExpand(RewriterBase &rewriter, Location loc, Value operand,
+ Type newOperandType, ArrayAttr reassociation) const {
+ if (operand.getType() == newOperandType)
+ return operand;
+ return rewriter.create<tensor::ExpandShapeOp>(loc, newOperandType, operand,
+ reassociation);
+ }
+
+ LogicalResult matchAndRewrite(PackOp packOp,
+ PatternRewriter &rewriter) const override {
+ RankedTensorType sourceType = packOp.getSourceType();
+ RankedTensorType destType = packOp.getDestType();
+ if (sourceType.getRank() != 1 || packOp.getPaddingValue())
+ return failure();
+ auto reassociation =
+ getReassociationIndicesForReshape(sourceType, destType);
+ if (!reassociation)
+ return failure();
+ Value expanded = insertExpand(
+ rewriter, packOp.getLoc(), packOp.getSource(), destType,
+ getReassociationIndicesAttribute(rewriter, *reassociation));
+ rewriter.replaceOp(packOp, expanded);
+ return success();
+ }
+};
+
+} // namespace
+
+void mlir::tensor::populateSimplifyTensorPack(RewritePatternSet &patterns) {
+ patterns.add<SimplifyPackToExandShape>(patterns.getContext());
+}
+
template <typename OpTy>
static LogicalResult
reifyResultShapesImpl(OpTy op, OpBuilder &builder,
@@ -3376,6 +3414,41 @@ Speculation::Speculatability PackOp::getSpeculatability() {
return Speculation::Speculatable;
}
+// Return true if `inner_dims_pos` and `outer_dims_perm` target the same
+// dimensions for pack and unpack.
+static bool hasSameInnerOuterAttribute(PackOp packOp, UnPackOp unPackOp) {
+ if (packOp.getInnerDimsPos() != unPackOp.getInnerDimsPos())
+ return false;
+ return packOp.getOuterDimsPerm() == unPackOp.getOuterDimsPerm();
+}
+
+// Return true if pack and unpack have the same tiles.
+// Same SSA values or same integer constants.
+static bool haveSameTiles(PackOp packOp, UnPackOp unPackOp) {
+ auto packTiles = packOp.getMixedTiles();
+ auto unPackTiles = unPackOp.getMixedTiles();
+ if (packTiles.size() != unPackTiles.size())
+ return false;
+ for (size_t i = 0, e = packTiles.size(); i < e; i++) {
+ if (!isEqualConstantIntOrValue(packTiles[i], unPackTiles[i]))
+ return false;
+ }
+ return true;
+}
+
+/// Fold an unpack(pack(x)) to x.
+LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
+ UnPackOp unPackOp = packOp.getSource().getDefiningOp<UnPackOp>();
+ if (!unPackOp || unPackOp.getSourceType() != packOp.getDestType())
+ return failure();
+ if (packOp.getPaddingValue() ||
+ !hasSameInnerOuterAttribute(packOp, unPackOp) ||
+ !haveSameTiles(packOp, unPackOp))
+ return failure();
+ rewriter.replaceOp(packOp, unPackOp.getSource());
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// UnPackOp
//===----------------------------------------------------------------------===//
@@ -3433,16 +3506,16 @@ void UnPackOp::build(OpBuilder &builder, OperationState &state, Value source,
}
/// pack(unpack(x)) -> x
-LogicalResult UnPackOp::canonicalize(UnPackOp unpackOp,
+LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
PatternRewriter &rewriter) {
- PackOp packOp = unpackOp.getSource().getDefiningOp<tensor::PackOp>();
- if (!packOp || packOp.getDestType() != unpackOp.getSourceType())
- return failure();
- if (packOp.getInnerDimsPos() != unpackOp.getInnerDimsPos())
+ PackOp packOp = unPackOp.getSource().getDefiningOp<tensor::PackOp>();
+ if (!packOp || packOp.getDestType() != unPackOp.getSourceType())
return failure();
- if (packOp.getOuterDimsPerm() != unpackOp.getOuterDimsPerm())
+ if (packOp.getPaddingValue() ||
+ !hasSameInnerOuterAttribute(packOp, unPackOp) ||
+ !haveSameTiles(packOp, unPackOp))
return failure();
- rewriter.replaceOp(unpackOp, packOp.getSource());
+ rewriter.replaceOp(unPackOp, packOp.getSource());
return success();
}
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 6267c269ab0b7..f4706fc439b9e 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -1704,3 +1704,73 @@ func.func @unpack_pack(%t: tensor<128x128xf32>, %tile1: index, %tile2: index) ->
<128x128xf32>
return %unpacked : tensor<128x128xf32>
}
+
+// -----
+
+// Chain NCnc -> NC -> NC -> NCnc
+// CHECK: func.func @pack_unpack(
+// CHECK-SAME: %[[T:.+]]: tensor<16x16x?x?xf32>,
+// CHECK: return %[[T]] : tensor<16x16x?x?xf32>
+func.func @pack_unpack(%t: tensor<16x16x?x?xf32>, %tile1: index, %tile2: index) -> tensor<16x16x?x?xf32> {
+ %tensor_empty = tensor.empty() : tensor<128x128xf32>
+ %unpacked = tensor.unpack %t inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty : tensor<16x16x?x?xf32> -> tensor<128x128xf32>
+ %tensor_empty1 = tensor.empty(%tile1, %tile2) : tensor<16x16x?x?xf32>
+ %packed = tensor.pack %unpacked inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty1 : tensor<128x128xf32> -> tensor<16x16x?x?xf32>
+ return %packed : tensor<16x16x?x?xf32>
+}
+
+// -----
+
+// Chain NCnc -> NC -> NC -> NCnc
+// CHECK: func.func @pack_unpack(
+// CHECK-SAME: %[[T:.+]]: tensor<16x16x8x8xf32>
+// CHECK: return %[[T]] : tensor<16x16x8x8xf32>
+func.func @pack_unpack(%t: tensor<16x16x8x8xf32>) -> tensor<16x16x8x8xf32> {
+ %tensor_empty = tensor.empty() : tensor<128x128xf32>
+ %unpacked = tensor.unpack %t inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %tensor_empty : tensor<16x16x8x8xf32> -> tensor<128x128xf32>
+ %tensor_empty1 = tensor.empty() : tensor<16x16x8x8xf32>
+ %packed = tensor.pack %unpacked inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %tensor_empty1 : tensor<128x128xf32> -> tensor<16x16x8x8xf32>
+ return %packed : tensor<16x16x8x8xf32>
+}
+
+// -----
+
+// CHECK: func.func @pack_unpack_same_tiles(
+// CHECK-SAME: %[[T:.+]]: tensor<?x?x?x?xf32>,
+// CHECK: return %[[T]] : tensor<?x?x?x?xf32>
+func.func @pack_unpack_same_tiles(%t: tensor<?x?x?x?xf32>, %dim1: index, %dim2: index, %dim3: index, %dim4: index, %dim5: index, %dim6: index,
+ %tile1: index, %tile2: index) -> tensor<?x?x?x?xf32> {
+ %tensor_empty = tensor.empty(%dim1, %dim2) : tensor<?x?xf32>
+ %unpacked = tensor.unpack %t inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty : tensor<?x?x?x?xf32> -> tensor<?x?xf32>
+ %tensor_empty1 = tensor.empty(%dim3, %dim4, %dim5, %dim6) : tensor<?x?x?x?xf32>
+ %packed = tensor.pack %unpacked inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty1 : tensor<?x?xf32> -> tensor<?x?x?x?xf32>
+ return %packed : tensor<?x?x?x?xf32>
+}
+
+// -----
+
+// CHECK: func.func @pack_unpack_
diff erent_tiles(
+// CHECK-SAME: %[[T:.+]]: tensor<?x?x?x?xf32>,
+// CHECK-NOT: return %[[T]] : tensor<?x?x?x?xf32>
+func.func @pack_unpack_
diff erent_tiles(%t: tensor<?x?x?x?xf32>, %dim1: index, %dim2: index, %dim3: index, %dim4: index, %dim5: index, %dim6: index,
+ %tile1: index, %tile2: index) -> tensor<?x?x?x?xf32> {
+ %tensor_empty = tensor.empty(%dim1, %dim2) : tensor<?x?xf32>
+ %unpacked = tensor.unpack %t inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty : tensor<?x?x?x?xf32> -> tensor<?x?xf32>
+ %tensor_empty1 = tensor.empty(%dim3, %dim4, %dim5, %dim6) : tensor<?x?x?x?xf32>
+ %packed = tensor.pack %unpacked inner_dims_pos = [0, 1] inner_tiles = [%tile2, %tile1] into %tensor_empty1 : tensor<?x?xf32> -> tensor<?x?x?x?xf32>
+ return %packed : tensor<?x?x?x?xf32>
+}
+
+// -----
+
+// CHECK: func.func @pack_unpack_dynamic_with_padding(
+// CHECK-SAME: %[[T:.+]]: tensor<?x?x?x?xf32>,
+// CHECK-NOT: return %[[T]] : tensor<?x?x?x?xf32>
+func.func @pack_unpack_dynamic_with_padding(%t: tensor<?x?x?x?xf32>, %dim1: index, %dim2: index, %dim3: index, %dim4: index, %dim5: index, %dim6: index,
+ %tile1: index, %tile2: index, %pad: f32) -> tensor<?x?x?x?xf32> {
+ %tensor_empty = tensor.empty(%dim1, %dim2) : tensor<?x?xf32>
+ %unpacked = tensor.unpack %t inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty : tensor<?x?x?x?xf32> -> tensor<?x?xf32>
+ %tensor_empty1 = tensor.empty(%dim3, %dim4, %dim5, %dim6) : tensor<?x?x?x?xf32>
+ %packed = tensor.pack %unpacked padding_value(%pad: f32) inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty1 : tensor<?x?xf32> -> tensor<?x?x?x?xf32>
+ return %packed : tensor<?x?x?x?xf32>
+}
diff --git a/mlir/test/Dialect/Tensor/simplify-tensor-pack.mlir b/mlir/test/Dialect/Tensor/simplify-tensor-pack.mlir
new file mode 100644
index 0000000000000..75eb33ed033b9
--- /dev/null
+++ b/mlir/test/Dialect/Tensor/simplify-tensor-pack.mlir
@@ -0,0 +1,24 @@
+// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns="test-simplify-pack-patterns" %s | FileCheck %s
+
+// CHECK: func.func @single_dim_packing(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<256xf32>)
+// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1]] : tensor<256xf32> into tensor<8x32xf32>
+// CHECK: return %[[EXPANDED]] : tensor<8x32xf32>
+func.func @single_dim_packing(%arg0: tensor<256xf32>) -> tensor<8x32xf32> {
+ %empty = tensor.empty() : tensor<8x32xf32>
+ %0 = tensor.pack %arg0 inner_dims_pos = [0] inner_tiles = [32] into %empty : tensor<256xf32> -> tensor<8x32xf32>
+ return %0 : tensor<8x32xf32>
+}
+
+// -----
+
+// CHECK: func.func @single_dim_packing_with_padding(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<255xf32>)
+// CHECK-NOT: tensor.expand_shape
+// CHECK: tensor.pack
+func.func @single_dim_packing_with_padding(%arg0: tensor<255xf32>) -> tensor<8x32xf32> {
+ %empty = tensor.empty() : tensor<8x32xf32>
+ %cst = arith.constant 0.000000e+00 : f32
+ %0 = tensor.pack %arg0 padding_value(%cst : f32) inner_dims_pos = [0] inner_tiles = [32] into %empty : tensor<255xf32> -> tensor<8x32xf32>
+ return %0 : tensor<8x32xf32>
+}
diff --git a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
index a87547035e75b..25039a5c41f92 100644
--- a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
@@ -85,6 +85,11 @@ struct TestTensorTransforms
"Use the scf.foreach_thread operation when generating loop nests for "
"the extract_slice of collapse_shape pattern"),
llvm::cl::init(false)};
+
+ Option<bool> testSimplifyPackPatterns{
+ *this, "test-simplify-pack-patterns",
+ llvm::cl::desc("Test patterns to simplify tensor.pack"),
+ llvm::cl::init(false)};
};
} // namespace
@@ -134,6 +139,12 @@ static void applyFoldConsecutiveInsertExtractSlicePatterns(Operation *rootOp) {
(void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns));
}
+static void applySimplifyPackPatterns(Operation *rootOp) {
+ RewritePatternSet patterns(rootOp->getContext());
+ tensor::populateSimplifyTensorPack(patterns);
+ (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns));
+}
+
namespace {
/// Base pattern to rewrite a `tensor.collapse_shape -> tensor.extract_slice`.
/// The `tensor.extract_slice` is replaced by a loop or gather operation that
@@ -277,6 +288,8 @@ applyRewriteExtractFromCollapseShapePatterns(Operation *rootOp,
void TestTensorTransforms::runOnOperation() {
Operation *rootOp = getOperation();
+ if (testSimplifyPackPatterns)
+ applySimplifyPackPatterns(rootOp);
if (testSplitPaddingPatterns)
applySplitPaddingPatterns(rootOp);
if (testFoldConstantExtractSlice)
More information about the Mlir-commits
mailing list