[Mlir-commits] [mlir] [mlir][vector] Extend mask calculation for vector.contract (PR #65733)

Andrzej WarzyƄski llvmlistbot at llvm.org
Mon Sep 11 02:17:25 PDT 2023


https://github.com/banach-space updated https://github.com/llvm/llvm-project/pull/65733:

>From 988b67ec707896080216e07ed29609531f307cce Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Fri, 8 Sep 2023 09:40:39 +0000
Subject: [PATCH 1/2] [mlir][vector] Extend mask calculation for
 vector.contract

Make sure that when calculating the expected mask for `vector.contract`,
scalable sizes are correctly taken into account.
---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp      | 17 +++++++++++------
 mlir/test/Dialect/Vector/ops.mlir             | 19 +++++++++++++++++++
 ...contract-to-parallel-arith-transforms.mlir |  1 +
 3 files changed, 31 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 6473c92a91aa64b..e753562c3fbd3f6 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -912,22 +912,27 @@ Type ContractionOp::getExpectedMaskType() {
 
   unsigned numVecDims = lhsIdxMap.getNumDims();
   SmallVector<int64_t> maskShape(numVecDims, ShapedType::kDynamic);
+  SmallVector<bool> maskShapeScalabledims(numVecDims, false);
 
   // Using the information in the indexing maps, extract the size of each
   // dimension in the vector.contract operation from the two input operands.
-  for (auto [dimIdx, dimSize] : llvm::enumerate(lhsType.getShape()))
+  for (auto [dimIdx, dimSize] : llvm::enumerate(lhsType.getShape())) {
     maskShape[lhsIdxMap.getDimPosition(dimIdx)] = dimSize;
-  for (auto [dimIdx, dimSize] : llvm::enumerate(rhsType.getShape()))
+    maskShapeScalabledims[lhsIdxMap.getDimPosition(dimIdx)] =
+        lhsType.getScalableDims()[dimIdx];
+  }
+  for (auto [dimIdx, dimSize] : llvm::enumerate(rhsType.getShape())) {
     maskShape[rhsIdxMap.getDimPosition(dimIdx)] = dimSize;
+    maskShapeScalabledims[rhsIdxMap.getDimPosition(dimIdx)] =
+        rhsType.getScalableDims()[dimIdx];
+  }
 
   assert(!ShapedType::isDynamicShape(maskShape) &&
          "Mask shape couldn't be computed");
-  // TODO: Extend the scalable vector type representation with a bit map.
-  assert(!lhsType.isScalable() && !rhsType.isScalable() &&
-         "Scalable vectors are not supported yet");
 
   return VectorType::get(maskShape,
-                         IntegerType::get(lhsType.getContext(), /*width=*/1));
+                         IntegerType::get(lhsType.getContext(), /*width=*/1),
+                         maskShapeScalabledims);
 }
 
 SmallVector<StringRef> ContractionOp::getTraitAttrNames() {
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index f00bc6e97b350ea..61118a35922f457 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -979,3 +979,22 @@ func.func @vector_scalable_extract(%sv: vector<[8]xi32>) {
   %2 = vector.scalable.extract %sv[4] : vector<4xi32> from vector<[8]xi32>
   return
  }
+
+#matmat_accesses = [
+  affine_map<(i, j, k) -> (i, k)>,
+  affine_map<(i, j, k) -> (k, j)>,
+  affine_map<(i, j, k) -> (i, j)>
+]
+#matmat_trait = {
+  indexing_maps = #matmat_accesses,
+  iterator_types = ["parallel", "parallel", "reduction"]
+}
+func.func @matmul_masked_scalable(%arg0: vector<3x4xf32>,
+                                    %arg1: vector<4x[8]xf32>,
+                                    %arg2: vector<3x[8]xf32>,
+                                    %m : vector<3x[8]x4xi1>) -> vector<3x[8]xf32> {
+  %0 = vector.mask %m { vector.contract #matmat_trait %arg0, %arg1, %arg2
+  : vector<3x4xf32>, vector<4x[8]xf32> into vector<3x[8]xf32> } : vector<3x[8]x4xi1> -> vector<3x[8]xf32>
+  return %0 : vector<3x[8]xf32>
+}
+
diff --git a/mlir/test/Dialect/Vector/vector-contract-to-parallel-arith-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-parallel-arith-transforms.mlir
index 147f3ae921991f5..b0e48c4e85142cd 100644
--- a/mlir/test/Dialect/Vector/vector-contract-to-parallel-arith-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-to-parallel-arith-transforms.mlir
@@ -60,3 +60,4 @@ transform.sequence failures(propagate) {
     transform.apply_patterns.vector.lower_contraction lowering_strategy = "parallelarith"
   } : !transform.any_op
 }
+

>From 77b8395f88777b7f6f49ebd61f9faf4f6d5e040c Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Mon, 11 Sep 2023 09:08:43 +0000
Subject: [PATCH 2/2] fixup! [mlir][vector] Extend mask calculation for
 vector.contract

---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp      |  8 ++++----
 mlir/test/Dialect/Vector/ops.mlir             | 19 ++++++++++++-------
 ...contract-to-parallel-arith-transforms.mlir |  1 -
 3 files changed, 16 insertions(+), 12 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index e753562c3fbd3f6..1222542ee39fd6a 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -912,18 +912,18 @@ Type ContractionOp::getExpectedMaskType() {
 
   unsigned numVecDims = lhsIdxMap.getNumDims();
   SmallVector<int64_t> maskShape(numVecDims, ShapedType::kDynamic);
-  SmallVector<bool> maskShapeScalabledims(numVecDims, false);
+  SmallVector<bool> maskShapeScalableDims(numVecDims, false);
 
   // Using the information in the indexing maps, extract the size of each
   // dimension in the vector.contract operation from the two input operands.
   for (auto [dimIdx, dimSize] : llvm::enumerate(lhsType.getShape())) {
     maskShape[lhsIdxMap.getDimPosition(dimIdx)] = dimSize;
-    maskShapeScalabledims[lhsIdxMap.getDimPosition(dimIdx)] =
+    maskShapeScalableDims[lhsIdxMap.getDimPosition(dimIdx)] =
         lhsType.getScalableDims()[dimIdx];
   }
   for (auto [dimIdx, dimSize] : llvm::enumerate(rhsType.getShape())) {
     maskShape[rhsIdxMap.getDimPosition(dimIdx)] = dimSize;
-    maskShapeScalabledims[rhsIdxMap.getDimPosition(dimIdx)] =
+    maskShapeScalableDims[rhsIdxMap.getDimPosition(dimIdx)] =
         rhsType.getScalableDims()[dimIdx];
   }
 
@@ -932,7 +932,7 @@ Type ContractionOp::getExpectedMaskType() {
 
   return VectorType::get(maskShape,
                          IntegerType::get(lhsType.getContext(), /*width=*/1),
-                         maskShapeScalabledims);
+                         maskShapeScalableDims);
 }
 
 SmallVector<StringRef> ContractionOp::getTraitAttrNames() {
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 61118a35922f457..d41cee5ea67b0c5 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -989,12 +989,17 @@ func.func @vector_scalable_extract(%sv: vector<[8]xi32>) {
   indexing_maps = #matmat_accesses,
   iterator_types = ["parallel", "parallel", "reduction"]
 }
-func.func @matmul_masked_scalable(%arg0: vector<3x4xf32>,
-                                    %arg1: vector<4x[8]xf32>,
-                                    %arg2: vector<3x[8]xf32>,
-                                    %m : vector<3x[8]x4xi1>) -> vector<3x[8]xf32> {
-  %0 = vector.mask %m { vector.contract #matmat_trait %arg0, %arg1, %arg2
-  : vector<3x4xf32>, vector<4x[8]xf32> into vector<3x[8]xf32> } : vector<3x[8]x4xi1> -> vector<3x[8]xf32>
+// CHECK-LABEL:   func.func @contraction_masked_scalable(
+// CHECK-SAME:    %[[A:.*]]: vector<3x4xf32>,
+// CHECK-SAME:    %[[B:.*]]: vector<4x[8]xf32>,
+// CHECK-SAME:    %[[C:.*]]: vector<3x[8]xf32>,
+// CHECK-SAME:    %[[M:.*]]: vector<3x[8]x4xi1>) -> vector<3x[8]xf32> {
+func.func @contraction_masked_scalable(%A: vector<3x4xf32>,
+                                    %B: vector<4x[8]xf32>,
+                                    %C: vector<3x[8]xf32>,
+                                    %M : vector<3x[8]x4xi1>) -> vector<3x[8]xf32> {
+ // CHECK:  vector.mask %[[M]] { vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[A]], %[[B]], %[[C]] : vector<3x4xf32>, vector<4x[8]xf32> into vector<3x[8]xf32> } : vector<3x[8]x4xi1> -> vector<3x[8]xf32>
+  %0 = vector.mask %M { vector.contract #matmat_trait %A, %B, %C : vector<3x4xf32>, vector<4x[8]xf32> into vector<3x[8]xf32> } 
+    : vector<3x[8]x4xi1> -> vector<3x[8]xf32>
   return %0 : vector<3x[8]xf32>
 }
-
diff --git a/mlir/test/Dialect/Vector/vector-contract-to-parallel-arith-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-parallel-arith-transforms.mlir
index b0e48c4e85142cd..147f3ae921991f5 100644
--- a/mlir/test/Dialect/Vector/vector-contract-to-parallel-arith-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-to-parallel-arith-transforms.mlir
@@ -60,4 +60,3 @@ transform.sequence failures(propagate) {
     transform.apply_patterns.vector.lower_contraction lowering_strategy = "parallelarith"
   } : !transform.any_op
 }
-



More information about the Mlir-commits mailing list