[Mlir-commits] [mlir] 2d10f81 - [mlir][Vector] Introduce 'vector.mask' operation and MaskableOpInterface

Diego Caballero llvmlistbot at llvm.org
Mon Oct 10 14:43:59 PDT 2022


Author: Diego Caballero
Date: 2022-10-10T21:25:43Z
New Revision: 2d10f81d461937d220c02f0c272fe3b1232db3a5

URL: https://github.com/llvm/llvm-project/commit/2d10f81d461937d220c02f0c272fe3b1232db3a5
DIFF: https://github.com/llvm/llvm-project/commit/2d10f81d461937d220c02f0c272fe3b1232db3a5.diff

LOG: [mlir][Vector] Introduce 'vector.mask' operation and MaskableOpInterface

This patch introduces the `vector.mask` operation and the MaskableOpInterface
as described in https://discourse.llvm.org/t/rfc-vector-masking-representation-in-mlir/64964.
The `vector.mask` operation is used to predicate the execution of operations
implementing the MaskableOpInterface. This interface will be implemented by maskable
operations and provides information about its masking constraints and semantics.

For now, only vector transfer and reduction ops implement the MaskableOpInterface
for illustration and testing purposes.

Reviewed By: nicolasvasilache, rriddle

Differential Revision: https://reviews.llvm.org/D134939

Added: 
    mlir/include/mlir/Dialect/Vector/Interfaces/CMakeLists.txt
    mlir/include/mlir/Dialect/Vector/Interfaces/MaskingInterfaces.h
    mlir/include/mlir/Dialect/Vector/Interfaces/MaskingInterfaces.td
    mlir/lib/Dialect/Vector/Interfaces/CMakeLists.txt
    mlir/lib/Dialect/Vector/Interfaces/MaskingInterfaces.cpp

Modified: 
    mlir/include/mlir/Dialect/Vector/CMakeLists.txt
    mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
    mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
    mlir/lib/Dialect/Vector/CMakeLists.txt
    mlir/lib/Dialect/Vector/IR/CMakeLists.txt
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/lib/Interfaces/CMakeLists.txt
    mlir/test/Dialect/Vector/invalid.mlir
    mlir/test/Dialect/Vector/ops.mlir
    utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/CMakeLists.txt b/mlir/include/mlir/Dialect/Vector/CMakeLists.txt
index 9f57627c321fb..0e2b2af11346f 100644
--- a/mlir/include/mlir/Dialect/Vector/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Vector/CMakeLists.txt
@@ -1,2 +1,3 @@
 add_subdirectory(IR)
+add_subdirectory(Interfaces)
 add_subdirectory(Transforms)

diff  --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
index 6f1f361005904..661affd953bd6 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
@@ -13,6 +13,7 @@
 #ifndef MLIR_DIALECT_VECTOR_IR_VECTOROPS_H
 #define MLIR_DIALECT_VECTOR_IR_VECTOROPS_H
 
+#include "mlir/Dialect/Vector/Interfaces/MaskingInterfaces.h"
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/BuiltinTypes.h"
@@ -49,6 +50,10 @@ namespace detail {
 struct BitmaskEnumStorage;
 } // namespace detail
 
+/// Default callback to build a region with a 'vector.yield' terminator with no
+/// arguments.
+void buildTerminatedBody(OpBuilder &builder, Location loc);
+
 /// Return whether `srcType` can be broadcast to `dstVectorType` under the
 /// semantics of the `vector.broadcast` op.
 enum class BroadcastableToResult {

diff  --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 575dfbb9a2114..1019ffe52a371 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -13,6 +13,7 @@
 #ifndef VECTOR_OPS
 #define VECTOR_OPS
 
+include "mlir/Dialect/Vector/Interfaces/MaskingInterfaces.td"
 include "mlir/IR/EnumAttr.td"
 include "mlir/Interfaces/ControlFlowInterfaces.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
@@ -283,6 +284,7 @@ def Vector_ReductionOp :
   Vector_Op<"reduction", [NoSideEffect,
      PredOpTrait<"source operand and result have same element type",
                  TCresVTEtIsSameAsOpBase<0, 0>>,
+     DeclareOpInterfaceMethods<MaskableOpInterface>,
      DeclareOpInterfaceMethods<VectorUnrollOpInterface,
                                ["getShapeForUnroll"]>]>,
     Arguments<(ins Vector_CombiningKindAttr:$kind,
@@ -360,7 +362,7 @@ def Vector_MultiDimReductionOp :
   }];
   let builders = [
     OpBuilder<(ins "Value":$source, "Value":$acc,
-                   "ArrayRef<bool>":$reductionMask, "CombiningKind":$kind)>                   
+                   "ArrayRef<bool>":$reductionMask, "CombiningKind":$kind)>
   ];
   let extraClassDeclaration = [{
     static StringRef getKindAttrStrName() { return "kind"; }
@@ -1050,6 +1052,7 @@ def Vector_TransferReadOp :
   Vector_Op<"transfer_read", [
       DeclareOpInterfaceMethods<VectorTransferOpInterface>,
       DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
+      DeclareOpInterfaceMethods<MaskableOpInterface>,
       DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
       AttrSizedOperandSegments
     ]>,
@@ -1246,6 +1249,12 @@ def Vector_TransferReadOp :
                    "ValueRange":$indices,
                    CArg<"Optional<ArrayRef<bool>>", "::llvm::None">:$inBounds)>,
   ];
+
+  let extraClassDeclaration = [{
+    // MaskableOpInterface methods.
+    bool supportsPassthru() { return true; }
+  }];
+
   let hasCanonicalizer = 1;
   let hasCustomAssemblyFormat = 1;
   let hasFolder = 1;
@@ -1256,6 +1265,7 @@ def Vector_TransferWriteOp :
   Vector_Op<"transfer_write", [
       DeclareOpInterfaceMethods<VectorTransferOpInterface>,
       DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
+      DeclareOpInterfaceMethods<MaskableOpInterface>,
       DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
       AttrSizedOperandSegments
   ]>,
@@ -2120,6 +2130,78 @@ def Vector_CreateMaskOp :
   let assemblyFormat = "$operands attr-dict `:` type(results)";
 }
 
+def Vector_MaskOp : Vector_Op<"mask", [
+  SingleBlockImplicitTerminator<"vector::YieldOp">, RecursiveSideEffects,
+  NoRegionArguments
+]> {
+  let summary = "Predicates a maskable vector operation";
+  let description = [{
+    The `vector.mask` operation predicates the execution of another operation.
+    It takes an `i1` vector mask and an optional pass-thru vector as arguments.
+    A `vector.yield`-terminated region encloses the operation to be masked.
+    Values used within the region are captured from above. Only one *maskable*
+    operation can be masked with a `vector.mask` operation at a time. An
+    operation is *maskable* if it implements the `MaskableOpInterface`.
+
+    The vector mask argument holds a bit for each vector lane and determines
+    which vector lanes should execute the maskable operation and which ones
+    should not. The `vector.mask` operation returns the value produced by the
+    masked execution of the nested operation, if any. The masked-off lanes in
+    the result vector are taken from the corresponding lanes of the pass-thru
+    argument, if provided, or left unmodified, otherwise.
+
+    The `vector.mask` operation does not prescribe how a maskable operation
+    should be masked or how a masked operation should be lowered. Masking
+    constraints and some semantic details are provided by each maskable
+    operation through the `MaskableOpInterface`. Lowering of masked operations
+    is implementation defined. For instance, scalarizing the masked operation
+    or executing the operation for the masked-off lanes are valid lowerings as
+    long as the execution of masked-off lanes does not change the observable
+    behavior of the program.
+
+    Examples:
+
+    ```
+      %0 = vector.mask %mask { vector.reduction <add>, %a : vector<8xi32> into i32 } : vector<8xi1> -> i32
+    ```
+
+    ```
+      %0 = vector.mask %mask, %passthru { arith.divsi %a, %b : vector<8xi32> } : vector<8xi1> -> vector<8xi32>
+    ```
+
+    ```
+      vector.mask %mask { vector.transfer_write %val, %t0[%idx] : vector<16xf32>, memref<?xf32> } : vector<16xi1>
+    ```
+  }];
+
+  // TODO: Support multiple results and passthru values.
+  let arguments = (ins VectorOf<[I1]>:$mask,
+                   Optional<AnyType>:$passthru);
+  let results = (outs Optional<AnyType>:$results);
+  let regions = (region SizedRegion<1>:$maskRegion);
+
+  let skipDefaultBuilders = 1;
+  let builders = [
+    OpBuilder<(ins "Value":$mask,
+                   CArg<"function_ref<void(OpBuilder &, Location)>",
+                        "buildTerminatedBody">:$maskRegion)>,
+    OpBuilder<(ins "Type":$resultType, "Value":$mask,
+                   CArg<"function_ref<void(OpBuilder &, Location)>",
+                        "buildTerminatedBody">:$maskRegion)>,
+    OpBuilder<(ins "Type":$resultType, "Value":$mask,
+                   "Value":$passthru,
+                   CArg<"function_ref<void(OpBuilder &, Location)>",
+                        "buildTerminatedBody">:$maskRegion)>
+  ];
+
+  let extraClassDeclaration = [{
+    static void ensureTerminator(Region &region, Builder &builder, Location loc);
+  }];
+
+  let hasCustomAssemblyFormat = 1;
+  let hasVerifier = 1;
+}
+
 def Vector_TransposeOp :
   Vector_Op<"transpose", [NoSideEffect,
     DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,

diff  --git a/mlir/include/mlir/Dialect/Vector/Interfaces/CMakeLists.txt b/mlir/include/mlir/Dialect/Vector/Interfaces/CMakeLists.txt
new file mode 100644
index 0000000000000..5d11c74684f4d
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Vector/Interfaces/CMakeLists.txt
@@ -0,0 +1 @@
+add_mlir_interface(MaskingInterfaces)

diff  --git a/mlir/include/mlir/Dialect/Vector/Interfaces/MaskingInterfaces.h b/mlir/include/mlir/Dialect/Vector/Interfaces/MaskingInterfaces.h
new file mode 100644
index 0000000000000..ae8e1daae0c4d
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Vector/Interfaces/MaskingInterfaces.h
@@ -0,0 +1,22 @@
+//===- MaskingInterfaces.h - Masking interfaces ---------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the interfaces for masking operations.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_VECTOR_INTERFACES_MASKINGINTERFACES_H_
+#define MLIR_DIALECT_VECTOR_INTERFACES_MASKINGINTERFACES_H_
+
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/OpDefinition.h"
+
+/// Include the generated interface declarations.
+#include "mlir/Dialect/Vector/Interfaces/MaskingInterfaces.h.inc"
+
+#endif // MLIR_DIALECT_VECTOR_INTERFACES_MASKINGINTERFACES_H_

diff  --git a/mlir/include/mlir/Dialect/Vector/Interfaces/MaskingInterfaces.td b/mlir/include/mlir/Dialect/Vector/Interfaces/MaskingInterfaces.td
new file mode 100644
index 0000000000000..3b8a2359cc7d9
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Vector/Interfaces/MaskingInterfaces.td
@@ -0,0 +1,52 @@
+//===- MaskingInterfaces.td - Masking Interfaces Decls === -*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This is the definition file for vector masking related interfaces.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_VECTOR_INTERFACES_MASKINGINTERFACES
+#define MLIR_DIALECT_VECTOR_INTERFACES_MASKINGINTERFACES
+
+include "mlir/IR/OpBase.td"
+
+def MaskableOpInterface : OpInterface<"MaskableOpInterface"> {
+  let description = [{
+    The 'MaskableOpInterface' define an operation that can be masked using the
+    `vector.mask` operation and provides information about its masking
+    constraints and semantics.
+  }];
+  let cppNamespace = "::mlir::vector";
+  let methods = [
+    InterfaceMethod<
+    /*desc=*/"Returns true if the operation may have a passthru argument when"
+             " masked.",
+    /*retTy=*/"bool",
+    /*methodName=*/"supportsPassthru",
+    /*args=*/(ins),
+    /*methodBody=*/"",
+    /*defaultImplementation=*/[{
+      return false;
+    }]>,
+    InterfaceMethod<
+    /*desc=*/"Returns the mask type expected by this operation. It requires the"
+             " operation to be vectorized.",
+    /*retTy=*/"mlir::VectorType",
+    /*methodName=*/"getExpectedMaskType",
+    /*args=*/(ins),
+    /*methodBody=*/"",
+    /*defaultImplementation=*/[{
+    // Default implementation is only aimed for operations that implement the
+    // `getVectorType()` method.
+      return $_op.getVectorType().cloneWith(
+        /*shape=*/llvm::None, IntegerType::get($_op.getContext(), /*width=*/1));
+    }]>,
+  ];
+}
+
+#endif // MLIR_DIALECT_VECTOR_INTERFACES_MASKINGINTERFACES

diff  --git a/mlir/lib/Dialect/Vector/CMakeLists.txt b/mlir/lib/Dialect/Vector/CMakeLists.txt
index 31167e6af908b..6a342a7e47f54 100644
--- a/mlir/lib/Dialect/Vector/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/CMakeLists.txt
@@ -1,3 +1,4 @@
 add_subdirectory(IR)
+add_subdirectory(Interfaces)
 add_subdirectory(Transforms)
 add_subdirectory(Utils)

diff  --git a/mlir/lib/Dialect/Vector/IR/CMakeLists.txt b/mlir/lib/Dialect/Vector/IR/CMakeLists.txt
index 17a0e1d66bc66..d71ea243db7c3 100644
--- a/mlir/lib/Dialect/Vector/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/IR/CMakeLists.txt
@@ -14,6 +14,7 @@ add_mlir_dialect_library(MLIRVectorDialect
   MLIRDataLayoutInterfaces
   MLIRDialectUtils
   MLIRIR
+  MLIRMaskingInterfaces
   MLIRMemRefDialect
   MLIRSideEffectInterfaces
   MLIRTensorDialect

diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 5e0a177e4e1c2..73506463a23d7 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -96,6 +96,12 @@ static MaskFormat getMaskFormat(Value mask) {
   return MaskFormat::Unknown;
 }
 
+/// Default callback to build a region with a 'vector.yield' terminator with no
+/// arguments.
+void mlir::vector::buildTerminatedBody(OpBuilder &builder, Location loc) {
+  builder.create<vector::YieldOp>(loc);
+}
+
 // Helper for verifying combining kinds in contractions and reductions.
 static bool isSupportedCombiningKind(CombiningKind combiningKind,
                                      Type elementType) {
@@ -4860,6 +4866,172 @@ void CreateMaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
   results.add<CreateMaskFolder>(context);
 }
 
+//===----------------------------------------------------------------------===//
+// MaskOp
+//===----------------------------------------------------------------------===//
+
+void MaskOp::build(
+    OpBuilder &builder, OperationState &result, Value mask,
+    function_ref<void(OpBuilder &, Location)> maskRegionBuilder) {
+  assert(maskRegionBuilder &&
+         "builder callback for 'maskRegion' must be present");
+
+  result.addOperands(mask);
+  OpBuilder::InsertionGuard guard(builder);
+  Region *maskRegion = result.addRegion();
+  builder.createBlock(maskRegion);
+  maskRegionBuilder(builder, result.location);
+}
+
+void MaskOp::build(
+    OpBuilder &builder, OperationState &result, Type resultType, Value mask,
+    function_ref<void(OpBuilder &, Location)> maskRegionBuilder) {
+  build(builder, result, resultType, mask, /*passthru=*/Value(),
+        maskRegionBuilder);
+}
+
+void MaskOp::build(
+    OpBuilder &builder, OperationState &result, Type resultType, Value mask,
+    Value passthru,
+    function_ref<void(OpBuilder &, Location)> maskRegionBuilder) {
+  build(builder, result, mask, maskRegionBuilder);
+  if (passthru)
+    result.addOperands(passthru);
+  result.addTypes(resultType);
+}
+
+ParseResult MaskOp::parse(OpAsmParser &parser, OperationState &result) {
+  // Create the op region.
+  result.regions.reserve(1);
+  Region &maskRegion = *result.addRegion();
+
+  auto &builder = parser.getBuilder();
+
+  // Parse all the operands.
+  OpAsmParser::UnresolvedOperand mask;
+  if (parser.parseOperand(mask))
+    return failure();
+
+  // Optional passthru operand.
+  OpAsmParser::UnresolvedOperand passthru;
+  ParseResult parsePassthru = parser.parseOptionalComma();
+  if (parsePassthru.succeeded() && parser.parseOperand(passthru))
+    return failure();
+
+  // Parse op region.
+  if (parser.parseRegion(maskRegion, /*arguments=*/{}, /*argTypes=*/{}))
+    return failure();
+
+  MaskOp::ensureTerminator(maskRegion, builder, result.location);
+
+  // Parse the optional attribute list.
+  if (parser.parseOptionalAttrDict(result.attributes))
+    return failure();
+
+  // Parse all the types.
+  Type maskType;
+  if (parser.parseColonType(maskType))
+    return failure();
+
+  SmallVector<Type> resultTypes;
+  if (parser.parseOptionalArrowTypeList(resultTypes))
+    return failure();
+  result.types.append(resultTypes);
+
+  // Resolve operands.
+  if (parser.resolveOperand(mask, maskType, result.operands))
+    return failure();
+
+  if (parsePassthru.succeeded())
+    if (parser.resolveOperand(passthru, resultTypes[0], result.operands))
+      return failure();
+
+  return success();
+}
+
+void mlir::vector::MaskOp::print(OpAsmPrinter &p) {
+  p << " " << getMask();
+  if (getPassthru())
+    p << ", " << getPassthru();
+
+  // Print single masked operation and skip terminator.
+  p << " { ";
+  Block *singleBlock = &getMaskRegion().getBlocks().front();
+  if (singleBlock && singleBlock->getOperations().size() > 1)
+    p.printCustomOrGenericOp(&singleBlock->front());
+  p << " }";
+
+  p.printOptionalAttrDict(getOperation()->getAttrs());
+
+  p << " : " << getMask().getType();
+  if (getNumResults() > 0)
+    p << " -> " << getResultTypes();
+}
+
+void MaskOp::ensureTerminator(Region &region, Builder &builder, Location loc) {
+  OpTrait::SingleBlockImplicitTerminator<vector::YieldOp>::Impl<
+      MaskOp>::ensureTerminator(region, builder, loc);
+  // Keep the default yield terminator if the number of masked operations is not
+  // the expected. This case will trigger a verification failure.
+  if (region.front().getOperations().size() != 2)
+    return;
+
+  // Replace default yield terminator with a new one that returns the results
+  // from the masked operation.
+  OpBuilder opBuilder(builder.getContext());
+  Operation *maskedOp = &region.front().front();
+  Operation *oldYieldOp = &region.front().back();
+  assert(isa<vector::YieldOp>(oldYieldOp) && "Expected vector::YieldOp");
+
+  opBuilder.setInsertionPoint(oldYieldOp);
+  opBuilder.create<vector::YieldOp>(maskedOp->getLoc(), maskedOp->getResults());
+  oldYieldOp->dropAllReferences();
+  oldYieldOp->erase();
+}
+
+LogicalResult MaskOp::verify() {
+  // Structural checks.
+  Block &block = getMaskRegion().getBlocks().front();
+  if (block.getOperations().size() < 2)
+    return emitOpError("expects an operation to mask");
+  if (block.getOperations().size() > 2)
+    return emitOpError("expects only one operation to mask");
+
+  auto maskableOp = dyn_cast<MaskableOpInterface>(block.front());
+  if (!maskableOp)
+    return emitOpError("expects a maskable operation");
+
+  // Result checks.
+  if (maskableOp->getNumResults() != getNumResults())
+    return emitOpError("expects number of results to match maskable operation "
+                       "number of results");
+
+  if (!llvm::equal(maskableOp->getResultTypes(), getResultTypes()))
+    return emitOpError(
+        "expects result type to match maskable operation result type");
+
+  // Mask checks.
+  if (getMask().getType() != maskableOp.getExpectedMaskType())
+    return emitOpError("expects a ") << maskableOp.getExpectedMaskType()
+                                     << " mask for the maskable operation";
+
+  // Passthru checks.
+  Value passthru = getPassthru();
+  if (passthru) {
+    if (!maskableOp.supportsPassthru())
+      return emitOpError(
+          "doesn't expect a passthru argument for this maskable operation");
+
+    if (maskableOp->getNumResults() != 1)
+      return emitOpError("expects result when passthru argument is provided");
+
+    if (passthru.getType() != maskableOp->getResultTypes()[0])
+      return emitOpError("expects passthru type to match result type");
+  }
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // ScanOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Vector/Interfaces/CMakeLists.txt b/mlir/lib/Dialect/Vector/Interfaces/CMakeLists.txt
new file mode 100644
index 0000000000000..474b94044d1ed
--- /dev/null
+++ b/mlir/lib/Dialect/Vector/Interfaces/CMakeLists.txt
@@ -0,0 +1,21 @@
+set(LLVM_OPTIONAL_SOURCES
+  MaskingInterfaces.cpp
+  )
+
+function(add_mlir_interface_library name)
+  add_mlir_library(MLIR${name}
+    ${name}.cpp
+
+    ADDITIONAL_HEADER_DIRS
+    ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Vector/Interfaces
+
+    DEPENDS
+    MLIR${name}IncGen
+
+    LINK_LIBS PUBLIC
+    MLIRIR
+    )
+endfunction(add_mlir_interface_library)
+
+add_mlir_interface_library(MaskingInterfaces)
+

diff  --git a/mlir/lib/Dialect/Vector/Interfaces/MaskingInterfaces.cpp b/mlir/lib/Dialect/Vector/Interfaces/MaskingInterfaces.cpp
new file mode 100644
index 0000000000000..7c253c6a8c999
--- /dev/null
+++ b/mlir/lib/Dialect/Vector/Interfaces/MaskingInterfaces.cpp
@@ -0,0 +1,16 @@
+//===- MaskingInterfaces.cpp - Masking interfaces ----------====-*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Vector/Interfaces/MaskingInterfaces.h"
+
+//===----------------------------------------------------------------------===//
+// Masking Interfaces
+//===----------------------------------------------------------------------===//
+
+/// Include the definitions of the masking interfaces.
+#include "mlir/Dialect/Vector/Interfaces/MaskingInterfaces.cpp.inc"

diff  --git a/mlir/lib/Interfaces/CMakeLists.txt b/mlir/lib/Interfaces/CMakeLists.txt
index ab1fa87ed7901..83b2fabf159a1 100644
--- a/mlir/lib/Interfaces/CMakeLists.txt
+++ b/mlir/lib/Interfaces/CMakeLists.txt
@@ -40,10 +40,10 @@ add_mlir_interface_library(DataLayoutInterfaces)
 add_mlir_interface_library(DerivedAttributeOpInterface)
 add_mlir_interface_library(InferIntRangeInterface)
 add_mlir_interface_library(InferTypeOpInterface)
+add_mlir_interface_library(LoopLikeInterface)
 add_mlir_interface_library(ParallelCombiningOpInterface)
 add_mlir_interface_library(ShapedOpInterfaces)
 add_mlir_interface_library(SideEffectInterfaces)
 add_mlir_interface_library(TilingInterface)
 add_mlir_interface_library(VectorInterfaces)
 add_mlir_interface_library(ViewLikeInterface)
-add_mlir_interface_library(LoopLikeInterface)

diff  --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 5fbcefcf3414f..115cf41f5fd54 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -57,7 +57,7 @@ func.func @shuffle_rank_mismatch(%arg0: vector<2xf32>, %arg1: vector<4x2xf32>) {
 }
 
 // -----
- 
+
 func.func @shuffle_rank_mismatch_0d(%arg0: vector<f32>, %arg1: vector<1xf32>) {
   // expected-error at +1 {{'vector.shuffle' op rank mismatch}}
   %1 = vector.shuffle %arg0, %arg1 [0, 1] : vector<f32>, vector<1xf32>
@@ -1166,7 +1166,7 @@ func.func @transpose_rank_mismatch(%arg0: vector<4x16x11xf32>) {
 }
 
 // -----
- 
+
 func.func @transpose_length_mismatch_0d(%arg0: vector<f32>) {
   // expected-error at +1 {{'vector.transpose' op transposition length mismatch: 1}}
   %0 = vector.transpose %arg0, [1] : vector<f32> to vector<f32>
@@ -1586,3 +1586,49 @@ func.func @warp_mismatch_rank(%laneid: index) {
   }
   return
 }
+
+// -----
+
+func.func @vector_mask_empty(%m0: vector<16xi1>) -> i32 {
+  // expected-error at +1 {{'vector.mask' op expects an operation to mask}}
+  vector.mask %m0 { } : vector<16xi1>
+}
+
+// -----
+
+func.func @vector_mask_multiple_ops(%t0: tensor<?xf32>, %t1: tensor<?xf32>, %idx: index, %val: vector<16xf32>, %m0: vector<16xi1>) {
+  %ft0 = arith.constant 0.0 : f32
+  // expected-error at +1 {{'vector.mask' op expects only one operation to mask}}
+  vector.mask %m0 {
+    vector.transfer_write %val, %t0[%idx] : vector<16xf32>, tensor<?xf32>
+    vector.transfer_write %val, %t1[%idx] : vector<16xf32>, tensor<?xf32>
+  } : vector<16xi1>
+  return
+}
+
+// -----
+
+func.func @vector_mask_shape_mismatch(%a: vector<8xi32>, %m0: vector<16xi1>) -> i32 {
+  // expected-error at +1 {{'vector.mask' op expects a 'vector<8xi1>' mask for the maskable operation}}
+  %0 = vector.mask %m0 { vector.reduction <add>, %a : vector<8xi32> into i32 } : vector<16xi1> -> i32
+  return %0 : i32
+}
+
+// -----
+
+// expected-note at +1 {{prior use here}}
+func.func @vector_mask_passthru_type_mismatch(%t0: tensor<?xf32>, %idx: index, %m0: vector<16xi1>, %pt0: vector<16xi32>) -> vector<16xf32> {
+  %ft0 = arith.constant 0.0 : f32
+  // expected-error at +1 {{use of value '%pt0' expects 
diff erent type than prior uses: 'vector<16xf32>' vs 'vector<16xi32>'}}
+  %0 = vector.mask %m0, %pt0 { vector.transfer_read %t0[%idx], %ft0 : tensor<?xf32>, vector<16xf32> } : vector<16xi1> -> vector<16xf32>
+  return %0 : vector<16xf32>
+}
+
+// -----
+
+func.func @vector_mask_passthru_no_return(%val: vector<16xf32>, %t0: tensor<?xf32>, %idx: index, %m0: vector<16xi1>, %pt0: vector<16xf32>) {
+  // expected-error at +1 {{'vector.mask' op expects result type to match maskable operation result type}}
+  vector.mask %m0, %pt0 { vector.transfer_write %val, %t0[%idx] : vector<16xf32>, tensor<?xf32> } : vector<16xi1> -> vector<16xf32>
+  return
+}
+

diff  --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index c74a5903a2fbe..7a6a2c3fdfe0b 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -786,7 +786,7 @@ func.func @vector_scan(%0: vector<4x8x16x32xf32>) -> vector<4x8x16x32xf32> {
 func.func @test_splat_op(%s : f32) {
   // CHECK: vector.splat [[S]] : vector<8xf32>
   %v = vector.splat %s : vector<8xf32>
-  
+
   // CHECK: vector.splat [[S]] : vector<4xf32>
   %u = "vector.splat"(%s) : (f32) -> vector<4xf32>
   return
@@ -824,4 +824,32 @@ func.func @warp_operand_result(%laneid: index, %v0 : vector<4xi32>) -> (vector<4
   return %2 : vector<4xi32>
 }
 
+// CHECK-LABEL: func @vector_mask
+func.func @vector_mask(%a: vector<8xi32>, %m0: vector<8xi1>) -> i32 {
+//  CHECK-NEXT:   %{{.*}} = vector.mask %{{.*}} { vector.reduction <add>, %{{.*}} : vector<8xi32> into i32 } : vector<8xi1> -> i32
+  %0 = vector.mask %m0 { vector.reduction <add>, %a : vector<8xi32> into i32 } : vector<8xi1> -> i32
+  return %0 : i32
+}
+
+// CHECK-LABEL: func @vector_mask_passthru
+func.func @vector_mask_passthru(%t0: tensor<?xf32>, %idx: index, %m0: vector<16xi1>, %pt0: vector<16xf32>) -> vector<16xf32> {
+  %ft0 = arith.constant 0.0 : f32
+//       CHECK:   %{{.*}} = vector.mask %{{.*}}, %{{.*}} { vector.transfer_read %{{.*}}[%{{.*}}], %{{.*}} : tensor<?xf32>, vector<16xf32> } : vector<16xi1> -> vector<16xf32>
+  %0 = vector.mask %m0, %pt0 { vector.transfer_read %t0[%idx], %ft0 : tensor<?xf32>, vector<16xf32> } : vector<16xi1> -> vector<16xf32>
+  return %0 : vector<16xf32>
+}
+
+// CHECK-LABEL: func @vector_mask_no_return
+func.func @vector_mask_no_return(%val: vector<16xf32>, %t0: memref<?xf32>, %idx: index, %m0: vector<16xi1>) {
+//  CHECK-NEXT:   vector.mask %{{.*}} { vector.transfer_write %{{.*}}, %{{.*}}[%{{.*}}] : vector<16xf32>, memref<?xf32> } : vector<16xi1>
+  vector.mask %m0 { vector.transfer_write %val, %t0[%idx] : vector<16xf32>, memref<?xf32> } : vector<16xi1>
+  return
+}
+
+// CHECK-LABEL: func @vector_mask_tensor_return
+func.func @vector_mask_tensor_return(%val: vector<16xf32>, %t0: tensor<?xf32>, %idx: index, %m0: vector<16xi1>) {
+//  CHECK-NEXT:   vector.mask %{{.*}} { vector.transfer_write %{{.*}}, %{{.*}}[%{{.*}}] : vector<16xf32>, tensor<?xf32> } : vector<16xi1> -> tensor<?xf32>
+  vector.mask %m0 { vector.transfer_write %val, %t0[%idx] : vector<16xf32>, tensor<?xf32> } : vector<16xi1> -> tensor<?xf32>
+  return
+}
 

diff  --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 39014d62040e1..87c4061504e66 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -3213,6 +3213,7 @@ cc_library(
         ":DialectUtils",
         ":IR",
         ":InferTypeOpInterface",
+        ":MaskingInterfaces",
         ":MemRefDialect",
         ":SideEffectInterfaces",
         ":Support",
@@ -7954,6 +7955,17 @@ cc_library(
     ],
 )
 
+##---------------------------------------------------------------------------##
+# Vector dialect.
+##---------------------------------------------------------------------------##
+
+td_library(
+    name = "MaskingInterfacesTdFiles",
+    srcs = ["include/mlir/Dialect/Vector/Interfaces/MaskingInterfaces.td"],
+    includes = ["include"],
+    deps = [":OpBaseTdFiles"],
+)
+
 td_library(
     name = "VectorOpsTdFiles",
     srcs = ["include/mlir/Dialect/Vector/IR/VectorOps.td"],
@@ -7961,6 +7973,7 @@ td_library(
     deps = [
         ":ControlFlowInterfacesTdFiles",
         ":InferTypeOpInterfaceTdFiles",
+        ":MaskingInterfacesTdFiles",
         ":OpBaseTdFiles",
         ":SideEffectInterfacesTdFiles",
         ":VectorInterfacesTdFiles",
@@ -7968,6 +7981,24 @@ td_library(
     ],
 )
 
+gentbl_cc_library(
+    name = "MaskingInterfacesIncGen",
+    strip_include_prefix = "include",
+    tbl_outs = [
+        (
+            ["-gen-op-interface-decls"],
+            "include/mlir/Dialect/Vector/Interfaces/MaskingInterfaces.h.inc",
+        ),
+        (
+            ["-gen-op-interface-defs"],
+            "include/mlir/Dialect/Vector/Interfaces/MaskingInterfaces.cpp.inc",
+        ),
+    ],
+    tblgen = ":mlir-tblgen",
+    td_file = "include/mlir/Dialect/Vector/Interfaces/MaskingInterfaces.td",
+    deps = [":MaskingInterfacesTdFiles"],
+)
+
 gentbl_cc_library(
     name = "VectorOpsIncGen",
     strip_include_prefix = "include",
@@ -8020,6 +8051,19 @@ gentbl_cc_library(
     deps = [":VectorOpsTdFiles"],
 )
 
+cc_library(
+    name = "MaskingInterfaces",
+    srcs = ["lib/Dialect/Vector/Interfaces/MaskingInterfaces.cpp"],
+    hdrs = ["include/mlir/Dialect/Vector/Interfaces/MaskingInterfaces.h"],
+    includes = ["include"],
+    deps = [
+        ":IR",
+        ":MaskingInterfacesIncGen",
+        ":Support",
+        "//llvm:Support",
+    ],
+)
+
 cc_library(
     name = "VectorToLLVM",
     srcs = glob([


        


More information about the Mlir-commits mailing list