[Mlir-commits] [mlir] fa596c6 - [mlir][Vector] Fix reordering of floating point adds during lower of `vector.contract`.

Mahesh Ravishankar llvmlistbot at llvm.org
Mon Jun 27 22:27:13 PDT 2022


Author: Mahesh Ravishankar
Date: 2022-06-28T05:26:39Z
New Revision: fa596c6921159af50e69cc3be189d951521a9eb9

URL: https://github.com/llvm/llvm-project/commit/fa596c6921159af50e69cc3be189d951521a9eb9
DIFF: https://github.com/llvm/llvm-project/commit/fa596c6921159af50e69cc3be189d951521a9eb9.diff

LOG: [mlir][Vector] Fix reordering of floating point adds during lower of `vector.contract`.

Adding the accumulator value after the `vector.contract` changes the
precision of the operation. This makes sure the accumulator is carried
through to `vector.reduce` (and down to LLVM).

Differential Revision: https://reviews.llvm.org/D128674

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
    mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
    mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
    mlir/test/Dialect/Vector/invalid.mlir
    mlir/test/Dialect/Vector/vector-contract-transforms.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 152e5387c2e0a..57c02c9a35ba3 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -91,7 +91,7 @@ def Vector_ContractionOp :
     Arguments<(ins AnyVector:$lhs, AnyVector:$rhs, AnyType:$acc,
                Variadic<VectorOf<[I1]>>:$masks,
                Vector_AffineMapArrayAttr:$indexing_maps,
-	       ArrayAttr:$iterator_types,
+               ArrayAttr:$iterator_types,
                DefaultValuedAttr<Vector_CombiningKindAttr,
                                  "CombiningKind::ADD">:$kind)>,
     Results<(outs AnyType)> {
@@ -280,8 +280,7 @@ def Vector_ReductionOp :
   let description = [{
     Reduces an 1-D vector "horizontally" into a scalar using the given
     operation (add/mul/min/max for int/fp and and/or/xor for int only).
-    Some reductions (add/mul for fp) also allow an optional fused
-    accumulator.
+    Reductions also allow an optional fused accumulator.
 
     Note that these operations are restricted to 1-D vectors to remain
     close to the corresponding LLVM intrinsics:
@@ -1760,7 +1759,7 @@ def Vector_GatherOp :
   Vector_Op<"gather">,
     Arguments<(ins Arg<AnyMemRef, "", [MemRead]>:$base,
                Variadic<Index>:$indices,
-	       VectorOfRankAndType<[1], [AnyInteger, Index]>:$index_vec,
+               VectorOfRankAndType<[1], [AnyInteger, Index]>:$index_vec,
                VectorOfRankAndType<[1], [I1]>:$mask,
                VectorOfRank<[1]>:$pass_thru)>,
     Results<(outs VectorOfRank<[1]>:$result)> {
@@ -1826,7 +1825,7 @@ def Vector_ScatterOp :
   Vector_Op<"scatter">,
     Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
                Variadic<Index>:$indices,
-	       VectorOfRankAndType<[1], [AnyInteger, Index]>:$index_vec,
+               VectorOfRankAndType<[1], [AnyInteger, Index]>:$index_vec,
                VectorOfRankAndType<[1], [I1]>:$mask,
                VectorOfRank<[1]>:$valueToStore)> {
 

diff  --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index a164c7d167dc6..fa4920486aad8 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -362,6 +362,37 @@ class VectorCompressStoreOpConversion
   }
 };
 
+/// Helper method to lower a `vector.reduction` op that performs an arithmetic
+/// operation like add,mul, etc.. `VectorOp` is the LLVM vector intrinsic to use
+/// and `ScalarOp` is the scalar operation used to add the accumulation value if
+/// non-null.
+template <class VectorOp, class ScalarOp>
+static Value createIntegerReductionArithmeticOpLowering(
+    ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
+    Value vectorOperand, Value accumulator) {
+  Value result = rewriter.create<VectorOp>(loc, llvmType, vectorOperand);
+  if (accumulator)
+    result = rewriter.create<ScalarOp>(loc, accumulator, result);
+  return result;
+}
+
+/// Helper method to lower a `vector.reduction` operation that performs
+/// a comparison operation like `min`/`max`. `VectorOp` is the LLVM vector
+/// intrinsic to use and `predicate` is the predicate to use to compare+combine
+/// the accumulator value if non-null.
+template <class VectorOp>
+static Value createIntegerReductionComparisonOpLowering(
+    ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
+    Value vectorOperand, Value accumulator, LLVM::ICmpPredicate predicate) {
+  Value result = rewriter.create<VectorOp>(loc, llvmType, vectorOperand);
+  if (accumulator) {
+    Value cmp =
+        rewriter.create<LLVM::ICmpOp>(loc, predicate, accumulator, result);
+    result = rewriter.create<LLVM::SelectOp>(loc, cmp, accumulator, result);
+  }
+  return result;
+}
+
 /// Conversion pattern for all vector reductions.
 class VectorReductionOpConversion
     : public ConvertOpToLLVMPattern<vector::ReductionOp> {
@@ -377,38 +408,68 @@ class VectorReductionOpConversion
     auto kind = reductionOp.getKind();
     Type eltType = reductionOp.getDest().getType();
     Type llvmType = typeConverter->convertType(eltType);
-    Value operand = adaptor.getOperands()[0];
+    Value operand = adaptor.getVector();
+    Value acc = adaptor.getAcc();
+    Location loc = reductionOp.getLoc();
     if (eltType.isIntOrIndex()) {
       // Integer reductions: add/mul/min/max/and/or/xor.
-      if (kind == vector::CombiningKind::ADD)
-        rewriter.replaceOpWithNewOp<LLVM::vector_reduce_add>(reductionOp,
-                                                             llvmType, operand);
-      else if (kind == vector::CombiningKind::MUL)
-        rewriter.replaceOpWithNewOp<LLVM::vector_reduce_mul>(reductionOp,
-                                                             llvmType, operand);
-      else if (kind == vector::CombiningKind::MINUI)
-        rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umin>(
-            reductionOp, llvmType, operand);
-      else if (kind == vector::CombiningKind::MINSI)
-        rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smin>(
-            reductionOp, llvmType, operand);
-      else if (kind == vector::CombiningKind::MAXUI)
-        rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umax>(
-            reductionOp, llvmType, operand);
-      else if (kind == vector::CombiningKind::MAXSI)
-        rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smax>(
-            reductionOp, llvmType, operand);
-      else if (kind == vector::CombiningKind::AND)
-        rewriter.replaceOpWithNewOp<LLVM::vector_reduce_and>(reductionOp,
-                                                             llvmType, operand);
-      else if (kind == vector::CombiningKind::OR)
-        rewriter.replaceOpWithNewOp<LLVM::vector_reduce_or>(reductionOp,
-                                                            llvmType, operand);
-      else if (kind == vector::CombiningKind::XOR)
-        rewriter.replaceOpWithNewOp<LLVM::vector_reduce_xor>(reductionOp,
-                                                             llvmType, operand);
-      else
+      Value result;
+      switch (kind) {
+      case vector::CombiningKind::ADD:
+        result =
+            createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_add,
+                                                       LLVM::AddOp>(
+                rewriter, loc, llvmType, operand, acc);
+        break;
+      case vector::CombiningKind::MUL:
+        result =
+            createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_mul,
+                                                       LLVM::MulOp>(
+                rewriter, loc, llvmType, operand, acc);
+        break;
+      case vector::CombiningKind::MINUI:
+        result = createIntegerReductionComparisonOpLowering<
+            LLVM::vector_reduce_umin>(rewriter, loc, llvmType, operand, acc,
+                                      LLVM::ICmpPredicate::ule);
+        break;
+      case vector::CombiningKind::MINSI:
+        result = createIntegerReductionComparisonOpLowering<
+            LLVM::vector_reduce_smin>(rewriter, loc, llvmType, operand, acc,
+                                      LLVM::ICmpPredicate::sle);
+        break;
+      case vector::CombiningKind::MAXUI:
+        result = createIntegerReductionComparisonOpLowering<
+            LLVM::vector_reduce_umax>(rewriter, loc, llvmType, operand, acc,
+                                      LLVM::ICmpPredicate::uge);
+        break;
+      case vector::CombiningKind::MAXSI:
+        result = createIntegerReductionComparisonOpLowering<
+            LLVM::vector_reduce_smax>(rewriter, loc, llvmType, operand, acc,
+                                      LLVM::ICmpPredicate::sge);
+        break;
+      case vector::CombiningKind::AND:
+        result =
+            createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_and,
+                                                       LLVM::AndOp>(
+                rewriter, loc, llvmType, operand, acc);
+        break;
+      case vector::CombiningKind::OR:
+        result =
+            createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_or,
+                                                       LLVM::OrOp>(
+                rewriter, loc, llvmType, operand, acc);
+        break;
+      case vector::CombiningKind::XOR:
+        result =
+            createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_xor,
+                                                       LLVM::XOrOp>(
+                rewriter, loc, llvmType, operand, acc);
+        break;
+      default:
         return failure();
+      }
+      rewriter.replaceOp(reductionOp, result);
+
       return success();
     }
 

diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index ebf36627b6308..8332c0b8b260b 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -403,15 +403,6 @@ LogicalResult ReductionOp::verify() {
            << eltType << "' for kind '" << stringifyCombiningKind(getKind())
            << "'";
 
-  // Verify optional accumulator.
-  if (getAcc()) {
-    if (getKind() != CombiningKind::ADD && getKind() != CombiningKind::MUL)
-      return emitOpError("no accumulator for reduction kind: ")
-             << stringifyCombiningKind(getKind());
-    if (!eltType.isa<FloatType>())
-      return emitOpError("no accumulator for type: ") << eltType;
-  }
-
   return success();
 }
 
@@ -1969,7 +1960,7 @@ LogicalResult InsertOp::verify() {
       (static_cast<unsigned>(srcVectorType.getRank()) + positionAttr.size() !=
        static_cast<unsigned>(destVectorType.getRank())))
     return emitOpError("expected position attribute rank + source rank to "
-                          "match dest vector rank");
+                       "match dest vector rank");
   if (!srcVectorType &&
       (positionAttr.size() != static_cast<unsigned>(destVectorType.getRank())))
     return emitOpError(
@@ -2302,8 +2293,7 @@ LogicalResult ReshapeOp::verify() {
   int64_t numFixedVectorSizes = fixedVectorSizes.size();
 
   if (inputVectorType.getRank() != inputShapeRank + numFixedVectorSizes)
-    return emitError("invalid input shape for vector type ")
-           << inputVectorType;
+    return emitError("invalid input shape for vector type ") << inputVectorType;
 
   if (outputVectorType.getRank() != outputShapeRank + numFixedVectorSizes)
     return emitError("invalid output shape for vector type ")
@@ -2396,24 +2386,29 @@ LogicalResult ExtractStridedSliceOp::verify() {
   auto sizes = getSizesAttr();
   auto strides = getStridesAttr();
   if (offsets.size() != sizes.size() || offsets.size() != strides.size())
-    return emitOpError("expected offsets, sizes and strides attributes of same size");
+    return emitOpError(
+        "expected offsets, sizes and strides attributes of same size");
 
   auto shape = type.getShape();
   auto offName = getOffsetsAttrName();
   auto sizesName = getSizesAttrName();
   auto stridesName = getStridesAttrName();
-  if (failed(isIntegerArrayAttrSmallerThanShape(*this, offsets, shape, offName)) ||
-      failed(isIntegerArrayAttrSmallerThanShape(*this, sizes, shape, sizesName)) ||
+  if (failed(
+          isIntegerArrayAttrSmallerThanShape(*this, offsets, shape, offName)) ||
+      failed(
+          isIntegerArrayAttrSmallerThanShape(*this, sizes, shape, sizesName)) ||
       failed(isIntegerArrayAttrSmallerThanShape(*this, strides, shape,
                                                 stridesName)) ||
-      failed(isIntegerArrayAttrConfinedToShape(*this, offsets, shape, offName)) ||
+      failed(
+          isIntegerArrayAttrConfinedToShape(*this, offsets, shape, offName)) ||
       failed(isIntegerArrayAttrConfinedToShape(*this, sizes, shape, sizesName,
                                                /*halfOpen=*/false,
                                                /*min=*/1)) ||
-      failed(isIntegerArrayAttrConfinedToRange(*this, strides, 1, 1, stridesName,
+      failed(isIntegerArrayAttrConfinedToRange(*this, strides, 1, 1,
+                                               stridesName,
                                                /*halfOpen=*/false)) ||
-      failed(isSumOfIntegerArrayAttrConfinedToShape(*this, offsets, sizes, shape,
-                                                    offName, sizesName,
+      failed(isSumOfIntegerArrayAttrConfinedToShape(*this, offsets, sizes,
+                                                    shape, offName, sizesName,
                                                     /*halfOpen=*/false)))
     return failure();
 
@@ -4223,7 +4218,7 @@ LogicalResult BitCastOp::verify() {
   if (sourceVectorType.getRank() == 0) {
     if (sourceElementBits != resultElementBits)
       return emitOpError("source/result bitwidth of the 0-D vector element "
-                            "types must be equal");
+                         "types must be equal");
   } else if (sourceElementBits * sourceVectorType.getShape().back() !=
              resultElementBits * resultVectorType.getShape().back()) {
     return emitOpError(

diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 67635d69ddeaf..338ffd053486a 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1875,10 +1875,9 @@ Value ContractionOpLowering::lowerReduction(vector::ContractionOp op,
     assert(rhsType.getRank() == 1 && "corrupt contraction");
     Value m = createMul(loc, op.getLhs(), op.getRhs(), isInt, rewriter);
     auto kind = vector::CombiningKind::ADD;
-    Value res = rewriter.create<vector::ReductionOp>(loc, kind, m);
     if (auto acc = op.getAcc())
-      res = createAdd(op.getLoc(), res, acc, isInt, rewriter);
-    return res;
+      return rewriter.create<vector::ReductionOp>(loc, kind, m, acc);
+    return rewriter.create<vector::ReductionOp>(loc, kind, m);
   }
   // Construct new iterator types and affine map array attribute.
   std::array<AffineMap, 3> lowIndexingMaps = {

diff  --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index adad1ea016ad2..0d8406c793d67 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1178,6 +1178,206 @@ func.func @reduce_i32(%arg0: vector<16xi32>) -> i32 {
 
 // -----
 
+func.func @reduce_acc_i32(%arg0: vector<16xi32>, %arg1 : i32) -> i32 {
+  %0 = vector.reduction <add>, %arg0, %arg1 : vector<16xi32> into i32
+  return %0 : i32
+}
+// CHECK-LABEL: @reduce_acc_i32(
+//  CHECK-SAME: %[[A:.*]]: vector<16xi32>, %[[ACC:.*]]: i32)
+//       CHECK: %[[R:.*]] = "llvm.intr.vector.reduce.add"(%[[A]])
+//       CHECK: %[[V:.*]] = llvm.add %[[ACC]], %[[R]]
+//       CHECK: return %[[V]] : i32
+
+// -----
+
+func.func @reduce_mul_i32(%arg0: vector<16xi32>) -> i32 {
+  %0 = vector.reduction <mul>, %arg0 : vector<16xi32> into i32
+  return %0 : i32
+}
+// CHECK-LABEL: @reduce_mul_i32(
+//  CHECK-SAME: %[[A:.*]]: vector<16xi32>)
+//       CHECK: %[[V:.*]] = "llvm.intr.vector.reduce.mul"(%[[A]])
+//       CHECK: return %[[V]] : i32
+
+// -----
+
+func.func @reduce_mul_acc_i32(%arg0: vector<16xi32>, %arg1 : i32) -> i32 {
+  %0 = vector.reduction <mul>, %arg0, %arg1 : vector<16xi32> into i32
+  return %0 : i32
+}
+// CHECK-LABEL: @reduce_mul_acc_i32(
+//  CHECK-SAME: %[[A:.*]]: vector<16xi32>, %[[ACC:.*]]: i32)
+//       CHECK: %[[R:.*]] = "llvm.intr.vector.reduce.mul"(%[[A]])
+//       CHECK: %[[V:.*]] = llvm.mul %[[ACC]], %[[R]]
+//       CHECK: return %[[V]] : i32
+
+// -----
+
+func.func @reduce_minui_i32(%arg0: vector<16xi32>) -> i32 {
+  %0 = vector.reduction <minui>, %arg0 : vector<16xi32> into i32
+  return %0 : i32
+}
+// CHECK-LABEL: @reduce_minui_i32(
+//  CHECK-SAME: %[[A:.*]]: vector<16xi32>)
+//       CHECK: %[[V:.*]] = "llvm.intr.vector.reduce.umin"(%[[A]])
+//       CHECK: return %[[V]] : i32
+
+// -----
+
+func.func @reduce_minui_acc_i32(%arg0: vector<16xi32>, %arg1 : i32) -> i32 {
+  %0 = vector.reduction <minui>, %arg0, %arg1 : vector<16xi32> into i32
+  return %0 : i32
+}
+// CHECK-LABEL: @reduce_minui_acc_i32(
+//  CHECK-SAME: %[[A:.*]]: vector<16xi32>, %[[ACC:.*]]: i32)
+//       CHECK: %[[R:.*]] = "llvm.intr.vector.reduce.umin"(%[[A]])
+//       CHECK: %[[S:.*]] = llvm.icmp "ule" %[[ACC]], %[[R]]
+//       CHECK: %[[V:.*]] = llvm.select %[[S]], %[[ACC]], %[[R]]
+//       CHECK: return %[[V]] : i32
+
+// -----
+
+func.func @reduce_maxui_i32(%arg0: vector<16xi32>) -> i32 {
+  %0 = vector.reduction <maxui>, %arg0 : vector<16xi32> into i32
+  return %0 : i32
+}
+// CHECK-LABEL: @reduce_maxui_i32(
+//  CHECK-SAME: %[[A:.*]]: vector<16xi32>)
+//       CHECK: %[[V:.*]] = "llvm.intr.vector.reduce.umax"(%[[A]])
+//       CHECK: return %[[V]] : i32
+
+// -----
+
+func.func @reduce_maxui_acc_i32(%arg0: vector<16xi32>, %arg1 : i32) -> i32 {
+  %0 = vector.reduction <maxui>, %arg0, %arg1 : vector<16xi32> into i32
+  return %0 : i32
+}
+// CHECK-LABEL: @reduce_maxui_acc_i32(
+//  CHECK-SAME: %[[A:.*]]: vector<16xi32>, %[[ACC:.*]]: i32)
+//       CHECK: %[[R:.*]] = "llvm.intr.vector.reduce.umax"(%[[A]])
+//       CHECK: %[[S:.*]] = llvm.icmp "uge" %[[ACC]], %[[R]]
+//       CHECK: %[[V:.*]] = llvm.select %[[S]], %[[ACC]], %[[R]]
+//       CHECK: return %[[V]] : i32
+
+// -----
+
+func.func @reduce_minsi_i32(%arg0: vector<16xi32>) -> i32 {
+  %0 = vector.reduction <minsi>, %arg0 : vector<16xi32> into i32
+  return %0 : i32
+}
+// CHECK-LABEL: @reduce_minsi_i32(
+//  CHECK-SAME: %[[A:.*]]: vector<16xi32>)
+//       CHECK: %[[V:.*]] = "llvm.intr.vector.reduce.smin"(%[[A]])
+//       CHECK: return %[[V]] : i32
+
+// -----
+
+func.func @reduce_minsi_acc_i32(%arg0: vector<16xi32>, %arg1 : i32) -> i32 {
+  %0 = vector.reduction <minsi>, %arg0, %arg1 : vector<16xi32> into i32
+  return %0 : i32
+}
+// CHECK-LABEL: @reduce_minsi_acc_i32(
+//  CHECK-SAME: %[[A:.*]]: vector<16xi32>, %[[ACC:.*]]: i32)
+//       CHECK: %[[R:.*]] = "llvm.intr.vector.reduce.smin"(%[[A]])
+//       CHECK: %[[S:.*]] = llvm.icmp "sle" %[[ACC]], %[[R]]
+//       CHECK: %[[V:.*]] = llvm.select %[[S]], %[[ACC]], %[[R]]
+//       CHECK: return %[[V]] : i32
+
+// -----
+
+func.func @reduce_maxsi_i32(%arg0: vector<16xi32>) -> i32 {
+  %0 = vector.reduction <maxsi>, %arg0 : vector<16xi32> into i32
+  return %0 : i32
+}
+// CHECK-LABEL: @reduce_maxsi_i32(
+//  CHECK-SAME: %[[A:.*]]: vector<16xi32>)
+//       CHECK: %[[V:.*]] = "llvm.intr.vector.reduce.smax"(%[[A]])
+//       CHECK: return %[[V]] : i32
+
+// -----
+
+func.func @reduce_maxsi_acc_i32(%arg0: vector<16xi32>, %arg1 : i32) -> i32 {
+  %0 = vector.reduction <maxsi>, %arg0, %arg1 : vector<16xi32> into i32
+  return %0 : i32
+}
+// CHECK-LABEL: @reduce_maxsi_acc_i32(
+//  CHECK-SAME: %[[A:.*]]: vector<16xi32>, %[[ACC:.*]]: i32)
+//       CHECK: %[[R:.*]] = "llvm.intr.vector.reduce.smax"(%[[A]])
+//       CHECK: %[[S:.*]] = llvm.icmp "sge" %[[ACC]], %[[R]]
+//       CHECK: %[[V:.*]] = llvm.select %[[S]], %[[ACC]], %[[R]]
+//       CHECK: return %[[V]] : i32
+
+// -----
+
+func.func @reduce_and_i32(%arg0: vector<16xi32>) -> i32 {
+  %0 = vector.reduction <and>, %arg0 : vector<16xi32> into i32
+  return %0 : i32
+}
+// CHECK-LABEL: @reduce_and_i32(
+//  CHECK-SAME: %[[A:.*]]: vector<16xi32>)
+//       CHECK: %[[V:.*]] = "llvm.intr.vector.reduce.and"(%[[A]])
+//       CHECK: return %[[V]] : i32
+
+// -----
+
+func.func @reduce_and_acc_i32(%arg0: vector<16xi32>, %arg1 : i32) -> i32 {
+  %0 = vector.reduction <and>, %arg0, %arg1 : vector<16xi32> into i32
+  return %0 : i32
+}
+// CHECK-LABEL: @reduce_and_acc_i32(
+//  CHECK-SAME: %[[A:.*]]: vector<16xi32>, %[[ACC:.*]]: i32)
+//       CHECK: %[[R:.*]] = "llvm.intr.vector.reduce.and"(%[[A]])
+//       CHECK: %[[V:.*]] = llvm.and %[[ACC]], %[[R]]
+//       CHECK: return %[[V]] : i32
+
+// -----
+
+func.func @reduce_or_i32(%arg0: vector<16xi32>) -> i32 {
+  %0 = vector.reduction <or>, %arg0 : vector<16xi32> into i32
+  return %0 : i32
+}
+// CHECK-LABEL: @reduce_or_i32(
+//  CHECK-SAME: %[[A:.*]]: vector<16xi32>)
+//       CHECK: %[[V:.*]] = "llvm.intr.vector.reduce.or"(%[[A]])
+//       CHECK: return %[[V]] : i32
+
+// -----
+
+func.func @reduce_or_acc_i32(%arg0: vector<16xi32>, %arg1 : i32) -> i32 {
+  %0 = vector.reduction <or>, %arg0, %arg1 : vector<16xi32> into i32
+  return %0 : i32
+}
+// CHECK-LABEL: @reduce_or_acc_i32(
+//  CHECK-SAME: %[[A:.*]]: vector<16xi32>, %[[ACC:.*]]: i32)
+//       CHECK: %[[R:.*]] = "llvm.intr.vector.reduce.or"(%[[A]])
+//       CHECK: %[[V:.*]] = llvm.or %[[ACC]], %[[R]]
+//       CHECK: return %[[V]] : i32
+
+// -----
+
+func.func @reduce_xor_i32(%arg0: vector<16xi32>) -> i32 {
+  %0 = vector.reduction <xor>, %arg0 : vector<16xi32> into i32
+  return %0 : i32
+}
+// CHECK-LABEL: @reduce_xor_i32(
+//  CHECK-SAME: %[[A:.*]]: vector<16xi32>)
+//       CHECK: %[[V:.*]] = "llvm.intr.vector.reduce.xor"(%[[A]])
+//       CHECK: return %[[V]] : i32
+
+// -----
+
+func.func @reduce_xor_acc_i32(%arg0: vector<16xi32>, %arg1 : i32) -> i32 {
+  %0 = vector.reduction <xor>, %arg0, %arg1 : vector<16xi32> into i32
+  return %0 : i32
+}
+// CHECK-LABEL: @reduce_xor_acc_i32(
+//  CHECK-SAME: %[[A:.*]]: vector<16xi32>, %[[ACC:.*]]: i32)
+//       CHECK: %[[R:.*]] = "llvm.intr.vector.reduce.xor"(%[[A]])
+//       CHECK: %[[V:.*]] = llvm.xor %[[ACC]], %[[R]]
+//       CHECK: return %[[V]] : i32
+
+// -----
+
 func.func @reduce_i64(%arg0: vector<16xi64>) -> i64 {
   %0 = vector.reduction <add>, %arg0 : vector<16xi64> into i64
   return %0 : i64

diff  --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 4dd9388b5cdb6..243e83e8ceb6a 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1116,27 +1116,6 @@ func.func @reduce_unsupported_third_argument(%arg0: vector<16xf32>, %arg1: f32)
 
 // -----
 
-func.func @reduce_unsupported_accumulator_kind(%arg0: vector<16xf32>, %arg1: f32) -> f32 {
-  // expected-error at +1 {{'vector.reduction' op no accumulator for reduction kind: min}}
-  %0 = vector.reduction <minf>, %arg0, %arg1 : vector<16xf32> into f32
-}
-
-// -----
-
-func.func @reduce_unsupported_accumulator_type(%arg0: vector<16xi32>, %arg1: i32) -> i32 {
-  // expected-error at +1 {{'vector.reduction' op no accumulator for type: 'i32'}}
-  %0 = vector.reduction <add>, %arg0, %arg1 : vector<16xi32> into i32
-}
-
-// -----
-
-func.func @reduce_unsupported_type(%arg0: vector<16xf32>) -> f32 {
-  // expected-error at +1 {{'vector.reduction' op unsupported reduction type}}
-  %0 = vector.reduction <xor>, %arg0 : vector<16xf32> into f32
-}
-
-// -----
-
 func.func @reduce_unsupported_rank(%arg0: vector<4x16xf32>) -> f32 {
   // expected-error at +1 {{'vector.reduction' op unsupported reduction rank: 2}}
   %0 = vector.reduction <add>, %arg0 : vector<4x16xf32> into f32

diff  --git a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
index 70f86fd4dc6dd..4123ef3b75135 100644
--- a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
@@ -19,9 +19,8 @@
 // CHECK-SAME: %[[B:.*1]]: vector<4xf32>,
 // CHECK-SAME: %[[C:.*2]]: f32
 // CHECK:      %[[F:.*]] = arith.mulf %[[A]], %[[B]] : vector<4xf32>
-// CHECK:      %[[R:.*]] = vector.reduction <add>, %[[F]] : vector<4xf32> into f32
-// CHECK:      %[[ACC:.*]] = arith.addf %[[R]], %[[C]] : f32
-// CHECK:      return %[[ACC]] : f32
+// CHECK:      %[[R:.*]] = vector.reduction <add>, %[[F]], %[[C]] : vector<4xf32> into f32
+// CHECK:      return %[[R]] : f32
 
 func.func @extract_contract1(%arg0: vector<4xf32>, %arg1: vector<4xf32>, %arg2: f32) -> f32 {
   %0 = vector.contract #dotp_trait %arg0, %arg1, %arg2
@@ -34,9 +33,8 @@ func.func @extract_contract1(%arg0: vector<4xf32>, %arg1: vector<4xf32>, %arg2:
 // CHECK-SAME: %[[B:.*1]]: vector<4xi32>,
 // CHECK-SAME: %[[C:.*2]]: i32
 // CHECK:      %[[F:.*]] = arith.muli %[[A]], %[[B]] : vector<4xi32>
-// CHECK:      %[[R:.*]] = vector.reduction <add>, %[[F]] : vector<4xi32> into i32
-// CHECK:      %[[ACC:.*]] = arith.addi %[[R]], %[[C]] : i32
-// CHECK:      return %[[ACC]] : i32
+// CHECK:      %[[R:.*]] = vector.reduction <add>, %[[F]], %[[C]] : vector<4xi32> into i32
+// CHECK:      return %[[R]] : i32
 
 func.func @extract_contract1_int(%arg0: vector<4xi32>, %arg1: vector<4xi32>, %arg2: i32) -> i32 {
   %0 = vector.contract #dotp_trait %arg0, %arg1, %arg2
@@ -72,7 +70,7 @@ func.func @extract_contract1_int(%arg0: vector<4xi32>, %arg1: vector<4xi32>, %ar
 
 func.func @extract_contract2(%arg0: vector<2x3xf32>,
                         %arg1: vector<3xf32>,
-			%arg2: vector<2xf32>) -> vector<2xf32> {
+                        %arg2: vector<2xf32>) -> vector<2xf32> {
   %0 = vector.contract #matvec_trait %arg0, %arg1, %arg2
     : vector<2x3xf32>, vector<3xf32> into vector<2xf32>
   return %0 : vector<2xf32>
@@ -95,7 +93,7 @@ func.func @extract_contract2(%arg0: vector<2x3xf32>,
 // CHECK:      return %[[T10]] : vector<2xi32>
 func.func @extract_contract2_int(%arg0: vector<2x3xi32>,
                         %arg1: vector<3xi32>,
-			%arg2: vector<2xi32>) -> vector<2xi32> {
+                        %arg2: vector<2xi32>) -> vector<2xi32> {
   %0 = vector.contract #matvec_trait %arg0, %arg1, %arg2
     : vector<2x3xi32>, vector<3xi32> into vector<2xi32>
   return %0 : vector<2xi32>
@@ -201,18 +199,16 @@ func.func @extract_contract4(%arg0: vector<2x2xf32>,
 // CHECK:      %[[T0:.*]] = vector.extract %[[A]][0] : vector<2x3xf32>
 // CHECK:      %[[T1:.*]] = vector.extract %[[B]][0] : vector<2x3xf32>
 // CHECK:      %[[T2:.*]] = arith.mulf %[[T0]], %[[T1]] : vector<3xf32>
-// CHECK:      %[[T3:.*]] = vector.reduction <add>, %[[T2]] : vector<3xf32> into f32
-// CHECK:      %[[T4:.*]] = arith.addf %[[T3]], %[[C]] : f32
+// CHECK:      %[[T3:.*]] = vector.reduction <add>, %[[T2]], %[[C]] : vector<3xf32> into f32
 // CHECK:      %[[T5:.*]] = vector.extract %[[A]][1] : vector<2x3xf32>
 // CHECK:      %[[T6:.*]] = vector.extract %[[B]][1] : vector<2x3xf32>
 // CHECK:      %[[T7:.*]] = arith.mulf %[[T5]], %[[T6]] : vector<3xf32>
-// CHECK:      %[[T8:.*]] = vector.reduction <add>, %[[T7]] : vector<3xf32> into f32
-// CHECK:      %[[T9:.*]] = arith.addf %[[T8]], %[[T4]] : f32
-// CHECK:      return %[[T9]] : f32
+// CHECK:      %[[T8:.*]] = vector.reduction <add>, %[[T7]], %[[T3]] : vector<3xf32> into f32
+// CHECK:      return %[[T8]] : f32
 
 func.func @full_contract1(%arg0: vector<2x3xf32>,
                      %arg1: vector<2x3xf32>,
-		     %arg2: f32) -> f32 {
+                     %arg2: f32) -> f32 {
   %0 = vector.contract #contraction2d_trait %arg0, %arg1, %arg2
     : vector<2x3xf32>, vector<2x3xf32> into f32
   return %0 : f32
@@ -241,8 +237,7 @@ func.func @full_contract1(%arg0: vector<2x3xf32>,
 // CHECK:      %[[T7:.*]] = vector.extract %[[B]][2, 0] : vector<3x2xf32>
 // CHECK:      %[[T9:.*]] = vector.insert %[[T7]], %[[T6]] [2] : f32 into vector<3xf32>
 // CHECK:      %[[T10:.*]] = arith.mulf %[[T0]], %[[T9]] : vector<3xf32>
-// CHECK:      %[[T11:.*]] = vector.reduction <add>, %[[T10]] : vector<3xf32> into f32
-// CHECK:      %[[ACC0:.*]] = arith.addf %[[T11]], %[[C]] : f32
+// CHECK:      %[[T11:.*]] = vector.reduction <add>, %[[T10]], %[[C]] : vector<3xf32> into f32
 //
 // CHECK:      %[[T12:.*]] = vector.extract %[[A]][1] : vector<2x3xf32>
 // CHECK:      %[[T13:.*]] = vector.extract %[[B]][0, 1] : vector<3x2xf
@@ -252,13 +247,12 @@ func.func @full_contract1(%arg0: vector<2x3xf32>,
 // CHECK:      %[[T19:.*]] = vector.extract %[[B]][2, 1] : vector<3x2xf32>
 // CHECK:      %[[T21:.*]] = vector.insert %[[T19]], %[[T18]] [2] : f32 into vector<3xf32>
 // CHECK:      %[[T22:.*]] = arith.mulf %[[T12]], %[[T21]] : vector<3xf32>
-// CHECK:      %[[T23:.*]] = vector.reduction <add>, %[[T22]] : vector<3xf32> into f32
-// CHECK:      %[[ACC1:.*]] = arith.addf %[[T23]], %[[ACC0]] : f32
-// CHECK:      return %[[ACC1]] : f32
+// CHECK:      %[[T23:.*]] = vector.reduction <add>, %[[T22]], %[[T11]] : vector<3xf32> into f32
+// CHECK:      return %[[T23]] : f32
 
 func.func @full_contract2(%arg0: vector<2x3xf32>,
                      %arg1: vector<3x2xf32>,
-		     %arg2: f32) -> f32 {
+                     %arg2: f32) -> f32 {
   %0 = vector.contract #contraction2d_trans_trait %arg0, %arg1, %arg2
     : vector<2x3xf32>, vector<3x2xf32> into f32
   return %0 : f32


        


More information about the Mlir-commits mailing list