[Mlir-commits] [mlir] [MLIR][VectorToLLVM] Handle scalable dim in createVectorLengthValue() (PR #93361)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri May 24 18:14:56 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Zhaoshi Zheng (zhaoshiz)
<details>
<summary>Changes</summary>
LLVM's Vector Predication Intrinsics require an explicit vector length parameter: https://llvm.org/docs/LangRef.html#vector-predication-intrinsics.
For a scalable vector type, this should be caculated as VectorScaleOp multiplied by base vector length, e.g.: for <[4]xf32> we should return: vscale * 4.
---
Full diff: https://github.com/llvm/llvm-project/pull/93361.diff
2 Files Affected:
- (modified) mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp (+12-2)
- (modified) mlir/test/Conversion/VectorToLLVM/vector-reduction-to-llvm.mlir (+38)
``````````diff
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index fe6bcc1c8b667..18bd9660525b4 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -523,7 +523,7 @@ static Value getOrCreateAccumulator(ConversionPatternRewriter &rewriter,
llvmType);
}
-/// Creates a constant value with the 1-D vector shape provided in `llvmType`.
+/// Creates a value with the 1-D vector shape provided in `llvmType`.
/// This is used as effective vector length by some intrinsics supporting
/// dynamic vector lengths at runtime.
static Value createVectorLengthValue(ConversionPatternRewriter &rewriter,
@@ -532,9 +532,19 @@ static Value createVectorLengthValue(ConversionPatternRewriter &rewriter,
auto vShape = vType.getShape();
assert(vShape.size() == 1 && "Unexpected multi-dim vector type");
- return rewriter.create<LLVM::ConstantOp>(
+ Value vLen = rewriter.create<LLVM::ConstantOp>(
loc, rewriter.getI32Type(),
rewriter.getIntegerAttr(rewriter.getI32Type(), vShape[0]));
+
+ if (!vType.getScalableDims()[0])
+ return vLen;
+
+ // Create VScale*vShape[0] and return it as vector length.
+ Value vScale = rewriter.create<vector::VectorScaleOp>(loc);
+ vScale = rewriter.create<arith::IndexCastOp>(
+ loc, rewriter.getI32Type(), vScale);
+ vLen = rewriter.create<arith::MulIOp>(loc, vLen, vScale);
+ return vLen;
}
/// Helper method to lower a `vector.reduction` op that performs an arithmetic
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-reduction-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-reduction-to-llvm.mlir
index f98a05f8d17e2..209afa217437b 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-reduction-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-reduction-to-llvm.mlir
@@ -79,6 +79,25 @@ func.func @masked_reduce_add_f32(%arg0: vector<16xf32>, %mask : vector<16xi1>) -
// CHECK: "llvm.intr.vp.reduce.fadd"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL]]) : (f32, vector<16xf32>, vector<16xi1>, i32) -> f32
+// -----
+
+func.func @masked_reduce_add_f32_scalable(%arg0: vector<[4]xf32>, %mask : vector<[4]xi1>) -> f32 {
+ %0 = vector.mask %mask { vector.reduction <add>, %arg0 : vector<[4]xf32> into f32 } : vector<[4]xi1> -> f32
+ return %0 : f32
+}
+
+// CHECK-LABEL: func.func @masked_reduce_add_f32_scalable(
+// CHECK-SAME: %[[INPUT:.*]]: vector<[4]xf32>,
+// CHECK-SAME: %[[MASK:.*]]: vector<[4]xi1>) -> f32 {
+// CHECK: %[[NEUTRAL:.*]] = llvm.mlir.constant(0.000000e+00 : f32) : f32
+// CHECK: %[[VL_BASE:.*]] = llvm.mlir.constant(4 : i32) : i32
+// CHECK: %[[VSCALE:.*]] = "llvm.intr.vscale"() : () -> i64
+// CHECK: %[[CAST_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index
+// CHECK: %[[CAST_I32:.*]] = arith.index_cast %[[CAST_IDX]] : index to i32
+// CHECK: %[[VL_MUL:.*]] = arith.muli %[[VL_BASE]], %[[CAST_I32]] : i32
+// CHECK: "llvm.intr.vp.reduce.fadd"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL_MUL]]) : (f32, vector<[4]xf32>, vector<[4]xi1>, i32) -> f32
+
+
// -----
func.func @masked_reduce_mul_f32(%arg0: vector<16xf32>, %mask : vector<16xi1>) -> f32 {
@@ -167,6 +186,25 @@ func.func @masked_reduce_add_i8(%arg0: vector<32xi8>, %mask : vector<32xi1>) ->
// CHECK: "llvm.intr.vp.reduce.add"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL]]) : (i8, vector<32xi8>, vector<32xi1>, i32) -> i8
+// -----
+
+func.func @masked_reduce_add_i8_scalable(%arg0: vector<[16]xi8>, %mask : vector<[16]xi1>) -> i8 {
+ %0 = vector.mask %mask { vector.reduction <add>, %arg0 : vector<[16]xi8> into i8 } : vector<[16]xi1> -> i8
+ return %0 : i8
+}
+
+// CHECK-LABEL: func.func @masked_reduce_add_i8_scalable(
+// CHECK-SAME: %[[INPUT:.*]]: vector<[16]xi8>,
+// CHECK-SAME: %[[MASK:.*]]: vector<[16]xi1>) -> i8 {
+// CHECK: %[[NEUTRAL:.*]] = llvm.mlir.constant(0 : i8) : i8
+// CHECK: %[[VL_BASE:.*]] = llvm.mlir.constant(16 : i32) : i32
+// CHECK: %[[VSCALE:.*]] = "llvm.intr.vscale"() : () -> i64
+// CHECK: %[[CAST_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index
+// CHECK: %[[CAST_I32:.*]] = arith.index_cast %[[CAST_IDX]] : index to i32
+// CHECK: %[[VL_MUL:.*]] = arith.muli %[[VL_BASE]], %[[CAST_I32]] : i32
+// CHECK: "llvm.intr.vp.reduce.add"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL_MUL]]) : (i8, vector<[16]xi8>, vector<[16]xi1>, i32) -> i8
+
+
// -----
func.func @masked_reduce_mul_i8(%arg0: vector<32xi8>, %mask : vector<32xi1>) -> i8 {
``````````
</details>
https://github.com/llvm/llvm-project/pull/93361
More information about the Mlir-commits
mailing list