[Mlir-commits] [mlir] db188ad - [mlir][Vector] NFC - Compress vector to outerproduct lowering.

Nicolas Vasilache llvmlistbot at llvm.org
Fri Jul 2 09:41:59 PDT 2021


Author: Nicolas Vasilache
Date: 2021-07-02T16:41:51Z
New Revision: db188adfb12f6783c5419d5165a1123b9f5b56b0

URL: https://github.com/llvm/llvm-project/commit/db188adfb12f6783c5419d5165a1123b9f5b56b0
DIFF: https://github.com/llvm/llvm-project/commit/db188adfb12f6783c5419d5165a1123b9f5b56b0.diff

LOG: [mlir][Vector] NFC - Compress vector to outerproduct lowering.

The implementation has become too unwieldy and cognitive overhead wins.
Instead compress the implementation in preparation for additional lowering paths.

Differential Revision: https://reviews.llvm.org/D105359

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/VectorTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index 1a7d2e80d56f7..39a6c39059fba 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -1816,6 +1816,72 @@ ContractionOpToMatmulOpLowering::matchAndRewrite(vector::ContractionOp op,
   return success();
 }
 
+namespace {
+struct IteratorType {
+  IteratorType(StringRef strRef) : strRef(strRef) {}
+  bool isOfType(Attribute attr) const {
+    auto sAttr = attr.dyn_cast<StringAttr>();
+    return sAttr && sAttr.getValue() == strRef;
+  }
+  StringRef strRef;
+};
+struct Par : public IteratorType {
+  Par() : IteratorType(getParallelIteratorTypeName()) {}
+};
+struct Red : public IteratorType {
+  Red() : IteratorType(getReductionIteratorTypeName()) {}
+};
+
+// Unroll outer-products along reduction.
+struct UnrolledOuterProductEmitter {
+  using MapList = ArrayRef<ArrayRef<AffineExpr>>;
+
+  UnrolledOuterProductEmitter(PatternRewriter &rewriter,
+                              vector::ContractionOp op)
+      : rewriter(rewriter), loc(op.getLoc()), kind(op.kind()),
+        iterators(op.iterator_types()), maps(op.getIndexingMaps()), op(op) {}
+
+  Value t(Value v) {
+    static constexpr std::array<int64_t, 2> perm = {1, 0};
+    return rewriter.create<vector::TransposeOp>(loc, v, perm);
+  }
+
+  bool iters(ArrayRef<IteratorType> its) {
+    if (its.size() != iterators.size())
+      return false;
+    for (int i = 0, e = its.size(); i != e; ++i) {
+      if (!its[i].isOfType(iterators[i]))
+        return false;
+    }
+    return true;
+  }
+
+  bool layout(MapList l) {
+    auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
+    return maps == infer(l);
+  }
+
+  LogicalResult outer_prod(Value lhs, Value rhs, Value res, int reductionSize) {
+    assert(reductionSize > 0);
+    for (int64_t k = 0; k < reductionSize; ++k) {
+      Value a = rewriter.create<vector::ExtractOp>(loc, lhs, k);
+      Value b = rewriter.create<vector::ExtractOp>(loc, rhs, k);
+      res = rewriter.create<vector::OuterProductOp>(loc, res.getType(), a, b,
+                                                    res, kind);
+    }
+    rewriter.replaceOp(op, res);
+    return success();
+  }
+
+  PatternRewriter &rewriter;
+  Location loc;
+  vector::CombiningKind kind;
+  ArrayAttr iterators;
+  SmallVector<AffineMap, 4> maps;
+  Operation *op;
+};
+} // namespace
+
 /// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul
 /// semantics to a reduction_size-unrolled sequence:
 /// ```
@@ -1844,104 +1910,64 @@ LogicalResult ContractionOpToOuterProductOpLowering::matchAndRewrite(
   if (failed(filter(op)))
     return failure();
 
-  Location loc = op.getLoc();
-  int64_t reductionSize = 0;
   VectorType lhsType = op.getLhsType();
   Value lhs = op.lhs(), rhs = op.rhs(), res = op.acc();
 
   // Set up the parallel/reduction structure in right form.
-  using MapList = ArrayRef<ArrayRef<AffineExpr>>;
-  auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
   AffineExpr m, n, k;
   bindDims(rewriter.getContext(), m, n, k);
-  static constexpr std::array<int64_t, 2> perm = {1, 0};
-  auto iteratorTypes = op.iterator_types().getValue();
-  SmallVector<AffineMap, 4> maps = op.getIndexingMaps();
-  if (isParallelIterator(iteratorTypes[0]) &&
-      isParallelIterator(iteratorTypes[1]) &&
-      isReductionIterator(iteratorTypes[2])) {
-    //
-    // Two outer parallel, one inner reduction (matmat flavor).
-    //
-    if (maps == infer({{m, k}, {k, n}, {m, n}})) {
-      // This is the classical row-major matmul. Just permute the lhs.
-      reductionSize = lhsType.getDimSize(1);
-      lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
-    } else if (maps == infer({{m, k}, {n, k}, {m, n}})) {
-      // TODO: may be better to fail and use some vector<k> -> scalar reduction.
-      reductionSize = lhsType.getDimSize(1);
-      lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
-      rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
-    } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
-      // No need to permute anything.
-      reductionSize = lhsType.getDimSize(0);
-    } else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
-      // Just permute the rhs.
-      reductionSize = lhsType.getDimSize(0);
-      rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
-    } else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
-      // This is the classical row-major matmul. Just permute the lhs.
-      reductionSize = lhsType.getDimSize(1);
-      Value tmp = rhs;
-      rhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
-      lhs = tmp;
-    } else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
-      // TODO: may be better to fail and use some vector<k> -> scalar reduction.
-      reductionSize = lhsType.getDimSize(1);
-      Value tmp = rhs;
-      rhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
-      lhs = rewriter.create<vector::TransposeOp>(loc, tmp, perm);
-    } else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
-      // No need to permute anything, but still swap lhs and rhs.
-      reductionSize = lhsType.getDimSize(0);
-      std::swap(lhs, rhs);
-    } else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
-      // Just permute the rhs.
-      reductionSize = lhsType.getDimSize(0);
-      Value tmp = lhs;
-      lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
-      rhs = tmp;
-    } else {
-      return failure();
-    }
-  } else if (isParallelIterator(iteratorTypes[0]) &&
-             isReductionIterator(iteratorTypes[1])) {
-    //
-    // One outer parallel, one inner reduction (matvec flavor)
-    //
-    if (maps == infer({{m, n}, {n}, {m}})) {
-      // Case mat-vec: transpose.
-      reductionSize = lhsType.getDimSize(1);
-      lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
-    } else if (maps == infer({{n, m}, {n}, {m}})) {
-      // Case mat-trans-vec: ready to go.
-      reductionSize = lhsType.getDimSize(0);
-    } else if (maps == infer({{n}, {m, n}, {m}})) {
-      // Case vec-mat: swap and transpose.
-      reductionSize = lhsType.getDimSize(0);
-      std::swap(lhs, rhs);
-      lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
-    } else if (maps == infer({{n}, {n, m}, {m}})) {
-      // Case vec-mat-trans: swap and ready to go.
-      reductionSize = lhsType.getDimSize(0);
-      std::swap(lhs, rhs);
-    } else {
-      return failure();
-    }
-  } else {
+
+  //
+  // Two outer parallel, one inner reduction (matmat flavor).
+  //
+  UnrolledOuterProductEmitter e(rewriter, op);
+  if (e.iters({Par(), Par(), Red()})) {
+    // Classical row-major matmul:  Just permute the lhs.
+    if (e.layout({{m, k}, {k, n}, {m, n}}))
+      return e.outer_prod(e.t(lhs), rhs, res, lhsType.getDimSize(1));
+    // TODO: may be better to fail and use some vector<k> -> scalar reduction.
+    if (e.layout({{m, k}, {n, k}, {m, n}}))
+      return e.outer_prod(e.t(lhs), e.t(rhs), res, lhsType.getDimSize(1));
+    // No need to permute anything.
+    if (e.layout({{k, m}, {k, n}, {m, n}}))
+      return e.outer_prod(lhs, rhs, res, lhsType.getDimSize(0));
+    // Just permute the rhs.
+    if (e.layout({{k, m}, {n, k}, {m, n}}))
+      return e.outer_prod(lhs, e.t(rhs), res, lhsType.getDimSize(0));
+    // Transposed output: swap RHS and LHS.
+    // Classical row-major matmul: permute the lhs.
+    if (e.layout({{m, k}, {k, n}, {n, m}}))
+      return e.outer_prod(rhs, e.t(lhs), res, lhsType.getDimSize(1));
+    // TODO: may be better to fail and use some vector<k> -> scalar reduction.
+    if (e.layout({{m, k}, {n, k}, {n, m}}))
+      return e.outer_prod(e.t(rhs), e.t(lhs), res, lhsType.getDimSize(1));
+    if (e.layout({{k, m}, {k, n}, {n, m}}))
+      return e.outer_prod(rhs, lhs, res, lhsType.getDimSize(0));
+    if (e.layout({{k, m}, {n, k}, {n, m}}))
+      return e.outer_prod(e.t(rhs), lhs, res, lhsType.getDimSize(0));
     return failure();
   }
-  assert(reductionSize > 0);
-
-  // Unroll outer-products along reduction.
-  for (int64_t k = 0; k < reductionSize; ++k) {
-    Value a = rewriter.create<vector::ExtractOp>(op.getLoc(), lhs, k);
-    Value b = rewriter.create<vector::ExtractOp>(op.getLoc(), rhs, k);
-    res = rewriter.create<vector::OuterProductOp>(op.getLoc(), res.getType(), a,
-                                                  b, res, op.kind());
+
+  //
+  // One outer parallel, one inner reduction (matvec flavor)
+  //
+  if (e.iters({Par(), Red()})) {
+    // Case mat-vec: transpose.
+    if (e.layout({{m, n}, {n}, {m}}))
+      return e.outer_prod(e.t(lhs), rhs, res, lhsType.getDimSize(1));
+    // Case mat-trans-vec: ready to go.
+    if (e.layout({{n, m}, {n}, {m}}))
+      return e.outer_prod(lhs, rhs, res, lhsType.getDimSize(0));
+    // Case vec-mat: swap and transpose.
+    if (e.layout({{n}, {m, n}, {m}}))
+      return e.outer_prod(e.t(rhs), lhs, res, lhsType.getDimSize(0));
+    // Case vec-mat-trans: swap and ready to go.
+    if (e.layout({{n}, {n, m}, {m}}))
+      return e.outer_prod(rhs, lhs, res, lhsType.getDimSize(0));
+    return failure();
   }
-  rewriter.replaceOp(op, res);
-  return success();
+
+  return failure();
 }
 
 LogicalResult


        


More information about the Mlir-commits mailing list