[Mlir-commits] [mlir] 3a3ab21 - [mlir][Vector] Add support for high-order masked contractions

Diego Caballero llvmlistbot at llvm.org
Tue Feb 21 23:00:10 PST 2023


Author: Diego Caballero
Date: 2023-02-22T06:54:02Z
New Revision: 3a3ab2147d7a6673c7b22bd14d8bf33ab1276d85

URL: https://github.com/llvm/llvm-project/commit/3a3ab2147d7a6673c7b22bd14d8bf33ab1276d85
DIFF: https://github.com/llvm/llvm-project/commit/3a3ab2147d7a6673c7b22bd14d8bf33ab1276d85.diff

LOG: [mlir][Vector] Add support for high-order masked contractions

This patch adds support for masked vector.contract ops that needs to be
decomposed using the ContractionOpLowering pattern. It just slices the
mask according to the rest of the lowering.

Reviewed By: ThomasRaoux

Differential Revision: https://reviews.llvm.org/D144427

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
    mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
    mlir/test/Dialect/Vector/vector-contract-transforms.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index fb2d07ebf413d..775bfbf0241bf 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -510,12 +510,12 @@ class ContractionOpLowering : public OpRewritePattern<vector::ContractionOp> {
   vector::VectorTransformsOptions vectorTransformOptions;
   FilterConstraintType filter;
   // Lower one parallel dimension.
-  FailureOr<Value> lowerParallel(vector::ContractionOp op, int64_t lhsIndex,
-                                 int64_t rhsIndex,
-                                 PatternRewriter &rewriter) const;
+  FailureOr<Value> lowerParallel(PatternRewriter &rewriter,
+                                 vector::ContractionOp op, int64_t lhsIndex,
+                                 int64_t rhsIndex, Value mask) const;
   // Lower one reduction dimension.
-  FailureOr<Value> lowerReduction(vector::ContractionOp op,
-                                  PatternRewriter &rewriter) const;
+  FailureOr<Value> lowerReduction(PatternRewriter &rewriter,
+                                  vector::ContractionOp op, Value mask) const;
 };
 
 } // namespace vector

diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index eecf9701e3cab..0844fda09328d 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1904,11 +1904,6 @@ ContractionOpToDotLowering::matchAndRewrite(vector::ContractionOp op,
 LogicalResult
 ContractionOpLowering::matchAndRewrite(vector::ContractionOp op,
                                        PatternRewriter &rewriter) const {
-  // TODO: Support vector.mask.
-  auto maskableOp = cast<MaskableOpInterface>(op.getOperation());
-  if (maskableOp.isMasked())
-    return failure();
-
   // TODO: Remove native masks from contraction op?
   if (!op.getMasks().empty())
     return failure();
@@ -1944,15 +1939,25 @@ ContractionOpLowering::matchAndRewrite(vector::ContractionOp op,
   if (succeeded(pat4.matchAndRewrite(op, rewriter)))
     return success();
 
+  // Vector mask setup.
+  OpBuilder::InsertionGuard guard(rewriter);
+  Operation *rootOp = op;
+  Value mask;
+  if (op.isMasked()) {
+    rewriter.setInsertionPoint(op.getMaskingOp());
+    rootOp = op.getMaskingOp();
+    mask = op.getMaskingOp().getMask();
+  }
+
   // Find first batch dimension in LHS/RHS, and lower when found.
   std::vector<std::pair<int64_t, int64_t>> batchDimMap = op.getBatchDimMap();
   if (!batchDimMap.empty()) {
     int64_t lhsIndex = batchDimMap[0].first;
     int64_t rhsIndex = batchDimMap[0].second;
-    auto newOp = lowerParallel(op, lhsIndex, rhsIndex, rewriter);
+    auto newOp = lowerParallel(rewriter, op, lhsIndex, rhsIndex, mask);
     if (failed(newOp))
       return failure();
-    rewriter.replaceOp(op, *newOp);
+    rewriter.replaceOp(rootOp, *newOp);
     return success();
   }
 
@@ -1970,10 +1975,10 @@ ContractionOpLowering::matchAndRewrite(vector::ContractionOp op,
   VectorType lhsType = op.getLhsType();
   for (int64_t lhsIndex = 0, e = lhsType.getRank(); lhsIndex < e; ++lhsIndex) {
     if (lhsContractingDimSet.count(lhsIndex) == 0) {
-      auto newOp = lowerParallel(op, lhsIndex, /*rhsIndex=*/-1, rewriter);
+      auto newOp = lowerParallel(rewriter, op, lhsIndex, /*rhsIndex=*/-1, mask);
       if (failed(newOp))
         return failure();
-      rewriter.replaceOp(op, *newOp);
+      rewriter.replaceOp(rootOp, *newOp);
       return success();
     }
   }
@@ -1982,20 +1987,20 @@ ContractionOpLowering::matchAndRewrite(vector::ContractionOp op,
   VectorType rhsType = op.getRhsType();
   for (int64_t rhsIndex = 0, e = rhsType.getRank(); rhsIndex < e; ++rhsIndex) {
     if (rhsContractingDimSet.count(rhsIndex) == 0) {
-      auto newOp = lowerParallel(op, /*lhsIndex=*/-1, rhsIndex, rewriter);
+      auto newOp = lowerParallel(rewriter, op, /*lhsIndex=*/-1, rhsIndex, mask);
       if (failed(newOp))
         return failure();
-      rewriter.replaceOp(op, *newOp);
+      rewriter.replaceOp(rootOp, *newOp);
       return success();
     }
   }
 
   // Lower the first remaining reduction dimension.
   if (!contractingDimMap.empty()) {
-    auto newOp = lowerReduction(op, rewriter);
+    auto newOp = lowerReduction(rewriter, op, mask);
     if (failed(newOp))
       return failure();
-    rewriter.replaceOp(op, *newOp);
+    rewriter.replaceOp(rootOp, *newOp);
     return success();
   }
 
@@ -2005,10 +2010,11 @@ ContractionOpLowering::matchAndRewrite(vector::ContractionOp op,
 // Lower one parallel dimension.
 // Incidentally also tolerates unit-size (hence trivial) reduction dimensions.
 // TODO: consider reusing existing contract unrolling
-FailureOr<Value>
-ContractionOpLowering::lowerParallel(vector::ContractionOp op, int64_t lhsIndex,
-                                     int64_t rhsIndex,
-                                     PatternRewriter &rewriter) const {
+FailureOr<Value> ContractionOpLowering::lowerParallel(PatternRewriter &rewriter,
+                                                      vector::ContractionOp op,
+                                                      int64_t lhsIndex,
+                                                      int64_t rhsIndex,
+                                                      Value mask) const {
   VectorType lhsType = op.getLhsType();
   VectorType rhsType = op.getRhsType();
   VectorType resType = op.getResultType().cast<VectorType>();
@@ -2046,6 +2052,7 @@ ContractionOpLowering::lowerParallel(vector::ContractionOp op, int64_t lhsIndex,
       diag << "expected the dimension for iterIndex=" << iterIndex
            << " to either appear in the result map, or to be a unit dimension";
     });
+
   // Construct new iterator types and affine map array attribute.
   std::array<AffineMap, 3> lowIndexingMaps = {
       adjustMap(iMap[0], iterIndex, rewriter),
@@ -2058,22 +2065,29 @@ ContractionOpLowering::lowerParallel(vector::ContractionOp op, int64_t lhsIndex,
   Location loc = op.getLoc();
   Value result = rewriter.create<arith::ConstantOp>(
       loc, resType, rewriter.getZeroAttr(resType));
+
   for (int64_t d = 0; d < dimSize; ++d) {
     auto lhs = reshapeLoad(loc, op.getLhs(), lhsType, lhsIndex, d, rewriter);
     auto rhs = reshapeLoad(loc, op.getRhs(), rhsType, rhsIndex, d, rewriter);
     auto acc = reshapeLoad(loc, op.getAcc(), resType, resIndex, d, rewriter);
-    Value lowContract = rewriter.create<vector::ContractionOp>(
+
+    Value lowMask;
+    if (mask)
+      lowMask = reshapeLoad(loc, mask, cast<VectorType>(mask.getType()),
+                            iterIndex, d, rewriter);
+
+    Operation *lowContract = rewriter.create<vector::ContractionOp>(
         loc, lhs, rhs, acc, lowAffine, lowIter);
-    result =
-        reshapeStore(loc, lowContract, result, resType, resIndex, d, rewriter);
+    lowContract = maskOperation(rewriter, lowContract, lowMask);
+    result = reshapeStore(loc, lowContract->getResult(0), result, resType,
+                          resIndex, d, rewriter);
   }
   return result;
 }
 
 // Lower one reduction dimension.
-FailureOr<Value>
-ContractionOpLowering::lowerReduction(vector::ContractionOp op,
-                                      PatternRewriter &rewriter) const {
+FailureOr<Value> ContractionOpLowering::lowerReduction(
+    PatternRewriter &rewriter, vector::ContractionOp op, Value mask) const {
   auto loc = op.getLoc();
   VectorType lhsType = op.getLhsType();
   VectorType rhsType = op.getRhsType();
@@ -2110,10 +2124,12 @@ ContractionOpLowering::lowerReduction(vector::ContractionOp op,
           op, "When LHS has rank 1, expected also RHS to have rank 1");
     Value m = createMul(loc, op.getLhs(), op.getRhs(), isInt, rewriter);
     auto kind = vector::CombiningKind::ADD;
-    if (auto acc = op.getAcc())
-      return rewriter.create<vector::ReductionOp>(loc, kind, m, acc)
-          .getResult();
-    return rewriter.create<vector::ReductionOp>(loc, kind, m).getResult();
+
+    Value acc = op.getAcc();
+    Operation *reductionOp =
+        acc ? rewriter.create<vector::ReductionOp>(loc, kind, m, acc)
+            : rewriter.create<vector::ReductionOp>(loc, kind, m);
+    return maskOperation(rewriter, reductionOp, mask)->getResult(0);
   }
   // Construct new iterator types and affine map array attribute.
   std::array<AffineMap, 3> lowIndexingMaps = {
@@ -2131,8 +2147,14 @@ ContractionOpLowering::lowerReduction(vector::ContractionOp op,
   for (int64_t d = 0; d < dimSize; ++d) {
     auto lhs = reshapeLoad(loc, op.getLhs(), lhsType, lhsIndex, d, rewriter);
     auto rhs = reshapeLoad(loc, op.getRhs(), rhsType, rhsIndex, d, rewriter);
-    result = rewriter.create<vector::ContractionOp>(loc, lhs, rhs, result,
-                                                    lowAffine, lowIter);
+    Value newMask;
+    if (mask)
+      newMask = reshapeLoad(loc, mask, cast<VectorType>(mask.getType()),
+                            iterIndex, d, rewriter);
+
+    Operation *newContract = rewriter.create<vector::ContractionOp>(
+        loc, lhs, rhs, result, lowAffine, lowIter);
+    result = maskOperation(rewriter, newContract, newMask)->getResult(0);
   }
   return result;
 }

diff  --git a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
index 6ad8a096df20f..2cbd604759edc 100644
--- a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
@@ -28,6 +28,18 @@ func.func @extract_contract1(%arg0: vector<4xf32>, %arg1: vector<4xf32>, %arg2:
   return %0 : f32
 }
 
+// CHECK-LABEL: func @masked_extract_contract1
+//  CHECK-SAME:   %[[A:.*0]]: vector<4xf32>, %[[B:.*1]]: vector<4xf32>, %[[C:.*2]]: f32
+//  CHECK-SAME:   %[[M:.*]]: vector<4xi1>
+//       CHECK:   %[[F:.*]] = arith.mulf %[[A]], %[[B]] : vector<4xf32>
+//       CHECK:   %[[R:.*]] = vector.mask %[[M]] { vector.reduction <add>, %0, %arg2 : vector<4xf32> into f32 } : vector<4xi1> -> f32
+//       CHECK:   return %[[R]] : f32
+
+func.func @masked_extract_contract1(%arg0: vector<4xf32>, %arg1: vector<4xf32>, %arg2: f32, %mask: vector<4xi1>) -> f32 {
+  %0 = vector.mask %mask { vector.contract #dotp_trait %arg0, %arg1, %arg2 : vector<4xf32>, vector<4xf32> into f32 } : vector<4xi1> -> f32
+  return %0 : f32
+}
+
 // CHECK-LABEL: func @extract_contract1_int
 // CHECK-SAME: %[[A:.*0]]: vector<4xi32>,
 // CHECK-SAME: %[[B:.*1]]: vector<4xi32>,


        


More information about the Mlir-commits mailing list