[Mlir-commits] [mlir] 4b2ba5a - [mlir][sve] Add an e2e for linalg.matmul with mixed types (#73773)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Nov 29 13:21:15 PST 2023


Author: Andrzej WarzyƄski
Date: 2023-11-29T21:21:10Z
New Revision: 4b2ba5a61aaa9b7bea493425699f8a8d32a191b0

URL: https://github.com/llvm/llvm-project/commit/4b2ba5a61aaa9b7bea493425699f8a8d32a191b0
DIFF: https://github.com/llvm/llvm-project/commit/4b2ba5a61aaa9b7bea493425699f8a8d32a191b0.diff

LOG: [mlir][sve] Add an e2e for linalg.matmul with mixed types (#73773)

Apart from the test itself, this patch also updates a few patterns to
fix how new VectorType(s) are created. Namely, it makes sure that
"scalability" is correctly propagated.

Regression tests will be updated seperately while auditing Vector
dialect tests in the context of scalable vectors:
  * https://github.com/orgs/llvm/projects/23

Added: 
    mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/matmul_mixed_ty.mlir

Modified: 
    mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 582d627d1ce4ac0..6e7fab293d3a1ca 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -455,7 +455,7 @@ struct ReorderCastOpsOnBroadcast
 
     Type castResTy = getElementTypeOrSelf(op->getResult(0));
     if (auto vecTy = dyn_cast<VectorType>(bcastOp.getSourceType()))
-      castResTy = VectorType::get(vecTy.getShape(), castResTy);
+      castResTy = vecTy.clone(castResTy);
     auto *castOp =
         rewriter.create(op->getLoc(), op->getName().getIdentifier(),
                         bcastOp.getSource(), castResTy, op->getAttrs());
@@ -527,16 +527,14 @@ struct ReorderElementwiseOpsOnTranspose final
         srcValues.push_back(transposeOp.getVector());
       } else {
         // This is a constant. Create a reverse transpose op for it.
-        auto vectorType = VectorType::get(
-            srcType.getShape(),
-            cast<VectorType>(operand.getType()).getElementType());
+        auto vectorType =
+            srcType.clone(cast<VectorType>(operand.getType()).getElementType());
         srcValues.push_back(rewriter.create<vector::TransposeOp>(
             operand.getLoc(), vectorType, operand, invOrder));
       }
     }
 
-    auto vectorType = VectorType::get(
-        srcType.getShape(),
+    auto vectorType = srcType.clone(
         cast<VectorType>(op->getResultTypes()[0]).getElementType());
     Operation *elementwiseOp =
         rewriter.create(op->getLoc(), op->getName().getIdentifier(), srcValues,
@@ -1314,8 +1312,8 @@ struct CanonicalizeContractMatmulToMMT final
         Value trans =
             rewriter.create<vector::TransposeOp>(loc, sext.getIn(), perm);
         VectorType newType =
-            VectorType::get(cast<VectorType>(trans.getType()).getShape(),
-                            cast<VectorType>(mat.getType()).getElementType());
+            cast<VectorType>(trans.getType())
+                .clone(cast<VectorType>(mat.getType()).getElementType());
         return rewriter.create<arith::ExtSIOp>(loc, newType, trans);
       }
       if (auto zext = mat.getDefiningOp<arith::ExtUIOp>()) {

diff  --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/matmul_mixed_ty.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/matmul_mixed_ty.mlir
new file mode 100644
index 000000000000000..f4f2d87b4d0b42c
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/matmul_mixed_ty.mlir
@@ -0,0 +1,83 @@
+// DEFINE: %{compile} =  mlir-opt %s \
+// DEFINE:    -transform-interpreter -test-transform-dialect-erase-schedule \
+// DEFINE:    -one-shot-bufferize -func-bufferize -cse -canonicalize -convert-vector-to-scf -arm-sve-legalize-vector-storage \
+// DEFINE:    -convert-vector-to-llvm="enable-arm-sve" -test-lower-to-llvm -o %t
+// DEFINE: %{entry_point} = matmul_mixed_ty
+// DEFINE: %{run} = %mcr_aarch64_cmd %t -e %{entry_point} -entry-point-result=void --march=aarch64 --mattr="+sve"\
+// DEFINE:    -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils
+
+// RUN: %{compile}
+
+// RUN: %{run} | FileCheck %s
+
+func.func @matmul_mixed_ty() {
+  // Matrix dimensions
+  %K = arith.constant 3 : index
+  %M = arith.constant 5 : index
+  %N = arith.constant 15 : index
+  %c0_i8 = arith.constant 0 : i8
+  %c0_i32 = arith.constant 0 : i32
+
+  // Allocate the matrices
+  %A_alloc = bufferization.alloc_tensor(%M, %K) : tensor<?x?xi8>
+  %B_alloc = bufferization.alloc_tensor(%K, %N) : tensor<?x?xi8>
+  %C_alloc = bufferization.alloc_tensor(%M, %N) : tensor<?x?xi32>
+
+  // Initialise the matrices
+  %pi = arith.constant  123 : i8
+  %A = linalg.fill ins(%pi : i8) outs(%A_alloc : tensor<?x?xi8>) -> tensor<?x?xi8>
+  %B = linalg.fill ins(%pi : i8) outs(%B_alloc : tensor<?x?xi8>) -> tensor<?x?xi8>
+  %C_in = linalg.fill ins(%c0_i32 : i32) outs(%C_alloc : tensor<?x?xi32>) -> tensor<?x?xi32>
+
+  // Matmul
+  %C_out = linalg.matmul ins(%A, %B: tensor<?x?xi8>, tensor<?x?xi8>) outs(%C_in: tensor<?x?xi32>) -> tensor<?x?xi32>
+
+  // Print and verify the output
+  // CHECK-LABEL: SVE: START OF TEST OUTPUT
+  vector.print str "SVE: START OF TEST OUTPUT"
+
+  // CHECK-NEXT: Unranked Memref {{.*}} rank = 2 offset = 0 sizes = [5, 15] strides = [15, 1] data =
+  // CHECK-COUNT-5: [45387,   45387,   45387,   45387,   45387,   45387,   45387,   45387,   45387,   45387,   45387,   45387,   45387,   45387,   45387]
+  %xf = tensor.cast %C_out : tensor<?x?xi32> to tensor<*xi32>
+  call @printMemrefI32(%xf) : (tensor<*xi32>) -> ()
+
+  // CHECK-NEXT: SVE: END OF TEST OUTPUT
+  vector.print str "SVE: END OF TEST OUTPUT"
+
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+transform.named_sequence @__transform_main(%module: !transform.any_op {transform.readonly}) {
+    %matmul = transform.structured.match ops{["linalg.matmul"]} in %module
+      : (!transform.any_op) -> !transform.any_op
+
+    // Step 1: Tile
+    %module_with_tiled_loops, %loops:3 = transform.structured.tile_using_for %matmul [2, [4], 1]
+      : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+
+    // Step 2: Vectorize
+    %tiled_matmul = transform.structured.match ops{["linalg.matmul"]} in %module_with_tiled_loops
+      : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %tiled_matmul vector_sizes [2, [4], 1] : !transform.any_op
+
+    // Step 3: Lower vector.multi_reduction to vector.contract (+ some helpful patterns)
+    %func = transform.structured.match ops{["func.func"]} in %module
+      : (!transform.any_op) -> !transform.op<"func.func">
+    transform.apply_patterns to %func {
+      transform.apply_patterns.vector.reduction_to_contract
+      transform.apply_patterns.vector.transfer_permutation_patterns
+      transform.apply_patterns.vector.lower_masked_transfers
+    } : !transform.op<"func.func">
+
+    // Step 4: Lower vector.contract to vector.fma
+    transform.apply_patterns to %func {
+      transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct"
+      transform.apply_patterns.vector.lower_outerproduct
+    } : !transform.op<"func.func">
+
+    transform.yield
+  }
+}
+
+func.func private @printMemrefI32(%ptr : tensor<*xi32>)


        


More information about the Mlir-commits mailing list