[Mlir-commits] [mlir] 5627564 - [mlir][tosa] Add tosa.concat to subtensor inserts lowering

Rob Suderman llvmlistbot at llvm.org
Thu Mar 18 16:00:30 PDT 2021


Author: Rob Suderman
Date: 2021-03-18T15:59:07-07:00
New Revision: 5627564fe053bd257385157cea43e795e7c48e3f

URL: https://github.com/llvm/llvm-project/commit/5627564fe053bd257385157cea43e795e7c48e3f
DIFF: https://github.com/llvm/llvm-project/commit/5627564fe053bd257385157cea43e795e7c48e3f.diff

LOG: [mlir][tosa] Add tosa.concat to subtensor inserts lowering

Includes lowering for tosa.concat with indice computation with subtensor insert
operations. Includes tests along two different indices.

Differential Revision: https://reviews.llvm.org/D98813

Added: 
    

Modified: 
    mlir/lib/Conversion/TosaToLinalg/CMakeLists.txt
    mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
    mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
    mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/TosaToLinalg/CMakeLists.txt b/mlir/lib/Conversion/TosaToLinalg/CMakeLists.txt
index 8a53b9da025b..a44621ec6033 100644
--- a/mlir/lib/Conversion/TosaToLinalg/CMakeLists.txt
+++ b/mlir/lib/Conversion/TosaToLinalg/CMakeLists.txt
@@ -14,6 +14,7 @@ add_mlir_conversion_library(MLIRTosaToLinalg
   MLIRLinalg
   MLIRLinalgUtils
   MLIRMath
+  MLIRMemRef
   MLIRPass
   MLIRTosa
   MLIRTosaTransforms

diff  --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 2fe4aa31e482..dd2725cbd0fa 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -13,6 +13,7 @@
 #include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h"
 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
 #include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
 #include "mlir/IR/Matchers.h"
@@ -657,6 +658,53 @@ class ReduceConverter : public OpRewritePattern<SrcOp> {
   }
 };
 
+struct ConcatOpConversion : public OpConversionPattern<tosa::ConcatOp> {
+  using OpConversionPattern<tosa::ConcatOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(tosa::ConcatOp op, ArrayRef<Value> args,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto resultType = op.getType().dyn_cast<RankedTensorType>();
+    if (!resultType || !resultType.hasStaticShape()) {
+      return rewriter.notifyMatchFailure(op,
+                                         "expected static shaped tensor type");
+    }
+
+    Location loc = op.getLoc();
+    int axis = op.axis();
+    Value axisValue =
+        rewriter.create<ConstantOp>(loc, rewriter.getIndexAttr(axis));
+    int rank = resultType.getRank();
+    SmallVector<Value, 3> offsets, sizes, strides;
+    sizes.reserve(rank);
+    strides.resize(rank, rewriter.create<ConstantIndexOp>(loc, 1));
+    offsets.resize(rank, rewriter.create<ConstantIndexOp>(loc, 0));
+
+    for (int i = 0; i < rank; ++i) {
+      sizes.push_back(rewriter.create<memref::DimOp>(loc, args[0], i));
+    }
+
+    Value resultDimSize = sizes[axis];
+    for (auto arg : args.drop_front()) {
+      auto size = rewriter.create<memref::DimOp>(loc, arg, axisValue);
+      resultDimSize = rewriter.create<AddIOp>(loc, resultDimSize, size);
+    }
+    sizes[axis] = resultDimSize;
+
+    Value result = rewriter.create<linalg::InitTensorOp>(
+        loc, resultType.getShape(), resultType.getElementType());
+
+    for (auto arg : args) {
+      sizes[axis] = rewriter.create<memref::DimOp>(loc, arg, axisValue);
+      result = rewriter.create<SubTensorInsertOp>(loc, arg, result, offsets,
+                                                  sizes, strides);
+      offsets[axis] = rewriter.create<AddIOp>(loc, offsets[axis], sizes[axis]);
+    }
+    rewriter.replaceOp(op, result);
+    return success();
+  }
+};
+
 } // namespace
 
 void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns(
@@ -680,6 +728,6 @@ void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns(
       IdentityNConverter<tosa::IdentityOp>,
       IdentityNConverter<tosa::IdentityNOp>, ReduceConverter<tosa::ReduceMinOp>,
       ReduceConverter<tosa::ReduceMaxOp>, ReduceConverter<tosa::ReduceSumOp>,
-      ReduceConverter<tosa::ReduceProdOp>, ReshapeOpConverter,
-      TransposeConverter>(context);
+      ReduceConverter<tosa::ReduceProdOp>, ConcatOpConversion,
+      ReshapeOpConverter, TransposeConverter>(context);
 }

diff  --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
index 8ccf83529457..a1bd694f67af 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
@@ -14,6 +14,7 @@
 #include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h"
 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
 #include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
 #include "mlir/Dialect/Tosa/Transforms/PassDetail.h"
@@ -31,14 +32,15 @@ struct TosaToLinalgOnTensors
     : public TosaToLinalgOnTensorsBase<TosaToLinalgOnTensors> {
 public:
   void getDependentDialects(DialectRegistry &registry) const override {
-    registry
-        .insert<linalg::LinalgDialect, math::MathDialect, StandardOpsDialect>();
+    registry.insert<linalg::LinalgDialect, math::MathDialect,
+                    memref::MemRefDialect, StandardOpsDialect>();
   }
 
   void runOnFunction() override {
     OwningRewritePatternList patterns;
     ConversionTarget target(getContext());
-    target.addLegalDialect<linalg::LinalgDialect, StandardOpsDialect>();
+    target.addLegalDialect<linalg::LinalgDialect, memref::MemRefDialect,
+                           StandardOpsDialect>();
     target.addIllegalDialect<tosa::TosaDialect>();
     target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
 

diff  --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index d1868e7683ce..9b1f6054ee06 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -433,3 +433,43 @@ func @reduce_int(%arg0: tensor<5x4xi32>) -> () {
   %4 = "tosa.reduce_max"(%arg0) {axis = 0 : i64} : (tensor<5x4xi32>) -> tensor<4xi32>
   return
 }
+
+// -----
+
+// CHECK-LABEL: @concat
+func @concat(%arg0: tensor<5x1xf32>, %arg1: tensor<6x1xf32>) -> () {
+  // CHECK: [[AXIS:%.+]] = constant 0
+  // CHECK: [[STRIDE:%.+]]   = constant 1
+  // CHECK: [[OFFSET:%.+]] = constant 0 : index
+  // CHECK: [[IDX0:%.+]] = constant 0 : index
+  // CHECK: [[ARG0_DIM0:%.+]] = memref.dim %arg0, [[IDX0]]
+  // CHECK: [[IDX1:%.+]] = constant 1 : index
+  // CHECK: [[ARG0_DIM1:%.+]] = memref.dim %arg0, [[IDX1]]
+  // CHECK: [[ARG1_AXIS:%.+]] = memref.dim %arg1, [[AXIS]]
+  // CHECK: [[RESULT_AXIS:%.+]] = addi [[ARG0_DIM0]], [[ARG1_AXIS]]
+  // CHECK: [[INIT:%.+]] = linalg.init_tensor [11, 1]
+  // CHECK: [[ARG0_DIM0:%.+]] = memref.dim %arg0, [[AXIS]]
+  // CHECK: [[INSERT0:%.+]] = subtensor_insert %arg0 into [[INIT]]{{\[}}[[OFFSET]], [[OFFSET]]] {{\[}}[[ARG0_DIM0]], [[ARG0_DIM1]]] {{\[}}[[STRIDE]], [[STRIDE]]]
+  // CHECK: [[NEW_OFFSET:%.+]] = addi [[OFFSET]], [[ARG0_DIM0]]
+  // CHECK: [[ARG1_DIM0:%.+]] = memref.dim %arg1, [[AXIS]]
+  // CHECK: [[INSERT1:%.+]] = subtensor_insert %arg1 into [[INSERT0]]{{\[}}[[NEW_OFFSET]], [[OFFSET]]] {{\[}}[[ARG1_DIM0]], [[ARG0_DIM1]]] {{\[}}[[STRIDE]], [[STRIDE]]]
+  %0 = "tosa.concat"(%arg0, %arg1) { axis = 0 : i64} : (tensor<5x1xf32>, tensor<6x1xf32>)  -> (tensor<11x1xf32>)
+
+  // CHECK: [[AXIS:%.+]] = constant 1
+  // CHECK: [[STRIDE:%.+]]   = constant 1
+  // CHECK: [[OFFSET:%.+]] = constant 0 : index
+  // CHECK: [[IDX0:%.+]] = constant 0 : index
+  // CHECK: [[ARG0_DIM0:%.+]] = memref.dim %arg0, [[IDX0]]
+  // CHECK: [[IDX1:%.+]] = constant 1 : index
+  // CHECK: [[ARG0_DIM1:%.+]] = memref.dim %arg0, [[IDX1]]
+  // CHECK: [[ARG1_AXIS:%.+]] = memref.dim %arg0, [[AXIS]]
+  // CHECK: [[RESULT_AXIS:%.+]] = addi [[ARG0_DIM1]], [[ARG1_AXIS]]
+  // CHECK: [[INIT:%.+]] = linalg.init_tensor [5, 2]
+  // CHECK: [[ARG0_DIM1:%.+]] = memref.dim %arg0, [[AXIS]]
+  // CHECK: [[INSERT0:%.+]] = subtensor_insert %arg0 into [[INIT]]{{\[}}[[OFFSET]], [[OFFSET]]] {{\[}}[[ARG0_DIM0]], [[ARG0_DIM1]]] {{\[}}[[STRIDE]], [[STRIDE]]]
+  // CHECK: [[NEW_OFFSET:%.+]] = addi [[OFFSET]], [[ARG0_DIM1]]
+  // CHECK: [[ARG1_DIM1:%.+]] = memref.dim %arg0, [[AXIS]]
+  // CHECK: [[INSERT1:%.+]] = subtensor_insert %arg0 into [[INSERT0]]{{\[}}[[OFFSET]], [[NEW_OFFSET]]] {{\[}}[[ARG0_DIM0]], [[ARG1_DIM1]]] {{\[}}[[STRIDE]], [[STRIDE]]]
+  %1 = "tosa.concat"(%arg0, %arg0) { axis = 1 : i64} : (tensor<5x1xf32>, tensor<5x1xf32>)  -> (tensor<5x2xf32>)
+  return
+}


        


More information about the Mlir-commits mailing list