[Mlir-commits] [mlir] [mlir][gpu] Add pass for imitating unsupported types. (PR #138087)
Md Abdullah Shahneous Bari
llvmlistbot at llvm.org
Wed Apr 30 23:35:58 PDT 2025
https://github.com/mshahneo updated https://github.com/llvm/llvm-project/pull/138087
>From 1c99201c44baf51d5c3a55a7af665364673aac6f Mon Sep 17 00:00:00 2001
From: "Shahneous Bari, Md Abdullah" <md.abdullah.shahneous.bari at intel.com>
Date: Thu, 1 May 2025 06:00:13 +0000
Subject: [PATCH 1/2] [mlir][gpu] Add pass for imitating unsupported types.
This pass imitates (bitcast/reinterpret_cast) unsupported types
with supported types of same bitwidth. The imitation is done
by bitcasting the unspported types to the supported types of same bitwidth.
Therefore, the source type and destination type must have the same bitwidth.
The imitation is done by using the following operations: arith.bitcast.
The imitation is often needed when the GPU target (dialect/IR) does not
support a certain type but the underlying architecture does. Take SPIR-V for
example, it does not support bf16, but an underlying architecture (e.g.,
intel pvc gpu) that uses SPIR-V for code-generation does.
Therefore, bf16 is neither a valid data type to pass to gpu kernel, nor to
be used inside the kernel. To use bf16 data type in a SPIR-V kernel (as a
kernel parameter or inside the kernel), bf16 have to be bitcasted (similar
to C++ reinterpret_cast) to a supported type (e.g., i16 for Intel GPUs). The
SPIR-V kernel can then use the imitated type (i16) in the computation.
However, i16 is not the same as bf16 (integer vs float), so the computation
can not readily use the imitated type (i16).
Therefore, this transformation pass is intended to be used in conjuction
with other transformation passes such as `EmulateUnsupportedFloats` and
`ExtendUnsupportedTypes` that extend the bitwidth of bf16 to f32 and
vice-versa.
Finally, usually, there are instructions available in the target
(dialect/IR) that can take advantage of these generated patterns
(bf16->i16->f32, f32->bf16->i16), and convert them to the supported
types.
For example, Intel provides SPIR-V extension ops that can
take imitated bf16 (i16) and convert them to f32 and vice-versa.
https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/INTEL/SPV_INTEL_bfloat16_conversion.asciidoc
https://mlir.llvm.org/docs/Dialects/SPIR-V/#spirvintelconvertbf16tof-spirvintelconvertbf16tofop
https://mlir.llvm.org/docs/Dialects/SPIR-V/#spirvintelconvertftobf16-spirvintelconvertftobf16op
---
.../mlir/Dialect/GPU/Transforms/Passes.h | 20 +
.../mlir/Dialect/GPU/Transforms/Passes.td | 53 +
.../Transforms/ImitateUnsupportedTypes.cpp | 916 ++++++++++++++++++
.../GPU/imitate-unsupported-types.mlir | 141 +++
4 files changed, 1130 insertions(+)
create mode 100644 mlir/lib/Dialect/GPU/Transforms/ImitateUnsupportedTypes.cpp
create mode 100644 mlir/test/Dialect/GPU/imitate-unsupported-types.mlir
diff --git a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
index 6cd6f03253aea..0b7339a94b274 100644
--- a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
@@ -16,6 +16,8 @@
#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/GPU/Utils/GPUUtils.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include <optional>
@@ -87,6 +89,24 @@ void populateGpuLowerClusteredSubgroupReduceToDPPPatterns(
RewritePatternSet &patterns, unsigned subgroupSize, amdgpu::Chipset chipset,
PatternBenefit benefit = 1);
+/// Set up a type converter to convert unsupported source types to
+/// supported target types.
+void populateImitateUnsupportedTypesTypeConverter(TypeConverter &typeConverter,
+ ArrayRef<Type> sourceTypes,
+ ArrayRef<Type> targetTypes);
+
+/// Collect a set of pattern needed to imitate unsupported source types
+/// using supported target types.
+void populateImitateUnsupportedTypesConversionPatterns(
+ RewritePatternSet &patterns, TypeConverter &typeConverter,
+ ArrayRef<Type> sourceTypes, ArrayRef<Type> targetTypes,
+ DenseMap<StringAttr, FunctionType> &convertedFuncTypes);
+
+/// Set up a dialect conversion to reject operations on unsupported
+/// float types.
+void configureImitateUnsupportedTypesLegality(ConversionTarget &target,
+ TypeConverter &typeConverter);
+
/// Collect all patterns to rewrite ops within the GPU dialect.
inline void populateGpuRewritePatterns(RewritePatternSet &patterns) {
populateGpuAllReducePatterns(patterns);
diff --git a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.td
index 3766eb16e9429..feb1b2820abd6 100644
--- a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.td
@@ -258,4 +258,57 @@ def GpuSPIRVAttachTarget: Pass<"spirv-attach-target", ""> {
];
}
+def GpuImitateUnsupportedTypes : Pass<"imitate-unsupported-types", "::mlir::ModuleOp"> {
+ let summary = "Imitate unsupported types with supported types of same bitwidth.";
+ let description = [{
+ This pass imitates (bitcast/reinterpret_cast) unsupported types
+ with supported types of same bitwidth. The imitation is done
+ by bitcasting the unspported types to the supported types of same bitwidth.
+ Therefore, the source type and destination type must have the same bitwidth.
+ The imitation is done by using the following operations: arith.bitcast.
+
+ The imitation is often needed when the GPU target (dialect/IR) does not
+ support a certain type but the underlying architecture does. Take SPIR-V for
+ example, it does not support bf16, but an underlying architecture (e.g.,
+ intel pvc gpu) that uses SPIR-V for code-generation does.
+ Therefore, bf16 is neither a valid data type to pass to gpu kernel, nor to
+ be used inside the kernel. To use bf16 data type in a SPIR-V kernel (as a
+ kernel parameter or inside the kernel), bf16 have to be bitcasted (similar
+ to C++ reinterpret_cast) to a supported type (e.g., i16 for Intel GPUs). The
+ SPIR-V kernel can then use the imitated type (i16) in the computation.
+ However, i16 is not the same as bf16 (integer vs float), so the computation
+ can not readily use the imitated type (i16).
+
+ Therefore, this transformation pass is intended to be used in conjuction
+ with other transformation passes such as `EmulateUnsupportedFloats` and
+ `ExtendUnsupportedTypes` that extend the bitwidth of bf16 to f32 and
+ vice-versa.
+
+ Finally, usually, there are instructions available in the target
+ (dialect/IR) that can take advantage of these generated patterns
+ (bf16->i16->f32, f32->bf16->i16), and convert them to the supported
+ types.
+ For example, Intel provides SPIR-V extension ops that can
+ take imitated bf16 (i16) and convert them to f32 and vice-versa.
+ https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/INTEL/SPV_INTEL_bfloat16_conversion.asciidoc
+ https://mlir.llvm.org/docs/Dialects/SPIR-V/#spirvintelconvertbf16tof-spirvintelconvertbf16tofop
+ https://mlir.llvm.org/docs/Dialects/SPIR-V/#spirvintelconvertftobf16-spirvintelconvertftobf16op
+
+ }];
+
+ let options = [
+ ListOption<"sourceTypeStrs", "source-types", "std::string",
+ "MLIR types without type support on a given target">,
+ ListOption<"targetTypeStrs", "target-types", "std::string",
+ "MLIR types to convert the unsupported source types to">,
+ ];
+
+ let dependentDialects = [
+ "::mlir::gpu::GPUDialect",
+ "::mlir::arith::ArithDialect",
+ "::mlir::memref::MemRefDialect"
+ ];
+}
+
+
#endif // MLIR_DIALECT_GPU_PASSES
diff --git a/mlir/lib/Dialect/GPU/Transforms/ImitateUnsupportedTypes.cpp b/mlir/lib/Dialect/GPU/Transforms/ImitateUnsupportedTypes.cpp
new file mode 100644
index 0000000000000..c83e6bec568e0
--- /dev/null
+++ b/mlir/lib/Dialect/GPU/Transforms/ImitateUnsupportedTypes.cpp
@@ -0,0 +1,916 @@
+//===- ImitateUnsupportedTypes.cpp - Unsupported Type Imitation ----*- C++
+//-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//===----------------------------------------------------------------------===//
+//
+/// \file
+/// This pass imitates (bitcast/reinterpret_cast) unsupported types
+/// with supported types of same bitwidth. The imitation is done
+/// by bitcasting the unspported types to the supported types of same bitwidth.
+/// Therefore, the source type and destination type must have the same bitwidth.
+/// The imitation is done by using the following operations: arith.bitcast.
+///
+/// The imitation is often needed when the GPU target (dialect/IR) does not
+/// support a certain type but the underlying architecture does. Take SPIR-V for
+/// example, it does not support bf16, but an underlying architecture (e.g.,
+/// intel pvc gpu) that uses SPIR-V for code-generation does.
+/// Therefore, bf16 is neither a valid data type to pass to gpu kernel, nor to
+/// be used inside the kernel. To use bf16 data type in a SPIR-V kernel (as a
+/// kernel parameter or inside the kernel), bf16 have to be bitcasted (similar
+/// to C++ reinterpret_cast) to a supported type (e.g., i16 for Intel GPUs). The
+/// SPIR-V kernel can then use the imitated type (i16) in the computation.
+/// However, i16 is not the same as bf16 (integer vs float), so the computation
+/// can not readily use the imitated type (i16).
+///
+/// Therefore, this transformation pass is intended to be used in conjuction
+/// with other transformation passes such as `EmulateUnsupportedFloats` and
+/// `ExtendUnsupportedTypes` that extend the bitwidth of bf16 to f32 and
+/// vice-versa.
+///
+/// Finally, usually, there are instructions available in the target
+/// (dialect/IR) that can take advantage of these generated patterns
+/// (bf16->i16->f32, f32->bf16->i16), and convert them to the supported
+/// types.
+/// For example, Intel provides SPIR-V extension ops that can
+/// take imitated bf16 (i16) and convert them to f32 and vice-versa.
+/// https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/INTEL/SPV_INTEL_bfloat16_conversion.asciidoc
+/// https://mlir.llvm.org/docs/Dialects/SPIR-V/#spirvintelconvertbf16tof-spirvintelconvertbf16tofop
+/// https://mlir.llvm.org/docs/Dialects/SPIR-V/#spirvintelconvertftobf16-spirvintelconvertftobf16op
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/GPU/Transforms/Passes.h"
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/APFloat.h"
+#include "llvm/ADT/APInt.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/Support/ErrorHandling.h"
+
+#include <optional>
+#include <type_traits>
+#include <variant>
+
+using namespace mlir;
+using namespace mlir::gpu;
+
+namespace mlir {
+#define GEN_PASS_DEF_GPUIMITATEUNSUPPORTEDTYPES
+#include "mlir/Dialect/GPU/Transforms/Passes.h.inc"
+} // namespace mlir
+
+//===----------------------------------------------------------------------===//
+// Utility functions
+//===----------------------------------------------------------------------===//
+
+APFloat bitcastAPIntToAPFloat(const APInt &intValue,
+ const llvm::fltSemantics &semantics) {
+ // Get the bit width of the APInt.
+ unsigned intBitWidth = intValue.getBitWidth();
+ // Get the total bit size required for the APFloat based on the semantics.
+ unsigned floatBitWidth = APFloat::getSizeInBits(semantics);
+ // Ensure the bit widths match for a direct bitcast.
+ assert(intBitWidth == floatBitWidth &&
+ "Bitwidth of APInt and APFloat must match for bitcast");
+
+ // Get the raw bit representation of the APInt as a byte vector.
+ auto intWords = intValue.getRawData();
+ // Create an APFloat with the specified semantics and the raw integer bits.
+ APFloat floatValue(semantics, APInt(intBitWidth, *intWords));
+ return floatValue;
+}
+
+// Get FloatAttr from IntegerAttr.
+FloatAttr getFloatAttrFromIntegerAttr(IntegerAttr intAttr, Type dstType,
+ ConversionPatternRewriter &rewriter) {
+ APInt intVal = intAttr.getValue();
+ auto floatVal = bitcastAPIntToAPFloat(
+ intVal, cast<FloatType>(dstType).getFloatSemantics());
+ return rewriter.getFloatAttr(dstType, floatVal);
+}
+// Get IntegerAttr from FloatAttr.
+IntegerAttr getIntegerAttrFromFloatAttr(FloatAttr floatAttr, Type dstType,
+ ConversionPatternRewriter &rewriter) {
+ APFloat floatVal = floatAttr.getValue();
+ APInt intVal = floatVal.bitcastToAPInt();
+ return rewriter.getIntegerAttr(dstType, intVal);
+}
+
+struct RawAllocator {
+ RawAllocator(OpBuilder &builder, Location loc) : builder(builder), loc(loc) {}
+
+ std::variant<Value, int64_t> computeTotalBytes(MemRefType srcType,
+ Value srcMemref) {
+ // Element size in bytes.
+ int64_t elemBitWidth = srcType.getElementTypeBitWidth();
+ int64_t elemByteWidth = (elemBitWidth + 7) / 8;
+
+ if (srcType.hasStaticShape()) {
+ // Static shape: compute total bytes statically.
+ int64_t numElements = 1;
+ for (int64_t dim : srcType.getShape()) {
+ numElements *= dim;
+ }
+ return numElements * elemByteWidth;
+ }
+
+ auto sizes = getSizes(srcType, srcMemref);
+ // Compute number of elements dynamically.
+ Value numElements = sizes.front();
+ for (auto size : llvm::drop_begin(sizes))
+ numElements = builder.create<arith::MulIOp>(loc, numElements, size);
+ Value elemSize = builder.create<arith::ConstantIndexOp>(loc, elemByteWidth);
+
+ return builder.create<arith::MulIOp>(loc, numElements, elemSize);
+ }
+
+ SmallVector<Value> getSizes(MemRefType type, Value memref) {
+ SmallVector<Value> sizes;
+ for (unsigned i = 0; i < type.getRank(); ++i) {
+ if (type.isDynamicDim(i)) {
+ sizes.push_back(builder.create<memref::DimOp>(loc, memref, i));
+ } else {
+ sizes.push_back(
+ builder.create<arith::ConstantIndexOp>(loc, type.getShape()[i]));
+ }
+ }
+ return sizes;
+ }
+
+ SmallVector<Value> getDynamicSizes(MemRefType type, Value memref) {
+ SmallVector<Value> sizes;
+ for (unsigned i = 0; i < type.getRank(); ++i) {
+ if (type.isDynamicDim(i)) {
+ sizes.push_back(builder.create<memref::DimOp>(loc, memref, i));
+ }
+ }
+ return sizes;
+ }
+
+ SmallVector<Value> getIdentityStrides(MemRefType type) {
+ SmallVector<Value> strides;
+ int64_t runningStride = 1;
+ for (int64_t dim : llvm::reverse(type.getShape())) {
+ strides.push_back(
+ builder.create<arith::ConstantIndexOp>(loc, runningStride));
+ if (dim != ShapedType::kDynamic)
+ runningStride *= dim;
+ else
+ runningStride = -1; // not handling dynamic strides.
+ }
+ std::reverse(strides.begin(), strides.end());
+ return strides;
+ }
+
+private:
+ OpBuilder &builder;
+ Location loc;
+};
+
+// Replace uses according to predicates automatically.
+template <typename OpTy>
+void replaceUsesWithPredicate(
+ OpTy originalValue,
+ ArrayRef<std::pair<std::function<bool(OpOperand &)>, Value>> replacements,
+ ConversionPatternRewriter &rewriter) {
+
+ for (OpOperand &use : llvm::make_early_inc_range(originalValue->getUses())) {
+ for (const auto &[predicate, newValue] : replacements) {
+ if (predicate(use)) {
+ use.set(newValue);
+ break;
+ }
+ }
+ }
+}
+
+//===----------------------------------------------------------------------===//
+// Convertion patterns
+//===----------------------------------------------------------------------===//
+namespace {
+
+//===----------------------------------------------------------------------===//
+// FunctionOp conversion pattern
+//===----------------------------------------------------------------------===//
+template <typename FuncLikeOp>
+struct ConvertFuncOp final : public OpConversionPattern<FuncLikeOp> {
+ ConvertFuncOp(MLIRContext *context, TypeConverter &typeConverter,
+ ArrayRef<Type> sourceTypes, ArrayRef<Type> targetTypes,
+ DenseMap<StringAttr, FunctionType> &convertedFuncTypes)
+ : OpConversionPattern<FuncLikeOp>(context),
+ typeConverter(typeConverter), // Store the reference
+ sourceTypes(sourceTypes), targetTypes(targetTypes),
+ convertedFuncTypes(convertedFuncTypes) {}
+ using OpConversionPattern<FuncLikeOp>::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(FuncLikeOp op, typename FuncLikeOp::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ // Only handle functions a gpu.module
+ if (!op->template getParentOfType<gpu::GPUModuleOp>())
+ return failure();
+ FunctionType oldFuncType = op.getFunctionType();
+
+ // Convert function signature
+ TypeConverter::SignatureConversion signatureConverter(
+ oldFuncType.getNumInputs());
+ for (const auto &argType :
+ llvm::enumerate(op.getFunctionType().getInputs())) {
+ auto convertedType = typeConverter.convertType(argType.value());
+ if (!convertedType)
+ return failure();
+ signatureConverter.addInputs(argType.index(), convertedType);
+ }
+ SmallVector<Type, 4> newResultTypes;
+ for (const auto &resultType : llvm::enumerate(oldFuncType.getResults())) {
+ auto convertedType = typeConverter.convertType(resultType.value());
+ if (!convertedType)
+ return failure();
+ newResultTypes.push_back(convertedType);
+ }
+
+ // Convert function signature
+ FunctionType newFuncType = rewriter.getFunctionType(
+ signatureConverter.getConvertedTypes(), newResultTypes);
+
+ if (!newFuncType)
+ return rewriter.notifyMatchFailure(op, "could not convert function "
+ "type");
+
+ // Create new GPU function with converted type
+ auto newFuncOp =
+ rewriter.create<FuncLikeOp>(op.getLoc(), op.getName(), newFuncType);
+
+ newFuncOp.setVisibility(op.getVisibility());
+ // Copy attributes
+ for (auto attr : op->getAttrs()) {
+ // Skip the function_type attribute since it is already set by
+ // the newFuncType and we don't want to overwrite it.
+ if (attr.getName() != op.getFunctionTypeAttrName() &&
+ attr.getName() != SymbolTable::getSymbolAttrName())
+ newFuncOp->setAttr(attr.getName(), attr.getValue());
+ }
+
+ newFuncOp.getRegion().getBlocks().clear();
+ // Inline region approach
+ rewriter.inlineRegionBefore(op.getBody(), newFuncOp.getBody(),
+ newFuncOp.end());
+ // Convert block argument types using the type converter
+ if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), typeConverter,
+ &signatureConverter))) {
+ return rewriter.notifyMatchFailure(op, "could not convert region "
+ "types");
+ }
+
+ if (!op.use_empty()) {
+ op.emitError("Cannot erase func: still has uses");
+ }
+ for (Operation *user : op->getUsers()) {
+ user->emitRemark() << "User of function " << op.getName();
+ }
+ rewriter.eraseOp(op);
+ // Add the converted function type to the map
+ newFuncOp.getNameAttr().getValue();
+ convertedFuncTypes[newFuncOp.getNameAttr()] = newFuncType;
+ return success();
+ }
+
+private:
+ TypeConverter &typeConverter; // Store a reference
+ ArrayRef<Type> sourceTypes;
+ ArrayRef<Type> targetTypes;
+ DenseMap<StringAttr, FunctionType> &convertedFuncTypes;
+};
+
+//===----------------------------------------------------------------------===//
+// CallOp conversion pattern
+//===----------------------------------------------------------------------===//
+struct ConvertCallOp : OpConversionPattern<func::CallOp> {
+ ConvertCallOp(MLIRContext *context, TypeConverter &typeConverter,
+ const DenseMap<StringAttr, FunctionType> &convertedFuncTypes)
+ : OpConversionPattern(context), convertedFuncTypes(convertedFuncTypes) {}
+
+ LogicalResult
+ matchAndRewrite(func::CallOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto callee = op.getCalleeAttr();
+
+ auto it = convertedFuncTypes.find(
+ StringAttr::get(callee.getContext(), callee.getValue()));
+ if (it == convertedFuncTypes.end())
+ return rewriter.notifyMatchFailure(
+ op, "Callee signature not converted. Perhaps the callee is not in "
+ "the same gpu module as the caller.");
+
+ auto newResultTypes = it->second.getResults();
+ rewriter.replaceOpWithNewOp<func::CallOp>(
+ op, callee.getValue(), newResultTypes, adaptor.getOperands());
+
+ return success();
+ }
+
+private:
+ const DenseMap<StringAttr, FunctionType> &convertedFuncTypes;
+};
+
+//===----------------------------------------------------------------------===//
+// GPULaunchFuncOp conversion pattern
+//===----------------------------------------------------------------------===//
+struct ConvertGPULaunchFuncOp : OpConversionPattern<gpu::LaunchFuncOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(gpu::LaunchFuncOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ std::optional<KernelDim3> clusterSizeOpernads =
+ op.hasClusterSize()
+ ? std::optional<gpu::KernelDim3>(op.getClusterSizeOperandValues())
+ : std::nullopt;
+
+ // Create the new launch_func.
+ auto newOp = rewriter.create<gpu::LaunchFuncOp>(
+ op.getLoc(), adaptor.getKernel(), op.getGridSizeOperandValues(),
+ op.getBlockSizeOperandValues(), op.getDynamicSharedMemorySize(),
+ adaptor.getKernelOperands(), op.getAsyncObject(), clusterSizeOpernads);
+
+ // Copy block size and grid size attributes
+ newOp->setAttrs(op->getAttrs());
+ rewriter.replaceOp(op, newOp.getResults());
+ return success();
+ }
+};
+
+//===----------------------------------------------------------------------===//
+// AllocOp conversion pattern
+//===----------------------------------------------------------------------===//
+template <typename AllocOp>
+struct ConvertAllocOp : OpConversionPattern<AllocOp> {
+ ConvertAllocOp(MLIRContext *ctx, TypeConverter &typeConverter)
+ : OpConversionPattern<AllocOp>(ctx), typeConverter(typeConverter) {}
+
+ LogicalResult
+ matchAndRewrite(AllocOp op, typename AllocOp::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ MemRefType srcType = llvm::cast<MemRefType>(op.getType());
+ // Only supports memref types with identity layout. Since this mechanism
+ // requires the usage of memref.ViewOp, which requires the layout to be
+ // identity.
+ if (!srcType.getLayout().isIdentity())
+ op.emitError("only memrefs with identity layout is supported");
+
+ auto dstType =
+ dyn_cast_or_null<MemRefType>(typeConverter.convertType(srcType));
+ if (!dstType || dstType == srcType)
+ return failure(); // No need to rewrite.
+
+ // Helper class to allocate raw memory.
+ RawAllocator allocator(rewriter, loc);
+
+ // 1. Compute total allocation size.
+ auto totalBytes = allocator.computeTotalBytes(srcType, op.getMemref());
+
+ // 2. Create raw i8 buffer.
+ MemRefType rawType;
+ if (std::holds_alternative<int64_t>(totalBytes)) {
+ // Static size.
+ SmallVector<int64_t> staticI8Shape;
+ staticI8Shape.push_back(std::get<int64_t>(totalBytes));
+ rawType = MemRefType::get(staticI8Shape, rewriter.getI8Type(), {},
+ srcType.getMemorySpaceAsInt());
+ } else {
+ // Dynamic size.
+ rawType = MemRefType::get({ShapedType::kDynamic}, rewriter.getI8Type(),
+ {}, srcType.getMemorySpaceAsInt());
+ }
+ Value rawAlloc;
+
+ if constexpr (std::is_same_v<AllocOp, gpu::AllocOp>) {
+ rawAlloc =
+ rewriter
+ .create<gpu::AllocOp>(
+ loc, rawType,
+ op.getAsyncToken() ? op.getAsyncToken().getType() : nullptr,
+ adaptor.getAsyncDependencies(),
+ std::holds_alternative<Value>(totalBytes)
+ ? ValueRange{std::get<Value>(totalBytes)}
+ : ValueRange{},
+ adaptor.getSymbolOperands(), op.getHostShared())
+ .getResult(0);
+ } else {
+ rawAlloc = rewriter.create<memref::AllocOp>(
+ loc, rawType,
+ std::holds_alternative<Value>(totalBytes)
+ ? ValueRange{std::get<Value>(totalBytes)}
+ : ValueRange{},
+ op.getSymbolOperands());
+ }
+
+ // 3. Create view for original type.
+ SmallVector<Value> dynamicSizes =
+ allocator.getDynamicSizes(srcType, op.getMemref());
+ // Since we are using memref::ViewOp, only identity strides are supported.
+ SmallVector<Value> dynamicStrides = allocator.getIdentityStrides(srcType);
+ Value zeroOffset = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ Value originalView = rewriter.create<memref::ViewOp>(
+ loc, srcType, rawAlloc, zeroOffset, dynamicSizes);
+
+ // 4. Create view for converted type.
+ Value convertedView = rewriter.create<memref::ViewOp>(
+ loc, dstType, rawAlloc, zeroOffset, dynamicSizes);
+
+ // 5. Replace uses:
+ // gpu::LaunchFuncOp uses -> Replace the original AllocOp use in
+ // gpu::LaunchFuncOp with the view of the
+ // converted type.
+ //
+ // DeallocOp uses -> Replace the original AllocOp use in dealloc with
+ // the new AllocOp.
+ //
+ // Other uses-> Replace the original AllocOp use with the view of the
+ // original type.
+
+ SmallVector<OpOperand *> launchFuncUses;
+ SmallVector<OpOperand *> deallocUses;
+ SmallVector<OpOperand *> otherUses;
+
+ for (OpOperand &use : op->getUses()) {
+ if (isa<gpu::LaunchFuncOp>(use.getOwner())) {
+ launchFuncUses.push_back(&use);
+ } else if (isa<memref::DeallocOp>(use.getOwner()) ||
+ isa<gpu::DeallocOp>(use.getOwner())) {
+ deallocUses.push_back(&use);
+ } else {
+ otherUses.push_back(&use);
+ }
+ }
+
+ for (OpOperand *use : launchFuncUses)
+ use->set(convertedView);
+ for (OpOperand *use : deallocUses)
+ use->set(rawAlloc);
+ for (OpOperand *use : otherUses)
+ use->set(originalView);
+
+ // Erase the original AllocOp.
+ rewriter.eraseOp(op);
+ return success();
+ }
+
+private:
+ TypeConverter &typeConverter;
+};
+
+//===----------------------------------------------------------------------===//
+// ArithConstantOp conversion pattern
+//===----------------------------------------------------------------------===//
+struct ConvertArithConstantOp : OpConversionPattern<arith::ConstantOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ ConvertArithConstantOp(MLIRContext *context, TypeConverter &typeConverter,
+ ArrayRef<Type> sourceTypes, ArrayRef<Type> targetTypes)
+ : OpConversionPattern(context),
+ typeConverter(typeConverter), // Store the reference.
+ sourceTypes(sourceTypes), targetTypes(targetTypes) {}
+
+ LogicalResult
+ matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Type srcType = op.getType();
+ Type dstType = typeConverter.convertType(srcType);
+ if (!dstType || dstType == srcType)
+ return failure();
+
+ Attribute value = op.getValue();
+ Value newConstOp = nullptr;
+
+ // When source is IntegerAttr.
+ if (auto intAttr = dyn_cast<IntegerAttr>(value)) {
+ APInt intVal = intAttr.getValue();
+ if (isa<FloatType>(dstType)) {
+ auto newAttr = getFloatAttrFromIntegerAttr(intAttr, dstType, rewriter);
+ newConstOp =
+ rewriter.create<arith::ConstantOp>(op.getLoc(), dstType, newAttr);
+ } else if (isa<IntegerType>(dstType)) {
+ auto newAttr = rewriter.getIntegerAttr(dstType, intVal);
+ newConstOp =
+ rewriter.create<arith::ConstantOp>(op.getLoc(), dstType, newAttr);
+ } else {
+ return rewriter.notifyMatchFailure(
+ op, "expected integer or float target type for constant");
+ }
+ }
+
+ // When source is FloatAttr.
+ else if (auto floatAttr = dyn_cast<FloatAttr>(value)) {
+ if (llvm::isa<IntegerType>(dstType)) {
+ auto newAttr =
+ getIntegerAttrFromFloatAttr(floatAttr, dstType, rewriter);
+ newConstOp =
+ rewriter.create<arith::ConstantOp>(op.getLoc(), dstType, newAttr);
+ } else if (llvm::isa<FloatType>(dstType)) {
+ auto newAttr = rewriter.getFloatAttr(dstType, floatAttr.getValue());
+ newConstOp =
+ rewriter.create<arith::ConstantOp>(op.getLoc(), dstType, newAttr);
+ } else {
+ return rewriter.notifyMatchFailure(
+ op, "expected integer or float target type for constant");
+ }
+ }
+ // Handle DenseElementsAttr.
+ else if (auto denseAttr = dyn_cast<DenseElementsAttr>(value)) {
+ Type newEltType;
+ if (auto shapedType = dyn_cast<ShapedType>(dstType))
+ newEltType = shapedType.getElementType();
+ else
+ return rewriter.notifyMatchFailure(
+ op, "expected shaped type for dense constant");
+
+ SmallVector<Attribute> newValues;
+ for (Attribute attr : denseAttr.getValues<Attribute>()) {
+ if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
+ if (llvm::isa<FloatType>(newEltType)) {
+ auto newAttr =
+ getFloatAttrFromIntegerAttr(intAttr, newEltType, rewriter);
+ newValues.push_back(newAttr);
+ } else if (llvm::isa<IntegerType>(newEltType)) {
+ newValues.push_back(
+ rewriter.getIntegerAttr(newEltType, intAttr.getValue()));
+ } else {
+ return rewriter.notifyMatchFailure(
+ op, "unsupported target element type in dense constant");
+ }
+ } else if (auto floatAttr = dyn_cast<FloatAttr>(attr)) {
+ if (llvm::isa<IntegerType>(newEltType)) {
+ auto newAttr =
+ getIntegerAttrFromFloatAttr(floatAttr, newEltType, rewriter);
+ newValues.push_back(newAttr);
+ } else if (llvm::isa<FloatType>(newEltType))
+ newValues.push_back(
+ rewriter.getFloatAttr(newEltType, floatAttr.getValue()));
+ else
+ return rewriter.notifyMatchFailure(
+ op, "unsupported target element type in dense constant");
+ } else {
+ return rewriter.notifyMatchFailure(
+ op, "unsupported target element type in dense constant");
+ }
+ }
+
+ auto newAttr =
+ DenseElementsAttr::get(cast<ShapedType>(dstType), newValues);
+ newConstOp =
+ rewriter.create<arith::ConstantOp>(op.getLoc(), dstType, newAttr);
+ }
+ if (!newConstOp)
+ return rewriter.notifyMatchFailure(
+ op, "unsupported constant type for source to target conversion");
+
+ auto bitcastOp =
+ rewriter.create<arith::BitcastOp>(op.getLoc(), srcType, newConstOp);
+ rewriter.replaceOp(op, bitcastOp.getResult());
+ return success();
+ }
+
+private:
+ TypeConverter &typeConverter; // Store a reference.
+ ArrayRef<Type> sourceTypes;
+ ArrayRef<Type> targetTypes;
+};
+
+//===----------------------------------------------------------------------===//
+// GenericOp conversion pattern
+//===----------------------------------------------------------------------===//
+struct ConvertOpWithSourceType final : ConversionPattern {
+ ConvertOpWithSourceType(MLIRContext *context,
+ const TypeConverter &typeConverter,
+ ArrayRef<Type> sourceTypes,
+ ArrayRef<Type> targetTypes)
+ : ConversionPattern(typeConverter, MatchAnyOpTypeTag{}, 1, context),
+ sourceTypes(sourceTypes), targetTypes(targetTypes) {}
+
+ LogicalResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ SmallVector<Type, 4> newResultTypes;
+ for (Type t : op->getResultTypes()) {
+ Type converted = typeConverter->convertType(t);
+ if (!converted)
+ return failure();
+ newResultTypes.push_back(converted);
+ }
+
+ // Clone the op manually with the converted result types
+ OperationState state(op->getLoc(), op->getName().getStringRef());
+ state.addOperands(operands);
+ state.addTypes(newResultTypes);
+ state.addAttributes(op->getAttrs());
+
+ for ([[maybe_unused]] auto ®ion : op->getRegions())
+ state.regions.emplace_back();
+
+ Operation *newOp = rewriter.create(state);
+ // Transfer regions and convert them
+ for (auto [oldRegion, newRegion] :
+ llvm::zip(op->getRegions(), newOp->getRegions())) {
+ if (!oldRegion.empty()) {
+ newRegion.takeBody(oldRegion);
+ if (failed(rewriter.convertRegionTypes(&newRegion, *typeConverter))) {
+ return rewriter.notifyMatchFailure(op,
+ "region type conversion failed");
+ }
+ }
+ }
+
+ rewriter.replaceOp(op, newOp->getResults());
+ return success();
+ }
+
+private:
+ ArrayRef<Type> sourceTypes;
+ ArrayRef<Type> targetTypes;
+};
+
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// Type Converter
+//===----------------------------------------------------------------------===//
+
+void mlir::populateImitateUnsupportedTypesTypeConverter(
+ TypeConverter &typeConverter, ArrayRef<Type> sourceTypes,
+ ArrayRef<Type> targetTypes) {
+ auto srcTypes = SmallVector<Type>(sourceTypes);
+ auto tgtTypes = SmallVector<Type>(targetTypes);
+
+ assert(sourceTypes.size() == targetTypes.size() &&
+ "Source and target types must have same size");
+
+ typeConverter.addConversion([srcTypes, tgtTypes](Type type) -> Type {
+ if (type.isIntOrIndexOrFloat()) {
+ for (auto [src, tgt] : llvm::zip_equal(srcTypes, tgtTypes)) {
+ if (type == src)
+ return tgt;
+ }
+ } else if (auto memref = llvm::dyn_cast<MemRefType>(type)) {
+ Type elemType = memref.getElementType();
+ for (auto [src, tgt] : llvm::zip_equal(srcTypes, tgtTypes)) {
+ if (elemType == src)
+ return MemRefType::get(memref.getShape(), tgt, memref.getLayout(),
+ memref.getMemorySpace());
+ }
+ } else if (auto vec = llvm::dyn_cast<VectorType>(type)) {
+ Type elemType = vec.getElementType();
+ for (auto [src, tgt] : llvm::zip_equal(srcTypes, tgtTypes)) {
+ if (elemType == src)
+ return VectorType::get(vec.getShape(), tgt);
+ }
+ }
+ return type;
+ });
+
+ auto materializeCast = [](OpBuilder &builder, Type resultType,
+ ValueRange inputs, Location loc) -> Value {
+ assert(inputs.size() == 1 && "Expected single input");
+ Type inputType = inputs[0].getType();
+ if (isa<MemRefType>(resultType) && isa<MemRefType>(inputType)) {
+ return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
+ .getResult(0);
+ }
+ if ((resultType.isIntOrIndexOrFloat() || isa<VectorType>(resultType)) &&
+ (inputType.isIntOrIndexOrFloat() || isa<VectorType>(inputType))) {
+ return builder.create<arith::BitcastOp>(loc, resultType, inputs[0])
+ .getResult();
+ }
+ return nullptr;
+ };
+
+ typeConverter.addSourceMaterialization(materializeCast);
+ typeConverter.addTargetMaterialization(materializeCast);
+ typeConverter.addArgumentMaterialization(materializeCast);
+}
+
+//===----------------------------------------------------------------------===//
+// Pattern Population
+//===----------------------------------------------------------------------===//
+
+void mlir::populateImitateUnsupportedTypesConversionPatterns(
+ RewritePatternSet &patterns, TypeConverter &typeConverter,
+ ArrayRef<Type> sourceTypes, ArrayRef<Type> targetTypes,
+ DenseMap<StringAttr, FunctionType> &convertedFuncTypes) {
+ auto ctx = patterns.getContext();
+ auto srcTypes = SmallVector<Type>(sourceTypes);
+ auto tgtTypes = SmallVector<Type>(targetTypes);
+ assert(srcTypes.size() == tgtTypes.size() &&
+ "Source and target types must have same size");
+
+ patterns.add<ConvertOpWithSourceType>(ctx, typeConverter, srcTypes, tgtTypes);
+ patterns.add<ConvertFuncOp<gpu::GPUFuncOp>, ConvertFuncOp<func::FuncOp>>(
+ ctx, typeConverter, srcTypes, tgtTypes, convertedFuncTypes);
+ patterns.add<ConvertCallOp>(ctx, typeConverter, convertedFuncTypes);
+ patterns.add<ConvertArithConstantOp>(ctx, typeConverter, srcTypes, tgtTypes);
+ patterns.add<ConvertGPULaunchFuncOp>(ctx);
+ patterns.add<ConvertAllocOp<gpu::AllocOp>>(ctx, typeConverter);
+ patterns.add<ConvertAllocOp<memref::AllocOp>>(ctx, typeConverter);
+}
+
+//===----------------------------------------------------------------------===//
+// Conversion Legality configuration
+//===----------------------------------------------------------------------===//
+
+void mlir::configureImitateUnsupportedTypesLegality(
+ ConversionTarget &target, TypeConverter &typeConverter) {
+ target.addLegalDialect<arith::ArithDialect>();
+ target.addLegalDialect<math::MathDialect>();
+ // Make Memref, func dialect legal for all ops in host code
+ target.addDynamicallyLegalDialect<memref::MemRefDialect>([&](Operation *op) {
+ if (op->getParentOfType<gpu::GPUModuleOp>())
+ return typeConverter.isLegal(op);
+ else
+ return true;
+ });
+
+ target.addDynamicallyLegalDialect<gpu::GPUDialect>(
+ [&](Operation *op) { return typeConverter.isLegal(op); });
+
+ target.addDynamicallyLegalDialect<func::FuncDialect>([&](Operation *op) {
+ if (op->getParentOfType<gpu::GPUModuleOp>())
+ return typeConverter.isLegal(op);
+ else
+ return true;
+ });
+
+ target.addLegalOp<gpu::GPUModuleOp>();
+ target.addLegalOp<UnrealizedConversionCastOp>();
+ // Manually mark arithmetic-performing vector instructions.
+ target.addLegalOp<vector::ContractionOp, vector::ReductionOp,
+ vector::MultiDimReductionOp, vector::FMAOp,
+ vector::OuterProductOp, vector::MatmulOp, vector::ScanOp,
+ vector::SplatOp>();
+ target.addDynamicallyLegalOp<arith::ConstantOp>([&](arith::ConstantOp op) {
+ return typeConverter.isLegal(op.getType());
+ });
+ target.addDynamicallyLegalOp<gpu::GPUFuncOp>([&](gpu::GPUFuncOp op) {
+ return typeConverter.isSignatureLegal(op.getFunctionType());
+ });
+ // Only convert functions and function calls in gpu.module
+ target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
+ if (op->getParentOfType<gpu::GPUModuleOp>())
+ return typeConverter.isSignatureLegal(op.getFunctionType());
+ return true;
+ });
+ target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) {
+ if (op->getParentOfType<gpu::GPUModuleOp>())
+ return typeConverter.isSignatureLegal(op.getCalleeType());
+ return true;
+ });
+
+ // Only convert alloc ops in gpu.module or in host functions and has a use
+ // in LaunchFunc
+ target.addDynamicallyLegalOp<memref::AllocOp>([&](memref::AllocOp op) {
+ if (op->getParentOfType<gpu::GPUModuleOp>())
+ return typeConverter.isLegal(op.getType());
+ else {
+ for (auto user : op->getUsers()) {
+ if (isa<gpu::LaunchFuncOp>(user))
+ return typeConverter.isLegal(op.getType());
+ }
+ }
+ return true;
+ });
+
+ // Mark unknown ops that are inside gpu.module, and one of its's operand is a
+ // memref type as dynamically legal.
+ target.markUnknownOpDynamicallyLegal([&typeConverter](Operation *op) -> bool {
+ // Check if the operation is inside a gpu.module.
+ if (op->getParentOfType<gpu::GPUModuleOp>()) {
+ // Check if the operation has any operands of type MemRefType.
+ for (Value operand : op->getOperands()) {
+ if (isa<MemRefType>(operand.getType()))
+ return typeConverter.isLegal(op);
+ }
+ // If no operands are of type MemRefType, mark it as illegal.
+ return true;
+ }
+ return true; // If not in gpu.module, mark it as legal.
+ });
+}
+
+//===----------------------------------------------------------------------===//
+// Pass Definition
+//===----------------------------------------------------------------------===//
+
+namespace {
+struct GpuImitateUnsupportedTypesPass
+ : public impl::GpuImitateUnsupportedTypesBase<
+ GpuImitateUnsupportedTypesPass> {
+ using Base::Base;
+
+ void runOnOperation() override {
+ MLIRContext *ctx = &getContext();
+ Operation *op = getOperation();
+
+ SmallVector<Type> sourceTypes;
+ SmallVector<Type> targetTypes;
+
+ // Parse source types
+ for (StringRef sourceTypeStr : sourceTypeStrs) {
+ std::optional<Type> maybeSourceType =
+ arith::parseIntOrFloatType(ctx, sourceTypeStr);
+
+ if (!maybeSourceType) {
+ emitError(UnknownLoc::get(ctx),
+ "could not map source type '" + sourceTypeStr +
+ "' to a known integer or floating-point type.");
+ return signalPassFailure();
+ }
+ sourceTypes.push_back(*maybeSourceType);
+ }
+ if (sourceTypes.empty()) {
+ (void)emitOptionalWarning(std::nullopt, "no source types "
+ "specified, type "
+ "imitation will do "
+ "nothing");
+ }
+
+ // Parse target types
+ for (StringRef targetTypeStr : targetTypeStrs) {
+ std::optional<Type> maybeTargetType =
+ arith::parseIntOrFloatType(ctx, targetTypeStr);
+
+ if (!maybeTargetType) {
+ emitError(UnknownLoc::get(ctx),
+ "could not map target type '" + targetTypeStr +
+ "' to a known integer or floating-point type");
+ return signalPassFailure();
+ }
+ targetTypes.push_back(*maybeTargetType);
+
+ if (llvm::is_contained(sourceTypes, *maybeTargetType)) {
+ emitError(UnknownLoc::get(ctx),
+ "target type cannot be an unsupported source type");
+ return signalPassFailure();
+ }
+ }
+ if (targetTypes.empty()) {
+ (void)emitOptionalWarning(
+ std::nullopt,
+ "no target types specified, type imitation will do nothing");
+ }
+
+ // Set up the type converter
+ TypeConverter typeConverter;
+ populateImitateUnsupportedTypesTypeConverter(typeConverter, sourceTypes,
+ targetTypes);
+
+ // Populate the conversion patterns
+ RewritePatternSet patterns(ctx);
+ DenseMap<StringAttr, FunctionType> convertedFuncTypes;
+ populateImitateUnsupportedTypesConversionPatterns(
+ patterns, typeConverter, sourceTypes, targetTypes, convertedFuncTypes);
+
+ // Set up conversion target and configure the legality of the conversion
+ ConversionTarget target(*ctx);
+ configureImitateUnsupportedTypesLegality(target, typeConverter);
+
+ // Apply the conversion
+ if (failed(applyPartialConversion(op, target, std::move(patterns))))
+ signalPassFailure();
+
+ // Post-conversion validation: check for any remaining
+ // unrealized_conversion_cast
+ bool hasUnresolvedCast = false;
+ op->walk([&](UnrealizedConversionCastOp op) {
+ // Check if the cast is from a source type to a target type
+ for (auto [sourceType, targetType] :
+ llvm::zip_equal(sourceTypes, targetTypes)) {
+ if (getElementTypeOrSelf(op.getOperand(0).getType()) == sourceType &&
+ getElementTypeOrSelf(op.getResult(0).getType()) == targetType) {
+ op->emitError("unresolved unrealized_conversion_cast left in IR "
+ "after conversion");
+ hasUnresolvedCast = true;
+ }
+ }
+ });
+
+ if (hasUnresolvedCast) {
+ signalPassFailure();
+ }
+ }
+};
+} // namespace
diff --git a/mlir/test/Dialect/GPU/imitate-unsupported-types.mlir b/mlir/test/Dialect/GPU/imitate-unsupported-types.mlir
new file mode 100644
index 0000000000000..8279a2e4594b1
--- /dev/null
+++ b/mlir/test/Dialect/GPU/imitate-unsupported-types.mlir
@@ -0,0 +1,141 @@
+// RUN: mlir-opt -verify-diagnostics -imitate-unsupported-types="source-types=bf16 target-types=i16" --canonicalize -split-input-file %s | FileCheck %s
+
+// CHECK: module @builtin_module
+module @builtin_module {
+ // CHECK: gpu.module @gpu_func_module {
+ gpu.module @gpu_func_module attributes{} {
+ // CHECK-LABEL: gpu.func @arith_and_vector_ops
+ // CHECK-SAME: (%[[ARG0:.*]]: memref<10x10xi16>, %[[ARG1:.*]]: memref<10x10xf32>, %[[ARG2:.*]]: vector<10x10xi16>, %[[ARG3:.*]]: memref<10x10xi16>, %[[ARG4:.*]]: vector<10x10xi16>) kernel
+ gpu.func @arith_and_vector_ops(%arg0: memref<10x10xbf16>, %arg1: memref<10x10xf32>, %arg2: vector<10x10xbf16>, %arg3: memref<10x10xi16>, %arg4: vector<10x10xi16>) kernel attributes {} {
+
+ %c0 = arith.constant 0 : index
+
+ // CHECK: %[[ARG2_CAST:.*]] = arith.bitcast %[[ARG2]] : vector<10x10xi16> to vector<10x10xbf16>
+ // CHECK: %[[LOAD1:.*]] = vector.load %[[ARG0]][%c0, %c0] : memref<10x10xi16>, vector<10x10xi16>
+ // CHECK: %[[BITCAST1:.*]] = arith.bitcast %[[LOAD1]] : vector<10x10xi16> to vector<10x10xbf16>
+ %2 = vector.load %arg0[%c0, %c0] : memref<10x10xbf16>, vector<10x10xbf16>
+
+ // CHECK: %[[ADDF:.*]] = arith.addf %[[BITCAST1]], %[[ARG2_CAST]] : vector<10x10xbf16>
+ %add = arith.addf %2, %arg2 : vector<10x10xbf16>
+
+ // CHECK: %[[EXTF1:.*]] = arith.extf %[[BITCAST1]] : vector<10x10xbf16> to vector<10x10xf32>
+ %3 = arith.extf %2 : vector<10x10xbf16> to vector<10x10xf32>
+
+ // CHECK: %[[EXTF2:.*]] = arith.extf %[[ADDF]] : vector<10x10xbf16> to vector<10x10xf32>
+ %4 = arith.extf %add : vector<10x10xbf16> to vector<10x10xf32>
+
+ // CHECK: %[[ADDF2:.*]] = arith.addf %[[EXTF1]], %[[EXTF2]] : vector<10x10xf32>
+ %5 = arith.addf %3, %4 : vector<10x10xf32>
+
+ // CHECK: %[[TRUNCF:.*]] = arith.truncf %[[ADDF2]] : vector<10x10xf32> to vector<10x10xbf16>
+ %6 = arith.truncf %5 : vector<10x10xf32> to vector<10x10xbf16>
+
+ // CHECK: %[[TRUNCF_CAST:.*]] = arith.bitcast %[[TRUNCF]] : vector<10x10xbf16> to vector<10x10xi16>
+ // CHECK: vector.store %[[TRUNCF_CAST]], %[[ARG0]][%c0, %c0] : memref<10x10xi16>, vector<10x10xi16>
+ vector.store %6, %arg0[%c0, %c0] : memref<10x10xbf16>, vector<10x10xbf16>
+
+ // CHECK: %[[LOAD2:.*]] = vector.load %[[ARG3]][%c0, %c0] : memref<10x10xi16>, vector<10x10xi16>
+ %7 = vector.load %arg3[%c0, %c0] : memref<10x10xi16>, vector<10x10xi16>
+
+ // CHECK: %[[ADDI:.*]] = arith.addi %[[LOAD2]], %[[ARG4]] : vector<10x10xi16>
+ %8 = arith.addi %7, %arg4 : vector<10x10xi16>
+
+ // CHECK: vector.store %[[ADDI]], %[[ARG3]][%c0, %c0] : memref<10x10xi16>, vector<10x10xi16>
+ vector.store %8, %arg3[%c0, %c0] : memref<10x10xi16>, vector<10x10xi16>
+
+ gpu.return
+ }
+ }
+}
+
+// -----
+
+
+// CHECK: module @caller_callee_launch_func_module attributes {gpu.container_module}
+module @caller_callee_launch_func_module attributes {gpu.container_module} {
+
+ // CHECK: gpu.module @caller_callee_gpu_module {
+ gpu.module @caller_callee_gpu_module attributes{} {
+
+ // CHECK: gpu.func @caller_func(%[[ARG0:.*]]: memref<10x10xi16>, %[[ARG1:.*]]: vector<10x10xi16>) kernel {
+ gpu.func @caller_func(%arg0: memref<10x10xbf16>, %arg1: vector<10x10xbf16>) kernel attributes {} {
+ %c0 = arith.constant 0 : index
+
+ // CHECK: %[[CALL_RET:.*]] = func.call @callee_constant_return() : () -> vector<10x10xi16>
+ %func_result = func.call @callee_constant_return() : () -> vector<10x10xbf16>
+
+ // CHECK: vector.store %[[CALL_RET]], %[[ARG0]][%c0, %c0] : memref<10x10xi16>, vector<10x10xi16>
+ vector.store %func_result, %arg0[%c0, %c0] : memref<10x10xbf16>, vector<10x10xbf16>
+
+ // CHECK: func.call @callee_func(%[[CALL_RET]]) : (vector<10x10xi16>) -> ()
+ func.call @callee_func(%func_result) : (vector<10x10xbf16>) -> ()
+
+ gpu.return
+ }
+
+ // CHECK: func.func @callee_constant_return() -> vector<10x10xi16> {
+ func.func @callee_constant_return() -> vector<10x10xbf16> {
+ // CHECK: arith.constant dense<16128> : vector<10x10xi16>
+ %dense_const = arith.constant dense<5.000000e-01> : vector<10x10xbf16>
+ func.return %dense_const : vector<10x10xbf16>
+ }
+
+ // CHECK: func.func @callee_func(%[[ARG:.*]]: vector<10x10xi16>) {
+ func.func @callee_func(%arg0: vector<10x10xbf16>) {
+ return
+ }
+ }
+
+ // CHECK: func.func @gpu_launch_func(%[[ARG0:.*]]: memref<10x10xbf16>, %[[ARG1:.*]]: vector<10x10xbf16>) {
+ func.func @gpu_launch_func(%arg0: memref<10x10xbf16>, %arg1: vector<10x10xbf16>) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ // CHECK: arith.constant dense<16128> : vector<10x10xi16>
+ %dense_const = arith.constant dense<5.000000e-01> : vector<10x10xbf16>
+ // CHECK: arith.constant dense<6.015630e-01> : vector<10x10xbf16>
+ %dense_const_2 = arith.constant dense<6.000000e-01> : vector<10x10xbf16>
+
+ // CHECK: %[[ALLOC:.*]] = gpu.alloc () : memref<200xi8>
+ %alloc = gpu.alloc () : memref<10x10xbf16>
+
+ vector.store %dense_const_2, %alloc[%c0, %c0] : memref<10x10xbf16>, vector<10x10xbf16>
+ // CHECK: %[[VIEW:.*]] = memref.view %[[ALLOC]][%c0][] : memref<200xi8> to memref<10x10xi16>
+ // CHECK: gpu.launch_func @caller_callee_gpu_module::@caller_func blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%[[VIEW]] : memref<10x10xi16>, %[[CST:.*]] : vector<10x10xi16>)
+ gpu.launch_func @caller_callee_gpu_module::@caller_func
+ blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1)
+ args(%alloc: memref<10x10xbf16>, %dense_const: vector<10x10xbf16>)
+ return
+ }
+}
+
+// -----
+
+// Only support alloc ops if it is in the same region as the launch op.
+// Otherwise, it will leave an unresolved unrealized_conversion_cast in the IR
+// due to typeconverter materialization.
+module @unsupported_module attributes {gpu.container_module} {
+ gpu.module @unsupported_gpu_module attributes{} {
+ gpu.func @kernel(%arg0: memref<10x10xbf16>) kernel attributes {} {
+ gpu.return
+ }
+ }
+
+ func.func @gpu_launch_func(%arg0: memref<10x10xbf16>) {
+ %c1 = arith.constant 1 : index
+ // expected-error at +1 {{unresolved unrealized_conversion_cast left in IR after conversion}}
+ gpu.launch_func @unsupported_gpu_module::@kernel
+ blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1)
+ args(%arg0: memref<10x10xbf16>)
+ return
+ }
+
+ func.func @main() {
+ %alloc = memref.alloc () : memref<10x10xbf16>
+ call @gpu_launch_func(%alloc) : (memref<10x10xbf16>) -> ()
+ memref.dealloc %alloc : memref<10x10xbf16>
+ return
+ }
+}
+
+// -----
+
>From 87f80dc6a99a3089f2cf75608ac20bd3729bcbb2 Mon Sep 17 00:00:00 2001
From: "Shahneous Bari, Md Abdullah" <md.abdullah.shahneous.bari at intel.com>
Date: Thu, 1 May 2025 06:34:19 +0000
Subject: [PATCH 2/2] Add the updated CMake file, missed in the first commit.
---
mlir/lib/Dialect/GPU/CMakeLists.txt | 5 +++--
1 file changed, 3 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/GPU/CMakeLists.txt b/mlir/lib/Dialect/GPU/CMakeLists.txt
index e21fa501bae6b..6d63f0d79e7d2 100644
--- a/mlir/lib/Dialect/GPU/CMakeLists.txt
+++ b/mlir/lib/Dialect/GPU/CMakeLists.txt
@@ -23,7 +23,7 @@ add_mlir_dialect_library(MLIRGPUDialect
MLIRMemRefDialect
MLIRSideEffectInterfaces
MLIRSupport
- )
+)
add_mlir_dialect_library(MLIRGPUTransforms
Transforms/AllReduceLowering.cpp
@@ -42,6 +42,7 @@ add_mlir_dialect_library(MLIRGPUTransforms
Transforms/SPIRVAttachTarget.cpp
Transforms/SubgroupIdRewriter.cpp
Transforms/SubgroupReduceLowering.cpp
+ Transforms/ImitateUnsupportedTypes.cpp
OBJECT
@@ -76,7 +77,7 @@ add_mlir_dialect_library(MLIRGPUTransforms
MLIRROCDLTarget
MLIRTransformUtils
MLIRVectorDialect
- )
+)
add_subdirectory(TransformOps)
add_subdirectory(Pipelines)
More information about the Mlir-commits
mailing list