[Mlir-commits] [mlir] [mlir][vector] Deal with special patterns when emulating masked load/store (PR #75587)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Dec 15 04:09:16 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-vector

Author: Hsiangkai Wang (Hsiangkai)

<details>
<summary>Changes</summary>

We can simplify the generated code when the mask is created by constant_mask or create_mask. The mask will look like [1, 1, 1, ...]. We can use vector.load + vector.insert_strided_slice to emulate maskedload and use vector.extract_strided_slice + vector.store to emulate maskedstore.

---

Patch is 22.50 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/75587.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Vector/Transforms/VectorEmulateMaskedLoadStore.cpp (+137-95) 
- (modified) mlir/test/Dialect/Vector/vector-emulate-masked-load-store.mlir (+122-19) 


``````````diff
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateMaskedLoadStore.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateMaskedLoadStore.cpp
index 8cc7008d80b3ed..c61633d5ff8aeb 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateMaskedLoadStore.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateMaskedLoadStore.cpp
@@ -19,33 +19,22 @@ using namespace mlir;
 
 namespace {
 
+std::optional<int64_t> getMaskLeadingOneLength(Value mask) {
+  if (auto m = mask.getDefiningOp<vector::ConstantMaskOp>()) {
+    ArrayAttr masks = m.getMaskDimSizes();
+    assert(masks.size() == 1 && "Only support 1-D mask.");
+    return llvm::cast<IntegerAttr>(masks[0]).getInt();
+  } else if (auto m = mask.getDefiningOp<vector::CreateMaskOp>()) {
+    auto maskOperands = m.getOperands();
+    assert(maskOperands.size() == 1 && "Only support 1-D mask.");
+    if (auto constantOp = maskOperands[0].getDefiningOp<arith::ConstantOp>()) {
+      return llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
+    }
+  }
+  return {};
+}
+
 /// Convert vector.maskedload
-///
-/// Before:
-///
-///   vector.maskedload %base[%idx_0, %idx_1], %mask, %pass_thru
-///
-/// After:
-///
-///   %ivalue = %pass_thru
-///   %m = vector.extract %mask[0]
-///   %result0 = scf.if %m {
-///     %v = memref.load %base[%idx_0, %idx_1]
-///     %combined = vector.insert %v, %ivalue[0]
-///     scf.yield %combined
-///   } else {
-///     scf.yield %ivalue
-///   }
-///   %m = vector.extract %mask[1]
-///   %result1 = scf.if %m {
-///     %v = memref.load %base[%idx_0, %idx_1 + 1]
-///     %combined = vector.insert %v, %result0[1]
-///     scf.yield %combined
-///   } else {
-///     scf.yield %result0
-///   }
-///   ...
-///
 struct VectorMaskedLoadOpConverter final
     : OpRewritePattern<vector::MaskedLoadOp> {
   using OpRewritePattern::OpRewritePattern;
@@ -58,61 +47,82 @@ struct VectorMaskedLoadOpConverter final
           maskedLoadOp, "expected vector.maskedstore with 1-D mask");
 
     Location loc = maskedLoadOp.getLoc();
-    int64_t maskLength = maskVType.getShape()[0];
-
-    Type indexType = rewriter.getIndexType();
-    Value mask = maskedLoadOp.getMask();
-    Value base = maskedLoadOp.getBase();
-    Value iValue = maskedLoadOp.getPassThru();
-    auto indices = llvm::to_vector_of<Value>(maskedLoadOp.getIndices());
-    Value one = rewriter.create<arith::ConstantOp>(
-        loc, indexType, IntegerAttr::get(indexType, 1));
-    for (int64_t i = 0; i < maskLength; ++i) {
-      auto maskBit = rewriter.create<vector::ExtractOp>(loc, mask, i);
-
-      auto ifOp = rewriter.create<scf::IfOp>(
-          loc, maskBit,
-          [&](OpBuilder &builder, Location loc) {
-            auto loadedValue =
-                builder.create<memref::LoadOp>(loc, base, indices);
-            auto combinedValue =
-                builder.create<vector::InsertOp>(loc, loadedValue, iValue, i);
-            builder.create<scf::YieldOp>(loc, combinedValue.getResult());
-          },
-          [&](OpBuilder &builder, Location loc) {
-            builder.create<scf::YieldOp>(loc, iValue);
-          });
-      iValue = ifOp.getResult(0);
-
-      indices.back() = rewriter.create<arith::AddIOp>(loc, indices.back(), one);
+    std::optional<int64_t> ones =
+        getMaskLeadingOneLength(maskedLoadOp.getMask());
+    if (ones) {
+      /// Converted to:
+      ///
+      ///   %partial = vector.load %base[%idx_0, %idx_1]
+      ///   %value = vector.insert_strided_slice %partial, %pass_thru
+      ///            {offsets = [0], strides = [1]}
+      Type vectorType =
+          VectorType::get(*ones, maskedLoadOp.getMemRefType().getElementType());
+      auto loadOp = rewriter.create<vector::LoadOp>(
+          loc, vectorType, maskedLoadOp.getBase(), maskedLoadOp.getIndices());
+      auto insertedValue = rewriter.create<vector::InsertStridedSliceOp>(
+          loc, loadOp, maskedLoadOp.getPassThru(),
+          /*offsets=*/ArrayRef<int64_t>({0}),
+          /*strides=*/ArrayRef<int64_t>({1}));
+      rewriter.replaceOp(maskedLoadOp, insertedValue);
+    } else {
+      /// Converted to:
+      ///
+      ///   %ivalue = %pass_thru
+      ///   %m = vector.extract %mask[0]
+      ///   %result0 = scf.if %m {
+      ///     %v = memref.load %base[%idx_0, %idx_1]
+      ///     %combined = vector.insert %v, %ivalue[0]
+      ///     scf.yield %combined
+      ///   } else {
+      ///     scf.yield %ivalue
+      ///   }
+      ///   %m = vector.extract %mask[1]
+      ///   %result1 = scf.if %m {
+      ///     %v = memref.load %base[%idx_0, %idx_1 + 1]
+      ///     %combined = vector.insert %v, %result0[1]
+      ///     scf.yield %combined
+      ///   } else {
+      ///     scf.yield %result0
+      ///   }
+      ///   ...
+      ///
+      int64_t maskLength = maskVType.getShape()[0];
+
+      Type indexType = rewriter.getIndexType();
+      Value mask = maskedLoadOp.getMask();
+      Value base = maskedLoadOp.getBase();
+      Value iValue = maskedLoadOp.getPassThru();
+      auto indices = llvm::to_vector_of<Value>(maskedLoadOp.getIndices());
+      Value one = rewriter.create<arith::ConstantOp>(
+          loc, indexType, IntegerAttr::get(indexType, 1));
+      for (int64_t i = 0; i < maskLength; ++i) {
+        auto maskBit = rewriter.create<vector::ExtractOp>(loc, mask, i);
+
+        auto ifOp = rewriter.create<scf::IfOp>(
+            loc, maskBit,
+            [&](OpBuilder &builder, Location loc) {
+              auto loadedValue =
+                  builder.create<memref::LoadOp>(loc, base, indices);
+              auto combinedValue =
+                  builder.create<vector::InsertOp>(loc, loadedValue, iValue, i);
+              builder.create<scf::YieldOp>(loc, combinedValue.getResult());
+            },
+            [&](OpBuilder &builder, Location loc) {
+              builder.create<scf::YieldOp>(loc, iValue);
+            });
+        iValue = ifOp.getResult(0);
+
+        indices.back() = rewriter.create<arith::AddIOp>(loc, indices.back(), one);
+      }
+
+      rewriter.replaceOp(maskedLoadOp, iValue);
     }
 
-    rewriter.replaceOp(maskedLoadOp, iValue);
-
     return success();
   }
 };
 
 /// Convert vector.maskedstore
-///
-/// Before:
-///
-///   vector.maskedstore %base[%idx_0, %idx_1], %mask, %value
-///
-/// After:
-///
-///   %m = vector.extract %mask[0]
-///   scf.if %m {
-///     %extracted = vector.extract %value[0]
-///     memref.store %extracted, %base[%idx_0, %idx_1]
-///   }
-///   %m = vector.extract %mask[1]
-///   scf.if %m {
-///     %extracted = vector.extract %value[1]
-///     memref.store %extracted, %base[%idx_0, %idx_1 + 1]
-///   }
-///   ...
-///
 struct VectorMaskedStoreOpConverter final
     : OpRewritePattern<vector::MaskedStoreOp> {
   using OpRewritePattern::OpRewritePattern;
@@ -125,29 +135,61 @@ struct VectorMaskedStoreOpConverter final
           maskedStoreOp, "expected vector.maskedstore with 1-D mask");
 
     Location loc = maskedStoreOp.getLoc();
-    int64_t maskLength = maskVType.getShape()[0];
-
-    Type indexType = rewriter.getIndexType();
-    Value mask = maskedStoreOp.getMask();
-    Value base = maskedStoreOp.getBase();
-    Value value = maskedStoreOp.getValueToStore();
-    auto indices = llvm::to_vector_of<Value>(maskedStoreOp.getIndices());
-    Value one = rewriter.create<arith::ConstantOp>(
-        loc, indexType, IntegerAttr::get(indexType, 1));
-    for (int64_t i = 0; i < maskLength; ++i) {
-      auto maskBit = rewriter.create<vector::ExtractOp>(loc, mask, i);
-
-      auto ifOp = rewriter.create<scf::IfOp>(loc, maskBit, /*else=*/false);
-      rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
-      auto extractedValue = rewriter.create<vector::ExtractOp>(loc, value, i);
-      rewriter.create<memref::StoreOp>(loc, extractedValue, base, indices);
-
-      rewriter.setInsertionPointAfter(ifOp);
-      indices.back() = rewriter.create<arith::AddIOp>(loc, indices.back(), one);
+    std::optional<int64_t> ones =
+      getMaskLeadingOneLength(maskedStoreOp.getMask());
+    if (ones) {
+      /// Converted to:
+      ///
+      ///   %partial = vector.extract_strided_slice %value
+      ///              {offsets = [0], sizes = [1], strides = [1]}
+      ///   vector.store %partial, %base[%idx_0, %idx_1]
+      Value value = maskedStoreOp.getValueToStore();
+      Value extractedValue = rewriter.create<vector::ExtractStridedSliceOp>(
+	loc, value, /*offsets=*/ArrayRef<int64_t>({0}),
+	/*sizes=*/ArrayRef<int64_t>(*ones),
+	/*strides=*/ArrayRef<int64_t>({1}));
+      auto storeOp = rewriter.create<vector::StoreOp>(
+	loc, extractedValue, maskedStoreOp.getBase(),
+	maskedStoreOp.getIndices());
+      rewriter.replaceOp(maskedStoreOp, storeOp);
+    } else {
+      /// Converted to:
+      ///
+      ///   %m = vector.extract %mask[0]
+      ///   scf.if %m {
+      ///     %extracted = vector.extract %value[0]
+      ///     memref.store %extracted, %base[%idx_0, %idx_1]
+      ///   }
+      ///   %m = vector.extract %mask[1]
+      ///   scf.if %m {
+      ///     %extracted = vector.extract %value[1]
+      ///     memref.store %extracted, %base[%idx_0, %idx_1 + 1]
+      ///   }
+      ///   ...
+      int64_t maskLength = maskVType.getShape()[0];
+
+      Type indexType = rewriter.getIndexType();
+      Value mask = maskedStoreOp.getMask();
+      Value base = maskedStoreOp.getBase();
+      Value value = maskedStoreOp.getValueToStore();
+      auto indices = llvm::to_vector_of<Value>(maskedStoreOp.getIndices());
+      Value one = rewriter.create<arith::ConstantOp>(
+          loc, indexType, IntegerAttr::get(indexType, 1));
+      for (int64_t i = 0; i < maskLength; ++i) {
+        auto maskBit = rewriter.create<vector::ExtractOp>(loc, mask, i);
+
+        auto ifOp = rewriter.create<scf::IfOp>(loc, maskBit, /*else=*/false);
+        rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
+        auto extractedValue = rewriter.create<vector::ExtractOp>(loc, value, i);
+        rewriter.create<memref::StoreOp>(loc, extractedValue, base, indices);
+
+        rewriter.setInsertionPointAfter(ifOp);
+        indices.back() = rewriter.create<arith::AddIOp>(loc, indices.back(), one);
+      }
+
+      rewriter.eraseOp(maskedStoreOp);
     }
 
-    rewriter.eraseOp(maskedStoreOp);
-
     return success();
   }
 };
diff --git a/mlir/test/Dialect/Vector/vector-emulate-masked-load-store.mlir b/mlir/test/Dialect/Vector/vector-emulate-masked-load-store.mlir
index 3867f075af8e4b..473f28d0b901b4 100644
--- a/mlir/test/Dialect/Vector/vector-emulate-masked-load-store.mlir
+++ b/mlir/test/Dialect/Vector/vector-emulate-masked-load-store.mlir
@@ -1,16 +1,14 @@
 // RUN: mlir-opt %s --test-vector-emulate-masked-load-store | FileCheck %s
 
 // CHECK-LABEL:  @vector_maskedload
-//  CHECK-SAME:  (%[[ARG0:.*]]: memref<4x5xf32>) -> vector<4xf32> {
+//  CHECK-SAME:  (%[[ARG0:.*]]: memref<4x5xf32>, %[[ARG1:.*]]: vector<4xi1>) -> vector<4xf32> {
 //   CHECK-DAG:  %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32>
 //   CHECK-DAG:  %[[C7:.*]] = arith.constant 7 : index
 //   CHECK-DAG:  %[[C6:.*]] = arith.constant 6 : index
 //   CHECK-DAG:  %[[C5:.*]] = arith.constant 5 : index
 //   CHECK-DAG:  %[[C0:.*]] = arith.constant 0 : index
-//   CHECK-DAG:  %[[C1:.*]] = arith.constant 1 : index
 //   CHECK-DAG:  %[[C4:.*]] = arith.constant 4 : index
-//   CHECK-DAG:  %[[S0:.*]] = vector.create_mask %[[C1]] : vector<4xi1>
-//       CHECK:  %[[S1:.*]] = vector.extract %[[S0]][0] : i1 from vector<4xi1>
+//       CHECK:  %[[S1:.*]] = vector.extract %[[ARG1]][0] : i1 from vector<4xi1>
 //       CHECK:  %[[S2:.*]] = scf.if %[[S1]] -> (vector<4xf32>) {
 //       CHECK:    %[[S9:.*]] = memref.load %[[ARG0]][%[[C0]], %[[C4]]] : memref<4x5xf32>
 //       CHECK:    %[[S10:.*]] = vector.insert %[[S9]], %[[CST]] [0] : f32 into vector<4xf32>
@@ -18,7 +16,7 @@
 //       CHECK:  } else {
 //       CHECK:    scf.yield %[[CST]] : vector<4xf32>
 //       CHECK:  }
-//       CHECK:  %[[S3:.*]] = vector.extract %[[S0]][1] : i1 from vector<4xi1>
+//       CHECK:  %[[S3:.*]] = vector.extract %[[ARG1]][1] : i1 from vector<4xi1>
 //       CHECK:  %[[S4:.*]] = scf.if %[[S3]] -> (vector<4xf32>) {
 //       CHECK:    %[[S9:.*]] = memref.load %[[ARG0]][%[[C0]], %[[C5]]] : memref<4x5xf32>
 //       CHECK:    %[[S10:.*]] = vector.insert %[[S9]], %[[S2]] [1] : f32 into vector<4xf32>
@@ -26,7 +24,7 @@
 //       CHECK:  } else {
 //       CHECK:    scf.yield %[[S2]] : vector<4xf32>
 //       CHECK:  }
-//       CHECK:  %[[S5:.*]] = vector.extract %[[S0]][2] : i1 from vector<4xi1>
+//       CHECK:  %[[S5:.*]] = vector.extract %[[ARG1]][2] : i1 from vector<4xi1>
 //       CHECK:  %[[S6:.*]] = scf.if %[[S5]] -> (vector<4xf32>) {
 //       CHECK:    %[[S9:.*]] = memref.load %[[ARG0]][%[[C0]], %[[C6]]] : memref<4x5xf32>
 //       CHECK:    %[[S10:.*]] = vector.insert %[[S9]], %[[S4]] [2] : f32 into vector<4xf32>
@@ -34,7 +32,7 @@
 //       CHECK:  } else {
 //       CHECK:    scf.yield %[[S4]] : vector<4xf32>
 //       CHECK:  }
-//       CHECK:  %[[S7:.*]] = vector.extract %[[S0]][3] : i1 from vector<4xi1>
+//       CHECK:  %[[S7:.*]] = vector.extract %[[ARG1]][3] : i1 from vector<4xi1>
 //       CHECK:  %[[S8:.*]] = scf.if %[[S7]] -> (vector<4xf32>) {
 //       CHECK:    %[[S9:.*]] = memref.load %[[ARG0]][%[[C0]], %[[C7]]] : memref<4x5xf32>
 //       CHECK:    %[[S10:.*]] = vector.insert %[[S9]], %[[S6]] [3] : f32 into vector<4xf32>
@@ -43,53 +41,158 @@
 //       CHECK:    scf.yield %[[S6]] : vector<4xf32>
 //       CHECK:  }
 //       CHECK:  return %[[S8]] : vector<4xf32>
-func.func @vector_maskedload(%arg0 : memref<4x5xf32>) -> vector<4xf32> {
+func.func @vector_maskedload(%arg0 : memref<4x5xf32>, %arg1 : vector<4xi1>) -> vector<4xf32> {
   %idx_0 = arith.constant 0 : index
-  %idx_1 = arith.constant 1 : index
   %idx_4 = arith.constant 4 : index
-  %mask = vector.create_mask %idx_1 : vector<4xi1>
   %s = arith.constant 0.0 : f32
   %pass_thru = vector.splat %s : vector<4xf32>
-  %0 = vector.maskedload %arg0[%idx_0, %idx_4], %mask, %pass_thru : memref<4x5xf32>, vector<4xi1>, vector<4xf32> into vector<4xf32>
+  %0 = vector.maskedload %arg0[%idx_0, %idx_4], %arg1, %pass_thru : memref<4x5xf32>, vector<4xi1>, vector<4xf32> into vector<4xf32>
   return %0: vector<4xf32>
 }
 
 // CHECK-LABEL:  @vector_maskedstore
-//  CHECK-SAME:  (%[[ARG0:.*]]: memref<4x5xf32>, %[[ARG1:.*]]: vector<4xf32>) {
+//  CHECK-SAME:  (%[[ARG0:.*]]: memref<4x5xf32>, %[[ARG1:.*]]: vector<4xf32>, %[[ARG2:.*]]: vector<4xi1>) {
 //   CHECK-DAG:  %[[C7:.*]] = arith.constant 7 : index
 //   CHECK-DAG:  %[[C6:.*]] = arith.constant 6 : index
 //   CHECK-DAG:  %[[C5:.*]] = arith.constant 5 : index
 //   CHECK-DAG:  %[[C0:.*]] = arith.constant 0 : index
-//   CHECK-DAG:  %[[C1:.*]] = arith.constant 1 : index
 //   CHECK-DAG:  %[[C4:.*]] = arith.constant 4 : index
-//   CHECK-DAG:  %[[S0:.*]] = vector.create_mask %[[C1]] : vector<4xi1>
-//       CHECK:  %[[S1:.*]] = vector.extract %[[S0]][0] : i1 from vector<4xi1>
+//       CHECK:  %[[S1:.*]] = vector.extract %[[ARG2]][0] : i1 from vector<4xi1>
 //       CHECK:  scf.if %[[S1]] {
 //       CHECK:    %[[S5:.*]] = vector.extract %[[ARG1]][0] : f32 from vector<4xf32>
 //       CHECK:    memref.store %[[S5]], %[[ARG0]][%[[C0]], %[[C4]]] : memref<4x5xf32>
 //       CHECK:  }
-//       CHECK:  %[[S2:.*]] = vector.extract %[[S0]][1] : i1 from vector<4xi1>
+//       CHECK:  %[[S2:.*]] = vector.extract %[[ARG2]][1] : i1 from vector<4xi1>
 //       CHECK:  scf.if %[[S2]] {
 //       CHECK:    %[[S5:.*]] = vector.extract %[[ARG1]][1] : f32 from vector<4xf32>
 //       CHECK:    memref.store %[[S5]], %[[ARG0]][%[[C0]], %[[C5]]] : memref<4x5xf32>
 //       CHECK:  }
-//       CHECK:  %[[S3:.*]] = vector.extract %[[S0]][2] : i1 from vector<4xi1>
+//       CHECK:  %[[S3:.*]] = vector.extract %[[ARG2]][2] : i1 from vector<4xi1>
 //       CHECK:  scf.if %[[S3]] {
 //       CHECK:    %[[S5:.*]] = vector.extract %[[ARG1]][2] : f32 from vector<4xf32>
 //       CHECK:    memref.store %[[S5]], %[[ARG0]][%[[C0]], %[[C6]]] : memref<4x5xf32>
 //       CHECK:  }
-//       CHECK:  %[[S4:.*]] = vector.extract %[[S0]][3] : i1 from vector<4xi1>
+//       CHECK:  %[[S4:.*]] = vector.extract %[[ARG2]][3] : i1 from vector<4xi1>
 //       CHECK:  scf.if %[[S4]] {
 //       CHECK:    %[[S5:.*]] = vector.extract %[[ARG1]][3] : f32 from vector<4xf32>
 //       CHECK:    memref.store %[[S5]], %[[ARG0]][%[[C0]], %[[C7]]] : memref<4x5xf32>
 //       CHECK:  }
 //       CHECK:  return
 //       CHECK:}
-func.func @vector_maskedstore(%arg0 : memref<4x5xf32>, %arg1 : vector<4xf32>) {
+func.func @vector_maskedstore(%arg0 : memref<4x5xf32>, %arg1 : vector<4xf32>, %arg2 : vector<4xi1>) {
+  %idx_0 = arith.constant 0 : index
+  %idx_4 = arith.constant 4 : index
+  vector.maskedstore %arg0[%idx_0, %idx_4], %arg2, %arg1 : memref<4x5xf32>, vector<4xi1>, vector<4xf32>
+  return
+}
+
+// CHECK-LABEL:  @vector_maskedload_c1
+//  CHECK-SAME:  (%[[ARG0:.*]]: memref<4x5xf32>) -> vector<4xf32> {
+//   CHECK-DAG:   %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32>
+//   CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
+//   CHECK-DAG:   %[[C4:.*]] = arith.constant 4 : index
+//       CHECK:   %[[S0:.*]] = vector.load %[[ARG0]][%[[C0]], %[[C4]]] : memref<4x5xf32>, vector<1xf32>
+//       CHECK:   %[[S1:.*]] = vector.insert_strided_slice %[[S0]], %[[CST]] {offsets = [0], strides = [1]} : vector<1xf32> into vector<4xf32>
+//       CHECK:   return %[[S1]] : vector<4xf32>
+//       CHECK: }
+func.func @vector_maskedload_c1(%arg0 : memref<4x5xf32>) -> vector<4xf32> {
   %idx_0 = arith.constant 0 : index
   %idx_1 = arith.constant 1 : index
   %idx_4 = arith.constant 4 : index
   %mask = vector.create_mask %idx_1 : vector<4xi1>
+  %s = arith.constant 0.0 : f32
+  %pass_thru = vector.splat %s : vector<4xf32>
+  %0 = vector.maskedload %arg0[%idx_0, %idx_4], %mask, %pass_thru : memref<4x5xf32>, vector<4xi1>, vector<4xf32> into vector<4xf32>
+  return %0: vector<4xf32>
+}
+
+// CHECK-LABEL:  @vector_maskedload_c2
+//  CHECK-SAME:  (%[[ARG0:.*]]: memref<4x5xf32>) -> vector<4xf32> {
+//   CHECK-DAG:   %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32>
+//   CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
+//   CHECK-DAG:   %[[C4:.*]] = arith.constant 4 : index
+//       CHECK:   %[[S0:.*]] = vector.load %[[ARG0]][%[[C0]], %[[C4]]] : memref<4x5xf32>, vector<2xf32>
+//       CHECK:   %[[S1:.*]] = vector.insert_strided_slice %[[S0]], %[[CST]] {offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
+//       CHECK:   return %[[S1]] : vector<4xf32>
+//       CHECK: }
+func.func @vector_maskedload_c2(%arg0 : memref<4x5xf32>) -> vector<4xf32> {
+  %idx_0 = arith.constant 0 : index
+  %idx_2 = arith.constant 2 : index
+  %idx_4 = arith.constant 4 : index
+  %mask = vector.create_mask %idx_2 : vector<4xi1>
+  %s = arith.constant 0.0 : f32
+  %pass_thru = vector.splat %s : vector<4xf32>
+  %0 = vector.maskedload %arg0[%idx_0, %idx_4], %mask, %pass_thru : memref<4x5xf32>, vector<4xi1>, vector<4xf32> into vector<4xf32>
+  return %0: vector<4xf32>
+}
+
+// CHECK-LABEL:  @vector_maskedload_c3
+//  CHECK-SAME:  (%[[ARG0:.*]]: memref<4x5xf32>) -> vector<4xf32> {
+//   CHECK-DAG:   %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32>
+//   CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
+//   CHECK-DAG:   %[[C4:.*]] = arith.constant 4 : index
+//       CHECK:   %[[S0:.*]] = vector.load %[[ARG0]][%[[C0]], %[[C4]]] : memref<4x5xf32>, vector<3xf32>
+//       CHECK:   %[[S1:.*]] = vector.insert_strided_slice %[[S0]], %[[CST]] {offsets = [0], strides = [1]} : vector<3xf32> into vector<4xf32>
+//       CHECK:   return %[[S1]] : vector<4xf32>
+//       CHECK: }
+func.func @vector_maskedload_c3(%arg0 : memref<4x5xf32>) -> vector<4xf32> {
+  %idx_0 = arith.constant 0 : index
+  %idx_3 = arith.constant 3 : index
+  %idx_4 = arith.constant 4 : index
+  %mask = vector.create_mask %idx_3 : vector<4xi1>
+  %s = arith.constant 0.0 : f32
+  %pass_thru = vector.splat %s : vector<4xf32>
+  %0 = vector.maskedload %arg0[%idx_0, %idx_4], %mask, %pass_thru : memref<4x5xf32>, vector<4xi1>, vector<4xf32> into vector<4xf32>
+  return %0: vector<4xf32>
+}
+
+// CHECK-LABEL:  @vector_maskedstore_c1
+//  CHECK-SAME:  (%[[ARG0:.*]]: memref<4x5xf32>, %[[ARG1:.*]]: vector<4xf32>) {
+//   CHECK-DAG:   %[[...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/75587


More information about the Mlir-commits mailing list