[Mlir-commits] [mlir] 6f03a10 - [mlir][TilingInterface] Add a method to generate scalar implementation of the op.

Mahesh Ravishankar llvmlistbot at llvm.org
Thu Jul 28 09:41:00 PDT 2022


Author: Mahesh Ravishankar
Date: 2022-07-28T16:37:15Z
New Revision: 6f03a10e4fdb4f22651a9dcc5d6ab318724235e8

URL: https://github.com/llvm/llvm-project/commit/6f03a10e4fdb4f22651a9dcc5d6ab318724235e8
DIFF: https://github.com/llvm/llvm-project/commit/6f03a10e4fdb4f22651a9dcc5d6ab318724235e8.diff

LOG: [mlir][TilingInterface] Add a method to generate scalar implementation of the op.

While The tiling interface provides a mechanism for operations to be
tiled into tiled version of the op (or another op at the same level of
abstraction), the `generateScalarImplementation` method added here is
the "exit point" after all transformations have been done. Ops that
implement this method are expected to generate IR that are directly
lowerable to backend dialects like LLVM or SPIR-V dialects.

Differential Revision: https://reviews.llvm.org/D130612

Added: 
    mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir

Modified: 
    mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
    mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
    mlir/include/mlir/Interfaces/TilingInterface.td
    mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
    mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
    mlir/lib/Dialect/Utils/StaticValueUtils.cpp
    mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index 52cc52325eb9a..a56b6b44e4657 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -141,6 +141,23 @@ struct TileConsumerAndFuseProducersUsingSCFForOp
   TileUsingSCFForOp tilingPattern;
 };
 
+/// Pattern to lower operations that implement the `TilingInterface` to
+/// loops/scalar IR using `scf.for`.
+struct LowerToLoopsUsingSCFForOp
+    : public OpInterfaceRewritePattern<TilingInterface> {
+  using OpInterfaceRewritePattern<TilingInterface>::OpInterfaceRewritePattern;
+
+  /// `matchAndRewrite` implementation that returns the significant transformed
+  /// pieces of IR.
+  FailureOr<SmallVector<scf::ForOp>>
+  returningMatchAndRewrite(TilingInterface op, PatternRewriter &rewriter) const;
+
+  LogicalResult matchAndRewrite(TilingInterface op,
+                                PatternRewriter &rewriter) const override {
+    return returningMatchAndRewrite(op, rewriter);
+  }
+};
+
 } // namespace scf
 } // namespace mlir
 

diff  --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
index b0c71f514a585..bde7f476bde50 100644
--- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
@@ -48,7 +48,7 @@ OpFoldResult getAsOpFoldResult(Value val);
 
 /// Given an array of values, try to extract a constant Attribute from each
 /// value. If this fails, return the original value.
-SmallVector<OpFoldResult> getAsOpFoldResult(ArrayRef<Value> values);
+SmallVector<OpFoldResult> getAsOpFoldResult(ValueRange values);
 
 /// Convert `arrayAttr` to a vector of OpFoldResult.
 SmallVector<OpFoldResult> getAsOpFoldResult(ArrayAttr arrayAttr);

diff  --git a/mlir/include/mlir/Interfaces/TilingInterface.td b/mlir/include/mlir/Interfaces/TilingInterface.td
index ee998530d3d8e..099e8f7eaac74 100644
--- a/mlir/include/mlir/Interfaces/TilingInterface.td
+++ b/mlir/include/mlir/Interfaces/TilingInterface.td
@@ -167,6 +167,28 @@ def TilingInterface : OpInterface<"TilingInterface"> {
         /*defaultImplementation=*/[{
           return failure();
         }]
+      >,
+      InterfaceMethod<
+        /*desc=*/[{
+          Generates the scalar implementation of the operation. 
+
+          Given the list `ivs` that represent points in the iteration space
+          (as specified by `getIterationDomain()`) returns the scalar operations
+          that represent the computation at that point in the iteration space.
+          This method is typically used as the "exit path", i.e. once all
+          transformations are done, this method can be used to lower to scalar 
+          code that can then be lowered to LLVM or SPIR-V dialects.
+        }],
+        /*retType=*/"LogicalResult",
+        /*methodName=*/"generateScalarImplementation",
+        /*args=*/(ins
+            "OpBuilder &":$b,
+            "Location ":$loc,
+            "ValueRange ":$ivs),
+        /*methodBody=*/"",
+        /*defaultImplementation=*/[{
+          return failure();
+        }]
       >
   ];  
 }

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index 9f838585d5742..518fdde4223fb 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -13,14 +13,68 @@
 #include "mlir/Dialect/Arithmetic/Utils/Utils.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/Linalg/Utils/Utils.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/Interfaces/TilingInterface.h"
 
 using namespace mlir;
 using namespace mlir::linalg;
 
-namespace {
+//===----------------------------------------------------------------------===//
+// Utility methods for implementation of Tiling Interface for Linalg ops
+//===----------------------------------------------------------------------===//
+
+/// Return the SSA values that represent the data point accessed using a given
+/// `indexingMap` for a given point in the iteration space represented by `ivs`.
+static SmallVector<Value> getIndicesForAccess(OpBuilder &b, Location loc,
+                                              AffineMap indexingMap,
+                                              ValueRange ivs) {
+  SmallVector<Value> indices;
+  indices.reserve(indexingMap.getNumResults());
+  for (auto result : indexingMap.getResults()) {
+    AffineMap m = AffineMap::get(indexingMap.getNumDims(),
+                                 indexingMap.getNumSymbols(), result);
+    Value v = b.create<AffineApplyOp>(loc, m, ivs);
+    indices.push_back(v);
+  }
+  return indices;
+}
+
+/// Method to inline the payload of a `linalgOp` given the iteration space
+/// point and values for the arguments of the payload.
+static LogicalResult inlinePayload(OpBuilder &b, LinalgOp linalgOp,
+                                   ValueRange ivs, ValueRange argValues) {
+  Block *body = linalgOp.getBlock();
+  BlockAndValueMapping map;
+  map.map(body->getArguments(), argValues);
+  for (auto &op : body->without_terminator()) {
+    if (auto indexOp = dyn_cast<IndexOp>(&op)) {
+      map.map(indexOp.getResult(), ivs[indexOp.dim()]);
+      continue;
+    }
+    b.clone(op, map);
+  }
+
+  Operation *terminator = body->getTerminator();
+  Location loc = terminator->getLoc();
+  for (auto operand : llvm::enumerate(terminator->getOperands())) {
+    Value toStore = map.lookupOrDefault(operand.value());
+    OpOperand *storeInto = linalgOp.getOutputOperand(operand.index());
+    auto indices = getIndicesForAccess(
+        b, loc, linalgOp.getTiedIndexingMap(storeInto), ivs);
+    b.create<memref::StoreOp>(loc, toStore,
+                              linalgOp.getOutputOperand(operand.index())->get(),
+                              indices);
+  }
+  return success();
+}
 
+//===----------------------------------------------------------------------===//
+// External Model for implementing `TilingInterface` for `LinalgOp`s.
+//===----------------------------------------------------------------------===//
+
+namespace {
 /// External model implementation of TilingInterface for LinalgOps. An external
 /// model implementation is used for now till the use of `TilingInterface` is
 /// on-par with the current Linalg tiling + fusion patterns. Once it is
@@ -167,6 +221,38 @@ struct LinalgOpTilingInterface
 
     return tiledOp[0]->getResult(resultNumber);
   }
+
+  LogicalResult generateScalarImplementation(Operation *op, OpBuilder &builder,
+                                             Location loc,
+                                             ValueRange ivs) const {
+    auto linalgOp = cast<LinalgOp>(op);
+    if (!linalgOp.hasBufferSemantics())
+      return op->emitOpError("expected operation to have buffer semantics");
+
+    SmallVector<Value> indexedValues;
+    indexedValues.reserve(linalgOp.getNumInputsAndOutputs());
+    Location linalgOpLoc = op->getLoc();
+    /// Load the data corresponding to the block arguments that
+    /// represent input operands.
+    for (OpOperand *operand : linalgOp.getInputAndOutputOperands()) {
+      if (!linalgOp.payloadUsesValueFromOperand(operand)) {
+        indexedValues.push_back(nullptr);
+        continue;
+      }
+      if (linalgOp.isScalar(operand)) {
+        indexedValues.push_back(operand->get());
+        continue;
+      }
+      SmallVector<Value> indices = getIndicesForAccess(
+          builder, linalgOpLoc, linalgOp.getTiedIndexingMap(operand), ivs);
+      Value load =
+          builder.create<memref::LoadOp>(linalgOpLoc, operand->get(), indices);
+      indexedValues.push_back(load);
+    }
+
+    /// Inline the op payload and store the result.
+    return inlinePayload(builder, linalgOp, ivs, indexedValues);
+  }
 };
 
 } // namespace

diff  --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index e2ca296616644..8d304fc5775a2 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -494,3 +494,41 @@ scf::TileConsumerAndFuseProducersUsingSCFForOp::returningMatchAndRewrite(
                   tileAndFuseResult.loops.back(), rewriter);
   return tileAndFuseResult;
 }
+
+//===----------------------------------------------------------------------===//
+// LowerToLoopsUsingSCFForOp
+//===----------------------------------------------------------------------===//
+
+FailureOr<SmallVector<scf::ForOp>>
+scf::LowerToLoopsUsingSCFForOp::returningMatchAndRewrite(
+    TilingInterface op, PatternRewriter &rewriter) const {
+  SmallVector<Range> domain = op.getIterationDomain(rewriter);
+
+  // TODO: Handle cases where the op has results if needed.
+  if (op->getNumResults() > 0) {
+    return rewriter.notifyMatchFailure(
+        op, "unable to lower to loops operations with return values");
+  }
+
+  SmallVector<Value> ivs;
+  SmallVector<scf::ForOp> loops;
+  Location loc = op.getLoc();
+  for (auto loopRange : domain) {
+    Value offsetVal =
+        getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.offset);
+    Value sizeVal =
+        getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.size);
+    Value strideVal =
+        getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.stride);
+    auto loop = rewriter.create<scf::ForOp>(op.getLoc(), offsetVal, sizeVal,
+                                            strideVal, ValueRange{});
+    loops.push_back(loop);
+    ivs.push_back(loop.getInductionVar());
+    rewriter.setInsertionPoint(loop.getBody()->getTerminator());
+  }
+  if (failed(op.generateScalarImplementation(rewriter, op.getLoc(), ivs))) {
+    return failure();
+  }
+  rewriter.eraseOp(op);
+  return loops;
+}

diff  --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
index db59141d57da2..80e93553858fd 100644
--- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
+++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
@@ -62,7 +62,7 @@ OpFoldResult getAsOpFoldResult(Value val) {
 
 /// Given an array of values, try to extract a constant Attribute from each
 /// value. If this fails, return the original value.
-SmallVector<OpFoldResult> getAsOpFoldResult(ArrayRef<Value> values) {
+SmallVector<OpFoldResult> getAsOpFoldResult(ValueRange values) {
   return llvm::to_vector<4>(
       llvm::map_range(values, [](Value v) { return getAsOpFoldResult(v); }));
 }

diff  --git a/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir
new file mode 100644
index 0000000000000..519364963e7cb
--- /dev/null
+++ b/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir
@@ -0,0 +1,161 @@
+// RUN: mlir-opt -test-tiling-interface=lower-to-scalar-using-scf-for -split-input-file %s | FileCheck %s
+
+func.func @gemm(%arg0 : memref<?x?xf32>, %arg1 : memref<?x?xf32>,
+  %arg2 : memref<?x?xf32>) {
+  linalg.matmul ins(%arg0, %arg1 : memref<?x?xf32>, memref<?x?xf32>)
+      outs(%arg2 : memref<?x?xf32>)
+  return
+}
+// CHECK-LABEL: func @gemm
+//  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: memref<?x?xf32>
+//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: memref<?x?xf32>
+//  CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: memref<?x?xf32>
+//   CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//   CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
+//   CHECK-DAG:   %[[M:.+]] = memref.dim %[[ARG0]], %[[C0]]
+//   CHECK-DAG:   %[[K:.+]] = memref.dim %[[ARG0]], %[[C1]]
+//   CHECK-DAG:   %[[N:.+]] = memref.dim %[[ARG1]], %[[C1]]
+//       CHECK:   scf.for %[[IV0:[a-zA-Z0-9]+]] = %[[C0]] to %[[M]] step %[[C1]]
+//       CHECK:     scf.for %[[IV1:[a-zA-Z0-9]+]] = %[[C0]] to %[[N]] step %[[C1]]
+//       CHECK:       scf.for %[[IV2:[a-zA-Z0-9]+]] = %[[C0]] to %[[K]] step %[[C1]]
+//   CHECK-DAG:         %[[LHS:.+]] = memref.load %[[ARG0]][%[[IV0]], %[[IV2]]]
+//   CHECK-DAG:         %[[RHS:.+]] = memref.load %[[ARG1]][%[[IV2]], %[[IV1]]]
+//   CHECK-DAG:         %[[OUT:.+]] = memref.load %[[ARG2]][%[[IV0]], %[[IV1]]]
+//       CHECK:         %[[MULF:.+]] = arith.mulf %[[LHS]], %[[RHS]]
+//       CHECK:         %[[ADDF:.+]] = arith.addf %[[OUT]], %[[MULF]]
+//       CHECK:         memref.store %[[ADDF]], %[[ARG2]][%[[IV0]], %[[IV1]]]
+
+// -----
+
+func.func @indexed_generic(%arg0 : memref<200x300xi32>, %arg1 : memref<300xi16>,
+    %arg2 : memref<200xi8>, %arg3 : memref<300x200xi64>) {
+  linalg.generic {
+      indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1)>,
+                       affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d1, d0)>],
+      iterator_types = ["parallel", "parallel"]}
+      ins(%arg0, %arg1, %arg2 : memref<200x300xi32>, memref<300xi16>, memref<200xi8>)
+      outs(%arg3 : memref<300x200xi64>) {
+    ^bb0(%b0 : i32, %b1 : i16, %b2 : i8, %b3 : i64):
+      %0 = linalg.index 0 : index
+      %1 = arith.index_cast %0 : index to i16
+      %2 = arith.muli %b1, %1 : i16
+      %3 = linalg.index 1 : index
+      %4 = arith.index_cast %3 : index to i8
+      %5 = arith.muli %b2, %4 : i8
+      %6 = arith.extsi %2 : i16 to i32
+      %7 = arith.extsi %5 : i8 to i32
+      %8 = arith.addi %6, %7 : i32
+      %9 = arith.addi %8, %b0 : i32
+      %10 = arith.extsi %9 : i32 to i64
+      linalg.yield %10 : i64
+    }
+  return
+}
+// CHECK-LABEL: func @indexed_generic
+//  CHECK-SAME:     %[[ARG0:.+]]: memref<200x300xi32>
+//  CHECK-SAME:     %[[ARG1:.+]]: memref<300xi16>
+//  CHECK-SAME:     %[[ARG2:.+]]: memref<200xi8>
+//  CHECK-SAME:     %[[ARG3:.+]]: memref<300x200xi64>
+//   CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//   CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
+//   CHECK-DAG:   %[[C200:.+]] = arith.constant 200 : index
+//   CHECK-DAG:   %[[C300:.+]] = arith.constant 300 : index
+//       CHECK:   scf.for %[[IV0:[a-zA-Z0-9]+]] = %[[C0]] to %[[C200]] step %[[C1]]
+//       CHECK:     scf.for %[[IV1:[a-zA-Z0-9]+]] = %[[C0]] to %[[C300]] step %[[C1]]
+//   CHECK-DAG:       %[[B0:.+]] = memref.load %[[ARG0]][%[[IV0]], %[[IV1]]]
+//   CHECK-DAG:       %[[B1:.+]] = memref.load %[[ARG1]][%[[IV1]]]
+//   CHECK-DAG:       %[[B2:.+]] = memref.load %[[ARG2]][%[[IV0]]]
+//       CHECK:       %[[T1:.+]] = arith.index_cast %[[IV0]]
+//       CHECK:       %[[T2:.+]] = arith.muli %[[B1]], %[[T1]]
+//       CHECK:       %[[T4:.+]] = arith.index_cast %[[IV1]]
+//       CHECK:       %[[T5:.+]] = arith.muli %[[B2]], %[[T4]]
+//       CHECK:       %[[T6:.+]] = arith.extsi %[[T2]]
+//       CHECK:       %[[T7:.+]] = arith.extsi %[[T5]]
+//       CHECK:       %[[T8:.+]] = arith.addi %[[T6]], %[[T7]]
+//       CHECK:       %[[T9:.+]] = arith.addi %[[T8]], %[[B0]]
+//       CHECK:       %[[T10:.+]] = arith.extsi %[[T9]]
+//       CHECK:       memref.store %[[T10]], %[[ARG3]][%[[IV1]], %[[IV0]]]
+
+// -----
+
+func.func @conv_strides_and_dilation(%arg0 : memref<?x?x?x?xf32>, %arg1 : memref<?x?x?x?xf32>,
+  %arg2 : memref<?x?x?x?xf32>) {
+  linalg.conv_2d_nhwc_hwcf {
+      strides = dense<[1, 2]> : tensor<2xi64>,
+      dilations = dense<[3, 4]> : tensor<2xi64>}
+      ins(%arg0, %arg1 : memref<?x?x?x?xf32>, memref<?x?x?x?xf32>)
+      outs(%arg2 : memref<?x?x?x?xf32>)
+  return
+}
+//  CHECK-DAG:  #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1 + d4 * 3)>
+//  CHECK-DAG:  #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d2 * 2 + d5 * 4)>
+//       CHECK: func @conv_strides_and_dilation(
+//  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: memref<?x?x?x?xf32>
+//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: memref<?x?x?x?xf32>
+//  CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: memref<?x?x?x?xf32>
+//   CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//   CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
+//   CHECK-DAG:   %[[C2:.+]] = arith.constant 2 : index
+//   CHECK-DAG:   %[[C3:.+]] = arith.constant 3 : index
+//   CHECK-DAG:   %[[N:.+]] = memref.dim %[[ARG0]], %[[C0]]
+//   CHECK-DAG:   %[[C:.+]] = memref.dim %[[ARG0]], %[[C3]]
+//   CHECK-DAG:   %[[H:.+]] = memref.dim %[[ARG1]], %[[C0]]
+//   CHECK-DAG:   %[[W:.+]] = memref.dim %[[ARG1]], %[[C1]]
+//   CHECK-DAG:   %[[F:.+]] = memref.dim %[[ARG1]], %[[C3]]
+//   CHECK-DAG:   %[[P:.+]] = memref.dim %[[ARG2]], %[[C1]]
+//   CHECK-DAG:   %[[Q:.+]] = memref.dim %[[ARG2]], %[[C2]]
+//       CHECK:   scf.for %[[IV0:[a-zA-Z0-9]+]] = %[[C0]] to %[[N]] step %[[C1]]
+//       CHECK:     scf.for %[[IV1:[a-zA-Z0-9]+]] = %[[C0]] to %[[P]] step %[[C1]]
+//       CHECK:       scf.for %[[IV2:[a-zA-Z0-9]+]] = %[[C0]] to %[[Q]] step %[[C1]]
+//       CHECK:         scf.for %[[IV3:[a-zA-Z0-9]+]] = %[[C0]] to %[[F]] step %[[C1]]
+//       CHECK:           scf.for %[[IV4:[a-zA-Z0-9]+]] = %[[C0]] to %[[H]] step %[[C1]]
+//       CHECK:             scf.for %[[IV5:[a-zA-Z0-9]+]] = %[[C0]] to %[[W]] step %[[C1]]
+//       CHECK:               scf.for %[[IV6:[a-zA-Z0-9]+]] = %[[C0]] to %[[C]] step %[[C1]]
+//   CHECK-DAG:                 %[[I:.+]] = affine.apply #[[MAP0]](%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]], %[[IV4]], %[[IV5]], %[[IV6]])
+//   CHECK-DAG:                 %[[J:.+]] = affine.apply #[[MAP1]](%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]], %[[IV4]], %[[IV5]], %[[IV6]])
+//   CHECK-DAG:                 %[[T9:.+]] = memref.load %[[ARG0]][%[[IV0]], %[[I]], %[[J]], %[[IV6]]]
+//   CHECK-DAG:                 %[[T10:.+]] = memref.load %[[ARG1]][%[[IV4]], %[[IV5]], %[[IV6]], %[[IV3]]]
+//   CHECK-DAG:                 %[[T11:.+]] = memref.load %[[ARG2]][%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]]]
+//       CHECK:                 %[[T12:.+]] = arith.mulf %[[T9]], %[[T10]]
+//       CHECK:                 %[[T13:.+]] = arith.addf %[[T11]], %[[T12]]
+//       CHECK:                 memref.store %[[T13]], %[[ARG2]][%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]]]
+
+// -----
+
+func.func @pool_strides_and_dilation(%arg0 : memref<?x?x?x?xf32>, %arg1 : memref<?x?xf32>,
+  %arg2 : memref<?x?x?x?xf32>) {
+  linalg.pooling_nhwc_max {
+      strides = dense<[1, 2]> : tensor<2xi64>,
+      dilations = dense<[3, 4]> : tensor<2xi64>}
+      ins(%arg0, %arg1 : memref<?x?x?x?xf32>, memref<?x?xf32>)
+      outs(%arg2 : memref<?x?x?x?xf32>)
+  return
+}
+//  CHECK-DAG:  #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1 + d4 * 3)>
+//  CHECK-DAG:  #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2 * 2 + d5 * 4)>
+//       CHECK: func @pool_strides_and_dilation
+//  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: memref<?x?x?x?xf32>
+//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: memref<?x?xf32>
+//  CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: memref<?x?x?x?xf32>
+//   CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//   CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
+//   CHECK-DAG:   %[[C2:.+]] = arith.constant 2 : index
+//   CHECK-DAG:   %[[C3:.+]] = arith.constant 3 : index
+//   CHECK-DAG:   %[[N:.+]] = memref.dim %[[ARG0]], %[[C0]]
+//   CHECK-DAG:   %[[C:.+]] = memref.dim %[[ARG0]], %[[C3]]
+//   CHECK-DAG:   %[[H:.+]] = memref.dim %[[ARG1]], %[[C0]]
+//   CHECK-DAG:   %[[W:.+]] = memref.dim %[[ARG1]], %[[C1]]
+//   CHECK-DAG:   %[[P:.+]] = memref.dim %[[ARG2]], %[[C1]]
+//   CHECK-DAG:   %[[Q:.+]] = memref.dim %[[ARG2]], %[[C2]]
+//       CHECK:   scf.for %[[IV0:[a-zA-Z0-9]+]] = %[[C0]] to %[[N]] step %[[C1]]
+//       CHECK:     scf.for %[[IV1:[a-zA-Z0-9]+]] = %[[C0]] to %[[P]] step %[[C1]]
+//       CHECK:       scf.for %[[IV2:[a-zA-Z0-9]+]] = %[[C0]] to %[[Q]] step %[[C1]]
+//       CHECK:         scf.for %[[IV3:[a-zA-Z0-9]+]] = %[[C0]] to %[[C]] step %[[C1]]
+//       CHECK:           scf.for %[[IV4:[a-zA-Z0-9]+]] = %[[C0]] to %[[H]] step %[[C1]]
+//       CHECK:             scf.for %[[IV5:[a-zA-Z0-9]+]] = %[[C0]] to %[[W]] step %[[C1]]
+//   CHECK-DAG:               %[[I:.+]] = affine.apply #[[MAP0]](%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]], %[[IV4]], %[[IV5]])
+//   CHECK-DAG:               %[[J:.+]] = affine.apply #[[MAP1]](%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]], %[[IV4]], %[[IV5]])
+//   CHECK-DAG:               %[[T8:.+]] = memref.load %[[ARG0]][%[[IV0]], %[[I]], %[[J]], %[[IV3]]]
+//   CHECK-DAG:               %[[T9:.+]] = memref.load %[[ARG2]][%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]]]
+//       CHECK:               %[[T10:.+]] = arith.maxf %[[T9]], %[[T8]]
+//       CHECK:               memref.store %[[T10]], %[[ARG2]][%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]]]

diff  --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp
index 5c603a55d7419..c535fca1f506c 100644
--- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp
+++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp
@@ -65,7 +65,7 @@ struct TestTileUsingSCFForOpWithFilter : public scf::TileUsingSCFForOp {
   linalg::LinalgTransformationFilter filter;
 };
 
-/// Pattern for testing `TileConsumerAndFUseProducersUsingSCFForOp` pattern
+/// Pattern for testing `TileConsumerAndFuseProducersUsingSCFForOp` pattern
 /// (that tiles and fuses operations using the `TilingInterface` with `scf.for`
 /// ops for iterating over the tiles) while using a `filter` to avoid recursive
 /// application.
@@ -138,6 +138,12 @@ struct TestTilingInterfacePass
                      "with scf.for operations"),
       llvm::cl::init(false)};
 
+  Option<bool> testLoweringToScalar{
+      *this, "lower-to-scalar-using-scf-for",
+      llvm::cl::desc("Test lowering to scalar implementation using "
+                     "TilingInterface with scf.for operations"),
+      llvm::cl::init(false)};
+
   void runOnOperation() override;
 
 private:
@@ -199,6 +205,9 @@ void TestTilingInterfacePass::addTestPatterns(MLIRContext *context,
         context, patterns, "gemm_sequence_fusion", {10});
     return;
   }
+  if (testLoweringToScalar) {
+    patterns.add<scf::LowerToLoopsUsingSCFForOp>(context);
+  }
 }
 
 void TestTilingInterfacePass::runOnOperation() {


        


More information about the Mlir-commits mailing list