[Mlir-commits] [mlir] f69175b - [mlir][vector] Add unrolling pattern for multidim_reduce op

Thomas Raoux llvmlistbot at llvm.org
Mon Mar 14 08:22:50 PDT 2022


Author: Thomas Raoux
Date: 2022-03-14T15:22:24Z
New Revision: f69175b1e6aba63ad349644256c58c0e3b3316f1

URL: https://github.com/llvm/llvm-project/commit/f69175b1e6aba63ad349644256c58c0e3b3316f1
DIFF: https://github.com/llvm/llvm-project/commit/f69175b1e6aba63ad349644256c58c0e3b3316f1.diff

LOG: [mlir][vector] Add unrolling pattern for multidim_reduce op

Implement the vectorLoopUnroll interface for MultiDimReduceOp and add a
pattern to do the unrolling following the same interface other vector
unroll patterns.

Differential Revision: https://reviews.llvm.org/D121263

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/lib/Dialect/Vector/Transforms/VectorUnrollDistribute.cpp
    mlir/test/Dialect/Vector/vector-unroll-options.mlir
    mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 2e7f06903824f..b0012924e5bae 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -314,7 +314,9 @@ def Vector_MultiDimReductionOp :
   Vector_Op<"multi_reduction", [NoSideEffect,
      PredOpTrait<"source operand and result have same element type",
                  TCresVTEtIsSameAsOpBase<0, 0>>,
-     DeclareOpInterfaceMethods<InferTypeOpInterface>]>,
+     DeclareOpInterfaceMethods<InferTypeOpInterface>,
+     DeclareOpInterfaceMethods<VectorUnrollOpInterface,
+                               ["getShapeForUnroll"]>]>,
     Arguments<(ins Vector_CombiningKindAttr:$kind,
                    AnyVector:$source,
                    I64ArrayAttr:$reduction_dims)>,

diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index f6547e46d5418..61b5c7aac9f6f 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -371,6 +371,10 @@ OpFoldResult MultiDimReductionOp::fold(ArrayRef<Attribute> operands) {
   return {};
 }
 
+Optional<SmallVector<int64_t, 4>> MultiDimReductionOp::getShapeForUnroll() {
+  return llvm::to_vector<4>(getSourceVectorType().getShape());
+}
+
 //===----------------------------------------------------------------------===//
 // ReductionOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnrollDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnrollDistribute.cpp
index 1da965e3fafc9..7ec69f006dba5 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnrollDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnrollDistribute.cpp
@@ -207,23 +207,23 @@ struct UnrollTransferWritePattern
   vector::UnrollVectorOptions options;
 };
 
-struct UnrollContractionPattern
-    : public OpRewritePattern<vector::ContractionOp> {
-  struct OffsetMapInfo {
-    static SmallVector<int64_t> getEmptyKey() { return {int64_t(-1)}; }
+struct OffsetMapInfo {
+  static SmallVector<int64_t> getEmptyKey() { return {int64_t(-1)}; }
 
-    static SmallVector<int64_t> getTombstoneKey() { return {int64_t(-2)}; }
+  static SmallVector<int64_t> getTombstoneKey() { return {int64_t(-2)}; }
 
-    static unsigned getHashValue(const SmallVector<int64_t> &v) {
-      return static_cast<unsigned>(
-          llvm::hash_combine_range(v.begin(), v.end()));
-    }
+  static unsigned getHashValue(const SmallVector<int64_t> &v) {
+    return static_cast<unsigned>(llvm::hash_combine_range(v.begin(), v.end()));
+  }
 
-    static bool isEqual(const SmallVector<int64_t> &lhs,
-                        const SmallVector<int64_t> &rhs) {
-      return lhs == rhs;
-    }
-  };
+  static bool isEqual(const SmallVector<int64_t> &lhs,
+                      const SmallVector<int64_t> &rhs) {
+    return lhs == rhs;
+  }
+};
+
+struct UnrollContractionPattern
+    : public OpRewritePattern<vector::ContractionOp> {
   UnrollContractionPattern(MLIRContext *context,
                            const vector::UnrollVectorOptions &options)
       : OpRewritePattern<vector::ContractionOp>(context, /*benefit=*/1),
@@ -320,6 +320,74 @@ struct UnrollContractionPattern
   vector::UnrollVectorOptions options;
 };
 
+struct UnrollMultiReductionPattern
+    : public OpRewritePattern<vector::MultiDimReductionOp> {
+  UnrollMultiReductionPattern(MLIRContext *context,
+                              const vector::UnrollVectorOptions &options)
+      : OpRewritePattern<vector::MultiDimReductionOp>(context, /*benefit=*/1),
+        options(options) {}
+
+  LogicalResult matchAndRewrite(vector::MultiDimReductionOp reductionOp,
+                                PatternRewriter &rewriter) const override {
+    Optional<SmallVector<int64_t, 4>> targetShape =
+        getTargetShape(options, reductionOp);
+    if (!targetShape)
+      return failure();
+    SmallVector<int64_t, 4> originalSize = *reductionOp.getShapeForUnroll();
+    SmallVector<int64_t, 4> ratio = *shapeRatio(originalSize, *targetShape);
+    llvm::MapVector<
+        SmallVector<int64_t>, Value,
+        llvm::DenseMap<SmallVector<int64_t>, unsigned, OffsetMapInfo>>
+        accCache;
+    // Compute shape ratio of 'shape' and 'sizes'.
+    int64_t sliceCount = computeMaxLinearIndex(ratio);
+    Location loc = reductionOp.getLoc();
+    for (int64_t i = 0; i < sliceCount; i++) {
+      SmallVector<int64_t, 4> offsets =
+          getVectorOffset(originalSize, *targetShape, i);
+
+      SmallVector<int64_t, 4> operandStrides(offsets.size(), 1);
+      Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>(
+          loc, reductionOp.getOperand(), offsets, *targetShape, operandStrides);
+
+      SmallVector<int64_t> dstShape;
+      SmallVector<int64_t> destOffset;
+      for (size_t i : llvm::seq(size_t(0), targetShape->size())) {
+        if (!reductionOp.isReducedDim(i)) {
+          destOffset.push_back(offsets[i]);
+          dstShape.push_back((*targetShape)[i]);
+        }
+      }
+      auto targetType = VectorType::get(
+          dstShape, reductionOp.getSourceVectorType().getElementType());
+      Operation *newOp = cloneOpWithOperandsAndTypes(rewriter, loc, reductionOp,
+                                                     slicedOperand, targetType);
+      Value result = newOp->getResult(0);
+      // Save the accumulated value until all the loops are unrolled since
+      // reduction loop keeps updating the accumulator.
+      auto accIt = accCache.find(destOffset);
+      if (accIt != accCache.end())
+        result = makeArithReduction(rewriter, loc, reductionOp.kind(), result,
+                                    accIt->second);
+      accCache[destOffset] = result;
+    }
+    // Assemble back the accumulator into a single vector.
+    Value result = rewriter.create<arith::ConstantOp>(
+        loc, reductionOp.getDestType(),
+        rewriter.getZeroAttr(reductionOp.getDestType()));
+    for (const auto &it : accCache) {
+      SmallVector<int64_t> dstStrides(it.first.size(), 1);
+      result = rewriter.create<vector::InsertStridedSliceOp>(
+          loc, it.second, result, it.first, dstStrides);
+    }
+    rewriter.replaceOp(reductionOp, result);
+    return success();
+  }
+
+private:
+  vector::UnrollVectorOptions options;
+};
+
 struct UnrollElementwisePattern : public RewritePattern {
   UnrollElementwisePattern(MLIRContext *context,
                            const vector::UnrollVectorOptions &options)
@@ -568,8 +636,8 @@ struct TransferWriteInsertPattern
 void mlir::vector::populateVectorUnrollPatterns(
     RewritePatternSet &patterns, const UnrollVectorOptions &options) {
   patterns.add<UnrollTransferReadPattern, UnrollTransferWritePattern,
-               UnrollContractionPattern, UnrollElementwisePattern>(
-      patterns.getContext(), options);
+               UnrollContractionPattern, UnrollElementwisePattern,
+               UnrollMultiReductionPattern>(patterns.getContext(), options);
 }
 
 void mlir::vector::populatePropagateVectorDistributionPatterns(

diff  --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
index 581039c48cb73..dd1a6fd781e47 100644
--- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir
+++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
@@ -80,3 +80,29 @@ func @vector_fma(%a: vector<4x4xf32>, %b: vector<4x4xf32>, %c: vector<4x4xf32>)
 }
 //   CHECK-LABEL: func @vector_fma
 // CHECK-COUNT-4: vector.fma %{{.+}}, %{{.+}}, %{{.+}} : vector<2x2xf32>
+
+func @vector_multi_reduction(%v : vector<4x6xf32>) -> vector<4xf32> {
+  %0 = vector.multi_reduction #vector.kind<add>, %v [1] : vector<4x6xf32> to vector<4xf32>
+  return %0 : vector<4xf32>
+}
+// CHECK-LABEL: func @vector_multi_reduction
+//       CHECK:   %[[V0:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32>
+//       CHECK:   %[[E0:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32>
+//       CHECK:   %[[R0:.*]] = vector.multi_reduction <add>, %[[E0]] [1] : vector<2x2xf32> to vector<2xf32>
+//       CHECK:   %[[E1:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32>
+//       CHECK:   %[[R1:.*]] = vector.multi_reduction <add>, %[[E1]] [1] : vector<2x2xf32> to vector<2xf32>
+//       CHECK:   %[[A0:.*]] = arith.addf %[[R1]], %[[R0]] : vector<2xf32>
+//       CHECK:   %[[E2:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 4], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32>
+//       CHECK:   %[[R2:.*]] = vector.multi_reduction <add>, %5 [1] : vector<2x2xf32> to vector<2xf32>
+//       CHECK:   %[[A1:.*]] = arith.addf %[[R2]], %[[A0]] : vector<2xf32>
+//       CHECK:   %[[E3:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32>
+//       CHECK:   %[[R3:.*]] = vector.multi_reduction <add>, %[[E3]] [1] : vector<2x2xf32> to vector<2xf32>
+//       CHECK:   %[[E4:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32>
+//       CHECK:   %[[R4:.*]] = vector.multi_reduction <add>, %[[E4]] [1] : vector<2x2xf32> to vector<2xf32>
+//       CHECK:   %[[A2:.*]] = arith.addf %[[R4]], %[[R3]] : vector<2xf32>
+//       CHECK:   %[[E5:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 4], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32>
+//       CHECK:   %[[R5:.*]] = vector.multi_reduction <add>, %[[E5]] [1] : vector<2x2xf32> to vector<2xf32>
+//       CHECK:   %[[A3:.*]] = arith.addf %[[R5]], %[[A2]] : vector<2xf32>
+//       CHECK:   %[[V1:.*]] = vector.insert_strided_slice %[[A1]], %[[V0]] {offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
+//       CHECK:   %[[V2:.*]] = vector.insert_strided_slice %[[A3]], %[[V1]] {offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32>
+//       CHECK:   return %[[V2]] : vector<4xf32>

diff  --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 67e33d3aa0b4a..2bf5e3f1a8e7d 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -265,7 +265,8 @@ struct TestVectorUnrollingPatterns
         patterns, UnrollVectorOptions()
                       .setNativeShape(ArrayRef<int64_t>{2, 2})
                       .setFilterConstraint([](Operation *op) {
-                        return success(isa<arith::AddFOp, vector::FMAOp>(op));
+                        return success(isa<arith::AddFOp, vector::FMAOp,
+                                           vector::MultiDimReductionOp>(op));
                       }));
 
     if (unrollBasedOnType) {


        


More information about the Mlir-commits mailing list