[Mlir-commits] [mlir] [mlir][Vector] Infer mask and pass_thru types for maskedload/store (PR #131482)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sat Mar 15 17:20:36 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-memref
Author: Kunwar Grover (Groverkss)
<details>
<summary>Changes</summary>
---
Patch is 84.25 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/131482.diff
20 Files Affected:
- (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.td (+25-8)
- (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+6-23)
- (modified) mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir (+1-1)
- (modified) mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir (+8-8)
- (modified) mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir (+8-8)
- (modified) mlir/test/Dialect/SparseTensor/sparse_vector.mlir (+16-16)
- (modified) mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir (+2-2)
- (modified) mlir/test/Dialect/SparseTensor/sparse_vector_index.mlir (+3-3)
- (modified) mlir/test/Dialect/SparseTensor/sparse_vector_peeled.mlir (+2-2)
- (modified) mlir/test/Dialect/SparseTensor/vectorize_reduction.mlir (+7-7)
- (modified) mlir/test/Dialect/Vector/canonicalize.mlir (+5-5)
- (modified) mlir/test/Dialect/Vector/emulate-narrow-type-unsupported.mlir (+3-3)
- (modified) mlir/test/Dialect/Vector/invalid.mlir (+15-12)
- (modified) mlir/test/Dialect/Vector/ops.mlir (+8-8)
- (modified) mlir/test/Dialect/Vector/vector-emulate-masked-load-store.mlir (+2-2)
- (modified) mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir (+8-8)
- (modified) mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir (+32-32)
- (modified) mlir/test/Dialect/Vector/vector-mem-transforms.mlir (+6-6)
- (modified) mlir/test/mlir-tblgen/op-result.td (+7-11)
- (modified) mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp (+23-28)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index fbbf817ecff98..fd77249402934 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -1840,7 +1840,16 @@ def Vector_StoreOp : Vector_Op<"store"> {
}
def Vector_MaskedLoadOp :
- Vector_Op<"maskedload">,
+ Vector_Op<"maskedload", [
+ AllTypesMatch<["result", "pass_thru"]>,
+ TypesMatchWith<"mask shape should match result shape",
+ "result",
+ "mask",
+ "VectorType::get(::llvm::cast<VectorType>($_self).getShape(),"
+ "IntegerType::get($_ctxt, 1),"
+ "::llvm::cast<VectorType>($_self).getScalableDims())">,
+ AllElementTypesMatch<["result", "base"]>
+ ]>,
Arguments<(ins Arg<AnyMemRef, "", [MemRead]>:$base,
Variadic<Index>:$indices,
VectorOfNonZeroRankOf<[I1]>:$mask,
@@ -1875,10 +1884,10 @@ def Vector_MaskedLoadOp :
```mlir
%0 = vector.maskedload %base[%i], %mask, %pass_thru
- : memref<?xf32>, vector<8xi1>, vector<8xf32> into vector<8xf32>
+ : memref<?xf32>, vector<8xf32>
%1 = vector.maskedload %base[%i, %j], %mask, %pass_thru
- : memref<?x?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+ : memref<?x?xf32>, vector<16xf32>
```
}];
let extraClassDeclaration = [{
@@ -1896,14 +1905,22 @@ def Vector_MaskedLoadOp :
}
}];
let assemblyFormat = "$base `[` $indices `]` `,` $mask `,` $pass_thru attr-dict `:` "
- "type($base) `,` type($mask) `,` type($pass_thru) `into` type($result)";
+ "type($base) `,` type($result)";
let hasCanonicalizer = 1;
let hasFolder = 1;
let hasVerifier = 1;
}
def Vector_MaskedStoreOp :
- Vector_Op<"maskedstore">,
+ Vector_Op<"maskedstore", [
+ TypesMatchWith<"mask shape should match result shape",
+ "valueToStore",
+ "mask",
+ "VectorType::get(::llvm::cast<VectorType>($_self).getShape(),"
+ "IntegerType::get($_ctxt, 1),"
+ "::llvm::cast<VectorType>($_self).getScalableDims())">,
+ AllElementTypesMatch<["valueToStore", "base"]>
+ ]>,
Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
Variadic<Index>:$indices,
VectorOfNonZeroRankOf<[I1]>:$mask,
@@ -1937,10 +1954,10 @@ def Vector_MaskedStoreOp :
```mlir
vector.maskedstore %base[%i], %mask, %value
- : memref<?xf32>, vector<8xi1>, vector<8xf32>
+ : memref<?xf32>, vector<8xf32>
vector.maskedstore %base[%i, %j], %mask, %value
- : memref<?x?xf32>, vector<16xi1>, vector<16xf32>
+ : memref<?x?xf32>, vector<16xf32>
```
}];
let extraClassDeclaration = [{
@@ -1956,7 +1973,7 @@ def Vector_MaskedStoreOp :
}];
let assemblyFormat =
"$base `[` $indices `]` `,` $mask `,` $valueToStore "
- "attr-dict `:` type($base) `,` type($mask) `,` type($valueToStore)";
+ "attr-dict `:` type($base) `,` type($valueToStore)";
let hasCanonicalizer = 1;
let hasFolder = 1;
let hasVerifier = 1;
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 8e0e723cf4ed3..83b962e54110a 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5127,19 +5127,9 @@ LogicalResult StoreOp::fold(FoldAdaptor adaptor,
//===----------------------------------------------------------------------===//
LogicalResult MaskedLoadOp::verify() {
- VectorType maskVType = getMaskVectorType();
- VectorType passVType = getPassThruVectorType();
- VectorType resVType = getVectorType();
- MemRefType memType = getMemRefType();
-
- if (resVType.getElementType() != memType.getElementType())
- return emitOpError("base and result element type should match");
- if (llvm::size(getIndices()) != memType.getRank())
- return emitOpError("requires ") << memType.getRank() << " indices";
- if (resVType.getShape() != maskVType.getShape())
- return emitOpError("expected result shape to match mask shape");
- if (resVType != passVType)
- return emitOpError("expected pass_thru of same type as result type");
+ int64_t memRank = getMemRefType().getRank();
+ if (llvm::size(getIndices()) != memRank)
+ return emitOpError("requires ") << memRank << " indices";
return success();
}
@@ -5181,16 +5171,9 @@ OpFoldResult MaskedLoadOp::fold(FoldAdaptor) {
//===----------------------------------------------------------------------===//
LogicalResult MaskedStoreOp::verify() {
- VectorType maskVType = getMaskVectorType();
- VectorType valueVType = getVectorType();
- MemRefType memType = getMemRefType();
-
- if (valueVType.getElementType() != memType.getElementType())
- return emitOpError("base and valueToStore element type should match");
- if (llvm::size(getIndices()) != memType.getRank())
- return emitOpError("requires ") << memType.getRank() << " indices";
- if (valueVType.getShape() != maskVType.getShape())
- return emitOpError("expected valueToStore shape to match mask shape");
+ int64_t memRank = getMemRefType().getRank();
+ if (llvm::size(getIndices()) != memRank)
+ return emitOpError("requires ") << memRank << " indices";
return success();
}
diff --git a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
index 4ae710aa29113..10224aec95d48 100644
--- a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
+++ b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
@@ -88,7 +88,7 @@ func.func @arm_sme_tile_load_hor_with_mask_and_pad_zero(%src : memref<?x?xi32>)
// CHECK-NEXT: %[[MASK_1D:.*]] = vector.create_mask %[[MASK_INDEX]] : vector<[4]xi1>
// CHECK-NEXT: %[[OFFSET:.*]] = arith.addi %[[C0]], %[[TILE_SLICE_INDEX]] : index
// CHECK: %[[PAD_1D:.*]] = vector.splat %[[PAD]] : vector<[4]xi32>
-// CHECK: %[[LOAD_SLICE:.*]] = vector.maskedload %[[SRC]]{{\[}}%[[OFFSET]], %[[C0]]], %[[MASK_1D]], %[[PAD_1D]] : memref<?x?xi32>, vector<[4]xi1>, vector<[4]xi32> into vector<[4]xi32>
+// CHECK: %[[LOAD_SLICE:.*]] = vector.maskedload %[[SRC]]{{\[}}%[[OFFSET]], %[[C0]]], %[[MASK_1D]], %[[PAD_1D]] : memref<?x?xi32>, vector<[4]xi32>
// CHECK: %[[TILE_UPDATE:.*]] = arm_sme.insert_tile_slice %[[LOAD_SLICE]], %[[CURRENT_TILE]][%[[TILE_SLICE_INDEX]]] : vector<[4]xi32> into vector<[4]x[4]xi32>
// CHECK-NEXT: scf.yield %[[TILE_UPDATE]] : vector<[4]x[4]xi32>
func.func @arm_sme_tile_load_hor_with_mask_and_nonzero_pad(%src : memref<?x?xi32>, %pad : i32) {
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
index c3f06dd4d5dd1..6dcae67abda57 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
@@ -1891,7 +1891,7 @@ func.func @store_0d(%memref : memref<200x100xf32>, %i : index, %j : index) {
func.func @masked_load(%arg0: memref<?xf32>, %arg1: vector<16xi1>, %arg2: vector<16xf32>) -> vector<16xf32> {
%c0 = arith.constant 0: index
- %0 = vector.maskedload %arg0[%c0], %arg1, %arg2 : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+ %0 = vector.maskedload %arg0[%c0], %arg1, %arg2 : memref<?xf32>, vector<16xf32>
return %0 : vector<16xf32>
}
@@ -1906,7 +1906,7 @@ func.func @masked_load(%arg0: memref<?xf32>, %arg1: vector<16xi1>, %arg2: vector
func.func @masked_load_scalable(%arg0: memref<?xf32>, %arg1: vector<[16]xi1>, %arg2: vector<[16]xf32>) -> vector<[16]xf32> {
%c0 = arith.constant 0: index
- %0 = vector.maskedload %arg0[%c0], %arg1, %arg2 : memref<?xf32>, vector<[16]xi1>, vector<[16]xf32> into vector<[16]xf32>
+ %0 = vector.maskedload %arg0[%c0], %arg1, %arg2 : memref<?xf32>, vector<[16]xf32>
return %0 : vector<[16]xf32>
}
@@ -1921,7 +1921,7 @@ func.func @masked_load_scalable(%arg0: memref<?xf32>, %arg1: vector<[16]xi1>, %a
func.func @masked_load_index(%arg0: memref<?xindex>, %arg1: vector<16xi1>, %arg2: vector<16xindex>) -> vector<16xindex> {
%c0 = arith.constant 0: index
- %0 = vector.maskedload %arg0[%c0], %arg1, %arg2 : memref<?xindex>, vector<16xi1>, vector<16xindex> into vector<16xindex>
+ %0 = vector.maskedload %arg0[%c0], %arg1, %arg2 : memref<?xindex>, vector<16xindex>
return %0 : vector<16xindex>
}
// CHECK-LABEL: func @masked_load_index
@@ -1931,7 +1931,7 @@ func.func @masked_load_index(%arg0: memref<?xindex>, %arg1: vector<16xi1>, %arg2
func.func @masked_load_index_scalable(%arg0: memref<?xindex>, %arg1: vector<[16]xi1>, %arg2: vector<[16]xindex>) -> vector<[16]xindex> {
%c0 = arith.constant 0: index
- %0 = vector.maskedload %arg0[%c0], %arg1, %arg2 : memref<?xindex>, vector<[16]xi1>, vector<[16]xindex> into vector<[16]xindex>
+ %0 = vector.maskedload %arg0[%c0], %arg1, %arg2 : memref<?xindex>, vector<[16]xindex>
return %0 : vector<[16]xindex>
}
// CHECK-LABEL: func @masked_load_index_scalable
@@ -1945,7 +1945,7 @@ func.func @masked_load_index_scalable(%arg0: memref<?xindex>, %arg1: vector<[16]
func.func @masked_store(%arg0: memref<?xf32>, %arg1: vector<16xi1>, %arg2: vector<16xf32>) {
%c0 = arith.constant 0: index
- vector.maskedstore %arg0[%c0], %arg1, %arg2 : memref<?xf32>, vector<16xi1>, vector<16xf32>
+ vector.maskedstore %arg0[%c0], %arg1, %arg2 : memref<?xf32>, vector<16xf32>
return
}
@@ -1959,7 +1959,7 @@ func.func @masked_store(%arg0: memref<?xf32>, %arg1: vector<16xi1>, %arg2: vecto
func.func @masked_store_scalable(%arg0: memref<?xf32>, %arg1: vector<[16]xi1>, %arg2: vector<[16]xf32>) {
%c0 = arith.constant 0: index
- vector.maskedstore %arg0[%c0], %arg1, %arg2 : memref<?xf32>, vector<[16]xi1>, vector<[16]xf32>
+ vector.maskedstore %arg0[%c0], %arg1, %arg2 : memref<?xf32>, vector<[16]xf32>
return
}
@@ -1973,7 +1973,7 @@ func.func @masked_store_scalable(%arg0: memref<?xf32>, %arg1: vector<[16]xi1>, %
func.func @masked_store_index(%arg0: memref<?xindex>, %arg1: vector<16xi1>, %arg2: vector<16xindex>) {
%c0 = arith.constant 0: index
- vector.maskedstore %arg0[%c0], %arg1, %arg2 : memref<?xindex>, vector<16xi1>, vector<16xindex>
+ vector.maskedstore %arg0[%c0], %arg1, %arg2 : memref<?xindex>, vector<16xindex>
return
}
// CHECK-LABEL: func @masked_store_index
@@ -1983,7 +1983,7 @@ func.func @masked_store_index(%arg0: memref<?xindex>, %arg1: vector<16xi1>, %arg
func.func @masked_store_index_scalable(%arg0: memref<?xindex>, %arg1: vector<[16]xi1>, %arg2: vector<[16]xindex>) {
%c0 = arith.constant 0: index
- vector.maskedstore %arg0[%c0], %arg1, %arg2 : memref<?xindex>, vector<[16]xi1>, vector<[16]xindex>
+ vector.maskedstore %arg0[%c0], %arg1, %arg2 : memref<?xindex>, vector<[16]xindex>
return
}
// CHECK-LABEL: func @masked_store_index_scalable
diff --git a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
index 327cacf7d9a20..7246ae4884a19 100644
--- a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
+++ b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
@@ -837,7 +837,7 @@ func.func @fold_vector_load_subview(
func.func @fold_vector_maskedload_subview(
%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3: vector<32xi1>, %arg4: vector<32xf32>) -> vector<32xf32> {
%0 = memref.subview %arg0[%arg1, %arg2][1, 1][1, 1] : memref<12x32xf32> to memref<f32, strided<[], offset: ?>>
- %1 = vector.maskedload %0[], %arg3, %arg4 : memref<f32, strided<[], offset: ?>>, vector<32xi1>, vector<32xf32> into vector<32xf32>
+ %1 = vector.maskedload %0[], %arg3, %arg4 : memref<f32, strided<[], offset: ?>>, vector<32xf32>
return %1 : vector<32xf32>
}
@@ -847,7 +847,7 @@ func.func @fold_vector_maskedload_subview(
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: vector<32xi1>
// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: vector<32xf32>
-// CHECK: vector.maskedload %[[ARG0]][%[[ARG1]], %[[ARG2]]], %[[ARG3]], %[[ARG4]] : memref<12x32xf32>, vector<32xi1>, vector<32xf32> into vector<32xf32>
+// CHECK: vector.maskedload %[[ARG0]][%[[ARG1]], %[[ARG2]]], %[[ARG3]], %[[ARG4]] : memref<12x32xf32>, vector<32xf32>
// -----
@@ -871,7 +871,7 @@ func.func @fold_vector_store_subview(
func.func @fold_vector_maskedstore_subview(
%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3: vector<32xi1>, %arg4: vector<32xf32>) -> () {
%0 = memref.subview %arg0[%arg1, %arg2][1, 1][1, 1] : memref<12x32xf32> to memref<f32, strided<[], offset: ?>>
- vector.maskedstore %0[], %arg3, %arg4 : memref<f32, strided<[], offset: ?>>, vector<32xi1>, vector<32xf32>
+ vector.maskedstore %0[], %arg3, %arg4 : memref<f32, strided<[], offset: ?>>, vector<32xf32>
return
}
@@ -881,7 +881,7 @@ func.func @fold_vector_maskedstore_subview(
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: vector<32xi1>
// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: vector<32xf32>
-// CHECK: vector.maskedstore %[[ARG0]][%[[ARG1]], %[[ARG2]]], %[[ARG3]], %[[ARG4]] : memref<12x32xf32>, vector<32xi1>, vector<32xf32>
+// CHECK: vector.maskedstore %[[ARG0]][%[[ARG1]], %[[ARG2]]], %[[ARG3]], %[[ARG4]] : memref<12x32xf32>, vector<32xf32>
// CHECK: return
// -----
@@ -907,7 +907,7 @@ func.func @fold_vector_maskedload_expand_shape(
%arg0 : memref<32xf32>, %arg1 : index, %arg3: vector<8xi1>, %arg4: vector<8xf32>) -> vector<8xf32> {
%c0 = arith.constant 0 : index
%0 = memref.expand_shape %arg0 [[0, 1]] output_shape [4, 8] : memref<32xf32> into memref<4x8xf32>
- %1 = vector.maskedload %0[%arg1, %c0], %arg3, %arg4 : memref<4x8xf32>, vector<8xi1>, vector<8xf32> into vector<8xf32>
+ %1 = vector.maskedload %0[%arg1, %c0], %arg3, %arg4 : memref<4x8xf32>, vector<8xf32>
return %1 : vector<8xf32>
}
@@ -943,7 +943,7 @@ func.func @fold_vector_maskedstore_expand_shape(
%arg0 : memref<32xf32>, %arg1 : index, %arg3: vector<8xi1>, %arg4: vector<8xf32>) {
%c0 = arith.constant 0 : index
%0 = memref.expand_shape %arg0 [[0, 1]] output_shape [4, 8] : memref<32xf32> into memref<4x8xf32>
- vector.maskedstore %0[%arg1, %c0], %arg3, %arg4 : memref<4x8xf32>, vector<8xi1>, vector<8xf32>
+ vector.maskedstore %0[%arg1, %c0], %arg3, %arg4 : memref<4x8xf32>, vector<8xf32>
return
}
@@ -979,7 +979,7 @@ func.func @fold_vector_load_collapse_shape(
func.func @fold_vector_maskedload_collapse_shape(
%arg0 : memref<4x8xf32>, %arg1 : index, %arg3: vector<8xi1>, %arg4: vector<8xf32>) -> vector<8xf32> {
%0 = memref.collapse_shape %arg0 [[0, 1]] : memref<4x8xf32> into memref<32xf32>
- %1 = vector.maskedload %0[%arg1], %arg3, %arg4 : memref<32xf32>, vector<8xi1>, vector<8xf32> into vector<8xf32>
+ %1 = vector.maskedload %0[%arg1], %arg3, %arg4 : memref<32xf32>, vector<8xf32>
return %1 : vector<8xf32>
}
@@ -1017,7 +1017,7 @@ func.func @fold_vector_store_collapse_shape(
func.func @fold_vector_maskedstore_collapse_shape(
%arg0 : memref<4x8xf32>, %arg1 : index, %arg3: vector<8xi1>, %arg4: vector<8xf32>) {
%0 = memref.collapse_shape %arg0 [[0, 1]] : memref<4x8xf32> into memref<32xf32>
- vector.maskedstore %0[%arg1], %arg3, %arg4 : memref<32xf32>, vector<8xi1>, vector<8xf32>
+ vector.maskedstore %0[%arg1], %arg3, %arg4 : memref<32xf32>, vector<8xf32>
return
}
diff --git a/mlir/test/Dialect/SparseTensor/sparse_vector.mlir b/mlir/test/Dialect/SparseTensor/sparse_vector.mlir
index 364ba6e71ff3b..c50d44f7faa1e 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_vector.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_vector.mlir
@@ -65,10 +65,10 @@
// CHECK-VEC4-SVE: scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[step]] {
// CHECK-VEC4-SVE: %[[sub:.*]] = affine.min #[[$map]](%[[c1024]], %[[i]])[%[[step]]]
// CHECK-VEC4-SVE: %[[mask:.*]] = vector.create_mask %[[sub]] : vector<[4]xi1>
-// CHECK-VEC4-SVE: %[[val:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %[[v0]] : memref<?xf32>, vector<[4]xi1>, vector<[4]xf32> into vector<[4]xf32>
+// CHECK-VEC4-SVE: %[[val:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %[[v0]] : memref<?xf32>, vector<[4]xf32>
// CHECK-VEC4-SVE: %[[scalev:.*]] = vector.broadcast %{{.*}} : f32 to vector<[4]xf32>
// CHECK-VEC4-SVE: %[[scaled:.*]] = arith.mulf %[[val]], %[[scalev]] : vector<[4]xf32>
-// CHECK-VEC4-SVE: vector.maskedstore %{{.*}}[%[[i]]], %[[mask]], %[[scaled]] : memref<1024xf32>, vector<[4]xi1>, vector<[4]xf32>
+// CHECK-VEC4-SVE: vector.maskedstore %{{.*}}[%[[i]]], %[[mask]], %[[scaled]] : memref<1024xf32>, vector<[4]xf32>
// CHECK-VEC4-SVE: }
// CHECK-VEC4-SVE: return
//
@@ -136,9 +136,9 @@ func.func @scale_d(%arga: tensor<1024xf32, #DenseVector>, %b: f32, %argx: tensor
// CHECK-VEC16: scf.for %[[i:.*]] = %[[q]] to %[[s]] step %[[c16]] {
// CHECK-VEC16: %[[sub:.*]] = affine.min #[[$map]](%[[s]], %[[i]])[%[[c16]]]
// CHECK-VEC16: %[[mask:.*]] = vector.create_mask %[[sub]] : vector<16xi1>
-// CHECK-VEC16: %[[li:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %{{.*}} : memref<?xi32>, vector<16xi1>, vector<16xi32> into vector<16xi32>
+// CHECK-VEC16: %[[li:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %{{.*}} : memref<?xi32>, vector<16xi32>
// CHECK-VEC16: %[[zi:.*]] = arith.extui %[[li]] : vector<16xi32> to vector<16xi64>
-// CHECK-VEC16: %[[la:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %{{.*}} : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+// CHECK-VEC16: %[[la:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %{{.*}} : memref<?xf32>, vector<16xf32>
// CHECK-VEC16: %[[lb:.*]] = vector.gather %{{.*}}[%[[c0]]] [%[[zi]]], %[[mask]], %{{.*}} : memref<1024xf32>, vector<16xi64>, vector<16xi1>, vector<16xf32> into vector<16xf32>
// CHECK-VEC16: %[[m:.*]] = arith.mulf %[[la]], %[[lb]] : vector<16xf32>
// CHECK-VEC16: vector.scatter %{{.*}}[%[[c0]]] [%[[zi]]], %[[mask]], %[[m]] : memref<1024xf32>, vector<16xi64>, vector<16xi1>, vector<16xf32>
@@ -159,8 +159,8 @@ func.func @scale_d(%arga: tensor<1024xf32, #DenseVector>, %b: f32, %argx: tensor
// CHECK-VEC16-IDX32: scf.for %[[i:.*]] = %[[q]] to %[[s]] step %[[c16]] {
// CHECK-VEC16-IDX32: %[[sub:.*]] = affine.min #[[$map]](%[[s]], %[[i]])[%[[c16]]]
// CHECK-VEC16-IDX32: %[[mask:.*]] = vector.create_mask %[[sub]] : vector<16xi1>
-// CHECK-VEC16-IDX32: %[[li:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %{{.*}} : memref<?xi32>, vector<16xi1>, vector<16xi32> into vector<16xi32>
-// CHECK-VEC16-IDX32: %[[la:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %{{.*}} : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+// CHECK-VEC16-IDX32: %[[li:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %{{.*}} : memref<?xi32>, vector<16xi32>
+// CHECK-VEC16-IDX32: %[[la:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %{{.*}} : memref<?xf32>, vector<16xf32>
// CHECK-VEC16-IDX32: %[[lb:.*]] = vector.gather %{{.*}}[%[[c0]]] [%[[li]]], %[[mask]], %{{.*}} : memref<1024xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
// CHECK-VEC16-IDX32: %[[m:.*]] = arith.mulf %[[la]], %[[lb]] : vector<16xf32>
// CHECK-VEC16-IDX32: vector.scatter %{{.*}}[%[[c0]]] [%[[li]]], %[[mask]], %[[m]] : memref<1024xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
@@ -185,9 +185,9 @@ func.func @scale_d(%arga: tensor<1024xf32, #DenseVector>, %b: f32, %argx: tensor
// CHECK-VEC4-SVE: scf.for %[[i:.*]] = %[[q]] to %[[s]] step %[[step]] {
// CHECK-VEC4-SVE: %[[sub:.*]] = affine.min #[[$map]](%[[s]], %[[i]])[%[[step]]]
// CHECK-VEC4-SVE: %[[mask:.*]] = vector.create_mask %[[sub]] : vector<[4]xi1>
-// CHECK-VEC4-SVE: %[[li:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %[[v0i]] : memref<?xi32>, vector<[4]xi1>, vector<[4]xi32> into vector<[4]xi32>
+// CHECK-VEC4-SVE: %[[li:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %[[v0i]] : memref<?xi32>, vector<[4]xi32>
// CHECK-VEC4-SVE: %[[lii64:.*]] = arith.extui %[[li]] : vector<[4]xi32> to vector<[4]xi64>
-// CHECK-VEC4-SVE: %[[la:.*]] = vector.maskedload %{{.*...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/131482
More information about the Mlir-commits
mailing list