[Mlir-commits] [mlir] [MLIR] [AMX] Fix strides used by AMX lowering for tile loads and stores. (PR #113476)

Ilya Enkovich llvmlistbot at llvm.org
Mon Oct 28 12:59:51 PDT 2024


https://github.com/ienkovich updated https://github.com/llvm/llvm-project/pull/113476

>From 0c0f415e4b3b34768102dbea162bcb6e29f53de1 Mon Sep 17 00:00:00 2001
From: Ilya Enkovich <ilya.enkovich at intel.com>
Date: Wed, 23 Oct 2024 16:59:28 +0000
Subject: [PATCH 1/2] Fix strides used by AMX lowering.

Lowering of tile stores and load uses the size of the last memref
dimension as a stride and ignores actual strides specified in the
memref. This causes unexpected results when actual stride doesn't
match the last dimension size.

Signed-off-by: Ilya Enkovich <ilya.enkovich at intel.com>
---
 .../AMX/Transforms/LegalizeForLLVMExport.cpp  | 16 +++++++----
 mlir/test/Dialect/AMX/legalize-for-llvm.mlir  | 28 +++++++++++++++++++
 2 files changed, 38 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
index c8cfcc3d945bec..14f28c436a1ff6 100644
--- a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
@@ -55,21 +55,25 @@ Value getStride(ConversionPatternRewriter &rewriter,
                 const LLVMTypeConverter &typeConverter, MemRefType mType,
                 Value base, Location loc) {
   assert(mType.getRank() >= 2);
-  int64_t last = mType.getRank() - 1;
+  int64_t preLast = mType.getRank() - 2;
   Type llvmInt64Type = IntegerType::get(&typeConverter.getContext(), 64);
   unsigned width = mType.getElementType().getIntOrFloatBitWidth();
   assert(llvm::isPowerOf2_64(width) && width >= 8);
   unsigned bytes = width >> 3;
-  if (mType.isDynamicDim(last)) {
-    // Dynamic size needs code to compute the stride at runtime.
+  int64_t offset;
+  SmallVector<int64_t, 4> strides;
+  getStridesAndOffset(mType, strides, offset);
+  if (strides[preLast] == ShapedType::kDynamic) {
+    // Dynamic stride needs code to compute the stride at runtime.
     MemRefDescriptor memrefDescriptor(base);
     auto attr = rewriter.getI64IntegerAttr(bytes);
     Value scale = rewriter.create<LLVM::ConstantOp>(loc, llvmInt64Type, attr);
     return rewriter.create<LLVM::MulOp>(
-        loc, llvmInt64Type, scale, memrefDescriptor.size(rewriter, loc, last));
+        loc, llvmInt64Type, scale,
+        memrefDescriptor.stride(rewriter, loc, preLast));
   }
-  // Use direct constant for static size.
-  auto attr = rewriter.getI64IntegerAttr(mType.getDimSize(last) * bytes);
+  // Use direct constant for static stride.
+  auto attr = rewriter.getI64IntegerAttr(strides[preLast] * bytes);
   return rewriter.create<LLVM::ConstantOp>(loc, llvmInt64Type, attr);
 }
 
diff --git a/mlir/test/Dialect/AMX/legalize-for-llvm.mlir b/mlir/test/Dialect/AMX/legalize-for-llvm.mlir
index 992203153939fe..3cacbd0044f825 100644
--- a/mlir/test/Dialect/AMX/legalize-for-llvm.mlir
+++ b/mlir/test/Dialect/AMX/legalize-for-llvm.mlir
@@ -43,3 +43,31 @@ func.func @mulf(%arg0: memref<?x?xbf16>, %arg1: memref<?x?xf32>) {
   amx.tile_store %arg1[%0, %0], %4 : memref<?x?xf32>, vector<16x16xf32>
   return
 }
+
+// CHECK-LABEL: strides(
+// CHECK: %[[CST_64_1:.+]] = llvm.mlir.constant(64 : i64) : i64
+// CHECK: "amx.tileloadd64"(%{{.+}}, %{{.+}}, %{{.+}}, %[[CST_64_1]]
+// CHECK: %[[CST_128_1:.+]] = llvm.mlir.constant(128 : i64) : i64
+// CHECK: "amx.tileloadd64"(%{{.+}}, %{{.+}}, %{{.+}}, %[[CST_128_1]]
+// CHECK: llvm.mlir.constant(2 : i64) : i64
+// CHECK: llvm.extractvalue %{{.+}}[4, 0]
+// CHECK: %[[STRIDE_1:.+]] = llvm.mul
+// CHECK: "amx.tileloadd64"(%{{.+}}, %{{.+}}, %{{.+}}, %[[STRIDE_1]]
+// CHECK: %[[CST_64_2:.+]] = llvm.mlir.constant(64 : i64) : i64
+// CHECK: "amx.tilestored64"(%{{.+}}, %{{.+}}, %{{.+}}, %[[CST_64_2]]
+// CHECK: %[[CST_128_2:.+]] = llvm.mlir.constant(128 : i64) : i64
+// CHECK: "amx.tilestored64"(%{{.+}}, %{{.+}}, %{{.+}}, %[[CST_128_2]]
+// CHECK: llvm.mlir.constant(2 : i64) : i64
+// CHECK: llvm.extractvalue %{{.+}}[4, 0]
+// CHECK: %[[STRIDE_2:.+]] = llvm.mul
+// CHECK: "amx.tilestored64"(%{{.+}}, %{{.+}}, %{{.+}}, %[[STRIDE_2]]
+func.func @strides(%arg0: memref<16x32xbf16>, %arg1: memref<16x32xbf16, strided<[64, 1]>>, %arg2: memref<16x32xbf16, strided<[?, 1]>>) {
+  %0 = arith.constant 0 : index
+  %1 = amx.tile_load %arg0[%0, %0] : memref<16x32xbf16> into vector<16x32xbf16>
+  %2 = amx.tile_load %arg1[%0, %0] : memref<16x32xbf16, strided<[64, 1]>> into vector<16x32xbf16>
+  %3 = amx.tile_load %arg2[%0, %0] : memref<16x32xbf16, strided<[?, 1]>> into vector<16x32xbf16>
+  amx.tile_store %arg0[%0, %0], %3 : memref<16x32xbf16>, vector<16x32xbf16>
+  amx.tile_store %arg1[%0, %0], %1 : memref<16x32xbf16, strided<[64, 1]>>, vector<16x32xbf16>
+  amx.tile_store %arg2[%0, %0], %2 : memref<16x32xbf16, strided<[?, 1]>>, vector<16x32xbf16>
+  return
+}

>From 41f66cd2c60ab3af59178b23f5038cd8f6fd9074 Mon Sep 17 00:00:00 2001
From: Ilya Enkovich <ilya.enkovich at intel.com>
Date: Mon, 28 Oct 2024 19:59:43 +0000
Subject: [PATCH 2/2] Fix review comments.

Signed-off-by: Ilya Enkovich <ilya.enkovich at intel.com>
---
 .../AMX/Transforms/LegalizeForLLVMExport.cpp  | 52 ++++++++-----------
 1 file changed, 23 insertions(+), 29 deletions(-)

diff --git a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
index 14f28c436a1ff6..46c7bfbf3ffcc2 100644
--- a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
@@ -37,24 +37,14 @@ std::pair<Value, Value> getTileSizes(ConversionPatternRewriter &rewriter,
       rewriter.create<LLVM::ConstantOp>(loc, llvmInt16Type, nattr));
 }
 
-/// Verifies if the stride matches proper tile access.
-LogicalResult verifyStride(MemRefType mType) {
-  if (mType.getRank() < 2)
-    return failure();
-  int64_t last = mType.getRank() - 1;
-  int64_t offset;
-  SmallVector<int64_t, 4> strides;
-  if (failed(getStridesAndOffset(mType, strides, offset)) || strides[last] != 1)
-    return failure();
-  return success();
-}
-
 /// Maps the 2-dim memref shape to the 64-bit stride. Note that the buffer
 /// shape may "envelop" the actual tile shape, and may be dynamically sized.
-Value getStride(ConversionPatternRewriter &rewriter,
-                const LLVMTypeConverter &typeConverter, MemRefType mType,
-                Value base, Location loc) {
-  assert(mType.getRank() >= 2);
+/// Returns failure if proper stride couldn't be found.
+FailureOr<Value> getStride(ConversionPatternRewriter &rewriter,
+                           const LLVMTypeConverter &typeConverter,
+                           MemRefType mType, Value base, Location loc) {
+  if (mType.getRank() < 2)
+    return failure();
   int64_t preLast = mType.getRank() - 2;
   Type llvmInt64Type = IntegerType::get(&typeConverter.getContext(), 64);
   unsigned width = mType.getElementType().getIntOrFloatBitWidth();
@@ -62,19 +52,23 @@ Value getStride(ConversionPatternRewriter &rewriter,
   unsigned bytes = width >> 3;
   int64_t offset;
   SmallVector<int64_t, 4> strides;
-  getStridesAndOffset(mType, strides, offset);
+  if (failed(getStridesAndOffset(mType, strides, offset)) ||
+      strides.back() != 1)
+    return failure();
   if (strides[preLast] == ShapedType::kDynamic) {
     // Dynamic stride needs code to compute the stride at runtime.
     MemRefDescriptor memrefDescriptor(base);
     auto attr = rewriter.getI64IntegerAttr(bytes);
     Value scale = rewriter.create<LLVM::ConstantOp>(loc, llvmInt64Type, attr);
-    return rewriter.create<LLVM::MulOp>(
-        loc, llvmInt64Type, scale,
-        memrefDescriptor.stride(rewriter, loc, preLast));
+    return rewriter
+        .create<LLVM::MulOp>(loc, llvmInt64Type, scale,
+                             memrefDescriptor.stride(rewriter, loc, preLast))
+        .getResult();
   }
   // Use direct constant for static stride.
   auto attr = rewriter.getI64IntegerAttr(strides[preLast] * bytes);
-  return rewriter.create<LLVM::ConstantOp>(loc, llvmInt64Type, attr);
+  return rewriter.create<LLVM::ConstantOp>(loc, llvmInt64Type, attr)
+      .getResult();
 }
 
 struct TileZeroConversion : public ConvertOpToLLVMPattern<TileZeroOp> {
@@ -106,16 +100,16 @@ struct TileLoadConversion : public ConvertOpToLLVMPattern<TileLoadOp> {
     std::pair<Value, Value> tsz =
         getTileSizes(rewriter, *getTypeConverter(), vType, op.getLoc());
     // Determine stride.
-    if (failed(verifyStride(mType)))
+    auto stride = getStride(rewriter, *getTypeConverter(), mType,
+                            adaptor.getBase(), op.getLoc());
+    if (failed(stride))
       return failure();
-    Value stride = getStride(rewriter, *getTypeConverter(), mType,
-                             adaptor.getBase(), op.getLoc());
     // Replace operation with intrinsic.
     Value ptr = getStridedElementPtr(op.getLoc(), mType, adaptor.getBase(),
                                      adaptor.getIndices(), rewriter);
     Type resType = typeConverter->convertType(vType);
     rewriter.replaceOpWithNewOp<amx::x86_amx_tileloadd64>(
-        op, resType, tsz.first, tsz.second, ptr, stride);
+        op, resType, tsz.first, tsz.second, ptr, stride.value());
     return success();
   }
 };
@@ -132,15 +126,15 @@ struct TileStoreConversion : public ConvertOpToLLVMPattern<TileStoreOp> {
     std::pair<Value, Value> tsz =
         getTileSizes(rewriter, *getTypeConverter(), vType, op.getLoc());
     // Determine stride.
-    if (failed(verifyStride(mType)))
+    auto stride = getStride(rewriter, *getTypeConverter(), mType,
+                            adaptor.getBase(), op.getLoc());
+    if (failed(stride))
       return failure();
-    Value stride = getStride(rewriter, *getTypeConverter(), mType,
-                             adaptor.getBase(), op.getLoc());
     // Replace operation with intrinsic.
     Value ptr = getStridedElementPtr(op.getLoc(), mType, adaptor.getBase(),
                                      adaptor.getIndices(), rewriter);
     rewriter.replaceOpWithNewOp<amx::x86_amx_tilestored64>(
-        op, tsz.first, tsz.second, ptr, stride, adaptor.getVal());
+        op, tsz.first, tsz.second, ptr, stride.value(), adaptor.getVal());
     return success();
   }
 };



More information about the Mlir-commits mailing list