[Mlir-commits] [mlir] [mlir][ArmSME] Verify ops on tile types post LLVM conversion (PR #92076)

Cullen Rhodes llvmlistbot at llvm.org
Wed May 15 06:50:36 PDT 2024


https://github.com/c-rhodes updated https://github.com/llvm/llvm-project/pull/92076

>From d350c493bded5812a93e11b4e3687d9b551e2d4d Mon Sep 17 00:00:00 2001
From: Cullen Rhodes <cullen.rhodes at arm.com>
Date: Tue, 14 May 2024 06:58:02 +0000
Subject: [PATCH 1/2] [mlir][ArmSME] Verify ops on tile types post LLVM
 conversion

Unsupported ops on tile types can become dead after
`-convert-arm-sme-to-llvm` resulting in incorrect results. Verify such
operations don't exist post-conversion and fail if they do.

Based on discussion from
https://discourse.llvm.org/t/on-improving-arm-sme-lowering-resilience-in-mlir/78543
---
 .../Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp  |  21 +++
 .../ArmSMEToLLVM/arm-sme-to-llvm.mlir         | 149 ++++++++++++------
 .../ArmSMEToLLVM/tile-spills-and-fills.mlir   |   2 +-
 .../Conversion/ArmSMEToLLVM/unsupported.mlir  |  14 +-
 4 files changed, 132 insertions(+), 54 deletions(-)

diff --git a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
index 3dbc8e9916df6..32c18c84d04b1 100644
--- a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
+++ b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
@@ -18,6 +18,7 @@
 #include "mlir/Dialect/ArmSME/IR/ArmSME.h"
 #include "mlir/Dialect/ArmSME/Transforms/Transforms.h"
 #include "mlir/Dialect/ArmSME/Utils/Utils.h"
+#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
@@ -877,6 +878,26 @@ struct ConvertArmSMEToLLVMPass
 
     if (failed(applyPartialConversion(function, target, std::move(patterns))))
       signalPassFailure();
+
+    // Walk the function and fail if there are unexpected operations on SME
+    // tile types after conversion.
+    function->walk([&](Operation *op) {
+      // These ops are legal post conversion, skip these.
+      if (isa<arm_sme::CopyTileOp, arm_sme::GetTileOp, cf::BranchOp>(op) ||
+          !op->isRegistered())
+        return;
+      if (llvm::any_of(op->getResultTypes(),
+                       [](Type type) {
+                         return arm_sme::isValidSMETileVectorType(type);
+                       }) ||
+          llvm::any_of(op->getOperandTypes(), [](Type type) {
+            return arm_sme::isValidSMETileVectorType(type);
+          })) {
+        op->emitOpError("unexpected operation with SME tile type after "
+                        "conversion to LLVM");
+        signalPassFailure();
+      }
+    });
   }
 };
 
diff --git a/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir b/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir
index 14b1f323da3a2..ef85f3d069d74 100644
--- a/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir
+++ b/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir
@@ -629,18 +629,18 @@ func.func @arm_sme_streaming_vl_double_words() -> index {
 
 // CHECK-LABEL: arm_sme_fmopa_2way_f16f16_to_f32
 // CHECK: "arm_sme.intr.mopa.wide"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xf16>, vector<[8]xf16>) -> ()
-func.func @arm_sme_fmopa_2way_f16f16_to_f32(%vecA: vector<[8]xf16>, %vecB: vector<[8]xf16>) -> vector<[4]x[4]xf32> {
+func.func @arm_sme_fmopa_2way_f16f16_to_f32(%vecA: vector<[8]xf16>, %vecB: vector<[8]xf16>) {
   %result = arm_sme.fmopa_2way %vecA, %vecB : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
-  return %result : vector<[4]x[4]xf32>
+  "test.some_use"(%result) : (vector<[4]x[4]xf32>) -> ()
 }
 
 // -----
 
 // CHECK-LABEL: arm_sme_fmopa_2way_bf16bf16_to_f32
 // CHECK: "arm_sme.intr.mopa.wide"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xbf16>, vector<[8]xbf16>) -> ()
-func.func @arm_sme_fmopa_2way_bf16bf16_to_f32(%vecA: vector<[8]xbf16>, %vecB: vector<[8]xbf16>) -> vector<[4]x[4]xf32> {
+func.func @arm_sme_fmopa_2way_bf16bf16_to_f32(%vecA: vector<[8]xbf16>, %vecB: vector<[8]xbf16>) {
   %result = arm_sme.fmopa_2way %vecA, %vecB : vector<[8]xbf16>, vector<[8]xbf16> into vector<[4]x[4]xf32>
-  return %result : vector<[4]x[4]xf32>
+  "test.some_use"(%result) : (vector<[4]x[4]xf32>) -> ()
 }
 
 //===----------------------------------------------------------------------===//
@@ -651,18 +651,18 @@ func.func @arm_sme_fmopa_2way_bf16bf16_to_f32(%vecA: vector<[8]xbf16>, %vecB: ve
 
 // CHECK-LABEL: arm_sme_fmops_2way_f16f16_to_f32
 // CHECK: "arm_sme.intr.mops.wide"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xf16>, vector<[8]xf16>) -> ()
-func.func @arm_sme_fmops_2way_f16f16_to_f32(%vecA: vector<[8]xf16>, %vecB: vector<[8]xf16>) -> vector<[4]x[4]xf32> {
+func.func @arm_sme_fmops_2way_f16f16_to_f32(%vecA: vector<[8]xf16>, %vecB: vector<[8]xf16>) {
   %result = arm_sme.fmops_2way %vecA, %vecB : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
-  return %result : vector<[4]x[4]xf32>
+  "test.some_use"(%result) : (vector<[4]x[4]xf32>) -> ()
 }
 
 // -----
 
 // CHECK-LABEL: arm_sme_fmops_2way_bf16bf16_to_f32
 // CHECK: "arm_sme.intr.mops.wide"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xbf16>, vector<[8]xbf16>) -> ()
-func.func @arm_sme_fmops_2way_bf16bf16_to_f32(%vecA: vector<[8]xbf16>, %vecB: vector<[8]xbf16>) -> vector<[4]x[4]xf32> {
+func.func @arm_sme_fmops_2way_bf16bf16_to_f32(%vecA: vector<[8]xbf16>, %vecB: vector<[8]xbf16>) {
   %result = arm_sme.fmops_2way %vecA, %vecB : vector<[8]xbf16>, vector<[8]xbf16> into vector<[4]x[4]xf32>
-  return %result : vector<[4]x[4]xf32>
+  "test.some_use"(%result) : (vector<[4]x[4]xf32>) -> ()
 }
 
 //===----------------------------------------------------------------------===//
@@ -673,9 +673,9 @@ func.func @arm_sme_fmops_2way_bf16bf16_to_f32(%vecA: vector<[8]xbf16>, %vecB: ve
 
 // CHECK-LABEL: arm_sme_smopa_2way_i16i16_to_i32
 // CHECK: "arm_sme.intr.smopa.za32"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
-func.func @arm_sme_smopa_2way_i16i16_to_i32(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[4]x[4]xi32> {
+func.func @arm_sme_smopa_2way_i16i16_to_i32(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) {
   %result = arm_sme.smopa_2way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
-  return %result : vector<[4]x[4]xi32>
+  "test.some_use"(%result) : (vector<[4]x[4]xi32>) -> ()
 }
 
 //===----------------------------------------------------------------------===//
@@ -686,9 +686,9 @@ func.func @arm_sme_smopa_2way_i16i16_to_i32(%vecA: vector<[8]xi16>, %vecB: vecto
 
 // CHECK-LABEL: arm_sme_smops_2way_i16i16_to_i32
 // CHECK: "arm_sme.intr.smops.za32"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
-func.func @arm_sme_smops_2way_i16i16_to_i32(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[4]x[4]xi32> {
+func.func @arm_sme_smops_2way_i16i16_to_i32(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) {
   %result = arm_sme.smops_2way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
-  return %result : vector<[4]x[4]xi32>
+  "test.some_use"(%result) : (vector<[4]x[4]xi32>) -> ()
 }
 
 //===----------------------------------------------------------------------===//
@@ -699,9 +699,10 @@ func.func @arm_sme_smops_2way_i16i16_to_i32(%vecA: vector<[8]xi16>, %vecB: vecto
 
 // CHECK-LABEL: arm_sme_umopa_2way_i16i16_to_i32
 // CHECK: "arm_sme.intr.umopa.za32"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
-func.func @arm_sme_umopa_2way_i16i16_to_i32(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[4]x[4]xi32> {
+func.func @arm_sme_umopa_2way_i16i16_to_i32(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) {
   %result = arm_sme.umopa_2way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
-  return %result : vector<[4]x[4]xi32>
+  "test.some_use"(%result) : (vector<[4]x[4]xi32>) -> ()
+  return
 }
 
 //===----------------------------------------------------------------------===//
@@ -712,9 +713,10 @@ func.func @arm_sme_umopa_2way_i16i16_to_i32(%vecA: vector<[8]xi16>, %vecB: vecto
 
 // CHECK-LABEL: arm_sme_umops_2way_i16i16_to_i32
 // CHECK: "arm_sme.intr.umops.za32"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
-func.func @arm_sme_umops_2way_i16i16_to_i32(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[4]x[4]xi32> {
+func.func @arm_sme_umops_2way_i16i16_to_i32(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) {
   %result = arm_sme.umops_2way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
-  return %result : vector<[4]x[4]xi32>
+  "test.some_use"(%result) : (vector<[4]x[4]xi32>) -> ()
+  return
 }
 
 //===----------------------------------------------------------------------===//
@@ -725,18 +727,20 @@ func.func @arm_sme_umops_2way_i16i16_to_i32(%vecA: vector<[8]xi16>, %vecB: vecto
 
 // CHECK-LABEL: arm_sme_smopa_4way_i8i8_to_i32
 // CHECK: "arm_sme.intr.smopa.wide"({{.*}}) <{tile_id = 0 : i32}> : (vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> ()
-func.func @arm_sme_smopa_4way_i8i8_to_i32(%vecA: vector<[16]xi8>, %vecB: vector<[16]xi8>) -> vector<[4]x[4]xi32> {
+func.func @arm_sme_smopa_4way_i8i8_to_i32(%vecA: vector<[16]xi8>, %vecB: vector<[16]xi8>) {
   %result = arm_sme.smopa_4way %vecA, %vecB : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
-  return %result : vector<[4]x[4]xi32>
+  "test.some_use"(%result) : (vector<[4]x[4]xi32>) -> ()
+  return
 }
 
 // -----
 
 // CHECK-LABEL: arm_sme_smopa_4way_i16i16_to_i64
 // CHECK: "arm_sme.intr.smopa.wide"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
-func.func @arm_sme_smopa_4way_i16i16_to_i64(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[2]x[2]xi64> {
+func.func @arm_sme_smopa_4way_i16i16_to_i64(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) {
   %result = arm_sme.smopa_4way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
-  return %result : vector<[2]x[2]xi64>
+  "test.some_use"(%result) : (vector<[2]x[2]xi64>) -> ()
+  return
 }
 
 //===----------------------------------------------------------------------===//
@@ -747,18 +751,20 @@ func.func @arm_sme_smopa_4way_i16i16_to_i64(%vecA: vector<[8]xi16>, %vecB: vecto
 
 // CHECK-LABEL: arm_sme_smops_4way_i8i8_to_i32
 // CHECK: "arm_sme.intr.smops.wide"({{.*}}) <{tile_id = 0 : i32}> : (vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> ()
-func.func @arm_sme_smops_4way_i8i8_to_i32(%vecA: vector<[16]xi8>, %vecB: vector<[16]xi8>) -> vector<[4]x[4]xi32> {
+func.func @arm_sme_smops_4way_i8i8_to_i32(%vecA: vector<[16]xi8>, %vecB: vector<[16]xi8>) {
   %result = arm_sme.smops_4way %vecA, %vecB : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
-  return %result : vector<[4]x[4]xi32>
+  "test.some_use"(%result) : (vector<[4]x[4]xi32>) -> ()
+  return
 }
 
 // -----
 
 // CHECK-LABEL: arm_sme_smops_4way_i16i16_to_i64
 // CHECK: "arm_sme.intr.smops.wide"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
-func.func @arm_sme_smops_4way_i16i16_to_i64(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[2]x[2]xi64> {
+func.func @arm_sme_smops_4way_i16i16_to_i64(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) {
   %result = arm_sme.smops_4way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
-  return %result : vector<[2]x[2]xi64>
+  "test.some_use"(%result) : (vector<[2]x[2]xi64>) -> ()
+  return
 }
 
 //===----------------------------------------------------------------------===//
@@ -769,18 +775,20 @@ func.func @arm_sme_smops_4way_i16i16_to_i64(%vecA: vector<[8]xi16>, %vecB: vecto
 
 // CHECK-LABEL: arm_sme_umopa_4way_i8i8_to_i32
 // CHECK: "arm_sme.intr.umopa.wide"({{.*}}) <{tile_id = 0 : i32}> : (vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> ()
-func.func @arm_sme_umopa_4way_i8i8_to_i32(%vecA: vector<[16]xi8>, %vecB: vector<[16]xi8>) -> vector<[4]x[4]xi32> {
+func.func @arm_sme_umopa_4way_i8i8_to_i32(%vecA: vector<[16]xi8>, %vecB: vector<[16]xi8>) {
   %result = arm_sme.umopa_4way %vecA, %vecB : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
-  return %result : vector<[4]x[4]xi32>
+  "test.some_use"(%result) : (vector<[4]x[4]xi32>) -> ()
+  return
 }
 
 // -----
 
 // CHECK-LABEL: arm_sme_umopa_4way_i16i16_to_i64
 // CHECK: "arm_sme.intr.umopa.wide"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
-func.func @arm_sme_umopa_4way_i16i16_to_i64(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[2]x[2]xi64> {
+func.func @arm_sme_umopa_4way_i16i16_to_i64(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) {
   %result = arm_sme.umopa_4way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
-  return %result : vector<[2]x[2]xi64>
+  "test.some_use"(%result) : (vector<[2]x[2]xi64>) -> ()
+  return
 }
 
 //===----------------------------------------------------------------------===//
@@ -791,18 +799,20 @@ func.func @arm_sme_umopa_4way_i16i16_to_i64(%vecA: vector<[8]xi16>, %vecB: vecto
 
 // CHECK-LABEL: arm_sme_umops_4way_i8i8_to_i32
 // CHECK: "arm_sme.intr.umops.wide"({{.*}}) <{tile_id = 0 : i32}> : (vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> ()
-func.func @arm_sme_umops_4way_i8i8_to_i32(%vecA: vector<[16]xi8>, %vecB: vector<[16]xi8>) -> vector<[4]x[4]xi32> {
+func.func @arm_sme_umops_4way_i8i8_to_i32(%vecA: vector<[16]xi8>, %vecB: vector<[16]xi8>) {
   %result = arm_sme.umops_4way %vecA, %vecB : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
-  return %result : vector<[4]x[4]xi32>
+  "test.some_use"(%result) : (vector<[4]x[4]xi32>) -> ()
+  return
 }
 
 // -----
 
 // CHECK-LABEL: arm_sme_umops_4way_i16i16_to_i64
 // CHECK: "arm_sme.intr.umops.wide"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
-func.func @arm_sme_umops_4way_i16i16_to_i64(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[2]x[2]xi64> {
+func.func @arm_sme_umops_4way_i16i16_to_i64(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) {
   %result = arm_sme.umops_4way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
-  return %result : vector<[2]x[2]xi64>
+  "test.some_use"(%result) : (vector<[2]x[2]xi64>) -> ()
+  return
 }
 
 //===----------------------------------------------------------------------===//
@@ -813,18 +823,20 @@ func.func @arm_sme_umops_4way_i16i16_to_i64(%vecA: vector<[8]xi16>, %vecB: vecto
 
 // CHECK-LABEL: arm_sme_sumopa_4way_i8i8_to_i32
 // CHECK: "arm_sme.intr.sumopa.wide"({{.*}}) <{tile_id = 0 : i32}> : (vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> ()
-func.func @arm_sme_sumopa_4way_i8i8_to_i32(%vecA: vector<[16]xi8>, %vecB: vector<[16]xi8>) -> vector<[4]x[4]xi32> {
+func.func @arm_sme_sumopa_4way_i8i8_to_i32(%vecA: vector<[16]xi8>, %vecB: vector<[16]xi8>) {
   %result = arm_sme.sumopa_4way %vecA, %vecB : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
-  return %result : vector<[4]x[4]xi32>
+  "test.some_use"(%result) : (vector<[4]x[4]xi32>) -> ()
+  return
 }
 
 // -----
 
 // CHECK-LABEL: arm_sme_sumopa_4way_i16i16_to_i64
 // CHECK: "arm_sme.intr.sumopa.wide"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
-func.func @arm_sme_sumopa_4way_i16i16_to_i64(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[2]x[2]xi64> {
+func.func @arm_sme_sumopa_4way_i16i16_to_i64(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) {
   %result = arm_sme.sumopa_4way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
-  return %result : vector<[2]x[2]xi64>
+  "test.some_use"(%result) : (vector<[2]x[2]xi64>) -> ()
+  return
 }
 
 //===----------------------------------------------------------------------===//
@@ -835,18 +847,20 @@ func.func @arm_sme_sumopa_4way_i16i16_to_i64(%vecA: vector<[8]xi16>, %vecB: vect
 
 // CHECK-LABEL: arm_sme_sumops_4way_i8i8_to_i32
 // CHECK: "arm_sme.intr.sumops.wide"({{.*}}) <{tile_id = 0 : i32}> : (vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> ()
-func.func @arm_sme_sumops_4way_i8i8_to_i32(%vecA: vector<[16]xi8>, %vecB: vector<[16]xi8>) -> vector<[4]x[4]xi32> {
+func.func @arm_sme_sumops_4way_i8i8_to_i32(%vecA: vector<[16]xi8>, %vecB: vector<[16]xi8>) {
   %result = arm_sme.sumops_4way %vecA, %vecB : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
-  return %result : vector<[4]x[4]xi32>
+  "test.some_use"(%result) : (vector<[4]x[4]xi32>) -> ()
+  return
 }
 
 // -----
 
 // CHECK-LABEL: arm_sme_sumops_4way_i16i16_to_i64
 // CHECK: "arm_sme.intr.sumops.wide"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
-func.func @arm_sme_sumops_4way_i16i16_to_i64(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[2]x[2]xi64> {
+func.func @arm_sme_sumops_4way_i16i16_to_i64(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) {
   %result = arm_sme.sumops_4way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
-  return %result : vector<[2]x[2]xi64>
+  "test.some_use"(%result) : (vector<[2]x[2]xi64>) -> ()
+  return
 }
 
 //===----------------------------------------------------------------------===//
@@ -857,18 +871,20 @@ func.func @arm_sme_sumops_4way_i16i16_to_i64(%vecA: vector<[8]xi16>, %vecB: vect
 
 // CHECK-LABEL: arm_sme_usmopa_4way_i8i8_to_i32
 // CHECK: "arm_sme.intr.usmopa.wide"({{.*}}) <{tile_id = 0 : i32}> : (vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> ()
-func.func @arm_sme_usmopa_4way_i8i8_to_i32(%vecA: vector<[16]xi8>, %vecB: vector<[16]xi8>) -> vector<[4]x[4]xi32> {
-  %reuslt = arm_sme.usmopa_4way %vecA, %vecB : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
-  return %reuslt : vector<[4]x[4]xi32>
+func.func @arm_sme_usmopa_4way_i8i8_to_i32(%vecA: vector<[16]xi8>, %vecB: vector<[16]xi8>) {
+  %result = arm_sme.usmopa_4way %vecA, %vecB : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+  "test.some_use"(%result) : (vector<[4]x[4]xi32>) -> ()
+  return
 }
 
 // -----
 
 // CHECK-LABEL: arm_sme_usmopa_4way_i16i16_to_i64
 // CHECK: "arm_sme.intr.usmopa.wide"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
-func.func @arm_sme_usmopa_4way_i16i16_to_i64(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[2]x[2]xi64> {
-  %reuslt = arm_sme.usmopa_4way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
-  return %reuslt : vector<[2]x[2]xi64>
+func.func @arm_sme_usmopa_4way_i16i16_to_i64(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) {
+  %result = arm_sme.usmopa_4way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+  "test.some_use"(%result) : (vector<[2]x[2]xi64>) -> ()
+  return
 }
 
 //===----------------------------------------------------------------------===//
@@ -879,16 +895,45 @@ func.func @arm_sme_usmopa_4way_i16i16_to_i64(%vecA: vector<[8]xi16>, %vecB: vect
 
 // CHECK-LABEL: arm_sme_usmops_4way_i8i8_to_i32
 // CHECK: "arm_sme.intr.usmops.wide"({{.*}}) <{tile_id = 0 : i32}> : (vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> ()
-func.func @arm_sme_usmops_4way_i8i8_to_i32(%vecA: vector<[16]xi8>, %vecB: vector<[16]xi8>) -> vector<[4]x[4]xi32> {
-  %reuslt = arm_sme.usmops_4way %vecA, %vecB : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
-  return %reuslt : vector<[4]x[4]xi32>
+func.func @arm_sme_usmops_4way_i8i8_to_i32(%vecA: vector<[16]xi8>, %vecB: vector<[16]xi8>) {
+  %result = arm_sme.usmops_4way %vecA, %vecB : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+  "test.some_use"(%result) : (vector<[4]x[4]xi32>) -> ()
+  return
 }
 
 // -----
 
 // CHECK-LABEL: arm_sme_usmops_4way_i16i16_to_i64
 // CHECK: "arm_sme.intr.usmops.wide"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
-func.func @arm_sme_usmops_4way_i16i16_to_i64(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[2]x[2]xi64> {
-  %reuslt = arm_sme.usmops_4way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
-  return %reuslt : vector<[2]x[2]xi64>
+func.func @arm_sme_usmops_4way_i16i16_to_i64(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) {
+  %result = arm_sme.usmops_4way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+  "test.some_use"(%result) : (vector<[2]x[2]xi64>) -> ()
+  return
+}
+
+//===----------------------------------------------------------------------===//
+// Operations on SME tile types allowed after conversion
+//===----------------------------------------------------------------------===//
+
+// -----
+
+// The following operations on SME tile types are permitted after conversion:
+//
+//   - arm_sme.copy_tile
+//   - arm_sme.get_tile
+//   - cf.br
+//   - any unregistered op such as 'test.some_use'.
+//
+// this test verifies this. Conversion will fail for operations with SME tile
+// types not in this list, this is tested in 'unsupported.mlir'.
+
+func.func @ops_on_tiles_legal_post_conversion(%ub : index) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %tile = arm_sme.get_tile : vector<[4]x[4]xf32>
+  %copy = arm_sme.copy_tile %tile : vector<[4]x[4]xf32>
+  cf.br ^bb1(%copy : vector<[4]x[4]xf32>)
+^bb1(%x : vector<[4]x[4]xf32>):
+  "test.some_use"(%x) : (vector<[4]x[4]xf32>) -> ()
+  return
 }
diff --git a/mlir/test/Conversion/ArmSMEToLLVM/tile-spills-and-fills.mlir b/mlir/test/Conversion/ArmSMEToLLVM/tile-spills-and-fills.mlir
index 2c3868d7f25cb..91c1b92b01224 100644
--- a/mlir/test/Conversion/ArmSMEToLLVM/tile-spills-and-fills.mlir
+++ b/mlir/test/Conversion/ArmSMEToLLVM/tile-spills-and-fills.mlir
@@ -141,7 +141,7 @@ func.func @very_excessive_spills(%useAllTiles : vector<[16]x[16]xi8>, %memref: m
   // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}}
   %loadSlice = arm_sme.load_tile_slice %memref[%c0, %c0], %mask, %tile, %c0 : memref<?x?xf32>, vector<[4]xi1>, vector<[4]x[4]xf32>
   "test.some_use"(%useAllTiles) : (vector<[16]x[16]xi8>) -> ()
-  return %loadSlice : vector<[4]x[4]xf32>
+  "test.some_use"(%loadSlice) : (vector<[4]x[4]xf32>) -> ()
 }
 // AFTER-TILE-ALLOC-LABEL: @very_excessive_spills
 //      AFTER-TILE-ALLOC: arm_sme.load_tile_slice
diff --git a/mlir/test/Conversion/ArmSMEToLLVM/unsupported.mlir b/mlir/test/Conversion/ArmSMEToLLVM/unsupported.mlir
index a62ca080ab8d9..b2c41f284fb86 100644
--- a/mlir/test/Conversion/ArmSMEToLLVM/unsupported.mlir
+++ b/mlir/test/Conversion/ArmSMEToLLVM/unsupported.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --pass-pipeline="builtin.module(func.func(convert-arm-sme-to-llvm))" -verify-diagnostics
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(func.func(convert-arm-sme-to-llvm))" -verify-diagnostics -split-input-file
 
 //===----------------------------------------------------------------------===//
 // arm_sme.outerproduct
@@ -6,9 +6,21 @@
 
 func.func @arm_sme_outerproduct_unsupported_type(%lhs : vector<[16]xi8>, %rhs : vector<[16]xi8>) {
   %acc = arm_sme.get_tile : vector<[16]x[16]xi8>
+  // expected-error at below {{unexpected operation with SME tile type after conversion to LLVM}}
   // expected-error at +2 {{failed to legalize operation 'arm_sme.outerproduct'}}
   // expected-error at +1 {{unsupported type}}
   %0 = arm_sme.outerproduct %lhs, %rhs  acc(%acc) : vector<[16]xi8>, vector<[16]xi8>
   "test.some_use"(%0) : (vector<[16]x[16]xi8>) -> ()
 }
 
+//===----------------------------------------------------------------------===//
+// Unsupported operations on SME tile types
+//===----------------------------------------------------------------------===//
+
+// -----
+
+func.func @unsupported_arith_op(%a : vector<[4]x[4]xf32>, %b : vector<[4]x[4]xf32>) {
+  // expected-error at below {{unexpected operation with SME tile type after conversion to LLVM}}
+  %0 = arith.addf %a, %b : vector<[4]x[4]xf32>
+  "test.some_use"(%0) : (vector<[4]x[4]xf32>) -> ()
+}

>From 0830d2feea76e63ae50869a9004c357753699417 Mon Sep 17 00:00:00 2001
From: Cullen Rhodes <cullen.rhodes at arm.com>
Date: Wed, 15 May 2024 13:48:21 +0000
Subject: [PATCH 2/2] factor lambda out

---
 mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp | 12 +++++-------
 1 file changed, 5 insertions(+), 7 deletions(-)

diff --git a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
index 32c18c84d04b1..7e77067936743 100644
--- a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
+++ b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
@@ -886,13 +886,11 @@ struct ConvertArmSMEToLLVMPass
       if (isa<arm_sme::CopyTileOp, arm_sme::GetTileOp, cf::BranchOp>(op) ||
           !op->isRegistered())
         return;
-      if (llvm::any_of(op->getResultTypes(),
-                       [](Type type) {
-                         return arm_sme::isValidSMETileVectorType(type);
-                       }) ||
-          llvm::any_of(op->getOperandTypes(), [](Type type) {
-            return arm_sme::isValidSMETileVectorType(type);
-          })) {
+      auto isSMETileType = [](Type type) {
+        return arm_sme::isValidSMETileVectorType(type);
+      };
+      if (llvm::any_of(op->getResultTypes(), isSMETileType) ||
+          llvm::any_of(op->getOperandTypes(), isSMETileType)) {
         op->emitOpError("unexpected operation with SME tile type after "
                         "conversion to LLVM");
         signalPassFailure();



More information about the Mlir-commits mailing list