[Mlir-commits] [mlir] 074ab23 - [mlir][Linalg] Refactor Linalg creation of loops to allow passing iterArgs - NFC

Nicolas Vasilache llvmlistbot at llvm.org
Tue Sep 29 06:51:41 PDT 2020


Author: Nicolas Vasilache
Date: 2020-09-29T09:51:11-04:00
New Revision: 074ab233ed620c1afa44e5bc2d86ab448a9ce1ed

URL: https://github.com/llvm/llvm-project/commit/074ab233ed620c1afa44e5bc2d86ab448a9ce1ed
DIFF: https://github.com/llvm/llvm-project/commit/074ab233ed620c1afa44e5bc2d86ab448a9ce1ed.diff

LOG: [mlir][Linalg] Refactor Linalg creation of loops to allow passing iterArgs - NFC

This revision changes the signatures of helper function that Linalg uses to create loops so that they can also take iterArgs.
iterArgs are asserted empty to ensure no functional change.
This is a mechanical change in preparation of tiling on linalg on tensors to avoid  polluting the implementation with an NFC change.

Differential Revision: https://reviews.llvm.org/D88480

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
    mlir/include/mlir/Dialect/SCF/EDSC/Builders.h
    mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
    mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
    mlir/lib/Dialect/Linalg/Utils/Utils.cpp
    mlir/lib/Dialect/SCF/EDSC/Builders.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index 35353adf11ed..aca5a981b003 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -214,10 +214,11 @@ struct GenerateLoopNest {
       typename std::conditional<std::is_same<LoopTy, AffineForOp>::value,
                                 AffineIndexedValue, StdIndexedValue>::type;
 
-  static void doit(ArrayRef<SubViewOp::Range> loopRanges,
-                   ArrayRef<Attribute> iteratorTypes,
-                   function_ref<void(ValueRange)> bodyBuilderFn,
-                   Optional<LinalgLoopDistributionOptions> = None);
+  static void
+  doit(ArrayRef<SubViewOp::Range> loopRanges, ValueRange iterArgInitValues,
+       ArrayRef<Attribute> iteratorTypes,
+       function_ref<scf::ValueVector(ValueRange, ValueRange)> bodyBuilderFn,
+       Optional<LinalgLoopDistributionOptions> = None);
 };
 
 } // namespace linalg

diff  --git a/mlir/include/mlir/Dialect/SCF/EDSC/Builders.h b/mlir/include/mlir/Dialect/SCF/EDSC/Builders.h
index 50adec2f9b8b..fe8df4c2d0e4 100644
--- a/mlir/include/mlir/Dialect/SCF/EDSC/Builders.h
+++ b/mlir/include/mlir/Dialect/SCF/EDSC/Builders.h
@@ -32,6 +32,10 @@ scf::ValueVector loopNestBuilder(Value lb, Value ub, Value step,
 scf::ValueVector loopNestBuilder(
     Value lb, Value ub, Value step, ValueRange iterArgInitValues,
     function_ref<scf::ValueVector(Value, ValueRange)> fun = nullptr);
+scf::ValueVector loopNestBuilder(
+    ValueRange lbs, ValueRange ubs, ValueRange steps,
+    ValueRange iterArgInitValues,
+    function_ref<scf::ValueVector(ValueRange, ValueRange)> fun = nullptr);
 
 /// Adapters for building if conditions using the builder and the location
 /// stored in ScopedContext. 'thenBody' is mandatory, 'elseBody' can be omitted

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
index d3c90ffab06f..eb452cc40305 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
@@ -515,9 +515,12 @@ Optional<LinalgLoops> linalgOpToLoopsImpl(Operation *op, OpBuilder &builder) {
                                    map, getViewSizes(builder, linalgOp));
   SmallVector<Value, 4> allIvs;
   GenerateLoopNest<LoopTy>::doit(
-      loopRanges, linalgOp.iterator_types().getValue(), [&](ValueRange ivs) {
+      loopRanges, /*iterInitArgs*/ {}, linalgOp.iterator_types().getValue(),
+      [&](ValueRange ivs, ValueRange iterArgs) -> scf::ValueVector {
+        assert(iterArgs.empty() && "unexpected iterArgs");
         allIvs.append(ivs.begin(), ivs.end());
         emitScalarImplementation<IndexedValueTy>(allIvs, linalgOp);
+        return scf::ValueVector{};
       });
   // Number of loop ops might be 
diff erent from the number of ivs since some
   // loops like affine.parallel and scf.parallel have multiple ivs.

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index daaad2e6fa4b..676caa145c3a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -386,8 +386,8 @@ Optional<TiledLinalgOp> static tileLinalgOpImpl(
   if (!options.interchangeVector.empty())
     applyPermutationToVector(iteratorTypes, options.interchangeVector);
   GenerateLoopNest<LoopTy>::doit(
-      loopRanges, iteratorTypes,
-      [&](ValueRange localIvs) {
+      loopRanges, /*iterArgInitValues*/ {}, iteratorTypes,
+      [&](ValueRange localIvs, ValueRange iterArgs) -> scf::ValueVector {
         auto &b = ScopedContext::getBuilderRef();
         auto loc = ScopedContext::getLocation();
         ivs.assign(localIvs.begin(), localIvs.end());
@@ -406,6 +406,7 @@ Optional<TiledLinalgOp> static tileLinalgOpImpl(
         auto operands = getAssumedNonViewOperands(op);
         views.append(operands.begin(), operands.end());
         res = op.clone(b, loc, views);
+        return scf::ValueVector{};
       },
       options.distribution);
 

diff  --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index 585b00189964..204716b40746 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -194,20 +194,23 @@ getLoopRanges(OpBuilder &builder, LinalgOp linalgOp, OperationFolder *folder) {
 /// Specialization to build an scf "for" nest.
 template <>
 void GenerateLoopNest<scf::ForOp>::doit(
-    ArrayRef<SubViewOp::Range> loopRanges, ArrayRef<Attribute> iteratorTypes,
-    function_ref<void(ValueRange)> bodyBuilderFn,
+    ArrayRef<SubViewOp::Range> loopRanges, ValueRange iterArgInitValues,
+    ArrayRef<Attribute> iteratorTypes,
+    function_ref<scf::ValueVector(ValueRange, ValueRange)> bodyBuilderFn,
     Optional<LinalgLoopDistributionOptions>) {
   SmallVector<Value, 4> lbs, ubs, steps;
   unpackRanges(loopRanges, lbs, ubs, steps);
-  edsc::loopNestBuilder(lbs, ubs, steps, bodyBuilderFn);
+  edsc::loopNestBuilder(lbs, ubs, steps, iterArgInitValues, bodyBuilderFn);
 }
 
 /// Specialization to build affine "for" nest.
 template <>
 void GenerateLoopNest<AffineForOp>::doit(
-    ArrayRef<SubViewOp::Range> loopRanges, ArrayRef<Attribute> iteratorTypes,
-    function_ref<void(ValueRange)> bodyBuilderFn,
+    ArrayRef<SubViewOp::Range> loopRanges, ValueRange iterArgInitValues,
+    ArrayRef<Attribute> iteratorTypes,
+    function_ref<scf::ValueVector(ValueRange, ValueRange)> bodyBuilderFn,
     Optional<LinalgLoopDistributionOptions>) {
+  assert(iterArgInitValues.empty() && "unexpected AffineForOp init values");
   SmallVector<Value, 4> lbs, ubs, steps;
   unpackRanges(loopRanges, lbs, ubs, steps);
 
@@ -220,7 +223,11 @@ void GenerateLoopNest<AffineForOp>::doit(
     constantSteps.push_back(op.getValue());
   }
 
-  edsc::affineLoopNestBuilder(lbs, ubs, constantSteps, bodyBuilderFn);
+  auto bodyBuilderWithoutIterArgsFn = [&](ValueRange ivs) {
+    bodyBuilderFn(ivs, {});
+  };
+  edsc::affineLoopNestBuilder(lbs, ubs, constantSteps,
+                              bodyBuilderWithoutIterArgsFn);
 }
 
 /// Update the `lb`, `ub` and `step` to get per processor `lb`, `ub` and `step`.
@@ -357,9 +364,11 @@ generateParallelLoopNest(ValueRange lbs, ValueRange ubs, ValueRange steps,
 /// Specialization for generating a mix of parallel and sequential scf loops.
 template <>
 void GenerateLoopNest<scf::ParallelOp>::doit(
-    ArrayRef<SubViewOp::Range> loopRanges, ArrayRef<Attribute> iteratorTypes,
-    function_ref<void(ValueRange)> bodyBuilderFn,
+    ArrayRef<SubViewOp::Range> loopRanges, ValueRange iterArgInitValues,
+    ArrayRef<Attribute> iteratorTypes,
+    function_ref<scf::ValueVector(ValueRange, ValueRange)> bodyBuilderFn,
     Optional<LinalgLoopDistributionOptions> distributionOptions) {
+  assert(iterArgInitValues.empty() && "unexpected ParallelOp init values");
   // This function may be passed more iterator types than ranges.
   assert(iteratorTypes.size() >= loopRanges.size() &&
          "expected iterator type for all ranges");
@@ -405,7 +414,11 @@ void GenerateLoopNest<scf::ParallelOp>::doit(
     }
   }
   ValueRange lbs(lbsStorage), ubs(ubsStorage), steps(stepsStorage);
-  generateParallelLoopNest(lbs, ubs, steps, iteratorTypes, bodyBuilderFn, ivs,
+  auto bodyBuilderWithoutIterArgsFn = [&](ValueRange ivs) {
+    bodyBuilderFn(ivs, {});
+  };
+  generateParallelLoopNest(lbs, ubs, steps, iteratorTypes,
+                           bodyBuilderWithoutIterArgsFn, ivs,
                            distributionMethod);
 
   assert(ivs.size() == iteratorTypes.size() && "did not generate enough loops");

diff  --git a/mlir/lib/Dialect/SCF/EDSC/Builders.cpp b/mlir/lib/Dialect/SCF/EDSC/Builders.cpp
index 2098ca1bf7d0..45097186a248 100644
--- a/mlir/lib/Dialect/SCF/EDSC/Builders.cpp
+++ b/mlir/lib/Dialect/SCF/EDSC/Builders.cpp
@@ -61,6 +61,25 @@ mlir::scf::ValueVector mlir::edsc::loopNestBuilder(
       });
 }
 
+mlir::scf::ValueVector mlir::edsc::loopNestBuilder(
+    ValueRange lbs, ValueRange ubs, ValueRange steps,
+    ValueRange iterArgInitValues,
+    function_ref<scf::ValueVector(ValueRange, ValueRange)> fun) {
+  // Delegates actual construction to scf::buildLoopNest by wrapping `fun` into
+  // the expected function interface.
+  assert(ScopedContext::getContext() && "EDSC ScopedContext not set up");
+  return mlir::scf::buildLoopNest(
+      ScopedContext::getBuilderRef(), ScopedContext::getLocation(), lbs, ubs,
+      steps, iterArgInitValues,
+      [&](OpBuilder &builder, Location loc, ValueRange ivs, ValueRange args) {
+        ScopedContext context(builder, loc);
+        if (fun)
+          return fun(ivs, args);
+        return scf::ValueVector(iterArgInitValues.begin(),
+                                iterArgInitValues.end());
+      });
+}
+
 static std::function<void(OpBuilder &, Location)>
 wrapIfBody(function_ref<scf::ValueVector()> body, TypeRange expectedTypes) {
   (void)expectedTypes;


        


More information about the Mlir-commits mailing list