[Mlir-commits] [mlir] [mlir][gpu] Add pass for emulating unsupported types. (PR #138087)

Md Abdullah Shahneous Bari llvmlistbot at llvm.org
Mon May 5 16:22:24 PDT 2025


================
@@ -0,0 +1,141 @@
+// RUN: mlir-opt -verify-diagnostics -imitate-unsupported-types="source-types=bf16 target-types=i16" --canonicalize -split-input-file %s | FileCheck %s
+
+// CHECK: module @builtin_module
+module @builtin_module {
+  // CHECK: gpu.module @gpu_func_module {
+  gpu.module @gpu_func_module attributes{} {
+    // CHECK-LABEL: gpu.func @arith_and_vector_ops
+    // CHECK-SAME: (%[[ARG0:.*]]: memref<10x10xi16>, %[[ARG1:.*]]: memref<10x10xf32>, %[[ARG2:.*]]: vector<10x10xi16>, %[[ARG3:.*]]: memref<10x10xi16>, %[[ARG4:.*]]: vector<10x10xi16>) kernel
+    gpu.func @arith_and_vector_ops(%arg0: memref<10x10xbf16>, %arg1: memref<10x10xf32>, %arg2: vector<10x10xbf16>, %arg3: memref<10x10xi16>, %arg4: vector<10x10xi16>) kernel attributes {} {
+
+      %c0 = arith.constant 0 : index
+
+      // CHECK: %[[ARG2_CAST:.*]] = arith.bitcast %[[ARG2]] : vector<10x10xi16> to vector<10x10xbf16>
+      // CHECK: %[[LOAD1:.*]] = vector.load %[[ARG0]][%c0, %c0] : memref<10x10xi16>, vector<10x10xi16>
+      // CHECK: %[[BITCAST1:.*]] = arith.bitcast %[[LOAD1]] : vector<10x10xi16> to vector<10x10xbf16>
+      %2 = vector.load %arg0[%c0, %c0] : memref<10x10xbf16>, vector<10x10xbf16>
+
+      // CHECK: %[[ADDF:.*]] = arith.addf %[[BITCAST1]], %[[ARG2_CAST]] : vector<10x10xbf16>
+      %add = arith.addf %2, %arg2 : vector<10x10xbf16>
+
+      // CHECK: %[[EXTF1:.*]] = arith.extf %[[BITCAST1]] : vector<10x10xbf16> to vector<10x10xf32>
+      %3 = arith.extf %2 : vector<10x10xbf16> to vector<10x10xf32>
+
+      // CHECK: %[[EXTF2:.*]] = arith.extf %[[ADDF]] : vector<10x10xbf16> to vector<10x10xf32>
+      %4 = arith.extf %add : vector<10x10xbf16> to vector<10x10xf32>
+
+      // CHECK: %[[ADDF2:.*]] = arith.addf %[[EXTF1]], %[[EXTF2]] : vector<10x10xf32>
+      %5 = arith.addf %3, %4 : vector<10x10xf32>
+
+      // CHECK: %[[TRUNCF:.*]] = arith.truncf %[[ADDF2]] : vector<10x10xf32> to vector<10x10xbf16>
+      %6 = arith.truncf %5 : vector<10x10xf32> to vector<10x10xbf16>
+
+      // CHECK: %[[TRUNCF_CAST:.*]] = arith.bitcast %[[TRUNCF]] : vector<10x10xbf16> to vector<10x10xi16>
+      // CHECK: vector.store %[[TRUNCF_CAST]], %[[ARG0]][%c0, %c0] : memref<10x10xi16>, vector<10x10xi16>
+      vector.store %6, %arg0[%c0, %c0] : memref<10x10xbf16>, vector<10x10xbf16>
+
+      // CHECK: %[[LOAD2:.*]] = vector.load %[[ARG3]][%c0, %c0] : memref<10x10xi16>, vector<10x10xi16>
+      %7 = vector.load %arg3[%c0, %c0] : memref<10x10xi16>, vector<10x10xi16>
+
+      // CHECK: %[[ADDI:.*]] = arith.addi %[[LOAD2]], %[[ARG4]] : vector<10x10xi16>
+      %8 = arith.addi %7, %arg4 : vector<10x10xi16>
+
+      // CHECK: vector.store %[[ADDI]], %[[ARG3]][%c0, %c0] : memref<10x10xi16>, vector<10x10xi16>
+      vector.store %8, %arg3[%c0, %c0] : memref<10x10xi16>, vector<10x10xi16>
+
+      gpu.return
+    }
+  }
+}
+
+// -----
+
+
+// CHECK: module @caller_callee_launch_func_module attributes {gpu.container_module}
+module @caller_callee_launch_func_module attributes {gpu.container_module} {
+
+  // CHECK: gpu.module @caller_callee_gpu_module {
+  gpu.module @caller_callee_gpu_module attributes{} {
+
+    // CHECK: gpu.func @caller_func(%[[ARG0:.*]]: memref<10x10xi16>, %[[ARG1:.*]]: vector<10x10xi16>) kernel {
+    gpu.func @caller_func(%arg0: memref<10x10xbf16>, %arg1: vector<10x10xbf16>) kernel attributes {} {
+      %c0 = arith.constant 0 : index
+
+      // CHECK: %[[CALL_RET:.*]] = func.call @callee_constant_return() : () -> vector<10x10xi16>
+      %func_result = func.call @callee_constant_return() : () -> vector<10x10xbf16>
+
+      // CHECK: vector.store %[[CALL_RET]], %[[ARG0]][%c0, %c0] : memref<10x10xi16>, vector<10x10xi16>
+      vector.store %func_result, %arg0[%c0, %c0] : memref<10x10xbf16>, vector<10x10xbf16>
+
+      // CHECK: func.call @callee_func(%[[CALL_RET]]) : (vector<10x10xi16>) -> ()
+      func.call @callee_func(%func_result) : (vector<10x10xbf16>) -> ()
+
+      gpu.return
+    }
+
+    // CHECK: func.func @callee_constant_return() -> vector<10x10xi16> {
+    func.func @callee_constant_return() -> vector<10x10xbf16> {
+      // CHECK: arith.constant dense<16128> : vector<10x10xi16>
+      %dense_const = arith.constant dense<5.000000e-01> : vector<10x10xbf16>
+      func.return %dense_const : vector<10x10xbf16>
+    }
+
+    // CHECK: func.func @callee_func(%[[ARG:.*]]: vector<10x10xi16>) {
+    func.func @callee_func(%arg0: vector<10x10xbf16>) {
+      return
+    }
+  }
+
+  // CHECK: func.func @gpu_launch_func(%[[ARG0:.*]]: memref<10x10xbf16>, %[[ARG1:.*]]: vector<10x10xbf16>) {
+  func.func @gpu_launch_func(%arg0: memref<10x10xbf16>, %arg1: vector<10x10xbf16>) {
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : index
+    // CHECK: arith.constant dense<16128> : vector<10x10xi16>
+    %dense_const = arith.constant dense<5.000000e-01> : vector<10x10xbf16>
+    // CHECK: arith.constant dense<6.015630e-01> : vector<10x10xbf16>
+    %dense_const_2 = arith.constant dense<6.000000e-01> : vector<10x10xbf16>
+
+    // CHECK: %[[ALLOC:.*]] = gpu.alloc () : memref<200xi8>
+    %alloc = gpu.alloc () : memref<10x10xbf16>
----------------
mshahneo wrote:

Removed the need for alloc linearization. Using arith.bitcast for casting memref types now.

https://github.com/llvm/llvm-project/pull/138087


More information about the Mlir-commits mailing list