[Mlir-commits] [mlir] [mlir][vector] Add patterns for vector masked load/store (PR #74834)
Jakub Kuderski
llvmlistbot at llvm.org
Tue Dec 12 12:33:25 PST 2023
================
@@ -0,0 +1,122 @@
+//===- LowerVectorMaskedLoadStore.cpp - Lower 'vector.maskedload/store' op ===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements target-independent rewrites and utilities to lower the
+// 'vector.maskedload' and 'vector.maskedstore' operation.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
+
+#define DEBUG_TYPE "vector-masked-load-store-lowering"
+
+using namespace mlir;
+
+namespace {
+
+/// Convert vector.maskedload
+///
+/// Before:
+///
+/// vector.maskedload %base[%idx_0, %idx_1], %mask, %pass_thru
+///
+/// After:
+///
+/// %value = vector.load %base[%idx_0, %idx_1]
+/// arith.select %mask, %value, %pass_thru
+///
+struct VectorMaskedLoadOpConverter : OpRewritePattern<vector::MaskedLoadOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::MaskedLoadOp maskedLoadOp,
+ PatternRewriter &rewriter) const override {
+ auto loc = maskedLoadOp.getLoc();
+ auto loadAll = rewriter.create<vector::LoadOp>(loc, maskedLoadOp.getType(),
+ maskedLoadOp.getBase(),
+ maskedLoadOp.getIndices());
+ auto selectedLoad = rewriter.create<arith::SelectOp>(
+ loc, maskedLoadOp.getMask(), loadAll, maskedLoadOp.getPassThru());
----------------
kuhar wrote:
Is the load guaranteed to be in-bounds? Looking at the spec, this lowering seems invalid to me:
> If a mask bit is set and the corresponding index is out-of-bounds for the given base, the behavior is undefined. If a mask bit is not set, the value comes from the pass-through vector regardless of the index, and the index is allowed to be out-of-bounds.
We need to insert some control-flow here, `arith.select` seems insufficient to avoid UB on out-of-bounds indices.
https://github.com/llvm/llvm-project/pull/74834
More information about the Mlir-commits
mailing list