[Mlir-commits] [mlir] a367c57 - [mlir][linalg] Relax tiling constraint when there are multiple destination operands

Guray Ozen llvmlistbot at llvm.org
Fri Sep 9 07:38:43 PDT 2022


Author: Guray Ozen
Date: 2022-09-09T16:38:33+02:00
New Revision: a367c571412d885be453b6adce37589b1ed0e504

URL: https://github.com/llvm/llvm-project/commit/a367c571412d885be453b6adce37589b1ed0e504
DIFF: https://github.com/llvm/llvm-project/commit/a367c571412d885be453b6adce37589b1ed0e504.diff

LOG: [mlir][linalg] Relax tiling constraint when there are multiple destination operands

This revision relaxes constraint of tiling when there are multiple destination operands. It also adds a test.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D132937

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
    mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index 1110b0551f1fe..4b3e260ac8d2e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -235,10 +235,7 @@ static FailureOr<ForeachThreadTilingResult> tileToForeachThreadOpImpl(
   auto hasStrideOne = [](Range r) { return !isConstantIntValue(r.stride, 1); };
   if (llvm::any_of(loopRanges, hasStrideOne))
     return op->emitOpError("only stride-1 supported atm");
-  // TODO: support `getTiledImplementation` with >1 produced tiled ops.
-  auto dest = op.getDestinationOperands(b);
-  if (dest.size() != 1)
-    return op->emitOpError("only single dest operand supported atm");
+  auto destOperands = op.getDestinationOperands(b);
 
   SmallVector<OpFoldResult> nonZeroNumThreads =
       llvm::to_vector(llvm::make_filter_range(numThreads, [](OpFoldResult ofr) {

diff  --git a/mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir b/mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir
index bf1a2cdeae41c..4c5607e89f1be 100644
--- a/mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir
+++ b/mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir
@@ -249,3 +249,61 @@ transform.with_pdl_patterns {
     %1:2 = transform.structured.tile_to_foreach_thread_op %0 tile_sizes [%sz, 20]
   }
 }
+
+// -----
+
+// CHECK-DAG: #[[$map0:.+]] = affine_map<(d0) -> (d0 * -15 + 100, 15)>
+// CHECK-DAG: #[[$map1:.+]] = affine_map<(d0) -> (0, d0)>
+// CHECK-DAG: #[[$map2:.+]] = affine_map<(d0) -> (d0 * 15)>
+// CHECK-DAG: #[[$map3:.+]] = affine_map<(d0) -> (d0)>
+
+// CHECK-LABEL: tile_output_multi_1d_static(
+//  CHECK-SAME:   %[[IN1:[0-9a-z]+]]: tensor<100xf32>
+//  CHECK-SAME:   %[[IN2:[0-9a-z]+]]: tensor<100xf32>
+//  CHECK-SAME:   %[[OUT1:[0-9a-z]+]]: tensor<100xf32>
+//  CHECK-SAME:   %[[OUT2:[0-9a-z]+]]: tensor<100xf32>
+  func.func @tile_output_multi_1d_static(%IN1: tensor<100xf32>, %IN2: tensor<100xf32>, 
+                                         %OUT1: tensor<100xf32>, %OUT2: tensor<100xf32>)
+                                         -> (tensor<100xf32>, tensor<100xf32>) {
+//  CHECK-DAG: %[[c0:.+]] = arith.constant 7 :
+//      CHECK: scf.foreach_thread (%[[IV0:.+]]) in (%[[c0]])
+//      CHECK:   %[[TSMIN:.+]] = affine.min #[[$map0]](%[[IV0]])
+//      CHECK:   %[[TS:.+]] = affine.max #[[$map1]](%[[TSMIN]])
+//  CHECK-NOT:   affine.min
+//  CHECK-NOT:   affine.max
+//      CHECK:   %[[LB:.+]] = affine.apply #[[$map2]](%[[IV0]])
+//      CHECK:   %[[tIN1:.+]] = tensor.extract_slice %[[IN1]][%[[LB]]] [%[[TS]]] [1] :
+//      CHECK:   %[[tIN2:.+]] = tensor.extract_slice %[[IN2]][%[[LB]]] [%[[TS]]] [1] :
+//      CHECK:   %[[tOUT1:.+]] = tensor.extract_slice %[[OUT1]][%[[LB]]] [%[[TS]]] [1] :
+//      CHECK:   %[[tOUT2:.+]] = tensor.extract_slice %[[OUT2]][%[[LB]]] [%[[TS]]] [1] :
+//      CHECK:   %[[RES1:[0-9]+]]:[[RES2:[0-9]+]] = linalg.generic
+//      CHECK:   scf.foreach_thread.perform_concurrently
+// CHECK-NEXT:    tensor.parallel_insert_slice %[[RES1]]#1 into %[[OUT2]][%[[LB]]] [%[[TS]]] [1] :
+// CHECK-NEXT:    tensor.parallel_insert_slice %[[RES1]]#0 into %[[OUT1]][%[[LB]]] [%[[TS]]] [1] :
+    %res1, %res2 = linalg.generic
+    {
+      indexing_maps = [affine_map<(d0) -> (d0)>,
+                       affine_map<(d0) -> (d0)>,
+                       affine_map<(d0) -> (d0)>,
+                       affine_map<(d0) -> (d0)>],
+      iterator_types = ["parallel"]
+    } ins(%IN1, %IN2 : tensor<100xf32>, tensor<100xf32>)
+      outs(%OUT1, %OUT2 : tensor<100xf32>, tensor<100xf32>) 
+    {
+      ^bb0(%a1: f32, %a2: f32, %a3: f32, %a4: f32):
+        %1 = arith.addf %a1, %a3 : f32
+        %2 = arith.addf %a2, %a4 : f32
+        linalg.yield %1, %2 : f32,f32
+    } -> (tensor<100xf32>, tensor<100xf32>)
+    return %res1, %res2 : tensor<100xf32>, tensor<100xf32>
+  }
+
+  transform.with_pdl_patterns {
+    ^bb0(%arg0: !pdl.operation):
+    transform.sequence %arg0 failures(propagate) {
+      ^bb1(%arg1: !pdl.operation):
+      %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
+      %foreach_thread, %tiled_generic = transform.structured.tile_to_foreach_thread_op %0 num_threads [7]
+    }
+  }
+


        


More information about the Mlir-commits mailing list