[Mlir-commits] [mlir] 9a807b8 - [mlir][tosa] Fix start index in slice canonicalization

Jacques Pienaar llvmlistbot at llvm.org
Fri Jul 28 16:57:19 PDT 2023


Author: Jacques Pienaar
Date: 2023-07-28T16:57:12-07:00
New Revision: 9a807b8763d644b3c7da018c12bf45937d4f7ea5

URL: https://github.com/llvm/llvm-project/commit/9a807b8763d644b3c7da018c12bf45937d4f7ea5
DIFF: https://github.com/llvm/llvm-project/commit/9a807b8763d644b3c7da018c12bf45937d4f7ea5.diff

LOG: [mlir][tosa] Fix start index in slice canonicalization

The updated start indices weren't being used.

Differential Revision: https://reviews.llvm.org/D156567

Added: 
    

Modified: 
    mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
    mlir/test/Dialect/Tosa/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 152b8857393bba..e69c40f2b05239 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -406,13 +406,12 @@ struct ConcatSliceOptimization : public OpRewritePattern<tosa::SliceOp> {
 
       if (sliceStart[axis] >= 0 &&
           (sliceStart[axis] + sliceSize[axis]) <= inputType.getDimSize(axis)) {
-        replaceWithSlice =
-            rewriter
-                .create<tosa::SliceOp>(
-                    sliceOp.getLoc(), sliceOp.getType(), input,
-                    rewriter.getDenseI64ArrayAttr(sliceOp.getStart()),
-                    rewriter.getDenseI64ArrayAttr(sliceSize))
-                .getResult();
+        replaceWithSlice = rewriter
+                               .create<tosa::SliceOp>(
+                                   sliceOp.getLoc(), sliceOp.getType(), input,
+                                   rewriter.getDenseI64ArrayAttr(sliceStart),
+                                   rewriter.getDenseI64ArrayAttr(sliceSize))
+                               .getResult();
         break;
       }
       sliceStart[axis] -= inputType.getDimSize(axis);

diff  --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 1f9ce1f4c7b929..57c52b7fa163ad 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -542,7 +542,7 @@ func.func @canonicalize_cross_concat_inputs(%arg0 : tensor<1x12x12xf32>, %arg1 :
 // CHECK-LABEL: @canonicalize_concat_slice_on_non_concat_axis
 // CHECK-SAME: %[[VAL_0:.*]]: tensor<1x12x12xf32>, %[[VAL_1:.*]]: tensor<1x12x12xf32>
 // CHECK: %[[VAL_2:.*]] = "tosa.slice"(%[[VAL_0]]) <{size = array<i64: 1, 6, 12>, start = array<i64: 0, 0, 0>}> : (tensor<1x12x12xf32>) -> tensor<1x6x12xf32>
-// CHECK: %[[VAL_3:.*]] = "tosa.slice"(%[[VAL_1]]) <{size = array<i64: 1, 3, 12>, start = array<i64: 1, 3, 12>}> : (tensor<1x12x12xf32>) -> tensor<1x3x12xf32>
+// CHECK: %[[VAL_3:.*]] = "tosa.slice"(%[[VAL_1]]) <{size = array<i64: 1, 3, 12>, start = array<i64: 1, 3, 0>}> : (tensor<1x12x12xf32>) -> tensor<1x3x12xf32>
 // CHECK: return %[[VAL_2]], %[[VAL_3]] : tensor<1x6x12xf32>, tensor<1x3x12xf32>
 func.func @canonicalize_concat_slice_on_non_concat_axis(%arg0 : tensor<1x12x12xf32>, %arg1 : tensor<1x12x12xf32>) -> (tensor<1x6x12xf32>, tensor<1x3x12xf32>) {
   %0 = "tosa.concat"(%arg0, %arg1) {axis = 2 : i64} : (tensor<1x12x12xf32>, tensor<1x12x12xf32>) -> tensor<1x12x24xf32>


        


More information about the Mlir-commits mailing list