[Mlir-commits] [mlir] [mlir][vector] Add patterns for vector masked load/store (PR #74834)
Jakub Kuderski
llvmlistbot at llvm.org
Thu Dec 14 07:51:35 PST 2023
================
@@ -0,0 +1,159 @@
+//=- VectorEmulateMaskedLoadStore.cpp - Emulate 'vector.maskedload/store' op =//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements target-independent rewrites and utilities to emulate the
+// 'vector.maskedload' and 'vector.maskedstore' operation.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
+
+#define DEBUG_TYPE "vector-emulate-masked-load-store"
+
+using namespace mlir;
+
+namespace {
+
+/// Convert vector.maskedload
+///
+/// Before:
+///
+/// vector.maskedload %base[%idx_0, %idx_1], %mask, %pass_thru
+///
+/// After:
+///
+/// %ivalue = %pass_thru
+/// %m = vector.extractelement %mask[%c0]
+/// %result = scf.if %m {
+/// %v = memref.load %base[%idx_0, %idx_1]
+/// %combined = vector.insertelement %v, %ivalue[%c0]
+/// scf.yield %combined
+/// } else {
+/// scf.yield %ivalue
+/// }
+/// ...
+///
+struct VectorMaskedLoadOpConverter final
+ : OpRewritePattern<vector::MaskedLoadOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::MaskedLoadOp maskedLoadOp,
+ PatternRewriter &rewriter) const override {
+ VectorType maskVType = maskedLoadOp.getMaskVectorType();
+ if (maskVType.getShape().size() != 1)
+ return rewriter.notifyMatchFailure(
+ maskedLoadOp, "expected vector.maskedstore with 1-D mask");
+
+ Location loc = maskedLoadOp.getLoc();
+ int64_t maskLength = maskVType.getShape()[0];
+
+ Type indexType = rewriter.getIndexType();
+ Value mask = maskedLoadOp.getMask();
+ Value base = maskedLoadOp.getBase();
+ Value iValue = maskedLoadOp.getPassThru();
+ SmallVector<Value> indices(maskedLoadOp.getIndices().begin(),
+ maskedLoadOp.getIndices().end());
+ Value one = rewriter.create<arith::ConstantOp>(
+ loc, indexType, IntegerAttr::get(indexType, 1));
+ for (int64_t i = 0; i < maskLength; ++i) {
+ Value index = rewriter.create<arith::ConstantOp>(
+ loc, indexType, IntegerAttr::get(indexType, i));
+ auto maskBit =
+ rewriter.create<vector::ExtractElementOp>(loc, mask, index);
+
+ auto ifOp = rewriter.create<scf::IfOp>(
+ loc, maskBit,
+ [&](OpBuilder &builder, Location loc) {
+ auto loadedValue =
+ builder.create<memref::LoadOp>(loc, base, indices);
+ auto combinedValue = builder.create<vector::InsertElementOp>(
+ loc, loadedValue, iValue, index);
+ builder.create<scf::YieldOp>(loc, combinedValue.getResult());
+ },
+ [&](OpBuilder &builder, Location loc) {
+ builder.create<scf::YieldOp>(loc, iValue);
+ });
+ iValue = ifOp.getResult(0);
+
+ indices.back() = rewriter.create<arith::AddIOp>(loc, indices.back(), one);
+ }
+
+ rewriter.replaceOp(maskedLoadOp, iValue);
+
+ return success();
+ }
+};
+
+/// Convert vector.maskedstore
+///
+/// Before:
+///
+/// vector.maskedstore %base[%idx_0, %idx_1], %mask, %value
+///
+/// After:
+///
+/// %m = vector.extractelement %mask[%c0]
+/// scf.if %m {
+/// %extracted = vector.extractelement %value[%c0]
+/// memref.store %extracted, %base[%idx_0, %idx_1]
+/// }
+/// ...
+///
+struct VectorMaskedStoreOpConverter final
+ : OpRewritePattern<vector::MaskedStoreOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::MaskedStoreOp maskedStoreOp,
+ PatternRewriter &rewriter) const override {
+ VectorType maskVType = maskedStoreOp.getMaskVectorType();
+ if (maskVType.getShape().size() != 1)
+ return rewriter.notifyMatchFailure(
+ maskedStoreOp, "expected vector.maskedstore with 1-D mask");
+
+ Location loc = maskedStoreOp.getLoc();
+ int64_t maskLength = maskVType.getShape()[0];
+
+ Type indexType = rewriter.getIndexType();
+ Value mask = maskedStoreOp.getMask();
+ Value base = maskedStoreOp.getBase();
+ Value value = maskedStoreOp.getValueToStore();
+ SmallVector<Value> indices(maskedStoreOp.getIndices().begin(),
+ maskedStoreOp.getIndices().end());
+ Value one = rewriter.create<arith::ConstantOp>(
+ loc, indexType, IntegerAttr::get(indexType, 1));
+ for (int64_t i = 0; i < maskLength; ++i) {
+ Value index = rewriter.create<arith::ConstantOp>(
+ loc, indexType, IntegerAttr::get(indexType, i));
+ auto maskBit =
+ rewriter.create<vector::ExtractElementOp>(loc, mask, index);
----------------
kuhar wrote:
use `vector::ExtractOp`
https://github.com/llvm/llvm-project/pull/74834
More information about the Mlir-commits
mailing list