[Mlir-commits] [mlir] [mlir][scf]-Fix reverse iterator overflow in loop traversal (PR #128421)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Feb 23 06:55:51 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Amir Bishara (amirBish)

<details>
<summary>Changes</summary>

Fix a bug in method `getUntiledProducerFromSliceSource` where address sanitizer fails compilation on heap
buffer overflow for accessing value out of the iteration range.

This PR fixes the issue and adds a lit test to reproduce it.

---
Full diff: https://github.com/llvm/llvm-project/pull/128421.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp (+3-1) 
- (modified) mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir (+54) 


``````````diff
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index b548f8ce8b560..af87fb7a79d04 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -1119,8 +1119,10 @@ static std::tuple<OpResult, std::optional<OpOperand *>>
 getUntiledProducerFromSliceSource(OpOperand *source,
                                   ArrayRef<LoopLikeOpInterface> loops) {
   std::optional<OpOperand *> destinationIterArg;
+  assert(!loops.empty() && "expected non empty loops container");
   auto loopIt = loops.rbegin();
-  while (auto iterArg = dyn_cast<BlockArgument>(source->get())) {
+  while (loopIt != loops.rend() && isa<BlockArgument>(source->get())) {
+    auto iterArg = cast<BlockArgument>(source->get());
     auto loop = *loopIt;
     if (iterArg.getOwner()->getParentOp() != loop)
       break;
diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir
index bc27840fdf5e9..8a0390a4379cf 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir
@@ -634,3 +634,57 @@ module attributes {transform.with_named_sequence} {
 //       CHECK:     %[[INSERT_SLICE:.+]] = tensor.insert_slice %[[GENERIC]] into %[[ITER_ARG]]
 //       CHECK:     scf.yield %[[INSERT_SLICE]]
 //       CHECK:   return %[[FOR_RESULT]]
+
+// -----
+
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2, d1)>
+module {
+  func.func private @tile_one_consumer_using_tile_and_fuse(%arg0: tensor<16x128x48x96xf32>, %arg1: tensor<16x96x48x128xf32>) -> tensor<16x96x48x128xf32> {
+    %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<16x128x48x96xf32>) outs(%arg1 : tensor<16x96x48x128xf32>) {
+    ^bb0(%in: f32, %out: f32):
+      linalg.yield %in : f32
+    } -> tensor<16x96x48x128xf32>
+    return %0 : tensor<16x96x48x128xf32>
+  }
+}
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %generic = transform.structured.match ops{["linalg.generic"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %a, %loops:4 = transform.structured.fuse %generic {tile_sizes = [1, 16, 16, 16], tile_interchange = [0, 1, 2, 3], apply_cleanup = false}
+      : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+
+// CHECK:           func.func private @tile_one_consumer_using_tile_and_fuse(%[[VAL_0:.*]]: tensor<16x128x48x96xf32>, %[[VAL_1:.*]]: tensor<16x96x48x128xf32>) -> tensor<16x96x48x128xf32> {
+// CHECK:             %[[VAL_2:.*]] = arith.constant 0 : index
+// CHECK:             %[[VAL_3:.*]] = arith.constant 16 : index
+// CHECK:             %[[VAL_4:.*]] = arith.constant 128 : index
+// CHECK:             %[[VAL_5:.*]] = arith.constant 48 : index
+// CHECK:             %[[VAL_6:.*]] = arith.constant 96 : index
+// CHECK:             %[[VAL_7:.*]] = arith.constant 1 : index
+// CHECK:             %[[VAL_8:.*]] = scf.for %[[VAL_9:.*]] = %[[VAL_2]] to %[[VAL_3]] step %[[VAL_7]] iter_args(%[[VAL_10:.*]] = %[[VAL_1]]) -> (tensor<16x96x48x128xf32>) {
+// CHECK:               %[[VAL_11:.*]] = scf.for %[[VAL_12:.*]] = %[[VAL_2]] to %[[VAL_4]] step %[[VAL_3]] iter_args(%[[VAL_13:.*]] = %[[VAL_10]]) -> (tensor<16x96x48x128xf32>) {
+// CHECK:                 %[[VAL_14:.*]] = scf.for %[[VAL_15:.*]] = %[[VAL_2]] to %[[VAL_5]] step %[[VAL_3]] iter_args(%[[VAL_16:.*]] = %[[VAL_13]]) -> (tensor<16x96x48x128xf32>) {
+// CHECK:                   %[[VAL_17:.*]] = scf.for %[[VAL_18:.*]] = %[[VAL_2]] to %[[VAL_6]] step %[[VAL_3]] iter_args(%[[VAL_19:.*]] = %[[VAL_16]]) -> (tensor<16x96x48x128xf32>) {
+// CHECK:                     %[[VAL_20:.*]] = tensor.extract_slice %[[VAL_0]]{{\[}}%[[VAL_9]], %[[VAL_12]], %[[VAL_15]], %[[VAL_18]]] [1, 16, 16, 16] [1, 1, 1, 1] : tensor<16x128x48x96xf32> to tensor<1x16x16x16xf32>
+// CHECK:                     %[[VAL_21:.*]] = tensor.extract_slice %[[VAL_19]]{{\[}}%[[VAL_9]], %[[VAL_18]], %[[VAL_15]], %[[VAL_12]]] [1, 16, 16, 16] [1, 1, 1, 1] : tensor<16x96x48x128xf32> to tensor<1x16x16x16xf32>
+// CHECK:                     %[[VAL_22:.*]] = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[VAL_20]] : tensor<1x16x16x16xf32>) outs(%[[VAL_21]] : tensor<1x16x16x16xf32>) {
+// CHECK:                     ^bb0(%[[VAL_23:.*]]: f32, %[[VAL_24:.*]]: f32):
+// CHECK:                       linalg.yield %[[VAL_23]] : f32
+// CHECK:                     } -> tensor<1x16x16x16xf32>
+// CHECK:                     %[[VAL_25:.*]] = tensor.insert_slice %[[VAL_26:.*]] into %[[VAL_19]]{{\[}}%[[VAL_9]], %[[VAL_18]], %[[VAL_15]], %[[VAL_12]]] [1, 16, 16, 16] [1, 1, 1, 1] : tensor<1x16x16x16xf32> into tensor<16x96x48x128xf32>
+// CHECK:                     scf.yield %[[VAL_25]] : tensor<16x96x48x128xf32>
+// CHECK:                   }
+// CHECK:                   scf.yield %[[VAL_27:.*]] : tensor<16x96x48x128xf32>
+// CHECK:                 }
+// CHECK:                 scf.yield %[[VAL_28:.*]] : tensor<16x96x48x128xf32>
+// CHECK:               }
+// CHECK:               scf.yield %[[VAL_29:.*]] : tensor<16x96x48x128xf32>
+// CHECK:             }
+// CHECK:             return %[[VAL_30:.*]] : tensor<16x96x48x128xf32>
+// CHECK:           }
+// CHECK:         }
+

``````````

</details>


https://github.com/llvm/llvm-project/pull/128421


More information about the Mlir-commits mailing list