[llvm-branch-commits] [mlir] 7c7b55b - [mlir][vector] Extend vector unroll to all element-wise ops

Thomas Raoux via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Mon Dec 21 13:37:10 PST 2020


Author: Thomas Raoux
Date: 2020-12-21T13:31:22-08:00
New Revision: 7c7b55b985136a975223a9cefccd8fa1a5df7765

URL: https://github.com/llvm/llvm-project/commit/7c7b55b985136a975223a9cefccd8fa1a5df7765
DIFF: https://github.com/llvm/llvm-project/commit/7c7b55b985136a975223a9cefccd8fa1a5df7765.diff

LOG: [mlir][vector] Extend vector unroll to all element-wise ops

Extend unroll to support all element-wise ops and allow unrolling for ops with
vector operands of with the same shape as the destination but different element
type (like Cmp or Select).

Differential Revision: https://reviews.llvm.org/D93121

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
    mlir/lib/Dialect/Vector/VectorTransforms.cpp
    mlir/test/Dialect/Vector/vector-transforms.mlir
    mlir/test/lib/Transforms/TestVectorTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index 7af44f8435ff..ba78db68214f 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -69,7 +69,9 @@ class CastOp<string mnemonic, list<OpTrait> traits = []> :
 
 // Base class for arithmetic cast operations.
 class ArithmeticCastOp<string mnemonic, list<OpTrait> traits = []> :
-    CastOp<mnemonic, !listconcat(traits, [ElementwiseMappable])> {
+    CastOp<mnemonic,
+    !listconcat(traits, [ElementwiseMappable,
+                         DeclareOpInterfaceMethods<VectorUnrollOpInterface>])> {
 }
 
 // Base class for unary ops. Requires single operand and result. Individual
@@ -104,6 +106,7 @@ class ArithmeticOp<string mnemonic, list<OpTrait> traits = []> :
     Op<StandardOps_Dialect, mnemonic,
        !listconcat(traits, [NoSideEffect,
                             SameOperandsAndResultType,
+                            DeclareOpInterfaceMethods<VectorUnrollOpInterface>,
                             ElementwiseMappable])> {
 
   let results = (outs AnyType:$result);
@@ -992,6 +995,7 @@ def CmpFPredicateAttr : I64EnumAttr<
 
 def CmpFOp : Std_Op<"cmpf",
     [NoSideEffect, SameTypeOperands, ElementwiseMappable,
+     DeclareOpInterfaceMethods<VectorUnrollOpInterface>,
      TypesMatchWith<
        "result type has i1 element type and same shape as operands",
        "lhs", "result", "getI1SameShape($_self)">]> {
@@ -1076,6 +1080,7 @@ def CmpIPredicateAttr : I64EnumAttr<
 
 def CmpIOp : Std_Op<"cmpi",
     [NoSideEffect, SameTypeOperands, ElementwiseMappable,
+     DeclareOpInterfaceMethods<VectorUnrollOpInterface>,
      TypesMatchWith<
        "result type has i1 element type and same shape as operands",
        "lhs", "result", "getI1SameShape($_self)">]> {
@@ -2548,7 +2553,7 @@ def RsqrtOp : FloatUnaryOp<"rsqrt"> {
 
 def SelectOp : Std_Op<"select", [NoSideEffect,
      AllTypesMatch<["true_value", "false_value", "result"]>,
-     ElementwiseMappable]> {
+     ElementwiseMappable, DeclareOpInterfaceMethods<VectorUnrollOpInterface>]> {
   let summary = "select operation";
   let description = [{
     The `select` operation chooses one value based on a binary condition
@@ -2779,7 +2784,8 @@ def SignedShiftRightOp : IntArithmeticOp<"shift_right_signed"> {
 //===----------------------------------------------------------------------===//
 
 def SignExtendIOp : Std_Op<"sexti",
-    [NoSideEffect, ElementwiseMappable]> {
+    [NoSideEffect, ElementwiseMappable,
+    DeclareOpInterfaceMethods<VectorUnrollOpInterface>]> {
   let summary = "integer sign extension operation";
   let description = [{
     The integer sign extension operation takes an integer input of
@@ -3595,7 +3601,9 @@ def TransposeOp : Std_Op<"transpose", [NoSideEffect]>,
 // TruncateIOp
 //===----------------------------------------------------------------------===//
 
-def TruncateIOp : Std_Op<"trunci", [NoSideEffect, ElementwiseMappable]> {
+def TruncateIOp : Std_Op<"trunci",
+  [NoSideEffect, ElementwiseMappable,
+   DeclareOpInterfaceMethods<VectorUnrollOpInterface>,]> {
   let summary = "integer truncation operation";
   let description = [{
     The integer truncation operation takes an integer input of
@@ -3862,7 +3870,9 @@ def XOrOp : IntArithmeticOp<"xor", [Commutative]> {
 // ZeroExtendIOp
 //===----------------------------------------------------------------------===//
 
-def ZeroExtendIOp : Std_Op<"zexti", [NoSideEffect, ElementwiseMappable]> {
+def ZeroExtendIOp : Std_Op<"zexti",
+  [NoSideEffect, ElementwiseMappable,
+   DeclareOpInterfaceMethods<VectorUnrollOpInterface>,]> {
   let summary = "integer zero extension operation";
   let description = [{
     The integer zero extension operation takes an integer input of

diff  --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index 1e58a759d305..5ba82b39a5a6 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -492,8 +492,9 @@ static void getVectorElementwiseOpUnrollState(Operation *op,
   assert(resultType && "Expected op with vector result type");
   auto resultShape = resultType.getShape();
   // Verify that all operands have the same vector type as result.
-  assert(llvm::all_of(op->getOperandTypes(),
-                      [=](Type type) { return type == resultType; }));
+  assert(llvm::all_of(op->getOperandTypes(), [=](Type type) {
+    return type.cast<VectorType>().getShape() == resultShape;
+  }));
 
   // Create trivial elementwise identity index map based on 'resultShape'.
   DenseMap<int64_t, int64_t> indexMap;
@@ -504,8 +505,9 @@ static void getVectorElementwiseOpUnrollState(Operation *op,
   // Create VectorState each operand and single result.
   unsigned numVectors = op->getNumOperands() + op->getNumResults();
   vectors.resize(numVectors);
-  for (unsigned i = 0; i < op->getNumOperands(); ++i)
-    vectors[i] = {resultType, indexMap, i, false};
+  for (auto it : llvm::enumerate(op->getOperandTypes()))
+    vectors[it.index()] = {it.value().cast<VectorType>(), indexMap,
+                           static_cast<int64_t>(it.index()), false};
   vectors[numVectors - 1] = {resultType, indexMap, -1, false};
   resultIndex = numVectors - 1;
 }

diff  --git a/mlir/test/Dialect/Vector/vector-transforms.mlir b/mlir/test/Dialect/Vector/vector-transforms.mlir
index 167314d36458..43a83f04dd30 100644
--- a/mlir/test/Dialect/Vector/vector-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-transforms.mlir
@@ -1,5 +1,4 @@
 // RUN: mlir-opt %s -test-vector-to-vector-conversion | FileCheck %s
-// RUN: mlir-opt %s -test-vector-unrolling-patterns | FileCheck %s
 
 // CHECK-DAG: #[[MAP1:map[0-9]+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
 
@@ -514,3 +513,38 @@ func @shape_cast_fold(%arg0 : vector<5x4x2xf32>, %arg1 : vector<3x4x2xf32>)
 
   return %6, %7 : vector<5x4x2xf32>, vector<3x4x2xf32>
 }
+
+// CHECK-LABEL: func @elementwise_unroll
+//  CHECK-SAME: (%[[ARG0:.*]]: memref<4x4xf32>, %[[ARG1:.*]]: memref<4x4xf32>)
+//       CHECK:   %[[C0:.*]] = constant 0 : index
+//       CHECK:   %[[C2:.*]] = constant 2 : index
+//       CHECK:   %[[VT0:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : memref<4x4xf32>, vector<2x2xf32>
+//       CHECK:   %[[VT1:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C2]]], {{.*}} : memref<4x4xf32>, vector<2x2xf32>
+//       CHECK:   %[[VT2:.*]] = vector.transfer_read %[[ARG0]][%[[C2]], %[[C0]]], {{.*}} : memref<4x4xf32>, vector<2x2xf32>
+//       CHECK:   %[[VT3:.*]] = vector.transfer_read %[[ARG0]][%[[C2]], %[[C2]]], {{.*}} : memref<4x4xf32>, vector<2x2xf32>
+//       CHECK:   %[[VT4:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : memref<4x4xf32>, vector<2x2xf32>
+//       CHECK:   %[[VT5:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C2]]], {{.*}} : memref<4x4xf32>, vector<2x2xf32>
+//       CHECK:   %[[VT6:.*]] = vector.transfer_read %[[ARG1]][%[[C2]], %[[C0]]], {{.*}} : memref<4x4xf32>, vector<2x2xf32>
+//       CHECK:   %[[VT7:.*]] = vector.transfer_read %[[ARG1]][%[[C2]], %[[C2]]], {{.*}} : memref<4x4xf32>, vector<2x2xf32>
+//       CHECK:   %[[CMP0:.*]] = cmpf "ult", %[[VT0]], %[[VT4]] : vector<2x2xf32>
+//       CHECK:   %[[CMP1:.*]] = cmpf "ult", %[[VT1]], %[[VT5]] : vector<2x2xf32>
+//       CHECK:   %[[CMP2:.*]] = cmpf "ult", %[[VT2]], %[[VT6]] : vector<2x2xf32>
+//       CHECK:   %[[CMP3:.*]] = cmpf "ult", %[[VT3]], %[[VT7]] : vector<2x2xf32>
+//       CHECK:   %[[SEL0:.*]] = select %[[CMP0]], %[[VT0]], %[[VT4]] : vector<2x2xi1>, vector<2x2xf32>
+//       CHECK:   %[[SEL1:.*]] = select %[[CMP1]], %[[VT1]], %[[VT5]] : vector<2x2xi1>, vector<2x2xf32>
+//       CHECK:   %[[SEL2:.*]] = select %[[CMP2]], %[[VT2]], %[[VT6]] : vector<2x2xi1>, vector<2x2xf32>
+//       CHECK:   %[[SEL3:.*]] = select %[[CMP3]], %[[VT3]], %[[VT7]] : vector<2x2xi1>, vector<2x2xf32>
+//       CHECK:   vector.transfer_write %[[SEL0]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32>
+//       CHECK:   vector.transfer_write %[[SEL1]], %[[ARG0]][%[[C0]], %[[C2]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32>
+//       CHECK:   vector.transfer_write %[[SEL2]], %[[ARG0]][%[[C2]], %[[C0]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32>
+//       CHECK:   vector.transfer_write %[[SEL3]], %[[ARG0]][%[[C2]], %[[C2]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32>
+func @elementwise_unroll(%arg0 : memref<4x4xf32>, %arg1 : memref<4x4xf32>) {
+  %c0 = constant 0 : index
+  %cf0 = constant 0.0 : f32
+  %0 = vector.transfer_read %arg0[%c0, %c0], %cf0 : memref<4x4xf32>, vector<4x4xf32>
+  %1 = vector.transfer_read %arg1[%c0, %c0], %cf0 : memref<4x4xf32>, vector<4x4xf32>
+  %cond = cmpf "ult", %0, %1 : vector<4x4xf32>
+  %2 = select %cond, %0, %1 : vector<4x4xi1>, vector<4x4xf32>
+  vector.transfer_write %2, %arg0[%c0, %c0] : vector<4x4xf32>, memref<4x4xf32>
+  return
+}

diff  --git a/mlir/test/lib/Transforms/TestVectorTransforms.cpp b/mlir/test/lib/Transforms/TestVectorTransforms.cpp
index 99c336ef0565..f219ef04fce5 100644
--- a/mlir/test/lib/Transforms/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestVectorTransforms.cpp
@@ -37,8 +37,11 @@ struct TestVectorToVectorConversion
 private:
   // Return the target shape based on op type.
   static Optional<SmallVector<int64_t, 4>> getShape(Operation *op) {
-    if (isa<AddFOp>(op))
+    if (isa<AddFOp, SelectOp, CmpFOp>(op))
       return SmallVector<int64_t, 4>(2, 2);
+    if (auto transferOp = dyn_cast<VectorTransferOpInterface>(op)) {
+      return SmallVector<int64_t, 4>(transferOp.getVectorType().getRank(), 2);
+    }
     if (isa<vector::ContractionOp>(op))
       return SmallVector<int64_t, 4>(3, 2);
     return llvm::None;


        


More information about the llvm-branch-commits mailing list