[Mlir-commits] [mlir] [MLIR] [AMX] Fix strides used by AMX lowering for tile loads and stores. (PR #113476)
Ilya Enkovich
llvmlistbot at llvm.org
Mon Oct 28 10:22:50 PDT 2024
================
@@ -55,21 +55,25 @@ Value getStride(ConversionPatternRewriter &rewriter,
const LLVMTypeConverter &typeConverter, MemRefType mType,
Value base, Location loc) {
assert(mType.getRank() >= 2);
- int64_t last = mType.getRank() - 1;
+ int64_t preLast = mType.getRank() - 2;
Type llvmInt64Type = IntegerType::get(&typeConverter.getContext(), 64);
unsigned width = mType.getElementType().getIntOrFloatBitWidth();
assert(llvm::isPowerOf2_64(width) && width >= 8);
unsigned bytes = width >> 3;
- if (mType.isDynamicDim(last)) {
- // Dynamic size needs code to compute the stride at runtime.
+ int64_t offset;
+ SmallVector<int64_t, 4> strides;
+ getStridesAndOffset(mType, strides, offset);
----------------
ienkovich wrote:
Stride 1 in the last dimension doesn't guarantee `strides[preLast] == getDimSize[last]` because strides can be specified separately, e.g. `memref<16x32xbf16, strided<[64, 1]>>` like in the added test. I encountered such memrefs in our TritonCPU lowerings.
https://github.com/llvm/llvm-project/pull/113476
More information about the Mlir-commits
mailing list