[Mlir-commits] [mlir] 4758e91 - [mlir] Change IteratorType in ContractionOp in Vector dialect from string to enum.
Oleg Shyshkov
llvmlistbot at llvm.org
Mon Sep 12 07:59:46 PDT 2022
Author: Oleg Shyshkov
Date: 2022-09-12T16:59:34+02:00
New Revision: 4758e916e1b34d800b03cd2ea6a0a554ce2483be
URL: https://github.com/llvm/llvm-project/commit/4758e916e1b34d800b03cd2ea6a0a554ce2483be
DIFF: https://github.com/llvm/llvm-project/commit/4758e916e1b34d800b03cd2ea6a0a554ce2483be.diff
LOG: [mlir] Change IteratorType in ContractionOp in Vector dialect from string to enum.
This is the first step in replacing interator_type from strings with enums in Vector and Linalg dialect. This change adds IteratorTypeAttr and uses it in ContractionOp.
To avoid breaking all the tests, print/parse code has conversion between string and enum for now.
There is a shared code in StructuredOpsUtils.h that expects iterator types to be strings. To break this dependancy, this change forks helper function `isParallelIterator` and `isReductionIterator` to utils in both dialects and adds `getIteratorTypeNames()` to support backward compatibility with StructuredGenerator.
In the later changes, I plan to add a similar enum attribute to Linalg.
Differential Revision: https://reviews.llvm.org/D133696
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
mlir/lib/Conversion/VectorToGPU/NvGpuSupport.cpp
mlir/lib/Conversion/VectorToGPU/NvGpuSupport.h
mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/lib/Dialect/Linalg/Utils/Utils.cpp
mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index 8bcde8a5a2aa..513aedcb82c4 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -727,6 +727,10 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
// TODO: Remove once prefixing is flipped.
ArrayAttr getIteratorTypes() { return iterator_types(); }
+ SmallVector<StringRef> getIteratorTypeNames() {
+ return llvm::to_vector(getIteratorTypes().getAsValueRange<StringAttr>());
+ }
+
//========================================================================//
// Forwarding functions to access interface methods from the
// DestinationStyleOpInterface.
diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index 20164d7b051c..6f31e6a3abea 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -45,6 +45,12 @@ bool isElementwise(LinalgOp op);
/// `[0, permutation.size())`.
bool isPermutation(ArrayRef<int64_t> permutation);
+/// Check if `attr` has "parallel" iterator type semantics.
+bool isParallelIterator(Attribute attr);
+
+/// Check if `attr` has "reduction" iterator type semantics.
+bool isReductionIterator(Attribute attr);
+
/// Helper function that creates a memref::DimOp or tensor::DimOp depending on
/// the type of `source`.
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim);
diff --git a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
index 400da2d751ae..9dfde80e793b 100644
--- a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
@@ -78,24 +78,12 @@ constexpr StringRef getPaddingAttrName() { return "padding"; }
/// Use to encode that a particular iterator type has parallel semantics.
constexpr StringRef getParallelIteratorTypeName() { return "parallel"; }
-inline bool isParallelIterator(Attribute attr) {
- auto strAttr = attr.dyn_cast_or_null<StringAttr>();
- return strAttr && strAttr.getValue() == getParallelIteratorTypeName();
-}
/// Use to encode that a particular iterator type has reduction semantics.
constexpr StringRef getReductionIteratorTypeName() { return "reduction"; }
-inline bool isReductionIterator(Attribute attr) {
- auto strAttr = attr.dyn_cast_or_null<StringAttr>();
- return strAttr && strAttr.getValue() == getReductionIteratorTypeName();
-}
/// Use to encode that a particular iterator type has window semantics.
constexpr StringRef getWindowIteratorTypeName() { return "window"; }
-inline bool isWindowIterator(Attribute attr) {
- auto strAttr = attr.dyn_cast_or_null<StringAttr>();
- return strAttr && strAttr.getValue() == getWindowIteratorTypeName();
-}
/// Use to encode that a particular iterator type has window semantics.
inline ArrayRef<StringRef> getAllIteratorTypeNames() {
@@ -122,19 +110,6 @@ inline unsigned getNumIterators(ArrayAttr iteratorTypes) {
return res;
}
-/// Typed representation for loop type strings.
-enum class IteratorType { Parallel, Reduction };
-
-inline StringRef toString(IteratorType t) {
- switch (t) {
- case IteratorType::Parallel:
- return getParallelIteratorTypeName();
- case IteratorType::Reduction:
- return getReductionIteratorTypeName();
- }
- llvm_unreachable("Unsupported IteratorType");
-}
-
/// Helper StructuredGenerator class to manipulate and rewrite ops with
/// `StructuredOpInterface`. This is templated for now because VectorOps do not
/// yet implement the StructuredOpInterface itself.
@@ -145,10 +120,7 @@ class StructuredGenerator {
struct IteratorType {
IteratorType(StringRef strRef) : strRef(strRef) {}
- bool isOfType(Attribute attr) const {
- auto sAttr = attr.dyn_cast<StringAttr>();
- return sAttr && sAttr.getValue() == strRef;
- }
+ bool isOfType(StringRef typeName) const { return typeName == strRef; }
StringRef strRef;
};
struct Par : public IteratorType {
@@ -163,7 +135,7 @@ class StructuredGenerator {
StructuredGenerator(OpBuilder &builder, StructuredOpInterface op)
: builder(builder), ctx(op.getContext()), loc(op.getLoc()),
- iterators(op.getIteratorTypes()), maps(op.getIndexingMapsArray()),
+ iterators(op.getIteratorTypeNames()), maps(op.getIndexingMapsArray()),
op(op) {}
bool iters(ArrayRef<IteratorType> its) {
@@ -185,7 +157,7 @@ class StructuredGenerator {
OpBuilder &builder;
MLIRContext *ctx;
Location loc;
- ArrayAttr iterators;
+ SmallVector<StringRef> iterators;
SmallVector<AffineMap, 4> maps;
Operation *op;
};
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
index f4316d061560..61a68449f83c 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
@@ -185,6 +185,17 @@ bool isDisjointTransferSet(VectorTransferOpInterface transferA,
/// corresponding arith operation.
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind,
Value v1, Value v2);
+
+/// Returns true if `attr` has "parallel" iterator type semantics.
+inline bool isParallelIterator(Attribute attr) {
+ return attr.cast<IteratorTypeAttr>().getValue() == IteratorType::parallel;
+}
+
+/// Returns true if `attr` has "reduction" iterator type semantics.
+inline bool isReductionIterator(Attribute attr) {
+ return attr.cast<IteratorTypeAttr>().getValue() == IteratorType::reduction;
+}
+
} // namespace vector
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index ec29d53d2ae8..07f1b21a24b8 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -63,6 +63,21 @@ def Vector_CombiningKindAttr : EnumAttr<Vector_Dialect, CombiningKind, "kind"> {
let assemblyFormat = "`<` $value `>`";
}
+def IteratorType : I32EnumAttr<"IteratorType", "Iterator type", [
+ I32EnumAttrCase<"parallel", 0>,
+ I32EnumAttrCase<"reduction", 1>
+ ]> {
+ let genSpecializedAttr = 0;
+ let cppNamespace = "::mlir::vector";
+}
+
+def IteratorTypeEnum : EnumAttr<Vector_Dialect, IteratorType, "iterator_type"> {
+ let assemblyFormat = "`<` $value `>`";
+}
+
+def IteratorTypeArrayAttr : TypedArrayAttrBase<IteratorTypeEnum,
+ "Iterator type should be an enum.">;
+
// TODO: Add an attribute to specify a
diff erent algebra with operators other
// than the current set: {*, +}.
def Vector_ContractionOp :
@@ -76,7 +91,7 @@ def Vector_ContractionOp :
Arguments<(ins AnyVector:$lhs, AnyVector:$rhs, AnyType:$acc,
Variadic<VectorOf<[I1]>>:$masks,
ArrayAttr:$indexing_maps,
- ArrayAttr:$iterator_types,
+ IteratorTypeArrayAttr:$iterator_types,
DefaultValuedAttr<Vector_CombiningKindAttr,
"CombiningKind::ADD">:$kind)>,
Results<(outs AnyType)> {
@@ -201,7 +216,7 @@ def Vector_ContractionOp :
"ArrayAttr":$indexingMaps, "ArrayAttr":$iteratorTypes)>,
OpBuilder<(ins "Value":$lhs, "Value":$rhs, "Value":$acc,
"ArrayRef<ArrayRef<AffineExpr>>":$indexingExprs,
- "ArrayRef<StringRef>":$iteratorTypes)>,
+ "ArrayRef<IteratorType>":$iteratorTypes)>,
OpBuilder<(ins "Value":$lhs, "Value":$rhs, "Value":$acc,
"ArrayAttr":$indexingMaps, "ArrayAttr":$iteratorTypes,
"CombiningKind":$kind)>
@@ -249,6 +264,14 @@ def Vector_ContractionOp :
static CombiningKind getDefaultKind() {
return CombiningKind::ADD;
}
+
+ // Returns iterator types in string format.
+ SmallVector<StringRef> getIteratorTypeNames() {
+ return llvm::to_vector(
+ llvm::map_range(getIteratorTypes(), [](Attribute a) {
+ return stringifyIteratorType(a.cast<IteratorTypeAttr>().getValue());
+ }));
+ }
}];
let hasCanonicalizer = 1;
diff --git a/mlir/lib/Conversion/VectorToGPU/NvGpuSupport.cpp b/mlir/lib/Conversion/VectorToGPU/NvGpuSupport.cpp
index 776e746eee63..9df576c4e73e 100644
--- a/mlir/lib/Conversion/VectorToGPU/NvGpuSupport.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/NvGpuSupport.cpp
@@ -217,10 +217,10 @@ FailureOr<nvgpu::LdMatrixParams> getLdMatrixParams(const WarpMatrixInfo &type,
params.targetLayout = NVVM::MMALayout::col;
}
ArrayRef<int64_t> shape = type.vectorType.getShape();
- params.contiguousDimType =
- transpose ? IteratorType::Parallel : IteratorType::Reduction;
+ params.contiguousDimType = transpose ? vector::IteratorType::parallel
+ : vector::IteratorType::reduction;
- if (params.contiguousDimType == IteratorType::Reduction) {
+ if (params.contiguousDimType == vector::IteratorType::reduction) {
params.numTiles = (shape[0] / kNumRowsPerTile) *
((shape[1] * elType.getIntOrFloatBitWidth()) / 128);
} else {
@@ -250,7 +250,7 @@ getLaneIdToLdMatrixMatrixCoord(Location loc, OpBuilder &builder,
};
// This case corresponds to row-major A|C or col-major B operands.
- if (params.contiguousDimType == IteratorType::Reduction) {
+ if (params.contiguousDimType == vector::IteratorType::reduction) {
AffineExpr row = d0 % (operandShape[0]);
AffineExpr col = d0.floorDiv(operandShape[0]) * (kElementsPer128b);
return makeMap({row, col});
@@ -258,7 +258,7 @@ getLaneIdToLdMatrixMatrixCoord(Location loc, OpBuilder &builder,
// This case Corresponds to col-major A|C or row-major B operands. The
// operandShape given is already pre-transposed (e.g. 8x16 = KxN).
- if (params.contiguousDimType == IteratorType::Parallel) {
+ if (params.contiguousDimType == vector::IteratorType::parallel) {
const int64_t num8x128bCols = (operandShape[0] * bitsPerElement) / 128;
// Threads are assigned in groups of 8 first across columns, then to
// rows. This is transpose of what `ldmatrix` expects, but when
@@ -293,9 +293,9 @@ PrepareContractToGPUMMASync::matchAndRewrite(vector::ContractionOp op,
SmallVector<AffineMap, 4> maps = op.getIndexingMapsArray();
if (iteratorTypes.size() != 3)
return failure();
- if (!(isParallelIterator(iteratorTypes[0]) &&
- isParallelIterator(iteratorTypes[1]) &&
- isReductionIterator(iteratorTypes[2])))
+ if (!(vector::isParallelIterator(iteratorTypes[0]) &&
+ vector::isParallelIterator(iteratorTypes[1]) &&
+ vector::isReductionIterator(iteratorTypes[2])))
return failure();
// The canonical form is "TNT" = A row-major, B col-major, C row-major.
diff --git a/mlir/lib/Conversion/VectorToGPU/NvGpuSupport.h b/mlir/lib/Conversion/VectorToGPU/NvGpuSupport.h
index 9902faa835a6..aee429edbad7 100644
--- a/mlir/lib/Conversion/VectorToGPU/NvGpuSupport.h
+++ b/mlir/lib/Conversion/VectorToGPU/NvGpuSupport.h
@@ -71,7 +71,7 @@ struct LdMatrixParams {
VectorType fragmentType;
bool isAccum;
int64_t numTiles;
- IteratorType contiguousDimType;
+ vector::IteratorType contiguousDimType;
NVVM::MMALayout targetLayout;
};
diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index a1708dfd3b56..707a35044980 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -74,9 +74,9 @@ static bool contractSupportsMMAMatrixType(vector::ContractionOp contract,
AffineExpr m, n, k;
bindDims(contract.getContext(), m, n, k);
auto iteratorTypes = contract.getIteratorTypes().getValue();
- if (!(isParallelIterator(iteratorTypes[0]) &&
- isParallelIterator(iteratorTypes[1]) &&
- isReductionIterator(iteratorTypes[2])))
+ if (!(vector::isParallelIterator(iteratorTypes[0]) &&
+ vector::isParallelIterator(iteratorTypes[1]) &&
+ vector::isReductionIterator(iteratorTypes[2])))
return false;
// The contract needs to represent a matmul to be able to convert to
@@ -296,9 +296,9 @@ struct PrepareContractToGPUMMA
static constexpr std::array<int64_t, 2> perm = {1, 0};
auto iteratorTypes = op.getIteratorTypes().getValue();
SmallVector<AffineMap, 4> maps = op.getIndexingMapsArray();
- if (!(isParallelIterator(iteratorTypes[0]) &&
- isParallelIterator(iteratorTypes[1]) &&
- isReductionIterator(iteratorTypes[2])))
+ if (!(vector::isParallelIterator(iteratorTypes[0]) &&
+ vector::isParallelIterator(iteratorTypes[1]) &&
+ vector::isReductionIterator(iteratorTypes[2])))
return failure();
//
// Two outer parallel, one inner reduction (matmat flavor).
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 94f338a0d1de..badff4310187 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1488,13 +1488,14 @@ struct Conv1DNwcGenerator : public StructuredGenerator<LinalgOp> {
// Create a contraction: lhs{n, w, c} * rhs{c, f} -> res{n, w, f}
Value conv1dSliceAsContraction(OpBuilder &b, Location loc, Value lhs,
Value rhs, Value res) {
- StringRef par = Par().strRef, red = Red().strRef;
+ vector::IteratorType par = vector::IteratorType::parallel;
+ vector::IteratorType red = vector::IteratorType::reduction;
AffineExpr n, w, f, c;
bindDims(ctx, n, w, f, c);
return builder.create<vector::ContractionOp>(
loc, lhs, rhs, res,
/*indexingMaps=*/MapList{{n, w, c}, {c, f}, {n, w, f}},
- /*iteratorTypes=*/ArrayRef<StringRef>{par, par, par, red});
+ /*iteratorTypes=*/ArrayRef<vector::IteratorType>{par, par, par, red});
}
/// Generate a vector implementation for:
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index 2c430456d866..d70e81cc6d45 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -199,6 +199,16 @@ bool isPermutation(ArrayRef<int64_t> permutation) {
return count(indexCounts, 1) == static_cast<int64_t>(permutation.size());
}
+bool isParallelIterator(Attribute attr) {
+ auto strAttr = attr.dyn_cast_or_null<StringAttr>();
+ return strAttr && strAttr.getValue() == getParallelIteratorTypeName();
+}
+
+bool isReductionIterator(Attribute attr) {
+ auto strAttr = attr.dyn_cast_or_null<StringAttr>();
+ return strAttr && strAttr.getValue() == getReductionIteratorTypeName();
+}
+
/// Helper function that creates a memref::DimOp or tensor::DimOp depending on
/// the type of `source`.
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim) {
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index d81ed779cd8b..aa44a7f07878 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -350,7 +350,7 @@ static bool isAdmissableTensorExp(Merger &merger, linalg::GenericOp op,
if (isMaterializing(lhs->get())) {
unsigned nest = 0;
for (unsigned i = 0; i < numLoops; i++) {
- if (isReductionIterator(iteratorTypes[topSort[i]]))
+ if (linalg::isReductionIterator(iteratorTypes[topSort[i]]))
break; // terminate at first reduction
nest++;
}
@@ -1234,7 +1234,7 @@ static Operation *genFor(Merger &merger, CodeGen &codegen, OpBuilder &builder,
unsigned tensor = merger.tensor(fb);
assert(idx == merger.index(fb));
auto iteratorTypes = op.iterator_types().getValue();
- bool isReduction = isReductionIterator(iteratorTypes[idx]);
+ bool isReduction = linalg::isReductionIterator(iteratorTypes[idx]);
bool isSparse = merger.isDim(fb, Dim::kSparse);
bool isVector = isVectorFor(codegen, isInner, isReduction, isSparse) &&
denseUnitStrides(merger, op, idx);
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 574b4b977961..78ccb1fcf26d 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -455,14 +455,18 @@ void ReductionOp::getCanonicalizationPatterns(RewritePatternSet &results,
void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
Value lhs, Value rhs, Value acc,
ArrayRef<ArrayRef<AffineExpr>> indexingExprs,
- ArrayRef<StringRef> iteratorTypes) {
+ ArrayRef<IteratorType> iteratorTypes) {
result.addOperands({lhs, rhs, acc});
result.addTypes(acc.getType());
result.addAttribute(::mlir::getIndexingMapsAttrName(),
builder.getAffineMapArrayAttr(
AffineMap::inferFromExprList(indexingExprs)));
- result.addAttribute(::mlir::getIteratorTypesAttrName(),
- builder.getStrArrayAttr(iteratorTypes));
+ result.addAttribute(
+ ::mlir::getIteratorTypesAttrName(),
+ builder.getArrayAttr(llvm::to_vector(llvm::map_range(
+ iteratorTypes, [&](IteratorType t) -> mlir::Attribute {
+ return IteratorTypeAttr::get(builder.getContext(), t);
+ }))));
}
void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
@@ -510,6 +514,27 @@ ParseResult ContractionOp::parse(OpAsmParser &parser, OperationState &result) {
return failure();
result.attributes.assign(dictAttr.getValue().begin(),
dictAttr.getValue().end());
+
+ // Convert array of string into an array of IteratyType enums. This is needed,
+ // because tests still use the old format when 'iterator_types' attribute is
+ // represented as an array of strings.
+ // TODO: Remove this conversion once tests are fixed.
+ ArrayAttr iteratorTypes =
+ result.attributes.get("iterator_types").cast<ArrayAttr>();
+
+ SmallVector<Attribute> iteratorTypeAttrs;
+
+ for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
+ auto maybeIteratorType = symbolizeIteratorType(s);
+ if (!maybeIteratorType.hasValue())
+ return parser.emitError(loc) << "unexpected iterator_type (" << s << ")";
+
+ iteratorTypeAttrs.push_back(IteratorTypeAttr::get(
+ parser.getContext(), maybeIteratorType.getValue()));
+ }
+ result.attributes.set("iterator_types",
+ parser.getBuilder().getArrayAttr(iteratorTypeAttrs));
+
if (!result.attributes.get(ContractionOp::getKindAttrStrName())) {
result.addAttribute(
ContractionOp::getKindAttrStrName(),
@@ -538,9 +563,26 @@ void ContractionOp::print(OpAsmPrinter &p) {
llvm::StringSet<> traitAttrsSet;
traitAttrsSet.insert(attrNames.begin(), attrNames.end());
SmallVector<NamedAttribute, 8> attrs;
- for (auto attr : (*this)->getAttrs())
- if (traitAttrsSet.count(attr.getName().strref()) > 0)
+ for (auto attr : (*this)->getAttrs()) {
+ if (attr.getName() == getIteratorTypesAttrName()) {
+ auto iteratorTypes =
+ attr.getValue()
+ .cast<ArrayAttr>()
+ .getAsValueRange<IteratorTypeAttr, IteratorType>();
+ // Convert IteratorType enums into the string representation. This is
+ // needed, because tests still use the old format when 'iterator_types'
+ // attribute is represented as an array of strings.
+ // TODO: Remove this conversion once tests are fixed.
+ SmallVector<Attribute> iteratorTypeNames = llvm::to_vector(
+ llvm::map_range(iteratorTypes, [&](IteratorType t) -> Attribute {
+ return StringAttr::get(getContext(), stringifyIteratorType(t));
+ }));
+
+ attrs.emplace_back(getIteratorTypesAttrName(),
+ ArrayAttr::get(getContext(), iteratorTypeNames));
+ } else if (traitAttrsSet.count(attr.getName().strref()) > 0)
attrs.push_back(attr);
+ }
auto dictAttr = DictionaryAttr::get(getContext(), attrs);
p << " " << dictAttr << " " << getLhs() << ", ";
@@ -746,11 +788,11 @@ static int64_t getResultIndex(AffineMap map, AffineExpr targetExpr) {
static std::vector<std::pair<int64_t, int64_t>>
getDimMap(ArrayRef<AffineMap> indexingMaps, ArrayAttr iteratorTypes,
- StringRef targetIteratorTypeName, MLIRContext *context) {
+ IteratorType targetIteratorType, MLIRContext *context) {
std::vector<std::pair<int64_t, int64_t>> dimMap;
for (const auto &it : llvm::enumerate(iteratorTypes)) {
- auto iteratorTypeName = it.value().cast<StringAttr>().getValue();
- if (iteratorTypeName != targetIteratorTypeName)
+ auto iteratorType = it.value().cast<IteratorTypeAttr>().getValue();
+ if (iteratorType != targetIteratorType)
continue;
// Search lhs/rhs map results for 'targetExpr'.
auto targetExpr = getAffineDimExpr(it.index(), context);
@@ -771,8 +813,8 @@ void ContractionOp::getIterationBounds(
for (const auto &it : llvm::enumerate(getIteratorTypes())) {
// Search lhs/rhs map results for 'targetExpr'.
auto targetExpr = getAffineDimExpr(it.index(), getContext());
- auto iteratorTypeName = it.value().cast<StringAttr>().getValue();
- if (iteratorTypeName == getReductionIteratorTypeName()) {
+ auto iteratorType = it.value().cast<IteratorTypeAttr>().getValue();
+ if (iteratorType == IteratorType::reduction) {
// Get reduction dim size from lhs shape (same size in rhsShape).
int64_t lhsDimIndex = getResultIndex(indexingMaps[0], targetExpr);
assert(lhsDimIndex >= 0);
@@ -803,14 +845,14 @@ void ContractionOp::getIterationIndexMap(
std::vector<std::pair<int64_t, int64_t>> ContractionOp::getContractingDimMap() {
SmallVector<AffineMap, 4> indexingMaps(getIndexingMapsArray());
- return getDimMap(indexingMaps, getIteratorTypes(),
- getReductionIteratorTypeName(), getContext());
+ return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::reduction,
+ getContext());
}
std::vector<std::pair<int64_t, int64_t>> ContractionOp::getBatchDimMap() {
SmallVector<AffineMap, 4> indexingMaps(getIndexingMapsArray());
- return getDimMap(indexingMaps, getIteratorTypes(),
- getParallelIteratorTypeName(), getContext());
+ return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::parallel,
+ getContext());
}
Optional<SmallVector<int64_t, 4>> ContractionOp::getShapeForUnroll() {
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 3ff045031be1..2bd6756fd77a 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -22,6 +22,7 @@
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
@@ -986,13 +987,13 @@ struct MultiReduceToContract
SmallVector<bool> reductionMask = reduceOp.getReductionMask();
auto srcMap = rewriter.getMultiDimIdentityMap(reductionMask.size());
SmallVector<AffineExpr> exprs;
- SmallVector<StringRef> iteratorTypes;
+ SmallVector<vector::IteratorType> iteratorTypes;
for (const auto &isReduceDim : llvm::enumerate(reductionMask)) {
if (!isReduceDim.value()) {
- iteratorTypes.push_back(getParallelIteratorTypeName());
+ iteratorTypes.push_back(vector::IteratorType::parallel);
exprs.push_back(rewriter.getAffineDimExpr(isReduceDim.index()));
} else {
- iteratorTypes.push_back(getReductionIteratorTypeName());
+ iteratorTypes.push_back(vector::IteratorType::reduction);
}
}
auto dstMap = AffineMap::get(/*dimCount=*/reductionMask.size(),
@@ -1000,7 +1001,10 @@ struct MultiReduceToContract
rewriter.replaceOpWithNewOp<mlir::vector::ContractionOp>(
reduceOp, mulOp->getOperand(0), mulOp->getOperand(1), reduceOp.getAcc(),
rewriter.getAffineMapArrayAttr({srcMap, srcMap, dstMap}),
- rewriter.getStrArrayAttr(iteratorTypes));
+ rewriter.getArrayAttr(llvm::to_vector(llvm::map_range(
+ iteratorTypes, [&](IteratorType t) -> mlir::Attribute {
+ return IteratorTypeAttr::get(rewriter.getContext(), t);
+ }))));
return success();
}
};
More information about the Mlir-commits
mailing list