[Mlir-commits] [mlir] df455be - [mlir][MemRef] Fix the simplification of extract_strided_metadata(subview)
Quentin Colombet
llvmlistbot at llvm.org
Tue Oct 18 12:38:02 PDT 2022
Author: Quentin Colombet
Date: 2022-10-18T19:29:49Z
New Revision: df455beedfcd4634450e5782d6bb4986218174e2
URL: https://github.com/llvm/llvm-project/commit/df455beedfcd4634450e5782d6bb4986218174e2
DIFF: https://github.com/llvm/llvm-project/commit/df455beedfcd4634450e5782d6bb4986218174e2.diff
LOG: [mlir][MemRef] Fix the simplification of extract_strided_metadata(subview)
Prior to this patch we were wrongly applying the sub-strides to the
computation of the final offset of the subview.
Put differently, we were computing the offset as:
```
offset = baseOffset + sum(subOffset#i * baseStrides#i * subSizes#i)
```
Whereas we should be doing:
```
offset = baseOffset + sum(subOffset#i * baseStrides#i)
```
I.e., drop the subSizes#i term from the sum.
Differential Revision: https://reviews.llvm.org/D136107
Added:
Modified:
mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp
mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp
index 6aa68ae249635..6e861032d35bc 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp
@@ -39,7 +39,7 @@ namespace {
/// baseBuffer, baseOffset, baseSizes, baseStrides =
/// extract_strided_metadata(memref)
/// strides#i = baseStrides#i * subSizes#i
-/// offset = baseOffset + sum(subOffset#i * strides#i)
+/// offset = baseOffset + sum(subOffset#i * baseStrides#i)
/// sizes = subSizes
/// \endverbatim
///
@@ -83,8 +83,8 @@ struct ExtractStridedMetadataOpSubviewFolder
auto origStrides = newExtractStridedMetadata.getStrides();
// Hold the affine symbols and values for the computation of the offset.
- SmallVector<OpFoldResult> values(3 * sourceRank + 1);
- SmallVector<AffineExpr> symbols(3 * sourceRank + 1);
+ SmallVector<OpFoldResult> values(2 * sourceRank + 1);
+ SmallVector<AffineExpr> symbols(2 * sourceRank + 1);
detail::bindSymbolsList(rewriter.getContext(), symbols);
AffineExpr expr = symbols.front();
@@ -105,14 +105,11 @@ struct ExtractStridedMetadataOpSubviewFolder
rewriter, origLoc, s0 * s1, {subStrides[i], origStride}));
// Build up the computation of the offset.
- unsigned baseIdxForDim = 1 + 3 * i;
+ unsigned baseIdxForDim = 1 + 2 * i;
unsigned subOffsetForDim = baseIdxForDim;
- unsigned subStrideForDim = baseIdxForDim + 1;
- unsigned origStrideForDim = baseIdxForDim + 2;
- expr = expr + symbols[subOffsetForDim] * symbols[subStrideForDim] *
- symbols[origStrideForDim];
+ unsigned origStrideForDim = baseIdxForDim + 1;
+ expr = expr + symbols[subOffsetForDim] * symbols[origStrideForDim];
values[subOffsetForDim] = subOffsets[i];
- values[subStrideForDim] = subStrides[i];
values[origStrideForDim] = origStride;
}
diff --git a/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir b/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir
index 338b52b5ac427..caa7efdcc6c3a 100644
--- a/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir
+++ b/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir
@@ -24,7 +24,7 @@ func.func @extract_strided_metadata_constants(%base: memref<5x4xf32, strided<[4,
// Check that we simplify extract_strided_metadata of subview to
// base_buf, base_offset, base_sizes, base_strides = extract_strided_metadata
// strides = base_stride_i * subview_stride_i
-// offset = base_offset + sum(subview_offsets_i * strides_i).
+// offset = base_offset + sum(subview_offsets_i * base_strides_i).
//
// This test also checks that we don't create useless arith operations
// when subview_offsets_i is 0.
@@ -42,8 +42,8 @@ func.func @extract_strided_metadata_constants(%base: memref<5x4xf32, strided<[4,
//
// Final offset is:
// origOffset + (== 0)
-// base_stride0 * subview_stride0 * subview_offset0 + (== 4 * 1 * 0 == 0)
-// base_stride1 * subview_stride1 * subview_offset1 (== 1 * 1 * 2)
+// base_stride0 * subview_offset0 + (== 4 * 0 == 0)
+// base_stride1 * subview_offset1 (== 1 * 2)
// == 2
//
// Return the new tuple.
@@ -171,14 +171,14 @@ func.func @extract_strided_metadata_of_rank_reduced_subview(%base: memref<8x16x4
//
// Orig offset: 0
// Sub offsets: [3, 4, 2]
-// => Final offset: 3 * 64 + 4 * 4 * %stride + 2 * 1 + 0 == 16 * %stride + 194
+// => Final offset: 3 * 64 + 4 * 4 + 2 * 1 + 0 == 210
//
// CHECK-DAG: #[[$STRIDE1_MAP:.*]] = affine_map<()[s0] -> (s0 * 4)>
-// CHECK-DAG: #[[$OFFSET_MAP:.*]] = affine_map<()[s0] -> (s0 * 16 + 194)>
// CHECK-LABEL: func @extract_strided_metadata_of_rank_reduced_subview_w_variable_strides
// CHECK-SAME: (%[[ARG:.*]]: memref<8x16x4xf32>,
// CHECK-SAME: %[[DYN_STRIDE:.*]]: index)
//
+// CHECK-DAG: %[[C210:.*]] = arith.constant 210 : index
// CHECK-DAG: %[[C6:.*]] = arith.constant 6 : index
// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
@@ -186,9 +186,8 @@ func.func @extract_strided_metadata_of_rank_reduced_subview(%base: memref<8x16x4
// CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:3, %[[STRIDES:.*]]:3 = memref.extract_strided_metadata %[[ARG]]
//
// CHECK-DAG: %[[DIM1_STRIDE:.*]] = affine.apply #[[$STRIDE1_MAP]]()[%[[DYN_STRIDE]]]
-// CHECK-DAG: %[[FINAL_OFFSET:.*]] = affine.apply #[[$OFFSET_MAP]]()[%[[DYN_STRIDE]]]
//
-// CHECK: return %[[BASE]], %[[FINAL_OFFSET]], %[[C6]], %[[C3]], %[[DIM1_STRIDE]], %[[C1]]
+// CHECK: return %[[BASE]], %[[C210]], %[[C6]], %[[C3]], %[[DIM1_STRIDE]], %[[C1]]
func.func @extract_strided_metadata_of_rank_reduced_subview_w_variable_strides(
%base: memref<8x16x4xf32>, %stride: index)
-> (memref<f32>, index, index, index, index, index) {
@@ -262,11 +261,11 @@ func.func @extract_strided_metadata_of_subview_w_variable_offset(
//
// Orig offset: origOff
// Sub offsets: [subO0, subO1, subO2]
-// => Final offset: s0 * subS0 * subO0 + ... + s2 * subS2 * subO2 + origOff
-// ==> 1 affine map with (rank * 3 + 1) symbols
+// => Final offset: s0 * * subO0 + ... + s2 * subO2 + origOff
+// ==> 1 affine map with (rank * 2 + 1) symbols
//
// CHECK-DAG: #[[$STRIDE_MAP:.*]] = affine_map<()[s0, s1] -> (s0 * s1)>
-// CHECK-DAG: #[[$OFFSET_MAP:.*]] = affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s0 + (s1 * s2) * s3 + (s4 * s5) * s6 + (s7 * s8) * s9)>
+// CHECK-DAG: #[[$OFFSET_MAP:.*]] = affine_map<()[s0, s1, s2, s3, s4, s5, s6] -> (s0 + s1 * s2 + s3 * s4 + s5 * s6)>
// CHECK-LABEL: func @extract_strided_metadata_of_subview_all_dynamic
// CHECK-SAME: (%[[ARG:.*]]: memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>, %[[DYN_OFFSET0:.*]]: index, %[[DYN_OFFSET1:.*]]: index, %[[DYN_OFFSET2:.*]]: index, %[[DYN_SIZE0:.*]]: index, %[[DYN_SIZE1:.*]]: index, %[[DYN_SIZE2:.*]]: index, %[[DYN_STRIDE0:.*]]: index, %[[DYN_STRIDE1:.*]]: index, %[[DYN_STRIDE2:.*]]: index)
//
@@ -276,7 +275,7 @@ func.func @extract_strided_metadata_of_subview_w_variable_offset(
// CHECK-DAG: %[[FINAL_STRIDE1:.*]] = affine.apply #[[$STRIDE_MAP]]()[%[[DYN_STRIDE1]], %[[STRIDES]]#1]
// CHECK-DAG: %[[FINAL_STRIDE2:.*]] = affine.apply #[[$STRIDE_MAP]]()[%[[DYN_STRIDE2]], %[[STRIDES]]#2]
//
-// CHECK-DAG: %[[FINAL_OFFSET:.*]] = affine.apply #[[$OFFSET_MAP]]()[%[[OFFSET]], %[[DYN_OFFSET0]], %[[DYN_STRIDE0]], %[[STRIDES]]#0, %[[DYN_OFFSET1]], %[[DYN_STRIDE1]], %[[STRIDES]]#1, %[[DYN_OFFSET2]], %[[DYN_STRIDE2]], %[[STRIDES]]#2]
+// CHECK-DAG: %[[FINAL_OFFSET:.*]] = affine.apply #[[$OFFSET_MAP]]()[%[[OFFSET]], %[[DYN_OFFSET0]], %[[STRIDES]]#0, %[[DYN_OFFSET1]], %[[STRIDES]]#1, %[[DYN_OFFSET2]], %[[STRIDES]]#2]
//
// CHECK: return %[[BASE]], %[[FINAL_OFFSET]], %[[DYN_SIZE0]], %[[DYN_SIZE1]], %[[DYN_SIZE2]], %[[FINAL_STRIDE0]], %[[FINAL_STRIDE1]], %[[FINAL_STRIDE2]]
func.func @extract_strided_metadata_of_subview_all_dynamic(
More information about the Mlir-commits
mailing list