[Mlir-commits] [mlir] 89aaa2d - [mlir][vector] Add new lowering mode to vector.contractionOp

Thomas Raoux llvmlistbot at llvm.org
Tue May 24 07:23:50 PDT 2022


Author: Thomas Raoux
Date: 2022-05-24T14:19:08Z
New Revision: 89aaa2d033270d6eeeb82429c0bb88a78ae030fa

URL: https://github.com/llvm/llvm-project/commit/89aaa2d033270d6eeeb82429c0bb88a78ae030fa
DIFF: https://github.com/llvm/llvm-project/commit/89aaa2d033270d6eeeb82429c0bb88a78ae030fa.diff

LOG: [mlir][vector] Add new lowering mode to vector.contractionOp

Add lowering for cases where the reduction dimension is fully unrolled.
It is common to unroll the reduction dimension, therefore we would want
to lower the contractions to an elementwise vector op in this case.

Differential Revision: https://reviews.llvm.org/D126120

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
    mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
    mlir/test/Dialect/Vector/vector-contract-transforms.mlir
    mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index 0522ef58bc812..e7226f4a6ac0f 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -49,6 +49,9 @@ enum class VectorContractLowering {
   Matmul = 1,
   /// Lower to `vector.outerproduct`.
   OuterProduct = 2,
+  /// Lower contract with all reduction dimensions unrolled to 1 to a vector
+  /// elementwise operations.
+  ParallelArith = 3,
 };
 /// Enum to control the splitting of `vector.transfer` operations into
 /// in-bounds and out-of-bounds variants.

diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index f84fc2da08f23..e1f4cffb93552 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -144,6 +144,59 @@ static SmallVector<IntType, 4> extractVector(ArrayAttr arrayAttr) {
       [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
 }
 
+/// Helper to create arithmetic operation associated with a kind of contraction.
+static Optional<Value> createContractArithOp(Location loc, Value x, Value y,
+                                             Value acc,
+                                             vector::CombiningKind kind,
+                                             PatternRewriter &rewriter,
+                                             bool isInt) {
+  using vector::CombiningKind;
+  Value mul;
+  if (isInt) {
+    if (kind == CombiningKind::MINF || kind == CombiningKind::MAXF)
+      // Only valid for floating point types.
+      return Optional<Value>();
+    mul = rewriter.create<arith::MulIOp>(loc, x, y);
+  } else {
+    // Float case.
+    if (kind == CombiningKind::AND || kind == CombiningKind::MINUI ||
+        kind == CombiningKind::MINSI || kind == CombiningKind::MAXUI ||
+        kind == CombiningKind::MAXSI || kind == CombiningKind::OR ||
+        kind == CombiningKind::XOR)
+      // Only valid for integer types.
+      return Optional<Value>();
+    // Special case for fused multiply-add.
+    if (acc && acc.getType().isa<VectorType>() && kind == CombiningKind::ADD) {
+      return Optional<Value>(rewriter.create<vector::FMAOp>(loc, x, y, acc));
+    }
+    mul = rewriter.create<arith::MulFOp>(loc, x, y);
+  }
+  if (!acc)
+    return Optional<Value>(mul);
+  return makeArithReduction(rewriter, loc, kind, mul, acc);
+}
+
+/// Return the positions of the reductions in the given map.
+static SmallVector<int64_t> getReductionIndex(AffineMap map,
+                                              ArrayAttr iteratorTypes) {
+  SmallVector<int64_t> dimsIdx;
+  for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
+    if (isReductionIterator(iteratorTypes[map.getDimPosition(i)]))
+      dimsIdx.push_back(i);
+  }
+  return dimsIdx;
+}
+
+/// Look for a given dimension in an affine map and return its position. Return
+/// llvm::None if the dimension is not in the map results.
+static llvm::Optional<unsigned> getDimPosition(AffineMap map, unsigned dim) {
+  for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
+    if (map.getDimPosition(i) == dim)
+      return i;
+  }
+  return llvm::None;
+}
+
 namespace {
 
 /// ShapeCastOpFolder folds cancelling ShapeCastOps away.
@@ -498,9 +551,8 @@ class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
     if (!rhsType) {
       // Special case: AXPY operation.
       Value b = rewriter.create<vector::BroadcastOp>(loc, lhsType, op.getRhs());
-      Optional<Value> mult =
-          isInt ? genMultI(loc, op.getLhs(), b, acc, kind, rewriter)
-                : genMultF(loc, op.getLhs(), b, acc, kind, rewriter);
+      Optional<Value> mult = createContractArithOp(loc, op.getLhs(), b, acc,
+                                                   kind, rewriter, isInt);
       if (!mult.hasValue())
         return failure();
       rewriter.replaceOp(op, mult.getValue());
@@ -518,8 +570,7 @@ class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
       if (acc)
         r = rewriter.create<vector::ExtractOp>(loc, rhsType, acc, pos);
       Optional<Value> m =
-          isInt ? genMultI(loc, a, op.getRhs(), r, kind, rewriter)
-                : genMultF(loc, a, op.getRhs(), r, kind, rewriter);
+          createContractArithOp(loc, a, op.getRhs(), r, kind, rewriter, isInt);
       if (!m.hasValue())
         return failure();
       result = rewriter.create<vector::InsertOp>(loc, resType, m.getValue(),
@@ -528,48 +579,127 @@ class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
     rewriter.replaceOp(op, result);
     return success();
   }
+};
 
-private:
-  static Optional<Value> genMultI(Location loc, Value x, Value y, Value acc,
-                                  vector::CombiningKind kind,
-                                  PatternRewriter &rewriter) {
-    using vector::CombiningKind;
-
-    auto mul = rewriter.create<arith::MulIOp>(loc, x, y);
-    if (!acc)
-      return Optional<Value>(mul);
-
-    if (kind == CombiningKind::MINF || kind == CombiningKind::MAXF)
-      // Only valid for floating point types.
-      return Optional<Value>();
-
-    return makeArithReduction(rewriter, loc, kind, mul, acc);
+/// Lower vector.contract with all size one reduction dimensions to
+/// elementwise ops when possible.
+struct ContractOpToElementwise
+    : public OpRewritePattern<vector::ContractionOp> {
+  using OpRewritePattern::OpRewritePattern;
+  using FilterConstraintType =
+      std::function<LogicalResult(vector::ContractionOp op)>;
+  static LogicalResult defaultFilter(vector::ContractionOp op) {
+    return success();
   }
+  ContractOpToElementwise(
+      vector::VectorTransformsOptions vectorTransformOptions,
+      MLIRContext *context,
+      const FilterConstraintType &constraint = defaultFilter)
+      : OpRewritePattern<vector::ContractionOp>(context),
+        vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {}
 
-  static Optional<Value> genMultF(Location loc, Value x, Value y, Value acc,
-                                  vector::CombiningKind kind,
-                                  PatternRewriter &rewriter) {
-    using vector::CombiningKind;
-
-    // Special case for fused multiply-add.
-    if (acc && kind == CombiningKind::ADD) {
-      return Optional<Value>(rewriter.create<vector::FMAOp>(loc, x, y, acc));
-    }
-
-    auto mul = rewriter.create<arith::MulFOp>(loc, x, y);
-
-    if (!acc)
-      return Optional<Value>(mul);
+  LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
+                                PatternRewriter &rewriter) const override {
+    // TODO: implement masks
+    if (llvm::size(contractOp.getMasks()) != 0)
+      return failure();
 
-    if (kind == CombiningKind::ADD || kind == CombiningKind::AND ||
-        kind == CombiningKind::MINUI || kind == CombiningKind::MINSI ||
-        kind == CombiningKind::MAXUI || kind == CombiningKind::MAXSI ||
-        kind == CombiningKind::OR || kind == CombiningKind::XOR)
-      // Already handled or only valid for integer types.
-      return Optional<Value>();
+    if (failed(filter(contractOp)))
+      return failure();
 
-    return makeArithReduction(rewriter, loc, kind, mul, acc);
+    if (vectorTransformOptions.vectorContractLowering !=
+        vector::VectorContractLowering::ParallelArith)
+      return failure();
+    ArrayRef<int64_t> lhsShape = contractOp.getLhsType().getShape();
+    ArrayRef<int64_t> rhsShape = contractOp.getRhsType().getShape();
+    AffineMap lhsMap = contractOp.getIndexingMaps()[0];
+    AffineMap rhsMap = contractOp.getIndexingMaps()[1];
+    SmallVector<int64_t> lhsReductionDims =
+        getReductionIndex(lhsMap, contractOp.getIteratorTypes());
+    SmallVector<int64_t> rhsReductionDims =
+        getReductionIndex(rhsMap, contractOp.getIteratorTypes());
+    // All the reduction dimensions must be a size 1.
+    for (int64_t dim : lhsReductionDims) {
+      if (lhsShape[dim] != 1)
+        return failure();
+    }
+    for (int64_t dim : rhsReductionDims) {
+      if (rhsShape[dim] != 1)
+        return failure();
+    }
+    AffineMap accMap = contractOp.getIndexingMaps()[2];
+    unsigned numParallelDims = accMap.getNumResults();
+    unsigned numLhsDimToBroadcast =
+        numParallelDims - (lhsMap.getNumResults() - lhsReductionDims.size());
+    unsigned numRhsDimToBroadcast =
+        numParallelDims - (rhsMap.getNumResults() - rhsReductionDims.size());
+    SmallVector<int64_t> lhsDims;
+    SmallVector<int64_t> lhsTranspose;
+    SmallVector<int64_t> rhsDims;
+    SmallVector<int64_t> rhsTranspose;
+    for (int64_t dim : lhsReductionDims)
+      lhsTranspose.push_back(numLhsDimToBroadcast + dim);
+    for (int64_t dim : rhsReductionDims)
+      rhsTranspose.push_back(numRhsDimToBroadcast + dim);
+    // Loop through the parallel dimensions to calculate the dimensions to
+    // broadcast and to permute in order to extract only parallel dimensions.
+    for (unsigned i = 0; i < numParallelDims; i++) {
+      llvm::Optional<unsigned> lhsDim =
+          getDimPosition(lhsMap, accMap.getDimPosition(i));
+      if (lhsDim) {
+        lhsTranspose.push_back(numLhsDimToBroadcast + *lhsDim);
+      } else {
+        // If the parallel dimension doesn't exist we will have to broadcast it.
+        lhsDims.push_back(
+            contractOp.getResultType().cast<VectorType>().getDimSize(i));
+        lhsTranspose.push_back(lhsDims.size() - 1);
+      }
+      llvm::Optional<unsigned> rhsDim =
+          getDimPosition(rhsMap, accMap.getDimPosition(i));
+      if (rhsDim) {
+        rhsTranspose.push_back(numRhsDimToBroadcast + *rhsDim);
+      } else {
+        // If the parallel dimension doesn't exist we will have to broadcast it.
+        rhsDims.push_back(
+            contractOp.getResultType().cast<VectorType>().getDimSize(i));
+        rhsTranspose.push_back(rhsDims.size() - 1);
+      }
+    }
+    Value newLhs = contractOp.getLhs();
+    Value newRhs = contractOp.getRhs();
+    Location loc = contractOp.getLoc();
+    if (!lhsDims.empty()) {
+      lhsDims.append(lhsShape.begin(), lhsShape.end());
+      auto expandedType =
+          VectorType::get(lhsDims, contractOp.getLhsType().getElementType());
+      newLhs = rewriter.create<vector::BroadcastOp>(loc, expandedType, newLhs);
+    }
+    if (!rhsDims.empty()) {
+      rhsDims.append(rhsShape.begin(), rhsShape.end());
+      auto expandedType =
+          VectorType::get(rhsDims, contractOp.getRhsType().getElementType());
+      newRhs = rewriter.create<vector::BroadcastOp>(loc, expandedType, newRhs);
+    }
+    bool isInt = contractOp.getLhsType().getElementType().isIntOrIndex();
+    newLhs = rewriter.create<vector::TransposeOp>(loc, newLhs, lhsTranspose);
+    newRhs = rewriter.create<vector::TransposeOp>(loc, newRhs, rhsTranspose);
+    SmallVector<int64_t, 4> lhsOffsets(lhsReductionDims.size(), 0);
+    SmallVector<int64_t, 4> rhsOffsets(rhsReductionDims.size(), 0);
+    newLhs = rewriter.create<vector::ExtractOp>(
+        loc, newLhs, rewriter.getI64ArrayAttr(lhsOffsets));
+    newRhs = rewriter.create<vector::ExtractOp>(
+        loc, newRhs, rewriter.getI64ArrayAttr(rhsOffsets));
+    Optional<Value> result =
+        createContractArithOp(loc, newLhs, newRhs, contractOp.getAcc(),
+                              contractOp.getKind(), rewriter, isInt);
+    rewriter.replaceOp(contractOp, {*result});
+    return success();
   }
+
+private:
+  /// Options to control the vector patterns.
+  vector::VectorTransformsOptions vectorTransformOptions;
+  FilterConstraintType filter;
 };
 
 /// Progressive lowering of ConstantMaskOp.
@@ -1594,6 +1724,9 @@ ContractionOpLowering::matchAndRewrite(vector::ContractionOp op,
   ContractionOpToDotLowering pat3(vectorTransformOptions, ctx);
   if (succeeded(pat3.matchAndRewrite(op, rewriter)))
     return success();
+  ContractOpToElementwise pat4(vectorTransformOptions, ctx);
+  if (succeeded(pat4.matchAndRewrite(op, rewriter)))
+    return success();
 
   // Find first batch dimension in LHS/RHS, and lower when found.
   std::vector<std::pair<int64_t, int64_t>> batchDimMap = op.getBatchDimMap();

diff  --git a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
index 989dee42f8a1c..e725d1883e9bd 100644
--- a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
@@ -2,6 +2,7 @@
 // RUN: mlir-opt %s -test-vector-contraction-lowering=vector-lower-matrix-intrinsics=1 | FileCheck %s --check-prefix=MATRIX
 // RUN: mlir-opt %s -test-vector-contraction-lowering=vector-outerproduct=1 | FileCheck %s --check-prefix=OUTERPRODUCT
 // RUN: mlir-opt %s -test-vector-contraction-lowering=vector-filter-outerproduct=1 | FileCheck %s --check-prefix=FILTEROUTERPRODUCT
+// RUN: mlir-opt %s -test-vector-contraction-lowering=vector-parallel-arith=1 | FileCheck %s --check-prefix=PARALLEL
 
 #dotp_accesses = [
   affine_map<(i) -> (i)>,
@@ -1104,3 +1105,54 @@ func.func @matmul_4_not_filtered(%arg0: vector<3x4xf32>, %arg1: vector<4x4xf32>,
     : vector<3x4xf32>, vector<4x4xf32> into vector<3x4xf32>
   return %0 : vector<3x4xf32>
 }
+
+// PARALLEL-LABEL: func @parrallel_contract_lowering
+//       PARALLEL:   %[[E0:.*]] = vector.extract %{{.*}}[0, 0] : vector<1x1x4xf32>
+//       PARALLEL:   %[[E1:.*]] = vector.extract %{{.*}}[0, 0] : vector<1x1x4xf32>
+//       PARALLEL:   %[[F:.*]] = vector.fma %[[E0]], %[[E1]], %{{.*}} : vector<4xf32>
+//       PARALLEL:   return %[[F]] : vector<4xf32>
+func.func @parrallel_contract_lowering(%arg0: vector<1x1x4xf32>, %arg1: vector<1x1x4xf32>, %arg2: vector<4xf32>) -> vector<4xf32> {
+  %0 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2, d0)>, affine_map<(d0, d1, d2) -> (d1, d2, d0)>, affine_map<(d0, d1, d2) -> (d0)>], iterator_types = ["parallel", "reduction", "reduction"], kind = #vector.kind<add>} %arg0, %arg1, %arg2 : vector<1x1x4xf32>, vector<1x1x4xf32> into vector<4xf32>
+  return %0 : vector<4xf32>
+}
+
+// PARALLEL-LABEL: func @parrallel_contract_lowering_broadcast
+//       PARALLEL:   %[[B:.*]] = vector.broadcast %{{.*}} : vector<1x1xf32> to vector<4x1x1xf32>
+//       PARALLEL:   %[[T:.*]] = vector.transpose %[[B]], [1, 2, 0] : vector<4x1x1xf32> to vector<1x1x4xf32>
+//       PARALLEL:   %[[E0:.*]] = vector.extract %[[T]][0, 0] : vector<1x1x4xf32>
+//       PARALLEL:   %[[E1:.*]] = vector.extract %{{.*}}[0, 0] : vector<1x1x4xf32>
+//       PARALLEL:   %[[F:.*]] = vector.fma %[[E0]], %[[E1]], %{{.*}} : vector<4xf32>
+//       PARALLEL:   return %[[F]] : vector<4xf32>
+func.func @parrallel_contract_lowering_broadcast(%arg0: vector<1x1xf32>, %arg1: vector<1x1x4xf32>, %arg2: vector<4xf32>) -> vector<4xf32> {
+  %0 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d1, d2, d0)>, affine_map<(d0, d1, d2) -> (d0)>], iterator_types = ["parallel", "reduction", "reduction"], kind = #vector.kind<add>} %arg0, %arg1, %arg2 : vector<1x1xf32>, vector<1x1x4xf32> into vector<4xf32>
+  return %0 : vector<4xf32>
+}
+
+// PARALLEL-LABEL: func @parrallel_contract_lowering
+//       PARALLEL:   %[[B:.*]] = vector.broadcast %{{.*}} : vector<1x1xf32> to vector<4x1x1xf32>
+//       PARALLEL:   %[[T0:.*]] = vector.transpose %[[B]], [1, 2, 0] : vector<4x1x1xf32> to vector<1x1x4xf32>
+//       PARALLEL:   %[[T1:.*]] = vector.transpose %{{.*}}, [0, 2, 1] : vector<1x4x1xf32> to vector<1x1x4xf32>
+//       PARALLEL:   %[[E0:.*]] = vector.extract %[[T0]][0, 0] : vector<1x1x4xf32>
+//       PARALLEL:   %[[E1:.*]] = vector.extract %[[T1]][0, 0] : vector<1x1x4xf32>
+//       PARALLEL:   %[[F:.*]] = vector.fma %[[E0]], %[[E1]], %arg2 : vector<4xf32>
+//       PARALLEL:   return %[[F]] : vector<4xf32>
+func.func @parrallel_contract_lowering_transpose(%arg0: vector<1x1xf32>, %arg1: vector<1x4x1xf32>, %arg2: vector<4xf32>) -> vector<4xf32> {
+  %0 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d1, d0, d2)>, affine_map<(d0, d1, d2) -> (d0)>], iterator_types = ["parallel", "reduction", "reduction"], kind = #vector.kind<add>} %arg0, %arg1, %arg2 : vector<1x1xf32>, vector<1x4x1xf32> into vector<4xf32>
+  return %0 : vector<4xf32>
+}
+
+// PARALLEL-LABEL: func @parrallel_contract_lowering_scalar
+//       PARALLEL:   %[[E0:.*]] = vector.extract %{{.*}}[0, 0] : vector<1x1xf32>
+//       PARALLEL:   %[[E1:.*]] = vector.extract %{{.*}}[0, 0] : vector<1x1xf32>
+//       PARALLEL:   %[[M:.*]] = arith.mulf %[[E0]], %[[E1]] : f32
+//       PARALLEL:   %[[A:.*]] = arith.addf %[[M]], %{{.*}} : f32
+//       PARALLEL:   return %[[A]] : f32
+func.func @parrallel_contract_lowering_scalar(%arg0: vector<1x1xf32>, %arg1: vector<1x1xf32>, %arg2: f32) -> f32 {
+  %0 = vector.contract {
+    indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+                     affine_map<(d0, d1) -> (d0, d1)>,
+                     affine_map<(d0, d1) -> ()>],
+    iterator_types = ["reduction", "reduction"], kind = #vector.kind<add>}
+  %arg0, %arg1, %arg2 : vector<1x1xf32>, vector<1x1xf32> into f32
+  return %0 : f32
+}

diff  --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index f2514fc932727..a81aa536df4ad 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -135,6 +135,10 @@ struct TestVectorContractionLowering
       llvm::cl::desc("Lower vector.contract to vector.outerproduct but not for "
                      "vectors of size 4."),
       llvm::cl::init(false)};
+  Option<bool> lowerToParallelArith{
+      *this, "vector-parallel-arith",
+      llvm::cl::desc("Lower vector.contract to elementwise vector ops."),
+      llvm::cl::init(false)};
 
   void runOnOperation() override {
     RewritePatternSet patterns(&getContext());
@@ -165,6 +169,15 @@ struct TestVectorContractionLowering
       return;
     }
 
+    if (lowerToParallelArith) {
+      vector::populateVectorContractLoweringPatterns(
+          patterns,
+          vector::VectorTransformsOptions().setVectorTransformsOptions(
+              vector::VectorContractLowering::ParallelArith));
+      (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+      return;
+    }
+
     // Test on all contract lowering patterns.
     VectorContractLowering contractLowering = VectorContractLowering::Dot;
     if (lowerToFlatMatrix)


        


More information about the Mlir-commits mailing list