[Mlir-commits] [mlir] [mlir][vector] Add support for scalable vectors to VectorLinearize (PR #86786)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Thu Mar 28 04:39:36 PDT 2024
https://github.com/banach-space updated https://github.com/llvm/llvm-project/pull/86786
>From 457722c5ffd2da7b2ad457c52b9dab5a64bc49ce Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Tue, 26 Mar 2024 09:19:44 +0000
Subject: [PATCH 1/3] [mlir][vector] Add support for scalable vectors to
VectorLinearize
Adds support for scalable vectors to patterns defined in
VectorLineralize.cpp.
Linearization is disable in 2 notable cases:
* vectors with more than 1 scalable dimension (we cannot represent
vscale^2),
* vectors initialised with arith.constant that's not a vector splat
(such arith.constant Ops cannot be flattened).
---
.../mlir/Dialect/Vector/Utils/VectorUtils.h | 10 +++++
.../Vector/Transforms/VectorLinearize.cpp | 11 +++--
mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp | 5 +++
mlir/test/Dialect/Vector/linearize.mlir | 44 +++++++++++++++++++
.../Dialect/Vector/TestVectorTransforms.cpp | 4 +-
5 files changed, 70 insertions(+), 4 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
index 2c548fb6740251..f88fbdf9e62765 100644
--- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
+++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
@@ -170,6 +170,16 @@ struct MaskableOpRewritePattern : OpRewritePattern<SourceOp> {
PatternRewriter &rewriter) const = 0;
};
+/// Returns true if the input Vector type can be linearized.
+///
+/// Linearization is meant in the sense of flattening vectors, e.g.:
+/// * vector<NxMxKxi32> -> vector<N*M*Kxi32>
+/// In this sense, Vectors that are either:
+/// * already linearized, or
+/// * contain more than 1 scalable dimensions,
+/// are not linearizable.
+bool isLinearizableVector(VectorType type);
+
} // namespace vector
/// Constructs a permutation map of invariant memref indices to vector
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 38536de43f13f2..c8043fbb7c3061 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -49,6 +49,11 @@ struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> {
Location loc = constOp.getLoc();
auto resType =
getTypeConverter()->convertType<VectorType>(constOp.getType());
+
+ if (resType.isScalable() && !isa<SplatElementsAttr>(constOp.getValue()))
+ return rewriter.notifyMatchFailure(
+ loc, "Cannot linearize a constant scalable vector that's not a splt");
+
if (!resType)
return rewriter.notifyMatchFailure(loc, "can't convert return type");
if (!isLessThanTargetBitWidth(constOp, targetVectorBitWidth))
@@ -104,11 +109,11 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
ConversionTarget &target, unsigned targetBitWidth) {
typeConverter.addConversion([](VectorType type) -> std::optional<Type> {
- // Ignore scalable vectors for now.
- if (type.getRank() <= 1 || type.isScalable())
+ if (!isLinearizableVector(type))
return type;
- return VectorType::get(type.getNumElements(), type.getElementType());
+ return VectorType::get(type.getNumElements(), type.getElementType(),
+ type.isScalable());
});
auto materializeCast = [](OpBuilder &builder, Type type, ValueRange inputs,
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index 63ed0947cf6ce2..a4415a80139af1 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -317,3 +317,8 @@ SmallVector<OpFoldResult> vector::getMixedSizesXfer(bool hasTensorSemantics,
: memref::getMixedSizes(rewriter, loc, base);
return mixedSourceDims;
}
+
+bool vector::isLinearizableVector(VectorType type) {
+ auto numScalableDims = llvm::count(type.getScalableDims(), true);
+ return ((type.getRank() > 1) && (numScalableDims <= 1));
+}
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 1b225c7a97d233..3ab68f19aa0c60 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -97,3 +97,47 @@ func.func @test_tensor_no_linearize(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf3
return %0, %arg0 : tensor<2x2xf32>, tensor<2x2xf32>
}
+
+// -----
+
+// ALL-LABEL: func.func @test_1_scalable_dim(
+// ALL-SAME: %[[ARG_0:.*]]: vector<2x[4]xf32>) -> vector<2x[4]xf32> {
+func.func @test_1_scalable_dim(%arg0: vector<2x[4]xf32>) -> vector<2x[4]xf32> {
+ // DEFAULT: %[[SC:.*]] = vector.shape_cast %[[ARG_0]] : vector<2x[4]xf32> to vector<[8]xf32>
+ // DEFAULT: %[[CST:.*]] = arith.constant dense<3.000000e+00> : vector<[8]xf32>
+ // BW-128: %[[CST:.*]] = arith.constant dense<3.000000e+00> : vector<2x[4]xf32>
+ // BW-0: %[[CST:.*]] = arith.constant dense<3.000000e+00> : vector<2x[4]xf32>
+ %0 = arith.constant dense<[[3., 3., 3., 3.], [3., 3., 3., 3.]]> : vector<2x[4]xf32>
+
+ // DEFAULT: %[[SIN:.*]] = math.sin %[[SC]] : vector<[8]xf32>
+ // BW-128: %[[SIN:.*]] = math.sin %[[ARG_0]] : vector<2x[4]xf32>
+ // BW-0: %[[SIN:.*]] = math.sin %[[ARG_0]] : vector<2x[4]xf32>
+ %1 = math.sin %arg0 : vector<2x[4]xf32>
+
+ // DEFAULT: %[[ADDF:.*]] = arith.addf %[[SIN]], %[[CST]] : vector<[8]xf32>
+ // BW-128: %[[RES:.*]] = arith.addf %[[CST]], %[[SIN]] : vector<2x[4]xf32>
+ // BW-0: %[[RES:.*]] = arith.addf %[[CST]], %[[SIN]] : vector<2x[4]xf32>
+ %2 = arith.addf %0, %1 : vector<2x[4]xf32>
+
+ // DEFAULT: %[[RES:.*]] = vector.shape_cast %[[ADDF]] : vector<[8]xf32> to vector<2x[4]xf32>
+ // ALL: return %[[RES]] : vector<2x[4]xf32>
+ return %2 : vector<2x[4]xf32>
+}
+
+// -----
+
+// ALL-LABEL: func.func @test_2_scalable_dims(
+// ALL-SAME: %[[VAL_0:.*]]: vector<[2]x[2]xf32>) -> vector<[2]x[2]xf32> {
+func.func @test_2_scalable_dims(%arg0: vector<[2]x[2]xf32>) -> vector<[2]x[2]xf32> {
+ // ALL: %[[CST:.*]] = arith.constant dense<2.000000e+00> : vector<[2]x[2]xf32>
+ %0 = arith.constant dense<[[2., 2.], [2., 2.]]> : vector<[2]x[2]xf32>
+
+ // ALL: %[[SIN:.*]] = math.sin %[[VAL_0]] : vector<[2]x[2]xf32>
+ %1 = math.sin %arg0 : vector<[2]x[2]xf32>
+
+ // ALL: %[[RES:.*]] = arith.addf %[[CST]], %[[SIN]] : vector<[2]x[2]xf32>
+ %2 = arith.addf %0, %1 : vector<[2]x[2]xf32>
+
+ // ALL: return %[[RES]] : vector<[2]x[2]xf32>
+ return %2 : vector<[2]x[2]xf32>
+}
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index f14fb18706d1b7..766ddae47c53b9 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -489,7 +489,9 @@ struct TestFlattenVectorTransferPatterns
Option<unsigned> targetVectorBitwidth{
*this, "target-vector-bitwidth",
llvm::cl::desc(
- "Minimum vector bitwidth to enable the flattening transformation"),
+ "Minimum vector bitwidth to enable the flattening transformation. "
+ "For scalable vectors this is the base size that's known at compile "
+ "time."),
llvm::cl::init(std::numeric_limits<unsigned>::max())};
void runOnOperation() override {
>From ef110ea47176e765373c435e627fcf5c7a34d5a1 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Wed, 27 Mar 2024 19:35:33 +0000
Subject: [PATCH 2/3] fixup! [mlir][vector] Add support for scalable vectors to
VectorLinearize
Address PR comments
---
.../Vector/Transforms/VectorLinearize.cpp | 3 ++-
mlir/test/Dialect/Vector/linearize.mlir | 23 ++++++++++++++-----
.../Dialect/Vector/TestVectorTransforms.cpp | 4 ++--
3 files changed, 21 insertions(+), 9 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index c8043fbb7c3061..4fa5b8a4865b4f 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -52,7 +52,8 @@ struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> {
if (resType.isScalable() && !isa<SplatElementsAttr>(constOp.getValue()))
return rewriter.notifyMatchFailure(
- loc, "Cannot linearize a constant scalable vector that's not a splt");
+ loc,
+ "Cannot linearize a constant scalable vector that's not a splat");
if (!resType)
return rewriter.notifyMatchFailure(loc, "can't convert return type");
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 3ab68f19aa0c60..e9288d5eac0411 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -1,5 +1,5 @@
-// RUN: mlir-opt %s -split-input-file -test-vector-linearize | FileCheck %s --check-prefixes=ALL,DEFAULT
-// RUN: mlir-opt %s -split-input-file -test-vector-linearize=target-vector-bitwidth=128 | FileCheck %s --check-prefixes=ALL,BW-128
+// RUN: mlir-opt %s -split-input-file -test-vector-linearize -verify-diagnostics | FileCheck %s --check-prefixes=ALL,DEFAULT
+// RUN: mlir-opt %s -split-input-file -test-vector-linearize=target-vector-bitwidth=128 -verify-diagnostics | FileCheck %s --check-prefixes=ALL,BW-128
// RUN: mlir-opt %s -split-input-file -test-vector-linearize=target-vector-bitwidth=0 | FileCheck %s --check-prefixes=ALL,BW-0
// ALL-LABEL: test_linearize
@@ -100,9 +100,9 @@ func.func @test_tensor_no_linearize(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf3
// -----
-// ALL-LABEL: func.func @test_1_scalable_dim(
+// ALL-LABEL: func.func @test_scalable_linearize(
// ALL-SAME: %[[ARG_0:.*]]: vector<2x[4]xf32>) -> vector<2x[4]xf32> {
-func.func @test_1_scalable_dim(%arg0: vector<2x[4]xf32>) -> vector<2x[4]xf32> {
+func.func @test_scalable_linearize(%arg0: vector<2x[4]xf32>) -> vector<2x[4]xf32> {
// DEFAULT: %[[SC:.*]] = vector.shape_cast %[[ARG_0]] : vector<2x[4]xf32> to vector<[8]xf32>
// DEFAULT: %[[CST:.*]] = arith.constant dense<3.000000e+00> : vector<[8]xf32>
// BW-128: %[[CST:.*]] = arith.constant dense<3.000000e+00> : vector<2x[4]xf32>
@@ -126,9 +126,9 @@ func.func @test_1_scalable_dim(%arg0: vector<2x[4]xf32>) -> vector<2x[4]xf32> {
// -----
-// ALL-LABEL: func.func @test_2_scalable_dims(
+// ALL-LABEL: func.func @test_scalable_no_linearize(
// ALL-SAME: %[[VAL_0:.*]]: vector<[2]x[2]xf32>) -> vector<[2]x[2]xf32> {
-func.func @test_2_scalable_dims(%arg0: vector<[2]x[2]xf32>) -> vector<[2]x[2]xf32> {
+func.func @test_scalable_no_linearize(%arg0: vector<[2]x[2]xf32>) -> vector<[2]x[2]xf32> {
// ALL: %[[CST:.*]] = arith.constant dense<2.000000e+00> : vector<[2]x[2]xf32>
%0 = arith.constant dense<[[2., 2.], [2., 2.]]> : vector<[2]x[2]xf32>
@@ -141,3 +141,14 @@ func.func @test_2_scalable_dims(%arg0: vector<[2]x[2]xf32>) -> vector<[2]x[2]xf3
// ALL: return %[[RES]] : vector<[2]x[2]xf32>
return %2 : vector<[2]x[2]xf32>
}
+
+// -----
+
+func.func @test_scalable_no_lineariz(%arg0: vector<2x[2]xf32>) -> vector<2x[2]xf32> {
+ // expected-error at +1 {{failed to legalize operation 'arith.constant' that was explicitly marked illegal}}
+ %0 = arith.constant dense<[[1., 1.], [3., 3.]]> : vector<2x[2]xf32>
+ %1 = math.sin %arg0 : vector<2x[2]xf32>
+ %2 = arith.addf %0, %1 : vector<2x[2]xf32>
+
+ return %2 : vector<2x[2]xf32>
+}
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 766ddae47c53b9..00622599910567 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -490,8 +490,8 @@ struct TestFlattenVectorTransferPatterns
*this, "target-vector-bitwidth",
llvm::cl::desc(
"Minimum vector bitwidth to enable the flattening transformation. "
- "For scalable vectors this is the base size that's known at compile "
- "time."),
+ "For scalable vectors this is the base size, i.e. the size "
+ "corresponding to vscale=1."),
llvm::cl::init(std::numeric_limits<unsigned>::max())};
void runOnOperation() override {
>From 1d259028ac244bd77b3d4d926945591ff4fad356 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Thu, 28 Mar 2024 11:39:02 +0000
Subject: [PATCH 3/3] fixup! fixup! [mlir][vector] Add support for scalable
vectors to VectorLinearize
Addressing PR comments
---
mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp | 2 +-
mlir/test/Dialect/Vector/linearize.mlir | 46 ++++++++++---------
2 files changed, 25 insertions(+), 23 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index a4415a80139af1..ebc6f5cbcaa9ed 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -320,5 +320,5 @@ SmallVector<OpFoldResult> vector::getMixedSizesXfer(bool hasTensorSemantics,
bool vector::isLinearizableVector(VectorType type) {
auto numScalableDims = llvm::count(type.getScalableDims(), true);
- return ((type.getRank() > 1) && (numScalableDims <= 1));
+ return (type.getRank() > 1) && (numScalableDims <= 1);
}
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index e9288d5eac0411..f0e9b3a05c066e 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -101,27 +101,29 @@ func.func @test_tensor_no_linearize(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf3
// -----
// ALL-LABEL: func.func @test_scalable_linearize(
-// ALL-SAME: %[[ARG_0:.*]]: vector<2x[4]xf32>) -> vector<2x[4]xf32> {
-func.func @test_scalable_linearize(%arg0: vector<2x[4]xf32>) -> vector<2x[4]xf32> {
- // DEFAULT: %[[SC:.*]] = vector.shape_cast %[[ARG_0]] : vector<2x[4]xf32> to vector<[8]xf32>
- // DEFAULT: %[[CST:.*]] = arith.constant dense<3.000000e+00> : vector<[8]xf32>
- // BW-128: %[[CST:.*]] = arith.constant dense<3.000000e+00> : vector<2x[4]xf32>
- // BW-0: %[[CST:.*]] = arith.constant dense<3.000000e+00> : vector<2x[4]xf32>
- %0 = arith.constant dense<[[3., 3., 3., 3.], [3., 3., 3., 3.]]> : vector<2x[4]xf32>
-
- // DEFAULT: %[[SIN:.*]] = math.sin %[[SC]] : vector<[8]xf32>
- // BW-128: %[[SIN:.*]] = math.sin %[[ARG_0]] : vector<2x[4]xf32>
- // BW-0: %[[SIN:.*]] = math.sin %[[ARG_0]] : vector<2x[4]xf32>
- %1 = math.sin %arg0 : vector<2x[4]xf32>
-
- // DEFAULT: %[[ADDF:.*]] = arith.addf %[[SIN]], %[[CST]] : vector<[8]xf32>
- // BW-128: %[[RES:.*]] = arith.addf %[[CST]], %[[SIN]] : vector<2x[4]xf32>
- // BW-0: %[[RES:.*]] = arith.addf %[[CST]], %[[SIN]] : vector<2x[4]xf32>
- %2 = arith.addf %0, %1 : vector<2x[4]xf32>
-
- // DEFAULT: %[[RES:.*]] = vector.shape_cast %[[ADDF]] : vector<[8]xf32> to vector<2x[4]xf32>
- // ALL: return %[[RES]] : vector<2x[4]xf32>
- return %2 : vector<2x[4]xf32>
+// ALL-SAME: %[[ARG_0:.*]]: vector<2x[2]xf32>) -> vector<2x[2]xf32> {
+func.func @test_scalable_linearize(%arg0: vector<2x[2]xf32>) -> vector<2x[2]xf32> {
+ // DEFAULT: %[[SC:.*]] = vector.shape_cast %[[ARG_0]] : vector<2x[2]xf32> to vector<[4]xf32>
+ // DEFAULT: %[[CST:.*]] = arith.constant dense<3.000000e+00> : vector<[4]xf32>
+ // BW-128: %[[SC:.*]] = vector.shape_cast %[[ARG_0]] : vector<2x[2]xf32> to vector<[4]xf32>
+ // BW-128: %[[CST:.*]] = arith.constant dense<3.000000e+00> : vector<[4]xf32>
+ // BW-0: %[[CST:.*]] = arith.constant dense<3.000000e+00> : vector<2x[2]xf32>
+ %0 = arith.constant dense<[[3., 3.], [3., 3.]]> : vector<2x[2]xf32>
+
+ // DEFAULT: %[[SIN:.*]] = math.sin %[[SC]] : vector<[4]xf32>
+ // BW-128: %[[SIN:.*]] = math.sin %[[SC]] : vector<[4]xf32>
+ // BW-0: %[[SIN:.*]] = math.sin %[[ARG_0]] : vector<2x[2]xf32>
+ %1 = math.sin %arg0 : vector<2x[2]xf32>
+
+ // DEFAULT: %[[ADDF:.*]] = arith.addf %[[SIN]], %[[CST]] : vector<[4]xf32>
+ // BW-128: %[[ADDF:.*]] = arith.addf %[[SIN]], %[[CST]] : vector<[4]xf32>
+ // BW-0: %[[RES:.*]] = arith.addf %[[CST]], %[[SIN]] : vector<2x[2]xf32>
+ %2 = arith.addf %0, %1 : vector<2x[2]xf32>
+
+ // DEFAULT: %[[RES:.*]] = vector.shape_cast %[[ADDF]] : vector<[4]xf32> to vector<2x[2]xf32>
+ // BW-128: %[[RES:.*]] = vector.shape_cast %[[ADDF]] : vector<[4]xf32> to vector<2x[2]xf32>
+ // ALL: return %[[RES]] : vector<2x[2]xf32>
+ return %2 : vector<2x[2]xf32>
}
// -----
@@ -144,7 +146,7 @@ func.func @test_scalable_no_linearize(%arg0: vector<[2]x[2]xf32>) -> vector<[2]x
// -----
-func.func @test_scalable_no_lineariz(%arg0: vector<2x[2]xf32>) -> vector<2x[2]xf32> {
+func.func @test_scalable_no_linearize(%arg0: vector<2x[2]xf32>) -> vector<2x[2]xf32> {
// expected-error at +1 {{failed to legalize operation 'arith.constant' that was explicitly marked illegal}}
%0 = arith.constant dense<[[1., 1.], [3., 3.]]> : vector<2x[2]xf32>
%1 = math.sin %arg0 : vector<2x[2]xf32>
More information about the Mlir-commits
mailing list