[Mlir-commits] [mlir] 83df43f - [mlir] use strided layouts in vector transfer on memrefs
Alex Zinenko
llvmlistbot at llvm.org
Fri Sep 16 23:11:42 PDT 2022
Author: Alex Zinenko
Date: 2022-09-17T08:11:30+02:00
New Revision: 83df43f3a204c529167caeedc26edae5faebbd31
URL: https://github.com/llvm/llvm-project/commit/83df43f3a204c529167caeedc26edae5faebbd31
DIFF: https://github.com/llvm/llvm-project/commit/83df43f3a204c529167caeedc26edae5faebbd31.diff
LOG: [mlir] use strided layouts in vector transfer on memrefs
One of the vector transformation patterns has been indiscriminately
converting layouts to affine maps. Leverage the strided form when
possible.
Reviewed By: nicolasvasilache, dcaballe
Differential Revision: https://reviews.llvm.org/D134047
Added:
Modified:
mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index fded10629b307..18f6c5a154e55 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -24,6 +24,7 @@
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
+#include "mlir/IR/BuiltinAttributeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/Matchers.h"
@@ -2631,22 +2632,29 @@ class DropInnerMostUnitDims : public OpRewritePattern<vector::TransferReadOp> {
targetType.getElementType());
MemRefType resultMemrefType;
- if (srcType.getLayout().getAffineMap().isIdentity()) {
+ MemRefLayoutAttrInterface layout = srcType.getLayout();
+ if (layout.isa<AffineMapAttr>() && layout.isIdentity()) {
resultMemrefType = MemRefType::get(
srcType.getShape().drop_back(dimsToDrop), srcType.getElementType(),
- {}, srcType.getMemorySpaceAsInt());
+ nullptr, srcType.getMemorySpace());
} else {
- AffineMap map = srcType.getLayout().getAffineMap();
- int numSymbols = map.getNumSymbols();
- for (size_t i = 0; i < dimsToDrop; ++i) {
- int dim = srcType.getRank() - i - 1;
- map = map.replace(rewriter.getAffineDimExpr(dim),
- rewriter.getAffineConstantExpr(0),
- map.getNumDims() - 1, numSymbols);
+ MemRefLayoutAttrInterface updatedLayout;
+ if (auto strided = layout.dyn_cast<StridedLayoutAttr>()) {
+ auto strides = llvm::to_vector(strided.getStrides().drop_back(dimsToDrop));
+ updatedLayout = StridedLayoutAttr::get(strided.getContext(), strided.getOffset(), strides);
+ } else {
+ AffineMap map = srcType.getLayout().getAffineMap();
+ int numSymbols = map.getNumSymbols();
+ for (size_t i = 0; i < dimsToDrop; ++i) {
+ int dim = srcType.getRank() - i - 1;
+ map = map.replace(rewriter.getAffineDimExpr(dim),
+ rewriter.getAffineConstantExpr(0),
+ map.getNumDims() - 1, numSymbols);
+ }
}
resultMemrefType = MemRefType::get(
srcType.getShape().drop_back(dimsToDrop), srcType.getElementType(),
- map, srcType.getMemorySpaceAsInt());
+ updatedLayout, srcType.getMemorySpace());
}
auto loc = readOp.getLoc();
diff --git a/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir b/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir
index 25f38452e3b4a..ef0bd9ddf8abf 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir
@@ -8,9 +8,9 @@ func.func @contiguous_inner_most_view(%in: memref<1x1x8x1xf32, strided<[3072, 8,
}
// CHECK: func @contiguous_inner_most_view(%[[SRC:.+]]: memref<1x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>>
// CHECK: %[[SRC_0:.+]] = memref.subview %[[SRC]]
-// CHECK-SAME: memref<1x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>> to memref<1x1x8xf32
+// CHECK-SAME: memref<1x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>> to memref<1x1x8xf32, strided<[3072, 8, 1], offset: ?>>
// CHECK: %[[VEC:.+]] = vector.transfer_read %[[SRC_0]]
-// CHECK-SAME: memref<1x1x8xf32, {{.*}}>, vector<1x8xf32>
+// CHECK-SAME: memref<1x1x8xf32, strided<[3072, 8, 1], offset: ?>>, vector<1x8xf32>
// CHECK: %[[RESULT:.+]] = vector.shape_cast %[[VEC]]
// CHECK: return %[[RESULT]]
@@ -34,8 +34,8 @@ func.func @contiguous_inner_most_dim(%A: memref<16x1xf32>, %i:index, %j:index) -
func.func @contiguous_inner_most_dim_bounds(%A: memref<1000x1xf32>, %i:index, %ii:index) -> (vector<4x1xf32>) {
%c0 = arith.constant 0 : index
%cst = arith.constant 0.0 : f32
- %0 = memref.subview %A[%i, 0] [40, 1] [1, 1] : memref<1000x1xf32> to memref<40x1xf32, affine_map<(d0, d1)[s0] -> (d0 + s0 + d1)>>
- %1 = vector.transfer_read %0[%ii, %c0], %cst {in_bounds = [true, true]} : memref<40x1xf32, affine_map<(d0, d1)[s0] -> (d0 + s0 + d1)>>, vector<4x1xf32>
+ %0 = memref.subview %A[%i, 0] [40, 1] [1, 1] : memref<1000x1xf32> to memref<40x1xf32, strided<[1, 1], offset: ?>>
+ %1 = vector.transfer_read %0[%ii, %c0], %cst {in_bounds = [true, true]} : memref<40x1xf32, strided<[1, 1], offset: ?>>, vector<4x1xf32>
return %1 : vector<4x1xf32>
}
// CHECK: func @contiguous_inner_most_dim_bounds(%[[SRC:.+]]: memref<1000x1xf32>, %[[II:.+]]: index, %[[J:.+]]: index) -> vector<4x1xf32>
@@ -50,8 +50,8 @@ func.func @contiguous_inner_most_dim_bounds(%A: memref<1000x1xf32>, %i:index, %i
func.func @contiguous_inner_most_dim_bounds_2d(%A: memref<1000x1x1xf32>, %i:index, %ii:index) -> (vector<4x1x1xf32>) {
%c0 = arith.constant 0 : index
%cst = arith.constant 0.0 : f32
- %0 = memref.subview %A[%i, 0, 0] [40, 1, 1] [1, 1, 1] : memref<1000x1x1xf32> to memref<40x1x1xf32, affine_map<(d0, d1, d2)[s0] -> (d0 + s0 + d1 + d2)>>
- %1 = vector.transfer_read %0[%ii, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<40x1x1xf32, affine_map<(d0, d1, d2)[s0] -> (d0 + s0 + d1 + d2)>>, vector<4x1x1xf32>
+ %0 = memref.subview %A[%i, 0, 0] [40, 1, 1] [1, 1, 1] : memref<1000x1x1xf32> to memref<40x1x1xf32, strided<[1, 1, 1], offset: ?>>
+ %1 = vector.transfer_read %0[%ii, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<40x1x1xf32, strided<[1, 1, 1], offset: ?>>, vector<4x1x1xf32>
return %1 : vector<4x1x1xf32>
}
// CHECK: func @contiguous_inner_most_dim_bounds_2d(%[[SRC:.+]]: memref<1000x1x1xf32>, %[[II:.+]]: index, %[[J:.+]]: index) -> vector<4x1x1xf32>
More information about the Mlir-commits
mailing list