[Mlir-commits] [mlir] f7e1797 - [mlir][tosa] Lowering of tosa.gather operations with dynamic dimensions
Robert Suderman
llvmlistbot at llvm.org
Mon Apr 10 09:01:51 PDT 2023
Author: Spenser Bauman
Date: 2023-04-10T15:56:57Z
New Revision: f7e17975804f4736c660a5163e71c5be6395631f
URL: https://github.com/llvm/llvm-project/commit/f7e17975804f4736c660a5163e71c5be6395631f
DIFF: https://github.com/llvm/llvm-project/commit/f7e17975804f4736c660a5163e71c5be6395631f.diff
LOG: [mlir][tosa] Lowering of tosa.gather operations with dynamic dimensions
The existing TOSA->Linalg lowering of tosa.gather only supports gathers
with either a static shape or a single dynamic batch dimension.
This change extends support to arbitrary number of dynamic dimensions on
both the values and indices of the gather operation.
Reviewed By: rsuderman
Differential Revision: https://reviews.llvm.org/D147810
Added:
Modified:
mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index be24f5ee5feb4..b2e59b24979cb 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -1829,19 +1829,19 @@ class GatherConverter : public OpConversionPattern<tosa::GatherOp> {
auto input = adaptor.getOperands()[0];
auto indices = adaptor.getOperands()[1];
+ auto valuesTy =
+ op.getValues().getType().dyn_cast_or_null<RankedTensorType>();
auto resultTy = op.getType().cast<ShapedType>();
- auto dynamicDimsOr = checkHasDynamicBatchDims(
- rewriter, op, {input, indices, op.getOutput()});
- if (!dynamicDimsOr.has_value())
- return rewriter.notifyMatchFailure(
- op, "tosa.gather currently only supports dynamic batch dimensions");
- SmallVector<Value> dynamicDims = *dynamicDimsOr;
+ if (!valuesTy)
+ return rewriter.notifyMatchFailure(op, "unranked tensors not supported");
+
+ auto dynamicDims = inferDynamicDimsForGather(
+ rewriter, op.getLoc(), adaptor.getValues(), adaptor.getIndices());
auto resultElementTy = resultTy.getElementType();
auto loc = op.getLoc();
-
auto emptyTensor =
rewriter
.create<tensor::EmptyOp>(loc, resultTy.getShape(), resultElementTy,
@@ -1872,6 +1872,24 @@ class GatherConverter : public OpConversionPattern<tosa::GatherOp> {
rewriter.replaceOp(op, genericOp.getResult(0));
return success();
}
+
+ static llvm::SmallVector<Value> inferDynamicDimsForGather(OpBuilder &builder,
+ Location loc,
+ Value values,
+ Value indices) {
+ llvm::SmallVector<Value> results;
+
+ auto addDynamicDimension = [&](Value source, int64_t dim) {
+ auto dynamicDim = tensor::createDimValue(builder, loc, source, dim);
+ if (auto dimValue = dynamicDim.value().dyn_cast<Value>())
+ results.push_back(dimValue);
+ };
+
+ addDynamicDimension(values, 0);
+ addDynamicDimension(indices, 1);
+ addDynamicDimension(values, 2);
+ return results;
+ }
};
// Lowerings the TableOp to a series of gathers and numerica operations. This
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 476131b262fb9..e9e9037ebcdb5 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -1267,6 +1267,30 @@ func.func @gather_float_dyn(%arg0: tensor<?x3x2xf32>, %arg1: tensor<?x3xi32>) ->
// -----
+// CHECK-LABEL: @gather_float_all_dynamic
+// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]
+// CHECK-SAME: %[[ARG1:[0-9a-zA-Z_]*]]
+func.func @gather_float_all_dynamic(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?xi32>) -> () {
+ // CHECK: %[[C0:.+]] = arith.constant 0
+ // CHECK: %[[BATCH:.+]] = tensor.dim %[[ARG0]], %[[C0]]
+ // CHECK: %[[C1:.+]] = arith.constant 1
+ // CHECK: %[[INDEX:.+]] = tensor.dim %[[ARG1]], %[[C1]]
+ // CHECK: %[[C2:.+]] = arith.constant 2
+ // CHECK: %[[CHANNEL:.+]] = tensor.dim %[[ARG0]], %[[C2]]
+ // CHECK: %[[INIT:.+]] = tensor.empty(%[[BATCH]], %[[INDEX]], %[[CHANNEL]])
+ // CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[ARG1]] : tensor<?x?xi32>) outs(%[[INIT]] : tensor<?x?x?xf32>)
+ // CHECK: ^bb0(%[[BBARG0:.+]]: i32, %[[BBARG1:.+]]: f32)
+ // CHECK: %[[IDX0:.+]] = linalg.index 0
+ // CHECK: %[[CAST:.+]] = arith.index_cast %[[BBARG0]]
+ // CHECK: %[[IDX2:.+]] = linalg.index 2
+ // CHECK: %[[EXTRACT:.+]] = tensor.extract %[[ARG0]][%[[IDX0]], %[[CAST]], %[[IDX2]]] : tensor<?x?x?xf32>
+ // CHECK: linalg.yield %[[EXTRACT]]
+ %0 = "tosa.gather"(%arg0, %arg1) : (tensor<?x?x?xf32>, tensor<?x?xi32>) -> (tensor<?x?x?xf32>)
+ return
+}
+
+// -----
+
// CHECK-LABEL: @gather_int
// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]
// CHECK-SAME: %[[ARG1:[0-9a-zA-Z_]*]]
More information about the Mlir-commits
mailing list