[Mlir-commits] [mlir] 640973f - [tosa] Lower tosa.slice to tensor.slice for dynamic case
Rob Suderman
llvmlistbot at llvm.org
Wed Jun 15 10:01:36 PDT 2022
Author: Rob Suderman
Date: 2022-06-15T09:54:36-07:00
New Revision: 640973f2b99b9b9eb85be096626fd0a7fc7d1dfe
URL: https://github.com/llvm/llvm-project/commit/640973f2b99b9b9eb85be096626fd0a7fc7d1dfe
DIFF: https://github.com/llvm/llvm-project/commit/640973f2b99b9b9eb85be096626fd0a7fc7d1dfe.diff
LOG: [tosa] Lower tosa.slice to tensor.slice for dynamic case
Existing slice lowering only supporting static shapes.
Reviewed By: NatashaKnk
Differential Revision: https://reviews.llvm.org/D127704
Added:
Modified:
mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
mlir/lib/Conversion/TosaToTensor/TosaToTensorPass.cpp
mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
index c02108eee265e..c8c326d35d33c 100644
--- a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
+++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
@@ -11,6 +11,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/TosaToTensor/TosaToTensor.h"
+#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/IR/PatternMatch.h"
@@ -27,14 +28,32 @@ class SliceOpConverter : public OpRewritePattern<tosa::SliceOp> {
LogicalResult matchAndRewrite(tosa::SliceOp sliceOp,
PatternRewriter &rewriter) const final {
+ Location loc = sliceOp.getLoc();
Value input = sliceOp.input();
SmallVector<int64_t> strides;
+ auto starts = sliceOp.start();
+ auto sizes = sliceOp.size();
strides.resize(sliceOp.getType().template cast<ShapedType>().getRank(), 1);
- rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(
- sliceOp, sliceOp.getType(), input, ValueRange({}), ValueRange({}),
- ValueRange({}), sliceOp.start(), sliceOp.size(),
- rewriter.getI64ArrayAttr(strides));
+ SmallVector<Value> dynSizes;
+ for (auto i : llvm::enumerate(sizes)) {
+ int64_t size = i.value().cast<IntegerAttr>().getInt();
+ size_t index = i.index();
+ if (size != ShapedType::kDynamicSize)
+ continue;
+
+ auto dim = rewriter.create<tensor::DimOp>(loc, input, index);
+ auto offset = rewriter.create<arith::ConstantOp>(
+ loc,
+ rewriter.getIndexAttr(starts[index].cast<IntegerAttr>().getInt()));
+ dynSizes.push_back(rewriter.create<arith::SubIOp>(loc, dim, offset));
+ }
+
+ auto newSliceOp = rewriter.create<tensor::ExtractSliceOp>(
+ sliceOp.getLoc(), sliceOp.getType(), input, ValueRange({}), dynSizes,
+ ValueRange({}), starts, sizes, rewriter.getI64ArrayAttr(strides));
+
+ rewriter.replaceOp(sliceOp, newSliceOp.getResult());
return success();
}
};
diff --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensorPass.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensorPass.cpp
index 6fe862b46a2ac..08d5c7d50640b 100644
--- a/mlir/lib/Conversion/TosaToTensor/TosaToTensorPass.cpp
+++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensorPass.cpp
@@ -12,6 +12,7 @@
#include "../PassDetail.h"
#include "mlir/Conversion/TosaToTensor/TosaToTensor.h"
+#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Tosa/Transforms/PassDetail.h"
@@ -31,6 +32,7 @@ struct TosaToTensor : public TosaToTensorBase<TosaToTensor> {
RewritePatternSet patterns(&getContext());
ConversionTarget target(getContext());
target.addIllegalOp<tosa::SliceOp>();
+ target.addLegalDialect<arith::ArithmeticDialect>();
target.addLegalDialect<tensor::TensorDialect>();
mlir::tosa::populateTosaToTensorConversionPatterns(&patterns);
diff --git a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
index 12eb51f58abed..15a4bcd7498fd 100644
--- a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
+++ b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
@@ -6,3 +6,16 @@ func.func @slice(%arg0: tensor<6xf32>) ->() {
%0 = "tosa.slice"(%arg0) {start = [2], size = [1]} : (tensor<6xf32>) -> (tensor<1xf32>)
return
}
+
+// -----
+
+// CHECK-LABLE: func @slice_dyn
+func.func @slice_dyn(%arg0: tensor<?xf32>) -> (tensor<?xf32>) {
+ // CHECK: %[[C0:.+]] = arith.constant 0 : index
+ // CHECK: %[[DIM:.+]] = tensor.dim %arg0, %[[C0]]
+ // CHECK: %[[C2:.+]] = arith.constant 2 : index
+ // CHECK: %[[SUB:.+]] = arith.subi %[[DIM]], %[[C2]]
+ // CHECK: %2 = tensor.extract_slice %arg0[2] [%[[SUB]]] [1]
+ %0 = "tosa.slice"(%arg0) {start = [2], size = [-1]} : (tensor<?xf32>) -> (tensor<?xf32>)
+ return %0 : tensor<?xf32>
+}
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 23308f9a0bc01..e918670288e2b 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -7897,6 +7897,7 @@ cc_library(
"lib/Conversion/TosaToTensor",
],
deps = [
+ ":ArithmeticDialect",
":ConversionPassIncGen",
":FuncDialect",
":IR",
More information about the Mlir-commits
mailing list