[Mlir-commits] [mlir] 7bc8ad5 - [mlir][vector][nfc] Rename index optimizations option

Javier Setoain llvmlistbot at llvm.org
Tue Mar 29 03:34:58 PDT 2022


Author: Javier Setoain
Date: 2022-03-29T11:33:22+01:00
New Revision: 7bc8ad5109eb955b8da9b279955bae098e1bd669

URL: https://github.com/llvm/llvm-project/commit/7bc8ad5109eb955b8da9b279955bae098e1bd669
DIFF: https://github.com/llvm/llvm-project/commit/7bc8ad5109eb955b8da9b279955bae098e1bd669.diff

LOG: [mlir][vector][nfc] Rename index optimizations option

We are using "enable-index-optimizations" and "indexOptimizations" as
names for an optimization that consists of using i32 for indices within
a vector. For instance, when building a vector comparison for mask
generation. The name is confusing and suggests a scope beyond these
vector indices.  This change makes the function of the option explicit
in its name.

Differential Revision: https://reviews.llvm.org/D122415

Added: 
    

Modified: 
    mlir/include/mlir/Conversion/Passes.td
    mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
    mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
    mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
    mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
    mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
    mlir/test/Conversion/VectorToLLVM/vector-mask-to-llvm.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index bdef5615ff2a1..f4bd4d8c65ff2 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -819,10 +819,10 @@ def ConvertVectorToLLVM : Pass<"convert-vector-to-llvm", "ModuleOp"> {
     Option<"reassociateFPReductions", "reassociate-fp-reductions",
            "bool", /*default=*/"false",
            "Allows llvm to reassociate floating-point reductions for speed">,
-    Option<"indexOptimizations", "enable-index-optimizations",
+    Option<"force32BitVectorIndices", "force-32bit-vector-indices",
            "bool", /*default=*/"true",
-           "Allows compiler to assume indices fit in 32-bit if that yields "
-	   "faster code">,
+           "Allows compiler to assume vector indices fit in 32-bit if that "
+     "yields faster code">,
     Option<"amx", "enable-amx",
            "bool", /*default=*/"false",
            "Enables the use of AMX dialect while lowering the vector "

diff  --git a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
index 94cb53f9300b8..80a6454a42498 100644
--- a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
+++ b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
@@ -28,7 +28,7 @@ struct LowerVectorToLLVMOptions {
     return *this;
   }
   LowerVectorToLLVMOptions &enableIndexOptimizations(bool b = true) {
-    indexOptimizations = b;
+    force32BitVectorIndices = b;
     return *this;
   }
   LowerVectorToLLVMOptions &enableArmNeon(bool b = true) {
@@ -49,7 +49,7 @@ struct LowerVectorToLLVMOptions {
   }
 
   bool reassociateFPReductions{false};
-  bool indexOptimizations{true};
+  bool force32BitVectorIndices{true};
   bool armNeon{false};
   bool armSVE{false};
   bool amx{false};
@@ -63,10 +63,9 @@ void populateVectorToLLVMMatrixConversionPatterns(LLVMTypeConverter &converter,
                                                   RewritePatternSet &patterns);
 
 /// Collect a set of patterns to convert from the Vector dialect to LLVM.
-/// If `indexOptimizations` is set, assume indices fit in 32-bit.
 void populateVectorToLLVMConversionPatterns(
     LLVMTypeConverter &converter, RewritePatternSet &patterns,
-    bool reassociateFPReductions = false, bool indexOptimizations = false);
+    bool reassociateFPReductions = false, bool force32BitVectorIndices = false);
 
 /// Create a pass to convert vector operations to the LLVMIR dialect.
 std::unique_ptr<OperationPass<ModuleOp>> createConvertVectorToLLVMPass(

diff  --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
index 91378a8b33e55..b5e9f25c710e2 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
@@ -102,7 +102,7 @@ void populateVectorTransferLoweringPatterns(
 
 /// These patterns materialize masks for various vector ops such as transfers.
 void populateVectorMaskMaterializationPatterns(RewritePatternSet &patterns,
-                                               bool indexOptimizations);
+                                               bool force32BitVectorIndices);
 
 /// Collect a set of patterns to propagate insert_map/extract_map in the ssa
 /// chain.

diff  --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 3f6b3524b8965..a164c7d167dc6 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -909,7 +909,7 @@ class VectorCreateMaskOpRewritePattern
   explicit VectorCreateMaskOpRewritePattern(MLIRContext *context,
                                             bool enableIndexOpt)
       : OpRewritePattern<vector::CreateMaskOp>(context),
-        indexOptimizations(enableIndexOpt) {}
+        force32BitVectorIndices(enableIndexOpt) {}
 
   LogicalResult matchAndRewrite(vector::CreateMaskOp op,
                                 PatternRewriter &rewriter) const override {
@@ -917,7 +917,7 @@ class VectorCreateMaskOpRewritePattern
     if (dstType.getRank() != 1 || !dstType.cast<VectorType>().isScalable())
       return failure();
     IntegerType idxType =
-        indexOptimizations ? rewriter.getI32Type() : rewriter.getI64Type();
+        force32BitVectorIndices ? rewriter.getI32Type() : rewriter.getI64Type();
     auto loc = op->getLoc();
     Value indices = rewriter.create<LLVM::StepVectorOp>(
         loc, LLVM::getVectorType(idxType, dstType.getShape()[0],
@@ -932,7 +932,7 @@ class VectorCreateMaskOpRewritePattern
   }
 
 private:
-  const bool indexOptimizations;
+  const bool force32BitVectorIndices;
 };
 
 class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
@@ -1192,15 +1192,14 @@ struct VectorSplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
 } // namespace
 
 /// Populate the given list with patterns that convert from Vector to LLVM.
-void mlir::populateVectorToLLVMConversionPatterns(LLVMTypeConverter &converter,
-                                                  RewritePatternSet &patterns,
-                                                  bool reassociateFPReductions,
-                                                  bool indexOptimizations) {
+void mlir::populateVectorToLLVMConversionPatterns(
+    LLVMTypeConverter &converter, RewritePatternSet &patterns,
+    bool reassociateFPReductions, bool force32BitVectorIndices) {
   MLIRContext *ctx = converter.getDialect()->getContext();
   patterns.add<VectorFMAOpNDRewritePattern>(ctx);
   populateVectorInsertExtractStridedSliceTransforms(patterns);
   patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions);
-  patterns.add<VectorCreateMaskOpRewritePattern>(ctx, indexOptimizations);
+  patterns.add<VectorCreateMaskOpRewritePattern>(ctx, force32BitVectorIndices);
   patterns
       .add<VectorBitCastOpConversion, VectorShuffleOpConversion,
            VectorExtractElementOpConversion, VectorExtractOpConversion,

diff  --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index 68edc23e82375..3493e2c751654 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -34,7 +34,7 @@ struct LowerVectorToLLVMPass
     : public ConvertVectorToLLVMBase<LowerVectorToLLVMPass> {
   LowerVectorToLLVMPass(const LowerVectorToLLVMOptions &options) {
     this->reassociateFPReductions = options.reassociateFPReductions;
-    this->indexOptimizations = options.indexOptimizations;
+    this->force32BitVectorIndices = options.force32BitVectorIndices;
     this->armNeon = options.armNeon;
     this->armSVE = options.armSVE;
     this->amx = options.amx;
@@ -77,11 +77,11 @@ void LowerVectorToLLVMPass::runOnOperation() {
   // Convert to the LLVM IR dialect.
   LLVMTypeConverter converter(&getContext());
   RewritePatternSet patterns(&getContext());
-  populateVectorMaskMaterializationPatterns(patterns, indexOptimizations);
+  populateVectorMaskMaterializationPatterns(patterns, force32BitVectorIndices);
   populateVectorTransferLoweringPatterns(patterns);
   populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
   populateVectorToLLVMConversionPatterns(
-      converter, patterns, reassociateFPReductions, indexOptimizations);
+      converter, patterns, reassociateFPReductions, force32BitVectorIndices);
   populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
 
   // Architecture specific augmentations.

diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 2ca6481b18998..dd008056bb5fb 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -2185,22 +2185,22 @@ struct BubbleUpBitCastForStridedSliceInsert
 //       generates more elaborate instructions for this intrinsic since it
 //       is very conservative on the boundary conditions.
 static Value buildVectorComparison(PatternRewriter &rewriter, Operation *op,
-                                   bool indexOptimizations, int64_t dim,
+                                   bool force32BitVectorIndices, int64_t dim,
                                    Value b, Value *off = nullptr) {
   auto loc = op->getLoc();
   // If we can assume all indices fit in 32-bit, we perform the vector
   // comparison in 32-bit to get a higher degree of SIMD parallelism.
   // Otherwise we perform the vector comparison using 64-bit indices.
   Type idxType =
-      indexOptimizations ? rewriter.getI32Type() : rewriter.getI64Type();
+      force32BitVectorIndices ? rewriter.getI32Type() : rewriter.getI64Type();
   DenseIntElementsAttr indicesAttr;
-  if (dim == 0 && indexOptimizations) {
+  if (dim == 0 && force32BitVectorIndices) {
     indicesAttr = DenseIntElementsAttr::get(
         VectorType::get(ArrayRef<int64_t>{}, idxType), ArrayRef<int32_t>{0});
   } else if (dim == 0) {
     indicesAttr = DenseIntElementsAttr::get(
         VectorType::get(ArrayRef<int64_t>{}, idxType), ArrayRef<int64_t>{0});
-  } else if (indexOptimizations) {
+  } else if (force32BitVectorIndices) {
     indicesAttr = rewriter.getI32VectorAttr(
         llvm::to_vector<4>(llvm::seq<int32_t>(0, dim)));
   } else {
@@ -2227,7 +2227,7 @@ struct MaterializeTransferMask : public OpRewritePattern<ConcreteOp> {
 public:
   explicit MaterializeTransferMask(MLIRContext *context, bool enableIndexOpt)
       : mlir::OpRewritePattern<ConcreteOp>(context),
-        indexOptimizations(enableIndexOpt) {}
+        force32BitVectorIndices(enableIndexOpt) {}
 
   LogicalResult matchAndRewrite(ConcreteOp xferOp,
                                 PatternRewriter &rewriter) const override {
@@ -2270,7 +2270,7 @@ struct MaterializeTransferMask : public OpRewritePattern<ConcreteOp> {
   }
 
 private:
-  const bool indexOptimizations;
+  const bool force32BitVectorIndices;
 };
 
 /// Conversion pattern for a `vector.create_mask` (0-D and 1-D only).
@@ -2280,7 +2280,7 @@ class VectorCreateMaskOpConversion
   explicit VectorCreateMaskOpConversion(MLIRContext *context,
                                         bool enableIndexOpt)
       : mlir::OpRewritePattern<vector::CreateMaskOp>(context),
-        indexOptimizations(enableIndexOpt) {}
+        force32BitVectorIndices(enableIndexOpt) {}
 
   LogicalResult matchAndRewrite(vector::CreateMaskOp op,
                                 PatternRewriter &rewriter) const override {
@@ -2291,14 +2291,14 @@ class VectorCreateMaskOpConversion
     if (rank > 1)
       return failure();
     rewriter.replaceOp(
-        op, buildVectorComparison(rewriter, op, indexOptimizations,
+        op, buildVectorComparison(rewriter, op, force32BitVectorIndices,
                                   rank == 0 ? 0 : dstType.getDimSize(0),
                                   op.getOperand(0)));
     return success();
   }
 
 private:
-  const bool indexOptimizations;
+  const bool force32BitVectorIndices;
 };
 
 // Drop inner most contiguous unit dimensions from transfer_read operand.
@@ -2592,11 +2592,11 @@ struct ScanToArithOps : public OpRewritePattern<vector::ScanOp> {
 } // namespace
 
 void mlir::vector::populateVectorMaskMaterializationPatterns(
-    RewritePatternSet &patterns, bool indexOptimizations) {
+    RewritePatternSet &patterns, bool force32BitVectorIndices) {
   patterns.add<VectorCreateMaskOpConversion,
                MaterializeTransferMask<vector::TransferReadOp>,
                MaterializeTransferMask<vector::TransferWriteOp>>(
-      patterns.getContext(), indexOptimizations);
+      patterns.getContext(), force32BitVectorIndices);
 }
 
 void mlir::vector::populateShapeCastFoldingPatterns(

diff  --git a/mlir/test/Conversion/VectorToLLVM/vector-mask-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-mask-to-llvm.mlir
index 3c2ac46613310..0c44cea440faa 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-mask-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-mask-to-llvm.mlir
@@ -1,5 +1,5 @@
-// RUN: mlir-opt %s --convert-vector-to-llvm='enable-index-optimizations=1' | FileCheck %s --check-prefix=CMP32
-// RUN: mlir-opt %s --convert-vector-to-llvm='enable-index-optimizations=0' | FileCheck %s --check-prefix=CMP64
+// RUN: mlir-opt %s --convert-vector-to-llvm='force-32bit-vector-indices=1' | FileCheck %s --check-prefix=CMP32
+// RUN: mlir-opt %s --convert-vector-to-llvm='force-32bit-vector-indices=0' | FileCheck %s --check-prefix=CMP64
 
 // CMP32-LABEL: @genbool_var_1d(
 // CMP32-SAME: %[[ARG:.*]]: index)


        


More information about the Mlir-commits mailing list