[Mlir-commits] [mlir] [mlir][gpu] Add pass for emulating unsupported types. (PR #138087)

Md Abdullah Shahneous Bari llvmlistbot at llvm.org
Thu May 1 13:15:40 PDT 2025


================
@@ -0,0 +1,915 @@
+//===- ImitateUnsupportedTypes.cpp - Unsupported Type Imitation ----*- C++
+//-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//===----------------------------------------------------------------------===//
+//
+/// \file
+/// This pass imitates (bitcast/reinterpret_cast) unsupported types
+/// with supported types of same bitwidth. The imitation is done
+/// by bitcasting the unspported types to the supported types of same bitwidth.
+/// Therefore, the source type and destination type must have the same bitwidth.
+/// The imitation is done by using the following operations: arith.bitcast.
+///
+/// The imitation is often needed when the GPU target (dialect/IR) does not
+/// support a certain type but the underlying architecture does. Take SPIR-V for
+/// example, it does not support bf16, but an underlying architecture (e.g.,
+/// intel pvc gpu) that uses SPIR-V for code-generation does.
+/// Therefore, bf16 is neither a valid data type to pass to gpu kernel, nor to
+/// be used inside the kernel. To use bf16 data type in a SPIR-V kernel (as a
+/// kernel parameter or inside the kernel), bf16 have to be bitcasted (similar
+/// to C++ reinterpret_cast) to a supported type (e.g., i16 for Intel GPUs). The
+/// SPIR-V kernel can then use the imitated type (i16) in the computation.
+/// However, i16 is not the same as bf16 (integer vs float), so the computation
+/// can not readily use the imitated type (i16).
+///
+/// Therefore, this transformation pass is intended to be used in conjuction
+/// with other transformation passes such as `EmulateUnsupportedFloats` and
+/// `ExtendUnsupportedTypes` that extend the bitwidth of bf16 to f32 and
+/// vice-versa.
+///
+/// Finally, usually, there are instructions available in the target
+/// (dialect/IR) that can take advantage of these generated patterns
+/// (bf16->i16->f32, f32->bf16->i16), and convert them to the supported
+/// types.
+/// For example, Intel provides SPIR-V extension ops that can
+/// take imitated bf16 (i16) and convert them to f32 and vice-versa.
+/// https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/INTEL/SPV_INTEL_bfloat16_conversion.asciidoc
+/// https://mlir.llvm.org/docs/Dialects/SPIR-V/#spirvintelconvertbf16tof-spirvintelconvertbf16tofop
+/// https://mlir.llvm.org/docs/Dialects/SPIR-V/#spirvintelconvertftobf16-spirvintelconvertftobf16op
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/GPU/Transforms/Passes.h"
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/APFloat.h"
+#include "llvm/ADT/APInt.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/Support/ErrorHandling.h"
+
+#include <optional>
+#include <type_traits>
+#include <variant>
+
+using namespace mlir;
+using namespace mlir::gpu;
+
+namespace mlir {
+#define GEN_PASS_DEF_GPUIMITATEUNSUPPORTEDTYPES
+#include "mlir/Dialect/GPU/Transforms/Passes.h.inc"
+} // namespace mlir
+
+//===----------------------------------------------------------------------===//
+// Utility functions
+//===----------------------------------------------------------------------===//
+
+APFloat bitcastAPIntToAPFloat(const APInt &intValue,
+                              const llvm::fltSemantics &semantics) {
+  // Get the bit width of the APInt.
+  unsigned intBitWidth = intValue.getBitWidth();
+  // Get the total bit size required for the APFloat based on the semantics.
+  unsigned floatBitWidth = APFloat::getSizeInBits(semantics);
+  // Ensure the bit widths match for a direct bitcast.
+  assert(intBitWidth == floatBitWidth &&
+         "Bitwidth of APInt and APFloat must match for bitcast");
+
+  // Get the raw bit representation of the APInt as a byte vector.
+  auto intWords = intValue.getRawData();
+  // Create an APFloat with the specified semantics and the raw integer bits.
+  APFloat floatValue(semantics, APInt(intBitWidth, *intWords));
+  return floatValue;
+}
+
+// Get FloatAttr from IntegerAttr.
+FloatAttr getFloatAttrFromIntegerAttr(IntegerAttr intAttr, Type dstType,
+                                      ConversionPatternRewriter &rewriter) {
+  APInt intVal = intAttr.getValue();
+  auto floatVal = bitcastAPIntToAPFloat(
+      intVal, cast<FloatType>(dstType).getFloatSemantics());
+  return rewriter.getFloatAttr(dstType, floatVal);
+}
+// Get IntegerAttr from FloatAttr.
+IntegerAttr getIntegerAttrFromFloatAttr(FloatAttr floatAttr, Type dstType,
+                                        ConversionPatternRewriter &rewriter) {
+  APFloat floatVal = floatAttr.getValue();
+  APInt intVal = floatVal.bitcastToAPInt();
+  return rewriter.getIntegerAttr(dstType, intVal);
+}
+
+struct RawAllocator {
+  RawAllocator(OpBuilder &builder, Location loc) : builder(builder), loc(loc) {}
+
+  std::variant<Value, int64_t> computeTotalBytes(MemRefType srcType,
+                                                 Value srcMemref) {
+    // Element size in bytes.
+    int64_t elemBitWidth = srcType.getElementTypeBitWidth();
+    int64_t elemByteWidth = (elemBitWidth + 7) / 8;
+
+    if (srcType.hasStaticShape()) {
+      // Static shape: compute total bytes statically.
+      int64_t numElements = 1;
+      for (int64_t dim : srcType.getShape()) {
+        numElements *= dim;
+      }
+      return numElements * elemByteWidth;
+    }
+
+    auto sizes = getSizes(srcType, srcMemref);
+    // Compute number of elements dynamically.
+    Value numElements = sizes.front();
+    for (auto size : llvm::drop_begin(sizes))
+      numElements = builder.create<arith::MulIOp>(loc, numElements, size);
+    Value elemSize = builder.create<arith::ConstantIndexOp>(loc, elemByteWidth);
+
+    return builder.create<arith::MulIOp>(loc, numElements, elemSize);
+  }
+
+  SmallVector<Value> getSizes(MemRefType type, Value memref) {
+    SmallVector<Value> sizes;
+    for (unsigned i = 0; i < type.getRank(); ++i) {
+      if (type.isDynamicDim(i)) {
+        sizes.push_back(builder.create<memref::DimOp>(loc, memref, i));
+      } else {
+        sizes.push_back(
+            builder.create<arith::ConstantIndexOp>(loc, type.getShape()[i]));
+      }
+    }
+    return sizes;
+  }
+
+  SmallVector<Value> getDynamicSizes(MemRefType type, Value memref) {
+    SmallVector<Value> sizes;
+    for (unsigned i = 0; i < type.getRank(); ++i) {
+      if (type.isDynamicDim(i)) {
+        sizes.push_back(builder.create<memref::DimOp>(loc, memref, i));
+      }
+    }
+    return sizes;
+  }
+
+  SmallVector<Value> getIdentityStrides(MemRefType type) {
+    SmallVector<Value> strides;
+    int64_t runningStride = 1;
+    for (int64_t dim : llvm::reverse(type.getShape())) {
+      strides.push_back(
+          builder.create<arith::ConstantIndexOp>(loc, runningStride));
+      if (dim != ShapedType::kDynamic)
+        runningStride *= dim;
+      else
+        runningStride = -1; // not handling dynamic strides.
+    }
+    std::reverse(strides.begin(), strides.end());
+    return strides;
+  }
+
+private:
+  OpBuilder &builder;
+  Location loc;
+};
+
+// Replace uses according to predicates automatically.
+template <typename OpTy>
+void replaceUsesWithPredicate(
+    OpTy originalValue,
+    ArrayRef<std::pair<std::function<bool(OpOperand &)>, Value>> replacements,
+    ConversionPatternRewriter &rewriter) {
+
+  for (OpOperand &use : llvm::make_early_inc_range(originalValue->getUses())) {
+    for (const auto &[predicate, newValue] : replacements) {
+      if (predicate(use)) {
+        use.set(newValue);
+        break;
+      }
+    }
+  }
+}
+
+//===----------------------------------------------------------------------===//
+// Convertion patterns
+//===----------------------------------------------------------------------===//
+namespace {
+
+//===----------------------------------------------------------------------===//
+// FunctionOp conversion pattern
+//===----------------------------------------------------------------------===//
+template <typename FuncLikeOp>
+struct ConvertFuncOp final : public OpConversionPattern<FuncLikeOp> {
+  ConvertFuncOp(MLIRContext *context, TypeConverter &typeConverter,
+                ArrayRef<Type> sourceTypes, ArrayRef<Type> targetTypes,
+                DenseMap<StringAttr, FunctionType> &convertedFuncTypes)
+      : OpConversionPattern<FuncLikeOp>(context),
+        typeConverter(typeConverter), // Store the reference
+        sourceTypes(sourceTypes), targetTypes(targetTypes),
+        convertedFuncTypes(convertedFuncTypes) {}
+  using OpConversionPattern<FuncLikeOp>::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(FuncLikeOp op, typename FuncLikeOp::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    // Only handle functions a gpu.module
+    if (!op->template getParentOfType<gpu::GPUModuleOp>())
+      return failure();
+    FunctionType oldFuncType = op.getFunctionType();
+
+    // Convert function signature
+    TypeConverter::SignatureConversion signatureConverter(
+        oldFuncType.getNumInputs());
+    for (const auto &argType :
+         llvm::enumerate(op.getFunctionType().getInputs())) {
+      auto convertedType = typeConverter.convertType(argType.value());
+      if (!convertedType)
+        return failure();
+      signatureConverter.addInputs(argType.index(), convertedType);
+    }
+    SmallVector<Type, 4> newResultTypes;
+    for (const auto &resultType : llvm::enumerate(oldFuncType.getResults())) {
+      auto convertedType = typeConverter.convertType(resultType.value());
+      if (!convertedType)
+        return failure();
+      newResultTypes.push_back(convertedType);
+    }
+
+    // Convert function signature
+    FunctionType newFuncType = rewriter.getFunctionType(
+        signatureConverter.getConvertedTypes(), newResultTypes);
+
+    if (!newFuncType)
+      return rewriter.notifyMatchFailure(op, "could not convert function "
+                                             "type");
+
+    // Create new GPU function with converted type
+    auto newFuncOp =
+        rewriter.create<FuncLikeOp>(op.getLoc(), op.getName(), newFuncType);
+
+    newFuncOp.setVisibility(op.getVisibility());
+    // Copy attributes
+    for (auto attr : op->getAttrs()) {
+      // Skip the function_type attribute since it is already set by
+      // the newFuncType and we don't want to overwrite it.
+      if (attr.getName() != op.getFunctionTypeAttrName() &&
+          attr.getName() != SymbolTable::getSymbolAttrName())
+        newFuncOp->setAttr(attr.getName(), attr.getValue());
+    }
+
+    newFuncOp.getRegion().getBlocks().clear();
+    // Inline region approach
+    rewriter.inlineRegionBefore(op.getBody(), newFuncOp.getBody(),
+                                newFuncOp.end());
+    // Convert block argument types using the type converter
+    if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), typeConverter,
+                                           &signatureConverter))) {
+      return rewriter.notifyMatchFailure(op, "could not convert region "
+                                             "types");
+    }
+
+    if (!op.use_empty()) {
+      op.emitError("Cannot erase func: still has uses");
+    }
+    for (Operation *user : op->getUsers()) {
+      user->emitRemark() << "User of function " << op.getName();
+    }
+    rewriter.eraseOp(op);
+    // Add the converted function type to the map
+    newFuncOp.getNameAttr().getValue();
+    convertedFuncTypes[newFuncOp.getNameAttr()] = newFuncType;
+    return success();
+  }
+
+private:
+  TypeConverter &typeConverter; // Store a reference
+  ArrayRef<Type> sourceTypes;
+  ArrayRef<Type> targetTypes;
+  DenseMap<StringAttr, FunctionType> &convertedFuncTypes;
+};
+
+//===----------------------------------------------------------------------===//
+// CallOp conversion pattern
+//===----------------------------------------------------------------------===//
+struct ConvertCallOp : OpConversionPattern<func::CallOp> {
+  ConvertCallOp(MLIRContext *context, TypeConverter &typeConverter,
+                const DenseMap<StringAttr, FunctionType> &convertedFuncTypes)
+      : OpConversionPattern(context), convertedFuncTypes(convertedFuncTypes) {}
+
+  LogicalResult
+  matchAndRewrite(func::CallOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto callee = op.getCalleeAttr();
+
+    auto it = convertedFuncTypes.find(
+        StringAttr::get(callee.getContext(), callee.getValue()));
+    if (it == convertedFuncTypes.end())
+      return rewriter.notifyMatchFailure(
+          op, "Callee signature not converted. Perhaps the callee is not in "
+              "the same gpu module as the caller.");
+
+    auto newResultTypes = it->second.getResults();
+    rewriter.replaceOpWithNewOp<func::CallOp>(
+        op, callee.getValue(), newResultTypes, adaptor.getOperands());
+
+    return success();
+  }
+
+private:
+  const DenseMap<StringAttr, FunctionType> &convertedFuncTypes;
+};
+
+//===----------------------------------------------------------------------===//
+// GPULaunchFuncOp conversion pattern
+//===----------------------------------------------------------------------===//
+struct ConvertGPULaunchFuncOp : OpConversionPattern<gpu::LaunchFuncOp> {
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(gpu::LaunchFuncOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    std::optional<KernelDim3> clusterSizeOpernads =
+        op.hasClusterSize()
+            ? std::optional<gpu::KernelDim3>(op.getClusterSizeOperandValues())
+            : std::nullopt;
+
+    // Create the new launch_func.
+    auto newOp = rewriter.create<gpu::LaunchFuncOp>(
+        op.getLoc(), adaptor.getKernel(), op.getGridSizeOperandValues(),
+        op.getBlockSizeOperandValues(), op.getDynamicSharedMemorySize(),
+        adaptor.getKernelOperands(), op.getAsyncObject(), clusterSizeOpernads);
+
+    // Copy block size and grid size attributes
+    newOp->setAttrs(op->getAttrs());
+    rewriter.replaceOp(op, newOp.getResults());
+    return success();
+  }
+};
+
+//===----------------------------------------------------------------------===//
+// AllocOp conversion pattern
+//===----------------------------------------------------------------------===//
+template <typename AllocOp>
+struct ConvertAllocOp : OpConversionPattern<AllocOp> {
+  ConvertAllocOp(MLIRContext *ctx, TypeConverter &typeConverter)
+      : OpConversionPattern<AllocOp>(ctx), typeConverter(typeConverter) {}
+
+  LogicalResult
+  matchAndRewrite(AllocOp op, typename AllocOp::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    MemRefType srcType = llvm::cast<MemRefType>(op.getType());
+    // Only supports memref types with identity layout. Since this mechanism
+    // requires the usage of memref.ViewOp, which requires the layout to be
+    // identity.
+    if (!srcType.getLayout().isIdentity())
+      op.emitError("only memrefs with identity layout is supported");
+
+    auto dstType =
+        dyn_cast_or_null<MemRefType>(typeConverter.convertType(srcType));
+    if (!dstType || dstType == srcType)
+      return failure(); // No need to rewrite.
+
+    // Helper class to allocate raw memory.
+    RawAllocator allocator(rewriter, loc);
+
+    // 1. Compute total allocation size.
+    auto totalBytes = allocator.computeTotalBytes(srcType, op.getMemref());
+
+    // 2. Create raw i8 buffer.
+    MemRefType rawType;
+    if (std::holds_alternative<int64_t>(totalBytes)) {
+      // Static size.
+      SmallVector<int64_t> staticI8Shape;
+      staticI8Shape.push_back(std::get<int64_t>(totalBytes));
+      rawType = MemRefType::get(staticI8Shape, rewriter.getI8Type(), {},
+                                srcType.getMemorySpaceAsInt());
+    } else {
+      // Dynamic size.
+      rawType = MemRefType::get({ShapedType::kDynamic}, rewriter.getI8Type(),
+                                {}, srcType.getMemorySpaceAsInt());
+    }
+    Value rawAlloc;
+
+    if constexpr (std::is_same_v<AllocOp, gpu::AllocOp>) {
+      rawAlloc =
+          rewriter
+              .create<gpu::AllocOp>(
+                  loc, rawType,
+                  op.getAsyncToken() ? op.getAsyncToken().getType() : nullptr,
+                  adaptor.getAsyncDependencies(),
+                  std::holds_alternative<Value>(totalBytes)
+                      ? ValueRange{std::get<Value>(totalBytes)}
+                      : ValueRange{},
+                  adaptor.getSymbolOperands(), op.getHostShared())
+              .getResult(0);
+    } else {
+      rawAlloc = rewriter.create<memref::AllocOp>(
+          loc, rawType,
+          std::holds_alternative<Value>(totalBytes)
+              ? ValueRange{std::get<Value>(totalBytes)}
+              : ValueRange{},
+          op.getSymbolOperands());
+    }
+
+    // 3. Create view for original type.
+    SmallVector<Value> dynamicSizes =
+        allocator.getDynamicSizes(srcType, op.getMemref());
+    // Since we are using memref::ViewOp, only identity strides are supported.
+    SmallVector<Value> dynamicStrides = allocator.getIdentityStrides(srcType);
+    Value zeroOffset = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+    Value originalView = rewriter.create<memref::ViewOp>(
+        loc, srcType, rawAlloc, zeroOffset, dynamicSizes);
+
+    // 4. Create view for converted type.
+    Value convertedView = rewriter.create<memref::ViewOp>(
+        loc, dstType, rawAlloc, zeroOffset, dynamicSizes);
+
+    // 5. Replace uses:
+    //  gpu::LaunchFuncOp uses -> Replace the original AllocOp use in
+    //                            gpu::LaunchFuncOp with the view of the
+    //                            converted type.
+    //
+    //  DeallocOp uses -> Replace the original AllocOp use in dealloc with
+    //                    the new AllocOp.
+    //
+    //  Other uses-> Replace the original AllocOp use with the view of the
+    //               original type.
+
+    SmallVector<OpOperand *> launchFuncUses;
+    SmallVector<OpOperand *> deallocUses;
+    SmallVector<OpOperand *> otherUses;
+
+    for (OpOperand &use : op->getUses()) {
+      if (isa<gpu::LaunchFuncOp>(use.getOwner())) {
+        launchFuncUses.push_back(&use);
+      } else if (isa<memref::DeallocOp>(use.getOwner()) ||
+                 isa<gpu::DeallocOp>(use.getOwner())) {
+        deallocUses.push_back(&use);
+      } else {
+        otherUses.push_back(&use);
+      }
+    }
+
+    for (OpOperand *use : launchFuncUses)
+      use->set(convertedView);
+    for (OpOperand *use : deallocUses)
+      use->set(rawAlloc);
+    for (OpOperand *use : otherUses)
+      use->set(originalView);
+
+    // Erase the original AllocOp.
+    rewriter.eraseOp(op);
+    return success();
+  }
+
+private:
+  TypeConverter &typeConverter;
+};
+
+//===----------------------------------------------------------------------===//
+// ArithConstantOp conversion pattern
+//===----------------------------------------------------------------------===//
+struct ConvertArithConstantOp : OpConversionPattern<arith::ConstantOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  ConvertArithConstantOp(MLIRContext *context, TypeConverter &typeConverter,
+                         ArrayRef<Type> sourceTypes, ArrayRef<Type> targetTypes)
+      : OpConversionPattern(context),
+        typeConverter(typeConverter), // Store the reference.
+        sourceTypes(sourceTypes), targetTypes(targetTypes) {}
+
+  LogicalResult
+  matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Type srcType = op.getType();
+    Type dstType = typeConverter.convertType(srcType);
+    if (!dstType || dstType == srcType)
+      return failure();
+
+    Attribute value = op.getValue();
+    Value newConstOp = nullptr;
+
+    // When source is IntegerAttr.
+    if (auto intAttr = dyn_cast<IntegerAttr>(value)) {
+      APInt intVal = intAttr.getValue();
+      if (isa<FloatType>(dstType)) {
+        auto newAttr = getFloatAttrFromIntegerAttr(intAttr, dstType, rewriter);
+        newConstOp =
+            rewriter.create<arith::ConstantOp>(op.getLoc(), dstType, newAttr);
+      } else if (isa<IntegerType>(dstType)) {
+        auto newAttr = rewriter.getIntegerAttr(dstType, intVal);
+        newConstOp =
+            rewriter.create<arith::ConstantOp>(op.getLoc(), dstType, newAttr);
+      } else {
+        return rewriter.notifyMatchFailure(
+            op, "expected integer or float target type for constant");
+      }
+    }
+
+    // When source is FloatAttr.
+    else if (auto floatAttr = dyn_cast<FloatAttr>(value)) {
+      if (llvm::isa<IntegerType>(dstType)) {
+        auto newAttr =
+            getIntegerAttrFromFloatAttr(floatAttr, dstType, rewriter);
+        newConstOp =
+            rewriter.create<arith::ConstantOp>(op.getLoc(), dstType, newAttr);
+      } else if (llvm::isa<FloatType>(dstType)) {
+        auto newAttr = rewriter.getFloatAttr(dstType, floatAttr.getValue());
+        newConstOp =
+            rewriter.create<arith::ConstantOp>(op.getLoc(), dstType, newAttr);
+      } else {
+        return rewriter.notifyMatchFailure(
+            op, "expected integer or float target type for constant");
+      }
+    }
+    // Handle DenseElementsAttr.
+    else if (auto denseAttr = dyn_cast<DenseElementsAttr>(value)) {
+      Type newEltType;
+      if (auto shapedType = dyn_cast<ShapedType>(dstType))
+        newEltType = shapedType.getElementType();
+      else
+        return rewriter.notifyMatchFailure(
+            op, "expected shaped type for dense constant");
+
+      SmallVector<Attribute> newValues;
+      for (Attribute attr : denseAttr.getValues<Attribute>()) {
+        if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
+          if (llvm::isa<FloatType>(newEltType)) {
+            auto newAttr =
+                getFloatAttrFromIntegerAttr(intAttr, newEltType, rewriter);
+            newValues.push_back(newAttr);
+          } else if (llvm::isa<IntegerType>(newEltType)) {
+            newValues.push_back(
+                rewriter.getIntegerAttr(newEltType, intAttr.getValue()));
+          } else {
+            return rewriter.notifyMatchFailure(
+                op, "unsupported target element type in dense constant");
+          }
+        } else if (auto floatAttr = dyn_cast<FloatAttr>(attr)) {
+          if (llvm::isa<IntegerType>(newEltType)) {
+            auto newAttr =
+                getIntegerAttrFromFloatAttr(floatAttr, newEltType, rewriter);
+            newValues.push_back(newAttr);
+          } else if (llvm::isa<FloatType>(newEltType))
+            newValues.push_back(
+                rewriter.getFloatAttr(newEltType, floatAttr.getValue()));
+          else
+            return rewriter.notifyMatchFailure(
+                op, "unsupported target element type in dense constant");
+        } else {
+          return rewriter.notifyMatchFailure(
+              op, "unsupported target element type in dense constant");
+        }
+      }
+
+      auto newAttr =
+          DenseElementsAttr::get(cast<ShapedType>(dstType), newValues);
+      newConstOp =
+          rewriter.create<arith::ConstantOp>(op.getLoc(), dstType, newAttr);
+    }
+    if (!newConstOp)
+      return rewriter.notifyMatchFailure(
+          op, "unsupported constant type for source to target conversion");
+
+    auto bitcastOp =
+        rewriter.create<arith::BitcastOp>(op.getLoc(), srcType, newConstOp);
+    rewriter.replaceOp(op, bitcastOp.getResult());
+    return success();
+  }
+
+private:
+  TypeConverter &typeConverter; // Store a reference.
+  ArrayRef<Type> sourceTypes;
+  ArrayRef<Type> targetTypes;
+};
+
+//===----------------------------------------------------------------------===//
+// GenericOp conversion pattern
+//===----------------------------------------------------------------------===//
+struct ConvertOpWithSourceType final : ConversionPattern {
+  ConvertOpWithSourceType(MLIRContext *context,
+                          const TypeConverter &typeConverter,
+                          ArrayRef<Type> sourceTypes,
+                          ArrayRef<Type> targetTypes)
+      : ConversionPattern(typeConverter, MatchAnyOpTypeTag{}, 1, context),
+        sourceTypes(sourceTypes), targetTypes(targetTypes) {}
+
+  LogicalResult
+  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    SmallVector<Type, 4> newResultTypes;
+    for (Type t : op->getResultTypes()) {
+      Type converted = typeConverter->convertType(t);
+      if (!converted)
+        return failure();
+      newResultTypes.push_back(converted);
+    }
+
+    // Clone the op manually with the converted result types
+    OperationState state(op->getLoc(), op->getName().getStringRef());
+    state.addOperands(operands);
+    state.addTypes(newResultTypes);
+    state.addAttributes(op->getAttrs());
+
+    for ([[maybe_unused]] auto &region : op->getRegions())
+      state.regions.emplace_back();
+
+    Operation *newOp = rewriter.create(state);
+    // Transfer regions and convert them
+    for (auto [oldRegion, newRegion] :
+         llvm::zip(op->getRegions(), newOp->getRegions())) {
+      if (!oldRegion.empty()) {
+        newRegion.takeBody(oldRegion);
+        if (failed(rewriter.convertRegionTypes(&newRegion, *typeConverter))) {
+          return rewriter.notifyMatchFailure(op,
+                                             "region type conversion failed");
+        }
+      }
+    }
+
+    rewriter.replaceOp(op, newOp->getResults());
+    return success();
+  }
+
+private:
+  ArrayRef<Type> sourceTypes;
+  ArrayRef<Type> targetTypes;
+};
+
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// Type Converter
+//===----------------------------------------------------------------------===//
+
+void mlir::populateImitateUnsupportedTypesTypeConverter(
+    TypeConverter &typeConverter, ArrayRef<Type> sourceTypes,
+    ArrayRef<Type> targetTypes) {
+  auto srcTypes = SmallVector<Type>(sourceTypes);
+  auto tgtTypes = SmallVector<Type>(targetTypes);
+
+  assert(sourceTypes.size() == targetTypes.size() &&
+         "Source and target types must have same size");
+
+  typeConverter.addConversion([srcTypes, tgtTypes](Type type) -> Type {
+    if (type.isIntOrIndexOrFloat()) {
+      for (auto [src, tgt] : llvm::zip_equal(srcTypes, tgtTypes)) {
+        if (type == src)
+          return tgt;
+      }
+    } else if (auto memref = llvm::dyn_cast<MemRefType>(type)) {
+      Type elemType = memref.getElementType();
+      for (auto [src, tgt] : llvm::zip_equal(srcTypes, tgtTypes)) {
+        if (elemType == src)
+          return MemRefType::get(memref.getShape(), tgt, memref.getLayout(),
+                                 memref.getMemorySpace());
+      }
+    } else if (auto vec = llvm::dyn_cast<VectorType>(type)) {
+      Type elemType = vec.getElementType();
+      for (auto [src, tgt] : llvm::zip_equal(srcTypes, tgtTypes)) {
+        if (elemType == src)
+          return VectorType::get(vec.getShape(), tgt);
+      }
+    }
+    return type;
+  });
+
+  auto materializeCast = [](OpBuilder &builder, Type resultType,
+                            ValueRange inputs, Location loc) -> Value {
+    assert(inputs.size() == 1 && "Expected single input");
+    Type inputType = inputs[0].getType();
+    if (isa<MemRefType>(resultType) && isa<MemRefType>(inputType)) {
+      return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
+          .getResult(0);
+    }
+    if ((resultType.isIntOrIndexOrFloat() || isa<VectorType>(resultType)) &&
+        (inputType.isIntOrIndexOrFloat() || isa<VectorType>(inputType))) {
+      return builder.create<arith::BitcastOp>(loc, resultType, inputs[0])
+          .getResult();
+    }
+    return nullptr;
+  };
+
+  typeConverter.addSourceMaterialization(materializeCast);
+  typeConverter.addTargetMaterialization(materializeCast);
+}
+
+//===----------------------------------------------------------------------===//
+// Pattern Population
+//===----------------------------------------------------------------------===//
+
+void mlir::populateImitateUnsupportedTypesConversionPatterns(
+    RewritePatternSet &patterns, TypeConverter &typeConverter,
+    ArrayRef<Type> sourceTypes, ArrayRef<Type> targetTypes,
+    DenseMap<StringAttr, FunctionType> &convertedFuncTypes) {
+  auto ctx = patterns.getContext();
+  auto srcTypes = SmallVector<Type>(sourceTypes);
+  auto tgtTypes = SmallVector<Type>(targetTypes);
+  assert(srcTypes.size() == tgtTypes.size() &&
+         "Source and target types must have same size");
+
+  patterns.add<ConvertOpWithSourceType>(ctx, typeConverter, srcTypes, tgtTypes);
+  patterns.add<ConvertFuncOp<gpu::GPUFuncOp>, ConvertFuncOp<func::FuncOp>>(
+      ctx, typeConverter, srcTypes, tgtTypes, convertedFuncTypes);
+  patterns.add<ConvertCallOp>(ctx, typeConverter, convertedFuncTypes);
+  patterns.add<ConvertArithConstantOp>(ctx, typeConverter, srcTypes, tgtTypes);
+  patterns.add<ConvertGPULaunchFuncOp>(ctx);
+  patterns.add<ConvertAllocOp<gpu::AllocOp>>(ctx, typeConverter);
+  patterns.add<ConvertAllocOp<memref::AllocOp>>(ctx, typeConverter);
+}
+
+//===----------------------------------------------------------------------===//
+// Conversion Legality configuration
+//===----------------------------------------------------------------------===//
+
+void mlir::configureImitateUnsupportedTypesLegality(
+    ConversionTarget &target, TypeConverter &typeConverter) {
+  target.addLegalDialect<arith::ArithDialect>();
+  target.addLegalDialect<math::MathDialect>();
+  // Make Memref, func dialect legal for all ops in host code
+  target.addDynamicallyLegalDialect<memref::MemRefDialect>([&](Operation *op) {
+    if (op->getParentOfType<gpu::GPUModuleOp>())
+      return typeConverter.isLegal(op);
+    else
+      return true;
+  });
+
+  target.addDynamicallyLegalDialect<gpu::GPUDialect>(
+      [&](Operation *op) { return typeConverter.isLegal(op); });
+
+  target.addDynamicallyLegalDialect<func::FuncDialect>([&](Operation *op) {
+    if (op->getParentOfType<gpu::GPUModuleOp>())
+      return typeConverter.isLegal(op);
+    else
+      return true;
+  });
+
+  target.addLegalOp<gpu::GPUModuleOp>();
+  target.addLegalOp<UnrealizedConversionCastOp>();
+  // Manually mark arithmetic-performing vector instructions.
+  target.addLegalOp<vector::ContractionOp, vector::ReductionOp,
+                    vector::MultiDimReductionOp, vector::FMAOp,
+                    vector::OuterProductOp, vector::MatmulOp, vector::ScanOp,
+                    vector::SplatOp>();
+  target.addDynamicallyLegalOp<arith::ConstantOp>([&](arith::ConstantOp op) {
+    return typeConverter.isLegal(op.getType());
+  });
+  target.addDynamicallyLegalOp<gpu::GPUFuncOp>([&](gpu::GPUFuncOp op) {
+    return typeConverter.isSignatureLegal(op.getFunctionType());
+  });
+  // Only convert functions and function calls in gpu.module
+  target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
+    if (op->getParentOfType<gpu::GPUModuleOp>())
+      return typeConverter.isSignatureLegal(op.getFunctionType());
+    return true;
+  });
+  target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) {
+    if (op->getParentOfType<gpu::GPUModuleOp>())
+      return typeConverter.isSignatureLegal(op.getCalleeType());
+    return true;
+  });
+
+  // Only convert alloc ops in gpu.module or in host functions and has a use
+  // in LaunchFunc
+  target.addDynamicallyLegalOp<memref::AllocOp>([&](memref::AllocOp op) {
+    if (op->getParentOfType<gpu::GPUModuleOp>())
+      return typeConverter.isLegal(op.getType());
+    else {
+      for (auto user : op->getUsers()) {
+        if (isa<gpu::LaunchFuncOp>(user))
+          return typeConverter.isLegal(op.getType());
+      }
+    }
+    return true;
+  });
+
+  // Mark unknown ops that are inside gpu.module, and one of its's operand is a
+  // memref type as dynamically legal.
+  target.markUnknownOpDynamicallyLegal([&typeConverter](Operation *op) -> bool {
+    // Check if the operation is inside a gpu.module.
+    if (op->getParentOfType<gpu::GPUModuleOp>()) {
+      // Check if the operation has any operands of type MemRefType.
+      for (Value operand : op->getOperands()) {
+        if (isa<MemRefType>(operand.getType()))
+          return typeConverter.isLegal(op);
+      }
+      // If no operands are of type MemRefType, mark it as illegal.
+      return true;
+    }
+    return true; // If not in gpu.module, mark it as legal.
+  });
+}
+
+//===----------------------------------------------------------------------===//
+// Pass Definition
+//===----------------------------------------------------------------------===//
+
+namespace {
+struct GpuImitateUnsupportedTypesPass
+    : public impl::GpuImitateUnsupportedTypesBase<
+          GpuImitateUnsupportedTypesPass> {
+  using Base::Base;
+
+  void runOnOperation() override {
+    MLIRContext *ctx = &getContext();
+    Operation *op = getOperation();
+
+    SmallVector<Type> sourceTypes;
+    SmallVector<Type> targetTypes;
+
+    // Parse source types
+    for (StringRef sourceTypeStr : sourceTypeStrs) {
+      std::optional<Type> maybeSourceType =
+          arith::parseIntOrFloatType(ctx, sourceTypeStr);
+
+      if (!maybeSourceType) {
+        emitError(UnknownLoc::get(ctx),
+                  "could not map source type '" + sourceTypeStr +
+                      "' to a known integer or floating-point type.");
+        return signalPassFailure();
+      }
+      sourceTypes.push_back(*maybeSourceType);
+    }
+    if (sourceTypes.empty()) {
+      (void)emitOptionalWarning(std::nullopt, "no source types "
+                                              "specified, type "
+                                              "imitation will do "
+                                              "nothing");
+    }
+
+    // Parse target types
+    for (StringRef targetTypeStr : targetTypeStrs) {
+      std::optional<Type> maybeTargetType =
+          arith::parseIntOrFloatType(ctx, targetTypeStr);
+
+      if (!maybeTargetType) {
+        emitError(UnknownLoc::get(ctx),
+                  "could not map target type '" + targetTypeStr +
+                      "' to a known integer or floating-point type");
+        return signalPassFailure();
+      }
+      targetTypes.push_back(*maybeTargetType);
+
+      if (llvm::is_contained(sourceTypes, *maybeTargetType)) {
+        emitError(UnknownLoc::get(ctx),
+                  "target type cannot be an unsupported source type");
+        return signalPassFailure();
+      }
+    }
+    if (targetTypes.empty()) {
+      (void)emitOptionalWarning(
+          std::nullopt,
+          "no target types specified, type imitation will do nothing");
+    }
+
+    // Set up the type converter
+    TypeConverter typeConverter;
+    populateImitateUnsupportedTypesTypeConverter(typeConverter, sourceTypes,
+                                                 targetTypes);
+
+    // Populate the conversion patterns
+    RewritePatternSet patterns(ctx);
+    DenseMap<StringAttr, FunctionType> convertedFuncTypes;
+    populateImitateUnsupportedTypesConversionPatterns(
+        patterns, typeConverter, sourceTypes, targetTypes, convertedFuncTypes);
+
+    // Set up conversion target and configure the legality of the conversion
+    ConversionTarget target(*ctx);
+    configureImitateUnsupportedTypesLegality(target, typeConverter);
+
+    // Apply the conversion
+    if (failed(applyPartialConversion(op, target, std::move(patterns))))
+      signalPassFailure();
+
+    // Post-conversion validation: check for any remaining
+    // unrealized_conversion_cast
+    bool hasUnresolvedCast = false;
+    op->walk([&](UnrealizedConversionCastOp op) {
+      // Check if the cast is from a source type to a target type
+      for (auto [sourceType, targetType] :
+           llvm::zip_equal(sourceTypes, targetTypes)) {
+        if (getElementTypeOrSelf(op.getOperand(0).getType()) == sourceType &&
+            getElementTypeOrSelf(op.getResult(0).getType()) == targetType) {
+          op->emitError("unresolved unrealized_conversion_cast left in IR "
+                        "after conversion");
+          hasUnresolvedCast = true;
----------------
mshahneo wrote:

Thanks, fixed.

https://github.com/llvm/llvm-project/pull/138087


More information about the Mlir-commits mailing list