[Mlir-commits] [mlir] d558540 - [mlir][Vector] Add return type inference for multi_reduction
Benjamin Kramer
llvmlistbot at llvm.org
Fri Feb 18 04:00:59 PST 2022
Author: Benjamin Kramer
Date: 2022-02-18T13:00:42+01:00
New Revision: d558540fae376361fbbf9554828c7488bc1c341d
URL: https://github.com/llvm/llvm-project/commit/d558540fae376361fbbf9554828c7488bc1c341d
DIFF: https://github.com/llvm/llvm-project/commit/d558540fae376361fbbf9554828c7488bc1c341d.diff
LOG: [mlir][Vector] Add return type inference for multi_reduction
This subsumes the builder and verifier.
Added:
Modified:
mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/test/Dialect/Vector/invalid.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 4a20ea0dc4d1..aec8dd5b6882 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -313,7 +313,8 @@ def Vector_ReductionOp :
def Vector_MultiDimReductionOp :
Vector_Op<"multi_reduction", [NoSideEffect,
PredOpTrait<"source operand and result have same element type",
- TCresVTEtIsSameAsOpBase<0, 0>>]>,
+ TCresVTEtIsSameAsOpBase<0, 0>>,
+ DeclareOpInterfaceMethods<InferTypeOpInterface>]>,
Arguments<(ins Vector_CombiningKindAttr:$kind,
AnyVector:$source,
I64ArrayAttr:$reduction_dims)>,
@@ -367,31 +368,10 @@ def Vector_MultiDimReductionOp :
res[idx] = true;
return res;
}
-
- static SmallVector<int64_t> inferDestShape(
- ArrayRef<int64_t> sourceShape, ArrayRef<bool> reducedDimsMask) {
- assert(sourceShape.size() == reducedDimsMask.size() &&
- "sourceShape and maks of
diff erent sizes");
- SmallVector<int64_t> res;
- for (auto it : llvm::zip(reducedDimsMask, sourceShape))
- if (!std::get<0>(it))
- res.push_back(std::get<1>(it));
- return res;
- }
-
- static Type inferDestType(
- ArrayRef<int64_t> sourceShape, ArrayRef<bool> reducedDimsMask, Type elementType) {
- auto targetShape = inferDestShape(sourceShape, reducedDimsMask);
- // TODO: update to also allow 0-d vectors when available.
- if (targetShape.empty())
- return elementType;
- return VectorType::get(targetShape, elementType);
- }
}];
let assemblyFormat =
"$kind `,` $source attr-dict $reduction_dims `:` type($source) `to` type($dest)";
let hasFolder = 1;
- let hasVerifier = 1;
}
def Vector_BroadcastOp :
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 4ffb2b8c7569..ddfe0d844228 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -336,32 +336,31 @@ void vector::MultiDimReductionOp::build(OpBuilder &builder,
OperationState &result, Value source,
ArrayRef<bool> reductionMask,
CombiningKind kind) {
- result.addOperands(source);
- auto sourceVectorType = source.getType().cast<VectorType>();
- auto targetType = MultiDimReductionOp::inferDestType(
- sourceVectorType.getShape(), reductionMask,
- sourceVectorType.getElementType());
- result.addTypes(targetType);
-
SmallVector<int64_t> reductionDims;
for (const auto &en : llvm::enumerate(reductionMask))
if (en.value())
reductionDims.push_back(en.index());
- result.addAttribute(getReductionDimsAttrStrName(),
- builder.getI64ArrayAttr(reductionDims));
- result.addAttribute(getKindAttrStrName(),
- CombiningKindAttr::get(kind, builder.getContext()));
-}
-
-LogicalResult MultiDimReductionOp::verify() {
- auto reductionMask = getReductionMask();
- auto targetType = MultiDimReductionOp::inferDestType(
- getSourceVectorType().getShape(), reductionMask,
- getSourceVectorType().getElementType());
- // TODO: update to support 0-d vectors when available.
- if (targetType != getDestType())
- return emitError("invalid output vector type: ")
- << getDestType() << " (expected: " << targetType << ")";
+ build(builder, result, kind, source, builder.getI64ArrayAttr(reductionDims));
+}
+
+LogicalResult MultiDimReductionOp::inferReturnTypes(
+ MLIRContext *, Optional<Location>, ValueRange operands,
+ DictionaryAttr attributes, RegionRange,
+ SmallVectorImpl<Type> &inferredReturnTypes) {
+ MultiDimReductionOp::Adaptor op(operands, attributes);
+ auto vectorType = op.source().getType().cast<VectorType>();
+ SmallVector<int64_t> targetShape;
+ for (auto it : llvm::enumerate(vectorType.getShape()))
+ if (!llvm::any_of(op.reduction_dims().getValue(), [&](Attribute attr) {
+ return attr.cast<IntegerAttr>().getValue() == it.index();
+ }))
+ targetShape.push_back(it.value());
+ // TODO: update to also allow 0-d vectors when available.
+ if (targetShape.empty())
+ inferredReturnTypes.push_back(vectorType.getElementType());
+ else
+ inferredReturnTypes.push_back(
+ VectorType::get(targetShape, vectorType.getElementType()));
return success();
}
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 2e224f7f58eb..c90725e5d8d7 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1137,6 +1137,13 @@ func @reduce_unsupported_rank(%arg0: vector<4x16xf32>) -> f32 {
// -----
+func @multi_reduce_invalid_type(%arg0: vector<4x16xf32>) -> f32 {
+ // expected-error at +1 {{'vector.multi_reduction' op inferred type(s) 'vector<4xf32>' are incompatible with return type(s) of operation 'vector<16xf32>'}}
+ %0 = vector.multi_reduction <mul>, %arg0 [1] : vector<4x16xf32> to vector<16xf32>
+}
+
+// -----
+
func @transpose_rank_mismatch(%arg0: vector<4x16x11xf32>) {
// expected-error at +1 {{'vector.transpose' op vector result rank mismatch: 1}}
%0 = vector.transpose %arg0, [2, 1, 0] : vector<4x16x11xf32> to vector<100xf32>
More information about the Mlir-commits
mailing list