[Mlir-commits] [mlir] 6870a50 - lowerParallel is also called on unit-size, one-sided reduction dims
Benoit Jacob
llvmlistbot at llvm.org
Wed Jul 13 09:21:23 PDT 2022
Author: Benoit Jacob
Date: 2022-07-13T16:21:12Z
New Revision: 6870a50f43721d070436eed52b8c311f62818d7c
URL: https://github.com/llvm/llvm-project/commit/6870a50f43721d070436eed52b8c311f62818d7c
DIFF: https://github.com/llvm/llvm-project/commit/6870a50f43721d070436eed52b8c311f62818d7c.diff
LOG: lowerParallel is also called on unit-size, one-sided reduction dims
See: https://gist.github.com/bjacob/d8be8ec7e70ed0be4b3a5794ced2a7e8
Differential Revision: https://reviews.llvm.org/D129096
Added:
Modified:
mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
mlir/test/Dialect/Vector/vector-contract-transforms.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index e215be49b74ef..ba4f6b3788c32 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -527,11 +527,12 @@ class ContractionOpLowering : public OpRewritePattern<vector::ContractionOp> {
vector::VectorTransformsOptions vectorTransformOptions;
FilterConstraintType filter;
// Lower one parallel dimension.
- Value lowerParallel(vector::ContractionOp op, int64_t lhsIndex,
- int64_t rhsIndex, PatternRewriter &rewriter) const;
+ FailureOr<Value> lowerParallel(vector::ContractionOp op, int64_t lhsIndex,
+ int64_t rhsIndex,
+ PatternRewriter &rewriter) const;
// Lower one reduction dimension.
- Value lowerReduction(vector::ContractionOp op,
- PatternRewriter &rewriter) const;
+ FailureOr<Value> lowerReduction(vector::ContractionOp op,
+ PatternRewriter &rewriter) const;
};
} // namespace vector
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index a62f90693a5c9..97b603bcd3f5d 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1794,7 +1794,10 @@ ContractionOpLowering::matchAndRewrite(vector::ContractionOp op,
if (!batchDimMap.empty()) {
int64_t lhsIndex = batchDimMap[0].first;
int64_t rhsIndex = batchDimMap[0].second;
- rewriter.replaceOp(op, lowerParallel(op, lhsIndex, rhsIndex, rewriter));
+ auto newOp = lowerParallel(op, lhsIndex, rhsIndex, rewriter);
+ if (failed(newOp))
+ return failure();
+ rewriter.replaceOp(op, newOp.value());
return success();
}
@@ -1812,8 +1815,10 @@ ContractionOpLowering::matchAndRewrite(vector::ContractionOp op,
VectorType lhsType = op.getLhsType();
for (int64_t lhsIndex = 0, e = lhsType.getRank(); lhsIndex < e; ++lhsIndex) {
if (lhsContractingDimSet.count(lhsIndex) == 0) {
- rewriter.replaceOp(
- op, lowerParallel(op, lhsIndex, /*rhsIndex=*/-1, rewriter));
+ auto newOp = lowerParallel(op, lhsIndex, /*rhsIndex=*/-1, rewriter);
+ if (failed(newOp))
+ return failure();
+ rewriter.replaceOp(op, newOp.value());
return success();
}
}
@@ -1822,15 +1827,20 @@ ContractionOpLowering::matchAndRewrite(vector::ContractionOp op,
VectorType rhsType = op.getRhsType();
for (int64_t rhsIndex = 0, e = rhsType.getRank(); rhsIndex < e; ++rhsIndex) {
if (rhsContractingDimSet.count(rhsIndex) == 0) {
- rewriter.replaceOp(
- op, lowerParallel(op, /*lhsIndex=*/-1, rhsIndex, rewriter));
+ auto newOp = lowerParallel(op, /*lhsIndex=*/-1, rhsIndex, rewriter);
+ if (failed(newOp))
+ return failure();
+ rewriter.replaceOp(op, newOp.value());
return success();
}
}
// Lower the first remaining reduction dimension.
if (!contractingDimMap.empty()) {
- rewriter.replaceOp(op, lowerReduction(op, rewriter));
+ auto newOp = lowerReduction(op, rewriter);
+ if (failed(newOp))
+ return failure();
+ rewriter.replaceOp(op, newOp.value());
return success();
}
@@ -1838,10 +1848,12 @@ ContractionOpLowering::matchAndRewrite(vector::ContractionOp op,
}
// Lower one parallel dimension.
+// Incidentally also tolerates unit-size (hence trivial) reduction dimensions.
// TODO: consider reusing existing contract unrolling
-Value ContractionOpLowering::lowerParallel(vector::ContractionOp op,
- int64_t lhsIndex, int64_t rhsIndex,
- PatternRewriter &rewriter) const {
+FailureOr<Value>
+ContractionOpLowering::lowerParallel(vector::ContractionOp op, int64_t lhsIndex,
+ int64_t rhsIndex,
+ PatternRewriter &rewriter) const {
VectorType lhsType = op.getLhsType();
VectorType rhsType = op.getRhsType();
VectorType resType = op.getResultType().cast<VectorType>();
@@ -1851,18 +1863,34 @@ Value ContractionOpLowering::lowerParallel(vector::ContractionOp op,
int64_t dimSize = -1;
if (lhsIndex >= 0) {
iterIndex = iMap[0].getDimPosition(lhsIndex);
- assert((rhsIndex < 0 || iterIndex == iMap[1].getDimPosition(rhsIndex)) &&
- "parallel index should be free in LHS or batch in LHS/RHS");
+ if (rhsIndex >= 0 && iterIndex != iMap[1].getDimPosition(rhsIndex))
+ return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
+ diag << "expected lhsIndex=" << lhsIndex << " and rhsIndex=" << rhsIndex
+ << " to map to the same dimension";
+ });
dimSize = lhsType.getDimSize(lhsIndex);
- } else {
- assert(rhsIndex >= 0 && "missing parallel index");
+ } else if (rhsIndex >= 0) {
iterIndex = iMap[1].getDimPosition(rhsIndex);
dimSize = rhsType.getDimSize(rhsIndex);
}
- assert(iterIndex >= 0 && "parallel index not listed in operand mapping");
- Optional<int64_t> lookup = getResultIndex(iMap[2], iterIndex);
- assert(lookup.has_value() && "parallel index not listed in reduction");
- int64_t resIndex = lookup.getValue();
+ if (iterIndex < 0)
+ return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
+ diag << "expected either lhsIndex=" << lhsIndex
+ << " or rhsIndex=" << rhsIndex << " to be nonnegative";
+ });
+ // getValueOr(-1) means that we tolerate a dimension not appearing
+ // in the result map. That can't happen for actual parallel iterators, but
+ // the caller ContractionOpLowering::matchAndRewrite is currently calling
+ // lowerParallel also for the case of unit-size reduction dims appearing only
+ // on one of LHS or RHS, not both. At the moment, such cases are created by
+ // CastAwayContractionLeadingOneDim, so we need to either support that or
+ // modify that pattern.
+ int64_t resIndex = getResultIndex(iMap[2], iterIndex).getValueOr(-1);
+ if (resIndex == -1 && dimSize != 1)
+ return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
+ diag << "expected the dimension for iterIndex=" << iterIndex
+ << " to either appear in the result map, or to be a unit dimension";
+ });
// Construct new iterator types and affine map array attribute.
std::array<AffineMap, 3> lowIndexingMaps = {
adjustMap(iMap[0], iterIndex, rewriter),
@@ -1888,33 +1916,49 @@ Value ContractionOpLowering::lowerParallel(vector::ContractionOp op,
}
// Lower one reduction dimension.
-Value ContractionOpLowering::lowerReduction(vector::ContractionOp op,
- PatternRewriter &rewriter) const {
+FailureOr<Value>
+ContractionOpLowering::lowerReduction(vector::ContractionOp op,
+ PatternRewriter &rewriter) const {
auto loc = op.getLoc();
VectorType lhsType = op.getLhsType();
VectorType rhsType = op.getRhsType();
Type resType = op.getResultType();
- assert(!resType.isa<VectorType>());
+ if (resType.isa<VectorType>())
+ return rewriter.notifyMatchFailure(op,
+ "did not expect a VectorType result");
bool isInt = resType.isa<IntegerType>();
// Use iterator index 0.
int64_t iterIndex = 0;
SmallVector<AffineMap, 4> iMap = op.getIndexingMaps();
Optional<int64_t> lookupLhs = getResultIndex(iMap[0], iterIndex);
Optional<int64_t> lookupRhs = getResultIndex(iMap[1], iterIndex);
- assert(lookupLhs.has_value() && "missing LHS parallel index");
- assert(lookupRhs.has_value() && "missing RHS parallel index");
+ if (!lookupLhs.hasValue())
+ return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
+ diag << "expected iterIndex=" << iterIndex << "to map to a LHS dimension";
+ });
+ if (!lookupRhs.hasValue())
+ return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
+ diag << "expected iterIndex=" << iterIndex << "to map to a RHS dimension";
+ });
int64_t lhsIndex = lookupLhs.getValue();
int64_t rhsIndex = lookupRhs.getValue();
int64_t dimSize = lhsType.getDimSize(lhsIndex);
- assert(dimSize == rhsType.getDimSize(rhsIndex) && "corrupt shape");
+ if (dimSize != rhsType.getDimSize(rhsIndex))
+ return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
+ diag << "expect LHS dimension " << lhsIndex
+ << " to have the same size as RHS dimension " << rhsIndex;
+ });
// Base case.
if (lhsType.getRank() == 1) {
- assert(rhsType.getRank() == 1 && "corrupt contraction");
+ if (rhsType.getRank() != 1)
+ return rewriter.notifyMatchFailure(
+ op, "When LHS has rank 1, expected also RHS to have rank 1");
Value m = createMul(loc, op.getLhs(), op.getRhs(), isInt, rewriter);
auto kind = vector::CombiningKind::ADD;
if (auto acc = op.getAcc())
- return rewriter.create<vector::ReductionOp>(loc, kind, m, acc);
- return rewriter.create<vector::ReductionOp>(loc, kind, m);
+ return rewriter.create<vector::ReductionOp>(loc, kind, m, acc)
+ .getResult();
+ return rewriter.create<vector::ReductionOp>(loc, kind, m).getResult();
}
// Construct new iterator types and affine map array attribute.
std::array<AffineMap, 3> lowIndexingMaps = {
diff --git a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
index 4123ef3b75135..72bfdd6e580b2 100644
--- a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
@@ -858,6 +858,34 @@ func.func @genbool_var_3d(%arg0: index, %arg1: index, %arg2: index) -> vector<2x
return %0 : vector<2x1x7xi1>
}
+// CHECK-LABEL: @contract_one_sided_unit_reduction_dim
+// CHECK-SAME: (%[[A0:.+]]: vector<1x2xi32>, %[[A1:.+]]: vector<2x2xi32>, %[[A2:.+]]: vector<2xi32>)
+// CHECK-DAG: %[[C:.+]] = arith.constant dense<0> : vector<2xi32>
+// CHECK-DAG: %[[E00:.+]] = vector.extract %[[A0]][0] : vector<1x2xi32>
+// CHECK-DAG: %[[E10:.+]] = vector.extract %[[A1]][0] : vector<2x2xi32>
+// CHECK: %[[M0:.+]] = arith.muli %[[E10]], %[[E00]] : vector<2xi32>
+// CHECK: %[[R0:.+]] = vector.reduction <add>, %[[M0]] : vector<2xi32> into i32
+// CHECK: %[[I0:.+]] = vector.insert %[[R0]], %[[C]] [0] : i32 into vector<2xi32>
+// CHECK: %[[E11:.+]] = vector.extract %[[A1]][1] : vector<2x2xi32>
+// CHECK: %[[M1:.+]] = arith.muli %[[E11]], %[[E00]] : vector<2xi32>
+// CHECK: %[[R1:.+]] = vector.reduction <add>, %[[M1]] : vector<2xi32> into i32
+// CHECK: %[[I1:.+]] = vector.insert %[[R1]], %[[I0]] [1] : i32 into vector<2xi32>
+// CHECK: %[[S:.+]] = arith.addi %[[I1]], %[[A2]] : vector<2xi32>
+// CHECK: return %[[S]] : vector<2xi32>
+
+func.func @contract_one_sided_unit_reduction_dim(%arg0 : vector<1x2xi32>, %arg1 : vector<2x2xi32>, %arg2 : vector<2xi32>) -> vector<2xi32> {
+ %res = vector.contract {
+ indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d1)>
+ ],
+ iterator_types = ["reduction", "parallel", "reduction"],
+ kind = #vector.kind<add>
+ } %arg0, %arg1, %arg2 : vector<1x2xi32>, vector<2x2xi32>, vector<2xi32> into vector<2xi32>
+ return %res : vector<2xi32>
+}
+
#matmat_accesses_0 = [
affine_map<(m, n, k) -> (m, k)>,
affine_map<(m, n, k) -> (k, n)>,
More information about the Mlir-commits
mailing list