[Mlir-commits] [mlir] 6b4b63a - Lowering for 'tosa.scatter'

Eric Kunze llvmlistbot at llvm.org
Tue May 30 14:34:15 PDT 2023


Author: Rafael Ubal Tena
Date: 2023-05-30T14:28:52-07:00
New Revision: 6b4b63a832f105039442fc983d0b309abe5261d5

URL: https://github.com/llvm/llvm-project/commit/6b4b63a832f105039442fc983d0b309abe5261d5
DIFF: https://github.com/llvm/llvm-project/commit/6b4b63a832f105039442fc983d0b309abe5261d5.diff

LOG: Lowering for 'tosa.scatter'

This patch adds support for `tosa.scatter` lowering in the `--tosa-to-scf` pass. Here's an example for this lowering:

```
func.func @tosa(
                %valuesIn : tensor<3x7x5xi32>,
                %indices : tensor<3x6xi32>,
                %input : tensor<3x6x5xi32>) ->
                tensor<3x7x5xi32> {
        %0 = "tosa.scatter"(%valuesIn, %indices, %input) :
                        (tensor<3x7x5xi32>,
                        tensor<3x6xi32>,
                        tensor<3x6x5xi32>) ->
                        (tensor<3x7x5xi32>)
        return %0 : tensor<3x7x5xi32>
}
```

translates to
  func.func @tosa(%arg0: tensor<3x7x5xi32>, %arg1: tensor<3x6xi32>, %arg2: tensor<3x6x5xi32>) -> tensor<3x7x5xi32> {
    %c0 = arith.constant 0 : index
    %c3 = arith.constant 3 : index
    %c1 = arith.constant 1 : index
    %c6 = arith.constant 6 : index
    %c2 = arith.constant 2 : index
    %c5 = arith.constant 5 : index
    %c0_0 = arith.constant 0 : index
    %c1_1 = arith.constant 1 : index
    %0 = scf.for %arg3 = %c0_0 to %c3 step %c1_1 iter_args(%arg4 = %arg0) -> (tensor<3x7x5xi32>) {
      %1 = scf.for %arg5 = %c0_0 to %c6 step %c1_1 iter_args(%arg6 = %arg4) -> (tensor<3x7x5xi32>) {
        %extracted = tensor.extract %arg1[%arg3, %arg5] : tensor<3x6xi32>
        %2 = arith.index_cast %extracted : i32 to index
        %extracted_slice = tensor.extract_slice %arg2[%arg3, %arg5, %c0_0] [%c1_1, %c1_1, %c5] [%c1_1, %c1_1, %c1_1] : tensor<3x6x5xi32> to tensor<?x?x?xi32>
        %inserted_slice = tensor.insert_slice %extracted_slice into %arg6[%arg3, %2, %c0_0] [%c1_1, %c1_1, %c5] [%c1_1, %c1_1, %c1_1] : tensor<?x?x?xi32> into tensor<3x7x5xi32>
        scf.yield %inserted_slice : tensor<3x7x5xi32>
      }
      scf.yield %1 : tensor<3x7x5xi32>
    }
    return %0 : tensor<3x7x5xi32>
  }
```

We have attempted an alternative lowering pass that uses `tensor.scatter` as an intermediate step. However, we opted to aim straight at the `scf` dialect for the following reasons:

- The `tensor.scatter` op doesn't seem to be used anywhere. There is no available lowering pass for this op (although we have one that we'll upstream soon).
- The `tosa.scatter` and `tensor.scatter` op have different indexing semantics. The `indices` argument of `tosa.scatter` must be non-trivially modified and restructured (e.g. with a `linalg.generic` op) to adapt to the needs of `tensor.scatter`. While this overhead may be simplified and fused after a subsequent `tensor.scatter` lowering, it adds complex logic and an obscure intermediate state. Unless there is a good reason to go through the `tensor` dialect that we're missing, this additional complexity may not be justified.

Reviewed By: eric-k256

Differential Revision: https://reviews.llvm.org/D151117

Added: 
    

Modified: 
    mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp
    mlir/lib/Conversion/TosaToSCF/TosaToSCFPass.cpp
    mlir/test/Conversion/TosaToSCF/tosa-to-scf.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp b/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp
index 8f10497d99c32..9139bf191fdf1 100644
--- a/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp
+++ b/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp
@@ -82,6 +82,75 @@ class IfOpConverter : public OpRewritePattern<tosa::IfOp> {
   }
 };
 
+class ScatterOpConverter : public OpRewritePattern<tosa::ScatterOp> {
+  static Value createTensorDim(OpBuilder &builder, Location loc, Value tensor,
+                               int64_t dim) {
+    return builder.createOrFold<tensor::DimOp>(loc, tensor, dim);
+  }
+
+  static Value createIndexConst(OpBuilder &builder, Location loc,
+                                int64_t value) {
+    return builder.create<arith::ConstantIndexOp>(loc, value);
+  }
+
+public:
+  using OpRewritePattern<tosa::ScatterOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tosa::ScatterOp scatter,
+                                PatternRewriter &rewriter) const final {
+    auto valuesIn = scatter.getValuesIn();
+    auto indices = scatter.getIndices();
+    auto input = scatter.getInput();
+    auto loc = scatter.getLoc();
+
+    // N, W, C are chosen to match the TOSA spec
+    auto dimN = createTensorDim(rewriter, loc, input, 0);
+    auto dimW = createTensorDim(rewriter, loc, input, 1);
+    auto dimC = createTensorDim(rewriter, loc, input, 2);
+
+    auto zero = createIndexConst(rewriter, loc, 0);
+    auto one = createIndexConst(rewriter, loc, 1);
+
+    // Loop bounds
+    auto lbs = llvm::SmallVector<Value>(2, zero);
+    auto steps = llvm::SmallVector<Value>(2, one);
+    auto ubs = llvm::SmallVector<Value>{{dimN, dimW}};
+
+    auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange ivs,
+                         ValueRange args) -> scf::ValueVector {
+      auto n = ivs[0];
+
+      // Read the index and cast it to index type
+      auto index = builder.create<tensor::ExtractOp>(loc, indices, ivs);
+      auto castIndex = builder.create<arith::IndexCastOp>(
+          loc, builder.getIndexType(), index);
+
+      // Offset, sizes, and strides for the input tensor
+      auto inputOffset = llvm::to_vector(ivs);
+      inputOffset.push_back(zero);
+
+      llvm::SmallVector<Value> sizes = {one, one, dimC};
+      llvm::SmallVector<Value> strides = {one, one, one};
+
+      auto slice = builder.create<tensor::ExtractSliceOp>(
+          loc, input, inputOffset, sizes, strides);
+
+      // Insert the slice into the output accumulator tensor.
+      llvm::SmallVector<Value> outputOffset = {n, castIndex, zero};
+      auto updated = builder.create<tensor::InsertSliceOp>(
+          loc, slice, args[0], outputOffset, sizes, strides);
+
+      return {updated};
+    };
+
+    auto loops = scf::buildLoopNest(rewriter, loc, lbs, ubs, steps,
+                                    ValueRange{valuesIn}, buildBody);
+    rewriter.replaceOp(scatter, loops.results);
+
+    return success();
+  }
+};
+
 class WhileOpConverter : public OpRewritePattern<tosa::WhileOp> {
 public:
   using OpRewritePattern<tosa::WhileOp>::OpRewritePattern;
@@ -106,6 +175,6 @@ class WhileOpConverter : public OpRewritePattern<tosa::WhileOp> {
 
 void mlir::tosa::populateTosaToSCFConversionPatterns(
     RewritePatternSet *patterns) {
-  patterns->add<IfOpConverter>(patterns->getContext());
-  patterns->add<WhileOpConverter>(patterns->getContext());
+  patterns->add<IfOpConverter, ScatterOpConverter, WhileOpConverter>(
+      patterns->getContext());
 }

diff  --git a/mlir/lib/Conversion/TosaToSCF/TosaToSCFPass.cpp b/mlir/lib/Conversion/TosaToSCF/TosaToSCFPass.cpp
index 759b730556d7a..d14535029132f 100644
--- a/mlir/lib/Conversion/TosaToSCF/TosaToSCFPass.cpp
+++ b/mlir/lib/Conversion/TosaToSCF/TosaToSCFPass.cpp
@@ -37,7 +37,7 @@ struct TosaToSCF : public impl::TosaToSCFBase<TosaToSCF> {
     RewritePatternSet patterns(&getContext());
     ConversionTarget target(getContext());
     target.addLegalDialect<tensor::TensorDialect, scf::SCFDialect>();
-    target.addIllegalOp<tosa::IfOp, tosa::WhileOp>();
+    target.addIllegalOp<tosa::IfOp, tosa::ScatterOp, tosa::WhileOp>();
     target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
 
     auto *op = getOperation();

diff  --git a/mlir/test/Conversion/TosaToSCF/tosa-to-scf.mlir b/mlir/test/Conversion/TosaToSCF/tosa-to-scf.mlir
index 59931137cdf5b..4f0e29539b6e4 100644
--- a/mlir/test/Conversion/TosaToSCF/tosa-to-scf.mlir
+++ b/mlir/test/Conversion/TosaToSCF/tosa-to-scf.mlir
@@ -56,3 +56,33 @@ func.func @if_test(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tensor<i1>)
 
   return %0 : tensor<f32>
 }
+
+// -----
+
+// CHECK-LABEL: func @scatter_test
+// CHECK-SAME: ([[VALUES_IN:%.+]]: tensor<3x7x5xi32>, [[INDICES:%.+]]: tensor<3x6xi32>, [[INPUT:%.+]]: tensor<3x6x5xi32>)
+func.func @scatter_test(%values_in: tensor<3x7x5xi32>, %indices : tensor<3x6xi32>, %input: tensor<3x6x5xi32>) -> tensor<3x7x5xi32> {
+
+  // CHECK-DAG: [[C_0:%.+]] = arith.constant 0 : index
+  // CHECK-DAG: [[C_1:%.+]] = arith.constant 1 : index
+  // CHECK-DAG: [[C_2:%.+]] = arith.constant 2 : index
+  // CHECK-DAG: [[C_3:%.+]] = arith.constant 3 : index
+  // CHECK-DAG: [[C_5:%.+]] = arith.constant 5 : index
+  // CHECK-DAG: [[C_6:%.+]] = arith.constant 6 : index
+  // CHECK-DAG: [[C_0_0:%.+]] = arith.constant 0 : index
+  // CHECK-DAG: [[C_1_0:%.+]] = arith.constant 1 : index
+  // CHECK: [[RESULT_0:%.+]] = scf.for [[ITER_VAR_0:%.+]] = [[C_0_0]] to [[C_3]] step [[C_1_0]] iter_args([[ITER_ARG_0:%.+]] = [[VALUES_IN]]) -> (tensor<3x7x5xi32>) {
+    // CHECK: [[RESULT_1:%.+]] = scf.for [[ITER_VAR_1:%.+]] = [[C_0_0]] to [[C_6]] step [[C_1_0]] iter_args([[ITER_ARG_1:%.+]] = [[ITER_ARG_0]]) -> (tensor<3x7x5xi32>) {
+      // CHECK-DAG: [[EXTRACTED:%.+]] = tensor.extract [[INDICES]][[[ITER_VAR_0]], [[ITER_VAR_1]]] : tensor<3x6xi32>
+      // CHECK-DAG: [[EXTRACTED_CAST:%.+]] = arith.index_cast [[EXTRACTED]] : i32 to index
+      // CHECK-DAG: [[EXTRACTED_SLICE:%.+]] = tensor.extract_slice [[INPUT]][[[ITER_VAR_0]], [[ITER_VAR_1]], [[C_0_0]]] [[[C_1_0]], [[C_1_0]], [[C_5]]] [[[C_1_0]], [[C_1_0]], [[C_1_0]]] : tensor<3x6x5xi32> to tensor<?x?x?xi32>
+      // CHECK-DAG: [[INSERTED_SLICE:%.+]] = tensor.insert_slice [[EXTRACTED_SLICE]] into [[ITER_ARG_1]][[[ITER_VAR_0]], [[EXTRACTED_CAST]], [[C_0_0]]] [[[C_1_0]], [[C_1_0]], [[C_5]]] [[[C_1_0]], [[C_1_0]], [[C_1_0]]] : tensor<?x?x?xi32> into tensor<3x7x5xi32>
+      // CHECK: scf.yield [[INSERTED_SLICE]] : tensor<3x7x5xi32>
+    // CHECK: }
+    // CHECK: scf.yield [[RESULT_1]] : tensor<3x7x5xi32>
+  // CHECK: }
+	%0 = "tosa.scatter"(%values_in, %indices, %input) : (tensor<3x7x5xi32>, tensor<3x6xi32>, tensor<3x6x5xi32>) -> (tensor<3x7x5xi32>)
+
+  // CHECK: return [[RESULT_0]] : tensor<3x7x5xi32>
+	return %0 : tensor<3x7x5xi32>
+}


        


More information about the Mlir-commits mailing list