[Mlir-commits] [mlir] [mlir][vector] Deal with special patterns when emulating masked load/store (PR #75587)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Dec 15 04:09:16 PST 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-vector
Author: Hsiangkai Wang (Hsiangkai)
<details>
<summary>Changes</summary>
We can simplify the generated code when the mask is created by constant_mask or create_mask. The mask will look like [1, 1, 1, ...]. We can use vector.load + vector.insert_strided_slice to emulate maskedload and use vector.extract_strided_slice + vector.store to emulate maskedstore.
---
Patch is 22.50 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/75587.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Vector/Transforms/VectorEmulateMaskedLoadStore.cpp (+137-95)
- (modified) mlir/test/Dialect/Vector/vector-emulate-masked-load-store.mlir (+122-19)
``````````diff
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateMaskedLoadStore.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateMaskedLoadStore.cpp
index 8cc7008d80b3ed..c61633d5ff8aeb 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateMaskedLoadStore.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateMaskedLoadStore.cpp
@@ -19,33 +19,22 @@ using namespace mlir;
namespace {
+std::optional<int64_t> getMaskLeadingOneLength(Value mask) {
+ if (auto m = mask.getDefiningOp<vector::ConstantMaskOp>()) {
+ ArrayAttr masks = m.getMaskDimSizes();
+ assert(masks.size() == 1 && "Only support 1-D mask.");
+ return llvm::cast<IntegerAttr>(masks[0]).getInt();
+ } else if (auto m = mask.getDefiningOp<vector::CreateMaskOp>()) {
+ auto maskOperands = m.getOperands();
+ assert(maskOperands.size() == 1 && "Only support 1-D mask.");
+ if (auto constantOp = maskOperands[0].getDefiningOp<arith::ConstantOp>()) {
+ return llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
+ }
+ }
+ return {};
+}
+
/// Convert vector.maskedload
-///
-/// Before:
-///
-/// vector.maskedload %base[%idx_0, %idx_1], %mask, %pass_thru
-///
-/// After:
-///
-/// %ivalue = %pass_thru
-/// %m = vector.extract %mask[0]
-/// %result0 = scf.if %m {
-/// %v = memref.load %base[%idx_0, %idx_1]
-/// %combined = vector.insert %v, %ivalue[0]
-/// scf.yield %combined
-/// } else {
-/// scf.yield %ivalue
-/// }
-/// %m = vector.extract %mask[1]
-/// %result1 = scf.if %m {
-/// %v = memref.load %base[%idx_0, %idx_1 + 1]
-/// %combined = vector.insert %v, %result0[1]
-/// scf.yield %combined
-/// } else {
-/// scf.yield %result0
-/// }
-/// ...
-///
struct VectorMaskedLoadOpConverter final
: OpRewritePattern<vector::MaskedLoadOp> {
using OpRewritePattern::OpRewritePattern;
@@ -58,61 +47,82 @@ struct VectorMaskedLoadOpConverter final
maskedLoadOp, "expected vector.maskedstore with 1-D mask");
Location loc = maskedLoadOp.getLoc();
- int64_t maskLength = maskVType.getShape()[0];
-
- Type indexType = rewriter.getIndexType();
- Value mask = maskedLoadOp.getMask();
- Value base = maskedLoadOp.getBase();
- Value iValue = maskedLoadOp.getPassThru();
- auto indices = llvm::to_vector_of<Value>(maskedLoadOp.getIndices());
- Value one = rewriter.create<arith::ConstantOp>(
- loc, indexType, IntegerAttr::get(indexType, 1));
- for (int64_t i = 0; i < maskLength; ++i) {
- auto maskBit = rewriter.create<vector::ExtractOp>(loc, mask, i);
-
- auto ifOp = rewriter.create<scf::IfOp>(
- loc, maskBit,
- [&](OpBuilder &builder, Location loc) {
- auto loadedValue =
- builder.create<memref::LoadOp>(loc, base, indices);
- auto combinedValue =
- builder.create<vector::InsertOp>(loc, loadedValue, iValue, i);
- builder.create<scf::YieldOp>(loc, combinedValue.getResult());
- },
- [&](OpBuilder &builder, Location loc) {
- builder.create<scf::YieldOp>(loc, iValue);
- });
- iValue = ifOp.getResult(0);
-
- indices.back() = rewriter.create<arith::AddIOp>(loc, indices.back(), one);
+ std::optional<int64_t> ones =
+ getMaskLeadingOneLength(maskedLoadOp.getMask());
+ if (ones) {
+ /// Converted to:
+ ///
+ /// %partial = vector.load %base[%idx_0, %idx_1]
+ /// %value = vector.insert_strided_slice %partial, %pass_thru
+ /// {offsets = [0], strides = [1]}
+ Type vectorType =
+ VectorType::get(*ones, maskedLoadOp.getMemRefType().getElementType());
+ auto loadOp = rewriter.create<vector::LoadOp>(
+ loc, vectorType, maskedLoadOp.getBase(), maskedLoadOp.getIndices());
+ auto insertedValue = rewriter.create<vector::InsertStridedSliceOp>(
+ loc, loadOp, maskedLoadOp.getPassThru(),
+ /*offsets=*/ArrayRef<int64_t>({0}),
+ /*strides=*/ArrayRef<int64_t>({1}));
+ rewriter.replaceOp(maskedLoadOp, insertedValue);
+ } else {
+ /// Converted to:
+ ///
+ /// %ivalue = %pass_thru
+ /// %m = vector.extract %mask[0]
+ /// %result0 = scf.if %m {
+ /// %v = memref.load %base[%idx_0, %idx_1]
+ /// %combined = vector.insert %v, %ivalue[0]
+ /// scf.yield %combined
+ /// } else {
+ /// scf.yield %ivalue
+ /// }
+ /// %m = vector.extract %mask[1]
+ /// %result1 = scf.if %m {
+ /// %v = memref.load %base[%idx_0, %idx_1 + 1]
+ /// %combined = vector.insert %v, %result0[1]
+ /// scf.yield %combined
+ /// } else {
+ /// scf.yield %result0
+ /// }
+ /// ...
+ ///
+ int64_t maskLength = maskVType.getShape()[0];
+
+ Type indexType = rewriter.getIndexType();
+ Value mask = maskedLoadOp.getMask();
+ Value base = maskedLoadOp.getBase();
+ Value iValue = maskedLoadOp.getPassThru();
+ auto indices = llvm::to_vector_of<Value>(maskedLoadOp.getIndices());
+ Value one = rewriter.create<arith::ConstantOp>(
+ loc, indexType, IntegerAttr::get(indexType, 1));
+ for (int64_t i = 0; i < maskLength; ++i) {
+ auto maskBit = rewriter.create<vector::ExtractOp>(loc, mask, i);
+
+ auto ifOp = rewriter.create<scf::IfOp>(
+ loc, maskBit,
+ [&](OpBuilder &builder, Location loc) {
+ auto loadedValue =
+ builder.create<memref::LoadOp>(loc, base, indices);
+ auto combinedValue =
+ builder.create<vector::InsertOp>(loc, loadedValue, iValue, i);
+ builder.create<scf::YieldOp>(loc, combinedValue.getResult());
+ },
+ [&](OpBuilder &builder, Location loc) {
+ builder.create<scf::YieldOp>(loc, iValue);
+ });
+ iValue = ifOp.getResult(0);
+
+ indices.back() = rewriter.create<arith::AddIOp>(loc, indices.back(), one);
+ }
+
+ rewriter.replaceOp(maskedLoadOp, iValue);
}
- rewriter.replaceOp(maskedLoadOp, iValue);
-
return success();
}
};
/// Convert vector.maskedstore
-///
-/// Before:
-///
-/// vector.maskedstore %base[%idx_0, %idx_1], %mask, %value
-///
-/// After:
-///
-/// %m = vector.extract %mask[0]
-/// scf.if %m {
-/// %extracted = vector.extract %value[0]
-/// memref.store %extracted, %base[%idx_0, %idx_1]
-/// }
-/// %m = vector.extract %mask[1]
-/// scf.if %m {
-/// %extracted = vector.extract %value[1]
-/// memref.store %extracted, %base[%idx_0, %idx_1 + 1]
-/// }
-/// ...
-///
struct VectorMaskedStoreOpConverter final
: OpRewritePattern<vector::MaskedStoreOp> {
using OpRewritePattern::OpRewritePattern;
@@ -125,29 +135,61 @@ struct VectorMaskedStoreOpConverter final
maskedStoreOp, "expected vector.maskedstore with 1-D mask");
Location loc = maskedStoreOp.getLoc();
- int64_t maskLength = maskVType.getShape()[0];
-
- Type indexType = rewriter.getIndexType();
- Value mask = maskedStoreOp.getMask();
- Value base = maskedStoreOp.getBase();
- Value value = maskedStoreOp.getValueToStore();
- auto indices = llvm::to_vector_of<Value>(maskedStoreOp.getIndices());
- Value one = rewriter.create<arith::ConstantOp>(
- loc, indexType, IntegerAttr::get(indexType, 1));
- for (int64_t i = 0; i < maskLength; ++i) {
- auto maskBit = rewriter.create<vector::ExtractOp>(loc, mask, i);
-
- auto ifOp = rewriter.create<scf::IfOp>(loc, maskBit, /*else=*/false);
- rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
- auto extractedValue = rewriter.create<vector::ExtractOp>(loc, value, i);
- rewriter.create<memref::StoreOp>(loc, extractedValue, base, indices);
-
- rewriter.setInsertionPointAfter(ifOp);
- indices.back() = rewriter.create<arith::AddIOp>(loc, indices.back(), one);
+ std::optional<int64_t> ones =
+ getMaskLeadingOneLength(maskedStoreOp.getMask());
+ if (ones) {
+ /// Converted to:
+ ///
+ /// %partial = vector.extract_strided_slice %value
+ /// {offsets = [0], sizes = [1], strides = [1]}
+ /// vector.store %partial, %base[%idx_0, %idx_1]
+ Value value = maskedStoreOp.getValueToStore();
+ Value extractedValue = rewriter.create<vector::ExtractStridedSliceOp>(
+ loc, value, /*offsets=*/ArrayRef<int64_t>({0}),
+ /*sizes=*/ArrayRef<int64_t>(*ones),
+ /*strides=*/ArrayRef<int64_t>({1}));
+ auto storeOp = rewriter.create<vector::StoreOp>(
+ loc, extractedValue, maskedStoreOp.getBase(),
+ maskedStoreOp.getIndices());
+ rewriter.replaceOp(maskedStoreOp, storeOp);
+ } else {
+ /// Converted to:
+ ///
+ /// %m = vector.extract %mask[0]
+ /// scf.if %m {
+ /// %extracted = vector.extract %value[0]
+ /// memref.store %extracted, %base[%idx_0, %idx_1]
+ /// }
+ /// %m = vector.extract %mask[1]
+ /// scf.if %m {
+ /// %extracted = vector.extract %value[1]
+ /// memref.store %extracted, %base[%idx_0, %idx_1 + 1]
+ /// }
+ /// ...
+ int64_t maskLength = maskVType.getShape()[0];
+
+ Type indexType = rewriter.getIndexType();
+ Value mask = maskedStoreOp.getMask();
+ Value base = maskedStoreOp.getBase();
+ Value value = maskedStoreOp.getValueToStore();
+ auto indices = llvm::to_vector_of<Value>(maskedStoreOp.getIndices());
+ Value one = rewriter.create<arith::ConstantOp>(
+ loc, indexType, IntegerAttr::get(indexType, 1));
+ for (int64_t i = 0; i < maskLength; ++i) {
+ auto maskBit = rewriter.create<vector::ExtractOp>(loc, mask, i);
+
+ auto ifOp = rewriter.create<scf::IfOp>(loc, maskBit, /*else=*/false);
+ rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
+ auto extractedValue = rewriter.create<vector::ExtractOp>(loc, value, i);
+ rewriter.create<memref::StoreOp>(loc, extractedValue, base, indices);
+
+ rewriter.setInsertionPointAfter(ifOp);
+ indices.back() = rewriter.create<arith::AddIOp>(loc, indices.back(), one);
+ }
+
+ rewriter.eraseOp(maskedStoreOp);
}
- rewriter.eraseOp(maskedStoreOp);
-
return success();
}
};
diff --git a/mlir/test/Dialect/Vector/vector-emulate-masked-load-store.mlir b/mlir/test/Dialect/Vector/vector-emulate-masked-load-store.mlir
index 3867f075af8e4b..473f28d0b901b4 100644
--- a/mlir/test/Dialect/Vector/vector-emulate-masked-load-store.mlir
+++ b/mlir/test/Dialect/Vector/vector-emulate-masked-load-store.mlir
@@ -1,16 +1,14 @@
// RUN: mlir-opt %s --test-vector-emulate-masked-load-store | FileCheck %s
// CHECK-LABEL: @vector_maskedload
-// CHECK-SAME: (%[[ARG0:.*]]: memref<4x5xf32>) -> vector<4xf32> {
+// CHECK-SAME: (%[[ARG0:.*]]: memref<4x5xf32>, %[[ARG1:.*]]: vector<4xi1>) -> vector<4xf32> {
// CHECK-DAG: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32>
// CHECK-DAG: %[[C7:.*]] = arith.constant 7 : index
// CHECK-DAG: %[[C6:.*]] = arith.constant 6 : index
// CHECK-DAG: %[[C5:.*]] = arith.constant 5 : index
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
-// CHECK-DAG: %[[S0:.*]] = vector.create_mask %[[C1]] : vector<4xi1>
-// CHECK: %[[S1:.*]] = vector.extract %[[S0]][0] : i1 from vector<4xi1>
+// CHECK: %[[S1:.*]] = vector.extract %[[ARG1]][0] : i1 from vector<4xi1>
// CHECK: %[[S2:.*]] = scf.if %[[S1]] -> (vector<4xf32>) {
// CHECK: %[[S9:.*]] = memref.load %[[ARG0]][%[[C0]], %[[C4]]] : memref<4x5xf32>
// CHECK: %[[S10:.*]] = vector.insert %[[S9]], %[[CST]] [0] : f32 into vector<4xf32>
@@ -18,7 +16,7 @@
// CHECK: } else {
// CHECK: scf.yield %[[CST]] : vector<4xf32>
// CHECK: }
-// CHECK: %[[S3:.*]] = vector.extract %[[S0]][1] : i1 from vector<4xi1>
+// CHECK: %[[S3:.*]] = vector.extract %[[ARG1]][1] : i1 from vector<4xi1>
// CHECK: %[[S4:.*]] = scf.if %[[S3]] -> (vector<4xf32>) {
// CHECK: %[[S9:.*]] = memref.load %[[ARG0]][%[[C0]], %[[C5]]] : memref<4x5xf32>
// CHECK: %[[S10:.*]] = vector.insert %[[S9]], %[[S2]] [1] : f32 into vector<4xf32>
@@ -26,7 +24,7 @@
// CHECK: } else {
// CHECK: scf.yield %[[S2]] : vector<4xf32>
// CHECK: }
-// CHECK: %[[S5:.*]] = vector.extract %[[S0]][2] : i1 from vector<4xi1>
+// CHECK: %[[S5:.*]] = vector.extract %[[ARG1]][2] : i1 from vector<4xi1>
// CHECK: %[[S6:.*]] = scf.if %[[S5]] -> (vector<4xf32>) {
// CHECK: %[[S9:.*]] = memref.load %[[ARG0]][%[[C0]], %[[C6]]] : memref<4x5xf32>
// CHECK: %[[S10:.*]] = vector.insert %[[S9]], %[[S4]] [2] : f32 into vector<4xf32>
@@ -34,7 +32,7 @@
// CHECK: } else {
// CHECK: scf.yield %[[S4]] : vector<4xf32>
// CHECK: }
-// CHECK: %[[S7:.*]] = vector.extract %[[S0]][3] : i1 from vector<4xi1>
+// CHECK: %[[S7:.*]] = vector.extract %[[ARG1]][3] : i1 from vector<4xi1>
// CHECK: %[[S8:.*]] = scf.if %[[S7]] -> (vector<4xf32>) {
// CHECK: %[[S9:.*]] = memref.load %[[ARG0]][%[[C0]], %[[C7]]] : memref<4x5xf32>
// CHECK: %[[S10:.*]] = vector.insert %[[S9]], %[[S6]] [3] : f32 into vector<4xf32>
@@ -43,53 +41,158 @@
// CHECK: scf.yield %[[S6]] : vector<4xf32>
// CHECK: }
// CHECK: return %[[S8]] : vector<4xf32>
-func.func @vector_maskedload(%arg0 : memref<4x5xf32>) -> vector<4xf32> {
+func.func @vector_maskedload(%arg0 : memref<4x5xf32>, %arg1 : vector<4xi1>) -> vector<4xf32> {
%idx_0 = arith.constant 0 : index
- %idx_1 = arith.constant 1 : index
%idx_4 = arith.constant 4 : index
- %mask = vector.create_mask %idx_1 : vector<4xi1>
%s = arith.constant 0.0 : f32
%pass_thru = vector.splat %s : vector<4xf32>
- %0 = vector.maskedload %arg0[%idx_0, %idx_4], %mask, %pass_thru : memref<4x5xf32>, vector<4xi1>, vector<4xf32> into vector<4xf32>
+ %0 = vector.maskedload %arg0[%idx_0, %idx_4], %arg1, %pass_thru : memref<4x5xf32>, vector<4xi1>, vector<4xf32> into vector<4xf32>
return %0: vector<4xf32>
}
// CHECK-LABEL: @vector_maskedstore
-// CHECK-SAME: (%[[ARG0:.*]]: memref<4x5xf32>, %[[ARG1:.*]]: vector<4xf32>) {
+// CHECK-SAME: (%[[ARG0:.*]]: memref<4x5xf32>, %[[ARG1:.*]]: vector<4xf32>, %[[ARG2:.*]]: vector<4xi1>) {
// CHECK-DAG: %[[C7:.*]] = arith.constant 7 : index
// CHECK-DAG: %[[C6:.*]] = arith.constant 6 : index
// CHECK-DAG: %[[C5:.*]] = arith.constant 5 : index
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
-// CHECK-DAG: %[[S0:.*]] = vector.create_mask %[[C1]] : vector<4xi1>
-// CHECK: %[[S1:.*]] = vector.extract %[[S0]][0] : i1 from vector<4xi1>
+// CHECK: %[[S1:.*]] = vector.extract %[[ARG2]][0] : i1 from vector<4xi1>
// CHECK: scf.if %[[S1]] {
// CHECK: %[[S5:.*]] = vector.extract %[[ARG1]][0] : f32 from vector<4xf32>
// CHECK: memref.store %[[S5]], %[[ARG0]][%[[C0]], %[[C4]]] : memref<4x5xf32>
// CHECK: }
-// CHECK: %[[S2:.*]] = vector.extract %[[S0]][1] : i1 from vector<4xi1>
+// CHECK: %[[S2:.*]] = vector.extract %[[ARG2]][1] : i1 from vector<4xi1>
// CHECK: scf.if %[[S2]] {
// CHECK: %[[S5:.*]] = vector.extract %[[ARG1]][1] : f32 from vector<4xf32>
// CHECK: memref.store %[[S5]], %[[ARG0]][%[[C0]], %[[C5]]] : memref<4x5xf32>
// CHECK: }
-// CHECK: %[[S3:.*]] = vector.extract %[[S0]][2] : i1 from vector<4xi1>
+// CHECK: %[[S3:.*]] = vector.extract %[[ARG2]][2] : i1 from vector<4xi1>
// CHECK: scf.if %[[S3]] {
// CHECK: %[[S5:.*]] = vector.extract %[[ARG1]][2] : f32 from vector<4xf32>
// CHECK: memref.store %[[S5]], %[[ARG0]][%[[C0]], %[[C6]]] : memref<4x5xf32>
// CHECK: }
-// CHECK: %[[S4:.*]] = vector.extract %[[S0]][3] : i1 from vector<4xi1>
+// CHECK: %[[S4:.*]] = vector.extract %[[ARG2]][3] : i1 from vector<4xi1>
// CHECK: scf.if %[[S4]] {
// CHECK: %[[S5:.*]] = vector.extract %[[ARG1]][3] : f32 from vector<4xf32>
// CHECK: memref.store %[[S5]], %[[ARG0]][%[[C0]], %[[C7]]] : memref<4x5xf32>
// CHECK: }
// CHECK: return
// CHECK:}
-func.func @vector_maskedstore(%arg0 : memref<4x5xf32>, %arg1 : vector<4xf32>) {
+func.func @vector_maskedstore(%arg0 : memref<4x5xf32>, %arg1 : vector<4xf32>, %arg2 : vector<4xi1>) {
+ %idx_0 = arith.constant 0 : index
+ %idx_4 = arith.constant 4 : index
+ vector.maskedstore %arg0[%idx_0, %idx_4], %arg2, %arg1 : memref<4x5xf32>, vector<4xi1>, vector<4xf32>
+ return
+}
+
+// CHECK-LABEL: @vector_maskedload_c1
+// CHECK-SAME: (%[[ARG0:.*]]: memref<4x5xf32>) -> vector<4xf32> {
+// CHECK-DAG: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32>
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
+// CHECK: %[[S0:.*]] = vector.load %[[ARG0]][%[[C0]], %[[C4]]] : memref<4x5xf32>, vector<1xf32>
+// CHECK: %[[S1:.*]] = vector.insert_strided_slice %[[S0]], %[[CST]] {offsets = [0], strides = [1]} : vector<1xf32> into vector<4xf32>
+// CHECK: return %[[S1]] : vector<4xf32>
+// CHECK: }
+func.func @vector_maskedload_c1(%arg0 : memref<4x5xf32>) -> vector<4xf32> {
%idx_0 = arith.constant 0 : index
%idx_1 = arith.constant 1 : index
%idx_4 = arith.constant 4 : index
%mask = vector.create_mask %idx_1 : vector<4xi1>
+ %s = arith.constant 0.0 : f32
+ %pass_thru = vector.splat %s : vector<4xf32>
+ %0 = vector.maskedload %arg0[%idx_0, %idx_4], %mask, %pass_thru : memref<4x5xf32>, vector<4xi1>, vector<4xf32> into vector<4xf32>
+ return %0: vector<4xf32>
+}
+
+// CHECK-LABEL: @vector_maskedload_c2
+// CHECK-SAME: (%[[ARG0:.*]]: memref<4x5xf32>) -> vector<4xf32> {
+// CHECK-DAG: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32>
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
+// CHECK: %[[S0:.*]] = vector.load %[[ARG0]][%[[C0]], %[[C4]]] : memref<4x5xf32>, vector<2xf32>
+// CHECK: %[[S1:.*]] = vector.insert_strided_slice %[[S0]], %[[CST]] {offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
+// CHECK: return %[[S1]] : vector<4xf32>
+// CHECK: }
+func.func @vector_maskedload_c2(%arg0 : memref<4x5xf32>) -> vector<4xf32> {
+ %idx_0 = arith.constant 0 : index
+ %idx_2 = arith.constant 2 : index
+ %idx_4 = arith.constant 4 : index
+ %mask = vector.create_mask %idx_2 : vector<4xi1>
+ %s = arith.constant 0.0 : f32
+ %pass_thru = vector.splat %s : vector<4xf32>
+ %0 = vector.maskedload %arg0[%idx_0, %idx_4], %mask, %pass_thru : memref<4x5xf32>, vector<4xi1>, vector<4xf32> into vector<4xf32>
+ return %0: vector<4xf32>
+}
+
+// CHECK-LABEL: @vector_maskedload_c3
+// CHECK-SAME: (%[[ARG0:.*]]: memref<4x5xf32>) -> vector<4xf32> {
+// CHECK-DAG: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32>
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
+// CHECK: %[[S0:.*]] = vector.load %[[ARG0]][%[[C0]], %[[C4]]] : memref<4x5xf32>, vector<3xf32>
+// CHECK: %[[S1:.*]] = vector.insert_strided_slice %[[S0]], %[[CST]] {offsets = [0], strides = [1]} : vector<3xf32> into vector<4xf32>
+// CHECK: return %[[S1]] : vector<4xf32>
+// CHECK: }
+func.func @vector_maskedload_c3(%arg0 : memref<4x5xf32>) -> vector<4xf32> {
+ %idx_0 = arith.constant 0 : index
+ %idx_3 = arith.constant 3 : index
+ %idx_4 = arith.constant 4 : index
+ %mask = vector.create_mask %idx_3 : vector<4xi1>
+ %s = arith.constant 0.0 : f32
+ %pass_thru = vector.splat %s : vector<4xf32>
+ %0 = vector.maskedload %arg0[%idx_0, %idx_4], %mask, %pass_thru : memref<4x5xf32>, vector<4xi1>, vector<4xf32> into vector<4xf32>
+ return %0: vector<4xf32>
+}
+
+// CHECK-LABEL: @vector_maskedstore_c1
+// CHECK-SAME: (%[[ARG0:.*]]: memref<4x5xf32>, %[[ARG1:.*]]: vector<4xf32>) {
+// CHECK-DAG: %[[...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/75587
More information about the Mlir-commits
mailing list