[Mlir-commits] [mlir] 499e89f - Add patterns to lower vector.multi_reduction into a sequence of vector.reduction

Ahmed Taei llvmlistbot at llvm.org
Fri Apr 30 10:53:20 PDT 2021


Author: Ahmed Taei
Date: 2021-04-30T10:52:21-07:00
New Revision: 499e89fc9119d901132bcc8ab460b1c161c22acc

URL: https://github.com/llvm/llvm-project/commit/499e89fc9119d901132bcc8ab460b1c161c22acc
DIFF: https://github.com/llvm/llvm-project/commit/499e89fc9119d901132bcc8ab460b1c161c22acc.diff

LOG: Add patterns to lower vector.multi_reduction into a sequence of vector.reduction

Three patterns are added to convert into vector.multi_reduction into a
sequence of vector.reduction as the following:

- Transpose the inputs so inner most dimensions are always reduction.
- Reduce rank of vector.multi_reduction into 2d with inner most
reduction dim (get the 2d canical form)
- 2D canonical form is converted into a sequence of vector.reduction.

There are two things we might worth in a follow up diff:

- An scf.for (maybe optionally) around vector.reduction instead of unrolling it.
- Breakdown the vector.reduction into a sequence of vector.reduction
(e.g tree-based reduction) instead of relying on how downstream dialects
handle it.
  Note: this will requires passing target-vector-length

Differential Revision: https://reviews.llvm.org/D101570

Added: 
    mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir

Modified: 
    mlir/include/mlir/Dialect/Vector/VectorOps.h
    mlir/lib/Dialect/Vector/VectorTransforms.cpp
    mlir/test/lib/Transforms/TestVectorTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/VectorOps.h b/mlir/include/mlir/Dialect/Vector/VectorOps.h
index c11e8112b2e8e..399a90645df22 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.h
@@ -92,6 +92,10 @@ void populateVectorTransferLoweringPatterns(RewritePatternSet &patterns);
 void populateVectorMaskMaterializationPatterns(RewritePatternSet &patterns,
                                                bool enableIndexOptimizations);
 
+// Collect a set of patterns to convert vector.multi_reduction op into
+// a sequence of vector.reduction ops.
+void populateVectorMultiReductionLoweringPatterns(RewritePatternSet &patterns);
+
 /// An attribute that specifies the combining function for `vector.contract`,
 /// and `vector.reduction`.
 class CombiningKindAttr

diff  --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index 58e2c1d3a83dc..dab1c14e51979 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -3575,6 +3575,198 @@ class VectorCreateMaskOpConversion
   const bool enableIndexOptimizations;
 };
 
+// Converts vector.multi_reduction into inner-most reduction form by inserting
+// vector.transpose
+struct InnerDimReductionConversion
+    : public OpRewritePattern<vector::MultiDimReductionOp> {
+  using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
+                                PatternRewriter &rewriter) const override {
+    auto src = multiReductionOp.source();
+    auto loc = multiReductionOp.getLoc();
+    auto srcRank = multiReductionOp.getSourceVectorType().getRank();
+
+    auto reductionDims = llvm::to_vector<4>(
+        llvm::map_range(multiReductionOp.reduction_dims().cast<ArrayAttr>(),
+                        [](Attribute attr) -> int64_t {
+                          return attr.cast<IntegerAttr>().getInt();
+                        }));
+    llvm::sort(reductionDims);
+
+    int64_t reductionSize = multiReductionOp.reduction_dims().size();
+
+    // Fails if already inner most reduction.
+    bool innerMostReduction = true;
+    for (int i = 0; i < reductionSize; ++i) {
+      if (reductionDims[reductionSize - i - 1] != srcRank - i - 1) {
+        innerMostReduction = false;
+      }
+    }
+    if (innerMostReduction)
+      return failure();
+
+    // Permutes the indices so reduction dims are inner most dims.
+    SmallVector<int64_t> indices;
+    for (int i = 0; i < srcRank; ++i) {
+      indices.push_back(i);
+    }
+    int ir = reductionSize - 1;
+    int id = srcRank - 1;
+    while (ir >= 0) {
+      std::swap(indices[reductionDims[ir--]], indices[id--]);
+    }
+
+    // Sets inner most dims as reduction.
+    SmallVector<bool> reductionMask(srcRank, false);
+    for (int i = 0; i < reductionSize; ++i) {
+      reductionMask[srcRank - i - 1] = true;
+    }
+    auto transposeOp = rewriter.create<vector::TransposeOp>(loc, src, indices);
+    rewriter.replaceOpWithNewOp<vector::MultiDimReductionOp>(
+        multiReductionOp, transposeOp.result(), reductionMask,
+        multiReductionOp.kind());
+    return success();
+  }
+};
+
+// Reduces the rank of vector.mult_reduction nd -> 2d given all reduction
+// dimensions are inner most.
+struct ReduceMultiDimReductionRank
+    : public OpRewritePattern<vector::MultiDimReductionOp> {
+  using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
+                                PatternRewriter &rewriter) const override {
+    auto srcRank = multiReductionOp.getSourceVectorType().getRank();
+    auto srcShape = multiReductionOp.getSourceVectorType().getShape();
+    if (srcRank == 2)
+      return failure();
+
+    auto loc = multiReductionOp.getLoc();
+    auto reductionDims = llvm::to_vector<4>(
+        llvm::map_range(multiReductionOp.reduction_dims().cast<ArrayAttr>(),
+                        [](Attribute attr) -> int64_t {
+                          return attr.cast<IntegerAttr>().getInt();
+                        }));
+    llvm::sort(reductionDims);
+
+    // Fails if not inner most reduction.
+    int64_t reductionSize = reductionDims.size();
+    bool innerMostReduction = true;
+    for (int i = 0; i < reductionSize; ++i) {
+      if (reductionDims[reductionSize - i - 1] != srcRank - i - 1) {
+        innerMostReduction = false;
+      }
+    }
+    if (!innerMostReduction)
+      return failure();
+
+    // Extracts 2d rank reduction shape.
+    int innerDims = 1;
+    int outterDims = 1;
+    SmallVector<int64_t> innerDimsShape;
+    for (int i = 0; i < srcRank; ++i) {
+      if (i < (srcRank - reductionSize)) {
+        innerDims *= srcShape[i];
+        innerDimsShape.push_back(srcShape[i]);
+      } else {
+        outterDims *= srcShape[i];
+      }
+    }
+
+    // Creates shape cast for the inputs n_d -> 2d
+    auto castedType = VectorType::get(
+        {innerDims, outterDims},
+        multiReductionOp.getSourceVectorType().getElementType());
+    auto castedOp = rewriter.create<vector::ShapeCastOp>(
+        loc, castedType, multiReductionOp.source());
+
+    // Creates the canonical form of 2d vector.multi_reduction with inner most
+    // dim as reduction.
+    auto newOp = rewriter.create<vector::MultiDimReductionOp>(
+        loc, castedOp.result(), ArrayRef<bool>{false, true},
+        multiReductionOp.kind());
+
+    // Creates shape cast for the output 2d -> nd
+    auto outputCastedType = VectorType::get(
+        innerDimsShape,
+        multiReductionOp.getSourceVectorType().getElementType());
+    Value castedOutputOp = rewriter.create<vector::ShapeCastOp>(
+        loc, outputCastedType, newOp.dest());
+
+    rewriter.replaceOp(multiReductionOp, castedOutputOp);
+    return success();
+  }
+};
+
+// Converts 2d vector.multi_reduction with inner most reduction dimension into a
+// sequence of vector.reduction ops.
+struct TwoDimMultiReductionToReduction
+    : public OpRewritePattern<vector::MultiDimReductionOp> {
+  using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
+                                PatternRewriter &rewriter) const override {
+    auto srcRank = multiReductionOp.getSourceVectorType().getRank();
+    if (srcRank != 2)
+      return failure();
+
+    if (multiReductionOp.getReductionMask()[0] ||
+        !multiReductionOp.getReductionMask()[1])
+      return failure();
+
+    auto loc = multiReductionOp.getLoc();
+
+    Value result =
+        multiReductionOp.getDestVectorType().getElementType().isIntOrIndex()
+            ? rewriter.create<ConstantOp>(
+                  loc, multiReductionOp.getDestVectorType(),
+                  DenseElementsAttr::get(multiReductionOp.getDestVectorType(),
+                                         0))
+            : rewriter.create<ConstantOp>(
+                  loc, multiReductionOp.getDestVectorType(),
+                  DenseElementsAttr::get(multiReductionOp.getDestVectorType(),
+                                         0.0f));
+
+    int outerDim = multiReductionOp.getSourceVectorType().getShape()[0];
+
+    // TODO: Add vector::CombiningKind attribute instead of string to
+    // vector.reduction.
+    auto getKindStr = [](vector::CombiningKind kind) {
+      switch (kind) {
+      case vector::CombiningKind::ADD:
+        return "add";
+      case vector::CombiningKind::MUL:
+        return "mul";
+      case vector::CombiningKind::MIN:
+        return "min";
+      case vector::CombiningKind::MAX:
+        return "max";
+      case vector::CombiningKind::AND:
+        return "and";
+      case vector::CombiningKind::OR:
+        return "or";
+      case vector::CombiningKind::XOR:
+        return "xor";
+      }
+    };
+
+    for (int i = 0; i < outerDim; ++i) {
+      auto v = rewriter.create<vector::ExtractOp>(
+          loc, multiReductionOp.source(), ArrayRef<int64_t>{i});
+      auto reducedValue = rewriter.create<vector::ReductionOp>(
+          loc, multiReductionOp.getDestVectorType().getElementType(),
+          rewriter.getStringAttr(getKindStr(multiReductionOp.kind())), v,
+          ValueRange{});
+      result = rewriter.create<vector::InsertElementOp>(loc, reducedValue,
+                                                        result, i);
+    }
+    rewriter.replaceOp(multiReductionOp, result);
+    return success();
+  }
+};
+
 void mlir::vector::populateVectorMaskMaterializationPatterns(
     RewritePatternSet &patterns, bool enableIndexOptimizations) {
   patterns.add<VectorCreateMaskOpConversion,
@@ -3645,3 +3837,9 @@ void mlir::vector::populateVectorTransferLoweringPatterns(
            TransferReadPermutationLowering, TransferOpReduceRank>(
           patterns.getContext());
 }
+
+void mlir::vector::populateVectorMultiReductionLoweringPatterns(
+    RewritePatternSet &patterns) {
+  patterns.add<InnerDimReductionConversion, ReduceMultiDimReductionRank,
+               TwoDimMultiReductionToReduction>(patterns.getContext());
+}

diff  --git a/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir
new file mode 100644
index 0000000000000..6cfc4e035719d
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir
@@ -0,0 +1,66 @@
+// RUN: mlir-opt %s -test-vector-multi-reduction-lowering-patterns | FileCheck %s
+
+func @vector_multi_reduction(%arg0: vector<2x4xf32>) -> vector<2xf32> {
+    %0 = vector.multi_reduction #vector.kind<mul>, %arg0 [1] : vector<2x4xf32> to vector<2xf32>
+    return %0 : vector<2xf32>
+}
+// CHECK-LABEL: func @vector_multi_reduction
+//  CHECK-SAME:   %[[INPUT:.+]]: vector<2x4xf32>
+//       CHECK:       %[[RESULT_VEC_0:.+]] = constant dense<{{.*}}> : vector<2xf32>
+//       CHECK:       %[[C0:.+]] = constant 0 : i32
+//       CHECK:       %[[C1:.+]] = constant 1 : i32
+//       CHECK:       %[[V0:.+]] = vector.extract %[[INPUT]][0]
+//       CHECK:       %[[RV0:.+]] = vector.reduction "mul", %[[V0]] : vector<4xf32> into f32
+//       CHECK:       %[[RESULT_VEC_1:.+]] = vector.insertelement %[[RV0:.+]], %[[RESULT_VEC_0]][%[[C0]] : i32] : vector<2xf32>
+//       CHECK:       %[[V1:.+]] = vector.extract %[[INPUT]][1]
+//       CHECK:       %[[RV1:.+]] = vector.reduction "mul", %[[V1]] : vector<4xf32> into f32
+//       CHECK:       %[[RESULT_VEC:.+]] = vector.insertelement %[[RV1:.+]], %[[RESULT_VEC_1]][%[[C1]] : i32] : vector<2xf32>
+//       CHECK:       return %[[RESULT_VEC]]
+
+func @vector_reduction_inner(%arg0: vector<2x3x4x5xi32>) -> vector<2x3xi32> {
+    %0 = vector.multi_reduction #vector.kind<add>, %arg0 [2, 3] : vector<2x3x4x5xi32> to vector<2x3xi32>
+    return %0 : vector<2x3xi32>
+}
+// CHECK-LABEL: func @vector_reduction_inner
+//  CHECK-SAME:   %[[INPUT:.+]]: vector<2x3x4x5xi32>
+//       CHECK:       %[[FLAT_RESULT_VEC_0:.+]] = constant dense<0> : vector<6xi32>
+//   CHECK-DAG:       %[[C0:.+]] = constant 0 : i32
+//   CHECK-DAG:       %[[C1:.+]] = constant 1 : i32
+//   CHECK-DAG:       %[[C2:.+]] = constant 2 : i32
+//   CHECK-DAG:       %[[C3:.+]] = constant 3 : i32
+//   CHECK-DAG:       %[[C4:.+]] = constant 4 : i32
+//   CHECK-DAG:       %[[C5:.+]] = constant 5 : i32
+//       CHECK:       %[[RESHAPED_INPUT:.+]] = vector.shape_cast %[[INPUT]] : vector<2x3x4x5xi32> to vector<6x20xi32>
+//       CHECK:       %[[V0:.+]] = vector.extract %[[RESHAPED_INPUT]][0] : vector<6x20xi32>
+//       CHECK:       %[[V0R:.+]] = vector.reduction "add", %[[V0]] : vector<20xi32> into i32
+//       CHECK:       %[[FLAT_RESULT_VEC_1:.+]] = vector.insertelement %[[V0R]], %[[FLAT_RESULT_VEC_0]][%[[C0]] : i32] : vector<6xi32>
+//       CHECK:       %[[V1:.+]] = vector.extract %[[RESHAPED_INPUT]][1] : vector<6x20xi32>
+//       CHECK:       %[[V1R:.+]] = vector.reduction "add", %[[V1]] : vector<20xi32> into i32
+//       CHECK:       %[[FLAT_RESULT_VEC_2:.+]] = vector.insertelement %[[V1R]], %[[FLAT_RESULT_VEC_1]][%[[C1]] : i32] : vector<6xi32>
+//       CHECK:       %[[V2:.+]] = vector.extract %[[RESHAPED_INPUT]][2] : vector<6x20xi32>
+//       CHECK:       %[[V2R:.+]] = vector.reduction "add", %[[V2]] : vector<20xi32> into i32
+//       CHECK:       %[[FLAT_RESULT_VEC_3:.+]] = vector.insertelement %[[V2R]], %[[FLAT_RESULT_VEC_2]][%[[C2]] : i32] : vector<6xi32>
+//       CHECK:       %[[V3:.+]] = vector.extract %[[RESHAPED_INPUT]][3] : vector<6x20xi32>
+//       CHECK:       %[[V3R:.+]] = vector.reduction "add", %[[V3]] : vector<20xi32> into i32
+//       CHECK:       %[[FLAT_RESULT_VEC_4:.+]] = vector.insertelement %[[V3R]], %[[FLAT_RESULT_VEC_3]][%[[C3]] : i32] : vector<6xi32>
+//       CHECK:       %[[V4:.+]] = vector.extract %[[RESHAPED_INPUT]][4] : vector<6x20xi32>
+//       CHECK:       %[[V4R:.+]] = vector.reduction "add", %[[V4]] : vector<20xi32> into i32
+//       CHECK:       %[[FLAT_RESULT_VEC_5:.+]] = vector.insertelement %[[V4R]], %[[FLAT_RESULT_VEC_4]][%[[C4]] : i32] : vector<6xi32>
+///       CHECK:      %[[V5:.+]] = vector.extract %[[RESHAPED_INPUT]][5] : vector<6x20xi32>
+//       CHECK:       %[[V5R:.+]] = vector.reduction "add", %[[V5]] : vector<20xi32> into i32
+//       CHECK:       %[[FLAT_RESULT_VEC:.+]] = vector.insertelement %[[V5R]], %[[FLAT_RESULT_VEC_5]][%[[C5]] : i32] : vector<6xi32>
+//       CHECK:       %[[RESULT:.+]] = vector.shape_cast %[[FLAT_RESULT_VEC]] : vector<6xi32> to vector<2x3xi32>
+//       CHECK:       return %[[RESULT]]     
+
+
+func @vector_multi_reduction_transposed(%arg0: vector<2x3x4x5xf32>) -> vector<2x5xf32> {
+    %0 = vector.multi_reduction #vector.kind<add>, %arg0 [1, 2] : vector<2x3x4x5xf32> to vector<2x5xf32>
+    return %0 : vector<2x5xf32>
+}
+
+// CHECK-LABEL: func @vector_multi_reduction_transposed
+//  CHECK-SAME:    %[[INPUT:.+]]: vector<2x3x4x5xf32>
+//       CHECK:     %[[TRANSPOSED_INPUT:.+]] = vector.transpose %[[INPUT]], [0, 3, 1, 2] : vector<2x3x4x5xf32> to vector<2x5x3x4xf32>
+//       CHEKC:     vector.shape_cast %[[TRANSPOSED_INPUT]] : vector<2x5x3x4xf32> to vector<10x12xf32>
+//       CHECK:     %[[RESULT:.+]] = vector.shape_cast %{{.*}} : vector<10xf32> to vector<2x5xf32>
+//       CHECK:       return %[[RESULT]]     

diff  --git a/mlir/test/lib/Transforms/TestVectorTransforms.cpp b/mlir/test/lib/Transforms/TestVectorTransforms.cpp
index b125680efa70c..d78ce3db4bca5 100644
--- a/mlir/test/lib/Transforms/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestVectorTransforms.cpp
@@ -376,6 +376,19 @@ struct TestVectorTransferLoweringPatterns
   }
 };
 
+struct TestVectorMultiReductionLoweringPatterns
+    : public PassWrapper<TestVectorMultiReductionLoweringPatterns,
+                         FunctionPass> {
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<memref::MemRefDialect>();
+  }
+  void runOnFunction() override {
+    RewritePatternSet patterns(&getContext());
+    populateVectorMultiReductionLoweringPatterns(patterns);
+    (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
+  }
+};
+
 struct TestProgressiveVectorToSCFLoweringPatterns
     : public PassWrapper<TestProgressiveVectorToSCFLoweringPatterns,
                          FunctionPass> {
@@ -439,6 +452,12 @@ void registerTestVectorConversions() {
   PassRegistration<TestProgressiveVectorToSCFLoweringPatterns> transferOpToSCF(
       "test-progressive-convert-vector-to-scf",
       "Test conversion patterns to progressively lower transfer ops to SCF");
+
+  PassRegistration<TestVectorMultiReductionLoweringPatterns>
+      multiDimReductionOpLoweringPass(
+          "test-vector-multi-reduction-lowering-patterns",
+          "Test conversion patterns to lower vector.multi_reduction to other "
+          "vector ops");
 }
 } // namespace test
 } // namespace mlir


        


More information about the Mlir-commits mailing list