[Mlir-commits] [mlir] f011d32 - [mlir][vector] Fix contraction op lowering with mixed types
Thomas Raoux
llvmlistbot at llvm.org
Thu Jun 16 09:41:12 PDT 2022
Author: Thomas Raoux
Date: 2022-06-16T16:40:56Z
New Revision: f011d32c3a625eb86d1e33a70100b0a031f5fcd4
URL: https://github.com/llvm/llvm-project/commit/f011d32c3a625eb86d1e33a70100b0a031f5fcd4
DIFF: https://github.com/llvm/llvm-project/commit/f011d32c3a625eb86d1e33a70100b0a031f5fcd4.diff
LOG: [mlir][vector] Fix contraction op lowering with mixed types
contraction op can have mixed type, add support for this case to the pattern
lowering contraction op to outerproduct.
Differential Revision: https://reviews.llvm.org/D127926
Added:
Modified:
mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
mlir/test/Dialect/Vector/vector-contract-transforms.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index e1f4cffb93552..87e3ea3d53890 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1324,6 +1324,12 @@ ContractionOpToMatmulOpLowering::matchAndRewrite(vector::ContractionOp op,
if (!elementType.isIntOrFloat())
return failure();
+ Type dstElementType = op.getType();
+ if (auto vecType = dstElementType.dyn_cast<VectorType>())
+ dstElementType = vecType.getElementType();
+ if (elementType != dstElementType)
+ return failure();
+
// Perform lhs + rhs transpositions to conform to matmul row-major semantics.
// Bail out if the contraction cannot be put in this form.
MLIRContext *ctx = op.getContext();
@@ -1416,11 +1422,29 @@ struct UnrolledOuterProductGenerator
return builder.create<vector::TransposeOp>(loc, v, perm);
}
+ Value promote(Value v, Type dstElementType) {
+ Type elementType = v.getType();
+ auto vecType = elementType.dyn_cast<VectorType>();
+ if (vecType)
+ elementType = vecType.getElementType();
+ if (elementType == dstElementType)
+ return v;
+ Type promotedType = dstElementType;
+ if (vecType)
+ promotedType = VectorType::get(vecType.getShape(), promotedType);
+ if (dstElementType.isa<FloatType>())
+ return builder.create<arith::ExtFOp>(loc, promotedType, v);
+ return builder.create<arith::ExtSIOp>(loc, promotedType, v);
+ }
+
Value outerProd(Value lhs, Value rhs, Value res, int reductionSize) {
assert(reductionSize > 0);
+ Type resElementType = res.getType().cast<VectorType>().getElementType();
for (int64_t k = 0; k < reductionSize; ++k) {
Value a = builder.create<vector::ExtractOp>(loc, lhs, k);
Value b = builder.create<vector::ExtractOp>(loc, rhs, k);
+ a = promote(a, resElementType);
+ b = promote(b, resElementType);
res = builder.create<vector::OuterProductOp>(loc, res.getType(), a, b,
res, kind);
}
diff --git a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
index e725d1883e9bd..70f86fd4dc6dd 100644
--- a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
@@ -891,6 +891,25 @@ func.func @matmul_0(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vecto
return %0 : vector<2x3xf32>
}
+// OUTERPRODUCT-LABEL: func @matmul_0_mixed
+// OUTERPRODUCT-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf16>,
+// OUTERPRODUCT-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf16>,
+// OUTERPRODUCT-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32>
+// OUTERPRODUCT: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
+// OUTERPRODUCT: %[[a0:.*]] = vector.extract %[[At]][0] : vector<1x2xf16>
+// OUTERPRODUCT: %[[b0:.*]] = vector.extract %[[B]][0] : vector<1x3xf16>
+// OUTERPRODUCT: %[[a1:.*]] = arith.extf %[[a0]] : vector<2xf16> to vector<2xf32>
+// OUTERPRODUCT: %[[b1:.*]] = arith.extf %[[b0]] : vector<3xf16> to vector<3xf32>
+// OUTERPRODUCT: %[[c0:.*]] = vector.outerproduct %[[a1]], %[[b1]], %[[C]]
+// OUTERPRODUCT: return %[[c0]] : vector<2x3xf32>
+func.func @matmul_0_mixed(%arg0: vector<2x1xf16>, %arg1: vector<1x3xf16>, %arg2: vector<2x3xf32>)
+-> vector<2x3xf32>
+{
+ %0 = vector.contract #matmat_trait_0 %arg0, %arg1, %arg2
+ : vector<2x1xf16>, vector<1x3xf16> into vector<2x3xf32>
+ return %0 : vector<2x3xf32>
+}
+
#matmat_accesses_1 = [
affine_map<(m, n, k) -> (m, k)>,
affine_map<(m, n, k) -> (n, k)>,
More information about the Mlir-commits
mailing list