[Mlir-commits] [mlir] 6f5c4f2 - [mlir][vector]Add Vector bitwidth target to Linearize Vectorizable and Constant Ops (#83314)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Mar 4 17:17:54 PST 2024


Author: Balaji V. Iyer
Date: 2024-03-04T19:17:51-06:00
New Revision: 6f5c4f2eacf24cecfc0faf0a204137ce65eecc2d

URL: https://github.com/llvm/llvm-project/commit/6f5c4f2eacf24cecfc0faf0a204137ce65eecc2d
DIFF: https://github.com/llvm/llvm-project/commit/6f5c4f2eacf24cecfc0faf0a204137ce65eecc2d.diff

LOG: [mlir][vector]Add Vector bitwidth target to Linearize Vectorizable and Constant Ops (#83314)

Added a new flag `targetVectorBitwidth` to capture bit-width input.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
    mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
    mlir/test/Dialect/Vector/linearize.mlir
    mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index 46bb3ddec0baf6..453fa73429dd1a 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -387,7 +387,7 @@ void populateVectorTransposeNarrowTypeRewritePatterns(
 /// the ops to get converted properly.
 void populateVectorLinearizeTypeConversionsAndLegality(
     TypeConverter &typeConverter, RewritePatternSet &patterns,
-    ConversionTarget &target);
+    ConversionTarget &target, unsigned targetBitWidth);
 
 } // namespace vector
 } // namespace mlir

diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index c5352043955579..7ca03537049812 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -19,10 +19,30 @@
 
 using namespace mlir;
 
+static bool isLessThanTargetBitWidth(Operation *op, unsigned targetBitWidth) {
+  auto resultTypes = op->getResultTypes();
+  for (auto resType : resultTypes) {
+    VectorType vecType = cast<VectorType>(resType);
+    // Reject index since getElementTypeBitWidth will abort for Index types.
+    if (vecType.getElementType().isIndex())
+      return false;
+    unsigned trailingVecDimBitWidth =
+        vecType.getShape().back() * vecType.getElementTypeBitWidth();
+    if (trailingVecDimBitWidth >= targetBitWidth)
+      return false;
+  }
+  return true;
+}
+
 namespace {
 struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> {
   using OpConversionPattern::OpConversionPattern;
-
+  LinearizeConstant(
+      const TypeConverter &typeConverter, MLIRContext *context,
+      unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+      PatternBenefit benefit = 1)
+      : OpConversionPattern(typeConverter, context, benefit),
+        targetVectorBitWidth(targetVectBitWidth) {}
   LogicalResult
   matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
@@ -31,7 +51,9 @@ struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> {
         getTypeConverter()->convertType<VectorType>(constOp.getType());
     if (!resType)
       return rewriter.notifyMatchFailure(loc, "can't convert return type");
-
+    if (!isLessThanTargetBitWidth(constOp, targetVectorBitWidth))
+      return rewriter.notifyMatchFailure(
+          loc, "Can't flatten since targetBitWidth <= OpSize");
     auto dstElementsAttr = dyn_cast<DenseElementsAttr>(constOp.getValue());
     if (!dstElementsAttr)
       return rewriter.notifyMatchFailure(loc, "unsupported attr type");
@@ -41,15 +63,28 @@ struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> {
                                                    dstElementsAttr);
     return success();
   }
+
+private:
+  unsigned targetVectorBitWidth;
 };
 
 struct LinearizeVectorizable final
     : OpTraitConversionPattern<OpTrait::Vectorizable> {
   using OpTraitConversionPattern::OpTraitConversionPattern;
 
+public:
+  LinearizeVectorizable(
+      const TypeConverter &typeConverter, MLIRContext *context,
+      unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+      PatternBenefit benefit = 1)
+      : OpTraitConversionPattern(typeConverter, context, benefit),
+        targetVectorBitWidth(targetVectBitWidth) {}
   LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
+    if (!isLessThanTargetBitWidth(op, targetVectorBitWidth))
+      return rewriter.notifyMatchFailure(
+          op->getLoc(), "Can't flatten since targetBitWidth <= OpSize");
     FailureOr<Operation *> newOp =
         convertOpResultTypes(op, operands, *getTypeConverter(), rewriter);
     if (failed(newOp))
@@ -58,12 +93,16 @@ struct LinearizeVectorizable final
     rewriter.replaceOp(op, (*newOp)->getResults());
     return success();
   }
+
+private:
+  unsigned targetVectorBitWidth;
 };
 } // namespace
 
 void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
     TypeConverter &typeConverter, RewritePatternSet &patterns,
-    ConversionTarget &target) {
+    ConversionTarget &target, unsigned targetBitWidth) {
+
   typeConverter.addConversion([](VectorType type) -> std::optional<Type> {
     // Ignore scalable vectors for now.
     if (type.getRank() <= 1 || type.isScalable())
@@ -83,15 +122,17 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
   typeConverter.addArgumentMaterialization(materializeCast);
   typeConverter.addSourceMaterialization(materializeCast);
   typeConverter.addTargetMaterialization(materializeCast);
-
   target.markUnknownOpDynamicallyLegal(
-      [&](Operation *op) -> std::optional<bool> {
-        if (isa<arith::ConstantOp>(op) || op->hasTrait<OpTrait::Vectorizable>())
-          return typeConverter.isLegal(op);
-
+      [=](Operation *op) -> std::optional<bool> {
+        if ((isa<arith::ConstantOp>(op) ||
+             op->hasTrait<OpTrait::Vectorizable>())) {
+          return (isLessThanTargetBitWidth(op, targetBitWidth)
+                      ? typeConverter.isLegal(op)
+                      : true);
+        }
         return std::nullopt;
       });
 
-  patterns.add<LinearizeConstant, LinearizeVectorizable>(typeConverter,
-                                                         patterns.getContext());
+  patterns.add<LinearizeConstant, LinearizeVectorizable>(
+      typeConverter, patterns.getContext(), targetBitWidth);
 }

diff  --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 85e23103eaedb7..2cbf9bec7a4136 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -1,19 +1,92 @@
 // RUN: mlir-opt %s -split-input-file -test-vector-linearize | FileCheck %s
+// RUN: mlir-opt %s -split-input-file -test-vector-linearize=target-vector-bitwidth=128 | FileCheck %s --check-prefix=CHECK128
+// RUN: mlir-opt %s -split-input-file -test-vector-linearize=target-vector-bitwidth=0 | FileCheck %s --check-prefix=CHECK0
 
 // CHECK-LABEL: test_linearize
+// CHECK128-LABEL: test_linearize
+// CHECK0-LABEL: test_linearize
 //  CHECK-SAME: (%[[ORIG_ARG:.*]]: vector<2x2xf32>)
+//  CHECK128-SAME: (%[[ORIG_ARG:.*]]: vector<2x2xf32>)
 //       CHECK: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<2x2xf32> to vector<4xf32>
+//       CHECK128: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<2x2xf32> to vector<4xf32>
 func.func @test_linearize(%arg0: vector<2x2xf32>) -> vector<2x2xf32> {
 //       CHECK: %[[C1:.*]] = arith.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : vector<4xf32>
+//       CHECK128: %[[C1:.*]] = arith.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : vector<4xf32>
+//       CHECK0: %[[C1:.*]] = arith.constant dense<{{.*}}> : vector<2x2xf32>
+
   %0 = arith.constant dense<[[1.0, 2.0], [3.0, 4.0]]> : vector<2x2xf32>
 //       CHECK: %[[RES:.*]] = vector.shape_cast %[[C1]] : vector<4xf32> to vector<2x2xf32>
-
+//       CHECK128: %[[RES:.*]] = vector.shape_cast %[[C1]] : vector<4xf32> to vector<2x2xf32>
 // Arith and math ops are handled in generic way, check some of them
 //       CHECK: %{{.*}} =  math.sin %[[ARG]] : vector<4xf32>
+//       CHECK128: %{{.*}} =  math.sin %[[ARG]] : vector<4xf32>
+//       CHECK0: %{{.*}} =  math.sin %{{.*}} : vector<2x2xf32>
+  %1 = math.sin %arg0 : vector<2x2xf32>
+//       CHECK: %{{.*}} = arith.addf %[[ARG]], %[[C1]] : vector<4xf32>
+//       CHECK128: %{{.*}} = arith.addf %[[ARG]], %[[C1]] : vector<4xf32>
+//       CHECK0: %{{.*}} = arith.addf %{{.*}} : vector<2x2xf32>
+
+  %2 = arith.addf %arg0, %0 :  vector<2x2xf32>
+
+//       CHECK: return %[[RES]] : vector<2x2xf32>
+//       CHECK128: return %[[RES]] : vector<2x2xf32>
+  return %0 : vector<2x2xf32>
+}
+
+// CHECK-LABEL: test_partial_linearize
+// CHECK128-LABEL: test_partial_linearize
+// CHECK0-LABEL: test_partial_linearize
+//  CHECK-SAME: (%[[ORIG_ARG:.*]]: vector<2x2xf32>, %[[ORIG_ARG2:.*]]: vector<4x4xf32>)
+//  CHECK128-SAME: (%[[ORIG_ARG:.*]]: vector<2x2xf32>, %[[ORIG_ARG2:.*]]: vector<4x4xf32>)
+//  CHECK0-SAME: (%[[ORIG_ARG:.*]]: vector<2x2xf32>, %[[ORIG_ARG2:.*]]: vector<4x4xf32>)
+//       CHECK: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<2x2xf32> to vector<4xf32>
+//       CHECK128: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<2x2xf32> to vector<4xf32>
+//       CHECK: %[[ARG2:.*]] = vector.shape_cast %[[ORIG_ARG2]] : vector<4x4xf32> to vector<16xf32>
+func.func @test_partial_linearize(%arg0: vector<2x2xf32>, %arg1: vector<4x4xf32>) -> vector<2x2xf32> {
+//       CHECK: %[[C1:.*]] = arith.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : vector<4xf32>
+//       CHECK128: %[[C1:.*]] = arith.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : vector<4xf32>
+//       CHECK0: %[[C1:.*]] = arith.constant dense<{{.*}}> : vector<2x2xf32>
+
+  %0 = arith.constant dense<[[1.0, 2.0], [3.0, 4.0]]> : vector<2x2xf32>
+//       CHECK: %[[RES:.*]] = vector.shape_cast %[[C1]] : vector<4xf32> to vector<2x2xf32>
+//       CHECK128: %[[RES:.*]] = vector.shape_cast %[[C1]] : vector<4xf32> to vector<2x2xf32>
+
+  // CHECK: %[[C2:.*]] = arith.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 1.000000e+00, 2.000000e+00, 5.000000e+00, 6.000000e+00]> : vector<16xf32>
+  // CHECK128: %[[C2:.*]] = arith.constant dense<{{.*}}> : vector<4x4xf32>
+  // CHECK0: %[[C2:.*]] = arith.constant dense<{{.*}}> : vector<4x4xf32>
+  %5 = arith.constant dense<[[1.0, 2.0, 3.0, 4.0], [1.0, 2.0,3.0, 4.0], [1.0, 2.0, 3.0, 4.0], [1.0, 2.0, 5.0, 6.0]]> : vector<4x4xf32>
+// Arith and math ops are handled in generic way, check some of them
+//       CHECK: %[[SIN:.*]] =  math.sin %[[ARG]] : vector<4xf32>
+//       CHECK128: %[[SIN:.*]] =  math.sin %[[ARG]] : vector<4xf32>
+//       CHECK0: %[[SIN:.*]] =  math.sin %[[ORIG_ARG]] : vector<2x2xf32>
   %1 = math.sin %arg0 : vector<2x2xf32>
+
+  //     CHECK: %[[SIN1:.*]] =  math.sin %[[ARG2]] : vector<16xf32>
+//       CHECK128: %[[SIN1:.*]] =  math.sin %[[ORIG_ARG2]] : vector<4x4xf32>
+//       CHECK0: %[[SIN1:.*]] =  math.sin %[[ORIG_ARG2]] : vector<4x4xf32>
+  %6 = math.sin %arg1 : vector<4x4xf32>
 //       CHECK: %{{.*}} = arith.addf %[[ARG]], %[[C1]] : vector<4xf32>
+//       CHECK128: %{{.*}} = arith.addf %[[ARG]], %[[C1]] : vector<4xf32>
+//       CHECK0: %{{.*}} = arith.addf %{{.*}} : vector<2x2xf32>
+
   %2 = arith.addf %arg0, %0 :  vector<2x2xf32>
 
+  // CHECK: %[[ADD2:.*]] = arith.addf %[[ARG2]], %[[C2]] : vector<16xf32>
+  // CHECK128: %[[ADD2:.*]] = arith.addf %[[ORIG_ARG2]], %[[C2]] : vector<4x4xf32>
+  // CHECK0: %[[ADD2:.*]] = arith.addf %[[ORIG_ARG2]], %[[C2]] : vector<4x4xf32>
+  %7 = arith.addf %arg1, %5 : vector<4x4xf32>
 //       CHECK: return %[[RES]] : vector<2x2xf32>
+//       CHECK128: return %[[RES]] : vector<2x2xf32>
   return %0 : vector<2x2xf32>
 }
+
+// CHECK-LABEL: test_index_no_linearize
+// CHECK128-LABEL: test_index_no_linearize
+// CHECK0-LABEL: test_index_no_linearize
+func.func @test_index_no_linearize(%arg0: vector<2x2xindex>, %arg1: vector<2x2xindex>) -> vector<2x2xindex> {
+    // CHECK: %[[ADD:.*]] = arith.addi {{.*}} : vector<2x2xindex>
+    // CHECK128: %[[ADD:.*]] = arith.addi {{.*}} : vector<2x2xindex>
+    // CHECK0: %[[ADD:.*]] = arith.addi {{.*}} : vector<2x2xindex>
+    %0 = arith.addi %arg0, %arg1 : vector<2x2xindex>
+    return %0 : vector<2x2xindex>
+}

diff  --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 915f713f7047be..f14fb18706d1b7 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -840,6 +840,9 @@ struct TestVectorLinearize final
     : public PassWrapper<TestVectorLinearize, OperationPass<>> {
   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorLinearize)
 
+  TestVectorLinearize() = default;
+  TestVectorLinearize(const TestVectorLinearize &pass) : PassWrapper(pass) {}
+
   StringRef getArgument() const override { return "test-vector-linearize"; }
   StringRef getDescription() const override {
     return "Linearizes ND vectors for N >= 2 into 1D vectors";
@@ -848,6 +851,11 @@ struct TestVectorLinearize final
     registry.insert<vector::VectorDialect>();
   }
 
+  Option<unsigned> targetVectorBitwidth{
+      *this, "target-vector-bitwidth",
+      llvm::cl::desc(
+          "Minimum vector bitwidth to enable the flattening transformation"),
+      llvm::cl::init(std::numeric_limits<unsigned>::max())};
   void runOnOperation() override {
     auto *context = &getContext();
 
@@ -855,8 +863,8 @@ struct TestVectorLinearize final
     RewritePatternSet patterns(context);
     ConversionTarget target(*context);
 
-    vector::populateVectorLinearizeTypeConversionsAndLegality(typeConverter,
-                                                              patterns, target);
+    vector::populateVectorLinearizeTypeConversionsAndLegality(
+        typeConverter, patterns, target, targetVectorBitwidth);
     if (failed(applyPartialConversion(getOperation(), target,
                                       std::move(patterns))))
       return signalPassFailure();


        


More information about the Mlir-commits mailing list