[Mlir-commits] [mlir] d3aa92e - [mlir][vector] Add support for scalable vectors to VectorLinearize (#86786)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Mar 28 07:53:26 PDT 2024


Author: Andrzej WarzyƄski
Date: 2024-03-28T14:53:21Z
New Revision: d3aa92ed142409266ebcc9cbc20e5f2c2d0209c0

URL: https://github.com/llvm/llvm-project/commit/d3aa92ed142409266ebcc9cbc20e5f2c2d0209c0
DIFF: https://github.com/llvm/llvm-project/commit/d3aa92ed142409266ebcc9cbc20e5f2c2d0209c0.diff

LOG: [mlir][vector] Add support for scalable vectors to VectorLinearize (#86786)

Adds support for scalable vectors to patterns defined in
VectorLineralize.cpp.

Linearization is disable in 2 notable cases:
  * vectors with more than 1 scalable dimension (we cannot represent
    vscale^2),
  * vectors initialised with arith.constant that's not a vector splat
    (such arith.constant Ops cannot be flattened).

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
    mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
    mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
    mlir/test/Dialect/Vector/linearize.mlir
    mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
index 2c548fb6740251..f88fbdf9e62765 100644
--- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
+++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
@@ -170,6 +170,16 @@ struct MaskableOpRewritePattern : OpRewritePattern<SourceOp> {
                             PatternRewriter &rewriter) const = 0;
 };
 
+/// Returns true if the input Vector type can be linearized.
+///
+/// Linearization is meant in the sense of flattening vectors, e.g.:
+///   * vector<NxMxKxi32> -> vector<N*M*Kxi32>
+/// In this sense, Vectors that are either:
+///   * already linearized, or
+///   * contain more than 1 scalable dimensions,
+/// are not linearizable.
+bool isLinearizableVector(VectorType type);
+
 } // namespace vector
 
 /// Constructs a permutation map of invariant memref indices to vector

diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 38536de43f13f2..4fa5b8a4865b4f 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -49,6 +49,12 @@ struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> {
     Location loc = constOp.getLoc();
     auto resType =
         getTypeConverter()->convertType<VectorType>(constOp.getType());
+
+    if (resType.isScalable() && !isa<SplatElementsAttr>(constOp.getValue()))
+      return rewriter.notifyMatchFailure(
+          loc,
+          "Cannot linearize a constant scalable vector that's not a splat");
+
     if (!resType)
       return rewriter.notifyMatchFailure(loc, "can't convert return type");
     if (!isLessThanTargetBitWidth(constOp, targetVectorBitWidth))
@@ -104,11 +110,11 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
     ConversionTarget &target, unsigned targetBitWidth) {
 
   typeConverter.addConversion([](VectorType type) -> std::optional<Type> {
-    // Ignore scalable vectors for now.
-    if (type.getRank() <= 1 || type.isScalable())
+    if (!isLinearizableVector(type))
       return type;
 
-    return VectorType::get(type.getNumElements(), type.getElementType());
+    return VectorType::get(type.getNumElements(), type.getElementType(),
+                           type.isScalable());
   });
 
   auto materializeCast = [](OpBuilder &builder, Type type, ValueRange inputs,

diff  --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index 63ed0947cf6ce2..ebc6f5cbcaa9ed 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -317,3 +317,8 @@ SmallVector<OpFoldResult> vector::getMixedSizesXfer(bool hasTensorSemantics,
                          : memref::getMixedSizes(rewriter, loc, base);
   return mixedSourceDims;
 }
+
+bool vector::isLinearizableVector(VectorType type) {
+  auto numScalableDims = llvm::count(type.getScalableDims(), true);
+  return (type.getRank() > 1) && (numScalableDims <= 1);
+}

diff  --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 1b225c7a97d233..f0e9b3a05c066e 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -1,5 +1,5 @@
-// RUN: mlir-opt %s -split-input-file -test-vector-linearize | FileCheck %s --check-prefixes=ALL,DEFAULT
-// RUN: mlir-opt %s -split-input-file -test-vector-linearize=target-vector-bitwidth=128 | FileCheck %s --check-prefixes=ALL,BW-128
+// RUN: mlir-opt %s -split-input-file -test-vector-linearize -verify-diagnostics | FileCheck %s --check-prefixes=ALL,DEFAULT
+// RUN: mlir-opt %s -split-input-file -test-vector-linearize=target-vector-bitwidth=128  -verify-diagnostics | FileCheck %s --check-prefixes=ALL,BW-128
 // RUN: mlir-opt %s -split-input-file -test-vector-linearize=target-vector-bitwidth=0 | FileCheck %s --check-prefixes=ALL,BW-0
 
 // ALL-LABEL: test_linearize
@@ -97,3 +97,60 @@ func.func @test_tensor_no_linearize(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf3
 
     return %0, %arg0 : tensor<2x2xf32>, tensor<2x2xf32>
 }
+
+// -----
+
+// ALL-LABEL:   func.func @test_scalable_linearize(
+// ALL-SAME:    %[[ARG_0:.*]]: vector<2x[2]xf32>) -> vector<2x[2]xf32> {
+func.func @test_scalable_linearize(%arg0: vector<2x[2]xf32>) -> vector<2x[2]xf32> {
+  // DEFAULT:  %[[SC:.*]] = vector.shape_cast %[[ARG_0]] : vector<2x[2]xf32> to vector<[4]xf32>
+  // DEFAULT:  %[[CST:.*]] = arith.constant dense<3.000000e+00> : vector<[4]xf32>
+  // BW-128:  %[[SC:.*]] = vector.shape_cast %[[ARG_0]] : vector<2x[2]xf32> to vector<[4]xf32>
+  // BW-128:  %[[CST:.*]] = arith.constant dense<3.000000e+00> : vector<[4]xf32>
+  // BW-0:  %[[CST:.*]] = arith.constant dense<3.000000e+00> : vector<2x[2]xf32>
+  %0 = arith.constant dense<[[3., 3.], [3., 3.]]> : vector<2x[2]xf32>
+
+  // DEFAULT: %[[SIN:.*]] = math.sin %[[SC]] : vector<[4]xf32>
+  // BW-128: %[[SIN:.*]] = math.sin %[[SC]] : vector<[4]xf32>
+  // BW-0: %[[SIN:.*]] = math.sin %[[ARG_0]] : vector<2x[2]xf32>
+  %1 = math.sin %arg0 : vector<2x[2]xf32>
+
+  // DEFAULT: %[[ADDF:.*]] = arith.addf %[[SIN]], %[[CST]] : vector<[4]xf32>
+  // BW-128: %[[ADDF:.*]] = arith.addf %[[SIN]], %[[CST]] : vector<[4]xf32>
+  // BW-0: %[[RES:.*]] = arith.addf %[[CST]], %[[SIN]] : vector<2x[2]xf32>
+  %2 = arith.addf %0, %1 : vector<2x[2]xf32>
+
+  // DEFAULT: %[[RES:.*]] = vector.shape_cast %[[ADDF]] : vector<[4]xf32> to vector<2x[2]xf32>
+  // BW-128: %[[RES:.*]] = vector.shape_cast %[[ADDF]] : vector<[4]xf32> to vector<2x[2]xf32>
+  // ALL: return %[[RES]] : vector<2x[2]xf32>
+  return %2 : vector<2x[2]xf32>
+}
+
+// -----
+
+// ALL-LABEL:   func.func @test_scalable_no_linearize(
+// ALL-SAME:     %[[VAL_0:.*]]: vector<[2]x[2]xf32>) -> vector<[2]x[2]xf32> {
+func.func @test_scalable_no_linearize(%arg0: vector<[2]x[2]xf32>) -> vector<[2]x[2]xf32> {
+  // ALL: %[[CST:.*]] = arith.constant dense<2.000000e+00> : vector<[2]x[2]xf32>
+  %0 = arith.constant dense<[[2., 2.], [2., 2.]]> : vector<[2]x[2]xf32>
+
+  // ALL: %[[SIN:.*]] = math.sin %[[VAL_0]] : vector<[2]x[2]xf32>
+  %1 = math.sin %arg0 : vector<[2]x[2]xf32>
+
+  // ALL: %[[RES:.*]] = arith.addf %[[CST]], %[[SIN]] : vector<[2]x[2]xf32>
+  %2 = arith.addf %0, %1 : vector<[2]x[2]xf32>
+
+  // ALL: return %[[RES]] : vector<[2]x[2]xf32>
+  return %2 : vector<[2]x[2]xf32>
+}
+
+// -----
+
+func.func @test_scalable_no_linearize(%arg0: vector<2x[2]xf32>) -> vector<2x[2]xf32> {
+  // expected-error at +1 {{failed to legalize operation 'arith.constant' that was explicitly marked illegal}}
+  %0 = arith.constant dense<[[1., 1.], [3., 3.]]> : vector<2x[2]xf32>
+  %1 = math.sin %arg0 : vector<2x[2]xf32>
+  %2 = arith.addf %0, %1 : vector<2x[2]xf32>
+
+  return %2 : vector<2x[2]xf32>
+}

diff  --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index f14fb18706d1b7..00622599910567 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -489,7 +489,9 @@ struct TestFlattenVectorTransferPatterns
   Option<unsigned> targetVectorBitwidth{
       *this, "target-vector-bitwidth",
       llvm::cl::desc(
-          "Minimum vector bitwidth to enable the flattening transformation"),
+          "Minimum vector bitwidth to enable the flattening transformation. "
+          "For scalable vectors this is the base size, i.e. the size "
+          "corresponding to vscale=1."),
       llvm::cl::init(std::numeric_limits<unsigned>::max())};
 
   void runOnOperation() override {


        


More information about the Mlir-commits mailing list