[Mlir-commits] [mlir] a65fb19 - Add a "kind" attribute to ContractionOp and OuterProductOp.

Mehdi Amini llvmlistbot at llvm.org
Fri Feb 12 12:24:09 PST 2021


Author: Praveen Narayanan
Date: 2021-02-12T20:23:59Z
New Revision: a65fb1916cb49b2f1ab3e7016fa54c04b2b048b4

URL: https://github.com/llvm/llvm-project/commit/a65fb1916cb49b2f1ab3e7016fa54c04b2b048b4
DIFF: https://github.com/llvm/llvm-project/commit/a65fb1916cb49b2f1ab3e7016fa54c04b2b048b4.diff

LOG: Add a "kind" attribute to ContractionOp and OuterProductOp.

Currently, vector.contract joins the intermediate result and the accumulator
argument (of ranks K) using summation. We desire more joining operations ---
such as max --- to help vector.contract express reductions. This change extends
Vector_ContractionOp to take an optional attribute (called "kind", of enum type
CombiningKind) specifying the joining operation to be add/mul/min/max for int/fp
, and and/or/xor for int only. By default this attribute has value "add".

To implement this we also need to extend vector.outerproduct, since
vector.contract gets transformed to vector.outerproduct (and that to
vector.fma). The extension for vector.outerproduct is also an optional kind
attribute that uses the same enum type and possible values. The default is
"add". In case of max/min we transform vector.outerproduct to a combination of
compare and select.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D93280

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/CMakeLists.txt
    mlir/include/mlir/Dialect/Vector/VectorOps.h
    mlir/include/mlir/Dialect/Vector/VectorOps.td
    mlir/include/mlir/IR/OpBase.td
    mlir/lib/Dialect/Vector/CMakeLists.txt
    mlir/lib/Dialect/Vector/VectorOps.cpp
    mlir/lib/Dialect/Vector/VectorTransforms.cpp
    mlir/test/Dialect/Linalg/vectorization.mlir
    mlir/test/Dialect/Vector/ops.mlir
    mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir
    mlir/test/Dialect/Vector/vector-transforms.mlir
    mlir/tools/mlir-tblgen/EnumsGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/CMakeLists.txt b/mlir/include/mlir/Dialect/Vector/CMakeLists.txt
index 23ad74e0cb72..b16570052f79 100644
--- a/mlir/include/mlir/Dialect/Vector/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Vector/CMakeLists.txt
@@ -1,2 +1,8 @@
 add_mlir_dialect(VectorOps vector)
 add_mlir_doc(VectorOps -gen-op-doc VectorOps Dialects/)
+
+set(LLVM_TARGET_DEFINITIONS VectorOps.td)
+mlir_tablegen(VectorOpsEnums.h.inc -gen-enum-decls)
+mlir_tablegen(VectorOpsEnums.cpp.inc -gen-enum-defs)
+add_public_tablegen_target(MLIRVectorOpsEnumsIncGen)
+add_dependencies(mlir-headers MLIRVectorOpsEnumsIncGen)

diff  --git a/mlir/include/mlir/Dialect/Vector/VectorOps.h b/mlir/include/mlir/Dialect/Vector/VectorOps.h
index afc55c1911ba..3650470bc7be 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.h
@@ -21,11 +21,21 @@
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "mlir/Interfaces/VectorInterfaces.h"
 #include "mlir/Interfaces/ViewLikeInterface.h"
+#include "llvm/ADT/StringExtras.h"
+
+// Pull in all enum type definitions and utility function declarations.
+#include "mlir/Dialect/Vector/VectorOpsEnums.h.inc"
 
 namespace mlir {
 class MLIRContext;
 class OwningRewritePatternList;
+
 namespace vector {
+class VectorDialect;
+
+namespace detail {
+struct BitmaskEnumStorage;
+} // namespace detail
 
 /// Collect a set of vector-to-vector canonicalization patterns.
 void populateVectorToVectorCanonicalizationPatterns(
@@ -63,6 +73,22 @@ void populateBubbleVectorBitCastOpPatterns(OwningRewritePatternList &patterns,
 void populateVectorSlicesLoweringPatterns(OwningRewritePatternList &patterns,
                                           MLIRContext *context);
 
+/// An attribute that specifies the combining function for `vector.contract`,
+/// and `vector.reduction`.
+class CombiningKindAttr
+    : public Attribute::AttrBase<CombiningKindAttr, Attribute,
+                                 detail::BitmaskEnumStorage> {
+public:
+  using Base::Base;
+
+  static CombiningKindAttr get(CombiningKind kind, MLIRContext *context);
+
+  CombiningKind getKind() const;
+
+  void print(DialectAsmPrinter &p) const;
+  static Attribute parse(DialectAsmParser &parser);
+};
+
 /// Enum to control the lowering of `vector.contract` operations.
 enum class VectorContractLowering {
   /// Progressively lower to finer grained `vector.contract` and dot-products.

diff  --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td
index 3da4cd25dc62..0f80b753b2c2 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td
@@ -37,6 +37,35 @@ class Vector_Op<string mnemonic, list<OpTrait> traits = []> :
   let parser = [{ return ::parse$cppClass(parser, result); }];
 }
 
+// The "kind" of combining function for contractions and reductions.
+def COMBINING_KIND_ADD : BitEnumAttrCase<"ADD", 0x1,  "add">;
+def COMBINING_KIND_MUL : BitEnumAttrCase<"MUL", 0x2,  "mul">;
+def COMBINING_KIND_MIN : BitEnumAttrCase<"MIN", 0x4,  "min">;
+def COMBINING_KIND_MAX : BitEnumAttrCase<"MAX", 0x8,  "max">;
+def COMBINING_KIND_AND : BitEnumAttrCase<"AND", 0x10, "and">;
+def COMBINING_KIND_OR  : BitEnumAttrCase<"OR",  0x20, "or">;
+def COMBINING_KIND_XOR : BitEnumAttrCase<"XOR", 0x40, "xor">;
+
+def CombiningKind : BitEnumAttr<
+    "CombiningKind",
+    "Kind of combining function for contractions and reductions",
+    [COMBINING_KIND_ADD, COMBINING_KIND_MUL, COMBINING_KIND_MIN,
+     COMBINING_KIND_MAX, COMBINING_KIND_AND, COMBINING_KIND_OR,
+     COMBINING_KIND_XOR]> {
+  let cppNamespace = "::mlir::vector";
+}
+
+def Vector_CombiningKindAttr : DialectAttr<
+    Vector_Dialect,
+    CPred<"$_self.isa<::mlir::vector::CombiningKindAttr>()">,
+    "Kind of combining function for contractions and reductions"> {
+  let storageType = "::mlir::vector::CombiningKindAttr";
+  let returnType = "::mlir::vector::CombiningKind";
+  let convertFromStorage = "$_self.getKind()";
+  let constBuilderCall =
+          "::mlir::vector::CombiningKindAttr::get($0, $_builder.getContext())";
+}
+
 // TODO: Add an attribute to specify a 
diff erent algebra with operators other
 // than the current set: {*, +}.
 def Vector_ContractionOp :
@@ -49,7 +78,9 @@ def Vector_ContractionOp :
     ]>,
     Arguments<(ins AnyVector:$lhs, AnyVector:$rhs, AnyType:$acc,
                Variadic<VectorOf<[I1]>>:$masks,
-               AffineMapArrayAttr:$indexing_maps, ArrayAttr:$iterator_types)>,
+               AffineMapArrayAttr:$indexing_maps, ArrayAttr:$iterator_types,
+               DefaultValuedAttr<Vector_CombiningKindAttr,
+                                 "CombiningKind::ADD">:$kind)>,
     Results<(outs AnyType)> {
   let summary = "vector contraction operation";
   let description = [{
@@ -88,6 +119,11 @@ def Vector_ContractionOp :
     and acc arguments. An indexing map attribute specifies a mapping from each
     iterator in the iterator type list, to each dimension of an N-D vector.
 
+    An optional kind attribute may be used to specify the combining function
+    between the intermediate result and accumulator argument of rank K. This
+    attribute can take the values add/mul/min/max for int/fp, and/or/xor for
+    int only. The default is "add".
+
     Example:
 
     ```mlir
@@ -146,6 +182,20 @@ def Vector_ContractionOp :
     // types than accumulator/result.
     %6 = vector.contract #contraction_trait %0, %1, %2
       : vector<10xf16>, vector<10xf16> into f32
+
+    // Contract with max (K = 0).
+    #contraction_accesses = [
+     affine_map<(i) -> (i)>,
+     affine_map<(i) -> (i)>,
+     affine_map<(i) -> ()>
+    ]
+    #contraction_trait = {
+      indexing_maps = #contraction_accesses,
+      iterator_types = ["reduction"],
+      kind = #vector.kind<max>
+    }
+    %7 = vector.contract #contraction_trait %0, %1, %2
+      : vector<10xf32>, vector<10xf32> into f32
     ```
   }];
   let builders = [
@@ -189,6 +239,12 @@ def Vector_ContractionOp :
 
     std::vector<std::pair<int64_t, int64_t>> getContractingDimMap();
     std::vector<std::pair<int64_t, int64_t>> getBatchDimMap();
+
+    static constexpr StringRef getKindAttrName() { return "kind"; }
+
+    static CombiningKind getDefaultKind() {
+      return CombiningKind::ADD;
+    }
   }];
 }
 
@@ -820,7 +876,9 @@ def Vector_OuterProductOp :
                 TCresVTEtIsSameAsOpBase<0, 0>>,
     PredOpTrait<"rhs operand and result have same element type",
                 TCresVTEtIsSameAsOpBase<0, 1>>]>,
-    Arguments<(ins AnyVector:$lhs, AnyType:$rhs, Variadic<AnyVector>:$acc)>,
+    Arguments<(ins AnyVector:$lhs, AnyType:$rhs,
+               Variadic<AnyVector>:$acc,
+               DefaultValuedAttr<Vector_CombiningKindAttr, "CombiningKind::ADD">:$kind)>,
     Results<(outs AnyVector)> {
   let summary = "vector outerproduct with optional fused add";
   let description = [{
@@ -846,6 +904,12 @@ def Vector_OuterProductOp :
     lowered to the LLVMIR dialect, this form emits `llvm.intr.fma`, which
     is guaranteed to lower to actual `fma` instructions on x86.
 
+    An optional kind attribute may be specified to be add/mul/min/max
+    for int/fp, and and/or/xor for int only. The default is "add", in which
+    case the operation returns a fused multiply-add. In other cases it returns
+    a multiply followed by the appropriate operation (for example, a compare and
+    select for "max").
+
     Example:
 
     ```
@@ -856,6 +920,10 @@ def Vector_OuterProductOp :
       vector<4xf32>, vector<8xf32>, vector<4x8xf32>
     return %3: vector<4x8xf32>
 
+    %4 = vector.outerproduct %0, %1, %2 {kind = #vector.kind<max>}:
+      vector<4xf32>, vector<8xf32>, vector<4x8xf32>
+    return %3: vector<4x8xf32>
+
     %6 = vector.outerproduct %4, %5: vector<10xf32>, f32
     return %6: vector<10xf32>
 
@@ -880,6 +948,12 @@ def Vector_OuterProductOp :
     VectorType getVectorType() {
       return getResult().getType().cast<VectorType>();
     }
+    static constexpr StringRef getKindAttrName() {
+      return "kind";
+    }
+    static CombiningKind getDefaultKind() {
+      return CombiningKind::ADD;
+    }
   }];
 }
 

diff  --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 31598c143211..4a5731cb2f86 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -1131,9 +1131,9 @@ class I64EnumAttrCase<string sym, int val, string str = sym>
 // A bit enum case stored with 32-bit IntegerAttr. `val` here is *not* the
 // ordinal number of the bit that is set. It is the 32-bit integer with only
 // one bit set.
-class BitEnumAttrCase<string sym, int val> :
-    EnumAttrCaseInfo<sym, val, sym>,
-    SignlessIntegerAttrBase<I32, "case " # sym> {
+class BitEnumAttrCase<string sym, int val, string str = sym> :
+    EnumAttrCaseInfo<sym, val, str>,
+    SignlessIntegerAttrBase<I32, "case " # str> {
   let predicate = CPred<
     "$_self.cast<::mlir::IntegerAttr>().getValue().getZExtValue() & "
     # val # "u">;

diff  --git a/mlir/lib/Dialect/Vector/CMakeLists.txt b/mlir/lib/Dialect/Vector/CMakeLists.txt
index 5c345fec7204..957ea6650854 100644
--- a/mlir/lib/Dialect/Vector/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/CMakeLists.txt
@@ -10,6 +10,7 @@ add_mlir_dialect_library(MLIRVector
 
   DEPENDS
   MLIRVectorOpsIncGen
+  MLIRVectorOpsEnumsIncGen
 
   LINK_LIBS PUBLIC
   MLIRAffineEDSC

diff  --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index 2c65a00092ea..678205d0b5d2 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -19,6 +19,7 @@
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/DialectImplementation.h"
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
@@ -28,6 +29,9 @@
 #include "llvm/ADT/bit.h"
 #include <numeric>
 
+// Pull in all enum type and utility function definitions.
+#include "mlir/Dialect/Vector/VectorOpsEnums.cpp.inc"
+
 using namespace mlir;
 using namespace mlir::vector;
 
@@ -77,11 +81,30 @@ static MaskFormat get1DMaskFormat(Value mask) {
   return MaskFormat::Unknown;
 }
 
+// Helper for verifying combining kinds in contractions and reductions.
+static bool isSupportedCombiningKind(CombiningKind combiningKind,
+                                     Type elementType) {
+  switch (combiningKind) {
+  case CombiningKind::ADD:
+  case CombiningKind::MUL:
+  case CombiningKind::MIN:
+  case CombiningKind::MAX:
+    return elementType.isIntOrIndexOrFloat();
+  case CombiningKind::AND:
+  case CombiningKind::OR:
+  case CombiningKind::XOR:
+    return elementType.isIntOrIndex();
+  }
+  return false;
+}
+
 //===----------------------------------------------------------------------===//
 // VectorDialect
 //===----------------------------------------------------------------------===//
 
 void VectorDialect::initialize() {
+  addAttributes<CombiningKindAttr>();
+
   addOperations<
 #define GET_OP_LIST
 #include "mlir/Dialect/Vector/VectorOps.cpp.inc"
@@ -105,6 +128,106 @@ ArrayAttr vector::getVectorSubscriptAttr(Builder &builder,
   return builder.getI64ArrayAttr(values);
 }
 
+//===----------------------------------------------------------------------===//
+// CombiningKindAttr
+//===----------------------------------------------------------------------===//
+
+namespace mlir {
+namespace vector {
+namespace detail {
+struct BitmaskEnumStorage : public AttributeStorage {
+  using KeyTy = uint64_t;
+
+  BitmaskEnumStorage(KeyTy val) : value(val) {}
+
+  bool operator==(const KeyTy &key) const { return value == key; }
+
+  static BitmaskEnumStorage *construct(AttributeStorageAllocator &allocator,
+                                       const KeyTy &key) {
+    return new (allocator.allocate<BitmaskEnumStorage>())
+        BitmaskEnumStorage(key);
+  }
+
+  KeyTy value = 0;
+};
+} // namespace detail
+} // namespace vector
+} // namespace mlir
+
+CombiningKindAttr CombiningKindAttr::get(CombiningKind kind,
+                                         MLIRContext *context) {
+  return Base::get(context, static_cast<uint64_t>(kind));
+}
+
+CombiningKind CombiningKindAttr::getKind() const {
+  return static_cast<CombiningKind>(getImpl()->value);
+}
+
+static constexpr const CombiningKind combiningKindsList[] = {
+    // clang-format off
+    CombiningKind::ADD,
+    CombiningKind::MUL,
+    CombiningKind::MIN,
+    CombiningKind::MAX,
+    CombiningKind::AND,
+    CombiningKind::OR,
+    CombiningKind::XOR,
+    // clang-format on
+};
+
+void CombiningKindAttr::print(DialectAsmPrinter &printer) const {
+  printer << "kind<";
+  auto kinds = llvm::make_filter_range(combiningKindsList, [&](auto kind) {
+    return bitEnumContains(this->getKind(), kind);
+  });
+  llvm::interleaveComma(kinds, printer,
+                        [&](auto kind) { printer << stringifyEnum(kind); });
+  printer << ">";
+}
+
+Attribute CombiningKindAttr::parse(DialectAsmParser &parser) {
+  if (failed(parser.parseLess()))
+    return {};
+
+  StringRef elemName;
+  if (failed(parser.parseKeyword(&elemName)))
+    return {};
+
+  auto kind = symbolizeCombiningKind(elemName);
+  if (!kind) {
+    parser.emitError(parser.getNameLoc(), "Unknown combining kind: ")
+        << elemName;
+    return {};
+  }
+
+  if (failed(parser.parseGreater()))
+    return {};
+
+  return CombiningKindAttr::get(kind.getValue(),
+                                parser.getBuilder().getContext());
+}
+
+Attribute VectorDialect::parseAttribute(DialectAsmParser &parser,
+                                        Type type) const {
+  StringRef attrKind;
+  if (parser.parseKeyword(&attrKind))
+    return {};
+
+  if (attrKind == "kind")
+    return CombiningKindAttr::parse(parser);
+
+  parser.emitError(parser.getNameLoc(), "Unknown attribute type: ") << attrKind;
+  return {};
+}
+
+void VectorDialect::printAttribute(Attribute attr,
+                                   DialectAsmPrinter &os) const {
+  if (auto ck = attr.dyn_cast<CombiningKindAttr>())
+    ck.print(os);
+  else
+    llvm_unreachable("Unknown attribute type");
+}
+
 //===----------------------------------------------------------------------===//
 // ReductionOp
 //===----------------------------------------------------------------------===//
@@ -193,6 +316,9 @@ void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
   result.addTypes(acc.getType());
   result.addAttribute(getIndexingMapsAttrName(), indexingMaps);
   result.addAttribute(getIteratorTypesAttrName(), iteratorTypes);
+  result.addAttribute(ContractionOp::getKindAttrName(),
+                      CombiningKindAttr::get(ContractionOp::getDefaultKind(),
+                                             builder.getContext()));
 }
 
 static ParseResult parseContractionOp(OpAsmParser &parser,
@@ -221,6 +347,11 @@ static ParseResult parseContractionOp(OpAsmParser &parser,
     return failure();
   result.attributes.assign(dictAttr.getValue().begin(),
                            dictAttr.getValue().end());
+  if (!result.attributes.get(ContractionOp::getKindAttrName())) {
+    result.addAttribute(ContractionOp::getKindAttrName(),
+                        CombiningKindAttr::get(ContractionOp::getDefaultKind(),
+                                               result.getContext()));
+  }
   if (masksInfo.empty())
     return success();
   if (masksInfo.size() != 2)
@@ -421,12 +552,20 @@ static LogicalResult verify(ContractionOp op) {
         rhsMaskType.getShape().size() != rhsType.getShape().size())
       return op.emitOpError("invalid vector mask rank");
   }
+
+  // Verify supported combining kind.
+  auto vectorType = resType.dyn_cast<VectorType>();
+  auto elementType = vectorType ? vectorType.getElementType() : resType;
+  if (!isSupportedCombiningKind(op.kind(), elementType))
+    return op.emitOpError("unsupported contraction type");
+
   return success();
 }
 
 ArrayRef<StringRef> ContractionOp::getTraitAttrNames() {
-  static constexpr StringRef names[2] = {getIndexingMapsAttrName(),
-                                         getIteratorTypesAttrName()};
+  static constexpr StringRef names[3] = {getIndexingMapsAttrName(),
+                                         getIteratorTypesAttrName(),
+                                         ContractionOp::getKindAttrName()};
   return llvm::makeArrayRef(names);
 }
 
@@ -1497,8 +1636,10 @@ void OuterProductOp::build(OpBuilder &builder, OperationState &result,
 
 static void print(OpAsmPrinter &p, OuterProductOp op) {
   p << op.getOperationName() << " " << op.lhs() << ", " << op.rhs();
-  if (!op.acc().empty())
+  if (!op.acc().empty()) {
     p << ", " << op.acc();
+    p.printOptionalAttrDict(op.getAttrs());
+  }
   p << " : " << op.lhs().getType() << ", " << op.rhs().getType();
 }
 
@@ -1506,8 +1647,10 @@ static ParseResult parseOuterProductOp(OpAsmParser &parser,
                                        OperationState &result) {
   SmallVector<OpAsmParser::OperandType, 3> operandsInfo;
   Type tLHS, tRHS;
-  if (parser.parseOperandList(operandsInfo) || parser.parseColonType(tLHS) ||
-      parser.parseComma() || parser.parseType(tRHS))
+  if (parser.parseOperandList(operandsInfo) ||
+      parser.parseOptionalAttrDict(result.attributes) ||
+      parser.parseColonType(tLHS) || parser.parseComma() ||
+      parser.parseType(tRHS))
     return failure();
   if (operandsInfo.size() < 2)
     return parser.emitError(parser.getNameLoc(),
@@ -1521,6 +1664,14 @@ static ParseResult parseOuterProductOp(OpAsmParser &parser,
       vRHS ? VectorType::get({vLHS.getDimSize(0), vRHS.getDimSize(0)},
                              vLHS.getElementType())
            : VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType());
+
+  if (!result.attributes.get(OuterProductOp::getKindAttrName())) {
+    result.attributes.append(
+        OuterProductOp::getKindAttrName(),
+        CombiningKindAttr::get(OuterProductOp::getDefaultKind(),
+                               result.getContext()));
+  }
+
   return failure(
       parser.resolveOperand(operandsInfo[0], tLHS, result.operands) ||
       parser.resolveOperand(operandsInfo[1], tRHS, result.operands) ||
@@ -1558,6 +1709,11 @@ static LogicalResult verify(OuterProductOp op) {
 
   if (vACC && vACC != vRES)
     return op.emitOpError("expected operand #3 of same type as result type");
+
+  // Verify supported combining kind.
+  if (!isSupportedCombiningKind(op.kind(), vRES.getElementType()))
+    return op.emitOpError("unsupported outerproduct type");
+
   return success();
 }
 

diff  --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index a6cb0a9e6ba4..dd3ea8ead746 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -1354,11 +1354,17 @@ class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
     Type eltType = resType.getElementType();
     bool isInt = eltType.isa<IntegerType>();
     Value acc = (op.acc().empty()) ? nullptr : op.acc()[0];
+    vector::CombiningKind kind = op.kind();
 
     if (!rhsType) {
       // Special case: AXPY operation.
       Value b = rewriter.create<vector::BroadcastOp>(loc, lhsType, op.rhs());
-      rewriter.replaceOp(op, genMult(loc, op.lhs(), b, acc, isInt, rewriter));
+      Optional<Value> mult =
+          isInt ? genMultI(loc, op.lhs(), b, acc, kind, rewriter)
+                : genMultF(loc, op.lhs(), b, acc, kind, rewriter);
+      if (!mult.hasValue())
+        return failure();
+      rewriter.replaceOp(op, mult.getValue());
       return success();
     }
 
@@ -1371,25 +1377,95 @@ class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
       Value r = nullptr;
       if (acc)
         r = rewriter.create<vector::ExtractOp>(loc, rhsType, acc, pos);
-      Value m = genMult(loc, a, op.rhs(), r, isInt, rewriter);
-      result = rewriter.create<vector::InsertOp>(loc, resType, m, result, pos);
+      Optional<Value> m = isInt ? genMultI(loc, a, op.rhs(), r, kind, rewriter)
+                                : genMultF(loc, a, op.rhs(), r, kind, rewriter);
+      if (!m.hasValue())
+        return failure();
+      result = rewriter.create<vector::InsertOp>(loc, resType, m.getValue(),
+                                                 result, pos);
     }
     rewriter.replaceOp(op, result);
     return success();
   }
 
 private:
-  static Value genMult(Location loc, Value x, Value y, Value acc, bool isInt,
-                       PatternRewriter &rewriter) {
-    if (acc) {
-      if (isInt)
-        return rewriter.create<AddIOp>(loc, rewriter.create<MulIOp>(loc, x, y),
-                                       acc);
-      return rewriter.create<vector::FMAOp>(loc, x, y, acc);
+  static Optional<Value> genMultI(Location loc, Value x, Value y, Value acc,
+                                  vector::CombiningKind kind,
+                                  PatternRewriter &rewriter) {
+    using vector::CombiningKind;
+
+    MulIOp mul = rewriter.create<MulIOp>(loc, x, y);
+    if (!acc)
+      return Optional<Value>(mul);
+
+    Value combinedResult;
+    switch (kind) {
+    case CombiningKind::ADD:
+      combinedResult = rewriter.create<AddIOp>(loc, mul, acc);
+      break;
+    case CombiningKind::MUL:
+      combinedResult = rewriter.create<MulIOp>(loc, mul, acc);
+      break;
+    case CombiningKind::MIN:
+      combinedResult = rewriter.create<SelectOp>(
+          loc, rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, mul, acc), mul,
+          acc);
+      break;
+    case CombiningKind::MAX:
+      combinedResult = rewriter.create<SelectOp>(
+          loc, rewriter.create<CmpIOp>(loc, CmpIPredicate::sge, mul, acc), mul,
+          acc);
+      break;
+    case CombiningKind::AND:
+      combinedResult = rewriter.create<AndOp>(loc, mul, acc);
+      break;
+    case CombiningKind::OR:
+      combinedResult = rewriter.create<OrOp>(loc, mul, acc);
+      break;
+    case CombiningKind::XOR:
+      combinedResult = rewriter.create<XOrOp>(loc, mul, acc);
+      break;
+    }
+    return Optional<Value>(combinedResult);
+  }
+
+  static Optional<Value> genMultF(Location loc, Value x, Value y, Value acc,
+                                  vector::CombiningKind kind,
+                                  PatternRewriter &rewriter) {
+    using vector::CombiningKind;
+
+    // Special case for fused multiply-add.
+    if (acc && kind == CombiningKind::ADD) {
+      return Optional<Value>(rewriter.create<vector::FMAOp>(loc, x, y, acc));
+    }
+
+    MulFOp mul = rewriter.create<MulFOp>(loc, x, y);
+
+    if (!acc)
+      return Optional<Value>(mul);
+
+    Value combinedResult;
+    switch (kind) {
+    case CombiningKind::MUL:
+      combinedResult = rewriter.create<MulFOp>(loc, mul, acc);
+      break;
+    case CombiningKind::MIN:
+      combinedResult = rewriter.create<SelectOp>(
+          loc, rewriter.create<CmpFOp>(loc, CmpFPredicate::OLE, mul, acc), mul,
+          acc);
+      break;
+    case CombiningKind::MAX:
+      combinedResult = rewriter.create<SelectOp>(
+          loc, rewriter.create<CmpFOp>(loc, CmpFPredicate::OGT, mul, acc), mul,
+          acc);
+      break;
+    case CombiningKind::ADD: // Already handled this special case above.
+    case CombiningKind::AND: // Only valid for integer types.
+    case CombiningKind::OR:  // Only valid for integer types.
+    case CombiningKind::XOR: // Only valid for integer types.
+      return Optional<Value>();
     }
-    if (isInt)
-      return rewriter.create<MulIOp>(loc, x, y);
-    return rewriter.create<MulFOp>(loc, x, y);
+    return Optional<Value>(combinedResult);
   }
 };
 
@@ -1804,7 +1880,8 @@ LogicalResult ContractionOpToOuterProductOpLowering::matchAndRewrite(
   for (int64_t k = 0; k < reductionSize; ++k) {
     Value a = rewriter.create<vector::ExtractOp>(op.getLoc(), lhs, k);
     Value b = rewriter.create<vector::ExtractOp>(op.getLoc(), rhs, k);
-    res = rewriter.create<vector::OuterProductOp>(op.getLoc(), a, b, res);
+    res = rewriter.create<vector::OuterProductOp>(op.getLoc(), res.getType(), a,
+                                                  b, res, op.kind());
   }
   rewriter.replaceOp(op, res);
   return success();

diff  --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index 6d56c6a44bf5..bb532b2a550c 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -355,7 +355,7 @@ func @matmul_tensors(
   //
   // linalg contraction lowers to %tmp = vector.contract %a, %b, %c0 followed by addf %c, %tmp.
   // a later canonicalization fuses the add into vector.contract.
-  //       CHECK:   %[[C:.*]] = vector.contract {{.*}} iterator_types = ["parallel", "parallel", "reduction"]} %[[V0]], %[[V1]], %[[VEC_C0]] : vector<8x4xf32>, vector<4x12xf32> into vector<8x12xf32>
+  //       CHECK:   %[[C:.*]] = vector.contract {{.*}} iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[V0]], %[[V1]], %[[VEC_C0]] : vector<8x4xf32>, vector<4x12xf32> into vector<8x12xf32>
   //       CHECK:   %[[C2:.*]] = addf %[[V2]], %[[C]] : vector<8x12xf32>
   //       CHECK:   %[[W:.*]] = vector.transfer_write %[[C2]], %[[ARG2]][%[[C0]], %[[C0]]] {masked = [false, false]} : vector<8x12xf32>, tensor<8x12xf32>
   %0 = linalg.matmul  ins(%arg0, %arg1: tensor<8x4xf32>, tensor<4x12xf32>)
@@ -380,7 +380,7 @@ func @matmul_i8_i8_i32(%a: memref<4x6xi8>, %b: memref<6x12xi8>, %c: memref<4x12x
   //
   // linalg contraction lowers to %tmp = vector.contract %a, %b, %c0 followed by addf %c, %tmp.
   // a later canonicalization fuses the add into vector.contract.
-  //       CHECK:   %[[C:.*]] = vector.contract {{.*}} iterator_types = ["parallel", "parallel", "reduction"]} %[[V0]], %[[V1]], %[[VEC_C0]]
+  //       CHECK:   %[[C:.*]] = vector.contract {{.*}} iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[V0]], %[[V1]], %[[VEC_C0]]
   //  CHECK-SAME:     vector<4x6xi8>, vector<6x12xi8> into vector<4x12xi8>
   //       CHECK:   %[[C32:.*]] = sexti %[[C]] : vector<4x12xi8> to vector<4x12xi32>
   //       CHECK:   %[[RES:.*]] = addi %[[V2]], %[[C32]] : vector<4x12xi32>

diff  --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 11197f1e0bee..eeca040844e4 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -198,13 +198,34 @@ func @extract_strided_slice(%arg0: vector<4x8x16xf32>) -> vector<2x2x16xf32> {
 func @contraction_to_scalar(%arg0: vector<10xf32>, %arg1: vector<10xf32>) -> f32 {
   // CHECK:      %[[C0:.*]] = constant 0.000000e+00 : f32
   %f0 = constant 0.0: f32
-  // CHECK:      %[[X:.*]] = vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["reduction"]} %{{.*}}, %{{.*}}, %[[C0]] : vector<10xf32>, vector<10xf32> into f32
+  // CHECK:      %[[X:.*]] = vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["reduction"], kind = #vector.kind<add>} %{{.*}}, %{{.*}}, %[[C0]] : vector<10xf32>, vector<10xf32> into f32
   %0 = vector.contract #contraction_to_scalar_trait %arg0, %arg1, %f0
     : vector<10xf32>, vector<10xf32> into f32
   // CHECK:      return %[[X]] : f32
   return %0 : f32
 }
 
+#contraction_to_scalar_max_accesses = [
+  affine_map<(i) -> (i)>,
+  affine_map<(i) -> (i)>,
+  affine_map<(i) -> ()>
+]
+#contraction_to_scalar_max_trait = {
+  indexing_maps = #contraction_to_scalar_max_accesses,
+  iterator_types = ["reduction"],
+  kind = #vector.kind<max>
+}
+// CHECK-LABEL: @contraction_to_scalar_with_max
+func @contraction_to_scalar_with_max(%arg0: vector<10xf32>, %arg1: vector<10xf32>) -> f32 {
+  // CHECK:      %[[C0:.*]] = constant 0.000000e+00 : f32
+  %f0 = constant 0.0: f32
+  // CHECK:      %[[X:.*]] = vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["reduction"], kind = #vector.kind<max>} %{{.*}}, %{{.*}}, %[[C0]] : vector<10xf32>, vector<10xf32> into f32
+  %0 = vector.contract #contraction_to_scalar_max_trait %arg0, %arg1, %f0
+    : vector<10xf32>, vector<10xf32> into f32
+  // CHECK:      return %[[X]] : f32
+  return %0 : f32
+}
+
 #contraction_accesses0 = [
   affine_map<(b0, f0, f1, c0, c1) -> (c0, b0, c1, f0)>,
   affine_map<(b0, f0, f1, c0, c1) -> (b0, c1, c0, f1)>,
@@ -221,36 +242,46 @@ func @contraction_to_scalar(%arg0: vector<10xf32>, %arg1: vector<10xf32>) -> f32
                                         // 8,  8, 15,  5
   affine_map<(f0, f1, f2, f3, c0, c1) -> (f0, f1, f2, f3)>
 ]
+#iterator_types1 = ["parallel", "parallel", "parallel", "parallel", "reduction",
+                    "reduction"]
 #contraction_trait1 = {
   indexing_maps = #contraction_accesses1,
-  iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction",
-                    "reduction"]
+  iterator_types = #iterator_types1
+}
+#contraction_trait2 = {
+  indexing_maps = #contraction_accesses1,
+  iterator_types = #iterator_types1,
+  kind = #vector.kind<max>
 }
 // CHECK-LABEL: @contraction
 func @contraction(%arg0 : vector<7x8x16x15xf32>, %arg1 : vector<8x16x7x5xf32>,
                   %arg2 : vector<8x15x5xf32>, %arg3 : vector<8x8x15x5xf32>,
                   %arg4 : vector<7x8x16x15xf16>, %arg5 : vector<8x16x7x5xf16>) {
   // Test contraction with batch and contracting dims.
-  // CHECK: vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]} {{.*}}, {{.*}}, {{.*}} : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x5xf32>
+  // CHECK: vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"], kind = #vector.kind<add>} {{.*}}, {{.*}}, {{.*}} : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x5xf32>
   %0 = vector.contract #contraction_trait0 %arg0, %arg1, %arg2
       : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x5xf32>
   // Test contraction with only contracting dims. In this case the lhs/rhs
   // dimension of size 8 will be considered a parallel dim for lhs/rhs and will
   // appear twice in the output.
-  // CHECK: vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} {{.*}}, {{.*}}, {{.*}} : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x8x15x5xf32>
+  // CHECK: vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"], kind = #vector.kind<add>} {{.*}}, {{.*}}, {{.*}} : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x8x15x5xf32>
   %1 = vector.contract #contraction_trait1 %arg0, %arg1, %arg3
       : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x8x15x5xf32>
   // Test contraction with optional vector mask arguments.
   %lhs_mask = vector.constant_mask [7, 8, 16, 15] : vector<7x8x16x15xi1>
   %rhs_mask = vector.constant_mask [8, 16, 7, 5] : vector<8x16x7x5xi1>
-  // CHECK: vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}} : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x8x15x5xf32>
+  // CHECK: vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"], kind = #vector.kind<add>} {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}} : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x8x15x5xf32>
   %2 = vector.contract #contraction_trait1 %arg0, %arg1, %arg3, %lhs_mask,
                                            %rhs_mask
       : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x8x15x5xf32>
   // Test contraction with mixed type.
-  // CHECK: vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} {{.*}}, {{.*}}, {{.*}} : vector<7x8x16x15xf16>, vector<8x16x7x5xf16> into vector<8x8x15x5xf32>
+  // CHECK: vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"], kind = #vector.kind<add>} {{.*}}, {{.*}}, {{.*}} : vector<7x8x16x15xf16>, vector<8x16x7x5xf16> into vector<8x8x15x5xf32>
   %3 = vector.contract #contraction_trait1 %arg4, %arg5, %arg3
       : vector<7x8x16x15xf16>, vector<8x16x7x5xf16> into vector<8x8x15x5xf32>
+  // Test contraction with "max" instead of "add".
+  // CHECK: vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"], kind = #vector.kind<max>} {{.*}}, {{.*}}, {{.*}} : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x8x15x5xf32>
+  %4 = vector.contract #contraction_trait2 %arg0, %arg1, %arg3
+      : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x8x15x5xf32>
   return
 }
 

diff  --git a/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir
index 3e2896c82bbe..c02de28b427e 100644
--- a/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir
@@ -9,6 +9,11 @@
   indexing_maps = #matvec_accesses,
   iterator_types = ["parallel", "reduction"]
 }
+#matvecmax_trait = {
+  indexing_maps = #matvec_accesses,
+  iterator_types = ["parallel", "reduction"],
+  kind = #vector.kind<max>
+}
 
 #mattransvec_accesses = [
   affine_map<(i, j) -> (j, i)>,
@@ -50,10 +55,10 @@
 // CHECK: %[[T3:.*]] = vector.transpose %[[T0]], [1, 0] : vector<2x2xf32> to vector<2x2xf32>
 // CHECK: %[[T4:.*]] = vector.extract %[[T3]][0] : vector<2x2xf32>
 // CHECK: %[[T5:.*]] = vector.extract %[[T1]][0] : vector<2xf32>
-// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[T2]] : vector<2xf32>, f32
+// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[T2]] {kind = #vector.kind<add>} : vector<2xf32>, f32
 // CHECK: %[[T7:.*]] = vector.extract %[[T3]][1] : vector<2x2xf32>
 // CHECK: %[[T8:.*]] = vector.extract %[[T1]][1] : vector<2xf32>
-// CHECK: %[[T9:.*]] = vector.outerproduct %[[T7]], %[[T8]], %[[T6]] : vector<2xf32>, f32
+// CHECK: %[[T9:.*]] = vector.outerproduct %[[T7]], %[[T8]], %[[T6]] {kind = #vector.kind<add>} : vector<2xf32>, f32
 // CHECK: store %[[T9]], %[[C]][] : memref<vector<2xf32>>
 // CHECK: return
 func @matvec2x2(%arg0: memref<vector<2x2xf32>>, %arg1: memref<vector<2xf32>>,
@@ -66,6 +71,32 @@ func @matvec2x2(%arg0: memref<vector<2x2xf32>>, %arg1: memref<vector<2xf32>>,
   return
 }
 
+// CHECK-LABEL: func @matvecmax2x2
+// CHECK-SAME: %[[A:.*0]]: memref<vector<2x2xf32>>
+// CHECK-SAME: %[[B:.*1]]: memref<vector<2xf32>>
+// CHECK-SAME: %[[C:.*2]]: memref<vector<2xf32>>
+// CHECK: %[[T0:.*]] = load %[[A]][] : memref<vector<2x2xf32>>
+// CHECK: %[[T1:.*]] = load %[[B]][] : memref<vector<2xf32>>
+// CHECK: %[[T2:.*]] = load %[[C]][] : memref<vector<2xf32>>
+// CHECK: %[[T3:.*]] = vector.transpose %[[T0]], [1, 0] : vector<2x2xf32> to vector<2x2xf32>
+// CHECK: %[[T4:.*]] = vector.extract %[[T3]][0] : vector<2x2xf32>
+// CHECK: %[[T5:.*]] = vector.extract %[[T1]][0] : vector<2xf32>
+// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[T2]] {kind = #vector.kind<max>} : vector<2xf32>, f32
+// CHECK: %[[T7:.*]] = vector.extract %[[T3]][1] : vector<2x2xf32>
+// CHECK: %[[T8:.*]] = vector.extract %[[T1]][1] : vector<2xf32>
+// CHECK: %[[T9:.*]] = vector.outerproduct %[[T7]], %[[T8]], %[[T6]] {kind = #vector.kind<max>} : vector<2xf32>, f32
+// CHECK: store %[[T9]], %[[C]][] : memref<vector<2xf32>>
+// CHECK: return
+func @matvecmax2x2(%arg0: memref<vector<2x2xf32>>, %arg1: memref<vector<2xf32>>,
+                                                   %arg2: memref<vector<2xf32>>) {
+  %A = load %arg0[] : memref<vector<2x2xf32>>
+  %x = load %arg1[] : memref<vector<2xf32>>
+  %b = load %arg2[] : memref<vector<2xf32>>
+  %0 = vector.contract #matvecmax_trait %A, %x, %b : vector<2x2xf32>, vector<2xf32> into vector<2xf32>
+  store %0, %arg2[] : memref<vector<2xf32>>
+  return
+}
+
 // CHECK-LABEL: func @mattransvec2x2
 // CHECK-SAME: %[[A:.*0]]: memref<vector<2x2xf32>>
 // CHECK-SAME: %[[B:.*1]]: memref<vector<2xf32>>
@@ -75,10 +106,10 @@ func @matvec2x2(%arg0: memref<vector<2x2xf32>>, %arg1: memref<vector<2xf32>>,
 // CHECK: %[[T2:.*]] = load %[[C]][] : memref<vector<2xf32>>
 // CHECK: %[[T3:.*]] = vector.extract %[[T0]][0] : vector<2x2xf32>
 // CHECK: %[[T4:.*]] = vector.extract %[[T1]][0] : vector<2xf32>
-// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[T2]] : vector<2xf32>, f32
+// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[T2]] {kind = #vector.kind<add>} : vector<2xf32>, f32
 // CHECK: %[[T6:.*]] = vector.extract %[[T0]][1] : vector<2x2xf32>
 // CHECK: %[[T7:.*]] = vector.extract %[[T1]][1] : vector<2xf32>
-// CHECK: %[[T8:.*]] = vector.outerproduct %[[T6]], %[[T7]], %[[T5]] : vector<2xf32>, f32
+// CHECK: %[[T8:.*]] = vector.outerproduct %[[T6]], %[[T7]], %[[T5]] {kind = #vector.kind<add>} : vector<2xf32>, f32
 // CHECK: store %[[T8]], %[[C]][] : memref<vector<2xf32>>
 // CHECK: return
 func @mattransvec2x2(%arg0: memref<vector<2x2xf32>>, %arg1: memref<vector<2xf32>>,
@@ -101,10 +132,10 @@ func @mattransvec2x2(%arg0: memref<vector<2x2xf32>>, %arg1: memref<vector<2xf32>
 // CHECK: %[[T3:.*]] = vector.transpose %[[T0]], [1, 0] : vector<2x2xf32> to vector<2x2xf32>
 // CHECK: %[[T4:.*]] = vector.extract %[[T3]][0] : vector<2x2xf32>
 // CHECK: %[[T5:.*]] = vector.extract %[[T1]][0] : vector<2xf32>
-// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[T2]] : vector<2xf32>, f32
+// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[T2]] {kind = #vector.kind<add>} : vector<2xf32>, f32
 // CHECK: %[[T7:.*]] = vector.extract %[[T3]][1] : vector<2x2xf32>
 // CHECK: %[[T8:.*]] = vector.extract %[[T1]][1] : vector<2xf32>
-// CHECK: %[[T9:.*]] = vector.outerproduct %[[T7]], %[[T8]], %[[T6]] : vector<2xf32>, f32
+// CHECK: %[[T9:.*]] = vector.outerproduct %[[T7]], %[[T8]], %[[T6]] {kind = #vector.kind<add>} : vector<2xf32>, f32
 // CHECK: store %[[T9]], %[[C]][] : memref<vector<2xf32>>
 // CHECK: return
 func @vecmat2x2(%arg0: memref<vector<2x2xf32>>, %arg1: memref<vector<2xf32>>,
@@ -126,10 +157,10 @@ func @vecmat2x2(%arg0: memref<vector<2x2xf32>>, %arg1: memref<vector<2xf32>>,
 // CHECK: %[[T2:.*]] = load %[[C]][] : memref<vector<2xf32>>
 // CHECK: %[[T3:.*]] = vector.extract %[[T0]][0] : vector<2x2xf32>
 // CHECK: %[[T4:.*]] = vector.extract %[[T1]][0] : vector<2xf32>
-// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[T2]] : vector<2xf32>, f32
+// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[T2]] {kind = #vector.kind<add>} : vector<2xf32>, f32
 // CHECK: %[[T6:.*]] = vector.extract %[[T0]][1] : vector<2x2xf32>
 // CHECK: %[[T7:.*]] = vector.extract %[[T1]][1] : vector<2xf32>
-// CHECK: %[[T8:.*]] = vector.outerproduct %[[T6]], %[[T7]], %[[T5]] : vector<2xf32>, f32
+// CHECK: %[[T8:.*]] = vector.outerproduct %[[T6]], %[[T7]], %[[T5]] {kind = #vector.kind<add>} : vector<2xf32>, f32
 // CHECK: store %[[T8]], %[[C]][] : memref<vector<2xf32>>
 // CHECK: return
 func @vecmattrans2x2(%arg0: memref<vector<2x2xf32>>, %arg1: memref<vector<2xf32>>,

diff  --git a/mlir/test/Dialect/Vector/vector-transforms.mlir b/mlir/test/Dialect/Vector/vector-transforms.mlir
index 20c91882871d..f57e5f6e7ea9 100644
--- a/mlir/test/Dialect/Vector/vector-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-transforms.mlir
@@ -92,56 +92,56 @@ func @add4x4(%0: vector<4x4xf32>, %1: vector<4x4xf32>) -> vector<4x4xf32> {
 // CHECK-NEXT: %[[TG3:.*]] = vector.tuple_get %[[ES3]], 0 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
 // CHECK-NEXT: %[[TG4:.*]] = vector.tuple_get %[[ES4]], 0 : tuple<vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>>
 // CHECK-NEXT: %[[TG5:.*]] = vector.tuple_get %[[ES5]], 0 : tuple<vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>>
-// CHECK-NEXT: %[[R1S00:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[TG1]], %[[TG2]], %[[TG3]], %[[TG4]], %[[TG5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
+// CHECK-NEXT: %[[R1S00:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[TG1]], %[[TG2]], %[[TG3]], %[[TG4]], %[[TG5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
 
 // CHECK-NEXT: %[[TG6:.*]] = vector.tuple_get %[[ES1]], 1 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
 // CHECK-NEXT: %[[TG7:.*]] = vector.tuple_get %[[ES2]], 2 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
 // CHECK-NEXT: %[[TG8:.*]] = vector.tuple_get %[[ES4]], 1 : tuple<vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>>
 // CHECK-NEXT: %[[TG9:.*]] = vector.tuple_get %[[ES5]], 2 : tuple<vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>>
-// CHECK-NEXT: %[[R2S00:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[TG6]], %[[TG7]], %[[R1S00]], %[[TG8]], %[[TG9]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
+// CHECK-NEXT: %[[R2S00:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[TG6]], %[[TG7]], %[[R1S00]], %[[TG8]], %[[TG9]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
 
 // CHECK-NEXT: %[[TG10:.*]] = vector.tuple_get %[[ES1]], 2 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
 // CHECK-NEXT: %[[TG11:.*]] = vector.tuple_get %[[ES2]], 4 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
 // CHECK-NEXT: %[[TG12:.*]] = vector.tuple_get %[[ES4]], 2 : tuple<vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>>
 // CHECK-NEXT: %[[TG13:.*]] = vector.tuple_get %[[ES5]], 4 : tuple<vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>>
-// CHECK-NEXT: %[[R3S00:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[TG10]], %[[TG11]], %[[R2S00]], %[[TG12]], %[[TG13]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
+// CHECK-NEXT: %[[R3S00:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[TG10]], %[[TG11]], %[[R2S00]], %[[TG12]], %[[TG13]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
 
 // Reducing output vector [0, 2]
 
 // CHECK-NEXT: %[[TG14:.*]] = vector.tuple_get %[[ES2]], 1 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
 // CHECK-NEXT: %[[TG15:.*]] = vector.tuple_get %[[ES3]], 1 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
 // CHECK-NEXT: %[[TG16:.*]] = vector.tuple_get %[[ES5]], 1 : tuple<vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>>
-// CHECK-NEXT: %[[R1S02:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[TG1]], %[[TG14]], %[[TG15]], %[[TG4]], %[[TG16]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
+// CHECK-NEXT: %[[R1S02:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[TG1]], %[[TG14]], %[[TG15]], %[[TG4]], %[[TG16]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
 
 // CHECK-NEXT: %[[TG17:.*]] = vector.tuple_get %[[ES2]], 3 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
 // CHECK-NEXT: %[[TG18:.*]] = vector.tuple_get %[[ES5]], 3 : tuple<vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>>
-// CHECK-NEXT: %[[R2S02:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[TG6]], %[[TG17]], %[[R1S02]], %[[TG8]], %[[TG18]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
+// CHECK-NEXT: %[[R2S02:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[TG6]], %[[TG17]], %[[R1S02]], %[[TG8]], %[[TG18]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
 
 // CHECK-NEXT: %[[TG19:.*]] = vector.tuple_get %[[ES2]], 5 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
 // CHECK-NEXT: %[[TG20:.*]] = vector.tuple_get %[[ES5]], 5 : tuple<vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>>
-// CHECK-NEXT: %[[R3S02:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[TG10]], %[[TG19]], %[[R2S02]], %[[TG12]], %[[TG20]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
+// CHECK-NEXT: %[[R3S02:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[TG10]], %[[TG19]], %[[R2S02]], %[[TG12]], %[[TG20]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
 
 // Reducing output vector [2, 0]
 
 // CHECK-NEXT: %[[TG21:.*]] = vector.tuple_get %[[ES1]], 3 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
 // CHECK-NEXT: %[[TG22:.*]] = vector.tuple_get %[[ES3]], 2 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
 // CHECK-NEXT: %[[TG23:.*]] = vector.tuple_get %[[ES4]], 3 : tuple<vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>>
-// CHECK-NEXT: %[[R1S20:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[TG21]], %[[TG2]], %[[TG22]], %[[TG23]], %[[TG5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
+// CHECK-NEXT: %[[R1S20:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[TG21]], %[[TG2]], %[[TG22]], %[[TG23]], %[[TG5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
 
 // CHECK-NEXT: %[[TG24:.*]] = vector.tuple_get %[[ES1]], 4 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
 // CHECK-NEXT: %[[TG25:.*]] = vector.tuple_get %[[ES4]], 4 : tuple<vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>>
-// CHECK-NEXT:  %[[R2S20:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[TG24]], %[[TG7]], %[[R1S20]], %[[TG25]], %[[TG9]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
+// CHECK-NEXT:  %[[R2S20:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[TG24]], %[[TG7]], %[[R1S20]], %[[TG25]], %[[TG9]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
 
 // CHECK-NEXT: %[[TG26:.*]] = vector.tuple_get %[[ES1]], 5 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
 // CHECK-NEXT: %[[TG27:.*]] = vector.tuple_get %[[ES4]], 5 : tuple<vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>>
-// CHECK-NEXT:  %[[R3S20:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[TG26]], %[[TG11]], %[[R2S20]], %[[TG27]], %[[TG13]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
+// CHECK-NEXT:  %[[R3S20:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[TG26]], %[[TG11]], %[[R2S20]], %[[TG27]], %[[TG13]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
 
 // Reducing output vector [2, 2]
 
 // CHECK-NEXT: %[[TG28:.*]] = vector.tuple_get %[[ES3]], 3 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
-// CHECK-NEXT: %[[R1S22:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[TG21]], %[[TG14]], %[[TG28]], %[[TG23]], %[[TG16]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
-// CHECK-NEXT: %[[R2S22:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[TG24]], %[[TG17]], %[[R1S22]], %[[TG25]], %[[TG18]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
-// CHECK-NEXT: %[[R3S22:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[TG26]], %[[TG19]], %[[R2S22]], %[[TG27]], %[[TG20]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
+// CHECK-NEXT: %[[R1S22:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[TG21]], %[[TG14]], %[[TG28]], %[[TG23]], %[[TG16]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
+// CHECK-NEXT: %[[R2S22:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[TG24]], %[[TG17]], %[[R1S22]], %[[TG25]], %[[TG18]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
+// CHECK-NEXT: %[[R3S22:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[TG26]], %[[TG19]], %[[R2S22]], %[[TG27]], %[[TG20]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
 
 // CHECK-NEXT: %[[RES0:.*]] = vector.tuple %[[R3S00]], %[[R3S02]], %[[R3S20]], %[[R3S22]] : vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>
 // CHECK-NEXT: %[[RES1:.*]] = vector.insert_slices %[[RES0]], [2, 2], [1, 1] : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> into vector<4x4xf32>
@@ -187,26 +187,26 @@ func @contraction4x4_ijk(%arg0 : vector<4x6xf32>, %arg1 : vector<6x4xf32>,
 // CHECK-NEXT: %[[TG3:.*]] = vector.tuple_get %[[ES3]], 0 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
 // CHECK-NEXT: %[[TG4:.*]] = vector.tuple_get %[[ES4]], 0 : tuple<vector<2x2xi1>, vector<2x2xi1>>
 // CHECK-NEXT: %[[TG5:.*]] = vector.tuple_get %[[ES5]], 0 : tuple<vector<2x2xi1>, vector<2x2xi1>>
-// CHECK-NEXT:  %[[R1S00:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[TG1]], %[[TG2]], %[[TG3]], %[[TG4]], %[[TG5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
+// CHECK-NEXT:  %[[R1S00:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind<add>} %[[TG1]], %[[TG2]], %[[TG3]], %[[TG4]], %[[TG5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
 
 // Reducing output vector [0, 2]
 
 // CHECK-NEXT: %[[TG6:.*]] = vector.tuple_get %[[ES2]], 1 : tuple<vector<2x2xf32>, vector<2x2xf32>>
 // CHECK-NEXT: %[[TG7:.*]] = vector.tuple_get %[[ES3]], 1 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
 // CHECK-NEXT: %[[TG8:.*]] = vector.tuple_get %[[ES5]], 1 : tuple<vector<2x2xi1>, vector<2x2xi1>>
-// CHECK-NEXT:  %[[R1S02:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[TG1]], %[[TG6]], %[[TG7]], %[[TG4]], %[[TG8]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
+// CHECK-NEXT:  %[[R1S02:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind<add>} %[[TG1]], %[[TG6]], %[[TG7]], %[[TG4]], %[[TG8]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
 
 // Reducing output vector [2, 0]
 
 // CHECK-NEXT: %[[TG9:.*]] = vector.tuple_get %[[ES1]], 1 : tuple<vector<2x2xf32>, vector<2x2xf32>>
 // CHECK-NEXT: %[[TG10:.*]] = vector.tuple_get %[[ES3]], 2 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
 // CHECK-NEXT: %[[TG11:.*]] = vector.tuple_get %[[ES4]], 1 : tuple<vector<2x2xi1>, vector<2x2xi1>>
-// CHECK-NEXT:  %[[R1S20:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[TG9]], %[[TG2]], %[[TG10]], %[[TG11]], %[[TG5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
+// CHECK-NEXT:  %[[R1S20:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind<add>} %[[TG9]], %[[TG2]], %[[TG10]], %[[TG11]], %[[TG5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
 
 // Reducing output vector [2, 2]
 
 // CHECK-NEXT: %[[TG12:.*]] = vector.tuple_get %[[ES3]], 3 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
-// CHECK-NEXT:  %[[R1S22:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[TG9]], %[[TG6]], %[[TG12]], %[[TG11]], %[[TG8]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
+// CHECK-NEXT:  %[[R1S22:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind<add>} %[[TG9]], %[[TG6]], %[[TG12]], %[[TG11]], %[[TG8]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
 
 // CHECK-NEXT: %[[RES0:.*]] = vector.tuple %[[R1S00]], %[[R1S02]], %[[R1S20]], %[[R1S22]] : vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>
 // CHECK-NEXT: %[[RES1:.*]] = vector.insert_slices %[[RES0]], [2, 2], [1, 1] : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> into vector<4x4xf32>
@@ -241,10 +241,10 @@ func @contraction4x4_ikj(%arg0 : vector<4x2xf32>, %arg1 : vector<2x4xf32>,
 // CHECK-NEXT: %[[VTR6:.*]] = vector.transfer_read %{{.*}}[%[[C2]], %[[C0]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32>
 // CHECK-NEXT: %[[VTR7:.*]] = vector.transfer_read %{{.*}}[%[[C2]], %[[C2]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32>
 
-// CHECK-NEXT: %[[R0:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[VTR0]], %[[VTR2]], %[[VTR4]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
-// CHECK-NEXT: %[[R1:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[VTR0]], %[[VTR3]], %[[VTR5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
-// CHECK-NEXT: %[[R2:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[VTR1]], %[[VTR2]], %[[VTR6]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
-// CHECK-NEXT: %[[R3:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[VTR1]], %[[VTR3]], %[[VTR7]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
+// CHECK-NEXT: %[[R0:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind<add>} %[[VTR0]], %[[VTR2]], %[[VTR4]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
+// CHECK-NEXT: %[[R1:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind<add>} %[[VTR0]], %[[VTR3]], %[[VTR5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
+// CHECK-NEXT: %[[R2:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind<add>} %[[VTR1]], %[[VTR2]], %[[VTR6]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
+// CHECK-NEXT: %[[R3:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind<add>} %[[VTR1]], %[[VTR3]], %[[VTR7]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
 
 // CHECK-NEXT: vector.transfer_write %[[R0]], %{{.*}}[%[[C0]], %[[C0]]] {masked = [false, false]} : vector<2x2xf32>, memref<4x4xf32>
 // CHECK-NEXT: vector.transfer_write %[[R1]], %{{.*}}[%[[C0]], %[[C2]]] {masked = [false, false]} : vector<2x2xf32>, memref<4x4xf32>
@@ -572,10 +572,10 @@ func @elementwise_unroll(%arg0 : memref<4x4xf32>, %arg1 : memref<4x4xf32>) {
 // CHECK-NEXT: %[[VTR6:.*]] = vector.transfer_read %{{.*}}[%[[C2]], %[[C0]]], %{{.*}} : tensor<4x4xf32>, vector<2x2xf32>
 // CHECK-NEXT: %[[VTR7:.*]] = vector.transfer_read %{{.*}}[%[[C2]], %[[C2]]], %{{.*}} : tensor<4x4xf32>, vector<2x2xf32>
 
-// CHECK-NEXT: %[[R0:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[VTR0]], %[[VTR2]], %[[VTR4]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
-// CHECK-NEXT: %[[R1:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[VTR0]], %[[VTR3]], %[[VTR5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
-// CHECK-NEXT: %[[R2:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[VTR1]], %[[VTR2]], %[[VTR6]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
-// CHECK-NEXT: %[[R3:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[VTR1]], %[[VTR3]], %[[VTR7]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
+// CHECK-NEXT: %[[R0:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind<add>} %[[VTR0]], %[[VTR2]], %[[VTR4]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
+// CHECK-NEXT: %[[R1:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind<add>} %[[VTR0]], %[[VTR3]], %[[VTR5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
+// CHECK-NEXT: %[[R2:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind<add>} %[[VTR1]], %[[VTR2]], %[[VTR6]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
+// CHECK-NEXT: %[[R3:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind<add>} %[[VTR1]], %[[VTR3]], %[[VTR7]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
 
 // CHECK-NEXT: %[[VTW0:.*]] = vector.transfer_write %[[R0]], %{{.*}}[%[[C0]], %[[C0]]] {masked = [false, false]} : vector<2x2xf32>, tensor<4x4xf32>
 // CHECK-NEXT: %[[VTW1:.*]] = vector.transfer_write %[[R1]], %[[VTW0]][%[[C0]], %[[C2]]] {masked = [false, false]} : vector<2x2xf32>, tensor<4x4xf32>

diff  --git a/mlir/tools/mlir-tblgen/EnumsGen.cpp b/mlir/tools/mlir-tblgen/EnumsGen.cpp
index 629bedb3e03f..e207e313b730 100644
--- a/mlir/tools/mlir-tblgen/EnumsGen.cpp
+++ b/mlir/tools/mlir-tblgen/EnumsGen.cpp
@@ -199,7 +199,7 @@ static void emitSymToStrFnForBitEnum(const Record &enumDef, raw_ostream &os) {
     if (auto val = enumerant.getValue())
       os << formatv("  if ({0}u & val) {{ strs.push_back(\"{1}\"); "
                     "val &= ~{0}u; }\n",
-                    val, enumerant.getSymbol());
+                    val, enumerant.getStr());
   }
   // If we have unknown bit set, return an empty string to signal errors.
   os << "\n  if (val) return \"\";\n";
@@ -261,8 +261,7 @@ static void emitStrToSymFnForBitEnum(const Record &enumDef, raw_ostream &os) {
   for (const auto &enumerant : enumerants) {
     // Skip the special enumerant for None.
     if (auto val = enumerant.getValue())
-      os.indent(6) << formatv(".Case(\"{0}\", {1})\n", enumerant.getSymbol(),
-                              val);
+      os.indent(6) << formatv(".Case(\"{0}\", {1})\n", enumerant.getStr(), val);
   }
   os.indent(6) << ".Default(::llvm::None);\n";
 


        


More information about the Mlir-commits mailing list