[Mlir-commits] [mlir] [mlir][vector] Extend mask calculation for vector.contract (PR #65733)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Mon Sep 11 02:17:25 PDT 2023
https://github.com/banach-space updated https://github.com/llvm/llvm-project/pull/65733:
>From 988b67ec707896080216e07ed29609531f307cce Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Fri, 8 Sep 2023 09:40:39 +0000
Subject: [PATCH 1/2] [mlir][vector] Extend mask calculation for
vector.contract
Make sure that when calculating the expected mask for `vector.contract`,
scalable sizes are correctly taken into account.
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 17 +++++++++++------
mlir/test/Dialect/Vector/ops.mlir | 19 +++++++++++++++++++
...contract-to-parallel-arith-transforms.mlir | 1 +
3 files changed, 31 insertions(+), 6 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 6473c92a91aa64b..e753562c3fbd3f6 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -912,22 +912,27 @@ Type ContractionOp::getExpectedMaskType() {
unsigned numVecDims = lhsIdxMap.getNumDims();
SmallVector<int64_t> maskShape(numVecDims, ShapedType::kDynamic);
+ SmallVector<bool> maskShapeScalabledims(numVecDims, false);
// Using the information in the indexing maps, extract the size of each
// dimension in the vector.contract operation from the two input operands.
- for (auto [dimIdx, dimSize] : llvm::enumerate(lhsType.getShape()))
+ for (auto [dimIdx, dimSize] : llvm::enumerate(lhsType.getShape())) {
maskShape[lhsIdxMap.getDimPosition(dimIdx)] = dimSize;
- for (auto [dimIdx, dimSize] : llvm::enumerate(rhsType.getShape()))
+ maskShapeScalabledims[lhsIdxMap.getDimPosition(dimIdx)] =
+ lhsType.getScalableDims()[dimIdx];
+ }
+ for (auto [dimIdx, dimSize] : llvm::enumerate(rhsType.getShape())) {
maskShape[rhsIdxMap.getDimPosition(dimIdx)] = dimSize;
+ maskShapeScalabledims[rhsIdxMap.getDimPosition(dimIdx)] =
+ rhsType.getScalableDims()[dimIdx];
+ }
assert(!ShapedType::isDynamicShape(maskShape) &&
"Mask shape couldn't be computed");
- // TODO: Extend the scalable vector type representation with a bit map.
- assert(!lhsType.isScalable() && !rhsType.isScalable() &&
- "Scalable vectors are not supported yet");
return VectorType::get(maskShape,
- IntegerType::get(lhsType.getContext(), /*width=*/1));
+ IntegerType::get(lhsType.getContext(), /*width=*/1),
+ maskShapeScalabledims);
}
SmallVector<StringRef> ContractionOp::getTraitAttrNames() {
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index f00bc6e97b350ea..61118a35922f457 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -979,3 +979,22 @@ func.func @vector_scalable_extract(%sv: vector<[8]xi32>) {
%2 = vector.scalable.extract %sv[4] : vector<4xi32> from vector<[8]xi32>
return
}
+
+#matmat_accesses = [
+ affine_map<(i, j, k) -> (i, k)>,
+ affine_map<(i, j, k) -> (k, j)>,
+ affine_map<(i, j, k) -> (i, j)>
+]
+#matmat_trait = {
+ indexing_maps = #matmat_accesses,
+ iterator_types = ["parallel", "parallel", "reduction"]
+}
+func.func @matmul_masked_scalable(%arg0: vector<3x4xf32>,
+ %arg1: vector<4x[8]xf32>,
+ %arg2: vector<3x[8]xf32>,
+ %m : vector<3x[8]x4xi1>) -> vector<3x[8]xf32> {
+ %0 = vector.mask %m { vector.contract #matmat_trait %arg0, %arg1, %arg2
+ : vector<3x4xf32>, vector<4x[8]xf32> into vector<3x[8]xf32> } : vector<3x[8]x4xi1> -> vector<3x[8]xf32>
+ return %0 : vector<3x[8]xf32>
+}
+
diff --git a/mlir/test/Dialect/Vector/vector-contract-to-parallel-arith-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-parallel-arith-transforms.mlir
index 147f3ae921991f5..b0e48c4e85142cd 100644
--- a/mlir/test/Dialect/Vector/vector-contract-to-parallel-arith-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-to-parallel-arith-transforms.mlir
@@ -60,3 +60,4 @@ transform.sequence failures(propagate) {
transform.apply_patterns.vector.lower_contraction lowering_strategy = "parallelarith"
} : !transform.any_op
}
+
>From 77b8395f88777b7f6f49ebd61f9faf4f6d5e040c Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Mon, 11 Sep 2023 09:08:43 +0000
Subject: [PATCH 2/2] fixup! [mlir][vector] Extend mask calculation for
vector.contract
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 8 ++++----
mlir/test/Dialect/Vector/ops.mlir | 19 ++++++++++++-------
...contract-to-parallel-arith-transforms.mlir | 1 -
3 files changed, 16 insertions(+), 12 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index e753562c3fbd3f6..1222542ee39fd6a 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -912,18 +912,18 @@ Type ContractionOp::getExpectedMaskType() {
unsigned numVecDims = lhsIdxMap.getNumDims();
SmallVector<int64_t> maskShape(numVecDims, ShapedType::kDynamic);
- SmallVector<bool> maskShapeScalabledims(numVecDims, false);
+ SmallVector<bool> maskShapeScalableDims(numVecDims, false);
// Using the information in the indexing maps, extract the size of each
// dimension in the vector.contract operation from the two input operands.
for (auto [dimIdx, dimSize] : llvm::enumerate(lhsType.getShape())) {
maskShape[lhsIdxMap.getDimPosition(dimIdx)] = dimSize;
- maskShapeScalabledims[lhsIdxMap.getDimPosition(dimIdx)] =
+ maskShapeScalableDims[lhsIdxMap.getDimPosition(dimIdx)] =
lhsType.getScalableDims()[dimIdx];
}
for (auto [dimIdx, dimSize] : llvm::enumerate(rhsType.getShape())) {
maskShape[rhsIdxMap.getDimPosition(dimIdx)] = dimSize;
- maskShapeScalabledims[rhsIdxMap.getDimPosition(dimIdx)] =
+ maskShapeScalableDims[rhsIdxMap.getDimPosition(dimIdx)] =
rhsType.getScalableDims()[dimIdx];
}
@@ -932,7 +932,7 @@ Type ContractionOp::getExpectedMaskType() {
return VectorType::get(maskShape,
IntegerType::get(lhsType.getContext(), /*width=*/1),
- maskShapeScalabledims);
+ maskShapeScalableDims);
}
SmallVector<StringRef> ContractionOp::getTraitAttrNames() {
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 61118a35922f457..d41cee5ea67b0c5 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -989,12 +989,17 @@ func.func @vector_scalable_extract(%sv: vector<[8]xi32>) {
indexing_maps = #matmat_accesses,
iterator_types = ["parallel", "parallel", "reduction"]
}
-func.func @matmul_masked_scalable(%arg0: vector<3x4xf32>,
- %arg1: vector<4x[8]xf32>,
- %arg2: vector<3x[8]xf32>,
- %m : vector<3x[8]x4xi1>) -> vector<3x[8]xf32> {
- %0 = vector.mask %m { vector.contract #matmat_trait %arg0, %arg1, %arg2
- : vector<3x4xf32>, vector<4x[8]xf32> into vector<3x[8]xf32> } : vector<3x[8]x4xi1> -> vector<3x[8]xf32>
+// CHECK-LABEL: func.func @contraction_masked_scalable(
+// CHECK-SAME: %[[A:.*]]: vector<3x4xf32>,
+// CHECK-SAME: %[[B:.*]]: vector<4x[8]xf32>,
+// CHECK-SAME: %[[C:.*]]: vector<3x[8]xf32>,
+// CHECK-SAME: %[[M:.*]]: vector<3x[8]x4xi1>) -> vector<3x[8]xf32> {
+func.func @contraction_masked_scalable(%A: vector<3x4xf32>,
+ %B: vector<4x[8]xf32>,
+ %C: vector<3x[8]xf32>,
+ %M : vector<3x[8]x4xi1>) -> vector<3x[8]xf32> {
+ // CHECK: vector.mask %[[M]] { vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[A]], %[[B]], %[[C]] : vector<3x4xf32>, vector<4x[8]xf32> into vector<3x[8]xf32> } : vector<3x[8]x4xi1> -> vector<3x[8]xf32>
+ %0 = vector.mask %M { vector.contract #matmat_trait %A, %B, %C : vector<3x4xf32>, vector<4x[8]xf32> into vector<3x[8]xf32> }
+ : vector<3x[8]x4xi1> -> vector<3x[8]xf32>
return %0 : vector<3x[8]xf32>
}
-
diff --git a/mlir/test/Dialect/Vector/vector-contract-to-parallel-arith-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-parallel-arith-transforms.mlir
index b0e48c4e85142cd..147f3ae921991f5 100644
--- a/mlir/test/Dialect/Vector/vector-contract-to-parallel-arith-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-to-parallel-arith-transforms.mlir
@@ -60,4 +60,3 @@ transform.sequence failures(propagate) {
transform.apply_patterns.vector.lower_contraction lowering_strategy = "parallelarith"
} : !transform.any_op
}
-
More information about the Mlir-commits
mailing list