[Mlir-commits] [mlir] [mlir][amx] Simplify intrinsic generation (PR #140559)
Tobias Gysi
llvmlistbot at llvm.org
Mon May 19 11:53:46 PDT 2025
================
@@ -60,24 +64,168 @@ static LogicalResult verifyMultShape(Operation *op, amx::TileType atp,
return success();
}
+/// Get pointer to a memref descriptor.
+/// Optionally, the base pointer can be offset using linearized index computed
+/// from the given indices.
+static Value getBufferPtr(Location loc, MemRefType type, Value buffer,
+ ValueRange indices,
+ const LLVMTypeConverter &typeConverter,
+ RewriterBase &rewriter) {
+ auto [strides, offset] = type.getStridesAndOffset();
+
+ MemRefDescriptor memRefDescriptor(buffer);
+ Value base = memRefDescriptor.bufferPtr(rewriter, loc, typeConverter, type);
+
+ int numIndices = indices.size();
+ if (numIndices == 0)
+ return base;
+
+ assert(type.getRank() == numIndices &&
+ "expects number of indices equal to memref rank");
+ Value index;
+ Type indexType = typeConverter.getIndexType();
+ for (int i = 0; i < numIndices; ++i) {
+ Value increment = indices[i];
+ if (strides[i] != 1) { // Skip if stride is 1.
+ Value stride =
+ ShapedType::isDynamic(strides[i])
+ ? memRefDescriptor.stride(rewriter, loc, i)
+ : rewriter.create<LLVM::ConstantOp>(
+ loc, indexType, rewriter.getIndexAttr(strides[i]));
+ increment = rewriter.create<LLVM::MulOp>(loc, increment, stride);
+ }
+ index =
+ index ? rewriter.create<LLVM::AddOp>(loc, index, increment) : increment;
+ }
+
+ Type elementPtrType = memRefDescriptor.getElementPtrType();
+ return rewriter.create<LLVM::GEPOp>(
+ loc, elementPtrType, typeConverter.convertType(type.getElementType()),
+ base, index);
+}
+
+/// Maps the 2-dim vector shape to the two 16-bit tile sizes. The first
+/// dimension directly translates into the number of rows of the tiles.
+/// The second dimensions needs to be scaled by the number of bytes.
+static SmallVector<Value> getTileSizes(Location loc, amx::TileType tType,
+ RewriterBase &rewriter) {
+ Type llvmInt16Type = rewriter.getIntegerType(16);
+ unsigned width = tType.getElementType().getIntOrFloatBitWidth();
+ assert(llvm::isPowerOf2_64(width) && width >= 8);
+ unsigned bytes = width >> 3;
+ auto mattr = rewriter.getI16IntegerAttr(tType.getDimSize(0));
+ auto nattr = rewriter.getI16IntegerAttr(tType.getDimSize(1) * bytes);
+ return SmallVector<Value>{
+ rewriter.create<LLVM::ConstantOp>(loc, llvmInt16Type, mattr),
+ rewriter.create<LLVM::ConstantOp>(loc, llvmInt16Type, nattr)};
+}
+
+/// Maps the 2-dim memref shape to the 64-bit stride. Note that the buffer
+/// shape may "envelop" the actual tile shape, and may be dynamically sized.
+/// Returns failure if proper stride couldn't be found.
----------------
gysit wrote:
It seems returning failure has been replaced by an assertion?
https://github.com/llvm/llvm-project/pull/140559
More information about the Mlir-commits
mailing list