[Mlir-commits] [mlir] [mlir][vector] Add emulation patterns for vector masked load/store (PR #74834)

Hsiangkai Wang llvmlistbot at llvm.org
Fri Dec 15 02:16:02 PST 2023


https://github.com/Hsiangkai updated https://github.com/llvm/llvm-project/pull/74834

>From 550ce29f0a9b88cf94fe1f4614ee6806866e7a02 Mon Sep 17 00:00:00 2001
From: Hsiangkai Wang <hsiangkai.wang at arm.com>
Date: Thu, 30 Nov 2023 14:09:00 +0000
Subject: [PATCH] [mlir][vector] Add emulation patterns for vector masked
 load/store

In this patch, it will convert

vector.maskedload %base[%idx_0, %idx_1], %mask, %pass_thru

to

%ivalue = %pass_thru
%m = vector.extract %mask[0]
%result0 = scf.if %m {
  %v = memref.load %base[%idx_0, %idx_1]
  %combined = vector.insert %v, %ivalue[0]
  scf.yield %combined
} else {
  scf.yield %ivalue
}
%m = vector.extract %mask[1]
%result1 = scf.if %m {
  %v = memref.load %base[%idx_0, %idx_1 + 1]
  %combined = vector.insert %v, %result0[1]
  scf.yield %combined
} else {
  scf.yield %result0
}
...

It will convert

vector.maskedstore %base[%idx_0, %idx_1], %mask, %value

to

%m = vector.extract %mask[0]
scf.if %m {
  %extracted = vector.extract %value[0]
  memref.store %extracted, %base[%idx_0, %idx_1]
}
%m = vector.extract %mask[1]
scf.if %m {
  %extracted = vector.extract %value[1]
  memref.store %extracted, %base[%idx_0, %idx_1 + 1]
}
...
---
 .../Vector/Transforms/LoweringPatterns.h      |  10 ++
 .../Dialect/Vector/Transforms/CMakeLists.txt  |   1 +
 .../VectorEmulateMaskedLoadStore.cpp          | 161 ++++++++++++++++++
 .../vector-emulate-masked-load-store.mlir     |  95 +++++++++++
 .../Dialect/Vector/TestVectorTransforms.cpp   |  27 +++
 5 files changed, 294 insertions(+)
 create mode 100644 mlir/lib/Dialect/Vector/Transforms/VectorEmulateMaskedLoadStore.cpp
 create mode 100644 mlir/test/Dialect/Vector/vector-emulate-masked-load-store.mlir

diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index 292398a3dc5a7d..57b39f5f52c6d3 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -254,6 +254,16 @@ void populateVectorGatherLoweringPatterns(RewritePatternSet &patterns,
 void populateVectorMaskLoweringPatternsForSideEffectingOps(
     RewritePatternSet &patterns);
 
+/// Populate the pattern set with the following patterns:
+///
+/// [VectorMaskedLoadOpConverter]
+/// Turns vector.maskedload to scf.if + memref.load
+///
+/// [VectorMaskedStoreOpConverter]
+/// Turns vector.maskedstore to scf.if + memref.store
+void populateVectorMaskedLoadStoreEmulationPatterns(RewritePatternSet &patterns,
+                                                    PatternBenefit benefit = 1);
+
 } // namespace vector
 } // namespace mlir
 #endif // MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H
diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
index 513340096a5c1f..daf28882976ef6 100644
--- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
@@ -13,6 +13,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
   SubsetOpInterfaceImpl.cpp
   VectorDistribute.cpp
   VectorDropLeadUnitDim.cpp
+  VectorEmulateMaskedLoadStore.cpp
   VectorEmulateNarrowType.cpp
   VectorInsertExtractStridedSliceRewritePatterns.cpp
   VectorTransferOpTransforms.cpp
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateMaskedLoadStore.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateMaskedLoadStore.cpp
new file mode 100644
index 00000000000000..8cc7008d80b3ed
--- /dev/null
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateMaskedLoadStore.cpp
@@ -0,0 +1,161 @@
+//=- VectorEmulateMaskedLoadStore.cpp - Emulate 'vector.maskedload/store' op =//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements target-independent rewrites and utilities to emulate the
+// 'vector.maskedload' and 'vector.maskedstore' operation.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
+
+using namespace mlir;
+
+namespace {
+
+/// Convert vector.maskedload
+///
+/// Before:
+///
+///   vector.maskedload %base[%idx_0, %idx_1], %mask, %pass_thru
+///
+/// After:
+///
+///   %ivalue = %pass_thru
+///   %m = vector.extract %mask[0]
+///   %result0 = scf.if %m {
+///     %v = memref.load %base[%idx_0, %idx_1]
+///     %combined = vector.insert %v, %ivalue[0]
+///     scf.yield %combined
+///   } else {
+///     scf.yield %ivalue
+///   }
+///   %m = vector.extract %mask[1]
+///   %result1 = scf.if %m {
+///     %v = memref.load %base[%idx_0, %idx_1 + 1]
+///     %combined = vector.insert %v, %result0[1]
+///     scf.yield %combined
+///   } else {
+///     scf.yield %result0
+///   }
+///   ...
+///
+struct VectorMaskedLoadOpConverter final
+    : OpRewritePattern<vector::MaskedLoadOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::MaskedLoadOp maskedLoadOp,
+                                PatternRewriter &rewriter) const override {
+    VectorType maskVType = maskedLoadOp.getMaskVectorType();
+    if (maskVType.getShape().size() != 1)
+      return rewriter.notifyMatchFailure(
+          maskedLoadOp, "expected vector.maskedstore with 1-D mask");
+
+    Location loc = maskedLoadOp.getLoc();
+    int64_t maskLength = maskVType.getShape()[0];
+
+    Type indexType = rewriter.getIndexType();
+    Value mask = maskedLoadOp.getMask();
+    Value base = maskedLoadOp.getBase();
+    Value iValue = maskedLoadOp.getPassThru();
+    auto indices = llvm::to_vector_of<Value>(maskedLoadOp.getIndices());
+    Value one = rewriter.create<arith::ConstantOp>(
+        loc, indexType, IntegerAttr::get(indexType, 1));
+    for (int64_t i = 0; i < maskLength; ++i) {
+      auto maskBit = rewriter.create<vector::ExtractOp>(loc, mask, i);
+
+      auto ifOp = rewriter.create<scf::IfOp>(
+          loc, maskBit,
+          [&](OpBuilder &builder, Location loc) {
+            auto loadedValue =
+                builder.create<memref::LoadOp>(loc, base, indices);
+            auto combinedValue =
+                builder.create<vector::InsertOp>(loc, loadedValue, iValue, i);
+            builder.create<scf::YieldOp>(loc, combinedValue.getResult());
+          },
+          [&](OpBuilder &builder, Location loc) {
+            builder.create<scf::YieldOp>(loc, iValue);
+          });
+      iValue = ifOp.getResult(0);
+
+      indices.back() = rewriter.create<arith::AddIOp>(loc, indices.back(), one);
+    }
+
+    rewriter.replaceOp(maskedLoadOp, iValue);
+
+    return success();
+  }
+};
+
+/// Convert vector.maskedstore
+///
+/// Before:
+///
+///   vector.maskedstore %base[%idx_0, %idx_1], %mask, %value
+///
+/// After:
+///
+///   %m = vector.extract %mask[0]
+///   scf.if %m {
+///     %extracted = vector.extract %value[0]
+///     memref.store %extracted, %base[%idx_0, %idx_1]
+///   }
+///   %m = vector.extract %mask[1]
+///   scf.if %m {
+///     %extracted = vector.extract %value[1]
+///     memref.store %extracted, %base[%idx_0, %idx_1 + 1]
+///   }
+///   ...
+///
+struct VectorMaskedStoreOpConverter final
+    : OpRewritePattern<vector::MaskedStoreOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::MaskedStoreOp maskedStoreOp,
+                                PatternRewriter &rewriter) const override {
+    VectorType maskVType = maskedStoreOp.getMaskVectorType();
+    if (maskVType.getShape().size() != 1)
+      return rewriter.notifyMatchFailure(
+          maskedStoreOp, "expected vector.maskedstore with 1-D mask");
+
+    Location loc = maskedStoreOp.getLoc();
+    int64_t maskLength = maskVType.getShape()[0];
+
+    Type indexType = rewriter.getIndexType();
+    Value mask = maskedStoreOp.getMask();
+    Value base = maskedStoreOp.getBase();
+    Value value = maskedStoreOp.getValueToStore();
+    auto indices = llvm::to_vector_of<Value>(maskedStoreOp.getIndices());
+    Value one = rewriter.create<arith::ConstantOp>(
+        loc, indexType, IntegerAttr::get(indexType, 1));
+    for (int64_t i = 0; i < maskLength; ++i) {
+      auto maskBit = rewriter.create<vector::ExtractOp>(loc, mask, i);
+
+      auto ifOp = rewriter.create<scf::IfOp>(loc, maskBit, /*else=*/false);
+      rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
+      auto extractedValue = rewriter.create<vector::ExtractOp>(loc, value, i);
+      rewriter.create<memref::StoreOp>(loc, extractedValue, base, indices);
+
+      rewriter.setInsertionPointAfter(ifOp);
+      indices.back() = rewriter.create<arith::AddIOp>(loc, indices.back(), one);
+    }
+
+    rewriter.eraseOp(maskedStoreOp);
+
+    return success();
+  }
+};
+
+} // namespace
+
+void mlir::vector::populateVectorMaskedLoadStoreEmulationPatterns(
+    RewritePatternSet &patterns, PatternBenefit benefit) {
+  patterns.add<VectorMaskedLoadOpConverter, VectorMaskedStoreOpConverter>(
+      patterns.getContext(), benefit);
+}
diff --git a/mlir/test/Dialect/Vector/vector-emulate-masked-load-store.mlir b/mlir/test/Dialect/Vector/vector-emulate-masked-load-store.mlir
new file mode 100644
index 00000000000000..3867f075af8e4b
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-emulate-masked-load-store.mlir
@@ -0,0 +1,95 @@
+// RUN: mlir-opt %s --test-vector-emulate-masked-load-store | FileCheck %s
+
+// CHECK-LABEL:  @vector_maskedload
+//  CHECK-SAME:  (%[[ARG0:.*]]: memref<4x5xf32>) -> vector<4xf32> {
+//   CHECK-DAG:  %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32>
+//   CHECK-DAG:  %[[C7:.*]] = arith.constant 7 : index
+//   CHECK-DAG:  %[[C6:.*]] = arith.constant 6 : index
+//   CHECK-DAG:  %[[C5:.*]] = arith.constant 5 : index
+//   CHECK-DAG:  %[[C0:.*]] = arith.constant 0 : index
+//   CHECK-DAG:  %[[C1:.*]] = arith.constant 1 : index
+//   CHECK-DAG:  %[[C4:.*]] = arith.constant 4 : index
+//   CHECK-DAG:  %[[S0:.*]] = vector.create_mask %[[C1]] : vector<4xi1>
+//       CHECK:  %[[S1:.*]] = vector.extract %[[S0]][0] : i1 from vector<4xi1>
+//       CHECK:  %[[S2:.*]] = scf.if %[[S1]] -> (vector<4xf32>) {
+//       CHECK:    %[[S9:.*]] = memref.load %[[ARG0]][%[[C0]], %[[C4]]] : memref<4x5xf32>
+//       CHECK:    %[[S10:.*]] = vector.insert %[[S9]], %[[CST]] [0] : f32 into vector<4xf32>
+//       CHECK:    scf.yield %[[S10]] : vector<4xf32>
+//       CHECK:  } else {
+//       CHECK:    scf.yield %[[CST]] : vector<4xf32>
+//       CHECK:  }
+//       CHECK:  %[[S3:.*]] = vector.extract %[[S0]][1] : i1 from vector<4xi1>
+//       CHECK:  %[[S4:.*]] = scf.if %[[S3]] -> (vector<4xf32>) {
+//       CHECK:    %[[S9:.*]] = memref.load %[[ARG0]][%[[C0]], %[[C5]]] : memref<4x5xf32>
+//       CHECK:    %[[S10:.*]] = vector.insert %[[S9]], %[[S2]] [1] : f32 into vector<4xf32>
+//       CHECK:    scf.yield %[[S10]] : vector<4xf32>
+//       CHECK:  } else {
+//       CHECK:    scf.yield %[[S2]] : vector<4xf32>
+//       CHECK:  }
+//       CHECK:  %[[S5:.*]] = vector.extract %[[S0]][2] : i1 from vector<4xi1>
+//       CHECK:  %[[S6:.*]] = scf.if %[[S5]] -> (vector<4xf32>) {
+//       CHECK:    %[[S9:.*]] = memref.load %[[ARG0]][%[[C0]], %[[C6]]] : memref<4x5xf32>
+//       CHECK:    %[[S10:.*]] = vector.insert %[[S9]], %[[S4]] [2] : f32 into vector<4xf32>
+//       CHECK:    scf.yield %[[S10]] : vector<4xf32>
+//       CHECK:  } else {
+//       CHECK:    scf.yield %[[S4]] : vector<4xf32>
+//       CHECK:  }
+//       CHECK:  %[[S7:.*]] = vector.extract %[[S0]][3] : i1 from vector<4xi1>
+//       CHECK:  %[[S8:.*]] = scf.if %[[S7]] -> (vector<4xf32>) {
+//       CHECK:    %[[S9:.*]] = memref.load %[[ARG0]][%[[C0]], %[[C7]]] : memref<4x5xf32>
+//       CHECK:    %[[S10:.*]] = vector.insert %[[S9]], %[[S6]] [3] : f32 into vector<4xf32>
+//       CHECK:    scf.yield %[[S10]] : vector<4xf32>
+//       CHECK:  } else {
+//       CHECK:    scf.yield %[[S6]] : vector<4xf32>
+//       CHECK:  }
+//       CHECK:  return %[[S8]] : vector<4xf32>
+func.func @vector_maskedload(%arg0 : memref<4x5xf32>) -> vector<4xf32> {
+  %idx_0 = arith.constant 0 : index
+  %idx_1 = arith.constant 1 : index
+  %idx_4 = arith.constant 4 : index
+  %mask = vector.create_mask %idx_1 : vector<4xi1>
+  %s = arith.constant 0.0 : f32
+  %pass_thru = vector.splat %s : vector<4xf32>
+  %0 = vector.maskedload %arg0[%idx_0, %idx_4], %mask, %pass_thru : memref<4x5xf32>, vector<4xi1>, vector<4xf32> into vector<4xf32>
+  return %0: vector<4xf32>
+}
+
+// CHECK-LABEL:  @vector_maskedstore
+//  CHECK-SAME:  (%[[ARG0:.*]]: memref<4x5xf32>, %[[ARG1:.*]]: vector<4xf32>) {
+//   CHECK-DAG:  %[[C7:.*]] = arith.constant 7 : index
+//   CHECK-DAG:  %[[C6:.*]] = arith.constant 6 : index
+//   CHECK-DAG:  %[[C5:.*]] = arith.constant 5 : index
+//   CHECK-DAG:  %[[C0:.*]] = arith.constant 0 : index
+//   CHECK-DAG:  %[[C1:.*]] = arith.constant 1 : index
+//   CHECK-DAG:  %[[C4:.*]] = arith.constant 4 : index
+//   CHECK-DAG:  %[[S0:.*]] = vector.create_mask %[[C1]] : vector<4xi1>
+//       CHECK:  %[[S1:.*]] = vector.extract %[[S0]][0] : i1 from vector<4xi1>
+//       CHECK:  scf.if %[[S1]] {
+//       CHECK:    %[[S5:.*]] = vector.extract %[[ARG1]][0] : f32 from vector<4xf32>
+//       CHECK:    memref.store %[[S5]], %[[ARG0]][%[[C0]], %[[C4]]] : memref<4x5xf32>
+//       CHECK:  }
+//       CHECK:  %[[S2:.*]] = vector.extract %[[S0]][1] : i1 from vector<4xi1>
+//       CHECK:  scf.if %[[S2]] {
+//       CHECK:    %[[S5:.*]] = vector.extract %[[ARG1]][1] : f32 from vector<4xf32>
+//       CHECK:    memref.store %[[S5]], %[[ARG0]][%[[C0]], %[[C5]]] : memref<4x5xf32>
+//       CHECK:  }
+//       CHECK:  %[[S3:.*]] = vector.extract %[[S0]][2] : i1 from vector<4xi1>
+//       CHECK:  scf.if %[[S3]] {
+//       CHECK:    %[[S5:.*]] = vector.extract %[[ARG1]][2] : f32 from vector<4xf32>
+//       CHECK:    memref.store %[[S5]], %[[ARG0]][%[[C0]], %[[C6]]] : memref<4x5xf32>
+//       CHECK:  }
+//       CHECK:  %[[S4:.*]] = vector.extract %[[S0]][3] : i1 from vector<4xi1>
+//       CHECK:  scf.if %[[S4]] {
+//       CHECK:    %[[S5:.*]] = vector.extract %[[ARG1]][3] : f32 from vector<4xf32>
+//       CHECK:    memref.store %[[S5]], %[[ARG0]][%[[C0]], %[[C7]]] : memref<4x5xf32>
+//       CHECK:  }
+//       CHECK:  return
+//       CHECK:}
+func.func @vector_maskedstore(%arg0 : memref<4x5xf32>, %arg1 : vector<4xf32>) {
+  %idx_0 = arith.constant 0 : index
+  %idx_1 = arith.constant 1 : index
+  %idx_4 = arith.constant 4 : index
+  %mask = vector.create_mask %idx_1 : vector<4xi1>
+  vector.maskedstore %arg0[%idx_0, %idx_4], %mask, %arg1 : memref<4x5xf32>, vector<4xi1>, vector<4xf32>
+  return
+}
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index a643343e9342ad..03ddebe82344d8 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -777,6 +777,31 @@ struct TestFoldArithExtensionIntoVectorContractPatterns
     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
   }
 };
+
+struct TestVectorEmulateMaskedLoadStore final
+    : public PassWrapper<TestVectorEmulateMaskedLoadStore,
+                         OperationPass<func::FuncOp>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorEmulateMaskedLoadStore)
+
+  StringRef getArgument() const override {
+    return "test-vector-emulate-masked-load-store";
+  }
+  StringRef getDescription() const override {
+    return "Test patterns that emulate the maskedload/maskedstore op by "
+           " memref.load/store and scf.if";
+  }
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry
+        .insert<arith::ArithDialect, func::FuncDialect, memref::MemRefDialect,
+                scf::SCFDialect, vector::VectorDialect>();
+  }
+
+  void runOnOperation() override {
+    RewritePatternSet patterns(&getContext());
+    populateVectorMaskedLoadStoreEmulationPatterns(patterns);
+    (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+  }
+};
 } // namespace
 
 namespace mlir {
@@ -817,6 +842,8 @@ void registerTestVectorLowerings() {
   PassRegistration<TestVectorGatherLowering>();
 
   PassRegistration<TestFoldArithExtensionIntoVectorContractPatterns>();
+
+  PassRegistration<TestVectorEmulateMaskedLoadStore>();
 }
 } // namespace test
 } // namespace mlir



More information about the Mlir-commits mailing list