[Mlir-commits] [mlir] 28e6420 - [mlir][tosa] Add tosa.argmax to linalg lowering
Rob Suderman
llvmlistbot at llvm.org
Tue Mar 23 16:07:36 PDT 2021
Author: Rob Suderman
Date: 2021-03-23T16:06:55-07:00
New Revision: 28e6420744f52cc39df0d0529c09385e61ddb8ef
URL: https://github.com/llvm/llvm-project/commit/28e6420744f52cc39df0d0529c09385e61ddb8ef
DIFF: https://github.com/llvm/llvm-project/commit/28e6420744f52cc39df0d0529c09385e61ddb8ef.diff
LOG: [mlir][tosa] Add tosa.argmax to linalg lowering
Tosa's argmax lowering is representable as a linalg.indexed_generic
operation. Include the lowering to this type for both integer and
floating point types.
Differential Revision: https://reviews.llvm.org/D99137
Added:
Modified:
mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index a4b6f826feb6..2f6246e717eb 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -488,6 +488,15 @@ static Attribute createInitialValueForReduceOp(Operation *op, Type elementTy,
return rewriter.getIntegerAttr(
elementTy, APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth()));
+ if (isa<tosa::ArgMaxOp>(op) && elementTy.isa<FloatType>())
+ return rewriter.getFloatAttr(
+ elementTy, APFloat::getLargest(
+ elementTy.cast<FloatType>().getFloatSemantics(), true));
+
+ if (isa<tosa::ArgMaxOp>(op) && elementTy.isa<IntegerType>())
+ return rewriter.getIntegerAttr(
+ elementTy, APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth()));
+
return {};
}
@@ -1233,6 +1242,131 @@ class PadConverter : public OpRewritePattern<tosa::PadOp> {
}
};
+// Tosa argmax lowering represents the ArgMax op as an linalg.indexed_generic
+// op, producing two output buffers.
+//
+// The first output buffer contains the index of the found maximum value. It is
+// initialized to 0 and is resulting integer type.
+//
+// The second output buffer contains the maximum value found. It is initialized
+// to the minimum representable value of the input element type. After being
+// populated by indexed_generic, this buffer is disgarded as only the index is
+// requested.
+//
+// The indexed_generic op updates both the maximum value and index if the
+// current value exceeds the running max.
+class ArgMaxConverter : public OpRewritePattern<tosa::ArgMaxOp> {
+public:
+ using OpRewritePattern<tosa::ArgMaxOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tosa::ArgMaxOp argmaxOp,
+ PatternRewriter &rewriter) const final {
+ auto loc = argmaxOp.getLoc();
+ Value input = argmaxOp.input();
+ auto inputTy = input.getType().cast<ShapedType>();
+ auto resultTy = argmaxOp.output().getType().cast<ShapedType>();
+ auto inElementTy = inputTy.getElementType();
+ auto outElementTy = resultTy.getElementType();
+ int axis = argmaxOp.axis();
+ auto resultMaxTy = RankedTensorType::get(resultTy.getShape(), inElementTy);
+
+ if (!inputTy.hasStaticShape())
+ return rewriter.notifyMatchFailure(
+ argmaxOp,
+ "tosa.arg_max to linalg.* requires statically shaped input");
+
+ if (!outElementTy.isa<IntegerType>())
+ return rewriter.notifyMatchFailure(
+ argmaxOp,
+ "tosa.arg_max to linalg.* requires integer-like result type");
+
+ // First fill the output buffer for the index.
+ auto initTensorIdx =
+ rewriter
+ .create<linalg::InitTensorOp>(loc, ArrayRef<Value>({}),
+ resultTy.getShape(), outElementTy)
+ .result();
+ auto fillValueIdx = rewriter.create<ConstantOp>(
+ loc, rewriter.getIntegerAttr(outElementTy, 0));
+ auto filledTensorIdx =
+ rewriter.create<linalg::FillOp>(loc, initTensorIdx, fillValueIdx)
+ .result();
+
+ // Second fill the output buffer for the running max.
+ auto initTensorMax =
+ rewriter
+ .create<linalg::InitTensorOp>(loc, ArrayRef<Value>({}),
+ resultTy.getShape(), inElementTy)
+ .result();
+ auto fillValueMaxAttr =
+ createInitialValueForReduceOp(argmaxOp, inElementTy, rewriter);
+
+ if (!fillValueMaxAttr)
+ return rewriter.notifyMatchFailure(
+ argmaxOp, "unsupported tosa.argmax element type");
+
+ auto fillValueMax = rewriter.create<ConstantOp>(loc, fillValueMaxAttr);
+ auto filledTensorMax =
+ rewriter.create<linalg::FillOp>(loc, initTensorMax, fillValueMax)
+ .result();
+
+ // We need to reduce along the arg-max axis, with parallel operations along
+ // the rest.
+ SmallVector<StringRef, 4> iteratorTypes;
+ iteratorTypes.resize(inputTy.getRank(), getParallelIteratorTypeName());
+ iteratorTypes[axis] = getReductionIteratorTypeName();
+
+ SmallVector<AffineExpr, 2> srcExprs;
+ SmallVector<AffineExpr, 2> dstExprs;
+ for (int i = 0, rank = inputTy.getRank(); i != rank; ++i) {
+ srcExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
+ if (axis != i)
+ dstExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
+ }
+
+ bool didEncounterError = false;
+ auto maps = AffineMap::inferFromExprList({srcExprs, dstExprs, dstExprs});
+ auto linalgOp = rewriter.create<linalg::IndexedGenericOp>(
+ loc, ArrayRef<Type>({resultTy, resultMaxTy}), input,
+ ValueRange({filledTensorIdx, filledTensorMax}), maps, iteratorTypes,
+ [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange ivs,
+ ValueRange blockArgs) {
+ auto newValue = blockArgs[0];
+ auto oldIndex = blockArgs[1];
+ auto oldValue = blockArgs[2];
+
+ Value newIndex = rewriter.create<IndexCastOp>(
+ nestedLoc, oldIndex.getType(), ivs[axis]);
+
+ Value predicate;
+ if (inElementTy.isa<FloatType>()) {
+ predicate = rewriter.create<mlir::CmpFOp>(
+ nestedLoc, CmpFPredicate::OGT, newValue, oldValue);
+ } else if (inElementTy.isa<IntegerType>()) {
+ predicate = rewriter.create<mlir::CmpIOp>(
+ nestedLoc, CmpIPredicate::sgt, newValue, oldValue);
+ } else {
+ didEncounterError = true;
+ return;
+ }
+
+ auto resultMax = rewriter.create<mlir::SelectOp>(nestedLoc, predicate,
+ newValue, oldValue);
+ auto resultIndex = rewriter.create<mlir::SelectOp>(
+ nestedLoc, predicate, newIndex, oldIndex);
+ nestedBuilder.create<linalg::YieldOp>(
+ nestedLoc, ValueRange({resultIndex, resultMax}));
+ });
+
+ if (didEncounterError)
+ return rewriter.notifyMatchFailure(
+ argmaxOp, "unsupported tosa.argmax element type");
+
+ rewriter.replaceOp(argmaxOp, linalgOp.getResult(0));
+ return success();
+ }
+};
+
} // namespace
void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns(
@@ -1260,7 +1394,7 @@ void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns(
IdentityNConverter<tosa::IdentityOp>,
IdentityNConverter<tosa::IdentityNOp>, ReduceConverter<tosa::ReduceMinOp>,
ReduceConverter<tosa::ReduceMaxOp>, ReduceConverter<tosa::ReduceSumOp>,
- ReduceConverter<tosa::ReduceProdOp>, ConcatConverter, PadConverter,
+ ReduceConverter<tosa::ReduceProdOp>, ArgMaxConverter, ConcatConverter, PadConverter,
ReshapeConverter, RescaleConverter, ReverseConverter, TileConverter,
TransposeConverter, MatMulConverter, FullyConnectedConverter>(
patterns->getContext());
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 39a4f4122924..8dc968193829 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -745,3 +745,51 @@ func @pad_quant(%arg0 : tensor<1x2xi32>) -> (tensor<4x9xi32>) {
%1 = "tosa.pad"(%arg0, %0) { quantization_info = { input_zp = 42 : i32}} : (tensor<1x2xi32>, tensor<2x2xi32>) -> (tensor<4x9xi32>)
return %1 : tensor<4x9xi32>
}
+
+// -----
+
+// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d1)>
+// CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0)>
+// CHECK: #[[$MAP3:.*]] = affine_map<(d0) -> (d0)>
+// CHECK: #[[$MAP4:.*]] = affine_map<(d0) -> ()>
+
+func @argmax(%arg0 : tensor<3x2xi32>, %arg1 : tensor<6xf32>) -> () {
+ // CHECK: [[IDX_INIT:%.+]] = linalg.init_tensor [2]
+ // CHECK: [[IDX_MIN:%.+]] = constant 0 : i32
+ // CHECK: [[IDX_FILL:%.+]] = linalg.fill([[IDX_INIT]], [[IDX_MIN]])
+ // CHECK: [[VAL_INIT:%.+]] = linalg.init_tensor [2]
+ // CHECK: [[VAL_MIN:%.+]] = constant -2147483648
+ // CHECK: [[VAL_FILL:%.+]] = linalg.fill([[VAL_INIT]], [[VAL_MIN]])
+ // CHECK: linalg.indexed_generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["reduction", "parallel"]} ins(%arg0 : tensor<3x2xi32>) outs([[IDX_FILL]], [[VAL_FILL]] : tensor<2xi32>, tensor<2xi32>)
+ // CHECK: [[CAST:%.+]] = index_cast %arg2
+ // CHECK: [[CMP:%.+]] = cmpi sgt, %arg4, %arg6
+ // CHECK: [[SELECT_VAL:%.+]] = select [[CMP]], %arg4, %arg6
+ // CHECK: [[SELECT_IDX:%.+]] = select [[CMP]], [[CAST]], %arg5
+ // CHECK: linalg.yield [[SELECT_IDX]], [[SELECT_VAL]]
+ %0 = "tosa.argmax"(%arg0) { axis = 0 : i64} : (tensor<3x2xi32>) -> (tensor<2xi32>)
+
+ // CHECK: [[IDX_INIT:%.+]] = linalg.init_tensor [3]
+ // CHECK: [[IDX_MIN:%.+]] = constant 0 : i32
+ // CHECK: [[IDX_FILL:%.+]] = linalg.fill([[IDX_INIT]], [[IDX_MIN]])
+ // CHECK: [[VAL_INIT:%.+]] = linalg.init_tensor [3]
+ // CHECK: [[VAL_MIN:%.+]] = constant -2147483648
+ // CHECK: [[VAL_FILL:%.+]] = linalg.fill([[VAL_INIT]], [[VAL_MIN]])
+ // CHECK: linalg.indexed_generic {indexing_maps = [#map0, #map2, #map2], iterator_types = ["parallel", "reduction"]} ins(%arg0 : tensor<3x2xi32>) outs([[IDX_FILL]], [[VAL_FILL]] : tensor<3xi32>, tensor<3xi32>)
+ // CHECK: [[CAST:%.+]] = index_cast %arg3
+ // CHECK: [[CMP:%.+]] = cmpi sgt, %arg4, %arg6
+ // CHECK: [[SELECT_VAL:%.+]] = select [[CMP]], %arg4, %arg6
+ // CHECK: [[SELECT_IDX:%.+]] = select [[CMP]], [[CAST]], %arg5
+ // CHECK: linalg.yield [[SELECT_IDX]], [[SELECT_VAL]]
+ %1 = "tosa.argmax"(%arg0) { axis = 1 : i64} : (tensor<3x2xi32>) -> (tensor<3xi32>)
+
+ // CHECK: constant -3.40282347E+38 : f32
+ // CHECK: index_cast
+ // CHECK: cmpf ogt
+ // CHECK: select
+ // CHECK: select
+ // CHECK: linalg.yield
+ %2 = "tosa.argmax"(%arg1) { axis = 0 : i64} : (tensor<6xf32>) -> (tensor<i32>)
+
+ return
+}
More information about the Mlir-commits
mailing list