[Mlir-commits] [mlir] 05c65dc - [mlir][Vector] Add a VectorUnrollInterface and expose UnrollVectorPattern.

Nicolas Vasilache llvmlistbot at llvm.org
Mon Jul 6 05:10:07 PDT 2020


Author: Nicolas Vasilache
Date: 2020-07-06T08:09:06-04:00
New Revision: 05c65dc0fee4dbb6afdcf76bc1990c46fac06efe

URL: https://github.com/llvm/llvm-project/commit/05c65dc0fee4dbb6afdcf76bc1990c46fac06efe
DIFF: https://github.com/llvm/llvm-project/commit/05c65dc0fee4dbb6afdcf76bc1990c46fac06efe.diff

LOG: [mlir][Vector] Add a VectorUnrollInterface and expose UnrollVectorPattern.

The UnrollVectorPattern is can be used in a programmable fashion by:
```
OwningRewritePatternList patterns;
    patterns.insert<UnrollVectorPattern<AddFOp>>(ArrayRef<int64_t>{2, 2}, ctx);
    patterns.insert<UnrollVectorPattern<vector::ContractionOp>>(
        ArrayRef<int64_t>{2, 2, 2}, ctx);
    ...
    applyPatternsAndFoldGreedily(getFunction(), patterns);
```

Differential revision: https://reviews.llvm.org/D83064

Added: 
    mlir/include/mlir/Interfaces/VectorUnrollInterface.h
    mlir/include/mlir/Interfaces/VectorUnrollInterface.td
    mlir/lib/Interfaces/VectorUnrollInterface.cpp

Modified: 
    mlir/docs/OpDefinitions.md
    mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
    mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
    mlir/include/mlir/Dialect/Vector/VectorOps.h
    mlir/include/mlir/Dialect/Vector/VectorOps.td
    mlir/include/mlir/Dialect/Vector/VectorTransformPatterns.td
    mlir/include/mlir/Dialect/Vector/VectorTransforms.h
    mlir/include/mlir/Interfaces/CMakeLists.txt
    mlir/lib/Dialect/StandardOps/CMakeLists.txt
    mlir/lib/Dialect/Vector/CMakeLists.txt
    mlir/lib/Dialect/Vector/VectorOps.cpp
    mlir/lib/Dialect/Vector/VectorTransforms.cpp
    mlir/lib/Interfaces/CMakeLists.txt
    mlir/test/Dialect/Vector/vector-transforms.mlir
    mlir/test/lib/Transforms/TestVectorTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/docs/OpDefinitions.md b/mlir/docs/OpDefinitions.md
index 3416b456e093..01dcb722ad07 100644
--- a/mlir/docs/OpDefinitions.md
+++ b/mlir/docs/OpDefinitions.md
@@ -444,7 +444,7 @@ def MyInterface : OpInterface<"MyInterface"> {
     // Note: `ConcreteOp` corresponds to the derived operation typename.
     InterfaceMethod<"/*insert doc here*/",
       "unsigned", "getNumWithDefault", (ins), /*methodBody=*/[{}], [{
-        ConcreteOp op = cast<ConcreteOp>(getOperation());
+        ConcreteOp op = cast<ConcreteOp>(this->getOperation());
         return op.getNumInputs() + op.getNumOutputs();
     }]>,
   ];

diff  --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
index 8005ecbbdc49..7599988bdefc 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
@@ -21,6 +21,7 @@
 #include "mlir/Interfaces/CallInterfaces.h"
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Interfaces/VectorUnrollInterface.h"
 #include "mlir/Interfaces/ViewLikeInterface.h"
 
 // Pull in all enum type definitions and utility function declarations.

diff  --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index 8440b9b3d60b..2019db4a956f 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -17,6 +17,7 @@ include "mlir/IR/OpAsmInterface.td"
 include "mlir/Interfaces/CallInterfaces.td"
 include "mlir/Interfaces/ControlFlowInterfaces.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/Interfaces/VectorUnrollInterface.td"
 include "mlir/Interfaces/ViewLikeInterface.td"
 
 def StandardOps_Dialect : Dialect {
@@ -82,7 +83,9 @@ class UnaryOpSameOperandAndResultType<string mnemonic,
 }
 
 class FloatUnaryOp<string mnemonic, list<OpTrait> traits = []> :
-    UnaryOpSameOperandAndResultType<mnemonic, traits>,
+    UnaryOpSameOperandAndResultType<mnemonic,
+      !listconcat(traits,
+                  [DeclareOpInterfaceMethods<VectorUnrollOpInterface>])>,
     Arguments<(ins FloatLike:$operand)>;
 
 // Base class for standard arithmetic operations.  Requires operands and
@@ -112,7 +115,9 @@ class ArithmeticOp<string mnemonic, list<OpTrait> traits = []> :
 //     <op>i %0, %1 : i32
 //
 class IntArithmeticOp<string mnemonic, list<OpTrait> traits = []> :
-    ArithmeticOp<mnemonic, traits>,
+    ArithmeticOp<mnemonic,
+      !listconcat(traits,
+                  [DeclareOpInterfaceMethods<VectorUnrollOpInterface>])>,
     Arguments<(ins SignlessIntegerLike:$lhs, SignlessIntegerLike:$rhs)>;
 
 // Base class for standard arithmetic binary operations on floats, vectors and
@@ -125,7 +130,9 @@ class IntArithmeticOp<string mnemonic, list<OpTrait> traits = []> :
 //     <op>f %0, %1 : f32
 //
 class FloatArithmeticOp<string mnemonic, list<OpTrait> traits = []> :
-    ArithmeticOp<mnemonic, traits>,
+    ArithmeticOp<mnemonic,
+      !listconcat(traits,
+                  [DeclareOpInterfaceMethods<VectorUnrollOpInterface>])>,
     Arguments<(ins FloatLike:$lhs, FloatLike:$rhs)>;
 
 // Base class for standard arithmetic operations on complex numbers with a

diff  --git a/mlir/include/mlir/Dialect/Vector/VectorOps.h b/mlir/include/mlir/Dialect/Vector/VectorOps.h
index dd79b2986963..29c320903aec 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.h
@@ -1,4 +1,4 @@
-//===- VectorOps.h - MLIR Super Vectorizer Operations -----------*- C++ -*-===//
+//===- VectorOps.h - MLIR Vector Dialect Operations -------------*- C++ -*-===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -19,6 +19,7 @@
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/StandardTypes.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Interfaces/VectorUnrollInterface.h"
 
 namespace mlir {
 class MLIRContext;

diff  --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td
index 70ee272c8cef..8ca9baf2e0d0 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td
@@ -15,6 +15,7 @@
 
 include "mlir/Dialect/Affine/IR/AffineOpsBase.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/Interfaces/VectorUnrollInterface.td"
 
 def Vector_Dialect : Dialect {
   let name = "vector";
@@ -39,10 +40,13 @@ class Vector_Op<string mnemonic, list<OpTrait> traits = []> :
 // TODO(andydavis, ntv) Add an attribute to specify a 
diff erent algebra
 // with operators other than the current set: {*, +}.
 def Vector_ContractionOp :
-  Vector_Op<"contract", [NoSideEffect,
-     PredOpTrait<"lhs and rhs have same element type", TCopVTEtIsSameAs<0, 1>>,
-     PredOpTrait<"third operand acc and result have same element type",
-                 TCresVTEtIsSameAsOpBase<0, 2>>]>,
+  Vector_Op<"contract", [
+      NoSideEffect,
+      PredOpTrait<"lhs and rhs have same element type", TCopVTEtIsSameAs<0, 1>>,
+      PredOpTrait<"third operand acc and result have same element type",
+                  TCresVTEtIsSameAsOpBase<0, 2>>,
+      DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>
+    ]>,
     Arguments<(ins AnyVector:$lhs, AnyVector:$rhs, AnyType:$acc,
                Variadic<VectorOf<[I1]>>:$masks,
                AffineMapArrayAttr:$indexing_maps, ArrayAttr:$iterator_types)>,
@@ -896,7 +900,9 @@ def Vector_TransferOpUtils {
 }
 
 def Vector_TransferReadOp :
-  Vector_Op<"transfer_read">,
+  Vector_Op<"transfer_read", [
+      DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>
+    ]>,
     Arguments<(ins AnyMemRef:$memref, Variadic<Index>:$indices,
                AffineMapAttr:$permutation_map, AnyType:$padding,
                OptionalAttr<BoolArrayAttr>:$masked)>,
@@ -1068,7 +1074,9 @@ def Vector_TransferReadOp :
 }
 
 def Vector_TransferWriteOp :
-  Vector_Op<"transfer_write">,
+  Vector_Op<"transfer_write", [
+      DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>
+    ]>,
     Arguments<(ins AnyVector:$vector, AnyMemRef:$memref,
                Variadic<Index>:$indices,
                AffineMapAttr:$permutation_map,

diff  --git a/mlir/include/mlir/Dialect/Vector/VectorTransformPatterns.td b/mlir/include/mlir/Dialect/Vector/VectorTransformPatterns.td
index 5f5c90521a7d..ef8118ec6470 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorTransformPatterns.td
+++ b/mlir/include/mlir/Dialect/Vector/VectorTransformPatterns.td
@@ -20,7 +20,7 @@ class HasShape<list<int> shape> :
     StrJoinInt<shape>.result # "})">;
 
 class UnrollVectorOp<list<int> factors> : NativeCodeCall<
-  "unrollSingleResultOpMatchingType($_builder, $0.getDefiningOp(), " #
+  "unrollSingleResultVectorOp($_builder, $0.getDefiningOp(), " #
     "{" # StrJoinInt<factors>.result # "})">;
 
 #endif // VECTOR_TRANSFORM_PATTERNS

diff  --git a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
index 1864d45ac552..ab69a8246587 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
@@ -10,6 +10,8 @@
 #define DIALECT_VECTOR_VECTORTRANSFORMS_H_
 
 #include "mlir/Dialect/Vector/VectorOps.h"
+#include "mlir/Dialect/Vector/VectorUtils.h"
+#include "mlir/IR/Function.h"
 #include "mlir/IR/PatternMatch.h"
 
 namespace mlir {
@@ -25,42 +27,82 @@ void populateVectorToVectorConversionPatterns(
 
 namespace vector {
 
-// Entry point for unrolling declarative pattern rewrites.
-// `op` is unrolled to the `targetShape` as follows, for each of its operands:
-//   1. the unrolled type `unrolledVectorType` and number of unrolled instances
-//   `numUnrolledInstances` are computed from the `targetShape`. For now it is
-//   assumed the unrolling factors divide the vector sizes.
-//   2. a fakeFork cast op is inserted that takes the operand and returns
-//   `numUnrolledInstances` results of type `unrolledVectorType`.
-//   3. the original op is cloned `numUnrolledInstances` times, once for each
-//   result of the fakeFork cast op.
-//   4. a fakeJoin cast op takes all these results and merges them into a single
-//   aggregate vector result whose size matches the original non-unrolled op
-//   operand types.
-//
-// Example:
-//
-//    opA(operand0, operand1)  // numUnrolledInstances = 3
-//
-//            operand0                   operand1
-//               |                          |
-//             fork                       fork
-//        <----------gather all fork ops --------->
-//              /|\                        /|\
-//          f00 f01 f02                f10 f11 f12
-//        <---------- clone op 3 times --------->
-//          opA0(f00, f10), opA1(f01, f11), opA2(f02, f12)
-//                 \            |            /
-//      <-------------------- join ------------------------->
-//
-// Other local patterns then kick in iteratively (including DCE) and compose
-// until all the fakeFork and fakeJoin ops are removed.
-//
-// This will be extended in the future to support more advanced use cases than
-// simple pointwise ops.
-SmallVector<Value, 1>
-unrollSingleResultOpMatchingType(OpBuilder &builder, Operation *op,
-                                 ArrayRef<int64_t> targetShape);
+/// Entry point for unrolling declarative pattern rewrites.
+/// `op` is unrolled to the `targetShape` as follows, for each of its operands:
+///   1. the unrolled type `unrolledVectorType` and number of unrolled instances
+///   `numUnrolledInstances` are computed from the `targetShape`. For now it is
+///   assumed the unrolling factors divide the vector sizes.
+///   2. a fakeFork cast op is inserted that takes the operand and returns
+///   `numUnrolledInstances` results of type `unrolledVectorType`.
+///   3. the original op is cloned `numUnrolledInstances` times, once for each
+///   result of the fakeFork cast op.
+///   4. a fakeJoin cast op takes all these results and merges them into a
+///   single aggregate vector result whose size matches the original
+///   non-unrolled op operand types.
+///
+/// Example:
+///
+///    opA(operand0, operand1)  // numUnrolledInstances = 3
+///
+///            operand0                   operand1
+///               |                          |
+///             fork                       fork
+///        <----------gather all fork ops --------->
+///              /|\                        /|\
+///          f00 f01 f02                f10 f11 f12
+///        <---------- clone op 3 times --------->
+///          opA0(f00, f10), opA1(f01, f11), opA2(f02, f12)
+///                 \            |            /
+///      <-------------------- join ------------------------->
+///
+/// Other local patterns then kick in iteratively (including DCE) and compose
+/// until all the fakeFork and fakeJoin ops are removed.
+///
+/// This will be extended in the future to support more advanced use cases than
+/// simple pointwise ops.
+SmallVector<Value, 1> unrollSingleResultVectorOp(OpBuilder &builder,
+                                                 Operation *op,
+                                                 ArrayRef<int64_t> targetShape);
+
+/// Pattern to apply `unrollSingleResultVectorOp` to a `targetShape`
+/// declaratively.
+template <typename OpTy>
+struct UnrollVectorPattern : public OpRewritePattern<OpTy> {
+  using FilterConstraintType = std::function<LogicalResult(OpTy op)>;
+  UnrollVectorPattern(
+      ArrayRef<int64_t> targetShape, MLIRContext *context,
+      FilterConstraintType constraint = [](OpTy op) { return success(); })
+      : OpRewritePattern<OpTy>(context),
+        targetShape(targetShape.begin(), targetShape.end()),
+        filter(constraint) {}
+  LogicalResult matchAndRewrite(OpTy op,
+                                PatternRewriter &rewriter) const override {
+    if (failed(filter(op)))
+      return failure();
+    auto unrollableVectorOp =
+        dyn_cast<VectorUnrollOpInterface>(op.getOperation());
+    if (!unrollableVectorOp)
+      return failure();
+    auto maybeUnrollShape = unrollableVectorOp.getShapeForUnroll();
+    if (!maybeUnrollShape)
+      return failure();
+    auto maybeShapeRatio = shapeRatio(*maybeUnrollShape, targetShape);
+    if (!maybeShapeRatio ||
+        llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; }))
+      return failure();
+    if (op.getOperation()->getNumResults() != 1)
+      return failure();
+    auto resultVector = unrollSingleResultVectorOp(rewriter, op, targetShape);
+    if (resultVector.size() != 1)
+      return failure();
+    rewriter.replaceOp(op, resultVector.front());
+    return success();
+  }
+
+private:
+  SmallVector<int64_t, 4> targetShape;
+  FilterConstraintType filter;
+};
 
 } // namespace vector
 

diff  --git a/mlir/include/mlir/Interfaces/CMakeLists.txt b/mlir/include/mlir/Interfaces/CMakeLists.txt
index 51f3f8ac1be6..0de2b5a8688b 100644
--- a/mlir/include/mlir/Interfaces/CMakeLists.txt
+++ b/mlir/include/mlir/Interfaces/CMakeLists.txt
@@ -5,5 +5,6 @@ add_mlir_interface(DerivedAttributeOpInterface)
 add_mlir_interface(InferTypeOpInterface)
 add_mlir_interface(LoopLikeInterface)
 add_mlir_interface(SideEffectInterfaces)
+add_mlir_interface(VectorUnrollInterface)
 add_mlir_interface(ViewLikeInterface)
 

diff  --git a/mlir/include/mlir/Interfaces/VectorUnrollInterface.h b/mlir/include/mlir/Interfaces/VectorUnrollInterface.h
new file mode 100644
index 000000000000..a1cf39c17ebe
--- /dev/null
+++ b/mlir/include/mlir/Interfaces/VectorUnrollInterface.h
@@ -0,0 +1,26 @@
+//===- VectorUnrollInterface.h - Vector unrolling interface ---------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the operation interface for vector ops that can be
+// unrolled.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_INTERFACES_VECTORUNROLLINTERFACE_H
+#define MLIR_INTERFACES_VECTORUNROLLINTERFACE_H
+
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/StandardTypes.h"
+
+namespace mlir {
+
+#include "mlir/Interfaces/VectorUnrollInterface.h.inc"
+
+} // namespace mlir
+
+#endif // MLIR_INTERFACES_VECTORUNROLLINTERFACE_H

diff  --git a/mlir/include/mlir/Interfaces/VectorUnrollInterface.td b/mlir/include/mlir/Interfaces/VectorUnrollInterface.td
new file mode 100644
index 000000000000..b9cff8bdab1d
--- /dev/null
+++ b/mlir/include/mlir/Interfaces/VectorUnrollInterface.td
@@ -0,0 +1,45 @@
+//===- VectorUnrollInterface.td - VectorUnroll interface ---*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Defines the interface for operations on vectors that can be unrolled.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_INTERFACES_VECTORUNROLLINTERFACE
+#define MLIR_INTERFACES_VECTORUNROLLINTERFACE
+
+include "mlir/IR/OpBase.td"
+
+def VectorUnrollOpInterface : OpInterface<"VectorUnrollOpInterface"> {
+  let description = [{
+    Encodes properties of an operation on vectors that can be unrolled.
+  }];
+
+  let methods = [
+    InterfaceMethod<[{
+        Returns the shape ratio of unrolling to the target vector shape
+        `targetShape`. Returns `None` if the op cannot be unrolled to the target
+        vector shape.
+      }],
+      "Optional<SmallVector<int64_t, 4>>",
+      "getShapeForUnroll",
+      (ins),
+      /*methodBody=*/[{}],
+      [{
+        auto vt = this->getOperation()->getResult(0).getType().
+          template dyn_cast<VectorType>();
+        if (!vt)
+          return None;
+        SmallVector<int64_t, 4> res(vt.getShape().begin(), vt.getShape().end());
+        return res;
+      }]
+    >,
+  ];
+}
+
+#endif // MLIR_INTERFACES_VECTORUNROLLINTERFACE

diff  --git a/mlir/lib/Dialect/StandardOps/CMakeLists.txt b/mlir/lib/Dialect/StandardOps/CMakeLists.txt
index f3b93d6013ce..7d61aea3116e 100644
--- a/mlir/lib/Dialect/StandardOps/CMakeLists.txt
+++ b/mlir/lib/Dialect/StandardOps/CMakeLists.txt
@@ -15,6 +15,7 @@ add_mlir_dialect_library(MLIRStandardOps
   MLIREDSC
   MLIRIR
   MLIRSideEffectInterfaces
+  MLIRVectorUnrollInterface
   MLIRViewLikeInterface
   )
 

diff  --git a/mlir/lib/Dialect/Vector/CMakeLists.txt b/mlir/lib/Dialect/Vector/CMakeLists.txt
index 7a5ed49cd9ce..69a329917228 100644
--- a/mlir/lib/Dialect/Vector/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/CMakeLists.txt
@@ -19,4 +19,5 @@ add_mlir_dialect_library(MLIRVector
   MLIRSCF
   MLIRLoopAnalysis
   MLIRSideEffectInterfaces
+  MLIRVectorUnrollInterface
   )

diff  --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index 5d3a916d02ea..184aed2ee1cd 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -469,6 +469,12 @@ SmallVector<AffineMap, 4> ContractionOp::getIndexingMaps() {
   return res;
 }
 
+Optional<SmallVector<int64_t, 4>> ContractionOp::getShapeForUnroll() {
+  SmallVector<int64_t, 4> shape;
+  getIterationBounds(shape);
+  return shape;
+}
+
 //===----------------------------------------------------------------------===//
 // ExtractElementOp
 //===----------------------------------------------------------------------===//
@@ -1522,6 +1528,11 @@ OpFoldResult TransferReadOp::fold(ArrayRef<Attribute>) {
   return OpFoldResult();
 }
 
+Optional<SmallVector<int64_t, 4>> TransferReadOp::getShapeForUnroll() {
+  auto s = getVectorType().getShape();
+  return SmallVector<int64_t, 4>{s.begin(), s.end()};
+}
+
 //===----------------------------------------------------------------------===//
 // TransferWriteOp
 //===----------------------------------------------------------------------===//
@@ -1612,6 +1623,11 @@ LogicalResult TransferWriteOp::fold(ArrayRef<Attribute>,
   return foldMemRefCast(*this);
 }
 
+Optional<SmallVector<int64_t, 4>> TransferWriteOp::getShapeForUnroll() {
+  auto s = getVectorType().getShape();
+  return SmallVector<int64_t, 4>{s.begin(), s.end()};
+}
+
 //===----------------------------------------------------------------------===//
 // ShapeCastOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index b841580433f9..c7cf2937939c 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -30,6 +30,7 @@
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/IR/Types.h"
+#include "mlir/Interfaces/VectorUnrollInterface.h"
 
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Debug.h"
@@ -357,7 +358,7 @@ struct VectorState {
 //    (removable with DCE).
 
 // TODO(andydavis) Generalize this to support structured ops beyond
-// vector ContractionOp, and merge it with 'unrollSingleResultOpMatchingType'
+// vector ContractionOp, and merge it with 'unrollSingleResultVectorOp'
 static Value unrollSingleResultStructuredOp(Operation *op,
                                             ArrayRef<int64_t> iterationBounds,
                                             std::vector<VectorState> &vectors,
@@ -450,11 +451,7 @@ static Value unrollSingleResultStructuredOp(Operation *op,
 
 static void getVectorContractionOpUnrollState(
     vector::ContractionOp contractionOp, ArrayRef<int64_t> targetShape,
-    SmallVectorImpl<int64_t> &iterationBounds,
     std::vector<VectorState> &vectors, unsigned &resultIndex) {
-  // Get contraction op iteration bounds.
-  contractionOp.getIterationBounds(iterationBounds);
-  assert(iterationBounds.size() == targetShape.size());
   // Get map from iteration space index to lhs/rhs/result shape index.
   std::vector<DenseMap<int64_t, int64_t>> iterationIndexMapList;
   contractionOp.getIterationIndexMap(iterationIndexMapList);
@@ -476,17 +473,15 @@ static void getVectorContractionOpUnrollState(
     vectors.push_back({contractionOp.getRHSVectorMaskType(),
                        vectors[1].indexMap, accOperandIndex + 2, false});
   }
-  // Unroll 'op' 'iterationBounds' to 'targetShape'.
   // TODO(andydavis) Use linalg style 'args_in'/'args_out' to partition
   // 'vectors' instead of 'resultIndex'.
   resultIndex = accOperandIndex;
 }
 
-static void
-getVectorElementwiseOpUnrollState(Operation *op, ArrayRef<int64_t> targetShape,
-                                  SmallVectorImpl<int64_t> &iterationBounds,
-                                  std::vector<VectorState> &vectors,
-                                  unsigned &resultIndex) {
+static void getVectorElementwiseOpUnrollState(Operation *op,
+                                              ArrayRef<int64_t> targetShape,
+                                              std::vector<VectorState> &vectors,
+                                              unsigned &resultIndex) {
   // Verify that operation and operands all have the same vector shape.
   auto resultType = op->getResult(0).getType().dyn_cast_or_null<VectorType>();
   assert(resultType && "Expected op with vector result type");
@@ -494,8 +489,6 @@ getVectorElementwiseOpUnrollState(Operation *op, ArrayRef<int64_t> targetShape,
   // Verify that all operands have the same vector type as result.
   assert(llvm::all_of(op->getOperandTypes(),
                       [=](Type type) { return type == resultType; }));
-  // Populate 'iterationBounds' with 'resultShape' for elementwise operations.
-  iterationBounds.assign(resultShape.begin(), resultShape.end());
 
   // Create trivial elementwise identity index map based on 'resultShape'.
   DenseMap<int64_t, int64_t> indexMap;
@@ -513,28 +506,32 @@ getVectorElementwiseOpUnrollState(Operation *op, ArrayRef<int64_t> targetShape,
 }
 
 // Entry point for unrolling declarative pattern rewrites.
-SmallVector<Value, 1> mlir::vector::unrollSingleResultOpMatchingType(
-    OpBuilder &builder, Operation *op, ArrayRef<int64_t> targetShape) {
+SmallVector<Value, 1>
+mlir::vector::unrollSingleResultVectorOp(OpBuilder &builder, Operation *op,
+                                         ArrayRef<int64_t> targetShape) {
   assert(op->getNumResults() == 1 && "Expected single result operation");
 
   // Populate 'iterationBounds', 'vectors' and 'resultIndex' to unroll 'op'.
   SmallVector<int64_t, 6> iterationBounds;
+  auto unrollableVectorOp = cast<VectorUnrollOpInterface>(op);
+  auto maybeUnrollShape = unrollableVectorOp.getShapeForUnroll();
+  assert(maybeUnrollShape && "Trying to unroll an incorrect vector op");
+
   std::vector<VectorState> vectors;
   unsigned resultIndex;
 
   if (auto contractionOp = dyn_cast<vector::ContractionOp>(op)) {
     // Populate state for vector ContractionOp.
-    getVectorContractionOpUnrollState(contractionOp, targetShape,
-                                      iterationBounds, vectors, resultIndex);
+    getVectorContractionOpUnrollState(contractionOp, targetShape, vectors,
+                                      resultIndex);
   } else {
     // Populate state for vector elementwise op.
-    getVectorElementwiseOpUnrollState(op, targetShape, iterationBounds, vectors,
-                                      resultIndex);
+    getVectorElementwiseOpUnrollState(op, targetShape, vectors, resultIndex);
   }
 
   // Unroll 'op' with 'iterationBounds' to 'targetShape'.
   return SmallVector<Value, 1>{unrollSingleResultStructuredOp(
-      op, iterationBounds, vectors, resultIndex, targetShape, builder)};
+      op, *maybeUnrollShape, vectors, resultIndex, targetShape, builder)};
 }
 
 /// Generates slices of 'vectorType' according to 'sizes' and 'strides, and

diff  --git a/mlir/lib/Interfaces/CMakeLists.txt b/mlir/lib/Interfaces/CMakeLists.txt
index 19b4e0af626d..b8498e224f25 100644
--- a/mlir/lib/Interfaces/CMakeLists.txt
+++ b/mlir/lib/Interfaces/CMakeLists.txt
@@ -6,6 +6,7 @@ set(LLVM_OPTIONAL_SOURCES
   InferTypeOpInterface.cpp
   LoopLikeInterface.cpp
   SideEffectInterfaces.cpp
+  VectorUnrollInterface.cpp
   ViewLikeInterface.cpp
   )
 
@@ -32,5 +33,6 @@ add_mlir_interface_library(DerivedAttributeOpInterface)
 add_mlir_interface_library(InferTypeOpInterface)
 add_mlir_interface_library(LoopLikeInterface)
 add_mlir_interface_library(SideEffectInterfaces)
+add_mlir_interface_library(VectorUnrollInterface)
 add_mlir_interface_library(ViewLikeInterface)
 

diff  --git a/mlir/lib/Interfaces/VectorUnrollInterface.cpp b/mlir/lib/Interfaces/VectorUnrollInterface.cpp
new file mode 100644
index 000000000000..6d3d432a7061
--- /dev/null
+++ b/mlir/lib/Interfaces/VectorUnrollInterface.cpp
@@ -0,0 +1,18 @@
+//===- VectorUnrollInterface.cpp - Unrollable vector operations -*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Interfaces/VectorUnrollInterface.h"
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// VectorUnroll Interfaces
+//===----------------------------------------------------------------------===//
+
+/// Include the definitions of the VectorUntoll interfaces.
+#include "mlir/Interfaces/VectorUnrollInterface.cpp.inc"

diff  --git a/mlir/test/Dialect/Vector/vector-transforms.mlir b/mlir/test/Dialect/Vector/vector-transforms.mlir
index 8de153adf731..0bd6c3c43b59 100644
--- a/mlir/test/Dialect/Vector/vector-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-transforms.mlir
@@ -1,4 +1,5 @@
 // RUN: mlir-opt %s -test-vector-to-vector-conversion | FileCheck %s
+// RUN: mlir-opt %s -test-vector-unrolling-patterns | FileCheck %s
 
 // CHECK-DAG: #[[MAP0:map[0-9]+]] = affine_map<(d0, d1) -> (d0, d1)>
 // CHECK-DAG: #[[MAP1:map[0-9]+]] = affine_map<(d0, d1, d2) -> (d1, d2)>

diff  --git a/mlir/test/lib/Transforms/TestVectorTransforms.cpp b/mlir/test/lib/Transforms/TestVectorTransforms.cpp
index c6cf45e824d7..1af6c3564b80 100644
--- a/mlir/test/lib/Transforms/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestVectorTransforms.cpp
@@ -92,6 +92,20 @@ struct TestVectorContractionConversion
   }
 };
 
+struct TestVectorUnrollingPatterns
+    : public PassWrapper<TestVectorUnrollingPatterns, FunctionPass> {
+  void runOnFunction() override {
+    MLIRContext *ctx = &getContext();
+    OwningRewritePatternList patterns;
+    patterns.insert<UnrollVectorPattern<AddFOp>>(ArrayRef<int64_t>{2, 2}, ctx);
+    patterns.insert<UnrollVectorPattern<vector::ContractionOp>>(
+        ArrayRef<int64_t>{2, 2, 2}, ctx);
+    populateVectorToVectorCanonicalizationPatterns(patterns, ctx);
+    populateVectorToVectorTransformationPatterns(patterns, ctx);
+    applyPatternsAndFoldGreedily(getFunction(), patterns);
+  }
+};
+
 } // end anonymous namespace
 
 namespace mlir {
@@ -107,5 +121,9 @@ void registerTestVectorConversions() {
   PassRegistration<TestVectorContractionConversion> contractionPass(
       "test-vector-contraction-conversion",
       "Test conversion patterns that lower contract ops in the vector dialect");
+
+  PassRegistration<TestVectorUnrollingPatterns> contractionUnrollingPass(
+      "test-vector-unrolling-patterns",
+      "Test conversion patterns to unroll contract ops in the vector dialect");
 }
 } // namespace mlir


        


More information about the Mlir-commits mailing list