[Mlir-commits] [mlir] 7be28d8 - [mlir][linalg] Add IndexOp support to fusion on tensors.
Tobias Gysi
llvmlistbot at llvm.org
Mon Sep 20 09:00:37 PDT 2021
Author: Tobias Gysi
Date: 2021-09-20T15:59:35Z
New Revision: 7be28d82b4ce810ef662239a9dba7a1409c1ad49
URL: https://github.com/llvm/llvm-project/commit/7be28d82b4ce810ef662239a9dba7a1409c1ad49
DIFF: https://github.com/llvm/llvm-project/commit/7be28d82b4ce810ef662239a9dba7a1409c1ad49.diff
LOG: [mlir][linalg] Add IndexOp support to fusion on tensors.
This revision depends on https://reviews.llvm.org/D109761 and https://reviews.llvm.org/D109766.
Reviewed By: nicolasvasilache
Differential Revision: https://reviews.llvm.org/D109774
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
mlir/test/Dialect/Linalg/tile-and-fuse-on-tensors.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
index 3c362172a34e..eedda7ddf7be 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
@@ -181,6 +181,9 @@ static LinalgOp getTiledProducer(OpBuilder &b, OpResult producerResult,
.getTypes();
LinalgOp clonedOp = producerOp.clone(b, loc, resultTypes, tiledOperands);
+ // Shift all IndexOp results by the tile offset.
+ addTileLoopIvsToIndexOpResults(b, clonedOp, allIvs);
+
return clonedOp;
}
@@ -325,10 +328,6 @@ FailureOr<LinalgOp> TileLoopNest::fuseProducer(OpBuilder &b,
if (!producerResult || !isa<LinalgOp>(producerResult.getOwner()))
return failure();
- // TODO: support producers that have index semantics.
- if (cast<LinalgOp>(producerResult.getOwner()).hasIndexSemantics())
- return failure();
-
// Compute the slice dimensions tiled by `tileLoopNest`.
SmallVector<int64_t> tiledSliceDims =
getTiledSliceDims(producerResult, rootOpOperand, loopDims);
diff --git a/mlir/test/Dialect/Linalg/tile-and-fuse-on-tensors.mlir b/mlir/test/Dialect/Linalg/tile-and-fuse-on-tensors.mlir
index 4ead2391e4ef..1a4ca86aa81d 100644
--- a/mlir/test/Dialect/Linalg/tile-and-fuse-on-tensors.mlir
+++ b/mlir/test/Dialect/Linalg/tile-and-fuse-on-tensors.mlir
@@ -188,3 +188,45 @@ builtin.func @fuse_input_and_output(%arg0: tensor<24x12xf32>,
%2 = linalg.matmul ins(%0, %arg1 : tensor<24x12xf32>, tensor<12x25xf32>) outs(%1 : tensor<24x25xf32>) -> tensor<24x25xf32>
return %2 : tensor<24x25xf32>
}
+
+// -----
+
+// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0 + d1)>
+#map0 = affine_map<(d0, d1) -> (d1, d0)>
+
+// CHECK: fuse_indexed
+// CHECK-SAME: %[[ARG1:[0-9a-zA-Z]*]]: tensor<12x25xi32>
+builtin.func @fuse_indexed(%arg0: tensor<24x12xi32>,
+ %arg1: tensor<12x25xi32>,
+ %arg2: tensor<24x25xi32>) -> tensor<24x25xi32> {
+ %c0 = constant 0 : index
+ %c12 = constant 12 : index
+ %c25 = constant 25 : index
+ %c24 = constant 24 : index
+ %c4 = constant 4 : index
+ %0 = linalg.generic {indexing_maps = [#map0], iterator_types = ["parallel", "parallel"]} outs(%arg1 : tensor<12x25xi32>) {
+ ^bb0(%arg3: i32): // no predecessors
+ %6 = linalg.index 0 : index
+ %7 = linalg.index 1 : index
+ %8 = addi %6, %7 : index
+ %9 = index_cast %8 : index to i32
+ linalg.yield %9 : i32
+ } -> tensor<12x25xi32>
+
+ // CHECK: scf.for %[[IV0:[0-9a-zA-Z]*]] =
+ // CHECK: scf.for %[[IV1:[0-9a-zA-Z]*]] =
+ // CHECK: scf.for %[[IV2:[0-9a-zA-Z]*]] =
+
+ // Shift the indexes by the slice offsets and swap the offsets due to the transposed indexing map.
+ // CHECK: %[[T1:.*]] = tensor.extract_slice %[[ARG1]]
+ // CHECK-SAME: %[[IV2]], %[[IV0]]
+ // CHECK: linalg.generic {{.*}} outs(%[[T1]]
+ // CHECK: %[[IDX0:.*]] = linalg.index 0
+ // CHECK: %[[IDX0_SHIFTED:.*]] = affine.apply #[[MAP0]](%[[IDX0]], %[[IV0]])
+ // CHECK: %[[IDX1:.*]] = linalg.index 1
+ // CHECK: %[[IDX1_SHIFTED:.*]] = affine.apply #[[MAP0]](%[[IDX1]], %[[IV2]])
+ // CHECK: %{{.*}} = addi %[[IDX0_SHIFTED]], %[[IDX1_SHIFTED]]
+ %1 = linalg.matmul ins(%arg0, %0 : tensor<24x12xi32>, tensor<12x25xi32>) outs(%arg2 : tensor<24x25xi32>) -> tensor<24x25xi32>
+ return %1 : tensor<24x25xi32>
+}
+
More information about the Mlir-commits
mailing list