[Mlir-commits] [mlir] [mlir][linalg] Fix for invalid IR eliminate_empty_tensors (PR #73513)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Nov 27 05:31:22 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Spenser Bauman (sabauma)

<details>
<summary>Changes</summary>

The transform.structured.eliminate_empty_tensors can produce mis-typed IR when traversing use-def chains past tensor reshaping operations for sharing candidates. This results in Linalg operations whose output types do not match their 'outs' arguments.

This patch filters out candidate tensor.empty operations when their types do not match the candidate input operand.

---
Full diff: https://github.com/llvm/llvm-project/pull/73513.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Linalg/Transforms/EliminateEmptyTensors.cpp (+4-1) 
- (modified) mlir/test/Dialect/Linalg/one-shot-bufferize-empty-tensor-elimination.mlir (+41) 


``````````diff
diff --git a/mlir/lib/Dialect/Linalg/Transforms/EliminateEmptyTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/EliminateEmptyTensors.cpp
index 5a8320bdb287533..f28f8f0d34a4da5 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/EliminateEmptyTensors.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/EliminateEmptyTensors.cpp
@@ -60,7 +60,10 @@ LogicalResult linalg::linalgOpAnchoredEmptyTensorEliminationStep(
       config.alwaysIncludeLeaves = false;
       SetVector<Value> emptyTensors = state.findValueInReverseUseDefChain(
           in->get(), /*condition=*/
-          [&](Value val) { return val.getDefiningOp<tensor::EmptyOp>(); },
+          [&](Value val) {
+            return val.getDefiningOp<tensor::EmptyOp>() &&
+                   val.getType() == in->get().getType();
+          },
           config);
       if (emptyTensors.empty())
         continue;
diff --git a/mlir/test/Dialect/Linalg/one-shot-bufferize-empty-tensor-elimination.mlir b/mlir/test/Dialect/Linalg/one-shot-bufferize-empty-tensor-elimination.mlir
index 0172760576efc51..7b575119c9cc44b 100644
--- a/mlir/test/Dialect/Linalg/one-shot-bufferize-empty-tensor-elimination.mlir
+++ b/mlir/test/Dialect/Linalg/one-shot-bufferize-empty-tensor-elimination.mlir
@@ -42,3 +42,44 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+
+// -----
+
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+
+// This test is intended to check that the produced IR does not contain any
+// type errors from sharing empty tensor operations with different types.
+// The verifiers are sufficient to lock down the intended behavior.
+
+// CHECK-LABEL: func.func @collapse_shape_prevents_reuse(
+func.func @collapse_shape_prevents_reuse(%fill_value: f32) -> tensor<1x128x128x56xf32>
+{
+  %init0 = tensor.empty() : tensor<1x128x128x56xf32>
+  %init1 = tensor.empty() : tensor<1x128x128x56x1xf32>
+
+  %filled_tensor = linalg.fill
+    ins(%fill_value : f32)
+    outs(%init1 : tensor<1x128x128x56x1xf32>) -> tensor<1x128x128x56x1xf32>
+
+  %reshaped_tensor = tensor.collapse_shape %filled_tensor [[0], [1], [2], [3, 4]]
+    : tensor<1x128x128x56x1xf32> into tensor<1x128x128x56xf32>
+
+  %bias = linalg.generic {
+    indexing_maps = [#map, #map],
+    iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+  } ins(%reshaped_tensor : tensor<1x128x128x56xf32>)
+    outs(%init0 : tensor<1x128x128x56xf32>) {
+    ^bb0(%in: f32, %out: f32):
+      linalg.yield %in : f32
+  } -> tensor<1x128x128x56xf32>
+
+  return %bias : tensor<1x128x128x56xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.structured.eliminate_empty_tensors %0 : !transform.any_op
+    transform.yield
+  }
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/73513


More information about the Mlir-commits mailing list