[Mlir-commits] [mlir] 8ea5d19 - [mlir][linalg] update tiling to support linalg index operations.
Tobias Gysi
llvmlistbot at llvm.org
Tue Apr 13 07:37:43 PDT 2021
Author: Tobias Gysi
Date: 2021-04-13T14:36:01Z
New Revision: 8ea5d190ecc701e5a2df7661f21d4b0ad0fbdc29
URL: https://github.com/llvm/llvm-project/commit/8ea5d190ecc701e5a2df7661f21d4b0ad0fbdc29
DIFF: https://github.com/llvm/llvm-project/commit/8ea5d190ecc701e5a2df7661f21d4b0ad0fbdc29.diff
LOG: [mlir][linalg] update tiling to support linalg index operations.
The patch updates the tiling pass to add the tile offsets to the indices returned by the linalg operations.
Differential Revision: https://reviews.llvm.org/D100379
Added:
mlir/test/Dialect/Linalg/tile-indexed.mlir
Modified:
mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
mlir/test/Dialect/Linalg/tile.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index 0c29bc05cb66..becbd36e9fdb 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -172,6 +172,85 @@ static void transformIndexedGenericOpIndices(
}
}
+// All indices returned by IndexOp should be invariant with respect to tiling.
+// Therefore, if an operation is tiled, we have to transform the indices
+// accordingly, i.e. offset them by the values of the corresponding induction
+// variables that are captured implicitly in the body of the op.
+//
+// Example. `linalg.generic` before tiling:
+//
+// #id_2d = (i, j) -> (i, j)
+// #pointwise_2d_trait = {
+// indexing_maps = [#id_2d, #id_2d],
+// iterator_types = ["parallel", "parallel"]
+// }
+// linalg.generic #pointwise_2d_trait %operand, %result {
+// ^bb0(%operand_in: f32, %result_in: f32):
+// %i = linalg.index 0 : index
+// %j = linalg.index 1 : index
+// <some operations that use %i, %j>
+// }: memref<50x100xf32>, memref<50x100xf32>
+//
+// After tiling pass with tiles sizes 10 and 25:
+//
+// #strided = (i, j)[s0, s1, s2] -> (i * s1 + s0 + j * s2)
+//
+// %c1 = constant 1 : index
+// %c0 = constant 0 : index
+// %c25 = constant 25 : index
+// %c10 = constant 10 : index
+// operand_dim_0 = dim %operand, 0 : memref<50x100xf32>
+// operand_dim_1 = dim %operand, 1 : memref<50x100xf32>
+// scf.for %k = %c0 to operand_dim_0 step %c10 {
+// scf.for %l = %c0 to operand_dim_1 step %c25 {
+// %4 = std.subview %operand[%k, %l][%c10, %c25][%c1, %c1]
+// : memref<50x100xf32> to memref<?x?xf32, #strided>
+// %5 = std.subview %result[%k, %l][%c10, %c25][%c1, %c1]
+// : memref<50x100xf32> to memref<?x?xf32, #strided>
+// linalg.generic pointwise_2d_trait %4, %5 {
+// ^bb0(%operand_in: f32, %result_in: f32):
+// %i = linalg.index 0 : index
+// %j = linalg.index 1 : index
+// // Indices `k` and `l` are implicitly captured in the body.
+// %transformed_i = addi %i, %k : index // index `i` is offset by %k
+// %transformed_j = addi %j, %l : index // index `j` is offset by %l
+// // Every use of %i, %j is replaced with %transformed_i, %transformed_j
+// <some operations that use %transformed_i, %transformed_j>
+// }: memref<?x?xf32, #strided>, memref<?x?xf32, #strided>
+// }
+// }
+//
+// TODO: Investigate whether mixing implicit and explicit indices
+// does not lead to losing information.
+static void
+transformIndexOps(OpBuilder &b, LinalgOp op, SmallVectorImpl<Value> &ivs,
+ const LoopIndexToRangeIndexMap &loopIndexToRangeIndex) {
+ // Skip operations that have no region attached.
+ if (op->getNumRegions() == 0)
+ return;
+ assert(op->getNumRegions() == 1 && op->getRegion(0).getBlocks().size() == 1 &&
+ "expected linalg operation to have one block.");
+ Block &block = op->getRegion(0).front();
+
+ for (IndexOp indexOp :
+ llvm::make_early_inc_range(block.getOps<linalg::IndexOp>())) {
+ auto rangeIndex = loopIndexToRangeIndex.find(indexOp.dim());
+ if (rangeIndex == loopIndexToRangeIndex.end())
+ continue;
+ // Offset the index by the value of the corresponding induction variable and
+ // replace all uses of the previous value.
+ OpBuilder::InsertionGuard g(b);
+ b.setInsertionPointAfter(indexOp);
+ AffineExpr index, iv;
+ bindDims(b.getContext(), index, iv);
+ AffineApplyOp applyOp = b.create<AffineApplyOp>(
+ indexOp.getLoc(), index + iv,
+ ValueRange{indexOp.getResult(), ivs[rangeIndex->second]});
+ indexOp.getResult().replaceAllUsesExcept(
+ applyOp.getResult(), SmallPtrSet<Operation *, 1>{applyOp});
+ }
+}
+
template <typename LoopTy>
static Optional<TiledLinalgOp>
tileLinalgOpImpl(OpBuilder &b, LinalgOp op, ValueRange tileSizes,
@@ -299,8 +378,10 @@ tileLinalgOpImpl(OpBuilder &b, LinalgOp op, ValueRange tileSizes,
},
options.distribution);
- // 3. Transforms index arguments of `linalg.generic` w.r.t. to the tiling.
+ // 3a. Transforms index arguments of `linalg.generic` w.r.t. to the tiling.
transformIndexedGenericOpIndices(b, res, ivs, loopIndexToRangeIndex);
+ // 3b. Transform IndexOp results w.r.t. the tiling.
+ transformIndexOps(b, res, ivs, loopIndexToRangeIndex);
// 4. Gather the newly created loops and return them with the new op.
SmallVector<Operation *, 8> loops;
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 00d6a3a1b95f..ccb4b0e3baa4 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -246,8 +246,7 @@ mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern(
LogicalResult mlir::linalg::LinalgBaseTilingPattern::matchAndRewriteBase(
Operation *op, PatternRewriter &rewriter, TiledLinalgOp &result) const {
LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
- // TODO: remove hasIndexSemantics check once index ops are supported.
- if (!linalgOp || linalgOp.hasIndexSemantics())
+ if (!linalgOp)
return failure();
if (failed(filter.checkAndNotify(rewriter, linalgOp)))
return failure();
diff --git a/mlir/test/Dialect/Linalg/tile-indexed.mlir b/mlir/test/Dialect/Linalg/tile-indexed.mlir
new file mode 100644
index 000000000000..cd8a85ccd21b
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/tile-indexed.mlir
@@ -0,0 +1,85 @@
+// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=10,25" -split-input-file | FileCheck %s -check-prefix=TILE-10n25
+// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=25,0" -split-input-file | FileCheck %s -check-prefix=TILE-25n0
+// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=0,25" -split-input-file | FileCheck %s -check-prefix=TILE-0n25
+
+func @indexed_vector(%arg0: memref<50xindex>) {
+ linalg.generic {indexing_maps = [affine_map<(i) -> (i)>],
+ iterator_types = ["parallel"]}
+ outs(%arg0 : memref<50xindex>) {
+ ^bb0(%a: index):
+ %i = linalg.index 0 : index
+ linalg.yield %i : index
+ }
+ return
+}
+// TILE-10n25-DAG: [[$MAP:#[a-zA-Z0-9_]*]] = affine_map<(d0, d1) -> (d0 + d1)>
+// TILE-10n25-LABEL: func @indexed_vector
+// TILE-10n25: %[[C10:.*]] = constant 10 : index
+// TILE-10n25: scf.for %[[J:.*]] = {{.*}} step %[[C10]]
+// TILE-10n25: linalg.generic
+// TILE-10n25: %[[I:.*]] = linalg.index 0 : index
+// TILE-10n25: %[[NEW_I:.*]] = affine.apply [[$MAP]](%[[I]], %[[J]])
+// TILE-10n25: linalg.yield %[[NEW_I]] : index
+
+// TILE-25n0-DAG: [[$MAP:#[a-zA-Z0-9_]*]] = affine_map<(d0, d1) -> (d0 + d1)>
+// TILE-25n0-LABEL: func @indexed_vector
+// TILE-25n0: %[[C25:.*]] = constant 25 : index
+// TILE-25n0: scf.for %[[J:.*]] = {{.*}} step %[[C25]]
+// TILE-25n0: linalg.generic
+// TILE-25n0: %[[I:.*]] = linalg.index 0 : index
+// TILE-25n0: %[[NEW_I:.*]] = affine.apply [[$MAP]](%[[I]], %[[J]])
+// TILE-25n0: linalg.yield %[[NEW_I]] : index
+
+// TILE-0n25-LABEL: func @indexed_vector
+// TILE-0n25-NOT: scf.for %[[J:.*]] = {{.*}} step %
+// TILE-0n25: linalg.generic
+
+// -----
+
+func @indexed_matrix(%arg0: memref<50x50xindex>) {
+ linalg.generic {indexing_maps = [affine_map<(i, j) -> (i, j)>],
+ iterator_types = ["parallel", "parallel"]}
+ outs(%arg0 : memref<50x50xindex>) {
+ ^bb0(%a: index):
+ %i = linalg.index 0 : index
+ %j = linalg.index 1 : index
+ %sum = addi %i, %j : index
+ linalg.yield %sum : index
+ }
+ return
+}
+// TILE-10n25-DAG: [[$MAP:#[a-zA-Z0-9_]*]] = affine_map<(d0, d1) -> (d0 + d1)>
+// TILE-10n25-LABEL: func @indexed_matrix
+// TILE-10n25-DAG: %[[C25:.*]] = constant 25 : index
+// TILE-10n25-DAG: %[[C10:.*]] = constant 10 : index
+// TILE-10n25: scf.for %[[K:.*]] = {{.*}} step %[[C10]]
+// TILE-10n25: scf.for %[[L:.*]] = {{.*}} step %[[C25]]
+// TILE-10n25: linalg.generic
+// TILE-10n25: %[[I:.*]] = linalg.index 0 : index
+// TILE-10n25: %[[NEW_I:.*]] = affine.apply [[$MAP]](%[[I]], %[[K]])
+// TILE-10n25: %[[J:.*]] = linalg.index 1 : index
+// TILE-10n25: %[[NEW_J:.*]] = affine.apply [[$MAP]](%[[J]], %[[L]])
+// TILE-10n25: %[[SUM:.*]] = addi %[[NEW_I]], %[[NEW_J]] : index
+// TILE-10n25: linalg.yield %[[SUM]] : index
+
+// TILE-25n0-DAG: [[$MAP:#[a-zA-Z0-9_]*]] = affine_map<(d0, d1) -> (d0 + d1)>
+// TILE-25n0-LABEL: func @indexed_matrix
+// TILE-25n0: %[[C25:.*]] = constant 25 : index
+// TILE-25n0: scf.for %[[L:.*]] = {{.*}} step %[[C25]]
+// TILE-25n0: linalg.generic
+// TILE-25n0: %[[I:.*]] = linalg.index 0 : index
+// TILE-25n0: %[[NEW_I:.*]] = affine.apply [[$MAP]](%[[I]], %[[L]])
+// TILE-25n0: %[[J:.*]] = linalg.index 1 : index
+// TILE-25n0: %[[SUM:.*]] = addi %[[NEW_I]], %[[J]] : index
+// TILE-25n0: linalg.yield %[[SUM]] : index
+
+// TILE-0n25-DAG: [[$MAP:#[a-zA-Z0-9_]*]] = affine_map<(d0, d1) -> (d0 + d1)>
+// TILE-0n25-LABEL: func @indexed_matrix
+// TILE-0n25: %[[C25:.*]] = constant 25 : index
+// TILE-0n25: scf.for %[[L:.*]] = {{.*}} step %[[C25]]
+// TILE-0n25: linalg.generic
+// TILE-0n25: %[[I:.*]] = linalg.index 0 : index
+// TILE-0n25: %[[J:.*]] = linalg.index 1 : index
+// TILE-0n25: %[[NEW_J:.*]] = affine.apply [[$MAP]](%[[J]], %[[L]])
+// TILE-0n25: %[[SUM:.*]] = addi %[[I]], %[[NEW_J]] : index
+// TILE-0n25: linalg.yield %[[SUM]] : index
diff --git a/mlir/test/Dialect/Linalg/tile.mlir b/mlir/test/Dialect/Linalg/tile.mlir
index d8b20904f751..c761bd6cd57e 100644
--- a/mlir/test/Dialect/Linalg/tile.mlir
+++ b/mlir/test/Dialect/Linalg/tile.mlir
@@ -377,18 +377,3 @@ func @pointwise(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>, %arg1: memre
// TILE-234: for
// TILE-234-NOT: for
// TILE-234: linalg.generic
-
-// TILE-2-LABEL: func @index_op
-// TILE-2-NOT: for
-// TILE-2: linalg.generic
-func @index_op(%arg0: memref<?x?xindex>) {
- linalg.generic {
- indexing_maps = [affine_map<(i, j) -> (i, j)>],
- iterator_types = ["parallel", "parallel"]}
- outs(%arg0 : memref<?x?xindex>) {
- ^bb0(%arg1: index): // no predecessors
- %0 = linalg.index 1 : index
- linalg.yield %0 : index
- }
- return
-}
More information about the Mlir-commits
mailing list