[Mlir-commits] [mlir] [mlir][Vector] Infer mask and pass_thru types for maskedload/store (PR #131482)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat Mar 15 17:20:36 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-memref

Author: Kunwar Grover (Groverkss)

<details>
<summary>Changes</summary>



---

Patch is 84.25 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/131482.diff


20 Files Affected:

- (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.td (+25-8) 
- (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+6-23) 
- (modified) mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir (+1-1) 
- (modified) mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir (+8-8) 
- (modified) mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir (+8-8) 
- (modified) mlir/test/Dialect/SparseTensor/sparse_vector.mlir (+16-16) 
- (modified) mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir (+2-2) 
- (modified) mlir/test/Dialect/SparseTensor/sparse_vector_index.mlir (+3-3) 
- (modified) mlir/test/Dialect/SparseTensor/sparse_vector_peeled.mlir (+2-2) 
- (modified) mlir/test/Dialect/SparseTensor/vectorize_reduction.mlir (+7-7) 
- (modified) mlir/test/Dialect/Vector/canonicalize.mlir (+5-5) 
- (modified) mlir/test/Dialect/Vector/emulate-narrow-type-unsupported.mlir (+3-3) 
- (modified) mlir/test/Dialect/Vector/invalid.mlir (+15-12) 
- (modified) mlir/test/Dialect/Vector/ops.mlir (+8-8) 
- (modified) mlir/test/Dialect/Vector/vector-emulate-masked-load-store.mlir (+2-2) 
- (modified) mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir (+8-8) 
- (modified) mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir (+32-32) 
- (modified) mlir/test/Dialect/Vector/vector-mem-transforms.mlir (+6-6) 
- (modified) mlir/test/mlir-tblgen/op-result.td (+7-11) 
- (modified) mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp (+23-28) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index fbbf817ecff98..fd77249402934 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -1840,7 +1840,16 @@ def Vector_StoreOp : Vector_Op<"store"> {
 }
 
 def Vector_MaskedLoadOp :
-  Vector_Op<"maskedload">,
+    Vector_Op<"maskedload", [
+      AllTypesMatch<["result", "pass_thru"]>,
+      TypesMatchWith<"mask shape should match result shape",
+        "result",
+        "mask",
+        "VectorType::get(::llvm::cast<VectorType>($_self).getShape(),"
+          "IntegerType::get($_ctxt, 1),"
+          "::llvm::cast<VectorType>($_self).getScalableDims())">,
+      AllElementTypesMatch<["result", "base"]>
+    ]>,
     Arguments<(ins Arg<AnyMemRef, "", [MemRead]>:$base,
                Variadic<Index>:$indices,
                VectorOfNonZeroRankOf<[I1]>:$mask,
@@ -1875,10 +1884,10 @@ def Vector_MaskedLoadOp :
 
     ```mlir
     %0 = vector.maskedload %base[%i], %mask, %pass_thru
-       : memref<?xf32>, vector<8xi1>, vector<8xf32> into vector<8xf32>
+       : memref<?xf32>, vector<8xf32>
 
     %1 = vector.maskedload %base[%i, %j], %mask, %pass_thru
-       : memref<?x?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+       : memref<?x?xf32>, vector<16xf32>
     ```
   }];
   let extraClassDeclaration = [{
@@ -1896,14 +1905,22 @@ def Vector_MaskedLoadOp :
     }
   }];
   let assemblyFormat = "$base `[` $indices `]` `,` $mask `,` $pass_thru attr-dict `:` "
-    "type($base) `,` type($mask) `,` type($pass_thru) `into` type($result)";
+    "type($base) `,` type($result)";
   let hasCanonicalizer = 1;
   let hasFolder = 1;
   let hasVerifier = 1;
 }
 
 def Vector_MaskedStoreOp :
-  Vector_Op<"maskedstore">,
+    Vector_Op<"maskedstore", [
+      TypesMatchWith<"mask shape should match result shape",
+        "valueToStore",
+        "mask",
+        "VectorType::get(::llvm::cast<VectorType>($_self).getShape(),"
+          "IntegerType::get($_ctxt, 1),"
+          "::llvm::cast<VectorType>($_self).getScalableDims())">,
+      AllElementTypesMatch<["valueToStore", "base"]>
+    ]>,
     Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
                Variadic<Index>:$indices,
                VectorOfNonZeroRankOf<[I1]>:$mask,
@@ -1937,10 +1954,10 @@ def Vector_MaskedStoreOp :
 
     ```mlir
     vector.maskedstore %base[%i], %mask, %value
-      : memref<?xf32>, vector<8xi1>, vector<8xf32>
+      : memref<?xf32>, vector<8xf32>
 
     vector.maskedstore %base[%i, %j], %mask, %value
-      : memref<?x?xf32>, vector<16xi1>, vector<16xf32>
+      : memref<?x?xf32>, vector<16xf32>
     ```
   }];
   let extraClassDeclaration = [{
@@ -1956,7 +1973,7 @@ def Vector_MaskedStoreOp :
   }];
   let assemblyFormat =
       "$base `[` $indices `]` `,` $mask `,` $valueToStore "
-      "attr-dict `:` type($base) `,` type($mask) `,` type($valueToStore)";
+      "attr-dict `:` type($base) `,` type($valueToStore)";
   let hasCanonicalizer = 1;
   let hasFolder = 1;
   let hasVerifier = 1;
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 8e0e723cf4ed3..83b962e54110a 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5127,19 +5127,9 @@ LogicalResult StoreOp::fold(FoldAdaptor adaptor,
 //===----------------------------------------------------------------------===//
 
 LogicalResult MaskedLoadOp::verify() {
-  VectorType maskVType = getMaskVectorType();
-  VectorType passVType = getPassThruVectorType();
-  VectorType resVType = getVectorType();
-  MemRefType memType = getMemRefType();
-
-  if (resVType.getElementType() != memType.getElementType())
-    return emitOpError("base and result element type should match");
-  if (llvm::size(getIndices()) != memType.getRank())
-    return emitOpError("requires ") << memType.getRank() << " indices";
-  if (resVType.getShape() != maskVType.getShape())
-    return emitOpError("expected result shape to match mask shape");
-  if (resVType != passVType)
-    return emitOpError("expected pass_thru of same type as result type");
+  int64_t memRank = getMemRefType().getRank();
+  if (llvm::size(getIndices()) != memRank)
+    return emitOpError("requires ") << memRank << " indices";
   return success();
 }
 
@@ -5181,16 +5171,9 @@ OpFoldResult MaskedLoadOp::fold(FoldAdaptor) {
 //===----------------------------------------------------------------------===//
 
 LogicalResult MaskedStoreOp::verify() {
-  VectorType maskVType = getMaskVectorType();
-  VectorType valueVType = getVectorType();
-  MemRefType memType = getMemRefType();
-
-  if (valueVType.getElementType() != memType.getElementType())
-    return emitOpError("base and valueToStore element type should match");
-  if (llvm::size(getIndices()) != memType.getRank())
-    return emitOpError("requires ") << memType.getRank() << " indices";
-  if (valueVType.getShape() != maskVType.getShape())
-    return emitOpError("expected valueToStore shape to match mask shape");
+  int64_t memRank = getMemRefType().getRank();
+  if (llvm::size(getIndices()) != memRank)
+    return emitOpError("requires ") << memRank << " indices";
   return success();
 }
 
diff --git a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
index 4ae710aa29113..10224aec95d48 100644
--- a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
+++ b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
@@ -88,7 +88,7 @@ func.func @arm_sme_tile_load_hor_with_mask_and_pad_zero(%src : memref<?x?xi32>)
 // CHECK-NEXT:        %[[MASK_1D:.*]] = vector.create_mask %[[MASK_INDEX]] : vector<[4]xi1>
 // CHECK-NEXT:        %[[OFFSET:.*]] = arith.addi %[[C0]], %[[TILE_SLICE_INDEX]] : index
 // CHECK:             %[[PAD_1D:.*]] = vector.splat %[[PAD]] : vector<[4]xi32>
-// CHECK:             %[[LOAD_SLICE:.*]] = vector.maskedload %[[SRC]]{{\[}}%[[OFFSET]], %[[C0]]], %[[MASK_1D]], %[[PAD_1D]] : memref<?x?xi32>, vector<[4]xi1>, vector<[4]xi32> into vector<[4]xi32>
+// CHECK:             %[[LOAD_SLICE:.*]] = vector.maskedload %[[SRC]]{{\[}}%[[OFFSET]], %[[C0]]], %[[MASK_1D]], %[[PAD_1D]] : memref<?x?xi32>, vector<[4]xi32>
 // CHECK:             %[[TILE_UPDATE:.*]] = arm_sme.insert_tile_slice %[[LOAD_SLICE]], %[[CURRENT_TILE]][%[[TILE_SLICE_INDEX]]] : vector<[4]xi32> into vector<[4]x[4]xi32>
 // CHECK-NEXT:        scf.yield %[[TILE_UPDATE]] : vector<[4]x[4]xi32>
 func.func @arm_sme_tile_load_hor_with_mask_and_nonzero_pad(%src : memref<?x?xi32>, %pad : i32) {
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
index c3f06dd4d5dd1..6dcae67abda57 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
@@ -1891,7 +1891,7 @@ func.func @store_0d(%memref : memref<200x100xf32>, %i : index, %j : index) {
 
 func.func @masked_load(%arg0: memref<?xf32>, %arg1: vector<16xi1>, %arg2: vector<16xf32>) -> vector<16xf32> {
   %c0 = arith.constant 0: index
-  %0 = vector.maskedload %arg0[%c0], %arg1, %arg2 : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+  %0 = vector.maskedload %arg0[%c0], %arg1, %arg2 : memref<?xf32>, vector<16xf32>
   return %0 : vector<16xf32>
 }
 
@@ -1906,7 +1906,7 @@ func.func @masked_load(%arg0: memref<?xf32>, %arg1: vector<16xi1>, %arg2: vector
 
 func.func @masked_load_scalable(%arg0: memref<?xf32>, %arg1: vector<[16]xi1>, %arg2: vector<[16]xf32>) -> vector<[16]xf32> {
   %c0 = arith.constant 0: index
-  %0 = vector.maskedload %arg0[%c0], %arg1, %arg2 : memref<?xf32>, vector<[16]xi1>, vector<[16]xf32> into vector<[16]xf32>
+  %0 = vector.maskedload %arg0[%c0], %arg1, %arg2 : memref<?xf32>, vector<[16]xf32>
   return %0 : vector<[16]xf32>
 }
 
@@ -1921,7 +1921,7 @@ func.func @masked_load_scalable(%arg0: memref<?xf32>, %arg1: vector<[16]xi1>, %a
 
 func.func @masked_load_index(%arg0: memref<?xindex>, %arg1: vector<16xi1>, %arg2: vector<16xindex>) -> vector<16xindex> {
   %c0 = arith.constant 0: index
-  %0 = vector.maskedload %arg0[%c0], %arg1, %arg2 : memref<?xindex>, vector<16xi1>, vector<16xindex> into vector<16xindex>
+  %0 = vector.maskedload %arg0[%c0], %arg1, %arg2 : memref<?xindex>, vector<16xindex>
   return %0 : vector<16xindex>
 }
 // CHECK-LABEL: func @masked_load_index
@@ -1931,7 +1931,7 @@ func.func @masked_load_index(%arg0: memref<?xindex>, %arg1: vector<16xi1>, %arg2
 
 func.func @masked_load_index_scalable(%arg0: memref<?xindex>, %arg1: vector<[16]xi1>, %arg2: vector<[16]xindex>) -> vector<[16]xindex> {
   %c0 = arith.constant 0: index
-  %0 = vector.maskedload %arg0[%c0], %arg1, %arg2 : memref<?xindex>, vector<[16]xi1>, vector<[16]xindex> into vector<[16]xindex>
+  %0 = vector.maskedload %arg0[%c0], %arg1, %arg2 : memref<?xindex>, vector<[16]xindex>
   return %0 : vector<[16]xindex>
 }
 // CHECK-LABEL: func @masked_load_index_scalable
@@ -1945,7 +1945,7 @@ func.func @masked_load_index_scalable(%arg0: memref<?xindex>, %arg1: vector<[16]
 
 func.func @masked_store(%arg0: memref<?xf32>, %arg1: vector<16xi1>, %arg2: vector<16xf32>) {
   %c0 = arith.constant 0: index
-  vector.maskedstore %arg0[%c0], %arg1, %arg2 : memref<?xf32>, vector<16xi1>, vector<16xf32>
+  vector.maskedstore %arg0[%c0], %arg1, %arg2 : memref<?xf32>, vector<16xf32>
   return
 }
 
@@ -1959,7 +1959,7 @@ func.func @masked_store(%arg0: memref<?xf32>, %arg1: vector<16xi1>, %arg2: vecto
 
 func.func @masked_store_scalable(%arg0: memref<?xf32>, %arg1: vector<[16]xi1>, %arg2: vector<[16]xf32>) {
   %c0 = arith.constant 0: index
-  vector.maskedstore %arg0[%c0], %arg1, %arg2 : memref<?xf32>, vector<[16]xi1>, vector<[16]xf32>
+  vector.maskedstore %arg0[%c0], %arg1, %arg2 : memref<?xf32>, vector<[16]xf32>
   return
 }
 
@@ -1973,7 +1973,7 @@ func.func @masked_store_scalable(%arg0: memref<?xf32>, %arg1: vector<[16]xi1>, %
 
 func.func @masked_store_index(%arg0: memref<?xindex>, %arg1: vector<16xi1>, %arg2: vector<16xindex>) {
   %c0 = arith.constant 0: index
-  vector.maskedstore %arg0[%c0], %arg1, %arg2 : memref<?xindex>, vector<16xi1>, vector<16xindex>
+  vector.maskedstore %arg0[%c0], %arg1, %arg2 : memref<?xindex>, vector<16xindex>
   return
 }
 // CHECK-LABEL: func @masked_store_index
@@ -1983,7 +1983,7 @@ func.func @masked_store_index(%arg0: memref<?xindex>, %arg1: vector<16xi1>, %arg
 
 func.func @masked_store_index_scalable(%arg0: memref<?xindex>, %arg1: vector<[16]xi1>, %arg2: vector<[16]xindex>) {
   %c0 = arith.constant 0: index
-  vector.maskedstore %arg0[%c0], %arg1, %arg2 : memref<?xindex>, vector<[16]xi1>, vector<[16]xindex>
+  vector.maskedstore %arg0[%c0], %arg1, %arg2 : memref<?xindex>, vector<[16]xindex>
   return
 }
 // CHECK-LABEL: func @masked_store_index_scalable
diff --git a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
index 327cacf7d9a20..7246ae4884a19 100644
--- a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
+++ b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
@@ -837,7 +837,7 @@ func.func @fold_vector_load_subview(
 func.func @fold_vector_maskedload_subview(
   %arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3: vector<32xi1>, %arg4: vector<32xf32>) -> vector<32xf32> {
   %0 = memref.subview %arg0[%arg1, %arg2][1, 1][1, 1] : memref<12x32xf32> to memref<f32, strided<[], offset: ?>>
-  %1 = vector.maskedload %0[], %arg3, %arg4 : memref<f32, strided<[], offset: ?>>, vector<32xi1>, vector<32xf32> into vector<32xf32>
+  %1 = vector.maskedload %0[], %arg3, %arg4 : memref<f32, strided<[], offset: ?>>, vector<32xf32>
   return %1 : vector<32xf32>
 }
 
@@ -847,7 +847,7 @@ func.func @fold_vector_maskedload_subview(
 // CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: index
 // CHECK-SAME:   %[[ARG3:[a-zA-Z0-9_]+]]: vector<32xi1>
 // CHECK-SAME:   %[[ARG4:[a-zA-Z0-9_]+]]: vector<32xf32>
-//      CHECK:   vector.maskedload %[[ARG0]][%[[ARG1]], %[[ARG2]]], %[[ARG3]], %[[ARG4]] : memref<12x32xf32>, vector<32xi1>, vector<32xf32> into vector<32xf32>
+//      CHECK:   vector.maskedload %[[ARG0]][%[[ARG1]], %[[ARG2]]], %[[ARG3]], %[[ARG4]] : memref<12x32xf32>, vector<32xf32>
 
 // -----
 
@@ -871,7 +871,7 @@ func.func @fold_vector_store_subview(
 func.func @fold_vector_maskedstore_subview(
   %arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3: vector<32xi1>, %arg4: vector<32xf32>) -> () {
   %0 = memref.subview %arg0[%arg1, %arg2][1, 1][1, 1] : memref<12x32xf32> to memref<f32, strided<[], offset: ?>>
-  vector.maskedstore %0[], %arg3, %arg4 : memref<f32, strided<[], offset: ?>>, vector<32xi1>, vector<32xf32>
+  vector.maskedstore %0[], %arg3, %arg4 : memref<f32, strided<[], offset: ?>>, vector<32xf32>
   return
 }
 
@@ -881,7 +881,7 @@ func.func @fold_vector_maskedstore_subview(
 // CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: index
 // CHECK-SAME:   %[[ARG3:[a-zA-Z0-9_]+]]: vector<32xi1>
 // CHECK-SAME:   %[[ARG4:[a-zA-Z0-9_]+]]: vector<32xf32>
-//      CHECK:   vector.maskedstore %[[ARG0]][%[[ARG1]], %[[ARG2]]], %[[ARG3]], %[[ARG4]] : memref<12x32xf32>, vector<32xi1>, vector<32xf32>
+//      CHECK:   vector.maskedstore %[[ARG0]][%[[ARG1]], %[[ARG2]]], %[[ARG3]], %[[ARG4]] : memref<12x32xf32>, vector<32xf32>
 //      CHECK:   return
 
 // -----
@@ -907,7 +907,7 @@ func.func @fold_vector_maskedload_expand_shape(
   %arg0 : memref<32xf32>, %arg1 : index, %arg3: vector<8xi1>, %arg4: vector<8xf32>) -> vector<8xf32> {
   %c0 = arith.constant 0 : index
   %0 = memref.expand_shape %arg0 [[0, 1]] output_shape [4, 8] : memref<32xf32> into memref<4x8xf32>
-  %1 = vector.maskedload %0[%arg1, %c0], %arg3, %arg4 : memref<4x8xf32>, vector<8xi1>, vector<8xf32> into vector<8xf32>
+  %1 = vector.maskedload %0[%arg1, %c0], %arg3, %arg4 : memref<4x8xf32>, vector<8xf32>
   return %1 : vector<8xf32>
 }
 
@@ -943,7 +943,7 @@ func.func @fold_vector_maskedstore_expand_shape(
   %arg0 : memref<32xf32>, %arg1 : index, %arg3: vector<8xi1>, %arg4: vector<8xf32>) {
   %c0 = arith.constant 0 : index
   %0 = memref.expand_shape %arg0 [[0, 1]] output_shape [4, 8] : memref<32xf32> into memref<4x8xf32>
-  vector.maskedstore %0[%arg1, %c0], %arg3, %arg4 : memref<4x8xf32>, vector<8xi1>, vector<8xf32>
+  vector.maskedstore %0[%arg1, %c0], %arg3, %arg4 : memref<4x8xf32>, vector<8xf32>
   return
 }
 
@@ -979,7 +979,7 @@ func.func @fold_vector_load_collapse_shape(
 func.func @fold_vector_maskedload_collapse_shape(
   %arg0 : memref<4x8xf32>, %arg1 : index, %arg3: vector<8xi1>, %arg4: vector<8xf32>) -> vector<8xf32> {
   %0 = memref.collapse_shape %arg0 [[0, 1]] : memref<4x8xf32> into memref<32xf32>
-  %1 = vector.maskedload %0[%arg1], %arg3, %arg4 : memref<32xf32>, vector<8xi1>, vector<8xf32> into vector<8xf32>
+  %1 = vector.maskedload %0[%arg1], %arg3, %arg4 : memref<32xf32>, vector<8xf32>
   return %1 : vector<8xf32>
 }
 
@@ -1017,7 +1017,7 @@ func.func @fold_vector_store_collapse_shape(
 func.func @fold_vector_maskedstore_collapse_shape(
   %arg0 : memref<4x8xf32>, %arg1 : index, %arg3: vector<8xi1>, %arg4: vector<8xf32>) {
   %0 = memref.collapse_shape %arg0 [[0, 1]] : memref<4x8xf32> into memref<32xf32>
-  vector.maskedstore %0[%arg1], %arg3, %arg4 : memref<32xf32>, vector<8xi1>, vector<8xf32>
+  vector.maskedstore %0[%arg1], %arg3, %arg4 : memref<32xf32>, vector<8xf32>
   return
 }
 
diff --git a/mlir/test/Dialect/SparseTensor/sparse_vector.mlir b/mlir/test/Dialect/SparseTensor/sparse_vector.mlir
index 364ba6e71ff3b..c50d44f7faa1e 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_vector.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_vector.mlir
@@ -65,10 +65,10 @@
 // CHECK-VEC4-SVE:       scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[step]] {
 // CHECK-VEC4-SVE:         %[[sub:.*]] = affine.min #[[$map]](%[[c1024]], %[[i]])[%[[step]]]
 // CHECK-VEC4-SVE:         %[[mask:.*]] = vector.create_mask %[[sub]] : vector<[4]xi1>
-// CHECK-VEC4-SVE:         %[[val:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %[[v0]] : memref<?xf32>, vector<[4]xi1>, vector<[4]xf32> into vector<[4]xf32>
+// CHECK-VEC4-SVE:         %[[val:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %[[v0]] : memref<?xf32>, vector<[4]xf32>
 // CHECK-VEC4-SVE:         %[[scalev:.*]] = vector.broadcast %{{.*}} : f32 to vector<[4]xf32>
 // CHECK-VEC4-SVE:         %[[scaled:.*]] = arith.mulf %[[val]], %[[scalev]] : vector<[4]xf32>
-// CHECK-VEC4-SVE:         vector.maskedstore %{{.*}}[%[[i]]], %[[mask]], %[[scaled]] : memref<1024xf32>, vector<[4]xi1>, vector<[4]xf32>
+// CHECK-VEC4-SVE:         vector.maskedstore %{{.*}}[%[[i]]], %[[mask]], %[[scaled]] : memref<1024xf32>, vector<[4]xf32>
 // CHECK-VEC4-SVE:       }
 // CHECK-VEC4-SVE:       return
 //
@@ -136,9 +136,9 @@ func.func @scale_d(%arga: tensor<1024xf32, #DenseVector>, %b: f32, %argx: tensor
 // CHECK-VEC16:       scf.for %[[i:.*]] = %[[q]] to %[[s]] step %[[c16]] {
 // CHECK-VEC16:         %[[sub:.*]] = affine.min #[[$map]](%[[s]], %[[i]])[%[[c16]]]
 // CHECK-VEC16:         %[[mask:.*]] = vector.create_mask %[[sub]] : vector<16xi1>
-// CHECK-VEC16:         %[[li:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %{{.*}} : memref<?xi32>, vector<16xi1>, vector<16xi32> into vector<16xi32>
+// CHECK-VEC16:         %[[li:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %{{.*}} : memref<?xi32>, vector<16xi32>
 // CHECK-VEC16:         %[[zi:.*]] = arith.extui %[[li]] : vector<16xi32> to vector<16xi64>
-// CHECK-VEC16:         %[[la:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %{{.*}} : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+// CHECK-VEC16:         %[[la:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %{{.*}} : memref<?xf32>, vector<16xf32>
 // CHECK-VEC16:         %[[lb:.*]] = vector.gather %{{.*}}[%[[c0]]] [%[[zi]]], %[[mask]], %{{.*}} : memref<1024xf32>, vector<16xi64>, vector<16xi1>, vector<16xf32> into vector<16xf32>
 // CHECK-VEC16:         %[[m:.*]] = arith.mulf %[[la]], %[[lb]] : vector<16xf32>
 // CHECK-VEC16:         vector.scatter %{{.*}}[%[[c0]]] [%[[zi]]], %[[mask]], %[[m]] : memref<1024xf32>, vector<16xi64>, vector<16xi1>, vector<16xf32>
@@ -159,8 +159,8 @@ func.func @scale_d(%arga: tensor<1024xf32, #DenseVector>, %b: f32, %argx: tensor
 // CHECK-VEC16-IDX32:       scf.for %[[i:.*]] = %[[q]] to %[[s]] step %[[c16]] {
 // CHECK-VEC16-IDX32:         %[[sub:.*]] = affine.min #[[$map]](%[[s]], %[[i]])[%[[c16]]]
 // CHECK-VEC16-IDX32:         %[[mask:.*]] = vector.create_mask %[[sub]] : vector<16xi1>
-// CHECK-VEC16-IDX32:         %[[li:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %{{.*}} : memref<?xi32>, vector<16xi1>, vector<16xi32> into vector<16xi32>
-// CHECK-VEC16-IDX32:         %[[la:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %{{.*}} : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+// CHECK-VEC16-IDX32:         %[[li:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %{{.*}} : memref<?xi32>, vector<16xi32>
+// CHECK-VEC16-IDX32:         %[[la:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %{{.*}} : memref<?xf32>, vector<16xf32>
 // CHECK-VEC16-IDX32:         %[[lb:.*]] = vector.gather %{{.*}}[%[[c0]]] [%[[li]]], %[[mask]], %{{.*}} : memref<1024xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
 // CHECK-VEC16-IDX32:         %[[m:.*]] = arith.mulf %[[la]], %[[lb]] : vector<16xf32>
 // CHECK-VEC16-IDX32:         vector.scatter %{{.*}}[%[[c0]]] [%[[li]]], %[[mask]], %[[m]] : memref<1024xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
@@ -185,9 +185,9 @@ func.func @scale_d(%arga: tensor<1024xf32, #DenseVector>, %b: f32, %argx: tensor
 // CHECK-VEC4-SVE:       scf.for %[[i:.*]] = %[[q]] to %[[s]] step %[[step]] {
 // CHECK-VEC4-SVE:         %[[sub:.*]] = affine.min #[[$map]](%[[s]], %[[i]])[%[[step]]]
 // CHECK-VEC4-SVE:         %[[mask:.*]] = vector.create_mask %[[sub]] : vector<[4]xi1>
-// CHECK-VEC4-SVE:         %[[li:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %[[v0i]] : memref<?xi32>, vector<[4]xi1>, vector<[4]xi32> into vector<[4]xi32>
+// CHECK-VEC4-SVE:         %[[li:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %[[v0i]] : memref<?xi32>, vector<[4]xi32>
 // CHECK-VEC4-SVE:         %[[lii64:.*]] = arith.extui %[[li]] : vector<[4]xi32> to vector<[4]xi64>
-// CHECK-VEC4-SVE:         %[[la:.*]] = vector.maskedload %{{.*...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/131482


More information about the Mlir-commits mailing list