[Mlir-commits] [mlir] [mlir][vector] Extend mask calculation for vector.contract (PR #65733)

Andrzej WarzyƄski llvmlistbot at llvm.org
Fri Sep 8 02:46:37 PDT 2023


https://github.com/banach-space created https://github.com/llvm/llvm-project/pull/65733:

Make sure that when calculating the expected mask for `vector.contract`,
scalable sizes are correctly taken into account.

Depends on: #65724 65724

>From 21542ceb35b42350ecd52a97719ab7d5f8ffff2b Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Fri, 8 Sep 2023 08:13:24 +0000
Subject: [PATCH 1/2] [mlir][Vector] Make `vector.contract` work with scalable
 vectors

This is just a small fix that makes sure that `vector.contract` works
with scalable vectors.

Rather than duplicating all the roundtrip tests for vector.contract, I'm
treating scalable vectors as an edge case and just adding a couple to
verify that this works.
---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp |  3 ++-
 mlir/test/Dialect/Vector/ops.mlir        | 29 ++++++++++++++++++++++++
 2 files changed, 31 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 2aaf1cb7e5878e4..6473c92a91aa64b 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -820,7 +820,8 @@ static LogicalResult verifyOutputShape(
           return e.cast<AffineConstantExpr>().getValue();
         }));
     auto expected =
-        VectorType::get(expectedShape, resVectorType.getElementType());
+        VectorType::get(expectedShape, resVectorType.getElementType(),
+                        resVectorType.getScalableDims());
     if (resVectorType != expected || accVectorType != expected)
       return op.emitOpError(
                  "invalid accumulator/result vector shape, expected: ")
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 2154304965a5d04..f00bc6e97b350ea 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -307,6 +307,17 @@ func.func @contraction_to_scalar(%arg0: vector<10xf32>, %arg1: vector<10xf32>) -
   return %0 : f32
 }
 
+// CHECK-LABEL: @contraction_to_scalar_scalable
+func.func @contraction_to_scalar_scalable(%arg0: vector<[10]xf32>, %arg1: vector<[10]xf32>) -> f32 {
+  // CHECK:      %[[C0:.*]] = arith.constant 0.000000e+00 : f32
+  %f0 = arith.constant 0.0: f32
+  // CHECK:      %[[X:.*]] = vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["reduction"], kind = #vector.kind<add>} %{{.*}}, %{{.*}}, %[[C0]] : vector<[10]xf32>, vector<[10]xf32> into f32
+  %0 = vector.contract #contraction_to_scalar_trait %arg0, %arg1, %f0
+    : vector<[10]xf32>, vector<[10]xf32> into f32
+  // CHECK:      return %[[X]] : f32
+  return %0 : f32
+}
+
 // CHECK-LABEL: @contraction_extra_attrs
 func.func @contraction_extra_attrs(%arg0: vector<10xf32>, %arg1: vector<10xf32>) -> f32 {
   // CHECK:      %[[C0:.*]] = arith.constant 0.000000e+00 : f32
@@ -392,6 +403,24 @@ func.func @contraction(%arg0 : vector<7x8x16x15xf32>, %arg1 : vector<8x16x7x5xf3
   return
 }
 
+#contraction_matmul_accesses = [
+  affine_map<(d0, d1, d2) -> (d0, d2)>,
+  affine_map<(d0, d1, d2) -> (d2, d1)>,
+  affine_map<(d0, d1, d2) -> (d0, d1)>
+]
+#contraction_matmul_trait = {
+  indexing_maps = #contraction_matmul_accesses,
+  iterator_types = ["parallel", "parallel", "reduction"]
+}
+// CHECK-LABEL: @contraction_matmul_scalable
+func.func @contraction_matmul_scalable(%A: vector<8x1xf32>, %B: vector<1x[32]xf32>, %C: vector<8x[32]xf32>) -> vector<8x[32]xf32> {
+  // CHECK:   %[[X:.*]] = vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} {{.*}}, {{.*}}, {{.*}} : vector<8x1xf32>, vector<1x[32]xf32> into vector<8x[32]xf32>
+  %res = vector.contract #contraction_matmul_trait %A, %B, %C
+    : vector<8x1xf32>, vector<1x[32]xf32> into vector<8x[32]xf32>
+  // CHECK:   return %[[X]] : vector<8x[32]xf32>
+  return %res : vector<8x[32]xf32>
+}
+
 // CHECK-LABEL: @create_vector_mask
 func.func @create_vector_mask() {
   // CHECK:      %[[C2:.*]] = arith.constant 2 : index

>From 77eec46e06e47f49881308ceb468a0c63895b426 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Fri, 8 Sep 2023 09:40:39 +0000
Subject: [PATCH 2/2] [mlir][vector] Extend mask calculation for
 vector.contract

Make sure that when calculating the expected mask for `vector.contract`,
scalable sizes are correctly taken into account.
---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp      | 17 +++++++++++------
 mlir/test/Dialect/Vector/ops.mlir             | 19 +++++++++++++++++++
 ...contract-to-parallel-arith-transforms.mlir |  1 +
 3 files changed, 31 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 6473c92a91aa64b..e753562c3fbd3f6 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -912,22 +912,27 @@ Type ContractionOp::getExpectedMaskType() {
 
   unsigned numVecDims = lhsIdxMap.getNumDims();
   SmallVector<int64_t> maskShape(numVecDims, ShapedType::kDynamic);
+  SmallVector<bool> maskShapeScalabledims(numVecDims, false);
 
   // Using the information in the indexing maps, extract the size of each
   // dimension in the vector.contract operation from the two input operands.
-  for (auto [dimIdx, dimSize] : llvm::enumerate(lhsType.getShape()))
+  for (auto [dimIdx, dimSize] : llvm::enumerate(lhsType.getShape())) {
     maskShape[lhsIdxMap.getDimPosition(dimIdx)] = dimSize;
-  for (auto [dimIdx, dimSize] : llvm::enumerate(rhsType.getShape()))
+    maskShapeScalabledims[lhsIdxMap.getDimPosition(dimIdx)] =
+        lhsType.getScalableDims()[dimIdx];
+  }
+  for (auto [dimIdx, dimSize] : llvm::enumerate(rhsType.getShape())) {
     maskShape[rhsIdxMap.getDimPosition(dimIdx)] = dimSize;
+    maskShapeScalabledims[rhsIdxMap.getDimPosition(dimIdx)] =
+        rhsType.getScalableDims()[dimIdx];
+  }
 
   assert(!ShapedType::isDynamicShape(maskShape) &&
          "Mask shape couldn't be computed");
-  // TODO: Extend the scalable vector type representation with a bit map.
-  assert(!lhsType.isScalable() && !rhsType.isScalable() &&
-         "Scalable vectors are not supported yet");
 
   return VectorType::get(maskShape,
-                         IntegerType::get(lhsType.getContext(), /*width=*/1));
+                         IntegerType::get(lhsType.getContext(), /*width=*/1),
+                         maskShapeScalabledims);
 }
 
 SmallVector<StringRef> ContractionOp::getTraitAttrNames() {
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index f00bc6e97b350ea..61118a35922f457 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -979,3 +979,22 @@ func.func @vector_scalable_extract(%sv: vector<[8]xi32>) {
   %2 = vector.scalable.extract %sv[4] : vector<4xi32> from vector<[8]xi32>
   return
  }
+
+#matmat_accesses = [
+  affine_map<(i, j, k) -> (i, k)>,
+  affine_map<(i, j, k) -> (k, j)>,
+  affine_map<(i, j, k) -> (i, j)>
+]
+#matmat_trait = {
+  indexing_maps = #matmat_accesses,
+  iterator_types = ["parallel", "parallel", "reduction"]
+}
+func.func @matmul_masked_scalable(%arg0: vector<3x4xf32>,
+                                    %arg1: vector<4x[8]xf32>,
+                                    %arg2: vector<3x[8]xf32>,
+                                    %m : vector<3x[8]x4xi1>) -> vector<3x[8]xf32> {
+  %0 = vector.mask %m { vector.contract #matmat_trait %arg0, %arg1, %arg2
+  : vector<3x4xf32>, vector<4x[8]xf32> into vector<3x[8]xf32> } : vector<3x[8]x4xi1> -> vector<3x[8]xf32>
+  return %0 : vector<3x[8]xf32>
+}
+
diff --git a/mlir/test/Dialect/Vector/vector-contract-to-parallel-arith-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-parallel-arith-transforms.mlir
index 147f3ae921991f5..b0e48c4e85142cd 100644
--- a/mlir/test/Dialect/Vector/vector-contract-to-parallel-arith-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-to-parallel-arith-transforms.mlir
@@ -60,3 +60,4 @@ transform.sequence failures(propagate) {
     transform.apply_patterns.vector.lower_contraction lowering_strategy = "parallelarith"
   } : !transform.any_op
 }
+



More information about the Mlir-commits mailing list