[Mlir-commits] [mlir] 3937991 - [mlir] [VectorOps] Add masked load/store operations to Vector dialect

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Aug 5 16:46:36 PDT 2020


Author: aartbik
Date: 2020-08-05T16:45:24-07:00
New Revision: 39379916a7f01d907562c1b70114568dac1778a2

URL: https://github.com/llvm/llvm-project/commit/39379916a7f01d907562c1b70114568dac1778a2
DIFF: https://github.com/llvm/llvm-project/commit/39379916a7f01d907562c1b70114568dac1778a2.diff

LOG: [mlir] [VectorOps] Add masked load/store operations to Vector dialect

The intrinsics were already supported and vector.transfer_read/write lowered
direclty into these operations. By providing them as individual ops, however,
clients can used them directly, and it opens up progressively lowering transfer
operations at higher levels (rather than direct lowering to LLVM IR as done now).

Reviewed By: bkramer

Differential Revision: https://reviews.llvm.org/D85357

Added: 
    mlir/integration_test/Dialect/Vector/CPU/test-maskedload.mlir
    mlir/integration_test/Dialect/Vector/CPU/test-maskedstore.mlir

Modified: 
    mlir/include/mlir/Dialect/Vector/VectorOps.td
    mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
    mlir/lib/Dialect/Vector/VectorOps.cpp
    mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
    mlir/test/Dialect/Vector/invalid.mlir
    mlir/test/Dialect/Vector/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td
index 89a2b1226e1e..4f98fd97df48 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td
@@ -1150,6 +1150,102 @@ def Vector_TransferWriteOp :
   let hasFolder = 1;
 }
 
+def Vector_MaskedLoadOp :
+  Vector_Op<"maskedload">,
+    Arguments<(ins AnyMemRef:$base,
+               VectorOfRankAndType<[1], [I1]>:$mask,
+               VectorOfRank<[1]>:$pass_thru)>,
+    Results<(outs VectorOfRank<[1]>:$result)> {
+
+  let summary = "loads elements from memory into a vector as defined by a mask vector";
+
+  let description = [{
+    The masked load reads elements from memory into a 1-D vector as defined
+    by a base and a 1-D mask vector. When the mask is set, the element is read
+    from memory. Otherwise, the corresponding element is taken from a 1-D
+    pass-through vector. Informally the semantics are:
+    ```
+    result[0] := mask[0] ? MEM[base+0] : pass_thru[0]
+    result[1] := mask[1] ? MEM[base+1] : pass_thru[1]
+    etc.
+    ```
+    The masked load can be used directly where applicable, or can be used
+    during progressively lowering to bring other memory operations closer to
+    hardware ISA support for a masked load. The semantics of the operation
+    closely correspond to those of the `llvm.masked.load`
+    [intrinsic](https://llvm.org/docs/LangRef.html#llvm-masked-load-intrinsics).
+
+    Example:
+
+    ```mlir
+    %0 = vector.maskedload %base, %mask, %pass_thru
+       : memref<?xf32>, vector<8xi1>, vector<8xf32> into vector<8xf32>
+    ```
+  }];
+  let extraClassDeclaration = [{
+    MemRefType getMemRefType() {
+      return base().getType().cast<MemRefType>();
+    }
+    VectorType getMaskVectorType() {
+      return mask().getType().cast<VectorType>();
+    }
+    VectorType getPassThruVectorType() {
+      return pass_thru().getType().cast<VectorType>();
+    }
+    VectorType getResultVectorType() {
+      return result().getType().cast<VectorType>();
+    }
+  }];
+  let assemblyFormat = "$base `,` $mask `,` $pass_thru attr-dict `:` "
+    "type($base) `,` type($mask) `,` type($pass_thru) `into` type($result)";
+}
+
+def Vector_MaskedStoreOp :
+  Vector_Op<"maskedstore">,
+    Arguments<(ins AnyMemRef:$base,
+               VectorOfRankAndType<[1], [I1]>:$mask,
+               VectorOfRank<[1]>:$value)> {
+
+  let summary = "stores elements from a vector into memory as defined by a mask vector";
+
+  let description = [{
+    The masked store operation writes elements from a 1-D vector into memory
+    as defined by a base and a 1-D mask vector. When the mask is set, the
+    corresponding element from the vector is written to memory. Otherwise,
+    no action is taken for the element. Informally the semantics are:
+    ```
+    if (mask[0]) MEM[base+0] = value[0]
+    if (mask[1]) MEM[base+1] = value[1]
+    etc.
+    ```
+    The masked store can be used directly where applicable, or can be used
+    during progressively lowering to bring other memory operations closer to
+    hardware ISA support for a masked store. The semantics of the operation
+    closely correspond to those of the `llvm.masked.store`
+    [intrinsic](https://llvm.org/docs/LangRef.html#llvm-masked-store-intrinsics).
+
+    Example:
+
+    ```mlir
+    vector.maskedstore %base, %mask, %value
+      : memref<?xf32>, vector<8xi1>, vector<8xf32>
+    ```
+  }];
+  let extraClassDeclaration = [{
+    MemRefType getMemRefType() {
+      return base().getType().cast<MemRefType>();
+    }
+    VectorType getMaskVectorType() {
+      return mask().getType().cast<VectorType>();
+    }
+    VectorType getValueVectorType() {
+      return value().getType().cast<VectorType>();
+    }
+  }];
+  let assemblyFormat = "$base `,` $mask `,` $value attr-dict `:` "
+    "type($mask) `,` type($value) `into` type($base)";
+}
+
 def Vector_GatherOp :
   Vector_Op<"gather">,
     Arguments<(ins AnyMemRef:$base,

diff  --git a/mlir/integration_test/Dialect/Vector/CPU/test-maskedload.mlir b/mlir/integration_test/Dialect/Vector/CPU/test-maskedload.mlir
new file mode 100644
index 000000000000..6c6f6ead005f
--- /dev/null
+++ b/mlir/integration_test/Dialect/Vector/CPU/test-maskedload.mlir
@@ -0,0 +1,66 @@
+// RUN: mlir-opt %s -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \
+// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
+// RUN:   -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
+// RUN: FileCheck %s
+
+func @maskedload16(%base: memref<?xf32>, %mask: vector<16xi1>,
+                   %pass_thru: vector<16xf32>) -> vector<16xf32> {
+  %ld = vector.maskedload %base, %mask, %pass_thru
+    : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+  return %ld : vector<16xf32>
+}
+
+func @entry() {
+  // Set up memory.
+  %c0 = constant 0: index
+  %c1 = constant 1: index
+  %c16 = constant 16: index
+  %A = alloc(%c16) : memref<?xf32>
+  scf.for %i = %c0 to %c16 step %c1 {
+    %i32 = index_cast %i : index to i32
+    %fi = sitofp %i32 : i32 to f32
+    store %fi, %A[%i] : memref<?xf32>
+  }
+
+  // Set up pass thru vector.
+  %u = constant -7.0: f32
+  %pass = vector.broadcast %u : f32 to vector<16xf32>
+
+  // Set up masks.
+  %f = constant 0: i1
+  %t = constant 1: i1
+  %none = vector.constant_mask [0] : vector<16xi1>
+  %all = vector.constant_mask [16] : vector<16xi1>
+  %some = vector.constant_mask [8] : vector<16xi1>
+  %0 = vector.insert %f, %some[0] : i1 into vector<16xi1>
+  %1 = vector.insert %t, %0[13] : i1 into vector<16xi1>
+  %2 = vector.insert %t, %1[14] : i1 into vector<16xi1>
+  %other = vector.insert %t, %2[14] : i1 into vector<16xi1>
+
+  //
+  // Masked load tests.
+  //
+
+  %l1 = call @maskedload16(%A, %none, %pass)
+    : (memref<?xf32>, vector<16xi1>, vector<16xf32>) -> (vector<16xf32>)
+  vector.print %l1 : vector<16xf32>
+  // CHECK: ( -7, -7, -7, -7, -7, -7, -7, -7, -7, -7, -7, -7, -7, -7, -7, -7 )
+
+  %l2 = call @maskedload16(%A, %all, %pass)
+    : (memref<?xf32>, vector<16xi1>, vector<16xf32>) -> (vector<16xf32>)
+  vector.print %l2 : vector<16xf32>
+  // CHECK: ( 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 )
+
+  %l3 = call @maskedload16(%A, %some, %pass)
+    : (memref<?xf32>, vector<16xi1>, vector<16xf32>) -> (vector<16xf32>)
+  vector.print %l3 : vector<16xf32>
+  // CHECK: ( 0, 1, 2, 3, 4, 5, 6, 7, -7, -7, -7, -7, -7, -7, -7, -7 )
+
+  %l4 = call @maskedload16(%A, %other, %pass)
+    : (memref<?xf32>, vector<16xi1>, vector<16xf32>) -> (vector<16xf32>)
+  vector.print %l4 : vector<16xf32>
+  // CHECK: ( -7, 1, 2, 3, 4, 5, 6, 7, -7, -7, -7, -7, -7, 13, 14, -7 )
+
+  return
+}
+

diff  --git a/mlir/integration_test/Dialect/Vector/CPU/test-maskedstore.mlir b/mlir/integration_test/Dialect/Vector/CPU/test-maskedstore.mlir
new file mode 100644
index 000000000000..d0132f61c3e9
--- /dev/null
+++ b/mlir/integration_test/Dialect/Vector/CPU/test-maskedstore.mlir
@@ -0,0 +1,89 @@
+// RUN: mlir-opt %s -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \
+// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
+// RUN:   -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
+// RUN: FileCheck %s
+
+func @maskedstore16(%base: memref<?xf32>,
+                    %mask: vector<16xi1>, %value: vector<16xf32>) {
+  vector.maskedstore %base, %mask, %value
+    : vector<16xi1>, vector<16xf32> into memref<?xf32>
+  return
+}
+
+func @printmem16(%A: memref<?xf32>) {
+  %c0 = constant 0: index
+  %c1 = constant 1: index
+  %c16 = constant 16: index
+  %z = constant 0.0: f32
+  %m = vector.broadcast %z : f32 to vector<16xf32>
+  %mem = scf.for %i = %c0 to %c16 step %c1
+    iter_args(%m_iter = %m) -> (vector<16xf32>) {
+    %c = load %A[%i] : memref<?xf32>
+    %i32 = index_cast %i : index to i32
+    %m_new = vector.insertelement %c, %m_iter[%i32 : i32] : vector<16xf32>
+    scf.yield %m_new : vector<16xf32>
+  }
+  vector.print %mem : vector<16xf32>
+  return
+}
+
+func @entry() {
+  // Set up memory.
+  %f0 = constant 0.0: f32
+  %c0 = constant 0: index
+  %c1 = constant 1: index
+  %c16 = constant 16: index
+  %A = alloc(%c16) : memref<?xf32>
+  scf.for %i = %c0 to %c16 step %c1 {
+    store %f0, %A[%i] : memref<?xf32>
+  }
+
+  // Set up value vector.
+  %v = vector.broadcast %f0 : f32 to vector<16xf32>
+  %val = scf.for %i = %c0 to %c16 step %c1
+    iter_args(%v_iter = %v) -> (vector<16xf32>) {
+    %i32 = index_cast %i : index to i32
+    %fi = sitofp %i32 : i32 to f32
+    %v_new = vector.insertelement %fi, %v_iter[%i32 : i32] : vector<16xf32>
+    scf.yield %v_new : vector<16xf32>
+  }
+
+  // Set up masks.
+  %t = constant 1: i1
+  %none = vector.constant_mask [0] : vector<16xi1>
+  %some = vector.constant_mask [8] : vector<16xi1>
+  %more = vector.insert %t, %some[13] : i1 into vector<16xi1>
+  %all = vector.constant_mask [16] : vector<16xi1>
+
+  //
+  // Masked store tests.
+  //
+
+  vector.print %val : vector<16xf32>
+  // CHECK: ( 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 )
+
+  call @printmem16(%A) : (memref<?xf32>) -> ()
+  // CHECK: ( 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 )
+
+  call @maskedstore16(%A, %none, %val)
+    : (memref<?xf32>, vector<16xi1>, vector<16xf32>) -> ()
+  call @printmem16(%A) : (memref<?xf32>) -> ()
+  // CHECK: ( 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 )
+
+  call @maskedstore16(%A, %some, %val)
+    : (memref<?xf32>, vector<16xi1>, vector<16xf32>) -> ()
+  call @printmem16(%A) : (memref<?xf32>) -> ()
+  // CHECK: ( 0, 1, 2, 3, 4, 5, 6, 7, 0, 0, 0, 0, 0, 0, 0, 0 )
+
+  call @maskedstore16(%A, %more, %val)
+    : (memref<?xf32>, vector<16xi1>, vector<16xf32>) -> ()
+  call @printmem16(%A) : (memref<?xf32>) -> ()
+  // CHECK: ( 0, 1, 2, 3, 4, 5, 6, 7, 0, 0, 0, 0, 0, 13, 0, 0 )
+
+  call @maskedstore16(%A, %all, %val)
+    : (memref<?xf32>, vector<16xi1>, vector<16xf32>) -> ()
+  call @printmem16(%A) : (memref<?xf32>) -> ()
+  // CHECK: ( 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 )
+
+  return
+}

diff  --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 8f5ec9e785dc..64d38e4cc293 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -163,6 +163,19 @@ LogicalResult getBasePtr(ConversionPatternRewriter &rewriter, Location loc,
   return success();
 }
 
+// Helper that returns a bit-casted pointer given a memref base.
+LogicalResult getBasePtr(ConversionPatternRewriter &rewriter, Location loc,
+                         Value memref, MemRefType memRefType, Type type,
+                         Value &ptr) {
+  Value base;
+  if (failed(getBase(rewriter, loc, memref, memRefType, base)))
+    return failure();
+  auto pType = type.template cast<LLVM::LLVMType>().getPointerTo();
+  base = rewriter.create<LLVM::BitcastOp>(loc, pType, base);
+  ptr = rewriter.create<LLVM::GEPOp>(loc, pType, base);
+  return success();
+}
+
 // Helper that returns vector of pointers given a memref base and an index
 // vector.
 LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter, Location loc,
@@ -298,6 +311,72 @@ class VectorFlatTransposeOpConversion : public ConvertToLLVMPattern {
   }
 };
 
+/// Conversion pattern for a vector.maskedload.
+class VectorMaskedLoadOpConversion : public ConvertToLLVMPattern {
+public:
+  explicit VectorMaskedLoadOpConversion(MLIRContext *context,
+                                        LLVMTypeConverter &typeConverter)
+      : ConvertToLLVMPattern(vector::MaskedLoadOp::getOperationName(), context,
+                             typeConverter) {}
+
+  LogicalResult
+  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto loc = op->getLoc();
+    auto load = cast<vector::MaskedLoadOp>(op);
+    auto adaptor = vector::MaskedLoadOpAdaptor(operands);
+
+    // Resolve alignment.
+    unsigned align;
+    if (failed(getMemRefAlignment(typeConverter, load, align)))
+      return failure();
+
+    auto vtype = typeConverter.convertType(load.getResultVectorType());
+    Value ptr;
+    if (failed(getBasePtr(rewriter, loc, adaptor.base(), load.getMemRefType(),
+                          vtype, ptr)))
+      return failure();
+
+    rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
+        load, vtype, ptr, adaptor.mask(), adaptor.pass_thru(),
+        rewriter.getI32IntegerAttr(align));
+    return success();
+  }
+};
+
+/// Conversion pattern for a vector.maskedstore.
+class VectorMaskedStoreOpConversion : public ConvertToLLVMPattern {
+public:
+  explicit VectorMaskedStoreOpConversion(MLIRContext *context,
+                                         LLVMTypeConverter &typeConverter)
+      : ConvertToLLVMPattern(vector::MaskedStoreOp::getOperationName(), context,
+                             typeConverter) {}
+
+  LogicalResult
+  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto loc = op->getLoc();
+    auto store = cast<vector::MaskedStoreOp>(op);
+    auto adaptor = vector::MaskedStoreOpAdaptor(operands);
+
+    // Resolve alignment.
+    unsigned align;
+    if (failed(getMemRefAlignment(typeConverter, store, align)))
+      return failure();
+
+    auto vtype = typeConverter.convertType(store.getValueVectorType());
+    Value ptr;
+    if (failed(getBasePtr(rewriter, loc, adaptor.base(), store.getMemRefType(),
+                          vtype, ptr)))
+      return failure();
+
+    rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
+        store, adaptor.value(), ptr, adaptor.mask(),
+        rewriter.getI32IntegerAttr(align));
+    return success();
+  }
+};
+
 /// Conversion pattern for a vector.gather.
 class VectorGatherOpConversion : public ConvertToLLVMPattern {
 public:
@@ -1342,6 +1421,8 @@ void mlir::populateVectorToLLVMConversionPatterns(
               VectorTransferConversion<TransferReadOp>,
               VectorTransferConversion<TransferWriteOp>,
               VectorTypeCastOpConversion,
+              VectorMaskedLoadOpConversion,
+              VectorMaskedStoreOpConversion,
               VectorGatherOpConversion,
               VectorScatterOpConversion,
               VectorExpandLoadOpConversion,

diff  --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index 9e64ff9af80a..e04091e8574f 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -1855,6 +1855,41 @@ Optional<SmallVector<int64_t, 4>> TransferWriteOp::getShapeForUnroll() {
   return llvm::to_vector<4>(getVectorType().getShape());
 }
 
+//===----------------------------------------------------------------------===//
+// MaskedLoadOp
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verify(MaskedLoadOp op) {
+  VectorType maskVType = op.getMaskVectorType();
+  VectorType passVType = op.getPassThruVectorType();
+  VectorType resVType = op.getResultVectorType();
+
+  if (resVType.getElementType() != op.getMemRefType().getElementType())
+    return op.emitOpError("base and result element type should match");
+
+  if (resVType.getDimSize(0) != maskVType.getDimSize(0))
+    return op.emitOpError("expected result dim to match mask dim");
+  if (resVType != passVType)
+    return op.emitOpError("expected pass_thru of same type as result type");
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// MaskedStoreOp
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verify(MaskedStoreOp op) {
+  VectorType maskVType = op.getMaskVectorType();
+  VectorType valueVType = op.getValueVectorType();
+
+  if (valueVType.getElementType() != op.getMemRefType().getElementType())
+    return op.emitOpError("base and value element type should match");
+
+  if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
+    return op.emitOpError("expected value dim to match mask dim");
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // GatherOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index be70c08bc948..5254d2eef4bf 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -970,6 +970,26 @@ func @flat_transpose(%arg0: vector<16xf32>) -> vector<16xf32> {
 // CHECK-SAME:      !llvm.vec<16 x float> into !llvm.vec<16 x float>
 // CHECK:       llvm.return %[[T]] : !llvm.vec<16 x float>
 
+func @masked_load_op(%arg0: memref<?xf32>, %arg1: vector<16xi1>, %arg2: vector<16xf32>) -> vector<16xf32> {
+  %0 = vector.maskedload %arg0, %arg1, %arg2 : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+  return %0 : vector<16xf32>
+}
+
+// CHECK-LABEL: func @masked_load_op
+// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[] : (!llvm.ptr<vec<16 x float>>) -> !llvm.ptr<vec<16 x float>>
+// CHECK: %[[L:.*]] = llvm.intr.masked.load %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.ptr<vec<16 x float>>, !llvm.vec<16 x i1>, !llvm.vec<16 x float>) -> !llvm.vec<16 x float>
+// CHECK: llvm.return %[[L]] : !llvm.vec<16 x float>
+
+func @masked_store_op(%arg0: memref<?xf32>, %arg1: vector<16xi1>, %arg2: vector<16xf32>) {
+  vector.maskedstore %arg0, %arg1, %arg2 : vector<16xi1>, vector<16xf32> into memref<?xf32>
+  return
+}
+
+// CHECK-LABEL: func @masked_store_op
+// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[] : (!llvm.ptr<vec<16 x float>>) -> !llvm.ptr<vec<16 x float>>
+// CHECK: llvm.intr.masked.store %{{.*}}, %[[P]], %{{.*}} {alignment = 4 : i32} : !llvm.vec<16 x float>, !llvm.vec<16 x i1> into !llvm.ptr<vec<16 x float>>
+// CHECK: llvm.return
+
 func @gather_op(%arg0: memref<?xf32>, %arg1: vector<3xi32>, %arg2: vector<3xi1>, %arg3: vector<3xf32>) -> vector<3xf32> {
   %0 = vector.gather %arg0, %arg1, %arg2, %arg3 : (memref<?xf32>, vector<3xi32>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
   return %0 : vector<3xf32>

diff  --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 651fe27cd36c..e1d03ca480f4 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1180,6 +1180,41 @@ func @type_cast_layout(%arg0: memref<4x3xf32, affine_map<(d0, d1)[s0, s1, s2] ->
 
 // -----
 
+func @maskedload_base_type_mismatch(%base: memref<?xf64>, %mask: vector<16xi1>, %pass: vector<16xf32>) {
+  // expected-error at +1 {{'vector.maskedload' op base and result element type should match}}
+  %0 = vector.maskedload %base, %mask, %pass : memref<?xf64>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+}
+
+// -----
+
+func @maskedload_dim_mask_mismatch(%base: memref<?xf32>, %mask: vector<15xi1>, %pass: vector<16xf32>) {
+  // expected-error at +1 {{'vector.maskedload' op expected result dim to match mask dim}}
+  %0 = vector.maskedload %base, %mask, %pass : memref<?xf32>, vector<15xi1>, vector<16xf32> into vector<16xf32>
+}
+
+// -----
+
+func @maskedload_pass_thru_type_mask_mismatch(%base: memref<?xf32>, %mask: vector<16xi1>, %pass: vector<16xi32>) {
+  // expected-error at +1 {{'vector.maskedload' op expected pass_thru of same type as result type}}
+  %0 = vector.maskedload %base, %mask, %pass : memref<?xf32>, vector<16xi1>, vector<16xi32> into vector<16xf32>
+}
+
+// -----
+
+func @maskedstore_base_type_mismatch(%base: memref<?xf64>, %mask: vector<16xi1>, %value: vector<16xf32>) {
+  // expected-error at +1 {{'vector.maskedstore' op base and value element type should match}}
+  vector.maskedstore %base, %mask, %value : vector<16xi1>, vector<16xf32> into memref<?xf64>
+}
+
+// -----
+
+func @maskedstore_dim_mask_mismatch(%base: memref<?xf32>, %mask: vector<15xi1>, %value: vector<16xf32>) {
+  // expected-error at +1 {{'vector.maskedstore' op expected value dim to match mask dim}}
+  vector.maskedstore %base, %mask, %value : vector<15xi1>, vector<16xf32> into memref<?xf32>
+}
+
+// -----
+
 func @gather_base_type_mismatch(%base: memref<?xf64>, %indices: vector<16xi32>, %mask: vector<16xi1>) {
   // expected-error at +1 {{'vector.gather' op base and result element type should match}}
   %0 = vector.gather %base, %indices, %mask : (memref<?xf64>, vector<16xi32>, vector<16xi1>) -> vector<16xf32>

diff  --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index d4d1abe8e646..0381c88cc247 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -369,6 +369,15 @@ func @flat_transpose_int(%arg0: vector<16xi32>) -> vector<16xi32> {
   return %0 : vector<16xi32>
 }
 
+// CHECK-LABEL: @masked_load_and_store
+func @masked_load_and_store(%base: memref<?xf32>, %mask: vector<16xi1>, %passthru: vector<16xf32>) {
+  // CHECK: %[[X:.*]] = vector.maskedload %{{.*}}, %{{.*}}, %{{.*}} : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+  %0 = vector.maskedload %base, %mask, %passthru : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+  // CHECK: vector.maskedstore %{{.*}}, %{{.*}}, %[[X]] : vector<16xi1>, vector<16xf32> into memref<?xf32>
+  vector.maskedstore %base, %mask, %0 : vector<16xi1>, vector<16xf32> into memref<?xf32>
+  return
+}
+
 // CHECK-LABEL: @gather_and_scatter
 func @gather_and_scatter(%base: memref<?xf32>, %indices: vector<16xi32>, %mask: vector<16xi1>) {
   // CHECK: %[[X:.*]] = vector.gather %{{.*}}, %{{.*}}, %{{.*}} : (memref<?xf32>, vector<16xi32>, vector<16xi1>) -> vector<16xf32>


        


More information about the Mlir-commits mailing list