[Mlir-commits] [mlir] 44865e9 - [mlir][Linalg] Lower padding attribute for pooling ops

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Oct 13 14:11:59 PDT 2020


Author: Alberto Magni
Date: 2020-10-13T14:11:02-07:00
New Revision: 44865e9169f625b78b7198c0f3f11d9330b01f06

URL: https://github.com/llvm/llvm-project/commit/44865e9169f625b78b7198c0f3f11d9330b01f06
DIFF: https://github.com/llvm/llvm-project/commit/44865e9169f625b78b7198c0f3f11d9330b01f06.diff

LOG: [mlir][Linalg] Lower padding attribute for pooling ops

Update linalg-to-loops lowering for pooling operations to perform
padding of the input when specified by the corresponding attribute.

Reviewed By: hanchung

Differential Revision: https://reviews.llvm.org/D88911

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
    mlir/test/Dialect/Linalg/loops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
index 2abf0aed37b7..ee1512c8ec89 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
@@ -222,22 +222,24 @@ static void emitScalarImplementation(ArrayRef<Value> allIvs, FillOp fillOp) {
   nPar > 0 ? O(ivs) = fillOp.value() : O() = fillOp.value();
 }
 
+// Create a padded view into the given `input` tensor using the 'indices'
+// to access the tensor. `skipPadding` lists the dimensions for which no padding
+// is needed e.g. the non-spatial dimensions for convolutions.
 template <typename IndexedValueType>
-static Value getConvOpInput(ConvOp convOp, StdIndexedValue im,
-                            MutableArrayRef<Value> imIdx) {
+Value getPaddedInput(Value input, ArrayRef<Value> indices,
+                     ArrayRef<int> skipPadding, Value padValue) {
   // TODO: add a level of indirection to linalg.generic.
-  if (!convOp.padding())
-    return im(imIdx);
+
+  IndexedValueType indexedInput(input);
 
   auto *context = ScopedContext::getContext();
   Value zeroIndex = std_constant_index(0);
   SmallVector<Value, 8> conds;
   SmallVector<Value, 8> clampedImIdx;
-  for (auto iter : llvm::enumerate(imIdx)) {
+  for (auto iter : llvm::enumerate(indices)) {
     int idx = iter.index();
     auto dim = iter.value();
-    // Only need to iterate over the window dimensions.
-    if (idx == 0 || idx == static_cast<int>(imIdx.size()) - 1) {
+    if (is_contained(skipPadding, idx)) {
       clampedImIdx.push_back(dim);
       continue;
     }
@@ -250,7 +252,7 @@ static Value getConvOpInput(ConvOp convOp, StdIndexedValue im,
       conds.push_back(leftOutOfBound);
     else
       conds.push_back(conds.back() || leftOutOfBound);
-    Value rightBound = std_dim(convOp.input(), idx);
+    Value rightBound = std_dim(input, idx);
     conds.push_back(conds.back() || (sge(dim, rightBound)));
 
     // When padding is involved, the indices will only be shifted to negative,
@@ -262,14 +264,73 @@ static Value getConvOpInput(ConvOp convOp, StdIndexedValue im,
     clampedImIdx.push_back(affine_max(dim.getType(), maxMap, ValueRange{dim}));
   }
 
-  auto &b = ScopedContext::getBuilderRef();
-  Type type = convOp.input().getType().cast<MemRefType>().getElementType();
-  Value zero = std_constant(type, b.getZeroAttr(type));
-  Value readInput = im(clampedImIdx);
+  Value readInput = indexedInput(clampedImIdx);
   return conds.empty() ? readInput
-                       : (Value)std_select(conds.back(), zero, readInput);
+                       : (Value)std_select(conds.back(), padValue, readInput);
+}
+
+namespace {
+
+/// The padding value for a given Op depends on the semantics of the Op.
+/// The identity value for ConvOp and PoolingSumOp is 0, for PoolingMaxOp is
+/// -inf or minInt and for PoolingMinOp is inf or maxInt.
+template <typename OpType>
+Attribute getPadValueAttr(Type type) {
+  llvm_unreachable("Unexpected op type for getPadValueAttr");
+  return {};
+}
+
+template <>
+Attribute getPadValueAttr<PoolingMaxOp>(Type type) {
+  auto &b = ScopedContext::getBuilderRef();
+  if (auto floatType = type.dyn_cast<FloatType>()) {
+    return b.getFloatAttr(
+        floatType,
+        APFloat::getInf(floatType.getFloatSemantics(), /*Negative*/ true));
+  }
+  if (auto intType = type.dyn_cast<IntegerType>()) {
+    unsigned width = intType.getWidth();
+    // The select instruction used to lower the PoolingMin uses a signed
+    // comparison, use a signed constant irrespective of the signedness of the
+    // integer type.
+    return b.getIntegerAttr(intType, APInt::getSignedMinValue(width));
+  }
+  llvm_unreachable("Unsupported data type for PoolingMaxOp");
+  return {};
 }
 
+template <>
+Attribute getPadValueAttr<PoolingMinOp>(Type type) {
+  auto &b = ScopedContext::getBuilderRef();
+  if (auto floatType = type.dyn_cast<FloatType>()) {
+    return b.getFloatAttr(floatType,
+                          APFloat::getInf(floatType.getFloatSemantics()));
+  }
+  if (auto intType = type.dyn_cast<IntegerType>()) {
+    unsigned width = intType.getWidth();
+    // The select instruction used to lower the PoolingMin uses a signed
+    // comparison, use a signed constant irrespective of the signedness of the
+    // integer type.
+    return b.getIntegerAttr(intType, APInt::getSignedMaxValue(width));
+  }
+  llvm_unreachable("Unsupported data type for PoolingMinOp");
+  return {};
+}
+
+template <>
+Attribute getPadValueAttr<PoolingSumOp>(Type type) {
+  auto &b = ScopedContext::getBuilderRef();
+  return b.getZeroAttr(type);
+}
+
+template <>
+Attribute getPadValueAttr<ConvOp>(Type type) {
+  auto &b = ScopedContext::getBuilderRef();
+  return b.getZeroAttr(type);
+}
+
+} // namespace
+
 /// Returns true is `convOp` has a non-zero padding.
 static bool hasPadding(ConvOp convOp) {
   for (unsigned i = 0, e = convOp.getNumSpatialDimensions(); i < e; ++i) {
@@ -301,8 +362,12 @@ static void emitScalarImplementation(ArrayRef<Value> allIvs, ConvOp convOp) {
   // which is not allowed by affine.load. Override to use an StdIndexedValue
   // when there is non-zero padding.
   if (hasPadding(convOp)) {
-    StdIndexedValue I(convOp.input());
-    Value paddedInput = getConvOpInput<IndexedValueType>(convOp, I, imIdx);
+    Type type = convOp.input().getType().cast<MemRefType>().getElementType();
+    Value padValue = std_constant(type, getPadValueAttr<ConvOp>(type));
+    Value paddedInput = getPaddedInput<StdIndexedValue>(
+        convOp.input(), imIdx,
+        /* Only need to pad the window dimensions */
+        {0, static_cast<int>(imIdx.size()) - 1}, padValue);
     O(oIdx) += F(fIdx) * paddedInput;
   } else {
     IndexedValueType I(convOp.input());
@@ -310,15 +375,36 @@ static void emitScalarImplementation(ArrayRef<Value> allIvs, ConvOp convOp) {
   }
 }
 
+template <typename PoolingOp>
+static bool hasPadding(PoolingOp poolingOp) {
+  for (unsigned i = 0, e = poolingOp.getNumWindowLoops(); i < e; ++i) {
+    if (poolingOp.getLowPad(i) > 0 || poolingOp.getHighPad(i) > 0)
+      return true;
+  }
+  return false;
+}
+
+template <typename IndexedValueType, typename PoolingOp>
+static Value getPoolingInput(PoolingOp op, ArrayRef<Value> inputIndices) {
+  if (hasPadding(op)) {
+    Type type =
+        op.input().getType().template cast<MemRefType>().getElementType();
+    Value padValue = std_constant(type, getPadValueAttr<PoolingOp>(type));
+    return getPaddedInput<StdIndexedValue>(op.input(), inputIndices,
+                                           /*Pad every dimension*/ {},
+                                           padValue);
+  }
+  IndexedValueType input(op.input());
+  return input(inputIndices);
+}
+
 template <typename IndexedValueType, typename OpType>
-static void emitPoolingMinMaxScalarImplementation(ArrayRef<Value> allIvs,
-                                                  OpType op) {
+void emitPoolingMinMaxScalarImplementation(ArrayRef<Value> allIvs, OpType op) {
   InputAndOutputIndices indices = getInputAndOutputIndices(allIvs, op);
   // Emit scalar form.
   IndexedValueType output(op.output());
-  IndexedValueType input(op.input());
   Value lhs = output(indices.outputs);
-  Value rhs = input(indices.inputs);
+  Value rhs = getPoolingInput<IndexedValueType>(op, indices.inputs);
   using edsc::op::sgt;
   using edsc::op::slt;
   Value value = std::is_same<OpType, PoolingMinOp>()
@@ -342,10 +428,11 @@ static void emitScalarImplementation(ArrayRef<Value> allIvs, PoolingMinOp op) {
 template <typename IndexedValueType>
 static void emitScalarImplementation(ArrayRef<Value> allIvs, PoolingSumOp op) {
   auto indices = getInputAndOutputIndices(allIvs, op);
-  IndexedValueType input(op.input()), output(op.output());
+  IndexedValueType output(op.output());
 
   // Emit scalar form.
-  output(indices.outputs) += input(indices.inputs);
+  output(indices.outputs) +=
+      getPoolingInput<IndexedValueType>(op, indices.inputs);
 }
 
 /// Emits the MLIR for the scalar part of the indexed generic op by:

diff  --git a/mlir/test/Dialect/Linalg/loops.mlir b/mlir/test/Dialect/Linalg/loops.mlir
index 348468ad105e..60042556ecac 100644
--- a/mlir/test/Dialect/Linalg/loops.mlir
+++ b/mlir/test/Dialect/Linalg/loops.mlir
@@ -17,6 +17,8 @@
 // CHECKLOOP-DAG: #[[$convLowerBound:.*]] = affine_map<()[s0] -> (s0 floordiv 2)>
 // CHECKLOOP-DAG: #[[$convUpperBound:.*]] = affine_map<()[s0, s1] -> (s1 + s0 floordiv 2 - s0 + 1)>
 // CHECKLOOP-DAG: #[[$convMap:.*]] = affine_map<(d0, d1)[s0] -> (d0 + d1 - s0 floordiv 2)>
+// CHECKLOOP-DAG: #[[$stride1Dilation1Padding1:.*]] = affine_map<(d0, d1) -> (d0 + d1 - 1)>
+// CHECKLOOP-DAG: #[[$stride1Dilation1Padding2:.*]] = affine_map<(d0, d1) -> (d0 + d1 - 2)>
 
 // CHECKPARALLEL-DAG: #[[$strided1D:.*]] = affine_map<(d0)[s0] -> (d0 + s0)>
 // CHECKPARALLEL-DAG: #[[$strided2D:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
@@ -31,6 +33,8 @@
 // CHECKPARALLEL-DAG: #[[$convLowerBound:.*]] = affine_map<()[s0] -> (s0 floordiv 2)>
 // CHECKPARALLEL-DAG: #[[$convUpperBound:.*]] = affine_map<()[s0, s1] -> (s1 + s0 floordiv 2 - s0 + 1)>
 // CHECKPARALLEL-DAG: #[[$convMap:.*]] = affine_map<(d0, d1)[s0] -> (d0 + d1 - s0 floordiv 2)>
+// CHECKPARALLEL-DAG: #[[$stride1Dilation1Padding1:.*]] = affine_map<(d0, d1) -> (d0 + d1 - 1)>
+// CHECKPARALLEL-DAG: #[[$stride1Dilation1Padding2:.*]] = affine_map<(d0, d1) -> (d0 + d1 - 2)>
 
 
 func @matmul(%arg0: memref<?xi8>, %M: index, %N: index, %K: index) {
@@ -470,6 +474,102 @@ func @pooling_max(%arg0: memref<?x?xf32>,
 //       CHECKPARALLEL:         %[[RES:.*]] = select %{{.*}}, %{{.*}}, %{{.*}} : f32
 //       CHECKPARALLEL:         store %[[RES]], %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xf32>
 
+func @pooling_max_padding(%arg0: memref<?x?xf32>,
+                          %arg1: memref<?x?xi32>,
+                          %arg2: memref<?x?xf32>) {
+  linalg.pooling_max(%arg0, %arg1, %arg2) { padding = dense<[[2, 2], [1, 1]]> : tensor<2x2xi64> } :
+    memref<?x?xf32>, memref<?x?xi32>, memref<?x?xf32>
+  return
+}
+// CHECKLOOP-LABEL: func @pooling_max_padding
+//       CHECKLOOP:   %[[PAD:.*]] = constant 0xFF800000 : f32
+//       CHECKLOOP:   %[[WX:.*]] = dim %arg1, %c0 : memref<?x?xi32>
+//       CHECKLOOP:   %[[WY:.*]] = dim %arg1, %c1 : memref<?x?xi32>
+//       CHECKLOOP:   %[[OX:.*]] = dim %arg2, %c0 : memref<?x?xf32>
+//       CHECKLOOP:   %[[OY:.*]] = dim %arg2, %c1 : memref<?x?xf32>
+//       CHECKLOOP:   scf.for %{{.*}} = %{{.*}} to %[[OX]] step %{{.*}} {
+//       CHECKLOOP:     scf.for %{{.*}} = %{{.*}} to %[[OY]] step %{{.*}} {
+//       CHECKLOOP:       scf.for %{{.*}} = %{{.*}} to %[[WX]] step %{{.*}} {
+//       CHECKLOOP:         scf.for %{{.*}} = %{{.*}} to %[[WY]] step %{{.*}} {
+//       CHECKLOOP:           %[[IX:.*]] = affine.apply #[[$stride1Dilation1Padding2]](%{{.*}}, %{{.*}})
+//       CHECKLOOP:           %[[IY:.*]] = affine.apply #[[$stride1Dilation1Padding1]](%{{.*}}, %{{.*}})
+//       CHECKLOOP:           %[[RHS:.*]] = load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xf32>
+//       CHECKLOOP:           %[[IDX:.*]] = affine.max #[[$clampMinMap]](%[[IX]])
+//       CHECKLOOP:           %[[IDY:.*]] = affine.max #[[$clampMinMap]](%[[IY]])
+//       CHECKLOOP:           %[[LHS:.*]] = load %{{.*}}[%[[IDX]], %[[IDY]]] : memref<?x?xf32>
+//       CHECKLOOP:           %[[SEL:.*]] = select %{{.*}}, %[[PAD]], %[[LHS]] : f32
+//       CHECKLOOP:           %[[CMP:.*]] = cmpf "ogt", %[[RHS]], %[[SEL]] : f32
+//       CHECKLOOP:           %[[RES:.*]] = select %{{.*}}, %[[RHS]], %[[SEL]] : f32
+//       CHECKLOOP:           store %[[RES]], %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xf32>
+
+// CHECKPARALLEL-LABEL: func @pooling_max_padding
+//       CHECKPARALLEL:   %[[PAD:.*]] = constant 0xFF800000 : f32
+//       CHECKPARALLEL:   %[[WX:.*]] = dim %arg1, %c0 : memref<?x?xi32>
+//       CHECKPARALLEL:   %[[WY:.*]] = dim %arg1, %c1 : memref<?x?xi32>
+//       CHECKPARALLEL:   %[[OX:.*]] = dim %arg2, %c0 : memref<?x?xf32>
+//       CHECKPARALLEL:   %[[OY:.*]] = dim %arg2, %c1 : memref<?x?xf32>
+//       CHECKPARALLEL:   scf.parallel (%{{.*}}, %{{.*}}) = (%{{.*}}, %{{.*}}) to (%[[OX]], %[[OY]]) step (%{{.*}}, %{{.*}}) {
+//       CHECKPARALLEL:     scf.for %{{.*}} = %{{.*}} to %[[WX]] step %{{.*}} {
+//       CHECKPARALLEL:       scf.for %{{.*}} = %{{.*}} to %[[WY]] step %{{.*}} {
+//       CHECKPARALLEL:         %[[IX:.*]] = affine.apply #[[$stride1Dilation1Padding2]](%{{.*}}, %{{.*}})
+//       CHECKPARALLEL:         %[[IY:.*]] = affine.apply #[[$stride1Dilation1Padding1]](%{{.*}}, %{{.*}})
+//       CHECKPARALLEL:         %[[RHS:.*]] = load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xf32>
+//       CHECKPARALLEL:         %[[IDX:.*]] = affine.max #[[$clampMinMap]](%[[IX]])
+//       CHECKPARALLEL:         %[[IDY:.*]] = affine.max #[[$clampMinMap]](%[[IY]])
+//       CHECKPARALLEL:         %[[LHS:.*]] = load %{{.*}}[%[[IDX]], %[[IDY]]] : memref<?x?xf32>
+//       CHECKPARALLEL:         %[[SEL:.*]] = select %{{.*}}, %[[PAD]], %[[LHS]] : f32
+//       CHECKPARALLEL:         %[[CMP:.*]] = cmpf "ogt", %[[RHS]], %[[SEL]] : f32
+//       CHECKPARALLEL:         %[[RES:.*]] = select %{{.*}}, %[[RHS]], %[[SEL]] : f32
+//       CHECKPARALLEL:         store %[[RES]], %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xf32>
+
+func @pooling_max_padding_i32(%arg0: memref<?x?xi32>,
+                              %arg1: memref<?x?xi32>,
+                              %arg2: memref<?x?xi32>) {
+  linalg.pooling_max(%arg0, %arg1, %arg2) { padding = dense<[[2, 2], [1, 1]]> : tensor<2x2xi64> } :
+    memref<?x?xi32>, memref<?x?xi32>, memref<?x?xi32>
+  return
+}
+// CHECKLOOP-LABEL: func @pooling_max_padding_i32
+//       CHECKLOOP:   %[[PAD:.*]] = constant -2147483648 : i32
+//       CHECKLOOP:   %[[WX:.*]] = dim %arg1, %c0 : memref<?x?xi32>
+//       CHECKLOOP:   %[[WY:.*]] = dim %arg1, %c1 : memref<?x?xi32>
+//       CHECKLOOP:   %[[OX:.*]] = dim %arg2, %c0 : memref<?x?xi32>
+//       CHECKLOOP:   %[[OY:.*]] = dim %arg2, %c1 : memref<?x?xi32>
+//       CHECKLOOP:   scf.for %{{.*}} = %{{.*}} to %[[OX]] step %{{.*}} {
+//       CHECKLOOP:     scf.for %{{.*}} = %{{.*}} to %[[OY]] step %{{.*}} {
+//       CHECKLOOP:       scf.for %{{.*}} = %{{.*}} to %[[WX]] step %{{.*}} {
+//       CHECKLOOP:         scf.for %{{.*}} = %{{.*}} to %[[WY]] step %{{.*}} {
+//       CHECKLOOP:           %[[IX:.*]] = affine.apply #[[$stride1Dilation1Padding2]](%{{.*}}, %{{.*}})
+//       CHECKLOOP:           %[[IY:.*]] = affine.apply #[[$stride1Dilation1Padding1]](%{{.*}}, %{{.*}})
+//       CHECKLOOP:           %[[RHS:.*]] = load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi32>
+//       CHECKLOOP:           %[[IDX:.*]] = affine.max #[[$clampMinMap]](%[[IX]])
+//       CHECKLOOP:           %[[IDY:.*]] = affine.max #[[$clampMinMap]](%[[IY]])
+//       CHECKLOOP:           %[[LHS:.*]] = load %{{.*}}[%[[IDX]], %[[IDY]]] : memref<?x?xi32>
+//       CHECKLOOP:           %[[SEL:.*]] = select %{{.*}}, %[[PAD]], %[[LHS]] : i32
+//       CHECKLOOP:           %[[CMP:.*]] = cmpi "sgt", %[[RHS]], %[[SEL]] : i32
+//       CHECKLOOP:           %[[RES:.*]] = select %{{.*}}, %[[RHS]], %[[SEL]] : i32
+//       CHECKLOOP:           store %[[RES]], %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi32>
+
+// CHECKPARALLEL-LABEL: func @pooling_max_padding_i32
+//       CHECKPARALLEL:   %[[PAD:.*]] = constant -2147483648 : i32
+//       CHECKPARALLEL:   %[[WX:.*]] = dim %arg1, %c0 : memref<?x?xi32>
+//       CHECKPARALLEL:   %[[WY:.*]] = dim %arg1, %c1 : memref<?x?xi32>
+//       CHECKPARALLEL:   %[[OX:.*]] = dim %arg2, %c0 : memref<?x?xi32>
+//       CHECKPARALLEL:   %[[OY:.*]] = dim %arg2, %c1 : memref<?x?xi32>
+//       CHECKPARALLEL:   scf.parallel (%{{.*}}, %{{.*}}) = (%{{.*}}, %{{.*}}) to (%[[OX]], %[[OY]]) step (%{{.*}}, %{{.*}}) {
+//       CHECKPARALLEL:     scf.for %{{.*}} = %{{.*}} to %[[WX]] step %{{.*}} {
+//       CHECKPARALLEL:       scf.for %{{.*}} = %{{.*}} to %[[WY]] step %{{.*}} {
+//       CHECKPARALLEL:         %[[IX:.*]] = affine.apply #[[$stride1Dilation1Padding2]](%{{.*}}, %{{.*}})
+//       CHECKPARALLEL:         %[[IY:.*]] = affine.apply #[[$stride1Dilation1Padding1]](%{{.*}}, %{{.*}})
+//       CHECKPARALLEL:         %[[RHS:.*]] = load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi32>
+//       CHECKPARALLEL:         %[[IDX:.*]] = affine.max #[[$clampMinMap]](%[[IX]])
+//       CHECKPARALLEL:         %[[IDY:.*]] = affine.max #[[$clampMinMap]](%[[IY]])
+//       CHECKPARALLEL:         %[[LHS:.*]] = load %{{.*}}[%[[IDX]], %[[IDY]]] : memref<?x?xi32>
+//       CHECKPARALLEL:         %[[SEL:.*]] = select %{{.*}}, %[[PAD]], %[[LHS]] : i32
+//       CHECKPARALLEL:         %[[CMP:.*]] = cmpi "sgt", %[[RHS]], %[[SEL]] : i32
+//       CHECKPARALLEL:         %[[RES:.*]] = select %{{.*}}, %[[RHS]], %[[SEL]] : i32
+//       CHECKPARALLEL:         store %[[RES]], %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi32>
+
 func @pooling_min(%arg0: memref<?x?xf32>,
                   %arg1: memref<?x?xi32>,
                   %arg2: memref<?x?xf32>) {
@@ -508,6 +608,102 @@ func @pooling_min(%arg0: memref<?x?xf32>,
 //       CHECKPARALLEL:         %[[RES:.*]] = select %{{.*}}, %{{.*}}, %{{.*}} : f32
 //       CHECKPARALLEL:         store %[[RES]], %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xf32>
 
+func @pooling_min_padding(%arg0: memref<?x?xf32>,
+                          %arg1: memref<?x?xi32>,
+                          %arg2: memref<?x?xf32>) {
+  linalg.pooling_min(%arg0, %arg1, %arg2) { padding = dense<[[2, 2], [1, 1]]> : tensor<2x2xi64> } :
+    memref<?x?xf32>, memref<?x?xi32>, memref<?x?xf32>
+  return
+}
+// CHECKLOOP-LABEL: func @pooling_min_padding
+//       CHECKLOOP:   %[[PAD:.*]] = constant 0x7F800000 : f32
+//       CHECKLOOP:   %[[WX:.*]] = dim %arg1, %c0 : memref<?x?xi32>
+//       CHECKLOOP:   %[[WY:.*]] = dim %arg1, %c1 : memref<?x?xi32>
+//       CHECKLOOP:   %[[OX:.*]] = dim %arg2, %c0 : memref<?x?xf32>
+//       CHECKLOOP:   %[[OY:.*]] = dim %arg2, %c1 : memref<?x?xf32>
+//       CHECKLOOP:   scf.for %{{.*}} = %{{.*}} to %[[OX]] step %{{.*}} {
+//       CHECKLOOP:     scf.for %{{.*}} = %{{.*}} to %[[OY]] step %{{.*}} {
+//       CHECKLOOP:       scf.for %{{.*}} = %{{.*}} to %[[WX]] step %{{.*}} {
+//       CHECKLOOP:         scf.for %{{.*}} = %{{.*}} to %[[WY]] step %{{.*}} {
+//       CHECKLOOP:           %[[IX:.*]] = affine.apply #[[$stride1Dilation1Padding2]](%{{.*}}, %{{.*}})
+//       CHECKLOOP:           %[[IY:.*]] = affine.apply #[[$stride1Dilation1Padding1]](%{{.*}}, %{{.*}})
+//       CHECKLOOP:           %[[RHS:.*]] = load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xf32>
+//       CHECKLOOP:           %[[IDX:.*]] = affine.max #[[$clampMinMap]](%[[IX]])
+//       CHECKLOOP:           %[[IDY:.*]] = affine.max #[[$clampMinMap]](%[[IY]])
+//       CHECKLOOP:           %[[LHS:.*]] = load %{{.*}}[%[[IDX]], %[[IDY]]] : memref<?x?xf32>
+//       CHECKLOOP:           %[[SEL:.*]] = select %{{.*}}, %[[PAD]], %[[LHS]] : f32
+//       CHECKLOOP:           %[[CMP:.*]] = cmpf "olt", %[[RHS]], %[[SEL]] : f32
+//       CHECKLOOP:           %[[RES:.*]] = select %{{.*}}, %[[RHS]], %[[SEL]] : f32
+//       CHECKLOOP:           store %[[RES]], %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xf32>
+
+// CHECKPARALLEL-LABEL: func @pooling_min_padding
+//       CHECKPARALLEL:   %[[PAD:.*]] = constant 0x7F800000 : f32
+//       CHECKPARALLEL:   %[[WX:.*]] = dim %arg1, %c0 : memref<?x?xi32>
+//       CHECKPARALLEL:   %[[WY:.*]] = dim %arg1, %c1 : memref<?x?xi32>
+//       CHECKPARALLEL:   %[[OX:.*]] = dim %arg2, %c0 : memref<?x?xf32>
+//       CHECKPARALLEL:   %[[OY:.*]] = dim %arg2, %c1 : memref<?x?xf32>
+//       CHECKPARALLEL:   scf.parallel (%{{.*}}, %{{.*}}) = (%{{.*}}, %{{.*}}) to (%[[OX]], %[[OY]]) step (%{{.*}}, %{{.*}}) {
+//       CHECKPARALLEL:     scf.for %{{.*}} = %{{.*}} to %[[WX]] step %{{.*}} {
+//       CHECKPARALLEL:       scf.for %{{.*}} = %{{.*}} to %[[WY]] step %{{.*}} {
+//       CHECKPARALLEL:         %[[IX:.*]] = affine.apply #[[$stride1Dilation1Padding2]](%{{.*}}, %{{.*}})
+//       CHECKPARALLEL:         %[[IY:.*]] = affine.apply #[[$stride1Dilation1Padding1]](%{{.*}}, %{{.*}})
+//       CHECKPARALLEL:         %[[RHS:.*]] = load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xf32>
+//       CHECKPARALLEL:         %[[IDX:.*]] = affine.max #[[$clampMinMap]](%[[IX]])
+//       CHECKPARALLEL:         %[[IDY:.*]] = affine.max #[[$clampMinMap]](%[[IY]])
+//       CHECKPARALLEL:         %[[LHS:.*]] = load %{{.*}}[%[[IDX]], %[[IDY]]] : memref<?x?xf32>
+//       CHECKPARALLEL:         %[[SEL:.*]] = select %{{.*}}, %[[PAD]], %[[LHS]] : f32
+//       CHECKPARALLEL:         %[[CMP:.*]] = cmpf "olt", %[[RHS]], %[[SEL]] : f32
+//       CHECKPARALLEL:         %[[RES:.*]] = select %{{.*}}, %[[RHS]], %[[SEL]] : f32
+//       CHECKPARALLEL:         store %[[RES]], %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xf32>
+
+func @pooling_min_padding_i32(%arg0: memref<?x?xi32>,
+                              %arg1: memref<?x?xi32>,
+                              %arg2: memref<?x?xi32>) {
+  linalg.pooling_min(%arg0, %arg1, %arg2) { padding = dense<[[2, 2], [1, 1]]> : tensor<2x2xi64> } :
+    memref<?x?xi32>, memref<?x?xi32>, memref<?x?xi32>
+  return
+}
+// CHECKLOOP-LABEL: func @pooling_min_padding_i32
+//       CHECKLOOP:   %[[PAD:.*]] = constant 2147483647 : i32
+//       CHECKLOOP:   %[[WX:.*]] = dim %arg1, %c0 : memref<?x?xi32>
+//       CHECKLOOP:   %[[WY:.*]] = dim %arg1, %c1 : memref<?x?xi32>
+//       CHECKLOOP:   %[[OX:.*]] = dim %arg2, %c0 : memref<?x?xi32>
+//       CHECKLOOP:   %[[OY:.*]] = dim %arg2, %c1 : memref<?x?xi32>
+//       CHECKLOOP:   scf.for %{{.*}} = %{{.*}} to %[[OX]] step %{{.*}} {
+//       CHECKLOOP:     scf.for %{{.*}} = %{{.*}} to %[[OY]] step %{{.*}} {
+//       CHECKLOOP:       scf.for %{{.*}} = %{{.*}} to %[[WX]] step %{{.*}} {
+//       CHECKLOOP:         scf.for %{{.*}} = %{{.*}} to %[[WY]] step %{{.*}} {
+//       CHECKLOOP:           %[[IX:.*]] = affine.apply #[[$stride1Dilation1Padding2]](%{{.*}}, %{{.*}})
+//       CHECKLOOP:           %[[IY:.*]] = affine.apply #[[$stride1Dilation1Padding1]](%{{.*}}, %{{.*}})
+//       CHECKLOOP:           %[[RHS:.*]] = load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi32>
+//       CHECKLOOP:           %[[IDX:.*]] = affine.max #[[$clampMinMap]](%[[IX]])
+//       CHECKLOOP:           %[[IDY:.*]] = affine.max #[[$clampMinMap]](%[[IY]])
+//       CHECKLOOP:           %[[LHS:.*]] = load %{{.*}}[%[[IDX]], %[[IDY]]] : memref<?x?xi32>
+//       CHECKLOOP:           %[[SEL:.*]] = select %{{.*}}, %[[PAD]], %[[LHS]] : i32
+//       CHECKLOOP:           %[[CMP:.*]] = cmpi "slt", %[[RHS]], %[[SEL]] : i32
+//       CHECKLOOP:           %[[RES:.*]] = select %{{.*}}, %[[RHS]], %[[SEL]] : i32
+//       CHECKLOOP:           store %[[RES]], %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi32>
+
+// CHECKPARALLEL-LABEL: func @pooling_min_padding_i32
+//       CHECKPARALLEL:   %[[PAD:.*]] = constant 2147483647 : i32
+//       CHECKPARALLEL:   %[[WX:.*]] = dim %arg1, %c0 : memref<?x?xi32>
+//       CHECKPARALLEL:   %[[WY:.*]] = dim %arg1, %c1 : memref<?x?xi32>
+//       CHECKPARALLEL:   %[[OX:.*]] = dim %arg2, %c0 : memref<?x?xi32>
+//       CHECKPARALLEL:   %[[OY:.*]] = dim %arg2, %c1 : memref<?x?xi32>
+//       CHECKPARALLEL:   scf.parallel (%{{.*}}, %{{.*}}) = (%{{.*}}, %{{.*}}) to (%[[OX]], %[[OY]]) step (%{{.*}}, %{{.*}}) {
+//       CHECKPARALLEL:     scf.for %{{.*}} = %{{.*}} to %[[WX]] step %{{.*}} {
+//       CHECKPARALLEL:       scf.for %{{.*}} = %{{.*}} to %[[WY]] step %{{.*}} {
+//       CHECKPARALLEL:         %[[IX:.*]] = affine.apply #[[$stride1Dilation1Padding2]](%{{.*}}, %{{.*}})
+//       CHECKPARALLEL:         %[[IY:.*]] = affine.apply #[[$stride1Dilation1Padding1]](%{{.*}}, %{{.*}})
+//       CHECKPARALLEL:         %[[RHS:.*]] = load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi32>
+//       CHECKPARALLEL:         %[[IDX:.*]] = affine.max #[[$clampMinMap]](%[[IX]])
+//       CHECKPARALLEL:         %[[IDY:.*]] = affine.max #[[$clampMinMap]](%[[IY]])
+//       CHECKPARALLEL:         %[[LHS:.*]] = load %{{.*}}[%[[IDX]], %[[IDY]]] : memref<?x?xi32>
+//       CHECKPARALLEL:         %[[SEL:.*]] = select %{{.*}}, %[[PAD]], %[[LHS]] : i32
+//       CHECKPARALLEL:         %[[CMP:.*]] = cmpi "slt", %[[RHS]], %[[SEL]] : i32
+//       CHECKPARALLEL:         %[[RES:.*]] = select %{{.*}}, %[[RHS]], %[[SEL]] : i32
+//       CHECKPARALLEL:         store %[[RES]], %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi32>
+
 func @pooling_sum(%arg0: memref<?x?xf32>,
                   %arg1: memref<?x?xi32>,
                   %arg2: memref<?x?xf32>) {
@@ -546,6 +742,98 @@ func @pooling_sum(%arg0: memref<?x?xf32>,
 //       CHECKPARALLEL:         %[[RES:.*]] = addf %[[LHS]], %[[RHS]] : f32
 //       CHECKPARALLEL:         store %[[RES]], %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xf32>
 
+func @pooling_sum_padding(%arg0: memref<?x?xf32>,
+                          %arg1: memref<?x?xi32>,
+                          %arg2: memref<?x?xf32>) {
+  linalg.pooling_sum(%arg0, %arg1, %arg2) { padding = dense<[[2, 2], [1, 1]]> : tensor<2x2xi64> } :
+    memref<?x?xf32>, memref<?x?xi32>, memref<?x?xf32>
+  return
+}
+// CHECKLOOP-LABEL: func @pooling_sum_padding
+//       CHECKLOOP:   %[[PAD:.*]] = constant 0.000000e+00 : f32
+//       CHECKLOOP:   %[[WX:.*]] = dim %arg1, %c0 : memref<?x?xi32>
+//       CHECKLOOP:   %[[WY:.*]] = dim %arg1, %c1 : memref<?x?xi32>
+//       CHECKLOOP:   %[[OX:.*]] = dim %arg2, %c0 : memref<?x?xf32>
+//       CHECKLOOP:   %[[OY:.*]] = dim %arg2, %c1 : memref<?x?xf32>
+//       CHECKLOOP:   scf.for %{{.*}} = %{{.*}} to %[[OX]] step %{{.*}} {
+//       CHECKLOOP:     scf.for %{{.*}} = %{{.*}} to %[[OY]] step %{{.*}} {
+//       CHECKLOOP:       scf.for %{{.*}} = %{{.*}} to %[[WX]] step %{{.*}} {
+//       CHECKLOOP:         scf.for %{{.*}} = %{{.*}} to %[[WY]] step %{{.*}} {
+//       CHECKLOOP:           %[[IX:.*]] = affine.apply #[[$stride1Dilation1Padding2]](%{{.*}}, %{{.*}})
+//       CHECKLOOP:           %[[IY:.*]] = affine.apply #[[$stride1Dilation1Padding1]](%{{.*}}, %{{.*}})
+//       CHECKLOOP:           %[[IDX:.*]] = affine.max #[[$clampMinMap]](%[[IX]])
+//       CHECKLOOP:           %[[IDY:.*]] = affine.max #[[$clampMinMap]](%[[IY]])
+//       CHECKLOOP:           %[[LHS:.*]] = load %{{.*}}[%[[IDX]], %[[IDY]]] : memref<?x?xf32>
+//       CHECKLOOP:           %[[SEL:.*]] = select %{{.*}}, %[[PAD]], %[[LHS]] : f32
+//       CHECKLOOP:           %[[RHS:.*]] = load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xf32>
+//       CHECKLOOP:           %[[RES:.*]] = addf %[[RHS]], %[[SEL]] : f32
+//       CHECKLOOP:           store %[[RES]], %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xf32>
+
+// CHECKPARALLEL-LABEL: func @pooling_sum_padding
+//       CHECKPARALLEL:   %[[PAD:.*]] = constant 0.000000e+00 : f32
+//       CHECKPARALLEL:   %[[WX:.*]] = dim %arg1, %c0 : memref<?x?xi32>
+//       CHECKPARALLEL:   %[[WY:.*]] = dim %arg1, %c1 : memref<?x?xi32>
+//       CHECKPARALLEL:   %[[OX:.*]] = dim %arg2, %c0 : memref<?x?xf32>
+//       CHECKPARALLEL:   %[[OY:.*]] = dim %arg2, %c1 : memref<?x?xf32>
+//       CHECKPARALLEL:   scf.parallel (%{{.*}}, %{{.*}}) = (%{{.*}}, %{{.*}}) to (%[[OX]], %[[OY]]) step (%{{.*}}, %{{.*}}) {
+//       CHECKPARALLEL:     scf.for %{{.*}} = %{{.*}} to %[[WX]] step %{{.*}} {
+//       CHECKPARALLEL:       scf.for %{{.*}} = %{{.*}} to %[[WY]] step %{{.*}} {
+//       CHECKPARALLEL:         %[[IX:.*]] = affine.apply #[[$stride1Dilation1Padding2]](%{{.*}}, %{{.*}})
+//       CHECKPARALLEL:         %[[IY:.*]] = affine.apply #[[$stride1Dilation1Padding1]](%{{.*}}, %{{.*}})
+//       CHECKPARALLEL:         %[[IDX:.*]] = affine.max #[[$clampMinMap]](%[[IX]])
+//       CHECKPARALLEL:         %[[IDY:.*]] = affine.max #[[$clampMinMap]](%[[IY]])
+//       CHECKPARALLEL:         %[[LHS:.*]] = load %{{.*}}[%[[IDX]], %[[IDY]]] : memref<?x?xf32>
+//       CHECKPARALLEL:         %[[SEL:.*]] = select %{{.*}}, %[[PAD]], %[[LHS]] : f32
+//       CHECKPARALLEL:         %[[RHS:.*]] = load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xf32>
+//       CHECKPARALLEL:         %[[RES:.*]] = addf %[[RHS]], %[[SEL]] : f32
+//       CHECKPARALLEL:         store %[[RES]], %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xf32>
+
+func @pooling_sum_padding_i32(%arg0: memref<?x?xi32>,
+                              %arg1: memref<?x?xi32>,
+                              %arg2: memref<?x?xi32>) {
+  linalg.pooling_sum(%arg0, %arg1, %arg2) { padding = dense<[[2, 2], [1, 1]]> : tensor<2x2xi64> } :
+    memref<?x?xi32>, memref<?x?xi32>, memref<?x?xi32>
+  return
+}
+// CHECKLOOP-LABEL: func @pooling_sum_padding_i32
+//       CHECKLOOP:   %[[PAD:.*]] = constant 0 : i32
+//       CHECKLOOP:   %[[WX:.*]] = dim %arg1, %c0 : memref<?x?xi32>
+//       CHECKLOOP:   %[[WY:.*]] = dim %arg1, %c1 : memref<?x?xi32>
+//       CHECKLOOP:   %[[OX:.*]] = dim %arg2, %c0 : memref<?x?xi32>
+//       CHECKLOOP:   %[[OY:.*]] = dim %arg2, %c1 : memref<?x?xi32>
+//       CHECKLOOP:   scf.for %{{.*}} = %{{.*}} to %[[OX]] step %{{.*}} {
+//       CHECKLOOP:     scf.for %{{.*}} = %{{.*}} to %[[OY]] step %{{.*}} {
+//       CHECKLOOP:       scf.for %{{.*}} = %{{.*}} to %[[WX]] step %{{.*}} {
+//       CHECKLOOP:         scf.for %{{.*}} = %{{.*}} to %[[WY]] step %{{.*}} {
+//       CHECKLOOP:           %[[IX:.*]] = affine.apply #[[$stride1Dilation1Padding2]](%{{.*}}, %{{.*}})
+//       CHECKLOOP:           %[[IY:.*]] = affine.apply #[[$stride1Dilation1Padding1]](%{{.*}}, %{{.*}})
+//       CHECKLOOP:           %[[IDX:.*]] = affine.max #[[$clampMinMap]](%[[IX]])
+//       CHECKLOOP:           %[[IDY:.*]] = affine.max #[[$clampMinMap]](%[[IY]])
+//       CHECKLOOP:           %[[LHS:.*]] = load %{{.*}}[%[[IDX]], %[[IDY]]] : memref<?x?xi32>
+//       CHECKLOOP:           %[[SEL:.*]] = select %{{.*}}, %[[PAD]], %[[LHS]] : i32
+//       CHECKLOOP:           %[[RHS:.*]] = load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi32>
+//       CHECKLOOP:           %[[RES:.*]] = addi %[[RHS]], %[[SEL]] : i32
+//       CHECKLOOP:           store %[[RES]], %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi32>
+
+// CHECKPARALLEL-LABEL: func @pooling_sum_padding_i32
+//       CHECKPARALLEL:   %[[PAD:.*]] = constant 0 : i32
+//       CHECKPARALLEL:   %[[WX:.*]] = dim %arg1, %c0 : memref<?x?xi32>
+//       CHECKPARALLEL:   %[[WY:.*]] = dim %arg1, %c1 : memref<?x?xi32>
+//       CHECKPARALLEL:   %[[OX:.*]] = dim %arg2, %c0 : memref<?x?xi32>
+//       CHECKPARALLEL:   %[[OY:.*]] = dim %arg2, %c1 : memref<?x?xi32>
+//       CHECKPARALLEL:   scf.parallel (%{{.*}}, %{{.*}}) = (%{{.*}}, %{{.*}}) to (%[[OX]], %[[OY]]) step (%{{.*}}, %{{.*}}) {
+//       CHECKPARALLEL:     scf.for %{{.*}} = %{{.*}} to %[[WX]] step %{{.*}} {
+//       CHECKPARALLEL:       scf.for %{{.*}} = %{{.*}} to %[[WY]] step %{{.*}} {
+//       CHECKPARALLEL:         %[[IX:.*]] = affine.apply #[[$stride1Dilation1Padding2]](%{{.*}}, %{{.*}})
+//       CHECKPARALLEL:         %[[IY:.*]] = affine.apply #[[$stride1Dilation1Padding1]](%{{.*}}, %{{.*}})
+//       CHECKPARALLEL:         %[[IDX:.*]] = affine.max #[[$clampMinMap]](%[[IX]])
+//       CHECKPARALLEL:         %[[IDY:.*]] = affine.max #[[$clampMinMap]](%[[IY]])
+//       CHECKPARALLEL:         %[[LHS:.*]] = load %{{.*}}[%[[IDX]], %[[IDY]]] : memref<?x?xi32>
+//       CHECKPARALLEL:         %[[SEL:.*]] = select %{{.*}}, %[[PAD]], %[[LHS]] : i32
+//       CHECKPARALLEL:         %[[RHS:.*]] = load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi32>
+//       CHECKPARALLEL:         %[[RES:.*]] = addi %[[RHS]], %[[SEL]] : i32
+//       CHECKPARALLEL:         store %[[RES]], %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi32>
+
 #accesses = [
   affine_map<(i, j, k) -> (i, j)>,
   affine_map<(i, j, k) -> (i, j, k)>,


        


More information about the Mlir-commits mailing list