[Mlir-commits] [mlir] 6870a50 - lowerParallel is also called on unit-size, one-sided reduction dims

Benoit Jacob llvmlistbot at llvm.org
Wed Jul 13 09:21:23 PDT 2022


Author: Benoit Jacob
Date: 2022-07-13T16:21:12Z
New Revision: 6870a50f43721d070436eed52b8c311f62818d7c

URL: https://github.com/llvm/llvm-project/commit/6870a50f43721d070436eed52b8c311f62818d7c
DIFF: https://github.com/llvm/llvm-project/commit/6870a50f43721d070436eed52b8c311f62818d7c.diff

LOG: lowerParallel is also called on unit-size, one-sided reduction dims

See: https://gist.github.com/bjacob/d8be8ec7e70ed0be4b3a5794ced2a7e8

Differential Revision: https://reviews.llvm.org/D129096

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
    mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
    mlir/test/Dialect/Vector/vector-contract-transforms.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index e215be49b74ef..ba4f6b3788c32 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -527,11 +527,12 @@ class ContractionOpLowering : public OpRewritePattern<vector::ContractionOp> {
   vector::VectorTransformsOptions vectorTransformOptions;
   FilterConstraintType filter;
   // Lower one parallel dimension.
-  Value lowerParallel(vector::ContractionOp op, int64_t lhsIndex,
-                      int64_t rhsIndex, PatternRewriter &rewriter) const;
+  FailureOr<Value> lowerParallel(vector::ContractionOp op, int64_t lhsIndex,
+                                 int64_t rhsIndex,
+                                 PatternRewriter &rewriter) const;
   // Lower one reduction dimension.
-  Value lowerReduction(vector::ContractionOp op,
-                       PatternRewriter &rewriter) const;
+  FailureOr<Value> lowerReduction(vector::ContractionOp op,
+                                  PatternRewriter &rewriter) const;
 };
 
 } // namespace vector

diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index a62f90693a5c9..97b603bcd3f5d 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1794,7 +1794,10 @@ ContractionOpLowering::matchAndRewrite(vector::ContractionOp op,
   if (!batchDimMap.empty()) {
     int64_t lhsIndex = batchDimMap[0].first;
     int64_t rhsIndex = batchDimMap[0].second;
-    rewriter.replaceOp(op, lowerParallel(op, lhsIndex, rhsIndex, rewriter));
+    auto newOp = lowerParallel(op, lhsIndex, rhsIndex, rewriter);
+    if (failed(newOp))
+      return failure();
+    rewriter.replaceOp(op, newOp.value());
     return success();
   }
 
@@ -1812,8 +1815,10 @@ ContractionOpLowering::matchAndRewrite(vector::ContractionOp op,
   VectorType lhsType = op.getLhsType();
   for (int64_t lhsIndex = 0, e = lhsType.getRank(); lhsIndex < e; ++lhsIndex) {
     if (lhsContractingDimSet.count(lhsIndex) == 0) {
-      rewriter.replaceOp(
-          op, lowerParallel(op, lhsIndex, /*rhsIndex=*/-1, rewriter));
+      auto newOp = lowerParallel(op, lhsIndex, /*rhsIndex=*/-1, rewriter);
+      if (failed(newOp))
+        return failure();
+      rewriter.replaceOp(op, newOp.value());
       return success();
     }
   }
@@ -1822,15 +1827,20 @@ ContractionOpLowering::matchAndRewrite(vector::ContractionOp op,
   VectorType rhsType = op.getRhsType();
   for (int64_t rhsIndex = 0, e = rhsType.getRank(); rhsIndex < e; ++rhsIndex) {
     if (rhsContractingDimSet.count(rhsIndex) == 0) {
-      rewriter.replaceOp(
-          op, lowerParallel(op, /*lhsIndex=*/-1, rhsIndex, rewriter));
+      auto newOp = lowerParallel(op, /*lhsIndex=*/-1, rhsIndex, rewriter);
+      if (failed(newOp))
+        return failure();
+      rewriter.replaceOp(op, newOp.value());
       return success();
     }
   }
 
   // Lower the first remaining reduction dimension.
   if (!contractingDimMap.empty()) {
-    rewriter.replaceOp(op, lowerReduction(op, rewriter));
+    auto newOp = lowerReduction(op, rewriter);
+    if (failed(newOp))
+      return failure();
+    rewriter.replaceOp(op, newOp.value());
     return success();
   }
 
@@ -1838,10 +1848,12 @@ ContractionOpLowering::matchAndRewrite(vector::ContractionOp op,
 }
 
 // Lower one parallel dimension.
+// Incidentally also tolerates unit-size (hence trivial) reduction dimensions.
 // TODO: consider reusing existing contract unrolling
-Value ContractionOpLowering::lowerParallel(vector::ContractionOp op,
-                                           int64_t lhsIndex, int64_t rhsIndex,
-                                           PatternRewriter &rewriter) const {
+FailureOr<Value>
+ContractionOpLowering::lowerParallel(vector::ContractionOp op, int64_t lhsIndex,
+                                     int64_t rhsIndex,
+                                     PatternRewriter &rewriter) const {
   VectorType lhsType = op.getLhsType();
   VectorType rhsType = op.getRhsType();
   VectorType resType = op.getResultType().cast<VectorType>();
@@ -1851,18 +1863,34 @@ Value ContractionOpLowering::lowerParallel(vector::ContractionOp op,
   int64_t dimSize = -1;
   if (lhsIndex >= 0) {
     iterIndex = iMap[0].getDimPosition(lhsIndex);
-    assert((rhsIndex < 0 || iterIndex == iMap[1].getDimPosition(rhsIndex)) &&
-           "parallel index should be free in LHS or batch in LHS/RHS");
+    if (rhsIndex >= 0 && iterIndex != iMap[1].getDimPosition(rhsIndex))
+      return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
+        diag << "expected lhsIndex=" << lhsIndex << " and rhsIndex=" << rhsIndex
+             << " to map to the same dimension";
+      });
     dimSize = lhsType.getDimSize(lhsIndex);
-  } else {
-    assert(rhsIndex >= 0 && "missing parallel index");
+  } else if (rhsIndex >= 0) {
     iterIndex = iMap[1].getDimPosition(rhsIndex);
     dimSize = rhsType.getDimSize(rhsIndex);
   }
-  assert(iterIndex >= 0 && "parallel index not listed in operand mapping");
-  Optional<int64_t> lookup = getResultIndex(iMap[2], iterIndex);
-  assert(lookup.has_value() && "parallel index not listed in reduction");
-  int64_t resIndex = lookup.getValue();
+  if (iterIndex < 0)
+    return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
+      diag << "expected either lhsIndex=" << lhsIndex
+           << " or rhsIndex=" << rhsIndex << " to be nonnegative";
+    });
+  // getValueOr(-1) means that we tolerate a dimension not appearing
+  // in the result map. That can't happen for actual parallel iterators, but
+  // the caller ContractionOpLowering::matchAndRewrite is currently calling
+  // lowerParallel also for the case of unit-size reduction dims appearing only
+  // on one of LHS or RHS, not both. At the moment, such cases are created by
+  // CastAwayContractionLeadingOneDim, so we need to either support that or
+  // modify that pattern.
+  int64_t resIndex = getResultIndex(iMap[2], iterIndex).getValueOr(-1);
+  if (resIndex == -1 && dimSize != 1)
+    return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
+      diag << "expected the dimension for iterIndex=" << iterIndex
+           << " to either appear in the result map, or to be a unit dimension";
+    });
   // Construct new iterator types and affine map array attribute.
   std::array<AffineMap, 3> lowIndexingMaps = {
       adjustMap(iMap[0], iterIndex, rewriter),
@@ -1888,33 +1916,49 @@ Value ContractionOpLowering::lowerParallel(vector::ContractionOp op,
 }
 
 // Lower one reduction dimension.
-Value ContractionOpLowering::lowerReduction(vector::ContractionOp op,
-                                            PatternRewriter &rewriter) const {
+FailureOr<Value>
+ContractionOpLowering::lowerReduction(vector::ContractionOp op,
+                                      PatternRewriter &rewriter) const {
   auto loc = op.getLoc();
   VectorType lhsType = op.getLhsType();
   VectorType rhsType = op.getRhsType();
   Type resType = op.getResultType();
-  assert(!resType.isa<VectorType>());
+  if (resType.isa<VectorType>())
+    return rewriter.notifyMatchFailure(op,
+                                       "did not expect a VectorType result");
   bool isInt = resType.isa<IntegerType>();
   // Use iterator index 0.
   int64_t iterIndex = 0;
   SmallVector<AffineMap, 4> iMap = op.getIndexingMaps();
   Optional<int64_t> lookupLhs = getResultIndex(iMap[0], iterIndex);
   Optional<int64_t> lookupRhs = getResultIndex(iMap[1], iterIndex);
-  assert(lookupLhs.has_value() && "missing LHS parallel index");
-  assert(lookupRhs.has_value() && "missing RHS parallel index");
+  if (!lookupLhs.hasValue())
+    return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
+      diag << "expected iterIndex=" << iterIndex << "to map to a LHS dimension";
+    });
+  if (!lookupRhs.hasValue())
+    return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
+      diag << "expected iterIndex=" << iterIndex << "to map to a RHS dimension";
+    });
   int64_t lhsIndex = lookupLhs.getValue();
   int64_t rhsIndex = lookupRhs.getValue();
   int64_t dimSize = lhsType.getDimSize(lhsIndex);
-  assert(dimSize == rhsType.getDimSize(rhsIndex) && "corrupt shape");
+  if (dimSize != rhsType.getDimSize(rhsIndex))
+    return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
+      diag << "expect LHS dimension " << lhsIndex
+           << " to have the same size as RHS dimension " << rhsIndex;
+    });
   // Base case.
   if (lhsType.getRank() == 1) {
-    assert(rhsType.getRank() == 1 && "corrupt contraction");
+    if (rhsType.getRank() != 1)
+      return rewriter.notifyMatchFailure(
+          op, "When LHS has rank 1, expected also RHS to have rank 1");
     Value m = createMul(loc, op.getLhs(), op.getRhs(), isInt, rewriter);
     auto kind = vector::CombiningKind::ADD;
     if (auto acc = op.getAcc())
-      return rewriter.create<vector::ReductionOp>(loc, kind, m, acc);
-    return rewriter.create<vector::ReductionOp>(loc, kind, m);
+      return rewriter.create<vector::ReductionOp>(loc, kind, m, acc)
+          .getResult();
+    return rewriter.create<vector::ReductionOp>(loc, kind, m).getResult();
   }
   // Construct new iterator types and affine map array attribute.
   std::array<AffineMap, 3> lowIndexingMaps = {

diff  --git a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
index 4123ef3b75135..72bfdd6e580b2 100644
--- a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
@@ -858,6 +858,34 @@ func.func @genbool_var_3d(%arg0: index, %arg1: index, %arg2: index) -> vector<2x
   return %0 : vector<2x1x7xi1>
 }
 
+// CHECK-LABEL: @contract_one_sided_unit_reduction_dim
+// CHECK-SAME: (%[[A0:.+]]: vector<1x2xi32>, %[[A1:.+]]: vector<2x2xi32>, %[[A2:.+]]: vector<2xi32>)
+// CHECK-DAG: %[[C:.+]] = arith.constant dense<0> : vector<2xi32>
+// CHECK-DAG: %[[E00:.+]] = vector.extract %[[A0]][0] : vector<1x2xi32>
+// CHECK-DAG: %[[E10:.+]] = vector.extract %[[A1]][0] : vector<2x2xi32>
+// CHECK:     %[[M0:.+]] = arith.muli %[[E10]], %[[E00]] : vector<2xi32>
+// CHECK:     %[[R0:.+]] = vector.reduction <add>, %[[M0]] : vector<2xi32> into i32
+// CHECK:     %[[I0:.+]] = vector.insert %[[R0]], %[[C]] [0] : i32 into vector<2xi32>
+// CHECK:     %[[E11:.+]] = vector.extract %[[A1]][1] : vector<2x2xi32>
+// CHECK:     %[[M1:.+]] = arith.muli %[[E11]], %[[E00]] : vector<2xi32>
+// CHECK:     %[[R1:.+]] = vector.reduction <add>, %[[M1]] : vector<2xi32> into i32
+// CHECK:     %[[I1:.+]] = vector.insert %[[R1]], %[[I0]] [1] : i32 into vector<2xi32>
+// CHECK:     %[[S:.+]] = arith.addi %[[I1]], %[[A2]] : vector<2xi32>
+// CHECK:     return %[[S]] : vector<2xi32>
+
+func.func @contract_one_sided_unit_reduction_dim(%arg0 : vector<1x2xi32>, %arg1 : vector<2x2xi32>, %arg2 : vector<2xi32>) -> vector<2xi32> {
+  %res = vector.contract {
+    indexing_maps = [
+      affine_map<(d0, d1, d2) -> (d0, d2)>,
+      affine_map<(d0, d1, d2) -> (d1, d2)>,
+      affine_map<(d0, d1, d2) -> (d1)>
+    ],
+    iterator_types = ["reduction", "parallel", "reduction"],
+    kind = #vector.kind<add>
+  } %arg0, %arg1, %arg2 : vector<1x2xi32>, vector<2x2xi32>, vector<2xi32> into vector<2xi32>
+  return %res : vector<2xi32>
+}
+
 #matmat_accesses_0 = [
   affine_map<(m, n, k) -> (m, k)>,
   affine_map<(m, n, k) -> (k, n)>,


        


More information about the Mlir-commits mailing list