[Mlir-commits] [mlir] 6080271 - [mlir][vector] NFC - Refactor and extract a helper StructuredGenerator class

Nicolas Vasilache llvmlistbot at llvm.org
Fri Oct 15 09:02:56 PDT 2021


Author: Nicolas Vasilache
Date: 2021-10-15T16:01:59Z
New Revision: 60802715d1d0f06b49d7d0d3f4b38bd7aae0fe9c

URL: https://github.com/llvm/llvm-project/commit/60802715d1d0f06b49d7d0d3f4b38bd7aae0fe9c
DIFF: https://github.com/llvm/llvm-project/commit/60802715d1d0f06b49d7d0d3f4b38bd7aae0fe9c.diff

LOG: [mlir][vector] NFC - Refactor and extract a helper StructuredGenerator class

Differential Revision: https://reviews.llvm.org/D111893

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
    mlir/lib/Dialect/Vector/VectorTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
index 14f8b29f66892..e318bd5cf3e1a 100644
--- a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
@@ -19,11 +19,14 @@
 
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/Location.h"
 #include "mlir/Support/LLVM.h"
 #include "llvm/ADT/StringRef.h"
 
 namespace mlir {
 
+class PatternRewriter;
+
 /// Tests whether the given maps describe a row major matmul. The test is
 /// permutation-invariant. Note that this only checks the affine maps from an
 /// operation, so does not perform any checks on the math being performed within
@@ -132,6 +135,60 @@ inline StringRef toString(IteratorType t) {
   llvm_unreachable("Unsupported IteratorType");
 }
 
+/// Helper StructuredGenerator class to manipulate and rewrite ops with
+/// `StructuredOpInterface`. This is templated for now because VectorOps do not
+/// yet implement the StructuredOpInterface itself.
+template <typename StructuredOpInterface>
+class StructuredGenerator {
+public:
+  using MapList = ArrayRef<ArrayRef<AffineExpr>>;
+
+  struct IteratorType {
+    IteratorType(StringRef strRef) : strRef(strRef) {}
+    bool isOfType(Attribute attr) const {
+      auto sAttr = attr.dyn_cast<StringAttr>();
+      return sAttr && sAttr.getValue() == strRef;
+    }
+    StringRef strRef;
+  };
+  struct Par : public IteratorType {
+    Par() : IteratorType(getParallelIteratorTypeName()) {}
+  };
+  struct Red : public IteratorType {
+    Red() : IteratorType(getReductionIteratorTypeName()) {}
+  };
+  struct Win : public IteratorType {
+    Win() : IteratorType(getWindowIteratorTypeName()) {}
+  };
+
+  StructuredGenerator(PatternRewriter &rewriter, StructuredOpInterface op)
+      : rewriter(rewriter), ctx(op.getContext()), loc(op.getLoc()),
+        iterators(op.iterator_types()), maps(op.getIndexingMaps()), op(op) {}
+
+  bool iters(ArrayRef<IteratorType> its) {
+    if (its.size() != iterators.size())
+      return false;
+    for (int i = 0, e = its.size(); i != e; ++i) {
+      if (!its[i].isOfType(iterators[i]))
+        return false;
+    }
+    return true;
+  }
+
+  bool layout(MapList l) {
+    auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
+    return maps == infer(l);
+  }
+
+protected:
+  PatternRewriter &rewriter;
+  MLIRContext *ctx;
+  Location loc;
+  ArrayAttr iterators;
+  SmallVector<AffineMap, 4> maps;
+  Operation *op;
+};
+
 } // end namespace mlir
 
 #endif // MLIR_UTILS_STRUCTUREDOPSUTILS_H

diff  --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index 2b83c23730aad..d1b9835452b7d 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -1252,35 +1252,22 @@ struct Red : public IteratorType {
   Red() : IteratorType(getReductionIteratorTypeName()) {}
 };
 
-// Unroll outer-products along reduction.
-struct UnrolledOuterProductEmitter {
-  using MapList = ArrayRef<ArrayRef<AffineExpr>>;
+/// Generate a vector implementation for matmat, matvec and tmatvec.
+/// This unrolls outer-products along the reduction dimension.
+struct UnrolledOuterProductGenerator
+    : public StructuredGenerator<vector::ContractionOp> {
 
-  UnrolledOuterProductEmitter(PatternRewriter &rewriter,
-                              vector::ContractionOp op)
-      : rewriter(rewriter), loc(op.getLoc()), kind(op.kind()),
-        iterators(op.iterator_types()), maps(op.getIndexingMaps()), op(op) {}
+  UnrolledOuterProductGenerator(PatternRewriter &rewriter,
+                                vector::ContractionOp op)
+      : StructuredGenerator<vector::ContractionOp>(rewriter, op),
+        kind(op.kind()), lhs(op.lhs()), rhs(op.rhs()), res(op.acc()),
+        lhsType(op.getLhsType()) {}
 
   Value t(Value v) {
     static constexpr std::array<int64_t, 2> perm = {1, 0};
     return rewriter.create<vector::TransposeOp>(loc, v, perm);
   }
 
-  bool iters(ArrayRef<IteratorType> its) {
-    if (its.size() != iterators.size())
-      return false;
-    for (int i = 0, e = its.size(); i != e; ++i) {
-      if (!its[i].isOfType(iterators[i]))
-        return false;
-    }
-    return true;
-  }
-
-  bool layout(MapList l) {
-    auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
-    return maps == infer(l);
-  }
-
   LogicalResult outer_prod(Value lhs, Value rhs, Value res, int reductionSize) {
     assert(reductionSize > 0);
     for (int64_t k = 0; k < reductionSize; ++k) {
@@ -1293,128 +1280,132 @@ struct UnrolledOuterProductEmitter {
     return success();
   }
 
-  PatternRewriter &rewriter;
-  Location loc;
-  vector::CombiningKind kind;
-  ArrayAttr iterators;
-  SmallVector<AffineMap, 4> maps;
-  Operation *op;
-};
-} // namespace
-
-/// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul
-/// semantics to a reduction_size-unrolled sequence:
-/// ```
-///    %at = vector.transpose %a, [1, 0]
-///    %bRow0 = vector.extract %b[0]
-///    %atRow0 = vector.extract %at[0]
-///    %c0 = vector.outerproduct %atRow0, %bRow0, %c
-///    ...
-///    %bRowK = vector.extract %b[K]
-///    %atRowK = vector.extract %at[K]
-///    %cK = vector.outerproduct %atRowK, %bRowK, %cK-1
-/// ```
-///
-/// This only kicks in when VectorTransformsOptions is set to OuterProduct but
-/// otherwise supports any layout permutation of the matrix-multiply.
-LogicalResult ContractionOpToOuterProductOpLowering::matchAndRewrite(
-    vector::ContractionOp op, PatternRewriter &rewriter) const {
-  // TODO: implement masks
-  if (llvm::size(op.masks()) != 0)
-    return failure();
-
-  if (vectorTransformOptions.vectorContractLowering !=
-      vector::VectorContractLowering::OuterProduct)
-    return failure();
-
-  if (failed(filter(op)))
-    return failure();
-
-  VectorType lhsType = op.getLhsType();
-  Value lhs = op.lhs(), rhs = op.rhs(), res = op.acc();
-
-  //
-  // Two outer parallel, one inner reduction (matmat flavor).
-  //
-  UnrolledOuterProductEmitter e(rewriter, op);
-  if (e.iters({Par(), Par(), Red()})) {
-    // Set up the parallel/reduction structure in right form.
+  /// Two outer parallel, one inner reduction (matmat flavor).
+  LogicalResult matmat() {
+    if (!iters({Par(), Par(), Red()}))
+      return failure();
+    // Set up the parallel/reduction structure in the right form.
     AffineExpr m, n, k;
     bindDims(rewriter.getContext(), m, n, k);
     // Classical row-major matmul:  Just permute the lhs.
-    if (e.layout({{m, k}, {k, n}, {m, n}}))
-      return e.outer_prod(e.t(lhs), rhs, res, lhsType.getDimSize(1));
+    if (layout({{m, k}, {k, n}, {m, n}}))
+      return outer_prod(t(lhs), rhs, res, lhsType.getDimSize(1));
     // TODO: may be better to fail and use some vector<k> -> scalar reduction.
-    if (e.layout({{m, k}, {n, k}, {m, n}})) {
-      Value tlhs = e.t(lhs);
-      return e.outer_prod(tlhs, e.t(rhs), res, lhsType.getDimSize(1));
+    if (layout({{m, k}, {n, k}, {m, n}})) {
+      Value tlhs = t(lhs);
+      return outer_prod(tlhs, t(rhs), res, lhsType.getDimSize(1));
     }
     // No need to permute anything.
-    if (e.layout({{k, m}, {k, n}, {m, n}}))
-      return e.outer_prod(lhs, rhs, res, lhsType.getDimSize(0));
+    if (layout({{k, m}, {k, n}, {m, n}}))
+      return outer_prod(lhs, rhs, res, lhsType.getDimSize(0));
     // Just permute the rhs.
-    if (e.layout({{k, m}, {n, k}, {m, n}}))
-      return e.outer_prod(lhs, e.t(rhs), res, lhsType.getDimSize(0));
+    if (layout({{k, m}, {n, k}, {m, n}}))
+      return outer_prod(lhs, t(rhs), res, lhsType.getDimSize(0));
     // Transposed output: swap RHS and LHS.
     // Classical row-major matmul: permute the lhs.
-    if (e.layout({{m, k}, {k, n}, {n, m}}))
-      return e.outer_prod(rhs, e.t(lhs), res, lhsType.getDimSize(1));
+    if (layout({{m, k}, {k, n}, {n, m}}))
+      return outer_prod(rhs, t(lhs), res, lhsType.getDimSize(1));
     // TODO: may be better to fail and use some vector<k> -> scalar reduction.
-    if (e.layout({{m, k}, {n, k}, {n, m}})) {
-      Value trhs = e.t(rhs);
-      return e.outer_prod(trhs, e.t(lhs), res, lhsType.getDimSize(1));
+    if (layout({{m, k}, {n, k}, {n, m}})) {
+      Value trhs = t(rhs);
+      return outer_prod(trhs, t(lhs), res, lhsType.getDimSize(1));
     }
-    if (e.layout({{k, m}, {k, n}, {n, m}}))
-      return e.outer_prod(rhs, lhs, res, lhsType.getDimSize(0));
-    if (e.layout({{k, m}, {n, k}, {n, m}}))
-      return e.outer_prod(e.t(rhs), lhs, res, lhsType.getDimSize(0));
+    if (layout({{k, m}, {k, n}, {n, m}}))
+      return outer_prod(rhs, lhs, res, lhsType.getDimSize(0));
+    if (layout({{k, m}, {n, k}, {n, m}}))
+      return outer_prod(t(rhs), lhs, res, lhsType.getDimSize(0));
     return failure();
   }
 
-  //
-  // One outer parallel, one inner reduction (matvec flavor)
-  //
-  if (e.iters({Par(), Red()})) {
+  /// One outer parallel, one inner reduction (matvec flavor)
+  LogicalResult matvec() {
+    if (!iters({Par(), Red()}))
+      return failure();
     AffineExpr m, k;
     bindDims(rewriter.getContext(), m, k);
 
     // Case mat-vec: transpose.
-    if (e.layout({{m, k}, {k}, {m}}))
-      return e.outer_prod(e.t(lhs), rhs, res, lhsType.getDimSize(1));
+    if (layout({{m, k}, {k}, {m}}))
+      return outer_prod(t(lhs), rhs, res, lhsType.getDimSize(1));
     // Case mat-trans-vec: ready to go.
-    if (e.layout({{k, m}, {k}, {m}}))
-      return e.outer_prod(lhs, rhs, res, lhsType.getDimSize(0));
+    if (layout({{k, m}, {k}, {m}}))
+      return outer_prod(lhs, rhs, res, lhsType.getDimSize(0));
     // Case vec-mat: swap and transpose.
-    if (e.layout({{k}, {m, k}, {m}}))
-      return e.outer_prod(e.t(rhs), lhs, res, lhsType.getDimSize(0));
+    if (layout({{k}, {m, k}, {m}}))
+      return outer_prod(t(rhs), lhs, res, lhsType.getDimSize(0));
     // Case vec-mat-trans: swap and ready to go.
-    if (e.layout({{k}, {k, m}, {m}}))
-      return e.outer_prod(rhs, lhs, res, lhsType.getDimSize(0));
+    if (layout({{k}, {k, m}, {m}}))
+      return outer_prod(rhs, lhs, res, lhsType.getDimSize(0));
     return failure();
   }
 
   //
   // One outer reduction, one inner parallel (tmatvec flavor)
   //
-  if (e.iters({Red(), Par()})) {
+  LogicalResult tmatvec() {
+    if (!iters({Red(), Par()}))
+      return failure();
     AffineExpr k, m;
     bindDims(rewriter.getContext(), k, m);
 
     // Case mat-vec: transpose.
-    if (e.layout({{m, k}, {k}, {m}}))
-      return e.outer_prod(e.t(lhs), rhs, res, lhsType.getDimSize(1));
+    if (layout({{m, k}, {k}, {m}}))
+      return outer_prod(t(lhs), rhs, res, lhsType.getDimSize(1));
     // Case mat-trans-vec: ready to go.
-    if (e.layout({{k, m}, {k}, {m}}))
-      return e.outer_prod(lhs, rhs, res, lhsType.getDimSize(0));
+    if (layout({{k, m}, {k}, {m}}))
+      return outer_prod(lhs, rhs, res, lhsType.getDimSize(0));
     // Case vec-mat: swap and transpose.
-    if (e.layout({{k}, {m, k}, {m}}))
-      return e.outer_prod(e.t(rhs), lhs, res, lhsType.getDimSize(0));
+    if (layout({{k}, {m, k}, {m}}))
+      return outer_prod(t(rhs), lhs, res, lhsType.getDimSize(0));
     // Case vec-mat-trans: swap and ready to go.
-    if (e.layout({{k}, {k, m}, {m}}))
-      return e.outer_prod(rhs, lhs, res, lhsType.getDimSize(0));
+    if (layout({{k}, {k, m}, {m}}))
+      return outer_prod(rhs, lhs, res, lhsType.getDimSize(0));
     return failure();
   }
 
+private:
+  vector::CombiningKind kind;
+  Value lhs, rhs, res;
+  VectorType lhsType;
+};
+} // namespace
+
+/// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul
+/// semantics to a reduction_size-unrolled sequence:
+/// ```
+///    %at = vector.transpose %a, [1, 0]
+///    %bRow0 = vector.extract %b[0]
+///    %atRow0 = vector.extract %at[0]
+///    %c0 = vector.outerproduct %atRow0, %bRow0, %c
+///    ...
+///    %bRowK = vector.extract %b[K]
+///    %atRowK = vector.extract %at[K]
+///    %cK = vector.outerproduct %atRowK, %bRowK, %cK-1
+/// ```
+///
+/// This only kicks in when VectorTransformsOptions is set to OuterProduct but
+/// otherwise supports any layout permutation of the matrix-multiply.
+LogicalResult ContractionOpToOuterProductOpLowering::matchAndRewrite(
+    vector::ContractionOp op, PatternRewriter &rewriter) const {
+  // TODO: implement masks
+  if (llvm::size(op.masks()) != 0)
+    return failure();
+
+  if (vectorTransformOptions.vectorContractLowering !=
+      vector::VectorContractLowering::OuterProduct)
+    return failure();
+
+  if (failed(filter(op)))
+    return failure();
+
+  UnrolledOuterProductGenerator e(rewriter, op);
+  if (succeeded(e.matmat()))
+    return success();
+  if (succeeded(e.matvec()))
+    return success();
+  if (succeeded(e.tmatvec()))
+    return success();
+
   return failure();
 }
 


        


More information about the Mlir-commits mailing list