[Mlir-commits] [mlir] [mlir][vector] Add patterns for vector masked load/store (PR #74834)

Diego Caballero llvmlistbot at llvm.org
Wed Dec 13 01:59:38 PST 2023


================
@@ -0,0 +1,122 @@
+//===- LowerVectorMaskedLoadStore.cpp - Lower 'vector.maskedload/store' op ===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements target-independent rewrites and utilities to lower the
+// 'vector.maskedload' and 'vector.maskedstore' operation.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
+
+#define DEBUG_TYPE "vector-masked-load-store-lowering"
+
+using namespace mlir;
+
+namespace {
+
+/// Convert vector.maskedload
+///
+/// Before:
+///
+///   vector.maskedload %base[%idx_0, %idx_1], %mask, %pass_thru
+///
+/// After:
+///
+///   %value = vector.load %base[%idx_0, %idx_1]
+///   arith.select %mask, %value, %pass_thru
+///
+struct VectorMaskedLoadOpConverter : OpRewritePattern<vector::MaskedLoadOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::MaskedLoadOp maskedLoadOp,
+                                PatternRewriter &rewriter) const override {
+    auto loc = maskedLoadOp.getLoc();
+    auto loadAll = rewriter.create<vector::LoadOp>(loc, maskedLoadOp.getType(),
+                                                   maskedLoadOp.getBase(),
+                                                   maskedLoadOp.getIndices());
+    auto selectedLoad = rewriter.create<arith::SelectOp>(
+        loc, maskedLoadOp.getMask(), loadAll, maskedLoadOp.getPassThru());
+    rewriter.replaceOp(maskedLoadOp, selectedLoad);
+
+    return success();
+  }
+};
+
+Value createConstantInteger(PatternRewriter &rewriter, Location loc,
+                            int32_t value) {
+  auto i32Type = rewriter.getI32Type();
+  return rewriter.create<arith::ConstantOp>(loc, i32Type,
+                                            IntegerAttr::get(i32Type, value));
+}
+
+/// Convert vector.maskedstore
+///
+/// Before:
+///
+///   vector.maskedstore %base[%idx_0, %idx_1], %mask, %value
+///
+/// After:
+///
+///   scf.for %iv = %c0 to %vector_len step %c1 {
+///     %m = vector.extractelement %mask[%iv]
+///     scf.if %m {
+///       %v = vector.extractelement %value[%iv]
+///       memref.store %v, %base[%idx_0, %idx_1]
+///     }
+///   }
+///
+struct VectorMaskedStoreOpConverter : OpRewritePattern<vector::MaskedStoreOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::MaskedStoreOp maskedStoreOp,
+                                PatternRewriter &rewriter) const override {
+    VectorType maskVType = maskedStoreOp.getMaskVectorType();
+    if (maskVType.getShape().size() != 1)
+      return rewriter.notifyMatchFailure(
+          maskedStoreOp, "expected vector.maskedstore with 1-D mask");
+
+    auto loc = maskedStoreOp.getLoc();
+    Value zero = createConstantInteger(rewriter, loc, 0);
+    Value one = createConstantInteger(rewriter, loc, 1);
+    Value maskLength =
+        createConstantInteger(rewriter, loc, maskVType.getShape()[0]);
+
+    auto loopOp = rewriter.create<scf::ForOp>(loc, zero, maskLength, one);
+    rewriter.setInsertionPointToStart(loopOp.getBody());
----------------
dcaballe wrote:

IRBuilder guard is needed to reset the insertion point back to where it was?

https://github.com/llvm/llvm-project/pull/74834


More information about the Mlir-commits mailing list