[llvm-branch-commits] [mlir] f4bb076 - [mlir][tosa] Add tosa.slice to std.subtensor lowering

Rob Suderman via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Wed Mar 17 17:29:02 PDT 2021


Author: Rob Suderman
Date: 2021-03-17T17:28:18-07:00
New Revision: f4bb076a4419767cf35a17e3c08f392505a5acd2

URL: https://github.com/llvm/llvm-project/commit/f4bb076a4419767cf35a17e3c08f392505a5acd2
DIFF: https://github.com/llvm/llvm-project/commit/f4bb076a4419767cf35a17e3c08f392505a5acd2.diff

LOG: [mlir][tosa] Add tosa.slice to std.subtensor lowering

Lowering to subtensor is added for tosa.slice operator.

Differential Revision: https://reviews.llvm.org/D98825

Added: 
    

Modified: 
    mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp
    mlir/lib/Conversion/TosaToStandard/TosaToStandardPass.cpp
    mlir/test/Conversion/TosaToStandard/tosa-to-standard.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp b/mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp
index 21a8da291aee..6e5411dd5ecb 100644
--- a/mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp
+++ b/mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp
@@ -32,9 +32,28 @@ class ConstOpConverter : public OpRewritePattern<tosa::ConstOp> {
   }
 };
 
+class SliceOpConverter : public OpRewritePattern<tosa::SliceOp> {
+public:
+  using OpRewritePattern<tosa::SliceOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tosa::SliceOp sliceOp,
+                                PatternRewriter &rewriter) const final {
+    Value input = sliceOp.input();
+    SmallVector<int64_t> strides;
+    strides.resize(sliceOp.getType().template cast<ShapedType>().getRank(), 1);
+
+    rewriter.replaceOpWithNewOp<SubTensorOp>(
+        sliceOp, sliceOp.getType(), input, ValueRange({}), ValueRange({}),
+        ValueRange({}), sliceOp.start(), sliceOp.size(),
+        rewriter.getI64ArrayAttr(strides));
+
+    return success();
+  }
+};
+
 } // namespace
 
 void mlir::tosa::populateTosaToStandardConversionPatterns(
     MLIRContext *context, OwningRewritePatternList *patterns) {
-  patterns->insert<ConstOpConverter>(context);
+  patterns->insert<ConstOpConverter, SliceOpConverter>(context);
 }

diff  --git a/mlir/lib/Conversion/TosaToStandard/TosaToStandardPass.cpp b/mlir/lib/Conversion/TosaToStandard/TosaToStandardPass.cpp
index 225855e78bda..78a0e65da81b 100644
--- a/mlir/lib/Conversion/TosaToStandard/TosaToStandardPass.cpp
+++ b/mlir/lib/Conversion/TosaToStandard/TosaToStandardPass.cpp
@@ -32,7 +32,8 @@ struct TosaToStandard : public TosaToStandardBase<TosaToStandard> {
     OwningRewritePatternList patterns;
     ConversionTarget target(getContext());
     target.addIllegalOp<tosa::ConstOp>();
-    target.addLegalOp<ConstantOp>();
+    target.addIllegalOp<tosa::SliceOp>();
+    target.addLegalDialect<StandardOpsDialect>();
 
     auto *op = getOperation();
     mlir::tosa::populateTosaToStandardConversionPatterns(op->getContext(),

diff  --git a/mlir/test/Conversion/TosaToStandard/tosa-to-standard.mlir b/mlir/test/Conversion/TosaToStandard/tosa-to-standard.mlir
index 86304dcba862..94925aec15c7 100644
--- a/mlir/test/Conversion/TosaToStandard/tosa-to-standard.mlir
+++ b/mlir/test/Conversion/TosaToStandard/tosa-to-standard.mlir
@@ -8,3 +8,11 @@ func @const_test() -> (tensor<i32>) {
   // CHECK: return [[C3]]
   return %0 : tensor<i32>
 }
+
+// ----
+
+func @slice(%arg0: tensor<6xf32>) ->() {
+  // CHECK: [[SLICE:%.+]] = subtensor %arg0[2] [1] [1]
+  %0 = "tosa.slice"(%arg0) {start = [2], size = [1]} : (tensor<6xf32>)  -> (tensor<1xf32>)
+  return
+}


        


More information about the llvm-branch-commits mailing list