[Mlir-commits] [mlir] 310e963 - [tosa][mlir] Support dynamic batch dimension for ops where the batch dim is explicit

Rob Suderman llvmlistbot at llvm.org
Wed Jan 12 14:20:09 PST 2022


Author: natashaknk
Date: 2022-01-12T14:16:50-08:00
New Revision: 310e9636caeb2f3f02f3cc5bc2f180248061bbe5

URL: https://github.com/llvm/llvm-project/commit/310e9636caeb2f3f02f3cc5bc2f180248061bbe5
DIFF: https://github.com/llvm/llvm-project/commit/310e9636caeb2f3f02f3cc5bc2f180248061bbe5.diff

LOG: [tosa][mlir] Support dynamic batch dimension for ops where the batch dim is explicit

Dynamic batch for rescale, gather, max_pool, avg_pool, conv2D and depthwise_conv2D. Split helper functions into a separate header file.

Reviewed By: rsuderman

Differential Revision: https://reviews.llvm.org/D117031

Added: 
    mlir/include/mlir/Dialect/Tosa/Utils/CoversionUtils.h
    mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp

Modified: 
    mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
    mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
    mlir/lib/Dialect/Tosa/CMakeLists.txt
    mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
    mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
    utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tosa/Utils/CoversionUtils.h b/mlir/include/mlir/Dialect/Tosa/Utils/CoversionUtils.h
new file mode 100644
index 0000000000000..bbf149865f6e9
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Tosa/Utils/CoversionUtils.h
@@ -0,0 +1,84 @@
+//===- ConversionUtils.h - Helper functions for tosa conversion -*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Utility functions for TOSA lowering
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef DIALECT_TOSA_UTILS_COVERSION_UTILS_H_
+#define DIALECT_TOSA_UTILS_COVERSION_UTILS_H_
+
+#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
+#include "mlir/IR/PatternMatch.h"
+
+namespace mlir {
+namespace tosa {
+
+// Creates a SmallVector of Stringrefs for N parallel loops
+SmallVector<StringRef> getNParallelLoopsAttrs(unsigned nParallelLoops);
+
+// Takes a vector of values and condenses them to a vector with no gaps.
+SmallVector<Value> condenseValues(const SmallVector<Value> &values);
+
+// Takes the parameters for a clamp and turns it into a series of ops.
+template <typename T, typename P>
+mlir::SelectOp clampHelper(Location loc, Value arg, arith::ConstantOp min,
+                           arith::ConstantOp max, P pred, OpBuilder &rewriter) {
+  auto smallerThanMin = rewriter.create<T>(loc, pred, arg, min);
+  auto minOrArg =
+      rewriter.create<mlir::SelectOp>(loc, smallerThanMin, min, arg);
+  auto largerThanMax = rewriter.create<T>(loc, pred, max, arg);
+  return rewriter.create<mlir::SelectOp>(loc, largerThanMax, max, minOrArg);
+}
+
+// Returns the values in an attribute as an array of values.
+template <typename T>
+void getValuesFromIntArrayAttribute(ArrayAttr attr,
+                                    SmallVector<T> &arrayValues) {
+  for (Attribute val : attr.getValue()) {
+    arrayValues.push_back(val.cast<IntegerAttr>().getValue().getSExtValue());
+  }
+}
+
+// Checks for a dynamic batch dim in any of the passed parameters of an op.
+// The batch dimention must be #0 and the rest of the dimensions must be static.
+template <typename Op>
+Optional<SmallVector<Value>> checkHasDynamicBatchDims(PatternRewriter &rewriter,
+                                                      Op op,
+                                                      ArrayRef<Value> params) {
+  SmallVector<ShapedType> dynTypes;
+  SmallVector<Value> dynamicDims;
+  for (const Value &param : params) {
+    auto paramTy = param.getType().cast<ShapedType>();
+    if (!paramTy.hasStaticShape())
+      dynTypes.push_back(paramTy);
+  }
+
+  if (dynTypes.empty())
+    return dynamicDims;
+
+  for (const ShapedType &dynTy : dynTypes) {
+    if (llvm::any_of(dynTy.getShape().drop_front(), ShapedType::isDynamic)) {
+      (void)rewriter.notifyMatchFailure(
+          op, "input can only be dynamic for batch size");
+      return llvm::None;
+    }
+  }
+
+  dynamicDims.push_back(
+      rewriter.create<tensor::DimOp>(op->getLoc(), params[0], 0));
+  return dynamicDims;
+}
+
+} // namespace tosa
+} // namespace mlir
+
+#endif // DIALECT_TOSA_UTILS_COVERSION_UTILS_H_

diff  --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index e9a5c37708e67..1fab060a6b62f 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -18,6 +18,7 @@
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
+#include "mlir/Dialect/Tosa/Utils/CoversionUtils.h"
 #include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/PatternMatch.h"
@@ -27,10 +28,7 @@
 #include <numeric>
 
 using namespace mlir;
-
-static SmallVector<StringRef> getNParallelLoopsAttrs(unsigned nParallelLoops) {
-  return SmallVector<StringRef>(nParallelLoops, getParallelIteratorTypeName());
-}
+using namespace mlir::tosa;
 
 template <typename T>
 static arith::ConstantOp
@@ -42,33 +40,6 @@ createConstFromIntAttribute(Operation *op, const std::string &attrName,
       op->getLoc(), IntegerAttr::get(requiredAttrType, castedN));
 }
 
-template <typename T>
-static void getValuesFromIntArrayAttribute(ArrayAttr attr,
-                                           SmallVector<T> &arrayValues) {
-  for (Attribute val : attr.getValue()) {
-    arrayValues.push_back(val.cast<IntegerAttr>().getValue().getSExtValue());
-  }
-}
-
-template <typename T, typename P>
-static mlir::SelectOp clampHelper(Location loc, Value arg,
-                                  arith::ConstantOp min, arith::ConstantOp max,
-                                  P pred, OpBuilder &rewriter) {
-  auto smallerThanMin = rewriter.create<T>(loc, pred, arg, min);
-  auto minOrArg =
-      rewriter.create<mlir::SelectOp>(loc, smallerThanMin, min, arg);
-  auto largerThanMax = rewriter.create<T>(loc, pred, max, arg);
-  return rewriter.create<mlir::SelectOp>(loc, largerThanMax, max, minOrArg);
-}
-
-static SmallVector<Value> filterDynamicDims(const SmallVector<Value> &dynDims) {
-  SmallVector<Value> filteredDims;
-  for (auto dim : dynDims)
-    if (dim)
-      filteredDims.push_back(dim);
-  return filteredDims;
-}
-
 static Value
 createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
                                             ArrayRef<Type> resultTypes,
@@ -665,7 +636,7 @@ elementwiseMatchAndRewriteHelper(Operation *operation,
     }
   }
 
-  SmallVector<Value> filteredDims = filterDynamicDims(dynDims);
+  SmallVector<Value> filteredDims = condenseValues(dynDims);
 
   for (auto result : results) {
     auto resultTy = result.getType().template cast<ShapedType>();
@@ -1184,7 +1155,7 @@ class TransposeConverter : public OpRewritePattern<tosa::TransposeOp> {
       inputExprs[value] = rewriter.getAffineDimExpr(index);
     }
 
-    SmallVector<Value> filteredDims = filterDynamicDims(dynDims);
+    SmallVector<Value> filteredDims = condenseValues(dynDims);
 
     auto initTensor = rewriter.create<linalg::InitTensorOp>(
         loc, filteredDims, resultTy.getShape(), resultTy.getElementType());
@@ -1221,9 +1192,11 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
       return rewriter.notifyMatchFailure(
           op, "tosa.rescale requires scale32 for double_round to be true");
 
-    if (!outputTy.hasStaticShape())
-      return rewriter.notifyMatchFailure(
-          op, "tosa to linalg conversion expects statically shaped tensors");
+    auto dynamicDimsOr =
+        checkHasDynamicBatchDims(rewriter, op, {input, op.output()});
+    if (!dynamicDimsOr.hasValue())
+      return failure();
+    SmallVector<Value> dynamicDims = dynamicDimsOr.getValue();
 
     // The shift and multiplier values.
     SmallVector<int32_t> multiplierValues;
@@ -1299,8 +1272,7 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
 
     // Construct the indexing maps needed for linalg.generic ops.
     Value initTensor = rewriter.create<linalg::InitTensorOp>(
-        loc, ArrayRef<Value>({}), outputTy.getShape(),
-        outputTy.getElementType());
+        loc, dynamicDims, outputTy.getShape(), outputTy.getElementType());
 
     auto linalgOp = rewriter.create<linalg::GenericOp>(
         loc, outputTy, genericInputs, ValueRange{initTensor}, indexingMaps,
@@ -1412,16 +1384,17 @@ class ResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
     auto imageH = inputTy.getShape()[1];
     auto imageW = inputTy.getShape()[2];
 
-    if (!resultTy.hasStaticShape())
+    auto dynamicDimsOr =
+        checkHasDynamicBatchDims(rewriter, op, {input, op.output()});
+    if (!dynamicDimsOr.hasValue())
       return failure();
+    SmallVector<Value> dynamicDims = dynamicDimsOr.getValue();
+
     if (op.mode() != "NEAREST_NEIGHBOR" && op.mode() != "BILINEAR")
       return failure();
 
-    auto initTensor =
-        rewriter
-            .create<linalg::InitTensorOp>(loc, ArrayRef<Value>{},
-                                          resultTy.getShape(), resultElementTy)
-            .result();
+    auto initTensor = rewriter.create<linalg::InitTensorOp>(
+        loc, dynamicDims, resultTy.getShape(), resultElementTy);
 
     SmallVector<AffineMap, 2> affineMaps = {
         rewriter.getMultiDimIdentityMap(resultTy.getRank())};
@@ -2098,13 +2071,13 @@ class GatherConverter : public OpConversionPattern<tosa::GatherOp> {
     auto input = adaptor.getOperands()[0];
     auto indices = adaptor.getOperands()[1];
 
-    auto inputTy = input.getType().cast<ShapedType>();
-    auto indicesTy = indices.getType().cast<ShapedType>();
     auto resultTy = op.getType().cast<ShapedType>();
 
-    if (!inputTy.hasStaticShape() || !indicesTy.hasStaticShape())
-      return rewriter.notifyMatchFailure(
-          op, "require input type to have static shape");
+    auto dynamicDimsOr =
+        checkHasDynamicBatchDims(rewriter, op, {input, indices, op.output()});
+    if (!dynamicDimsOr.hasValue())
+      return failure();
+    SmallVector<Value> dynamicDims = dynamicDimsOr.getValue();
 
     auto resultElementTy = resultTy.getElementType();
 
@@ -2112,8 +2085,8 @@ class GatherConverter : public OpConversionPattern<tosa::GatherOp> {
 
     auto initTensor =
         rewriter
-            .create<linalg::InitTensorOp>(loc, ArrayRef<Value>{},
-                                          resultTy.getShape(), resultElementTy)
+            .create<linalg::InitTensorOp>(loc, dynamicDims, resultTy.getShape(),
+                                          resultElementTy)
             .result();
 
     SmallVector<AffineMap, 2> affineMaps = {

diff  --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index a9c525f43aa37..54012c9760eb2 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -18,6 +18,7 @@
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
+#include "mlir/Dialect/Tosa/Utils/CoversionUtils.h"
 #include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/PatternMatch.h"
@@ -27,29 +28,7 @@
 #include <numeric>
 
 using namespace mlir;
-
-static SmallVector<StringRef> getNParallelLoopsAttrs(unsigned nParallelLoops) {
-  return SmallVector<StringRef>(nParallelLoops, getParallelIteratorTypeName());
-}
-
-template <typename T>
-static void getValuesFromIntArrayAttribute(ArrayAttr attr,
-                                           SmallVector<T> &arrayValues) {
-  for (Attribute val : attr.getValue()) {
-    arrayValues.push_back(val.cast<IntegerAttr>().getValue().getSExtValue());
-  }
-}
-
-template <typename T, typename P>
-static mlir::SelectOp clampHelper(Location loc, Value arg,
-                                  arith::ConstantOp min, arith::ConstantOp max,
-                                  P pred, OpBuilder &rewriter) {
-  auto smallerThanMin = rewriter.create<T>(loc, pred, arg, min);
-  auto minOrArg =
-      rewriter.create<mlir::SelectOp>(loc, smallerThanMin, min, arg);
-  auto largerThanMax = rewriter.create<T>(loc, pred, max, arg);
-  return rewriter.create<mlir::SelectOp>(loc, largerThanMax, max, minOrArg);
-}
+using namespace mlir::tosa;
 
 static mlir::Value applyPad(Location loc, Value input, ArrayRef<int64_t> pad,
                             Attribute padAttr, OpBuilder &rewriter) {
@@ -82,14 +61,6 @@ static mlir::Value applyPad(Location loc, Value input, ArrayRef<int64_t> pad,
       .result();
 }
 
-static SmallVector<Value> filterDynamicDims(const SmallVector<Value> &dynDims) {
-  SmallVector<Value> filteredDims;
-  for (auto dim : dynDims)
-    if (dim)
-      filteredDims.push_back(dim);
-  return filteredDims;
-}
-
 namespace {
 
 class ConvConverter : public OpConversionPattern<tosa::Conv2DOp> {
@@ -116,10 +87,15 @@ class ConvConverter : public OpConversionPattern<tosa::Conv2DOp> {
     auto dilationTosaAttr = op->getAttr("dilation").cast<ArrayAttr>();
     bool isQuantized = op->hasAttr("quantization_info");
 
-    if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() ||
-        !biasTy.hasStaticShape() || !resultTy.hasStaticShape())
-      return rewriter.notifyMatchFailure(op,
-                                         "tosa.conv ops require static shapes");
+    if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape())
+      return rewriter.notifyMatchFailure(
+          op, "tosa.conv ops require static shapes for weight and bias");
+
+    auto dynamicDimsOr =
+        checkHasDynamicBatchDims(rewriter, op, {input, op.output()});
+    if (!dynamicDimsOr.hasValue())
+      return failure();
+    SmallVector<Value> dynamicDims = dynamicDimsOr.getValue();
 
     if (inputETy.isUnsignedInteger())
       return rewriter.notifyMatchFailure(
@@ -172,7 +148,7 @@ class ConvConverter : public OpConversionPattern<tosa::Conv2DOp> {
 
     Attribute resultZeroAttr = rewriter.getZeroAttr(resultETy);
     Value initTensor = rewriter.create<linalg::InitTensorOp>(
-        loc, resultTy.getShape(), resultETy);
+        loc, dynamicDims, resultTy.getShape(), resultETy);
     Value zero = rewriter.create<arith::ConstantOp>(loc, resultZeroAttr);
     Value zeroTensor =
         rewriter.create<linalg::FillOp>(loc, zero, initTensor).getResult(0);
@@ -197,7 +173,7 @@ class ConvConverter : public OpConversionPattern<tosa::Conv2DOp> {
     indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultTy.getRank()));
 
     Value biasInitTensor = rewriter.create<linalg::InitTensorOp>(
-        loc, resultTy.getShape(), resultETy);
+        loc, dynamicDims, resultTy.getShape(), resultETy);
 
     if (isQuantized) {
       auto quantizationInfo =
@@ -292,10 +268,15 @@ class DepthwiseConvConverter
           quantizationInfo.weight_zp().getValue().getSExtValue());
     }
 
-    if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() ||
-        !biasTy.hasStaticShape() || !resultTy.hasStaticShape())
-      return rewriter.notifyMatchFailure(op,
-                                         "tosa.conv ops require static shapes");
+    if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape())
+      return rewriter.notifyMatchFailure(
+          op, "tosa.depthwise_conv ops require static shapes");
+
+    auto dynamicDimsOr =
+        checkHasDynamicBatchDims(rewriter, op, {input, op.output()});
+    if (!dynamicDimsOr.hasValue())
+      return failure();
+    SmallVector<Value> dynamicDims = dynamicDimsOr.getValue();
 
     auto weightShape = weightTy.getShape();
     auto resultShape = resultTy.getShape();
@@ -354,13 +335,13 @@ class DepthwiseConvConverter
 
     Attribute resultZeroAttr = rewriter.getZeroAttr(resultETy);
     Value initTensor = rewriter.create<linalg::InitTensorOp>(
-        loc, linalgConvTy.getShape(), resultETy);
+        loc, dynamicDims, linalgConvTy.getShape(), resultETy);
     Value zero = rewriter.create<arith::ConstantOp>(loc, resultZeroAttr);
     Value zeroTensor =
         rewriter.create<linalg::FillOp>(loc, zero, initTensor).getResult(0);
 
     Value biasInitTensor = rewriter.create<linalg::InitTensorOp>(
-        loc, resultTy.getShape(), resultETy);
+        loc, dynamicDims, resultTy.getShape(), resultETy);
     if (!isQuantized) {
       Value conv = rewriter
                        .create<linalg::DepthwiseConv2DNhwcHwcmOp>(
@@ -442,7 +423,7 @@ class MatMulConverter : public OpConversionPattern<tosa::MatMulOp> {
       dynDims[2] = rewriter.create<tensor::DimOp>(loc, op->getOperand(1), 2);
     }
 
-    SmallVector<Value> filteredDims = filterDynamicDims(dynDims);
+    SmallVector<Value> filteredDims = condenseValues(dynDims);
 
     auto zeroAttr = rewriter.getZeroAttr(outputElementTy);
     Value zero = rewriter.create<arith::ConstantOp>(loc, zeroAttr);
@@ -503,7 +484,7 @@ class FullyConnectedConverter
       dynDims[1] = rewriter.create<tensor::DimOp>(loc, weight, 0);
     }
 
-    SmallVector<Value> filteredDims = filterDynamicDims(dynDims);
+    SmallVector<Value> filteredDims = condenseValues(dynDims);
 
     // Creating maps for the output of MatMul and the bias
     SmallVector<AffineMap, 4> indexingMaps;
@@ -611,8 +592,11 @@ class MaxPool2dConverter : public OpRewritePattern<tosa::MaxPool2dOp> {
     ShapedType resultTy = op.getType().template cast<ShapedType>();
     Type resultETy = inputTy.getElementType();
 
-    if (!inputTy.hasStaticShape())
+    auto dynamicDimsOr =
+        checkHasDynamicBatchDims(rewriter, op, {input, op.output()});
+    if (!dynamicDimsOr.hasValue())
       return failure();
+    SmallVector<Value> dynamicDims = dynamicDimsOr.getValue();
 
     // Determine what the initial value needs to be for the max pool op.
     Attribute initialAttr;
@@ -649,7 +633,7 @@ class MaxPool2dConverter : public OpRewritePattern<tosa::MaxPool2dOp> {
 
     // Create the linalg op that performs pooling.
     Value initTensor = rewriter.create<linalg::InitTensorOp>(
-        loc, resultTy.getShape(), resultTy.getElementType());
+        loc, dynamicDims, resultTy.getShape(), resultTy.getElementType());
 
     Value filledInitTensor =
         rewriter.create<linalg::FillOp>(loc, initialValue, initTensor).result();
@@ -682,8 +666,11 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
         inElementTy.isa<IntegerType>() ? rewriter.getI32Type() : inElementTy;
     ShapedType accTy = resultTy.clone(accETy);
 
-    if (!inputTy.hasStaticShape())
+    auto dynamicDimsOr =
+        checkHasDynamicBatchDims(rewriter, op, {input, op.output()});
+    if (!dynamicDimsOr.hasValue())
       return failure();
+    SmallVector<Value> dynamicDims = dynamicDimsOr.getValue();
 
     // Apply padding as necessary.
     llvm::SmallVector<int64_t> pad;
@@ -704,8 +691,8 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
     Attribute dilationAttr = rewriter.getI64VectorAttr({1, 1});
 
     // Create the linalg op that performs pooling.
-    Value poolInitTensor =
-        rewriter.create<linalg::InitTensorOp>(loc, accTy.getShape(), accETy);
+    Value poolInitTensor = rewriter.create<linalg::InitTensorOp>(
+        loc, dynamicDims, accTy.getShape(), accETy);
 
     Value filledInitTensor =
         rewriter.create<linalg::FillOp>(loc, initialValue, poolInitTensor)
@@ -728,7 +715,7 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
     auto affineMap = rewriter.getMultiDimIdentityMap(resultTy.getRank());
 
     Value genericInitTensor = rewriter.create<linalg::InitTensorOp>(
-        loc, resultTy.getShape(), resultETy);
+        loc, dynamicDims, resultTy.getShape(), resultETy);
 
     auto genericOp = rewriter.create<linalg::GenericOp>(
         loc, ArrayRef<Type>({resultTy}), ValueRange{poolingOp},
@@ -770,7 +757,7 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
           auto kH2 = padFn(kH1, y1, pad[3]);
           auto kHCmp = rewriter.create<arith::CmpIOp>(
               loc, arith::CmpIPredicate::slt, kH2, one);
-          auto kH3 = rewriter.create<SelectOp>(loc, kHCmp, one, kH2);
+          auto kH3 = rewriter.create<mlir::SelectOp>(loc, kHCmp, one, kH2);
 
           // compute the horizontal component of coverage.
           auto kW0 = rewriter.create<arith::ConstantIndexOp>(loc, kernel[1]);
@@ -778,7 +765,7 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
           auto kW2 = padFn(kW1, x1, pad[5]);
           auto kWCmp = rewriter.create<arith::CmpIOp>(
               loc, arith::CmpIPredicate::slt, kW2, one);
-          auto kW3 = rewriter.create<SelectOp>(loc, kWCmp, one, kW2);
+          auto kW3 = rewriter.create<mlir::SelectOp>(loc, kWCmp, one, kW2);
 
           // Compute the total number of elements and normalize.
           Value count = rewriter.create<arith::MulIOp>(loc, kH3, kW3);

diff  --git a/mlir/lib/Dialect/Tosa/CMakeLists.txt b/mlir/lib/Dialect/Tosa/CMakeLists.txt
index cae6d14ed9633..9a9c80f933aec 100644
--- a/mlir/lib/Dialect/Tosa/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tosa/CMakeLists.txt
@@ -1,4 +1,5 @@
 add_mlir_dialect_library(MLIRTosa
+  Utils/ConversionUtils.cpp
   Utils/QuantUtils.cpp
   IR/TosaOps.cpp
 

diff  --git a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
new file mode 100644
index 0000000000000..e994adb29bf5c
--- /dev/null
+++ b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
@@ -0,0 +1,30 @@
+//===- ConversionUtils.cpp ------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Utility functions for TOSA lowering
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Tosa/Utils/CoversionUtils.h"
+
+using namespace mlir;
+using namespace mlir::tosa;
+
+SmallVector<StringRef>
+mlir::tosa::getNParallelLoopsAttrs(unsigned nParallelLoops) {
+  return SmallVector<StringRef>(nParallelLoops, getParallelIteratorTypeName());
+}
+
+SmallVector<Value>
+mlir::tosa::condenseValues(const SmallVector<Value> &values) {
+  SmallVector<Value> condensedValues;
+  for (auto value : values)
+    if (value)
+      condensedValues.push_back(value);
+  return condensedValues;
+}

diff  --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
index f5814883cc49c..9db6c58cd113d 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
@@ -164,6 +164,19 @@ func @max_pool_padded(%arg0: tensor<1x6x34x62xf32>) -> () {
   return
 }
 
+// CHECK-LABEL: @max_pool_dyn
+func @max_pool_dyn(%arg0: tensor<?x6x34x62xf32>) -> () {
+  // CHECK: %[[C0:.+]] = arith.constant 0
+  // CHECK: %[[BATCH:.+]] = tensor.dim %arg0, %[[C0]]
+  // CHECK: %[[CONST:.+]] = arith.constant -3.40282347E+38
+  // CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[BATCH]], 4, 32, 62]
+  // CHECK: %[[FILL:.+]] = linalg.fill(%[[CONST]], %[[INIT]])
+  // CHECK: %[[KERNEL:.+]] = linalg.init_tensor [3, 3]
+  // CHECK: linalg.pooling_nhwc_max {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%arg0, %[[KERNEL]] : tensor<?x6x34x62xf32>, tensor<3x3xf32>) outs(%[[FILL]] : tensor<?x4x32x62xf32>)
+  %0 = "tosa.max_pool2d"(%arg0) {pad = [0, 0, 0, 0], kernel = [3, 3], stride = [1, 1]} : (tensor<?x6x34x62xf32>)  -> (tensor<?x4x32x62xf32>)
+  return
+}
+
 // CHECK-LABEL: @max_pool_i8
 func @max_pool_i8(%arg0: tensor<1x6x34x62xi8>) -> () {
   // CHECK: arith.constant -128
@@ -250,6 +263,24 @@ func @avg_pool(%arg0: tensor<1x6x34x62xf32>) -> (tensor<1x5x33x62xf32>) {
 
 // -----
 
+// CHECK-LABEL: @avg_pool_dyn
+func @avg_pool_dyn(%arg0: tensor<?x6x34x62xf32>) -> (tensor<?x5x33x62xf32>) {
+  // The calculations remain the same as above, only testing for dyn behavior
+  // CHECK: %[[C0:.+]] = arith.constant 0
+  // CHECK: %[[BATCH:.+]] = tensor.dim %arg0, %[[C0]]
+  // CHECK: %[[PAD:.+]] = linalg.pad_tensor %arg0 low[0, 1, 1, 0] high[0, 1, 1, 0]
+  // CHECK: %[[POOLINIT:.+]] = linalg.init_tensor [%[[BATCH]], 5, 33, 62]
+  // CHECK: %[[FILL:.+]] = linalg.fill
+  // CHECK: %[[KERNEL:.+]] = linalg.init_tensor [4, 4]
+  // CHECK: %[[POOL:.+]] = linalg.pooling_nhwc_sum {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%[[PAD]], %[[KERNEL]] : tensor<?x8x36x62xf32>, tensor<4x4xf32>) outs(%[[FILL]] : tensor<?x5x33x62xf32>)
+  // CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[BATCH]], 5, 33, 62]
+  // CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[POOL]] : tensor<?x5x33x62xf32>) outs(%[[INIT]] : tensor<?x5x33x62xf32>)
+  %0 = "tosa.avg_pool2d"(%arg0) {pad = [1, 1, 1, 1], kernel = [4, 4], stride = [1, 1]} : (tensor<?x6x34x62xf32>)  -> (tensor<?x5x33x62xf32>)
+  return %0 : tensor<?x5x33x62xf32>
+}
+
+// -----
+
 // CHECK-LABEL: @avg_pool_i8
 func @avg_pool_i8(%arg0 : tensor<1x128x128x2xi8>) -> () {
 
@@ -329,6 +360,29 @@ func @conv2d_f32(%input: tensor<1x49x42x27xf32>, %weights: tensor<28x3x3x27xf32>
 
 // -----
 
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d3)>
+// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+
+// CHECK-LABEL: @conv2d_dyn
+func @conv2d_dyn(%input: tensor<?x49x42x27xf32>, %weights: tensor<28x3x3x27xf32>, %bias: tensor<28xf32>) -> () {
+  // CHECK: %[[C0:.+]] = arith.constant 0
+  // CHECK: %[[BATCH:.+]] = tensor.dim %arg0, %[[C0]]
+  // CHECK: %[[PERM:.+]] = arith.constant dense<[1, 2, 3, 0]>
+  // CHECK: %[[W:.+]] = "tosa.transpose"(%arg1, %[[PERM]])
+  // CHECK: %[[M_IN:.+]] = linalg.init_tensor [%[[BATCH]], 45, 40, 28]
+  // CHECK: %[[CST:.+]] = arith.constant 0
+  // CHECK: %[[FILL:.+]] = linalg.fill
+  // CHECK: %[[B_IN:.+]] = linalg.init_tensor [%[[BATCH]], 45, 40, 28]
+  // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[W]] : tensor<?x49x42x27xf32>, tensor<3x3x27x28xf32>) outs(%[[FILL]] : tensor<?x45x40x28xf32>)
+  // CHECK: %[[B:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %[[CONV]] : tensor<28xf32>, tensor<?x45x40x28xf32>) outs(%[[B_IN]] : tensor<?x45x40x28xf32>)
+  // CHECK:   %[[ADD:.+]] = arith.addf
+  // CHECK:   linalg.yield %[[ADD]] : f32
+  %0 = "tosa.conv2d"(%input, %weights, %bias) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [2, 1]} : (tensor<?x49x42x27xf32>, tensor<28x3x3x27xf32>, tensor<28xf32>)  -> (tensor<?x45x40x28xf32>)
+  return
+}
+
+// -----
+
 // CHECK-LABEL: @conv2d_padded_f32
 func @conv2d_padded_f32(%input: tensor<1x47x40x28xf32>, %weights: tensor<28x3x3x28xf32>, %bias: tensor<28xf32>) -> () {
   // CHECK: %[[C0:.+]] = arith.constant 0
@@ -378,6 +432,30 @@ func @depthwise_conv(%arg0 : tensor<1x7x5x3xf32>, %arg1 : tensor<3x1x3x11xf32>,
 // CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d3)>
 // CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
 
+// CHECK-LABEL: @depthwise_conv_dyn
+func @depthwise_conv_dyn(%arg0 : tensor<?x7x5x3xf32>, %arg1 : tensor<3x1x3x11xf32>, %arg2 : tensor<33xf32>) -> () {
+  // CHECK: %[[C0:.+]] = arith.constant 0
+  // CHECK: %[[BATCH:.+]] = tensor.dim %arg0, %[[C0]]
+  // CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[BATCH]], 5, 5, 3, 11]
+  // CHECK: %[[CST0:.+]] = arith.constant 0
+  // CHECK: %[[FILL:.+]] = linalg.fill
+  // CHECK: %[[OUT:.+]] = linalg.init_tensor [%[[BATCH]], 5, 5, 33]
+  // CHECK: %[[DEPTH:.+]] = linalg.depthwise_conv_2d_nhwc_hwcm {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<?x7x5x3xf32>, tensor<3x1x3x11xf32>) outs(%[[FILL]] : tensor<?x5x5x3x11xf32>)
+  // CHECK: %[[COLLAPSED:.+]] = "tosa.reshape"(%[[DEPTH]]) {new_shape = [-1, 5, 5, 33]}
+  // CHECK: %[[BIAS:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %[[COLLAPSED]] : tensor<33xf32>, tensor<?x5x5x33xf32>) outs(%[[OUT]] : tensor<?x5x5x33xf32>) {
+  // CHECK: ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):  // no predecessors
+  // CHECK:   %[[ADD:.+]] = arith.addf %arg3, %arg4 : f32
+  // CHECK:   linalg.yield %[[ADD]] : f32
+  // CHECK: } -> tensor<?x5x5x33xf32>
+  %2 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) { pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1] } : (tensor<?x7x5x3xf32>, tensor<3x1x3x11xf32>, tensor<33xf32>)  -> (tensor<?x5x5x33xf32>)
+  return
+}
+
+// -----
+
+// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d3)>
+// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+
 // CHECK-LABEL: @depthwise_conv_strides
 func @depthwise_conv_strides(%arg0 : tensor<1x11x9x3xf32>, %arg1 : tensor<3x1x3x11xf32>, %arg2 : tensor<33xf32>) -> () {
   // CHECK: [[INIT:%.+]] = linalg.init_tensor [1, 5, 5, 3, 11]

diff  --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index e68e76c67ef98..3706c4131dcbb 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -897,6 +897,26 @@ func @rescale_i8(%arg0 : tensor<2xi8>) -> () {
 
 // -----
 
+// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
+
+// CHECK-LABEL: @rescale_i8_dyn
+func @rescale_i8_dyn(%arg0 : tensor<?x2xi8>) -> () {
+  // CHECK: %[[C0:.+]] = arith.constant 0
+  // CHECK: %[[BATCH:.+]] = tensor.dim %arg0, %[[C0]]
+  // CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[BATCH]], 2]
+  // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor<?x2xi8>) outs(%[[INIT]] : tensor<?x2xi8>)
+  %0 = "tosa.rescale"(%arg0) {input_zp = 17 : i32, output_zp = 22 : i32, multiplier = [19689 : i32], shift = [15 : i32], scale32 = false, double_round = false, per_channel = false} : (tensor<?x2xi8>)  -> (tensor<?x2xi8>)
+
+  // CHECK: %[[C0:.+]] = arith.constant 0
+  // CHECK: %[[BATCH:.+]] = tensor.dim %arg0, %[[C0]]
+  // CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[BATCH]], 2]
+  // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor<?x2xi8>) outs(%[[INIT]] : tensor<?x2xui8>)
+  %1 = "tosa.rescale"(%arg0) {input_zp = 17 : i32, output_zp = 22 : i32, multiplier = [19689 : i32], shift = [15 : i32], scale32 = false, double_round = false, per_channel = false} : (tensor<?x2xi8>)  -> (tensor<?x2xui8>)
+
+  return
+}
+// -----
+
 // CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>
 
 // CHECK-LABEL: @rescale_ui8
@@ -1184,6 +1204,22 @@ func @gather_float(%arg0: tensor<2x3x2xf32>, %arg1: tensor<2x3xi32>) -> () {
   return
 }
 
+// CHECK-LABEL: @gather_float_dyn
+func @gather_float_dyn(%arg0: tensor<?x3x2xf32>, %arg1: tensor<?x3xi32>) -> () {
+  // CHECK: %[[C0:.+]] = arith.constant 0
+  // CHECK: %[[BATCH:.+]] = tensor.dim %arg0, %[[C0]]
+  // CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[BATCH]], 3, 2]
+  // CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg1 : tensor<?x3xi32>) outs(%[[INIT]] : tensor<?x3x2xf32>)
+  // CHECK: ^bb0(%[[ARG0:.+]]: i32, %[[ARG1:.+]]: f32)
+  // CHECK:   %[[IDX0:.+]] = linalg.index 0
+  // CHECK:   %[[CAST:.+]] = arith.index_cast %[[ARG0]]
+  // CHECK:   %[[IDX2:.+]] = linalg.index 2
+  // CHECK:   %[[EXTRACT:.+]] = tensor.extract %arg0[%[[IDX0]], %[[CAST]], %[[IDX2]]] : tensor<?x3x2xf32>
+  // CHECK:   linalg.yield %[[EXTRACT]]
+  %0 = "tosa.gather"(%arg0, %arg1)  : (tensor<?x3x2xf32>, tensor<?x3xi32>)  -> (tensor<?x3x2xf32>)
+  return
+}
+
 // CHECK-LABEL: @gather_int
 func @gather_int(%arg0: tensor<2x3x2xi32>, %arg1: tensor<2x3xi32>) -> () {
   // CHECK: %[[INIT:.+]] = linalg.init_tensor [2, 3, 2]
@@ -1548,3 +1584,15 @@ func @resize_bilinear_int(%input: tensor<1x2x2x1xi8>) -> () {
   %output = "tosa.resize"(%input) { output_size = [4, 4], stride = [128, 128], offset = [1, 2], stride_fp = [0. : f32, 0. : f32], offset_fp = [0. : f32, 0. : f32], shift = 8 : i32, mode = "BILINEAR" } : (tensor<1x2x2x1xi8>)  -> (tensor<1x4x4x1xi32>)
   return
 }
+
+// -----
+
+// CHECK-LABEL: @resize_dyn
+func @resize_dyn(%input: tensor<?x2x2x1xi8>) -> () {
+    // CHECK: %[[C0:.+]] = arith.constant 0
+  // CHECK: %[[BATCH:.+]] = tensor.dim %arg0, %[[C0]]
+  // CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[BATCH]], 4, 4, 1]
+  // CHECK: %[[GENERIC:.+]] = linalg.generic
+  %output = "tosa.resize"(%input) { output_size = [4, 4], stride = [128, 128], offset = [1, 2], stride_fp = [0. : f32, 0. : f32], offset_fp = [0. : f32, 0. : f32], shift = 8 : i32, mode = "BILINEAR" } : (tensor<?x2x2x1xi8>)  -> (tensor<?x4x4x1xi32>)
+  return
+}

diff  --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 4399f2f6a5eab..0b144d9920f26 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -7071,13 +7071,14 @@ cc_library(
     includes = ["include"],
     deps = [
         ":Analysis",
+        ":ArithmeticDialect",
         ":Dialect",
+        ":DialectUtils",
         ":IR",
         ":InferTypeOpInterface",
         ":LoopLikeInterface",
         ":Pass",
         ":QuantOps",
-        ":SideEffectInterfaces",
         ":StandardOps",
         ":TensorDialect",
         ":TosaDialectIncGen",


        


More information about the Mlir-commits mailing list