[Mlir-commits] [mlir] [MLIR] [AMX] Fix strides used by AMX lowering for tile loads and stores. (PR #113476)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Oct 23 10:15:50 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Ilya Enkovich (ienkovich)
<details>
<summary>Changes</summary>
Lowering of tile stores and load uses the size of the last memref dimension as a stride and ignores actual strides specified in the memref. This causes unexpected results when actual stride doesn't match the last dimension size.
---
Full diff: https://github.com/llvm/llvm-project/pull/113476.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp (+10-6)
- (modified) mlir/test/Dialect/AMX/legalize-for-llvm.mlir (+28)
``````````diff
diff --git a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
index c8cfcc3d945bec..14f28c436a1ff6 100644
--- a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
@@ -55,21 +55,25 @@ Value getStride(ConversionPatternRewriter &rewriter,
const LLVMTypeConverter &typeConverter, MemRefType mType,
Value base, Location loc) {
assert(mType.getRank() >= 2);
- int64_t last = mType.getRank() - 1;
+ int64_t preLast = mType.getRank() - 2;
Type llvmInt64Type = IntegerType::get(&typeConverter.getContext(), 64);
unsigned width = mType.getElementType().getIntOrFloatBitWidth();
assert(llvm::isPowerOf2_64(width) && width >= 8);
unsigned bytes = width >> 3;
- if (mType.isDynamicDim(last)) {
- // Dynamic size needs code to compute the stride at runtime.
+ int64_t offset;
+ SmallVector<int64_t, 4> strides;
+ getStridesAndOffset(mType, strides, offset);
+ if (strides[preLast] == ShapedType::kDynamic) {
+ // Dynamic stride needs code to compute the stride at runtime.
MemRefDescriptor memrefDescriptor(base);
auto attr = rewriter.getI64IntegerAttr(bytes);
Value scale = rewriter.create<LLVM::ConstantOp>(loc, llvmInt64Type, attr);
return rewriter.create<LLVM::MulOp>(
- loc, llvmInt64Type, scale, memrefDescriptor.size(rewriter, loc, last));
+ loc, llvmInt64Type, scale,
+ memrefDescriptor.stride(rewriter, loc, preLast));
}
- // Use direct constant for static size.
- auto attr = rewriter.getI64IntegerAttr(mType.getDimSize(last) * bytes);
+ // Use direct constant for static stride.
+ auto attr = rewriter.getI64IntegerAttr(strides[preLast] * bytes);
return rewriter.create<LLVM::ConstantOp>(loc, llvmInt64Type, attr);
}
diff --git a/mlir/test/Dialect/AMX/legalize-for-llvm.mlir b/mlir/test/Dialect/AMX/legalize-for-llvm.mlir
index 992203153939fe..3cacbd0044f825 100644
--- a/mlir/test/Dialect/AMX/legalize-for-llvm.mlir
+++ b/mlir/test/Dialect/AMX/legalize-for-llvm.mlir
@@ -43,3 +43,31 @@ func.func @mulf(%arg0: memref<?x?xbf16>, %arg1: memref<?x?xf32>) {
amx.tile_store %arg1[%0, %0], %4 : memref<?x?xf32>, vector<16x16xf32>
return
}
+
+// CHECK-LABEL: strides(
+// CHECK: %[[CST_64_1:.+]] = llvm.mlir.constant(64 : i64) : i64
+// CHECK: "amx.tileloadd64"(%{{.+}}, %{{.+}}, %{{.+}}, %[[CST_64_1]]
+// CHECK: %[[CST_128_1:.+]] = llvm.mlir.constant(128 : i64) : i64
+// CHECK: "amx.tileloadd64"(%{{.+}}, %{{.+}}, %{{.+}}, %[[CST_128_1]]
+// CHECK: llvm.mlir.constant(2 : i64) : i64
+// CHECK: llvm.extractvalue %{{.+}}[4, 0]
+// CHECK: %[[STRIDE_1:.+]] = llvm.mul
+// CHECK: "amx.tileloadd64"(%{{.+}}, %{{.+}}, %{{.+}}, %[[STRIDE_1]]
+// CHECK: %[[CST_64_2:.+]] = llvm.mlir.constant(64 : i64) : i64
+// CHECK: "amx.tilestored64"(%{{.+}}, %{{.+}}, %{{.+}}, %[[CST_64_2]]
+// CHECK: %[[CST_128_2:.+]] = llvm.mlir.constant(128 : i64) : i64
+// CHECK: "amx.tilestored64"(%{{.+}}, %{{.+}}, %{{.+}}, %[[CST_128_2]]
+// CHECK: llvm.mlir.constant(2 : i64) : i64
+// CHECK: llvm.extractvalue %{{.+}}[4, 0]
+// CHECK: %[[STRIDE_2:.+]] = llvm.mul
+// CHECK: "amx.tilestored64"(%{{.+}}, %{{.+}}, %{{.+}}, %[[STRIDE_2]]
+func.func @strides(%arg0: memref<16x32xbf16>, %arg1: memref<16x32xbf16, strided<[64, 1]>>, %arg2: memref<16x32xbf16, strided<[?, 1]>>) {
+ %0 = arith.constant 0 : index
+ %1 = amx.tile_load %arg0[%0, %0] : memref<16x32xbf16> into vector<16x32xbf16>
+ %2 = amx.tile_load %arg1[%0, %0] : memref<16x32xbf16, strided<[64, 1]>> into vector<16x32xbf16>
+ %3 = amx.tile_load %arg2[%0, %0] : memref<16x32xbf16, strided<[?, 1]>> into vector<16x32xbf16>
+ amx.tile_store %arg0[%0, %0], %3 : memref<16x32xbf16>, vector<16x32xbf16>
+ amx.tile_store %arg1[%0, %0], %1 : memref<16x32xbf16, strided<[64, 1]>>, vector<16x32xbf16>
+ amx.tile_store %arg2[%0, %0], %2 : memref<16x32xbf16, strided<[?, 1]>>, vector<16x32xbf16>
+ return
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/113476
More information about the Mlir-commits
mailing list