[Mlir-commits] [mlir] 2fae787 - [mlir][Vector] Mostly-NFC - Restructure options for lowering to LLVM Matrix Intrinsics

Nicolas Vasilache llvmlistbot at llvm.org
Tue Mar 17 20:01:05 PDT 2020


Author: Nicolas Vasilache
Date: 2020-03-17T22:58:02-04:00
New Revision: 2fae7878d552605311e2b62846af73c242b6a22a

URL: https://github.com/llvm/llvm-project/commit/2fae7878d552605311e2b62846af73c242b6a22a
DIFF: https://github.com/llvm/llvm-project/commit/2fae7878d552605311e2b62846af73c242b6a22a.diff

LOG: [mlir][Vector] Mostly-NFC - Restructure options for lowering to LLVM Matrix Intrinsics

Summary:
This revision restructures the calling of vector transforms to make it more flexible to ask for lowering through LLVM matrix intrinsics.
This also makes sure we bail out in degenerate cases (i.e. 1) in which LLVM complains about not being able to scalarize.

Differential Revision: https://reviews.llvm.org/D76266

Added: 
    

Modified: 
    mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
    mlir/include/mlir/Dialect/Vector/VectorOps.h
    mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
    mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
    mlir/lib/Dialect/Vector/VectorTransforms.cpp
    mlir/test/Dialect/Vector/vector-contract-transforms.mlir
    mlir/test/lib/Transforms/TestVectorTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
index 8c95d13a2922..a92906d0c2c3 100644
--- a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
+++ b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
@@ -26,7 +26,7 @@ void populateVectorToLLVMConversionPatterns(LLVMTypeConverter &converter,
                                             OwningRewritePatternList &patterns);
 
 /// Create a pass to convert vector operations to the LLVMIR dialect.
-OpPassBase<ModuleOp> *createLowerVectorToLLVMPass();
+std::unique_ptr<OpPassBase<ModuleOp>> createConvertVectorToLLVMPass();
 
 } // namespace mlir
 

diff  --git a/mlir/include/mlir/Dialect/Vector/VectorOps.h b/mlir/include/mlir/Dialect/Vector/VectorOps.h
index 1aaa75290931..50fa0150ba53 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.h
@@ -24,6 +24,13 @@ class MLIRContext;
 class OwningRewritePatternList;
 namespace vector {
 
+/// Structure to control the behavior of vector transform patterns.
+struct VectorTransformsOptions {
+  /// Let vector.contract lower to vector.matrix_multiply and LLVM matrix
+  /// intrinsics.
+  bool lowerToLLVMMatrixIntrinsics = false;
+};
+
 /// Collect a set of vector-to-vector canonicalization patterns.
 void populateVectorToVectorCanonicalizationPatterns(
     OwningRewritePatternList &patterns, MLIRContext *context);
@@ -50,8 +57,9 @@ void populateVectorSlicesLoweringPatterns(OwningRewritePatternList &patterns,
 ///   OuterproductOpLowering
 /// These transformation express higher level vector ops in terms of more
 /// elementary extraction, insertion, reduction, product, and broadcast ops.
-void populateVectorContractLoweringPatterns(OwningRewritePatternList &patterns,
-                                            MLIRContext *context);
+void populateVectorContractLoweringPatterns(
+    OwningRewritePatternList &patterns, MLIRContext *context,
+    VectorTransformsOptions vectorTransformOptions = VectorTransformsOptions());
 
 /// Returns the integer type required for subscripts in the vector dialect.
 IntegerType getVectorSubscriptType(Builder &builder);

diff  --git a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
index 497cb448ef42..05e593ba300c 100644
--- a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
+++ b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
@@ -562,6 +562,7 @@ void ConvertLinalgToLLVMPass::runOnModule() {
   populateLoopToStdConversionPatterns(patterns, &getContext());
   populateStdToLLVMConversionPatterns(converter, patterns, /*useAlloca=*/false,
                                       /*emitCWrappers=*/true);
+  populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
   populateVectorToLLVMConversionPatterns(converter, patterns);
   populateLinalgToStandardConversionPatterns(patterns, &getContext());
   populateLinalgToLLVMConversionPatterns(converter, patterns, &getContext());

diff  --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 290fa4c3b109..459c7243fd46 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1150,8 +1150,8 @@ void LowerVectorToLLVMPass::runOnModule() {
   }
 }
 
-OpPassBase<ModuleOp> *mlir::createLowerVectorToLLVMPass() {
-  return new LowerVectorToLLVMPass();
+std::unique_ptr<OpPassBase<ModuleOp>> mlir::createConvertVectorToLLVMPass() {
+  return std::make_unique<LowerVectorToLLVMPass>();
 }
 
 static PassRegistration<LowerVectorToLLVMPass>

diff  --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index fb5749cfe727..38a83e01bcbf 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -42,13 +42,6 @@ using namespace mlir;
 using llvm::dbgs;
 using mlir::functional::zipMap;
 
-static llvm::cl::OptionCategory clOptionsCategory(DEBUG_TYPE " options");
-
-static llvm::cl::opt<bool> lowerToLLVMMatrixIntrinsics(
-    "vector-lower-matrix-intrinsics",
-    llvm::cl::desc("Lower vector.contract to llvm.intr.matrix.multiply"),
-    llvm::cl::init(false), llvm::cl::cat(clOptionsCategory));
-
 /// Given a shape with sizes greater than 0 along all dimensions,
 /// returns the distance, in number of elements, between a slice in a dimension
 /// and the next slice in the same dimension.
@@ -936,6 +929,11 @@ class ContractionOpLowering : public OpRewritePattern<vector::ContractionOp> {
 public:
   using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
 
+  ContractionOpLowering(vector::VectorTransformsOptions vectorTransformsOptions,
+                        MLIRContext *context)
+      : OpRewritePattern<vector::ContractionOp>(context),
+        vectorTransformsOptions(vectorTransformsOptions) {}
+
   PatternMatchResult matchAndRewrite(vector::ContractionOp op,
                                      PatternRewriter &rewriter) const override {
     // TODO(ajcbik): implement masks
@@ -946,33 +944,41 @@ class ContractionOpLowering : public OpRewritePattern<vector::ContractionOp> {
     // a new pattern.
     // TODO(ntv, fhahn): once row-major mode is available in LLVM's matrix
     // intrinsics, use that.
-    if (lowerToLLVMMatrixIntrinsics &&
+    if (vectorTransformsOptions.lowerToLLVMMatrixIntrinsics &&
         isColumnMajorMatmul(op.indexing_maps())) {
       VectorType lhsType = op.getLhsType();
       VectorType rhsType = op.getRhsType();
-      Type flattenedLHSType =
-          VectorType::get(lhsType.getNumElements(), lhsType.getElementType());
-      Type flattenedRHSType =
-          VectorType::get(rhsType.getNumElements(), rhsType.getElementType());
-      auto lhs = rewriter.create<vector::ShapeCastOp>(
-          op.getLoc(), flattenedLHSType, op.lhs());
-      auto rhs = rewriter.create<vector::ShapeCastOp>(
-          op.getLoc(), flattenedRHSType, op.rhs());
-
       unsigned lhsRows = op.getLhsType().getShape()[0];
       unsigned lhsColumns = op.getLhsType().getShape()[1];
       unsigned rhsColumns = op.getRhsType().getShape()[1];
-      Value mul = rewriter.create<vector::MatmulOp>(
-          op.getLoc(), lhs, rhs, lhsRows, lhsColumns, rhsColumns);
-      mul = rewriter.create<vector::ShapeCastOp>(op.getLoc(),
-                                                 op.acc().getType(), mul);
-      Type elementType = op.getLhsType().getElementType();
-      assert(elementType.isIntOrFloat());
-      if (elementType.isa<IntegerType>())
-        rewriter.replaceOpWithNewOp<AddIOp>(op, op.acc(), mul);
-      else
-        rewriter.replaceOpWithNewOp<AddFOp>(op, op.acc(), mul);
-      return matchSuccess();
+
+      // In cases where matrices are degenerate, scalarization issues occur in
+      // the backend. Avoid all LLVM scalarization issues for now.
+      // For more details, see: https://bugs.llvm.org/show_bug.cgi?id=45227 and
+      // https://bugs.llvm.org/show_bug.cgi?id=45229
+      // TODO(ntv, fhahn): Relax once above bugs are fixed.
+      if (lhsRows != 1 && lhsColumns != 1 && rhsColumns != 1) {
+        Type flattenedLHSType =
+            VectorType::get(lhsType.getNumElements(), lhsType.getElementType());
+        Type flattenedRHSType =
+            VectorType::get(rhsType.getNumElements(), rhsType.getElementType());
+        auto lhs = rewriter.create<vector::ShapeCastOp>(
+            op.getLoc(), flattenedLHSType, op.lhs());
+        auto rhs = rewriter.create<vector::ShapeCastOp>(
+            op.getLoc(), flattenedRHSType, op.rhs());
+
+        Value mul = rewriter.create<vector::MatmulOp>(
+            op.getLoc(), lhs, rhs, lhsRows, lhsColumns, rhsColumns);
+        mul = rewriter.create<vector::ShapeCastOp>(op.getLoc(),
+                                                   op.acc().getType(), mul);
+        Type elementType = op.getLhsType().getElementType();
+        assert(elementType.isIntOrFloat());
+        if (elementType.isa<IntegerType>())
+          rewriter.replaceOpWithNewOp<AddIOp>(op, op.acc(), mul);
+        else
+          rewriter.replaceOpWithNewOp<AddFOp>(op, op.acc(), mul);
+        return matchSuccess();
+      }
     }
 
     // Find first batch dimension in LHS/RHS, and lower when found.
@@ -1255,6 +1261,8 @@ class ContractionOpLowering : public OpRewritePattern<vector::ContractionOp> {
     }
     return result;
   }
+
+  vector::VectorTransformsOptions vectorTransformsOptions;
 };
 
 /// ShapeOp 2D -> 1D downcast serves the purpose of flattening 2-D to 1-D
@@ -1342,8 +1350,10 @@ void mlir::vector::populateVectorSlicesLoweringPatterns(
 }
 
 void mlir::vector::populateVectorContractLoweringPatterns(
-    OwningRewritePatternList &patterns, MLIRContext *context) {
-  patterns.insert<ContractionOpLowering, ShapeCastOp2DDownCastRewritePattern,
+    OwningRewritePatternList &patterns, MLIRContext *context,
+    VectorTransformsOptions parameters) {
+  patterns.insert<ShapeCastOp2DDownCastRewritePattern,
                   ShapeCastOp2DUpCastRewritePattern, OuterProductOpLowering>(
       context);
+  patterns.insert<ContractionOpLowering>(parameters, context);
 }

diff  --git a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
index 2440c7b4a566..bed90d6341d9 100644
--- a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
@@ -1,5 +1,5 @@
 // RUN: mlir-opt %s -test-vector-contraction-conversion | FileCheck %s
-// RUN: mlir-opt %s -test-vector-contraction-conversion -vector-lower-matrix-intrinsics | FileCheck %s --check-prefix=MATRIX
+// RUN: mlir-opt %s -test-vector-contraction-conversion=vector-lower-matrix-intrinsics=1 | FileCheck %s --check-prefix=MATRIX
 
 #dotp_accesses = [
   affine_map<(i) -> (i)>,

diff  --git a/mlir/test/lib/Transforms/TestVectorTransforms.cpp b/mlir/test/lib/Transforms/TestVectorTransforms.cpp
index fc6095fb1fb1..8f2f64e5f60a 100644
--- a/mlir/test/lib/Transforms/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestVectorTransforms.cpp
@@ -16,7 +16,6 @@
 
 using namespace mlir;
 using namespace mlir::vector;
-
 namespace {
 
 #include "TestVectorTransformPatterns.h.inc"
@@ -44,9 +43,20 @@ struct TestVectorSlicesConversion
 
 struct TestVectorContractionConversion
     : public FunctionPass<TestVectorContractionConversion> {
+  TestVectorContractionConversion() = default;
+  TestVectorContractionConversion(const TestVectorContractionConversion &pass) {
+  }
+
+  Option<bool> lowerToLLVMMatrixIntrinsics{
+      *this, "vector-lower-matrix-intrinsics",
+      llvm::cl::desc("Lower vector.contract to llvm.intr.matrix.multiply"),
+      llvm::cl::init(false)};
+
   void runOnFunction() override {
     OwningRewritePatternList patterns;
-    populateVectorContractLoweringPatterns(patterns, &getContext());
+    VectorTransformsOptions options{
+        /*lowerToLLVMMatrixIntrinsics=*/lowerToLLVMMatrixIntrinsics};
+    populateVectorContractLoweringPatterns(patterns, &getContext(), options);
     applyPatternsGreedily(getFunction(), patterns);
   }
 };


        


More information about the Mlir-commits mailing list