[Mlir-commits] [mlir] 6f720d5 - [mlir][tosa] Add tosa.gather lowering to linalg.indexed_generic
Rob Suderman
llvmlistbot at llvm.org
Fri Apr 23 22:45:07 PDT 2021
Author: natashaknk
Date: 2021-04-23T22:42:56-07:00
New Revision: 6f720d5eca2e5a152e21e6e1d97c6e7df12e40af
URL: https://github.com/llvm/llvm-project/commit/6f720d5eca2e5a152e21e6e1d97c6e7df12e40af
DIFF: https://github.com/llvm/llvm-project/commit/6f720d5eca2e5a152e21e6e1d97c6e7df12e40af.diff
LOG: [mlir][tosa] Add tosa.gather lowering to linalg.indexed_generic
Lowering gather operation to linalg dialect.
Reviewed By: rsuderman
Differential Revision: https://reviews.llvm.org/D101200
Added:
Modified:
mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 21ef0b84a3dd5..042626e2a4776 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -1781,6 +1781,59 @@ class ArgMaxConverter : public OpRewritePattern<tosa::ArgMaxOp> {
}
};
+class GatherConverter : public OpConversionPattern<tosa::GatherOp> {
+public:
+ using OpConversionPattern<tosa::GatherOp>::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(tosa::GatherOp op, ArrayRef<Value> args,
+ ConversionPatternRewriter &rewriter) const final {
+ auto input = args[0];
+ auto indices = args[1];
+
+ auto inputTy = input.getType().cast<ShapedType>();
+ auto indicesTy = indices.getType().cast<ShapedType>();
+ auto resultTy = op.getType().cast<ShapedType>();
+
+ if (!inputTy.hasStaticShape() || !indicesTy.hasStaticShape())
+ return rewriter.notifyMatchFailure(
+ op, "require input type to have static shape");
+
+ auto resultElementTy = resultTy.getElementType();
+
+ auto loc = op.getLoc();
+
+ auto initTensor =
+ rewriter
+ .create<linalg::InitTensorOp>(loc, ArrayRef<Value>{},
+ resultTy.getShape(), resultElementTy)
+ .result();
+
+ SmallVector<AffineMap, 2> affineMaps = {
+ AffineMap::get(
+ /*dimCount=*/resultTy.getRank(), /*symbolCount=*/0,
+ {rewriter.getAffineDimExpr(0), rewriter.getAffineDimExpr(1)},
+ rewriter.getContext()),
+ rewriter.getMultiDimIdentityMap(resultTy.getRank())};
+
+ auto genericOp = rewriter.create<linalg::IndexedGenericOp>(
+ loc, ArrayRef<Type>({resultTy}), ValueRange{indices},
+ ValueRange{initTensor}, affineMaps,
+ getNParallelLoopsAttrs(resultTy.getRank()),
+ [&](OpBuilder &b, Location loc, ValueRange indices, ValueRange args) {
+ auto indexValue = args[0];
+ auto index0 = indices[0];
+ Value index1 = rewriter.create<IndexCastOp>(
+ loc, rewriter.getIndexType(), indexValue);
+ auto index2 = indices[2];
+ Value extract = rewriter.create<tensor::ExtractOp>(
+ loc, input, ValueRange{index0, index1, index2});
+ rewriter.create<linalg::YieldOp>(loc, extract);
+ });
+ rewriter.replaceOp(op, genericOp.getResult(0));
+ return success();
+ }
+};
+
// Lowerings the TableOp to a series of gathers and numerica operations. This
// includes interpolation between the high/low values. For the I8 varient, this
// simplifies to a single gather operation.
@@ -2085,6 +2138,7 @@ void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns(
ArgMaxConverter,
ConcatConverter,
Conv2DConverter,
+ GatherConverter,
PadConverter,
ReshapeConverter,
RescaleConverter,
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index ff4dbf4ac0529..489bdd3cd94fd 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -833,6 +833,32 @@ func @argmax(%arg0 : tensor<3x2xi32>, %arg1 : tensor<6xf32>) -> () {
// -----
+// CHECK-LABEL: @gather_float
+func @gather_float(%arg0: tensor<2x3x2xf32>, %arg1: tensor<2x3xi32>) -> () {
+ // CHECK: %[[INIT:.+]] = linalg.init_tensor [2, 3, 2]
+ // CHECK: %[[GENERIC:.+]] = linalg.indexed_generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg1 : tensor<2x3xi32>) outs(%[[INIT]] : tensor<2x3x2xf32>)
+ // CHECK: ^bb0(%[[IDX0:.+]]: index, %[[IDX1:.+]]: index, %[[IDX2:.+]]: index, %[[ARG0:.+]]: i32, %[[ARG1:.+]]: f32)
+ // CHECK: %[[CAST:.+]] = index_cast %[[ARG0]]
+ // CHECK: %[[EXTRACT:.+]] = tensor.extract %arg0[%[[IDX0]], %[[CAST]], %[[IDX2]]] : tensor<2x3x2xf32>
+ // CHECK: linalg.yield %[[EXTRACT]]
+ %0 = "tosa.gather"(%arg0, %arg1) : (tensor<2x3x2xf32>, tensor<2x3xi32>) -> (tensor<2x3x2xf32>)
+ return
+}
+
+// CHECK-LABEL: @gather_int
+func @gather_int(%arg0: tensor<2x3x2xi32>, %arg1: tensor<2x3xi32>) -> () {
+ // CHECK: %[[INIT:.+]] = linalg.init_tensor [2, 3, 2]
+ // CHECK: %[[GENERIC:.+]] = linalg.indexed_generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg1 : tensor<2x3xi32>) outs(%[[INIT]] : tensor<2x3x2xi32>)
+ // CHECK: ^bb0(%[[IDX0:.+]]: index, %[[IDX1:.+]]: index, %[[IDX2:.+]]: index, %[[ARG0:.+]]: i32, %[[ARG1:.+]]: i32)
+ // CHECK: %[[CAST:.+]] = index_cast %[[ARG0]]
+ // CHECK: %[[EXTRACT:.+]] = tensor.extract %arg0[%[[IDX0]], %[[CAST]], %[[IDX2]]] : tensor<2x3x2xi32>
+ // CHECK: linalg.yield %[[EXTRACT]]
+ %0 = "tosa.gather"(%arg0, %arg1) : (tensor<2x3x2xi32>, tensor<2x3xi32>) -> (tensor<2x3x2xi32>)
+ return
+}
+
+// -----
+
// CHECK-LABEL: @table8
func @table8(%arg0: tensor<6xi8>, %arg1: tensor<513xi8>) -> () {
// CHECK: %[[INIT:.+]] = linalg.init_tensor [6]
More information about the Mlir-commits
mailing list