[Mlir-commits] [mlir] [mlir][Vector] Fix bug in vector xfer op flattening transformation (PR #81964)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Feb 15 19:24:58 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Diego Caballero (dcaballe)
<details>
<summary>Changes</summary>
It looks like the affine map generated to compute the indices of the collapsed dimensions used the wrong dim size. For indices `[idx0][idx1]` we computed the collapsed index as `idx0*size0 + idx1` instead of `idx0*size1 + idx1`. This led to correctness issues in convolution tests when enabling this transformation internally.
---
Full diff: https://github.com/llvm/llvm-project/pull/81964.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp (+6-2)
- (modified) mlir/test/Dialect/Vector/vector-transfer-flatten.mlir (+31-3)
``````````diff
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index b761d1ed888973..5f150be0dd8cb6 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -615,10 +615,14 @@ class FlattenContiguousRowMajorTransferReadPattern
OpFoldResult offset =
rewriter.create<arith::ConstantIndexOp>(loc, 0).getResult();
+ auto srcType = dyn_cast<ShapedType>(source.getType());
for (int64_t i = firstDimToCollapse; i < outputRank; ++i) {
- int64_t dim = dyn_cast<ShapedType>(source.getType()).getDimSize(i);
+ // Multiply each index by the size of the next dimension. The last
+ // dimension (contiguous) is multiplied by one.
+ int64_t nextDimSize =
+ (i == outputRank - 1) ? 1 : srcType.getDimSize(i + 1);
offset = affine::makeComposedFoldedAffineApply(
- rewriter, loc, offsetExpr + dim * idxExpr,
+ rewriter, loc, offsetExpr + nextDimSize * idxExpr,
{offset, transferReadOp.getIndices()[i]});
}
if (offset.is<Value>()) {
diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index 9976048a3320b6..3025d22eef3623 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -66,14 +66,14 @@ func.func @transfer_read_dims_mismatch_non_zero_indices(
%m_out: memref<1x2x6xi32>) {
%c0 = arith.constant 0 : index
%c0_i32 = arith.constant 0 : i32
- %2 = vector.transfer_read %m_in[%c0, %idx_1, %idx_2, %c0], %c0_i32 {in_bounds = [true, true, true]} :
+ %2 = vector.transfer_read %m_in[%c0, %idx_1, %idx_2, %c0], %c0_i32 {in_bounds = [true, true, true]} :
memref<1x43x4x6xi32>, vector<1x2x6xi32>
vector.transfer_write %2, %m_out[%c0, %c0, %c0] {in_bounds = [true, true, true]} :
vector<1x2x6xi32>, memref<1x2x6xi32>
return
}
-// CHECK: #[[$ATTR_0:.+]] = affine_map<()[s0, s1] -> (s0 * 4 + s1 * 43)>
+// CHECK: #[[$ATTR_0:.+]] = affine_map<()[s0, s1] -> (s0 * 6 + s1 * 4)>
// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_zero_indices(
// CHECK-SAME: %[[IDX_1:.*]]: index, %[[IDX_2:.*]]: index,
@@ -99,7 +99,7 @@ func.func @transfer_read_dims_mismatch_non_zero_indices_dynamic_shapes(
%m_out: memref<1x2x6xi32>) {
%c0 = arith.constant 0 : index
%c0_i32 = arith.constant 0 : i32
- %2 = vector.transfer_read %m_in[%c0, %idx_1, %idx_2, %c0], %c0_i32 {in_bounds = [true, true, true]} :
+ %2 = vector.transfer_read %m_in[%c0, %idx_1, %idx_2, %c0], %c0_i32 {in_bounds = [true, true, true]} :
memref<1x?x4x6xi32>, vector<1x2x6xi32>
vector.transfer_write %2, %m_out[%c0, %c0, %c0] {in_bounds = [true, true, true]} :
vector<1x2x6xi32>, memref<1x2x6xi32>
@@ -389,3 +389,31 @@ func.func @fold_unit_dims_entirely(%arg0 : vector<8xi32>,
// CHECK: %[[VAL_3:.*]] = arith.muli %[[VAL_0]], %[[VAL_1]] : vector<8xi32>
// CHECK: %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_2]] : vector<8xi32>
// CHECK: return %[[VAL_4]] : vector<8xi32>
+
+// -----
+
+func.func @regression_non_contiguous_dim_read(%subview : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>>,
+ %idx0 : index, %idx1 : index) -> vector<2x2xf32> {
+ %c0 = arith.constant 0 : index
+ %cst_1 = arith.constant 0.000000e+00 : f32
+ %8 = vector.transfer_read %subview[%c0, %idx0, %idx1, %c0], %cst_1 {in_bounds = [true, true]} : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>>, vector<2x2xf32>
+ return %8 : vector<2x2xf32>
+}
+
+// CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 * 2)>
+// CHECK-LABEL: func.func @regression_non_contiguous_dim_read(
+// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %{{.*}} {{\[}}[0], [1], [2, 3]] : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>> into memref<1x3x6xf32, strided<[40, 10, 1], offset: ?>>
+// CHECK: %[[APPLY:.*]] = affine.apply #[[$MAP]]()
+
+// -----
+
+func.func @unsupported_non_contiguous_dim_write(%value : vector<2x2xf32>,
+ %subview : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>>,
+ %idx0 : index, %idx1 : index) {
+ %c0 = arith.constant 0 : index
+ vector.transfer_write %value, %subview[%c0, %idx0, %idx1, %c0] {in_bounds = [true, true]} : vector<2x2xf32>, memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>>
+ return
+}
+
+// CHECK-LABEL: func.func @unsupported_non_contiguous_dim_write(
+// CHECK-NOT: memref.collapse_shape
``````````
</details>
https://github.com/llvm/llvm-project/pull/81964
More information about the Mlir-commits
mailing list