[Mlir-commits] [mlir] [mlir][bufferization]-Add enforce immutable func args pass (PR #113130)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun Oct 20 23:55:39 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Amir Bishara (amirBish)
<details>
<summary>Changes</summary>
Adding a pass which allocates a new a buffer for each input argument of the function it operates on and is being written to, also copying it into the allocated buffer by a `memref.copy`.
---
Patch is 20.27 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/113130.diff
5 Files Affected:
- (modified) mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h (+4)
- (modified) mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td (+14)
- (modified) mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt (+1)
- (added) mlir/lib/Dialect/Bufferization/Transforms/EnforceImmutableFuncArgs.cpp (+101)
- (added) mlir/test/Dialect/Bufferization/Transforms/enforce-immutable-func-args.mlir (+248)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
index 72abb5b3f1f94e..e17914fbbd5840 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
@@ -229,6 +229,10 @@ createPromoteBuffersToStackPass(std::function<bool(Value)> isSmallAlloc);
/// insert_slice ops.
std::unique_ptr<Pass> createEmptyTensorEliminationPass();
+// Create a pass that enforces read only buffers of the
+// relevant function arguments.
+std::unique_ptr<Pass> createEnforceImmutableFuncArgsPass();
+
//===----------------------------------------------------------------------===//
// Registration
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
index cc5463ea968fc3..fb2b4d3a305f4a 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
@@ -595,4 +595,18 @@ def EmptyTensorElimination : Pass<"eliminate-empty-tensors"> {
let constructor = "mlir::bufferization::createEmptyTensorEliminationPass()";
}
+def EnforceImmutableFuncArgs : Pass<"enforce-immutable-func-args", "func::FuncOp"> {
+ let summary = "Enforcing function's arguments immutabilty by inserting allocOps and copy";
+ let description = [{
+ This pass allocates a new a buffer for each input argument of the function
+ which is being written to and marked to be enforced, also copying it into the
+ allocated buffer.
+ This will avoid in place memory updates for the function's arguments and
+ make it immutable/read-only buffer.
+ }];
+ let constructor = "mlir::bufferization::createEnforceImmutableFuncArgsPass()";
+ let dependentDialects = ["memref::MemRefDialect"];
+}
+
+
#endif // MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_PASSES
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt
index 50104e8f8346b4..25de31c179a31d 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt
@@ -16,6 +16,7 @@ add_mlir_dialect_library(MLIRBufferizationTransforms
OwnershipBasedBufferDeallocation.cpp
TensorCopyInsertion.cpp
OptimizeAllocationLiveness.cpp
+ EnforceImmutableFuncArgs.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Bufferization
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/EnforceImmutableFuncArgs.cpp b/mlir/lib/Dialect/Bufferization/Transforms/EnforceImmutableFuncArgs.cpp
new file mode 100644
index 00000000000000..84f201c141a3d1
--- /dev/null
+++ b/mlir/lib/Dialect/Bufferization/Transforms/EnforceImmutableFuncArgs.cpp
@@ -0,0 +1,101 @@
+//===- OptimizeAllocationLiveness.cpp - impl. optimize allocation liveness pass
+//-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass for optimizing allocation liveness.
+// The pass moves the deallocation operation after the last user of the
+// allocated buffer.
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Bufferization/Transforms/Passes.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/IR/Operation.h"
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "enforce-immutable-func-args"
+#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
+#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
+
+using namespace mlir;
+
+namespace mlir {
+namespace bufferization {
+#define GEN_PASS_DEF_ENFORCEIMMUTABLEFUNCARGS
+#include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
+} // namespace bufferization
+} // namespace mlir
+
+// Checks if there is any operation which tries to write
+// into `buffer`.
+// This method assumes buffer has `MemRefType`.
+static bool isWrittenTo(Value buffer);
+
+namespace {
+/// This pass allocates a new a buffer for each input argument of the function
+/// which is being written to, also copying it into the allocated buffer.
+/// This will avoid in place memory updates for the kernel's arguments and
+/// make them immutable/read-only buffers.
+struct EnforceImmutableFuncArgsPass
+ : public bufferization::impl::EnforceImmutableFuncArgsBase<
+ EnforceImmutableFuncArgsPass> {
+ void runOnOperation() final;
+};
+} // end anonymous namespace.
+
+static bool isWrittenTo(Value buffer) {
+ assert(isa<MemRefType>(buffer.getType()));
+
+ for (auto user : buffer.getUsers()) {
+ if (hasEffect<MemoryEffects::Write>(user, buffer))
+ return true;
+ if (auto viewLikeOp = dyn_cast<ViewLikeOpInterface>(user)) {
+ assert(viewLikeOp->getNumResults() == 1);
+ if (isWrittenTo(viewLikeOp->getResult(0)))
+ return true;
+ }
+ }
+ return false;
+}
+
+void EnforceImmutableFuncArgsPass::runOnOperation() {
+
+ func::FuncOp funcOp = getOperation();
+
+ LDBG("enforcing immutable function arguments in func " << funcOp.getName());
+
+ IRRewriter rewriter(funcOp->getContext());
+ rewriter.setInsertionPointToStart(&funcOp.getBody().front());
+ for (auto argument : funcOp.getArguments()) {
+
+ auto argType = dyn_cast<MemRefType>(argument.getType());
+ if (!argType) {
+ emitError(argument.getLoc(),
+ "function has argument with non memref type");
+ return signalPassFailure();
+ }
+
+ if (!isWrittenTo(argument))
+ continue;
+
+ LDBG("Found a function argument is being written to " << argument);
+ Value allocatedMemref =
+ rewriter.create<memref::AllocOp>(funcOp.getLoc(), argType);
+ rewriter.replaceAllUsesWith(argument, allocatedMemref);
+ rewriter.create<memref::CopyOp>(funcOp.getLoc(), argument, allocatedMemref);
+ }
+}
+
+//===----------------------------------------------------------------------===//
+// EnforceImmutableFuncArgs construction
+//===----------------------------------------------------------------------===//
+
+std::unique_ptr<Pass>
+mlir::bufferization::createEnforceImmutableFuncArgsPass() {
+ return std::make_unique<EnforceImmutableFuncArgsPass>();
+}
diff --git a/mlir/test/Dialect/Bufferization/Transforms/enforce-immutable-func-args.mlir b/mlir/test/Dialect/Bufferization/Transforms/enforce-immutable-func-args.mlir
new file mode 100644
index 00000000000000..13019d2fbf5af4
--- /dev/null
+++ b/mlir/test/Dialect/Bufferization/Transforms/enforce-immutable-func-args.mlir
@@ -0,0 +1,248 @@
+// RUN: mlir-opt --split-input-file --enforce-immutable-func-args %s -o - | FileCheck %s
+
+
+// CHECK-LABEL: func.func @func_no_input() {
+// CHECK: return
+// CHECK: }
+
+func.func @func_no_input() {
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func.func private @func_with_returned_argument(
+// CHECK-SAME: %[[VAL_0:.*]]: memref<1x13x21x3xf32>) -> memref<1x13x21x3xf32> {
+// CHECK: return %[[VAL_0]] : memref<1x13x21x3xf32>
+// CHECK: }
+
+func.func private @func_with_returned_argument(%arg0: memref<1x13x21x3xf32>) -> (memref<1x13x21x3xf32>) {
+ return %arg0 : memref<1x13x21x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func private @func_with_modified_argument_directly(
+// CHECK-SAME: %[[VAL_0:.*]]: memref<1x13x21x3xf32>, %[[VAL_1:.*]]: memref<1x13x21x3xf32>) -> memref<1x13x21x3xf32> {
+// CHECK: %[[VAL_2:.*]] = memref.alloc() : memref<1x13x21x3xf32>
+// CHECK: memref.copy %[[VAL_0]], %[[VAL_2]] : memref<1x13x21x3xf32> to memref<1x13x21x3xf32>
+// CHECK: %[[VAL_3:.*]] = memref.alloc() {alignment = 64 : i64} : memref<1x13x21x3xf32>
+// CHECK: linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[VAL_2]], %[[VAL_1]] : memref<1x13x21x3xf32>, memref<1x13x21x3xf32>) outs(%[[VAL_2]] : memref<1x13x21x3xf32>) {
+// CHECK: ^bb0(%[[VAL_4:.*]]: f32, %[[VAL_5:.*]]: f32, %[[VAL_6:.*]]: f32):
+// CHECK: %[[VAL_7:.*]] = arith.addf %[[VAL_4]], %[[VAL_5]] : f32
+// CHECK: linalg.yield %[[VAL_7]] : f32
+// CHECK: }
+// CHECK: return %[[VAL_3]] : memref<1x13x21x3xf32>
+// CHECK: }
+
+func.func private @func_with_modified_argument_directly(%arg0: memref<1x13x21x3xf32>, %arg1: memref<1x13x21x3xf32>) -> (memref<1x13x21x3xf32>){
+ %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x13x21x3xf32>
+ linalg.generic {
+ indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+ ],
+ iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+ }
+ ins(%arg0, %arg1 : memref<1x13x21x3xf32>, memref<1x13x21x3xf32>)
+ outs(%arg0 : memref<1x13x21x3xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %0 = arith.addf %in, %in_0 : f32
+ linalg.yield %0 : f32
+ }
+ return %alloc : memref<1x13x21x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func private @func_with_modified_argument_directly_and_returned(
+// CHECK-SAME: %[[VAL_0:.*]]: memref<1x13x21x3xf32>, %[[VAL_1:.*]]: memref<1x13x21x3xf32>) -> memref<1x13x21x3xf32> {
+// CHECK: %[[VAL_2:.*]] = memref.alloc() : memref<1x13x21x3xf32>
+// CHECK: memref.copy %[[VAL_0]], %[[VAL_2]] : memref<1x13x21x3xf32> to memref<1x13x21x3xf32>
+// CHECK: linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[VAL_2]], %[[VAL_1]] : memref<1x13x21x3xf32>, memref<1x13x21x3xf32>) outs(%[[VAL_2]] : memref<1x13x21x3xf32>) {
+// CHECK: ^bb0(%[[VAL_3:.*]]: f32, %[[VAL_4:.*]]: f32, %[[VAL_5:.*]]: f32):
+// CHECK: %[[VAL_6:.*]] = arith.addf %[[VAL_3]], %[[VAL_4]] : f32
+// CHECK: linalg.yield %[[VAL_6]] : f32
+// CHECK: }
+// CHECK: return %[[VAL_2]] : memref<1x13x21x3xf32>
+// CHECK: }
+
+func.func private @func_with_modified_argument_directly_and_returned(%arg0: memref<1x13x21x3xf32>, %arg1: memref<1x13x21x3xf32>) -> (memref<1x13x21x3xf32>){
+ linalg.generic {
+ indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+ ],
+ iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+ }
+ ins(%arg0, %arg1 : memref<1x13x21x3xf32>, memref<1x13x21x3xf32>)
+ outs(%arg0 : memref<1x13x21x3xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %0 = arith.addf %in, %in_0 : f32
+ linalg.yield %0 : f32
+ }
+ return %arg0 : memref<1x13x21x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func private @func_with_modified_argument_directly_twice(
+// CHECK-SAME: %[[VAL_0:.*]]: memref<1x13x21x3xf32>, %[[VAL_1:.*]]: memref<1x13x21x3xf32>) -> memref<1x13x21x3xf32> {
+// CHECK: %[[VAL_2:.*]] = memref.alloc() : memref<1x13x21x3xf32>
+// CHECK: memref.copy %[[VAL_0]], %[[VAL_2]] : memref<1x13x21x3xf32> to memref<1x13x21x3xf32>
+// CHECK: %[[VAL_3:.*]] = memref.alloc() {alignment = 64 : i64} : memref<1x13x21x3xf32>
+// CHECK: linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[VAL_2]], %[[VAL_1]] : memref<1x13x21x3xf32>, memref<1x13x21x3xf32>) outs(%[[VAL_2]] : memref<1x13x21x3xf32>) {
+// CHECK: ^bb0(%[[VAL_4:.*]]: f32, %[[VAL_5:.*]]: f32, %[[VAL_6:.*]]: f32):
+// CHECK: %[[VAL_7:.*]] = arith.addf %[[VAL_4]], %[[VAL_5]] : f32
+// CHECK: linalg.yield %[[VAL_7]] : f32
+// CHECK: }
+// CHECK: linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[VAL_2]], %[[VAL_1]] : memref<1x13x21x3xf32>, memref<1x13x21x3xf32>) outs(%[[VAL_2]] : memref<1x13x21x3xf32>) {
+// CHECK: ^bb0(%[[VAL_8:.*]]: f32, %[[VAL_9:.*]]: f32, %[[VAL_10:.*]]: f32):
+// CHECK: %[[VAL_11:.*]] = arith.addf %[[VAL_8]], %[[VAL_9]] : f32
+// CHECK: linalg.yield %[[VAL_11]] : f32
+// CHECK: }
+// CHECK: return %[[VAL_3]] : memref<1x13x21x3xf32>
+// CHECK: }
+
+func.func private @func_with_modified_argument_directly_twice(%arg0: memref<1x13x21x3xf32>, %arg1: memref<1x13x21x3xf32>) -> (memref<1x13x21x3xf32>){
+ %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x13x21x3xf32>
+ linalg.generic {
+ indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+ ],
+ iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+ }
+ ins(%arg0, %arg1 : memref<1x13x21x3xf32>, memref<1x13x21x3xf32>)
+ outs(%arg0 : memref<1x13x21x3xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %0 = arith.addf %in, %in_0 : f32
+ linalg.yield %0 : f32
+ }
+ linalg.generic {
+ indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+ ],
+ iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+ }
+ ins(%arg0, %arg1 : memref<1x13x21x3xf32>, memref<1x13x21x3xf32>)
+ outs(%arg0 : memref<1x13x21x3xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %0 = arith.addf %in, %in_0 : f32
+ linalg.yield %0 : f32
+ }
+ return %alloc : memref<1x13x21x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func private @func_with_modified_argument_directly(
+// CHECK-SAME: %[[VAL_0:.*]]: memref<5xi32, 1>, %[[VAL_1:.*]]: memref<5xi32, 1>, %[[VAL_2:.*]]: memref<5xi32, 1>) -> memref<5xi32, 1> {
+// CHECK: %[[VAL_3:.*]] = memref.alloc() : memref<5xi32, 1>
+// CHECK: memref.copy %[[VAL_2]], %[[VAL_3]] : memref<5xi32, 1> to memref<5xi32, 1>
+// CHECK: %[[VAL_4:.*]] = arith.constant 1 : index
+// CHECK: %[[VAL_5:.*]] = arith.constant 5 : index
+// CHECK: %[[VAL_6:.*]] = arith.constant 0 : index
+// CHECK: scf.for %[[VAL_7:.*]] = %[[VAL_6]] to %[[VAL_5]] step %[[VAL_4]] {
+// CHECK: %[[VAL_8:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_7]]] : memref<5xi32, 1>
+// CHECK: %[[VAL_9:.*]] = arith.index_cast %[[VAL_8]] : i32 to index
+// CHECK: %[[VAL_10:.*]] = memref.load %[[VAL_1]]{{\[}}%[[VAL_7]]] : memref<5xi32, 1>
+// CHECK: %[[VAL_11:.*]] = memref.load %[[VAL_3]]{{\[}}%[[VAL_9]]] : memref<5xi32, 1>
+// CHECK: %[[VAL_12:.*]] = arith.addi %[[VAL_10]], %[[VAL_11]] : i32
+// CHECK: memref.store %[[VAL_12]], %[[VAL_3]]{{\[}}%[[VAL_9]]] : memref<5xi32, 1>
+// CHECK: }
+// CHECK: %[[VAL_13:.*]] = memref.alloc() : memref<5xi32, 1>
+// CHECK: memref.copy %[[VAL_3]], %[[VAL_13]] : memref<5xi32, 1> to memref<5xi32, 1>
+// CHECK: return %[[VAL_13]] : memref<5xi32, 1>
+// CHECK: }
+
+func.func private @func_with_modified_argument_directly(%arg0: memref<5xi32, 1>, %arg1: memref<5xi32, 1>, %arg2: memref<5xi32, 1>) -> (memref<5xi32, 1>){
+ %c1 = arith.constant 1 : index
+ %c5 = arith.constant 5 : index
+ %c0 = arith.constant 0 : index
+ scf.for %arg3 = %c0 to %c5 step %c1 {
+ %0 = memref.load %arg0[%arg3] : memref<5xi32, 1>
+ %1 = arith.index_cast %0 : i32 to index
+ %2 = memref.load %arg1[%arg3] : memref<5xi32, 1>
+ %3 = memref.load %arg2[%1] : memref<5xi32, 1>
+ %4 = arith.addi %2, %3 : i32
+ memref.store %4, %arg2[%1] : memref<5xi32, 1>
+ }
+ %alloc = memref.alloc() : memref<5xi32, 1>
+ memref.copy %arg2, %alloc : memref<5xi32, 1> to memref<5xi32, 1>
+ return %alloc : memref<5xi32, 1>
+}
+
+// -----
+
+// CHECK-LABEL: func.func private @func_with_modified_argument_indirectly(
+// CHECK-SAME: %[[VAL_0:.*]]: memref<3x3x4xf32, 1>) -> memref<3x3x4xf32, 1> {
+// CHECK: %[[VAL_1:.*]] = memref.alloc() : memref<3x3x4xf32, 1>
+// CHECK: memref.copy %[[VAL_0]], %[[VAL_1]] : memref<3x3x4xf32, 1> to memref<3x3x4xf32, 1>
+// CHECK: %[[VAL_2:.*]] = memref.collapse_shape %[[VAL_1]] {{\[\[}}0, 1], [2]] : memref<3x3x4xf32, 1> into memref<9x4xf32, 1>
+// CHECK: %[[VAL_3:.*]] = memref.expand_shape %[[VAL_2]] {{\[\[}}0, 1], [2]] output_shape [3, 3, 4] : memref<9x4xf32, 1> into memref<3x3x4xf32, 1>
+// CHECK: linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel", "parallel"]} outs(%[[VAL_3]] : memref<3x3x4xf32, 1>) {
+// CHECK: ^bb0(%[[VAL_4:.*]]: f32):
+// CHECK: %[[VAL_5:.*]] = arith.addf %[[VAL_4]], %[[VAL_4]] : f32
+// CHECK: linalg.yield %[[VAL_5]] : f32
+// CHECK: }
+// CHECK: return %[[VAL_3]] : memref<3x3x4xf32, 1>
+// CHECK: }
+
+func.func private @func_with_modified_argument_indirectly(%arg0: memref<3x3x4xf32, 1>) -> (memref<3x3x4xf32, 1>) {
+ %collapse_arg = memref.collapse_shape %arg0 [[0, 1], [2]] : memref<3x3x4xf32, 1> into memref<9x4xf32, 1>
+ %expand_arg = memref.expand_shape %collapse_arg [[0, 1], [2]] output_shape [3, 3, 4] : memref<9x4xf32, 1> into memref<3x3x4xf32, 1>
+ linalg.generic {
+ indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
+ iterator_types = ["parallel", "parallel", "parallel"]
+ }
+ outs(%expand_arg : memref<3x3x4xf32, 1>) {
+ ^bb0(%out: f32):
+ %0 = arith.addf %out, %out : f32
+ linalg.yield %0 : f32
+ }
+ return %expand_arg: memref<3x3x4xf32, 1>
+}
+
+// -----
+
+// CHECK-LABEL: func.func private @func_with_modified_argument_subview(
+// CHECK-SAME: %[[VAL_0:.*]]: memref<2x4x4xi32, 1>) -> memref<4x4xi32, 1> {
+// CHECK: %[[VAL_1:.*]] = memref.alloc() : memref<2x4x4xi32, 1>
+// CHECK: memref.copy %[[VAL_0]], %[[VAL_1]] : memref<2x4x4xi32, 1> to memref<2x4x4xi32, 1>
+// CHECK: %[[VAL_2:.*]] = memref.subview %[[VAL_1]][0, 0, 0] [1, 4, 4] [1, 1, 1] : memref<2x4x4xi32, 1> to memref<1x4x4xi32, strided<[16, 4, 1]>, 1>
+// CHECK: %[[VAL_3:.*]] = memref.collapse_shape %[[VAL_2]] {{\[\[}}0, 1], [2]] : memref<1x4x4xi32, strided<[16, 4, 1]>, 1> into memref<4x4xi32, strided<[4, 1]>, 1>
+// CHECK: %[[VAL_4:.*]] = memref.cast %[[VAL_3]] : memref<4x4xi32, strided<[4, 1]>, 1> to memref<4x4xi32, 1>
+// CHECK: linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel"]} outs(%[[VAL_4]] : memref<4x4xi32, 1>) {
+// CHECK: ^bb0(%[[VAL_5:.*]]: i32):
+// CHECK: %[[VAL_6:.*]] = arith.addi %[[VAL_5]], %[[VAL_5]] : i32
+// CHECK: linalg.yield %[[VAL_6]] : i32
+// CHECK: }
+// CHECK: %[[VAL_7:.*]] = memref.alloc() : memref<4x4xi32, 1>
+// CHECK: memref.copy %[[VAL_4]], %[[VAL_7]] : memref<4x4xi32, 1> to memref<4x4xi32, 1>
+// CHECK: return %[[VAL_7]] : memref<4x4xi32, 1>
+// CHECK: }
+
+func.func private @func_with_modified_argument_subview(%arg0: memref<2x4x4xi32, 1>) -> ( memref<4x4xi32, 1>){
+ %subview = memref.subview %arg0[0, 0, 0] [1, 4, 4] [1, 1, 1] : memref<2x4x4xi32, 1> to memref<1x4x4xi32, strided<[16, 4, 1]>, 1>
+ %collapse_shape = memref.collapse_shape %subview [[0, 1], [2]] : memref<1x4x4xi32, strided<[16, 4, 1]>, 1> into memref<4x4xi32, strided<[4, 1]>, 1>
+ %cast = memref.cast %collapse_shape : memref<4x4xi32, strided<[4, 1]>, 1> to memref<4x4xi32, 1>
+ linalg.generic {
+ indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel"]
+ }
+ outs(%cas...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/113130
More information about the Mlir-commits
mailing list