[Mlir-commits] [mlir] Make createReadOrMaskedRead a utility (PR #89119)
Matthias Springer
llvmlistbot at llvm.org
Thu Apr 18 01:24:23 PDT 2024
================
@@ -1593,3 +1593,40 @@ void linalg::populateDecomposeConvolutionPatterns(RewritePatternSet &patterns,
DownscaleSizeOneWindowed2DConvolution<PoolingNchwMaxOp, PoolingNcwMaxOp>>(
patterns.getContext(), benefit);
}
+
+Value mlir::linalg::createReadOrMaskedRead(OpBuilder &builder, Location loc,
+ Value source,
+ ArrayRef<int64_t> readShape,
+ Value padValue, bool doMasking) {
+ assert(llvm::none_of(readShape,
+ [](int64_t s) { return s == ShapedType::kDynamic; }));
+ auto sourceShape = dyn_cast<ShapedType>(source.getType()).getShape();
+ assert(sourceShape.size() == readShape.size());
+ auto maskType = VectorType::get(readShape, builder.getI1Type());
+ auto vectorType = VectorType::get(readShape, padValue.getType());
+ int64_t readRank = readShape.size();
+ auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
+ SmallVector<bool> inBoundsVal(readRank, true);
+ if (!doMasking) {
+ // Update the inBounds attribute.
+ for (unsigned i = 0; i < readRank; i++)
+ inBoundsVal[i] = sourceShape[i] == readShape[i];
+ }
+ auto transferReadOp = builder.create<vector::TransferReadOp>(
+ loc,
+ /*vectorType=*/vectorType,
+ /*source=*/source,
+ /*indices=*/SmallVector<Value>(readRank, zero),
+ /*padding=*/padValue,
+ /*inBounds=*/inBoundsVal);
+
+ if (llvm::equal(readShape, sourceShape) || !doMasking) {
----------------
matthias-springer wrote:
nit: Trivial braces not needed
https://github.com/llvm/llvm-project/pull/89119
More information about the Mlir-commits
mailing list