[Mlir-commits] [mlir] b6204b9 - [mlir][Vector] Introduce UnrollVectorOptions to control vector unrolling.

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Oct 23 13:52:46 PDT 2020


Author: MaheshRavishankar
Date: 2020-10-23T13:52:26-07:00
New Revision: b6204b995eaa2ec771f947a2109bd2ef338e688c

URL: https://github.com/llvm/llvm-project/commit/b6204b995eaa2ec771f947a2109bd2ef338e688c
DIFF: https://github.com/llvm/llvm-project/commit/b6204b995eaa2ec771f947a2109bd2ef338e688c.diff

LOG: [mlir][Vector] Introduce UnrollVectorOptions to control vector unrolling.

The current pattern for vector unrolling takes the native shape to
unroll to at pattern instantiation time, but the native shape might
defer based on the types of the operand. Introduce a
UnrollVectorOptions struct which allows for using a function that will
return the native shape based on the operation. Move other options of
unrolling like `filterConstraints` into this struct.

Differential Revision: https://reviews.llvm.org/D89744

Added: 
    mlir/test/Dialect/Vector/vector-unroll-options.mlir

Modified: 
    mlir/include/mlir/Dialect/Vector/VectorTransforms.h
    mlir/test/lib/Transforms/TestVectorTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
index 157084a2bff1..a1cf90cb10d5 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
@@ -85,21 +85,51 @@ SmallVector<Value, 1> unrollSingleResultVectorOp(OpBuilder &builder,
 LogicalResult unrollTransferWriteOp(OpBuilder &builder, Operation *op,
                                     ArrayRef<int64_t> targetShape);
 
+/// Options that control the vector unrolling.
+struct UnrollVectorOptions {
+  using FilterConstraintFnType = std::function<LogicalResult(Operation *op)>;
+  /// Callback function that indicates whether vector unrolling should be
+  /// attempted on the operation.
+  FilterConstraintFnType filterConstraint = nullptr;
+  UnrollVectorOptions &setFilterContraint(FilterConstraintFnType constraint) {
+    filterConstraint = constraint;
+    return *this;
+  }
+
+  using NativeShapeFnType =
+      std::function<Optional<SmallVector<int64_t, 4>>(Operation *op)>;
+  /// Function that returns the shape of the vector to unroll to for a given
+  /// operation. The unrolling is aborted if the function returns `llvm::None`.
+  NativeShapeFnType nativeShape = nullptr;
+  UnrollVectorOptions &setNativeShapeFn(NativeShapeFnType fn) {
+    nativeShape = fn;
+    return *this;
+  }
+
+  /// Set the native shape to use for unrolling.
+  UnrollVectorOptions &setNativeShape(ArrayRef<int64_t> shape) {
+    SmallVector<int64_t, 4> tsShape(shape.begin(), shape.end());
+    nativeShape = [=](Operation *) -> Optional<SmallVector<int64_t, 4>> {
+      return tsShape;
+    };
+    return *this;
+  }
+};
 /// Pattern to apply `unrollSingleResultVectorOp` to a `targetShape`
 /// declaratively.
 template <typename OpTy>
 struct UnrollVectorPattern : public OpRewritePattern<OpTy> {
   using FilterConstraintType = std::function<LogicalResult(OpTy op)>;
-  UnrollVectorPattern(
-      ArrayRef<int64_t> targetShape, MLIRContext *context,
-      FilterConstraintType constraint = [](OpTy op) { return success(); })
-      : OpRewritePattern<OpTy>(context),
-        targetShape(targetShape.begin(), targetShape.end()),
-        filter(constraint) {}
+  UnrollVectorPattern(MLIRContext *context, UnrollVectorOptions options)
+      : OpRewritePattern<OpTy>(context), options(options) {}
   LogicalResult matchAndRewrite(OpTy op,
                                 PatternRewriter &rewriter) const override {
-    if (failed(filter(op)))
+    if (options.filterConstraint && failed(options.filterConstraint(op)))
       return failure();
+    if (!options.nativeShape) {
+      return op.emitError("vector unrolling expects the native shape or native"
+                          "shape call back function to be set");
+    }
     auto unrollableVectorOp =
         dyn_cast<VectorUnrollOpInterface>(op.getOperation());
     if (!unrollableVectorOp)
@@ -107,19 +137,22 @@ struct UnrollVectorPattern : public OpRewritePattern<OpTy> {
     auto maybeUnrollShape = unrollableVectorOp.getShapeForUnroll();
     if (!maybeUnrollShape)
       return failure();
-    auto maybeShapeRatio = shapeRatio(*maybeUnrollShape, targetShape);
+    Optional<SmallVector<int64_t, 4>> targetShape = options.nativeShape(op);
+    if (!targetShape)
+      return op.emitError("failed to get target shape for vector unroll");
+    auto maybeShapeRatio = shapeRatio(*maybeUnrollShape, *targetShape);
     if (!maybeShapeRatio ||
         llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; }))
       return failure();
     if (std::is_same<OpTy, TransferWriteOp>::value) {
-      if (failed(unrollTransferWriteOp(rewriter, op, targetShape)))
+      if (failed(unrollTransferWriteOp(rewriter, op, *targetShape)))
         return failure();
       rewriter.eraseOp(op);
       return success();
     }
     if (op.getOperation()->getNumResults() != 1)
       return failure();
-    auto resultVector = unrollSingleResultVectorOp(rewriter, op, targetShape);
+    auto resultVector = unrollSingleResultVectorOp(rewriter, op, *targetShape);
     if (resultVector.size() != 1)
       return failure();
     rewriter.replaceOp(op, resultVector.front());
@@ -127,8 +160,7 @@ struct UnrollVectorPattern : public OpRewritePattern<OpTy> {
   }
 
 private:
-  SmallVector<int64_t, 4> targetShape;
-  FilterConstraintType filter;
+  UnrollVectorOptions options;
 };
 
 /// Split a vector.transfer operation into an unmasked fastpath and a slowpath.

diff  --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
new file mode 100644
index 000000000000..705d4ab65739
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
@@ -0,0 +1,75 @@
+// RUN: mlir-opt %s -test-vector-unrolling-patterns=unroll-based-on-type | FileCheck %s
+
+func @vector_contract_f32(%lhs : vector<8x8xf32>, %rhs : vector<8x8xf32>,
+                          %init : vector<8x8xf32>) -> vector<8x8xf32> {
+  %0 = vector.contract
+         {indexing_maps = [affine_map<(i, j, k) -> (i, k)>,
+                           affine_map<(i, j, k) -> (j, k)>,
+                           affine_map<(i, j, k) -> (i, j)>],
+          iterator_types = ["parallel", "parallel", "reduction"]}
+       %lhs, %rhs, %init : vector<8x8xf32>, vector<8x8xf32> into vector<8x8xf32>
+  return %0 : vector<8x8xf32>
+}
+// CHECK-LABEL: func @vector_contract_f32
+//       CHECK:   vector.contract {
+//  CHECK-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
+//       CHECK:   vector.contract {
+//  CHECK-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
+//       CHECK:   vector.contract {
+//  CHECK-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
+//       CHECK:   vector.contract {
+//  CHECK-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
+//       CHECK:   vector.contract {
+//  CHECK-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
+//       CHECK:   vector.contract {
+//  CHECK-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
+//       CHECK:   vector.contract {
+//  CHECK-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
+//       CHECK:   vector.contract {
+//  CHECK-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
+//       CHECK:   vector.contract {
+//  CHECK-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
+//       CHECK:   vector.contract {
+//  CHECK-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
+//       CHECK:   vector.contract {
+//  CHECK-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
+//       CHECK:   vector.contract {
+//  CHECK-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
+//       CHECK:   vector.contract {
+//  CHECK-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
+//       CHECK:   vector.contract {
+//  CHECK-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
+//       CHECK:   vector.contract {
+//  CHECK-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
+//       CHECK:   vector.contract {
+//  CHECK-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
+//       CHECK:   return
+
+func @vector_contract_f16(%lhs : vector<8x8xf16>, %rhs : vector<8x8xf16>,
+                          %init : vector<8x8xf16>) -> vector<8x8xf16> {
+  %0 = vector.contract
+         {indexing_maps = [affine_map<(i, j, k) -> (i, k)>,
+                           affine_map<(i, j, k) -> (j, k)>,
+                           affine_map<(i, j, k) -> (i, j)>],
+          iterator_types = ["parallel", "parallel", "reduction"]}
+       %lhs, %rhs, %init : vector<8x8xf16>, vector<8x8xf16> into vector<8x8xf16>
+  return %0 : vector<8x8xf16>
+}
+// CHECK-LABEL: func @vector_contract_f16
+//       CHECK:   vector.contract {
+//  CHECK-SAME:     vector<4x4xf16>, vector<4x4xf16> into vector<4x4xf16>
+//       CHECK:   vector.contract {
+//  CHECK-SAME:     vector<4x4xf16>, vector<4x4xf16> into vector<4x4xf16>
+//       CHECK:   vector.contract {
+//  CHECK-SAME:     vector<4x4xf16>, vector<4x4xf16> into vector<4x4xf16>
+//       CHECK:   vector.contract {
+//  CHECK-SAME:     vector<4x4xf16>, vector<4x4xf16> into vector<4x4xf16>
+//       CHECK:   vector.contract {
+//  CHECK-SAME:     vector<4x4xf16>, vector<4x4xf16> into vector<4x4xf16>
+//       CHECK:   vector.contract {
+//  CHECK-SAME:     vector<4x4xf16>, vector<4x4xf16> into vector<4x4xf16>
+//       CHECK:   vector.contract {
+//  CHECK-SAME:     vector<4x4xf16>, vector<4x4xf16> into vector<4x4xf16>
+//       CHECK:   vector.contract {
+//  CHECK-SAME:     vector<4x4xf16>, vector<4x4xf16> into vector<4x4xf16>
+//       CHECK:   return

diff  --git a/mlir/test/lib/Transforms/TestVectorTransforms.cpp b/mlir/test/lib/Transforms/TestVectorTransforms.cpp
index 52d0f7b2bb5e..5369ab51ddb0 100644
--- a/mlir/test/lib/Transforms/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestVectorTransforms.cpp
@@ -14,6 +14,7 @@
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/Dialect/Vector/VectorOps.h"
 #include "mlir/Dialect/Vector/VectorTransforms.h"
+#include "mlir/IR/Operation.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Pass/Pass.h"
 
@@ -26,9 +27,10 @@ struct TestVectorToVectorConversion
   void runOnFunction() override {
     OwningRewritePatternList patterns;
     auto *ctx = &getContext();
-    patterns.insert<UnrollVectorPattern<AddFOp>>(ArrayRef<int64_t>{2, 2}, ctx);
+    patterns.insert<UnrollVectorPattern<AddFOp>>(
+        ctx, UnrollVectorOptions().setNativeShape(ArrayRef<int64_t>{2, 2}));
     patterns.insert<UnrollVectorPattern<vector::ContractionOp>>(
-        ArrayRef<int64_t>{2, 2, 2}, ctx);
+        ctx, UnrollVectorOptions().setNativeShape(ArrayRef<int64_t>{2, 2, 2}));
     populateVectorToVectorCanonicalizationPatterns(patterns, ctx);
     populateVectorToVectorTransformationPatterns(patterns, ctx);
     applyPatternsAndFoldGreedily(getFunction(), patterns);
@@ -113,16 +115,44 @@ struct TestVectorContractionConversion
 
 struct TestVectorUnrollingPatterns
     : public PassWrapper<TestVectorUnrollingPatterns, FunctionPass> {
+  TestVectorUnrollingPatterns() = default;
+  TestVectorUnrollingPatterns(const TestVectorUnrollingPatterns &pass) {}
   void runOnFunction() override {
     MLIRContext *ctx = &getContext();
     OwningRewritePatternList patterns;
-    patterns.insert<UnrollVectorPattern<AddFOp>>(ArrayRef<int64_t>{2, 2}, ctx);
-    patterns.insert<UnrollVectorPattern<vector::ContractionOp>>(
-        ArrayRef<int64_t>{2, 2, 2}, ctx);
+    patterns.insert<UnrollVectorPattern<AddFOp>>(
+        ctx, UnrollVectorOptions().setNativeShape(ArrayRef<int64_t>{2, 2}));
+
+    if (unrollBasedOnType) {
+      UnrollVectorOptions::NativeShapeFnType nativeShapeFn =
+          [](Operation *op) -> Optional<SmallVector<int64_t, 4>> {
+        vector::ContractionOp contractOp = cast<vector::ContractionOp>(op);
+        SmallVector<int64_t, 4> nativeShape = {4, 4, 2};
+        if (auto floatType = contractOp.getLhsType()
+                                 .getElementType()
+                                 .dyn_cast<FloatType>()) {
+          if (floatType.getWidth() == 16) {
+            nativeShape[2] = 4;
+          }
+        }
+        return nativeShape;
+      };
+      patterns.insert<UnrollVectorPattern<vector::ContractionOp>>(
+          ctx, UnrollVectorOptions().setNativeShapeFn(nativeShapeFn));
+    } else {
+      patterns.insert<UnrollVectorPattern<vector::ContractionOp>>(
+          ctx,
+          UnrollVectorOptions().setNativeShape(ArrayRef<int64_t>{2, 2, 2}));
+    }
     populateVectorToVectorCanonicalizationPatterns(patterns, ctx);
     populateVectorToVectorTransformationPatterns(patterns, ctx);
     applyPatternsAndFoldGreedily(getFunction(), patterns);
   }
+
+  Option<bool> unrollBasedOnType{
+      *this, "unroll-based-on-type",
+      llvm::cl::desc("Set the unroll factor based on type of the operation"),
+      llvm::cl::init(false)};
 };
 
 struct TestVectorDistributePatterns
@@ -165,9 +195,9 @@ struct TestVectorTransferUnrollingPatterns
     MLIRContext *ctx = &getContext();
     OwningRewritePatternList patterns;
     patterns.insert<UnrollVectorPattern<vector::TransferReadOp>>(
-        ArrayRef<int64_t>{2, 2}, ctx);
+        ctx, UnrollVectorOptions().setNativeShape(ArrayRef<int64_t>{2, 2}));
     patterns.insert<UnrollVectorPattern<vector::TransferWriteOp>>(
-        ArrayRef<int64_t>{2, 2}, ctx);
+        ctx, UnrollVectorOptions().setNativeShape(ArrayRef<int64_t>{2, 2}));
     populateVectorToVectorCanonicalizationPatterns(patterns, ctx);
     populateVectorToVectorTransformationPatterns(patterns, ctx);
     applyPatternsAndFoldGreedily(getFunction(), patterns);


        


More information about the Mlir-commits mailing list