[Mlir-commits] [mlir] 69ddee1 - [mlir][Linalg] Introduce linalg.pooling_min/max/sum op.

Hanhan Wang llvmlistbot at llvm.org
Tue Mar 31 21:22:44 PDT 2020


Author: Hanhan Wang
Date: 2020-03-31T21:21:54-07:00
New Revision: 69ddee1d2aadaa0b9ac4549f366d1bf5701a65f0

URL: https://github.com/llvm/llvm-project/commit/69ddee1d2aadaa0b9ac4549f366d1bf5701a65f0
DIFF: https://github.com/llvm/llvm-project/commit/69ddee1d2aadaa0b9ac4549f366d1bf5701a65f0.diff

LOG: [mlir][Linalg] Introduce linalg.pooling_min/max/sum op.

Summary:
Performs an N-D pooling operation similarly to the description in the TF
documentation:
https://www.tensorflow.org/api_docs/python/tf/nn/pool

Different from the description, this operation doesn't perform on batch and
channel. It only takes tensors of rank `N`.

```
  output[x[0], ..., x[N-1]] =
    REDUCE_{z[0], ..., z[N-1]}
      input[
            x[0] * strides[0] - pad_before[0] + dilation_rate[0]*z[0],
            ...
            x[N-1]*strides[N-1] - pad_before[N-1] + dilation_rate[N-1]*z[N-1]
            ],
```

The required optional arguments are:
  - strides: an i64 array specifying the stride (i.e. step) for window
    loops.
  - dilations: an i64 array specifying the filter upsampling/input
    downsampling rate
  - padding: an i64 array of pairs (low, high) specifying the number of
    elements to pad along a dimension.

If strides or dilations attributes are missing then the default value is
one for each of the input dimensions. Similarly, padding values are zero
for both low and high in each of the dimensions, if not specified.

Differential Revision: https://reviews.llvm.org/D76414

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h
    mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
    mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
    mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp
    mlir/test/Dialect/Linalg/invalid.mlir
    mlir/test/Dialect/Linalg/loops.mlir
    mlir/test/Dialect/Linalg/roundtrip.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h
index 7756a08d5cb2..77d9d9fc2631 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h
@@ -29,6 +29,9 @@ namespace mlir {
 namespace linalg {
 
 class ConvOp;
+class PoolingMaxOp;
+class PoolingMinOp;
+class PoolingSumOp;
 
 /// Returns the name mangled library call name to disambiguate between 
diff erent
 /// overloads at the C level. The name mangling scheme is basic and uses MLIR
@@ -60,12 +63,13 @@ std::string generateLibraryCallName(Operation *op);
 SmallVector<AffineExpr, 4> makeAffineDimExprs(unsigned num, unsigned &startIdx,
                                               MLIRContext *context);
 
-/// Builds the indexing expressions for a ConvOp `op`. Returns the vector of
-/// AffineMaps representing:
-///   `stride[i] * xs[i] + dilation[i] * zs[i] - pad_low[i]`
-SmallVector<AffineExpr, 4> weightedConvInputIndex(ConvOp op,
-                                                  ArrayRef<AffineExpr> xs,
-                                                  ArrayRef<AffineExpr> zs);
+/// Builds the indexing expressions for a ConvOp/PoolingOp `op`. Returns the
+/// vector of AffineMaps representing:
+///   `stride[i] * outputDims[i] + dilation[i] * windowDims[i] - pad_low[i]`
+template <typename PoolingOp>
+extern SmallVector<AffineExpr, 4>
+weightedPoolingInputIndex(PoolingOp op, ArrayRef<AffineExpr> outputDims,
+                          ArrayRef<AffineExpr> windowDims);
 
 /// Returns `maybeMap.get()` if `maybeMap` is set, otherwise returns the
 /// symbol-less identity map of `rank`.

diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index ab53fc30aca8..31b89bc1b2bf 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -251,7 +251,69 @@ def MatmulOp : LinalgStructured_Op<"matmul", [NInputs<2>, NOutputs<1>]> {
   let hasFolder = 1;
 }
 
-def ConvOp : LinalgStructured_Op<"conv", [NInputs<2>, NOutputs<1>]> {
+/// A base class for pooling operation such as conv. The arguments must contain
+/// optional arguments `strides`, `dilations` and `padding` with following type:
+///   OptionalAttr<I64ArrayAttr>:$strides
+///   OptionalAttr<I64ArrayAttr>:$dilations
+///   OptionalAttr<I64ElementsAttr>:$padding
+/// `stirdes` denotes the step of each window along the dimension.
+class PoolingBase_Op<string mnemonic, list<OpTrait> props>
+  : LinalgStructured_Op<mnemonic, props> {
+  let description = [{
+    Performs an N-D pooling operation similarly to the description in the TF
+    documentation:
+    https://www.tensorflow.org/api_docs/python/tf/nn/pool
+
+    Different from the description, this operation doesn't perform on batch and
+    channel. It only takes tensors of rank `N`.
+
+    ```
+      output[x[0], ..., x[N-1]] =
+        REDUCE_{z[0], ..., z[N-1]}
+          input[
+                x[0] * strides[0] - pad_before[0] + dilation_rate[0]*z[0],
+                ...
+                x[N-1]*strides[N-1] - pad_before[N-1] + dilation_rate[N-1]*z[N-1]
+                ],
+    ```
+
+    The required optional arguments are:
+      - strides: an i64 array specifying the stride (i.e. step) for window
+        loops.
+      - dilations: an i64 array specifying the filter upsampling/input
+        downsampling rate
+      - padding: an i64 array of pairs (low, high) specifying the number of
+        elements to pad along a dimension.
+
+    If strides or dilations attributes are missing then the default value is
+    one for each of the input dimensions. Similarly, padding values are zero
+    for both low and high in each of the dimensions, if not specified.
+  }];
+
+  code commonUtils = libraryCallName # [{
+    int64_t getStride(unsigned i) {
+      assert(i < getNumWindowLoops());
+      if (!strides().hasValue()) return 1;
+      return strides()->getValue()[i]
+        .cast<IntegerAttr>().getValue().getSExtValue();
+    }
+
+    int64_t getDilation(unsigned i) {
+      assert(i < getNumWindowLoops());
+      if (!dilations().hasValue()) return 1;
+      return dilations()->getValue()[i]
+        .cast<IntegerAttr>().getValue().getSExtValue();
+    }
+
+    int64_t getLowPad(unsigned i) {
+      assert(i < getNumWindowLoops());
+      if (!padding().hasValue()) return 0;
+      return padding().getValue().getValue<int64_t>({i, 0});
+    }
+  }];
+}
+
+def ConvOp : PoolingBase_Op<"conv", [NInputs<2>, NOutputs<1>]> {
 
   let description = [{
     Generic n-D convolution as described in the TF documentation:
@@ -282,7 +344,7 @@ def ConvOp : LinalgStructured_Op<"conv", [NInputs<2>, NOutputs<1>]> {
                    OptionalAttr<I64ArrayAttr>:$dilations,
                    OptionalAttr<I64ElementsAttr>:$padding);
 
-  let extraClassDeclaration = libraryCallName # [{
+  let extraClassDeclaration = commonUtils # [{
     // TODO(ntv) extend to support more than 1 dimensions and potentially
     // grouping too.
     unsigned getNumBatchDimensions() { return 1; }
@@ -309,26 +371,6 @@ def ConvOp : LinalgStructured_Op<"conv", [NInputs<2>, NOutputs<1>]> {
       return iters;
     }
 
-    int64_t getStride(unsigned i) {
-      assert(i < getNumWindowLoops());
-      if (!strides().hasValue()) return 1;
-      return strides()->getValue()[i]
-        .cast<IntegerAttr>().getValue().getSExtValue();
-    }
-
-    int64_t getDilation(unsigned i) {
-      assert(i < getNumWindowLoops());
-      if (!dilations().hasValue()) return 1;
-      return dilations()->getValue()[i]
-        .cast<IntegerAttr>().getValue().getSExtValue();
-    }
-
-    int64_t getLowPad(unsigned i) {
-      assert(i < getNumWindowLoops());
-      if (!padding().hasValue()) return 0;
-      return padding().getValue().getValue<int64_t>({i, 0});
-    }
-
     //   F(z0, ..., zN-1, q, k) *
     //     I(b, x0 + z0 - pad_low_0, ..., xN-1 + zN-1 - pad_low_N-1, q)
     //   ->  O(b, x0, ..., xN-1, k)
@@ -358,7 +400,7 @@ def ConvOp : LinalgStructured_Op<"conv", [NInputs<2>, NOutputs<1>]> {
       // Window reduction dims: sum_{z[0], ..., z[N-1], q}
       auto zs = makeAffineDimExprs(nWin, idx, context);
       // Construct the weighedSum expression.
-      auto ws = weightedConvInputIndex(*this, xs, zs);
+      auto ws = weightedPoolingInputIndex(*this, xs, zs);
       return SmallVector<AffineMap, 8>{
         // filter[z[0], ..., z[N-1], q, k]
         AffineMap::get(idx, 0, concat(concat(zs, qs), ks)),
@@ -378,6 +420,86 @@ def ConvOp : LinalgStructured_Op<"conv", [NInputs<2>, NOutputs<1>]> {
   let hasFolder = 1;
 }
 
+class SingleInputPoolingBase_Op<string mnemonic>
+  : PoolingBase_Op<mnemonic, [NInputs<2>, NOutputs<1>]> {
+  let description = [{
+    A base class for single input pooling function.
+
+    TODO: Figure out a better way to handle window dimensions, i.e., eliminate
+    the fake memref.
+    The window dimensions are specified by argument `windowDims`. The i-th
+    dimension in the shape of `windowDims` denotes the size of the window along
+    dimension i. For example, if the window size is 2x3, then a memref<2x3>
+    should be passed to the operation as `windowDims`.
+  }];
+
+  let arguments = (ins AnyStridedMemRef:$input,
+                   AnyStridedMemRef:$windowDims,
+                   AnyStridedMemRef:$output,
+                   OptionalAttr<I64ArrayAttr>:$strides,
+                   OptionalAttr<I64ArrayAttr>:$dilations,
+                   OptionalAttr<I64ElementsAttr>:$padding);
+
+  let extraClassDeclaration = commonUtils# [{
+    llvm::Optional<SmallVector<StringRef, 8>> referenceIterators() {
+      // Outer parallel loops are always the number of output dimensions.
+      unsigned nPar = getOutputShapedType(0).getRank();
+      // The window loops has the same number loops with output dimensions.
+      unsigned nWin = nPar;
+      SmallVector<StringRef, 8> iters(nPar, getParallelIteratorTypeName());
+      iters.reserve(nPar + nWin);
+      iters.append(nWin, getWindowIteratorTypeName());
+      return iters;
+    }
+
+    llvm::Optional<SmallVector<AffineMap, 8>> referenceIndexingMaps() {
+      MLIRContext *context = getContext();
+      auto nPar = getNumParallelLoops();
+      auto nWin = getNumWindowLoops();
+      assert(nWin > 0 && "expected at least one window dimension");
+      unsigned idx = 0;
+      auto outputDims = makeAffineDimExprs(nPar, idx, context);
+      auto windowDims = makeAffineDimExprs(nWin, idx, context);
+      // Construct the weighedSum expression.
+      auto inputDims =
+          weightedPoolingInputIndex(*this, outputDims, windowDims);
+      return SmallVector<AffineMap, 8>{
+        // input
+        AffineMap::get(idx, 0, inputDims),
+        // windowDims
+        AffineMap::get(idx, 0, windowDims),
+        // output
+        AffineMap::get(idx, 0, outputDims)
+        };
+    }
+  }];
+
+  let verifier = [{ return ::verify(*this); }];
+
+  let hasFolder = 1;
+}
+
+def PoolingMaxOp: SingleInputPoolingBase_Op<"pooling_max"> {
+  let description = [{
+    Takes max op as pooling operation, i.e., it samples the maximum value in the
+    window.
+  }];
+}
+
+def PoolingMinOp: SingleInputPoolingBase_Op<"pooling_min"> {
+  let description = [{
+    Takes min op as pooling operation, i.e., it samples the minimum value in the
+    window.
+  }];
+}
+
+def PoolingSumOp: SingleInputPoolingBase_Op<"pooling_sum"> {
+  let description = [{
+    Takes add op as pooling operation, i.e., it accumulates the values in the
+    window.
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // Generic Linalg ops.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
index d54791a65410..bb37bb28a18c 100644
--- a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
@@ -72,6 +72,15 @@ constexpr StringRef getFunAttrName() { return "fun"; }
 /// function that implements the structured op.
 constexpr StringRef getLibraryCallAttrName() { return "library_call"; }
 
+/// Attribute name for the StrArrayAttr which encodes the value of strides.
+constexpr StringRef getStridesAttrName() { return "strides"; }
+
+/// Attribute name for the StrArrayAttr which encodes the value of dilations.
+constexpr StringRef getDilationsAttrName() { return "dilations"; }
+
+/// Attribute name for the StrArrayAttr which encodes the value of paddings.
+constexpr StringRef getPaddingAttrName() { return "padding"; }
+
 /// Use to encode that a particular iterator type has parallel semantics.
 constexpr StringRef getParallelIteratorTypeName() { return "parallel"; }
 

diff  --git a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
index 577b134fa5ed..b493aa662db0 100644
--- a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
+++ b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
@@ -524,12 +524,21 @@ populateLinalgToStandardConversionPatterns(OwningRewritePatternList &patterns,
                                            MLIRContext *ctx) {
   // TODO(ntv) ConvOp conversion needs to export a descriptor with relevant
   // attribute values such as kernel striding and dilation.
-  patterns.insert<CopyTransposeConversion, LinalgOpConversion<ConvOp>,
-                  LinalgOpConversion<CopyOp>, LinalgOpConversion<DotOp>,
-                  LinalgOpConversion<FillOp>, LinalgOpConversion<GenericOp>,
-                  LinalgOpConversion<IndexedGenericOp>,
-                  LinalgOpConversion<MatmulOp>, LinalgOpConversion<MatvecOp>>(
-      ctx);
+  // clang-format off
+  patterns.insert<
+      CopyTransposeConversion,
+      LinalgOpConversion<ConvOp>,
+      LinalgOpConversion<PoolingMaxOp>,
+      LinalgOpConversion<PoolingMinOp>,
+      LinalgOpConversion<PoolingSumOp>,
+      LinalgOpConversion<CopyOp>,
+      LinalgOpConversion<DotOp>,
+      LinalgOpConversion<FillOp>,
+      LinalgOpConversion<GenericOp>,
+      LinalgOpConversion<IndexedGenericOp>,
+      LinalgOpConversion<MatmulOp>,
+      LinalgOpConversion<MatvecOp>>(ctx);
+  // clang-format on
 }
 
 } // namespace

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index aa340e55e8b5..077b34c320ec 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -140,7 +140,6 @@ static void printGenericOp(OpAsmPrinter &p, GenericOpType op) {
     p.printRegion(op.region());
   p.printOptionalAttrDict(op.getAttrs(), attrNames);
   p << ": " << op.getOperandTypes();
-
   auto outputTensorTypes = op.getResultTypes();
   if (!outputTensorTypes.empty())
     p << " -> " << outputTensorTypes;
@@ -827,8 +826,10 @@ static LogicalResult verify(CopyOp op) {
   return success();
 }
 
-static LogicalResult
-verifyStrideOrDilation(ConvOp op, ArrayRef<Attribute> attrs, bool isStride) {
+template <typename LinalgPoolingOp>
+static LogicalResult verifyStrideOrDilation(LinalgPoolingOp op,
+                                            ArrayRef<Attribute> attrs,
+                                            bool isStride) {
   auto strideOrDilation = isStride ? "stride" : "dilation";
   if (attrs.size() != op.getNumWindowLoops())
     return op.emitOpError("expects num ")
@@ -860,6 +861,41 @@ static LogicalResult verify(ConvOp op) {
   return success();
 }
 
+template <typename PoolingOp>
+LogicalResult verifySingleInputPoolingOp(PoolingOp op) {
+  auto inputType = op.input().getType().template cast<MemRefType>();
+  auto outputType = op.output().getType().template cast<MemRefType>();
+  if (outputType.getElementType() != inputType.getElementType())
+    return op.emitOpError("expects memref elemental types to match");
+
+  auto windowDimsType = op.windowDims().getType().template cast<MemRefType>();
+  if (outputType.getRank() != inputType.getRank() ||
+      outputType.getRank() != windowDimsType.getRank())
+    return op.emitOpError("expects memref ranks to match");
+
+  if (auto strides = op.strides()) {
+    if (failed(
+            verifyStrideOrDilation(op, strides->getValue(), /*isStride=*/true)))
+      return failure();
+  }
+  if (auto dilations = op.dilations()) {
+    if (failed(verifyStrideOrDilation(op, dilations->getValue(),
+                                      /*isStride=*/false)))
+      return failure();
+  }
+  return success();
+}
+
+static LogicalResult verify(PoolingMaxOp op) {
+  return verifySingleInputPoolingOp(op);
+}
+static LogicalResult verify(PoolingMinOp op) {
+  return verifySingleInputPoolingOp(op);
+}
+static LogicalResult verify(PoolingSumOp op) {
+  return verifySingleInputPoolingOp(op);
+}
+
 namespace mlir {
 namespace linalg {
 
@@ -894,21 +930,34 @@ mlir::linalg::makeAffineDimExprs(unsigned num, unsigned &startIdx,
   return res;
 }
 
+template <typename PoolingOp>
 SmallVector<AffineExpr, 4>
-mlir::linalg::weightedConvInputIndex(ConvOp op, ArrayRef<AffineExpr> xs,
-                                     ArrayRef<AffineExpr> zs) {
-  assert(xs.size() == zs.size());
+mlir::linalg::weightedPoolingInputIndex(PoolingOp op,
+                                        ArrayRef<AffineExpr> outputDims,
+                                        ArrayRef<AffineExpr> windowDims) {
+  assert(outputDims.size() == windowDims.size());
   SmallVector<AffineExpr, 4> res;
-  res.reserve(xs.size());
-  for (unsigned i = 0, e = xs.size(); i < e; ++i) {
+  res.reserve(outputDims.size());
+  for (unsigned i = 0, e = outputDims.size(); i < e; ++i) {
     // TODO(ntv): add a level of indirection to linalg.generic.
-    auto expr =
-        op.getStride(i) * xs[i] + op.getDilation(i) * zs[i] - op.getLowPad(i);
+    auto expr = op.getStride(i) * outputDims[i] +
+                op.getDilation(i) * windowDims[i] - op.getLowPad(i);
     res.push_back(expr);
   }
   return res;
 }
 
+#define INSTANTIATE_WEIGHTED_POOLING_INPUT_INDEX(OP_TYPE)                      \
+  template SmallVector<AffineExpr, 4>                                          \
+  mlir::linalg::weightedPoolingInputIndex<OP_TYPE>(                            \
+      OP_TYPE op, ArrayRef<AffineExpr> outputDims,                             \
+      ArrayRef<AffineExpr> windowDims);
+
+INSTANTIATE_WEIGHTED_POOLING_INPUT_INDEX(ConvOp)
+INSTANTIATE_WEIGHTED_POOLING_INPUT_INDEX(PoolingMaxOp)
+INSTANTIATE_WEIGHTED_POOLING_INPUT_INDEX(PoolingMinOp)
+INSTANTIATE_WEIGHTED_POOLING_INPUT_INDEX(PoolingSumOp)
+
 SmallVector<AffineExpr, 4> mlir::linalg::concat(ArrayRef<AffineExpr> a,
                                                 ArrayRef<AffineExpr> b) {
   auto rangeA = llvm::make_range(a.begin(), a.end());
@@ -959,6 +1008,18 @@ LogicalResult ConvOp::fold(ArrayRef<Attribute>,
                            SmallVectorImpl<OpFoldResult> &) {
   return foldMemRefCast(*this);
 }
+LogicalResult PoolingMaxOp::fold(ArrayRef<Attribute>,
+                                 SmallVectorImpl<OpFoldResult> &) {
+  return foldMemRefCast(*this);
+}
+LogicalResult PoolingMinOp::fold(ArrayRef<Attribute>,
+                                 SmallVectorImpl<OpFoldResult> &) {
+  return foldMemRefCast(*this);
+}
+LogicalResult PoolingSumOp::fold(ArrayRef<Attribute>,
+                                 SmallVectorImpl<OpFoldResult> &) {
+  return foldMemRefCast(*this);
+}
 LogicalResult CopyOp::fold(ArrayRef<Attribute>,
                            SmallVectorImpl<OpFoldResult> &) {
   return foldMemRefCast(*this);

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp
index eb2a881308c4..ae589e7a6b63 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp
@@ -106,6 +106,23 @@ static void inlineRegionAndEmitStdStore(OpType op,
   }
 }
 
+// Returns a pair that contains input indices and output indices of a
+// SingleInputPoolingOp `op`.
+template <typename SingleInputPoolingOp>
+static std::pair<SmallVector<ValueHandle, 8>, SmallVector<ValueHandle, 8>>
+getInputAndOutputIndices(ArrayRef<Value> allIvs, SingleInputPoolingOp op) {
+  auto &b = ScopedContext::getBuilder();
+  auto loc = ScopedContext::getLocation();
+  auto mapsRange = op.indexing_maps().template getAsRange<AffineMapAttr>();
+  auto maps =
+      functional::map([](AffineMapAttr a) { return a.getValue(); }, mapsRange);
+  SmallVector<ValueHandle, 8> iIdx(
+      makeCanonicalAffineApplies(b, loc, maps[0], allIvs));
+  SmallVector<ValueHandle, 8> oIdx(
+      makeCanonicalAffineApplies(b, loc, maps[2], allIvs));
+  return {iIdx, oIdx};
+}
+
 namespace {
 template <typename IndexedValueType, typename LinalgOpType>
 class LinalgScopedEmitter {};
@@ -273,6 +290,57 @@ class LinalgScopedEmitter<IndexedValueType, ConvOp> {
   }
 };
 
+template <typename IndexedValueType>
+class LinalgScopedEmitter<IndexedValueType, PoolingMaxOp> {
+public:
+  static void emitScalarImplementation(ArrayRef<Value> allIvs,
+                                       PoolingMaxOp op) {
+    auto indices = getInputAndOutputIndices(allIvs, op);
+    ValueHandleArray iIdx(indices.first);
+    ValueHandleArray oIdx(indices.second);
+
+    // Emit scalar form.
+    ValueHandle lhs = std_load(op.output(), oIdx);
+    ValueHandle rhs = std_load(op.input(), iIdx);
+    using edsc::op::operator>;
+    ValueHandle maxValue = std_select(lhs > rhs, lhs, rhs);
+    std_store(maxValue, op.output(), oIdx);
+  }
+};
+
+template <typename IndexedValueType>
+class LinalgScopedEmitter<IndexedValueType, PoolingMinOp> {
+public:
+  static void emitScalarImplementation(ArrayRef<Value> allIvs,
+                                       PoolingMinOp op) {
+    auto indices = getInputAndOutputIndices(allIvs, op);
+    ValueHandleArray iIdx(indices.first);
+    ValueHandleArray oIdx(indices.second);
+
+    // Emit scalar form.
+    ValueHandle lhs = std_load(op.output(), oIdx);
+    ValueHandle rhs = std_load(op.input(), iIdx);
+    using edsc::op::operator<;
+    ValueHandle minValue = std_select(lhs < rhs, lhs, rhs);
+    std_store(minValue, op.output(), oIdx);
+  }
+};
+
+template <typename IndexedValueType>
+class LinalgScopedEmitter<IndexedValueType, PoolingSumOp> {
+public:
+  static void emitScalarImplementation(ArrayRef<Value> allIvs,
+                                       PoolingSumOp op) {
+    auto indices = getInputAndOutputIndices(allIvs, op);
+    SmallVector<ValueHandle, 8> iIdx = indices.first;
+    SmallVector<ValueHandle, 8> oIdx = indices.second;
+    IndexedValueType input(op.input()), output(op.output());
+
+    // Emit scalar form.
+    output(oIdx) += input(iIdx);
+  }
+};
+
 // Emits the MLIR for the scalar part of the generic op by:
 //   1. Emitting std_load and std_store ops for each input and output
 //      view in order. This is achieved by applying the appropriate input or
@@ -688,6 +756,9 @@ INSTANTIATE_LINALG_OP_TO_LOOPS(DotOp)
 INSTANTIATE_LINALG_OP_TO_LOOPS(MatvecOp)
 INSTANTIATE_LINALG_OP_TO_LOOPS(MatmulOp)
 INSTANTIATE_LINALG_OP_TO_LOOPS(ConvOp)
+INSTANTIATE_LINALG_OP_TO_LOOPS(PoolingMaxOp)
+INSTANTIATE_LINALG_OP_TO_LOOPS(PoolingMinOp)
+INSTANTIATE_LINALG_OP_TO_LOOPS(PoolingSumOp)
 INSTANTIATE_LINALG_OP_TO_LOOPS(GenericOp)
 INSTANTIATE_LINALG_OP_TO_LOOPS(IndexedGenericOp)
 

diff  --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index 59e4a764afcc..7a8291504ae6 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -513,3 +513,14 @@ func @reshape(%arg0: memref<?x?x?xf32>) {
   %0 = linalg.reshape %arg0 [affine_map<(i, j, k) -> (i, j)>, affine_map<(i, j, k) -> (k)>] :
     memref<?x?x?xf32> into memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * s0 + d1)>>
 }
+
+// -----
+
+func @pooling_rank_mismatch(%arg0: memref<?x?x?xf32>,
+                            %arg1: memref<2x3xf32>,
+                            %arg2: memref<?x?x?xf32>) {
+  // expected-error @+1 {{expects memref ranks to match}}
+  linalg.pooling_max(%arg0, %arg1, %arg2) {strides = [2, 1, 2]}:
+    memref<?x?x?xf32>, memref<2x3xf32>, memref<?x?x?xf32>
+  return
+}

diff  --git a/mlir/test/Dialect/Linalg/loops.mlir b/mlir/test/Dialect/Linalg/loops.mlir
index c8d114bee6ae..1bd0cf61dd24 100644
--- a/mlir/test/Dialect/Linalg/loops.mlir
+++ b/mlir/test/Dialect/Linalg/loops.mlir
@@ -9,6 +9,7 @@
 // CHECK-DAG: #[[strided4D:.*]] = affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3 + d3)>
 // CHECK-DAG: #[[clampMinMap:.*]] = affine_map<(d0) -> (d0, 0)>
 
+// CHECK-DAG: #[[Stride1Dilation1:.*]] = affine_map<(d0, d1) -> (d0  + d1)>
 // CHECK-DAG: #[[Stride2Dilation1:.*]] = affine_map<(d0, d1) -> (d0 * 2 + d1)>
 // CHECK-DAG: #[[Stride2Dilation4:.*]] = affine_map<(d0, d1) -> (d0 * 2 + d1 * 4)>
 // CHECK-DAG: #[[Stride3Dilation5:.*]] = affine_map<(d0, d1) -> (d0 * 3 + d1 * 5)>
@@ -251,6 +252,75 @@ func @conv_padding(%arg0: memref<?x?x?x?xf32>,
 //       CHECK:                 %{{.*}} = addf %{{.*}}, %{{.*}} : f32
 //       CHECK:                 store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}] : memref<?x?x?x?xf32>
 
+func @pooling_max(%arg0: memref<?x?xf32>,
+                  %arg1: memref<?x?xi32>,
+                  %arg2: memref<?x?xf32>) {
+  linalg.pooling_max(%arg0, %arg1, %arg2) { strides = [2, 1] }:
+    memref<?x?xf32>, memref<?x?xi32>, memref<?x?xf32>
+  return
+}
+// CHECK-LABEL: func @pooling_max
+//       CHECK:   %[[WX:.*]] = dim %arg1, 0 : memref<?x?xi32>
+//       CHECK:   %[[WY:.*]] = dim %arg1, 1 : memref<?x?xi32>
+//       CHECK:   %[[OX:.*]] = dim %arg2, 0 : memref<?x?xf32>
+//       CHECK:   %[[OY:.*]] = dim %arg2, 1 : memref<?x?xf32>
+//       CHECK:   loop.for %{{.*}} = %{{.*}} to %[[OX]] step %{{.*}} {
+//       CHECK:     loop.for %{{.*}} = %{{.*}} to %[[OY]] step %{{.*}} {
+//       CHECK:       loop.for %{{.*}} = %{{.*}} to %[[WX]] step %{{.*}} {
+//       CHECK:         loop.for %{{.*}} = %{{.*}} to %[[WY]] step %{{.*}} {
+//       CHECK:           %[[IX:.*]] = affine.apply #[[Stride2Dilation1]](%{{.*}}, %{{.*}})
+//       CHECK:           %[[IY:.*]] = affine.apply #[[Stride1Dilation1]](%{{.*}}, %{{.*}})
+//       CHECK:           %{{.*}} = load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xf32>
+//       CHECK:           %{{.*}} = load %{{.*}}[%[[IX]], %[[IY]]] : memref<?x?xf32>
+//       CHECK:           %[[RES:.*]] = select %{{.*}}, %{{.*}}, %{{.*}} : f32
+//       CHECK:           store %[[RES]], %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xf32>
+
+func @pooling_min(%arg0: memref<?x?xf32>,
+                  %arg1: memref<?x?xi32>,
+                  %arg2: memref<?x?xf32>) {
+  linalg.pooling_min(%arg0, %arg1, %arg2) { strides = [2, 1] }:
+    memref<?x?xf32>, memref<?x?xi32>, memref<?x?xf32>
+  return
+}
+// CHECK-LABEL: func @pooling_min
+//       CHECK:   %[[WX:.*]] = dim %arg1, 0 : memref<?x?xi32>
+//       CHECK:   %[[WY:.*]] = dim %arg1, 1 : memref<?x?xi32>
+//       CHECK:   %[[OX:.*]] = dim %arg2, 0 : memref<?x?xf32>
+//       CHECK:   %[[OY:.*]] = dim %arg2, 1 : memref<?x?xf32>
+//       CHECK:   loop.for %{{.*}} = %{{.*}} to %[[OX]] step %{{.*}} {
+//       CHECK:     loop.for %{{.*}} = %{{.*}} to %[[OY]] step %{{.*}} {
+//       CHECK:       loop.for %{{.*}} = %{{.*}} to %[[WX]] step %{{.*}} {
+//       CHECK:         loop.for %{{.*}} = %{{.*}} to %[[WY]] step %{{.*}} {
+//       CHECK:           %[[IX:.*]] = affine.apply #[[Stride2Dilation1]](%{{.*}}, %{{.*}})
+//       CHECK:           %[[IY:.*]] = affine.apply #[[Stride1Dilation1]](%{{.*}}, %{{.*}})
+//       CHECK:           %{{.*}} = load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xf32>
+//       CHECK:           %{{.*}} = load %{{.*}}[%[[IX]], %[[IY]]] : memref<?x?xf32>
+//       CHECK:           %[[RES:.*]] = select %{{.*}}, %{{.*}}, %{{.*}} : f32
+//       CHECK:           store %[[RES]], %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xf32>
+
+func @pooling_sum(%arg0: memref<?x?xf32>,
+                  %arg1: memref<?x?xi32>,
+                  %arg2: memref<?x?xf32>) {
+  linalg.pooling_sum(%arg0, %arg1, %arg2) { strides = [2, 1] }:
+    memref<?x?xf32>, memref<?x?xi32>, memref<?x?xf32>
+  return
+}
+// CHECK-LABEL: func @pooling_sum
+//       CHECK:   %[[WX:.*]] = dim %arg1, 0 : memref<?x?xi32>
+//       CHECK:   %[[WY:.*]] = dim %arg1, 1 : memref<?x?xi32>
+//       CHECK:   %[[OX:.*]] = dim %arg2, 0 : memref<?x?xf32>
+//       CHECK:   %[[OY:.*]] = dim %arg2, 1 : memref<?x?xf32>
+//       CHECK:   loop.for %{{.*}} = %{{.*}} to %[[OX]] step %{{.*}} {
+//       CHECK:     loop.for %{{.*}} = %{{.*}} to %[[OY]] step %{{.*}} {
+//       CHECK:       loop.for %{{.*}} = %{{.*}} to %[[WX]] step %{{.*}} {
+//       CHECK:         loop.for %{{.*}} = %{{.*}} to %[[WY]] step %{{.*}} {
+//       CHECK:           %[[IX:.*]] = affine.apply #[[Stride2Dilation1]](%{{.*}}, %{{.*}})
+//       CHECK:           %[[IY:.*]] = affine.apply #[[Stride1Dilation1]](%{{.*}}, %{{.*}})
+//       CHECK:           %[[RHS:.*]] = load %{{.*}}[%[[IX]], %[[IY]]] : memref<?x?xf32>
+//       CHECK:           %[[LHS:.*]] = load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xf32>
+//       CHECK:           %[[RES:.*]] = addf %[[LHS]], %[[RHS]] : f32
+//       CHECK:           store %[[RES]], %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xf32>
+
 func @foo(%0: f32, %1: f32, %2: f32) -> (f32, f32) {
   %f0 = constant 0.0 : f32
   return %f0, %f0 : f32, f32

diff  --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index 468fad45dd90..05d35f8f43e4 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -244,6 +244,48 @@ func @conv_padding(%arg0: memref<?x?x?x?xf32>,
 
 // -----
 
+func @pooling_max(%arg0: memref<?x?x?xf32>,
+                  %arg1: memref<?x?x?xi32>,
+                  %arg2: memref<?x?x?xf32>) {
+  linalg.pooling_max(%arg0, %arg1, %arg2) {strides = [2, 1, 2]}:
+    memref<?x?x?xf32>, memref<?x?x?xi32>, memref<?x?x?xf32>
+  return
+}
+// CHECK-LABEL: func @pooling_max
+//       CHECK:   linalg.pooling_max(%{{.*}}, %{{.*}}, %{{.*}})
+//  CHECK-SAME:   {strides = [2, 1, 2]}
+//  CHECK-SAME:   memref<?x?x?xf32>, memref<?x?x?xi32>, memref<?x?x?xf32>
+
+// -----
+
+func @pooling_min(%arg0: memref<?x?x?xf32>,
+                  %arg1: memref<?x?x?xi32>,
+                  %arg2: memref<?x?x?xf32>) {
+  linalg.pooling_min(%arg0, %arg1, %arg2) {strides = [2, 1, 2]}:
+    memref<?x?x?xf32>, memref<?x?x?xi32>, memref<?x?x?xf32>
+  return
+}
+// CHECK-LABEL: func @pooling_min
+//       CHECK:   linalg.pooling_min(%{{.*}}, %{{.*}}, %{{.*}})
+//  CHECK-SAME:   {strides = [2, 1, 2]}
+//  CHECK-SAME:   memref<?x?x?xf32>, memref<?x?x?xi32>, memref<?x?x?xf32>
+
+// -----
+
+func @pooling_sum(%arg0: memref<?x?x?xf32>,
+                  %arg1: memref<?x?x?xi32>,
+                  %arg2: memref<?x?x?xf32>) {
+  linalg.pooling_sum(%arg0, %arg1, %arg2) {strides = [2, 1, 2]}:
+    memref<?x?x?xf32>, memref<?x?x?xi32>, memref<?x?x?xf32>
+  return
+}
+// CHECK-LABEL: func @pooling_sum
+//       CHECK:   linalg.pooling_sum(%{{.*}}, %{{.*}}, %{{.*}})
+//  CHECK-SAME:   {strides = [2, 1, 2]}
+//  CHECK-SAME:   memref<?x?x?xf32>, memref<?x?x?xi32>, memref<?x?x?xf32>
+
+// -----
+
 // CHECK-DAG: #[[strided2D:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
 // CHECK-DAG: #[[strided3D:.*]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2 + d2)>
 


        


More information about the Mlir-commits mailing list