[Mlir-commits] [mlir] Check linalg.generic arguments to prevent crashing when they are deleted (PR #119110)

Renat Idrisov llvmlistbot at llvm.org
Sat Dec 7 19:43:44 PST 2024


https://github.com/parsifal-47 created https://github.com/llvm/llvm-project/pull/119110

During remove-dead-values pass, values are getting deleted and if they appear as arguments to dead operations values disappear before operations are deleted. In this case, arguments may be null. This is a temporary state while the IR is being processed.
When remove-dead-values is trying to delete linalg.generic, it calls `isMemoryEffectFree` and it iterates over the arguments, some of which are already deleted. This is causing a crash: https://github.com/llvm/llvm-project/issues/118450 which is added as a test.
Thank you!

>From 25e6576899633f370ed7793c20ace8f5ed76789d Mon Sep 17 00:00:00 2001
From: Renat Idrisov <parsifal-47 at users.noreply.github.com>
Date: Sun, 8 Dec 2024 03:36:52 +0000
Subject: [PATCH] Check linalg.generic arguments to prevent crashing when they
 are deleted

---
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp     |  4 +--
 mlir/test/Transforms/remove-dead-values.mlir | 26 ++++++++++++++++++++
 2 files changed, 28 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index d9840e3923c4f7..c5b1a3b55126c8 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1199,7 +1199,7 @@ static void getGenericEffectsImpl(
         &effects,
     LinalgOp linalgOp) {
   for (auto [index, operand] : llvm::enumerate(linalgOp.getDpsInputs())) {
-    if (!llvm::isa<MemRefType>(operand.getType()))
+    if (!operand || !llvm::isa<MemRefType>(operand.getType()))
       continue;
     effects.emplace_back(
         MemoryEffects::Read::get(), &linalgOp->getOpOperand(index), /*stage=*/0,
@@ -1207,7 +1207,7 @@ static void getGenericEffectsImpl(
   }
 
   for (OpOperand &operand : linalgOp.getDpsInitsMutable()) {
-    if (!llvm::isa<MemRefType>(operand.get().getType()))
+    if (!operand.get() || !llvm::isa<MemRefType>(operand.get().getType()))
       continue;
     if (linalgOp.payloadUsesValueFromOperand(&operand)) {
       effects.emplace_back(MemoryEffects::Read::get(), &operand, /*stage=*/0,
diff --git a/mlir/test/Transforms/remove-dead-values.mlir b/mlir/test/Transforms/remove-dead-values.mlir
index 9273ac01e7ccec..fe7bcbc7c490b6 100644
--- a/mlir/test/Transforms/remove-dead-values.mlir
+++ b/mlir/test/Transforms/remove-dead-values.mlir
@@ -73,6 +73,32 @@ func.func @cleanable_loop_iter_args_value(%arg0: index) -> index {
 
 // -----
 
+// Checking that the arguments of linalg.generic are properly handled
+// All code below is removed as a result of the pass
+//
+#map = affine_map<(d0, d1, d2) -> (0, d1, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+module {
+  func.func @main() {
+    %cst_3 = arith.constant dense<54> : tensor<1x25x13xi32>
+    %cst_7 = arith.constant dense<11> : tensor<1x25x13xi32>
+    // CHECK-NOT: arith.constant
+    %0 = tensor.empty() : tensor<1x25x13xi32>
+    // CHECK-NOT: tensor
+    %1 = linalg.generic {indexing_maps = [#map, #map, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_3, %cst_7 : tensor<1x25x13xi32>, tensor<1x25x13xi32>) outs(%0 : tensor<1x25x13xi32>) {
+    // CHECK-NOT: linalg.generic
+    ^bb0(%in: i32, %in_15: i32, %out: i32):
+      %29 = arith.xori %in, %in_15 : i32
+      // CHECK-NOT: arith.xori
+      linalg.yield %29 : i32
+      // CHECK-NOT: linalg.yield
+    } -> tensor<1x25x13xi32>
+    return
+  }
+}
+
+// -----
+
 // Note that this cleanup cannot be done by the `canonicalize` pass.
 //
 // CHECK-LABEL: func.func private @clean_func_op_remove_argument_and_return_value() {



More information about the Mlir-commits mailing list