[Mlir-commits] [mlir] [MLIR] [AMX] Fix strides used by AMX lowering for tile loads and stores. (PR #113476)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Oct 23 10:15:50 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Ilya Enkovich (ienkovich)

<details>
<summary>Changes</summary>

Lowering of tile stores and load uses the size of the last memref dimension as a stride and ignores actual strides specified in the memref. This causes unexpected results when actual stride doesn't match the last dimension size.

---
Full diff: https://github.com/llvm/llvm-project/pull/113476.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp (+10-6) 
- (modified) mlir/test/Dialect/AMX/legalize-for-llvm.mlir (+28) 


``````````diff
diff --git a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
index c8cfcc3d945bec..14f28c436a1ff6 100644
--- a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
@@ -55,21 +55,25 @@ Value getStride(ConversionPatternRewriter &rewriter,
                 const LLVMTypeConverter &typeConverter, MemRefType mType,
                 Value base, Location loc) {
   assert(mType.getRank() >= 2);
-  int64_t last = mType.getRank() - 1;
+  int64_t preLast = mType.getRank() - 2;
   Type llvmInt64Type = IntegerType::get(&typeConverter.getContext(), 64);
   unsigned width = mType.getElementType().getIntOrFloatBitWidth();
   assert(llvm::isPowerOf2_64(width) && width >= 8);
   unsigned bytes = width >> 3;
-  if (mType.isDynamicDim(last)) {
-    // Dynamic size needs code to compute the stride at runtime.
+  int64_t offset;
+  SmallVector<int64_t, 4> strides;
+  getStridesAndOffset(mType, strides, offset);
+  if (strides[preLast] == ShapedType::kDynamic) {
+    // Dynamic stride needs code to compute the stride at runtime.
     MemRefDescriptor memrefDescriptor(base);
     auto attr = rewriter.getI64IntegerAttr(bytes);
     Value scale = rewriter.create<LLVM::ConstantOp>(loc, llvmInt64Type, attr);
     return rewriter.create<LLVM::MulOp>(
-        loc, llvmInt64Type, scale, memrefDescriptor.size(rewriter, loc, last));
+        loc, llvmInt64Type, scale,
+        memrefDescriptor.stride(rewriter, loc, preLast));
   }
-  // Use direct constant for static size.
-  auto attr = rewriter.getI64IntegerAttr(mType.getDimSize(last) * bytes);
+  // Use direct constant for static stride.
+  auto attr = rewriter.getI64IntegerAttr(strides[preLast] * bytes);
   return rewriter.create<LLVM::ConstantOp>(loc, llvmInt64Type, attr);
 }
 
diff --git a/mlir/test/Dialect/AMX/legalize-for-llvm.mlir b/mlir/test/Dialect/AMX/legalize-for-llvm.mlir
index 992203153939fe..3cacbd0044f825 100644
--- a/mlir/test/Dialect/AMX/legalize-for-llvm.mlir
+++ b/mlir/test/Dialect/AMX/legalize-for-llvm.mlir
@@ -43,3 +43,31 @@ func.func @mulf(%arg0: memref<?x?xbf16>, %arg1: memref<?x?xf32>) {
   amx.tile_store %arg1[%0, %0], %4 : memref<?x?xf32>, vector<16x16xf32>
   return
 }
+
+// CHECK-LABEL: strides(
+// CHECK: %[[CST_64_1:.+]] = llvm.mlir.constant(64 : i64) : i64
+// CHECK: "amx.tileloadd64"(%{{.+}}, %{{.+}}, %{{.+}}, %[[CST_64_1]]
+// CHECK: %[[CST_128_1:.+]] = llvm.mlir.constant(128 : i64) : i64
+// CHECK: "amx.tileloadd64"(%{{.+}}, %{{.+}}, %{{.+}}, %[[CST_128_1]]
+// CHECK: llvm.mlir.constant(2 : i64) : i64
+// CHECK: llvm.extractvalue %{{.+}}[4, 0]
+// CHECK: %[[STRIDE_1:.+]] = llvm.mul
+// CHECK: "amx.tileloadd64"(%{{.+}}, %{{.+}}, %{{.+}}, %[[STRIDE_1]]
+// CHECK: %[[CST_64_2:.+]] = llvm.mlir.constant(64 : i64) : i64
+// CHECK: "amx.tilestored64"(%{{.+}}, %{{.+}}, %{{.+}}, %[[CST_64_2]]
+// CHECK: %[[CST_128_2:.+]] = llvm.mlir.constant(128 : i64) : i64
+// CHECK: "amx.tilestored64"(%{{.+}}, %{{.+}}, %{{.+}}, %[[CST_128_2]]
+// CHECK: llvm.mlir.constant(2 : i64) : i64
+// CHECK: llvm.extractvalue %{{.+}}[4, 0]
+// CHECK: %[[STRIDE_2:.+]] = llvm.mul
+// CHECK: "amx.tilestored64"(%{{.+}}, %{{.+}}, %{{.+}}, %[[STRIDE_2]]
+func.func @strides(%arg0: memref<16x32xbf16>, %arg1: memref<16x32xbf16, strided<[64, 1]>>, %arg2: memref<16x32xbf16, strided<[?, 1]>>) {
+  %0 = arith.constant 0 : index
+  %1 = amx.tile_load %arg0[%0, %0] : memref<16x32xbf16> into vector<16x32xbf16>
+  %2 = amx.tile_load %arg1[%0, %0] : memref<16x32xbf16, strided<[64, 1]>> into vector<16x32xbf16>
+  %3 = amx.tile_load %arg2[%0, %0] : memref<16x32xbf16, strided<[?, 1]>> into vector<16x32xbf16>
+  amx.tile_store %arg0[%0, %0], %3 : memref<16x32xbf16>, vector<16x32xbf16>
+  amx.tile_store %arg1[%0, %0], %1 : memref<16x32xbf16, strided<[64, 1]>>, vector<16x32xbf16>
+  amx.tile_store %arg2[%0, %0], %2 : memref<16x32xbf16, strided<[?, 1]>>, vector<16x32xbf16>
+  return
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/113476


More information about the Mlir-commits mailing list