[Mlir-commits] [mlir] [mlir][vector] Add patterns for vector masked load/store (PR #74834)
Hsiangkai Wang
llvmlistbot at llvm.org
Wed Dec 13 06:01:14 PST 2023
https://github.com/Hsiangkai updated https://github.com/llvm/llvm-project/pull/74834
>From dd83d7ec6bf95b33350b7b069337c95639b7b7e6 Mon Sep 17 00:00:00 2001
From: Hsiangkai Wang <hsiangkai.wang at arm.com>
Date: Thu, 30 Nov 2023 14:09:00 +0000
Subject: [PATCH] [mlir][vector] Add patterns for vector masked load/store
In this patch, it will convert
vector.maskedload %base[%idx_0, %idx_1], %mask, %pass_thru
to
%ret = scf.for %iv = %c0 to %vector_len step %c1 iter_args(%vectorBuf = %pass_thru) {
%m = vector.extractelement %mask[%iv]
%value = scf.if %m {
%v = memref.load %base[%idx_0, %idx_1 + %iv]
%combined = vector.insertelement %v, %vectorBuf
yield %combined
} else {
yield %vectorBuf
}
yield %value
}
It will convert
vector.maskedstore %base[%idx_0, %idx_1], %mask, %value
to
scf.for 0 to %mask_len step 1 {
%m = vector.extractelement %mask
scf.if(%m) {
%v = vector.extractelement %value
memref.store %v, %base[%idx_0, %idx_1]
}
}
---
.../Vector/Transforms/LoweringPatterns.h | 10 +
.../Dialect/Vector/Transforms/CMakeLists.txt | 1 +
.../VectorEmulateMaskedLoadStore.cpp | 171 ++++++++++++++++++
.../vector-emulate-masked-load-store.mlir | 72 ++++++++
.../Dialect/Vector/TestVectorTransforms.cpp | 27 +++
5 files changed, 281 insertions(+)
create mode 100644 mlir/lib/Dialect/Vector/Transforms/VectorEmulateMaskedLoadStore.cpp
create mode 100644 mlir/test/Dialect/Vector/vector-emulate-masked-load-store.mlir
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index 292398a3dc5a7d..fa93bdaf248b7e 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -254,6 +254,16 @@ void populateVectorGatherLoweringPatterns(RewritePatternSet &patterns,
void populateVectorMaskLoweringPatternsForSideEffectingOps(
RewritePatternSet &patterns);
+/// Populate the pattern set with the following patterns:
+///
+/// [VectorMaskedLoadOpConverter]
+/// Turns vector.maskedload to scf.for + scf.if + memref.load
+///
+/// [VectorMaskedStoreOpConverter]
+/// Turns vector.maskedstore to scf.for + scf.if + memref.store
+void populateVectorMaskedLoadStoreEmulationPatterns(RewritePatternSet &patterns,
+ PatternBenefit benefit = 1);
+
} // namespace vector
} // namespace mlir
#endif // MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H
diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
index 513340096a5c1f..daf28882976ef6 100644
--- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
@@ -13,6 +13,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
SubsetOpInterfaceImpl.cpp
VectorDistribute.cpp
VectorDropLeadUnitDim.cpp
+ VectorEmulateMaskedLoadStore.cpp
VectorEmulateNarrowType.cpp
VectorInsertExtractStridedSliceRewritePatterns.cpp
VectorTransferOpTransforms.cpp
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateMaskedLoadStore.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateMaskedLoadStore.cpp
new file mode 100644
index 00000000000000..e35e13a1853397
--- /dev/null
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateMaskedLoadStore.cpp
@@ -0,0 +1,171 @@
+//=- VectorEmulateMaskedLoadStore.cpp - Emulate 'vector.maskedload/store' op =//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements target-independent rewrites and utilities to emulate the
+// 'vector.maskedload' and 'vector.maskedstore' operation.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
+
+#define DEBUG_TYPE "vector-emulate-masked-load-store"
+
+using namespace mlir;
+
+namespace {
+
+/// Convert vector.maskedload
+///
+/// Before:
+///
+/// vector.maskedload %base[%idx_0, %idx_1], %mask, %pass_thru
+///
+/// After:
+///
+/// %ret = scf.for %iv = %c0 to %vector_len step %c1 iter_args(%vectorBuf =
+/// %pass_thru) {
+/// %m = vector.extractelement %mask[%iv]
+/// %value = scf.if %m {
+/// %v = memref.load %base[%idx_0, %idx_1 + %iv]
+/// %combined = vector.insertelement %v, %vectorBuf
+/// yield %combined
+/// } else {
+/// yield %vectorBuf
+/// }
+/// yield %value
+/// }
+///
+struct VectorMaskedLoadOpConverter final
+ : OpRewritePattern<vector::MaskedLoadOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::MaskedLoadOp maskedLoadOp,
+ PatternRewriter &rewriter) const override {
+ VectorType maskVType = maskedLoadOp.getMaskVectorType();
+ if (maskVType.getShape().size() != 1)
+ return rewriter.notifyMatchFailure(
+ maskedLoadOp, "expected vector.maskedstore with 1-D mask");
+
+ Location loc = maskedLoadOp.getLoc();
+ Type i32Type = rewriter.getI32Type();
+ Value zero = rewriter.create<arith::ConstantOp>(
+ loc, i32Type, IntegerAttr::get(i32Type, 0));
+ Value one = rewriter.create<arith::ConstantOp>(
+ loc, i32Type, IntegerAttr::get(i32Type, 1));
+ Value maskLength = rewriter.create<arith::ConstantOp>(
+ loc, i32Type, IntegerAttr::get(i32Type, maskVType.getShape()[0]));
+
+ auto loopOp = rewriter.create<scf::ForOp>(
+ loc, zero, maskLength, one, ValueRange{maskedLoadOp.getPassThru()});
+ OpBuilder::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPointToStart(loopOp.getBody());
+
+ Value indVar = loopOp.getInductionVar();
+ Value vectorBuf = loopOp.getRegionIterArg(0);
+
+ auto maskBit = rewriter.create<vector::ExtractElementOp>(
+ loc, maskedLoadOp.getMask(), indVar);
+ auto ifOp = rewriter.create<scf::IfOp>(
+ loc, maskBit,
+ [&](OpBuilder &builder, Location loc) {
+ SmallVector<Value> newIndices(maskedLoadOp.getIndices().begin(),
+ maskedLoadOp.getIndices().end());
+ indVar = builder.create<arith::IndexCastOp>(
+ loc, rewriter.getIndexType(), indVar);
+ newIndices.back() =
+ builder.create<arith::AddIOp>(loc, newIndices.back(), indVar);
+ auto loadedValue = builder.create<memref::LoadOp>(
+ loc, maskedLoadOp.getBase(), newIndices);
+ auto combinedValue = builder.create<vector::InsertElementOp>(
+ loc, loadedValue, vectorBuf, indVar);
+ builder.create<scf::YieldOp>(loc, combinedValue.getResult());
+ },
+ [&](OpBuilder &builder, Location loc) {
+ builder.create<scf::YieldOp>(loc, vectorBuf);
+ });
+
+ rewriter.setInsertionPointToEnd(loopOp.getBody());
+ rewriter.create<scf::YieldOp>(loc, ifOp.getResult(0));
+
+ rewriter.replaceOp(maskedLoadOp, loopOp);
+
+ return success();
+ }
+};
+
+/// Convert vector.maskedstore
+///
+/// Before:
+///
+/// vector.maskedstore %base[%idx_0, %idx_1], %mask, %value
+///
+/// After:
+///
+/// scf.for %iv = %c0 to %vector_len step %c1 {
+/// %m = vector.extractelement %mask[%iv]
+/// scf.if %m {
+/// %v = vector.extractelement %value[%iv]
+/// memref.store %v, %base[%idx_0, %idx_1 + %iv]
+/// }
+/// }
+///
+struct VectorMaskedStoreOpConverter final
+ : OpRewritePattern<vector::MaskedStoreOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::MaskedStoreOp maskedStoreOp,
+ PatternRewriter &rewriter) const override {
+ VectorType maskVType = maskedStoreOp.getMaskVectorType();
+ if (maskVType.getShape().size() != 1)
+ return rewriter.notifyMatchFailure(
+ maskedStoreOp, "expected vector.maskedstore with 1-D mask");
+
+ Location loc = maskedStoreOp.getLoc();
+ Type i32Type = rewriter.getI32Type();
+ Value zero = rewriter.create<arith::ConstantOp>(
+ loc, i32Type, IntegerAttr::get(i32Type, 0));
+ Value one = rewriter.create<arith::ConstantOp>(
+ loc, i32Type, IntegerAttr::get(i32Type, 1));
+ Value maskLength = rewriter.create<arith::ConstantOp>(
+ loc, i32Type, IntegerAttr::get(i32Type, maskVType.getShape()[0]));
+
+ auto loopOp = rewriter.create<scf::ForOp>(loc, zero, maskLength, one);
+ OpBuilder::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPointToStart(loopOp.getBody());
+
+ Value indVar = loopOp.getInductionVar();
+ auto maskBit = rewriter.create<vector::ExtractElementOp>(
+ loc, maskedStoreOp.getMask(), indVar);
+ auto ifOp = rewriter.create<scf::IfOp>(loc, maskBit, /*else=*/false);
+ rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
+ auto value = rewriter.create<vector::ExtractElementOp>(
+ loc, maskedStoreOp.getValueToStore(), indVar);
+ SmallVector<Value> newIndices(maskedStoreOp.getIndices().begin(),
+ maskedStoreOp.getIndices().end());
+ indVar = rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(),
+ indVar);
+ newIndices.back() =
+ rewriter.create<arith::AddIOp>(loc, newIndices.back(), indVar);
+ rewriter.create<memref::StoreOp>(loc, value, maskedStoreOp.getBase(),
+ newIndices);
+
+ rewriter.replaceOp(maskedStoreOp, loopOp);
+
+ return success();
+ }
+};
+
+} // namespace
+
+void mlir::vector::populateVectorMaskedLoadStoreEmulationPatterns(
+ RewritePatternSet &patterns, PatternBenefit benefit) {
+ patterns.add<VectorMaskedLoadOpConverter, VectorMaskedStoreOpConverter>(
+ patterns.getContext(), benefit);
+}
diff --git a/mlir/test/Dialect/Vector/vector-emulate-masked-load-store.mlir b/mlir/test/Dialect/Vector/vector-emulate-masked-load-store.mlir
new file mode 100644
index 00000000000000..bfe7ba5932d44e
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-emulate-masked-load-store.mlir
@@ -0,0 +1,72 @@
+// RUN: mlir-opt %s --test-vector-emulate-masked-load-store | FileCheck %s
+
+module attributes {
+ spirv.target_env = #spirv.target_env<
+ #spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>
+ } {
+// CHECK-LABEL: @vector_maskedload
+// CHECK-SAME: (%[[ARG0:.*]]: memref<4x5xf32, #spirv.storage_class<StorageBuffer>>) -> vector<4xf32> {
+// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32>
+// CHECK: %[[C0_i32:.*]] = arith.constant 0 : i32
+// CHECK: %[[C1_i32:.*]] = arith.constant 1 : i32
+// CHECK: %[[C4_i32:.*]] = arith.constant 4 : i32
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[C4:.*]] = arith.constant 4 : index
+// CHECK: %[[S0:.*]] = vector.create_mask %[[C1]] : vector<4xi1>
+// CHECK: %[[S1:.*]] = scf.for %[[ARG1:.*]] = %[[C0_i32]] to %[[C4_i32]] step %[[C1_i32]] iter_args(%[[ARG2:.*]] = %[[CST]]) -> (vector<4xf32>) : i32 {
+// CHECK: %[[S2:.*]] = vector.extractelement %[[S0]][%[[ARG1]] : i32] : vector<4xi1>
+// CHECK: %[[S3:.*]] = scf.if %[[S2]] -> (vector<4xf32>) {
+// CHECK: %[[S4:.*]] = arith.index_cast %[[ARG1]] : i32 to index
+// CHECK: %[[S5:.*]] = arith.addi %[[S4]], %[[C4]] : index
+// CHECK: %[[S6:.*]] = memref.load %[[ARG0]][%[[C0]], %[[S5]]] : memref<4x5xf32, #spirv.storage_class<StorageBuffer>>
+// CHECK: %[[S7:.*]] = vector.insertelement %[[S6]], %[[ARG2]][%[[S4]] : index] : vector<4xf32>
+// CHECK: scf.yield %[[S7]] : vector<4xf32>
+// CHECK: } else {
+// CHECK: scf.yield %[[ARG2]] : vector<4xf32>
+// CHECK: }
+// CHECK: scf.yield %[[S3]] : vector<4xf32>
+// CHECK: }
+// CHECK: return %[[S1]] : vector<4xf32>
+// CHECK: }
+func.func @vector_maskedload(%arg0 : memref<4x5xf32, #spirv.storage_class<StorageBuffer>>) -> vector<4xf32> {
+ %idx_0 = arith.constant 0 : index
+ %idx_1 = arith.constant 1 : index
+ %idx_4 = arith.constant 4 : index
+ %mask = vector.create_mask %idx_1 : vector<4xi1>
+ %s = arith.constant 0.0 : f32
+ %pass_thru = vector.splat %s : vector<4xf32>
+ %0 = vector.maskedload %arg0[%idx_0, %idx_4], %mask, %pass_thru : memref<4x5xf32, #spirv.storage_class<StorageBuffer>>, vector<4xi1>, vector<4xf32> into vector<4xf32>
+ return %0: vector<4xf32>
+}
+
+// CHECK-LABEL: @vector_maskedstore
+// CHECK-SAME: (%[[ARG0:.*]]: memref<4x5xf32, #spirv.storage_class<StorageBuffer>>, %[[ARG1:.*]]: vector<4xf32>) {
+// CHECK: %[[C0_I32:.*]] = arith.constant 0 : i32
+// CHECK: %[[C1_I32:.*]] = arith.constant 1 : i32
+// CHECK: %[[C4_I32:.*]] = arith.constant 4 : i32
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[C4:.*]] = arith.constant 4 : index
+// CHECK: %[[S0:.*]] = vector.create_mask %[[C1]] : vector<4xi1>
+// CHECK: scf.for %[[ARG2:.*]] = %[[C0_I32]] to %[[C4_I32]] step %[[C1_I32]] : i32 {
+// CHECK: %[[S1:.*]] = vector.extractelement %[[S0]][%[[ARG2]] : i32] : vector<4xi1>
+// CHECK: scf.if %[[S1]] {
+// CHECK: %[[S2:.*]] = vector.extractelement %[[ARG1]][%[[ARG2]] : i32] : vector<4xf32>
+// CHECK: %[[S3:.*]] = arith.index_cast %[[ARG2]] : i32 to index
+// CHECK: %[[S4:.*]] = arith.addi %[[S3]], %[[C4]] : index
+// CHECK: memref.store %[[S2]], %[[ARG0]][%[[C0]], %[[S4]]] : memref<4x5xf32, #spirv.storage_class<StorageBuffer>>
+// CHECK: }
+// CHECK: }
+// CHECK: return
+// CHECK:}
+func.func @vector_maskedstore(%arg0 : memref<4x5xf32, #spirv.storage_class<StorageBuffer>>, %arg1 : vector<4xf32>) {
+ %idx_0 = arith.constant 0 : index
+ %idx_1 = arith.constant 1 : index
+ %idx_4 = arith.constant 4 : index
+ %mask = vector.create_mask %idx_1 : vector<4xi1>
+ vector.maskedstore %arg0[%idx_0, %idx_4], %mask, %arg1 : memref<4x5xf32, #spirv.storage_class<StorageBuffer>>, vector<4xi1>, vector<4xf32>
+ return
+}
+
+} // end module
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index e593c0defcd29e..39ca597ce63c8c 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -776,6 +776,31 @@ struct TestFoldArithExtensionIntoVectorContractPatterns
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};
+
+struct TestVectorEmulateMaskedLoadStore final
+ : public PassWrapper<TestVectorEmulateMaskedLoadStore,
+ OperationPass<func::FuncOp>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorEmulateMaskedLoadStore)
+
+ StringRef getArgument() const override {
+ return "test-vector-emulate-masked-load-store";
+ }
+ StringRef getDescription() const override {
+ return "Test patterns that emulate the maskedload/maskedstore op by "
+ " vector.load/store op";
+ }
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry
+ .insert<arith::ArithDialect, func::FuncDialect, memref::MemRefDialect,
+ scf::SCFDialect, vector::VectorDialect>();
+ }
+
+ void runOnOperation() override {
+ RewritePatternSet patterns(&getContext());
+ populateVectorMaskedLoadStoreEmulationPatterns(patterns);
+ (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+ }
+};
} // namespace
namespace mlir {
@@ -816,6 +841,8 @@ void registerTestVectorLowerings() {
PassRegistration<TestVectorGatherLowering>();
PassRegistration<TestFoldArithExtensionIntoVectorContractPatterns>();
+
+ PassRegistration<TestVectorEmulateMaskedLoadStore>();
}
} // namespace test
} // namespace mlir
More information about the Mlir-commits
mailing list