[Mlir-commits] [mlir] [MLIR] [AMX] Fix strides used by AMX lowering for tile loads and stores. (PR #113476)

Renato Golin llvmlistbot at llvm.org
Mon Oct 28 02:38:01 PDT 2024


================
@@ -55,21 +55,25 @@ Value getStride(ConversionPatternRewriter &rewriter,
                 const LLVMTypeConverter &typeConverter, MemRefType mType,
                 Value base, Location loc) {
   assert(mType.getRank() >= 2);
-  int64_t last = mType.getRank() - 1;
+  int64_t preLast = mType.getRank() - 2;
   Type llvmInt64Type = IntegerType::get(&typeConverter.getContext(), 64);
   unsigned width = mType.getElementType().getIntOrFloatBitWidth();
   assert(llvm::isPowerOf2_64(width) && width >= 8);
   unsigned bytes = width >> 3;
-  if (mType.isDynamicDim(last)) {
-    // Dynamic size needs code to compute the stride at runtime.
+  int64_t offset;
+  SmallVector<int64_t, 4> strides;
+  getStridesAndOffset(mType, strides, offset);
----------------
rengolin wrote:

If you're assuming above that `stride(last) == 1`, then `strides[preLast] == getDimSize[last]`, so maybe we don't need to call this function and just check if `isDynamicDim(preLast)` here, then keep `mType.getDimSize(last)` below.

https://github.com/llvm/llvm-project/pull/113476


More information about the Mlir-commits mailing list