[Mlir-commits] [mlir] Allow empty dimension arrays in `linalg::inferContractionDims` (PR #69496)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Oct 18 11:45:10 PDT 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-linalg
Author: None (bjacob)
<details>
<summary>Changes</summary>
This function was returning failure when any of the intersection sets was empty, but this is actually legitimate in "matrix times vector" cases, where some of the operands have lower dimensionality, implying unit-dimension semantics for the "missing" dimensions.
Example:
```mlir
func.func @<!-- -->transpose_extend_batch_matmul(
%vec: tensor<32x128xi16>,
%mat: tensor<11008x32x128xi4>) -> tensor<11008x32xi32> {
%c0_i32 = arith.constant 0 : i32
%cst_0 = arith.constant 0.000000e+00 : f32
%0 = tensor.empty() : tensor<11008x32xi32>
%1 = linalg.fill ins(%c0_i32 : i32) outs(%0 : tensor<11008x32xi32>) -> tensor<11008x32xi32>
%2 = tensor.empty() : tensor<11008xf32>
%3 = linalg.fill ins(%cst_0 : f32) outs(%2 : tensor<11008xf32>) -> tensor<11008xf32>
%batch_matmul_result = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2)>,
affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
affine_map<(d0, d1, d2) -> (d0, d1)>],
iterator_types = ["parallel", "parallel", "reduction"]}
ins(%vec, %mat : tensor<32x128xi16>, tensor<11008x32x128xi4>)
outs(%1 : tensor<11008x32xi32>) {
^bb0(%in: i16, %in_3: i4, %out: i32):
%19 = arith.extsi %in : i16 to i32
%20 = arith.extui %in_3 : i4 to i32
%21 = arith.muli %19, %20 : i32
%22 = arith.addi %21, %out : i32
linalg.yield %22 : i32
} -> tensor<11008x32xi32>
return %batch_matmul_result : tensor<11008x32xi32>
}
```
Here, we were returning failure because `ac` is empty. With this PR, we return this useful information:
```
batch: [ 1 ]
m: [ ]
n: [ 0 ]
k: [ 2 ]
```
---
Full diff: https://github.com/llvm/llvm-project/pull/69496.diff
1 Files Affected:
- (modified) mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp (-3)
``````````diff
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index ea50e1232a4c74a..5fde8d71cac3e75 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -227,9 +227,6 @@ mlir::linalg::inferContractionDims(LinalgOp linalgOp) {
linalgOp, linalgOp.getDpsInputOperand(1), red);
llvm::set_intersect(ra, rb);
- if (ac.empty() || bc.empty() || ra.empty())
- return failure();
-
// Return each set in sorted order.
ContractionDimensions dimensions{
SmallVector<unsigned, 2>(batches.begin(), batches.end()),
``````````
</details>
https://github.com/llvm/llvm-project/pull/69496
More information about the Mlir-commits
mailing list