[Mlir-commits] [mlir] [mlir][linalg] Fix for invalid IR in eliminate_empty_tensors (PR #73513)

Spenser Bauman llvmlistbot at llvm.org
Tue Nov 28 06:28:33 PST 2023


https://github.com/sabauma updated https://github.com/llvm/llvm-project/pull/73513

>From a9cff82ff60cf5ee020faa0d7acc30603928eb38 Mon Sep 17 00:00:00 2001
From: Spenser Bauman <sbauman at mathworks.com>
Date: Sun, 26 Nov 2023 17:47:53 -0500
Subject: [PATCH] [mlir][linalg] Fix for invalid IR eliminate_empty_tensors

The transform.structured.eliminate_empty_tensors can produce mis-typed
IR when traversing use-def chains past tensor reshaping operations for
sharing candidates. This results in Linalg operations whose output types
do not match their 'outs' arguments.

This patch filters out candidate tensor.empty operations when their types
do not match the candidate input operand.
---
 .../Transforms/EliminateEmptyTensors.cpp      |  5 +-
 ...ot-bufferize-empty-tensor-elimination.mlir | 86 +++++++++++++++++++
 2 files changed, 90 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/EliminateEmptyTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/EliminateEmptyTensors.cpp
index 5a8320bdb287533..f28f8f0d34a4da5 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/EliminateEmptyTensors.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/EliminateEmptyTensors.cpp
@@ -60,7 +60,10 @@ LogicalResult linalg::linalgOpAnchoredEmptyTensorEliminationStep(
       config.alwaysIncludeLeaves = false;
       SetVector<Value> emptyTensors = state.findValueInReverseUseDefChain(
           in->get(), /*condition=*/
-          [&](Value val) { return val.getDefiningOp<tensor::EmptyOp>(); },
+          [&](Value val) {
+            return val.getDefiningOp<tensor::EmptyOp>() &&
+                   val.getType() == in->get().getType();
+          },
           config);
       if (emptyTensors.empty())
         continue;
diff --git a/mlir/test/Dialect/Linalg/one-shot-bufferize-empty-tensor-elimination.mlir b/mlir/test/Dialect/Linalg/one-shot-bufferize-empty-tensor-elimination.mlir
index 0172760576efc51..761b75d81837321 100644
--- a/mlir/test/Dialect/Linalg/one-shot-bufferize-empty-tensor-elimination.mlir
+++ b/mlir/test/Dialect/Linalg/one-shot-bufferize-empty-tensor-elimination.mlir
@@ -42,3 +42,89 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+
+// -----
+
+#map = affine_map<(d0) -> (d0)>
+
+// This test is intended to check that the produced IR does not contain any
+// type errors from sharing empty tensor operations with different types.
+// The verifiers are sufficient to lock down the intended behavior.
+
+// CHECK-LABEL: func.func @collapse_shape_prevents_reuse(
+func.func @collapse_shape_prevents_reuse(%fill_value: f32) -> tensor<56xf32>
+{
+  %init0 = tensor.empty() : tensor<56xf32>
+  %init1 = tensor.empty() : tensor<56x1xf32>
+
+  %filled_tensor = linalg.fill
+    ins(%fill_value : f32)
+    outs(%init1 : tensor<56x1xf32>) -> tensor<56x1xf32>
+
+  // The collapse shape alters the tensor rank, so the %init1 tensor.empty cannot be
+  // pushed into the output of the linalg.generic.
+  %reshaped_tensor = tensor.collapse_shape %filled_tensor [[0, 1]]
+    : tensor<56x1xf32> into tensor<56xf32>
+
+  %bias = linalg.generic {
+    indexing_maps = [#map, #map],
+    iterator_types = ["parallel"]
+  } ins(%reshaped_tensor : tensor<56xf32>)
+    outs(%init0 : tensor<56xf32>) {
+    ^bb0(%in: f32, %out: f32):
+      linalg.yield %in : f32
+  } -> tensor<56xf32>
+
+  return %bias : tensor<56xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.structured.eliminate_empty_tensors %0 : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+
+// This test is intended to check that the produced IR does not contain any
+// type errors from sharing empty tensor operations with different types.
+// The verifiers are sufficient to lock down the intended behavior.
+
+// CHECK-LABEL: func.func @collapse_cast_prevents_reuse(
+func.func @collapse_cast_prevents_reuse(%fill_value: f32) -> tensor<56x?xf32>
+{
+  %c1 = arith.constant 1 : index
+  %init0 = tensor.empty(%c1) : tensor<56x?xf32>
+  %init1 = tensor.empty() : tensor<56x1xf32>
+
+  %filled_tensor = linalg.fill
+    ins(%fill_value : f32)
+    outs(%init1 : tensor<56x1xf32>) -> tensor<56x1xf32>
+
+  // The cast alters the number of dynamic dims, so the %init1 tensor.empty cannot be
+  // pushed into the output of the linalg.generic.
+  %cast = tensor.cast %filled_tensor : tensor<56x1xf32> to tensor<56x?xf32>
+
+  %bias = linalg.generic {
+    indexing_maps = [#map, #map],
+    iterator_types = ["parallel", "parallel"]
+  } ins(%cast : tensor<56x?xf32>)
+    outs(%init0 : tensor<56x?xf32>) {
+    ^bb0(%in: f32, %out: f32):
+      linalg.yield %in : f32
+  } -> tensor<56x?xf32>
+
+  return %bias : tensor<56x?xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.structured.eliminate_empty_tensors %0 : !transform.any_op
+    transform.yield
+  }
+}



More information about the Mlir-commits mailing list