[Mlir-commits] [mlir] e33f301 - [mlir] Add support for moving reductions to outer most dimensions in vector.multi_reduction

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Aug 13 13:16:45 PDT 2021


Author: harsh-nod
Date: 2021-08-13T12:59:50-07:00
New Revision: e33f301ec220bf5349692126b4cf5597e08185dd

URL: https://github.com/llvm/llvm-project/commit/e33f301ec220bf5349692126b4cf5597e08185dd
DIFF: https://github.com/llvm/llvm-project/commit/e33f301ec220bf5349692126b4cf5597e08185dd.diff

LOG: [mlir] Add support for moving reductions to outer most dimensions in vector.multi_reduction

The approach for handling reductions in the outer most
dimension follows that for inner most dimensions, outlined
below

First, transpose to move reduction dims, if needed
Convert reduction from n-d to 2-d canonical form
Then, for outer reductions, we emit the appropriate op
(add/mul/min/max/or/and/xor) and combine the results.

Differential Revision: https://reviews.llvm.org/D107675

Added: 
    mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir

Modified: 
    mlir/include/mlir/Dialect/Vector/VectorOps.h
    mlir/lib/Dialect/Vector/VectorTransforms.cpp
    mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/VectorOps.h b/mlir/include/mlir/Dialect/Vector/VectorOps.h
index cf53e8fcff97c..9bc2cd4e35acf 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.h
@@ -81,7 +81,8 @@ void populateVectorMaskMaterializationPatterns(RewritePatternSet &patterns,
 
 // Collect a set of patterns to convert vector.multi_reduction op into
 // a sequence of vector.reduction ops.
-void populateVectorMultiReductionLoweringPatterns(RewritePatternSet &patterns);
+void populateVectorMultiReductionLoweringPatterns(
+    RewritePatternSet &patterns, bool useInnerDimsForReduction = false);
 
 /// Collect a set of patterns to propagate insert_map/extract_map in the ssa
 /// chain.

diff  --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index 278cd6d639cfa..f3ad31c042f6f 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -3490,12 +3490,18 @@ class VectorCreateMaskOpConversion
   const bool enableIndexOptimizations;
 };
 
-// Converts vector.multi_reduction into inner-most reduction form by inserting
-// vector.transpose
-struct InnerDimReductionConversion
+// Converts vector.multi_reduction into inner-most/outer-most reduction form
+// by using vector.tranpose
+class InnerOuterDimReductionConversion
     : public OpRewritePattern<vector::MultiDimReductionOp> {
+public:
   using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
 
+  explicit InnerOuterDimReductionConversion(MLIRContext *context,
+                                            bool useInnerDimsForReduction)
+      : mlir::OpRewritePattern<vector::MultiDimReductionOp>(context),
+        useInnerDimsForReduction(useInnerDimsForReduction) {}
+
   LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
                                 PatternRewriter &rewriter) const override {
     auto src = multiReductionOp.source();
@@ -3516,87 +3522,116 @@ struct InnerDimReductionConversion
         parallelDims.push_back(i);
     }
 
-    // Add transpose only if inner-most dimensions are not reductions
-    if (parallelDims ==
-        llvm::to_vector<4>(llvm::seq<int64_t>(0, parallelDims.size())))
+    // Add transpose only if inner-most/outer-most dimensions are not parallel
+    if (useInnerDimsForReduction &&
+        (parallelDims ==
+         llvm::to_vector<4>(llvm::seq<int64_t>(0, parallelDims.size()))))
+      return failure();
+
+    if (!useInnerDimsForReduction &&
+        (parallelDims !=
+         llvm::to_vector<4>(llvm::seq<int64_t>(0, parallelDims.size()))))
       return failure();
 
     SmallVector<int64_t, 4> indices;
-    indices.append(parallelDims.begin(), parallelDims.end());
-    indices.append(reductionDims.begin(), reductionDims.end());
+    if (useInnerDimsForReduction) {
+      indices.append(parallelDims.begin(), parallelDims.end());
+      indices.append(reductionDims.begin(), reductionDims.end());
+    } else {
+      indices.append(reductionDims.begin(), reductionDims.end());
+      indices.append(parallelDims.begin(), parallelDims.end());
+    }
     auto transposeOp = rewriter.create<vector::TransposeOp>(loc, src, indices);
     SmallVector<bool> reductionMask(srcRank, false);
     for (int i = 0; i < reductionSize; ++i) {
-      reductionMask[srcRank - i - 1] = true;
+      if (useInnerDimsForReduction)
+        reductionMask[srcRank - i - 1] = true;
+      else
+        reductionMask[i] = true;
     }
     rewriter.replaceOpWithNewOp<vector::MultiDimReductionOp>(
         multiReductionOp, transposeOp.result(), reductionMask,
         multiReductionOp.kind());
     return success();
   }
+
+private:
+  const bool useInnerDimsForReduction;
 };
 
 // Reduces the rank of vector.mult_reduction nd -> 2d given all reduction
-// dimensions are inner most.
-struct ReduceMultiDimReductionRank
+// dimensions are either inner most or outer most.
+class ReduceMultiDimReductionRank
     : public OpRewritePattern<vector::MultiDimReductionOp> {
+public:
   using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
 
+  explicit ReduceMultiDimReductionRank(MLIRContext *context,
+                                       bool useInnerDimsForReduction)
+      : mlir::OpRewritePattern<vector::MultiDimReductionOp>(context),
+        useInnerDimsForReduction(useInnerDimsForReduction) {}
+
   LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
                                 PatternRewriter &rewriter) const override {
     auto srcRank = multiReductionOp.getSourceVectorType().getRank();
     auto srcShape = multiReductionOp.getSourceVectorType().getShape();
+    auto loc = multiReductionOp.getLoc();
     if (srcRank == 2)
       return failure();
 
-    auto loc = multiReductionOp.getLoc();
-    auto reductionDims = llvm::to_vector<4>(
-        llvm::map_range(multiReductionOp.reduction_dims().cast<ArrayAttr>(),
-                        [](Attribute attr) -> int64_t {
-                          return attr.cast<IntegerAttr>().getInt();
-                        }));
-    llvm::sort(reductionDims);
-
-    // Fails if not inner most reduction.
-    int64_t reductionSize = reductionDims.size();
-    bool innerMostReduction = true;
-    for (int i = 0; i < reductionSize; ++i) {
-      if (reductionDims[reductionSize - i - 1] != srcRank - i - 1) {
-        innerMostReduction = false;
+    // Separate reduction and parallel dims
+    auto reductionDimsRange =
+        multiReductionOp.reduction_dims().getAsValueRange<IntegerAttr>();
+    auto reductionDims = llvm::to_vector<4>(llvm::map_range(
+        reductionDimsRange, [](APInt a) { return a.getZExtValue(); }));
+    llvm::SmallDenseSet<int64_t> reductionDimsSet(reductionDims.begin(),
+                                                  reductionDims.end());
+    SmallVector<int64_t, 4> parallelDims, parallelShapes;
+    int canonicalReductionDim = 1;
+    int canonicalParallelDim = 1;
+    for (int64_t i = 0; i < srcRank; i++) {
+      if (!reductionDimsSet.contains(i)) {
+        parallelDims.push_back(i);
+        parallelShapes.push_back(srcShape[i]);
+        canonicalParallelDim *= srcShape[i];
+      } else {
+        canonicalReductionDim *= srcShape[i];
       }
     }
-    if (!innerMostReduction)
+
+    // Fail if reduction dims are not either inner-most or outer-most
+    if (useInnerDimsForReduction &&
+        (parallelDims !=
+         llvm::to_vector<4>(llvm::seq<int64_t>(0, parallelDims.size()))))
       return failure();
 
-    // Extracts 2d rank reduction shape.
-    int innerDims = 1;
-    int outterDims = 1;
-    SmallVector<int64_t> innerDimsShape;
-    for (int i = 0; i < srcRank; ++i) {
-      if (i < (srcRank - reductionSize)) {
-        innerDims *= srcShape[i];
-        innerDimsShape.push_back(srcShape[i]);
-      } else {
-        outterDims *= srcShape[i];
-      }
-    }
+    if (!useInnerDimsForReduction &&
+        (parallelDims ==
+         llvm::to_vector<4>(llvm::seq<int64_t>(0, parallelDims.size()))))
+      return failure();
 
     // Creates shape cast for the inputs n_d -> 2d
+    int64_t outerDim =
+        useInnerDimsForReduction ? canonicalParallelDim : canonicalReductionDim;
+    int64_t innerDim =
+        useInnerDimsForReduction ? canonicalReductionDim : canonicalParallelDim;
+
     auto castedType = VectorType::get(
-        {innerDims, outterDims},
+        ArrayRef<int64_t>{outerDim, innerDim},
         multiReductionOp.getSourceVectorType().getElementType());
     auto castedOp = rewriter.create<vector::ShapeCastOp>(
         loc, castedType, multiReductionOp.source());
 
-    // Creates the canonical form of 2d vector.multi_reduction with inner most
-    // dim as reduction.
+    // Creates the canonical form of 2d vector.multi_reduction with inner/outer
+    // most dim as reduction.
+    SmallVector<bool, 2> mask{!useInnerDimsForReduction,
+                              useInnerDimsForReduction};
     auto newOp = rewriter.create<vector::MultiDimReductionOp>(
-        loc, castedOp.result(), ArrayRef<bool>{false, true},
-        multiReductionOp.kind());
+        loc, castedOp.result(), mask, multiReductionOp.kind());
 
     // Creates shape cast for the output 2d -> nd
-    auto outputCastedType = VectorType::get(
-        innerDimsShape,
+    VectorType outputCastedType = VectorType::get(
+        parallelShapes,
         multiReductionOp.getSourceVectorType().getElementType());
     Value castedOutputOp = rewriter.create<vector::ShapeCastOp>(
         loc, outputCastedType, newOp.dest());
@@ -3604,6 +3639,88 @@ struct ReduceMultiDimReductionRank
     rewriter.replaceOp(multiReductionOp, castedOutputOp);
     return success();
   }
+
+private:
+  const bool useInnerDimsForReduction;
+};
+
+// Unrolls vector.multi_reduction with outermost reductions
+// and combines results
+struct UnrollOuterMultiReduction
+    : public OpRewritePattern<vector::MultiDimReductionOp> {
+  using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
+                                PatternRewriter &rewriter) const override {
+    auto srcRank = multiReductionOp.getSourceVectorType().getRank();
+    if (srcRank != 2)
+      return failure();
+
+    if (multiReductionOp.getReductionMask()[1] ||
+        !multiReductionOp.getReductionMask()[0])
+      return failure();
+
+    auto loc = multiReductionOp.getLoc();
+    ArrayRef<int64_t> srcShape =
+        multiReductionOp.getSourceVectorType().getShape();
+
+    Type elementType = multiReductionOp.getDestVectorType().getElementType();
+    if (!elementType.isIntOrIndexOrFloat())
+      return failure();
+
+    Value condition;
+    Value result =
+        rewriter.create<vector::ExtractOp>(loc, multiReductionOp.source(), 0)
+            .getResult();
+    for (int64_t i = 1; i < srcShape[0]; i++) {
+      auto operand =
+          rewriter.create<vector::ExtractOp>(loc, multiReductionOp.source(), i);
+      switch (multiReductionOp.kind()) {
+      case vector::CombiningKind::ADD:
+        if (elementType.isIntOrIndex())
+          result = rewriter.create<AddIOp>(loc, operand, result);
+        else
+          result = rewriter.create<AddFOp>(loc, operand, result);
+        break;
+      case vector::CombiningKind::MUL:
+        if (elementType.isIntOrIndex())
+          result = rewriter.create<MulIOp>(loc, operand, result);
+        else
+          result = rewriter.create<MulFOp>(loc, operand, result);
+        break;
+      case vector::CombiningKind::MIN:
+        if (elementType.isIntOrIndex())
+          condition =
+              rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, operand, result);
+        else
+          condition =
+              rewriter.create<CmpFOp>(loc, CmpFPredicate::OLT, operand, result);
+        result = rewriter.create<SelectOp>(loc, condition, operand, result);
+        break;
+      case vector::CombiningKind::MAX:
+        if (elementType.isIntOrIndex())
+          condition =
+              rewriter.create<CmpIOp>(loc, CmpIPredicate::sge, operand, result);
+        else
+          condition =
+              rewriter.create<CmpFOp>(loc, CmpFPredicate::OGE, operand, result);
+        result = rewriter.create<SelectOp>(loc, condition, operand, result);
+        break;
+      case vector::CombiningKind::AND:
+        result = rewriter.create<AndOp>(loc, operand, result);
+        break;
+      case vector::CombiningKind::OR:
+        result = rewriter.create<OrOp>(loc, operand, result);
+        break;
+      case vector::CombiningKind::XOR:
+        result = rewriter.create<XOrOp>(loc, operand, result);
+        break;
+      }
+    }
+
+    rewriter.replaceOp(multiReductionOp, result);
+    return success();
+  }
 };
 
 // Converts 2d vector.multi_reduction with inner most reduction dimension into a
@@ -3747,9 +3864,13 @@ void mlir::vector::populateVectorTransferLoweringPatterns(
 }
 
 void mlir::vector::populateVectorMultiReductionLoweringPatterns(
-    RewritePatternSet &patterns) {
-  patterns.add<InnerDimReductionConversion, ReduceMultiDimReductionRank,
-               TwoDimMultiReductionToReduction>(patterns.getContext());
+    RewritePatternSet &patterns, bool useInnerDimsForReduction) {
+  patterns.add<InnerOuterDimReductionConversion, ReduceMultiDimReductionRank>(
+      patterns.getContext(), useInnerDimsForReduction);
+  if (useInnerDimsForReduction)
+    patterns.add<TwoDimMultiReductionToReduction>(patterns.getContext());
+  else
+    patterns.add<UnrollOuterMultiReduction>(patterns.getContext());
 }
 
 void mlir::vector::populateVectorUnrollPatterns(

diff  --git a/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir
new file mode 100644
index 0000000000000..91dcc2e0172f7
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir
@@ -0,0 +1,161 @@
+// RUN: mlir-opt %s -test-vector-multi-reduction-lowering-patterns="use-outer-reductions" | FileCheck %s
+
+func @vector_multi_reduction(%arg0: vector<2x4xf32>) -> vector<2xf32> {
+    %0 = vector.multi_reduction #vector.kind<mul>, %arg0 [1] : vector<2x4xf32> to vector<2xf32>
+    return %0 : vector<2xf32>
+}
+
+// CHECK-LABEL: func @vector_multi_reduction
+//  CHECK-SAME:   %[[INPUT:.+]]: vector<2x4xf32>
+//       CHECK:   %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xf32> to vector<4x2xf32>
+//       CHECK:   %[[V0:.+]] = vector.extract %[[TRANSPOSED]][0] : vector<4x2xf32>
+//       CHECK:   %[[V1:.+]] = vector.extract %[[TRANSPOSED]][1] : vector<4x2xf32>
+//       CHECK:   %[[RV01:.+]] = mulf %[[V1]], %[[V0]] : vector<2xf32>
+//       CHECK:   %[[V2:.+]] = vector.extract %[[TRANSPOSED]][2] : vector<4x2xf32>
+//       CHECK:   %[[RV012:.+]] = mulf %[[V2]], %[[RV01]] : vector<2xf32>
+//       CHECK:   %[[V3:.+]] = vector.extract %[[TRANSPOSED]][3] : vector<4x2xf32>
+//       CHECK:   %[[RESULT_VEC:.+]] = mulf %[[V3]], %[[RV012]] : vector<2xf32>
+//       CHECK:   return %[[RESULT_VEC]] : vector<2xf32>
+
+func @vector_multi_reduction_min(%arg0: vector<2x4xf32>) -> vector<2xf32> {
+    %0 = vector.multi_reduction #vector.kind<min>, %arg0 [1] : vector<2x4xf32> to vector<2xf32>
+    return %0 : vector<2xf32>
+}
+
+// CHECK-LABEL: func @vector_multi_reduction_min
+//  CHECK-SAME:   %[[INPUT:.+]]: vector<2x4xf32>
+//       CHECK:   %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xf32> to vector<4x2xf32>
+//       CHECK:   %[[V0:.+]] = vector.extract %[[TRANSPOSED]][0] : vector<4x2xf32>
+//       CHECK:   %[[V1:.+]] = vector.extract %[[TRANSPOSED]][1] : vector<4x2xf32>
+//       CHECK:   %[[C0:.+]] = cmpf olt, %[[V1]], %[[V0]] : vector<2xf32>
+//       CHECK:   %[[RV01:.+]] = select %[[C0]], %[[V1]], %[[V0]] : vector<2xi1>, vector<2xf32>
+//       CHECK:   %[[V2:.+]] = vector.extract %[[TRANSPOSED]][2] : vector<4x2xf32>
+//       CHECK:   %[[C1:.+]] = cmpf olt, %[[V2]], %[[RV01]] : vector<2xf32>
+//       CHECK:   %[[RV012:.+]] = select %[[C1]], %[[V2]], %[[RV01]] : vector<2xi1>, vector<2xf32>
+//       CHECK:   %[[V3:.+]] = vector.extract %[[TRANSPOSED]][3] : vector<4x2xf32>
+//       CHECK:   %[[C2:.+]] = cmpf olt, %[[V3]], %[[RV012]] : vector<2xf32>
+//       CHECK:   %[[RESULT_VEC:.+]] = select %[[C2]], %[[V3]], %[[RV012]] : vector<2xi1>, vector<2xf32>
+//       CHECK:   return %[[RESULT_VEC]] : vector<2xf32>
+
+func @vector_multi_reduction_max(%arg0: vector<2x4xf32>) -> vector<2xf32> {
+    %0 = vector.multi_reduction #vector.kind<max>, %arg0 [1] : vector<2x4xf32> to vector<2xf32>
+    return %0 : vector<2xf32>
+}
+
+// CHECK-LABEL: func @vector_multi_reduction_max
+//  CHECK-SAME:   %[[INPUT:.+]]: vector<2x4xf32>
+//       CHECK:   %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xf32> to vector<4x2xf32>
+//       CHECK:   %[[V0:.+]] = vector.extract %[[TRANSPOSED]][0] : vector<4x2xf32>
+//       CHECK:   %[[V1:.+]] = vector.extract %[[TRANSPOSED]][1] : vector<4x2xf32>
+//       CHECK:   %[[C0:.+]] = cmpf oge, %[[V1]], %[[V0]] : vector<2xf32>
+//       CHECK:   %[[RV01:.+]] = select %[[C0]], %[[V1]], %[[V0]] : vector<2xi1>, vector<2xf32>
+//       CHECK:   %[[V2:.+]] = vector.extract %[[TRANSPOSED]][2] : vector<4x2xf32>
+//       CHECK:   %[[C1:.+]] = cmpf oge, %[[V2]], %[[RV01]] : vector<2xf32>
+//       CHECK:   %[[RV012:.+]] = select %[[C1]], %[[V2]], %[[RV01]] : vector<2xi1>, vector<2xf32>
+//       CHECK:   %[[V3:.+]] = vector.extract %[[TRANSPOSED]][3] : vector<4x2xf32>
+//       CHECK:   %[[C2:.+]] = cmpf oge, %[[V3]], %[[RV012]] : vector<2xf32>
+//       CHECK:   %[[RESULT_VEC:.+]] = select %[[C2]], %[[V3]], %[[RV012]] : vector<2xi1>, vector<2xf32>
+//       CHECK:   return %[[RESULT_VEC]] : vector<2xf32>
+
+func @vector_multi_reduction_and(%arg0: vector<2x4xi32>) -> vector<2xi32> {
+    %0 = vector.multi_reduction #vector.kind<and>, %arg0 [1] : vector<2x4xi32> to vector<2xi32>
+    return %0 : vector<2xi32>
+}
+
+// CHECK-LABEL: func @vector_multi_reduction_and
+//  CHECK-SAME:   %[[INPUT:.+]]: vector<2x4xi32>
+//       CHECK:   %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xi32> to vector<4x2xi32>
+//       CHECK:   %[[V0:.+]] = vector.extract %[[TRANSPOSED]][0] : vector<4x2xi32>
+//       CHECK:   %[[V1:.+]] = vector.extract %[[TRANSPOSED]][1] : vector<4x2xi32>
+//       CHECK:   %[[RV01:.+]] = and %[[V1]], %[[V0]] : vector<2xi32>
+//       CHECK:   %[[V2:.+]] = vector.extract %[[TRANSPOSED]][2] : vector<4x2xi32>
+//       CHECK:   %[[RV012:.+]] = and %[[V2]], %[[RV01]] : vector<2xi32>
+//       CHECK:   %[[V3:.+]] = vector.extract %[[TRANSPOSED]][3] : vector<4x2xi32>
+//       CHECK:   %[[RESULT_VEC:.+]] = and %[[V3]], %[[RV012]] : vector<2xi32>
+//       CHECK:   return %[[RESULT_VEC]] : vector<2xi32>
+
+func @vector_multi_reduction_or(%arg0: vector<2x4xi32>) -> vector<2xi32> {
+    %0 = vector.multi_reduction #vector.kind<or>, %arg0 [1] : vector<2x4xi32> to vector<2xi32>
+    return %0 : vector<2xi32>
+}
+
+// CHECK-LABEL: func @vector_multi_reduction_or
+//  CHECK-SAME:   %[[INPUT:.+]]: vector<2x4xi32>
+//       CHECK:   %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xi32> to vector<4x2xi32>
+//       CHECK:   %[[V0:.+]] = vector.extract %[[TRANSPOSED]][0] : vector<4x2xi32>
+//       CHECK:   %[[V1:.+]] = vector.extract %[[TRANSPOSED]][1] : vector<4x2xi32>
+//       CHECK:   %[[RV01:.+]] = or %[[V1]], %[[V0]] : vector<2xi32>
+//       CHECK:   %[[V2:.+]] = vector.extract %[[TRANSPOSED]][2] : vector<4x2xi32>
+//       CHECK:   %[[RV012:.+]] = or %[[V2]], %[[RV01]] : vector<2xi32>
+//       CHECK:   %[[V3:.+]] = vector.extract %[[TRANSPOSED]][3] : vector<4x2xi32>
+//       CHECK:   %[[RESULT_VEC:.+]] = or %[[V3]], %[[RV012]] : vector<2xi32>
+//       CHECK:   return %[[RESULT_VEC]] : vector<2xi32>
+
+func @vector_multi_reduction_xor(%arg0: vector<2x4xi32>) -> vector<2xi32> {
+    %0 = vector.multi_reduction #vector.kind<xor>, %arg0 [1] : vector<2x4xi32> to vector<2xi32>
+    return %0 : vector<2xi32>
+}
+
+// CHECK-LABEL: func @vector_multi_reduction_xor
+//  CHECK-SAME:   %[[INPUT:.+]]: vector<2x4xi32>
+//       CHECK:   %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xi32> to vector<4x2xi32>
+//       CHECK:   %[[V0:.+]] = vector.extract %[[TRANSPOSED]][0] : vector<4x2xi32>
+//       CHECK:   %[[V1:.+]] = vector.extract %[[TRANSPOSED]][1] : vector<4x2xi32>
+//       CHECK:   %[[RV01:.+]] = xor %[[V1]], %[[V0]] : vector<2xi32>
+//       CHECK:   %[[V2:.+]] = vector.extract %[[TRANSPOSED]][2] : vector<4x2xi32>
+//       CHECK:   %[[RV012:.+]] = xor %[[V2]], %[[RV01]] : vector<2xi32>
+//       CHECK:   %[[V3:.+]] = vector.extract %[[TRANSPOSED]][3] : vector<4x2xi32>
+//       CHECK:   %[[RESULT_VEC:.+]] = xor %[[V3]], %[[RV012]] : vector<2xi32>
+//       CHECK:   return %[[RESULT_VEC]] : vector<2xi32>
+
+
+func @vector_reduction_outer(%arg0: vector<2x3x4x5xi32>) -> vector<2x3xi32> {
+    %0 = vector.multi_reduction #vector.kind<add>, %arg0 [2, 3] : vector<2x3x4x5xi32> to vector<2x3xi32>
+    return %0 : vector<2x3xi32>
+}
+
+// CHECK-LABEL: func @vector_reduction_outer
+//  CHECK-SAME:   %[[INPUT:.+]]: vector<2x3x4x5xi32>
+//       CHECK:   %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [2, 3, 0, 1] : vector<2x3x4x5xi32> to vector<4x5x2x3xi32>
+//       CHECK:   %[[RESHAPED:.+]] = vector.shape_cast %[[TRANSPOSED]] : vector<4x5x2x3xi32> to vector<20x6xi32>
+//       CHECK:   %[[V0:.+]] = vector.extract %[[RESHAPED]][0] : vector<20x6xi32>
+//       CHECK:   %[[V1:.+]] = vector.extract %[[RESHAPED]][1] : vector<20x6xi32>
+//       CHECK:   %[[R0:.+]] = addi %[[V1]], %[[V0]] : vector<6xi32>
+//       CHECK:   %[[V2:.+]] = vector.extract %[[RESHAPED]][2] : vector<20x6xi32>
+//       CHECK:   %[[R1:.+]] = addi %[[V2]], %[[R0]] : vector<6xi32>
+//       CHECK:   %[[V3:.+]] = vector.extract %[[RESHAPED]][3] : vector<20x6xi32>
+//       CHECK:   %[[R2:.+]] = addi %[[V3]], %[[R1]] : vector<6xi32>
+//       CHECK:   %[[V4:.+]] = vector.extract %[[RESHAPED]][4] : vector<20x6xi32>
+//       CHECK:   %[[R3:.+]] = addi %[[V4]], %[[R2]] : vector<6xi32>
+//       CHECK:   %[[V5:.+]] = vector.extract %[[RESHAPED]][5] : vector<20x6xi32>
+//       CHECK:   %[[R4:.+]] = addi %[[V5]], %[[R3]] : vector<6xi32>
+//       CHECK:   %[[V6:.+]] = vector.extract %[[RESHAPED]][6] : vector<20x6xi32>
+//       CHECK:   %[[R5:.+]] = addi %[[V6]], %[[R4]] : vector<6xi32>
+//       CHECK:   %[[V7:.+]] = vector.extract %[[RESHAPED]][7] : vector<20x6xi32>
+//       CHECK:   %[[R6:.+]] = addi %[[V7]], %[[R5]] : vector<6xi32>
+//       CHECK:   %[[V8:.+]] = vector.extract %[[RESHAPED]][8] : vector<20x6xi32>
+//       CHECK:   %[[R7:.+]] = addi %[[V8]], %[[R6]] : vector<6xi32>
+//       CHECK:   %[[V9:.+]] = vector.extract %[[RESHAPED]][9] : vector<20x6xi32>
+//       CHECK:   %[[R8:.+]] = addi %[[V9]], %[[R7]] : vector<6xi32>
+//       CHECK:   %[[V10:.+]] = vector.extract %[[RESHAPED]][10] : vector<20x6xi32>
+//       CHECK:   %[[R9:.+]] = addi %[[V10]], %[[R8]] : vector<6xi32>
+//       CHECK:   %[[V11:.+]] = vector.extract %[[RESHAPED]][11] : vector<20x6xi32>
+//       CHECK:   %[[R10:.+]] = addi %[[V11]], %[[R9]] : vector<6xi32>
+//       CHECK:   %[[V12:.+]] = vector.extract %[[RESHAPED]][12] : vector<20x6xi32>
+//       CHECK:   %[[R11:.+]] = addi %[[V12]], %[[R10]] : vector<6xi32>
+//       CHECK:   %[[V13:.+]] = vector.extract %[[RESHAPED]][13] : vector<20x6xi32>
+//       CHECK:   %[[R12:.+]] = addi %[[V13]], %[[R11]] : vector<6xi32>
+//       CHECK:   %[[V14:.+]] = vector.extract %[[RESHAPED]][14] : vector<20x6xi32>
+//       CHECK:   %[[R13:.+]] = addi %[[V14]], %[[R12]] : vector<6xi32>
+//       CHECK:   %[[V15:.+]] = vector.extract %[[RESHAPED]][15] : vector<20x6xi32>
+//       CHECK:   %[[R14:.+]] = addi %[[V15]], %[[R13]] : vector<6xi32>
+//       CHECK:   %[[V16:.+]] = vector.extract %[[RESHAPED]][16] : vector<20x6xi32>
+//       CHECK:   %[[R15:.+]] = addi %[[V16]], %[[R14]] : vector<6xi32>
+//       CHECK:   %[[V17:.+]] = vector.extract %[[RESHAPED]][17] : vector<20x6xi32>
+//       CHECK:   %[[R16:.+]] = addi %[[V17]], %[[R15]] : vector<6xi32>
+//       CHECK:   %[[V18:.+]] = vector.extract %[[RESHAPED]][18] : vector<20x6xi32>
+//       CHECK:   %[[R17:.+]] = addi %[[V18]], %[[R16]] : vector<6xi32>
+//       CHECK:   %[[V19:.+]] = vector.extract %[[RESHAPED]][19] : vector<20x6xi32>
+//       CHECK:   %[[R18:.+]] = addi %[[V19]], %[[R17]] : vector<6xi32>
+//       CHECK:   %[[RESULT_VEC:.+]] = vector.shape_cast %[[R18]] : vector<6xi32> to vector<2x3xi32>
+//       CHECK:   return %[[RESULT_VEC]] : vector<2x3xi32>

diff  --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 11b56a583cc83..907f9aedfdb17 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -444,6 +444,9 @@ struct TestVectorTransferLoweringPatterns
 struct TestVectorMultiReductionLoweringPatterns
     : public PassWrapper<TestVectorMultiReductionLoweringPatterns,
                          FunctionPass> {
+  TestVectorMultiReductionLoweringPatterns() = default;
+  TestVectorMultiReductionLoweringPatterns(
+      const TestVectorMultiReductionLoweringPatterns &pass) {}
   void getDependentDialects(DialectRegistry &registry) const override {
     registry.insert<memref::MemRefDialect>();
   }
@@ -454,9 +457,13 @@ struct TestVectorMultiReductionLoweringPatterns
     return "Test conversion patterns to lower vector.multi_reduction to other "
            "vector ops";
   }
+  Option<bool> useOuterReductions{
+      *this, "use-outer-reductions",
+      llvm::cl::desc("Move reductions to outer most dimensions"),
+      llvm::cl::init(false)};
   void runOnFunction() override {
     RewritePatternSet patterns(&getContext());
-    populateVectorMultiReductionLoweringPatterns(patterns);
+    populateVectorMultiReductionLoweringPatterns(patterns, !useOuterReductions);
     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
   }
 };


        


More information about the Mlir-commits mailing list