[Mlir-commits] [mlir] 640973f - [tosa] Lower tosa.slice to tensor.slice for dynamic case

Rob Suderman llvmlistbot at llvm.org
Wed Jun 15 10:01:36 PDT 2022


Author: Rob Suderman
Date: 2022-06-15T09:54:36-07:00
New Revision: 640973f2b99b9b9eb85be096626fd0a7fc7d1dfe

URL: https://github.com/llvm/llvm-project/commit/640973f2b99b9b9eb85be096626fd0a7fc7d1dfe
DIFF: https://github.com/llvm/llvm-project/commit/640973f2b99b9b9eb85be096626fd0a7fc7d1dfe.diff

LOG: [tosa] Lower tosa.slice to tensor.slice for dynamic case

Existing slice lowering only supporting static shapes.

Reviewed By: NatashaKnk

Differential Revision: https://reviews.llvm.org/D127704

Added: 
    

Modified: 
    mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
    mlir/lib/Conversion/TosaToTensor/TosaToTensorPass.cpp
    mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
    utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
index c02108eee265e..c8c326d35d33c 100644
--- a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
+++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
@@ -11,6 +11,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Conversion/TosaToTensor/TosaToTensor.h"
+#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
 #include "mlir/IR/PatternMatch.h"
@@ -27,14 +28,32 @@ class SliceOpConverter : public OpRewritePattern<tosa::SliceOp> {
 
   LogicalResult matchAndRewrite(tosa::SliceOp sliceOp,
                                 PatternRewriter &rewriter) const final {
+    Location loc = sliceOp.getLoc();
     Value input = sliceOp.input();
     SmallVector<int64_t> strides;
+    auto starts = sliceOp.start();
+    auto sizes = sliceOp.size();
     strides.resize(sliceOp.getType().template cast<ShapedType>().getRank(), 1);
 
-    rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(
-        sliceOp, sliceOp.getType(), input, ValueRange({}), ValueRange({}),
-        ValueRange({}), sliceOp.start(), sliceOp.size(),
-        rewriter.getI64ArrayAttr(strides));
+    SmallVector<Value> dynSizes;
+    for (auto i : llvm::enumerate(sizes)) {
+      int64_t size = i.value().cast<IntegerAttr>().getInt();
+      size_t index = i.index();
+      if (size != ShapedType::kDynamicSize)
+        continue;
+
+      auto dim = rewriter.create<tensor::DimOp>(loc, input, index);
+      auto offset = rewriter.create<arith::ConstantOp>(
+          loc,
+          rewriter.getIndexAttr(starts[index].cast<IntegerAttr>().getInt()));
+      dynSizes.push_back(rewriter.create<arith::SubIOp>(loc, dim, offset));
+    }
+
+    auto newSliceOp = rewriter.create<tensor::ExtractSliceOp>(
+        sliceOp.getLoc(), sliceOp.getType(), input, ValueRange({}), dynSizes,
+        ValueRange({}), starts, sizes, rewriter.getI64ArrayAttr(strides));
+
+    rewriter.replaceOp(sliceOp, newSliceOp.getResult());
     return success();
   }
 };

diff  --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensorPass.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensorPass.cpp
index 6fe862b46a2ac..08d5c7d50640b 100644
--- a/mlir/lib/Conversion/TosaToTensor/TosaToTensorPass.cpp
+++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensorPass.cpp
@@ -12,6 +12,7 @@
 
 #include "../PassDetail.h"
 #include "mlir/Conversion/TosaToTensor/TosaToTensor.h"
+#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
 #include "mlir/Dialect/Tosa/Transforms/PassDetail.h"
@@ -31,6 +32,7 @@ struct TosaToTensor : public TosaToTensorBase<TosaToTensor> {
     RewritePatternSet patterns(&getContext());
     ConversionTarget target(getContext());
     target.addIllegalOp<tosa::SliceOp>();
+    target.addLegalDialect<arith::ArithmeticDialect>();
     target.addLegalDialect<tensor::TensorDialect>();
 
     mlir::tosa::populateTosaToTensorConversionPatterns(&patterns);

diff  --git a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
index 12eb51f58abed..15a4bcd7498fd 100644
--- a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
+++ b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
@@ -6,3 +6,16 @@ func.func @slice(%arg0: tensor<6xf32>) ->() {
   %0 = "tosa.slice"(%arg0) {start = [2], size = [1]} : (tensor<6xf32>)  -> (tensor<1xf32>)
   return
 }
+
+// -----
+
+// CHECK-LABLE: func @slice_dyn
+func.func @slice_dyn(%arg0: tensor<?xf32>) -> (tensor<?xf32>) {
+  // CHECK: %[[C0:.+]] = arith.constant 0 : index
+  // CHECK: %[[DIM:.+]] = tensor.dim %arg0, %[[C0]]
+  // CHECK: %[[C2:.+]] = arith.constant 2 : index
+  // CHECK: %[[SUB:.+]] = arith.subi %[[DIM]], %[[C2]]
+  // CHECK: %2 = tensor.extract_slice %arg0[2] [%[[SUB]]] [1]
+  %0 = "tosa.slice"(%arg0) {start = [2], size = [-1]} : (tensor<?xf32>)  -> (tensor<?xf32>)
+  return %0 : tensor<?xf32>
+}

diff  --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 23308f9a0bc01..e918670288e2b 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -7897,6 +7897,7 @@ cc_library(
         "lib/Conversion/TosaToTensor",
     ],
     deps = [
+        ":ArithmeticDialect",
         ":ConversionPassIncGen",
         ":FuncDialect",
         ":IR",


        


More information about the Mlir-commits mailing list