[Mlir-commits] [mlir] [MLIR][Vector] Allow Scalable Dim in OneDimMultiReductionToTwoDim (PR #89978)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Apr 24 12:30:13 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Zhaoshi Zheng (zhaoshiz)
<details>
<summary>Changes</summary>
To correctly lower multi_reduction of 1-dim scalable vector, e.g.: <[4]xf32>
---
Full diff: https://github.com/llvm/llvm-project/pull/89978.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp (+6-3)
- (modified) mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir (+18)
``````````diff
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index 2f21c50c63473b..240fedf4ccb651 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -438,7 +438,9 @@ struct OneDimMultiReductionToTwoDim
auto srcVectorType = multiReductionOp.getSourceVectorType();
auto srcShape = srcVectorType.getShape();
auto castedType = VectorType::get(ArrayRef<int64_t>{1, srcShape.back()},
- srcVectorType.getElementType());
+ srcVectorType.getElementType(),
+ ArrayRef<bool>{false, srcVectorType.getScalableDims().back()});
+
auto accType =
VectorType::get(ArrayRef<int64_t>{1}, srcVectorType.getElementType());
assert(!llvm::isa<VectorType>(multiReductionOp.getDestType()) &&
@@ -455,10 +457,11 @@ struct OneDimMultiReductionToTwoDim
loc, accType, multiReductionOp.getAcc());
Value castMask;
if (maskableOp.isMasked()) {
- auto maskType = llvm::cast<ShapedType>(mask.getType());
+ auto maskType = llvm::cast<VectorType>(mask.getType());
auto castMaskType =
VectorType::get(ArrayRef<int64_t>{1, maskType.getShape().back()},
- maskType.getElementType());
+ maskType.getElementType(),
+ ArrayRef<bool>{false, maskType.getScalableDims().back()});
castMask = rewriter.create<vector::BroadcastOp>(loc, castMaskType, mask);
}
diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir
index 22808aa7d6acc3..a5fe83118a0cf2 100644
--- a/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir
@@ -253,6 +253,7 @@ func.func private @scalable_dims(%A : vector<8x[4]x2xf32>, %B: vector<8x[4]xf32>
%0 = vector.multi_reduction <add>, %A, %B [2] : vector<8x[4]x2xf32> to vector<8x[4]xf32>
return %0 : vector<8x[4]xf32>
}
+
// CHECK-LABEL: func.func private @scalable_dims(
// CHECK-SAME: %[[VAL_0:.*]]: vector<8x[4]x2xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: vector<8x[4]xf32>) -> vector<8x[4]xf32> {
@@ -281,6 +282,23 @@ func.func private @scalable_dims(%A : vector<8x[4]x2xf32>, %B: vector<8x[4]xf32>
// CHECK: %[[VAL_163:.*]] = vector.shape_cast %[[VAL_162]] : vector<[32]xf32> to vector<8x[4]xf32>
// CHECK: return %[[VAL_163]] : vector<8x[4]xf32>
+// Check that OneDimMultiReductionToTwoDim handles scalable dim
+func.func private @scalable_dim_1d(%A: vector<[4]xf32>, %B: f32, %C: vector<[4]xi1>) -> f32 {
+ %0 = vector.mask %C { vector.multi_reduction <add>, %A, %B [0] : vector<[4]xf32> to f32 } : vector<[4]xi1> -> f32
+ return %0 : f32
+}
+
+// CHECK-LABEL: func.func private @scalable_dim_1d(
+// CHECK-SAME: %[[ARG_0:.*]]: vector<[4]xf32>,
+// CHECK-SAME: %[[ARG_1:.*]]: f32,
+// CHECK-SAME: %[[ARG_2:.*]]: vector<[4]xi1>) -> f32 {
+// CHECK-DAG: %[[VAL_0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[VAL_1:.*]] = arith.constant dense<0.000000e+00> : vector<1xf32>
+// CHECK: %[[VAL_2:.*]] = vector.mask %[[ARG_2]] { vector.reduction <add>, %[[ARG_0]], %[[ARG_1]] : vector<[4]xf32> into f32 } : vector<[4]xi1> -> f32
+// CHECK: %[[VAL_3:.*]] = vector.insertelement %[[VAL_2]], %[[VAL_1]][%[[VAL_0]] : index] : vector<1xf32>
+// CHECK: %[[VAL_4:.*]] = vector.extract %[[VAL_3]][0] : f32 from vector<1xf32>
+// CHECK: return %[[VAL_4]] : f32
+
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%root : !transform.any_op {transform.readonly}) {
%func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">
``````````
</details>
https://github.com/llvm/llvm-project/pull/89978
More information about the Mlir-commits
mailing list