[Mlir-commits] [mlir] c214cee - [mlir] improve error handling in Linalg op splitting

Alex Zinenko llvmlistbot at llvm.org
Fri Jan 6 09:35:15 PST 2023


Author: Alex Zinenko
Date: 2023-01-06T18:35:08+01:00
New Revision: c214cee772f7ce9c9128384b2dd1640f12c9feac

URL: https://github.com/llvm/llvm-project/commit/c214cee772f7ce9c9128384b2dd1640f12c9feac
DIFF: https://github.com/llvm/llvm-project/commit/c214cee772f7ce9c9128384b2dd1640f12c9feac.diff

LOG: [mlir] improve error handling in Linalg op splitting

In several cases, the splitting may be known to be a noop, i.e., produce
no second part. Thread this information through the transform utilities
to the transform dialect, and differentiate it from the error state.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D141138

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
    mlir/lib/Dialect/Linalg/Transforms/Split.cpp
    mlir/test/Dialect/Linalg/transform-op-split.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index cc2b63d2786ee..3be4ee72cd5a5 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -162,6 +162,9 @@ FailureOr<Operation *> fuseElementwiseOps(RewriterBase &rewriter,
 
 /// Split the given `op` into two parts along the given iteration space
 /// `dimension` at the specified `splitPoint`, and return the two parts.
+/// If the second part is statically known to be empty, do not create it
+/// and return nullptr instead. Error state is signalled by returning
+/// a pair of nullptrs.
 ///
 /// For example, the following op:
 ///

diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index dddee38b32710..f170d0bd1199b 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -1043,6 +1043,7 @@ DiagnosedSilenceableFailure SplitOp::apply(TransformResults &results,
 
   // Split each target operation.
   SmallVector<Operation *> first, second;
+  Operation *noSecondPart = nullptr;
   for (const auto &pair : llvm::zip(payload, splitPoints)) {
     Operation *target = std::get<0>(pair);
     auto linalgOp = dyn_cast<LinalgOp>(target);
@@ -1067,6 +1068,32 @@ DiagnosedSilenceableFailure SplitOp::apply(TransformResults &results,
     std::tie(first.emplace_back(), second.emplace_back()) = linalg::splitOp(
         rewriter, cast<TilingInterface>(linalgOp.getOperation()),
         getDimension(), std::get<1>(pair));
+
+    // Propagate errors.
+    if (!first.back() && !second.back()) {
+      auto diag = emitDefiniteFailure() << "internal failure in splitting";
+      diag.attachNote(target->getLoc()) << "target op";
+      return diag;
+    }
+
+    // Do not add null second parts.
+    if (!second.back()) {
+      noSecondPart = target;
+      second.pop_back();
+    }
+  }
+
+  if (second.size() != first.size() && !second.empty()) {
+    results.set(getFirst().cast<OpResult>(), {});
+    results.set(getSecond().cast<OpResult>(), {});
+    auto diag =
+        emitSilenceableError()
+        << "splitting does not produce the second part for a subset of targets";
+    diag.attachNote() << "expected splitting to produce the second part of all "
+                         "or none of the targets";
+    diag.attachNote(noSecondPart->getLoc())
+        << "first target with no second part";
+    return diag;
   }
 
   results.set(getFirst().cast<OpResult>(), first);

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Split.cpp b/mlir/lib/Dialect/Linalg/Transforms/Split.cpp
index f58182e71c96e..c8c9c0bd4af89 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Split.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Split.cpp
@@ -128,6 +128,10 @@ linalg::splitOp(RewriterBase &rewriter, TilingInterface op, unsigned dimension,
       createSplitPart(rewriter, op.getLoc(), op, offsets, sizes, firstResults,
                       dimension, remainingSize, totalOffset, secondResults);
 
+  // Propagate any errors in part creation.
+  if (!firstPart || !secondPart)
+    return {TilingInterface(), TilingInterface()};
+
   // Replace the original op with the results of the two newly created ops.
   rewriter.replaceOp(op, secondResults);
   return {firstPart, secondPart};

diff  --git a/mlir/test/Dialect/Linalg/transform-op-split.mlir b/mlir/test/Dialect/Linalg/transform-op-split.mlir
index ba5ff6ec6bfec..1d7f15efe73cb 100644
--- a/mlir/test/Dialect/Linalg/transform-op-split.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-split.mlir
@@ -46,6 +46,16 @@ func.func @one_d_static(%arg0: tensor<100xf32>, %arg1: tensor<100xf32>) -> tenso
   return %0 : tensor<100xf32>
 }
 
+// -----
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+  %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
+  %1:2 = transform.structured.split %0 after 42 { dimension = 0 }
+}
+
+func.func private @elem(%arg0: f32, %arg1: index, %arg2: index) -> f32
+
 // CHECK-LABEL: @one_d_static_overflow
 // CHECK-SAME:  %[[IN:.+]]: tensor<10xf32>, %[[OUT:.+]]: tensor<10xf32>
 func.func @one_d_static_overflow(%arg0: tensor<10xf32>, %arg1: tensor<10xf32>) -> tensor<10xf32> {
@@ -268,3 +278,45 @@ func.func @one_d_static(%arg0: tensor<100xf32>, %arg1: tensor<100xf32>) -> tenso
   return %0 : tensor<100xf32>
 }
 
+// -----
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+  %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
+  // expected-error @below {{splitting does not produce the second part for a subset of targets}}
+  // expected-note @below {{expected splitting to produce the second part of all or none of the targets}}
+  %1:2 = transform.structured.split %0 after 142 { dimension = 0 }
+}
+
+func.func private @elem(%arg0: f32, %arg1: index, %arg2: index) -> f32
+
+func.func @split_one_but_not_other(
+    %arg0: tensor<100xf32>, %arg1: tensor<100xf32>,
+    %arg2: tensor<200xf32>, %arg3: tensor<200xf32>)
+    -> (tensor<100xf32>, tensor<200xf32>) {
+  // expected-note @below {{first target with no second part}}
+  %0 = linalg.generic {
+    indexing_maps = [affine_map<(i) -> (i)>, affine_map<(i) -> (i)>],
+    iterator_types = ["parallel"]
+  }
+  ins(%arg0: tensor<100xf32>) outs(%arg1: tensor<100xf32>) {
+  ^bb0(%arg4: f32, %arg5: f32):
+    %i = linalg.index 0 : index
+    %call_res = func.call @elem(%arg4, %i, %i) : (f32, index, index) -> f32
+    linalg.yield %call_res : f32
+  } -> tensor<100xf32>
+
+  %1 = linalg.generic {
+    indexing_maps = [affine_map<(i) -> (i)>, affine_map<(i) -> (i)>],
+    iterator_types = ["parallel"]
+  }
+  ins(%arg2: tensor<200xf32>) outs(%arg3: tensor<200xf32>) {
+  ^bb0(%arg4: f32, %arg5: f32):
+    %i = linalg.index 0 : index
+    %call_res = func.call @elem(%arg4, %i, %i) : (f32, index, index) -> f32
+    linalg.yield %call_res : f32
+  } -> tensor<200xf32>
+
+  return %0, %1 : tensor<100xf32>, tensor<200xf32>
+}
+


        


More information about the Mlir-commits mailing list