[Mlir-commits] [mlir] [mlir][vector] Add patterns for vector masked load/store (PR #74834)

Jakub Kuderski llvmlistbot at llvm.org
Thu Dec 14 07:51:33 PST 2023


================
@@ -0,0 +1,159 @@
+//=- VectorEmulateMaskedLoadStore.cpp - Emulate 'vector.maskedload/store' op =//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements target-independent rewrites and utilities to emulate the
+// 'vector.maskedload' and 'vector.maskedstore' operation.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
+
+#define DEBUG_TYPE "vector-emulate-masked-load-store"
+
+using namespace mlir;
+
+namespace {
+
+/// Convert vector.maskedload
+///
+/// Before:
+///
+///   vector.maskedload %base[%idx_0, %idx_1], %mask, %pass_thru
+///
+/// After:
+///
+///   %ivalue = %pass_thru
+///   %m = vector.extractelement %mask[%c0]
+///   %result = scf.if %m {
+///     %v = memref.load %base[%idx_0, %idx_1]
+///     %combined = vector.insertelement %v, %ivalue[%c0]
+///     scf.yield %combined
+///   } else {
+///     scf.yield %ivalue
+///   }
+///   ...
+///
+struct VectorMaskedLoadOpConverter final
+    : OpRewritePattern<vector::MaskedLoadOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::MaskedLoadOp maskedLoadOp,
+                                PatternRewriter &rewriter) const override {
+    VectorType maskVType = maskedLoadOp.getMaskVectorType();
+    if (maskVType.getShape().size() != 1)
+      return rewriter.notifyMatchFailure(
+          maskedLoadOp, "expected vector.maskedstore with 1-D mask");
+
+    Location loc = maskedLoadOp.getLoc();
+    int64_t maskLength = maskVType.getShape()[0];
+
+    Type indexType = rewriter.getIndexType();
+    Value mask = maskedLoadOp.getMask();
+    Value base = maskedLoadOp.getBase();
+    Value iValue = maskedLoadOp.getPassThru();
+    SmallVector<Value> indices(maskedLoadOp.getIndices().begin(),
+                               maskedLoadOp.getIndices().end());
+    Value one = rewriter.create<arith::ConstantOp>(
+        loc, indexType, IntegerAttr::get(indexType, 1));
+    for (int64_t i = 0; i < maskLength; ++i) {
+      Value index = rewriter.create<arith::ConstantOp>(
+          loc, indexType, IntegerAttr::get(indexType, i));
+      auto maskBit =
+          rewriter.create<vector::ExtractElementOp>(loc, mask, index);
----------------
kuhar wrote:

Use `vector::Extract` -- this is a non-dynamic version:

I think the builder would look like this:
```suggestion
      auto maskBit = rewriter.create<vector::ExtractOp>(loc, mask, i);
```

https://github.com/llvm/llvm-project/pull/74834


More information about the Mlir-commits mailing list