[Mlir-commits] [mlir] 9527d77 - [mlir][spirv] Restructure code in `SPIRVConversion.cpp`. NFC. (#99393)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Jul 18 11:31:18 PDT 2024


Author: Angel Zhang
Date: 2024-07-18T14:31:15-04:00
New Revision: 9527d77aefcf214944a4c8bd284dde3ffe9dff60

URL: https://github.com/llvm/llvm-project/commit/9527d77aefcf214944a4c8bd284dde3ffe9dff60
DIFF: https://github.com/llvm/llvm-project/commit/9527d77aefcf214944a4c8bd284dde3ffe9dff60.diff

LOG: [mlir][spirv] Restructure code in `SPIRVConversion.cpp`. NFC. (#99393)

Added: 
    

Modified: 
    mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index e3a09ef1ff684..bf5044437fd09 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -40,6 +40,8 @@
 
 using namespace mlir;
 
+namespace {
+
 //===----------------------------------------------------------------------===//
 // Utility functions
 //===----------------------------------------------------------------------===//
@@ -171,18 +173,6 @@ static spirv::ScalarType getIndexType(MLIRContext *ctx,
       IntegerType::get(ctx, options.use64bitIndex ? 64 : 32));
 }
 
-Type SPIRVTypeConverter::getIndexType() const {
-  return ::getIndexType(getContext(), options);
-}
-
-MLIRContext *SPIRVTypeConverter::getContext() const {
-  return targetEnv.getAttr().getContext();
-}
-
-bool SPIRVTypeConverter::allows(spirv::Capability capability) const {
-  return targetEnv.allows(capability);
-}
-
 // TODO: This is a utility function that should probably be exposed by the
 // SPIR-V dialect. Keeping it local till the use case arises.
 static std::optional<int64_t>
@@ -673,9 +663,9 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
 /// This function is meant to handle the **compute** side; so it does not
 /// involve storage classes in its logic. The storage side is expected to be
 /// handled by MemRef conversion logic.
-std::optional<Value> castToSourceType(const spirv::TargetEnv &targetEnv,
-                                      OpBuilder &builder, Type type,
-                                      ValueRange inputs, Location loc) {
+static std::optional<Value> castToSourceType(const spirv::TargetEnv &targetEnv,
+                                             OpBuilder &builder, Type type,
+                                             ValueRange inputs, Location loc) {
   // We can only cast one value in SPIR-V.
   if (inputs.size() != 1) {
     auto castOp = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
@@ -731,140 +721,185 @@ std::optional<Value> castToSourceType(const spirv::TargetEnv &targetEnv,
 }
 
 //===----------------------------------------------------------------------===//
-// SPIRVTypeConverter
+// Builtin Variables
 //===----------------------------------------------------------------------===//
 
-SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr,
-                                       const SPIRVConversionOptions &options)
-    : targetEnv(targetAttr), options(options) {
-  // Add conversions. The order matters here: later ones will be tried earlier.
+static spirv::GlobalVariableOp getBuiltinVariable(Block &body,
+                                                  spirv::BuiltIn builtin) {
+  // Look through all global variables in the given `body` block and check if
+  // there is a spirv.GlobalVariable that has the same `builtin` attribute.
+  for (auto varOp : body.getOps<spirv::GlobalVariableOp>()) {
+    if (auto builtinAttr = varOp->getAttrOfType<StringAttr>(
+            spirv::SPIRVDialect::getAttributeName(
+                spirv::Decoration::BuiltIn))) {
+      auto varBuiltIn = spirv::symbolizeBuiltIn(builtinAttr.getValue());
+      if (varBuiltIn && *varBuiltIn == builtin) {
+        return varOp;
+      }
+    }
+  }
+  return nullptr;
+}
 
-  // Allow all SPIR-V dialect specific types. This assumes all builtin types
-  // adopted in the SPIR-V dialect (i.e., IntegerType, FloatType, VectorType)
-  // were tried before.
-  //
-  // TODO: This assumes that the SPIR-V types are valid to use in the given
-  // target environment, which should be the case if the whole pipeline is
-  // driven by the same target environment. Still, we probably still want to
-  // validate and convert to be safe.
-  addConversion([](spirv::SPIRVType type) { return type; });
+/// Gets name of global variable for a builtin.
+std::string getBuiltinVarName(spirv::BuiltIn builtin, StringRef prefix,
+                              StringRef suffix) {
+  return Twine(prefix).concat(stringifyBuiltIn(builtin)).concat(suffix).str();
+}
 
-  addConversion([this](IndexType /*indexType*/) { return getIndexType(); });
+/// Gets or inserts a global variable for a builtin within `body` block.
+static spirv::GlobalVariableOp
+getOrInsertBuiltinVariable(Block &body, Location loc, spirv::BuiltIn builtin,
+                           Type integerType, OpBuilder &builder,
+                           StringRef prefix, StringRef suffix) {
+  if (auto varOp = getBuiltinVariable(body, builtin))
+    return varOp;
 
-  addConversion([this](IntegerType intType) -> std::optional<Type> {
-    if (auto scalarType = dyn_cast<spirv::ScalarType>(intType))
-      return convertScalarType(this->targetEnv, this->options, scalarType);
-    if (intType.getWidth() < 8)
-      return convertSubByteIntegerType(this->options, intType);
-    return Type();
-  });
+  OpBuilder::InsertionGuard guard(builder);
+  builder.setInsertionPointToStart(&body);
 
-  addConversion([this](FloatType floatType) -> std::optional<Type> {
-    if (auto scalarType = dyn_cast<spirv::ScalarType>(floatType))
-      return convertScalarType(this->targetEnv, this->options, scalarType);
-    return Type();
-  });
+  spirv::GlobalVariableOp newVarOp;
+  switch (builtin) {
+  case spirv::BuiltIn::NumWorkgroups:
+  case spirv::BuiltIn::WorkgroupSize:
+  case spirv::BuiltIn::WorkgroupId:
+  case spirv::BuiltIn::LocalInvocationId:
+  case spirv::BuiltIn::GlobalInvocationId: {
+    auto ptrType = spirv::PointerType::get(VectorType::get({3}, integerType),
+                                           spirv::StorageClass::Input);
+    std::string name = getBuiltinVarName(builtin, prefix, suffix);
+    newVarOp =
+        builder.create<spirv::GlobalVariableOp>(loc, ptrType, name, builtin);
+    break;
+  }
+  case spirv::BuiltIn::SubgroupId:
+  case spirv::BuiltIn::NumSubgroups:
+  case spirv::BuiltIn::SubgroupSize: {
+    auto ptrType =
+        spirv::PointerType::get(integerType, spirv::StorageClass::Input);
+    std::string name = getBuiltinVarName(builtin, prefix, suffix);
+    newVarOp =
+        builder.create<spirv::GlobalVariableOp>(loc, ptrType, name, builtin);
+    break;
+  }
+  default:
+    emitError(loc, "unimplemented builtin variable generation for ")
+        << stringifyBuiltIn(builtin);
+  }
+  return newVarOp;
+}
 
-  addConversion([this](ComplexType complexType) {
-    return convertComplexType(this->targetEnv, this->options, complexType);
-  });
+//===----------------------------------------------------------------------===//
+// Push constant storage
+//===----------------------------------------------------------------------===//
 
-  addConversion([this](VectorType vectorType) {
-    return convertVectorType(this->targetEnv, this->options, vectorType);
-  });
+/// Returns the pointer type for the push constant storage containing
+/// `elementCount` 32-bit integer values.
+static spirv::PointerType getPushConstantStorageType(unsigned elementCount,
+                                                     Builder &builder,
+                                                     Type indexType) {
+  auto arrayType = spirv::ArrayType::get(indexType, elementCount,
+                                         /*stride=*/4);
+  auto structType = spirv::StructType::get({arrayType}, /*offsetInfo=*/0);
+  return spirv::PointerType::get(structType, spirv::StorageClass::PushConstant);
+}
 
-  addConversion([this](TensorType tensorType) {
-    return convertTensorType(this->targetEnv, this->options, tensorType);
-  });
+/// Returns the push constant varible containing `elementCount` 32-bit integer
+/// values in `body`. Returns null op if such an op does not exit.
+static spirv::GlobalVariableOp getPushConstantVariable(Block &body,
+                                                       unsigned elementCount) {
+  for (auto varOp : body.getOps<spirv::GlobalVariableOp>()) {
+    auto ptrType = dyn_cast<spirv::PointerType>(varOp.getType());
+    if (!ptrType)
+      continue;
 
-  addConversion([this](MemRefType memRefType) {
-    return convertMemrefType(this->targetEnv, this->options, memRefType);
-  });
+    // Note that Vulkan requires "There must be no more than one push constant
+    // block statically used per shader entry point." So we should always reuse
+    // the existing one.
+    if (ptrType.getStorageClass() == spirv::StorageClass::PushConstant) {
+      auto numElements = cast<spirv::ArrayType>(
+                             cast<spirv::StructType>(ptrType.getPointeeType())
+                                 .getElementType(0))
+                             .getNumElements();
+      if (numElements == elementCount)
+        return varOp;
+    }
+  }
+  return nullptr;
+}
 
-  // Register some last line of defense casting logic.
-  addSourceMaterialization(
-      [this](OpBuilder &builder, Type type, ValueRange inputs, Location loc) {
-        return castToSourceType(this->targetEnv, builder, type, inputs, loc);
-      });
-  addTargetMaterialization([](OpBuilder &builder, Type type, ValueRange inputs,
-                              Location loc) {
-    auto cast = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
-    return std::optional<Value>(cast.getResult(0));
-  });
+/// Gets or inserts a global variable for push constant storage containing
+/// `elementCount` 32-bit integer values in `block`.
+static spirv::GlobalVariableOp
+getOrInsertPushConstantVariable(Location loc, Block &block,
+                                unsigned elementCount, OpBuilder &b,
+                                Type indexType) {
+  if (auto varOp = getPushConstantVariable(block, elementCount))
+    return varOp;
+
+  auto builder = OpBuilder::atBlockBegin(&block, b.getListener());
+  auto type = getPushConstantStorageType(elementCount, builder, indexType);
+  const char *name = "__push_constant_var__";
+  return builder.create<spirv::GlobalVariableOp>(loc, type, name,
+                                                 /*initializer=*/nullptr);
 }
 
 //===----------------------------------------------------------------------===//
 // func::FuncOp Conversion Patterns
 //===----------------------------------------------------------------------===//
 
-namespace {
 /// A pattern for rewriting function signature to convert arguments of functions
 /// to be of valid SPIR-V types.
-class FuncOpConversion final : public OpConversionPattern<func::FuncOp> {
-public:
+struct FuncOpConversion final : OpConversionPattern<func::FuncOp> {
   using OpConversionPattern<func::FuncOp>::OpConversionPattern;
 
   LogicalResult
   matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override;
-};
-} // namespace
-
-LogicalResult
-FuncOpConversion::matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor,
-                                  ConversionPatternRewriter &rewriter) const {
-  auto fnType = funcOp.getFunctionType();
-  if (fnType.getNumResults() > 1)
-    return failure();
-
-  TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs());
-  for (const auto &argType : enumerate(fnType.getInputs())) {
-    auto convertedType = getTypeConverter()->convertType(argType.value());
-    if (!convertedType)
-      return failure();
-    signatureConverter.addInputs(argType.index(), convertedType);
-  }
-
-  Type resultType;
-  if (fnType.getNumResults() == 1) {
-    resultType = getTypeConverter()->convertType(fnType.getResult(0));
-    if (!resultType)
+                  ConversionPatternRewriter &rewriter) const override {
+    FunctionType fnType = funcOp.getFunctionType();
+    if (fnType.getNumResults() > 1)
       return failure();
-  }
-
-  // Create the converted spirv.func op.
-  auto newFuncOp = rewriter.create<spirv::FuncOp>(
-      funcOp.getLoc(), funcOp.getName(),
-      rewriter.getFunctionType(signatureConverter.getConvertedTypes(),
-                               resultType ? TypeRange(resultType)
-                                          : TypeRange()));
 
-  // Copy over all attributes other than the function name and type.
-  for (const auto &namedAttr : funcOp->getAttrs()) {
-    if (namedAttr.getName() != funcOp.getFunctionTypeAttrName() &&
-        namedAttr.getName() != SymbolTable::getSymbolAttrName())
-      newFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue());
-  }
+    TypeConverter::SignatureConversion signatureConverter(
+        fnType.getNumInputs());
+    for (const auto &argType : enumerate(fnType.getInputs())) {
+      auto convertedType = getTypeConverter()->convertType(argType.value());
+      if (!convertedType)
+        return failure();
+      signatureConverter.addInputs(argType.index(), convertedType);
+    }
 
-  rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
-                              newFuncOp.end());
-  if (failed(rewriter.convertRegionTypes(
-          &newFuncOp.getBody(), *getTypeConverter(), &signatureConverter)))
-    return failure();
-  rewriter.eraseOp(funcOp);
-  return success();
-}
+    Type resultType;
+    if (fnType.getNumResults() == 1) {
+      resultType = getTypeConverter()->convertType(fnType.getResult(0));
+      if (!resultType)
+        return failure();
+    }
 
-void mlir::populateBuiltinFuncToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
-                                              RewritePatternSet &patterns) {
-  patterns.add<FuncOpConversion>(typeConverter, patterns.getContext());
-}
+    // Create the converted spirv.func op.
+    auto newFuncOp = rewriter.create<spirv::FuncOp>(
+        funcOp.getLoc(), funcOp.getName(),
+        rewriter.getFunctionType(signatureConverter.getConvertedTypes(),
+                                 resultType ? TypeRange(resultType)
+                                            : TypeRange()));
+
+    // Copy over all attributes other than the function name and type.
+    for (const auto &namedAttr : funcOp->getAttrs()) {
+      if (namedAttr.getName() != funcOp.getFunctionTypeAttrName() &&
+          namedAttr.getName() != SymbolTable::getSymbolAttrName())
+        newFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue());
+    }
 
-//===----------------------------------------------------------------------===//
-// func::FuncOp Conversion Patterns
-//===----------------------------------------------------------------------===//
+    rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
+                                newFuncOp.end());
+    if (failed(rewriter.convertRegionTypes(
+            &newFuncOp.getBody(), *getTypeConverter(), &signatureConverter)))
+      return failure();
+    rewriter.eraseOp(funcOp);
+    return success();
+  }
+};
 
-namespace {
 /// A pattern for rewriting function signature to convert vector arguments of
 /// functions to be of valid types
 struct FuncOpVectorUnroll final : OpRewritePattern<func::FuncOp> {
@@ -1015,17 +1050,11 @@ struct FuncOpVectorUnroll final : OpRewritePattern<func::FuncOp> {
     return success();
   }
 };
-} // namespace
-
-void mlir::populateFuncOpVectorRewritePatterns(RewritePatternSet &patterns) {
-  patterns.add<FuncOpVectorUnroll>(patterns.getContext());
-}
 
 //===----------------------------------------------------------------------===//
 // func::ReturnOp Conversion Patterns
 //===----------------------------------------------------------------------===//
 
-namespace {
 /// A pattern for rewriting function signature and the return op to convert
 /// vectors to be of valid types.
 struct ReturnOpVectorUnroll final : OpRewritePattern<func::ReturnOp> {
@@ -1097,81 +1126,13 @@ struct ReturnOpVectorUnroll final : OpRewritePattern<func::ReturnOp> {
     return success();
   }
 };
-} // namespace
 
-void mlir::populateReturnOpVectorRewritePatterns(RewritePatternSet &patterns) {
-  patterns.add<ReturnOpVectorUnroll>(patterns.getContext());
-}
+} // namespace
 
 //===----------------------------------------------------------------------===//
-// Builtin Variables
+// Public function for builtin variables
 //===----------------------------------------------------------------------===//
 
-static spirv::GlobalVariableOp getBuiltinVariable(Block &body,
-                                                  spirv::BuiltIn builtin) {
-  // Look through all global variables in the given `body` block and check if
-  // there is a spirv.GlobalVariable that has the same `builtin` attribute.
-  for (auto varOp : body.getOps<spirv::GlobalVariableOp>()) {
-    if (auto builtinAttr = varOp->getAttrOfType<StringAttr>(
-            spirv::SPIRVDialect::getAttributeName(
-                spirv::Decoration::BuiltIn))) {
-      auto varBuiltIn = spirv::symbolizeBuiltIn(builtinAttr.getValue());
-      if (varBuiltIn && *varBuiltIn == builtin) {
-        return varOp;
-      }
-    }
-  }
-  return nullptr;
-}
-
-/// Gets name of global variable for a builtin.
-static std::string getBuiltinVarName(spirv::BuiltIn builtin, StringRef prefix,
-                                     StringRef suffix) {
-  return Twine(prefix).concat(stringifyBuiltIn(builtin)).concat(suffix).str();
-}
-
-/// Gets or inserts a global variable for a builtin within `body` block.
-static spirv::GlobalVariableOp
-getOrInsertBuiltinVariable(Block &body, Location loc, spirv::BuiltIn builtin,
-                           Type integerType, OpBuilder &builder,
-                           StringRef prefix, StringRef suffix) {
-  if (auto varOp = getBuiltinVariable(body, builtin))
-    return varOp;
-
-  OpBuilder::InsertionGuard guard(builder);
-  builder.setInsertionPointToStart(&body);
-
-  spirv::GlobalVariableOp newVarOp;
-  switch (builtin) {
-  case spirv::BuiltIn::NumWorkgroups:
-  case spirv::BuiltIn::WorkgroupSize:
-  case spirv::BuiltIn::WorkgroupId:
-  case spirv::BuiltIn::LocalInvocationId:
-  case spirv::BuiltIn::GlobalInvocationId: {
-    auto ptrType = spirv::PointerType::get(VectorType::get({3}, integerType),
-                                           spirv::StorageClass::Input);
-    std::string name = getBuiltinVarName(builtin, prefix, suffix);
-    newVarOp =
-        builder.create<spirv::GlobalVariableOp>(loc, ptrType, name, builtin);
-    break;
-  }
-  case spirv::BuiltIn::SubgroupId:
-  case spirv::BuiltIn::NumSubgroups:
-  case spirv::BuiltIn::SubgroupSize: {
-    auto ptrType =
-        spirv::PointerType::get(integerType, spirv::StorageClass::Input);
-    std::string name = getBuiltinVarName(builtin, prefix, suffix);
-    newVarOp =
-        builder.create<spirv::GlobalVariableOp>(loc, ptrType, name, builtin);
-    break;
-  }
-  default:
-    emitError(loc, "unimplemented builtin variable generation for ")
-        << stringifyBuiltIn(builtin);
-  }
-  return newVarOp;
-}
-
 Value mlir::spirv::getBuiltinVariableValue(Operation *op,
                                            spirv::BuiltIn builtin,
                                            Type integerType, OpBuilder &builder,
@@ -1190,60 +1151,9 @@ Value mlir::spirv::getBuiltinVariableValue(Operation *op,
 }
 
 //===----------------------------------------------------------------------===//
-// Push constant storage
+// Public function for pushing constant storage
 //===----------------------------------------------------------------------===//
 
-/// Returns the pointer type for the push constant storage containing
-/// `elementCount` 32-bit integer values.
-static spirv::PointerType getPushConstantStorageType(unsigned elementCount,
-                                                     Builder &builder,
-                                                     Type indexType) {
-  auto arrayType = spirv::ArrayType::get(indexType, elementCount,
-                                         /*stride=*/4);
-  auto structType = spirv::StructType::get({arrayType}, /*offsetInfo=*/0);
-  return spirv::PointerType::get(structType, spirv::StorageClass::PushConstant);
-}
-
-/// Returns the push constant varible containing `elementCount` 32-bit integer
-/// values in `body`. Returns null op if such an op does not exit.
-static spirv::GlobalVariableOp getPushConstantVariable(Block &body,
-                                                       unsigned elementCount) {
-  for (auto varOp : body.getOps<spirv::GlobalVariableOp>()) {
-    auto ptrType = dyn_cast<spirv::PointerType>(varOp.getType());
-    if (!ptrType)
-      continue;
-
-    // Note that Vulkan requires "There must be no more than one push constant
-    // block statically used per shader entry point." So we should always reuse
-    // the existing one.
-    if (ptrType.getStorageClass() == spirv::StorageClass::PushConstant) {
-      auto numElements = cast<spirv::ArrayType>(
-                             cast<spirv::StructType>(ptrType.getPointeeType())
-                                 .getElementType(0))
-                             .getNumElements();
-      if (numElements == elementCount)
-        return varOp;
-    }
-  }
-  return nullptr;
-}
-
-/// Gets or inserts a global variable for push constant storage containing
-/// `elementCount` 32-bit integer values in `block`.
-static spirv::GlobalVariableOp
-getOrInsertPushConstantVariable(Location loc, Block &block,
-                                unsigned elementCount, OpBuilder &b,
-                                Type indexType) {
-  if (auto varOp = getPushConstantVariable(block, elementCount))
-    return varOp;
-
-  auto builder = OpBuilder::atBlockBegin(&block, b.getListener());
-  auto type = getPushConstantStorageType(elementCount, builder, indexType);
-  const char *name = "__push_constant_var__";
-  return builder.create<spirv::GlobalVariableOp>(loc, type, name,
-                                                 /*initializer=*/nullptr);
-}
-
 Value spirv::getPushConstantValue(Operation *op, unsigned elementCount,
                                   unsigned offset, Type integerType,
                                   OpBuilder &builder) {
@@ -1267,7 +1177,7 @@ Value spirv::getPushConstantValue(Operation *op, unsigned elementCount,
 }
 
 //===----------------------------------------------------------------------===//
-// Index calculation
+// Public functions for index calculation
 //===----------------------------------------------------------------------===//
 
 Value mlir::spirv::linearizeIndex(ValueRange indices, ArrayRef<int64_t> strides,
@@ -1375,6 +1285,81 @@ Value mlir::spirv::getElementPtr(const SPIRVTypeConverter &typeConverter,
                              builder);
 }
 
+//===----------------------------------------------------------------------===//
+// SPIR-V TypeConverter
+//===----------------------------------------------------------------------===//
+
+SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr,
+                                       const SPIRVConversionOptions &options)
+    : targetEnv(targetAttr), options(options) {
+  // Add conversions. The order matters here: later ones will be tried earlier.
+
+  // Allow all SPIR-V dialect specific types. This assumes all builtin types
+  // adopted in the SPIR-V dialect (i.e., IntegerType, FloatType, VectorType)
+  // were tried before.
+  //
+  // TODO: This assumes that the SPIR-V types are valid to use in the given
+  // target environment, which should be the case if the whole pipeline is
+  // driven by the same target environment. Still, we probably still want to
+  // validate and convert to be safe.
+  addConversion([](spirv::SPIRVType type) { return type; });
+
+  addConversion([this](IndexType /*indexType*/) { return getIndexType(); });
+
+  addConversion([this](IntegerType intType) -> std::optional<Type> {
+    if (auto scalarType = dyn_cast<spirv::ScalarType>(intType))
+      return convertScalarType(this->targetEnv, this->options, scalarType);
+    if (intType.getWidth() < 8)
+      return convertSubByteIntegerType(this->options, intType);
+    return Type();
+  });
+
+  addConversion([this](FloatType floatType) -> std::optional<Type> {
+    if (auto scalarType = dyn_cast<spirv::ScalarType>(floatType))
+      return convertScalarType(this->targetEnv, this->options, scalarType);
+    return Type();
+  });
+
+  addConversion([this](ComplexType complexType) {
+    return convertComplexType(this->targetEnv, this->options, complexType);
+  });
+
+  addConversion([this](VectorType vectorType) {
+    return convertVectorType(this->targetEnv, this->options, vectorType);
+  });
+
+  addConversion([this](TensorType tensorType) {
+    return convertTensorType(this->targetEnv, this->options, tensorType);
+  });
+
+  addConversion([this](MemRefType memRefType) {
+    return convertMemrefType(this->targetEnv, this->options, memRefType);
+  });
+
+  // Register some last line of defense casting logic.
+  addSourceMaterialization(
+      [this](OpBuilder &builder, Type type, ValueRange inputs, Location loc) {
+        return castToSourceType(this->targetEnv, builder, type, inputs, loc);
+      });
+  addTargetMaterialization([](OpBuilder &builder, Type type, ValueRange inputs,
+                              Location loc) {
+    auto cast = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
+    return std::optional<Value>(cast.getResult(0));
+  });
+}
+
+Type SPIRVTypeConverter::getIndexType() const {
+  return ::getIndexType(getContext(), options);
+}
+
+MLIRContext *SPIRVTypeConverter::getContext() const {
+  return targetEnv.getAttr().getContext();
+}
+
+bool SPIRVTypeConverter::allows(spirv::Capability capability) const {
+  return targetEnv.allows(capability);
+}
+
 //===----------------------------------------------------------------------===//
 // SPIR-V ConversionTarget
 //===----------------------------------------------------------------------===//
@@ -1468,3 +1453,20 @@ bool SPIRVConversionTarget::isLegalOp(Operation *op) {
 
   return true;
 }
+
+//===----------------------------------------------------------------------===//
+// Public functions for populating patterns
+//===----------------------------------------------------------------------===//
+
+void mlir::populateBuiltinFuncToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
+                                              RewritePatternSet &patterns) {
+  patterns.add<FuncOpConversion>(typeConverter, patterns.getContext());
+}
+
+void mlir::populateFuncOpVectorRewritePatterns(RewritePatternSet &patterns) {
+  patterns.add<FuncOpVectorUnroll>(patterns.getContext());
+}
+
+void mlir::populateReturnOpVectorRewritePatterns(RewritePatternSet &patterns) {
+  patterns.add<ReturnOpVectorUnroll>(patterns.getContext());
+}


        


More information about the Mlir-commits mailing list