[Mlir-commits] [mlir] [mlir][Vector] Fix bug in vector xfer op flattening transformation (PR #81964)
Diego Caballero
llvmlistbot at llvm.org
Wed Feb 21 09:25:01 PST 2024
https://github.com/dcaballe updated https://github.com/llvm/llvm-project/pull/81964
>From e6f82891e9f661c7673d526889a36ae6c711a549 Mon Sep 17 00:00:00 2001
From: Diego Caballero <diegocaballero at google.com>
Date: Fri, 16 Feb 2024 03:17:02 +0000
Subject: [PATCH 1/3] [mlir][Vector] Fix bug in vector xfer op flattening
It looks like the affine map generated to compute the indices of the
collapsed dimensions used the wrong dim size. For indices `[idx0][idx1]` we
computed the collapsed index as `idx0*size0 + idx1` instead of `idx0*size1 + idx1`.
This led to correctness issues in convolution tests when enabling this
transformation internally.
---
.../Transforms/VectorTransferOpTransforms.cpp | 8 +++--
.../Vector/vector-transfer-flatten.mlir | 34 +++++++++++++++++--
2 files changed, 37 insertions(+), 5 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index b761d1ed888973..5f150be0dd8cb6 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -615,10 +615,14 @@ class FlattenContiguousRowMajorTransferReadPattern
OpFoldResult offset =
rewriter.create<arith::ConstantIndexOp>(loc, 0).getResult();
+ auto srcType = dyn_cast<ShapedType>(source.getType());
for (int64_t i = firstDimToCollapse; i < outputRank; ++i) {
- int64_t dim = dyn_cast<ShapedType>(source.getType()).getDimSize(i);
+ // Multiply each index by the size of the next dimension. The last
+ // dimension (contiguous) is multiplied by one.
+ int64_t nextDimSize =
+ (i == outputRank - 1) ? 1 : srcType.getDimSize(i + 1);
offset = affine::makeComposedFoldedAffineApply(
- rewriter, loc, offsetExpr + dim * idxExpr,
+ rewriter, loc, offsetExpr + nextDimSize * idxExpr,
{offset, transferReadOp.getIndices()[i]});
}
if (offset.is<Value>()) {
diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index 9976048a3320b6..3025d22eef3623 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -66,14 +66,14 @@ func.func @transfer_read_dims_mismatch_non_zero_indices(
%m_out: memref<1x2x6xi32>) {
%c0 = arith.constant 0 : index
%c0_i32 = arith.constant 0 : i32
- %2 = vector.transfer_read %m_in[%c0, %idx_1, %idx_2, %c0], %c0_i32 {in_bounds = [true, true, true]} :
+ %2 = vector.transfer_read %m_in[%c0, %idx_1, %idx_2, %c0], %c0_i32 {in_bounds = [true, true, true]} :
memref<1x43x4x6xi32>, vector<1x2x6xi32>
vector.transfer_write %2, %m_out[%c0, %c0, %c0] {in_bounds = [true, true, true]} :
vector<1x2x6xi32>, memref<1x2x6xi32>
return
}
-// CHECK: #[[$ATTR_0:.+]] = affine_map<()[s0, s1] -> (s0 * 4 + s1 * 43)>
+// CHECK: #[[$ATTR_0:.+]] = affine_map<()[s0, s1] -> (s0 * 6 + s1 * 4)>
// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_zero_indices(
// CHECK-SAME: %[[IDX_1:.*]]: index, %[[IDX_2:.*]]: index,
@@ -99,7 +99,7 @@ func.func @transfer_read_dims_mismatch_non_zero_indices_dynamic_shapes(
%m_out: memref<1x2x6xi32>) {
%c0 = arith.constant 0 : index
%c0_i32 = arith.constant 0 : i32
- %2 = vector.transfer_read %m_in[%c0, %idx_1, %idx_2, %c0], %c0_i32 {in_bounds = [true, true, true]} :
+ %2 = vector.transfer_read %m_in[%c0, %idx_1, %idx_2, %c0], %c0_i32 {in_bounds = [true, true, true]} :
memref<1x?x4x6xi32>, vector<1x2x6xi32>
vector.transfer_write %2, %m_out[%c0, %c0, %c0] {in_bounds = [true, true, true]} :
vector<1x2x6xi32>, memref<1x2x6xi32>
@@ -389,3 +389,31 @@ func.func @fold_unit_dims_entirely(%arg0 : vector<8xi32>,
// CHECK: %[[VAL_3:.*]] = arith.muli %[[VAL_0]], %[[VAL_1]] : vector<8xi32>
// CHECK: %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_2]] : vector<8xi32>
// CHECK: return %[[VAL_4]] : vector<8xi32>
+
+// -----
+
+func.func @regression_non_contiguous_dim_read(%subview : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>>,
+ %idx0 : index, %idx1 : index) -> vector<2x2xf32> {
+ %c0 = arith.constant 0 : index
+ %cst_1 = arith.constant 0.000000e+00 : f32
+ %8 = vector.transfer_read %subview[%c0, %idx0, %idx1, %c0], %cst_1 {in_bounds = [true, true]} : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>>, vector<2x2xf32>
+ return %8 : vector<2x2xf32>
+}
+
+// CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 * 2)>
+// CHECK-LABEL: func.func @regression_non_contiguous_dim_read(
+// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %{{.*}} {{\[}}[0], [1], [2, 3]] : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>> into memref<1x3x6xf32, strided<[40, 10, 1], offset: ?>>
+// CHECK: %[[APPLY:.*]] = affine.apply #[[$MAP]]()
+
+// -----
+
+func.func @unsupported_non_contiguous_dim_write(%value : vector<2x2xf32>,
+ %subview : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>>,
+ %idx0 : index, %idx1 : index) {
+ %c0 = arith.constant 0 : index
+ vector.transfer_write %value, %subview[%c0, %idx0, %idx1, %c0] {in_bounds = [true, true]} : vector<2x2xf32>, memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>>
+ return
+}
+
+// CHECK-LABEL: func.func @unsupported_non_contiguous_dim_write(
+// CHECK-NOT: memref.collapse_shape
>From ae87cfd47c47a8fbb9da34212ce28e51457b7c2f Mon Sep 17 00:00:00 2001
From: Diego Caballero <diegocaballero at google.com>
Date: Fri, 16 Feb 2024 19:56:42 +0000
Subject: [PATCH 2/3] Use index utils.
---
.../mlir/Dialect/Utils/IndexingUtils.h | 3 ++
mlir/lib/Dialect/Utils/IndexingUtils.cpp | 11 ++++-
.../Transforms/VectorTransferOpTransforms.cpp | 45 ++++++++++---------
.../Vector/vector-transfer-flatten.mlir | 4 +-
4 files changed, 37 insertions(+), 26 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
index 2453d841f633e4..9892253df2bff1 100644
--- a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
@@ -257,6 +257,9 @@ SmallVector<int64_t> getI64SubArray(ArrayAttr arrayAttr, unsigned dropFront = 0,
std::pair<AffineExpr, SmallVector<OpFoldResult>>
computeLinearIndex(OpFoldResult sourceOffset, ArrayRef<OpFoldResult> strides,
ArrayRef<OpFoldResult> indices);
+std::pair<AffineExpr, SmallVector<OpFoldResult>>
+computeLinearIndex(OpFoldResult sourceOffset, ArrayRef<int64_t> strides,
+ ArrayRef<Value> indices);
//===----------------------------------------------------------------------===//
// Utilities for decomposing larger shapes
diff --git a/mlir/lib/Dialect/Utils/IndexingUtils.cpp b/mlir/lib/Dialect/Utils/IndexingUtils.cpp
index 2765d1eb1000da..d31591b2d0435a 100644
--- a/mlir/lib/Dialect/Utils/IndexingUtils.cpp
+++ b/mlir/lib/Dialect/Utils/IndexingUtils.cpp
@@ -7,13 +7,12 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Utils/IndexingUtils.h"
-
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/MLIRContext.h"
#include "llvm/ADT/STLExtras.h"
-
#include <numeric>
#include <optional>
@@ -307,6 +306,14 @@ mlir::computeLinearIndex(OpFoldResult sourceOffset,
return {expr, values};
}
+std::pair<AffineExpr, SmallVector<OpFoldResult>>
+mlir::computeLinearIndex(OpFoldResult sourceOffset, ArrayRef<int64_t> strides,
+ ArrayRef<Value> indices) {
+ return computeLinearIndex(
+ sourceOffset, getAsIndexOpFoldResult(sourceOffset.getContext(), strides),
+ getAsOpFoldResult(ValueRange(indices)));
+}
+
//===----------------------------------------------------------------------===//
// TileOffsetRange
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index 5f150be0dd8cb6..b6e52088bfa490 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -15,11 +15,11 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
-#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Dominance.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "llvm/ADT/STLExtras.h"
@@ -564,7 +564,6 @@ class FlattenContiguousRowMajorTransferReadPattern
if (transferReadOp.getMask())
return failure();
- SmallVector<Value> collapsedIndices;
int64_t firstDimToCollapse = sourceType.getRank() - vectorType.getRank();
// 1. Collapse the source memref
@@ -586,12 +585,14 @@ class FlattenContiguousRowMajorTransferReadPattern
// 2.2 New indices
// If all the collapsed indices are zero then no extra logic is needed.
// Otherwise, a new offset/index has to be computed.
+ SmallVector<Value> collapsedIndices;
if (failed(checkAndCollapseInnerZeroIndices(transferReadOp.getIndices(),
firstDimToCollapse,
collapsedIndices))) {
- // Copy all the leading indices
- collapsedIndices = transferReadOp.getIndices();
- collapsedIndices.resize(firstDimToCollapse);
+ // Copy all the leading indices.
+ SmallVector<Value> indices = transferReadOp.getIndices();
+ collapsedIndices.append(indices.begin(),
+ indices.begin() + firstDimToCollapse);
// Compute the remaining trailing index/offset required for reading from
// the collapsed memref:
@@ -608,28 +609,27 @@ class FlattenContiguousRowMajorTransferReadPattern
// memref<1x86xi32>, vector<2xi32>
// one would get the following offset:
// %offset = %arg0 * 43
- AffineExpr offsetExpr, idxExpr;
- bindSymbols(rewriter.getContext(), offsetExpr, idxExpr);
-
int64_t outputRank = transferReadOp.getIndices().size();
- OpFoldResult offset =
+ OpFoldResult collapsedOffset =
rewriter.create<arith::ConstantIndexOp>(loc, 0).getResult();
- auto srcType = dyn_cast<ShapedType>(source.getType());
- for (int64_t i = firstDimToCollapse; i < outputRank; ++i) {
- // Multiply each index by the size of the next dimension. The last
- // dimension (contiguous) is multiplied by one.
- int64_t nextDimSize =
- (i == outputRank - 1) ? 1 : srcType.getDimSize(i + 1);
- offset = affine::makeComposedFoldedAffineApply(
- rewriter, loc, offsetExpr + nextDimSize * idxExpr,
- {offset, transferReadOp.getIndices()[i]});
- }
- if (offset.is<Value>()) {
- collapsedIndices.push_back(offset.get<Value>());
+ auto sourceShape = sourceType.getShape();
+ auto collapsedStrides = computeSuffixProduct(ArrayRef<int64_t>(
+ sourceShape.begin() + firstDimToCollapse, sourceShape.end()));
+
+ // Compute the collapsed offset.
+ ArrayRef<Value> indicesToCollapse(indices.begin() + firstDimToCollapse,
+ indices.end());
+ auto &&[collapsedExpr, collapsedVals] = computeLinearIndex(
+ collapsedOffset, collapsedStrides, indicesToCollapse);
+ collapsedOffset = affine::makeComposedFoldedAffineApply(
+ rewriter, loc, collapsedExpr, collapsedVals);
+
+ if (collapsedOffset.is<Value>()) {
+ collapsedIndices.push_back(collapsedOffset.get<Value>());
} else {
collapsedIndices.push_back(rewriter.create<arith::ConstantIndexOp>(
- loc, *getConstantIntValue(offset)));
+ loc, *getConstantIntValue(collapsedOffset)));
}
}
@@ -685,6 +685,7 @@ class FlattenContiguousRowMajorTransferWritePattern
firstContiguousInnerDim,
collapsedIndices)))
return failure();
+
Value collapsedSource =
collapseInnerDims(rewriter, loc, source, firstContiguousInnerDim);
MemRefType collapsedSourceType =
diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index 3025d22eef3623..317fc33ea55ceb 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -73,7 +73,7 @@ func.func @transfer_read_dims_mismatch_non_zero_indices(
return
}
-// CHECK: #[[$ATTR_0:.+]] = affine_map<()[s0, s1] -> (s0 * 6 + s1 * 4)>
+// CHECK: #[[$ATTR_0:.+]] = affine_map<()[s0, s1] -> (s0 * 24 + s1 * 6)>
// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_zero_indices(
// CHECK-SAME: %[[IDX_1:.*]]: index, %[[IDX_2:.*]]: index,
@@ -82,7 +82,7 @@ func.func @transfer_read_dims_mismatch_non_zero_indices(
// CHECK: %[[C_0:.*]] = arith.constant 0 : i32
// CHECK: %[[C_0_IDX:.*]] = arith.constant 0 : index
// CHECK: %[[COLLAPSED_IN:.*]] = memref.collapse_shape %[[M_IN]] {{\[}}[0], [1, 2, 3]] : memref<1x43x4x6xi32> into memref<1x1032xi32>
-// CHECK: %[[COLLAPSED_IDX:.*]] = affine.apply #[[$ATTR_0]]()[%[[IDX_2]], %[[IDX_1]]]
+// CHECK: %[[COLLAPSED_IDX:.*]] = affine.apply #[[$ATTR_0]]()[%[[IDX_1]], %[[IDX_2]]]
// CHECK: %[[READ:.*]] = vector.transfer_read %[[COLLAPSED_IN]][%[[C_0_IDX]], %[[COLLAPSED_IDX]]], %[[C_0]] {in_bounds = [true]} : memref<1x1032xi32>, vector<12xi32>
// CHECK: %[[COLLAPSED_OUT:.*]] = memref.collapse_shape %[[M_OUT]] {{\[}}[0, 1, 2]] : memref<1x2x6xi32> into memref<12xi32>
// CHECK: vector.transfer_write %[[READ]], %[[COLLAPSED_OUT]][%[[C_0_IDX]]] {in_bounds = [true]} : vector<12xi32>, memref<12xi32>
>From 78b7c75698a8ee0b3636b2fd1cfdfba3f6bcf007 Mon Sep 17 00:00:00 2001
From: Diego Caballero <diegocaballero at google.com>
Date: Wed, 21 Feb 2024 01:57:21 +0000
Subject: [PATCH 3/3] Remove unused variable
---
.../lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp | 1 -
1 file changed, 1 deletion(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index b6e52088bfa490..02972359149fd6 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -609,7 +609,6 @@ class FlattenContiguousRowMajorTransferReadPattern
// memref<1x86xi32>, vector<2xi32>
// one would get the following offset:
// %offset = %arg0 * 43
- int64_t outputRank = transferReadOp.getIndices().size();
OpFoldResult collapsedOffset =
rewriter.create<arith::ConstantIndexOp>(loc, 0).getResult();
More information about the Mlir-commits
mailing list