[Mlir-commits] [mlir] [mlir][linalg] Fix for invalid IR eliminate_empty_tensors (PR #73513)

Spenser Bauman llvmlistbot at llvm.org
Mon Nov 27 05:30:52 PST 2023


https://github.com/sabauma created https://github.com/llvm/llvm-project/pull/73513

The transform.structured.eliminate_empty_tensors can produce mis-typed IR when traversing use-def chains past tensor reshaping operations for sharing candidates. This results in Linalg operations whose output types do not match their 'outs' arguments.

This patch filters out candidate tensor.empty operations when their types do not match the candidate input operand.

>From 06d14a7e473563d4842ec42a8ab28177bece0d0d Mon Sep 17 00:00:00 2001
From: Spenser Bauman <sbauman at mathworks.com>
Date: Sun, 26 Nov 2023 17:47:53 -0500
Subject: [PATCH] [mlir][linalg] Fix for invalid IR eliminate_empty_tensors

The transform.structured.eliminate_empty_tensors can produce mis-typed
IR when traversing use-def chains past tensor reshaping operations for
sharing candidates. This results in Linalg operations whose output types
do not match their 'outs' arguments.

This patch filters out candidate tensor.empty operations when their types
do not match the candidate input operand.
---
 .../Transforms/EliminateEmptyTensors.cpp      |  5 ++-
 ...ot-bufferize-empty-tensor-elimination.mlir | 41 +++++++++++++++++++
 2 files changed, 45 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/EliminateEmptyTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/EliminateEmptyTensors.cpp
index 5a8320bdb287533..f28f8f0d34a4da5 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/EliminateEmptyTensors.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/EliminateEmptyTensors.cpp
@@ -60,7 +60,10 @@ LogicalResult linalg::linalgOpAnchoredEmptyTensorEliminationStep(
       config.alwaysIncludeLeaves = false;
       SetVector<Value> emptyTensors = state.findValueInReverseUseDefChain(
           in->get(), /*condition=*/
-          [&](Value val) { return val.getDefiningOp<tensor::EmptyOp>(); },
+          [&](Value val) {
+            return val.getDefiningOp<tensor::EmptyOp>() &&
+                   val.getType() == in->get().getType();
+          },
           config);
       if (emptyTensors.empty())
         continue;
diff --git a/mlir/test/Dialect/Linalg/one-shot-bufferize-empty-tensor-elimination.mlir b/mlir/test/Dialect/Linalg/one-shot-bufferize-empty-tensor-elimination.mlir
index 0172760576efc51..7b575119c9cc44b 100644
--- a/mlir/test/Dialect/Linalg/one-shot-bufferize-empty-tensor-elimination.mlir
+++ b/mlir/test/Dialect/Linalg/one-shot-bufferize-empty-tensor-elimination.mlir
@@ -42,3 +42,44 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+
+// -----
+
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+
+// This test is intended to check that the produced IR does not contain any
+// type errors from sharing empty tensor operations with different types.
+// The verifiers are sufficient to lock down the intended behavior.
+
+// CHECK-LABEL: func.func @collapse_shape_prevents_reuse(
+func.func @collapse_shape_prevents_reuse(%fill_value: f32) -> tensor<1x128x128x56xf32>
+{
+  %init0 = tensor.empty() : tensor<1x128x128x56xf32>
+  %init1 = tensor.empty() : tensor<1x128x128x56x1xf32>
+
+  %filled_tensor = linalg.fill
+    ins(%fill_value : f32)
+    outs(%init1 : tensor<1x128x128x56x1xf32>) -> tensor<1x128x128x56x1xf32>
+
+  %reshaped_tensor = tensor.collapse_shape %filled_tensor [[0], [1], [2], [3, 4]]
+    : tensor<1x128x128x56x1xf32> into tensor<1x128x128x56xf32>
+
+  %bias = linalg.generic {
+    indexing_maps = [#map, #map],
+    iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+  } ins(%reshaped_tensor : tensor<1x128x128x56xf32>)
+    outs(%init0 : tensor<1x128x128x56xf32>) {
+    ^bb0(%in: f32, %out: f32):
+      linalg.yield %in : f32
+  } -> tensor<1x128x128x56xf32>
+
+  return %bias : tensor<1x128x128x56xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.structured.eliminate_empty_tensors %0 : !transform.any_op
+    transform.yield
+  }
+}



More information about the Mlir-commits mailing list