[llvm-branch-commits] [mlir] c95acf0 - [mlir][vector][avx512] move avx512 lowering pass into general vector lowering

Aart Bik via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Thu Dec 3 17:28:52 PST 2020


Author: Aart Bik
Date: 2020-12-03T17:23:46-08:00
New Revision: c95acf052b53e5c18e380b8632e7de24b5e65dbe

URL: https://github.com/llvm/llvm-project/commit/c95acf052b53e5c18e380b8632e7de24b5e65dbe
DIFF: https://github.com/llvm/llvm-project/commit/c95acf052b53e5c18e380b8632e7de24b5e65dbe.diff

LOG: [mlir][vector][avx512] move avx512 lowering pass into general vector lowering

A separate AVX512 lowering pass does not compose well with the regular
vector lowering pass. As such, it is at risk of code duplication and
lowering inconsistencies. This change removes the separate AVX512 lowering
pass and makes it an "option" in the regular vector lowering pass
(viz. vector dialect "augmented" with AVX512 dialect).

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D92614

Added: 
    mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp

Modified: 
    mlir/include/mlir/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.h
    mlir/include/mlir/Conversion/Passes.h
    mlir/include/mlir/Conversion/Passes.td
    mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
    mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp
    mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
    mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
    mlir/lib/Dialect/AVX512/CMakeLists.txt
    mlir/test/Conversion/AVX512ToLLVM/convert-to-llvm.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.h b/mlir/include/mlir/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.h
index aff5c4ca2c70..06f2958a2d5a 100644
--- a/mlir/include/mlir/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.h
+++ b/mlir/include/mlir/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.h
@@ -9,21 +9,15 @@
 #ifndef MLIR_CONVERSION_AVX512TOLLVM_CONVERTAVX512TOLLVM_H_
 #define MLIR_CONVERSION_AVX512TOLLVM_CONVERTAVX512TOLLVM_H_
 
-#include <memory>
-
 namespace mlir {
+
 class LLVMTypeConverter;
-class ModuleOp;
-template <typename T> class OperationPass;
 class OwningRewritePatternList;
 
 /// Collect a set of patterns to convert from the AVX512 dialect to LLVM.
 void populateAVX512ToLLVMConversionPatterns(LLVMTypeConverter &converter,
                                             OwningRewritePatternList &patterns);
 
-/// Create a pass to convert AVX512 operations to the LLVMIR dialect.
-std::unique_ptr<OperationPass<ModuleOp>> createConvertAVX512ToLLVMPass();
-
 } // namespace mlir
 
 #endif // MLIR_CONVERSION_AVX512TOLLVM_CONVERTAVX512TOLLVM_H_

diff  --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index 0d2281f99581..cc4f59c12496 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -9,7 +9,6 @@
 #ifndef MLIR_CONVERSION_PASSES_H
 #define MLIR_CONVERSION_PASSES_H
 
-#include "mlir/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.h"
 #include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
 #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h"
 #include "mlir/Conversion/GPUCommon/GPUCommonPass.h"

diff  --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index fdf01b7f93d0..6c99d84ea30a 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -73,17 +73,6 @@ def ConvertAffineToStandard : Pass<"lower-affine"> {
   ];
 }
 
-//===----------------------------------------------------------------------===//
-// AVX512ToLLVM
-//===----------------------------------------------------------------------===//
-
-def ConvertAVX512ToLLVM : Pass<"convert-avx512-to-llvm", "ModuleOp"> {
-  let summary = "Convert the operations from the avx512 dialect into the LLVM "
-                "dialect";
-  let constructor = "mlir::createConvertAVX512ToLLVMPass()";
-  let dependentDialects = ["LLVM::LLVMDialect", "LLVM::LLVMAVX512Dialect"];
-}
-
 //===----------------------------------------------------------------------===//
 // AsyncToLLVM
 //===----------------------------------------------------------------------===//
@@ -401,15 +390,28 @@ def ConvertVectorToSCF : FunctionPass<"convert-vector-to-scf"> {
 def ConvertVectorToLLVM : Pass<"convert-vector-to-llvm", "ModuleOp"> {
   let summary = "Lower the operations from the vector dialect into the LLVM "
                 "dialect";
+  let description = [{
+
+    Convert operations from the vector dialect into the LLVM IR dialect
+    operations. The lowering pass provides several options to control
+    the kind of optimizations that are allowed. It also provides options
+    that augment the architectural-neutral vector dialect with
+    architectural-specific dialects (AVX512, Neon, etc.).
+
+  }];
   let constructor = "mlir::createConvertVectorToLLVMPass()";
-  let dependentDialects = ["LLVM::LLVMDialect"];
+  let dependentDialects = ["LLVM::LLVMDialect", "LLVM::LLVMAVX512Dialect"];
   let options = [
     Option<"reassociateFPReductions", "reassociate-fp-reductions",
            "bool", /*default=*/"false",
            "Allows llvm to reassociate floating-point reductions for speed">,
     Option<"enableIndexOptimizations", "enable-index-optimizations",
            "bool", /*default=*/"true",
-           "Allows compiler to assume indices fit in 32-bit if that yields faster code">
+           "Allows compiler to assume indices fit in 32-bit if that yields "
+	   "faster code">,
+    Option<"enableAVX512", "enable-avx512",
+           "bool", /*default=*/"false",
+           "Augments the vector dialect with the AVX512 dialect during lowering">
   ];
 }
 

diff  --git a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
index 1a6fe7d166d0..435a148e2b10 100644
--- a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
+++ b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
@@ -23,6 +23,7 @@ class OperationPass;
 struct LowerVectorToLLVMOptions {
   bool reassociateFPReductions = false;
   bool enableIndexOptimizations = true;
+  bool enableAVX512 = false;
   LowerVectorToLLVMOptions &setReassociateFPReductions(bool b) {
     reassociateFPReductions = b;
     return *this;
@@ -31,6 +32,10 @@ struct LowerVectorToLLVMOptions {
     enableIndexOptimizations = b;
     return *this;
   }
+  LowerVectorToLLVMOptions &setEnableAVX512(bool b) {
+    enableAVX512 = b;
+    return *this;
+  }
 };
 
 /// Collect a set of patterns to convert from Vector contractions to LLVM Matrix

diff  --git a/mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp b/mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp
index c1daf485501e..3950562539f6 100644
--- a/mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp
+++ b/mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp
@@ -8,10 +8,7 @@
 
 #include "mlir/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.h"
 
-#include "../PassDetail.h"
 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
-#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
-#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
 #include "mlir/Dialect/AVX512/AVX512Dialect.h"
 #include "mlir/Dialect/LLVMIR/LLVMAVX512Dialect.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
@@ -19,7 +16,6 @@
 #include "mlir/Dialect/Vector/VectorOps.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/PatternMatch.h"
-#include "mlir/Pass/Pass.h"
 
 using namespace mlir;
 using namespace mlir::vector;
@@ -157,32 +153,3 @@ void mlir::populateAVX512ToLLVMConversionPatterns(
                   ScaleFOpPD512Conversion>(ctx, converter);
   // clang-format on
 }
-
-namespace {
-struct ConvertAVX512ToLLVMPass
-    : public ConvertAVX512ToLLVMBase<ConvertAVX512ToLLVMPass> {
-  void runOnOperation() override;
-};
-} // namespace
-
-void ConvertAVX512ToLLVMPass::runOnOperation() {
-  // Convert to the LLVM IR dialect.
-  OwningRewritePatternList patterns;
-  LLVMTypeConverter converter(&getContext());
-  populateAVX512ToLLVMConversionPatterns(converter, patterns);
-  populateVectorToLLVMConversionPatterns(converter, patterns);
-  populateStdToLLVMConversionPatterns(converter, patterns);
-
-  ConversionTarget target(getContext());
-  target.addLegalDialect<LLVM::LLVMDialect>();
-  target.addLegalDialect<LLVM::LLVMAVX512Dialect>();
-  target.addIllegalDialect<avx512::AVX512Dialect>();
-  if (failed(applyPartialConversion(getOperation(), target,
-                                    std::move(patterns)))) {
-    signalPassFailure();
-  }
-}
-
-std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertAVX512ToLLVMPass() {
-  return std::make_unique<ConvertAVX512ToLLVMPass>();
-}

diff  --git a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
index eeefd372f85f..c53839714af7 100644
--- a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
+++ b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
@@ -1,5 +1,6 @@
 add_mlir_conversion_library(MLIRVectorToLLVM
   ConvertVectorToLLVM.cpp
+  ConvertVectorToLLVMPass.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/VectorToLLVM
@@ -12,6 +13,9 @@ add_mlir_conversion_library(MLIRVectorToLLVM
   Core
 
   LINK_LIBS PUBLIC
+  MLIRAVX512
+  MLIRAVX512ToLLVM
+  MLIRLLVMAVX512
   MLIRLLVMIR
   MLIRStandardToLLVM
   MLIRTargetLLVMIRModuleTranslation

diff  --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 5628c555dea4..72d0d7d92c5f 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -8,25 +8,14 @@
 
 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
 
-#include "../PassDetail.h"
 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/Dialect/Vector/VectorOps.h"
-#include "mlir/IR/AffineMap.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/StandardTypes.h"
 #include "mlir/Target/LLVMIR/TypeTranslation.h"
 #include "mlir/Transforms/DialectConversion.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#include "mlir/Transforms/Passes.h"
-#include "llvm/IR/DerivedTypes.h"
-#include "llvm/IR/Module.h"
-#include "llvm/IR/Type.h"
-#include "llvm/Support/Allocator.h"
-#include "llvm/Support/ErrorHandling.h"
 
 using namespace mlir;
 using namespace mlir::vector;
@@ -1599,45 +1588,3 @@ void mlir::populateVectorToLLVMMatrixConversionPatterns(
   patterns.insert<VectorMatmulOpConversion>(ctx, converter);
   patterns.insert<VectorFlatTransposeOpConversion>(ctx, converter);
 }
-
-namespace {
-struct LowerVectorToLLVMPass
-    : public ConvertVectorToLLVMBase<LowerVectorToLLVMPass> {
-  LowerVectorToLLVMPass(const LowerVectorToLLVMOptions &options) {
-    this->reassociateFPReductions = options.reassociateFPReductions;
-    this->enableIndexOptimizations = options.enableIndexOptimizations;
-  }
-  void runOnOperation() override;
-};
-} // namespace
-
-void LowerVectorToLLVMPass::runOnOperation() {
-  // Perform progressive lowering of operations on slices and
-  // all contraction operations. Also applies folding and DCE.
-  {
-    OwningRewritePatternList patterns;
-    populateVectorToVectorCanonicalizationPatterns(patterns, &getContext());
-    populateVectorSlicesLoweringPatterns(patterns, &getContext());
-    populateVectorContractLoweringPatterns(patterns, &getContext());
-    applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
-  }
-
-  // Convert to the LLVM IR dialect.
-  LLVMTypeConverter converter(&getContext());
-  OwningRewritePatternList patterns;
-  populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
-  populateVectorToLLVMConversionPatterns(
-      converter, patterns, reassociateFPReductions, enableIndexOptimizations);
-  populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
-  populateStdToLLVMConversionPatterns(converter, patterns);
-
-  LLVMConversionTarget target(getContext());
-  if (failed(
-          applyPartialConversion(getOperation(), target, std::move(patterns))))
-    signalPassFailure();
-}
-
-std::unique_ptr<OperationPass<ModuleOp>>
-mlir::createConvertVectorToLLVMPass(const LowerVectorToLLVMOptions &options) {
-  return std::make_unique<LowerVectorToLLVMPass>(options);
-}

diff  --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
new file mode 100644
index 000000000000..4d5576ec9d9f
--- /dev/null
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -0,0 +1,73 @@
+//===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
+
+#include "../PassDetail.h"
+
+#include "mlir/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.h"
+#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
+#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
+#include "mlir/Dialect/AVX512/AVX512Dialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMAVX512Dialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/Vector/VectorOps.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+using namespace mlir::vector;
+
+namespace {
+struct LowerVectorToLLVMPass
+    : public ConvertVectorToLLVMBase<LowerVectorToLLVMPass> {
+  LowerVectorToLLVMPass(const LowerVectorToLLVMOptions &options) {
+    this->reassociateFPReductions = options.reassociateFPReductions;
+    this->enableIndexOptimizations = options.enableIndexOptimizations;
+    this->enableAVX512 = options.enableAVX512;
+  }
+  void runOnOperation() override;
+};
+} // namespace
+
+void LowerVectorToLLVMPass::runOnOperation() {
+  // Perform progressive lowering of operations on slices and
+  // all contraction operations. Also applies folding and DCE.
+  {
+    OwningRewritePatternList patterns;
+    populateVectorToVectorCanonicalizationPatterns(patterns, &getContext());
+    populateVectorSlicesLoweringPatterns(patterns, &getContext());
+    populateVectorContractLoweringPatterns(patterns, &getContext());
+    applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+  }
+
+  // Convert to the LLVM IR dialect.
+  LLVMTypeConverter converter(&getContext());
+  OwningRewritePatternList patterns;
+  populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
+  populateVectorToLLVMConversionPatterns(
+      converter, patterns, reassociateFPReductions, enableIndexOptimizations);
+  populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
+  populateStdToLLVMConversionPatterns(converter, patterns);
+
+  // Architecture specific augmentations.
+  LLVMConversionTarget target(getContext());
+  if (enableAVX512) {
+    target.addLegalDialect<LLVM::LLVMAVX512Dialect>();
+    target.addIllegalDialect<avx512::AVX512Dialect>();
+    populateAVX512ToLLVMConversionPatterns(converter, patterns);
+  }
+
+  if (failed(
+          applyPartialConversion(getOperation(), target, std::move(patterns))))
+    signalPassFailure();
+}
+
+std::unique_ptr<OperationPass<ModuleOp>>
+mlir::createConvertVectorToLLVMPass(const LowerVectorToLLVMOptions &options) {
+  return std::make_unique<LowerVectorToLLVMPass>(options);
+}

diff  --git a/mlir/lib/Dialect/AVX512/CMakeLists.txt b/mlir/lib/Dialect/AVX512/CMakeLists.txt
index eaacc8e46c9e..008add875a19 100644
--- a/mlir/lib/Dialect/AVX512/CMakeLists.txt
+++ b/mlir/lib/Dialect/AVX512/CMakeLists.txt
@@ -10,5 +10,4 @@ add_mlir_dialect_library(MLIRAVX512
   LINK_LIBS PUBLIC
   MLIRIR
   MLIRSideEffectInterfaces
-  MLIRVectorToLLVM
   )

diff  --git a/mlir/test/Conversion/AVX512ToLLVM/convert-to-llvm.mlir b/mlir/test/Conversion/AVX512ToLLVM/convert-to-llvm.mlir
index 936819e27eb9..f6afe9c053e8 100644
--- a/mlir/test/Conversion/AVX512ToLLVM/convert-to-llvm.mlir
+++ b/mlir/test/Conversion/AVX512ToLLVM/convert-to-llvm.mlir
@@ -1,7 +1,7 @@
-// RUN: mlir-opt %s -convert-avx512-to-llvm | mlir-opt | FileCheck %s
+// RUN: mlir-opt %s -convert-vector-to-llvm="enable-avx512" | mlir-opt | FileCheck %s
 
 func @avx512_mask_rndscale(%a: vector<16xf32>, %b: vector<8xf64>, %i32: i32, %i16: i16, %i8: i8)
-  -> (vector<16xf32>, vector<8xf64>)
+  -> (vector<16xf32>, vector<8xf64>, vector<16xf32>, vector<8xf64>)
 {
   // CHECK: llvm_avx512.mask.rndscale.ps.512
   %0 = avx512.mask.rndscale %a, %i32, %a, %i16, %i32: vector<16xf32>
@@ -9,9 +9,10 @@ func @avx512_mask_rndscale(%a: vector<16xf32>, %b: vector<8xf64>, %i32: i32, %i1
   %1 = avx512.mask.rndscale %b, %i32, %b, %i8, %i32: vector<8xf64>
 
   // CHECK: llvm_avx512.mask.scalef.ps.512
-  %a0 = avx512.mask.scalef %a, %a, %a, %i16, %i32: vector<16xf32>
+  %2 = avx512.mask.scalef %a, %a, %a, %i16, %i32: vector<16xf32>
   // CHECK: llvm_avx512.mask.scalef.pd.512
-  %a1 = avx512.mask.scalef %b, %b, %b, %i8, %i32: vector<8xf64>
+  %3 = avx512.mask.scalef %b, %b, %b, %i8, %i32: vector<8xf64>
 
-  return %a0, %a1: vector<16xf32>, vector<8xf64>
+  // Keep results alive.
+  return %0, %1, %2, %3 : vector<16xf32>, vector<8xf64>, vector<16xf32>, vector<8xf64>
 }


        


More information about the llvm-branch-commits mailing list