[Mlir-commits] [mlir] [mlir][vector] Address linearization comments (post commit) (PR #138075)

James Newling llvmlistbot at llvm.org
Mon May 5 17:04:10 PDT 2025


https://github.com/newling updated https://github.com/llvm/llvm-project/pull/138075

>From 569c56013c8fe28fb8754ce124cffab3da946f7d Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Wed, 30 Apr 2025 18:02:04 -0700
Subject: [PATCH 1/4] first commit (needs refinement)

---
 .../Vector/Transforms/VectorRewritePatterns.h |  9 +++++--
 .../Transforms/StructuralTypeConversions.cpp  |  2 +-
 .../Vector/Transforms/VectorLinearize.cpp     | 25 ++++++++++++-------
 .../Dialect/Vector/TestVectorTransforms.cpp   | 14 ++++++++---
 4 files changed, 34 insertions(+), 16 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index f1100d5cf8b68..c45ecb3bebc1c 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -407,8 +407,13 @@ void populateVectorTransposeNarrowTypeRewritePatterns(
     RewritePatternSet &patterns, PatternBenefit benefit = 1);
 
 /// Initialize `typeConverter` and `conversionTarget` for vector linearization.
-/// This registers (1) which operations are legal and hence should not be
-/// linearized, (2) what converted types are (rank-1 vectors) and how to
+///
+/// Definition: here 'linearization' means converting a single operation with
+/// 1+ vector operands and results of rank>1, into a single operation whose
+/// vector operands are all of rank<=1.
+///
+/// This function registers (1) which operations are legal, and hence should not
+/// be linearized, (2) what the converted types are (rank-1 vectors) and how to
 /// materialze the conversion (with shape_cast)
 ///
 /// Note: the set of legal operations can be extended by a user if for example
diff --git a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
index 40d2e254fb7dd..09326242eec2a 100644
--- a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
@@ -99,7 +99,7 @@ class ConvertForOpTypes
     // PR47938 tracks this issue, but it seems hard to fix. Instead, we need
     // to clone the op.
     //
-    // 2. We need to resue the original region instead of cloning it, otherwise
+    // 2. We need to reuse the original region instead of cloning it, otherwise
     // the dialect conversion framework thinks that we just inserted all the
     // cloned child ops. But what we want is to "take" the child regions and let
     // the dialect conversion framework continue recursively into ops inside
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index b9cef003fa365..8dee2454d53c2 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -134,9 +134,6 @@ struct LinearizeVectorExtractStridedSlice final
     VectorType dstType =
         getTypeConverter()->convertType<VectorType>(extractOp.getType());
     assert(dstType && "vector type destination expected.");
-    if (extractOp.getVector().getType().isScalable() || dstType.isScalable())
-      return rewriter.notifyMatchFailure(extractOp,
-                                         "scalable vectors are not supported.");
 
     ArrayAttr offsets = extractOp.getOffsets();
     ArrayAttr sizes = extractOp.getSizes();
@@ -447,18 +444,21 @@ struct LinearizeVectorSplat final
 
 } // namespace
 
-/// Return true if the operation `op` does not support scalable vectors and
-/// has at least 1 scalable vector result. These ops should all eventually
-/// support scalable vectors, and this function should be removed.
+/// Some operations currently cannot be linearized if they have scalable vector
+/// results. This function returns true if `op` is such an operation.
 static bool isNotLinearizableBecauseScalable(Operation *op) {
 
   bool unsupported =
       isa<vector::ExtractStridedSliceOp, vector::ExtractOp, vector::InsertOp>(
           op);
+
+  // Case where linearization is possible even when there are scalable vector
+  // results.
   if (!unsupported)
     return false;
 
-  // Check if any of the results is a scalable vector type.
+  // Check if any of the results is a scalable vector type, and if there are
+  // return true (not linearizable).
   auto types = op->getResultTypes();
   bool containsScalableResult =
       std::any_of(types.begin(), types.end(), [](Type type) {
@@ -469,10 +469,17 @@ static bool isNotLinearizableBecauseScalable(Operation *op) {
   return containsScalableResult;
 }
 
+/// This method defines a set of operations that are not linearizable,
+/// and hence considered legal for the conversion target. These ops are
+/// currently
+///
+/// 1) Ops that are not in the vector dialect, are not ConstantLike, and are not
+///    Vectorizable.
+///
+/// 2) Certain ops with scalable vector results, for which support has not yet
+///    been added.
 static bool isNotLinearizable(Operation *op) {
 
-  // Only ops that are in the vector dialect, are ConstantLike, or
-  // are Vectorizable might be linearized currently.
   StringLiteral vectorDialect = vector::VectorDialect::getDialectNamespace();
   StringRef opDialect = op->getDialect()->getNamespace();
   bool unsupported = (opDialect != vectorDialect) &&
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index eda2594fbc7c7..c3b036b6c9bcf 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -17,6 +17,7 @@
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SCF/Transforms/Patterns.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
@@ -836,8 +837,7 @@ struct TestVectorEmulateMaskedLoadStore final
   }
 };
 
-// TODO: move this code into the user project.
-namespace vendor {
+namespace bit_width_constrained_vector_linearize {
 
 /// Get the set of operand/result types to check for sufficiently
 /// small inner-most dimension size.
@@ -960,7 +960,7 @@ struct TestVectorBitWidthLinearize final
   }
 };
 
-} // namespace vendor
+} // namespace bit_width_constrained_vector_linearize
 
 struct TestVectorLinearize final
     : public PassWrapper<TestVectorLinearize, OperationPass<>> {
@@ -982,12 +982,17 @@ struct TestVectorLinearize final
     RewritePatternSet patterns(&context);
     ConversionTarget target(context);
 
+    SmallVector<Operation *> ops;
+
     vector::populateForVectorLinearize(converter, target);
 
     vector::populateVectorLinearizeBasePatterns(converter, target, patterns);
     vector::populateVectorLinearizeShuffleLikeOpsPatterns(converter, target,
                                                           patterns);
 
+    mlir::scf::populateSCFStructuralTypeConversionsAndLegality(
+        converter, patterns, target);
+
     if (failed(applyPartialConversion(getOperation(), target,
                                       std::move(patterns))))
       return signalPassFailure();
@@ -1067,7 +1072,8 @@ void registerTestVectorLowerings() {
 
   PassRegistration<TestVectorLinearize>();
 
-  PassRegistration<vendor::TestVectorBitWidthLinearize>();
+  PassRegistration<
+      bit_width_constrained_vector_linearize::TestVectorBitWidthLinearize>();
 
   PassRegistration<TestEliminateVectorMasks>();
 }

>From 9c3d1a1d8093982f8c38502984f7b5c9b3f17a72 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Wed, 30 Apr 2025 19:09:40 -0700
Subject: [PATCH 2/4] further enhancements

---
 .../Vector/Transforms/VectorRewritePatterns.h | 10 +++++---
 .../Vector/Transforms/VectorLinearize.cpp     | 22 ++++++++++--------
 mlir/test/Dialect/Vector/linearize.mlir       | 23 +++++++++++++++++++
 .../Dialect/Vector/TestVectorTransforms.cpp   |  7 +++---
 4 files changed, 46 insertions(+), 16 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index c45ecb3bebc1c..34a94e6ea7051 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -409,16 +409,20 @@ void populateVectorTransposeNarrowTypeRewritePatterns(
 /// Initialize `typeConverter` and `conversionTarget` for vector linearization.
 ///
 /// Definition: here 'linearization' means converting a single operation with
-/// 1+ vector operands and results of rank>1, into a single operation whose
-/// vector operands are all of rank<=1.
+/// 1+ vector operand/result of rank>1, into a new single operation whose
+/// vector operands and results are all of rank<=1.
 ///
 /// This function registers (1) which operations are legal, and hence should not
 /// be linearized, (2) what the converted types are (rank-1 vectors) and how to
 /// materialze the conversion (with shape_cast)
 ///
 /// Note: the set of legal operations can be extended by a user if for example
-/// certain rank>1 vectors are considered valid, but adding additional
+/// certain rank>1 vectors are considered valid, by adding additional
 /// dynamically legal ops to `conversionTarget`.
+///
+/// Further note: the choice to use a dialect conversion design for
+/// linearization is to make it easy to reuse generic structural type
+/// conversions for linearizing scf/cf/func operations
 void populateForVectorLinearize(TypeConverter &typeConverter,
                                 ConversionTarget &conversionTarget);
 
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 8dee2454d53c2..63fe63e4d7c5c 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -444,8 +444,9 @@ struct LinearizeVectorSplat final
 
 } // namespace
 
-/// Some operations currently cannot be linearized if they have scalable vector
-/// results. This function returns true if `op` is such an operation.
+/// Some operations currently will not be linearized if they have scalable
+/// vector results, although support should be added in the future. This
+/// function returns true if `op` is such an operation.
 static bool isNotLinearizableBecauseScalable(Operation *op) {
 
   bool unsupported =
@@ -469,15 +470,14 @@ static bool isNotLinearizableBecauseScalable(Operation *op) {
   return containsScalableResult;
 }
 
-/// This method defines a set of operations that are not linearizable,
-/// and hence considered legal for the conversion target. These ops are
-/// currently
+/// This method defines a set of operations that are not linearizable, and hence
+/// they are considered legal for the conversion target. These ops are
+/// currently,
 ///
-/// 1) Ops that are not in the vector dialect, are not ConstantLike, and are not
-///    Vectorizable.
+/// 1) ones that are not in the vector dialect, are not ConstantLike, and are
+///    not Vectorizable, or
 ///
-/// 2) Certain ops with scalable vector results, for which support has not yet
-///    been added.
+/// 2) have scalable vector results, for which support has not yet been added.
 static bool isNotLinearizable(Operation *op) {
 
   StringLiteral vectorDialect = vector::VectorDialect::getDialectNamespace();
@@ -488,6 +488,10 @@ static bool isNotLinearizable(Operation *op) {
   if (unsupported)
     return true;
 
+  // vector.shape_cast cannot be linearized.
+  if (isa<vector::ShapeCastOp>(op))
+    return true;
+
   // Some ops currently don't support scalable vectors.
   if (isNotLinearizableBecauseScalable(op))
     return true;
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 01ad1ac48b012..33d648d7163cf 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -320,6 +320,28 @@ func.func @test_vector_bitcast(%arg0: vector<[4]x2xf32>) -> vector<[4]x4xf16> {
   return %1 : vector<[4]x4xf16>
 }
 
+// ----- 
+
+// CHECK-LABEL: test_linearize_across_for
+func.func @test_linearize_across_for(%arg0 : vector<4xi8>) -> vector<4xi8> {
+  %0 = vector.shape_cast %arg0 : vector<4xi8> to vector<2x2xi8>
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c4 = arith.constant 4 : index
+
+  // CHECK:  scf.for {{.*}} -> (vector<4xi8>)
+  %1 = scf.for %i = %c0 to %c4 step %c1 iter_args(%arg1 = %0) -> (vector<2x2xi8>) {
+
+    // CHECK:  arith.addi {{.*}} : vector<4xi8>
+    %2 = arith.addi %arg1, %0 : vector<2x2xi8>
+
+    // CHECK:  scf.yield {{.*}} : vector<4xi8>
+    scf.yield %2 : vector<2x2xi8>
+  }
+  %3 = vector.shape_cast %1 : vector<2x2xi8> to vector<4xi8>
+  return %3 : vector<4xi8>
+}
+
 // -----
 
 // CHECK-LABEL: linearize_vector_splat
@@ -344,4 +366,5 @@ func.func @linearize_scalable_vector_splat(%arg0: i32) -> vector<4x[2]xi32> {
   // CHECK: return %[[CAST]] : vector<4x[2]xi32>
   %0 = vector.splat %arg0 : vector<4x[2]xi32>
   return %0 : vector<4x[2]xi32>
+
 }
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index c3b036b6c9bcf..318c5a70e7919 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -837,7 +837,7 @@ struct TestVectorEmulateMaskedLoadStore final
   }
 };
 
-namespace bit_width_constrained_vector_linearize {
+namespace bit_width_constrained_linearization {
 
 /// Get the set of operand/result types to check for sufficiently
 /// small inner-most dimension size.
@@ -960,7 +960,7 @@ struct TestVectorBitWidthLinearize final
   }
 };
 
-} // namespace bit_width_constrained_vector_linearize
+} // namespace bit_width_constrained_linearization
 
 struct TestVectorLinearize final
     : public PassWrapper<TestVectorLinearize, OperationPass<>> {
@@ -989,7 +989,6 @@ struct TestVectorLinearize final
     vector::populateVectorLinearizeBasePatterns(converter, target, patterns);
     vector::populateVectorLinearizeShuffleLikeOpsPatterns(converter, target,
                                                           patterns);
-
     mlir::scf::populateSCFStructuralTypeConversionsAndLegality(
         converter, patterns, target);
 
@@ -1073,7 +1072,7 @@ void registerTestVectorLowerings() {
   PassRegistration<TestVectorLinearize>();
 
   PassRegistration<
-      bit_width_constrained_vector_linearize::TestVectorBitWidthLinearize>();
+      bit_width_constrained_linearization::TestVectorBitWidthLinearize>();
 
   PassRegistration<TestEliminateVectorMasks>();
 }

>From bbf3d6b391b84a49c5d073a5042d5f4e1bdaf203 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Wed, 30 Apr 2025 19:25:30 -0700
Subject: [PATCH 3/4] remove dead code

---
 mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp | 2 --
 1 file changed, 2 deletions(-)

diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 318c5a70e7919..03f8a04a0ba7a 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -982,8 +982,6 @@ struct TestVectorLinearize final
     RewritePatternSet patterns(&context);
     ConversionTarget target(context);
 
-    SmallVector<Operation *> ops;
-
     vector::populateForVectorLinearize(converter, target);
 
     vector::populateVectorLinearizeBasePatterns(converter, target, patterns);

>From 004672710ca6f7aed6a67de9b1dfe238b1da897b Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Mon, 5 May 2025 17:03:55 -0700
Subject: [PATCH 4/4] whitespace

---
 mlir/test/Dialect/Vector/linearize.mlir | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 33d648d7163cf..97a4acaefbcee 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -320,7 +320,7 @@ func.func @test_vector_bitcast(%arg0: vector<[4]x2xf32>) -> vector<[4]x4xf16> {
   return %1 : vector<[4]x4xf16>
 }
 
-// ----- 
+// -----
 
 // CHECK-LABEL: test_linearize_across_for
 func.func @test_linearize_across_for(%arg0 : vector<4xi8>) -> vector<4xi8> {



More information about the Mlir-commits mailing list