[Mlir-commits] [mlir] [mlir][vector] Support index type in ND to 1D vector linearization (PR #118404)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Dec 2 14:09:13 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Amy Zhuang (ayzhuang)

<details>
<summary>Changes</summary>

Currently index type is not supported because getElementTypeBitWidth aborts for index type. This patch adds indexBitWidth input to the vector linearization patterns.

---
Full diff: https://github.com/llvm/llvm-project/pull/118404.diff


4 Files Affected:

- (modified) mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h (+2-2) 
- (modified) mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp (+59-26) 
- (modified) mlir/test/Dialect/Vector/linearize.mlir (+12-3) 
- (modified) mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp (+6-2) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index a59f06f3c1ef1b..e3c19a078c18b0 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -399,13 +399,13 @@ void populateVectorTransposeNarrowTypeRewritePatterns(
 /// the ops to get converted properly.
 void populateVectorLinearizeTypeConversionsAndLegality(
     TypeConverter &typeConverter, RewritePatternSet &patterns,
-    ConversionTarget &target, unsigned targetBitWidth);
+    ConversionTarget &target, unsigned indexBitWidth, unsigned targetBitWidth);
 
 /// Populates patterns for linearizing ND (N >= 2) vector operations to 1D
 /// vector shuffle operations.
 void populateVectorLinearizeShuffleLikeOpsPatterns(
     const TypeConverter &typeConverter, RewritePatternSet &patterns,
-    ConversionTarget &target, unsigned targetBitWidth);
+    ConversionTarget &target, unsigned indexBitWidth, unsigned targetBitWidth);
 
 } // namespace vector
 } // namespace mlir
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 757631944f224f..f0bf6276f0e659 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -25,34 +25,44 @@
 
 using namespace mlir;
 
-static bool isLessThanTargetBitWidth(Operation *op, unsigned targetBitWidth) {
+static bool isLessThanTargetBitWidth(Operation *op, unsigned indexBitWidth,
+                                     unsigned targetBitWidth) {
   auto resultTypes = op->getResultTypes();
   for (auto resType : resultTypes) {
     VectorType vecType = dyn_cast<VectorType>(resType);
-    // Reject index since getElementTypeBitWidth will abort for Index types.
-    if (!vecType || vecType.getElementType().isIndex())
+    if (!vecType)
+      return false;
+    bool isIndexTy = vecType.getElementType().isIndex();
+    // Reject index if `indexBitWidth` is not supplied.
+    if (isIndexTy && indexBitWidth == 0)
       return false;
     // There are no dimension to fold if it is a 0-D vector.
     if (vecType.getRank() == 0)
       return false;
     unsigned trailingVecDimBitWidth =
-        vecType.getShape().back() * vecType.getElementTypeBitWidth();
+        vecType.getShape().back() *
+        (isIndexTy ? indexBitWidth : vecType.getElementTypeBitWidth());
     if (trailingVecDimBitWidth >= targetBitWidth)
       return false;
   }
   return true;
 }
 
-static bool isLessThanOrEqualTargetBitWidth(Type t, unsigned targetBitWidth) {
+static bool isLessThanOrEqualTargetBitWidth(Type t, unsigned indexBitWidth,
+                                            unsigned targetBitWidth) {
   VectorType vecType = dyn_cast<VectorType>(t);
-  // Reject index since getElementTypeBitWidth will abort for Index types.
-  if (!vecType || vecType.getElementType().isIndex())
+  if (!vecType)
+    return false;
+  bool isIndexTy = vecType.getElementType().isIndex();
+  // Reject index if `indexBitWidth` is not supplied.
+  if (isIndexTy && indexBitWidth == 0)
     return false;
   // There are no dimension to fold if it is a 0-D vector.
   if (vecType.getRank() == 0)
     return false;
   unsigned trailingVecDimBitWidth =
-      vecType.getShape().back() * vecType.getElementTypeBitWidth();
+      vecType.getShape().back() *
+      (isIndexTy ? indexBitWidth : vecType.getElementTypeBitWidth());
   return trailingVecDimBitWidth <= targetBitWidth;
 }
 
@@ -61,10 +71,12 @@ struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> {
   using OpConversionPattern::OpConversionPattern;
   LinearizeConstant(
       const TypeConverter &typeConverter, MLIRContext *context,
+      unsigned indexBitWidth = 0,
       unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
       PatternBenefit benefit = 1)
       : OpConversionPattern(typeConverter, context, benefit),
-        targetVectorBitWidth(targetVectBitWidth) {}
+        indexBitWidth(indexBitWidth), targetVectorBitWidth(targetVectBitWidth) {
+  }
   LogicalResult
   matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
@@ -79,7 +91,7 @@ struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> {
 
     if (!resType)
       return rewriter.notifyMatchFailure(loc, "can't convert return type");
-    if (!isLessThanTargetBitWidth(constOp, targetVectorBitWidth))
+    if (!isLessThanTargetBitWidth(constOp, indexBitWidth, targetVectorBitWidth))
       return rewriter.notifyMatchFailure(
           loc, "Can't flatten since targetBitWidth <= OpSize");
     auto dstElementsAttr = dyn_cast<DenseElementsAttr>(constOp.getValue());
@@ -93,6 +105,7 @@ struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> {
   }
 
 private:
+  unsigned indexBitWidth;
   unsigned targetVectorBitWidth;
 };
 
@@ -103,14 +116,16 @@ struct LinearizeVectorizable final
 public:
   LinearizeVectorizable(
       const TypeConverter &typeConverter, MLIRContext *context,
+      unsigned indexBitWidth = 0,
       unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
       PatternBenefit benefit = 1)
       : OpTraitConversionPattern(typeConverter, context, benefit),
-        targetVectorBitWidth(targetVectBitWidth) {}
+        indexBitWidth(indexBitWidth), targetVectorBitWidth(targetVectBitWidth) {
+  }
   LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
-    if (!isLessThanTargetBitWidth(op, targetVectorBitWidth))
+    if (!isLessThanTargetBitWidth(op, indexBitWidth, targetVectorBitWidth))
       return rewriter.notifyMatchFailure(
           op->getLoc(), "Can't flatten since targetBitWidth <= OpSize");
     FailureOr<Operation *> newOp =
@@ -123,6 +138,7 @@ struct LinearizeVectorizable final
   }
 
 private:
+  unsigned indexBitWidth;
   unsigned targetVectorBitWidth;
 };
 
@@ -142,10 +158,12 @@ struct LinearizeVectorExtractStridedSlice final
   using OpConversionPattern::OpConversionPattern;
   LinearizeVectorExtractStridedSlice(
       const TypeConverter &typeConverter, MLIRContext *context,
+      unsigned indexBitWidth = 0,
       unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
       PatternBenefit benefit = 1)
       : OpConversionPattern(typeConverter, context, benefit),
-        targetVectorBitWidth(targetVectBitWidth) {}
+        indexBitWidth(indexBitWidth), targetVectorBitWidth(targetVectBitWidth) {
+  }
 
   LogicalResult
   matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor,
@@ -156,7 +174,8 @@ struct LinearizeVectorExtractStridedSlice final
     if (extractOp.getVector().getType().isScalable() || dstType.isScalable())
       return rewriter.notifyMatchFailure(extractOp,
                                          "scalable vectors are not supported.");
-    if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth))
+    if (!isLessThanTargetBitWidth(extractOp, indexBitWidth,
+                                  targetVectorBitWidth))
       return rewriter.notifyMatchFailure(
           extractOp, "Can't flatten since targetBitWidth <= OpSize");
 
@@ -237,6 +256,7 @@ struct LinearizeVectorExtractStridedSlice final
   }
 
 private:
+  unsigned indexBitWidth;
   unsigned targetVectorBitWidth;
 };
 
@@ -256,10 +276,12 @@ struct LinearizeVectorShuffle final
   using OpConversionPattern::OpConversionPattern;
   LinearizeVectorShuffle(
       const TypeConverter &typeConverter, MLIRContext *context,
+      unsigned indexBitWidth = 0,
       unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
       PatternBenefit benefit = 1)
       : OpConversionPattern(typeConverter, context, benefit),
-        targetVectorBitWidth(targetVectBitWidth) {}
+        indexBitWidth(indexBitWidth), targetVectorBitWidth(targetVectBitWidth) {
+  }
 
   LogicalResult
   matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
@@ -273,7 +295,8 @@ struct LinearizeVectorShuffle final
              shuffleOp.getV2VectorType().isScalable() ||
              dstType.isScalable()) &&
            "scalable vectors are not supported.");
-    if (!isLessThanTargetBitWidth(shuffleOp, targetVectorBitWidth))
+    if (!isLessThanTargetBitWidth(shuffleOp, indexBitWidth,
+                                  targetVectorBitWidth))
       return rewriter.notifyMatchFailure(
           shuffleOp, "Can't flatten since targetBitWidth <= OpSize");
 
@@ -312,6 +335,7 @@ struct LinearizeVectorShuffle final
   }
 
 private:
+  unsigned indexBitWidth;
   unsigned targetVectorBitWidth;
 };
 
@@ -329,10 +353,12 @@ struct LinearizeVectorExtract final
   using OpConversionPattern::OpConversionPattern;
   LinearizeVectorExtract(
       const TypeConverter &typeConverter, MLIRContext *context,
+      unsigned indexBitWidth = 0,
       unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
       PatternBenefit benefit = 1)
       : OpConversionPattern(typeConverter, context, benefit),
-        targetVectorBitWidth(targetVectBitWidth) {}
+        indexBitWidth(indexBitWidth), targetVectorBitWidth(targetVectBitWidth) {
+  }
   LogicalResult
   matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
@@ -345,7 +371,8 @@ struct LinearizeVectorExtract final
         cast<VectorType>(dstTy).isScalable())
       return rewriter.notifyMatchFailure(extractOp,
                                          "scalable vectors are not supported.");
-    if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth))
+    if (!isLessThanTargetBitWidth(extractOp, indexBitWidth,
+                                  targetVectorBitWidth))
       return rewriter.notifyMatchFailure(
           extractOp, "Can't flatten since targetBitWidth <= OpSize");
 
@@ -374,6 +401,7 @@ struct LinearizeVectorExtract final
   }
 
 private:
+  unsigned indexBitWidth;
   unsigned targetVectorBitWidth;
 };
 
@@ -392,10 +420,12 @@ struct LinearizeVectorInsert final
   using OpConversionPattern::OpConversionPattern;
   LinearizeVectorInsert(
       const TypeConverter &typeConverter, MLIRContext *context,
+      unsigned indexBitWidth = 0,
       unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
       PatternBenefit benefit = 1)
       : OpConversionPattern(typeConverter, context, benefit),
-        targetVectorBitWidth(targetVectBitWidth) {}
+        indexBitWidth(indexBitWidth), targetVectorBitWidth(targetVectBitWidth) {
+  }
   LogicalResult
   matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
@@ -407,7 +437,7 @@ struct LinearizeVectorInsert final
                                          "scalable vectors are not supported.");
 
     if (!isLessThanOrEqualTargetBitWidth(insertOp.getSourceType(),
-                                         targetVectorBitWidth))
+                                         indexBitWidth, targetVectorBitWidth))
       return rewriter.notifyMatchFailure(
           insertOp, "Can't flatten since targetBitWidth < OpSize");
 
@@ -457,13 +487,14 @@ struct LinearizeVectorInsert final
   }
 
 private:
+  unsigned indexBitWidth;
   unsigned targetVectorBitWidth;
 };
 } // namespace
 
 void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
     TypeConverter &typeConverter, RewritePatternSet &patterns,
-    ConversionTarget &target, unsigned targetBitWidth) {
+    ConversionTarget &target, unsigned indexBitWidth, unsigned targetBitWidth) {
 
   typeConverter.addConversion([](VectorType type) -> std::optional<Type> {
     if (!isLinearizableVector(type))
@@ -488,7 +519,7 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
       [=](Operation *op) -> std::optional<bool> {
         if ((isa<arith::ConstantOp>(op) ||
              op->hasTrait<OpTrait::Vectorizable>())) {
-          return (isLessThanTargetBitWidth(op, targetBitWidth)
+          return (isLessThanTargetBitWidth(op, indexBitWidth, targetBitWidth)
                       ? typeConverter.isLegal(op)
                       : true);
         }
@@ -496,15 +527,17 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
       });
 
   patterns.add<LinearizeConstant, LinearizeVectorizable>(
-      typeConverter, patterns.getContext(), targetBitWidth);
+      typeConverter, patterns.getContext(), indexBitWidth, targetBitWidth);
 }
 
 void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
     const TypeConverter &typeConverter, RewritePatternSet &patterns,
-    ConversionTarget &target, unsigned int targetBitWidth) {
+    ConversionTarget &target, unsigned indexBitWidth,
+    unsigned int targetBitWidth) {
   target.addDynamicallyLegalOp<vector::ShuffleOp>(
       [=](vector::ShuffleOp shuffleOp) -> bool {
-        return isLessThanTargetBitWidth(shuffleOp, targetBitWidth)
+        return isLessThanTargetBitWidth(shuffleOp, indexBitWidth,
+                                        targetBitWidth)
                    ? (typeConverter.isLegal(shuffleOp) &&
                       cast<mlir::VectorType>(shuffleOp.getResult().getType())
                               .getRank() == 1)
@@ -512,5 +545,5 @@ void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
       });
   patterns.add<LinearizeVectorShuffle, LinearizeVectorExtract,
                LinearizeVectorInsert, LinearizeVectorExtractStridedSlice>(
-      typeConverter, patterns.getContext(), targetBitWidth);
+      typeConverter, patterns.getContext(), indexBitWidth, targetBitWidth);
 }
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 543e76b5b26e0c..fe169d3e16d683 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -1,6 +1,7 @@
 // RUN: mlir-opt %s -split-input-file -test-vector-linearize -verify-diagnostics | FileCheck %s --check-prefixes=ALL,DEFAULT
 // RUN: mlir-opt %s -split-input-file -test-vector-linearize=target-vector-bitwidth=128  -verify-diagnostics | FileCheck %s --check-prefixes=ALL,BW-128
 // RUN: mlir-opt %s -split-input-file -test-vector-linearize=target-vector-bitwidth=0 | FileCheck %s --check-prefixes=ALL,BW-0
+// RUN: mlir-opt %s -split-input-file -test-vector-linearize=index-bitwidth=64 | FileCheck %s --check-prefixes=ALL,INDEX-BW-64
 
 // ALL-LABEL: test_linearize
 // ALL-SAME: (%[[ORIG_ARG:.*]]: vector<2x2xf32>)
@@ -14,6 +15,8 @@ func.func @test_linearize(%arg0: vector<2x2xf32>) -> vector<2x2xf32> {
   // BW-128: %[[RES:.*]] = vector.shape_cast %[[CST]] : vector<4xf32> to vector<2x2xf32>
 
   // BW-0: %[[RES:.*]] = arith.constant dense<{{.*}}> : vector<2x2xf32>
+
+  // INDEX-BW-64: %[[RES:.*]] = vector.shape_cast %{{.*}} : vector<4xf32> to vector<2x2xf32>
   %0 = arith.constant dense<[[1.0, 2.0], [3.0, 4.0]]> : vector<2x2xf32>
 
   // DEFAULT: %{{.*}} =  math.sin %[[ARG]] : vector<4xf32>
@@ -45,6 +48,8 @@ func.func @test_partial_linearize(%arg0: vector<2x2xf32>, %arg1: vector<4x4xf32>
   // BW-128: %[[RES:.*]] = vector.shape_cast %[[CST]] : vector<4xf32> to vector<2x2xf32>
 
   // BW-0: %[[RES:.*]] = arith.constant dense<{{.*}}> : vector<2x2xf32>
+
+  // INDEX-BW-64: %[[RES:.*]] = vector.shape_cast %{{.*}} : vector<4xf32> to vector<2x2xf32>
   %0 = arith.constant dense<[[1.0, 2.0], [3.0, 4.0]]> : vector<2x2xf32>
 
   // DEFAULT: %[[C2:.*]] = arith.constant dense<{{.*}}> : vector<16xf32>
@@ -79,9 +84,12 @@ func.func @test_partial_linearize(%arg0: vector<2x2xf32>, %arg1: vector<4x4xf32>
 
 // -----
 
-// ALL-LABEL: test_index_no_linearize
-func.func @test_index_no_linearize(%arg0: vector<2x2xindex>, %arg1: vector<2x2xindex>) -> vector<2x2xindex> {
-    // ALL: %[[ADD:.*]] = arith.addi {{.*}} : vector<2x2xindex>
+// ALL-LABEL: test_index_linearize
+func.func @test_index_linearize(%arg0: vector<2x2xindex>, %arg1: vector<2x2xindex>) -> vector<2x2xindex> {
+    // DEFAULT: %[[ADD:.*]] = arith.addi {{.*}} : vector<2x2xindex>
+    // BW-128: %[[ADD:.*]] = arith.addi {{.*}} : vector<2x2xindex>
+    // BW-0: %[[ADD:.*]] = arith.addi {{.*}} : vector<2x2xindex>
+    // INDEX-BW-64: %[[ADD:.*]] = arith.addi {{.*}} : vector<4xindex>
     %0 = arith.addi %arg0, %arg1 : vector<2x2xindex>
     return %0 : vector<2x2xindex>
 }
@@ -122,6 +130,7 @@ func.func @test_scalable_linearize(%arg0: vector<2x[2]xf32>) -> vector<2x[2]xf32
 
   // DEFAULT: %[[RES:.*]] = vector.shape_cast %[[ADDF]] : vector<[4]xf32> to vector<2x[2]xf32>
   // BW-128: %[[RES:.*]] = vector.shape_cast %[[ADDF]] : vector<[4]xf32> to vector<2x[2]xf32>
+  // INDEX-BW-64: %[[RES:.*]] = vector.shape_cast %{{.*}} : vector<[4]xf32> to vector<2x[2]xf32>
   // ALL: return %[[RES]] : vector<2x[2]xf32>
   return %2 : vector<2x[2]xf32>
 }
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index f67a24755ac09a..2589782aee1449 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -853,6 +853,10 @@ struct TestVectorLinearize final
     registry.insert<vector::VectorDialect>();
   }
 
+  Option<unsigned> indexBitwidth{*this, "index-bitwidth",
+                                 llvm::cl::desc("Bitwidth of the index type"),
+                                 llvm::cl::init(0)};
+
   Option<unsigned> targetVectorBitwidth{
       *this, "target-vector-bitwidth",
       llvm::cl::desc(
@@ -866,9 +870,9 @@ struct TestVectorLinearize final
     ConversionTarget target(*context);
 
     vector::populateVectorLinearizeTypeConversionsAndLegality(
-        typeConverter, patterns, target, targetVectorBitwidth);
+        typeConverter, patterns, target, indexBitwidth, targetVectorBitwidth);
     vector::populateVectorLinearizeShuffleLikeOpsPatterns(
-        typeConverter, patterns, target, targetVectorBitwidth);
+        typeConverter, patterns, target, indexBitwidth, targetVectorBitwidth);
     if (failed(applyPartialConversion(getOperation(), target,
                                       std::move(patterns))))
       return signalPassFailure();

``````````

</details>


https://github.com/llvm/llvm-project/pull/118404


More information about the Mlir-commits mailing list