[Mlir-commits] [mlir] cb89457 - [nlir][vector] Constrain `ContractionOpToMatmulOpLowering` (#102225)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Aug 7 02:18:18 PDT 2024


Author: Andrzej WarzyƄski
Date: 2024-08-07T10:18:14+01:00
New Revision: cb89457ff825926f0004711bef3d534df1f5576d

URL: https://github.com/llvm/llvm-project/commit/cb89457ff825926f0004711bef3d534df1f5576d
DIFF: https://github.com/llvm/llvm-project/commit/cb89457ff825926f0004711bef3d534df1f5576d.diff

LOG: [nlir][vector] Constrain `ContractionOpToMatmulOpLowering` (#102225)

Disables `ContractionOpToMatmulOpLowering` for scalable vectors. This
pattern is meant to enable lowering to `llvm.matrix.multiply` - I'm not
aware of any use of that in the context of scalable vectors.

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
    mlir/test/Dialect/Vector/vector-contract-to-matrix-intrinsics-transforms.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
index 3a799ce8e0bce..21261478f0648 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
@@ -1283,6 +1283,8 @@ class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
 /// This only kicks in when VectorTransformsOptions is set to `Matmul`.
 /// vector.transpose operations are inserted if the vector.contract op is not a
 /// row-major matrix multiply.
+///
+/// Scalable vectors are not supported.
 FailureOr<Value> ContractionOpToMatmulOpLowering::matchAndRewriteMaskableOp(
     vector::ContractionOp op, MaskingOpInterface maskOp,
     PatternRewriter &rew) const {
@@ -1302,13 +1304,18 @@ FailureOr<Value> ContractionOpToMatmulOpLowering::matchAndRewriteMaskableOp(
       !isReductionIterator(iteratorTypes[2]))
     return failure();
 
+  Type opResType = op.getType();
+  VectorType vecType = dyn_cast<VectorType>(opResType);
+  if (vecType && vecType.isScalable()) {
+    // Note - this is sufficient to reject all cases with scalable vectors.
+    return failure();
+  }
+
   Type elementType = op.getLhsType().getElementType();
   if (!elementType.isIntOrFloat())
     return failure();
 
-  Type dstElementType = op.getType();
-  if (auto vecType = dyn_cast<VectorType>(dstElementType))
-    dstElementType = vecType.getElementType();
+  Type dstElementType = vecType ? vecType.getElementType() : opResType;
   if (elementType != dstElementType)
     return failure();
 

diff  --git a/mlir/test/Dialect/Vector/vector-contract-to-matrix-intrinsics-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-matrix-intrinsics-transforms.mlir
index 78cf82e1ab6c1..4867a416e5d14 100644
--- a/mlir/test/Dialect/Vector/vector-contract-to-matrix-intrinsics-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-to-matrix-intrinsics-transforms.mlir
@@ -36,13 +36,23 @@
 //      CHECK:  %[[mm5:.*]] = vector.insert %[[mm4]], %[[mm3]] [1] : vector<3xf32> into vector<2x3xf32>
 //      CHECK:  %[[mm6:.*]] = arith.addf %[[C]], %[[mm5]] : vector<2x3xf32>
 func.func @matmul(%arg0: vector<2x4xf32>,
-                          %arg1: vector<4x3xf32>,
-                          %arg2: vector<2x3xf32>) -> vector<2x3xf32> {
+                  %arg1: vector<4x3xf32>,
+                  %arg2: vector<2x3xf32>) -> vector<2x3xf32> {
   %0 = vector.contract #matmat_trait %arg0, %arg1, %arg2
     : vector<2x4xf32>, vector<4x3xf32> into vector<2x3xf32>
   return %0 : vector<2x3xf32>
 }
 
+// CHECK-LABEL: func @matmul_scalable
+// CHECK-NOT: vector.matrix_multiply
+func.func @matmul_scalable(%arg0: vector<2x4xf32>,
+                           %arg1: vector<4x[3]xf32>,
+                           %arg2: vector<2x[3]xf32>) -> vector<2x[3]xf32> {
+  %0 = vector.contract #matmat_trait %arg0, %arg1, %arg2
+    : vector<2x4xf32>, vector<4x[3]xf32> into vector<2x[3]xf32>
+  return %0 : vector<2x[3]xf32>
+}
+
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
     %f = transform.structured.match ops{["func.func"]} in %module_op


        


More information about the Mlir-commits mailing list