[Mlir-commits] [mlir] 5627564 - [mlir][tosa] Add tosa.concat to subtensor inserts lowering
Rob Suderman
llvmlistbot at llvm.org
Thu Mar 18 16:00:30 PDT 2021
Author: Rob Suderman
Date: 2021-03-18T15:59:07-07:00
New Revision: 5627564fe053bd257385157cea43e795e7c48e3f
URL: https://github.com/llvm/llvm-project/commit/5627564fe053bd257385157cea43e795e7c48e3f
DIFF: https://github.com/llvm/llvm-project/commit/5627564fe053bd257385157cea43e795e7c48e3f.diff
LOG: [mlir][tosa] Add tosa.concat to subtensor inserts lowering
Includes lowering for tosa.concat with indice computation with subtensor insert
operations. Includes tests along two different indices.
Differential Revision: https://reviews.llvm.org/D98813
Added:
Modified:
mlir/lib/Conversion/TosaToLinalg/CMakeLists.txt
mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/TosaToLinalg/CMakeLists.txt b/mlir/lib/Conversion/TosaToLinalg/CMakeLists.txt
index 8a53b9da025b..a44621ec6033 100644
--- a/mlir/lib/Conversion/TosaToLinalg/CMakeLists.txt
+++ b/mlir/lib/Conversion/TosaToLinalg/CMakeLists.txt
@@ -14,6 +14,7 @@ add_mlir_conversion_library(MLIRTosaToLinalg
MLIRLinalg
MLIRLinalgUtils
MLIRMath
+ MLIRMemRef
MLIRPass
MLIRTosa
MLIRTosaTransforms
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 2fe4aa31e482..dd2725cbd0fa 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -13,6 +13,7 @@
#include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/IR/Matchers.h"
@@ -657,6 +658,53 @@ class ReduceConverter : public OpRewritePattern<SrcOp> {
}
};
+struct ConcatOpConversion : public OpConversionPattern<tosa::ConcatOp> {
+ using OpConversionPattern<tosa::ConcatOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(tosa::ConcatOp op, ArrayRef<Value> args,
+ ConversionPatternRewriter &rewriter) const override {
+ auto resultType = op.getType().dyn_cast<RankedTensorType>();
+ if (!resultType || !resultType.hasStaticShape()) {
+ return rewriter.notifyMatchFailure(op,
+ "expected static shaped tensor type");
+ }
+
+ Location loc = op.getLoc();
+ int axis = op.axis();
+ Value axisValue =
+ rewriter.create<ConstantOp>(loc, rewriter.getIndexAttr(axis));
+ int rank = resultType.getRank();
+ SmallVector<Value, 3> offsets, sizes, strides;
+ sizes.reserve(rank);
+ strides.resize(rank, rewriter.create<ConstantIndexOp>(loc, 1));
+ offsets.resize(rank, rewriter.create<ConstantIndexOp>(loc, 0));
+
+ for (int i = 0; i < rank; ++i) {
+ sizes.push_back(rewriter.create<memref::DimOp>(loc, args[0], i));
+ }
+
+ Value resultDimSize = sizes[axis];
+ for (auto arg : args.drop_front()) {
+ auto size = rewriter.create<memref::DimOp>(loc, arg, axisValue);
+ resultDimSize = rewriter.create<AddIOp>(loc, resultDimSize, size);
+ }
+ sizes[axis] = resultDimSize;
+
+ Value result = rewriter.create<linalg::InitTensorOp>(
+ loc, resultType.getShape(), resultType.getElementType());
+
+ for (auto arg : args) {
+ sizes[axis] = rewriter.create<memref::DimOp>(loc, arg, axisValue);
+ result = rewriter.create<SubTensorInsertOp>(loc, arg, result, offsets,
+ sizes, strides);
+ offsets[axis] = rewriter.create<AddIOp>(loc, offsets[axis], sizes[axis]);
+ }
+ rewriter.replaceOp(op, result);
+ return success();
+ }
+};
+
} // namespace
void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns(
@@ -680,6 +728,6 @@ void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns(
IdentityNConverter<tosa::IdentityOp>,
IdentityNConverter<tosa::IdentityNOp>, ReduceConverter<tosa::ReduceMinOp>,
ReduceConverter<tosa::ReduceMaxOp>, ReduceConverter<tosa::ReduceSumOp>,
- ReduceConverter<tosa::ReduceProdOp>, ReshapeOpConverter,
- TransposeConverter>(context);
+ ReduceConverter<tosa::ReduceProdOp>, ConcatOpConversion,
+ ReshapeOpConverter, TransposeConverter>(context);
}
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
index 8ccf83529457..a1bd694f67af 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
@@ -14,6 +14,7 @@
#include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Tosa/Transforms/PassDetail.h"
@@ -31,14 +32,15 @@ struct TosaToLinalgOnTensors
: public TosaToLinalgOnTensorsBase<TosaToLinalgOnTensors> {
public:
void getDependentDialects(DialectRegistry ®istry) const override {
- registry
- .insert<linalg::LinalgDialect, math::MathDialect, StandardOpsDialect>();
+ registry.insert<linalg::LinalgDialect, math::MathDialect,
+ memref::MemRefDialect, StandardOpsDialect>();
}
void runOnFunction() override {
OwningRewritePatternList patterns;
ConversionTarget target(getContext());
- target.addLegalDialect<linalg::LinalgDialect, StandardOpsDialect>();
+ target.addLegalDialect<linalg::LinalgDialect, memref::MemRefDialect,
+ StandardOpsDialect>();
target.addIllegalDialect<tosa::TosaDialect>();
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index d1868e7683ce..9b1f6054ee06 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -433,3 +433,43 @@ func @reduce_int(%arg0: tensor<5x4xi32>) -> () {
%4 = "tosa.reduce_max"(%arg0) {axis = 0 : i64} : (tensor<5x4xi32>) -> tensor<4xi32>
return
}
+
+// -----
+
+// CHECK-LABEL: @concat
+func @concat(%arg0: tensor<5x1xf32>, %arg1: tensor<6x1xf32>) -> () {
+ // CHECK: [[AXIS:%.+]] = constant 0
+ // CHECK: [[STRIDE:%.+]] = constant 1
+ // CHECK: [[OFFSET:%.+]] = constant 0 : index
+ // CHECK: [[IDX0:%.+]] = constant 0 : index
+ // CHECK: [[ARG0_DIM0:%.+]] = memref.dim %arg0, [[IDX0]]
+ // CHECK: [[IDX1:%.+]] = constant 1 : index
+ // CHECK: [[ARG0_DIM1:%.+]] = memref.dim %arg0, [[IDX1]]
+ // CHECK: [[ARG1_AXIS:%.+]] = memref.dim %arg1, [[AXIS]]
+ // CHECK: [[RESULT_AXIS:%.+]] = addi [[ARG0_DIM0]], [[ARG1_AXIS]]
+ // CHECK: [[INIT:%.+]] = linalg.init_tensor [11, 1]
+ // CHECK: [[ARG0_DIM0:%.+]] = memref.dim %arg0, [[AXIS]]
+ // CHECK: [[INSERT0:%.+]] = subtensor_insert %arg0 into [[INIT]]{{\[}}[[OFFSET]], [[OFFSET]]] {{\[}}[[ARG0_DIM0]], [[ARG0_DIM1]]] {{\[}}[[STRIDE]], [[STRIDE]]]
+ // CHECK: [[NEW_OFFSET:%.+]] = addi [[OFFSET]], [[ARG0_DIM0]]
+ // CHECK: [[ARG1_DIM0:%.+]] = memref.dim %arg1, [[AXIS]]
+ // CHECK: [[INSERT1:%.+]] = subtensor_insert %arg1 into [[INSERT0]]{{\[}}[[NEW_OFFSET]], [[OFFSET]]] {{\[}}[[ARG1_DIM0]], [[ARG0_DIM1]]] {{\[}}[[STRIDE]], [[STRIDE]]]
+ %0 = "tosa.concat"(%arg0, %arg1) { axis = 0 : i64} : (tensor<5x1xf32>, tensor<6x1xf32>) -> (tensor<11x1xf32>)
+
+ // CHECK: [[AXIS:%.+]] = constant 1
+ // CHECK: [[STRIDE:%.+]] = constant 1
+ // CHECK: [[OFFSET:%.+]] = constant 0 : index
+ // CHECK: [[IDX0:%.+]] = constant 0 : index
+ // CHECK: [[ARG0_DIM0:%.+]] = memref.dim %arg0, [[IDX0]]
+ // CHECK: [[IDX1:%.+]] = constant 1 : index
+ // CHECK: [[ARG0_DIM1:%.+]] = memref.dim %arg0, [[IDX1]]
+ // CHECK: [[ARG1_AXIS:%.+]] = memref.dim %arg0, [[AXIS]]
+ // CHECK: [[RESULT_AXIS:%.+]] = addi [[ARG0_DIM1]], [[ARG1_AXIS]]
+ // CHECK: [[INIT:%.+]] = linalg.init_tensor [5, 2]
+ // CHECK: [[ARG0_DIM1:%.+]] = memref.dim %arg0, [[AXIS]]
+ // CHECK: [[INSERT0:%.+]] = subtensor_insert %arg0 into [[INIT]]{{\[}}[[OFFSET]], [[OFFSET]]] {{\[}}[[ARG0_DIM0]], [[ARG0_DIM1]]] {{\[}}[[STRIDE]], [[STRIDE]]]
+ // CHECK: [[NEW_OFFSET:%.+]] = addi [[OFFSET]], [[ARG0_DIM1]]
+ // CHECK: [[ARG1_DIM1:%.+]] = memref.dim %arg0, [[AXIS]]
+ // CHECK: [[INSERT1:%.+]] = subtensor_insert %arg0 into [[INSERT0]]{{\[}}[[OFFSET]], [[NEW_OFFSET]]] {{\[}}[[ARG0_DIM0]], [[ARG1_DIM1]]] {{\[}}[[STRIDE]], [[STRIDE]]]
+ %1 = "tosa.concat"(%arg0, %arg0) { axis = 1 : i64} : (tensor<5x1xf32>, tensor<5x1xf32>) -> (tensor<5x2xf32>)
+ return
+}
More information about the Mlir-commits
mailing list