[llvm-branch-commits] [mlir] 7c7b55b - [mlir][vector] Extend vector unroll to all element-wise ops
Thomas Raoux via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Mon Dec 21 13:37:10 PST 2020
Author: Thomas Raoux
Date: 2020-12-21T13:31:22-08:00
New Revision: 7c7b55b985136a975223a9cefccd8fa1a5df7765
URL: https://github.com/llvm/llvm-project/commit/7c7b55b985136a975223a9cefccd8fa1a5df7765
DIFF: https://github.com/llvm/llvm-project/commit/7c7b55b985136a975223a9cefccd8fa1a5df7765.diff
LOG: [mlir][vector] Extend vector unroll to all element-wise ops
Extend unroll to support all element-wise ops and allow unrolling for ops with
vector operands of with the same shape as the destination but different element
type (like Cmp or Select).
Differential Revision: https://reviews.llvm.org/D93121
Added:
Modified:
mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
mlir/lib/Dialect/Vector/VectorTransforms.cpp
mlir/test/Dialect/Vector/vector-transforms.mlir
mlir/test/lib/Transforms/TestVectorTransforms.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index 7af44f8435ff..ba78db68214f 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -69,7 +69,9 @@ class CastOp<string mnemonic, list<OpTrait> traits = []> :
// Base class for arithmetic cast operations.
class ArithmeticCastOp<string mnemonic, list<OpTrait> traits = []> :
- CastOp<mnemonic, !listconcat(traits, [ElementwiseMappable])> {
+ CastOp<mnemonic,
+ !listconcat(traits, [ElementwiseMappable,
+ DeclareOpInterfaceMethods<VectorUnrollOpInterface>])> {
}
// Base class for unary ops. Requires single operand and result. Individual
@@ -104,6 +106,7 @@ class ArithmeticOp<string mnemonic, list<OpTrait> traits = []> :
Op<StandardOps_Dialect, mnemonic,
!listconcat(traits, [NoSideEffect,
SameOperandsAndResultType,
+ DeclareOpInterfaceMethods<VectorUnrollOpInterface>,
ElementwiseMappable])> {
let results = (outs AnyType:$result);
@@ -992,6 +995,7 @@ def CmpFPredicateAttr : I64EnumAttr<
def CmpFOp : Std_Op<"cmpf",
[NoSideEffect, SameTypeOperands, ElementwiseMappable,
+ DeclareOpInterfaceMethods<VectorUnrollOpInterface>,
TypesMatchWith<
"result type has i1 element type and same shape as operands",
"lhs", "result", "getI1SameShape($_self)">]> {
@@ -1076,6 +1080,7 @@ def CmpIPredicateAttr : I64EnumAttr<
def CmpIOp : Std_Op<"cmpi",
[NoSideEffect, SameTypeOperands, ElementwiseMappable,
+ DeclareOpInterfaceMethods<VectorUnrollOpInterface>,
TypesMatchWith<
"result type has i1 element type and same shape as operands",
"lhs", "result", "getI1SameShape($_self)">]> {
@@ -2548,7 +2553,7 @@ def RsqrtOp : FloatUnaryOp<"rsqrt"> {
def SelectOp : Std_Op<"select", [NoSideEffect,
AllTypesMatch<["true_value", "false_value", "result"]>,
- ElementwiseMappable]> {
+ ElementwiseMappable, DeclareOpInterfaceMethods<VectorUnrollOpInterface>]> {
let summary = "select operation";
let description = [{
The `select` operation chooses one value based on a binary condition
@@ -2779,7 +2784,8 @@ def SignedShiftRightOp : IntArithmeticOp<"shift_right_signed"> {
//===----------------------------------------------------------------------===//
def SignExtendIOp : Std_Op<"sexti",
- [NoSideEffect, ElementwiseMappable]> {
+ [NoSideEffect, ElementwiseMappable,
+ DeclareOpInterfaceMethods<VectorUnrollOpInterface>]> {
let summary = "integer sign extension operation";
let description = [{
The integer sign extension operation takes an integer input of
@@ -3595,7 +3601,9 @@ def TransposeOp : Std_Op<"transpose", [NoSideEffect]>,
// TruncateIOp
//===----------------------------------------------------------------------===//
-def TruncateIOp : Std_Op<"trunci", [NoSideEffect, ElementwiseMappable]> {
+def TruncateIOp : Std_Op<"trunci",
+ [NoSideEffect, ElementwiseMappable,
+ DeclareOpInterfaceMethods<VectorUnrollOpInterface>,]> {
let summary = "integer truncation operation";
let description = [{
The integer truncation operation takes an integer input of
@@ -3862,7 +3870,9 @@ def XOrOp : IntArithmeticOp<"xor", [Commutative]> {
// ZeroExtendIOp
//===----------------------------------------------------------------------===//
-def ZeroExtendIOp : Std_Op<"zexti", [NoSideEffect, ElementwiseMappable]> {
+def ZeroExtendIOp : Std_Op<"zexti",
+ [NoSideEffect, ElementwiseMappable,
+ DeclareOpInterfaceMethods<VectorUnrollOpInterface>,]> {
let summary = "integer zero extension operation";
let description = [{
The integer zero extension operation takes an integer input of
diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index 1e58a759d305..5ba82b39a5a6 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -492,8 +492,9 @@ static void getVectorElementwiseOpUnrollState(Operation *op,
assert(resultType && "Expected op with vector result type");
auto resultShape = resultType.getShape();
// Verify that all operands have the same vector type as result.
- assert(llvm::all_of(op->getOperandTypes(),
- [=](Type type) { return type == resultType; }));
+ assert(llvm::all_of(op->getOperandTypes(), [=](Type type) {
+ return type.cast<VectorType>().getShape() == resultShape;
+ }));
// Create trivial elementwise identity index map based on 'resultShape'.
DenseMap<int64_t, int64_t> indexMap;
@@ -504,8 +505,9 @@ static void getVectorElementwiseOpUnrollState(Operation *op,
// Create VectorState each operand and single result.
unsigned numVectors = op->getNumOperands() + op->getNumResults();
vectors.resize(numVectors);
- for (unsigned i = 0; i < op->getNumOperands(); ++i)
- vectors[i] = {resultType, indexMap, i, false};
+ for (auto it : llvm::enumerate(op->getOperandTypes()))
+ vectors[it.index()] = {it.value().cast<VectorType>(), indexMap,
+ static_cast<int64_t>(it.index()), false};
vectors[numVectors - 1] = {resultType, indexMap, -1, false};
resultIndex = numVectors - 1;
}
diff --git a/mlir/test/Dialect/Vector/vector-transforms.mlir b/mlir/test/Dialect/Vector/vector-transforms.mlir
index 167314d36458..43a83f04dd30 100644
--- a/mlir/test/Dialect/Vector/vector-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-transforms.mlir
@@ -1,5 +1,4 @@
// RUN: mlir-opt %s -test-vector-to-vector-conversion | FileCheck %s
-// RUN: mlir-opt %s -test-vector-unrolling-patterns | FileCheck %s
// CHECK-DAG: #[[MAP1:map[0-9]+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
@@ -514,3 +513,38 @@ func @shape_cast_fold(%arg0 : vector<5x4x2xf32>, %arg1 : vector<3x4x2xf32>)
return %6, %7 : vector<5x4x2xf32>, vector<3x4x2xf32>
}
+
+// CHECK-LABEL: func @elementwise_unroll
+// CHECK-SAME: (%[[ARG0:.*]]: memref<4x4xf32>, %[[ARG1:.*]]: memref<4x4xf32>)
+// CHECK: %[[C0:.*]] = constant 0 : index
+// CHECK: %[[C2:.*]] = constant 2 : index
+// CHECK: %[[VT0:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : memref<4x4xf32>, vector<2x2xf32>
+// CHECK: %[[VT1:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C2]]], {{.*}} : memref<4x4xf32>, vector<2x2xf32>
+// CHECK: %[[VT2:.*]] = vector.transfer_read %[[ARG0]][%[[C2]], %[[C0]]], {{.*}} : memref<4x4xf32>, vector<2x2xf32>
+// CHECK: %[[VT3:.*]] = vector.transfer_read %[[ARG0]][%[[C2]], %[[C2]]], {{.*}} : memref<4x4xf32>, vector<2x2xf32>
+// CHECK: %[[VT4:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : memref<4x4xf32>, vector<2x2xf32>
+// CHECK: %[[VT5:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C2]]], {{.*}} : memref<4x4xf32>, vector<2x2xf32>
+// CHECK: %[[VT6:.*]] = vector.transfer_read %[[ARG1]][%[[C2]], %[[C0]]], {{.*}} : memref<4x4xf32>, vector<2x2xf32>
+// CHECK: %[[VT7:.*]] = vector.transfer_read %[[ARG1]][%[[C2]], %[[C2]]], {{.*}} : memref<4x4xf32>, vector<2x2xf32>
+// CHECK: %[[CMP0:.*]] = cmpf "ult", %[[VT0]], %[[VT4]] : vector<2x2xf32>
+// CHECK: %[[CMP1:.*]] = cmpf "ult", %[[VT1]], %[[VT5]] : vector<2x2xf32>
+// CHECK: %[[CMP2:.*]] = cmpf "ult", %[[VT2]], %[[VT6]] : vector<2x2xf32>
+// CHECK: %[[CMP3:.*]] = cmpf "ult", %[[VT3]], %[[VT7]] : vector<2x2xf32>
+// CHECK: %[[SEL0:.*]] = select %[[CMP0]], %[[VT0]], %[[VT4]] : vector<2x2xi1>, vector<2x2xf32>
+// CHECK: %[[SEL1:.*]] = select %[[CMP1]], %[[VT1]], %[[VT5]] : vector<2x2xi1>, vector<2x2xf32>
+// CHECK: %[[SEL2:.*]] = select %[[CMP2]], %[[VT2]], %[[VT6]] : vector<2x2xi1>, vector<2x2xf32>
+// CHECK: %[[SEL3:.*]] = select %[[CMP3]], %[[VT3]], %[[VT7]] : vector<2x2xi1>, vector<2x2xf32>
+// CHECK: vector.transfer_write %[[SEL0]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32>
+// CHECK: vector.transfer_write %[[SEL1]], %[[ARG0]][%[[C0]], %[[C2]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32>
+// CHECK: vector.transfer_write %[[SEL2]], %[[ARG0]][%[[C2]], %[[C0]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32>
+// CHECK: vector.transfer_write %[[SEL3]], %[[ARG0]][%[[C2]], %[[C2]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32>
+func @elementwise_unroll(%arg0 : memref<4x4xf32>, %arg1 : memref<4x4xf32>) {
+ %c0 = constant 0 : index
+ %cf0 = constant 0.0 : f32
+ %0 = vector.transfer_read %arg0[%c0, %c0], %cf0 : memref<4x4xf32>, vector<4x4xf32>
+ %1 = vector.transfer_read %arg1[%c0, %c0], %cf0 : memref<4x4xf32>, vector<4x4xf32>
+ %cond = cmpf "ult", %0, %1 : vector<4x4xf32>
+ %2 = select %cond, %0, %1 : vector<4x4xi1>, vector<4x4xf32>
+ vector.transfer_write %2, %arg0[%c0, %c0] : vector<4x4xf32>, memref<4x4xf32>
+ return
+}
diff --git a/mlir/test/lib/Transforms/TestVectorTransforms.cpp b/mlir/test/lib/Transforms/TestVectorTransforms.cpp
index 99c336ef0565..f219ef04fce5 100644
--- a/mlir/test/lib/Transforms/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestVectorTransforms.cpp
@@ -37,8 +37,11 @@ struct TestVectorToVectorConversion
private:
// Return the target shape based on op type.
static Optional<SmallVector<int64_t, 4>> getShape(Operation *op) {
- if (isa<AddFOp>(op))
+ if (isa<AddFOp, SelectOp, CmpFOp>(op))
return SmallVector<int64_t, 4>(2, 2);
+ if (auto transferOp = dyn_cast<VectorTransferOpInterface>(op)) {
+ return SmallVector<int64_t, 4>(transferOp.getVectorType().getRank(), 2);
+ }
if (isa<vector::ContractionOp>(op))
return SmallVector<int64_t, 4>(3, 2);
return llvm::None;
More information about the llvm-branch-commits
mailing list