[Mlir-commits] [mlir] ee260c1 - [mlir] [VectorOps] Multi-dim reductions for lowering vector.contract

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Feb 20 14:17:17 PST 2020


Author: aartbik
Date: 2020-02-20T14:16:50-08:00
New Revision: ee260c1a0f1c0a8fd1179cdab9fb4312086dcc54

URL: https://github.com/llvm/llvm-project/commit/ee260c1a0f1c0a8fd1179cdab9fb4312086dcc54
DIFF: https://github.com/llvm/llvm-project/commit/ee260c1a0f1c0a8fd1179cdab9fb4312086dcc54.diff

LOG: [mlir] [VectorOps] Multi-dim reductions for lowering vector.contract

Summary:
This implements the last step for lowering vector.contract progressively
to LLVM IR (except for masks). Multi-dimensional reductions that remain
after expanding all parallel dimensions are lowered into into simpler
vector.contract operations until a trivial 1-dim reduction remains.

Reviewers: nicolasvasilache, andydavis1

Reviewed By: andydavis1

Subscribers: mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, nicolasvasilache, arpith-jacob, mgester, lucyrfox, liufengdb, Joonsoo, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D74880

Added: 
    

Modified: 
    mlir/lib/Dialect/VectorOps/VectorTransforms.cpp
    mlir/test/Dialect/VectorOps/vector-contract-transforms.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp b/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp
index 61a3d556a70b..923f1c215583 100644
--- a/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp
@@ -929,21 +929,9 @@ class ContractionOpLowering : public OpRewritePattern<vector::ContractionOp> {
       }
     }
 
-    // Lower the only remaining contraction dimensions.
-    // TODO(ajcbik): handle multi-dim reductions
-    auto loc = op.getLoc();
-    Type resType = op.getResultType();
-    if (!resType.isa<VectorType>() && lhsType.getRank() == 1 &&
-        rhsType.getRank() == 1) {
-
-      Value zero = rewriter.create<ConstantOp>(loc, resType,
-                                               rewriter.getZeroAttr(resType));
-      Value splat = rewriter.create<SplatOp>(loc, lhsType, zero);
-      Value fma =
-          rewriter.create<vector::FMAOp>(loc, op.lhs(), op.rhs(), splat);
-      StringAttr kind = rewriter.getStringAttr("add");
-      rewriter.replaceOpWithNewOp<vector::ReductionV2Op>(op, resType, kind, fma,
-                                                         op.acc());
+    // Lower the first remaining reduction dimension.
+    if (!contractingDimMap.empty()) {
+      rewriter.replaceOp(op, lowerReduction(op, rewriter));
       return matchSuccess();
     }
 
@@ -981,27 +969,14 @@ class ContractionOpLowering : public OpRewritePattern<vector::ContractionOp> {
     Optional<int64_t> lookup = getResultIndex(iMap[2], iterIndex);
     assert(lookup.hasValue() && "parallel index not listed in reduction");
     int64_t resIndex = lookup.getValue();
-    // Construct new iterator types.
-    ArrayAttr iteratorTypes = op.iterator_types();
-    SmallVector<Attribute, 4> lowIterTypes;
-    for (auto it : llvm::enumerate(iteratorTypes)) {
-      int64_t idx = it.index();
-      if (idx == iterIndex) {
-        assert(it.value().cast<StringAttr>().getValue() ==
-                   getParallelIteratorTypeName() &&
-               "parallel index not marked as such");
-        continue;
-      }
-      lowIterTypes.push_back(it.value());
-    }
-    // Construct new affine map array attribute.
+    // Construct new iterator types and affine map array attribute.
     SmallVector<AffineMap, 4> lowIndexingMaps;
     lowIndexingMaps.push_back(adjustMap(iMap[0], iterIndex, rewriter));
     lowIndexingMaps.push_back(adjustMap(iMap[1], iterIndex, rewriter));
     lowIndexingMaps.push_back(adjustMap(iMap[2], iterIndex, rewriter));
     auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps);
-    // Construct new iterator types array attribute.
-    auto lowIter = rewriter.getArrayAttr(lowIterTypes);
+    auto lowIter =
+        rewriter.getArrayAttr(adjustIter(op.iterator_types(), iterIndex));
     // Unroll into a series of lower dimensional vector.contract ops.
     Location loc = op.getLoc();
     Value result = zeroVector(loc, resType, rewriter);
@@ -1017,6 +992,56 @@ class ContractionOpLowering : public OpRewritePattern<vector::ContractionOp> {
     return result;
   }
 
+  // Lower one reduction dimension.
+  Value lowerReduction(vector::ContractionOp op,
+                       PatternRewriter &rewriter) const {
+    auto loc = op.getLoc();
+    VectorType lhsType = op.getLhsType();
+    VectorType rhsType = op.getRhsType();
+    Type resType = op.getResultType();
+    assert(!resType.isa<VectorType>());
+    // Use iterator index 0.
+    int64_t iterIndex = 0;
+    SmallVector<AffineMap, 4> iMap = op.getIndexingMaps();
+    Optional<int64_t> lookupLhs = getResultIndex(iMap[0], iterIndex);
+    Optional<int64_t> lookupRhs = getResultIndex(iMap[1], iterIndex);
+    assert(lookupLhs.hasValue() && "missing LHS parallel index");
+    assert(lookupRhs.hasValue() && "missing RHS parallel index");
+    int64_t lhsIndex = lookupLhs.getValue();
+    int64_t rhsIndex = lookupRhs.getValue();
+    int64_t dimSize = lhsType.getDimSize(lhsIndex);
+    assert(dimSize == rhsType.getDimSize(rhsIndex) && "corrupt shape");
+    // Base case.
+    if (lhsType.getRank() == 1) {
+      assert(rhsType.getRank() == 1 && "corrupt contraction");
+      Value zero = zeroVector(loc, lhsType, rewriter);
+      Value fma = rewriter.create<vector::FMAOp>(loc, op.lhs(), op.rhs(), zero);
+      StringAttr kind = rewriter.getStringAttr("add");
+      return rewriter.create<vector::ReductionV2Op>(loc, resType, kind, fma,
+                                                    op.acc());
+    }
+    // Construct new iterator types and affine map array attribute.
+    SmallVector<AffineMap, 4> lowIndexingMaps;
+    lowIndexingMaps.push_back(adjustMap(iMap[0], iterIndex, rewriter));
+    lowIndexingMaps.push_back(adjustMap(iMap[1], iterIndex, rewriter));
+    lowIndexingMaps.push_back(adjustMap(iMap[2], iterIndex, rewriter));
+    auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps);
+    auto lowIter =
+        rewriter.getArrayAttr(adjustIter(op.iterator_types(), iterIndex));
+    // Unroll into a series of lower dimensional vector.contract ops.
+    // By feeding the initial accumulator into the first contraction,
+    // and the result of each contraction into the next, eventually
+    // the sum of all reductions is computed.
+    Value result = op.acc();
+    for (int64_t d = 0; d < dimSize; ++d) {
+      auto lhs = reshapeLoad(loc, op.lhs(), lhsType, lhsIndex, d, rewriter);
+      auto rhs = reshapeLoad(loc, op.rhs(), rhsType, rhsIndex, d, rewriter);
+      result = rewriter.create<vector::ContractionOp>(loc, lhs, rhs, result,
+                                                      lowAffine, lowIter);
+    }
+    return result;
+  }
+
   // Helper method to construct a zero vector.
   static Value zeroVector(Location loc, VectorType vType,
                           PatternRewriter &rewriter) {
@@ -1036,6 +1061,20 @@ class ContractionOpLowering : public OpRewritePattern<vector::ContractionOp> {
     return None;
   }
 
+  // Helper to construct iterator types with one index removed.
+  static SmallVector<Attribute, 4> adjustIter(ArrayAttr iteratorTypes,
+                                              int64_t index) {
+    SmallVector<Attribute, 4> results;
+    for (auto it : llvm::enumerate(iteratorTypes)) {
+      int64_t idx = it.index();
+      if (idx == index) {
+        continue;
+      }
+      results.push_back(it.value());
+    }
+    return results;
+  }
+
   // Helper to construct an affine map with one index removed.
   static AffineMap adjustMap(AffineMap map, int64_t index,
                              PatternRewriter &rewriter) {

diff  --git a/mlir/test/Dialect/VectorOps/vector-contract-transforms.mlir b/mlir/test/Dialect/VectorOps/vector-contract-transforms.mlir
index f781e37d586b..362c85a38d09 100644
--- a/mlir/test/Dialect/VectorOps/vector-contract-transforms.mlir
+++ b/mlir/test/Dialect/VectorOps/vector-contract-transforms.mlir
@@ -169,3 +169,84 @@ func @extract_contract4(%arg0: vector<2x2xf32>,
   return %0 : vector<2x2xf32>
 }
 
+#contraction2d_accesses = [
+  affine_map<(i, j) -> (i, j)>,
+  affine_map<(i, j) -> (i, j)>,
+  affine_map<(i, j) -> ()>
+]
+#contraction2d_trait = {
+  indexing_maps = #contraction2d_accesses,
+  iterator_types = ["reduction", "reduction"]
+}
+
+// CHECK-LABEL: func @full_contract1
+// CHECK-SAME: %[[A:.*0]]: vector<2x3xf32>,
+// CHECK-SAME: %[[B:.*1]]: vector<2x3xf32>,
+// CHECK-SAME: %[[C:.*2]]: f32
+// CHECK:      %[[Z:.*]] = constant dense<0.000000e+00> : vector<3xf32>
+// CHECK:      %[[T0:.*]] = vector.extract %[[A]][0] : vector<2x3xf32>
+// CHECK:      %[[T1:.*]] = vector.extract %[[B]][0] : vector<2x3xf32>
+// CHECK:      %[[T2:.*]] = vector.fma %[[T0]], %[[T1]], %[[Z]] : vector<3xf32>
+// CHECK:      %[[T3:.*]] = vector.reductionv2 "add", %[[T2]], %[[C]] : vector<3xf32>, f32 into f32
+// CHECK:      %[[T4:.*]] = vector.extract %[[A]][1] : vector<2x3xf32>
+// CHECK:      %[[T5:.*]] = vector.extract %[[B]][1] : vector<2x3xf32>
+// CHECK:      %[[T6:.*]] = vector.fma %[[T4]], %[[T5]], %[[Z]] : vector<3xf32>
+// CHECK:      %[[T7:.*]] = vector.reductionv2 "add", %[[T6]], %[[T3]] : vector<3xf32>, f32 into f32
+// CHECK:      return %[[T7]] : f32
+
+func @full_contract1(%arg0: vector<2x3xf32>,
+                     %arg1: vector<2x3xf32>,
+		     %arg2: f32) -> f32 {
+  %0 = vector.contract #contraction2d_trait %arg0, %arg1, %arg2
+    : vector<2x3xf32>, vector<2x3xf32> into f32
+  return %0 : f32
+}
+
+#contraction2d_trans_accesses = [
+  affine_map<(i, j) -> (i, j)>,
+  affine_map<(i, j) -> (j, i)>,
+  affine_map<(i, j) -> ()>
+]
+#contraction2d_trans_trait = {
+  indexing_maps = #contraction2d_trans_accesses,
+  iterator_types = ["reduction", "reduction"]
+}
+
+// CHECK-LABEL: func @full_contract2
+// CHECK-SAME: %[[A:.*0]]: vector<2x3xf32>,
+// CHECK-SAME: %[[B:.*1]]: vector<3x2xf32>,
+// CHECK-SAME: %[[C:.*2]]: f32
+// CHECK:      %[[Z:.*]] = constant dense<0.000000e+00> : vector<3xf32>
+// CHECK:      %[[T0:.*]] = vector.extract %[[A]][0] : vector<2x3xf32>
+// CHECK:      %[[T1:.*]] = vector.extract %[[B]][0] : vector<3x2xf32>
+// CHECK:      %[[T2:.*]] = vector.extract %[[T1]][0] : vector<2xf32>
+// CHECK:      %[[T3:.*]] = vector.insert %[[T2]], %[[Z]] [0] : f32 into vector<3xf32>
+// CHECK:      %[[T4:.*]] = vector.extract %[[B]][1] : vector<3x2xf32>
+// CHECK:      %[[T5:.*]] = vector.extract %[[T4]][0] : vector<2xf32>
+// CHECK:      %[[T6:.*]] = vector.insert %[[T5]], %[[T3]] [1] : f32 into vector<3xf32>
+// CHECK:      %[[T7:.*]] = vector.extract %[[B]][2] : vector<3x2xf32>
+// CHECK:      %[[T8:.*]] = vector.extract %[[T7]][0] : vector<2xf32>
+// CHECK:      %[[T9:.*]] = vector.insert %[[T8]], %[[T6]] [2] : f32 into vector<3xf32>
+// CHECK:      %[[T10:.*]] = vector.fma %[[T0]], %[[T9]], %[[Z]] : vector<3xf32>
+// CHECK:      %[[T11:.*]] = vector.reductionv2 "add", %[[T10]], %[[C]] : vector<3xf32>, f32 into f32
+// CHECK:      %[[T12:.*]] = vector.extract %[[A]][1] : vector<2x3xf32>
+// CHECK:      %[[T13:.*]] = vector.extract %[[B]][0] : vector<3x2xf32>
+// CHECK:      %[[T14:.*]] = vector.extract %[[T13]][1] : vector<2xf32>
+// CHECK:      %[[T15:.*]] = vector.insert %[[T14]], %[[Z]] [0] : f32 into vector<3xf32>
+// CHECK:      %[[T16:.*]] = vector.extract %[[B]][1] : vector<3x2xf32>
+// CHECK:      %[[T17:.*]] = vector.extract %[[T16]][1] : vector<2xf32>
+// CHECK:      %[[T18:.*]] = vector.insert %[[T17]], %[[T15]] [1] : f32 into vector<3xf32>
+// CHECK:      %[[T19:.*]] = vector.extract %[[B]][2] : vector<3x2xf32>
+// CHECK:      %[[T20:.*]] = vector.extract %[[T19]][1] : vector<2xf32>
+// CHECK:      %[[T21:.*]] = vector.insert %[[T20]], %[[T18]] [2] : f32 into vector<3xf32>
+// CHECK:      %[[T22:.*]] = vector.fma %[[T12]], %[[T21]], %[[Z]] : vector<3xf32>
+// CHECK:      %[[T23:.*]] = vector.reductionv2 "add", %[[T22]], %[[T11]] : vector<3xf32>, f32 into f32
+// CHECK:      return %[[T23]] : f32
+
+func @full_contract2(%arg0: vector<2x3xf32>,
+                     %arg1: vector<3x2xf32>,
+		     %arg2: f32) -> f32 {
+  %0 = vector.contract #contraction2d_trans_trait %arg0, %arg1, %arg2
+    : vector<2x3xf32>, vector<3x2xf32> into f32
+  return %0 : f32
+}


        


More information about the Mlir-commits mailing list