[Mlir-commits] [mlir] e377520 - [mlir] Move tosa.concat lowering from TosaToLinalg to TosaToTensor
Maya Amrami
llvmlistbot at llvm.org
Tue Mar 14 02:24:08 PDT 2023
Author: Maya Amrami
Date: 2023-03-14T11:24:01+02:00
New Revision: e377520a47e6ced171fe1a4a39b91297326a1817
URL: https://github.com/llvm/llvm-project/commit/e377520a47e6ced171fe1a4a39b91297326a1817
DIFF: https://github.com/llvm/llvm-project/commit/e377520a47e6ced171fe1a4a39b91297326a1817.diff
LOG: [mlir] Move tosa.concat lowering from TosaToLinalg to TosaToTensor
tosa.concat is lowered to tensor.insert_slice thus it should be in
TosaToTensor rather than in TosaToLinalg.
Reviewed By: rsuderman
Differential Revision: https://reviews.llvm.org/D145952
Added:
Modified:
mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 82cadf2a07daa..f6ca01949632a 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -1556,68 +1556,6 @@ class ReduceConverter : public OpRewritePattern<SrcOp> {
}
};
-struct ConcatConverter : public OpConversionPattern<tosa::ConcatOp> {
- using OpConversionPattern<tosa::ConcatOp>::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(tosa::ConcatOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- auto inputType = op.getOperand(0).getType().template cast<ShapedType>();
- auto resultType = op.getType().dyn_cast<RankedTensorType>();
-
- Location loc = op.getLoc();
- int axis = op.getAxis();
- Value axisValue = rewriter.createOrFold<arith::ConstantOp>(
- loc, rewriter.getIndexAttr(axis));
- int rank = resultType.getRank();
- SmallVector<Value, 3> offsets, sizes, strides;
- sizes.reserve(rank);
- strides.resize(rank, rewriter.create<arith::ConstantIndexOp>(loc, 1));
- offsets.resize(rank, rewriter.create<arith::ConstantIndexOp>(loc, 0));
-
- SmallVector<Value> dynDims;
- for (int i = 0; i < rank; ++i) {
- sizes.push_back(rewriter.createOrFold<tensor::DimOp>(
- loc, adaptor.getOperands()[0], i));
- if (inputType.isDynamicDim(i)) {
- dynDims.push_back(
- rewriter.create<tensor::DimOp>(loc, op.getOperand(0), i));
- }
- }
-
- Value resultDimSize = sizes[axis];
- for (auto arg : adaptor.getOperands().drop_front()) {
- auto size = rewriter.createOrFold<tensor::DimOp>(loc, arg, axisValue);
- resultDimSize =
- rewriter.createOrFold<arith::AddIOp>(loc, resultDimSize, size);
- }
- sizes[axis] = resultDimSize;
-
- Value emptyTensor = rewriter.create<tensor::EmptyOp>(
- loc, resultType.getShape(), resultType.getElementType(), dynDims);
-
- auto toOpFoldResult = [](Value v) -> OpFoldResult {
- auto op = v.getDefiningOp<arith::ConstantIndexOp>();
- if (!op)
- return v;
- return op.getValue();
- };
- Value result = emptyTensor;
- for (auto arg : adaptor.getOperands()) {
- sizes[axis] = rewriter.createOrFold<tensor::DimOp>(loc, arg, axisValue);
- result = rewriter.createOrFold<tensor::InsertSliceOp>(
- loc, arg, result,
- llvm::to_vector(llvm::map_range(offsets, toOpFoldResult)),
- llvm::to_vector(llvm::map_range(sizes, toOpFoldResult)),
- llvm::to_vector(llvm::map_range(strides, toOpFoldResult)));
- offsets[axis] =
- rewriter.createOrFold<arith::AddIOp>(loc, offsets[axis], sizes[axis]);
- }
- rewriter.replaceOp(op, result);
- return success();
- }
-};
-
class ReverseConverter : public OpRewritePattern<tosa::ReverseOp> {
public:
using OpRewritePattern<tosa::ReverseOp>::OpRewritePattern;
@@ -2110,7 +2048,6 @@ void mlir::tosa::populateTosaToLinalgConversionPatterns(
ReduceConverter<tosa::ReduceSumOp>,
ReduceConverter<tosa::ReduceProdOp>,
ArgMaxConverter,
- ConcatConverter,
GatherConverter,
RescaleConverter,
ReverseConverter,
diff --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
index b276c773d55ba..1ef31df2defd7 100644
--- a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
+++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
@@ -349,11 +349,74 @@ class PadConverter : public OpRewritePattern<tosa::PadOp> {
}
};
+struct ConcatConverter : public OpConversionPattern<tosa::ConcatOp> {
+ using OpConversionPattern<tosa::ConcatOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(tosa::ConcatOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto inputType = op.getOperand(0).getType().template cast<ShapedType>();
+ auto resultType = op.getType().dyn_cast<RankedTensorType>();
+
+ Location loc = op.getLoc();
+ int axis = op.getAxis();
+ Value axisValue = rewriter.createOrFold<arith::ConstantOp>(
+ loc, rewriter.getIndexAttr(axis));
+ int rank = resultType.getRank();
+ SmallVector<Value, 3> offsets, sizes, strides;
+ sizes.reserve(rank);
+ strides.resize(rank, rewriter.create<arith::ConstantIndexOp>(loc, 1));
+ offsets.resize(rank, rewriter.create<arith::ConstantIndexOp>(loc, 0));
+
+ SmallVector<Value> dynDims;
+ for (int i = 0; i < rank; ++i) {
+ sizes.push_back(rewriter.createOrFold<tensor::DimOp>(
+ loc, adaptor.getOperands()[0], i));
+ if (inputType.isDynamicDim(i)) {
+ dynDims.push_back(
+ rewriter.create<tensor::DimOp>(loc, op.getOperand(0), i));
+ }
+ }
+
+ Value resultDimSize = sizes[axis];
+ for (auto arg : adaptor.getOperands().drop_front()) {
+ auto size = rewriter.createOrFold<tensor::DimOp>(loc, arg, axisValue);
+ resultDimSize =
+ rewriter.createOrFold<arith::AddIOp>(loc, resultDimSize, size);
+ }
+ sizes[axis] = resultDimSize;
+
+ Value emptyTensor = rewriter.create<tensor::EmptyOp>(
+ loc, resultType.getShape(), resultType.getElementType(), dynDims);
+
+ auto toOpFoldResult = [](Value v) -> OpFoldResult {
+ auto op = v.getDefiningOp<arith::ConstantIndexOp>();
+ if (!op)
+ return v;
+ return op.getValue();
+ };
+ Value result = emptyTensor;
+ for (auto arg : adaptor.getOperands()) {
+ sizes[axis] = rewriter.createOrFold<tensor::DimOp>(loc, arg, axisValue);
+ result = rewriter.createOrFold<tensor::InsertSliceOp>(
+ loc, arg, result,
+ llvm::to_vector(llvm::map_range(offsets, toOpFoldResult)),
+ llvm::to_vector(llvm::map_range(sizes, toOpFoldResult)),
+ llvm::to_vector(llvm::map_range(strides, toOpFoldResult)));
+ offsets[axis] =
+ rewriter.createOrFold<arith::AddIOp>(loc, offsets[axis], sizes[axis]);
+ }
+ rewriter.replaceOp(op, result);
+ return success();
+ }
+};
+
} // namespace
void mlir::tosa::populateTosaToTensorConversionPatterns(
RewritePatternSet *patterns) {
- patterns->add<SliceConverter, PadConverter>(patterns->getContext());
+ patterns->add<SliceConverter, PadConverter, ConcatConverter>(
+ patterns->getContext());
patterns->add<ReshapeConverterCollapse>(patterns->getContext(),
/*benefit=*/100);
patterns->add<ReshapeConverterExpand>(patterns->getContext(),
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 138a30fa837d4..427fe6b2f16b2 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -823,79 +823,6 @@ func.func @reduce_bool(%arg0: tensor<5x4xi1>) -> () {
return
}
-// -----
-
-// CHECK-LABEL: @concat
-// CHECK-SAME: %[[ARG0:.+]]: tensor<5x1xf32>
-// CHECK-SAME: %[[ARG1:.+]]: tensor<6x1xf32>
-func.func @concat(%arg0: tensor<5x1xf32>, %arg1: tensor<6x1xf32>) -> () {
- // CHECK: [[AXIS:%.+]] = arith.constant 0
- // CHECK: [[STRIDE:%.+]] = arith.constant 1
- // CHECK: [[OFFSET:%.+]] = arith.constant 0 : index
- // CHECK: [[IDX0:%.+]] = arith.constant 0 : index
- // CHECK: [[IDX1:%.+]] = arith.constant 1 : index
- // CHECK: [[INIT:%.+]] = tensor.empty() : tensor<11x1xf32>
- // CHECK: [[INSERT0:%.+]] = tensor.insert_slice %[[ARG0]] into [[INIT]][0, 0] [5, 1] [1, 1]
- // CHECK: [[INSERT1:%.+]] = tensor.insert_slice %[[ARG1]] into [[INSERT0]][5, 0] [6, 1] [1, 1]
- %0 = "tosa.concat"(%arg0, %arg1) { axis = 0 : i64} : (tensor<5x1xf32>, tensor<6x1xf32>) -> (tensor<11x1xf32>)
-
- // CHECK: [[AXIS:%.+]] = arith.constant 1
- // CHECK: [[STRIDE:%.+]] = arith.constant 1
- // CHECK: [[OFFSET:%.+]] = arith.constant 0 : index
- // CHECK: [[IDX0:%.+]] = arith.constant 0 : index
- // CHECK: [[IDX1:%.+]] = arith.constant 1 : index
- // CHECK: [[INIT:%.+]] = tensor.empty() : tensor<5x2xf32>
- // CHECK: [[INSERT0:%.+]] = tensor.insert_slice %[[ARG0]] into [[INIT]][0, 0] [5, 1] [1, 1]
- // CHECK: [[INSERT1:%.+]] = tensor.insert_slice %[[ARG0]] into [[INSERT0]][0, 1] [5, 1] [1, 1]
- %1 = "tosa.concat"(%arg0, %arg0) { axis = 1 : i64} : (tensor<5x1xf32>, tensor<5x1xf32>) -> (tensor<5x2xf32>)
- return
-}
-
-// -----
-
-// CHECK-LABEL: @concat_non_axis_dyn
-// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
-// CHECK-SAME: %[[ARG1:[0-9a-zA-Z_]*]]
-func.func @concat_non_axis_dyn(%arg0: tensor<5x?xf32>, %arg1: tensor<6x?xf32>) -> () {
- // CHECK: %[[AXIS:.+]] = arith.constant 0
- // CHECK: %[[STRIDE:.+]] = arith.constant 1
- // CHECK: %[[OFFSET:.+]] = arith.constant 0 : index
- // CHECK: %[[IDX0:.+]] = arith.constant 0 : index
- // CHECK: %[[IDX1:.+]] = arith.constant 1 : index
- // CHECK: %[[SIZE:.+]] = tensor.dim %[[ARG0]], %[[IDX1]]
- // CHECK: %[[IDX1_2:.+]] = arith.constant 1 : index
- // CHECK: %[[DYN:.+]] = tensor.dim %[[ARG0]], %[[IDX1_2]]
- // CHECK: %[[INIT:.+]] = tensor.empty(%[[DYN]]) : tensor<11x?xf32>
- // CHECK: %[[INSERT0:.+]] = tensor.insert_slice %[[ARG0]] into %[[INIT]][0, 0] [5, %[[SIZE]]] [1, 1]
- // CHECK: %[[INSERT1:.+]] = tensor.insert_slice %[[ARG1]] into %[[INSERT0]][5, 0] [6, %[[SIZE]]] [1, 1]
- %0 = "tosa.concat"(%arg0, %arg1) { axis = 0 : i64} : (tensor<5x?xf32>, tensor<6x?xf32>) -> (tensor<11x?xf32>)
- return
-}
-
-// -----
-
-// CHECK-LABEL: @concat_axis_dyn
-// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
-// CHECK-SAME: %[[ARG1:[0-9a-zA-Z_]*]]:
-func.func @concat_axis_dyn(%arg0: tensor<?x3xf32>, %arg1: tensor<?x3xf32>) -> () {
- // CHECK: %[[AXIS:.+]] = arith.constant 0
- // CHECK: %[[STRIDE:.+]] = arith.constant 1
- // CHECK: %[[OFFSET:.+]] = arith.constant 0 : index
- // CHECK: %[[IDX0:.+]] = arith.constant 0 : index
- // CHECK: %[[SIZE:.+]] = tensor.dim %[[ARG0]], %[[IDX0]]
- // CHECK: %[[IDX0_2:.+]] = arith.constant 0 : index
- // CHECK: %[[DYN:.+]] = tensor.dim %[[ARG0]], %[[IDX0_2]]
- // CHECK: %[[IDX1:.+]] = arith.constant 1 : index
- // CHECK: %[[INIT:.+]] = tensor.empty(%[[DYN]]) : tensor<?x3xf32>
- // CHECK: %[[DYN1:.+]] = tensor.dim %[[ARG0]], %[[AXIS]]
- // CHECK: %[[INSERT0:.+]] = tensor.insert_slice %[[ARG0]] into %[[INIT]][0, 0] [%[[DYN1]], 3] [1, 1]
- // CHECK: %[[SUM:.+]] = arith.addi %[[OFFSET]], %[[DYN1]]
- // CHECK: %[[DYN2:.+]] = tensor.dim %[[ARG1]], %[[AXIS]]
- // CHECK: %[[INSERT1:.+]] = tensor.insert_slice %[[ARG1]] into %[[INSERT0]][%[[SUM]], 0] [%[[DYN2]], 3] [1, 1]
- %0 = "tosa.concat"(%arg0, %arg1) { axis = 0 : i64} : (tensor<?x3xf32>, tensor<?x3xf32>) -> (tensor<?x3xf32>)
- return
-}
-
// -----
// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>
diff --git a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
index 34084ebf5d3ce..d96c7848ed5b1 100644
--- a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
+++ b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
@@ -195,3 +195,76 @@ func.func @pad_dyn_padding(%arg0 : tensor<1x2xf32>) -> (tensor<?x9xf32>) {
%1 = "tosa.pad"(%arg0, %0) : (tensor<1x2xf32>, tensor<2x2xi32>) -> (tensor<?x9xf32>)
return %1 : tensor<?x9xf32>
}
+
+// -----
+
+// CHECK-LABEL: @concat
+// CHECK-SAME: %[[ARG0:.+]]: tensor<5x1xf32>
+// CHECK-SAME: %[[ARG1:.+]]: tensor<6x1xf32>
+func.func @concat(%arg0: tensor<5x1xf32>, %arg1: tensor<6x1xf32>) -> () {
+ // CHECK: [[AXIS:%.+]] = arith.constant 0
+ // CHECK: [[STRIDE:%.+]] = arith.constant 1
+ // CHECK: [[OFFSET:%.+]] = arith.constant 0 : index
+ // CHECK: [[IDX0:%.+]] = arith.constant 0 : index
+ // CHECK: [[IDX1:%.+]] = arith.constant 1 : index
+ // CHECK: [[INIT:%.+]] = tensor.empty() : tensor<11x1xf32>
+ // CHECK: [[INSERT0:%.+]] = tensor.insert_slice %[[ARG0]] into [[INIT]][0, 0] [5, 1] [1, 1]
+ // CHECK: [[INSERT1:%.+]] = tensor.insert_slice %[[ARG1]] into [[INSERT0]][5, 0] [6, 1] [1, 1]
+ %0 = "tosa.concat"(%arg0, %arg1) { axis = 0 : i64} : (tensor<5x1xf32>, tensor<6x1xf32>) -> (tensor<11x1xf32>)
+
+ // CHECK: [[AXIS:%.+]] = arith.constant 1
+ // CHECK: [[STRIDE:%.+]] = arith.constant 1
+ // CHECK: [[OFFSET:%.+]] = arith.constant 0 : index
+ // CHECK: [[IDX0:%.+]] = arith.constant 0 : index
+ // CHECK: [[IDX1:%.+]] = arith.constant 1 : index
+ // CHECK: [[INIT:%.+]] = tensor.empty() : tensor<5x2xf32>
+ // CHECK: [[INSERT0:%.+]] = tensor.insert_slice %[[ARG0]] into [[INIT]][0, 0] [5, 1] [1, 1]
+ // CHECK: [[INSERT1:%.+]] = tensor.insert_slice %[[ARG0]] into [[INSERT0]][0, 1] [5, 1] [1, 1]
+ %1 = "tosa.concat"(%arg0, %arg0) { axis = 1 : i64} : (tensor<5x1xf32>, tensor<5x1xf32>) -> (tensor<5x2xf32>)
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @concat_non_axis_dyn
+// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
+// CHECK-SAME: %[[ARG1:[0-9a-zA-Z_]*]]
+func.func @concat_non_axis_dyn(%arg0: tensor<5x?xf32>, %arg1: tensor<6x?xf32>) -> () {
+ // CHECK: %[[AXIS:.+]] = arith.constant 0
+ // CHECK: %[[STRIDE:.+]] = arith.constant 1
+ // CHECK: %[[OFFSET:.+]] = arith.constant 0 : index
+ // CHECK: %[[IDX0:.+]] = arith.constant 0 : index
+ // CHECK: %[[IDX1:.+]] = arith.constant 1 : index
+ // CHECK: %[[SIZE:.+]] = tensor.dim %[[ARG0]], %[[IDX1]]
+ // CHECK: %[[IDX1_2:.+]] = arith.constant 1 : index
+ // CHECK: %[[DYN:.+]] = tensor.dim %[[ARG0]], %[[IDX1_2]]
+ // CHECK: %[[INIT:.+]] = tensor.empty(%[[DYN]]) : tensor<11x?xf32>
+ // CHECK: %[[INSERT0:.+]] = tensor.insert_slice %[[ARG0]] into %[[INIT]][0, 0] [5, %[[SIZE]]] [1, 1]
+ // CHECK: %[[INSERT1:.+]] = tensor.insert_slice %[[ARG1]] into %[[INSERT0]][5, 0] [6, %[[SIZE]]] [1, 1]
+ %0 = "tosa.concat"(%arg0, %arg1) { axis = 0 : i64} : (tensor<5x?xf32>, tensor<6x?xf32>) -> (tensor<11x?xf32>)
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @concat_axis_dyn
+// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
+// CHECK-SAME: %[[ARG1:[0-9a-zA-Z_]*]]:
+func.func @concat_axis_dyn(%arg0: tensor<?x3xf32>, %arg1: tensor<?x3xf32>) -> () {
+ // CHECK: %[[AXIS:.+]] = arith.constant 0
+ // CHECK: %[[STRIDE:.+]] = arith.constant 1
+ // CHECK: %[[OFFSET:.+]] = arith.constant 0 : index
+ // CHECK: %[[IDX0:.+]] = arith.constant 0 : index
+ // CHECK: %[[SIZE:.+]] = tensor.dim %[[ARG0]], %[[IDX0]]
+ // CHECK: %[[IDX0_2:.+]] = arith.constant 0 : index
+ // CHECK: %[[DYN:.+]] = tensor.dim %[[ARG0]], %[[IDX0_2]]
+ // CHECK: %[[IDX1:.+]] = arith.constant 1 : index
+ // CHECK: %[[INIT:.+]] = tensor.empty(%[[DYN]]) : tensor<?x3xf32>
+ // CHECK: %[[DYN1:.+]] = tensor.dim %[[ARG0]], %[[AXIS]]
+ // CHECK: %[[INSERT0:.+]] = tensor.insert_slice %[[ARG0]] into %[[INIT]][0, 0] [%[[DYN1]], 3] [1, 1]
+ // CHECK: %[[SUM:.+]] = arith.addi %[[OFFSET]], %[[DYN1]]
+ // CHECK: %[[DYN2:.+]] = tensor.dim %[[ARG1]], %[[AXIS]]
+ // CHECK: %[[INSERT1:.+]] = tensor.insert_slice %[[ARG1]] into %[[INSERT0]][%[[SUM]], 0] [%[[DYN2]], 3] [1, 1]
+ %0 = "tosa.concat"(%arg0, %arg1) { axis = 0 : i64} : (tensor<?x3xf32>, tensor<?x3xf32>) -> (tensor<?x3xf32>)
+ return
+}
More information about the Mlir-commits
mailing list