[Mlir-commits] [mlir] d558540 - [mlir][Vector] Add return type inference for multi_reduction

Benjamin Kramer llvmlistbot at llvm.org
Fri Feb 18 04:00:59 PST 2022


Author: Benjamin Kramer
Date: 2022-02-18T13:00:42+01:00
New Revision: d558540fae376361fbbf9554828c7488bc1c341d

URL: https://github.com/llvm/llvm-project/commit/d558540fae376361fbbf9554828c7488bc1c341d
DIFF: https://github.com/llvm/llvm-project/commit/d558540fae376361fbbf9554828c7488bc1c341d.diff

LOG: [mlir][Vector] Add return type inference for multi_reduction

This subsumes the builder and verifier.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/test/Dialect/Vector/invalid.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 4a20ea0dc4d1..aec8dd5b6882 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -313,7 +313,8 @@ def Vector_ReductionOp :
 def Vector_MultiDimReductionOp :
   Vector_Op<"multi_reduction", [NoSideEffect,
      PredOpTrait<"source operand and result have same element type",
-                 TCresVTEtIsSameAsOpBase<0, 0>>]>,
+                 TCresVTEtIsSameAsOpBase<0, 0>>,
+     DeclareOpInterfaceMethods<InferTypeOpInterface>]>,
     Arguments<(ins Vector_CombiningKindAttr:$kind,
                    AnyVector:$source,
                    I64ArrayAttr:$reduction_dims)>,
@@ -367,31 +368,10 @@ def Vector_MultiDimReductionOp :
         res[idx] = true;
       return res;
     }
-
-    static SmallVector<int64_t> inferDestShape(
-      ArrayRef<int64_t> sourceShape, ArrayRef<bool> reducedDimsMask) {
-      assert(sourceShape.size() == reducedDimsMask.size() &&
-             "sourceShape and maks of 
diff erent sizes");
-      SmallVector<int64_t> res;
-      for (auto it : llvm::zip(reducedDimsMask, sourceShape))
-        if (!std::get<0>(it))
-          res.push_back(std::get<1>(it));
-      return res;
-    }
-
-    static Type inferDestType(
-      ArrayRef<int64_t> sourceShape, ArrayRef<bool> reducedDimsMask, Type elementType) {
-      auto targetShape = inferDestShape(sourceShape, reducedDimsMask);
-      // TODO: update to also allow 0-d vectors when available.
-      if (targetShape.empty())
-        return elementType;
-      return VectorType::get(targetShape, elementType);
-    }
   }];
   let assemblyFormat =
     "$kind `,` $source attr-dict $reduction_dims `:` type($source) `to` type($dest)";
   let hasFolder = 1;
-  let hasVerifier = 1;
 }
 
 def Vector_BroadcastOp :

diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 4ffb2b8c7569..ddfe0d844228 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -336,32 +336,31 @@ void vector::MultiDimReductionOp::build(OpBuilder &builder,
                                         OperationState &result, Value source,
                                         ArrayRef<bool> reductionMask,
                                         CombiningKind kind) {
-  result.addOperands(source);
-  auto sourceVectorType = source.getType().cast<VectorType>();
-  auto targetType = MultiDimReductionOp::inferDestType(
-      sourceVectorType.getShape(), reductionMask,
-      sourceVectorType.getElementType());
-  result.addTypes(targetType);
-
   SmallVector<int64_t> reductionDims;
   for (const auto &en : llvm::enumerate(reductionMask))
     if (en.value())
       reductionDims.push_back(en.index());
-  result.addAttribute(getReductionDimsAttrStrName(),
-                      builder.getI64ArrayAttr(reductionDims));
-  result.addAttribute(getKindAttrStrName(),
-                      CombiningKindAttr::get(kind, builder.getContext()));
-}
-
-LogicalResult MultiDimReductionOp::verify() {
-  auto reductionMask = getReductionMask();
-  auto targetType = MultiDimReductionOp::inferDestType(
-      getSourceVectorType().getShape(), reductionMask,
-      getSourceVectorType().getElementType());
-  // TODO: update to support 0-d vectors when available.
-  if (targetType != getDestType())
-    return emitError("invalid output vector type: ")
-           << getDestType() << " (expected: " << targetType << ")";
+  build(builder, result, kind, source, builder.getI64ArrayAttr(reductionDims));
+}
+
+LogicalResult MultiDimReductionOp::inferReturnTypes(
+    MLIRContext *, Optional<Location>, ValueRange operands,
+    DictionaryAttr attributes, RegionRange,
+    SmallVectorImpl<Type> &inferredReturnTypes) {
+  MultiDimReductionOp::Adaptor op(operands, attributes);
+  auto vectorType = op.source().getType().cast<VectorType>();
+  SmallVector<int64_t> targetShape;
+  for (auto it : llvm::enumerate(vectorType.getShape()))
+    if (!llvm::any_of(op.reduction_dims().getValue(), [&](Attribute attr) {
+          return attr.cast<IntegerAttr>().getValue() == it.index();
+        }))
+      targetShape.push_back(it.value());
+  // TODO: update to also allow 0-d vectors when available.
+  if (targetShape.empty())
+    inferredReturnTypes.push_back(vectorType.getElementType());
+  else
+    inferredReturnTypes.push_back(
+        VectorType::get(targetShape, vectorType.getElementType()));
   return success();
 }
 

diff  --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 2e224f7f58eb..c90725e5d8d7 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1137,6 +1137,13 @@ func @reduce_unsupported_rank(%arg0: vector<4x16xf32>) -> f32 {
 
 // -----
 
+func @multi_reduce_invalid_type(%arg0: vector<4x16xf32>) -> f32 {
+  // expected-error at +1 {{'vector.multi_reduction' op inferred type(s) 'vector<4xf32>' are incompatible with return type(s) of operation 'vector<16xf32>'}}
+  %0 = vector.multi_reduction <mul>, %arg0 [1] : vector<4x16xf32> to vector<16xf32>
+}
+
+// -----
+
 func @transpose_rank_mismatch(%arg0: vector<4x16x11xf32>) {
   // expected-error at +1 {{'vector.transpose' op vector result rank mismatch: 1}}
   %0 = vector.transpose %arg0, [2, 1, 0] : vector<4x16x11xf32> to vector<100xf32>


        


More information about the Mlir-commits mailing list