[llvm-branch-commits] [mlir] f4bb076 - [mlir][tosa] Add tosa.slice to std.subtensor lowering
Rob Suderman via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Wed Mar 17 17:29:02 PDT 2021
Author: Rob Suderman
Date: 2021-03-17T17:28:18-07:00
New Revision: f4bb076a4419767cf35a17e3c08f392505a5acd2
URL: https://github.com/llvm/llvm-project/commit/f4bb076a4419767cf35a17e3c08f392505a5acd2
DIFF: https://github.com/llvm/llvm-project/commit/f4bb076a4419767cf35a17e3c08f392505a5acd2.diff
LOG: [mlir][tosa] Add tosa.slice to std.subtensor lowering
Lowering to subtensor is added for tosa.slice operator.
Differential Revision: https://reviews.llvm.org/D98825
Added:
Modified:
mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp
mlir/lib/Conversion/TosaToStandard/TosaToStandardPass.cpp
mlir/test/Conversion/TosaToStandard/tosa-to-standard.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp b/mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp
index 21a8da291aee..6e5411dd5ecb 100644
--- a/mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp
+++ b/mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp
@@ -32,9 +32,28 @@ class ConstOpConverter : public OpRewritePattern<tosa::ConstOp> {
}
};
+class SliceOpConverter : public OpRewritePattern<tosa::SliceOp> {
+public:
+ using OpRewritePattern<tosa::SliceOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tosa::SliceOp sliceOp,
+ PatternRewriter &rewriter) const final {
+ Value input = sliceOp.input();
+ SmallVector<int64_t> strides;
+ strides.resize(sliceOp.getType().template cast<ShapedType>().getRank(), 1);
+
+ rewriter.replaceOpWithNewOp<SubTensorOp>(
+ sliceOp, sliceOp.getType(), input, ValueRange({}), ValueRange({}),
+ ValueRange({}), sliceOp.start(), sliceOp.size(),
+ rewriter.getI64ArrayAttr(strides));
+
+ return success();
+ }
+};
+
} // namespace
void mlir::tosa::populateTosaToStandardConversionPatterns(
MLIRContext *context, OwningRewritePatternList *patterns) {
- patterns->insert<ConstOpConverter>(context);
+ patterns->insert<ConstOpConverter, SliceOpConverter>(context);
}
diff --git a/mlir/lib/Conversion/TosaToStandard/TosaToStandardPass.cpp b/mlir/lib/Conversion/TosaToStandard/TosaToStandardPass.cpp
index 225855e78bda..78a0e65da81b 100644
--- a/mlir/lib/Conversion/TosaToStandard/TosaToStandardPass.cpp
+++ b/mlir/lib/Conversion/TosaToStandard/TosaToStandardPass.cpp
@@ -32,7 +32,8 @@ struct TosaToStandard : public TosaToStandardBase<TosaToStandard> {
OwningRewritePatternList patterns;
ConversionTarget target(getContext());
target.addIllegalOp<tosa::ConstOp>();
- target.addLegalOp<ConstantOp>();
+ target.addIllegalOp<tosa::SliceOp>();
+ target.addLegalDialect<StandardOpsDialect>();
auto *op = getOperation();
mlir::tosa::populateTosaToStandardConversionPatterns(op->getContext(),
diff --git a/mlir/test/Conversion/TosaToStandard/tosa-to-standard.mlir b/mlir/test/Conversion/TosaToStandard/tosa-to-standard.mlir
index 86304dcba862..94925aec15c7 100644
--- a/mlir/test/Conversion/TosaToStandard/tosa-to-standard.mlir
+++ b/mlir/test/Conversion/TosaToStandard/tosa-to-standard.mlir
@@ -8,3 +8,11 @@ func @const_test() -> (tensor<i32>) {
// CHECK: return [[C3]]
return %0 : tensor<i32>
}
+
+// ----
+
+func @slice(%arg0: tensor<6xf32>) ->() {
+ // CHECK: [[SLICE:%.+]] = subtensor %arg0[2] [1] [1]
+ %0 = "tosa.slice"(%arg0) {start = [2], size = [1]} : (tensor<6xf32>) -> (tensor<1xf32>)
+ return
+}
More information about the llvm-branch-commits
mailing list