[Mlir-commits] [mlir] [mlir][tosa] Implement dynamic shape support for tosa.max_pool2d lowering (PR #87538)

Spenser Bauman llvmlistbot at llvm.org
Mon Apr 15 15:57:32 PDT 2024


https://github.com/sabauma updated https://github.com/llvm/llvm-project/pull/87538

>From 4cc3a552f4cfbc5400d7ea7176d7d915778cc9ea Mon Sep 17 00:00:00 2001
From: Spenser Bauman <sbauman at mathworks.com>
Date: Tue, 2 Apr 2024 17:54:38 -0400
Subject: [PATCH] [mlir][tosa] Implement dynamic shape support for
 tosa.max_pool2d lowering

The existing lowering for tosa.max_pool2d only supports dynamic
dimensions when the dynamic dimension is the batch dimension.
This change updates the lowering to support arbitrary dynamic dimensions
on the inputs and outputs of the tosa.max_pool2d operation.

This change also fixes a bug in the implementation of implicit
broadcasting in the tosa-to-linalg pass, which was introducing uses of
constant ops that violated dominanace requirements.
---
 .../mlir/Dialect/Tosa/IR/TosaTypesBase.td     |  10 +-
 .../Conversion/TosaToLinalg/TosaToLinalg.cpp  |  12 +-
 .../TosaToLinalg/TosaToLinalgNamed.cpp        |  88 ++++++++++----
 .../TosaToLinalg/tosa-to-linalg-named.mlir    |  54 +++++++++
 .../TosaToLinalg/tosa-to-linalg.mlir          |  12 +-
 .../Tosa/CPU/test-maxpool-dynamic.mlir        | 112 ++++++++++++++++++
 6 files changed, 251 insertions(+), 37 deletions(-)
 create mode 100644 mlir/test/Integration/Dialect/Tosa/CPU/test-maxpool-dynamic.mlir

diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
index cff3de0a69af95..3687891fe4b7cf 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
@@ -130,11 +130,11 @@ def Tosa_ScalarTensor : TensorRankOf<[Tosa_AnyNumber], [0]>;
 // to not include any remaining unranked tensors.
 def Tosa_UnrankedTensor : UnrankedTensorOf<[Tosa_AnyNumber]>;
 
-def Tosa_Tensor1D : AnyTypeOf<[Tosa_UnrankedTensor, 1DTensorOf<[Tosa_AnyNumber]>]>;
-def Tosa_Tensor2D : AnyTypeOf<[Tosa_UnrankedTensor, 2DTensorOf<[Tosa_AnyNumber]>]>;
-def Tosa_Tensor3D : AnyTypeOf<[Tosa_UnrankedTensor, 3DTensorOf<[Tosa_AnyNumber]>]>;
-def Tosa_Tensor4D : AnyTypeOf<[Tosa_UnrankedTensor, 4DTensorOf<[Tosa_AnyNumber]>]>;
-def Tosa_Tensor5D : AnyTypeOf<[Tosa_UnrankedTensor, TensorRankOf<[Tosa_AnyNumber], [5]>]>;
+def Tosa_Tensor1D : AnyTypeOf<[Tosa_UnrankedTensor, 1DTensorOf<[Tosa_AnyNumber]>], "1-d tensor", "::mlir::TensorType">;
+def Tosa_Tensor2D : AnyTypeOf<[Tosa_UnrankedTensor, 2DTensorOf<[Tosa_AnyNumber]>], "2-d tensor", "::mlir::TensorType">;
+def Tosa_Tensor3D : AnyTypeOf<[Tosa_UnrankedTensor, 3DTensorOf<[Tosa_AnyNumber]>], "3-d tensor", "::mlir::TensorType">;
+def Tosa_Tensor4D : AnyTypeOf<[Tosa_UnrankedTensor, 4DTensorOf<[Tosa_AnyNumber]>], "4-d tensor", "::mlir::TensorType">;
+def Tosa_Tensor5D : AnyTypeOf<[Tosa_UnrankedTensor, TensorRankOf<[Tosa_AnyNumber], [5]>], "5-d tensor", "::mlir::TensorType">;
 
 // Ranked tensors up to given rank.
 def Tosa_Tensor1Dto4D : AnyTypeOf<[
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 7c477f2e1412be..d8dd1c93722b09 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -766,11 +766,15 @@ static Value broadcastDynamicDimension(PatternRewriter &rewriter, Location loc,
 
   // Emit 'then' region of 'scf.if'
   auto emitThenRegion = [&](OpBuilder &opBuilder, Location loc) {
+    // It is not safe to cache constants across regions.
+    // New constants could potentially violate dominance requirements.
+    IndexPool localPool;
+
     // Emit 'tensor.empty' op
     SmallVector<OpFoldResult> outputTensorShape;
     for (auto index : llvm::seq<int64_t>(0, rank)) {
       auto size = index == dim ? targetSize
-                               : getOrFoldTensorDim(rewriter, loc, indexPool,
+                               : getOrFoldTensorDim(rewriter, loc, localPool,
                                                     operand, index);
       outputTensorShape.push_back(size);
     }
@@ -812,9 +816,9 @@ static Value broadcastDynamicDimensions(PatternRewriter &rewriter, Location loc,
                                         IndexPool &indexPool, Value operand,
                                         ArrayRef<OpFoldResult> targetShape,
                                         ArrayRef<Value> masterOperands) {
-  size_t rank = operand.getType().cast<RankedTensorType>().getRank();
-  assert(targetShape.size() == rank);
-  assert(masterOperands.size() == rank);
+  int64_t rank = operand.getType().cast<RankedTensorType>().getRank();
+  assert((int64_t)targetShape.size() == rank);
+  assert((int64_t)masterOperands.size() == rank);
   for (auto index : llvm::seq<int64_t>(0, rank))
     operand =
         broadcastDynamicDimension(rewriter, loc, indexPool, operand, index,
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index 3f39cbf03a9a80..8fb8d16486560c 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -26,6 +26,8 @@
 #include "mlir/Transforms/DialectConversion.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
+#include "mlir/Interfaces/InferTypeOpInterface.h"
+
 #include <numeric>
 #include <type_traits>
 
@@ -34,7 +36,7 @@ using namespace mlir::tosa;
 
 static mlir::Value applyPad(Location loc, Value input, ArrayRef<int64_t> pad,
                             TypedAttr padAttr, OpBuilder &rewriter) {
-  // Input should be padded if necessary.
+  // Input should be padded only if necessary.
   if (llvm::all_of(pad, [](int64_t p) { return p == 0; }))
     return input;
 
@@ -47,7 +49,7 @@ static mlir::Value applyPad(Location loc, Value input, ArrayRef<int64_t> pad,
   SmallVector<int64_t, 4> paddedShape;
   SmallVector<OpFoldResult, 8> lowIndices;
   SmallVector<OpFoldResult, 8> highIndices;
-  for (int i = 0, s = inputShape.size(); i < s; i++) {
+  for (size_t i : llvm::seq(inputShape.size())) {
     auto lowPad = pad[i * 2];
     auto highPad = pad[i * 2 + 1];
     if (ShapedType::isDynamic(inputShape[i]))
@@ -131,20 +133,19 @@ static mlir::Value linalgBroadcastAndMaybeExtSI(PatternRewriter &rewriter,
 
 static mlir::Value reifyConstantDim(int64_t attr,
                                     ImplicitLocOpBuilder &builder) {
-  return builder.createOrFold<arith::IndexCastOp>(
-      builder.getIndexType(),
-      builder.create<arith::ConstantOp>(builder.getI64IntegerAttr(attr)));
+  return builder.create<arith::ConstantIndexOp>(attr);
 }
 
 // Calculating the output width/height using the formula:
 // H = ((IH+pad_top+pad_bottom-(dilation_y*(KH-1)+1))/stride_y)+1
 // W = ((IW+pad_left+pad_right-(dilation_x*(KW-1)+1))/stride_x)+1
 
-static mlir::Value getConvOutputDim(Location loc, Value inputDim,
-                                    int64_t padBeforeAttr, int64_t padAfterAttr,
-                                    Value kernelDim, int64_t strideAttr,
-                                    int64_t dilationAttr, Type inputETy,
-                                    OpBuilder &rewriter) {
+static mlir::Value getConvOrPoolOutputDim(Location loc, Value inputDim,
+                                          int64_t padBeforeAttr,
+                                          int64_t padAfterAttr, Value kernelDim,
+                                          int64_t strideAttr,
+                                          int64_t dilationAttr,
+                                          OpBuilder &rewriter) {
   ImplicitLocOpBuilder builder(loc, rewriter);
   auto one = rewriter.create<arith::ConstantOp>(
       loc, IntegerAttr::get(inputDim.getType(), 1));
@@ -171,7 +172,6 @@ static SmallVector<Value> inferDynamicDimsForConv(
     ArrayRef<int64_t> dilationAttr, ArrayRef<int64_t> inputSizeDims,
     ArrayRef<int64_t> kernelSizeDims, OpBuilder &rewriter) {
   ShapedType inputTy = cast<ShapedType>(input.getType());
-  Type inputETy = inputTy.getElementType();
   int64_t inputRank = inputTy.getRank();
 
   SmallVector<Value> dynDims;
@@ -190,8 +190,8 @@ static SmallVector<Value> inferDynamicDimsForConv(
           rewriter.create<tensor::DimOp>(loc, weight, kernelDim);
       // H = F(IH, pad_top, pad_bottom, dilation_y, KH, stride_y)
       dynDims[inputDim] =
-          getConvOutputDim(loc, initDynDim, padTop, padBottom, kernelDynDim,
-                           stride, dilation, inputETy, rewriter);
+          getConvOrPoolOutputDim(loc, initDynDim, padTop, padBottom,
+                                 kernelDynDim, stride, dilation, rewriter);
     }
   }
 
@@ -685,20 +685,61 @@ class MaxPool2dConverter : public OpRewritePattern<tosa::MaxPool2dOp> {
 public:
   using OpRewritePattern<tosa::MaxPool2dOp>::OpRewritePattern;
 
+  // Compute the dynamic output sizes of the maxpool operation.
+  static SmallVector<Value>
+  computeDynamicOutputSizes(tosa::MaxPool2dOp op, PatternRewriter &rewriter) {
+    TensorType resultTy = op.getType();
+    Location loc = op.getLoc();
+
+    TypedValue<TensorType> input = op.getInput();
+    ArrayRef<int64_t> kernel = op.getKernel();
+    ArrayRef<int64_t> pad = op.getPad();
+    ArrayRef<int64_t> stride = op.getStride();
+
+    SmallVector<Value> dynamicDims;
+
+    // Batch dimension
+    if (resultTy.isDynamicDim(0))
+      dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 0));
+
+    // Height/width dimensions
+    for (int64_t dim : {1, 2}) {
+      if (!resultTy.isDynamicDim(dim))
+        continue;
+
+      // Index into the attribute arrays
+      int64_t index = dim - 1;
+
+      // Input height/width
+      Value ihw = rewriter.create<tensor::DimOp>(loc, input, dim);
+
+      // Kernel height/width
+      Value khw = rewriter.create<arith::ConstantIndexOp>(loc, kernel[index]);
+
+      // Output height/width
+      Value ohw = getConvOrPoolOutputDim(loc, ihw, pad[index * 2],
+                                         pad[index * 2 + 1], khw, stride[index],
+                                         /*dilationAttr=*/1, rewriter);
+      dynamicDims.push_back(ohw);
+    }
+
+    // Channel dimension
+    if (resultTy.isDynamicDim(3))
+      dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 3));
+
+    return dynamicDims;
+  }
+
   LogicalResult matchAndRewrite(tosa::MaxPool2dOp op,
                                 PatternRewriter &rewriter) const final {
     Location loc = op.getLoc();
-    Value input = op.getInput();
-    ShapedType inputTy = cast<ShapedType>(input.getType());
+    TypedValue<TensorType> input = op.getInput();
+    ShapedType inputTy = input.getType();
 
-    ShapedType resultTy = cast<ShapedType>(op.getType());
+    ShapedType resultTy = op.getType();
     Type resultETy = inputTy.getElementType();
 
-    auto dynamicDimsOr =
-        checkHasDynamicBatchDims(rewriter, op, {input, op.getOutput()});
-    if (!dynamicDimsOr.has_value())
-      return failure();
-    SmallVector<Value> dynamicDims = *dynamicDimsOr;
+    SmallVector<Value> dynamicDims = computeDynamicOutputSizes(op, rewriter);
 
     // Determine what the initial value needs to be for the max pool op.
     TypedAttr initialAttr;
@@ -721,6 +762,7 @@ class MaxPool2dConverter : public OpRewritePattern<tosa::MaxPool2dOp> {
     pad.resize(2, 0);
     llvm::append_range(pad, op.getPad());
     pad.resize(pad.size() + 2, 0);
+
     Value paddedInput = applyPad(loc, input, pad, initialAttr, rewriter);
 
     Value initialValue = rewriter.create<arith::ConstantOp>(loc, initialAttr);
@@ -736,9 +778,7 @@ class MaxPool2dConverter : public OpRewritePattern<tosa::MaxPool2dOp> {
         loc, resultTy.getShape(), resultTy.getElementType(), dynamicDims);
 
     Value filledEmptyTensor =
-        rewriter
-            .create<linalg::FillOp>(loc, ValueRange{initialValue},
-                                    ValueRange{emptyTensor})
+        rewriter.create<linalg::FillOp>(loc, initialValue, emptyTensor)
             .result();
 
     Value fakeWindowDims =
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
index e64903671e599f..b4049000c50dc8 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
@@ -1,5 +1,6 @@
 // RUN: mlir-opt --split-input-file -pass-pipeline="builtin.module(func.func(tosa-to-linalg-named))" %s -verify-diagnostics -o -| FileCheck %s
 // RUN: mlir-opt --split-input-file -pass-pipeline="builtin.module(func.func(tosa-to-linalg-named{prefer-conv2d-kernel-layout-hwcf=true}))" %s -verify-diagnostics -o -| FileCheck --check-prefix="HWCF" %s
+// RUN: mlir-opt --split-input-file -pass-pipeline="builtin.module(func.func(tosa-to-linalg-named,cse))" %s -verify-diagnostics -o -| FileCheck --check-prefix="CHECK-CSE" %s
 
 // CHECK-LABEL: @matmul
 func.func @matmul(%arg0: tensor<1x5x3xf32>, %arg1: tensor<1x3x6xf32>) -> (tensor<1x5x6xf32>) {
@@ -215,6 +216,59 @@ func.func @max_pool_i32(%arg0: tensor<1x6x34x62xi32>) -> () {
   return
 }
 
+// CHECK-CSE-LABEL: @max_pool_all_dynamic
+func.func @max_pool_all_dynamic(%arg0: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+  // Batch size
+  // CHECK-CSE: %[[C0:.+]] = arith.constant 0 : index
+  // CHECK-CSE: %[[BATCH:.+]] = tensor.dim %arg0, %[[C0]] : tensor<?x?x?x?xf32>
+
+  // Compute output height
+  // CHECK-CSE: %[[C1:.+]] = arith.constant 1 : index
+  // CHECK-CSE: %[[IH:.+]] = tensor.dim %arg0, %[[C1]] : tensor<?x?x?x?xf32>
+  // CHECK-CSE: %[[C2:.+]] = arith.constant 2 : index
+  // CHECK-CSE: %[[PADDED_BEFORE:.+]] = arith.addi %[[IH]], %[[C0]] : index
+  // CHECK-CSE: %[[PADDED_AFTER:.+]] = arith.addi %[[PADDED_BEFORE]], %[[C0]] : index
+  // CHECK-CSE: %[[SUB_ONE:.+]] = arith.subi %[[C2]], %[[C1]] : index
+  // CHECK-CSE: %[[DILATED:.+]] = arith.muli %[[C1]], %[[SUB_ONE]] : index
+  // CHECK-CSE: %[[ADD_ONE:.+]] = arith.addi %[[DILATED]], %[[C1]] : index
+  // CHECK-CSE: %[[SUBTRACT:.+]] = arith.subi %[[PADDED_AFTER]], %[[ADD_ONE]] : index
+  // CHECK-CSE: %[[DIVIDE:.+]] = arith.divui %[[SUBTRACT]], %[[C1]] : index
+  // CHECK-CSE: %[[HEIGHT:.+]] = arith.addi %[[DIVIDE]], %[[C1]] : index
+
+  // Compute output width
+  // CHECK-CSE: %[[IW:.+]] = tensor.dim %arg0, %[[C2]] : tensor<?x?x?x?xf32>
+  // CHECK-CSE: %[[C5:.+]] = arith.constant 5 : index
+  // CHECK-CSE: %[[PADDED_BEFORE:.+]] = arith.addi %[[IW]], %[[C2]] : index
+  // CHECK-CSE: %[[PADDED_AFTER:.+]] = arith.addi %[[PADDED_BEFORE]], %[[C2]] : index
+  // CHECK-CSE: %[[SUB_ONE:.+]] = arith.subi %[[C5]], %[[C1]] : index
+  // CHECK-CSE: %[[DILATED:.+]] = arith.muli %[[C1]], %[[SUB_ONE]] : index
+  // CHECK-CSE: %[[ADD_ONE:.+]] = arith.addi %[[DILATED]], %[[C1]] : index
+  // CHECK-CSE: %[[SUBTRACT:.+]] = arith.subi %[[PADDED_AFTER]], %[[ADD_ONE]] : index
+  // CHECK-CSE: %[[DIVIDE:.+]] = arith.divui %[[SUBTRACT]], %[[C1]] : index
+  // CHECK-CSE: %[[WIDTH:.+]] = arith.addi %14, %[[C1]] : index
+
+  // Channel size
+  // CHECK-CSE: %[[C3:.+]] = arith.constant 3 : index
+  // CHECK-CSE: %[[CHANNEL:.+]] = tensor.dim %arg0, %[[C3]] : tensor<?x?x?x?xf32>
+
+  // Pad the input
+  // CHECK-CSE: %[[FLOAT_MIN:.+]] = arith.constant -3.40282347E+38 : f32
+  // CHECK-CSE: %[[PADDED:.+]] = tensor.pad %arg0 low[0, 0, 2, 0] high[0, 0, 2, 0] {
+  // CHECK-CSE:   tensor.yield %[[FLOAT_MIN]] : f32
+
+  // Allocate the output and fill with minimum value
+  // CHECK-CSE: %[[INIT:.+]] = tensor.empty(%[[BATCH]], %[[HEIGHT]], %[[WIDTH]], %[[CHANNEL]]) : tensor<?x?x?x?xf32>
+  // CHECK-CSE: %[[FILL:.+]] = linalg.fill ins(%[[FLOAT_MIN]] : f32) outs(%[[INIT]] : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+  // CHECK-CSE: %[[FAKE_WINDOW:.+]] = tensor.empty() : tensor<2x5xf32>
+
+  // Compute max pool
+  // CHECK-CSE: %[[OUT:.+]] = linalg.pooling_nhwc_max {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%[[PADDED]], %[[FAKE_WINDOW]] : tensor<?x?x?x?xf32>, tensor<2x5xf32>) outs(%[[FILL]] : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+  // CHECK-CSE: return %[[OUT]]
+
+  %0 = tosa.max_pool2d %arg0 {kernel = array<i64: 2, 5>, pad = array<i64: 0, 0, 2, 2>, stride = array<i64: 1, 1>} : (tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+  return %0 : tensor<?x?x?x?xf32>
+}
+
 // -----
 
 // CHECK-LABEL: @avg_pool_f32
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 1fa783f05f04ee..445e8be47678d5 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -270,7 +270,8 @@ func.func @test_add_2d_all_dynamic(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32
   // CHECK: %[[VAL_0:.*]] = tensor.dim %[[ARG0]], %[[CONST0]] : tensor<?x?xf32>
   // CHECK: %[[VAL_1:.*]] = arith.cmpi eq, %[[VAL_0]], %[[CONST1]] : index
   // CHECK: %[[ARG0_DIM0_BROADCAST:.*]] = scf.if %[[VAL_1]] -> (tensor<?x?xf32>) {
-  // CHECK:   %[[VAL_2:.*]] = tensor.dim %[[ARG0]], %[[CONST1]] : tensor<?x?xf32>
+  // CHECK:   %[[LOCAL_CONST1:.*]] = arith.constant 1 : index
+  // CHECK:   %[[VAL_2:.*]] = tensor.dim %[[ARG0]], %[[LOCAL_CONST1]] : tensor<?x?xf32>
   // CHECK:   %[[VAL_3:.*]] = tensor.empty(%[[MAX_DIM0]], %[[VAL_2]]) : tensor<?x?xf32>
   // CHECK:   %[[VAL_4:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG0]] : tensor<?x?xf32>) outs(%[[VAL_3]] : tensor<?x?xf32>) {
   // CHECK:   ^bb0(%[[VAL_5:.*]]: f32, %[[VAL_6:.*]]: f32):
@@ -284,7 +285,8 @@ func.func @test_add_2d_all_dynamic(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32
   // CHECK: %[[VAL_7:.*]] = tensor.dim %[[ARG0_DIM0_BROADCAST]], %[[CONST1]] : tensor<?x?xf32>
   // CHECK: %[[VAL_8:.*]] = arith.cmpi eq, %[[VAL_7]], %[[CONST1]] : index
   // CHECK: %[[ARG0_DIM1_BROADCAST:.*]] = scf.if %[[VAL_8]] -> (tensor<?x?xf32>) {
-  // CHECK:   %[[VAL_9:.*]] = tensor.dim %[[ARG0_DIM0_BROADCAST]], %[[CONST0]] : tensor<?x?xf32>
+  // CHECK:   %[[LOCAL_CONST0:.*]] = arith.constant 0 : index
+  // CHECK:   %[[VAL_9:.*]] = tensor.dim %[[ARG0_DIM0_BROADCAST]], %[[LOCAL_CONST0]] : tensor<?x?xf32>
   // CHECK:   %[[VAL_10:.*]] = tensor.empty(%[[VAL_9]], %[[MAX_DIM1]]) : tensor<?x?xf32>
   // CHECK:   %[[VAL_11:.*]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG0_DIM0_BROADCAST]] : tensor<?x?xf32>) outs(%[[VAL_10]] : tensor<?x?xf32>) {
   // CHECK:   ^bb0(%[[VAL_12:.*]]: f32, %[[VAL_13:.*]]: f32):
@@ -298,7 +300,8 @@ func.func @test_add_2d_all_dynamic(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32
   // CHECK: %[[VAL_14:.*]] = tensor.dim %[[ARG1]], %[[CONST0]] : tensor<?x?xf32>
   // CHECK: %[[VAL_15:.*]] = arith.cmpi eq, %[[VAL_14]], %[[CONST1]] : index
   // CHECK: %[[ARG1_DIM0_BROADCAST:.*]] = scf.if %[[VAL_15]] -> (tensor<?x?xf32>) {
-  // CHECK:   %[[VAL_16:.*]] = tensor.dim %[[ARG1]], %[[CONST1]] : tensor<?x?xf32>
+  // CHECK:   %[[LOCAL_CONST1:.*]] = arith.constant 1 : index
+  // CHECK:   %[[VAL_16:.*]] = tensor.dim %[[ARG1]], %[[LOCAL_CONST1]] : tensor<?x?xf32>
   // CHECK:   %[[VAL_17:.*]] = tensor.empty(%[[MAX_DIM0]], %[[VAL_16]]) : tensor<?x?xf32>
   // CHECK:   %[[VAL_18:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG1]] : tensor<?x?xf32>) outs(%[[VAL_17]] : tensor<?x?xf32>) {
   // CHECK:   ^bb0(%[[VAL_19:.*]]: f32, %[[VAL_20:.*]]: f32):
@@ -312,7 +315,8 @@ func.func @test_add_2d_all_dynamic(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32
   // CHECK: %[[VAL_21:.*]] = tensor.dim %[[ARG1_DIM0_BROADCAST]], %[[CONST1]] : tensor<?x?xf32>
   // CHECK: %[[VAL_22:.*]] = arith.cmpi eq, %[[VAL_21]], %[[CONST1]] : index
   // CHECK: %[[ARG1_DIM1_BROADCAST:.*]] = scf.if %[[VAL_22]] -> (tensor<?x?xf32>) {
-  // CHECK:   %[[VAL_23:.*]] = tensor.dim %[[ARG1_DIM0_BROADCAST]], %[[CONST0]] : tensor<?x?xf32>
+  // CHECK:   %[[LOCAL_CONST0:.*]] = arith.constant 0 : index
+  // CHECK:   %[[VAL_23:.*]] = tensor.dim %[[ARG1_DIM0_BROADCAST]], %[[LOCAL_CONST0]] : tensor<?x?xf32>
   // CHECK:   %[[VAL_24:.*]] = tensor.empty(%[[VAL_23]], %[[MAX_DIM1]]) : tensor<?x?xf32>
   // CHECK:   %[[VAL_25:.*]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG1_DIM0_BROADCAST]] : tensor<?x?xf32>) outs(%[[VAL_24]] : tensor<?x?xf32>) {
   // CHECK:   ^bb0(%[[VAL_26:.*]]: f32, %[[VAL_27:.*]]: f32):
diff --git a/mlir/test/Integration/Dialect/Tosa/CPU/test-maxpool-dynamic.mlir b/mlir/test/Integration/Dialect/Tosa/CPU/test-maxpool-dynamic.mlir
new file mode 100644
index 00000000000000..05a78e32b9e115
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Tosa/CPU/test-maxpool-dynamic.mlir
@@ -0,0 +1,112 @@
+// DEFINE: %{tosa-to-linalg-pipeline} = -pass-pipeline="builtin.module(func.func(tosa-infer-shapes,tosa-to-linalg-named,tosa-to-linalg,tosa-to-arith))"
+
+// RUN:   mlir-opt %s \
+// RUN:     %{tosa-to-linalg-pipeline} \
+// RUN: | mlir-opt \
+// RUN:     -one-shot-bufferize="bufferize-function-boundaries" \
+// RUN:     -buffer-deallocation-pipeline \
+// RUN:     -test-lower-to-llvm \
+// RUN: | mlir-cpu-runner \
+// RUN:     -entry-point-result=void \
+// RUN:     -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils \
+// RUN: | FileCheck %s
+
+// Validate that the TOSA lowering for tosa.max_pool2d produces the same results when
+// for fully static and fully dynamic inputs.
+
+!tensor_type = tensor<1x4x4x1xf32>
+!memref_type = memref<1x4x4x1xf32>
+
+// Utility functions
+func.func private @printMemrefF32(memref<*xf32>) attributes { llvm.emit_c_interface }
+
+func.func @max_pool_static(%arg0: !tensor_type) -> (!tensor_type) {
+  %0 = tosa.max_pool2d %arg0 {
+    pad = array<i64: 1, 1, 1, 1>,
+    kernel = array<i64: 3, 3>,
+    stride = array<i64: 1, 1>
+  } : (tensor<1x4x4x1xf32>) -> tensor<1x4x4x1xf32>
+  return %0 : tensor<1x4x4x1xf32>
+}
+
+func.func @max_pool_dynamic(%arg0: tensor<?x?x?x?xf32>) -> (tensor<?x?x?x?xf32>) {
+  %0 = tosa.max_pool2d %arg0 {
+    pad = array<i64: 1, 1, 1, 1>,
+    kernel = array<i64: 3, 3>,
+    stride = array<i64: 1, 1>
+  } : (tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+  return %0 : tensor<?x?x?x?xf32>
+}
+
+// Test harness to compare the results of a fully statically shaped max_pool2d with
+// a fully dynamically shaped max_pool2d on the same inputs.
+func.func @main() {
+  %A = arith.constant dense<[[
+    [[0.0], [0.1], [0.2], [0.3]], // H = 0
+    [[1.0], [1.1], [1.2], [1.3]], // H = 1
+    [[2.0], [2.1], [2.2], [2.3]], // H = 2
+    [[3.0], [3.1], [3.2], [3.3]]  // H = 3
+  ]]> : tensor<1x4x4x1xf32>
+
+  %A_dynamic = tensor.cast %A : !tensor_type to tensor<?x?x?x?xf32>
+
+  // Call both static and dynamically sized variants
+  %result_static  = func.call @max_pool_static(%A) : (!tensor_type) -> !tensor_type
+  %result_dynamic = func.call @max_pool_dynamic(%A_dynamic) : (tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+
+  %static_buffer = bufferization.to_memref %result_static : !memref_type
+  %unranked_static_buffer = memref.cast %static_buffer : !memref_type to memref<*xf32>
+
+  // CHECK: Unranked Memref base@ = {{.*}} rank = 4 offset = 0 sizes = [1, 4, 4, 1] strides = [16, 4, 1, 1] data =
+
+  // CHECK-NEXT: 1.1
+  // CHECK-NEXT: 1.2
+  // CHECK-NEXT: 1.3
+  // CHECK-NEXT: 1.3
+
+  // CHECK-NEXT: 2.1
+  // CHECK-NEXT: 2.2
+  // CHECK-NEXT: 2.3
+  // CHECK-NEXT: 2.3
+
+  // CHECK-NEXT: 3.1
+  // CHECK-NEXT: 3.2
+  // CHECK-NEXT: 3.3
+  // CHECK-NEXT: 3.3
+
+  // CHECK-NEXT: 3.1
+  // CHECK-NEXT: 3.2
+  // CHECK-NEXT: 3.3
+  // CHECK-NEXT: 3.3
+
+  func.call @printMemrefF32(%unranked_static_buffer) : (memref<*xf32>) -> ()
+
+  %dynamic_buffer = bufferization.to_memref %result_dynamic : memref<?x?x?x?xf32>
+  %unranked_dynamic_buffer = memref.cast %dynamic_buffer : memref<?x?x?x?xf32> to memref<*xf32>
+
+  // CHECK: Unranked Memref base@ = {{.*}} rank = 4 offset = 0 sizes = [1, 4, 4, 1] strides = [16, 4, 1, 1] data =
+  // CHECK-NEXT: 1.1
+  // CHECK-NEXT: 1.2
+  // CHECK-NEXT: 1.3
+  // CHECK-NEXT: 1.3
+
+  // CHECK-NEXT: 2.1
+  // CHECK-NEXT: 2.2
+  // CHECK-NEXT: 2.3
+  // CHECK-NEXT: 2.3
+
+  // CHECK-NEXT: 3.1
+  // CHECK-NEXT: 3.2
+  // CHECK-NEXT: 3.3
+  // CHECK-NEXT: 3.3
+
+  // CHECK-NEXT: 3.1
+  // CHECK-NEXT: 3.2
+  // CHECK-NEXT: 3.3
+  // CHECK-NEXT: 3.3
+
+  func.call @printMemrefF32(%unranked_dynamic_buffer) : (memref<*xf32>) -> ()
+
+  return
+}
+



More information about the Mlir-commits mailing list