[Mlir-commits] [mlir] ee260c1 - [mlir] [VectorOps] Multi-dim reductions for lowering vector.contract
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Feb 20 14:17:17 PST 2020
Author: aartbik
Date: 2020-02-20T14:16:50-08:00
New Revision: ee260c1a0f1c0a8fd1179cdab9fb4312086dcc54
URL: https://github.com/llvm/llvm-project/commit/ee260c1a0f1c0a8fd1179cdab9fb4312086dcc54
DIFF: https://github.com/llvm/llvm-project/commit/ee260c1a0f1c0a8fd1179cdab9fb4312086dcc54.diff
LOG: [mlir] [VectorOps] Multi-dim reductions for lowering vector.contract
Summary:
This implements the last step for lowering vector.contract progressively
to LLVM IR (except for masks). Multi-dimensional reductions that remain
after expanding all parallel dimensions are lowered into into simpler
vector.contract operations until a trivial 1-dim reduction remains.
Reviewers: nicolasvasilache, andydavis1
Reviewed By: andydavis1
Subscribers: mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, nicolasvasilache, arpith-jacob, mgester, lucyrfox, liufengdb, Joonsoo, llvm-commits
Tags: #llvm
Differential Revision: https://reviews.llvm.org/D74880
Added:
Modified:
mlir/lib/Dialect/VectorOps/VectorTransforms.cpp
mlir/test/Dialect/VectorOps/vector-contract-transforms.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp b/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp
index 61a3d556a70b..923f1c215583 100644
--- a/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp
@@ -929,21 +929,9 @@ class ContractionOpLowering : public OpRewritePattern<vector::ContractionOp> {
}
}
- // Lower the only remaining contraction dimensions.
- // TODO(ajcbik): handle multi-dim reductions
- auto loc = op.getLoc();
- Type resType = op.getResultType();
- if (!resType.isa<VectorType>() && lhsType.getRank() == 1 &&
- rhsType.getRank() == 1) {
-
- Value zero = rewriter.create<ConstantOp>(loc, resType,
- rewriter.getZeroAttr(resType));
- Value splat = rewriter.create<SplatOp>(loc, lhsType, zero);
- Value fma =
- rewriter.create<vector::FMAOp>(loc, op.lhs(), op.rhs(), splat);
- StringAttr kind = rewriter.getStringAttr("add");
- rewriter.replaceOpWithNewOp<vector::ReductionV2Op>(op, resType, kind, fma,
- op.acc());
+ // Lower the first remaining reduction dimension.
+ if (!contractingDimMap.empty()) {
+ rewriter.replaceOp(op, lowerReduction(op, rewriter));
return matchSuccess();
}
@@ -981,27 +969,14 @@ class ContractionOpLowering : public OpRewritePattern<vector::ContractionOp> {
Optional<int64_t> lookup = getResultIndex(iMap[2], iterIndex);
assert(lookup.hasValue() && "parallel index not listed in reduction");
int64_t resIndex = lookup.getValue();
- // Construct new iterator types.
- ArrayAttr iteratorTypes = op.iterator_types();
- SmallVector<Attribute, 4> lowIterTypes;
- for (auto it : llvm::enumerate(iteratorTypes)) {
- int64_t idx = it.index();
- if (idx == iterIndex) {
- assert(it.value().cast<StringAttr>().getValue() ==
- getParallelIteratorTypeName() &&
- "parallel index not marked as such");
- continue;
- }
- lowIterTypes.push_back(it.value());
- }
- // Construct new affine map array attribute.
+ // Construct new iterator types and affine map array attribute.
SmallVector<AffineMap, 4> lowIndexingMaps;
lowIndexingMaps.push_back(adjustMap(iMap[0], iterIndex, rewriter));
lowIndexingMaps.push_back(adjustMap(iMap[1], iterIndex, rewriter));
lowIndexingMaps.push_back(adjustMap(iMap[2], iterIndex, rewriter));
auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps);
- // Construct new iterator types array attribute.
- auto lowIter = rewriter.getArrayAttr(lowIterTypes);
+ auto lowIter =
+ rewriter.getArrayAttr(adjustIter(op.iterator_types(), iterIndex));
// Unroll into a series of lower dimensional vector.contract ops.
Location loc = op.getLoc();
Value result = zeroVector(loc, resType, rewriter);
@@ -1017,6 +992,56 @@ class ContractionOpLowering : public OpRewritePattern<vector::ContractionOp> {
return result;
}
+ // Lower one reduction dimension.
+ Value lowerReduction(vector::ContractionOp op,
+ PatternRewriter &rewriter) const {
+ auto loc = op.getLoc();
+ VectorType lhsType = op.getLhsType();
+ VectorType rhsType = op.getRhsType();
+ Type resType = op.getResultType();
+ assert(!resType.isa<VectorType>());
+ // Use iterator index 0.
+ int64_t iterIndex = 0;
+ SmallVector<AffineMap, 4> iMap = op.getIndexingMaps();
+ Optional<int64_t> lookupLhs = getResultIndex(iMap[0], iterIndex);
+ Optional<int64_t> lookupRhs = getResultIndex(iMap[1], iterIndex);
+ assert(lookupLhs.hasValue() && "missing LHS parallel index");
+ assert(lookupRhs.hasValue() && "missing RHS parallel index");
+ int64_t lhsIndex = lookupLhs.getValue();
+ int64_t rhsIndex = lookupRhs.getValue();
+ int64_t dimSize = lhsType.getDimSize(lhsIndex);
+ assert(dimSize == rhsType.getDimSize(rhsIndex) && "corrupt shape");
+ // Base case.
+ if (lhsType.getRank() == 1) {
+ assert(rhsType.getRank() == 1 && "corrupt contraction");
+ Value zero = zeroVector(loc, lhsType, rewriter);
+ Value fma = rewriter.create<vector::FMAOp>(loc, op.lhs(), op.rhs(), zero);
+ StringAttr kind = rewriter.getStringAttr("add");
+ return rewriter.create<vector::ReductionV2Op>(loc, resType, kind, fma,
+ op.acc());
+ }
+ // Construct new iterator types and affine map array attribute.
+ SmallVector<AffineMap, 4> lowIndexingMaps;
+ lowIndexingMaps.push_back(adjustMap(iMap[0], iterIndex, rewriter));
+ lowIndexingMaps.push_back(adjustMap(iMap[1], iterIndex, rewriter));
+ lowIndexingMaps.push_back(adjustMap(iMap[2], iterIndex, rewriter));
+ auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps);
+ auto lowIter =
+ rewriter.getArrayAttr(adjustIter(op.iterator_types(), iterIndex));
+ // Unroll into a series of lower dimensional vector.contract ops.
+ // By feeding the initial accumulator into the first contraction,
+ // and the result of each contraction into the next, eventually
+ // the sum of all reductions is computed.
+ Value result = op.acc();
+ for (int64_t d = 0; d < dimSize; ++d) {
+ auto lhs = reshapeLoad(loc, op.lhs(), lhsType, lhsIndex, d, rewriter);
+ auto rhs = reshapeLoad(loc, op.rhs(), rhsType, rhsIndex, d, rewriter);
+ result = rewriter.create<vector::ContractionOp>(loc, lhs, rhs, result,
+ lowAffine, lowIter);
+ }
+ return result;
+ }
+
// Helper method to construct a zero vector.
static Value zeroVector(Location loc, VectorType vType,
PatternRewriter &rewriter) {
@@ -1036,6 +1061,20 @@ class ContractionOpLowering : public OpRewritePattern<vector::ContractionOp> {
return None;
}
+ // Helper to construct iterator types with one index removed.
+ static SmallVector<Attribute, 4> adjustIter(ArrayAttr iteratorTypes,
+ int64_t index) {
+ SmallVector<Attribute, 4> results;
+ for (auto it : llvm::enumerate(iteratorTypes)) {
+ int64_t idx = it.index();
+ if (idx == index) {
+ continue;
+ }
+ results.push_back(it.value());
+ }
+ return results;
+ }
+
// Helper to construct an affine map with one index removed.
static AffineMap adjustMap(AffineMap map, int64_t index,
PatternRewriter &rewriter) {
diff --git a/mlir/test/Dialect/VectorOps/vector-contract-transforms.mlir b/mlir/test/Dialect/VectorOps/vector-contract-transforms.mlir
index f781e37d586b..362c85a38d09 100644
--- a/mlir/test/Dialect/VectorOps/vector-contract-transforms.mlir
+++ b/mlir/test/Dialect/VectorOps/vector-contract-transforms.mlir
@@ -169,3 +169,84 @@ func @extract_contract4(%arg0: vector<2x2xf32>,
return %0 : vector<2x2xf32>
}
+#contraction2d_accesses = [
+ affine_map<(i, j) -> (i, j)>,
+ affine_map<(i, j) -> (i, j)>,
+ affine_map<(i, j) -> ()>
+]
+#contraction2d_trait = {
+ indexing_maps = #contraction2d_accesses,
+ iterator_types = ["reduction", "reduction"]
+}
+
+// CHECK-LABEL: func @full_contract1
+// CHECK-SAME: %[[A:.*0]]: vector<2x3xf32>,
+// CHECK-SAME: %[[B:.*1]]: vector<2x3xf32>,
+// CHECK-SAME: %[[C:.*2]]: f32
+// CHECK: %[[Z:.*]] = constant dense<0.000000e+00> : vector<3xf32>
+// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2x3xf32>
+// CHECK: %[[T1:.*]] = vector.extract %[[B]][0] : vector<2x3xf32>
+// CHECK: %[[T2:.*]] = vector.fma %[[T0]], %[[T1]], %[[Z]] : vector<3xf32>
+// CHECK: %[[T3:.*]] = vector.reductionv2 "add", %[[T2]], %[[C]] : vector<3xf32>, f32 into f32
+// CHECK: %[[T4:.*]] = vector.extract %[[A]][1] : vector<2x3xf32>
+// CHECK: %[[T5:.*]] = vector.extract %[[B]][1] : vector<2x3xf32>
+// CHECK: %[[T6:.*]] = vector.fma %[[T4]], %[[T5]], %[[Z]] : vector<3xf32>
+// CHECK: %[[T7:.*]] = vector.reductionv2 "add", %[[T6]], %[[T3]] : vector<3xf32>, f32 into f32
+// CHECK: return %[[T7]] : f32
+
+func @full_contract1(%arg0: vector<2x3xf32>,
+ %arg1: vector<2x3xf32>,
+ %arg2: f32) -> f32 {
+ %0 = vector.contract #contraction2d_trait %arg0, %arg1, %arg2
+ : vector<2x3xf32>, vector<2x3xf32> into f32
+ return %0 : f32
+}
+
+#contraction2d_trans_accesses = [
+ affine_map<(i, j) -> (i, j)>,
+ affine_map<(i, j) -> (j, i)>,
+ affine_map<(i, j) -> ()>
+]
+#contraction2d_trans_trait = {
+ indexing_maps = #contraction2d_trans_accesses,
+ iterator_types = ["reduction", "reduction"]
+}
+
+// CHECK-LABEL: func @full_contract2
+// CHECK-SAME: %[[A:.*0]]: vector<2x3xf32>,
+// CHECK-SAME: %[[B:.*1]]: vector<3x2xf32>,
+// CHECK-SAME: %[[C:.*2]]: f32
+// CHECK: %[[Z:.*]] = constant dense<0.000000e+00> : vector<3xf32>
+// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2x3xf32>
+// CHECK: %[[T1:.*]] = vector.extract %[[B]][0] : vector<3x2xf32>
+// CHECK: %[[T2:.*]] = vector.extract %[[T1]][0] : vector<2xf32>
+// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[Z]] [0] : f32 into vector<3xf32>
+// CHECK: %[[T4:.*]] = vector.extract %[[B]][1] : vector<3x2xf32>
+// CHECK: %[[T5:.*]] = vector.extract %[[T4]][0] : vector<2xf32>
+// CHECK: %[[T6:.*]] = vector.insert %[[T5]], %[[T3]] [1] : f32 into vector<3xf32>
+// CHECK: %[[T7:.*]] = vector.extract %[[B]][2] : vector<3x2xf32>
+// CHECK: %[[T8:.*]] = vector.extract %[[T7]][0] : vector<2xf32>
+// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T6]] [2] : f32 into vector<3xf32>
+// CHECK: %[[T10:.*]] = vector.fma %[[T0]], %[[T9]], %[[Z]] : vector<3xf32>
+// CHECK: %[[T11:.*]] = vector.reductionv2 "add", %[[T10]], %[[C]] : vector<3xf32>, f32 into f32
+// CHECK: %[[T12:.*]] = vector.extract %[[A]][1] : vector<2x3xf32>
+// CHECK: %[[T13:.*]] = vector.extract %[[B]][0] : vector<3x2xf32>
+// CHECK: %[[T14:.*]] = vector.extract %[[T13]][1] : vector<2xf32>
+// CHECK: %[[T15:.*]] = vector.insert %[[T14]], %[[Z]] [0] : f32 into vector<3xf32>
+// CHECK: %[[T16:.*]] = vector.extract %[[B]][1] : vector<3x2xf32>
+// CHECK: %[[T17:.*]] = vector.extract %[[T16]][1] : vector<2xf32>
+// CHECK: %[[T18:.*]] = vector.insert %[[T17]], %[[T15]] [1] : f32 into vector<3xf32>
+// CHECK: %[[T19:.*]] = vector.extract %[[B]][2] : vector<3x2xf32>
+// CHECK: %[[T20:.*]] = vector.extract %[[T19]][1] : vector<2xf32>
+// CHECK: %[[T21:.*]] = vector.insert %[[T20]], %[[T18]] [2] : f32 into vector<3xf32>
+// CHECK: %[[T22:.*]] = vector.fma %[[T12]], %[[T21]], %[[Z]] : vector<3xf32>
+// CHECK: %[[T23:.*]] = vector.reductionv2 "add", %[[T22]], %[[T11]] : vector<3xf32>, f32 into f32
+// CHECK: return %[[T23]] : f32
+
+func @full_contract2(%arg0: vector<2x3xf32>,
+ %arg1: vector<3x2xf32>,
+ %arg2: f32) -> f32 {
+ %0 = vector.contract #contraction2d_trans_trait %arg0, %arg1, %arg2
+ : vector<2x3xf32>, vector<3x2xf32> into f32
+ return %0 : f32
+}
More information about the Mlir-commits
mailing list