[Mlir-commits] [mlir] 28e6420 - [mlir][tosa] Add tosa.argmax to linalg lowering

Rob Suderman llvmlistbot at llvm.org
Tue Mar 23 16:07:36 PDT 2021


Author: Rob Suderman
Date: 2021-03-23T16:06:55-07:00
New Revision: 28e6420744f52cc39df0d0529c09385e61ddb8ef

URL: https://github.com/llvm/llvm-project/commit/28e6420744f52cc39df0d0529c09385e61ddb8ef
DIFF: https://github.com/llvm/llvm-project/commit/28e6420744f52cc39df0d0529c09385e61ddb8ef.diff

LOG: [mlir][tosa] Add tosa.argmax to linalg lowering

Tosa's argmax lowering is representable as a linalg.indexed_generic
operation. Include the lowering to this type for both integer and
floating point types.

Differential Revision: https://reviews.llvm.org/D99137

Added: 
    

Modified: 
    mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
    mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index a4b6f826feb6..2f6246e717eb 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -488,6 +488,15 @@ static Attribute createInitialValueForReduceOp(Operation *op, Type elementTy,
     return rewriter.getIntegerAttr(
         elementTy, APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth()));
 
+  if (isa<tosa::ArgMaxOp>(op) && elementTy.isa<FloatType>())
+    return rewriter.getFloatAttr(
+        elementTy, APFloat::getLargest(
+                       elementTy.cast<FloatType>().getFloatSemantics(), true));
+
+  if (isa<tosa::ArgMaxOp>(op) && elementTy.isa<IntegerType>())
+    return rewriter.getIntegerAttr(
+        elementTy, APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth()));
+
   return {};
 }
 
@@ -1233,6 +1242,131 @@ class PadConverter : public OpRewritePattern<tosa::PadOp> {
   }
 };
 
+// Tosa argmax lowering represents the ArgMax op as an linalg.indexed_generic
+// op, producing two output buffers.
+//
+// The first output buffer contains the index of the found maximum value. It is
+// initialized to 0 and is resulting integer type.
+//
+// The second output buffer contains the maximum value found. It is initialized
+// to the minimum representable value of the input element type. After being
+// populated by indexed_generic, this buffer is disgarded as only the index is
+// requested.
+//
+// The indexed_generic op updates both the maximum value and index if the
+// current value exceeds the running max.
+class ArgMaxConverter : public OpRewritePattern<tosa::ArgMaxOp> {
+public:
+  using OpRewritePattern<tosa::ArgMaxOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tosa::ArgMaxOp argmaxOp,
+                                PatternRewriter &rewriter) const final {
+    auto loc = argmaxOp.getLoc();
+    Value input = argmaxOp.input();
+    auto inputTy = input.getType().cast<ShapedType>();
+    auto resultTy = argmaxOp.output().getType().cast<ShapedType>();
+    auto inElementTy = inputTy.getElementType();
+    auto outElementTy = resultTy.getElementType();
+    int axis = argmaxOp.axis();
+    auto resultMaxTy = RankedTensorType::get(resultTy.getShape(), inElementTy);
+
+    if (!inputTy.hasStaticShape())
+      return rewriter.notifyMatchFailure(
+          argmaxOp,
+          "tosa.arg_max to linalg.* requires statically shaped input");
+
+    if (!outElementTy.isa<IntegerType>())
+      return rewriter.notifyMatchFailure(
+          argmaxOp,
+          "tosa.arg_max to linalg.* requires integer-like result type");
+
+    // First fill the output buffer for the index.
+    auto initTensorIdx =
+        rewriter
+            .create<linalg::InitTensorOp>(loc, ArrayRef<Value>({}),
+                                          resultTy.getShape(), outElementTy)
+            .result();
+    auto fillValueIdx = rewriter.create<ConstantOp>(
+        loc, rewriter.getIntegerAttr(outElementTy, 0));
+    auto filledTensorIdx =
+        rewriter.create<linalg::FillOp>(loc, initTensorIdx, fillValueIdx)
+            .result();
+
+    // Second fill the output buffer for the running max.
+    auto initTensorMax =
+        rewriter
+            .create<linalg::InitTensorOp>(loc, ArrayRef<Value>({}),
+                                          resultTy.getShape(), inElementTy)
+            .result();
+    auto fillValueMaxAttr =
+        createInitialValueForReduceOp(argmaxOp, inElementTy, rewriter);
+
+    if (!fillValueMaxAttr)
+      return rewriter.notifyMatchFailure(
+          argmaxOp, "unsupported tosa.argmax element type");
+
+    auto fillValueMax = rewriter.create<ConstantOp>(loc, fillValueMaxAttr);
+    auto filledTensorMax =
+        rewriter.create<linalg::FillOp>(loc, initTensorMax, fillValueMax)
+            .result();
+
+    // We need to reduce along the arg-max axis, with parallel operations along
+    // the rest.
+    SmallVector<StringRef, 4> iteratorTypes;
+    iteratorTypes.resize(inputTy.getRank(), getParallelIteratorTypeName());
+    iteratorTypes[axis] = getReductionIteratorTypeName();
+
+    SmallVector<AffineExpr, 2> srcExprs;
+    SmallVector<AffineExpr, 2> dstExprs;
+    for (int i = 0, rank = inputTy.getRank(); i != rank; ++i) {
+      srcExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
+      if (axis != i)
+        dstExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
+    }
+
+    bool didEncounterError = false;
+    auto maps = AffineMap::inferFromExprList({srcExprs, dstExprs, dstExprs});
+    auto linalgOp = rewriter.create<linalg::IndexedGenericOp>(
+        loc, ArrayRef<Type>({resultTy, resultMaxTy}), input,
+        ValueRange({filledTensorIdx, filledTensorMax}), maps, iteratorTypes,
+        [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange ivs,
+            ValueRange blockArgs) {
+          auto newValue = blockArgs[0];
+          auto oldIndex = blockArgs[1];
+          auto oldValue = blockArgs[2];
+
+          Value newIndex = rewriter.create<IndexCastOp>(
+              nestedLoc, oldIndex.getType(), ivs[axis]);
+
+          Value predicate;
+          if (inElementTy.isa<FloatType>()) {
+            predicate = rewriter.create<mlir::CmpFOp>(
+                nestedLoc, CmpFPredicate::OGT, newValue, oldValue);
+          } else if (inElementTy.isa<IntegerType>()) {
+            predicate = rewriter.create<mlir::CmpIOp>(
+                nestedLoc, CmpIPredicate::sgt, newValue, oldValue);
+          } else {
+            didEncounterError = true;
+            return;
+          }
+
+          auto resultMax = rewriter.create<mlir::SelectOp>(nestedLoc, predicate,
+                                                           newValue, oldValue);
+          auto resultIndex = rewriter.create<mlir::SelectOp>(
+              nestedLoc, predicate, newIndex, oldIndex);
+          nestedBuilder.create<linalg::YieldOp>(
+              nestedLoc, ValueRange({resultIndex, resultMax}));
+        });
+
+    if (didEncounterError)
+      return rewriter.notifyMatchFailure(
+          argmaxOp, "unsupported tosa.argmax element type");
+
+    rewriter.replaceOp(argmaxOp, linalgOp.getResult(0));
+    return success();
+  }
+};
+
 } // namespace
 
 void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns(
@@ -1260,7 +1394,7 @@ void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns(
       IdentityNConverter<tosa::IdentityOp>,
       IdentityNConverter<tosa::IdentityNOp>, ReduceConverter<tosa::ReduceMinOp>,
       ReduceConverter<tosa::ReduceMaxOp>, ReduceConverter<tosa::ReduceSumOp>,
-      ReduceConverter<tosa::ReduceProdOp>, ConcatConverter, PadConverter,
+      ReduceConverter<tosa::ReduceProdOp>, ArgMaxConverter, ConcatConverter, PadConverter,
       ReshapeConverter, RescaleConverter, ReverseConverter, TileConverter,
       TransposeConverter, MatMulConverter, FullyConnectedConverter>(
         patterns->getContext());

diff  --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 39a4f4122924..8dc968193829 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -745,3 +745,51 @@ func @pad_quant(%arg0 : tensor<1x2xi32>) -> (tensor<4x9xi32>) {
   %1 = "tosa.pad"(%arg0, %0) { quantization_info = { input_zp = 42 : i32}} : (tensor<1x2xi32>, tensor<2x2xi32>)  -> (tensor<4x9xi32>)
   return %1 : tensor<4x9xi32>
 }
+
+// -----
+
+// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d1)>
+// CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0)>
+// CHECK: #[[$MAP3:.*]] = affine_map<(d0) -> (d0)>
+// CHECK: #[[$MAP4:.*]] = affine_map<(d0) -> ()>
+
+func @argmax(%arg0 : tensor<3x2xi32>, %arg1 : tensor<6xf32>) -> () {
+  // CHECK: [[IDX_INIT:%.+]] = linalg.init_tensor [2]
+  // CHECK: [[IDX_MIN:%.+]] = constant 0 : i32
+  // CHECK: [[IDX_FILL:%.+]] = linalg.fill([[IDX_INIT]], [[IDX_MIN]])
+  // CHECK: [[VAL_INIT:%.+]] = linalg.init_tensor [2]
+  // CHECK: [[VAL_MIN:%.+]] = constant -2147483648
+  // CHECK: [[VAL_FILL:%.+]] = linalg.fill([[VAL_INIT]], [[VAL_MIN]])
+  // CHECK: linalg.indexed_generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["reduction", "parallel"]} ins(%arg0 : tensor<3x2xi32>) outs([[IDX_FILL]], [[VAL_FILL]] : tensor<2xi32>, tensor<2xi32>)
+  // CHECK:   [[CAST:%.+]] = index_cast %arg2
+  // CHECK:   [[CMP:%.+]] = cmpi sgt, %arg4, %arg6
+  // CHECK:   [[SELECT_VAL:%.+]] = select [[CMP]], %arg4, %arg6
+  // CHECK:   [[SELECT_IDX:%.+]] = select [[CMP]], [[CAST]], %arg5
+  // CHECK:   linalg.yield [[SELECT_IDX]], [[SELECT_VAL]]
+  %0 = "tosa.argmax"(%arg0) { axis = 0 : i64} : (tensor<3x2xi32>)  -> (tensor<2xi32>)
+
+  // CHECK: [[IDX_INIT:%.+]] = linalg.init_tensor [3]
+  // CHECK: [[IDX_MIN:%.+]] = constant 0 : i32
+  // CHECK: [[IDX_FILL:%.+]] = linalg.fill([[IDX_INIT]], [[IDX_MIN]])
+  // CHECK: [[VAL_INIT:%.+]] = linalg.init_tensor [3]
+  // CHECK: [[VAL_MIN:%.+]] = constant -2147483648
+  // CHECK: [[VAL_FILL:%.+]] = linalg.fill([[VAL_INIT]], [[VAL_MIN]])
+  // CHECK: linalg.indexed_generic {indexing_maps = [#map0, #map2, #map2], iterator_types = ["parallel", "reduction"]} ins(%arg0 : tensor<3x2xi32>) outs([[IDX_FILL]], [[VAL_FILL]] : tensor<3xi32>, tensor<3xi32>)
+  // CHECK:   [[CAST:%.+]] = index_cast %arg3
+  // CHECK:   [[CMP:%.+]] = cmpi sgt, %arg4, %arg6
+  // CHECK:   [[SELECT_VAL:%.+]] = select [[CMP]], %arg4, %arg6
+  // CHECK:   [[SELECT_IDX:%.+]] = select [[CMP]], [[CAST]], %arg5
+  // CHECK:   linalg.yield [[SELECT_IDX]], [[SELECT_VAL]]
+  %1 = "tosa.argmax"(%arg0) { axis = 1 : i64} : (tensor<3x2xi32>)  -> (tensor<3xi32>)
+
+  // CHECK: constant -3.40282347E+38 : f32
+  // CHECK: index_cast
+  // CHECK: cmpf ogt
+  // CHECK: select
+  // CHECK: select
+  // CHECK: linalg.yield
+  %2 = "tosa.argmax"(%arg1) { axis = 0 : i64} : (tensor<6xf32>)  -> (tensor<i32>)
+
+  return
+}


        


More information about the Mlir-commits mailing list