[Mlir-commits] [mlir] 718af88 - [mlir][vector] Extend mask calculation for vector.contract (#65733)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Sep 11 03:34:51 PDT 2023
Author: Andrzej WarzyĆski
Date: 2023-09-11T11:34:47+01:00
New Revision: 718af8837639b6a47ae9bf911f668437f0ce0e3c
URL: https://github.com/llvm/llvm-project/commit/718af8837639b6a47ae9bf911f668437f0ce0e3c
DIFF: https://github.com/llvm/llvm-project/commit/718af8837639b6a47ae9bf911f668437f0ce0e3c.diff
LOG: [mlir][vector] Extend mask calculation for vector.contract (#65733)
Make sure that when calculating the expected mask for `vector.contract`,
scalable sizes are correctly taken into account.
Depends on: #65724
Added:
Modified:
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/test/Dialect/Vector/ops.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 6473c92a91aa64b..1222542ee39fd6a 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -912,22 +912,27 @@ Type ContractionOp::getExpectedMaskType() {
unsigned numVecDims = lhsIdxMap.getNumDims();
SmallVector<int64_t> maskShape(numVecDims, ShapedType::kDynamic);
+ SmallVector<bool> maskShapeScalableDims(numVecDims, false);
// Using the information in the indexing maps, extract the size of each
// dimension in the vector.contract operation from the two input operands.
- for (auto [dimIdx, dimSize] : llvm::enumerate(lhsType.getShape()))
+ for (auto [dimIdx, dimSize] : llvm::enumerate(lhsType.getShape())) {
maskShape[lhsIdxMap.getDimPosition(dimIdx)] = dimSize;
- for (auto [dimIdx, dimSize] : llvm::enumerate(rhsType.getShape()))
+ maskShapeScalableDims[lhsIdxMap.getDimPosition(dimIdx)] =
+ lhsType.getScalableDims()[dimIdx];
+ }
+ for (auto [dimIdx, dimSize] : llvm::enumerate(rhsType.getShape())) {
maskShape[rhsIdxMap.getDimPosition(dimIdx)] = dimSize;
+ maskShapeScalableDims[rhsIdxMap.getDimPosition(dimIdx)] =
+ rhsType.getScalableDims()[dimIdx];
+ }
assert(!ShapedType::isDynamicShape(maskShape) &&
"Mask shape couldn't be computed");
- // TODO: Extend the scalable vector type representation with a bit map.
- assert(!lhsType.isScalable() && !rhsType.isScalable() &&
- "Scalable vectors are not supported yet");
return VectorType::get(maskShape,
- IntegerType::get(lhsType.getContext(), /*width=*/1));
+ IntegerType::get(lhsType.getContext(), /*width=*/1),
+ maskShapeScalableDims);
}
SmallVector<StringRef> ContractionOp::getTraitAttrNames() {
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index f00bc6e97b350ea..d41cee5ea67b0c5 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -979,3 +979,27 @@ func.func @vector_scalable_extract(%sv: vector<[8]xi32>) {
%2 = vector.scalable.extract %sv[4] : vector<4xi32> from vector<[8]xi32>
return
}
+
+#matmat_accesses = [
+ affine_map<(i, j, k) -> (i, k)>,
+ affine_map<(i, j, k) -> (k, j)>,
+ affine_map<(i, j, k) -> (i, j)>
+]
+#matmat_trait = {
+ indexing_maps = #matmat_accesses,
+ iterator_types = ["parallel", "parallel", "reduction"]
+}
+// CHECK-LABEL: func.func @contraction_masked_scalable(
+// CHECK-SAME: %[[A:.*]]: vector<3x4xf32>,
+// CHECK-SAME: %[[B:.*]]: vector<4x[8]xf32>,
+// CHECK-SAME: %[[C:.*]]: vector<3x[8]xf32>,
+// CHECK-SAME: %[[M:.*]]: vector<3x[8]x4xi1>) -> vector<3x[8]xf32> {
+func.func @contraction_masked_scalable(%A: vector<3x4xf32>,
+ %B: vector<4x[8]xf32>,
+ %C: vector<3x[8]xf32>,
+ %M : vector<3x[8]x4xi1>) -> vector<3x[8]xf32> {
+ // CHECK: vector.mask %[[M]] { vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[A]], %[[B]], %[[C]] : vector<3x4xf32>, vector<4x[8]xf32> into vector<3x[8]xf32> } : vector<3x[8]x4xi1> -> vector<3x[8]xf32>
+ %0 = vector.mask %M { vector.contract #matmat_trait %A, %B, %C : vector<3x4xf32>, vector<4x[8]xf32> into vector<3x[8]xf32> }
+ : vector<3x[8]x4xi1> -> vector<3x[8]xf32>
+ return %0 : vector<3x[8]xf32>
+}
More information about the Mlir-commits
mailing list