[Mlir-commits] [mlir] 9a807b8 - [mlir][tosa] Fix start index in slice canonicalization
Jacques Pienaar
llvmlistbot at llvm.org
Fri Jul 28 16:57:19 PDT 2023
Author: Jacques Pienaar
Date: 2023-07-28T16:57:12-07:00
New Revision: 9a807b8763d644b3c7da018c12bf45937d4f7ea5
URL: https://github.com/llvm/llvm-project/commit/9a807b8763d644b3c7da018c12bf45937d4f7ea5
DIFF: https://github.com/llvm/llvm-project/commit/9a807b8763d644b3c7da018c12bf45937d4f7ea5.diff
LOG: [mlir][tosa] Fix start index in slice canonicalization
The updated start indices weren't being used.
Differential Revision: https://reviews.llvm.org/D156567
Added:
Modified:
mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
mlir/test/Dialect/Tosa/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 152b8857393bba..e69c40f2b05239 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -406,13 +406,12 @@ struct ConcatSliceOptimization : public OpRewritePattern<tosa::SliceOp> {
if (sliceStart[axis] >= 0 &&
(sliceStart[axis] + sliceSize[axis]) <= inputType.getDimSize(axis)) {
- replaceWithSlice =
- rewriter
- .create<tosa::SliceOp>(
- sliceOp.getLoc(), sliceOp.getType(), input,
- rewriter.getDenseI64ArrayAttr(sliceOp.getStart()),
- rewriter.getDenseI64ArrayAttr(sliceSize))
- .getResult();
+ replaceWithSlice = rewriter
+ .create<tosa::SliceOp>(
+ sliceOp.getLoc(), sliceOp.getType(), input,
+ rewriter.getDenseI64ArrayAttr(sliceStart),
+ rewriter.getDenseI64ArrayAttr(sliceSize))
+ .getResult();
break;
}
sliceStart[axis] -= inputType.getDimSize(axis);
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 1f9ce1f4c7b929..57c52b7fa163ad 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -542,7 +542,7 @@ func.func @canonicalize_cross_concat_inputs(%arg0 : tensor<1x12x12xf32>, %arg1 :
// CHECK-LABEL: @canonicalize_concat_slice_on_non_concat_axis
// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x12x12xf32>, %[[VAL_1:.*]]: tensor<1x12x12xf32>
// CHECK: %[[VAL_2:.*]] = "tosa.slice"(%[[VAL_0]]) <{size = array<i64: 1, 6, 12>, start = array<i64: 0, 0, 0>}> : (tensor<1x12x12xf32>) -> tensor<1x6x12xf32>
-// CHECK: %[[VAL_3:.*]] = "tosa.slice"(%[[VAL_1]]) <{size = array<i64: 1, 3, 12>, start = array<i64: 1, 3, 12>}> : (tensor<1x12x12xf32>) -> tensor<1x3x12xf32>
+// CHECK: %[[VAL_3:.*]] = "tosa.slice"(%[[VAL_1]]) <{size = array<i64: 1, 3, 12>, start = array<i64: 1, 3, 0>}> : (tensor<1x12x12xf32>) -> tensor<1x3x12xf32>
// CHECK: return %[[VAL_2]], %[[VAL_3]] : tensor<1x6x12xf32>, tensor<1x3x12xf32>
func.func @canonicalize_concat_slice_on_non_concat_axis(%arg0 : tensor<1x12x12xf32>, %arg1 : tensor<1x12x12xf32>) -> (tensor<1x6x12xf32>, tensor<1x3x12xf32>) {
%0 = "tosa.concat"(%arg0, %arg1) {axis = 2 : i64} : (tensor<1x12x12xf32>, tensor<1x12x12xf32>) -> tensor<1x12x24xf32>
More information about the Mlir-commits
mailing list