[Mlir-commits] [mlir] 8ea5d19 - [mlir][linalg] update tiling to support linalg index operations.

Tobias Gysi llvmlistbot at llvm.org
Tue Apr 13 07:37:43 PDT 2021


Author: Tobias Gysi
Date: 2021-04-13T14:36:01Z
New Revision: 8ea5d190ecc701e5a2df7661f21d4b0ad0fbdc29

URL: https://github.com/llvm/llvm-project/commit/8ea5d190ecc701e5a2df7661f21d4b0ad0fbdc29
DIFF: https://github.com/llvm/llvm-project/commit/8ea5d190ecc701e5a2df7661f21d4b0ad0fbdc29.diff

LOG: [mlir][linalg] update tiling to support linalg index operations.

The patch updates the tiling pass to add the tile offsets to the indices returned by the linalg operations.

Differential Revision: https://reviews.llvm.org/D100379

Added: 
    mlir/test/Dialect/Linalg/tile-indexed.mlir

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
    mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
    mlir/test/Dialect/Linalg/tile.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index 0c29bc05cb66..becbd36e9fdb 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -172,6 +172,85 @@ static void transformIndexedGenericOpIndices(
   }
 }
 
+// All indices returned by IndexOp should be invariant with respect to tiling.
+// Therefore, if an operation is tiled, we have to transform the indices
+// accordingly, i.e. offset them by the values of the corresponding induction
+// variables that are captured implicitly in the body of the op.
+//
+// Example. `linalg.generic` before tiling:
+//
+// #id_2d = (i, j) -> (i, j)
+// #pointwise_2d_trait = {
+//   indexing_maps = [#id_2d, #id_2d],
+//   iterator_types = ["parallel", "parallel"]
+// }
+// linalg.generic #pointwise_2d_trait %operand, %result {
+//   ^bb0(%operand_in: f32, %result_in: f32):
+//     %i = linalg.index 0 : index
+//     %j = linalg.index 1 : index
+//     <some operations that use %i, %j>
+// }: memref<50x100xf32>, memref<50x100xf32>
+//
+// After tiling pass with tiles sizes 10 and 25:
+//
+// #strided = (i, j)[s0, s1, s2] -> (i * s1 + s0 + j * s2)
+//
+// %c1 = constant 1 : index
+// %c0 = constant 0 : index
+// %c25 = constant 25 : index
+// %c10 = constant 10 : index
+// operand_dim_0 = dim %operand, 0 : memref<50x100xf32>
+// operand_dim_1 = dim %operand, 1 : memref<50x100xf32>
+// scf.for %k = %c0 to operand_dim_0 step %c10 {
+//   scf.for %l = %c0 to operand_dim_1 step %c25 {
+//     %4 = std.subview %operand[%k, %l][%c10, %c25][%c1, %c1]
+//       : memref<50x100xf32> to memref<?x?xf32, #strided>
+//     %5 = std.subview %result[%k, %l][%c10, %c25][%c1, %c1]
+//       : memref<50x100xf32> to memref<?x?xf32, #strided>
+//     linalg.generic pointwise_2d_trait %4, %5 {
+//     ^bb0(%operand_in: f32, %result_in: f32):
+//       %i = linalg.index 0 : index
+//       %j = linalg.index 1 : index
+//       // Indices `k` and `l` are implicitly captured in the body.
+//       %transformed_i = addi %i, %k : index // index `i` is offset by %k
+//       %transformed_j = addi %j, %l : index // index `j` is offset by %l
+//       // Every use of %i, %j is replaced with %transformed_i, %transformed_j
+//       <some operations that use %transformed_i, %transformed_j>
+//     }: memref<?x?xf32, #strided>, memref<?x?xf32, #strided>
+//   }
+// }
+//
+// TODO: Investigate whether mixing implicit and explicit indices
+// does not lead to losing information.
+static void
+transformIndexOps(OpBuilder &b, LinalgOp op, SmallVectorImpl<Value> &ivs,
+                  const LoopIndexToRangeIndexMap &loopIndexToRangeIndex) {
+  // Skip operations that have no region attached.
+  if (op->getNumRegions() == 0)
+    return;
+  assert(op->getNumRegions() == 1 && op->getRegion(0).getBlocks().size() == 1 &&
+         "expected linalg operation to have one block.");
+  Block &block = op->getRegion(0).front();
+
+  for (IndexOp indexOp :
+       llvm::make_early_inc_range(block.getOps<linalg::IndexOp>())) {
+    auto rangeIndex = loopIndexToRangeIndex.find(indexOp.dim());
+    if (rangeIndex == loopIndexToRangeIndex.end())
+      continue;
+    // Offset the index by the value of the corresponding induction variable and
+    // replace all uses of the previous value.
+    OpBuilder::InsertionGuard g(b);
+    b.setInsertionPointAfter(indexOp);
+    AffineExpr index, iv;
+    bindDims(b.getContext(), index, iv);
+    AffineApplyOp applyOp = b.create<AffineApplyOp>(
+        indexOp.getLoc(), index + iv,
+        ValueRange{indexOp.getResult(), ivs[rangeIndex->second]});
+    indexOp.getResult().replaceAllUsesExcept(
+        applyOp.getResult(), SmallPtrSet<Operation *, 1>{applyOp});
+  }
+}
+
 template <typename LoopTy>
 static Optional<TiledLinalgOp>
 tileLinalgOpImpl(OpBuilder &b, LinalgOp op, ValueRange tileSizes,
@@ -299,8 +378,10 @@ tileLinalgOpImpl(OpBuilder &b, LinalgOp op, ValueRange tileSizes,
       },
       options.distribution);
 
-  // 3. Transforms index arguments of `linalg.generic` w.r.t. to the tiling.
+  // 3a. Transforms index arguments of `linalg.generic` w.r.t. to the tiling.
   transformIndexedGenericOpIndices(b, res, ivs, loopIndexToRangeIndex);
+  // 3b. Transform IndexOp results w.r.t. the tiling.
+  transformIndexOps(b, res, ivs, loopIndexToRangeIndex);
 
   // 4. Gather the newly created loops and return them with the new op.
   SmallVector<Operation *, 8> loops;

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 00d6a3a1b95f..ccb4b0e3baa4 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -246,8 +246,7 @@ mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern(
 LogicalResult mlir::linalg::LinalgBaseTilingPattern::matchAndRewriteBase(
     Operation *op, PatternRewriter &rewriter, TiledLinalgOp &result) const {
   LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
-  // TODO: remove hasIndexSemantics check once index ops are supported.
-  if (!linalgOp || linalgOp.hasIndexSemantics())
+  if (!linalgOp)
     return failure();
   if (failed(filter.checkAndNotify(rewriter, linalgOp)))
     return failure();

diff  --git a/mlir/test/Dialect/Linalg/tile-indexed.mlir b/mlir/test/Dialect/Linalg/tile-indexed.mlir
new file mode 100644
index 000000000000..cd8a85ccd21b
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/tile-indexed.mlir
@@ -0,0 +1,85 @@
+// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=10,25" -split-input-file | FileCheck %s -check-prefix=TILE-10n25
+// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=25,0" -split-input-file | FileCheck %s -check-prefix=TILE-25n0
+// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=0,25" -split-input-file | FileCheck %s -check-prefix=TILE-0n25
+
+func @indexed_vector(%arg0: memref<50xindex>) {
+  linalg.generic {indexing_maps = [affine_map<(i) -> (i)>],
+                  iterator_types = ["parallel"]}
+     outs(%arg0 : memref<50xindex>) {
+    ^bb0(%a: index):
+      %i = linalg.index 0 : index
+      linalg.yield %i : index
+  }
+  return
+}
+// TILE-10n25-DAG: [[$MAP:#[a-zA-Z0-9_]*]] = affine_map<(d0, d1) -> (d0 + d1)>
+// TILE-10n25-LABEL: func @indexed_vector
+// TILE-10n25: %[[C10:.*]] = constant 10 : index
+// TILE-10n25: scf.for %[[J:.*]] = {{.*}} step %[[C10]]
+// TILE-10n25:   linalg.generic
+// TILE-10n25:     %[[I:.*]] = linalg.index 0 : index
+// TILE-10n25:     %[[NEW_I:.*]] = affine.apply [[$MAP]](%[[I]], %[[J]])
+// TILE-10n25:     linalg.yield %[[NEW_I]] : index
+
+// TILE-25n0-DAG: [[$MAP:#[a-zA-Z0-9_]*]] = affine_map<(d0, d1) -> (d0 + d1)>
+// TILE-25n0-LABEL: func @indexed_vector
+// TILE-25n0: %[[C25:.*]] = constant 25 : index
+// TILE-25n0: scf.for %[[J:.*]] = {{.*}} step %[[C25]]
+// TILE-25n0:   linalg.generic
+// TILE-25n0:     %[[I:.*]] = linalg.index 0 : index
+// TILE-25n0:     %[[NEW_I:.*]] = affine.apply [[$MAP]](%[[I]], %[[J]])
+// TILE-25n0:     linalg.yield %[[NEW_I]] : index
+
+// TILE-0n25-LABEL: func @indexed_vector
+// TILE-0n25-NOT: scf.for %[[J:.*]] = {{.*}} step %
+// TILE-0n25: linalg.generic
+
+// -----
+
+func @indexed_matrix(%arg0: memref<50x50xindex>) {
+  linalg.generic {indexing_maps = [affine_map<(i, j) -> (i, j)>],
+                  iterator_types = ["parallel", "parallel"]}
+    outs(%arg0 : memref<50x50xindex>) {
+    ^bb0(%a: index):
+      %i = linalg.index 0 : index
+      %j = linalg.index 1 : index
+      %sum = addi %i, %j : index
+      linalg.yield %sum : index
+  }
+  return
+}
+// TILE-10n25-DAG: [[$MAP:#[a-zA-Z0-9_]*]] = affine_map<(d0, d1) -> (d0 + d1)>
+// TILE-10n25-LABEL: func @indexed_matrix
+// TILE-10n25-DAG: %[[C25:.*]] = constant 25 : index
+// TILE-10n25-DAG: %[[C10:.*]] = constant 10 : index
+// TILE-10n25: scf.for %[[K:.*]] = {{.*}} step %[[C10]]
+// TILE-10n25:   scf.for %[[L:.*]] = {{.*}} step %[[C25]]
+// TILE-10n25:     linalg.generic
+// TILE-10n25:       %[[I:.*]] = linalg.index 0 : index
+// TILE-10n25:       %[[NEW_I:.*]] = affine.apply [[$MAP]](%[[I]], %[[K]])
+// TILE-10n25:       %[[J:.*]] = linalg.index 1 : index
+// TILE-10n25:       %[[NEW_J:.*]] = affine.apply [[$MAP]](%[[J]], %[[L]])
+// TILE-10n25:       %[[SUM:.*]] = addi %[[NEW_I]], %[[NEW_J]] : index
+// TILE-10n25:       linalg.yield %[[SUM]] : index
+
+// TILE-25n0-DAG: [[$MAP:#[a-zA-Z0-9_]*]] = affine_map<(d0, d1) -> (d0 + d1)>
+// TILE-25n0-LABEL: func @indexed_matrix
+// TILE-25n0: %[[C25:.*]] = constant 25 : index
+// TILE-25n0: scf.for %[[L:.*]] = {{.*}} step %[[C25]]
+// TILE-25n0:   linalg.generic
+// TILE-25n0:     %[[I:.*]] = linalg.index 0 : index
+// TILE-25n0:     %[[NEW_I:.*]] = affine.apply [[$MAP]](%[[I]], %[[L]])
+// TILE-25n0:     %[[J:.*]] = linalg.index 1 : index
+// TILE-25n0:     %[[SUM:.*]] = addi %[[NEW_I]], %[[J]] : index
+// TILE-25n0:     linalg.yield %[[SUM]] : index
+
+// TILE-0n25-DAG: [[$MAP:#[a-zA-Z0-9_]*]] = affine_map<(d0, d1) -> (d0 + d1)>
+// TILE-0n25-LABEL: func @indexed_matrix
+// TILE-0n25: %[[C25:.*]] = constant 25 : index
+// TILE-0n25: scf.for %[[L:.*]] = {{.*}} step %[[C25]]
+// TILE-0n25:   linalg.generic
+// TILE-0n25:     %[[I:.*]] = linalg.index 0 : index
+// TILE-0n25:     %[[J:.*]] = linalg.index 1 : index
+// TILE-0n25:     %[[NEW_J:.*]] = affine.apply [[$MAP]](%[[J]], %[[L]])
+// TILE-0n25:     %[[SUM:.*]] = addi %[[I]], %[[NEW_J]] : index
+// TILE-0n25:     linalg.yield %[[SUM]] : index

diff  --git a/mlir/test/Dialect/Linalg/tile.mlir b/mlir/test/Dialect/Linalg/tile.mlir
index d8b20904f751..c761bd6cd57e 100644
--- a/mlir/test/Dialect/Linalg/tile.mlir
+++ b/mlir/test/Dialect/Linalg/tile.mlir
@@ -377,18 +377,3 @@ func @pointwise(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>, %arg1: memre
 //       TILE-234:     for
 //   TILE-234-NOT:   for
 //       TILE-234:       linalg.generic
-
-// TILE-2-LABEL: func @index_op
-//   TILE-2-NOT:   for
-//       TILE-2:   linalg.generic
-func @index_op(%arg0: memref<?x?xindex>) {
-  linalg.generic {
-    indexing_maps = [affine_map<(i, j) -> (i, j)>],
-    iterator_types = ["parallel", "parallel"]}
-  outs(%arg0 : memref<?x?xindex>) {
-  ^bb0(%arg1: index):   // no predecessors
-    %0 = linalg.index 1 : index
-    linalg.yield %0 : index
-  }
-  return
-}


        


More information about the Mlir-commits mailing list