[Mlir-commits] [mlir] 5759e47 - [mlir][Linalg] Avoid using scf.parallel for non-parallel loops in Linalg ops.

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue May 26 21:53:48 PDT 2020


Author: MaheshRavishankar
Date: 2020-05-26T21:33:57-07:00
New Revision: 5759e4731635e1f28fef2c4619491a1b4a2bc305

URL: https://github.com/llvm/llvm-project/commit/5759e4731635e1f28fef2c4619491a1b4a2bc305
DIFF: https://github.com/llvm/llvm-project/commit/5759e4731635e1f28fef2c4619491a1b4a2bc305.diff

LOG: [mlir][Linalg] Avoid using scf.parallel for non-parallel loops in Linalg ops.

Modifying the loop nest builder for generating scf.parallel loops to
not generate scf.parallel loops for non-parallel iterator types in
Linalg operations. The existing implementation incorrectly generated
scf.parallel for all tiled loops. It is rectified by refactoring logic
used while lowering to loops that accounted for this.

Differential Revision: https://reviews.llvm.org/D80188

Added: 
    mlir/test/Dialect/Linalg/tile_parallel_reduce.mlir

Modified: 
    mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
    mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
    mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
    mlir/lib/Dialect/Linalg/Utils/Utils.cpp
    mlir/test/Dialect/Linalg/parallel_loops.mlir
    mlir/test/Dialect/Linalg/transform-patterns.mlir
    mlir/test/lib/Transforms/TestLinalgTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index 9caec484659e..c8a5d83438f5 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -9,14 +9,21 @@
 #ifndef MLIR_DIALECT_LINALG_UTILS_H_
 #define MLIR_DIALECT_LINALG_UTILS_H_
 
+#include "mlir/Dialect/Affine/EDSC/Intrinsics.h"
+#include "mlir/Dialect/Linalg/EDSC/Builders.h"
 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
 #include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 
 #include "llvm/ADT/SetVector.h"
 
+using mlir::edsc::intrinsics::AffineIndexedValue;
+using mlir::edsc::intrinsics::StdIndexedValue;
+
 namespace mlir {
 class AffineExpr;
+class AffineForOp;
 class AffineMap;
 class OperationFolder;
 class PatternRewriter;
@@ -49,6 +56,15 @@ struct RegionMatcher {
   static Optional<BinaryOpKind> matchAsScalarBinaryOp(GenericOp op);
 };
 
+/// Checks if an iterator_type attribute is parallel.
+bool isParallelIteratorType(Attribute attr);
+
+/// Checks if an iterator_type attribute is parallel.
+bool isReductionIteratorType(Attribute attr);
+
+/// Checks if an iterator_type attribute is parallel.
+bool isWindowIteratorType(Attribute attr);
+
 /// Checks whether the specific `producer` is the last write to exactly the
 /// whole `consumedView`. This checks structural dominance, that the dependence
 /// is a RAW without any interleaved write to any piece of `consumedView`.
@@ -141,6 +157,21 @@ void applyPermutationToVector(SmallVector<T, N> &inVec,
   inVec = auxVec;
 }
 
+/// Utility class used to generate nested loops with ranges described by
+/// `loopRanges` and loop type described by the `iteratorTypes`. `allIvs` is
+/// populated with induction variables for all generated loops on return, with
+/// `fun` used to generate the body of the innermost loop.
+template <typename LoopTy>
+struct GenerateLoopNest {
+  using IndexedValueTy =
+      typename std::conditional<std::is_same<LoopTy, AffineForOp>::value,
+                                AffineIndexedValue, StdIndexedValue>::type;
+  static void doit(MutableArrayRef<Value> allIvs,
+                   ArrayRef<SubViewOp::Range> loopRanges,
+                   ArrayRef<Attribute> iteratorTypes,
+                   std::function<void(void)> fun);
+};
+
 } // namespace linalg
 } // namespace mlir
 

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
index 74da63dafee3..910078875f57 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
@@ -487,80 +487,9 @@ class LinalgScopedEmitter<IndexedValueType, IndexedGenericOp> {
   }
 };
 
-namespace {
-/// Helper struct to generate the loop nest for the op. This factored out here
-/// to be able to partially specialize this for 
diff erent LoopTy.
-template <typename LoopTy, typename ConcreteOpTy>
-class GenerateLoopNest {
-public:
-  using IndexedValueTy =
-      typename std::conditional<std::is_same<LoopTy, AffineForOp>::value,
-                                AffineIndexedValue, StdIndexedValue>::type;
-  static void doit(ConcreteOpTy linalgOp, ArrayRef<SubViewOp::Range> loopRanges,
-                   MutableArrayRef<Value> allIvs) {
-    GenericLoopNestRangeBuilder<LoopTy>(allIvs, loopRanges)([&] {
-      SmallVector<Value, 4> allIvValues(allIvs.begin(), allIvs.end());
-      LinalgScopedEmitter<IndexedValueTy,
-                          ConcreteOpTy>::emitScalarImplementation(allIvValues,
-                                                                  linalgOp);
-    });
-  }
-};
-
-/// Generates loop nest using scf.parallel. scf.parallel is only used for the
-/// outer parallel loops. All other loops are generated using scf.for
-/// operation.
-template <typename ConcreteOpTy>
-class GenerateLoopNest<scf::ParallelOp, ConcreteOpTy> {
-public:
-  using IndexedValueTy = StdIndexedValue;
-
-  static void doit(ConcreteOpTy linalgOp, ArrayRef<SubViewOp::Range> loopRanges,
-                   MutableArrayRef<Value> allIvs) {
-    // Only generate scf.parallel for outer consecutive "parallel"
-    // iterator_types.
-    // TODO(ravishankarm): Generate scf.parallel for all "parallel" iterator
-    // types, not just the outer most ones. Also handle "reduction" iterator
-    // types.
-    auto nOuterPar = linalgOp.iterator_types()
-                         .getValue()
-                         .take_while([](Attribute attr) {
-                           return attr.cast<StringAttr>().getValue() ==
-                                  getParallelIteratorTypeName();
-                         })
-                         .size();
-    // If there are no outer parallel loops, then number of loop ops is same as
-    // the number of loops, and they are all scf.for ops.
-    if (nOuterPar) {
-      GenericLoopNestRangeBuilder<scf::ParallelOp>(
-          allIvs.take_front(nOuterPar), loopRanges.take_front(nOuterPar))([&] {
-        GenericLoopNestRangeBuilder<scf::ForOp>(
-            allIvs.drop_front(nOuterPar),
-            loopRanges.drop_front(nOuterPar))([&] {
-          SmallVector<Value, 4> allIvValues(allIvs.begin(), allIvs.end());
-          LinalgScopedEmitter<StdIndexedValue, ConcreteOpTy>::
-              emitScalarImplementation(allIvValues, linalgOp);
-        });
-      });
-    } else {
-      // If there are no parallel loops then fallback to generating all scf.for
-      // operations.
-      GenericLoopNestRangeBuilder<scf::ForOp>(allIvs, loopRanges)([&] {
-        SmallVector<Value, 4> allIvValues(allIvs.begin(), allIvs.end());
-        LinalgScopedEmitter<StdIndexedValue,
-                            ConcreteOpTy>::emitScalarImplementation(allIvValues,
-                                                                    linalgOp);
-      });
-    }
-  }
-};
-} // namespace
-
 template <typename LoopTy, typename ConcreteOpTy>
 Optional<LinalgLoops> linalgOpToLoopsImpl(Operation *op, OpBuilder &builder) {
-  using Impl = GenerateLoopNest<LoopTy, ConcreteOpTy>;
-  using IndexedValueTy =
-      typename GenerateLoopNest<LoopTy, ConcreteOpTy>::IndexedValueTy;
+  using IndexedValueTy = typename GenerateLoopNest<LoopTy>::IndexedValueTy;
 
   ScopedContext scope(builder, op->getLoc());
 
@@ -591,7 +520,13 @@ Optional<LinalgLoops> linalgOpToLoopsImpl(Operation *op, OpBuilder &builder) {
       emitLoopRanges(scope.getBuilderRef(), scope.getLocation(), invertedMap,
                      getViewSizes(builder, linalgOp));
   assert(loopRanges.size() == allIvs.size());
-  Impl::doit(linalgOp, loopRanges, allIvs);
+  GenerateLoopNest<LoopTy>::doit(
+      allIvs, loopRanges, linalgOp.iterator_types().getValue(), [&] {
+        SmallVector<Value, 4> allIvValues(allIvs.begin(), allIvs.end());
+        LinalgScopedEmitter<IndexedValueTy,
+                            ConcreteOpTy>::emitScalarImplementation(allIvValues,
+                                                                    linalgOp);
+      });
   // Number of loop ops might be 
diff erent from the number of ivs since some
   // loops like affine.parallel and scf.parallel have multiple ivs.
   llvm::SetVector<Operation *> loopSet;

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index 2d875d4e95e4..5b4fec4bbf20 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -376,7 +376,11 @@ Optional<TiledLinalgOp> static tileLinalgOpImpl(
   // 3. Create the tiled loops.
   LinalgOp res = op;
   SmallVector<Value, 4> ivs(loopRanges.size());
-  GenericLoopNestRangeBuilder<LoopTy>(ivs, loopRanges)([&] {
+  SmallVector<Attribute, 4> iteratorTypes =
+      llvm::to_vector<4>(op.iterator_types().cast<ArrayAttr>().getValue());
+  if (!options.interchangeVector.empty())
+    applyPermutationToVector(iteratorTypes, options.interchangeVector);
+  GenerateLoopNest<LoopTy>::doit(ivs, loopRanges, iteratorTypes, [&] {
     auto &b = ScopedContext::getBuilderRef();
     auto loc = ScopedContext::getLocation();
     SmallVector<Value, 4> ivValues(ivs.begin(), ivs.end());
@@ -384,8 +388,8 @@ Optional<TiledLinalgOp> static tileLinalgOpImpl(
     // If we have to apply a permutation to the tiled loop nest, we have to
     // reorder the induction variables This permutation is the right one
     // assuming that loopRanges have previously been permuted by
-    // (i,j,k)->(k,i,j) So this permutation should be the inversePermutation of
-    // that one: (d0,d1,d2)->(d2,d0,d1)
+    // (i,j,k)->(k,i,j) So this permutation should be the inversePermutation
+    // of that one: (d0,d1,d2)->(d2,d0,d1)
     if (!options.interchangeVector.empty())
       ivValues = applyMapToValues(b, loc, invPermutationMap, ivValues);
 

diff  --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index 4f86b934172b..cd8b17650bb1 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -14,6 +14,7 @@
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
+#include "mlir/Dialect/SCF/EDSC/Builders.h"
 #include "mlir/Dialect/SCF/SCF.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/IR/AffineExpr.h"
@@ -101,3 +102,91 @@ mlir::linalg::getAssumedNonViewOperands(LinalgOp linalgOp) {
   }
   return res;
 }
+
+bool mlir::linalg::isParallelIteratorType(Attribute attr) {
+  if (auto strAttr = attr.dyn_cast<StringAttr>()) {
+    return strAttr.getValue() == getParallelIteratorTypeName();
+  }
+  return false;
+}
+
+bool mlir::linalg::isReductionIteratorType(Attribute attr) {
+  if (auto strAttr = attr.dyn_cast<StringAttr>()) {
+    return strAttr.getValue() == getReductionIteratorTypeName();
+  }
+  return false;
+}
+
+bool mlir::linalg::isWindowIteratorType(Attribute attr) {
+  if (auto strAttr = attr.dyn_cast<StringAttr>()) {
+    return strAttr.getValue() == getWindowIteratorTypeName();
+  }
+  return false;
+}
+
+/// Explicit instantiation of loop nest generator for 
diff erent loop types.
+template struct mlir::linalg::GenerateLoopNest<scf::ForOp>;
+template struct mlir::linalg::GenerateLoopNest<scf::ParallelOp>;
+template struct mlir::linalg::GenerateLoopNest<AffineForOp>;
+
+/// Specialization of loop nest generator for scf.parallel loops to handle
+/// iterator types that are not parallel. These are generated as sequential
+/// loops.
+template <>
+void mlir::linalg::GenerateLoopNest<scf::ForOp>::doit(
+    MutableArrayRef<Value> allIvs, ArrayRef<SubViewOp::Range> loopRanges,
+    ArrayRef<Attribute> iteratorTypes, std::function<void(void)> fun) {
+  edsc::GenericLoopNestRangeBuilder<scf::ForOp>(allIvs, loopRanges)(fun);
+}
+
+template <>
+void mlir::linalg::GenerateLoopNest<AffineForOp>::doit(
+    MutableArrayRef<Value> allIvs, ArrayRef<SubViewOp::Range> loopRanges,
+    ArrayRef<Attribute> iteratorTypes, std::function<void(void)> fun) {
+  edsc::GenericLoopNestRangeBuilder<AffineForOp>(allIvs, loopRanges)(fun);
+}
+
+template <>
+void mlir::linalg::GenerateLoopNest<scf::ParallelOp>::doit(
+    MutableArrayRef<Value> allIvs, ArrayRef<SubViewOp::Range> loopRanges,
+    ArrayRef<Attribute> iteratorTypes, std::function<void(void)> fun) {
+  // Check if there is nothing to do here. This is also the recursion
+  // termination.
+  if (loopRanges.empty())
+    return;
+  size_t nOuterPar = iteratorTypes.take_front(loopRanges.size())
+                         .take_while(isParallelIteratorType)
+                         .size();
+  if (nOuterPar == 0 && loopRanges.size() == 1)
+    // Generate the sequential for loop for the remaining non-parallel loop.
+    return GenerateLoopNest<scf::ForOp>::doit(allIvs, loopRanges, iteratorTypes,
+                                              fun);
+  if (nOuterPar == 0) {
+    // The immediate outer loop is not parallel. Generate a scf.for op for this
+    // loop, but there might be subsequent loops that are parallel. Use
+    // recursion to find those.
+    auto nestedFn = [&]() {
+      GenerateLoopNest<scf::ParallelOp>::doit(allIvs.drop_front(),
+                                              loopRanges.drop_front(),
+                                              iteratorTypes.drop_front(), fun);
+    };
+    return GenerateLoopNest<scf::ForOp>::doit(allIvs[0], loopRanges[0],
+                                              iteratorTypes[0], nestedFn);
+  }
+  if (nOuterPar == loopRanges.size()) {
+    // All loops are parallel, so generate the scf.parallel op.
+    return edsc::GenericLoopNestRangeBuilder<scf::ParallelOp>(allIvs,
+                                                              loopRanges)(fun);
+  }
+  // Generate scf.parallel for the outer parallel loops. The next inner loop is
+  // sequential, but there might be more parallel loops after that. So recurse
+  // into the same method.
+  auto nestedFn = [&]() {
+    GenerateLoopNest<scf::ParallelOp>::doit(
+        allIvs.drop_front(nOuterPar), loopRanges.drop_front(nOuterPar),
+        iteratorTypes.drop_front(nOuterPar), fun);
+  };
+  return GenerateLoopNest<scf::ParallelOp>::doit(
+      allIvs.take_front(nOuterPar), loopRanges.take_front(nOuterPar),
+      iteratorTypes.take_front(nOuterPar), nestedFn);
+}

diff  --git a/mlir/test/Dialect/Linalg/parallel_loops.mlir b/mlir/test/Dialect/Linalg/parallel_loops.mlir
index abe9cccc8b75..2174ddc3c269 100644
--- a/mlir/test/Dialect/Linalg/parallel_loops.mlir
+++ b/mlir/test/Dialect/Linalg/parallel_loops.mlir
@@ -57,6 +57,42 @@ func @lower_outer_parallel(%A: memref<?x?x?x?xf32>, %B: memref<?x?x?xf32>) {
 //   CHECK-DAG: %[[D3:.*]] = dim %{{.*}}, 3
 //       CHECK: scf.parallel (%[[IV0:.*]], %[[IV1:.*]]) = (%[[C0]], %[[C0]]) to (%[[D0]], %[[D1]]) step (%[[C1]], %[[C1]])
 //       CHECK:   scf.for %[[IV2:.*]] = %[[C0]] to %[[D2]] step %[[C1]]
-//       CHECK:     scf.for %[[IV3:.*]] = %[[C0]] to %[[D3]] step %[[C1]]
+//       CHECK:     scf.parallel (%[[IV3:.*]]) = (%[[C0]]) to (%[[D3]]) step (%[[C1]])
 //       CHECK:       load %{{.*}}[%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]]]
 //       CHECK:       store %{{.*}}, %{{.*}}[%[[IV0]], %[[IV1]], %[[IV3]]]
+
+// -----
+
+#accesses = [
+  affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>,
+  affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d4, d5)>
+]
+#trait = {
+  args_in = 1,
+  args_out = 1,
+  iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"],
+  indexing_maps = #accesses
+}
+
+func @lower_mixed_parallel(%A: memref<?x?x?x?x?x?xf32>, %B: memref<?x?x?x?xf32>) {
+  linalg.generic #trait %A, %B {
+    ^bb0(%a: f32, %b: f32):
+      linalg.yield %a: f32
+  } : memref<?x?x?x?x?x?xf32>, memref<?x?x?x?xf32>
+  return
+}
+// CHECK-LABEL: @lower_mixed_parallel
+//   CHECK-DAG: %[[C0:.*]] = constant 0
+//   CHECK-DAG: %[[C1:.*]] = constant 1
+//   CHECK-DAG: %[[D0:.*]] = dim %{{.*}}, 0
+//   CHECK-DAG: %[[D1:.*]] = dim %{{.*}}, 1
+//   CHECK-DAG: %[[D2:.*]] = dim %{{.*}}, 2
+//   CHECK-DAG: %[[D3:.*]] = dim %{{.*}}, 3
+//   CHECK-DAG: %[[D4:.*]] = dim %{{.*}}, 4
+//   CHECK-DAG: %[[D5:.*]] = dim %{{.*}}, 5
+//       CHECK: scf.parallel (%[[IV0:.*]], %[[IV1:.*]]) = (%[[C0]], %[[C0]]) to (%[[D0]], %[[D1]]) step (%[[C1]], %[[C1]])
+//       CHECK:   scf.for %[[IV2:.*]] = %[[C0]] to %[[D2]] step %[[C1]]
+//       CHECK:     scf.parallel (%[[IV3:.*]], %[[IV4:.*]]) = (%[[C0]], %[[C0]]) to (%[[D3]], %[[D4]]) step (%[[C1]], %[[C1]])
+//       CHECK:       scf.for %[[IV5:.*]] = %[[C0]] to %[[D5]] step %[[C1]]
+//       CHECK:       load %{{.*}}[%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]], %[[IV4]], %[[IV5]]]
+//       CHECK:       store %{{.*}}, %{{.*}}[%[[IV0]], %[[IV2]], %[[IV4]], %[[IV5]]]

diff  --git a/mlir/test/Dialect/Linalg/tile_parallel_reduce.mlir b/mlir/test/Dialect/Linalg/tile_parallel_reduce.mlir
new file mode 100644
index 000000000000..bfa14570aef1
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/tile_parallel_reduce.mlir
@@ -0,0 +1,108 @@
+// RUN: mlir-opt %s -linalg-tile-to-parallel-loops="linalg-tile-sizes=2,4,8" -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -linalg-tile-to-parallel-loops="linalg-tile-sizes=2" -split-input-file | FileCheck %s -check-prefix=TILE1
+// RUN: mlir-opt %s -linalg-tile-to-parallel-loops="linalg-tile-sizes=2,4" -split-input-file | FileCheck %s -check-prefix=TILE2
+
+func @gemm(%arg0 : memref<?x?xf32>,
+           %arg1 : memref<?x?xf32>,
+           %arg2 : memref<?x?xf32>)
+{
+  linalg.matmul(%arg0, %arg1, %arg2)
+    : memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>
+  return
+}
+// CHECK-LABEL: func @gemm
+//   CHECK-DAG:   %[[C2:.*]] = constant 2 : index
+//   CHECK-DAG:   %[[C4:.*]] = constant 4 : index
+//   CHECK-DAG:   %[[C8:.*]] = constant 8 : index
+//       CHECK:   scf.parallel (%[[ARG3:.*]], %[[ARG4:.*]]) =
+//  CHECK-SAME:     step (%[[C2]], %[[C4]])
+//       CHECK:     scf.for %[[ARG5:.*]] =
+//  CHECK-SAME:       step %[[C8]]
+//       CHECK:       %[[SV1:.*]] = subview %{{.*}}[%[[ARG3]], %[[ARG5]]]
+//       CHECK:       %[[SV2:.*]] = subview %{{.*}}[%[[ARG5]], %[[ARG4]]]
+//       CHECK:       %[[SV3:.*]] = subview %{{.*}}[%[[ARG3]], %[[ARG4]]]
+//       CHECK:       linalg.matmul(%[[SV1]], %[[SV2]], %[[SV3]])
+
+// TILE1-LABEL: func @gemm
+//   TILE1-DAG:   %[[C2:.*]] = constant 2 : index
+//       TILE1:   scf.parallel (%[[ARG3:.*]]) =
+//  TILE1-SAME:     step (%[[C2]])
+//       TILE1:     %[[SV1:.*]] = subview %{{.*}}[%[[ARG3]], 0]
+//       TILE1:     %[[SV3:.*]] = subview %{{.*}}[%[[ARG3]], 0]
+//   TILE1-NOT:     subview
+//       TILE1:     linalg.matmul(%[[SV1]], %{{.*}}, %[[SV3]])
+
+// TILE2-LABEL: func @gemm
+//   TILE2-DAG:   %[[C2:.*]] = constant 2 : index
+//   TILE2-DAG:   %[[C4:.*]] = constant 4 : index
+//       TILE2:   scf.parallel (%[[ARG3:.*]], %[[ARG4:.*]]) =
+//  TILE2-SAME:     step (%[[C2]], %[[C4]])
+//       TILE2:       %[[SV1:.*]] = subview %{{.*}}[%[[ARG3]], 0]
+//       TILE2:       %[[SV2:.*]] = subview %{{.*}}[0, %[[ARG4]]]
+//       TILE2:       %[[SV3:.*]] = subview %{{.*}}[%[[ARG3]], %[[ARG4]]]
+//       TILE2:       linalg.matmul(%[[SV1]], %[[SV2]], %[[SV3]])
+
+// -----
+
+#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d1)>
+#accesses = [#map0, #map1, #map2]
+#trait = {
+  args_in = 2 : i64,
+  args_out = 1 : i64,
+  iterator_types = ["reduction", "parallel", "reduction"],
+  indexing_maps = #accesses
+}
+
+func @reduction(%arg0 : memref<?x?x?xf32>,
+                %arg1 : memref<?x?xf32>,
+                %arg2 : memref<?xf32>)
+{
+  linalg.generic #trait %arg0, %arg1, %arg2 {
+  ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32):
+    %0 = addf %arg3, %arg4 : f32
+    %1 = addf %0, %arg5 : f32
+    linalg.yield %1 : f32
+  } : memref<?x?x?xf32>, memref<?x?xf32>, memref<?xf32>
+  return
+}
+
+// CHECK-LABEL: func @reduction
+//   CHECK-DAG:   %[[C2:.*]] = constant 2 : index
+//   CHECK-DAG:   %[[C4:.*]] = constant 4 : index
+//   CHECK-DAG:   %[[C8:.*]] = constant 8 : index
+//       CHECK:   scf.for %[[ARG3:.*]] =
+//  CHECK-SAME:     step %[[C2]]
+//       CHECK:     scf.parallel (%[[ARG4:.*]]) =
+//  CHECK-SAME:       step (%[[C4]])
+//       CHECK:       scf.for %[[ARG5:.*]] =
+//  CHECK-SAME:         step %[[C8]]
+//       CHECK:         %[[SV1:.*]] = subview %{{.*}}[%[[ARG3]], %[[ARG4]], %[[ARG5]]]
+//       CHECK:         %[[SV2:.*]] = subview %{{.*}}[%[[ARG3]], %[[ARG5]]]
+//       CHECK:         %[[SV3:.*]] = subview %{{.*}}[%[[ARG4]]]
+//       CHECK:         linalg.generic
+//  CHECK-SAME:           %[[SV1]], %[[SV2]], %[[SV3]]
+
+// TILE1-LABEL: func @reduction
+//   TILE1-DAG:   %[[C2:.*]] = constant 2 : index
+//       TILE1:   scf.for %[[ARG3:.*]] =
+//  TILE1-SAME:     step %[[C2]]
+//       TILE1:         %[[SV1:.*]] = subview %{{.*}}[%[[ARG3]], 0, 0]
+//       TILE1:         %[[SV2:.*]] = subview %{{.*}}[%[[ARG3]], 0]
+//   TILE1-NOT:         subview
+//       TILE1:         linalg.generic
+//  TILE1-SAME:           %[[SV1]], %[[SV2]], %{{.*}}
+
+// TILE2-LABEL: func @reduction
+//   TILE2-DAG:   %[[C2:.*]] = constant 2 : index
+//   TILE2-DAG:   %[[C4:.*]] = constant 4 : index
+//       TILE2:   scf.for %[[ARG3:.*]] =
+//  TILE2-SAME:     step %[[C2]]
+//       TILE2:     scf.parallel (%[[ARG4:.*]]) =
+//  TILE2-SAME:       step (%[[C4]])
+//       TILE2:         %[[SV1:.*]] = subview %{{.*}}[%[[ARG3]], %[[ARG4]], 0]
+//       TILE2:         %[[SV2:.*]] = subview %{{.*}}[%[[ARG3]], 0]
+//       TILE2:         %[[SV3:.*]] = subview %{{.*}}[%[[ARG4]]]
+//       TILE2:         linalg.generic
+//  TILE2-SAME:           %[[SV1]], %[[SV2]], %[[SV3]]

diff  --git a/mlir/test/Dialect/Linalg/transform-patterns.mlir b/mlir/test/Dialect/Linalg/transform-patterns.mlir
index ce868d156f6d..4c46c74fe490 100644
--- a/mlir/test/Dialect/Linalg/transform-patterns.mlir
+++ b/mlir/test/Dialect/Linalg/transform-patterns.mlir
@@ -44,7 +44,8 @@ func @matvec(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
 // CHECK-DAG:     %[[c0:.*]] = constant 0 : index
 // CHECK-DAG:     %[[c5:.*]] = constant 5 : index
 // CHECK-DAG:     %[[c6:.*]] = constant 6 : index
-// CHECK:         scf.parallel {{.*}} step (%[[c5]], %[[c6]])
+// CHECK:         scf.parallel {{.*}} step (%[[c5]])
+// CHECK:           scf.for {{.*}} step %[[c6]]
 // CHECK:             linalg.matvec({{.*}}, {{.*}}, {{.*}}) : memref<?x?xf32, #[[STRIDED_2D]]>, memref<?xf32, #[[STRIDED_1D]]>, memref<?xf32, #[[STRIDED_1D]]>
 
 func @matmul(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
@@ -364,3 +365,25 @@ func @aligned_promote_fill(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>) {
 // CHECK:         linalg.fill(%[[v0]], {{%.*}}) : memref<?x?xf32>, f32
 // CHECK:         linalg.copy(%[[s0]], %[[l0]]) : memref<?x?xf32, #map{{.*}}>, memref<?x?xf32, #map{{.*}}>
 // CHECK:         linalg.fill(%[[v0]], %[[cf]]) : memref<?x?xf32>, f32
+
+func @tile_permute_parallel_loop(%arg0: memref<?x?xf32>,
+                                 %arg1: memref<?x?xf32>,
+                                 %arg2: memref<?x?xf32>) {
+  linalg.matmul(%arg0, %arg1, %arg2) {__internal_linalg_transform__ = "par__with_perm__"}
+    : memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>
+  return
+}
+// CHECK-LABEL: func @tile_permute_parallel_loop
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: memref<?x?xf32>
+//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: memref<?x?xf32>
+//  CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: memref<?x?xf32>
+//   CHECK-DAG:   %[[C16:.*]] = constant 16 : index
+//   CHECK-DAG:   %[[C8:.*]] = constant 8 : index
+//   CHECK-DAG:   %[[C4:.*]] = constant 4 : index
+//   CHECK-DAG:   %[[C0:.*]] = constant 0 : index
+//   CHECK-DAG:   %[[D0:.*]] = dim %[[ARG0]], 0
+//   CHECK-DAG:   %[[D1:.*]] = dim %[[ARG0]], 1
+//   CHECK-DAG:   %[[D2:.*]] = dim %[[ARG1]], 1
+//       CHECK:   scf.parallel (%{{.*}}) = (%[[C0]]) to (%[[D2]]) step (%[[C8]])
+//       CHECK:     scf.for %{{.*}} = %[[C0]] to %[[D1]] step %[[C4]]
+//       CHECK:       scf.parallel (%{{.*}}) = (%[[C0]]) to (%[[D0]]) step (%[[C16]])

diff  --git a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
index e38153058419..7547e2953ef2 100644
--- a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
@@ -101,6 +101,14 @@ static void applyPatterns(FuncOp funcOp) {
       ctx, LinalgTilingOptions().setTileSizes({5, 6}).setInterchange({1, 0}),
       LinalgMarker({"__with_perm__"}, "L1__with_perm__"));
 
+  patterns.insert<LinalgTilingPattern<MatmulOp>>(
+      ctx,
+      LinalgTilingOptions()
+          .setTileSizes({16, 8, 4})
+          .setInterchange({1, 2, 0})
+          .setLoopType(LinalgTilingLoopType::ParallelLoops),
+      LinalgMarker({"par__with_perm__"}, "after_par__with_perm__"));
+
   //===--------------------------------------------------------------------===//
   // Linalg to loops patterns.
   //===--------------------------------------------------------------------===//


        


More information about the Mlir-commits mailing list