[Mlir-commits] [mlir] 1d8cc45 - [mlir][vector] Add patterns to convert multidimreduce to vector.contract

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Oct 21 14:04:08 PDT 2021


Author: thomasraoux
Date: 2021-10-21T14:03:32-07:00
New Revision: 1d8cc45b0e4e316dbc4d4724964f1115fdb3f909

URL: https://github.com/llvm/llvm-project/commit/1d8cc45b0e4e316dbc4d4724964f1115fdb3f909
DIFF: https://github.com/llvm/llvm-project/commit/1d8cc45b0e4e316dbc4d4724964f1115fdb3f909.diff

LOG: [mlir][vector] Add patterns to convert multidimreduce to vector.contract

add several patterns that will simplify contraction vectorization in the
future. With those canonicalizationns we will be able to remove the special
case for contration during vectorization and rely on those transformations to
avoid materizalizing broadcast ops.

Differential Revision: https://reviews.llvm.org/D112121

Added: 
    mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir

Modified: 
    mlir/include/mlir/Dialect/Vector/VectorOps.h
    mlir/lib/Dialect/Vector/VectorTransforms.cpp
    mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/VectorOps.h b/mlir/include/mlir/Dialect/Vector/VectorOps.h
index 694875f5e143e..5a5eadf6ad7d1 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.h
@@ -215,6 +215,10 @@ void populateVectorTransposeLoweringPatterns(
     RewritePatternSet &patterns,
     VectorTransformsOptions options = VectorTransformsOptions());
 
+/// Collect patterns to convert reduction op to vector.contract and fold
+/// transpose/broadcast ops into the contract.
+void populateVetorReductionToContractPatterns(RewritePatternSet &patterns);
+
 /// Returns the integer type required for subscripts in the vector dialect.
 IntegerType getVectorSubscriptType(Builder &builder);
 

diff  --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index 51ef008a2fa21..75ee23b5ffb3b 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -240,6 +240,13 @@ sliceTransferIndices(int64_t index, ArrayRef<int64_t> originalShape,
   return slicedIndices;
 }
 
+template <typename IntType>
+static SmallVector<IntType, 4> extractVector(ArrayAttr arrayAttr) {
+  return llvm::to_vector<4>(llvm::map_range(
+      arrayAttr.getAsRange<IntegerAttr>(),
+      [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
+}
+
 namespace {
 
 struct UnrollTransferReadPattern
@@ -1114,6 +1121,193 @@ class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
   }
 };
 
+/// Convert MulIOp/MulFOp + MultiDimReductionOp<add> into ContractionOp.
+/// Ex:
+/// ```
+///   %0 = arith.mulf %arg0, %arg1 : vector<8x32x16xf32>
+///   %1 = vector.multi_reduction #vector.kind<add>, %0 [1]
+///     : vector<8x32x16xf32> to vector<8x16xf32>
+/// ```
+/// Gets converted to:
+/// ```
+///   %1 = vector.contract {indexing_maps = [
+///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
+///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
+///         affine_map<(d0, d1, d2) -> (d0, d1)>],
+///    iterator_types = ["parallel", "parallel", "reduction"],
+///    kind = #vector.kind<add>} %0, %arg1, %cst_f0
+///    : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
+///  ```
+struct MultiReduceToContract
+    : public OpRewritePattern<vector::MultiDimReductionOp> {
+  using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::MultiDimReductionOp reduceOp,
+                                PatternRewriter &rewriter) const override {
+    if (reduceOp.kind() != vector::CombiningKind::ADD)
+      return failure();
+    Operation *mulOp = reduceOp.source().getDefiningOp();
+    if (!mulOp || !isa<arith::MulIOp, arith::MulFOp>(mulOp))
+      return failure();
+    SmallVector<bool> reductionMask = reduceOp.getReductionMask();
+    auto srcMap = rewriter.getMultiDimIdentityMap(reductionMask.size());
+    SmallVector<AffineExpr> exprs;
+    SmallVector<StringRef> iteratorTypes;
+    for (auto isReduceDim : llvm::enumerate(reductionMask)) {
+      if (!isReduceDim.value()) {
+        iteratorTypes.push_back(getParallelIteratorTypeName());
+        exprs.push_back(rewriter.getAffineDimExpr(isReduceDim.index()));
+      } else {
+        iteratorTypes.push_back(getReductionIteratorTypeName());
+      }
+    }
+    auto dstMap = AffineMap::get(/*dimCount=*/reductionMask.size(),
+                                 /*symCount=*/0, exprs, reduceOp.getContext());
+    Value zero = rewriter.create<arith::ConstantOp>(
+        reduceOp.getLoc(), reduceOp.getDestType(),
+        rewriter.getZeroAttr(reduceOp.getDestType()));
+    rewriter.replaceOpWithNewOp<mlir::vector::ContractionOp>(
+        reduceOp, mulOp->getOperand(0), mulOp->getOperand(1), zero,
+        rewriter.getAffineMapArrayAttr({srcMap, srcMap, dstMap}),
+        rewriter.getStrArrayAttr(iteratorTypes));
+    return success();
+  }
+};
+
+/// Merge TransposeOp into ContractionOp user.
+/// Ex:
+/// ```
+///   %0 = vector.transpose %arg0, [2, 0, 1]
+///     : vector<32x16x8xf32> to vector<8x32x16xf32>
+///   %1 = vector.contract {indexing_maps = [
+///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
+///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
+///         affine_map<(d0, d1, d2) -> (d0, d1)>],
+///    iterator_types = ["parallel", "parallel", "reduction"],
+///    kind = #vector.kind<add>} %0, %arg1, %cst_f0
+///    : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
+/// ```
+/// Gets converted to:
+/// ```
+///   %1 = vector.contract {indexing_maps = [
+///         affine_map<(d0, d1, d2) -> (d1, d2, d0)>,
+///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
+///         affine_map<(d0, d1, d2) -> (d0, d1)>],
+///    iterator_types = ["parallel", "parallel", "reduction"],
+///    kind = #vector.kind<add>} %arg0, %arg1, %cst_f0
+///    : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
+///  ```
+struct CombineContractTranspose
+    : public OpRewritePattern<vector::ContractionOp> {
+  using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
+                                PatternRewriter &rewriter) const override {
+    SmallVector<AffineMap, 4> maps =
+        llvm::to_vector<4>(contractOp.getIndexingMaps());
+    Value lhs = contractOp.lhs();
+    Value rhs = contractOp.rhs();
+    size_t index = 0;
+    bool changed = false;
+    for (Value *operand : {&lhs, &rhs}) {
+      AffineMap &map = maps[index++];
+      auto transposeOp = operand->getDefiningOp<vector::TransposeOp>();
+      if (!transposeOp)
+        continue;
+      SmallVector<int64_t> perm;
+      transposeOp.getTransp(perm);
+      AffineMap permutationMap = AffineMap::getPermutationMap(
+          extractVector<unsigned>(transposeOp.transp()),
+          contractOp.getContext());
+      map = inversePermutation(permutationMap).compose(map);
+      *operand = transposeOp.vector();
+      changed = true;
+    }
+    if (!changed)
+      return failure();
+    rewriter.replaceOpWithNewOp<vector::ContractionOp>(
+        contractOp, lhs, rhs, contractOp.acc(),
+        rewriter.getAffineMapArrayAttr(maps), contractOp.iterator_types());
+    return success();
+  }
+};
+
+/// Merge BroadcastOp into ContractionOp user.
+/// Ex:
+/// ```
+///   %0 = vector.broadcast %arg0 : vector<32x16xf32> to vector<8x32x16xf32>
+///   %1 = vector.contract {indexing_maps = [
+///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
+///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
+///         affine_map<(d0, d1, d2) -> (d0, d1)>],
+///    iterator_types = ["parallel", "parallel", "reduction"],
+///    kind = #vector.kind<add>} %0, %arg1, %cst_f0
+///    : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
+/// ```
+/// Gets converted to:
+/// ```
+///   %1 = vector.contract {indexing_maps = [
+///         affine_map<(d0, d1, d2) -> (d1, d2)>,
+///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
+///         affine_map<(d0, d1, d2) -> (d0, d1)>],
+///    iterator_types = ["parallel", "parallel", "reduction"],
+///    kind = #vector.kind<add>} %arg0, %arg1, %cst_f0
+///    : vector<32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
+///  ```
+struct CombineContractBroadcast
+    : public OpRewritePattern<vector::ContractionOp> {
+  using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
+                                PatternRewriter &rewriter) const override {
+    SmallVector<AffineMap, 4> maps =
+        llvm::to_vector<4>(contractOp.getIndexingMaps());
+    Value lhs = contractOp.lhs();
+    Value rhs = contractOp.rhs();
+    size_t index = 0;
+    bool changed = false;
+    for (Value *operand : {&lhs, &rhs}) {
+      AffineMap &map = maps[index++];
+      auto broadcast = operand->getDefiningOp<vector::BroadcastOp>();
+      if (!broadcast)
+        continue;
+      // contractionOp can only take vector as operands.
+      auto srcType = broadcast.getSourceType().dyn_cast<VectorType>();
+      if (!srcType || srcType.getRank() == broadcast.getVectorType().getRank())
+        continue;
+      int64_t rankDiff =
+          broadcast.getVectorType().getRank() - srcType.getRank();
+      bool innerDimBroadcast = false;
+      SmallVector<AffineExpr> originalDims;
+      for (auto dim : llvm::enumerate(srcType.getShape())) {
+        if (dim.value() !=
+            broadcast.getVectorType().getDimSize(rankDiff + dim.index())) {
+          innerDimBroadcast = true;
+          break;
+        }
+        originalDims.push_back(
+            rewriter.getAffineDimExpr(dim.index() + rankDiff));
+      }
+      // Contract doesn't support inner dimension broadcast. Once this is
+      // relaxed we can remove this case.
+      if (innerDimBroadcast)
+        continue;
+      AffineMap broadcastMap =
+          AffineMap::get(broadcast.getVectorType().getRank(), 0, originalDims,
+                         contractOp.getContext());
+      map = broadcastMap.compose(map);
+      *operand = broadcast.source();
+      changed = true;
+    }
+    if (!changed)
+      return failure();
+    rewriter.replaceOpWithNewOp<vector::ContractionOp>(
+        contractOp, lhs, rhs, contractOp.acc(),
+        rewriter.getAffineMapArrayAttr(maps), contractOp.iterator_types());
+    return success();
+  }
+};
+
 } // namespace
 
 /// Creates an AddIOp if `isInt` is true otherwise create an arith::AddFOp using
@@ -3668,6 +3862,12 @@ void mlir::vector::populateVectorTransposeLoweringPatterns(
   patterns.add<TransposeOpLowering>(options, patterns.getContext());
 }
 
+void mlir::vector::populateVetorReductionToContractPatterns(
+    RewritePatternSet &patterns) {
+  patterns.add<MultiReduceToContract, CombineContractBroadcast,
+               CombineContractTranspose>(patterns.getContext());
+}
+
 void mlir::vector::populateVectorTransferPermutationMapLoweringPatterns(
     RewritePatternSet &patterns) {
   patterns.add<TransferReadPermutationLowering,

diff  --git a/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir b/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir
new file mode 100644
index 0000000000000..ccab0617e241e
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir
@@ -0,0 +1,87 @@
+// RUN: mlir-opt %s -test-vector-reduction-to-contract-patterns -split-input-file | FileCheck %s
+
+// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+
+// CHECK-LABEL: multidimreduction_contract
+//  CHECK-NEXT:   %[[C0:.+]] = arith.constant dense<0.000000e+00> : vector<8x16xf32>
+//  CHECK-NEXT:   %[[R:.+]] = vector.contract {indexing_maps = [#[[$map0]], #[[$map0]], #[[$map1]]],
+//  CHECK-SAME:   iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind<add>}
+//  CHECK-SAME:   %{{.*}}, %{{.*}}, %[[C0]] : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x16xf32>
+//  CHECK-NEXT:   return %[[R]] : vector<8x16xf32>
+func @multidimreduction_contract(
+  %arg0: vector<8x32x16xf32>,%arg1: vector<8x32x16xf32>) -> vector<8x16xf32> {
+  %0 = arith.mulf %arg0, %arg1 : vector<8x32x16xf32>
+  %1 = vector.multi_reduction #vector.kind<add>, %0 [1] : vector<8x32x16xf32> to vector<8x16xf32>
+  return %1 : vector<8x16xf32>
+}
+
+// -----
+
+// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+
+// CHECK-LABEL: multidimreduction_contract_int
+//  CHECK-NEXT:   %[[C0:.+]] = arith.constant dense<0> : vector<8x16xi32>
+//  CHECK-NEXT:   %[[R:.+]] = vector.contract {indexing_maps = [#[[$map0]], #[[$map0]], #[[$map1]]],
+//  CHECK-SAME:   iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind<add>}
+//  CHECK-SAME:   %{{.*}}, %{{.*}}, %[[C0]] : vector<8x32x16xi32>, vector<8x32x16xi32> into vector<8x16xi32>
+//  CHECK-NEXT:   return %[[R]] : vector<8x16xi32>
+func @multidimreduction_contract_int(
+  %arg0: vector<8x32x16xi32>,%arg1: vector<8x32x16xi32>) -> vector<8x16xi32> {
+  %0 = arith.muli %arg0, %arg1 : vector<8x32x16xi32>
+  %1 = vector.multi_reduction #vector.kind<add>, %0 [1] : vector<8x32x16xi32> to vector<8x16xi32>
+  return %1 : vector<8x16xi32>
+}
+
+// -----
+
+#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1, d2) -> (d1, d2, d0)>
+// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL: contract_transpose
+//  CHECK-SAME: (%[[ARG0:.+]]: vector<32x16x8xf32>,
+//  CHECK-NEXT:   %[[C0:.+]] = arith.constant dense<0.000000e+00> : vector<8x32xf32>
+//  CHECK-NEXT:   %[[R:.+]] = vector.contract {indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]],
+//  CHECK-SAME:   iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+//  CHECK-SAME:   %[[ARG0]], %{{.*}}, %[[C0]] : vector<32x16x8xf32>, vector<8x32x16xf32> into vector<8x32xf32>
+//  CHECK-NEXT:   return %[[R]] : vector<8x32xf32>
+func @contract_transpose(
+  %arg0: vector<32x16x8xf32>, %arg1: vector<8x32x16xf32>) -> vector<8x32xf32> {
+  %cst = arith.constant dense<0.000000e+00> : vector<8x32xf32>
+  %0 = vector.transpose %arg0, [2, 0, 1] : vector<32x16x8xf32> to vector<8x32x16xf32>
+  %1 = vector.contract {indexing_maps = [#map0, #map0, #map1],
+    iterator_types = ["parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>} %0, %arg1, %cst : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
+  return %1 : vector<8x32xf32>
+}
+
+// -----
+
+#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL: contract_broadcast
+//  CHECK-SAME: (%[[ARG0:.+]]: vector<32x16xf32>,
+//  CHECK-NEXT:   %[[C0:.+]] = arith.constant dense<0.000000e+00> : vector<8x32xf32>
+//  CHECK-NEXT:   %[[R:.+]] = vector.contract {indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]],
+//  CHECK-SAME:   iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+//  CHECK-SAME:   %[[ARG0]], %{{.*}}, %[[C0]] : vector<32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
+//  CHECK-NEXT:   return %[[R]] : vector<8x32xf32>
+func @contract_broadcast(
+  %arg0: vector<32x16xf32>, %arg1: vector<8x32x16xf32>) -> vector<8x32xf32> {
+  %cst = arith.constant dense<0.000000e+00> : vector<8x32xf32>
+  %0 = vector.broadcast %arg0 : vector<32x16xf32> to vector<8x32x16xf32>
+  %1 = vector.contract {indexing_maps = [#map0, #map0, #map1],
+    iterator_types = ["parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>} %0, %arg1, %cst : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
+  return %1 : vector<8x32xf32>
+}

diff  --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 06c64d4fc42b9..d5538b09dadda 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -493,6 +493,23 @@ struct TestVectorTransferCollapseInnerMostContiguousDims
   }
 };
 
+struct TestVectorReduceToContractPatternsPatterns
+    : public PassWrapper<TestVectorReduceToContractPatternsPatterns,
+                         FunctionPass> {
+  StringRef getArgument() const final {
+    return "test-vector-reduction-to-contract-patterns";
+  }
+  StringRef getDescription() const final {
+    return "Test patterns to convert multireduce op to contract and combine "
+           "broadcast/transpose to contract";
+  }
+  void runOnFunction() override {
+    RewritePatternSet patterns(&getContext());
+    populateVetorReductionToContractPatterns(patterns);
+    (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
+  }
+};
+
 } // end anonymous namespace
 
 namespace mlir {
@@ -519,6 +536,8 @@ void registerTestVectorConversions() {
   PassRegistration<TestVectorMultiReductionLoweringPatterns>();
 
   PassRegistration<TestVectorTransferCollapseInnerMostContiguousDims>();
+
+  PassRegistration<TestVectorReduceToContractPatternsPatterns>();
 }
 } // namespace test
 } // namespace mlir


        


More information about the Mlir-commits mailing list