[Mlir-commits] [mlir] 718af88 - [mlir][vector] Extend mask calculation for vector.contract (#65733)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Sep 11 03:34:51 PDT 2023


Author: Andrzej WarzyƄski
Date: 2023-09-11T11:34:47+01:00
New Revision: 718af8837639b6a47ae9bf911f668437f0ce0e3c

URL: https://github.com/llvm/llvm-project/commit/718af8837639b6a47ae9bf911f668437f0ce0e3c
DIFF: https://github.com/llvm/llvm-project/commit/718af8837639b6a47ae9bf911f668437f0ce0e3c.diff

LOG: [mlir][vector] Extend mask calculation for vector.contract (#65733)

Make sure that when calculating the expected mask for `vector.contract`,
scalable sizes are correctly taken into account.

Depends on: #65724

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/test/Dialect/Vector/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 6473c92a91aa64b..1222542ee39fd6a 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -912,22 +912,27 @@ Type ContractionOp::getExpectedMaskType() {
 
   unsigned numVecDims = lhsIdxMap.getNumDims();
   SmallVector<int64_t> maskShape(numVecDims, ShapedType::kDynamic);
+  SmallVector<bool> maskShapeScalableDims(numVecDims, false);
 
   // Using the information in the indexing maps, extract the size of each
   // dimension in the vector.contract operation from the two input operands.
-  for (auto [dimIdx, dimSize] : llvm::enumerate(lhsType.getShape()))
+  for (auto [dimIdx, dimSize] : llvm::enumerate(lhsType.getShape())) {
     maskShape[lhsIdxMap.getDimPosition(dimIdx)] = dimSize;
-  for (auto [dimIdx, dimSize] : llvm::enumerate(rhsType.getShape()))
+    maskShapeScalableDims[lhsIdxMap.getDimPosition(dimIdx)] =
+        lhsType.getScalableDims()[dimIdx];
+  }
+  for (auto [dimIdx, dimSize] : llvm::enumerate(rhsType.getShape())) {
     maskShape[rhsIdxMap.getDimPosition(dimIdx)] = dimSize;
+    maskShapeScalableDims[rhsIdxMap.getDimPosition(dimIdx)] =
+        rhsType.getScalableDims()[dimIdx];
+  }
 
   assert(!ShapedType::isDynamicShape(maskShape) &&
          "Mask shape couldn't be computed");
-  // TODO: Extend the scalable vector type representation with a bit map.
-  assert(!lhsType.isScalable() && !rhsType.isScalable() &&
-         "Scalable vectors are not supported yet");
 
   return VectorType::get(maskShape,
-                         IntegerType::get(lhsType.getContext(), /*width=*/1));
+                         IntegerType::get(lhsType.getContext(), /*width=*/1),
+                         maskShapeScalableDims);
 }
 
 SmallVector<StringRef> ContractionOp::getTraitAttrNames() {

diff  --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index f00bc6e97b350ea..d41cee5ea67b0c5 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -979,3 +979,27 @@ func.func @vector_scalable_extract(%sv: vector<[8]xi32>) {
   %2 = vector.scalable.extract %sv[4] : vector<4xi32> from vector<[8]xi32>
   return
  }
+
+#matmat_accesses = [
+  affine_map<(i, j, k) -> (i, k)>,
+  affine_map<(i, j, k) -> (k, j)>,
+  affine_map<(i, j, k) -> (i, j)>
+]
+#matmat_trait = {
+  indexing_maps = #matmat_accesses,
+  iterator_types = ["parallel", "parallel", "reduction"]
+}
+// CHECK-LABEL:   func.func @contraction_masked_scalable(
+// CHECK-SAME:    %[[A:.*]]: vector<3x4xf32>,
+// CHECK-SAME:    %[[B:.*]]: vector<4x[8]xf32>,
+// CHECK-SAME:    %[[C:.*]]: vector<3x[8]xf32>,
+// CHECK-SAME:    %[[M:.*]]: vector<3x[8]x4xi1>) -> vector<3x[8]xf32> {
+func.func @contraction_masked_scalable(%A: vector<3x4xf32>,
+                                    %B: vector<4x[8]xf32>,
+                                    %C: vector<3x[8]xf32>,
+                                    %M : vector<3x[8]x4xi1>) -> vector<3x[8]xf32> {
+ // CHECK:  vector.mask %[[M]] { vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[A]], %[[B]], %[[C]] : vector<3x4xf32>, vector<4x[8]xf32> into vector<3x[8]xf32> } : vector<3x[8]x4xi1> -> vector<3x[8]xf32>
+  %0 = vector.mask %M { vector.contract #matmat_trait %A, %B, %C : vector<3x4xf32>, vector<4x[8]xf32> into vector<3x[8]xf32> } 
+    : vector<3x[8]x4xi1> -> vector<3x[8]xf32>
+  return %0 : vector<3x[8]xf32>
+}


        


More information about the Mlir-commits mailing list