[Mlir-commits] [mlir] [mlir][vector][spirv] Lower vector.maskedload and vector.maskedstore to SPIR-V (PR #74834)

Hsiangkai Wang llvmlistbot at llvm.org
Tue Dec 12 11:35:16 PST 2023


https://github.com/Hsiangkai updated https://github.com/llvm/llvm-project/pull/74834

>From 04c1e28422f918e907f4dd6a92f78a04a5121f88 Mon Sep 17 00:00:00 2001
From: Hsiangkai Wang <hsiangkai.wang at arm.com>
Date: Thu, 30 Nov 2023 14:09:00 +0000
Subject: [PATCH] [mlir][vector] Add patterns for vector masked load/store

In this patch, it will convert

vector.maskedload %base[%idx_0, %idx_1], %mask, %pass_thru

to

%value = vector.load %base[%idx_0, %idx_1]
arith.select %mask, %value, %pass_thru

It will convert

vector.maskedstore %base[%idx_0, %idx_1], %mask, %value

to

scf.for 0 to %mask_len step 1 {
  %m = vector.extractelement %mask
  scf.if(%m) {
    %v = vector.extractelement %value
    memref.store %v, %base[%idx_0, %idx_1]
  }
}
---
 .../Vector/Transforms/LoweringPatterns.h      |  10 ++
 .../Dialect/Vector/Transforms/CMakeLists.txt  |   1 +
 .../Transforms/LowerVectorMaskedLoadStore.cpp | 122 ++++++++++++++++++
 .../vector-masked-load-store-lowering.mlir    |  58 +++++++++
 .../Dialect/Vector/TestVectorTransforms.cpp   |  28 ++++
 5 files changed, 219 insertions(+)
 create mode 100644 mlir/lib/Dialect/Vector/Transforms/LowerVectorMaskedLoadStore.cpp
 create mode 100644 mlir/test/Dialect/Vector/vector-masked-load-store-lowering.mlir

diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index 292398a3dc5a7..3d4b1af55342f 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -254,6 +254,16 @@ void populateVectorGatherLoweringPatterns(RewritePatternSet &patterns,
 void populateVectorMaskLoweringPatternsForSideEffectingOps(
     RewritePatternSet &patterns);
 
+/// Populate the pattern set with the following patterns:
+///
+/// [VectorMaskedLoadOpConverter]
+/// Turns vector.maskedload to vector.load + arith.select
+///
+/// [VectorMaskedStoreOpConverter]
+/// Turns vector.maskedstore to scf.for + scf.if + memref.store
+void populateVectorMaskedLoadStoreLoweringPatterns(RewritePatternSet &patterns,
+                                                   PatternBenefit benefit = 1);
+
 } // namespace vector
 } // namespace mlir
 #endif // MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H
diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
index 513340096a5c1..a0457adaf7f53 100644
--- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
@@ -5,6 +5,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
   LowerVectorContract.cpp
   LowerVectorGather.cpp
   LowerVectorMask.cpp
+  LowerVectorMaskedLoadStore.cpp
   LowerVectorMultiReduction.cpp
   LowerVectorScan.cpp
   LowerVectorShapeCast.cpp
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMaskedLoadStore.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMaskedLoadStore.cpp
new file mode 100644
index 0000000000000..6735da921ab18
--- /dev/null
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMaskedLoadStore.cpp
@@ -0,0 +1,122 @@
+//===- LowerVectorMaskedLoadStore.cpp - Lower 'vector.maskedload/store' op ===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements target-independent rewrites and utilities to lower the
+// 'vector.maskedload' and 'vector.maskedstore' operation.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
+
+#define DEBUG_TYPE "vector-masked-load-store-lowering"
+
+using namespace mlir;
+
+namespace {
+
+/// Convert vector.maskedload
+///
+/// Before:
+///
+///   vector.maskedload %base[%idx_0, %idx_1], %mask, %pass_thru
+///
+/// After:
+///
+///   %value = vector.load %base[%idx_0, %idx_1]
+///   arith.select %mask, %value, %pass_thru
+///
+struct VectorMaskedLoadOpConverter : OpRewritePattern<vector::MaskedLoadOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::MaskedLoadOp maskedLoadOp,
+                                PatternRewriter &rewriter) const override {
+    auto loc = maskedLoadOp.getLoc();
+    auto loadAll = rewriter.create<vector::LoadOp>(loc, maskedLoadOp.getType(),
+                                                   maskedLoadOp.getBase(),
+                                                   maskedLoadOp.getIndices());
+    auto selectedLoad = rewriter.create<arith::SelectOp>(
+        loc, maskedLoadOp.getMask(), loadAll, maskedLoadOp.getPassThru());
+    rewriter.replaceOp(maskedLoadOp, selectedLoad);
+
+    return success();
+  }
+};
+
+Value createConstantInteger(PatternRewriter &rewriter, Location loc,
+                            int32_t value) {
+  auto i32Type = rewriter.getI32Type();
+  return rewriter.create<arith::ConstantOp>(loc, i32Type,
+                                            IntegerAttr::get(i32Type, value));
+}
+
+/// Convert vector.maskedstore
+///
+/// Before:
+///
+///   vector.maskedstore %base[%idx_0, %idx_1], %mask, %value
+///
+/// After:
+///
+///   scf.for %iv = %c0 to %vector_len step %c1 {
+///     %m = vector.extractelement %mask[%iv]
+///     scf.if %m {
+///       %v = vector.extractelement %value[%iv]
+///       memref.store %v, %base[%idx_0, %idx_1]
+///     }
+///   }
+///
+struct VectorMaskedStoreOpConverter : OpRewritePattern<vector::MaskedStoreOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::MaskedStoreOp maskedStoreOp,
+                                PatternRewriter &rewriter) const override {
+    VectorType maskVType = maskedStoreOp.getMaskVectorType();
+    if (maskVType.getShape().size() != 1)
+      return rewriter.notifyMatchFailure(
+          maskedStoreOp, "expected vector.maskedstore with 1-D mask");
+
+    auto loc = maskedStoreOp.getLoc();
+    Value zero = createConstantInteger(rewriter, loc, 0);
+    Value one = createConstantInteger(rewriter, loc, 1);
+    Value maskLength =
+        createConstantInteger(rewriter, loc, maskVType.getShape()[0]);
+
+    auto loopOp = rewriter.create<scf::ForOp>(loc, zero, maskLength, one);
+    rewriter.setInsertionPointToStart(loopOp.getBody());
+
+    auto indVar = loopOp.getInductionVar();
+    auto maskBit = rewriter.create<vector::ExtractElementOp>(
+        loc, maskedStoreOp.getMask(), indVar);
+    auto ifOp = rewriter.create<scf::IfOp>(loc, maskBit, /*else=*/false);
+    rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
+    auto value = rewriter.create<vector::ExtractElementOp>(
+        loc, maskedStoreOp.getValueToStore(), indVar);
+    SmallVector<Value> newIndices(maskedStoreOp.getIndices().begin(),
+                                  maskedStoreOp.getIndices().end());
+    indVar = rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(),
+                                                 indVar);
+    newIndices.back() =
+        rewriter.create<arith::AddIOp>(loc, newIndices.back(), indVar);
+    rewriter.create<memref::StoreOp>(loc, value, maskedStoreOp.getBase(),
+                                     newIndices);
+
+    rewriter.replaceOp(maskedStoreOp, loopOp);
+
+    return success();
+  }
+};
+
+} // namespace
+
+void mlir::vector::populateVectorMaskedLoadStoreLoweringPatterns(
+    RewritePatternSet &patterns, PatternBenefit benefit) {
+  patterns.add<VectorMaskedLoadOpConverter, VectorMaskedStoreOpConverter>(
+      patterns.getContext(), benefit);
+}
diff --git a/mlir/test/Dialect/Vector/vector-masked-load-store-lowering.mlir b/mlir/test/Dialect/Vector/vector-masked-load-store-lowering.mlir
new file mode 100644
index 0000000000000..c173fbad9f524
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-masked-load-store-lowering.mlir
@@ -0,0 +1,58 @@
+// RUN: mlir-opt %s --test-vector-masked-load-store-lowering | FileCheck %s
+
+module attributes {
+  spirv.target_env = #spirv.target_env<
+    #spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>
+  } {
+// CHECK-LABEL:  @vector_maskedload
+//  CHECK-SAME:  (%[[ARG0:.*]]: memref<4x5xf32, #spirv.storage_class<StorageBuffer>>) -> vector<4xf32> {
+//       CHECK:    %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32>
+//       CHECK:    %[[C0:.*]] = arith.constant 0 : index
+//       CHECK:    %[[C1:.*]] = arith.constant 1 : index
+//       CHECK:    %[[C4:.*]] = arith.constant 4 : index
+//       CHECK:    %[[S0:.*]] = vector.create_mask %[[C1]] : vector<4xi1>
+//       CHECK:    %[[S1:.*]] = vector.load %[[ARG0]][%[[C0]], %[[C4]]] : memref<4x5xf32, #spirv.storage_class<StorageBuffer>>, vector<4xf32>
+//       CHECK:    %[[S2:.*]] = arith.select %[[S0]], %[[S1]], %[[CST]] : vector<4xi1>, vector<4xf32>
+//       CHECK:    return %[[S2]] : vector<4xf32>
+//       CHECK:  }
+func.func @vector_maskedload(%arg0 : memref<4x5xf32, #spirv.storage_class<StorageBuffer>>) -> vector<4xf32> {
+  %idx_0 = arith.constant 0 : index
+  %idx_1 = arith.constant 1 : index
+  %idx_4 = arith.constant 4 : index
+  %mask = vector.create_mask %idx_1 : vector<4xi1>
+  %s = arith.constant 0.0 : f32
+  %pass_thru = vector.splat %s : vector<4xf32>
+  %0 = vector.maskedload %arg0[%idx_0, %idx_4], %mask, %pass_thru : memref<4x5xf32, #spirv.storage_class<StorageBuffer>>, vector<4xi1>, vector<4xf32> into vector<4xf32>
+  return %0: vector<4xf32>
+}
+
+// CHECK-LABEL:  @vector_maskedstore
+//  CHECK-SAME:  (%[[ARG0:.*]]: memref<4x5xf32, #spirv.storage_class<StorageBuffer>>, %[[ARG1:.*]]: vector<4xf32>) {
+//       CHECK:  %[[C0_I32:.*]] = arith.constant 0 : i32
+//       CHECK:  %[[C1_I32:.*]] = arith.constant 1 : i32
+//       CHECK:  %[[C4_I32:.*]] = arith.constant 4 : i32
+//       CHECK:  %[[C0:.*]] = arith.constant 0 : index
+//       CHECK:  %[[C1:.*]] = arith.constant 1 : index
+//       CHECK:  %[[C4:.*]] = arith.constant 4 : index
+//       CHECK:  %[[S0:.*]] = vector.create_mask %[[C1]] : vector<4xi1>
+//       CHECK:  scf.for %[[ARG2:.*]] = %[[C0_I32]] to %[[C4_I32]] step %[[C1_I32]]  : i32 {
+//       CHECK:    %[[S1:.*]] = vector.extractelement %[[S0]][%[[ARG2]] : i32] : vector<4xi1>
+//       CHECK:    scf.if %[[S1]] {
+//       CHECK:      %[[S2:.*]] = vector.extractelement %[[ARG1]][%[[ARG2]] : i32] : vector<4xf32>
+//       CHECK:      %[[S3:.*]] = arith.index_cast %[[ARG2]] : i32 to index
+//       CHECK:      %[[S4:.*]] = arith.addi %[[S3]], %[[C4]] : index
+//       CHECK:      memref.store %[[S2]], %[[ARG0]][%[[C0]], %[[S4]]] : memref<4x5xf32, #spirv.storage_class<StorageBuffer>>
+//       CHECK:    }
+//       CHECK:  }
+//       CHECK:  return
+//       CHECK:}
+func.func @vector_maskedstore(%arg0 : memref<4x5xf32, #spirv.storage_class<StorageBuffer>>, %arg1 : vector<4xf32>) {
+  %idx_0 = arith.constant 0 : index
+  %idx_1 = arith.constant 1 : index
+  %idx_4 = arith.constant 4 : index
+  %mask = vector.create_mask %idx_1 : vector<4xi1>
+  vector.maskedstore %arg0[%idx_0, %idx_4], %mask, %arg1 : memref<4x5xf32, #spirv.storage_class<StorageBuffer>>, vector<4xi1>, vector<4xf32>
+  return
+}
+
+} // end module
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index e593c0defcd29..bfb3bda79e57e 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -776,6 +776,32 @@ struct TestFoldArithExtensionIntoVectorContractPatterns
     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
   }
 };
+
+struct TestVectorMaskedLoadStoreLowering
+    : public PassWrapper<TestVectorMaskedLoadStoreLowering,
+                         OperationPass<func::FuncOp>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
+      TestVectorMaskedLoadStoreLowering)
+
+  StringRef getArgument() const final {
+    return "test-vector-masked-load-store-lowering";
+  }
+  StringRef getDescription() const final {
+    return "Test patterns that lower the maskedload/maskedstore op to "
+           " vector.load/store op";
+  }
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry
+        .insert<arith::ArithDialect, func::FuncDialect, memref::MemRefDialect,
+                scf::SCFDialect, vector::VectorDialect>();
+  }
+
+  void runOnOperation() override {
+    RewritePatternSet patterns(&getContext());
+    populateVectorMaskedLoadStoreLoweringPatterns(patterns);
+    (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+  }
+};
 } // namespace
 
 namespace mlir {
@@ -816,6 +842,8 @@ void registerTestVectorLowerings() {
   PassRegistration<TestVectorGatherLowering>();
 
   PassRegistration<TestFoldArithExtensionIntoVectorContractPatterns>();
+
+  PassRegistration<TestVectorMaskedLoadStoreLowering>();
 }
 } // namespace test
 } // namespace mlir



More information about the Mlir-commits mailing list