[llvm] [mlir] [AMDGPU][MLIR]Add shmem-optimization as an op using transform dialect (PR #81550)

Oleksandr Alex Zinenko via llvm-commits llvm-commits at lists.llvm.org
Wed Feb 14 01:49:55 PST 2024


================
@@ -0,0 +1,64 @@
+// RUN: mlir-opt  %s -transform-interpreter  | FileCheck %s
+
+  // CHECK: @optimize_shmem([[arg0:%.+]]: memref<{{.*}}>, [[readRow:%.+]]: index, [[readCol:%.+]]: index, [[writeRow:%.+]]: index, [[writeCol:%.+]]: index, [[fragRow:%.+]]: index, [[fragCol:%.+]]: index, [[fragColPerm:%.+]]: index, [[stRow:%.+]]: index, [[stCol:%.+]]: index)
+  func.func @optimize_shmem(%arg0: memref<4096x4096xf16>, 
+                    %readRow: index, %readCol: index,
+                    %writeRow: index, %writeCol: index,
+                    %fragRow: index, %fragCol: index, 
+                    %fragColPerm: index,
+                    %stRow: index, %stCol: index) {
+    // CHECK:    %[[cst:.+]] = arith.constant 0.000000e+00 : f16                  
+    %cst = arith.constant 0.000000e+00 : f16
+
+    // CHECK: [[shmA:%.+]] = memref.alloc
+    // CHECK: [[shmB:%.+]] = memref.alloc
+    %shmA = memref.alloc() {alignment = 64 : i64} : memref<128x32xf16, 3>
+    %shmB = memref.alloc() {alignment = 64 : i64} : memref<256x32xf16, 3>
+
+    // CHECK: %[[D0:.+]] = vector.transfer_read [[arg0:%.+]][[[readRow:%.+]], [[readCol:%.+]]], [[cst:.+]] {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<1x8xf16>
+    %0 = vector.transfer_read %arg0[%readRow, %readCol], %cst {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<1x8xf16>
+    // CHECK: [[c7:%.+]] = arith.constant 7 : index                  
+    // CHECK: [[srcBits:%.+]] = arith.andi [[stRow:%.+]], [[c7]]       
+    // CHECK: [[c2:%.+]] = arith.constant 2 : index                 
+    // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]     
+    // CHECK: [[stColPerm:%.+]] = arith.xori [[stCol:%.+]], [[xorBits]]  
+    // CHECK: vector.transfer_write %[[D0:.+]], [[shmB]][[[writeRow:%.+]], [[writeCol:%.+]]] {in_bounds = [true, true]} : vector<1x8xf16>, memref<256x32xf16, 3>
----------------
ftynse wrote:

How much of this is actually exercising the functionality added here rather than generic printing/parsing? Like the in-bounds part or the types? Excessive tests increase maintenance burden, please reduce tests to only test the intended functinality.

https://github.com/llvm/llvm-project/pull/81550


More information about the llvm-commits mailing list