[Mlir-commits] [mlir] 4525d52 - Revert "[mlir][Vector] NFC - Compress vector to outerproduct lowering."

Mehdi Amini llvmlistbot at llvm.org
Fri Jul 2 10:55:40 PDT 2021


Author: Mehdi Amini
Date: 2021-07-02T17:55:06Z
New Revision: 4525d52c73de592dad73518a54d8925c66b20549

URL: https://github.com/llvm/llvm-project/commit/4525d52c73de592dad73518a54d8925c66b20549
DIFF: https://github.com/llvm/llvm-project/commit/4525d52c73de592dad73518a54d8925c66b20549.diff

LOG: Revert "[mlir][Vector] NFC - Compress vector to outerproduct lowering."

This reverts commit db188adfb12f6783c5419d5165a1123b9f5b56b0.

Breaks the GCC tests, likely because of some order of evaluation
difference between clang and gcc.

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/VectorTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index 39a6c39059fba..1a7d2e80d56f7 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -1816,72 +1816,6 @@ ContractionOpToMatmulOpLowering::matchAndRewrite(vector::ContractionOp op,
   return success();
 }
 
-namespace {
-struct IteratorType {
-  IteratorType(StringRef strRef) : strRef(strRef) {}
-  bool isOfType(Attribute attr) const {
-    auto sAttr = attr.dyn_cast<StringAttr>();
-    return sAttr && sAttr.getValue() == strRef;
-  }
-  StringRef strRef;
-};
-struct Par : public IteratorType {
-  Par() : IteratorType(getParallelIteratorTypeName()) {}
-};
-struct Red : public IteratorType {
-  Red() : IteratorType(getReductionIteratorTypeName()) {}
-};
-
-// Unroll outer-products along reduction.
-struct UnrolledOuterProductEmitter {
-  using MapList = ArrayRef<ArrayRef<AffineExpr>>;
-
-  UnrolledOuterProductEmitter(PatternRewriter &rewriter,
-                              vector::ContractionOp op)
-      : rewriter(rewriter), loc(op.getLoc()), kind(op.kind()),
-        iterators(op.iterator_types()), maps(op.getIndexingMaps()), op(op) {}
-
-  Value t(Value v) {
-    static constexpr std::array<int64_t, 2> perm = {1, 0};
-    return rewriter.create<vector::TransposeOp>(loc, v, perm);
-  }
-
-  bool iters(ArrayRef<IteratorType> its) {
-    if (its.size() != iterators.size())
-      return false;
-    for (int i = 0, e = its.size(); i != e; ++i) {
-      if (!its[i].isOfType(iterators[i]))
-        return false;
-    }
-    return true;
-  }
-
-  bool layout(MapList l) {
-    auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
-    return maps == infer(l);
-  }
-
-  LogicalResult outer_prod(Value lhs, Value rhs, Value res, int reductionSize) {
-    assert(reductionSize > 0);
-    for (int64_t k = 0; k < reductionSize; ++k) {
-      Value a = rewriter.create<vector::ExtractOp>(loc, lhs, k);
-      Value b = rewriter.create<vector::ExtractOp>(loc, rhs, k);
-      res = rewriter.create<vector::OuterProductOp>(loc, res.getType(), a, b,
-                                                    res, kind);
-    }
-    rewriter.replaceOp(op, res);
-    return success();
-  }
-
-  PatternRewriter &rewriter;
-  Location loc;
-  vector::CombiningKind kind;
-  ArrayAttr iterators;
-  SmallVector<AffineMap, 4> maps;
-  Operation *op;
-};
-} // namespace
-
 /// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul
 /// semantics to a reduction_size-unrolled sequence:
 /// ```
@@ -1910,64 +1844,104 @@ LogicalResult ContractionOpToOuterProductOpLowering::matchAndRewrite(
   if (failed(filter(op)))
     return failure();
 
+  Location loc = op.getLoc();
+  int64_t reductionSize = 0;
   VectorType lhsType = op.getLhsType();
   Value lhs = op.lhs(), rhs = op.rhs(), res = op.acc();
 
   // Set up the parallel/reduction structure in right form.
+  using MapList = ArrayRef<ArrayRef<AffineExpr>>;
+  auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
   AffineExpr m, n, k;
   bindDims(rewriter.getContext(), m, n, k);
-
-  //
-  // Two outer parallel, one inner reduction (matmat flavor).
-  //
-  UnrolledOuterProductEmitter e(rewriter, op);
-  if (e.iters({Par(), Par(), Red()})) {
-    // Classical row-major matmul:  Just permute the lhs.
-    if (e.layout({{m, k}, {k, n}, {m, n}}))
-      return e.outer_prod(e.t(lhs), rhs, res, lhsType.getDimSize(1));
-    // TODO: may be better to fail and use some vector<k> -> scalar reduction.
-    if (e.layout({{m, k}, {n, k}, {m, n}}))
-      return e.outer_prod(e.t(lhs), e.t(rhs), res, lhsType.getDimSize(1));
-    // No need to permute anything.
-    if (e.layout({{k, m}, {k, n}, {m, n}}))
-      return e.outer_prod(lhs, rhs, res, lhsType.getDimSize(0));
-    // Just permute the rhs.
-    if (e.layout({{k, m}, {n, k}, {m, n}}))
-      return e.outer_prod(lhs, e.t(rhs), res, lhsType.getDimSize(0));
-    // Transposed output: swap RHS and LHS.
-    // Classical row-major matmul: permute the lhs.
-    if (e.layout({{m, k}, {k, n}, {n, m}}))
-      return e.outer_prod(rhs, e.t(lhs), res, lhsType.getDimSize(1));
-    // TODO: may be better to fail and use some vector<k> -> scalar reduction.
-    if (e.layout({{m, k}, {n, k}, {n, m}}))
-      return e.outer_prod(e.t(rhs), e.t(lhs), res, lhsType.getDimSize(1));
-    if (e.layout({{k, m}, {k, n}, {n, m}}))
-      return e.outer_prod(rhs, lhs, res, lhsType.getDimSize(0));
-    if (e.layout({{k, m}, {n, k}, {n, m}}))
-      return e.outer_prod(e.t(rhs), lhs, res, lhsType.getDimSize(0));
+  static constexpr std::array<int64_t, 2> perm = {1, 0};
+  auto iteratorTypes = op.iterator_types().getValue();
+  SmallVector<AffineMap, 4> maps = op.getIndexingMaps();
+  if (isParallelIterator(iteratorTypes[0]) &&
+      isParallelIterator(iteratorTypes[1]) &&
+      isReductionIterator(iteratorTypes[2])) {
+    //
+    // Two outer parallel, one inner reduction (matmat flavor).
+    //
+    if (maps == infer({{m, k}, {k, n}, {m, n}})) {
+      // This is the classical row-major matmul. Just permute the lhs.
+      reductionSize = lhsType.getDimSize(1);
+      lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
+    } else if (maps == infer({{m, k}, {n, k}, {m, n}})) {
+      // TODO: may be better to fail and use some vector<k> -> scalar reduction.
+      reductionSize = lhsType.getDimSize(1);
+      lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
+      rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
+    } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
+      // No need to permute anything.
+      reductionSize = lhsType.getDimSize(0);
+    } else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
+      // Just permute the rhs.
+      reductionSize = lhsType.getDimSize(0);
+      rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
+    } else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
+      // This is the classical row-major matmul. Just permute the lhs.
+      reductionSize = lhsType.getDimSize(1);
+      Value tmp = rhs;
+      rhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
+      lhs = tmp;
+    } else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
+      // TODO: may be better to fail and use some vector<k> -> scalar reduction.
+      reductionSize = lhsType.getDimSize(1);
+      Value tmp = rhs;
+      rhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
+      lhs = rewriter.create<vector::TransposeOp>(loc, tmp, perm);
+    } else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
+      // No need to permute anything, but still swap lhs and rhs.
+      reductionSize = lhsType.getDimSize(0);
+      std::swap(lhs, rhs);
+    } else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
+      // Just permute the rhs.
+      reductionSize = lhsType.getDimSize(0);
+      Value tmp = lhs;
+      lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
+      rhs = tmp;
+    } else {
+      return failure();
+    }
+  } else if (isParallelIterator(iteratorTypes[0]) &&
+             isReductionIterator(iteratorTypes[1])) {
+    //
+    // One outer parallel, one inner reduction (matvec flavor)
+    //
+    if (maps == infer({{m, n}, {n}, {m}})) {
+      // Case mat-vec: transpose.
+      reductionSize = lhsType.getDimSize(1);
+      lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
+    } else if (maps == infer({{n, m}, {n}, {m}})) {
+      // Case mat-trans-vec: ready to go.
+      reductionSize = lhsType.getDimSize(0);
+    } else if (maps == infer({{n}, {m, n}, {m}})) {
+      // Case vec-mat: swap and transpose.
+      reductionSize = lhsType.getDimSize(0);
+      std::swap(lhs, rhs);
+      lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
+    } else if (maps == infer({{n}, {n, m}, {m}})) {
+      // Case vec-mat-trans: swap and ready to go.
+      reductionSize = lhsType.getDimSize(0);
+      std::swap(lhs, rhs);
+    } else {
+      return failure();
+    }
+  } else {
     return failure();
   }
-
-  //
-  // One outer parallel, one inner reduction (matvec flavor)
-  //
-  if (e.iters({Par(), Red()})) {
-    // Case mat-vec: transpose.
-    if (e.layout({{m, n}, {n}, {m}}))
-      return e.outer_prod(e.t(lhs), rhs, res, lhsType.getDimSize(1));
-    // Case mat-trans-vec: ready to go.
-    if (e.layout({{n, m}, {n}, {m}}))
-      return e.outer_prod(lhs, rhs, res, lhsType.getDimSize(0));
-    // Case vec-mat: swap and transpose.
-    if (e.layout({{n}, {m, n}, {m}}))
-      return e.outer_prod(e.t(rhs), lhs, res, lhsType.getDimSize(0));
-    // Case vec-mat-trans: swap and ready to go.
-    if (e.layout({{n}, {n, m}, {m}}))
-      return e.outer_prod(rhs, lhs, res, lhsType.getDimSize(0));
-    return failure();
+  assert(reductionSize > 0);
+
+  // Unroll outer-products along reduction.
+  for (int64_t k = 0; k < reductionSize; ++k) {
+    Value a = rewriter.create<vector::ExtractOp>(op.getLoc(), lhs, k);
+    Value b = rewriter.create<vector::ExtractOp>(op.getLoc(), rhs, k);
+    res = rewriter.create<vector::OuterProductOp>(op.getLoc(), res.getType(), a,
+                                                  b, res, op.kind());
   }
-
-  return failure();
+  rewriter.replaceOp(op, res);
+  return success();
 }
 
 LogicalResult


        


More information about the Mlir-commits mailing list