[Mlir-commits] [mlir] [mlir][bufferization]-Add enforce immutable func args pass (PR #113130)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Oct 20 23:55:39 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Amir Bishara (amirBish)

<details>
<summary>Changes</summary>

Adding a pass which allocates a new a buffer for each input argument of the function it operates on and is being written to, also copying it into the allocated buffer by a `memref.copy`.

---

Patch is 20.27 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/113130.diff


5 Files Affected:

- (modified) mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h (+4) 
- (modified) mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td (+14) 
- (modified) mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt (+1) 
- (added) mlir/lib/Dialect/Bufferization/Transforms/EnforceImmutableFuncArgs.cpp (+101) 
- (added) mlir/test/Dialect/Bufferization/Transforms/enforce-immutable-func-args.mlir (+248) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
index 72abb5b3f1f94e..e17914fbbd5840 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
@@ -229,6 +229,10 @@ createPromoteBuffersToStackPass(std::function<bool(Value)> isSmallAlloc);
 /// insert_slice ops.
 std::unique_ptr<Pass> createEmptyTensorEliminationPass();
 
+// Create a pass that enforces read only buffers of the
+// relevant function arguments.
+std::unique_ptr<Pass> createEnforceImmutableFuncArgsPass();
+
 //===----------------------------------------------------------------------===//
 // Registration
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
index cc5463ea968fc3..fb2b4d3a305f4a 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
@@ -595,4 +595,18 @@ def EmptyTensorElimination : Pass<"eliminate-empty-tensors"> {
   let constructor = "mlir::bufferization::createEmptyTensorEliminationPass()";
 }
 
+def EnforceImmutableFuncArgs : Pass<"enforce-immutable-func-args", "func::FuncOp"> {
+  let summary = "Enforcing function's arguments immutabilty by inserting allocOps and copy";
+  let description = [{
+    This pass allocates a new a buffer for each input argument of the function
+    which is being written to and marked to be enforced, also copying it into the
+    allocated buffer.
+    This will avoid in place memory updates for the function's arguments and
+    make it immutable/read-only buffer.
+  }];
+  let constructor = "mlir::bufferization::createEnforceImmutableFuncArgsPass()";
+  let dependentDialects = ["memref::MemRefDialect"];
+}
+
+
 #endif // MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_PASSES
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt
index 50104e8f8346b4..25de31c179a31d 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt
@@ -16,6 +16,7 @@ add_mlir_dialect_library(MLIRBufferizationTransforms
   OwnershipBasedBufferDeallocation.cpp
   TensorCopyInsertion.cpp
   OptimizeAllocationLiveness.cpp
+  EnforceImmutableFuncArgs.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Bufferization
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/EnforceImmutableFuncArgs.cpp b/mlir/lib/Dialect/Bufferization/Transforms/EnforceImmutableFuncArgs.cpp
new file mode 100644
index 00000000000000..84f201c141a3d1
--- /dev/null
+++ b/mlir/lib/Dialect/Bufferization/Transforms/EnforceImmutableFuncArgs.cpp
@@ -0,0 +1,101 @@
+//===- OptimizeAllocationLiveness.cpp - impl. optimize allocation liveness pass
+//-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass for optimizing allocation liveness.
+// The pass moves the deallocation operation after the last user of the
+// allocated buffer.
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Bufferization/Transforms/Passes.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/IR/Operation.h"
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "enforce-immutable-func-args"
+#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
+#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
+
+using namespace mlir;
+
+namespace mlir {
+namespace bufferization {
+#define GEN_PASS_DEF_ENFORCEIMMUTABLEFUNCARGS
+#include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
+} // namespace bufferization
+} // namespace mlir
+
+// Checks if there is any operation which tries to write
+// into `buffer`.
+// This method assumes buffer has `MemRefType`.
+static bool isWrittenTo(Value buffer);
+
+namespace {
+/// This pass allocates a new a buffer for each input argument of the function
+/// which is being written to, also copying it into the allocated buffer.
+/// This will avoid in place memory updates for the kernel's arguments and
+/// make them immutable/read-only buffers.
+struct EnforceImmutableFuncArgsPass
+    : public bufferization::impl::EnforceImmutableFuncArgsBase<
+          EnforceImmutableFuncArgsPass> {
+  void runOnOperation() final;
+};
+} // end anonymous namespace.
+
+static bool isWrittenTo(Value buffer) {
+  assert(isa<MemRefType>(buffer.getType()));
+
+  for (auto user : buffer.getUsers()) {
+    if (hasEffect<MemoryEffects::Write>(user, buffer))
+      return true;
+    if (auto viewLikeOp = dyn_cast<ViewLikeOpInterface>(user)) {
+      assert(viewLikeOp->getNumResults() == 1);
+      if (isWrittenTo(viewLikeOp->getResult(0)))
+        return true;
+    }
+  }
+  return false;
+}
+
+void EnforceImmutableFuncArgsPass::runOnOperation() {
+
+  func::FuncOp funcOp = getOperation();
+
+  LDBG("enforcing immutable function arguments in func " << funcOp.getName());
+
+  IRRewriter rewriter(funcOp->getContext());
+  rewriter.setInsertionPointToStart(&funcOp.getBody().front());
+  for (auto argument : funcOp.getArguments()) {
+
+    auto argType = dyn_cast<MemRefType>(argument.getType());
+    if (!argType) {
+      emitError(argument.getLoc(),
+                "function has argument with non memref type");
+      return signalPassFailure();
+    }
+
+    if (!isWrittenTo(argument))
+      continue;
+
+    LDBG("Found a function argument is being written to " << argument);
+    Value allocatedMemref =
+        rewriter.create<memref::AllocOp>(funcOp.getLoc(), argType);
+    rewriter.replaceAllUsesWith(argument, allocatedMemref);
+    rewriter.create<memref::CopyOp>(funcOp.getLoc(), argument, allocatedMemref);
+  }
+}
+
+//===----------------------------------------------------------------------===//
+// EnforceImmutableFuncArgs construction
+//===----------------------------------------------------------------------===//
+
+std::unique_ptr<Pass>
+mlir::bufferization::createEnforceImmutableFuncArgsPass() {
+  return std::make_unique<EnforceImmutableFuncArgsPass>();
+}
diff --git a/mlir/test/Dialect/Bufferization/Transforms/enforce-immutable-func-args.mlir b/mlir/test/Dialect/Bufferization/Transforms/enforce-immutable-func-args.mlir
new file mode 100644
index 00000000000000..13019d2fbf5af4
--- /dev/null
+++ b/mlir/test/Dialect/Bufferization/Transforms/enforce-immutable-func-args.mlir
@@ -0,0 +1,248 @@
+// RUN: mlir-opt --split-input-file --enforce-immutable-func-args %s -o - | FileCheck %s
+
+
+// CHECK-LABEL:   func.func @func_no_input() {
+// CHECK:           return
+// CHECK:         }
+
+func.func @func_no_input() {
+  return
+}
+
+// -----
+
+// CHECK-LABEL:   func.func private @func_with_returned_argument(
+// CHECK-SAME:                                                   %[[VAL_0:.*]]: memref<1x13x21x3xf32>) -> memref<1x13x21x3xf32> {
+// CHECK:           return %[[VAL_0]] : memref<1x13x21x3xf32>
+// CHECK:         }
+
+func.func private @func_with_returned_argument(%arg0: memref<1x13x21x3xf32>) -> (memref<1x13x21x3xf32>) {
+  return %arg0 : memref<1x13x21x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL:   func.func private @func_with_modified_argument_directly(
+// CHECK-SAME:                                                            %[[VAL_0:.*]]: memref<1x13x21x3xf32>, %[[VAL_1:.*]]: memref<1x13x21x3xf32>) -> memref<1x13x21x3xf32> {
+// CHECK:           %[[VAL_2:.*]] = memref.alloc() : memref<1x13x21x3xf32>
+// CHECK:           memref.copy %[[VAL_0]], %[[VAL_2]] : memref<1x13x21x3xf32> to memref<1x13x21x3xf32>
+// CHECK:           %[[VAL_3:.*]] = memref.alloc() {alignment = 64 : i64} : memref<1x13x21x3xf32>
+// CHECK:           linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[VAL_2]], %[[VAL_1]] : memref<1x13x21x3xf32>, memref<1x13x21x3xf32>) outs(%[[VAL_2]] : memref<1x13x21x3xf32>) {
+// CHECK:           ^bb0(%[[VAL_4:.*]]: f32, %[[VAL_5:.*]]: f32, %[[VAL_6:.*]]: f32):
+// CHECK:             %[[VAL_7:.*]] = arith.addf %[[VAL_4]], %[[VAL_5]] : f32
+// CHECK:             linalg.yield %[[VAL_7]] : f32
+// CHECK:           }
+// CHECK:           return %[[VAL_3]] : memref<1x13x21x3xf32>
+// CHECK:         }
+
+func.func private @func_with_modified_argument_directly(%arg0: memref<1x13x21x3xf32>, %arg1: memref<1x13x21x3xf32>) -> (memref<1x13x21x3xf32>){
+  %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x13x21x3xf32>
+  linalg.generic {
+    indexing_maps = [
+      affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+      affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+      affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+    ],
+    iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+  }
+  ins(%arg0, %arg1 : memref<1x13x21x3xf32>, memref<1x13x21x3xf32>)
+  outs(%arg0 : memref<1x13x21x3xf32>) {
+  ^bb0(%in: f32, %in_0: f32, %out: f32):
+    %0 = arith.addf %in, %in_0 : f32
+    linalg.yield %0 : f32
+  }
+  return %alloc : memref<1x13x21x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL:   func.func private @func_with_modified_argument_directly_and_returned(
+// CHECK-SAME:                                                                         %[[VAL_0:.*]]: memref<1x13x21x3xf32>, %[[VAL_1:.*]]: memref<1x13x21x3xf32>) -> memref<1x13x21x3xf32> {
+// CHECK:           %[[VAL_2:.*]] = memref.alloc() : memref<1x13x21x3xf32>
+// CHECK:           memref.copy %[[VAL_0]], %[[VAL_2]] : memref<1x13x21x3xf32> to memref<1x13x21x3xf32>
+// CHECK:           linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[VAL_2]], %[[VAL_1]] : memref<1x13x21x3xf32>, memref<1x13x21x3xf32>) outs(%[[VAL_2]] : memref<1x13x21x3xf32>) {
+// CHECK:           ^bb0(%[[VAL_3:.*]]: f32, %[[VAL_4:.*]]: f32, %[[VAL_5:.*]]: f32):
+// CHECK:             %[[VAL_6:.*]] = arith.addf %[[VAL_3]], %[[VAL_4]] : f32
+// CHECK:             linalg.yield %[[VAL_6]] : f32
+// CHECK:           }
+// CHECK:           return %[[VAL_2]] : memref<1x13x21x3xf32>
+// CHECK:         }
+
+func.func private @func_with_modified_argument_directly_and_returned(%arg0: memref<1x13x21x3xf32>, %arg1: memref<1x13x21x3xf32>) -> (memref<1x13x21x3xf32>){
+  linalg.generic {
+    indexing_maps = [
+      affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+      affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+      affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+    ],
+    iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+  }
+  ins(%arg0, %arg1 : memref<1x13x21x3xf32>, memref<1x13x21x3xf32>)
+  outs(%arg0 : memref<1x13x21x3xf32>) {
+  ^bb0(%in: f32, %in_0: f32, %out: f32):
+    %0 = arith.addf %in, %in_0 : f32
+    linalg.yield %0 : f32
+  }
+  return %arg0 : memref<1x13x21x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL:   func.func private @func_with_modified_argument_directly_twice(
+// CHECK-SAME:                                                                  %[[VAL_0:.*]]: memref<1x13x21x3xf32>, %[[VAL_1:.*]]: memref<1x13x21x3xf32>) -> memref<1x13x21x3xf32> {
+// CHECK:           %[[VAL_2:.*]] = memref.alloc() : memref<1x13x21x3xf32>
+// CHECK:           memref.copy %[[VAL_0]], %[[VAL_2]] : memref<1x13x21x3xf32> to memref<1x13x21x3xf32>
+// CHECK:           %[[VAL_3:.*]] = memref.alloc() {alignment = 64 : i64} : memref<1x13x21x3xf32>
+// CHECK:           linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[VAL_2]], %[[VAL_1]] : memref<1x13x21x3xf32>, memref<1x13x21x3xf32>) outs(%[[VAL_2]] : memref<1x13x21x3xf32>) {
+// CHECK:           ^bb0(%[[VAL_4:.*]]: f32, %[[VAL_5:.*]]: f32, %[[VAL_6:.*]]: f32):
+// CHECK:             %[[VAL_7:.*]] = arith.addf %[[VAL_4]], %[[VAL_5]] : f32
+// CHECK:             linalg.yield %[[VAL_7]] : f32
+// CHECK:           }
+// CHECK:           linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[VAL_2]], %[[VAL_1]] : memref<1x13x21x3xf32>, memref<1x13x21x3xf32>) outs(%[[VAL_2]] : memref<1x13x21x3xf32>) {
+// CHECK:           ^bb0(%[[VAL_8:.*]]: f32, %[[VAL_9:.*]]: f32, %[[VAL_10:.*]]: f32):
+// CHECK:             %[[VAL_11:.*]] = arith.addf %[[VAL_8]], %[[VAL_9]] : f32
+// CHECK:             linalg.yield %[[VAL_11]] : f32
+// CHECK:           }
+// CHECK:           return %[[VAL_3]] : memref<1x13x21x3xf32>
+// CHECK:         }
+
+func.func private @func_with_modified_argument_directly_twice(%arg0: memref<1x13x21x3xf32>, %arg1: memref<1x13x21x3xf32>) -> (memref<1x13x21x3xf32>){
+  %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x13x21x3xf32>
+  linalg.generic {
+    indexing_maps = [
+      affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+      affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+      affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+    ],
+    iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+  }
+  ins(%arg0, %arg1 : memref<1x13x21x3xf32>, memref<1x13x21x3xf32>)
+  outs(%arg0 : memref<1x13x21x3xf32>) {
+  ^bb0(%in: f32, %in_0: f32, %out: f32):
+    %0 = arith.addf %in, %in_0 : f32
+    linalg.yield %0 : f32
+  }
+  linalg.generic {
+    indexing_maps = [
+      affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+      affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+      affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+    ],
+    iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+  }
+  ins(%arg0, %arg1 : memref<1x13x21x3xf32>, memref<1x13x21x3xf32>)
+  outs(%arg0 : memref<1x13x21x3xf32>) {
+  ^bb0(%in: f32, %in_0: f32, %out: f32):
+    %0 = arith.addf %in, %in_0 : f32
+    linalg.yield %0 : f32
+  }
+  return %alloc : memref<1x13x21x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL:   func.func private @func_with_modified_argument_directly(
+// CHECK-SAME:                                                            %[[VAL_0:.*]]: memref<5xi32, 1>, %[[VAL_1:.*]]: memref<5xi32, 1>, %[[VAL_2:.*]]: memref<5xi32, 1>) -> memref<5xi32, 1> {
+// CHECK:           %[[VAL_3:.*]] = memref.alloc() : memref<5xi32, 1>
+// CHECK:           memref.copy %[[VAL_2]], %[[VAL_3]] : memref<5xi32, 1> to memref<5xi32, 1>
+// CHECK:           %[[VAL_4:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_5:.*]] = arith.constant 5 : index
+// CHECK:           %[[VAL_6:.*]] = arith.constant 0 : index
+// CHECK:           scf.for %[[VAL_7:.*]] = %[[VAL_6]] to %[[VAL_5]] step %[[VAL_4]] {
+// CHECK:             %[[VAL_8:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_7]]] : memref<5xi32, 1>
+// CHECK:             %[[VAL_9:.*]] = arith.index_cast %[[VAL_8]] : i32 to index
+// CHECK:             %[[VAL_10:.*]] = memref.load %[[VAL_1]]{{\[}}%[[VAL_7]]] : memref<5xi32, 1>
+// CHECK:             %[[VAL_11:.*]] = memref.load %[[VAL_3]]{{\[}}%[[VAL_9]]] : memref<5xi32, 1>
+// CHECK:             %[[VAL_12:.*]] = arith.addi %[[VAL_10]], %[[VAL_11]] : i32
+// CHECK:             memref.store %[[VAL_12]], %[[VAL_3]]{{\[}}%[[VAL_9]]] : memref<5xi32, 1>
+// CHECK:           }
+// CHECK:           %[[VAL_13:.*]] = memref.alloc() : memref<5xi32, 1>
+// CHECK:           memref.copy %[[VAL_3]], %[[VAL_13]] : memref<5xi32, 1> to memref<5xi32, 1>
+// CHECK:           return %[[VAL_13]] : memref<5xi32, 1>
+// CHECK:         }
+
+func.func private @func_with_modified_argument_directly(%arg0: memref<5xi32, 1>, %arg1: memref<5xi32, 1>, %arg2: memref<5xi32, 1>) -> (memref<5xi32, 1>){
+  %c1 = arith.constant 1 : index
+  %c5 = arith.constant 5 : index
+  %c0 = arith.constant 0 : index
+  scf.for %arg3 = %c0 to %c5 step %c1 {
+    %0 = memref.load %arg0[%arg3] : memref<5xi32, 1>
+    %1 = arith.index_cast %0 : i32 to index
+    %2 = memref.load %arg1[%arg3] : memref<5xi32, 1>
+    %3 = memref.load %arg2[%1] : memref<5xi32, 1>
+    %4 = arith.addi %2, %3 : i32
+    memref.store %4, %arg2[%1] : memref<5xi32, 1>
+  }
+  %alloc = memref.alloc() : memref<5xi32, 1>
+  memref.copy %arg2, %alloc : memref<5xi32, 1> to memref<5xi32, 1>
+  return %alloc : memref<5xi32, 1>
+}
+
+// -----
+
+// CHECK-LABEL:   func.func private @func_with_modified_argument_indirectly(
+// CHECK-SAME:                                                              %[[VAL_0:.*]]: memref<3x3x4xf32, 1>) -> memref<3x3x4xf32, 1> {
+// CHECK:           %[[VAL_1:.*]] = memref.alloc() : memref<3x3x4xf32, 1>
+// CHECK:           memref.copy %[[VAL_0]], %[[VAL_1]] : memref<3x3x4xf32, 1> to memref<3x3x4xf32, 1>
+// CHECK:           %[[VAL_2:.*]] = memref.collapse_shape %[[VAL_1]] {{\[\[}}0, 1], [2]] : memref<3x3x4xf32, 1> into memref<9x4xf32, 1>
+// CHECK:           %[[VAL_3:.*]] = memref.expand_shape %[[VAL_2]] {{\[\[}}0, 1], [2]] output_shape [3, 3, 4] : memref<9x4xf32, 1> into memref<3x3x4xf32, 1>
+// CHECK:           linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel", "parallel"]} outs(%[[VAL_3]] : memref<3x3x4xf32, 1>) {
+// CHECK:           ^bb0(%[[VAL_4:.*]]: f32):
+// CHECK:             %[[VAL_5:.*]] = arith.addf %[[VAL_4]], %[[VAL_4]] : f32
+// CHECK:             linalg.yield %[[VAL_5]] : f32
+// CHECK:           }
+// CHECK:           return %[[VAL_3]] : memref<3x3x4xf32, 1>
+// CHECK:         }
+
+func.func private @func_with_modified_argument_indirectly(%arg0: memref<3x3x4xf32, 1>) -> (memref<3x3x4xf32, 1>) {
+  %collapse_arg = memref.collapse_shape %arg0 [[0, 1], [2]] : memref<3x3x4xf32, 1> into memref<9x4xf32, 1>
+  %expand_arg = memref.expand_shape %collapse_arg [[0, 1], [2]] output_shape [3, 3, 4] : memref<9x4xf32, 1> into memref<3x3x4xf32, 1>
+  linalg.generic {
+    indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
+    iterator_types = ["parallel", "parallel", "parallel"]
+  }
+  outs(%expand_arg : memref<3x3x4xf32, 1>) {
+  ^bb0(%out: f32):
+    %0 = arith.addf %out, %out : f32
+    linalg.yield %0 : f32
+  }
+  return %expand_arg: memref<3x3x4xf32, 1>
+}
+
+// -----
+
+// CHECK-LABEL:   func.func private @func_with_modified_argument_subview(
+// CHECK-SAME:                                                           %[[VAL_0:.*]]: memref<2x4x4xi32, 1>) -> memref<4x4xi32, 1> {
+// CHECK:           %[[VAL_1:.*]] = memref.alloc() : memref<2x4x4xi32, 1>
+// CHECK:           memref.copy %[[VAL_0]], %[[VAL_1]] : memref<2x4x4xi32, 1> to memref<2x4x4xi32, 1>
+// CHECK:           %[[VAL_2:.*]] = memref.subview %[[VAL_1]][0, 0, 0] [1, 4, 4] [1, 1, 1] : memref<2x4x4xi32, 1> to memref<1x4x4xi32, strided<[16, 4, 1]>, 1>
+// CHECK:           %[[VAL_3:.*]] = memref.collapse_shape %[[VAL_2]] {{\[\[}}0, 1], [2]] : memref<1x4x4xi32, strided<[16, 4, 1]>, 1> into memref<4x4xi32, strided<[4, 1]>, 1>
+// CHECK:           %[[VAL_4:.*]] = memref.cast %[[VAL_3]] : memref<4x4xi32, strided<[4, 1]>, 1> to memref<4x4xi32, 1>
+// CHECK:           linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel"]} outs(%[[VAL_4]] : memref<4x4xi32, 1>) {
+// CHECK:           ^bb0(%[[VAL_5:.*]]: i32):
+// CHECK:             %[[VAL_6:.*]] = arith.addi %[[VAL_5]], %[[VAL_5]] : i32
+// CHECK:             linalg.yield %[[VAL_6]] : i32
+// CHECK:           }
+// CHECK:           %[[VAL_7:.*]] = memref.alloc() : memref<4x4xi32, 1>
+// CHECK:           memref.copy %[[VAL_4]], %[[VAL_7]] : memref<4x4xi32, 1> to memref<4x4xi32, 1>
+// CHECK:           return %[[VAL_7]] : memref<4x4xi32, 1>
+// CHECK:         }
+
+func.func private @func_with_modified_argument_subview(%arg0: memref<2x4x4xi32, 1>) -> ( memref<4x4xi32, 1>){
+  %subview = memref.subview %arg0[0, 0, 0] [1, 4, 4] [1, 1, 1] : memref<2x4x4xi32, 1> to memref<1x4x4xi32, strided<[16, 4, 1]>, 1>
+  %collapse_shape = memref.collapse_shape %subview [[0, 1], [2]] : memref<1x4x4xi32, strided<[16, 4, 1]>, 1> into memref<4x4xi32, strided<[4, 1]>, 1>
+  %cast = memref.cast %collapse_shape : memref<4x4xi32, strided<[4, 1]>, 1> to memref<4x4xi32, 1>
+  linalg.generic {
+    indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>],
+    iterator_types = ["parallel", "parallel"]
+  }
+  outs(%cas...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/113130


More information about the Mlir-commits mailing list