[Mlir-commits] [mlir] [mlir][vector] Add patterns for vector masked load/store (PR #74834)

Jakub Kuderski llvmlistbot at llvm.org
Tue Dec 12 12:33:26 PST 2023


================
@@ -0,0 +1,122 @@
+//===- LowerVectorMaskedLoadStore.cpp - Lower 'vector.maskedload/store' op ===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements target-independent rewrites and utilities to lower the
+// 'vector.maskedload' and 'vector.maskedstore' operation.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
+
+#define DEBUG_TYPE "vector-masked-load-store-lowering"
+
+using namespace mlir;
+
+namespace {
+
+/// Convert vector.maskedload
+///
+/// Before:
+///
+///   vector.maskedload %base[%idx_0, %idx_1], %mask, %pass_thru
+///
+/// After:
+///
+///   %value = vector.load %base[%idx_0, %idx_1]
+///   arith.select %mask, %value, %pass_thru
+///
+struct VectorMaskedLoadOpConverter : OpRewritePattern<vector::MaskedLoadOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::MaskedLoadOp maskedLoadOp,
+                                PatternRewriter &rewriter) const override {
+    auto loc = maskedLoadOp.getLoc();
+    auto loadAll = rewriter.create<vector::LoadOp>(loc, maskedLoadOp.getType(),
+                                                   maskedLoadOp.getBase(),
+                                                   maskedLoadOp.getIndices());
+    auto selectedLoad = rewriter.create<arith::SelectOp>(
+        loc, maskedLoadOp.getMask(), loadAll, maskedLoadOp.getPassThru());
+    rewriter.replaceOp(maskedLoadOp, selectedLoad);
+
+    return success();
+  }
+};
+
+Value createConstantInteger(PatternRewriter &rewriter, Location loc,
+                            int32_t value) {
+  auto i32Type = rewriter.getI32Type();
+  return rewriter.create<arith::ConstantOp>(loc, i32Type,
+                                            IntegerAttr::get(i32Type, value));
+}
+
+/// Convert vector.maskedstore
+///
+/// Before:
+///
+///   vector.maskedstore %base[%idx_0, %idx_1], %mask, %value
+///
+/// After:
+///
+///   scf.for %iv = %c0 to %vector_len step %c1 {
+///     %m = vector.extractelement %mask[%iv]
+///     scf.if %m {
+///       %v = vector.extractelement %value[%iv]
+///       memref.store %v, %base[%idx_0, %idx_1]
+///     }
+///   }
+///
+struct VectorMaskedStoreOpConverter : OpRewritePattern<vector::MaskedStoreOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::MaskedStoreOp maskedStoreOp,
+                                PatternRewriter &rewriter) const override {
+    VectorType maskVType = maskedStoreOp.getMaskVectorType();
+    if (maskVType.getShape().size() != 1)
+      return rewriter.notifyMatchFailure(
+          maskedStoreOp, "expected vector.maskedstore with 1-D mask");
+
+    auto loc = maskedStoreOp.getLoc();
+    Value zero = createConstantInteger(rewriter, loc, 0);
+    Value one = createConstantInteger(rewriter, loc, 1);
+    Value maskLength =
+        createConstantInteger(rewriter, loc, maskVType.getShape()[0]);
+
+    auto loopOp = rewriter.create<scf::ForOp>(loc, zero, maskLength, one);
+    rewriter.setInsertionPointToStart(loopOp.getBody());
----------------
kuhar wrote:

Do we prefer `scf.for` here over a sequence of `scf.if`? I'd think that dynamic indexing into vector would lead to poor codegen when we have to break these vectors down into smaller ones when lowering to spirv. But maybe we rely on this for loop to be unrolled by something else later on?

https://github.com/llvm/llvm-project/pull/74834


More information about the Mlir-commits mailing list