[Mlir-commits] [mlir] 90b7817 - [mlir][linalg] Add helper to update IndexOps after tiling (NFC).

Tobias Gysi llvmlistbot at llvm.org
Fri Sep 17 08:19:20 PDT 2021


Author: Tobias Gysi
Date: 2021-09-17T15:17:33Z
New Revision: 90b7817e03af19a8fdc7f32f82e39d0fbf8a9791

URL: https://github.com/llvm/llvm-project/commit/90b7817e03af19a8fdc7f32f82e39d0fbf8a9791
DIFF: https://github.com/llvm/llvm-project/commit/90b7817e03af19a8fdc7f32f82e39d0fbf8a9791.diff

LOG: [mlir][linalg] Add helper to update IndexOps after tiling (NFC).

Add the addTileLoopIvsToIndexOpResults method to shift the IndexOp results after tiling.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D109761

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
    mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
    mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
    mlir/lib/Dialect/Linalg/Utils/Utils.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index fd0a0befcbc29..8776b74045422 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -127,6 +127,11 @@ SmallVector<Value, 4> makeTiledShapes(OpBuilder &builder, Location loc,
                                       ValueRange ivs, ValueRange tileSizes,
                                       ArrayRef<Value> sizeBounds);
 
+/// Add the tile loop induction variables `ivs` to the IndexOp results found in
+/// the body of the `tiledOp` to account for the tile offset.
+void addTileLoopIvsToIndexOpResults(OpBuilder &b, LinalgOp tiledOp,
+                                    ArrayRef<Value> ivs);
+
 using FusableOpDependencesTy = llvm::MapVector<
     Operation *,
     SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 1>>;

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
index 71556903245bd..adce2c74b78c2 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
@@ -222,26 +222,12 @@ static LinalgOp fuse(OpBuilder &b, LinalgOp producer,
   }
 
   Operation *clonedOp = producer.clone(b, loc, resultTypes, clonedShapes);
-  // When the producer has index semantics, we have to transform the indices of
-  // the producer according to the tiling of the consumer, i.e. offset them by
-  // the values computed in `loopRanges`.
-  if (producer.hasIndexSemantics()) {
-    assert(clonedOp->getNumRegions() == 1 &&
-           clonedOp->getRegion(0).getBlocks().size() == 1 &&
-           "expected producer to have one block.");
-    // Shift all indices by the tile offset.
-    Block &block = clonedOp->getRegion(0).front();
-    for (IndexOp indexOp : block.getOps<IndexOp>()) {
-      OpBuilder::InsertionGuard g(b);
-      b.setInsertionPointAfter(indexOp);
-      AffineExpr index, offset;
-      bindDims(b.getContext(), index, offset);
-      AffineApplyOp applyOp = b.create<AffineApplyOp>(
-          indexOp.getLoc(), index + offset,
-          ValueRange{indexOp.getResult(), loopRanges[indexOp.dim()].offset});
-      indexOp.getResult().replaceAllUsesExcept(applyOp, applyOp);
-    }
-  }
+
+  // Shift all IndexOp results by the tile offset.
+  SmallVector<Value> allIvs;
+  transform(loopRanges, std::back_inserter(allIvs),
+            [](Range range) { return range.offset; });
+  addTileLoopIvsToIndexOpResults(b, clonedOp, allIvs);
 
   return clonedOp;
 }

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index 57e77c0bec059..f3b9516bacf00 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -129,28 +129,14 @@ makeTiledLoopRanges(OpBuilder &b, Location loc, AffineMap map,
 static void
 transformIndexOps(OpBuilder &b, LinalgOp op, SmallVectorImpl<Value> &ivs,
                   const LoopIndexToRangeIndexMap &loopIndexToRangeIndex) {
-  // Skip operations that have no region attached.
-  if (op->getNumRegions() == 0)
-    return;
-  assert(op->getNumRegions() == 1 && op->getRegion(0).getBlocks().size() == 1 &&
-         "expected linalg operation to have one block.");
-  Block &block = op->getRegion(0).front();
-
-  for (IndexOp indexOp : block.getOps<linalg::IndexOp>()) {
-    auto rangeIndex = loopIndexToRangeIndex.find(indexOp.dim());
+  SmallVector<Value> allIvs(op.getNumLoops(), nullptr);
+  for (auto &en : enumerate(allIvs)) {
+    auto rangeIndex = loopIndexToRangeIndex.find(en.index());
     if (rangeIndex == loopIndexToRangeIndex.end())
       continue;
-    // Offset the index by the value of the corresponding induction variable and
-    // replace all uses of the previous value.
-    OpBuilder::InsertionGuard g(b);
-    b.setInsertionPointAfter(indexOp);
-    AffineExpr index, iv;
-    bindDims(b.getContext(), index, iv);
-    AffineApplyOp applyOp = b.create<AffineApplyOp>(
-        indexOp.getLoc(), index + iv,
-        ValueRange{indexOp.getResult(), ivs[rangeIndex->second]});
-    indexOp.getResult().replaceAllUsesExcept(applyOp, applyOp);
+    en.value() = ivs[rangeIndex->second];
   }
+  addTileLoopIvsToIndexOpResults(b, op, allIvs);
 }
 
 // Insert a tile `source` into the destination tensor `dest`. The position at

diff  --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index 296956d7425eb..e0b1a4840b41f 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -731,5 +731,28 @@ SmallVector<Value, 4> makeTiledShapes(OpBuilder &b, Location loc,
   return tiledShapes;
 }
 
+void addTileLoopIvsToIndexOpResults(OpBuilder &b, LinalgOp tiledOp,
+                                    ArrayRef<Value> ivs) {
+  if (tiledOp.hasIndexSemantics()) {
+    assert(tiledOp->getNumRegions() == 1 &&
+           tiledOp->getRegion(0).getBlocks().size() == 1 &&
+           "expect producer to have one block.");
+    // Shift all IndexOp results by the tile offset.
+    Block &block = tiledOp->getRegion(0).front();
+    for (IndexOp indexOp : block.getOps<IndexOp>()) {
+      if (ivs[indexOp.dim()] == nullptr)
+        continue;
+      OpBuilder::InsertionGuard guard(b);
+      b.setInsertionPointAfter(indexOp);
+      AffineExpr index, offset;
+      bindDims(b.getContext(), index, offset);
+      AffineApplyOp applyOp = makeComposedAffineApply(
+          b, indexOp.getLoc(), index + offset,
+          ValueRange{indexOp.getResult(), ivs[indexOp.dim()]});
+      indexOp.getResult().replaceAllUsesExcept(applyOp, applyOp);
+    }
+  }
+}
+
 } // namespace linalg
 } // namespace mlir


        


More information about the Mlir-commits mailing list