[Mlir-commits] [mlir] 9047825 - [mlir][tosa] Tosa reverse to linalg supporting dynamic shapes

Rob Suderman llvmlistbot at llvm.org
Thu Aug 26 13:26:07 PDT 2021


Author: Rob Suderman
Date: 2021-08-26T13:23:59-07:00
New Revision: 90478251c736ce335fe8d45e46a09d9bec889583

URL: https://github.com/llvm/llvm-project/commit/90478251c736ce335fe8d45e46a09d9bec889583
DIFF: https://github.com/llvm/llvm-project/commit/90478251c736ce335fe8d45e46a09d9bec889583.diff

LOG: [mlir][tosa] Tosa reverse to linalg supporting dynamic shapes

Needed to switch to extract to support tosa.reverse using dynamic shapes.

Reviewed By: NatashaKnk

Differential Revision: https://reviews.llvm.org/D108744

Added: 
    

Modified: 
    mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
    mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index beab93bc8bd0b..74239fed3f010 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -2043,40 +2043,48 @@ class ReverseConverter : public OpRewritePattern<tosa::ReverseOp> {
     Value input = op.input();
     auto inputTy = input.getType().template cast<ShapedType>();
     auto resultTy = op.getType().template cast<ShapedType>();
-    auto rank = resultTy.getRank();
     auto axis = op.axis();
 
-    if (!inputTy.hasStaticShape())
-      return rewriter.notifyMatchFailure(
-          op, "No initial value found for reduction operation");
+    SmallVector<Value> dynDims;
+    for (int i = 0; i < inputTy.getRank(); i++) {
+      if (inputTy.isDynamicDim(i)) {
+        dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
+      }
+    }
+
+    Value axisDimSize = rewriter.create<tensor::DimOp>(loc, input, axis);
 
     // First fill the output buffer with the init value.
     auto initTensor = rewriter
                           .create<linalg::InitTensorOp>(
-                              loc, ArrayRef<Value>({}), inputTy.getShape(),
-                              inputTy.getElementType())
+                              loc, ArrayRef<Value>({dynDims}),
+                              inputTy.getShape(), inputTy.getElementType())
                           .result();
-
-    SmallVector<AffineExpr, 2> inputExprs;
-    inputExprs.resize(resultTy.getRank());
-
-    for (int i = 0; i < rank; i++)
-      inputExprs[i] = rewriter.getAffineDimExpr(i);
-
-    inputExprs[axis] =
-        rewriter.getAffineConstantExpr(inputTy.getDimSize(axis) - 1) -
-        inputExprs[axis];
-
     SmallVector<AffineMap, 2> affineMaps = {
-        AffineMap::get(resultTy.getRank(), /*symbolCount=*/0, inputExprs,
-                       rewriter.getContext()),
         rewriter.getMultiDimIdentityMap(resultTy.getRank())};
 
     rewriter.replaceOpWithNewOp<linalg::GenericOp>(
-        op, resultTy, op.input(), ValueRange{initTensor}, affineMaps,
+        op, resultTy, ArrayRef<Value>({}), ValueRange{initTensor}, affineMaps,
         getNParallelLoopsAttrs(resultTy.getRank()),
         [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
-          nestedBuilder.create<linalg::YieldOp>(op.getLoc(), *args.begin());
+          llvm::SmallVector<Value> indices;
+          for (unsigned int i = 0; i < inputTy.getRank(); i++) {
+            auto index =
+                rewriter.create<linalg::IndexOp>(nestedLoc, i).getResult();
+            if (i == axis) {
+              auto one = rewriter.create<ConstantIndexOp>(nestedLoc, 1);
+              auto sizeMinusOne =
+                  rewriter.create<SubIOp>(nestedLoc, axisDimSize, one);
+              index = rewriter.create<SubIOp>(nestedLoc, sizeMinusOne, index);
+            }
+
+            indices.push_back(index);
+          }
+
+          auto extract = nestedBuilder.create<tensor::ExtractOp>(
+              nestedLoc, input, indices);
+          nestedBuilder.create<linalg::YieldOp>(op.getLoc(),
+                                                extract.getResult());
         });
     return success();
   }

diff  --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 8a3cf62405386..50e4c78bb2d45 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -881,28 +881,62 @@ func @rescaleUnnecessaryDoubleRound(%arg0 : tensor<2xi8>) -> (tensor<2xi8>) {
 
 // -----
 
-// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (-d0 + 4, d1)>
-// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0, d1)>
-// CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0, -d1 + 3)>
+// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
 
 // CHECK-LABEL: @reverse
 func @reverse(%arg0: tensor<5x4xi32>) -> () {
-  // CHECK: [[INIT:%.+]] = linalg.init_tensor [5, 4]
-  // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor<5x4xi32>) outs([[INIT]] : tensor<5x4xi32>) {
-  // CHECK: ^bb0(%arg1: i32, %arg2: i32):
-  // CHECK:   linalg.yield %arg1 : i32
+  // CHECK: %[[C0:.+]] = constant 0
+  // CHECK: %[[RDIM:.+]] = tensor.dim %arg0, %[[C0]]
+  // CHECK: %[[INIT:.+]] = linalg.init_tensor [5, 4]
+  // CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]]], iterator_types = ["parallel", "parallel"]} outs(%[[INIT]] : tensor<5x4xi32>)
+  // CHECK-DAG:   %[[I0:.+]] = linalg.index 0
+  // CHECK-DAG:   %[[I1:.+]] = linalg.index 1
+  // CHECK-DAG:   %[[SUB1:.+]] = constant 1
+  // CHECK-DAG:   %[[RDIM_MINUS_C1:.+]] = subi %[[RDIM]], %[[SUB1]]
+  // CHECK-DAG:   %[[READ_DIM:.+]] = subi %[[RDIM_MINUS_C1]], %[[I0]]
+  // CHECK-DAG:   %[[EXTRACT:.+]] = tensor.extract %arg0[%[[READ_DIM]], %[[I1]]] : tensor<5x4xi32>
+  // CHECK:   linalg.yield %[[EXTRACT]]
   %0 = "tosa.reverse"(%arg0) {axis = 0 : i64} : (tensor<5x4xi32>) -> tensor<5x4xi32>
 
-  // CHECK: [[INIT:%.+]] = linalg.init_tensor [5, 4]
-  // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor<5x4xi32>) outs([[INIT]] : tensor<5x4xi32>) {
-  // CHECK: ^bb0(%arg1: i32, %arg2: i32):
-  // CHECK:   linalg.yield %arg1 : i32
+  // CHECK: %[[C1:.+]] = constant 1
+  // CHECK: %[[RDIM:.+]] = tensor.dim %arg0, %[[C1]]
+  // CHECK: %[[INIT:.+]] = linalg.init_tensor [5, 4]
+  // CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]]], iterator_types = ["parallel", "parallel"]} outs(%[[INIT]] : tensor<5x4xi32>)
+  // CHECK-DAG:   %[[I0:.+]] = linalg.index 0
+  // CHECK-DAG:   %[[I1:.+]] = linalg.index 1
+  // CHECK-DAG:   %[[SUB1:.+]] = constant 1
+  // CHECK-DAG:   %[[RDIM_MINUS_C1:.+]] = subi %[[RDIM]], %[[SUB1]]
+  // CHECK-DAG:   %[[READ_DIM:.+]] = subi %[[RDIM_MINUS_C1]], %[[I1]]
+  // CHECK-DAG:   %[[EXTRACT:.+]] = tensor.extract %arg0[%[[I0]], %[[READ_DIM]]] : tensor<5x4xi32>
+  // CHECK:   linalg.yield %[[EXTRACT]]
   %1 = "tosa.reverse"(%arg0) {axis = 1 : i64} : (tensor<5x4xi32>) -> tensor<5x4xi32>
   return
 }
 
 // -----
 
+// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>
+
+// CHECK-LABEL: @reverse_dyn
+func @reverse_dyn(%arg0: tensor<?xi32>) -> () {
+  // CHECK: %[[C0_1:.+]] = constant 0
+  // CHECK: %[[D0_1:.+]] = tensor.dim %arg0, %[[C0_1]]
+  // CHECK: %[[C0_2:.+]] = constant 0
+  // CHECK: %[[D0_2:.+]] = tensor.dim %arg0, %[[C0_2]]
+  // CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[D0_1]]]
+  // CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]]], iterator_types = ["parallel"]} outs(%[[INIT]] : tensor<?xi32>)
+  // CHECK-DAG:   %[[I0:.+]] = linalg.index 0
+  // CHECK-DAG:   %[[SUB1:.+]] = constant 1
+  // CHECK-DAG:   %[[RDIM_MINUS_C1:.+]] = subi %[[D0_2]], %[[SUB1]]
+  // CHECK-DAG:   %[[READ_DIM:.+]] = subi %[[RDIM_MINUS_C1]], %[[I0]]
+  // CHECK-DAG:   %[[EXTRACT:.+]] = tensor.extract %arg0[%[[READ_DIM]]] : tensor<?xi32>
+  // CHECK:   linalg.yield %[[EXTRACT]]
+  %0 = "tosa.reverse"(%arg0) {axis = 0 : i64} : (tensor<?xi32>) -> tensor<?xi32>
+  return
+}
+
+// -----
+
 // CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, d3)>
 // CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
 


        


More information about the Mlir-commits mailing list