[Mlir-commits] [mlir] 4758e91 - [mlir] Change IteratorType in ContractionOp in Vector dialect from string to enum.

Oleg Shyshkov llvmlistbot at llvm.org
Mon Sep 12 07:59:46 PDT 2022


Author: Oleg Shyshkov
Date: 2022-09-12T16:59:34+02:00
New Revision: 4758e916e1b34d800b03cd2ea6a0a554ce2483be

URL: https://github.com/llvm/llvm-project/commit/4758e916e1b34d800b03cd2ea6a0a554ce2483be
DIFF: https://github.com/llvm/llvm-project/commit/4758e916e1b34d800b03cd2ea6a0a554ce2483be.diff

LOG: [mlir] Change IteratorType in ContractionOp in Vector dialect from string to enum.

This is the first step in replacing interator_type from strings with enums in Vector and Linalg dialect. This change adds IteratorTypeAttr and uses it in ContractionOp.

To avoid breaking all the tests, print/parse code has conversion between string and enum for now.

There is a shared code in StructuredOpsUtils.h that expects iterator types to be strings. To break this dependancy, this change forks helper function `isParallelIterator` and `isReductionIterator` to utils in both dialects and adds `getIteratorTypeNames()` to support backward compatibility with StructuredGenerator.

In the later changes, I plan to add a similar enum attribute to Linalg.

Differential Revision: https://reviews.llvm.org/D133696

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
    mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
    mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
    mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
    mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
    mlir/lib/Conversion/VectorToGPU/NvGpuSupport.cpp
    mlir/lib/Conversion/VectorToGPU/NvGpuSupport.h
    mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
    mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
    mlir/lib/Dialect/Linalg/Utils/Utils.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index 8bcde8a5a2aa..513aedcb82c4 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -727,6 +727,10 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
     // TODO: Remove once prefixing is flipped.
     ArrayAttr getIteratorTypes() { return iterator_types(); }
 
+    SmallVector<StringRef> getIteratorTypeNames() {
+      return llvm::to_vector(getIteratorTypes().getAsValueRange<StringAttr>());
+    }
+
     //========================================================================//
     // Forwarding functions to access interface methods from the
     // DestinationStyleOpInterface.

diff  --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index 20164d7b051c..6f31e6a3abea 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -45,6 +45,12 @@ bool isElementwise(LinalgOp op);
 /// `[0, permutation.size())`.
 bool isPermutation(ArrayRef<int64_t> permutation);
 
+/// Check if `attr` has "parallel" iterator type semantics.
+bool isParallelIterator(Attribute attr);
+
+/// Check if `attr` has "reduction" iterator type semantics.
+bool isReductionIterator(Attribute attr);
+
 /// Helper function that creates a memref::DimOp or tensor::DimOp depending on
 /// the type of `source`.
 Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim);

diff  --git a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
index 400da2d751ae..9dfde80e793b 100644
--- a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
@@ -78,24 +78,12 @@ constexpr StringRef getPaddingAttrName() { return "padding"; }
 
 /// Use to encode that a particular iterator type has parallel semantics.
 constexpr StringRef getParallelIteratorTypeName() { return "parallel"; }
-inline bool isParallelIterator(Attribute attr) {
-  auto strAttr = attr.dyn_cast_or_null<StringAttr>();
-  return strAttr && strAttr.getValue() == getParallelIteratorTypeName();
-}
 
 /// Use to encode that a particular iterator type has reduction semantics.
 constexpr StringRef getReductionIteratorTypeName() { return "reduction"; }
-inline bool isReductionIterator(Attribute attr) {
-  auto strAttr = attr.dyn_cast_or_null<StringAttr>();
-  return strAttr && strAttr.getValue() == getReductionIteratorTypeName();
-}
 
 /// Use to encode that a particular iterator type has window semantics.
 constexpr StringRef getWindowIteratorTypeName() { return "window"; }
-inline bool isWindowIterator(Attribute attr) {
-  auto strAttr = attr.dyn_cast_or_null<StringAttr>();
-  return strAttr && strAttr.getValue() == getWindowIteratorTypeName();
-}
 
 /// Use to encode that a particular iterator type has window semantics.
 inline ArrayRef<StringRef> getAllIteratorTypeNames() {
@@ -122,19 +110,6 @@ inline unsigned getNumIterators(ArrayAttr iteratorTypes) {
   return res;
 }
 
-/// Typed representation for loop type strings.
-enum class IteratorType { Parallel, Reduction };
-
-inline StringRef toString(IteratorType t) {
-  switch (t) {
-  case IteratorType::Parallel:
-    return getParallelIteratorTypeName();
-  case IteratorType::Reduction:
-    return getReductionIteratorTypeName();
-  }
-  llvm_unreachable("Unsupported IteratorType");
-}
-
 /// Helper StructuredGenerator class to manipulate and rewrite ops with
 /// `StructuredOpInterface`. This is templated for now because VectorOps do not
 /// yet implement the StructuredOpInterface itself.
@@ -145,10 +120,7 @@ class StructuredGenerator {
 
   struct IteratorType {
     IteratorType(StringRef strRef) : strRef(strRef) {}
-    bool isOfType(Attribute attr) const {
-      auto sAttr = attr.dyn_cast<StringAttr>();
-      return sAttr && sAttr.getValue() == strRef;
-    }
+    bool isOfType(StringRef typeName) const { return typeName == strRef; }
     StringRef strRef;
   };
   struct Par : public IteratorType {
@@ -163,7 +135,7 @@ class StructuredGenerator {
 
   StructuredGenerator(OpBuilder &builder, StructuredOpInterface op)
       : builder(builder), ctx(op.getContext()), loc(op.getLoc()),
-        iterators(op.getIteratorTypes()), maps(op.getIndexingMapsArray()),
+        iterators(op.getIteratorTypeNames()), maps(op.getIndexingMapsArray()),
         op(op) {}
 
   bool iters(ArrayRef<IteratorType> its) {
@@ -185,7 +157,7 @@ class StructuredGenerator {
   OpBuilder &builder;
   MLIRContext *ctx;
   Location loc;
-  ArrayAttr iterators;
+  SmallVector<StringRef> iterators;
   SmallVector<AffineMap, 4> maps;
   Operation *op;
 };

diff  --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
index f4316d061560..61a68449f83c 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
@@ -185,6 +185,17 @@ bool isDisjointTransferSet(VectorTransferOpInterface transferA,
 /// corresponding arith operation.
 Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind,
                          Value v1, Value v2);
+
+/// Returns true if `attr` has "parallel" iterator type semantics.
+inline bool isParallelIterator(Attribute attr) {
+  return attr.cast<IteratorTypeAttr>().getValue() == IteratorType::parallel;
+}
+
+/// Returns true if `attr` has "reduction" iterator type semantics.
+inline bool isReductionIterator(Attribute attr) {
+  return attr.cast<IteratorTypeAttr>().getValue() == IteratorType::reduction;
+}
+
 } // namespace vector
 } // namespace mlir
 

diff  --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index ec29d53d2ae8..07f1b21a24b8 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -63,6 +63,21 @@ def Vector_CombiningKindAttr : EnumAttr<Vector_Dialect, CombiningKind, "kind"> {
   let assemblyFormat = "`<` $value `>`";
 }
 
+def IteratorType : I32EnumAttr<"IteratorType", "Iterator type", [
+  I32EnumAttrCase<"parallel", 0>,
+  I32EnumAttrCase<"reduction", 1>
+  ]> {
+    let genSpecializedAttr = 0;
+    let cppNamespace = "::mlir::vector";
+}
+
+def IteratorTypeEnum : EnumAttr<Vector_Dialect, IteratorType, "iterator_type"> {
+  let assemblyFormat = "`<` $value `>`";
+}
+
+def IteratorTypeArrayAttr : TypedArrayAttrBase<IteratorTypeEnum,
+    "Iterator type should be an enum.">;
+
 // TODO: Add an attribute to specify a 
diff erent algebra with operators other
 // than the current set: {*, +}.
 def Vector_ContractionOp :
@@ -76,7 +91,7 @@ def Vector_ContractionOp :
     Arguments<(ins AnyVector:$lhs, AnyVector:$rhs, AnyType:$acc,
                Variadic<VectorOf<[I1]>>:$masks,
                ArrayAttr:$indexing_maps,
-               ArrayAttr:$iterator_types,
+               IteratorTypeArrayAttr:$iterator_types,
                DefaultValuedAttr<Vector_CombiningKindAttr,
                                  "CombiningKind::ADD">:$kind)>,
     Results<(outs AnyType)> {
@@ -201,7 +216,7 @@ def Vector_ContractionOp :
       "ArrayAttr":$indexingMaps, "ArrayAttr":$iteratorTypes)>,
     OpBuilder<(ins "Value":$lhs, "Value":$rhs, "Value":$acc,
       "ArrayRef<ArrayRef<AffineExpr>>":$indexingExprs,
-      "ArrayRef<StringRef>":$iteratorTypes)>,
+      "ArrayRef<IteratorType>":$iteratorTypes)>,
     OpBuilder<(ins "Value":$lhs, "Value":$rhs, "Value":$acc,
       "ArrayAttr":$indexingMaps, "ArrayAttr":$iteratorTypes,
       "CombiningKind":$kind)>
@@ -249,6 +264,14 @@ def Vector_ContractionOp :
     static CombiningKind getDefaultKind() {
       return CombiningKind::ADD;
     }
+
+    // Returns iterator types in string format.
+    SmallVector<StringRef> getIteratorTypeNames() {
+      return llvm::to_vector(
+          llvm::map_range(getIteratorTypes(), [](Attribute a) {
+            return stringifyIteratorType(a.cast<IteratorTypeAttr>().getValue());
+          }));
+    }
   }];
 
   let hasCanonicalizer = 1;

diff  --git a/mlir/lib/Conversion/VectorToGPU/NvGpuSupport.cpp b/mlir/lib/Conversion/VectorToGPU/NvGpuSupport.cpp
index 776e746eee63..9df576c4e73e 100644
--- a/mlir/lib/Conversion/VectorToGPU/NvGpuSupport.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/NvGpuSupport.cpp
@@ -217,10 +217,10 @@ FailureOr<nvgpu::LdMatrixParams> getLdMatrixParams(const WarpMatrixInfo &type,
     params.targetLayout = NVVM::MMALayout::col;
   }
   ArrayRef<int64_t> shape = type.vectorType.getShape();
-  params.contiguousDimType =
-      transpose ? IteratorType::Parallel : IteratorType::Reduction;
+  params.contiguousDimType = transpose ? vector::IteratorType::parallel
+                                       : vector::IteratorType::reduction;
 
-  if (params.contiguousDimType == IteratorType::Reduction) {
+  if (params.contiguousDimType == vector::IteratorType::reduction) {
     params.numTiles = (shape[0] / kNumRowsPerTile) *
                       ((shape[1] * elType.getIntOrFloatBitWidth()) / 128);
   } else {
@@ -250,7 +250,7 @@ getLaneIdToLdMatrixMatrixCoord(Location loc, OpBuilder &builder,
   };
 
   // This case corresponds to row-major A|C or col-major B operands.
-  if (params.contiguousDimType == IteratorType::Reduction) {
+  if (params.contiguousDimType == vector::IteratorType::reduction) {
     AffineExpr row = d0 % (operandShape[0]);
     AffineExpr col = d0.floorDiv(operandShape[0]) * (kElementsPer128b);
     return makeMap({row, col});
@@ -258,7 +258,7 @@ getLaneIdToLdMatrixMatrixCoord(Location loc, OpBuilder &builder,
 
   // This case Corresponds to col-major A|C or row-major B operands. The
   // operandShape given is already pre-transposed (e.g. 8x16 = KxN).
-  if (params.contiguousDimType == IteratorType::Parallel) {
+  if (params.contiguousDimType == vector::IteratorType::parallel) {
     const int64_t num8x128bCols = (operandShape[0] * bitsPerElement) / 128;
     // Threads are assigned in groups of 8 first across columns, then to
     // rows. This is transpose of what `ldmatrix` expects, but when
@@ -293,9 +293,9 @@ PrepareContractToGPUMMASync::matchAndRewrite(vector::ContractionOp op,
   SmallVector<AffineMap, 4> maps = op.getIndexingMapsArray();
   if (iteratorTypes.size() != 3)
     return failure();
-  if (!(isParallelIterator(iteratorTypes[0]) &&
-        isParallelIterator(iteratorTypes[1]) &&
-        isReductionIterator(iteratorTypes[2])))
+  if (!(vector::isParallelIterator(iteratorTypes[0]) &&
+        vector::isParallelIterator(iteratorTypes[1]) &&
+        vector::isReductionIterator(iteratorTypes[2])))
     return failure();
 
   // The canonical form is "TNT" = A row-major, B col-major, C row-major.

diff  --git a/mlir/lib/Conversion/VectorToGPU/NvGpuSupport.h b/mlir/lib/Conversion/VectorToGPU/NvGpuSupport.h
index 9902faa835a6..aee429edbad7 100644
--- a/mlir/lib/Conversion/VectorToGPU/NvGpuSupport.h
+++ b/mlir/lib/Conversion/VectorToGPU/NvGpuSupport.h
@@ -71,7 +71,7 @@ struct LdMatrixParams {
   VectorType fragmentType;
   bool isAccum;
   int64_t numTiles;
-  IteratorType contiguousDimType;
+  vector::IteratorType contiguousDimType;
   NVVM::MMALayout targetLayout;
 };
 

diff  --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index a1708dfd3b56..707a35044980 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -74,9 +74,9 @@ static bool contractSupportsMMAMatrixType(vector::ContractionOp contract,
   AffineExpr m, n, k;
   bindDims(contract.getContext(), m, n, k);
   auto iteratorTypes = contract.getIteratorTypes().getValue();
-  if (!(isParallelIterator(iteratorTypes[0]) &&
-        isParallelIterator(iteratorTypes[1]) &&
-        isReductionIterator(iteratorTypes[2])))
+  if (!(vector::isParallelIterator(iteratorTypes[0]) &&
+        vector::isParallelIterator(iteratorTypes[1]) &&
+        vector::isReductionIterator(iteratorTypes[2])))
     return false;
 
   // The contract needs to represent a matmul to be able to convert to
@@ -296,9 +296,9 @@ struct PrepareContractToGPUMMA
     static constexpr std::array<int64_t, 2> perm = {1, 0};
     auto iteratorTypes = op.getIteratorTypes().getValue();
     SmallVector<AffineMap, 4> maps = op.getIndexingMapsArray();
-    if (!(isParallelIterator(iteratorTypes[0]) &&
-          isParallelIterator(iteratorTypes[1]) &&
-          isReductionIterator(iteratorTypes[2])))
+    if (!(vector::isParallelIterator(iteratorTypes[0]) &&
+          vector::isParallelIterator(iteratorTypes[1]) &&
+          vector::isReductionIterator(iteratorTypes[2])))
       return failure();
     //
     // Two outer parallel, one inner reduction (matmat flavor).

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 94f338a0d1de..badff4310187 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1488,13 +1488,14 @@ struct Conv1DNwcGenerator : public StructuredGenerator<LinalgOp> {
   // Create a contraction: lhs{n, w, c} * rhs{c, f} -> res{n, w, f}
   Value conv1dSliceAsContraction(OpBuilder &b, Location loc, Value lhs,
                                  Value rhs, Value res) {
-    StringRef par = Par().strRef, red = Red().strRef;
+    vector::IteratorType par = vector::IteratorType::parallel;
+    vector::IteratorType red = vector::IteratorType::reduction;
     AffineExpr n, w, f, c;
     bindDims(ctx, n, w, f, c);
     return builder.create<vector::ContractionOp>(
         loc, lhs, rhs, res,
         /*indexingMaps=*/MapList{{n, w, c}, {c, f}, {n, w, f}},
-        /*iteratorTypes=*/ArrayRef<StringRef>{par, par, par, red});
+        /*iteratorTypes=*/ArrayRef<vector::IteratorType>{par, par, par, red});
   }
 
   /// Generate a vector implementation for:

diff  --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index 2c430456d866..d70e81cc6d45 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -199,6 +199,16 @@ bool isPermutation(ArrayRef<int64_t> permutation) {
   return count(indexCounts, 1) == static_cast<int64_t>(permutation.size());
 }
 
+bool isParallelIterator(Attribute attr) {
+  auto strAttr = attr.dyn_cast_or_null<StringAttr>();
+  return strAttr && strAttr.getValue() == getParallelIteratorTypeName();
+}
+
+bool isReductionIterator(Attribute attr) {
+  auto strAttr = attr.dyn_cast_or_null<StringAttr>();
+  return strAttr && strAttr.getValue() == getReductionIteratorTypeName();
+}
+
 /// Helper function that creates a memref::DimOp or tensor::DimOp depending on
 /// the type of `source`.
 Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim) {

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index d81ed779cd8b..aa44a7f07878 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -350,7 +350,7 @@ static bool isAdmissableTensorExp(Merger &merger, linalg::GenericOp op,
   if (isMaterializing(lhs->get())) {
     unsigned nest = 0;
     for (unsigned i = 0; i < numLoops; i++) {
-      if (isReductionIterator(iteratorTypes[topSort[i]]))
+      if (linalg::isReductionIterator(iteratorTypes[topSort[i]]))
         break; // terminate at first reduction
       nest++;
     }
@@ -1234,7 +1234,7 @@ static Operation *genFor(Merger &merger, CodeGen &codegen, OpBuilder &builder,
   unsigned tensor = merger.tensor(fb);
   assert(idx == merger.index(fb));
   auto iteratorTypes = op.iterator_types().getValue();
-  bool isReduction = isReductionIterator(iteratorTypes[idx]);
+  bool isReduction = linalg::isReductionIterator(iteratorTypes[idx]);
   bool isSparse = merger.isDim(fb, Dim::kSparse);
   bool isVector = isVectorFor(codegen, isInner, isReduction, isSparse) &&
                   denseUnitStrides(merger, op, idx);

diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 574b4b977961..78ccb1fcf26d 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -455,14 +455,18 @@ void ReductionOp::getCanonicalizationPatterns(RewritePatternSet &results,
 void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
                                   Value lhs, Value rhs, Value acc,
                                   ArrayRef<ArrayRef<AffineExpr>> indexingExprs,
-                                  ArrayRef<StringRef> iteratorTypes) {
+                                  ArrayRef<IteratorType> iteratorTypes) {
   result.addOperands({lhs, rhs, acc});
   result.addTypes(acc.getType());
   result.addAttribute(::mlir::getIndexingMapsAttrName(),
                       builder.getAffineMapArrayAttr(
                           AffineMap::inferFromExprList(indexingExprs)));
-  result.addAttribute(::mlir::getIteratorTypesAttrName(),
-                      builder.getStrArrayAttr(iteratorTypes));
+  result.addAttribute(
+      ::mlir::getIteratorTypesAttrName(),
+      builder.getArrayAttr(llvm::to_vector(llvm::map_range(
+          iteratorTypes, [&](IteratorType t) -> mlir::Attribute {
+            return IteratorTypeAttr::get(builder.getContext(), t);
+          }))));
 }
 
 void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
@@ -510,6 +514,27 @@ ParseResult ContractionOp::parse(OpAsmParser &parser, OperationState &result) {
     return failure();
   result.attributes.assign(dictAttr.getValue().begin(),
                            dictAttr.getValue().end());
+
+  // Convert array of string into an array of IteratyType enums. This is needed,
+  // because tests still use the old format when 'iterator_types' attribute is
+  // represented as an array of strings.
+  // TODO: Remove this conversion once tests are fixed.
+  ArrayAttr iteratorTypes =
+      result.attributes.get("iterator_types").cast<ArrayAttr>();
+
+  SmallVector<Attribute> iteratorTypeAttrs;
+
+  for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
+    auto maybeIteratorType = symbolizeIteratorType(s);
+    if (!maybeIteratorType.hasValue())
+      return parser.emitError(loc) << "unexpected iterator_type (" << s << ")";
+
+    iteratorTypeAttrs.push_back(IteratorTypeAttr::get(
+        parser.getContext(), maybeIteratorType.getValue()));
+  }
+  result.attributes.set("iterator_types",
+                        parser.getBuilder().getArrayAttr(iteratorTypeAttrs));
+
   if (!result.attributes.get(ContractionOp::getKindAttrStrName())) {
     result.addAttribute(
         ContractionOp::getKindAttrStrName(),
@@ -538,9 +563,26 @@ void ContractionOp::print(OpAsmPrinter &p) {
   llvm::StringSet<> traitAttrsSet;
   traitAttrsSet.insert(attrNames.begin(), attrNames.end());
   SmallVector<NamedAttribute, 8> attrs;
-  for (auto attr : (*this)->getAttrs())
-    if (traitAttrsSet.count(attr.getName().strref()) > 0)
+  for (auto attr : (*this)->getAttrs()) {
+    if (attr.getName() == getIteratorTypesAttrName()) {
+      auto iteratorTypes =
+          attr.getValue()
+              .cast<ArrayAttr>()
+              .getAsValueRange<IteratorTypeAttr, IteratorType>();
+      // Convert IteratorType enums into the string representation. This is
+      // needed, because tests still use the old format when 'iterator_types'
+      // attribute is represented as an array of strings.
+      // TODO: Remove this conversion once tests are fixed.
+      SmallVector<Attribute> iteratorTypeNames = llvm::to_vector(
+          llvm::map_range(iteratorTypes, [&](IteratorType t) -> Attribute {
+            return StringAttr::get(getContext(), stringifyIteratorType(t));
+          }));
+
+      attrs.emplace_back(getIteratorTypesAttrName(),
+                         ArrayAttr::get(getContext(), iteratorTypeNames));
+    } else if (traitAttrsSet.count(attr.getName().strref()) > 0)
       attrs.push_back(attr);
+  }
 
   auto dictAttr = DictionaryAttr::get(getContext(), attrs);
   p << " " << dictAttr << " " << getLhs() << ", ";
@@ -746,11 +788,11 @@ static int64_t getResultIndex(AffineMap map, AffineExpr targetExpr) {
 
 static std::vector<std::pair<int64_t, int64_t>>
 getDimMap(ArrayRef<AffineMap> indexingMaps, ArrayAttr iteratorTypes,
-          StringRef targetIteratorTypeName, MLIRContext *context) {
+          IteratorType targetIteratorType, MLIRContext *context) {
   std::vector<std::pair<int64_t, int64_t>> dimMap;
   for (const auto &it : llvm::enumerate(iteratorTypes)) {
-    auto iteratorTypeName = it.value().cast<StringAttr>().getValue();
-    if (iteratorTypeName != targetIteratorTypeName)
+    auto iteratorType = it.value().cast<IteratorTypeAttr>().getValue();
+    if (iteratorType != targetIteratorType)
       continue;
     // Search lhs/rhs map results for 'targetExpr'.
     auto targetExpr = getAffineDimExpr(it.index(), context);
@@ -771,8 +813,8 @@ void ContractionOp::getIterationBounds(
   for (const auto &it : llvm::enumerate(getIteratorTypes())) {
     // Search lhs/rhs map results for 'targetExpr'.
     auto targetExpr = getAffineDimExpr(it.index(), getContext());
-    auto iteratorTypeName = it.value().cast<StringAttr>().getValue();
-    if (iteratorTypeName == getReductionIteratorTypeName()) {
+    auto iteratorType = it.value().cast<IteratorTypeAttr>().getValue();
+    if (iteratorType == IteratorType::reduction) {
       // Get reduction dim size from lhs shape (same size in rhsShape).
       int64_t lhsDimIndex = getResultIndex(indexingMaps[0], targetExpr);
       assert(lhsDimIndex >= 0);
@@ -803,14 +845,14 @@ void ContractionOp::getIterationIndexMap(
 
 std::vector<std::pair<int64_t, int64_t>> ContractionOp::getContractingDimMap() {
   SmallVector<AffineMap, 4> indexingMaps(getIndexingMapsArray());
-  return getDimMap(indexingMaps, getIteratorTypes(),
-                   getReductionIteratorTypeName(), getContext());
+  return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::reduction,
+                   getContext());
 }
 
 std::vector<std::pair<int64_t, int64_t>> ContractionOp::getBatchDimMap() {
   SmallVector<AffineMap, 4> indexingMaps(getIndexingMapsArray());
-  return getDimMap(indexingMaps, getIteratorTypes(),
-                   getParallelIteratorTypeName(), getContext());
+  return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::parallel,
+                   getContext());
 }
 
 Optional<SmallVector<int64_t, 4>> ContractionOp::getShapeForUnroll() {

diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 3ff045031be1..2bd6756fd77a 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -22,6 +22,7 @@
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/ImplicitLocOpBuilder.h"
@@ -986,13 +987,13 @@ struct MultiReduceToContract
     SmallVector<bool> reductionMask = reduceOp.getReductionMask();
     auto srcMap = rewriter.getMultiDimIdentityMap(reductionMask.size());
     SmallVector<AffineExpr> exprs;
-    SmallVector<StringRef> iteratorTypes;
+    SmallVector<vector::IteratorType> iteratorTypes;
     for (const auto &isReduceDim : llvm::enumerate(reductionMask)) {
       if (!isReduceDim.value()) {
-        iteratorTypes.push_back(getParallelIteratorTypeName());
+        iteratorTypes.push_back(vector::IteratorType::parallel);
         exprs.push_back(rewriter.getAffineDimExpr(isReduceDim.index()));
       } else {
-        iteratorTypes.push_back(getReductionIteratorTypeName());
+        iteratorTypes.push_back(vector::IteratorType::reduction);
       }
     }
     auto dstMap = AffineMap::get(/*dimCount=*/reductionMask.size(),
@@ -1000,7 +1001,10 @@ struct MultiReduceToContract
     rewriter.replaceOpWithNewOp<mlir::vector::ContractionOp>(
         reduceOp, mulOp->getOperand(0), mulOp->getOperand(1), reduceOp.getAcc(),
         rewriter.getAffineMapArrayAttr({srcMap, srcMap, dstMap}),
-        rewriter.getStrArrayAttr(iteratorTypes));
+        rewriter.getArrayAttr(llvm::to_vector(llvm::map_range(
+            iteratorTypes, [&](IteratorType t) -> mlir::Attribute {
+              return IteratorTypeAttr::get(rewriter.getContext(), t);
+            }))));
     return success();
   }
 };


        


More information about the Mlir-commits mailing list