[Mlir-commits] [mlir] [mlir][vector] Extend mask calculation for vector.contract (PR #65733)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Fri Sep 8 02:46:37 PDT 2023
https://github.com/banach-space created https://github.com/llvm/llvm-project/pull/65733:
Make sure that when calculating the expected mask for `vector.contract`,
scalable sizes are correctly taken into account.
Depends on: #65724 65724
>From 21542ceb35b42350ecd52a97719ab7d5f8ffff2b Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Fri, 8 Sep 2023 08:13:24 +0000
Subject: [PATCH 1/2] [mlir][Vector] Make `vector.contract` work with scalable
vectors
This is just a small fix that makes sure that `vector.contract` works
with scalable vectors.
Rather than duplicating all the roundtrip tests for vector.contract, I'm
treating scalable vectors as an edge case and just adding a couple to
verify that this works.
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 3 ++-
mlir/test/Dialect/Vector/ops.mlir | 29 ++++++++++++++++++++++++
2 files changed, 31 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 2aaf1cb7e5878e4..6473c92a91aa64b 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -820,7 +820,8 @@ static LogicalResult verifyOutputShape(
return e.cast<AffineConstantExpr>().getValue();
}));
auto expected =
- VectorType::get(expectedShape, resVectorType.getElementType());
+ VectorType::get(expectedShape, resVectorType.getElementType(),
+ resVectorType.getScalableDims());
if (resVectorType != expected || accVectorType != expected)
return op.emitOpError(
"invalid accumulator/result vector shape, expected: ")
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 2154304965a5d04..f00bc6e97b350ea 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -307,6 +307,17 @@ func.func @contraction_to_scalar(%arg0: vector<10xf32>, %arg1: vector<10xf32>) -
return %0 : f32
}
+// CHECK-LABEL: @contraction_to_scalar_scalable
+func.func @contraction_to_scalar_scalable(%arg0: vector<[10]xf32>, %arg1: vector<[10]xf32>) -> f32 {
+ // CHECK: %[[C0:.*]] = arith.constant 0.000000e+00 : f32
+ %f0 = arith.constant 0.0: f32
+ // CHECK: %[[X:.*]] = vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["reduction"], kind = #vector.kind<add>} %{{.*}}, %{{.*}}, %[[C0]] : vector<[10]xf32>, vector<[10]xf32> into f32
+ %0 = vector.contract #contraction_to_scalar_trait %arg0, %arg1, %f0
+ : vector<[10]xf32>, vector<[10]xf32> into f32
+ // CHECK: return %[[X]] : f32
+ return %0 : f32
+}
+
// CHECK-LABEL: @contraction_extra_attrs
func.func @contraction_extra_attrs(%arg0: vector<10xf32>, %arg1: vector<10xf32>) -> f32 {
// CHECK: %[[C0:.*]] = arith.constant 0.000000e+00 : f32
@@ -392,6 +403,24 @@ func.func @contraction(%arg0 : vector<7x8x16x15xf32>, %arg1 : vector<8x16x7x5xf3
return
}
+#contraction_matmul_accesses = [
+ affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d2, d1)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+]
+#contraction_matmul_trait = {
+ indexing_maps = #contraction_matmul_accesses,
+ iterator_types = ["parallel", "parallel", "reduction"]
+}
+// CHECK-LABEL: @contraction_matmul_scalable
+func.func @contraction_matmul_scalable(%A: vector<8x1xf32>, %B: vector<1x[32]xf32>, %C: vector<8x[32]xf32>) -> vector<8x[32]xf32> {
+ // CHECK: %[[X:.*]] = vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} {{.*}}, {{.*}}, {{.*}} : vector<8x1xf32>, vector<1x[32]xf32> into vector<8x[32]xf32>
+ %res = vector.contract #contraction_matmul_trait %A, %B, %C
+ : vector<8x1xf32>, vector<1x[32]xf32> into vector<8x[32]xf32>
+ // CHECK: return %[[X]] : vector<8x[32]xf32>
+ return %res : vector<8x[32]xf32>
+}
+
// CHECK-LABEL: @create_vector_mask
func.func @create_vector_mask() {
// CHECK: %[[C2:.*]] = arith.constant 2 : index
>From 77eec46e06e47f49881308ceb468a0c63895b426 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Fri, 8 Sep 2023 09:40:39 +0000
Subject: [PATCH 2/2] [mlir][vector] Extend mask calculation for
vector.contract
Make sure that when calculating the expected mask for `vector.contract`,
scalable sizes are correctly taken into account.
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 17 +++++++++++------
mlir/test/Dialect/Vector/ops.mlir | 19 +++++++++++++++++++
...contract-to-parallel-arith-transforms.mlir | 1 +
3 files changed, 31 insertions(+), 6 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 6473c92a91aa64b..e753562c3fbd3f6 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -912,22 +912,27 @@ Type ContractionOp::getExpectedMaskType() {
unsigned numVecDims = lhsIdxMap.getNumDims();
SmallVector<int64_t> maskShape(numVecDims, ShapedType::kDynamic);
+ SmallVector<bool> maskShapeScalabledims(numVecDims, false);
// Using the information in the indexing maps, extract the size of each
// dimension in the vector.contract operation from the two input operands.
- for (auto [dimIdx, dimSize] : llvm::enumerate(lhsType.getShape()))
+ for (auto [dimIdx, dimSize] : llvm::enumerate(lhsType.getShape())) {
maskShape[lhsIdxMap.getDimPosition(dimIdx)] = dimSize;
- for (auto [dimIdx, dimSize] : llvm::enumerate(rhsType.getShape()))
+ maskShapeScalabledims[lhsIdxMap.getDimPosition(dimIdx)] =
+ lhsType.getScalableDims()[dimIdx];
+ }
+ for (auto [dimIdx, dimSize] : llvm::enumerate(rhsType.getShape())) {
maskShape[rhsIdxMap.getDimPosition(dimIdx)] = dimSize;
+ maskShapeScalabledims[rhsIdxMap.getDimPosition(dimIdx)] =
+ rhsType.getScalableDims()[dimIdx];
+ }
assert(!ShapedType::isDynamicShape(maskShape) &&
"Mask shape couldn't be computed");
- // TODO: Extend the scalable vector type representation with a bit map.
- assert(!lhsType.isScalable() && !rhsType.isScalable() &&
- "Scalable vectors are not supported yet");
return VectorType::get(maskShape,
- IntegerType::get(lhsType.getContext(), /*width=*/1));
+ IntegerType::get(lhsType.getContext(), /*width=*/1),
+ maskShapeScalabledims);
}
SmallVector<StringRef> ContractionOp::getTraitAttrNames() {
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index f00bc6e97b350ea..61118a35922f457 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -979,3 +979,22 @@ func.func @vector_scalable_extract(%sv: vector<[8]xi32>) {
%2 = vector.scalable.extract %sv[4] : vector<4xi32> from vector<[8]xi32>
return
}
+
+#matmat_accesses = [
+ affine_map<(i, j, k) -> (i, k)>,
+ affine_map<(i, j, k) -> (k, j)>,
+ affine_map<(i, j, k) -> (i, j)>
+]
+#matmat_trait = {
+ indexing_maps = #matmat_accesses,
+ iterator_types = ["parallel", "parallel", "reduction"]
+}
+func.func @matmul_masked_scalable(%arg0: vector<3x4xf32>,
+ %arg1: vector<4x[8]xf32>,
+ %arg2: vector<3x[8]xf32>,
+ %m : vector<3x[8]x4xi1>) -> vector<3x[8]xf32> {
+ %0 = vector.mask %m { vector.contract #matmat_trait %arg0, %arg1, %arg2
+ : vector<3x4xf32>, vector<4x[8]xf32> into vector<3x[8]xf32> } : vector<3x[8]x4xi1> -> vector<3x[8]xf32>
+ return %0 : vector<3x[8]xf32>
+}
+
diff --git a/mlir/test/Dialect/Vector/vector-contract-to-parallel-arith-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-parallel-arith-transforms.mlir
index 147f3ae921991f5..b0e48c4e85142cd 100644
--- a/mlir/test/Dialect/Vector/vector-contract-to-parallel-arith-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-to-parallel-arith-transforms.mlir
@@ -60,3 +60,4 @@ transform.sequence failures(propagate) {
transform.apply_patterns.vector.lower_contraction lowering_strategy = "parallelarith"
} : !transform.any_op
}
+
More information about the Mlir-commits
mailing list