[Mlir-commits] [mlir] 3d6d5df - [mlir][vector] Address linearization comments (post commit) (#138075)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu May 15 07:52:57 PDT 2025


Author: James Newling
Date: 2025-05-15T07:52:53-07:00
New Revision: 3d6d5dfed2b303e9fba74586993df3fa85058991

URL: https://github.com/llvm/llvm-project/commit/3d6d5dfed2b303e9fba74586993df3fa85058991
DIFF: https://github.com/llvm/llvm-project/commit/3d6d5dfed2b303e9fba74586993df3fa85058991.diff

LOG: [mlir][vector] Address linearization comments (post commit) (#138075)

This PR adds some documentation to address comments in
https://github.com/llvm/llvm-project/pull/136581 

This PR adds a test for linearization across scf.for. This new test
might be considered redundant by more experienced MLIRers, but might
help newer users understand how to linearize scf/cf/func operations
easily

The documentation added in this PR also tightens our definition of
linearization, to now exclude unrolling (which creates multiple ops from
1 op). We hadn't really specified what linearization meant before.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
    mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
    mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
    mlir/test/Dialect/Vector/linearize.mlir
    mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index f1100d5cf8b68..34a94e6ea7051 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -407,13 +407,22 @@ void populateVectorTransposeNarrowTypeRewritePatterns(
     RewritePatternSet &patterns, PatternBenefit benefit = 1);
 
 /// Initialize `typeConverter` and `conversionTarget` for vector linearization.
-/// This registers (1) which operations are legal and hence should not be
-/// linearized, (2) what converted types are (rank-1 vectors) and how to
+///
+/// Definition: here 'linearization' means converting a single operation with
+/// 1+ vector operand/result of rank>1, into a new single operation whose
+/// vector operands and results are all of rank<=1.
+///
+/// This function registers (1) which operations are legal, and hence should not
+/// be linearized, (2) what the converted types are (rank-1 vectors) and how to
 /// materialze the conversion (with shape_cast)
 ///
 /// Note: the set of legal operations can be extended by a user if for example
-/// certain rank>1 vectors are considered valid, but adding additional
+/// certain rank>1 vectors are considered valid, by adding additional
 /// dynamically legal ops to `conversionTarget`.
+///
+/// Further note: the choice to use a dialect conversion design for
+/// linearization is to make it easy to reuse generic structural type
+/// conversions for linearizing scf/cf/func operations
 void populateForVectorLinearize(TypeConverter &typeConverter,
                                 ConversionTarget &conversionTarget);
 

diff  --git a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
index 40d2e254fb7dd..09326242eec2a 100644
--- a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
@@ -99,7 +99,7 @@ class ConvertForOpTypes
     // PR47938 tracks this issue, but it seems hard to fix. Instead, we need
     // to clone the op.
     //
-    // 2. We need to resue the original region instead of cloning it, otherwise
+    // 2. We need to reuse the original region instead of cloning it, otherwise
     // the dialect conversion framework thinks that we just inserted all the
     // cloned child ops. But what we want is to "take" the child regions and let
     // the dialect conversion framework continue recursively into ops inside

diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 90e0479a515d5..060ce7d1d6643 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -626,45 +626,49 @@ struct LinearizeVectorCreateMask final
 
 } // namespace
 
-/// Return true if the operation `op` does not support scalable vectors and
-/// has at least 1 scalable vector result. These ops should all eventually
-/// support scalable vectors, and this function should be removed.
-static bool isNotLinearizableBecauseScalable(Operation *op) {
-
-  bool unsupported =
-      isa<vector::ExtractStridedSliceOp, vector::InsertStridedSliceOp,
-          vector::ExtractOp, vector::InsertOp>(op);
-  if (!unsupported)
-    return false;
-
-  // Check if any of the results is a scalable vector type.
-  auto types = op->getResultTypes();
-  bool containsScalableResult =
-      std::any_of(types.begin(), types.end(), [](Type type) {
-        auto vecType = dyn_cast<VectorType>(type);
-        return vecType && vecType.isScalable();
-      });
-
-  return containsScalableResult;
-}
-
-static bool isNotLinearizable(Operation *op) {
+/// This method defines the set of operations that are linearizable, and hence
+/// that are considered illegal for the conversion target.
+static bool isLinearizable(Operation *op) {
 
   // Only ops that are in the vector dialect, are ConstantLike, or
   // are Vectorizable might be linearized currently.
   StringLiteral vectorDialect = vector::VectorDialect::getDialectNamespace();
   StringRef opDialect = op->getDialect()->getNamespace();
-  bool unsupported = (opDialect != vectorDialect) &&
-                     !op->hasTrait<OpTrait::ConstantLike>() &&
-                     !op->hasTrait<OpTrait::Vectorizable>();
-  if (unsupported)
-    return true;
-
-  // Some ops currently don't support scalable vectors.
-  if (isNotLinearizableBecauseScalable(op))
-    return true;
+  bool supported = (opDialect == vectorDialect) ||
+                   op->hasTrait<OpTrait::ConstantLike>() ||
+                   op->hasTrait<OpTrait::Vectorizable>();
+  if (!supported)
+    return false;
 
-  return false;
+  return TypeSwitch<Operation *, bool>(op)
+      // As type legalization is done with vector.shape_cast, shape_cast
+      // itself cannot be linearized (will create new shape_casts to linearize
+      // ad infinitum).
+      .Case<vector::ShapeCastOp>([&](auto) { return false; })
+      // The operations
+      // - vector.extract_strided_slice
+      // - vector.extract
+      // - vector.insert_strided_slice
+      // - vector.insert
+      // are linearized to a rank-1 vector.shuffle by the current patterns.
+      // vector.shuffle only supports fixed size vectors, so it is impossible to
+      // use this approach to linearize these ops if they operate on scalable
+      // vectors.
+      .Case<vector::ExtractStridedSliceOp>(
+          [&](vector::ExtractStridedSliceOp extractOp) {
+            return !extractOp.getType().isScalable();
+          })
+      .Case<vector::InsertStridedSliceOp>(
+          [&](vector::InsertStridedSliceOp insertOp) {
+            return !insertOp.getType().isScalable();
+          })
+      .Case<vector::InsertOp>([&](vector::InsertOp insertOp) {
+        return !insertOp.getType().isScalable();
+      })
+      .Case<vector::ExtractOp>([&](vector::ExtractOp extractOp) {
+        return !extractOp.getSourceVectorType().isScalable();
+      })
+      .Default([&](auto) { return true; });
 }
 
 void mlir::vector::populateForVectorLinearize(TypeConverter &typeConverter,
@@ -698,7 +702,7 @@ void mlir::vector::populateForVectorLinearize(TypeConverter &typeConverter,
 
   target.markUnknownOpDynamicallyLegal(
       [=](Operation *op) -> std::optional<bool> {
-        if (isNotLinearizable(op))
+        if (!isLinearizable(op))
           return true;
         // This will return true if, for all operand and result types `t`,
         // convertType(t) = t. This is true if there are no rank>=2 vectors.

diff  --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 40445d3781228..9cbf319ffddb2 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -392,6 +392,28 @@ func.func @test_vector_bitcast(%arg0: vector<[4]x2xf32>) -> vector<[4]x4xf16> {
 
 // -----
 
+// CHECK-LABEL: test_linearize_across_for
+func.func @test_linearize_across_for(%arg0 : vector<4xi8>) -> vector<4xi8> {
+  %0 = vector.shape_cast %arg0 : vector<4xi8> to vector<2x2xi8>
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c4 = arith.constant 4 : index
+
+  // CHECK:  scf.for {{.*}} -> (vector<4xi8>)
+  %1 = scf.for %i = %c0 to %c4 step %c1 iter_args(%arg1 = %0) -> (vector<2x2xi8>) {
+
+    // CHECK:  arith.addi {{.*}} : vector<4xi8>
+    %2 = arith.addi %arg1, %0 : vector<2x2xi8>
+
+    // CHECK:  scf.yield {{.*}} : vector<4xi8>
+    scf.yield %2 : vector<2x2xi8>
+  }
+  %3 = vector.shape_cast %1 : vector<2x2xi8> to vector<4xi8>
+  return %3 : vector<4xi8>
+}
+
+// -----
+
 // CHECK-LABEL: linearize_vector_splat
 // CHECK-SAME: (%[[ARG:.*]]: i32) -> vector<4x2xi32>
 func.func @linearize_vector_splat(%arg0: i32) -> vector<4x2xi32> {
@@ -414,6 +436,7 @@ func.func @linearize_scalable_vector_splat(%arg0: i32) -> vector<4x[2]xi32> {
   // CHECK: return %[[CAST]] : vector<4x[2]xi32>
   %0 = vector.splat %arg0 : vector<4x[2]xi32>
   return %0 : vector<4x[2]xi32>
+
 }
 
 // -----

diff  --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 54defd949c264..ccba2e2806862 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -17,6 +17,7 @@
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SCF/Transforms/Patterns.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
@@ -836,9 +837,6 @@ struct TestVectorEmulateMaskedLoadStore final
   }
 };
 
-// TODO: move this code into the user project.
-namespace vendor {
-
 /// Get the set of operand/result types to check for sufficiently
 /// small inner-most dimension size.
 static SmallVector<std::pair<Type, unsigned>>
@@ -960,8 +958,6 @@ struct TestVectorBitWidthLinearize final
   }
 };
 
-} // namespace vendor
-
 struct TestVectorLinearize final
     : public PassWrapper<TestVectorLinearize, OperationPass<>> {
   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorLinearize)
@@ -987,6 +983,8 @@ struct TestVectorLinearize final
     vector::populateVectorLinearizeBasePatterns(converter, target, patterns);
     vector::populateVectorLinearizeShuffleLikeOpsPatterns(converter, target,
                                                           patterns);
+    mlir::scf::populateSCFStructuralTypeConversionsAndLegality(
+        converter, patterns, target);
 
     if (failed(applyPartialConversion(getOperation(), target,
                                       std::move(patterns))))
@@ -1067,7 +1065,7 @@ void registerTestVectorLowerings() {
 
   PassRegistration<TestVectorLinearize>();
 
-  PassRegistration<vendor::TestVectorBitWidthLinearize>();
+  PassRegistration<TestVectorBitWidthLinearize>();
 
   PassRegistration<TestEliminateVectorMasks>();
 }


        


More information about the Mlir-commits mailing list