[Mlir-commits] [mlir] a088bed - [mlir] VectorToSCF cleanup
Matthias Springer
llvmlistbot at llvm.org
Thu May 13 19:04:45 PDT 2021
Author: Matthias Springer
Date: 2021-05-14T11:04:37+09:00
New Revision: a088bed4e3b572f8a3f8a1f7a41942f3005e4811
URL: https://github.com/llvm/llvm-project/commit/a088bed4e3b572f8a3f8a1f7a41942f3005e4811
DIFF: https://github.com/llvm/llvm-project/commit/a088bed4e3b572f8a3f8a1f7a41942f3005e4811.diff
LOG: [mlir] VectorToSCF cleanup
Group functions/structs in namespaces for better code readability.
Depends On D102123
Reviewed By: nicolasvasilache
Differential Revision: https://reviews.llvm.org/D102124
Added:
Modified:
mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index a209bc4672e3..9972bcf5a3ae 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -49,52 +49,6 @@ struct VectorToSCFPattern : public OpRewritePattern<OpTy> {
VectorTransferToSCFOptions options;
};
-/// Given a MemRefType with VectorType element type, unpack one dimension from
-/// the VectorType into the MemRefType.
-///
-/// E.g.: memref<9xvector<5x6xf32>> --> memref<9x5xvector<6xf32>>
-static MemRefType unpackOneDim(MemRefType type) {
- auto vectorType = type.getElementType().dyn_cast<VectorType>();
- auto memrefShape = type.getShape();
- SmallVector<int64_t, 8> newMemrefShape;
- newMemrefShape.append(memrefShape.begin(), memrefShape.end());
- newMemrefShape.push_back(vectorType.getDimSize(0));
- return MemRefType::get(newMemrefShape,
- VectorType::get(vectorType.getShape().drop_front(),
- vectorType.getElementType()));
-}
-
-/// Helper data structure for data and mask buffers.
-struct BufferAllocs {
- Value dataBuffer;
- Value maskBuffer;
-};
-
-/// Allocate temporary buffers for data (vector) and mask (if present).
-/// TODO: Parallelism and threadlocal considerations.
-template <typename OpTy>
-static BufferAllocs allocBuffers(OpTy xferOp) {
- auto &b = ScopedContext::getBuilderRef();
- OpBuilder::InsertionGuard guard(b);
- Operation *scope =
- xferOp->template getParentWithTrait<OpTrait::AutomaticAllocationScope>();
- assert(scope && "Expected op to be inside automatic allocation scope");
- b.setInsertionPointToStart(&scope->getRegion(0).front());
-
- BufferAllocs result;
- auto bufferType = MemRefType::get({}, xferOp.getVectorType());
- result.dataBuffer = memref_alloca(bufferType).value;
-
- if (xferOp.mask()) {
- auto maskType = MemRefType::get({}, xferOp.mask().getType());
- Value maskBuffer = memref_alloca(maskType);
- memref_store(xferOp.mask(), maskBuffer);
- result.maskBuffer = memref_load(maskBuffer);
- }
-
- return result;
-}
-
/// Given a vector transfer op, calculate which dimension of the `source`
/// memref should be unpacked in the next application of TransferOpConversion.
/// A return value of None indicates a broadcast.
@@ -284,6 +238,54 @@ static void maybeApplyPassLabel(OpBuilder &builder, OpTy newXferOp,
newXferOp->setAttr(kPassLabel, builder.getUnitAttr());
}
+namespace lowering_n_d {
+
+/// Helper data structure for data and mask buffers.
+struct BufferAllocs {
+ Value dataBuffer;
+ Value maskBuffer;
+};
+
+/// Allocate temporary buffers for data (vector) and mask (if present).
+/// TODO: Parallelism and threadlocal considerations.
+template <typename OpTy>
+static BufferAllocs allocBuffers(OpTy xferOp) {
+ auto &b = ScopedContext::getBuilderRef();
+ OpBuilder::InsertionGuard guard(b);
+ Operation *scope =
+ xferOp->template getParentWithTrait<OpTrait::AutomaticAllocationScope>();
+ assert(scope && "Expected op to be inside automatic allocation scope");
+ b.setInsertionPointToStart(&scope->getRegion(0).front());
+
+ BufferAllocs result;
+ auto bufferType = MemRefType::get({}, xferOp.getVectorType());
+ result.dataBuffer = memref_alloca(bufferType).value;
+
+ if (xferOp.mask()) {
+ auto maskType = MemRefType::get({}, xferOp.mask().getType());
+ auto maskBuffer = memref_alloca(maskType).value;
+ memref_store(xferOp.mask(), maskBuffer);
+ result.maskBuffer = memref_load(maskBuffer);
+ }
+
+ return result;
+}
+
+/// Given a MemRefType with VectorType element type, unpack one dimension from
+/// the VectorType into the MemRefType.
+///
+/// E.g.: memref<9xvector<5x6xf32>> --> memref<9x5xvector<6xf32>>
+static MemRefType unpackOneDim(MemRefType type) {
+ auto vectorType = type.getElementType().dyn_cast<VectorType>();
+ auto memrefShape = type.getShape();
+ SmallVector<int64_t, 8> newMemrefShape;
+ newMemrefShape.append(memrefShape.begin(), memrefShape.end());
+ newMemrefShape.push_back(vectorType.getDimSize(0));
+ return MemRefType::get(newMemrefShape,
+ VectorType::get(vectorType.getShape().drop_front(),
+ vectorType.getElementType()));
+}
+
/// Given a transfer op, find the memref from which the mask is loaded. This
/// is similar to Strategy<TransferWriteOp>::getBuffer.
template <typename OpTy>
@@ -688,6 +690,10 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
}
};
+} // namespace lowering_n_d
+
+namespace lowering_n_d_unrolled {
+
/// If the original transfer op has a mask, compute the mask of the new transfer
/// op (for the current iteration `i`) and assign it.
template <typename OpTy>
@@ -954,6 +960,10 @@ struct UnrollTransferWriteConversion
}
};
+} // namespace lowering_n_d_unrolled
+
+namespace lowering_1_d {
+
/// Compute the indices into the memref for the LoadOp/StoreOp generated as
/// part of TransferOp1dConversion. Return the memref dimension on which
/// the transfer is operating. A return value of None indicates a broadcast.
@@ -1114,6 +1124,7 @@ struct TransferOp1dConversion : public VectorToSCFPattern<OpTy> {
}
};
+} // namespace lowering_1_d
} // namespace
namespace mlir {
@@ -1121,19 +1132,21 @@ namespace mlir {
void populateVectorToSCFConversionPatterns(
RewritePatternSet &patterns, const VectorTransferToSCFOptions &options) {
if (options.unroll) {
- patterns.add<UnrollTransferReadConversion, UnrollTransferWriteConversion>(
+ patterns.add<lowering_n_d_unrolled::UnrollTransferReadConversion,
+ lowering_n_d_unrolled::UnrollTransferWriteConversion>(
patterns.getContext(), options);
} else {
- patterns.add<PrepareTransferReadConversion, PrepareTransferWriteConversion,
- TransferOpConversion<TransferReadOp>,
- TransferOpConversion<TransferWriteOp>>(patterns.getContext(),
- options);
+ patterns.add<lowering_n_d::PrepareTransferReadConversion,
+ lowering_n_d::PrepareTransferWriteConversion,
+ lowering_n_d::TransferOpConversion<TransferReadOp>,
+ lowering_n_d::TransferOpConversion<TransferWriteOp>>(
+ patterns.getContext(), options);
}
if (options.targetRank == 1) {
- patterns.add<TransferOp1dConversion<TransferReadOp>,
- TransferOp1dConversion<TransferWriteOp>>(patterns.getContext(),
- options);
+ patterns.add<lowering_1_d::TransferOp1dConversion<TransferReadOp>,
+ lowering_1_d::TransferOp1dConversion<TransferWriteOp>>(
+ patterns.getContext(), options);
}
}
More information about the Mlir-commits
mailing list