[Mlir-commits] [mlir] 2ecf608 - [mlir]Fix compose subview (#80551)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Feb 7 11:49:32 PST 2024
Author: lonely eagle
Date: 2024-02-07T20:49:27+01:00
New Revision: 2ecf608829252d7d5b530a03b87817cd948a3386
URL: https://github.com/llvm/llvm-project/commit/2ecf608829252d7d5b530a03b87817cd948a3386
DIFF: https://github.com/llvm/llvm-project/commit/2ecf608829252d7d5b530a03b87817cd948a3386.diff
LOG: [mlir]Fix compose subview (#80551)
I found a bug in `test-compose-subview`,You can see the example I gave.
```
#map = affine_map<() -> ()>
module {
func.func private @fun(%arg0: memref<10x10xf32>, %arg1: memref<5x5xf32>) -> memref<5x5xf32> {
%c0 = arith.constant 0 : index
%c5 = arith.constant 5 : index
%c1 = arith.constant 1 : index
%subview = memref.subview %arg0[0, 0] [5, 5] [1, 1] : memref<10x10xf32> to memref<5x5xf32, strided<[10, 1]>>
%alloc = memref.alloc() : memref<5x5xf32>
scf.for %arg2 = %c0 to %c5 step %c1 {
scf.for %arg3 = %c0 to %c5 step %c1 {
%subview_0 = memref.subview %subview[%arg2, %arg3] [1, 1] [1, 1] : memref<5x5xf32, strided<[10, 1]>> to memref<f32, strided<[], offset: ?>>
%subview_1 = memref.subview %arg1[%arg2, %arg3] [1, 1] [1, 1] : memref<5x5xf32> to memref<f32, strided<[], offset: ?>>
%alloc_2 = memref.alloc() : memref<f32>
linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = []} ins(%subview_0, %subview_1 : memref<f32, strided<[], offset: ?>>, memref<f32, strided<[], offset: ?>>) outs(%alloc_2 : memref<f32>) {
^bb0(%in: f32, %in_4: f32, %out: f32):
%0 = arith.addf %in, %in_4 : f32
linalg.yield %0 : f32
}
%subview_3 = memref.subview %alloc[%arg2, %arg3] [1, 1] [1, 1] : memref<5x5xf32> to memref<f32, strided<[], offset: ?>>
memref.copy %alloc_2, %subview_3 : memref<f32> to memref<f32, strided<[], offset: ?>>
}
}
return %alloc : memref<5x5xf32>
}
func.func @test(%arg0: memref<10x10xf32>, %arg1: memref<5x5xf32>) -> memref<5x5xf32> {
%0 = call @fun(%arg0, %arg1) : (memref<10x10xf32>, memref<5x5xf32>) -> memref<5x5xf32>
return %0 : memref<5x5xf32>
}
}
```
When I run `mlir-opt test.mlir ---test-compose-subview`.
```
test.mlir:14:9: error: 'linalg.generic' op expected operand rank (2) to match the result rank of indexing_map #0 (0)
linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = []} ins(%subview_0, %subview_1 : memref<f32, strided<[], offset: ?>>, memref<f32, strided<[], offset: ?>>) outs(%alloc_2 : memref<f32>) {
^
test1.mlir:14:9: note: see current operation:
"linalg.generic"(%4, %5, %6) <{indexing_maps = [affine_map<() -> ()>, affine_map<() -> ()>, affine_map<() -> ()>], iterator_types = [], operandSegmentSizes = array<i32: 2, 1>}> ({
^bb0(%arg4: f32, %arg5: f32, %arg6: f32):
%8 = "arith.addf"(%arg4, %arg5) <{fastmath = #arith.fastmath<none>}> : (f32, f32) -> f32
"linalg.yield"(%8) : (f32) -> ()
}) : (memref<1x1xf32, strided<[10, 1], offset: ?>>, memref<f32, strided<[], offset: ?>>, memref<f32>) -> ()
```
This PR fixes that.In the meantime I've extended this PR to handle cases
where stride is greater than 1.
```
func.func private @Unknown0(%arg0: memref<10x10xf32>, %arg1: memref<5x5xf32>) -> memref<5x5xf32> {
%c0 = arith.constant 0 : index
%c5 = arith.constant 5 : index
%c1 = arith.constant 1 : index
%subview = memref.subview %arg0[0, 0] [5, 5] [2, 2] : memref<10x10xf32> to memref<5x5xf32, strided<[20, 2]>>
%alloc = memref.alloc() : memref<5x5xf32>
scf.for %arg2 = %c0 to %c5 step %c1 {
scf.for %arg3 = %c0 to %c5 step %c1 {
%subview_0 = memref.subview %subview[%arg2, %arg3] [1, 1] [1, 1] : memref<5x5xf32, strided<[20, 2]>> to memref<f32, strided<[], offset: ?>>
%subview_1 = memref.subview %arg1[%arg2, %arg3] [1, 1] [1, 1] : memref<5x5xf32> to memref<f32, strided<[], offset: ?>>
%alloc_2 = memref.alloc() : memref<f32>
linalg.generic {indexing_maps = [affine_map<() -> ()>, affine_map<() -> ()>, affine_map<() -> ()>], iterator_types = []} ins(%subview_0, %subview_1 : memref<f32, strided<[], offset: ?>>, memref<f32, strided<[], offset: ?>>) outs(%alloc_2 : memref<f32>) {
^bb0(%in: f32, %in_4: f32, %out: f32):
%0 = arith.addf %in, %in_4 : f32
linalg.yield %0 : f32
}
%subview_3 = memref.subview %alloc[%arg2, %arg3] [1, 1] [1, 1] : memref<5x5xf32> to memref<f32, strided<[], offset: ?>>
memref.copy %alloc_2, %subview_3 : memref<f32> to memref<f32, strided<[], offset: ?>>
}
}
return %alloc : memref<5x5xf32>
}
$ mlir-opt test.mlir -test-compose-subview
#map = affine_map<()[s0] -> (s0 * 2)>
#map1 = affine_map<() -> ()>
module {
func.func private @Unknown0(%arg0: memref<10x10xf32>, %arg1: memref<5x5xf32>) -> memref<5x5xf32> {
%c0 = arith.constant 0 : index
%c5 = arith.constant 5 : index
%c1 = arith.constant 1 : index
%alloc = memref.alloc() : memref<5x5xf32>
scf.for %arg2 = %c0 to %c5 step %c1 {
scf.for %arg3 = %c0 to %c5 step %c1 {
%0 = affine.apply #map()[%arg2]
%1 = affine.apply #map()[%arg3]
%subview = memref.subview %arg0[%0, %1] [1, 1] [2, 2] : memref<10x10xf32> to memref<f32, strided<[], offset: ?>>
%subview_0 = memref.subview %arg1[%arg2, %arg3] [1, 1] [1, 1] : memref<5x5xf32> to memref<f32, strided<[], offset: ?>>
%alloc_1 = memref.alloc() : memref<f32>
linalg.generic {indexing_maps = [#map1, #map1, #map1], iterator_types = []} ins(%subview, %subview_0 : memref<f32, strided<[], offset: ?>>, memref<f32, strided<[], offset: ?>>) outs(%alloc_1 : memref<f32>) {
^bb0(%in: f32, %in_3: f32, %out: f32):
%2 = arith.addf %in, %in_3 : f32
linalg.yield %2 : f32
}
%subview_2 = memref.subview %alloc[%arg2, %arg3] [1, 1] [1, 1] : memref<5x5xf32> to memref<f32, strided<[], offset: ?>>
memref.copy %alloc_1, %subview_2 : memref<f32> to memref<f32, strided<[], offset: ?>>
}
}
return %alloc : memref<5x5xf32>
}
}
```
Added:
Modified:
mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp
mlir/test/Transforms/compose-subview.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp b/mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp
index 431d270b0a2cb7..79c3277c1280d8 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp
@@ -24,8 +24,7 @@ using namespace mlir;
namespace {
-// Replaces a subview of a subview with a single subview. Only supports subview
-// ops with static sizes and static strides of 1 (both static and dynamic
+// Replaces a subview of a subview with a single subview(both static and dynamic
// offsets are supported).
struct ComposeSubViewOpPattern : public OpRewritePattern<memref::SubViewOp> {
using OpRewritePattern::OpRewritePattern;
@@ -51,64 +50,78 @@ struct ComposeSubViewOpPattern : public OpRewritePattern<memref::SubViewOp> {
}
// Offsets, sizes and strides OpFoldResult for the combined 'SubViewOp'.
- SmallVector<OpFoldResult> offsets, sizes, strides;
-
- // Because we only support input strides of 1, the output stride is also
- // always 1.
- if (llvm::all_of(strides, [](OpFoldResult &valueOrAttr) {
- Attribute attr = llvm::dyn_cast_if_present<Attribute>(valueOrAttr);
- return attr && cast<IntegerAttr>(attr).getInt() == 1;
- })) {
- strides = SmallVector<OpFoldResult>(sourceOp.getMixedStrides().size(),
- rewriter.getI64IntegerAttr(1));
- } else {
- return failure();
+ SmallVector<OpFoldResult> offsets, sizes, strides,
+ opStrides = op.getMixedStrides(),
+ sourceStrides = sourceOp.getMixedStrides();
+
+ // The output stride in each dimension is equal to the product of the
+ // dimensions corresponding to source and op.
+ int64_t sourceStrideValue;
+ for (auto &&[opStride, sourceStride] :
+ llvm::zip(opStrides, sourceStrides)) {
+ Attribute opStrideAttr = dyn_cast_if_present<Attribute>(opStride);
+ Attribute sourceStrideAttr = dyn_cast_if_present<Attribute>(sourceStride);
+ if (!opStrideAttr || !sourceStrideAttr)
+ return failure();
+ sourceStrideValue = cast<IntegerAttr>(sourceStrideAttr).getInt();
+ strides.push_back(rewriter.getI64IntegerAttr(
+ cast<IntegerAttr>(opStrideAttr).getInt() * sourceStrideValue));
}
// The rules for calculating the new offsets and sizes are:
// * Multiple subview offsets for a given dimension compose additively.
- // ("Offset by m" followed by "Offset by n" == "Offset by m + n")
+ // ("Offset by m and Stride by k" followed by "Offset by n" == "Offset by
+ // m + n * k")
// * Multiple sizes for a given dimension compose by taking the size of the
// final subview and ignoring the rest. ("Take m values" followed by "Take
// n values" == "Take n values") This size must also be the smallest one
// by definition (a subview needs to be the same size as or smaller than
// its source along each dimension; presumably subviews that are larger
// than their sources are disallowed by validation).
- for (auto it : llvm::zip(op.getMixedOffsets(), sourceOp.getMixedOffsets(),
- op.getMixedSizes())) {
- auto opOffset = std::get<0>(it);
- auto sourceOffset = std::get<1>(it);
- auto opSize = std::get<2>(it);
-
+ for (auto &&[opOffset, sourceOffset, sourceStride, opSize] :
+ llvm::zip(op.getMixedOffsets(), sourceOp.getMixedOffsets(),
+ sourceOp.getMixedStrides(), op.getMixedSizes())) {
// We only support static sizes.
if (opSize.is<Value>()) {
return failure();
}
-
sizes.push_back(opSize);
Attribute opOffsetAttr = llvm::dyn_cast_if_present<Attribute>(opOffset),
sourceOffsetAttr =
- llvm::dyn_cast_if_present<Attribute>(sourceOffset);
-
+ llvm::dyn_cast_if_present<Attribute>(sourceOffset),
+ sourceStrideAttr =
+ llvm::dyn_cast_if_present<Attribute>(sourceStride);
if (opOffsetAttr && sourceOffsetAttr) {
+
// If both offsets are static we can simply calculate the combined
// offset statically.
offsets.push_back(rewriter.getI64IntegerAttr(
- cast<IntegerAttr>(opOffsetAttr).getInt() +
+ cast<IntegerAttr>(opOffsetAttr).getInt() *
+ cast<IntegerAttr>(sourceStrideAttr).getInt() +
cast<IntegerAttr>(sourceOffsetAttr).getInt()));
} else {
- // When either offset is dynamic, we must emit an additional affine
- // transformation to add the two offsets together dynamically.
- AffineExpr expr = rewriter.getAffineConstantExpr(0);
+ AffineExpr expr;
SmallVector<Value> affineApplyOperands;
- for (auto valueOrAttr : {opOffset, sourceOffset}) {
- if (auto attr = llvm::dyn_cast_if_present<Attribute>(valueOrAttr)) {
- expr = expr + cast<IntegerAttr>(attr).getInt();
- } else {
- expr =
- expr + rewriter.getAffineSymbolExpr(affineApplyOperands.size());
- affineApplyOperands.push_back(valueOrAttr.get<Value>());
- }
+
+ // Make 'expr' add 'sourceOffset'.
+ if (auto attr = llvm::dyn_cast_if_present<Attribute>(sourceOffset)) {
+ expr =
+ rewriter.getAffineConstantExpr(cast<IntegerAttr>(attr).getInt());
+ } else {
+ expr = rewriter.getAffineSymbolExpr(affineApplyOperands.size());
+ affineApplyOperands.push_back(sourceOffset.get<Value>());
+ }
+
+ // Multiply 'opOffset' by 'sourceStride' and make the 'expr' add the
+ // result.
+ if (auto attr = llvm::dyn_cast_if_present<Attribute>(opOffset)) {
+ expr = expr + cast<IntegerAttr>(attr).getInt() *
+ cast<IntegerAttr>(sourceStrideAttr).getInt();
+ } else {
+ expr =
+ expr + rewriter.getAffineSymbolExpr(affineApplyOperands.size()) *
+ cast<IntegerAttr>(sourceStrideAttr).getInt();
+ affineApplyOperands.push_back(opOffset.get<Value>());
}
AffineMap map = AffineMap::get(0, affineApplyOperands.size(), expr);
@@ -120,8 +133,8 @@ struct ComposeSubViewOpPattern : public OpRewritePattern<memref::SubViewOp> {
// This replaces 'op' but leaves 'sourceOp' alone; if it no longer has any
// uses it can be removed by a (separate) dead code elimination pass.
- rewriter.replaceOpWithNewOp<memref::SubViewOp>(op, sourceOp.getSource(),
- offsets, sizes, strides);
+ rewriter.replaceOpWithNewOp<memref::SubViewOp>(
+ op, op.getType(), sourceOp.getSource(), offsets, sizes, strides);
return success();
}
};
diff --git a/mlir/test/Transforms/compose-subview.mlir b/mlir/test/Transforms/compose-subview.mlir
index cb8ebcb2bf9e9b..22ffd836c68ed7 100644
--- a/mlir/test/Transforms/compose-subview.mlir
+++ b/mlir/test/Transforms/compose-subview.mlir
@@ -1,8 +1,9 @@
// RUN: mlir-opt %s -test-compose-subview -split-input-file | FileCheck %s
-func.func @main(%input: memref<4x1024xf32>) -> memref<1x128xf32, strided<[1024, 1], offset: 3456>> {
- // CHECK: subview %arg0[3, 384] [1, 128] [1, 1]
- // CHECK-SAME: memref<4x1024xf32> to memref<1x128xf32, strided<[1024, 1], offset: 3456>>
+// CHECK-LABEL: func.func @subview_strided(
+// CHECK-SAME: %[[VAL_0:.*]]: memref<4x1024xf32>) -> memref<1x128xf32, strided<[1024, 1], offset: 3456>> {
+func.func @subview_strided(%input: memref<4x1024xf32>) -> memref<1x128xf32, strided<[1024, 1], offset: 3456>> {
+ // CHECK: %[[VAL_1:.*]] = memref.subview %[[VAL_0]][3, 384] [1, 128] [1, 1] : memref<4x1024xf32> to memref<1x128xf32, strided<[1024, 1], offset: 3456>>
%0 = memref.subview %input[2, 256] [2, 256] [1, 1] : memref<4x1024xf32> to memref<2x256xf32, strided<[1024, 1], offset: 2304>>
%1 = memref.subview %0[1, 128] [1, 128] [1, 1] : memref<2x256xf32, strided<[1024, 1], offset: 2304>> to memref<1x128xf32, strided<[1024, 1], offset: 3456>>
return %1 : memref<1x128xf32, strided<[1024, 1], offset: 3456>>
@@ -10,9 +11,10 @@ func.func @main(%input: memref<4x1024xf32>) -> memref<1x128xf32, strided<[1024,
// -----
-func.func @main(%input: memref<4x1024xf32>) -> memref<1x10xf32, strided<[1024, 1], offset: 3745>> {
- // CHECK: subview %arg0[3, 673] [1, 10] [1, 1]
- // CHECK-SAME: memref<4x1024xf32> to memref<1x10xf32, strided<[1024, 1], offset: 3745>>
+// CHECK-LABEL: func.func @subview_strided(
+// CHECK-SAME: %[[VAL_0:.*]]: memref<4x1024xf32>) -> memref<1x10xf32, strided<[1024, 1], offset: 3745>> {
+func.func @subview_strided(%input: memref<4x1024xf32>) -> memref<1x10xf32, strided<[1024, 1], offset: 3745>> {
+ // CHECK: %[[VAL_1:.*]] = memref.subview %[[VAL_0]][3, 673] [1, 10] [1, 1] : memref<4x1024xf32> to memref<1x10xf32, strided<[1024, 1], offset: 3745>>
%0 = memref.subview %input[1, 512] [3, 256] [1, 1] : memref<4x1024xf32> to memref<3x256xf32, strided<[1024, 1], offset: 1536>>
%1 = memref.subview %0[1, 128] [2, 128] [1, 1] : memref<3x256xf32, strided<[1024, 1], offset: 1536>> to memref<2x128xf32, strided<[1024, 1], offset: 2688>>
%2 = memref.subview %1[1, 33] [1, 10] [1, 1] : memref<2x128xf32, strided<[1024, 1], offset: 2688>> to memref<1x10xf32, strided<[1024, 1], offset: 3745>>
@@ -21,12 +23,13 @@ func.func @main(%input: memref<4x1024xf32>) -> memref<1x10xf32, strided<[1024, 1
// -----
-func.func @main(%input: memref<4x1024xf32>) -> memref<1x128xf32, strided<[1024, 1], offset: ?>> {
- // CHECK: [[CST_3:%.*]] = arith.constant 3 : index
+// CHECK-LABEL: func.func @subview_strided(
+// CHECK-SAME: %[[VAL_0:.*]]: memref<4x1024xf32>) -> memref<1x128xf32, strided<[1024, 1], offset: ?>> {
+func.func @subview_strided(%input: memref<4x1024xf32>) -> memref<1x128xf32, strided<[1024, 1], offset: ?>> {
+ // CHECK: %[[VAL_1:.*]] = arith.constant 3 : index
%cst_1 = arith.constant 1 : index
%cst_2 = arith.constant 2 : index
- // CHECK: subview %arg0{{\[}}[[CST_3]], 384] [1, 128] [1, 1]
- // CHECK-SAME: memref<4x1024xf32> to memref<1x128xf32, strided<[1024, 1], offset: ?>>
+ // CHECK: %[[VAL_2:.*]] = memref.subview %[[VAL_0]]{{\[}}%[[VAL_1]], 384] [1, 128] [1, 1] : memref<4x1024xf32> to memref<1x128xf32, strided<[1024, 1], offset: ?>>
%0 = memref.subview %input[%cst_2, 256] [2, 256] [1, 1] : memref<4x1024xf32> to memref<2x256xf32, strided<[1024, 1], offset: ?>>
%1 = memref.subview %0[%cst_1, 128] [1, 128] [1, 1] : memref<2x256xf32, strided<[1024, 1], offset: ?>> to memref<1x128xf32, strided<[1024, 1], offset: ?>>
return %1 : memref<1x128xf32, strided<[1024, 1], offset: ?>>
@@ -34,14 +37,67 @@ func.func @main(%input: memref<4x1024xf32>) -> memref<1x128xf32, strided<[1024,
// -----
-func.func @main(%input: memref<4x1024xf32>) -> memref<1x128xf32, strided<[1024, 1], offset: ?>> {
- // CHECK: [[CST_3:%.*]] = arith.constant 3 : index
+// CHECK-LABEL: func.func @subview_strided(
+// CHECK-SAME: %[[VAL_0:.*]]: memref<4x1024xf32>) -> memref<1x128xf32, strided<[1024, 1], offset: ?>> {
+func.func @subview_strided(%input: memref<4x1024xf32>) -> memref<1x128xf32, strided<[1024, 1], offset: ?>> {
+ // CHECK: %[[VAL_1:.*]] = arith.constant 3 : index
%cst_2 = arith.constant 2 : index
- // CHECK: [[CST_384:%.*]] = arith.constant 384 : index
+ // CHECK: %[[VAL_2:.*]] = arith.constant 384 : index
%cst_128 = arith.constant 128 : index
- // CHECK: subview %arg0{{\[}}[[CST_3]], [[CST_384]]] [1, 128] [1, 1]
- // CHECK-SAME: memref<4x1024xf32> to memref<1x128xf32, strided<[1024, 1], offset: ?>>
+ // CHECK: %[[VAL_3:.*]] = memref.subview %[[VAL_0]]{{\[}}%[[VAL_1]], %[[VAL_2]]] [1, 128] [1, 1] : memref<4x1024xf32> to memref<1x128xf32, strided<[1024, 1], offset: ?>>
%0 = memref.subview %input[%cst_2, 256] [2, 256] [1, 1] : memref<4x1024xf32> to memref<2x256xf32, strided<[1024, 1], offset: ?>>
%1 = memref.subview %0[1, %cst_128] [1, 128] [1, 1] : memref<2x256xf32, strided<[1024, 1], offset: ?>> to memref<1x128xf32, strided<[1024, 1], offset: ?>>
return %1 : memref<1x128xf32, strided<[1024, 1], offset: ?>>
}
+
+// -----
+
+// CHECK-LABEL: func.func @subview_strided(
+// CHECK-SAME: %[[VAL_0:.*]]: memref<4x1024xf32>) -> memref<1x64xf32, strided<[4096, 4], offset: 4480>> {
+func.func @subview_strided(%input: memref<4x1024xf32>) -> memref<1x64xf32, strided<[4096, 4], offset: 4480>> {
+ // CHECK: %[[VAL_1:.*]] = memref.subview %[[VAL_0]][4, 384] [1, 64] [4, 4] : memref<4x1024xf32> to memref<1x64xf32, strided<[4096, 4], offset: 4480>>
+ %0 = memref.subview %input[2, 256] [2, 256] [2, 2] : memref<4x1024xf32> to memref<2x256xf32, strided<[2048, 2], offset: 2304>>
+ %1 = memref.subview %0[1, 64] [1, 64] [2, 2] : memref<2x256xf32, strided<[2048, 2], offset: 2304>> to memref<1x64xf32, strided<[4096, 4], offset: 4480>>
+ return %1 : memref<1x64xf32, strided<[4096, 4], offset: 4480>>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @subview_strided(
+// CHECK-SAME: %[[VAL_0:.*]]: memref<30x30xf32>) -> memref<2x2xf32, strided<[240, 8], offset: 217>> {
+func.func @subview_strided(%input: memref<30x30xf32>) -> memref<2x2xf32, strided<[240, 8], offset: 217>> {
+ // CHECK: %[[VAL_1:.*]] = memref.subview %[[VAL_0]][7, 7] [2, 2] [8, 8] : memref<30x30xf32> to memref<2x2xf32, strided<[240, 8], offset: 217>>
+ %0 = memref.subview %input[1, 1] [12, 12] [2, 2] : memref<30x30xf32> to memref<12x12xf32, strided<[60, 2], offset: 31>>
+ %1 = memref.subview %0[1, 1] [5, 5] [2, 2] : memref<12x12xf32, strided<[60, 2], offset: 31>> to memref<5x5xf32, strided<[120, 4], offset: 93>>
+ %2 = memref.subview %1[1, 1] [2, 2] [2, 2] : memref<5x5xf32, strided<[120, 4], offset: 93>> to memref<2x2xf32, strided<[240, 8], offset: 217>>
+ return %2 : memref<2x2xf32, strided<[240, 8], offset: 217>>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @subview_strided(
+// CHECK-SAME: %[[VAL_0:.*]]: memref<4x1024xf32>) -> memref<1x64xf32, strided<[4096, 4], offset: ?>> {
+func.func @subview_strided(%input: memref<4x1024xf32>) -> memref<1x64xf32, strided<[4096, 4], offset: ?>> {
+ // CHECK: %[[VAL_1:.*]] = arith.constant 4 : index
+ %cst_2 = arith.constant 2 : index
+ // CHECK: %[[VAL_2:.*]] = arith.constant 384 : index
+ %cst_64 = arith.constant 64 : index
+ // CHECK: %[[VAL_3:.*]] = memref.subview %[[VAL_0]]{{\[}}%[[VAL_1]], %[[VAL_2]]] [1, 64] [4, 4] : memref<4x1024xf32> to memref<1x64xf32, strided<[4096, 4], offset: ?>>
+ %0 = memref.subview %input[%cst_2, 256] [2, 256] [2, 2] : memref<4x1024xf32> to memref<2x256xf32, strided<[2048, 2], offset: ?>>
+ %1 = memref.subview %0[1, %cst_64] [1, 64] [2, 2] : memref<2x256xf32, strided<[2048, 2], offset: ?>> to memref<1x64xf32, strided<[4096, 4], offset: ?>>
+ return %1 : memref<1x64xf32, strided<[4096, 4], offset: ?>>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @subview_strided(
+// CHECK-SAME: %[[VAL_0:.*]]: memref<4x1024xf32>) -> memref<1x64xf32, strided<[4096, 4], offset: ?>> {
+func.func @subview_strided(%input: memref<4x1024xf32>) -> memref<1x64xf32, strided<[4096, 4], offset: ?>> {
+ // CHECK: %[[VAL_1:.*]] = arith.constant 4 : index
+ %cst_1 = arith.constant 1 : index
+ %cst_2 = arith.constant 2 : index
+ // CHECK: %[[VAL_2:.*]] = memref.subview %[[VAL_0]]{{\[}}%[[VAL_1]], 384] [1, 64] [4, 4] : memref<4x1024xf32> to memref<1x64xf32, strided<[4096, 4], offset: ?>>
+ %0 = memref.subview %input[%cst_2, 256] [2, 256] [2, 2] : memref<4x1024xf32> to memref<2x256xf32, strided<[2048, 2], offset: ?>>
+ %1 = memref.subview %0[%cst_1, 64] [1, 64] [2, 2] : memref<2x256xf32, strided<[2048, 2], offset: ?>> to memref<1x64xf32, strided<[4096, 4], offset: ?>>
+ return %1 : memref<1x64xf32, strided<[4096, 4], offset: ?>>
+}
More information about the Mlir-commits
mailing list