[Mlir-commits] [mlir] [MLIR] [AMX] Fix strides used by AMX lowering for tile loads and stores. (PR #113476)
Ilya Enkovich
llvmlistbot at llvm.org
Mon Oct 28 12:59:51 PDT 2024
https://github.com/ienkovich updated https://github.com/llvm/llvm-project/pull/113476
>From 0c0f415e4b3b34768102dbea162bcb6e29f53de1 Mon Sep 17 00:00:00 2001
From: Ilya Enkovich <ilya.enkovich at intel.com>
Date: Wed, 23 Oct 2024 16:59:28 +0000
Subject: [PATCH 1/2] Fix strides used by AMX lowering.
Lowering of tile stores and load uses the size of the last memref
dimension as a stride and ignores actual strides specified in the
memref. This causes unexpected results when actual stride doesn't
match the last dimension size.
Signed-off-by: Ilya Enkovich <ilya.enkovich at intel.com>
---
.../AMX/Transforms/LegalizeForLLVMExport.cpp | 16 +++++++----
mlir/test/Dialect/AMX/legalize-for-llvm.mlir | 28 +++++++++++++++++++
2 files changed, 38 insertions(+), 6 deletions(-)
diff --git a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
index c8cfcc3d945bec..14f28c436a1ff6 100644
--- a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
@@ -55,21 +55,25 @@ Value getStride(ConversionPatternRewriter &rewriter,
const LLVMTypeConverter &typeConverter, MemRefType mType,
Value base, Location loc) {
assert(mType.getRank() >= 2);
- int64_t last = mType.getRank() - 1;
+ int64_t preLast = mType.getRank() - 2;
Type llvmInt64Type = IntegerType::get(&typeConverter.getContext(), 64);
unsigned width = mType.getElementType().getIntOrFloatBitWidth();
assert(llvm::isPowerOf2_64(width) && width >= 8);
unsigned bytes = width >> 3;
- if (mType.isDynamicDim(last)) {
- // Dynamic size needs code to compute the stride at runtime.
+ int64_t offset;
+ SmallVector<int64_t, 4> strides;
+ getStridesAndOffset(mType, strides, offset);
+ if (strides[preLast] == ShapedType::kDynamic) {
+ // Dynamic stride needs code to compute the stride at runtime.
MemRefDescriptor memrefDescriptor(base);
auto attr = rewriter.getI64IntegerAttr(bytes);
Value scale = rewriter.create<LLVM::ConstantOp>(loc, llvmInt64Type, attr);
return rewriter.create<LLVM::MulOp>(
- loc, llvmInt64Type, scale, memrefDescriptor.size(rewriter, loc, last));
+ loc, llvmInt64Type, scale,
+ memrefDescriptor.stride(rewriter, loc, preLast));
}
- // Use direct constant for static size.
- auto attr = rewriter.getI64IntegerAttr(mType.getDimSize(last) * bytes);
+ // Use direct constant for static stride.
+ auto attr = rewriter.getI64IntegerAttr(strides[preLast] * bytes);
return rewriter.create<LLVM::ConstantOp>(loc, llvmInt64Type, attr);
}
diff --git a/mlir/test/Dialect/AMX/legalize-for-llvm.mlir b/mlir/test/Dialect/AMX/legalize-for-llvm.mlir
index 992203153939fe..3cacbd0044f825 100644
--- a/mlir/test/Dialect/AMX/legalize-for-llvm.mlir
+++ b/mlir/test/Dialect/AMX/legalize-for-llvm.mlir
@@ -43,3 +43,31 @@ func.func @mulf(%arg0: memref<?x?xbf16>, %arg1: memref<?x?xf32>) {
amx.tile_store %arg1[%0, %0], %4 : memref<?x?xf32>, vector<16x16xf32>
return
}
+
+// CHECK-LABEL: strides(
+// CHECK: %[[CST_64_1:.+]] = llvm.mlir.constant(64 : i64) : i64
+// CHECK: "amx.tileloadd64"(%{{.+}}, %{{.+}}, %{{.+}}, %[[CST_64_1]]
+// CHECK: %[[CST_128_1:.+]] = llvm.mlir.constant(128 : i64) : i64
+// CHECK: "amx.tileloadd64"(%{{.+}}, %{{.+}}, %{{.+}}, %[[CST_128_1]]
+// CHECK: llvm.mlir.constant(2 : i64) : i64
+// CHECK: llvm.extractvalue %{{.+}}[4, 0]
+// CHECK: %[[STRIDE_1:.+]] = llvm.mul
+// CHECK: "amx.tileloadd64"(%{{.+}}, %{{.+}}, %{{.+}}, %[[STRIDE_1]]
+// CHECK: %[[CST_64_2:.+]] = llvm.mlir.constant(64 : i64) : i64
+// CHECK: "amx.tilestored64"(%{{.+}}, %{{.+}}, %{{.+}}, %[[CST_64_2]]
+// CHECK: %[[CST_128_2:.+]] = llvm.mlir.constant(128 : i64) : i64
+// CHECK: "amx.tilestored64"(%{{.+}}, %{{.+}}, %{{.+}}, %[[CST_128_2]]
+// CHECK: llvm.mlir.constant(2 : i64) : i64
+// CHECK: llvm.extractvalue %{{.+}}[4, 0]
+// CHECK: %[[STRIDE_2:.+]] = llvm.mul
+// CHECK: "amx.tilestored64"(%{{.+}}, %{{.+}}, %{{.+}}, %[[STRIDE_2]]
+func.func @strides(%arg0: memref<16x32xbf16>, %arg1: memref<16x32xbf16, strided<[64, 1]>>, %arg2: memref<16x32xbf16, strided<[?, 1]>>) {
+ %0 = arith.constant 0 : index
+ %1 = amx.tile_load %arg0[%0, %0] : memref<16x32xbf16> into vector<16x32xbf16>
+ %2 = amx.tile_load %arg1[%0, %0] : memref<16x32xbf16, strided<[64, 1]>> into vector<16x32xbf16>
+ %3 = amx.tile_load %arg2[%0, %0] : memref<16x32xbf16, strided<[?, 1]>> into vector<16x32xbf16>
+ amx.tile_store %arg0[%0, %0], %3 : memref<16x32xbf16>, vector<16x32xbf16>
+ amx.tile_store %arg1[%0, %0], %1 : memref<16x32xbf16, strided<[64, 1]>>, vector<16x32xbf16>
+ amx.tile_store %arg2[%0, %0], %2 : memref<16x32xbf16, strided<[?, 1]>>, vector<16x32xbf16>
+ return
+}
>From 41f66cd2c60ab3af59178b23f5038cd8f6fd9074 Mon Sep 17 00:00:00 2001
From: Ilya Enkovich <ilya.enkovich at intel.com>
Date: Mon, 28 Oct 2024 19:59:43 +0000
Subject: [PATCH 2/2] Fix review comments.
Signed-off-by: Ilya Enkovich <ilya.enkovich at intel.com>
---
.../AMX/Transforms/LegalizeForLLVMExport.cpp | 52 ++++++++-----------
1 file changed, 23 insertions(+), 29 deletions(-)
diff --git a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
index 14f28c436a1ff6..46c7bfbf3ffcc2 100644
--- a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
@@ -37,24 +37,14 @@ std::pair<Value, Value> getTileSizes(ConversionPatternRewriter &rewriter,
rewriter.create<LLVM::ConstantOp>(loc, llvmInt16Type, nattr));
}
-/// Verifies if the stride matches proper tile access.
-LogicalResult verifyStride(MemRefType mType) {
- if (mType.getRank() < 2)
- return failure();
- int64_t last = mType.getRank() - 1;
- int64_t offset;
- SmallVector<int64_t, 4> strides;
- if (failed(getStridesAndOffset(mType, strides, offset)) || strides[last] != 1)
- return failure();
- return success();
-}
-
/// Maps the 2-dim memref shape to the 64-bit stride. Note that the buffer
/// shape may "envelop" the actual tile shape, and may be dynamically sized.
-Value getStride(ConversionPatternRewriter &rewriter,
- const LLVMTypeConverter &typeConverter, MemRefType mType,
- Value base, Location loc) {
- assert(mType.getRank() >= 2);
+/// Returns failure if proper stride couldn't be found.
+FailureOr<Value> getStride(ConversionPatternRewriter &rewriter,
+ const LLVMTypeConverter &typeConverter,
+ MemRefType mType, Value base, Location loc) {
+ if (mType.getRank() < 2)
+ return failure();
int64_t preLast = mType.getRank() - 2;
Type llvmInt64Type = IntegerType::get(&typeConverter.getContext(), 64);
unsigned width = mType.getElementType().getIntOrFloatBitWidth();
@@ -62,19 +52,23 @@ Value getStride(ConversionPatternRewriter &rewriter,
unsigned bytes = width >> 3;
int64_t offset;
SmallVector<int64_t, 4> strides;
- getStridesAndOffset(mType, strides, offset);
+ if (failed(getStridesAndOffset(mType, strides, offset)) ||
+ strides.back() != 1)
+ return failure();
if (strides[preLast] == ShapedType::kDynamic) {
// Dynamic stride needs code to compute the stride at runtime.
MemRefDescriptor memrefDescriptor(base);
auto attr = rewriter.getI64IntegerAttr(bytes);
Value scale = rewriter.create<LLVM::ConstantOp>(loc, llvmInt64Type, attr);
- return rewriter.create<LLVM::MulOp>(
- loc, llvmInt64Type, scale,
- memrefDescriptor.stride(rewriter, loc, preLast));
+ return rewriter
+ .create<LLVM::MulOp>(loc, llvmInt64Type, scale,
+ memrefDescriptor.stride(rewriter, loc, preLast))
+ .getResult();
}
// Use direct constant for static stride.
auto attr = rewriter.getI64IntegerAttr(strides[preLast] * bytes);
- return rewriter.create<LLVM::ConstantOp>(loc, llvmInt64Type, attr);
+ return rewriter.create<LLVM::ConstantOp>(loc, llvmInt64Type, attr)
+ .getResult();
}
struct TileZeroConversion : public ConvertOpToLLVMPattern<TileZeroOp> {
@@ -106,16 +100,16 @@ struct TileLoadConversion : public ConvertOpToLLVMPattern<TileLoadOp> {
std::pair<Value, Value> tsz =
getTileSizes(rewriter, *getTypeConverter(), vType, op.getLoc());
// Determine stride.
- if (failed(verifyStride(mType)))
+ auto stride = getStride(rewriter, *getTypeConverter(), mType,
+ adaptor.getBase(), op.getLoc());
+ if (failed(stride))
return failure();
- Value stride = getStride(rewriter, *getTypeConverter(), mType,
- adaptor.getBase(), op.getLoc());
// Replace operation with intrinsic.
Value ptr = getStridedElementPtr(op.getLoc(), mType, adaptor.getBase(),
adaptor.getIndices(), rewriter);
Type resType = typeConverter->convertType(vType);
rewriter.replaceOpWithNewOp<amx::x86_amx_tileloadd64>(
- op, resType, tsz.first, tsz.second, ptr, stride);
+ op, resType, tsz.first, tsz.second, ptr, stride.value());
return success();
}
};
@@ -132,15 +126,15 @@ struct TileStoreConversion : public ConvertOpToLLVMPattern<TileStoreOp> {
std::pair<Value, Value> tsz =
getTileSizes(rewriter, *getTypeConverter(), vType, op.getLoc());
// Determine stride.
- if (failed(verifyStride(mType)))
+ auto stride = getStride(rewriter, *getTypeConverter(), mType,
+ adaptor.getBase(), op.getLoc());
+ if (failed(stride))
return failure();
- Value stride = getStride(rewriter, *getTypeConverter(), mType,
- adaptor.getBase(), op.getLoc());
// Replace operation with intrinsic.
Value ptr = getStridedElementPtr(op.getLoc(), mType, adaptor.getBase(),
adaptor.getIndices(), rewriter);
rewriter.replaceOpWithNewOp<amx::x86_amx_tilestored64>(
- op, tsz.first, tsz.second, ptr, stride, adaptor.getVal());
+ op, tsz.first, tsz.second, ptr, stride.value(), adaptor.getVal());
return success();
}
};
More information about the Mlir-commits
mailing list