[Mlir-commits] [mlir] 9e3ca79 - [mlir][tosa] Canonicalize concatenate->slice sequence
Robert Suderman
llvmlistbot at llvm.org
Wed Mar 22 10:01:30 PDT 2023
Author: Luke Hutton
Date: 2023-03-22T16:52:44Z
New Revision: 9e3ca7987a4dc33cdf847b79a6304b117651d21f
URL: https://github.com/llvm/llvm-project/commit/9e3ca7987a4dc33cdf847b79a6304b117651d21f
DIFF: https://github.com/llvm/llvm-project/commit/9e3ca7987a4dc33cdf847b79a6304b117651d21f.diff
LOG: [mlir][tosa] Canonicalize concatenate->slice sequence
Adds a canonicalizer for the concatenate->slice sequence where
an output of slice can be replaced with an input of concatenate.
This is useful in the context of operations with complex inputs
and outputs that are legalized from a framework such as TFL.
For example, a TFL graph (FFT->FFT) will be legalized to the
following TOSA graph:
<complex input>
/ \
slice slice
\ /
FFT
/ \ -+
concatenate |
/ \ | Redundant
slice slice |
\ / -+
FFT
/ \
concatenate
|
<complex output>
Concatenate and slice operations at the boundaries of the graph are
useful as they maintain the correct correspondance of input/output
tensors to the original TFL graph. However, consecutive
complex operations will result in redundant concatenate->slice
sequences which should be removed from the final TOSA graph.
The canonicalization does not currently handle dynamic types.
Signed-off-by: Luke Hutton <luke.hutton at arm.com>
Reviewed By: rsuderman
Differential Revision: https://reviews.llvm.org/D144545
Added:
Modified:
mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
mlir/test/Dialect/Tosa/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 7c8018ad6460..b6127f1ffa3c 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -1556,6 +1556,7 @@ def Tosa_SliceOp: Tosa_Op<"slice", [
Tosa_Tensor1Dto6D:$output
);
+ let hasCanonicalizer = 1;
let hasFolder = 1;
}
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 1a8a5782e11f..16f23e4798c0 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -519,6 +519,65 @@ void ClampOp::getCanonicalizationPatterns(RewritePatternSet &results,
results.add<ClampClampOptimization>(context);
}
+struct ConcatSliceOptimization : public OpRewritePattern<tosa::SliceOp> {
+ using OpRewritePattern<tosa::SliceOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tosa::SliceOp sliceOp,
+ PatternRewriter &rewriter) const override {
+ Value sliceInput = sliceOp.getInput();
+ auto concatOp = sliceInput.getDefiningOp<tosa::ConcatOp>();
+ if (!concatOp)
+ return rewriter.notifyMatchFailure(
+ sliceOp, "slice input must be concat operation");
+
+ OperandRange inputs = concatOp.getInput1();
+ auto concatType = dyn_cast<RankedTensorType>(concatOp.getType());
+ if (!concatType || !concatType.hasStaticShape())
+ return rewriter.notifyMatchFailure(
+ sliceOp, "slice input must be a static ranked tensor");
+ int32_t axis = concatOp.getAxis();
+
+ llvm::SmallVector<int64_t> sliceStart(sliceOp.getStart());
+ llvm::ArrayRef<int64_t> sliceSize = sliceOp.getSize();
+
+ // Validate slice on the concatenated axis. Slicing along this
+ // axis should span only one of the inputs to the concatenate
+ // operation.
+ std::optional<Value> replaceWithSlice;
+ for (auto input : inputs) {
+ auto inputType = dyn_cast<RankedTensorType>(input.getType());
+ if (!inputType || !inputType.hasStaticShape())
+ return rewriter.notifyMatchFailure(
+ sliceOp, "concat input must be a static ranked tensor");
+
+ if (sliceStart[axis] >= 0 &&
+ (sliceStart[axis] + sliceSize[axis]) <= inputType.getDimSize(axis)) {
+ replaceWithSlice =
+ rewriter
+ .create<tosa::SliceOp>(
+ sliceOp.getLoc(), sliceOp.getType(), input,
+ rewriter.getDenseI64ArrayAttr(sliceOp.getStart()),
+ rewriter.getDenseI64ArrayAttr(sliceSize))
+ .getResult();
+ break;
+ }
+ sliceStart[axis] -= inputType.getDimSize(axis);
+ }
+
+ if (!replaceWithSlice)
+ return rewriter.notifyMatchFailure(
+ sliceOp, "corresponding concat input not found for slice");
+
+ rewriter.replaceOp(sliceOp, replaceWithSlice.value());
+ return success();
+ }
+};
+
+void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.add<ConcatSliceOptimization>(context);
+}
+
//===----------------------------------------------------------------------===//
// Operator Folders.
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index e16a614c7cd0..77627d8c8ba6 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -434,3 +434,56 @@ func.func @fold_resize_bilinear(%arg0 : tensor<1x15x13x1xi8>) -> tensor<1x15x13x
%resize = "tosa.resize"(%arg0) {mode = "BILINEAR", scale = array<i64: 2, 2, 1, 1>, offset = array<i64: 0, 0>, border = array<i64: 0, 0>} : (tensor<1x15x13x1xi8>) -> tensor<1x15x13x1xi8>
return %resize : tensor<1x15x13x1xi8>
}
+
+// -----
+
+// CHECK-LABEL: @canonicalize_concat_slice_final_axis
+// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x12x12x1xf32>, %[[VAL_1:.*]]: tensor<1x12x12x1xf32>
+// CHECK: return %[[VAL_0]], %[[VAL_1]] : tensor<1x12x12x1xf32>, tensor<1x12x12x1xf32>
+func.func @canonicalize_concat_slice_final_axis(%arg0 : tensor<1x12x12x1xf32>, %arg1 : tensor<1x12x12x1xf32>) -> (tensor<1x12x12x1xf32>, tensor<1x12x12x1xf32>) {
+ %0 = "tosa.concat"(%arg0, %arg1) {axis = 3 : i64} : (tensor<1x12x12x1xf32>, tensor<1x12x12x1xf32>) -> tensor<1x12x12x2xf32>
+ %1 = "tosa.slice"(%0) {size = array<i64: 1, 12, 12, 1>, start = array<i64: 0, 0, 0, 0>} : (tensor<1x12x12x2xf32>) -> tensor<1x12x12x1xf32>
+ %2 = "tosa.slice"(%0) {size = array<i64: 1, 12, 12, 1>, start = array<i64: 0, 0, 0, 1>} : (tensor<1x12x12x2xf32>) -> tensor<1x12x12x1xf32>
+ return %1, %2 : tensor<1x12x12x1xf32>, tensor<1x12x12x1xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @canonicalize_concat_slice_middle_axis
+// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x12x12xf32>, %[[VAL_1:.*]]: tensor<1x12x12xf32>
+// CHECK: return %[[VAL_0]], %[[VAL_1]] : tensor<1x12x12xf32>, tensor<1x12x12xf32>
+func.func @canonicalize_concat_slice_middle_axis(%arg0 : tensor<1x12x12xf32>, %arg1 : tensor<1x12x12xf32>) -> (tensor<1x12x12xf32>, tensor<1x12x12xf32>) {
+ %0 = "tosa.concat"(%arg0, %arg1) {axis = 1 : i64} : (tensor<1x12x12xf32>, tensor<1x12x12xf32>) -> tensor<1x24x12xf32>
+ %1 = "tosa.slice"(%0) {size = array<i64: 1, 12, 12>, start = array<i64: 0, 0, 0>} : (tensor<1x24x12xf32>) -> tensor<1x12x12xf32>
+ %2 = "tosa.slice"(%0) {size = array<i64: 1, 12, 12>, start = array<i64: 0, 12, 0>} : (tensor<1x24x12xf32>) -> tensor<1x12x12xf32>
+ return %1, %2 : tensor<1x12x12xf32>, tensor<1x12x12xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @canonicalize_cross_concat_inputs
+// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x12x12xf32>, %[[VAL_1:.*]]: tensor<1x12x12xf32>
+// CHECK: %[[VAL_2:.*]] = "tosa.concat"(%[[VAL_0]], %[[VAL_1]]) {axis = 2 : i64} : (tensor<1x12x12xf32>, tensor<1x12x12xf32>) -> tensor<1x12x24xf32>
+// CHECK: %[[VAL_3:.*]] = "tosa.slice"(%[[VAL_2]]) {size = array<i64: 1, 12, 15>, start = array<i64: 0, 0, 0>} : (tensor<1x12x24xf32>) -> tensor<1x12x15xf32>
+// CHECK: %[[VAL_4:.*]] = "tosa.slice"(%[[VAL_2]]) {size = array<i64: 1, 12, 20>, start = array<i64: 0, 0, 4>} : (tensor<1x12x24xf32>) -> tensor<1x12x20xf32>
+// CHECK: return %[[VAL_3]], %[[VAL_4]] : tensor<1x12x15xf32>, tensor<1x12x20xf32>
+func.func @canonicalize_cross_concat_inputs(%arg0 : tensor<1x12x12xf32>, %arg1 : tensor<1x12x12xf32>) -> (tensor<1x12x15xf32>, tensor<1x12x20xf32>) {
+ %0 = "tosa.concat"(%arg0, %arg1) {axis = 2 : i64} : (tensor<1x12x12xf32>, tensor<1x12x12xf32>) -> tensor<1x12x24xf32>
+ %1 = "tosa.slice"(%0) {size = array<i64: 1, 12, 15>, start = array<i64: 0, 0, 0>} : (tensor<1x12x24xf32>) -> tensor<1x12x15xf32>
+ %2 = "tosa.slice"(%0) {size = array<i64: 1, 12, 20>, start = array<i64: 0, 0, 4>} : (tensor<1x12x24xf32>) -> tensor<1x12x20xf32>
+ return %1, %2 : tensor<1x12x15xf32>, tensor<1x12x20xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @canonicalize_concat_slice_on_non_concat_axis
+// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x12x12xf32>, %[[VAL_1:.*]]: tensor<1x12x12xf32>
+// CHECK: %[[VAL_2:.*]] = "tosa.slice"(%[[VAL_0]]) {size = array<i64: 1, 6, 12>, start = array<i64: 0, 0, 0>} : (tensor<1x12x12xf32>) -> tensor<1x6x12xf32>
+// CHECK: %[[VAL_3:.*]] = "tosa.slice"(%[[VAL_1]]) {size = array<i64: 1, 3, 12>, start = array<i64: 1, 3, 12>} : (tensor<1x12x12xf32>) -> tensor<1x3x12xf32>
+// CHECK: return %[[VAL_2]], %[[VAL_3]] : tensor<1x6x12xf32>, tensor<1x3x12xf32>
+func.func @canonicalize_concat_slice_on_non_concat_axis(%arg0 : tensor<1x12x12xf32>, %arg1 : tensor<1x12x12xf32>) -> (tensor<1x6x12xf32>, tensor<1x3x12xf32>) {
+ %0 = "tosa.concat"(%arg0, %arg1) {axis = 2 : i64} : (tensor<1x12x12xf32>, tensor<1x12x12xf32>) -> tensor<1x12x24xf32>
+ %1 = "tosa.slice"(%0) {size = array<i64: 1, 6, 12>, start = array<i64: 0, 0, 0>} : (tensor<1x12x24xf32>) -> tensor<1x6x12xf32>
+ %2 = "tosa.slice"(%0) {size = array<i64: 1, 3, 12>, start = array<i64: 1, 3, 12>} : (tensor<1x12x24xf32>) -> tensor<1x3x12xf32>
+ return %1, %2 : tensor<1x6x12xf32>, tensor<1x3x12xf32>
+}
More information about the Mlir-commits
mailing list