[Mlir-commits] [mlir] 6b4b63a - Lowering for 'tosa.scatter'
Eric Kunze
llvmlistbot at llvm.org
Tue May 30 14:34:15 PDT 2023
Author: Rafael Ubal Tena
Date: 2023-05-30T14:28:52-07:00
New Revision: 6b4b63a832f105039442fc983d0b309abe5261d5
URL: https://github.com/llvm/llvm-project/commit/6b4b63a832f105039442fc983d0b309abe5261d5
DIFF: https://github.com/llvm/llvm-project/commit/6b4b63a832f105039442fc983d0b309abe5261d5.diff
LOG: Lowering for 'tosa.scatter'
This patch adds support for `tosa.scatter` lowering in the `--tosa-to-scf` pass. Here's an example for this lowering:
func.func @tosa(
%valuesIn : tensor<3x7x5xi32>,
%indices : tensor<3x6xi32>,
%input : tensor<3x6x5xi32>) ->
tensor<3x7x5xi32> {
%0 = "tosa.scatter"(%valuesIn, %indices, %input) :
tensor<3x6x5xi32>) ->
return %0 : tensor<3x7x5xi32>
translates to
func.func @tosa(%arg0: tensor<3x7x5xi32>, %arg1: tensor<3x6xi32>, %arg2: tensor<3x6x5xi32>) -> tensor<3x7x5xi32> {
%c0 = arith.constant 0 : index
%c3 = arith.constant 3 : index
%c1 = arith.constant 1 : index
%c6 = arith.constant 6 : index
%c2 = arith.constant 2 : index
%c5 = arith.constant 5 : index
%c0_0 = arith.constant 0 : index
%c1_1 = arith.constant 1 : index
%0 = scf.for %arg3 = %c0_0 to %c3 step %c1_1 iter_args(%arg4 = %arg0) -> (tensor<3x7x5xi32>) {
%1 = scf.for %arg5 = %c0_0 to %c6 step %c1_1 iter_args(%arg6 = %arg4) -> (tensor<3x7x5xi32>) {
%extracted = tensor.extract %arg1[%arg3, %arg5] : tensor<3x6xi32>
%2 = arith.index_cast %extracted : i32 to index
%extracted_slice = tensor.extract_slice %arg2[%arg3, %arg5, %c0_0] [%c1_1, %c1_1, %c5] [%c1_1, %c1_1, %c1_1] : tensor<3x6x5xi32> to tensor<?x?x?xi32>
%inserted_slice = tensor.insert_slice %extracted_slice into %arg6[%arg3, %2, %c0_0] [%c1_1, %c1_1, %c5] [%c1_1, %c1_1, %c1_1] : tensor<?x?x?xi32> into tensor<3x7x5xi32>
scf.yield %inserted_slice : tensor<3x7x5xi32>
scf.yield %1 : tensor<3x7x5xi32>
return %0 : tensor<3x7x5xi32>
We have attempted an alternative lowering pass that uses `tensor.scatter` as an intermediate step. However, we opted to aim straight at the `scf` dialect for the following reasons:
- The `tensor.scatter` op doesn't seem to be used anywhere. There is no available lowering pass for this op (although we have one that we'll upstream soon).
- The `tosa.scatter` and `tensor.scatter` op have different indexing semantics. The `indices` argument of `tosa.scatter` must be non-trivially modified and restructured (e.g. with a `linalg.generic` op) to adapt to the needs of `tensor.scatter`. While this overhead may be simplified and fused after a subsequent `tensor.scatter` lowering, it adds complex logic and an obscure intermediate state. Unless there is a good reason to go through the `tensor` dialect that we're missing, this additional complexity may not be justified.
Reviewed By: eric-k256
Differential Revision: https://reviews.llvm.org/D151117
diff --git a/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp b/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp
index 8f10497d99c32..9139bf191fdf1 100644
--- a/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp
+++ b/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp
@@ -82,6 +82,75 @@ class IfOpConverter : public OpRewritePattern<tosa::IfOp> {
+class ScatterOpConverter : public OpRewritePattern<tosa::ScatterOp> {
+ static Value createTensorDim(OpBuilder &builder, Location loc, Value tensor,
+ int64_t dim) {
+ return builder.createOrFold<tensor::DimOp>(loc, tensor, dim);
+ }
+ static Value createIndexConst(OpBuilder &builder, Location loc,
+ int64_t value) {
+ return builder.create<arith::ConstantIndexOp>(loc, value);
+ }
+ using OpRewritePattern<tosa::ScatterOp>::OpRewritePattern;
+ LogicalResult matchAndRewrite(tosa::ScatterOp scatter,
+ PatternRewriter &rewriter) const final {
+ auto valuesIn = scatter.getValuesIn();
+ auto indices = scatter.getIndices();
+ auto input = scatter.getInput();
+ auto loc = scatter.getLoc();
+ // N, W, C are chosen to match the TOSA spec
+ auto dimN = createTensorDim(rewriter, loc, input, 0);
+ auto dimW = createTensorDim(rewriter, loc, input, 1);
+ auto dimC = createTensorDim(rewriter, loc, input, 2);
+ auto zero = createIndexConst(rewriter, loc, 0);
+ auto one = createIndexConst(rewriter, loc, 1);
+ // Loop bounds
+ auto lbs = llvm::SmallVector<Value>(2, zero);
+ auto steps = llvm::SmallVector<Value>(2, one);
+ auto ubs = llvm::SmallVector<Value>{{dimN, dimW}};
+ auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange ivs,
+ ValueRange args) -> scf::ValueVector {
+ auto n = ivs[0];
+ // Read the index and cast it to index type
+ auto index = builder.create<tensor::ExtractOp>(loc, indices, ivs);
+ auto castIndex = builder.create<arith::IndexCastOp>(
+ loc, builder.getIndexType(), index);
+ // Offset, sizes, and strides for the input tensor
+ auto inputOffset = llvm::to_vector(ivs);
+ inputOffset.push_back(zero);
+ llvm::SmallVector<Value> sizes = {one, one, dimC};
+ llvm::SmallVector<Value> strides = {one, one, one};
+ auto slice = builder.create<tensor::ExtractSliceOp>(
+ loc, input, inputOffset, sizes, strides);
+ // Insert the slice into the output accumulator tensor.
+ llvm::SmallVector<Value> outputOffset = {n, castIndex, zero};
+ auto updated = builder.create<tensor::InsertSliceOp>(
+ loc, slice, args[0], outputOffset, sizes, strides);
+ return {updated};
+ };
+ auto loops = scf::buildLoopNest(rewriter, loc, lbs, ubs, steps,
+ ValueRange{valuesIn}, buildBody);
+ rewriter.replaceOp(scatter, loops.results);
+ return success();
+ }
class WhileOpConverter : public OpRewritePattern<tosa::WhileOp> {
using OpRewritePattern<tosa::WhileOp>::OpRewritePattern;
@@ -106,6 +175,6 @@ class WhileOpConverter : public OpRewritePattern<tosa::WhileOp> {
void mlir::tosa::populateTosaToSCFConversionPatterns(
RewritePatternSet *patterns) {
- patterns->add<IfOpConverter>(patterns->getContext());
- patterns->add<WhileOpConverter>(patterns->getContext());
+ patterns->add<IfOpConverter, ScatterOpConverter, WhileOpConverter>(
+ patterns->getContext());
diff --git a/mlir/lib/Conversion/TosaToSCF/TosaToSCFPass.cpp b/mlir/lib/Conversion/TosaToSCF/TosaToSCFPass.cpp
index 759b730556d7a..d14535029132f 100644
--- a/mlir/lib/Conversion/TosaToSCF/TosaToSCFPass.cpp
+++ b/mlir/lib/Conversion/TosaToSCF/TosaToSCFPass.cpp
@@ -37,7 +37,7 @@ struct TosaToSCF : public impl::TosaToSCFBase<TosaToSCF> {
RewritePatternSet patterns(&getContext());
ConversionTarget target(getContext());
target.addLegalDialect<tensor::TensorDialect, scf::SCFDialect>();
- target.addIllegalOp<tosa::IfOp, tosa::WhileOp>();
+ target.addIllegalOp<tosa::IfOp, tosa::ScatterOp, tosa::WhileOp>();
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
auto *op = getOperation();
diff --git a/mlir/test/Conversion/TosaToSCF/tosa-to-scf.mlir b/mlir/test/Conversion/TosaToSCF/tosa-to-scf.mlir
index 59931137cdf5b..4f0e29539b6e4 100644
--- a/mlir/test/Conversion/TosaToSCF/tosa-to-scf.mlir
+++ b/mlir/test/Conversion/TosaToSCF/tosa-to-scf.mlir
@@ -56,3 +56,33 @@ func.func @if_test(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tensor<i1>)
return %0 : tensor<f32>
+// -----
+// CHECK-LABEL: func @scatter_test
+// CHECK-SAME: ([[VALUES_IN:%.+]]: tensor<3x7x5xi32>, [[INDICES:%.+]]: tensor<3x6xi32>, [[INPUT:%.+]]: tensor<3x6x5xi32>)
+func.func @scatter_test(%values_in: tensor<3x7x5xi32>, %indices : tensor<3x6xi32>, %input: tensor<3x6x5xi32>) -> tensor<3x7x5xi32> {
+ // CHECK-DAG: [[C_0:%.+]] = arith.constant 0 : index
+ // CHECK-DAG: [[C_1:%.+]] = arith.constant 1 : index
+ // CHECK-DAG: [[C_2:%.+]] = arith.constant 2 : index
+ // CHECK-DAG: [[C_3:%.+]] = arith.constant 3 : index
+ // CHECK-DAG: [[C_5:%.+]] = arith.constant 5 : index
+ // CHECK-DAG: [[C_6:%.+]] = arith.constant 6 : index
+ // CHECK-DAG: [[C_0_0:%.+]] = arith.constant 0 : index
+ // CHECK-DAG: [[C_1_0:%.+]] = arith.constant 1 : index
+ // CHECK: [[RESULT_0:%.+]] = scf.for [[ITER_VAR_0:%.+]] = [[C_0_0]] to [[C_3]] step [[C_1_0]] iter_args([[ITER_ARG_0:%.+]] = [[VALUES_IN]]) -> (tensor<3x7x5xi32>) {
+ // CHECK: [[RESULT_1:%.+]] = scf.for [[ITER_VAR_1:%.+]] = [[C_0_0]] to [[C_6]] step [[C_1_0]] iter_args([[ITER_ARG_1:%.+]] = [[ITER_ARG_0]]) -> (tensor<3x7x5xi32>) {
+ // CHECK-DAG: [[EXTRACTED:%.+]] = tensor.extract [[INDICES]][[[ITER_VAR_0]], [[ITER_VAR_1]]] : tensor<3x6xi32>
+ // CHECK-DAG: [[EXTRACTED_CAST:%.+]] = arith.index_cast [[EXTRACTED]] : i32 to index
+ // CHECK-DAG: [[EXTRACTED_SLICE:%.+]] = tensor.extract_slice [[INPUT]][[[ITER_VAR_0]], [[ITER_VAR_1]], [[C_0_0]]] [[[C_1_0]], [[C_1_0]], [[C_5]]] [[[C_1_0]], [[C_1_0]], [[C_1_0]]] : tensor<3x6x5xi32> to tensor<?x?x?xi32>
+ // CHECK-DAG: [[INSERTED_SLICE:%.+]] = tensor.insert_slice [[EXTRACTED_SLICE]] into [[ITER_ARG_1]][[[ITER_VAR_0]], [[EXTRACTED_CAST]], [[C_0_0]]] [[[C_1_0]], [[C_1_0]], [[C_5]]] [[[C_1_0]], [[C_1_0]], [[C_1_0]]] : tensor<?x?x?xi32> into tensor<3x7x5xi32>
+ // CHECK: scf.yield [[INSERTED_SLICE]] : tensor<3x7x5xi32>
+ // CHECK: }
+ // CHECK: scf.yield [[RESULT_1]] : tensor<3x7x5xi32>
+ // CHECK: }
+ %0 = "tosa.scatter"(%values_in, %indices, %input) : (tensor<3x7x5xi32>, tensor<3x6xi32>, tensor<3x6x5xi32>) -> (tensor<3x7x5xi32>)
+ // CHECK: return [[RESULT_0]] : tensor<3x7x5xi32>
+ return %0 : tensor<3x7x5xi32>
