[Mlir-commits] [mlir] 05c65dc - [mlir][Vector] Add a VectorUnrollInterface and expose UnrollVectorPattern.
Nicolas Vasilache
llvmlistbot at llvm.org
Mon Jul 6 05:10:07 PDT 2020
Author: Nicolas Vasilache
Date: 2020-07-06T08:09:06-04:00
New Revision: 05c65dc0fee4dbb6afdcf76bc1990c46fac06efe
URL: https://github.com/llvm/llvm-project/commit/05c65dc0fee4dbb6afdcf76bc1990c46fac06efe
DIFF: https://github.com/llvm/llvm-project/commit/05c65dc0fee4dbb6afdcf76bc1990c46fac06efe.diff
LOG: [mlir][Vector] Add a VectorUnrollInterface and expose UnrollVectorPattern.
The UnrollVectorPattern is can be used in a programmable fashion by:
```
OwningRewritePatternList patterns;
patterns.insert<UnrollVectorPattern<AddFOp>>(ArrayRef<int64_t>{2, 2}, ctx);
patterns.insert<UnrollVectorPattern<vector::ContractionOp>>(
ArrayRef<int64_t>{2, 2, 2}, ctx);
...
applyPatternsAndFoldGreedily(getFunction(), patterns);
```
Differential revision: https://reviews.llvm.org/D83064
Added:
mlir/include/mlir/Interfaces/VectorUnrollInterface.h
mlir/include/mlir/Interfaces/VectorUnrollInterface.td
mlir/lib/Interfaces/VectorUnrollInterface.cpp
Modified:
mlir/docs/OpDefinitions.md
mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
mlir/include/mlir/Dialect/Vector/VectorOps.h
mlir/include/mlir/Dialect/Vector/VectorOps.td
mlir/include/mlir/Dialect/Vector/VectorTransformPatterns.td
mlir/include/mlir/Dialect/Vector/VectorTransforms.h
mlir/include/mlir/Interfaces/CMakeLists.txt
mlir/lib/Dialect/StandardOps/CMakeLists.txt
mlir/lib/Dialect/Vector/CMakeLists.txt
mlir/lib/Dialect/Vector/VectorOps.cpp
mlir/lib/Dialect/Vector/VectorTransforms.cpp
mlir/lib/Interfaces/CMakeLists.txt
mlir/test/Dialect/Vector/vector-transforms.mlir
mlir/test/lib/Transforms/TestVectorTransforms.cpp
Removed:
################################################################################
diff --git a/mlir/docs/OpDefinitions.md b/mlir/docs/OpDefinitions.md
index 3416b456e093..01dcb722ad07 100644
--- a/mlir/docs/OpDefinitions.md
+++ b/mlir/docs/OpDefinitions.md
@@ -444,7 +444,7 @@ def MyInterface : OpInterface<"MyInterface"> {
// Note: `ConcreteOp` corresponds to the derived operation typename.
InterfaceMethod<"/*insert doc here*/",
"unsigned", "getNumWithDefault", (ins), /*methodBody=*/[{}], [{
- ConcreteOp op = cast<ConcreteOp>(getOperation());
+ ConcreteOp op = cast<ConcreteOp>(this->getOperation());
return op.getNumInputs() + op.getNumOutputs();
}]>,
];
diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
index 8005ecbbdc49..7599988bdefc 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
@@ -21,6 +21,7 @@
#include "mlir/Interfaces/CallInterfaces.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Interfaces/VectorUnrollInterface.h"
#include "mlir/Interfaces/ViewLikeInterface.h"
// Pull in all enum type definitions and utility function declarations.
diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index 8440b9b3d60b..2019db4a956f 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -17,6 +17,7 @@ include "mlir/IR/OpAsmInterface.td"
include "mlir/Interfaces/CallInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/Interfaces/VectorUnrollInterface.td"
include "mlir/Interfaces/ViewLikeInterface.td"
def StandardOps_Dialect : Dialect {
@@ -82,7 +83,9 @@ class UnaryOpSameOperandAndResultType<string mnemonic,
}
class FloatUnaryOp<string mnemonic, list<OpTrait> traits = []> :
- UnaryOpSameOperandAndResultType<mnemonic, traits>,
+ UnaryOpSameOperandAndResultType<mnemonic,
+ !listconcat(traits,
+ [DeclareOpInterfaceMethods<VectorUnrollOpInterface>])>,
Arguments<(ins FloatLike:$operand)>;
// Base class for standard arithmetic operations. Requires operands and
@@ -112,7 +115,9 @@ class ArithmeticOp<string mnemonic, list<OpTrait> traits = []> :
// <op>i %0, %1 : i32
//
class IntArithmeticOp<string mnemonic, list<OpTrait> traits = []> :
- ArithmeticOp<mnemonic, traits>,
+ ArithmeticOp<mnemonic,
+ !listconcat(traits,
+ [DeclareOpInterfaceMethods<VectorUnrollOpInterface>])>,
Arguments<(ins SignlessIntegerLike:$lhs, SignlessIntegerLike:$rhs)>;
// Base class for standard arithmetic binary operations on floats, vectors and
@@ -125,7 +130,9 @@ class IntArithmeticOp<string mnemonic, list<OpTrait> traits = []> :
// <op>f %0, %1 : f32
//
class FloatArithmeticOp<string mnemonic, list<OpTrait> traits = []> :
- ArithmeticOp<mnemonic, traits>,
+ ArithmeticOp<mnemonic,
+ !listconcat(traits,
+ [DeclareOpInterfaceMethods<VectorUnrollOpInterface>])>,
Arguments<(ins FloatLike:$lhs, FloatLike:$rhs)>;
// Base class for standard arithmetic operations on complex numbers with a
diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.h b/mlir/include/mlir/Dialect/Vector/VectorOps.h
index dd79b2986963..29c320903aec 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.h
@@ -1,4 +1,4 @@
-//===- VectorOps.h - MLIR Super Vectorizer Operations -----------*- C++ -*-===//
+//===- VectorOps.h - MLIR Vector Dialect Operations -------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -19,6 +19,7 @@
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Interfaces/VectorUnrollInterface.h"
namespace mlir {
class MLIRContext;
diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td
index 70ee272c8cef..8ca9baf2e0d0 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td
@@ -15,6 +15,7 @@
include "mlir/Dialect/Affine/IR/AffineOpsBase.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/Interfaces/VectorUnrollInterface.td"
def Vector_Dialect : Dialect {
let name = "vector";
@@ -39,10 +40,13 @@ class Vector_Op<string mnemonic, list<OpTrait> traits = []> :
// TODO(andydavis, ntv) Add an attribute to specify a
diff erent algebra
// with operators other than the current set: {*, +}.
def Vector_ContractionOp :
- Vector_Op<"contract", [NoSideEffect,
- PredOpTrait<"lhs and rhs have same element type", TCopVTEtIsSameAs<0, 1>>,
- PredOpTrait<"third operand acc and result have same element type",
- TCresVTEtIsSameAsOpBase<0, 2>>]>,
+ Vector_Op<"contract", [
+ NoSideEffect,
+ PredOpTrait<"lhs and rhs have same element type", TCopVTEtIsSameAs<0, 1>>,
+ PredOpTrait<"third operand acc and result have same element type",
+ TCresVTEtIsSameAsOpBase<0, 2>>,
+ DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>
+ ]>,
Arguments<(ins AnyVector:$lhs, AnyVector:$rhs, AnyType:$acc,
Variadic<VectorOf<[I1]>>:$masks,
AffineMapArrayAttr:$indexing_maps, ArrayAttr:$iterator_types)>,
@@ -896,7 +900,9 @@ def Vector_TransferOpUtils {
}
def Vector_TransferReadOp :
- Vector_Op<"transfer_read">,
+ Vector_Op<"transfer_read", [
+ DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>
+ ]>,
Arguments<(ins AnyMemRef:$memref, Variadic<Index>:$indices,
AffineMapAttr:$permutation_map, AnyType:$padding,
OptionalAttr<BoolArrayAttr>:$masked)>,
@@ -1068,7 +1074,9 @@ def Vector_TransferReadOp :
}
def Vector_TransferWriteOp :
- Vector_Op<"transfer_write">,
+ Vector_Op<"transfer_write", [
+ DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>
+ ]>,
Arguments<(ins AnyVector:$vector, AnyMemRef:$memref,
Variadic<Index>:$indices,
AffineMapAttr:$permutation_map,
diff --git a/mlir/include/mlir/Dialect/Vector/VectorTransformPatterns.td b/mlir/include/mlir/Dialect/Vector/VectorTransformPatterns.td
index 5f5c90521a7d..ef8118ec6470 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorTransformPatterns.td
+++ b/mlir/include/mlir/Dialect/Vector/VectorTransformPatterns.td
@@ -20,7 +20,7 @@ class HasShape<list<int> shape> :
StrJoinInt<shape>.result # "})">;
class UnrollVectorOp<list<int> factors> : NativeCodeCall<
- "unrollSingleResultOpMatchingType($_builder, $0.getDefiningOp(), " #
+ "unrollSingleResultVectorOp($_builder, $0.getDefiningOp(), " #
"{" # StrJoinInt<factors>.result # "})">;
#endif // VECTOR_TRANSFORM_PATTERNS
diff --git a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
index 1864d45ac552..ab69a8246587 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
@@ -10,6 +10,8 @@
#define DIALECT_VECTOR_VECTORTRANSFORMS_H_
#include "mlir/Dialect/Vector/VectorOps.h"
+#include "mlir/Dialect/Vector/VectorUtils.h"
+#include "mlir/IR/Function.h"
#include "mlir/IR/PatternMatch.h"
namespace mlir {
@@ -25,42 +27,82 @@ void populateVectorToVectorConversionPatterns(
namespace vector {
-// Entry point for unrolling declarative pattern rewrites.
-// `op` is unrolled to the `targetShape` as follows, for each of its operands:
-// 1. the unrolled type `unrolledVectorType` and number of unrolled instances
-// `numUnrolledInstances` are computed from the `targetShape`. For now it is
-// assumed the unrolling factors divide the vector sizes.
-// 2. a fakeFork cast op is inserted that takes the operand and returns
-// `numUnrolledInstances` results of type `unrolledVectorType`.
-// 3. the original op is cloned `numUnrolledInstances` times, once for each
-// result of the fakeFork cast op.
-// 4. a fakeJoin cast op takes all these results and merges them into a single
-// aggregate vector result whose size matches the original non-unrolled op
-// operand types.
-//
-// Example:
-//
-// opA(operand0, operand1) // numUnrolledInstances = 3
-//
-// operand0 operand1
-// | |
-// fork fork
-// <----------gather all fork ops --------->
-// /|\ /|\
-// f00 f01 f02 f10 f11 f12
-// <---------- clone op 3 times --------->
-// opA0(f00, f10), opA1(f01, f11), opA2(f02, f12)
-// \ | /
-// <-------------------- join ------------------------->
-//
-// Other local patterns then kick in iteratively (including DCE) and compose
-// until all the fakeFork and fakeJoin ops are removed.
-//
-// This will be extended in the future to support more advanced use cases than
-// simple pointwise ops.
-SmallVector<Value, 1>
-unrollSingleResultOpMatchingType(OpBuilder &builder, Operation *op,
- ArrayRef<int64_t> targetShape);
+/// Entry point for unrolling declarative pattern rewrites.
+/// `op` is unrolled to the `targetShape` as follows, for each of its operands:
+/// 1. the unrolled type `unrolledVectorType` and number of unrolled instances
+/// `numUnrolledInstances` are computed from the `targetShape`. For now it is
+/// assumed the unrolling factors divide the vector sizes.
+/// 2. a fakeFork cast op is inserted that takes the operand and returns
+/// `numUnrolledInstances` results of type `unrolledVectorType`.
+/// 3. the original op is cloned `numUnrolledInstances` times, once for each
+/// result of the fakeFork cast op.
+/// 4. a fakeJoin cast op takes all these results and merges them into a
+/// single aggregate vector result whose size matches the original
+/// non-unrolled op operand types.
+///
+/// Example:
+///
+/// opA(operand0, operand1) // numUnrolledInstances = 3
+///
+/// operand0 operand1
+/// | |
+/// fork fork
+/// <----------gather all fork ops --------->
+/// /|\ /|\
+/// f00 f01 f02 f10 f11 f12
+/// <---------- clone op 3 times --------->
+/// opA0(f00, f10), opA1(f01, f11), opA2(f02, f12)
+/// \ | /
+/// <-------------------- join ------------------------->
+///
+/// Other local patterns then kick in iteratively (including DCE) and compose
+/// until all the fakeFork and fakeJoin ops are removed.
+///
+/// This will be extended in the future to support more advanced use cases than
+/// simple pointwise ops.
+SmallVector<Value, 1> unrollSingleResultVectorOp(OpBuilder &builder,
+ Operation *op,
+ ArrayRef<int64_t> targetShape);
+
+/// Pattern to apply `unrollSingleResultVectorOp` to a `targetShape`
+/// declaratively.
+template <typename OpTy>
+struct UnrollVectorPattern : public OpRewritePattern<OpTy> {
+ using FilterConstraintType = std::function<LogicalResult(OpTy op)>;
+ UnrollVectorPattern(
+ ArrayRef<int64_t> targetShape, MLIRContext *context,
+ FilterConstraintType constraint = [](OpTy op) { return success(); })
+ : OpRewritePattern<OpTy>(context),
+ targetShape(targetShape.begin(), targetShape.end()),
+ filter(constraint) {}
+ LogicalResult matchAndRewrite(OpTy op,
+ PatternRewriter &rewriter) const override {
+ if (failed(filter(op)))
+ return failure();
+ auto unrollableVectorOp =
+ dyn_cast<VectorUnrollOpInterface>(op.getOperation());
+ if (!unrollableVectorOp)
+ return failure();
+ auto maybeUnrollShape = unrollableVectorOp.getShapeForUnroll();
+ if (!maybeUnrollShape)
+ return failure();
+ auto maybeShapeRatio = shapeRatio(*maybeUnrollShape, targetShape);
+ if (!maybeShapeRatio ||
+ llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; }))
+ return failure();
+ if (op.getOperation()->getNumResults() != 1)
+ return failure();
+ auto resultVector = unrollSingleResultVectorOp(rewriter, op, targetShape);
+ if (resultVector.size() != 1)
+ return failure();
+ rewriter.replaceOp(op, resultVector.front());
+ return success();
+ }
+
+private:
+ SmallVector<int64_t, 4> targetShape;
+ FilterConstraintType filter;
+};
} // namespace vector
diff --git a/mlir/include/mlir/Interfaces/CMakeLists.txt b/mlir/include/mlir/Interfaces/CMakeLists.txt
index 51f3f8ac1be6..0de2b5a8688b 100644
--- a/mlir/include/mlir/Interfaces/CMakeLists.txt
+++ b/mlir/include/mlir/Interfaces/CMakeLists.txt
@@ -5,5 +5,6 @@ add_mlir_interface(DerivedAttributeOpInterface)
add_mlir_interface(InferTypeOpInterface)
add_mlir_interface(LoopLikeInterface)
add_mlir_interface(SideEffectInterfaces)
+add_mlir_interface(VectorUnrollInterface)
add_mlir_interface(ViewLikeInterface)
diff --git a/mlir/include/mlir/Interfaces/VectorUnrollInterface.h b/mlir/include/mlir/Interfaces/VectorUnrollInterface.h
new file mode 100644
index 000000000000..a1cf39c17ebe
--- /dev/null
+++ b/mlir/include/mlir/Interfaces/VectorUnrollInterface.h
@@ -0,0 +1,26 @@
+//===- VectorUnrollInterface.h - Vector unrolling interface ---------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the operation interface for vector ops that can be
+// unrolled.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_INTERFACES_VECTORUNROLLINTERFACE_H
+#define MLIR_INTERFACES_VECTORUNROLLINTERFACE_H
+
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/StandardTypes.h"
+
+namespace mlir {
+
+#include "mlir/Interfaces/VectorUnrollInterface.h.inc"
+
+} // namespace mlir
+
+#endif // MLIR_INTERFACES_VECTORUNROLLINTERFACE_H
diff --git a/mlir/include/mlir/Interfaces/VectorUnrollInterface.td b/mlir/include/mlir/Interfaces/VectorUnrollInterface.td
new file mode 100644
index 000000000000..b9cff8bdab1d
--- /dev/null
+++ b/mlir/include/mlir/Interfaces/VectorUnrollInterface.td
@@ -0,0 +1,45 @@
+//===- VectorUnrollInterface.td - VectorUnroll interface ---*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Defines the interface for operations on vectors that can be unrolled.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_INTERFACES_VECTORUNROLLINTERFACE
+#define MLIR_INTERFACES_VECTORUNROLLINTERFACE
+
+include "mlir/IR/OpBase.td"
+
+def VectorUnrollOpInterface : OpInterface<"VectorUnrollOpInterface"> {
+ let description = [{
+ Encodes properties of an operation on vectors that can be unrolled.
+ }];
+
+ let methods = [
+ InterfaceMethod<[{
+ Returns the shape ratio of unrolling to the target vector shape
+ `targetShape`. Returns `None` if the op cannot be unrolled to the target
+ vector shape.
+ }],
+ "Optional<SmallVector<int64_t, 4>>",
+ "getShapeForUnroll",
+ (ins),
+ /*methodBody=*/[{}],
+ [{
+ auto vt = this->getOperation()->getResult(0).getType().
+ template dyn_cast<VectorType>();
+ if (!vt)
+ return None;
+ SmallVector<int64_t, 4> res(vt.getShape().begin(), vt.getShape().end());
+ return res;
+ }]
+ >,
+ ];
+}
+
+#endif // MLIR_INTERFACES_VECTORUNROLLINTERFACE
diff --git a/mlir/lib/Dialect/StandardOps/CMakeLists.txt b/mlir/lib/Dialect/StandardOps/CMakeLists.txt
index f3b93d6013ce..7d61aea3116e 100644
--- a/mlir/lib/Dialect/StandardOps/CMakeLists.txt
+++ b/mlir/lib/Dialect/StandardOps/CMakeLists.txt
@@ -15,6 +15,7 @@ add_mlir_dialect_library(MLIRStandardOps
MLIREDSC
MLIRIR
MLIRSideEffectInterfaces
+ MLIRVectorUnrollInterface
MLIRViewLikeInterface
)
diff --git a/mlir/lib/Dialect/Vector/CMakeLists.txt b/mlir/lib/Dialect/Vector/CMakeLists.txt
index 7a5ed49cd9ce..69a329917228 100644
--- a/mlir/lib/Dialect/Vector/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/CMakeLists.txt
@@ -19,4 +19,5 @@ add_mlir_dialect_library(MLIRVector
MLIRSCF
MLIRLoopAnalysis
MLIRSideEffectInterfaces
+ MLIRVectorUnrollInterface
)
diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index 5d3a916d02ea..184aed2ee1cd 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -469,6 +469,12 @@ SmallVector<AffineMap, 4> ContractionOp::getIndexingMaps() {
return res;
}
+Optional<SmallVector<int64_t, 4>> ContractionOp::getShapeForUnroll() {
+ SmallVector<int64_t, 4> shape;
+ getIterationBounds(shape);
+ return shape;
+}
+
//===----------------------------------------------------------------------===//
// ExtractElementOp
//===----------------------------------------------------------------------===//
@@ -1522,6 +1528,11 @@ OpFoldResult TransferReadOp::fold(ArrayRef<Attribute>) {
return OpFoldResult();
}
+Optional<SmallVector<int64_t, 4>> TransferReadOp::getShapeForUnroll() {
+ auto s = getVectorType().getShape();
+ return SmallVector<int64_t, 4>{s.begin(), s.end()};
+}
+
//===----------------------------------------------------------------------===//
// TransferWriteOp
//===----------------------------------------------------------------------===//
@@ -1612,6 +1623,11 @@ LogicalResult TransferWriteOp::fold(ArrayRef<Attribute>,
return foldMemRefCast(*this);
}
+Optional<SmallVector<int64_t, 4>> TransferWriteOp::getShapeForUnroll() {
+ auto s = getVectorType().getShape();
+ return SmallVector<int64_t, 4>{s.begin(), s.end()};
+}
+
//===----------------------------------------------------------------------===//
// ShapeCastOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index b841580433f9..c7cf2937939c 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -30,6 +30,7 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Types.h"
+#include "mlir/Interfaces/VectorUnrollInterface.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
@@ -357,7 +358,7 @@ struct VectorState {
// (removable with DCE).
// TODO(andydavis) Generalize this to support structured ops beyond
-// vector ContractionOp, and merge it with 'unrollSingleResultOpMatchingType'
+// vector ContractionOp, and merge it with 'unrollSingleResultVectorOp'
static Value unrollSingleResultStructuredOp(Operation *op,
ArrayRef<int64_t> iterationBounds,
std::vector<VectorState> &vectors,
@@ -450,11 +451,7 @@ static Value unrollSingleResultStructuredOp(Operation *op,
static void getVectorContractionOpUnrollState(
vector::ContractionOp contractionOp, ArrayRef<int64_t> targetShape,
- SmallVectorImpl<int64_t> &iterationBounds,
std::vector<VectorState> &vectors, unsigned &resultIndex) {
- // Get contraction op iteration bounds.
- contractionOp.getIterationBounds(iterationBounds);
- assert(iterationBounds.size() == targetShape.size());
// Get map from iteration space index to lhs/rhs/result shape index.
std::vector<DenseMap<int64_t, int64_t>> iterationIndexMapList;
contractionOp.getIterationIndexMap(iterationIndexMapList);
@@ -476,17 +473,15 @@ static void getVectorContractionOpUnrollState(
vectors.push_back({contractionOp.getRHSVectorMaskType(),
vectors[1].indexMap, accOperandIndex + 2, false});
}
- // Unroll 'op' 'iterationBounds' to 'targetShape'.
// TODO(andydavis) Use linalg style 'args_in'/'args_out' to partition
// 'vectors' instead of 'resultIndex'.
resultIndex = accOperandIndex;
}
-static void
-getVectorElementwiseOpUnrollState(Operation *op, ArrayRef<int64_t> targetShape,
- SmallVectorImpl<int64_t> &iterationBounds,
- std::vector<VectorState> &vectors,
- unsigned &resultIndex) {
+static void getVectorElementwiseOpUnrollState(Operation *op,
+ ArrayRef<int64_t> targetShape,
+ std::vector<VectorState> &vectors,
+ unsigned &resultIndex) {
// Verify that operation and operands all have the same vector shape.
auto resultType = op->getResult(0).getType().dyn_cast_or_null<VectorType>();
assert(resultType && "Expected op with vector result type");
@@ -494,8 +489,6 @@ getVectorElementwiseOpUnrollState(Operation *op, ArrayRef<int64_t> targetShape,
// Verify that all operands have the same vector type as result.
assert(llvm::all_of(op->getOperandTypes(),
[=](Type type) { return type == resultType; }));
- // Populate 'iterationBounds' with 'resultShape' for elementwise operations.
- iterationBounds.assign(resultShape.begin(), resultShape.end());
// Create trivial elementwise identity index map based on 'resultShape'.
DenseMap<int64_t, int64_t> indexMap;
@@ -513,28 +506,32 @@ getVectorElementwiseOpUnrollState(Operation *op, ArrayRef<int64_t> targetShape,
}
// Entry point for unrolling declarative pattern rewrites.
-SmallVector<Value, 1> mlir::vector::unrollSingleResultOpMatchingType(
- OpBuilder &builder, Operation *op, ArrayRef<int64_t> targetShape) {
+SmallVector<Value, 1>
+mlir::vector::unrollSingleResultVectorOp(OpBuilder &builder, Operation *op,
+ ArrayRef<int64_t> targetShape) {
assert(op->getNumResults() == 1 && "Expected single result operation");
// Populate 'iterationBounds', 'vectors' and 'resultIndex' to unroll 'op'.
SmallVector<int64_t, 6> iterationBounds;
+ auto unrollableVectorOp = cast<VectorUnrollOpInterface>(op);
+ auto maybeUnrollShape = unrollableVectorOp.getShapeForUnroll();
+ assert(maybeUnrollShape && "Trying to unroll an incorrect vector op");
+
std::vector<VectorState> vectors;
unsigned resultIndex;
if (auto contractionOp = dyn_cast<vector::ContractionOp>(op)) {
// Populate state for vector ContractionOp.
- getVectorContractionOpUnrollState(contractionOp, targetShape,
- iterationBounds, vectors, resultIndex);
+ getVectorContractionOpUnrollState(contractionOp, targetShape, vectors,
+ resultIndex);
} else {
// Populate state for vector elementwise op.
- getVectorElementwiseOpUnrollState(op, targetShape, iterationBounds, vectors,
- resultIndex);
+ getVectorElementwiseOpUnrollState(op, targetShape, vectors, resultIndex);
}
// Unroll 'op' with 'iterationBounds' to 'targetShape'.
return SmallVector<Value, 1>{unrollSingleResultStructuredOp(
- op, iterationBounds, vectors, resultIndex, targetShape, builder)};
+ op, *maybeUnrollShape, vectors, resultIndex, targetShape, builder)};
}
/// Generates slices of 'vectorType' according to 'sizes' and 'strides, and
diff --git a/mlir/lib/Interfaces/CMakeLists.txt b/mlir/lib/Interfaces/CMakeLists.txt
index 19b4e0af626d..b8498e224f25 100644
--- a/mlir/lib/Interfaces/CMakeLists.txt
+++ b/mlir/lib/Interfaces/CMakeLists.txt
@@ -6,6 +6,7 @@ set(LLVM_OPTIONAL_SOURCES
InferTypeOpInterface.cpp
LoopLikeInterface.cpp
SideEffectInterfaces.cpp
+ VectorUnrollInterface.cpp
ViewLikeInterface.cpp
)
@@ -32,5 +33,6 @@ add_mlir_interface_library(DerivedAttributeOpInterface)
add_mlir_interface_library(InferTypeOpInterface)
add_mlir_interface_library(LoopLikeInterface)
add_mlir_interface_library(SideEffectInterfaces)
+add_mlir_interface_library(VectorUnrollInterface)
add_mlir_interface_library(ViewLikeInterface)
diff --git a/mlir/lib/Interfaces/VectorUnrollInterface.cpp b/mlir/lib/Interfaces/VectorUnrollInterface.cpp
new file mode 100644
index 000000000000..6d3d432a7061
--- /dev/null
+++ b/mlir/lib/Interfaces/VectorUnrollInterface.cpp
@@ -0,0 +1,18 @@
+//===- VectorUnrollInterface.cpp - Unrollable vector operations -*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Interfaces/VectorUnrollInterface.h"
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// VectorUnroll Interfaces
+//===----------------------------------------------------------------------===//
+
+/// Include the definitions of the VectorUntoll interfaces.
+#include "mlir/Interfaces/VectorUnrollInterface.cpp.inc"
diff --git a/mlir/test/Dialect/Vector/vector-transforms.mlir b/mlir/test/Dialect/Vector/vector-transforms.mlir
index 8de153adf731..0bd6c3c43b59 100644
--- a/mlir/test/Dialect/Vector/vector-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-transforms.mlir
@@ -1,4 +1,5 @@
// RUN: mlir-opt %s -test-vector-to-vector-conversion | FileCheck %s
+// RUN: mlir-opt %s -test-vector-unrolling-patterns | FileCheck %s
// CHECK-DAG: #[[MAP0:map[0-9]+]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-DAG: #[[MAP1:map[0-9]+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
diff --git a/mlir/test/lib/Transforms/TestVectorTransforms.cpp b/mlir/test/lib/Transforms/TestVectorTransforms.cpp
index c6cf45e824d7..1af6c3564b80 100644
--- a/mlir/test/lib/Transforms/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestVectorTransforms.cpp
@@ -92,6 +92,20 @@ struct TestVectorContractionConversion
}
};
+struct TestVectorUnrollingPatterns
+ : public PassWrapper<TestVectorUnrollingPatterns, FunctionPass> {
+ void runOnFunction() override {
+ MLIRContext *ctx = &getContext();
+ OwningRewritePatternList patterns;
+ patterns.insert<UnrollVectorPattern<AddFOp>>(ArrayRef<int64_t>{2, 2}, ctx);
+ patterns.insert<UnrollVectorPattern<vector::ContractionOp>>(
+ ArrayRef<int64_t>{2, 2, 2}, ctx);
+ populateVectorToVectorCanonicalizationPatterns(patterns, ctx);
+ populateVectorToVectorTransformationPatterns(patterns, ctx);
+ applyPatternsAndFoldGreedily(getFunction(), patterns);
+ }
+};
+
} // end anonymous namespace
namespace mlir {
@@ -107,5 +121,9 @@ void registerTestVectorConversions() {
PassRegistration<TestVectorContractionConversion> contractionPass(
"test-vector-contraction-conversion",
"Test conversion patterns that lower contract ops in the vector dialect");
+
+ PassRegistration<TestVectorUnrollingPatterns> contractionUnrollingPass(
+ "test-vector-unrolling-patterns",
+ "Test conversion patterns to unroll contract ops in the vector dialect");
}
} // namespace mlir
More information about the Mlir-commits
mailing list