[Mlir-commits] [mlir] a3dd4e7 - Drop transfer_read inner most unit dimensions
Ahmed S. Taei
llvmlistbot at llvm.org
Wed Oct 20 12:27:10 PDT 2021
Author: Ahmed S. Taei
Date: 2021-10-20T19:27:04Z
New Revision: a3dd4e777095f9668215a3babab1041025819f64
URL: https://github.com/llvm/llvm-project/commit/a3dd4e777095f9668215a3babab1041025819f64
DIFF: https://github.com/llvm/llvm-project/commit/a3dd4e777095f9668215a3babab1041025819f64.diff
LOG: Drop transfer_read inner most unit dimensions
Add a pattern to take a rank-reducing subview and drop inner most
contiguous unit dim.
This is useful when lowering vector to backends with 1d vector types.
Reviewed By: ThomasRaoux
Differential Revision: https://reviews.llvm.org/D111561
Added:
mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir
Modified:
mlir/include/mlir/Dialect/Vector/VectorTransforms.h
mlir/lib/Dialect/Vector/VectorTransforms.cpp
mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
index d51ba185e86e0..59e6ac07bbca3 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
@@ -97,6 +97,13 @@ struct UnrollVectorOptions {
void populateVectorUnrollPatterns(RewritePatternSet &patterns,
const UnrollVectorOptions &options);
+/// Collect a set of patterns to reduce the rank of the operands of vector
+/// transfer ops to operate on the largest contigious vector.
+/// These patterns are useful when lowering to dialects with 1d vector type
+/// such as llvm and it will result fewer memory reads.
+void populateVectorTransferCollapseInnerMostContiguousDimsPatterns(
+ RewritePatternSet &patterns);
+
/// Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds
/// masking) fastpath and a slowpath.
/// If `ifOp` is not null and the result is `success, the `ifOp` points to the
diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index 7f73e367356bc..51ef008a2fa21 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -3529,6 +3529,80 @@ class VectorCreateMaskOpConversion
const bool enableIndexOptimizations;
};
+// Drop inner most contiguous unit dimensions from transfer_read operand.
+class DropInnerMostUnitDims : public OpRewritePattern<vector::TransferReadOp> {
+ using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
+ PatternRewriter &rewriter) const override {
+ auto srcType = readOp.source().getType().cast<MemRefType>();
+ if (!srcType || !srcType.hasStaticShape())
+ return failure();
+
+ if (!readOp.permutation_map().isMinorIdentity())
+ return failure();
+
+ auto targetType = readOp.getVectorType();
+ if (targetType.getRank() <= 1)
+ return failure();
+
+ SmallVector<int64_t> srcStrides;
+ int64_t srcOffset;
+ if (failed(getStridesAndOffset(srcType, srcStrides, srcOffset)))
+ return failure();
+
+ size_t dimsToDrop = 0;
+ for (size_t i = 1; i < srcStrides.size(); ++i) {
+ int dim = srcType.getRank() - i - 1;
+ if (srcStrides[dim] == 1) {
+ dimsToDrop++;
+ } else {
+ break;
+ }
+ }
+ if (dimsToDrop == 0)
+ return failure();
+
+ auto resultTargetVecType =
+ VectorType::get(targetType.getShape().drop_back(dimsToDrop),
+ targetType.getElementType());
+
+ MemRefType resultMemrefType;
+ if (srcType.getLayout().getAffineMap().isIdentity()) {
+ resultMemrefType = MemRefType::get(
+ srcType.getShape().drop_back(dimsToDrop), srcType.getElementType(),
+ {}, srcType.getMemorySpaceAsInt());
+ } else {
+ AffineMap map = srcType.getLayout().getAffineMap();
+ int numResultDims = map.getNumDims() - dimsToDrop;
+ int numSymbols = map.getNumSymbols();
+ for (size_t i = 0; i < dimsToDrop; ++i) {
+ int dim = srcType.getRank() - i - 1;
+ map = map.replace(rewriter.getAffineDimExpr(dim),
+ rewriter.getAffineConstantExpr(0), numResultDims,
+ numSymbols);
+ }
+ resultMemrefType = MemRefType::get(
+ srcType.getShape().drop_back(dimsToDrop), srcType.getElementType(),
+ map, srcType.getMemorySpaceAsInt());
+ }
+
+ auto loc = readOp.getLoc();
+ SmallVector<int64_t> offsets(srcType.getRank(), 0);
+ SmallVector<int64_t> strides(srcType.getRank(), 1);
+ Value rankedReducedView = rewriter.create<memref::SubViewOp>(
+ loc, resultMemrefType, readOp.source(), offsets, srcType.getShape(),
+ strides);
+ Value result = rewriter.create<vector::TransferReadOp>(
+ loc, resultTargetVecType, rankedReducedView,
+ readOp.indices().drop_back(dimsToDrop));
+ rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(readOp, targetType,
+ result);
+
+ return success();
+ }
+};
+
void mlir::vector::populateVectorMaskMaterializationPatterns(
RewritePatternSet &patterns, bool enableIndexOptimizations) {
patterns.add<VectorCreateMaskOpConversion,
@@ -3617,3 +3691,9 @@ void mlir::vector::populateVectorUnrollPatterns(
UnrollContractionPattern, UnrollElementwisePattern>(
patterns.getContext(), options);
}
+
+void mlir::vector::
+ populateVectorTransferCollapseInnerMostContiguousDimsPatterns(
+ RewritePatternSet &patterns) {
+ patterns.add<DropInnerMostUnitDims>(patterns.getContext());
+}
diff --git a/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir b/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir
new file mode 100644
index 0000000000000..6ebfeebd81314
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir
@@ -0,0 +1,33 @@
+// RUN: mlir-opt %s -test-vector-transfer-collapse-inner-most-dims -split-input-file | FileCheck %s
+
+#map1 = affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 3072 + s0 + d1 * 8 + d2 + d3)>
+func @contiguous_inner_most_view(%in: memref<1x1x8x1xf32, #map1>) -> vector<1x8x1xf32>{
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0.0 : f32
+ %0 = vector.transfer_read %in[%c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<1x1x8x1xf32, #map1>, vector<1x8x1xf32>
+ return %0 : vector<1x8x1xf32>
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 3072 + s0 + d1 * 8 + d2 + d3)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2)[s0] -> (d0 * 3072 + s0 + d1 * 8 + d2)>
+// CHECK: func @contiguous_inner_most_view(%[[SRC:.+]]: memref<1x1x8x1xf32, #[[MAP0]]>
+// CHECK: %[[SRC_0:.+]] = memref.subview %[[SRC]]
+// CHECK-SAME: memref<1x1x8x1xf32, #[[MAP0]]> to memref<1x1x8xf32, #[[MAP1]]>
+// CHECK: %[[VEC:.+]] = vector.transfer_read %[[SRC_0]]
+// CHECK-SAME: memref<1x1x8xf32, #[[MAP1]]>, vector<1x8xf32>
+// CHECK: %[[RESULT:.+]] = vector.shape_cast %[[VEC]]
+// CHECK: return %[[RESULT]]
+
+// -----
+
+func @contiguous_inner_most_dim(%A: memref<16x1xf32>, %i:index, %j:index) -> (vector<8x1xf32>) {
+ %c0 = arith.constant 0 : index
+ %f0 = arith.constant 0.0 : f32
+ %1 = vector.transfer_read %A[%i, %j], %f0 : memref<16x1xf32>, vector<8x1xf32>
+ return %1 : vector<8x1xf32>
+}
+// CHECK: func @contiguous_inner_most_dim(%[[SRC:.+]]: memref<16x1xf32>, %[[I:.+]]: index, %[[J:.+]]: index) -> vector<8x1xf32>
+// CHECK: %[[SRC_0:.+]] = memref.subview %[[SRC]]
+// CHECK-SAME: memref<16x1xf32> to memref<16xf32>
+// CHECK: %[[V:.+]] = vector.transfer_read %[[SRC_0]]
+// CHECK: %[[RESULT]] = vector.shape_cast %[[V]] : vector<8xf32> to vector<8x1xf32>
+// CHECK: return %[[RESULT]]
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 85c5700798c61..06c64d4fc42b9 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -466,6 +466,33 @@ struct TestVectorMultiReductionLoweringPatterns
}
};
+struct TestVectorTransferCollapseInnerMostContiguousDims
+ : public PassWrapper<TestVectorTransferCollapseInnerMostContiguousDims,
+ FunctionPass> {
+ TestVectorTransferCollapseInnerMostContiguousDims() = default;
+ TestVectorTransferCollapseInnerMostContiguousDims(
+ const TestVectorTransferCollapseInnerMostContiguousDims &pass) {}
+
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<memref::MemRefDialect, AffineDialect>();
+ }
+
+ StringRef getArgument() const final {
+ return "test-vector-transfer-collapse-inner-most-dims";
+ }
+
+ StringRef getDescription() const final {
+ return "Test conversion patterns that reducedes the rank of the vector "
+ "transfer memory and vector operands.";
+ }
+
+ void runOnFunction() override {
+ RewritePatternSet patterns(&getContext());
+ populateVectorTransferCollapseInnerMostContiguousDimsPatterns(patterns);
+ (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
+ }
+};
+
} // end anonymous namespace
namespace mlir {
@@ -490,6 +517,8 @@ void registerTestVectorConversions() {
PassRegistration<TestVectorTransferLoweringPatterns>();
PassRegistration<TestVectorMultiReductionLoweringPatterns>();
+
+ PassRegistration<TestVectorTransferCollapseInnerMostContiguousDims>();
}
} // namespace test
} // namespace mlir
More information about the Mlir-commits
mailing list