[Mlir-commits] [mlir] 6b65d79 - [mlir][linalg] Fix for invalid IR in eliminate_empty_tensors (#73513)
    llvmlistbot at llvm.org 
    llvmlistbot at llvm.org
       
    Mon Jan  1 09:12:44 PST 2024
    
    
  
Author: Spenser Bauman
Date: 2024-01-01T17:12:40Z
New Revision: 6b65d79fbb4682468333cea42b62f15c2dffd8f3
URL: https://github.com/llvm/llvm-project/commit/6b65d79fbb4682468333cea42b62f15c2dffd8f3
DIFF: https://github.com/llvm/llvm-project/commit/6b65d79fbb4682468333cea42b62f15c2dffd8f3.diff
LOG: [mlir][linalg] Fix for invalid IR in eliminate_empty_tensors (#73513)
The transform.structured.eliminate_empty_tensors can produce mis-typed
IR when traversing use-def chains past tensor reshaping operations for
sharing candidates. This results in Linalg operations whose output types
do not match their 'outs' arguments.
This patch filters out candidate tensor.empty operations when their
types do not match the candidate input operand.
Added: 
    
Modified: 
    mlir/lib/Dialect/Linalg/Transforms/EliminateEmptyTensors.cpp
    mlir/test/Dialect/Linalg/one-shot-bufferize-empty-tensor-elimination.mlir
Removed: 
    
################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/EliminateEmptyTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/EliminateEmptyTensors.cpp
index 5a8320bdb28753..f28f8f0d34a4da 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/EliminateEmptyTensors.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/EliminateEmptyTensors.cpp
@@ -60,7 +60,10 @@ LogicalResult linalg::linalgOpAnchoredEmptyTensorEliminationStep(
       config.alwaysIncludeLeaves = false;
       SetVector<Value> emptyTensors = state.findValueInReverseUseDefChain(
           in->get(), /*condition=*/
-          [&](Value val) { return val.getDefiningOp<tensor::EmptyOp>(); },
+          [&](Value val) {
+            return val.getDefiningOp<tensor::EmptyOp>() &&
+                   val.getType() == in->get().getType();
+          },
           config);
       if (emptyTensors.empty())
         continue;
diff  --git a/mlir/test/Dialect/Linalg/one-shot-bufferize-empty-tensor-elimination.mlir b/mlir/test/Dialect/Linalg/one-shot-bufferize-empty-tensor-elimination.mlir
index 0172760576efc5..761b75d8183732 100644
--- a/mlir/test/Dialect/Linalg/one-shot-bufferize-empty-tensor-elimination.mlir
+++ b/mlir/test/Dialect/Linalg/one-shot-bufferize-empty-tensor-elimination.mlir
@@ -42,3 +42,89 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+
+// -----
+
+#map = affine_map<(d0) -> (d0)>
+
+// This test is intended to check that the produced IR does not contain any
+// type errors from sharing empty tensor operations with 
diff erent types.
+// The verifiers are sufficient to lock down the intended behavior.
+
+// CHECK-LABEL: func.func @collapse_shape_prevents_reuse(
+func.func @collapse_shape_prevents_reuse(%fill_value: f32) -> tensor<56xf32>
+{
+  %init0 = tensor.empty() : tensor<56xf32>
+  %init1 = tensor.empty() : tensor<56x1xf32>
+
+  %filled_tensor = linalg.fill
+    ins(%fill_value : f32)
+    outs(%init1 : tensor<56x1xf32>) -> tensor<56x1xf32>
+
+  // The collapse shape alters the tensor rank, so the %init1 tensor.empty cannot be
+  // pushed into the output of the linalg.generic.
+  %reshaped_tensor = tensor.collapse_shape %filled_tensor [[0, 1]]
+    : tensor<56x1xf32> into tensor<56xf32>
+
+  %bias = linalg.generic {
+    indexing_maps = [#map, #map],
+    iterator_types = ["parallel"]
+  } ins(%reshaped_tensor : tensor<56xf32>)
+    outs(%init0 : tensor<56xf32>) {
+    ^bb0(%in: f32, %out: f32):
+      linalg.yield %in : f32
+  } -> tensor<56xf32>
+
+  return %bias : tensor<56xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.structured.eliminate_empty_tensors %0 : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+
+// This test is intended to check that the produced IR does not contain any
+// type errors from sharing empty tensor operations with 
diff erent types.
+// The verifiers are sufficient to lock down the intended behavior.
+
+// CHECK-LABEL: func.func @collapse_cast_prevents_reuse(
+func.func @collapse_cast_prevents_reuse(%fill_value: f32) -> tensor<56x?xf32>
+{
+  %c1 = arith.constant 1 : index
+  %init0 = tensor.empty(%c1) : tensor<56x?xf32>
+  %init1 = tensor.empty() : tensor<56x1xf32>
+
+  %filled_tensor = linalg.fill
+    ins(%fill_value : f32)
+    outs(%init1 : tensor<56x1xf32>) -> tensor<56x1xf32>
+
+  // The cast alters the number of dynamic dims, so the %init1 tensor.empty cannot be
+  // pushed into the output of the linalg.generic.
+  %cast = tensor.cast %filled_tensor : tensor<56x1xf32> to tensor<56x?xf32>
+
+  %bias = linalg.generic {
+    indexing_maps = [#map, #map],
+    iterator_types = ["parallel", "parallel"]
+  } ins(%cast : tensor<56x?xf32>)
+    outs(%init0 : tensor<56x?xf32>) {
+    ^bb0(%in: f32, %out: f32):
+      linalg.yield %in : f32
+  } -> tensor<56x?xf32>
+
+  return %bias : tensor<56x?xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.structured.eliminate_empty_tensors %0 : !transform.any_op
+    transform.yield
+  }
+}
        
    
    
More information about the Mlir-commits
mailing list